Compare commits

...

6 Commits

Author SHA1 Message Date
Kevin Wong
0a5a17402c 更新 2026-02-24 16:55:29 +08:00
Kevin Wong
bc0fe9326a 更新 2026-02-11 17:48:38 +08:00
Kevin Wong
035ee29d72 更新 2026-02-11 14:33:05 +08:00
Kevin Wong
a6cc919e5c 更新 2026-02-11 13:57:41 +08:00
Kevin Wong
96a298e51c 更新 2026-02-11 13:48:45 +08:00
Kevin Wong
e33dfc3031 更新 2026-02-10 13:31:29 +08:00
336 changed files with 99081 additions and 882 deletions

278
Docs/ALIPAY_DEPLOY.md Normal file
View File

@@ -0,0 +1,278 @@
# 支付宝付费开通会员 — 部署指南
本文档涵盖支付宝电脑网站支付功能的完整部署流程。用户注册后通过支付宝付费自动激活会员,有效期 1 年。
---
## 前置条件
- 支付宝企业/个体商户账号
- 已在 [支付宝开放平台](https://open.alipay.com) 创建应用并获取 APPID
- 应用已开通 **「电脑网站支付」** 产品权限(`alipay.trade.page.pay` 接口)
- 服务器域名已配置 HTTPS支付宝回调要求公网可达
---
## 第一部分:支付宝开放平台配置
### 1. 创建应用
登录 https://open.alipay.com → 控制台 → 创建应用(或使用已有应用)。
### 2. 开通「电脑网站支付」产品
进入应用详情 → 产品绑定/产品管理 → 添加 **「电脑网站支付」** → 提交审核。
> **注意**:未开通此产品会导致 `ACQ.ACCESS_FORBIDDEN` 错误。
### 3. 生成密钥对
进入应用详情 → 开发设置 → 接口加签方式 → 选择 **RSA2(SHA256)**
1. 使用支付宝官方密钥工具生成 RSA2048 密钥对
2.**应用公钥** 上传到开放平台
3. 上传后平台会显示 **支付宝公钥**`alipayPublicKey_RSA2`
最终你会得到两样东西:
- **应用私钥**:你本地保存,代码用来签名请求
- **支付宝公钥**:平台返回给你,代码用来验证回调签名
> 应用公钥只是上传用的中间产物,代码中不需要。
---
## 第二部分:服务器配置
### 1. 放置密钥文件
将密钥保存为标准 PEM 格式,放到 `backend/keys/` 目录:
```bash
mkdir -p /home/rongye/ProgramFiles/ViGent2/backend/keys
```
**`backend/keys/app_private_key.pem`**(应用私钥):
```
-----BEGIN PRIVATE KEY-----
MIIEvQIBADANBgkqhkiG9w0BAQEFAASC...(你的私钥内容)
...
-----END PRIVATE KEY-----
```
**`backend/keys/alipay_public_key.pem`**(支付宝公钥):
```
-----BEGIN PUBLIC KEY-----
MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8A...(支付宝公钥内容)
...
-----END PUBLIC KEY-----
```
#### PEM 格式要求
支付宝密钥工具导出的是一行纯文本,需要转换为标准 PEM 格式:
- 必须有头尾标记(`-----BEGIN/END ...-----`
- 密钥内容每 64 字符换行
- 私钥头标记为 `-----BEGIN PRIVATE KEY-----`PKCS#8 格式)
- 公钥头标记为 `-----BEGIN PUBLIC KEY-----`
如果你拿到的是一行裸密钥,用以下命令转换:
```bash
# 私钥格式化(假设裸密钥在 raw_private.txt 中)
echo "-----BEGIN PRIVATE KEY-----" > app_private_key.pem
cat raw_private.txt | fold -w 64 >> app_private_key.pem
echo "-----END PRIVATE KEY-----" >> app_private_key.pem
# 公钥格式化
echo "-----BEGIN PUBLIC KEY-----" > alipay_public_key.pem
cat raw_public.txt | fold -w 64 >> alipay_public_key.pem
echo "-----END PUBLIC KEY-----" >> alipay_public_key.pem
```
> `backend/keys/` 目录已加入 `.gitignore`,不会被提交到仓库。
### 2. 配置环境变量
`backend/.env` 中添加:
```ini
# =============== 支付宝配置 ===============
ALIPAY_APP_ID=你的应用APPID
ALIPAY_PRIVATE_KEY_PATH=/home/rongye/ProgramFiles/ViGent2/backend/keys/app_private_key.pem
ALIPAY_PUBLIC_KEY_PATH=/home/rongye/ProgramFiles/ViGent2/backend/keys/alipay_public_key.pem
ALIPAY_NOTIFY_URL=https://vigent.hbyrkj.top/api/payment/notify
ALIPAY_RETURN_URL=https://vigent.hbyrkj.top/pay
```
| 变量 | 说明 |
|------|------|
| `ALIPAY_APP_ID` | 支付宝开放平台应用 APPID |
| `ALIPAY_PRIVATE_KEY_PATH` | 应用私钥 PEM 文件绝对路径 |
| `ALIPAY_PUBLIC_KEY_PATH` | 支付宝公钥 PEM 文件绝对路径 |
| `ALIPAY_NOTIFY_URL` | 异步回调地址(服务器间通信),必须公网 HTTPS 可达 |
| `ALIPAY_RETURN_URL` | 同步跳转地址(用户支付完成后浏览器跳转回的页面) |
`config.py` 中还有几个可调参数(已有默认值,一般不需要加到 .env
| 变量 | 默认值 | 说明 |
|------|--------|------|
| `ALIPAY_SANDBOX` | `false` | 是否使用沙箱环境 |
| `PAYMENT_AMOUNT` | `999.00` | 会员价格(元) |
| `PAYMENT_EXPIRE_DAYS` | `365` | 会员有效天数 |
### 3. 创建数据库表
通过 Docker 在本地 Supabase 中执行:
```bash
docker exec -i supabase-db psql -U postgres -c "
CREATE TABLE IF NOT EXISTS orders (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID REFERENCES users(id) ON DELETE CASCADE,
out_trade_no TEXT UNIQUE NOT NULL,
amount DECIMAL(10, 2) NOT NULL DEFAULT 999.00,
status TEXT DEFAULT 'pending' CHECK (status IN ('pending', 'paid', 'failed')),
trade_no TEXT,
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
paid_at TIMESTAMP WITH TIME ZONE
);
CREATE INDEX IF NOT EXISTS idx_orders_user_id ON orders(user_id);
CREATE INDEX IF NOT EXISTS idx_orders_out_trade_no ON orders(out_trade_no);
"
```
### 4. 安装依赖
```bash
# 后端(在 venv 中)
cd /home/rongye/ProgramFiles/ViGent2/backend
venv/bin/pip install python-alipay-sdk
```
> 前端无额外依赖需要安装。
### 5. Nginx 配置
确保 Nginx 将 `/api/payment/notify` 代理到后端。如果现有配置已覆盖 `/api/` 前缀,则无需额外修改:
```nginx
location /api/ {
proxy_pass http://localhost:8006;
# ... 现有配置
}
```
### 6. 重启服务
```bash
# 构建前端
cd /home/rongye/ProgramFiles/ViGent2/frontend
npx next build
# 重启
pm2 restart vigent2-backend
pm2 restart vigent2-frontend
```
---
## 第三部分:正式上线
测试通过后,将 `backend/app/core/config.py` 中的测试金额改为正式价格:
```python
PAYMENT_AMOUNT: float = 999.00 # 正式价格
```
或在 `backend/.env` 中添加覆盖:
```ini
PAYMENT_AMOUNT=999.00
```
然后重启后端:
```bash
pm2 restart vigent2-backend
```
---
## 支付流程说明
```
用户注册 → 登录(密码正确但 is_active=false
→ 后端返回 403 + payment_token
→ 前端跳转 /pay 页面
→ POST /api/payment/create-order → 返回支付宝收银台 URL
→ 前端重定向到支付宝收银台页面(支持扫码、账号登录、余额等多种支付方式)
→ 用户完成支付
→ 支付宝异步回调 POST /api/payment/notify
→ 后端验签 → 更新订单 → 激活用户is_active=true, expires_at=+365天
→ 支付宝同步跳转回 /pay?out_trade_no=xxx
→ 前端轮询 GET /api/payment/status/{out_trade_no}
→ 轮询到 paid → 提示成功 → 跳转登录页
→ 用户重新登录 → 成功进入系统
```
**电脑网站支付 vs 当面付**:电脑网站支付(`alipay.trade.page.pay`)会跳转到支付宝官方收银台页面,用户可以选择扫码、支付宝账号登录、余额等多种方式支付,体验更好。当面付(`alipay.trade.precreate`)仅生成一个二维码,只能扫码支付。
会员到期续费同流程:登录时检测到过期 → 返回 PAYMENT_REQUIRED → 跳转 /pay。
管理员手动激活功能不受影响,两种方式并存。
---
## 涉及文件
| 文件 | 变更类型 | 说明 |
|------|---------|------|
| `backend/requirements.txt` | 修改 | 添加 `python-alipay-sdk` |
| `backend/database/schema.sql` | 修改 | 新增 `orders` 表 |
| `backend/app/core/config.py` | 修改 | 支付宝配置项 |
| `backend/app/core/security.py` | 修改 | payment_token 函数 |
| `backend/app/core/deps.py` | 修改 | is_active 安全兜底 |
| `backend/app/repositories/orders.py` | 新建 | orders 数据层 |
| `backend/app/modules/payment/__init__.py` | 新建 | 模块初始化 |
| `backend/app/modules/payment/schemas.py` | 新建 | 请求/响应模型 |
| `backend/app/modules/payment/service.py` | 新建 | 支付业务逻辑(电脑网站支付) |
| `backend/app/modules/payment/router.py` | 新建 | 3 个 API 端点 |
| `backend/app/modules/auth/router.py` | 修改 | 登录返回 PAYMENT_REQUIRED |
| `backend/app/main.py` | 修改 | 注册 payment_router |
| `backend/.env` | 修改 | 支付宝环境变量 |
| `backend/keys/` | 新建 | PEM 密钥文件 |
| `frontend/src/shared/lib/auth.ts` | 修改 | login() 处理 paymentToken |
| `frontend/src/shared/api/axios.ts` | 修改 | PUBLIC_PATHS 加 /pay |
| `frontend/src/app/login/page.tsx` | 修改 | paymentToken 跳转 |
| `frontend/src/app/register/page.tsx` | 修改 | 注册成功提示文案 |
| `frontend/src/app/pay/page.tsx` | 新建 | 付费页面(重定向到支付宝收银台) |
---
## 常见问题
### RSA key format is not supported
密钥文件缺少 PEM 头尾标记或未按 64 字符换行。参考「PEM 格式要求」重新格式化。
### ACQ.ACCESS_FORBIDDEN
应用未开通「电脑网站支付」产品。在支付宝开放平台 → 应用详情 → 产品管理中添加并开通。
### 支付宝回调不到
1. 检查 `ALIPAY_NOTIFY_URL` 是否公网 HTTPS 可达
2. 检查 Nginx 是否将 `/api/payment/notify` 代理到后端
3. 支付宝回调超时15s 未响应)会重试,共重试 8 次,持续 24 小时
### 支付完成后页面未跳转回来
检查 `ALIPAY_RETURN_URL` 配置是否正确,必须是前端 `/pay` 页面的完整 URL`https://vigent.hbyrkj.top/pay`)。支付宝会在用户支付完成后将浏览器重定向到此地址,并附带 `out_trade_no` 等参数。
### 前端显示"网络错误"而非具体错误
API 函数缺少 try/catch 捕获 axios 异常。已在 `auth.ts``register()``login()` 中修复。

View File

@@ -33,11 +33,13 @@ backend/
│ │ ├── materials/ # 素材管理router/schemas/service
│ │ ├── publish/ # 多平台发布
│ │ ├── auth/ # 认证与会话
│ │ ├── ai/ # AI 功能(标题标签生成
│ │ ├── ai/ # AI 功能(标题标签生成、多语言翻译
│ │ ├── assets/ # 静态资源(字体/样式/BGM
│ │ ├── ref_audios/ # 声音克隆参考音频router/schemas/service
│ │ ├── generated_audios/ # 预生成配音管理router/schemas/service
│ │ ├── login_helper/ # 扫码登录辅助
│ │ ├── tools/ # 工具接口router/schemas/service
│ │ ├── payment/ # 支付宝付费开通router/schemas/service
│ │ └── admin/ # 管理员功能
│ ├── repositories/ # Supabase 数据访问
│ ├── services/ # 外部服务集成
@@ -73,6 +75,18 @@ backend/
- 错误通过 `HTTPException` 抛出,统一由全局异常处理返回 `{success:false, message, code}`
- 不再使用 `detail` 作为前端错误文案(前端已改为读 `message`)。
### `/api/videos/generate` 参数契约(关键约定)
- `custom_assignments` 每项使用 `material_path/start/end/source_start/source_end?`,并以时间轴可见段为准。
- `output_aspect_ratio` 仅允许 `9:16` / `16:9`,默认 `9:16`
- 标题显示模式参数:
- `title_display_mode`: `short` / `persistent`(默认 `short`
- `title_duration`: 默认 `4.0`(秒),仅 `short` 模式生效
- 片头副标题参数:
- `secondary_title`: 副标题文字(可选,限 20 字),仅在视频画面中显示,不参与发布标题
- `secondary_title_style_id` / `secondary_title_font_size` / `secondary_title_top_margin`: 副标题样式配置
- workflow/remotion 侧需保持字段透传一致,避免前后端语义漂移。
---
## 4. 认证与权限
@@ -156,7 +170,13 @@ backend/user_data/{user_uuid}/cookies/
- `DOUYIN_LOCALE` / `DOUYIN_TIMEZONE_ID`
- `DOUYIN_FORCE_SWIFTSHADER`
- `DOUYIN_DEBUG_ARTIFACTS` / `DOUYIN_RECORD_VIDEO` / `DOUYIN_KEEP_SUCCESS_VIDEO`
- `DOUYIN_COOKIE` (抖音视频下载 Cookie)
### 支付宝
- `ALIPAY_APP_ID` / `ALIPAY_PRIVATE_KEY_PATH` / `ALIPAY_PUBLIC_KEY_PATH`
- `ALIPAY_NOTIFY_URL` / `ALIPAY_RETURN_URL`
- `ALIPAY_SANDBOX` (沙箱模式,默认 false)
- `PAYMENT_AMOUNT` (会员价格,默认 999.00)
- `PAYMENT_EXPIRE_DAYS` (会员有效天数,默认 365)
---

View File

@@ -19,12 +19,14 @@ backend/
│ │ ├── materials/ # 素材管理router/schemas/service
│ │ ├── publish/ # 多平台发布
│ │ ├── auth/ # 认证与会话
│ │ ├── ai/ # AI 功能(标题标签生成)
│ │ ├── assets/ # 静态资源(字体/样式/BGM
│ │ ├── ref_audios/ # 声音克隆参考音频router/schemas/service
│ │ ├── login_helper/ # 扫码登录辅助
│ │ ├── tools/ # 工具接口router/schemas/service
│ │ ── admin/ # 管理员功能
│ │ ├── ai/ # AI 功能(标题标签生成、多语言翻译
│ │ ├── assets/ # 静态资源(字体/样式/BGM
│ │ ├── ref_audios/ # 声音克隆参考音频router/schemas/service
│ │ ├── generated_audios/ # 预生成配音管理router/schemas/service
│ │ ├── login_helper/ # 扫码登录辅助
│ │ ── tools/ # 工具接口router/schemas/service
│ │ ├── payment/ # 支付宝付费开通router/schemas/service
│ │ └── admin/ # 管理员功能
│ ├── repositories/ # Supabase 数据访问
│ ├── services/ # 外部服务集成 (TTS/Remotion/Storage/Uploader 等)
│ └── tests/ # 单元测试与集成测试
@@ -50,6 +52,8 @@ backend/
* `POST /api/auth/register`: 用户注册
* `GET /api/auth/me`: 获取当前用户信息
> 授权有效期策略:在登录与受保护接口鉴权时,后端会检查 `users.expires_at`。账号到期会自动停用 (`is_active=false`) 并清理 session返回 `403: 会员已到期,请续费`。
2. **视频生成 (Videos)**
* `POST /api/videos/generate`: 提交生成任务
* `GET /api/videos/tasks/{task_id}`: 查询单个任务状态
@@ -76,20 +80,36 @@ backend/
* `GET /api/assets/bgm`: 背景音乐列表
6. **声音克隆 (Ref Audios)**
* `POST /api/ref-audios`: 上传参考音频 (multipart/form-data)
* `POST /api/ref-audios`: 上传参考音频 (multipart/form-data,自动 Whisper 转写 ref_text)
* `GET /api/ref-audios`: 获取参考音频列表
* `PUT /api/ref-audios/{id}`: 重命名参考音频
* `DELETE /api/ref-audios/{id}`: 删除参考音频
* `POST /api/ref-audios/{id}/retranscribe`: 重新识别参考音频文字Whisper 转写 + 超 10s 自动截取)
7. **AI 功能 (AI)**
* `POST /api/ai/generate-meta`: AI 生成标题和标签
* `POST /api/ai/translate`: AI 多语言翻译(支持 9 种目标语言)
8. **工具 (Tools)**
8. **预生成配音 (Generated Audios)**
* `POST /api/generated-audios/generate`: 异步生成配音(返回 task_id
* `GET /api/generated-audios/tasks/{task_id}`: 轮询生成进度
* `GET /api/generated-audios`: 列出用户所有配音
* `DELETE /api/generated-audios/{audio_id}`: 删除配音
* `PUT /api/generated-audios/{audio_id}`: 重命名配音
9. **工具 (Tools)**
* `POST /api/tools/extract-script`: 从视频链接提取文案
9. **健康检查**
10. **健康检查**
* `GET /api/lipsync/health`: LatentSync 服务健康状态
* `GET /api/voiceclone/health`: Qwen3-TTS 服务健康状态
* `GET /api/voiceclone/health`: CosyVoice 3.0 服务健康状态
11. **支付 (Payment)**
* `POST /api/payment/create-order`: 创建支付宝电脑网站支付订单(需 payment_token
* `POST /api/payment/notify`: 支付宝异步通知回调(返回纯文本 success/fail
* `GET /api/payment/status/{out_trade_no}`: 查询订单支付状态(前端轮询)
> 登录时若账号未激活或已过期,返回 403 + `payment_token`,前端跳转 `/pay` 页面完成付费。详见 [支付宝部署指南](ALIPAY_DEPLOY.md)。
### 统一响应结构
@@ -113,17 +133,34 @@ backend/
- `tts_mode`: TTS 模式 (`edgetts` / `voiceclone`)
- `voice`: EdgeTTS 音色 IDedgetts 模式)
- `ref_audio_id` / `ref_text`: 参考音频 ID 与文本voiceclone 模式)
- `generated_audio_id`: 预生成配音 ID存在时跳过内联 TTS使用已生成的配音文件
- `speed`: 语速(声音克隆模式,默认 1.0,范围 0.8-1.2
- `custom_assignments`: 自定义素材分配数组(每项含 `material_path` / `start` / `end` / `source_start` / `source_end?`),存在时优先按时间轴可见段生成
- `output_aspect_ratio`: 输出画面比例(`9:16``16:9`,默认 `9:16`
- `language`: TTS 语言(默认自动检测,声音克隆时透传给 CosyVoice 3.0
- `title`: 片头标题文字
- `title_display_mode`: 标题显示模式(`short` / `persistent`,默认 `short`
- `title_duration`: 标题显示时长(秒,默认 `4.0``short` 模式生效)
- `subtitle_style_id`: 字幕样式 ID
- `title_style_id`: 标题样式 ID
- `subtitle_font_size`: 字幕字号(覆盖样式默认值)
- `title_font_size`: 标题字号(覆盖样式默认值)
- `title_top_margin`: 标题距顶部像素
- `secondary_title`: 片头副标题文字(可选,限 20 字,仅视频画面显示)
- `secondary_title_style_id`: 副标题样式 ID
- `secondary_title_font_size`: 副标题字号
- `secondary_title_top_margin`: 副标题距主标题间距
- `subtitle_bottom_margin`: 字幕距底部像素
- `enable_subtitles`: 是否启用字幕
- `bgm_id`: 背景音乐 ID
- `bgm_volume`: 背景音乐音量0-1默认 0.2
### 多素材稳定性说明
- 多素材片段在拼接前统一重编码,并强制 `25fps + CFR`,减少段边界时间基不一致导致的画面卡顿。
- concat 流程启用 `+genpts` 重建时间戳,提升拼接后时间轴连续性。
- 对带旋转元数据的 MOV 素材会先做方向归一化,再进入分辨率判断和后续流程。
## 📦 资源库与静态资源
- 本地资源目录:`backend/assets/{fonts,bgm,styles}`

211
Docs/COSYVOICE3_DEPLOY.md Normal file
View File

@@ -0,0 +1,211 @@
# CosyVoice 3.0 部署文档
## 概览
| 项目 | 值 |
|------|------|
| 模型 | Fun-CosyVoice3-0.5B-2512 (0.5B 参数) |
| 端口 | 8010 |
| GPU | 0 (CUDA_VISIBLE_DEVICES=0) |
| PM2 名称 | vigent2-cosyvoice (id=15) |
| Conda 环境 | cosyvoice (Python 3.10) |
| 启动脚本 | `run_cosyvoice.sh` |
| 服务脚本 | `models/CosyVoice/cosyvoice_server.py` |
| 模型加载时间 | ~22-34 秒 |
| 显存占用 | ~3-5 GB |
## 支持语言
中文、英文、日语、韩语、德语、西班牙语、法语、意大利语、俄语18+ 中国方言
## 目录结构
```
models/CosyVoice/
├── cosyvoice_server.py # FastAPI 服务 (端口 8010)
├── cosyvoice/ # CosyVoice 源码
│ └── cli/cosyvoice.py # AutoModel 入口
├── third_party/Matcha-TTS/ # 子模块依赖
├── pretrained_models/
│ ├── Fun-CosyVoice3-0.5B/ # 模型文件 (~8.2GB)
│ │ ├── llm.pt # LLM 模型 (1.9GB)
│ │ ├── llm.rl.pt # RL 模型 (1.9GB, 备用)
│ │ ├── flow.pt # Flow 模型 (1.3GB)
│ │ ├── hift.pt # HiFT 声码器 (80MB)
│ │ ├── campplus.onnx # 说话人嵌入 (27MB)
│ │ ├── speech_tokenizer_v3.onnx # 语音分词器 (925MB)
│ │ ├── cosyvoice3.yaml # 模型配置
│ │ └── CosyVoice-BlankEN/ # Qwen tokenizer
│ └── CosyVoice-ttsfrd/ # 文本正则化资源
│ ├── resource/ # 解压后的 ttsfrd 资源
│ └── resource.zip
run_cosyvoice.sh # PM2 启动脚本
```
## API 接口
### GET /health
健康检查,返回:
```json
{
"service": "CosyVoice 3.0 Voice Clone",
"model": "Fun-CosyVoice3-0.5B",
"ready": true,
"gpu_id": 0
}
```
### POST /generate
声音克隆生成。
**参数 (multipart/form-data)**
| 参数 | 类型 | 必填 | 说明 |
|------|------|------|------|
| ref_audio | File | 是 | 参考音频 (WAV) |
| text | string | 是 | 要合成的文本 |
| ref_text | string | 是 | 参考音频的转写文字 |
| language | string | 否 | 语言 (默认 "Chinese"CosyVoice 自动检测) |
| speed | float | 否 | 语速 (默认 1.0,范围 0.5-2.0,建议 0.8-1.2) |
**返回:** WAV 音频文件
**状态码:**
- 200: 成功
- 429: GPU 忙,请重试
- 500: 生成失败/超时
- 503: 模型未加载/服务中毒
## 安全机制
1. **GPU 推理锁** (`asyncio.Lock`): 防止并发推理导致 GPU 状态损坏
2. **429 拒绝**: 锁被占用时立即返回 429客户端重试
3. **超时保护**: `60 + len(text) * 2` 秒,上限 300 秒
4. **Poisoned 标记**: 超时后标记服务为中毒状态,健康检查返回 `ready: false`
5. **强制退出**: 超时后 1.5 秒强制 `os._exit(1)`PM2 自动重启
6. **启动自检**: 启动时用短文本做一次真实推理,验证 GPU 推理链路可用;失败则 `_model_loaded = False`,健康检查返回 `ready: false`,避免假阳性
7. **参考音频自动截取**: 参考音频超过 10 秒时自动截取前 10 秒CosyVoice 建议 3-10 秒),避免采样异常
## 运维命令
```bash
# 启动
pm2 start run_cosyvoice.sh --name vigent2-cosyvoice
# 重启
pm2 restart vigent2-cosyvoice
# 查看日志
pm2 logs vigent2-cosyvoice --lines 50
# 健康检查
curl http://localhost:8010/health
# 停止
pm2 stop vigent2-cosyvoice
```
## 从零部署步骤
### 1. 克隆仓库
```bash
cd /home/rongye/ProgramFiles/ViGent2/models
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
cd CosyVoice
git submodule update --init --recursive
```
### 2. 创建 Conda 环境
```bash
conda create -n cosyvoice -y python=3.10
conda activate cosyvoice
```
### 3. 安装依赖
注意:不能直接 `pip install -r requirements.txt`,有版本冲突需要处理。
```bash
# 安装 PyTorch 2.3.1 (CUDA 12.1) — 必须先装,版本严格要求
pip install torch==2.3.1 torchaudio==2.3.1 --index-url https://download.pytorch.org/whl/cu121
# 核心推理依赖
pip install conformer==0.3.2 HyperPyYAML==1.2.2 inflect==7.3.1 \
librosa==0.10.2 lightning==2.2.4 modelscope==1.20.0 omegaconf==2.3.0 \
pydantic==2.7.0 soundfile==0.12.1 fastapi==0.115.6 uvicorn==0.30.0 \
transformers==4.51.3 protobuf==4.25 hydra-core==1.3.2 \
rich==13.7.1 diffusers==0.29.0 x-transformers==2.11.24 wetext==0.0.4
# onnxruntime-gpu
pip install onnxruntime-gpu==1.18.0 \
--extra-index-url https://aiinfra.pkgs.visualstudio.com/PublicPackages/_packaging/onnxruntime-cuda-12/pypi/simple/
# 其他必要依赖
pip install gdown matplotlib pyarrow wget onnx python-multipart httpx
# openai-whisper 需要 setuptools < 71提供 pkg_resources
pip install "setuptools<71"
pip install --no-build-isolation openai-whisper==20231117
# pyworld 需要 g++ 和 Cython
pip install Cython
PATH="/usr/bin:$PATH" pip install pyworld==0.3.4
# 关键版本修复
pip install "numpy<2" # onnxruntime-gpu 不兼容 numpy 2.x
pip install "ruamel.yaml<0.18" # hyperpyyaml 不兼容 ruamel.yaml 0.19+
```
> **重要**: CosyVoice 要求 torch==2.3.1。torch 2.10+ 会导致 CUBLAS_STATUS_INVALID_VALUE 错误。
> torch 2.3.1+cu121 自带 nvidia-cudnn-cu12onnxruntime CUDAExecutionProvider 可正常使用。
### 4. 下载模型
```bash
# 使用 huggingface_hub (国内用 hf-mirror.com)
HF_ENDPOINT=https://hf-mirror.com python -c "
from huggingface_hub import snapshot_download
snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B')
snapshot_download('FunAudioLLM/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
"
```
### 5. 安装 ttsfrd (可选,提升文本正则化质量)
```bash
cd pretrained_models/CosyVoice-ttsfrd/
unzip resource.zip -d .
pip install ttsfrd_dependency-0.1-py3-none-any.whl
pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl
```
### 6. 注册 PM2
```bash
pm2 start run_cosyvoice.sh --name vigent2-cosyvoice
pm2 save
```
## 已知问题
1. **ttsfrd "prepare tts engine failed"**: ttsfrd C 库内部日志Python 层初始化成功,不影响使用
2. **Sliding Window Attention 警告**: transformers 库提示,不影响推理结果
3. **onnxruntime Memcpy 性能提示**: `Memcpy nodes are not supported by the CUDA EP`,仅为性能建议日志,不影响功能
> libcudnn.so.8 问题在 torch 2.3.1+cu121 环境下已解决(自带 nvidia-cudnn-cu12onnxruntime CUDAExecutionProvider 可正常加载。
## 与 Qwen3-TTS 对比
| 特性 | Qwen3-TTS (已停用) | CosyVoice 3.0 (当前) |
|------|-----------|----------------|
| 端口 | 8009 | 8010 |
| 模型大小 | 0.6B | 0.5B |
| 语言 | 中/英/日/韩 | 9 语言 + 18 方言 |
| 克隆方式 | ref_audio + ref_text | ref_audio + ref_text |
| prompt 格式 | 直接传 ref_text | `You are a helpful assistant.<\|endofprompt\|>` + ref_text |
| 内置分段 | 无,需客户端分段 | 内置 text_normalize 自动分段 |
| 状态 | 已停用 (PM2 stopped) | 生产使用中 |

View File

@@ -165,6 +165,8 @@ playwright install chromium
CREATE POLICY "Allow public read" ON storage.objects FOR SELECT TO anon USING (bucket_id = 'materials' OR bucket_id = 'outputs');
EOF
```
> **注意**:后端启动时会自动创建额外的存储桶(`ref-audios`、`generated-audios`),无需手动创建。
---
@@ -211,6 +213,15 @@ cp .env.example .env
| `DOUYIN_KEEP_SUCCESS_VIDEO` | false | 成功后保留录屏 |
| `CORS_ORIGINS` | `*` | CORS 允许源 (生产环境建议白名单) |
| `DOUYIN_COOKIE` | 空 | 抖音视频下载 Cookie (文案提取功能) |
| `ALIPAY_APP_ID` | 空 | 支付宝应用 APPID |
| `ALIPAY_PRIVATE_KEY_PATH` | 空 | 应用私钥 PEM 文件路径 |
| `ALIPAY_PUBLIC_KEY_PATH` | 空 | 支付宝公钥 PEM 文件路径 |
| `ALIPAY_NOTIFY_URL` | 空 | 支付宝异步回调地址 (公网 HTTPS) |
| `ALIPAY_RETURN_URL` | 空 | 支付完成后浏览器跳转地址 |
| `PAYMENT_AMOUNT` | `999.00` | 会员价格 (元) |
| `PAYMENT_EXPIRE_DAYS` | `365` | 会员有效天数 |
> 支付宝完整配置步骤密钥生成、PEM 格式、产品开通等)请参考 **[支付宝部署指南](ALIPAY_DEPLOY.md)**。
---
@@ -334,34 +345,28 @@ chmod +x run_latentsync.sh
pm2 start ./run_latentsync.sh --name vigent2-latentsync
```
### 4. 启动 Qwen3-TTS 声音克隆服务 (可选)
### 4. 启动 CosyVoice 3.0 声音克隆服务 (可选)
> 如需使用声音克隆功能,需要启动此服务。
> 如需使用声音克隆功能,需要启动此服务。详细部署步骤见 [CosyVoice 3.0 部署文档](COSYVOICE3_DEPLOY.md)。
1. 安装 HTTP 服务依赖:
```bash
conda activate qwen-tts
pip install fastapi uvicorn python-multipart
```
1. 启动脚本位于项目根目录: `run_cosyvoice.sh`
2. 启动脚本位于项目根目录: `run_qwen_tts.sh`
3. 使用 pm2 启动:
2. 使用 pm2 启动:
```bash
cd /home/rongye/ProgramFiles/ViGent2
pm2 start ./run_qwen_tts.sh --name vigent2-qwen-tts
pm2 start ./run_cosyvoice.sh --name vigent2-cosyvoice
pm2 save
```
4. 验证服务:
3. 验证服务:
```bash
# 检查健康状态
curl http://localhost:8009/health
curl http://localhost:8010/health
```
### 5. 启动服务看门狗 (Watchdog)
> 🛡️ **推荐**:监控 Qwen-TTS 和 LatentSync 服务健康状态,卡死时自动重启。
> 🛡️ **推荐**:监控 CosyVoice 和 LatentSync 服务健康状态,卡死时自动重启。
```bash
cd /home/rongye/ProgramFiles/ViGent2
@@ -382,7 +387,7 @@ pm2 startup
pm2 status # 查看所有服务状态
pm2 logs # 查看所有日志
pm2 logs vigent2-backend # 查看后端日志
pm2 logs vigent2-qwen-tts # 查看 Qwen3-TTS 日志
pm2 logs vigent2-cosyvoice # 查看 CosyVoice 日志
pm2 restart all # 重启所有服务
pm2 stop vigent2-latentsync # 停止 LatentSync 服务
pm2 delete all # 删除所有服务
@@ -521,7 +526,7 @@ python3 -c "import torch; print(torch.cuda.is_available())"
sudo lsof -i :8006
sudo lsof -i :3002
sudo lsof -i :8007
sudo lsof -i :8009 # Qwen3-TTS
sudo lsof -i :8010 # CosyVoice
```
### 查看日志
@@ -531,7 +536,7 @@ sudo lsof -i :8009 # Qwen3-TTS
pm2 logs vigent2-backend
pm2 logs vigent2-frontend
pm2 logs vigent2-latentsync
pm2 logs vigent2-qwen-tts
pm2 logs vigent2-cosyvoice
```
### SSH 连接卡顿 / 系统响应慢
@@ -562,6 +567,7 @@ pm2 logs vigent2-qwen-tts
| `playwright` | 社交媒体自动发布 |
| `biliup` | B站视频上传 |
| `loguru` | 日志管理 |
| `python-alipay-sdk` | 支付宝支付集成 |
### 前端关键依赖
@@ -570,6 +576,7 @@ pm2 logs vigent2-qwen-tts
| `next` | React 框架 |
| `swr` | 数据请求与缓存 |
| `tailwindcss` | CSS 样式 |
| `wavesurfer.js` | 音频波形(时间轴编辑器) |
### LatentSync 关键依赖

856
Docs/DevLogs/Day23.md Normal file
View File

@@ -0,0 +1,856 @@
## 🎙️ 配音前置重构 — 第一阶段 (Day 23)
### 概述
将配音从视频生成流程中独立出来,实现"先生成配音 → 选中配音 → 再选素材 → 生成视频"的新工作流。用户可以独立管理配音(生成/试听/改名/删除/选择),并在选中配音后看到时长信息,为第二阶段的素材时间轴编排奠定数据基础。
**旧流程**: 文案 + 选素材 → 一键生成(内联 TTS → Whisper → 均分 → LipSync → 合成)
**新流程**: 文案 → 配音方式 → **生成配音** → 选中配音 → 选素材 → 背景音乐 → 生成视频
---
### 一、后端:新增 `generated_audios` 模块
#### 模块结构
```
backend/app/modules/generated_audios/
├── __init__.py
├── router.py # 5 个 API 端点
├── schemas.py # 请求/响应模型
└── service.py # 生成/列表/删除/改名
```
#### API 端点
| 方法 | 路径 | 说明 |
|------|------|------|
| POST | `/api/generated-audios/generate` | 异步生成配音(返回 task_id |
| GET | `/api/generated-audios/tasks/{task_id}` | 轮询生成进度 |
| GET | `/api/generated-audios` | 列出用户所有配音 |
| DELETE | `/api/generated-audios/{audio_id}` | 删除配音 |
| PUT | `/api/generated-audios/{audio_id}` | 改名 |
#### 存储方案
- Supabase 存储桶:`generated-audios`(启动时自动创建)
- 音频文件:`{user_id}/{timestamp}_audio.wav`
- 元数据文件:`{user_id}/{timestamp}_audio.json`(含 display_name、text、tts_mode、duration_sec 等)
#### 生成流程
复用现有 `TTSService` / `voice_clone_service` / `task_store`
```
POST /generate → 创建 task → BackgroundTask:
1. edgetts → TTSService.generate_audio()
voiceclone → 下载 ref_audio → voice_clone_service.generate_audio()
2. ffprobe 获取时长
3. 上传 .wav + .json 到 generated-audios 桶
4. 更新 task(status=completed, output={audio_id, duration_sec, ...})
```
---
### 二、后端:修改视频生成 workflow
#### `GenerateRequest` 新增字段
```python
generated_audio_id: Optional[str] = None # 预生成配音 ID存在时跳过内联 TTS
```
#### `workflow.py` TTS 阶段新增分支
```python
if req.generated_audio_id:
# 下载预生成配音 + 从元数据读取 language
elif req.tts_mode == "voiceclone":
# 原有声音克隆逻辑
else:
# 原有 EdgeTTS 逻辑
```
向后兼容:不传 `generated_audio_id` 时,原有内联 TTS 流程不受影响。
---
### 三、前端:新增配音列表 hook + 面板
#### `useGeneratedAudios.ts`
- 状态:`generatedAudios[]``selectedAudio``isGeneratingAudio``audioTask`
- 方法:`fetchGeneratedAudios()``generateAudio()``deleteAudio()``renameAudio()``selectAudio()`
- 轮询:生成后 1s 轮询 task 状态,完成后自动刷新列表并选中最新配音
- 独立于视频生成的 TaskContext不互相干扰
#### `GeneratedAudiosPanel.tsx`
- 每条配音:播放/暂停、名称、时长、重命名、删除
- 选中态:`border-purple-500 bg-purple-500/20`
- 内嵌进度条(生成中显示)
- 底部显示选中配音的原始文案(截断)
- 播放逻辑自包含于面板内(`new Audio()` + play/pause toggle
---
### 四、前端UI 面板重排序
**旧顺序**: MaterialSelector → ScriptEditor → TitleSubtitle → VoiceSelector → BgmPanel → GenerateActionBar
**新顺序**:
1. ScriptEditor文案编辑
2. TitleSubtitlePanel标题与字幕样式
3. VoiceSelector配音方式
4. **GeneratedAudiosPanel**(配音列表)← 新增
5. MaterialSelector视频素材← 后移,需选中配音才解锁
6. BgmPanel背景音乐
7. GenerateActionBar生成视频
#### 素材区门控
未选中配音时,素材区显示半透明遮罩 + "请先生成并选中配音"提示。素材上传/预览/改名/删除始终可用,仅选择勾选被遮罩。
#### 时长信息
选中配音后MaterialSelector 顶部显示:
```
当前配音: 45.2 秒 | 已选 3 个素材(自动均分每段 ~15.1 秒)
```
#### 生成按钮条件更新
```typescript
// 旧条件
disabled={isGenerating || selectedMaterials.length === 0 || (ttsMode === "voiceclone" && !selectedRefAudio)}
// 新条件
disabled={isGenerating || selectedMaterials.length === 0 || !selectedAudio}
```
---
### 五、持久化
`useHomePersistence` 新增 `selectedAudioId` 的 localStorage 读写,刷新页面后恢复选中的配音。
---
### 涉及文件汇总
#### 后端新增
| 文件 | 说明 |
|------|------|
| `backend/app/modules/generated_audios/__init__.py` | 模块标记 |
| `backend/app/modules/generated_audios/router.py` | 5 个 API 端点 |
| `backend/app/modules/generated_audios/service.py` | 生成/列表/删除/改名 |
| `backend/app/modules/generated_audios/schemas.py` | 请求/响应模型 |
#### 后端修改
| 文件 | 变更 |
|------|------|
| `backend/app/main.py` | 注册 generated_audios 路由 |
| `backend/app/services/storage.py` | 新增 `BUCKET_GENERATED_AUDIOS`,启动时自动创建桶 |
| `backend/app/modules/videos/schemas.py` | `GenerateRequest` 新增 `generated_audio_id` 字段 |
| `backend/app/modules/videos/workflow.py` | TTS 阶段新增预生成音频分支 |
#### 前端新增
| 文件 | 说明 |
|------|------|
| `frontend/src/features/home/model/useGeneratedAudios.ts` | 配音列表 hook |
| `frontend/src/features/home/ui/GeneratedAudiosPanel.tsx` | 配音列表面板 |
#### 前端修改
| 文件 | 变更 |
|------|------|
| `frontend/src/features/home/ui/HomePage.tsx` | 面板重排序 + 素材区门控 + 插入 GeneratedAudiosPanel |
| `frontend/src/features/home/ui/MaterialSelector.tsx` | 新增 `selectedAudioDuration` prop + 时长信息显示 |
| `frontend/src/features/home/ui/GenerateActionBar.tsx` | 禁用条件改为 `!selectedAudio` |
| `frontend/src/features/home/model/useHomeController.ts` | 集成 useGeneratedAudios、新增 handleGenerateAudio、修改 handleGenerate 使用 generated_audio_id |
| `frontend/src/features/home/model/useHomePersistence.ts` | 新增 selectedAudioId 持久化 |
---
## 🎞️ 素材时间轴编排 — 第二阶段 (Day 23)
### 概述
在第一阶段"配音前置"基础上,新增**时间轴编辑器**,用户可以:
1. 在音频波形上查看各素材块的时长分配
2. 拖拽分割线调整每段素材的时长(无缝铺满,调整一段自动压缩/扩展相邻段)
3. 为每段素材设置**源视频截取起点**(从视频任意位置开始,而非始终从头)
**旧行为**: 多素材时自动均分(`_split_equal`),无法控制每段时长和源视频起始点
**新行为**: 时间轴编辑器可视化分配 + 拖拽调整 + ClipTrimmer 截取设置
---
### 一、后端改动
#### 1.1 新增 `CustomAssignment` 模型
```python
# backend/app/modules/videos/schemas.py
class CustomAssignment(BaseModel):
material_path: str
start: float # 音频时间轴起点
end: float # 音频时间轴终点
source_start: float = 0.0 # 源视频截取起点
```
`GenerateRequest` 新增 `custom_assignments: Optional[List[CustomAssignment]] = None`。存在时跳过 Whisper 均分,直接使用用户定义的分配。
#### 1.2 `prepare_segment` 支持 `source_start`
```python
def prepare_segment(self, video_path, target_duration, output_path,
target_resolution=None, source_start: float = 0.0):
```
关键逻辑:
- `source_start > 0` 时使用 `-ss` 快速 seek并强制重编码避免 stream copy 关键帧不精确)
- 当需要循环且有 `source_start` 时,先裁剪出 `source_start` 到视频结尾的片段,再循环裁剪后的文件(避免 `stream_loop` 从视频 0s 开始循环)
- 裁剪临时文件在 `finally` 中自动清理
#### 1.3 `workflow.py` 支持 `custom_assignments`
- **多素材模式**: `custom_assignments` 存在时,直接使用用户分配(仍运行 Whisper 生成字幕),每个 `prepare_segment` 调用传入 `source_start`
- **单素材模式**: `custom_assignments` 有 1 条且 `source_start > 0` 时,先截取片段再传入 LatentSync
- **向后兼容**: `custom_assignments``None` 时完全走旧路径
---
### 二、前端新增组件
#### 2.1 `useTimelineEditor.ts` — 时间轴段管理 hook
```typescript
interface TimelineSegment {
id: string; // React key
materialId: string; // 素材 ID
materialName: string; // 显示名
start: number; // 音频时间轴开始秒数
end: number; // 音频时间轴结束秒数
sourceStart: number; // 源视频截取起点(默认 0
sourceEnd: number; // 源视频截取终点0 = 到结尾)
color: string; // 色块颜色
}
```
核心方法:
- `initSegments()`: selectedMaterials 变化时按数量均分 audioDuration
- `resizeSegment(id, newEnd)`: 拖拽右边界,约束每段最小 1s
- `setSourceRange(id, sourceStart, sourceEnd)`: 设置截取范围
- `toCustomAssignments()`: 转为后端 `CustomAssignment[]` 格式
#### 2.2 `TimelineEditor.tsx` — 波形 + 色块时间轴
- **wavesurfer.js** 渲染音频波形(仅展示,不播放)
- 色块层按比例排列,显示素材名 + 时长 + 截取标记
- 色块间分割线可拖拽(`onPointerDown/Move/Up` 实现连续像素拖拽)
- 点击色块打开 ClipTrimmer
#### 2.3 `ClipTrimmer.tsx` — 素材截取模态框
- HTML5 `<video>` 实时预览,拖拽滑块时 `video.currentTime` 跟随
- 双端 Range Slider起点/终点),互锁约束 ≥ 0.5s
- 显示截取时长 vs 分配时长对比(循环补足/截断提示)
- `loadedmetadata` 获取源视频时长
---
### 三、前端整合改动
#### 3.1 `useHomeController.ts`
- 集成 `useTimelineEditor` hook
- 新增 `clipTrimmerOpen` / `clipTrimmerSegmentId` 状态
- `handleGenerate` 多素材时始终发送 `custom_assignments`;单素材 + `sourceStart > 0` 时也发送
- 移除不再使用的 `reorderMaterials` 导出
#### 3.2 `HomePage.tsx`
- 在 MaterialSelector 和 BgmPanel 之间插入 TimelineEditor仅当有配音且已选素材时显示
- 底部新增 ClipTrimmer 模态框
- 移除 `reorderMaterials``selectedAudioDuration` prop 传递
#### 3.3 `MaterialSelector.tsx`
- 移除配音时长信息栏(功能迁至 TimelineEditor
- 移除拖拽排序区SortableChip + @dnd-kit 相关代码)
- 移除 `onReorderMaterials` / `selectedAudioDuration` prop
---
### 四、审查修复的 Bug
| # | 严重程度 | 问题 | 修复 |
|---|---------|------|------|
| 1 | **中** | `prepare_segment` 使用 `source_start > 0` + stream copy 时 seek 不精确 | 添加 `source_start > 0` 到重编码条件 |
| 2 | **高** | `stream_loop + source_start` 循环时从视频 0s 开始而非从 source_start 循环 | 改为两步:先裁剪片段再循环裁剪后的文件 |
| 3 | **低** | `useHomeController` 导出已废弃的 `reorderMaterials` | 移除 |
---
### 涉及文件汇总
#### 后端修改
| 文件 | 变更 |
|------|------|
| `backend/app/modules/videos/schemas.py` | 新增 `CustomAssignment` model`GenerateRequest` 新增 `custom_assignments` 字段 |
| `backend/app/services/video_service.py` | `prepare_segment` 新增 `source_start` 参数,循环+截取两步处理 |
| `backend/app/modules/videos/workflow.py` | 多素材/单素材流水线支持 `custom_assignments`,传递 `source_start` |
#### 前端新增
| 文件 | 说明 |
|------|------|
| `frontend/src/features/home/model/useTimelineEditor.ts` | 时间轴段管理 hook |
| `frontend/src/features/home/ui/TimelineEditor.tsx` | 波形 + 色块时间轴组件 |
| `frontend/src/features/home/ui/ClipTrimmer.tsx` | 素材截取模态框 |
#### 前端修改
| 文件 | 变更 |
|------|------|
| `frontend/src/features/home/ui/HomePage.tsx` | 插入 TimelineEditor + ClipTrimmer |
| `frontend/src/features/home/ui/MaterialSelector.tsx` | 移除时长信息 + 拖拽排序区 + 相关 prop |
| `frontend/src/features/home/model/useHomeController.ts` | 集成 useTimelineEditorhandleGenerate 发送 custom_assignments |
| `frontend/package.json` | 新增 `wavesurfer.js` 依赖 |
---
## 🎨 UI 体验优化 + TTS 稳定性修复 — 第三阶段 (Day 23)
### 概述
根据用户反馈,修复 6 项 UI 体验问题,同时修复声音克隆服务的 SoX 路径问题和显存缓存管理。
> **注**: Qwen3-TTS 已在后续被 CosyVoice 3.0 (端口 8010) 替换,以下记录为当时的修复过程。
---
### 一、Qwen3-TTS 稳定性修复 (已被 CosyVoice 3.0 替换)
#### 1.1 SoX PATH 修复
**问题**: PM2 启动 qwen-tts 时,`sox` 工具安装在 conda env 的 bin 目录中,系统 PATH 找不到,导致音频编解码走 fallback 路径CPU 密集型),日志中出现 `SoX could not be found!` 警告。
**修复**: `run_qwen_tts.sh` 中 export conda env bin 到 PATH
```bash
export PATH="/home/rongye/ProgramFiles/miniconda3/envs/qwen-tts/bin:$PATH"
```
#### 1.2 CUDA 缓存清理
**修复**: `qwen_tts_server.py` 每次生成完成后(无论成功或失败)调用 `torch.cuda.empty_cache()`,防止显存碎片累积。使用 `asyncio.to_thread()` 在线程池中运行推理,避免阻塞事件循环导致健康检查超时。
> **后续**: Qwen3-TTS 已停用CosyVoice 3.0 沿用了相同的保护机制GPU 推理锁、超时保护、显存清理、启动自检)。
---
### 二、配音列表按钮布局统一 (反馈 #1 + #6)
**问题**: `GeneratedAudiosPanel` 的试听按钮位于左侧(独立于 Edit/Delete`RefAudioPanel` 的布局不一致。底部文案摘要区域不需要展示。
**修复**:
- Play/Edit/Delete 按钮统一放在右侧同组hover 显示,顺序为 试听→重命名→删除
- 移除选中配音的文案摘要区域
- 布局与 RefAudioPanel 一致:左侧名称+时长,右侧操作按钮组
---
### 三、视频素材区域移除配音依赖遮罩 (反馈 #2)
**问题**: MaterialSelector 被 `!selectedAudio` 遮罩覆盖,必须先选配音才能操作素材。
**修复**: 移除 `HomePage.tsx` 中 MaterialSelector 外层的 disabled overlay `<div>`。素材随时可上传/预览/管理,仅 TimelineEditor 需要选中配音才显示(已有独立条件 `selectedAudio && selectedMaterials.length > 0`)。
---
### 四、时间轴拖拽排序 (反馈 #3)
**问题**: TimelineEditor 不支持调换素材顺序。
**修复**:
- `useTimelineEditor` 已有 `reorderSegments()` 方法(交换两个段的素材信息但保留时间范围)
- 通过 `useHomeController` 暴露 `reorderSegments`,传入 `TimelineEditor`
- 色块支持 HTML5 Drag & Drop`draggable` + `onDragStart/Over/Drop/End`
- 拖拽时:源色块半透明(`opacity-50`),目标色块高亮 ring`ring-2 ring-purple-400 scale-[1.02]`
- 光标样式:`cursor-grab` / `active:cursor-grabbing`
---
### 五、截取设置双手柄 Range Slider (反馈 #4)
**问题**: ClipTrimmer 使用两个独立的 `<input type="range">` 滑块,起点和终点分开操作,体验不直观。
**修复**: 改为自定义双手柄 range slider
- 单条轨道,紫色圆形手柄(起点)+ 粉色圆形手柄(终点)
- 轨道底色 `bg-white/10`,选中范围用素材对应颜色高亮
- Pointer Events 实现拖拽:`onPointerDown` 捕获手柄 → `onPointerMove` 更新位置 → `onPointerUp` 释放
- 手柄互锁约束:起点不超过终点 - 0.5s,终点不低于起点 + 0.5s
- 底部显示起点(紫色)和终点(粉色)时间标签
---
### 六、截取设置视频预览 (反馈 #5)
**问题**: ClipTrimmer 的视频只能静态查看,无法播放预览截取范围。
**修复**:
- 视频区域点击可播放/暂停Play/Pause 图标覆盖层)
- 播放范围:从 sourceStart 播放到 sourceEnd 自动停止
- 播放结束后回到起点
- 拖拽手柄时 `video.currentTime` 实时跟随seek 到当前位置查看画面)
- 播放进度条(白色竖线)叠加在 range slider 轨道上
- `preload="auto"` 预加载视频,确保拖拽时快速 seek
---
### 涉及文件汇总
#### 后端修改
| 文件 | 变更 |
|------|------|
| `run_qwen_tts.sh` | export conda env bin 到 PATH修复 SoX 找不到问题 (已停用) |
| `models/Qwen3-TTS/qwen_tts_server.py` | 每次生成后 `torch.cuda.empty_cache()`asyncio.to_thread 避免阻塞 (已停用) |
#### 前端修改
| 文件 | 变更 |
|------|------|
| `frontend/src/features/home/ui/GeneratedAudiosPanel.tsx` | 按钮布局统一Play/Edit/Delete 右侧同组),移除文案摘要 |
| `frontend/src/features/home/ui/HomePage.tsx` | 移除 MaterialSelector 配音遮罩,传入 onReorderSegment |
| `frontend/src/features/home/ui/TimelineEditor.tsx` | 新增 HTML5 Drag & Drop 排序,新增 onReorderSegment prop |
| `frontend/src/features/home/ui/ClipTrimmer.tsx` | 双手柄 range slider + 视频播放预览 + 播放进度指示 |
| `frontend/src/features/home/model/useHomeController.ts` | 暴露 reorderSegments 方法 |
---
## 📝 历史文案保存 + 时间轴拖拽修复 — 第四阶段 (Day 23)
### 概述
新增文案手动保存与加载功能,修复时间轴拖拽排序后素材时长不跟随的 Bug统一按钮视觉规范。
---
### 一、历史文案保存与加载
#### 功能
用户可手动保存当前文案到历史列表,随时从历史中加载恢复。只有手动保存的文案才出现在历史列表中,与自动保存(`useHomePersistence`)完全独立。
#### UI 布局
```
按钮栏: [历史文案▼] [文案提取助手] [AI多语言▼] [AI生成标题标签]
底部栏: 128 字 [保存文案]
```
- **历史文案下拉**: 展示已保存列表(名称 + 日期 + 删除按钮),点击条目加载文案,空列表显示"暂无保存的文案"
- **保存文案按钮**: 文案为空时 disabled点击后 `toast.success("文案已保存")`
- **预计时长已移除**: 底部栏只保留字数 + 保存按钮
#### 实现
##### `useSavedScripts.ts`(新建)
```typescript
interface SavedScript { id: string; name: string; content: string; savedAt: number }
```
- localStorage key: `vigent_{storageKey}_savedScripts`
- `saveScript(content)`: 取前 15 字符自动命名,新条目插入列表头部,**直接写入 localStorage**
- `deleteScript(id)`: 删除指定条目,直接写入 localStorage
- `useEffect([lsKey])`: lsKey 变化时guest → userId重新从 localStorage 读取
- **不使用自动持久化 effect**,避免 storageKey 切换时空数组覆盖已有数据
##### 数据流
```
ScriptEditor (UI)
↑ savedScripts / onSaveScript / onLoadScript / onDeleteScript (纯 props + callbacks)
useHomeController
├── useSavedScripts(storageKey) → { savedScripts, saveScript, deleteScript }
└── handleSaveScript() → saveScript(text) + toast
HomePage
└── 传递 props 到 ScriptEditor
```
---
### 二、时间轴拖拽排序 Bug 修复
#### 问题
拖拽调换素材顺序后各素材的时长没有跟随素材移动而是留在原槽位。例如素材1(3s) + 素材2(8s+4s循环)拖拽后变成素材2(3s) + 素材1(8s+4s循环),时长分配没变。
#### 根因
`reorderSegments` 使用**属性交换**方式:逐个拷贝 `materialId``sourceStart``sourceEnd` 等属性在两个槽位间交换,然后调用 `recalcPositions` 重算位置。
#### 修复
改为**数组移动**splice将整个 segment 对象从旧位置取出插入到新位置。segment 对象携带全部属性materialId、sourceStart、sourceEnd、color 等)作为一个整体移动,再由 `recalcPositions` 重算位置。
```typescript
// 修复前:属性交换
const fromMat = { materialId: next[fromIdx].materialId, ... };
const toMat = { materialId: next[toIdx].materialId, ... };
next[fromIdx] = { ...next[fromIdx], ...toMat };
next[toIdx] = { ...next[toIdx], ...fromMat };
// 修复后:数组移动
const [moved] = next.splice(fromIdx, 1);
next.splice(toIdx, 0, moved);
```
附带优势3+ 素材拖拽行为从"交换"变为"插入",更符合用户直觉。
---
### 三、按钮视觉统一
#### 问题
历史文案、文案提取助手、AI多语言、AI生成标题标签 4 个按钮高度不一致AI 按钮的文本被 `<span>` 嵌套包裹导致内部布局差异。
#### 修复
- 4 个按钮统一为 `h-7 px-2.5 text-xs rounded inline-flex items-center gap-1`(固定高度 28px
- 移除 AI多语言 / AI生成标题标签 按钮内多余的 `<span>` 嵌套,改为 `<>...</>` fragment
---
### 涉及文件汇总
#### 前端新增
| 文件 | 说明 |
|------|------|
| `frontend/src/features/home/model/useSavedScripts.ts` | 历史文案 hooklocalStorage 持久化) |
#### 前端修改
| 文件 | 变更 |
|------|------|
| `frontend/src/features/home/ui/ScriptEditor.tsx` | 历史文案下拉 + 保存按钮 + 移除预计时长 + 按钮高度统一 |
| `frontend/src/features/home/model/useHomeController.ts` | 集成 useSavedScripts新增 handleSaveScript |
| `frontend/src/features/home/ui/HomePage.tsx` | 传递 savedScripts / handleSaveScript / deleteSavedScript 到 ScriptEditor |
| `frontend/src/features/home/model/useTimelineEditor.ts` | reorderSegments 从属性交换改为数组移动splice |
---
## 🔤 字幕语言不匹配 + 视频比例错位修复 — 第五阶段 (Day 23)
### 概述
修复两个视频生成 Bug
1. **字幕语言不匹配**: 中文配音 + 英文翻译文案 → 字幕错误显示英文Whisper 独立转录,忽略原文)
2. **标题字幕比例错位**: 9:16 竖屏素材生成视频后,标题/字幕按 16:9 横屏布局渲染
附带修复代码审查中发现的 `split_word_to_chars` 英文空格丢失问题。
---
### 一、字幕用原文替换 Whisper 转录文字
#### 根因
Whisper 对音频独立转录,完全忽略传入的 `text` 参数。当配音语言与编辑器文案语言不一致时(例如:用户先写中文文案 → 翻译成英文 → 生成英文配音 → 再改回中文文案Whisper "听到"英文语音就输出英文字幕。
#### 修复思路
Whisper 仅负责检测**语音总时间范围**`first_start``last_end`),字幕文字永远用配音保存的原始文案。
#### `whisper_service.py` — `align()` 新增 `original_text` 参数
```python
async def align(self, audio_path, text, output_path=None,
language="zh", original_text=None):
```
`original_text` 非空时:
1. 正常运行 Whisper 转录,记录 `whisper_first_start``whisper_last_end`
2.`original_text` 传入 `split_word_to_chars()` 在总时间范围上线性分布
3.`split_segment_to_lines()` 按标点和字数断行
4. 替换 Whisper 的转录结果
#### `workflow.py` — 配音元数据无条件覆盖 + 传入原文
```python
# 改前(只在文案为空时覆盖)
if not req.text.strip():
req.text = meta.get("text", req.text)
# 改后(无条件用配音元数据覆盖)
meta_text = meta.get("text", "")
if meta_text:
req.text = meta_text
```
所有 4 处 `whisper_service.align()` 调用添加 `original_text=req.text`
---
### 二、Remotion 动态传入视频尺寸
#### 根因
`remotion/src/Root.tsx` 硬编码 `width={1280} height={720}`。虽然 `render.ts` 用 ffprobe 检测真实尺寸后覆盖 `composition.width/height`,但 `selectComposition` 阶段组件已按 1280×720 初始化,标题和字幕定位基于错误的画布尺寸。
#### 修复
##### `Root.tsx` — `calculateMetadata` 从 props 读取尺寸
```tsx
<Composition
id="ViGentVideo"
component={Video}
durationInFrames={300}
fps={25}
width={1080}
height={1920}
calculateMetadata={async ({ props }) => ({
width: props.width || 1080,
height: props.height || 1920,
})}
defaultProps={{
videoSrc: '',
width: 1080,
height: 1920,
// ...
}}
/>
```
默认从 1280×720 改为 1080×1920竖屏优先`calculateMetadata` 确保 `selectComposition` 阶段使用 ffprobe 检测的真实尺寸。
##### `Video.tsx` — VideoProps 新增可选 `width/height`
仅供 `calculateMetadata` 访问,组件渲染不引用。
##### `render.ts` — inputProps 统一传入视频尺寸
```typescript
const inputProps = {
videoSrc: videoFileName,
captions,
title: options.title,
// ...
width: videoWidth, // ffprobe 检测值
height: videoHeight, // ffprobe 检测值
};
```
`selectComposition``renderMedia` 使用同一个 `inputProps`。保留显式 `composition.width/height` 覆盖作为保险。
---
### 三、代码审查修复:英文空格丢失
#### 问题
`split_word_to_chars` 原设计处理 Whisper 单个词(如 `" Hello"`),但 `original_text` 传入整段文本时,中间空格被 `continue` 跳过且不 flush `ascii_buffer`,导致 `"Hello World"` 变成 `"HelloWorld"`
#### 执行路径追踪
```
输入: "Hello World"
H,e,l,l,o → ascii_buffer = "Hello"
' ' → continue跳过不 flush
W,o,r,l,d → ascii_buffer = "HelloWorld"
结果: tokens = ["HelloWorld"] ← 空格丢失
```
#### 修复
遇到空格时 flush `ascii_buffer`,并用 `pending_space` 标记给下一个 token 前置空格:
```python
if not char.strip():
if ascii_buffer:
tokens.append(ascii_buffer)
ascii_buffer = ""
if tokens:
pending_space = True
continue
```
修复后:`"Hello World"` → tokens = `["Hello", " World"]` → 字幕正确显示。中文不受影响。
---
### 涉及文件汇总
#### 后端修改
| 文件 | 变更 |
|------|------|
| `backend/app/services/whisper_service.py` | `align()` 新增 `original_text` 参数;`split_word_to_chars` 修复英文空格丢失 |
| `backend/app/modules/videos/workflow.py` | 配音元数据无条件覆盖 text/language4 处 `align()` 调用传入 `original_text` |
#### 前端修改Remotion
| 文件 | 变更 |
|------|------|
| `remotion/src/Root.tsx` | 默认尺寸改为 1080×1920新增 `calculateMetadata` + width/height defaultProps |
| `remotion/src/Video.tsx` | VideoProps 新增可选 `width`/`height` |
| `remotion/render.ts` | inputProps 统一传入 `videoWidth`/`videoHeight`selectComposition 和 renderMedia 共用 |
---
## 🎤 参考音频自动转写 + 语速控制 — 第六阶段 (Day 23)
### 概述
解决声音克隆 ref_text 不匹配问题:旧方案使用前端固定文字作为 ref_textCosyVoice zero-shot 克隆要求 ref_text 必须与参考音频实际内容匹配,不匹配时模型会在生成音频开头"幻觉"出多余片段。
**改进**:上传参考音频时自动调用 Whisper 转写内容作为 ref_text同时新增语速控制功能。
---
### 一、Whisper 自动转写参考音频
#### 1.1 `whisper_service.py` — 语言自动检测
`transcribe()` 方法原先硬编码 `language="zh"`,改为接受可选 `language` 参数(默认 `None` = 自动检测),支持多语言参考音频。
#### 1.2 `ref_audios/service.py` — 上传时自动转写
上传流程变更:转码 WAV → 检查时长(≥1s) → 超 10s 在静音点截取 → **Whisper 自动转写** → 验证非空 → 上传。
```python
try:
transcribed = await whisper_service.transcribe(tmp_wav_path)
if transcribed.strip():
ref_text = transcribed.strip()
except Exception as e:
logger.warning(f"Auto-transcribe failed: {e}")
if not ref_text or not ref_text.strip():
raise ValueError("无法识别音频内容,请确保音频包含清晰的语音")
```
#### 1.3 `ref_audios/router.py` — ref_text 改为可选
`ref_text: str = Form("")`(不再必填),前端不再发送固定文字。
---
### 二、参考音频智能截取10 秒上限)
CosyVoice 对 3-10 秒参考音频效果最好。
#### 2.1 静音点检测
使用 ffmpeg `silencedetect` 找 10 秒内最后一个静音结束点(阈值 -30dB最短 0.3s),避免在字词中间硬切:
```python
def _find_silence_cut_point(file_path, max_duration):
# silencedetect → 解析 silence_end → 找 3s~max_duration 内最后的静音点
# 找不到则回退到 max_duration
```
#### 2.2 淡出处理
截取时末尾 0.1 秒淡出(`afade=t=out`),避免截断爆音。
---
### 三、重新识别功能(旧数据迁移)
#### 3.1 新增 API
`POST /api/ref-audios/{audio_id}/retranscribe` — 下载音频 → 超 10s 截取 → Whisper 转写 → 重新上传音频和元数据。
#### 3.2 前端 UI
- RefAudioPanel 新增 RotateCw 按钮("重新识别文字"),转写中显示 `animate-spin`
- 旧音频 ref_text 以固定文字开头时显示 ⚠ 黄色警告
---
### 四、语速控制CosyVoice speed 参数)
#### 4.1 全链路传递
```
前端 GeneratedAudiosPanel (速度选择器)
→ useHomeController (speed state + persistence)
→ useGeneratedAudios.generateAudio(params)
→ POST /api/generated-audios/generate { speed: 1.0 }
→ GenerateAudioRequest.speed (Pydantic)
→ generate_audio_task → voice_clone_service.generate_audio(speed=)
→ _generate_once → POST /generate { speed: "1.0" }
→ cosyvoice_server → _model.inference_zero_shot(speed=speed)
```
#### 4.2 前端 UI
声音克隆模式下,配音列表面板标题栏"生成配音"按钮左侧显示语速下拉菜单(`语速: 正常 ▼`
| 标签 | speed 值 |
|------|----------|
| 较慢 | 0.8 |
| 稍慢 | 0.9 |
| 正常 | 1.0 (默认) |
| 稍快 | 1.1 |
| 较快 | 1.2 |
语速选择持久化到 localStorage`vigent_{storageKey}_speed`)。
---
### 五、缺少参考音频门控
声音克隆模式下未选参考音频时:
- "生成配音"按钮禁用 + title 提示"请先选择参考音频"
- 面板内显示黄色警告条"声音克隆模式需要先选择参考音频"
---
### 六、前端清理
- 移除 `FIXED_REF_TEXT` 常量和 `fixedRefText` prop
- 移除"请朗读以下内容"引导区块
- 上传提示简化为"上传任意语音样本3-10秒系统将自动识别内容并克隆声音"
- 录音区备注"建议 3-10 秒,超出将自动截取"
---
### 涉及文件汇总
#### 后端修改
| 文件 | 变更 |
|------|------|
| `backend/app/services/whisper_service.py` | `transcribe()` 增加可选 `language` 参数,默认 None (自动检测) |
| `backend/app/modules/ref_audios/service.py` | 上传自动转写 + 静音点截取 + 淡出 + retranscribe 函数 |
| `backend/app/modules/ref_audios/router.py` | `ref_text` 改为 Form(""),新增 retranscribe 端点 |
| `backend/app/modules/generated_audios/schemas.py` | `GenerateAudioRequest` 新增 `speed: float = 1.0` |
| `backend/app/modules/generated_audios/service.py` | 传递 `req.speed` 到 voice_clone_service |
| `backend/app/services/voice_clone_service.py` | `generate_audio()` / `_generate_once()` 接受并传递 speed |
| `models/CosyVoice/cosyvoice_server.py` | `/generate` 端点接受 `speed` 参数,传递到 `inference_zero_shot(speed=)` |
#### 前端修改
| 文件 | 变更 |
|------|------|
| `frontend/src/features/home/model/useHomeController.ts` | 新增 speed state移除 FIXED_REF_TEXThandleGenerateAudio 传 speed |
| `frontend/src/features/home/model/useHomePersistence.ts` | 新增 speed 持久化 |
| `frontend/src/features/home/model/useRefAudios.ts` | 移除 fixedRefText新增 retranscribe |
| `frontend/src/features/home/model/useGeneratedAudios.ts` | generateAudio params 新增 speed |
| `frontend/src/features/home/ui/GeneratedAudiosPanel.tsx` | 新增语速选择器 + 缺少参考音频门控 |
| `frontend/src/features/home/ui/RefAudioPanel.tsx` | 移除朗读引导,新增重新识别按钮 + ⚠ 警告 |
| `frontend/src/features/home/ui/HomePage.tsx` | 传递 speed/setSpeed/ttsMode 到 GeneratedAudiosPanel |

185
Docs/DevLogs/Day24.md Normal file
View File

@@ -0,0 +1,185 @@
## 🔧 鉴权到期治理 + 多素材时间轴稳定性修复 (Day 24)
### 概述
本日主要完成两条主线:
1. **账号与鉴权治理**:会员到期改为请求时自动失效(登录/鉴权接口触发),并统一返回续费提示。
2. **视频生成稳定性**:围绕多素材时间轴、截取语义、拼接边界冻结、画面比例与字幕标题适配进行一轮端到端修复。
---
## 🔐 会员到期请求时失效 — 第一阶段 (Day 24)
### 目标
避免依赖定时任务,用户在触发登录或访问受保护接口时即可完成到期判定与账号停用。
### 行为调整
- 到期判断基于 `users.expires_at`
- 判定到期后:
-`is_active` 自动置为 `false`
- 删除该用户全部 session
- 返回 `403`,提示:`会员已到期,请续费`
### 实现点
- `users.py` 新增 `deactivate_user_if_expired()`,并补充 `_parse_expires_at()` 统一时区解析。
- `deps.py``get_current_user` / `get_current_user_optional` 中统一接入到期检查。
- `auth/router.py` 在登录路径增加到期停用逻辑;`/api/auth/me` 统一走 `Depends(get_current_user)`
---
## 🖼️ 画面比例控制 + 字幕标题适配 — 第二阶段 (Day 24)
### 2.1 输出画面比例可配置
- 时间轴顶部新增“画面比例”下拉:`9:16` / `16:9`
- 默认值 `9:16`,并持久化到 localStorage。
- 生成请求携带 `output_aspect_ratio`,后端在单素材与多素材流程中统一按目标分辨率处理。
### 2.2 标题/字幕在窄屏画布防溢出
为减少“预览正常、成片溢出”的差异,统一了预览与渲染策略:
- 根据 composition 宽度进行响应式缩放。
- 开启可换行:`white-space: normal` + `word-break` + `overflow-wrap`
- 描边、字距、上下边距同步按比例缩放。
### 2.3 片头标题显示模式(短暂/常驻)
- 在“标题与字幕”面板的“片头标题”行尾新增下拉,支持:`短暂显示` / `常驻显示`
- 默认模式为 `短暂显示`,短暂模式默认时长为 4 秒。
- 用户选择会持久化到 localStorage刷新后保持上次配置。
- 生成请求新增 `title_display_mode`,短暂模式透传 `title_duration=4.0`
- Remotion 端到端支持该参数:
- `short`:标题在设定时长后淡出并结束渲染;
- `persistent`:标题全程常驻(保留淡入动画,不执行淡出)。
---
## 🎥 方向归一化 + 多素材拼接稳定性 — 第三阶段 (Day 24)
### 3.1 MOV 旋转元数据导致横竖识别错误
问题场景:编码分辨率是横屏,但依赖 rotation side-data 才能正确显示为竖屏(常见于手机 MOV
修复方案:
- `get_video_metadata()` 扩展返回 `rotation/effective_width/effective_height`
- 新增 `normalize_orientation()`,在流程前对带旋转元数据素材做物理方向归一化。
- 单素材和多素材下载后统一执行方向归一化,再做分辨率决策。
### 3.2 多素材“只看到第一段”与边界冻结
针对拼接可靠性补了两类保护:
- **分配保护**`custom_assignments` 与素材数量不一致时,后端回退自动分配,避免异常输入导致仅首段生效。
- **编码一致性**
- 片段准备阶段统一重编码;
- concat 阶段不再走拷贝;
- 进一步统一为 `25fps + CFR`,并在 concat 增加 `+genpts`,降低段边界时间基不连续导致的“画面冻结口型还动”风险。
---
## ⏱️ 时间轴截取语义对齐修复 — 第四阶段 (Day 24)
### 背景
时间轴设计语义是:
- 每段可以设置 `sourceStart/sourceEnd`
- 总时长超出音频时,仅保留可见段,末段截齐音频;
- 总时长不足时,由最后可见段循环补齐。
本日将前后端对齐到这一语义。
### 4.1 `source_end` 全链路打通
此前仅传 `source_start`,导致后端无法准确知道“截到哪里”。
本次改动:
- 前端 `toCustomAssignments()` 增加可选 `source_end`
- 后端 `CustomAssignment` schema 增加 `source_end`
- workflow 将 `source_end` 透传到 `prepare_segment()`(单素材/多素材均支持)。
- `prepare_segment()` 增加 `source_end` 参数,按 `[source_start, source_end)` 计算可用片段,并在需要循环时先裁剪再循环,避免循环范围错位。
### 4.2 时间轴有效时长计算修复
修复 `sourceStart > 0 且 sourceEnd = 0` 时的有效时长错误:
- 旧逻辑会按整段素材时长计算;
- 新逻辑改为 `materialDuration - sourceStart`
该修复同时用于:
- `recalcPositions()` 的段时长计算;
- TimelineEditor 中“循环补足”可视化比例计算。
### 4.3 可见段分配优先级修复
修复“可见段数 < 已选素材数时custom_assignments 被丢弃回退自动分配”的问题:
- 生成请求优先以时间轴可见段的 `assignments` 为准;
- 超出时间轴的素材不参与本次生成。
### 4.4 单素材截取触发条件补齐
单素材模式下,若只改了终点(`sourceEnd > 0`)也会发送 `custom_assignments`,确保截取生效。
---
## 🧭 页面交互与体验细节 — 第五阶段 (Day 24)
- 页面刷新后自动回到顶部,避免从历史滚动位置进入页面。
- 素材列表与历史视频列表滚动增加“跳过首次自动滚动”保护,减少恢复状态时页面跳动。
- 时间轴比例区移除多余文案,保持信息简洁。
---
## 涉及文件汇总
### 后端修改
| 文件 | 变更 |
|------|------|
| `backend/app/repositories/users.py` | 新增 `deactivate_user_if_expired()``_parse_expires_at()` |
| `backend/app/core/deps.py` | `get_current_user` / `get_current_user_optional` 接入到期失效检查 |
| `backend/app/modules/auth/router.py` | 登录时到期停用 + `/api/auth/me` 统一鉴权依赖 |
| `backend/app/modules/videos/schemas.py` | `CustomAssignment` 新增 `source_end`;保留 `output_aspect_ratio` |
| `backend/app/modules/videos/workflow.py` | 多素材/单素材透传 `source_end`;多素材 prepare/concat 统一 25fps标题显示模式参数透传 Remotion |
| `backend/app/services/video_service.py` | 旋转元数据解析与方向归一化;`prepare_segment` 支持 `source_end/target_fps`concat 强制 CFR + `+genpts` |
| `backend/app/services/remotion_service.py` | render 支持 `title_display_mode/title_duration` 并传递到 render.ts |
### 前端修改
| 文件 | 变更 |
|------|------|
| `frontend/src/features/home/model/useTimelineEditor.ts` | `CustomAssignment` 新增 `source_end`;修复 sourceStart 开放终点时长计算 |
| `frontend/src/features/home/model/useHomeController.ts` | 多素材以可见 assignments 为准发送;单素材截取触发条件补齐 |
| `frontend/src/features/home/ui/TimelineEditor.tsx` | 画面比例下拉;循环比例按截取后有效时长计算 |
| `frontend/src/features/home/model/useHomePersistence.ts` | `outputAspectRatio``titleDisplayMode` 持久化 |
| `frontend/src/features/home/ui/HomePage.tsx` | 页面进入滚动到顶部ClipTrimmer/Timeline 交互保持一致 |
| `frontend/src/features/home/ui/FloatingStylePreview.tsx` | 标题/字幕样式预览与成片渲染策略对齐 |
| `frontend/src/features/home/ui/TitleSubtitlePanel.tsx` | 标题行新增“短暂显示/常驻显示”下拉 |
### Remotion 修改
| 文件 | 变更 |
|------|------|
| `remotion/src/components/Title.tsx` | 标题响应式缩放与自动换行;新增短暂/常驻显示模式控制 |
| `remotion/src/components/Subtitles.tsx` | 字幕响应式缩放与自动换行,减少预览/成片差异 |
| `remotion/src/Video.tsx` | 新增 `titleDisplayMode` 透传到标题组件 |
| `remotion/src/Root.tsx` | 默认 props 增加 `titleDisplayMode='short'``titleDuration=4` |
| `remotion/render.ts` | CLI 参数新增 `--titleDisplayMode`inputProps 增加 `titleDisplayMode` |
---
## 验证记录
- 后端语法检查:`python -m py_compile backend/app/modules/videos/schemas.py backend/app/modules/videos/workflow.py backend/app/services/video_service.py backend/app/services/remotion_service.py`
- 前端类型检查:`npx tsc --noEmit`
- 前端 ESLint`npx eslint src/features/home/model/useHomeController.ts src/features/home/model/useHomePersistence.ts src/features/home/ui/HomePage.tsx src/features/home/ui/TitleSubtitlePanel.tsx`
- Remotion 渲染脚本构建:`npm run build:render`

254
Docs/DevLogs/Day25.md Normal file
View File

@@ -0,0 +1,254 @@
## 🔧 文案提取助手修复 — 抖音链接无法提取文案 (Day 25)
### 概述
文案提取助手粘贴抖音链接后无法提取文案yt-dlp 报错 `Fresh cookies are needed`,手动回退方案也因抖音页面结构变化失效。本日完成了完整修复,并清理了不再需要的 `DOUYIN_COOKIE` 配置。
---
## 🐛 问题诊断
### 错误链路
1. **yt-dlp 失败**`ERROR: [Douyin] Fresh cookies (not necessarily logged in) are needed`
- yt-dlp 版本 `2025.12.08` 过旧
- 抖音 API `aweme/v1/web/aweme/detail/` 需要签名 cookie`s_v_web_id` 等),即使升级 yt-dlp 到最新版 + 传入 cookie 仍无法解决,属 yt-dlp 已知问题
2. **手动回退失败**`Could not find RENDER_DATA in page`
- 旧方案通过桌面端用户主页 + `modal_id` 访问,抖音 SSR 已不再返回 `videoDetail` 数据
3. **`.env``DOUYIN_COOKIE`**:时间戳 2024 年 12 月,早已过期
---
## ✅ 修复方案:移动端分享页 + 自动获取 ttwid
### 核心思路
放弃依赖 yt-dlp 下载抖音视频和手动维护 cookie改为
1. 自动从 ByteDance 公共 API 获取新鲜 `ttwid`(匿名令牌,不绑定账号)
2.`ttwid` 访问移动端分享页 `m.douyin.com/share/video/{id}`
3. 从页面内嵌 JSON 中提取 `play_addr` 播放地址并下载
### 关键代码(`_download_douyin_manual` 重写)
```python
# 1. 获取新鲜 ttwid
ttwid_resp = await client.post(
"https://ttwid.bytedance.com/ttwid/union/register/",
json={"region": "cn", "aid": 6383, "service": "www.douyin.com", ...}
)
ttwid = ttwid_resp.cookies.get("ttwid", "")
# 2. 访问移动端分享页
page_resp = await client.get(
f"https://m.douyin.com/share/video/{video_id}",
headers={"cookie": f"ttwid={ttwid}", ...}
)
# 3. 提取 play_addr
addr_match = re.search(r'"play_addr":\{"uri":"([^"]+)","url_list":\["([^"]+)"', page_text)
video_url = addr_match.group(2).replace(r"\u002F", "/")
```
### 优势
- 不再依赖手动维护的 `DOUYIN_COOKIE`ttwid 每次请求自动获取
- 不受 yt-dlp 对抖音支持状况影响
- 所有用户通用,不绑定特定账号
---
## 🧹 清理 DOUYIN_COOKIE 配置
`DOUYIN_COOKIE` 仅用于文案提取,新方案不再需要,已从以下位置删除:
| 文件 | 变更 |
|------|------|
| `backend/.env` | 删除 `DOUYIN_COOKIE` 配置项及注释 |
| `backend/app/core/config.py` | 删除 `DOUYIN_COOKIE: str = ""` 字段定义 |
| `backend/app/modules/tools/service.py` | 删除 yt-dlp 传 cookie 逻辑和 `_write_netscape_cookies` 辅助函数 |
---
## 🔤 前端文案修正
将文案提取界面中的"AI 洗稿结果"改为"AI 改写结果"。
| 文件 | 变更 |
|------|------|
| `frontend/src/features/home/ui/ScriptExtractionModal.tsx` | `AI 洗稿结果``AI 改写结果` |
| `backend/app/modules/tools/service.py` | 注释中"洗稿"→"改写" |
| `backend/app/services/glm_service.py` | docstring 中"洗稿"→"改写文案" |
---
## 📦 其他变更
- **yt-dlp 升级**`2025.12.08``2026.2.21`
- **yt-dlp 初始化修正**:改为 `YoutubeDL(ydl_opts)` 直接传参初始化(原先空初始化后 update params 不生效)
- **User-Agent 更新**yt-dlp 中 `Chrome/91``Chrome/131`
---
## 涉及文件汇总
### 后端修改
| 文件 | 变更 |
|------|------|
| `backend/app/modules/tools/service.py` | 重写 `_download_douyin_manual`(移动端分享页方案);修正 yt-dlp 初始化;清理 cookie 相关代码;注释改写 |
| `backend/app/services/glm_service.py` | docstring "洗稿" → "改写文案" |
| `backend/app/core/config.py` | 删除 `DOUYIN_COOKIE` 字段 |
| `backend/.env` | 删除 `DOUYIN_COOKIE` 配置 |
| `backend/requirements.txt` | yt-dlp 版本升级 |
### 前端修改
| 文件 | 变更 |
|------|------|
| `frontend/src/features/home/ui/ScriptExtractionModal.tsx` | "AI 洗稿结果" → "AI 改写结果" |
---
## ✏️ AI 智能改写 — 自定义提示词功能
### 概述
文案提取助手的"AI 智能改写"原先使用硬编码 prompt用户无法定制改写风格。本次在 checkbox 右侧新增"自定义提示词"折叠区域,用户可编辑自定义 prompt持久化到 localStorage后端按需替换默认 prompt。
### 后端修改
**路由层** (`router.py`)`extract_script_tool` 新增可选 Form 参数 `custom_prompt: Optional[str] = Form(None)`,透传给 service。
**服务层** (`service.py`)`extract_script()` 签名新增 `custom_prompt`,透传给 `glm_service.rewrite_script(script, custom_prompt)`
**AI 层** (`glm_service.py`)`rewrite_script(self, text, custom_prompt=None)`,若 `custom_prompt` 有值则用自定义 prompt + 原文拼接,否则保持原有默认 prompt。
```python
if custom_prompt and custom_prompt.strip():
prompt = f"""{custom_prompt.strip()}
原始文案:
{text}"""
else:
prompt = f"""请将以下视频文案进行改写。...(原有默认)"""
```
### 前端修改
**Hook** (`useScriptExtraction.ts`)
- 新增 `customPrompt` / `showCustomPrompt` 状态
- 初始值从 `localStorage.getItem("vigent_rewriteCustomPrompt")` 恢复
- `customPrompt` 变化时防抖 300ms 保存到 localStorage
- `handleExtract()` 中若 `doRewrite && customPrompt.trim()` 有值,追加 `formData.append("custom_prompt", ...)`
- modal 重置时不清空 customPrompt持久化偏好
**UI** (`ScriptExtractionModal.tsx`)
- checkbox 同行右侧新增"自定义提示词 ▼"按钮(仅 `doRewrite` 时显示)
- 点击展开 textarea 编辑区域,底部提示"留空则使用默认提示词"
- 取消勾选 AI 智能改写时,自定义提示词区域自动隐藏
### 涉及文件
| 文件 | 变更 |
|------|------|
| `backend/app/modules/tools/router.py` | 新增 `custom_prompt` Form 参数 |
| `backend/app/modules/tools/service.py` | `extract_script()` 透传 `custom_prompt` |
| `backend/app/services/glm_service.py` | `rewrite_script()` 支持自定义 prompt |
| `frontend/.../useScriptExtraction.ts` | 新增状态、localStorage 持久化、FormData 传参 |
| `frontend/.../ScriptExtractionModal.tsx` | UI 按钮 + 展开 textarea |
### 验证
- 后端 `python -m py_compile` 三个文件通过
- 前端 `npx tsc --noEmit` 通过
---
## 🐛 SSR 构建修复 — localStorage is not defined
### 问题
`npm run build` 报错 `ReferenceError: localStorage is not defined`,因为 `useScriptExtraction.ts``useState` 的初始化函数在 SSRNode.js环境下也会执行而服务端没有 `localStorage`
### 修复
`useState` 初始化加 `typeof window !== "undefined"` 守卫:
```typescript
const [customPrompt, setCustomPrompt] = useState(
() => typeof window !== "undefined" ? localStorage.getItem(CUSTOM_PROMPT_KEY) || "" : ""
);
```
| 文件 | 变更 |
|------|------|
| `frontend/.../useScriptExtraction.ts` | `useState` 初始化增加 SSR 安全守卫 |
---
## 🎬 片头副标题功能
### 概述
新增片头副标题secondary_title显示在主标题下方用于补充说明或悬念引导。副标题有独立的样式配置字体、字号、颜色等可由 AI 同时生成20 字限制,仅在视频画面中显示,不参与发布标题。
命名约定:后端 `secondary_title`snake_case前端 `videoSecondaryTitle`camelCase用户界面"片头副标题"。
---
### 后端修改
| 文件 | 变更 |
|------|------|
| `backend/app/modules/videos/schemas.py` | `GenerateRequest` 新增 4 个可选字段:`secondary_title``secondary_title_style_id``secondary_title_font_size``secondary_title_top_margin` |
| `backend/app/services/glm_service.py` | AI prompt 增加副标题生成要求不超过20字JSON 格式新增 `secondary_title` 字段 |
| `backend/app/modules/ai/router.py` | `GenerateMetaResponse` 增加 `secondary_title: str = ""`endpoint 返回时取 `result.get("secondary_title", "")` |
| `backend/app/modules/videos/workflow.py` | `use_remotion` 条件增加 `or req.secondary_title`;副标题样式解析复用 `get_style("title", ...)`;字号/间距覆盖;`prepare_style_for_remotion` 处理副标题字体;`remotion_service.render()` 传入 `secondary_title` + `secondary_title_style` |
| `backend/app/services/remotion_service.py` | `render()` 新增 `secondary_title``secondary_title_style` 参数,构建 CLI 参数 `--secondaryTitle``--secondaryTitleStyle` |
### Remotion 修改
| 文件 | 变更 |
|------|------|
| `remotion/render.ts` | `RenderOptions` 新增 `secondaryTitle?` + `secondaryTitleStyle?``parseArgs()` 新增 switch case`inputProps` 新增两个字段 |
| `remotion/src/components/Title.tsx` | `TitleProps` 新增 `secondaryTitle?``secondaryTitleStyle?``AbsoluteFill` 改为 `flexDirection: 'column'` 垂直堆叠;主标题 `<h1>` 后增加副标题 `<h2>`,独立样式(默认字号 48px、字重 700共享淡入淡出动画副标题字体使用独立 `@font-face``SecondaryTitleFont`)避免与主标题冲突 |
| `remotion/src/Video.tsx` | `VideoProps` 新增 `secondaryTitle?` + `secondaryTitleStyle?`;传递给 `<Title>` 组件;渲染条件改为 `{(title \|\| secondaryTitle) && ...}` |
| `remotion/src/Root.tsx` | `defaultProps` 新增 `secondaryTitle: undefined` + `secondaryTitleStyle: undefined` |
### 前端修改
| 文件 | 变更 |
|------|------|
| `frontend/src/shared/lib/title.ts` | 新增 `SECONDARY_TITLE_MAX_LENGTH = 20``clampSecondaryTitle()` |
| `frontend/src/features/home/model/useHomeController.ts` | 新增状态 `videoSecondaryTitle``selectedSecondaryTitleStyleId``secondaryTitleFontSize``secondaryTitleTopMargin``secondaryTitleSizeLocked`;新建 `secondaryTitleInput = useTitleInput({ maxLength: 20 })`(不 sync 到发布页);`handleGenerateMeta()` 接收并填充 `secondary_title``handleGenerate()` 构建 payload 增加副标题字段return 暴露所有新状态 |
| `frontend/src/features/home/model/useHomePersistence.ts` | 新增 localStorage key`secondaryTitle``secondaryTitleStyle``secondaryTitleFontSize``secondaryTitleTopMargin`;对应的恢复和保存 effect |
| `frontend/src/features/home/ui/TitleSubtitlePanel.tsx` | Props 新增副标题相关;主标题输入框下方添加"片头副标题限制20个字"输入框;副标题样式选择器(复用 titleStyles 预设、字号滑块30-100px、间距滑块0-100px |
| `frontend/src/features/home/ui/FloatingStylePreview.tsx` | 标题预览改为 flex column 布局;主标题下方增加副标题预览行,独立样式渲染 |
| `frontend/src/features/home/ui/HomePage.tsx` | 从 `useHomeController` 解构新状态,传给 `TitleSubtitlePanel` |
---
## 🐛 参考音频上传 — 中文文件名 InvalidKey 修复
### 问题
上传中文名参考音频(如"我的声音.wav"Supabase Storage 报 `InvalidKey`,因为存储路径直接使用了原始中文文件名。
### 修复
`ref_audios/service.py` 新增 `sanitize_filename()` 函数,将存储路径的文件名清洗为 ASCII 安全字符(仅 `A-Za-z0-9._-`
- NFKD 规范化 → 丢弃非 ASCII → 非法字符替换为 `_`
- 纯中文/emoji 清洗后为空时,使用 MD5 哈希兜底(如 `audio_e924b1193007`
- 文件名限长 50 字符
- 原始中文文件名保留在 metadata 中作为展示名,前端显示不受影响
```
修复前: cbbe.../1771915755_我的声音.wav → InvalidKey
修复后: cbbe.../1771915755_audio_xxxxxxxx.wav → 上传成功
```
| 文件 | 变更 |
|------|------|
| `backend/app/modules/ref_audios/service.py` | 新增 `sanitize_filename()` 函数,上传路径使用清洗后文件名 |

View File

@@ -30,7 +30,7 @@
| ⚡ **Med** | `Docs/BACKEND_README.md` | **(后端文档)** 接口说明、架构设计 |
| ⚡ **Med** | `Docs/FRONTEND_DEV.md` | **(前端规范)** API封装、日期格式化、新页面规范 |
| ⚡ **Med** | `Docs/FRONTEND_README.md` | **(前端文档)** 功能说明、页面变更 |
| 🧊 **Low** | `Docs/*_DEPLOY.md` | **(子系统部署)** LatentSync/Qwen3/字幕等独立部署文档 |
| 🧊 **Low** | `Docs/*_DEPLOY.md` | **(子系统部署)** LatentSync/CosyVoice/字幕等独立部署文档 |
---
@@ -195,7 +195,8 @@ ViGent2/Docs/
├── DEPLOY_MANUAL.md # 部署手册
├── SUPABASE_DEPLOY.md # Supabase 部署文档
├── LATENTSYNC_DEPLOY.md # LatentSync 部署文档
├── QWEN3_TTS_DEPLOY.md # 声音克隆部署文档
├── COSYVOICE3_DEPLOY.md # 声音克隆部署文档
├── ALIPAY_DEPLOY.md # 支付宝付费部署文档
├── SUBTITLE_DEPLOY.md # 字幕系统部署文档
└── DevLogs/
├── Day1.md # 开发日志
@@ -304,4 +305,4 @@ ViGent2/Docs/
---
**最后更新**2026-02-08
**最后更新**2026-02-11

View File

@@ -10,8 +10,9 @@ frontend/src/
│ ├── page.tsx # 首页(视频生成)
│ ├── publish/ # 发布管理页
│ ├── admin/ # 管理员页面
│ ├── login/ # 登录
── register/ # 注册
│ ├── login/ # 登录
── register/ # 注册
│ └── pay/ # 付费开通会员
├── features/ # 功能模块(按业务拆分)
│ ├── home/
│ │ ├── model/ # 业务逻辑 hooks
@@ -19,9 +20,12 @@ frontend/src/
│ │ │ ├── useHomePersistence.ts # 持久化管理
│ │ │ ├── useBgm.ts
│ │ │ ├── useGeneratedVideos.ts
│ │ │ ├── useGeneratedAudios.ts
│ │ │ ├── useMaterials.ts
│ │ │ ├── useMediaPlayers.ts
│ │ │ ├── useRefAudios.ts
│ │ │ ├── useSavedScripts.ts
│ │ │ ├── useTimelineEditor.ts
│ │ │ └── useTitleSubtitleStyles.ts
│ │ └── ui/ # UI 组件(纯 props + 回调)
│ │ ├── HomePage.tsx
@@ -35,6 +39,9 @@ frontend/src/
│ │ ├── FloatingStylePreview.tsx
│ │ ├── VoiceSelector.tsx
│ │ ├── RefAudioPanel.tsx
│ │ ├── GeneratedAudiosPanel.tsx
│ │ ├── TimelineEditor.tsx
│ │ ├── ClipTrimmer.tsx
│ │ ├── BgmPanel.tsx
│ │ ├── GenerateActionBar.tsx
│ │ ├── PreviewPanel.tsx
@@ -250,6 +257,12 @@ import { formatDate } from '@/shared/lib/media';
## ⚡️ 体验优化规范
### 刷新回顶部(统一体验)
- 长页面(如首页/发布页)在首次挂载时统一回到顶部,避免浏览器恢复旧滚动位置导致进入即跳到中部。
- 推荐实现:`useEffect(() => { window.scrollTo({ top: 0, left: 0, behavior: 'auto' }); }, [])`
- 列表内自动定位(素材/历史记录)应跳过恢复后的首次触发,防止刷新后页面二次跳动。
### 路由预取
- 首页进入发布管理时使用 `router.prefetch("/publish")`
@@ -299,8 +312,20 @@ import { formatDate } from '@/shared/lib/media';
- **必须持久化**
- 标题样式 ID / 字幕样式 ID
- 标题字号 / 字幕字号
- 标题显示模式(`short` / `persistent`
- 背景音乐选择 / 音量 / 开关状态
- 输出画面比例(`9:16` / `16:9`
- 素材选择 / 历史作品选择
- 选中配音 ID (`selectedAudioId`)
- 语速 (`speed`,声音克隆模式)
- 时间轴段信息 (`useTimelineEditor` 的 localStorage)
### 历史文案(独立持久化)
`useSavedScripts` hook 独立管理历史文案的 localStorage 持久化:
- key: `vigent_{storageKey}_savedScripts`
- 仅在用户手动保存/删除时写入 localStorage不使用自动持久化 effect
-`useHomePersistence` 完全独立,互不影响
### 实施规范
- 使用 `storageKey = userId || 'guest'`,按用户隔离。
@@ -317,6 +342,7 @@ import { formatDate } from '@/shared/lib/media';
- 片头标题与发布信息标题统一限制 15 字。
- 中文输入法合成阶段不截断,合成结束后才校验长度。
- 首页片头标题修改会同步写入 `vigent_${storageKey}_publish_title`
- 标题显示模式使用 `short` / `persistent` 两个固定值;默认 `short`(短暂显示 4 秒)。
- 避免使用 `maxLength` 强制截断输入法合成态。
- 推荐使用 `@/shared/hooks/useTitleInput` 统一处理输入逻辑。
@@ -346,9 +372,11 @@ import { formatDate } from '@/shared/lib/media';
| 接口 | 方法 | 功能 |
|------|------|------|
| `/api/ref-audios` | POST | 上传参考音频 (multipart/form-data: file + ref_text) |
| `/api/ref-audios` | POST | 上传参考音频 (multipart/form-data: fileref_text 可选,后端自动 Whisper 转写) |
| `/api/ref-audios` | GET | 列出用户的参考音频 |
| `/api/ref-audios/{id}` | PUT | 重命名参考音频 |
| `/api/ref-audios/{id}` | DELETE | 删除参考音频 (id 需 encodeURIComponent) |
| `/api/ref-audios/{id}/retranscribe` | POST | 重新识别参考音频文字Whisper 转写 + 超 10s 自动截取) |
### 视频生成 API 扩展
@@ -367,7 +395,8 @@ await api.post('/api/videos/generate', {
text: '口播文案',
tts_mode: 'voiceclone',
ref_audio_id: 'user_id/timestamp_name.wav',
ref_text: '参考音频对应文字',
ref_text: '参考音频对应文字', // 从参考音频 metadata 自动获取
speed: 1.0, // 语速 (0.8-1.2)
});
```
@@ -381,8 +410,14 @@ const stream = await navigator.mediaDevices.getUserMedia({ audio: true });
const mediaRecorder = new MediaRecorder(stream, { mimeType: 'audio/webm' });
```
### 参考音频自动处理
- **自动转写**: 上传参考音频时后端自动调用 Whisper 转写内容作为 `ref_text`,无需用户手动输入
- **自动截取**: 参考音频超过 10 秒时自动在静音点截取前 10 秒CosyVoice 建议 3-10 秒)
- **重新识别**: 旧参考音频可通过 retranscribe 端点重新转写并截取
### UI 结构
配音方式使用 Tab 切换:
- **EdgeTTS 音色** - 预设音色 2x3 网格
- **声音克隆** - 参考音频列表 + 在线录音 + 参考文字输入
- **声音克隆** - 参考音频列表 + 在线录音 + 语速下拉菜单 (5 档: 较慢/稍慢/正常/稍快/较快)

View File

@@ -17,7 +17,9 @@ ViGent2 的前端界面,采用 Next.js 16 + TailwindCSS 构建。
- **作品预览**: 生成完成后直接播放下载(作品预览 + 历史作品)。
- **预览优化**: 预览视频 `metadata` 预取,首帧加载更快。
- **本地保存**: 文案/标题/偏好由 `useHomePersistence` 统一持久化,刷新后恢复 (Day 14/17)。
- **历史文案**: 手动保存/加载/删除历史文案,独立 localStorage 持久化 (Day 23)。
- **选择持久化**: 首页/发布页作品选择均使用稳定 `id` 持久化,刷新保持用户选择;新视频生成后自动选中最新 (Day 21)。
- **AI 多语言翻译**: 支持 9 种目标语言翻译文案 + 还原原文 (Day 22)。
### 2. 全自动发布 (`/publish`) [Day 7 新增]
- **多平台管理**: 统一管理抖音、微信视频号、B站、小红书账号状态。
@@ -33,30 +35,51 @@ ViGent2 的前端界面,采用 Next.js 16 + TailwindCSS 构建。
### 3. 声音克隆 [Day 13 新增]
- **TTS 模式选择**: EdgeTTS (预设音色) / 声音克隆 (自定义音色) 切换。
- **参考音频管理**: 上传/列表/删除参考音频 (3-20秒 WAV)
- **一键克隆**: 选择参考音频后自动调用 Qwen3-TTS 服务
- **参考音频管理**: 上传/列表/重命名/删除参考音频,上传后自动 Whisper 转写 ref_text + 超 10s 自动截取
- **重新识别**: 参考音频可重新转写并截取 (RotateCw 按钮)
- **一键克隆**: 选择参考音频后自动调用 CosyVoice 3.0 服务。
- **语速控制**: 声音克隆模式下支持 5 档语速 (0.8-1.2),选择持久化 (Day 23)。
- **多语言支持**: EdgeTTS 10 语言声音列表,声音克隆 language 透传 (Day 22)。
### 4. 字幕与标题 [Day 13 新增]
- **片头标题**: 可选输入,限制 15 字,视频开头显示 3 秒淡入淡出标题
### 4. 配音前置 + 时间轴编排 [Day 23 新增]
- **配音独立生成**: 先生成配音 → 选中配音 → 再选素材 → 生成视频
- **配音管理面板**: 生成/试听/改名/删除/选中,异步生成 + 进度轮询。
- **时间轴编辑器**: wavesurfer.js 音频波形 + 色块可视化素材分配,拖拽分割线调整各段时长。
- **素材截取设置**: ClipTrimmer 双手柄 range slider + HTML5 视频预览播放。
- **拖拽排序**: 时间轴色块支持 HTML5 Drag & Drop 调换素材顺序。
- **自定义分配**: 后端 `custom_assignments` 支持用户定义的素材分配方案(含 `source_start/source_end` 截取区间)。
- **时间轴语义对齐**: 超出音频时仅保留可见段并截齐末段,超出段不参与生成;不足音频时最后可见段自动循环补齐。
- **画面比例控制**: 时间轴顶部支持 `9:16 / 16:9` 输出比例选择,设置持久化并透传后端。
### 5. 字幕与标题 [Day 13 新增]
- **片头标题**: 可选输入,限制 15 字;支持”短暂显示 / 常驻显示”默认短暂显示4 秒)。
- **片头副标题**: 可选输入,限制 20 字;显示在主标题下方,用于补充说明或悬念引导;独立样式配置(字体/字号/颜色/间距),可由 AI 同时生成;仅在视频画面中显示,不参与发布标题 (Day 25)。
- **标题同步**: 首页片头标题修改会同步到发布信息标题。
- **逐字高亮字幕**: 卡拉OK效果默认开启可关闭。
- **自动对齐**: 基于 faster-whisper 生成字级别时间戳。
- **样式预设**: 标题/字幕样式选择 + 预览 + 字号调节 (Day 16)。
- **样式预设**: 标题/字幕/副标题样式选择 + 预览 + 字号调节 (Day 16/25)。
- **默认样式**: 标题 90px 站酷快乐体;字幕 60px 经典黄字 + DingTalkJinBuTi (Day 17)。
- **样式持久化**: 标题/字幕样式与字号刷新保留 (Day 17)。
- **样式持久化**: 标题/字幕/副标题样式与字号刷新保留 (Day 17/25)。
### 5. 背景音乐 [Day 16 新增]
### 6. 背景音乐 [Day 16 新增]
- **试听预览**: 点击试听即选中,音量滑块实时生效。
- **混音控制**: 仅影响 BGM配音保持原音量。
### 6. 账户设置 [Day 15 新增]
### 7. 账户设置 [Day 15 新增]
- **手机号登录**: 11位中国手机号验证登录。
- **账户下拉菜单**: 显示有效期 + 修改密码 + 安全退出。
- **修改密码**: 弹窗输入当前密码与新密码,修改后强制重新登录。
### 7. 文案提取助手 (`ScriptExtractionModal`) [Day 15 新增]
### 8. 付费开通会员 (`/pay`)
- **支付宝电脑网站支付**: 跳转支付宝官方收银台,支持扫码/账号登录/余额等多种支付方式。
- **自动激活**: 支付成功后异步回调自动激活会员(有效期 1 年),前端轮询检测支付结果。
- **到期续费**: 会员到期后登录自动跳转付费页续费,流程与首次开通一致。
- **管理员激活**: 管理员手动激活功能并存,两种方式互不影响。
### 8. 文案提取助手 (`ScriptExtractionModal`) [Day 15 新增]
- **多源提取**: 支持文件拖拽上传与 URL 粘贴 (B站/抖音/TikTok)。
- **AI 洗稿**: 集成 GLM-4.7-Flash自动改写为口播文案。
- **AI 智能改写**: 集成 GLM-4.7-Flash自动改写为口播文案。
- **自定义提示词**: 可自定义改写提示词,留空使用默认;设置持久化到 localStorage (Day 25)。
- **一键填入**: 提取结果直接填充至视频生成输入框。
- **智能交互**: 实时进度展示,防误触设计。
@@ -66,6 +89,7 @@ ViGent2 的前端界面,采用 Next.js 16 + TailwindCSS 构建。
- **样式**: TailwindCSS
- **图标**: Lucide React
- **组件**: 自定义现代化组件 (Glassmorphism 风格)
- **音频波形**: wavesurfer.js (时间轴编辑器)
- **API**: Axios 实例 `@/shared/api/axios` (对接后端 FastAPI :8006)
## 🚀 开发指南
@@ -93,6 +117,8 @@ src/
│ ├── page.tsx # 视频生成主页
│ ├── publish/ # 发布管理页
│ │ └── page.tsx
│ ├── pay/ # 付费开通会员页
│ │ └── page.tsx
│ └── layout.tsx # 全局布局 (导航栏)
├── features/
│ ├── home/

View File

@@ -298,12 +298,20 @@ Response: audio/wav 文件
SoX could not be found!
```
**解决**: 通过 conda 安装 sox
**解决**:
1. 通过 conda 安装 sox
```bash
conda install -y -c conda-forge sox
```
2. 确保启动脚本 `run_qwen_tts.sh` 中已 export conda env bin 到 PATHPM2 启动时系统 PATH 不含 conda 环境目录):
```bash
export PATH="/home/rongye/ProgramFiles/miniconda3/envs/qwen-tts/bin:$PATH"
```
### CUDA 内存不足
Qwen3-TTS 1.7B 通常需要 8-10GB VRAM。如果遇到 OOM
@@ -371,6 +379,7 @@ FOR INSERT TO anon WITH CHECK (bucket_id = 'ref-audios');
| 日期 | 版本 | 说明 |
|------|------|------|
| 2026-02-09 | 1.2.0 | 修复 SoX PATH 问题run_qwen_tts.sh export conda bin每次生成后 empty_cache() |
| 2026-01-30 | 1.1.0 | 明确默认模型升级为 1.7B-Base替换旧版 0.6B 路径 |
---

View File

@@ -15,9 +15,13 @@
原有流程:
文本 → EdgeTTS → 音频 → LatentSync → FFmpeg合成 → 最终视频
新流程:
文本 → EdgeTTS → 音频 ─┬→ LatentSync → 唇形视频 ─┐
└→ faster-whisper → 字幕JSON ─┴→ Remotion合成 → 最终视频
新流程 (单素材):
文本 → EdgeTTS/Qwen3-TTS/预生成配音 → 音频 ─┬→ LatentSync → 唇形视频 ─┐
└→ faster-whisper → 字幕JSON ─┴→ Remotion合成 → 最终视频
新流程 (多素材):
音频 → 多素材按 custom_assignments 拼接 → LatentSync (单次推理) → 唇形视频 ─┐
音频 → faster-whisper → 字幕JSON ─────────────────────────────────────────────┴→ Remotion合成 → 最终视频
```
## 系统要求
@@ -140,7 +144,7 @@ remotion/
| 阶段 | 进度 | 说明 |
|------|------|------|
| 下载素材 | 0% → 5% | 从 Supabase 下载输入视频 |
| TTS 语音生成 | 5% → 25% | EdgeTTS Qwen3-TTS 生成音频 |
| TTS 语音生成 | 5% → 25% | EdgeTTS / Qwen3-TTS / 预生成配音下载 |
| 唇形同步 | 25% → 80% | LatentSync 推理 |
| 字幕对齐 | 80% → 85% | faster-whisper 生成字级别时间戳 |
| Remotion 渲染 | 85% → 95% | 合成字幕和标题 |
@@ -181,7 +185,8 @@ Remotion 渲染参数在 `backend/app/services/remotion_service.py` 中配置:
| 参数 | 默认值 | 说明 |
|------|--------|------|
| `fps` | 25 | 输出帧率 |
| `title_duration` | 3.0 | 标题显示时长(秒 |
| `title_display_mode` | `short` | 标题显示模式(`short`=短暂显示;`persistent`=常驻显示 |
| `title_duration` | 4.0 | 标题显示时长(秒,仅 `short` 模式生效) |
---
@@ -282,4 +287,5 @@ WhisperService(device="cuda:0") # 或 "cuda:1"
| 日期 | 版本 | 说明 |
|------|------|------|
| 2026-01-29 | 1.0.0 | 初始版本,使用 faster-whisper + Remotion 实现逐字高亮字幕和片头标题 |
| 2026-02-10 | 1.1.0 | 更新架构图:多素材 concat-then-infer、预生成配音选项 |
| 2026-01-30 | 1.0.1 | 字幕高亮样式与标题动画优化,视觉表现更清晰 |

View File

@@ -1,8 +1,8 @@
# ViGent2 开发任务清单 (Task Log)
**项目**: ViGent2 数字人口播视频生成系统
**进度**: 100% (Day 21 - 缺陷修复与持久化回归治理)
**更新时间**: 2026-02-08
**进度**: 100% (Day 25 - 支付宝付费开通会员)
**更新时间**: 2026-02-24
---
@@ -10,7 +10,80 @@
> 这里记录了每一天的核心开发内容与 milestone。
### Day 21: 缺陷修复 + 浮动预览 + 发布重构 + 架构优化 + 多素材生成 (Current)
### Day 25: 文案提取修复 + 自定义提示词 + 片头副标题 (Current)
- [x] **抖音文案提取修复**: yt-dlp Fresh cookies 报错,重写 `_download_douyin_manual` 为移动端分享页 + 自动获取 ttwid 方案。
- [x] **清理 DOUYIN_COOKIE**: 新方案不再需要手动维护 Cookie`.env`/`config.py`/`service.py` 全面删除。
- [x] **AI 智能改写自定义提示词**: 后端 `rewrite_script()` 支持 `custom_prompt` 参数;前端 checkbox 旁新增折叠式提示词编辑区localStorage 持久化。
- [x] **SSR 构建修复**: `useState` 初始化 `localStorage` 访问加 `typeof window` 守卫,修复 `npm run build` 报错。
- [x] **片头副标题**: 新增 secondary_title后端/Remotion/前端全链路AI 同时生成独立样式配置20 字限制。
- [x] **前端文案修正**: "AI 洗稿结果"→"AI 改写结果"。
- [x] **yt-dlp 升级**: `2025.12.08``2026.2.21`
- [x] **参考音频中文文件名修复**: `sanitize_filename()` 将存储路径清洗为 ASCII 安全字符,纯中文名哈希兜底,原始名保留为展示名。
### Day 24: 鉴权到期治理 + 多素材时间轴稳定性修复
- [x] **会员到期请求时失效**: 登录与鉴权接口统一执行 `expires_at` 检查;到期后自动停用账号、清理 session并返回“会员已到期请续费”。
- [x] **画面比例控制**: 时间轴新增 `9:16 / 16:9` 输出比例选择,前端持久化并透传后端,单素材/多素材统一按目标分辨率处理。
- [x] **标题/字幕防溢出**: Remotion 与前端预览统一响应式缩放、自动换行、描边/字距/边距比例缩放,降低预览与成片差异。
- [x] **标题显示模式**: 标题行新增“短暂显示/常驻显示”下拉默认短暂显示4 秒),用户选择持久化并透传至 Remotion 渲染链路。
- [x] **MOV 方向归一化**: 新增旋转元数据解析与 orientation normalize修复“编码横屏+旋转元数据”导致的竖屏判断偏差。
- [x] **多素材拼接稳定性**: 片段 prepare 与 concat 统一 25fps/CFRconcat 增加 `+genpts`,缓解段切换处“画面冻结口型还动”。
- [x] **时间轴语义对齐**: 打通 `source_end` 全链路;修复 `sourceStart>0 且 sourceEnd=0` 时长计算;生成时以时间轴可见段 assignments 为准,超出段不参与。
- [x] **交互细节优化**: 页面刷新回顶部;素材/历史列表首轮自动滚动抑制,减少恢复状态时页面跳动。
### Day 23: 配音前置重构 + 素材时间轴编排 + UI 体验优化 + 声音克隆增强
#### 第一阶段:配音前置
- [x] **配音生成独立化**: 新增 `generated_audios` 后端模块router/schemas/service5 个 API 端点,复用现有 TTSService / voice_clone_service / task_store。
- [x] **配音管理面板**: 前端新增 `useGeneratedAudios` hook + `GeneratedAudiosPanel` 组件,支持生成/试听/改名/删除/选中。
- [x] **UI 面板重排序**: 文案 → 标题字幕 → 配音方式 → 配音列表 → 素材选择 → BGM → 生成视频。
- [x] **素材区门控**: 未选中配音时素材区显示遮罩,选中后显示配音时长 + 素材均分信息。
- [x] **视频生成对接**: workflow.py 新增预生成音频分支(`generated_audio_id`),跳过内联 TTS向后兼容。
- [x] **持久化**: selectedAudioId 加入 useHomePersistence刷新页面恢复选中配音。
#### 第二阶段:素材时间轴编排
- [x] **时间轴编辑器**: 新增 `TimelineEditor` 组件wavesurfer.js 音频波形 + 色块可视化素材分配,拖拽分割线调整各段时长。
- [x] **素材截取设置**: 新增 `ClipTrimmer` 模态框HTML5 视频预览 + 双端滑块设置源视频截取起点/终点。
- [x] **后端自定义分配**: 新增 `CustomAssignment` 模型,`prepare_segment` 支持 `source_start`workflow 多素材/单素材流水线支持 `custom_assignments`
- [x] **循环截取修复**: `stream_loop + source_start` 改为两步处理(先裁剪再循环),确保从截取起点循环而非从视频 0s 开始。
- [x] **MaterialSelector 精简**: 移除旧的时长信息栏和拖拽排序区(功能迁移到 TimelineEditor
#### 第三阶段UI 体验优化 + TTS 稳定性
- [x] **TTS SoX PATH 修复**: `run_qwen_tts.sh` export conda env bin 到 PATH (Qwen3-TTS 已停用,已被 CosyVoice 3.0 替换)。
- [x] **TTS 显存管理**: 每次生成后 `torch.cuda.empty_cache()`asyncio.to_thread 避免阻塞事件循环 (CosyVoice 沿用相同机制)。
- [x] **配音列表按钮统一**: Play/Edit/Delete 按钮右侧同组 hover 显示,与 RefAudioPanel 一致,移除文案摘要。
- [x] **素材区解除配音门控**: 移除 MaterialSelector 的 selectedAudio 遮罩,素材随时可上传管理。
- [x] **时间轴拖拽排序**: TimelineEditor 色块支持 HTML5 Drag & Drop 调换素材顺序。
- [x] **截取设置 Range Slider**: ClipTrimmer 改为单轨道双手柄(紫色起点+粉色终点),替换两个独立滑块。
- [x] **截取设置视频预览**: 视频区域可播放/暂停,从 sourceStart 到 sourceEnd 自动停止,拖拽手柄时实时 seek。
#### 第四阶段:历史文案 + Bug 修复
- [x] **历史文案保存与加载**: 新增 `useSavedScripts` hook手动保存/加载/删除历史文案,独立 localStorage 持久化。
- [x] **时间轴拖拽修复**: `reorderSegments` 从属性交换改为数组移动splice修复拖拽后时长不跟随素材的 Bug。
- [x] **按钮视觉统一**: 文案编辑区 4 个按钮统一为固定高度 `h-7`,移除多余 `<span>` 嵌套。
- [x] **底部栏调整**: "保存文案"按钮移至底部右侧,移除预计时长显示。
#### 第五阶段:字幕语言不匹配 + 视频比例错位修复
- [x] **字幕用原文替换 Whisper 转录**: `align()` 新增 `original_text` 参数,字幕文字永远用配音保存的原始文案。
- [x] **Remotion 动态视频尺寸**: `calculateMetadata` 从 props 读取真实尺寸,修复标题/字幕比例错位。
- [x] **英文空格丢失修复**: `split_word_to_chars` 遇到空格时 flush buffer + pending_space 标记。
#### 第六阶段:参考音频自动转写 + 语速控制
- [x] **Whisper 自动转写 ref_text**: 上传参考音频时自动调用 Whisper 转写内容作为 ref_text不再使用前端固定文字。
- [x] **参考音频自动截取**: 超过 10 秒自动在静音点截取ffmpeg silencedetect末尾 0.1 秒淡出避免截断爆音。
- [x] **重新识别功能**: 新增 `POST /ref-audios/{id}/retranscribe` 端点 + 前端 RotateCw 按钮,旧音频可重新转写并截取。
- [x] **语速控制**: 全链路 speed 参数(前端选择器 → 持久化 → 后端 → CosyVoice `inference_zero_shot(speed=)`5 档:较慢(0.8)/稍慢(0.9)/正常(1.0)/稍快(1.1)/较快(1.2)。
- [x] **缺少参考音频门控**: 声音克隆模式下未选参考音频时,生成配音按钮禁用 + 黄色警告提示。
- [x] **Whisper 语言自动检测**: `transcribe()` language 参数改为可选(默认 None = 自动检测),支持多语言参考音频。
- [x] **前端清理**: 移除固定 ref_text 常量、朗读引导文字,简化为"上传任意语音样本,系统将自动识别内容并克隆声音"。
### Day 22: 多素材优化 + AI 翻译 + TTS 多语言
- [x] **多素材 Bug 修复**: 6 个高优 Bug边界溢出、单段 fallback、除零、duration 校验、Whisper 兜底、空列表检查)。
- [x] **架构重构**: 多素材从"逐段 LatentSync"重构为"先拼接再推理",推理次数 N→1。
- [x] **前端优化**: payload 安全、进度消息、上传自动选中、Material 接口统一、拖拽修复、素材上限 4 个。
- [x] **AI 多语言翻译**: 新增 `/api/ai/translate` 接口,前端 9 种语言翻译 + 还原原文。
- [x] **TTS 多语言**: EdgeTTS 10 语言声音列表、翻译自动切换声音、声音克隆 language 透传、textLang 持久化。
### Day 21: 缺陷修复 + 浮动预览 + 发布重构 + 架构优化 + 多素材生成
- [x] **Remotion 崩溃容错**: 渲染进程 SIGABRT 退出时检查输出文件,避免误判失败导致标题/字幕丢失。
- [x] **首页作品选择持久化**: 修复 `fetchGeneratedVideos` 无条件覆盖恢复值的问题,新增 `preferVideoId` 参数控制选中逻辑。
- [x] **发布页作品选择持久化**: 根因为签名 URL 不稳定,全面改用 `video.id` 替代 `path` 进行选择/持久化/比较。
@@ -78,7 +151,7 @@
- [x] **体验细节优化**: 录音预览 URL 回收,预览弹窗滚动恢复,全局任务提示挂载。
### Day 16: 深度性能优化
- [x] **Qwen-TTS 加速**: 集成 Flash Attention 2,模型加载速度提升至 8.9s
- [x] **Qwen-TTS 加速**: 集成 Flash Attention 2 (已停用,被 CosyVoice 3.0 替换)
- [x] **服务守护**: 开发 `Watchdog` 看门狗机制,自动监控并重启僵死服务。
- [x] **LatentSync 性能确认**: 验证 DeepCache + 原生 Flash Attn 生效。
- [x] **文档重构**: 全面更新 README、部署手册及后端文档。
@@ -91,10 +164,10 @@
### Day 14: AI 增强与体验优化
- [x] **AI 标题/标签**: 集成 GLM-4API 自动生成视频元数据。
- [x] **字幕升级**: Remotion 逐字高亮字幕 (卡拉OK效果) 及动画片头。
- [x] **模型升级**: Qwen3-TTS 升级至 1.7B-Base 版本
- [x] **模型升级**: 声音克隆已迁移至 CosyVoice 3.0 (0.5B)
### Day 13: 声音克隆集成
- [x] **声音克隆微服务**: 封装 Qwen3-TTS 为独立 API (8009端口)。
- [x] **声音克隆微服务**: 封装 CosyVoice 3.0 为独立 API (8010端口替换 Qwen3-TTS)。
- [x] **参考音频管理**: Supabase 存储桶配置与管理接口。
- [x] **多模态 TTS**: 前端支持 EdgeTTS / Clone Voice 切换。
@@ -129,6 +202,7 @@
## 🛤️ 后续规划 (Roadmap)
### 🔴 优先待办
- [x] ~~**配音前置重构 — 第二阶段**: 素材片段截取 + 语音时间轴编排~~ ✅ Day 23 已完成
- [ ] **批量生成架构**: 支持 Excel 导入,批量生产视频。
- [ ] **定时任务后台化**: 迁移前端触发的定时发布到后端 APScheduler。
- [ ] **发布任务恢复机制**: 发布任务化 + 状态持久化 + 前端断点恢复,解决刷新后状态丢失。
@@ -146,9 +220,10 @@
| **核心 API** | 100% | ✅ 稳定 |
| **Web UI** | 100% | ✅ 稳定 (移动端适配) |
| **唇形同步** | 100% | ✅ LatentSync 1.6 |
| **TTS 配音** | 100% | ✅ EdgeTTS + Qwen3 |
| **TTS 配音** | 100% | ✅ EdgeTTS + CosyVoice 3.0 + 配音前置 + 时间轴编排 + 自动转写 + 语速控制 |
| **自动发布** | 100% | ✅ 抖音/微信视频号/B站/小红书 |
| **用户认证** | 100% | ✅ 手机号 + JWT |
| **付费会员** | 100% | ✅ 支付宝电脑网站支付 + 自动激活 |
| **部署运维** | 100% | ✅ PM2 + Watchdog |
---

View File

@@ -5,7 +5,7 @@
> 📹 **上传人物** · 🎙️ **输入文案** · 🎬 **一键成片**
基于 **LatentSync 1.6 + EdgeTTS** 的开源数字人口播视频生成系统。
集成 **Qwen3-TTS** 声音克隆与自动社交媒体发布功能。
集成 **CosyVoice 3.0** 声音克隆与自动社交媒体发布功能。
[功能特性](#-功能特性) • [技术栈](#-技术栈) • [文档中心](#-文档中心) • [部署指南](Docs/DEPLOY_MANUAL.md)
@@ -17,19 +17,24 @@
### 核心能力
- 🎬 **高清唇形同步** - LatentSync 1.6 驱动512×512 高分辨率 Latent Diffusion 模型。
- 🎙️ **多模态配音** - 支持 **EdgeTTS** (微软超自然语音) 和 **Qwen3-TTS** (3秒极速声音克隆)
- 🎙️ **多模态配音** - 支持 **EdgeTTS** (微软超自然语音, 10 语言) 和 **CosyVoice 3.0** (3秒极速声音克隆, 9语言+18方言, 语速可调)。上传参考音频自动 Whisper 转写 + 智能截取。配音前置工作流:先生成配音 → 选素材 → 生成视频
- 📝 **智能字幕** - 集成 faster-whisper + Remotion自动生成逐字高亮 (卡拉OK效果) 字幕。
- 🎨 **样式预设** - 标题/字幕样式选择 + 预览 + 字号调节,支持自定义字体库。
- 🖼 **作品预览一致性** - 标题/字幕预览按素材分辨率缩放,效果更接近成片
- 💾 **用户偏好持久化** - 首页状态统一恢复/保存,刷新后延续上次配置
- 🎨 **样式预设** - 标题/副标题/字幕样式选择 + 预览 + 字号调节,支持自定义字体库。
- 🏷 **标题显示模式** - 片头标题支持 `短暂显示` / `常驻显示`默认短暂显示4秒用户偏好自动持久化
- 📌 **片头副标题** - 可选副标题显示在主标题下方独立样式配置AI 可同时生成20 字限制
- 🖼️ **作品预览一致性** - 标题/字幕预览与 Remotion 成片统一响应式缩放和自动换行,窄屏画布也稳定显示。
- 🎞️ **多素材多机位** - 支持多选素材 + 时间轴编辑器 (wavesurfer.js 波形可视化),拖拽分割线调整时长、拖拽排序切换机位、按 `source_start/source_end` 截取片段。
- 📐 **画面比例控制** - 时间轴一键切换 `9:16 / 16:9` 输出比例,生成链路全程按目标比例处理。
- 💾 **用户偏好持久化** - 首页状态统一恢复/保存,刷新后延续上次配置。历史文案手动保存与加载。
- 🎵 **背景音乐** - 试听 + 音量控制 + 混音,保持配音音量稳定。
- 🤖 **AI 辅助创作** - 内置 GLM-4.7-Flash支持 B站/抖音链接文案提取、AI 洗稿、标题/标签自动生成。
- 🤖 **AI 辅助创作** - 内置 GLM-4.7-Flash支持 B站/抖音链接文案提取、AI 智能改写(支持自定义提示词)、标题/标签自动生成、9 语言翻译
### 平台化功能
- 📱 **全自动发布** - 支持抖音/微信视频号/B站/小红书立即发布;扫码登录 + Cookie 持久化。
- 🖥️ **发布管理预览** - 支持签名 URL / 相对路径作品预览,确保可直接播放。
- 📸 **发布结果可视化** - 抖音/微信视频号发布成功后返回截图,发布页结果卡片可直接查看。
- 🛡️ **发布防误操作** - 发布进行中自动提示“请勿刷新或关闭网页”,并拦截刷新/关页二次确认。
- 💳 **付费会员** - 支付宝电脑网站支付自动开通会员,到期自动停用并引导续费,管理员手动激活并存。
- 🔐 **认证与隔离** - 基于 Supabase 的用户隔离,支持手机号注册/登录、密码管理。
- 🛡️ **服务守护** - 内置 Watchdog 看门狗机制,自动监控并重启僵死服务,确保 7x24h 稳定运行。
- 🚀 **性能优化** - 视频预压缩、模型常驻服务(近实时加载)、双 GPU 流水线并发。
@@ -40,11 +45,11 @@
| 领域 | 核心技术 | 说明 |
|------|----------|------|
| **前端** | Next.js 16 | TypeScript, TailwindCSS, SWR |
| **前端** | Next.js 16 | TypeScript, TailwindCSS, SWR, wavesurfer.js |
| **后端** | FastAPI | Python 3.10, AsyncIO, PM2 |
| **数据库** | Supabase | PostgreSQL, Storage (本地/S3), Auth |
| **唇形同步** | LatentSync 1.6 | PyTorch 2.5, Diffusers, DeepCache |
| **声音克隆** | Qwen3-TTS | 1.7B 参数量Flash Attention 2 加速 |
| **声音克隆** | CosyVoice 3.0 | 0.5B 参数量9 语言 + 18 方言 |
| **自动化** | Playwright | 社交媒体无头浏览器自动化 |
| **部署** | Docker & PM2 | 混合部署架构 |
@@ -56,9 +61,10 @@
### 部署运维
- **[部署手册 (DEPLOY_MANUAL.md)](Docs/DEPLOY_MANUAL.md)** - 👈 **部署请看这里**!包含完整的环境搭建步骤。
- [参考音频服务部署 (QWEN3_TTS_DEPLOY.md)](Docs/QWEN3_TTS_DEPLOY.md) - 声音克隆模型部署指南。
- [参考音频服务部署 (COSYVOICE3_DEPLOY.md)](Docs/COSYVOICE3_DEPLOY.md) - 声音克隆模型部署指南。
- [LatentSync 部署指南](models/LatentSync/DEPLOY.md) - 唇形同步模型独立部署。
- [Supabase 部署指南 (SUPABASE_DEPLOY.md)](Docs/SUPABASE_DEPLOY.md) - Supabase 与认证系统配置。
- [支付宝部署指南 (ALIPAY_DEPLOY.md)](Docs/ALIPAY_DEPLOY.md) - 支付宝付费开通会员配置。
### 开发文档
- [后端开发指南](Docs/BACKEND_README.md) - 接口规范与开发流程。
@@ -81,7 +87,7 @@ ViGent2/
├── remotion/ # Remotion 视频渲染 (标题/字幕合成)
├── models/ # AI 模型仓库
│ ├── LatentSync/ # 唇形同步服务
│ └── Qwen3-TTS/ # 声音克隆服务
│ └── CosyVoice/ # 声音克隆服务
└── Docs/ # 项目文档
```
@@ -96,7 +102,7 @@ ViGent2/
| **Web UI** | 3002 | 用户访问入口 (Next.js) |
| **Backend API** | 8006 | 核心业务接口 (FastAPI) |
| **LatentSync** | 8007 | 唇形同步推理服务 |
| **Qwen3-TTS** | 8009 | 声音克隆推理服务 |
| **CosyVoice 3.0** | 8010 | 声音克隆推理服务 |
| **Supabase** | 8008 | 数据库与认证网关 |
---

View File

@@ -71,5 +71,10 @@ GLM_MODEL=glm-4.7-flash
SUPABASE_STORAGE_LOCAL_PATH=/home/rongye/ProgramFiles/Supabase/volumes/storage/stub/stub
# =============== 抖音视频下载 Cookie ===============
# 用于从抖音 URL 提取视频文案功能,会过期需要定期更新
DOUYIN_COOKIE=douyin.com; device_web_cpu_core=10; device_web_memory_size=8; __ac_nonce=06760391f00b9b51264ae; __ac_signature=_02B4Z6wo00f019a5ceAAAIDAhEZR-X3jjWfWmXVAAJLXd4; ttwid=1%7C7MTKBSMsP4eOv9h5NAh8p0E-NYIud09ftNmB0mjLpWc%7C1734359327%7C8794abeabbd47447e1f56e5abc726be089f2a0344d6343b5f75f23e7b0f0028f; UIFID_TEMP=0de8750d2b188f4235dbfd208e44abbb976428f0720eb983255afefa45d39c0c6532e1d4768dd8587bf919f866ff1396912bcb2af71efee56a14a2a9f37b74010d0a0413795262f6d4afe02a032ac7ab; s_v_web_id=verify_m4r4ribr_c7krmY1z_WoeI_43po_ATpO_I4o8U1bex2D7; hevc_supported=true; home_can_add_dy_2_desktop=%220%22; dy_swidth=2560; dy_sheight=1440; stream_recommend_feed_params=%22%7B%5C%22cookie_enabled%5C%22%3Atrue%2C%5C%22screen_width%5C%22%3A2560%2C%5C%22screen_height%5C%22%3A1440%2C%5C%22browser_online%5C%22%3Atrue%2C%5C%22cpu_core_num%5C%22%3A10%2C%5C%22device_memory%5C%22%3A8%2C%5C%22downlink%5C%22%3A10%2C%5C%22effective_type%5C%22%3A%5C%224g%5C%22%2C%5C%22round_trip_time%5C%22%3A50%7D%22; strategyABtestKey=%221734359328.577%22; csrf_session_id=2f53aed9aa6974e83aa9a1014180c3a4; fpk1=U2FsdGVkX1/IpBh0qdmlKAVhGyYHgur4/VtL9AReZoeSxadXn4juKvsakahRGqjxOPytHWspYoBogyhS/V6QSw==; fpk2=0845b309c7b9b957afd9ecf775a4c21f; passport_csrf_token=d80e0c5b2fa2328219856be5ba7e671e; passport_csrf_token_default=d80e0c5b2fa2328219856be5ba7e671e; odin_tt=3c891091d2eb0f4718c1d5645bc4a0017032d4d5aa989decb729e9da2ad570918cbe5e9133dc6b145fa8c758de98efe32ff1f81aa0d611e838cc73ab08ef7d3f6adf66ab4d10e8372ddd628f94f16b8e; volume_info=%7B%22isUserMute%22%3Afalse%2C%22isMute%22%3Afalse%2C%22volume%22%3A0.5%7D; bd_ticket_guard_client_web_domain=2; FORCE_LOGIN=%7B%22videoConsumedRemainSeconds%22%3A180%7D; UIFID=0de8750d2b188f4235dbfd208e44abbb976428f0720eb983255afefa45d39c0c6532e1d4768dd8587bf919f866ff139655a3c2b735923234f371c699560c657923fd3d6c5b63ab7bb9b83423b6cb4787e2ce66a7fbc4ecb24c8570f520fe6de068bbb95115023c0c6c1b6ee31b49fb7e3996fb8349f43a3fd8b7a61cd9e18e8fe65eb6a7c13de4c0960d84e344b644725db3eb2fa6b7caf821de1b50527979f2; is_dash_user=1; biz_trace_id=b57a241f; bd_ticket_guard_client_data=eyJiZC10aWNrZXQtZ3VhcmQtdmVyc2lvbiI6MiwiYmQtdGlja2V0LWd1YXJkLWl0ZXJhdGlvbi12ZXJzaW9uIjoxLCJiZC10aWNrZXQtZ3VhcmQtcmVlLXB1YmxpYy1rZXkiOiJCTEo2R0lDalVoWW1XcHpGOFdrN0Vrc0dXcCtaUzNKY1g4NGNGY2k0TTl1TEowNjdUb21mbFU5aDdvWVBGamhNRWNRQWtKdnN1MnM3RmpTWnlJQXpHMjA9IiwiYmQtdGlja2V0LWd1YXJkLXdlYi12ZXJzaW9uIjoyfQ%3D%3D; download_guide=%221%2F20241216%2F0%22; sdk_source_info=7e276470716a68645a606960273f276364697660272927676c715a6d6069756077273f276364697660272927666d776a68605a607d71606b766c6a6b5a7666776c7571273f275e58272927666a6b766a69605a696c6061273f27636469766027292762696a6764695a7364776c6467696076273f275e5827292771273f273d33323131333c3036313632342778; bit_env=RiOY4jzzpxZoVCl6zdVSVhVRjdwHRTxqcqWdqMBZLPGjMdB4Tax1kAELHNTVAAh72KuhumewE4Lq6f0-VJ2UpJrkrhSxoPw9LUb3zQrq1OSwbeSPHkRlRgRQvO89sItdGUyq1oFr0XyRCnMYG87KSeWyc4x0czGR0o50hTDoDLG5rJVoRcdQOLvjiAegsqyytKF59sPX_QM9qffK2SqYsg0hCggURc_AI6kguDDE5DvG0bnyz1utw4z1eEnIoLrkGDqzqBZj4dOAr0BVU6ofbsS-pOQ2u2PM1dLP9FlBVBlVaqYVgHJeSLsR5k76BRTddUjTb4zEilVIEwAMJWGN4I1BxVt6fC9B5tBQpuT0lj3n3eKXCKXZsd8FrEs5_pbfDsxV-e_WMiXI2ff4qxiTC0U73sfo9OpicKICtZjdq8qsHxJuu6wVR36zvXeL2Wch5C6MzprNvkivv0l8nbh2mSgy1nabZr3dmU6NcR-Bg3Q3xTWUlR9aAUmpopC-cNuXjgLpT-Lw1AYGilSUnCvosth1Gfypq-b0MpgmdSDgTrQ%3D; gulu_source_res=eyJwX2luIjoiMDhjOGQ3ZTJiODQyNjZkZWI5Y2VkMGJiODNlNmY1ZWY0ZjMyNTE2ZmYyZjAzNDMzZjI0OWU1Y2Q1NTczNTk5NyJ9; passport_auth_mix_state=hp9bc3dgb1tm5wd8p82zawus27g0e3ue; IsDouyinActive=false
# =============== 支付宝配置 ===============
ALIPAY_APP_ID=2021006132600283
ALIPAY_PRIVATE_KEY_PATH=/home/rongye/ProgramFiles/ViGent2/backend/keys/app_private_key.pem
ALIPAY_PUBLIC_KEY_PATH=/home/rongye/ProgramFiles/ViGent2/backend/keys/alipay_public_key.pem
ALIPAY_NOTIFY_URL=https://vigent.hbyrkj.top/api/payment/notify
ALIPAY_RETURN_URL=https://vigent.hbyrkj.top/pay

View File

@@ -76,12 +76,18 @@ class Settings(BaseSettings):
GLM_API_KEY: str = ""
GLM_MODEL: str = "glm-4.7-flash"
# 支付宝配置
ALIPAY_APP_ID: str = ""
ALIPAY_PRIVATE_KEY_PATH: str = "" # 应用私钥 PEM 文件路径
ALIPAY_PUBLIC_KEY_PATH: str = "" # 支付宝公钥 PEM 文件路径
ALIPAY_NOTIFY_URL: str = "" # 异步通知回调地址(公网可达)
ALIPAY_RETURN_URL: str = "" # 支付成功后同步跳转地址
ALIPAY_SANDBOX: bool = False # 是否使用沙箱环境
PAYMENT_AMOUNT: float = 999.00 # 会员价格(元)
PAYMENT_EXPIRE_DAYS: int = 365 # 会员有效天数
# CORS 配置 (逗号分隔的域名列表,* 表示允许所有)
CORS_ORIGINS: str = "*"
# 抖音 Cookie (用于视频下载功能,会过期需要定期更新)
DOUYIN_COOKIE: str = ""
@property
def LATENTSYNC_DIR(self) -> Path:
"""LatentSync 目录路径 (动态计算)"""

View File

@@ -1,11 +1,11 @@
"""
依赖注入模块:认证和用户获取
"""
from typing import Optional, Any, Dict, cast
from typing import Optional, Any, Dict, cast
from fastapi import Request, HTTPException, Depends, status
from app.core.security import decode_access_token, TokenData
from app.repositories.sessions import get_session
from app.repositories.users import get_user_by_id
from app.core.security import decode_access_token
from app.repositories.sessions import get_session, delete_sessions
from app.repositories.users import get_user_by_id, deactivate_user_if_expired
from loguru import logger
@@ -14,9 +14,9 @@ async def get_token_from_cookie(request: Request) -> Optional[str]:
return request.cookies.get("access_token")
async def get_current_user_optional(
request: Request
) -> Optional[Dict[str, Any]]:
async def get_current_user_optional(
request: Request
) -> Optional[Dict[str, Any]]:
"""
获取当前用户 (可选,未登录返回 None)
"""
@@ -29,22 +29,30 @@ async def get_current_user_optional(
return None
# 验证 session_token 是否有效 (单设备登录检查)
try:
session = get_session(token_data.user_id, token_data.session_token)
if not session:
logger.warning(f"Session token 无效: user_id={token_data.user_id}")
return None
user = get_user_by_id(token_data.user_id)
return cast(Optional[Dict[str, Any]], user)
except Exception as e:
logger.error(f"获取用户信息失败: {e}")
return None
try:
session = get_session(token_data.user_id, token_data.session_token)
if not session:
logger.warning(f"Session token 无效: user_id={token_data.user_id}")
return None
user = cast(Optional[Dict[str, Any]], get_user_by_id(token_data.user_id))
if user and deactivate_user_if_expired(user):
delete_sessions(token_data.user_id)
return None
if user and not user.get("is_active"):
delete_sessions(token_data.user_id)
return None
return user
except Exception as e:
logger.error(f"获取用户信息失败: {e}")
return None
async def get_current_user(
request: Request
) -> Dict[str, Any]:
async def get_current_user(
request: Request
) -> Dict[str, Any]:
"""
获取当前用户 (必须登录)
@@ -66,40 +74,45 @@ async def get_current_user(
detail="Token 无效或已过期"
)
try:
session = get_session(token_data.user_id, token_data.session_token)
if not session:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="会话已失效,请重新登录(可能已在其他设备登录)"
)
user = get_user_by_id(token_data.user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户不存在"
)
user = cast(Dict[str, Any], user)
if user.get("expires_at"):
from datetime import datetime, timezone
expires_at = datetime.fromisoformat(user["expires_at"].replace("Z", "+00:00"))
if datetime.now(timezone.utc) > expires_at:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="授权已过期,请联系管理员续期"
)
return user
except HTTPException:
raise
except Exception as e:
logger.error(f"获取用户信息失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="服务器错误"
)
try:
session = get_session(token_data.user_id, token_data.session_token)
if not session:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="会话已失效,请重新登录(可能已在其他设备登录)"
)
user = get_user_by_id(token_data.user_id)
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户不存在"
)
user = cast(Dict[str, Any], user)
if deactivate_user_if_expired(user):
delete_sessions(token_data.user_id)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="会员已到期,请续费"
)
if not user.get("is_active"):
delete_sessions(token_data.user_id)
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="账号已停用"
)
return user
except HTTPException:
raise
except Exception as e:
logger.error(f"获取用户信息失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="服务器错误"
)
async def get_current_admin(

View File

@@ -110,3 +110,28 @@ def set_auth_cookie(response: Response, token: str) -> None:
def clear_auth_cookie(response: Response) -> None:
"""清除认证 Cookie"""
response.delete_cookie(key="access_token")
def create_payment_token(user_id: str) -> str:
"""生成付费专用短期 JWT token30 分钟有效)"""
payload = {
"sub": user_id,
"purpose": "payment",
"exp": datetime.now(timezone.utc) + timedelta(minutes=30),
}
return jwt.encode(payload, settings.JWT_SECRET_KEY, algorithm=settings.JWT_ALGORITHM)
def decode_payment_token(token: str) -> str | None:
"""解析 payment_token返回 user_id仅 purpose=payment 有效)"""
try:
data = jwt.decode(
token,
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM],
)
if data.get("purpose") != "payment":
return None
return data.get("sub")
except JWTError:
return None

View File

@@ -15,6 +15,8 @@ from app.modules.ref_audios.router import router as ref_audios_router
from app.modules.ai.router import router as ai_router
from app.modules.tools.router import router as tools_router
from app.modules.assets.router import router as assets_router
from app.modules.generated_audios.router import router as generated_audios_router
from app.modules.payment.router import router as payment_router
from loguru import logger
import os
@@ -124,6 +126,8 @@ app.include_router(ref_audios_router, prefix="/api/ref-audios", tags=["RefAudios
app.include_router(ai_router) # /api/ai
app.include_router(tools_router, prefix="/api/tools", tags=["Tools"])
app.include_router(assets_router, prefix="/api/assets", tags=["Assets"])
app.include_router(generated_audios_router, prefix="/api/generated-audios", tags=["GeneratedAudios"])
app.include_router(payment_router) # /api/payment
@app.on_event("startup")

View File

@@ -21,6 +21,7 @@ class GenerateMetaRequest(BaseModel):
class GenerateMetaResponse(BaseModel):
"""生成标题标签响应"""
title: str
secondary_title: str = ""
tags: list[str]
@@ -66,6 +67,7 @@ async def generate_meta(req: GenerateMetaRequest):
result = await glm_service.generate_title_tags(req.text)
return success_response(GenerateMetaResponse(
title=result.get("title", ""),
secondary_title=result.get("secondary_title", ""),
tags=result.get("tags", [])
).model_dump())
except Exception as e:

View File

@@ -1,22 +1,32 @@
"""
认证 API注册、登录、登出、修改密码
"""
from fastapi import APIRouter, HTTPException, Response, status, Request
from fastapi import APIRouter, HTTPException, Response, status, Request, Depends
from fastapi.responses import JSONResponse
from pydantic import BaseModel, field_validator
from app.core.security import (
get_password_hash,
verify_password,
create_access_token,
generate_session_token,
set_auth_cookie,
clear_auth_cookie,
decode_access_token
)
from app.repositories.sessions import create_session, delete_sessions
from app.repositories.users import create_user, get_user_by_id, get_user_by_phone, user_exists_by_phone, update_user
from app.core.response import success_response
from app.core.security import (
get_password_hash,
verify_password,
create_access_token,
generate_session_token,
set_auth_cookie,
clear_auth_cookie,
decode_access_token,
create_payment_token,
)
from app.repositories.sessions import create_session, delete_sessions
from app.repositories.users import (
create_user,
get_user_by_id,
get_user_by_phone,
user_exists_by_phone,
update_user,
deactivate_user_if_expired,
)
from app.core.deps import get_current_user
from app.core.response import success_response
from loguru import logger
from typing import Optional, Any, cast
from typing import Optional, Any, cast
import re
router = APIRouter(prefix="/api/auth", tags=["认证"])
@@ -76,26 +86,26 @@ async def register(request: RegisterRequest):
注册后状态为 pending需要管理员激活
"""
try:
if user_exists_by_phone(request.phone):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="该手机号已注册"
)
if user_exists_by_phone(request.phone):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="该手机号已注册"
)
# 创建用户
password_hash = get_password_hash(request.password)
create_user({
"phone": request.phone,
"password_hash": password_hash,
"username": request.username or f"用户{request.phone[-4:]}",
"role": "pending",
"is_active": False
})
create_user({
"phone": request.phone,
"password_hash": password_hash,
"username": request.username or f"用户{request.phone[-4:]}",
"role": "pending",
"is_active": False
})
logger.info(f"新用户注册: {request.phone}")
return success_response(message="注册成功,请等待管理员审核激活")
return success_response(message="注册成功,请等待管理员审核激活")
except HTTPException:
raise
except Exception as e:
@@ -116,12 +126,12 @@ async def login(request: LoginRequest, response: Response):
- 实现"后踢前"单设备登录
"""
try:
user = cast(dict[str, Any], get_user_by_phone(request.phone) or {})
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="手机号或密码错误"
)
user = cast(dict[str, Any], get_user_by_phone(request.phone) or {})
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="手机号或密码错误"
)
# 验证密码
if not verify_password(request.password, user["password_hash"]):
@@ -130,29 +140,33 @@ async def login(request: LoginRequest, response: Response):
detail="手机号或密码错误"
)
# 检查是否激活
if not user["is_active"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="账号未激活,请等待管理员审核"
# 过期自动停用(注意:只更新 DB不修改内存中的 user 字典)
expired = deactivate_user_if_expired(user)
if expired:
delete_sessions(user["id"])
# 过期 或 未激活(新注册)→ 返回付费指引
if expired or not user["is_active"]:
payment_token = create_payment_token(user["id"])
return JSONResponse(
status_code=403,
content={
"success": False,
"message": "请付费开通会员",
"code": 403,
"data": {
"reason": "PAYMENT_REQUIRED",
"payment_token": payment_token,
}
}
)
# 检查授权是否过期
if user.get("expires_at"):
from datetime import datetime, timezone
expires_at = datetime.fromisoformat(user["expires_at"].replace("Z", "+00:00"))
if datetime.now(timezone.utc) > expires_at:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="授权已过期,请联系管理员续期"
)
# 生成新的 session_token (后踢前)
session_token = generate_session_token()
# 删除旧 session插入新 session
delete_sessions(user["id"])
create_session(user["id"], session_token, None)
delete_sessions(user["id"])
create_session(user["id"], session_token, None)
# 生成 JWT Token
token = create_access_token(user["id"], session_token)
@@ -162,19 +176,19 @@ async def login(request: LoginRequest, response: Response):
logger.info(f"用户登录: {request.phone}")
return success_response(
data={
"user": UserResponse(
id=user["id"],
phone=user["phone"],
username=user.get("username"),
role=user["role"],
is_active=user["is_active"],
expires_at=user.get("expires_at")
).model_dump()
},
message="登录成功",
)
return success_response(
data={
"user": UserResponse(
id=user["id"],
phone=user["phone"],
username=user.get("username"),
role=user["role"],
is_active=user["is_active"],
expires_at=user.get("expires_at")
).model_dump()
},
message="登录成功",
)
except HTTPException:
raise
except Exception as e:
@@ -186,10 +200,10 @@ async def login(request: LoginRequest, response: Response):
@router.post("/logout")
async def logout(response: Response):
"""用户登出"""
clear_auth_cookie(response)
return success_response(message="已登出")
async def logout(response: Response):
"""用户登出"""
clear_auth_cookie(response)
return success_response(message="已登出")
@router.post("/change-password")
@@ -217,12 +231,12 @@ async def change_password(request: ChangePasswordRequest, req: Request, response
)
try:
user = cast(dict[str, Any], get_user_by_id(token_data.user_id) or {})
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户不存在"
)
user = cast(dict[str, Any], get_user_by_id(token_data.user_id) or {})
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户不存在"
)
# 验证当前密码
if not verify_password(request.old_password, user["password_hash"]):
@@ -233,13 +247,13 @@ async def change_password(request: ChangePasswordRequest, req: Request, response
# 更新密码
new_password_hash = get_password_hash(request.new_password)
update_user(user["id"], {"password_hash": new_password_hash})
update_user(user["id"], {"password_hash": new_password_hash})
# 生成新的 session token使旧 token 失效
new_session_token = generate_session_token()
delete_sessions(user["id"])
create_session(user["id"], new_session_token, None)
delete_sessions(user["id"])
create_session(user["id"], new_session_token, None)
# 生成新的 JWT Token
new_token = create_access_token(user["id"], new_session_token)
@@ -247,7 +261,7 @@ async def change_password(request: ChangePasswordRequest, req: Request, response
logger.info(f"用户修改密码: {user['phone']}")
return success_response(message="密码修改成功")
return success_response(message="密码修改成功")
except HTTPException:
raise
except Exception as e:
@@ -259,35 +273,13 @@ async def change_password(request: ChangePasswordRequest, req: Request, response
@router.get("/me")
async def get_me(request: Request):
async def get_me(user: dict = Depends(get_current_user)):
"""获取当前用户信息"""
# 从 Cookie 获取用户
token = request.cookies.get("access_token")
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未登录"
)
token_data = decode_access_token(token)
if not token_data:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token 无效"
)
user = cast(dict[str, Any], get_user_by_id(token_data.user_id) or {})
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户不存在"
)
return success_response(UserResponse(
id=user["id"],
phone=user["phone"],
username=user.get("username"),
role=user["role"],
is_active=user["is_active"],
expires_at=user.get("expires_at")
).model_dump())
return success_response(UserResponse(
id=user["id"],
phone=user["phone"],
username=user.get("username"),
role=user["role"],
is_active=user["is_active"],
expires_at=user.get("expires_at")
).model_dump())

View File

@@ -0,0 +1,77 @@
"""生成配音 API"""
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException
import uuid
from loguru import logger
from app.core.deps import get_current_user
from app.core.response import success_response
from app.modules.videos.task_store import create_task, get_task
from app.modules.generated_audios.schemas import GenerateAudioRequest, RenameAudioRequest
from app.modules.generated_audios import service
router = APIRouter()
@router.post("/generate")
async def generate_audio(
req: GenerateAudioRequest,
background_tasks: BackgroundTasks,
user: dict = Depends(get_current_user),
):
"""异步生成配音(返回 task_id"""
task_id = str(uuid.uuid4())
create_task(task_id, user["id"])
background_tasks.add_task(service.generate_audio_task, task_id, req, user["id"])
return success_response({"task_id": task_id})
@router.get("/tasks/{task_id}")
async def get_audio_task(task_id: str, user: dict = Depends(get_current_user)):
"""轮询配音生成进度"""
task = get_task(task_id)
if task.get("status") != "not_found" and task.get("user_id") != user["id"]:
return success_response({"status": "not_found"})
return success_response(task)
@router.get("")
async def list_audios(user: dict = Depends(get_current_user)):
"""列出当前用户所有已生成配音"""
try:
result = await service.list_generated_audios(user["id"])
return success_response(result)
except Exception as e:
logger.error(f"列出配音失败: {e}")
raise HTTPException(status_code=500, detail=f"获取列表失败: {str(e)}")
@router.delete("/{audio_id:path}")
async def delete_audio(audio_id: str, user: dict = Depends(get_current_user)):
"""删除配音"""
try:
await service.delete_generated_audio(audio_id, user["id"])
return success_response(message="删除成功")
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except Exception as e:
logger.error(f"删除配音失败: {e}")
raise HTTPException(status_code=500, detail=f"删除失败: {str(e)}")
@router.put("/{audio_id:path}")
async def rename_audio(
audio_id: str,
request: RenameAudioRequest,
user: dict = Depends(get_current_user),
):
"""重命名配音"""
try:
result = await service.rename_generated_audio(audio_id, request.new_name, user["id"])
return success_response(result, message="重命名成功")
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"重命名配音失败: {e}")
raise HTTPException(status_code=500, detail=f"重命名失败: {str(e)}")

View File

@@ -0,0 +1,31 @@
from pydantic import BaseModel
from typing import Optional, List
class GenerateAudioRequest(BaseModel):
text: str
tts_mode: str = "edgetts"
voice: str = "zh-CN-YunxiNeural"
ref_audio_id: Optional[str] = None
ref_text: Optional[str] = None
language: str = "zh-CN"
speed: float = 1.0
class RenameAudioRequest(BaseModel):
new_name: str
class GeneratedAudioItem(BaseModel):
id: str
name: str
path: str
duration_sec: float
text: str
tts_mode: str
language: str
created_at: int
class GeneratedAudioListResponse(BaseModel):
items: List[GeneratedAudioItem]

View File

@@ -0,0 +1,264 @@
"""生成配音 - 业务逻辑"""
import re
import json
import time
import asyncio
import subprocess
import tempfile
import os
from pathlib import Path
from typing import Optional
import httpx
from loguru import logger
from app.services.storage import storage_service
from app.services.tts_service import TTSService
from app.services.voice_clone_service import voice_clone_service
from app.modules.videos.task_store import task_store
from app.modules.generated_audios.schemas import (
GenerateAudioRequest,
GeneratedAudioItem,
GeneratedAudioListResponse,
)
BUCKET = "generated-audios"
def _locale_to_tts_lang(locale: str) -> str:
mapping = {"zh": "Chinese", "en": "English"}
return mapping.get(locale.split("-")[0], "Auto")
def _get_audio_duration(file_path: str) -> float:
try:
result = subprocess.run(
['ffprobe', '-v', 'quiet', '-show_entries', 'format=duration',
'-of', 'csv=p=0', file_path],
capture_output=True, text=True, timeout=10
)
return float(result.stdout.strip())
except Exception as e:
logger.warning(f"获取音频时长失败: {e}")
return 0.0
async def generate_audio_task(task_id: str, req: GenerateAudioRequest, user_id: str):
"""后台任务:生成配音"""
try:
task_store.update(task_id, {"status": "processing", "progress": 10, "message": "正在生成配音..."})
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
audio_path = tmp.name
try:
if req.tts_mode == "voiceclone":
if not req.ref_audio_id or not req.ref_text:
raise ValueError("声音克隆模式需要提供参考音频和参考文字")
task_store.update(task_id, {"progress": 20, "message": "正在下载参考音频..."})
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp_ref:
ref_local = tmp_ref.name
try:
ref_url = await storage_service.get_signed_url(
bucket="ref-audios", path=req.ref_audio_id
)
timeout = httpx.Timeout(None)
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream("GET", ref_url) as resp:
resp.raise_for_status()
with open(ref_local, "wb") as f:
async for chunk in resp.aiter_bytes():
f.write(chunk)
task_store.update(task_id, {"progress": 40, "message": "正在克隆声音..."})
await voice_clone_service.generate_audio(
text=req.text,
ref_audio_path=ref_local,
ref_text=req.ref_text,
output_path=audio_path,
language=_locale_to_tts_lang(req.language),
speed=req.speed,
)
finally:
if os.path.exists(ref_local):
os.unlink(ref_local)
else:
task_store.update(task_id, {"progress": 30, "message": "正在生成语音..."})
tts = TTSService()
await tts.generate_audio(req.text, req.voice, audio_path)
task_store.update(task_id, {"progress": 70, "message": "正在上传配音..."})
duration = _get_audio_duration(audio_path)
timestamp = int(time.time())
audio_id = f"{user_id}/{timestamp}_audio.wav"
meta_id = f"{user_id}/{timestamp}_audio.json"
# 生成 display_name
now = time.strftime("%Y%m%d_%H%M", time.localtime(timestamp))
display_name = f"配音_{now}"
with open(audio_path, "rb") as f:
wav_data = f.read()
await storage_service.upload_file(
bucket=BUCKET, path=audio_id,
file_data=wav_data, content_type="audio/wav",
)
metadata = {
"display_name": display_name,
"text": req.text,
"tts_mode": req.tts_mode,
"voice": req.voice if req.tts_mode == "edgetts" else None,
"ref_audio_id": req.ref_audio_id,
"language": req.language,
"duration_sec": duration,
"created_at": timestamp,
}
await storage_service.upload_file(
bucket=BUCKET, path=meta_id,
file_data=json.dumps(metadata, ensure_ascii=False).encode("utf-8"),
content_type="application/json",
)
signed_url = await storage_service.get_signed_url(BUCKET, audio_id)
task_store.update(task_id, {
"status": "completed",
"progress": 100,
"message": f"配音生成完成 ({duration:.1f}s)",
"output": {
"audio_id": audio_id,
"name": display_name,
"path": signed_url,
"duration_sec": duration,
"text": req.text,
"tts_mode": req.tts_mode,
"language": req.language,
"created_at": timestamp,
},
})
finally:
if os.path.exists(audio_path):
os.unlink(audio_path)
except Exception as e:
import traceback
task_store.update(task_id, {
"status": "failed",
"message": f"配音生成失败: {str(e)}",
"error": traceback.format_exc(),
})
logger.error(f"Generate audio failed: {e}")
async def list_generated_audios(user_id: str) -> dict:
"""列出用户的所有已生成配音"""
files = await storage_service.list_files(BUCKET, user_id)
wav_files = [f for f in files if f.get("name", "").endswith("_audio.wav")]
if not wav_files:
return GeneratedAudioListResponse(items=[]).model_dump()
async def fetch_info(f):
name = f.get("name", "")
storage_path = f"{user_id}/{name}"
meta_name = name.replace("_audio.wav", "_audio.json")
meta_path = f"{user_id}/{meta_name}"
display_name = name
text = ""
tts_mode = "edgetts"
language = "zh-CN"
duration_sec = 0.0
created_at = 0
try:
meta_url = await storage_service.get_signed_url(BUCKET, meta_path)
async with httpx.AsyncClient(timeout=5.0) as client:
resp = await client.get(meta_url)
if resp.status_code == 200:
meta = resp.json()
display_name = meta.get("display_name", name)
text = meta.get("text", "")
tts_mode = meta.get("tts_mode", "edgetts")
language = meta.get("language", "zh-CN")
duration_sec = meta.get("duration_sec", 0.0)
created_at = meta.get("created_at", 0)
except Exception as e:
logger.debug(f"读取配音 metadata 失败: {e}")
try:
created_at = int(name.split("_")[0])
except:
pass
signed_url = await storage_service.get_signed_url(BUCKET, storage_path)
return GeneratedAudioItem(
id=storage_path,
name=display_name,
path=signed_url,
duration_sec=duration_sec,
text=text,
tts_mode=tts_mode,
language=language,
created_at=created_at,
)
items = await asyncio.gather(*[fetch_info(f) for f in wav_files])
items = sorted(items, key=lambda x: x.created_at, reverse=True)
return GeneratedAudioListResponse(items=items).model_dump()
async def delete_generated_audio(audio_id: str, user_id: str) -> None:
if not audio_id.startswith(f"{user_id}/"):
raise PermissionError("无权删除此文件")
await storage_service.delete_file(BUCKET, audio_id)
meta_path = audio_id.replace("_audio.wav", "_audio.json")
try:
await storage_service.delete_file(BUCKET, meta_path)
except:
pass
async def rename_generated_audio(audio_id: str, new_name: str, user_id: str) -> dict:
if not audio_id.startswith(f"{user_id}/"):
raise PermissionError("无权修改此文件")
new_name = new_name.strip()
if not new_name:
raise ValueError("新名称不能为空")
meta_path = audio_id.replace("_audio.wav", "_audio.json")
try:
meta_url = await storage_service.get_signed_url(BUCKET, meta_path)
async with httpx.AsyncClient() as client:
resp = await client.get(meta_url)
if resp.status_code == 200:
metadata = resp.json()
else:
raise Exception(f"Failed to fetch metadata: {resp.status_code}")
except Exception as e:
logger.warning(f"无法读取配音元数据: {e}, 将创建新的")
metadata = {
"display_name": new_name,
"text": "",
"tts_mode": "edgetts",
"language": "zh-CN",
"duration_sec": 0.0,
"created_at": int(time.time()),
}
metadata["display_name"] = new_name
await storage_service.upload_file(
bucket=BUCKET,
path=meta_path,
file_data=json.dumps(metadata, ensure_ascii=False).encode("utf-8"),
content_type="application/json",
)
return {"name": new_name}

View File

View File

@@ -0,0 +1,52 @@
"""
支付 API创建订单、异步通知、状态查询
遵循 BACKEND_DEV.md 规范router 只做参数校验、调用 service、返回统一响应
"""
from fastapi import APIRouter, HTTPException, Request, status
from fastapi.responses import PlainTextResponse
from app.core.response import success_response
from .schemas import CreateOrderRequest, CreateOrderResponse, OrderStatusResponse
from . import service
router = APIRouter(prefix="/api/payment", tags=["支付"])
@router.post("/create-order")
async def create_payment_order(request: CreateOrderRequest):
"""创建支付宝电脑网站支付订单,返回收银台 URL"""
try:
result = service.create_payment_order(request.payment_token)
except ValueError as e:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=str(e))
except RuntimeError as e:
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e))
return success_response(
CreateOrderResponse(**result).model_dump()
)
@router.post("/notify")
async def payment_notify(request: Request):
"""
支付宝异步通知回调
必须返回纯文本 "success"(不是 JSON否则支付宝会重复推送。
"""
form_data = await request.form()
verified = service.handle_payment_notify(dict(form_data))
return PlainTextResponse("success" if verified else "fail")
@router.get("/status/{out_trade_no}")
async def check_payment_status(out_trade_no: str):
"""查询订单支付状态(前端轮询)"""
order_status = service.get_order_status(out_trade_no)
if order_status is None:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="订单不存在")
return success_response(
OrderStatusResponse(status=order_status).model_dump()
)

View File

@@ -0,0 +1,15 @@
from pydantic import BaseModel
class CreateOrderRequest(BaseModel):
payment_token: str
class CreateOrderResponse(BaseModel):
pay_url: str
out_trade_no: str
amount: float
class OrderStatusResponse(BaseModel):
status: str

View File

@@ -0,0 +1,137 @@
"""
支付业务服务
职责Alipay SDK 封装、创建订单、处理支付通知、查询状态
遵循 BACKEND_DEV.md "薄路由 + 厚服务" 原则
"""
from datetime import datetime, timezone, timedelta
import uuid
from alipay import AliPay
from loguru import logger
from app.core.config import settings
from app.core.security import decode_payment_token
from app.repositories.orders import create_order, get_order_by_trade_no, update_order_status
from app.repositories.users import update_user
# 支付宝网关地址
ALIPAY_GATEWAY = "https://openapi.alipay.com/gateway.do"
ALIPAY_GATEWAY_SANDBOX = "https://openapi-sandbox.dl.alipaydev.com/gateway.do"
def _get_alipay_client() -> AliPay:
"""延迟初始化 Alipay 客户端"""
return AliPay(
appid=settings.ALIPAY_APP_ID,
app_notify_url=settings.ALIPAY_NOTIFY_URL,
app_private_key_string=open(settings.ALIPAY_PRIVATE_KEY_PATH).read(),
alipay_public_key_string=open(settings.ALIPAY_PUBLIC_KEY_PATH).read(),
sign_type="RSA2",
debug=settings.ALIPAY_SANDBOX,
)
def _create_page_pay_url(out_trade_no: str, amount: float, subject: str) -> str | None:
"""调用 alipay.trade.page.pay返回支付宝收银台 URL"""
client = _get_alipay_client()
order_string = client.api_alipay_trade_page_pay(
subject=subject,
out_trade_no=out_trade_no,
total_amount=amount,
return_url=settings.ALIPAY_RETURN_URL,
)
if not order_string:
logger.error(f"电脑网站支付下单失败: {out_trade_no}")
return None
gateway = ALIPAY_GATEWAY_SANDBOX if settings.ALIPAY_SANDBOX else ALIPAY_GATEWAY
pay_url = f"{gateway}?{order_string}"
logger.info(f"电脑网站支付下单成功: {out_trade_no}")
return pay_url
def _verify_signature(data: dict, signature: str) -> bool:
"""验证支付宝异步通知签名"""
client = _get_alipay_client()
return client.verify(data, signature)
def create_payment_order(payment_token: str) -> dict:
"""
创建支付订单完整流程
Returns: {"pay_url": str, "out_trade_no": str, "amount": float}
Raises: ValueError (token 无效), RuntimeError (API 失败)
"""
user_id = decode_payment_token(payment_token)
if not user_id:
raise ValueError("付费凭证无效或已过期,请重新登录")
out_trade_no = f"VG_{int(datetime.now().timestamp())}_{uuid.uuid4().hex[:8]}"
amount = settings.PAYMENT_AMOUNT
create_order(user_id, out_trade_no, amount)
pay_url = _create_page_pay_url(out_trade_no, amount, "IPAgent 会员开通")
if not pay_url:
raise RuntimeError("创建支付订单失败,请稍后重试")
logger.info(f"用户 {user_id} 创建支付订单: {out_trade_no}")
return {"pay_url": pay_url, "out_trade_no": out_trade_no, "amount": amount}
def handle_payment_notify(form_data: dict) -> bool:
"""
处理支付宝异步通知完整流程
Returns: True=验签通过, False=验签失败
"""
data = dict(form_data)
signature = data.pop("sign", "")
data.pop("sign_type", None)
if not _verify_signature(data, signature):
logger.warning(f"支付宝通知验签失败: {data.get('out_trade_no')}")
return False
out_trade_no = data.get("out_trade_no", "")
trade_status = data.get("trade_status", "")
trade_no = data.get("trade_no", "")
logger.info(f"收到支付宝通知: {out_trade_no}, status={trade_status}, trade_no={trade_no}")
if trade_status not in ("TRADE_SUCCESS", "TRADE_FINISHED"):
return True
order = get_order_by_trade_no(out_trade_no)
if not order:
logger.warning(f"订单不存在: {out_trade_no}")
return True
if order["status"] == "paid":
logger.info(f"订单已处理过: {out_trade_no}")
return True
update_order_status(out_trade_no, "paid", trade_no)
user_id = order["user_id"]
expires_at = (datetime.now(timezone.utc) + timedelta(days=settings.PAYMENT_EXPIRE_DAYS)).isoformat()
update_user(user_id, {
"is_active": True,
"role": "user",
"expires_at": expires_at,
})
logger.success(f"用户 {user_id} 支付成功,已激活,有效期至 {expires_at}")
return True
def get_order_status(out_trade_no: str) -> str | None:
"""查询订单支付状态"""
order = get_order_by_trade_no(out_trade_no)
if not order:
return None
return order["status"]

View File

@@ -13,7 +13,7 @@ router = APIRouter()
@router.post("")
async def upload_ref_audio(
file: UploadFile = File(...),
ref_text: str = Form(...),
ref_text: str = Form(""),
user: dict = Depends(get_current_user)
):
"""上传参考音频"""
@@ -68,3 +68,21 @@ async def rename_ref_audio(
except Exception as e:
logger.error(f"重命名失败: {e}")
raise HTTPException(status_code=500, detail=f"重命名失败: {str(e)}")
@router.post("/{audio_id:path}/retranscribe")
async def retranscribe_ref_audio(
audio_id: str,
user: dict = Depends(get_current_user)
):
"""重新识别参考音频的文字内容"""
try:
result = await service.retranscribe_ref_audio(audio_id, user["id"])
return success_response(result, message="识别完成")
except PermissionError as e:
raise HTTPException(status_code=403, detail=str(e))
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"重新识别失败: {e}")
raise HTTPException(status_code=500, detail=f"识别失败: {str(e)}")

View File

@@ -2,9 +2,11 @@ import re
import os
import time
import json
import hashlib
import asyncio
import subprocess
import tempfile
import unicodedata
from pathlib import Path
from typing import Optional
@@ -19,8 +21,16 @@ BUCKET_REF_AUDIOS = "ref-audios"
def sanitize_filename(filename: str) -> str:
"""清理文件名,移除特殊字符"""
safe_name = re.sub(r'[<>:"/\\|?*\s]', '_', filename)
"""清理文件名用于 Storage key仅保留 ASCII 安全字符)。"""
normalized = unicodedata.normalize("NFKD", filename)
ascii_name = normalized.encode("ascii", "ignore").decode("ascii")
safe_name = re.sub(r"[^A-Za-z0-9._-]+", "_", ascii_name).strip("._-")
# 纯中文/emoji 等场景会被清空,使用稳定哈希兜底,避免 InvalidKey
if not safe_name:
digest = hashlib.md5(filename.encode("utf-8")).hexdigest()[:12]
safe_name = f"audio_{digest}"
if len(safe_name) > 50:
ext = Path(safe_name).suffix
safe_name = safe_name[:50 - len(ext)] + ext
@@ -41,16 +51,40 @@ def _get_audio_duration(file_path: str) -> float:
return 0.0
def _convert_to_wav(input_path: str, output_path: str) -> bool:
"""将音频转换为 WAV 格式 (16kHz, mono)"""
def _find_silence_cut_point(file_path: str, max_duration: float) -> float:
"""在 max_duration 附近找一个静音点作为截取位置,找不到则回退到 max_duration"""
try:
subprocess.run([
'ffmpeg', '-y', '-i', input_path,
'-ar', '16000',
'-ac', '1',
'-acodec', 'pcm_s16le',
output_path
], capture_output=True, timeout=60, check=True)
# 用 silencedetect 找所有静音段(阈值 -30dB最短 0.3 秒)
result = subprocess.run(
['ffmpeg', '-i', file_path, '-af',
'silencedetect=noise=-30dB:d=0.3', '-f', 'null', '-'],
capture_output=True, text=True, timeout=30
)
# 解析 silence_end 时间点
import re as _re
ends = [float(m) for m in _re.findall(r'silence_end:\s*([\d.]+)', result.stderr)]
# 找 max_duration 之前最后一个静音结束点(至少 3 秒)
candidates = [t for t in ends if 3.0 <= t <= max_duration]
if candidates:
cut = candidates[-1]
logger.info(f"Found silence cut point at {cut:.1f}s (max={max_duration}s)")
return cut
except Exception as e:
logger.warning(f"Silence detection failed: {e}")
return max_duration
def _convert_to_wav(input_path: str, output_path: str, max_duration: float = 0) -> bool:
"""将音频转换为 WAV 格式 (16kHz, mono),可选截取前 max_duration 秒并淡出"""
try:
cmd = ['ffmpeg', '-y', '-i', input_path]
if max_duration > 0:
cmd += ['-t', str(max_duration)]
# 末尾 0.1 秒淡出,避免截断爆音
fade_start = max(0, max_duration - 0.1)
cmd += ['-af', f'afade=t=out:st={fade_start}:d=0.1']
cmd += ['-ar', '16000', '-ac', '1', '-acodec', 'pcm_s16le', output_path]
subprocess.run(cmd, capture_output=True, timeout=60, check=True)
return True
except Exception as e:
logger.error(f"音频转换失败: {e}")
@@ -67,9 +101,6 @@ async def upload_ref_audio(file, ref_text: str, user_id: str) -> dict:
if ext not in ALLOWED_AUDIO_EXTENSIONS:
raise ValueError(f"不支持的音频格式: {ext}。支持的格式: {', '.join(ALLOWED_AUDIO_EXTENSIONS)}")
if not ref_text or len(ref_text.strip()) < 2:
raise ValueError("参考文字不能为空")
# 创建临时文件
with tempfile.NamedTemporaryFile(delete=False, suffix=ext) as tmp_input:
content = await file.read()
@@ -86,8 +117,31 @@ async def upload_ref_audio(file, ref_text: str, user_id: str) -> dict:
duration = _get_audio_duration(tmp_wav_path)
if duration < 1.0:
raise ValueError("音频时长过短,至少需要 1 秒")
if duration > 60.0:
raise ValueError("音频时长过长,最多 60 秒")
# 超过 10 秒自动在静音点截取CosyVoice 对 3-10 秒效果最好)
MAX_REF_DURATION = 10.0
if duration > MAX_REF_DURATION:
cut_point = _find_silence_cut_point(tmp_wav_path, MAX_REF_DURATION)
logger.info(f"Ref audio {duration:.1f}s > {MAX_REF_DURATION}s, trimming at {cut_point:.1f}s")
trimmed_path = tmp_input_path + "_trimmed.wav"
if not _convert_to_wav(tmp_wav_path, trimmed_path, max_duration=cut_point):
raise RuntimeError("音频截取失败")
os.unlink(tmp_wav_path)
tmp_wav_path = trimmed_path
duration = _get_audio_duration(tmp_wav_path)
# 自动转写参考音频内容
try:
from app.services.whisper_service import whisper_service
transcribed = await whisper_service.transcribe(tmp_wav_path)
if transcribed.strip():
ref_text = transcribed.strip()
logger.info(f"Auto-transcribed ref audio: {ref_text[:50]}...")
except Exception as e:
logger.warning(f"Auto-transcribe failed: {e}")
if not ref_text or not ref_text.strip():
raise ValueError("无法识别音频内容,请确保音频包含清晰的语音")
# 检查重名
existing_files = await storage_service.list_files(BUCKET_REF_AUDIOS, user_id)
@@ -267,3 +321,85 @@ async def rename_ref_audio(audio_id: str, new_name: str, user_id: str) -> dict:
)
return {"name": new_name}
async def retranscribe_ref_audio(audio_id: str, user_id: str) -> dict:
"""重新转写参考音频的 ref_text并截取前 10 秒重新上传(用于迁移旧数据)"""
if not audio_id.startswith(f"{user_id}/"):
raise PermissionError("无权修改此文件")
# 下载音频到临时文件
audio_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, audio_id)
tmp_wav_path = None
trimmed_path = None
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as tmp:
tmp_wav_path = tmp.name
timeout = httpx.Timeout(None)
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream("GET", audio_url) as resp:
resp.raise_for_status()
async for chunk in resp.aiter_bytes():
tmp.write(chunk)
# 超过 10 秒则截取前 10 秒并重新上传音频
MAX_REF_DURATION = 10.0
duration = _get_audio_duration(tmp_wav_path)
transcribe_path = tmp_wav_path
need_reupload = False
if duration > MAX_REF_DURATION:
cut_point = _find_silence_cut_point(tmp_wav_path, MAX_REF_DURATION)
logger.info(f"Retranscribe: trimming {audio_id} from {duration:.1f}s at {cut_point:.1f}s")
trimmed_path = tmp_wav_path + "_trimmed.wav"
if _convert_to_wav(tmp_wav_path, trimmed_path, max_duration=cut_point):
transcribe_path = trimmed_path
duration = _get_audio_duration(trimmed_path)
need_reupload = True
# Whisper 转写
from app.services.whisper_service import whisper_service
transcribed = await whisper_service.transcribe(transcribe_path)
if not transcribed or not transcribed.strip():
raise ValueError("无法识别音频内容")
ref_text = transcribed.strip()
logger.info(f"Re-transcribed ref audio {audio_id}: {ref_text[:50]}...")
# 截取过的音频重新上传覆盖原文件
if need_reupload and trimmed_path:
with open(trimmed_path, "rb") as f:
await storage_service.upload_file(
bucket=BUCKET_REF_AUDIOS, path=audio_id,
file_data=f.read(), content_type="audio/wav",
)
logger.info(f"Re-uploaded trimmed audio: {audio_id} ({duration:.1f}s)")
# 更新 metadata
metadata_path = audio_id.replace(".wav", ".json")
try:
meta_url = await storage_service.get_signed_url(BUCKET_REF_AUDIOS, metadata_path)
async with httpx.AsyncClient(timeout=5.0) as client:
resp = await client.get(meta_url)
if resp.status_code == 200:
metadata = resp.json()
else:
raise Exception(f"status {resp.status_code}")
except Exception:
metadata = {}
metadata["ref_text"] = ref_text
metadata["duration_sec"] = duration
await storage_service.upload_file(
bucket=BUCKET_REF_AUDIOS,
path=metadata_path,
file_data=json.dumps(metadata, ensure_ascii=False).encode('utf-8'),
content_type="application/json"
)
return {"ref_text": ref_text, "duration_sec": duration}
finally:
if tmp_wav_path and os.path.exists(tmp_wav_path):
os.unlink(tmp_wav_path)
if trimmed_path and os.path.exists(trimmed_path):
os.unlink(trimmed_path)

View File

@@ -13,11 +13,12 @@ router = APIRouter()
async def extract_script_tool(
file: Optional[UploadFile] = File(None),
url: Optional[str] = Form(None),
rewrite: bool = Form(True)
rewrite: bool = Form(True),
custom_prompt: Optional[str] = Form(None)
):
"""独立文案提取工具"""
try:
result = await service.extract_script(file=file, url=url, rewrite=rewrite)
result = await service.extract_script(file=file, url=url, rewrite=rewrite, custom_prompt=custom_prompt)
return success_response(result)
except ValueError as e:
raise HTTPException(400, str(e))

View File

@@ -17,9 +17,9 @@ from app.services.whisper_service import whisper_service
from app.services.glm_service import glm_service
async def extract_script(file=None, url: Optional[str] = None, rewrite: bool = True) -> dict:
async def extract_script(file=None, url: Optional[str] = None, rewrite: bool = True, custom_prompt: Optional[str] = None) -> dict:
"""
文案提取:上传文件或视频链接 -> Whisper 转写 -> (可选) GLM 洗稿
文案提取:上传文件或视频链接 -> Whisper 转写 -> (可选) GLM 改写
"""
if not file and not url:
raise ValueError("必须提供文件或视频链接")
@@ -63,11 +63,11 @@ async def extract_script(file=None, url: Optional[str] = None, rewrite: bool = T
# 2. 提取文案 (Whisper)
script = await whisper_service.transcribe(str(audio_path))
# 3. AI 洗稿 (GLM)
# 3. AI 改写 (GLM)
rewritten = None
if rewrite and script and len(script.strip()) > 0:
logger.info("Rewriting script...")
rewritten = await glm_service.rewrite_script(script)
rewritten = await glm_service.rewrite_script(script, custom_prompt)
return {
"original_script": script,
@@ -156,125 +156,120 @@ def _download_yt_dlp(url_value: str, temp_dir: Path, timestamp: int) -> Path:
'quiet': True,
'no_warnings': True,
'http_headers': {
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36',
'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36',
'Referer': 'https://www.douyin.com/',
}
}
with yt_dlp.YoutubeDL() as ydl_raw:
ydl: Any = ydl_raw
ydl.params.update(ydl_opts)
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
info = ydl.extract_info(url_value, download=True)
if 'requested_downloads' in info:
downloaded_file = info['requested_downloads'][0]['filepath']
else:
ext = info.get('ext', 'mp4')
id = info.get('id')
downloaded_file = str(temp_dir / f"tool_download_{timestamp}_{id}.{ext}")
vid_id = info.get('id')
downloaded_file = str(temp_dir / f"tool_download_{timestamp}_{vid_id}.{ext}")
return Path(downloaded_file)
async def _download_douyin_manual(url: str, temp_dir: Path, timestamp: int) -> Optional[Path]:
"""手动下载抖音视频 (Fallback)"""
logger.info(f"[SuperIPAgent] Starting download for: {url}")
"""手动下载抖音视频 (Fallback) — 通过移动端分享页获取播放地址"""
logger.info(f"[douyin-fallback] Starting download for: {url}")
try:
# 1. 解析短链接,提取视频 ID
headers = {
"user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36"
"user-agent": "Mozilla/5.0 (iPhone; CPU iPhone OS 16_0 like Mac OS X) AppleWebKit/605.1.15"
}
async with httpx.AsyncClient(follow_redirects=True, timeout=10.0) as client:
resp = await client.get(url, headers=headers)
final_url = str(resp.url)
logger.info(f"[SuperIPAgent] Final URL: {final_url}")
logger.info(f"[douyin-fallback] Final URL: {final_url}")
modal_id = None
video_id = None
match = re.search(r'/video/(\d+)', final_url)
if match:
modal_id = match.group(1)
video_id = match.group(1)
if not modal_id:
logger.error("[SuperIPAgent] Could not extract modal_id")
if not video_id:
logger.error("[douyin-fallback] Could not extract video_id")
return None
logger.info(f"[SuperIPAgent] Extracted modal_id: {modal_id}")
logger.info(f"[douyin-fallback] Extracted video_id: {video_id}")
target_url = f"https://www.douyin.com/user/MS4wLjABAAAAN_s_hups7LD0N4qnrM3o2gI0vuG3pozNaEolz2_py3cHTTrpVr1Z4dukFD9SOlwY?from_tab_name=main&modal_id={modal_id}"
# 2. 获取新鲜 ttwid
ttwid = ""
try:
async with httpx.AsyncClient(timeout=10.0) as client:
ttwid_resp = await client.post(
"https://ttwid.bytedance.com/ttwid/union/register/",
json={
"region": "cn", "aid": 6383, "needFid": False,
"service": "www.douyin.com",
"migrate_info": {"ticket": "", "source": "node"},
"cbUrlProtocol": "https", "union": True,
}
)
ttwid = ttwid_resp.cookies.get("ttwid", "")
logger.info(f"[douyin-fallback] Got fresh ttwid (len={len(ttwid)})")
except Exception as e:
logger.warning(f"[douyin-fallback] Failed to get ttwid: {e}")
from app.core.config import settings
if not settings.DOUYIN_COOKIE:
logger.warning("[SuperIPAgent] DOUYIN_COOKIE 未配置,视频下载可能失败")
headers_with_cookie = {
"accept": "text/html,application/xhtml+xml,application/xml;q=0.9,image/avif,image/webp,image/apng,*/*;q=0.8,application/signed-exchange;v=b3;q=0.7",
"cookie": settings.DOUYIN_COOKIE,
"user-agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36",
# 3. 访问移动端分享页提取播放地址
page_headers = {
"user-agent": "Mozilla/5.0 (iPhone; CPU iPhone OS 16_0 like Mac OS X) AppleWebKit/605.1.15",
"cookie": f"ttwid={ttwid}" if ttwid else "",
}
logger.info(f"[SuperIPAgent] Requesting page with Cookie...")
async with httpx.AsyncClient(follow_redirects=True, timeout=15.0) as client:
page_resp = await client.get(
f"https://m.douyin.com/share/video/{video_id}",
headers=page_headers,
)
async with httpx.AsyncClient(timeout=10.0) as client:
response = await client.get(target_url, headers=headers_with_cookie)
page_text = page_resp.text
logger.info(f"[douyin-fallback] Mobile page length: {len(page_text)}")
content_match = re.findall(r'<script id="RENDER_DATA" type="application/json">(.*?)</script>', response.text)
if not content_match:
if "SSR_HYDRATED_DATA" in response.text:
content_match = re.findall(r'<script id="SSR_HYDRATED_DATA" type="application/json">(.*?)</script>', response.text)
if not content_match:
logger.error(f"[SuperIPAgent] Could not find RENDER_DATA in page (len={len(response.text)})")
return None
content = unquote(content_match[0])
try:
data = json.loads(content)
except:
logger.error("[SuperIPAgent] JSON decode failed")
return None
video_url = None
try:
if "app" in data and "videoDetail" in data["app"]:
info = data["app"]["videoDetail"]["video"]
if "bitRateList" in info and info["bitRateList"]:
video_url = info["bitRateList"][0]["playAddr"][0]["src"]
elif "playAddr" in info and info["playAddr"]:
video_url = info["playAddr"][0]["src"]
except Exception as e:
logger.error(f"[SuperIPAgent] Path extraction failed: {e}")
if not video_url:
logger.error("[SuperIPAgent] No video_url found")
# 4. 提取 play_addr
addr_match = re.search(
r'"play_addr":\{"uri":"([^"]+)","url_list":\["([^"]+)"',
page_text,
)
if not addr_match:
logger.error("[douyin-fallback] Could not find play_addr in mobile page")
return None
video_url = addr_match.group(2).replace(r"\u002F", "/")
if video_url.startswith("//"):
video_url = "https:" + video_url
logger.info(f"[SuperIPAgent] Found video URL: {video_url[:50]}...")
logger.info(f"[douyin-fallback] Found video URL: {video_url[:80]}...")
# 5. 下载视频
temp_path = temp_dir / f"douyin_manual_{timestamp}.mp4"
download_headers = {
'Referer': 'https://www.douyin.com/',
'User-Agent': 'Mozilla/5.0 (Macintosh; Intel Mac OS X 10_15_7) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/131.0.0.0 Safari/537.36',
"Referer": "https://www.douyin.com/",
"User-Agent": "Mozilla/5.0 (iPhone; CPU iPhone OS 16_0 like Mac OS X) AppleWebKit/605.1.15",
}
async with httpx.AsyncClient(timeout=60.0) as client:
async with httpx.AsyncClient(timeout=120.0, follow_redirects=True) as client:
async with client.stream("GET", video_url, headers=download_headers) as dl_resp:
if dl_resp.status_code == 200:
with open(temp_path, 'wb') as f:
with open(temp_path, "wb") as f:
async for chunk in dl_resp.aiter_bytes(chunk_size=8192):
f.write(chunk)
logger.info(f"[SuperIPAgent] Downloaded successfully: {temp_path}")
logger.info(f"[douyin-fallback] Downloaded successfully: {temp_path}")
return temp_path
else:
logger.error(f"[SuperIPAgent] Download failed: {dl_resp.status_code}")
logger.error(f"[douyin-fallback] Download failed: {dl_resp.status_code}")
return None
except Exception as e:
logger.error(f"[SuperIPAgent] Logic failed: {e}")
logger.error(f"[douyin-fallback] Logic failed: {e}")
return None

View File

@@ -1,5 +1,13 @@
from pydantic import BaseModel
from typing import Optional, List
from typing import Optional, List, Literal
class CustomAssignment(BaseModel):
material_path: str
start: float # 音频时间轴起点
end: float # 音频时间轴终点
source_start: float = 0.0 # 源视频截取起点
source_end: Optional[float] = None # 源视频截取终点(可选)
class GenerateRequest(BaseModel):
@@ -11,13 +19,22 @@ class GenerateRequest(BaseModel):
ref_audio_id: Optional[str] = None
ref_text: Optional[str] = None
language: str = "zh-CN"
generated_audio_id: Optional[str] = None # 预生成配音 ID存在时跳过内联 TTS
title: Optional[str] = None
title_display_mode: Literal["short", "persistent"] = "short"
title_duration: float = 4.0
enable_subtitles: bool = True
subtitle_style_id: Optional[str] = None
title_style_id: Optional[str] = None
secondary_title: Optional[str] = None
secondary_title_style_id: Optional[str] = None
secondary_title_font_size: Optional[int] = None
secondary_title_top_margin: Optional[int] = None
subtitle_font_size: Optional[int] = None
title_font_size: Optional[int] = None
title_top_margin: Optional[int] = None
subtitle_bottom_margin: Optional[int] = None
bgm_id: Optional[str] = None
bgm_volume: Optional[float] = 0.2
custom_assignments: Optional[List[CustomAssignment]] = None
output_aspect_ratio: Literal["9:16", "16:9"] = "9:16"

View File

@@ -29,7 +29,7 @@ def _locale_to_whisper_lang(locale: str) -> str:
return locale.split("-")[0] if "-" in locale else locale
def _locale_to_qwen_lang(locale: str) -> str:
def _locale_to_tts_lang(locale: str) -> str:
"""'zh-CN''Chinese', 'en-US''English', 其他 → 'Auto'"""
mapping = {"zh": "Chinese", "en": "English"}
return mapping.get(locale.split("-")[0], "Auto")
@@ -174,17 +174,27 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
# ── 确定素材列表 ──
material_paths: List[str] = []
if req.material_paths and len(req.material_paths) > 1:
if req.custom_assignments and len(req.custom_assignments) > 1:
material_paths = [a.material_path for a in req.custom_assignments if a.material_path]
elif req.material_paths and len(req.material_paths) > 1:
material_paths = req.material_paths
else:
material_paths = [req.material_path]
is_multi = len(material_paths) > 1
target_resolution = (1080, 1920) if req.output_aspect_ratio == "9:16" else (1920, 1080)
logger.info(
f"[Render] 输出画面比例: {req.output_aspect_ratio}, "
f"目标分辨率: {target_resolution[0]}x{target_resolution[1]}"
)
_update_task(task_id, status="processing", progress=5, message="正在下载素材...")
temp_dir = settings.UPLOAD_DIR / "temp"
temp_dir.mkdir(parents=True, exist_ok=True)
video = VideoService()
input_material_path: Optional[Path] = None
# 单素材模式:下载主素材
if not is_multi:
@@ -192,12 +202,50 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
temp_files.append(input_material_path)
await _download_material(material_paths[0], input_material_path)
# 归一化旋转元数据(如 iPhone MOV 1920x1080 + rotation=-90
normalized_input_path = temp_dir / f"{task_id}_input_norm.mp4"
normalized_result = video.normalize_orientation(
str(input_material_path),
str(normalized_input_path),
)
if normalized_result != str(input_material_path):
temp_files.append(normalized_input_path)
input_material_path = normalized_input_path
_update_task(task_id, message="正在生成语音...", progress=10)
audio_path = temp_dir / f"{task_id}_audio.wav"
temp_files.append(audio_path)
if req.tts_mode == "voiceclone":
if req.generated_audio_id:
# 新流程:使用预生成的配音
_update_task(task_id, message="正在下载配音...", progress=12)
audio_url = await storage_service.get_signed_url(
bucket="generated-audios",
path=req.generated_audio_id,
)
await _download_material(audio_url, audio_path)
# 从元数据获取 language
meta_path = req.generated_audio_id.replace("_audio.wav", "_audio.json")
try:
meta_url = await storage_service.get_signed_url(
bucket="generated-audios", path=meta_path,
)
import httpx as _httpx
async with _httpx.AsyncClient(timeout=5.0) as client:
resp = await client.get(meta_url)
if resp.status_code == 200:
meta = resp.json()
req.language = meta.get("language", req.language)
# 无条件用配音元数据覆盖文案,确保字幕与配音语言一致
meta_text = meta.get("text", "")
if meta_text:
req.text = meta_text
except Exception as e:
logger.warning(f"读取配音元数据失败: {e}")
elif req.tts_mode == "voiceclone":
if not req.ref_audio_id or not req.ref_text:
raise ValueError("声音克隆模式需要提供参考音频和参考文字")
@@ -212,13 +260,13 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
)
await _download_material(ref_audio_url, ref_audio_local)
_update_task(task_id, message="正在克隆声音 (Qwen3-TTS)...")
_update_task(task_id, message="正在克隆声音...")
await voice_clone_service.generate_audio(
text=req.text,
ref_audio_path=str(ref_audio_local),
ref_text=req.ref_text,
output_path=str(audio_path),
language=_locale_to_qwen_lang(req.language)
language=_locale_to_tts_lang(req.language)
)
else:
_update_task(task_id, message="正在生成语音 (EdgeTTS)...")
@@ -232,47 +280,126 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
lipsync_video_path = temp_dir / f"{task_id}_lipsync.mp4"
temp_files.append(lipsync_video_path)
video = VideoService()
captions_path = None
if is_multi:
# ══════════════════════════════════════
# 多素材流水线
# ══════════════════════════════════════
_update_task(task_id, progress=12, message="正在生成字幕 (Whisper)...")
_update_task(task_id, progress=12, message="正在分配素材...")
captions_path = temp_dir / f"{task_id}_captions.json"
temp_files.append(captions_path)
try:
captions_data = await whisper_service.align(
audio_path=str(audio_path),
text=req.text,
output_path=str(captions_path),
language=_locale_to_whisper_lang(req.language),
)
print(f"[Pipeline] Whisper alignment completed (multi-material)")
except Exception as e:
logger.warning(f"Whisper alignment failed: {e}")
captions_data = None
captions_path = None
_update_task(task_id, progress=15, message="正在分配素材...")
if captions_data and captions_data.get("segments"):
assignments = _split_equal(captions_data["segments"], material_paths)
else:
# Whisper 失败 → 按时长均分(不依赖字符对齐)
logger.warning("[MultiMat] Whisper 无数据,按时长均分")
audio_dur = video._get_duration(str(audio_path))
if audio_dur <= 0:
audio_dur = 30.0 # 安全兜底
seg_dur = audio_dur / len(material_paths)
if req.custom_assignments and len(req.custom_assignments) == len(material_paths):
# 用户自定义分配,跳过 Whisper 均分
assignments = [
{"material_path": material_paths[i], "start": i * seg_dur,
"end": (i + 1) * seg_dur, "index": i}
for i in range(len(material_paths))
{
"material_path": a.material_path,
"start": a.start,
"end": a.end,
"source_start": a.source_start,
"source_end": a.source_end,
"index": i,
}
for i, a in enumerate(req.custom_assignments)
]
# 仍然需要 Whisper 生成字幕(如果启用)
captions_path = temp_dir / f"{task_id}_captions.json"
temp_files.append(captions_path)
if req.enable_subtitles:
_update_task(task_id, message="正在生成字幕 (Whisper)...")
try:
await whisper_service.align(
audio_path=str(audio_path),
text=req.text,
output_path=str(captions_path),
language=_locale_to_whisper_lang(req.language),
original_text=req.text,
)
print(f"[Pipeline] Whisper alignment completed (custom assignments)")
except Exception as e:
logger.warning(f"Whisper alignment failed: {e}")
captions_path = None
else:
captions_path = None
elif req.custom_assignments:
logger.warning(
f"[MultiMat] custom_assignments 数量({len(req.custom_assignments)})"
f" 与素材数量({len(material_paths)})不一致,回退自动分配"
)
# 原有逻辑Whisper → _split_equal
_update_task(task_id, message="正在生成字幕 (Whisper)...")
captions_path = temp_dir / f"{task_id}_captions.json"
temp_files.append(captions_path)
try:
captions_data = await whisper_service.align(
audio_path=str(audio_path),
text=req.text,
output_path=str(captions_path),
language=_locale_to_whisper_lang(req.language),
original_text=req.text,
)
print(f"[Pipeline] Whisper alignment completed (multi-material)")
except Exception as e:
logger.warning(f"Whisper alignment failed: {e}")
captions_data = None
captions_path = None
_update_task(task_id, progress=15, message="正在分配素材...")
if captions_data and captions_data.get("segments"):
assignments = _split_equal(captions_data["segments"], material_paths)
else:
# Whisper 失败 → 按时长均分(不依赖字符对齐)
logger.warning("[MultiMat] Whisper 无数据,按时长均分")
audio_dur = video._get_duration(str(audio_path))
if audio_dur <= 0:
audio_dur = 30.0 # 安全兜底
seg_dur = audio_dur / len(material_paths)
assignments = [
{"material_path": material_paths[i], "start": i * seg_dur,
"end": (i + 1) * seg_dur, "index": i}
for i in range(len(material_paths))
]
else:
# 原有逻辑Whisper → _split_equal
_update_task(task_id, message="正在生成字幕 (Whisper)...")
captions_path = temp_dir / f"{task_id}_captions.json"
temp_files.append(captions_path)
try:
captions_data = await whisper_service.align(
audio_path=str(audio_path),
text=req.text,
output_path=str(captions_path),
language=_locale_to_whisper_lang(req.language),
original_text=req.text,
)
print(f"[Pipeline] Whisper alignment completed (multi-material)")
except Exception as e:
logger.warning(f"Whisper alignment failed: {e}")
captions_data = None
captions_path = None
_update_task(task_id, progress=15, message="正在分配素材...")
if captions_data and captions_data.get("segments"):
assignments = _split_equal(captions_data["segments"], material_paths)
else:
# Whisper 失败 → 按时长均分(不依赖字符对齐)
logger.warning("[MultiMat] Whisper 无数据,按时长均分")
audio_dur = video._get_duration(str(audio_path))
if audio_dur <= 0:
audio_dur = 30.0 # 安全兜底
seg_dur = audio_dur / len(material_paths)
assignments = [
{"material_path": material_paths[i], "start": i * seg_dur,
"end": (i + 1) * seg_dur, "index": i}
for i in range(len(material_paths))
]
# 扩展段覆盖完整音频范围首段从0开始末段到音频结尾
audio_duration = video._get_duration(str(audio_path))
@@ -296,12 +423,23 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
material_local = temp_dir / f"{task_id}_material_{i}.mp4"
temp_files.append(material_local)
await _download_material(assignment["material_path"], material_local)
# 归一化旋转元数据,确保分辨率判断与后续推理一致
normalized_material = temp_dir / f"{task_id}_material_{i}_norm.mp4"
normalized_result = video.normalize_orientation(
str(material_local),
str(normalized_material),
)
if normalized_result != str(material_local):
temp_files.append(normalized_material)
material_local = normalized_material
material_locals.append(material_local)
resolutions.append(video.get_resolution(str(material_local)))
# 分辨率不一致时,统一到第一个素材的分辨率
base_res = resolutions[0] if resolutions else (0, 0)
need_scale = any(r != base_res for r in resolutions) and base_res[0] > 0
# 按用户选择的画面比例统一分辨率
base_res = target_resolution
need_scale = any(r != base_res for r in resolutions)
if need_scale:
logger.info(f"[MultiMat] 素材分辨率不一致,统一到 {base_res[0]}x{base_res[1]}")
@@ -321,7 +459,11 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
temp_files.append(prepared_path)
video.prepare_segment(
str(material_locals[i]), seg_dur, str(prepared_path),
target_resolution=base_res if need_scale else None
# 多素材拼接前统一重编码为同分辨率/同编码,避免 concat 仅保留首段
target_resolution=base_res,
source_start=assignment.get("source_start", 0.0),
source_end=assignment.get("source_end"),
target_fps=25,
)
prepared_segments.append(prepared_path)
@@ -331,7 +473,8 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
temp_files.append(concat_path)
video.concat_videos(
[str(p) for p in prepared_segments],
str(concat_path)
str(concat_path),
target_fps=25,
)
# ── 第三步:一次 LatentSync 推理 ──
@@ -363,6 +506,33 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
# ══════════════════════════════════════
# 单素材流水线(原有逻辑)
# ══════════════════════════════════════
if input_material_path is None:
raise RuntimeError("单素材流程缺少输入素材")
# 单素材:按用户选择画面比例统一到目标分辨率,并应用 source_start
single_source_start = 0.0
single_source_end = None
if req.custom_assignments and len(req.custom_assignments) == 1:
single_source_start = req.custom_assignments[0].source_start
single_source_end = req.custom_assignments[0].source_end
_update_task(task_id, progress=20, message="正在准备素材片段...")
audio_dur = video._get_duration(str(audio_path))
if audio_dur <= 0:
audio_dur = 30.0
prepared_single_path = temp_dir / f"{task_id}_prepared_single.mp4"
temp_files.append(prepared_single_path)
video.prepare_segment(
str(input_material_path),
audio_dur,
str(prepared_single_path),
target_resolution=target_resolution,
source_start=single_source_start,
source_end=single_source_end,
)
input_material_path = prepared_single_path
_update_task(task_id, progress=25)
_update_task(task_id, message="正在合成唇形 (LatentSync)...", progress=30)
@@ -396,6 +566,7 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
text=req.text,
output_path=str(captions_path),
language=_locale_to_whisper_lang(req.language),
original_text=req.text,
)
print(f"[Pipeline] Whisper alignment completed")
except Exception as e:
@@ -427,14 +598,17 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
else:
logger.warning(f"BGM not found: {req.bgm_id}")
use_remotion = (captions_path and captions_path.exists()) or req.title
use_remotion = (captions_path and captions_path.exists()) or req.title or req.secondary_title
subtitle_style = None
title_style = None
secondary_title_style = None
if req.enable_subtitles:
subtitle_style = get_style("subtitle", req.subtitle_style_id) or get_default_style("subtitle")
if req.title:
title_style = get_style("title", req.title_style_id) or get_default_style("title")
if req.secondary_title:
secondary_title_style = get_style("title", req.secondary_title_style_id) or get_default_style("title")
if req.subtitle_font_size and req.enable_subtitles:
if subtitle_style is None:
@@ -456,6 +630,16 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
subtitle_style = {}
subtitle_style["bottom_margin"] = int(req.subtitle_bottom_margin)
if req.secondary_title_font_size and req.secondary_title:
if secondary_title_style is None:
secondary_title_style = {}
secondary_title_style["font_size"] = int(req.secondary_title_font_size)
if req.secondary_title_top_margin is not None and req.secondary_title:
if secondary_title_style is None:
secondary_title_style = {}
secondary_title_style["top_margin"] = int(req.secondary_title_top_margin)
if use_remotion:
subtitle_style = prepare_style_for_remotion(
subtitle_style,
@@ -467,6 +651,11 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
temp_dir,
f"{task_id}_title_font"
)
secondary_title_style = prepare_style_for_remotion(
secondary_title_style,
temp_dir,
f"{task_id}_secondary_title_font"
)
final_output_local_path = temp_dir / f"{task_id}_output.mp4"
temp_files.append(final_output_local_path)
@@ -486,16 +675,26 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
mapped = 87 + int(percent * 0.08)
_update_task(task_id, progress=mapped)
title_display_mode = (
req.title_display_mode
if req.title_display_mode in ("short", "persistent")
else "short"
)
title_duration = max(0.5, min(float(req.title_duration or 4.0), 30.0))
await remotion_service.render(
video_path=str(composed_video_path),
output_path=str(final_output_local_path),
captions_path=str(captions_path) if captions_path else None,
title=req.title,
title_duration=3.0,
title_duration=title_duration,
title_display_mode=title_display_mode,
fps=25,
enable_subtitles=req.enable_subtitles,
subtitle_style=subtitle_style,
title_style=title_style,
secondary_title=req.secondary_title,
secondary_title_style=secondary_title_style,
on_progress=on_remotion_progress
)
print(f"[Pipeline] Remotion render completed")

View File

@@ -0,0 +1,34 @@
"""
订单数据访问层
"""
from datetime import datetime, timezone
from typing import Any, Dict, Optional, cast
from app.core.supabase import get_supabase
def create_order(user_id: str, out_trade_no: str, amount: float) -> Dict[str, Any]:
supabase = get_supabase()
result = supabase.table("orders").insert({
"user_id": user_id,
"out_trade_no": out_trade_no,
"amount": amount,
"status": "pending",
}).execute()
return cast(Dict[str, Any], (result.data or [{}])[0])
def get_order_by_trade_no(out_trade_no: str) -> Optional[Dict[str, Any]]:
supabase = get_supabase()
result = supabase.table("orders").select("*").eq("out_trade_no", out_trade_no).single().execute()
return cast(Optional[Dict[str, Any]], result.data or None)
def update_order_status(out_trade_no: str, status: str, trade_no: str | None = None) -> None:
supabase = get_supabase()
payload: Dict[str, Any] = {"status": status}
if trade_no:
payload["trade_no"] = trade_no
if status == "paid":
payload["paid_at"] = datetime.now(timezone.utc).isoformat()
supabase.table("orders").update(payload).eq("out_trade_no", out_trade_no).execute()

View File

@@ -1,3 +1,4 @@
from datetime import datetime, timezone
from typing import Any, Dict, List, Optional, cast
from app.core.supabase import get_supabase
@@ -37,3 +38,33 @@ def update_user(user_id: str, payload: Dict[str, Any]) -> List[Dict[str, Any]]:
supabase = get_supabase()
result = supabase.table("users").update(payload).eq("id", user_id).execute()
return cast(List[Dict[str, Any]], result.data or [])
def _parse_expires_at(expires_at: Any) -> Optional[datetime]:
try:
expires_at_dt = datetime.fromisoformat(str(expires_at).replace("Z", "+00:00"))
except Exception:
return None
if expires_at_dt.tzinfo is None:
expires_at_dt = expires_at_dt.replace(tzinfo=timezone.utc)
return expires_at_dt.astimezone(timezone.utc)
def deactivate_user_if_expired(user: Dict[str, Any]) -> bool:
expires_at = user.get("expires_at")
if not expires_at:
return False
expires_at_dt = _parse_expires_at(expires_at)
if not expires_at_dt:
return False
if datetime.now(timezone.utc) <= expires_at_dt:
return False
user_id = user.get("id")
if user.get("is_active") and user_id:
update_user(cast(str, user_id), {"is_active": False})
return True

View File

@@ -35,18 +35,19 @@ class GLMService:
Returns:
{"title": "标题", "tags": ["标签1", "标签2", ...]}
"""
prompt = f"""根据以下口播文案生成一个吸引人的短视频标题和3个相关标签。
prompt = f"""根据以下口播文案,生成一个吸引人的短视频标题、副标题和3个相关标签。
口播文案:
{text}
要求:
1. 标题要简洁有力能吸引观众点击不超过10个字
2. 标签要与内容相关便于搜索和推荐只要3个
3. 标题和标签必须使用与口播文案相同的语言(如文案是英文就用英文,日文就用日文)
2. 副标题是对标题的补充说明或描述性文字不超过20个字
3. 标签要与内容相关便于搜索和推荐只要3个
4. 标题、副标题和标签必须使用与口播文案相同的语言(如文案是英文就用英文,日文就用日文)
请严格按以下JSON格式返回不要包含其他内容
{{"title": "标题", "tags": ["标签1", "标签2", "标签3"]}}"""
{{"title": "标题", "secondary_title": "副标题", "tags": ["标签1", "标签2", "标签3"]}}"""
try:
client = self._get_client()
@@ -75,17 +76,24 @@ class GLMService:
logger.error(f"GLM service error: {e}")
raise Exception(f"AI 生成失败: {str(e)}")
async def rewrite_script(self, text: str) -> str:
async def rewrite_script(self, text: str, custom_prompt: str = None) -> str:
"""
AI 洗稿(文案改写)
AI 改写文案
Args:
text: 原始文案
custom_prompt: 自定义提示词,为空则使用默认提示词
Returns:
改写后的文案
"""
prompt = f"""请将以下视频文案进行改写。
if custom_prompt and custom_prompt.strip():
prompt = f"""{custom_prompt.strip()}
原始文案:
{text}"""
else:
prompt = f"""请将以下视频文案进行改写。
原始文案:
{text}
@@ -174,6 +182,8 @@ class GLMService:
# 尝试提取 JSON 块
json_match = re.search(r'\{[^{}]*"title"[^{}]*"tags"[^{}]*\}', content, re.DOTALL)
if not json_match:
json_match = re.search(r'\{[^{}]*"title"[^{}]*"secondary_title"[^{}]*"tags"[^{}]*\}', content, re.DOTALL)
if json_match:
try:
return json.loads(json_match.group())

View File

@@ -7,6 +7,7 @@ import asyncio
import json
import os
import subprocess
from collections.abc import Callable
from pathlib import Path
from typing import Optional
from loguru import logger
@@ -29,12 +30,15 @@ class RemotionService:
output_path: str,
captions_path: Optional[str] = None,
title: Optional[str] = None,
title_duration: float = 3.0,
title_duration: float = 4.0,
title_display_mode: str = "short",
fps: int = 25,
enable_subtitles: bool = True,
subtitle_style: Optional[dict] = None,
title_style: Optional[dict] = None,
on_progress: Optional[callable] = None
secondary_title: Optional[str] = None,
secondary_title_style: Optional[dict] = None,
on_progress: Optional[Callable[[int], None]] = None
) -> str:
"""
使用 Remotion 渲染视频(添加字幕和标题)
@@ -45,6 +49,7 @@ class RemotionService:
captions_path: 字幕 JSON 文件路径Whisper 生成)
title: 视频标题(可选)
title_duration: 标题显示时长(秒)
title_display_mode: 标题显示模式short/persistent
fps: 帧率
enable_subtitles: 是否启用字幕
on_progress: 进度回调函数
@@ -75,6 +80,7 @@ class RemotionService:
if title:
cmd.extend(["--title", title])
cmd.extend(["--titleDuration", str(title_duration)])
cmd.extend(["--titleDisplayMode", title_display_mode])
if subtitle_style:
cmd.extend(["--subtitleStyle", json.dumps(subtitle_style, ensure_ascii=False)])
@@ -82,6 +88,12 @@ class RemotionService:
if title_style:
cmd.extend(["--titleStyle", json.dumps(title_style, ensure_ascii=False)])
if secondary_title:
cmd.extend(["--secondaryTitle", secondary_title])
if secondary_title_style:
cmd.extend(["--secondaryTitleStyle", json.dumps(secondary_title_style, ensure_ascii=False)])
logger.info(f"Running Remotion render: {' '.join(cmd)}")
# 在线程池中运行子进程
@@ -95,8 +107,12 @@ class RemotionService:
bufsize=1
)
if process.stdout is None:
raise RuntimeError("Remotion process stdout is unavailable")
stdout = process.stdout
output_lines = []
for line in iter(process.stdout.readline, ''):
for line in iter(stdout.readline, ''):
line = line.strip()
if line:
output_lines.append(line)

View File

@@ -20,12 +20,13 @@ class StorageService:
self.BUCKET_MATERIALS = "materials"
self.BUCKET_OUTPUTS = "outputs"
self.BUCKET_REF_AUDIOS = "ref-audios"
self.BUCKET_GENERATED_AUDIOS = "generated-audios"
# 确保所有 bucket 存在
self._ensure_buckets()
def _ensure_buckets(self):
"""确保所有必需的 bucket 存在"""
buckets = [self.BUCKET_MATERIALS, self.BUCKET_OUTPUTS, self.BUCKET_REF_AUDIOS]
buckets = [self.BUCKET_MATERIALS, self.BUCKET_OUTPUTS, self.BUCKET_REF_AUDIOS, self.BUCKET_GENERATED_AUDIOS]
try:
existing = self.supabase.storage.list_buckets()
existing_names = {b.name for b in existing} if existing else set()

View File

@@ -9,9 +9,110 @@ from pathlib import Path
from loguru import logger
from typing import Optional
class VideoService:
def __init__(self):
pass
class VideoService:
def __init__(self):
pass
def get_video_metadata(self, file_path: str) -> dict:
"""获取视频元信息(含旋转角与有效显示分辨率)"""
cmd = [
"ffprobe", "-v", "error",
"-select_streams", "v:0",
"-show_entries", "stream=width,height:stream_side_data=rotation",
"-of", "json",
file_path,
]
default_info = {
"width": 0,
"height": 0,
"rotation": 0,
"effective_width": 0,
"effective_height": 0,
}
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
if result.returncode != 0:
return default_info
payload = json.loads(result.stdout or "{}")
streams = payload.get("streams") or []
if not streams:
return default_info
stream = streams[0]
width = int(stream.get("width") or 0)
height = int(stream.get("height") or 0)
rotation = 0
for side_data in stream.get("side_data_list") or []:
if not isinstance(side_data, dict):
continue
raw_rotation = side_data.get("rotation")
if raw_rotation is None:
continue
try:
rotation = int(round(float(str(raw_rotation))))
except Exception:
rotation = 0
break
norm_rotation = rotation % 360
if norm_rotation > 180:
norm_rotation -= 360
swap_wh = abs(norm_rotation) == 90
effective_width = height if swap_wh else width
effective_height = width if swap_wh else height
return {
"width": width,
"height": height,
"rotation": norm_rotation,
"effective_width": effective_width,
"effective_height": effective_height,
}
except Exception as e:
logger.warning(f"获取视频元信息失败: {e}")
return default_info
def normalize_orientation(self, video_path: str, output_path: str) -> str:
"""将带旋转元数据的视频转为物理方向,避免后续流程忽略 rotation。"""
info = self.get_video_metadata(video_path)
rotation = int(info.get("rotation") or 0)
if rotation == 0:
return video_path
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
logger.info(
f"检测到旋转元数据 rotation={rotation},归一化方向: "
f"{info.get('effective_width', 0)}x{info.get('effective_height', 0)}"
)
cmd = [
"ffmpeg", "-y",
"-i", video_path,
"-map", "0:v:0",
"-map", "0:a?",
"-c:v", "libx264",
"-preset", "fast",
"-crf", "18",
"-c:a", "copy",
"-movflags", "+faststart",
output_path,
]
if self._run_ffmpeg(cmd):
normalized = self.get_video_metadata(output_path)
logger.info(
"视频方向归一化完成: "
f"coded={normalized.get('width', 0)}x{normalized.get('height', 0)}, "
f"rotation={normalized.get('rotation', 0)}"
)
return output_path
logger.warning("视频方向归一化失败,回退使用原视频")
return video_path
def _run_ffmpeg(self, cmd: list) -> bool:
cmd_str = ' '.join(shlex.quote(str(c)) for c in cmd)
@@ -139,8 +240,8 @@ class VideoService:
else:
raise RuntimeError("FFmpeg composition failed")
def concat_videos(self, video_paths: list, output_path: str) -> str:
"""使用 FFmpeg concat demuxer 拼接多个视频片段"""
def concat_videos(self, video_paths: list, output_path: str, target_fps: int = 25) -> str:
"""使用 FFmpeg concat demuxer 拼接多个视频片段"""
if not video_paths:
raise ValueError("No video segments to concat")
@@ -152,14 +253,22 @@ class VideoService:
for vp in video_paths:
f.write(f"file '{vp}'\n")
cmd = [
"ffmpeg", "-y",
"-f", "concat",
"-safe", "0",
"-i", str(list_path),
"-c", "copy",
output_path,
]
cmd = [
"ffmpeg", "-y",
"-f", "concat",
"-safe", "0",
"-fflags", "+genpts",
"-i", str(list_path),
"-an",
"-vsync", "cfr",
"-r", str(target_fps),
"-c:v", "libx264",
"-preset", "fast",
"-crf", "18",
"-pix_fmt", "yuv420p",
"-movflags", "+faststart",
output_path,
]
try:
if self._run_ffmpeg(cmd):
@@ -193,54 +302,104 @@ class VideoService:
return output_path
raise RuntimeError(f"FFmpeg audio split failed: {start}-{end}")
def get_resolution(self, file_path: str) -> tuple:
"""获取视频分辨率,返回 (width, height)"""
cmd = [
'ffprobe', '-v', 'error',
'-select_streams', 'v:0',
'-show_entries', 'stream=width,height',
'-of', 'csv=p=0',
file_path
]
try:
result = subprocess.run(cmd, capture_output=True, text=True, timeout=10)
parts = result.stdout.strip().split(',')
return (int(parts[0]), int(parts[1]))
except Exception:
return (0, 0)
def get_resolution(self, file_path: str) -> tuple[int, int]:
"""获取视频有效显示分辨率(考虑旋转元数据)。"""
info = self.get_video_metadata(file_path)
return (
int(info.get("effective_width") or 0),
int(info.get("effective_height") or 0),
)
def prepare_segment(self, video_path: str, target_duration: float, output_path: str,
target_resolution: tuple = None) -> str:
"""将素材视频裁剪或循环到指定时长(无音频)。
target_resolution: (width, height) 如需统一分辨率则传入,否则保持原分辨率
"""
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
def prepare_segment(self, video_path: str, target_duration: float, output_path: str,
target_resolution: Optional[tuple] = None, source_start: float = 0.0,
source_end: Optional[float] = None, target_fps: Optional[int] = None) -> str:
"""将素材视频裁剪或循环到指定时长(无音频)
target_resolution: (width, height) 如需统一分辨率则传入,否则保持原分辨率。
source_start: 源视频截取起点(秒),默认 0。
source_end: 源视频截取终点(秒),默认到素材结尾。
target_fps: 输出帧率(可选),用于多素材拼接前统一时间基。
"""
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
video_dur = self._get_duration(video_path)
if video_dur <= 0:
video_dur = target_duration
clip_end = video_dur
if source_end is not None:
try:
source_end_value = float(source_end)
if source_end_value > source_start:
clip_end = min(source_end_value, video_dur)
except Exception:
pass
# 可用时长 = 从 source_start 到视频结尾
available = max(clip_end - source_start, 0.1)
needs_loop = target_duration > available
needs_scale = target_resolution is not None
needs_fps = bool(target_fps and target_fps > 0)
has_source_end = clip_end < video_dur
# 当需要循环且存在截取范围时,先裁剪出片段,再循环裁剪后的文件
# 避免 stream_loop 循环整个视频(而不是截取后的片段)
actual_input = video_path
trim_temp = None
if needs_loop and (source_start > 0 or has_source_end):
trim_temp = str(Path(output_path).parent / (Path(output_path).stem + "_trim_tmp.mp4"))
trim_cmd = [
"ffmpeg", "-y",
"-ss", str(source_start),
"-i", video_path,
"-t", str(available),
"-an",
"-c:v", "libx264", "-preset", "fast", "-crf", "18",
trim_temp,
]
if not self._run_ffmpeg(trim_cmd):
raise RuntimeError(f"FFmpeg trim for loop failed: {video_path}")
actual_input = trim_temp
source_start = 0.0 # 已裁剪,不需要再 seek
# 重新计算循环次数(基于裁剪后文件)
available = self._get_duration(trim_temp) or available
video_dur = self._get_duration(video_path)
if video_dur <= 0:
video_dur = target_duration
needs_loop = target_duration > video_dur
needs_scale = target_resolution is not None
loop_count = int(target_duration / available) + 1 if needs_loop else 0
cmd = ["ffmpeg", "-y"]
if needs_loop:
loop_count = int(target_duration / video_dur) + 1
cmd.extend(["-stream_loop", str(loop_count)])
cmd.extend(["-i", video_path, "-t", str(target_duration), "-an"])
if needs_scale:
w, h = target_resolution
cmd.extend(["-vf", f"scale={w}:{h}:force_original_aspect_ratio=decrease,pad={w}:{h}:(ow-iw)/2:(oh-ih)/2"])
# 需要循环或缩放时必须重编码,否则用 stream copy 保持原画质
if needs_loop or needs_scale:
cmd.extend(["-c:v", "libx264", "-preset", "fast", "-crf", "18"])
else:
cmd.extend(["-c:v", "copy"])
if source_start > 0:
cmd.extend(["-ss", str(source_start)])
cmd.extend(["-i", actual_input, "-t", str(target_duration), "-an"])
filters = []
if needs_fps:
filters.append(f"fps={int(target_fps)}")
if needs_scale:
w, h = target_resolution
filters.append(f"scale={w}:{h}:force_original_aspect_ratio=decrease,pad={w}:{h}:(ow-iw)/2:(oh-ih)/2")
if filters:
cmd.extend(["-vf", ",".join(filters)])
if needs_fps:
cmd.extend(["-vsync", "cfr", "-r", str(int(target_fps))])
# 需要循环、缩放或指定起点时必须重编码,否则用 stream copy 保持原画质
if needs_loop or needs_scale or source_start > 0 or has_source_end or needs_fps:
cmd.extend(["-c:v", "libx264", "-preset", "fast", "-crf", "18"])
else:
cmd.extend(["-c:v", "copy"])
cmd.append(output_path)
if self._run_ffmpeg(cmd):
return output_path
raise RuntimeError(f"FFmpeg prepare_segment failed: {video_path}")
try:
if self._run_ffmpeg(cmd):
return output_path
raise RuntimeError(f"FFmpeg prepare_segment failed: {video_path}")
finally:
# 清理裁剪临时文件
if trim_temp:
try:
Path(trim_temp).unlink(missing_ok=True)
except Exception:
pass

View File

@@ -1,37 +1,104 @@
"""
声音克隆服务
通过 HTTP 调用 Qwen3-TTS 独立服务 (端口 8009)
通过 HTTP 调用 CosyVoice 3.0 独立服务 (端口 8010)
"""
import httpx
import asyncio
from pathlib import Path
from typing import Optional
import httpx
from loguru import logger
from app.core.config import settings
# Qwen3-TTS 服务地址
QWEN_TTS_URL = "http://localhost:8009"
# CosyVoice 3.0 服务地址
VOICE_CLONE_URL = "http://localhost:8010"
class VoiceCloneService:
"""声音克隆服务 - 调用 Qwen3-TTS HTTP API"""
"""声音克隆服务 - 调用 CosyVoice 3.0 HTTP API"""
def __init__(self):
self.base_url = QWEN_TTS_URL
self.base_url = VOICE_CLONE_URL
# 健康状态缓存
self._health_cache: Optional[dict] = None
self._health_cache_time: float = 0
# GPU 并发锁 (Serial Queue)
self._lock = asyncio.Lock()
async def _generate_once(
self,
*,
text: str,
ref_audio_data: bytes,
ref_text: str,
language: str,
speed: float = 1.0,
max_retries: int = 4,
) -> bytes:
timeout = httpx.Timeout(240.0)
for attempt in range(max_retries):
try:
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.post(
f"{self.base_url}/generate",
files={"ref_audio": ("ref.wav", ref_audio_data, "audio/wav")},
data={
"text": text,
"ref_text": ref_text,
"language": language,
"speed": str(speed),
},
)
retryable = False
reason = ""
if response.status_code in (429, 502, 503, 504):
retryable = True
reason = f"HTTP {response.status_code}"
elif response.status_code == 500 and (
"生成超时" in response.text or "timeout" in response.text.lower()
):
retryable = True
reason = "upstream timeout"
if retryable and attempt < max_retries - 1:
wait = 8 * (attempt + 1)
logger.warning(
f"Voice clone retryable error ({reason}), retrying in {wait}s "
f"(attempt {attempt + 1}/{max_retries})"
)
await asyncio.sleep(wait)
continue
response.raise_for_status()
return response.content
except httpx.HTTPStatusError as e:
logger.error(f"Voice clone API error: {e.response.status_code} - {e.response.text}")
raise RuntimeError(f"声音克隆服务错误: {e.response.text}")
except httpx.RequestError as e:
if attempt < max_retries - 1:
wait = 6 * (attempt + 1)
logger.warning(
f"Voice clone connection error: {e}; retrying in {wait}s "
f"(attempt {attempt + 1}/{max_retries})"
)
await asyncio.sleep(wait)
continue
logger.error(f"Voice clone connection error: {e}")
raise RuntimeError("无法连接声音克隆服务,请检查服务是否启动")
raise RuntimeError("声音克隆服务繁忙,请稍后重试")
async def generate_audio(
self,
text: str,
ref_audio_path: str,
ref_text: str,
output_path: str,
language: str = "Chinese"
language: str = "Chinese",
speed: float = 1.0,
) -> str:
"""
使用声音克隆生成语音
@@ -51,60 +118,49 @@ class VoiceCloneService:
logger.info(f"🎤 Voice Clone: {text[:30]}... (language={language})")
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
# 读取参考音频
text = text.strip()
if not text:
raise RuntimeError("文本为空,无法生成语音")
with open(ref_audio_path, "rb") as f:
ref_audio_data = f.read()
# 调用 Qwen3-TTS 服务
timeout = httpx.Timeout(300.0) # 5分钟超时
async with httpx.AsyncClient(timeout=timeout) as client:
try:
response = await client.post(
f"{self.base_url}/generate",
files={"ref_audio": ("ref.wav", ref_audio_data, "audio/wav")},
data={
"text": text,
"ref_text": ref_text,
"language": language
}
)
response.raise_for_status()
# 保存返回的音频
with open(output_path, "wb") as f:
f.write(response.content)
logger.info(f"✅ Voice clone saved: {output_path}")
return output_path
except httpx.HTTPStatusError as e:
logger.error(f"Qwen3-TTS API error: {e.response.status_code} - {e.response.text}")
raise RuntimeError(f"声音克隆服务错误: {e.response.text}")
except httpx.RequestError as e:
logger.error(f"Qwen3-TTS connection error: {e}")
raise RuntimeError("无法连接声音克隆服务,请检查服务是否启动")
# CosyVoice 内部自带 text_normalize 分段,无需客户端切分
audio_bytes = await self._generate_once(
text=text,
ref_audio_data=ref_audio_data,
ref_text=ref_text,
language=language,
speed=speed,
)
with open(output_path, "wb") as f:
f.write(audio_bytes)
logger.info(f"✅ Voice clone saved: {output_path}")
return output_path
async def check_health(self) -> dict:
"""健康检查"""
import time
# 5分钟缓存
# 30秒缓存
now = time.time()
if self._health_cache and (now - self._health_cache_time) < 300:
return self._health_cache
cached = self._health_cache
if cached is not None and (now - self._health_cache_time) < 30:
return cached
try:
async with httpx.AsyncClient(timeout=5.0) as client:
response = await client.get(f"{self.base_url}/health")
response.raise_for_status()
self._health_cache = response.json()
payload = response.json()
self._health_cache = payload
self._health_cache_time = now
return self._health_cache
return payload
except Exception as e:
logger.warning(f"Qwen3-TTS health check failed: {e}")
logger.warning(f"Voice clone health check failed: {e}")
return {
"service": "Qwen3-TTS Voice Clone",
"model": "0.6B-Base",
"service": "CosyVoice 3.0 Voice Clone",
"model": "unknown",
"ready": False,
"gpu_id": 0,
"error": str(e)

View File

@@ -39,12 +39,22 @@ def split_word_to_chars(word: str, start: float, end: float) -> list:
tokens = []
ascii_buffer = ""
pending_space = False # 记录是否有待处理的空格(用于英文单词间距)
for char in word:
if not char.strip():
# 空格flush ascii_buffer标记下一个 token 需要前导空格
if ascii_buffer:
tokens.append(ascii_buffer)
ascii_buffer = ""
if tokens: # 仅在已有 token 时标记(避免开头重复空格)
pending_space = True
continue
if char.isascii() and char.isalnum():
if pending_space and not ascii_buffer:
ascii_buffer = " " # 将空格前置到新英文单词
pending_space = False
ascii_buffer += char
continue
@@ -52,7 +62,9 @@ def split_word_to_chars(word: str, start: float, end: float) -> list:
tokens.append(ascii_buffer)
ascii_buffer = ""
tokens.append(char)
prefix = " " if pending_space else ""
pending_space = False
tokens.append(prefix + char)
if ascii_buffer:
tokens.append(ascii_buffer)
@@ -175,6 +187,7 @@ class WhisperService:
text: str,
output_path: Optional[str] = None,
language: str = "zh",
original_text: Optional[str] = None,
) -> dict:
"""
对音频进行转录,生成字级别时间戳
@@ -184,6 +197,8 @@ class WhisperService:
text: 原始文本(用于参考,但实际使用 whisper 转录结果)
output_path: 可选,输出 JSON 文件路径
language: 语言代码 (zh/en 等)
original_text: 原始文案。非空时Whisper 仅用于检测总时间范围,
字幕文字用此原文替换(解决语言不匹配问题)
Returns:
包含字级别时间戳的字典
@@ -208,16 +223,19 @@ class WhisperService:
logger.info(f"Detected language: {info.language} (prob: {info.language_probability:.2f})")
# 收集 Whisper 转录结果(始终需要,用于获取时间范围)
all_segments = []
whisper_first_start = None
whisper_last_end = None
for segment in segments_iter:
# 提取每个字的时间戳,并拆分成单字
all_words = []
if segment.words:
for word_info in segment.words:
word_text = word_info.word
if word_text.strip():
# 将词拆分成单字,时间戳线性插值
# 保留前导空格用于英文词间距
if whisper_first_start is None:
whisper_first_start = word_info.start
whisper_last_end = word_info.end
chars = split_word_to_chars(
word_text,
word_info.start,
@@ -225,11 +243,24 @@ class WhisperService:
)
all_words.extend(chars)
# 将长段落按标点和字数拆分成多行
if all_words:
line_segments = split_segment_to_lines(all_words, max_chars)
all_segments.extend(line_segments)
# 如果提供了 original_text用原文替换 Whisper 转录文字
if original_text and original_text.strip() and whisper_first_start is not None:
logger.info(f"Using original_text for subtitles (len={len(original_text)}), "
f"Whisper time range: {whisper_first_start:.2f}-{whisper_last_end:.2f}s")
# 用 split_word_to_chars 拆分原文
orig_chars = split_word_to_chars(
original_text.strip(),
whisper_first_start,
whisper_last_end
)
if orig_chars:
all_segments = split_segment_to_lines(orig_chars, max_chars)
logger.info(f"Rebuilt {len(all_segments)} subtitle segments from original text")
logger.info(f"Generated {len(all_segments)} subtitle segments")
return {"segments": all_segments}
@@ -247,12 +278,13 @@ class WhisperService:
return result
async def transcribe(self, audio_path: str) -> str:
async def transcribe(self, audio_path: str, language: str | None = None) -> str:
"""
仅转录文本(用于提取文案)
Args:
audio_path: 音频/视频文件路径
language: 语言代码None 表示自动检测
Returns:
纯文本内容
@@ -266,7 +298,7 @@ class WhisperService:
# 转录 (无需字级时间戳)
segments_iter, _ = model.transcribe(
audio_path,
language="zh",
language=language,
word_timestamps=False,
vad_filter=True,
)

View File

@@ -71,3 +71,18 @@ CREATE TRIGGER users_updated_at
BEFORE UPDATE ON users
FOR EACH ROW
EXECUTE FUNCTION update_updated_at();
-- 8. 订单表(支付宝付费)
CREATE TABLE IF NOT EXISTS orders (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID REFERENCES users(id) ON DELETE CASCADE,
out_trade_no TEXT UNIQUE NOT NULL,
amount DECIMAL(10, 2) NOT NULL DEFAULT 999.00,
status TEXT DEFAULT 'pending' CHECK (status IN ('pending', 'paid', 'failed')),
trade_no TEXT,
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
paid_at TIMESTAMP WITH TIME ZONE
);
CREATE INDEX IF NOT EXISTS idx_orders_user_id ON orders(user_id);
CREATE INDEX IF NOT EXISTS idx_orders_out_trade_no ON orders(out_trade_no);

31
backend/package-lock.json generated Normal file
View File

@@ -0,0 +1,31 @@
{
"name": "backend",
"lockfileVersion": 3,
"requires": true,
"packages": {
"": {
"dependencies": {
"qrcode.react": "^4.2.0"
}
},
"node_modules/qrcode.react": {
"version": "4.2.0",
"resolved": "https://registry.npmjs.org/qrcode.react/-/qrcode.react-4.2.0.tgz",
"integrity": "sha512-QpgqWi8rD9DsS9EP3z7BT+5lY5SFhsqGjpgW5DY/i3mK4M9DTBNz3ErMi8BWYEfI3L0d8GIbGmcdFAS1uIRGjA==",
"license": "ISC",
"peerDependencies": {
"react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0"
}
},
"node_modules/react": {
"version": "19.2.4",
"resolved": "https://registry.npmjs.org/react/-/react-19.2.4.tgz",
"integrity": "sha512-9nfp2hYpCwOjAN+8TZFGhtWEwgvWHXqESH8qT89AT/lWklpLON22Lc8pEtnpsZz7VmawabSU0gCjnj8aC0euHQ==",
"license": "MIT",
"peer": true,
"engines": {
"node": ">=0.10.0"
}
}
}
}

5
backend/package.json Normal file
View File

@@ -0,0 +1,5 @@
{
"dependencies": {
"qrcode.react": "^4.2.0"
}
}

View File

@@ -29,6 +29,9 @@ python-jose[cryptography]>=3.3.0
passlib[bcrypt]>=1.7.4
bcrypt==4.0.1
# 支付宝支付
python-alipay-sdk>=3.6.0
# 字幕对齐
faster-whisper>=1.0.0

View File

@@ -20,62 +20,81 @@ logger = logging.getLogger("Watchdog")
# 服务配置
SERVICES = [
{
"name": "vigent2-qwen-tts",
"url": "http://localhost:8009/health",
"name": "vigent2-cosyvoice",
"url": "http://localhost:8010/health",
"failures": 0,
"threshold": 3,
"threshold": 3, # 连续3次失败才重启3×15s ≈ 45秒容忍期
"timeout": 10.0,
"restart_cmd": ["pm2", "restart", "vigent2-qwen-tts"]
"restart_cmd": ["pm2", "restart", "vigent2-cosyvoice"],
"cooldown_until": 0, # 重启后的冷却截止时间戳
"cooldown_sec": 45, # 重启后等待45秒再开始检查
}
]
async def check_service(service):
"""检查单个服务健康状态"""
# 冷却期内跳过检查
now = time.time()
if now < service.get("cooldown_until", 0):
remaining = int(service["cooldown_until"] - now)
logger.debug(f"⏳ 服务 {service['name']} 冷却中,剩余 {remaining}s")
return True
try:
timeout = service.get("timeout", 10.0)
async with httpx.AsyncClient(timeout=timeout) as client:
response = await client.get(service["url"])
if response.status_code == 200:
# 成功
if service["failures"] > 0:
logger.info(f"✅ 服务 {service['name']} 已恢复正常")
service["failures"] = 0
return True
ready = True
try:
payload = response.json()
ready = bool(payload.get("ready", True))
except Exception:
payload = {}
if ready:
if service["failures"] > 0:
logger.info(f"✅ 服务 {service['name']} 已恢复正常")
service["failures"] = 0
return True
logger.warning(f"⚠️ 服务 {service['name']} ready=false健康检查未通过: {payload}")
else:
logger.warning(f"⚠️ 服务 {service['name']} 返回状态码 {response.status_code}")
except Exception as e:
logger.warning(f"⚠️ 无法连接服务 {service['name']}: {str(e)}")
# 失败处理
service["failures"] += 1
logger.warning(f"❌ 服务 {service['name']} 连续失败 {service['failures']}/{service['threshold']}")
if service["failures"] >= service['threshold']:
logger.error(f"🚨 服务 {service['name']} 已达到失败阈值,正在重启...")
try:
subprocess.run(service["restart_cmd"], check=True)
logger.info(f"♻️ 服务 {service['name']} 重启命令已发送")
# 重启后给予一段宽限期 (例如 60秒) 不检查,等待服务启动
service["failures"] = 0 # 重置计数
return "restarting"
service["failures"] = 0
# 设置冷却期,等待服务完成启动和模型加载
service["cooldown_until"] = time.time() + service.get("cooldown_sec", 120)
return "restarting"
except Exception as restart_error:
logger.error(f"💥 重启服务 {service['name']} 失败: {restart_error}")
return False
async def main():
logger.info("🛡️ ViGent2 服务看门狗 (Watchdog) 已启动")
# 启动时给所有服务一个初始冷却期,避免服务还没起来就被判定失败
for service in SERVICES:
service["cooldown_until"] = time.time() + 60
while True:
# 并发检查所有服务
for service in SERVICES:
result = await check_service(service)
if result == "restarting":
# 如果有服务重启,额外等待包含启动时间
pass
# 每 30 秒检查一次
await asyncio.sleep(30)
await check_service(service)
# 每 15 秒检查一次
await asyncio.sleep(15)
if __name__ == "__main__":
try:

View File

@@ -15,10 +15,12 @@
"axios": "^1.13.4",
"lucide-react": "^0.563.0",
"next": "16.1.1",
"qrcode.react": "^4.2.0",
"react": "19.2.3",
"react-dom": "19.2.3",
"sonner": "^2.0.7",
"swr": "^2.3.8"
"swr": "^2.3.8",
"wavesurfer.js": "^7.12.1"
},
"devDependencies": {
"@tailwindcss/postcss": "^4",
@@ -5617,6 +5619,15 @@
"node": ">=6"
}
},
"node_modules/qrcode.react": {
"version": "4.2.0",
"resolved": "https://registry.npmjs.org/qrcode.react/-/qrcode.react-4.2.0.tgz",
"integrity": "sha512-QpgqWi8rD9DsS9EP3z7BT+5lY5SFhsqGjpgW5DY/i3mK4M9DTBNz3ErMi8BWYEfI3L0d8GIbGmcdFAS1uIRGjA==",
"license": "ISC",
"peerDependencies": {
"react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0"
}
},
"node_modules/queue-microtask": {
"version": "1.2.3",
"resolved": "https://registry.npmjs.org/queue-microtask/-/queue-microtask-1.2.3.tgz",
@@ -6667,6 +6678,12 @@
"react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0"
}
},
"node_modules/wavesurfer.js": {
"version": "7.12.1",
"resolved": "https://registry.npmjs.org/wavesurfer.js/-/wavesurfer.js-7.12.1.tgz",
"integrity": "sha512-NswPjVHxk0Q1F/VMRemCPUzSojjuHHisQrBqQiRXg7MVbe3f5vQ6r0rTTXA/a/neC/4hnOEC4YpXca4LpH0SUg==",
"license": "BSD-3-Clause"
},
"node_modules/which": {
"version": "2.0.2",
"resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz",

View File

@@ -16,10 +16,12 @@
"axios": "^1.13.4",
"lucide-react": "^0.563.0",
"next": "16.1.1",
"qrcode.react": "^4.2.0",
"react": "19.2.3",
"react-dom": "19.2.3",
"sonner": "^2.0.7",
"swr": "^2.3.8"
"swr": "^2.3.8",
"wavesurfer.js": "^7.12.1"
},
"devDependencies": {
"@tailwindcss/postcss": "^4",

View File

@@ -46,7 +46,6 @@ export default function RootLayout({
<Toaster
position="top-center"
richColors
closeButton
toastOptions={{
duration: 3000,
className: "text-sm",

View File

@@ -25,7 +25,10 @@ export default function LoginPage() {
try {
const result = await login(phone, password);
if (result.success) {
if (result.paymentToken) {
sessionStorage.setItem('payment_token', result.paymentToken);
router.push('/pay');
} else if (result.success) {
router.push('/');
} else {
setError(result.message || '登录失败');

View File

@@ -0,0 +1,160 @@
'use client';
import { Suspense, useState, useEffect, useRef } from 'react';
import { useRouter, useSearchParams } from 'next/navigation';
import api from '@/shared/api/axios';
type PageStatus = 'loading' | 'redirecting' | 'checking' | 'success' | 'error';
function PayContent() {
const router = useRouter();
const searchParams = useSearchParams();
const [status, setStatus] = useState<PageStatus>('loading');
const [errorMsg, setErrorMsg] = useState('');
const pollRef = useRef<ReturnType<typeof setInterval> | null>(null);
useEffect(() => {
const outTradeNo = searchParams.get('out_trade_no');
if (outTradeNo) {
setStatus('checking');
startPolling(outTradeNo);
return;
}
const token = sessionStorage.getItem('payment_token');
if (!token) {
router.replace('/login');
return;
}
createOrder(token);
return () => {
if (pollRef.current) clearInterval(pollRef.current);
};
}, []);
const createOrder = async (token: string) => {
try {
const { data } = await api.post('/api/payment/create-order', { payment_token: token });
const { pay_url } = data.data;
setStatus('redirecting');
window.location.href = pay_url;
} catch (err: any) {
setStatus('error');
setErrorMsg(err.response?.data?.message || '创建订单失败,请重新登录');
}
};
const startPolling = (tradeNo: string) => {
checkStatus(tradeNo);
pollRef.current = setInterval(() => checkStatus(tradeNo), 3000);
};
const checkStatus = async (tradeNo: string) => {
try {
const { data } = await api.get(`/api/payment/status/${tradeNo}`);
if (data.data.status === 'paid') {
if (pollRef.current) clearInterval(pollRef.current);
setStatus('success');
sessionStorage.removeItem('payment_token');
setTimeout(() => router.replace('/login'), 3000);
}
} catch {
// ignore polling errors
}
};
return (
<div className="w-full max-w-md p-8 bg-white/10 backdrop-blur-lg rounded-2xl shadow-2xl border border-white/20">
{(status === 'loading' || status === 'redirecting') && (
<div className="text-center">
<div className="mb-6">
<svg className="animate-spin h-12 w-12 mx-auto text-purple-400" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
<circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4"></circle>
<path className="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
</svg>
</div>
<p className="text-gray-300">
{status === 'loading' ? '正在创建订单...' : '正在跳转到支付宝...'}
</p>
</div>
)}
{status === 'checking' && (
<div className="text-center">
<h1 className="text-2xl font-bold text-white mb-6"></h1>
<div className="flex items-center justify-center gap-2 text-purple-300 mb-4">
<svg className="animate-spin h-5 w-5" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
<circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4"></circle>
<path className="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
</svg>
...
</div>
<p className="text-gray-400 text-sm"></p>
</div>
)}
{status === 'success' && (
<div className="text-center">
<div className="mb-6">
<svg className="w-16 h-16 mx-auto text-green-400" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z" />
</svg>
</div>
<h2 className="text-2xl font-bold text-white mb-4"></h2>
<p className="text-gray-300 mb-2">...</p>
<p className="text-gray-500 text-sm">使</p>
</div>
)}
{status === 'error' && (
<div className="text-center">
<div className="mb-6">
<svg className="w-16 h-16 mx-auto text-red-400" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M12 8v4m0 4h.01M21 12a9 9 0 11-18 0 9 9 0 0118 0z" />
</svg>
</div>
<h2 className="text-2xl font-bold text-white mb-4"></h2>
<p className="text-red-300 mb-6">{errorMsg}</p>
<button
onClick={() => router.replace('/login')}
className="py-3 px-6 bg-gradient-to-r from-purple-600 to-pink-600 text-white font-semibold rounded-lg"
>
</button>
</div>
)}
{status === 'checking' && (
<div className="mt-6 text-center">
<button
onClick={() => {
if (pollRef.current) clearInterval(pollRef.current);
router.replace('/login');
}}
className="text-purple-300 hover:text-purple-200 text-sm"
>
</button>
</div>
)}
</div>
);
}
export default function PayPage() {
return (
<div className="min-h-dvh flex items-center justify-center">
<Suspense fallback={
<div className="w-full max-w-md p-8 bg-white/10 backdrop-blur-lg rounded-2xl shadow-2xl border border-white/20 text-center">
<svg className="animate-spin h-12 w-12 mx-auto text-purple-400" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
<circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4"></circle>
<path className="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
</svg>
</div>
}>
<PayContent />
</Suspense>
</div>
);
}

View File

@@ -61,7 +61,7 @@ export default function RegisterPage() {
</div>
<h2 className="text-2xl font-bold text-white mb-4"></h2>
<p className="text-gray-300 mb-6">
</p>
<a
href="/login"

View File

@@ -0,0 +1,193 @@
import { useCallback, useEffect, useRef, useState } from "react";
import api from "@/shared/api/axios";
import { ApiResponse, unwrap } from "@/shared/api/types";
import { toast } from "sonner";
export interface GeneratedAudio {
id: string;
name: string;
path: string;
duration_sec: number;
text: string;
tts_mode: string;
language: string;
created_at: number;
}
interface AudioTask {
status: string;
progress?: number;
message?: string;
output?: GeneratedAudio & { audio_id: string };
}
interface UseGeneratedAudiosOptions {
selectedAudioId: string | null;
setSelectedAudioId: React.Dispatch<React.SetStateAction<string | null>>;
}
export const useGeneratedAudios = ({
selectedAudioId,
setSelectedAudioId,
}: UseGeneratedAudiosOptions) => {
const [generatedAudios, setGeneratedAudios] = useState<GeneratedAudio[]>([]);
const [selectedAudio, setSelectedAudio] = useState<GeneratedAudio | null>(null);
const [isGeneratingAudio, setIsGeneratingAudio] = useState(false);
const [audioTaskId, setAudioTaskId] = useState<string | null>(null);
const [audioTask, setAudioTask] = useState<AudioTask | null>(null);
const pollRef = useRef<NodeJS.Timeout | null>(null);
const fetchGeneratedAudios = useCallback(async (selectId?: string) => {
try {
const { data: res } = await api.get<ApiResponse<{ items: GeneratedAudio[] }>>(
"/api/generated-audios"
);
const payload = unwrap(res);
const items: GeneratedAudio[] = payload.items || [];
setGeneratedAudios(items);
if (selectId && items.length > 0) {
if (selectId === "__latest__") {
setSelectedAudioId(items[0].id);
setSelectedAudio(items[0]);
} else {
const found = items.find((a) => a.id === selectId);
if (found) {
setSelectedAudioId(found.id);
setSelectedAudio(found);
}
}
}
} catch (error) {
console.error("获取配音列表失败:", error);
}
}, [setSelectedAudioId]);
// Sync selectedAudio when selectedAudioId changes externally (e.g. from persistence)
useEffect(() => {
if (!selectedAudioId || generatedAudios.length === 0) return;
const found = generatedAudios.find((a) => a.id === selectedAudioId);
if (found) {
setSelectedAudio(found);
}
}, [selectedAudioId, generatedAudios]);
const stopPolling = useCallback(() => {
if (pollRef.current) {
clearInterval(pollRef.current);
pollRef.current = null;
}
}, []);
const startPolling = useCallback((taskId: string) => {
stopPolling();
pollRef.current = setInterval(async () => {
try {
const { data: res } = await api.get<ApiResponse<AudioTask>>(
`/api/generated-audios/tasks/${taskId}`
);
const task = unwrap(res);
setAudioTask(task);
if (task.status === "completed") {
stopPolling();
setIsGeneratingAudio(false);
setAudioTaskId(null);
// Refresh list and select the new audio
await fetchGeneratedAudios("__latest__");
toast.success(task.message || "配音生成完成");
} else if (task.status === "failed") {
stopPolling();
setIsGeneratingAudio(false);
setAudioTaskId(null);
toast.error(task.message || "配音生成失败");
} else if (task.status === "not_found") {
stopPolling();
setIsGeneratingAudio(false);
setAudioTaskId(null);
setAudioTask(null);
toast.error("任务已丢失(服务可能已重启),请重新生成");
}
} catch {
// Network error, keep polling
}
}, 1000);
}, [stopPolling, fetchGeneratedAudios]);
// Cleanup on unmount
useEffect(() => {
return () => stopPolling();
}, [stopPolling]);
const generateAudio = useCallback(async (params: {
text: string;
tts_mode: string;
voice?: string;
ref_audio_id?: string;
ref_text?: string;
language: string;
speed?: number;
}) => {
setIsGeneratingAudio(true);
setAudioTask({ status: "pending", progress: 0, message: "正在提交..." });
try {
const { data: res } = await api.post<ApiResponse<{ task_id: string }>>(
"/api/generated-audios/generate",
params
);
const { task_id } = unwrap(res);
setAudioTaskId(task_id);
startPolling(task_id);
} catch (err: unknown) {
setIsGeneratingAudio(false);
setAudioTask(null);
const axiosErr = err as { response?: { data?: { message?: string } }; message?: string };
const errorMsg = axiosErr.response?.data?.message || axiosErr.message || String(err);
toast.error(`配音生成失败: ${errorMsg}`);
}
}, [startPolling]);
const deleteAudio = useCallback(async (audioId: string) => {
if (!confirm("确定要删除这个配音吗?")) return;
try {
await api.delete(`/api/generated-audios/${encodeURIComponent(audioId)}`);
if (selectedAudioId === audioId) {
setSelectedAudioId(null);
setSelectedAudio(null);
}
fetchGeneratedAudios();
} catch (error) {
toast.error("删除失败: " + error);
}
}, [fetchGeneratedAudios, selectedAudioId, setSelectedAudioId]);
const renameAudio = useCallback(async (audioId: string, newName: string) => {
try {
await api.put(`/api/generated-audios/${encodeURIComponent(audioId)}`, {
new_name: newName,
});
fetchGeneratedAudios();
} catch (err: unknown) {
toast.error("重命名失败: " + String(err));
}
}, [fetchGeneratedAudios]);
const selectAudio = useCallback((audio: GeneratedAudio) => {
setSelectedAudioId(audio.id);
setSelectedAudio(audio);
}, [setSelectedAudioId]);
return {
generatedAudios,
selectedAudio,
selectedAudioId,
isGeneratingAudio,
audioTask,
fetchGeneratedAudios,
generateAudio,
deleteAudio,
renameAudio,
selectAudio,
};
};

View File

@@ -9,7 +9,7 @@ import {
resolveBgmUrl,
resolveMediaUrl,
} from "@/shared/lib/media";
import { clampTitle } from "@/shared/lib/title";
import { clampTitle, clampSecondaryTitle, SECONDARY_TITLE_MAX_LENGTH } from "@/shared/lib/title";
import { useTitleInput } from "@/shared/hooks/useTitleInput";
import { useAuth } from "@/shared/contexts/AuthContext";
import { useTask } from "@/shared/contexts/TaskContext";
@@ -18,11 +18,14 @@ import { usePublishPrefetch } from "@/shared/hooks/usePublishPrefetch";
import { PublishAccount } from "@/shared/types/publish";
import { useBgm } from "@/features/home/model/useBgm";
import { useGeneratedVideos } from "@/features/home/model/useGeneratedVideos";
import { useGeneratedAudios } from "@/features/home/model/useGeneratedAudios";
import { useHomePersistence } from "@/features/home/model/useHomePersistence";
import { useMaterials } from "@/features/home/model/useMaterials";
import { useMediaPlayers } from "@/features/home/model/useMediaPlayers";
import { useRefAudios } from "@/features/home/model/useRefAudios";
import { useTitleSubtitleStyles } from "@/features/home/model/useTitleSubtitleStyles";
import { useTimelineEditor } from "@/features/home/model/useTimelineEditor";
import { useSavedScripts } from "@/features/home/model/useSavedScripts";
import { ApiResponse, unwrap } from "@/shared/api/types";
const VOICES: Record<string, { id: string; name: string }[]> = {
@@ -84,10 +87,9 @@ const LANG_TO_LOCALE: Record<string, string> = {
"Português": "pt-BR",
};
const DEFAULT_SHORT_TITLE_DURATION = 4;
const FIXED_REF_TEXT =
"其实生活中有许多美好的瞬间,比如清晨的阳光,或者一杯温热的清茶。希望这次生成的音色能够自然、流畅,完美还原出我最真实的声音状态。";
const scrollContainerToItem = (container: HTMLDivElement, item: HTMLDivElement) => {
const containerRect = container.getBoundingClientRect();
@@ -149,10 +151,19 @@ export const useHomeController = () => {
const [subtitleSizeLocked, setSubtitleSizeLocked] = useState<boolean>(false);
const [titleSizeLocked, setTitleSizeLocked] = useState<boolean>(false);
const [titleTopMargin, setTitleTopMargin] = useState<number>(62);
const [titleDisplayMode, setTitleDisplayMode] = useState<"short" | "persistent">("short");
const [subtitleBottomMargin, setSubtitleBottomMargin] = useState<number>(80);
const [outputAspectRatio, setOutputAspectRatio] = useState<"9:16" | "16:9">("9:16");
const [showStylePreview, setShowStylePreview] = useState<boolean>(false);
const [materialDimensions, setMaterialDimensions] = useState<{ width: number; height: number } | null>(null);
// 副标题相关状态
const [videoSecondaryTitle, setVideoSecondaryTitle] = useState<string>("");
const [selectedSecondaryTitleStyleId, setSelectedSecondaryTitleStyleId] = useState<string>("");
const [secondaryTitleFontSize, setSecondaryTitleFontSize] = useState<number>(48);
const [secondaryTitleTopMargin, setSecondaryTitleTopMargin] = useState<number>(12);
const [secondaryTitleSizeLocked, setSecondaryTitleSizeLocked] = useState<boolean>(false);
// 背景音乐相关状态
const [selectedBgmId, setSelectedBgmId] = useState<string>("");
@@ -162,7 +173,17 @@ export const useHomeController = () => {
// 声音克隆相关状态
const [ttsMode, setTtsMode] = useState<"edgetts" | "voiceclone">("edgetts");
const [selectedRefAudio, setSelectedRefAudio] = useState<RefAudio | null>(null);
const [refText, setRefText] = useState(FIXED_REF_TEXT);
const [refText, setRefText] = useState("");
// 预生成配音选中 ID
const [selectedAudioId, setSelectedAudioId] = useState<string | null>(null);
// 语速控制
const [speed, setSpeed] = useState<number>(1.0);
// ClipTrimmer 模态框状态
const [clipTrimmerOpen, setClipTrimmerOpen] = useState(false);
const [clipTrimmerSegmentId, setClipTrimmerSegmentId] = useState<string | null>(null);
// 音频预览与重命名状态
const [editingAudioId, setEditingAudioId] = useState<string | null>(null);
@@ -276,7 +297,6 @@ export const useHomeController = () => {
setUploadError,
fetchMaterials,
toggleMaterial,
reorderMaterials,
deleteMaterial,
handleUpload,
} = useMaterials({
@@ -304,8 +324,9 @@ export const useHomeController = () => {
fetchRefAudios,
uploadRefAudio,
deleteRefAudio,
retranscribeRefAudio,
retranscribingId,
} = useRefAudios({
fixedRefText: FIXED_REF_TEXT,
selectedRefAudio,
setSelectedRefAudio,
setRefText,
@@ -347,6 +368,33 @@ export const useHomeController = () => {
resolveMediaUrl,
});
const {
generatedAudios,
selectedAudio,
isGeneratingAudio,
audioTask,
fetchGeneratedAudios,
generateAudio,
deleteAudio,
renameAudio,
selectAudio,
} = useGeneratedAudios({
selectedAudioId,
setSelectedAudioId,
});
const {
segments: timelineSegments,
reorderSegments,
setSourceRange,
toCustomAssignments,
} = useTimelineEditor({
audioDuration: selectedAudio?.duration_sec ?? 0,
materials,
selectedMaterials,
storageKey,
});
useEffect(() => {
if (isAuthLoading || !userId) return;
let active = true;
@@ -389,6 +437,8 @@ export const useHomeController = () => {
setText,
videoTitle,
setVideoTitle,
videoSecondaryTitle,
setVideoSecondaryTitle,
ttsMode,
setTtsMode,
voice,
@@ -401,16 +451,27 @@ export const useHomeController = () => {
setSelectedSubtitleStyleId,
selectedTitleStyleId,
setSelectedTitleStyleId,
selectedSecondaryTitleStyleId,
setSelectedSecondaryTitleStyleId,
subtitleFontSize,
setSubtitleFontSize,
titleFontSize,
setTitleFontSize,
secondaryTitleFontSize,
setSecondaryTitleFontSize,
setSubtitleSizeLocked,
setTitleSizeLocked,
setSecondaryTitleSizeLocked,
titleTopMargin,
setTitleTopMargin,
secondaryTitleTopMargin,
setSecondaryTitleTopMargin,
titleDisplayMode,
setTitleDisplayMode,
subtitleBottomMargin,
setSubtitleBottomMargin,
outputAspectRatio,
setOutputAspectRatio,
selectedBgmId,
setSelectedBgmId,
bgmVolume,
@@ -420,8 +481,20 @@ export const useHomeController = () => {
selectedVideoId,
setSelectedVideoId,
selectedRefAudio,
selectedAudioId,
setSelectedAudioId,
speed,
setSpeed,
});
const { savedScripts, saveScript, deleteScript: deleteSavedScript } = useSavedScripts(storageKey);
const handleSaveScript = () => {
if (!text.trim()) return;
saveScript(text);
toast.success("文案已保存");
};
const syncTitleToPublish = (value: string) => {
if (typeof window !== "undefined") {
localStorage.setItem(`vigent_${storageKey}_publish_title`, value);
@@ -434,6 +507,12 @@ export const useHomeController = () => {
onCommit: syncTitleToPublish,
});
const secondaryTitleInput = useTitleInput({
value: videoSecondaryTitle,
onChange: setVideoSecondaryTitle,
maxLength: SECONDARY_TITLE_MAX_LENGTH,
});
// 加载素材列表和历史视频
useEffect(() => {
if (isAuthLoading) return;
@@ -441,6 +520,7 @@ export const useHomeController = () => {
fetchMaterials(),
fetchGeneratedVideos(),
fetchRefAudios(),
fetchGeneratedAudios(),
refreshSubtitleStyles(),
refreshTitleStyles(),
fetchBgmList(),
@@ -475,7 +555,6 @@ export const useHomeController = () => {
let isActive = true;
const video = document.createElement("video");
video.crossOrigin = "anonymous";
video.preload = "metadata";
video.src = url;
video.load();
@@ -525,6 +604,16 @@ export const useHomeController = () => {
}
}, [titleStyles, selectedTitleStyleId, titleSizeLocked]);
useEffect(() => {
if (secondaryTitleSizeLocked || titleStyles.length === 0) return;
const active = titleStyles.find((s) => s.id === selectedSecondaryTitleStyleId)
|| titleStyles.find((s) => s.is_default)
|| titleStyles[0];
if (active?.font_size) {
setSecondaryTitleFontSize(active.font_size);
}
}, [titleStyles, selectedSecondaryTitleStyleId, secondaryTitleSizeLocked]);
// 移除重复的 BGM 持久化恢复逻辑 (已统一移动到 useHomePersistence 中)
// useEffect(() => { ... })
@@ -537,14 +626,22 @@ export const useHomeController = () => {
}
}, [selectedBgmId, bgmList]);
// 素材列表滚动:跳过首次恢复,仅用户主动操作时滚动
const materialScrollReady = useRef(false);
useEffect(() => {
const firstSelected = selectedMaterials[0];
if (!firstSelected) return;
if (!materialScrollReady.current) {
// 首次有选中素材时标记就绪,但不滚动(避免刷新后整页跳动)
materialScrollReady.current = true;
return;
}
const target = materialItemRefs.current[firstSelected];
if (target) {
target.scrollIntoView({ block: "nearest", behavior: "smooth" });
}
}, [selectedMaterials, materials]);
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [selectedMaterials.length]);
// 【修复】历史视频默认选中逻辑
// 当持久化恢复完成,且列表加载完毕,如果没选中任何视频,默认选中第一个
@@ -554,7 +651,7 @@ export const useHomeController = () => {
setSelectedVideoId(firstId);
setGeneratedVideo(resolveMediaUrl(generatedVideos[0].path));
}
}, [isRestored, generatedVideos, selectedVideoId, setSelectedVideoId, setGeneratedVideo, resolveMediaUrl]);
}, [isRestored, generatedVideos, selectedVideoId, setSelectedVideoId, setGeneratedVideo]);
// 【修复】BGM 默认选中逻辑
useEffect(() => {
@@ -563,8 +660,14 @@ export const useHomeController = () => {
}
}, [isRestored, bgmList, selectedBgmId, enableBgm, setSelectedBgmId]);
const videoScrollReady = useRef(false);
useEffect(() => {
if (!selectedVideoId) return;
if (!videoScrollReady.current) {
videoScrollReady.current = true;
return;
}
const target = videoItemRefs.current[selectedVideoId];
if (target) {
target.scrollIntoView({ block: "nearest", behavior: "smooth" });
@@ -670,7 +773,7 @@ export const useHomeController = () => {
setIsGeneratingMeta(true);
try {
const { data: res } = await api.post<ApiResponse<{ title?: string; tags?: string[] }>>(
const { data: res } = await api.post<ApiResponse<{ title?: string; secondary_title?: string; tags?: string[] }>>(
"/api/ai/generate-meta",
{ text: text.trim() }
);
@@ -680,6 +783,10 @@ export const useHomeController = () => {
const nextTitle = clampTitle(payload.title || "");
titleInput.commitValue(nextTitle);
// 更新副标题
const nextSecondaryTitle = clampSecondaryTitle(payload.secondary_title || "");
secondaryTitleInput.commitValue(nextSecondaryTitle);
// 同步到发布页 localStorage
localStorage.setItem(`vigent_${storageKey}_publish_tags`, JSON.stringify(payload.tags || []));
} catch (err: unknown) {
@@ -741,6 +848,29 @@ export const useHomeController = () => {
}
};
// 生成配音
const handleGenerateAudio = async () => {
if (!text.trim()) {
toast.error("请先输入文案");
return;
}
if (ttsMode === "voiceclone" && !selectedRefAudio) {
toast.error("请选择参考音频");
return;
}
const params = {
text: text.trim(),
tts_mode: ttsMode,
voice: ttsMode === "edgetts" ? voice : undefined,
ref_audio_id: ttsMode === "voiceclone" ? selectedRefAudio!.id : undefined,
ref_text: ttsMode === "voiceclone" ? refText : undefined,
language: textLang,
speed: ttsMode === "voiceclone" ? speed : undefined,
};
await generateAudio(params);
};
// 生成视频
const handleGenerate = async () => {
if (selectedMaterials.length === 0 || !text.trim()) {
@@ -748,12 +878,9 @@ export const useHomeController = () => {
return;
}
// 声音克隆模式校验
if (ttsMode === "voiceclone") {
if (!selectedRefAudio) {
toast.error("请选择或上传参考音频");
return;
}
if (!selectedAudio) {
toast.error("请先生成并选中配音");
return;
}
if (enableBgm && !selectedBgmId) {
@@ -771,20 +898,68 @@ export const useHomeController = () => {
return;
}
// 构建请求参数
// 构建请求参数 - 使用预生成配音
const payload: Record<string, unknown> = {
material_path: firstMaterialObj.path,
text: text,
tts_mode: ttsMode,
text: selectedAudio.text || text,
generated_audio_id: selectedAudio.id,
language: selectedAudio.language || textLang,
title: videoTitle.trim() || undefined,
enable_subtitles: true,
output_aspect_ratio: outputAspectRatio,
};
// 多素材
if (selectedMaterials.length > 1) {
payload.material_paths = selectedMaterials
const timelineOrderedIds = timelineSegments
.map((seg) => seg.materialId)
.filter((id, index, arr) => arr.indexOf(id) === index);
const orderedMaterialIds = [
...timelineOrderedIds.filter((id) => selectedMaterials.includes(id)),
...selectedMaterials.filter((id) => !timelineOrderedIds.includes(id)),
];
const materialPaths = orderedMaterialIds
.map((id) => materials.find((x) => x.id === id)?.path)
.filter((path): path is string => !!path);
if (materialPaths.length === 0) {
toast.error("多素材解析失败,请刷新素材后重试");
return;
}
payload.material_paths = materialPaths;
payload.material_path = materialPaths[0];
// 发送自定义时间轴分配
const assignments = toCustomAssignments();
if (assignments.length > 0) {
const assignmentPaths = assignments
.map((a) => a.material_path)
.filter((path): path is string => !!path);
if (assignmentPaths.length === assignments.length) {
// 以时间轴可见段为准:超出时间轴的素材不会参与本次生成
payload.material_paths = assignmentPaths;
payload.material_path = assignmentPaths[0];
}
payload.custom_assignments = assignments;
} else {
console.warn(
"[Timeline] custom_assignments 为空,回退后端自动分配",
{ materials: materialPaths.length }
);
}
}
// 单素材 + 截取范围
const singleSeg = timelineSegments[0];
if (
selectedMaterials.length === 1
&& singleSeg
&& (singleSeg.sourceStart > 0 || singleSeg.sourceEnd > 0)
) {
payload.custom_assignments = toCustomAssignments();
}
if (selectedSubtitleStyleId) {
@@ -804,9 +979,24 @@ export const useHomeController = () => {
}
if (videoTitle.trim()) {
payload.title_display_mode = titleDisplayMode;
if (titleDisplayMode === "short") {
payload.title_duration = DEFAULT_SHORT_TITLE_DURATION;
}
payload.title_top_margin = Math.round(titleTopMargin);
}
if (videoSecondaryTitle.trim()) {
payload.secondary_title = videoSecondaryTitle.trim();
if (selectedSecondaryTitleStyleId) {
payload.secondary_title_style_id = selectedSecondaryTitleStyleId;
}
if (secondaryTitleFontSize) {
payload.secondary_title_font_size = Math.round(secondaryTitleFontSize);
}
payload.secondary_title_top_margin = Math.round(secondaryTitleTopMargin);
}
payload.subtitle_bottom_margin = Math.round(subtitleBottomMargin);
if (enableBgm && selectedBgmId) {
@@ -814,15 +1004,6 @@ export const useHomeController = () => {
payload.bgm_volume = bgmVolume;
}
payload.language = textLang;
if (ttsMode === "edgetts") {
payload.voice = voice;
} else {
payload.ref_audio_id = selectedRefAudio!.id;
payload.ref_text = refText;
}
// 创建生成任务
const { data: res } = await api.post<ApiResponse<{ task_id: string }>>(
"/api/videos/generate",
@@ -885,7 +1066,6 @@ export const useHomeController = () => {
handleUpload,
selectedMaterials,
toggleMaterial,
reorderMaterials,
handlePreviewMaterial,
editingMaterialId,
editMaterialName,
@@ -903,6 +1083,9 @@ export const useHomeController = () => {
isTranslating,
originalText,
handleRestoreOriginal,
savedScripts,
handleSaveScript,
deleteSavedScript,
showStylePreview,
setShowStylePreview,
videoTitle,
@@ -913,6 +1096,15 @@ export const useHomeController = () => {
titleFontSize,
setTitleFontSize,
setTitleSizeLocked,
videoSecondaryTitle,
secondaryTitleInput,
selectedSecondaryTitleStyleId,
setSelectedSecondaryTitleStyleId,
secondaryTitleFontSize,
setSecondaryTitleFontSize,
setSecondaryTitleSizeLocked,
secondaryTitleTopMargin,
setSecondaryTitleTopMargin,
subtitleStyles,
selectedSubtitleStyleId,
setSelectedSubtitleStyleId,
@@ -921,8 +1113,12 @@ export const useHomeController = () => {
setSubtitleSizeLocked,
titleTopMargin,
setTitleTopMargin,
titleDisplayMode,
setTitleDisplayMode,
subtitleBottomMargin,
setSubtitleBottomMargin,
outputAspectRatio,
setOutputAspectRatio,
resolveAssetUrl,
getFontFormat,
buildTextShadow,
@@ -950,6 +1146,8 @@ export const useHomeController = () => {
saveEditing,
cancelEditing,
deleteRefAudio,
retranscribeRefAudio,
retranscribingId,
recordedBlob,
isRecording,
recordingTime,
@@ -957,7 +1155,6 @@ export const useHomeController = () => {
stopRecording,
useRecording,
formatRecordingTime,
fixedRefText: FIXED_REF_TEXT,
bgmList,
bgmLoading,
bgmError,
@@ -983,5 +1180,24 @@ export const useHomeController = () => {
fetchGeneratedVideos,
registerVideoRef,
formatDate,
generatedAudios,
selectedAudio,
selectedAudioId,
isGeneratingAudio,
audioTask,
fetchGeneratedAudios,
handleGenerateAudio,
deleteAudio,
renameAudio,
selectAudio,
speed,
setSpeed,
timelineSegments,
reorderSegments,
setSourceRange,
clipTrimmerOpen,
setClipTrimmerOpen,
clipTrimmerSegmentId,
setClipTrimmerSegmentId,
};
};

View File

@@ -1,5 +1,5 @@
import { useEffect, useState } from "react";
import { clampTitle } from "@/shared/lib/title";
import { clampTitle, clampSecondaryTitle } from "@/shared/lib/title";
interface RefAudio {
id: string;
@@ -17,6 +17,8 @@ interface UseHomePersistenceOptions {
setText: React.Dispatch<React.SetStateAction<string>>;
videoTitle: string;
setVideoTitle: React.Dispatch<React.SetStateAction<string>>;
videoSecondaryTitle: string;
setVideoSecondaryTitle: React.Dispatch<React.SetStateAction<string>>;
ttsMode: 'edgetts' | 'voiceclone';
setTtsMode: React.Dispatch<React.SetStateAction<'edgetts' | 'voiceclone'>>;
voice: string;
@@ -29,16 +31,27 @@ interface UseHomePersistenceOptions {
setSelectedSubtitleStyleId: React.Dispatch<React.SetStateAction<string>>;
selectedTitleStyleId: string;
setSelectedTitleStyleId: React.Dispatch<React.SetStateAction<string>>;
selectedSecondaryTitleStyleId: string;
setSelectedSecondaryTitleStyleId: React.Dispatch<React.SetStateAction<string>>;
subtitleFontSize: number;
setSubtitleFontSize: React.Dispatch<React.SetStateAction<number>>;
titleFontSize: number;
setTitleFontSize: React.Dispatch<React.SetStateAction<number>>;
secondaryTitleFontSize: number;
setSecondaryTitleFontSize: React.Dispatch<React.SetStateAction<number>>;
setSubtitleSizeLocked: React.Dispatch<React.SetStateAction<boolean>>;
setTitleSizeLocked: React.Dispatch<React.SetStateAction<boolean>>;
setSecondaryTitleSizeLocked: React.Dispatch<React.SetStateAction<boolean>>;
titleTopMargin: number;
setTitleTopMargin: React.Dispatch<React.SetStateAction<number>>;
secondaryTitleTopMargin: number;
setSecondaryTitleTopMargin: React.Dispatch<React.SetStateAction<number>>;
titleDisplayMode: 'short' | 'persistent';
setTitleDisplayMode: React.Dispatch<React.SetStateAction<'short' | 'persistent'>>;
subtitleBottomMargin: number;
setSubtitleBottomMargin: React.Dispatch<React.SetStateAction<number>>;
outputAspectRatio: '9:16' | '16:9';
setOutputAspectRatio: React.Dispatch<React.SetStateAction<'9:16' | '16:9'>>;
selectedBgmId: string;
setSelectedBgmId: React.Dispatch<React.SetStateAction<string>>;
bgmVolume: number;
@@ -48,6 +61,10 @@ interface UseHomePersistenceOptions {
selectedVideoId: string | null;
setSelectedVideoId: React.Dispatch<React.SetStateAction<string | null>>;
selectedRefAudio: RefAudio | null;
selectedAudioId: string | null;
setSelectedAudioId: React.Dispatch<React.SetStateAction<string | null>>;
speed: number;
setSpeed: React.Dispatch<React.SetStateAction<number>>;
}
export const useHomePersistence = ({
@@ -57,6 +74,8 @@ export const useHomePersistence = ({
setText,
videoTitle,
setVideoTitle,
videoSecondaryTitle,
setVideoSecondaryTitle,
ttsMode,
setTtsMode,
voice,
@@ -69,16 +88,27 @@ export const useHomePersistence = ({
setSelectedSubtitleStyleId,
selectedTitleStyleId,
setSelectedTitleStyleId,
selectedSecondaryTitleStyleId,
setSelectedSecondaryTitleStyleId,
subtitleFontSize,
setSubtitleFontSize,
titleFontSize,
setTitleFontSize,
secondaryTitleFontSize,
setSecondaryTitleFontSize,
setSubtitleSizeLocked,
setTitleSizeLocked,
setSecondaryTitleSizeLocked,
titleTopMargin,
setTitleTopMargin,
secondaryTitleTopMargin,
setSecondaryTitleTopMargin,
titleDisplayMode,
setTitleDisplayMode,
subtitleBottomMargin,
setSubtitleBottomMargin,
outputAspectRatio,
setOutputAspectRatio,
selectedBgmId,
setSelectedBgmId,
bgmVolume,
@@ -88,6 +118,10 @@ export const useHomePersistence = ({
selectedVideoId,
setSelectedVideoId,
selectedRefAudio,
selectedAudioId,
setSelectedAudioId,
speed,
setSpeed,
}: UseHomePersistenceOptions) => {
const [isRestored, setIsRestored] = useState(false);
@@ -96,23 +130,32 @@ export const useHomePersistence = ({
const savedText = localStorage.getItem(`vigent_${storageKey}_text`);
const savedTitle = localStorage.getItem(`vigent_${storageKey}_title`);
const savedSecondaryTitle = localStorage.getItem(`vigent_${storageKey}_secondaryTitle`);
const savedTtsMode = localStorage.getItem(`vigent_${storageKey}_ttsMode`);
const savedVoice = localStorage.getItem(`vigent_${storageKey}_voice`);
const savedTextLang = localStorage.getItem(`vigent_${storageKey}_textLang`);
const savedMaterial = localStorage.getItem(`vigent_${storageKey}_material`);
const savedSubtitleStyle = localStorage.getItem(`vigent_${storageKey}_subtitleStyle`);
const savedTitleStyle = localStorage.getItem(`vigent_${storageKey}_titleStyle`);
const savedSecondaryTitleStyle = localStorage.getItem(`vigent_${storageKey}_secondaryTitleStyle`);
const savedSubtitleFontSize = localStorage.getItem(`vigent_${storageKey}_subtitleFontSize`);
const savedTitleFontSize = localStorage.getItem(`vigent_${storageKey}_titleFontSize`);
const savedSecondaryTitleFontSize = localStorage.getItem(`vigent_${storageKey}_secondaryTitleFontSize`);
const savedBgmId = localStorage.getItem(`vigent_${storageKey}_bgmId`);
const savedSelectedVideoId = localStorage.getItem(`vigent_${storageKey}_selectedVideoId`);
const savedSelectedAudioId = localStorage.getItem(`vigent_${storageKey}_selectedAudioId`);
const savedBgmVolume = localStorage.getItem(`vigent_${storageKey}_bgmVolume`);
const savedEnableBgm = localStorage.getItem(`vigent_${storageKey}_enableBgm`);
const savedTitleTopMargin = localStorage.getItem(`vigent_${storageKey}_titleTopMargin`);
const savedSecondaryTitleTopMargin = localStorage.getItem(`vigent_${storageKey}_secondaryTitleTopMargin`);
const savedTitleDisplayMode = localStorage.getItem(`vigent_${storageKey}_titleDisplayMode`);
const savedSubtitleBottomMargin = localStorage.getItem(`vigent_${storageKey}_subtitleBottomMargin`);
const savedOutputAspectRatio = localStorage.getItem(`vigent_${storageKey}_outputAspectRatio`);
const savedSpeed = localStorage.getItem(`vigent_${storageKey}_speed`);
setText(savedText || "大家好,欢迎来到我的频道,今天给大家分享一些有趣的内容。");
setVideoTitle(savedTitle ? clampTitle(savedTitle) : "");
setVideoSecondaryTitle(savedSecondaryTitle ? clampSecondaryTitle(savedSecondaryTitle) : "");
setTtsMode((savedTtsMode as 'edgetts' | 'voiceclone') || 'edgetts');
setVoice(savedVoice || "zh-CN-YunxiNeural");
if (savedTextLang) setTextLang(savedTextLang);
@@ -132,6 +175,7 @@ export const useHomePersistence = ({
}
if (savedSubtitleStyle) setSelectedSubtitleStyleId(savedSubtitleStyle);
if (savedTitleStyle) setSelectedTitleStyleId(savedTitleStyle);
if (savedSecondaryTitleStyle) setSelectedSecondaryTitleStyleId(savedSecondaryTitleStyle);
if (savedSubtitleFontSize) {
const parsed = parseInt(savedSubtitleFontSize, 10);
@@ -149,20 +193,45 @@ export const useHomePersistence = ({
}
}
if (savedSecondaryTitleFontSize) {
const parsed = parseInt(savedSecondaryTitleFontSize, 10);
if (!Number.isNaN(parsed)) {
setSecondaryTitleFontSize(parsed);
setSecondaryTitleSizeLocked(true);
}
}
if (savedBgmId) setSelectedBgmId(savedBgmId);
if (savedBgmVolume) setBgmVolume(parseFloat(savedBgmVolume));
if (savedEnableBgm !== null) setEnableBgm(savedEnableBgm === 'true');
if (savedSelectedVideoId) setSelectedVideoId(savedSelectedVideoId);
if (savedSelectedAudioId) setSelectedAudioId(savedSelectedAudioId);
if (savedTitleTopMargin) {
const parsed = parseInt(savedTitleTopMargin, 10);
if (!Number.isNaN(parsed)) setTitleTopMargin(parsed);
}
if (savedSecondaryTitleTopMargin) {
const parsed = parseInt(savedSecondaryTitleTopMargin, 10);
if (!Number.isNaN(parsed)) setSecondaryTitleTopMargin(parsed);
}
if (savedTitleDisplayMode === 'short' || savedTitleDisplayMode === 'persistent') {
setTitleDisplayMode(savedTitleDisplayMode);
}
if (savedSubtitleBottomMargin) {
const parsed = parseInt(savedSubtitleBottomMargin, 10);
if (!Number.isNaN(parsed)) setSubtitleBottomMargin(parsed);
}
if (savedOutputAspectRatio === '9:16' || savedOutputAspectRatio === '16:9') {
setOutputAspectRatio(savedOutputAspectRatio);
}
if (savedSpeed) {
const parsed = parseFloat(savedSpeed);
if (!Number.isNaN(parsed)) setSpeed(parsed);
}
// eslint-disable-next-line react-hooks/set-state-in-effect
setIsRestored(true);
}, [
@@ -173,17 +242,26 @@ export const useHomePersistence = ({
setSelectedMaterials,
setSelectedSubtitleStyleId,
setSelectedTitleStyleId,
setSelectedSecondaryTitleStyleId,
setSelectedVideoId,
setSelectedAudioId,
setSpeed,
setSubtitleFontSize,
setSubtitleSizeLocked,
setText,
setTextLang,
setTitleFontSize,
setTitleSizeLocked,
setSecondaryTitleFontSize,
setSecondaryTitleSizeLocked,
setTitleTopMargin,
setSecondaryTitleTopMargin,
setTitleDisplayMode,
setSubtitleBottomMargin,
setOutputAspectRatio,
setTtsMode,
setVideoTitle,
setVideoSecondaryTitle,
setVoice,
storageKey,
]);
@@ -204,6 +282,14 @@ export const useHomePersistence = ({
return () => clearTimeout(timeout);
}, [videoTitle, storageKey, isRestored]);
useEffect(() => {
if (!isRestored) return;
const timeout = setTimeout(() => {
localStorage.setItem(`vigent_${storageKey}_secondaryTitle`, videoSecondaryTitle);
}, 300);
return () => clearTimeout(timeout);
}, [videoSecondaryTitle, storageKey, isRestored]);
useEffect(() => {
if (isRestored) localStorage.setItem(`vigent_${storageKey}_ttsMode`, ttsMode);
}, [ttsMode, storageKey, isRestored]);
@@ -234,6 +320,12 @@ export const useHomePersistence = ({
}
}, [selectedTitleStyleId, storageKey, isRestored]);
useEffect(() => {
if (isRestored && selectedSecondaryTitleStyleId) {
localStorage.setItem(`vigent_${storageKey}_secondaryTitleStyle`, selectedSecondaryTitleStyleId);
}
}, [selectedSecondaryTitleStyleId, storageKey, isRestored]);
useEffect(() => {
if (isRestored) {
localStorage.setItem(`vigent_${storageKey}_subtitleFontSize`, String(subtitleFontSize));
@@ -246,18 +338,42 @@ export const useHomePersistence = ({
}
}, [titleFontSize, storageKey, isRestored]);
useEffect(() => {
if (isRestored) {
localStorage.setItem(`vigent_${storageKey}_secondaryTitleFontSize`, String(secondaryTitleFontSize));
}
}, [secondaryTitleFontSize, storageKey, isRestored]);
useEffect(() => {
if (isRestored) {
localStorage.setItem(`vigent_${storageKey}_titleTopMargin`, String(titleTopMargin));
}
}, [titleTopMargin, storageKey, isRestored]);
useEffect(() => {
if (isRestored) {
localStorage.setItem(`vigent_${storageKey}_secondaryTitleTopMargin`, String(secondaryTitleTopMargin));
}
}, [secondaryTitleTopMargin, storageKey, isRestored]);
useEffect(() => {
if (isRestored) {
localStorage.setItem(`vigent_${storageKey}_titleDisplayMode`, titleDisplayMode);
}
}, [titleDisplayMode, storageKey, isRestored]);
useEffect(() => {
if (isRestored) {
localStorage.setItem(`vigent_${storageKey}_subtitleBottomMargin`, String(subtitleBottomMargin));
}
}, [subtitleBottomMargin, storageKey, isRestored]);
useEffect(() => {
if (isRestored) {
localStorage.setItem(`vigent_${storageKey}_outputAspectRatio`, outputAspectRatio);
}
}, [outputAspectRatio, storageKey, isRestored]);
useEffect(() => {
if (isRestored) {
localStorage.setItem(`vigent_${storageKey}_bgmId`, selectedBgmId);
@@ -287,11 +403,26 @@ export const useHomePersistence = ({
}
}, [selectedVideoId, storageKey, isRestored]);
useEffect(() => {
if (!isRestored) return;
if (selectedAudioId) {
localStorage.setItem(`vigent_${storageKey}_selectedAudioId`, selectedAudioId);
} else {
localStorage.removeItem(`vigent_${storageKey}_selectedAudioId`);
}
}, [selectedAudioId, storageKey, isRestored]);
useEffect(() => {
if (isRestored && selectedRefAudio) {
localStorage.setItem(`vigent_${storageKey}_refAudioId`, selectedRefAudio.id);
}
}, [selectedRefAudio, storageKey, isRestored]);
useEffect(() => {
if (isRestored) {
localStorage.setItem(`vigent_${storageKey}_speed`, String(speed));
}
}, [speed, storageKey, isRestored]);
return { isRestored };
};

View File

@@ -2,8 +2,36 @@ import { useCallback, useState } from "react";
import api from "@/shared/api/axios";
import { ApiResponse, unwrap } from "@/shared/api/types";
import { toast } from "sonner";
import { resolveMediaUrl } from "@/shared/lib/media";
import type { Material } from "@/shared/types/material";
/** Probe video duration from a URL using <video> element */
function probeVideoDuration(url: string): Promise<number> {
return new Promise((resolve) => {
const video = document.createElement("video");
video.preload = "metadata";
video.crossOrigin = "anonymous";
const cleanup = () => {
video.removeEventListener("loadedmetadata", onMeta);
video.removeEventListener("error", onError);
video.src = "";
};
const onMeta = () => {
const dur = video.duration;
cleanup();
resolve(Number.isFinite(dur) ? dur : 0);
};
const onError = () => {
cleanup();
resolve(0);
};
video.addEventListener("loadedmetadata", onMeta);
video.addEventListener("error", onError);
video.src = url;
video.load();
});
}
interface UseMaterialsOptions {
selectedMaterials: string[];
setSelectedMaterials: React.Dispatch<React.SetStateAction<string[]>>;
@@ -34,6 +62,18 @@ export const useMaterials = ({
setMaterials(nextMaterials);
setLastMaterialCount(nextMaterials.length);
// Probe video durations in background
if (nextMaterials.length > 0) {
Promise.all(
nextMaterials.map(async (m) => {
const url = resolveMediaUrl(m.path);
if (!url) return m;
const dur = await probeVideoDuration(url);
return { ...m, duration_sec: dur };
})
).then((enriched) => setMaterials(enriched));
}
setSelectedMaterials((prev) => {
// 保留已选中且仍存在的
const existingIds = new Set(nextMaterials.map((m) => m.id));
@@ -133,11 +173,26 @@ export const useMaterials = ({
setMaterials(nextMaterials);
setLastMaterialCount(nextMaterials.length);
// 找出新增的素材 ID 并自动选中
// Probe video durations in background
if (nextMaterials.length > 0) {
Promise.all(
nextMaterials.map(async (m) => {
const url = resolveMediaUrl(m.path);
if (!url) return m;
const dur = await probeVideoDuration(url);
return { ...m, duration_sec: dur };
})
).then((enriched) => setMaterials(enriched));
}
// 找出新增素材并默认仅选中新上传项,避免误触发多素材模式
const oldIds = new Set(materials.map((m) => m.id));
const newIds = nextMaterials.filter((m) => !oldIds.has(m.id)).map((m) => m.id);
if (newIds.length > 0) {
setSelectedMaterials((prev) => [...prev, ...newIds]);
setSelectedMaterials([newIds[0]]);
} else if (nextMaterials[0]?.id) {
// 兜底:即使未识别到新增项,也保持单素材默认选择最新一个
setSelectedMaterials([nextMaterials[0].id]);
}
} catch (err: unknown) {
console.error("Upload failed:", err);
@@ -148,7 +203,7 @@ export const useMaterials = ({
}
e.target.value = '';
}, [fetchMaterials]);
}, [materials, setSelectedMaterials]);
return {
materials,

View File

@@ -13,14 +13,12 @@ interface RefAudio {
}
interface UseRefAudiosOptions {
fixedRefText: string;
selectedRefAudio: RefAudio | null;
setSelectedRefAudio: React.Dispatch<React.SetStateAction<RefAudio | null>>;
setRefText: React.Dispatch<React.SetStateAction<string>>;
}
export const useRefAudios = ({
fixedRefText,
selectedRefAudio,
setSelectedRefAudio,
setRefText,
@@ -28,6 +26,7 @@ export const useRefAudios = ({
const [refAudios, setRefAudios] = useState<RefAudio[]>([]);
const [isUploadingRef, setIsUploadingRef] = useState(false);
const [uploadRefError, setUploadRefError] = useState<string | null>(null);
const [retranscribingId, setRetranscribingId] = useState<string | null>(null);
const fetchRefAudios = useCallback(async () => {
try {
@@ -42,15 +41,12 @@ export const useRefAudios = ({
}, []);
const uploadRefAudio = useCallback(async (file: File) => {
const refTextInput = fixedRefText;
setIsUploadingRef(true);
setUploadRefError(null);
try {
const formData = new FormData();
formData.append('file', file);
formData.append('ref_text', refTextInput);
const { data: res } = await api.post<ApiResponse<RefAudio>>('/api/ref-audios', formData, {
headers: { 'Content-Type': 'multipart/form-data' },
@@ -68,7 +64,7 @@ export const useRefAudios = ({
const errorMsg = axiosErr.response?.data?.message || axiosErr.message || String(err);
setUploadRefError(`上传失败: ${errorMsg}`);
}
}, [fetchRefAudios, fixedRefText, setRefText, setSelectedRefAudio]);
}, [fetchRefAudios, setRefText, setSelectedRefAudio]);
const deleteRefAudio = useCallback(async (audioId: string) => {
if (!confirm("确定要删除这个参考音频吗?")) return;
@@ -84,6 +80,28 @@ export const useRefAudios = ({
}
}, [fetchRefAudios, selectedRefAudio, setRefText, setSelectedRefAudio]);
const retranscribeRefAudio = useCallback(async (audioId: string) => {
setRetranscribingId(audioId);
try {
const { data: res } = await api.post<ApiResponse<{ ref_text: string }>>(
`/api/ref-audios/${encodeURIComponent(audioId)}/retranscribe`
);
const payload = unwrap(res);
toast.success("识别完成");
// 更新列表和当前选中
await fetchRefAudios();
if (selectedRefAudio?.id === audioId) {
setRefText(payload.ref_text);
}
} catch (err: unknown) {
const axiosErr = err as { response?: { data?: { message?: string } }; message?: string };
const errorMsg = axiosErr.response?.data?.message || axiosErr.message || String(err);
toast.error(`识别失败: ${errorMsg}`);
} finally {
setRetranscribingId(null);
}
}, [fetchRefAudios, selectedRefAudio, setRefText]);
return {
refAudios,
isUploadingRef,
@@ -92,5 +110,7 @@ export const useRefAudios = ({
fetchRefAudios,
uploadRefAudio,
deleteRefAudio,
retranscribeRefAudio,
retranscribingId,
};
};

View File

@@ -0,0 +1,51 @@
import { useState, useEffect, useRef } from "react";
export interface SavedScript {
id: string;
name: string;
content: string;
savedAt: number;
}
export function useSavedScripts(storageKey: string) {
const lsKey = `vigent_${storageKey}_savedScripts`;
const lsKeyRef = useRef(lsKey);
lsKeyRef.current = lsKey;
const [savedScripts, setSavedScripts] = useState<SavedScript[]>([]);
// Re-read from localStorage whenever lsKey changes (e.g. guest → userId)
useEffect(() => {
try {
const raw = localStorage.getItem(lsKey);
setSavedScripts(raw ? JSON.parse(raw) : []);
} catch {
setSavedScripts([]);
}
}, [lsKey]);
const saveScript = (content: string) => {
const name = content.slice(0, 15).replace(/\n/g, " ") || "未命名";
const entry: SavedScript = {
id: Date.now().toString(36) + Math.random().toString(36).slice(2, 6),
name,
content,
savedAt: Date.now(),
};
setSavedScripts((prev) => {
const next = [entry, ...prev];
localStorage.setItem(lsKeyRef.current, JSON.stringify(next));
return next;
});
};
const deleteScript = (id: string) => {
setSavedScripts((prev) => {
const next = prev.filter((s) => s.id !== id);
localStorage.setItem(lsKeyRef.current, JSON.stringify(next));
return next;
});
};
return { savedScripts, saveScript, deleteScript };
}

View File

@@ -0,0 +1,256 @@
import { useCallback, useEffect, useRef, useState } from "react";
import type { Material } from "@/shared/types/material";
export interface TimelineSegment {
id: string;
materialId: string;
materialName: string;
start: number;
end: number;
sourceStart: number;
sourceEnd: number;
color: string;
}
export interface CustomAssignment {
material_path: string;
start: number;
end: number;
source_start: number;
source_end?: number;
}
const COLORS = ["#8b5cf6", "#ec4899", "#06b6d4", "#f59e0b", "#10b981", "#f97316"];
/** Serializable subset for localStorage */
interface SegmentSnapshot {
materialId: string;
start: number;
end: number;
sourceStart: number;
sourceEnd: number;
}
/** Get effective duration of a segment (clipped range or full material duration) */
function getEffectiveDuration(
seg: { sourceStart: number; sourceEnd: number; materialId: string },
mats: Material[]
): number {
const mat = mats.find((m) => m.id === seg.materialId);
const matDur = mat?.duration_sec ?? 0;
if (seg.sourceEnd > seg.sourceStart) return seg.sourceEnd - seg.sourceStart;
if (seg.sourceStart > 0) return Math.max(matDur - seg.sourceStart, 0);
return matDur;
}
/**
* Recalculate segment start/end positions based on effective durations.
* - Segments placed sequentially by effective duration
* - Segments exceeding audioDuration keep their positions (overflow, start >= duration)
* - Last visible segment is capped/extended to exactly audioDuration (loop fill)
*/
function recalcPositions(
segs: TimelineSegment[],
mats: Material[],
duration: number
): TimelineSegment[] {
if (segs.length === 0 || duration <= 0) return segs;
const fallbackDur = duration / segs.length;
let cursor = 0;
const result = segs.map((seg) => {
const effDur = getEffectiveDuration(seg, mats);
const dur = effDur > 0 ? effDur : fallbackDur;
const newSeg = { ...seg, start: cursor, end: cursor + dur };
cursor += dur;
return newSeg;
});
// Find last segment that starts before audioDuration
let lastVisibleIdx = -1;
for (let i = result.length - 1; i >= 0; i--) {
if (result[i].start < duration) {
lastVisibleIdx = i;
break;
}
}
// Cap/extend last visible segment to exactly audioDuration
if (lastVisibleIdx >= 0) {
result[lastVisibleIdx] = { ...result[lastVisibleIdx], end: duration };
}
return result;
}
interface UseTimelineEditorOptions {
audioDuration: number;
materials: Material[];
selectedMaterials: string[];
storageKey?: string;
}
export const useTimelineEditor = ({
audioDuration,
materials,
selectedMaterials,
storageKey,
}: UseTimelineEditorOptions) => {
const [segments, setSegments] = useState<TimelineSegment[]>([]);
const prevKey = useRef("");
const restoredRef = useRef(false);
// Refs for stable callbacks (avoid recreating on every materials/duration change)
const materialsRef = useRef(materials);
const audioDurationRef = useRef(audioDuration);
useEffect(() => {
materialsRef.current = materials;
}, [materials]);
useEffect(() => {
audioDurationRef.current = audioDuration;
}, [audioDuration]);
// Build a durationsKey so segments re-init when material durations become available
const durationsKey = selectedMaterials
.map((id) => materials.find((m) => m.id === id)?.duration_sec ?? 0)
.join(",");
// Build a cache key from materials + duration
const cacheKey = `${selectedMaterials.join(",")}_${audioDuration.toFixed(1)}`;
const lsKey = storageKey ? `vigent_${storageKey}_timeline` : null;
const initSegments = useCallback(() => {
if (selectedMaterials.length === 0 || audioDuration <= 0) {
setSegments([]);
return;
}
// Try restore from localStorage
if (lsKey) {
try {
const raw = localStorage.getItem(lsKey);
if (raw) {
const saved = JSON.parse(raw) as { key: string; segments: SegmentSnapshot[] };
if (saved.key === cacheKey && saved.segments.length === selectedMaterials.length) {
const allMatch = saved.segments.every(
(s, i) => s.materialId === selectedMaterials[i] || saved.segments.some((ss) => ss.materialId === selectedMaterials[i])
);
if (allMatch) {
const restored: TimelineSegment[] = saved.segments.map((s, i) => {
const mat = materials.find((m) => m.id === s.materialId);
return {
id: `seg-${i}-${Date.now()}`,
materialId: s.materialId,
materialName: mat?.scene || mat?.name || s.materialId,
start: 0,
end: 0,
sourceStart: s.sourceStart,
sourceEnd: s.sourceEnd,
color: COLORS[i % COLORS.length],
};
});
setSegments(recalcPositions(restored, materials, audioDuration));
restoredRef.current = true;
return;
}
}
}
} catch {
// ignore parse errors
}
}
// Create fresh segments — positions derived by recalcPositions
const newSegments: TimelineSegment[] = selectedMaterials.map((matId, i) => {
const mat = materials.find((m) => m.id === matId);
return {
id: `seg-${i}-${Date.now()}`,
materialId: matId,
materialName: mat?.scene || mat?.name || matId,
start: 0,
end: 0,
sourceStart: 0,
sourceEnd: 0,
color: COLORS[i % COLORS.length],
};
});
setSegments(recalcPositions(newSegments, materials, audioDuration));
}, [audioDuration, materials, selectedMaterials, lsKey, cacheKey]);
// Auto-init when selectedMaterials, audioDuration, or material durations change
useEffect(() => {
const key = `${selectedMaterials.join(",")}_${audioDuration}_${durationsKey}`;
if (key !== prevKey.current) {
prevKey.current = key;
initSegments();
}
}, [selectedMaterials, audioDuration, durationsKey, initSegments]);
// Persist segments to localStorage on change (debounced)
useEffect(() => {
if (!lsKey || segments.length === 0) return;
const timeout = setTimeout(() => {
const snapshots: SegmentSnapshot[] = segments.map((s) => ({
materialId: s.materialId,
start: s.start,
end: s.end,
sourceStart: s.sourceStart,
sourceEnd: s.sourceEnd,
}));
localStorage.setItem(lsKey, JSON.stringify({ key: cacheKey, segments: snapshots }));
}, 300);
return () => clearTimeout(timeout);
}, [segments, lsKey, cacheKey]);
const reorderSegments = useCallback(
(fromIdx: number, toIdx: number) => {
setSegments((prev) => {
if (fromIdx < 0 || toIdx < 0 || fromIdx >= prev.length || toIdx >= prev.length) return prev;
if (fromIdx === toIdx) return prev;
const next = [...prev];
// Move the segment: remove from old position, insert at new position
const [moved] = next.splice(fromIdx, 1);
next.splice(toIdx, 0, moved);
return recalcPositions(next, materialsRef.current, audioDurationRef.current);
});
},
[]
);
const setSourceRange = useCallback(
(id: string, sourceStart: number, sourceEnd: number) => {
setSegments((prev) => {
const updated = prev.map((s) => (s.id === id ? { ...s, sourceStart, sourceEnd } : s));
return recalcPositions(updated, materialsRef.current, audioDurationRef.current);
});
},
[]
);
const toCustomAssignments = useCallback((): CustomAssignment[] => {
const duration = audioDurationRef.current;
return segments
.filter((seg) => seg.start < duration)
.map((seg) => {
const mat = materialsRef.current.find((m) => m.id === seg.materialId);
return {
material_path: mat?.path || seg.materialId,
start: seg.start,
end: seg.end,
source_start: seg.sourceStart,
source_end: seg.sourceEnd > seg.sourceStart ? seg.sourceEnd : undefined,
};
});
}, [segments]);
return {
segments,
initSegments,
reorderSegments,
setSourceRange,
toCustomAssignments,
};
};

View File

@@ -0,0 +1,293 @@
import { useCallback, useEffect, useRef, useState } from "react";
import { X, Play, Pause } from "lucide-react";
import type { TimelineSegment } from "@/features/home/model/useTimelineEditor";
interface ClipTrimmerProps {
isOpen: boolean;
segment: TimelineSegment | null;
materialUrl: string | null;
onConfirm: (sourceStart: number, sourceEnd: number) => void;
onClose: () => void;
}
function formatSec(sec: number): string {
const m = Math.floor(sec / 60);
const s = sec % 60;
return `${String(m).padStart(2, "0")}:${s.toFixed(1).padStart(4, "0")}`;
}
export function ClipTrimmer({
isOpen,
segment,
materialUrl,
onConfirm,
onClose,
}: ClipTrimmerProps) {
const videoRef = useRef<HTMLVideoElement>(null);
const trackRef = useRef<HTMLDivElement>(null);
const [duration, setDuration] = useState(0);
const [sourceStart, setSourceStart] = useState(0);
const [sourceEnd, setSourceEnd] = useState(0);
const [currentTime, setCurrentTime] = useState(0);
const [isPlaying, setIsPlaying] = useState(false);
const [dragging, setDragging] = useState<"start" | "end" | null>(null);
const animRef = useRef<number>(0);
// Reset state when segment changes
useEffect(() => {
if (segment && isOpen) {
setSourceStart(segment.sourceStart);
setSourceEnd(segment.sourceEnd);
setCurrentTime(segment.sourceStart);
setIsPlaying(false);
}
}, [segment, isOpen]);
// Track currentTime during playback
useEffect(() => {
if (!isPlaying || !videoRef.current) return;
const tick = () => {
if (!videoRef.current) return;
const t = videoRef.current.currentTime;
const end = sourceEnd || duration;
if (t >= end) {
videoRef.current.pause();
videoRef.current.currentTime = sourceStart;
setCurrentTime(sourceStart);
setIsPlaying(false);
return;
}
setCurrentTime(t);
animRef.current = requestAnimationFrame(tick);
};
animRef.current = requestAnimationFrame(tick);
return () => cancelAnimationFrame(animRef.current);
}, [isPlaying, sourceStart, sourceEnd, duration]);
// Seek video when not playing and currentTime changes
useEffect(() => {
if (videoRef.current && !isPlaying) {
videoRef.current.currentTime = currentTime;
}
}, [currentTime, isPlaying]);
const handleLoadedMetadata = useCallback(() => {
if (videoRef.current) {
const dur = videoRef.current.duration;
setDuration(dur);
if (sourceEnd === 0) {
setSourceEnd(dur);
}
}
}, [sourceEnd]);
const togglePlay = useCallback(() => {
if (!videoRef.current || duration === 0) return;
if (isPlaying) {
videoRef.current.pause();
setIsPlaying(false);
} else {
const end = sourceEnd || duration;
if (videoRef.current.currentTime >= end || videoRef.current.currentTime < sourceStart) {
videoRef.current.currentTime = sourceStart;
setCurrentTime(sourceStart);
}
videoRef.current.play().catch(() => {});
setIsPlaying(true);
}
}, [isPlaying, sourceStart, sourceEnd, duration]);
// --- Dual-handle slider logic ---
const getPositionFromEvent = useCallback(
(clientX: number) => {
if (!trackRef.current || duration === 0) return 0;
const rect = trackRef.current.getBoundingClientRect();
const ratio = Math.max(0, Math.min(1, (clientX - rect.left) / rect.width));
return ratio * duration;
},
[duration]
);
const handleThumbPointerDown = useCallback(
(which: "start" | "end", e: React.PointerEvent) => {
e.preventDefault();
e.stopPropagation();
setDragging(which);
(e.target as HTMLElement).setPointerCapture(e.pointerId);
},
[]
);
const handleTrackPointerMove = useCallback(
(e: React.PointerEvent) => {
if (!dragging) return;
const pos = getPositionFromEvent(e.clientX);
const minGap = 0.5;
if (dragging === "start") {
const clamped = Math.max(0, Math.min(pos, (sourceEnd || duration) - minGap));
setSourceStart(clamped);
setCurrentTime(clamped);
} else {
const clamped = Math.min(duration, Math.max(pos, sourceStart + minGap));
setSourceEnd(clamped);
}
},
[dragging, getPositionFromEvent, sourceStart, sourceEnd, duration]
);
const handleTrackPointerUp = useCallback(() => {
setDragging(null);
}, []);
const handleConfirm = () => {
onConfirm(sourceStart, sourceEnd >= duration ? 0 : sourceEnd);
};
if (!isOpen || !segment) return null;
const assignedDur = segment.end - segment.start;
const effectiveEnd = sourceEnd || duration;
const clipDur = effectiveEnd - sourceStart;
const startPct = duration > 0 ? (sourceStart / duration) * 100 : 0;
const endPct = duration > 0 ? (effectiveEnd / duration) * 100 : 100;
const playheadPct = duration > 0 ? (currentTime / duration) * 100 : 0;
return (
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/60 backdrop-blur-sm" onClick={onClose}>
<div
className="bg-gray-900 border border-white/10 rounded-2xl w-full max-w-lg mx-4 overflow-hidden"
onClick={(e) => e.stopPropagation()}
>
{/* Header */}
<div className="flex items-center justify-between px-5 py-3 border-b border-white/10">
<h3 className="text-white font-semibold text-sm">
- {segment.materialName}
</h3>
<button onClick={onClose} className="text-gray-400 hover:text-white">
<X className="h-4 w-4" />
</button>
</div>
{/* Video preview */}
<div className="px-5 pt-4">
<div className="relative bg-black rounded-lg overflow-hidden aspect-video group">
{materialUrl ? (
<video
ref={videoRef}
src={materialUrl}
className="w-full h-full object-contain"
onLoadedMetadata={handleLoadedMetadata}
onEnded={() => setIsPlaying(false)}
preload="auto"
muted
/>
) : (
<div className="flex items-center justify-center h-full text-gray-500 text-sm">
</div>
)}
{/* Play/Pause overlay */}
{materialUrl && (
<button
onClick={togglePlay}
className="absolute inset-0 flex items-center justify-center bg-black/0 hover:bg-black/30 transition-colors"
>
<div className={`p-3 rounded-full bg-black/60 text-white transition-opacity ${isPlaying ? "opacity-0 group-hover:opacity-100" : "opacity-100"}`}>
{isPlaying ? <Pause className="h-6 w-6" /> : <Play className="h-6 w-6" />}
</div>
</button>
)}
<div className="absolute bottom-2 right-2 bg-black/70 text-white text-[10px] px-2 py-0.5 rounded pointer-events-none">
{formatSec(currentTime)}
</div>
</div>
</div>
{/* Dual-handle range slider */}
<div className="px-5 py-4 space-y-3">
<div className="text-xs text-gray-400 flex justify-between">
<span>: {duration > 0 ? formatSec(duration) : "加载中..."}</span>
</div>
{/* Custom range track */}
<div
ref={trackRef}
className="relative h-8 cursor-pointer select-none touch-none"
onPointerMove={handleTrackPointerMove}
onPointerUp={handleTrackPointerUp}
onPointerLeave={handleTrackPointerUp}
>
{/* Background track */}
<div className="absolute top-1/2 -translate-y-1/2 left-0 right-0 h-2 bg-white/10 rounded-full" />
{/* Selected range */}
<div
className="absolute top-1/2 -translate-y-1/2 h-2 rounded-full"
style={{
left: `${startPct}%`,
width: `${endPct - startPct}%`,
backgroundColor: segment.color + "88",
}}
/>
{/* Playhead indicator */}
{duration > 0 && (
<div
className="absolute top-1/2 -translate-y-1/2 w-0.5 h-4 bg-white/60 rounded-full pointer-events-none"
style={{ left: `${playheadPct}%` }}
/>
)}
{/* Start thumb */}
<div
onPointerDown={(e) => handleThumbPointerDown("start", e)}
className="absolute top-1/2 -translate-y-1/2 -translate-x-1/2 w-4 h-4 rounded-full bg-purple-500 border-2 border-white shadow-lg cursor-grab active:cursor-grabbing hover:scale-110 transition-transform z-10"
style={{ left: `${startPct}%` }}
title={`起点: ${formatSec(sourceStart)}`}
/>
{/* End thumb */}
<div
onPointerDown={(e) => handleThumbPointerDown("end", e)}
className="absolute top-1/2 -translate-y-1/2 -translate-x-1/2 w-4 h-4 rounded-full bg-pink-500 border-2 border-white shadow-lg cursor-grab active:cursor-grabbing hover:scale-110 transition-transform z-10"
style={{ left: `${endPct}%` }}
title={`终点: ${formatSec(effectiveEnd)}`}
/>
</div>
{/* Time labels */}
<div className="flex justify-between text-xs text-gray-400">
<span className="text-purple-400">{formatSec(sourceStart)}</span>
<span className="text-pink-400">{formatSec(effectiveEnd)}</span>
</div>
{/* Info */}
<div className="text-[11px] text-gray-500 flex items-center gap-2 flex-wrap">
<span>: {clipDur.toFixed(1)}s</span>
<span className="text-gray-600">|</span>
<span>: {assignedDur.toFixed(1)}s</span>
{clipDur < assignedDur && <span className="text-amber-500">()</span>}
{clipDur > assignedDur && <span className="text-cyan-500">()</span>}
</div>
</div>
{/* Actions */}
<div className="flex justify-end gap-2 px-5 pb-4">
<button
onClick={onClose}
className="px-4 py-1.5 text-xs bg-white/10 hover:bg-white/20 rounded-lg text-gray-300 transition-colors"
>
</button>
<button
onClick={handleConfirm}
className="px-4 py-1.5 text-xs bg-gradient-to-r from-purple-600 to-pink-600 hover:from-purple-700 hover:to-pink-700 text-white rounded-lg transition-colors"
>
</button>
</div>
</div>
</div>
);
}

View File

@@ -35,9 +35,13 @@ interface TitleStyleOption {
interface FloatingStylePreviewProps {
onClose: () => void;
videoTitle: string;
videoSecondaryTitle: string;
titleStyles: TitleStyleOption[];
selectedTitleStyleId: string;
titleFontSize: number;
selectedSecondaryTitleStyleId: string;
secondaryTitleFontSize: number;
secondaryTitleTopMargin: number;
subtitleStyles: SubtitleStyleOption[];
selectedSubtitleStyleId: string;
subtitleFontSize: number;
@@ -56,9 +60,13 @@ const DESKTOP_WIDTH = 280;
export function FloatingStylePreview({
onClose,
videoTitle,
videoSecondaryTitle,
titleStyles,
selectedTitleStyleId,
titleFontSize,
selectedSecondaryTitleStyleId,
secondaryTitleFontSize,
secondaryTitleTopMargin,
subtitleStyles,
selectedSubtitleStyleId,
subtitleFontSize,
@@ -86,6 +94,8 @@ export function FloatingStylePreview({
const previewScale = windowWidth / previewBaseWidth;
const previewHeight = previewBaseHeight * previewScale;
const widthScale = Math.min(1, previewBaseWidth / 1080);
const responsiveScale = Math.max(0.55, widthScale);
const activeSubtitleStyle = subtitleStyles.find((s) => s.id === selectedSubtitleStyleId)
|| subtitleStyles.find((s) => s.is_default)
@@ -102,8 +112,8 @@ export function FloatingStylePreview({
const subtitleHighlightColor = activeSubtitleStyle?.highlight_color || "#FFE600";
const subtitleNormalColor = activeSubtitleStyle?.normal_color || "#FFFFFF";
const subtitleStrokeColor = activeSubtitleStyle?.stroke_color || "#000000";
const subtitleStrokeSize = activeSubtitleStyle?.stroke_size ?? 3;
const subtitleLetterSpacing = activeSubtitleStyle?.letter_spacing ?? 2;
const subtitleStrokeSize = Math.max(1, Math.round((activeSubtitleStyle?.stroke_size ?? 3) * responsiveScale));
const subtitleLetterSpacing = Math.max(0, (activeSubtitleStyle?.letter_spacing ?? 2) * responsiveScale);
const subtitleFontFamilyName = `SubtitlePreview-${activeSubtitleStyle?.id || "default"}`;
const subtitleFontUrl = activeSubtitleStyle?.font_file
? resolveAssetUrl(`fonts/${activeSubtitleStyle.font_file}`)
@@ -111,14 +121,35 @@ export function FloatingStylePreview({
const titleColor = activeTitleStyle?.color || "#FFFFFF";
const titleStrokeColor = activeTitleStyle?.stroke_color || "#000000";
const titleStrokeSize = activeTitleStyle?.stroke_size ?? 8;
const titleLetterSpacing = activeTitleStyle?.letter_spacing ?? 4;
const titleStrokeSize = Math.max(1, Math.round((activeTitleStyle?.stroke_size ?? 8) * responsiveScale));
const titleLetterSpacing = Math.max(0, (activeTitleStyle?.letter_spacing ?? 4) * responsiveScale);
const titleFontWeight = activeTitleStyle?.font_weight ?? 900;
const titleFontFamilyName = `TitlePreview-${activeTitleStyle?.id || "default"}`;
const titleFontUrl = activeTitleStyle?.font_file
? resolveAssetUrl(`fonts/${activeTitleStyle.font_file}`)
: null;
const scaledTitleFontSize = Math.max(36, Math.round(titleFontSize * responsiveScale));
const scaledSubtitleFontSize = Math.max(28, Math.round(subtitleFontSize * responsiveScale));
const scaledTitleTopMargin = Math.max(0, Math.round(titleTopMargin * responsiveScale));
const scaledSubtitleBottomMargin = Math.max(0, Math.round(subtitleBottomMargin * responsiveScale));
// 副标题样式
const activeSecondaryTitleStyle = titleStyles.find((s) => s.id === selectedSecondaryTitleStyleId)
|| activeTitleStyle;
const stColor = activeSecondaryTitleStyle?.color || "#FFFFFF";
const stStrokeColor = activeSecondaryTitleStyle?.stroke_color || "#000000";
const stStrokeSize = Math.max(1, Math.round((activeSecondaryTitleStyle?.stroke_size ?? 6) * responsiveScale));
const stLetterSpacing = Math.max(0, (activeSecondaryTitleStyle?.letter_spacing ?? 2) * responsiveScale);
const stFontWeight = activeSecondaryTitleStyle?.font_weight ?? 700;
const stFontFamilyName = `SecondaryTitlePreview-${activeSecondaryTitleStyle?.id || "default"}`;
const stFontUrl = activeSecondaryTitleStyle?.font_file
? resolveAssetUrl(`fonts/${activeSecondaryTitleStyle.font_file}`)
: null;
const scaledSecondaryTitleFontSize = Math.max(24, Math.round(secondaryTitleFontSize * responsiveScale));
const scaledSecondaryTitleTopMargin = Math.max(0, Math.round(secondaryTitleTopMargin * responsiveScale));
const previewSecondaryTitleText = videoSecondaryTitle.trim() || "";
const content = (
<div
style={{
@@ -152,9 +183,10 @@ export function FloatingStylePreview({
className="relative overflow-hidden rounded-b-xl"
style={{ height: `${previewHeight}px` }}
>
{(titleFontUrl || subtitleFontUrl) && (
{(titleFontUrl || subtitleFontUrl || stFontUrl) && (
<style>{`
${titleFontUrl ? `@font-face { font-family: '${titleFontFamilyName}'; src: url('${titleFontUrl}') format('${getFontFormat(activeTitleStyle?.font_file)}'); font-weight: 400; font-style: normal; }` : ''}
${stFontUrl && stFontUrl !== titleFontUrl ? `@font-face { font-family: '${stFontFamilyName}'; src: url('${stFontUrl}') format('${getFontFormat(activeSecondaryTitleStyle?.font_file)}'); font-weight: 400; font-style: normal; }` : ''}
${subtitleFontUrl ? `@font-face { font-family: '${subtitleFontFamilyName}'; src: url('${subtitleFontUrl}') format('${getFontFormat(activeSubtitleStyle?.font_file)}'); font-weight: 400; font-style: normal; }` : ''}
`}</style>
)}
@@ -172,39 +204,78 @@ export function FloatingStylePreview({
className="w-full text-center"
style={{
position: 'absolute',
top: `${titleTopMargin}px`,
top: `${scaledTitleTopMargin}px`,
left: 0,
right: 0,
color: titleColor,
fontSize: `${titleFontSize}px`,
fontWeight: titleFontWeight,
fontFamily: titleFontUrl
? `'${titleFontFamilyName}', "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei", "Noto Sans SC", sans-serif`
: '"PingFang SC", "Hiragino Sans GB", "Microsoft YaHei", "Noto Sans SC", sans-serif',
textShadow: buildTextShadow(titleStrokeColor, titleStrokeSize),
letterSpacing: `${titleLetterSpacing}px`,
lineHeight: 1.2,
opacity: videoTitle.trim() ? 1 : 0.7,
display: 'flex',
flexDirection: 'column',
alignItems: 'center',
padding: '0 5%',
boxSizing: 'border-box',
}}
>
{previewTitleText}
<div
style={{
color: titleColor,
fontSize: `${scaledTitleFontSize}px`,
fontWeight: titleFontWeight,
fontFamily: titleFontUrl
? `'${titleFontFamilyName}', "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei", "Noto Sans SC", sans-serif`
: '"PingFang SC", "Hiragino Sans GB", "Microsoft YaHei", "Noto Sans SC", sans-serif',
textShadow: buildTextShadow(titleStrokeColor, titleStrokeSize),
letterSpacing: `${titleLetterSpacing}px`,
lineHeight: 1.2,
whiteSpace: 'normal',
wordBreak: 'break-word',
overflowWrap: 'anywhere',
opacity: videoTitle.trim() ? 1 : 0.7,
}}
>
{previewTitleText}
</div>
{previewSecondaryTitleText && (
<div
style={{
marginTop: `${scaledSecondaryTitleTopMargin}px`,
color: stColor,
fontSize: `${scaledSecondaryTitleFontSize}px`,
fontWeight: stFontWeight,
fontFamily: stFontUrl && stFontUrl !== titleFontUrl
? `'${stFontFamilyName}', "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei", "Noto Sans SC", sans-serif`
: titleFontUrl
? `'${titleFontFamilyName}', "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei", "Noto Sans SC", sans-serif`
: '"PingFang SC", "Hiragino Sans GB", "Microsoft YaHei", "Noto Sans SC", sans-serif',
textShadow: buildTextShadow(stStrokeColor, stStrokeSize),
letterSpacing: `${stLetterSpacing}px`,
lineHeight: 1.2,
whiteSpace: 'normal',
wordBreak: 'break-word',
overflowWrap: 'anywhere',
}}
>
{previewSecondaryTitleText}
</div>
)}
</div>
<div
className="w-full text-center"
style={{
position: 'absolute',
bottom: `${subtitleBottomMargin}px`,
bottom: `${scaledSubtitleBottomMargin}px`,
left: 0,
right: 0,
fontSize: `${subtitleFontSize}px`,
fontSize: `${scaledSubtitleFontSize}px`,
fontFamily: subtitleFontUrl
? `'${subtitleFontFamilyName}', "PingFang SC", "Hiragino Sans GB", "Microsoft YaHei", "Noto Sans SC", sans-serif`
: '"PingFang SC", "Hiragino Sans GB", "Microsoft YaHei", "Noto Sans SC", sans-serif',
textShadow: buildTextShadow(subtitleStrokeColor, subtitleStrokeSize),
letterSpacing: `${subtitleLetterSpacing}px`,
lineHeight: 1.35,
whiteSpace: 'normal',
wordBreak: 'break-word',
overflowWrap: 'anywhere',
boxSizing: 'border-box',
padding: '0 6%',
}}
>

View File

@@ -0,0 +1,293 @@
import { useState, useRef, useCallback, useEffect } from "react";
import { Play, Pause, Pencil, Trash2, Check, X, RefreshCw, Mic, ChevronDown } from "lucide-react";
import type { GeneratedAudio } from "@/features/home/model/useGeneratedAudios";
interface AudioTask {
status: string;
progress?: number;
message?: string;
}
interface GeneratedAudiosPanelProps {
generatedAudios: GeneratedAudio[];
selectedAudioId: string | null;
isGeneratingAudio: boolean;
audioTask: AudioTask | null;
onGenerateAudio: () => void;
onRefresh: () => void;
onSelectAudio: (audio: GeneratedAudio) => void;
onDeleteAudio: (id: string) => void;
onRenameAudio: (id: string, newName: string) => void;
hasText: boolean;
missingRefAudio?: boolean;
speed: number;
onSpeedChange: (speed: number) => void;
ttsMode: string;
}
export function GeneratedAudiosPanel({
generatedAudios,
selectedAudioId,
isGeneratingAudio,
audioTask,
onGenerateAudio,
onRefresh,
onSelectAudio,
onDeleteAudio,
onRenameAudio,
hasText,
missingRefAudio = false,
speed,
onSpeedChange,
ttsMode,
}: GeneratedAudiosPanelProps) {
const [editingId, setEditingId] = useState<string | null>(null);
const [editName, setEditName] = useState("");
const [playingId, setPlayingId] = useState<string | null>(null);
const [speedOpen, setSpeedOpen] = useState(false);
const audioRef = useRef<HTMLAudioElement | null>(null);
const speedRef = useRef<HTMLDivElement>(null);
const stopPlaying = useCallback(() => {
if (audioRef.current) {
audioRef.current.pause();
audioRef.current.currentTime = 0;
audioRef.current = null;
}
setPlayingId(null);
}, []);
// Cleanup on unmount
useEffect(() => {
return () => {
if (audioRef.current) {
audioRef.current.pause();
audioRef.current = null;
}
};
}, []);
// Close speed dropdown on click outside
useEffect(() => {
const handler = (e: MouseEvent) => {
if (speedRef.current && !speedRef.current.contains(e.target as Node)) {
setSpeedOpen(false);
}
};
if (speedOpen) document.addEventListener("mousedown", handler);
return () => document.removeEventListener("mousedown", handler);
}, [speedOpen]);
const togglePlay = (audio: GeneratedAudio, e: React.MouseEvent) => {
e.stopPropagation();
if (playingId === audio.id) {
stopPlaying();
return;
}
stopPlaying();
const player = new Audio(audio.path);
player.onended = () => setPlayingId(null);
player.play().catch(() => {});
audioRef.current = player;
setPlayingId(audio.id);
};
const startEditing = (audio: GeneratedAudio, e: React.MouseEvent) => {
e.stopPropagation();
setEditingId(audio.id);
setEditName(audio.name);
};
const saveEditing = (audioId: string, e: React.MouseEvent) => {
e.stopPropagation();
if (!editName.trim()) return;
onRenameAudio(audioId, editName.trim());
setEditingId(null);
setEditName("");
};
const cancelEditing = (e: React.MouseEvent) => {
e.stopPropagation();
setEditingId(null);
setEditName("");
};
const canGenerate = hasText && !missingRefAudio;
const speedOptions = [
{ value: 0.8, label: "较慢" },
{ value: 0.9, label: "稍慢" },
{ value: 1.0, label: "正常" },
{ value: 1.1, label: "稍快" },
{ value: 1.2, label: "较快" },
] as const;
const currentSpeedLabel = speedOptions.find((o) => o.value === speed)?.label ?? "正常";
return (
<div className="bg-white/5 rounded-2xl p-4 sm:p-6 border border-white/10 backdrop-blur-sm relative z-10">
<div className="flex justify-between items-center gap-2 mb-4">
<h2 className="text-base sm:text-lg font-semibold text-white flex items-center gap-2 whitespace-nowrap">
<Mic className="h-4 w-4 text-purple-400" />
</h2>
<div className="flex gap-1.5">
{/* 语速下拉 (仅声音克隆模式) */}
{ttsMode === "voiceclone" && (
<div ref={speedRef} className="relative">
<button
onClick={() => setSpeedOpen((v) => !v)}
className="px-2 py-1 text-xs bg-white/10 hover:bg-white/20 rounded text-gray-300 whitespace-nowrap flex items-center gap-1 transition-all"
>
: {currentSpeedLabel}
<ChevronDown className={`h-3 w-3 transition-transform ${speedOpen ? "rotate-180" : ""}`} />
</button>
{speedOpen && (
<div className="absolute right-0 top-full mt-1 bg-gray-800 border border-white/20 rounded-lg shadow-xl py-1 z-50 min-w-[80px]">
{speedOptions.map((opt) => (
<button
key={opt.value}
onClick={() => { onSpeedChange(opt.value); setSpeedOpen(false); }}
className={`w-full text-left px-3 py-1.5 text-xs transition-colors ${
speed === opt.value
? "bg-purple-600/40 text-purple-200"
: "text-gray-300 hover:bg-white/10"
}`}
>
{opt.label}
</button>
))}
</div>
)}
</div>
)}
<button
onClick={onGenerateAudio}
disabled={isGeneratingAudio || !canGenerate}
title={missingRefAudio ? "请先选择参考音频" : !hasText ? "请先输入文案" : ""}
className={`px-2 py-1 text-xs rounded transition-all whitespace-nowrap flex items-center gap-1 ${
isGeneratingAudio || !canGenerate
? "bg-gray-600 cursor-not-allowed text-gray-400"
: "bg-gradient-to-r from-purple-600 to-pink-600 hover:from-purple-700 hover:to-pink-700 text-white"
}`}
>
<Mic className="h-3.5 w-3.5" />
</button>
<button
onClick={onRefresh}
className="px-2 py-1 text-xs bg-white/10 hover:bg-white/20 rounded text-gray-300 whitespace-nowrap flex items-center gap-1"
>
<RefreshCw className="h-3.5 w-3.5" />
</button>
</div>
</div>
{/* 缺少参考音频提示 */}
{missingRefAudio && (
<div className="mb-3 px-3 py-2 bg-yellow-500/10 border border-yellow-500/30 rounded-lg text-yellow-300 text-xs">
</div>
)}
{/* 生成进度 */}
{isGeneratingAudio && audioTask && (
<div className="mb-4 p-3 bg-purple-500/10 rounded-xl border border-purple-500/30">
<div className="flex justify-between text-sm text-purple-300 mb-2">
<span>{audioTask.message || "生成中..."}</span>
<span>{audioTask.progress || 0}%</span>
</div>
<div className="h-2 bg-black/30 rounded-full overflow-hidden">
<div
className="h-full bg-gradient-to-r from-purple-500 to-pink-500 transition-all duration-300"
style={{ width: `${audioTask.progress || 0}%` }}
/>
</div>
</div>
)}
{/* 配音列表 */}
{generatedAudios.length === 0 ? (
<div className="text-center py-6 text-gray-400">
<p className="text-sm"></p>
<p className="text-xs mt-1 text-gray-500"></p>
</div>
) : (
<div className="space-y-2 max-h-48 sm:max-h-56 overflow-y-auto hide-scrollbar">
{generatedAudios.map((audio) => {
const isSelected = selectedAudioId === audio.id;
return (
<div
key={audio.id}
onClick={() => onSelectAudio(audio)}
className={`p-3 rounded-lg border transition-all cursor-pointer flex items-center justify-between group ${
isSelected
? "border-purple-500 bg-purple-500/20"
: "border-white/10 bg-white/5 hover:border-white/30"
}`}
>
{editingId === audio.id ? (
<div className="flex-1 flex items-center gap-2" onClick={(e) => e.stopPropagation()}>
<input
value={editName}
onChange={(e) => setEditName(e.target.value)}
className="flex-1 bg-black/40 border border-white/20 rounded-md px-2 py-1 text-xs text-white"
autoFocus
onKeyDown={(e) => {
if (e.key === "Enter") saveEditing(audio.id, e as unknown as React.MouseEvent);
if (e.key === "Escape") cancelEditing(e as unknown as React.MouseEvent);
}}
/>
<button onClick={(e) => saveEditing(audio.id, e)} className="p-1 text-green-400 hover:text-green-300" title="保存">
<Check className="h-4 w-4" />
</button>
<button onClick={cancelEditing} className="p-1 text-gray-400 hover:text-white" title="取消">
<X className="h-4 w-4" />
</button>
</div>
) : (
<>
<div className="min-w-0 flex-1">
<div className="text-white text-sm truncate">{audio.name}</div>
<div className="text-gray-400 text-xs">{audio.duration_sec.toFixed(1)}s</div>
</div>
<div className="flex items-center gap-1 pl-2 opacity-0 group-hover:opacity-100 transition-opacity">
<button
onClick={(e) => togglePlay(audio, e)}
className="p-1 text-gray-500 hover:text-purple-400 transition-colors"
title={playingId === audio.id ? "暂停" : "播放"}
>
{playingId === audio.id ? (
<Pause className="h-3.5 w-3.5" />
) : (
<Play className="h-3.5 w-3.5" />
)}
</button>
<button
onClick={(e) => startEditing(audio, e)}
className="p-1 text-gray-500 hover:text-white transition-colors"
title="重命名"
>
<Pencil className="h-3.5 w-3.5" />
</button>
<button
onClick={(e) => {
e.stopPropagation();
onDeleteAudio(audio.id);
}}
className="p-1 text-gray-500 hover:text-red-400 transition-colors"
title="删除"
>
<Trash2 className="h-3.5 w-3.5" />
</button>
</div>
</>
)}
</div>
);
})}
</div>
)}
</div>
);
}

View File

@@ -1,20 +1,24 @@
"use client";
import { useEffect } from "react";
import { useEffect, useMemo } from "react";
import { useRouter } from "next/navigation";
import VideoPreviewModal from "@/components/VideoPreviewModal";
import ScriptExtractionModal from "./ScriptExtractionModal";
import { useHomeController } from "@/features/home/model/useHomeController";
import { resolveMediaUrl } from "@/shared/lib/media";
import { BgmPanel } from "@/features/home/ui/BgmPanel";
import { GenerateActionBar } from "@/features/home/ui/GenerateActionBar";
import { HistoryList } from "@/features/home/ui/HistoryList";
import { HomeHeader } from "@/features/home/ui/HomeHeader";
import { MaterialSelector } from "@/features/home/ui/MaterialSelector";
import { TimelineEditor } from "@/features/home/ui/TimelineEditor";
import { ClipTrimmer } from "@/features/home/ui/ClipTrimmer";
import { PreviewPanel } from "@/features/home/ui/PreviewPanel";
import { RefAudioPanel } from "@/features/home/ui/RefAudioPanel";
import { ScriptEditor } from "@/features/home/ui/ScriptEditor";
import { TitleSubtitlePanel } from "@/features/home/ui/TitleSubtitlePanel";
import { VoiceSelector } from "@/features/home/ui/VoiceSelector";
import { GeneratedAudiosPanel } from "@/features/home/ui/GeneratedAudiosPanel";
export function HomePage() {
const router = useRouter();
@@ -36,7 +40,6 @@ export function HomePage() {
handleUpload,
selectedMaterials,
toggleMaterial,
reorderMaterials,
handlePreviewMaterial,
editingMaterialId,
editMaterialName,
@@ -54,6 +57,9 @@ export function HomePage() {
isTranslating,
originalText,
handleRestoreOriginal,
savedScripts,
handleSaveScript,
deleteSavedScript,
showStylePreview,
setShowStylePreview,
videoTitle,
@@ -64,6 +70,15 @@ export function HomePage() {
titleFontSize,
setTitleFontSize,
setTitleSizeLocked,
videoSecondaryTitle,
secondaryTitleInput,
selectedSecondaryTitleStyleId,
setSelectedSecondaryTitleStyleId,
secondaryTitleFontSize,
setSecondaryTitleFontSize,
setSecondaryTitleSizeLocked,
secondaryTitleTopMargin,
setSecondaryTitleTopMargin,
subtitleStyles,
selectedSubtitleStyleId,
setSelectedSubtitleStyleId,
@@ -74,10 +89,13 @@ export function HomePage() {
setTitleTopMargin,
subtitleBottomMargin,
setSubtitleBottomMargin,
titleDisplayMode,
setTitleDisplayMode,
outputAspectRatio,
setOutputAspectRatio,
resolveAssetUrl,
getFontFormat,
buildTextShadow,
materialDimensions,
ttsMode,
setTtsMode,
voices,
@@ -100,6 +118,8 @@ export function HomePage() {
saveEditing,
cancelEditing,
deleteRefAudio,
retranscribeRefAudio,
retranscribingId,
recordedBlob,
isRecording,
recordingTime,
@@ -107,7 +127,6 @@ export function HomePage() {
stopRecording,
useRecording,
formatRecordingTime,
fixedRefText,
bgmList,
bgmLoading,
bgmError,
@@ -133,12 +152,47 @@ export function HomePage() {
fetchGeneratedVideos,
registerVideoRef,
formatDate,
generatedAudios,
selectedAudio,
selectedAudioId,
isGeneratingAudio,
audioTask,
fetchGeneratedAudios,
handleGenerateAudio,
deleteAudio,
renameAudio,
selectAudio,
speed,
setSpeed,
timelineSegments,
reorderSegments,
setSourceRange,
clipTrimmerOpen,
setClipTrimmerOpen,
clipTrimmerSegmentId,
setClipTrimmerSegmentId,
} = useHomeController();
useEffect(() => {
router.prefetch("/publish");
}, [router]);
useEffect(() => {
if (typeof window === "undefined") return;
window.scrollTo({ top: 0, left: 0, behavior: "auto" });
}, []);
const clipTrimmerSegment = useMemo(
() => timelineSegments.find((s) => s.id === clipTrimmerSegmentId) ?? null,
[timelineSegments, clipTrimmerSegmentId]
);
const clipTrimmerMaterialUrl = useMemo(() => {
if (!clipTrimmerSegment) return null;
const mat = materials.find((m) => m.id === clipTrimmerSegment.materialId);
return mat?.path ? resolveMediaUrl(mat.path) : null;
}, [clipTrimmerSegment, materials]);
return (
<div className="min-h-dvh">
<HomeHeader />
@@ -147,34 +201,7 @@ export function HomePage() {
<div className="grid grid-cols-1 lg:grid-cols-2 gap-8">
{/* 左侧: 输入区域 */}
<div className="space-y-6">
{/* 素材选择 */}
<MaterialSelector
materials={materials}
selectedMaterials={selectedMaterials}
isFetching={isFetching}
lastMaterialCount={lastMaterialCount}
editingMaterialId={editingMaterialId}
editMaterialName={editMaterialName}
isUploading={isUploading}
uploadProgress={uploadProgress}
uploadError={uploadError}
fetchError={fetchError}
apiBase={apiBase}
onUploadChange={handleUpload}
onRefresh={fetchMaterials}
onToggleMaterial={toggleMaterial}
onReorderMaterials={reorderMaterials}
onPreviewMaterial={handlePreviewMaterial}
onStartEditing={startMaterialEditing}
onEditNameChange={setEditMaterialName}
onSaveEditing={saveMaterialEditing}
onCancelEditing={cancelMaterialEditing}
onDeleteMaterial={deleteMaterial}
onClearUploadError={() => setUploadError(null)}
registerMaterialRef={registerMaterialRef}
/>
{/* 文案输入 */}
{/* 1. 文案输入 */}
<ScriptEditor
text={text}
onChangeText={setText}
@@ -185,9 +212,13 @@ export function HomePage() {
isTranslating={isTranslating}
hasOriginalText={originalText !== null}
onRestoreOriginal={handleRestoreOriginal}
savedScripts={savedScripts}
onSaveScript={handleSaveScript}
onLoadScript={setText}
onDeleteScript={deleteSavedScript}
/>
{/* 标题和字幕设置 */}
{/* 2. 标题和字幕设置 */}
<TitleSubtitlePanel
showStylePreview={showStylePreview}
onTogglePreview={() => setShowStylePreview((prev) => !prev)}
@@ -195,6 +226,10 @@ export function HomePage() {
onTitleChange={titleInput.handleChange}
onTitleCompositionStart={titleInput.handleCompositionStart}
onTitleCompositionEnd={titleInput.handleCompositionEnd}
videoSecondaryTitle={videoSecondaryTitle}
onSecondaryTitleChange={secondaryTitleInput.handleChange}
onSecondaryTitleCompositionStart={secondaryTitleInput.handleCompositionStart}
onSecondaryTitleCompositionEnd={secondaryTitleInput.handleCompositionEnd}
titleStyles={titleStyles}
selectedTitleStyleId={selectedTitleStyleId}
onSelectTitleStyle={setSelectedTitleStyleId}
@@ -203,6 +238,15 @@ export function HomePage() {
setTitleFontSize(value);
setTitleSizeLocked(true);
}}
selectedSecondaryTitleStyleId={selectedSecondaryTitleStyleId}
onSelectSecondaryTitleStyle={setSelectedSecondaryTitleStyleId}
secondaryTitleFontSize={secondaryTitleFontSize}
onSecondaryTitleFontSizeChange={(value) => {
setSecondaryTitleFontSize(value);
setSecondaryTitleSizeLocked(true);
}}
secondaryTitleTopMargin={secondaryTitleTopMargin}
onSecondaryTitleTopMarginChange={setSecondaryTitleTopMargin}
subtitleStyles={subtitleStyles}
selectedSubtitleStyleId={selectedSubtitleStyleId}
onSelectSubtitleStyle={setSelectedSubtitleStyleId}
@@ -215,14 +259,16 @@ export function HomePage() {
onTitleTopMarginChange={setTitleTopMargin}
subtitleBottomMargin={subtitleBottomMargin}
onSubtitleBottomMarginChange={setSubtitleBottomMargin}
titleDisplayMode={titleDisplayMode}
onTitleDisplayModeChange={setTitleDisplayMode}
resolveAssetUrl={resolveAssetUrl}
getFontFormat={getFontFormat}
buildTextShadow={buildTextShadow}
previewBaseWidth={materialDimensions?.width || 1080}
previewBaseHeight={materialDimensions?.height || 1920}
previewBaseWidth={outputAspectRatio === "16:9" ? 1920 : 1080}
previewBaseHeight={outputAspectRatio === "16:9" ? 1080 : 1920}
/>
{/* 配音方式选择 */}
{/* 3. 配音方式选择 */}
<VoiceSelector
ttsMode={ttsMode}
onSelectTtsMode={setTtsMode}
@@ -248,6 +294,8 @@ export function HomePage() {
onSaveEditing={saveEditing}
onCancelEditing={cancelEditing}
onDeleteRefAudio={deleteRefAudio}
onRetranscribe={retranscribeRefAudio}
retranscribingId={retranscribingId}
recordedBlob={recordedBlob}
isRecording={isRecording}
recordingTime={recordingTime}
@@ -255,12 +303,79 @@ export function HomePage() {
onStopRecording={stopRecording}
onUseRecording={useRecording}
formatRecordingTime={formatRecordingTime}
fixedRefText={fixedRefText}
/>
)}
/>
{/* 背景音乐 */}
{/* 4. 配音列表 */}
<GeneratedAudiosPanel
generatedAudios={generatedAudios}
selectedAudioId={selectedAudioId}
isGeneratingAudio={isGeneratingAudio}
audioTask={audioTask}
onGenerateAudio={handleGenerateAudio}
onRefresh={() => fetchGeneratedAudios()}
onSelectAudio={selectAudio}
onDeleteAudio={deleteAudio}
onRenameAudio={renameAudio}
hasText={!!text.trim()}
missingRefAudio={ttsMode === "voiceclone" && !selectedRefAudio}
speed={speed}
onSpeedChange={setSpeed}
ttsMode={ttsMode}
/>
{/* 5. 视频素材 */}
<MaterialSelector
materials={materials}
selectedMaterials={selectedMaterials}
isFetching={isFetching}
lastMaterialCount={lastMaterialCount}
editingMaterialId={editingMaterialId}
editMaterialName={editMaterialName}
isUploading={isUploading}
uploadProgress={uploadProgress}
uploadError={uploadError}
fetchError={fetchError}
apiBase={apiBase}
onUploadChange={handleUpload}
onRefresh={fetchMaterials}
onToggleMaterial={toggleMaterial}
onPreviewMaterial={handlePreviewMaterial}
onStartEditing={startMaterialEditing}
onEditNameChange={setEditMaterialName}
onSaveEditing={saveMaterialEditing}
onCancelEditing={cancelMaterialEditing}
onDeleteMaterial={deleteMaterial}
onClearUploadError={() => setUploadError(null)}
registerMaterialRef={registerMaterialRef}
/>
{/* 5.5 时间轴编辑器 — 未选配音/素材时模糊遮挡 */}
<div className="relative">
{(!selectedAudio || selectedMaterials.length === 0) && (
<div className="absolute inset-0 bg-black/50 backdrop-blur-sm rounded-2xl flex items-center justify-center z-10">
<p className="text-gray-400">
{!selectedAudio ? "请先生成并选中配音" : "请先选择素材"}
</p>
</div>
)}
<TimelineEditor
audioDuration={selectedAudio?.duration_sec ?? 0}
audioUrl={selectedAudio ? (resolveMediaUrl(selectedAudio.path) || "") : ""}
segments={timelineSegments}
materials={materials}
outputAspectRatio={outputAspectRatio}
onOutputAspectRatioChange={setOutputAspectRatio}
onReorderSegment={reorderSegments}
onClickSegment={(seg) => {
setClipTrimmerSegmentId(seg.id);
setClipTrimmerOpen(true);
}}
/>
</div>
{/* 6. 背景音乐 */}
<BgmPanel
bgmList={bgmList}
bgmLoading={bgmLoading}
@@ -278,12 +393,12 @@ export function HomePage() {
registerBgmItemRef={registerBgmItemRef}
/>
{/* 生成按钮 */}
{/* 7. 生成按钮 */}
<GenerateActionBar
isGenerating={isGenerating}
progress={currentTask?.progress || 0}
materialCount={selectedMaterials.length}
disabled={isGenerating || selectedMaterials.length === 0 || (ttsMode === "voiceclone" && !selectedRefAudio)}
disabled={isGenerating || selectedMaterials.length === 0 || !selectedAudio}
onGenerate={handleGenerate}
/>
</div>
@@ -319,6 +434,19 @@ export function HomePage() {
onClose={() => setExtractModalOpen(false)}
onApply={(nextText) => setText(nextText)}
/>
<ClipTrimmer
isOpen={clipTrimmerOpen}
segment={clipTrimmerSegment}
materialUrl={clipTrimmerMaterialUrl}
onConfirm={(sourceStart, sourceEnd) => {
if (clipTrimmerSegmentId) {
setSourceRange(clipTrimmerSegmentId, sourceStart, sourceEnd);
}
setClipTrimmerOpen(false);
}}
onClose={() => setClipTrimmerOpen(false)}
/>
</div>
);
}

View File

@@ -1,21 +1,6 @@
import { type ChangeEvent, type MouseEvent } from "react";
import { Upload, RefreshCw, Eye, Trash2, X, Pencil, Check, GripVertical } from "lucide-react";
import { Upload, RefreshCw, Eye, Trash2, X, Pencil, Check } from "lucide-react";
import type { Material } from "@/shared/types/material";
import {
DndContext,
closestCenter,
KeyboardSensor,
PointerSensor,
useSensor,
useSensors,
type DragEndEvent,
} from "@dnd-kit/core";
import {
SortableContext,
horizontalListSortingStrategy,
useSortable,
} from "@dnd-kit/sortable";
import { CSS } from "@dnd-kit/utilities";
interface MaterialSelectorProps {
materials: Material[];
@@ -32,7 +17,6 @@ interface MaterialSelectorProps {
onUploadChange: (event: ChangeEvent<HTMLInputElement>) => void;
onRefresh: () => void;
onToggleMaterial: (id: string) => void;
onReorderMaterials: (activeId: string, overId: string) => void;
onPreviewMaterial: (path: string) => void;
onStartEditing: (material: Material, event: MouseEvent) => void;
onEditNameChange: (value: string) => void;
@@ -43,61 +27,6 @@ interface MaterialSelectorProps {
registerMaterialRef: (id: string, element: HTMLDivElement | null) => void;
}
function SortableChip({
id,
index,
label,
onRemove,
}: {
id: string;
index: number;
label: string;
onRemove: () => void;
}) {
const {
attributes,
listeners,
setNodeRef,
transform,
transition,
isDragging,
} = useSortable({ id });
const style = {
transform: CSS.Translate.toString(transform),
transition,
};
const circledNumbers = ["\u2460", "\u2461", "\u2462", "\u2463", "\u2464", "\u2465", "\u2466", "\u2467", "\u2468", "\u2469"];
return (
<div
ref={setNodeRef}
style={style}
className={`flex items-center gap-1 rounded-lg px-2 py-1 text-xs whitespace-nowrap transition-colors ${
isDragging
? "bg-purple-500/50 border border-purple-400 text-white shadow-lg shadow-purple-500/30 z-10"
: "bg-purple-500/30 border border-purple-500/50 text-purple-200"
}`}
>
<span {...attributes} {...listeners} className="cursor-grab active:cursor-grabbing text-purple-400">
<GripVertical className="h-3 w-3" />
</span>
<span className="text-purple-300">{circledNumbers[index] || `${index + 1}`}</span>
<span className="max-w-[80px] truncate">{label}</span>
<button
onClick={(e) => {
e.stopPropagation();
onRemove();
}}
className="text-purple-400 hover:text-white ml-0.5"
>
<X className="h-3 w-3" />
</button>
</div>
);
}
export function MaterialSelector({
materials,
selectedMaterials,
@@ -113,7 +42,6 @@ export function MaterialSelector({
onUploadChange,
onRefresh,
onToggleMaterial,
onReorderMaterials,
onPreviewMaterial,
onStartEditing,
onEditNameChange,
@@ -123,21 +51,8 @@ export function MaterialSelector({
onClearUploadError,
registerMaterialRef,
}: MaterialSelectorProps) {
const sensors = useSensors(
useSensor(PointerSensor, { activationConstraint: { distance: 5 } }),
useSensor(KeyboardSensor)
);
const handleDragEnd = (event: DragEndEvent) => {
const { active, over } = event;
if (over && active.id !== over.id) {
onReorderMaterials(String(active.id), String(over.id));
}
};
const selectedSet = new Set(selectedMaterials);
const isFull = selectedMaterials.length >= 4;
const circledNumbers = ["\u2460", "\u2461", "\u2462", "\u2463", "\u2464", "\u2465", "\u2466", "\u2467", "\u2468", "\u2469"];
return (
<div className="bg-white/5 rounded-2xl p-4 sm:p-6 border border-white/10 backdrop-blur-sm">
@@ -200,38 +115,6 @@ export function MaterialSelector({
</div>
)}
{/* 已选素材排列(拖拽排序区) - 仅当选中 >= 2 个时显示 */}
{selectedMaterials.length >= 2 && (
<div className="mb-3 p-3 bg-purple-500/10 rounded-xl border border-purple-500/20">
<div className="text-[11px] text-purple-300/70 mb-2">🎬 ()</div>
<DndContext
sensors={sensors}
collisionDetection={closestCenter}
onDragEnd={handleDragEnd}
>
<SortableContext
items={selectedMaterials}
strategy={horizontalListSortingStrategy}
>
<div className="flex flex-wrap gap-1.5">
{selectedMaterials.map((id, index) => {
const m = materials.find((x) => x.id === id);
return (
<SortableChip
key={id}
id={id}
index={index}
label={m?.scene || m?.name || id}
onRemove={() => onToggleMaterial(id)}
/>
);
})}
</div>
</SortableContext>
</DndContext>
</div>
)}
{fetchError ? (
<div className="p-4 bg-red-500/20 text-red-200 rounded-xl text-sm mb-4">
: {fetchError}
@@ -265,7 +148,6 @@ export function MaterialSelector({
>
{materials.map((m) => {
const isSelected = selectedSet.has(m.id);
const selIndex = selectedMaterials.indexOf(m.id);
return (
<div
key={m.id}
@@ -309,7 +191,7 @@ export function MaterialSelector({
: "border-white/30 text-transparent"
}`}
>
{isSelected ? (selIndex >= 0 ? circledNumbers[selIndex] || "✓" : "✓") : ""}
{isSelected ? "✓" : ""}
</span>
<div className="min-w-0">
<div className="text-white text-sm truncate">{m.scene || m.name}</div>

View File

@@ -1,6 +1,6 @@
import { useEffect, useState } from "react";
import type { MouseEvent } from "react";
import { Upload, RefreshCw, Play, Pause, Pencil, Trash2, Check, X, Mic, Square } from "lucide-react";
import { Upload, RefreshCw, Play, Pause, Pencil, Trash2, Check, X, Mic, Square, RotateCw } from "lucide-react";
interface RefAudio {
id: string;
@@ -29,6 +29,8 @@ interface RefAudioPanelProps {
onSaveEditing: (id: string, event: MouseEvent) => void;
onCancelEditing: (event: MouseEvent) => void;
onDeleteRefAudio: (id: string) => void;
onRetranscribe: (id: string) => void;
retranscribingId: string | null;
recordedBlob: Blob | null;
isRecording: boolean;
recordingTime: number;
@@ -36,9 +38,10 @@ interface RefAudioPanelProps {
onStopRecording: () => void;
onUseRecording: () => void;
formatRecordingTime: (seconds: number) => string;
fixedRefText: string;
}
const OLD_FIXED_REF_TEXT = "其实生活中有许多美好的瞬间";
export function RefAudioPanel({
refAudios,
selectedRefAudio,
@@ -57,6 +60,8 @@ export function RefAudioPanel({
onSaveEditing,
onCancelEditing,
onDeleteRefAudio,
onRetranscribe,
retranscribingId,
recordedBlob,
isRecording,
recordingTime,
@@ -64,7 +69,6 @@ export function RefAudioPanel({
onStopRecording,
onUseRecording,
formatRecordingTime,
fixedRefText,
}: RefAudioPanelProps) {
const [recordedUrl, setRecordedUrl] = useState<string | null>(null);
@@ -81,6 +85,9 @@ export function RefAudioPanel({
};
}, [recordedBlob]);
const needsRetranscribe = (audio: RefAudio) =>
audio.ref_text.startsWith(OLD_FIXED_REF_TEXT);
return (
<div className="space-y-4">
<div>
@@ -122,7 +129,7 @@ export function RefAudioPanel({
{isUploadingRef && (
<div className="mb-2 p-2 bg-purple-500/10 rounded text-sm text-purple-300">
...
...
</div>
)}
@@ -192,6 +199,17 @@ export function RefAudioPanel({
<Play className="h-3.5 w-3.5" />
)}
</button>
<button
onClick={(e) => {
e.stopPropagation();
onRetranscribe(audio.id);
}}
disabled={retranscribingId === audio.id}
className="text-gray-400 hover:text-cyan-400 text-xs disabled:opacity-50"
title="重新识别文字"
>
<RotateCw className={`h-3.5 w-3.5 ${retranscribingId === audio.id ? 'animate-spin' : ''}`} />
</button>
<button
onClick={(e) => onStartEditing(audio, e)}
className="text-gray-400 hover:text-blue-400 text-xs"
@@ -211,7 +229,12 @@ export function RefAudioPanel({
</button>
</div>
</div>
<div className="text-gray-400 text-xs">{audio.duration_sec.toFixed(1)}s</div>
<div className="text-gray-400 text-xs">
{audio.duration_sec.toFixed(1)}s
{needsRetranscribe(audio) && (
<span className="text-yellow-500 ml-1" title="需要重新识别文字"></span>
)}
</div>
</>
)}
</div>
@@ -221,7 +244,7 @@ export function RefAudioPanel({
</div>
<div className="border-t border-white/10 pt-4">
<span className="text-sm text-gray-300 mb-2 block">🎤 线</span>
<span className="text-sm text-gray-300 mb-2 block">🎤 线 <span className="text-xs text-gray-500"> 3-10 </span></span>
<div className="flex gap-2 items-center">
{!isRecording ? (
<button
@@ -264,15 +287,9 @@ export function RefAudioPanel({
)}
</div>
<div className="border-t border-white/10 pt-4">
<label className="text-sm text-gray-300 mb-2 block">📝 /</label>
<div className="w-full bg-black/30 border border-white/10 rounded-lg p-3 text-white text-sm">
{fixedRefText}
</div>
<p className="text-xs text-gray-500 mt-1">
</p>
</div>
<p className="text-xs text-gray-500 mt-2 border-t border-white/10 pt-3">
3-10
</p>
</div>
);
}

View File

@@ -1,5 +1,6 @@
import { useEffect, useRef, useState } from "react";
import { FileText, Languages, Loader2, RotateCcw, Sparkles } from "lucide-react";
import { FileText, History, Languages, Loader2, RotateCcw, Save, Sparkles, Trash2 } from "lucide-react";
import type { SavedScript } from "@/features/home/model/useSavedScripts";
const LANGUAGES = [
{ code: "English", label: "英语 English" },
@@ -23,6 +24,10 @@ interface ScriptEditorProps {
isTranslating: boolean;
hasOriginalText: boolean;
onRestoreOriginal: () => void;
savedScripts: SavedScript[];
onSaveScript: () => void;
onLoadScript: (content: string) => void;
onDeleteScript: (id: string) => void;
}
export function ScriptEditor({
@@ -35,9 +40,15 @@ export function ScriptEditor({
isTranslating,
hasOriginalText,
onRestoreOriginal,
savedScripts,
onSaveScript,
onLoadScript,
onDeleteScript,
}: ScriptEditorProps) {
const [showLangMenu, setShowLangMenu] = useState(false);
const langMenuRef = useRef<HTMLDivElement>(null);
const [showHistoryMenu, setShowHistoryMenu] = useState(false);
const historyMenuRef = useRef<HTMLDivElement>(null);
useEffect(() => {
if (!showLangMenu) return;
@@ -50,21 +61,81 @@ export function ScriptEditor({
return () => document.removeEventListener("mousedown", handleClickOutside);
}, [showLangMenu]);
useEffect(() => {
if (!showHistoryMenu) return;
const handleClickOutside = (e: MouseEvent) => {
if (historyMenuRef.current && !historyMenuRef.current.contains(e.target as Node)) {
setShowHistoryMenu(false);
}
};
document.addEventListener("mousedown", handleClickOutside);
return () => document.removeEventListener("mousedown", handleClickOutside);
}, [showHistoryMenu]);
const handleSelectLang = (langCode: string) => {
setShowLangMenu(false);
onTranslate(langCode);
};
const formatDate = (ts: number) => {
const d = new Date(ts);
return `${(d.getMonth() + 1).toString().padStart(2, "0")}-${d.getDate().toString().padStart(2, "0")} ${d.getHours().toString().padStart(2, "0")}:${d.getMinutes().toString().padStart(2, "0")}`;
};
return (
<div className="relative z-10 bg-white/5 rounded-2xl p-4 sm:p-6 border border-white/10 backdrop-blur-sm">
<div className="mb-4 space-y-3">
<h2 className="text-base sm:text-lg font-semibold text-white flex items-center gap-2">
</h2>
<div className="flex gap-2 flex-wrap justify-end">
<div className="flex gap-2 flex-wrap justify-end items-center">
{/* 历史文案 */}
<div className="relative" ref={historyMenuRef}>
<button
onClick={() => setShowHistoryMenu((prev) => !prev)}
className="h-7 px-2.5 text-xs rounded transition-all whitespace-nowrap bg-gray-600 hover:bg-gray-500 text-white inline-flex items-center gap-1"
>
<History className="h-3.5 w-3.5" />
</button>
{showHistoryMenu && (
<div className="absolute left-0 top-full mt-1 z-50 bg-gray-800 border border-white/10 rounded-lg shadow-xl py-1 min-w-[220px] max-h-[280px] overflow-y-auto">
{savedScripts.length === 0 ? (
<div className="px-3 py-3 text-xs text-gray-500 text-center"></div>
) : (
savedScripts.map((script) => (
<div
key={script.id}
className="flex items-center gap-1 px-3 py-1.5 hover:bg-white/10 transition-colors group"
>
<button
onClick={() => {
onLoadScript(script.content);
setShowHistoryMenu(false);
}}
className="flex-1 text-left min-w-0"
>
<div className="text-xs text-gray-200 truncate">{script.name}</div>
<div className="text-[10px] text-gray-500">{formatDate(script.savedAt)}</div>
</button>
<button
onClick={(e) => {
e.stopPropagation();
onDeleteScript(script.id);
}}
className="opacity-0 group-hover:opacity-100 p-1 text-gray-500 hover:text-red-400 transition-all shrink-0"
>
<Trash2 className="h-3 w-3" />
</button>
</div>
))
)}
</div>
)}
</div>
<button
onClick={onOpenExtractModal}
className="px-2 py-1 text-xs rounded transition-all whitespace-nowrap bg-purple-600 hover:bg-purple-700 text-white flex items-center gap-1"
className="h-7 px-2.5 text-xs rounded transition-all whitespace-nowrap bg-purple-600 hover:bg-purple-700 text-white inline-flex items-center gap-1"
>
<FileText className="h-3.5 w-3.5" />
@@ -73,22 +144,22 @@ export function ScriptEditor({
<button
onClick={() => setShowLangMenu((prev) => !prev)}
disabled={isTranslating || !text.trim()}
className={`px-2 py-1 text-xs rounded transition-all whitespace-nowrap ${
className={`h-7 px-2.5 text-xs rounded transition-all whitespace-nowrap inline-flex items-center gap-1 ${
isTranslating || !text.trim()
? "bg-gray-600 cursor-not-allowed text-gray-400"
: "bg-gradient-to-r from-emerald-600 to-teal-600 hover:from-emerald-700 hover:to-teal-700 text-white"
}`}
>
{isTranslating ? (
<span className="flex items-center gap-1">
<>
<Loader2 className="h-3.5 w-3.5 animate-spin" />
...
</span>
</>
) : (
<span className="flex items-center gap-1">
<>
<Languages className="h-3.5 w-3.5" />
AI多语言
</span>
</>
)}
</button>
{showLangMenu && (
@@ -120,21 +191,21 @@ export function ScriptEditor({
<button
onClick={onGenerateMeta}
disabled={isGeneratingMeta || !text.trim()}
className={`px-2 py-1 text-xs rounded transition-all whitespace-nowrap ${isGeneratingMeta || !text.trim()
className={`h-7 px-2.5 text-xs rounded transition-all whitespace-nowrap inline-flex items-center gap-1 ${isGeneratingMeta || !text.trim()
? "bg-gray-600 cursor-not-allowed text-gray-400"
: "bg-gradient-to-r from-blue-600 to-cyan-600 hover:from-blue-700 hover:to-cyan-700 text-white"
}`}
>
{isGeneratingMeta ? (
<span className="flex items-center gap-1">
<>
<Loader2 className="h-3.5 w-3.5 animate-spin" />
...
</span>
</>
) : (
<span className="flex items-center gap-1">
<>
<Sparkles className="h-3.5 w-3.5" />
AI生成标题标签
</span>
</>
)}
</button>
</div>
@@ -145,9 +216,20 @@ export function ScriptEditor({
placeholder="请输入你想说的话..."
className="w-full h-40 bg-black/30 border border-white/10 rounded-xl p-4 text-white placeholder-gray-500 resize-none focus:outline-none focus:border-purple-500 transition-colors hide-scrollbar"
/>
<div className="flex justify-between mt-2 text-sm text-gray-400">
<div className="flex items-center justify-between mt-2 text-sm text-gray-400">
<span>{text.length} </span>
<span>: ~{Math.ceil(text.length / 4)} </span>
<button
onClick={onSaveScript}
disabled={!text.trim()}
className={`px-2.5 py-1 text-xs rounded transition-all flex items-center gap-1 ${
!text.trim()
? "bg-gray-700 cursor-not-allowed text-gray-500"
: "bg-amber-600/80 hover:bg-amber-600 text-white"
}`}
>
<Save className="h-3 w-3" />
</button>
</div>
</div>
);

View File

@@ -26,9 +26,13 @@ export default function ScriptExtractionModal({
selectedFile,
activeTab,
inputUrl,
customPrompt,
showCustomPrompt,
setDoRewrite,
setActiveTab,
setInputUrl,
setCustomPrompt,
setShowCustomPrompt,
handleDrag,
handleDrop,
handleFileChange,
@@ -187,18 +191,43 @@ export default function ScriptExtractionModal({
)}
{/* Options */}
<div className="flex items-center gap-3 bg-white/5 rounded-xl p-4 border border-white/10">
<label className="flex items-center gap-2 cursor-pointer">
<input
type="checkbox"
checked={doRewrite}
onChange={(e) => setDoRewrite(e.target.checked)}
className="w-4 h-4 rounded bg-white/10 border-white/20 text-purple-500 focus:ring-purple-500"
/>
<span className="text-sm text-gray-300">
AI
</span>
</label>
<div className="bg-white/5 rounded-xl border border-white/10 overflow-hidden">
<div className="flex items-center justify-between p-4">
<label className="flex items-center gap-2 cursor-pointer">
<input
type="checkbox"
checked={doRewrite}
onChange={(e) => setDoRewrite(e.target.checked)}
className="w-4 h-4 rounded bg-white/10 border-white/20 text-purple-500 focus:ring-purple-500"
/>
<span className="text-sm text-gray-300">
AI
</span>
</label>
{doRewrite && (
<button
type="button"
onClick={() => setShowCustomPrompt(!showCustomPrompt)}
className="text-xs text-purple-400 hover:text-purple-300 transition-colors flex items-center gap-1"
>
{showCustomPrompt ? "▲" : "▼"}
</button>
)}
</div>
{doRewrite && showCustomPrompt && (
<div className="px-4 pb-4 space-y-2">
<textarea
value={customPrompt}
onChange={(e) => setCustomPrompt(e.target.value)}
placeholder="输入自定义改写提示词..."
rows={3}
className="w-full bg-black/20 border border-white/10 rounded-lg px-3 py-2 text-sm text-white placeholder-gray-500 focus:outline-none focus:border-purple-500 transition-colors resize-none"
/>
<p className="text-xs text-gray-500">
使
</p>
</div>
)}
</div>
{/* Error */}
@@ -261,7 +290,7 @@ export default function ScriptExtractionModal({
<div className="space-y-2">
<div className="flex justify-between items-center">
<h4 className="font-semibold text-purple-300 flex items-center gap-2">
AI 稿{" "}
AI {" "}
<span className="text-xs font-normal text-purple-400/70">
()
</span>

View File

@@ -0,0 +1,349 @@
import { useEffect, useRef, useCallback, useState } from "react";
import WaveSurfer from "wavesurfer.js";
import { ChevronDown } from "lucide-react";
import type { TimelineSegment } from "@/features/home/model/useTimelineEditor";
import type { Material } from "@/shared/types/material";
interface TimelineEditorProps {
audioDuration: number;
audioUrl: string;
segments: TimelineSegment[];
materials: Material[];
outputAspectRatio: "9:16" | "16:9";
onOutputAspectRatioChange: (ratio: "9:16" | "16:9") => void;
onReorderSegment: (fromIdx: number, toIdx: number) => void;
onClickSegment: (segment: TimelineSegment) => void;
}
function formatTime(sec: number): string {
const m = Math.floor(sec / 60);
const s = sec % 60;
return `${String(m).padStart(2, "0")}:${s.toFixed(1).padStart(4, "0")}`;
}
export function TimelineEditor({
audioDuration,
audioUrl,
segments,
materials,
outputAspectRatio,
onOutputAspectRatioChange,
onReorderSegment,
onClickSegment,
}: TimelineEditorProps) {
const waveRef = useRef<HTMLDivElement>(null);
const wsRef = useRef<WaveSurfer | null>(null);
const [waveReady, setWaveReady] = useState(false);
const [isPlaying, setIsPlaying] = useState(false);
// Refs for high-frequency DOM updates (avoid 60fps re-renders)
const playheadRef = useRef<HTMLDivElement>(null);
const timeRef = useRef<HTMLSpanElement>(null);
const audioDurationRef = useRef(audioDuration);
useEffect(() => {
audioDurationRef.current = audioDuration;
}, [audioDuration]);
// Drag-to-reorder state
const [dragFromIdx, setDragFromIdx] = useState<number | null>(null);
const [dragOverIdx, setDragOverIdx] = useState<number | null>(null);
// Aspect ratio dropdown
const [ratioOpen, setRatioOpen] = useState(false);
const ratioRef = useRef<HTMLDivElement>(null);
const ratioOptions = [
{ value: "9:16" as const, label: "竖屏 9:16" },
{ value: "16:9" as const, label: "横屏 16:9" },
];
const currentRatioLabel =
ratioOptions.find((opt) => opt.value === outputAspectRatio)?.label ?? "竖屏 9:16";
useEffect(() => {
const handler = (e: MouseEvent) => {
if (ratioRef.current && !ratioRef.current.contains(e.target as Node)) {
setRatioOpen(false);
}
};
if (ratioOpen) document.addEventListener("mousedown", handler);
return () => document.removeEventListener("mousedown", handler);
}, [ratioOpen]);
// Create / recreate wavesurfer when audioUrl changes
useEffect(() => {
if (!waveRef.current || !audioUrl) return;
const playheadEl = playheadRef.current;
const timeEl = timeRef.current;
// Destroy previous instance
if (wsRef.current) {
wsRef.current.destroy();
wsRef.current = null;
}
const ws = WaveSurfer.create({
container: waveRef.current,
height: 56,
waveColor: "#6d28d9",
progressColor: "#a855f7",
barWidth: 2,
barGap: 1,
barRadius: 2,
cursorWidth: 1,
cursorColor: "#e879f9",
interact: true,
normalize: true,
});
// Click waveform → seek + auto-play
ws.on("interaction", () => ws.play());
ws.on("play", () => setIsPlaying(true));
ws.on("pause", () => setIsPlaying(false));
ws.on("finish", () => {
setIsPlaying(false);
if (playheadRef.current) playheadRef.current.style.display = "none";
});
// High-frequency: update playhead + time via refs (no React re-render)
ws.on("timeupdate", (time: number) => {
const dur = audioDurationRef.current;
if (playheadRef.current && dur > 0) {
playheadRef.current.style.left = `${(time / dur) * 100}%`;
playheadRef.current.style.display = "block";
}
if (timeRef.current) {
timeRef.current.textContent = formatTime(time);
}
});
ws.load(audioUrl);
wsRef.current = ws;
return () => {
ws.destroy();
wsRef.current = null;
setIsPlaying(false);
if (playheadEl) playheadEl.style.display = "none";
if (timeEl) timeEl.textContent = formatTime(0);
};
}, [audioUrl, waveReady]);
// Callback ref to detect when waveRef div mounts
const waveCallbackRef = useCallback((node: HTMLDivElement | null) => {
(waveRef as React.MutableRefObject<HTMLDivElement | null>).current = node;
setWaveReady(!!node);
}, []);
const handlePlayPause = useCallback(() => {
wsRef.current?.playPause();
}, []);
// Drag-to-reorder handlers
const handleDragStart = useCallback((idx: number, e: React.DragEvent) => {
setDragFromIdx(idx);
e.dataTransfer.effectAllowed = "move";
e.dataTransfer.setData("text/plain", String(idx));
}, []);
const handleDragOver = useCallback((idx: number, e: React.DragEvent) => {
e.preventDefault();
e.dataTransfer.dropEffect = "move";
setDragOverIdx(idx);
}, []);
const handleDragLeave = useCallback(() => {
setDragOverIdx(null);
}, []);
const handleDrop = useCallback((toIdx: number, e: React.DragEvent) => {
e.preventDefault();
const fromIdx = parseInt(e.dataTransfer.getData("text/plain"), 10);
if (!isNaN(fromIdx) && fromIdx !== toIdx) {
onReorderSegment(fromIdx, toIdx);
}
setDragFromIdx(null);
setDragOverIdx(null);
}, [onReorderSegment]);
const handleDragEnd = useCallback(() => {
setDragFromIdx(null);
setDragOverIdx(null);
}, []);
// Filter visible vs overflow segments
const visibleSegments = segments.filter((s) => s.start < audioDuration);
const overflowSegments = segments.filter((s) => s.start >= audioDuration);
const hasSegments = visibleSegments.length > 0;
return (
<div className="bg-white/5 rounded-2xl p-4 sm:p-6 border border-white/10 backdrop-blur-sm">
<div className="flex items-center justify-between mb-3">
<h2 className="text-base sm:text-lg font-semibold text-white flex items-center gap-2">
🎞
</h2>
<div className="flex items-center gap-2 text-xs text-gray-400">
<div ref={ratioRef} className="relative">
<button
type="button"
onClick={() => setRatioOpen((v) => !v)}
className="px-2 py-1 text-xs bg-white/10 hover:bg-white/20 rounded text-gray-300 whitespace-nowrap flex items-center gap-1 transition-all"
title="设置输出画面比例"
>
: {currentRatioLabel}
<ChevronDown className={`h-3 w-3 transition-transform ${ratioOpen ? "rotate-180" : ""}`} />
</button>
{ratioOpen && (
<div className="absolute right-0 top-full mt-1 bg-gray-800 border border-white/20 rounded-lg shadow-xl py-1 z-50 min-w-[106px]">
{ratioOptions.map((opt) => (
<button
key={opt.value}
type="button"
onClick={() => {
onOutputAspectRatioChange(opt.value);
setRatioOpen(false);
}}
className={`w-full text-left px-3 py-1.5 text-xs transition-colors ${
outputAspectRatio === opt.value
? "bg-purple-600/40 text-purple-200"
: "text-gray-300 hover:bg-white/10"
}`}
>
{opt.label}
</button>
))}
</div>
)}
</div>
{audioUrl && (
<>
<button
onClick={handlePlayPause}
className="w-7 h-7 flex items-center justify-center rounded-full bg-white/10 hover:bg-white/20 text-white transition-colors"
title={isPlaying ? "暂停" : "播放"}
>
{isPlaying ? "⏸" : "▶"}
</button>
<span ref={timeRef} className="tabular-nums">00:00.0</span>
<span className="text-gray-600">/</span>
<span className="tabular-nums">{formatTime(audioDuration)}</span>
</>
)}
</div>
</div>
{/* Waveform — always rendered so ref stays mounted */}
<div className="relative mb-1">
<div ref={waveCallbackRef} className="rounded-lg overflow-hidden bg-black/20 cursor-pointer" style={{ minHeight: 56 }} />
</div>
{/* Segment blocks or empty placeholder */}
{hasSegments ? (
<>
<div className="relative h-14 flex select-none">
{/* Playhead — syncs with audio playback */}
<div
ref={playheadRef}
className="absolute top-0 h-full w-0.5 bg-fuchsia-400 z-10 pointer-events-none"
style={{ display: "none", left: "0%" }}
/>
{visibleSegments.map((seg, i) => {
const left = (seg.start / audioDuration) * 100;
const width = ((seg.end - seg.start) / audioDuration) * 100;
const segDur = seg.end - seg.start;
const isDragTarget = dragOverIdx === i && dragFromIdx !== i;
// Compute loop portion for the last visible segment
const isLastVisible = i === visibleSegments.length - 1;
let loopPercent = 0;
if (isLastVisible && audioDuration > 0) {
const mat = materials.find((m) => m.id === seg.materialId);
const matDur = mat?.duration_sec ?? 0;
const effDur = (seg.sourceEnd > seg.sourceStart)
? (seg.sourceEnd - seg.sourceStart)
: Math.max(matDur - seg.sourceStart, 0);
if (effDur > 0 && segDur > effDur + 0.1) {
loopPercent = ((segDur - effDur) / segDur) * 100;
}
}
return (
<div key={seg.id} className="absolute top-0 h-full" style={{ left: `${left}%`, width: `${width}%` }}>
<button
draggable
onDragStart={(e) => handleDragStart(i, e)}
onDragOver={(e) => handleDragOver(i, e)}
onDragLeave={handleDragLeave}
onDrop={(e) => handleDrop(i, e)}
onDragEnd={handleDragEnd}
onClick={() => onClickSegment(seg)}
className={`relative w-full h-full rounded-lg flex flex-col items-center justify-center overflow-hidden cursor-grab active:cursor-grabbing transition-all border ${
isDragTarget
? "ring-2 ring-purple-400 border-purple-400 scale-[1.02]"
: dragFromIdx === i
? "opacity-50 border-white/10"
: "hover:opacity-90 border-white/10"
}`}
style={{ backgroundColor: seg.color + "33", borderColor: isDragTarget ? undefined : seg.color + "66" }}
title={`拖拽可调换顺序 · 点击设置截取范围\n${seg.materialName}\n${segDur.toFixed(1)}s${loopPercent > 0 ? ` (含循环 ${(segDur * loopPercent / 100).toFixed(1)}s)` : ""}`}
>
<span className="text-[11px] text-white/90 truncate max-w-full px-1 leading-tight z-[1]">
{seg.materialName}
</span>
<span className="text-[10px] text-white/60 leading-tight z-[1]">
{segDur.toFixed(1)}s
</span>
{seg.sourceStart > 0 && (
<span className="text-[9px] text-amber-400/80 leading-tight z-[1]">
{seg.sourceStart.toFixed(1)}s
</span>
)}
{/* Loop fill stripe overlay */}
{loopPercent > 0 && (
<div
className="absolute top-0 right-0 h-full pointer-events-none flex items-center justify-center"
style={{
width: `${loopPercent}%`,
background: `repeating-linear-gradient(-45deg, transparent, transparent 3px, rgba(255,255,255,0.07) 3px, rgba(255,255,255,0.07) 6px)`,
borderLeft: "1px dashed rgba(255,255,255,0.25)",
}}
>
<span className="text-[9px] text-white/30"></span>
</div>
)}
</button>
</div>
);
})}
</div>
{/* Overflow segments — shown as gray chips */}
{overflowSegments.length > 0 && (
<div className="flex flex-wrap items-center gap-1.5 mt-1.5">
<span className="text-[10px] text-gray-500">使:</span>
{overflowSegments.map((seg) => (
<span
key={seg.id}
className="text-[10px] text-gray-500 bg-white/5 border border-white/10 rounded px-1.5 py-0.5"
>
{seg.materialName}
</span>
))}
</div>
)}
<p className="text-[10px] text-gray-500 mt-1.5">
· ·
</p>
</>
) : (
<>
<div className="h-14 bg-white/5 rounded-lg" />
<p className="text-[10px] text-gray-500 mt-1.5">
</p>
</>
)}
</div>
);
}

View File

@@ -1,4 +1,4 @@
import { Eye } from "lucide-react";
import { ChevronDown, Eye } from "lucide-react";
import { FloatingStylePreview } from "@/features/home/ui/FloatingStylePreview";
interface SubtitleStyleOption {
@@ -38,11 +38,21 @@ interface TitleSubtitlePanelProps {
onTitleChange: (value: string) => void;
onTitleCompositionStart?: () => void;
onTitleCompositionEnd?: (value: string) => void;
videoSecondaryTitle: string;
onSecondaryTitleChange: (value: string) => void;
onSecondaryTitleCompositionStart?: () => void;
onSecondaryTitleCompositionEnd?: (value: string) => void;
titleStyles: TitleStyleOption[];
selectedTitleStyleId: string;
onSelectTitleStyle: (id: string) => void;
titleFontSize: number;
onTitleFontSizeChange: (value: number) => void;
selectedSecondaryTitleStyleId: string;
onSelectSecondaryTitleStyle: (id: string) => void;
secondaryTitleFontSize: number;
onSecondaryTitleFontSizeChange: (value: number) => void;
secondaryTitleTopMargin: number;
onSecondaryTitleTopMarginChange: (value: number) => void;
subtitleStyles: SubtitleStyleOption[];
selectedSubtitleStyleId: string;
onSelectSubtitleStyle: (id: string) => void;
@@ -52,6 +62,8 @@ interface TitleSubtitlePanelProps {
onTitleTopMarginChange: (value: number) => void;
subtitleBottomMargin: number;
onSubtitleBottomMarginChange: (value: number) => void;
titleDisplayMode: "short" | "persistent";
onTitleDisplayModeChange: (mode: "short" | "persistent") => void;
resolveAssetUrl: (path?: string | null) => string | null;
getFontFormat: (fontFile?: string) => string;
buildTextShadow: (color: string, size: number) => string;
@@ -66,11 +78,21 @@ export function TitleSubtitlePanel({
onTitleChange,
onTitleCompositionStart,
onTitleCompositionEnd,
videoSecondaryTitle,
onSecondaryTitleChange,
onSecondaryTitleCompositionStart,
onSecondaryTitleCompositionEnd,
titleStyles,
selectedTitleStyleId,
onSelectTitleStyle,
titleFontSize,
onTitleFontSizeChange,
selectedSecondaryTitleStyleId,
onSelectSecondaryTitleStyle,
secondaryTitleFontSize,
onSecondaryTitleFontSizeChange,
secondaryTitleTopMargin,
onSecondaryTitleTopMarginChange,
subtitleStyles,
selectedSubtitleStyleId,
onSelectSubtitleStyle,
@@ -80,6 +102,8 @@ export function TitleSubtitlePanel({
onTitleTopMarginChange,
subtitleBottomMargin,
onSubtitleBottomMarginChange,
titleDisplayMode,
onTitleDisplayModeChange,
resolveAssetUrl,
getFontFormat,
buildTextShadow,
@@ -105,9 +129,13 @@ export function TitleSubtitlePanel({
<FloatingStylePreview
onClose={onTogglePreview}
videoTitle={videoTitle}
videoSecondaryTitle={videoSecondaryTitle}
titleStyles={titleStyles}
selectedTitleStyleId={selectedTitleStyleId}
titleFontSize={titleFontSize}
selectedSecondaryTitleStyleId={selectedSecondaryTitleStyleId}
secondaryTitleFontSize={secondaryTitleFontSize}
secondaryTitleTopMargin={secondaryTitleTopMargin}
subtitleStyles={subtitleStyles}
selectedSubtitleStyleId={selectedSubtitleStyleId}
subtitleFontSize={subtitleFontSize}
@@ -123,7 +151,21 @@ export function TitleSubtitlePanel({
)}
<div className="mb-4">
<label className="text-sm text-gray-300 mb-2 block">15</label>
<div className="mb-2 flex items-center justify-between gap-2">
<label className="text-sm text-gray-300">15</label>
<div className="relative shrink-0">
<select
value={titleDisplayMode}
onChange={(e) => onTitleDisplayModeChange(e.target.value as "short" | "persistent")}
className="appearance-none rounded-lg border border-white/15 bg-black/35 px-2.5 py-1.5 pr-7 text-xs text-gray-200 outline-none transition-colors hover:border-white/25 focus:border-purple-500"
aria-label="标题显示方式"
>
<option value="short"></option>
<option value="persistent"></option>
</select>
<ChevronDown className="pointer-events-none absolute right-2 top-1/2 h-3.5 w-3.5 -translate-y-1/2 text-gray-400" />
</div>
</div>
<input
type="text"
value={videoTitle}
@@ -135,6 +177,19 @@ export function TitleSubtitlePanel({
/>
</div>
<div className="mb-4">
<label className="text-sm text-gray-300 mb-2 block">20</label>
<input
type="text"
value={videoSecondaryTitle}
onChange={(e) => onSecondaryTitleChange(e.target.value)}
onCompositionStart={onSecondaryTitleCompositionStart}
onCompositionEnd={(e) => onSecondaryTitleCompositionEnd?.(e.currentTarget.value)}
placeholder="输入副标题,显示在主标题下方"
className="w-full px-3 sm:px-4 py-2 text-sm sm:text-base bg-black/30 border border-white/10 rounded-xl text-white placeholder-gray-500 focus:outline-none focus:border-purple-500 transition-colors"
/>
</div>
{titleStyles.length > 0 && (
<div className="mb-4">
<label className="text-sm text-gray-300 mb-2 block"></label>
@@ -182,6 +237,53 @@ export function TitleSubtitlePanel({
</div>
)}
{titleStyles.length > 0 && (
<div className="mb-4">
<label className="text-sm text-gray-300 mb-2 block"></label>
<div className="grid grid-cols-2 gap-2">
{titleStyles.map((style) => (
<button
key={style.id}
onClick={() => onSelectSecondaryTitleStyle(style.id)}
className={`p-2 rounded-lg border transition-all text-left ${selectedSecondaryTitleStyleId === style.id
? "border-purple-500 bg-purple-500/20"
: "border-white/10 bg-white/5 hover:border-white/30"
}`}
>
<div className="text-white text-sm truncate">{style.label}</div>
<div className="text-xs text-gray-400 truncate">
{style.font_family || style.font_file || ""}
</div>
</button>
))}
</div>
<div className="mt-3">
<label className="text-xs text-gray-400 mb-2 block">: {secondaryTitleFontSize}px</label>
<input
type="range"
min="30"
max="100"
step="1"
value={secondaryTitleFontSize}
onChange={(e) => onSecondaryTitleFontSizeChange(parseInt(e.target.value, 10))}
className="w-full accent-purple-500"
/>
</div>
<div className="mt-3">
<label className="text-xs text-gray-400 mb-2 block">: {secondaryTitleTopMargin}px</label>
<input
type="range"
min="0"
max="100"
step="1"
value={secondaryTitleTopMargin}
onChange={(e) => onSecondaryTitleTopMarginChange(parseInt(e.target.value, 10))}
className="w-full accent-purple-500"
/>
</div>
</div>
)}
{subtitleStyles.length > 0 && (
<div className="mt-4">
<label className="text-sm text-gray-300 mb-2 block"></label>

View File

@@ -1,4 +1,4 @@
import { useState, useEffect, useCallback } from "react";
import { useState, useEffect, useCallback, useRef } from "react";
import api from "@/shared/api/axios";
import { ApiResponse, unwrap } from "@/shared/api/types";
import { toast } from "sonner";
@@ -7,6 +7,7 @@ export type ExtractionStep = "config" | "processing" | "result";
export type InputTab = "file" | "url";
const VALID_FILE_TYPES = [".mp4", ".mov", ".avi", ".mp3", ".wav", ".m4a"];
const CUSTOM_PROMPT_KEY = "vigent_rewriteCustomPrompt";
interface UseScriptExtractionOptions {
isOpen: boolean;
@@ -23,8 +24,19 @@ export const useScriptExtraction = ({ isOpen }: UseScriptExtractionOptions) => {
const [selectedFile, setSelectedFile] = useState<File | null>(null);
const [activeTab, setActiveTab] = useState<InputTab>("url");
const [inputUrl, setInputUrl] = useState("");
const [customPrompt, setCustomPrompt] = useState(() => typeof window !== "undefined" ? localStorage.getItem(CUSTOM_PROMPT_KEY) || "" : "");
const [showCustomPrompt, setShowCustomPrompt] = useState(false);
// Reset state when modal opens
// Debounced save customPrompt to localStorage
const debounceRef = useRef<ReturnType<typeof setTimeout>>(undefined);
useEffect(() => {
debounceRef.current = setTimeout(() => {
localStorage.setItem(CUSTOM_PROMPT_KEY, customPrompt);
}, 300);
return () => clearTimeout(debounceRef.current);
}, [customPrompt]);
// Reset state when modal opens (customPrompt is persistent, not reset)
useEffect(() => {
if (isOpen) {
setStep("config");
@@ -101,6 +113,9 @@ export const useScriptExtraction = ({ isOpen }: UseScriptExtractionOptions) => {
formData.append("url", inputUrl.trim());
}
formData.append("rewrite", doRewrite ? "true" : "false");
if (doRewrite && customPrompt.trim()) {
formData.append("custom_prompt", customPrompt.trim());
}
const { data: res } = await api.post<
ApiResponse<{ original_script: string; rewritten_script?: string }>
@@ -126,7 +141,7 @@ export const useScriptExtraction = ({ isOpen }: UseScriptExtractionOptions) => {
} finally {
setIsLoading(false);
}
}, [activeTab, selectedFile, inputUrl, doRewrite]);
}, [activeTab, selectedFile, inputUrl, doRewrite, customPrompt]);
const copyToClipboard = useCallback((text: string) => {
if (navigator.clipboard && window.isSecureContext) {
@@ -193,10 +208,14 @@ export const useScriptExtraction = ({ isOpen }: UseScriptExtractionOptions) => {
selectedFile,
activeTab,
inputUrl,
customPrompt,
showCustomPrompt,
// Setters
setDoRewrite,
setActiveTab,
setInputUrl,
setCustomPrompt,
setShowCustomPrompt,
// Handlers
handleDrag,
handleDrop,

View File

@@ -12,7 +12,7 @@ const API_BASE = typeof window === 'undefined'
// 防止重复跳转
let isRedirecting = false;
const PUBLIC_PATHS = new Set(['/login', '/register']);
const PUBLIC_PATHS = new Set(['/login', '/register', '/pay']);
// 创建 axios 实例
const api = axios.create({

View File

@@ -12,6 +12,7 @@ export interface AuthResponse {
success: boolean;
message: string;
user?: User;
paymentToken?: string;
}
interface ApiResponse<T> {
@@ -25,20 +26,41 @@ interface ApiResponse<T> {
* 用户注册
*/
export async function register(phone: string, password: string, username?: string): Promise<AuthResponse> {
const { data: payload } = await api.post<ApiResponse<null>>('/api/auth/register', {
phone, password, username
});
return { success: payload.success, message: payload.message };
try {
const { data: payload } = await api.post<ApiResponse<null>>('/api/auth/register', {
phone, password, username
});
return { success: payload.success, message: payload.message };
} catch (err: any) {
return {
success: false,
message: err.response?.data?.message || '注册失败',
};
}
}
/**
* 用户登录
*/
export async function login(phone: string, password: string): Promise<AuthResponse> {
const { data: payload } = await api.post<ApiResponse<{ user?: User }>>('/api/auth/login', {
phone, password
});
return { success: payload.success, message: payload.message, user: payload.data?.user };
try {
const { data: payload } = await api.post<ApiResponse<{ user?: User }>>('/api/auth/login', {
phone, password
});
return { success: payload.success, message: payload.message, user: payload.data?.user };
} catch (err: any) {
if (err.response?.status === 403 && err.response?.data?.data?.reason === 'PAYMENT_REQUIRED') {
return {
success: false,
message: err.response.data.message,
paymentToken: err.response.data.data.payment_token,
};
}
return {
success: false,
message: err.response?.data?.message || '登录失败',
};
}
}
/**

View File

@@ -1,8 +1,12 @@
export const TITLE_MAX_LENGTH = 15;
export const SECONDARY_TITLE_MAX_LENGTH = 20;
export const clampTitle = (value: string, maxLength: number = TITLE_MAX_LENGTH) =>
value.slice(0, maxLength);
export const clampSecondaryTitle = (value: string, maxLength: number = SECONDARY_TITLE_MAX_LENGTH) =>
value.slice(0, maxLength);
export const applyTitleLimit = (
prev: string,
next: string,

View File

@@ -4,4 +4,5 @@ export interface Material {
path: string;
size_mb: number;
scene?: string;
duration_sec?: number;
}

View File

@@ -0,0 +1,76 @@
# Contributor Covenant Code of Conduct
## Our Pledge
In the interest of fostering an open and welcoming environment, we as
contributors and maintainers pledge to making participation in our project and
our community a harassment-free experience for everyone, regardless of age, body
size, disability, ethnicity, sex characteristics, gender identity and expression,
level of experience, education, socio-economic status, nationality, personal
appearance, race, religion, or sexual identity and orientation.
## Our Standards
Examples of behavior that contributes to creating a positive environment
include:
* Using welcoming and inclusive language
* Being respectful of differing viewpoints and experiences
* Gracefully accepting constructive criticism
* Focusing on what is best for the community
* Showing empathy towards other community members
Examples of unacceptable behavior by participants include:
* The use of sexualized language or imagery and unwelcome sexual attention or
advances
* Trolling, insulting/derogatory comments, and personal or political attacks
* Public or private harassment
* Publishing others' private information, such as a physical or electronic
address, without explicit permission
* Other conduct which could reasonably be considered inappropriate in a
professional setting
## Our Responsibilities
Project maintainers are responsible for clarifying the standards of acceptable
behavior and are expected to take appropriate and fair corrective action in
response to any instances of unacceptable behavior.
Project maintainers have the right and responsibility to remove, edit, or
reject comments, commits, code, wiki edits, issues, and other contributions
that are not aligned to this Code of Conduct, or to ban temporarily or
permanently any contributor for other behaviors that they deem inappropriate,
threatening, offensive, or harmful.
## Scope
This Code of Conduct applies both within project spaces and in public spaces
when an individual is representing the project or its community. Examples of
representing a project or community include using an official project e-mail
address, posting via an official social media account, or acting as an appointed
representative at an online or offline event. Representation of a project may be
further defined and clarified by project maintainers.
## Enforcement
Instances of abusive, harassing, or otherwise unacceptable behavior may be
reported by contacting the project team at mikelei@mobvoi.com. All
complaints will be reviewed and investigated and will result in a response that
is deemed necessary and appropriate to the circumstances. The project team is
obligated to maintain confidentiality with regard to the reporter of an incident.
Further details of specific enforcement policies may be posted separately.
Project maintainers who do not follow or enforce the Code of Conduct in good
faith may face temporary or permanent repercussions as determined by other
members of the project's leadership.
## Attribution
This Code of Conduct is adapted from the [Contributor Covenant][homepage], version 1.4,
available at https://www.contributor-covenant.org/version/1/4/code-of-conduct.html
[homepage]: https://www.contributor-covenant.org
For answers to common questions about this code of conduct, see
https://www.contributor-covenant.org/faq

16
models/CosyVoice/FAQ.md Normal file
View File

@@ -0,0 +1,16 @@
## ModuleNotFoundError: No module named 'matcha'
Matcha-TTS is a third_party module. Please check `third_party` directory. If there is no `Matcha-TTS`, execute `git submodule update --init --recursive`.
run `export PYTHONPATH=third_party/Matcha-TTS` if you want to use `from cosyvoice.cli.cosyvoice import CosyVoice` in python script.
## cannot find resource.zip or cannot unzip resource.zip
Please make sure you have git-lfs installed. Execute
```sh
git clone https://www.modelscope.cn/iic/CosyVoice-ttsfrd.git pretrained_models/CosyVoice-ttsfrd
cd pretrained_models/CosyVoice-ttsfrd/
unzip resource.zip -d .
pip install ttsfrd-0.3.6-cp38-cp38-linux_x86_64.whl
```

201
models/CosyVoice/LICENSE Normal file
View File

@@ -0,0 +1,201 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS
APPENDIX: How to apply the Apache License to your work.
To apply the Apache License to your work, attach the following
boilerplate notice, with the fields enclosed by brackets "[]"
replaced with your own identifying information. (Don't include
the brackets!) The text should be enclosed in the appropriate
comment syntax for the file format. We also recommend that a
file or class name and description of purpose be included on the
same "printed page" as the copyright notice for easier
identification within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.

264
models/CosyVoice/README.md Normal file
View File

@@ -0,0 +1,264 @@
![SVG Banners](https://svg-banners.vercel.app/api?type=origin&text1=CosyVoice🤠&text2=Text-to-Speech%20💖%20Large%20Language%20Model&width=800&height=210)
## 👉🏻 CosyVoice 👈🏻
**Fun-CosyVoice 3.0**: [Demos](https://funaudiollm.github.io/cosyvoice3/); [Paper](https://arxiv.org/pdf/2505.17589); [Modelscope](https://www.modelscope.cn/models/FunAudioLLM/Fun-CosyVoice3-0.5B-2512); [Huggingface](https://huggingface.co/FunAudioLLM/Fun-CosyVoice3-0.5B-2512); [CV3-Eval](https://github.com/FunAudioLLM/CV3-Eval)
**CosyVoice 2.0**: [Demos](https://funaudiollm.github.io/cosyvoice2/); [Paper](https://arxiv.org/pdf/2412.10117); [Modelscope](https://www.modelscope.cn/models/iic/CosyVoice2-0.5B); [HuggingFace](https://huggingface.co/FunAudioLLM/CosyVoice2-0.5B)
**CosyVoice 1.0**: [Demos](https://fun-audio-llm.github.io); [Paper](https://funaudiollm.github.io/pdf/CosyVoice_v1.pdf); [Modelscope](https://www.modelscope.cn/models/iic/CosyVoice-300M); [HuggingFace](https://huggingface.co/FunAudioLLM/CosyVoice-300M)
## Highlight🔥
**Fun-CosyVoice 3.0** is an advanced text-to-speech (TTS) system based on large language models (LLM), surpassing its predecessor (CosyVoice 2.0) in content consistency, speaker similarity, and prosody naturalness. It is designed for zero-shot multilingual speech synthesis in the wild.
### Key Features
- **Language Coverage**: Covers 9 common languages (Chinese, English, Japanese, Korean, German, Spanish, French, Italian, Russian), 18+ Chinese dialects/accents (Guangdong, Minnan, Sichuan, Dongbei, Shan3xi, Shan1xi, Shanghai, Tianjin, Shandong, Ningxia, Gansu, etc.) and meanwhile supports both multi-lingual/cross-lingual zero-shot voice cloning.
- **Content Consistency & Naturalness**: Achieves state-of-the-art performance in content consistency, speaker similarity, and prosody naturalness.
- **Pronunciation Inpainting**: Supports pronunciation inpainting of Chinese Pinyin and English CMU phonemes, providing more controllability and thus suitable for production use.
- **Text Normalization**: Supports reading of numbers, special symbols and various text formats without a traditional frontend module.
- **Bi-Streaming**: Support both text-in streaming and audio-out streaming, and achieves latency as low as 150ms while maintaining high-quality audio output.
- **Instruct Support**: Supports various instructions such as languages, dialects, emotions, speed, volume, etc.
## Roadmap
- [x] 2025/12
- [x] release Fun-CosyVoice3-0.5B-2512 base model, rl model and its training/inference script
- [x] release Fun-CosyVoice3-0.5B modelscope gradio space
- [x] 2025/08
- [x] Thanks to the contribution from NVIDIA Yuekai Zhang, add triton trtllm runtime support and cosyvoice2 grpo training support
- [x] 2025/07
- [x] release Fun-CosyVoice 3.0 eval set
- [x] 2025/05
- [x] add CosyVoice2-0.5B vllm support
- [x] 2024/12
- [x] 25hz CosyVoice2-0.5B released
- [x] 2024/09
- [x] 25hz CosyVoice-300M base model
- [x] 25hz CosyVoice-300M voice conversion function
- [x] 2024/08
- [x] Repetition Aware Sampling(RAS) inference for llm stability
- [x] Streaming inference mode support, including kv cache and sdpa for rtf optimization
- [x] 2024/07
- [x] Flow matching training support
- [x] WeTextProcessing support when ttsfrd is not available
- [x] Fastapi server and client
## Evaluation
| Model | Open-Source | Model Size | test-zh<br>CER (%) ↓ | test-zh<br>SS (%) ↑ | test-en<br>WER (%) ↓ | test-en<br>SS (%) ↑ | test-hard<br>CER (%) ↓ | test-hard<br>SS (%) ↑ |
| :--- | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |
| Human | - | - | 1.26 | 75.5 | 2.14 | 73.4 | - | - |
| Seed-TTS | ❌ | - | 1.12 | 79.6 | 2.25 | 76.2 | 7.59 | 77.6 |
| MiniMax-Speech | ❌ | - | 0.83 | 78.3 | 1.65 | 69.2 | - | - |
| F5-TTS | ✅ | 0.3B | 1.52 | 74.1 | 2.00 | 64.7 | 8.67 | 71.3 |
| Spark TTS | ✅ | 0.5B | 1.2 | 66.0 | 1.98 | 57.3 | - | - |
| CosyVoice2 | ✅ | 0.5B | 1.45 | 75.7 | 2.57 | 65.9 | 6.83 | 72.4 |
| FireRedTTS2 | ✅ | 1.5B | 1.14 | 73.2 | 1.95 | 66.5 | - | - |
| Index-TTS2 | ✅ | 1.5B | 1.03 | 76.5 | 2.23 | 70.6 | 7.12 | 75.5 |
| VibeVoice-1.5B | ✅ | 1.5B | 1.16 | 74.4 | 3.04 | 68.9 | - | - |
| VibeVoice-Realtime | ✅ | 0.5B | - | - | 2.05 | 63.3 | - | - |
| HiggsAudio-v2 | ✅ | 3B | 1.50 | 74.0 | 2.44 | 67.7 | - | - |
| VoxCPM | ✅ | 0.5B | 0.93 | 77.2 | 1.85 | 72.9 | 8.87 | 73.0 |
| GLM-TTS | ✅ | 1.5B | 1.03 | 76.1 | - | - | - | - |
| GLM-TTS RL | ✅ | 1.5B | 0.89 | 76.4 | - | - | - | - |
| Fun-CosyVoice3-0.5B-2512 | ✅ | 0.5B | 1.21 | 78.0 | 2.24 | 71.8 | 6.71 | 75.8 |
| Fun-CosyVoice3-0.5B-2512_RL | ✅ | 0.5B | 0.81 | 77.4 | 1.68 | 69.5 | 5.44 | 75.0 |
## Install
### Clone and install
- Clone the repo
``` sh
git clone --recursive https://github.com/FunAudioLLM/CosyVoice.git
# If you failed to clone the submodule due to network failures, please run the following command until success
cd CosyVoice
git submodule update --init --recursive
```
- Install Conda: please see https://docs.conda.io/en/latest/miniconda.html
- Create Conda env:
``` sh
conda create -n cosyvoice -y python=3.10
conda activate cosyvoice
pip install -r requirements.txt -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
# If you encounter sox compatibility issues
# ubuntu
sudo apt-get install sox libsox-dev
# centos
sudo yum install sox sox-devel
```
### Model download
We strongly recommend that you download our pretrained `Fun-CosyVoice3-0.5B` `CosyVoice2-0.5B` `CosyVoice-300M` `CosyVoice-300M-SFT` `CosyVoice-300M-Instruct` model and `CosyVoice-ttsfrd` resource.
``` python
# modelscope SDK model download
from modelscope import snapshot_download
snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B')
snapshot_download('iic/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
snapshot_download('iic/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
snapshot_download('iic/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
snapshot_download('iic/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
snapshot_download('iic/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
# for oversea users, huggingface SDK model download
from huggingface_hub import snapshot_download
snapshot_download('FunAudioLLM/Fun-CosyVoice3-0.5B-2512', local_dir='pretrained_models/Fun-CosyVoice3-0.5B')
snapshot_download('FunAudioLLM/CosyVoice2-0.5B', local_dir='pretrained_models/CosyVoice2-0.5B')
snapshot_download('FunAudioLLM/CosyVoice-300M', local_dir='pretrained_models/CosyVoice-300M')
snapshot_download('FunAudioLLM/CosyVoice-300M-SFT', local_dir='pretrained_models/CosyVoice-300M-SFT')
snapshot_download('FunAudioLLM/CosyVoice-300M-Instruct', local_dir='pretrained_models/CosyVoice-300M-Instruct')
snapshot_download('FunAudioLLM/CosyVoice-ttsfrd', local_dir='pretrained_models/CosyVoice-ttsfrd')
```
Optionally, you can unzip `ttsfrd` resource and install `ttsfrd` package for better text normalization performance.
Notice that this step is not necessary. If you do not install `ttsfrd` package, we will use wetext by default.
``` sh
cd pretrained_models/CosyVoice-ttsfrd/
unzip resource.zip -d .
pip install ttsfrd_dependency-0.1-py3-none-any.whl
pip install ttsfrd-0.4.2-cp310-cp310-linux_x86_64.whl
```
### Basic Usage
We strongly recommend using `Fun-CosyVoice3-0.5B` for better performance.
Follow the code in `example.py` for detailed usage of each model.
```sh
python example.py
```
#### vLLM Usage
CosyVoice2/3 now supports **vLLM 0.11.x+ (V1 engine)** and **vLLM 0.9.0 (legacy)**.
Older vllm version(<0.9.0) do not support CosyVoice inference, and versions in between (e.g., 0.10.x) are not tested.
Notice that `vllm` has a lot of specific requirements. You can create a new env to in case your hardward do not support vllm and old env is corrupted.
``` sh
conda create -n cosyvoice_vllm --clone cosyvoice
conda activate cosyvoice_vllm
# for vllm==0.9.0
pip install vllm==v0.9.0 transformers==4.51.3 numpy==1.26.4 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
# for vllm>=0.11.0
pip install vllm==v0.11.0 transformers==4.57.1 numpy==1.26.4 -i https://mirrors.aliyun.com/pypi/simple/ --trusted-host=mirrors.aliyun.com
python vllm_example.py
```
#### Start web demo
You can use our web demo page to get familiar with CosyVoice quickly.
Please see the demo website for details.
``` python
# change iic/CosyVoice-300M-SFT for sft inference, or iic/CosyVoice-300M-Instruct for instruct inference
python3 webui.py --port 50000 --model_dir pretrained_models/CosyVoice-300M
```
#### Advanced Usage
For advanced users, we have provided training and inference scripts in `examples/libritts`.
#### Build for deployment
Optionally, if you want service deployment,
You can run the following steps.
``` sh
cd runtime/python
docker build -t cosyvoice:v1.0 .
# change iic/CosyVoice-300M to iic/CosyVoice-300M-Instruct if you want to use instruct inference
# for grpc usage
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/grpc && python3 server.py --port 50000 --max_conc 4 --model_dir iic/CosyVoice-300M && sleep infinity"
cd grpc && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
# for fastapi usage
docker run -d --runtime=nvidia -p 50000:50000 cosyvoice:v1.0 /bin/bash -c "cd /opt/CosyVoice/CosyVoice/runtime/python/fastapi && python3 server.py --port 50000 --model_dir iic/CosyVoice-300M && sleep infinity"
cd fastapi && python3 client.py --port 50000 --mode <sft|zero_shot|cross_lingual|instruct>
```
#### Using Nvidia TensorRT-LLM for deployment
Using TensorRT-LLM to accelerate cosyvoice2 llm could give 4x acceleration comparing with huggingface transformers implementation.
To quick start:
``` sh
cd runtime/triton_trtllm
docker compose up -d
```
For more details, you could check [here](https://github.com/FunAudioLLM/CosyVoice/tree/main/runtime/triton_trtllm)
## Discussion & Communication
You can directly discuss on [Github Issues](https://github.com/FunAudioLLM/CosyVoice/issues).
You can also scan the QR code to join our official Dingding chat group.
<img src="./asset/dingding.png" width="250px">
## Acknowledge
1. We borrowed a lot of code from [FunASR](https://github.com/modelscope/FunASR).
2. We borrowed a lot of code from [FunCodec](https://github.com/modelscope/FunCodec).
3. We borrowed a lot of code from [Matcha-TTS](https://github.com/shivammehta25/Matcha-TTS).
4. We borrowed a lot of code from [AcademiCodec](https://github.com/yangdongchao/AcademiCodec).
5. We borrowed a lot of code from [WeNet](https://github.com/wenet-e2e/wenet).
## Citations
``` bibtex
@article{du2024cosyvoice,
title={Cosyvoice: A scalable multilingual zero-shot text-to-speech synthesizer based on supervised semantic tokens},
author={Du, Zhihao and Chen, Qian and Zhang, Shiliang and Hu, Kai and Lu, Heng and Yang, Yexin and Hu, Hangrui and Zheng, Siqi and Gu, Yue and Ma, Ziyang and others},
journal={arXiv preprint arXiv:2407.05407},
year={2024}
}
@article{du2024cosyvoice,
title={Cosyvoice 2: Scalable streaming speech synthesis with large language models},
author={Du, Zhihao and Wang, Yuxuan and Chen, Qian and Shi, Xian and Lv, Xiang and Zhao, Tianyu and Gao, Zhifu and Yang, Yexin and Gao, Changfeng and Wang, Hui and others},
journal={arXiv preprint arXiv:2412.10117},
year={2024}
}
@article{du2025cosyvoice,
title={CosyVoice 3: Towards In-the-wild Speech Generation via Scaling-up and Post-training},
author={Du, Zhihao and Gao, Changfeng and Wang, Yuxuan and Yu, Fan and Zhao, Tianyu and Wang, Hao and Lv, Xiang and Wang, Hui and Shi, Xian and An, Keyu and others},
journal={arXiv preprint arXiv:2505.17589},
year={2025}
}
@inproceedings{lyu2025build,
title={Build LLM-Based Zero-Shot Streaming TTS System with Cosyvoice},
author={Lyu, Xiang and Wang, Yuxuan and Zhao, Tianyu and Wang, Hao and Liu, Huadai and Du, Zhihao},
booktitle={ICASSP 2025-2025 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP)},
pages={1--2},
year={2025},
organization={IEEE}
}
```
## Disclaimer
The content provided above is for academic purposes only and is intended to demonstrate technical capabilities. Some examples are sourced from the internet. If any content infringes on your rights, please contact us to request its removal.

View File

View File

@@ -0,0 +1,93 @@
# Copyright (c) 2020 Mobvoi Inc (Di Wu)
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import argparse
import glob
import yaml
import torch
def get_args():
parser = argparse.ArgumentParser(description='average model')
parser.add_argument('--dst_model', required=True, help='averaged model')
parser.add_argument('--src_path',
required=True,
help='src model path for average')
parser.add_argument('--val_best',
action="store_true",
help='averaged model')
parser.add_argument('--num',
default=5,
type=int,
help='nums for averaged model')
args = parser.parse_args()
print(args)
return args
def main():
args = get_args()
val_scores = []
if args.val_best:
yamls = glob.glob('{}/*.yaml'.format(args.src_path))
yamls = [
f for f in yamls
if not (os.path.basename(f).startswith('train')
or os.path.basename(f).startswith('init'))
]
for y in yamls:
with open(y, 'r') as f:
dic_yaml = yaml.load(f, Loader=yaml.BaseLoader)
loss = float(dic_yaml['loss_dict']['loss'])
epoch = int(dic_yaml['epoch'])
step = int(dic_yaml['step'])
tag = dic_yaml['tag']
val_scores += [[epoch, step, loss, tag]]
sorted_val_scores = sorted(val_scores,
key=lambda x: x[2],
reverse=False)
print("best val (epoch, step, loss, tag) = " +
str(sorted_val_scores[:args.num]))
path_list = [
args.src_path + '/epoch_{}_whole.pt'.format(score[0])
for score in sorted_val_scores[:args.num]
]
print(path_list)
avg = {}
num = args.num
assert num == len(path_list)
for path in path_list:
print('Processing {}'.format(path))
states = torch.load(path, map_location=torch.device('cpu'))
for k in states.keys():
if k not in ['step', 'epoch']:
if k not in avg.keys():
avg[k] = states[k].clone()
else:
avg[k] += states[k]
# average
for k in avg.keys():
if avg[k] is not None:
# pytorch 1.6 use true_divide instead of /=
avg[k] = torch.true_divide(avg[k], num)
print('Saving to {}'.format(args.dst_model))
torch.save(avg, args.dst_model)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,99 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import sys
import torch
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import AutoModel
from cosyvoice.utils.file_utils import logging
def get_args():
parser = argparse.ArgumentParser(description='export your model for deployment')
parser.add_argument('--model_dir',
type=str,
default='pretrained_models/CosyVoice-300M',
help='local path')
args = parser.parse_args()
print(args)
return args
def get_optimized_script(model, preserved_attrs=[]):
script = torch.jit.script(model)
if preserved_attrs != []:
script = torch.jit.freeze(script, preserved_attrs=preserved_attrs)
else:
script = torch.jit.freeze(script)
script = torch.jit.optimize_for_inference(script)
return script
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
torch._C._jit_set_fusion_strategy([('STATIC', 1)])
torch._C._jit_set_profiling_mode(False)
torch._C._jit_set_profiling_executor(False)
model = AutoModel(model_dir=args.model_dir)
if model.__class__.__name__ == 'CosyVoice':
# 1. export llm text_encoder
llm_text_encoder = model.model.llm.text_encoder
script = get_optimized_script(llm_text_encoder)
script.save('{}/llm.text_encoder.fp32.zip'.format(args.model_dir))
script = get_optimized_script(llm_text_encoder.half())
script.save('{}/llm.text_encoder.fp16.zip'.format(args.model_dir))
logging.info('successfully export llm_text_encoder')
# 2. export llm llm
llm_llm = model.model.llm.llm
script = get_optimized_script(llm_llm, ['forward_chunk'])
script.save('{}/llm.llm.fp32.zip'.format(args.model_dir))
script = get_optimized_script(llm_llm.half(), ['forward_chunk'])
script.save('{}/llm.llm.fp16.zip'.format(args.model_dir))
logging.info('successfully export llm_llm')
# 3. export flow encoder
flow_encoder = model.model.flow.encoder
script = get_optimized_script(flow_encoder)
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
script = get_optimized_script(flow_encoder.half())
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
logging.info('successfully export flow_encoder')
elif model.__class__.__name__ == 'CosyVoice2':
# 1. export flow encoder
flow_encoder = model.model.flow.encoder
script = get_optimized_script(flow_encoder)
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
script = get_optimized_script(flow_encoder.half())
script.save('{}/flow.encoder.fp16.zip'.format(args.model_dir))
logging.info('successfully export flow_encoder')
else:
raise ValueError('unsupported model type')
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,114 @@
# Copyright (c) 2024 Antgroup Inc (authors: Zhoubofan, hexisyztem@icloud.com)
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
import os
import sys
import onnxruntime
import random
import torch
from tqdm import tqdm
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
sys.path.append('{}/../..'.format(ROOT_DIR))
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
from cosyvoice.cli.cosyvoice import AutoModel
from cosyvoice.utils.file_utils import logging
def get_dummy_input(batch_size, seq_len, out_channels, device):
x = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
mask = torch.ones((batch_size, 1, seq_len), dtype=torch.float32, device=device)
mu = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
t = torch.rand((batch_size), dtype=torch.float32, device=device)
spks = torch.rand((batch_size, out_channels), dtype=torch.float32, device=device)
cond = torch.rand((batch_size, out_channels, seq_len), dtype=torch.float32, device=device)
return x, mask, mu, t, spks, cond
def get_args():
parser = argparse.ArgumentParser(description='export your model for deployment')
parser.add_argument('--model_dir',
type=str,
default='pretrained_models/CosyVoice-300M',
help='local path')
args = parser.parse_args()
print(args)
return args
@torch.no_grad()
def main():
args = get_args()
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
model = AutoModel(model_dir=args.model_dir)
# 1. export flow decoder estimator
estimator = model.model.flow.decoder.estimator
estimator.eval()
device = model.model.device
batch_size, seq_len = 2, 256
out_channels = model.model.flow.decoder.estimator.out_channels
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, seq_len, out_channels, device)
torch.onnx.export(
estimator,
(x, mask, mu, t, spks, cond),
'{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
export_params=True,
opset_version=18,
do_constant_folding=True,
input_names=['x', 'mask', 'mu', 't', 'spks', 'cond'],
output_names=['estimator_out'],
dynamic_axes={
'x': {2: 'seq_len'},
'mask': {2: 'seq_len'},
'mu': {2: 'seq_len'},
'cond': {2: 'seq_len'},
'estimator_out': {2: 'seq_len'},
}
)
# 2. test computation consistency
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
sess_options=option, providers=providers)
for _ in tqdm(range(10)):
x, mask, mu, t, spks, cond = get_dummy_input(batch_size, random.randint(16, 512), out_channels, device)
output_pytorch = estimator(x, mask, mu, t, spks, cond)
ort_inputs = {
'x': x.cpu().numpy(),
'mask': mask.cpu().numpy(),
'mu': mu.cpu().numpy(),
't': t.cpu().numpy(),
'spks': spks.cpu().numpy(),
'cond': cond.cpu().numpy()
}
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
logging.info('successfully export estimator')
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,195 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import argparse
import datetime
import logging
logging.getLogger('matplotlib').setLevel(logging.WARNING)
from copy import deepcopy
import os
import torch
import torch.distributed as dist
import deepspeed
from hyperpyyaml import load_hyperpyyaml
from torch.distributed.elastic.multiprocessing.errors import record
from cosyvoice.utils.losses import DPOLoss
from cosyvoice.utils.executor import Executor
from cosyvoice.utils.train_utils import (
init_distributed,
init_dataset_and_dataloader,
init_optimizer_and_scheduler,
init_summarywriter, save_model,
wrap_cuda_model, check_modify_and_save_config)
def get_args():
parser = argparse.ArgumentParser(description='training your network')
parser.add_argument('--train_engine',
default='torch_ddp',
choices=['torch_ddp', 'deepspeed'],
help='Engine for paralleled training')
parser.add_argument('--model', required=True, help='model which will be trained')
parser.add_argument('--ref_model', required=False, help='ref model used in dpo')
parser.add_argument('--config', required=True, help='config file')
parser.add_argument('--train_data', required=True, help='train data file')
parser.add_argument('--cv_data', required=True, help='cv data file')
parser.add_argument('--qwen_pretrain_path', required=False, help='qwen pretrain path')
parser.add_argument('--onnx_path', required=False, help='onnx path, which is required for online feature extraction')
parser.add_argument('--checkpoint', help='checkpoint model')
parser.add_argument('--model_dir', required=True, help='save model dir')
parser.add_argument('--tensorboard_dir',
default='tensorboard',
help='tensorboard log dir')
parser.add_argument('--ddp.dist_backend',
dest='dist_backend',
default='nccl',
choices=['nccl', 'gloo'],
help='distributed backend')
parser.add_argument('--num_workers',
default=0,
type=int,
help='num of subprocess workers for reading')
parser.add_argument('--prefetch',
default=100,
type=int,
help='prefetch number')
parser.add_argument('--pin_memory',
action='store_true',
default=False,
help='Use pinned memory buffers used for reading')
parser.add_argument('--use_amp',
action='store_true',
default=False,
help='Use automatic mixed precision training')
parser.add_argument('--dpo',
action='store_true',
default=False,
help='Use Direct Preference Optimization')
parser.add_argument('--deepspeed.save_states',
dest='save_states',
default='model_only',
choices=['model_only', 'model+optimizer'],
help='save model/optimizer states')
parser.add_argument('--timeout',
default=60,
type=int,
help='timeout (in seconds) of cosyvoice_join.')
parser = deepspeed.add_config_arguments(parser)
args = parser.parse_args()
return args
@record
def main():
args = get_args()
os.environ['onnx_path'] = args.onnx_path
logging.basicConfig(level=logging.DEBUG,
format='%(asctime)s %(levelname)s %(message)s')
# gan train has some special initialization logic
gan = True if args.model == 'hifigan' else False
override_dict = {k: None for k in ['llm', 'flow', 'hift', 'hifigan'] if k != args.model}
if gan is True:
override_dict.pop('hift')
if args.qwen_pretrain_path is not None:
override_dict['qwen_pretrain_path'] = args.qwen_pretrain_path
with open(args.config, 'r') as f:
configs = load_hyperpyyaml(f, overrides=override_dict)
if gan is True:
configs['train_conf'] = configs['train_conf_gan']
configs['train_conf'].update(vars(args))
# Init env for ddp
init_distributed(args)
# Get dataset & dataloader
train_dataset, cv_dataset, train_data_loader, cv_data_loader = \
init_dataset_and_dataloader(args, configs, gan, args.dpo)
# Do some sanity checks and save config to arsg.model_dir
configs = check_modify_and_save_config(args, configs)
# Tensorboard summary
writer = init_summarywriter(args)
# load checkpoint
if args.dpo is True:
configs[args.model].forward = configs[args.model].forward_dpo
model = configs[args.model]
start_step, start_epoch = 0, -1
if args.checkpoint is not None:
if os.path.exists(args.checkpoint):
state_dict = torch.load(args.checkpoint, map_location='cpu')
model.load_state_dict(state_dict, strict=False)
if 'step' in state_dict:
start_step = state_dict['step']
if 'epoch' in state_dict:
start_epoch = state_dict['epoch']
else:
logging.warning('checkpoint {} do not exsist!'.format(args.checkpoint))
# Dispatch model from cpu to gpu
model = wrap_cuda_model(args, model)
# Get optimizer & scheduler
model, optimizer, scheduler, optimizer_d, scheduler_d = init_optimizer_and_scheduler(args, configs, model, gan)
scheduler.set_step(start_step)
if scheduler_d is not None:
scheduler_d.set_step(start_step)
# Save init checkpoints
info_dict = deepcopy(configs['train_conf'])
info_dict['step'] = start_step
info_dict['epoch'] = start_epoch
save_model(model, 'init', info_dict)
# DPO related
if args.dpo is True:
ref_model = deepcopy(configs[args.model])
state_dict = torch.load(args.ref_model, map_location='cpu')
ref_model.load_state_dict(state_dict, strict=False)
dpo_loss = DPOLoss(beta=0.01, label_smoothing=0.0, ipo=False)
# NOTE maybe it is not needed to wrap ref_model as ddp because its parameter is not updated
ref_model = wrap_cuda_model(args, ref_model)
else:
ref_model, dpo_loss = None, None
# Get executor
executor = Executor(gan=gan, ref_model=ref_model, dpo_loss=dpo_loss)
executor.step = start_step
# Init scaler, used for pytorch amp mixed precision training
scaler = torch.cuda.amp.GradScaler() if args.use_amp else None
print('start step {} start epoch {}'.format(start_step, start_epoch))
# Start training loop
for epoch in range(start_epoch + 1, info_dict['max_epoch']):
executor.epoch = epoch
train_dataset.set_epoch(epoch)
dist.barrier()
group_join = dist.new_group(backend="gloo", timeout=datetime.timedelta(seconds=args.timeout))
if gan is True:
executor.train_one_epoc_gan(model, optimizer, scheduler, optimizer_d, scheduler_d, train_data_loader, cv_data_loader,
writer, info_dict, scaler, group_join)
else:
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, scaler, group_join, ref_model=ref_model)
dist.destroy_process_group(group_join)
if __name__ == '__main__':
main()

View File

@@ -0,0 +1,240 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import time
from typing import Generator
from tqdm import tqdm
from hyperpyyaml import load_hyperpyyaml
from modelscope import snapshot_download
import torch
from cosyvoice.cli.frontend import CosyVoiceFrontEnd
from cosyvoice.cli.model import CosyVoiceModel, CosyVoice2Model, CosyVoice3Model
from cosyvoice.utils.file_utils import logging
from cosyvoice.utils.class_utils import get_model_type
class CosyVoice:
def __init__(self, model_dir, load_jit=False, load_trt=False, fp16=False, trt_concurrent=1):
self.model_dir = model_dir
self.fp16 = fp16
if not os.path.exists(model_dir):
model_dir = snapshot_download(model_dir)
hyper_yaml_path = '{}/cosyvoice.yaml'.format(model_dir)
if not os.path.exists(hyper_yaml_path):
raise ValueError('{} not found!'.format(hyper_yaml_path))
with open(hyper_yaml_path, 'r') as f:
configs = load_hyperpyyaml(f)
assert get_model_type(configs) == CosyVoiceModel, 'do not use {} for CosyVoice initialization!'.format(model_dir)
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
configs['feat_extractor'],
'{}/campplus.onnx'.format(model_dir),
'{}/speech_tokenizer_v1.onnx'.format(model_dir),
'{}/spk2info.pt'.format(model_dir),
configs['allowed_special'])
self.sample_rate = configs['sample_rate']
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or fp16 is True):
load_jit, load_trt, fp16 = False, False, False
logging.warning('no cuda device, set load_jit/load_trt/fp16 to False')
self.model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'], fp16)
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
'{}/hift.pt'.format(model_dir))
if load_jit:
self.model.load_jit('{}/llm.text_encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/llm.llm.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
if load_trt:
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
trt_concurrent,
self.fp16)
del configs
def list_available_spks(self):
spks = list(self.frontend.spk2info.keys())
return spks
def add_zero_shot_spk(self, prompt_text, prompt_wav, zero_shot_spk_id):
assert zero_shot_spk_id != '', 'do not use empty zero_shot_spk_id'
model_input = self.frontend.frontend_zero_shot('', prompt_text, prompt_wav, self.sample_rate, '')
del model_input['text']
del model_input['text_len']
self.frontend.spk2info[zero_shot_spk_id] = model_input
return True
def save_spkinfo(self):
torch.save(self.frontend.spk2info, '{}/spk2info.pt'.format(self.model_dir))
def inference_sft(self, tts_text, spk_id, stream=False, speed=1.0, text_frontend=True):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_sft(i, spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
def inference_zero_shot(self, tts_text, prompt_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
if self.__class__.__name__ == 'CosyVoice3' and '<|endofprompt|>' not in prompt_text + tts_text:
logging.warning('<|endofprompt|> not found in CosyVoice3 inference, check your input text')
prompt_text = self.frontend.text_normalize(prompt_text, split=False, text_frontend=text_frontend)
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
if (not isinstance(i, Generator)) and len(i) < 0.5 * len(prompt_text):
logging.warning('synthesis text {} too short than prompt text {}, this may lead to bad performance'.format(i, prompt_text))
model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_wav, self.sample_rate, zero_shot_spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
def inference_cross_lingual(self, tts_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_cross_lingual(i, prompt_wav, self.sample_rate, zero_shot_spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False, speed=1.0, text_frontend=True):
assert self.__class__.__name__ == 'CosyVoice', 'inference_instruct is only implemented for CosyVoice!'
instruct_text = self.frontend.text_normalize(instruct_text, split=False, text_frontend=text_frontend)
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
def inference_vc(self, source_wav, prompt_wav, stream=False, speed=1.0):
model_input = self.frontend.frontend_vc(source_wav, prompt_wav, self.sample_rate)
start_time = time.time()
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
class CosyVoice2(CosyVoice):
def __init__(self, model_dir, load_jit=False, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
self.model_dir = model_dir
self.fp16 = fp16
if not os.path.exists(model_dir):
model_dir = snapshot_download(model_dir)
hyper_yaml_path = '{}/cosyvoice2.yaml'.format(model_dir)
if not os.path.exists(hyper_yaml_path):
raise ValueError('{} not found!'.format(hyper_yaml_path))
with open(hyper_yaml_path, 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
assert get_model_type(configs) == CosyVoice2Model, 'do not use {} for CosyVoice2 initialization!'.format(model_dir)
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
configs['feat_extractor'],
'{}/campplus.onnx'.format(model_dir),
'{}/speech_tokenizer_v2.onnx'.format(model_dir),
'{}/spk2info.pt'.format(model_dir),
configs['allowed_special'])
self.sample_rate = configs['sample_rate']
if torch.cuda.is_available() is False and (load_jit is True or load_trt is True or load_vllm is True or fp16 is True):
load_jit, load_trt, load_vllm, fp16 = False, False, False, False
logging.warning('no cuda device, set load_jit/load_trt/load_vllm/fp16 to False')
self.model = CosyVoice2Model(configs['llm'], configs['flow'], configs['hift'], fp16)
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
'{}/hift.pt'.format(model_dir))
if load_vllm:
self.model.load_vllm('{}/vllm'.format(model_dir))
if load_jit:
self.model.load_jit('{}/flow.encoder.{}.zip'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'))
if load_trt:
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
trt_concurrent,
self.fp16)
del configs
def inference_instruct2(self, tts_text, instruct_text, prompt_wav, zero_shot_spk_id='', stream=False, speed=1.0, text_frontend=True):
for i in tqdm(self.frontend.text_normalize(tts_text, split=True, text_frontend=text_frontend)):
model_input = self.frontend.frontend_instruct2(i, instruct_text, prompt_wav, self.sample_rate, zero_shot_spk_id)
start_time = time.time()
logging.info('synthesis text {}'.format(i))
for model_output in self.model.tts(**model_input, stream=stream, speed=speed):
speech_len = model_output['tts_speech'].shape[1] / self.sample_rate
logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
yield model_output
start_time = time.time()
class CosyVoice3(CosyVoice2):
def __init__(self, model_dir, load_trt=False, load_vllm=False, fp16=False, trt_concurrent=1):
self.model_dir = model_dir
self.fp16 = fp16
if not os.path.exists(model_dir):
model_dir = snapshot_download(model_dir)
hyper_yaml_path = '{}/cosyvoice3.yaml'.format(model_dir)
if not os.path.exists(hyper_yaml_path):
raise ValueError('{} not found!'.format(hyper_yaml_path))
with open(hyper_yaml_path, 'r') as f:
configs = load_hyperpyyaml(f, overrides={'qwen_pretrain_path': os.path.join(model_dir, 'CosyVoice-BlankEN')})
assert get_model_type(configs) == CosyVoice3Model, 'do not use {} for CosyVoice3 initialization!'.format(model_dir)
self.frontend = CosyVoiceFrontEnd(configs['get_tokenizer'],
configs['feat_extractor'],
'{}/campplus.onnx'.format(model_dir),
'{}/speech_tokenizer_v3.onnx'.format(model_dir),
'{}/spk2info.pt'.format(model_dir),
configs['allowed_special'])
self.sample_rate = configs['sample_rate']
if torch.cuda.is_available() is False and (load_trt is True or fp16 is True):
load_trt, fp16 = False, False
logging.warning('no cuda device, set load_trt/fp16 to False')
self.model = CosyVoice3Model(configs['llm'], configs['flow'], configs['hift'], fp16)
self.model.load('{}/llm.pt'.format(model_dir),
'{}/flow.pt'.format(model_dir),
'{}/hift.pt'.format(model_dir))
if load_vllm:
self.model.load_vllm('{}/vllm'.format(model_dir))
if load_trt:
if self.fp16 is True:
logging.warning('DiT tensorRT fp16 engine have some performance issue, use at caution!')
self.model.load_trt('{}/flow.decoder.estimator.{}.mygpu.plan'.format(model_dir, 'fp16' if self.fp16 is True else 'fp32'),
'{}/flow.decoder.estimator.fp32.onnx'.format(model_dir),
trt_concurrent,
self.fp16)
del configs
def AutoModel(**kwargs):
if not os.path.exists(kwargs['model_dir']):
kwargs['model_dir'] = snapshot_download(kwargs['model_dir'])
if os.path.exists('{}/cosyvoice.yaml'.format(kwargs['model_dir'])):
return CosyVoice(**kwargs)
elif os.path.exists('{}/cosyvoice2.yaml'.format(kwargs['model_dir'])):
return CosyVoice2(**kwargs)
elif os.path.exists('{}/cosyvoice3.yaml'.format(kwargs['model_dir'])):
return CosyVoice3(**kwargs)
else:
raise TypeError('No valid model type found!')

View File

@@ -0,0 +1,224 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from functools import partial
from typing import Generator
import json
import onnxruntime
import torch
import numpy as np
import whisper
from typing import Callable
import torchaudio.compliance.kaldi as kaldi
import os
import re
import inflect
from cosyvoice.utils.file_utils import logging, load_wav
from cosyvoice.utils.frontend_utils import contains_chinese, replace_blank, replace_corner_mark, remove_bracket, spell_out_number, split_paragraph, is_only_punctuation
class CosyVoiceFrontEnd:
def __init__(self,
get_tokenizer: Callable,
feat_extractor: Callable,
campplus_model: str,
speech_tokenizer_model: str,
spk2info: str = '',
allowed_special: str = 'all'):
self.tokenizer = get_tokenizer()
self.feat_extractor = feat_extractor
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
option = onnxruntime.SessionOptions()
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
option.intra_op_num_threads = 1
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
"CPUExecutionProvider"])
if os.path.exists(spk2info):
self.spk2info = torch.load(spk2info, map_location=self.device, weights_only=True)
else:
self.spk2info = {}
self.allowed_special = allowed_special
self.inflect_parser = inflect.engine()
# NOTE compatible when no text frontend tool is avaliable
try:
import ttsfrd
self.frd = ttsfrd.TtsFrontendEngine()
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
'failed to initialize ttsfrd resource'
self.frd.set_lang_type('pinyinvg')
self.text_frontend = 'ttsfrd'
logging.info('use ttsfrd frontend')
except:
try:
from wetext import Normalizer as ZhNormalizer
from wetext import Normalizer as EnNormalizer
self.zh_tn_model = ZhNormalizer(remove_erhua=False)
self.en_tn_model = EnNormalizer()
self.text_frontend = 'wetext'
logging.info('use wetext frontend')
except:
self.text_frontend = ''
logging.info('no frontend is avaliable')
def _extract_text_token(self, text):
if isinstance(text, Generator):
logging.info('get tts_text generator, will return _extract_text_token_generator!')
# NOTE add a dummy text_token_len for compatibility
return self._extract_text_token_generator(text), torch.tensor([0], dtype=torch.int32).to(self.device)
else:
text_token = self.tokenizer.encode(text, allowed_special=self.allowed_special)
text_token = torch.tensor([text_token], dtype=torch.int32).to(self.device)
text_token_len = torch.tensor([text_token.shape[1]], dtype=torch.int32).to(self.device)
return text_token, text_token_len
def _extract_text_token_generator(self, text_generator):
for text in text_generator:
text_token, _ = self._extract_text_token(text)
for i in range(text_token.shape[1]):
yield text_token[:, i: i + 1]
def _extract_speech_token(self, prompt_wav):
speech = load_wav(prompt_wav, 16000)
assert speech.shape[1] / 16000 <= 30, 'do not support extract speech token for audio longer than 30s'
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
speech_token = self.speech_tokenizer_session.run(None,
{self.speech_tokenizer_session.get_inputs()[0].name:
feat.detach().cpu().numpy(),
self.speech_tokenizer_session.get_inputs()[1].name:
np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
return speech_token, speech_token_len
def _extract_spk_embedding(self, prompt_wav):
speech = load_wav(prompt_wav, 16000)
feat = kaldi.fbank(speech,
num_mel_bins=80,
dither=0,
sample_frequency=16000)
feat = feat - feat.mean(dim=0, keepdim=True)
embedding = self.campplus_session.run(None,
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
embedding = torch.tensor([embedding]).to(self.device)
return embedding
def _extract_speech_feat(self, prompt_wav):
speech = load_wav(prompt_wav, 24000)
speech_feat = self.feat_extractor(speech).squeeze(dim=0).transpose(0, 1).to(self.device)
speech_feat = speech_feat.unsqueeze(dim=0)
speech_feat_len = torch.tensor([speech_feat.shape[1]], dtype=torch.int32).to(self.device)
return speech_feat, speech_feat_len
def text_normalize(self, text, split=True, text_frontend=True):
if isinstance(text, Generator):
logging.info('get tts_text generator, will skip text_normalize!')
return [text]
# NOTE skip text_frontend when ssml symbol in text
if '<|' in text and '|>' in text:
text_frontend = False
if text_frontend is False or text == '':
return [text] if split is True else text
text = text.strip()
if self.text_frontend == 'ttsfrd':
texts = [i["text"] for i in json.loads(self.frd.do_voicegen_frd(text))["sentences"]]
text = ''.join(texts)
else:
if contains_chinese(text):
if self.text_frontend == 'wetext':
text = self.zh_tn_model.normalize(text)
text = text.replace("\n", "")
text = replace_blank(text)
text = replace_corner_mark(text)
text = text.replace(".", "")
text = text.replace(" - ", "")
text = remove_bracket(text)
text = re.sub(r'[,、]+$', '', text)
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False))
else:
if self.text_frontend == 'wetext':
text = self.en_tn_model.normalize(text)
text = spell_out_number(text, self.inflect_parser)
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
token_min_n=60, merge_len=20, comma_split=False))
texts = [i for i in texts if not is_only_punctuation(i)]
return texts if split is True else text
def frontend_sft(self, tts_text, spk_id):
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
embedding = self.spk2info[spk_id]['embedding']
model_input = {'text': tts_text_token, 'text_len': tts_text_token_len, 'llm_embedding': embedding, 'flow_embedding': embedding}
return model_input
def frontend_zero_shot(self, tts_text, prompt_text, prompt_wav, resample_rate, zero_shot_spk_id):
tts_text_token, tts_text_token_len = self._extract_text_token(tts_text)
if zero_shot_spk_id == '':
prompt_text_token, prompt_text_token_len = self._extract_text_token(prompt_text)
speech_feat, speech_feat_len = self._extract_speech_feat(prompt_wav)
speech_token, speech_token_len = self._extract_speech_token(prompt_wav)
if resample_rate == 24000:
# cosyvoice2, force speech_feat % speech_token = 2
token_len = min(int(speech_feat.shape[1] / 2), speech_token.shape[1])
speech_feat, speech_feat_len[:] = speech_feat[:, :2 * token_len], 2 * token_len
speech_token, speech_token_len[:] = speech_token[:, :token_len], token_len
embedding = self._extract_spk_embedding(prompt_wav)
model_input = {'prompt_text': prompt_text_token, 'prompt_text_len': prompt_text_token_len,
'llm_prompt_speech_token': speech_token, 'llm_prompt_speech_token_len': speech_token_len,
'flow_prompt_speech_token': speech_token, 'flow_prompt_speech_token_len': speech_token_len,
'prompt_speech_feat': speech_feat, 'prompt_speech_feat_len': speech_feat_len,
'llm_embedding': embedding, 'flow_embedding': embedding}
else:
model_input = {**self.spk2info[zero_shot_spk_id]}
model_input['text'] = tts_text_token
model_input['text_len'] = tts_text_token_len
return model_input
def frontend_cross_lingual(self, tts_text, prompt_wav, resample_rate, zero_shot_spk_id):
model_input = self.frontend_zero_shot(tts_text, '', prompt_wav, resample_rate, zero_shot_spk_id)
# in cross lingual mode, we remove prompt in llm
del model_input['prompt_text']
del model_input['prompt_text_len']
del model_input['llm_prompt_speech_token']
del model_input['llm_prompt_speech_token_len']
return model_input
def frontend_instruct(self, tts_text, spk_id, instruct_text):
model_input = self.frontend_sft(tts_text, spk_id)
# in instruct mode, we remove spk_embedding in llm due to information leakage
del model_input['llm_embedding']
instruct_text_token, instruct_text_token_len = self._extract_text_token(instruct_text)
model_input['prompt_text'] = instruct_text_token
model_input['prompt_text_len'] = instruct_text_token_len
return model_input
def frontend_instruct2(self, tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id):
model_input = self.frontend_zero_shot(tts_text, instruct_text, prompt_wav, resample_rate, zero_shot_spk_id)
del model_input['llm_prompt_speech_token']
del model_input['llm_prompt_speech_token_len']
return model_input
def frontend_vc(self, source_speech_16k, prompt_wav, resample_rate):
prompt_speech_token, prompt_speech_token_len = self._extract_speech_token(prompt_wav)
prompt_speech_feat, prompt_speech_feat_len = self._extract_speech_feat(prompt_wav)
embedding = self._extract_spk_embedding(prompt_wav)
source_speech_token, source_speech_token_len = self._extract_speech_token(source_speech_16k)
model_input = {'source_speech_token': source_speech_token, 'source_speech_token_len': source_speech_token_len,
'flow_prompt_speech_token': prompt_speech_token, 'flow_prompt_speech_token_len': prompt_speech_token_len,
'prompt_speech_feat': prompt_speech_feat, 'prompt_speech_feat_len': prompt_speech_feat_len,
'flow_embedding': embedding}
return model_input

View File

@@ -0,0 +1,450 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
# 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from typing import Generator
import torch
import numpy as np
import threading
import time
from torch.nn import functional as F
from contextlib import nullcontext
import uuid
from cosyvoice.utils.common import fade_in_out
from cosyvoice.utils.file_utils import convert_onnx_to_trt, export_cosyvoice2_vllm
from cosyvoice.utils.common import TrtContextWrapper
class CosyVoiceModel:
def __init__(self,
llm: torch.nn.Module,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool = False):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.llm = llm
self.flow = flow
self.hift = hift
self.fp16 = fp16
self.token_min_hop_len = 2 * self.flow.input_frame_rate
self.token_max_hop_len = 4 * self.flow.input_frame_rate
self.token_overlap_len = 20
# mel fade in out
self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
self.mel_window = np.hamming(2 * self.mel_overlap_len)
# hift cache
self.mel_cache_len = 20
self.source_cache_len = int(self.mel_cache_len * 256)
# speech fade in out
self.speech_window = np.hamming(2 * self.source_cache_len)
# rtf and decoding related
self.stream_scale_factor = 1
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.lock = threading.Lock()
# dict used to store session related variable
self.tts_speech_token_dict = {}
self.llm_end_dict = {}
self.mel_overlap_dict = {}
self.flow_cache_dict = {}
self.hift_cache_dict = {}
self.silent_tokens = []
def load(self, llm_model, flow_model, hift_model):
self.llm.load_state_dict(torch.load(llm_model, map_location=self.device, weights_only=True), strict=True)
self.llm.to(self.device).eval()
self.flow.load_state_dict(torch.load(flow_model, map_location=self.device, weights_only=True), strict=True)
self.flow.to(self.device).eval()
# in case hift_model is a hifigan model
hift_state_dict = {k.replace('generator.', ''): v for k, v in torch.load(hift_model, map_location=self.device, weights_only=True).items()}
self.hift.load_state_dict(hift_state_dict, strict=True)
self.hift.to(self.device).eval()
def load_jit(self, llm_text_encoder_model, llm_llm_model, flow_encoder_model):
llm_text_encoder = torch.jit.load(llm_text_encoder_model, map_location=self.device)
self.llm.text_encoder = llm_text_encoder
llm_llm = torch.jit.load(llm_llm_model, map_location=self.device)
self.llm.llm = llm_llm
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder
def load_trt(self, flow_decoder_estimator_model, flow_decoder_onnx_model, trt_concurrent, fp16):
assert torch.cuda.is_available(), 'tensorrt only supports gpu!'
if not os.path.exists(flow_decoder_estimator_model) or os.path.getsize(flow_decoder_estimator_model) == 0:
convert_onnx_to_trt(flow_decoder_estimator_model, self.get_trt_kwargs(), flow_decoder_onnx_model, fp16)
del self.flow.decoder.estimator
import tensorrt as trt
with open(flow_decoder_estimator_model, 'rb') as f:
estimator_engine = trt.Runtime(trt.Logger(trt.Logger.INFO)).deserialize_cuda_engine(f.read())
assert estimator_engine is not None, 'failed to load trt {}'.format(flow_decoder_estimator_model)
self.flow.decoder.estimator = TrtContextWrapper(estimator_engine, trt_concurrent=trt_concurrent, device=self.device)
def get_trt_kwargs(self):
min_shape = [(2, 80, 4), (2, 1, 4), (2, 80, 4), (2, 80, 4)]
opt_shape = [(2, 80, 500), (2, 1, 500), (2, 80, 500), (2, 80, 500)]
max_shape = [(2, 80, 3000), (2, 1, 3000), (2, 80, 3000), (2, 80, 3000)]
input_names = ["x", "mask", "mu", "cond"]
return {'min_shape': min_shape, 'opt_shape': opt_shape, 'max_shape': max_shape, 'input_names': input_names}
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
cur_silent_token_num, max_silent_token_num = 0, 5
with self.llm_context, torch.cuda.amp.autocast(self.fp16 is True and hasattr(self.llm, 'vllm') is False):
if isinstance(text, Generator):
assert (self.__class__.__name__ != 'CosyVoiceModel') and not hasattr(self.llm, 'vllm'), 'streaming input text is only implemented for CosyVoice2/3 and do not support vllm!'
token_generator = self.llm.inference_bistream(text=text,
prompt_text=prompt_text.to(self.device),
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
embedding=llm_embedding.to(self.device))
else:
token_generator = self.llm.inference(text=text.to(self.device),
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
prompt_text=prompt_text.to(self.device),
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
prompt_speech_token=llm_prompt_speech_token.to(self.device),
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
embedding=llm_embedding.to(self.device),
uuid=uuid)
for i in token_generator:
if i in self.silent_tokens:
cur_silent_token_num += 1
if cur_silent_token_num > max_silent_token_num:
continue
else:
cur_silent_token_num = 0
self.tts_speech_token_dict[uuid].append(i)
self.llm_end_dict[uuid] = True
def vc_job(self, source_speech_token, uuid):
self.tts_speech_token_dict[uuid] = source_speech_token.flatten().tolist()
self.llm_end_dict[uuid] = True
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16):
tts_mel, self.flow_cache_dict[uuid] = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
flow_cache=self.flow_cache_dict[uuid])
# mel overlap fade in out
if self.mel_overlap_dict[uuid].shape[2] != 0:
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
# append hift cache
if self.hift_cache_dict[uuid] is not None:
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
else:
hift_cache_source = torch.zeros(1, 1, 0)
# keep overlap mel and hift cache
if finalize is False:
self.mel_overlap_dict[uuid] = tts_mel[:, :, -self.mel_overlap_len:]
tts_mel = tts_mel[:, :, :-self.mel_overlap_len]
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
if self.hift_cache_dict[uuid] is not None:
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
'source': tts_source[:, :, -self.source_cache_len:],
'speech': tts_speech[:, -self.source_cache_len:]}
tts_speech = tts_speech[:, :-self.source_cache_len]
else:
if speed != 1.0:
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
if self.hift_cache_dict[uuid] is not None:
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
return tts_speech
def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread
this_uuid = str(uuid.uuid1())
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.hift_cache_dict[this_uuid] = None
self.mel_overlap_dict[this_uuid] = torch.zeros(1, 80, 0)
self.flow_cache_dict[this_uuid] = torch.zeros(1, 80, 0, 2)
if source_speech_token.shape[1] == 0:
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
else:
p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
p.start()
if stream is True:
token_hop_len = self.token_min_hop_len
while True:
time.sleep(0.1)
if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
.unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=False)
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
# increase token_hop_len for better speech quality
token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
break
p.join()
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=True)
yield {'tts_speech': this_tts_speech.cpu()}
else:
# deal with all tokens
p.join()
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
uuid=this_uuid,
finalize=True,
speed=speed)
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid)
self.mel_overlap_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
self.flow_cache_dict.pop(this_uuid)
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.current_stream().synchronize()
class CosyVoice2Model(CosyVoiceModel):
def __init__(self,
llm: torch.nn.Module,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool = False):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.llm = llm
self.flow = flow
self.hift = hift
self.fp16 = fp16
# NOTE must matching training static_chunk_size
self.token_hop_len = 25
# NOTE increase token_hop_len incrementally to avoid duplicate inference
self.token_max_hop_len = 4 * self.token_hop_len
self.stream_scale_factor = 2
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
# hift cache
self.mel_cache_len = 8
self.source_cache_len = int(self.mel_cache_len * 480)
# speech fade in out
self.speech_window = np.hamming(2 * self.source_cache_len)
# rtf and decoding related
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.lock = threading.Lock()
# dict used to store session related variable
self.tts_speech_token_dict = {}
self.llm_end_dict = {}
self.hift_cache_dict = {}
self.silent_tokens = []
def load_jit(self, flow_encoder_model):
flow_encoder = torch.jit.load(flow_encoder_model, map_location=self.device)
self.flow.encoder = flow_encoder
def load_vllm(self, model_dir):
export_cosyvoice2_vllm(self.llm, model_dir, self.device)
from vllm import EngineArgs, LLMEngine
engine_args = EngineArgs(model=model_dir,
skip_tokenizer_init=True,
enable_prompt_embeds=True,
gpu_memory_utilization=0.2)
self.llm.vllm = LLMEngine.from_engine_args(engine_args)
self.llm.lock = threading.Lock()
del self.llm.llm.model.model.layers
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16):
tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
streaming=stream,
finalize=finalize)
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
# append hift cache
if self.hift_cache_dict[uuid] is not None:
hift_cache_mel, hift_cache_source = self.hift_cache_dict[uuid]['mel'], self.hift_cache_dict[uuid]['source']
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
else:
hift_cache_source = torch.zeros(1, 1, 0)
# keep overlap mel and hift cache
if finalize is False:
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
if self.hift_cache_dict[uuid] is not None:
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
self.hift_cache_dict[uuid] = {'mel': tts_mel[:, :, -self.mel_cache_len:],
'source': tts_source[:, :, -self.source_cache_len:],
'speech': tts_speech[:, -self.source_cache_len:]}
tts_speech = tts_speech[:, :-self.source_cache_len]
else:
if speed != 1.0:
assert self.hift_cache_dict[uuid] is None, 'speed change only support non-stream inference mode'
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
tts_speech, tts_source = self.hift.inference(speech_feat=tts_mel, cache_source=hift_cache_source)
if self.hift_cache_dict[uuid] is not None:
tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
return tts_speech
def tts(self, text=torch.zeros(1, 0, dtype=torch.int32), flow_embedding=torch.zeros(0, 192), llm_embedding=torch.zeros(0, 192),
prompt_text=torch.zeros(1, 0, dtype=torch.int32),
llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
prompt_speech_feat=torch.zeros(1, 0, 80), source_speech_token=torch.zeros(1, 0, dtype=torch.int32), stream=False, speed=1.0, **kwargs):
# this_uuid is used to track variables related to this inference thread
this_uuid = str(uuid.uuid1())
with self.lock:
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
self.hift_cache_dict[this_uuid] = None
if source_speech_token.shape[1] == 0:
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
else:
p = threading.Thread(target=self.vc_job, args=(source_speech_token, this_uuid))
p.start()
if stream is True:
token_offset = 0
prompt_token_pad = int(np.ceil(flow_prompt_speech_token.shape[1] / self.token_hop_len) * self.token_hop_len - flow_prompt_speech_token.shape[1])
while True:
time.sleep(0.1)
this_token_hop_len = self.token_hop_len + prompt_token_pad if token_offset == 0 else self.token_hop_len
if len(self.tts_speech_token_dict[this_uuid]) - token_offset >= this_token_hop_len + self.flow.pre_lookahead_len:
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_offset + this_token_hop_len + self.flow.pre_lookahead_len]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
token_offset=token_offset,
uuid=this_uuid,
stream=stream,
finalize=False)
token_offset += this_token_hop_len
self.token_hop_len = min(self.token_max_hop_len, self.token_hop_len * self.stream_scale_factor)
yield {'tts_speech': this_tts_speech.cpu()}
if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) - token_offset < this_token_hop_len + self.flow.pre_lookahead_len:
break
p.join()
# deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
token_offset=token_offset,
uuid=this_uuid,
finalize=True)
yield {'tts_speech': this_tts_speech.cpu()}
else:
# deal with all tokens
p.join()
this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
this_tts_speech = self.token2wav(token=this_tts_speech_token,
prompt_token=flow_prompt_speech_token,
prompt_feat=prompt_speech_feat,
embedding=flow_embedding,
token_offset=0,
uuid=this_uuid,
finalize=True,
speed=speed)
yield {'tts_speech': this_tts_speech.cpu()}
with self.lock:
self.tts_speech_token_dict.pop(this_uuid)
self.llm_end_dict.pop(this_uuid)
self.hift_cache_dict.pop(this_uuid)
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.current_stream().synchronize()
class CosyVoice3Model(CosyVoice2Model):
def __init__(self,
llm: torch.nn.Module,
flow: torch.nn.Module,
hift: torch.nn.Module,
fp16: bool = False):
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
self.llm = llm
self.flow = flow
self.hift = hift
self.fp16 = fp16
# NOTE must matching training static_chunk_size
self.token_hop_len = 25
# NOTE increase token_hop_len incrementally to avoid duplicate inference
self.token_max_hop_len = 4 * self.token_hop_len
self.stream_scale_factor = 2
assert self.stream_scale_factor >= 1, 'stream_scale_factor should be greater than 1, change it according to your actual rtf'
# rtf and decoding related
self.llm_context = torch.cuda.stream(torch.cuda.Stream(self.device)) if torch.cuda.is_available() else nullcontext()
self.lock = threading.Lock()
# dict used to store session related variable
self.tts_speech_token_dict = {}
self.llm_end_dict = {}
self.hift_cache_dict = {}
# FSQ silent and breath token
self.silent_tokens = [1, 2, 28, 29, 55, 248, 494, 2241, 2242, 2322, 2323]
def token2wav(self, token, prompt_token, prompt_feat, embedding, token_offset, uuid, stream=False, finalize=False, speed=1.0):
with torch.cuda.amp.autocast(self.fp16):
tts_mel, _ = self.flow.inference(token=token.to(self.device, dtype=torch.int32),
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
prompt_token=prompt_token.to(self.device),
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
prompt_feat=prompt_feat.to(self.device),
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
embedding=embedding.to(self.device),
streaming=stream,
finalize=finalize)
tts_mel = tts_mel[:, :, token_offset * self.flow.token_mel_ratio:]
# append mel cache
if self.hift_cache_dict[uuid] is not None:
hift_cache_mel = self.hift_cache_dict[uuid]['mel']
tts_mel = torch.concat([hift_cache_mel, tts_mel], dim=2)
self.hift_cache_dict[uuid]['mel'] = tts_mel
else:
self.hift_cache_dict[uuid] = {'mel': tts_mel, 'speech_offset': 0}
if speed != 1.0:
assert token_offset == 0 and finalize is True, 'speed change only support non-stream inference mode'
tts_mel = F.interpolate(tts_mel, size=int(tts_mel.shape[2] / speed), mode='linear')
tts_speech, _ = self.hift.inference(speech_feat=tts_mel, finalize=finalize)
tts_speech = tts_speech[:, self.hift_cache_dict[uuid]['speech_offset']:]
self.hift_cache_dict[uuid]['speech_offset'] += tts_speech.shape[1]
return tts_speech

View File

@@ -0,0 +1,155 @@
# Copyright (c) 2021 Mobvoi Inc. (authors: Binbin Zhang)
# 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
import math
from functools import partial
import torch
import torch.distributed as dist
from torch.utils.data import IterableDataset
from cosyvoice.utils.file_utils import read_lists
class Processor(IterableDataset):
def __init__(self, source, f, *args, **kw):
assert callable(f)
self.source = source
self.f = f
self.args = args
self.kw = kw
def set_epoch(self, epoch):
self.source.set_epoch(epoch)
def __iter__(self):
""" Return an iterator over the source dataset processed by the
given processor.
"""
assert self.source is not None
assert callable(self.f)
return self.f(iter(self.source), *self.args, **self.kw)
def apply(self, f):
assert callable(f)
return Processor(self, f, *self.args, **self.kw)
class DistributedSampler:
def __init__(self, shuffle=True, partition=True):
self.epoch = -1
self.update()
self.shuffle = shuffle
self.partition = partition
def update(self):
assert dist.is_available()
if dist.is_initialized():
self.rank = dist.get_rank()
self.world_size = dist.get_world_size()
else:
self.rank = 0
self.world_size = 1
worker_info = torch.utils.data.get_worker_info()
if worker_info is None:
self.worker_id = 0
self.num_workers = 1
else:
self.worker_id = worker_info.id
self.num_workers = worker_info.num_workers
return dict(rank=self.rank,
world_size=self.world_size,
worker_id=self.worker_id,
num_workers=self.num_workers)
def set_epoch(self, epoch):
self.epoch = epoch
def sample(self, data):
""" Sample data according to rank/world_size/num_workers
Args:
data(List): input data list
Returns:
List: data list after sample
"""
data = list(range(len(data)))
# force datalist even
if self.partition:
if self.shuffle:
random.Random(self.epoch).shuffle(data)
if len(data) < self.world_size:
data = data * math.ceil(self.world_size / len(data))
data = data[:self.world_size]
data = data[self.rank::self.world_size]
if len(data) < self.num_workers:
data = data * math.ceil(self.num_workers / len(data))
data = data[:self.num_workers]
data = data[self.worker_id::self.num_workers]
return data
class DataList(IterableDataset):
def __init__(self, lists, shuffle=True, partition=True):
self.lists = lists
self.sampler = DistributedSampler(shuffle, partition)
def set_epoch(self, epoch):
self.sampler.set_epoch(epoch)
def __iter__(self):
sampler_info = self.sampler.update()
indexes = self.sampler.sample(self.lists)
for index in indexes:
data = dict(src=self.lists[index])
data.update(sampler_info)
yield data
def Dataset(data_list_file,
data_pipeline,
mode='train',
gan=False,
dpo=False,
shuffle=True,
partition=True):
""" Construct dataset from arguments
We have two shuffle stage in the Dataset. The first is global
shuffle at shards tar/raw file level. The second is global shuffle
at training samples level.
Args:
data_type(str): raw/shard
tokenizer (BaseTokenizer): tokenizer to tokenize
partition(bool): whether to do data partition in terms of rank
"""
lists = read_lists(data_list_file)
dataset = DataList(lists,
shuffle=shuffle,
partition=partition)
# map partial arg to padding func
for i in range(1, len(data_pipeline)):
if data_pipeline[i].func.__name__ == 'compute_fbank' and gan is True:
data_pipeline[i] = partial(data_pipeline[i], token_mel_ratio=0)
if data_pipeline[i].func.__name__ == 'padding':
data_pipeline[i] = partial(data_pipeline[i], gan=gan, dpo=dpo)
for func in data_pipeline:
dataset = Processor(dataset, func, mode=mode)
return dataset

View File

@@ -0,0 +1,431 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import random
import pyarrow.parquet as pq
from io import BytesIO
import numpy as np
import whisper
import torch
import torchaudio
from torch.nn.utils.rnn import pad_sequence
import torch.nn.functional as F
import pyworld as pw
from cosyvoice.utils.onnx import embedding_extractor, online_feature
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
def parquet_opener(data, mode='train'):
""" Give url or local file, return file descriptor
Inplace operation.
Args:
data(Iterable[str]): url or local file list
Returns:
Iterable[{src, stream}]
"""
for sample in data:
assert 'src' in sample
url = sample['src']
try:
for df in pq.ParquetFile(url).iter_batches(batch_size=64):
df = df.to_pandas()
for i in range(len(df)):
sample.update(dict(df.loc[i]))
# NOTE do not return sample directly, must initialize a new dict
yield {**sample}
except Exception as ex:
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
def filter(data,
max_length=10240,
min_length=10,
token_max_length=200,
token_min_length=1,
min_output_input_ratio=0.0005,
max_output_input_ratio=1,
mode='train'):
""" Filter sample according to feature and label length
Inplace operation.
Args::
data: Iterable[{key, wav, label, sample_rate}]
max_length: drop utterance which is greater than max_length(10ms)
min_length: drop utterance which is less than min_length(10ms)
token_max_length: drop utterance which is greater than
token_max_length, especially when use char unit for
english modeling
token_min_length: drop utterance which is
less than token_max_length
min_output_input_ratio: minimal ration of
token_length / feats_length(10ms)
max_output_input_ratio: maximum ration of
token_length / feats_length(10ms)
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
sample['speech'], sample['sample_rate'] = torchaudio.load(BytesIO(sample['audio_data']))
sample['speech'] = sample['speech'].mean(dim=0, keepdim=True)
del sample['audio_data']
# sample['wav'] is torch.Tensor, we have 100 frames every second
num_frames = sample['speech'].size(1) / sample['sample_rate'] * 100
if num_frames < min_length:
continue
if num_frames > max_length:
continue
if len(sample['text_token']) < token_min_length:
continue
if len(sample['text_token']) > token_max_length:
continue
if online_feature is False and len(sample['speech_token']) == 0:
continue
if online_feature is False and 'reject_speech_token' in sample and len(sample['reject_speech_token']) == 0:
continue
if num_frames != 0:
if len(sample['text_token']) / num_frames < min_output_input_ratio:
continue
if len(sample['text_token']) / num_frames > max_output_input_ratio:
continue
yield sample
def resample(data, resample_rate=22050, min_sample_rate=16000, mode='train'):
""" Resample data.
Inplace operation.
Args:
data: Iterable[{key, wav, label, sample_rate}]
resample_rate: target resample rate
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
sample_rate = sample['sample_rate']
waveform = sample['speech']
if sample_rate != resample_rate:
if sample_rate < min_sample_rate:
continue
sample['sample_rate'] = resample_rate
sample['speech'] = torchaudio.transforms.Resample(
orig_freq=sample_rate, new_freq=resample_rate)(waveform)
max_val = sample['speech'].abs().max()
if max_val > 1:
sample['speech'] /= max_val
yield sample
def truncate(data, truncate_length=24576, mode='train'):
""" Truncate data.
Args:
data: Iterable[{key, wav, label, sample_rate}]
truncate_length: truncate length
Returns:
Iterable[{key, wav, label, sample_rate}]
"""
for sample in data:
waveform = sample['speech']
if waveform.shape[1] > truncate_length:
start = random.randint(0, waveform.shape[1] - truncate_length)
waveform = waveform[:, start: start + truncate_length]
else:
waveform = torch.concat([waveform, torch.zeros(1, truncate_length - waveform.shape[1])], dim=1)
sample['speech'] = waveform
yield sample
def compute_fbank(data,
feat_extractor,
num_frames=-1,
mode='train'):
""" Extract fbank
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
assert 'utt' in sample
assert 'text_token' in sample
# NOTE in cosyvoice2/3, we support online token extraction, so we need to align speech to 25hz first
if num_frames != -1:
index = int(np.ceil(sample['speech'].shape[1] / num_frames))
sample['speech'] = torch.concat([sample['speech'], torch.zeros(1, index * num_frames - sample['speech'].shape[1])], dim=1)
sample['speech_feat'] = feat_extractor(sample['speech']).squeeze(dim=0).transpose(0, 1)
yield sample
def compute_whisper_fbank(data, num_frames=-1, mode='train'):
""" Extract whisper fbank
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
if num_frames != -1:
assert sample['speech'].shape[1] % num_frames == 0, 'speech length is not aligned with speech_token'
sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
sample['whisper_feat'] = whisper.log_mel_spectrogram(sample['speech_16k'], n_mels=128).squeeze(dim=0).transpose(0, 1)
yield sample
def compute_f0(data, sample_rate, hop_size, mode='train'):
""" Extract f0
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
frame_period = hop_size * 1000 / sample_rate
for sample in data:
assert 'sample_rate' in sample
assert 'speech' in sample
assert 'utt' in sample
assert 'text_token' in sample
waveform = sample['speech']
_f0, t = pw.harvest(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period)
if sum(_f0 != 0) < 5: # this happens when the algorithm fails
_f0, t = pw.dio(waveform.squeeze(dim=0).numpy().astype('double'), sample_rate, frame_period=frame_period) # if harvest fails, try dio
f0 = pw.stonemask(waveform.squeeze(dim=0).numpy().astype('double'), _f0, t, sample_rate)
f0 = F.interpolate(torch.from_numpy(f0).view(1, 1, -1), size=sample['speech_feat'].shape[0], mode='linear').view(-1)
sample['pitch_feat'] = f0
yield sample
def parse_embedding(data, normalize, mode='train'):
""" Parse utt_embedding/spk_embedding
Args:
data: Iterable[{key, wav, label, sample_rate}]
Returns:
Iterable[{key, feat, label}]
"""
for sample in data:
if 'utt_embedding' not in sample and 'spk_embedding' not in sample:
sample['speech_16k'] = torchaudio.transforms.Resample(orig_freq=sample['sample_rate'], new_freq=16000)(sample['speech'])
embedding = embedding_extractor.inference(sample['speech_16k'])
sample['spk_embedding'] = sample['utt_embedding'] = embedding
else:
sample['utt_embedding'] = torch.tensor(sample['utt_embedding'], dtype=torch.float32)
sample['spk_embedding'] = torch.tensor(sample['spk_embedding'], dtype=torch.float32)
if normalize:
sample['utt_embedding'] = F.normalize(sample['utt_embedding'], dim=0)
sample['spk_embedding'] = F.normalize(sample['spk_embedding'], dim=0)
yield sample
def tokenize(data, get_tokenizer, allowed_special, mode='train'):
""" Decode text to chars or BPE
Inplace operation
Args:
data: Iterable[{key, wav, txt, sample_rate}]
Returns:
Iterable[{key, wav, txt, tokens, label, sample_rate}]
"""
tokenizer = get_tokenizer()
for sample in data:
assert 'text' in sample
sample['text_token'] = tokenizer.encode(sample['text'], allowed_special=allowed_special)
if 'instruct' in sample:
sample['instruct_token'] = tokenizer.encode(sample['instruct'], allowed_special=allowed_special)
yield sample
def shuffle(data, shuffle_size=10000, mode='train'):
""" Local shuffle the data
Args:
data: Iterable[{key, feat, label}]
shuffle_size: buffer size for shuffle
Returns:
Iterable[{key, feat, label}]
"""
buf = []
yield_size = int(shuffle_size / 2)
for sample in data:
buf.append(sample)
if len(buf) >= shuffle_size:
random.shuffle(buf)
for x in buf[:yield_size]:
yield x
buf = buf[yield_size:]
# The sample left over
random.shuffle(buf)
for x in buf:
yield x
def sort(data, sort_size=500, mode='train'):
""" Sort the data by feature length.
Sort is used after shuffle and before batch, so we can group
utts with similar lengths into a batch, and `sort_size` should
be less than `shuffle_size`
Args:
data: Iterable[{key, feat, label}]
sort_size: buffer size for sort
Returns:
Iterable[{key, feat, label}]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= sort_size:
buf.sort(key=lambda x: x['speech_feat'].size(0))
for x in buf:
yield x
buf = []
# The sample left over
buf.sort(key=lambda x: x['speech_feat'].size(0))
for x in buf:
yield x
def static_batch(data, batch_size=16):
""" Static batch the data by `batch_size`
Args:
data: Iterable[{key, feat, label}]
batch_size: batch size
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
for sample in data:
buf.append(sample)
if len(buf) >= batch_size:
yield buf
buf = []
if len(buf) > 0:
yield buf
def dynamic_batch(data, max_frames_in_batch=12000, mode='train'):
""" Dynamic batch the data until the total frames in batch
reach `max_frames_in_batch`
Args:
data: Iterable[{key, feat, label}]
max_frames_in_batch: max_frames in one batch
Returns:
Iterable[List[{key, feat, label}]]
"""
buf = []
longest_frames = 0
for sample in data:
assert 'speech_feat' in sample
assert isinstance(sample['speech_feat'], torch.Tensor)
new_sample_frames = sample['speech_feat'].size(0)
longest_frames = max(longest_frames, new_sample_frames)
frames_after_padding = longest_frames * (len(buf) + 1)
if frames_after_padding > max_frames_in_batch:
yield buf
buf = [sample]
longest_frames = new_sample_frames
else:
buf.append(sample)
if len(buf) > 0:
yield buf
def batch(data, batch_type='static', batch_size=16, max_frames_in_batch=12000, mode='train'):
""" Wrapper for static/dynamic batch
"""
if batch_type == 'static':
return static_batch(data, batch_size)
elif batch_type == 'dynamic':
return dynamic_batch(data, max_frames_in_batch)
else:
logging.fatal('Unsupported batch type {}'.format(batch_type))
def padding(data, use_spk_embedding, mode='train', gan=False, dpo=False):
""" Padding the data into training data
Args:
data: Iterable[List[{key, feat, label}]]
Returns:
Iterable[Tuple(keys, feats, labels, feats lengths, label lengths)]
"""
for sample in data:
assert isinstance(sample, list)
order = torch.argsort(torch.tensor([x['speech'].size(1) for x in sample], dtype=torch.int32), descending=True)
batch = {}
batch['utts'] = [sample[i]['utt'] for i in order]
batch['text'] = [sample[i]['text'] for i in order]
text_token = [torch.tensor(sample[i]['text_token']) for i in order]
batch['text_token_len'] = torch.tensor([i.size(0) for i in text_token], dtype=torch.int32)
batch['text_token'] = pad_sequence(text_token, batch_first=True, padding_value=0)
speech_feat = [sample[i]['speech_feat'] for i in order]
batch['speech_feat_len'] = torch.tensor([i.size(0) for i in speech_feat], dtype=torch.int32)
batch['speech_feat'] = pad_sequence(speech_feat, batch_first=True, padding_value=0)
batch['utt_embedding'] = torch.stack([sample[i]['utt_embedding'] for i in order], dim=0)
batch['spk_embedding'] = torch.stack([sample[i]['spk_embedding'] for i in order], dim=0)
if torch.tensor(['instruct_token' in sample[i] for i in order]).all():
instruct_token = [torch.tensor(sample[i]['instruct_token']) for i in order]
batch['instruct_token_len'] = torch.tensor([i.size(0) for i in instruct_token], dtype=torch.int32)
batch['instruct_token'] = pad_sequence(instruct_token, batch_first=True, padding_value=0)
if torch.tensor(['whisper_feat' in sample[i] for i in order]).all():
whisper_feat = [sample[i]['whisper_feat'] for i in order]
batch['whisper_feat_len'] = torch.tensor([i.size(0) for i in whisper_feat], dtype=torch.int32)
batch['whisper_feat'] = pad_sequence(whisper_feat, batch_first=True, padding_value=0)
if torch.tensor(['speech_token' in sample[i] for i in order]).all():
speech_token = [torch.tensor(sample[i]['speech_token']) for i in order]
batch['speech_token_len'] = torch.tensor([i.size(0) for i in speech_token], dtype=torch.int32)
batch['speech_token'] = pad_sequence(speech_token, batch_first=True, padding_value=0)
if gan is True:
# in gan train, we need speech/pitch_feat
speech = [sample[i]['speech'].squeeze(dim=0) for i in order]
batch['speech_len'] = torch.tensor([i.size(0) for i in speech], dtype=torch.int32)
batch['speech'] = pad_sequence(speech, batch_first=True, padding_value=0)
pitch_feat = [sample[i]['pitch_feat'] for i in order]
batch['pitch_feat_len'] = torch.tensor([i.size(0) for i in pitch_feat], dtype=torch.int32)
batch['pitch_feat'] = pad_sequence(pitch_feat, batch_first=True, padding_value=0)
if dpo is True:
reject_speech_token = [torch.tensor(sample[i]['reject_speech_token']) for i in order]
batch['reject_speech_token_len'] = torch.tensor([i.size(0) for i in reject_speech_token], dtype=torch.int32)
batch['reject_speech_token'] = pad_sequence(reject_speech_token, batch_first=True, padding_value=0)
if use_spk_embedding is True:
batch["embedding"] = batch["spk_embedding"]
else:
batch["embedding"] = batch["utt_embedding"]
yield batch

View File

@@ -0,0 +1,176 @@
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
import torch
from torch import nn
import torch.nn.functional as F
from einops import repeat
from x_transformers.x_transformers import RotaryEmbedding
from cosyvoice.utils.mask import add_optional_chunk_mask
from cosyvoice.flow.DiT.modules import (
TimestepEmbedding,
ConvNeXtV2Block,
CausalConvPositionEmbedding,
DiTBlock,
AdaLayerNormZero_Final,
precompute_freqs_cis,
get_pos_embed_indices,
)
# Text embedding
class TextEmbedding(nn.Module):
def __init__(self, text_num_embeds, text_dim, conv_layers=0, conv_mult=2):
super().__init__()
self.text_embed = nn.Embedding(text_num_embeds + 1, text_dim) # use 0 as filler token
if conv_layers > 0:
self.extra_modeling = True
self.precompute_max_pos = 4096 # ~44s of 24khz audio
self.register_buffer("freqs_cis", precompute_freqs_cis(text_dim, self.precompute_max_pos), persistent=False)
self.text_blocks = nn.Sequential(
*[ConvNeXtV2Block(text_dim, text_dim * conv_mult) for _ in range(conv_layers)]
)
else:
self.extra_modeling = False
def forward(self, text: int["b nt"], seq_len, drop_text=False): # noqa: F722
batch, text_len = text.shape[0], text.shape[1]
text = text + 1 # use 0 as filler token. preprocess of batch pad -1, see list_str_to_idx()
text = text[:, :seq_len] # curtail if character tokens are more than the mel spec tokens
text = F.pad(text, (0, seq_len - text_len), value=0)
if drop_text: # cfg for text
text = torch.zeros_like(text)
text = self.text_embed(text) # b n -> b n d
# possible extra modeling
if self.extra_modeling:
# sinus pos emb
batch_start = torch.zeros((batch,), dtype=torch.long)
pos_idx = get_pos_embed_indices(batch_start, seq_len, max_pos=self.precompute_max_pos)
text_pos_embed = self.freqs_cis[pos_idx]
text = text + text_pos_embed
# convnextv2 blocks
text = self.text_blocks(text)
return text
# noised input audio and context mixing embedding
class InputEmbedding(nn.Module):
def __init__(self, mel_dim, text_dim, out_dim, spk_dim=None):
super().__init__()
spk_dim = 0 if spk_dim is None else spk_dim
self.spk_dim = spk_dim
self.proj = nn.Linear(mel_dim * 2 + text_dim + spk_dim, out_dim)
self.conv_pos_embed = CausalConvPositionEmbedding(dim=out_dim)
def forward(
self,
x: float["b n d"],
cond: float["b n d"],
text_embed: float["b n d"],
spks: float["b d"],
):
to_cat = [x, cond, text_embed]
if self.spk_dim > 0:
spks = repeat(spks, "b c -> b t c", t=x.shape[1])
to_cat.append(spks)
x = self.proj(torch.cat(to_cat, dim=-1))
x = self.conv_pos_embed(x) + x
return x
# Transformer backbone using DiT blocks
class DiT(nn.Module):
def __init__(
self,
*,
dim,
depth=8,
heads=8,
dim_head=64,
dropout=0.1,
ff_mult=4,
mel_dim=80,
mu_dim=None,
long_skip_connection=False,
spk_dim=None,
out_channels=None,
static_chunk_size=50,
num_decoding_left_chunks=2
):
super().__init__()
self.time_embed = TimestepEmbedding(dim)
if mu_dim is None:
mu_dim = mel_dim
self.input_embed = InputEmbedding(mel_dim, mu_dim, dim, spk_dim)
self.rotary_embed = RotaryEmbedding(dim_head)
self.dim = dim
self.depth = depth
self.transformer_blocks = nn.ModuleList(
[DiTBlock(dim=dim, heads=heads, dim_head=dim_head, ff_mult=ff_mult, dropout=dropout) for _ in range(depth)]
)
self.long_skip_connection = nn.Linear(dim * 2, dim, bias=False) if long_skip_connection else None
self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
self.proj_out = nn.Linear(dim, mel_dim)
self.out_channels = out_channels
self.static_chunk_size = static_chunk_size
self.num_decoding_left_chunks = num_decoding_left_chunks
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
x = x.transpose(1, 2)
mu = mu.transpose(1, 2)
cond = cond.transpose(1, 2)
spks = spks.unsqueeze(dim=1)
batch, seq_len = x.shape[0], x.shape[1]
if t.ndim == 0:
t = t.repeat(batch)
# t: conditioning time, c: context (text + masked cond audio), x: noised input audio
t = self.time_embed(t)
x = self.input_embed(x, cond, mu, spks.squeeze(1))
rope = self.rotary_embed.forward_from_seq_len(seq_len)
if self.long_skip_connection is not None:
residual = x
if streaming is True:
attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, self.static_chunk_size, -1).unsqueeze(dim=1)
else:
attn_mask = add_optional_chunk_mask(x, mask.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1).unsqueeze(dim=1)
for block in self.transformer_blocks:
x = block(x, t, mask=attn_mask.bool(), rope=rope)
if self.long_skip_connection is not None:
x = self.long_skip_connection(torch.cat((x, residual), dim=-1))
x = self.norm_out(x, t)
output = self.proj_out(x).transpose(1, 2)
return output

View File

@@ -0,0 +1,616 @@
"""
ein notation:
b - batch
n - sequence
nt - text sequence
nw - raw wave length
d - dimension
"""
from __future__ import annotations
from typing import Optional
import math
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
from x_transformers.x_transformers import apply_rotary_pos_emb
# raw wav to mel spec
class MelSpec(nn.Module):
def __init__(
self,
filter_length=1024,
hop_length=256,
win_length=1024,
n_mel_channels=100,
target_sample_rate=24_000,
normalize=False,
power=1,
norm=None,
center=True,
):
super().__init__()
self.n_mel_channels = n_mel_channels
self.mel_stft = torchaudio.transforms.MelSpectrogram(
sample_rate=target_sample_rate,
n_fft=filter_length,
win_length=win_length,
hop_length=hop_length,
n_mels=n_mel_channels,
power=power,
center=center,
normalized=normalize,
norm=norm,
)
self.register_buffer("dummy", torch.tensor(0), persistent=False)
def forward(self, inp):
if len(inp.shape) == 3:
inp = inp.squeeze(1) # 'b 1 nw -> b nw'
assert len(inp.shape) == 2
if self.dummy.device != inp.device:
self.to(inp.device)
mel = self.mel_stft(inp)
mel = mel.clamp(min=1e-5).log()
return mel
# sinusoidal position embedding
class SinusPositionEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x, scale=1000):
device = x.device
half_dim = self.dim // 2
emb = math.log(10000) / (half_dim - 1)
emb = torch.exp(torch.arange(half_dim, device=device).float() * -emb)
emb = scale * x.unsqueeze(1) * emb.unsqueeze(0)
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
return emb
# convolutional position embedding
class ConvPositionEmbedding(nn.Module):
def __init__(self, dim, kernel_size=31, groups=16):
super().__init__()
assert kernel_size % 2 != 0
self.conv1d = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
nn.Mish(),
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=kernel_size // 2),
nn.Mish(),
)
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
if mask is not None:
mask = mask[..., None]
x = x.masked_fill(~mask, 0.0)
x = x.permute(0, 2, 1)
x = self.conv1d(x)
out = x.permute(0, 2, 1)
if mask is not None:
out = out.masked_fill(~mask, 0.0)
return out
class CausalConvPositionEmbedding(nn.Module):
def __init__(self, dim, kernel_size=31, groups=16):
super().__init__()
assert kernel_size % 2 != 0
self.kernel_size = kernel_size
self.conv1 = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
nn.Mish(),
)
self.conv2 = nn.Sequential(
nn.Conv1d(dim, dim, kernel_size, groups=groups, padding=0),
nn.Mish(),
)
def forward(self, x: float["b n d"], mask: bool["b n"] | None = None): # noqa: F722
if mask is not None:
mask = mask[..., None]
x = x.masked_fill(~mask, 0.0)
x = x.permute(0, 2, 1)
x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
x = self.conv1(x)
x = F.pad(x, (self.kernel_size - 1, 0, 0, 0))
x = self.conv2(x)
out = x.permute(0, 2, 1)
if mask is not None:
out = out.masked_fill(~mask, 0.0)
return out
# rotary positional embedding related
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0, theta_rescale_factor=1.0):
# proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
# has some connection to NTK literature
# https://www.reddit.com/r/LocalLLaMA/comments/14lz7j5/ntkaware_scaled_rope_allows_llama_models_to_have/
# https://github.com/lucidrains/rotary-embedding-torch/blob/main/rotary_embedding_torch/rotary_embedding_torch.py
theta *= theta_rescale_factor ** (dim / (dim - 2))
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
t = torch.arange(end, device=freqs.device) # type: ignore
freqs = torch.outer(t, freqs).float() # type: ignore
freqs_cos = torch.cos(freqs) # real part
freqs_sin = torch.sin(freqs) # imaginary part
return torch.cat([freqs_cos, freqs_sin], dim=-1)
def get_pos_embed_indices(start, length, max_pos, scale=1.0):
# length = length if isinstance(length, int) else length.max()
scale = scale * torch.ones_like(start, dtype=torch.float32) # in case scale is a scalar
pos = (
start.unsqueeze(1)
+ (torch.arange(length, device=start.device, dtype=torch.float32).unsqueeze(0) * scale.unsqueeze(1)).long()
)
# avoid extra long error.
pos = torch.where(pos < max_pos, pos, max_pos - 1)
return pos
# Global Response Normalization layer (Instance Normalization ?)
class GRN(nn.Module):
def __init__(self, dim):
super().__init__()
self.gamma = nn.Parameter(torch.zeros(1, 1, dim))
self.beta = nn.Parameter(torch.zeros(1, 1, dim))
def forward(self, x):
Gx = torch.norm(x, p=2, dim=1, keepdim=True)
Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
return self.gamma * (x * Nx) + self.beta + x
# ConvNeXt-V2 Block https://github.com/facebookresearch/ConvNeXt-V2/blob/main/models/convnextv2.py
# ref: https://github.com/bfs18/e2_tts/blob/main/rfwave/modules.py#L108
class ConvNeXtV2Block(nn.Module):
def __init__(
self,
dim: int,
intermediate_dim: int,
dilation: int = 1,
):
super().__init__()
padding = (dilation * (7 - 1)) // 2
self.dwconv = nn.Conv1d(
dim, dim, kernel_size=7, padding=padding, groups=dim, dilation=dilation
) # depthwise conv
self.norm = nn.LayerNorm(dim, eps=1e-6)
self.pwconv1 = nn.Linear(dim, intermediate_dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.grn = GRN(intermediate_dim)
self.pwconv2 = nn.Linear(intermediate_dim, dim)
def forward(self, x: torch.Tensor) -> torch.Tensor:
residual = x
x = x.transpose(1, 2) # b n d -> b d n
x = self.dwconv(x)
x = x.transpose(1, 2) # b d n -> b n d
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.grn(x)
x = self.pwconv2(x)
return residual + x
# AdaLayerNormZero
# return with modulated x for attn input, and params for later mlp modulation
class AdaLayerNormZero(nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, dim * 6)
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb=None):
emb = self.linear(self.silu(emb))
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = torch.chunk(emb, 6, dim=1)
x = self.norm(x) * (1 + scale_msa[:, None]) + shift_msa[:, None]
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp
# AdaLayerNormZero for final layer
# return only with modulated x for attn input, cuz no more mlp modulation
class AdaLayerNormZero_Final(nn.Module):
def __init__(self, dim):
super().__init__()
self.silu = nn.SiLU()
self.linear = nn.Linear(dim, dim * 2)
self.norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
def forward(self, x, emb):
emb = self.linear(self.silu(emb))
scale, shift = torch.chunk(emb, 2, dim=1)
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
return x
# FeedForward
class FeedForward(nn.Module):
def __init__(self, dim, dim_out=None, mult=4, dropout=0.0, approximate: str = "none"):
super().__init__()
inner_dim = int(dim * mult)
dim_out = dim_out if dim_out is not None else dim
activation = nn.GELU(approximate=approximate)
project_in = nn.Sequential(nn.Linear(dim, inner_dim), activation)
self.ff = nn.Sequential(project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out))
def forward(self, x):
return self.ff(x)
# Attention with possible joint part
# modified from diffusers/src/diffusers/models/attention_processor.py
class Attention(nn.Module):
def __init__(
self,
processor: JointAttnProcessor | AttnProcessor,
dim: int,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
context_dim: Optional[int] = None, # if not None -> joint attention
context_pre_only=None,
):
super().__init__()
if not hasattr(F, "scaled_dot_product_attention"):
raise ImportError("Attention equires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
self.processor = processor
self.dim = dim
self.heads = heads
self.inner_dim = dim_head * heads
self.dropout = dropout
self.context_dim = context_dim
self.context_pre_only = context_pre_only
self.to_q = nn.Linear(dim, self.inner_dim)
self.to_k = nn.Linear(dim, self.inner_dim)
self.to_v = nn.Linear(dim, self.inner_dim)
if self.context_dim is not None:
self.to_k_c = nn.Linear(context_dim, self.inner_dim)
self.to_v_c = nn.Linear(context_dim, self.inner_dim)
if self.context_pre_only is not None:
self.to_q_c = nn.Linear(context_dim, self.inner_dim)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(self.inner_dim, dim))
self.to_out.append(nn.Dropout(dropout))
if self.context_pre_only is not None and not self.context_pre_only:
self.to_out_c = nn.Linear(self.inner_dim, dim)
def forward(
self,
x: float["b n d"], # noised input x # noqa: F722
c: float["b n d"] = None, # context c # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.Tensor:
if c is not None:
return self.processor(self, x, c=c, mask=mask, rope=rope, c_rope=c_rope)
else:
return self.processor(self, x, mask=mask, rope=rope)
# Attention processor
class AttnProcessor:
def __init__(self):
pass
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding
) -> torch.FloatTensor:
batch_size = x.shape[0]
# `sample` projections.
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# apply rotary position embedding
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
# attention
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = mask
if attn_mask.dim() == 2:
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# linear proj
x = attn.to_out[0](x)
# dropout
x = attn.to_out[1](x)
if mask is not None:
if mask.dim() == 2:
mask = mask.unsqueeze(-1)
else:
mask = mask[:, 0, -1].unsqueeze(-1)
x = x.masked_fill(~mask, 0.0)
return x
# Joint Attention processor for MM-DiT
# modified from diffusers/src/diffusers/models/attention_processor.py
class JointAttnProcessor:
def __init__(self):
pass
def __call__(
self,
attn: Attention,
x: float["b n d"], # noised input x # noqa: F722
c: float["b nt d"] = None, # context c, here text # noqa: F722
mask: bool["b n"] | None = None, # noqa: F722
rope=None, # rotary position embedding for x
c_rope=None, # rotary position embedding for c
) -> torch.FloatTensor:
residual = x
batch_size = c.shape[0]
# `sample` projections.
query = attn.to_q(x)
key = attn.to_k(x)
value = attn.to_v(x)
# `context` projections.
c_query = attn.to_q_c(c)
c_key = attn.to_k_c(c)
c_value = attn.to_v_c(c)
# apply rope for context and noised input independently
if rope is not None:
freqs, xpos_scale = rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
query = apply_rotary_pos_emb(query, freqs, q_xpos_scale)
key = apply_rotary_pos_emb(key, freqs, k_xpos_scale)
if c_rope is not None:
freqs, xpos_scale = c_rope
q_xpos_scale, k_xpos_scale = (xpos_scale, xpos_scale**-1.0) if xpos_scale is not None else (1.0, 1.0)
c_query = apply_rotary_pos_emb(c_query, freqs, q_xpos_scale)
c_key = apply_rotary_pos_emb(c_key, freqs, k_xpos_scale)
# attention
query = torch.cat([query, c_query], dim=1)
key = torch.cat([key, c_key], dim=1)
value = torch.cat([value, c_value], dim=1)
inner_dim = key.shape[-1]
head_dim = inner_dim // attn.heads
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
# mask. e.g. inference got a batch with different target durations, mask out the padding
if mask is not None:
attn_mask = F.pad(mask, (0, c.shape[1]), value=True) # no mask for c (text)
attn_mask = attn_mask.unsqueeze(1).unsqueeze(1) # 'b n -> b 1 1 n'
attn_mask = attn_mask.expand(batch_size, attn.heads, query.shape[-2], key.shape[-2])
else:
attn_mask = None
x = F.scaled_dot_product_attention(query, key, value, attn_mask=attn_mask, dropout_p=0.0, is_causal=False)
x = x.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
x = x.to(query.dtype)
# Split the attention outputs.
x, c = (
x[:, : residual.shape[1]],
x[:, residual.shape[1]:],
)
# linear proj
x = attn.to_out[0](x)
# dropout
x = attn.to_out[1](x)
if not attn.context_pre_only:
c = attn.to_out_c(c)
if mask is not None:
mask = mask.unsqueeze(-1)
x = x.masked_fill(~mask, 0.0)
# c = c.masked_fill(~mask, 0.) # no mask for c (text)
return x, c
# DiT Block
class DiTBlock(nn.Module):
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1):
super().__init__()
self.attn_norm = AdaLayerNormZero(dim)
self.attn = Attention(
processor=AttnProcessor(),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
)
self.ff_norm = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
def forward(self, x, t, mask=None, rope=None): # x: noised input, t: time embedding
# pre-norm & modulation for attention input
norm, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.attn_norm(x, emb=t)
# attention
attn_output = self.attn(x=norm, mask=mask, rope=rope)
# process attention output for input x
x = x + gate_msa.unsqueeze(1) * attn_output
ff_norm = self.ff_norm(x) * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
ff_output = self.ff(ff_norm)
x = x + gate_mlp.unsqueeze(1) * ff_output
return x
# MMDiT Block https://arxiv.org/abs/2403.03206
class MMDiTBlock(nn.Module):
r"""
modified from diffusers/src/diffusers/models/attention.py
notes.
_c: context related. text, cond, etc. (left part in sd3 fig2.b)
_x: noised input related. (right part)
context_pre_only: last layer only do prenorm + modulation cuz no more ffn
"""
def __init__(self, dim, heads, dim_head, ff_mult=4, dropout=0.1, context_pre_only=False):
super().__init__()
self.context_pre_only = context_pre_only
self.attn_norm_c = AdaLayerNormZero_Final(dim) if context_pre_only else AdaLayerNormZero(dim)
self.attn_norm_x = AdaLayerNormZero(dim)
self.attn = Attention(
processor=JointAttnProcessor(),
dim=dim,
heads=heads,
dim_head=dim_head,
dropout=dropout,
context_dim=dim,
context_pre_only=context_pre_only,
)
if not context_pre_only:
self.ff_norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_c = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
else:
self.ff_norm_c = None
self.ff_c = None
self.ff_norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
self.ff_x = FeedForward(dim=dim, mult=ff_mult, dropout=dropout, approximate="tanh")
def forward(self, x, c, t, mask=None, rope=None, c_rope=None): # x: noised input, c: context, t: time embedding
# pre-norm & modulation for attention input
if self.context_pre_only:
norm_c = self.attn_norm_c(c, t)
else:
norm_c, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.attn_norm_c(c, emb=t)
norm_x, x_gate_msa, x_shift_mlp, x_scale_mlp, x_gate_mlp = self.attn_norm_x(x, emb=t)
# attention
x_attn_output, c_attn_output = self.attn(x=norm_x, c=norm_c, mask=mask, rope=rope, c_rope=c_rope)
# process attention output for context c
if self.context_pre_only:
c = None
else: # if not last layer
c = c + c_gate_msa.unsqueeze(1) * c_attn_output
norm_c = self.ff_norm_c(c) * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
c_ff_output = self.ff_c(norm_c)
c = c + c_gate_mlp.unsqueeze(1) * c_ff_output
# process attention output for input x
x = x + x_gate_msa.unsqueeze(1) * x_attn_output
norm_x = self.ff_norm_x(x) * (1 + x_scale_mlp[:, None]) + x_shift_mlp[:, None]
x_ff_output = self.ff_x(norm_x)
x = x + x_gate_mlp.unsqueeze(1) * x_ff_output
return c, x
# time step conditioning embedding
class TimestepEmbedding(nn.Module):
def __init__(self, dim, freq_embed_dim=256):
super().__init__()
self.time_embed = SinusPositionEmbedding(freq_embed_dim)
self.time_mlp = nn.Sequential(nn.Linear(freq_embed_dim, dim), nn.SiLU(), nn.Linear(dim, dim))
def forward(self, timestep: float["b"]): # noqa: F821
time_hidden = self.time_embed(timestep)
time_hidden = time_hidden.to(timestep.dtype)
time = self.time_mlp(time_hidden) # b d
return time

View File

@@ -0,0 +1,494 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import pack, rearrange, repeat
from cosyvoice.utils.common import mask_to_bias
from cosyvoice.utils.mask import add_optional_chunk_mask
from matcha.models.components.decoder import SinusoidalPosEmb, Block1D, ResnetBlock1D, Downsample1D, TimestepEmbedding, Upsample1D
from matcha.models.components.transformer import BasicTransformerBlock
class Transpose(torch.nn.Module):
def __init__(self, dim0: int, dim1: int):
super().__init__()
self.dim0 = dim0
self.dim1 = dim1
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = torch.transpose(x, self.dim0, self.dim1)
return x
class CausalConv1d(torch.nn.Conv1d):
def __init__(
self,
in_channels: int,
out_channels: int,
kernel_size: int,
stride: int = 1,
dilation: int = 1,
groups: int = 1,
bias: bool = True,
padding_mode: str = 'zeros',
device=None,
dtype=None
) -> None:
super(CausalConv1d, self).__init__(in_channels, out_channels,
kernel_size, stride,
padding=0, dilation=dilation,
groups=groups, bias=bias,
padding_mode=padding_mode,
device=device, dtype=dtype)
assert stride == 1
self.causal_padding = kernel_size - 1
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = F.pad(x, (self.causal_padding, 0), value=0.0)
x = super(CausalConv1d, self).forward(x)
return x
class CausalBlock1D(Block1D):
def __init__(self, dim: int, dim_out: int):
super(CausalBlock1D, self).__init__(dim, dim_out)
self.block = torch.nn.Sequential(
CausalConv1d(dim, dim_out, 3),
Transpose(1, 2),
nn.LayerNorm(dim_out),
Transpose(1, 2),
nn.Mish(),
)
def forward(self, x: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
output = self.block(x * mask)
return output * mask
class CausalResnetBlock1D(ResnetBlock1D):
def __init__(self, dim: int, dim_out: int, time_emb_dim: int, groups: int = 8):
super(CausalResnetBlock1D, self).__init__(dim, dim_out, time_emb_dim, groups)
self.block1 = CausalBlock1D(dim, dim_out)
self.block2 = CausalBlock1D(dim_out, dim_out)
class ConditionalDecoder(nn.Module):
def __init__(
self,
in_channels,
out_channels,
channels=(256, 256),
dropout=0.05,
attention_head_dim=64,
n_blocks=1,
num_mid_blocks=2,
num_heads=4,
act_fn="snake",
):
"""
This decoder requires an input with the same shape of the target. So, if your text content
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
"""
super().__init__()
channels = tuple(channels)
self.in_channels = in_channels
self.out_channels = out_channels
self.time_embeddings = SinusoidalPosEmb(in_channels)
time_embed_dim = channels[0] * 4
self.time_mlp = TimestepEmbedding(
in_channels=in_channels,
time_embed_dim=time_embed_dim,
act_fn="silu",
)
self.down_blocks = nn.ModuleList([])
self.mid_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
output_channel = in_channels
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
input_channel = output_channel
output_channel = channels[i]
is_last = i == len(channels) - 1
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
downsample = (
Downsample1D(output_channel) if not is_last else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
for _ in range(num_mid_blocks):
input_channel = channels[-1]
out_channels = channels[-1]
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
channels = channels[::-1] + (channels[0],)
for i in range(len(channels) - 1):
input_channel = channels[i] * 2
output_channel = channels[i + 1]
is_last = i == len(channels) - 2
resnet = ResnetBlock1D(
dim=input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
upsample = (
Upsample1D(output_channel, use_conv_transpose=True)
if not is_last
else nn.Conv1d(output_channel, output_channel, 3, padding=1)
)
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
self.final_block = Block1D(channels[-1], channels[-1])
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
self.initialize_weights()
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv1d):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.GroupNorm):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.kaiming_normal_(m.weight, nonlinearity="relu")
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
"""Forward pass of the UNet1DConditional model.
Args:
x (torch.Tensor): shape (batch_size, in_channels, time)
mask (_type_): shape (batch_size, 1, time)
t (_type_): shape (batch_size)
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
cond (_type_, optional): placeholder for future use. Defaults to None.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
_type_: _description_
"""
t = self.time_embeddings(t).to(t.dtype)
t = self.time_mlp(t)
x = pack([x, mu], "b * t")[0]
if spks is not None:
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
x = pack([x, spks], "b * t")[0]
if cond is not None:
x = pack([x, cond], "b * t")[0]
hiddens = []
masks = [mask]
for resnet, transformer_blocks, downsample in self.down_blocks:
mask_down = masks[-1]
x = resnet(x, mask_down, t)
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
hiddens.append(x) # Save hidden states for skip connections
x = downsample(x * mask_down)
masks.append(mask_down[:, :, ::2])
masks = masks[:-1]
mask_mid = masks[-1]
for resnet, transformer_blocks in self.mid_blocks:
x = resnet(x, mask_mid, t)
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
for resnet, transformer_blocks, upsample in self.up_blocks:
mask_up = masks.pop()
skip = hiddens.pop()
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
x = resnet(x, mask_up, t)
x = rearrange(x, "b c t -> b t c").contiguous()
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
x = upsample(x * mask_up)
x = self.final_block(x, mask_up)
output = self.final_proj(x * mask_up)
return output * mask
class CausalConditionalDecoder(ConditionalDecoder):
def __init__(
self,
in_channels,
out_channels,
channels=(256, 256),
dropout=0.05,
attention_head_dim=64,
n_blocks=1,
num_mid_blocks=2,
num_heads=4,
act_fn="snake",
static_chunk_size=50,
num_decoding_left_chunks=2,
):
"""
This decoder requires an input with the same shape of the target. So, if your text content
is shorter or longer than the outputs, please re-sampling it before feeding to the decoder.
"""
torch.nn.Module.__init__(self)
channels = tuple(channels)
self.in_channels = in_channels
self.out_channels = out_channels
self.time_embeddings = SinusoidalPosEmb(in_channels)
time_embed_dim = channels[0] * 4
self.time_mlp = TimestepEmbedding(
in_channels=in_channels,
time_embed_dim=time_embed_dim,
act_fn="silu",
)
self.static_chunk_size = static_chunk_size
self.num_decoding_left_chunks = num_decoding_left_chunks
self.down_blocks = nn.ModuleList([])
self.mid_blocks = nn.ModuleList([])
self.up_blocks = nn.ModuleList([])
output_channel = in_channels
for i in range(len(channels)): # pylint: disable=consider-using-enumerate
input_channel = output_channel
output_channel = channels[i]
is_last = i == len(channels) - 1
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
downsample = (
Downsample1D(output_channel) if not is_last else CausalConv1d(output_channel, output_channel, 3)
)
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
for _ in range(num_mid_blocks):
input_channel = channels[-1]
out_channels = channels[-1]
resnet = CausalResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
self.mid_blocks.append(nn.ModuleList([resnet, transformer_blocks]))
channels = channels[::-1] + (channels[0],)
for i in range(len(channels) - 1):
input_channel = channels[i] * 2
output_channel = channels[i + 1]
is_last = i == len(channels) - 2
resnet = CausalResnetBlock1D(
dim=input_channel,
dim_out=output_channel,
time_emb_dim=time_embed_dim,
)
transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
dim=output_channel,
num_attention_heads=num_heads,
attention_head_dim=attention_head_dim,
dropout=dropout,
activation_fn=act_fn,
)
for _ in range(n_blocks)
]
)
upsample = (
Upsample1D(output_channel, use_conv_transpose=True)
if not is_last
else CausalConv1d(output_channel, output_channel, 3)
)
self.up_blocks.append(nn.ModuleList([resnet, transformer_blocks, upsample]))
self.final_block = CausalBlock1D(channels[-1], channels[-1])
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
self.initialize_weights()
def forward(self, x, mask, mu, t, spks=None, cond=None, streaming=False):
"""Forward pass of the UNet1DConditional model.
Args:
x (torch.Tensor): shape (batch_size, in_channels, time)
mask (_type_): shape (batch_size, 1, time)
t (_type_): shape (batch_size)
spks (_type_, optional): shape: (batch_size, condition_channels). Defaults to None.
cond (_type_, optional): placeholder for future use. Defaults to None.
Raises:
ValueError: _description_
ValueError: _description_
Returns:
_type_: _description_
"""
t = self.time_embeddings(t).to(t.dtype)
t = self.time_mlp(t)
x = pack([x, mu], "b * t")[0]
if spks is not None:
spks = repeat(spks, "b c -> b c t", t=x.shape[-1])
x = pack([x, spks], "b * t")[0]
if cond is not None:
x = pack([x, cond], "b * t")[0]
hiddens = []
masks = [mask]
for resnet, transformer_blocks, downsample in self.down_blocks:
mask_down = masks[-1]
x = resnet(x, mask_down, t)
x = rearrange(x, "b c t -> b t c").contiguous()
if streaming is True:
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, self.static_chunk_size, -1)
else:
attn_mask = add_optional_chunk_mask(x, mask_down.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
hiddens.append(x) # Save hidden states for skip connections
x = downsample(x * mask_down)
masks.append(mask_down[:, :, ::2])
masks = masks[:-1]
mask_mid = masks[-1]
for resnet, transformer_blocks in self.mid_blocks:
x = resnet(x, mask_mid, t)
x = rearrange(x, "b c t -> b t c").contiguous()
if streaming is True:
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, self.static_chunk_size, -1)
else:
attn_mask = add_optional_chunk_mask(x, mask_mid.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
for resnet, transformer_blocks, upsample in self.up_blocks:
mask_up = masks.pop()
skip = hiddens.pop()
x = pack([x[:, :, :skip.shape[-1]], skip], "b * t")[0]
x = resnet(x, mask_up, t)
x = rearrange(x, "b c t -> b t c").contiguous()
if streaming is True:
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, self.static_chunk_size, -1)
else:
attn_mask = add_optional_chunk_mask(x, mask_up.bool(), False, False, 0, 0, -1).repeat(1, x.size(1), 1)
attn_mask = mask_to_bias(attn_mask, x.dtype)
for transformer_block in transformer_blocks:
x = transformer_block(
hidden_states=x,
attention_mask=attn_mask,
timestep=t,
)
x = rearrange(x, "b t c -> b c t").contiguous()
x = upsample(x * mask_up)
x = self.final_block(x, mask_up)
output = self.final_proj(x * mask_up)
return output * mask

View File

@@ -0,0 +1,443 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os, logging
import random
from typing import Dict, Optional
import torch
import torch.nn as nn
from torch.nn import functional as F
from omegaconf import DictConfig
from cosyvoice.utils.mask import make_pad_mask
from cosyvoice.utils.onnx import SpeechTokenExtractor, online_feature, onnx_path
class MaskedDiffWithXvec(torch.nn.Module):
def __init__(self,
input_size: int = 512,
output_size: int = 80,
spk_embed_dim: int = 192,
output_type: str = "mel",
vocab_size: int = 4096,
input_frame_rate: int = 50,
only_mask_loss: bool = True,
encoder: torch.nn.Module = None,
length_regulator: torch.nn.Module = None,
decoder: torch.nn.Module = None,
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.decoder_conf = decoder_conf
self.vocab_size = vocab_size
self.output_type = output_type
self.input_frame_rate = input_frame_rate
logging.info(f"input frame rate={self.input_frame_rate}")
self.input_embedding = nn.Embedding(vocab_size, input_size)
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
self.encoder = encoder
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
self.decoder = decoder
self.length_regulator = length_regulator
self.only_mask_loss = only_mask_loss
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device)
feat_len = batch['speech_feat_len'].to(device)
embedding = batch['embedding'].to(device)
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
h, h_lengths = self.encoder(token, token_len)
h = self.encoder_proj(h)
h, h_lengths = self.length_regulator(h, feat_len)
# get conditions
conds = torch.zeros(feat.shape, device=token.device)
for i, j in enumerate(feat_len):
if random.random() < 0.5:
continue
index = random.randint(0, int(0.3 * j))
conds[i, :index] = feat[i, :index]
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(feat_len)).to(h)
# NOTE this is unnecessary, feat/h already same shape
loss, _ = self.decoder.compute_loss(
feat.transpose(1, 2).contiguous(),
mask.unsqueeze(1),
h.transpose(1, 2).contiguous(),
embedding,
cond=conds
)
return {'loss': loss}
@torch.inference_mode()
def inference(self,
token,
token_len,
prompt_token,
prompt_token_len,
prompt_feat,
prompt_feat_len,
embedding,
flow_cache):
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat speech token and prompt speech token
token_len1, token_len2 = prompt_token.shape[1], token.shape[1]
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
h, h_lengths = self.encoder(token, token_len)
h = self.encoder_proj(h)
mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
# get conditions
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
conds[:, :mel_len1] = prompt_feat
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
feat, flow_cache = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
n_timesteps=10,
prompt_len=mel_len1,
cache=flow_cache
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat.float(), flow_cache
class CausalMaskedDiffWithXvec(torch.nn.Module):
def __init__(self,
input_size: int = 512,
output_size: int = 80,
spk_embed_dim: int = 192,
output_type: str = "mel",
vocab_size: int = 4096,
input_frame_rate: int = 50,
only_mask_loss: bool = True,
token_mel_ratio: int = 2,
pre_lookahead_len: int = 3,
encoder: torch.nn.Module = None,
decoder: torch.nn.Module = None,
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.decoder_conf = decoder_conf
self.vocab_size = vocab_size
self.output_type = output_type
self.input_frame_rate = input_frame_rate
logging.info(f"input frame rate={self.input_frame_rate}")
self.input_embedding = nn.Embedding(vocab_size, input_size)
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
self.encoder = encoder
self.encoder_proj = torch.nn.Linear(self.encoder.output_size(), output_size)
self.decoder = decoder
self.only_mask_loss = only_mask_loss
self.token_mel_ratio = token_mel_ratio
self.pre_lookahead_len = pre_lookahead_len
if online_feature is True:
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v2.batch.onnx'))
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
if 'speech_token' not in batch:
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
else:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device)
feat_len = batch['speech_feat_len'].to(device)
embedding = batch['embedding'].to(device)
# NOTE unified training, static_chunk_size > 0 or = 0
streaming = True if random.random() < 0.5 else False
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
h = self.encoder_proj(h)
# get conditions
conds = torch.zeros(feat.shape, device=token.device)
for i, j in enumerate(feat_len):
if random.random() < 0.5:
continue
index = random.randint(0, int(0.3 * j))
conds[i, :index] = feat[i, :index]
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(h_lengths.sum(dim=-1).squeeze(dim=1))).to(h)
loss, _ = self.decoder.compute_loss(
feat.transpose(1, 2).contiguous(),
mask.unsqueeze(1),
h.transpose(1, 2).contiguous(),
embedding,
cond=conds,
streaming=streaming,
)
return {'loss': loss}
@torch.inference_mode()
def inference(self,
token,
token_len,
prompt_token,
prompt_token_len,
prompt_feat,
prompt_feat_len,
embedding,
streaming,
finalize):
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
if finalize is True:
h, h_lengths = self.encoder(token, token_len, streaming=streaming)
else:
token, context = token[:, :-self.pre_lookahead_len], token[:, -self.pre_lookahead_len:]
h, h_lengths = self.encoder(token, token_len, context=context, streaming=streaming)
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
h = self.encoder_proj(h)
# get conditions
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
conds[:, :mel_len1] = prompt_feat
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
feat, _ = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
n_timesteps=10,
streaming=streaming
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat.float(), None
class CausalMaskedDiffWithDiT(torch.nn.Module):
def __init__(self,
input_size: int = 512,
output_size: int = 80,
spk_embed_dim: int = 192,
output_type: str = "mel",
vocab_size: int = 4096,
input_frame_rate: int = 50,
only_mask_loss: bool = True,
token_mel_ratio: int = 2,
pre_lookahead_len: int = 3,
pre_lookahead_layer: torch.nn.Module = None,
decoder: torch.nn.Module = None,
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}}):
super().__init__()
self.input_size = input_size
self.output_size = output_size
self.decoder_conf = decoder_conf
self.vocab_size = vocab_size
self.output_type = output_type
self.input_frame_rate = input_frame_rate
logging.info(f"input frame rate={self.input_frame_rate}")
self.input_embedding = nn.Embedding(vocab_size, input_size)
self.spk_embed_affine_layer = torch.nn.Linear(spk_embed_dim, output_size)
self.pre_lookahead_len = pre_lookahead_len
self.pre_lookahead_layer = pre_lookahead_layer
self.decoder = decoder
self.only_mask_loss = only_mask_loss
self.token_mel_ratio = token_mel_ratio
if online_feature is True:
self.speech_token_extractor = SpeechTokenExtractor(model_path=os.path.join(onnx_path, 'speech_tokenizer_v3.batch.onnx'))
def forward(
self,
batch: dict,
device: torch.device,
) -> Dict[str, Optional[torch.Tensor]]:
if 'speech_token' not in batch:
token, token_len = self.speech_token_extractor.inference(batch['whisper_feat'], batch['whisper_feat_len'], device)
else:
token = batch['speech_token'].to(device)
token_len = batch['speech_token_len'].to(device)
feat = batch['speech_feat'].to(device)
feat_len = batch['speech_feat_len'].to(device)
embedding = batch['embedding'].to(device)
# NOTE unified training, static_chunk_size > 0 or = 0
streaming = True if random.random() < 0.5 else False
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
mask = (~make_pad_mask(token_len)).float().unsqueeze(-1).to(device)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
h = self.pre_lookahead_layer(token)
h = h.repeat_interleave(self.token_mel_ratio, dim=1)
mask = mask.repeat_interleave(self.token_mel_ratio, dim=1).squeeze(dim=-1)
# get conditions
conds = torch.zeros(feat.shape, device=token.device)
for i, j in enumerate(feat_len):
if random.random() < 0.5:
continue
index = random.randint(0, int(0.3 * j))
conds[i, :index] = feat[i, :index]
conds = conds.transpose(1, 2)
loss, _ = self.decoder.compute_loss(
feat.transpose(1, 2).contiguous(),
mask.unsqueeze(1),
h.transpose(1, 2).contiguous(),
embedding,
cond=conds,
streaming=streaming,
)
return {'loss': loss}
@torch.inference_mode()
def inference(self,
token,
token_len,
prompt_token,
prompt_token_len,
prompt_feat,
prompt_feat_len,
embedding,
streaming,
finalize):
assert token.shape[0] == 1
# xvec projection
embedding = F.normalize(embedding, dim=1)
embedding = self.spk_embed_affine_layer(embedding)
# concat text and prompt_text
token, token_len = torch.concat([prompt_token, token], dim=1), prompt_token_len + token_len
mask = (~make_pad_mask(token_len)).unsqueeze(-1).to(embedding)
token = self.input_embedding(torch.clamp(token, min=0)) * mask
# text encode
if finalize is True:
h = self.pre_lookahead_layer(token)
else:
h = self.pre_lookahead_layer(token[:, :-self.pre_lookahead_len], context=token[:, -self.pre_lookahead_len:])
h = h.repeat_interleave(self.token_mel_ratio, dim=1)
mel_len1, mel_len2 = prompt_feat.shape[1], h.shape[1] - prompt_feat.shape[1]
# get conditions
conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device).to(h.dtype)
conds[:, :mel_len1] = prompt_feat
conds = conds.transpose(1, 2)
mask = (~make_pad_mask(torch.tensor([mel_len1 + mel_len2]))).to(h)
feat, _ = self.decoder(
mu=h.transpose(1, 2).contiguous(),
mask=mask.unsqueeze(1),
spks=embedding,
cond=conds,
n_timesteps=10,
streaming=streaming
)
feat = feat[:, :, mel_len1:]
assert feat.shape[2] == mel_len2
return feat.float(), None
if __name__ == '__main__':
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
from hyperpyyaml import load_hyperpyyaml
with open('./pretrained_models/Fun-CosyVoice3-0.5B/cosyvoice3.yaml', 'r') as f:
configs = load_hyperpyyaml(f, overrides={'llm': None, 'hift': None})
model = configs['flow']
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model.to(device)
model.eval()
max_len = 10 * model.decoder.estimator.static_chunk_size
chunk_size = model.decoder.estimator.static_chunk_size
context_size = model.pre_lookahead_layer.pre_lookahead_len
token = torch.randint(0, 6561, size=(1, max_len)).to(device)
token_len = torch.tensor([max_len]).to(device)
prompt_token = torch.randint(0, 6561, size=(1, chunk_size)).to(device)
prompt_token_len = torch.tensor([chunk_size]).to(device)
prompt_feat = torch.rand(1, chunk_size * 2, 80).to(device)
prompt_feat_len = torch.tensor([chunk_size * 2]).to(device)
prompt_embedding = torch.rand(1, 192).to(device)
pred_gt, _ = model.inference(token, token_len, prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=True)
for i in range(0, max_len, chunk_size):
finalize = True if i + chunk_size + context_size >= max_len else False
pred_chunk, _ = model.inference(token[:, :i + chunk_size + context_size], torch.tensor([token[:, :i + chunk_size + context_size].shape[1]]).to(device),
prompt_token, prompt_token_len, prompt_feat, prompt_feat_len, prompt_embedding, streaming=True, finalize=finalize)
pred_chunk = pred_chunk[:, :, i * model.token_mel_ratio:]
print((pred_gt[:, :, i * model.token_mel_ratio: i * model.token_mel_ratio + pred_chunk.shape[2]] - pred_chunk).abs().max().item())

View File

@@ -0,0 +1,227 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
# 2025 Alibaba Inc (authors: Xiang Lyu, Bofan Zhou)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from matcha.models.components.flow_matching import BASECFM
from cosyvoice.utils.common import set_all_random_seed
class ConditionalCFM(BASECFM):
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
super().__init__(
n_feats=in_channels,
cfm_params=cfm_params,
n_spks=n_spks,
spk_emb_dim=spk_emb_dim,
)
self.t_scheduler = cfm_params.t_scheduler
self.training_cfg_rate = cfm_params.training_cfg_rate
self.inference_cfg_rate = cfm_params.inference_cfg_rate
in_channels = in_channels + (spk_emb_dim if n_spks > 0 else 0)
# Just change the architecture of the estimator here
self.estimator = estimator
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, prompt_len=0, cache=torch.zeros(1, 80, 0, 2)):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
z = torch.randn_like(mu).to(mu.device).to(mu.dtype) * temperature
cache_size = cache.shape[2]
# fix prompt and overlap part mu and z
if cache_size != 0:
z[:, :, :cache_size] = cache[:, :, :, 0]
mu[:, :, :cache_size] = cache[:, :, :, 1]
z_cache = torch.concat([z[:, :, :prompt_len], z[:, :, -34:]], dim=2)
mu_cache = torch.concat([mu[:, :, :prompt_len], mu[:, :, -34:]], dim=2)
cache = torch.stack([z_cache, mu_cache], dim=-1)
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond), cache
def solve_euler(self, x, t_span, mu, mask, spks, cond, streaming=False):
"""
Fixed euler solver for ODEs.
Args:
x (torch.Tensor): random noise
t_span (torch.Tensor): n_timesteps interpolated
shape: (n_timesteps + 1,)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
"""
t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
t = t.unsqueeze(dim=0)
# I am storing this because I can later plot it by putting a debugger here and saving it to a file
# Or in future might add like a return_all_steps flag
sol = []
# Do not use concat, it may cause memory format changed and trt infer with wrong results!
# NOTE when flow run in amp mode, x.dtype is float32, which cause nan in trt fp16 inference, so set dtype=spks.dtype
x_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
mask_in = torch.zeros([2, 1, x.size(2)], device=x.device, dtype=spks.dtype)
mu_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
t_in = torch.zeros([2], device=x.device, dtype=spks.dtype)
spks_in = torch.zeros([2, 80], device=x.device, dtype=spks.dtype)
cond_in = torch.zeros([2, 80, x.size(2)], device=x.device, dtype=spks.dtype)
for step in range(1, len(t_span)):
# Classifier-Free Guidance inference introduced in VoiceBox
x_in[:] = x
mask_in[:] = mask
mu_in[0] = mu
t_in[:] = t.unsqueeze(0)
spks_in[0] = spks
cond_in[0] = cond
dphi_dt = self.forward_estimator(
x_in, mask_in,
mu_in, t_in,
spks_in,
cond_in,
streaming
)
dphi_dt, cfg_dphi_dt = torch.split(dphi_dt, [x.size(0), x.size(0)], dim=0)
dphi_dt = ((1.0 + self.inference_cfg_rate) * dphi_dt - self.inference_cfg_rate * cfg_dphi_dt)
x = x + dt * dphi_dt
t = t + dt
sol.append(x)
if step < len(t_span) - 1:
dt = t_span[step + 1] - t
return sol[-1].float()
def forward_estimator(self, x, mask, mu, t, spks, cond, streaming=False):
if isinstance(self.estimator, torch.nn.Module):
return self.estimator(x, mask, mu, t, spks, cond, streaming=streaming)
else:
[estimator, stream], trt_engine = self.estimator.acquire_estimator()
# NOTE need to synchronize when switching stream
torch.cuda.current_stream().synchronize()
with stream:
estimator.set_input_shape('x', (2, 80, x.size(2)))
estimator.set_input_shape('mask', (2, 1, x.size(2)))
estimator.set_input_shape('mu', (2, 80, x.size(2)))
estimator.set_input_shape('t', (2,))
estimator.set_input_shape('spks', (2, 80))
estimator.set_input_shape('cond', (2, 80, x.size(2)))
data_ptrs = [x.contiguous().data_ptr(),
mask.contiguous().data_ptr(),
mu.contiguous().data_ptr(),
t.contiguous().data_ptr(),
spks.contiguous().data_ptr(),
cond.contiguous().data_ptr(),
x.data_ptr()]
for i, j in enumerate(data_ptrs):
estimator.set_tensor_address(trt_engine.get_tensor_name(i), j)
# run trt engine
assert estimator.execute_async_v3(torch.cuda.current_stream().cuda_stream) is True
torch.cuda.current_stream().synchronize()
self.estimator.release_estimator(estimator, stream)
return x
def compute_loss(self, x1, mask, mu, spks=None, cond=None, streaming=False):
"""Computes diffusion loss
Args:
x1 (torch.Tensor): Target
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): target mask
shape: (batch_size, 1, mel_timesteps)
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
spks (torch.Tensor, optional): speaker embedding. Defaults to None.
shape: (batch_size, spk_emb_dim)
Returns:
loss: conditional flow matching loss
y: conditional flow
shape: (batch_size, n_feats, mel_timesteps)
"""
b, _, t = mu.shape
# random timestep
t = torch.rand([b, 1, 1], device=mu.device, dtype=mu.dtype)
# sample noise p(x_0)
z = torch.randn_like(x1)
y = (1 - (1 - self.sigma_min) * t) * z + t * x1
u = x1 - (1 - self.sigma_min) * z
# during training, we randomly drop condition to trade off mode coverage and sample fidelity
if self.training_cfg_rate > 0:
cfg_mask = torch.rand(b, device=x1.device) > self.training_cfg_rate
mu = mu * cfg_mask.view(-1, 1, 1)
spks = spks * cfg_mask.view(-1, 1)
cond = cond * cfg_mask.view(-1, 1, 1)
pred = self.estimator(y, mask, mu, t.squeeze(), spks, cond, streaming=streaming)
loss = F.mse_loss(pred * mask, u * mask, reduction="sum") / (torch.sum(mask) * u.shape[1])
return loss, y
class CausalConditionalCFM(ConditionalCFM):
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
super().__init__(in_channels, cfm_params, n_spks, spk_emb_dim, estimator)
set_all_random_seed(0)
self.rand_noise = torch.randn([1, 80, 50 * 300])
@torch.inference_mode()
def forward(self, mu, mask, n_timesteps, temperature=1.0, spks=None, cond=None, streaming=False):
"""Forward diffusion
Args:
mu (torch.Tensor): output of encoder
shape: (batch_size, n_feats, mel_timesteps)
mask (torch.Tensor): output_mask
shape: (batch_size, 1, mel_timesteps)
n_timesteps (int): number of diffusion steps
temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
spks (torch.Tensor, optional): speaker ids. Defaults to None.
shape: (batch_size, spk_emb_dim)
cond: Not used but kept for future purposes
Returns:
sample: generated mel-spectrogram
shape: (batch_size, n_feats, mel_timesteps)
"""
z = self.rand_noise[:, :, :mu.size(2)].to(mu.device).to(mu.dtype) * temperature
# fix prompt and overlap part mu and z
t_span = torch.linspace(0, 1, n_timesteps + 1, device=mu.device, dtype=mu.dtype)
if self.t_scheduler == 'cosine':
t_span = 1 - torch.cos(t_span * 0.5 * torch.pi)
return self.solve_euler(z, t_span=t_span, mu=mu, mask=mask, spks=spks, cond=cond, streaming=streaming), None

View File

@@ -0,0 +1,70 @@
# Copyright (c) 2024 Alibaba Inc (authors: Xiang Lyu, Zhihao Du)
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Tuple
import torch.nn as nn
import torch
from torch.nn import functional as F
from cosyvoice.utils.mask import make_pad_mask
class InterpolateRegulator(nn.Module):
def __init__(
self,
channels: int,
sampling_ratios: Tuple,
out_channels: int = None,
groups: int = 1,
):
super().__init__()
self.sampling_ratios = sampling_ratios
out_channels = out_channels or channels
model = nn.ModuleList([])
if len(sampling_ratios) > 0:
for _ in sampling_ratios:
module = nn.Conv1d(channels, channels, 3, 1, 1)
norm = nn.GroupNorm(groups, channels)
act = nn.Mish()
model.extend([module, norm, act])
model.append(
nn.Conv1d(channels, out_channels, 1, 1)
)
self.model = nn.Sequential(*model)
def forward(self, x, ylens=None):
# x in (B, T, D)
mask = (~make_pad_mask(ylens)).to(x).unsqueeze(-1)
x = F.interpolate(x.transpose(1, 2).contiguous(), size=ylens.max(), mode='linear')
out = self.model(x).transpose(1, 2).contiguous()
olens = ylens
return out * mask, olens
def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
# in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
# NOTE 20 corresponds to token_overlap_len in cosyvoice/cli/model.py
# x in (B, T, D)
if x2.shape[1] > 40:
x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
mode='linear')
x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
else:
x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
if x1.shape[1] != 0:
x1 = F.interpolate(x1.transpose(1, 2).contiguous(), size=mel_len1, mode='linear')
x = torch.concat([x1, x2], dim=2)
else:
x = x2
out = self.model(x).transpose(1, 2).contiguous()
return out, mel_len1 + mel_len2

View File

@@ -0,0 +1,230 @@
import torch
import torch.nn as nn
import torch.nn.functional as F
try:
from torch.nn.utils.parametrizations import weight_norm, spectral_norm
except ImportError:
from torch.nn.utils import weight_norm, spectral_norm
from typing import List, Optional, Tuple
from einops import rearrange
from torchaudio.transforms import Spectrogram
LRELU_SLOPE = 0.1
class MultipleDiscriminator(nn.Module):
def __init__(
self, mpd: nn.Module, mrd: nn.Module
):
super().__init__()
self.mpd = mpd
self.mrd = mrd
def forward(self, y: torch.Tensor, y_hat: torch.Tensor):
y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mpd(y.unsqueeze(dim=1), y_hat.unsqueeze(dim=1))
y_d_rs += this_y_d_rs
y_d_gs += this_y_d_gs
fmap_rs += this_fmap_rs
fmap_gs += this_fmap_gs
this_y_d_rs, this_y_d_gs, this_fmap_rs, this_fmap_gs = self.mrd(y, y_hat)
y_d_rs += this_y_d_rs
y_d_gs += this_y_d_gs
fmap_rs += this_fmap_rs
fmap_gs += this_fmap_gs
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class MultiResolutionDiscriminator(nn.Module):
def __init__(
self,
fft_sizes: Tuple[int, ...] = (2048, 1024, 512),
num_embeddings: Optional[int] = None,
):
"""
Multi-Resolution Discriminator module adapted from https://github.com/descriptinc/descript-audio-codec.
Additionally, it allows incorporating conditional information with a learned embeddings table.
Args:
fft_sizes (tuple[int]): Tuple of window lengths for FFT. Defaults to (2048, 1024, 512).
num_embeddings (int, optional): Number of embeddings. None means non-conditional discriminator.
Defaults to None.
"""
super().__init__()
self.discriminators = nn.ModuleList(
[DiscriminatorR(window_length=w, num_embeddings=num_embeddings) for w in fft_sizes]
)
def forward(
self, y: torch.Tensor, y_hat: torch.Tensor, bandwidth_id: torch.Tensor = None
) -> Tuple[List[torch.Tensor], List[torch.Tensor], List[List[torch.Tensor]], List[List[torch.Tensor]]]:
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for d in self.discriminators:
y_d_r, fmap_r = d(x=y, cond_embedding_id=bandwidth_id)
y_d_g, fmap_g = d(x=y_hat, cond_embedding_id=bandwidth_id)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
class DiscriminatorR(nn.Module):
def __init__(
self,
window_length: int,
num_embeddings: Optional[int] = None,
channels: int = 32,
hop_factor: float = 0.25,
bands: Tuple[Tuple[float, float], ...] = ((0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)),
):
super().__init__()
self.window_length = window_length
self.hop_factor = hop_factor
self.spec_fn = Spectrogram(
n_fft=window_length, hop_length=int(window_length * hop_factor), win_length=window_length, power=None
)
n_fft = window_length // 2 + 1
bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands]
self.bands = bands
convs = lambda: nn.ModuleList(
[
weight_norm(nn.Conv2d(2, channels, (3, 9), (1, 1), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 9), (1, 2), padding=(1, 4))),
weight_norm(nn.Conv2d(channels, channels, (3, 3), (1, 1), padding=(1, 1))),
]
)
self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))])
if num_embeddings is not None:
self.emb = torch.nn.Embedding(num_embeddings=num_embeddings, embedding_dim=channels)
torch.nn.init.zeros_(self.emb.weight)
self.conv_post = weight_norm(nn.Conv2d(channels, 1, (3, 3), (1, 1), padding=(1, 1)))
def spectrogram(self, x):
# Remove DC offset
x = x - x.mean(dim=-1, keepdims=True)
# Peak normalize the volume of input audio
x = 0.8 * x / (x.abs().max(dim=-1, keepdim=True)[0] + 1e-9)
x = self.spec_fn(x)
x = torch.view_as_real(x)
x = rearrange(x, "b f t c -> b c t f")
# Split into bands
x_bands = [x[..., b[0]: b[1]] for b in self.bands]
return x_bands
def forward(self, x: torch.Tensor, cond_embedding_id: torch.Tensor = None):
x_bands = self.spectrogram(x)
fmap = []
x = []
for band, stack in zip(x_bands, self.band_convs):
for i, layer in enumerate(stack):
band = layer(band)
band = torch.nn.functional.leaky_relu(band, 0.1)
if i > 0:
fmap.append(band)
x.append(band)
x = torch.cat(x, dim=-1)
if cond_embedding_id is not None:
emb = self.emb(cond_embedding_id)
h = (emb.view(1, -1, 1, 1) * x).sum(dim=1, keepdims=True)
else:
h = 0
x = self.conv_post(x)
fmap.append(x)
x += h
return x, fmap
class MultiResSpecDiscriminator(torch.nn.Module):
def __init__(self,
fft_sizes=[1024, 2048, 512],
hop_sizes=[120, 240, 50],
win_lengths=[600, 1200, 240],
window="hann_window"):
super(MultiResSpecDiscriminator, self).__init__()
self.discriminators = nn.ModuleList([
SpecDiscriminator(fft_sizes[0], hop_sizes[0], win_lengths[0], window),
SpecDiscriminator(fft_sizes[1], hop_sizes[1], win_lengths[1], window),
SpecDiscriminator(fft_sizes[2], hop_sizes[2], win_lengths[2], window)])
def forward(self, y, y_hat):
y_d_rs = []
y_d_gs = []
fmap_rs = []
fmap_gs = []
for _, d in enumerate(self.discriminators):
y_d_r, fmap_r = d(y)
y_d_g, fmap_g = d(y_hat)
y_d_rs.append(y_d_r)
fmap_rs.append(fmap_r)
y_d_gs.append(y_d_g)
fmap_gs.append(fmap_g)
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
def stft(x, fft_size, hop_size, win_length, window):
"""Perform STFT and convert to magnitude spectrogram.
Args:
x (Tensor): Input signal tensor (B, T).
fft_size (int): FFT size.
hop_size (int): Hop size.
win_length (int): Window length.
window (str): Window function type.
Returns:
Tensor: Magnitude spectrogram (B, #frames, fft_size // 2 + 1).
"""
x_stft = torch.stft(x, fft_size, hop_size, win_length, window, return_complex=True)
# NOTE(kan-bayashi): clamp is needed to avoid nan or inf
return torch.abs(x_stft).transpose(2, 1)
class SpecDiscriminator(nn.Module):
"""docstring for Discriminator."""
def __init__(self, fft_size=1024, shift_size=120, win_length=600, window="hann_window", use_spectral_norm=False):
super(SpecDiscriminator, self).__init__()
norm_f = weight_norm if use_spectral_norm is False else spectral_norm
self.fft_size = fft_size
self.shift_size = shift_size
self.win_length = win_length
self.window = getattr(torch, window)(win_length)
self.discriminators = nn.ModuleList([
norm_f(nn.Conv2d(1, 32, kernel_size=(3, 9), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 9), stride=(1, 2), padding=(1, 4))),
norm_f(nn.Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))),
])
self.out = norm_f(nn.Conv2d(32, 1, 3, 1, 1))
def forward(self, y):
fmap = []
y = y.squeeze(1)
y = stft(y, self.fft_size, self.shift_size, self.win_length, self.window.to(y.device))
y = y.unsqueeze(1)
for _, d in enumerate(self.discriminators):
y = d(y)
y = F.leaky_relu(y, LRELU_SLOPE)
fmap.append(y)
y = self.out(y)
fmap.append(y)
return torch.flatten(y, 1, -1), fmap

Some files were not shown because too many files have changed in this diff Show More