更新
This commit is contained in:
12
README.md
12
README.md
@@ -5,7 +5,7 @@
|
||||
> 📹 **上传人物** · 🎙️ **输入文案** · 🎬 **一键成片**
|
||||
|
||||
基于 **LatentSync 1.6 + EdgeTTS** 的开源数字人口播视频生成系统。
|
||||
集成 **Qwen3-TTS** 声音克隆与自动社交媒体发布功能。
|
||||
集成 **CosyVoice 3.0** 声音克隆与自动社交媒体发布功能。
|
||||
|
||||
[功能特性](#-功能特性) • [技术栈](#-技术栈) • [文档中心](#-文档中心) • [部署指南](Docs/DEPLOY_MANUAL.md)
|
||||
|
||||
@@ -17,7 +17,7 @@
|
||||
|
||||
### 核心能力
|
||||
- 🎬 **高清唇形同步** - LatentSync 1.6 驱动,512×512 高分辨率 Latent Diffusion 模型。
|
||||
- 🎙️ **多模态配音** - 支持 **EdgeTTS** (微软超自然语音, 10 语言) 和 **Qwen3-TTS** (3秒极速声音克隆)。配音前置工作流:先生成配音 → 选素材 → 生成视频。
|
||||
- 🎙️ **多模态配音** - 支持 **EdgeTTS** (微软超自然语音, 10 语言) 和 **CosyVoice 3.0** (3秒极速声音克隆, 9语言+18方言, 语速可调)。上传参考音频自动 Whisper 转写 + 智能截取。配音前置工作流:先生成配音 → 选素材 → 生成视频。
|
||||
- 📝 **智能字幕** - 集成 faster-whisper + Remotion,自动生成逐字高亮 (卡拉OK效果) 字幕。
|
||||
- 🎨 **样式预设** - 标题/字幕样式选择 + 预览 + 字号调节,支持自定义字体库。
|
||||
- 🖼️ **作品预览一致性** - 标题/字幕预览按素材分辨率缩放,效果更接近成片。
|
||||
@@ -45,7 +45,7 @@
|
||||
| **后端** | FastAPI | Python 3.10, AsyncIO, PM2 |
|
||||
| **数据库** | Supabase | PostgreSQL, Storage (本地/S3), Auth |
|
||||
| **唇形同步** | LatentSync 1.6 | PyTorch 2.5, Diffusers, DeepCache |
|
||||
| **声音克隆** | Qwen3-TTS | 1.7B 参数量,Flash Attention 2 加速 |
|
||||
| **声音克隆** | CosyVoice 3.0 | 0.5B 参数量,9 语言 + 18 方言 |
|
||||
| **自动化** | Playwright | 社交媒体无头浏览器自动化 |
|
||||
| **部署** | Docker & PM2 | 混合部署架构 |
|
||||
|
||||
@@ -57,7 +57,7 @@
|
||||
|
||||
### 部署运维
|
||||
- **[部署手册 (DEPLOY_MANUAL.md)](Docs/DEPLOY_MANUAL.md)** - 👈 **部署请看这里**!包含完整的环境搭建步骤。
|
||||
- [参考音频服务部署 (QWEN3_TTS_DEPLOY.md)](Docs/QWEN3_TTS_DEPLOY.md) - 声音克隆模型部署指南。
|
||||
- [参考音频服务部署 (COSYVOICE3_DEPLOY.md)](Docs/COSYVOICE3_DEPLOY.md) - 声音克隆模型部署指南。
|
||||
- [LatentSync 部署指南](models/LatentSync/DEPLOY.md) - 唇形同步模型独立部署。
|
||||
- [Supabase 部署指南 (SUPABASE_DEPLOY.md)](Docs/SUPABASE_DEPLOY.md) - Supabase 与认证系统配置。
|
||||
|
||||
@@ -82,7 +82,7 @@ ViGent2/
|
||||
├── remotion/ # Remotion 视频渲染 (标题/字幕合成)
|
||||
├── models/ # AI 模型仓库
|
||||
│ ├── LatentSync/ # 唇形同步服务
|
||||
│ └── Qwen3-TTS/ # 声音克隆服务
|
||||
│ └── CosyVoice/ # 声音克隆服务
|
||||
└── Docs/ # 项目文档
|
||||
```
|
||||
|
||||
@@ -97,7 +97,7 @@ ViGent2/
|
||||
| **Web UI** | 3002 | 用户访问入口 (Next.js) |
|
||||
| **Backend API** | 8006 | 核心业务接口 (FastAPI) |
|
||||
| **LatentSync** | 8007 | 唇形同步推理服务 |
|
||||
| **Qwen3-TTS** | 8009 | 声音克隆推理服务 |
|
||||
| **CosyVoice 3.0** | 8010 | 声音克隆推理服务 |
|
||||
| **Supabase** | 8008 | 数据库与认证网关 |
|
||||
|
||||
---
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
"""
|
||||
from typing import Optional, Any, Dict, cast
|
||||
from fastapi import Request, HTTPException, Depends, status
|
||||
from app.core.security import decode_access_token, TokenData
|
||||
from app.repositories.sessions import get_session
|
||||
from app.repositories.users import get_user_by_id
|
||||
from app.core.security import decode_access_token
|
||||
from app.repositories.sessions import get_session, delete_sessions
|
||||
from app.repositories.users import get_user_by_id, deactivate_user_if_expired
|
||||
from loguru import logger
|
||||
|
||||
|
||||
@@ -35,8 +35,12 @@ async def get_current_user_optional(
|
||||
logger.warning(f"Session token 无效: user_id={token_data.user_id}")
|
||||
return None
|
||||
|
||||
user = get_user_by_id(token_data.user_id)
|
||||
return cast(Optional[Dict[str, Any]], user)
|
||||
user = cast(Optional[Dict[str, Any]], get_user_by_id(token_data.user_id))
|
||||
if user and deactivate_user_if_expired(user):
|
||||
delete_sessions(token_data.user_id)
|
||||
return None
|
||||
|
||||
return user
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户信息失败: {e}")
|
||||
return None
|
||||
@@ -82,14 +86,12 @@ async def get_current_user(
|
||||
)
|
||||
user = cast(Dict[str, Any], user)
|
||||
|
||||
if user.get("expires_at"):
|
||||
from datetime import datetime, timezone
|
||||
expires_at = datetime.fromisoformat(user["expires_at"].replace("Z", "+00:00"))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="授权已过期,请联系管理员续期"
|
||||
)
|
||||
if deactivate_user_if_expired(user):
|
||||
delete_sessions(token_data.user_id)
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="会员已到期,请续费"
|
||||
)
|
||||
|
||||
return user
|
||||
except HTTPException:
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
认证 API:注册、登录、登出、修改密码
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Response, status, Request
|
||||
from fastapi import APIRouter, HTTPException, Response, status, Request, Depends
|
||||
from pydantic import BaseModel, field_validator
|
||||
from app.core.security import (
|
||||
get_password_hash,
|
||||
@@ -13,7 +13,15 @@ from app.core.security import (
|
||||
decode_access_token
|
||||
)
|
||||
from app.repositories.sessions import create_session, delete_sessions
|
||||
from app.repositories.users import create_user, get_user_by_id, get_user_by_phone, user_exists_by_phone, update_user
|
||||
from app.repositories.users import (
|
||||
create_user,
|
||||
get_user_by_id,
|
||||
get_user_by_phone,
|
||||
user_exists_by_phone,
|
||||
update_user,
|
||||
deactivate_user_if_expired,
|
||||
)
|
||||
from app.core.deps import get_current_user
|
||||
from app.core.response import success_response
|
||||
from loguru import logger
|
||||
from typing import Optional, Any, cast
|
||||
@@ -130,6 +138,14 @@ async def login(request: LoginRequest, response: Response):
|
||||
detail="手机号或密码错误"
|
||||
)
|
||||
|
||||
# 授权过期时自动停用账号
|
||||
if deactivate_user_if_expired(user):
|
||||
delete_sessions(user["id"])
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="会员已到期,请续费"
|
||||
)
|
||||
|
||||
# 检查是否激活
|
||||
if not user["is_active"]:
|
||||
raise HTTPException(
|
||||
@@ -137,16 +153,6 @@ async def login(request: LoginRequest, response: Response):
|
||||
detail="账号未激活,请等待管理员审核"
|
||||
)
|
||||
|
||||
# 检查授权是否过期
|
||||
if user.get("expires_at"):
|
||||
from datetime import datetime, timezone
|
||||
expires_at = datetime.fromisoformat(user["expires_at"].replace("Z", "+00:00"))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="授权已过期,请联系管理员续期"
|
||||
)
|
||||
|
||||
# 生成新的 session_token (后踢前)
|
||||
session_token = generate_session_token()
|
||||
|
||||
@@ -259,30 +265,8 @@ async def change_password(request: ChangePasswordRequest, req: Request, response
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
async def get_me(request: Request):
|
||||
async def get_me(user: dict = Depends(get_current_user)):
|
||||
"""获取当前用户信息"""
|
||||
# 从 Cookie 获取用户
|
||||
token = request.cookies.get("access_token")
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="未登录"
|
||||
)
|
||||
|
||||
token_data = decode_access_token(token)
|
||||
if not token_data:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token 无效"
|
||||
)
|
||||
|
||||
user = cast(dict[str, Any], get_user_by_id(token_data.user_id) or {})
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户不存在"
|
||||
)
|
||||
|
||||
return success_response(UserResponse(
|
||||
id=user["id"],
|
||||
phone=user["phone"],
|
||||
|
||||
@@ -9,6 +9,7 @@ class GenerateAudioRequest(BaseModel):
|
||||
ref_audio_id: Optional[str] = None
|
||||
ref_text: Optional[str] = None
|
||||
language: str = "zh-CN"
|
||||
speed: float = 1.0
|
||||
|
||||
|
||||
class RenameAudioRequest(BaseModel):
|
||||
|
||||
@@ -25,7 +25,7 @@ from app.modules.generated_audios.schemas import (
|
||||
BUCKET = "generated-audios"
|
||||
|
||||
|
||||
def _locale_to_qwen_lang(locale: str) -> str:
|
||||
def _locale_to_tts_lang(locale: str) -> str:
|
||||
mapping = {"zh": "Chinese", "en": "English"}
|
||||
return mapping.get(locale.split("-")[0], "Auto")
|
||||
|
||||
@@ -73,19 +73,20 @@ async def generate_audio_task(task_id: str, req: GenerateAudioRequest, user_id:
|
||||
async for chunk in resp.aiter_bytes():
|
||||
f.write(chunk)
|
||||
|
||||
task_store.update(task_id, {"progress": 40, "message": "正在克隆声音 (Qwen3-TTS)..."})
|
||||
task_store.update(task_id, {"progress": 40, "message": "正在克隆声音..."})
|
||||
await voice_clone_service.generate_audio(
|
||||
text=req.text,
|
||||
ref_audio_path=ref_local,
|
||||
ref_text=req.ref_text,
|
||||
output_path=audio_path,
|
||||
language=_locale_to_qwen_lang(req.language),
|
||||
language=_locale_to_tts_lang(req.language),
|
||||
speed=req.speed,
|
||||
)
|
||||
finally:
|
||||
if os.path.exists(ref_local):
|
||||
os.unlink(ref_local)
|
||||
else:
|
||||
task_store.update(task_id, {"progress": 30, "message": "正在生成语音 (EdgeTTS)..."})
|
||||
task_store.update(task_id, {"progress": 30, "message": "正在生成语音..."})
|
||||
tts = TTSService()
|
||||
await tts.generate_audio(req.text, req.voice, audio_path)
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ router = APIRouter()
|
||||
@router.post("")
|
||||
async def upload_ref_audio(
|
||||
file: UploadFile = File(...),
|
||||
ref_text: str = Form(...),
|
||||
ref_text: str = Form(""),
|
||||
user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""上传参考音频"""
|
||||
@@ -68,3 +68,21 @@ async def rename_ref_audio(
|
||||
except Exception as e:
|
||||
logger.error(f"重命名失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"重命名失败: {str(e)}")
|
||||
|
||||
|
||||
@router.post("/{audio_id:path}/retranscribe")
|
||||
async def retranscribe_ref_audio(
|
||||
audio_id: str,
|
||||
user: dict = Depends(get_current_user)
|
||||
):
|
||||
"""重新识别参考音频的文字内容"""
|
||||
try:
|
||||
result = await service.retranscribe_ref_audio(audio_id, user["id"])
|
||||
return success_response(result, message="识别完成")
|
||||
except PermissionError as e:
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"重新识别失败: {e}")
|
||||
raise HTTPException(status_code=500, detail=f"识别失败: {str(e)}")
|
||||
|
||||
@@ -41,16 +41,40 @@ def _get_audio_duration(file_path: str) -> float:
|
||||
return 0.0
|
||||
|
||||
|
||||
def _convert_to_wav(input_path: str, output_path: str) -> bool:
|
||||
"""将音频转换为 WAV 格式 (16kHz, mono)"""
|
||||
def _find_silence_cut_point(file_path: str, max_duration: float) -> float:
|
||||
"""在 max_duration 附近找一个静音点作为截取位置,找不到则回退到 max_duration"""
|
||||
try:
|
||||
subprocess.run([
|
||||
'ffmpeg', '-y', '-i', input_path,
|
||||
'-ar', '16000',
|
||||
'-ac', '1',
|
||||
'-acodec', 'pcm_s16le',
|
||||
output_path
|
||||
], capture_output=True, timeout=60, check=True)
|
||||
# 用 silencedetect 找所有静音段(阈值 -30dB,最短 0.3 秒)
|
||||
result = subprocess.run(
|
||||
['ffmpeg', '-i', file_path, '-af',
|
||||
'silencedetect=noise=-30dB:d=0.3', '-f', 'null', '-'],
|
||||
capture_output=True, text=True, timeout=30
|
||||
)
|
||||
# 解析 silence_end 时间点
|
||||
import re as _re
|
||||
ends = [float(m) for m in _re.findall(r'silence_end:\s*([\d.]+)', result.stderr)]
|
||||
# 找 max_duration 之前最后一个静音结束点(至少 3 秒)
|
||||
candidates = [t for t in ends if 3.0 <= t <= max_duration]
|
||||
if candidates:
|
||||
cut = candidates[-1]
|
||||
logger.info(f"Found silence cut point at {cut:.1f}s (max={max_duration}s)")
|
||||
return cut
|
||||
except Exception as e:
|
||||
logger.warning(f"Silence detection failed: {e}")
|
||||
return max_duration
|
||||
|
||||
|
||||
def _convert_to_wav(input_path: str, output_path: str, max_duration: float = 0) -> bool:
|
||||
"""将音频转换为 WAV 格式 (16kHz, mono),可选截取前 max_duration 秒并淡出"""
|
||||
try:
|
||||
cmd = ['ffmpeg', '-y', '-i', input_path]
|
||||
if max_duration > 0:
|
||||
cmd += ['-t', str(max_duration)]
|
||||
# 末尾 0.1 秒淡出,避免截断爆音
|
||||
fade_start = max(0, max_duration - 0.1)
|
||||
cmd += ['-af', f'afade=t=out:st={fade_start}:d=0.1']
|
||||
cmd += ['-ar', '16000', '-ac', '1', '-acodec', 'pcm_s16le', output_path]
|
||||
subprocess.run(cmd, capture_output=True, timeout=60, check=True)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"音频转换失败: {e}")
|
||||
@@ -67,9 +91,6 @@ async def upload_ref_audio(file, ref_text: str, user_id: str) -> dict:
|
||||
if ext not in ALLOWED_AUDIO_EXTENSIONS:
|
||||
raise ValueError(f"不支持的音频格式: {ext}。支持的格式: {', '.join(ALLOWED_AUDIO_EXTENSIONS)}")
|
||||
|
||||
if not ref_text or len(ref_text.strip()) < 2:
|
||||
raise ValueError("参考文字不能为空")
|
||||
|
||||
# 创建临时文件
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp_input:
|
||||
content = await file.read()
|
||||
@@ -86,8 +107,31 @@ async def upload_ref_audio(file, ref_text: str, user_id: str) -> dict:
|
||||
duration = _get_audio_duration(tmp_wav_path)
|
||||
if duration < 1.0:
|
||||
raise ValueError("音频时长过短,至少需要 1 秒")
|
||||
if duration > 60.0:
|
||||
raise ValueError("音频时长过长,最多 60 秒")
|
||||
|
||||
# 超过 10 秒自动在静音点截取(CosyVoice 对 3-10 秒效果最好)
|
||||
MAX_REF_DURATION = 10.0
|
||||
if duration > MAX_REF_DURATION:
|
||||
cut_point = _find_silence_cut_point(tmp_wav_path, MAX_REF_DURATION)
|
||||
logger.info(f"Ref audio {duration:.1f}s > {MAX_REF_DURATION}s, trimming at {cut_point:.1f}s")
|
||||
trimmed_path = tmp_input_path + "_trimmed.wav"
|
||||
if not _convert_to_wav(tmp_wav_path, trimmed_path, max_duration=cut_point):
|
||||
raise RuntimeError("音频截取失败")
|
||||
os.unlink(tmp_wav_path)
|
||||
tmp_wav_path = trimmed_path
|
||||
duration = _get_audio_duration(tmp_wav_path)
|
||||
|
||||
# 自动转写参考音频内容
|
||||
try:
|
||||
from app.services.whisper_service import whisper_service
|
||||
transcribed = await whisper_service.transcribe(tmp_wav_path)
|
||||
if transcribed.strip():
|
||||
ref_text = transcribed.strip()
|
||||
logger.info(f"Auto-transcribed ref audio: {ref_text[:50]}...")
|
||||
except Exception as e:
|
||||
logger.warning(f"Auto-transcribe failed: {e}")
|
||||
|
||||
if not ref_text or not ref_text.strip():
|
||||
raise ValueError("无法识别音频内容,请确保音频包含清晰的语音")
|
||||
|
||||
# 检查重名
|
||||
existing_files = await storage_service.list_files(BUCKET_REF_AUDIOS, user_id)
|
||||
@@ -267,3 +311,85 @@ async def rename_ref_audio(audio_id: str, new_name: str, user_id: str) -> dict:
|
||||
)
|
||||
|
||||
return {"name": new_name}
|
||||
|
||||
|
||||
async def retranscribe_ref_audio(audio_id: str, user_id: str) -> dict:
|
||||
"""重新转写参考音频的 ref_text,并截取前 10 秒重新上传(用于迁移旧数据)"""
|
||||
if not audio_id.startswith(f"{user_id}/"):
|
||||
raise PermissionError("无权修改此文件")
|
||||
|
||||
# 下载音频到临时文件
|
||||
audio_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, audio_id)
|
||||
tmp_wav_path = None
|
||||
trimmed_path = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
||||
tmp_wav_path = tmp.name
|
||||
timeout = httpx.Timeout(None)
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream("GET", audio_url) as resp:
|
||||
resp.raise_for_status()
|
||||
async for chunk in resp.aiter_bytes():
|
||||
tmp.write(chunk)
|
||||
|
||||
# 超过 10 秒则截取前 10 秒并重新上传音频
|
||||
MAX_REF_DURATION = 10.0
|
||||
duration = _get_audio_duration(tmp_wav_path)
|
||||
transcribe_path = tmp_wav_path
|
||||
need_reupload = False
|
||||
|
||||
if duration > MAX_REF_DURATION:
|
||||
cut_point = _find_silence_cut_point(tmp_wav_path, MAX_REF_DURATION)
|
||||
logger.info(f"Retranscribe: trimming {audio_id} from {duration:.1f}s at {cut_point:.1f}s")
|
||||
trimmed_path = tmp_wav_path + "_trimmed.wav"
|
||||
if _convert_to_wav(tmp_wav_path, trimmed_path, max_duration=cut_point):
|
||||
transcribe_path = trimmed_path
|
||||
duration = _get_audio_duration(trimmed_path)
|
||||
need_reupload = True
|
||||
|
||||
# Whisper 转写
|
||||
from app.services.whisper_service import whisper_service
|
||||
transcribed = await whisper_service.transcribe(transcribe_path)
|
||||
if not transcribed or not transcribed.strip():
|
||||
raise ValueError("无法识别音频内容")
|
||||
|
||||
ref_text = transcribed.strip()
|
||||
logger.info(f"Re-transcribed ref audio {audio_id}: {ref_text[:50]}...")
|
||||
|
||||
# 截取过的音频重新上传覆盖原文件
|
||||
if need_reupload and trimmed_path:
|
||||
with open(trimmed_path, "rb") as f:
|
||||
await storage_service.upload_file(
|
||||
bucket=BUCKET_REF_AUDIOS, path=audio_id,
|
||||
file_data=f.read(), content_type="audio/wav",
|
||||
)
|
||||
logger.info(f"Re-uploaded trimmed audio: {audio_id} ({duration:.1f}s)")
|
||||
|
||||
# 更新 metadata
|
||||
metadata_path = audio_id.replace(".wav", ".json")
|
||||
try:
|
||||
meta_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, metadata_path)
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
resp = await client.get(meta_url)
|
||||
if resp.status_code == 200:
|
||||
metadata = resp.json()
|
||||
else:
|
||||
raise Exception(f"status {resp.status_code}")
|
||||
except Exception:
|
||||
metadata = {}
|
||||
|
||||
metadata["ref_text"] = ref_text
|
||||
metadata["duration_sec"] = duration
|
||||
await storage_service.upload_file(
|
||||
bucket=BUCKET_REF_AUDIOS,
|
||||
path=metadata_path,
|
||||
file_data=json.dumps(metadata, ensure_ascii=False).encode('utf-8'),
|
||||
content_type="application/json"
|
||||
)
|
||||
|
||||
return {"ref_text": ref_text, "duration_sec": duration}
|
||||
finally:
|
||||
if tmp_wav_path and os.path.exists(tmp_wav_path):
|
||||
os.unlink(tmp_wav_path)
|
||||
if trimmed_path and os.path.exists(trimmed_path):
|
||||
os.unlink(trimmed_path)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Literal
|
||||
|
||||
|
||||
class CustomAssignment(BaseModel):
|
||||
@@ -7,6 +7,7 @@ class CustomAssignment(BaseModel):
|
||||
start: float # 音频时间轴起点
|
||||
end: float # 音频时间轴终点
|
||||
source_start: float = 0.0 # 源视频截取起点
|
||||
source_end: Optional[float] = None # 源视频截取终点(可选)
|
||||
|
||||
|
||||
class GenerateRequest(BaseModel):
|
||||
@@ -30,3 +31,4 @@ class GenerateRequest(BaseModel):
|
||||
bgm_id: Optional[str] = None
|
||||
bgm_volume: Optional[float] = 0.2
|
||||
custom_assignments: Optional[List[CustomAssignment]] = None
|
||||
output_aspect_ratio: Literal["9:16", "16:9"] = "9:16"
|
||||
|
||||
@@ -29,7 +29,7 @@ def _locale_to_whisper_lang(locale: str) -> str:
|
||||
return locale.split("-")[0] if "-" in locale else locale
|
||||
|
||||
|
||||
def _locale_to_qwen_lang(locale: str) -> str:
|
||||
def _locale_to_tts_lang(locale: str) -> str:
|
||||
"""'zh-CN' → 'Chinese', 'en-US' → 'English', 其他 → 'Auto'"""
|
||||
mapping = {"zh": "Chinese", "en": "English"}
|
||||
return mapping.get(locale.split("-")[0], "Auto")
|
||||
@@ -174,17 +174,27 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
|
||||
|
||||
# ── 确定素材列表 ──
|
||||
material_paths: List[str] = []
|
||||
if req.material_paths and len(req.material_paths) > 1:
|
||||
if req.custom_assignments and len(req.custom_assignments) > 1:
|
||||
material_paths = [a.material_path for a in req.custom_assignments if a.material_path]
|
||||
elif req.material_paths and len(req.material_paths) > 1:
|
||||
material_paths = req.material_paths
|
||||
else:
|
||||
material_paths = [req.material_path]
|
||||
|
||||
is_multi = len(material_paths) > 1
|
||||
target_resolution = (1080, 1920) if req.output_aspect_ratio == "9:16" else (1920, 1080)
|
||||
|
||||
logger.info(
|
||||
f"[Render] 输出画面比例: {req.output_aspect_ratio}, "
|
||||
f"目标分辨率: {target_resolution[0]}x{target_resolution[1]}"
|
||||
)
|
||||
|
||||
_update_task(task_id, status="processing", progress=5, message="正在下载素材...")
|
||||
|
||||
temp_dir = settings.UPLOAD_DIR / "temp"
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
video = VideoService()
|
||||
input_material_path: Optional[Path] = None
|
||||
|
||||
# 单素材模式:下载主素材
|
||||
if not is_multi:
|
||||
@@ -192,6 +202,16 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
|
||||
temp_files.append(input_material_path)
|
||||
await _download_material(material_paths[0], input_material_path)
|
||||
|
||||
# 归一化旋转元数据(如 iPhone MOV 1920x1080 + rotation=-90)
|
||||
normalized_input_path = temp_dir / f"{task_id}_input_norm.mp4"
|
||||
normalized_result = video.normalize_orientation(
|
||||
str(input_material_path),
|
||||
str(normalized_input_path),
|
||||
)
|
||||
if normalized_result != str(input_material_path):
|
||||
temp_files.append(normalized_input_path)
|
||||
input_material_path = normalized_input_path
|
||||
|
||||
_update_task(task_id, message="正在生成语音...", progress=10)
|
||||
|
||||
audio_path = temp_dir / f"{task_id}_audio.wav"
|
||||
@@ -218,8 +238,10 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
|
||||
if resp.status_code == 200:
|
||||
meta = resp.json()
|
||||
req.language = meta.get("language", req.language)
|
||||
if not req.text.strip():
|
||||
req.text = meta.get("text", req.text)
|
||||
# 无条件用配音元数据覆盖文案,确保字幕与配音语言一致
|
||||
meta_text = meta.get("text", "")
|
||||
if meta_text:
|
||||
req.text = meta_text
|
||||
except Exception as e:
|
||||
logger.warning(f"读取配音元数据失败: {e}")
|
||||
|
||||
@@ -238,13 +260,13 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
|
||||
)
|
||||
await _download_material(ref_audio_url, ref_audio_local)
|
||||
|
||||
_update_task(task_id, message="正在克隆声音 (Qwen3-TTS)...")
|
||||
_update_task(task_id, message="正在克隆声音...")
|
||||
await voice_clone_service.generate_audio(
|
||||
text=req.text,
|
||||
ref_audio_path=str(ref_audio_local),
|
||||
ref_text=req.ref_text,
|
||||
output_path=str(audio_path),
|
||||
language=_locale_to_qwen_lang(req.language)
|
||||
language=_locale_to_tts_lang(req.language)
|
||||
)
|
||||
else:
|
||||
_update_task(task_id, message="正在生成语音 (EdgeTTS)...")
|
||||
@@ -258,7 +280,6 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
|
||||
lipsync_video_path = temp_dir / f"{task_id}_lipsync.mp4"
|
||||
temp_files.append(lipsync_video_path)
|
||||
|
||||
video = VideoService()
|
||||
captions_path = None
|
||||
|
||||
if is_multi:
|
||||
@@ -267,7 +288,7 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
|
||||
# ══════════════════════════════════════
|
||||
_update_task(task_id, progress=12, message="正在分配素材...")
|
||||
|
||||
if req.custom_assignments:
|
||||
if req.custom_assignments and len(req.custom_assignments) == len(material_paths):
|
||||
# 用户自定义分配,跳过 Whisper 均分
|
||||
assignments = [
|
||||
{
|
||||
@@ -275,6 +296,7 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
|
||||
"start": a.start,
|
||||
"end": a.end,
|
||||
"source_start": a.source_start,
|
||||
"source_end": a.source_end,
|
||||
"index": i,
|
||||
}
|
||||
for i, a in enumerate(req.custom_assignments)
|
||||
@@ -290,6 +312,7 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
|
||||
text=req.text,
|
||||
output_path=str(captions_path),
|
||||
language=_locale_to_whisper_lang(req.language),
|
||||
original_text=req.text,
|
||||
)
|
||||
print(f"[Pipeline] Whisper alignment completed (custom assignments)")
|
||||
except Exception as e:
|
||||
@@ -297,6 +320,49 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
|
||||
captions_path = None
|
||||
else:
|
||||
captions_path = None
|
||||
elif req.custom_assignments:
|
||||
logger.warning(
|
||||
f"[MultiMat] custom_assignments 数量({len(req.custom_assignments)})"
|
||||
f" 与素材数量({len(material_paths)})不一致,回退自动分配"
|
||||
)
|
||||
|
||||
# 原有逻辑:Whisper → _split_equal
|
||||
_update_task(task_id, message="正在生成字幕 (Whisper)...")
|
||||
|
||||
captions_path = temp_dir / f"{task_id}_captions.json"
|
||||
temp_files.append(captions_path)
|
||||
|
||||
try:
|
||||
captions_data = await whisper_service.align(
|
||||
audio_path=str(audio_path),
|
||||
text=req.text,
|
||||
output_path=str(captions_path),
|
||||
language=_locale_to_whisper_lang(req.language),
|
||||
original_text=req.text,
|
||||
)
|
||||
print(f"[Pipeline] Whisper alignment completed (multi-material)")
|
||||
except Exception as e:
|
||||
logger.warning(f"Whisper alignment failed: {e}")
|
||||
captions_data = None
|
||||
captions_path = None
|
||||
|
||||
_update_task(task_id, progress=15, message="正在分配素材...")
|
||||
|
||||
if captions_data and captions_data.get("segments"):
|
||||
assignments = _split_equal(captions_data["segments"], material_paths)
|
||||
else:
|
||||
# Whisper 失败 → 按时长均分(不依赖字符对齐)
|
||||
logger.warning("[MultiMat] Whisper 无数据,按时长均分")
|
||||
audio_dur = video._get_duration(str(audio_path))
|
||||
if audio_dur <= 0:
|
||||
audio_dur = 30.0 # 安全兜底
|
||||
seg_dur = audio_dur / len(material_paths)
|
||||
assignments = [
|
||||
{"material_path": material_paths[i], "start": i * seg_dur,
|
||||
"end": (i + 1) * seg_dur, "index": i}
|
||||
for i in range(len(material_paths))
|
||||
]
|
||||
|
||||
else:
|
||||
# 原有逻辑:Whisper → _split_equal
|
||||
_update_task(task_id, message="正在生成字幕 (Whisper)...")
|
||||
@@ -310,6 +376,7 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
|
||||
text=req.text,
|
||||
output_path=str(captions_path),
|
||||
language=_locale_to_whisper_lang(req.language),
|
||||
original_text=req.text,
|
||||
)
|
||||
print(f"[Pipeline] Whisper alignment completed (multi-material)")
|
||||
except Exception as e:
|
||||
@@ -356,12 +423,23 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
|
||||
material_local = temp_dir / f"{task_id}_material_{i}.mp4"
|
||||
temp_files.append(material_local)
|
||||
await _download_material(assignment["material_path"], material_local)
|
||||
|
||||
# 归一化旋转元数据,确保分辨率判断与后续推理一致
|
||||
normalized_material = temp_dir / f"{task_id}_material_{i}_norm.mp4"
|
||||
normalized_result = video.normalize_orientation(
|
||||
str(material_local),
|
||||
str(normalized_material),
|
||||
)
|
||||
if normalized_result != str(material_local):
|
||||
temp_files.append(normalized_material)
|
||||
material_local = normalized_material
|
||||
|
||||
material_locals.append(material_local)
|
||||
resolutions.append(video.get_resolution(str(material_local)))
|
||||
|
||||
# 分辨率不一致时,统一到第一个素材的分辨率
|
||||
base_res = resolutions[0] if resolutions else (0, 0)
|
||||
need_scale = any(r != base_res for r in resolutions) and base_res[0] > 0
|
||||
# 按用户选择的画面比例统一分辨率
|
||||
base_res = target_resolution
|
||||
need_scale = any(r != base_res for r in resolutions)
|
||||
if need_scale:
|
||||
logger.info(f"[MultiMat] 素材分辨率不一致,统一到 {base_res[0]}x{base_res[1]}")
|
||||
|
||||
@@ -381,8 +459,11 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
|
||||
temp_files.append(prepared_path)
|
||||
video.prepare_segment(
|
||||
str(material_locals[i]), seg_dur, str(prepared_path),
|
||||
target_resolution=base_res if need_scale else None,
|
||||
# 多素材拼接前统一重编码为同分辨率/同编码,避免 concat 仅保留首段
|
||||
target_resolution=base_res,
|
||||
source_start=assignment.get("source_start", 0.0),
|
||||
source_end=assignment.get("source_end"),
|
||||
target_fps=25,
|
||||
)
|
||||
prepared_segments.append(prepared_path)
|
||||
|
||||
@@ -392,7 +473,8 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
|
||||
temp_files.append(concat_path)
|
||||
video.concat_videos(
|
||||
[str(p) for p in prepared_segments],
|
||||
str(concat_path)
|
||||
str(concat_path),
|
||||
target_fps=25,
|
||||
)
|
||||
|
||||
# ── 第三步:一次 LatentSync 推理 ──
|
||||
@@ -425,23 +507,31 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
|
||||
# 单素材流水线(原有逻辑)
|
||||
# ══════════════════════════════════════
|
||||
|
||||
# 单素材 + source_start:先截取片段
|
||||
if input_material_path is None:
|
||||
raise RuntimeError("单素材流程缺少输入素材")
|
||||
|
||||
# 单素材:按用户选择画面比例统一到目标分辨率,并应用 source_start
|
||||
single_source_start = 0.0
|
||||
single_source_end = None
|
||||
if req.custom_assignments and len(req.custom_assignments) == 1:
|
||||
single_source_start = req.custom_assignments[0].source_start
|
||||
single_source_end = req.custom_assignments[0].source_end
|
||||
|
||||
if single_source_start > 0:
|
||||
_update_task(task_id, progress=20, message="正在截取素材片段...")
|
||||
audio_dur = video._get_duration(str(audio_path))
|
||||
if audio_dur <= 0:
|
||||
audio_dur = 30.0
|
||||
trimmed_path = temp_dir / f"{task_id}_trimmed.mp4"
|
||||
temp_files.append(trimmed_path)
|
||||
video.prepare_segment(
|
||||
str(input_material_path), audio_dur, str(trimmed_path),
|
||||
source_start=single_source_start,
|
||||
)
|
||||
input_material_path = trimmed_path
|
||||
_update_task(task_id, progress=20, message="正在准备素材片段...")
|
||||
audio_dur = video._get_duration(str(audio_path))
|
||||
if audio_dur <= 0:
|
||||
audio_dur = 30.0
|
||||
prepared_single_path = temp_dir / f"{task_id}_prepared_single.mp4"
|
||||
temp_files.append(prepared_single_path)
|
||||
video.prepare_segment(
|
||||
str(input_material_path),
|
||||
audio_dur,
|
||||
str(prepared_single_path),
|
||||
target_resolution=target_resolution,
|
||||
source_start=single_source_start,
|
||||
source_end=single_source_end,
|
||||
)
|
||||
input_material_path = prepared_single_path
|
||||
|
||||
_update_task(task_id, progress=25)
|
||||
_update_task(task_id, message="正在合成唇形 (LatentSync)...", progress=30)
|
||||
@@ -476,6 +566,7 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
|
||||
text=req.text,
|
||||
output_path=str(captions_path),
|
||||
language=_locale_to_whisper_lang(req.language),
|
||||
original_text=req.text,
|
||||
)
|
||||
print(f"[Pipeline] Whisper alignment completed")
|
||||
except Exception as e:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
||||
from app.core.supabase import get_supabase
|
||||
@@ -37,3 +38,33 @@ def update_user(user_id: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
|
||||
supabase = get_supabase()
|
||||
result = supabase.table("users").update(payload).eq("id", user_id).execute()
|
||||
return cast(List[Dict[str, Any]], result.data or [])
|
||||
|
||||
|
||||
def _parse_expires_at(expires_at: Any) -> Optional[datetime]:
|
||||
try:
|
||||
expires_at_dt = datetime.fromisoformat(str(expires_at).replace("Z", "+00:00"))
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
if expires_at_dt.tzinfo is None:
|
||||
expires_at_dt = expires_at_dt.replace(tzinfo=timezone.utc)
|
||||
return expires_at_dt.astimezone(timezone.utc)
|
||||
|
||||
|
||||
def deactivate_user_if_expired(user: Dict[str, Any]) -> bool:
|
||||
expires_at = user.get("expires_at")
|
||||
if not expires_at:
|
||||
return False
|
||||
|
||||
expires_at_dt = _parse_expires_at(expires_at)
|
||||
if not expires_at_dt:
|
||||
return False
|
||||
|
||||
if datetime.now(timezone.utc) <= expires_at_dt:
|
||||
return False
|
||||
|
||||
user_id = user.get("id")
|
||||
if user.get("is_active") and user_id:
|
||||
update_user(cast(str, user_id), {"is_active": False})
|
||||
|
||||
return True
|
||||
|
||||
@@ -13,6 +13,107 @@ class VideoService:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def get_video_metadata(self, file_path: str) -> dict:
|
||||
"""获取视频元信息(含旋转角与有效显示分辨率)"""
|
||||
cmd = [
|
||||
"ffprobe", "-v", "error",
|
||||
"-select_streams", "v:0",
|
||||
"-show_entries", "stream=width,height:stream_side_data=rotation",
|
||||
"-of", "json",
|
||||
file_path,
|
||||
]
|
||||
default_info = {
|
||||
"width": 0,
|
||||
"height": 0,
|
||||
"rotation": 0,
|
||||
"effective_width": 0,
|
||||
"effective_height": 0,
|
||||
}
|
||||
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
||||
if result.returncode != 0:
|
||||
return default_info
|
||||
|
||||
payload = json.loads(result.stdout or "{}")
|
||||
streams = payload.get("streams") or []
|
||||
if not streams:
|
||||
return default_info
|
||||
|
||||
stream = streams[0]
|
||||
width = int(stream.get("width") or 0)
|
||||
height = int(stream.get("height") or 0)
|
||||
|
||||
rotation = 0
|
||||
for side_data in stream.get("side_data_list") or []:
|
||||
if not isinstance(side_data, dict):
|
||||
continue
|
||||
raw_rotation = side_data.get("rotation")
|
||||
if raw_rotation is None:
|
||||
continue
|
||||
try:
|
||||
rotation = int(round(float(str(raw_rotation))))
|
||||
except Exception:
|
||||
rotation = 0
|
||||
break
|
||||
|
||||
norm_rotation = rotation % 360
|
||||
if norm_rotation > 180:
|
||||
norm_rotation -= 360
|
||||
swap_wh = abs(norm_rotation) == 90
|
||||
|
||||
effective_width = height if swap_wh else width
|
||||
effective_height = width if swap_wh else height
|
||||
|
||||
return {
|
||||
"width": width,
|
||||
"height": height,
|
||||
"rotation": norm_rotation,
|
||||
"effective_width": effective_width,
|
||||
"effective_height": effective_height,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.warning(f"获取视频元信息失败: {e}")
|
||||
return default_info
|
||||
|
||||
def normalize_orientation(self, video_path: str, output_path: str) -> str:
|
||||
"""将带旋转元数据的视频转为物理方向,避免后续流程忽略 rotation。"""
|
||||
info = self.get_video_metadata(video_path)
|
||||
rotation = int(info.get("rotation") or 0)
|
||||
if rotation == 0:
|
||||
return video_path
|
||||
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
logger.info(
|
||||
f"检测到旋转元数据 rotation={rotation},归一化方向: "
|
||||
f"{info.get('effective_width', 0)}x{info.get('effective_height', 0)}"
|
||||
)
|
||||
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-i", video_path,
|
||||
"-map", "0:v:0",
|
||||
"-map", "0:a?",
|
||||
"-c:v", "libx264",
|
||||
"-preset", "fast",
|
||||
"-crf", "18",
|
||||
"-c:a", "copy",
|
||||
"-movflags", "+faststart",
|
||||
output_path,
|
||||
]
|
||||
|
||||
if self._run_ffmpeg(cmd):
|
||||
normalized = self.get_video_metadata(output_path)
|
||||
logger.info(
|
||||
"视频方向归一化完成: "
|
||||
f"coded={normalized.get('width', 0)}x{normalized.get('height', 0)}, "
|
||||
f"rotation={normalized.get('rotation', 0)}"
|
||||
)
|
||||
return output_path
|
||||
|
||||
logger.warning("视频方向归一化失败,回退使用原视频")
|
||||
return video_path
|
||||
|
||||
def _run_ffmpeg(self, cmd: list) -> bool:
|
||||
cmd_str = ' '.join(shlex.quote(str(c)) for c in cmd)
|
||||
logger.debug(f"FFmpeg CMD: {cmd_str}")
|
||||
@@ -139,7 +240,7 @@ class VideoService:
|
||||
else:
|
||||
raise RuntimeError("FFmpeg composition failed")
|
||||
|
||||
def concat_videos(self, video_paths: list, output_path: str) -> str:
|
||||
def concat_videos(self, video_paths: list, output_path: str, target_fps: int = 25) -> str:
|
||||
"""使用 FFmpeg concat demuxer 拼接多个视频片段"""
|
||||
if not video_paths:
|
||||
raise ValueError("No video segments to concat")
|
||||
@@ -156,8 +257,16 @@ class VideoService:
|
||||
"ffmpeg", "-y",
|
||||
"-f", "concat",
|
||||
"-safe", "0",
|
||||
"-fflags", "+genpts",
|
||||
"-i", str(list_path),
|
||||
"-c", "copy",
|
||||
"-an",
|
||||
"-vsync", "cfr",
|
||||
"-r", str(target_fps),
|
||||
"-c:v", "libx264",
|
||||
"-preset", "fast",
|
||||
"-crf", "18",
|
||||
"-pix_fmt", "yuv420p",
|
||||
"-movflags", "+faststart",
|
||||
output_path,
|
||||
]
|
||||
|
||||
@@ -193,27 +302,22 @@ class VideoService:
|
||||
return output_path
|
||||
raise RuntimeError(f"FFmpeg audio split failed: {start}-{end}")
|
||||
|
||||
def get_resolution(self, file_path: str) -> tuple:
|
||||
"""获取视频分辨率,返回 (width, height)"""
|
||||
cmd = [
|
||||
'ffprobe', '-v', 'error',
|
||||
'-select_streams', 'v:0',
|
||||
'-show_entries', 'stream=width,height',
|
||||
'-of', 'csv=p=0',
|
||||
file_path
|
||||
]
|
||||
try:
|
||||
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
|
||||
parts = result.stdout.strip().split(',')
|
||||
return (int(parts[0]), int(parts[1]))
|
||||
except Exception:
|
||||
return (0, 0)
|
||||
def get_resolution(self, file_path: str) -> tuple[int, int]:
|
||||
"""获取视频有效显示分辨率(考虑旋转元数据)。"""
|
||||
info = self.get_video_metadata(file_path)
|
||||
return (
|
||||
int(info.get("effective_width") or 0),
|
||||
int(info.get("effective_height") or 0),
|
||||
)
|
||||
|
||||
def prepare_segment(self, video_path: str, target_duration: float, output_path: str,
|
||||
target_resolution: tuple = None, source_start: float = 0.0) -> str:
|
||||
target_resolution: Optional[tuple] = None, source_start: float = 0.0,
|
||||
source_end: Optional[float] = None, target_fps: Optional[int] = None) -> str:
|
||||
"""将素材视频裁剪或循环到指定时长(无音频)。
|
||||
target_resolution: (width, height) 如需统一分辨率则传入,否则保持原分辨率。
|
||||
source_start: 源视频截取起点(秒),默认 0。
|
||||
source_end: 源视频截取终点(秒),默认到素材结尾。
|
||||
target_fps: 输出帧率(可选),用于多素材拼接前统一时间基。
|
||||
"""
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
@@ -221,16 +325,27 @@ class VideoService:
|
||||
if video_dur <= 0:
|
||||
video_dur = target_duration
|
||||
|
||||
clip_end = video_dur
|
||||
if source_end is not None:
|
||||
try:
|
||||
source_end_value = float(source_end)
|
||||
if source_end_value > source_start:
|
||||
clip_end = min(source_end_value, video_dur)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 可用时长 = 从 source_start 到视频结尾
|
||||
available = max(video_dur - source_start, 0.1)
|
||||
available = max(clip_end - source_start, 0.1)
|
||||
needs_loop = target_duration > available
|
||||
needs_scale = target_resolution is not None
|
||||
needs_fps = bool(target_fps and target_fps > 0)
|
||||
has_source_end = clip_end < video_dur
|
||||
|
||||
# 当需要循环且有 source_start 时,先裁剪出片段,再循环裁剪后的文件
|
||||
# 避免 stream_loop 循环整个视频(而不是从 source_start 开始的片段)
|
||||
# 当需要循环且存在截取范围时,先裁剪出片段,再循环裁剪后的文件
|
||||
# 避免 stream_loop 循环整个视频(而不是截取后的片段)
|
||||
actual_input = video_path
|
||||
trim_temp = None
|
||||
if needs_loop and source_start > 0:
|
||||
if needs_loop and (source_start > 0 or has_source_end):
|
||||
trim_temp = str(Path(output_path).parent / (Path(output_path).stem + "_trim_tmp.mp4"))
|
||||
trim_cmd = [
|
||||
"ffmpeg", "-y",
|
||||
@@ -257,12 +372,20 @@ class VideoService:
|
||||
cmd.extend(["-ss", str(source_start)])
|
||||
cmd.extend(["-i", actual_input, "-t", str(target_duration), "-an"])
|
||||
|
||||
filters = []
|
||||
if needs_fps:
|
||||
filters.append(f"fps={int(target_fps)}")
|
||||
if needs_scale:
|
||||
w, h = target_resolution
|
||||
cmd.extend(["-vf", f"scale={w}:{h}:force_original_aspect_ratio=decrease,pad={w}:{h}:(ow-iw)/2:(oh-ih)/2"])
|
||||
filters.append(f"scale={w}:{h}:force_original_aspect_ratio=decrease,pad={w}:{h}:(ow-iw)/2:(oh-ih)/2")
|
||||
|
||||
if filters:
|
||||
cmd.extend(["-vf", ",".join(filters)])
|
||||
if needs_fps:
|
||||
cmd.extend(["-vsync", "cfr", "-r", str(int(target_fps))])
|
||||
|
||||
# 需要循环、缩放或指定起点时必须重编码,否则用 stream copy 保持原画质
|
||||
if needs_loop or needs_scale or source_start > 0:
|
||||
if needs_loop or needs_scale or source_start > 0 or has_source_end or needs_fps:
|
||||
cmd.extend(["-c:v", "libx264", "-preset", "fast", "-crf", "18"])
|
||||
else:
|
||||
cmd.extend(["-c:v", "copy"])
|
||||
|
||||
@@ -1,37 +1,104 @@
|
||||
"""
|
||||
声音克隆服务
|
||||
通过 HTTP 调用 Qwen3-TTS 独立服务 (端口 8009)
|
||||
通过 HTTP 调用 CosyVoice 3.0 独立服务 (端口 8010)
|
||||
"""
|
||||
import httpx
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from loguru import logger
|
||||
|
||||
from app.core.config import settings
|
||||
|
||||
# Qwen3-TTS 服务地址
|
||||
QWEN_TTS_URL = "http://localhost:8009"
|
||||
# CosyVoice 3.0 服务地址
|
||||
VOICE_CLONE_URL = "http://localhost:8010"
|
||||
|
||||
|
||||
class VoiceCloneService:
|
||||
"""声音克隆服务 - 调用 Qwen3-TTS HTTP API"""
|
||||
"""声音克隆服务 - 调用 CosyVoice 3.0 HTTP API"""
|
||||
|
||||
def __init__(self):
|
||||
self.base_url = QWEN_TTS_URL
|
||||
self.base_url = VOICE_CLONE_URL
|
||||
# 健康状态缓存
|
||||
self._health_cache: Optional[dict] = None
|
||||
self._health_cache_time: float = 0
|
||||
# GPU 并发锁 (Serial Queue)
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
async def _generate_once(
|
||||
self,
|
||||
*,
|
||||
text: str,
|
||||
ref_audio_data: bytes,
|
||||
ref_text: str,
|
||||
language: str,
|
||||
speed: float = 1.0,
|
||||
max_retries: int = 4,
|
||||
) -> bytes:
|
||||
timeout = httpx.Timeout(240.0)
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/generate",
|
||||
files={"ref_audio": ("ref.wav", ref_audio_data, "audio/wav")},
|
||||
data={
|
||||
"text": text,
|
||||
"ref_text": ref_text,
|
||||
"language": language,
|
||||
"speed": str(speed),
|
||||
},
|
||||
)
|
||||
|
||||
retryable = False
|
||||
reason = ""
|
||||
|
||||
if response.status_code in (429, 502, 503, 504):
|
||||
retryable = True
|
||||
reason = f"HTTP {response.status_code}"
|
||||
elif response.status_code == 500 and (
|
||||
"生成超时" in response.text or "timeout" in response.text.lower()
|
||||
):
|
||||
retryable = True
|
||||
reason = "upstream timeout"
|
||||
|
||||
if retryable and attempt < max_retries - 1:
|
||||
wait = 8 * (attempt + 1)
|
||||
logger.warning(
|
||||
f"Voice clone retryable error ({reason}), retrying in {wait}s "
|
||||
f"(attempt {attempt + 1}/{max_retries})"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
|
||||
response.raise_for_status()
|
||||
return response.content
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Voice clone API error: {e.response.status_code} - {e.response.text}")
|
||||
raise RuntimeError(f"声音克隆服务错误: {e.response.text}")
|
||||
except httpx.RequestError as e:
|
||||
if attempt < max_retries - 1:
|
||||
wait = 6 * (attempt + 1)
|
||||
logger.warning(
|
||||
f"Voice clone connection error: {e}; retrying in {wait}s "
|
||||
f"(attempt {attempt + 1}/{max_retries})"
|
||||
)
|
||||
await asyncio.sleep(wait)
|
||||
continue
|
||||
logger.error(f"Voice clone connection error: {e}")
|
||||
raise RuntimeError("无法连接声音克隆服务,请检查服务是否启动")
|
||||
|
||||
raise RuntimeError("声音克隆服务繁忙,请稍后重试")
|
||||
|
||||
async def generate_audio(
|
||||
self,
|
||||
text: str,
|
||||
ref_audio_path: str,
|
||||
ref_text: str,
|
||||
output_path: str,
|
||||
language: str = "Chinese"
|
||||
language: str = "Chinese",
|
||||
speed: float = 1.0,
|
||||
) -> str:
|
||||
"""
|
||||
使用声音克隆生成语音
|
||||
@@ -51,60 +118,49 @@ class VoiceCloneService:
|
||||
logger.info(f"🎤 Voice Clone: {text[:30]}... (language={language})")
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 读取参考音频
|
||||
text = text.strip()
|
||||
if not text:
|
||||
raise RuntimeError("文本为空,无法生成语音")
|
||||
|
||||
with open(ref_audio_path, "rb") as f:
|
||||
ref_audio_data = f.read()
|
||||
|
||||
# 调用 Qwen3-TTS 服务
|
||||
timeout = httpx.Timeout(300.0) # 5分钟超时
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
try:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/generate",
|
||||
files={"ref_audio": ("ref.wav", ref_audio_data, "audio/wav")},
|
||||
data={
|
||||
"text": text,
|
||||
"ref_text": ref_text,
|
||||
"language": language
|
||||
}
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# 保存返回的音频
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(response.content)
|
||||
|
||||
logger.info(f"✅ Voice clone saved: {output_path}")
|
||||
return output_path
|
||||
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.error(f"Qwen3-TTS API error: {e.response.status_code} - {e.response.text}")
|
||||
raise RuntimeError(f"声音克隆服务错误: {e.response.text}")
|
||||
except httpx.RequestError as e:
|
||||
logger.error(f"Qwen3-TTS connection error: {e}")
|
||||
raise RuntimeError("无法连接声音克隆服务,请检查服务是否启动")
|
||||
# CosyVoice 内部自带 text_normalize 分段,无需客户端切分
|
||||
audio_bytes = await self._generate_once(
|
||||
text=text,
|
||||
ref_audio_data=ref_audio_data,
|
||||
ref_text=ref_text,
|
||||
language=language,
|
||||
speed=speed,
|
||||
)
|
||||
with open(output_path, "wb") as f:
|
||||
f.write(audio_bytes)
|
||||
logger.info(f"✅ Voice clone saved: {output_path}")
|
||||
return output_path
|
||||
|
||||
async def check_health(self) -> dict:
|
||||
"""健康检查"""
|
||||
import time
|
||||
|
||||
# 5分钟缓存
|
||||
# 30秒缓存
|
||||
now = time.time()
|
||||
if self._health_cache and (now - self._health_cache_time) < 300:
|
||||
return self._health_cache
|
||||
cached = self._health_cache
|
||||
if cached is not None and (now - self._health_cache_time) < 30:
|
||||
return cached
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
response = await client.get(f"{self.base_url}/health")
|
||||
response.raise_for_status()
|
||||
self._health_cache = response.json()
|
||||
payload = response.json()
|
||||
self._health_cache = payload
|
||||
self._health_cache_time = now
|
||||
return self._health_cache
|
||||
return payload
|
||||
except Exception as e:
|
||||
logger.warning(f"Qwen3-TTS health check failed: {e}")
|
||||
logger.warning(f"Voice clone health check failed: {e}")
|
||||
return {
|
||||
"service": "Qwen3-TTS Voice Clone",
|
||||
"model": "0.6B-Base",
|
||||
"service": "CosyVoice 3.0 Voice Clone",
|
||||
"model": "unknown",
|
||||
"ready": False,
|
||||
"gpu_id": 0,
|
||||
"error": str(e)
|
||||
|
||||
@@ -39,12 +39,22 @@ def split_word_to_chars(word: str, start: float, end: float) -> list:
|
||||
|
||||
tokens = []
|
||||
ascii_buffer = ""
|
||||
pending_space = False # 记录是否有待处理的空格(用于英文单词间距)
|
||||
|
||||
for char in word:
|
||||
if not char.strip():
|
||||
# 空格:flush ascii_buffer,标记下一个 token 需要前导空格
|
||||
if ascii_buffer:
|
||||
tokens.append(ascii_buffer)
|
||||
ascii_buffer = ""
|
||||
if tokens: # 仅在已有 token 时标记(避免开头重复空格)
|
||||
pending_space = True
|
||||
continue
|
||||
|
||||
if char.isascii() and char.isalnum():
|
||||
if pending_space and not ascii_buffer:
|
||||
ascii_buffer = " " # 将空格前置到新英文单词
|
||||
pending_space = False
|
||||
ascii_buffer += char
|
||||
continue
|
||||
|
||||
@@ -52,7 +62,9 @@ def split_word_to_chars(word: str, start: float, end: float) -> list:
|
||||
tokens.append(ascii_buffer)
|
||||
ascii_buffer = ""
|
||||
|
||||
tokens.append(char)
|
||||
prefix = " " if pending_space else ""
|
||||
pending_space = False
|
||||
tokens.append(prefix + char)
|
||||
|
||||
if ascii_buffer:
|
||||
tokens.append(ascii_buffer)
|
||||
@@ -175,6 +187,7 @@ class WhisperService:
|
||||
text: str,
|
||||
output_path: Optional[str] = None,
|
||||
language: str = "zh",
|
||||
original_text: Optional[str] = None,
|
||||
) -> dict:
|
||||
"""
|
||||
对音频进行转录,生成字级别时间戳
|
||||
@@ -184,6 +197,8 @@ class WhisperService:
|
||||
text: 原始文本(用于参考,但实际使用 whisper 转录结果)
|
||||
output_path: 可选,输出 JSON 文件路径
|
||||
language: 语言代码 (zh/en 等)
|
||||
original_text: 原始文案。非空时,Whisper 仅用于检测总时间范围,
|
||||
字幕文字用此原文替换(解决语言不匹配问题)
|
||||
|
||||
Returns:
|
||||
包含字级别时间戳的字典
|
||||
@@ -208,16 +223,19 @@ class WhisperService:
|
||||
|
||||
logger.info(f"Detected language: {info.language} (prob: {info.language_probability:.2f})")
|
||||
|
||||
# 收集 Whisper 转录结果(始终需要,用于获取时间范围)
|
||||
all_segments = []
|
||||
whisper_first_start = None
|
||||
whisper_last_end = None
|
||||
for segment in segments_iter:
|
||||
# 提取每个字的时间戳,并拆分成单字
|
||||
all_words = []
|
||||
if segment.words:
|
||||
for word_info in segment.words:
|
||||
word_text = word_info.word
|
||||
if word_text.strip():
|
||||
# 将词拆分成单字,时间戳线性插值
|
||||
# 保留前导空格用于英文词间距
|
||||
if whisper_first_start is None:
|
||||
whisper_first_start = word_info.start
|
||||
whisper_last_end = word_info.end
|
||||
chars = split_word_to_chars(
|
||||
word_text,
|
||||
word_info.start,
|
||||
@@ -225,11 +243,24 @@ class WhisperService:
|
||||
)
|
||||
all_words.extend(chars)
|
||||
|
||||
# 将长段落按标点和字数拆分成多行
|
||||
if all_words:
|
||||
line_segments = split_segment_to_lines(all_words, max_chars)
|
||||
all_segments.extend(line_segments)
|
||||
|
||||
# 如果提供了 original_text,用原文替换 Whisper 转录文字
|
||||
if original_text and original_text.strip() and whisper_first_start is not None:
|
||||
logger.info(f"Using original_text for subtitles (len={len(original_text)}), "
|
||||
f"Whisper time range: {whisper_first_start:.2f}-{whisper_last_end:.2f}s")
|
||||
# 用 split_word_to_chars 拆分原文
|
||||
orig_chars = split_word_to_chars(
|
||||
original_text.strip(),
|
||||
whisper_first_start,
|
||||
whisper_last_end
|
||||
)
|
||||
if orig_chars:
|
||||
all_segments = split_segment_to_lines(orig_chars, max_chars)
|
||||
logger.info(f"Rebuilt {len(all_segments)} subtitle segments from original text")
|
||||
|
||||
logger.info(f"Generated {len(all_segments)} subtitle segments")
|
||||
return {"segments": all_segments}
|
||||
|
||||
@@ -247,12 +278,13 @@ class WhisperService:
|
||||
|
||||
return result
|
||||
|
||||
async def transcribe(self, audio_path: str) -> str:
|
||||
async def transcribe(self, audio_path: str, language: str | None = None) -> str:
|
||||
"""
|
||||
仅转录文本(用于提取文案)
|
||||
|
||||
Args:
|
||||
audio_path: 音频/视频文件路径
|
||||
language: 语言代码,None 表示自动检测
|
||||
|
||||
Returns:
|
||||
纯文本内容
|
||||
@@ -266,7 +298,7 @@ class WhisperService:
|
||||
# 转录 (无需字级时间戳)
|
||||
segments_iter, _ = model.transcribe(
|
||||
audio_path,
|
||||
language="zh",
|
||||
language=language,
|
||||
word_timestamps=False,
|
||||
vad_filter=True,
|
||||
)
|
||||
|
||||
@@ -20,14 +20,14 @@ logger = logging.getLogger("Watchdog")
|
||||
# 服务配置
|
||||
SERVICES = [
|
||||
{
|
||||
"name": "vigent2-qwen-tts",
|
||||
"url": "http://localhost:8009/health",
|
||||
"name": "vigent2-cosyvoice",
|
||||
"url": "http://localhost:8010/health",
|
||||
"failures": 0,
|
||||
"threshold": 5, # 连续5次失败才重启(5×30s = 2.5分钟容忍期)
|
||||
"threshold": 3, # 连续3次失败才重启(3×15s ≈ 45秒容忍期)
|
||||
"timeout": 10.0,
|
||||
"restart_cmd": ["pm2", "restart", "vigent2-qwen-tts"],
|
||||
"restart_cmd": ["pm2", "restart", "vigent2-cosyvoice"],
|
||||
"cooldown_until": 0, # 重启后的冷却截止时间戳
|
||||
"cooldown_sec": 120, # 重启后等待120秒再开始检查
|
||||
"cooldown_sec": 45, # 重启后等待45秒再开始检查
|
||||
}
|
||||
]
|
||||
|
||||
@@ -45,10 +45,20 @@ async def check_service(service):
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
response = await client.get(service["url"])
|
||||
if response.status_code == 200:
|
||||
if service["failures"] > 0:
|
||||
logger.info(f"✅ 服务 {service['name']} 已恢复正常")
|
||||
service["failures"] = 0
|
||||
return True
|
||||
ready = True
|
||||
try:
|
||||
payload = response.json()
|
||||
ready = bool(payload.get("ready", True))
|
||||
except Exception:
|
||||
payload = {}
|
||||
|
||||
if ready:
|
||||
if service["failures"] > 0:
|
||||
logger.info(f"✅ 服务 {service['name']} 已恢复正常")
|
||||
service["failures"] = 0
|
||||
return True
|
||||
|
||||
logger.warning(f"⚠️ 服务 {service['name']} ready=false,健康检查未通过: {payload}")
|
||||
else:
|
||||
logger.warning(f"⚠️ 服务 {service['name']} 返回状态码 {response.status_code}")
|
||||
except Exception as e:
|
||||
@@ -83,8 +93,8 @@ async def main():
|
||||
for service in SERVICES:
|
||||
await check_service(service)
|
||||
|
||||
# 每 30 秒检查一次
|
||||
await asyncio.sleep(30)
|
||||
# 每 15 秒检查一次
|
||||
await asyncio.sleep(15)
|
||||
|
||||
if __name__ == "__main__":
|
||||
try:
|
||||
|
||||
@@ -126,6 +126,7 @@ export const useGeneratedAudios = ({
|
||||
ref_audio_id?: string;
|
||||
ref_text?: string;
|
||||
language: string;
|
||||
speed?: number;
|
||||
}) => {
|
||||
setIsGeneratingAudio(true);
|
||||
setAudioTask({ status: "pending", progress: 0, message: "正在提交..." });
|
||||
|
||||
@@ -89,9 +89,6 @@ const LANG_TO_LOCALE: Record<string, string> = {
|
||||
|
||||
|
||||
|
||||
const FIXED_REF_TEXT =
|
||||
"其实生活中有许多美好的瞬间,比如清晨的阳光,或者一杯温热的清茶。希望这次生成的音色能够自然、流畅,完美还原出我最真实的声音状态。";
|
||||
|
||||
const scrollContainerToItem = (container: HTMLDivElement, item: HTMLDivElement) => {
|
||||
const containerRect = container.getBoundingClientRect();
|
||||
const itemRect = item.getBoundingClientRect();
|
||||
@@ -153,6 +150,7 @@ export const useHomeController = () => {
|
||||
const [titleSizeLocked, setTitleSizeLocked] = useState<boolean>(false);
|
||||
const [titleTopMargin, setTitleTopMargin] = useState<number>(62);
|
||||
const [subtitleBottomMargin, setSubtitleBottomMargin] = useState<number>(80);
|
||||
const [outputAspectRatio, setOutputAspectRatio] = useState<"9:16" | "16:9">("9:16");
|
||||
const [showStylePreview, setShowStylePreview] = useState<boolean>(false);
|
||||
const [materialDimensions, setMaterialDimensions] = useState<{ width: number; height: number } | null>(null);
|
||||
|
||||
@@ -165,11 +163,14 @@ export const useHomeController = () => {
|
||||
// 声音克隆相关状态
|
||||
const [ttsMode, setTtsMode] = useState<"edgetts" | "voiceclone">("edgetts");
|
||||
const [selectedRefAudio, setSelectedRefAudio] = useState<RefAudio | null>(null);
|
||||
const [refText, setRefText] = useState(FIXED_REF_TEXT);
|
||||
const [refText, setRefText] = useState("");
|
||||
|
||||
// 预生成配音选中 ID
|
||||
const [selectedAudioId, setSelectedAudioId] = useState<string | null>(null);
|
||||
|
||||
// 语速控制
|
||||
const [speed, setSpeed] = useState<number>(1.0);
|
||||
|
||||
// ClipTrimmer 模态框状态
|
||||
const [clipTrimmerOpen, setClipTrimmerOpen] = useState(false);
|
||||
const [clipTrimmerSegmentId, setClipTrimmerSegmentId] = useState<string | null>(null);
|
||||
@@ -286,7 +287,6 @@ export const useHomeController = () => {
|
||||
setUploadError,
|
||||
fetchMaterials,
|
||||
toggleMaterial,
|
||||
reorderMaterials,
|
||||
deleteMaterial,
|
||||
handleUpload,
|
||||
} = useMaterials({
|
||||
@@ -314,8 +314,9 @@ export const useHomeController = () => {
|
||||
fetchRefAudios,
|
||||
uploadRefAudio,
|
||||
deleteRefAudio,
|
||||
retranscribeRefAudio,
|
||||
retranscribingId,
|
||||
} = useRefAudios({
|
||||
fixedRefText: FIXED_REF_TEXT,
|
||||
selectedRefAudio,
|
||||
setSelectedRefAudio,
|
||||
setRefText,
|
||||
@@ -448,6 +449,8 @@ export const useHomeController = () => {
|
||||
setTitleTopMargin,
|
||||
subtitleBottomMargin,
|
||||
setSubtitleBottomMargin,
|
||||
outputAspectRatio,
|
||||
setOutputAspectRatio,
|
||||
selectedBgmId,
|
||||
setSelectedBgmId,
|
||||
bgmVolume,
|
||||
@@ -459,6 +462,8 @@ export const useHomeController = () => {
|
||||
selectedRefAudio,
|
||||
selectedAudioId,
|
||||
setSelectedAudioId,
|
||||
speed,
|
||||
setSpeed,
|
||||
});
|
||||
|
||||
const { savedScripts, saveScript, deleteScript: deleteSavedScript } = useSavedScripts(storageKey);
|
||||
@@ -523,7 +528,6 @@ export const useHomeController = () => {
|
||||
|
||||
let isActive = true;
|
||||
const video = document.createElement("video");
|
||||
video.crossOrigin = "anonymous";
|
||||
video.preload = "metadata";
|
||||
video.src = url;
|
||||
video.load();
|
||||
@@ -610,7 +614,7 @@ export const useHomeController = () => {
|
||||
setSelectedVideoId(firstId);
|
||||
setGeneratedVideo(resolveMediaUrl(generatedVideos[0].path));
|
||||
}
|
||||
}, [isRestored, generatedVideos, selectedVideoId, setSelectedVideoId, setGeneratedVideo, resolveMediaUrl]);
|
||||
}, [isRestored, generatedVideos, selectedVideoId, setSelectedVideoId, setGeneratedVideo]);
|
||||
|
||||
// 【修复】BGM 默认选中逻辑
|
||||
useEffect(() => {
|
||||
@@ -619,8 +623,14 @@ export const useHomeController = () => {
|
||||
}
|
||||
}, [isRestored, bgmList, selectedBgmId, enableBgm, setSelectedBgmId]);
|
||||
|
||||
const videoScrollReady = useRef(false);
|
||||
useEffect(() => {
|
||||
if (!selectedVideoId) return;
|
||||
if (!videoScrollReady.current) {
|
||||
videoScrollReady.current = true;
|
||||
return;
|
||||
}
|
||||
|
||||
const target = videoItemRefs.current[selectedVideoId];
|
||||
if (target) {
|
||||
target.scrollIntoView({ block: "nearest", behavior: "smooth" });
|
||||
@@ -815,6 +825,7 @@ export const useHomeController = () => {
|
||||
ref_audio_id: ttsMode === "voiceclone" ? selectedRefAudio!.id : undefined,
|
||||
ref_text: ttsMode === "voiceclone" ? refText : undefined,
|
||||
language: textLang,
|
||||
speed: ttsMode === "voiceclone" ? speed : undefined,
|
||||
};
|
||||
await generateAudio(params);
|
||||
};
|
||||
@@ -854,22 +865,59 @@ export const useHomeController = () => {
|
||||
language: selectedAudio.language || textLang,
|
||||
title: videoTitle.trim() || undefined,
|
||||
enable_subtitles: true,
|
||||
output_aspect_ratio: outputAspectRatio,
|
||||
};
|
||||
|
||||
// 多素材
|
||||
if (selectedMaterials.length > 1) {
|
||||
payload.material_paths = selectedMaterials
|
||||
const timelineOrderedIds = timelineSegments
|
||||
.map((seg) => seg.materialId)
|
||||
.filter((id, index, arr) => arr.indexOf(id) === index);
|
||||
const orderedMaterialIds = [
|
||||
...timelineOrderedIds.filter((id) => selectedMaterials.includes(id)),
|
||||
...selectedMaterials.filter((id) => !timelineOrderedIds.includes(id)),
|
||||
];
|
||||
|
||||
const materialPaths = orderedMaterialIds
|
||||
.map((id) => materials.find((x) => x.id === id)?.path)
|
||||
.filter((path): path is string => !!path);
|
||||
|
||||
if (materialPaths.length === 0) {
|
||||
toast.error("多素材解析失败,请刷新素材后重试");
|
||||
return;
|
||||
}
|
||||
|
||||
payload.material_paths = materialPaths;
|
||||
payload.material_path = materialPaths[0];
|
||||
|
||||
// 发送自定义时间轴分配
|
||||
const assignments = toCustomAssignments();
|
||||
if (assignments.length > 0) {
|
||||
const assignmentPaths = assignments
|
||||
.map((a) => a.material_path)
|
||||
.filter((path): path is string => !!path);
|
||||
|
||||
if (assignmentPaths.length === assignments.length) {
|
||||
// 以时间轴可见段为准:超出时间轴的素材不会参与本次生成
|
||||
payload.material_paths = assignmentPaths;
|
||||
payload.material_path = assignmentPaths[0];
|
||||
}
|
||||
payload.custom_assignments = assignments;
|
||||
} else {
|
||||
console.warn(
|
||||
"[Timeline] custom_assignments 为空,回退后端自动分配",
|
||||
{ materials: materialPaths.length }
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// 单素材 + 截取起点
|
||||
if (selectedMaterials.length === 1 && timelineSegments[0]?.sourceStart > 0) {
|
||||
// 单素材 + 截取范围
|
||||
const singleSeg = timelineSegments[0];
|
||||
if (
|
||||
selectedMaterials.length === 1
|
||||
&& singleSeg
|
||||
&& (singleSeg.sourceStart > 0 || singleSeg.sourceEnd > 0)
|
||||
) {
|
||||
payload.custom_assignments = toCustomAssignments();
|
||||
}
|
||||
|
||||
@@ -1002,6 +1050,8 @@ export const useHomeController = () => {
|
||||
setTitleTopMargin,
|
||||
subtitleBottomMargin,
|
||||
setSubtitleBottomMargin,
|
||||
outputAspectRatio,
|
||||
setOutputAspectRatio,
|
||||
resolveAssetUrl,
|
||||
getFontFormat,
|
||||
buildTextShadow,
|
||||
@@ -1029,6 +1079,8 @@ export const useHomeController = () => {
|
||||
saveEditing,
|
||||
cancelEditing,
|
||||
deleteRefAudio,
|
||||
retranscribeRefAudio,
|
||||
retranscribingId,
|
||||
recordedBlob,
|
||||
isRecording,
|
||||
recordingTime,
|
||||
@@ -1036,7 +1088,6 @@ export const useHomeController = () => {
|
||||
stopRecording,
|
||||
useRecording,
|
||||
formatRecordingTime,
|
||||
fixedRefText: FIXED_REF_TEXT,
|
||||
bgmList,
|
||||
bgmLoading,
|
||||
bgmError,
|
||||
@@ -1072,6 +1123,8 @@ export const useHomeController = () => {
|
||||
deleteAudio,
|
||||
renameAudio,
|
||||
selectAudio,
|
||||
speed,
|
||||
setSpeed,
|
||||
timelineSegments,
|
||||
reorderSegments,
|
||||
setSourceRange,
|
||||
|
||||
@@ -39,6 +39,8 @@ interface UseHomePersistenceOptions {
|
||||
setTitleTopMargin: React.Dispatch<React.SetStateAction<number>>;
|
||||
subtitleBottomMargin: number;
|
||||
setSubtitleBottomMargin: React.Dispatch<React.SetStateAction<number>>;
|
||||
outputAspectRatio: '9:16' | '16:9';
|
||||
setOutputAspectRatio: React.Dispatch<React.SetStateAction<'9:16' | '16:9'>>;
|
||||
selectedBgmId: string;
|
||||
setSelectedBgmId: React.Dispatch<React.SetStateAction<string>>;
|
||||
bgmVolume: number;
|
||||
@@ -50,6 +52,8 @@ interface UseHomePersistenceOptions {
|
||||
selectedRefAudio: RefAudio | null;
|
||||
selectedAudioId: string | null;
|
||||
setSelectedAudioId: React.Dispatch<React.SetStateAction<string | null>>;
|
||||
speed: number;
|
||||
setSpeed: React.Dispatch<React.SetStateAction<number>>;
|
||||
}
|
||||
|
||||
export const useHomePersistence = ({
|
||||
@@ -81,6 +85,8 @@ export const useHomePersistence = ({
|
||||
setTitleTopMargin,
|
||||
subtitleBottomMargin,
|
||||
setSubtitleBottomMargin,
|
||||
outputAspectRatio,
|
||||
setOutputAspectRatio,
|
||||
selectedBgmId,
|
||||
setSelectedBgmId,
|
||||
bgmVolume,
|
||||
@@ -92,6 +98,8 @@ export const useHomePersistence = ({
|
||||
selectedRefAudio,
|
||||
selectedAudioId,
|
||||
setSelectedAudioId,
|
||||
speed,
|
||||
setSpeed,
|
||||
}: UseHomePersistenceOptions) => {
|
||||
const [isRestored, setIsRestored] = useState(false);
|
||||
|
||||
@@ -115,6 +123,8 @@ export const useHomePersistence = ({
|
||||
const savedEnableBgm = localStorage.getItem(`vigent_${storageKey}_enableBgm`);
|
||||
const savedTitleTopMargin = localStorage.getItem(`vigent_${storageKey}_titleTopMargin`);
|
||||
const savedSubtitleBottomMargin = localStorage.getItem(`vigent_${storageKey}_subtitleBottomMargin`);
|
||||
const savedOutputAspectRatio = localStorage.getItem(`vigent_${storageKey}_outputAspectRatio`);
|
||||
const savedSpeed = localStorage.getItem(`vigent_${storageKey}_speed`);
|
||||
|
||||
setText(savedText || "大家好,欢迎来到我的频道,今天给大家分享一些有趣的内容。");
|
||||
setVideoTitle(savedTitle ? clampTitle(savedTitle) : "");
|
||||
@@ -169,6 +179,15 @@ export const useHomePersistence = ({
|
||||
if (!Number.isNaN(parsed)) setSubtitleBottomMargin(parsed);
|
||||
}
|
||||
|
||||
if (savedOutputAspectRatio === '9:16' || savedOutputAspectRatio === '16:9') {
|
||||
setOutputAspectRatio(savedOutputAspectRatio);
|
||||
}
|
||||
|
||||
if (savedSpeed) {
|
||||
const parsed = parseFloat(savedSpeed);
|
||||
if (!Number.isNaN(parsed)) setSpeed(parsed);
|
||||
}
|
||||
|
||||
// eslint-disable-next-line react-hooks/set-state-in-effect
|
||||
setIsRestored(true);
|
||||
}, [
|
||||
@@ -181,6 +200,7 @@ export const useHomePersistence = ({
|
||||
setSelectedTitleStyleId,
|
||||
setSelectedVideoId,
|
||||
setSelectedAudioId,
|
||||
setSpeed,
|
||||
setSubtitleFontSize,
|
||||
setSubtitleSizeLocked,
|
||||
setText,
|
||||
@@ -189,6 +209,7 @@ export const useHomePersistence = ({
|
||||
setTitleSizeLocked,
|
||||
setTitleTopMargin,
|
||||
setSubtitleBottomMargin,
|
||||
setOutputAspectRatio,
|
||||
setTtsMode,
|
||||
setVideoTitle,
|
||||
setVoice,
|
||||
@@ -265,6 +286,12 @@ export const useHomePersistence = ({
|
||||
}
|
||||
}, [subtitleBottomMargin, storageKey, isRestored]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isRestored) {
|
||||
localStorage.setItem(`vigent_${storageKey}_outputAspectRatio`, outputAspectRatio);
|
||||
}
|
||||
}, [outputAspectRatio, storageKey, isRestored]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isRestored) {
|
||||
localStorage.setItem(`vigent_${storageKey}_bgmId`, selectedBgmId);
|
||||
@@ -309,5 +336,11 @@ export const useHomePersistence = ({
|
||||
}
|
||||
}, [selectedRefAudio, storageKey, isRestored]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isRestored) {
|
||||
localStorage.setItem(`vigent_${storageKey}_speed`, String(speed));
|
||||
}
|
||||
}, [speed, storageKey, isRestored]);
|
||||
|
||||
return { isRestored };
|
||||
};
|
||||
|
||||
@@ -185,11 +185,14 @@ export const useMaterials = ({
|
||||
).then((enriched) => setMaterials(enriched));
|
||||
}
|
||||
|
||||
// 找出新增的素材 ID 并自动选中
|
||||
// 找出新增素材并默认仅选中新上传项,避免误触发多素材模式
|
||||
const oldIds = new Set(materials.map((m) => m.id));
|
||||
const newIds = nextMaterials.filter((m) => !oldIds.has(m.id)).map((m) => m.id);
|
||||
if (newIds.length > 0) {
|
||||
setSelectedMaterials((prev) => [...prev, ...newIds]);
|
||||
setSelectedMaterials([newIds[0]]);
|
||||
} else if (nextMaterials[0]?.id) {
|
||||
// 兜底:即使未识别到新增项,也保持单素材默认选择最新一个
|
||||
setSelectedMaterials([nextMaterials[0].id]);
|
||||
}
|
||||
} catch (err: unknown) {
|
||||
console.error("Upload failed:", err);
|
||||
@@ -200,7 +203,7 @@ export const useMaterials = ({
|
||||
}
|
||||
|
||||
e.target.value = '';
|
||||
}, [fetchMaterials]);
|
||||
}, [materials, setSelectedMaterials]);
|
||||
|
||||
return {
|
||||
materials,
|
||||
|
||||
@@ -13,14 +13,12 @@ interface RefAudio {
|
||||
}
|
||||
|
||||
interface UseRefAudiosOptions {
|
||||
fixedRefText: string;
|
||||
selectedRefAudio: RefAudio | null;
|
||||
setSelectedRefAudio: React.Dispatch<React.SetStateAction<RefAudio | null>>;
|
||||
setRefText: React.Dispatch<React.SetStateAction<string>>;
|
||||
}
|
||||
|
||||
export const useRefAudios = ({
|
||||
fixedRefText,
|
||||
selectedRefAudio,
|
||||
setSelectedRefAudio,
|
||||
setRefText,
|
||||
@@ -28,6 +26,7 @@ export const useRefAudios = ({
|
||||
const [refAudios, setRefAudios] = useState<RefAudio[]>([]);
|
||||
const [isUploadingRef, setIsUploadingRef] = useState(false);
|
||||
const [uploadRefError, setUploadRefError] = useState<string | null>(null);
|
||||
const [retranscribingId, setRetranscribingId] = useState<string | null>(null);
|
||||
|
||||
const fetchRefAudios = useCallback(async () => {
|
||||
try {
|
||||
@@ -42,15 +41,12 @@ export const useRefAudios = ({
|
||||
}, []);
|
||||
|
||||
const uploadRefAudio = useCallback(async (file: File) => {
|
||||
const refTextInput = fixedRefText;
|
||||
|
||||
setIsUploadingRef(true);
|
||||
setUploadRefError(null);
|
||||
|
||||
try {
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
formData.append('ref_text', refTextInput);
|
||||
|
||||
const { data: res } = await api.post<ApiResponse<RefAudio>>('/api/ref-audios', formData, {
|
||||
headers: { 'Content-Type': 'multipart/form-data' },
|
||||
@@ -68,7 +64,7 @@ export const useRefAudios = ({
|
||||
const errorMsg = axiosErr.response?.data?.message || axiosErr.message || String(err);
|
||||
setUploadRefError(`上传失败: ${errorMsg}`);
|
||||
}
|
||||
}, [fetchRefAudios, fixedRefText, setRefText, setSelectedRefAudio]);
|
||||
}, [fetchRefAudios, setRefText, setSelectedRefAudio]);
|
||||
|
||||
const deleteRefAudio = useCallback(async (audioId: string) => {
|
||||
if (!confirm("确定要删除这个参考音频吗?")) return;
|
||||
@@ -84,6 +80,28 @@ export const useRefAudios = ({
|
||||
}
|
||||
}, [fetchRefAudios, selectedRefAudio, setRefText, setSelectedRefAudio]);
|
||||
|
||||
const retranscribeRefAudio = useCallback(async (audioId: string) => {
|
||||
setRetranscribingId(audioId);
|
||||
try {
|
||||
const { data: res } = await api.post<ApiResponse<{ ref_text: string }>>(
|
||||
`/api/ref-audios/${encodeURIComponent(audioId)}/retranscribe`
|
||||
);
|
||||
const payload = unwrap(res);
|
||||
toast.success("识别完成");
|
||||
// 更新列表和当前选中
|
||||
await fetchRefAudios();
|
||||
if (selectedRefAudio?.id === audioId) {
|
||||
setRefText(payload.ref_text);
|
||||
}
|
||||
} catch (err: unknown) {
|
||||
const axiosErr = err as { response?: { data?: { message?: string } }; message?: string };
|
||||
const errorMsg = axiosErr.response?.data?.message || axiosErr.message || String(err);
|
||||
toast.error(`识别失败: ${errorMsg}`);
|
||||
} finally {
|
||||
setRetranscribingId(null);
|
||||
}
|
||||
}, [fetchRefAudios, selectedRefAudio, setRefText]);
|
||||
|
||||
return {
|
||||
refAudios,
|
||||
isUploadingRef,
|
||||
@@ -92,5 +110,7 @@ export const useRefAudios = ({
|
||||
fetchRefAudios,
|
||||
uploadRefAudio,
|
||||
deleteRefAudio,
|
||||
retranscribeRefAudio,
|
||||
retranscribingId,
|
||||
};
|
||||
};
|
||||
|
||||
@@ -17,6 +17,7 @@ export interface CustomAssignment {
|
||||
start: number;
|
||||
end: number;
|
||||
source_start: number;
|
||||
source_end?: number;
|
||||
}
|
||||
|
||||
const COLORS = ["#8b5cf6", "#ec4899", "#06b6d4", "#f59e0b", "#10b981", "#f97316"];
|
||||
@@ -35,9 +36,11 @@ function getEffectiveDuration(
|
||||
seg: { sourceStart: number; sourceEnd: number; materialId: string },
|
||||
mats: Material[]
|
||||
): number {
|
||||
if (seg.sourceEnd > seg.sourceStart) return seg.sourceEnd - seg.sourceStart;
|
||||
const mat = mats.find((m) => m.id === seg.materialId);
|
||||
return mat?.duration_sec ?? 0;
|
||||
const matDur = mat?.duration_sec ?? 0;
|
||||
if (seg.sourceEnd > seg.sourceStart) return seg.sourceEnd - seg.sourceStart;
|
||||
if (seg.sourceStart > 0) return Math.max(matDur - seg.sourceStart, 0);
|
||||
return matDur;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -99,9 +102,15 @@ export const useTimelineEditor = ({
|
||||
|
||||
// Refs for stable callbacks (avoid recreating on every materials/duration change)
|
||||
const materialsRef = useRef(materials);
|
||||
materialsRef.current = materials;
|
||||
const audioDurationRef = useRef(audioDuration);
|
||||
audioDurationRef.current = audioDuration;
|
||||
|
||||
useEffect(() => {
|
||||
materialsRef.current = materials;
|
||||
}, [materials]);
|
||||
|
||||
useEffect(() => {
|
||||
audioDurationRef.current = audioDuration;
|
||||
}, [audioDuration]);
|
||||
|
||||
// Build a durationsKey so segments re-init when material durations become available
|
||||
const durationsKey = selectedMaterials
|
||||
@@ -232,6 +241,7 @@ export const useTimelineEditor = ({
|
||||
start: seg.start,
|
||||
end: seg.end,
|
||||
source_start: seg.sourceStart,
|
||||
source_end: seg.sourceEnd > seg.sourceStart ? seg.sourceEnd : undefined,
|
||||
};
|
||||
});
|
||||
}, [segments]);
|
||||
|
||||
@@ -86,6 +86,8 @@ export function FloatingStylePreview({
|
||||
|
||||
const previewScale = windowWidth / previewBaseWidth;
|
||||
const previewHeight = previewBaseHeight * previewScale;
|
||||
const widthScale = Math.min(1, previewBaseWidth / 1080);
|
||||
const responsiveScale = Math.max(0.55, widthScale);
|
||||
|
||||
const activeSubtitleStyle = subtitleStyles.find((s) => s.id === selectedSubtitleStyleId)
|
||||
|| subtitleStyles.find((s) => s.is_default)
|
||||
@@ -102,8 +104,8 @@ export function FloatingStylePreview({
|
||||
const subtitleHighlightColor = activeSubtitleStyle?.highlight_color || "#FFE600";
|
||||
const subtitleNormalColor = activeSubtitleStyle?.normal_color || "#FFFFFF";
|
||||
const subtitleStrokeColor = activeSubtitleStyle?.stroke_color || "#000000";
|
||||
const subtitleStrokeSize = activeSubtitleStyle?.stroke_size ?? 3;
|
||||
const subtitleLetterSpacing = activeSubtitleStyle?.letter_spacing ?? 2;
|
||||
const subtitleStrokeSize = Math.max(1, Math.round((activeSubtitleStyle?.stroke_size ?? 3) * responsiveScale));
|
||||
const subtitleLetterSpacing = Math.max(0, (activeSubtitleStyle?.letter_spacing ?? 2) * responsiveScale);
|
||||
const subtitleFontFamilyName = `SubtitlePreview-${activeSubtitleStyle?.id || "default"}`;
|
||||
const subtitleFontUrl = activeSubtitleStyle?.font_file
|
||||
? resolveAssetUrl(`fonts/${activeSubtitleStyle.font_file}`)
|
||||
@@ -111,14 +113,19 @@ export function FloatingStylePreview({
|
||||
|
||||
const titleColor = activeTitleStyle?.color || "#FFFFFF";
|
||||
const titleStrokeColor = activeTitleStyle?.stroke_color || "#000000";
|
||||
const titleStrokeSize = activeTitleStyle?.stroke_size ?? 8;
|
||||
const titleLetterSpacing = activeTitleStyle?.letter_spacing ?? 4;
|
||||
const titleStrokeSize = Math.max(1, Math.round((activeTitleStyle?.stroke_size ?? 8) * responsiveScale));
|
||||
const titleLetterSpacing = Math.max(0, (activeTitleStyle?.letter_spacing ?? 4) * responsiveScale);
|
||||
const titleFontWeight = activeTitleStyle?.font_weight ?? 900;
|
||||
const titleFontFamilyName = `TitlePreview-${activeTitleStyle?.id || "default"}`;
|
||||
const titleFontUrl = activeTitleStyle?.font_file
|
||||
? resolveAssetUrl(`fonts/${activeTitleStyle.font_file}`)
|
||||
: null;
|
||||
|
||||
const scaledTitleFontSize = Math.max(36, Math.round(titleFontSize * responsiveScale));
|
||||
const scaledSubtitleFontSize = Math.max(28, Math.round(subtitleFontSize * responsiveScale));
|
||||
const scaledTitleTopMargin = Math.max(0, Math.round(titleTopMargin * responsiveScale));
|
||||
const scaledSubtitleBottomMargin = Math.max(0, Math.round(subtitleBottomMargin * responsiveScale));
|
||||
|
||||
const content = (
|
||||
<div
|
||||
style={{
|
||||
@@ -172,11 +179,11 @@ export function FloatingStylePreview({
|
||||
className="w-full text-center"
|
||||
style={{
|
||||
position: 'absolute',
|
||||
top: `${titleTopMargin}px`,
|
||||
top: `${scaledTitleTopMargin}px`,
|
||||
left: 0,
|
||||
right: 0,
|
||||
color: titleColor,
|
||||
fontSize: `${titleFontSize}px`,
|
||||
fontSize: `${scaledTitleFontSize}px`,
|
||||
fontWeight: titleFontWeight,
|
||||
fontFamily: titleFontUrl
|
||||
? `'${titleFontFamilyName}', "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei", "Noto Sans SC", sans-serif`
|
||||
@@ -184,6 +191,10 @@ export function FloatingStylePreview({
|
||||
textShadow: buildTextShadow(titleStrokeColor, titleStrokeSize),
|
||||
letterSpacing: `${titleLetterSpacing}px`,
|
||||
lineHeight: 1.2,
|
||||
whiteSpace: 'normal',
|
||||
wordBreak: 'break-word',
|
||||
overflowWrap: 'anywhere',
|
||||
boxSizing: 'border-box',
|
||||
opacity: videoTitle.trim() ? 1 : 0.7,
|
||||
padding: '0 5%',
|
||||
}}
|
||||
@@ -195,16 +206,20 @@ export function FloatingStylePreview({
|
||||
className="w-full text-center"
|
||||
style={{
|
||||
position: 'absolute',
|
||||
bottom: `${subtitleBottomMargin}px`,
|
||||
bottom: `${scaledSubtitleBottomMargin}px`,
|
||||
left: 0,
|
||||
right: 0,
|
||||
fontSize: `${subtitleFontSize}px`,
|
||||
fontSize: `${scaledSubtitleFontSize}px`,
|
||||
fontFamily: subtitleFontUrl
|
||||
? `'${subtitleFontFamilyName}', "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei", "Noto Sans SC", sans-serif`
|
||||
: '"PingFang SC", "Hiragino Sans GB", "Microsoft YaHei", "Noto Sans SC", sans-serif',
|
||||
textShadow: buildTextShadow(subtitleStrokeColor, subtitleStrokeSize),
|
||||
letterSpacing: `${subtitleLetterSpacing}px`,
|
||||
lineHeight: 1.35,
|
||||
whiteSpace: 'normal',
|
||||
wordBreak: 'break-word',
|
||||
overflowWrap: 'anywhere',
|
||||
boxSizing: 'border-box',
|
||||
padding: '0 6%',
|
||||
}}
|
||||
>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import { useState, useRef, useCallback, useEffect } from "react";
|
||||
import { Play, Pause, Pencil, Trash2, Check, X, RefreshCw, Mic } from "lucide-react";
|
||||
import { Play, Pause, Pencil, Trash2, Check, X, RefreshCw, Mic, ChevronDown } from "lucide-react";
|
||||
import type { GeneratedAudio } from "@/features/home/model/useGeneratedAudios";
|
||||
|
||||
interface AudioTask {
|
||||
@@ -19,6 +19,10 @@ interface GeneratedAudiosPanelProps {
|
||||
onDeleteAudio: (id: string) => void;
|
||||
onRenameAudio: (id: string, newName: string) => void;
|
||||
hasText: boolean;
|
||||
missingRefAudio?: boolean;
|
||||
speed: number;
|
||||
onSpeedChange: (speed: number) => void;
|
||||
ttsMode: string;
|
||||
}
|
||||
|
||||
export function GeneratedAudiosPanel({
|
||||
@@ -32,11 +36,17 @@ export function GeneratedAudiosPanel({
|
||||
onDeleteAudio,
|
||||
onRenameAudio,
|
||||
hasText,
|
||||
missingRefAudio = false,
|
||||
speed,
|
||||
onSpeedChange,
|
||||
ttsMode,
|
||||
}: GeneratedAudiosPanelProps) {
|
||||
const [editingId, setEditingId] = useState<string | null>(null);
|
||||
const [editName, setEditName] = useState("");
|
||||
const [playingId, setPlayingId] = useState<string | null>(null);
|
||||
const [speedOpen, setSpeedOpen] = useState(false);
|
||||
const audioRef = useRef<HTMLAudioElement | null>(null);
|
||||
const speedRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const stopPlaying = useCallback(() => {
|
||||
if (audioRef.current) {
|
||||
@@ -57,6 +67,17 @@ export function GeneratedAudiosPanel({
|
||||
};
|
||||
}, []);
|
||||
|
||||
// Close speed dropdown on click outside
|
||||
useEffect(() => {
|
||||
const handler = (e: MouseEvent) => {
|
||||
if (speedRef.current && !speedRef.current.contains(e.target as Node)) {
|
||||
setSpeedOpen(false);
|
||||
}
|
||||
};
|
||||
if (speedOpen) document.addEventListener("mousedown", handler);
|
||||
return () => document.removeEventListener("mousedown", handler);
|
||||
}, [speedOpen]);
|
||||
|
||||
const togglePlay = (audio: GeneratedAudio, e: React.MouseEvent) => {
|
||||
e.stopPropagation();
|
||||
if (playingId === audio.id) {
|
||||
@@ -91,19 +112,60 @@ export function GeneratedAudiosPanel({
|
||||
setEditName("");
|
||||
};
|
||||
|
||||
const canGenerate = hasText && !missingRefAudio;
|
||||
|
||||
const speedOptions = [
|
||||
{ value: 0.8, label: "较慢" },
|
||||
{ value: 0.9, label: "稍慢" },
|
||||
{ value: 1.0, label: "正常" },
|
||||
{ value: 1.1, label: "稍快" },
|
||||
{ value: 1.2, label: "较快" },
|
||||
] as const;
|
||||
const currentSpeedLabel = speedOptions.find((o) => o.value === speed)?.label ?? "正常";
|
||||
|
||||
return (
|
||||
<div className="bg-white/5 rounded-2xl p-4 sm:p-6 border border-white/10 backdrop-blur-sm">
|
||||
<div className="bg-white/5 rounded-2xl p-4 sm:p-6 border border-white/10 backdrop-blur-sm relative z-10">
|
||||
<div className="flex justify-between items-center gap-2 mb-4">
|
||||
<h2 className="text-base sm:text-lg font-semibold text-white flex items-center gap-2 whitespace-nowrap">
|
||||
<Mic className="h-4 w-4 text-purple-400" />
|
||||
配音列表
|
||||
</h2>
|
||||
<div className="flex gap-1.5">
|
||||
{/* 语速下拉 (仅声音克隆模式) */}
|
||||
{ttsMode === "voiceclone" && (
|
||||
<div ref={speedRef} className="relative">
|
||||
<button
|
||||
onClick={() => setSpeedOpen((v) => !v)}
|
||||
className="px-2 py-1 text-xs bg-white/10 hover:bg-white/20 rounded text-gray-300 whitespace-nowrap flex items-center gap-1 transition-all"
|
||||
>
|
||||
语速: {currentSpeedLabel}
|
||||
<ChevronDown className={`h-3 w-3 transition-transform ${speedOpen ? "rotate-180" : ""}`} />
|
||||
</button>
|
||||
{speedOpen && (
|
||||
<div className="absolute right-0 top-full mt-1 bg-gray-800 border border-white/20 rounded-lg shadow-xl py-1 z-50 min-w-[80px]">
|
||||
{speedOptions.map((opt) => (
|
||||
<button
|
||||
key={opt.value}
|
||||
onClick={() => { onSpeedChange(opt.value); setSpeedOpen(false); }}
|
||||
className={`w-full text-left px-3 py-1.5 text-xs transition-colors ${
|
||||
speed === opt.value
|
||||
? "bg-purple-600/40 text-purple-200"
|
||||
: "text-gray-300 hover:bg-white/10"
|
||||
}`}
|
||||
>
|
||||
{opt.label}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
<button
|
||||
onClick={onGenerateAudio}
|
||||
disabled={isGeneratingAudio || !hasText}
|
||||
disabled={isGeneratingAudio || !canGenerate}
|
||||
title={missingRefAudio ? "请先选择参考音频" : !hasText ? "请先输入文案" : ""}
|
||||
className={`px-2 py-1 text-xs rounded transition-all whitespace-nowrap flex items-center gap-1 ${
|
||||
isGeneratingAudio || !hasText
|
||||
isGeneratingAudio || !canGenerate
|
||||
? "bg-gray-600 cursor-not-allowed text-gray-400"
|
||||
: "bg-gradient-to-r from-purple-600 to-pink-600 hover:from-purple-700 hover:to-pink-700 text-white"
|
||||
}`}
|
||||
@@ -120,6 +182,13 @@ export function GeneratedAudiosPanel({
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* 缺少参考音频提示 */}
|
||||
{missingRefAudio && (
|
||||
<div className="mb-3 px-3 py-2 bg-yellow-500/10 border border-yellow-500/30 rounded-lg text-yellow-300 text-xs">
|
||||
声音克隆模式需要先选择参考音频
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* 生成进度 */}
|
||||
{isGeneratingAudio && audioTask && (
|
||||
<div className="mb-4 p-3 bg-purple-500/10 rounded-xl border border-purple-500/30">
|
||||
|
||||
@@ -80,10 +80,11 @@ export function HomePage() {
|
||||
setTitleTopMargin,
|
||||
subtitleBottomMargin,
|
||||
setSubtitleBottomMargin,
|
||||
outputAspectRatio,
|
||||
setOutputAspectRatio,
|
||||
resolveAssetUrl,
|
||||
getFontFormat,
|
||||
buildTextShadow,
|
||||
materialDimensions,
|
||||
ttsMode,
|
||||
setTtsMode,
|
||||
voices,
|
||||
@@ -106,6 +107,8 @@ export function HomePage() {
|
||||
saveEditing,
|
||||
cancelEditing,
|
||||
deleteRefAudio,
|
||||
retranscribeRefAudio,
|
||||
retranscribingId,
|
||||
recordedBlob,
|
||||
isRecording,
|
||||
recordingTime,
|
||||
@@ -113,7 +116,6 @@ export function HomePage() {
|
||||
stopRecording,
|
||||
useRecording,
|
||||
formatRecordingTime,
|
||||
fixedRefText,
|
||||
bgmList,
|
||||
bgmLoading,
|
||||
bgmError,
|
||||
@@ -149,6 +151,8 @@ export function HomePage() {
|
||||
deleteAudio,
|
||||
renameAudio,
|
||||
selectAudio,
|
||||
speed,
|
||||
setSpeed,
|
||||
timelineSegments,
|
||||
reorderSegments,
|
||||
setSourceRange,
|
||||
@@ -162,6 +166,11 @@ export function HomePage() {
|
||||
router.prefetch("/publish");
|
||||
}, [router]);
|
||||
|
||||
useEffect(() => {
|
||||
if (typeof window === "undefined") return;
|
||||
window.scrollTo({ top: 0, left: 0, behavior: "auto" });
|
||||
}, []);
|
||||
|
||||
const clipTrimmerSegment = useMemo(
|
||||
() => timelineSegments.find((s) => s.id === clipTrimmerSegmentId) ?? null,
|
||||
[timelineSegments, clipTrimmerSegmentId]
|
||||
@@ -229,8 +238,8 @@ export function HomePage() {
|
||||
resolveAssetUrl={resolveAssetUrl}
|
||||
getFontFormat={getFontFormat}
|
||||
buildTextShadow={buildTextShadow}
|
||||
previewBaseWidth={materialDimensions?.width || 1080}
|
||||
previewBaseHeight={materialDimensions?.height || 1920}
|
||||
previewBaseWidth={outputAspectRatio === "16:9" ? 1920 : 1080}
|
||||
previewBaseHeight={outputAspectRatio === "16:9" ? 1080 : 1920}
|
||||
/>
|
||||
|
||||
{/* 3. 配音方式选择 */}
|
||||
@@ -259,6 +268,8 @@ export function HomePage() {
|
||||
onSaveEditing={saveEditing}
|
||||
onCancelEditing={cancelEditing}
|
||||
onDeleteRefAudio={deleteRefAudio}
|
||||
onRetranscribe={retranscribeRefAudio}
|
||||
retranscribingId={retranscribingId}
|
||||
recordedBlob={recordedBlob}
|
||||
isRecording={isRecording}
|
||||
recordingTime={recordingTime}
|
||||
@@ -266,7 +277,6 @@ export function HomePage() {
|
||||
onStopRecording={stopRecording}
|
||||
onUseRecording={useRecording}
|
||||
formatRecordingTime={formatRecordingTime}
|
||||
fixedRefText={fixedRefText}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
@@ -283,6 +293,10 @@ export function HomePage() {
|
||||
onDeleteAudio={deleteAudio}
|
||||
onRenameAudio={renameAudio}
|
||||
hasText={!!text.trim()}
|
||||
missingRefAudio={ttsMode === "voiceclone" && !selectedRefAudio}
|
||||
speed={speed}
|
||||
onSpeedChange={setSpeed}
|
||||
ttsMode={ttsMode}
|
||||
/>
|
||||
|
||||
{/* 5. 视频素材 */}
|
||||
@@ -325,6 +339,8 @@ export function HomePage() {
|
||||
audioUrl={selectedAudio ? (resolveMediaUrl(selectedAudio.path) || "") : ""}
|
||||
segments={timelineSegments}
|
||||
materials={materials}
|
||||
outputAspectRatio={outputAspectRatio}
|
||||
onOutputAspectRatioChange={setOutputAspectRatio}
|
||||
onReorderSegment={reorderSegments}
|
||||
onClickSegment={(seg) => {
|
||||
setClipTrimmerSegmentId(seg.id);
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import { useEffect, useState } from "react";
|
||||
import type { MouseEvent } from "react";
|
||||
import { Upload, RefreshCw, Play, Pause, Pencil, Trash2, Check, X, Mic, Square } from "lucide-react";
|
||||
import { Upload, RefreshCw, Play, Pause, Pencil, Trash2, Check, X, Mic, Square, RotateCw } from "lucide-react";
|
||||
|
||||
interface RefAudio {
|
||||
id: string;
|
||||
@@ -29,6 +29,8 @@ interface RefAudioPanelProps {
|
||||
onSaveEditing: (id: string, event: MouseEvent) => void;
|
||||
onCancelEditing: (event: MouseEvent) => void;
|
||||
onDeleteRefAudio: (id: string) => void;
|
||||
onRetranscribe: (id: string) => void;
|
||||
retranscribingId: string | null;
|
||||
recordedBlob: Blob | null;
|
||||
isRecording: boolean;
|
||||
recordingTime: number;
|
||||
@@ -36,9 +38,10 @@ interface RefAudioPanelProps {
|
||||
onStopRecording: () => void;
|
||||
onUseRecording: () => void;
|
||||
formatRecordingTime: (seconds: number) => string;
|
||||
fixedRefText: string;
|
||||
}
|
||||
|
||||
const OLD_FIXED_REF_TEXT = "其实生活中有许多美好的瞬间";
|
||||
|
||||
export function RefAudioPanel({
|
||||
refAudios,
|
||||
selectedRefAudio,
|
||||
@@ -57,6 +60,8 @@ export function RefAudioPanel({
|
||||
onSaveEditing,
|
||||
onCancelEditing,
|
||||
onDeleteRefAudio,
|
||||
onRetranscribe,
|
||||
retranscribingId,
|
||||
recordedBlob,
|
||||
isRecording,
|
||||
recordingTime,
|
||||
@@ -64,7 +69,6 @@ export function RefAudioPanel({
|
||||
onStopRecording,
|
||||
onUseRecording,
|
||||
formatRecordingTime,
|
||||
fixedRefText,
|
||||
}: RefAudioPanelProps) {
|
||||
const [recordedUrl, setRecordedUrl] = useState<string | null>(null);
|
||||
|
||||
@@ -81,6 +85,9 @@ export function RefAudioPanel({
|
||||
};
|
||||
}, [recordedBlob]);
|
||||
|
||||
const needsRetranscribe = (audio: RefAudio) =>
|
||||
audio.ref_text.startsWith(OLD_FIXED_REF_TEXT);
|
||||
|
||||
return (
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
@@ -122,7 +129,7 @@ export function RefAudioPanel({
|
||||
|
||||
{isUploadingRef && (
|
||||
<div className="mb-2 p-2 bg-purple-500/10 rounded text-sm text-purple-300">
|
||||
⏳ 上传中...
|
||||
⏳ 上传并识别中...
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -192,6 +199,17 @@ export function RefAudioPanel({
|
||||
<Play className="h-3.5 w-3.5" />
|
||||
)}
|
||||
</button>
|
||||
<button
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
onRetranscribe(audio.id);
|
||||
}}
|
||||
disabled={retranscribingId === audio.id}
|
||||
className="text-gray-400 hover:text-cyan-400 text-xs disabled:opacity-50"
|
||||
title="重新识别文字"
|
||||
>
|
||||
<RotateCw className={`h-3.5 w-3.5 ${retranscribingId === audio.id ? 'animate-spin' : ''}`} />
|
||||
</button>
|
||||
<button
|
||||
onClick={(e) => onStartEditing(audio, e)}
|
||||
className="text-gray-400 hover:text-blue-400 text-xs"
|
||||
@@ -211,7 +229,12 @@ export function RefAudioPanel({
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div className="text-gray-400 text-xs">{audio.duration_sec.toFixed(1)}s</div>
|
||||
<div className="text-gray-400 text-xs">
|
||||
{audio.duration_sec.toFixed(1)}s
|
||||
{needsRetranscribe(audio) && (
|
||||
<span className="text-yellow-500 ml-1" title="需要重新识别文字">⚠</span>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
@@ -221,7 +244,7 @@ export function RefAudioPanel({
|
||||
</div>
|
||||
|
||||
<div className="border-t border-white/10 pt-4">
|
||||
<span className="text-sm text-gray-300 mb-2 block">🎤 或在线录音</span>
|
||||
<span className="text-sm text-gray-300 mb-2 block">🎤 或在线录音 <span className="text-xs text-gray-500">(建议 3-10 秒,超出将自动截取)</span></span>
|
||||
<div className="flex gap-2 items-center">
|
||||
{!isRecording ? (
|
||||
<button
|
||||
@@ -264,15 +287,9 @@ export function RefAudioPanel({
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="border-t border-white/10 pt-4">
|
||||
<label className="text-sm text-gray-300 mb-2 block">📝 录音/上传时请朗读以下内容:</label>
|
||||
<div className="w-full bg-black/30 border border-white/10 rounded-lg p-3 text-white text-sm">
|
||||
{fixedRefText}
|
||||
</div>
|
||||
<p className="text-xs text-gray-500 mt-1">
|
||||
请清晰朗读上述内容完成录音,系统将以此为参考克隆您的声音
|
||||
</p>
|
||||
</div>
|
||||
<p className="text-xs text-gray-500 mt-2 border-t border-white/10 pt-3">
|
||||
上传任意语音样本(3-10秒),系统将自动识别内容并克隆声音
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { useEffect, useRef, useCallback, useState } from "react";
|
||||
import WaveSurfer from "wavesurfer.js";
|
||||
import { ChevronDown } from "lucide-react";
|
||||
import type { TimelineSegment } from "@/features/home/model/useTimelineEditor";
|
||||
import type { Material } from "@/shared/types/material";
|
||||
|
||||
@@ -8,6 +9,8 @@ interface TimelineEditorProps {
|
||||
audioUrl: string;
|
||||
segments: TimelineSegment[];
|
||||
materials: Material[];
|
||||
outputAspectRatio: "9:16" | "16:9";
|
||||
onOutputAspectRatioChange: (ratio: "9:16" | "16:9") => void;
|
||||
onReorderSegment: (fromIdx: number, toIdx: number) => void;
|
||||
onClickSegment: (segment: TimelineSegment) => void;
|
||||
}
|
||||
@@ -23,6 +26,8 @@ export function TimelineEditor({
|
||||
audioUrl,
|
||||
segments,
|
||||
materials,
|
||||
outputAspectRatio,
|
||||
onOutputAspectRatioChange,
|
||||
onReorderSegment,
|
||||
onClickSegment,
|
||||
}: TimelineEditorProps) {
|
||||
@@ -35,16 +40,42 @@ export function TimelineEditor({
|
||||
const playheadRef = useRef<HTMLDivElement>(null);
|
||||
const timeRef = useRef<HTMLSpanElement>(null);
|
||||
const audioDurationRef = useRef(audioDuration);
|
||||
audioDurationRef.current = audioDuration;
|
||||
|
||||
useEffect(() => {
|
||||
audioDurationRef.current = audioDuration;
|
||||
}, [audioDuration]);
|
||||
|
||||
// Drag-to-reorder state
|
||||
const [dragFromIdx, setDragFromIdx] = useState<number | null>(null);
|
||||
const [dragOverIdx, setDragOverIdx] = useState<number | null>(null);
|
||||
|
||||
// Aspect ratio dropdown
|
||||
const [ratioOpen, setRatioOpen] = useState(false);
|
||||
const ratioRef = useRef<HTMLDivElement>(null);
|
||||
const ratioOptions = [
|
||||
{ value: "9:16" as const, label: "竖屏 9:16" },
|
||||
{ value: "16:9" as const, label: "横屏 16:9" },
|
||||
];
|
||||
const currentRatioLabel =
|
||||
ratioOptions.find((opt) => opt.value === outputAspectRatio)?.label ?? "竖屏 9:16";
|
||||
|
||||
useEffect(() => {
|
||||
const handler = (e: MouseEvent) => {
|
||||
if (ratioRef.current && !ratioRef.current.contains(e.target as Node)) {
|
||||
setRatioOpen(false);
|
||||
}
|
||||
};
|
||||
if (ratioOpen) document.addEventListener("mousedown", handler);
|
||||
return () => document.removeEventListener("mousedown", handler);
|
||||
}, [ratioOpen]);
|
||||
|
||||
// Create / recreate wavesurfer when audioUrl changes
|
||||
useEffect(() => {
|
||||
if (!waveRef.current || !audioUrl) return;
|
||||
|
||||
const playheadEl = playheadRef.current;
|
||||
const timeEl = timeRef.current;
|
||||
|
||||
// Destroy previous instance
|
||||
if (wsRef.current) {
|
||||
wsRef.current.destroy();
|
||||
@@ -92,8 +123,8 @@ export function TimelineEditor({
|
||||
ws.destroy();
|
||||
wsRef.current = null;
|
||||
setIsPlaying(false);
|
||||
if (playheadRef.current) playheadRef.current.style.display = "none";
|
||||
if (timeRef.current) timeRef.current.textContent = formatTime(0);
|
||||
if (playheadEl) playheadEl.style.display = "none";
|
||||
if (timeEl) timeEl.textContent = formatTime(0);
|
||||
};
|
||||
}, [audioUrl, waveReady]);
|
||||
|
||||
@@ -150,20 +181,55 @@ export function TimelineEditor({
|
||||
<h2 className="text-base sm:text-lg font-semibold text-white flex items-center gap-2">
|
||||
🎞️ 时间轴编辑
|
||||
</h2>
|
||||
{audioUrl && (
|
||||
<div className="flex items-center gap-2 text-xs text-gray-400">
|
||||
<div className="flex items-center gap-2 text-xs text-gray-400">
|
||||
<div ref={ratioRef} className="relative">
|
||||
<button
|
||||
onClick={handlePlayPause}
|
||||
className="w-7 h-7 flex items-center justify-center rounded-full bg-white/10 hover:bg-white/20 text-white transition-colors"
|
||||
title={isPlaying ? "暂停" : "播放"}
|
||||
type="button"
|
||||
onClick={() => setRatioOpen((v) => !v)}
|
||||
className="px-2 py-1 text-xs bg-white/10 hover:bg-white/20 rounded text-gray-300 whitespace-nowrap flex items-center gap-1 transition-all"
|
||||
title="设置输出画面比例"
|
||||
>
|
||||
{isPlaying ? "⏸" : "▶"}
|
||||
画面: {currentRatioLabel}
|
||||
<ChevronDown className={`h-3 w-3 transition-transform ${ratioOpen ? "rotate-180" : ""}`} />
|
||||
</button>
|
||||
<span ref={timeRef} className="tabular-nums">00:00.0</span>
|
||||
<span className="text-gray-600">/</span>
|
||||
<span className="tabular-nums">{formatTime(audioDuration)}</span>
|
||||
{ratioOpen && (
|
||||
<div className="absolute right-0 top-full mt-1 bg-gray-800 border border-white/20 rounded-lg shadow-xl py-1 z-50 min-w-[106px]">
|
||||
{ratioOptions.map((opt) => (
|
||||
<button
|
||||
key={opt.value}
|
||||
type="button"
|
||||
onClick={() => {
|
||||
onOutputAspectRatioChange(opt.value);
|
||||
setRatioOpen(false);
|
||||
}}
|
||||
className={`w-full text-left px-3 py-1.5 text-xs transition-colors ${
|
||||
outputAspectRatio === opt.value
|
||||
? "bg-purple-600/40 text-purple-200"
|
||||
: "text-gray-300 hover:bg-white/10"
|
||||
}`}
|
||||
>
|
||||
{opt.label}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{audioUrl && (
|
||||
<>
|
||||
<button
|
||||
onClick={handlePlayPause}
|
||||
className="w-7 h-7 flex items-center justify-center rounded-full bg-white/10 hover:bg-white/20 text-white transition-colors"
|
||||
title={isPlaying ? "暂停" : "播放"}
|
||||
>
|
||||
{isPlaying ? "⏸" : "▶"}
|
||||
</button>
|
||||
<span ref={timeRef} className="tabular-nums">00:00.0</span>
|
||||
<span className="text-gray-600">/</span>
|
||||
<span className="tabular-nums">{formatTime(audioDuration)}</span>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Waveform — always rendered so ref stays mounted */}
|
||||
@@ -195,7 +261,7 @@ export function TimelineEditor({
|
||||
const matDur = mat?.duration_sec ?? 0;
|
||||
const effDur = (seg.sourceEnd > seg.sourceStart)
|
||||
? (seg.sourceEnd - seg.sourceStart)
|
||||
: matDur;
|
||||
: Math.max(matDur - seg.sourceStart, 0);
|
||||
if (effDur > 0 && segDur > effDur + 0.1) {
|
||||
loopPercent = ((segDur - effDur) / segDur) * 100;
|
||||
}
|
||||
|
||||
76
models/CosyVoice/CODE_OF_CONDUCT.md
Normal file
76
models/CosyVoice/CODE_OF_CONDUCT.md
Normal file
@@ -0,0 +1,76 @@
|
||||
# Contributor Covenant Code of Conduct
|
||||
|
||||
## Our Pledge
|
||||
|
||||
In the interest of fostering an open and welcoming environment, we as
|
||||
contributors and maintainers pledge to making participation in our project and
|
||||
our community a harassment-free experience for everyone, regardless of age, body
|
||||
size, disability, ethnicity, sex characteristics, gender identity and expression,
|
||||
level of experience, education, socio-economic status, nationality, personal
|
||||
appearance, race, religion, or sexual identity and orientation.
|
||||
|
||||
## Our Standards
|
||||
|
||||
Examples of behavior that contributes to creating a positive environment
|
||||
include:
|
||||
|
||||
* Using welcoming and inclusive language
|
||||
* Being respectful of differing viewpoints and experiences
|
||||
* Gracefully accepting constructive criticism
|
||||
* Focusing on what is best for the community
|
||||
* Showing empathy towards other community members
|
||||
|
||||
Examples of unacceptable behavior by participants include:
|
||||
|
||||
* The use of sexualized language or imagery and unwelcome sexual attention or
|
||||
advances
|
||||
* Trolling, insulting/derogatory comments, and personal or political attacks
|
||||
* Public or private harassment
|
||||
* Publishing others' private information, such as a physical or electronic
|
||||
address, without explicit permission
|
||||
* Other conduct which could reasonably be considered inappropriate in a
|
||||
professional setting
|
||||
|
||||
## Our Responsibilities
|
||||
|
||||
Project maintainers are responsible for clarifying the standards of acceptable
|
||||
behavior and are expected to take appropriate and fair corrective action in
|
||||
response to any instances of unacceptable behavior.
|
||||
|
||||
Project maintainers have the right and responsibility to remove, edit, or
|
||||
reject comments, commits, code, wiki edits, issues, and other contributions
|
||||
that are not aligned to this Code of Conduct, or to ban temporarily or
|
||||
permanently any contributor for other behaviors that they deem inappropriate,
|
||||
threatening, offensive, or harmful.
|
||||
|
||||
## Scope
|
||||
|
||||
This Code of Conduct applies both within project spaces and in public spaces
|
||||
when an individual is representing the project or its community. Examples of
|
||||
representing a project or community include using an official project e-mail
|
||||
address, posting via an official social media account, or acting as an appointed
|
||||
representative at an online or offline event. Representation of a project may be
|
||||
further defined and clarified by project maintainers.
|
||||
|
||||
## Enforcement
|
||||
|
||||
Instances of abusive, harassing, or otherwise unacceptable behavior may be
|
||||
reported by contacting the project team at mikelei@mobvoi.com. All
|
||||
complaints will be reviewed and investigated and will result in a response that
|
||||
is deemed necessary and appropriate to the circumstances. The project team is
|
||||
obligated to maintain confidentiality with regard to the reporter of an incident.
|
||||
Further details of specific enforcement policies may be posted separately.
|
||||
|
||||
Project maintainers who do not follow or enforce the Code of Conduct in good
|
||||
faith may face temporary or permanent repercussions as determined by other
|
||||
members of the project's leadership.
|
||||
|
||||
## Attribution
|
||||
|
||||
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
|
||||
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
|
||||
|
||||
[homepage]: https://www.contributor-covenant.org
|
||||
|
||||
For answers to common questions about this code of conduct, see
|
||||
https://www.contributor-covenant.org/faq
|
||||
16
models/CosyVoice/FAQ.md
Normal file
16
models/CosyVoice/FAQ.md
Normal file
@@ -0,0 +1,16 @@
|
||||
## ModuleNotFoundError: No module named 'matcha'
|
||||
|
||||
Matcha-TTS is a third_party module. Please check `third_party` directory. If there is no `Matcha-TTS`, execute `git submodule update --init --recursive`.
|
||||
|
||||
run `export PYTHONPATH=third_party/Matcha-TTS` if you want to use `from cosyvoice.cli.cosyvoice import CosyVoice` in python script.
|
||||
|
||||
## cannot find resource.zip or cannot unzip resource.zip
|
||||
|
||||
Please make sure you have git-lfs installed. Execute
|
||||
|
||||
```sh
|
||||
git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
|
||||
cd pretrained_models/CosyVoice-ttsfrd/
|
||||
unzip resource.zip -d .
|
||||
pip install ttsfrd-0.3.6-cp38-cp38-linux_x86_64.whl
|
||||
```
|
||||
201
models/CosyVoice/LICENSE
Normal file
201
models/CosyVoice/LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright [yyyy] [name of copyright owner]
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
264
models/CosyVoice/README.md
Normal file
264
models/CosyVoice/README.md
Normal file
@@ -0,0 +1,264 @@
|
||||

|
||||
|
||||
## 👉🏻 CosyVoice 👈🏻
|
||||
|
||||
**Fun-CosyVoice 3.0**: [Demos](https://funaudiollm.github.io/cosyvoice3/); [Paper](https://arxiv.org/pdf/2505.17589); [Modelscope](https://www.modelscope.cn/models/FunAudioLLM/Fun-CosyVoice3-0.5B-2512); [Huggingface](https://huggingface.co/FunAudioLLM/Fun-CosyVoice3-0.5B-2512); [CV3-Eval](https://github.com/FunAudioLLM/CV3-Eval)
|
||||
|
||||
**CosyVoice 2.0**: [Demos](https://funaudiollm.github.io/cosyvoice2/); [Paper](https://arxiv.org/pdf/2412.10117); [Modelscope](https://www.modelscope.cn/models/iic/CosyVoice2-0.5B); [HuggingFace](https://huggingface.co/FunAudioLLM/CosyVoice2-0.5B)
|
||||
|
||||
**CosyVoice 1.0**: [Demos](https://fun-audio-llm.github.io); [Paper](https://funaudiollm.github.io/pdf/CosyVoice_v1.pdf); [Modelscope](https://www.modelscope.cn/models/iic/CosyVoice-300M); [HuggingFace](https://huggingface.co/FunAudioLLM/CosyVoice-300M)
|
||||
|
||||
## Highlight🔥
|
||||
|
||||
**Fun-CosyVoice 3.0** is an advanced text-to-speech (TTS) system based on large language models (LLM), surpassing its predecessor (CosyVoice 2.0) in content consistency, speaker similarity, and prosody naturalness. It is designed for zero-shot multilingual speech synthesis in the wild.
|
||||
### Key Features
|
||||
- **Language Coverage**: Covers 9 common languages (Chinese, English, Japanese, Korean, German, Spanish, French, Italian, Russian), 18+ Chinese dialects/accents (Guangdong, Minnan, Sichuan, Dongbei, Shan3xi, Shan1xi, Shanghai, Tianjin, Shandong, Ningxia, Gansu, etc.) and meanwhile supports both multi-lingual/cross-lingual zero-shot voice cloning.
|
||||
- **Content Consistency & Naturalness**: Achieves state-of-the-art performance in content consistency, speaker similarity, and prosody naturalness.
|
||||
- **Pronunciation Inpainting**: Supports pronunciation inpainting of Chinese Pinyin and English CMU phonemes, providing more controllability and thus suitable for production use.
|
||||
- **Text Normalization**: Supports reading of numbers, special symbols and various text formats without a traditional frontend module.
|
||||
- **Bi-Streaming**: Support both text-in streaming and audio-out streaming, and achieves latency as low as 150ms while maintaining high-quality audio output.
|
||||
- **Instruct Support**: Supports various instructions such as languages, dialects, emotions, speed, volume, etc.
|
||||
|
||||
|
||||
## Roadmap
|
||||
|
||||
- [x] 2025/12
|
||||
|
||||
- [x] release Fun-CosyVoice3-0.5B-2512 base model, rl model and its training/inference script
|
||||
- [x] release Fun-CosyVoice3-0.5B modelscope gradio space
|
||||
|
||||
- [x] 2025/08
|
||||
|
||||
- [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support and cosyvoice2 grpo training support
|
||||
|
||||
- [x] 2025/07
|
||||
|
||||
- [x] release Fun-CosyVoice 3.0 eval set
|
||||
|
||||
- [x] 2025/05
|
||||
|
||||
- [x] add CosyVoice2-0.5B vllm support
|
||||
|
||||
- [x] 2024/12
|
||||
|
||||
- [x] 25hz CosyVoice2-0.5B released
|
||||
|
||||
- [x] 2024/09
|
||||
|
||||
- [x] 25hz CosyVoice-300M base model
|
||||
- [x] 25hz CosyVoice-300M voice conversion function
|
||||
|
||||
- [x] 2024/08
|
||||
|
||||
- [x] Repetition Aware Sampling(RAS) inference for llm stability
|
||||
- [x] Streaming inference mode support, including kv cache and sdpa for rtf optimization
|
||||
|
||||
- [x] 2024/07
|
||||
|
||||
- [x] Flow matching training support
|
||||
- [x] WeTextProcessing support when ttsfrd is not available
|
||||
- [x] Fastapi server and client
|
||||
|
||||
## Evaluation
|
||||
|
||||
| Model | Open-Source | Model Size | test-zh<br>CER (%) ↓ | test-zh<br>SS (%) ↑ | test-en<br>WER (%) ↓ | test-en<br>SS (%) ↑ | test-hard<br>CER (%) ↓ | test-hard<br>SS (%) ↑ |
|
||||
| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
|
||||
| Human | - | - | 1.26 | 75.5 | 2.14 | 73.4 | - | - |
|
||||
| Seed-TTS | ❌ | - | 1.12 | 79.6 | 2.25 | 76.2 | 7.59 | 77.6 |
|
||||
| MiniMax-Speech | ❌ | - | 0.83 | 78.3 | 1.65 | 69.2 | - | - |
|
||||
| F5-TTS | ✅ | 0.3B | 1.52 | 74.1 | 2.00 | 64.7 | 8.67 | 71.3 |
|
||||
| Spark TTS | ✅ | 0.5B | 1.2 | 66.0 | 1.98 | 57.3 | - | - |
|
||||
| CosyVoice2 | ✅ | 0.5B | 1.45 | 75.7 | 2.57 | 65.9 | 6.83 | 72.4 |
|
||||
| FireRedTTS2 | ✅ | 1.5B | 1.14 | 73.2 | 1.95 | 66.5 | - | - |
|
||||
| Index-TTS2 | ✅ | 1.5B | 1.03 | 76.5 | 2.23 | 70.6 | 7.12 | 75.5 |
|
||||
| VibeVoice-1.5B | ✅ | 1.5B | 1.16 | 74.4 | 3.04 | 68.9 | - | - |
|
||||
| VibeVoice-Realtime | ✅ | 0.5B | - | - | 2.05 | 63.3 | - | - |
|
||||
| HiggsAudio-v2 | ✅ | 3B | 1.50 | 74.0 | 2.44 | 67.7 | - | - |
|
||||
| VoxCPM | ✅ | 0.5B | 0.93 | 77.2 | 1.85 | 72.9 | 8.87 | 73.0 |
|
||||
| GLM-TTS | ✅ | 1.5B | 1.03 | 76.1 | - | - | - | - |
|
||||
| GLM-TTS RL | ✅ | 1.5B | 0.89 | 76.4 | - | - | - | - |
|
||||
| Fun-CosyVoice3-0.5B-2512 | ✅ | 0.5B | 1.21 | 78.0 | 2.24 | 71.8 | 6.71 | 75.8 |
|
||||
| Fun-CosyVoice3-0.5B-2512_RL | ✅ | 0.5B | 0.81 | 77.4 | 1.68 | 69.5 | 5.44 | 75.0 |
|
||||
|
||||
|
||||
## Install
|
||||
|
||||
### Clone and install
|
||||
|
||||
- Clone the repo
|
||||
``` sh
|
||||
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
|
||||
# If you failed to clone the submodule due to network failures, please run the following command until success
|
||||
cd CosyVoice
|
||||
git submodule update --init --recursive
|
||||
```
|
||||
|
||||
- Install Conda: please see https://docs.conda.io/en/latest/miniconda.html
|
||||
- Create Conda env:
|
||||
|
||||
``` sh
|
||||
conda create -n cosyvoice -y python=3.10
|
||||
conda activate cosyvoice
|
||||
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
||||
|
||||
# If you encounter sox compatibility issues
|
||||
# ubuntu
|
||||
sudo apt-get install sox libsox-dev
|
||||
# centos
|
||||
sudo yum install sox sox-devel
|
||||
```
|
||||
|
||||
### Model download
|
||||
|
||||
We strongly recommend that you download our pretrained `Fun-CosyVoice3-0.5B` `CosyVoice2-0.5B` `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
|
||||
|
||||
``` python
|
||||
# modelscope SDK model download
|
||||
from modelscope import snapshot_download
|
||||
snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B')
|
||||
snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
|
||||
snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
|
||||
snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
|
||||
snapshot_download('iic/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
|
||||
snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
|
||||
|
||||
# for oversea users, huggingface SDK model download
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B')
|
||||
snapshot_download('FunAudioLLM/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
|
||||
snapshot_download('FunAudioLLM/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
|
||||
snapshot_download('FunAudioLLM/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
|
||||
snapshot_download('FunAudioLLM/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
|
||||
snapshot_download('FunAudioLLM/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
|
||||
```
|
||||
|
||||
Optionally, you can unzip `ttsfrd` resource and install `ttsfrd` package for better text normalization performance.
|
||||
|
||||
Notice that this step is not necessary. If you do not install `ttsfrd` package, we will use wetext by default.
|
||||
|
||||
``` sh
|
||||
cd pretrained_models/CosyVoice-ttsfrd/
|
||||
unzip resource.zip -d .
|
||||
pip install ttsfrd_dependency-0.1-py3-none-any.whl
|
||||
pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl
|
||||
```
|
||||
|
||||
### Basic Usage
|
||||
|
||||
We strongly recommend using `Fun-CosyVoice3-0.5B` for better performance.
|
||||
Follow the code in `example.py` for detailed usage of each model.
|
||||
```sh
|
||||
python example.py
|
||||
```
|
||||
|
||||
#### vLLM Usage
|
||||
CosyVoice2/3 now supports **vLLM 0.11.x+ (V1 engine)** and **vLLM 0.9.0 (legacy)**.
|
||||
Older vllm version(<0.9.0) do not support CosyVoice inference, and versions in between (e.g., 0.10.x) are not tested.
|
||||
|
||||
Notice that `vllm` has a lot of specific requirements. You can create a new env to in case your hardward do not support vllm and old env is corrupted.
|
||||
|
||||
``` sh
|
||||
conda create -n cosyvoice_vllm --clone cosyvoice
|
||||
conda activate cosyvoice_vllm
|
||||
# for vllm==0.9.0
|
||||
pip install vllm==v0.9.0 transformers==4.51.3 numpy==1.26.4 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
||||
# for vllm>=0.11.0
|
||||
pip install vllm==v0.11.0 transformers==4.57.1 numpy==1.26.4 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
|
||||
python vllm_example.py
|
||||
```
|
||||
|
||||
#### Start web demo
|
||||
|
||||
You can use our web demo page to get familiar with CosyVoice quickly.
|
||||
|
||||
Please see the demo website for details.
|
||||
|
||||
``` python
|
||||
# change iic/CosyVoice-300M-SFT for sft inference, or iic/CosyVoice-300M-Instruct for instruct inference
|
||||
python3 webui.py --port 50000 --model_dir pretrained_models/CosyVoice-300M
|
||||
```
|
||||
|
||||
#### Advanced Usage
|
||||
|
||||
For advanced users, we have provided training and inference scripts in `examples/libritts`.
|
||||
|
||||
#### Build for deployment
|
||||
|
||||
Optionally, if you want service deployment,
|
||||
You can run the following steps.
|
||||
|
||||
``` sh
|
||||
cd runtime/python
|
||||
docker build -t cosyvoice:v1.0 .
|
||||
# change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
|
||||
# for grpc usage
|
||||
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
|
||||
cd grpc && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
||||
# for fastapi usage
|
||||
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && python3 server.py --port 50000 --model_dir iic/CosyVoice-300M && sleep infinity"
|
||||
cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
|
||||
```
|
||||
|
||||
#### Using Nvidia TensorRT-LLM for deployment
|
||||
|
||||
Using TensorRT-LLM to accelerate cosyvoice2 llm could give 4x acceleration comparing with huggingface transformers implementation.
|
||||
To quick start:
|
||||
|
||||
``` sh
|
||||
cd runtime/triton_trtllm
|
||||
docker compose up -d
|
||||
```
|
||||
For more details, you could check [here](https://github.com/FunAudioLLM/CosyVoice/tree/main/runtime/triton_trtllm)
|
||||
|
||||
## Discussion & Communication
|
||||
|
||||
You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).
|
||||
|
||||
You can also scan the QR code to join our official Dingding chat group.
|
||||
|
||||
<img src="./asset/dingding.png" width="250px">
|
||||
|
||||
## Acknowledge
|
||||
|
||||
1. We borrowed a lot of code from [FunASR](https://github.com/modelscope/FunASR).
|
||||
2. We borrowed a lot of code from [FunCodec](https://github.com/modelscope/FunCodec).
|
||||
3. We borrowed a lot of code from [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS).
|
||||
4. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec).
|
||||
5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
|
||||
|
||||
## Citations
|
||||
|
||||
``` bibtex
|
||||
@article{du2024cosyvoice,
|
||||
title={Cosyvoice: A scalable multilingual zero-shot text-to-speech synthesizer based on supervised semantic tokens},
|
||||
author={Du, Zhihao and Chen, Qian and Zhang, Shiliang and Hu, Kai and Lu, Heng and Yang, Yexin and Hu, Hangrui and Zheng, Siqi and Gu, Yue and Ma, Ziyang and others},
|
||||
journal={arXiv preprint arXiv:2407.05407},
|
||||
year={2024}
|
||||
}
|
||||
|
||||
@article{du2024cosyvoice,
|
||||
title={Cosyvoice 2: Scalable streaming speech synthesis with large language models},
|
||||
author={Du, Zhihao and Wang, Yuxuan and Chen, Qian and Shi, Xian and Lv, Xiang and Zhao, Tianyu and Gao, Zhifu and Yang, Yexin and Gao, Changfeng and Wang, Hui and others},
|
||||
journal={arXiv preprint arXiv:2412.10117},
|
||||
year={2024}
|
||||
}
|
||||
|
||||
@article{du2025cosyvoice,
|
||||
title={CosyVoice 3: Towards In-the-wild Speech Generation via Scaling-up and Post-training},
|
||||
author={Du, Zhihao and Gao, Changfeng and Wang, Yuxuan and Yu, Fan and Zhao, Tianyu and Wang, Hao and Lv, Xiang and Wang, Hui and Shi, Xian and An, Keyu and others},
|
||||
journal={arXiv preprint arXiv:2505.17589},
|
||||
year={2025}
|
||||
}
|
||||
|
||||
@inproceedings{lyu2025build,
|
||||
title={Build LLM-Based Zero-Shot Streaming TTS System with Cosyvoice},
|
||||
author={Lyu, Xiang and Wang, Yuxuan and Zhao, Tianyu and Wang, Hao and Liu, Huadai and Du, Zhihao},
|
||||
booktitle={ICASSP 2025-2025 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
|
||||
pages={1--2},
|
||||
year={2025},
|
||||
organization={IEEE}
|
||||
}
|
||||
```
|
||||
|
||||
## Disclaimer
|
||||
The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.
|
||||
0
models/CosyVoice/cosyvoice/__init__.py
Normal file
0
models/CosyVoice/cosyvoice/__init__.py
Normal file
93
models/CosyVoice/cosyvoice/bin/average_model.py
Normal file
93
models/CosyVoice/cosyvoice/bin/average_model.py
Normal file
@@ -0,0 +1,93 @@
|
||||
# Copyright (c) 2020 Mobvoi Inc (Di Wu)
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import argparse
|
||||
import glob
|
||||
|
||||
import yaml
|
||||
import torch
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='average model')
|
||||
parser.add_argument('--dst_model', required=True, help='averaged model')
|
||||
parser.add_argument('--src_path',
|
||||
required=True,
|
||||
help='src model path for average')
|
||||
parser.add_argument('--val_best',
|
||||
action="store_true",
|
||||
help='averaged model')
|
||||
parser.add_argument('--num',
|
||||
default=5,
|
||||
type=int,
|
||||
help='nums for averaged model')
|
||||
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
return args
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
val_scores = []
|
||||
if args.val_best:
|
||||
yamls = glob.glob('{}/*.yaml'.format(args.src_path))
|
||||
yamls = [
|
||||
f for f in yamls
|
||||
if not (os.path.basename(f).startswith('train')
|
||||
or os.path.basename(f).startswith('init'))
|
||||
]
|
||||
for y in yamls:
|
||||
with open(y, 'r') as f:
|
||||
dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
|
||||
loss = float(dic_yaml['loss_dict']['loss'])
|
||||
epoch = int(dic_yaml['epoch'])
|
||||
step = int(dic_yaml['step'])
|
||||
tag = dic_yaml['tag']
|
||||
val_scores += [[epoch, step, loss, tag]]
|
||||
sorted_val_scores = sorted(val_scores,
|
||||
key=lambda x: x[2],
|
||||
reverse=False)
|
||||
print("best val (epoch, step, loss, tag) = " +
|
||||
str(sorted_val_scores[:args.num]))
|
||||
path_list = [
|
||||
args.src_path + '/epoch_{}_whole.pt'.format(score[0])
|
||||
for score in sorted_val_scores[:args.num]
|
||||
]
|
||||
print(path_list)
|
||||
avg = {}
|
||||
num = args.num
|
||||
assert num == len(path_list)
|
||||
for path in path_list:
|
||||
print('Processing {}'.format(path))
|
||||
states = torch.load(path, map_location=torch.device('cpu'))
|
||||
for k in states.keys():
|
||||
if k not in ['step', 'epoch']:
|
||||
if k not in avg.keys():
|
||||
avg[k] = states[k].clone()
|
||||
else:
|
||||
avg[k] += states[k]
|
||||
# average
|
||||
for k in avg.keys():
|
||||
if avg[k] is not None:
|
||||
# pytorch 1.6 use true_divide instead of /=
|
||||
avg[k] = torch.true_divide(avg[k], num)
|
||||
print('Saving to {}'.format(args.dst_model))
|
||||
torch.save(avg, args.dst_model)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
99
models/CosyVoice/cosyvoice/bin/export_jit.py
Normal file
99
models/CosyVoice/cosyvoice/bin/export_jit.py
Normal file
@@ -0,0 +1,99 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append('{}/../..'.format(ROOT_DIR))
|
||||
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||
from cosyvoice.cli.cosyvoice import AutoModel
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='export your model for deployment')
|
||||
parser.add_argument('--model_dir',
|
||||
type=str,
|
||||
default='pretrained_models/CosyVoice-300M',
|
||||
help='local path')
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
return args
|
||||
|
||||
|
||||
def get_optimized_script(model, preserved_attrs=[]):
|
||||
script = torch.jit.script(model)
|
||||
if preserved_attrs != []:
|
||||
script = torch.jit.freeze(script, preserved_attrs=preserved_attrs)
|
||||
else:
|
||||
script = torch.jit.freeze(script)
|
||||
script = torch.jit.optimize_for_inference(script)
|
||||
return script
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
|
||||
torch._C._jit_set_fusion_strategy([('STATIC', 1)])
|
||||
torch._C._jit_set_profiling_mode(False)
|
||||
torch._C._jit_set_profiling_executor(False)
|
||||
|
||||
model = AutoModel(model_dir=args.model_dir)
|
||||
|
||||
if model.__class__.__name__ == 'CosyVoice':
|
||||
# 1. export llm text_encoder
|
||||
llm_text_encoder = model.model.llm.text_encoder
|
||||
script = get_optimized_script(llm_text_encoder)
|
||||
script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir))
|
||||
script = get_optimized_script(llm_text_encoder.half())
|
||||
script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
|
||||
logging.info('successfully export llm_text_encoder')
|
||||
|
||||
# 2. export llm llm
|
||||
llm_llm = model.model.llm.llm
|
||||
script = get_optimized_script(llm_llm, ['forward_chunk'])
|
||||
script.save('{}/llm.llm.fp32.zip'.format(args.model_dir))
|
||||
script = get_optimized_script(llm_llm.half(), ['forward_chunk'])
|
||||
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
|
||||
logging.info('successfully export llm_llm')
|
||||
|
||||
# 3. export flow encoder
|
||||
flow_encoder = model.model.flow.encoder
|
||||
script = get_optimized_script(flow_encoder)
|
||||
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
||||
script = get_optimized_script(flow_encoder.half())
|
||||
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
||||
logging.info('successfully export flow_encoder')
|
||||
elif model.__class__.__name__ == 'CosyVoice2':
|
||||
# 1. export flow encoder
|
||||
flow_encoder = model.model.flow.encoder
|
||||
script = get_optimized_script(flow_encoder)
|
||||
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
||||
script = get_optimized_script(flow_encoder.half())
|
||||
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
|
||||
logging.info('successfully export flow_encoder')
|
||||
else:
|
||||
raise ValueError('unsupported model type')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
114
models/CosyVoice/cosyvoice/bin/export_onnx.py
Normal file
114
models/CosyVoice/cosyvoice/bin/export_onnx.py
Normal file
@@ -0,0 +1,114 @@
|
||||
# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||
import os
|
||||
import sys
|
||||
import onnxruntime
|
||||
import random
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
sys.path.append('{}/../..'.format(ROOT_DIR))
|
||||
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
||||
from cosyvoice.cli.cosyvoice import AutoModel
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
|
||||
|
||||
def get_dummy_input(batch_size, seq_len, out_channels, device):
|
||||
x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
||||
mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
|
||||
mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
||||
t = torch.rand((batch_size), dtype=torch.float32, device=device)
|
||||
spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
|
||||
cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
|
||||
return x, mask, mu, t, spks, cond
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='export your model for deployment')
|
||||
parser.add_argument('--model_dir',
|
||||
type=str,
|
||||
default='pretrained_models/CosyVoice-300M',
|
||||
help='local path')
|
||||
args = parser.parse_args()
|
||||
print(args)
|
||||
return args
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def main():
|
||||
args = get_args()
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
|
||||
model = AutoModel(model_dir=args.model_dir)
|
||||
|
||||
# 1. export flow decoder estimator
|
||||
estimator = model.model.flow.decoder.estimator
|
||||
estimator.eval()
|
||||
|
||||
device = model.model.device
|
||||
batch_size, seq_len = 2, 256
|
||||
out_channels = model.model.flow.decoder.estimator.out_channels
|
||||
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
|
||||
torch.onnx.export(
|
||||
estimator,
|
||||
(x, mask, mu, t, spks, cond),
|
||||
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
||||
export_params=True,
|
||||
opset_version=18,
|
||||
do_constant_folding=True,
|
||||
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
|
||||
output_names=['estimator_out'],
|
||||
dynamic_axes={
|
||||
'x': {2: 'seq_len'},
|
||||
'mask': {2: 'seq_len'},
|
||||
'mu': {2: 'seq_len'},
|
||||
'cond': {2: 'seq_len'},
|
||||
'estimator_out': {2: 'seq_len'},
|
||||
}
|
||||
)
|
||||
|
||||
# 2. test computation consistency
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
||||
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
||||
sess_options=option, providers=providers)
|
||||
|
||||
for _ in tqdm(range(10)):
|
||||
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
|
||||
output_pytorch = estimator(x, mask, mu, t, spks, cond)
|
||||
ort_inputs = {
|
||||
'x': x.cpu().numpy(),
|
||||
'mask': mask.cpu().numpy(),
|
||||
'mu': mu.cpu().numpy(),
|
||||
't': t.cpu().numpy(),
|
||||
'spks': spks.cpu().numpy(),
|
||||
'cond': cond.cpu().numpy()
|
||||
}
|
||||
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
|
||||
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
|
||||
logging.info('successfully export estimator')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
195
models/CosyVoice/cosyvoice/bin/train.py
Normal file
195
models/CosyVoice/cosyvoice/bin/train.py
Normal file
@@ -0,0 +1,195 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import print_function
|
||||
import argparse
|
||||
import datetime
|
||||
import logging
|
||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||
from copy import deepcopy
|
||||
import os
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import deepspeed
|
||||
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
|
||||
from torch.distributed.elastic.multiprocessing.errors import record
|
||||
|
||||
from cosyvoice.utils.losses import DPOLoss
|
||||
from cosyvoice.utils.executor import Executor
|
||||
from cosyvoice.utils.train_utils import (
|
||||
init_distributed,
|
||||
init_dataset_and_dataloader,
|
||||
init_optimizer_and_scheduler,
|
||||
init_summarywriter, save_model,
|
||||
wrap_cuda_model, check_modify_and_save_config)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description='training your network')
|
||||
parser.add_argument('--train_engine',
|
||||
default='torch_ddp',
|
||||
choices=['torch_ddp', 'deepspeed'],
|
||||
help='Engine for paralleled training')
|
||||
parser.add_argument('--model', required=True, help='model which will be trained')
|
||||
parser.add_argument('--ref_model', required=False, help='ref model used in dpo')
|
||||
parser.add_argument('--config', required=True, help='config file')
|
||||
parser.add_argument('--train_data', required=True, help='train data file')
|
||||
parser.add_argument('--cv_data', required=True, help='cv data file')
|
||||
parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
|
||||
parser.add_argument('--onnx_path', required=False, help='onnx path, which is required for online feature extraction')
|
||||
parser.add_argument('--checkpoint', help='checkpoint model')
|
||||
parser.add_argument('--model_dir', required=True, help='save model dir')
|
||||
parser.add_argument('--tensorboard_dir',
|
||||
default='tensorboard',
|
||||
help='tensorboard log dir')
|
||||
parser.add_argument('--ddp.dist_backend',
|
||||
dest='dist_backend',
|
||||
default='nccl',
|
||||
choices=['nccl', 'gloo'],
|
||||
help='distributed backend')
|
||||
parser.add_argument('--num_workers',
|
||||
default=0,
|
||||
type=int,
|
||||
help='num of subprocess workers for reading')
|
||||
parser.add_argument('--prefetch',
|
||||
default=100,
|
||||
type=int,
|
||||
help='prefetch number')
|
||||
parser.add_argument('--pin_memory',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Use pinned memory buffers used for reading')
|
||||
parser.add_argument('--use_amp',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Use automatic mixed precision training')
|
||||
parser.add_argument('--dpo',
|
||||
action='store_true',
|
||||
default=False,
|
||||
help='Use Direct Preference Optimization')
|
||||
parser.add_argument('--deepspeed.save_states',
|
||||
dest='save_states',
|
||||
default='model_only',
|
||||
choices=['model_only', 'model+optimizer'],
|
||||
help='save model/optimizer states')
|
||||
parser.add_argument('--timeout',
|
||||
default=60,
|
||||
type=int,
|
||||
help='timeout (in seconds) of cosyvoice_join.')
|
||||
parser = deepspeed.add_config_arguments(parser)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
@record
|
||||
def main():
|
||||
args = get_args()
|
||||
os.environ['onnx_path'] = args.onnx_path
|
||||
logging.basicConfig(level=logging.DEBUG,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
# gan train has some special initialization logic
|
||||
gan = True if args.model == 'hifigan' else False
|
||||
|
||||
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
|
||||
if gan is True:
|
||||
override_dict.pop('hift')
|
||||
if args.qwen_pretrain_path is not None:
|
||||
override_dict['qwen_pretrain_path'] = args.qwen_pretrain_path
|
||||
with open(args.config, 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides=override_dict)
|
||||
if gan is True:
|
||||
configs['train_conf'] = configs['train_conf_gan']
|
||||
configs['train_conf'].update(vars(args))
|
||||
|
||||
# Init env for ddp
|
||||
init_distributed(args)
|
||||
|
||||
# Get dataset & dataloader
|
||||
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
|
||||
init_dataset_and_dataloader(args, configs, gan, args.dpo)
|
||||
|
||||
# Do some sanity checks and save config to arsg.model_dir
|
||||
configs = check_modify_and_save_config(args, configs)
|
||||
|
||||
# Tensorboard summary
|
||||
writer = init_summarywriter(args)
|
||||
|
||||
# load checkpoint
|
||||
if args.dpo is True:
|
||||
configs[args.model].forward = configs[args.model].forward_dpo
|
||||
model = configs[args.model]
|
||||
start_step, start_epoch = 0, -1
|
||||
if args.checkpoint is not None:
|
||||
if os.path.exists(args.checkpoint):
|
||||
state_dict = torch.load(args.checkpoint, map_location='cpu')
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
if 'step' in state_dict:
|
||||
start_step = state_dict['step']
|
||||
if 'epoch' in state_dict:
|
||||
start_epoch = state_dict['epoch']
|
||||
else:
|
||||
logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
|
||||
|
||||
# Dispatch model from cpu to gpu
|
||||
model = wrap_cuda_model(args, model)
|
||||
|
||||
# Get optimizer & scheduler
|
||||
model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
|
||||
scheduler.set_step(start_step)
|
||||
if scheduler_d is not None:
|
||||
scheduler_d.set_step(start_step)
|
||||
|
||||
# Save init checkpoints
|
||||
info_dict = deepcopy(configs['train_conf'])
|
||||
info_dict['step'] = start_step
|
||||
info_dict['epoch'] = start_epoch
|
||||
save_model(model, 'init', info_dict)
|
||||
|
||||
# DPO related
|
||||
if args.dpo is True:
|
||||
ref_model = deepcopy(configs[args.model])
|
||||
state_dict = torch.load(args.ref_model, map_location='cpu')
|
||||
ref_model.load_state_dict(state_dict, strict=False)
|
||||
dpo_loss = DPOLoss(beta=0.01, label_smoothing=0.0, ipo=False)
|
||||
# NOTE maybe it is not needed to wrap ref_model as ddp because its parameter is not updated
|
||||
ref_model = wrap_cuda_model(args, ref_model)
|
||||
else:
|
||||
ref_model, dpo_loss = None, None
|
||||
|
||||
# Get executor
|
||||
executor = Executor(gan=gan, ref_model=ref_model, dpo_loss=dpo_loss)
|
||||
executor.step = start_step
|
||||
|
||||
# Init scaler, used for pytorch amp mixed precision training
|
||||
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
|
||||
print('start step {} start epoch {}'.format(start_step, start_epoch))
|
||||
|
||||
# Start training loop
|
||||
for epoch in range(start_epoch + 1, info_dict['max_epoch']):
|
||||
executor.epoch = epoch
|
||||
train_dataset.set_epoch(epoch)
|
||||
dist.barrier()
|
||||
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
|
||||
if gan is True:
|
||||
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
|
||||
writer, info_dict, scaler, group_join)
|
||||
else:
|
||||
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=ref_model)
|
||||
dist.destroy_process_group(group_join)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
0
models/CosyVoice/cosyvoice/cli/__init__.py
Normal file
0
models/CosyVoice/cosyvoice/cli/__init__.py
Normal file
240
models/CosyVoice/cosyvoice/cli/cosyvoice.py
Normal file
240
models/CosyVoice/cosyvoice/cli/cosyvoice.py
Normal file
@@ -0,0 +1,240 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import time
|
||||
from typing import Generator
|
||||
from tqdm import tqdm
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
from modelscope import snapshot_download
|
||||
import torch
|
||||
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
||||
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
from cosyvoice.utils.class_utils import get_model_type
|
||||
|
||||
|
||||
class CosyVoice:
|
||||
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
|
||||
self.model_dir = model_dir
|
||||
self.fp16 = fp16
|
||||
if not os.path.exists(model_dir):
|
||||
model_dir = snapshot_download(model_dir)
|
||||
hyper_yaml_path = '{}/cosyvoice.yaml'.format(model_dir)
|
||||
if not os.path.exists(hyper_yaml_path):
|
||||
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
||||
with open(hyper_yaml_path, 'r') as f:
|
||||
configs = load_hyperpyyaml(f)
|
||||
assert get_model_type(configs) == CosyVoiceModel, 'do not use {} for CosyVoice initialization!'.format(model_dir)
|
||||
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
||||
configs['feat_extractor'],
|
||||
'{}/campplus.onnx'.format(model_dir),
|
||||
'{}/speech_tokenizer_v1.onnx'.format(model_dir),
|
||||
'{}/spk2info.pt'.format(model_dir),
|
||||
configs['allowed_special'])
|
||||
self.sample_rate = configs['sample_rate']
|
||||
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
|
||||
load_jit, load_trt, fp16 = False, False, False
|
||||
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
|
||||
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
|
||||
self.model.load('{}/llm.pt'.format(model_dir),
|
||||
'{}/flow.pt'.format(model_dir),
|
||||
'{}/hift.pt'.format(model_dir))
|
||||
if load_jit:
|
||||
self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||
'{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||
'{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||
if load_trt:
|
||||
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
||||
trt_concurrent,
|
||||
self.fp16)
|
||||
del configs
|
||||
|
||||
def list_available_spks(self):
|
||||
spks = list(self.frontend.spk2info.keys())
|
||||
return spks
|
||||
|
||||
def add_zero_shot_spk(self, prompt_text, prompt_wav, zero_shot_spk_id):
|
||||
assert zero_shot_spk_id != '', 'do not use empty zero_shot_spk_id'
|
||||
model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_wav, self.sample_rate, '')
|
||||
del model_input['text']
|
||||
del model_input['text_len']
|
||||
self.frontend.spk2info[zero_shot_spk_id] = model_input
|
||||
return True
|
||||
|
||||
def save_spkinfo(self):
|
||||
torch.save(self.frontend.spk2info, '{}/spk2info.pt'.format(self.model_dir))
|
||||
|
||||
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
|
||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||
model_input = self.frontend.frontend_sft(i, spk_id)
|
||||
start_time = time.time()
|
||||
logging.info('synthesis text {}'.format(i))
|
||||
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
def inference_zero_shot(self, tts_text, prompt_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
|
||||
if self.__class__.__name__ == 'CosyVoice3' and '<|endofprompt|>' not in prompt_text + tts_text:
|
||||
logging.warning('<|endofprompt|> not found in CosyVoice3 inference, check your input text')
|
||||
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
|
||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
|
||||
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
|
||||
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_wav, self.sample_rate, zero_shot_spk_id)
|
||||
start_time = time.time()
|
||||
logging.info('synthesis text {}'.format(i))
|
||||
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
def inference_cross_lingual(self, tts_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
|
||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||
model_input = self.frontend.frontend_cross_lingual(i, prompt_wav, self.sample_rate, zero_shot_spk_id)
|
||||
start_time = time.time()
|
||||
logging.info('synthesis text {}'.format(i))
|
||||
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
|
||||
assert self.__class__.__name__ == 'CosyVoice', 'inference_instruct is only implemented for CosyVoice!'
|
||||
instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
|
||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
|
||||
start_time = time.time()
|
||||
logging.info('synthesis text {}'.format(i))
|
||||
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
def inference_vc(self, source_wav, prompt_wav, stream=False, speed=1.0):
|
||||
model_input = self.frontend.frontend_vc(source_wav, prompt_wav, self.sample_rate)
|
||||
start_time = time.time()
|
||||
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
class CosyVoice2(CosyVoice):
|
||||
|
||||
def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
|
||||
self.model_dir = model_dir
|
||||
self.fp16 = fp16
|
||||
if not os.path.exists(model_dir):
|
||||
model_dir = snapshot_download(model_dir)
|
||||
hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
|
||||
if not os.path.exists(hyper_yaml_path):
|
||||
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
||||
with open(hyper_yaml_path, 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
||||
assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
|
||||
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
||||
configs['feat_extractor'],
|
||||
'{}/campplus.onnx'.format(model_dir),
|
||||
'{}/speech_tokenizer_v2.onnx'.format(model_dir),
|
||||
'{}/spk2info.pt'.format(model_dir),
|
||||
configs['allowed_special'])
|
||||
self.sample_rate = configs['sample_rate']
|
||||
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or load_vllm is True or fp16 is True):
|
||||
load_jit, load_trt, load_vllm, fp16 = False, False, False, False
|
||||
logging.warning('no cuda device, set load_jit/load_trt/load_vllm/fp16 to False')
|
||||
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
|
||||
self.model.load('{}/llm.pt'.format(model_dir),
|
||||
'{}/flow.pt'.format(model_dir),
|
||||
'{}/hift.pt'.format(model_dir))
|
||||
if load_vllm:
|
||||
self.model.load_vllm('{}/vllm'.format(model_dir))
|
||||
if load_jit:
|
||||
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
|
||||
if load_trt:
|
||||
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
||||
trt_concurrent,
|
||||
self.fp16)
|
||||
del configs
|
||||
|
||||
def inference_instruct2(self, tts_text, instruct_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
|
||||
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
|
||||
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_wav, self.sample_rate, zero_shot_spk_id)
|
||||
start_time = time.time()
|
||||
logging.info('synthesis text {}'.format(i))
|
||||
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
|
||||
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
|
||||
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
|
||||
yield model_output
|
||||
start_time = time.time()
|
||||
|
||||
|
||||
class CosyVoice3(CosyVoice2):
|
||||
|
||||
def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
|
||||
self.model_dir = model_dir
|
||||
self.fp16 = fp16
|
||||
if not os.path.exists(model_dir):
|
||||
model_dir = snapshot_download(model_dir)
|
||||
hyper_yaml_path = '{}/cosyvoice3.yaml'.format(model_dir)
|
||||
if not os.path.exists(hyper_yaml_path):
|
||||
raise ValueError('{} not found!'.format(hyper_yaml_path))
|
||||
with open(hyper_yaml_path, 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
|
||||
assert get_model_type(configs) == CosyVoice3Model, 'do not use {} for CosyVoice3 initialization!'.format(model_dir)
|
||||
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
|
||||
configs['feat_extractor'],
|
||||
'{}/campplus.onnx'.format(model_dir),
|
||||
'{}/speech_tokenizer_v3.onnx'.format(model_dir),
|
||||
'{}/spk2info.pt'.format(model_dir),
|
||||
configs['allowed_special'])
|
||||
self.sample_rate = configs['sample_rate']
|
||||
if torch.cuda.is_available() is False and (load_trt is True or fp16 is True):
|
||||
load_trt, fp16 = False, False
|
||||
logging.warning('no cuda device, set load_trt/fp16 to False')
|
||||
self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16)
|
||||
self.model.load('{}/llm.pt'.format(model_dir),
|
||||
'{}/flow.pt'.format(model_dir),
|
||||
'{}/hift.pt'.format(model_dir))
|
||||
if load_vllm:
|
||||
self.model.load_vllm('{}/vllm'.format(model_dir))
|
||||
if load_trt:
|
||||
if self.fp16 is True:
|
||||
logging.warning('DiT tensorRT fp16 engine have some performance issue, use at caution!')
|
||||
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
|
||||
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
|
||||
trt_concurrent,
|
||||
self.fp16)
|
||||
del configs
|
||||
|
||||
|
||||
def AutoModel(**kwargs):
|
||||
if not os.path.exists(kwargs['model_dir']):
|
||||
kwargs['model_dir'] = snapshot_download(kwargs['model_dir'])
|
||||
if os.path.exists('{}/cosyvoice.yaml'.format(kwargs['model_dir'])):
|
||||
return CosyVoice(**kwargs)
|
||||
elif os.path.exists('{}/cosyvoice2.yaml'.format(kwargs['model_dir'])):
|
||||
return CosyVoice2(**kwargs)
|
||||
elif os.path.exists('{}/cosyvoice3.yaml'.format(kwargs['model_dir'])):
|
||||
return CosyVoice3(**kwargs)
|
||||
else:
|
||||
raise TypeError('No valid model type found!')
|
||||
224
models/CosyVoice/cosyvoice/cli/frontend.py
Normal file
224
models/CosyVoice/cosyvoice/cli/frontend.py
Normal file
@@ -0,0 +1,224 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from functools import partial
|
||||
from typing import Generator
|
||||
import json
|
||||
import onnxruntime
|
||||
import torch
|
||||
import numpy as np
|
||||
import whisper
|
||||
from typing import Callable
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
import os
|
||||
import re
|
||||
import inflect
|
||||
from cosyvoice.utils.file_utils import logging, load_wav
|
||||
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
|
||||
|
||||
|
||||
class CosyVoiceFrontEnd:
|
||||
|
||||
def __init__(self,
|
||||
get_tokenizer: Callable,
|
||||
feat_extractor: Callable,
|
||||
campplus_model: str,
|
||||
speech_tokenizer_model: str,
|
||||
spk2info: str = '',
|
||||
allowed_special: str = 'all'):
|
||||
self.tokenizer = get_tokenizer()
|
||||
self.feat_extractor = feat_extractor
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
||||
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
|
||||
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
|
||||
"CPUExecutionProvider"])
|
||||
if os.path.exists(spk2info):
|
||||
self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=True)
|
||||
else:
|
||||
self.spk2info = {}
|
||||
self.allowed_special = allowed_special
|
||||
self.inflect_parser = inflect.engine()
|
||||
# NOTE compatible when no text frontend tool is avaliable
|
||||
try:
|
||||
import ttsfrd
|
||||
self.frd = ttsfrd.TtsFrontendEngine()
|
||||
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
||||
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
|
||||
'failed to initialize ttsfrd resource'
|
||||
self.frd.set_lang_type('pinyinvg')
|
||||
self.text_frontend = 'ttsfrd'
|
||||
logging.info('use ttsfrd frontend')
|
||||
except:
|
||||
try:
|
||||
from wetext import Normalizer as ZhNormalizer
|
||||
from wetext import Normalizer as EnNormalizer
|
||||
self.zh_tn_model = ZhNormalizer(remove_erhua=False)
|
||||
self.en_tn_model = EnNormalizer()
|
||||
self.text_frontend = 'wetext'
|
||||
logging.info('use wetext frontend')
|
||||
except:
|
||||
self.text_frontend = ''
|
||||
logging.info('no frontend is avaliable')
|
||||
|
||||
|
||||
def _extract_text_token(self, text):
|
||||
if isinstance(text, Generator):
|
||||
logging.info('get tts_text generator, will return _extract_text_token_generator!')
|
||||
# NOTE add a dummy text_token_len for compatibility
|
||||
return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
|
||||
else:
|
||||
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
|
||||
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
|
||||
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
|
||||
return text_token, text_token_len
|
||||
|
||||
def _extract_text_token_generator(self, text_generator):
|
||||
for text in text_generator:
|
||||
text_token, _ = self._extract_text_token(text)
|
||||
for i in range(text_token.shape[1]):
|
||||
yield text_token[:, i: i + 1]
|
||||
|
||||
def _extract_speech_token(self, prompt_wav):
|
||||
speech = load_wav(prompt_wav, 16000)
|
||||
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
|
||||
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
||||
speech_token = self.speech_tokenizer_session.run(None,
|
||||
{self.speech_tokenizer_session.get_inputs()[0].name:
|
||||
feat.detach().cpu().numpy(),
|
||||
self.speech_tokenizer_session.get_inputs()[1].name:
|
||||
np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
|
||||
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
|
||||
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
||||
return speech_token, speech_token_len
|
||||
|
||||
def _extract_spk_embedding(self, prompt_wav):
|
||||
speech = load_wav(prompt_wav, 16000)
|
||||
feat = kaldi.fbank(speech,
|
||||
num_mel_bins=80,
|
||||
dither=0,
|
||||
sample_frequency=16000)
|
||||
feat = feat - feat.mean(dim=0, keepdim=True)
|
||||
embedding = self.campplus_session.run(None,
|
||||
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
||||
embedding = torch.tensor([embedding]).to(self.device)
|
||||
return embedding
|
||||
|
||||
def _extract_speech_feat(self, prompt_wav):
|
||||
speech = load_wav(prompt_wav, 24000)
|
||||
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
|
||||
speech_feat = speech_feat.unsqueeze(dim=0)
|
||||
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
|
||||
return speech_feat, speech_feat_len
|
||||
|
||||
def text_normalize(self, text, split=True, text_frontend=True):
|
||||
if isinstance(text, Generator):
|
||||
logging.info('get tts_text generator, will skip text_normalize!')
|
||||
return [text]
|
||||
# NOTE skip text_frontend when ssml symbol in text
|
||||
if '<|' in text and '|>' in text:
|
||||
text_frontend = False
|
||||
if text_frontend is False or text == '':
|
||||
return [text] if split is True else text
|
||||
text = text.strip()
|
||||
if self.text_frontend == 'ttsfrd':
|
||||
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
|
||||
text = ''.join(texts)
|
||||
else:
|
||||
if contains_chinese(text):
|
||||
if self.text_frontend == 'wetext':
|
||||
text = self.zh_tn_model.normalize(text)
|
||||
text = text.replace("\n", "")
|
||||
text = replace_blank(text)
|
||||
text = replace_corner_mark(text)
|
||||
text = text.replace(".", "。")
|
||||
text = text.replace(" - ", ",")
|
||||
text = remove_bracket(text)
|
||||
text = re.sub(r'[,,、]+$', '。', text)
|
||||
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
||||
token_min_n=60, merge_len=20, comma_split=False))
|
||||
else:
|
||||
if self.text_frontend == 'wetext':
|
||||
text = self.en_tn_model.normalize(text)
|
||||
text = spell_out_number(text, self.inflect_parser)
|
||||
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
||||
token_min_n=60, merge_len=20, comma_split=False))
|
||||
texts = [i for i in texts if not is_only_punctuation(i)]
|
||||
return texts if split is True else text
|
||||
|
||||
def frontend_sft(self, tts_text, spk_id):
|
||||
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
||||
embedding = self.spk2info[spk_id]['embedding']
|
||||
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
|
||||
return model_input
|
||||
|
||||
def frontend_zero_shot(self, tts_text, prompt_text, prompt_wav, resample_rate, zero_shot_spk_id):
|
||||
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
|
||||
if zero_shot_spk_id == '':
|
||||
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
|
||||
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_wav)
|
||||
speech_token, speech_token_len = self._extract_speech_token(prompt_wav)
|
||||
if resample_rate == 24000:
|
||||
# cosyvoice2, force speech_feat % speech_token = 2
|
||||
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
|
||||
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
|
||||
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
|
||||
embedding = self._extract_spk_embedding(prompt_wav)
|
||||
model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
|
||||
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
|
||||
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
|
||||
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
|
||||
'llm_embedding': embedding, 'flow_embedding': embedding}
|
||||
else:
|
||||
model_input = {**self.spk2info[zero_shot_spk_id]}
|
||||
model_input['text'] = tts_text_token
|
||||
model_input['text_len'] = tts_text_token_len
|
||||
return model_input
|
||||
|
||||
def frontend_cross_lingual(self, tts_text, prompt_wav, resample_rate, zero_shot_spk_id):
|
||||
model_input = self.frontend_zero_shot(tts_text, '', prompt_wav, resample_rate, zero_shot_spk_id)
|
||||
# in cross lingual mode, we remove prompt in llm
|
||||
del model_input['prompt_text']
|
||||
del model_input['prompt_text_len']
|
||||
del model_input['llm_prompt_speech_token']
|
||||
del model_input['llm_prompt_speech_token_len']
|
||||
return model_input
|
||||
|
||||
def frontend_instruct(self, tts_text, spk_id, instruct_text):
|
||||
model_input = self.frontend_sft(tts_text, spk_id)
|
||||
# in instruct mode, we remove spk_embedding in llm due to information leakage
|
||||
del model_input['llm_embedding']
|
||||
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text)
|
||||
model_input['prompt_text'] = instruct_text_token
|
||||
model_input['prompt_text_len'] = instruct_text_token_len
|
||||
return model_input
|
||||
|
||||
def frontend_instruct2(self, tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id):
|
||||
model_input = self.frontend_zero_shot(tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id)
|
||||
del model_input['llm_prompt_speech_token']
|
||||
del model_input['llm_prompt_speech_token_len']
|
||||
return model_input
|
||||
|
||||
def frontend_vc(self, source_speech_16k, prompt_wav, resample_rate):
|
||||
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_wav)
|
||||
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_wav)
|
||||
embedding = self._extract_spk_embedding(prompt_wav)
|
||||
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
|
||||
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
|
||||
'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
|
||||
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
|
||||
'flow_embedding': embedding}
|
||||
return model_input
|
||||
450
models/CosyVoice/cosyvoice/cli/model.py
Normal file
450
models/CosyVoice/cosyvoice/cli/model.py
Normal file
@@ -0,0 +1,450 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
# 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
from typing import Generator
|
||||
import torch
|
||||
import numpy as np
|
||||
import threading
|
||||
import time
|
||||
from torch.nn import functional as F
|
||||
from contextlib import nullcontext
|
||||
import uuid
|
||||
from cosyvoice.utils.common import fade_in_out
|
||||
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
|
||||
from cosyvoice.utils.common import TrtContextWrapper
|
||||
|
||||
|
||||
class CosyVoiceModel:
|
||||
|
||||
def __init__(self,
|
||||
llm: torch.nn.Module,
|
||||
flow: torch.nn.Module,
|
||||
hift: torch.nn.Module,
|
||||
fp16: bool = False):
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.llm = llm
|
||||
self.flow = flow
|
||||
self.hift = hift
|
||||
self.fp16 = fp16
|
||||
self.token_min_hop_len = 2 * self.flow.input_frame_rate
|
||||
self.token_max_hop_len = 4 * self.flow.input_frame_rate
|
||||
self.token_overlap_len = 20
|
||||
# mel fade in out
|
||||
self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
|
||||
self.mel_window = np.hamming(2 * self.mel_overlap_len)
|
||||
# hift cache
|
||||
self.mel_cache_len = 20
|
||||
self.source_cache_len = int(self.mel_cache_len * 256)
|
||||
# speech fade in out
|
||||
self.speech_window = np.hamming(2 * self.source_cache_len)
|
||||
# rtf and decoding related
|
||||
self.stream_scale_factor = 1
|
||||
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
||||
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
||||
self.lock = threading.Lock()
|
||||
# dict used to store session related variable
|
||||
self.tts_speech_token_dict = {}
|
||||
self.llm_end_dict = {}
|
||||
self.mel_overlap_dict = {}
|
||||
self.flow_cache_dict = {}
|
||||
self.hift_cache_dict = {}
|
||||
self.silent_tokens = []
|
||||
|
||||
def load(self, llm_model, flow_model, hift_model):
|
||||
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device, weights_only=True), strict=True)
|
||||
self.llm.to(self.device).eval()
|
||||
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device, weights_only=True), strict=True)
|
||||
self.flow.to(self.device).eval()
|
||||
# in case hift_model is a hifigan model
|
||||
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device, weights_only=True).items()}
|
||||
self.hift.load_state_dict(hift_state_dict, strict=True)
|
||||
self.hift.to(self.device).eval()
|
||||
|
||||
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
|
||||
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
|
||||
self.llm.text_encoder = llm_text_encoder
|
||||
llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
|
||||
self.llm.llm = llm_llm
|
||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||
self.flow.encoder = flow_encoder
|
||||
|
||||
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
|
||||
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
|
||||
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
|
||||
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
|
||||
del self.flow.decoder.estimator
|
||||
import tensorrt as trt
|
||||
with open(flow_decoder_estimator_model, 'rb') as f:
|
||||
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
|
||||
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
|
||||
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
|
||||
|
||||
def get_trt_kwargs(self):
|
||||
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
|
||||
opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
|
||||
max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
|
||||
input_names = ["x", "mask", "mu", "cond"]
|
||||
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
|
||||
|
||||
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
||||
cur_silent_token_num, max_silent_token_num = 0, 5
|
||||
with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
|
||||
if isinstance(text, Generator):
|
||||
assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'
|
||||
token_generator = self.llm.inference_bistream(text=text,
|
||||
prompt_text=prompt_text.to(self.device),
|
||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=llm_embedding.to(self.device))
|
||||
else:
|
||||
token_generator = self.llm.inference(text=text.to(self.device),
|
||||
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_text=prompt_text.to(self.device),
|
||||
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
||||
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=llm_embedding.to(self.device),
|
||||
uuid=uuid)
|
||||
for i in token_generator:
|
||||
if i in self.silent_tokens:
|
||||
cur_silent_token_num += 1
|
||||
if cur_silent_token_num > max_silent_token_num:
|
||||
continue
|
||||
else:
|
||||
cur_silent_token_num = 0
|
||||
self.tts_speech_token_dict[uuid].append(i)
|
||||
self.llm_end_dict[uuid] = True
|
||||
|
||||
def vc_job(self, source_speech_token, uuid):
|
||||
self.tts_speech_token_dict[uuid] = source_speech_token.flatten().tolist()
|
||||
self.llm_end_dict[uuid] = True
|
||||
|
||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
|
||||
with torch.cuda.amp.autocast(self.fp16):
|
||||
tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
|
||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_token=prompt_token.to(self.device),
|
||||
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_feat=prompt_feat.to(self.device),
|
||||
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=embedding.to(self.device),
|
||||
flow_cache=self.flow_cache_dict[uuid])
|
||||
|
||||
# mel overlap fade in out
|
||||
if self.mel_overlap_dict[uuid].shape[2] != 0:
|
||||
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
|
||||
# append hift cache
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
||||
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
||||
else:
|
||||
hift_cache_source = torch.zeros(1, 1, 0)
|
||||
# keep overlap mel and hift cache
|
||||
if finalize is False:
|
||||
self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
|
||||
tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
|
||||
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
||||
'source': tts_source[:, :, -self.source_cache_len:],
|
||||
'speech': tts_speech[:, -self.source_cache_len:]}
|
||||
tts_speech = tts_speech[:, :-self.source_cache_len]
|
||||
else:
|
||||
if speed != 1.0:
|
||||
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
||||
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
||||
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||
return tts_speech
|
||||
|
||||
def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
|
||||
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
||||
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||
prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
|
||||
# this_uuid is used to track variables related to this inference thread
|
||||
this_uuid = str(uuid.uuid1())
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
||||
self.hift_cache_dict[this_uuid] = None
|
||||
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
|
||||
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
|
||||
if source_speech_token.shape[1] == 0:
|
||||
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
||||
else:
|
||||
p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
|
||||
p.start()
|
||||
if stream is True:
|
||||
token_hop_len = self.token_min_hop_len
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
|
||||
.unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
finalize=False)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
||||
# increase token_hop_len for better speech quality
|
||||
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
|
||||
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
|
||||
break
|
||||
p.join()
|
||||
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
finalize=True)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
else:
|
||||
# deal with all tokens
|
||||
p.join()
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
uuid=this_uuid,
|
||||
finalize=True,
|
||||
speed=speed)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict.pop(this_uuid)
|
||||
self.llm_end_dict.pop(this_uuid)
|
||||
self.mel_overlap_dict.pop(this_uuid)
|
||||
self.hift_cache_dict.pop(this_uuid)
|
||||
self.flow_cache_dict.pop(this_uuid)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
|
||||
class CosyVoice2Model(CosyVoiceModel):
|
||||
|
||||
def __init__(self,
|
||||
llm: torch.nn.Module,
|
||||
flow: torch.nn.Module,
|
||||
hift: torch.nn.Module,
|
||||
fp16: bool = False):
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.llm = llm
|
||||
self.flow = flow
|
||||
self.hift = hift
|
||||
self.fp16 = fp16
|
||||
# NOTE must matching training static_chunk_size
|
||||
self.token_hop_len = 25
|
||||
# NOTE increase token_hop_len incrementally to avoid duplicate inference
|
||||
self.token_max_hop_len = 4 * self.token_hop_len
|
||||
self.stream_scale_factor = 2
|
||||
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
||||
# hift cache
|
||||
self.mel_cache_len = 8
|
||||
self.source_cache_len = int(self.mel_cache_len * 480)
|
||||
# speech fade in out
|
||||
self.speech_window = np.hamming(2 * self.source_cache_len)
|
||||
# rtf and decoding related
|
||||
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
||||
self.lock = threading.Lock()
|
||||
# dict used to store session related variable
|
||||
self.tts_speech_token_dict = {}
|
||||
self.llm_end_dict = {}
|
||||
self.hift_cache_dict = {}
|
||||
self.silent_tokens = []
|
||||
|
||||
def load_jit(self, flow_encoder_model):
|
||||
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
|
||||
self.flow.encoder = flow_encoder
|
||||
|
||||
def load_vllm(self, model_dir):
|
||||
export_cosyvoice2_vllm(self.llm, model_dir, self.device)
|
||||
from vllm import EngineArgs, LLMEngine
|
||||
engine_args = EngineArgs(model=model_dir,
|
||||
skip_tokenizer_init=True,
|
||||
enable_prompt_embeds=True,
|
||||
gpu_memory_utilization=0.2)
|
||||
self.llm.vllm = LLMEngine.from_engine_args(engine_args)
|
||||
self.llm.lock = threading.Lock()
|
||||
del self.llm.llm.model.model.layers
|
||||
|
||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
|
||||
with torch.cuda.amp.autocast(self.fp16):
|
||||
tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
|
||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_token=prompt_token.to(self.device),
|
||||
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_feat=prompt_feat.to(self.device),
|
||||
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=embedding.to(self.device),
|
||||
streaming=stream,
|
||||
finalize=finalize)
|
||||
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
||||
# append hift cache
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
|
||||
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
||||
else:
|
||||
hift_cache_source = torch.zeros(1, 1, 0)
|
||||
# keep overlap mel and hift cache
|
||||
if finalize is False:
|
||||
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
|
||||
'source': tts_source[:, :, -self.source_cache_len:],
|
||||
'speech': tts_speech[:, -self.source_cache_len:]}
|
||||
tts_speech = tts_speech[:, :-self.source_cache_len]
|
||||
else:
|
||||
if speed != 1.0:
|
||||
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
|
||||
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
||||
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
|
||||
return tts_speech
|
||||
|
||||
def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
|
||||
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
|
||||
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
|
||||
prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
|
||||
# this_uuid is used to track variables related to this inference thread
|
||||
this_uuid = str(uuid.uuid1())
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
||||
self.hift_cache_dict[this_uuid] = None
|
||||
if source_speech_token.shape[1] == 0:
|
||||
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
||||
else:
|
||||
p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
|
||||
p.start()
|
||||
if stream is True:
|
||||
token_offset = 0
|
||||
prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1])
|
||||
while True:
|
||||
time.sleep(0.1)
|
||||
this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len
|
||||
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len:
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
token_offset=token_offset,
|
||||
uuid=this_uuid,
|
||||
stream=stream,
|
||||
finalize=False)
|
||||
token_offset += this_token_hop_len
|
||||
self.token_hop_len = min(self.token_max_hop_len, self.token_hop_len * self.stream_scale_factor)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
|
||||
break
|
||||
p.join()
|
||||
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
token_offset=token_offset,
|
||||
uuid=this_uuid,
|
||||
finalize=True)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
else:
|
||||
# deal with all tokens
|
||||
p.join()
|
||||
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
|
||||
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
||||
prompt_token=flow_prompt_speech_token,
|
||||
prompt_feat=prompt_speech_feat,
|
||||
embedding=flow_embedding,
|
||||
token_offset=0,
|
||||
uuid=this_uuid,
|
||||
finalize=True,
|
||||
speed=speed)
|
||||
yield {'tts_speech': this_tts_speech.cpu()}
|
||||
with self.lock:
|
||||
self.tts_speech_token_dict.pop(this_uuid)
|
||||
self.llm_end_dict.pop(this_uuid)
|
||||
self.hift_cache_dict.pop(this_uuid)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.current_stream().synchronize()
|
||||
|
||||
|
||||
class CosyVoice3Model(CosyVoice2Model):
|
||||
|
||||
def __init__(self,
|
||||
llm: torch.nn.Module,
|
||||
flow: torch.nn.Module,
|
||||
hift: torch.nn.Module,
|
||||
fp16: bool = False):
|
||||
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
||||
self.llm = llm
|
||||
self.flow = flow
|
||||
self.hift = hift
|
||||
self.fp16 = fp16
|
||||
# NOTE must matching training static_chunk_size
|
||||
self.token_hop_len = 25
|
||||
# NOTE increase token_hop_len incrementally to avoid duplicate inference
|
||||
self.token_max_hop_len = 4 * self.token_hop_len
|
||||
self.stream_scale_factor = 2
|
||||
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
|
||||
# rtf and decoding related
|
||||
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
|
||||
self.lock = threading.Lock()
|
||||
# dict used to store session related variable
|
||||
self.tts_speech_token_dict = {}
|
||||
self.llm_end_dict = {}
|
||||
self.hift_cache_dict = {}
|
||||
# FSQ silent and breath token
|
||||
self.silent_tokens = [1, 2, 28, 29, 55, 248, 494, 2241, 2242, 2322, 2323]
|
||||
|
||||
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
|
||||
with torch.cuda.amp.autocast(self.fp16):
|
||||
tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
|
||||
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_token=prompt_token.to(self.device),
|
||||
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
||||
prompt_feat=prompt_feat.to(self.device),
|
||||
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
||||
embedding=embedding.to(self.device),
|
||||
streaming=stream,
|
||||
finalize=finalize)
|
||||
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
|
||||
# append mel cache
|
||||
if self.hift_cache_dict[uuid] is not None:
|
||||
hift_cache_mel = self.hift_cache_dict[uuid]['mel']
|
||||
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
|
||||
self.hift_cache_dict[uuid]['mel'] = tts_mel
|
||||
else:
|
||||
self.hift_cache_dict[uuid] = {'mel': tts_mel, 'speech_offset': 0}
|
||||
if speed != 1.0:
|
||||
assert token_offset == 0 and finalize is True, 'speed change only support non-stream inference mode'
|
||||
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
|
||||
tts_speech, _ = self.hift.inference(speech_feat=tts_mel, finalize=finalize)
|
||||
tts_speech = tts_speech[:, self.hift_cache_dict[uuid]['speech_offset']:]
|
||||
self.hift_cache_dict[uuid]['speech_offset'] += tts_speech.shape[1]
|
||||
return tts_speech
|
||||
0
models/CosyVoice/cosyvoice/dataset/__init__.py
Normal file
0
models/CosyVoice/cosyvoice/dataset/__init__.py
Normal file
155
models/CosyVoice/cosyvoice/dataset/dataset.py
Normal file
155
models/CosyVoice/cosyvoice/dataset/dataset.py
Normal file
@@ -0,0 +1,155 @@
|
||||
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
||||
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import random
|
||||
import math
|
||||
from functools import partial
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.utils.data import IterableDataset
|
||||
from cosyvoice.utils.file_utils import read_lists
|
||||
|
||||
|
||||
class Processor(IterableDataset):
|
||||
|
||||
def __init__(self, source, f, *args, **kw):
|
||||
assert callable(f)
|
||||
self.source = source
|
||||
self.f = f
|
||||
self.args = args
|
||||
self.kw = kw
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.source.set_epoch(epoch)
|
||||
|
||||
def __iter__(self):
|
||||
""" Return an iterator over the source dataset processed by the
|
||||
given processor.
|
||||
"""
|
||||
assert self.source is not None
|
||||
assert callable(self.f)
|
||||
return self.f(iter(self.source), *self.args, **self.kw)
|
||||
|
||||
def apply(self, f):
|
||||
assert callable(f)
|
||||
return Processor(self, f, *self.args, **self.kw)
|
||||
|
||||
|
||||
class DistributedSampler:
|
||||
|
||||
def __init__(self, shuffle=True, partition=True):
|
||||
self.epoch = -1
|
||||
self.update()
|
||||
self.shuffle = shuffle
|
||||
self.partition = partition
|
||||
|
||||
def update(self):
|
||||
assert dist.is_available()
|
||||
if dist.is_initialized():
|
||||
self.rank = dist.get_rank()
|
||||
self.world_size = dist.get_world_size()
|
||||
else:
|
||||
self.rank = 0
|
||||
self.world_size = 1
|
||||
worker_info = torch.utils.data.get_worker_info()
|
||||
if worker_info is None:
|
||||
self.worker_id = 0
|
||||
self.num_workers = 1
|
||||
else:
|
||||
self.worker_id = worker_info.id
|
||||
self.num_workers = worker_info.num_workers
|
||||
return dict(rank=self.rank,
|
||||
world_size=self.world_size,
|
||||
worker_id=self.worker_id,
|
||||
num_workers=self.num_workers)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.epoch = epoch
|
||||
|
||||
def sample(self, data):
|
||||
""" Sample data according to rank/world_size/num_workers
|
||||
|
||||
Args:
|
||||
data(List): input data list
|
||||
|
||||
Returns:
|
||||
List: data list after sample
|
||||
"""
|
||||
data = list(range(len(data)))
|
||||
# force datalist even
|
||||
if self.partition:
|
||||
if self.shuffle:
|
||||
random.Random(self.epoch).shuffle(data)
|
||||
if len(data) < self.world_size:
|
||||
data = data * math.ceil(self.world_size / len(data))
|
||||
data = data[:self.world_size]
|
||||
data = data[self.rank::self.world_size]
|
||||
if len(data) < self.num_workers:
|
||||
data = data * math.ceil(self.num_workers / len(data))
|
||||
data = data[:self.num_workers]
|
||||
data = data[self.worker_id::self.num_workers]
|
||||
return data
|
||||
|
||||
|
||||
class DataList(IterableDataset):
|
||||
|
||||
def __init__(self, lists, shuffle=True, partition=True):
|
||||
self.lists = lists
|
||||
self.sampler = DistributedSampler(shuffle, partition)
|
||||
|
||||
def set_epoch(self, epoch):
|
||||
self.sampler.set_epoch(epoch)
|
||||
|
||||
def __iter__(self):
|
||||
sampler_info = self.sampler.update()
|
||||
indexes = self.sampler.sample(self.lists)
|
||||
for index in indexes:
|
||||
data = dict(src=self.lists[index])
|
||||
data.update(sampler_info)
|
||||
yield data
|
||||
|
||||
|
||||
def Dataset(data_list_file,
|
||||
data_pipeline,
|
||||
mode='train',
|
||||
gan=False,
|
||||
dpo=False,
|
||||
shuffle=True,
|
||||
partition=True):
|
||||
""" Construct dataset from arguments
|
||||
|
||||
We have two shuffle stage in the Dataset. The first is global
|
||||
shuffle at shards tar/raw file level. The second is global shuffle
|
||||
at training samples level.
|
||||
|
||||
Args:
|
||||
data_type(str): raw/shard
|
||||
tokenizer (BaseTokenizer): tokenizer to tokenize
|
||||
partition(bool): whether to do data partition in terms of rank
|
||||
"""
|
||||
lists = read_lists(data_list_file)
|
||||
dataset = DataList(lists,
|
||||
shuffle=shuffle,
|
||||
partition=partition)
|
||||
# map partial arg to padding func
|
||||
for i in range(1, len(data_pipeline)):
|
||||
if data_pipeline[i].func.__name__ == 'compute_fbank' and gan is True:
|
||||
data_pipeline[i] = partial(data_pipeline[i], token_mel_ratio=0)
|
||||
if data_pipeline[i].func.__name__ == 'padding':
|
||||
data_pipeline[i] = partial(data_pipeline[i], gan=gan, dpo=dpo)
|
||||
for func in data_pipeline:
|
||||
dataset = Processor(dataset, func, mode=mode)
|
||||
return dataset
|
||||
431
models/CosyVoice/cosyvoice/dataset/processor.py
Normal file
431
models/CosyVoice/cosyvoice/dataset/processor.py
Normal file
@@ -0,0 +1,431 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
import random
|
||||
|
||||
import pyarrow.parquet as pq
|
||||
from io import BytesIO
|
||||
import numpy as np
|
||||
import whisper
|
||||
import torch
|
||||
import torchaudio
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
import torch.nn.functional as F
|
||||
import pyworld as pw
|
||||
from cosyvoice.utils.onnx import embedding_extractor, online_feature
|
||||
|
||||
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
|
||||
|
||||
|
||||
def parquet_opener(data, mode='train'):
|
||||
""" Give url or local file, return file descriptor
|
||||
Inplace operation.
|
||||
|
||||
Args:
|
||||
data(Iterable[str]): url or local file list
|
||||
|
||||
Returns:
|
||||
Iterable[{src, stream}]
|
||||
"""
|
||||
for sample in data:
|
||||
assert 'src' in sample
|
||||
url = sample['src']
|
||||
try:
|
||||
for df in pq.ParquetFile(url).iter_batches(batch_size=64):
|
||||
df = df.to_pandas()
|
||||
for i in range(len(df)):
|
||||
sample.update(dict(df.loc[i]))
|
||||
# NOTE do not return sample directly, must initialize a new dict
|
||||
yield {**sample}
|
||||
except Exception as ex:
|
||||
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
|
||||
|
||||
|
||||
def filter(data,
|
||||
max_length=10240,
|
||||
min_length=10,
|
||||
token_max_length=200,
|
||||
token_min_length=1,
|
||||
min_output_input_ratio=0.0005,
|
||||
max_output_input_ratio=1,
|
||||
mode='train'):
|
||||
""" Filter sample according to feature and label length
|
||||
Inplace operation.
|
||||
|
||||
Args::
|
||||
data: Iterable[{key, wav, label, sample_rate}]
|
||||
max_length: drop utterance which is greater than max_length(10ms)
|
||||
min_length: drop utterance which is less than min_length(10ms)
|
||||
token_max_length: drop utterance which is greater than
|
||||
token_max_length, especially when use char unit for
|
||||
english modeling
|
||||
token_min_length: drop utterance which is
|
||||
less than token_max_length
|
||||
min_output_input_ratio: minimal ration of
|
||||
token_length / feats_length(10ms)
|
||||
max_output_input_ratio: maximum ration of
|
||||
token_length / feats_length(10ms)
|
||||
|
||||
Returns:
|
||||
Iterable[{key, wav, label, sample_rate}]
|
||||
"""
|
||||
for sample in data:
|
||||
sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
|
||||
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
|
||||
del sample['audio_data']
|
||||
# sample['wav'] is torch.Tensor, we have 100 frames every second
|
||||
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
|
||||
if num_frames < min_length:
|
||||
continue
|
||||
if num_frames > max_length:
|
||||
continue
|
||||
if len(sample['text_token']) < token_min_length:
|
||||
continue
|
||||
if len(sample['text_token']) > token_max_length:
|
||||
continue
|
||||
if online_feature is False and len(sample['speech_token']) == 0:
|
||||
continue
|
||||
if online_feature is False and 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
|
||||
continue
|
||||
if num_frames != 0:
|
||||
if len(sample['text_token']) / num_frames < min_output_input_ratio:
|
||||
continue
|
||||
if len(sample['text_token']) / num_frames > max_output_input_ratio:
|
||||
continue
|
||||
yield sample
|
||||
|
||||
|
||||
def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
|
||||
""" Resample data.
|
||||
Inplace operation.
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, wav, label, sample_rate}]
|
||||
resample_rate: target resample rate
|
||||
|
||||
Returns:
|
||||
Iterable[{key, wav, label, sample_rate}]
|
||||
"""
|
||||
for sample in data:
|
||||
assert 'sample_rate' in sample
|
||||
assert 'speech' in sample
|
||||
sample_rate = sample['sample_rate']
|
||||
waveform = sample['speech']
|
||||
if sample_rate != resample_rate:
|
||||
if sample_rate < min_sample_rate:
|
||||
continue
|
||||
sample['sample_rate'] = resample_rate
|
||||
sample['speech'] = torchaudio.transforms.Resample(
|
||||
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
|
||||
max_val = sample['speech'].abs().max()
|
||||
if max_val > 1:
|
||||
sample['speech'] /= max_val
|
||||
yield sample
|
||||
|
||||
|
||||
def truncate(data, truncate_length=24576, mode='train'):
|
||||
""" Truncate data.
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, wav, label, sample_rate}]
|
||||
truncate_length: truncate length
|
||||
|
||||
Returns:
|
||||
Iterable[{key, wav, label, sample_rate}]
|
||||
"""
|
||||
for sample in data:
|
||||
waveform = sample['speech']
|
||||
if waveform.shape[1] > truncate_length:
|
||||
start = random.randint(0, waveform.shape[1] - truncate_length)
|
||||
waveform = waveform[:, start: start + truncate_length]
|
||||
else:
|
||||
waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
|
||||
sample['speech'] = waveform
|
||||
yield sample
|
||||
|
||||
|
||||
def compute_fbank(data,
|
||||
feat_extractor,
|
||||
num_frames=-1,
|
||||
mode='train'):
|
||||
""" Extract fbank
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, wav, label, sample_rate}]
|
||||
|
||||
Returns:
|
||||
Iterable[{key, feat, label}]
|
||||
"""
|
||||
for sample in data:
|
||||
assert 'sample_rate' in sample
|
||||
assert 'speech' in sample
|
||||
assert 'utt' in sample
|
||||
assert 'text_token' in sample
|
||||
# NOTE in cosyvoice2/3, we support online token extraction, so we need to align speech to 25hz first
|
||||
if num_frames != -1:
|
||||
index = int(np.ceil(sample['speech'].shape[1] / num_frames))
|
||||
sample['speech'] = torch.concat([sample['speech'], torch.zeros(1, index * num_frames - sample['speech'].shape[1])], dim=1)
|
||||
sample['speech_feat'] = feat_extractor(sample['speech']).squeeze(dim=0).transpose(0, 1)
|
||||
yield sample
|
||||
|
||||
|
||||
def compute_whisper_fbank(data, num_frames=-1, mode='train'):
|
||||
""" Extract whisper fbank
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, wav, label, sample_rate}]
|
||||
|
||||
Returns:
|
||||
Iterable[{key, feat, label}]
|
||||
"""
|
||||
for sample in data:
|
||||
if num_frames != -1:
|
||||
assert sample['speech'].shape[1] % num_frames == 0, 'speech length is not aligned with speech_token'
|
||||
sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
|
||||
sample['whisper_feat'] = whisper.log_mel_spectrogram(sample['speech_16k'], n_mels=128).squeeze(dim=0).transpose(0, 1)
|
||||
yield sample
|
||||
|
||||
|
||||
def compute_f0(data, sample_rate, hop_size, mode='train'):
|
||||
""" Extract f0
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, wav, label, sample_rate}]
|
||||
|
||||
Returns:
|
||||
Iterable[{key, feat, label}]
|
||||
"""
|
||||
frame_period = hop_size * 1000 / sample_rate
|
||||
for sample in data:
|
||||
assert 'sample_rate' in sample
|
||||
assert 'speech' in sample
|
||||
assert 'utt' in sample
|
||||
assert 'text_token' in sample
|
||||
waveform = sample['speech']
|
||||
_f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
|
||||
if sum(_f0 != 0) < 5: # this happens when the algorithm fails
|
||||
_f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
|
||||
f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
|
||||
f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
|
||||
sample['pitch_feat'] = f0
|
||||
yield sample
|
||||
|
||||
|
||||
def parse_embedding(data, normalize, mode='train'):
|
||||
""" Parse utt_embedding/spk_embedding
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, wav, label, sample_rate}]
|
||||
|
||||
Returns:
|
||||
Iterable[{key, feat, label}]
|
||||
"""
|
||||
for sample in data:
|
||||
if 'utt_embedding' not in sample and 'spk_embedding' not in sample:
|
||||
sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
|
||||
embedding = embedding_extractor.inference(sample['speech_16k'])
|
||||
sample['spk_embedding'] = sample['utt_embedding'] = embedding
|
||||
else:
|
||||
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
|
||||
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
|
||||
if normalize:
|
||||
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
|
||||
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
|
||||
yield sample
|
||||
|
||||
|
||||
def tokenize(data, get_tokenizer, allowed_special, mode='train'):
|
||||
""" Decode text to chars or BPE
|
||||
Inplace operation
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, wav, txt, sample_rate}]
|
||||
|
||||
Returns:
|
||||
Iterable[{key, wav, txt, tokens, label, sample_rate}]
|
||||
"""
|
||||
tokenizer = get_tokenizer()
|
||||
for sample in data:
|
||||
assert 'text' in sample
|
||||
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
|
||||
if 'instruct' in sample:
|
||||
sample['instruct_token'] = tokenizer.encode(sample['instruct'], allowed_special=allowed_special)
|
||||
yield sample
|
||||
|
||||
|
||||
def shuffle(data, shuffle_size=10000, mode='train'):
|
||||
""" Local shuffle the data
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, feat, label}]
|
||||
shuffle_size: buffer size for shuffle
|
||||
|
||||
Returns:
|
||||
Iterable[{key, feat, label}]
|
||||
"""
|
||||
buf = []
|
||||
yield_size = int(shuffle_size / 2)
|
||||
for sample in data:
|
||||
buf.append(sample)
|
||||
if len(buf) >= shuffle_size:
|
||||
random.shuffle(buf)
|
||||
for x in buf[:yield_size]:
|
||||
yield x
|
||||
buf = buf[yield_size:]
|
||||
# The sample left over
|
||||
random.shuffle(buf)
|
||||
for x in buf:
|
||||
yield x
|
||||
|
||||
|
||||
def sort(data, sort_size=500, mode='train'):
|
||||
""" Sort the data by feature length.
|
||||
Sort is used after shuffle and before batch, so we can group
|
||||
utts with similar lengths into a batch, and `sort_size` should
|
||||
be less than `shuffle_size`
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, feat, label}]
|
||||
sort_size: buffer size for sort
|
||||
|
||||
Returns:
|
||||
Iterable[{key, feat, label}]
|
||||
"""
|
||||
|
||||
buf = []
|
||||
for sample in data:
|
||||
buf.append(sample)
|
||||
if len(buf) >= sort_size:
|
||||
buf.sort(key=lambda x: x['speech_feat'].size(0))
|
||||
for x in buf:
|
||||
yield x
|
||||
buf = []
|
||||
# The sample left over
|
||||
buf.sort(key=lambda x: x['speech_feat'].size(0))
|
||||
for x in buf:
|
||||
yield x
|
||||
|
||||
|
||||
def static_batch(data, batch_size=16):
|
||||
""" Static batch the data by `batch_size`
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, feat, label}]
|
||||
batch_size: batch size
|
||||
|
||||
Returns:
|
||||
Iterable[List[{key, feat, label}]]
|
||||
"""
|
||||
buf = []
|
||||
for sample in data:
|
||||
buf.append(sample)
|
||||
if len(buf) >= batch_size:
|
||||
yield buf
|
||||
buf = []
|
||||
if len(buf) > 0:
|
||||
yield buf
|
||||
|
||||
|
||||
def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
|
||||
""" Dynamic batch the data until the total frames in batch
|
||||
reach `max_frames_in_batch`
|
||||
|
||||
Args:
|
||||
data: Iterable[{key, feat, label}]
|
||||
max_frames_in_batch: max_frames in one batch
|
||||
|
||||
Returns:
|
||||
Iterable[List[{key, feat, label}]]
|
||||
"""
|
||||
buf = []
|
||||
longest_frames = 0
|
||||
for sample in data:
|
||||
assert 'speech_feat' in sample
|
||||
assert isinstance(sample['speech_feat'], torch.Tensor)
|
||||
new_sample_frames = sample['speech_feat'].size(0)
|
||||
longest_frames = max(longest_frames, new_sample_frames)
|
||||
frames_after_padding = longest_frames * (len(buf) + 1)
|
||||
if frames_after_padding > max_frames_in_batch:
|
||||
yield buf
|
||||
buf = [sample]
|
||||
longest_frames = new_sample_frames
|
||||
else:
|
||||
buf.append(sample)
|
||||
if len(buf) > 0:
|
||||
yield buf
|
||||
|
||||
|
||||
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
|
||||
""" Wrapper for static/dynamic batch
|
||||
"""
|
||||
if batch_type == 'static':
|
||||
return static_batch(data, batch_size)
|
||||
elif batch_type == 'dynamic':
|
||||
return dynamic_batch(data, max_frames_in_batch)
|
||||
else:
|
||||
logging.fatal('Unsupported batch type {}'.format(batch_type))
|
||||
|
||||
|
||||
def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
|
||||
""" Padding the data into training data
|
||||
|
||||
Args:
|
||||
data: Iterable[List[{key, feat, label}]]
|
||||
|
||||
Returns:
|
||||
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
|
||||
"""
|
||||
for sample in data:
|
||||
assert isinstance(sample, list)
|
||||
order = torch.argsort(torch.tensor([x['speech'].size(1) for x in sample], dtype=torch.int32), descending=True)
|
||||
batch = {}
|
||||
batch['utts'] = [sample[i]['utt'] for i in order]
|
||||
batch['text'] = [sample[i]['text'] for i in order]
|
||||
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
|
||||
batch['text_token_len'] = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
|
||||
batch['text_token'] = pad_sequence(text_token, batch_first=True, padding_value=0)
|
||||
speech_feat = [sample[i]['speech_feat'] for i in order]
|
||||
batch['speech_feat_len'] = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
|
||||
batch['speech_feat'] = pad_sequence(speech_feat, batch_first=True, padding_value=0)
|
||||
batch['utt_embedding'] = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
|
||||
batch['spk_embedding'] = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
|
||||
if torch.tensor(['instruct_token' in sample[i] for i in order]).all():
|
||||
instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order]
|
||||
batch['instruct_token_len'] = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32)
|
||||
batch['instruct_token'] = pad_sequence(instruct_token, batch_first=True, padding_value=0)
|
||||
if torch.tensor(['whisper_feat' in sample[i] for i in order]).all():
|
||||
whisper_feat = [sample[i]['whisper_feat'] for i in order]
|
||||
batch['whisper_feat_len'] = torch.tensor([i.size(0) for i in whisper_feat], dtype=torch.int32)
|
||||
batch['whisper_feat'] = pad_sequence(whisper_feat, batch_first=True, padding_value=0)
|
||||
if torch.tensor(['speech_token' in sample[i] for i in order]).all():
|
||||
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
|
||||
batch['speech_token_len'] = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
|
||||
batch['speech_token'] = pad_sequence(speech_token, batch_first=True, padding_value=0)
|
||||
if gan is True:
|
||||
# in gan train, we need speech/pitch_feat
|
||||
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
|
||||
batch['speech_len'] = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
|
||||
batch['speech'] = pad_sequence(speech, batch_first=True, padding_value=0)
|
||||
pitch_feat = [sample[i]['pitch_feat'] for i in order]
|
||||
batch['pitch_feat_len'] = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
|
||||
batch['pitch_feat'] = pad_sequence(pitch_feat, batch_first=True, padding_value=0)
|
||||
if dpo is True:
|
||||
reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
|
||||
batch['reject_speech_token_len'] = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
|
||||
batch['reject_speech_token'] = pad_sequence(reject_speech_token, batch_first=True, padding_value=0)
|
||||
if use_spk_embedding is True:
|
||||
batch["embedding"] = batch["spk_embedding"]
|
||||
else:
|
||||
batch["embedding"] = batch["utt_embedding"]
|
||||
yield batch
|
||||
176
models/CosyVoice/cosyvoice/flow/DiT/dit.py
Normal file
176
models/CosyVoice/cosyvoice/flow/DiT/dit.py
Normal file
@@ -0,0 +1,176 @@
|
||||
|
||||
"""
|
||||
ein notation:
|
||||
b - batch
|
||||
n - sequence
|
||||
nt - text sequence
|
||||
nw - raw wave length
|
||||
d - dimension
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from einops import repeat
|
||||
from x_transformers.x_transformers import RotaryEmbedding
|
||||
from cosyvoice.utils.mask import add_optional_chunk_mask
|
||||
from cosyvoice.flow.DiT.modules import (
|
||||
TimestepEmbedding,
|
||||
ConvNeXtV2Block,
|
||||
CausalConvPositionEmbedding,
|
||||
DiTBlock,
|
||||
AdaLayerNormZero_Final,
|
||||
precompute_freqs_cis,
|
||||
get_pos_embed_indices,
|
||||
)
|
||||
|
||||
|
||||
# Text embedding
|
||||
|
||||
|
||||
class TextEmbedding(nn.Module):
|
||||
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
|
||||
super().__init__()
|
||||
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
|
||||
|
||||
if conv_layers > 0:
|
||||
self.extra_modeling = True
|
||||
self.precompute_max_pos = 4096 # ~44s of 24khz audio
|
||||
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
|
||||
self.text_blocks = nn.Sequential(
|
||||
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
|
||||
)
|
||||
else:
|
||||
self.extra_modeling = False
|
||||
|
||||
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
|
||||
batch, text_len = text.shape[0], text.shape[1]
|
||||
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
|
||||
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
|
||||
text = F.pad(text, (0, seq_len - text_len), value=0)
|
||||
|
||||
if drop_text: # cfg for text
|
||||
text = torch.zeros_like(text)
|
||||
|
||||
text = self.text_embed(text) # b n -> b n d
|
||||
|
||||
# possible extra modeling
|
||||
if self.extra_modeling:
|
||||
# sinus pos emb
|
||||
batch_start = torch.zeros((batch,), dtype=torch.long)
|
||||
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
|
||||
text_pos_embed = self.freqs_cis[pos_idx]
|
||||
text = text + text_pos_embed
|
||||
|
||||
# convnextv2 blocks
|
||||
text = self.text_blocks(text)
|
||||
|
||||
return text
|
||||
|
||||
|
||||
# noised input audio and context mixing embedding
|
||||
|
||||
|
||||
class InputEmbedding(nn.Module):
|
||||
def __init__(self, mel_dim, text_dim, out_dim, spk_dim=None):
|
||||
super().__init__()
|
||||
spk_dim = 0 if spk_dim is None else spk_dim
|
||||
self.spk_dim = spk_dim
|
||||
self.proj = nn.Linear(mel_dim * 2 + text_dim + spk_dim, out_dim)
|
||||
self.conv_pos_embed = CausalConvPositionEmbedding(dim=out_dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"],
|
||||
cond: float["b n d"],
|
||||
text_embed: float["b n d"],
|
||||
spks: float["b d"],
|
||||
):
|
||||
to_cat = [x, cond, text_embed]
|
||||
if self.spk_dim > 0:
|
||||
spks = repeat(spks, "b c -> b t c", t=x.shape[1])
|
||||
to_cat.append(spks)
|
||||
|
||||
x = self.proj(torch.cat(to_cat, dim=-1))
|
||||
x = self.conv_pos_embed(x) + x
|
||||
return x
|
||||
|
||||
|
||||
# Transformer backbone using DiT blocks
|
||||
|
||||
|
||||
class DiT(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
dim,
|
||||
depth=8,
|
||||
heads=8,
|
||||
dim_head=64,
|
||||
dropout=0.1,
|
||||
ff_mult=4,
|
||||
mel_dim=80,
|
||||
mu_dim=None,
|
||||
long_skip_connection=False,
|
||||
spk_dim=None,
|
||||
out_channels=None,
|
||||
static_chunk_size=50,
|
||||
num_decoding_left_chunks=2
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.time_embed = TimestepEmbedding(dim)
|
||||
if mu_dim is None:
|
||||
mu_dim = mel_dim
|
||||
self.input_embed = InputEmbedding(mel_dim, mu_dim, dim, spk_dim)
|
||||
|
||||
self.rotary_embed = RotaryEmbedding(dim_head)
|
||||
|
||||
self.dim = dim
|
||||
self.depth = depth
|
||||
|
||||
self.transformer_blocks = nn.ModuleList(
|
||||
[DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
|
||||
)
|
||||
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
|
||||
|
||||
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
|
||||
self.proj_out = nn.Linear(dim, mel_dim)
|
||||
self.out_channels = out_channels
|
||||
self.static_chunk_size = static_chunk_size
|
||||
self.num_decoding_left_chunks = num_decoding_left_chunks
|
||||
|
||||
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
||||
x = x.transpose(1, 2)
|
||||
mu = mu.transpose(1, 2)
|
||||
cond = cond.transpose(1, 2)
|
||||
spks = spks.unsqueeze(dim=1)
|
||||
batch, seq_len = x.shape[0], x.shape[1]
|
||||
if t.ndim == 0:
|
||||
t = t.repeat(batch)
|
||||
|
||||
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
|
||||
t = self.time_embed(t)
|
||||
x = self.input_embed(x, cond, mu, spks.squeeze(1))
|
||||
|
||||
rope = self.rotary_embed.forward_from_seq_len(seq_len)
|
||||
|
||||
if self.long_skip_connection is not None:
|
||||
residual = x
|
||||
|
||||
if streaming is True:
|
||||
attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, self.static_chunk_size, -1).unsqueeze(dim=1)
|
||||
else:
|
||||
attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1).unsqueeze(dim=1)
|
||||
|
||||
for block in self.transformer_blocks:
|
||||
x = block(x, t, mask=attn_mask.bool(), rope=rope)
|
||||
|
||||
if self.long_skip_connection is not None:
|
||||
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
|
||||
|
||||
x = self.norm_out(x, t)
|
||||
output = self.proj_out(x).transpose(1, 2)
|
||||
return output
|
||||
616
models/CosyVoice/cosyvoice/flow/DiT/modules.py
Normal file
616
models/CosyVoice/cosyvoice/flow/DiT/modules.py
Normal file
@@ -0,0 +1,616 @@
|
||||
|
||||
"""
|
||||
ein notation:
|
||||
b - batch
|
||||
n - sequence
|
||||
nt - text sequence
|
||||
nw - raw wave length
|
||||
d - dimension
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
from typing import Optional
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
|
||||
from x_transformers.x_transformers import apply_rotary_pos_emb
|
||||
|
||||
|
||||
# raw wav to mel spec
|
||||
class MelSpec(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
filter_length=1024,
|
||||
hop_length=256,
|
||||
win_length=1024,
|
||||
n_mel_channels=100,
|
||||
target_sample_rate=24_000,
|
||||
normalize=False,
|
||||
power=1,
|
||||
norm=None,
|
||||
center=True,
|
||||
):
|
||||
super().__init__()
|
||||
self.n_mel_channels = n_mel_channels
|
||||
|
||||
self.mel_stft = torchaudio.transforms.MelSpectrogram(
|
||||
sample_rate=target_sample_rate,
|
||||
n_fft=filter_length,
|
||||
win_length=win_length,
|
||||
hop_length=hop_length,
|
||||
n_mels=n_mel_channels,
|
||||
power=power,
|
||||
center=center,
|
||||
normalized=normalize,
|
||||
norm=norm,
|
||||
)
|
||||
|
||||
self.register_buffer("dummy", torch.tensor(0), persistent=False)
|
||||
|
||||
def forward(self, inp):
|
||||
if len(inp.shape) == 3:
|
||||
inp = inp.squeeze(1) # 'b 1 nw -> b nw'
|
||||
|
||||
assert len(inp.shape) == 2
|
||||
|
||||
if self.dummy.device != inp.device:
|
||||
self.to(inp.device)
|
||||
|
||||
mel = self.mel_stft(inp)
|
||||
mel = mel.clamp(min=1e-5).log()
|
||||
return mel
|
||||
|
||||
|
||||
# sinusoidal position embedding
|
||||
|
||||
|
||||
class SinusPositionEmbedding(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.dim = dim
|
||||
|
||||
def forward(self, x, scale=1000):
|
||||
device = x.device
|
||||
half_dim = self.dim // 2
|
||||
emb = math.log(10000) / (half_dim - 1)
|
||||
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
|
||||
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
|
||||
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
||||
return emb
|
||||
|
||||
|
||||
# convolutional position embedding
|
||||
|
||||
|
||||
class ConvPositionEmbedding(nn.Module):
|
||||
def __init__(self, dim, kernel_size=31, groups=16):
|
||||
super().__init__()
|
||||
assert kernel_size % 2 != 0
|
||||
self.conv1d = nn.Sequential(
|
||||
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
||||
nn.Mish(),
|
||||
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
|
||||
if mask is not None:
|
||||
mask = mask[..., None]
|
||||
x = x.masked_fill(~mask, 0.0)
|
||||
|
||||
x = x.permute(0, 2, 1)
|
||||
x = self.conv1d(x)
|
||||
out = x.permute(0, 2, 1)
|
||||
|
||||
if mask is not None:
|
||||
out = out.masked_fill(~mask, 0.0)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class CausalConvPositionEmbedding(nn.Module):
|
||||
def __init__(self, dim, kernel_size=31, groups=16):
|
||||
super().__init__()
|
||||
assert kernel_size % 2 != 0
|
||||
self.kernel_size = kernel_size
|
||||
self.conv1 = nn.Sequential(
|
||||
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
|
||||
nn.Mish(),
|
||||
)
|
||||
self.conv2 = nn.Sequential(
|
||||
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
|
||||
if mask is not None:
|
||||
mask = mask[..., None]
|
||||
x = x.masked_fill(~mask, 0.0)
|
||||
|
||||
x = x.permute(0, 2, 1)
|
||||
x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
|
||||
x = self.conv1(x)
|
||||
x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
|
||||
x = self.conv2(x)
|
||||
out = x.permute(0, 2, 1)
|
||||
|
||||
if mask is not None:
|
||||
out = out.masked_fill(~mask, 0.0)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
# rotary positional embedding related
|
||||
|
||||
|
||||
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
|
||||
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
|
||||
# has some connection to NTK literature
|
||||
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
|
||||
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
|
||||
theta *= theta_rescale_factor ** (dim / (dim - 2))
|
||||
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
||||
t = torch.arange(end, device=freqs.device) # type: ignore
|
||||
freqs = torch.outer(t, freqs).float() # type: ignore
|
||||
freqs_cos = torch.cos(freqs) # real part
|
||||
freqs_sin = torch.sin(freqs) # imaginary part
|
||||
return torch.cat([freqs_cos, freqs_sin], dim=-1)
|
||||
|
||||
|
||||
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
|
||||
# length = length if isinstance(length, int) else length.max()
|
||||
scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
|
||||
pos = (
|
||||
start.unsqueeze(1)
|
||||
+ (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
|
||||
)
|
||||
# avoid extra long error.
|
||||
pos = torch.where(pos < max_pos, pos, max_pos - 1)
|
||||
return pos
|
||||
|
||||
|
||||
# Global Response Normalization layer (Instance Normalization ?)
|
||||
|
||||
|
||||
class GRN(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
|
||||
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
|
||||
|
||||
def forward(self, x):
|
||||
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
|
||||
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
|
||||
return self.gamma * (x * Nx) + self.beta + x
|
||||
|
||||
|
||||
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
|
||||
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
|
||||
|
||||
|
||||
class ConvNeXtV2Block(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
intermediate_dim: int,
|
||||
dilation: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
padding = (dilation * (7 - 1)) // 2
|
||||
self.dwconv = nn.Conv1d(
|
||||
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
|
||||
) # depthwise conv
|
||||
self.norm = nn.LayerNorm(dim, eps=1e-6)
|
||||
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
|
||||
self.act = nn.GELU()
|
||||
self.grn = GRN(intermediate_dim)
|
||||
self.pwconv2 = nn.Linear(intermediate_dim, dim)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
residual = x
|
||||
x = x.transpose(1, 2) # b n d -> b d n
|
||||
x = self.dwconv(x)
|
||||
x = x.transpose(1, 2) # b d n -> b n d
|
||||
x = self.norm(x)
|
||||
x = self.pwconv1(x)
|
||||
x = self.act(x)
|
||||
x = self.grn(x)
|
||||
x = self.pwconv2(x)
|
||||
return residual + x
|
||||
|
||||
|
||||
# AdaLayerNormZero
|
||||
# return with modulated x for attn input, and params for later mlp modulation
|
||||
|
||||
|
||||
class AdaLayerNormZero(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(dim, dim * 6)
|
||||
|
||||
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x, emb=None):
|
||||
emb = self.linear(self.silu(emb))
|
||||
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
|
||||
|
||||
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
||||
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
|
||||
|
||||
|
||||
# AdaLayerNormZero for final layer
|
||||
# return only with modulated x for attn input, cuz no more mlp modulation
|
||||
|
||||
|
||||
class AdaLayerNormZero_Final(nn.Module):
|
||||
def __init__(self, dim):
|
||||
super().__init__()
|
||||
|
||||
self.silu = nn.SiLU()
|
||||
self.linear = nn.Linear(dim, dim * 2)
|
||||
|
||||
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
|
||||
def forward(self, x, emb):
|
||||
emb = self.linear(self.silu(emb))
|
||||
scale, shift = torch.chunk(emb, 2, dim=1)
|
||||
|
||||
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
||||
return x
|
||||
|
||||
|
||||
# FeedForward
|
||||
|
||||
|
||||
class FeedForward(nn.Module):
|
||||
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
|
||||
super().__init__()
|
||||
inner_dim = int(dim * mult)
|
||||
dim_out = dim_out if dim_out is not None else dim
|
||||
|
||||
activation = nn.GELU(approximate=approximate)
|
||||
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
|
||||
self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
|
||||
|
||||
def forward(self, x):
|
||||
return self.ff(x)
|
||||
|
||||
|
||||
# Attention with possible joint part
|
||||
# modified from diffusers/src/diffusers/models/attention_processor.py
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
processor: JointAttnProcessor | AttnProcessor,
|
||||
dim: int,
|
||||
heads: int = 8,
|
||||
dim_head: int = 64,
|
||||
dropout: float = 0.0,
|
||||
context_dim: Optional[int] = None, # if not None -> joint attention
|
||||
context_pre_only=None,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
if not hasattr(F, "scaled_dot_product_attention"):
|
||||
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
||||
|
||||
self.processor = processor
|
||||
|
||||
self.dim = dim
|
||||
self.heads = heads
|
||||
self.inner_dim = dim_head * heads
|
||||
self.dropout = dropout
|
||||
|
||||
self.context_dim = context_dim
|
||||
self.context_pre_only = context_pre_only
|
||||
|
||||
self.to_q = nn.Linear(dim, self.inner_dim)
|
||||
self.to_k = nn.Linear(dim, self.inner_dim)
|
||||
self.to_v = nn.Linear(dim, self.inner_dim)
|
||||
|
||||
if self.context_dim is not None:
|
||||
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
|
||||
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
|
||||
if self.context_pre_only is not None:
|
||||
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
|
||||
|
||||
self.to_out = nn.ModuleList([])
|
||||
self.to_out.append(nn.Linear(self.inner_dim, dim))
|
||||
self.to_out.append(nn.Dropout(dropout))
|
||||
|
||||
if self.context_pre_only is not None and not self.context_pre_only:
|
||||
self.to_out_c = nn.Linear(self.inner_dim, dim)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: float["b n d"], # noised input x # noqa: F722
|
||||
c: float["b n d"] = None, # context c # noqa: F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
rope=None, # rotary position embedding for x
|
||||
c_rope=None, # rotary position embedding for c
|
||||
) -> torch.Tensor:
|
||||
if c is not None:
|
||||
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
|
||||
else:
|
||||
return self.processor(self, x, mask=mask, rope=rope)
|
||||
|
||||
|
||||
# Attention processor
|
||||
|
||||
|
||||
class AttnProcessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
x: float["b n d"], # noised input x # noqa: F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
rope=None, # rotary position embedding
|
||||
) -> torch.FloatTensor:
|
||||
batch_size = x.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(x)
|
||||
key = attn.to_k(x)
|
||||
value = attn.to_v(x)
|
||||
|
||||
# apply rotary position embedding
|
||||
if rope is not None:
|
||||
freqs, xpos_scale = rope
|
||||
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
||||
|
||||
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
||||
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
||||
|
||||
# attention
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
||||
if mask is not None:
|
||||
attn_mask = mask
|
||||
if attn_mask.dim() == 2:
|
||||
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
||||
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
||||
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
x = x.to(query.dtype)
|
||||
|
||||
# linear proj
|
||||
x = attn.to_out[0](x)
|
||||
# dropout
|
||||
x = attn.to_out[1](x)
|
||||
|
||||
if mask is not None:
|
||||
if mask.dim() == 2:
|
||||
mask = mask.unsqueeze(-1)
|
||||
else:
|
||||
mask = mask[:, 0, -1].unsqueeze(-1)
|
||||
x = x.masked_fill(~mask, 0.0)
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# Joint Attention processor for MM-DiT
|
||||
# modified from diffusers/src/diffusers/models/attention_processor.py
|
||||
|
||||
|
||||
class JointAttnProcessor:
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
attn: Attention,
|
||||
x: float["b n d"], # noised input x # noqa: F722
|
||||
c: float["b nt d"] = None, # context c, here text # noqa: F722
|
||||
mask: bool["b n"] | None = None, # noqa: F722
|
||||
rope=None, # rotary position embedding for x
|
||||
c_rope=None, # rotary position embedding for c
|
||||
) -> torch.FloatTensor:
|
||||
residual = x
|
||||
|
||||
batch_size = c.shape[0]
|
||||
|
||||
# `sample` projections.
|
||||
query = attn.to_q(x)
|
||||
key = attn.to_k(x)
|
||||
value = attn.to_v(x)
|
||||
|
||||
# `context` projections.
|
||||
c_query = attn.to_q_c(c)
|
||||
c_key = attn.to_k_c(c)
|
||||
c_value = attn.to_v_c(c)
|
||||
|
||||
# apply rope for context and noised input independently
|
||||
if rope is not None:
|
||||
freqs, xpos_scale = rope
|
||||
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
||||
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
|
||||
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
|
||||
if c_rope is not None:
|
||||
freqs, xpos_scale = c_rope
|
||||
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
|
||||
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
|
||||
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
|
||||
|
||||
# attention
|
||||
query = torch.cat([query, c_query], dim=1)
|
||||
key = torch.cat([key, c_key], dim=1)
|
||||
value = torch.cat([value, c_value], dim=1)
|
||||
|
||||
inner_dim = key.shape[-1]
|
||||
head_dim = inner_dim // attn.heads
|
||||
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
||||
|
||||
# mask. e.g. inference got a batch with different target durations, mask out the padding
|
||||
if mask is not None:
|
||||
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
|
||||
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
|
||||
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
|
||||
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
||||
x = x.to(query.dtype)
|
||||
|
||||
# Split the attention outputs.
|
||||
x, c = (
|
||||
x[:, : residual.shape[1]],
|
||||
x[:, residual.shape[1]:],
|
||||
)
|
||||
|
||||
# linear proj
|
||||
x = attn.to_out[0](x)
|
||||
# dropout
|
||||
x = attn.to_out[1](x)
|
||||
if not attn.context_pre_only:
|
||||
c = attn.to_out_c(c)
|
||||
|
||||
if mask is not None:
|
||||
mask = mask.unsqueeze(-1)
|
||||
x = x.masked_fill(~mask, 0.0)
|
||||
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
|
||||
|
||||
return x, c
|
||||
|
||||
|
||||
# DiT Block
|
||||
|
||||
|
||||
class DiTBlock(nn.Module):
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
|
||||
super().__init__()
|
||||
|
||||
self.attn_norm = AdaLayerNormZero(dim)
|
||||
self.attn = Attention(
|
||||
processor=AttnProcessor(),
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
dropout=dropout,
|
||||
)
|
||||
|
||||
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
||||
|
||||
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
|
||||
# pre-norm & modulation for attention input
|
||||
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
|
||||
|
||||
# attention
|
||||
attn_output = self.attn(x=norm, mask=mask, rope=rope)
|
||||
|
||||
# process attention output for input x
|
||||
x = x + gate_msa.unsqueeze(1) * attn_output
|
||||
|
||||
ff_norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
||||
ff_output = self.ff(ff_norm)
|
||||
x = x + gate_mlp.unsqueeze(1) * ff_output
|
||||
|
||||
return x
|
||||
|
||||
|
||||
# MMDiT Block https://arxiv.org/abs/2403.03206
|
||||
|
||||
|
||||
class MMDiTBlock(nn.Module):
|
||||
r"""
|
||||
modified from diffusers/src/diffusers/models/attention.py
|
||||
|
||||
notes.
|
||||
_c: context related. text, cond, etc. (left part in sd3 fig2.b)
|
||||
_x: noised input related. (right part)
|
||||
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
|
||||
"""
|
||||
|
||||
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
|
||||
super().__init__()
|
||||
|
||||
self.context_pre_only = context_pre_only
|
||||
|
||||
self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
|
||||
self.attn_norm_x = AdaLayerNormZero(dim)
|
||||
self.attn = Attention(
|
||||
processor=JointAttnProcessor(),
|
||||
dim=dim,
|
||||
heads=heads,
|
||||
dim_head=dim_head,
|
||||
dropout=dropout,
|
||||
context_dim=dim,
|
||||
context_pre_only=context_pre_only,
|
||||
)
|
||||
|
||||
if not context_pre_only:
|
||||
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
||||
else:
|
||||
self.ff_norm_c = None
|
||||
self.ff_c = None
|
||||
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
|
||||
self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
|
||||
|
||||
def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
|
||||
# pre-norm & modulation for attention input
|
||||
if self.context_pre_only:
|
||||
norm_c = self.attn_norm_c(c, t)
|
||||
else:
|
||||
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
|
||||
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
|
||||
|
||||
# attention
|
||||
x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
|
||||
|
||||
# process attention output for context c
|
||||
if self.context_pre_only:
|
||||
c = None
|
||||
else: # if not last layer
|
||||
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
|
||||
|
||||
norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
||||
c_ff_output = self.ff_c(norm_c)
|
||||
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
|
||||
|
||||
# process attention output for input x
|
||||
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
|
||||
|
||||
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
|
||||
x_ff_output = self.ff_x(norm_x)
|
||||
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
|
||||
|
||||
return c, x
|
||||
|
||||
|
||||
# time step conditioning embedding
|
||||
|
||||
|
||||
class TimestepEmbedding(nn.Module):
|
||||
def __init__(self, dim, freq_embed_dim=256):
|
||||
super().__init__()
|
||||
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
|
||||
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
|
||||
|
||||
def forward(self, timestep: float["b"]): # noqa: F821
|
||||
time_hidden = self.time_embed(timestep)
|
||||
time_hidden = time_hidden.to(timestep.dtype)
|
||||
time = self.time_mlp(time_hidden) # b d
|
||||
return time
|
||||
494
models/CosyVoice/cosyvoice/flow/decoder.py
Normal file
494
models/CosyVoice/cosyvoice/flow/decoder.py
Normal file
@@ -0,0 +1,494 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Tuple
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from einops import pack, rearrange, repeat
|
||||
from cosyvoice.utils.common import mask_to_bias
|
||||
from cosyvoice.utils.mask import add_optional_chunk_mask
|
||||
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
|
||||
from matcha.models.components.transformer import BasicTransformerBlock
|
||||
|
||||
|
||||
class Transpose(torch.nn.Module):
|
||||
def __init__(self, dim0: int, dim1: int):
|
||||
super().__init__()
|
||||
self.dim0 = dim0
|
||||
self.dim1 = dim1
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = torch.transpose(x, self.dim0, self.dim1)
|
||||
return x
|
||||
|
||||
|
||||
class CausalConv1d(torch.nn.Conv1d):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = 'zeros',
|
||||
device=None,
|
||||
dtype=None
|
||||
) -> None:
|
||||
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
||||
kernel_size, stride,
|
||||
padding=0, dilation=dilation,
|
||||
groups=groups, bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
device=device, dtype=dtype)
|
||||
assert stride == 1
|
||||
self.causal_padding = kernel_size - 1
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
||||
x = super(CausalConv1d, self).forward(x)
|
||||
return x
|
||||
|
||||
|
||||
class CausalBlock1D(Block1D):
|
||||
def __init__(self, dim: int, dim_out: int):
|
||||
super(CausalBlock1D, self).__init__(dim, dim_out)
|
||||
self.block = torch.nn.Sequential(
|
||||
CausalConv1d(dim, dim_out, 3),
|
||||
Transpose(1, 2),
|
||||
nn.LayerNorm(dim_out),
|
||||
Transpose(1, 2),
|
||||
nn.Mish(),
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
output = self.block(x * mask)
|
||||
return output * mask
|
||||
|
||||
|
||||
class CausalResnetBlock1D(ResnetBlock1D):
|
||||
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
|
||||
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
|
||||
self.block1 = CausalBlock1D(dim, dim_out)
|
||||
self.block2 = CausalBlock1D(dim_out, dim_out)
|
||||
|
||||
|
||||
class ConditionalDecoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
channels=(256, 256),
|
||||
dropout=0.05,
|
||||
attention_head_dim=64,
|
||||
n_blocks=1,
|
||||
num_mid_blocks=2,
|
||||
num_heads=4,
|
||||
act_fn="snake",
|
||||
):
|
||||
"""
|
||||
This decoder requires an input with the same shape of the target. So, if your text content
|
||||
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
||||
"""
|
||||
super().__init__()
|
||||
channels = tuple(channels)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
|
||||
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
||||
time_embed_dim = channels[0] * 4
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=in_channels,
|
||||
time_embed_dim=time_embed_dim,
|
||||
act_fn="silu",
|
||||
)
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_blocks = nn.ModuleList([])
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
output_channel = in_channels
|
||||
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
||||
input_channel = output_channel
|
||||
output_channel = channels[i]
|
||||
is_last = i == len(channels) - 1
|
||||
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
dim=output_channel,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
downsample = (
|
||||
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
||||
)
|
||||
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
||||
|
||||
for _ in range(num_mid_blocks):
|
||||
input_channel = channels[-1]
|
||||
out_channels = channels[-1]
|
||||
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
dim=output_channel,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
||||
|
||||
channels = channels[::-1] + (channels[0],)
|
||||
for i in range(len(channels) - 1):
|
||||
input_channel = channels[i] * 2
|
||||
output_channel = channels[i + 1]
|
||||
is_last = i == len(channels) - 2
|
||||
resnet = ResnetBlock1D(
|
||||
dim=input_channel,
|
||||
dim_out=output_channel,
|
||||
time_emb_dim=time_embed_dim,
|
||||
)
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
dim=output_channel,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
upsample = (
|
||||
Upsample1D(output_channel, use_conv_transpose=True)
|
||||
if not is_last
|
||||
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
|
||||
)
|
||||
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
||||
self.final_block = Block1D(channels[-1], channels[-1])
|
||||
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
||||
self.initialize_weights()
|
||||
|
||||
def initialize_weights(self):
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv1d):
|
||||
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.GroupNorm):
|
||||
nn.init.constant_(m.weight, 1)
|
||||
nn.init.constant_(m.bias, 0)
|
||||
elif isinstance(m, nn.Linear):
|
||||
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
|
||||
if m.bias is not None:
|
||||
nn.init.constant_(m.bias, 0)
|
||||
|
||||
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
||||
"""Forward pass of the UNet1DConditional model.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): shape (batch_size, in_channels, time)
|
||||
mask (_type_): shape (batch_size, 1, time)
|
||||
t (_type_): shape (batch_size)
|
||||
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
||||
cond (_type_, optional): placeholder for future use. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: _description_
|
||||
ValueError: _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
|
||||
t = self.time_embeddings(t).to(t.dtype)
|
||||
t = self.time_mlp(t)
|
||||
|
||||
x = pack([x, mu], "b * t")[0]
|
||||
|
||||
if spks is not None:
|
||||
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
||||
x = pack([x, spks], "b * t")[0]
|
||||
if cond is not None:
|
||||
x = pack([x, cond], "b * t")[0]
|
||||
|
||||
hiddens = []
|
||||
masks = [mask]
|
||||
for resnet, transformer_blocks, downsample in self.down_blocks:
|
||||
mask_down = masks[-1]
|
||||
x = resnet(x, mask_down, t)
|
||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||
hiddens.append(x) # Save hidden states for skip connections
|
||||
x = downsample(x * mask_down)
|
||||
masks.append(mask_down[:, :, ::2])
|
||||
masks = masks[:-1]
|
||||
mask_mid = masks[-1]
|
||||
|
||||
for resnet, transformer_blocks in self.mid_blocks:
|
||||
x = resnet(x, mask_mid, t)
|
||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||
|
||||
for resnet, transformer_blocks, upsample in self.up_blocks:
|
||||
mask_up = masks.pop()
|
||||
skip = hiddens.pop()
|
||||
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
||||
x = resnet(x, mask_up, t)
|
||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||
x = upsample(x * mask_up)
|
||||
x = self.final_block(x, mask_up)
|
||||
output = self.final_proj(x * mask_up)
|
||||
return output * mask
|
||||
|
||||
|
||||
class CausalConditionalDecoder(ConditionalDecoder):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels,
|
||||
out_channels,
|
||||
channels=(256, 256),
|
||||
dropout=0.05,
|
||||
attention_head_dim=64,
|
||||
n_blocks=1,
|
||||
num_mid_blocks=2,
|
||||
num_heads=4,
|
||||
act_fn="snake",
|
||||
static_chunk_size=50,
|
||||
num_decoding_left_chunks=2,
|
||||
):
|
||||
"""
|
||||
This decoder requires an input with the same shape of the target. So, if your text content
|
||||
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
|
||||
"""
|
||||
torch.nn.Module.__init__(self)
|
||||
channels = tuple(channels)
|
||||
self.in_channels = in_channels
|
||||
self.out_channels = out_channels
|
||||
self.time_embeddings = SinusoidalPosEmb(in_channels)
|
||||
time_embed_dim = channels[0] * 4
|
||||
self.time_mlp = TimestepEmbedding(
|
||||
in_channels=in_channels,
|
||||
time_embed_dim=time_embed_dim,
|
||||
act_fn="silu",
|
||||
)
|
||||
self.static_chunk_size = static_chunk_size
|
||||
self.num_decoding_left_chunks = num_decoding_left_chunks
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
self.mid_blocks = nn.ModuleList([])
|
||||
self.up_blocks = nn.ModuleList([])
|
||||
|
||||
output_channel = in_channels
|
||||
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
|
||||
input_channel = output_channel
|
||||
output_channel = channels[i]
|
||||
is_last = i == len(channels) - 1
|
||||
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
dim=output_channel,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
downsample = (
|
||||
Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3)
|
||||
)
|
||||
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
||||
|
||||
for _ in range(num_mid_blocks):
|
||||
input_channel = channels[-1]
|
||||
out_channels = channels[-1]
|
||||
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
||||
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
dim=output_channel,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
|
||||
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
|
||||
|
||||
channels = channels[::-1] + (channels[0],)
|
||||
for i in range(len(channels) - 1):
|
||||
input_channel = channels[i] * 2
|
||||
output_channel = channels[i + 1]
|
||||
is_last = i == len(channels) - 2
|
||||
resnet = CausalResnetBlock1D(
|
||||
dim=input_channel,
|
||||
dim_out=output_channel,
|
||||
time_emb_dim=time_embed_dim,
|
||||
)
|
||||
transformer_blocks = nn.ModuleList(
|
||||
[
|
||||
BasicTransformerBlock(
|
||||
dim=output_channel,
|
||||
num_attention_heads=num_heads,
|
||||
attention_head_dim=attention_head_dim,
|
||||
dropout=dropout,
|
||||
activation_fn=act_fn,
|
||||
)
|
||||
for _ in range(n_blocks)
|
||||
]
|
||||
)
|
||||
upsample = (
|
||||
Upsample1D(output_channel, use_conv_transpose=True)
|
||||
if not is_last
|
||||
else CausalConv1d(output_channel, output_channel, 3)
|
||||
)
|
||||
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
|
||||
self.final_block = CausalBlock1D(channels[-1], channels[-1])
|
||||
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
||||
self.initialize_weights()
|
||||
|
||||
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
|
||||
"""Forward pass of the UNet1DConditional model.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): shape (batch_size, in_channels, time)
|
||||
mask (_type_): shape (batch_size, 1, time)
|
||||
t (_type_): shape (batch_size)
|
||||
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
|
||||
cond (_type_, optional): placeholder for future use. Defaults to None.
|
||||
|
||||
Raises:
|
||||
ValueError: _description_
|
||||
ValueError: _description_
|
||||
|
||||
Returns:
|
||||
_type_: _description_
|
||||
"""
|
||||
t = self.time_embeddings(t).to(t.dtype)
|
||||
t = self.time_mlp(t)
|
||||
|
||||
x = pack([x, mu], "b * t")[0]
|
||||
|
||||
if spks is not None:
|
||||
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
|
||||
x = pack([x, spks], "b * t")[0]
|
||||
if cond is not None:
|
||||
x = pack([x, cond], "b * t")[0]
|
||||
|
||||
hiddens = []
|
||||
masks = [mask]
|
||||
for resnet, transformer_blocks, downsample in self.down_blocks:
|
||||
mask_down = masks[-1]
|
||||
x = resnet(x, mask_down, t)
|
||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||
if streaming is True:
|
||||
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
|
||||
else:
|
||||
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||
hiddens.append(x) # Save hidden states for skip connections
|
||||
x = downsample(x * mask_down)
|
||||
masks.append(mask_down[:, :, ::2])
|
||||
masks = masks[:-1]
|
||||
mask_mid = masks[-1]
|
||||
|
||||
for resnet, transformer_blocks in self.mid_blocks:
|
||||
x = resnet(x, mask_mid, t)
|
||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||
if streaming is True:
|
||||
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
|
||||
else:
|
||||
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||
|
||||
for resnet, transformer_blocks, upsample in self.up_blocks:
|
||||
mask_up = masks.pop()
|
||||
skip = hiddens.pop()
|
||||
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
|
||||
x = resnet(x, mask_up, t)
|
||||
x = rearrange(x, "b c t -> b t c").contiguous()
|
||||
if streaming is True:
|
||||
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
|
||||
else:
|
||||
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
|
||||
attn_mask = mask_to_bias(attn_mask, x.dtype)
|
||||
for transformer_block in transformer_blocks:
|
||||
x = transformer_block(
|
||||
hidden_states=x,
|
||||
attention_mask=attn_mask,
|
||||
timestep=t,
|
||||
)
|
||||
x = rearrange(x, "b t c -> b c t").contiguous()
|
||||
x = upsample(x * mask_up)
|
||||
x = self.final_block(x, mask_up)
|
||||
output = self.final_proj(x * mask_up)
|
||||
return output * mask
|
||||
443
models/CosyVoice/cosyvoice/flow/flow.py
Normal file
443
models/CosyVoice/cosyvoice/flow/flow.py
Normal file
@@ -0,0 +1,443 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os, logging
|
||||
import random
|
||||
from typing import Dict, Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch.nn import functional as F
|
||||
from omegaconf import DictConfig
|
||||
from cosyvoice.utils.mask import make_pad_mask
|
||||
from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path
|
||||
|
||||
|
||||
class MaskedDiffWithXvec(torch.nn.Module):
|
||||
def __init__(self,
|
||||
input_size: int = 512,
|
||||
output_size: int = 80,
|
||||
spk_embed_dim: int = 192,
|
||||
output_type: str = "mel",
|
||||
vocab_size: int = 4096,
|
||||
input_frame_rate: int = 50,
|
||||
only_mask_loss: bool = True,
|
||||
encoder: torch.nn.Module = None,
|
||||
length_regulator: torch.nn.Module = None,
|
||||
decoder: torch.nn.Module = None,
|
||||
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
||||
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
||||
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
||||
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
||||
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.decoder_conf = decoder_conf
|
||||
self.vocab_size = vocab_size
|
||||
self.output_type = output_type
|
||||
self.input_frame_rate = input_frame_rate
|
||||
logging.info(f"input frame rate={self.input_frame_rate}")
|
||||
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
||||
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
||||
self.encoder = encoder
|
||||
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
||||
self.decoder = decoder
|
||||
self.length_regulator = length_regulator
|
||||
self.only_mask_loss = only_mask_loss
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict,
|
||||
device: torch.device,
|
||||
) -> Dict[str, Optional[torch.Tensor]]:
|
||||
token = batch['speech_token'].to(device)
|
||||
token_len = batch['speech_token_len'].to(device)
|
||||
feat = batch['speech_feat'].to(device)
|
||||
feat_len = batch['speech_feat_len'].to(device)
|
||||
embedding = batch['embedding'].to(device)
|
||||
|
||||
# xvec projection
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
|
||||
# concat text and prompt_text
|
||||
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
||||
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
||||
|
||||
# text encode
|
||||
h, h_lengths = self.encoder(token, token_len)
|
||||
h = self.encoder_proj(h)
|
||||
h, h_lengths = self.length_regulator(h, feat_len)
|
||||
|
||||
# get conditions
|
||||
conds = torch.zeros(feat.shape, device=token.device)
|
||||
for i, j in enumerate(feat_len):
|
||||
if random.random() < 0.5:
|
||||
continue
|
||||
index = random.randint(0, int(0.3 * j))
|
||||
conds[i, :index] = feat[i, :index]
|
||||
conds = conds.transpose(1, 2)
|
||||
|
||||
mask = (~make_pad_mask(feat_len)).to(h)
|
||||
# NOTE this is unnecessary, feat/h already same shape
|
||||
loss, _ = self.decoder.compute_loss(
|
||||
feat.transpose(1, 2).contiguous(),
|
||||
mask.unsqueeze(1),
|
||||
h.transpose(1, 2).contiguous(),
|
||||
embedding,
|
||||
cond=conds
|
||||
)
|
||||
return {'loss': loss}
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(self,
|
||||
token,
|
||||
token_len,
|
||||
prompt_token,
|
||||
prompt_token_len,
|
||||
prompt_feat,
|
||||
prompt_feat_len,
|
||||
embedding,
|
||||
flow_cache):
|
||||
assert token.shape[0] == 1
|
||||
# xvec projection
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
|
||||
# concat speech token and prompt speech token
|
||||
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
|
||||
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
||||
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
||||
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
||||
|
||||
# text encode
|
||||
h, h_lengths = self.encoder(token, token_len)
|
||||
h = self.encoder_proj(h)
|
||||
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
|
||||
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
|
||||
|
||||
# get conditions
|
||||
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
||||
conds[:, :mel_len1] = prompt_feat
|
||||
conds = conds.transpose(1, 2)
|
||||
|
||||
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
||||
feat, flow_cache = self.decoder(
|
||||
mu=h.transpose(1, 2).contiguous(),
|
||||
mask=mask.unsqueeze(1),
|
||||
spks=embedding,
|
||||
cond=conds,
|
||||
n_timesteps=10,
|
||||
prompt_len=mel_len1,
|
||||
cache=flow_cache
|
||||
)
|
||||
feat = feat[:, :, mel_len1:]
|
||||
assert feat.shape[2] == mel_len2
|
||||
return feat.float(), flow_cache
|
||||
|
||||
|
||||
class CausalMaskedDiffWithXvec(torch.nn.Module):
|
||||
def __init__(self,
|
||||
input_size: int = 512,
|
||||
output_size: int = 80,
|
||||
spk_embed_dim: int = 192,
|
||||
output_type: str = "mel",
|
||||
vocab_size: int = 4096,
|
||||
input_frame_rate: int = 50,
|
||||
only_mask_loss: bool = True,
|
||||
token_mel_ratio: int = 2,
|
||||
pre_lookahead_len: int = 3,
|
||||
encoder: torch.nn.Module = None,
|
||||
decoder: torch.nn.Module = None,
|
||||
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
||||
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
||||
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
||||
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
||||
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.decoder_conf = decoder_conf
|
||||
self.vocab_size = vocab_size
|
||||
self.output_type = output_type
|
||||
self.input_frame_rate = input_frame_rate
|
||||
logging.info(f"input frame rate={self.input_frame_rate}")
|
||||
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
||||
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
||||
self.encoder = encoder
|
||||
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
|
||||
self.decoder = decoder
|
||||
self.only_mask_loss = only_mask_loss
|
||||
self.token_mel_ratio = token_mel_ratio
|
||||
self.pre_lookahead_len = pre_lookahead_len
|
||||
if online_feature is True:
|
||||
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx'))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict,
|
||||
device: torch.device,
|
||||
) -> Dict[str, Optional[torch.Tensor]]:
|
||||
if 'speech_token' not in batch:
|
||||
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
|
||||
else:
|
||||
token = batch['speech_token'].to(device)
|
||||
token_len = batch['speech_token_len'].to(device)
|
||||
feat = batch['speech_feat'].to(device)
|
||||
feat_len = batch['speech_feat_len'].to(device)
|
||||
embedding = batch['embedding'].to(device)
|
||||
|
||||
# NOTE unified training, static_chunk_size > 0 or = 0
|
||||
streaming = True if random.random() < 0.5 else False
|
||||
|
||||
# xvec projection
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
|
||||
# concat text and prompt_text
|
||||
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
||||
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
||||
|
||||
# text encode
|
||||
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
|
||||
h = self.encoder_proj(h)
|
||||
|
||||
# get conditions
|
||||
conds = torch.zeros(feat.shape, device=token.device)
|
||||
for i, j in enumerate(feat_len):
|
||||
if random.random() < 0.5:
|
||||
continue
|
||||
index = random.randint(0, int(0.3 * j))
|
||||
conds[i, :index] = feat[i, :index]
|
||||
conds = conds.transpose(1, 2)
|
||||
|
||||
mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
|
||||
loss, _ = self.decoder.compute_loss(
|
||||
feat.transpose(1, 2).contiguous(),
|
||||
mask.unsqueeze(1),
|
||||
h.transpose(1, 2).contiguous(),
|
||||
embedding,
|
||||
cond=conds,
|
||||
streaming=streaming,
|
||||
)
|
||||
return {'loss': loss}
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(self,
|
||||
token,
|
||||
token_len,
|
||||
prompt_token,
|
||||
prompt_token_len,
|
||||
prompt_feat,
|
||||
prompt_feat_len,
|
||||
embedding,
|
||||
streaming,
|
||||
finalize):
|
||||
assert token.shape[0] == 1
|
||||
# xvec projection
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
|
||||
# concat text and prompt_text
|
||||
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
||||
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
||||
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
||||
|
||||
# text encode
|
||||
if finalize is True:
|
||||
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
|
||||
else:
|
||||
token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
|
||||
h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
|
||||
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
||||
h = self.encoder_proj(h)
|
||||
|
||||
# get conditions
|
||||
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
||||
conds[:, :mel_len1] = prompt_feat
|
||||
conds = conds.transpose(1, 2)
|
||||
|
||||
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
||||
feat, _ = self.decoder(
|
||||
mu=h.transpose(1, 2).contiguous(),
|
||||
mask=mask.unsqueeze(1),
|
||||
spks=embedding,
|
||||
cond=conds,
|
||||
n_timesteps=10,
|
||||
streaming=streaming
|
||||
)
|
||||
feat = feat[:, :, mel_len1:]
|
||||
assert feat.shape[2] == mel_len2
|
||||
return feat.float(), None
|
||||
|
||||
|
||||
class CausalMaskedDiffWithDiT(torch.nn.Module):
|
||||
def __init__(self,
|
||||
input_size: int = 512,
|
||||
output_size: int = 80,
|
||||
spk_embed_dim: int = 192,
|
||||
output_type: str = "mel",
|
||||
vocab_size: int = 4096,
|
||||
input_frame_rate: int = 50,
|
||||
only_mask_loss: bool = True,
|
||||
token_mel_ratio: int = 2,
|
||||
pre_lookahead_len: int = 3,
|
||||
pre_lookahead_layer: torch.nn.Module = None,
|
||||
decoder: torch.nn.Module = None,
|
||||
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
||||
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
||||
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
||||
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
||||
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.decoder_conf = decoder_conf
|
||||
self.vocab_size = vocab_size
|
||||
self.output_type = output_type
|
||||
self.input_frame_rate = input_frame_rate
|
||||
logging.info(f"input frame rate={self.input_frame_rate}")
|
||||
self.input_embedding = nn.Embedding(vocab_size, input_size)
|
||||
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
|
||||
self.pre_lookahead_len = pre_lookahead_len
|
||||
self.pre_lookahead_layer = pre_lookahead_layer
|
||||
self.decoder = decoder
|
||||
self.only_mask_loss = only_mask_loss
|
||||
self.token_mel_ratio = token_mel_ratio
|
||||
if online_feature is True:
|
||||
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict,
|
||||
device: torch.device,
|
||||
) -> Dict[str, Optional[torch.Tensor]]:
|
||||
if 'speech_token' not in batch:
|
||||
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
|
||||
else:
|
||||
token = batch['speech_token'].to(device)
|
||||
token_len = batch['speech_token_len'].to(device)
|
||||
feat = batch['speech_feat'].to(device)
|
||||
feat_len = batch['speech_feat_len'].to(device)
|
||||
embedding = batch['embedding'].to(device)
|
||||
|
||||
# NOTE unified training, static_chunk_size > 0 or = 0
|
||||
streaming = True if random.random() < 0.5 else False
|
||||
|
||||
# xvec projection
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
|
||||
# concat text and prompt_text
|
||||
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
|
||||
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
||||
|
||||
# text encode
|
||||
h = self.pre_lookahead_layer(token)
|
||||
h = h.repeat_interleave(self.token_mel_ratio, dim=1)
|
||||
mask = mask.repeat_interleave(self.token_mel_ratio, dim=1).squeeze(dim=-1)
|
||||
|
||||
# get conditions
|
||||
conds = torch.zeros(feat.shape, device=token.device)
|
||||
for i, j in enumerate(feat_len):
|
||||
if random.random() < 0.5:
|
||||
continue
|
||||
index = random.randint(0, int(0.3 * j))
|
||||
conds[i, :index] = feat[i, :index]
|
||||
conds = conds.transpose(1, 2)
|
||||
|
||||
loss, _ = self.decoder.compute_loss(
|
||||
feat.transpose(1, 2).contiguous(),
|
||||
mask.unsqueeze(1),
|
||||
h.transpose(1, 2).contiguous(),
|
||||
embedding,
|
||||
cond=conds,
|
||||
streaming=streaming,
|
||||
)
|
||||
return {'loss': loss}
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(self,
|
||||
token,
|
||||
token_len,
|
||||
prompt_token,
|
||||
prompt_token_len,
|
||||
prompt_feat,
|
||||
prompt_feat_len,
|
||||
embedding,
|
||||
streaming,
|
||||
finalize):
|
||||
assert token.shape[0] == 1
|
||||
# xvec projection
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
|
||||
# concat text and prompt_text
|
||||
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
|
||||
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
|
||||
token = self.input_embedding(torch.clamp(token, min=0)) * mask
|
||||
|
||||
# text encode
|
||||
if finalize is True:
|
||||
h = self.pre_lookahead_layer(token)
|
||||
else:
|
||||
h = self.pre_lookahead_layer(token[:, :-self.pre_lookahead_len], context=token[:, -self.pre_lookahead_len:])
|
||||
h = h.repeat_interleave(self.token_mel_ratio, dim=1)
|
||||
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
|
||||
|
||||
# get conditions
|
||||
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
|
||||
conds[:, :mel_len1] = prompt_feat
|
||||
conds = conds.transpose(1, 2)
|
||||
|
||||
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
|
||||
feat, _ = self.decoder(
|
||||
mu=h.transpose(1, 2).contiguous(),
|
||||
mask=mask.unsqueeze(1),
|
||||
spks=embedding,
|
||||
cond=conds,
|
||||
n_timesteps=10,
|
||||
streaming=streaming
|
||||
)
|
||||
feat = feat[:, :, mel_len1:]
|
||||
assert feat.shape[2] == mel_len2
|
||||
return feat.float(), None
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides={'llm': None, 'hift': None})
|
||||
model = configs['flow']
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
model.to(device)
|
||||
model.eval()
|
||||
max_len = 10 * model.decoder.estimator.static_chunk_size
|
||||
chunk_size = model.decoder.estimator.static_chunk_size
|
||||
context_size = model.pre_lookahead_layer.pre_lookahead_len
|
||||
token = torch.randint(0, 6561, size=(1, max_len)).to(device)
|
||||
token_len = torch.tensor([max_len]).to(device)
|
||||
prompt_token = torch.randint(0, 6561, size=(1, chunk_size)).to(device)
|
||||
prompt_token_len = torch.tensor([chunk_size]).to(device)
|
||||
prompt_feat = torch.rand(1, chunk_size * 2, 80).to(device)
|
||||
prompt_feat_len = torch.tensor([chunk_size * 2]).to(device)
|
||||
prompt_embedding = torch.rand(1, 192).to(device)
|
||||
pred_gt, _ = model.inference(token, token_len, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=True)
|
||||
for i in range(0, max_len, chunk_size):
|
||||
finalize = True if i + chunk_size + context_size >= max_len else False
|
||||
pred_chunk, _ = model.inference(token[:, :i + chunk_size + context_size], torch.tensor([token[:, :i + chunk_size + context_size].shape[1]]).to(device),
|
||||
prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=finalize)
|
||||
pred_chunk = pred_chunk[:, :, i * model.token_mel_ratio:]
|
||||
print((pred_gt[:, :, i * model.token_mel_ratio: i * model.token_mel_ratio + pred_chunk.shape[2]] - pred_chunk).abs().max().item())
|
||||
227
models/CosyVoice/cosyvoice/flow/flow_matching.py
Normal file
227
models/CosyVoice/cosyvoice/flow/flow_matching.py
Normal file
@@ -0,0 +1,227 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
||||
# 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from matcha.models.components.flow_matching import BASECFM
|
||||
from cosyvoice.utils.common import set_all_random_seed
|
||||
|
||||
|
||||
class ConditionalCFM(BASECFM):
|
||||
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
||||
super().__init__(
|
||||
n_feats=in_channels,
|
||||
cfm_params=cfm_params,
|
||||
n_spks=n_spks,
|
||||
spk_emb_dim=spk_emb_dim,
|
||||
)
|
||||
self.t_scheduler = cfm_params.t_scheduler
|
||||
self.training_cfg_rate = cfm_params.training_cfg_rate
|
||||
self.inference_cfg_rate = cfm_params.inference_cfg_rate
|
||||
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
|
||||
# Just change the architecture of the estimator here
|
||||
self.estimator = estimator
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): output_mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
n_timesteps (int): number of diffusion steps
|
||||
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
||||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
cond: Not used but kept for future purposes
|
||||
|
||||
Returns:
|
||||
sample: generated mel-spectrogram
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
|
||||
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
|
||||
cache_size = cache.shape[2]
|
||||
# fix prompt and overlap part mu and z
|
||||
if cache_size != 0:
|
||||
z[:, :, :cache_size] = cache[:, :, :, 0]
|
||||
mu[:, :, :cache_size] = cache[:, :, :, 1]
|
||||
z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
|
||||
mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
|
||||
cache = torch.stack([z_cache, mu_cache], dim=-1)
|
||||
|
||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||
if self.t_scheduler == 'cosine':
|
||||
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
|
||||
|
||||
def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
|
||||
"""
|
||||
Fixed euler solver for ODEs.
|
||||
Args:
|
||||
x (torch.Tensor): random noise
|
||||
t_span (torch.Tensor): n_timesteps interpolated
|
||||
shape: (n_timesteps + 1,)
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): output_mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
cond: Not used but kept for future purposes
|
||||
"""
|
||||
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
|
||||
t = t.unsqueeze(dim=0)
|
||||
|
||||
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
|
||||
# Or in future might add like a return_all_steps flag
|
||||
sol = []
|
||||
|
||||
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
|
||||
# NOTE when flow run in amp mode, x.dtype is float32, which cause nan in trt fp16 inference, so set dtype=spks.dtype
|
||||
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
|
||||
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=spks.dtype)
|
||||
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
|
||||
t_in = torch.zeros([2], device=x.device, dtype=spks.dtype)
|
||||
spks_in = torch.zeros([2, 80], device=x.device, dtype=spks.dtype)
|
||||
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
|
||||
for step in range(1, len(t_span)):
|
||||
# Classifier-Free Guidance inference introduced in VoiceBox
|
||||
x_in[:] = x
|
||||
mask_in[:] = mask
|
||||
mu_in[0] = mu
|
||||
t_in[:] = t.unsqueeze(0)
|
||||
spks_in[0] = spks
|
||||
cond_in[0] = cond
|
||||
dphi_dt = self.forward_estimator(
|
||||
x_in, mask_in,
|
||||
mu_in, t_in,
|
||||
spks_in,
|
||||
cond_in,
|
||||
streaming
|
||||
)
|
||||
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
|
||||
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
|
||||
x = x + dt * dphi_dt
|
||||
t = t + dt
|
||||
sol.append(x)
|
||||
if step < len(t_span) - 1:
|
||||
dt = t_span[step + 1] - t
|
||||
|
||||
return sol[-1].float()
|
||||
|
||||
def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
|
||||
if isinstance(self.estimator, torch.nn.Module):
|
||||
return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
|
||||
else:
|
||||
[estimator, stream], trt_engine = self.estimator.acquire_estimator()
|
||||
# NOTE need to synchronize when switching stream
|
||||
torch.cuda.current_stream().synchronize()
|
||||
with stream:
|
||||
estimator.set_input_shape('x', (2, 80, x.size(2)))
|
||||
estimator.set_input_shape('mask', (2, 1, x.size(2)))
|
||||
estimator.set_input_shape('mu', (2, 80, x.size(2)))
|
||||
estimator.set_input_shape('t', (2,))
|
||||
estimator.set_input_shape('spks', (2, 80))
|
||||
estimator.set_input_shape('cond', (2, 80, x.size(2)))
|
||||
data_ptrs = [x.contiguous().data_ptr(),
|
||||
mask.contiguous().data_ptr(),
|
||||
mu.contiguous().data_ptr(),
|
||||
t.contiguous().data_ptr(),
|
||||
spks.contiguous().data_ptr(),
|
||||
cond.contiguous().data_ptr(),
|
||||
x.data_ptr()]
|
||||
for i, j in enumerate(data_ptrs):
|
||||
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
|
||||
# run trt engine
|
||||
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
|
||||
torch.cuda.current_stream().synchronize()
|
||||
self.estimator.release_estimator(estimator, stream)
|
||||
return x
|
||||
|
||||
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
|
||||
"""Computes diffusion loss
|
||||
|
||||
Args:
|
||||
x1 (torch.Tensor): Target
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): target mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
|
||||
Returns:
|
||||
loss: conditional flow matching loss
|
||||
y: conditional flow
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
b, _, t = mu.shape
|
||||
|
||||
# random timestep
|
||||
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
|
||||
|
||||
# sample noise p(x_0)
|
||||
z = torch.randn_like(x1)
|
||||
|
||||
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
|
||||
u = x1 - (1 - self.sigma_min) * z
|
||||
|
||||
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
|
||||
if self.training_cfg_rate > 0:
|
||||
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
|
||||
mu = mu * cfg_mask.view(-1, 1, 1)
|
||||
spks = spks * cfg_mask.view(-1, 1)
|
||||
cond = cond * cfg_mask.view(-1, 1, 1)
|
||||
|
||||
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
|
||||
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
|
||||
return loss, y
|
||||
|
||||
|
||||
class CausalConditionalCFM(ConditionalCFM):
|
||||
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
||||
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
|
||||
set_all_random_seed(0)
|
||||
self.rand_noise = torch.randn([1, 80, 50 * 300])
|
||||
|
||||
@torch.inference_mode()
|
||||
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
|
||||
"""Forward diffusion
|
||||
|
||||
Args:
|
||||
mu (torch.Tensor): output of encoder
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
mask (torch.Tensor): output_mask
|
||||
shape: (batch_size, 1, mel_timesteps)
|
||||
n_timesteps (int): number of diffusion steps
|
||||
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
|
||||
spks (torch.Tensor, optional): speaker ids. Defaults to None.
|
||||
shape: (batch_size, spk_emb_dim)
|
||||
cond: Not used but kept for future purposes
|
||||
|
||||
Returns:
|
||||
sample: generated mel-spectrogram
|
||||
shape: (batch_size, n_feats, mel_timesteps)
|
||||
"""
|
||||
|
||||
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
|
||||
# fix prompt and overlap part mu and z
|
||||
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
|
||||
if self.t_scheduler == 'cosine':
|
||||
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
|
||||
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None
|
||||
70
models/CosyVoice/cosyvoice/flow/length_regulator.py
Normal file
70
models/CosyVoice/cosyvoice/flow/length_regulator.py
Normal file
@@ -0,0 +1,70 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Tuple
|
||||
import torch.nn as nn
|
||||
import torch
|
||||
from torch.nn import functional as F
|
||||
from cosyvoice.utils.mask import make_pad_mask
|
||||
|
||||
|
||||
class InterpolateRegulator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
channels: int,
|
||||
sampling_ratios: Tuple,
|
||||
out_channels: int = None,
|
||||
groups: int = 1,
|
||||
):
|
||||
super().__init__()
|
||||
self.sampling_ratios = sampling_ratios
|
||||
out_channels = out_channels or channels
|
||||
model = nn.ModuleList([])
|
||||
if len(sampling_ratios) > 0:
|
||||
for _ in sampling_ratios:
|
||||
module = nn.Conv1d(channels, channels, 3, 1, 1)
|
||||
norm = nn.GroupNorm(groups, channels)
|
||||
act = nn.Mish()
|
||||
model.extend([module, norm, act])
|
||||
model.append(
|
||||
nn.Conv1d(channels, out_channels, 1, 1)
|
||||
)
|
||||
self.model = nn.Sequential(*model)
|
||||
|
||||
def forward(self, x, ylens=None):
|
||||
# x in (B, T, D)
|
||||
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
|
||||
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
|
||||
out = self.model(x).transpose(1, 2).contiguous()
|
||||
olens = ylens
|
||||
return out * mask, olens
|
||||
|
||||
def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
|
||||
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
|
||||
# NOTE 20 corresponds to token_overlap_len in cosyvoice/cli/model.py
|
||||
# x in (B, T, D)
|
||||
if x2.shape[1] > 40:
|
||||
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
||||
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
|
||||
mode='linear')
|
||||
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
|
||||
x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
|
||||
else:
|
||||
x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
|
||||
if x1.shape[1] != 0:
|
||||
x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
|
||||
x = torch.concat([x1, x2], dim=2)
|
||||
else:
|
||||
x = x2
|
||||
out = self.model(x).transpose(1, 2).contiguous()
|
||||
return out, mel_len1 + mel_len2
|
||||
230
models/CosyVoice/cosyvoice/hifigan/discriminator.py
Normal file
230
models/CosyVoice/cosyvoice/hifigan/discriminator.py
Normal file
@@ -0,0 +1,230 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
try:
|
||||
from torch.nn.utils.parametrizations import weight_norm, spectral_norm
|
||||
except ImportError:
|
||||
from torch.nn.utils import weight_norm, spectral_norm
|
||||
from typing import List, Optional, Tuple
|
||||
from einops import rearrange
|
||||
from torchaudio.transforms import Spectrogram
|
||||
|
||||
LRELU_SLOPE = 0.1
|
||||
|
||||
|
||||
class MultipleDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self, mpd: nn.Module, mrd: nn.Module
|
||||
):
|
||||
super().__init__()
|
||||
self.mpd = mpd
|
||||
self.mrd = mrd
|
||||
|
||||
def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
|
||||
y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
|
||||
this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1))
|
||||
y_d_rs += this_y_d_rs
|
||||
y_d_gs += this_y_d_gs
|
||||
fmap_rs += this_fmap_rs
|
||||
fmap_gs += this_fmap_gs
|
||||
this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat)
|
||||
y_d_rs += this_y_d_rs
|
||||
y_d_gs += this_y_d_gs
|
||||
fmap_rs += this_fmap_rs
|
||||
fmap_gs += this_fmap_gs
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class MultiResolutionDiscriminator(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
|
||||
num_embeddings: Optional[int] = None,
|
||||
):
|
||||
"""
|
||||
Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
|
||||
Additionally, it allows incorporating conditional information with a learned embeddings table.
|
||||
|
||||
Args:
|
||||
fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
|
||||
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
|
||||
Defaults to None.
|
||||
"""
|
||||
|
||||
super().__init__()
|
||||
self.discriminators = nn.ModuleList(
|
||||
[DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
|
||||
)
|
||||
|
||||
def forward(
|
||||
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
|
||||
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
|
||||
for d in self.discriminators:
|
||||
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
|
||||
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
class DiscriminatorR(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
window_length: int,
|
||||
num_embeddings: Optional[int] = None,
|
||||
channels: int = 32,
|
||||
hop_factor: float = 0.25,
|
||||
bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
|
||||
):
|
||||
super().__init__()
|
||||
self.window_length = window_length
|
||||
self.hop_factor = hop_factor
|
||||
self.spec_fn = Spectrogram(
|
||||
n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
|
||||
)
|
||||
n_fft = window_length // 2 + 1
|
||||
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
|
||||
self.bands = bands
|
||||
convs = lambda: nn.ModuleList(
|
||||
[
|
||||
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
|
||||
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
|
||||
]
|
||||
)
|
||||
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
|
||||
|
||||
if num_embeddings is not None:
|
||||
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
|
||||
torch.nn.init.zeros_(self.emb.weight)
|
||||
|
||||
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
|
||||
|
||||
def spectrogram(self, x):
|
||||
# Remove DC offset
|
||||
x = x - x.mean(dim=-1, keepdims=True)
|
||||
# Peak normalize the volume of input audio
|
||||
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
|
||||
x = self.spec_fn(x)
|
||||
x = torch.view_as_real(x)
|
||||
x = rearrange(x, "b f t c -> b c t f")
|
||||
# Split into bands
|
||||
x_bands = [x[..., b[0]: b[1]] for b in self.bands]
|
||||
return x_bands
|
||||
|
||||
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
|
||||
x_bands = self.spectrogram(x)
|
||||
fmap = []
|
||||
x = []
|
||||
for band, stack in zip(x_bands, self.band_convs):
|
||||
for i, layer in enumerate(stack):
|
||||
band = layer(band)
|
||||
band = torch.nn.functional.leaky_relu(band, 0.1)
|
||||
if i > 0:
|
||||
fmap.append(band)
|
||||
x.append(band)
|
||||
x = torch.cat(x, dim=-1)
|
||||
if cond_embedding_id is not None:
|
||||
emb = self.emb(cond_embedding_id)
|
||||
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
|
||||
else:
|
||||
h = 0
|
||||
x = self.conv_post(x)
|
||||
fmap.append(x)
|
||||
x += h
|
||||
|
||||
return x, fmap
|
||||
|
||||
|
||||
class MultiResSpecDiscriminator(torch.nn.Module):
|
||||
|
||||
def __init__(self,
|
||||
fft_sizes=[1024, 2048, 512],
|
||||
hop_sizes=[120, 240, 50],
|
||||
win_lengths=[600, 1200, 240],
|
||||
window="hann_window"):
|
||||
|
||||
super(MultiResSpecDiscriminator, self).__init__()
|
||||
self.discriminators = nn.ModuleList([
|
||||
SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
|
||||
SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
|
||||
SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)])
|
||||
|
||||
def forward(self, y, y_hat):
|
||||
y_d_rs = []
|
||||
y_d_gs = []
|
||||
fmap_rs = []
|
||||
fmap_gs = []
|
||||
for _, d in enumerate(self.discriminators):
|
||||
y_d_r, fmap_r = d(y)
|
||||
y_d_g, fmap_g = d(y_hat)
|
||||
y_d_rs.append(y_d_r)
|
||||
fmap_rs.append(fmap_r)
|
||||
y_d_gs.append(y_d_g)
|
||||
fmap_gs.append(fmap_g)
|
||||
|
||||
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
||||
|
||||
|
||||
def stft(x, fft_size, hop_size, win_length, window):
|
||||
"""Perform STFT and convert to magnitude spectrogram.
|
||||
Args:
|
||||
x (Tensor): Input signal tensor (B, T).
|
||||
fft_size (int): FFT size.
|
||||
hop_size (int): Hop size.
|
||||
win_length (int): Window length.
|
||||
window (str): Window function type.
|
||||
Returns:
|
||||
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
|
||||
"""
|
||||
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
|
||||
|
||||
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
|
||||
return torch.abs(x_stft).transpose(2, 1)
|
||||
|
||||
|
||||
class SpecDiscriminator(nn.Module):
|
||||
"""docstring for Discriminator."""
|
||||
|
||||
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
|
||||
super(SpecDiscriminator, self).__init__()
|
||||
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
|
||||
self.fft_size = fft_size
|
||||
self.shift_size = shift_size
|
||||
self.win_length = win_length
|
||||
self.window = getattr(torch, window)(win_length)
|
||||
self.discriminators = nn.ModuleList([
|
||||
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
|
||||
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
|
||||
])
|
||||
|
||||
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
|
||||
|
||||
def forward(self, y):
|
||||
|
||||
fmap = []
|
||||
y = y.squeeze(1)
|
||||
y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.device))
|
||||
y = y.unsqueeze(1)
|
||||
for _, d in enumerate(self.discriminators):
|
||||
y = d(y)
|
||||
y = F.leaky_relu(y, LRELU_SLOPE)
|
||||
fmap.append(y)
|
||||
|
||||
y = self.out(y)
|
||||
fmap.append(y)
|
||||
|
||||
return torch.flatten(y, 1, -1), fmap
|
||||
103
models/CosyVoice/cosyvoice/hifigan/f0_predictor.py
Normal file
103
models/CosyVoice/cosyvoice/hifigan/f0_predictor.py
Normal file
@@ -0,0 +1,103 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
try:
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
except ImportError:
|
||||
from torch.nn.utils import weight_norm
|
||||
from cosyvoice.transformer.convolution import CausalConv1d
|
||||
|
||||
|
||||
class ConvRNNF0Predictor(nn.Module):
|
||||
def __init__(self,
|
||||
num_class: int = 1,
|
||||
in_channels: int = 80,
|
||||
cond_channels: int = 512
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_class = num_class
|
||||
self.condnet = nn.Sequential(
|
||||
weight_norm(
|
||||
nn.Conv1d(in_channels, cond_channels, kernel_size=3, padding=1)
|
||||
),
|
||||
nn.ELU(),
|
||||
weight_norm(
|
||||
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
||||
),
|
||||
nn.ELU(),
|
||||
weight_norm(
|
||||
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
||||
),
|
||||
nn.ELU(),
|
||||
weight_norm(
|
||||
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
||||
),
|
||||
nn.ELU(),
|
||||
weight_norm(
|
||||
nn.Conv1d(cond_channels, cond_channels, kernel_size=3, padding=1)
|
||||
),
|
||||
nn.ELU(),
|
||||
)
|
||||
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
x = self.condnet(x)
|
||||
x = x.transpose(1, 2)
|
||||
return torch.abs(self.classifier(x).squeeze(-1))
|
||||
|
||||
|
||||
class CausalConvRNNF0Predictor(nn.Module):
|
||||
def __init__(self,
|
||||
num_class: int = 1,
|
||||
in_channels: int = 80,
|
||||
cond_channels: int = 512
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.num_class = num_class
|
||||
self.condnet = nn.Sequential(
|
||||
weight_norm(
|
||||
CausalConv1d(in_channels, cond_channels, kernel_size=4, causal_type='right')
|
||||
),
|
||||
nn.ELU(),
|
||||
weight_norm(
|
||||
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
|
||||
),
|
||||
nn.ELU(),
|
||||
weight_norm(
|
||||
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
|
||||
),
|
||||
nn.ELU(),
|
||||
weight_norm(
|
||||
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
|
||||
),
|
||||
nn.ELU(),
|
||||
weight_norm(
|
||||
CausalConv1d(cond_channels, cond_channels, kernel_size=3, causal_type='left')
|
||||
),
|
||||
nn.ELU(),
|
||||
)
|
||||
self.classifier = nn.Linear(in_features=cond_channels, out_features=self.num_class)
|
||||
|
||||
def forward(self, x: torch.Tensor, finalize: bool = True) -> torch.Tensor:
|
||||
if finalize is True:
|
||||
x = self.condnet[0](x)
|
||||
else:
|
||||
x = self.condnet[0](x[:, :, :-self.condnet[0].causal_padding], x[:, :, -self.condnet[0].causal_padding:])
|
||||
for i in range(1, len(self.condnet)):
|
||||
x = self.condnet[i](x)
|
||||
x = x.transpose(1, 2)
|
||||
return torch.abs(self.classifier(x).squeeze(-1))
|
||||
746
models/CosyVoice/cosyvoice/hifigan/generator.py
Normal file
746
models/CosyVoice/cosyvoice/hifigan/generator.py
Normal file
@@ -0,0 +1,746 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Kai Hu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""HIFI-GAN"""
|
||||
|
||||
from typing import Dict, Optional, List
|
||||
import numpy as np
|
||||
from scipy.signal import get_window
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.nn import Conv1d
|
||||
from torch.nn import ConvTranspose1d
|
||||
from torch.nn.utils import remove_weight_norm
|
||||
try:
|
||||
from torch.nn.utils.parametrizations import weight_norm
|
||||
except ImportError:
|
||||
from torch.nn.utils import weight_norm
|
||||
from torch.distributions.uniform import Uniform
|
||||
from cosyvoice.transformer.convolution import CausalConv1d, CausalConv1dDownSample, CausalConv1dUpsample
|
||||
from cosyvoice.transformer.activation import Snake
|
||||
from cosyvoice.utils.common import get_padding
|
||||
from cosyvoice.utils.common import init_weights
|
||||
|
||||
|
||||
"""hifigan based generator implementation.
|
||||
|
||||
This code is modified from https://github.com/jik876/hifi-gan
|
||||
,https://github.com/kan-bayashi/ParallelWaveGAN and
|
||||
https://github.com/NVIDIA/BigVGAN
|
||||
|
||||
"""
|
||||
|
||||
|
||||
class ResBlock(torch.nn.Module):
|
||||
"""Residual block module in HiFiGAN/BigVGAN."""
|
||||
def __init__(
|
||||
self,
|
||||
channels: int = 512,
|
||||
kernel_size: int = 3,
|
||||
dilations: List[int] = [1, 3, 5],
|
||||
causal: bool = False,
|
||||
):
|
||||
super(ResBlock, self).__init__()
|
||||
self.causal = causal
|
||||
self.convs1 = nn.ModuleList()
|
||||
self.convs2 = nn.ModuleList()
|
||||
|
||||
for dilation in dilations:
|
||||
self.convs1.append(
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation,
|
||||
padding=get_padding(kernel_size, dilation)) if causal is False else
|
||||
CausalConv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=dilation,
|
||||
causal_type='left'
|
||||
)
|
||||
)
|
||||
)
|
||||
self.convs2.append(
|
||||
weight_norm(
|
||||
Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
padding=get_padding(kernel_size, 1)) if causal is False else
|
||||
CausalConv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
1,
|
||||
dilation=1,
|
||||
causal_type='left'
|
||||
)
|
||||
)
|
||||
)
|
||||
self.convs1.apply(init_weights)
|
||||
self.convs2.apply(init_weights)
|
||||
self.activations1 = nn.ModuleList([
|
||||
Snake(channels, alpha_logscale=False)
|
||||
for _ in range(len(self.convs1))
|
||||
])
|
||||
self.activations2 = nn.ModuleList([
|
||||
Snake(channels, alpha_logscale=False)
|
||||
for _ in range(len(self.convs2))
|
||||
])
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
for idx in range(len(self.convs1)):
|
||||
xt = self.activations1[idx](x)
|
||||
xt = self.convs1[idx](xt)
|
||||
xt = self.activations2[idx](xt)
|
||||
xt = self.convs2[idx](xt)
|
||||
x = xt + x
|
||||
return x
|
||||
|
||||
def remove_weight_norm(self):
|
||||
for idx in range(len(self.convs1)):
|
||||
remove_weight_norm(self.convs1[idx])
|
||||
remove_weight_norm(self.convs2[idx])
|
||||
|
||||
|
||||
class SineGen(torch.nn.Module):
|
||||
""" Definition of sine generator
|
||||
SineGen(samp_rate, harmonic_num = 0,
|
||||
sine_amp = 0.1, noise_std = 0.003,
|
||||
voiced_threshold = 0,
|
||||
flag_for_pulse=False)
|
||||
samp_rate: sampling rate in Hz
|
||||
harmonic_num: number of harmonic overtones (default 0)
|
||||
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
||||
noise_std: std of Gaussian noise (default 0.003)
|
||||
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
||||
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
||||
Note: when flag_for_pulse is True, the first time step of a voiced
|
||||
segment is always sin(np.pi) or cos(0)
|
||||
"""
|
||||
|
||||
def __init__(self, samp_rate, harmonic_num=0,
|
||||
sine_amp=0.1, noise_std=0.003,
|
||||
voiced_threshold=0):
|
||||
super(SineGen, self).__init__()
|
||||
self.sine_amp = sine_amp
|
||||
self.noise_std = noise_std
|
||||
self.harmonic_num = harmonic_num
|
||||
self.sampling_rate = samp_rate
|
||||
self.voiced_threshold = voiced_threshold
|
||||
|
||||
def _f02uv(self, f0):
|
||||
# generate uv signal
|
||||
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
||||
return uv
|
||||
|
||||
@torch.no_grad()
|
||||
def forward(self, f0):
|
||||
""" sine_tensor, uv = forward(f0)
|
||||
input F0: tensor(batchsize=1, dim=1, length)
|
||||
f0 for unvoiced steps should be 0
|
||||
output sine_tensor: tensor(batchsize=1, length, dim)
|
||||
output uv: tensor(batchsize=1, length, 1)
|
||||
"""
|
||||
f0 = f0.transpose(1, 2)
|
||||
F_mat = torch.zeros((f0.size(0), self.harmonic_num + 1, f0.size(-1))).to(f0.device)
|
||||
for i in range(self.harmonic_num + 1):
|
||||
F_mat[:, i: i + 1, :] = f0 * (i + 1) / self.sampling_rate
|
||||
|
||||
theta_mat = 2 * np.pi * (torch.cumsum(F_mat, dim=-1) % 1)
|
||||
u_dist = Uniform(low=-np.pi, high=np.pi)
|
||||
phase_vec = u_dist.sample(sample_shape=(f0.size(0), self.harmonic_num + 1, 1)).to(F_mat.device)
|
||||
phase_vec[:, 0, :] = 0
|
||||
|
||||
# generate sine waveforms
|
||||
sine_waves = self.sine_amp * torch.sin(theta_mat + phase_vec)
|
||||
|
||||
# generate uv signal
|
||||
uv = self._f02uv(f0)
|
||||
|
||||
# noise: for unvoiced should be similar to sine_amp
|
||||
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
||||
# . for voiced regions is self.noise_std
|
||||
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
||||
noise = noise_amp * torch.randn_like(sine_waves)
|
||||
|
||||
# first: set the unvoiced part to 0 by uv
|
||||
# then: additive noise
|
||||
sine_waves = sine_waves * uv + noise
|
||||
return sine_waves.transpose(1, 2), uv.transpose(1, 2), noise
|
||||
|
||||
|
||||
class SineGen2(torch.nn.Module):
|
||||
""" Definition of sine generator
|
||||
SineGen(samp_rate, harmonic_num = 0,
|
||||
sine_amp = 0.1, noise_std = 0.003,
|
||||
voiced_threshold = 0,
|
||||
flag_for_pulse=False)
|
||||
samp_rate: sampling rate in Hz
|
||||
harmonic_num: number of harmonic overtones (default 0)
|
||||
sine_amp: amplitude of sine-wavefrom (default 0.1)
|
||||
noise_std: std of Gaussian noise (default 0.003)
|
||||
voiced_thoreshold: F0 threshold for U/V classification (default 0)
|
||||
flag_for_pulse: this SinGen is used inside PulseGen (default False)
|
||||
Note: when flag_for_pulse is True, the first time step of a voiced
|
||||
segment is always sin(np.pi) or cos(0)
|
||||
"""
|
||||
|
||||
def __init__(self, samp_rate, upsample_scale, harmonic_num=0,
|
||||
sine_amp=0.1, noise_std=0.003,
|
||||
voiced_threshold=0,
|
||||
flag_for_pulse=False,
|
||||
causal=False):
|
||||
super(SineGen2, self).__init__()
|
||||
self.sine_amp = sine_amp
|
||||
self.noise_std = noise_std
|
||||
self.harmonic_num = harmonic_num
|
||||
self.dim = self.harmonic_num + 1
|
||||
self.sampling_rate = samp_rate
|
||||
self.voiced_threshold = voiced_threshold
|
||||
self.flag_for_pulse = flag_for_pulse
|
||||
self.upsample_scale = upsample_scale
|
||||
self.causal = causal
|
||||
if causal is True:
|
||||
self.rand_ini = torch.rand(1, 9)
|
||||
self.rand_ini[:, 0] = 0
|
||||
self.sine_waves = torch.rand(1, 300 * 24000, 9)
|
||||
|
||||
def _f02uv(self, f0):
|
||||
# generate uv signal
|
||||
uv = (f0 > self.voiced_threshold).type(torch.float32)
|
||||
return uv
|
||||
|
||||
def _f02sine(self, f0_values):
|
||||
""" f0_values: (batchsize, length, dim)
|
||||
where dim indicates fundamental tone and overtones
|
||||
"""
|
||||
# convert to F0 in rad. The interger part n can be ignored
|
||||
# because 2 * np.pi * n doesn't affect phase
|
||||
rad_values = (f0_values / self.sampling_rate) % 1
|
||||
|
||||
# initial phase noise (no noise for fundamental component)
|
||||
if self.training is False and self.causal is True:
|
||||
rad_values[:, 0, :] = rad_values[:, 0, :] + self.rand_ini.to(rad_values.device)
|
||||
else:
|
||||
rand_ini = torch.rand(f0_values.shape[0], f0_values.shape[2], device=f0_values.device)
|
||||
rand_ini[:, 0] = 0
|
||||
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
||||
|
||||
# instantanouse phase sine[t] = sin(2*pi \sum_i=1 ^{t} rad)
|
||||
if not self.flag_for_pulse:
|
||||
rad_values = torch.nn.functional.interpolate(rad_values.transpose(1, 2),
|
||||
scale_factor=1 / self.upsample_scale,
|
||||
mode="linear").transpose(1, 2)
|
||||
|
||||
phase = torch.cumsum(rad_values, dim=1) * 2 * np.pi
|
||||
phase = torch.nn.functional.interpolate(phase.transpose(1, 2) * self.upsample_scale,
|
||||
scale_factor=self.upsample_scale, mode="nearest" if self.causal is True else 'linear').transpose(1, 2)
|
||||
sines = torch.sin(phase)
|
||||
else:
|
||||
# If necessary, make sure that the first time step of every
|
||||
# voiced segments is sin(pi) or cos(0)
|
||||
# This is used for pulse-train generation
|
||||
|
||||
# identify the last time step in unvoiced segments
|
||||
uv = self._f02uv(f0_values)
|
||||
uv_1 = torch.roll(uv, shifts=-1, dims=1)
|
||||
uv_1[:, -1, :] = 1
|
||||
u_loc = (uv < 1) * (uv_1 > 0)
|
||||
|
||||
# get the instantanouse phase
|
||||
tmp_cumsum = torch.cumsum(rad_values, dim=1)
|
||||
# different batch needs to be processed differently
|
||||
for idx in range(f0_values.shape[0]):
|
||||
temp_sum = tmp_cumsum[idx, u_loc[idx, :, 0], :]
|
||||
temp_sum[1:, :] = temp_sum[1:, :] - temp_sum[0:-1, :]
|
||||
# stores the accumulation of i.phase within
|
||||
# each voiced segments
|
||||
tmp_cumsum[idx, :, :] = 0
|
||||
tmp_cumsum[idx, u_loc[idx, :, 0], :] = temp_sum
|
||||
|
||||
# rad_values - tmp_cumsum: remove the accumulation of i.phase
|
||||
# within the previous voiced segment.
|
||||
i_phase = torch.cumsum(rad_values - tmp_cumsum, dim=1)
|
||||
|
||||
# get the sines
|
||||
sines = torch.cos(i_phase * 2 * np.pi)
|
||||
return sines
|
||||
|
||||
def forward(self, f0):
|
||||
""" sine_tensor, uv = forward(f0)
|
||||
input F0: tensor(batchsize=1, length, dim=1)
|
||||
f0 for unvoiced steps should be 0
|
||||
output sine_tensor: tensor(batchsize=1, length, dim)
|
||||
output uv: tensor(batchsize=1, length, 1)
|
||||
"""
|
||||
# fundamental component
|
||||
fn = torch.multiply(f0, torch.FloatTensor([[range(1, self.harmonic_num + 2)]]).to(f0.device))
|
||||
|
||||
# generate sine waveforms
|
||||
sine_waves = self._f02sine(fn) * self.sine_amp
|
||||
|
||||
# generate uv signal
|
||||
uv = self._f02uv(f0)
|
||||
|
||||
# noise: for unvoiced should be similar to sine_amp
|
||||
# std = self.sine_amp/3 -> max value ~ self.sine_amp
|
||||
# . for voiced regions is self.noise_std
|
||||
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
||||
if self.training is False and self.causal is True:
|
||||
noise = noise_amp * self.sine_waves[:, :sine_waves.shape[1]].to(sine_waves.device)
|
||||
else:
|
||||
noise = noise_amp * torch.randn_like(sine_waves)
|
||||
|
||||
# first: set the unvoiced part to 0 by uv
|
||||
# then: additive noise
|
||||
sine_waves = sine_waves * uv + noise
|
||||
return sine_waves, uv, noise
|
||||
|
||||
|
||||
class SourceModuleHnNSF(torch.nn.Module):
|
||||
""" SourceModule for hn-nsf
|
||||
SourceModule(sampling_rate, harmonic_num=0, sine_amp=0.1,
|
||||
add_noise_std=0.003, voiced_threshod=0)
|
||||
sampling_rate: sampling_rate in Hz
|
||||
harmonic_num: number of harmonic above F0 (default: 0)
|
||||
sine_amp: amplitude of sine source signal (default: 0.1)
|
||||
add_noise_std: std of additive Gaussian noise (default: 0.003)
|
||||
note that amplitude of noise in unvoiced is decided
|
||||
by sine_amp
|
||||
voiced_threshold: threhold to set U/V given F0 (default: 0)
|
||||
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||
F0_sampled (batchsize, length, 1)
|
||||
Sine_source (batchsize, length, 1)
|
||||
noise_source (batchsize, length 1)
|
||||
uv (batchsize, length, 1)
|
||||
"""
|
||||
|
||||
def __init__(self, sampling_rate, upsample_scale, harmonic_num=0, sine_amp=0.1,
|
||||
add_noise_std=0.003, voiced_threshod=0, sinegen_type='1', causal=False):
|
||||
super(SourceModuleHnNSF, self).__init__()
|
||||
|
||||
self.sine_amp = sine_amp
|
||||
self.noise_std = add_noise_std
|
||||
|
||||
# to produce sine waveforms
|
||||
if sinegen_type == '1':
|
||||
self.l_sin_gen = SineGen(sampling_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
|
||||
else:
|
||||
self.l_sin_gen = SineGen2(sampling_rate, upsample_scale, harmonic_num, sine_amp, add_noise_std, voiced_threshod, causal=causal)
|
||||
|
||||
# to merge source harmonics into a single excitation
|
||||
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
||||
self.l_tanh = torch.nn.Tanh()
|
||||
self.causal = causal
|
||||
if causal is True:
|
||||
self.uv = torch.rand(1, 300 * 24000, 1)
|
||||
|
||||
def forward(self, x):
|
||||
"""
|
||||
Sine_source, noise_source = SourceModuleHnNSF(F0_sampled)
|
||||
F0_sampled (batchsize, length, 1)
|
||||
Sine_source (batchsize, length, 1)
|
||||
noise_source (batchsize, length 1)
|
||||
"""
|
||||
# source for harmonic branch
|
||||
with torch.no_grad():
|
||||
sine_wavs, uv, _ = self.l_sin_gen(x)
|
||||
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
||||
|
||||
# source for noise branch, in the same shape as uv
|
||||
if self.training is False and self.causal is True:
|
||||
noise = self.uv[:, :uv.shape[1]] * self.sine_amp / 3
|
||||
else:
|
||||
noise = torch.randn_like(uv) * self.sine_amp / 3
|
||||
return sine_merge, noise, uv
|
||||
|
||||
|
||||
class HiFTGenerator(nn.Module):
|
||||
"""
|
||||
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
||||
https://arxiv.org/abs/2309.09493
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 80,
|
||||
base_channels: int = 512,
|
||||
nb_harmonics: int = 8,
|
||||
sampling_rate: int = 22050,
|
||||
nsf_alpha: float = 0.1,
|
||||
nsf_sigma: float = 0.003,
|
||||
nsf_voiced_threshold: float = 10,
|
||||
upsample_rates: List[int] = [8, 8],
|
||||
upsample_kernel_sizes: List[int] = [16, 16],
|
||||
istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
|
||||
resblock_kernel_sizes: List[int] = [3, 7, 11],
|
||||
resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
source_resblock_kernel_sizes: List[int] = [7, 11],
|
||||
source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
|
||||
lrelu_slope: float = 0.1,
|
||||
audio_limit: float = 0.99,
|
||||
f0_predictor: torch.nn.Module = None,
|
||||
):
|
||||
super(HiFTGenerator, self).__init__()
|
||||
|
||||
self.out_channels = 1
|
||||
self.nb_harmonics = nb_harmonics
|
||||
self.sampling_rate = sampling_rate
|
||||
self.istft_params = istft_params
|
||||
self.lrelu_slope = lrelu_slope
|
||||
self.audio_limit = audio_limit
|
||||
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
# NOTE in CosyVoice2, we use the original SineGen implementation
|
||||
self.m_source = SourceModuleHnNSF(
|
||||
sampling_rate=sampling_rate,
|
||||
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
||||
harmonic_num=nb_harmonics,
|
||||
sine_amp=nsf_alpha,
|
||||
add_noise_std=nsf_sigma,
|
||||
voiced_threshod=nsf_voiced_threshold,
|
||||
sinegen_type='1' if self.sampling_rate == 22050 else '2',
|
||||
causal=False)
|
||||
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
|
||||
|
||||
self.conv_pre = weight_norm(
|
||||
Conv1d(in_channels, base_channels, 7, 1, padding=3)
|
||||
)
|
||||
|
||||
# Up
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
ConvTranspose1d(
|
||||
base_channels // (2**i),
|
||||
base_channels // (2**(i + 1)),
|
||||
k,
|
||||
u,
|
||||
padding=(k - u) // 2,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Down
|
||||
self.source_downs = nn.ModuleList()
|
||||
self.source_resblocks = nn.ModuleList()
|
||||
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
||||
downsample_cum_rates = np.cumprod(downsample_rates)
|
||||
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
|
||||
if u == 1:
|
||||
self.source_downs.append(
|
||||
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
|
||||
)
|
||||
else:
|
||||
self.source_downs.append(
|
||||
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u, padding=(u // 2))
|
||||
)
|
||||
|
||||
self.source_resblocks.append(
|
||||
ResBlock(base_channels // (2 ** (i + 1)), k, d)
|
||||
)
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = base_channels // (2**(i + 1))
|
||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(ResBlock(ch, k, d))
|
||||
|
||||
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
|
||||
self.ups.apply(init_weights)
|
||||
self.conv_post.apply(init_weights)
|
||||
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
||||
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
|
||||
self.f0_predictor = f0_predictor
|
||||
|
||||
def remove_weight_norm(self):
|
||||
print('Removing weight norm...')
|
||||
for l in self.ups:
|
||||
remove_weight_norm(l)
|
||||
for l in self.resblocks:
|
||||
l.remove_weight_norm()
|
||||
remove_weight_norm(self.conv_pre)
|
||||
remove_weight_norm(self.conv_post)
|
||||
self.m_source.remove_weight_norm()
|
||||
for l in self.source_downs:
|
||||
remove_weight_norm(l)
|
||||
for l in self.source_resblocks:
|
||||
l.remove_weight_norm()
|
||||
|
||||
def _stft(self, x):
|
||||
spec = torch.stft(
|
||||
x,
|
||||
self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(x.device),
|
||||
return_complex=True)
|
||||
spec = torch.view_as_real(spec) # [B, F, TT, 2]
|
||||
return spec[..., 0], spec[..., 1]
|
||||
|
||||
def _istft(self, magnitude, phase):
|
||||
magnitude = torch.clip(magnitude, max=1e2)
|
||||
real = magnitude * torch.cos(phase)
|
||||
img = magnitude * torch.sin(phase)
|
||||
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
|
||||
self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
||||
return inverse_transform
|
||||
|
||||
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
||||
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
||||
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
||||
|
||||
x = self.conv_pre(x)
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, self.lrelu_slope)
|
||||
x = self.ups[i](x)
|
||||
|
||||
if i == self.num_upsamples - 1:
|
||||
x = self.reflection_pad(x)
|
||||
|
||||
# fusion
|
||||
si = self.source_downs[i](s_stft)
|
||||
si = self.source_resblocks[i](si)
|
||||
x = x + si
|
||||
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
|
||||
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
|
||||
|
||||
x = self._istft(magnitude, phase)
|
||||
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict,
|
||||
device: torch.device,
|
||||
) -> Dict[str, Optional[torch.Tensor]]:
|
||||
speech_feat = batch['speech_feat'].transpose(1, 2).to(device)
|
||||
# mel->f0
|
||||
f0 = self.f0_predictor(speech_feat)
|
||||
# f0->source
|
||||
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
||||
s, _, _ = self.m_source(s)
|
||||
s = s.transpose(1, 2)
|
||||
# mel+source->speech
|
||||
generated_speech = self.decode(x=speech_feat, s=s)
|
||||
return generated_speech, f0
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(self, speech_feat: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
||||
# mel->f0
|
||||
f0 = self.f0_predictor(speech_feat)
|
||||
# f0->source
|
||||
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
||||
s, _, _ = self.m_source(s)
|
||||
s = s.transpose(1, 2)
|
||||
# use cache_source to avoid glitch
|
||||
if cache_source.shape[2] != 0:
|
||||
s[:, :, :cache_source.shape[2]] = cache_source
|
||||
generated_speech = self.decode(x=speech_feat, s=s)
|
||||
return generated_speech, s
|
||||
|
||||
|
||||
class CausalHiFTGenerator(HiFTGenerator):
|
||||
"""
|
||||
HiFTNet Generator: Neural Source Filter + ISTFTNet
|
||||
https://arxiv.org/abs/2309.09493
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int = 80,
|
||||
base_channels: int = 512,
|
||||
nb_harmonics: int = 8,
|
||||
sampling_rate: int = 22050,
|
||||
nsf_alpha: float = 0.1,
|
||||
nsf_sigma: float = 0.003,
|
||||
nsf_voiced_threshold: float = 10,
|
||||
upsample_rates: List[int] = [8, 8],
|
||||
upsample_kernel_sizes: List[int] = [16, 16],
|
||||
istft_params: Dict[str, int] = {"n_fft": 16, "hop_len": 4},
|
||||
resblock_kernel_sizes: List[int] = [3, 7, 11],
|
||||
resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
source_resblock_kernel_sizes: List[int] = [7, 11],
|
||||
source_resblock_dilation_sizes: List[List[int]] = [[1, 3, 5], [1, 3, 5]],
|
||||
lrelu_slope: float = 0.1,
|
||||
audio_limit: float = 0.99,
|
||||
conv_pre_look_right: int = 4,
|
||||
f0_predictor: torch.nn.Module = None,
|
||||
):
|
||||
torch.nn.Module.__init__(self)
|
||||
|
||||
self.out_channels = 1
|
||||
self.nb_harmonics = nb_harmonics
|
||||
self.sampling_rate = sampling_rate
|
||||
self.istft_params = istft_params
|
||||
self.lrelu_slope = lrelu_slope
|
||||
self.audio_limit = audio_limit
|
||||
|
||||
self.num_kernels = len(resblock_kernel_sizes)
|
||||
self.num_upsamples = len(upsample_rates)
|
||||
self.m_source = SourceModuleHnNSF(
|
||||
sampling_rate=sampling_rate,
|
||||
upsample_scale=np.prod(upsample_rates) * istft_params["hop_len"],
|
||||
harmonic_num=nb_harmonics,
|
||||
sine_amp=nsf_alpha,
|
||||
add_noise_std=nsf_sigma,
|
||||
voiced_threshod=nsf_voiced_threshold,
|
||||
sinegen_type='1' if self.sampling_rate == 22050 else '2',
|
||||
causal=True)
|
||||
self.upsample_rates = upsample_rates
|
||||
self.f0_upsamp = torch.nn.Upsample(scale_factor=np.prod(upsample_rates) * istft_params["hop_len"])
|
||||
|
||||
self.conv_pre = weight_norm(
|
||||
CausalConv1d(in_channels, base_channels, conv_pre_look_right + 1, 1, causal_type='right')
|
||||
)
|
||||
|
||||
# Up
|
||||
self.ups = nn.ModuleList()
|
||||
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
||||
self.ups.append(
|
||||
weight_norm(
|
||||
CausalConv1dUpsample(
|
||||
base_channels // (2**i),
|
||||
base_channels // (2**(i + 1)),
|
||||
k,
|
||||
u,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# Down
|
||||
self.source_downs = nn.ModuleList()
|
||||
self.source_resblocks = nn.ModuleList()
|
||||
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
||||
downsample_cum_rates = np.cumprod(downsample_rates)
|
||||
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
|
||||
if u == 1:
|
||||
self.source_downs.append(
|
||||
CausalConv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1, causal_type='left')
|
||||
)
|
||||
else:
|
||||
self.source_downs.append(
|
||||
CausalConv1dDownSample(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), u * 2, u)
|
||||
)
|
||||
|
||||
self.source_resblocks.append(
|
||||
ResBlock(base_channels // (2 ** (i + 1)), k, d, causal=True)
|
||||
)
|
||||
|
||||
self.resblocks = nn.ModuleList()
|
||||
for i in range(len(self.ups)):
|
||||
ch = base_channels // (2**(i + 1))
|
||||
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
||||
self.resblocks.append(ResBlock(ch, k, d, causal=True))
|
||||
|
||||
self.conv_post = weight_norm(CausalConv1d(ch, istft_params["n_fft"] + 2, 7, 1, causal_type='left'))
|
||||
self.ups.apply(init_weights)
|
||||
self.conv_post.apply(init_weights)
|
||||
self.reflection_pad = nn.ReflectionPad1d((1, 0))
|
||||
self.stft_window = torch.from_numpy(get_window("hann", istft_params["n_fft"], fftbins=True).astype(np.float32))
|
||||
self.conv_pre_look_right = conv_pre_look_right
|
||||
self.f0_predictor = f0_predictor
|
||||
|
||||
def decode(self, x: torch.Tensor, s: torch.Tensor = torch.zeros(1, 1, 0), finalize: bool = True) -> torch.Tensor:
|
||||
s_stft_real, s_stft_imag = self._stft(s.squeeze(1))
|
||||
if finalize is True:
|
||||
x = self.conv_pre(x)
|
||||
else:
|
||||
x = self.conv_pre(x[:, :, :-self.conv_pre_look_right], x[:, :, -self.conv_pre_look_right:])
|
||||
s_stft_real = s_stft_real[:, :, :-int(np.prod(self.upsample_rates) * self.conv_pre_look_right)]
|
||||
s_stft_imag = s_stft_imag[:, :, :-int(np.prod(self.upsample_rates) * self.conv_pre_look_right)]
|
||||
s_stft = torch.cat([s_stft_real, s_stft_imag], dim=1)
|
||||
|
||||
for i in range(self.num_upsamples):
|
||||
x = F.leaky_relu(x, self.lrelu_slope)
|
||||
x = self.ups[i](x)
|
||||
|
||||
if i == self.num_upsamples - 1:
|
||||
x = self.reflection_pad(x)
|
||||
|
||||
# fusion
|
||||
si = self.source_downs[i](s_stft)
|
||||
si = self.source_resblocks[i](si)
|
||||
x = x + si
|
||||
|
||||
xs = None
|
||||
for j in range(self.num_kernels):
|
||||
if xs is None:
|
||||
xs = self.resblocks[i * self.num_kernels + j](x)
|
||||
else:
|
||||
xs += self.resblocks[i * self.num_kernels + j](x)
|
||||
x = xs / self.num_kernels
|
||||
|
||||
x = F.leaky_relu(x)
|
||||
x = self.conv_post(x)
|
||||
magnitude = torch.exp(x[:, :self.istft_params["n_fft"] // 2 + 1, :])
|
||||
phase = torch.sin(x[:, self.istft_params["n_fft"] // 2 + 1:, :]) # actually, sin is redundancy
|
||||
|
||||
x = self._istft(magnitude, phase)
|
||||
if finalize is False:
|
||||
x = x[:, :-int(np.prod(self.upsample_rates) * self.istft_params['hop_len'])]
|
||||
x = torch.clamp(x, -self.audio_limit, self.audio_limit)
|
||||
return x
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(self, speech_feat: torch.Tensor, finalize: bool = True) -> torch.Tensor:
|
||||
# mel->f0 NOTE f0_predictor precision is crucial for causal inference, move self.f0_predictor to cpu if necessary
|
||||
self.f0_predictor.to(torch.float64)
|
||||
f0 = self.f0_predictor(speech_feat.to(torch.float64), finalize=finalize).to(speech_feat)
|
||||
# f0->source
|
||||
s = self.f0_upsamp(f0[:, None]).transpose(1, 2) # bs,n,t
|
||||
s, _, _ = self.m_source(s)
|
||||
s = s.transpose(1, 2)
|
||||
if finalize is True:
|
||||
generated_speech = self.decode(x=speech_feat, s=s, finalize=finalize)
|
||||
else:
|
||||
generated_speech = self.decode(x=speech_feat[:, :, :-self.f0_predictor.condnet[0].causal_padding], s=s, finalize=finalize)
|
||||
return generated_speech, s
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
from hyperpyyaml import load_hyperpyyaml
|
||||
with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
|
||||
configs = load_hyperpyyaml(f, overrides={'llm': None, 'flow': None})
|
||||
model = configs['hift']
|
||||
device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
||||
model.to(device)
|
||||
model.eval()
|
||||
max_len, chunk_size, context_size = 300, 30, 8
|
||||
mel = torch.rand(1, 80, max_len).to(device)
|
||||
pred_gt, _ = model.inference(mel)
|
||||
for i in range(0, max_len, chunk_size):
|
||||
finalize = True if i + chunk_size + context_size >= max_len else False
|
||||
pred_chunk, _ = model.inference(mel[:, :, : i + chunk_size + context_size], finalize=finalize)
|
||||
pred_chunk = pred_chunk[:, i * 480:]
|
||||
print((pred_gt[:, i * 480:i * 480 + pred_chunk.shape[1]] - pred_chunk).abs().max().item())
|
||||
67
models/CosyVoice/cosyvoice/hifigan/hifigan.py
Normal file
67
models/CosyVoice/cosyvoice/hifigan/hifigan.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from typing import Dict, Optional
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from matcha.hifigan.models import feature_loss, generator_loss, discriminator_loss
|
||||
from cosyvoice.utils.losses import tpr_loss, mel_loss
|
||||
|
||||
|
||||
class HiFiGan(nn.Module):
|
||||
def __init__(self, generator, discriminator, mel_spec_transform,
|
||||
multi_mel_spectral_recon_loss_weight=45, feat_match_loss_weight=2.0,
|
||||
tpr_loss_weight=1.0, tpr_loss_tau=0.04):
|
||||
super(HiFiGan, self).__init__()
|
||||
self.generator = generator
|
||||
self.discriminator = discriminator
|
||||
self.mel_spec_transform = mel_spec_transform
|
||||
self.multi_mel_spectral_recon_loss_weight = multi_mel_spectral_recon_loss_weight
|
||||
self.feat_match_loss_weight = feat_match_loss_weight
|
||||
self.tpr_loss_weight = tpr_loss_weight
|
||||
self.tpr_loss_tau = tpr_loss_tau
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict,
|
||||
device: torch.device,
|
||||
) -> Dict[str, Optional[torch.Tensor]]:
|
||||
if batch['turn'] == 'generator':
|
||||
return self.forward_generator(batch, device)
|
||||
else:
|
||||
return self.forward_discriminator(batch, device)
|
||||
|
||||
def forward_generator(self, batch, device):
|
||||
real_speech = batch['speech'].to(device)
|
||||
pitch_feat = batch['pitch_feat'].to(device)
|
||||
# 1. calculate generator outputs
|
||||
generated_speech, generated_f0 = self.generator(batch, device)
|
||||
# 2. calculate discriminator outputs
|
||||
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech)
|
||||
# 3. calculate generator losses, feature loss, mel loss, tpr losses [Optional]
|
||||
loss_gen, _ = generator_loss(y_d_gs)
|
||||
loss_fm = feature_loss(fmap_rs, fmap_gs)
|
||||
loss_mel = mel_loss(real_speech, generated_speech, self.mel_spec_transform)
|
||||
if self.tpr_loss_weight != 0:
|
||||
loss_tpr = tpr_loss(y_d_gs, y_d_rs, self.tpr_loss_tau)
|
||||
else:
|
||||
loss_tpr = torch.zeros(1).to(device)
|
||||
loss_f0 = F.l1_loss(generated_f0, pitch_feat)
|
||||
loss = loss_gen + self.feat_match_loss_weight * loss_fm + \
|
||||
self.multi_mel_spectral_recon_loss_weight * loss_mel + \
|
||||
self.tpr_loss_weight * loss_tpr + loss_f0
|
||||
return {'loss': loss, 'loss_gen': loss_gen, 'loss_fm': loss_fm, 'loss_mel': loss_mel, 'loss_tpr': loss_tpr, 'loss_f0': loss_f0}
|
||||
|
||||
def forward_discriminator(self, batch, device):
|
||||
real_speech = batch['speech'].to(device)
|
||||
# 1. calculate generator outputs
|
||||
with torch.no_grad():
|
||||
generated_speech, generated_f0 = self.generator(batch, device)
|
||||
# 2. calculate discriminator outputs
|
||||
y_d_rs, y_d_gs, fmap_rs, fmap_gs = self.discriminator(real_speech, generated_speech.detach())
|
||||
# 3. calculate discriminator losses, tpr losses [Optional]
|
||||
loss_disc, _, _ = discriminator_loss(y_d_rs, y_d_gs)
|
||||
if self.tpr_loss_weight != 0:
|
||||
loss_tpr = tpr_loss(y_d_rs, y_d_gs, self.tpr_loss_tau)
|
||||
else:
|
||||
loss_tpr = torch.zeros(1).to(device)
|
||||
loss = loss_disc + self.tpr_loss_weight * loss_tpr
|
||||
return {'loss': loss, 'loss_disc': loss_disc, 'loss_tpr': loss_tpr}
|
||||
767
models/CosyVoice/cosyvoice/llm/llm.py
Normal file
767
models/CosyVoice/cosyvoice/llm/llm.py
Normal file
@@ -0,0 +1,767 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
||||
# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li, Qihua, Shengqiang Li)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os, queue
|
||||
import random
|
||||
import time
|
||||
import threading
|
||||
from typing import Dict, Optional, Callable, List, Generator
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
from transformers import Qwen2ForCausalLM
|
||||
from torch.nn.utils.rnn import pad_sequence, unpad_sequence
|
||||
from cosyvoice.utils.common import IGNORE_ID
|
||||
from cosyvoice.transformer.label_smoothing_loss import LabelSmoothingLoss
|
||||
from cosyvoice.utils.common import th_accuracy
|
||||
from cosyvoice.utils.file_utils import logging
|
||||
from cosyvoice.utils.mask import make_pad_mask
|
||||
from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path
|
||||
|
||||
|
||||
class TransformerLM(torch.nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
text_encoder_input_size: int,
|
||||
llm_input_size: int,
|
||||
llm_output_size: int,
|
||||
text_token_size: int,
|
||||
speech_token_size: int,
|
||||
text_encoder: torch.nn.Module,
|
||||
llm: torch.nn.Module,
|
||||
sampling: Callable,
|
||||
length_normalized_loss: bool = True,
|
||||
lsm_weight: float = 0.0,
|
||||
spk_embed_dim: int = 192,
|
||||
):
|
||||
super().__init__()
|
||||
self.llm_input_size = llm_input_size
|
||||
self.speech_token_size = speech_token_size
|
||||
# 1. build text token inputs related modules
|
||||
self.text_embedding = torch.nn.Embedding(text_token_size, text_encoder_input_size)
|
||||
self.text_encoder = text_encoder
|
||||
self.text_encoder_affine_layer = nn.Linear(
|
||||
self.text_encoder.output_size(),
|
||||
llm_input_size
|
||||
)
|
||||
|
||||
# 2. build speech token language model related modules
|
||||
self.sos = 0
|
||||
self.task_id = 1
|
||||
self.eos_token = self.speech_token_size
|
||||
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
|
||||
self.llm = llm
|
||||
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 1)
|
||||
self.criterion_ce = LabelSmoothingLoss(
|
||||
size=speech_token_size + 1,
|
||||
padding_idx=IGNORE_ID,
|
||||
smoothing=lsm_weight,
|
||||
normalize_length=length_normalized_loss,
|
||||
)
|
||||
|
||||
# 3. [Optional] build speech token related modules
|
||||
self.speech_embedding = torch.nn.Embedding(speech_token_size, llm_input_size)
|
||||
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, llm_input_size)
|
||||
|
||||
# 4. sampling method
|
||||
self.sampling = sampling
|
||||
|
||||
def encode(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_lengths: torch.Tensor,
|
||||
):
|
||||
encoder_out, encoder_mask = self.text_encoder(text, text_lengths, decoding_chunk_size=1, num_decoding_left_chunks=-1)
|
||||
encoder_out_lens = encoder_mask.squeeze(1).sum(1)
|
||||
encoder_out = self.text_encoder_affine_layer(encoder_out)
|
||||
return encoder_out, encoder_out_lens
|
||||
|
||||
def pad_unpad_sequence(self, sos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
|
||||
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
||||
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
||||
lm_input = [torch.concat([sos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
|
||||
for i in range(len(text_token))]
|
||||
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
|
||||
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
|
||||
return lm_input, lm_input_len
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict,
|
||||
device: torch.device,
|
||||
) -> Dict[str, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
text: (B, L, D)
|
||||
text_lengths: (B,)
|
||||
audio: (B, T, N) or (B, T)
|
||||
audio_lengths: (B,)
|
||||
"""
|
||||
text_token = batch['text_token'].to(device)
|
||||
text_token_len = batch['text_token_len'].to(device)
|
||||
speech_token = batch['speech_token'].to(device)
|
||||
speech_token_len = batch['speech_token_len'].to(device)
|
||||
embedding = batch['embedding'].to(device)
|
||||
|
||||
# 1. prepare llm_target
|
||||
lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
|
||||
[self.speech_token_size]) for i in range(text_token.size(0))]
|
||||
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
|
||||
|
||||
# 1. encode text_token
|
||||
text_token = self.text_embedding(text_token)
|
||||
text_token, text_token_len = self.encode(text_token, text_token_len)
|
||||
|
||||
# 2. embedding projection
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
embedding = embedding.unsqueeze(1)
|
||||
|
||||
# 3. sos and task_id
|
||||
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
|
||||
# 4. encode speech_token
|
||||
speech_token = self.speech_embedding(speech_token)
|
||||
|
||||
# 5. unpad and pad
|
||||
lm_input, lm_input_len = self.pad_unpad_sequence(sos_emb, embedding, text_token, text_token_len,
|
||||
task_id_emb, speech_token, speech_token_len)
|
||||
|
||||
# 6. run lm forward
|
||||
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
||||
logits = self.llm_decoder(lm_output)
|
||||
loss = self.criterion_ce(logits, lm_target)
|
||||
acc = th_accuracy(logits.view(-1, self.speech_token_size + 1), lm_target, ignore_label=IGNORE_ID)
|
||||
return {'loss': loss, 'acc': acc}
|
||||
|
||||
def sampling_ids(
|
||||
self,
|
||||
weighted_scores: torch.Tensor,
|
||||
decoded_tokens: List,
|
||||
sampling: int,
|
||||
ignore_eos: bool = True,
|
||||
):
|
||||
num_trials, max_trials = 0, 100
|
||||
while True:
|
||||
top_ids = self.sampling(weighted_scores, decoded_tokens, sampling)
|
||||
if (not ignore_eos) or (top_ids < self.speech_token_size):
|
||||
break
|
||||
num_trials += 1
|
||||
if num_trials > max_trials:
|
||||
raise RuntimeError('sampling reaches max_trials {} and still get eos when ignore_eos is True, check your input!'.format(max_trials))
|
||||
return top_ids
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_len: torch.Tensor,
|
||||
prompt_text: torch.Tensor,
|
||||
prompt_text_len: torch.Tensor,
|
||||
prompt_speech_token: torch.Tensor,
|
||||
prompt_speech_token_len: torch.Tensor,
|
||||
embedding: torch.Tensor,
|
||||
sampling: int = 25,
|
||||
max_token_text_ratio: float = 20,
|
||||
min_token_text_ratio: float = 2,
|
||||
uuid: str = '',
|
||||
) -> Generator[torch.Tensor, None, None]:
|
||||
device = text.device
|
||||
text = torch.concat([prompt_text, text], dim=1)
|
||||
text_len += prompt_text_len
|
||||
text = self.text_embedding(text)
|
||||
|
||||
# 1. encode text
|
||||
text, text_len = self.encode(text, text_len)
|
||||
|
||||
# 2. encode embedding
|
||||
if embedding.shape[0] != 0:
|
||||
embedding = F.normalize(embedding, dim=1)
|
||||
embedding = self.spk_embed_affine_layer(embedding)
|
||||
embedding = embedding.unsqueeze(dim=1)
|
||||
else:
|
||||
embedding = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device).to(text.dtype)
|
||||
|
||||
# 3. concat llm_input
|
||||
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
if prompt_speech_token_len != 0:
|
||||
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||
else:
|
||||
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
||||
lm_input = torch.concat([sos_emb, embedding, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
||||
|
||||
# 4. cal min/max_length
|
||||
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
||||
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
||||
|
||||
# 5. step by step decode
|
||||
out_tokens = []
|
||||
offset = 0
|
||||
att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
|
||||
for i in range(max_len):
|
||||
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=offset, required_cache_size=-1,
|
||||
att_cache=att_cache, cnn_cache=cnn_cache,
|
||||
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
|
||||
device=lm_input.device)).to(torch.bool))
|
||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
|
||||
if top_ids == self.eos_token:
|
||||
break
|
||||
# in stream mode, yield token one by one
|
||||
yield top_ids
|
||||
out_tokens.append(top_ids)
|
||||
offset += lm_input.size(1)
|
||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||
|
||||
|
||||
class Qwen2Encoder(torch.nn.Module):
|
||||
def __init__(self, pretrain_path):
|
||||
super().__init__()
|
||||
attn_impl = os.getenv('COSYVOICE_ATTN_IMPL', 'eager')
|
||||
try:
|
||||
self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path, attn_implementation=attn_impl)
|
||||
logging.info(f'Qwen2ForCausalLM loaded with attn_implementation={attn_impl}')
|
||||
except TypeError:
|
||||
# transformers 旧版本无 attn_implementation 参数
|
||||
self.model = Qwen2ForCausalLM.from_pretrained(pretrain_path)
|
||||
logging.info('Qwen2ForCausalLM loaded without attn_implementation override')
|
||||
|
||||
def forward(self, xs: torch.Tensor, xs_lens: torch.Tensor):
|
||||
T = xs.size(1)
|
||||
masks = ~make_pad_mask(xs_lens, T)
|
||||
outs = self.model(
|
||||
inputs_embeds=xs,
|
||||
attention_mask=masks,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
)
|
||||
return outs.hidden_states[-1], masks.unsqueeze(1)
|
||||
|
||||
def forward_one_step(self, xs, masks, cache=None):
|
||||
input_masks = masks[:, -1, :]
|
||||
outs = self.model(
|
||||
inputs_embeds=xs,
|
||||
attention_mask=input_masks,
|
||||
output_hidden_states=True,
|
||||
return_dict=True,
|
||||
use_cache=True,
|
||||
past_key_values=cache,
|
||||
)
|
||||
xs = outs.hidden_states[-1]
|
||||
new_cache = outs.past_key_values
|
||||
return xs, new_cache
|
||||
|
||||
|
||||
class Qwen2LM(TransformerLM):
|
||||
def __init__(
|
||||
self,
|
||||
llm_input_size: int,
|
||||
llm_output_size: int,
|
||||
speech_token_size: int,
|
||||
llm: torch.nn.Module,
|
||||
sampling: Callable,
|
||||
length_normalized_loss: bool = True,
|
||||
lsm_weight: float = 0.0,
|
||||
mix_ratio: List[int] = [5, 15],
|
||||
):
|
||||
torch.nn.Module.__init__(self)
|
||||
self.llm_input_size = llm_input_size
|
||||
self.llm_output_size = llm_output_size
|
||||
self.speech_token_size = speech_token_size
|
||||
# 2. build speech token language model related modules
|
||||
self.sos = 0
|
||||
self.task_id = 1
|
||||
self.eos_token = speech_token_size
|
||||
self.fill_token = speech_token_size + 2
|
||||
|
||||
self.llm_embedding = torch.nn.Embedding(2, llm_input_size)
|
||||
self.llm = llm
|
||||
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 3)
|
||||
self.criterion_ce = LabelSmoothingLoss(
|
||||
size=speech_token_size + 3,
|
||||
padding_idx=IGNORE_ID,
|
||||
smoothing=lsm_weight,
|
||||
normalize_length=length_normalized_loss,
|
||||
)
|
||||
|
||||
# 3. [Optional] build speech token related modules
|
||||
self.speech_embedding = torch.nn.Embedding(speech_token_size + 3, llm_input_size)
|
||||
|
||||
# 4. sampling method
|
||||
self.sampling = sampling
|
||||
self.mix_ratio = mix_ratio
|
||||
|
||||
# 5. vllm related
|
||||
self.stop_token_ids = [speech_token_size + i for i in range(3)]
|
||||
self.vllm_output_queue = {}
|
||||
if online_feature is True:
|
||||
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx'))
|
||||
|
||||
def prepare_lm_input_target(self, sos_emb, text_token, text_token_emb, text_token_len, task_id_emb, speech_token, speech_token_emb, speech_token_len, instruct_token=None, instruct_token_emb=None, instruct_token_len=None):
|
||||
lm_target, lm_input = [], []
|
||||
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
||||
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
||||
text_token_emb = unpad_sequence(text_token_emb, text_token_len.cpu(), batch_first=True)
|
||||
speech_token_emb = unpad_sequence(speech_token_emb, speech_token_len.cpu(), batch_first=True)
|
||||
# NOTE add instruct_token in CosyVoice3
|
||||
if instruct_token is not None and instruct_token_emb is not None and instruct_token_len is not None:
|
||||
instruct_token = unpad_sequence(instruct_token, instruct_token_len.cpu(), batch_first=True)
|
||||
instruct_token_emb = unpad_sequence(instruct_token_emb, instruct_token_len.cpu(), batch_first=True)
|
||||
else:
|
||||
instruct_token = [torch.empty(0).to(text_token[0])] * len(text_token)
|
||||
instruct_token_emb = [torch.empty(0, 896).to(text_token_emb[0])] * len(text_token)
|
||||
instruct_token_len = torch.zeros(len(text_token)).to(text_token_len)
|
||||
for i in range(len(text_token)):
|
||||
# bistream sequence
|
||||
if random.random() < 0.5 and speech_token_len[i] / text_token_len[i] > self.mix_ratio[1] / self.mix_ratio[0]:
|
||||
this_lm_target, this_lm_input = [IGNORE_ID], [sos_emb.squeeze(dim=0)]
|
||||
this_lm_target += [IGNORE_ID] * instruct_token_len[i]
|
||||
this_lm_input.append(instruct_token_emb[i])
|
||||
for j in range(((text_token_len[i] + 1) / self.mix_ratio[0]).ceil().int().item()):
|
||||
this_text_token = text_token[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]].tolist()
|
||||
this_speech_token = speech_token[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]].tolist()
|
||||
if len(this_text_token) == self.mix_ratio[0]:
|
||||
assert len(this_speech_token) == self.mix_ratio[1]
|
||||
this_lm_target += [IGNORE_ID] * (self.mix_ratio[0] - 1)
|
||||
this_lm_target += this_speech_token
|
||||
this_lm_target.append(self.fill_token)
|
||||
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]: (j + 1) * self.mix_ratio[0]])
|
||||
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]: (j + 1) * self.mix_ratio[1]])
|
||||
else:
|
||||
this_lm_target += [-1] * len(this_text_token)
|
||||
this_lm_target += speech_token[i][j * self.mix_ratio[1]:].tolist()
|
||||
this_lm_target.append(self.eos_token)
|
||||
this_lm_input.append(text_token_emb[i][j * self.mix_ratio[0]:])
|
||||
this_lm_input.append(task_id_emb.squeeze(dim=0))
|
||||
this_lm_input.append(speech_token_emb[i][j * self.mix_ratio[1]:])
|
||||
this_lm_target, this_lm_input = torch.tensor(this_lm_target), torch.concat(this_lm_input, dim=0)
|
||||
# unistream sequence
|
||||
else:
|
||||
this_lm_target = torch.tensor([IGNORE_ID] * (1 + instruct_token_len[i] + text_token_len[i]) + speech_token[i].tolist() + [self.eos_token])
|
||||
this_lm_input = torch.concat([sos_emb.squeeze(dim=0), instruct_token_emb[i], text_token_emb[i], task_id_emb.squeeze(dim=0), speech_token_emb[i]], dim=0)
|
||||
lm_target.append(this_lm_target)
|
||||
lm_input.append(this_lm_input)
|
||||
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
|
||||
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
|
||||
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID)
|
||||
return lm_target, lm_input, lm_input_len
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict,
|
||||
device: torch.device,
|
||||
) -> Dict[str, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
text: (B, L, D)
|
||||
text_lengths: (B,)
|
||||
audio: (B, T, N) or (B, T)
|
||||
audio_lengths: (B,)
|
||||
"""
|
||||
text_token = batch['text_token'].to(device)
|
||||
text_token_len = batch['text_token_len'].to(device)
|
||||
if 'speech_token' not in batch:
|
||||
speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
|
||||
else:
|
||||
speech_token = batch['speech_token'].to(device)
|
||||
speech_token_len = batch['speech_token_len'].to(device)
|
||||
|
||||
# 1. encode text_token
|
||||
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
||||
|
||||
# 3. sos and task_id
|
||||
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
|
||||
# 2. encode speech_token
|
||||
speech_token_emb = self.speech_embedding(speech_token)
|
||||
|
||||
# 3. prepare llm_input/target
|
||||
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
|
||||
speech_token, speech_token_emb, speech_token_len)
|
||||
lm_target = lm_target.to(device)
|
||||
|
||||
# 4. run lm forward
|
||||
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
||||
logits = self.llm_decoder(lm_output)
|
||||
loss = self.criterion_ce(logits, lm_target.to(device))
|
||||
acc = th_accuracy(logits.view(-1, self.speech_token_size + 3), lm_target, ignore_label=IGNORE_ID)
|
||||
return {'loss': loss, 'acc': acc}
|
||||
|
||||
def forward_dpo(
|
||||
self,
|
||||
batch: dict,
|
||||
device: torch.device,
|
||||
) -> Dict[str, Optional[torch.Tensor]]:
|
||||
text_token = batch['text_token'].to(device)
|
||||
text_token_len = batch['text_token_len'].to(device)
|
||||
speech_token = batch['speech_token'].to(device)
|
||||
speech_token_len = batch['speech_token_len'].to(device)
|
||||
reject_speech_token = batch['reject_speech_token'].to(device)
|
||||
reject_speech_token_len = batch['reject_speech_token_len'].to(device)
|
||||
|
||||
# 1. encode text_token
|
||||
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
||||
|
||||
# 3. sos and task_id
|
||||
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
|
||||
# 2. encode speech_token
|
||||
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
||||
reject_speech_token = unpad_sequence(reject_speech_token, reject_speech_token_len.cpu(), batch_first=True)
|
||||
speech_token_combined = speech_token + reject_speech_token
|
||||
speech_token_combined = pad_sequence(speech_token_combined, batch_first=True, padding_value=0)
|
||||
speech_token_combined_len = torch.concat([speech_token_len, reject_speech_token_len], dim=0)
|
||||
speech_token_combined_emb = self.speech_embedding(speech_token_combined)
|
||||
|
||||
# 3. prepare llm_input/target
|
||||
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token.repeat(2, 1), text_token_emb.repeat(2, 1, 1), text_token_len.repeat(2),
|
||||
task_id_emb, speech_token_combined, speech_token_combined_emb, speech_token_combined_len)
|
||||
lm_target = lm_target.to(device)
|
||||
|
||||
# 4. run lm forward
|
||||
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
||||
logits = self.llm_decoder(lm_output)
|
||||
chosen_logits = logits[:text_token.shape[0]]
|
||||
rejected_logits = logits[text_token.shape[0]:]
|
||||
chosen_lm_target = lm_target[:text_token.shape[0]]
|
||||
rejected_lm_target = lm_target[text_token.shape[0]:]
|
||||
loss = self.criterion_ce(chosen_logits, chosen_lm_target.to(device))
|
||||
acc = th_accuracy(chosen_logits.view(-1, self.speech_token_size + 3), chosen_lm_target, ignore_label=IGNORE_ID)
|
||||
|
||||
# 5. calculate dpo logits
|
||||
chosen_lm_mask = chosen_lm_target == IGNORE_ID
|
||||
rejected_lm_mask = rejected_lm_target == IGNORE_ID
|
||||
chosen_logps = torch.gather(chosen_logits.log_softmax(dim=-1), dim=2, index=chosen_lm_target.masked_fill(chosen_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
|
||||
rejected_logps = torch.gather(rejected_logits.log_softmax(dim=-1), dim=2, index=rejected_lm_target.masked_fill(rejected_lm_mask, 0).unsqueeze(dim=-1)).squeeze(dim=-1)
|
||||
chosen_logps = (chosen_logps * chosen_lm_mask).sum(dim=-1) / chosen_lm_mask.sum(dim=-1)
|
||||
rejected_logps = (rejected_logps * rejected_lm_mask).sum(dim=-1) / rejected_lm_mask.sum(dim=-1)
|
||||
return {'loss': loss, 'acc': acc, 'chosen_logps': chosen_logps, 'rejected_logps': rejected_logps}
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_len: torch.Tensor,
|
||||
prompt_text: torch.Tensor,
|
||||
prompt_text_len: torch.Tensor,
|
||||
prompt_speech_token: torch.Tensor,
|
||||
prompt_speech_token_len: torch.Tensor,
|
||||
embedding: torch.Tensor,
|
||||
sampling: int = 25,
|
||||
max_token_text_ratio: float = 20,
|
||||
min_token_text_ratio: float = 2,
|
||||
uuid: str = '',
|
||||
) -> Generator[torch.Tensor, None, None]:
|
||||
device = text.device
|
||||
text = torch.concat([prompt_text, text], dim=1)
|
||||
text_len += prompt_text_len
|
||||
text = self.llm.model.model.embed_tokens(text)
|
||||
|
||||
# 3. concat llm_input
|
||||
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
if prompt_speech_token_len != 0:
|
||||
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||
else:
|
||||
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
||||
lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
||||
|
||||
# 4. cal min/max_length
|
||||
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
||||
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
||||
|
||||
# 5. step by step decode
|
||||
for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
|
||||
yield token
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference_wrapper(self, lm_input, sampling, min_len, max_len, uuid):
|
||||
if hasattr(self, 'vllm'):
|
||||
from vllm import SamplingParams, RequestOutput
|
||||
sampling_params = SamplingParams(top_k=sampling,
|
||||
stop_token_ids=self.stop_token_ids,
|
||||
min_tokens=min_len,
|
||||
max_tokens=max_len)
|
||||
with self.lock:
|
||||
self.vllm.add_request(uuid, {"prompt_embeds": lm_input.squeeze(0).to(torch.bfloat16).to(lm_input.device)}, sampling_params)
|
||||
self.vllm_output_queue[uuid] = queue.Queue()
|
||||
out_tokens = []
|
||||
while True:
|
||||
with self.lock:
|
||||
if self.vllm_output_queue[uuid].empty() is True:
|
||||
request_outputs: List[RequestOutput] = self.vllm.step()
|
||||
for request_output in request_outputs:
|
||||
top_ids = list(request_output.outputs[0].token_ids)[-1]
|
||||
self.vllm_output_queue[request_output.request_id].put(top_ids)
|
||||
if self.vllm_output_queue[uuid].empty() is False:
|
||||
top_ids = self.vllm_output_queue[uuid].get()
|
||||
if top_ids in self.stop_token_ids:
|
||||
break
|
||||
# in stream mode, yield token one by one
|
||||
yield top_ids
|
||||
out_tokens.append(top_ids)
|
||||
if len(out_tokens) == max_len:
|
||||
break
|
||||
time.sleep(0.001)
|
||||
with self.lock:
|
||||
self.vllm_output_queue.pop(uuid)
|
||||
else:
|
||||
out_tokens = []
|
||||
cache = None
|
||||
for i in range(max_len):
|
||||
y_pred, cache = self.llm.forward_one_step(lm_input,
|
||||
masks=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool),
|
||||
cache=cache)
|
||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False)
|
||||
if top_ids in self.stop_token_ids:
|
||||
break
|
||||
# in stream mode, yield token one by one
|
||||
yield top_ids
|
||||
out_tokens.append(top_ids)
|
||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference_bistream(
|
||||
self,
|
||||
text: Generator,
|
||||
prompt_text: torch.Tensor,
|
||||
prompt_text_len: torch.Tensor,
|
||||
prompt_speech_token: torch.Tensor,
|
||||
prompt_speech_token_len: torch.Tensor,
|
||||
embedding: torch.Tensor,
|
||||
sampling: int = 25,
|
||||
max_token_text_ratio: float = 20,
|
||||
min_token_text_ratio: float = 2,
|
||||
) -> Generator[torch.Tensor, None, None]:
|
||||
|
||||
device = prompt_text.device
|
||||
# 1. prepare input
|
||||
sos_emb = self.llm_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||
task_id_emb = self.llm_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
if prompt_speech_token_len != 0:
|
||||
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||
else:
|
||||
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=prompt_text.dtype).to(device)
|
||||
lm_input = torch.concat([sos_emb], dim=1)
|
||||
|
||||
# 2. iterate text
|
||||
out_tokens = []
|
||||
cache = None
|
||||
# NOTE init prompt_text as text_cache as it is basically impossible prompt_speech_token/prompt_text < 15/5
|
||||
text_cache = self.llm.model.model.embed_tokens(prompt_text)
|
||||
next_fill_index = (int(prompt_speech_token.shape[1] / self.mix_ratio[1]) + 1) * self.mix_ratio[1] - prompt_speech_token.shape[1]
|
||||
for this_text in text:
|
||||
text_cache = torch.concat([text_cache, self.llm.model.model.embed_tokens(this_text)], dim=1)
|
||||
# prompt_speech_token_emb not empty, try append to lm_input
|
||||
while prompt_speech_token_emb.size(1) != 0:
|
||||
if text_cache.size(1) >= self.mix_ratio[0]:
|
||||
lm_input_text, lm_input_speech = text_cache[:, :self.mix_ratio[0]], prompt_speech_token_emb[:, :self.mix_ratio[1]]
|
||||
logging.info('append {} text token {} speech token'.format(lm_input_text.size(1), lm_input_speech.size(1)))
|
||||
lm_input = torch.concat([lm_input, lm_input_text, lm_input_speech], dim=1)
|
||||
text_cache, prompt_speech_token_emb = text_cache[:, self.mix_ratio[0]:], prompt_speech_token_emb[:, self.mix_ratio[1]:]
|
||||
else:
|
||||
logging.info('not enough text token to decode, wait for more')
|
||||
break
|
||||
# no prompt_speech_token_emb remain, can decode some speech token
|
||||
if prompt_speech_token_emb.size(1) == 0:
|
||||
if (len(out_tokens) != 0 and out_tokens[-1] == self.fill_token) or (len(out_tokens) == 0 and lm_input.size(1) == 1):
|
||||
logging.info('get fill token, need to append more text token')
|
||||
if text_cache.size(1) >= self.mix_ratio[0]:
|
||||
lm_input_text = text_cache[:, :self.mix_ratio[0]]
|
||||
logging.info('append {} text token'.format(lm_input_text.size(1)))
|
||||
if len(out_tokens) != 0 and out_tokens[-1] == self.fill_token:
|
||||
lm_input = lm_input_text
|
||||
else:
|
||||
lm_input = torch.concat([lm_input, lm_input_text], dim=1)
|
||||
text_cache = text_cache[:, self.mix_ratio[0]:]
|
||||
else:
|
||||
logging.info('not enough text token to decode, wait for more')
|
||||
continue
|
||||
while True:
|
||||
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
|
||||
y_pred, cache = self.llm.forward_one_step(lm_input,
|
||||
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
|
||||
cache=cache)
|
||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||
if next_fill_index != -1 and len(out_tokens) == next_fill_index:
|
||||
top_ids = self.fill_token
|
||||
next_fill_index += (self.mix_ratio[1] + 1)
|
||||
else:
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True)
|
||||
if top_ids == self.fill_token:
|
||||
next_fill_index = len(out_tokens) + self.mix_ratio[1] + 1
|
||||
logging.info('fill_token index {} next fill_token index {}'.format(len(out_tokens), next_fill_index))
|
||||
out_tokens.append(top_ids)
|
||||
if top_ids >= self.speech_token_size:
|
||||
if top_ids == self.fill_token:
|
||||
break
|
||||
else:
|
||||
raise ValueError('should not get token {}'.format(top_ids))
|
||||
yield top_ids
|
||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||
|
||||
# 3. final decode
|
||||
lm_input = torch.concat([lm_input, text_cache, task_id_emb], dim=1)
|
||||
logging.info('no more text token, decode until met eos')
|
||||
while True:
|
||||
seq_len = lm_input.shape[1] if cache is None else lm_input.shape[1] + cache[0][0].size(2)
|
||||
y_pred, cache = self.llm.forward_one_step(lm_input,
|
||||
masks=torch.tril(torch.ones((1, seq_len, seq_len), device=lm_input.device)).to(torch.bool),
|
||||
cache=cache)
|
||||
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
||||
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=False)
|
||||
out_tokens.append(top_ids)
|
||||
if top_ids >= self.speech_token_size:
|
||||
if top_ids == self.eos_token:
|
||||
break
|
||||
else:
|
||||
raise ValueError('should not get token {}'.format(top_ids))
|
||||
# in stream mode, yield token one by one
|
||||
yield top_ids
|
||||
lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
|
||||
|
||||
|
||||
class CosyVoice3LM(Qwen2LM):
|
||||
def __init__(
|
||||
self,
|
||||
llm_input_size: int,
|
||||
llm_output_size: int,
|
||||
speech_token_size: int,
|
||||
llm: torch.nn.Module,
|
||||
sampling: Callable,
|
||||
length_normalized_loss: bool = True,
|
||||
lsm_weight: float = 0.0,
|
||||
mix_ratio: List[int] = [5, 15],
|
||||
):
|
||||
torch.nn.Module.__init__(self)
|
||||
self.llm_input_size = llm_input_size
|
||||
self.llm_output_size = llm_output_size
|
||||
self.speech_token_size = speech_token_size
|
||||
# 2. build speech token language model related modules
|
||||
self.sos = speech_token_size + 0
|
||||
self.eos_token = speech_token_size + 1
|
||||
self.task_id = speech_token_size + 2
|
||||
self.fill_token = speech_token_size + 3
|
||||
|
||||
self.llm = llm
|
||||
self.llm_decoder = nn.Linear(llm_output_size, speech_token_size + 200, bias=False)
|
||||
self.criterion_ce = LabelSmoothingLoss(
|
||||
size=speech_token_size + 200,
|
||||
padding_idx=IGNORE_ID,
|
||||
smoothing=lsm_weight,
|
||||
normalize_length=length_normalized_loss,
|
||||
)
|
||||
|
||||
# 3. [Optional] build speech token related modules
|
||||
self.speech_embedding = torch.nn.Embedding(speech_token_size + 200, llm_input_size)
|
||||
|
||||
# 4. sampling method
|
||||
self.sampling = sampling
|
||||
self.mix_ratio = mix_ratio
|
||||
|
||||
# 5. vllm related
|
||||
self.stop_token_ids = [speech_token_size + i for i in range(200)]
|
||||
self.vllm_output_queue = {}
|
||||
if online_feature is True:
|
||||
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict,
|
||||
device: torch.device,
|
||||
) -> Dict[str, Optional[torch.Tensor]]:
|
||||
"""
|
||||
Args:
|
||||
text: (B, L, D)
|
||||
text_lengths: (B,)
|
||||
audio: (B, T, N) or (B, T)
|
||||
audio_lengths: (B,)
|
||||
"""
|
||||
text_token = batch['text_token'].to(device)
|
||||
text_token_len = batch['text_token_len'].to(device)
|
||||
if 'speech_token' not in batch:
|
||||
speech_token, speech_token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
|
||||
else:
|
||||
speech_token = batch['speech_token'].to(device)
|
||||
speech_token_len = batch['speech_token_len'].to(device)
|
||||
|
||||
# NOTE should append instruct_token to sequence, not implemented yet
|
||||
instruct_token = batch['instruct_token'].to(device)
|
||||
instruct_token_len = batch['instruct_token_len'].to(device)
|
||||
|
||||
# 1. encode text_token
|
||||
text_token_emb = self.llm.model.model.embed_tokens(text_token)
|
||||
instruct_token_emb = self.llm.model.model.embed_tokens(instruct_token)
|
||||
|
||||
# 3. sos and task_id
|
||||
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||
task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
|
||||
# 2. encode speech_token
|
||||
speech_token_emb = self.speech_embedding(speech_token)
|
||||
|
||||
# 3. prepare llm_input/target
|
||||
lm_target, lm_input, lm_input_len = self.prepare_lm_input_target(sos_emb, text_token, text_token_emb, text_token_len, task_id_emb,
|
||||
speech_token, speech_token_emb, speech_token_len, instruct_token, instruct_token_emb, instruct_token_len)
|
||||
lm_target = lm_target.to(device)
|
||||
|
||||
# 4. run lm forward
|
||||
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
||||
logits = self.llm_decoder(lm_output)
|
||||
loss = self.criterion_ce(logits, lm_target.to(device))
|
||||
acc = th_accuracy(logits.view(-1, self.speech_token_size + 200), lm_target, ignore_label=IGNORE_ID)
|
||||
return {'loss': loss, 'acc': acc}
|
||||
|
||||
@torch.inference_mode()
|
||||
def inference(
|
||||
self,
|
||||
text: torch.Tensor,
|
||||
text_len: torch.Tensor,
|
||||
prompt_text: torch.Tensor,
|
||||
prompt_text_len: torch.Tensor,
|
||||
prompt_speech_token: torch.Tensor,
|
||||
prompt_speech_token_len: torch.Tensor,
|
||||
embedding: torch.Tensor,
|
||||
sampling: int = 25,
|
||||
max_token_text_ratio: float = 20,
|
||||
min_token_text_ratio: float = 2,
|
||||
uuid: str = '',
|
||||
) -> Generator[torch.Tensor, None, None]:
|
||||
device = text.device
|
||||
text = torch.concat([prompt_text, text], dim=1)
|
||||
text_len += prompt_text_len
|
||||
text = self.llm.model.model.embed_tokens(text)
|
||||
|
||||
# 3. concat llm_input
|
||||
sos_emb = self.speech_embedding.weight[self.sos].reshape(1, 1, -1)
|
||||
task_id_emb = self.speech_embedding.weight[self.task_id].reshape(1, 1, -1)
|
||||
if prompt_speech_token_len != 0:
|
||||
prompt_speech_token_emb = self.speech_embedding(prompt_speech_token)
|
||||
else:
|
||||
prompt_speech_token_emb = torch.zeros(1, 0, self.llm_input_size, dtype=text.dtype).to(device)
|
||||
lm_input = torch.concat([sos_emb, text, task_id_emb, prompt_speech_token_emb], dim=1)
|
||||
|
||||
# 4. cal min/max_length
|
||||
min_len = int((text_len - prompt_text_len) * min_token_text_ratio)
|
||||
max_len = int((text_len - prompt_text_len) * max_token_text_ratio)
|
||||
|
||||
# 5. step by step decode
|
||||
for token in self.inference_wrapper(lm_input, sampling, min_len, max_len, uuid):
|
||||
yield token
|
||||
File diff suppressed because it is too large
Load Diff
327
models/CosyVoice/cosyvoice/tokenizer/tokenizer.py
Normal file
327
models/CosyVoice/cosyvoice/tokenizer/tokenizer.py
Normal file
@@ -0,0 +1,327 @@
|
||||
import base64
|
||||
import os
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
import torch
|
||||
from transformers import AutoTokenizer
|
||||
from whisper.tokenizer import Tokenizer
|
||||
|
||||
import tiktoken
|
||||
|
||||
LANGUAGES = {
|
||||
"en": "english",
|
||||
"zh": "chinese",
|
||||
"de": "german",
|
||||
"es": "spanish",
|
||||
"ru": "russian",
|
||||
"ko": "korean",
|
||||
"fr": "french",
|
||||
"ja": "japanese",
|
||||
"pt": "portuguese",
|
||||
"tr": "turkish",
|
||||
"pl": "polish",
|
||||
"ca": "catalan",
|
||||
"nl": "dutch",
|
||||
"ar": "arabic",
|
||||
"sv": "swedish",
|
||||
"it": "italian",
|
||||
"id": "indonesian",
|
||||
"hi": "hindi",
|
||||
"fi": "finnish",
|
||||
"vi": "vietnamese",
|
||||
"he": "hebrew",
|
||||
"uk": "ukrainian",
|
||||
"el": "greek",
|
||||
"ms": "malay",
|
||||
"cs": "czech",
|
||||
"ro": "romanian",
|
||||
"da": "danish",
|
||||
"hu": "hungarian",
|
||||
"ta": "tamil",
|
||||
"no": "norwegian",
|
||||
"th": "thai",
|
||||
"ur": "urdu",
|
||||
"hr": "croatian",
|
||||
"bg": "bulgarian",
|
||||
"lt": "lithuanian",
|
||||
"la": "latin",
|
||||
"mi": "maori",
|
||||
"ml": "malayalam",
|
||||
"cy": "welsh",
|
||||
"sk": "slovak",
|
||||
"te": "telugu",
|
||||
"fa": "persian",
|
||||
"lv": "latvian",
|
||||
"bn": "bengali",
|
||||
"sr": "serbian",
|
||||
"az": "azerbaijani",
|
||||
"sl": "slovenian",
|
||||
"kn": "kannada",
|
||||
"et": "estonian",
|
||||
"mk": "macedonian",
|
||||
"br": "breton",
|
||||
"eu": "basque",
|
||||
"is": "icelandic",
|
||||
"hy": "armenian",
|
||||
"ne": "nepali",
|
||||
"mn": "mongolian",
|
||||
"bs": "bosnian",
|
||||
"kk": "kazakh",
|
||||
"sq": "albanian",
|
||||
"sw": "swahili",
|
||||
"gl": "galician",
|
||||
"mr": "marathi",
|
||||
"pa": "punjabi",
|
||||
"si": "sinhala",
|
||||
"km": "khmer",
|
||||
"sn": "shona",
|
||||
"yo": "yoruba",
|
||||
"so": "somali",
|
||||
"af": "afrikaans",
|
||||
"oc": "occitan",
|
||||
"ka": "georgian",
|
||||
"be": "belarusian",
|
||||
"tg": "tajik",
|
||||
"sd": "sindhi",
|
||||
"gu": "gujarati",
|
||||
"am": "amharic",
|
||||
"yi": "yiddish",
|
||||
"lo": "lao",
|
||||
"uz": "uzbek",
|
||||
"fo": "faroese",
|
||||
"ht": "haitian creole",
|
||||
"ps": "pashto",
|
||||
"tk": "turkmen",
|
||||
"nn": "nynorsk",
|
||||
"mt": "maltese",
|
||||
"sa": "sanskrit",
|
||||
"lb": "luxembourgish",
|
||||
"my": "myanmar",
|
||||
"bo": "tibetan",
|
||||
"tl": "tagalog",
|
||||
"mg": "malagasy",
|
||||
"as": "assamese",
|
||||
"tt": "tatar",
|
||||
"haw": "hawaiian",
|
||||
"ln": "lingala",
|
||||
"ha": "hausa",
|
||||
"ba": "bashkir",
|
||||
"jw": "javanese",
|
||||
"su": "sundanese",
|
||||
"yue": "cantonese",
|
||||
"minnan": "minnan",
|
||||
"wuyu": "wuyu",
|
||||
"dialect": "dialect",
|
||||
"zh/en": "zh/en",
|
||||
"en/zh": "en/zh",
|
||||
}
|
||||
|
||||
# language code lookup by name, with a few language aliases
|
||||
TO_LANGUAGE_CODE = {
|
||||
**{language: code for code, language in LANGUAGES.items()},
|
||||
"burmese": "my",
|
||||
"valencian": "ca",
|
||||
"flemish": "nl",
|
||||
"haitian": "ht",
|
||||
"letzeburgesch": "lb",
|
||||
"pushto": "ps",
|
||||
"panjabi": "pa",
|
||||
"moldavian": "ro",
|
||||
"moldovan": "ro",
|
||||
"sinhalese": "si",
|
||||
"castilian": "es",
|
||||
"mandarin": "zh",
|
||||
}
|
||||
|
||||
AUDIO_EVENT = {
|
||||
"ASR": "ASR",
|
||||
"AED": "AED",
|
||||
"SER": "SER",
|
||||
"Speech": "Speech",
|
||||
"/Speech": "/Speech",
|
||||
"BGM": "BGM",
|
||||
"/BGM": "/BGM",
|
||||
"Laughter": "Laughter",
|
||||
"/Laughter": "/Laughter",
|
||||
"Applause": "Applause",
|
||||
"/Applause": "/Applause",
|
||||
}
|
||||
|
||||
EMOTION = {
|
||||
"HAPPY": "HAPPY",
|
||||
"SAD": "SAD",
|
||||
"ANGRY": "ANGRY",
|
||||
"NEUTRAL": "NEUTRAL",
|
||||
}
|
||||
|
||||
TTS_Vocal_Token = {
|
||||
"TTS/B": "TTS/B",
|
||||
"TTS/O": "TTS/O",
|
||||
"TTS/Q": "TTS/Q",
|
||||
"TTS/A": "TTS/A",
|
||||
"TTS/CO": "TTS/CO",
|
||||
"TTS/CL": "TTS/CL",
|
||||
"TTS/H": "TTS/H",
|
||||
**{f"TTS/SP{i:02d}": f"TTS/SP{i:02d}" for i in range(1, 14)}
|
||||
}
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_encoding(name: str = "gpt2", num_languages: int = 99):
|
||||
vocab_path = os.path.join(os.path.dirname(__file__), "assets", f"{name}.tiktoken")
|
||||
ranks = {
|
||||
base64.b64decode(token): int(rank)
|
||||
for token, rank in (line.split() for line in open(vocab_path) if line)
|
||||
}
|
||||
n_vocab = len(ranks)
|
||||
special_tokens = {}
|
||||
|
||||
specials = [
|
||||
"<|endoftext|>",
|
||||
"<|startoftranscript|>",
|
||||
*[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]],
|
||||
*[f"<|{audio_event}|>" for audio_event in list(AUDIO_EVENT.keys())],
|
||||
*[f"<|{emotion}|>" for emotion in list(EMOTION.keys())],
|
||||
"<|translate|>",
|
||||
"<|transcribe|>",
|
||||
"<|startoflm|>",
|
||||
"<|startofprev|>",
|
||||
"<|nospeech|>",
|
||||
"<|notimestamps|>",
|
||||
*[f"<|SPECIAL_TOKEN_{i}|>" for i in range(1, 31)], # register special tokens for ASR
|
||||
*[f"<|{tts}|>" for tts in list(TTS_Vocal_Token.keys())], # register special tokens for TTS
|
||||
*[f"<|{i * 0.02:.2f}|>" for i in range(1501)],
|
||||
]
|
||||
|
||||
for token in specials:
|
||||
special_tokens[token] = n_vocab
|
||||
n_vocab += 1
|
||||
|
||||
return tiktoken.Encoding(
|
||||
name=os.path.basename(vocab_path),
|
||||
explicit_n_vocab=n_vocab,
|
||||
pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""",
|
||||
mergeable_ranks=ranks,
|
||||
special_tokens=special_tokens,
|
||||
)
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_tokenizer(
|
||||
multilingual: bool,
|
||||
*,
|
||||
num_languages: int = 99,
|
||||
language: Optional[str] = None,
|
||||
task: Optional[str] = None, # Literal["transcribe", "translate", None]
|
||||
) -> Tokenizer:
|
||||
if language is not None:
|
||||
language = language.lower()
|
||||
if language not in LANGUAGES:
|
||||
if language in TO_LANGUAGE_CODE:
|
||||
language = TO_LANGUAGE_CODE[language]
|
||||
else:
|
||||
raise ValueError(f"Unsupported language: {language}")
|
||||
|
||||
if multilingual:
|
||||
encoding_name = "multilingual_zh_ja_yue_char_del"
|
||||
language = language or "en"
|
||||
task = task or "transcribe"
|
||||
else:
|
||||
encoding_name = "gpt2"
|
||||
language = None
|
||||
task = None
|
||||
|
||||
encoding = get_encoding(name=encoding_name, num_languages=num_languages)
|
||||
|
||||
return Tokenizer(
|
||||
encoding=encoding, num_languages=num_languages, language=language, task=task
|
||||
)
|
||||
|
||||
|
||||
class CosyVoice2Tokenizer():
|
||||
def __init__(self, token_path, skip_special_tokens=True):
|
||||
super().__init__()
|
||||
# NOTE: non-chat model, all these special tokens keep randomly initialized.
|
||||
special_tokens = {
|
||||
'eos_token': '<|endoftext|>',
|
||||
'pad_token': '<|endoftext|>',
|
||||
'additional_special_tokens': [
|
||||
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
|
||||
'[breath]', '<strong>', '</strong>', '[noise]',
|
||||
'[laughter]', '[cough]', '[clucking]', '[accent]',
|
||||
'[quick_breath]',
|
||||
"<laughter>", "</laughter>",
|
||||
"[hissing]", "[sigh]", "[vocalized-noise]",
|
||||
"[lipsmack]", "[mn]"
|
||||
]
|
||||
}
|
||||
self.special_tokens = special_tokens
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(token_path)
|
||||
self.tokenizer.add_special_tokens(special_tokens)
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
|
||||
def encode(self, text, **kwargs):
|
||||
tokens = self.tokenizer([text], return_tensors="pt")
|
||||
tokens = tokens["input_ids"][0].cpu().tolist()
|
||||
return tokens
|
||||
|
||||
def decode(self, tokens):
|
||||
tokens = torch.tensor(tokens, dtype=torch.int64)
|
||||
text = self.tokenizer.batch_decode([tokens], skip_special_tokens=self.skip_special_tokens)[0]
|
||||
return text
|
||||
|
||||
|
||||
class CosyVoice3Tokenizer(CosyVoice2Tokenizer):
|
||||
def __init__(self, token_path, skip_special_tokens=True):
|
||||
# NOTE: non-chat model, all these special tokens keep randomly initialized.
|
||||
special_tokens = {
|
||||
'eos_token': '<|endoftext|>',
|
||||
'pad_token': '<|endoftext|>',
|
||||
'additional_special_tokens': [
|
||||
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
|
||||
'[breath]', '<strong>', '</strong>', '[noise]',
|
||||
'[laughter]', '[cough]', '[clucking]', '[accent]',
|
||||
'[quick_breath]',
|
||||
"<laughter>", "</laughter>",
|
||||
"[hissing]", "[sigh]", "[vocalized-noise]",
|
||||
"[lipsmack]", "[mn]", "<|endofsystem|>",
|
||||
"[AA]", "[AA0]", "[AA1]", "[AA2]", "[AE]", "[AE0]", "[AE1]", "[AE2]", "[AH]", "[AH0]", "[AH1]", "[AH2]",
|
||||
"[AO]", "[AO0]", "[AO1]", "[AO2]", "[AW]", "[AW0]", "[AW1]", "[AW2]", "[AY]", "[AY0]", "[AY1]", "[AY2]",
|
||||
"[B]", "[CH]", "[D]", "[DH]", "[EH]", "[EH0]", "[EH1]", "[EH2]", "[ER]", "[ER0]", "[ER1]", "[ER2]", "[EY]",
|
||||
"[EY0]", "[EY1]", "[EY2]", "[F]", "[G]", "[HH]", "[IH]", "[IH0]", "[IH1]", "[IH2]", "[IY]", "[IY0]", "[IY1]",
|
||||
"[IY2]", "[JH]", "[K]", "[L]", "[M]", "[N]", "[NG]", "[OW]", "[OW0]", "[OW1]", "[OW2]", "[OY]", "[OY0]",
|
||||
"[OY1]", "[OY2]", "[P]", "[R]", "[S]", "[SH]", "[T]", "[TH]", "[UH]", "[UH0]", "[UH1]", "[UH2]", "[UW]",
|
||||
"[UW0]", "[UW1]", "[UW2]", "[V]", "[W]", "[Y]", "[Z]", "[ZH]",
|
||||
"[a]", "[ai]", "[an]", "[ang]", "[ao]", "[b]", "[c]", "[ch]", "[d]", "[e]", "[ei]", "[en]", "[eng]", "[f]",
|
||||
"[g]", "[h]", "[i]", "[ian]", "[in]", "[ing]", "[iu]", "[ià]", "[iàn]", "[iàng]", "[iào]", "[iá]", "[ián]",
|
||||
"[iáng]", "[iáo]", "[iè]", "[ié]", "[iòng]", "[ióng]", "[iù]", "[iú]", "[iā]", "[iān]", "[iāng]", "[iāo]",
|
||||
"[iē]", "[iě]", "[iōng]", "[iū]", "[iǎ]", "[iǎn]", "[iǎng]", "[iǎo]", "[iǒng]", "[iǔ]", "[j]", "[k]", "[l]",
|
||||
"[m]", "[n]", "[o]", "[ong]", "[ou]", "[p]", "[q]", "[r]", "[s]", "[sh]", "[t]", "[u]", "[uang]", "[ue]",
|
||||
"[un]", "[uo]", "[uà]", "[uài]", "[uàn]", "[uàng]", "[uá]", "[uái]", "[uán]", "[uáng]", "[uè]", "[ué]", "[uì]",
|
||||
"[uí]", "[uò]", "[uó]", "[uā]", "[uāi]", "[uān]", "[uāng]", "[uē]", "[uě]", "[uī]", "[uō]", "[uǎ]", "[uǎi]",
|
||||
"[uǎn]", "[uǎng]", "[uǐ]", "[uǒ]", "[vè]", "[w]", "[x]", "[y]", "[z]", "[zh]", "[à]", "[ài]", "[àn]", "[àng]",
|
||||
"[ào]", "[á]", "[ái]", "[án]", "[áng]", "[áo]", "[è]", "[èi]", "[èn]", "[èng]", "[èr]", "[é]", "[éi]", "[én]",
|
||||
"[éng]", "[ér]", "[ì]", "[ìn]", "[ìng]", "[í]", "[ín]", "[íng]", "[ò]", "[òng]", "[òu]", "[ó]", "[óng]", "[óu]",
|
||||
"[ù]", "[ùn]", "[ú]", "[ún]", "[ā]", "[āi]", "[ān]", "[āng]", "[āo]", "[ē]", "[ēi]", "[ēn]", "[ēng]", "[ě]",
|
||||
"[ěi]", "[ěn]", "[ěng]", "[ěr]", "[ī]", "[īn]", "[īng]", "[ō]", "[ōng]", "[ōu]", "[ū]", "[ūn]", "[ǎ]", "[ǎi]",
|
||||
"[ǎn]", "[ǎng]", "[ǎo]", "[ǐ]", "[ǐn]", "[ǐng]", "[ǒ]", "[ǒng]", "[ǒu]", "[ǔ]", "[ǔn]", "[ǘ]", "[ǚ]", "[ǜ]"
|
||||
]
|
||||
}
|
||||
self.special_tokens = special_tokens
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(token_path)
|
||||
self.tokenizer.add_special_tokens(special_tokens)
|
||||
self.skip_special_tokens = skip_special_tokens
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def get_qwen_tokenizer(
|
||||
token_path: str,
|
||||
skip_special_tokens: bool,
|
||||
version: str = 'cosyvoice2'
|
||||
):
|
||||
if version == 'cosyvoice2':
|
||||
return CosyVoice2Tokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
|
||||
elif version == 'cosyvoice3':
|
||||
return CosyVoice3Tokenizer(token_path=token_path, skip_special_tokens=skip_special_tokens)
|
||||
else:
|
||||
raise ValueError
|
||||
0
models/CosyVoice/cosyvoice/transformer/__init__.py
Normal file
0
models/CosyVoice/cosyvoice/transformer/__init__.py
Normal file
84
models/CosyVoice/cosyvoice/transformer/activation.py
Normal file
84
models/CosyVoice/cosyvoice/transformer/activation.py
Normal file
@@ -0,0 +1,84 @@
|
||||
# Copyright (c) 2020 Johns Hopkins University (Shinji Watanabe)
|
||||
# 2020 Northwestern Polytechnical University (Pengcheng Guo)
|
||||
# 2020 Mobvoi Inc (Binbin Zhang)
|
||||
# 2024 Alibaba Inc (Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Swish() activation function for Conformer."""
|
||||
|
||||
import torch
|
||||
from torch import nn, sin, pow
|
||||
from torch.nn import Parameter
|
||||
|
||||
|
||||
class Swish(torch.nn.Module):
|
||||
"""Construct an Swish object."""
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Return Swish activation function."""
|
||||
return x * torch.sigmoid(x)
|
||||
|
||||
|
||||
# Implementation adapted from https://github.com/EdwardDixon/snake under the MIT license.
|
||||
# LICENSE is in incl_licenses directory.
|
||||
class Snake(nn.Module):
|
||||
'''
|
||||
Implementation of a sine-based periodic activation function
|
||||
Shape:
|
||||
- Input: (B, C, T)
|
||||
- Output: (B, C, T), same shape as the input
|
||||
Parameters:
|
||||
- alpha - trainable parameter
|
||||
References:
|
||||
- This activation function is from this paper by Liu Ziyin, Tilman Hartwig, Masahito Ueda:
|
||||
https://arxiv.org/abs/2006.08195
|
||||
Examples:
|
||||
>>> a1 = snake(256)
|
||||
>>> x = torch.randn(256)
|
||||
>>> x = a1(x)
|
||||
'''
|
||||
def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=False):
|
||||
'''
|
||||
Initialization.
|
||||
INPUT:
|
||||
- in_features: shape of the input
|
||||
- alpha: trainable parameter
|
||||
alpha is initialized to 1 by default, higher values = higher-frequency.
|
||||
alpha will be trained along with the rest of your model.
|
||||
'''
|
||||
super(Snake, self).__init__()
|
||||
self.in_features = in_features
|
||||
|
||||
# initialize alpha
|
||||
self.alpha_logscale = alpha_logscale
|
||||
if self.alpha_logscale: # log scale alphas initialized to zeros
|
||||
self.alpha = Parameter(torch.zeros(in_features) * alpha)
|
||||
else: # linear scale alphas initialized to ones
|
||||
self.alpha = Parameter(torch.ones(in_features) * alpha)
|
||||
|
||||
self.alpha.requires_grad = alpha_trainable
|
||||
|
||||
self.no_div_by_zero = 0.000000001
|
||||
|
||||
def forward(self, x):
|
||||
'''
|
||||
Forward pass of the function.
|
||||
Applies the function to the input elementwise.
|
||||
Snake ∶= x + 1/a * sin^2 (xa)
|
||||
'''
|
||||
alpha = self.alpha.unsqueeze(0).unsqueeze(-1) # line up with x to [B, C, T]
|
||||
if self.alpha_logscale:
|
||||
alpha = torch.exp(alpha)
|
||||
x = x + (1.0 / (alpha + self.no_div_by_zero)) * pow(sin(x * alpha), 2)
|
||||
|
||||
return x
|
||||
330
models/CosyVoice/cosyvoice/transformer/attention.py
Normal file
330
models/CosyVoice/cosyvoice/transformer/attention.py
Normal file
@@ -0,0 +1,330 @@
|
||||
# Copyright (c) 2019 Shigeki Karita
|
||||
# 2020 Mobvoi Inc (Binbin Zhang)
|
||||
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
|
||||
# 2024 Alibaba Inc (Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Multi-Head Attention layer definition."""
|
||||
|
||||
import math
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class MultiHeadedAttention(nn.Module):
|
||||
"""Multi-Head Attention layer.
|
||||
|
||||
Args:
|
||||
n_head (int): The number of heads.
|
||||
n_feat (int): The number of features.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_head: int,
|
||||
n_feat: int,
|
||||
dropout_rate: float,
|
||||
key_bias: bool = True):
|
||||
"""Construct an MultiHeadedAttention object."""
|
||||
super().__init__()
|
||||
assert n_feat % n_head == 0
|
||||
# We assume d_v always equals d_k
|
||||
self.d_k = n_feat // n_head
|
||||
self.h = n_head
|
||||
self.linear_q = nn.Linear(n_feat, n_feat)
|
||||
self.linear_k = nn.Linear(n_feat, n_feat, bias=key_bias)
|
||||
self.linear_v = nn.Linear(n_feat, n_feat)
|
||||
self.linear_out = nn.Linear(n_feat, n_feat)
|
||||
self.dropout = nn.Dropout(p=dropout_rate)
|
||||
|
||||
def forward_qkv(
|
||||
self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Transform query, key and value.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed query tensor, size
|
||||
(#batch, n_head, time1, d_k).
|
||||
torch.Tensor: Transformed key tensor, size
|
||||
(#batch, n_head, time2, d_k).
|
||||
torch.Tensor: Transformed value tensor, size
|
||||
(#batch, n_head, time2, d_k).
|
||||
|
||||
"""
|
||||
n_batch = query.size(0)
|
||||
q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
|
||||
k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
|
||||
v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
|
||||
q = q.transpose(1, 2) # (batch, head, time1, d_k)
|
||||
k = k.transpose(1, 2) # (batch, head, time2, d_k)
|
||||
v = v.transpose(1, 2) # (batch, head, time2, d_k)
|
||||
|
||||
return q, k, v
|
||||
|
||||
def forward_attention(
|
||||
self,
|
||||
value: torch.Tensor,
|
||||
scores: torch.Tensor,
|
||||
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool)
|
||||
) -> torch.Tensor:
|
||||
"""Compute attention context vector.
|
||||
|
||||
Args:
|
||||
value (torch.Tensor): Transformed value, size
|
||||
(#batch, n_head, time2, d_k).
|
||||
scores (torch.Tensor): Attention score, size
|
||||
(#batch, n_head, time1, time2).
|
||||
mask (torch.Tensor): Mask, size (#batch, 1, time2) or
|
||||
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Transformed value (#batch, time1, d_model)
|
||||
weighted by the attention score (#batch, time1, time2).
|
||||
|
||||
"""
|
||||
n_batch = value.size(0)
|
||||
# NOTE(xcsong): When will `if mask.size(2) > 0` be True?
|
||||
# 1. onnx(16/4) [WHY? Because we feed real cache & real mask for the
|
||||
# 1st chunk to ease the onnx export.]
|
||||
# 2. pytorch training
|
||||
if mask.size(2) > 0: # time2 > 0
|
||||
mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
|
||||
# For last chunk, time2 might be larger than scores.size(-1)
|
||||
mask = mask[:, :, :, :scores.size(-1)] # (batch, 1, *, time2)
|
||||
scores = scores.masked_fill(mask, -float('inf'))
|
||||
attn = torch.softmax(scores, dim=-1).masked_fill(
|
||||
mask, 0.0) # (batch, head, time1, time2)
|
||||
# NOTE(xcsong): When will `if mask.size(2) > 0` be False?
|
||||
# 1. onnx(16/-1, -1/-1, 16/0)
|
||||
# 2. jit (16/-1, -1/-1, 16/0, 16/4)
|
||||
else:
|
||||
attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
|
||||
|
||||
p_attn = self.dropout(attn)
|
||||
x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
|
||||
x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
|
||||
self.h * self.d_k)
|
||||
) # (batch, time1, d_model)
|
||||
|
||||
return self.linear_out(x) # (batch, time1, d_model)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
||||
pos_emb: torch.Tensor = torch.empty(0),
|
||||
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute scaled dot product attention.
|
||||
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
||||
(#batch, time1, time2).
|
||||
1.When applying cross attention between decoder and encoder,
|
||||
the batch padding mask for input is in (#batch, 1, T) shape.
|
||||
2.When applying self attention of encoder,
|
||||
the mask is in (#batch, T, T) shape.
|
||||
3.When applying self attention of decoder,
|
||||
the mask is in (#batch, L, L) shape.
|
||||
4.If the different position in decoder see different block
|
||||
of the encoder, such as Mocha, the passed in mask could be
|
||||
in (#batch, L, T) shape. But there is no such case in current
|
||||
CosyVoice.
|
||||
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
||||
where `cache_t == chunk_size * num_decoding_left_chunks`
|
||||
and `head * d_k == size`
|
||||
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
||||
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
||||
where `cache_t == chunk_size * num_decoding_left_chunks`
|
||||
and `head * d_k == size`
|
||||
|
||||
"""
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
|
||||
# NOTE(xcsong):
|
||||
# when export onnx model, for 1st chunk, we feed
|
||||
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
||||
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
||||
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
||||
# and we will always do splitting and
|
||||
# concatnation(this will simplify onnx export). Note that
|
||||
# it's OK to concat & split zero-shaped tensors(see code below).
|
||||
# when export jit model, for 1st chunk, we always feed
|
||||
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
||||
# >>> a = torch.ones((1, 2, 0, 4))
|
||||
# >>> b = torch.ones((1, 2, 3, 4))
|
||||
# >>> c = torch.cat((a, b), dim=2)
|
||||
# >>> torch.equal(b, c) # True
|
||||
# >>> d = torch.split(a, 2, dim=-1)
|
||||
# >>> torch.equal(d[0], d[1]) # True
|
||||
if cache.size(0) > 0:
|
||||
key_cache, value_cache = torch.split(cache,
|
||||
cache.size(-1) // 2,
|
||||
dim=-1)
|
||||
k = torch.cat([key_cache, k], dim=2)
|
||||
v = torch.cat([value_cache, v], dim=2)
|
||||
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
||||
# non-trivial to calculate `next_cache_start` here.
|
||||
new_cache = torch.cat((k, v), dim=-1)
|
||||
|
||||
scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
|
||||
return self.forward_attention(v, scores, mask), new_cache
|
||||
|
||||
|
||||
class RelPositionMultiHeadedAttention(MultiHeadedAttention):
|
||||
"""Multi-Head Attention layer with relative position encoding.
|
||||
Paper: https://arxiv.org/abs/1901.02860
|
||||
Args:
|
||||
n_head (int): The number of heads.
|
||||
n_feat (int): The number of features.
|
||||
dropout_rate (float): Dropout rate.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
n_head: int,
|
||||
n_feat: int,
|
||||
dropout_rate: float,
|
||||
key_bias: bool = True):
|
||||
"""Construct an RelPositionMultiHeadedAttention object."""
|
||||
super().__init__(n_head, n_feat, dropout_rate, key_bias)
|
||||
# linear transformation for positional encoding
|
||||
self.linear_pos = nn.Linear(n_feat, n_feat, bias=False)
|
||||
# these two learnable bias are used in matrix c and matrix d
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
self.pos_bias_u = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
self.pos_bias_v = nn.Parameter(torch.Tensor(self.h, self.d_k))
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_u)
|
||||
torch.nn.init.xavier_uniform_(self.pos_bias_v)
|
||||
|
||||
def rel_shift(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute relative positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, head, time1, 2*time1-1).
|
||||
time1 means the length of query vector.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor.
|
||||
|
||||
"""
|
||||
zero_pad = torch.zeros((x.size()[0], x.size()[1], x.size()[2], 1),
|
||||
device=x.device,
|
||||
dtype=x.dtype)
|
||||
x_padded = torch.cat([zero_pad, x], dim=-1)
|
||||
|
||||
x_padded = x_padded.view(x.size()[0],
|
||||
x.size()[1],
|
||||
x.size(3) + 1, x.size(2))
|
||||
x = x_padded[:, :, 1:].view_as(x)[
|
||||
:, :, :, : x.size(-1) // 2 + 1
|
||||
] # only keep the positions from 0 to time2
|
||||
return x
|
||||
|
||||
def forward(
|
||||
self,
|
||||
query: torch.Tensor,
|
||||
key: torch.Tensor,
|
||||
value: torch.Tensor,
|
||||
mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
||||
pos_emb: torch.Tensor = torch.empty(0),
|
||||
cache: torch.Tensor = torch.zeros((0, 0, 0, 0))
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute 'Scaled Dot Product Attention' with rel. positional encoding.
|
||||
Args:
|
||||
query (torch.Tensor): Query tensor (#batch, time1, size).
|
||||
key (torch.Tensor): Key tensor (#batch, time2, size).
|
||||
value (torch.Tensor): Value tensor (#batch, time2, size).
|
||||
mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
|
||||
(#batch, time1, time2), (0, 0, 0) means fake mask.
|
||||
pos_emb (torch.Tensor): Positional embedding tensor
|
||||
(#batch, time2, size).
|
||||
cache (torch.Tensor): Cache tensor (1, head, cache_t, d_k * 2),
|
||||
where `cache_t == chunk_size * num_decoding_left_chunks`
|
||||
and `head * d_k == size`
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time1, d_model).
|
||||
torch.Tensor: Cache tensor (1, head, cache_t + time1, d_k * 2)
|
||||
where `cache_t == chunk_size * num_decoding_left_chunks`
|
||||
and `head * d_k == size`
|
||||
"""
|
||||
q, k, v = self.forward_qkv(query, key, value)
|
||||
q = q.transpose(1, 2) # (batch, time1, head, d_k)
|
||||
|
||||
# NOTE(xcsong):
|
||||
# when export onnx model, for 1st chunk, we feed
|
||||
# cache(1, head, 0, d_k * 2) (16/-1, -1/-1, 16/0 mode)
|
||||
# or cache(1, head, real_cache_t, d_k * 2) (16/4 mode).
|
||||
# In all modes, `if cache.size(0) > 0` will alwayse be `True`
|
||||
# and we will always do splitting and
|
||||
# concatnation(this will simplify onnx export). Note that
|
||||
# it's OK to concat & split zero-shaped tensors(see code below).
|
||||
# when export jit model, for 1st chunk, we always feed
|
||||
# cache(0, 0, 0, 0) since jit supports dynamic if-branch.
|
||||
# >>> a = torch.ones((1, 2, 0, 4))
|
||||
# >>> b = torch.ones((1, 2, 3, 4))
|
||||
# >>> c = torch.cat((a, b), dim=2)
|
||||
# >>> torch.equal(b, c) # True
|
||||
# >>> d = torch.split(a, 2, dim=-1)
|
||||
# >>> torch.equal(d[0], d[1]) # True
|
||||
if cache.size(0) > 0:
|
||||
key_cache, value_cache = torch.split(cache,
|
||||
cache.size(-1) // 2,
|
||||
dim=-1)
|
||||
k = torch.cat([key_cache, k], dim=2)
|
||||
v = torch.cat([value_cache, v], dim=2)
|
||||
# NOTE(xcsong): We do cache slicing in encoder.forward_chunk, since it's
|
||||
# non-trivial to calculate `next_cache_start` here.
|
||||
new_cache = torch.cat((k, v), dim=-1)
|
||||
|
||||
n_batch_pos = pos_emb.size(0)
|
||||
p = self.linear_pos(pos_emb).view(n_batch_pos, -1, self.h, self.d_k)
|
||||
p = p.transpose(1, 2) # (batch, head, time1, d_k)
|
||||
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_u = (q + self.pos_bias_u).transpose(1, 2)
|
||||
# (batch, head, time1, d_k)
|
||||
q_with_bias_v = (q + self.pos_bias_v).transpose(1, 2)
|
||||
|
||||
# compute attention score
|
||||
# first compute matrix a and matrix c
|
||||
# as described in https://arxiv.org/abs/1901.02860 Section 3.3
|
||||
# (batch, head, time1, time2)
|
||||
matrix_ac = torch.matmul(q_with_bias_u, k.transpose(-2, -1))
|
||||
|
||||
# compute matrix b and matrix d
|
||||
# (batch, head, time1, time2)
|
||||
matrix_bd = torch.matmul(q_with_bias_v, p.transpose(-2, -1))
|
||||
# NOTE(Xiang Lyu): Keep rel_shift since espnet rel_pos_emb is used
|
||||
if matrix_ac.shape != matrix_bd.shape:
|
||||
matrix_bd = self.rel_shift(matrix_bd)
|
||||
|
||||
scores = (matrix_ac + matrix_bd) / math.sqrt(
|
||||
self.d_k) # (batch, head, time1, time2)
|
||||
|
||||
return self.forward_attention(v, scores, mask), new_cache
|
||||
258
models/CosyVoice/cosyvoice/transformer/convolution.py
Normal file
258
models/CosyVoice/cosyvoice/transformer/convolution.py
Normal file
@@ -0,0 +1,258 @@
|
||||
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
|
||||
# 2024 Alibaba Inc (Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from ESPnet(https://github.com/espnet/espnet)
|
||||
"""ConvolutionModule definition."""
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class ConvolutionModule(nn.Module):
|
||||
"""ConvolutionModule in Conformer model."""
|
||||
|
||||
def __init__(self,
|
||||
channels: int,
|
||||
kernel_size: int = 15,
|
||||
activation: nn.Module = nn.ReLU(),
|
||||
norm: str = "batch_norm",
|
||||
causal: bool = False,
|
||||
bias: bool = True):
|
||||
"""Construct an ConvolutionModule object.
|
||||
Args:
|
||||
channels (int): The number of channels of conv layers.
|
||||
kernel_size (int): Kernel size of conv layers.
|
||||
causal (int): Whether use causal convolution or not
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.pointwise_conv1 = nn.Conv1d(
|
||||
channels,
|
||||
2 * channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
# self.lorder is used to distinguish if it's a causal convolution,
|
||||
# if self.lorder > 0: it's a causal convolution, the input will be
|
||||
# padded with self.lorder frames on the left in forward.
|
||||
# else: it's a symmetrical convolution
|
||||
if causal:
|
||||
padding = 0
|
||||
self.lorder = kernel_size - 1
|
||||
else:
|
||||
# kernel_size should be an odd number for none causal convolution
|
||||
assert (kernel_size - 1) % 2 == 0
|
||||
padding = (kernel_size - 1) // 2
|
||||
self.lorder = 0
|
||||
self.depthwise_conv = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size,
|
||||
stride=1,
|
||||
padding=padding,
|
||||
groups=channels,
|
||||
bias=bias,
|
||||
)
|
||||
|
||||
assert norm in ['batch_norm', 'layer_norm']
|
||||
if norm == "batch_norm":
|
||||
self.use_layer_norm = False
|
||||
self.norm = nn.BatchNorm1d(channels)
|
||||
else:
|
||||
self.use_layer_norm = True
|
||||
self.norm = nn.LayerNorm(channels)
|
||||
|
||||
self.pointwise_conv2 = nn.Conv1d(
|
||||
channels,
|
||||
channels,
|
||||
kernel_size=1,
|
||||
stride=1,
|
||||
padding=0,
|
||||
bias=bias,
|
||||
)
|
||||
self.activation = activation
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
||||
cache: torch.Tensor = torch.zeros((0, 0, 0)),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute convolution module.
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, time, channels).
|
||||
mask_pad (torch.Tensor): used for batch padding (#batch, 1, time),
|
||||
(0, 0, 0) means fake mask.
|
||||
cache (torch.Tensor): left context cache, it is only
|
||||
used in causal convolution (#batch, channels, cache_t),
|
||||
(0, 0, 0) meas fake cache.
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, channels).
|
||||
"""
|
||||
# exchange the temporal dimension and the feature dimension
|
||||
x = x.transpose(1, 2) # (#batch, channels, time)
|
||||
|
||||
# mask batch padding
|
||||
if mask_pad.size(2) > 0: # time > 0
|
||||
x.masked_fill_(~mask_pad, 0.0)
|
||||
|
||||
if self.lorder > 0:
|
||||
if cache.size(2) == 0: # cache_t == 0
|
||||
x = nn.functional.pad(x, (self.lorder, 0), 'constant', 0.0)
|
||||
else:
|
||||
assert cache.size(0) == x.size(0) # equal batch
|
||||
assert cache.size(1) == x.size(1) # equal channel
|
||||
x = torch.cat((cache, x), dim=2)
|
||||
assert (x.size(2) > self.lorder)
|
||||
new_cache = x[:, :, -self.lorder:]
|
||||
else:
|
||||
# It's better we just return None if no cache is required,
|
||||
# However, for JIT export, here we just fake one tensor instead of
|
||||
# None.
|
||||
new_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
||||
|
||||
# GLU mechanism
|
||||
x = self.pointwise_conv1(x) # (batch, 2*channel, dim)
|
||||
x = nn.functional.glu(x, dim=1) # (batch, channel, dim)
|
||||
|
||||
# 1D Depthwise Conv
|
||||
x = self.depthwise_conv(x)
|
||||
if self.use_layer_norm:
|
||||
x = x.transpose(1, 2)
|
||||
x = self.activation(self.norm(x))
|
||||
if self.use_layer_norm:
|
||||
x = x.transpose(1, 2)
|
||||
x = self.pointwise_conv2(x)
|
||||
# mask batch padding
|
||||
if mask_pad.size(2) > 0: # time > 0
|
||||
x.masked_fill_(~mask_pad, 0.0)
|
||||
|
||||
return x.transpose(1, 2), new_cache
|
||||
|
||||
|
||||
# NOTE(Xiang Lyu) causal conv module used in convolution-based vocoder
|
||||
class CausalConv1d(torch.nn.Conv1d):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = 'zeros',
|
||||
causal_type: str = 'left',
|
||||
device=None,
|
||||
dtype=None
|
||||
) -> None:
|
||||
super(CausalConv1d, self).__init__(in_channels, out_channels,
|
||||
kernel_size, stride=1,
|
||||
padding=0, dilation=dilation,
|
||||
groups=groups, bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
device=device, dtype=dtype)
|
||||
assert stride == 1
|
||||
self.causal_padding = int((kernel_size * dilation - dilation) / 2) * 2 + (kernel_size + 1) % 2
|
||||
assert causal_type in ['left', 'right']
|
||||
self.causal_type = causal_type
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor]:
|
||||
input_timestep = x.shape[2]
|
||||
if cache.size(2) == 0:
|
||||
cache = torch.zeros(x.shape[0], x.shape[1], self.causal_padding).to(x)
|
||||
assert cache.size(2) == self.causal_padding
|
||||
if self.causal_type == 'left':
|
||||
x = torch.concat([cache, x], dim=2)
|
||||
else:
|
||||
x = torch.concat([x, cache], dim=2)
|
||||
x = super(CausalConv1d, self).forward(x)
|
||||
assert x.shape[2] == input_timestep
|
||||
return x
|
||||
|
||||
|
||||
class CausalConv1dDownSample(torch.nn.Conv1d):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = 'zeros',
|
||||
device=None,
|
||||
dtype=None
|
||||
) -> None:
|
||||
super(CausalConv1dDownSample, self).__init__(in_channels, out_channels,
|
||||
kernel_size, stride,
|
||||
padding=0, dilation=dilation,
|
||||
groups=groups, bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
device=device, dtype=dtype)
|
||||
assert stride != 1 and dilation == 1
|
||||
assert kernel_size % stride == 0
|
||||
self.causal_padding = stride - 1
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
if cache.size(2) == 0:
|
||||
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
||||
else:
|
||||
assert cache.size(2) == self.causal_padding
|
||||
x = torch.concat([cache, x], dim=2)
|
||||
x = super(CausalConv1dDownSample, self).forward(x)
|
||||
return x
|
||||
|
||||
|
||||
class CausalConv1dUpsample(torch.nn.Conv1d):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
kernel_size: int,
|
||||
stride: int = 1,
|
||||
dilation: int = 1,
|
||||
groups: int = 1,
|
||||
bias: bool = True,
|
||||
padding_mode: str = 'zeros',
|
||||
device=None,
|
||||
dtype=None
|
||||
) -> None:
|
||||
super(CausalConv1dUpsample, self).__init__(in_channels, out_channels,
|
||||
kernel_size, 1,
|
||||
padding=0, dilation=dilation,
|
||||
groups=groups, bias=bias,
|
||||
padding_mode=padding_mode,
|
||||
device=device, dtype=dtype)
|
||||
assert dilation == 1
|
||||
self.causal_padding = kernel_size - 1
|
||||
self.upsample = torch.nn.Upsample(scale_factor=stride, mode='nearest')
|
||||
|
||||
def forward(self, x: torch.Tensor, cache: torch.Tensor = torch.zeros(0, 0, 0)) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
x = self.upsample(x)
|
||||
input_timestep = x.shape[2]
|
||||
if cache.size(2) == 0:
|
||||
x = F.pad(x, (self.causal_padding, 0), value=0.0)
|
||||
else:
|
||||
assert cache.size(2) == self.causal_padding
|
||||
x = torch.concat([cache, x], dim=2)
|
||||
x = super(CausalConv1dUpsample, self).forward(x)
|
||||
assert input_timestep == x.shape[2]
|
||||
return x
|
||||
396
models/CosyVoice/cosyvoice/transformer/decoder.py
Normal file
396
models/CosyVoice/cosyvoice/transformer/decoder.py
Normal file
@@ -0,0 +1,396 @@
|
||||
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
|
||||
# 2024 Alibaba Inc (Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from ESPnet(https://github.com/espnet/espnet)
|
||||
"""Decoder definition."""
|
||||
from typing import Tuple, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint as ckpt
|
||||
import logging
|
||||
|
||||
from cosyvoice.transformer.decoder_layer import DecoderLayer
|
||||
from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
||||
from cosyvoice.utils.class_utils import (
|
||||
COSYVOICE_EMB_CLASSES,
|
||||
COSYVOICE_ATTENTION_CLASSES,
|
||||
COSYVOICE_ACTIVATION_CLASSES,
|
||||
)
|
||||
from cosyvoice.utils.mask import (subsequent_mask, make_pad_mask)
|
||||
|
||||
|
||||
class TransformerDecoder(torch.nn.Module):
|
||||
"""Base class of Transfomer decoder module.
|
||||
Args:
|
||||
vocab_size: output dim
|
||||
encoder_output_size: dimension of attention
|
||||
attention_heads: the number of heads of multi head attention
|
||||
linear_units: the hidden units number of position-wise feedforward
|
||||
num_blocks: the number of decoder blocks
|
||||
dropout_rate: dropout rate
|
||||
self_attention_dropout_rate: dropout rate for attention
|
||||
input_layer: input layer type
|
||||
use_output_layer: whether to use output layer
|
||||
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
|
||||
normalize_before:
|
||||
True: use layer_norm before each sub-block of a layer.
|
||||
False: use layer_norm after each sub-block of a layer.
|
||||
src_attention: if false, encoder-decoder cross attention is not
|
||||
applied, such as CIF model
|
||||
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
||||
gradient_checkpointing: rerunning a forward-pass segment for each
|
||||
checkpointed segment during backward.
|
||||
tie_word_embedding: Tie or clone module weights depending of whether we are
|
||||
using TorchScript or not
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
normalize_before: bool = True,
|
||||
src_attention: bool = True,
|
||||
key_bias: bool = True,
|
||||
activation_type: str = "relu",
|
||||
gradient_checkpointing: bool = False,
|
||||
tie_word_embedding: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
attention_dim = encoder_output_size
|
||||
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
|
||||
|
||||
self.embed = torch.nn.Sequential(
|
||||
torch.nn.Identity() if input_layer == "no_pos" else
|
||||
torch.nn.Embedding(vocab_size, attention_dim),
|
||||
COSYVOICE_EMB_CLASSES[input_layer](attention_dim,
|
||||
positional_dropout_rate),
|
||||
)
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
self.after_norm = torch.nn.LayerNorm(attention_dim, eps=1e-5)
|
||||
self.use_output_layer = use_output_layer
|
||||
if use_output_layer:
|
||||
self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
|
||||
else:
|
||||
self.output_layer = torch.nn.Identity()
|
||||
self.num_blocks = num_blocks
|
||||
self.decoders = torch.nn.ModuleList([
|
||||
DecoderLayer(
|
||||
attention_dim,
|
||||
COSYVOICE_ATTENTION_CLASSES["selfattn"](
|
||||
attention_heads, attention_dim,
|
||||
self_attention_dropout_rate, key_bias),
|
||||
COSYVOICE_ATTENTION_CLASSES["selfattn"](
|
||||
attention_heads, attention_dim, src_attention_dropout_rate,
|
||||
key_bias) if src_attention else None,
|
||||
PositionwiseFeedForward(attention_dim, linear_units,
|
||||
dropout_rate, activation),
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
) for _ in range(self.num_blocks)
|
||||
])
|
||||
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
self.tie_word_embedding = tie_word_embedding
|
||||
|
||||
def forward(
|
||||
self,
|
||||
memory: torch.Tensor,
|
||||
memory_mask: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
r_ys_in_pad: torch.Tensor = torch.empty(0),
|
||||
reverse_weight: float = 0.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Forward decoder.
|
||||
Args:
|
||||
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
|
||||
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
|
||||
ys_in_lens: input lengths of this batch (batch)
|
||||
r_ys_in_pad: not used in transformer decoder, in order to unify api
|
||||
with bidirectional decoder
|
||||
reverse_weight: not used in transformer decoder, in order to unify
|
||||
api with bidirectional decode
|
||||
Returns:
|
||||
(tuple): tuple containing:
|
||||
x: decoded token score before softmax (batch, maxlen_out,
|
||||
vocab_size) if use_output_layer is True,
|
||||
torch.tensor(0.0), in order to unify api with bidirectional decoder
|
||||
olens: (batch, )
|
||||
NOTE(xcsong):
|
||||
We pass the `__call__` method of the modules instead of `forward` to the
|
||||
checkpointing API because `__call__` attaches all the hooks of the module.
|
||||
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
||||
"""
|
||||
tgt = ys_in_pad
|
||||
maxlen = tgt.size(1)
|
||||
# tgt_mask: (B, 1, L)
|
||||
tgt_mask = ~make_pad_mask(ys_in_lens, maxlen).unsqueeze(1)
|
||||
tgt_mask = tgt_mask.to(tgt.device)
|
||||
# m: (1, L, L)
|
||||
m = subsequent_mask(tgt_mask.size(-1),
|
||||
device=tgt_mask.device).unsqueeze(0)
|
||||
# tgt_mask: (B, L, L)
|
||||
tgt_mask = tgt_mask & m
|
||||
x, _ = self.embed(tgt)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
x = self.forward_layers_checkpointed(x, tgt_mask, memory,
|
||||
memory_mask)
|
||||
else:
|
||||
x = self.forward_layers(x, tgt_mask, memory, memory_mask)
|
||||
if self.normalize_before:
|
||||
x = self.after_norm(x)
|
||||
if self.use_output_layer:
|
||||
x = self.output_layer(x)
|
||||
olens = tgt_mask.sum(1)
|
||||
return x, torch.tensor(0.0), olens
|
||||
|
||||
def forward_layers(self, x: torch.Tensor, tgt_mask: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
memory_mask: torch.Tensor) -> torch.Tensor:
|
||||
for layer in self.decoders:
|
||||
x, tgt_mask, memory, memory_mask = layer(x, tgt_mask, memory,
|
||||
memory_mask)
|
||||
return x
|
||||
|
||||
@torch.jit.unused
|
||||
def forward_layers_checkpointed(self, x: torch.Tensor,
|
||||
tgt_mask: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
memory_mask: torch.Tensor) -> torch.Tensor:
|
||||
for layer in self.decoders:
|
||||
x, tgt_mask, memory, memory_mask = ckpt.checkpoint(
|
||||
layer.__call__, x, tgt_mask, memory, memory_mask)
|
||||
return x
|
||||
|
||||
def forward_one_step(
|
||||
self,
|
||||
memory: torch.Tensor,
|
||||
memory_mask: torch.Tensor,
|
||||
tgt: torch.Tensor,
|
||||
tgt_mask: torch.Tensor,
|
||||
cache: Optional[List[torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Forward one step.
|
||||
This is only used for decoding.
|
||||
Args:
|
||||
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
memory_mask: encoded memory mask, (batch, 1, maxlen_in)
|
||||
tgt: input token ids, int64 (batch, maxlen_out)
|
||||
tgt_mask: input token mask, (batch, maxlen_out)
|
||||
dtype=torch.uint8 in PyTorch 1.2-
|
||||
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
|
||||
cache: cached output list of (batch, max_time_out-1, size)
|
||||
Returns:
|
||||
y, cache: NN output value and cache per `self.decoders`.
|
||||
y.shape` is (batch, maxlen_out, token)
|
||||
"""
|
||||
x, _ = self.embed(tgt)
|
||||
new_cache = []
|
||||
for i, decoder in enumerate(self.decoders):
|
||||
if cache is None:
|
||||
c = None
|
||||
else:
|
||||
c = cache[i]
|
||||
x, tgt_mask, memory, memory_mask = decoder(x,
|
||||
tgt_mask,
|
||||
memory,
|
||||
memory_mask,
|
||||
cache=c)
|
||||
new_cache.append(x)
|
||||
if self.normalize_before:
|
||||
y = self.after_norm(x[:, -1])
|
||||
else:
|
||||
y = x[:, -1]
|
||||
if self.use_output_layer:
|
||||
y = torch.log_softmax(self.output_layer(y), dim=-1)
|
||||
return y, new_cache
|
||||
|
||||
def tie_or_clone_weights(self, jit_mode: bool = True):
|
||||
"""Tie or clone module weights (between word_emb and output_layer)
|
||||
depending of whether we are using TorchScript or not"""
|
||||
if not self.use_output_layer:
|
||||
return
|
||||
if jit_mode:
|
||||
logging.info("clone emb.weight to output.weight")
|
||||
self.output_layer.weight = torch.nn.Parameter(
|
||||
self.embed[0].weight.clone())
|
||||
else:
|
||||
logging.info("tie emb.weight with output.weight")
|
||||
self.output_layer.weight = self.embed[0].weight
|
||||
|
||||
if getattr(self.output_layer, "bias", None) is not None:
|
||||
self.output_layer.bias.data = torch.nn.functional.pad(
|
||||
self.output_layer.bias.data,
|
||||
(
|
||||
0,
|
||||
self.output_layer.weight.shape[0] -
|
||||
self.output_layer.bias.shape[0],
|
||||
),
|
||||
"constant",
|
||||
0,
|
||||
)
|
||||
|
||||
|
||||
class BiTransformerDecoder(torch.nn.Module):
|
||||
"""Base class of Transfomer decoder module.
|
||||
Args:
|
||||
vocab_size: output dim
|
||||
encoder_output_size: dimension of attention
|
||||
attention_heads: the number of heads of multi head attention
|
||||
linear_units: the hidden units number of position-wise feedforward
|
||||
num_blocks: the number of decoder blocks
|
||||
r_num_blocks: the number of right to left decoder blocks
|
||||
dropout_rate: dropout rate
|
||||
self_attention_dropout_rate: dropout rate for attention
|
||||
input_layer: input layer type
|
||||
use_output_layer: whether to use output layer
|
||||
pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
|
||||
normalize_before:
|
||||
True: use layer_norm before each sub-block of a layer.
|
||||
False: use layer_norm after each sub-block of a layer.
|
||||
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size: int,
|
||||
encoder_output_size: int,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
r_num_blocks: int = 0,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
self_attention_dropout_rate: float = 0.0,
|
||||
src_attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "embed",
|
||||
use_output_layer: bool = True,
|
||||
normalize_before: bool = True,
|
||||
key_bias: bool = True,
|
||||
gradient_checkpointing: bool = False,
|
||||
tie_word_embedding: bool = False,
|
||||
):
|
||||
|
||||
super().__init__()
|
||||
self.tie_word_embedding = tie_word_embedding
|
||||
self.left_decoder = TransformerDecoder(
|
||||
vocab_size,
|
||||
encoder_output_size,
|
||||
attention_heads,
|
||||
linear_units,
|
||||
num_blocks,
|
||||
dropout_rate,
|
||||
positional_dropout_rate,
|
||||
self_attention_dropout_rate,
|
||||
src_attention_dropout_rate,
|
||||
input_layer,
|
||||
use_output_layer,
|
||||
normalize_before,
|
||||
key_bias=key_bias,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
tie_word_embedding=tie_word_embedding)
|
||||
|
||||
self.right_decoder = TransformerDecoder(
|
||||
vocab_size,
|
||||
encoder_output_size,
|
||||
attention_heads,
|
||||
linear_units,
|
||||
r_num_blocks,
|
||||
dropout_rate,
|
||||
positional_dropout_rate,
|
||||
self_attention_dropout_rate,
|
||||
src_attention_dropout_rate,
|
||||
input_layer,
|
||||
use_output_layer,
|
||||
normalize_before,
|
||||
key_bias=key_bias,
|
||||
gradient_checkpointing=gradient_checkpointing,
|
||||
tie_word_embedding=tie_word_embedding)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
memory: torch.Tensor,
|
||||
memory_mask: torch.Tensor,
|
||||
ys_in_pad: torch.Tensor,
|
||||
ys_in_lens: torch.Tensor,
|
||||
r_ys_in_pad: torch.Tensor,
|
||||
reverse_weight: float = 0.0,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Forward decoder.
|
||||
Args:
|
||||
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
memory_mask: encoder memory mask, (batch, 1, maxlen_in)
|
||||
ys_in_pad: padded input token ids, int64 (batch, maxlen_out)
|
||||
ys_in_lens: input lengths of this batch (batch)
|
||||
r_ys_in_pad: padded input token ids, int64 (batch, maxlen_out),
|
||||
used for right to left decoder
|
||||
reverse_weight: used for right to left decoder
|
||||
Returns:
|
||||
(tuple): tuple containing:
|
||||
x: decoded token score before softmax (batch, maxlen_out,
|
||||
vocab_size) if use_output_layer is True,
|
||||
r_x: x: decoded token score (right to left decoder)
|
||||
before softmax (batch, maxlen_out, vocab_size)
|
||||
if use_output_layer is True,
|
||||
olens: (batch, )
|
||||
"""
|
||||
l_x, _, olens = self.left_decoder(memory, memory_mask, ys_in_pad,
|
||||
ys_in_lens)
|
||||
r_x = torch.tensor(0.0)
|
||||
if reverse_weight > 0.0:
|
||||
r_x, _, olens = self.right_decoder(memory, memory_mask,
|
||||
r_ys_in_pad, ys_in_lens)
|
||||
return l_x, r_x, olens
|
||||
|
||||
def forward_one_step(
|
||||
self,
|
||||
memory: torch.Tensor,
|
||||
memory_mask: torch.Tensor,
|
||||
tgt: torch.Tensor,
|
||||
tgt_mask: torch.Tensor,
|
||||
cache: Optional[List[torch.Tensor]] = None,
|
||||
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
|
||||
"""Forward one step.
|
||||
This is only used for decoding.
|
||||
Args:
|
||||
memory: encoded memory, float32 (batch, maxlen_in, feat)
|
||||
memory_mask: encoded memory mask, (batch, 1, maxlen_in)
|
||||
tgt: input token ids, int64 (batch, maxlen_out)
|
||||
tgt_mask: input token mask, (batch, maxlen_out)
|
||||
dtype=torch.uint8 in PyTorch 1.2-
|
||||
dtype=torch.bool in PyTorch 1.2+ (include 1.2)
|
||||
cache: cached output list of (batch, max_time_out-1, size)
|
||||
Returns:
|
||||
y, cache: NN output value and cache per `self.decoders`.
|
||||
y.shape` is (batch, maxlen_out, token)
|
||||
"""
|
||||
return self.left_decoder.forward_one_step(memory, memory_mask, tgt,
|
||||
tgt_mask, cache)
|
||||
|
||||
def tie_or_clone_weights(self, jit_mode: bool = True):
|
||||
"""Tie or clone module weights (between word_emb and output_layer)
|
||||
depending of whether we are using TorchScript or not"""
|
||||
self.left_decoder.tie_or_clone_weights(jit_mode)
|
||||
self.right_decoder.tie_or_clone_weights(jit_mode)
|
||||
132
models/CosyVoice/cosyvoice/transformer/decoder_layer.py
Normal file
132
models/CosyVoice/cosyvoice/transformer/decoder_layer.py
Normal file
@@ -0,0 +1,132 @@
|
||||
# Copyright (c) 2019 Shigeki Karita
|
||||
# 2020 Mobvoi Inc (Binbin Zhang)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Decoder self-attention layer definition."""
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class DecoderLayer(nn.Module):
|
||||
"""Single decoder layer module.
|
||||
|
||||
Args:
|
||||
size (int): Input dimension.
|
||||
self_attn (torch.nn.Module): Self-attention module instance.
|
||||
`MultiHeadedAttention` instance can be used as the argument.
|
||||
src_attn (torch.nn.Module): Inter-attention module instance.
|
||||
`MultiHeadedAttention` instance can be used as the argument.
|
||||
If `None` is passed, Inter-attention is not used, such as
|
||||
CIF, GPT, and other decoder only model.
|
||||
feed_forward (torch.nn.Module): Feed-forward module instance.
|
||||
`PositionwiseFeedForward` instance can be used as the argument.
|
||||
dropout_rate (float): Dropout rate.
|
||||
normalize_before (bool):
|
||||
True: use layer_norm before each sub-block.
|
||||
False: to use layer_norm after each sub-block.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
self_attn: nn.Module,
|
||||
src_attn: Optional[nn.Module],
|
||||
feed_forward: nn.Module,
|
||||
dropout_rate: float,
|
||||
normalize_before: bool = True,
|
||||
):
|
||||
"""Construct an DecoderLayer object."""
|
||||
super().__init__()
|
||||
self.size = size
|
||||
self.self_attn = self_attn
|
||||
self.src_attn = src_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = nn.LayerNorm(size, eps=1e-5)
|
||||
self.norm2 = nn.LayerNorm(size, eps=1e-5)
|
||||
self.norm3 = nn.LayerNorm(size, eps=1e-5)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def forward(
|
||||
self,
|
||||
tgt: torch.Tensor,
|
||||
tgt_mask: torch.Tensor,
|
||||
memory: torch.Tensor,
|
||||
memory_mask: torch.Tensor,
|
||||
cache: Optional[torch.Tensor] = None
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Compute decoded features.
|
||||
|
||||
Args:
|
||||
tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
|
||||
tgt_mask (torch.Tensor): Mask for input tensor
|
||||
(#batch, maxlen_out).
|
||||
memory (torch.Tensor): Encoded memory
|
||||
(#batch, maxlen_in, size).
|
||||
memory_mask (torch.Tensor): Encoded memory mask
|
||||
(#batch, maxlen_in).
|
||||
cache (torch.Tensor): cached tensors.
|
||||
(#batch, maxlen_out - 1, size).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, maxlen_out, size).
|
||||
torch.Tensor: Mask for output tensor (#batch, maxlen_out).
|
||||
torch.Tensor: Encoded memory (#batch, maxlen_in, size).
|
||||
torch.Tensor: Encoded memory mask (#batch, maxlen_in).
|
||||
|
||||
"""
|
||||
residual = tgt
|
||||
if self.normalize_before:
|
||||
tgt = self.norm1(tgt)
|
||||
|
||||
if cache is None:
|
||||
tgt_q = tgt
|
||||
tgt_q_mask = tgt_mask
|
||||
else:
|
||||
# compute only the last frame query keeping dim: max_time_out -> 1
|
||||
assert cache.shape == (
|
||||
tgt.shape[0],
|
||||
tgt.shape[1] - 1,
|
||||
self.size,
|
||||
), "{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
|
||||
tgt_q = tgt[:, -1:, :]
|
||||
residual = residual[:, -1:, :]
|
||||
tgt_q_mask = tgt_mask[:, -1:, :]
|
||||
|
||||
x = residual + self.dropout(
|
||||
self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)[0])
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
if self.src_attn is not None:
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
x = residual + self.dropout(
|
||||
self.src_attn(x, memory, memory, memory_mask)[0])
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm3(x)
|
||||
x = residual + self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm3(x)
|
||||
|
||||
if cache is not None:
|
||||
x = torch.cat([cache, x], dim=1)
|
||||
|
||||
return x, tgt_mask, memory, memory_mask
|
||||
302
models/CosyVoice/cosyvoice/transformer/embedding.py
Normal file
302
models/CosyVoice/cosyvoice/transformer/embedding.py
Normal file
@@ -0,0 +1,302 @@
|
||||
# Copyright (c) 2020 Mobvoi Inc. (authors: Binbin Zhang, Di Wu)
|
||||
# 2024 Alibaba Inc (Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from ESPnet(https://github.com/espnet/espnet)
|
||||
"""Positonal Encoding Module."""
|
||||
|
||||
import math
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import numpy as np
|
||||
|
||||
|
||||
class PositionalEncoding(torch.nn.Module):
|
||||
"""Positional encoding.
|
||||
|
||||
:param int d_model: embedding dim
|
||||
:param float dropout_rate: dropout rate
|
||||
:param int max_len: maximum input length
|
||||
|
||||
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
|
||||
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
d_model: int,
|
||||
dropout_rate: float,
|
||||
max_len: int = 5000,
|
||||
reverse: bool = False):
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.max_len = max_len
|
||||
|
||||
self.pe = torch.zeros(self.max_len, self.d_model)
|
||||
position = torch.arange(0, self.max_len,
|
||||
dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32) *
|
||||
-(math.log(10000.0) / self.d_model))
|
||||
self.pe[:, 0::2] = torch.sin(position * div_term)
|
||||
self.pe[:, 1::2] = torch.cos(position * div_term)
|
||||
self.pe = self.pe.unsqueeze(0)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
offset: Union[int, torch.Tensor] = 0) \
|
||||
-> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input. Its shape is (batch, time, ...)
|
||||
offset (int, torch.tensor): position offset
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
|
||||
torch.Tensor: for compatibility to RelPositionalEncoding
|
||||
"""
|
||||
|
||||
self.pe = self.pe.to(x.device)
|
||||
pos_emb = self.position_encoding(offset, x.size(1), False)
|
||||
x = x * self.xscale + pos_emb
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
|
||||
def position_encoding(self,
|
||||
offset: Union[int, torch.Tensor],
|
||||
size: int,
|
||||
apply_dropout: bool = True) -> torch.Tensor:
|
||||
""" For getting encoding in a streaming fashion
|
||||
|
||||
Attention!!!!!
|
||||
we apply dropout only once at the whole utterance level in a none
|
||||
streaming way, but will call this function several times with
|
||||
increasing input size in a streaming scenario, so the dropout will
|
||||
be applied several times.
|
||||
|
||||
Args:
|
||||
offset (int or torch.tensor): start offset
|
||||
size (int): required size of position encoding
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Corresponding encoding
|
||||
"""
|
||||
# How to subscript a Union type:
|
||||
# https://github.com/pytorch/pytorch/issues/69434
|
||||
if isinstance(offset, int):
|
||||
assert offset + size <= self.max_len
|
||||
pos_emb = self.pe[:, offset:offset + size]
|
||||
elif isinstance(offset, torch.Tensor) and offset.dim() == 0: # scalar
|
||||
assert offset + size <= self.max_len
|
||||
pos_emb = self.pe[:, offset:offset + size]
|
||||
else: # for batched streaming decoding on GPU
|
||||
assert torch.max(offset) + size <= self.max_len
|
||||
index = offset.unsqueeze(1) + \
|
||||
torch.arange(0, size).to(offset.device) # B X T
|
||||
flag = index > 0
|
||||
# remove negative offset
|
||||
index = index * flag
|
||||
pos_emb = F.embedding(index, self.pe[0]) # B X T X d_model
|
||||
|
||||
if apply_dropout:
|
||||
pos_emb = self.dropout(pos_emb)
|
||||
return pos_emb
|
||||
|
||||
|
||||
class RelPositionalEncoding(PositionalEncoding):
|
||||
"""Relative positional encoding module.
|
||||
See : Appendix B in https://arxiv.org/abs/1901.02860
|
||||
Args:
|
||||
d_model (int): Embedding dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
max_len (int): Maximum input length.
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
|
||||
"""Initialize class."""
|
||||
super().__init__(d_model, dropout_rate, max_len, reverse=True)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
offset: Union[int, torch.Tensor] = 0) \
|
||||
-> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Compute positional encoding.
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
torch.Tensor: Positional embedding tensor (1, time, `*`).
|
||||
"""
|
||||
self.pe = self.pe.to(x.device)
|
||||
x = x * self.xscale
|
||||
pos_emb = self.position_encoding(offset, x.size(1), False)
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
|
||||
|
||||
class WhisperPositionalEncoding(PositionalEncoding):
|
||||
""" Sinusoids position encoding used in openai-whisper.encoder
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 1500):
|
||||
super().__init__(d_model, dropout_rate, max_len)
|
||||
self.xscale = 1.0
|
||||
log_timescale_increment = np.log(10000) / (d_model // 2 - 1)
|
||||
inv_timescales = torch.exp(-log_timescale_increment *
|
||||
torch.arange(d_model // 2))
|
||||
scaled_time = torch.arange(max_len)[:, np.newaxis] * \
|
||||
inv_timescales[np.newaxis, :]
|
||||
pe = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
||||
delattr(self, "pe")
|
||||
self.register_buffer("pe", pe.unsqueeze(0))
|
||||
|
||||
|
||||
class LearnablePositionalEncoding(PositionalEncoding):
|
||||
""" Learnable position encoding used in openai-whisper.decoder
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 448):
|
||||
super().__init__(d_model, dropout_rate, max_len)
|
||||
# NOTE(xcsong): overwrite self.pe & self.xscale
|
||||
self.pe = torch.nn.Parameter(torch.empty(1, max_len, d_model))
|
||||
self.xscale = 1.0
|
||||
|
||||
|
||||
class NoPositionalEncoding(torch.nn.Module):
|
||||
""" No position encoding
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, dropout_rate: float):
|
||||
super().__init__()
|
||||
self.d_model = d_model
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
|
||||
def forward(self,
|
||||
x: torch.Tensor,
|
||||
offset: Union[int, torch.Tensor] = 0) \
|
||||
-> Tuple[torch.Tensor, torch.Tensor]:
|
||||
""" Just return zero vector for interface compatibility
|
||||
"""
|
||||
pos_emb = torch.zeros(1, x.size(1), self.d_model).to(x.device)
|
||||
return self.dropout(x), pos_emb
|
||||
|
||||
def position_encoding(self, offset: Union[int, torch.Tensor],
|
||||
size: int) -> torch.Tensor:
|
||||
return torch.zeros(1, size, self.d_model)
|
||||
|
||||
|
||||
class EspnetRelPositionalEncoding(torch.nn.Module):
|
||||
"""Relative positional encoding module (new implementation).
|
||||
|
||||
Details can be found in https://github.com/espnet/espnet/pull/2816.
|
||||
|
||||
See : Appendix B in https://arxiv.org/abs/1901.02860
|
||||
|
||||
Args:
|
||||
d_model (int): Embedding dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
max_len (int): Maximum input length.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
|
||||
"""Construct an PositionalEncoding object."""
|
||||
super(EspnetRelPositionalEncoding, self).__init__()
|
||||
self.d_model = d_model
|
||||
self.xscale = math.sqrt(self.d_model)
|
||||
self.dropout = torch.nn.Dropout(p=dropout_rate)
|
||||
self.pe = None
|
||||
self.extend_pe(torch.tensor(0.0).expand(1, max_len))
|
||||
|
||||
def extend_pe(self, x: torch.Tensor):
|
||||
"""Reset the positional encodings."""
|
||||
if self.pe is not None:
|
||||
# self.pe contains both positive and negative parts
|
||||
# the length of self.pe is 2 * input_len - 1
|
||||
if self.pe.size(1) >= x.size(1) * 2 - 1:
|
||||
if self.pe.dtype != x.dtype or self.pe.device != x.device:
|
||||
self.pe = self.pe.to(dtype=x.dtype, device=x.device)
|
||||
return
|
||||
# Suppose `i` means to the position of query vecotr and `j` means the
|
||||
# position of key vector. We use position relative positions when keys
|
||||
# are to the left (i>j) and negative relative positions otherwise (i<j).
|
||||
pe_positive = torch.zeros(x.size(1), self.d_model)
|
||||
pe_negative = torch.zeros(x.size(1), self.d_model)
|
||||
position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
|
||||
div_term = torch.exp(
|
||||
torch.arange(0, self.d_model, 2, dtype=torch.float32)
|
||||
* -(math.log(10000.0) / self.d_model)
|
||||
)
|
||||
pe_positive[:, 0::2] = torch.sin(position * div_term)
|
||||
pe_positive[:, 1::2] = torch.cos(position * div_term)
|
||||
pe_negative[:, 0::2] = torch.sin(-1 * position * div_term)
|
||||
pe_negative[:, 1::2] = torch.cos(-1 * position * div_term)
|
||||
|
||||
# Reserve the order of positive indices and concat both positive and
|
||||
# negative indices. This is used to support the shifting trick
|
||||
# as in https://arxiv.org/abs/1901.02860
|
||||
pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0)
|
||||
pe_negative = pe_negative[1:].unsqueeze(0)
|
||||
pe = torch.cat([pe_positive, pe_negative], dim=1)
|
||||
self.pe = pe.to(device=x.device, dtype=x.dtype)
|
||||
|
||||
def forward(self, x: torch.Tensor, offset: Union[int, torch.Tensor] = 0) \
|
||||
-> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Add positional encoding.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (batch, time, `*`).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Encoded tensor (batch, time, `*`).
|
||||
|
||||
"""
|
||||
self.extend_pe(x)
|
||||
x = x * self.xscale
|
||||
pos_emb = self.position_encoding(size=x.size(1), offset=offset)
|
||||
return self.dropout(x), self.dropout(pos_emb)
|
||||
|
||||
def position_encoding(self,
|
||||
offset: Union[int, torch.Tensor],
|
||||
size: int) -> torch.Tensor:
|
||||
""" For getting encoding in a streaming fashion
|
||||
|
||||
Attention!!!!!
|
||||
we apply dropout only once at the whole utterance level in a none
|
||||
streaming way, but will call this function several times with
|
||||
increasing input size in a streaming scenario, so the dropout will
|
||||
be applied several times.
|
||||
|
||||
Args:
|
||||
offset (int or torch.tensor): start offset
|
||||
size (int): required size of position encoding
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Corresponding encoding
|
||||
"""
|
||||
# How to subscript a Union type:
|
||||
# https://github.com/pytorch/pytorch/issues/69434
|
||||
if isinstance(offset, int):
|
||||
pos_emb = self.pe[
|
||||
:,
|
||||
self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
|
||||
]
|
||||
elif isinstance(offset, torch.Tensor):
|
||||
pos_emb = self.pe[
|
||||
:,
|
||||
self.pe.size(1) // 2 - size - offset + 1: self.pe.size(1) // 2 + size + offset,
|
||||
]
|
||||
return pos_emb
|
||||
474
models/CosyVoice/cosyvoice/transformer/encoder.py
Normal file
474
models/CosyVoice/cosyvoice/transformer/encoder.py
Normal file
@@ -0,0 +1,474 @@
|
||||
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
||||
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
|
||||
# 2024 Alibaba Inc (Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from ESPnet(https://github.com/espnet/espnet)
|
||||
"""Encoder definition."""
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.utils.checkpoint as ckpt
|
||||
|
||||
from cosyvoice.transformer.convolution import ConvolutionModule
|
||||
from cosyvoice.transformer.encoder_layer import TransformerEncoderLayer
|
||||
from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
|
||||
from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
||||
from cosyvoice.utils.class_utils import (
|
||||
COSYVOICE_EMB_CLASSES,
|
||||
COSYVOICE_SUBSAMPLE_CLASSES,
|
||||
COSYVOICE_ATTENTION_CLASSES,
|
||||
COSYVOICE_ACTIVATION_CLASSES,
|
||||
)
|
||||
from cosyvoice.utils.mask import make_pad_mask
|
||||
from cosyvoice.utils.mask import add_optional_chunk_mask
|
||||
|
||||
|
||||
class BaseEncoder(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "conv2d",
|
||||
pos_enc_layer_type: str = "abs_pos",
|
||||
normalize_before: bool = True,
|
||||
static_chunk_size: int = 0,
|
||||
use_dynamic_chunk: bool = False,
|
||||
global_cmvn: torch.nn.Module = None,
|
||||
use_dynamic_left_chunk: bool = False,
|
||||
gradient_checkpointing: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_size (int): input dim
|
||||
output_size (int): dimension of attention
|
||||
attention_heads (int): the number of heads of multi head attention
|
||||
linear_units (int): the hidden units number of position-wise feed
|
||||
forward
|
||||
num_blocks (int): the number of decoder blocks
|
||||
dropout_rate (float): dropout rate
|
||||
attention_dropout_rate (float): dropout rate in attention
|
||||
positional_dropout_rate (float): dropout rate after adding
|
||||
positional encoding
|
||||
input_layer (str): input layer type.
|
||||
optional [linear, conv2d, conv2d6, conv2d8]
|
||||
pos_enc_layer_type (str): Encoder positional encoding layer type.
|
||||
opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
|
||||
normalize_before (bool):
|
||||
True: use layer_norm before each sub-block of a layer.
|
||||
False: use layer_norm after each sub-block of a layer.
|
||||
static_chunk_size (int): chunk size for static chunk training and
|
||||
decoding
|
||||
use_dynamic_chunk (bool): whether use dynamic chunk size for
|
||||
training or not, You can only use fixed chunk(chunk_size > 0)
|
||||
or dyanmic chunk size(use_dynamic_chunk = True)
|
||||
global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
|
||||
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
|
||||
dynamic chunk training
|
||||
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
||||
gradient_checkpointing: rerunning a forward-pass segment for each
|
||||
checkpointed segment during backward.
|
||||
"""
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
self.global_cmvn = global_cmvn
|
||||
self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
|
||||
positional_dropout_rate),
|
||||
)
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
|
||||
self.static_chunk_size = static_chunk_size
|
||||
self.use_dynamic_chunk = use_dynamic_chunk
|
||||
self.use_dynamic_left_chunk = use_dynamic_left_chunk
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs: torch.Tensor,
|
||||
xs_lens: torch.Tensor,
|
||||
decoding_chunk_size: int = 0,
|
||||
num_decoding_left_chunks: int = -1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Embed positions in tensor.
|
||||
|
||||
Args:
|
||||
xs: padded input tensor (B, T, D)
|
||||
xs_lens: input length (B)
|
||||
decoding_chunk_size: decoding chunk size for dynamic chunk
|
||||
0: default for training, use random dynamic chunk.
|
||||
<0: for decoding, use full chunk.
|
||||
>0: for decoding, use fixed chunk size as set.
|
||||
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
||||
the chunk size is decoding_chunk_size.
|
||||
>=0: use num_decoding_left_chunks
|
||||
<0: use all left chunks
|
||||
Returns:
|
||||
encoder output tensor xs, and subsampled masks
|
||||
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
||||
masks: torch.Tensor batch padding mask after subsample
|
||||
(B, 1, T' ~= T/subsample_rate)
|
||||
NOTE(xcsong):
|
||||
We pass the `__call__` method of the modules instead of `forward` to the
|
||||
checkpointing API because `__call__` attaches all the hooks of the module.
|
||||
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
||||
"""
|
||||
T = xs.size(1)
|
||||
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
||||
if self.global_cmvn is not None:
|
||||
xs = self.global_cmvn(xs)
|
||||
xs, pos_emb, masks = self.embed(xs, masks)
|
||||
mask_pad = masks # (B, 1, T/subsample_rate)
|
||||
chunk_masks = add_optional_chunk_mask(xs, masks,
|
||||
self.use_dynamic_chunk,
|
||||
self.use_dynamic_left_chunk,
|
||||
decoding_chunk_size,
|
||||
self.static_chunk_size,
|
||||
num_decoding_left_chunks)
|
||||
if self.gradient_checkpointing and self.training:
|
||||
xs = self.forward_layers_checkpointed(xs, chunk_masks, pos_emb,
|
||||
mask_pad)
|
||||
else:
|
||||
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
||||
if self.normalize_before:
|
||||
xs = self.after_norm(xs)
|
||||
# Here we assume the mask is not changed in encoder layers, so just
|
||||
# return the masks before encoder layers, and the masks will be used
|
||||
# for cross attention with decoder later
|
||||
return xs, masks
|
||||
|
||||
def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
mask_pad: torch.Tensor) -> torch.Tensor:
|
||||
for layer in self.encoders:
|
||||
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
||||
return xs
|
||||
|
||||
@torch.jit.unused
|
||||
def forward_layers_checkpointed(self, xs: torch.Tensor,
|
||||
chunk_masks: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
mask_pad: torch.Tensor) -> torch.Tensor:
|
||||
for layer in self.encoders:
|
||||
xs, chunk_masks, _, _ = ckpt.checkpoint(layer.__call__, xs,
|
||||
chunk_masks, pos_emb,
|
||||
mask_pad)
|
||||
return xs
|
||||
|
||||
@torch.jit.export
|
||||
def forward_chunk(
|
||||
self,
|
||||
xs: torch.Tensor,
|
||||
offset: int,
|
||||
required_cache_size: int,
|
||||
att_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
||||
cnn_cache: torch.Tensor = torch.zeros(0, 0, 0, 0),
|
||||
att_mask: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
""" Forward just one chunk
|
||||
|
||||
Args:
|
||||
xs (torch.Tensor): chunk input, with shape (b=1, time, mel-dim),
|
||||
where `time == (chunk_size - 1) * subsample_rate + \
|
||||
subsample.right_context + 1`
|
||||
offset (int): current offset in encoder output time stamp
|
||||
required_cache_size (int): cache size required for next chunk
|
||||
compuation
|
||||
>=0: actual cache size
|
||||
<0: means all history cache is required
|
||||
att_cache (torch.Tensor): cache tensor for KEY & VALUE in
|
||||
transformer/conformer attention, with shape
|
||||
(elayers, head, cache_t1, d_k * 2), where
|
||||
`head * d_k == hidden-dim` and
|
||||
`cache_t1 == chunk_size * num_decoding_left_chunks`.
|
||||
cnn_cache (torch.Tensor): cache tensor for cnn_module in conformer,
|
||||
(elayers, b=1, hidden-dim, cache_t2), where
|
||||
`cache_t2 == cnn.lorder - 1`
|
||||
|
||||
Returns:
|
||||
torch.Tensor: output of current input xs,
|
||||
with shape (b=1, chunk_size, hidden-dim).
|
||||
torch.Tensor: new attention cache required for next chunk, with
|
||||
dynamic shape (elayers, head, ?, d_k * 2)
|
||||
depending on required_cache_size.
|
||||
torch.Tensor: new conformer cnn cache required for next chunk, with
|
||||
same shape as the original cnn_cache.
|
||||
|
||||
"""
|
||||
assert xs.size(0) == 1
|
||||
# tmp_masks is just for interface compatibility
|
||||
tmp_masks = torch.ones(1,
|
||||
xs.size(1),
|
||||
device=xs.device,
|
||||
dtype=torch.bool)
|
||||
tmp_masks = tmp_masks.unsqueeze(1)
|
||||
if self.global_cmvn is not None:
|
||||
xs = self.global_cmvn(xs)
|
||||
# NOTE(xcsong): Before embed, shape(xs) is (b=1, time, mel-dim)
|
||||
xs, pos_emb, _ = self.embed(xs, tmp_masks, offset)
|
||||
# NOTE(xcsong): After embed, shape(xs) is (b=1, chunk_size, hidden-dim)
|
||||
elayers, cache_t1 = att_cache.size(0), att_cache.size(2)
|
||||
chunk_size = xs.size(1)
|
||||
attention_key_size = cache_t1 + chunk_size
|
||||
pos_emb = self.embed.position_encoding(offset=offset - cache_t1,
|
||||
size=attention_key_size)
|
||||
if required_cache_size < 0:
|
||||
next_cache_start = 0
|
||||
elif required_cache_size == 0:
|
||||
next_cache_start = attention_key_size
|
||||
else:
|
||||
next_cache_start = max(attention_key_size - required_cache_size, 0)
|
||||
r_att_cache = []
|
||||
r_cnn_cache = []
|
||||
for i, layer in enumerate(self.encoders):
|
||||
# NOTE(xcsong): Before layer.forward
|
||||
# shape(att_cache[i:i + 1]) is (1, head, cache_t1, d_k * 2),
|
||||
# shape(cnn_cache[i]) is (b=1, hidden-dim, cache_t2)
|
||||
xs, _, new_att_cache, new_cnn_cache = layer(
|
||||
xs,
|
||||
att_mask,
|
||||
pos_emb,
|
||||
att_cache=att_cache[i:i + 1] if elayers > 0 else att_cache,
|
||||
cnn_cache=cnn_cache[i] if cnn_cache.size(0) > 0 else cnn_cache)
|
||||
# NOTE(xcsong): After layer.forward
|
||||
# shape(new_att_cache) is (1, head, attention_key_size, d_k * 2),
|
||||
# shape(new_cnn_cache) is (b=1, hidden-dim, cache_t2)
|
||||
r_att_cache.append(new_att_cache[:, :, next_cache_start:, :])
|
||||
r_cnn_cache.append(new_cnn_cache.unsqueeze(0))
|
||||
if self.normalize_before:
|
||||
xs = self.after_norm(xs)
|
||||
|
||||
# NOTE(xcsong): shape(r_att_cache) is (elayers, head, ?, d_k * 2),
|
||||
# ? may be larger than cache_t1, it depends on required_cache_size
|
||||
r_att_cache = torch.cat(r_att_cache, dim=0)
|
||||
# NOTE(xcsong): shape(r_cnn_cache) is (e, b=1, hidden-dim, cache_t2)
|
||||
r_cnn_cache = torch.cat(r_cnn_cache, dim=0)
|
||||
|
||||
return (xs, r_att_cache, r_cnn_cache)
|
||||
|
||||
@torch.jit.unused
|
||||
def forward_chunk_by_chunk(
|
||||
self,
|
||||
xs: torch.Tensor,
|
||||
decoding_chunk_size: int,
|
||||
num_decoding_left_chunks: int = -1,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
""" Forward input chunk by chunk with chunk_size like a streaming
|
||||
fashion
|
||||
|
||||
Here we should pay special attention to computation cache in the
|
||||
streaming style forward chunk by chunk. Three things should be taken
|
||||
into account for computation in the current network:
|
||||
1. transformer/conformer encoder layers output cache
|
||||
2. convolution in conformer
|
||||
3. convolution in subsampling
|
||||
|
||||
However, we don't implement subsampling cache for:
|
||||
1. We can control subsampling module to output the right result by
|
||||
overlapping input instead of cache left context, even though it
|
||||
wastes some computation, but subsampling only takes a very
|
||||
small fraction of computation in the whole model.
|
||||
2. Typically, there are several covolution layers with subsampling
|
||||
in subsampling module, it is tricky and complicated to do cache
|
||||
with different convolution layers with different subsampling
|
||||
rate.
|
||||
3. Currently, nn.Sequential is used to stack all the convolution
|
||||
layers in subsampling, we need to rewrite it to make it work
|
||||
with cache, which is not preferred.
|
||||
Args:
|
||||
xs (torch.Tensor): (1, max_len, dim)
|
||||
chunk_size (int): decoding chunk size
|
||||
"""
|
||||
assert decoding_chunk_size > 0
|
||||
# The model is trained by static or dynamic chunk
|
||||
assert self.static_chunk_size > 0 or self.use_dynamic_chunk
|
||||
subsampling = self.embed.subsampling_rate
|
||||
context = self.embed.right_context + 1 # Add current frame
|
||||
stride = subsampling * decoding_chunk_size
|
||||
decoding_window = (decoding_chunk_size - 1) * subsampling + context
|
||||
num_frames = xs.size(1)
|
||||
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
|
||||
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0), device=xs.device)
|
||||
outputs = []
|
||||
offset = 0
|
||||
required_cache_size = decoding_chunk_size * num_decoding_left_chunks
|
||||
|
||||
# Feed forward overlap input step by step
|
||||
for cur in range(0, num_frames - context + 1, stride):
|
||||
end = min(cur + decoding_window, num_frames)
|
||||
chunk_xs = xs[:, cur:end, :]
|
||||
(y, att_cache,
|
||||
cnn_cache) = self.forward_chunk(chunk_xs, offset,
|
||||
required_cache_size, att_cache,
|
||||
cnn_cache)
|
||||
outputs.append(y)
|
||||
offset += y.size(1)
|
||||
ys = torch.cat(outputs, 1)
|
||||
masks = torch.ones((1, 1, ys.size(1)),
|
||||
device=ys.device,
|
||||
dtype=torch.bool)
|
||||
return ys, masks
|
||||
|
||||
|
||||
class TransformerEncoder(BaseEncoder):
|
||||
"""Transformer encoder module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "conv2d",
|
||||
pos_enc_layer_type: str = "abs_pos",
|
||||
normalize_before: bool = True,
|
||||
static_chunk_size: int = 0,
|
||||
use_dynamic_chunk: bool = False,
|
||||
global_cmvn: torch.nn.Module = None,
|
||||
use_dynamic_left_chunk: bool = False,
|
||||
key_bias: bool = True,
|
||||
selfattention_layer_type: str = "selfattn",
|
||||
activation_type: str = "relu",
|
||||
gradient_checkpointing: bool = False,
|
||||
):
|
||||
""" Construct TransformerEncoder
|
||||
|
||||
See Encoder for the meaning of each parameter.
|
||||
"""
|
||||
super().__init__(input_size, output_size, attention_heads,
|
||||
linear_units, num_blocks, dropout_rate,
|
||||
positional_dropout_rate, attention_dropout_rate,
|
||||
input_layer, pos_enc_layer_type, normalize_before,
|
||||
static_chunk_size, use_dynamic_chunk, global_cmvn,
|
||||
use_dynamic_left_chunk, gradient_checkpointing)
|
||||
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
|
||||
self.encoders = torch.nn.ModuleList([
|
||||
TransformerEncoderLayer(
|
||||
output_size,
|
||||
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
key_bias),
|
||||
PositionwiseFeedForward(output_size, linear_units,
|
||||
dropout_rate, activation),
|
||||
dropout_rate, normalize_before) for _ in range(num_blocks)
|
||||
])
|
||||
|
||||
|
||||
class ConformerEncoder(BaseEncoder):
|
||||
"""Conformer encoder module."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "conv2d",
|
||||
pos_enc_layer_type: str = "rel_pos",
|
||||
normalize_before: bool = True,
|
||||
static_chunk_size: int = 0,
|
||||
use_dynamic_chunk: bool = False,
|
||||
global_cmvn: torch.nn.Module = None,
|
||||
use_dynamic_left_chunk: bool = False,
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
macaron_style: bool = True,
|
||||
selfattention_layer_type: str = "rel_selfattn",
|
||||
activation_type: str = "swish",
|
||||
use_cnn_module: bool = True,
|
||||
cnn_module_kernel: int = 15,
|
||||
causal: bool = False,
|
||||
cnn_module_norm: str = "batch_norm",
|
||||
key_bias: bool = True,
|
||||
gradient_checkpointing: bool = False,
|
||||
):
|
||||
"""Construct ConformerEncoder
|
||||
|
||||
Args:
|
||||
input_size to use_dynamic_chunk, see in BaseEncoder
|
||||
positionwise_conv_kernel_size (int): Kernel size of positionwise
|
||||
conv1d layer.
|
||||
macaron_style (bool): Whether to use macaron style for
|
||||
positionwise layer.
|
||||
selfattention_layer_type (str): Encoder attention layer type,
|
||||
the parameter has no effect now, it's just for configure
|
||||
compatibility.
|
||||
activation_type (str): Encoder activation function type.
|
||||
use_cnn_module (bool): Whether to use convolution module.
|
||||
cnn_module_kernel (int): Kernel size of convolution module.
|
||||
causal (bool): whether to use causal convolution or not.
|
||||
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
||||
"""
|
||||
super().__init__(input_size, output_size, attention_heads,
|
||||
linear_units, num_blocks, dropout_rate,
|
||||
positional_dropout_rate, attention_dropout_rate,
|
||||
input_layer, pos_enc_layer_type, normalize_before,
|
||||
static_chunk_size, use_dynamic_chunk, global_cmvn,
|
||||
use_dynamic_left_chunk, gradient_checkpointing)
|
||||
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
|
||||
|
||||
# self-attention module definition
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
key_bias,
|
||||
)
|
||||
# feed-forward module definition
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
activation,
|
||||
)
|
||||
# convolution module definition
|
||||
convolution_layer_args = (output_size, cnn_module_kernel, activation,
|
||||
cnn_module_norm, causal)
|
||||
|
||||
self.encoders = torch.nn.ModuleList([
|
||||
ConformerEncoderLayer(
|
||||
output_size,
|
||||
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
|
||||
*encoder_selfattn_layer_args),
|
||||
PositionwiseFeedForward(*positionwise_layer_args),
|
||||
PositionwiseFeedForward(
|
||||
*positionwise_layer_args) if macaron_style else None,
|
||||
ConvolutionModule(
|
||||
*convolution_layer_args) if use_cnn_module else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
) for _ in range(num_blocks)
|
||||
])
|
||||
236
models/CosyVoice/cosyvoice/transformer/encoder_layer.py
Normal file
236
models/CosyVoice/cosyvoice/transformer/encoder_layer.py
Normal file
@@ -0,0 +1,236 @@
|
||||
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
||||
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from ESPnet(https://github.com/espnet/espnet)
|
||||
"""Encoder self-attention layer definition."""
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class TransformerEncoderLayer(nn.Module):
|
||||
"""Encoder layer module.
|
||||
|
||||
Args:
|
||||
size (int): Input dimension.
|
||||
self_attn (torch.nn.Module): Self-attention module instance.
|
||||
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
||||
instance can be used as the argument.
|
||||
feed_forward (torch.nn.Module): Feed-forward module instance.
|
||||
`PositionwiseFeedForward`, instance can be used as the argument.
|
||||
dropout_rate (float): Dropout rate.
|
||||
normalize_before (bool):
|
||||
True: use layer_norm before each sub-block.
|
||||
False: to use layer_norm after each sub-block.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
self_attn: torch.nn.Module,
|
||||
feed_forward: torch.nn.Module,
|
||||
dropout_rate: float,
|
||||
normalize_before: bool = True,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super().__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.norm1 = nn.LayerNorm(size, eps=1e-12)
|
||||
self.norm2 = nn.LayerNorm(size, eps=1e-12)
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
||||
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
||||
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Compute encoded features.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): (#batch, time, size)
|
||||
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
|
||||
(0, 0, 0) means fake mask.
|
||||
pos_emb (torch.Tensor): just for interface compatibility
|
||||
to ConformerEncoderLayer
|
||||
mask_pad (torch.Tensor): does not used in transformer layer,
|
||||
just for unified api with conformer.
|
||||
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
||||
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
||||
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
||||
(#batch=1, size, cache_t2), not used here, it's for interface
|
||||
compatibility to ConformerEncoderLayer.
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, size).
|
||||
torch.Tensor: Mask tensor (#batch, time, time).
|
||||
torch.Tensor: att_cache tensor,
|
||||
(#batch=1, head, cache_t1 + time, d_k * 2).
|
||||
torch.Tensor: cnn_cahce tensor (#batch=1, size, cache_t2).
|
||||
|
||||
"""
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb=pos_emb, cache=att_cache)
|
||||
x = residual + self.dropout(x_att)
|
||||
if not self.normalize_before:
|
||||
x = self.norm1(x)
|
||||
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
x = residual + self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm2(x)
|
||||
|
||||
fake_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
||||
return x, mask, new_att_cache, fake_cnn_cache
|
||||
|
||||
|
||||
class ConformerEncoderLayer(nn.Module):
|
||||
"""Encoder layer module.
|
||||
Args:
|
||||
size (int): Input dimension.
|
||||
self_attn (torch.nn.Module): Self-attention module instance.
|
||||
`MultiHeadedAttention` or `RelPositionMultiHeadedAttention`
|
||||
instance can be used as the argument.
|
||||
feed_forward (torch.nn.Module): Feed-forward module instance.
|
||||
`PositionwiseFeedForward` instance can be used as the argument.
|
||||
feed_forward_macaron (torch.nn.Module): Additional feed-forward module
|
||||
instance.
|
||||
`PositionwiseFeedForward` instance can be used as the argument.
|
||||
conv_module (torch.nn.Module): Convolution module instance.
|
||||
`ConvlutionModule` instance can be used as the argument.
|
||||
dropout_rate (float): Dropout rate.
|
||||
normalize_before (bool):
|
||||
True: use layer_norm before each sub-block.
|
||||
False: use layer_norm after each sub-block.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
size: int,
|
||||
self_attn: torch.nn.Module,
|
||||
feed_forward: Optional[nn.Module] = None,
|
||||
feed_forward_macaron: Optional[nn.Module] = None,
|
||||
conv_module: Optional[nn.Module] = None,
|
||||
dropout_rate: float = 0.1,
|
||||
normalize_before: bool = True,
|
||||
):
|
||||
"""Construct an EncoderLayer object."""
|
||||
super().__init__()
|
||||
self.self_attn = self_attn
|
||||
self.feed_forward = feed_forward
|
||||
self.feed_forward_macaron = feed_forward_macaron
|
||||
self.conv_module = conv_module
|
||||
self.norm_ff = nn.LayerNorm(size, eps=1e-12) # for the FNN module
|
||||
self.norm_mha = nn.LayerNorm(size, eps=1e-12) # for the MHA module
|
||||
if feed_forward_macaron is not None:
|
||||
self.norm_ff_macaron = nn.LayerNorm(size, eps=1e-12)
|
||||
self.ff_scale = 0.5
|
||||
else:
|
||||
self.ff_scale = 1.0
|
||||
if self.conv_module is not None:
|
||||
self.norm_conv = nn.LayerNorm(size, eps=1e-12) # for the CNN module
|
||||
self.norm_final = nn.LayerNorm(
|
||||
size, eps=1e-12) # for the final output of the block
|
||||
self.dropout = nn.Dropout(dropout_rate)
|
||||
self.size = size
|
||||
self.normalize_before = normalize_before
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
mask: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
mask_pad: torch.Tensor = torch.ones((0, 0, 0), dtype=torch.bool),
|
||||
att_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
||||
cnn_cache: torch.Tensor = torch.zeros((0, 0, 0, 0)),
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Compute encoded features.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): (#batch, time, size)
|
||||
mask (torch.Tensor): Mask tensor for the input (#batch, time,time),
|
||||
(0, 0, 0) means fake mask.
|
||||
pos_emb (torch.Tensor): positional encoding, must not be None
|
||||
for ConformerEncoderLayer.
|
||||
mask_pad (torch.Tensor): batch padding mask used for conv module.
|
||||
(#batch, 1,time), (0, 0, 0) means fake mask.
|
||||
att_cache (torch.Tensor): Cache tensor of the KEY & VALUE
|
||||
(#batch=1, head, cache_t1, d_k * 2), head * d_k == size.
|
||||
cnn_cache (torch.Tensor): Convolution cache in conformer layer
|
||||
(#batch=1, size, cache_t2)
|
||||
Returns:
|
||||
torch.Tensor: Output tensor (#batch, time, size).
|
||||
torch.Tensor: Mask tensor (#batch, time, time).
|
||||
torch.Tensor: att_cache tensor,
|
||||
(#batch=1, head, cache_t1 + time, d_k * 2).
|
||||
torch.Tensor: cnn_cahce tensor (#batch, size, cache_t2).
|
||||
"""
|
||||
|
||||
# whether to use macaron style
|
||||
if self.feed_forward_macaron is not None:
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_ff_macaron(x)
|
||||
x = residual + self.ff_scale * self.dropout(
|
||||
self.feed_forward_macaron(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm_ff_macaron(x)
|
||||
|
||||
# multi-headed self-attention module
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_mha(x)
|
||||
x_att, new_att_cache = self.self_attn(x, x, x, mask, pos_emb,
|
||||
att_cache)
|
||||
x = residual + self.dropout(x_att)
|
||||
if not self.normalize_before:
|
||||
x = self.norm_mha(x)
|
||||
|
||||
# convolution module
|
||||
# Fake new cnn cache here, and then change it in conv_module
|
||||
new_cnn_cache = torch.zeros((0, 0, 0), dtype=x.dtype, device=x.device)
|
||||
if self.conv_module is not None:
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_conv(x)
|
||||
x, new_cnn_cache = self.conv_module(x, mask_pad, cnn_cache)
|
||||
x = residual + self.dropout(x)
|
||||
|
||||
if not self.normalize_before:
|
||||
x = self.norm_conv(x)
|
||||
|
||||
# feed forward module
|
||||
residual = x
|
||||
if self.normalize_before:
|
||||
x = self.norm_ff(x)
|
||||
|
||||
x = residual + self.ff_scale * self.dropout(self.feed_forward(x))
|
||||
if not self.normalize_before:
|
||||
x = self.norm_ff(x)
|
||||
|
||||
if self.conv_module is not None:
|
||||
x = self.norm_final(x)
|
||||
|
||||
return x, mask, new_att_cache, new_cnn_cache
|
||||
@@ -0,0 +1,96 @@
|
||||
# Copyright (c) 2019 Shigeki Karita
|
||||
# 2020 Mobvoi Inc (Binbin Zhang)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Label smoothing module."""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
|
||||
|
||||
class LabelSmoothingLoss(nn.Module):
|
||||
"""Label-smoothing loss.
|
||||
|
||||
In a standard CE loss, the label's data distribution is:
|
||||
[0,1,2] ->
|
||||
[
|
||||
[1.0, 0.0, 0.0],
|
||||
[0.0, 1.0, 0.0],
|
||||
[0.0, 0.0, 1.0],
|
||||
]
|
||||
|
||||
In the smoothing version CE Loss,some probabilities
|
||||
are taken from the true label prob (1.0) and are divided
|
||||
among other labels.
|
||||
|
||||
e.g.
|
||||
smoothing=0.1
|
||||
[0,1,2] ->
|
||||
[
|
||||
[0.9, 0.05, 0.05],
|
||||
[0.05, 0.9, 0.05],
|
||||
[0.05, 0.05, 0.9],
|
||||
]
|
||||
|
||||
Args:
|
||||
size (int): the number of class
|
||||
padding_idx (int): padding class id which will be ignored for loss
|
||||
smoothing (float): smoothing rate (0.0 means the conventional CE)
|
||||
normalize_length (bool):
|
||||
normalize loss by sequence length if True
|
||||
normalize loss by batch size if False
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
size: int,
|
||||
padding_idx: int,
|
||||
smoothing: float,
|
||||
normalize_length: bool = False):
|
||||
"""Construct an LabelSmoothingLoss object."""
|
||||
super(LabelSmoothingLoss, self).__init__()
|
||||
self.criterion = nn.KLDivLoss(reduction="none")
|
||||
self.padding_idx = padding_idx
|
||||
self.confidence = 1.0 - smoothing
|
||||
self.smoothing = smoothing
|
||||
self.size = size
|
||||
self.normalize_length = normalize_length
|
||||
|
||||
def forward(self, x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute loss between x and target.
|
||||
|
||||
The model outputs and data labels tensors are flatten to
|
||||
(batch*seqlen, class) shape and a mask is applied to the
|
||||
padding part which should not be calculated for loss.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): prediction (batch, seqlen, class)
|
||||
target (torch.Tensor):
|
||||
target signal masked with self.padding_id (batch, seqlen)
|
||||
Returns:
|
||||
loss (torch.Tensor) : The KL loss, scalar float value
|
||||
"""
|
||||
assert x.size(2) == self.size
|
||||
batch_size = x.size(0)
|
||||
x = x.view(-1, self.size)
|
||||
target = target.view(-1)
|
||||
# use zeros_like instead of torch.no_grad() for true_dist,
|
||||
# since no_grad() can not be exported by JIT
|
||||
true_dist = torch.zeros_like(x)
|
||||
true_dist.fill_(self.smoothing / (self.size - 1))
|
||||
ignore = target == self.padding_idx # (B,)
|
||||
total = len(target) - ignore.sum().item()
|
||||
target = target.masked_fill(ignore, 0) # avoid -1 index
|
||||
true_dist.scatter_(1, target.unsqueeze(1), self.confidence)
|
||||
kl = self.criterion(torch.log_softmax(x, dim=1), true_dist)
|
||||
denom = total if self.normalize_length else batch_size
|
||||
return kl.masked_fill(ignore.unsqueeze(1), 0).sum() / denom
|
||||
@@ -0,0 +1,115 @@
|
||||
# Copyright (c) 2019 Shigeki Karita
|
||||
# 2020 Mobvoi Inc (Binbin Zhang)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Positionwise feed forward layer definition."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class PositionwiseFeedForward(torch.nn.Module):
|
||||
"""Positionwise feed forward layer.
|
||||
|
||||
FeedForward are appied on each position of the sequence.
|
||||
The output dim is same with the input dim.
|
||||
|
||||
Args:
|
||||
idim (int): Input dimenstion.
|
||||
hidden_units (int): The number of hidden units.
|
||||
dropout_rate (float): Dropout rate.
|
||||
activation (torch.nn.Module): Activation function
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
idim: int,
|
||||
hidden_units: int,
|
||||
dropout_rate: float,
|
||||
activation: torch.nn.Module = torch.nn.ReLU(),
|
||||
):
|
||||
"""Construct a PositionwiseFeedForward object."""
|
||||
super(PositionwiseFeedForward, self).__init__()
|
||||
self.w_1 = torch.nn.Linear(idim, hidden_units)
|
||||
self.activation = activation
|
||||
self.dropout = torch.nn.Dropout(dropout_rate)
|
||||
self.w_2 = torch.nn.Linear(hidden_units, idim)
|
||||
|
||||
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
||||
"""Forward function.
|
||||
|
||||
Args:
|
||||
xs: input tensor (B, L, D)
|
||||
Returns:
|
||||
output tensor, (B, L, D)
|
||||
"""
|
||||
return self.w_2(self.dropout(self.activation(self.w_1(xs))))
|
||||
|
||||
|
||||
class MoEFFNLayer(torch.nn.Module):
|
||||
"""
|
||||
Mixture of expert with Positionwise feed forward layer
|
||||
See also figure 1 in https://arxiv.org/pdf/2305.15663.pdf
|
||||
The output dim is same with the input dim.
|
||||
|
||||
Modified from https://github.com/Lightning-AI/lit-gpt/pull/823
|
||||
https://github.com/mistralai/mistral-src/blob/b46d6/moe_one_file_ref.py#L203-L219
|
||||
Args:
|
||||
n_expert: number of expert.
|
||||
n_expert_per_token: The actual number of experts used for each frame
|
||||
idim (int): Input dimenstion.
|
||||
hidden_units (int): The number of hidden units.
|
||||
dropout_rate (float): Dropout rate.
|
||||
activation (torch.nn.Module): Activation function
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_expert: int,
|
||||
n_expert_per_token: int,
|
||||
idim: int,
|
||||
hidden_units: int,
|
||||
dropout_rate: float,
|
||||
activation: torch.nn.Module = torch.nn.ReLU(),
|
||||
):
|
||||
super(MoEFFNLayer, self).__init__()
|
||||
self.gate = torch.nn.Linear(idim, n_expert, bias=False)
|
||||
self.experts = torch.nn.ModuleList(
|
||||
PositionwiseFeedForward(idim, hidden_units, dropout_rate,
|
||||
activation) for _ in range(n_expert))
|
||||
self.n_expert_per_token = n_expert_per_token
|
||||
|
||||
def forward(self, xs: torch.Tensor) -> torch.Tensor:
|
||||
"""Foward function.
|
||||
Args:
|
||||
xs: input tensor (B, L, D)
|
||||
Returns:
|
||||
output tensor, (B, L, D)
|
||||
|
||||
"""
|
||||
B, L, D = xs.size(
|
||||
) # batch size, sequence length, embedding dimension (idim)
|
||||
xs = xs.view(-1, D) # (B*L, D)
|
||||
router = self.gate(xs) # (B*L, n_expert)
|
||||
logits, indices = torch.topk(
|
||||
router, self.n_expert_per_token
|
||||
) # probs:(B*L, n_expert), indices: (B*L, n_expert)
|
||||
weights = torch.nn.functional.softmax(
|
||||
logits, dim=1,
|
||||
dtype=torch.float).to(dtype=xs.dtype) # (B*L, n_expert_per_token)
|
||||
output = torch.zeros_like(xs) # (B*L, D)
|
||||
for i, expert in enumerate(self.experts):
|
||||
mask = indices == i
|
||||
batch_idx, ith_expert = torch.where(mask)
|
||||
output[batch_idx] += weights[batch_idx, ith_expert, None] * expert(
|
||||
xs[batch_idx])
|
||||
return output.view(B, L, D)
|
||||
383
models/CosyVoice/cosyvoice/transformer/subsampling.py
Normal file
383
models/CosyVoice/cosyvoice/transformer/subsampling.py
Normal file
@@ -0,0 +1,383 @@
|
||||
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
||||
# 2024 Alibaba Inc (Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from ESPnet(https://github.com/espnet/espnet)
|
||||
"""Subsampling layer definition."""
|
||||
|
||||
from typing import Tuple, Union
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class BaseSubsampling(torch.nn.Module):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.right_context = 0
|
||||
self.subsampling_rate = 1
|
||||
|
||||
def position_encoding(self, offset: Union[int, torch.Tensor],
|
||||
size: int) -> torch.Tensor:
|
||||
return self.pos_enc.position_encoding(offset, size)
|
||||
|
||||
|
||||
class EmbedinigNoSubsampling(BaseSubsampling):
|
||||
"""Embedding input without subsampling
|
||||
"""
|
||||
|
||||
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
||||
pos_enc_class: torch.nn.Module):
|
||||
super().__init__()
|
||||
self.embed = torch.nn.Embedding(idim, odim)
|
||||
self.pos_enc = pos_enc_class
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
offset: Union[int, torch.Tensor] = 0
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Input x.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, time, idim).
|
||||
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: linear input tensor (#batch, time', odim),
|
||||
where time' = time .
|
||||
torch.Tensor: linear input mask (#batch, 1, time'),
|
||||
where time' = time .
|
||||
|
||||
"""
|
||||
x = self.embed(x)
|
||||
x, pos_emb = self.pos_enc(x, offset)
|
||||
return x, pos_emb, x_mask
|
||||
|
||||
|
||||
class LinearNoSubsampling(BaseSubsampling):
|
||||
"""Linear transform the input without subsampling
|
||||
|
||||
Args:
|
||||
idim (int): Input dimension.
|
||||
odim (int): Output dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
||||
pos_enc_class: torch.nn.Module):
|
||||
"""Construct an linear object."""
|
||||
super().__init__()
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(idim, odim),
|
||||
torch.nn.LayerNorm(odim, eps=1e-5),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
)
|
||||
self.pos_enc = pos_enc_class
|
||||
self.right_context = 0
|
||||
self.subsampling_rate = 1
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
offset: Union[int, torch.Tensor] = 0
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Input x.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, time, idim).
|
||||
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: linear input tensor (#batch, time', odim),
|
||||
where time' = time .
|
||||
torch.Tensor: linear input mask (#batch, 1, time'),
|
||||
where time' = time .
|
||||
|
||||
"""
|
||||
x = self.out(x)
|
||||
x, pos_emb = self.pos_enc(x, offset)
|
||||
return x, pos_emb, x_mask
|
||||
|
||||
|
||||
class Conv1dSubsampling2(BaseSubsampling):
|
||||
"""Convolutional 1D subsampling (to 1/2 length).
|
||||
It is designed for Whisper, ref:
|
||||
https://github.com/openai/whisper/blob/main/whisper/model.py
|
||||
|
||||
Args:
|
||||
idim (int): Input dimension.
|
||||
odim (int): Output dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
||||
pos_enc_class: torch.nn.Module):
|
||||
"""Construct an Conv1dSubsampling2 object."""
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv1d(idim, odim, kernel_size=3, padding=1),
|
||||
torch.nn.GELU(),
|
||||
torch.nn.Conv1d(odim, odim, kernel_size=3, stride=2, padding=1),
|
||||
torch.nn.GELU(),
|
||||
)
|
||||
self.pos_enc = pos_enc_class
|
||||
# The right context for every conv layer is computed by:
|
||||
# (kernel_size - 1) * frame_rate_of_this_layer
|
||||
self.subsampling_rate = 2
|
||||
# 4 = (3 - 1) * 1 + (3 - 1) * 1
|
||||
self.right_context = 4
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
offset: Union[int, torch.Tensor] = 0
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, time, idim).
|
||||
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
||||
where time' = time // 2.
|
||||
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
||||
where time' = time // 2.
|
||||
torch.Tensor: positional encoding
|
||||
|
||||
"""
|
||||
time = x.size(1)
|
||||
x = x.transpose(1, 2) # (b, f, t)
|
||||
x = self.conv(x)
|
||||
x = x.transpose(1, 2) # (b, t, f)
|
||||
x, pos_emb = self.pos_enc(x, offset)
|
||||
return x, pos_emb, x_mask[:, :, (time + 1) % 2::2]
|
||||
|
||||
|
||||
class Conv2dSubsampling4(BaseSubsampling):
|
||||
"""Convolutional 2D subsampling (to 1/4 length).
|
||||
|
||||
Args:
|
||||
idim (int): Input dimension.
|
||||
odim (int): Output dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
||||
pos_enc_class: torch.nn.Module):
|
||||
"""Construct an Conv2dSubsampling4 object."""
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(odim * (((idim - 1) // 2 - 1) // 2), odim))
|
||||
self.pos_enc = pos_enc_class
|
||||
# The right context for every conv layer is computed by:
|
||||
# (kernel_size - 1) * frame_rate_of_this_layer
|
||||
self.subsampling_rate = 4
|
||||
# 6 = (3 - 1) * 1 + (3 - 1) * 2
|
||||
self.right_context = 6
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
offset: Union[int, torch.Tensor] = 0
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, time, idim).
|
||||
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
||||
where time' = time // 4.
|
||||
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
||||
where time' = time // 4.
|
||||
torch.Tensor: positional encoding
|
||||
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c=1, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.out(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
x, pos_emb = self.pos_enc(x, offset)
|
||||
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2]
|
||||
|
||||
|
||||
class Conv2dSubsampling6(BaseSubsampling):
|
||||
"""Convolutional 2D subsampling (to 1/6 length).
|
||||
Args:
|
||||
idim (int): Input dimension.
|
||||
odim (int): Output dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
pos_enc (torch.nn.Module): Custom position encoding layer.
|
||||
"""
|
||||
|
||||
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
||||
pos_enc_class: torch.nn.Module):
|
||||
"""Construct an Conv2dSubsampling6 object."""
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, 5, 3),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.linear = torch.nn.Linear(odim * (((idim - 1) // 2 - 2) // 3),
|
||||
odim)
|
||||
self.pos_enc = pos_enc_class
|
||||
# 10 = (3 - 1) * 1 + (5 - 1) * 2
|
||||
self.subsampling_rate = 6
|
||||
self.right_context = 10
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
offset: Union[int, torch.Tensor] = 0
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Subsample x.
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, time, idim).
|
||||
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
||||
where time' = time // 6.
|
||||
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
||||
where time' = time // 6.
|
||||
torch.Tensor: positional encoding
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
x, pos_emb = self.pos_enc(x, offset)
|
||||
return x, pos_emb, x_mask[:, :, 2::2][:, :, 4::3]
|
||||
|
||||
|
||||
class Conv2dSubsampling8(BaseSubsampling):
|
||||
"""Convolutional 2D subsampling (to 1/8 length).
|
||||
|
||||
Args:
|
||||
idim (int): Input dimension.
|
||||
odim (int): Output dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
||||
pos_enc_class: torch.nn.Module):
|
||||
"""Construct an Conv2dSubsampling8 object."""
|
||||
super().__init__()
|
||||
self.conv = torch.nn.Sequential(
|
||||
torch.nn.Conv2d(1, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Conv2d(odim, odim, 3, 2),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.linear = torch.nn.Linear(
|
||||
odim * ((((idim - 1) // 2 - 1) // 2 - 1) // 2), odim)
|
||||
self.pos_enc = pos_enc_class
|
||||
self.subsampling_rate = 8
|
||||
# 14 = (3 - 1) * 1 + (3 - 1) * 2 + (3 - 1) * 4
|
||||
self.right_context = 14
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
offset: Union[int, torch.Tensor] = 0
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Subsample x.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, time, idim).
|
||||
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Subsampled tensor (#batch, time', odim),
|
||||
where time' = time // 8.
|
||||
torch.Tensor: Subsampled mask (#batch, 1, time'),
|
||||
where time' = time // 8.
|
||||
torch.Tensor: positional encoding
|
||||
"""
|
||||
x = x.unsqueeze(1) # (b, c, t, f)
|
||||
x = self.conv(x)
|
||||
b, c, t, f = x.size()
|
||||
x = self.linear(x.transpose(1, 2).contiguous().view(b, t, c * f))
|
||||
x, pos_emb = self.pos_enc(x, offset)
|
||||
return x, pos_emb, x_mask[:, :, 2::2][:, :, 2::2][:, :, 2::2]
|
||||
|
||||
|
||||
class LegacyLinearNoSubsampling(BaseSubsampling):
|
||||
"""Linear transform the input without subsampling
|
||||
|
||||
Args:
|
||||
idim (int): Input dimension.
|
||||
odim (int): Output dimension.
|
||||
dropout_rate (float): Dropout rate.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, idim: int, odim: int, dropout_rate: float,
|
||||
pos_enc_class: torch.nn.Module):
|
||||
"""Construct an linear object."""
|
||||
super().__init__()
|
||||
self.out = torch.nn.Sequential(
|
||||
torch.nn.Linear(idim, odim),
|
||||
torch.nn.LayerNorm(odim, eps=1e-5),
|
||||
torch.nn.Dropout(dropout_rate),
|
||||
torch.nn.ReLU(),
|
||||
)
|
||||
self.pos_enc = pos_enc_class
|
||||
self.right_context = 0
|
||||
self.subsampling_rate = 1
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: torch.Tensor,
|
||||
x_mask: torch.Tensor,
|
||||
offset: Union[int, torch.Tensor] = 0
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Input x.
|
||||
|
||||
Args:
|
||||
x (torch.Tensor): Input tensor (#batch, time, idim).
|
||||
x_mask (torch.Tensor): Input mask (#batch, 1, time).
|
||||
|
||||
Returns:
|
||||
torch.Tensor: linear input tensor (#batch, time', odim),
|
||||
where time' = time .
|
||||
torch.Tensor: linear input mask (#batch, 1, time'),
|
||||
where time' = time .
|
||||
|
||||
"""
|
||||
x = self.out(x)
|
||||
x, pos_emb = self.pos_enc(x, offset)
|
||||
return x, pos_emb, x_mask
|
||||
321
models/CosyVoice/cosyvoice/transformer/upsample_encoder.py
Normal file
321
models/CosyVoice/cosyvoice/transformer/upsample_encoder.py
Normal file
@@ -0,0 +1,321 @@
|
||||
# Copyright (c) 2021 Mobvoi Inc (Binbin Zhang, Di Wu)
|
||||
# 2022 Xingchen Song (sxc19@mails.tsinghua.edu.cn)
|
||||
# 2024 Alibaba Inc (Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from ESPnet(https://github.com/espnet/espnet)
|
||||
"""Encoder definition."""
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from cosyvoice.transformer.convolution import ConvolutionModule
|
||||
from cosyvoice.transformer.encoder_layer import ConformerEncoderLayer
|
||||
from cosyvoice.transformer.positionwise_feed_forward import PositionwiseFeedForward
|
||||
from cosyvoice.utils.class_utils import (
|
||||
COSYVOICE_EMB_CLASSES,
|
||||
COSYVOICE_SUBSAMPLE_CLASSES,
|
||||
COSYVOICE_ATTENTION_CLASSES,
|
||||
COSYVOICE_ACTIVATION_CLASSES,
|
||||
)
|
||||
from cosyvoice.utils.mask import make_pad_mask
|
||||
from cosyvoice.utils.mask import add_optional_chunk_mask
|
||||
|
||||
|
||||
class Upsample1D(nn.Module):
|
||||
"""A 1D upsampling layer with an optional convolution.
|
||||
|
||||
Parameters:
|
||||
channels (`int`):
|
||||
number of channels in the inputs and outputs.
|
||||
use_conv (`bool`, default `False`):
|
||||
option to use a convolution.
|
||||
use_conv_transpose (`bool`, default `False`):
|
||||
option to use a convolution transpose.
|
||||
out_channels (`int`, optional):
|
||||
number of output channels. Defaults to `channels`.
|
||||
"""
|
||||
|
||||
def __init__(self, channels: int, out_channels: int, stride: int = 2):
|
||||
super().__init__()
|
||||
self.channels = channels
|
||||
self.out_channels = out_channels
|
||||
self.stride = stride
|
||||
# In this mode, first repeat interpolate, than conv with stride=1
|
||||
self.conv = nn.Conv1d(self.channels, self.out_channels, stride * 2 + 1, stride=1, padding=0)
|
||||
|
||||
def forward(self, inputs: torch.Tensor, input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
outputs = F.interpolate(inputs, scale_factor=float(self.stride), mode="nearest")
|
||||
outputs = F.pad(outputs, (self.stride * 2, 0), value=0.0)
|
||||
outputs = self.conv(outputs)
|
||||
return outputs, input_lengths * self.stride
|
||||
|
||||
|
||||
class PreLookaheadLayer(nn.Module):
|
||||
def __init__(self, in_channels: int, channels: int, pre_lookahead_len: int = 1):
|
||||
super().__init__()
|
||||
self.in_channels = in_channels
|
||||
self.channels = channels
|
||||
self.pre_lookahead_len = pre_lookahead_len
|
||||
self.conv1 = nn.Conv1d(
|
||||
in_channels, channels,
|
||||
kernel_size=pre_lookahead_len + 1,
|
||||
stride=1, padding=0,
|
||||
)
|
||||
self.conv2 = nn.Conv1d(
|
||||
channels, in_channels,
|
||||
kernel_size=3, stride=1, padding=0,
|
||||
)
|
||||
|
||||
def forward(self, inputs: torch.Tensor, context: torch.Tensor = torch.zeros(0, 0, 0)) -> torch.Tensor:
|
||||
"""
|
||||
inputs: (batch_size, seq_len, channels)
|
||||
"""
|
||||
outputs = inputs.transpose(1, 2).contiguous()
|
||||
context = context.transpose(1, 2).contiguous()
|
||||
# look ahead
|
||||
if context.size(2) == 0:
|
||||
outputs = F.pad(outputs, (0, self.pre_lookahead_len), mode='constant', value=0.0)
|
||||
else:
|
||||
assert self.training is False, 'you have passed context, make sure that you are running inference mode'
|
||||
assert context.size(2) == self.pre_lookahead_len
|
||||
outputs = F.pad(torch.concat([outputs, context], dim=2), (0, self.pre_lookahead_len - context.size(2)), mode='constant', value=0.0)
|
||||
outputs = F.leaky_relu(self.conv1(outputs))
|
||||
# outputs
|
||||
outputs = F.pad(outputs, (self.conv2.kernel_size[0] - 1, 0), mode='constant', value=0.0)
|
||||
outputs = self.conv2(outputs)
|
||||
outputs = outputs.transpose(1, 2).contiguous()
|
||||
|
||||
# residual connection
|
||||
outputs = outputs + inputs
|
||||
return outputs
|
||||
|
||||
|
||||
class UpsampleConformerEncoder(torch.nn.Module):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_size: int,
|
||||
output_size: int = 256,
|
||||
attention_heads: int = 4,
|
||||
linear_units: int = 2048,
|
||||
num_blocks: int = 6,
|
||||
dropout_rate: float = 0.1,
|
||||
positional_dropout_rate: float = 0.1,
|
||||
attention_dropout_rate: float = 0.0,
|
||||
input_layer: str = "conv2d",
|
||||
pos_enc_layer_type: str = "rel_pos",
|
||||
normalize_before: bool = True,
|
||||
static_chunk_size: int = 0,
|
||||
use_dynamic_chunk: bool = False,
|
||||
global_cmvn: torch.nn.Module = None,
|
||||
use_dynamic_left_chunk: bool = False,
|
||||
positionwise_conv_kernel_size: int = 1,
|
||||
macaron_style: bool = True,
|
||||
selfattention_layer_type: str = "rel_selfattn",
|
||||
activation_type: str = "swish",
|
||||
use_cnn_module: bool = True,
|
||||
cnn_module_kernel: int = 15,
|
||||
causal: bool = False,
|
||||
cnn_module_norm: str = "batch_norm",
|
||||
key_bias: bool = True,
|
||||
gradient_checkpointing: bool = False,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
input_size (int): input dim
|
||||
output_size (int): dimension of attention
|
||||
attention_heads (int): the number of heads of multi head attention
|
||||
linear_units (int): the hidden units number of position-wise feed
|
||||
forward
|
||||
num_blocks (int): the number of decoder blocks
|
||||
dropout_rate (float): dropout rate
|
||||
attention_dropout_rate (float): dropout rate in attention
|
||||
positional_dropout_rate (float): dropout rate after adding
|
||||
positional encoding
|
||||
input_layer (str): input layer type.
|
||||
optional [linear, conv2d, conv2d6, conv2d8]
|
||||
pos_enc_layer_type (str): Encoder positional encoding layer type.
|
||||
opitonal [abs_pos, scaled_abs_pos, rel_pos, no_pos]
|
||||
normalize_before (bool):
|
||||
True: use layer_norm before each sub-block of a layer.
|
||||
False: use layer_norm after each sub-block of a layer.
|
||||
static_chunk_size (int): chunk size for static chunk training and
|
||||
decoding
|
||||
use_dynamic_chunk (bool): whether use dynamic chunk size for
|
||||
training or not, You can only use fixed chunk(chunk_size > 0)
|
||||
or dyanmic chunk size(use_dynamic_chunk = True)
|
||||
global_cmvn (Optional[torch.nn.Module]): Optional GlobalCMVN module
|
||||
use_dynamic_left_chunk (bool): whether use dynamic left chunk in
|
||||
dynamic chunk training
|
||||
key_bias: whether use bias in attention.linear_k, False for whisper models.
|
||||
gradient_checkpointing: rerunning a forward-pass segment for each
|
||||
checkpointed segment during backward.
|
||||
"""
|
||||
super().__init__()
|
||||
self._output_size = output_size
|
||||
|
||||
self.global_cmvn = global_cmvn
|
||||
self.embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
|
||||
positional_dropout_rate),
|
||||
)
|
||||
|
||||
self.normalize_before = normalize_before
|
||||
self.after_norm = torch.nn.LayerNorm(output_size, eps=1e-5)
|
||||
self.static_chunk_size = static_chunk_size
|
||||
self.use_dynamic_chunk = use_dynamic_chunk
|
||||
self.use_dynamic_left_chunk = use_dynamic_left_chunk
|
||||
self.gradient_checkpointing = gradient_checkpointing
|
||||
activation = COSYVOICE_ACTIVATION_CLASSES[activation_type]()
|
||||
# self-attention module definition
|
||||
encoder_selfattn_layer_args = (
|
||||
attention_heads,
|
||||
output_size,
|
||||
attention_dropout_rate,
|
||||
key_bias,
|
||||
)
|
||||
# feed-forward module definition
|
||||
positionwise_layer_args = (
|
||||
output_size,
|
||||
linear_units,
|
||||
dropout_rate,
|
||||
activation,
|
||||
)
|
||||
# convolution module definition
|
||||
convolution_layer_args = (output_size, cnn_module_kernel, activation,
|
||||
cnn_module_norm, causal)
|
||||
self.pre_lookahead_layer = PreLookaheadLayer(in_channels=512, channels=512, pre_lookahead_len=3)
|
||||
self.encoders = torch.nn.ModuleList([
|
||||
ConformerEncoderLayer(
|
||||
output_size,
|
||||
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
|
||||
*encoder_selfattn_layer_args),
|
||||
PositionwiseFeedForward(*positionwise_layer_args),
|
||||
PositionwiseFeedForward(
|
||||
*positionwise_layer_args) if macaron_style else None,
|
||||
ConvolutionModule(
|
||||
*convolution_layer_args) if use_cnn_module else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
) for _ in range(num_blocks)
|
||||
])
|
||||
self.up_layer = Upsample1D(channels=512, out_channels=512, stride=2)
|
||||
self.up_embed = COSYVOICE_SUBSAMPLE_CLASSES[input_layer](
|
||||
input_size,
|
||||
output_size,
|
||||
dropout_rate,
|
||||
COSYVOICE_EMB_CLASSES[pos_enc_layer_type](output_size,
|
||||
positional_dropout_rate),
|
||||
)
|
||||
self.up_encoders = torch.nn.ModuleList([
|
||||
ConformerEncoderLayer(
|
||||
output_size,
|
||||
COSYVOICE_ATTENTION_CLASSES[selfattention_layer_type](
|
||||
*encoder_selfattn_layer_args),
|
||||
PositionwiseFeedForward(*positionwise_layer_args),
|
||||
PositionwiseFeedForward(
|
||||
*positionwise_layer_args) if macaron_style else None,
|
||||
ConvolutionModule(
|
||||
*convolution_layer_args) if use_cnn_module else None,
|
||||
dropout_rate,
|
||||
normalize_before,
|
||||
) for _ in range(4)
|
||||
])
|
||||
|
||||
def output_size(self) -> int:
|
||||
return self._output_size
|
||||
|
||||
def forward(
|
||||
self,
|
||||
xs: torch.Tensor,
|
||||
xs_lens: torch.Tensor,
|
||||
context: torch.Tensor = torch.zeros(0, 0, 0),
|
||||
decoding_chunk_size: int = 0,
|
||||
num_decoding_left_chunks: int = -1,
|
||||
streaming: bool = False,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Embed positions in tensor.
|
||||
|
||||
Args:
|
||||
xs: padded input tensor (B, T, D)
|
||||
xs_lens: input length (B)
|
||||
decoding_chunk_size: decoding chunk size for dynamic chunk
|
||||
0: default for training, use random dynamic chunk.
|
||||
<0: for decoding, use full chunk.
|
||||
>0: for decoding, use fixed chunk size as set.
|
||||
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
||||
the chunk size is decoding_chunk_size.
|
||||
>=0: use num_decoding_left_chunks
|
||||
<0: use all left chunks
|
||||
Returns:
|
||||
encoder output tensor xs, and subsampled masks
|
||||
xs: padded output tensor (B, T' ~= T/subsample_rate, D)
|
||||
masks: torch.Tensor batch padding mask after subsample
|
||||
(B, 1, T' ~= T/subsample_rate)
|
||||
NOTE(xcsong):
|
||||
We pass the `__call__` method of the modules instead of `forward` to the
|
||||
checkpointing API because `__call__` attaches all the hooks of the module.
|
||||
https://discuss.pytorch.org/t/any-different-between-model-input-and-model-forward-input/3690/2
|
||||
"""
|
||||
T = xs.size(1)
|
||||
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
||||
if self.global_cmvn is not None:
|
||||
xs = self.global_cmvn(xs)
|
||||
xs, pos_emb, masks = self.embed(xs, masks)
|
||||
if context.size(1) != 0:
|
||||
assert self.training is False, 'you have passed context, make sure that you are running inference mode'
|
||||
context_masks = torch.ones(1, 1, context.size(1)).to(masks)
|
||||
context, _, _ = self.embed(context, context_masks, offset=xs.size(1))
|
||||
mask_pad = masks # (B, 1, T/subsample_rate)
|
||||
chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size if streaming is True else 0, -1)
|
||||
# lookahead + conformer encoder
|
||||
xs = self.pre_lookahead_layer(xs, context=context)
|
||||
xs = self.forward_layers(xs, chunk_masks, pos_emb, mask_pad)
|
||||
|
||||
# upsample + conformer encoder
|
||||
xs = xs.transpose(1, 2).contiguous()
|
||||
xs, xs_lens = self.up_layer(xs, xs_lens)
|
||||
xs = xs.transpose(1, 2).contiguous()
|
||||
T = xs.size(1)
|
||||
masks = ~make_pad_mask(xs_lens, T).unsqueeze(1) # (B, 1, T)
|
||||
xs, pos_emb, masks = self.up_embed(xs, masks)
|
||||
mask_pad = masks # (B, 1, T/subsample_rate)
|
||||
chunk_masks = add_optional_chunk_mask(xs, masks, False, False, 0, self.static_chunk_size * self.up_layer.stride if streaming is True else 0, -1)
|
||||
xs = self.forward_up_layers(xs, chunk_masks, pos_emb, mask_pad)
|
||||
|
||||
if self.normalize_before:
|
||||
xs = self.after_norm(xs)
|
||||
# Here we assume the mask is not changed in encoder layers, so just
|
||||
# return the masks before encoder layers, and the masks will be used
|
||||
# for cross attention with decoder later
|
||||
return xs, masks
|
||||
|
||||
def forward_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
mask_pad: torch.Tensor) -> torch.Tensor:
|
||||
for layer in self.encoders:
|
||||
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
||||
return xs
|
||||
|
||||
def forward_up_layers(self, xs: torch.Tensor, chunk_masks: torch.Tensor,
|
||||
pos_emb: torch.Tensor,
|
||||
mask_pad: torch.Tensor) -> torch.Tensor:
|
||||
for layer in self.up_encoders:
|
||||
xs, chunk_masks, _, _ = layer(xs, chunk_masks, pos_emb, mask_pad)
|
||||
return xs
|
||||
0
models/CosyVoice/cosyvoice/utils/__init__.py
Normal file
0
models/CosyVoice/cosyvoice/utils/__init__.py
Normal file
85
models/CosyVoice/cosyvoice/utils/class_utils.py
Normal file
85
models/CosyVoice/cosyvoice/utils/class_utils.py
Normal file
@@ -0,0 +1,85 @@
|
||||
# Copyright [2023-11-28] <sxc19@mails.tsinghua.edu.cn, Xingchen Song>
|
||||
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import torch
|
||||
|
||||
from cosyvoice.transformer.activation import Swish
|
||||
from cosyvoice.transformer.subsampling import (
|
||||
LinearNoSubsampling,
|
||||
EmbedinigNoSubsampling,
|
||||
Conv1dSubsampling2,
|
||||
Conv2dSubsampling4,
|
||||
Conv2dSubsampling6,
|
||||
Conv2dSubsampling8,
|
||||
)
|
||||
from cosyvoice.transformer.embedding import (PositionalEncoding,
|
||||
RelPositionalEncoding,
|
||||
WhisperPositionalEncoding,
|
||||
LearnablePositionalEncoding,
|
||||
NoPositionalEncoding)
|
||||
from cosyvoice.transformer.attention import (MultiHeadedAttention,
|
||||
RelPositionMultiHeadedAttention)
|
||||
from cosyvoice.transformer.embedding import EspnetRelPositionalEncoding
|
||||
from cosyvoice.transformer.subsampling import LegacyLinearNoSubsampling
|
||||
from cosyvoice.llm.llm import TransformerLM, Qwen2LM, CosyVoice3LM
|
||||
from cosyvoice.flow.flow import MaskedDiffWithXvec, CausalMaskedDiffWithXvec, CausalMaskedDiffWithDiT
|
||||
from cosyvoice.hifigan.generator import HiFTGenerator, CausalHiFTGenerator
|
||||
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
|
||||
|
||||
|
||||
COSYVOICE_ACTIVATION_CLASSES = {
|
||||
"hardtanh": torch.nn.Hardtanh,
|
||||
"tanh": torch.nn.Tanh,
|
||||
"relu": torch.nn.ReLU,
|
||||
"selu": torch.nn.SELU,
|
||||
"swish": getattr(torch.nn, "SiLU", Swish),
|
||||
"gelu": torch.nn.GELU,
|
||||
}
|
||||
|
||||
COSYVOICE_SUBSAMPLE_CLASSES = {
|
||||
"linear": LinearNoSubsampling,
|
||||
"linear_legacy": LegacyLinearNoSubsampling,
|
||||
"embed": EmbedinigNoSubsampling,
|
||||
"conv1d2": Conv1dSubsampling2,
|
||||
"conv2d": Conv2dSubsampling4,
|
||||
"conv2d6": Conv2dSubsampling6,
|
||||
"conv2d8": Conv2dSubsampling8,
|
||||
'paraformer_dummy': torch.nn.Identity
|
||||
}
|
||||
|
||||
COSYVOICE_EMB_CLASSES = {
|
||||
"embed": PositionalEncoding,
|
||||
"abs_pos": PositionalEncoding,
|
||||
"rel_pos": RelPositionalEncoding,
|
||||
"rel_pos_espnet": EspnetRelPositionalEncoding,
|
||||
"no_pos": NoPositionalEncoding,
|
||||
"abs_pos_whisper": WhisperPositionalEncoding,
|
||||
"embed_learnable_pe": LearnablePositionalEncoding,
|
||||
}
|
||||
|
||||
COSYVOICE_ATTENTION_CLASSES = {
|
||||
"selfattn": MultiHeadedAttention,
|
||||
"rel_selfattn": RelPositionMultiHeadedAttention,
|
||||
}
|
||||
|
||||
|
||||
def get_model_type(configs):
|
||||
# NOTE CosyVoice2Model inherits CosyVoiceModel
|
||||
if isinstance(configs['llm'], TransformerLM) and isinstance(configs['flow'], MaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
|
||||
return CosyVoiceModel
|
||||
if isinstance(configs['llm'], Qwen2LM) and isinstance(configs['flow'], CausalMaskedDiffWithXvec) and isinstance(configs['hift'], HiFTGenerator):
|
||||
return CosyVoice2Model
|
||||
if isinstance(configs['llm'], CosyVoice3LM) and isinstance(configs['flow'], CausalMaskedDiffWithDiT) and isinstance(configs['hift'], CausalHiFTGenerator):
|
||||
return CosyVoice3Model
|
||||
raise TypeError('No valid model type found!')
|
||||
214
models/CosyVoice/cosyvoice/utils/common.py
Normal file
214
models/CosyVoice/cosyvoice/utils/common.py
Normal file
@@ -0,0 +1,214 @@
|
||||
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
||||
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
# 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from ESPnet(https://github.com/espnet/espnet)
|
||||
"""Unility functions for Transformer."""
|
||||
|
||||
import queue
|
||||
import random
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
IGNORE_ID = -1
|
||||
|
||||
instruct_list = ["You are a helpful assistant. 请用广东话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用东北话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用甘肃话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用贵州话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用河南话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用湖北话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用湖南话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用江西话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用闽南话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用宁夏话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用山西话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用陕西话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用山东话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用上海话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用四川话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用天津话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用云南话表达。<|endofprompt|>",
|
||||
"You are a helpful assistant. Please say a sentence as loudly as possible.<|endofprompt|>",
|
||||
"You are a helpful assistant. Please say a sentence in a very soft voice.<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用尽可能慢地语速说一句话。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请用尽可能快地语速说一句话。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请非常开心地说一句话。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请非常伤心地说一句话。<|endofprompt|>",
|
||||
"You are a helpful assistant. 请非常生气地说一句话。<|endofprompt|>",
|
||||
"You are a helpful assistant. 我想体验一下小猪佩奇风格,可以吗?<|endofprompt|>",
|
||||
"You are a helpful assistant. 你可以尝试用机器人的方式解答吗?<|endofprompt|>"]
|
||||
|
||||
|
||||
def pad_list(xs: List[torch.Tensor], pad_value: int):
|
||||
"""Perform padding for the list of tensors.
|
||||
|
||||
Args:
|
||||
xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)].
|
||||
pad_value (float): Value for padding.
|
||||
|
||||
Returns:
|
||||
Tensor: Padded tensor (B, Tmax, `*`).
|
||||
|
||||
Examples:
|
||||
>>> x = [torch.ones(4), torch.ones(2), torch.ones(1)]
|
||||
>>> x
|
||||
[tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])]
|
||||
>>> pad_list(x, 0)
|
||||
tensor([[1., 1., 1., 1.],
|
||||
[1., 1., 0., 0.],
|
||||
[1., 0., 0., 0.]])
|
||||
|
||||
"""
|
||||
max_len = max([len(item) for item in xs])
|
||||
batchs = len(xs)
|
||||
ndim = xs[0].ndim
|
||||
if ndim == 1:
|
||||
pad_res = torch.zeros(batchs,
|
||||
max_len,
|
||||
dtype=xs[0].dtype,
|
||||
device=xs[0].device)
|
||||
elif ndim == 2:
|
||||
pad_res = torch.zeros(batchs,
|
||||
max_len,
|
||||
xs[0].shape[1],
|
||||
dtype=xs[0].dtype,
|
||||
device=xs[0].device)
|
||||
elif ndim == 3:
|
||||
pad_res = torch.zeros(batchs,
|
||||
max_len,
|
||||
xs[0].shape[1],
|
||||
xs[0].shape[2],
|
||||
dtype=xs[0].dtype,
|
||||
device=xs[0].device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported ndim: {ndim}")
|
||||
pad_res.fill_(pad_value)
|
||||
for i in range(batchs):
|
||||
pad_res[i, :len(xs[i])] = xs[i]
|
||||
return pad_res
|
||||
|
||||
|
||||
def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor,
|
||||
ignore_label: int) -> torch.Tensor:
|
||||
"""Calculate accuracy.
|
||||
|
||||
Args:
|
||||
pad_outputs (Tensor): Prediction tensors (B * Lmax, D).
|
||||
pad_targets (LongTensor): Target label tensors (B, Lmax).
|
||||
ignore_label (int): Ignore label id.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Accuracy value (0.0 - 1.0).
|
||||
|
||||
"""
|
||||
pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1),
|
||||
pad_outputs.size(1)).argmax(2)
|
||||
mask = pad_targets != ignore_label
|
||||
numerator = torch.sum(
|
||||
pad_pred.masked_select(mask) == pad_targets.masked_select(mask))
|
||||
denominator = torch.sum(mask)
|
||||
return (numerator / denominator).detach()
|
||||
|
||||
|
||||
def get_padding(kernel_size, dilation=1):
|
||||
return int((kernel_size * dilation - dilation) / 2)
|
||||
|
||||
|
||||
def init_weights(m, mean=0.0, std=0.01):
|
||||
classname = m.__class__.__name__
|
||||
if classname.find("Conv") != -1:
|
||||
m.weight.data.normal_(mean, std)
|
||||
|
||||
|
||||
# Repetition Aware Sampling in VALL-E 2
|
||||
def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
|
||||
top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
|
||||
rep_num = (torch.tensor(decoded_tokens[-win_size:]).to(weighted_scores.device) == top_ids).sum().item()
|
||||
if rep_num >= win_size * tau_r:
|
||||
weighted_scores[top_ids] = -float('inf')
|
||||
top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
|
||||
return top_ids
|
||||
|
||||
|
||||
def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
|
||||
prob, indices = [], []
|
||||
cum_prob = 0.0
|
||||
sorted_value, sorted_idx = weighted_scores.softmax(dim=0).sort(descending=True, stable=True)
|
||||
for i in range(len(sorted_idx)):
|
||||
# sampling both top-p and numbers.
|
||||
if cum_prob < top_p and len(prob) < top_k:
|
||||
cum_prob += sorted_value[i]
|
||||
prob.append(sorted_value[i])
|
||||
indices.append(sorted_idx[i])
|
||||
else:
|
||||
break
|
||||
prob = torch.tensor(prob).to(weighted_scores)
|
||||
indices = torch.tensor(indices, dtype=torch.long).to(weighted_scores.device)
|
||||
top_ids = indices[prob.multinomial(1, replacement=True)].item()
|
||||
return top_ids
|
||||
|
||||
|
||||
def random_sampling(weighted_scores, decoded_tokens, sampling):
|
||||
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True).item()
|
||||
return top_ids
|
||||
|
||||
|
||||
def fade_in_out(fade_in_mel, fade_out_mel, window):
|
||||
device = fade_in_mel.device
|
||||
fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
|
||||
mel_overlap_len = int(window.shape[0] / 2)
|
||||
if fade_in_mel.device == torch.device('cpu'):
|
||||
fade_in_mel = fade_in_mel.clone()
|
||||
fade_in_mel[..., :mel_overlap_len] = fade_in_mel[..., :mel_overlap_len] * window[:mel_overlap_len] + \
|
||||
fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
|
||||
return fade_in_mel.to(device)
|
||||
|
||||
|
||||
def set_all_random_seed(seed):
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
|
||||
def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor:
|
||||
assert mask.dtype == torch.bool
|
||||
assert dtype in [torch.float32, torch.bfloat16, torch.float16]
|
||||
mask = mask.to(dtype)
|
||||
# attention mask bias
|
||||
# NOTE(Mddct): torch.finfo jit issues
|
||||
# chunk_masks = (1.0 - chunk_masks) * torch.finfo(dtype).min
|
||||
mask = (1.0 - mask) * -1.0e+10
|
||||
return mask
|
||||
|
||||
|
||||
class TrtContextWrapper:
|
||||
def __init__(self, trt_engine, trt_concurrent=1, device='cuda:0'):
|
||||
self.trt_context_pool = queue.Queue(maxsize=trt_concurrent)
|
||||
self.trt_engine = trt_engine
|
||||
for _ in range(trt_concurrent):
|
||||
trt_context = trt_engine.create_execution_context()
|
||||
trt_stream = torch.cuda.stream(torch.cuda.Stream(device))
|
||||
assert trt_context is not None, 'failed to create trt context, maybe not enough CUDA memory, try reduce current trt concurrent {}'.format(trt_concurrent)
|
||||
self.trt_context_pool.put([trt_context, trt_stream])
|
||||
assert self.trt_context_pool.empty() is False, 'no avaialbe estimator context'
|
||||
|
||||
def acquire_estimator(self):
|
||||
return self.trt_context_pool.get(), self.trt_engine
|
||||
|
||||
def release_estimator(self, context, stream):
|
||||
self.trt_context_pool.put([context, stream])
|
||||
176
models/CosyVoice/cosyvoice/utils/executor.py
Normal file
176
models/CosyVoice/cosyvoice/utils/executor.py
Normal file
@@ -0,0 +1,176 @@
|
||||
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
||||
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from contextlib import nullcontext
|
||||
import os
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
|
||||
from cosyvoice.utils.train_utils import update_parameter_and_lr, log_per_step, log_per_save, batch_forward, batch_backward, save_model, cosyvoice_join
|
||||
|
||||
|
||||
class Executor:
|
||||
|
||||
def __init__(self, gan: bool = False, ref_model: torch.nn.Module = None, dpo_loss: torch.nn.Module = None):
|
||||
self.gan = gan
|
||||
self.ref_model = ref_model
|
||||
self.dpo_loss = dpo_loss
|
||||
self.step = 0
|
||||
self.epoch = 0
|
||||
self.rank = int(os.environ.get('RANK', 0))
|
||||
self.device = torch.device('cuda:{}'.format(self.rank))
|
||||
|
||||
def train_one_epoc(self, model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=None):
|
||||
''' Train one epoch
|
||||
'''
|
||||
|
||||
lr = optimizer.param_groups[0]['lr']
|
||||
logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
|
||||
logging.info('using accumulate grad, new batch size is {} times'
|
||||
' larger than before'.format(info_dict['accum_grad']))
|
||||
# A context manager to be used in conjunction with an instance of
|
||||
# torch.nn.parallel.DistributedDataParallel to be able to train
|
||||
# with uneven inputs across participating processes.
|
||||
model.train()
|
||||
if self.ref_model is not None:
|
||||
self.ref_model.eval()
|
||||
model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
|
||||
with model_context():
|
||||
for batch_idx, batch_dict in enumerate(train_data_loader):
|
||||
info_dict["tag"] = "TRAIN"
|
||||
info_dict["step"] = self.step
|
||||
info_dict["epoch"] = self.epoch
|
||||
info_dict["batch_idx"] = batch_idx
|
||||
if cosyvoice_join(group_join, info_dict):
|
||||
break
|
||||
|
||||
# Disable gradient synchronizations across DDP processes.
|
||||
# Within this context, gradients will be accumulated on module
|
||||
# variables, which will later be synchronized.
|
||||
if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
|
||||
context = model.no_sync
|
||||
# Used for single gpu training and DDP gradient synchronization
|
||||
# processes.
|
||||
else:
|
||||
context = nullcontext
|
||||
|
||||
with context():
|
||||
info_dict = batch_forward(model, batch_dict, scaler, info_dict, ref_model=self.ref_model, dpo_loss=self.dpo_loss)
|
||||
info_dict = batch_backward(model, scaler, info_dict)
|
||||
|
||||
info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
|
||||
log_per_step(writer, info_dict)
|
||||
# NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
|
||||
if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
|
||||
(batch_idx + 1) % info_dict["accum_grad"] == 0:
|
||||
dist.barrier()
|
||||
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
|
||||
model.train()
|
||||
if (batch_idx + 1) % info_dict["accum_grad"] == 0:
|
||||
self.step += 1
|
||||
dist.barrier()
|
||||
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
|
||||
|
||||
def train_one_epoc_gan(self, model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
|
||||
writer, info_dict, scaler, group_join):
|
||||
''' Train one epoch
|
||||
'''
|
||||
|
||||
lr = optimizer.param_groups[0]['lr']
|
||||
logging.info('Epoch {} TRAIN info lr {} rank {}'.format(self.epoch, lr, self.rank))
|
||||
logging.info('using accumulate grad, new batch size is {} times'
|
||||
' larger than before'.format(info_dict['accum_grad']))
|
||||
# A context manager to be used in conjunction with an instance of
|
||||
# torch.nn.parallel.DistributedDataParallel to be able to train
|
||||
# with uneven inputs across participating processes.
|
||||
model.train()
|
||||
model_context = model.join if info_dict['train_engine'] == 'torch_ddp' else nullcontext
|
||||
with model_context():
|
||||
for batch_idx, batch_dict in enumerate(train_data_loader):
|
||||
info_dict["tag"] = "TRAIN"
|
||||
info_dict["step"] = self.step
|
||||
info_dict["epoch"] = self.epoch
|
||||
info_dict["batch_idx"] = batch_idx
|
||||
if cosyvoice_join(group_join, info_dict):
|
||||
break
|
||||
|
||||
# Disable gradient synchronizations across DDP processes.
|
||||
# Within this context, gradients will be accumulated on module
|
||||
# variables, which will later be synchronized.
|
||||
if info_dict['train_engine'] == 'torch_ddp' and (batch_idx + 1) % info_dict["accum_grad"] != 0:
|
||||
context = model.no_sync
|
||||
# Used for single gpu training and DDP gradient synchronization
|
||||
# processes.
|
||||
else:
|
||||
context = nullcontext
|
||||
|
||||
with context():
|
||||
batch_dict['turn'] = 'discriminator'
|
||||
info_dict = batch_forward(model, batch_dict, scaler, info_dict)
|
||||
info_dict = batch_backward(model, scaler, info_dict)
|
||||
info_dict = update_parameter_and_lr(model, optimizer_d, scheduler_d, scaler, info_dict)
|
||||
optimizer.zero_grad()
|
||||
log_per_step(writer, info_dict)
|
||||
with context():
|
||||
batch_dict['turn'] = 'generator'
|
||||
info_dict = batch_forward(model, batch_dict, scaler, info_dict)
|
||||
info_dict = batch_backward(model, scaler, info_dict)
|
||||
info_dict = update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict)
|
||||
optimizer_d.zero_grad()
|
||||
log_per_step(writer, info_dict)
|
||||
# NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
|
||||
if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
|
||||
(batch_idx + 1) % info_dict["accum_grad"] == 0:
|
||||
dist.barrier()
|
||||
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
|
||||
model.train()
|
||||
if (batch_idx + 1) % info_dict["accum_grad"] == 0:
|
||||
self.step += 1
|
||||
dist.barrier()
|
||||
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=True)
|
||||
|
||||
@torch.inference_mode()
|
||||
def cv(self, model, cv_data_loader, writer, info_dict, on_batch_end=True):
|
||||
''' Cross validation on
|
||||
'''
|
||||
logging.info('Epoch {} Step {} on_batch_end {} CV rank {}'.format(self.epoch, self.step + 1, on_batch_end, self.rank))
|
||||
model.eval()
|
||||
total_num_utts, total_loss_dict = 0, {} # avoid division by 0
|
||||
for batch_idx, batch_dict in enumerate(cv_data_loader):
|
||||
info_dict["tag"] = "CV"
|
||||
info_dict["step"] = self.step
|
||||
info_dict["epoch"] = self.epoch
|
||||
info_dict["batch_idx"] = batch_idx
|
||||
|
||||
num_utts = len(batch_dict["utts"])
|
||||
total_num_utts += num_utts
|
||||
|
||||
if self.gan is True:
|
||||
batch_dict['turn'] = 'generator'
|
||||
info_dict = batch_forward(model, batch_dict, None, info_dict)
|
||||
|
||||
for k, v in info_dict['loss_dict'].items():
|
||||
if k not in total_loss_dict:
|
||||
total_loss_dict[k] = []
|
||||
total_loss_dict[k].append(v.mean().item() * num_utts)
|
||||
log_per_step(None, info_dict)
|
||||
for k, v in total_loss_dict.items():
|
||||
total_loss_dict[k] = sum(v) / total_num_utts
|
||||
info_dict['loss_dict'] = total_loss_dict
|
||||
log_per_save(writer, info_dict)
|
||||
model_name = 'epoch_{}_whole'.format(self.epoch) if on_batch_end else 'epoch_{}_step_{}'.format(self.epoch, self.step + 1)
|
||||
save_model(model, model_name, info_dict)
|
||||
122
models/CosyVoice/cosyvoice/utils/file_utils.py
Normal file
122
models/CosyVoice/cosyvoice/utils/file_utils.py
Normal file
@@ -0,0 +1,122 @@
|
||||
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
||||
# 2024 Alibaba Inc (authors: Xiang Lyu, Zetao Hu)
|
||||
# 2025 Alibaba Inc (authors: Xiang Lyu, Yabin Li)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
import json
|
||||
import numpy as np
|
||||
import torch
|
||||
import torchaudio
|
||||
import soundfile as sf
|
||||
import logging
|
||||
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
||||
logging.basicConfig(level=logging.INFO,
|
||||
format='%(asctime)s %(levelname)s %(message)s')
|
||||
|
||||
|
||||
def read_lists(list_file):
|
||||
lists = []
|
||||
with open(list_file, 'r', encoding='utf8') as fin:
|
||||
for line in fin:
|
||||
lists.append(line.strip())
|
||||
return lists
|
||||
|
||||
|
||||
def read_json_lists(list_file):
|
||||
lists = read_lists(list_file)
|
||||
results = {}
|
||||
for fn in lists:
|
||||
with open(fn, 'r', encoding='utf8') as fin:
|
||||
results.update(json.load(fin))
|
||||
return results
|
||||
|
||||
|
||||
def load_wav(wav, target_sr, min_sr=16000):
|
||||
speech_np, sample_rate = sf.read(wav, dtype='float32', always_2d=True)
|
||||
# soundfile: [frames, channels] -> torch: [channels, frames]
|
||||
speech = torch.from_numpy(np.ascontiguousarray(speech_np.T))
|
||||
speech = speech.mean(dim=0, keepdim=True)
|
||||
if sample_rate != target_sr:
|
||||
assert sample_rate >= min_sr, 'wav sample rate {} must be greater than {}'.format(sample_rate, target_sr)
|
||||
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
|
||||
return speech
|
||||
|
||||
|
||||
def convert_onnx_to_trt(trt_model, trt_kwargs, onnx_model, fp16):
|
||||
import tensorrt as trt
|
||||
logging.info("Converting onnx to trt...")
|
||||
network_flags = 1 << int(trt.NetworkDefinitionCreationFlag.EXPLICIT_BATCH)
|
||||
logger = trt.Logger(trt.Logger.INFO)
|
||||
builder = trt.Builder(logger)
|
||||
network = builder.create_network(network_flags)
|
||||
parser = trt.OnnxParser(network, logger)
|
||||
config = builder.create_builder_config()
|
||||
config.set_memory_pool_limit(trt.MemoryPoolType.WORKSPACE, 1 << 32) # 4GB
|
||||
if fp16:
|
||||
config.set_flag(trt.BuilderFlag.FP16)
|
||||
profile = builder.create_optimization_profile()
|
||||
# load onnx model
|
||||
with open(onnx_model, "rb") as f:
|
||||
if not parser.parse(f.read()):
|
||||
for error in range(parser.num_errors):
|
||||
print(parser.get_error(error))
|
||||
raise ValueError('failed to parse {}'.format(onnx_model))
|
||||
# set input shapes
|
||||
for i in range(len(trt_kwargs['input_names'])):
|
||||
profile.set_shape(trt_kwargs['input_names'][i], trt_kwargs['min_shape'][i], trt_kwargs['opt_shape'][i], trt_kwargs['max_shape'][i])
|
||||
tensor_dtype = trt.DataType.HALF if fp16 else trt.DataType.FLOAT
|
||||
# set input and output data type
|
||||
for i in range(network.num_inputs):
|
||||
input_tensor = network.get_input(i)
|
||||
input_tensor.dtype = tensor_dtype
|
||||
for i in range(network.num_outputs):
|
||||
output_tensor = network.get_output(i)
|
||||
output_tensor.dtype = tensor_dtype
|
||||
config.add_optimization_profile(profile)
|
||||
engine_bytes = builder.build_serialized_network(network, config)
|
||||
# save trt engine
|
||||
with open(trt_model, "wb") as f:
|
||||
f.write(engine_bytes)
|
||||
logging.info("Succesfully convert onnx to trt...")
|
||||
|
||||
|
||||
# NOTE do not support bistream inference as only speech token embedding/head is kept
|
||||
def export_cosyvoice2_vllm(model, model_path, device):
|
||||
if os.path.exists(model_path):
|
||||
return
|
||||
|
||||
dtype = torch.bfloat16
|
||||
# lm_head
|
||||
use_bias = True if model.llm_decoder.bias is not None else False
|
||||
model.llm.model.lm_head = model.llm_decoder
|
||||
# embed_tokens
|
||||
embed_tokens = model.llm.model.model.embed_tokens
|
||||
model.llm.model.set_input_embeddings(model.speech_embedding)
|
||||
model.llm.model.to(device)
|
||||
model.llm.model.to(dtype)
|
||||
tmp_vocab_size = model.llm.model.config.vocab_size
|
||||
tmp_tie_embedding = model.llm.model.config.tie_word_embeddings
|
||||
del model.llm.model.generation_config.eos_token_id
|
||||
del model.llm.model.config.bos_token_id
|
||||
del model.llm.model.config.eos_token_id
|
||||
model.llm.model.config.vocab_size = model.speech_embedding.num_embeddings
|
||||
model.llm.model.config.tie_word_embeddings = False
|
||||
model.llm.model.config.use_bias = use_bias
|
||||
model.llm.model.save_pretrained(model_path)
|
||||
if use_bias is True:
|
||||
os.system('sed -i s@Qwen2ForCausalLM@CosyVoice2ForCausalLM@g {}/config.json'.format(os.path.abspath(model_path)))
|
||||
model.llm.model.config.vocab_size = tmp_vocab_size
|
||||
model.llm.model.config.tie_word_embeddings = tmp_tie_embedding
|
||||
model.llm.model.set_input_embeddings(embed_tokens)
|
||||
136
models/CosyVoice/cosyvoice/utils/frontend_utils.py
Normal file
136
models/CosyVoice/cosyvoice/utils/frontend_utils.py
Normal file
@@ -0,0 +1,136 @@
|
||||
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import re
|
||||
import regex
|
||||
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
|
||||
|
||||
|
||||
# whether contain chinese character
|
||||
def contains_chinese(text):
|
||||
return bool(chinese_char_pattern.search(text))
|
||||
|
||||
|
||||
# replace special symbol
|
||||
def replace_corner_mark(text):
|
||||
text = text.replace('²', '平方')
|
||||
text = text.replace('³', '立方')
|
||||
return text
|
||||
|
||||
|
||||
# remove meaningless symbol
|
||||
def remove_bracket(text):
|
||||
text = text.replace('(', '').replace(')', '')
|
||||
text = text.replace('【', '').replace('】', '')
|
||||
text = text.replace('`', '').replace('`', '')
|
||||
text = text.replace("——", " ")
|
||||
return text
|
||||
|
||||
|
||||
# spell Arabic numerals
|
||||
def spell_out_number(text: str, inflect_parser):
|
||||
new_text = []
|
||||
st = None
|
||||
for i, c in enumerate(text):
|
||||
if not c.isdigit():
|
||||
if st is not None:
|
||||
num_str = inflect_parser.number_to_words(text[st: i])
|
||||
new_text.append(num_str)
|
||||
st = None
|
||||
new_text.append(c)
|
||||
else:
|
||||
if st is None:
|
||||
st = i
|
||||
if st is not None and st < len(text):
|
||||
num_str = inflect_parser.number_to_words(text[st:])
|
||||
new_text.append(num_str)
|
||||
return ''.join(new_text)
|
||||
|
||||
|
||||
# split paragrah logic:
|
||||
# 1. per sentence max len token_max_n, min len token_min_n, merge if last sentence len less than merge_len
|
||||
# 2. cal sentence len according to lang
|
||||
# 3. split sentence according to puncatation
|
||||
def split_paragraph(text: str, tokenize, lang="zh", token_max_n=80, token_min_n=60, merge_len=20, comma_split=False):
|
||||
def calc_utt_length(_text: str):
|
||||
if lang == "zh":
|
||||
return len(_text)
|
||||
else:
|
||||
return len(tokenize(_text))
|
||||
|
||||
def should_merge(_text: str):
|
||||
if lang == "zh":
|
||||
return len(_text) < merge_len
|
||||
else:
|
||||
return len(tokenize(_text)) < merge_len
|
||||
|
||||
if lang == "zh":
|
||||
pounc = ['。', '?', '!', ';', ':', '、', '.', '?', '!', ';']
|
||||
else:
|
||||
pounc = ['.', '?', '!', ';', ':']
|
||||
if comma_split:
|
||||
pounc.extend([',', ','])
|
||||
|
||||
if text[-1] not in pounc:
|
||||
if lang == "zh":
|
||||
text += "。"
|
||||
else:
|
||||
text += "."
|
||||
|
||||
st = 0
|
||||
utts = []
|
||||
for i, c in enumerate(text):
|
||||
if c in pounc:
|
||||
if len(text[st: i]) > 0:
|
||||
utts.append(text[st: i] + c)
|
||||
if i + 1 < len(text) and text[i + 1] in ['"', '”']:
|
||||
tmp = utts.pop(-1)
|
||||
utts.append(tmp + text[i + 1])
|
||||
st = i + 2
|
||||
else:
|
||||
st = i + 1
|
||||
|
||||
final_utts = []
|
||||
cur_utt = ""
|
||||
for utt in utts:
|
||||
if calc_utt_length(cur_utt + utt) > token_max_n and calc_utt_length(cur_utt) > token_min_n:
|
||||
final_utts.append(cur_utt)
|
||||
cur_utt = ""
|
||||
cur_utt = cur_utt + utt
|
||||
if len(cur_utt) > 0:
|
||||
if should_merge(cur_utt) and len(final_utts) != 0:
|
||||
final_utts[-1] = final_utts[-1] + cur_utt
|
||||
else:
|
||||
final_utts.append(cur_utt)
|
||||
|
||||
return final_utts
|
||||
|
||||
|
||||
# remove blank between chinese character
|
||||
def replace_blank(text: str):
|
||||
out_str = []
|
||||
for i, c in enumerate(text):
|
||||
if c == " ":
|
||||
if ((text[i + 1].isascii() and text[i + 1] != " ") and
|
||||
(text[i - 1].isascii() and text[i - 1] != " ")):
|
||||
out_str.append(c)
|
||||
else:
|
||||
out_str.append(c)
|
||||
return "".join(out_str)
|
||||
|
||||
|
||||
def is_only_punctuation(text):
|
||||
# Regular expression: Match strings that consist only of punctuation marks or are empty.
|
||||
punctuation_pattern = r'^[\p{P}\p{S}]*$'
|
||||
return bool(regex.fullmatch(punctuation_pattern, text))
|
||||
57
models/CosyVoice/cosyvoice/utils/losses.py
Normal file
57
models/CosyVoice/cosyvoice/utils/losses.py
Normal file
@@ -0,0 +1,57 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
def tpr_loss(disc_real_outputs, disc_generated_outputs, tau):
|
||||
loss = 0
|
||||
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
||||
m_DG = torch.median((dr - dg))
|
||||
L_rel = torch.mean((((dr - dg) - m_DG) ** 2)[dr < dg + m_DG])
|
||||
loss += tau - F.relu(tau - L_rel)
|
||||
return loss
|
||||
|
||||
|
||||
def mel_loss(real_speech, generated_speech, mel_transforms):
|
||||
loss = 0
|
||||
for transform in mel_transforms:
|
||||
mel_r = transform(real_speech)
|
||||
mel_g = transform(generated_speech)
|
||||
loss += F.l1_loss(mel_g, mel_r)
|
||||
return loss
|
||||
|
||||
|
||||
class DPOLoss(torch.nn.Module):
|
||||
"""
|
||||
DPO Loss
|
||||
"""
|
||||
|
||||
def __init__(self, beta: float, label_smoothing: float = 0.0, ipo: bool = False) -> None:
|
||||
super().__init__()
|
||||
self.beta = beta
|
||||
self.label_smoothing = label_smoothing
|
||||
self.ipo = ipo
|
||||
|
||||
def forward(
|
||||
self,
|
||||
policy_chosen_logps: torch.Tensor,
|
||||
policy_rejected_logps: torch.Tensor,
|
||||
reference_chosen_logps: torch.Tensor,
|
||||
reference_rejected_logps: torch.Tensor,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
pi_logratios = policy_chosen_logps - policy_rejected_logps
|
||||
ref_logratios = reference_chosen_logps - reference_rejected_logps
|
||||
logits = pi_logratios - ref_logratios
|
||||
if self.ipo:
|
||||
losses = (logits - 1 / (2 * self.beta)) ** 2 # Eq. 17 of https://arxiv.org/pdf/2310.12036v2.pdf
|
||||
else:
|
||||
# Eq. 3 https://ericmitchell.ai/cdpo.pdf; label_smoothing=0 gives original DPO (Eq. 7 of https://arxiv.org/pdf/2305.18290.pdf)
|
||||
losses = (
|
||||
-F.logsigmoid(self.beta * logits) * (1 - self.label_smoothing)
|
||||
- F.logsigmoid(-self.beta * logits) * self.label_smoothing
|
||||
)
|
||||
loss = losses.mean()
|
||||
chosen_rewards = self.beta * (policy_chosen_logps - reference_chosen_logps).detach()
|
||||
rejected_rewards = self.beta * (policy_rejected_logps - reference_rejected_logps).detach()
|
||||
|
||||
return loss, chosen_rewards, rejected_rewards
|
||||
265
models/CosyVoice/cosyvoice/utils/mask.py
Normal file
265
models/CosyVoice/cosyvoice/utils/mask.py
Normal file
@@ -0,0 +1,265 @@
|
||||
# Copyright (c) 2019 Shigeki Karita
|
||||
# 2020 Mobvoi Inc (Binbin Zhang)
|
||||
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
'''
|
||||
def subsequent_mask(
|
||||
size: int,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> torch.Tensor:
|
||||
"""Create mask for subsequent steps (size, size).
|
||||
|
||||
This mask is used only in decoder which works in an auto-regressive mode.
|
||||
This means the current step could only do attention with its left steps.
|
||||
|
||||
In encoder, fully attention is used when streaming is not necessary and
|
||||
the sequence is not long. In this case, no attention mask is needed.
|
||||
|
||||
When streaming is need, chunk-based attention is used in encoder. See
|
||||
subsequent_chunk_mask for the chunk-based attention mask.
|
||||
|
||||
Args:
|
||||
size (int): size of mask
|
||||
str device (str): "cpu" or "cuda" or torch.Tensor.device
|
||||
dtype (torch.device): result dtype
|
||||
|
||||
Returns:
|
||||
torch.Tensor: mask
|
||||
|
||||
Examples:
|
||||
>>> subsequent_mask(3)
|
||||
[[1, 0, 0],
|
||||
[1, 1, 0],
|
||||
[1, 1, 1]]
|
||||
"""
|
||||
ret = torch.ones(size, size, device=device, dtype=torch.bool)
|
||||
return torch.tril(ret)
|
||||
'''
|
||||
|
||||
|
||||
def subsequent_mask(
|
||||
size: int,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> torch.Tensor:
|
||||
"""Create mask for subsequent steps (size, size).
|
||||
|
||||
This mask is used only in decoder which works in an auto-regressive mode.
|
||||
This means the current step could only do attention with its left steps.
|
||||
|
||||
In encoder, fully attention is used when streaming is not necessary and
|
||||
the sequence is not long. In this case, no attention mask is needed.
|
||||
|
||||
When streaming is need, chunk-based attention is used in encoder. See
|
||||
subsequent_chunk_mask for the chunk-based attention mask.
|
||||
|
||||
Args:
|
||||
size (int): size of mask
|
||||
str device (str): "cpu" or "cuda" or torch.Tensor.device
|
||||
dtype (torch.device): result dtype
|
||||
|
||||
Returns:
|
||||
torch.Tensor: mask
|
||||
|
||||
Examples:
|
||||
>>> subsequent_mask(3)
|
||||
[[1, 0, 0],
|
||||
[1, 1, 0],
|
||||
[1, 1, 1]]
|
||||
"""
|
||||
arange = torch.arange(size, device=device)
|
||||
mask = arange.expand(size, size)
|
||||
arange = arange.unsqueeze(-1)
|
||||
mask = mask <= arange
|
||||
return mask
|
||||
|
||||
|
||||
def subsequent_chunk_mask_deprecated(
|
||||
size: int,
|
||||
chunk_size: int,
|
||||
num_left_chunks: int = -1,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> torch.Tensor:
|
||||
"""Create mask for subsequent steps (size, size) with chunk size,
|
||||
this is for streaming encoder
|
||||
|
||||
Args:
|
||||
size (int): size of mask
|
||||
chunk_size (int): size of chunk
|
||||
num_left_chunks (int): number of left chunks
|
||||
<0: use full chunk
|
||||
>=0: use num_left_chunks
|
||||
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
||||
|
||||
Returns:
|
||||
torch.Tensor: mask
|
||||
|
||||
Examples:
|
||||
>>> subsequent_chunk_mask(4, 2)
|
||||
[[1, 1, 0, 0],
|
||||
[1, 1, 0, 0],
|
||||
[1, 1, 1, 1],
|
||||
[1, 1, 1, 1]]
|
||||
"""
|
||||
ret = torch.zeros(size, size, device=device, dtype=torch.bool)
|
||||
for i in range(size):
|
||||
if num_left_chunks < 0:
|
||||
start = 0
|
||||
else:
|
||||
start = max((i // chunk_size - num_left_chunks) * chunk_size, 0)
|
||||
ending = min((i // chunk_size + 1) * chunk_size, size)
|
||||
ret[i, start:ending] = True
|
||||
return ret
|
||||
|
||||
|
||||
def subsequent_chunk_mask(
|
||||
size: int,
|
||||
chunk_size: int,
|
||||
num_left_chunks: int = -1,
|
||||
device: torch.device = torch.device("cpu"),
|
||||
) -> torch.Tensor:
|
||||
"""Create mask for subsequent steps (size, size) with chunk size,
|
||||
this is for streaming encoder
|
||||
|
||||
Args:
|
||||
size (int): size of mask
|
||||
chunk_size (int): size of chunk
|
||||
num_left_chunks (int): number of left chunks
|
||||
<0: use full chunk
|
||||
>=0: use num_left_chunks
|
||||
device (torch.device): "cpu" or "cuda" or torch.Tensor.device
|
||||
|
||||
Returns:
|
||||
torch.Tensor: mask
|
||||
|
||||
Examples:
|
||||
>>> subsequent_chunk_mask(4, 2)
|
||||
[[1, 1, 0, 0],
|
||||
[1, 1, 0, 0],
|
||||
[1, 1, 1, 1],
|
||||
[1, 1, 1, 1]]
|
||||
"""
|
||||
# NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks
|
||||
pos_idx = torch.arange(size, device=device)
|
||||
block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size
|
||||
ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1)
|
||||
return ret
|
||||
|
||||
|
||||
def add_optional_chunk_mask(xs: torch.Tensor,
|
||||
masks: torch.Tensor,
|
||||
use_dynamic_chunk: bool,
|
||||
use_dynamic_left_chunk: bool,
|
||||
decoding_chunk_size: int,
|
||||
static_chunk_size: int,
|
||||
num_decoding_left_chunks: int,
|
||||
enable_full_context: bool = True):
|
||||
""" Apply optional mask for encoder.
|
||||
|
||||
Args:
|
||||
xs (torch.Tensor): padded input, (B, L, D), L for max length
|
||||
mask (torch.Tensor): mask for xs, (B, 1, L)
|
||||
use_dynamic_chunk (bool): whether to use dynamic chunk or not
|
||||
use_dynamic_left_chunk (bool): whether to use dynamic left chunk for
|
||||
training.
|
||||
decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's
|
||||
0: default for training, use random dynamic chunk.
|
||||
<0: for decoding, use full chunk.
|
||||
>0: for decoding, use fixed chunk size as set.
|
||||
static_chunk_size (int): chunk size for static chunk training/decoding
|
||||
if it's greater than 0, if use_dynamic_chunk is true,
|
||||
this parameter will be ignored
|
||||
num_decoding_left_chunks: number of left chunks, this is for decoding,
|
||||
the chunk size is decoding_chunk_size.
|
||||
>=0: use num_decoding_left_chunks
|
||||
<0: use all left chunks
|
||||
enable_full_context (bool):
|
||||
True: chunk size is either [1, 25] or full context(max_len)
|
||||
False: chunk size ~ U[1, 25]
|
||||
|
||||
Returns:
|
||||
torch.Tensor: chunk mask of the input xs.
|
||||
"""
|
||||
# Whether to use chunk mask or not
|
||||
if use_dynamic_chunk:
|
||||
max_len = xs.size(1)
|
||||
if decoding_chunk_size < 0:
|
||||
chunk_size = max_len
|
||||
num_left_chunks = -1
|
||||
elif decoding_chunk_size > 0:
|
||||
chunk_size = decoding_chunk_size
|
||||
num_left_chunks = num_decoding_left_chunks
|
||||
else:
|
||||
# chunk size is either [1, 25] or full context(max_len).
|
||||
# Since we use 4 times subsampling and allow up to 1s(100 frames)
|
||||
# delay, the maximum frame is 100 / 4 = 25.
|
||||
chunk_size = torch.randint(1, max_len, (1, )).item()
|
||||
num_left_chunks = -1
|
||||
if chunk_size > max_len // 2 and enable_full_context:
|
||||
chunk_size = max_len
|
||||
else:
|
||||
chunk_size = chunk_size % 25 + 1
|
||||
if use_dynamic_left_chunk:
|
||||
max_left_chunks = (max_len - 1) // chunk_size
|
||||
num_left_chunks = torch.randint(0, max_left_chunks,
|
||||
(1, )).item()
|
||||
chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size,
|
||||
num_left_chunks,
|
||||
xs.device) # (L, L)
|
||||
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
||||
chunk_masks = masks & chunk_masks # (B, L, L)
|
||||
elif static_chunk_size > 0:
|
||||
num_left_chunks = num_decoding_left_chunks
|
||||
chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size,
|
||||
num_left_chunks,
|
||||
xs.device) # (L, L)
|
||||
chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L)
|
||||
chunk_masks = masks & chunk_masks # (B, L, L)
|
||||
else:
|
||||
chunk_masks = masks
|
||||
assert chunk_masks.dtype == torch.bool
|
||||
if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0:
|
||||
print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!')
|
||||
chunk_masks[chunk_masks.sum(dim=-1) == 0] = True
|
||||
return chunk_masks
|
||||
|
||||
|
||||
def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor:
|
||||
"""Make mask tensor containing indices of padded part.
|
||||
|
||||
See description of make_non_pad_mask.
|
||||
|
||||
Args:
|
||||
lengths (torch.Tensor): Batch of lengths (B,).
|
||||
Returns:
|
||||
torch.Tensor: Mask tensor containing indices of padded part.
|
||||
|
||||
Examples:
|
||||
>>> lengths = [5, 3, 2]
|
||||
>>> make_pad_mask(lengths)
|
||||
masks = [[0, 0, 0, 0 ,0],
|
||||
[0, 0, 0, 1, 1],
|
||||
[0, 0, 1, 1, 1]]
|
||||
"""
|
||||
batch_size = lengths.size(0)
|
||||
max_len = max_len if max_len > 0 else lengths.max().item()
|
||||
seq_range = torch.arange(0,
|
||||
max_len,
|
||||
dtype=torch.int64,
|
||||
device=lengths.device)
|
||||
seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len)
|
||||
seq_length_expand = lengths.unsqueeze(-1)
|
||||
mask = seq_range_expand >= seq_length_expand
|
||||
return mask
|
||||
54
models/CosyVoice/cosyvoice/utils/onnx.py
Normal file
54
models/CosyVoice/cosyvoice/utils/onnx.py
Normal file
@@ -0,0 +1,54 @@
|
||||
import onnxruntime
|
||||
import torch, random
|
||||
import os
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
|
||||
|
||||
class SpeechTokenExtractor():
|
||||
def __init__(self, model_path):
|
||||
self.local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
self.speech_tokenizer_session = onnxruntime.InferenceSession(model_path,
|
||||
sess_options=option,
|
||||
providers=[("CUDAExecutionProvider", {'device_id': self.local_rank})])
|
||||
|
||||
def inference(self, feat, feat_lengths, device):
|
||||
speech_token = self.speech_tokenizer_session.run(None,
|
||||
{self.speech_tokenizer_session.get_inputs()[0].name:
|
||||
feat.transpose(1, 2).detach().cpu().numpy(),
|
||||
self.speech_tokenizer_session.get_inputs()[1].name:
|
||||
feat_lengths.detach().cpu().numpy()})[0]
|
||||
return torch.tensor(speech_token).to(torch.int32).to(device), (feat_lengths / 4).to(torch.int32).to(device)
|
||||
|
||||
|
||||
class EmbeddingExtractor():
|
||||
def __init__(self, model_path):
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
self.max_len = 10 * 16000
|
||||
self.campplus_session = onnxruntime.InferenceSession(model_path,
|
||||
sess_options=option,
|
||||
providers=["CPUExecutionProvider"])
|
||||
|
||||
def inference(self, speech):
|
||||
if speech.shape[1] > self.max_len:
|
||||
start_index = random.randint(0, speech.shape[1] - self.max_len)
|
||||
speech = speech[:, start_index: start_index + self.max_len]
|
||||
feat = kaldi.fbank(speech,
|
||||
num_mel_bins=80,
|
||||
dither=0,
|
||||
sample_frequency=16000)
|
||||
feat = feat - feat.mean(dim=0, keepdim=True)
|
||||
embedding = self.campplus_session.run(None,
|
||||
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
||||
return torch.tensor(embedding).to(speech.device)
|
||||
|
||||
# singleton mode, only initialized once
|
||||
onnx_path = os.environ.get('onnx_path')
|
||||
if onnx_path is not None:
|
||||
embedding_extractor, online_feature = EmbeddingExtractor(model_path=os.path.join(onnx_path, 'campplus.onnx')), True
|
||||
else:
|
||||
embedding_extractor, online_feature = None, False
|
||||
738
models/CosyVoice/cosyvoice/utils/scheduler.py
Normal file
738
models/CosyVoice/cosyvoice/utils/scheduler.py
Normal file
@@ -0,0 +1,738 @@
|
||||
# Copyright (c) 2020 Mobvoi Inc (Binbin Zhang)
|
||||
# 2022 Ximalaya Inc (Yuguang Yang)
|
||||
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# Modified from ESPnet(https://github.com/espnet/espnet)
|
||||
# NeMo(https://github.com/NVIDIA/NeMo)
|
||||
|
||||
from typing import Union
|
||||
|
||||
import math
|
||||
import warnings
|
||||
import torch
|
||||
from torch.optim.lr_scheduler import _LRScheduler
|
||||
|
||||
|
||||
class WarmupLR(_LRScheduler):
|
||||
"""The WarmupLR scheduler
|
||||
|
||||
This scheduler is almost same as NoamLR Scheduler except for following
|
||||
difference:
|
||||
|
||||
NoamLR:
|
||||
lr = optimizer.lr * model_size ** -0.5
|
||||
* min(step ** -0.5, step * warmup_step ** -1.5)
|
||||
WarmupLR:
|
||||
lr = optimizer.lr * warmup_step ** 0.5
|
||||
* min(step ** -0.5, step * warmup_step ** -1.5)
|
||||
|
||||
Note that the maximum lr equals to optimizer.lr in this scheduler.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
warmup_steps: Union[int, float] = 25000,
|
||||
last_epoch: int = -1,
|
||||
):
|
||||
self.warmup_steps = warmup_steps
|
||||
|
||||
# __init__() must be invoked before setting field
|
||||
# because step() is also invoked in __init__()
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def __repr__(self):
|
||||
return f"{self.__class__.__name__}(warmup_steps={self.warmup_steps})"
|
||||
|
||||
def get_lr(self):
|
||||
step_num = self.last_epoch + 1
|
||||
if self.warmup_steps == 0:
|
||||
return [lr * step_num**-0.5 for lr in self.base_lrs]
|
||||
else:
|
||||
return [
|
||||
lr * self.warmup_steps**0.5 *
|
||||
min(step_num**-0.5, step_num * self.warmup_steps**-1.5)
|
||||
for lr in self.base_lrs
|
||||
]
|
||||
|
||||
def set_step(self, step: int):
|
||||
self.last_epoch = step
|
||||
|
||||
|
||||
class WarmupPolicy(_LRScheduler):
|
||||
"""Adds warmup kwargs and warmup logic to lr policy.
|
||||
All arguments should be passed as kwargs for clarity,
|
||||
Args:
|
||||
warmup_steps: Number of training steps in warmup stage
|
||||
warmup_ratio: Ratio of warmup steps to total steps
|
||||
max_steps: Total number of steps while training or `None` for
|
||||
infinite training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
*,
|
||||
warmup_steps=None,
|
||||
warmup_ratio=None,
|
||||
max_steps=None,
|
||||
min_lr=0.0,
|
||||
last_epoch=-1):
|
||||
assert not (warmup_steps is not None and warmup_ratio is not None),\
|
||||
"Either use particular number of step or ratio"
|
||||
assert warmup_ratio is None or max_steps is not None, \
|
||||
"If there is a ratio, there should be a total steps"
|
||||
|
||||
# It is necessary to assign all attributes *before* __init__,
|
||||
# as class is wrapped by an inner class.
|
||||
self.max_steps = max_steps
|
||||
if warmup_steps is not None:
|
||||
self.warmup_steps = warmup_steps
|
||||
elif warmup_ratio is not None:
|
||||
self.warmup_steps = int(warmup_ratio * max_steps)
|
||||
else:
|
||||
self.warmup_steps = 0
|
||||
|
||||
self.min_lr = min_lr
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
warnings.warn(
|
||||
"To get the last learning rate computed "
|
||||
"by the scheduler, please use `get_last_lr()`.",
|
||||
UserWarning,
|
||||
stacklevel=2)
|
||||
|
||||
step = self.last_epoch
|
||||
|
||||
if step <= self.warmup_steps and self.warmup_steps > 0:
|
||||
return self._get_warmup_lr(step)
|
||||
|
||||
if step > self.max_steps:
|
||||
return [self.min_lr for _ in self.base_lrs]
|
||||
|
||||
return self._get_lr(step)
|
||||
|
||||
def _get_warmup_lr(self, step):
|
||||
lr_val = (step + 1) / (self.warmup_steps + 1)
|
||||
return [initial_lr * lr_val for initial_lr in self.base_lrs]
|
||||
|
||||
def _get_lr(self, step):
|
||||
"""Simple const lr policy"""
|
||||
return self.base_lrs
|
||||
|
||||
|
||||
class SquareRootConstantPolicy(_LRScheduler):
|
||||
"""Adds warmup kwargs and warmup logic to lr policy.
|
||||
All arguments should be passed as kwargs for clarity,
|
||||
Args:
|
||||
warmup_steps: Number of training steps in warmup stage
|
||||
warmup_ratio: Ratio of warmup steps to total steps
|
||||
max_steps: Total number of steps while training or `None` for
|
||||
infinite training
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
*,
|
||||
constant_steps=None,
|
||||
constant_ratio=None,
|
||||
max_steps=None,
|
||||
min_lr=0.0,
|
||||
last_epoch=-1):
|
||||
assert not (constant_steps is not None
|
||||
and constant_ratio is not None), \
|
||||
"Either use particular number of step or ratio"
|
||||
assert constant_ratio is None or max_steps is not None, \
|
||||
"If there is a ratio, there should be a total steps"
|
||||
|
||||
# It is necessary to assign all attributes *before* __init__,
|
||||
# as class is wrapped by an inner class.
|
||||
self.max_steps = max_steps
|
||||
if constant_steps is not None:
|
||||
self.constant_steps = constant_steps
|
||||
elif constant_ratio is not None:
|
||||
self.constant_steps = int(constant_ratio * max_steps)
|
||||
else:
|
||||
self.constant_steps = 0
|
||||
|
||||
self.constant_lr = 1 / (constant_steps**0.5)
|
||||
self.min_lr = min_lr
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
warnings.warn(
|
||||
"To get the last learning rate computed "
|
||||
"by the scheduler, please use `get_last_lr()`.",
|
||||
UserWarning,
|
||||
stacklevel=2)
|
||||
|
||||
step = self.last_epoch
|
||||
|
||||
if step <= self.constant_steps:
|
||||
return [self.constant_lr for _ in self.base_lrs]
|
||||
|
||||
if step > self.max_steps:
|
||||
return [self.min_lr for _ in self.base_lrs]
|
||||
|
||||
return self._get_lr(step)
|
||||
|
||||
def _get_lr(self, step):
|
||||
"""Simple const lr policy"""
|
||||
return self.base_lrs
|
||||
|
||||
|
||||
class WarmupHoldPolicy(WarmupPolicy):
|
||||
"""Variant of WarmupPolicy which maintains high
|
||||
learning rate for a defined number of steps.
|
||||
All arguments should be passed as kwargs for clarity,
|
||||
Args:
|
||||
warmup_steps: Number of training steps in warmup stage
|
||||
warmup_ratio: Ratio of warmup steps to total steps
|
||||
hold_steps: Number of training steps to
|
||||
hold the learning rate after warm up
|
||||
hold_ratio: Ratio of hold steps to total steps
|
||||
max_steps: Total number of steps while training or `None` for
|
||||
infinite training
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
*,
|
||||
warmup_steps=None,
|
||||
warmup_ratio=None,
|
||||
hold_steps=None,
|
||||
hold_ratio=None,
|
||||
max_steps=None,
|
||||
min_lr=0.0,
|
||||
last_epoch=-1,
|
||||
):
|
||||
assert not (hold_steps is not None and hold_ratio is not None), \
|
||||
"Either use particular number of step or ratio"
|
||||
assert hold_ratio is None or max_steps is not None, \
|
||||
"If there is a ratio, there should be a total steps"
|
||||
|
||||
self.min_lr = min_lr
|
||||
self._last_warmup_lr = 0.0
|
||||
|
||||
# Necessary to duplicate as class attributes are hidden in inner class
|
||||
self.max_steps = max_steps
|
||||
if warmup_steps is not None:
|
||||
self.warmup_steps = warmup_steps
|
||||
elif warmup_ratio is not None:
|
||||
self.warmup_steps = int(warmup_ratio * max_steps)
|
||||
else:
|
||||
self.warmup_steps = 0
|
||||
|
||||
if hold_steps is not None:
|
||||
self.hold_steps = hold_steps + self.warmup_steps
|
||||
elif hold_ratio is not None:
|
||||
self.hold_steps = int(hold_ratio * max_steps) + self.warmup_steps
|
||||
else:
|
||||
self.hold_steps = 0
|
||||
|
||||
super().__init__(
|
||||
optimizer,
|
||||
warmup_steps=warmup_steps,
|
||||
warmup_ratio=warmup_ratio,
|
||||
max_steps=max_steps,
|
||||
last_epoch=last_epoch,
|
||||
min_lr=min_lr,
|
||||
)
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
warnings.warn(
|
||||
"To get the last learning rate computed by the scheduler,"
|
||||
" "
|
||||
"please use `get_last_lr()`.",
|
||||
UserWarning,
|
||||
stacklevel=2)
|
||||
|
||||
step = self.last_epoch
|
||||
|
||||
# Warmup phase
|
||||
if step <= self.warmup_steps and self.warmup_steps > 0:
|
||||
return self._get_warmup_lr(step)
|
||||
|
||||
# Hold phase
|
||||
if (step >= self.warmup_steps) and (step < self.hold_steps):
|
||||
return self.base_lrs
|
||||
|
||||
if step > self.max_steps:
|
||||
return [self.min_lr for _ in self.base_lrs]
|
||||
|
||||
return self._get_lr(step)
|
||||
|
||||
|
||||
class WarmupAnnealHoldPolicy(_LRScheduler):
|
||||
"""Adds warmup kwargs and warmup logic to lr policy.
|
||||
All arguments should be passed as kwargs for clarity,
|
||||
Args:
|
||||
warmup_steps: Number of training steps in warmup stage
|
||||
warmup_ratio: Ratio of warmup steps to total steps
|
||||
max_steps: Total number of steps while training or `None` for
|
||||
infinite training
|
||||
min_lr: Minimum lr to hold the learning rate after decay at.
|
||||
constant_steps: Number of steps to keep lr constant at.
|
||||
constant_ratio: Ratio of steps to keep lr constant.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer,
|
||||
*,
|
||||
warmup_steps=None,
|
||||
warmup_ratio=None,
|
||||
constant_steps=None,
|
||||
constant_ratio=None,
|
||||
max_steps=None,
|
||||
min_lr=0.0,
|
||||
last_epoch=-1,
|
||||
):
|
||||
assert not (warmup_steps is not None
|
||||
and warmup_ratio is not None), \
|
||||
"Either use particular number of step or ratio"
|
||||
assert not (constant_steps is not None
|
||||
and constant_ratio is not None), \
|
||||
"Either use constant_steps or constant_ratio"
|
||||
assert warmup_ratio is None or max_steps is not None, \
|
||||
"If there is a ratio, there should be a total steps"
|
||||
|
||||
# It is necessary to assign all attributes *before* __init__,
|
||||
# as class is wrapped by an inner class.
|
||||
self.max_steps = max_steps
|
||||
|
||||
if warmup_steps is not None:
|
||||
self.warmup_steps = warmup_steps
|
||||
elif warmup_ratio is not None:
|
||||
self.warmup_steps = int(warmup_ratio * max_steps)
|
||||
else:
|
||||
self.warmup_steps = 0
|
||||
|
||||
if constant_steps is not None:
|
||||
self.constant_steps = constant_steps
|
||||
elif constant_ratio is not None:
|
||||
self.constant_steps = int(constant_ratio * max_steps)
|
||||
else:
|
||||
self.constant_steps = 0
|
||||
|
||||
self.decay_steps = max_steps - (self.constant_steps +
|
||||
self.warmup_steps)
|
||||
|
||||
self.min_lr = min_lr
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
warnings.warn(
|
||||
"To get the last learning rate computed "
|
||||
"by the scheduler, please use `get_last_lr()`.",
|
||||
UserWarning,
|
||||
stacklevel=2)
|
||||
|
||||
step = self.last_epoch
|
||||
|
||||
# Warmup steps
|
||||
if self.warmup_steps > 0 and step <= self.warmup_steps:
|
||||
return self._get_warmup_lr(step)
|
||||
|
||||
# Constant steps after warmup and decay
|
||||
if self.constant_steps > 0 and (
|
||||
self.warmup_steps + self.decay_steps) < step <= self.max_steps:
|
||||
return self._get_constant_lr(step)
|
||||
|
||||
# Min lr after max steps of updates
|
||||
if step > self.max_steps:
|
||||
return [self.min_lr for _ in self.base_lrs]
|
||||
|
||||
return self._get_lr(step)
|
||||
|
||||
def _get_warmup_lr(self, step):
|
||||
lr_val = (step + 1) / (self.warmup_steps + 1)
|
||||
return [initial_lr * lr_val for initial_lr in self.base_lrs]
|
||||
|
||||
def _get_constant_lr(self, step):
|
||||
return [self.min_lr for _ in self.base_lrs]
|
||||
|
||||
def _get_lr(self, step):
|
||||
"""Simple const lr policy"""
|
||||
return self.base_lrs
|
||||
|
||||
|
||||
def _squareroot_annealing(initial_lr, step, max_steps, min_lr):
|
||||
mult = ((max_steps - step) / max_steps)**0.5
|
||||
out_lr = initial_lr * mult
|
||||
out_lr = max(out_lr, min_lr)
|
||||
return out_lr
|
||||
|
||||
|
||||
def _square_annealing(initial_lr, step, max_steps, min_lr):
|
||||
mult = ((max_steps - step) / max_steps)**2
|
||||
out_lr = initial_lr * mult
|
||||
out_lr = max(out_lr, min_lr)
|
||||
return out_lr
|
||||
|
||||
|
||||
def _cosine_annealing(initial_lr, step, max_steps, min_lr):
|
||||
mult = 0.5 * (1 + math.cos(math.pi * step / max_steps))
|
||||
out_lr = (initial_lr - min_lr) * mult + min_lr
|
||||
return out_lr
|
||||
|
||||
|
||||
def _linear_warmup_with_cosine_annealing(max_lr, warmup_steps, step,
|
||||
decay_steps, min_lr):
|
||||
assert max_lr > min_lr
|
||||
# Use linear warmup for the initial part.
|
||||
if warmup_steps > 0 and step <= warmup_steps:
|
||||
return max_lr * float(step) / float(warmup_steps)
|
||||
|
||||
# For any steps larger than `decay_steps`, use `min_lr`.
|
||||
if step > warmup_steps + decay_steps:
|
||||
return min_lr
|
||||
|
||||
# If we are done with the warmup period, use the decay style.
|
||||
num_steps_ = step - warmup_steps
|
||||
decay_steps_ = decay_steps
|
||||
decay_ratio = float(num_steps_) / float(decay_steps_)
|
||||
assert decay_ratio >= 0.0
|
||||
assert decay_ratio <= 1.0
|
||||
delta_lr = max_lr - min_lr
|
||||
|
||||
coeff = 0.5 * (math.cos(math.pi * decay_ratio) + 1.0)
|
||||
|
||||
return min_lr + coeff * delta_lr
|
||||
|
||||
|
||||
def _poly_decay(initial_lr, step, decay_steps, power, min_lr, cycle):
|
||||
if cycle:
|
||||
multiplier = 1.0 if step == 0 else math.ceil(step / decay_steps)
|
||||
decay_steps *= multiplier
|
||||
else:
|
||||
step = min(step, decay_steps)
|
||||
p = step / decay_steps
|
||||
lr = (initial_lr - min_lr) * math.pow(1.0 - p, power)
|
||||
lr += min_lr
|
||||
return lr
|
||||
|
||||
|
||||
def _noam_hold_annealing(initial_lr, step, warmup_steps, hold_steps,
|
||||
decay_rate, min_lr):
|
||||
# hold_steps = total number of steps
|
||||
# to hold the LR, not the warmup + hold steps.
|
||||
T_warmup_decay = max(1, warmup_steps**decay_rate)
|
||||
T_hold_decay = max(1, (step - hold_steps)**decay_rate)
|
||||
lr = (initial_lr * T_warmup_decay) / T_hold_decay
|
||||
lr = max(lr, min_lr)
|
||||
return lr
|
||||
|
||||
|
||||
class SquareAnnealing(WarmupPolicy):
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
*,
|
||||
max_steps,
|
||||
min_lr=1e-5,
|
||||
last_epoch=-1,
|
||||
**kwargs):
|
||||
super().__init__(optimizer=optimizer,
|
||||
max_steps=max_steps,
|
||||
last_epoch=last_epoch,
|
||||
min_lr=min_lr,
|
||||
**kwargs)
|
||||
|
||||
def _get_lr(self, step):
|
||||
new_lrs = [
|
||||
_square_annealing(
|
||||
initial_lr=initial_lr,
|
||||
step=step - self.warmup_steps,
|
||||
max_steps=self.max_steps - self.warmup_steps,
|
||||
min_lr=self.min_lr,
|
||||
) for initial_lr in self.base_lrs
|
||||
]
|
||||
return new_lrs
|
||||
|
||||
|
||||
class SquareRootAnnealing(WarmupPolicy):
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
*,
|
||||
max_steps,
|
||||
min_lr=0,
|
||||
last_epoch=-1,
|
||||
**kwargs):
|
||||
super().__init__(optimizer=optimizer,
|
||||
max_steps=max_steps,
|
||||
last_epoch=last_epoch,
|
||||
min_lr=min_lr,
|
||||
**kwargs)
|
||||
|
||||
def _get_lr(self, step):
|
||||
new_lrs = [
|
||||
_squareroot_annealing(initial_lr=initial_lr,
|
||||
step=step,
|
||||
max_steps=self.max_steps,
|
||||
min_lr=self.min_lr)
|
||||
for initial_lr in self.base_lrs
|
||||
]
|
||||
return new_lrs
|
||||
|
||||
|
||||
class CosineAnnealing(WarmupAnnealHoldPolicy):
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
*,
|
||||
max_steps,
|
||||
min_lr=0,
|
||||
last_epoch=-1,
|
||||
**kwargs):
|
||||
super().__init__(optimizer=optimizer,
|
||||
max_steps=max_steps,
|
||||
last_epoch=last_epoch,
|
||||
min_lr=min_lr,
|
||||
**kwargs)
|
||||
|
||||
def _get_lr(self, step):
|
||||
for initial_lr in self.base_lrs:
|
||||
if initial_lr < self.min_lr:
|
||||
raise ValueError(
|
||||
f"{self} received an initial learning rate "
|
||||
f"that was lower than the minimum learning rate.")
|
||||
|
||||
if self.constant_steps is None or self.constant_steps == 0:
|
||||
new_lrs = [
|
||||
_cosine_annealing(
|
||||
initial_lr=initial_lr,
|
||||
step=step - self.warmup_steps,
|
||||
max_steps=self.max_steps - self.warmup_steps,
|
||||
min_lr=self.min_lr,
|
||||
) for initial_lr in self.base_lrs
|
||||
]
|
||||
else:
|
||||
new_lrs = self._get_linear_warmup_with_cosine_annealing_lr(step)
|
||||
return new_lrs
|
||||
|
||||
def _get_warmup_lr(self, step):
|
||||
if self.constant_steps is None or self.constant_steps == 0:
|
||||
return super()._get_warmup_lr(step)
|
||||
else:
|
||||
# Use linear warmup for the initial part.
|
||||
return self._get_linear_warmup_with_cosine_annealing_lr(step)
|
||||
|
||||
def _get_constant_lr(self, step):
|
||||
# Only called when `constant_steps` > 0.
|
||||
return self._get_linear_warmup_with_cosine_annealing_lr(step)
|
||||
|
||||
def _get_linear_warmup_with_cosine_annealing_lr(self, step):
|
||||
# Cosine Schedule for Megatron LM,
|
||||
# slightly different warmup schedule + constant LR at the end.
|
||||
new_lrs = [
|
||||
_linear_warmup_with_cosine_annealing(
|
||||
max_lr=self.base_lrs[0],
|
||||
warmup_steps=self.warmup_steps,
|
||||
step=step,
|
||||
decay_steps=self.decay_steps,
|
||||
min_lr=self.min_lr,
|
||||
) for _ in self.base_lrs
|
||||
]
|
||||
return new_lrs
|
||||
|
||||
|
||||
class NoamAnnealing(_LRScheduler):
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
*,
|
||||
d_model,
|
||||
warmup_steps=None,
|
||||
warmup_ratio=None,
|
||||
max_steps=None,
|
||||
min_lr=0.0,
|
||||
last_epoch=-1):
|
||||
self._normalize = d_model**(-0.5)
|
||||
assert not (warmup_steps is not None and warmup_ratio is not None), \
|
||||
"Either use particular number of step or ratio"
|
||||
assert warmup_ratio is None or max_steps is not None, \
|
||||
"If there is a ratio, there should be a total steps"
|
||||
|
||||
# It is necessary to assign all attributes *before* __init__,
|
||||
# as class is wrapped by an inner class.
|
||||
self.max_steps = max_steps
|
||||
if warmup_steps is not None:
|
||||
self.warmup_steps = warmup_steps
|
||||
elif warmup_ratio is not None:
|
||||
self.warmup_steps = int(warmup_ratio * max_steps)
|
||||
else:
|
||||
self.warmup_steps = 0
|
||||
|
||||
self.min_lr = min_lr
|
||||
super().__init__(optimizer, last_epoch)
|
||||
|
||||
def get_lr(self):
|
||||
if not self._get_lr_called_within_step:
|
||||
warnings.warn(
|
||||
"To get the last learning rate computed "
|
||||
"by the scheduler, please use `get_last_lr()`.",
|
||||
UserWarning,
|
||||
stacklevel=2)
|
||||
|
||||
step = max(1, self.last_epoch)
|
||||
|
||||
for initial_lr in self.base_lrs:
|
||||
if initial_lr < self.min_lr:
|
||||
raise ValueError(
|
||||
f"{self} received an initial learning rate "
|
||||
f"that was lower than the minimum learning rate.")
|
||||
|
||||
new_lrs = [
|
||||
self._noam_annealing(initial_lr=initial_lr, step=step)
|
||||
for initial_lr in self.base_lrs
|
||||
]
|
||||
return new_lrs
|
||||
|
||||
def _noam_annealing(self, initial_lr, step):
|
||||
if self.warmup_steps > 0:
|
||||
mult = self._normalize * min(step**(-0.5),
|
||||
step * (self.warmup_steps**(-1.5)))
|
||||
else:
|
||||
mult = self._normalize * step**(-0.5)
|
||||
|
||||
out_lr = initial_lr * mult
|
||||
if step > self.warmup_steps:
|
||||
out_lr = max(out_lr, self.min_lr)
|
||||
return out_lr
|
||||
|
||||
|
||||
class NoamHoldAnnealing(WarmupHoldPolicy):
|
||||
|
||||
def __init__(self,
|
||||
optimizer,
|
||||
*,
|
||||
max_steps,
|
||||
decay_rate=0.5,
|
||||
min_lr=0.0,
|
||||
last_epoch=-1,
|
||||
**kwargs):
|
||||
"""
|
||||
From Nemo:
|
||||
Implementation of the Noam Hold Annealing policy
|
||||
from the SqueezeFormer paper.
|
||||
|
||||
Unlike NoamAnnealing, the peak learning rate
|
||||
can be explicitly set for this scheduler.
|
||||
The schedule first performs linear warmup,
|
||||
then holds the peak LR, then decays with some schedule for
|
||||
the remainder of the steps.
|
||||
Therefore the min-lr is still dependent
|
||||
on the hyper parameters selected.
|
||||
|
||||
It's schedule is determined by three factors-
|
||||
|
||||
Warmup Steps: Initial stage, where linear warmup
|
||||
occurs uptil the peak LR is reached. Unlike NoamAnnealing,
|
||||
the peak LR is explicitly stated here instead of a scaling factor.
|
||||
|
||||
Hold Steps: Intermediate stage, where the peak LR
|
||||
is maintained for some number of steps. In this region,
|
||||
the high peak LR allows the model to converge faster
|
||||
if training is stable. However the high LR
|
||||
may also cause instability during training.
|
||||
Should usually be a significant fraction of training
|
||||
steps (around 30-40% of the entire training steps).
|
||||
|
||||
Decay Steps: Final stage, where the LR rapidly decays
|
||||
with some scaling rate (set by decay rate).
|
||||
To attain Noam decay, use 0.5,
|
||||
for Squeezeformer recommended decay, use 1.0.
|
||||
The fast decay after prolonged high LR during
|
||||
hold phase allows for rapid convergence.
|
||||
|
||||
References:
|
||||
- [Squeezeformer:
|
||||
An Efficient Transformer for Automatic Speech Recognition]
|
||||
(https://arxiv.org/abs/2206.00888)
|
||||
|
||||
Args:
|
||||
optimizer: Pytorch compatible Optimizer object.
|
||||
warmup_steps: Number of training steps in warmup stage
|
||||
warmup_ratio: Ratio of warmup steps to total steps
|
||||
hold_steps: Number of training steps to
|
||||
hold the learning rate after warm up
|
||||
hold_ratio: Ratio of hold steps to total steps
|
||||
max_steps: Total number of steps while training or `None` for
|
||||
infinite training
|
||||
decay_rate: Float value describing the polynomial decay
|
||||
after the hold period. Default value
|
||||
of 0.5 corresponds to Noam decay.
|
||||
min_lr: Minimum learning rate.
|
||||
"""
|
||||
self.decay_rate = decay_rate
|
||||
super().__init__(optimizer=optimizer,
|
||||
max_steps=max_steps,
|
||||
last_epoch=last_epoch,
|
||||
min_lr=min_lr,
|
||||
**kwargs)
|
||||
|
||||
def _get_lr(self, step):
|
||||
if self.warmup_steps is None or self.warmup_steps == 0:
|
||||
raise ValueError(
|
||||
"Noam scheduler cannot be used without warmup steps")
|
||||
|
||||
if self.hold_steps > 0:
|
||||
hold_steps = self.hold_steps - self.warmup_steps
|
||||
else:
|
||||
hold_steps = 0
|
||||
|
||||
new_lrs = [
|
||||
_noam_hold_annealing(
|
||||
initial_lr,
|
||||
step=step,
|
||||
warmup_steps=self.warmup_steps,
|
||||
hold_steps=hold_steps,
|
||||
decay_rate=self.decay_rate,
|
||||
min_lr=self.min_lr,
|
||||
) for initial_lr in self.base_lrs
|
||||
]
|
||||
return new_lrs
|
||||
|
||||
def set_step(self, step: int):
|
||||
self.last_epoch = step
|
||||
|
||||
|
||||
class ConstantLR(_LRScheduler):
|
||||
"""The ConstantLR scheduler
|
||||
|
||||
This scheduler keeps a constant lr
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
optimizer: torch.optim.Optimizer,
|
||||
):
|
||||
# __init__() must be invoked before setting field
|
||||
# because step() is also invoked in __init__()
|
||||
super().__init__(optimizer)
|
||||
|
||||
def get_lr(self):
|
||||
return self.base_lrs
|
||||
|
||||
def set_step(self, step: int):
|
||||
self.last_epoch = step
|
||||
367
models/CosyVoice/cosyvoice/utils/train_utils.py
Normal file
367
models/CosyVoice/cosyvoice/utils/train_utils.py
Normal file
@@ -0,0 +1,367 @@
|
||||
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
|
||||
# 2023 Horizon Inc. (authors: Xingchen Song)
|
||||
# 2024 Alibaba Inc (authors: Xiang Lyu)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import torch
|
||||
import json
|
||||
import re
|
||||
import datetime
|
||||
import yaml
|
||||
|
||||
import deepspeed
|
||||
import torch.optim as optim
|
||||
import torch.distributed as dist
|
||||
|
||||
from torch.utils.tensorboard import SummaryWriter
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.nn.utils import clip_grad_norm_
|
||||
|
||||
from deepspeed.runtime.zero.stage_1_and_2 import estimate_zero2_model_states_mem_needs_all_live
|
||||
|
||||
from cosyvoice.dataset.dataset import Dataset
|
||||
from cosyvoice.utils.scheduler import WarmupLR, NoamHoldAnnealing, ConstantLR
|
||||
|
||||
|
||||
def init_distributed(args):
|
||||
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
||||
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
||||
rank = int(os.environ.get('RANK', 0))
|
||||
logging.info('training on multiple gpus, this gpu {}'.format(local_rank) +
|
||||
', rank {}, world_size {}'.format(rank, world_size))
|
||||
if args.train_engine == 'torch_ddp':
|
||||
torch.cuda.set_device(local_rank)
|
||||
dist.init_process_group(args.dist_backend)
|
||||
else:
|
||||
deepspeed.init_distributed(dist_backend=args.dist_backend)
|
||||
return world_size, local_rank, rank
|
||||
|
||||
|
||||
def init_dataset_and_dataloader(args, configs, gan, dpo):
|
||||
data_pipeline = configs['data_pipeline_gan'] if gan is True else configs['data_pipeline']
|
||||
train_dataset = Dataset(args.train_data, data_pipeline=data_pipeline, mode='train', gan=gan, dpo=dpo, shuffle=True, partition=True)
|
||||
cv_dataset = Dataset(args.cv_data, data_pipeline=data_pipeline, mode='dev', gan=gan, dpo=dpo, shuffle=False, partition=False)
|
||||
|
||||
# do not use persistent_workers=True, as whisper tokenizer opens tiktoken file each time when the for loop starts
|
||||
train_data_loader = DataLoader(train_dataset,
|
||||
batch_size=None,
|
||||
pin_memory=args.pin_memory,
|
||||
num_workers=args.num_workers,
|
||||
prefetch_factor=args.prefetch)
|
||||
cv_data_loader = DataLoader(cv_dataset,
|
||||
batch_size=None,
|
||||
pin_memory=args.pin_memory,
|
||||
num_workers=args.num_workers,
|
||||
prefetch_factor=args.prefetch)
|
||||
return train_dataset, cv_dataset, train_data_loader, cv_data_loader
|
||||
|
||||
|
||||
def check_modify_and_save_config(args, configs):
|
||||
if args.train_engine == "torch_ddp":
|
||||
configs['train_conf']["dtype"] = 'bf16' if args.use_amp is True else 'fp32'
|
||||
else:
|
||||
with open(args.deepspeed_config, 'r') as fin:
|
||||
ds_configs = json.load(fin)
|
||||
if "fp16" in ds_configs and ds_configs["fp16"]["enabled"]:
|
||||
configs['train_conf']["dtype"] = "fp16"
|
||||
elif "bf16" in ds_configs and ds_configs["bf16"]["enabled"]:
|
||||
configs['train_conf']["dtype"] = "bf16"
|
||||
else:
|
||||
configs['train_conf']["dtype"] = "fp32"
|
||||
assert ds_configs["train_micro_batch_size_per_gpu"] == 1
|
||||
# if use deepspeed, override ddp config
|
||||
configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] *
|
||||
configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"])
|
||||
configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"]
|
||||
configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"]
|
||||
configs['train_conf']['log_interval'] = ds_configs["steps_per_print"]
|
||||
return configs
|
||||
|
||||
|
||||
def wrap_cuda_model(args, model):
|
||||
local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
|
||||
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
||||
if args.train_engine == "torch_ddp": # native pytorch ddp
|
||||
assert (torch.cuda.is_available())
|
||||
model.cuda()
|
||||
model = torch.nn.parallel.DistributedDataParallel(model, find_unused_parameters=True)
|
||||
else:
|
||||
if int(os.environ.get('RANK', 0)) == 0:
|
||||
logging.info("Estimating model states memory needs (zero2)...")
|
||||
estimate_zero2_model_states_mem_needs_all_live(
|
||||
model,
|
||||
num_gpus_per_node=local_world_size,
|
||||
num_nodes=world_size // local_world_size)
|
||||
return model
|
||||
|
||||
|
||||
def init_optimizer_and_scheduler(args, configs, model, gan):
|
||||
if gan is False:
|
||||
if configs['train_conf']['optim'] == 'adam':
|
||||
optimizer = optim.Adam(model.parameters(), **configs['train_conf']['optim_conf'])
|
||||
elif configs['train_conf']['optim'] == 'adamw':
|
||||
optimizer = optim.AdamW(model.parameters(), **configs['train_conf']['optim_conf'])
|
||||
else:
|
||||
raise ValueError("unknown optimizer: " + configs['train_conf'])
|
||||
|
||||
if configs['train_conf']['scheduler'] == 'warmuplr':
|
||||
scheduler_type = WarmupLR
|
||||
scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
|
||||
elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
|
||||
scheduler_type = NoamHoldAnnealing
|
||||
scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
|
||||
elif configs['train_conf']['scheduler'] == 'constantlr':
|
||||
scheduler_type = ConstantLR
|
||||
scheduler = ConstantLR(optimizer)
|
||||
else:
|
||||
raise ValueError("unknown scheduler: " + configs['train_conf'])
|
||||
|
||||
# use deepspeed optimizer for speedup
|
||||
if args.train_engine == "deepspeed":
|
||||
def scheduler(opt):
|
||||
return scheduler_type(opt, **configs['train_conf']['scheduler_conf'])
|
||||
model, optimizer, _, scheduler = deepspeed.initialize(
|
||||
args=args,
|
||||
model=model,
|
||||
optimizer=None,
|
||||
lr_scheduler=scheduler,
|
||||
model_parameters=model.parameters())
|
||||
|
||||
optimizer_d, scheduler_d = None, None
|
||||
|
||||
else:
|
||||
# currently we wrap generator and discriminator in one model, so we cannot use deepspeed
|
||||
if configs['train_conf']['optim'] == 'adam':
|
||||
optimizer = optim.Adam(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
|
||||
elif configs['train_conf']['optim'] == 'adamw':
|
||||
optimizer = optim.AdamW(model.module.generator.parameters(), **configs['train_conf']['optim_conf'])
|
||||
else:
|
||||
raise ValueError("unknown optimizer: " + configs['train_conf'])
|
||||
|
||||
if configs['train_conf']['scheduler'] == 'warmuplr':
|
||||
scheduler_type = WarmupLR
|
||||
scheduler = WarmupLR(optimizer, **configs['train_conf']['scheduler_conf'])
|
||||
elif configs['train_conf']['scheduler'] == 'NoamHoldAnnealing':
|
||||
scheduler_type = NoamHoldAnnealing
|
||||
scheduler = NoamHoldAnnealing(optimizer, **configs['train_conf']['scheduler_conf'])
|
||||
elif configs['train_conf']['scheduler'] == 'constantlr':
|
||||
scheduler_type = ConstantLR
|
||||
scheduler = ConstantLR(optimizer)
|
||||
else:
|
||||
raise ValueError("unknown scheduler: " + configs['train_conf'])
|
||||
|
||||
if configs['train_conf']['optim_d'] == 'adam':
|
||||
optimizer_d = optim.Adam(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf_d'])
|
||||
elif configs['train_conf']['optim_d'] == 'adamw':
|
||||
optimizer_d = optim.AdamW(model.module.discriminator.parameters(), **configs['train_conf']['optim_conf_d'])
|
||||
else:
|
||||
raise ValueError("unknown optimizer: " + configs['train_conf'])
|
||||
|
||||
if configs['train_conf']['scheduler_d'] == 'warmuplr':
|
||||
scheduler_type = WarmupLR
|
||||
scheduler_d = WarmupLR(optimizer_d, **configs['train_conf']['scheduler_d'])
|
||||
elif configs['train_conf']['scheduler_d'] == 'NoamHoldAnnealing':
|
||||
scheduler_type = NoamHoldAnnealing
|
||||
scheduler_d = NoamHoldAnnealing(optimizer_d, **configs['train_conf']['scheduler_d'])
|
||||
elif configs['train_conf']['scheduler'] == 'constantlr':
|
||||
scheduler_type = ConstantLR
|
||||
scheduler_d = ConstantLR(optimizer_d)
|
||||
else:
|
||||
raise ValueError("unknown scheduler: " + configs['train_conf'])
|
||||
return model, optimizer, scheduler, optimizer_d, scheduler_d
|
||||
|
||||
|
||||
def init_summarywriter(args):
|
||||
writer = None
|
||||
if int(os.environ.get('RANK', 0)) == 0:
|
||||
os.makedirs(args.model_dir, exist_ok=True)
|
||||
writer = SummaryWriter(args.tensorboard_dir)
|
||||
return writer
|
||||
|
||||
|
||||
def save_model(model, model_name, info_dict):
|
||||
rank = int(os.environ.get('RANK', 0))
|
||||
model_dir = info_dict["model_dir"]
|
||||
save_model_path = os.path.join(model_dir, '{}.pt'.format(model_name))
|
||||
|
||||
if info_dict["train_engine"] == "torch_ddp":
|
||||
if rank == 0:
|
||||
torch.save({**model.module.state_dict(), 'epoch': info_dict['epoch'], 'step': info_dict['step']}, save_model_path)
|
||||
else:
|
||||
with torch.no_grad():
|
||||
model.save_checkpoint(save_dir=model_dir,
|
||||
tag=model_name,
|
||||
client_state=info_dict)
|
||||
if rank == 0:
|
||||
info_path = re.sub('.pt$', '.yaml', save_model_path)
|
||||
info_dict['save_time'] = datetime.datetime.now().strftime('%d/%m/%Y %H:%M:%S')
|
||||
with open(info_path, 'w') as fout:
|
||||
data = yaml.dump(info_dict)
|
||||
fout.write(data)
|
||||
logging.info('[Rank {}] Checkpoint: save to checkpoint {}'.format(rank, save_model_path))
|
||||
|
||||
|
||||
def cosyvoice_join(group_join, info_dict):
|
||||
world_size = int(os.environ.get('WORLD_SIZE', 1))
|
||||
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
||||
rank = int(os.environ.get('RANK', 0))
|
||||
|
||||
if info_dict["batch_idx"] != 0:
|
||||
# we try to join all rank in both ddp and deepspeed mode, in case different rank has different lr
|
||||
try:
|
||||
dist.monitored_barrier(group=group_join,
|
||||
timeout=group_join.options._timeout)
|
||||
return False
|
||||
except RuntimeError as e:
|
||||
logging.info("Detected uneven workload distribution: {}\n".format(e) +
|
||||
"Break current worker to manually join all workers, " +
|
||||
"world_size {}, current rank {}, current local_rank {}\n".
|
||||
format(world_size, rank, local_rank))
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def batch_forward(model, batch, scaler, info_dict, ref_model=None, dpo_loss=None):
|
||||
device = int(os.environ.get('LOCAL_RANK', 0))
|
||||
|
||||
dtype = info_dict["dtype"]
|
||||
if dtype == "fp16":
|
||||
dtype = torch.float16
|
||||
elif dtype == "bf16":
|
||||
dtype = torch.bfloat16
|
||||
else: # fp32
|
||||
dtype = torch.float32
|
||||
|
||||
if info_dict['train_engine'] == 'torch_ddp':
|
||||
autocast = torch.cuda.amp.autocast(enabled=scaler is not None, dtype=dtype)
|
||||
else:
|
||||
autocast = torch.cuda.amp.autocast(enabled=True, dtype=dtype, cache_enabled=False)
|
||||
|
||||
with autocast:
|
||||
info_dict['loss_dict'] = model(batch, device)
|
||||
if ref_model is not None and dpo_loss is not None:
|
||||
chosen_logps = info_dict['loss_dict']["chosen_logps"]
|
||||
rejected_logps = info_dict['loss_dict']["rejected_logps"]
|
||||
sft_loss = info_dict['loss_dict']['loss']
|
||||
with torch.no_grad():
|
||||
ref_loss_dict = ref_model(batch, device)
|
||||
reference_chosen_logps = ref_loss_dict["chosen_logps"]
|
||||
reference_rejected_logps = ref_loss_dict["rejected_logps"]
|
||||
preference_loss, chosen_reward, reject_reward = dpo_loss(
|
||||
chosen_logps, rejected_logps, reference_chosen_logps, reference_rejected_logps
|
||||
)
|
||||
dpo_acc = (chosen_reward > reject_reward).float().mean()
|
||||
info_dict['loss_dict']["loss"] = preference_loss + sft_loss
|
||||
info_dict['loss_dict']["sft_loss"] = sft_loss
|
||||
info_dict['loss_dict']["dpo_loss"] = preference_loss
|
||||
info_dict['loss_dict']["dpo_acc"] = dpo_acc
|
||||
info_dict['loss_dict']["chosen_reward"] = chosen_reward.mean()
|
||||
info_dict['loss_dict']["reject_reward"] = reject_reward.mean()
|
||||
return info_dict
|
||||
|
||||
|
||||
def batch_backward(model, scaler, info_dict):
|
||||
if info_dict["train_engine"] == "deepspeed":
|
||||
scaled_loss = model.backward(info_dict['loss_dict']['loss'])
|
||||
else:
|
||||
scaled_loss = info_dict['loss_dict']['loss'] / info_dict['accum_grad']
|
||||
if scaler is not None:
|
||||
scaler.scale(scaled_loss).backward()
|
||||
else:
|
||||
scaled_loss.backward()
|
||||
|
||||
info_dict['loss_dict']['loss'] = scaled_loss
|
||||
return info_dict
|
||||
|
||||
|
||||
def update_parameter_and_lr(model, optimizer, scheduler, scaler, info_dict):
|
||||
grad_norm = 0.0
|
||||
if info_dict['train_engine'] == "deepspeed":
|
||||
info_dict["is_gradient_accumulation_boundary"] = model.is_gradient_accumulation_boundary()
|
||||
model.step()
|
||||
grad_norm = model.get_global_grad_norm()
|
||||
elif (info_dict['batch_idx'] + 1) % info_dict["accum_grad"] == 0:
|
||||
# Use mixed precision training
|
||||
if scaler is not None:
|
||||
scaler.unscale_(optimizer)
|
||||
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
|
||||
# We don't check grad here since that if the gradient
|
||||
# has inf/nan values, scaler.step will skip
|
||||
# optimizer.step().
|
||||
if torch.isfinite(grad_norm):
|
||||
scaler.step(optimizer)
|
||||
else:
|
||||
logging.warning('get infinite grad_norm, check your code/data if it appears frequently')
|
||||
scaler.update()
|
||||
else:
|
||||
grad_norm = clip_grad_norm_(model.parameters(), info_dict['grad_clip'])
|
||||
if torch.isfinite(grad_norm):
|
||||
optimizer.step()
|
||||
else:
|
||||
logging.warning('get infinite grad_norm, check your code/data if it appears frequently')
|
||||
optimizer.zero_grad()
|
||||
scheduler.step()
|
||||
info_dict["lr"] = optimizer.param_groups[0]['lr']
|
||||
info_dict["grad_norm"] = grad_norm
|
||||
return info_dict
|
||||
|
||||
|
||||
def log_per_step(writer, info_dict):
|
||||
tag = info_dict["tag"]
|
||||
epoch = info_dict.get('epoch', 0)
|
||||
step = info_dict["step"]
|
||||
batch_idx = info_dict["batch_idx"]
|
||||
loss_dict = info_dict['loss_dict']
|
||||
rank = int(os.environ.get('RANK', 0))
|
||||
|
||||
# only rank 0 write to tensorboard to avoid multi-process write
|
||||
if writer is not None:
|
||||
if (info_dict['train_engine'] == 'deepspeed' and info_dict['is_gradient_accumulation_boundary'] is True) or \
|
||||
(info_dict['train_engine'] == 'torch_ddp' and (info_dict['batch_idx'] + 1) % info_dict['accum_grad'] == 0):
|
||||
for k in ['epoch', 'lr', 'grad_norm']:
|
||||
writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
|
||||
for k, v in loss_dict.items():
|
||||
writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)
|
||||
|
||||
# TRAIN & CV, Shell log (stdout)
|
||||
if (info_dict['batch_idx'] + 1) % info_dict['log_interval'] == 0:
|
||||
log_str = '{} Batch {}/{} '.format(tag, epoch, batch_idx + 1)
|
||||
for name, value in loss_dict.items():
|
||||
log_str += '{} {:.6f} '.format(name, value)
|
||||
if tag == "TRAIN":
|
||||
log_str += 'lr {:.8f} grad_norm {:.6f}'.format(
|
||||
info_dict["lr"], info_dict['grad_norm'])
|
||||
log_str += ' rank {}'.format(rank)
|
||||
logging.debug(log_str)
|
||||
|
||||
|
||||
def log_per_save(writer, info_dict):
|
||||
tag = info_dict["tag"]
|
||||
epoch = info_dict["epoch"]
|
||||
step = info_dict["step"]
|
||||
loss_dict = info_dict["loss_dict"]
|
||||
lr = info_dict['lr']
|
||||
rank = int(os.environ.get('RANK', 0))
|
||||
logging.info(
|
||||
'Epoch {} Step {} CV info lr {} {} rank {}'.format(
|
||||
epoch, step + 1, lr, rank, ' '.join(['{} {}'.format(k, v) for k, v in loss_dict.items()])))
|
||||
|
||||
if writer is not None:
|
||||
for k in ['epoch', 'lr']:
|
||||
writer.add_scalar('{}/{}'.format(tag, k), info_dict[k], step + 1)
|
||||
for k, v in loss_dict.items():
|
||||
writer.add_scalar('{}/{}'.format(tag, k), v, step + 1)
|
||||
116
models/CosyVoice/cosyvoice/vllm/cosyvoice2.py
Normal file
116
models/CosyVoice/cosyvoice/vllm/cosyvoice2.py
Normal file
@@ -0,0 +1,116 @@
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
|
||||
# Adapted from
|
||||
# https://github.com/huggingface/transformers/blob/v4.28.0/src/transformers/models/qwen2/modeling_qwen2.py
|
||||
# Copyright 2024 The Qwen team.
|
||||
# Copyright 2023 The vLLM team.
|
||||
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
|
||||
# and OPT implementations in this library. It has been modified from its
|
||||
# original forms to accommodate minor architectural differences compared
|
||||
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
||||
from typing import Optional
|
||||
from packaging.version import parse as vparse
|
||||
import vllm
|
||||
|
||||
# vLLM-0.11.0+ only support V1 engine
|
||||
VLLM_V1_ENGINE_ONLY: bool = vparse(vllm.__version__) >= vparse("0.11.0")
|
||||
if VLLM_V1_ENGINE_ONLY:
|
||||
from vllm.v1.sample.metadata import SamplingMetadata
|
||||
|
||||
from vllm.model_executor.models.qwen2 import *
|
||||
|
||||
|
||||
class CosyVoice2ForCausalLM(nn.Module, SupportsLoRA, SupportsPP):
|
||||
packed_modules_mapping = {
|
||||
"qkv_proj": [
|
||||
"q_proj",
|
||||
"k_proj",
|
||||
"v_proj",
|
||||
],
|
||||
"gate_up_proj": [
|
||||
"gate_proj",
|
||||
"up_proj",
|
||||
],
|
||||
}
|
||||
|
||||
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
|
||||
super().__init__()
|
||||
config = vllm_config.model_config.hf_config
|
||||
quant_config = vllm_config.quant_config
|
||||
lora_config = vllm_config.lora_config
|
||||
|
||||
self.config = config
|
||||
self.lora_config = lora_config
|
||||
|
||||
self.quant_config = quant_config
|
||||
self.model = Qwen2Model(vllm_config=vllm_config,
|
||||
prefix=maybe_prefix(prefix, "model"))
|
||||
|
||||
if get_pp_group().is_last_rank:
|
||||
if config.tie_word_embeddings:
|
||||
self.lm_head = self.model.embed_tokens
|
||||
else:
|
||||
self.lm_head = ParallelLMHead(config.vocab_size,
|
||||
config.hidden_size,
|
||||
True,
|
||||
quant_config=quant_config,
|
||||
prefix=maybe_prefix(
|
||||
prefix, "lm_head"))
|
||||
else:
|
||||
self.lm_head = PPMissingLayer()
|
||||
|
||||
self.logits_processor = LogitsProcessor(config.vocab_size)
|
||||
|
||||
self.make_empty_intermediate_tensors = (
|
||||
self.model.make_empty_intermediate_tensors)
|
||||
|
||||
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
||||
return self.model.get_input_embeddings(input_ids)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
positions: torch.Tensor,
|
||||
intermediate_tensors: Optional[IntermediateTensors] = None,
|
||||
inputs_embeds: Optional[torch.Tensor] = None,
|
||||
) -> Union[torch.Tensor, IntermediateTensors]:
|
||||
hidden_states = self.model(input_ids, positions, intermediate_tensors,
|
||||
inputs_embeds)
|
||||
return hidden_states
|
||||
|
||||
def compute_logits(
|
||||
self,
|
||||
hidden_states: torch.Tensor,
|
||||
sampling_metadata: Optional[SamplingMetadata] = None,
|
||||
) -> Optional[torch.Tensor]:
|
||||
if VLLM_V1_ENGINE_ONLY:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
self.lm_head.bias)
|
||||
else:
|
||||
logits = self.logits_processor(self.lm_head, hidden_states,
|
||||
sampling_metadata, self.lm_head.bias)
|
||||
return logits
|
||||
|
||||
def load_weights(self, weights: Iterable[tuple[str,
|
||||
torch.Tensor]]) -> set[str]:
|
||||
loader = AutoWeightsLoader(
|
||||
self,
|
||||
skip_prefixes=(["lm_head."]
|
||||
if self.config.tie_word_embeddings else None),
|
||||
)
|
||||
return loader.load_weights(weights)
|
||||
328
models/CosyVoice/cosyvoice_server.py
Normal file
328
models/CosyVoice/cosyvoice_server.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""
|
||||
CosyVoice 3.0 声音克隆服务
|
||||
端口: 8010
|
||||
GPU: 0
|
||||
|
||||
启动方式:
|
||||
conda activate cosyvoice
|
||||
python cosyvoice_server.py
|
||||
|
||||
PM2 启动:
|
||||
pm2 start run_cosyvoice.sh --name vigent2-cosyvoice
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
|
||||
# 设置 GPU
|
||||
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
||||
|
||||
# CosyVoice 需要 Matcha-TTS 子模块
|
||||
SCRIPT_DIR = Path(__file__).parent
|
||||
sys.path.append(str(SCRIPT_DIR / "third_party" / "Matcha-TTS"))
|
||||
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form
|
||||
from fastapi.responses import FileResponse
|
||||
from pydantic import BaseModel
|
||||
import uvicorn
|
||||
|
||||
app = FastAPI(title="CosyVoice 3.0 Voice Clone Service", version="1.0")
|
||||
|
||||
MODEL_DIR = SCRIPT_DIR / "pretrained_models" / "Fun-CosyVoice3-0.5B"
|
||||
|
||||
# 全局模型实例
|
||||
_model = None
|
||||
_model_loaded = False
|
||||
_poisoned = False
|
||||
|
||||
# GPU 推理锁
|
||||
_inference_lock = asyncio.Lock()
|
||||
|
||||
|
||||
def _schedule_force_exit(reason: str, delay_sec: float = 1.5):
|
||||
"""超时后强制退出进程,让 PM2 立即拉起新进程。"""
|
||||
import threading
|
||||
|
||||
def _killer():
|
||||
time.sleep(delay_sec)
|
||||
print(f"💥 Force exiting process: {reason}")
|
||||
os._exit(1)
|
||||
|
||||
threading.Thread(target=_killer, daemon=True).start()
|
||||
|
||||
|
||||
def load_model():
|
||||
"""加载模型(启动时调用)"""
|
||||
global _model, _model_loaded
|
||||
|
||||
if _model_loaded:
|
||||
return
|
||||
|
||||
print(f"🔄 Loading CosyVoice 3.0 model from {MODEL_DIR}...")
|
||||
start = time.time()
|
||||
|
||||
from cosyvoice.cli.cosyvoice import AutoModel
|
||||
_model = AutoModel(model_dir=str(MODEL_DIR))
|
||||
|
||||
_model_loaded = True
|
||||
print(f"✅ CosyVoice 3.0 model loaded in {time.time() - start:.1f}s")
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
service: str
|
||||
model: str
|
||||
ready: bool
|
||||
gpu_id: int
|
||||
|
||||
|
||||
def _startup_selftest():
|
||||
"""启动自检:用短文本做一次推理,验证 GPU 推理链路可用。"""
|
||||
import torch
|
||||
|
||||
print("🔍 Running startup self-test inference...")
|
||||
start = time.time()
|
||||
|
||||
test_text = "你好"
|
||||
# 使用一段静音作为参考音频(0.5秒 24kHz)
|
||||
ref_audio_path = None
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
|
||||
ref_audio_path = tmp.name
|
||||
import torchaudio
|
||||
silence = torch.zeros(1, 12000) # 0.5s @ 24kHz
|
||||
torchaudio.save(ref_audio_path, silence, 24000)
|
||||
|
||||
prompt_text = f"You are a helpful assistant.<|endofprompt|>你好"
|
||||
results = list(_model.inference_zero_shot(
|
||||
test_text,
|
||||
prompt_text,
|
||||
ref_audio_path,
|
||||
stream=False,
|
||||
text_frontend=True,
|
||||
))
|
||||
if not results:
|
||||
raise RuntimeError("Self-test returned empty results")
|
||||
|
||||
segments = [r["tts_speech"] for r in results if isinstance(r, dict) and "tts_speech" in r]
|
||||
if not segments:
|
||||
raise RuntimeError("Self-test returned no tts_speech segments")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
print(f"✅ Self-test passed in {time.time() - start:.1f}s "
|
||||
f"(output shape: {segments[0].shape})")
|
||||
return True
|
||||
except Exception as e:
|
||||
print(f"❌ Self-test FAILED: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except:
|
||||
pass
|
||||
return False
|
||||
finally:
|
||||
if ref_audio_path:
|
||||
try:
|
||||
os.unlink(ref_audio_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def startup():
|
||||
"""服务启动时预加载模型并自检推理"""
|
||||
try:
|
||||
load_model()
|
||||
except Exception as e:
|
||||
print(f"❌ Model loading failed: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
return
|
||||
|
||||
# 自检推理 — 失败则标记为不可用
|
||||
global _model_loaded
|
||||
if not _startup_selftest():
|
||||
_model_loaded = False
|
||||
print("⚠️ Self-test failed, marking service as NOT ready")
|
||||
|
||||
|
||||
@app.get("/health", response_model=HealthResponse)
|
||||
async def health():
|
||||
"""健康检查"""
|
||||
gpu_ok = False
|
||||
try:
|
||||
import torch
|
||||
gpu_ok = torch.cuda.is_available()
|
||||
except:
|
||||
pass
|
||||
|
||||
return HealthResponse(
|
||||
service="CosyVoice 3.0 Voice Clone",
|
||||
model="Fun-CosyVoice3-0.5B",
|
||||
ready=_model_loaded and gpu_ok and not _poisoned,
|
||||
gpu_id=0
|
||||
)
|
||||
|
||||
|
||||
@app.post("/generate")
|
||||
async def generate(
|
||||
ref_audio: UploadFile = File(...),
|
||||
text: str = Form(...),
|
||||
ref_text: str = Form(...),
|
||||
language: str = Form("Chinese"),
|
||||
speed: float = Form(1.0),
|
||||
):
|
||||
"""
|
||||
声音克隆生成
|
||||
|
||||
Args:
|
||||
ref_audio: 参考音频文件 (WAV)
|
||||
text: 要合成的文本
|
||||
ref_text: 参考音频的转写文字
|
||||
language: 语言(兼容参数,CosyVoice 自动检测语言)
|
||||
|
||||
Returns:
|
||||
生成的音频文件 (WAV)
|
||||
"""
|
||||
global _poisoned
|
||||
|
||||
if not _model_loaded:
|
||||
raise HTTPException(status_code=503, detail="Model not loaded")
|
||||
|
||||
if _poisoned:
|
||||
raise HTTPException(status_code=503, detail="Service poisoned after timeout, waiting for restart")
|
||||
|
||||
if _inference_lock.locked():
|
||||
raise HTTPException(status_code=429, detail="GPU busy, please retry later")
|
||||
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
# 保存上传的参考音频到临时文件
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_ref:
|
||||
content = await ref_audio.read()
|
||||
tmp_ref.write(content)
|
||||
ref_audio_path = tmp_ref.name
|
||||
|
||||
# 参考音频过长时自动截取前 10 秒(CosyVoice 建议 3-10 秒)
|
||||
MAX_REF_SEC = 10
|
||||
try:
|
||||
info = torchaudio.info(ref_audio_path)
|
||||
ref_dur = info.num_frames / info.sample_rate
|
||||
if ref_dur > MAX_REF_SEC:
|
||||
print(f"✂️ Ref audio too long ({ref_dur:.1f}s), trimming to {MAX_REF_SEC}s")
|
||||
wav, sr = torchaudio.load(ref_audio_path, num_frames=int(info.sample_rate * MAX_REF_SEC))
|
||||
torchaudio.save(ref_audio_path, wav, sr)
|
||||
except Exception as e:
|
||||
print(f"⚠️ Could not check ref audio duration: {e}")
|
||||
|
||||
output_path = tempfile.mktemp(suffix=".wav")
|
||||
|
||||
try:
|
||||
async with _inference_lock:
|
||||
print(f"🎤 Generating: {text[:50]}... ({len(text)} chars)")
|
||||
print(f"📝 Ref text: {ref_text[:50]}...")
|
||||
print(f"🌐 Language: {language}")
|
||||
print(f"⚡ Speed: {speed}")
|
||||
|
||||
start = time.time()
|
||||
|
||||
# 超时保护:基础60秒 + 每字符2秒,上限300秒
|
||||
timeout_sec = min(60 + len(text) * 2, 300)
|
||||
|
||||
# CosyVoice3 的 prompt_text 格式
|
||||
prompt_text = f"You are a helpful assistant.<|endofprompt|>{ref_text}"
|
||||
|
||||
def _do_inference():
|
||||
"""在线程池中执行推理"""
|
||||
results = list(_model.inference_zero_shot(
|
||||
text,
|
||||
prompt_text,
|
||||
ref_audio_path,
|
||||
stream=False,
|
||||
speed=speed,
|
||||
text_frontend=True,
|
||||
))
|
||||
if not results:
|
||||
raise RuntimeError("CosyVoice returned empty results")
|
||||
|
||||
segments = [r["tts_speech"] for r in results if isinstance(r, dict) and "tts_speech" in r]
|
||||
if not segments:
|
||||
raise RuntimeError("CosyVoice returned no tts_speech segments")
|
||||
|
||||
if len(segments) == 1:
|
||||
merged = segments[0]
|
||||
else:
|
||||
gap = torch.zeros((segments[0].shape[0], int(_model.sample_rate * 0.05)), dtype=segments[0].dtype)
|
||||
parts = [segments[0]]
|
||||
for seg in segments[1:]:
|
||||
parts.append(gap)
|
||||
parts.append(seg)
|
||||
merged = torch.cat(parts, dim=-1)
|
||||
|
||||
return merged, _model.sample_rate
|
||||
|
||||
try:
|
||||
speech, sr = await asyncio.wait_for(
|
||||
asyncio.to_thread(_do_inference),
|
||||
timeout=timeout_sec,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
_poisoned = True
|
||||
print(f"⏰ Generation timed out after {timeout_sec}s for {len(text)} chars — service POISONED")
|
||||
torch.cuda.empty_cache()
|
||||
_schedule_force_exit("generation timeout")
|
||||
raise HTTPException(status_code=500, detail=f"生成超时({timeout_sec}s),请缩短文本后重试")
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
|
||||
torchaudio.save(output_path, speech, sr)
|
||||
|
||||
duration = speech.shape[-1] / sr
|
||||
print(f"✅ Generated in {time.time() - start:.1f}s, duration: {duration:.1f}s")
|
||||
|
||||
return FileResponse(
|
||||
output_path,
|
||||
media_type="audio/wav",
|
||||
filename="output.wav",
|
||||
background=None
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Generation failed: {e}")
|
||||
try:
|
||||
torch.cuda.empty_cache()
|
||||
except:
|
||||
pass
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
finally:
|
||||
try:
|
||||
os.unlink(ref_audio_path)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def shutdown():
|
||||
"""清理临时文件"""
|
||||
import glob
|
||||
for f in glob.glob("/tmp/tmp*.wav"):
|
||||
try:
|
||||
os.unlink(f)
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
uvicorn.run(
|
||||
app,
|
||||
host="0.0.0.0",
|
||||
port=8010,
|
||||
log_level="info"
|
||||
)
|
||||
51
models/CosyVoice/docker/Dockerfile
Normal file
51
models/CosyVoice/docker/Dockerfile
Normal file
@@ -0,0 +1,51 @@
|
||||
FROM nvidia/cuda:12.4.1-cudnn-devel-ubuntu22.04
|
||||
|
||||
ARG VENV_NAME="cosyvoice"
|
||||
ENV VENV=$VENV_NAME
|
||||
ENV LANG=C.UTF-8 LC_ALL=C.UTF-8
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
SHELL ["/bin/bash", "--login", "-c"]
|
||||
|
||||
RUN apt-get update -y --fix-missing
|
||||
RUN apt-get install -y git build-essential curl wget ffmpeg unzip git git-lfs sox libsox-dev && \
|
||||
apt-get clean && \
|
||||
git lfs install
|
||||
|
||||
# ==================================================================
|
||||
# conda install and conda forge channel as default
|
||||
# ------------------------------------------------------------------
|
||||
# Install miniforge
|
||||
RUN wget --quiet https://github.com/conda-forge/miniforge/releases/latest/download/Miniforge3-Linux-x86_64.sh -O ~/miniforge.sh && \
|
||||
/bin/bash ~/miniforge.sh -b -p /opt/conda && \
|
||||
rm ~/miniforge.sh && \
|
||||
ln -s /opt/conda/etc/profile.d/conda.sh /etc/profile.d/conda.sh && \
|
||||
echo "source /opt/conda/etc/profile.d/conda.sh" >> /opt/nvidia/entrypoint.d/100.conda.sh && \
|
||||
echo "source /opt/conda/etc/profile.d/conda.sh" >> ~/.bashrc && \
|
||||
echo "conda activate ${VENV}" >> /opt/nvidia/entrypoint.d/110.conda_default_env.sh && \
|
||||
echo "conda activate ${VENV}" >> $HOME/.bashrc
|
||||
|
||||
ENV PATH /opt/conda/bin:$PATH
|
||||
|
||||
RUN conda config --add channels conda-forge && \
|
||||
conda config --set channel_priority strict
|
||||
# ------------------------------------------------------------------
|
||||
# ~conda
|
||||
# ==================================================================
|
||||
|
||||
RUN conda create -y -n ${VENV} python=3.10
|
||||
ENV CONDA_DEFAULT_ENV=${VENV}
|
||||
ENV PATH /opt/conda/bin:/opt/conda/envs/${VENV}/bin:$PATH
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
ENV PYTHONPATH="${PYTHONPATH}:/workspace/CosyVoice:/workspace/CosyVoice/third_party/Matcha-TTS"
|
||||
|
||||
RUN git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
|
||||
|
||||
RUN conda activate ${VENV} && conda install -y -c conda-forge pynini==2.1.5
|
||||
RUN conda activate ${VENV} && cd CosyVoice && \
|
||||
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com --no-cache-dir
|
||||
|
||||
WORKDIR /workspace/CosyVoice
|
||||
112
models/CosyVoice/example.py
Normal file
112
models/CosyVoice/example.py
Normal file
@@ -0,0 +1,112 @@
|
||||
import sys
|
||||
sys.path.append('third_party/Matcha-TTS')
|
||||
from cosyvoice.cli.cosyvoice import AutoModel
|
||||
import torchaudio
|
||||
|
||||
|
||||
def cosyvoice_example():
|
||||
""" CosyVoice Usage, check https://fun-audio-llm.github.io/ for more details
|
||||
"""
|
||||
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M-SFT')
|
||||
# sft usage
|
||||
print(cosyvoice.list_available_spks())
|
||||
# change stream=True for chunk stream inference
|
||||
for i, j in enumerate(cosyvoice.inference_sft('你好,我是通义生成式语音大模型,请问有什么可以帮您的吗?', '中文女', stream=False)):
|
||||
torchaudio.save('sft_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M')
|
||||
# zero_shot usage
|
||||
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav')):
|
||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
# cross_lingual usage, <|zh|><|en|><|ja|><|yue|><|ko|> for Chinese/English/Japanese/Cantonese/Korean
|
||||
for i, j in enumerate(cosyvoice.inference_cross_lingual('<|en|>And then later on, fully acquiring that company. So keeping management in line, interest in line with the asset that\'s coming into the family is a reason why sometimes we don\'t buy the whole thing.',
|
||||
'./asset/cross_lingual_prompt.wav')):
|
||||
torchaudio.save('cross_lingual_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
# vc usage
|
||||
for i, j in enumerate(cosyvoice.inference_vc('./asset/cross_lingual_prompt.wav', './asset/zero_shot_prompt.wav')):
|
||||
torchaudio.save('vc_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice-300M-Instruct')
|
||||
# instruct usage, support <laughter></laughter><strong></strong>[laughter][breath]
|
||||
for i, j in enumerate(cosyvoice.inference_instruct('在面对挑战时,他展现了非凡的<strong>勇气</strong>与<strong>智慧</strong>。', '中文男',
|
||||
'Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.<|endofprompt|>')):
|
||||
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
|
||||
def cosyvoice2_example():
|
||||
""" CosyVoice2 Usage, check https://funaudiollm.github.io/cosyvoice2/ for more details
|
||||
"""
|
||||
cosyvoice = AutoModel(model_dir='pretrained_models/CosyVoice2-0.5B')
|
||||
|
||||
# NOTE if you want to reproduce the results on https://funaudiollm.github.io/cosyvoice2, please add text_frontend=False during inference
|
||||
# zero_shot usage
|
||||
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav')):
|
||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
# save zero_shot spk for future usage
|
||||
assert cosyvoice.add_zero_shot_spk('希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav', 'my_zero_shot_spk') is True
|
||||
for i, j in enumerate(cosyvoice.inference_zero_shot('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '', '', zero_shot_spk_id='my_zero_shot_spk')):
|
||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
cosyvoice.save_spkinfo()
|
||||
|
||||
# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L248
|
||||
for i, j in enumerate(cosyvoice.inference_cross_lingual('在他讲述那个荒诞故事的过程中,他突然[laughter]停下来,因为他自己也被逗笑了[laughter]。', './asset/zero_shot_prompt.wav')):
|
||||
torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
# instruct usage
|
||||
for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', '用四川话说这句话<|endofprompt|>', './asset/zero_shot_prompt.wav')):
|
||||
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
# bistream usage, you can use generator as input, this is useful when using text llm model as input
|
||||
# NOTE you should still have some basic sentence split logic because llm can not handle arbitrary sentence length
|
||||
def text_generator():
|
||||
yield '收到好友从远方寄来的生日礼物,'
|
||||
yield '那份意外的惊喜与深深的祝福'
|
||||
yield '让我心中充满了甜蜜的快乐,'
|
||||
yield '笑容如花儿般绽放。'
|
||||
for i, j in enumerate(cosyvoice.inference_zero_shot(text_generator(), '希望你以后能够做的比我还好呦。', './asset/zero_shot_prompt.wav', stream=False)):
|
||||
torchaudio.save('zero_shot_bistream_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
|
||||
def cosyvoice3_example():
|
||||
""" CosyVoice3 Usage, check https://funaudiollm.github.io/cosyvoice3/ for more details
|
||||
"""
|
||||
cosyvoice = AutoModel(model_dir='pretrained_models/Fun-CosyVoice3-0.5B')
|
||||
# zero_shot usage
|
||||
for i, j in enumerate(cosyvoice.inference_zero_shot('八百标兵奔北坡,北坡炮兵并排跑,炮兵怕把标兵碰,标兵怕碰炮兵炮。', 'You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。',
|
||||
'./asset/zero_shot_prompt.wav', stream=False)):
|
||||
torchaudio.save('zero_shot_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
# fine grained control, for supported control, check cosyvoice/tokenizer/tokenizer.py#L280
|
||||
for i, j in enumerate(cosyvoice.inference_cross_lingual('You are a helpful assistant.<|endofprompt|>[breath]因为他们那一辈人[breath]在乡里面住的要习惯一点,[breath]邻居都很活络,[breath]嗯,都很熟悉。[breath]',
|
||||
'./asset/zero_shot_prompt.wav', stream=False)):
|
||||
torchaudio.save('fine_grained_control_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
# instruct usage, for supported control, check cosyvoice/utils/common.py#L28
|
||||
for i, j in enumerate(cosyvoice.inference_instruct2('好少咯,一般系放嗰啲国庆啊,中秋嗰啲可能会咯。', 'You are a helpful assistant. 请用广东话表达。<|endofprompt|>',
|
||||
'./asset/zero_shot_prompt.wav', stream=False)):
|
||||
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
for i, j in enumerate(cosyvoice.inference_instruct2('收到好友从远方寄来的生日礼物,那份意外的惊喜与深深的祝福让我心中充满了甜蜜的快乐,笑容如花儿般绽放。', 'You are a helpful assistant. 请用尽可能快地语速说一句话。<|endofprompt|>',
|
||||
'./asset/zero_shot_prompt.wav', stream=False)):
|
||||
torchaudio.save('instruct_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
# hotfix usage
|
||||
for i, j in enumerate(cosyvoice.inference_zero_shot('高管也通过电话、短信、微信等方式对报道[j][ǐ]予好评。', 'You are a helpful assistant.<|endofprompt|>希望你以后能够做的比我还好呦。',
|
||||
'./asset/zero_shot_prompt.wav', stream=False)):
|
||||
torchaudio.save('hotfix_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
# NOTE for Japanese usage, you must translate it to katakana.
|
||||
# 歴史的世界においては、過去は単に過ぎ去ったものではない、プラトンのいう如く非有が有である。 -> レキシ テキ セカイ ニ オイ テ ワ、カコ ワ タンニ スギサッ タ モノ デ ワ ナイ、プラトン ノ イウ ゴトク ヒ ユー ガ ユー デ アル。
|
||||
for i, j in enumerate(cosyvoice.inference_cross_lingual('You are a helpful assistant.<|endofprompt|>レキシ テキ セカイ ニ オイ テ ワ、カコ ワ タンニ スギサッ タ モノ デ ワ ナイ、プラトン ノ イウ ゴトク ヒ ユー ガ ユー デ アル。',
|
||||
'./asset/zero_shot_prompt.wav', stream=False)):
|
||||
torchaudio.save('japanese_{}.wav'.format(i), j['tts_speech'], cosyvoice.sample_rate)
|
||||
|
||||
|
||||
def main():
|
||||
# cosyvoice_example()
|
||||
# cosyvoice2_example()
|
||||
cosyvoice3_example()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
main()
|
||||
6
models/CosyVoice/examples/grpo/cosyvoice2/Dockerfile
Normal file
6
models/CosyVoice/examples/grpo/cosyvoice2/Dockerfile
Normal file
@@ -0,0 +1,6 @@
|
||||
FROM verlai/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2
|
||||
COPY requirements.txt /myworkspace/requirements.txt
|
||||
RUN pip install -r /myworkspace/requirements.txt
|
||||
RUN pip install -U nvidia-pytriton
|
||||
RUN git clone https://github.com/yuekaizhang/verl.git /myworkspace/verl -b thread && cd /myworkspace/verl && pip install --no-deps -e .
|
||||
RUN git clone https://github.com/yuekaizhang/PytritonSenseVoice.git /myworkspace/PytritonSenseVoice && cd /myworkspace/PytritonSenseVoice && pip install -e .
|
||||
125
models/CosyVoice/examples/grpo/cosyvoice2/README.md
Normal file
125
models/CosyVoice/examples/grpo/cosyvoice2/README.md
Normal file
@@ -0,0 +1,125 @@
|
||||
# CosyVoice2 LLM Reinforcement Learning Recipe
|
||||
|
||||
This recipe demonstrates how to fine-tune the **CosyVoice2** large language model with reinforcement learning algorithms—specifically **GRPO**—using the [veRL](https://github.com/volcengine/verl) framework. Our experiments show that applying GRPO reduces the character error rate (CER) on the CosyVoice3 `zero_shot_zh` set from 4.08% to 3.36%.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Environment Setup](#environment-setup)
|
||||
- [Data Preparation](#data-preparation)
|
||||
- [Reward Function & ASR Server](#reward-function--asr-server)
|
||||
- [Training](#training)
|
||||
- [Evaluation](#evaluation)
|
||||
- [Export Model](#export-model)
|
||||
- [Results](#results)
|
||||
- [Acknowledgement](#acknowledgement)
|
||||
|
||||
## Environment Setup
|
||||
We recommend using the pre-built Docker image below. Alternatively, you can manually install the dependencies following the Dockerfile.
|
||||
```bash
|
||||
docker pull soar97/verl:app-verl0.4-vllm0.8.5-mcore0.12.2-te2.2
|
||||
```
|
||||
If Docker is not available, you can refer to `run.sh` `stage -2` to install the dependencies locally.
|
||||
|
||||
## Data Preparation
|
||||
|
||||
`prepare_data.py` expects a JSON/JSONL file with at least the following schema:
|
||||
|
||||
```jsonc
|
||||
{
|
||||
"text": "An example sentence to be synthesized."
|
||||
}
|
||||
```
|
||||
You can download the JSONL files from the metadata directory of the [SparkAudio/voxbox](https://huggingface.co/datasets/SparkAudio/voxbox/tree/main/metadata) dataset on Hugging Face.
|
||||
|
||||
Stage `0` converts raw JSONL files into the parquet format expected by veRL:
|
||||
|
||||
```bash
|
||||
bash run.sh 0 0
|
||||
```
|
||||
Create two JSONL files—`train.jsonl` and `test.jsonl`.
|
||||
The script will then generate two Parquet files:
|
||||
|
||||
```
|
||||
data/parquet_tiny/train.parquet
|
||||
data/parquet_tiny/test.parquet
|
||||
```
|
||||
|
||||
Each sample is automatically wrapped into a CosyVoice2-style prompt so that the LLM learns to output CosyVoice2 speech tokens.
|
||||
|
||||
|
||||
## Reward Function & ASR Server
|
||||
|
||||
To compute rewards, we run a lightweight server that:
|
||||
|
||||
1. Converts generated speech tokens back to a 16 kHz waveform with the **CosyVoice2** pretrained U-Net model.
|
||||
2. Transcribes the waveform with **SenseVoice** ASR.
|
||||
3. Calculates the pinyin-level error rate relative to the ground-truth text and maps it to a score between 0 and 1.
|
||||
|
||||
Start the server (stage `1`) in a dedicated terminal or on a separate GPU:
|
||||
|
||||
```bash
|
||||
bash run.sh 1 1
|
||||
# Triton server listens on ports 8000/8001/8002
|
||||
```
|
||||
|
||||
The custom reward implementation is located in [`reward_tts.py`](./reward_tts.py) and calls the server to obtain the reward score.
|
||||
|
||||
## Training
|
||||
|
||||
Run stage `2` to start GRPO training:
|
||||
|
||||
```bash
|
||||
bash run.sh 2 2
|
||||
```
|
||||
|
||||
Key CLI arguments passed to `verl.trainer.main_ppo`:
|
||||
|
||||
* `algorithm.adv_estimator=grpo` – use GRPO instead of PPO.
|
||||
* `data.train_files=data/parquet_aishell3/train.parquet` and `data.val_files=data/parquet_aishell3/test.parquet`
|
||||
* `custom_reward_function.path=reward_tts.py` – custom reward function described above.
|
||||
|
||||
Adjust `CUDA_VISIBLE_DEVICES`, batch sizes, and other hyperparameters to match your hardware.
|
||||
> [!TIP]
|
||||
> Note: the lm_head bias is disabled during training to make the model compatible with VLLM and Transformers' Qwen model.
|
||||
|
||||
## Evaluation
|
||||
|
||||
After training is complete, collect the sharded FSDP weights and export a Hugging Face-style checkpoint (stage `3`):
|
||||
|
||||
```bash
|
||||
bash run.sh 3 3 # merges weights into $llm_path/merged_hf_model
|
||||
```
|
||||
|
||||
You can then evaluate the model on the CosyVoice3 zero-shot Chinese test set (stage `4`):
|
||||
|
||||
```bash
|
||||
bash run.sh 4 4
|
||||
```
|
||||
|
||||
This command launches distributed inference via `infer_dataset.py` and computes WER with `scripts/compute_wer.sh`.
|
||||
|
||||
> [!TIP]
|
||||
> The script also supports the Seed-TTS test set by setting `dataset=test_zh`.
|
||||
|
||||
## Export Model
|
||||
|
||||
To use the RL-trained model with the official CosyVoice repository:
|
||||
|
||||
```bash
|
||||
bash run.sh 5 5
|
||||
```
|
||||
|
||||
The script converts the Hugging Face checkpoint back into the format expected by the CosyVoice repository.
|
||||
> [!TIP]
|
||||
> However, we observed a slight accuracy drop when using the RL-trained model after conversion, compared with the Hugging Face format.
|
||||
|
||||
## Results
|
||||
|
||||
| Model | Seed-TTS `test_zh` CER | CosyVoice3 `zero_shot_zh` CER | Comment |
|
||||
|-------|------------------------|------------------------------|---------|
|
||||
| CosyVoice2 LLM (official) | 1.45% | 4.08% | See the [paper](https://arxiv.org/abs/2412.10117) |
|
||||
| CosyVoice2 LLM + GRPO | 1.37% | **3.36%** | See the [decoding results](yuekai/official-cosyvoice-llm-grpo-aishell3), Hugging Face-format model |
|
||||
|
||||
## Acknowledgement
|
||||
|
||||
This work was inspired by the implementation in [ch-tts-llasa-rl-grpo](https://github.com/channel-io/ch-tts-llasa-rl-grpo).
|
||||
@@ -0,0 +1,71 @@
|
||||
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
python3 hf2pretrained.py --hf-cosyvoice2-llm-path /workspace/rl-exp/checkpoint-400 --output-path /workspace/CosyVoice2-0.5B/llm-new.pt
|
||||
"""
|
||||
from argparse import ArgumentParser
|
||||
import torch
|
||||
from safetensors import safe_open
|
||||
from transformers import AutoTokenizer
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--hf-cosyvoice2-llm-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The RL trained CosyVoice2 model path in HuggingFace format",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-path",
|
||||
type=str,
|
||||
default="./llm.pt",
|
||||
help="The path to save the llm.pt",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.hf_cosyvoice2_llm_path)
|
||||
speech_start_idx = tokenizer.convert_tokens_to_ids("<|s_0|>")
|
||||
cosyvoice2_token_size = 6561 + 3
|
||||
llm_embedding_vocab_size = 2
|
||||
|
||||
hf_tensors = {}
|
||||
with safe_open(f"{args.hf_cosyvoice2_llm_path}/model.safetensors", framework="pt", device="cpu") as f:
|
||||
for k in f.keys():
|
||||
if k.startswith("lm_head.bias"):
|
||||
# RL trained model disable bias for lm_head
|
||||
continue
|
||||
new_k = "llm.model." + k
|
||||
hf_tensors[new_k] = f.get_tensor(k)
|
||||
if k.startswith("lm_head"):
|
||||
hf_tensors["llm_decoder.weight"] = f.get_tensor(k)[speech_start_idx:speech_start_idx + cosyvoice2_token_size]
|
||||
hf_tensors["llm_decoder.bias"] = torch.zeros_like(hf_tensors["llm_decoder.weight"][:, 0])
|
||||
if k.startswith("model.embed_tokens"):
|
||||
hf_tensors["speech_embedding.weight"] = f.get_tensor(k)[speech_start_idx:speech_start_idx + cosyvoice2_token_size]
|
||||
hf_tensors["llm_embedding.weight"] = f.get_tensor(k)[speech_start_idx + cosyvoice2_token_size:speech_start_idx + cosyvoice2_token_size + llm_embedding_vocab_size]
|
||||
|
||||
# use tie_word_embeddings=True
|
||||
hf_tensors["llm.model.model.embed_tokens.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"][:151936]
|
||||
hf_tensors["llm.model.lm_head.weight"] = hf_tensors["llm.model.model.embed_tokens.weight"]
|
||||
|
||||
torch.save(hf_tensors, args.output_path)
|
||||
397
models/CosyVoice/examples/grpo/cosyvoice2/infer_dataset.py
Normal file
397
models/CosyVoice/examples/grpo/cosyvoice2/infer_dataset.py
Normal file
@@ -0,0 +1,397 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
""" Example Usage
|
||||
dataset=zero_shot_zh
|
||||
output_dir=./outputs_rl_aishell3_step${step}_${dataset}_jit_trt_fp16_reward_tts
|
||||
|
||||
token2wav_path=/workspace/CosyVoice2-0.5B
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
torchrun --nproc_per_node=8 \
|
||||
infer_dataset.py \
|
||||
--output-dir $output_dir \
|
||||
--llm-model-name-or-path $llm_path/merged_hf_model \
|
||||
--token2wav-path $token2wav_path \
|
||||
--split-name ${dataset} || exit 1
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn.functional as F
|
||||
import torchaudio
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||
from cosyvoice.utils.file_utils import load_wav
|
||||
from datasets import load_dataset
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
||||
from tqdm import tqdm
|
||||
import soundfile as sf
|
||||
import s3tokenizer
|
||||
from functools import partial
|
||||
|
||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||
try:
|
||||
torch.multiprocessing.set_start_method("spawn")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
|
||||
TEMPLATE = "{% for message in messages %}{%- if message['role'] == 'user' %}{{- '<|im_start|>' + message['role'] + '\n' + 'Convert the text to speech: ' + message['content'] + '<|im_end|>\n'}}{%- elif message['role'] == 'assistant' %}{{- '<|im_start|>' + message['role'] + '\n' + '<|SPEECH_GENERATION_START|>' + message['content']}}{%- endif %}{%- endfor %}" # noqa: E501
|
||||
|
||||
|
||||
def audio_decode_cosyvoice2(
|
||||
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
|
||||
):
|
||||
"""
|
||||
Generate audio from tokens with optional tone and prompt embedding.
|
||||
"""
|
||||
model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
|
||||
"empty", prompt_text, prompt_speech_16k, 24000
|
||||
)
|
||||
tts_mel, _ = codec_decoder.model.flow.inference(
|
||||
token=audio_tokens.to(codec_decoder.model.device),
|
||||
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
prompt_token_len=torch.tensor(
|
||||
[model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
|
||||
).to(codec_decoder.model.device),
|
||||
prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
|
||||
finalize=True,
|
||||
)
|
||||
|
||||
audio_hat, _ = codec_decoder.model.hift.inference(
|
||||
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
||||
)
|
||||
|
||||
return audio_hat
|
||||
|
||||
|
||||
def extract_speech_ids(speech_tokens_str):
|
||||
"""Extract speech IDs from token strings like <|s_23456|>"""
|
||||
speech_ids = []
|
||||
for token_str in speech_tokens_str:
|
||||
if token_str.startswith('<|s_') and token_str.endswith('|>'):
|
||||
num_str = token_str[4:-2]
|
||||
num = int(num_str)
|
||||
speech_ids.append(num)
|
||||
else:
|
||||
print(f"Unexpected token: {token_str}")
|
||||
return speech_ids
|
||||
|
||||
|
||||
def convert_cosy2_tokens_to_speech_id_str(cosy2_tokens):
|
||||
"""Convert CosyVoice2 tokens to speech IDs string like <|s_23456|>"""
|
||||
speech_id_str = ""
|
||||
for token in cosy2_tokens:
|
||||
speech_id_str += f"<|s_{token}|>"
|
||||
return speech_id_str
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(description="Speech generation using LLM + CosyVoice2")
|
||||
parser.add_argument(
|
||||
"--split-name",
|
||||
type=str,
|
||||
default="wenetspeech4tts",
|
||||
help="huggingface dataset split name, see yuekai/CV3-Eval, yuekai/seed_tts_cosy2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir", required=True, type=str, help="dir to save result"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--batch-size",
|
||||
default=1,
|
||||
type=int,
|
||||
help="batch size (per-device) for inference",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-workers", type=int, default=1, help="workers for dataloader"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prefetch", type=int, default=5, help="prefetch for dataloader"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--llm-model-name-or-path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="LLM model path (includes both model and tokenizer)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--token2wav-path",
|
||||
required=True,
|
||||
type=str,
|
||||
help="CosyVoice2 token2wav model path",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-text",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The prompt text for CosyVoice2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--prompt-speech-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the prompt speech for CosyVoice2",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-p",
|
||||
type=float,
|
||||
default=0.95,
|
||||
help="top p for sampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--temperature",
|
||||
type=float,
|
||||
default=0.8,
|
||||
help="temperature for sampling",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--top-k",
|
||||
type=int,
|
||||
default=50,
|
||||
help="top k for sampling",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
def data_collator(batch, tokenizer, s3_tokenizer):
|
||||
"""Simplified data collator for batch_size=1 processing"""
|
||||
target_sample_rate = 16000 # CosyVoice2 uses 16kHz for prompt audio
|
||||
device = s3_tokenizer.device if s3_tokenizer is not None else torch.device("cpu")
|
||||
input_ids_list, prompt_audio_list, prompt_text_list = [], [], []
|
||||
mels, prompt_audio_cosy2tokens_list = [], []
|
||||
for item in batch:
|
||||
prompt_text, target_text = (
|
||||
item["prompt_text"],
|
||||
item["target_text"],
|
||||
)
|
||||
prompt_text_list.append(prompt_text)
|
||||
# Combine prompt and target text
|
||||
full_text = prompt_text + target_text
|
||||
|
||||
# get prompt audio for CosyVoice2 (convert to 16kHz)
|
||||
ref_audio_org, ref_sr = (
|
||||
item["prompt_audio"]["array"],
|
||||
item["prompt_audio"]["sampling_rate"],
|
||||
)
|
||||
ref_audio_org = torch.from_numpy(ref_audio_org).float().unsqueeze(0)
|
||||
# ref_audio_org = ref_audio_org.mean(dim=0, keepdim=True)
|
||||
print(ref_audio_org.shape)
|
||||
|
||||
if ref_sr != target_sample_rate:
|
||||
resampler = torchaudio.transforms.Resample(ref_sr, target_sample_rate)
|
||||
ref_audio = resampler(ref_audio_org)
|
||||
else:
|
||||
ref_audio = ref_audio_org
|
||||
|
||||
prompt_audio_list.append(ref_audio)
|
||||
|
||||
if "prompt_audio_cosy2_tokens" in item:
|
||||
prompt_audio_cosy2tokens = item["prompt_audio_cosy2_tokens"]
|
||||
prompt_audio_cosy2tokens_list.append(prompt_audio_cosy2tokens)
|
||||
else:
|
||||
# convert to float first
|
||||
mels.append(s3tokenizer.log_mel_spectrogram(ref_audio.squeeze(0)))
|
||||
|
||||
if len(mels) > 0:
|
||||
mels, mels_lens = s3tokenizer.padding(mels)
|
||||
codes, codes_lens = s3_tokenizer.quantize(mels.to(device), mels_lens.to(device))
|
||||
for i in range(len(codes)):
|
||||
prompt_audio_cosy2tokens_list.append(codes[i, :codes_lens[i].item()])
|
||||
for prompt_audio_cosy2tokens in prompt_audio_cosy2tokens_list:
|
||||
prompt_audio_cosy2_id_str = convert_cosy2_tokens_to_speech_id_str(prompt_audio_cosy2tokens)
|
||||
# Create chat template for LLM generation
|
||||
chat = [
|
||||
{"role": "user", "content": full_text},
|
||||
{"role": "assistant", "content": prompt_audio_cosy2_id_str}
|
||||
]
|
||||
if 'system' in tokenizer.chat_template:
|
||||
tokenizer.chat_template = TEMPLATE
|
||||
input_ids = tokenizer.apply_chat_template(
|
||||
chat,
|
||||
tokenize=True,
|
||||
return_tensors='pt',
|
||||
continue_final_message=True
|
||||
)
|
||||
input_ids_list.append(input_ids.squeeze(0))
|
||||
|
||||
# For batch_size=1, no need to pad
|
||||
if len(input_ids_list) == 1:
|
||||
input_ids = input_ids_list[0].unsqueeze(0)
|
||||
else:
|
||||
# Handle batch > 1 if needed
|
||||
max_len = max([len(input_ids) for input_ids in input_ids_list])
|
||||
input_ids_list = [
|
||||
torch.cat([torch.full((max_len - len(input_ids),), tokenizer.pad_token_id), input_ids])
|
||||
for input_ids in input_ids_list
|
||||
]
|
||||
input_ids = torch.stack(input_ids_list)
|
||||
|
||||
ids = [item["id"] for item in batch]
|
||||
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"ids": ids,
|
||||
"prompt_text": prompt_text_list,
|
||||
"prompt_audio_list": prompt_audio_list,
|
||||
}
|
||||
|
||||
|
||||
def init_distributed():
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
rank = int(os.environ.get("RANK", 0))
|
||||
print(
|
||||
"Inference on multiple gpus, this gpu {}".format(local_rank)
|
||||
+ ", rank {}, world_size {}".format(rank, world_size)
|
||||
)
|
||||
torch.cuda.set_device(local_rank)
|
||||
dist.init_process_group("nccl")
|
||||
return world_size, local_rank, rank
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
os.makedirs(args.output_dir, exist_ok=True)
|
||||
|
||||
assert torch.cuda.is_available()
|
||||
world_size, local_rank, rank = init_distributed()
|
||||
device = torch.device(f"cuda:{local_rank}")
|
||||
|
||||
# Load LLM model and tokenizer directly
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.llm_model_name_or_path)
|
||||
model = AutoModelForCausalLM.from_pretrained(args.llm_model_name_or_path)
|
||||
model.eval()
|
||||
model.to(device)
|
||||
|
||||
cosyvoice_codec = CosyVoice2(
|
||||
args.token2wav_path, load_jit=True, load_trt=True, fp16=True
|
||||
)
|
||||
if args.prompt_speech_path:
|
||||
prompt_speech_16k = load_wav(args.prompt_speech_path, 16000)
|
||||
else:
|
||||
prompt_speech_16k = None
|
||||
s3_tokenizer = s3tokenizer.load_model("speech_tokenizer_v2_25hz").to(device) if 'zero' in args.split_name else None
|
||||
dataset_name = "yuekai/CV3-Eval" if 'zero' in args.split_name else "yuekai/seed_tts_cosy2"
|
||||
dataset = load_dataset(
|
||||
dataset_name,
|
||||
split=args.split_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
sampler = DistributedSampler(dataset, num_replicas=world_size, rank=rank)
|
||||
|
||||
dataloader = DataLoader(
|
||||
dataset,
|
||||
batch_size=args.batch_size,
|
||||
sampler=sampler,
|
||||
shuffle=False,
|
||||
num_workers=args.num_workers,
|
||||
prefetch_factor=args.prefetch,
|
||||
collate_fn=partial(data_collator, tokenizer=tokenizer, s3_tokenizer=s3_tokenizer),
|
||||
)
|
||||
|
||||
total_steps = len(dataset)
|
||||
|
||||
if rank == 0:
|
||||
progress_bar = tqdm(total=total_steps, desc="Processing", unit="wavs")
|
||||
|
||||
for batch in dataloader:
|
||||
with torch.no_grad():
|
||||
input_ids = batch["input_ids"].to(device)
|
||||
|
||||
# Generate speech tokens using LLM
|
||||
outputs = model.generate(
|
||||
input_ids,
|
||||
max_new_tokens=2048, # Max length for generation
|
||||
do_sample=True,
|
||||
top_p=args.top_p,
|
||||
temperature=args.temperature,
|
||||
top_k=args.top_k,
|
||||
)
|
||||
|
||||
# Process each sample in the batch
|
||||
for i in range(len(batch["ids"])):
|
||||
# Extract generated tokens (excluding input)
|
||||
input_length = input_ids[i].shape[0]
|
||||
generated_ids = outputs[i][input_length:-1] # Remove last token if needed
|
||||
speech_tokens_str = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
|
||||
|
||||
# Extract speech IDs from token strings like <|s_23456|>
|
||||
speech_ids = extract_speech_ids(speech_tokens_str)
|
||||
|
||||
if len(speech_ids) == 0:
|
||||
print(f"Warning: No speech tokens generated for sample {batch['ids'][i]}, skipping")
|
||||
continue
|
||||
|
||||
# Convert to tensor for CosyVoice2
|
||||
audio_tokens = torch.tensor(speech_ids, dtype=torch.long, device=device).unsqueeze(0)
|
||||
|
||||
if args.prompt_text is not None:
|
||||
current_prompt_text = args.prompt_text
|
||||
current_prompt_audio = prompt_speech_16k
|
||||
else:
|
||||
current_prompt_text = batch["prompt_text"][i]
|
||||
current_prompt_audio = batch["prompt_audio_list"][i]
|
||||
|
||||
if current_prompt_audio is not None:
|
||||
# Generate audio using CosyVoice2
|
||||
audio_hat = audio_decode_cosyvoice2(
|
||||
audio_tokens,
|
||||
current_prompt_text,
|
||||
current_prompt_audio,
|
||||
cosyvoice_codec,
|
||||
)
|
||||
|
||||
# Convert to numpy and save
|
||||
generated_wave = audio_hat.squeeze(0).cpu().numpy()
|
||||
target_sample_rate = 24000
|
||||
|
||||
utt = batch["ids"][i]
|
||||
sf.write(f"{args.output_dir}/{utt}.wav", generated_wave, target_sample_rate)
|
||||
|
||||
print(f"Generated audio for sample {utt} with {len(speech_ids)} tokens")
|
||||
else:
|
||||
print(f"Warning: No prompt audio available for sample {batch['ids'][i]}, skipping")
|
||||
|
||||
if rank == 0:
|
||||
progress_bar.update(world_size * len(batch["ids"]))
|
||||
|
||||
if rank == 0:
|
||||
progress_bar.close()
|
||||
|
||||
dist.barrier()
|
||||
dist.destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
86
models/CosyVoice/examples/grpo/cosyvoice2/prepare_data.py
Normal file
86
models/CosyVoice/examples/grpo/cosyvoice2/prepare_data.py
Normal file
@@ -0,0 +1,86 @@
|
||||
# Copyright 2024 Bytedance Ltd. and/or its affiliates
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Preprocess the Text to Speech dataset to parquet format
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
|
||||
import datasets
|
||||
|
||||
from verl.utils.hdfs_io import copy, makedirs
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--train_file", required=True, help="Path to training JSON/JSONL file")
|
||||
parser.add_argument("--test_file", required=True, help="Path to test JSON/JSONL file")
|
||||
parser.add_argument("--local_dir", default=None, required=True)
|
||||
parser.add_argument("--hdfs_dir", default=None)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load datasets from local JSON files
|
||||
train_dataset = datasets.load_dataset("json", data_files=args.train_file)['train']
|
||||
test_dataset = datasets.load_dataset("json", data_files=args.test_file)['train']
|
||||
|
||||
# add a row to each data item that represents a unique id
|
||||
def make_map_fn(split):
|
||||
def process_fn(example, idx):
|
||||
text = example.pop("text")
|
||||
|
||||
# use cosyvoice2 official huggingface compatible checkpoint template
|
||||
question = text
|
||||
answer = ""
|
||||
|
||||
data = {
|
||||
"data_source": f"{args.train_file}_{args.test_file}", # Use file names as data source
|
||||
"prompt": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": question,
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": answer,
|
||||
},
|
||||
],
|
||||
"ability": "text-to-speech",
|
||||
"reward_model": {"style": "rule", "ground_truth": text},
|
||||
"extra_info": {
|
||||
"split": split,
|
||||
"index": idx,
|
||||
"text": text,
|
||||
},
|
||||
}
|
||||
return data
|
||||
|
||||
return process_fn
|
||||
|
||||
train_dataset = train_dataset.map(function=make_map_fn("train"), with_indices=True)
|
||||
test_dataset = test_dataset.map(function=make_map_fn("test"), with_indices=True)
|
||||
|
||||
local_dir = args.local_dir
|
||||
hdfs_dir = args.hdfs_dir
|
||||
|
||||
print(train_dataset)
|
||||
print(test_dataset)
|
||||
train_dataset.to_parquet(os.path.join(local_dir, "train.parquet"))
|
||||
test_dataset.to_parquet(os.path.join(local_dir, "test.parquet"))
|
||||
|
||||
if hdfs_dir is not None:
|
||||
makedirs(hdfs_dir)
|
||||
|
||||
copy(src=local_dir, dst=hdfs_dir)
|
||||
@@ -0,0 +1,133 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Usage: Instruct TTS
|
||||
python3 infer.py \
|
||||
--token2wav-path /workspace/CosyVoice2-0.5B \
|
||||
--prompt-text "吃燕窝就选燕之屋,本节目由26年专注高品质燕窝的燕之屋冠名播出。豆奶牛奶换着喝,营养更均衡,本节目由豆本豆豆奶特约播出。" \
|
||||
--prompt-speech-path ./assets/prompt_audio.wav \
|
||||
--model-path ./transformers_cosyvoice2_llm \
|
||||
--input-text "用四川话说<|endofprompt|>扁担长,板凳宽,扁担绑在板凳上。吃葡萄不吐葡萄皮,不吃葡萄倒吐葡萄皮。"
|
||||
"""
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||
import sys
|
||||
from argparse import ArgumentParser
|
||||
from transformers import AutoTokenizer, AutoModelForCausalLM
|
||||
import torch
|
||||
|
||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--pretrained-cosyvoice2-path",
|
||||
type=str,
|
||||
default="/workspace/CosyVoice2-0.5B",
|
||||
help="Token2Wav path, default to %(default)r",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--save-path",
|
||||
type=str,
|
||||
default='./transformers_cosyvoice2_llm',
|
||||
help="The path to save the model",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
return args
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
args = get_args()
|
||||
cosy2_model = CosyVoice2(
|
||||
args.pretrained_cosyvoice2_path, load_jit=False, load_trt=False, fp16=False
|
||||
)
|
||||
|
||||
llm = cosy2_model.model.llm.llm.model
|
||||
|
||||
speech_embedding = cosy2_model.model.llm.speech_embedding
|
||||
llm_decoder = cosy2_model.model.llm.llm_decoder
|
||||
llm_embedding = cosy2_model.model.llm.llm_embedding
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(f"{args.pretrained_cosyvoice2_path}/CosyVoice-BlankEN")
|
||||
special_tokens = {
|
||||
'eos_token': '<|endoftext|>',
|
||||
'pad_token': '<|endoftext|>',
|
||||
'additional_special_tokens': [
|
||||
'<|im_start|>', '<|im_end|>', '<|endofprompt|>',
|
||||
'[breath]', '<strong>', '</strong>', '[noise]',
|
||||
'[laughter]', '[cough]', '[clucking]', '[accent]',
|
||||
'[quick_breath]',
|
||||
"<laughter>", "</laughter>",
|
||||
"[hissing]", "[sigh]", "[vocalized-noise]",
|
||||
"[lipsmack]", "[mn]"
|
||||
]
|
||||
}
|
||||
tokenizer.add_special_tokens(special_tokens)
|
||||
|
||||
original_tokenizer_vocab_size = len(tokenizer)
|
||||
cosyvoice2_token_size = 6561
|
||||
new_tokens = [f"<|s_{i}|>" for i in range(cosyvoice2_token_size)] + [
|
||||
"<|eos1|>", "<|eos2|>", "<|eos3|>", "<|sos|>", "<|task_id|>"
|
||||
]
|
||||
num_added_tokens = tokenizer.add_tokens(new_tokens)
|
||||
|
||||
llm.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=128)
|
||||
vocab_size = llm.get_input_embeddings().weight.shape[0]
|
||||
|
||||
feature_size = speech_embedding.embedding_dim
|
||||
new_lm_head = torch.nn.Linear(in_features=feature_size, out_features=vocab_size, bias=True)
|
||||
|
||||
with torch.no_grad():
|
||||
# set the weight and bias of the new lm_head to 0
|
||||
new_lm_head.weight.data.zero_()
|
||||
# make bias value -inf
|
||||
new_lm_head.bias.data.fill_(-float('inf'))
|
||||
new_lm_head.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.weight
|
||||
new_lm_head.bias[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = llm_decoder.bias
|
||||
|
||||
llm.lm_head = new_lm_head
|
||||
input_embeddings = llm.get_input_embeddings()
|
||||
|
||||
with torch.no_grad():
|
||||
input_embeddings.weight[original_tokenizer_vocab_size:original_tokenizer_vocab_size + cosyvoice2_token_size + 3] = speech_embedding.weight
|
||||
input_embeddings.weight[original_tokenizer_vocab_size + cosyvoice2_token_size + 3:original_tokenizer_vocab_size + cosyvoice2_token_size + 3 + 2] = llm_embedding.weight
|
||||
|
||||
eos_token_ids = [original_tokenizer_vocab_size + cosyvoice2_token_size,
|
||||
original_tokenizer_vocab_size + cosyvoice2_token_size + 1,
|
||||
original_tokenizer_vocab_size + cosyvoice2_token_size + 2]
|
||||
llm.generation_config.eos_token_id = eos_token_ids
|
||||
llm.generation_config.temperature = 1.0
|
||||
llm.generation_config.top_p = 0.8
|
||||
llm.generation_config.top_k = 25
|
||||
|
||||
llm.config.eos_token_id = original_tokenizer_vocab_size + cosyvoice2_token_size
|
||||
llm.config.vocab_size = vocab_size
|
||||
llm.config.tie_word_embeddings = False
|
||||
llm.config.use_bias = True
|
||||
llm.to(torch.bfloat16)
|
||||
llm.save_pretrained(args.save_path)
|
||||
|
||||
TEMPLATE = (
|
||||
"{%- for message in messages %}"
|
||||
"{%- if message['role'] == 'user' %}"
|
||||
"{{- '<|sos|>' + message['content'] + '<|task_id|>' }}"
|
||||
"{%- elif message['role'] == 'assistant' %}"
|
||||
"{{- message['content']}}"
|
||||
"{%- endif %}"
|
||||
"{%- endfor %}"
|
||||
)
|
||||
tokenizer.chat_template = TEMPLATE
|
||||
tokenizer.save_pretrained(args.save_path)
|
||||
31
models/CosyVoice/examples/grpo/cosyvoice2/requirements.txt
Normal file
31
models/CosyVoice/examples/grpo/cosyvoice2/requirements.txt
Normal file
@@ -0,0 +1,31 @@
|
||||
conformer==0.3.2
|
||||
diffusers==0.29.0
|
||||
gdown==5.1.0
|
||||
gradio
|
||||
hydra-core==1.3.2
|
||||
HyperPyYAML==1.2.2
|
||||
inflect==7.3.1
|
||||
librosa==0.10.2
|
||||
lightning==2.2.4
|
||||
matplotlib==3.7.5
|
||||
modelscope==1.15.0
|
||||
networkx==3.1
|
||||
omegaconf==2.3.0
|
||||
onnx==1.16.0
|
||||
onnxruntime-gpu==1.18.0
|
||||
protobuf==4.25
|
||||
pydantic==2.7.0
|
||||
pyworld==0.3.4
|
||||
rich==13.7.1
|
||||
soundfile==0.12.1
|
||||
tensorboard==2.14.0
|
||||
wget==3.2
|
||||
WeTextProcessing==1.0.3
|
||||
s3tokenizer
|
||||
tensorrt
|
||||
sherpa_onnx
|
||||
jiwer
|
||||
zhon
|
||||
numpy==1.25.2
|
||||
pypinyin
|
||||
openai-whisper
|
||||
233
models/CosyVoice/examples/grpo/cosyvoice2/reward_tts.py
Normal file
233
models/CosyVoice/examples/grpo/cosyvoice2/reward_tts.py
Normal file
@@ -0,0 +1,233 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Reward calculation for CosyVoice2-0.5B.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
import json
|
||||
import time
|
||||
import argparse
|
||||
from typing import List
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
|
||||
REWARD_SERVER_URL = "http://localhost:8000/v2/models/token2wav_asr/infer"
|
||||
|
||||
|
||||
def _parse_ids(token_str: str) -> List[int]:
|
||||
return [int(t) for t in re.findall(r"<\|s_(\d+)\|>", token_str)]
|
||||
|
||||
|
||||
def _remote_reward(tokens: List[int], ground_truth: str, timeout: float = 200.0) -> float:
|
||||
"""Send token IDs and ground-truth text to the Triton server and get reward."""
|
||||
|
||||
tokens_arr = np.array(tokens, dtype=np.int32).reshape(1, -1)
|
||||
lens_arr = np.array([[tokens_arr.shape[1]]], dtype=np.int32)
|
||||
|
||||
gt_arr = np.array([ground_truth.encode("utf-8")], dtype=object)
|
||||
|
||||
payload = {
|
||||
"inputs": [
|
||||
{
|
||||
"name": "TOKENS",
|
||||
"shape": list(tokens_arr.shape),
|
||||
"datatype": "INT32",
|
||||
"data": tokens_arr.tolist(),
|
||||
},
|
||||
{
|
||||
"name": "TOKEN_LENS",
|
||||
"shape": list(lens_arr.shape),
|
||||
"datatype": "INT32",
|
||||
"data": lens_arr.tolist(),
|
||||
},
|
||||
{
|
||||
"name": "GT_TEXT",
|
||||
"shape": [1, 1],
|
||||
"datatype": "BYTES",
|
||||
"data": [ground_truth],
|
||||
},
|
||||
]
|
||||
}
|
||||
rsp = requests.post(
|
||||
REWARD_SERVER_URL,
|
||||
headers={"Content-Type": "application/json"},
|
||||
json=payload,
|
||||
timeout=timeout,
|
||||
verify=False,
|
||||
params={"request_id": "0"},
|
||||
)
|
||||
rsp.raise_for_status()
|
||||
result = rsp.json()
|
||||
|
||||
try:
|
||||
# Reward is returned as the first output
|
||||
return float(result["outputs"][0]["data"][0])
|
||||
except (KeyError, IndexError, TypeError):
|
||||
return 0.0
|
||||
|
||||
|
||||
def compute_score(
|
||||
data_source: str,
|
||||
solution_str: str,
|
||||
ground_truth: str,
|
||||
extra_info: dict | None = None,
|
||||
*,
|
||||
debug_dump: bool = False,
|
||||
) -> float:
|
||||
"""Return reward in [0, 1] using the Triton ASR service.
|
||||
|
||||
The reward is based on the pinyin-level WER between the ASR transcript
|
||||
produced from *solution_str* and the provided *ground_truth* text.
|
||||
"""
|
||||
|
||||
# Decode token IDs
|
||||
ids = _parse_ids(solution_str)
|
||||
|
||||
# Query remote server for reward
|
||||
try:
|
||||
reward = _remote_reward(ids, ground_truth)
|
||||
except Exception as e:
|
||||
reward = 0.0
|
||||
|
||||
if debug_dump:
|
||||
print(
|
||||
f"\033[92m[{data_source}] Remote reward: {reward:.4f}\033[0m"
|
||||
)
|
||||
|
||||
return reward
|
||||
|
||||
|
||||
# CLI quick test
|
||||
if __name__ == "__main__":
|
||||
import sys
|
||||
|
||||
def get_args():
|
||||
"""Parse command line arguments."""
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Test TTS CER scoring with data from JSONL file",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--input", "-i",
|
||||
type=str,
|
||||
default="data/emilia_zh-cosy-tiny-test.jsonl",
|
||||
help="Path to input JSONL file"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--max-samples", "-n",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Maximum number of samples to process (default: all)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--no-interactive",
|
||||
action="store_true",
|
||||
help="Run in non-interactive mode (process all samples without prompts)"
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
action="store_true",
|
||||
help="Enable debug mode"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
def load_jsonl(file_path: str):
|
||||
"""Load data from jsonl file."""
|
||||
data = []
|
||||
with open(file_path, 'r', encoding='utf-8') as f:
|
||||
for line in f:
|
||||
data.append(json.loads(line.strip()))
|
||||
return data
|
||||
|
||||
def code_to_solution_str(code_list: List[int]) -> str:
|
||||
"""Convert code list to solution string format."""
|
||||
return ''.join([f"<|s_{code}|>" for code in code_list])
|
||||
|
||||
# Parse command line arguments
|
||||
args = get_args()
|
||||
|
||||
try:
|
||||
# Load data from jsonl file
|
||||
print(f"Loading data from: {args.input}")
|
||||
data_list = load_jsonl(args.input)
|
||||
print(f"Loaded {len(data_list)} samples")
|
||||
|
||||
# Limit samples if specified
|
||||
if args.max_samples is not None:
|
||||
data_list = data_list[:args.max_samples]
|
||||
print(f"Processing first {len(data_list)} samples (limited by --max-samples)")
|
||||
|
||||
# Process each sample
|
||||
begin_time = time.time()
|
||||
for i, sample in enumerate(data_list):
|
||||
print(f"\n--- Sample {i+1}/{len(data_list)} ---")
|
||||
print(f"Index: {sample.get('index', 'unknown')}")
|
||||
print(f"Text: {sample['text']}")
|
||||
|
||||
# Extract required fields
|
||||
code_list = sample['code']
|
||||
ground_truth = sample['text']
|
||||
data_source = sample.get('index', f'sample_{i}') # Use index as data_source
|
||||
|
||||
# Convert code list to solution string
|
||||
solution_str = code_to_solution_str(code_list)
|
||||
print(f"Solution tokens: {len(code_list)} tokens")
|
||||
if args.debug:
|
||||
print(f"Solution string: {solution_str}")
|
||||
else:
|
||||
print(f"Solution string preview: {solution_str[:100]}..." if len(solution_str) > 100 else f"Solution string: {solution_str}")
|
||||
|
||||
# Call compute_score function
|
||||
try:
|
||||
score = compute_score(
|
||||
data_source=data_source,
|
||||
solution_str=solution_str,
|
||||
ground_truth=ground_truth,
|
||||
extra_info=None,
|
||||
debug_dump=args.debug
|
||||
)
|
||||
print(f"Final Score: {score:.4f}")
|
||||
except Exception as e:
|
||||
print(f"Error computing score: {e}")
|
||||
|
||||
# Ask user if they want to continue (for interactive mode)
|
||||
if not args.no_interactive and i < len(data_list) - 1:
|
||||
try:
|
||||
response = input("\nPress Enter to continue or 'q' to quit: ").strip().lower()
|
||||
if response == 'q':
|
||||
break
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopped by user")
|
||||
break
|
||||
|
||||
print(f"\nProcessed {min(i+1, len(data_list))} samples")
|
||||
end_time = time.time()
|
||||
print(f"Time taken: {end_time - begin_time} seconds")
|
||||
except FileNotFoundError:
|
||||
print(f"Error: File not found - {args.input}")
|
||||
print("Please check the file path or use --input to specify correct path")
|
||||
print("Run with --help for usage information")
|
||||
except Exception as e:
|
||||
print(f"Error: {e}")
|
||||
159
models/CosyVoice/examples/grpo/cosyvoice2/run.sh
Normal file
159
models/CosyVoice/examples/grpo/cosyvoice2/run.sh
Normal file
@@ -0,0 +1,159 @@
|
||||
#!/usr/bin/env bash
|
||||
|
||||
set -eou pipefail
|
||||
|
||||
stage=-1
|
||||
stop_stage=4
|
||||
|
||||
log() {
|
||||
# This function is from espnet
|
||||
local fname=${BASH_SOURCE[1]##*/}
|
||||
echo -e "$(date '+%Y-%m-%d %H:%M:%S') (${fname}:${BASH_LINENO[0]}:${FUNCNAME[1]}) $*"
|
||||
}
|
||||
|
||||
export PYTHONPATH=/workspace/CosyVoice
|
||||
model_scope_model_path=./CosyVoice2-0.5B
|
||||
sft_model_path=./transformers_cosyvoice2_llm
|
||||
|
||||
if [ $stage -le -2 ] && [ $stop_stage -ge -2 ]; then
|
||||
log "stage -2: install dependencies locally if pre-built docker image is not available"
|
||||
conda create -n cosyvoice2 python=3.10 -y
|
||||
conda activate cosyvoice2
|
||||
# install verl
|
||||
git clone https://github.com/yuekaizhang/verl.git -b thread
|
||||
cd verl
|
||||
USE_MEGATRON=0 bash scripts/install_vllm_sglang_mcore.sh
|
||||
pip install --no-deps -e .
|
||||
cd -
|
||||
# install requirements
|
||||
pip install -r requirements.txt
|
||||
pip install -U nvidia-pytriton
|
||||
git clone https://github.com/yuekaizhang/PytritonSenseVoice.git && cd PytritonSenseVoice && pip install -e .
|
||||
fi
|
||||
|
||||
if [ $stage -le -1 ] && [ $stop_stage -ge -1 ]; then
|
||||
log "stage -1: download official CosyVoice2-0.5B LLM model and convert to huggingface compatible checkpoint"
|
||||
modelscope download --model iic/CosyVoice2-0.5B --local_dir $model_scope_model_path
|
||||
python3 pretrained_to_huggingface.py \
|
||||
--pretrained-cosyvoice2-path $model_scope_model_path \
|
||||
--save-path $sft_model_path
|
||||
|
||||
# Or, you could use the following command to download the huggingface compatible checkpoint
|
||||
# huggingface-cli download --local-dir $sft_model_path yuekai/cosyvoice2_llm
|
||||
|
||||
# Note: we remove the lm_head's bias to make it compatible with the Qwen2.5-0.5B model in Transformers.
|
||||
fi
|
||||
|
||||
data_dir=data/parquet_aishell3
|
||||
if [ $stage -le 0 ] && [ $stop_stage -ge 0 ]; then
|
||||
log "stage 0: prepare data into verl format"
|
||||
mkdir -p $data_dir
|
||||
wget -O data/aishell-3.jsonl https://huggingface.co/datasets/SparkAudio/voxbox/resolve/main/metadata/aishell-3.jsonl
|
||||
# total 88035 samples
|
||||
head -n 80000 data/aishell-3.jsonl > data/train.jsonl
|
||||
tail -n 100 data/aishell-3.jsonl > data/test.jsonl
|
||||
python prepare_data.py \
|
||||
--train_file data/train.jsonl \
|
||||
--test_file data/test.jsonl \
|
||||
--local_dir $data_dir
|
||||
fi
|
||||
|
||||
if [ $stage -le 1 ] && [ $stop_stage -ge 1 ]; then
|
||||
log "stage 1: start token2wav asr server for reward function"
|
||||
python3 token2wav_asr_server.py --number-of-devices 8
|
||||
fi
|
||||
|
||||
exp_name=official_llm_aishell3_grpo
|
||||
if [ $stage -le 2 ] && [ $stop_stage -ge 2 ]; then
|
||||
log "stage 2: grpo train"
|
||||
export CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7"
|
||||
export MKL_SERVICE_FORCE_INTEL=TRUE
|
||||
n_gpus_per_node=8
|
||||
micro_batch_size=4
|
||||
train_batch_size=32
|
||||
python3 -m verl.trainer.main_ppo \
|
||||
algorithm.adv_estimator=grpo \
|
||||
data.train_files=$data_dir/train.parquet \
|
||||
data.val_files=$data_dir/test.parquet \
|
||||
data.train_batch_size=$train_batch_size \
|
||||
data.max_prompt_length=1024 \
|
||||
data.max_response_length=512 \
|
||||
data.truncation='error' \
|
||||
actor_rollout_ref.model.use_remove_padding=False \
|
||||
actor_rollout_ref.model.path=$sft_model_path \
|
||||
actor_rollout_ref.actor.optim.lr=1e-6 \
|
||||
actor_rollout_ref.actor.ppo_mini_batch_size=32 \
|
||||
actor_rollout_ref.actor.ppo_micro_batch_size_per_gpu=$micro_batch_size \
|
||||
actor_rollout_ref.actor.use_kl_loss=False \
|
||||
actor_rollout_ref.model.enable_gradient_checkpointing=True \
|
||||
actor_rollout_ref.actor.fsdp_config.param_offload=False \
|
||||
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
|
||||
actor_rollout_ref.rollout.log_prob_micro_batch_size_per_gpu=$micro_batch_size \
|
||||
actor_rollout_ref.rollout.tensor_model_parallel_size=1 \
|
||||
actor_rollout_ref.rollout.name=vllm \
|
||||
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
|
||||
actor_rollout_ref.rollout.do_sample=true \
|
||||
actor_rollout_ref.rollout.temperature=0.8 \
|
||||
actor_rollout_ref.rollout.top_p=0.95 \
|
||||
actor_rollout_ref.rollout.top_k=25 \
|
||||
actor_rollout_ref.rollout.n=4 \
|
||||
actor_rollout_ref.rollout.val_kwargs.do_sample=true \
|
||||
actor_rollout_ref.rollout.val_kwargs.temperature=0.8 \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_p=0.95 \
|
||||
actor_rollout_ref.rollout.val_kwargs.top_k=25 \
|
||||
reward_model.reward_manager=prime \
|
||||
custom_reward_function.path=reward_tts.py \
|
||||
custom_reward_function.name=compute_score \
|
||||
trainer.project_name='cosyvoice2_grpo' \
|
||||
trainer.experiment_name=$exp_name \
|
||||
trainer.logger=['console','wandb'] \
|
||||
trainer.n_gpus_per_node=$n_gpus_per_node \
|
||||
trainer.nnodes=1 \
|
||||
trainer.save_freq=100 \
|
||||
trainer.test_freq=100 \
|
||||
trainer.resume_mode='auto' \
|
||||
trainer.total_epochs=1 \
|
||||
trainer.val_before_train=False
|
||||
fi
|
||||
|
||||
steps=(100 200 300 400 500)
|
||||
for step in ${steps[@]}; do
|
||||
llm_path=./checkpoints/cosyvoice2_grpo/$exp_name/global_step_${step}
|
||||
if [ $stage -le 3 ] && [ $stop_stage -ge 3 ]; then
|
||||
log "stage 3: merge the model"
|
||||
python -m verl.model_merger merge \
|
||||
--backend fsdp \
|
||||
--local_dir $llm_path/actor \
|
||||
--target_dir $llm_path/merged_hf_model || exit 1
|
||||
fi
|
||||
|
||||
if [ $stage -le 4 ] && [ $stop_stage -ge 4 ]; then
|
||||
log "stage 4: Test the model"
|
||||
dataset=zero_shot_zh # from CosyVoice3 test set
|
||||
# dataset=test_zh # from seed_tts test set
|
||||
output_dir=./outputs_${exp_name}_${step}_${dataset}
|
||||
|
||||
token2wav_path=/workspace/CosyVoice2-0.5B
|
||||
model_path=$llm_path/merged_hf_model
|
||||
|
||||
CUDA_VISIBLE_DEVICES=0,1,2,3,4,5,6,7 \
|
||||
torchrun --nproc_per_node=8 \
|
||||
infer_dataset.py \
|
||||
--output-dir $output_dir \
|
||||
--llm-model-name-or-path $model_path \
|
||||
--token2wav-path $token2wav_path \
|
||||
--split-name ${dataset} || exit 1
|
||||
|
||||
bash scripts/compute_wer.sh $output_dir ${dataset}
|
||||
fi
|
||||
done
|
||||
|
||||
if [ $stage -le 5 ] && [ $stop_stage -ge 5 ]; then
|
||||
log "stage 5: Convert the RL trained model to CosyVoice repo format"
|
||||
python3 huggingface_to_pretrained.py \
|
||||
--hf-cosyvoice2-llm-path $llm_path/merged_hf_model \
|
||||
--output-path /workspace/CosyVoice2-0.5B/llm-new.pt
|
||||
# You need to manually move the llm-new.pt to overwrite /workspace/CosyVoice2-0.5B/llm.pt
|
||||
# However, we found that the RL trained model accuracy would slightly drop after this conversion.
|
||||
# Please be careful or use the huggingface format inference code.
|
||||
fi
|
||||
@@ -0,0 +1,33 @@
|
||||
wav_dir=$1
|
||||
wav_files=$(ls $wav_dir/*.wav)
|
||||
# if wav_files is empty, then exit
|
||||
if [ -z "$wav_files" ]; then
|
||||
exit 1
|
||||
fi
|
||||
split_name=$2
|
||||
model_path=models/sherpa-onnx-paraformer-zh-2023-09-14
|
||||
|
||||
if [ ! -d $model_path ]; then
|
||||
pip install sherpa-onnx
|
||||
wget -nc https://github.com/k2-fsa/sherpa-onnx/releases/download/asr-models/sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2
|
||||
mkdir models
|
||||
tar xvf sherpa-onnx-paraformer-zh-2023-09-14.tar.bz2 -C models
|
||||
fi
|
||||
|
||||
python3 scripts/offline-decode-files.py \
|
||||
--tokens=$model_path/tokens.txt \
|
||||
--paraformer=$model_path/model.int8.onnx \
|
||||
--num-threads=2 \
|
||||
--decoding-method=greedy_search \
|
||||
--debug=false \
|
||||
--sample-rate=24000 \
|
||||
--log-dir $wav_dir \
|
||||
--feature-dim=80 \
|
||||
--split-name $split_name \
|
||||
--name sherpa_onnx \
|
||||
$wav_files
|
||||
|
||||
# python3 scripts/paraformer-pytriton-client.py \
|
||||
# --log-dir $wav_dir \
|
||||
# --split-name $split_name \
|
||||
# $wav_files
|
||||
@@ -0,0 +1,754 @@
|
||||
# Copyright (c) 2023 by manyeyes
|
||||
# Copyright (c) 2023 Xiaomi Corporation
|
||||
|
||||
"""
|
||||
This file demonstrates how to use sherpa-onnx Python API to transcribe
|
||||
file(s) with a non-streaming model.
|
||||
|
||||
(1) For paraformer
|
||||
|
||||
./python-api-examples/offline-decode-files.py \
|
||||
--tokens=/path/to/tokens.txt \
|
||||
--paraformer=/path/to/paraformer.onnx \
|
||||
--num-threads=2 \
|
||||
--decoding-method=greedy_search \
|
||||
--debug=false \
|
||||
--sample-rate=16000 \
|
||||
--feature-dim=80 \
|
||||
/path/to/0.wav \
|
||||
/path/to/1.wav
|
||||
|
||||
(2) For transducer models from icefall
|
||||
|
||||
./python-api-examples/offline-decode-files.py \
|
||||
--tokens=/path/to/tokens.txt \
|
||||
--encoder=/path/to/encoder.onnx \
|
||||
--decoder=/path/to/decoder.onnx \
|
||||
--joiner=/path/to/joiner.onnx \
|
||||
--num-threads=2 \
|
||||
--decoding-method=greedy_search \
|
||||
--debug=false \
|
||||
--sample-rate=16000 \
|
||||
--feature-dim=80 \
|
||||
/path/to/0.wav \
|
||||
/path/to/1.wav
|
||||
|
||||
(3) For CTC models from NeMo
|
||||
|
||||
python3 ./python-api-examples/offline-decode-files.py \
|
||||
--tokens=./sherpa-onnx-nemo-ctc-en-citrinet-512/tokens.txt \
|
||||
--nemo-ctc=./sherpa-onnx-nemo-ctc-en-citrinet-512/model.onnx \
|
||||
--num-threads=2 \
|
||||
--decoding-method=greedy_search \
|
||||
--debug=false \
|
||||
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/0.wav \
|
||||
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/1.wav \
|
||||
./sherpa-onnx-nemo-ctc-en-citrinet-512/test_wavs/8k.wav
|
||||
|
||||
(4) For Whisper models
|
||||
|
||||
python3 ./python-api-examples/offline-decode-files.py \
|
||||
--whisper-encoder=./sherpa-onnx-whisper-base.en/base.en-encoder.int8.onnx \
|
||||
--whisper-decoder=./sherpa-onnx-whisper-base.en/base.en-decoder.int8.onnx \
|
||||
--tokens=./sherpa-onnx-whisper-base.en/base.en-tokens.txt \
|
||||
--whisper-task=transcribe \
|
||||
--num-threads=1 \
|
||||
./sherpa-onnx-whisper-base.en/test_wavs/0.wav \
|
||||
./sherpa-onnx-whisper-base.en/test_wavs/1.wav \
|
||||
./sherpa-onnx-whisper-base.en/test_wavs/8k.wav
|
||||
|
||||
(5) For CTC models from WeNet
|
||||
|
||||
python3 ./python-api-examples/offline-decode-files.py \
|
||||
--wenet-ctc=./sherpa-onnx-zh-wenet-wenetspeech/model.onnx \
|
||||
--tokens=./sherpa-onnx-zh-wenet-wenetspeech/tokens.txt \
|
||||
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/0.wav \
|
||||
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/1.wav \
|
||||
./sherpa-onnx-zh-wenet-wenetspeech/test_wavs/8k.wav
|
||||
|
||||
(6) For tdnn models of the yesno recipe from icefall
|
||||
|
||||
python3 ./python-api-examples/offline-decode-files.py \
|
||||
--sample-rate=8000 \
|
||||
--feature-dim=23 \
|
||||
--tdnn-model=./sherpa-onnx-tdnn-yesno/model-epoch-14-avg-2.onnx \
|
||||
--tokens=./sherpa-onnx-tdnn-yesno/tokens.txt \
|
||||
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_0_1_0_0_0_1.wav \
|
||||
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_0_1_0.wav \
|
||||
./sherpa-onnx-tdnn-yesno/test_wavs/0_0_1_0_0_1_1_1.wav
|
||||
|
||||
Please refer to
|
||||
https://k2-fsa.github.io/sherpa/onnx/index.html
|
||||
to install sherpa-onnx and to download non-streaming pre-trained models
|
||||
used in this file.
|
||||
"""
|
||||
import argparse
|
||||
import time
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple, Dict, Iterable, TextIO, Union
|
||||
|
||||
import numpy as np
|
||||
import sherpa_onnx
|
||||
import soundfile as sf
|
||||
from datasets import load_dataset
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
import kaldialign
|
||||
from zhon.hanzi import punctuation
|
||||
import string
|
||||
punctuation_all = punctuation + string.punctuation
|
||||
Pathlike = Union[str, Path]
|
||||
|
||||
|
||||
def remove_punctuation(text: str) -> str:
|
||||
for x in punctuation_all:
|
||||
if x == '\'':
|
||||
continue
|
||||
text = text.replace(x, '')
|
||||
return text
|
||||
|
||||
|
||||
def store_transcripts(
|
||||
filename: Pathlike, texts: Iterable[Tuple[str, str, str]], char_level: bool = False
|
||||
) -> None:
|
||||
"""Save predicted results and reference transcripts to a file.
|
||||
|
||||
Args:
|
||||
filename:
|
||||
File to save the results to.
|
||||
texts:
|
||||
An iterable of tuples. The first element is the cur_id, the second is
|
||||
the reference transcript and the third element is the predicted result.
|
||||
If it is a multi-talker ASR system, the ref and hyp may also be lists of
|
||||
strings.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
with open(filename, "w", encoding="utf8") as f:
|
||||
for cut_id, ref, hyp in texts:
|
||||
if char_level:
|
||||
ref = list("".join(ref))
|
||||
hyp = list("".join(hyp))
|
||||
print(f"{cut_id}:\tref={ref}", file=f)
|
||||
print(f"{cut_id}:\thyp={hyp}", file=f)
|
||||
|
||||
|
||||
def write_error_stats(
|
||||
f: TextIO,
|
||||
test_set_name: str,
|
||||
results: List[Tuple[str, str]],
|
||||
enable_log: bool = True,
|
||||
compute_CER: bool = False,
|
||||
sclite_mode: bool = False,
|
||||
) -> float:
|
||||
"""Write statistics based on predicted results and reference transcripts.
|
||||
|
||||
It will write the following to the given file:
|
||||
|
||||
- WER
|
||||
- number of insertions, deletions, substitutions, corrects and total
|
||||
reference words. For example::
|
||||
|
||||
Errors: 23 insertions, 57 deletions, 212 substitutions, over 2606
|
||||
reference words (2337 correct)
|
||||
|
||||
- The difference between the reference transcript and predicted result.
|
||||
An instance is given below::
|
||||
|
||||
THE ASSOCIATION OF (EDISON->ADDISON) ILLUMINATING COMPANIES
|
||||
|
||||
The above example shows that the reference word is `EDISON`,
|
||||
but it is predicted to `ADDISON` (a substitution error).
|
||||
|
||||
Another example is::
|
||||
|
||||
FOR THE FIRST DAY (SIR->*) I THINK
|
||||
|
||||
The reference word `SIR` is missing in the predicted
|
||||
results (a deletion error).
|
||||
results:
|
||||
An iterable of tuples. The first element is the cut_id, the second is
|
||||
the reference transcript and the third element is the predicted result.
|
||||
enable_log:
|
||||
If True, also print detailed WER to the console.
|
||||
Otherwise, it is written only to the given file.
|
||||
Returns:
|
||||
Return None.
|
||||
"""
|
||||
subs: Dict[Tuple[str, str], int] = defaultdict(int)
|
||||
ins: Dict[str, int] = defaultdict(int)
|
||||
dels: Dict[str, int] = defaultdict(int)
|
||||
|
||||
# `words` stores counts per word, as follows:
|
||||
# corr, ref_sub, hyp_sub, ins, dels
|
||||
words: Dict[str, List[int]] = defaultdict(lambda: [0, 0, 0, 0, 0])
|
||||
num_corr = 0
|
||||
ERR = "*"
|
||||
|
||||
if compute_CER:
|
||||
for i, res in enumerate(results):
|
||||
cut_id, ref, hyp = res
|
||||
ref = list("".join(ref))
|
||||
hyp = list("".join(hyp))
|
||||
results[i] = (cut_id, ref, hyp)
|
||||
|
||||
for _cut_id, ref, hyp in results:
|
||||
ali = kaldialign.align(ref, hyp, ERR, sclite_mode=sclite_mode)
|
||||
for ref_word, hyp_word in ali:
|
||||
if ref_word == ERR:
|
||||
ins[hyp_word] += 1
|
||||
words[hyp_word][3] += 1
|
||||
elif hyp_word == ERR:
|
||||
dels[ref_word] += 1
|
||||
words[ref_word][4] += 1
|
||||
elif hyp_word != ref_word:
|
||||
subs[(ref_word, hyp_word)] += 1
|
||||
words[ref_word][1] += 1
|
||||
words[hyp_word][2] += 1
|
||||
else:
|
||||
words[ref_word][0] += 1
|
||||
num_corr += 1
|
||||
ref_len = sum([len(r) for _, r, _ in results])
|
||||
sub_errs = sum(subs.values())
|
||||
ins_errs = sum(ins.values())
|
||||
del_errs = sum(dels.values())
|
||||
tot_errs = sub_errs + ins_errs + del_errs
|
||||
tot_err_rate = "%.2f" % (100.0 * tot_errs / ref_len)
|
||||
|
||||
if enable_log:
|
||||
logging.info(
|
||||
f"[{test_set_name}] %WER {tot_errs / ref_len:.2%} "
|
||||
f"[{tot_errs} / {ref_len}, {ins_errs} ins, "
|
||||
f"{del_errs} del, {sub_errs} sub ]"
|
||||
)
|
||||
|
||||
print(f"%WER = {tot_err_rate}", file=f)
|
||||
print(
|
||||
f"Errors: {ins_errs} insertions, {del_errs} deletions, "
|
||||
f"{sub_errs} substitutions, over {ref_len} reference "
|
||||
f"words ({num_corr} correct)",
|
||||
file=f,
|
||||
)
|
||||
print(
|
||||
"Search below for sections starting with PER-UTT DETAILS:, "
|
||||
"SUBSTITUTIONS:, DELETIONS:, INSERTIONS:, PER-WORD STATS:",
|
||||
file=f,
|
||||
)
|
||||
|
||||
print("", file=f)
|
||||
print("PER-UTT DETAILS: corr or (ref->hyp) ", file=f)
|
||||
for cut_id, ref, hyp in results:
|
||||
ali = kaldialign.align(ref, hyp, ERR)
|
||||
combine_successive_errors = True
|
||||
if combine_successive_errors:
|
||||
ali = [[[x], [y]] for x, y in ali]
|
||||
for i in range(len(ali) - 1):
|
||||
if ali[i][0] != ali[i][1] and ali[i + 1][0] != ali[i + 1][1]:
|
||||
ali[i + 1][0] = ali[i][0] + ali[i + 1][0]
|
||||
ali[i + 1][1] = ali[i][1] + ali[i + 1][1]
|
||||
ali[i] = [[], []]
|
||||
ali = [
|
||||
[
|
||||
list(filter(lambda a: a != ERR, x)),
|
||||
list(filter(lambda a: a != ERR, y)),
|
||||
]
|
||||
for x, y in ali
|
||||
]
|
||||
ali = list(filter(lambda x: x != [[], []], ali))
|
||||
ali = [
|
||||
[
|
||||
ERR if x == [] else " ".join(x),
|
||||
ERR if y == [] else " ".join(y),
|
||||
]
|
||||
for x, y in ali
|
||||
]
|
||||
|
||||
print(
|
||||
f"{cut_id}:\t"
|
||||
+ " ".join(
|
||||
(
|
||||
ref_word if ref_word == hyp_word else f"({ref_word}->{hyp_word})"
|
||||
for ref_word, hyp_word in ali
|
||||
)
|
||||
),
|
||||
file=f,
|
||||
)
|
||||
|
||||
print("", file=f)
|
||||
print("SUBSTITUTIONS: count ref -> hyp", file=f)
|
||||
|
||||
for count, (ref, hyp) in sorted([(v, k) for k, v in subs.items()], reverse=True):
|
||||
print(f"{count} {ref} -> {hyp}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("DELETIONS: count ref", file=f)
|
||||
for count, ref in sorted([(v, k) for k, v in dels.items()], reverse=True):
|
||||
print(f"{count} {ref}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("INSERTIONS: count hyp", file=f)
|
||||
for count, hyp in sorted([(v, k) for k, v in ins.items()], reverse=True):
|
||||
print(f"{count} {hyp}", file=f)
|
||||
|
||||
print("", file=f)
|
||||
print("PER-WORD STATS: word corr tot_errs count_in_ref count_in_hyp", file=f)
|
||||
for _, word, counts in sorted(
|
||||
[(sum(v[1:]), k, v) for k, v in words.items()], reverse=True
|
||||
):
|
||||
(corr, ref_sub, hyp_sub, ins, dels) = counts
|
||||
tot_errs = ref_sub + hyp_sub + ins + dels
|
||||
ref_count = corr + ref_sub + dels
|
||||
hyp_count = corr + hyp_sub + ins
|
||||
|
||||
print(f"{word} {corr} {tot_errs} {ref_count} {hyp_count}", file=f)
|
||||
return float(tot_err_rate)
|
||||
|
||||
|
||||
def get_args():
|
||||
parser = argparse.ArgumentParser(
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tokens",
|
||||
type=str,
|
||||
help="Path to tokens.txt",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hotwords-file",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The file containing hotwords, one words/phrases per line, like
|
||||
HELLO WORLD
|
||||
你好世界
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--hotwords-score",
|
||||
type=float,
|
||||
default=1.5,
|
||||
help="""
|
||||
The hotword score of each token for biasing word/phrase. Used only if
|
||||
--hotwords-file is given.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--modeling-unit",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The modeling unit of the model, valid values are cjkchar, bpe, cjkchar+bpe.
|
||||
Used only when hotwords-file is given.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--bpe-vocab",
|
||||
type=str,
|
||||
default="",
|
||||
help="""
|
||||
The path to the bpe vocabulary, the bpe vocabulary is generated by
|
||||
sentencepiece, you can also export the bpe vocabulary through a bpe model
|
||||
by `scripts/export_bpe_vocab.py`. Used only when hotwords-file is given
|
||||
and modeling-unit is bpe or cjkchar+bpe.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--encoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--joiner",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the joiner model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--paraformer",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the model.onnx from Paraformer",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--nemo-ctc",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the model.onnx from NeMo CTC",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--wenet-ctc",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the model.onnx from WeNet CTC",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--tdnn-model",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to the model.onnx for the tdnn model of the yesno recipe",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--num-threads",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of threads for neural network computation",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-encoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to whisper encoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-decoder",
|
||||
default="",
|
||||
type=str,
|
||||
help="Path to whisper decoder model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-language",
|
||||
default="",
|
||||
type=str,
|
||||
help="""It specifies the spoken language in the input audio file.
|
||||
Example values: en, fr, de, zh, jp.
|
||||
Available languages for multilingual models can be found at
|
||||
https://github.com/openai/whisper/blob/main/whisper/tokenizer.py#L10
|
||||
If not specified, we infer the language from the input audio file.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-task",
|
||||
default="transcribe",
|
||||
choices=["transcribe", "translate"],
|
||||
type=str,
|
||||
help="""For multilingual models, if you specify translate, the output
|
||||
will be in English.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--whisper-tail-paddings",
|
||||
default=-1,
|
||||
type=int,
|
||||
help="""Number of tail padding frames.
|
||||
We have removed the 30-second constraint from whisper, so you need to
|
||||
choose the amount of tail padding frames by yourself.
|
||||
Use -1 to use a default value for tail padding.
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--blank-penalty",
|
||||
type=float,
|
||||
default=0.0,
|
||||
help="""
|
||||
The penalty applied on blank symbol during decoding.
|
||||
Note: It is a positive value that would be applied to logits like
|
||||
this `logits[:, 0] -= blank_penalty` (suppose logits.shape is
|
||||
[batch_size, vocab] and blank id is 0).
|
||||
""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--decoding-method",
|
||||
type=str,
|
||||
default="greedy_search",
|
||||
help="Valid values are greedy_search and modified_beam_search",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--debug",
|
||||
type=bool,
|
||||
default=False,
|
||||
help="True to show debug messages",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--sample-rate",
|
||||
type=int,
|
||||
default=16000,
|
||||
help="""Sample rate of the feature extractor. Must match the one
|
||||
expected by the model. Note: The input sound files can have a
|
||||
different sample rate from this argument.""",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--feature-dim",
|
||||
type=int,
|
||||
default=80,
|
||||
help="Feature dimension. Must match the one expected by the model",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"sound_files",
|
||||
type=str,
|
||||
nargs="+",
|
||||
help="The input sound file(s) to decode. Each file must be of WAVE"
|
||||
"format with a single channel, and each sample has 16-bit, "
|
||||
"i.e., int16_t. "
|
||||
"The sample rate of the file can be arbitrary and does not need to "
|
||||
"be 16 kHz",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--name",
|
||||
type=str,
|
||||
default="",
|
||||
help="The directory containing the input sound files to decode",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--log-dir",
|
||||
type=str,
|
||||
default="",
|
||||
help="The directory containing the input sound files to decode",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--label",
|
||||
type=str,
|
||||
default=None,
|
||||
help="wav_base_name label",
|
||||
)
|
||||
|
||||
# Dataset related arguments for loading labels when label file is not provided
|
||||
parser.add_argument(
|
||||
"--dataset-name",
|
||||
type=str,
|
||||
default="yuekai/seed_tts_cosy2",
|
||||
help="Huggingface dataset name for loading labels",
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--split-name",
|
||||
type=str,
|
||||
default="wenetspeech4tts",
|
||||
help="Dataset split name for loading labels",
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def assert_file_exists(filename: str):
|
||||
assert Path(filename).is_file(), (
|
||||
f"{filename} does not exist!\n"
|
||||
"Please refer to "
|
||||
"https://k2-fsa.github.io/sherpa/onnx/pretrained_models/index.html to download it"
|
||||
)
|
||||
|
||||
|
||||
def read_wave(wave_filename: str) -> Tuple[np.ndarray, int]:
|
||||
"""
|
||||
Args:
|
||||
wave_filename:
|
||||
Path to a wave file. It should be single channel and can be of type
|
||||
32-bit floating point PCM. Its sample rate does not need to be 24kHz.
|
||||
|
||||
Returns:
|
||||
Return a tuple containing:
|
||||
- A 1-D array of dtype np.float32 containing the samples,
|
||||
which are normalized to the range [-1, 1].
|
||||
- Sample rate of the wave file.
|
||||
"""
|
||||
|
||||
samples, sample_rate = sf.read(wave_filename, dtype="float32")
|
||||
assert (
|
||||
samples.ndim == 1
|
||||
), f"Expected single channel, but got {samples.ndim} channels."
|
||||
|
||||
samples_float32 = samples.astype(np.float32)
|
||||
|
||||
return samples_float32, sample_rate
|
||||
|
||||
|
||||
def normalize_text_alimeeting(text: str) -> str:
|
||||
"""
|
||||
Text normalization similar to M2MeT challenge baseline.
|
||||
See: https://github.com/yufan-aslp/AliMeeting/blob/main/asr/local/text_normalize.pl
|
||||
"""
|
||||
import re
|
||||
text = text.replace('\u00A0', '') # test_hard
|
||||
text = text.replace(" ", "")
|
||||
text = text.replace("<sil>", "")
|
||||
text = text.replace("<%>", "")
|
||||
text = text.replace("<->", "")
|
||||
text = text.replace("<$>", "")
|
||||
text = text.replace("<#>", "")
|
||||
text = text.replace("<_>", "")
|
||||
text = text.replace("<space>", "")
|
||||
text = text.replace("`", "")
|
||||
text = text.replace("&", "")
|
||||
text = text.replace(",", "")
|
||||
if re.search("[a-zA-Z]", text):
|
||||
text = text.upper()
|
||||
text = text.replace("A", "A")
|
||||
text = text.replace("a", "A")
|
||||
text = text.replace("b", "B")
|
||||
text = text.replace("c", "C")
|
||||
text = text.replace("k", "K")
|
||||
text = text.replace("t", "T")
|
||||
text = text.replace(",", "")
|
||||
text = text.replace("丶", "")
|
||||
text = text.replace("。", "")
|
||||
text = text.replace("、", "")
|
||||
text = text.replace("?", "")
|
||||
text = remove_punctuation(text)
|
||||
return text
|
||||
|
||||
|
||||
def main():
|
||||
args = get_args()
|
||||
assert_file_exists(args.tokens)
|
||||
assert args.num_threads > 0, args.num_threads
|
||||
|
||||
assert len(args.nemo_ctc) == 0, args.nemo_ctc
|
||||
assert len(args.wenet_ctc) == 0, args.wenet_ctc
|
||||
assert len(args.whisper_encoder) == 0, args.whisper_encoder
|
||||
assert len(args.whisper_decoder) == 0, args.whisper_decoder
|
||||
assert len(args.tdnn_model) == 0, args.tdnn_model
|
||||
|
||||
assert_file_exists(args.paraformer)
|
||||
|
||||
recognizer = sherpa_onnx.OfflineRecognizer.from_paraformer(
|
||||
paraformer=args.paraformer,
|
||||
tokens=args.tokens,
|
||||
num_threads=args.num_threads,
|
||||
sample_rate=args.sample_rate,
|
||||
feature_dim=args.feature_dim,
|
||||
decoding_method=args.decoding_method,
|
||||
debug=args.debug,
|
||||
)
|
||||
|
||||
print("Started!")
|
||||
start_time = time.time()
|
||||
|
||||
streams, results = [], []
|
||||
total_duration = 0
|
||||
|
||||
for i, wave_filename in enumerate(args.sound_files):
|
||||
assert_file_exists(wave_filename)
|
||||
samples, sample_rate = read_wave(wave_filename)
|
||||
duration = len(samples) / sample_rate
|
||||
total_duration += duration
|
||||
s = recognizer.create_stream()
|
||||
s.accept_waveform(sample_rate, samples)
|
||||
|
||||
streams.append(s)
|
||||
if i % 10 == 0:
|
||||
recognizer.decode_streams(streams)
|
||||
results += [s.result.text for s in streams]
|
||||
streams = []
|
||||
print(f"Processed {i} files")
|
||||
# process the last batch
|
||||
if streams:
|
||||
recognizer.decode_streams(streams)
|
||||
results += [s.result.text for s in streams]
|
||||
end_time = time.time()
|
||||
print("Done!")
|
||||
|
||||
results_dict = {}
|
||||
for wave_filename, result in zip(args.sound_files, results):
|
||||
print(f"{wave_filename}\n{result}")
|
||||
print("-" * 10)
|
||||
wave_basename = Path(wave_filename).stem
|
||||
results_dict[wave_basename] = result
|
||||
|
||||
elapsed_seconds = end_time - start_time
|
||||
rtf = elapsed_seconds / total_duration
|
||||
print(f"num_threads: {args.num_threads}")
|
||||
print(f"decoding_method: {args.decoding_method}")
|
||||
print(f"Wave duration: {total_duration:.3f} s")
|
||||
print(f"Elapsed time: {elapsed_seconds:.3f} s")
|
||||
print(
|
||||
f"Real time factor (RTF): {elapsed_seconds:.3f}/{total_duration:.3f} = {rtf:.3f}"
|
||||
)
|
||||
|
||||
# Load labels either from file or from dataset
|
||||
labels_dict = {}
|
||||
|
||||
if args.label:
|
||||
# Load labels from file (original functionality)
|
||||
print(f"Loading labels from file: {args.label}")
|
||||
with open(args.label, "r") as f:
|
||||
for line in f:
|
||||
# fields = line.strip().split(" ")
|
||||
# fields = [item for item in fields if item]
|
||||
# assert len(fields) == 4
|
||||
# prompt_text, prompt_audio, text, audio_path = fields
|
||||
|
||||
fields = line.strip().split("|")
|
||||
fields = [item for item in fields if item]
|
||||
assert len(fields) == 4
|
||||
audio_path, prompt_text, prompt_audio, text = fields
|
||||
labels_dict[Path(audio_path).stem] = normalize_text_alimeeting(text)
|
||||
else:
|
||||
# Load labels from dataset (new functionality)
|
||||
print(f"Loading labels from dataset: {args.dataset_name}, split: {args.split_name}")
|
||||
if 'zero' in args.split_name:
|
||||
dataset_name = "yuekai/CV3-Eval"
|
||||
else:
|
||||
dataset_name = "yuekai/seed_tts_cosy2"
|
||||
dataset = load_dataset(
|
||||
dataset_name,
|
||||
split=args.split_name,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
for item in dataset:
|
||||
audio_id = item["id"]
|
||||
labels_dict[audio_id] = normalize_text_alimeeting(item["target_text"])
|
||||
|
||||
print(f"Loaded {len(labels_dict)} labels from dataset")
|
||||
|
||||
# Perform evaluation if labels are available
|
||||
if labels_dict:
|
||||
|
||||
final_results = []
|
||||
for key, value in results_dict.items():
|
||||
if key in labels_dict:
|
||||
final_results.append((key, labels_dict[key], value))
|
||||
else:
|
||||
print(f"Warning: No label found for {key}, skipping...")
|
||||
|
||||
if final_results:
|
||||
store_transcripts(
|
||||
filename=f"{args.log_dir}/recogs-{args.name}.txt", texts=final_results
|
||||
)
|
||||
with open(f"{args.log_dir}/errs-{args.name}.txt", "w") as f:
|
||||
write_error_stats(f, "test-set", final_results, enable_log=True)
|
||||
|
||||
with open(f"{args.log_dir}/errs-{args.name}.txt", "r") as f:
|
||||
print(f.readline()) # WER
|
||||
print(f.readline()) # Detailed errors
|
||||
else:
|
||||
print("No matching labels found for evaluation")
|
||||
else:
|
||||
print("No labels available for evaluation")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,346 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Pytriton server for token2wav conversion and ASR"""
|
||||
|
||||
from datasets import load_dataset
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||
from omnisense.models import OmniSenseVoiceSmall
|
||||
from pytriton.proxy.types import Request
|
||||
from pytriton.triton import Triton, TritonConfig
|
||||
from pytriton.model_config import DynamicBatcher, ModelConfig, Tensor
|
||||
from pytriton.decorators import batch
|
||||
import argparse
|
||||
import io
|
||||
import logging
|
||||
from typing import Any, List
|
||||
import numpy as np
|
||||
import torch
|
||||
from scipy.signal import resample
|
||||
import sys
|
||||
import random
|
||||
import re
|
||||
from jiwer import wer
|
||||
from pypinyin import lazy_pinyin, Style
|
||||
from tn.chinese.normalizer import Normalizer as ZhNormalizer
|
||||
|
||||
# Chinese text normalizer (cached globally)
|
||||
zh_tn_model = ZhNormalizer(
|
||||
cache_dir="./cache",
|
||||
remove_erhua=False,
|
||||
remove_interjections=False,
|
||||
remove_puncts=True,
|
||||
overwrite_cache=True,
|
||||
)
|
||||
|
||||
|
||||
sys.path.append("/workspace/CosyVoice/third_party/Matcha-TTS")
|
||||
|
||||
logger = logging.getLogger("token2wav_asr_server")
|
||||
|
||||
|
||||
class _ASR_Server:
|
||||
"""Wraps a single OmniSenseVoiceSmall model instance for Triton."""
|
||||
|
||||
def __init__(self, device_id: int):
|
||||
self._model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id)
|
||||
|
||||
@batch
|
||||
def __call__(self, WAV: np.ndarray, WAV_LENS: np.ndarray, LANGUAGE: np.ndarray, TEXT_NORM: np.ndarray):
|
||||
"""
|
||||
WAV: np.ndarray, WAV_LENS: np.ndarray
|
||||
LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used
|
||||
See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu
|
||||
"""
|
||||
logger.debug("WAV: %s, WAV_LENS: %s, shapes: %s %s", type(WAV), type(WAV_LENS), WAV.shape, WAV_LENS.shape)
|
||||
wavs = [WAV[i, :WAV_LENS[i, 0]] for i in range(len(WAV))]
|
||||
|
||||
results = self._model.transcribe_single_batch(
|
||||
wavs,
|
||||
language="zh",
|
||||
textnorm="woitn",
|
||||
)
|
||||
texts = [result.text for result in results]
|
||||
transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
|
||||
return {"TRANSCRIPTS": transcripts}
|
||||
|
||||
|
||||
def audio_decode_cosyvoice2(
|
||||
audio_tokens, prompt_text, prompt_speech_16k, codec_decoder
|
||||
):
|
||||
"""
|
||||
Generate audio from tokens with optional tone and prompt embedding.
|
||||
"""
|
||||
model_inputs_dict = codec_decoder.frontend.frontend_zero_shot(
|
||||
"empty", prompt_text, prompt_speech_16k, 24000
|
||||
)
|
||||
tts_mel, _ = codec_decoder.model.flow.inference(
|
||||
token=audio_tokens.to(codec_decoder.model.device),
|
||||
token_len=torch.tensor([audio_tokens.shape[1]], dtype=torch.int32).to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
prompt_token=model_inputs_dict["flow_prompt_speech_token"].to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
prompt_token_len=torch.tensor(
|
||||
[model_inputs_dict["flow_prompt_speech_token_len"]], dtype=torch.int32
|
||||
).to(codec_decoder.model.device),
|
||||
prompt_feat=model_inputs_dict["prompt_speech_feat"].to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
prompt_feat_len=model_inputs_dict["prompt_speech_feat_len"].to(
|
||||
codec_decoder.model.device
|
||||
),
|
||||
embedding=model_inputs_dict["flow_embedding"].to(codec_decoder.model.device),
|
||||
finalize=True,
|
||||
)
|
||||
|
||||
audio_hat, _ = codec_decoder.model.hift.inference(
|
||||
speech_feat=tts_mel, cache_source=torch.zeros(1, 1, 0)
|
||||
)
|
||||
|
||||
return audio_hat
|
||||
|
||||
|
||||
def get_random_prompt_from_dataset(dataset):
|
||||
"""
|
||||
Get random prompt text and speech from the pre-loaded dataset.
|
||||
Returns (prompt_text, prompt_speech_16k)
|
||||
"""
|
||||
random_idx = random.randint(0, len(dataset) - 1)
|
||||
sample = dataset[random_idx]
|
||||
|
||||
# Extract audio data
|
||||
audio_data = sample["audio"]
|
||||
audio_array = audio_data["array"]
|
||||
sample_rate = audio_data["sampling_rate"]
|
||||
|
||||
# Convert audio to 16kHz if needed
|
||||
if sample_rate != 16000:
|
||||
num_samples = int(len(audio_array) * (16000 / sample_rate))
|
||||
audio_array = resample(audio_array, num_samples)
|
||||
|
||||
# Convert to torch tensor
|
||||
prompt_speech_16k = torch.from_numpy(audio_array).float().unsqueeze(0)
|
||||
prompt_text = sample["text"]
|
||||
# remove space in prompt_text
|
||||
prompt_text = prompt_text.replace(" ", "")
|
||||
return prompt_text, prompt_speech_16k
|
||||
|
||||
|
||||
class _Token2Wav_ASR:
|
||||
"""Wraps a single OmniSenseVoiceSmall model instance for Triton."""
|
||||
|
||||
def __init__(self, device_id: int):
|
||||
self.asr_model = OmniSenseVoiceSmall("iic/SenseVoiceSmall", quantize=False, device_id=device_id)
|
||||
self.dataset = load_dataset("yuekai/aishell", "test", trust_remote_code=True)["test"]
|
||||
|
||||
# Make sure the CosyVoice2 decoder lives on the same GPU as the ASR model
|
||||
# CosyVoice2 internally uses generic "cuda" device, so we first switch the
|
||||
# current CUDA context to the desired card before the object is created.
|
||||
# Afterwards, all parameters loaded with the generic "cuda" device will
|
||||
# reside on this GPU. We keep the selected id in `self.device_id` and
|
||||
# will set the context again for every forward call to avoid race
|
||||
# conditions when several instances are used in the same process.
|
||||
|
||||
self.device_id = device_id
|
||||
|
||||
# Construct the TTS codec decoder under the correct CUDA device context
|
||||
with torch.cuda.device(self.device_id):
|
||||
self.codec_decoder = CosyVoice2(
|
||||
"/workspace/CosyVoice2-0.5B", load_jit=True, load_trt=True, fp16=True
|
||||
)
|
||||
|
||||
@batch
|
||||
def __call__(self, TOKENS: np.ndarray, TOKEN_LENS: np.ndarray, GT_TEXT: np.ndarray):
|
||||
"""
|
||||
WAV: np.ndarray, WAV_LENS: np.ndarray
|
||||
LANGUAGE: np.ndarray, TEXTNORM: np.ndarray for backward compatibility, not used
|
||||
See: https://github.com/modelscope/FunASR/tree/main/runtime/triton_gpu
|
||||
"""
|
||||
# Ensure the default CUDA device is set correctly for this invocation
|
||||
torch.cuda.set_device(self.device_id)
|
||||
|
||||
if self.device_id == 0:
|
||||
print(f"device_id: {self.device_id}, TOKENS: {TOKENS.shape}, TOKEN_LENS: {TOKEN_LENS.shape}")
|
||||
|
||||
tokens_list = [TOKENS[i, :TOKEN_LENS[i, 0]] for i in range(len(TOKENS))]
|
||||
|
||||
# Decode ground-truth text strings (BYTES → str)
|
||||
if GT_TEXT.ndim == 2:
|
||||
gt_texts = [GT_TEXT[i, 0].decode("utf-8") for i in range(len(GT_TEXT))]
|
||||
else:
|
||||
gt_texts = [GT_TEXT[i].decode("utf-8") for i in range(len(GT_TEXT))]
|
||||
|
||||
wavs = []
|
||||
for tokens in tokens_list:
|
||||
prompt_text, prompt_speech_16k = get_random_prompt_from_dataset(self.dataset)
|
||||
audio_tokens = torch.tensor(tokens, dtype=torch.long, device=self.asr_model.device).unsqueeze(0)
|
||||
audio_hat = audio_decode_cosyvoice2(
|
||||
audio_tokens,
|
||||
prompt_text,
|
||||
prompt_speech_16k,
|
||||
self.codec_decoder,
|
||||
)
|
||||
# resample to 16000 using soundfile
|
||||
audio_hat = audio_hat.squeeze(0).float().cpu()
|
||||
audio_hat = audio_hat.numpy()
|
||||
num_samples = int(len(audio_hat) * (16000 / 24000))
|
||||
audio_hat = resample(audio_hat, num_samples)
|
||||
wavs.append(audio_hat)
|
||||
|
||||
results = self.asr_model.transcribe_single_batch(
|
||||
wavs,
|
||||
language="zh",
|
||||
textnorm="woitn",
|
||||
)
|
||||
texts = [result.text for result in results]
|
||||
|
||||
# ---------------- Reward computation ----------------
|
||||
rewards = []
|
||||
for gt_text, hyp_text in zip(gt_texts, texts):
|
||||
gt_norm = zh_tn_model.normalize(gt_text).lower()
|
||||
hyp_norm = zh_tn_model.normalize(hyp_text).lower()
|
||||
|
||||
gt_pinyin = lazy_pinyin(
|
||||
gt_norm,
|
||||
style=Style.TONE3,
|
||||
tone_sandhi=True,
|
||||
neutral_tone_with_five=True,
|
||||
)
|
||||
hyp_pinyin = lazy_pinyin(
|
||||
hyp_norm,
|
||||
style=Style.TONE3,
|
||||
tone_sandhi=True,
|
||||
neutral_tone_with_five=True,
|
||||
)
|
||||
|
||||
c = float(wer(" ".join(gt_pinyin), " ".join(hyp_pinyin)))
|
||||
reward_val = 1.0 - np.tanh(3.0 * c)
|
||||
reward_val = max(0.0, min(1.0, reward_val))
|
||||
rewards.append(reward_val)
|
||||
print(f"gt_text: {gt_text}, hyp_text: {hyp_text}, reward_val: {reward_val}")
|
||||
|
||||
transcripts = np.char.encode(np.array(texts).reshape(-1, 1), "utf-8")
|
||||
rewards_arr = np.array(rewards, dtype=np.float32).reshape(-1, 1)
|
||||
|
||||
return {"REWARDS": rewards_arr, "TRANSCRIPTS": transcripts}
|
||||
|
||||
|
||||
def _infer_function_factory(device_ids: List[int], model_name: str):
|
||||
"""Creates a list of inference functions, one for each requested device ID."""
|
||||
infer_funcs = []
|
||||
for device_id in device_ids:
|
||||
if model_name == "sensevoice":
|
||||
infer_funcs.append(_ASR_Server(device_id=device_id))
|
||||
else:
|
||||
infer_funcs.append(_Token2Wav_ASR(device_id=device_id))
|
||||
return infer_funcs
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description=__doc__)
|
||||
parser.add_argument(
|
||||
"--max-batch-size",
|
||||
type=int,
|
||||
default=32,
|
||||
help="Batch size of request.",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--verbose",
|
||||
action="store_true",
|
||||
default=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--number-of-instances-per-device",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Number of model instances to load.",
|
||||
required=False,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--number-of-devices",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of devices to use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--model-name",
|
||||
type=str,
|
||||
default="token2wav_asr",
|
||||
choices=["token2wav_asr", "sensevoice"],
|
||||
help="Model name.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
log_level = logging.DEBUG if args.verbose else logging.INFO
|
||||
logging.basicConfig(level=log_level, format="%(asctime)s - %(levelname)s - %(name)s: %(message)s")
|
||||
|
||||
triton_config = TritonConfig(
|
||||
http_port=8000,
|
||||
grpc_port=8001,
|
||||
metrics_port=8002,
|
||||
)
|
||||
|
||||
device_ids = list(range(args.number_of_devices))
|
||||
device_ids = device_ids * args.number_of_instances_per_device
|
||||
|
||||
with Triton(config=triton_config) as triton:
|
||||
logger.info("Loading SenseVoice model on device ids: %s", device_ids)
|
||||
if args.model_name == "sensevoice":
|
||||
triton.bind(
|
||||
model_name="sensevoice",
|
||||
infer_func=_infer_function_factory(device_ids, args.model_name),
|
||||
inputs=[
|
||||
Tensor(name="WAV", dtype=np.float32, shape=(-1,)),
|
||||
Tensor(name="WAV_LENS", dtype=np.int32, shape=(-1,)),
|
||||
Tensor(name="LANGUAGE", dtype=np.int32, shape=(-1,)),
|
||||
Tensor(name="TEXT_NORM", dtype=np.int32, shape=(-1,)),
|
||||
],
|
||||
outputs=[
|
||||
Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)),
|
||||
],
|
||||
config=ModelConfig(
|
||||
max_batch_size=args.max_batch_size,
|
||||
batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms
|
||||
),
|
||||
strict=True,
|
||||
)
|
||||
else:
|
||||
triton.bind(
|
||||
model_name="token2wav_asr",
|
||||
infer_func=_infer_function_factory(device_ids, args.model_name),
|
||||
inputs=[
|
||||
Tensor(name="TOKENS", dtype=np.int32, shape=(-1,)),
|
||||
Tensor(name="TOKEN_LENS", dtype=np.int32, shape=(-1,)),
|
||||
Tensor(name="GT_TEXT", dtype=bytes, shape=(-1,)),
|
||||
],
|
||||
outputs=[
|
||||
Tensor(name="REWARDS", dtype=np.float32, shape=(-1,)),
|
||||
Tensor(name="TRANSCRIPTS", dtype=bytes, shape=(-1,)),
|
||||
],
|
||||
config=ModelConfig(
|
||||
max_batch_size=args.max_batch_size,
|
||||
batcher=DynamicBatcher(max_queue_delay_microseconds=10000), # 10ms
|
||||
),
|
||||
strict=True,
|
||||
)
|
||||
logger.info("Serving inference")
|
||||
triton.serve()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
257
models/CosyVoice/examples/libritts/cosyvoice/conf/cosyvoice.yaml
Normal file
257
models/CosyVoice/examples/libritts/cosyvoice/conf/cosyvoice.yaml
Normal file
@@ -0,0 +1,257 @@
|
||||
# set random seed, so that you may reproduce your result.
|
||||
__set_seed1: !apply:random.seed [1986]
|
||||
__set_seed2: !apply:numpy.random.seed [1986]
|
||||
__set_seed3: !apply:torch.manual_seed [1986]
|
||||
__set_seed4: !apply:torch.cuda.manual_seed_all [1986]
|
||||
|
||||
# fixed params
|
||||
sample_rate: 22050
|
||||
text_encoder_input_size: 512
|
||||
llm_input_size: 1024
|
||||
llm_output_size: 1024
|
||||
spk_embed_dim: 192
|
||||
|
||||
# model params
|
||||
# for all class/function included in this repo, we use !<name> or !<new> for intialization, so that user may find all corresponding class/function according to one single yaml.
|
||||
# for system/third_party class/function, we do not require this.
|
||||
llm: !new:cosyvoice.llm.llm.TransformerLM
|
||||
text_encoder_input_size: !ref <text_encoder_input_size>
|
||||
llm_input_size: !ref <llm_input_size>
|
||||
llm_output_size: !ref <llm_output_size>
|
||||
text_token_size: 51866 # change to 60515 if you want to train with CosyVoice-300M-25Hz recipe
|
||||
speech_token_size: 4096
|
||||
length_normalized_loss: True
|
||||
lsm_weight: 0
|
||||
spk_embed_dim: !ref <spk_embed_dim>
|
||||
text_encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
|
||||
input_size: !ref <text_encoder_input_size>
|
||||
output_size: 1024
|
||||
attention_heads: 16
|
||||
linear_units: 4096
|
||||
num_blocks: 6
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
normalize_before: True
|
||||
input_layer: 'linear'
|
||||
pos_enc_layer_type: 'rel_pos_espnet'
|
||||
selfattention_layer_type: 'rel_selfattn'
|
||||
use_cnn_module: False
|
||||
macaron_style: False
|
||||
use_dynamic_chunk: False
|
||||
use_dynamic_left_chunk: False
|
||||
static_chunk_size: 1
|
||||
llm: !new:cosyvoice.transformer.encoder.TransformerEncoder
|
||||
input_size: !ref <llm_input_size>
|
||||
output_size: !ref <llm_output_size>
|
||||
attention_heads: 16
|
||||
linear_units: 4096
|
||||
num_blocks: 14
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.0
|
||||
input_layer: 'linear_legacy'
|
||||
pos_enc_layer_type: 'rel_pos_espnet'
|
||||
selfattention_layer_type: 'rel_selfattn'
|
||||
static_chunk_size: 1
|
||||
sampling: !name:cosyvoice.utils.common.ras_sampling
|
||||
top_p: 0.8
|
||||
top_k: 25
|
||||
win_size: 10
|
||||
tau_r: 0.1
|
||||
|
||||
flow: !new:cosyvoice.flow.flow.MaskedDiffWithXvec
|
||||
input_size: 512
|
||||
output_size: 80
|
||||
spk_embed_dim: !ref <spk_embed_dim>
|
||||
output_type: 'mel'
|
||||
vocab_size: 4096
|
||||
input_frame_rate: 50 # change to 25 if you want to train with CosyVoice-300M-25Hz recipe
|
||||
only_mask_loss: True
|
||||
encoder: !new:cosyvoice.transformer.encoder.ConformerEncoder
|
||||
output_size: 512
|
||||
attention_heads: 8
|
||||
linear_units: 2048
|
||||
num_blocks: 6
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.1
|
||||
normalize_before: True
|
||||
input_layer: 'linear'
|
||||
pos_enc_layer_type: 'rel_pos_espnet'
|
||||
selfattention_layer_type: 'rel_selfattn'
|
||||
input_size: 512
|
||||
use_cnn_module: False
|
||||
macaron_style: False
|
||||
length_regulator: !new:cosyvoice.flow.length_regulator.InterpolateRegulator
|
||||
channels: 80
|
||||
sampling_ratios: [1, 1, 1, 1]
|
||||
decoder: !new:cosyvoice.flow.flow_matching.ConditionalCFM
|
||||
in_channels: 240
|
||||
n_spks: 1
|
||||
spk_emb_dim: 80
|
||||
cfm_params: !new:omegaconf.DictConfig
|
||||
content:
|
||||
sigma_min: 1e-06
|
||||
solver: 'euler'
|
||||
t_scheduler: 'cosine'
|
||||
training_cfg_rate: 0.2
|
||||
inference_cfg_rate: 0.7
|
||||
reg_loss_type: 'l1'
|
||||
estimator: !new:cosyvoice.flow.decoder.ConditionalDecoder
|
||||
in_channels: 320
|
||||
out_channels: 80
|
||||
channels: [256, 256]
|
||||
dropout: 0.0
|
||||
attention_head_dim: 64
|
||||
n_blocks: 4
|
||||
num_mid_blocks: 12
|
||||
num_heads: 8
|
||||
act_fn: 'gelu'
|
||||
|
||||
hift: !new:cosyvoice.hifigan.generator.HiFTGenerator
|
||||
in_channels: 80
|
||||
base_channels: 512
|
||||
nb_harmonics: 8
|
||||
sampling_rate: !ref <sample_rate>
|
||||
nsf_alpha: 0.1
|
||||
nsf_sigma: 0.003
|
||||
nsf_voiced_threshold: 10
|
||||
upsample_rates: [8, 8]
|
||||
upsample_kernel_sizes: [16, 16]
|
||||
istft_params:
|
||||
n_fft: 16
|
||||
hop_len: 4
|
||||
resblock_kernel_sizes: [3, 7, 11]
|
||||
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
|
||||
source_resblock_kernel_sizes: [7, 11]
|
||||
source_resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5]]
|
||||
lrelu_slope: 0.1
|
||||
audio_limit: 0.99
|
||||
f0_predictor: !new:cosyvoice.hifigan.f0_predictor.ConvRNNF0Predictor
|
||||
num_class: 1
|
||||
in_channels: 80
|
||||
cond_channels: 512
|
||||
|
||||
# gan related module
|
||||
mel_spec_transform1: !name:matcha.utils.audio.mel_spectrogram
|
||||
n_fft: 1024
|
||||
num_mels: 80
|
||||
sampling_rate: !ref <sample_rate>
|
||||
hop_size: 256
|
||||
win_size: 1024
|
||||
fmin: 0
|
||||
fmax: null
|
||||
center: False
|
||||
hifigan: !new:cosyvoice.hifigan.hifigan.HiFiGan
|
||||
generator: !ref <hift>
|
||||
discriminator: !new:cosyvoice.hifigan.discriminator.MultipleDiscriminator
|
||||
mpd: !new:matcha.hifigan.models.MultiPeriodDiscriminator
|
||||
mrd: !new:cosyvoice.hifigan.discriminator.MultiResSpecDiscriminator
|
||||
mel_spec_transform: [
|
||||
!ref <mel_spec_transform1>
|
||||
]
|
||||
|
||||
# processor functions
|
||||
parquet_opener: !name:cosyvoice.dataset.processor.parquet_opener
|
||||
get_tokenizer: !name:whisper.tokenizer.get_tokenizer # change to !name:cosyvoice.tokenizer.tokenizer.get_tokenizer if you want to train with CosyVoice-300M-25Hz recipe
|
||||
multilingual: True
|
||||
num_languages: 100
|
||||
language: 'en'
|
||||
task: 'transcribe'
|
||||
allowed_special: 'all'
|
||||
tokenize: !name:cosyvoice.dataset.processor.tokenize
|
||||
get_tokenizer: !ref <get_tokenizer>
|
||||
allowed_special: !ref <allowed_special>
|
||||
filter: !name:cosyvoice.dataset.processor.filter
|
||||
max_length: 40960
|
||||
min_length: 0
|
||||
token_max_length: 200
|
||||
token_min_length: 1
|
||||
resample: !name:cosyvoice.dataset.processor.resample
|
||||
resample_rate: !ref <sample_rate>
|
||||
truncate: !name:cosyvoice.dataset.processor.truncate
|
||||
truncate_length: 24576 # must be a multiplier of hop_size
|
||||
feat_extractor: !name:matcha.utils.audio.mel_spectrogram
|
||||
n_fft: 1024
|
||||
num_mels: 80
|
||||
sampling_rate: !ref <sample_rate>
|
||||
hop_size: 256
|
||||
win_size: 1024
|
||||
fmin: 0
|
||||
fmax: 8000
|
||||
center: False
|
||||
compute_fbank: !name:cosyvoice.dataset.processor.compute_fbank
|
||||
feat_extractor: !ref <feat_extractor>
|
||||
compute_f0: !name:cosyvoice.dataset.processor.compute_f0
|
||||
sample_rate: !ref <sample_rate>
|
||||
hop_size: 256
|
||||
parse_embedding: !name:cosyvoice.dataset.processor.parse_embedding
|
||||
normalize: True
|
||||
shuffle: !name:cosyvoice.dataset.processor.shuffle
|
||||
shuffle_size: 1000
|
||||
sort: !name:cosyvoice.dataset.processor.sort
|
||||
sort_size: 500 # sort_size should be less than shuffle_size
|
||||
batch: !name:cosyvoice.dataset.processor.batch
|
||||
batch_type: 'dynamic'
|
||||
max_frames_in_batch: 2000 # change to 1400 in gan train on v100 16g
|
||||
padding: !name:cosyvoice.dataset.processor.padding
|
||||
use_spk_embedding: False # change to True during sft
|
||||
|
||||
# dataset processor pipeline
|
||||
data_pipeline: [
|
||||
!ref <parquet_opener>,
|
||||
!ref <tokenize>,
|
||||
!ref <filter>,
|
||||
!ref <resample>,
|
||||
!ref <compute_fbank>,
|
||||
!ref <parse_embedding>,
|
||||
!ref <shuffle>,
|
||||
!ref <sort>,
|
||||
!ref <batch>,
|
||||
!ref <padding>,
|
||||
]
|
||||
data_pipeline_gan: [
|
||||
!ref <parquet_opener>,
|
||||
!ref <tokenize>,
|
||||
!ref <filter>,
|
||||
!ref <resample>,
|
||||
!ref <truncate>,
|
||||
!ref <compute_fbank>,
|
||||
!ref <compute_f0>,
|
||||
!ref <parse_embedding>,
|
||||
!ref <shuffle>,
|
||||
!ref <sort>,
|
||||
!ref <batch>,
|
||||
!ref <padding>,
|
||||
]
|
||||
|
||||
# llm flow train conf
|
||||
train_conf:
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.001 # change to 1e-5 during sft
|
||||
scheduler: warmuplr # change to constantlr during sft
|
||||
scheduler_conf:
|
||||
warmup_steps: 2500
|
||||
max_epoch: 200
|
||||
grad_clip: 5
|
||||
accum_grad: 2
|
||||
log_interval: 100
|
||||
save_per_step: -1
|
||||
|
||||
# gan train conf
|
||||
train_conf_gan:
|
||||
optim: adam
|
||||
optim_conf:
|
||||
lr: 0.0002 # use small lr for gan training
|
||||
scheduler: constantlr
|
||||
optim_d: adam
|
||||
optim_conf_d:
|
||||
lr: 0.0002 # use small lr for gan training
|
||||
scheduler_d: constantlr
|
||||
max_epoch: 200
|
||||
grad_clip: 5
|
||||
accum_grad: 1 # in gan training, accum_grad must be 1
|
||||
log_interval: 100
|
||||
save_per_step: -1
|
||||
@@ -0,0 +1,42 @@
|
||||
{
|
||||
"train_micro_batch_size_per_gpu": 1,
|
||||
"gradient_accumulation_steps": 1,
|
||||
"steps_per_print": 100,
|
||||
"gradient_clipping": 5,
|
||||
"fp16": {
|
||||
"enabled": false,
|
||||
"auto_cast": false,
|
||||
"loss_scale": 0,
|
||||
"initial_scale_power": 16,
|
||||
"loss_scale_window": 256,
|
||||
"hysteresis": 2,
|
||||
"consecutive_hysteresis": false,
|
||||
"min_loss_scale": 1
|
||||
},
|
||||
"bf16": {
|
||||
"enabled": false
|
||||
},
|
||||
"zero_force_ds_cpu_optimizer": false,
|
||||
"zero_optimization": {
|
||||
"stage": 2,
|
||||
"offload_optimizer": {
|
||||
"device": "none",
|
||||
"pin_memory": true
|
||||
},
|
||||
"allgather_partitions": true,
|
||||
"allgather_bucket_size": 5e8,
|
||||
"overlap_comm": false,
|
||||
"reduce_scatter": true,
|
||||
"reduce_bucket_size": 5e8,
|
||||
"contiguous_gradients" : true
|
||||
},
|
||||
"optimizer": {
|
||||
"type": "AdamW",
|
||||
"params": {
|
||||
"lr": 0.001,
|
||||
"weight_decay": 0.0001,
|
||||
"torch_adam": true,
|
||||
"adam_w_mode": true
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,97 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Copyright 2014 Johns Hopkins University (author: Daniel Povey)
|
||||
# Apache 2.0
|
||||
|
||||
remove_archive=false
|
||||
|
||||
if [ "$1" == --remove-archive ]; then
|
||||
remove_archive=true
|
||||
shift
|
||||
fi
|
||||
|
||||
if [ $# -ne 3 ]; then
|
||||
echo "Usage: $0 [--remove-archive] <data-base> <url-base> <corpus-part>"
|
||||
echo "e.g.: $0 /export/a15/vpanayotov/data www.openslr.org/resources/11 dev-clean"
|
||||
echo "With --remove-archive it will remove the archive after successfully un-tarring it."
|
||||
echo "<corpus-part> can be one of: dev-clean, test-clean, dev-other, test-other,"
|
||||
echo " train-clean-100, train-clean-360, train-other-500."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
data=$1
|
||||
url=$2
|
||||
part=$3
|
||||
|
||||
if [ ! -d "$data" ]; then
|
||||
echo "$0: no such directory $data"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
part_ok=false
|
||||
list="dev-clean test-clean dev-other test-other train-clean-100 train-clean-360 train-other-500"
|
||||
for x in $list; do
|
||||
if [ "$part" == $x ]; then part_ok=true; fi
|
||||
done
|
||||
if ! $part_ok; then
|
||||
echo "$0: expected <corpus-part> to be one of $list, but got '$part'"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -z "$url" ]; then
|
||||
echo "$0: empty URL base."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ -f $data/LibriTTS/$part/.complete ]; then
|
||||
echo "$0: data part $part was already successfully extracted, nothing to do."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
|
||||
# sizes of the archive files in bytes. This is some older versions.
|
||||
sizes_old="371012589 347390293 379743611 361838298 6420417880 23082659865 30626749128"
|
||||
# sizes_new is the archive file sizes of the final release. Some of these sizes are of
|
||||
# things we probably won't download.
|
||||
sizes_new="337926286 314305928 695964615 297279345 87960560420 33373768 346663984 328757843 6387309499 23049477885 30593501606"
|
||||
|
||||
if [ -f $data/$part.tar.gz ]; then
|
||||
size=$(/bin/ls -l $data/$part.tar.gz | awk '{print $5}')
|
||||
size_ok=false
|
||||
for s in $sizes_old $sizes_new; do if [ $s == $size ]; then size_ok=true; fi; done
|
||||
if ! $size_ok; then
|
||||
echo "$0: removing existing file $data/$part.tar.gz because its size in bytes $size"
|
||||
echo "does not equal the size of one of the archives."
|
||||
rm $data/$part.tar.gz
|
||||
else
|
||||
echo "$data/$part.tar.gz exists and appears to be complete."
|
||||
fi
|
||||
fi
|
||||
|
||||
if [ ! -f $data/$part.tar.gz ]; then
|
||||
if ! which wget >/dev/null; then
|
||||
echo "$0: wget is not installed."
|
||||
exit 1
|
||||
fi
|
||||
full_url=$url/$part.tar.gz
|
||||
echo "$0: downloading data from $full_url. This may take some time, please be patient."
|
||||
|
||||
if ! wget -P $data --no-check-certificate $full_url; then
|
||||
echo "$0: error executing wget $full_url"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
|
||||
if ! tar -C $data -xvzf $data/$part.tar.gz; then
|
||||
echo "$0: error un-tarring archive $data/$part.tar.gz"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
touch $data/LibriTTS/$part/.complete
|
||||
|
||||
echo "$0: Successfully downloaded and un-tarred $data/$part.tar.gz"
|
||||
|
||||
if $remove_archive; then
|
||||
echo "$0: removing $data/$part.tar.gz file since --remove-archive option was supplied."
|
||||
rm $data/$part.tar.gz
|
||||
fi
|
||||
@@ -0,0 +1,60 @@
|
||||
import argparse
|
||||
import logging
|
||||
import glob
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def main():
|
||||
wavs = list(glob.glob('{}/*/*/*wav'.format(args.src_dir)))
|
||||
|
||||
utt2wav, utt2text, utt2spk, spk2utt = {}, {}, {}, {}
|
||||
for wav in tqdm(wavs):
|
||||
txt = wav.replace('.wav', '.normalized.txt')
|
||||
if not os.path.exists(txt):
|
||||
logger.warning('{} do not exsist'.format(txt))
|
||||
continue
|
||||
with open(txt) as f:
|
||||
content = ''.join(l.replace('\n', '') for l in f.readline())
|
||||
utt = os.path.basename(wav).replace('.wav', '')
|
||||
spk = utt.split('_')[0]
|
||||
utt2wav[utt] = wav
|
||||
utt2text[utt] = content
|
||||
utt2spk[utt] = spk
|
||||
if spk not in spk2utt:
|
||||
spk2utt[spk] = []
|
||||
spk2utt[spk].append(utt)
|
||||
|
||||
with open('{}/wav.scp'.format(args.des_dir), 'w') as f:
|
||||
for k, v in utt2wav.items():
|
||||
f.write('{} {}\n'.format(k, v))
|
||||
with open('{}/text'.format(args.des_dir), 'w') as f:
|
||||
for k, v in utt2text.items():
|
||||
f.write('{} {}\n'.format(k, v))
|
||||
with open('{}/utt2spk'.format(args.des_dir), 'w') as f:
|
||||
for k, v in utt2spk.items():
|
||||
f.write('{} {}\n'.format(k, v))
|
||||
with open('{}/spk2utt'.format(args.des_dir), 'w') as f:
|
||||
for k, v in spk2utt.items():
|
||||
f.write('{} {}\n'.format(k, ' '.join(v)))
|
||||
if args.instruct != '':
|
||||
with open('{}/instruct'.format(args.des_dir), 'w') as f:
|
||||
for k, v in utt2text.items():
|
||||
f.write('{} {}\n'.format(k, args.instruct))
|
||||
return
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--src_dir',
|
||||
type=str)
|
||||
parser.add_argument('--des_dir',
|
||||
type=str)
|
||||
parser.add_argument('--instruct',
|
||||
type=str,
|
||||
default='')
|
||||
args = parser.parse_args()
|
||||
main()
|
||||
@@ -0,0 +1,50 @@
|
||||
import argparse
|
||||
import logging
|
||||
import os
|
||||
from tqdm import tqdm
|
||||
import torch
|
||||
import torchaudio
|
||||
from cosyvoice.cli.cosyvoice import CosyVoice2
|
||||
from cosyvoice.utils.file_utils import load_wav
|
||||
|
||||
|
||||
logger = logging.getLogger()
|
||||
|
||||
|
||||
def main():
|
||||
cosyvoice = CosyVoice2(args.ref_model)
|
||||
|
||||
utt2wav, utt2text = {}, {}
|
||||
with open('{}/wav.scp'.format(args.src_dir)) as f:
|
||||
for l in f:
|
||||
l = l.split('\n')[0].split()
|
||||
utt2wav[l[0]] = l[1]
|
||||
with open('{}/text'.format(args.src_dir)) as f:
|
||||
for l in f:
|
||||
l = l.split('\n')[0].split()
|
||||
utt2text[l[0]] = ' '.join(l[1:])
|
||||
|
||||
os.makedirs('{}/wav'.format(args.des_dir), exist_ok=True)
|
||||
with open('{}/wav.scp'.format(args.des_dir), 'w') as f:
|
||||
for utt, wav in tqdm(utt2wav.items()):
|
||||
prompt_speech_16k = load_wav(wav, 16000)
|
||||
if prompt_speech_16k.shape[1] >= 30 * 16000:
|
||||
continue
|
||||
speech_list = []
|
||||
for _, j in enumerate(cosyvoice.inference_zero_shot(utt2text[utt], utt2text[utt], prompt_speech_16k, stream=False, text_frontend=False)):
|
||||
speech_list.append(j['tts_speech'])
|
||||
negative_wav = os.path.abspath('{}/wav/{}'.format(args.des_dir, os.path.basename(wav)))
|
||||
torchaudio.save(negative_wav, torch.concat(speech_list, dim=1), cosyvoice.sample_rate, backend='soundfile')
|
||||
f.write('{} {}\n'.format(utt, negative_wav))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--src_dir',
|
||||
type=str)
|
||||
parser.add_argument('--des_dir',
|
||||
type=str)
|
||||
parser.add_argument('--ref_model',
|
||||
type=str)
|
||||
args = parser.parse_args()
|
||||
main()
|
||||
3
models/CosyVoice/examples/libritts/cosyvoice/path.sh
Normal file
3
models/CosyVoice/examples/libritts/cosyvoice/path.sh
Normal file
@@ -0,0 +1,3 @@
|
||||
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
|
||||
export PYTHONIOENCODING=UTF-8
|
||||
export PYTHONPATH=../../../:../../../third_party/Matcha-TTS:$PYTHONPATH
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user