Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0e3502c6f0 | ||
|
|
a1604979f0 | ||
|
|
08221e48de | ||
|
|
42b5cc0c02 | ||
|
|
1717635bfd |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -40,6 +40,7 @@ backend/uploads/
|
||||
backend/cookies/
|
||||
backend/user_data/
|
||||
backend/debug_screenshots/
|
||||
backend/keys/
|
||||
*_cookies.json
|
||||
|
||||
# ============ 模型权重 ============
|
||||
|
||||
@@ -156,6 +156,14 @@ backend/user_data/{user_uuid}/cookies/
|
||||
- `LATENTSYNC_*`
|
||||
- `CORS_ORIGINS` (CORS 白名单,默认 *)
|
||||
|
||||
### MuseTalk / 混合唇形同步
|
||||
- `MUSETALK_GPU_ID` (GPU 编号,默认 0)
|
||||
- `MUSETALK_API_URL` (常驻服务地址,默认 http://localhost:8011)
|
||||
- `MUSETALK_BATCH_SIZE` (推理批大小,默认 32)
|
||||
- `MUSETALK_VERSION` (v15)
|
||||
- `MUSETALK_USE_FLOAT16` (半精度,默认 true)
|
||||
- `LIPSYNC_DURATION_THRESHOLD` (秒,>=此值用 MuseTalk,默认 120)
|
||||
|
||||
### 微信视频号
|
||||
- `WEIXIN_HEADLESS_MODE` (headful/headless-new)
|
||||
- `WEIXIN_CHROME_PATH` / `WEIXIN_BROWSER_CHANNEL`
|
||||
|
||||
@@ -101,7 +101,7 @@ backend/
|
||||
* `POST /api/tools/extract-script`: 从视频链接提取文案
|
||||
|
||||
10. **健康检查**
|
||||
* `GET /api/lipsync/health`: LatentSync 服务健康状态
|
||||
* `GET /api/lipsync/health`: 唇形同步服务健康状态(含 LatentSync + MuseTalk + 混合路由阈值)
|
||||
* `GET /api/voiceclone/health`: CosyVoice 3.0 服务健康状态
|
||||
|
||||
11. **支付 (Payment)**
|
||||
@@ -202,6 +202,12 @@ GLM_API_KEY=your_glm_api_key
|
||||
|
||||
# LatentSync 配置
|
||||
LATENTSYNC_GPU_ID=1
|
||||
|
||||
# MuseTalk 配置 (长视频唇形同步)
|
||||
MUSETALK_GPU_ID=0
|
||||
MUSETALK_API_URL=http://localhost:8011
|
||||
MUSETALK_BATCH_SIZE=32
|
||||
LIPSYNC_DURATION_THRESHOLD=120
|
||||
```
|
||||
|
||||
### 4. 启动服务
|
||||
@@ -224,6 +230,14 @@ uvicorn app.main:app --host 0.0.0.0 --port 8006 --reload
|
||||
3. **重要**: 如果模型占用 GPU,请务必使用 `asyncio.Lock` 进行并发控制,防止 OOM。
|
||||
4. 在 `app/modules/` 下创建对应模块,添加 router/service/schemas,并在 `main.py` 注册路由。
|
||||
|
||||
### 唇形同步混合路由
|
||||
|
||||
`lipsync_service.py` 实现了 LatentSync + MuseTalk 混合路由:
|
||||
- 短视频 (<`LIPSYNC_DURATION_THRESHOLD`s) → LatentSync 1.6 (GPU1, 端口 8007)
|
||||
- 长视频 (>=阈值) → MuseTalk 1.5 (GPU0, 端口 8011)
|
||||
- MuseTalk 不可用时自动回退到 LatentSync
|
||||
- 路由逻辑对 workflow 完全透明
|
||||
|
||||
### 添加定时任务
|
||||
|
||||
目前推荐使用 **APScheduler** 或 **Crontab** 来管理定时任务。
|
||||
|
||||
@@ -7,6 +7,7 @@
|
||||
| 模型 | Fun-CosyVoice3-0.5B-2512 (0.5B 参数) |
|
||||
| 端口 | 8010 |
|
||||
| GPU | 0 (CUDA_VISIBLE_DEVICES=0) |
|
||||
| 推理精度 | FP16 (自动混合精度) |
|
||||
| PM2 名称 | vigent2-cosyvoice (id=15) |
|
||||
| Conda 环境 | cosyvoice (Python 3.10) |
|
||||
| 启动脚本 | `run_cosyvoice.sh` |
|
||||
|
||||
@@ -7,8 +7,8 @@
|
||||
| 服务器 | Dell PowerEdge R730 |
|
||||
| CPU | 2× Intel Xeon E5-2680 v4 (56 线程) |
|
||||
| 内存 | 192GB DDR4 |
|
||||
| GPU 0 | NVIDIA RTX 3090 24GB |
|
||||
| GPU 1 | NVIDIA RTX 3090 24GB (用于 LatentSync) |
|
||||
| GPU 0 | NVIDIA RTX 3090 24GB (MuseTalk + CosyVoice) |
|
||||
| GPU 1 | NVIDIA RTX 3090 24GB (LatentSync) |
|
||||
| 部署路径 | `/home/rongye/ProgramFiles/ViGent2` |
|
||||
|
||||
---
|
||||
@@ -72,7 +72,9 @@ cd /home/rongye/ProgramFiles/ViGent2
|
||||
|
||||
---
|
||||
|
||||
## 步骤 3: 部署 AI 模型 (LatentSync 1.6)
|
||||
## 步骤 3: 部署 AI 模型
|
||||
|
||||
### 3a. LatentSync 1.6 (短视频唇形同步, GPU1)
|
||||
|
||||
> ⚠️ **重要**:LatentSync 需要独立的 Conda 环境和 **~18GB VRAM**。请**不要**直接安装在后端环境中。
|
||||
|
||||
@@ -93,6 +95,26 @@ conda activate latentsync
|
||||
python -m scripts.server # 测试能否启动,Ctrl+C 退出
|
||||
```
|
||||
|
||||
### 3b. MuseTalk 1.5 (长视频唇形同步, GPU0)
|
||||
|
||||
> MuseTalk 是单步潜空间修复模型(非扩散模型),推理速度接近实时,适合 >=120s 的长视频。与 CosyVoice 共享 GPU0,fp16 推理约需 4-8GB 显存。
|
||||
|
||||
请参考详细的独立部署指南:
|
||||
**[MuseTalk 部署指南](MUSETALK_DEPLOY.md)**
|
||||
|
||||
简要步骤:
|
||||
1. 创建独立的 `musetalk` Conda 环境 (Python 3.10 + PyTorch 2.0.1 + CUDA 11.8)
|
||||
2. 安装 mmcv/mmdet/mmpose 等依赖
|
||||
3. 下载模型权重 (`download_weights.sh`)
|
||||
4. 创建必要的软链接 (`musetalk/config.json`, `musetalk/musetalkV15`)
|
||||
|
||||
**验证 MuseTalk 部署**:
|
||||
```bash
|
||||
cd /home/rongye/ProgramFiles/ViGent2/models/MuseTalk
|
||||
/home/rongye/ProgramFiles/miniconda3/envs/musetalk/bin/python scripts/server.py
|
||||
# 另一个终端: curl http://localhost:8011/health
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 步骤 4: 安装后端依赖
|
||||
@@ -189,7 +211,7 @@ cp .env.example .env
|
||||
| `SUPABASE_PUBLIC_URL` | `https://api.hbyrkj.top` | Supabase API 公网地址 (前端访问) |
|
||||
| `LATENTSYNC_GPU_ID` | 1 | GPU 选择 (0 或 1) |
|
||||
| `LATENTSYNC_USE_SERVER` | false | 设为 true 以启用常驻服务加速 |
|
||||
| `LATENTSYNC_INFERENCE_STEPS` | 20 | 推理步数 (20-50) |
|
||||
| `LATENTSYNC_INFERENCE_STEPS` | 16 | 推理步数 (16-50) |
|
||||
| `LATENTSYNC_GUIDANCE_SCALE` | 1.5 | 引导系数 (1.0-3.0) |
|
||||
| `DEBUG` | true | 生产环境改为 false |
|
||||
| `REDIS_URL` | `redis://localhost:6379/0` | 任务状态存储(不可用时回退内存) |
|
||||
@@ -212,7 +234,12 @@ cp .env.example .env
|
||||
| `DOUYIN_RECORD_VIDEO` | false | 录制浏览器操作视频 |
|
||||
| `DOUYIN_KEEP_SUCCESS_VIDEO` | false | 成功后保留录屏 |
|
||||
| `CORS_ORIGINS` | `*` | CORS 允许源 (生产环境建议白名单) |
|
||||
| `DOUYIN_COOKIE` | 空 | 抖音视频下载 Cookie (文案提取功能) |
|
||||
| `MUSETALK_GPU_ID` | 0 | MuseTalk GPU 编号 |
|
||||
| `MUSETALK_API_URL` | `http://localhost:8011` | MuseTalk 常驻服务地址 |
|
||||
| `MUSETALK_BATCH_SIZE` | 32 | MuseTalk 推理批大小 |
|
||||
| `MUSETALK_VERSION` | v15 | MuseTalk 模型版本 |
|
||||
| `MUSETALK_USE_FLOAT16` | true | MuseTalk 半精度加速 |
|
||||
| `LIPSYNC_DURATION_THRESHOLD` | 120 | 秒,>=此值用 MuseTalk,<此值用 LatentSync |
|
||||
| `ALIPAY_APP_ID` | 空 | 支付宝应用 APPID |
|
||||
| `ALIPAY_PRIVATE_KEY_PATH` | 空 | 应用私钥 PEM 文件路径 |
|
||||
| `ALIPAY_PUBLIC_KEY_PATH` | 空 | 支付宝公钥 PEM 文件路径 |
|
||||
@@ -271,6 +298,13 @@ cd /home/rongye/ProgramFiles/ViGent2/models/LatentSync
|
||||
conda activate latentsync
|
||||
python -m scripts.server
|
||||
```
|
||||
|
||||
### 启动 MuseTalk (终端 4, 长视频唇形同步)
|
||||
|
||||
```bash
|
||||
cd /home/rongye/ProgramFiles/ViGent2/models/MuseTalk
|
||||
/home/rongye/ProgramFiles/miniconda3/envs/musetalk/bin/python scripts/server.py
|
||||
```
|
||||
|
||||
### 验证
|
||||
|
||||
@@ -364,7 +398,27 @@ pm2 save
|
||||
curl http://localhost:8010/health
|
||||
```
|
||||
|
||||
### 5. 启动服务看门狗 (Watchdog)
|
||||
### 5. 启动 MuseTalk 长视频唇形同步服务
|
||||
|
||||
> 长视频 (>=120s) 自动路由到 MuseTalk。MuseTalk 不可用时自动回退 LatentSync。
|
||||
> 详细部署步骤见 [MuseTalk 部署指南](MUSETALK_DEPLOY.md)。
|
||||
|
||||
1. 启动脚本位于项目根目录: `run_musetalk.sh`
|
||||
|
||||
2. 使用 pm2 启动:
|
||||
```bash
|
||||
cd /home/rongye/ProgramFiles/ViGent2
|
||||
pm2 start ./run_musetalk.sh --name vigent2-musetalk
|
||||
pm2 save
|
||||
```
|
||||
|
||||
3. 验证服务:
|
||||
```bash
|
||||
curl http://localhost:8011/health
|
||||
# {"status":"ok","model_loaded":true}
|
||||
```
|
||||
|
||||
### 6. 启动服务看门狗 (Watchdog)
|
||||
|
||||
> 🛡️ **推荐**:监控 CosyVoice 和 LatentSync 服务健康状态,卡死时自动重启。
|
||||
|
||||
@@ -381,6 +435,8 @@ pm2 save
|
||||
pm2 startup
|
||||
```
|
||||
|
||||
> **提示**: 完整的 PM2 进程列表应包含 5-6 个服务: vigent2-backend, vigent2-frontend, vigent2-latentsync, vigent2-cosyvoice, vigent2-musetalk, vigent2-watchdog。
|
||||
|
||||
### pm2 常用命令
|
||||
|
||||
```bash
|
||||
@@ -388,6 +444,7 @@ pm2 status # 查看所有服务状态
|
||||
pm2 logs # 查看所有日志
|
||||
pm2 logs vigent2-backend # 查看后端日志
|
||||
pm2 logs vigent2-cosyvoice # 查看 CosyVoice 日志
|
||||
pm2 logs vigent2-musetalk # 查看 MuseTalk 日志
|
||||
pm2 restart all # 重启所有服务
|
||||
pm2 stop vigent2-latentsync # 停止 LatentSync 服务
|
||||
pm2 delete all # 删除所有服务
|
||||
@@ -527,6 +584,7 @@ sudo lsof -i :8006
|
||||
sudo lsof -i :3002
|
||||
sudo lsof -i :8007
|
||||
sudo lsof -i :8010 # CosyVoice
|
||||
sudo lsof -i :8011 # MuseTalk
|
||||
```
|
||||
|
||||
### 查看日志
|
||||
@@ -537,6 +595,7 @@ pm2 logs vigent2-backend
|
||||
pm2 logs vigent2-frontend
|
||||
pm2 logs vigent2-latentsync
|
||||
pm2 logs vigent2-cosyvoice
|
||||
pm2 logs vigent2-musetalk
|
||||
```
|
||||
|
||||
### SSH 连接卡顿 / 系统响应慢
|
||||
|
||||
239
Docs/DevLogs/Day26.md
Normal file
239
Docs/DevLogs/Day26.md
Normal file
@@ -0,0 +1,239 @@
|
||||
## 🎨 前端优化:板块合并 + 序号标题 + UI 精细化 (Day 26)
|
||||
|
||||
### 概述
|
||||
|
||||
首页原有 9 个独立板块(左栏 7 个 + 右栏 2 个),每个都有自己的卡片容器和标题,视觉碎片化严重。本次将相关板块合并为 5 个主板块,添加中文序号(一~十),移除 emoji 图标,并对多个子组件的布局和交互细节进行优化。
|
||||
|
||||
---
|
||||
|
||||
## ✅ 改动内容
|
||||
|
||||
### 1. 板块合并方案
|
||||
|
||||
**左栏(4 个主板块 + 2 个独立区域):**
|
||||
|
||||
| 序号 | 板块名 | 子板块 | 原组件 |
|
||||
|------|--------|--------|--------|
|
||||
| 一 | 文案提取与编辑 | — | ScriptEditor |
|
||||
| 二 | 标题与字幕 | — | TitleSubtitlePanel |
|
||||
| 三 | 配音 | 配音方式 / 配音列表 | VoiceSelector + GeneratedAudiosPanel |
|
||||
| 四 | 素材编辑 | 视频素材 / 时间轴编辑 | MaterialSelector + TimelineEditor |
|
||||
| 五 | 背景音乐 | — | BgmPanel |
|
||||
| — | 生成按钮 | — | GenerateActionBar(不编号) |
|
||||
|
||||
**右栏(1 个主板块):**
|
||||
|
||||
| 序号 | 板块名 | 子板块 | 原组件 |
|
||||
|------|--------|--------|--------|
|
||||
| 六 | 作品 | 作品列表 / 作品预览 | HistoryList + PreviewPanel |
|
||||
|
||||
**发布页(/publish):**
|
||||
|
||||
| 序号 | 板块名 |
|
||||
|------|--------|
|
||||
| 七 | 平台账号 |
|
||||
| 八 | 选择发布作品 |
|
||||
| 九 | 发布信息 |
|
||||
| 十 | 选择发布平台 |
|
||||
|
||||
### 2. embedded 模式
|
||||
|
||||
6 个组件新增 `embedded?: boolean` prop(默认 `false`):
|
||||
|
||||
- `VoiceSelector` — embedded 时不渲染外层卡片和主标题
|
||||
- `GeneratedAudiosPanel` — embedded 时两行布局:第 1 行(语速+生成配音右对齐)、第 2 行(配音列表+刷新)
|
||||
- `MaterialSelector` — embedded 时自渲染 h3 子标题"视频素材"+ 上传/刷新按钮同行
|
||||
- `TimelineEditor` — embedded 时自渲染 h3 子标题"时间轴编辑"+ 画面比例/播放控件同行
|
||||
- `PreviewPanel` — embedded 时不渲染外层卡片和标题
|
||||
- `HistoryList` — embedded 时不渲染外层卡片和标题(刷新按钮由 HomePage 提供)
|
||||
|
||||
### 3. 序号标题 + emoji 移除
|
||||
|
||||
所有编号板块移除 emoji 图标,使用纯中文序号:
|
||||
|
||||
- ScriptEditor: `✍️ 文案提取与编辑` → `一、文案提取与编辑`
|
||||
- TitleSubtitlePanel: `🎬 标题与字幕` → `二、标题与字幕`
|
||||
- BgmPanel: `🎵 背景音乐` → `五、背景音乐`
|
||||
- HomePage 右栏: `五、作品` → `六、作品`
|
||||
- PublishPage: `👤 平台账号` → `七、平台账号`、`📹 选择发布作品` → `八、选择发布作品`、`✍️ 发布信息` → `九、发布信息`、`📱 选择发布平台` → `十、选择发布平台`
|
||||
|
||||
### 4. 子标题与分隔样式
|
||||
|
||||
- **主标题**: `text-base sm:text-lg font-semibold text-white`
|
||||
- **子标题**: `text-sm font-medium text-gray-400`
|
||||
- **分隔线**: `<div className="border-t border-white/10 my-4" />`
|
||||
|
||||
### 5. 配音列表布局优化
|
||||
|
||||
GeneratedAudiosPanel embedded 模式下采用两行布局:
|
||||
- **第 1 行**:语速下拉 + 生成配音按钮(右对齐,`flex justify-end`)
|
||||
- **第 2 行**:`<h3>配音列表</h3>` + 刷新按钮(两端对齐)
|
||||
- 非 embedded 模式保持原单行布局
|
||||
|
||||
### 6. TitleSubtitlePanel 下拉对齐
|
||||
|
||||
- 标题样式/副标题样式/字幕样式三行标签统一 `w-20`(固定 80px),确保下拉菜单垂直对齐
|
||||
- 下拉菜单宽度 `w-1/3 min-w-[100px]`,避免过宽
|
||||
|
||||
### 7. RefAudioPanel 文案简化
|
||||
|
||||
- 原底部段落"上传任意语音样本(3-10秒)…" 移至 "我的参考音频" 标题旁,简化为 `(上传3-10秒语音样本)`
|
||||
|
||||
### 8. 账户下拉菜单添加手机号
|
||||
|
||||
- AccountSettingsDropdown 在账户有效期上方新增手机号显示区域
|
||||
- 显示 `user?.phone || '未知账户'`
|
||||
|
||||
### 9. 标题显示模式对副标题生效
|
||||
|
||||
- **payload 修复**: `useHomeController.ts` 中 `title_display_mode` 的发送条件从 `videoTitle.trim()` 改为 `videoTitle.trim() || videoSecondaryTitle.trim()`,确保仅有副标题时也能发送显示模式
|
||||
- **UI 调整**: 短暂显示/常驻显示下拉从片头标题输入行移至"二、标题与字幕"板块标题行(与预览样式按钮同行),明确表示该设置对标题和副标题同时生效
|
||||
- Remotion 端 `Title.tsx` 已支持(标题和副标题作为整体组件渲染,`displayMode` 统一控制)
|
||||
|
||||
### 10. 时间轴模糊遮罩
|
||||
|
||||
遮罩从外层 wrapper 移入"四、素材编辑"卡片内,仅覆盖时间轴子区域(`rounded-xl`)。
|
||||
|
||||
### 11. 登录后用户信息立即可用
|
||||
|
||||
- AuthContext 新增 `setUser` 方法暴露给消费者
|
||||
- 登录页成功后调用 `setUser(result.user)` 立即写入 Context,无需等页面刷新
|
||||
- 修复登录后账户下拉显示"未知账户"、刷新后才显示手机号的问题
|
||||
|
||||
### 12. 文案与选项微调
|
||||
|
||||
- MaterialSelector 描述 `(可多选,最多4个)` → `(上传自拍视频,最多可选4个)`
|
||||
- TitleSubtitlePanel 显示模式选项 `短暂显示/常驻显示` → `标题短暂显示/标题常驻显示`
|
||||
|
||||
### 13. UI/UX 体验优化(6 项)
|
||||
|
||||
- **操作按钮移动端可见**: 配音列表、作品列表、素材列表、参考音频、历史文案的操作按钮从 `opacity-0`(hover 才显示)改为 `opacity-40`(平时半透明可见,hover 全亮),解决触屏设备无法发现按钮的问题
|
||||
- **手机号脱敏**: AccountSettingsDropdown 手机号中间四位遮掩 `138****5678`
|
||||
- **标题字数计数器**: TitleSubtitlePanel 标题/副标题输入框右侧显示实时字数 `3/15`,超限变红
|
||||
- **列表滚动条提示**: ~~配音列表、作品列表、素材列表、BGM 列表从 `hide-scrollbar` 改为 `custom-scrollbar`~~ → 已全部改回 `hide-scrollbar` 隐藏滚动条(滚动功能不变)
|
||||
- **时间轴拖拽提示**: TimelineEditor 色块左上角新增 `GripVertical` 抓手图标,暗示可拖拽排序
|
||||
- **截取滑块放大**: ClipTrimmer 手柄从 16px 放大到 20px,触控区从 32px 放大到 40px
|
||||
|
||||
### 14. 代码质量修复(4 项)
|
||||
|
||||
- **AccountSettingsDropdown**: 关闭密码弹窗补齐 `setSuccess('')` 清空
|
||||
- **MaterialSelector**: `selectedSet` 加 `useMemo` 避免每次渲染重建
|
||||
- **TimelineEditor**: `visibleSegments`/`overflowSegments` 加 `useMemo`
|
||||
- **MaterialSelector**: 素材满 4 个时非选中项按钮加 `disabled`
|
||||
|
||||
### 15. 发布页平台账号响应式布局
|
||||
|
||||
- **单行布局**:图标+名称+状态在左,按钮在右(`flex items-center`)
|
||||
- **移动端紧凑**:图标 `h-6 w-6`、按钮 `text-xs px-2 py-1 rounded-md`、间距 `space-y-2 px-3 py-2.5`
|
||||
- **桌面端宽松**:`sm:h-7 sm:w-7`、`sm:text-sm sm:px-3 sm:py-1.5 sm:rounded-lg`、`sm:space-y-3 sm:px-4 sm:py-3.5`
|
||||
- 两端各自美观,风格与其他板块一致
|
||||
|
||||
### 16. 移动端刷新回顶部修复
|
||||
|
||||
- **问题**: 移动端刷新页面后不回到顶部,而是滚动到背景音乐板块
|
||||
- **根因**: 1) 浏览器原生滚动恢复覆盖 `scrollTo(0,0)`;2) 列表 scroll effect 有双依赖(`selectedId` + `list`),数据异步加载时第二次触发跳过了 ref 守卫,执行了 `scrollIntoView` 导致页面跳动
|
||||
- **修复**: 三管齐下 — ① `history.scrollRestoration = "manual"` 禁用浏览器原生恢复;② 时间门控 `scrollEffectsEnabled` ref(1 秒内禁止所有列表自动滚动)替代单次 ref 守卫;③ 200ms 延迟兜底 `scrollTo(0,0)`
|
||||
|
||||
### 17. 移动端样式预览窗口缩小
|
||||
|
||||
- **问题**: 移动端点击"预览样式"后窗口占满整屏(宽 358px,高约 636px),遮挡样式调节控件
|
||||
- **修复**: 移动端宽度从 `window.innerWidth - 32` 缩小到 **160px**;位置从左上角改为**右下角**(`right:12, bottom:12`),不遮挡上方控件;最大高度限制 `50dvh`
|
||||
- 桌面端保持不变(280px,左上角)
|
||||
|
||||
### 18. 列表滚动条统一隐藏
|
||||
|
||||
- 将 Day 26 早期改为 `custom-scrollbar`(细紫色滚动条)的 7 处全部改回 `hide-scrollbar`
|
||||
- 涉及:BgmPanel、GeneratedAudiosPanel、HistoryList、MaterialSelector(2处)、ScriptExtractionModal(2处)
|
||||
- 滚动功能不受影响,仅视觉上不显示滚动条
|
||||
|
||||
### 19. 配音按钮移动端适配
|
||||
|
||||
- VoiceSelector "选择声音/克隆声音" 按钮:内边距 `px-4` → `px-2 sm:px-4`,字号加 `text-sm sm:text-base`,图标加 `shrink-0`
|
||||
- 修复移动端窄屏下按钮被挤压导致"克隆声音"不可见的问题
|
||||
|
||||
### 20. 素材标题溢出修复
|
||||
|
||||
- MaterialSelector embedded 标题行移除 `whitespace-nowrap`
|
||||
- 描述文字 `(上传自拍视频,最多可选4个)` 在移动端隐藏(`hidden sm:inline`),桌面端正常显示
|
||||
- 修复移动端刷新按钮被推出容器外的问题
|
||||
|
||||
### 21. 生成配音按钮放大
|
||||
|
||||
- "生成配音" 作为核心操作按钮,从辅助尺寸升级为主操作尺寸
|
||||
- 内边距 `px-2/px-3 py-1/py-1.5` → `px-4 py-2`,字号 `text-xs` → `text-sm font-medium`
|
||||
- 图标 `h-3.5 w-3.5` → `h-4 w-4`,新增 `shadow-sm` + hover `shadow-md`
|
||||
- embedded 与非 embedded 模式统一放大
|
||||
|
||||
### 22. 生成进度条位置调整
|
||||
|
||||
- **问题**: 生成进度条在"六、作品"卡片内部(作品预览下方),不够醒目
|
||||
- **修复**: 进度条从 PreviewPanel 内部提取到 HomePage 右栏,作为独立卡片渲染在"六、作品"卡片**上方**
|
||||
- 使用紫色边框(`border-purple-500/30`)区分,显示任务消息和百分比
|
||||
- PreviewPanel embedded 模式下不再渲染进度条(传入 `currentTask={null}`)
|
||||
- 生成完成后进度卡片自动消失
|
||||
|
||||
### 23. LatentSync 超时修复
|
||||
|
||||
- **问题**: 约 2 分钟的视频(3023 帧,190 段推理)预计推理 54 分钟,但 httpx 超时仅 20 分钟,导致 LatentSync 调用失败并回退到无口型同步
|
||||
- **根因**: `lipsync_service.py` 中 `httpx.AsyncClient(timeout=1200.0)` 不足以覆盖长视频推理时间
|
||||
- **修复**: 超时从 `1200s`(20 分钟)改为 `3600s`(1 小时),足以覆盖 2-3 分钟视频的推理
|
||||
|
||||
### 24. 字幕时间戳节奏映射(修复长视频字幕漂移)
|
||||
|
||||
- **问题**: 2 分钟视频字幕明显对不上语音,越到后面偏差越大
|
||||
- **根因**: `whisper_service.py` 的 `original_text` 处理逻辑丢弃了 Whisper 逐词时间戳,仅保留总时间范围后做全程线性插值,每个字分配相同时长,完全忽略语速变化和停顿
|
||||
- **修复**: 保留 Whisper 的逐字时间戳作为语音节奏模板,将原文字符按比例映射到 Whisper 时间节奏上(rhythm-mapping),而非线性均分。字幕文字不变,只是时间戳跟随真实语速
|
||||
- **算法**: 原文第 i 个字符映射到 Whisper 时间线的 `(i/N)*M` 位置(N=原文字符数,M=Whisper字符数),在相邻 Whisper 时间点间线性插值
|
||||
|
||||
---
|
||||
|
||||
## 📁 修改文件清单
|
||||
|
||||
| 文件 | 改动 |
|
||||
|------|------|
|
||||
| `VoiceSelector.tsx` | 新增 embedded prop,移动端按钮适配(`px-2 sm:px-4`) |
|
||||
| `GeneratedAudiosPanel.tsx` | 新增 embedded prop,两行布局,操作按钮可见度,"生成配音"按钮放大 |
|
||||
| `MaterialSelector.tsx` | 新增 embedded prop,自渲染子标题+操作按钮,useMemo,disabled 守卫,操作按钮可见度,标题溢出修复 |
|
||||
| `TimelineEditor.tsx` | 新增 embedded prop,自渲染子标题+控件,useMemo,拖拽抓手图标 |
|
||||
| `PreviewPanel.tsx` | 新增 embedded prop |
|
||||
| `HistoryList.tsx` | 新增 embedded prop,操作按钮可见度 |
|
||||
| `ScriptEditor.tsx` | 标题加序号,移除 emoji,操作按钮可见度 |
|
||||
| `TitleSubtitlePanel.tsx` | 标题加序号,移除 emoji,下拉对齐,显示模式下拉上移,字数计数器 |
|
||||
| `BgmPanel.tsx` | 标题加序号 |
|
||||
| `HomePage.tsx` | 核心重构:合并板块、序号标题、生成配音按钮迁入、`scrollRestoration` + 延迟兜底修复刷新回顶部、生成进度条提取到作品卡片上方 |
|
||||
| `PublishPage.tsx` | 四个板块加序号(七~十),移除 emoji,平台卡片响应式单行布局 |
|
||||
| `RefAudioPanel.tsx` | 简化提示文案,操作按钮可见度 |
|
||||
| `AccountSettingsDropdown.tsx` | 新增手机号显示(脱敏),补齐 success 清空 |
|
||||
| `AuthContext.tsx` | 新增 `setUser` 方法,登录后立即更新用户状态 |
|
||||
| `login/page.tsx` | 登录成功后调用 `setUser` 写入用户数据 |
|
||||
| `useHomeController.ts` | titleDisplayMode 条件修复,列表 scroll 时间门控 `scrollEffectsEnabled` |
|
||||
| `FloatingStylePreview.tsx` | 移动端预览窗口缩小(160px)并移至右下角 |
|
||||
| `ScriptExtractionModal.tsx` | 滚动条改回隐藏 |
|
||||
| `ClipTrimmer.tsx` | 滑块手柄放大、触控区增高 |
|
||||
| `lipsync_service.py` | httpx 超时从 1200s 改为 3600s |
|
||||
| `whisper_service.py` | 字幕时间戳从线性插值改为 Whisper 节奏映射 |
|
||||
|
||||
---
|
||||
|
||||
## 🔍 验证
|
||||
|
||||
- `npm run build` — 零报错零警告
|
||||
- 合并后布局:各子板块分隔清晰、主标题有序号
|
||||
- 向后兼容:`embedded` 默认 `false`,组件独立使用不受影响
|
||||
- 配音列表两行布局:语速+生成配音在上,配音列表+刷新在下
|
||||
- 下拉菜单垂直对齐正确
|
||||
- 短暂显示/常驻显示对标题和副标题同时生效
|
||||
- 操作按钮在移动端(触屏)可见
|
||||
- 手机号脱敏显示
|
||||
- 标题字数计数器正常
|
||||
- 列表滚动条全部隐藏
|
||||
- 时间轴拖拽抓手图标显示
|
||||
- 发布页平台卡片:移动端紧凑、桌面端宽松,风格一致
|
||||
- 移动端刷新后回到顶部,不再滚动到背景音乐位置
|
||||
- 移动端样式预览窗口不遮挡控件
|
||||
- 移动端配音按钮(选择声音/克隆声音)均可见
|
||||
- 移动端素材标题行按钮不溢出
|
||||
- 生成配音按钮视觉层级高于辅助按钮
|
||||
- 生成进度条在作品卡片上方独立显示
|
||||
- LatentSync 长视频推理不再超时回退
|
||||
- 字幕时间戳与语音节奏同步,长视频不漂移
|
||||
231
Docs/DevLogs/Day27.md
Normal file
231
Docs/DevLogs/Day27.md
Normal file
@@ -0,0 +1,231 @@
|
||||
## Remotion 描边修复 + 字体样式扩展 + TypeScript 修复 (Day 27)
|
||||
|
||||
### 概述
|
||||
|
||||
修复标题/字幕描边渲染问题(描边过粗 + 副标题重影),扩展字体样式选项(标题 4→12、字幕 4→8),修复 Remotion 项目 TypeScript 类型错误。
|
||||
|
||||
---
|
||||
|
||||
## ✅ 改动内容
|
||||
|
||||
### 1. 描边渲染修复(标题 + 字幕)
|
||||
|
||||
- **问题**: 标题黑色描边过粗,副标题出现重影/鬼影
|
||||
- **根因**: `buildTextShadow` 用 4 方向 `textShadow` 模拟描边 — 对角线叠加导致描边视觉上比实际 `stroke_size` 更粗;4 角方向在中间有间隙和叠加,造成重影
|
||||
- **修复**: 改用 CSS 原生描边 `-webkit-text-stroke` + `paint-order: stroke fill`(Remotion 用 Chromium 渲染,完美支持)
|
||||
- **旧方案**:
|
||||
```javascript
|
||||
textShadow: `-8px -8px 0 #000, 8px -8px 0 #000, -8px 8px 0 #000, 8px 8px 0 #000, 0 0 16px rgba(0,0,0,0.5), 0 2px 4px rgba(0,0,0,0.3)`
|
||||
```
|
||||
- **新方案**:
|
||||
```javascript
|
||||
WebkitTextStroke: `5px #000000`,
|
||||
paintOrder: 'stroke fill',
|
||||
textShadow: `0 2px 4px rgba(0,0,0,0.3)`,
|
||||
```
|
||||
- 同时将所有预设样式的 `stroke_size` 从 8 降到 5,配合原生描边视觉更干净
|
||||
|
||||
### 2. 字体样式扩展
|
||||
|
||||
**标题样式**: 4 个 → 12 个(+8)
|
||||
|
||||
| ID | 样式名 | 字体 | 配色 |
|
||||
|----|--------|------|------|
|
||||
| title_pangmen | 庞门正道 | 庞门正道标题体3.0 | 白字黑描 |
|
||||
| title_round | 优设标题圆 | 优设标题圆 | 白字紫描 |
|
||||
| title_alibaba | 阿里数黑体 | 阿里巴巴数黑体 | 白字黑描 |
|
||||
| title_chaohei | 文道潮黑 | 文道潮黑 | 青蓝字深蓝描 |
|
||||
| title_wujie | 无界黑 | 标小智无界黑 | 白字深灰描 |
|
||||
| title_houdi | 厚底黑 | Aa厚底黑 | 红字深黑描 |
|
||||
| title_banyuan | 寒蝉半圆体 | 寒蝉半圆体 | 白字黑描 |
|
||||
| title_jixiang | 欣意吉祥宋 | 字体圈欣意吉祥宋 | 金字棕描 |
|
||||
|
||||
**字幕样式**: 4 个 → 8 个(+4)
|
||||
|
||||
| ID | 样式名 | 字体 | 高亮色 |
|
||||
|----|--------|------|--------|
|
||||
| subtitle_pink | 少女粉 | DingTalk JinBuTi | 粉色 #FF69B4 |
|
||||
| subtitle_lime | 清新绿 | DingTalk Sans | 荧光绿 #76FF03 |
|
||||
| subtitle_gold | 金色隶书 | 阿里妈妈刀隶体 | 金色 #FDE68A |
|
||||
| subtitle_kai | 楷体红字 | SimKai | 红色 #FF4444 |
|
||||
|
||||
### 3. TypeScript 类型错误修复
|
||||
|
||||
- **Root.tsx**: `Composition` 泛型类型与 `calculateMetadata` 参数类型不匹配 — 内联 `calculateMetadata` 并显式标注参数类型,`defaultProps` 使用 `satisfies VideoProps` 约束
|
||||
- **Video.tsx**: `VideoProps` 接口添加 `[key: string]: unknown` 索引签名,兼容 Remotion 要求的 `Record<string, unknown>` 约束
|
||||
- **VideoLayer.tsx**: `OffthreadVideo` 组件不支持 `loop` prop — 移除(该 prop 原本就被忽略)
|
||||
|
||||
### 4. 进度条文案还原
|
||||
|
||||
- **问题**: 进度条显示后端推送的详细阶段消息(如"正在合成唇型"),用户希望只显示"正在AI生成中..."
|
||||
- **修复**: `HomePage.tsx` 进度条文案从 `{currentTask.message || "正在AI生成中..."}` 改为固定 `正在AI生成中...`
|
||||
|
||||
---
|
||||
|
||||
## 📁 修改文件清单
|
||||
|
||||
| 文件 | 改动 |
|
||||
|------|------|
|
||||
| `remotion/src/components/Title.tsx` | `buildTextShadow` → `buildStrokeStyle`(CSS 原生描边),标题+副标题同时生效 |
|
||||
| `remotion/src/components/Subtitles.tsx` | `buildTextShadow` → `buildStrokeStyle`(CSS 原生描边) |
|
||||
| `remotion/src/Root.tsx` | 修复 `Composition` 泛型类型、`calculateMetadata` 参数类型 |
|
||||
| `remotion/src/Video.tsx` | `VideoProps` 添加索引签名 |
|
||||
| `remotion/src/components/VideoLayer.tsx` | 移除 `OffthreadVideo` 不支持的 `loop` prop |
|
||||
| `backend/assets/styles/title.json` | 标题样式从 4 个扩展到 12 个,`stroke_size` 8→5 |
|
||||
| `backend/assets/styles/subtitle.json` | 字幕样式从 4 个扩展到 8 个 |
|
||||
| `frontend/.../HomePage.tsx` | 进度条文案还原为固定"正在AI生成中..." |
|
||||
|
||||
---
|
||||
|
||||
## 🔍 验证
|
||||
|
||||
- `npx tsc --noEmit` — 零错误
|
||||
- `npm run build:render` — 渲染脚本编译成功
|
||||
- `npm run build`(前端)— 零报错
|
||||
- 描边:标题/副标题/字幕使用 CSS 原生描边,无重影、无虚胖
|
||||
- 样式选择:前端下拉可加载全部 12 个标题 + 8 个字幕样式
|
||||
|
||||
---
|
||||
|
||||
## 视频生成流水线性能优化
|
||||
|
||||
### 概述
|
||||
|
||||
针对视频生成流水线进行全面性能优化,涵盖 FFmpeg 编码参数、LatentSync 推理参数、多素材并行化、以及后处理阶段并行化。预估 15s 单素材视频从 ~280s 降至 ~190s (32%),30s 双素材从 ~400s 降至 ~240s (40%)。
|
||||
|
||||
**服务器配置**: 2x RTX 3090 (24GB), 2x Xeon E5-2680 v4 (56核), 192GB RAM
|
||||
|
||||
### 第一阶段:FFmpeg 编码优化
|
||||
|
||||
**最终合成 preset `slow` → `medium`**
|
||||
- 合成阶段从 ~50s 降到 ~25s,质量几乎无变化
|
||||
|
||||
**中间文件 CRF 18 → 23**
|
||||
- 中间产物(trim、prepare_segment、concat、loop、normalize_orientation)不是最终输出,不需要高质量编码
|
||||
- 每个中间步骤快 3-8 秒
|
||||
|
||||
**最终合成 CRF 18 → 20**
|
||||
- 15 秒口播视频 CRF 18 vs 20 肉眼无法区分
|
||||
|
||||
### 第二阶段:LatentSync 推理参数调优
|
||||
|
||||
**inference_steps 20 → 16**
|
||||
- 推理时间线性减少 20%(~180s → ~144s)
|
||||
|
||||
**guidance_scale 2.0 → 1.5**
|
||||
- classifier-free guidance 权重降低,每步计算量微降(5-10%)
|
||||
|
||||
> ⚠️ 两项需重启 LatentSync 服务后测试唇形质量,确认可接受再保留。如质量不佳可回退 .env 参数。
|
||||
|
||||
### 第三阶段:多素材流水线并行化
|
||||
|
||||
**素材下载 + 归一化并行**
|
||||
- 串行 `for` 循环改为 `asyncio.gather()`,`normalize_orientation` 通过 `run_in_executor` 在线程池执行
|
||||
- N 个素材从串行 N×5s → ~5s
|
||||
|
||||
**片段预处理并行**
|
||||
- 逐个 `prepare_segment` 改为 `asyncio.gather()` + `run_in_executor`
|
||||
- 2 素材 ~90s → ~50s;4 素材 ~180s → ~60s
|
||||
|
||||
### 第四阶段:流水线交叠
|
||||
|
||||
**Whisper 字幕对齐 与 BGM 混音 并行**
|
||||
- 两者互不依赖(都只依赖 audio_path),用 `asyncio.gather()` 并行执行
|
||||
- 单素材模式下 Whisper 从 LatentSync 之后的串行步骤移至与 BGM 并行
|
||||
- 不开 BGM 或不开字幕时行为不变,只有同时启用时才并行
|
||||
|
||||
### 修改文件
|
||||
|
||||
| 文件 | 改动 |
|
||||
|------|------|
|
||||
| `backend/app/services/video_service.py` | compose: preset slow→medium, CRF 18→20; normalize_orientation/prepare_segment/concat: CRF 18→23 |
|
||||
| `backend/app/services/lipsync_service.py` | _loop_video_to_duration: CRF 18→23 |
|
||||
| `backend/.env` | LATENTSYNC_INFERENCE_STEPS=16, LATENTSYNC_GUIDANCE_SCALE=1.5 |
|
||||
| `backend/app/modules/videos/workflow.py` | import asyncio; 素材下载/归一化并行; 片段预处理并行; Whisper+BGM 并行 |
|
||||
|
||||
### 回退方案
|
||||
|
||||
- FFmpeg 参数:如画质不满意,将最终 CRF 改回 18、preset 改回 slow
|
||||
- LatentSync:如唇形质量下降,将 .env 中 `INFERENCE_STEPS` 改回 20、`GUIDANCE_SCALE` 改回 2.0
|
||||
- 并行化:纯架构优化,无质量影响,无需回退
|
||||
|
||||
---
|
||||
|
||||
## MuseTalk + LatentSync 混合唇形同步方案
|
||||
|
||||
### 概述
|
||||
|
||||
LatentSync 1.6 质量高但推理极慢(~78% 总时长),长视频(>=2min)耗时 20-60 分钟不可接受。MuseTalk 1.5 是单步潜空间修复(非扩散模型),逐帧推理速度接近实时(30fps+ on V100),适合长视频。混合方案按音频时长自动路由:短视频用 LatentSync 保质量,长视频用 MuseTalk 保速度。
|
||||
|
||||
### 架构
|
||||
|
||||
- **路由阈值**: `LIPSYNC_DURATION_THRESHOLD` (默认 120s)
|
||||
- **短视频 (<120s)**: LatentSync 1.6 (GPU1, 端口 8007)
|
||||
- **长视频 (>=120s)**: MuseTalk 1.5 (GPU0, 端口 8011)
|
||||
- **回退**: MuseTalk 不可用时自动 fallback 到 LatentSync
|
||||
|
||||
### 改动文件
|
||||
|
||||
| 文件 | 改动 |
|
||||
|------|------|
|
||||
| `models/MuseTalk/` | 从 Temp/MuseTalk 复制代码 + 下载权重 |
|
||||
| `models/MuseTalk/scripts/server.py` | 新建 FastAPI 常驻服务 (端口 8011, GPU0) |
|
||||
| `backend/app/core/config.py` | 新增 MUSETALK_* 和 LIPSYNC_DURATION_THRESHOLD |
|
||||
| `backend/.env` | 新增对应环境变量 |
|
||||
| `backend/app/services/lipsync_service.py` | 新增 `_call_musetalk_server()` + 混合路由逻辑 + 扩展 `check_health()` |
|
||||
|
||||
---
|
||||
|
||||
## MuseTalk 推理性能优化 (server.py v2)
|
||||
|
||||
### 概述
|
||||
|
||||
MuseTalk 首次长视频测试 (136s, 3404 帧) 耗时 1799s (~30 分钟),分析发现瓶颈集中在人脸检测 (28%)、BiSeNet 合成 (22%)、I/O (17%),而非 UNet 推理本身 (17%)。通过 6 项优化预估降至 8-10 分钟 (~3x 加速)。
|
||||
|
||||
### 性能瓶颈分析 (优化前, 1799s)
|
||||
|
||||
| 阶段 | 耗时 | 占比 | 瓶颈原因 |
|
||||
|------|------|------|---------|
|
||||
| DWPose + 人脸检测 | ~510s | 28% | `batch_size_fa=1`, 每帧跑 2 个 NN, 完全串行 |
|
||||
| 合成 + BiSeNet 人脸解析 | ~400s | 22% | 每帧都跑 BiSeNet + PNG 写盘 |
|
||||
| UNet 推理 | ~300s | 17% | batch_size=8 太小 |
|
||||
| I/O (PNG 读写 + FFmpeg) | ~300s | 17% | PNG 压缩慢, ffmpeg→PNG→imread 链路 |
|
||||
| VAE 编码 | ~100s | 6% | 逐帧编码, 未批处理 |
|
||||
|
||||
### 6 项优化
|
||||
|
||||
| # | 优化项 | 详情 |
|
||||
|---|--------|------|
|
||||
| 1 | **batch_size 8→32** | `.env` 修改, RTX 3090 显存充裕 |
|
||||
| 2 | **cv2.VideoCapture 直读帧** | 跳过 ffmpeg→PNG→imread 链路, 省去 3404 次 PNG 编解码 |
|
||||
| 3 | **人脸检测降频 (每5帧)** | 每 5 帧运行 DWPose + FaceAlignment, 中间帧线性插值 bbox |
|
||||
| 4 | **BiSeNet mask 缓存 (每5帧)** | 每 5 帧运行 `get_image_prepare_material`, 中间帧用 `get_image_blending` 复用缓存 mask |
|
||||
| 5 | **cv2.VideoWriter 直写** | 跳过逐帧 PNG 写盘 + ffmpeg 重编码, 用 VideoWriter 直写 mp4 |
|
||||
| 6 | **每阶段计时** | 7 个阶段精确计时, 方便后续进一步调优 |
|
||||
|
||||
### 修改文件
|
||||
|
||||
| 文件 | 改动 |
|
||||
|------|------|
|
||||
| `models/MuseTalk/scripts/server.py` | 完全重写 `_run_inference()`, 新增 `_detect_faces_subsampled()` |
|
||||
| `backend/.env` | `MUSETALK_BATCH_SIZE` 8→32 |
|
||||
|
||||
---
|
||||
|
||||
## Remotion 并发渲染优化
|
||||
|
||||
### 概述
|
||||
|
||||
Remotion 渲染在 56 核服务器上默认只用 8 并发 (`min(8, cores/2)`),改为 16 并发,预估从 ~5 分钟降到 ~2-3 分钟。
|
||||
|
||||
### 改动
|
||||
|
||||
- `remotion/render.ts`: `renderMedia()` 新增 `concurrency` 参数 (默认 16), 支持 `--concurrency` CLI 参数覆盖
|
||||
- `remotion/dist/render.js`: 重新编译
|
||||
|
||||
### 修改文件
|
||||
|
||||
| 文件 | 改动 |
|
||||
|------|------|
|
||||
| `remotion/render.ts` | `RenderOptions` 新增 `concurrency` 字段, `renderMedia()` 传入 `concurrency` |
|
||||
| `remotion/dist/render.js` | TypeScript 重新编译 |
|
||||
203
Docs/DevLogs/Day28.md
Normal file
203
Docs/DevLogs/Day28.md
Normal file
@@ -0,0 +1,203 @@
|
||||
## CosyVoice FP16 加速 + 文档更新 + AI改写界面重构 + 标题字幕面板重排与视频帧预览 (Day 28)
|
||||
|
||||
### 概述
|
||||
|
||||
CosyVoice 3.0 声音克隆服务开启 FP16 半精度推理,预估提速 30-40%。同步更新 4 个项目文档。重构 AI 改写文案界面(RewriteModal 两步流程 + ScriptExtractionModal 逻辑抽取)。前端将"标题与字幕"面板从第二步移至第四步(素材编辑之后),样式预览窗口背景从紫粉渐变改为视频片头帧截图,实现所见即所得。
|
||||
|
||||
---
|
||||
|
||||
## ✅ 改动内容
|
||||
|
||||
### 1. CosyVoice FP16 半精度加速
|
||||
|
||||
- **问题**: CosyVoice 3.0 以 FP32 全精度运行,RTF (Real-Time Factor) 约 0.9-1.35x,生成 2 分钟音频需要约 2 分钟
|
||||
- **根因**: `AutoModel()` 初始化时未传入 `fp16=True`,LLM 推理和 Flow Matching (DiT) 均在 FP32 下运行
|
||||
- **修复**: 一行改动开启 FP16 自动混合精度
|
||||
|
||||
```python
|
||||
# 旧: _model = AutoModel(model_dir=str(MODEL_DIR))
|
||||
# 新:
|
||||
_model = AutoModel(model_dir=str(MODEL_DIR), fp16=True)
|
||||
```
|
||||
|
||||
- **生效机制**: `CosyVoice3Model` 在 `llm_job()` 和 `token2wav()` 中通过 `torch.cuda.amp.autocast(self.fp16)` 自动将计算转为 FP16
|
||||
- **预期效果**:
|
||||
- 推理速度提升 30-40%
|
||||
- 显存占用降低 ~30%
|
||||
- 语音质量基本无损(0.5B 模型 FP16 精度充足)
|
||||
- **验证**: 服务重启后自检通过,健康检查 `ready: true`
|
||||
|
||||
### 2. 文档全面更新 (4 个文件)
|
||||
|
||||
补充 Day 27 新增的 MuseTalk 混合唇形同步方案、性能优化、Remotion 并发渲染等内容到所有相关文档。
|
||||
|
||||
#### README.md
|
||||
- 项目描述更新为 "LatentSync 1.6 + MuseTalk 1.5 混合唇形同步"
|
||||
- 唇形同步功能描述改为混合方案(短视频 LatentSync,长视频 MuseTalk)
|
||||
- 技术栈表新增 MuseTalk 1.5
|
||||
- 项目结构新增 `models/MuseTalk/`
|
||||
- 服务架构表新增 MuseTalk (端口 8011)
|
||||
- 文档中心新增 MuseTalk 部署指南链接
|
||||
- 性能优化描述新增降频检测 + Remotion 16 并发
|
||||
|
||||
#### DEPLOY_MANUAL.md
|
||||
- GPU 分配说明更新 (GPU0=MuseTalk+CosyVoice, GPU1=LatentSync)
|
||||
- 步骤 3 拆分为 3a (LatentSync) + 3b (MuseTalk)
|
||||
- 环境变量表新增 7 个 MuseTalk 变量,移除过时的 `DOUYIN_COOKIE`
|
||||
- LatentSync 推理步数默认值 20→16
|
||||
- 测试运行新增 MuseTalk 启动终端
|
||||
- PM2 管理新增 MuseTalk 服务(第 5 项)
|
||||
- 端口检查、日志查看命令新增 8011/vigent2-musetalk
|
||||
|
||||
#### SUBTITLE_DEPLOY.md
|
||||
- 技术架构图更新为 LatentSync/MuseTalk 混合路由
|
||||
- 新增唇形同步路由说明
|
||||
- Remotion 配置表新增 `concurrency` 参数 (默认 16)
|
||||
- GPU 分配说明更新
|
||||
- 更新日志新增 v1.3.0 条目
|
||||
|
||||
#### BACKEND_README.md
|
||||
- 健康检查接口描述更新为含 LatentSync + MuseTalk + 混合路由阈值
|
||||
- 环境变量配置新增 MuseTalk 相关变量
|
||||
- 服务集成指南新增"唇形同步混合路由"章节
|
||||
|
||||
---
|
||||
|
||||
### 3. AI 改写文案界面重构
|
||||
|
||||
#### RewriteModal 重构
|
||||
|
||||
将 AI 改写弹窗改为两步式流程,提升交互体验:
|
||||
|
||||
**第一步 — 配置与触发**:
|
||||
- 自定义提示词输入(可选),自动持久化到 localStorage
|
||||
- "开始改写"按钮触发 `/api/ai/rewrite` 请求
|
||||
|
||||
**第二步 — 结果对比与选择**:
|
||||
- 上方:AI 改写结果 + "使用此结果"按钮(紫粉渐变色,醒目)
|
||||
- 下方:原文对比 + "保留原文"按钮(灰色低调)
|
||||
- 底部:可"重新改写"(重回第一步,保留自定义提示词)
|
||||
- ESC 快捷键关闭
|
||||
|
||||
#### ScriptExtractionModal 逻辑抽取
|
||||
|
||||
将文案提取模态框的全部业务逻辑抽取到独立 hook `useScriptExtraction`:
|
||||
|
||||
- **useScriptExtraction.ts** (新建): 管理 URL/文件双模式输入、拖拽上传、提取请求、步骤状态机 (config → processing → result)、剪贴板复制
|
||||
- **ScriptExtractionModal.tsx**: 纯展示组件,消费 hook 返回值,新增 ESC/Enter 快捷键
|
||||
|
||||
#### ScriptEditor 工具栏调整
|
||||
|
||||
- 按钮组右对齐 (`justify-end`),统一高度 `h-7` 和圆角
|
||||
- "历史文案"按钮用灰色 (bg-gray-600) 区分辅助功能
|
||||
- "文案提取助手"用紫色 (bg-purple-600) 表示主功能
|
||||
- "AI多语言"用绿渐变 (emerald-teal),"AI生成标题标签"用蓝渐变 (blue-cyan)
|
||||
- "AI智能改写"和"保存文案"移至文本框下方状态栏
|
||||
|
||||
---
|
||||
|
||||
### 4. 标题字幕面板重排 + 视频帧背景预览
|
||||
|
||||
#### 面板顺序重排
|
||||
|
||||
将 `<TitleSubtitlePanel>` 从第二步移至第四步(素材编辑之后),使用户在设置标题字幕样式时已经完成了素材选择和时间轴编排。
|
||||
|
||||
新顺序:
|
||||
```
|
||||
一、文案提取与编辑(不变)
|
||||
二、配音(原三)
|
||||
三、素材编辑(原四)
|
||||
四、标题与字幕(原二)→ 移到素材编辑之后
|
||||
```
|
||||
|
||||
#### 新建 useVideoFrameCapture hook
|
||||
|
||||
从视频 URL 截取 0.1s 处帧画面,返回 JPEG data URL:
|
||||
|
||||
- 创建 `<video>` 元素,设置 `crossOrigin="anonymous"`(素材存储在 Supabase Storage 跨域地址)
|
||||
- 先绑定 `loadedmetadata` / `canplay` / `seeked` / `error` 事件监听,再设 src(避免事件丢失)
|
||||
- `loadedmetadata` 或 `canplay` 触发后 seek 到 0.1s,`seeked` 回调中用 canvas `drawImage` 截帧
|
||||
- canvas 缩放到 480px 宽再编码(预览窗口最大 280px,节省内存)
|
||||
- `canvas.toDataURL("image/jpeg", 0.7)` 导出
|
||||
- 防御 `videoWidth/videoHeight` 为 0 的边界情况
|
||||
- try-catch 防 canvas taint,失败返回 null(降级渐变)
|
||||
- `isActive` 标志 + `seeked` 去重标志防止 stale 和重复更新
|
||||
- 截图完成后清理 video 元素释放内存
|
||||
|
||||
#### 按需截取(性能优化)
|
||||
|
||||
只在样式预览窗口打开时才触发截取:
|
||||
|
||||
```typescript
|
||||
const materialPosterUrl = useVideoFrameCapture(
|
||||
showStylePreview ? firstTimelineMaterialUrl : null
|
||||
);
|
||||
```
|
||||
|
||||
截取源优先使用**时间轴第一段素材**(用户拖拽排序后的真实片头),回退到 `selectedMaterials[0]`(未生成配音、时间轴为空时)。
|
||||
|
||||
#### 预览背景替换
|
||||
|
||||
`FloatingStylePreview` 有视频帧时直接显示原始画面(不加半透明,保证颜色真实),文字靠描边保证可读性;无视频帧时降级为原紫粉渐变背景。
|
||||
|
||||
#### 踩坑记录
|
||||
|
||||
1. **CORS tainted canvas**: 素材文件存储在 Supabase Storage (`api.hbyrkj.top`),是跨域签名链接。必须设 `video.crossOrigin = "anonymous"` 才能让 canvas `toDataURL` 不被 SecurityError 拦截
|
||||
2. **时间轴为空**: `useTimelineEditor` 在 `audioDuration <= 0`(未选配音)时返回空数组,需回退到 `selectedMaterials[0]`
|
||||
3. **事件监听顺序**: 必须先绑定事件监听再设 `video.src`,否则快速加载时事件可能丢失
|
||||
|
||||
---
|
||||
|
||||
## 📁 修改文件清单
|
||||
|
||||
| 文件 | 改动 |
|
||||
|------|------|
|
||||
| `models/CosyVoice/cosyvoice_server.py` | `AutoModel()` 新增 `fp16=True` 参数 |
|
||||
| `README.md` | 混合唇形同步描述、技术栈、服务架构、项目结构更新 |
|
||||
| `Docs/DEPLOY_MANUAL.md` | MuseTalk 部署步骤、环境变量、PM2 管理、端口检查 |
|
||||
| `Docs/SUBTITLE_DEPLOY.md` | 架构图、Remotion concurrency、GPU 分配、更新日志 |
|
||||
| `Docs/BACKEND_README.md` | 健康检查、环境变量、混合路由章节 |
|
||||
| `frontend/.../RewriteModal.tsx` | 两步式改写流程(自定义提示词 → 结果对比) |
|
||||
| `frontend/.../script-extraction/useScriptExtraction.ts` | **新建** — 文案提取逻辑 hook |
|
||||
| `frontend/.../ScriptExtractionModal.tsx` | 纯展示组件,消费 hook,新增快捷键 |
|
||||
| `frontend/.../ScriptEditor.tsx` | 工具栏右对齐 + 按钮分色 + 改写/保存移至底部 |
|
||||
| `frontend/.../useVideoFrameCapture.ts` | **新建** — 视频帧截取 hook,crossOrigin + canvas 缩放 |
|
||||
| `frontend/.../useHomeController.ts` | 新增 useMemo 计算素材 URL,调用帧截取 hook,showStylePreview 门控 |
|
||||
| `frontend/.../HomePage.tsx` | 面板重排(二↔四互换),编号更新,透传 materialPosterUrl |
|
||||
| `frontend/.../TitleSubtitlePanel.tsx` | 编号"二"→"四",新增 previewBackgroundUrl prop |
|
||||
| `frontend/.../FloatingStylePreview.tsx` | 新增 previewBackgroundUrl prop,条件渲染视频帧/渐变背景 |
|
||||
|
||||
---
|
||||
|
||||
## 🔍 验证
|
||||
|
||||
- CosyVoice 重启成功,健康检查 `{"ready": true}`
|
||||
- 自检推理通过(7.2s for "你好")
|
||||
- FP16 通过 `torch.cuda.amp.autocast(self.fp16)` 在 LLM 和 Flow Matching 阶段生效
|
||||
- `npx tsc --noEmit` — 零错误
|
||||
- AI 改写:自定义提示词持久化 → 改写结果 + 原文对比 → "使用此结果"/"保留原文"
|
||||
- 文案提取:URL / 文件双模式 → 处理中动画 → 结果填入
|
||||
- 面板顺序:一→文案、二→配音、三→素材编辑、四→标题与字幕
|
||||
- 样式预览背景:有素材时显示真实视频片头帧,无素材降级紫粉渐变
|
||||
- 预览关闭时不触发截取,不浪费资源
|
||||
|
||||
---
|
||||
|
||||
## 💡 CosyVoice 性能分析备注
|
||||
|
||||
### 当前性能基线 (FP32, 优化前)
|
||||
|
||||
| 文本长度 | 音频时长 | 推理耗时 | RTF |
|
||||
|----------|----------|----------|-----|
|
||||
| 42 字 | 9.8s | 13.2s | 1.35x |
|
||||
| 89 字 | 18.2s | 20.3s | 1.12x |
|
||||
| ~530 字 | 115.8s | 107.7s | 0.93x |
|
||||
| ~670 字 | 143.5s | 131.6s | 0.92x |
|
||||
|
||||
### 未来可选优化(收益递减,暂不实施)
|
||||
|
||||
| 优化项 | 预期提升 | 复杂度 |
|
||||
|--------|----------|--------|
|
||||
| TensorRT (DiT 模块) | +20-30% | 需编译 .plan 引擎 |
|
||||
| torch.compile() | +10-20% | 一行代码,但首次编译慢 |
|
||||
| vLLM (LLM 模块) | +10-15% | 额外依赖 |
|
||||
@@ -151,6 +151,33 @@ body {
|
||||
| `sm:` | ≥ 640px | 平板/桌面 |
|
||||
| `lg:` | ≥ 1024px | 大屏桌面 |
|
||||
|
||||
### embedded 组件模式
|
||||
|
||||
合并板块时,子组件通过 `embedded?: boolean` prop 控制是否渲染外层卡片容器和主标题。
|
||||
|
||||
```tsx
|
||||
// embedded=false(独立使用):渲染完整卡片
|
||||
<div className="bg-white/5 rounded-2xl p-6 border border-white/10">
|
||||
<h2>标题</h2>
|
||||
{content}
|
||||
</div>
|
||||
|
||||
// embedded=true(嵌入父卡片):只渲染内容
|
||||
{content}
|
||||
```
|
||||
|
||||
- 子标题使用 `<h3 className="text-sm font-medium text-gray-400">`
|
||||
- 分隔线使用 `<div className="border-t border-white/10 my-4" />`
|
||||
- 移动端标题行避免 `whitespace-nowrap`,长描述文字可用 `hidden sm:inline` 在移动端隐藏
|
||||
|
||||
### 按钮视觉层级
|
||||
|
||||
| 层级 | 样式 | 用途 |
|
||||
|------|------|------|
|
||||
| 主操作 | `px-4 py-2 text-sm font-medium bg-gradient-to-r from-purple-600 to-pink-600 shadow-sm` | 生成配音、立即发布 |
|
||||
| 辅助操作 | `px-2 py-1 text-xs bg-white/10 rounded` | 刷新、上传、语速 |
|
||||
| 触屏可见 | `opacity-40 group-hover:opacity-100` | 列表行内操作(编辑/删除) |
|
||||
|
||||
---
|
||||
|
||||
## API 请求规范
|
||||
@@ -259,9 +286,35 @@ import { formatDate } from '@/shared/lib/media';
|
||||
|
||||
### 刷新回顶部(统一体验)
|
||||
|
||||
- 长页面(如首页/发布页)在首次挂载时统一回到顶部,避免浏览器恢复旧滚动位置导致进入即跳到中部。
|
||||
- 推荐实现:`useEffect(() => { window.scrollTo({ top: 0, left: 0, behavior: 'auto' }); }, [])`
|
||||
- 列表内自动定位(素材/历史记录)应跳过恢复后的首次触发,防止刷新后页面二次跳动。
|
||||
- 长页面(如首页/发布页)在首次挂载时统一回到顶部。
|
||||
- **必须**在页面级 `useEffect` 中设置 `history.scrollRestoration = "manual"` 禁用浏览器原生滚动恢复。
|
||||
- 调用 `window.scrollTo({ top: 0, left: 0, behavior: "auto" })` 并追加 200ms 延迟兜底(防止异步 effect 覆盖)。
|
||||
- **列表自动滚动必须使用时间门控**:页面加载后 1 秒内禁止所有列表自动滚动效果(`scrollEffectsEnabled` ref),防止持久化恢复 + 异步数据加载触发 `scrollIntoView` 导致页面跳动。
|
||||
- 推荐模式:
|
||||
|
||||
```typescript
|
||||
// 页面级(HomePage / PublishPage)
|
||||
useEffect(() => {
|
||||
if (typeof window === "undefined") return;
|
||||
if ("scrollRestoration" in history) history.scrollRestoration = "manual";
|
||||
window.scrollTo({ top: 0, left: 0, behavior: "auto" });
|
||||
const timer = setTimeout(() => window.scrollTo({ top: 0, left: 0, behavior: "auto" }), 200);
|
||||
return () => clearTimeout(timer);
|
||||
}, []);
|
||||
|
||||
// Controller 级(列表滚动时间门控)
|
||||
const scrollEffectsEnabled = useRef(false);
|
||||
useEffect(() => {
|
||||
const timer = setTimeout(() => { scrollEffectsEnabled.current = true; }, 1000);
|
||||
return () => clearTimeout(timer);
|
||||
}, []);
|
||||
|
||||
// 列表滚动 effect(BGM/素材/视频等)
|
||||
useEffect(() => {
|
||||
if (!selectedId || !scrollEffectsEnabled.current) return;
|
||||
target?.scrollIntoView({ block: "nearest", behavior: "smooth" });
|
||||
}, [selectedId, list]);
|
||||
```
|
||||
|
||||
### 路由预取
|
||||
|
||||
|
||||
@@ -5,14 +5,12 @@ ViGent2 的前端界面,采用 Next.js 16 + TailwindCSS 构建。
|
||||
## ✨ 核心功能
|
||||
|
||||
### 1. 视频生成 (`/`)
|
||||
- **素材管理**: 拖拽上传人物视频,实时预览。
|
||||
- **素材重命名**: 支持在列表中直接重命名素材。
|
||||
- **文案配音**: 集成 EdgeTTS,支持多音色选择 (云溪 / 晓晓)。
|
||||
- **AI 标题/标签**: 一键生成视频标题与标签 (Day 14)。
|
||||
- **标题/字幕样式**: 样式选择 + 预览 + 字号调节 (Day 16)。
|
||||
- **背景音乐**: 试听 + 音量控制 + 选择持久化 (Day 16)。
|
||||
- **交互优化**: 选择项持久化、列表内定位、刷新回顶部 (Day 16)。
|
||||
- **预览一致性**: 标题/字幕预览按素材分辨率缩放,效果更接近成片 (Day 17)。
|
||||
- **一、文案提取与编辑**: 文案输入/提取/翻译/保存。
|
||||
- **二、配音**: 配音方式(EdgeTTS/声音克隆)+ 配音列表(生成/试听/管理)合并为一个板块。
|
||||
- **三、素材编辑**: 视频素材(上传/选择/管理)+ 时间轴编辑(波形/色块/拖拽排序)合并为一个板块。
|
||||
- **四、标题与字幕**: 片头标题/副标题/字幕样式配置;短暂显示/常驻显示;样式预览使用视频片头帧作为真实背景 (Day 28)。
|
||||
- **五、背景音乐**: 试听 + 音量控制 + 选择持久化。
|
||||
- **六、作品**(右栏): 作品列表 + 作品预览合并为一个板块。
|
||||
- **进度追踪**: 实时显示视频生成进度 (10% -> 100%)。
|
||||
- **作品预览**: 生成完成后直接播放下载(作品预览 + 历史作品)。
|
||||
- **预览优化**: 预览视频 `metadata` 预取,首帧加载更快。
|
||||
@@ -52,8 +50,8 @@ ViGent2 的前端界面,采用 Next.js 16 + TailwindCSS 构建。
|
||||
- **画面比例控制**: 时间轴顶部支持 `9:16 / 16:9` 输出比例选择,设置持久化并透传后端。
|
||||
|
||||
### 5. 字幕与标题 [Day 13 新增]
|
||||
- **片头标题**: 可选输入,限制 15 字;支持”短暂显示 / 常驻显示”,默认短暂显示(4 秒)。
|
||||
- **片头副标题**: 可选输入,限制 20 字;显示在主标题下方,用于补充说明或悬念引导;独立样式配置(字体/字号/颜色/间距),可由 AI 同时生成;仅在视频画面中显示,不参与发布标题 (Day 25)。
|
||||
- **片头标题**: 可选输入,限制 15 字;支持”短暂显示 / 常驻显示”,默认短暂显示(4 秒),对标题和副标题同时生效。
|
||||
- **片头副标题**: 可选输入,限制 20 字;显示在主标题下方,用于补充说明或悬念引导;独立样式配置(字体/字号/颜色/间距),可由 AI 同时生成;与标题共享显示模式设定;仅在视频画面中显示,不参与发布标题 (Day 25)。
|
||||
- **标题同步**: 首页片头标题修改会同步到发布信息标题。
|
||||
- **逐字高亮字幕**: 卡拉OK效果,默认开启,可关闭。
|
||||
- **自动对齐**: 基于 faster-whisper 生成字级别时间戳。
|
||||
@@ -67,8 +65,9 @@ ViGent2 的前端界面,采用 Next.js 16 + TailwindCSS 构建。
|
||||
|
||||
### 7. 账户设置 [Day 15 新增]
|
||||
- **手机号登录**: 11位中国手机号验证登录。
|
||||
- **账户下拉菜单**: 显示有效期 + 修改密码 + 安全退出。
|
||||
- **账户下拉菜单**: 显示手机号(中间四位脱敏)+ 有效期 + 修改密码 + 安全退出。
|
||||
- **修改密码**: 弹窗输入当前密码与新密码,修改后强制重新登录。
|
||||
- **登录即时生效**: 登录成功后 AuthContext 立即写入用户数据,无需刷新即显示手机号。
|
||||
|
||||
### 8. 付费开通会员 (`/pay`)
|
||||
- **支付宝电脑网站支付**: 跳转支付宝官方收银台,支持扫码/账号登录/余额等多种支付方式。
|
||||
@@ -143,5 +142,8 @@ src/
|
||||
## 🎨 设计规范
|
||||
|
||||
- **主色调**: 深紫/黑色系 (Dark Mode)
|
||||
- **交互**: 悬停微动画 (Hover Effects)
|
||||
- **响应式**: 适配桌面端大屏操作
|
||||
- **交互**: 悬停微动画 (Hover Effects);操作按钮默认半透明可见 (opacity-40),hover 时全亮,兼顾触屏设备
|
||||
- **响应式**: 适配桌面端与移动端;发布页平台卡片响应式布局(移动端紧凑/桌面端宽松)
|
||||
- **滚动体验**: 列表滚动条统一隐藏 (hide-scrollbar);刷新后自动回到顶部(禁用浏览器滚动恢复 + 列表 scroll 时间门控)
|
||||
- **样式预览**: 浮动预览窗口,桌面端左上角 280px,移动端右下角 160px(不遮挡控件)
|
||||
- **输入辅助**: 标题/副标题输入框实时字数计数器,超限变红
|
||||
|
||||
252
Docs/MUSETALK_DEPLOY.md
Normal file
252
Docs/MUSETALK_DEPLOY.md
Normal file
@@ -0,0 +1,252 @@
|
||||
# MuseTalk 部署指南
|
||||
|
||||
> **更新时间**:2026-02-27
|
||||
> **适用版本**:MuseTalk v1.5 (常驻服务模式)
|
||||
> **架构**:FastAPI 常驻服务 + PM2 进程管理
|
||||
|
||||
---
|
||||
|
||||
## 架构概览
|
||||
|
||||
MuseTalk 作为 **混合唇形同步方案** 的长视频引擎:
|
||||
|
||||
- **短视频 (<120s)** → LatentSync 1.6 (GPU1, 端口 8007)
|
||||
- **长视频 (>=120s)** → MuseTalk 1.5 (GPU0, 端口 8011)
|
||||
- 路由阈值由 `LIPSYNC_DURATION_THRESHOLD` 控制
|
||||
- MuseTalk 不可用时自动回退到 LatentSync
|
||||
|
||||
---
|
||||
|
||||
## 硬件要求
|
||||
|
||||
| 配置 | 最低要求 | 推荐配置 |
|
||||
|------|----------|----------|
|
||||
| GPU | 8GB VRAM (RTX 3060) | 24GB VRAM (RTX 3090) |
|
||||
| 内存 | 32GB | 64GB |
|
||||
| CUDA | 11.7+ | 11.8 |
|
||||
|
||||
> MuseTalk fp16 推理约需 4-8GB 显存,可与 CosyVoice 共享 GPU0。
|
||||
|
||||
---
|
||||
|
||||
## 安装步骤
|
||||
|
||||
### 1. Conda 环境
|
||||
|
||||
```bash
|
||||
cd /home/rongye/ProgramFiles/ViGent2/models/MuseTalk
|
||||
conda create -n musetalk python=3.10 -y
|
||||
conda activate musetalk
|
||||
```
|
||||
|
||||
### 2. PyTorch 2.0.1 + CUDA 11.8
|
||||
|
||||
> 必须使用此版本,mmcv 预编译包依赖。
|
||||
|
||||
```bash
|
||||
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
|
||||
```
|
||||
|
||||
### 3. 依赖安装
|
||||
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
|
||||
# MMLab 系列
|
||||
pip install --no-cache-dir -U openmim
|
||||
mim install mmengine
|
||||
mim install "mmcv==2.0.1"
|
||||
mim install "mmdet==3.1.0"
|
||||
pip install chumpy --no-build-isolation
|
||||
pip install "mmpose==1.1.0" --no-deps
|
||||
|
||||
# FastAPI 服务依赖
|
||||
pip install fastapi uvicorn httpx
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 模型权重
|
||||
|
||||
### 目录结构
|
||||
|
||||
```
|
||||
models/MuseTalk/models/
|
||||
├── musetalk/ ← v1 基础模型
|
||||
│ ├── config.json -> musetalk.json (软链接)
|
||||
│ ├── musetalk.json
|
||||
│ ├── musetalkV15 -> ../musetalkV15 (软链接, 关键!)
|
||||
│ └── pytorch_model.bin (~3.2GB)
|
||||
├── musetalkV15/ ← v1.5 UNet 模型
|
||||
│ ├── musetalk.json
|
||||
│ └── unet.pth (~3.2GB)
|
||||
├── sd-vae/ ← Stable Diffusion VAE
|
||||
│ ├── config.json
|
||||
│ └── diffusion_pytorch_model.bin
|
||||
├── whisper/ ← OpenAI Whisper Tiny
|
||||
│ ├── config.json
|
||||
│ ├── pytorch_model.bin (~151MB)
|
||||
│ └── preprocessor_config.json
|
||||
├── dwpose/ ← DWPose 人体姿态检测
|
||||
│ └── dw-ll_ucoco_384.pth (~387MB)
|
||||
├── syncnet/ ← SyncNet 唇形同步评估
|
||||
│ └── latentsync_syncnet.pt
|
||||
└── face-parse-bisent/ ← 人脸解析模型
|
||||
├── 79999_iter.pth (~53MB)
|
||||
└── resnet18-5c106cde.pth (~45MB)
|
||||
```
|
||||
|
||||
### 下载方式
|
||||
|
||||
使用项目自带脚本:
|
||||
|
||||
```bash
|
||||
cd /home/rongye/ProgramFiles/ViGent2/models/MuseTalk
|
||||
conda activate musetalk
|
||||
bash download_weights.sh
|
||||
```
|
||||
|
||||
或手动 Python API 下载:
|
||||
|
||||
```bash
|
||||
conda activate musetalk
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
python -c "
|
||||
from huggingface_hub import snapshot_download
|
||||
snapshot_download('TMElyralab/MuseTalk', local_dir='models',
|
||||
allow_patterns=['musetalk/*', 'musetalkV15/*'])
|
||||
snapshot_download('stabilityai/sd-vae-ft-mse', local_dir='models/sd-vae',
|
||||
allow_patterns=['config.json', 'diffusion_pytorch_model.bin'])
|
||||
snapshot_download('openai/whisper-tiny', local_dir='models/whisper',
|
||||
allow_patterns=['config.json', 'pytorch_model.bin', 'preprocessor_config.json'])
|
||||
snapshot_download('yzd-v/DWPose', local_dir='models/dwpose',
|
||||
allow_patterns=['dw-ll_ucoco_384.pth'])
|
||||
"
|
||||
```
|
||||
|
||||
### 创建必要的软链接
|
||||
|
||||
```bash
|
||||
cd /home/rongye/ProgramFiles/ViGent2/models/MuseTalk/models/musetalk
|
||||
ln -sf musetalk.json config.json
|
||||
ln -sf ../musetalkV15 musetalkV15
|
||||
```
|
||||
|
||||
> **关键**:`musetalk/musetalkV15` 软链接缺失会导致权重检测失败 (`weights: False`)。
|
||||
|
||||
---
|
||||
|
||||
## 服务启动
|
||||
|
||||
### PM2 进程管理(推荐)
|
||||
|
||||
```bash
|
||||
# 首次注册
|
||||
cd /home/rongye/ProgramFiles/ViGent2
|
||||
pm2 start run_musetalk.sh --name vigent2-musetalk
|
||||
pm2 save
|
||||
|
||||
# 日常管理
|
||||
pm2 restart vigent2-musetalk
|
||||
pm2 logs vigent2-musetalk
|
||||
pm2 stop vigent2-musetalk
|
||||
```
|
||||
|
||||
### 手动启动
|
||||
|
||||
```bash
|
||||
cd /home/rongye/ProgramFiles/ViGent2/models/MuseTalk
|
||||
/home/rongye/ProgramFiles/miniconda3/envs/musetalk/bin/python scripts/server.py
|
||||
```
|
||||
|
||||
### 健康检查
|
||||
|
||||
```bash
|
||||
curl http://localhost:8011/health
|
||||
# {"status":"ok","model_loaded":true}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 后端配置
|
||||
|
||||
`backend/.env` 中的相关变量:
|
||||
|
||||
```ini
|
||||
# MuseTalk 配置
|
||||
MUSETALK_GPU_ID=0 # GPU 编号 (与 CosyVoice 共存)
|
||||
MUSETALK_API_URL=http://localhost:8011 # 常驻服务地址
|
||||
MUSETALK_BATCH_SIZE=32 # 推理批大小
|
||||
MUSETALK_VERSION=v15 # 模型版本
|
||||
MUSETALK_USE_FLOAT16=true # 半精度加速
|
||||
|
||||
# 混合唇形同步路由
|
||||
LIPSYNC_DURATION_THRESHOLD=120 # 秒, >=此值用 MuseTalk
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 相关文件
|
||||
|
||||
| 文件 | 说明 |
|
||||
|------|------|
|
||||
| `models/MuseTalk/scripts/server.py` | FastAPI 常驻服务 (端口 8011) |
|
||||
| `run_musetalk.sh` | PM2 启动脚本 |
|
||||
| `backend/app/services/lipsync_service.py` | 混合路由 + `_call_musetalk_server()` |
|
||||
| `backend/app/core/config.py` | `MUSETALK_*` 配置项 |
|
||||
|
||||
---
|
||||
|
||||
## 性能优化 (server.py v2)
|
||||
|
||||
首次长视频测试 (136s, 3404 帧) 耗时 30 分钟。分析发现瓶颈在人脸检测 (28%)、BiSeNet 合成 (22%)、I/O (17%),而非 UNet 推理 (17%)。
|
||||
|
||||
### 已实施优化
|
||||
|
||||
| 优化项 | 说明 |
|
||||
|--------|------|
|
||||
| `MUSETALK_BATCH_SIZE` 8→32 | RTX 3090 显存充裕,UNet 推理加速 ~3x |
|
||||
| cv2.VideoCapture 直读帧 | 跳过 ffmpeg→PNG→imread 链路 |
|
||||
| 人脸检测降频 (每5帧) | DWPose + FaceAlignment 只在采样帧运行,中间帧线性插值 bbox |
|
||||
| BiSeNet mask 缓存 (每5帧) | `get_image_prepare_material` 每 5 帧运行,中间帧用 `get_image_blending` 复用 |
|
||||
| cv2.VideoWriter 直写 | 跳过逐帧 PNG 写盘 + ffmpeg 重编码 |
|
||||
| 每阶段计时 | 7 个阶段精确计时,方便后续调优 |
|
||||
|
||||
### 调优参数
|
||||
|
||||
`models/MuseTalk/scripts/server.py` 顶部可调:
|
||||
|
||||
```python
|
||||
DETECT_EVERY = 5 # 人脸检测降频间隔 (帧)
|
||||
BLEND_CACHE_EVERY = 5 # BiSeNet mask 缓存间隔 (帧)
|
||||
```
|
||||
|
||||
> 对于口播视频 (人脸几乎不动),5 帧间隔的插值误差可忽略。
|
||||
> 如人脸运动剧烈的场景,可降低为 2-3。
|
||||
|
||||
---
|
||||
|
||||
## 常见问题
|
||||
|
||||
### huggingface-hub 版本冲突
|
||||
|
||||
```
|
||||
ImportError: huggingface-hub>=0.19.3,<1.0 is required
|
||||
```
|
||||
|
||||
**解决**:降级 huggingface-hub
|
||||
|
||||
```bash
|
||||
pip install "huggingface-hub>=0.19.3,<1.0"
|
||||
```
|
||||
|
||||
### mmcv 导入失败
|
||||
|
||||
```bash
|
||||
pip uninstall mmcv mmcv-full -y
|
||||
mim install "mmcv==2.0.1"
|
||||
```
|
||||
|
||||
### 音视频长度不匹配
|
||||
|
||||
已在 `musetalk/utils/audio_processor.py` 中修复(零填充逻辑),无需额外处理。
|
||||
@@ -16,14 +16,16 @@
|
||||
文本 → EdgeTTS → 音频 → LatentSync → FFmpeg合成 → 最终视频
|
||||
|
||||
新流程 (单素材):
|
||||
文本 → EdgeTTS/Qwen3-TTS/预生成配音 → 音频 ─┬→ LatentSync → 唇形视频 ─┐
|
||||
文本 → EdgeTTS/CosyVoice/预生成配音 → 音频 ─┬→ LatentSync/MuseTalk → 唇形视频 ─┐
|
||||
└→ faster-whisper → 字幕JSON ─┴→ Remotion合成 → 最终视频
|
||||
|
||||
新流程 (多素材):
|
||||
音频 → 多素材按 custom_assignments 拼接 → LatentSync (单次推理) → 唇形视频 ─┐
|
||||
音频 → 多素材按 custom_assignments 拼接 → LatentSync/MuseTalk (单次推理) → 唇形视频 ─┐
|
||||
音频 → faster-whisper → 字幕JSON ─────────────────────────────────────────────┴→ Remotion合成 → 最终视频
|
||||
```
|
||||
|
||||
> **唇形同步路由**: 短视频 (<120s) 用 LatentSync 1.6 (GPU1),长视频 (>=120s) 用 MuseTalk 1.5 (GPU0),由 `LIPSYNC_DURATION_THRESHOLD` 控制。
|
||||
|
||||
## 系统要求
|
||||
|
||||
| 组件 | 要求 |
|
||||
@@ -185,6 +187,7 @@ Remotion 渲染参数在 `backend/app/services/remotion_service.py` 中配置:
|
||||
| 参数 | 默认值 | 说明 |
|
||||
|------|--------|------|
|
||||
| `fps` | 25 | 输出帧率 |
|
||||
| `concurrency` | 16 | Remotion 并发渲染进程数(默认 16,可通过 `--concurrency` CLI 参数覆盖) |
|
||||
| `title_display_mode` | `short` | 标题显示模式(`short`=短暂显示;`persistent`=常驻显示) |
|
||||
| `title_duration` | 4.0 | 标题显示时长(秒,仅 `short` 模式生效) |
|
||||
|
||||
@@ -273,7 +276,7 @@ wget https://github.com/googlefonts/noto-cjk/raw/main/Sans/OTF/SimplifiedChinese
|
||||
|
||||
### 使用 GPU 0
|
||||
|
||||
faster-whisper 默认使用 GPU 0,与 LatentSync (GPU 1) 分开,避免显存冲突。如需指定 GPU:
|
||||
faster-whisper 默认使用 GPU 0,与 MuseTalk 共享 GPU 0;LatentSync 使用 GPU 1,互不冲突。如需指定 GPU:
|
||||
|
||||
```python
|
||||
# 在 whisper_service.py 中修改
|
||||
@@ -289,3 +292,5 @@ WhisperService(device="cuda:0") # 或 "cuda:1"
|
||||
| 2026-01-29 | 1.0.0 | 初始版本,使用 faster-whisper + Remotion 实现逐字高亮字幕和片头标题 |
|
||||
| 2026-02-10 | 1.1.0 | 更新架构图:多素材 concat-then-infer、预生成配音选项 |
|
||||
| 2026-01-30 | 1.0.1 | 字幕高亮样式与标题动画优化,视觉表现更清晰 |
|
||||
| 2026-02-25 | 1.2.0 | 字幕时间戳从线性插值改为 Whisper 节奏映射,修复长视频字幕漂移 |
|
||||
| 2026-02-27 | 1.3.0 | 架构图更新 MuseTalk 混合路由;Remotion 并发渲染从 8 提升到 16;GPU 分配说明更新 |
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# ViGent2 开发任务清单 (Task Log)
|
||||
|
||||
**项目**: ViGent2 数字人口播视频生成系统
|
||||
**进度**: 100% (Day 25 - 支付宝付费开通会员)
|
||||
**更新时间**: 2026-02-24
|
||||
**进度**: 100% (Day 28 - CosyVoice FP16 加速 + 文档全面更新)
|
||||
**更新时间**: 2026-02-27
|
||||
|
||||
---
|
||||
|
||||
@@ -10,7 +10,45 @@
|
||||
|
||||
> 这里记录了每一天的核心开发内容与 milestone。
|
||||
|
||||
### Day 25: 文案提取修复 + 自定义提示词 + 片头副标题 (Current)
|
||||
### Day 28: CosyVoice FP16 加速 + 文档全面更新 (Current)
|
||||
- [x] **CosyVoice FP16 半精度加速**: `AutoModel()` 开启 `fp16=True`,LLM 推理和 Flow Matching 自动混合精度运行,预估提速 30-40%、显存降低 ~30%。
|
||||
- [x] **文档全面更新**: README.md / DEPLOY_MANUAL.md / SUBTITLE_DEPLOY.md / BACKEND_README.md 补充 MuseTalk 混合唇形同步方案、性能优化、Remotion 并发渲染等内容。
|
||||
|
||||
### Day 27: Remotion 描边修复 + 字体样式扩展 + 混合唇形同步 + 性能优化
|
||||
- [x] **描边渲染修复**: 标题/副标题/字幕从 `textShadow` 4 方向模拟改为 CSS 原生 `-webkit-text-stroke` + `paint-order: stroke fill`,修复描边过粗和副标题重影问题。
|
||||
- [x] **字体样式扩展**: 标题样式 4→12 个(+庞门正道/优设标题圆/阿里数黑体/文道潮黑/无界黑/厚底黑/寒蝉半圆体/欣意吉祥宋),字幕样式 4→8 个(+少女粉/清新绿/金色隶书/楷体红字)。
|
||||
- [x] **描边参数优化**: 所有预设 `stroke_size` 从 8 降至 4~5,配合原生描边视觉更干净。
|
||||
- [x] **TypeScript 类型修复**: Root.tsx `Composition` 泛型与 `calculateMetadata` 参数类型对齐;Video.tsx `VideoProps` 添加索引签名兼容 `Record<string, unknown>`;VideoLayer.tsx 移除 `OffthreadVideo` 不支持的 `loop` prop。
|
||||
- [x] **进度条文案还原**: 进度条从显示后端推送消息改回固定 `正在AI生成中...`。
|
||||
- [x] **MuseTalk 混合唇形同步**: 部署 MuseTalk 1.5 常驻服务 (GPU0, 端口 8011),按音频时长自动路由 — 短视频 (<120s) 走 LatentSync,长视频 (>=120s) 走 MuseTalk,MuseTalk 不可用时自动回退。
|
||||
- [x] **MuseTalk 推理性能优化**: server.py v2 重写 — cv2 直读帧(跳过 ffmpeg→PNG)、人脸检测降频(每5帧)、BiSeNet mask 缓存(每5帧)、cv2.VideoWriter 直写(跳过 PNG 写盘)、batch_size 8→32,预估 30min→8-10min (~3x)。
|
||||
- [x] **Remotion 并发渲染优化**: render.ts 新增 concurrency 参数,从默认 8 提升到 16 (56核 CPU),预估 5min→2-3min。
|
||||
|
||||
### Day 26: 前端优化:板块合并 + 序号标题 + UI 精细化
|
||||
- [x] **板块合并**: 首页 9 个独立板块合并为 5 个主板块(配音方式+配音列表→三、配音;视频素材+时间轴→四、素材编辑;历史作品+作品预览→六、作品)。
|
||||
- [x] **中文序号标题**: 一~十编号(首页一~六,发布页七~十),移除所有 emoji 图标。
|
||||
- [x] **embedded 模式**: 6 个组件支持 `embedded` prop,嵌入时不渲染外层卡片/标题。
|
||||
- [x] **配音列表两行布局**: embedded 模式第 1 行语速+生成配音(右对齐),第 2 行配音列表+刷新。
|
||||
- [x] **子组件自渲染子标题**: MaterialSelector/TimelineEditor embedded 时自渲染 h3 子标题+操作按钮同行。
|
||||
- [x] **下拉对齐**: TitleSubtitlePanel 标签统一 `w-20`,下拉 `w-1/3 min-w-[100px]`,垂直对齐。
|
||||
- [x] **参考音频文案简化**: 底部段落移至标题旁,简化为 `(上传3-10秒语音样本)`。
|
||||
- [x] **账户手机号显示**: AccountSettingsDropdown 新增手机号显示。
|
||||
- [x] **标题显示模式对副标题生效**: payload 条件修复 + UI 下拉上移至板块标题行。
|
||||
- [x] **登录后用户信息立即可用**: AuthContext 暴露 `setUser`,登录成功后立即写入用户数据,修复登录后显示"未知账户"的问题。
|
||||
- [x] **文案微调**: 素材描述改为"上传自拍视频,最多可选4个";显示模式选项加"标题"前缀。
|
||||
- [x] **UI/UX 体验优化**: 操作按钮移动端可见(opacity-40)、手机号脱敏、标题字数计数器、时间轴拖拽抓手图标、截取滑块放大。
|
||||
- [x] **代码质量修复**: 密码弹窗 success 清空、MaterialSelector useMemo + disabled 守卫、TimelineEditor useMemo。
|
||||
- [x] **发布页响应式布局**: 平台账号卡片单行布局,移动端紧凑(小图标/小按钮),桌面端宽松(与其他板块风格一致)。
|
||||
- [x] **移动端刷新回顶部**: `scrollRestoration = "manual"` + 列表 scroll 时间门控(`scrollEffectsEnabled` ref,1 秒内禁止自动滚动)+ 延迟兜底 `scrollTo(0,0)`。
|
||||
- [x] **移动端样式预览缩小**: FloatingStylePreview 移动端宽度缩至 160px,位置改为右下角,不遮挡样式调节控件。
|
||||
- [x] **列表滚动条统一隐藏**: 所有列表(BGM/配音/作品/素材/文案提取)滚动条改回 `hide-scrollbar`。
|
||||
- [x] **移动端配音/素材适配**: VoiceSelector 按钮移动端缩小(`px-2 sm:px-4`)修复克隆声音不可见;MaterialSelector 标题行移除 `whitespace-nowrap`,描述移动端隐藏,修复刷新按钮溢出。
|
||||
- [x] **生成配音按钮放大**: 从辅助尺寸(`text-xs px-2 py-1`)升级为主操作尺寸(`text-sm font-medium px-4 py-2`),新增阴影。
|
||||
- [x] **生成进度条位置调整**: 从"六、作品"卡片内部提取到右栏独立卡片,显示在作品卡片上方,更醒目。
|
||||
- [x] **LatentSync 超时修复**: httpx 超时从 1200s(20 分钟)改为 3600s(1 小时),修复 2 分钟以上视频口型推理超时回退问题。
|
||||
- [x] **字幕时间戳节奏映射**: `whisper_service.py` 从全程线性插值改为 Whisper 逐词节奏映射,修复长视频字幕漂移。
|
||||
|
||||
### Day 25: 文案提取修复 + 自定义提示词 + 片头副标题
|
||||
- [x] **抖音文案提取修复**: yt-dlp Fresh cookies 报错,重写 `_download_douyin_manual` 为移动端分享页 + 自动获取 ttwid 方案。
|
||||
- [x] **清理 DOUYIN_COOKIE**: 新方案不再需要手动维护 Cookie,从 `.env`/`config.py`/`service.py` 全面删除。
|
||||
- [x] **AI 智能改写自定义提示词**: 后端 `rewrite_script()` 支持 `custom_prompt` 参数;前端 checkbox 旁新增折叠式提示词编辑区,localStorage 持久化。
|
||||
|
||||
29
README.md
29
README.md
@@ -4,7 +4,7 @@
|
||||
|
||||
> 📹 **上传人物** · 🎙️ **输入文案** · 🎬 **一键成片**
|
||||
|
||||
基于 **LatentSync 1.6 + EdgeTTS** 的开源数字人口播视频生成系统。
|
||||
基于 **LatentSync 1.6 + MuseTalk 1.5 混合唇形同步** 的开源数字人口播视频生成系统。
|
||||
集成 **CosyVoice 3.0** 声音克隆与自动社交媒体发布功能。
|
||||
|
||||
[功能特性](#-功能特性) • [技术栈](#-技术栈) • [文档中心](#-文档中心) • [部署指南](Docs/DEPLOY_MANUAL.md)
|
||||
@@ -16,10 +16,10 @@
|
||||
## ✨ 功能特性
|
||||
|
||||
### 核心能力
|
||||
- 🎬 **高清唇形同步** - LatentSync 1.6 驱动,512×512 高分辨率 Latent Diffusion 模型。
|
||||
- 🎬 **高清唇形同步** - 混合方案:短视频 (<120s) 用 LatentSync 1.6 (高质量 Latent Diffusion),长视频 (>=120s) 用 MuseTalk 1.5 (实时级单步推理),自动路由 + 回退。
|
||||
- 🎙️ **多模态配音** - 支持 **EdgeTTS** (微软超自然语音, 10 语言) 和 **CosyVoice 3.0** (3秒极速声音克隆, 9语言+18方言, 语速可调)。上传参考音频自动 Whisper 转写 + 智能截取。配音前置工作流:先生成配音 → 选素材 → 生成视频。
|
||||
- 📝 **智能字幕** - 集成 faster-whisper + Remotion,自动生成逐字高亮 (卡拉OK效果) 字幕。
|
||||
- 🎨 **样式预设** - 标题/副标题/字幕样式选择 + 预览 + 字号调节,支持自定义字体库。
|
||||
- 🎨 **样式预设** - 12 种标题 + 8 种字幕样式预设,支持预览 + 字号调节 + 自定义字体库。CSS 原生描边渲染,清晰无重影。
|
||||
- 🏷️ **标题显示模式** - 片头标题支持 `短暂显示` / `常驻显示`,默认短暂显示(4秒),用户偏好自动持久化。
|
||||
- 📌 **片头副标题** - 可选副标题显示在主标题下方,独立样式配置,AI 可同时生成,20 字限制。
|
||||
- 🖼️ **作品预览一致性** - 标题/字幕预览与 Remotion 成片统一响应式缩放和自动换行,窄屏画布也稳定显示。
|
||||
@@ -37,7 +37,7 @@
|
||||
- 💳 **付费会员** - 支付宝电脑网站支付自动开通会员,到期自动停用并引导续费,管理员手动激活并存。
|
||||
- 🔐 **认证与隔离** - 基于 Supabase 的用户隔离,支持手机号注册/登录、密码管理。
|
||||
- 🛡️ **服务守护** - 内置 Watchdog 看门狗机制,自动监控并重启僵死服务,确保 7x24h 稳定运行。
|
||||
- 🚀 **性能优化** - 视频预压缩、模型常驻服务(近实时加载)、双 GPU 流水线并发。
|
||||
- 🚀 **性能优化** - 视频预压缩、模型常驻服务(近实时加载)、双 GPU 流水线并发、MuseTalk 人脸检测降频 + BiSeNet 缓存、Remotion 16 并发渲染。
|
||||
|
||||
---
|
||||
|
||||
@@ -46,9 +46,9 @@
|
||||
| 领域 | 核心技术 | 说明 |
|
||||
|------|----------|------|
|
||||
| **前端** | Next.js 16 | TypeScript, TailwindCSS, SWR, wavesurfer.js |
|
||||
| **后端** | FastAPI | Python 3.10, AsyncIO, PM2 |
|
||||
| **后端** | FastAPI | Python 3.12, AsyncIO, PM2 |
|
||||
| **数据库** | Supabase | PostgreSQL, Storage (本地/S3), Auth |
|
||||
| **唇形同步** | LatentSync 1.6 | PyTorch 2.5, Diffusers, DeepCache |
|
||||
| **唇形同步** | LatentSync 1.6 + MuseTalk 1.5 | 混合路由:短视频 Diffusion 高质量,长视频单步实时推理 |
|
||||
| **声音克隆** | CosyVoice 3.0 | 0.5B 参数量,9 语言 + 18 方言 |
|
||||
| **自动化** | Playwright | 社交媒体无头浏览器自动化 |
|
||||
| **部署** | Docker & PM2 | 混合部署架构 |
|
||||
@@ -62,14 +62,17 @@
|
||||
### 部署运维
|
||||
- **[部署手册 (DEPLOY_MANUAL.md)](Docs/DEPLOY_MANUAL.md)** - 👈 **部署请看这里**!包含完整的环境搭建步骤。
|
||||
- [参考音频服务部署 (COSYVOICE3_DEPLOY.md)](Docs/COSYVOICE3_DEPLOY.md) - 声音克隆模型部署指南。
|
||||
- [LatentSync 部署指南](models/LatentSync/DEPLOY.md) - 唇形同步模型独立部署。
|
||||
- [LatentSync 部署指南 (LATENTSYNC_DEPLOY.md)](Docs/LATENTSYNC_DEPLOY.md) - 唇形同步模型独立部署。
|
||||
- [MuseTalk 部署指南 (MUSETALK_DEPLOY.md)](Docs/MUSETALK_DEPLOY.md) - 长视频唇形同步模型部署。
|
||||
- [Supabase 部署指南 (SUPABASE_DEPLOY.md)](Docs/SUPABASE_DEPLOY.md) - Supabase 与认证系统配置。
|
||||
- [支付宝部署指南 (ALIPAY_DEPLOY.md)](Docs/ALIPAY_DEPLOY.md) - 支付宝付费开通会员配置。
|
||||
|
||||
### 开发文档
|
||||
- [后端开发指南](Docs/BACKEND_README.md) - 接口规范与开发流程。
|
||||
- [后端开发规范](Docs/BACKEND_DEV.md) - 分层约定与开发习惯。
|
||||
- [前端开发指南](Docs/FRONTEND_DEV.md) - UI 组件与页面规范。
|
||||
- [后端开发指南 (BACKEND_README.md)](Docs/BACKEND_README.md) - 接口规范与开发流程。
|
||||
- [后端开发规范 (BACKEND_DEV.md)](Docs/BACKEND_DEV.md) - 分层约定与开发习惯。
|
||||
- [前端开发指南 (FRONTEND_DEV.md)](Docs/FRONTEND_DEV.md) - UI 组件与页面规范。
|
||||
- [前端组件文档 (FRONTEND_README.md)](Docs/FRONTEND_README.md) - 组件结构与板块说明。
|
||||
- [Remotion 字幕部署 (SUBTITLE_DEPLOY.md)](Docs/SUBTITLE_DEPLOY.md) - 字幕渲染服务部署。
|
||||
- [开发日志 (DevLogs)](Docs/DevLogs/) - 每日开发进度与技术决策记录。
|
||||
|
||||
---
|
||||
@@ -86,7 +89,8 @@ ViGent2/
|
||||
├── frontend/ # Next.js 前端应用
|
||||
├── remotion/ # Remotion 视频渲染 (标题/字幕合成)
|
||||
├── models/ # AI 模型仓库
|
||||
│ ├── LatentSync/ # 唇形同步服务
|
||||
│ ├── LatentSync/ # 唇形同步服务 (GPU1, 短视频)
|
||||
│ ├── MuseTalk/ # 唇形同步服务 (GPU0, 长视频)
|
||||
│ └── CosyVoice/ # 声音克隆服务
|
||||
└── Docs/ # 项目文档
|
||||
```
|
||||
@@ -101,7 +105,8 @@ ViGent2/
|
||||
|----------|------|------|
|
||||
| **Web UI** | 3002 | 用户访问入口 (Next.js) |
|
||||
| **Backend API** | 8006 | 核心业务接口 (FastAPI) |
|
||||
| **LatentSync** | 8007 | 唇形同步推理服务 |
|
||||
| **LatentSync** | 8007 | 唇形同步推理服务 (GPU1, 短视频) |
|
||||
| **MuseTalk** | 8011 | 唇形同步推理服务 (GPU0, 长视频) |
|
||||
| **CosyVoice 3.0** | 8010 | 声音克隆推理服务 |
|
||||
| **Supabase** | 8008 | 数据库与认证网关 |
|
||||
|
||||
|
||||
@@ -25,10 +25,10 @@ LATENTSYNC_USE_SERVER=true
|
||||
# LATENTSYNC_API_URL=http://localhost:8007
|
||||
|
||||
# 推理步数 (20-50, 越高质量越好,速度越慢)
|
||||
LATENTSYNC_INFERENCE_STEPS=40
|
||||
LATENTSYNC_INFERENCE_STEPS=16
|
||||
|
||||
# 引导系数 (1.0-3.0, 越高唇同步越准,但可能抖动)
|
||||
LATENTSYNC_GUIDANCE_SCALE=2.0
|
||||
LATENTSYNC_GUIDANCE_SCALE=1.5
|
||||
|
||||
# 启用 DeepCache 加速 (推荐开启)
|
||||
LATENTSYNC_ENABLE_DEEPCACHE=true
|
||||
@@ -36,6 +36,26 @@ LATENTSYNC_ENABLE_DEEPCACHE=true
|
||||
# 随机种子 (设为 -1 则随机)
|
||||
LATENTSYNC_SEED=1247
|
||||
|
||||
# =============== MuseTalk 配置 ===============
|
||||
# GPU 选择 (默认 GPU0,与 CosyVoice 共存)
|
||||
MUSETALK_GPU_ID=0
|
||||
|
||||
# 常驻服务地址 (端口 8011)
|
||||
MUSETALK_API_URL=http://localhost:8011
|
||||
|
||||
# 推理批大小
|
||||
MUSETALK_BATCH_SIZE=32
|
||||
|
||||
# 模型版本
|
||||
MUSETALK_VERSION=v15
|
||||
|
||||
# 半精度加速
|
||||
MUSETALK_USE_FLOAT16=true
|
||||
|
||||
# =============== 混合唇形同步路由 ===============
|
||||
# 音频时长 >= 此阈值(秒)用 MuseTalk,< 此阈值用 LatentSync
|
||||
LIPSYNC_DURATION_THRESHOLD=120
|
||||
|
||||
# =============== 上传配置 ===============
|
||||
# 最大上传文件大小 (MB)
|
||||
MAX_UPLOAD_SIZE_MB=500
|
||||
@@ -70,8 +90,6 @@ GLM_MODEL=glm-4.7-flash
|
||||
# 确保存储卷映射正确,避免硬编码路径
|
||||
SUPABASE_STORAGE_LOCAL_PATH=/home/rongye/ProgramFiles/Supabase/volumes/storage/stub/stub
|
||||
|
||||
# =============== 抖音视频下载 Cookie ===============
|
||||
|
||||
# =============== 支付宝配置 ===============
|
||||
ALIPAY_APP_ID=2021006132600283
|
||||
ALIPAY_PRIVATE_KEY_PATH=/home/rongye/ProgramFiles/ViGent2/backend/keys/app_private_key.pem
|
||||
|
||||
@@ -57,7 +57,17 @@ class Settings(BaseSettings):
|
||||
LATENTSYNC_ENABLE_DEEPCACHE: bool = True # 启用 DeepCache 加速
|
||||
LATENTSYNC_SEED: int = 1247 # 随机种子 (-1 则随机)
|
||||
LATENTSYNC_USE_SERVER: bool = True # 使用常驻服务 (Persistent Server) 加速
|
||||
|
||||
|
||||
# MuseTalk 配置
|
||||
MUSETALK_GPU_ID: int = 0 # GPU ID (默认使用 GPU0)
|
||||
MUSETALK_API_URL: str = "http://localhost:8011" # 常驻服务地址
|
||||
MUSETALK_BATCH_SIZE: int = 8 # 推理批大小
|
||||
MUSETALK_VERSION: str = "v15" # 模型版本
|
||||
MUSETALK_USE_FLOAT16: bool = True # 半精度加速
|
||||
|
||||
# 混合唇形同步路由
|
||||
LIPSYNC_DURATION_THRESHOLD: float = 120.0 # 秒,>=此值用 MuseTalk
|
||||
|
||||
# Supabase 配置
|
||||
SUPABASE_URL: str = ""
|
||||
SUPABASE_PUBLIC_URL: str = "" # 公网访问地址,用于生成前端可访问的 URL
|
||||
@@ -93,6 +103,11 @@ class Settings(BaseSettings):
|
||||
"""LatentSync 目录路径 (动态计算)"""
|
||||
return self.BASE_DIR.parent.parent / "models" / "LatentSync"
|
||||
|
||||
@property
|
||||
def MUSETALK_DIR(self) -> Path:
|
||||
"""MuseTalk 目录路径 (动态计算)"""
|
||||
return self.BASE_DIR.parent.parent / "models" / "MuseTalk"
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
extra = "ignore" # 忽略未知的环境变量
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
AI 相关 API 路由
|
||||
"""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
from pydantic import BaseModel
|
||||
from loguru import logger
|
||||
@@ -25,6 +27,12 @@ class GenerateMetaResponse(BaseModel):
|
||||
tags: list[str]
|
||||
|
||||
|
||||
class RewriteRequest(BaseModel):
|
||||
"""改写请求"""
|
||||
text: str
|
||||
custom_prompt: Optional[str] = None
|
||||
|
||||
|
||||
class TranslateRequest(BaseModel):
|
||||
"""翻译请求"""
|
||||
text: str
|
||||
@@ -73,3 +81,18 @@ async def generate_meta(req: GenerateMetaRequest):
|
||||
except Exception as e:
|
||||
logger.error(f"Generate meta failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/rewrite")
|
||||
async def rewrite_script(req: RewriteRequest):
|
||||
"""AI 改写文案"""
|
||||
if not req.text or not req.text.strip():
|
||||
raise HTTPException(status_code=400, detail="文案不能为空")
|
||||
|
||||
try:
|
||||
logger.info(f"Rewriting text: {req.text[:50]}...")
|
||||
rewritten = await glm_service.rewrite_script(req.text.strip(), req.custom_prompt)
|
||||
return success_response({"rewritten_text": rewritten})
|
||||
except Exception as e:
|
||||
logger.error(f"Rewrite failed: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
@@ -63,11 +63,15 @@ async def extract_script(file=None, url: Optional[str] = None, rewrite: bool = T
|
||||
# 2. 提取文案 (Whisper)
|
||||
script = await whisper_service.transcribe(str(audio_path))
|
||||
|
||||
# 3. AI 改写 (GLM)
|
||||
# 3. AI 改写 (GLM) — 失败时降级返回原文
|
||||
rewritten = None
|
||||
if rewrite and script and len(script.strip()) > 0:
|
||||
logger.info("Rewriting script...")
|
||||
rewritten = await glm_service.rewrite_script(script, custom_prompt)
|
||||
try:
|
||||
rewritten = await glm_service.rewrite_script(script, custom_prompt)
|
||||
except Exception as e:
|
||||
logger.warning(f"GLM rewrite failed, returning original script: {e}")
|
||||
rewritten = None
|
||||
|
||||
return {
|
||||
"original_script": script,
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Optional, Any, List
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
import time
|
||||
import traceback
|
||||
import httpx
|
||||
@@ -415,18 +416,21 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
|
||||
|
||||
lipsync_start = time.time()
|
||||
|
||||
# ── 第一步:下载所有素材并检测分辨率 ──
|
||||
# ── 第一步:并行下载所有素材并检测分辨率 ──
|
||||
material_locals: List[Path] = []
|
||||
resolutions = []
|
||||
|
||||
for i, assignment in enumerate(assignments):
|
||||
async def _download_and_normalize(i: int, assignment: dict):
|
||||
"""下载单个素材并归一化方向"""
|
||||
material_local = temp_dir / f"{task_id}_material_{i}.mp4"
|
||||
temp_files.append(material_local)
|
||||
await _download_material(assignment["material_path"], material_local)
|
||||
|
||||
# 归一化旋转元数据,确保分辨率判断与后续推理一致
|
||||
normalized_material = temp_dir / f"{task_id}_material_{i}_norm.mp4"
|
||||
normalized_result = video.normalize_orientation(
|
||||
loop = asyncio.get_event_loop()
|
||||
normalized_result = await loop.run_in_executor(
|
||||
None,
|
||||
video.normalize_orientation,
|
||||
str(material_local),
|
||||
str(normalized_material),
|
||||
)
|
||||
@@ -434,8 +438,17 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
|
||||
temp_files.append(normalized_material)
|
||||
material_local = normalized_material
|
||||
|
||||
material_locals.append(material_local)
|
||||
resolutions.append(video.get_resolution(str(material_local)))
|
||||
res = video.get_resolution(str(material_local))
|
||||
return material_local, res
|
||||
|
||||
download_tasks = [
|
||||
_download_and_normalize(i, assignment)
|
||||
for i, assignment in enumerate(assignments)
|
||||
]
|
||||
download_results = await asyncio.gather(*download_tasks)
|
||||
for local, res in download_results:
|
||||
material_locals.append(local)
|
||||
resolutions.append(res)
|
||||
|
||||
# 按用户选择的画面比例统一分辨率
|
||||
base_res = target_resolution
|
||||
@@ -443,29 +456,42 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
|
||||
if need_scale:
|
||||
logger.info(f"[MultiMat] 素材分辨率不一致,统一到 {base_res[0]}x{base_res[1]}")
|
||||
|
||||
# ── 第二步:裁剪每段素材到对应时长 ──
|
||||
prepared_segments: List[Path] = []
|
||||
# ── 第二步:并行裁剪每段素材到对应时长 ──
|
||||
prepared_segments: List[Path] = [None] * num_segments
|
||||
|
||||
for i, assignment in enumerate(assignments):
|
||||
seg_progress = 15 + int((i / num_segments) * 30) # 15% → 45%
|
||||
async def _prepare_one_segment(i: int, assignment: dict):
|
||||
"""将单个素材裁剪/循环到对应时长"""
|
||||
seg_dur = assignment["end"] - assignment["start"]
|
||||
_update_task(
|
||||
task_id,
|
||||
progress=seg_progress,
|
||||
message=f"正在准备素材 {i+1}/{num_segments}..."
|
||||
)
|
||||
|
||||
prepared_path = temp_dir / f"{task_id}_prepared_{i}.mp4"
|
||||
temp_files.append(prepared_path)
|
||||
video.prepare_segment(
|
||||
str(material_locals[i]), seg_dur, str(prepared_path),
|
||||
# 多素材拼接前统一重编码为同分辨率/同编码,避免 concat 仅保留首段
|
||||
target_resolution=base_res,
|
||||
source_start=assignment.get("source_start", 0.0),
|
||||
source_end=assignment.get("source_end"),
|
||||
target_fps=25,
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
video.prepare_segment,
|
||||
str(material_locals[i]),
|
||||
seg_dur,
|
||||
str(prepared_path),
|
||||
base_res,
|
||||
assignment.get("source_start", 0.0),
|
||||
assignment.get("source_end"),
|
||||
25,
|
||||
)
|
||||
prepared_segments.append(prepared_path)
|
||||
return i, prepared_path
|
||||
|
||||
_update_task(
|
||||
task_id,
|
||||
progress=15,
|
||||
message=f"正在并行准备 {num_segments} 个素材片段..."
|
||||
)
|
||||
|
||||
prepare_tasks = [
|
||||
_prepare_one_segment(i, assignment)
|
||||
for i, assignment in enumerate(assignments)
|
||||
]
|
||||
prepare_results = await asyncio.gather(*prepare_tasks)
|
||||
for i, path in prepare_results:
|
||||
prepared_segments[i] = path
|
||||
|
||||
# ── 第二步:拼接所有素材片段 ──
|
||||
_update_task(task_id, progress=50, message="正在拼接素材片段...")
|
||||
@@ -553,51 +579,89 @@ async def process_video_generation(task_id: str, req: GenerateRequest, user_id:
|
||||
print(f"[Pipeline] LipSync completed in {lipsync_time:.1f}s")
|
||||
_update_task(task_id, progress=80)
|
||||
|
||||
# 单素材模式:Whisper 在 LatentSync 之后
|
||||
if req.enable_subtitles:
|
||||
# 单素材模式:Whisper 延迟到下方与 BGM 并行执行
|
||||
if not req.enable_subtitles:
|
||||
captions_path = None
|
||||
|
||||
_update_task(task_id, progress=85)
|
||||
|
||||
# ── Whisper 字幕 + BGM 混音 并行(两者都只依赖 audio_path)──
|
||||
final_audio_path = audio_path
|
||||
_whisper_task = None
|
||||
_bgm_task = None
|
||||
|
||||
# 单素材模式下 Whisper 尚未执行,这里与 BGM 并行启动
|
||||
need_whisper = not is_multi and req.enable_subtitles and captions_path is None
|
||||
if need_whisper:
|
||||
captions_path = temp_dir / f"{task_id}_captions.json"
|
||||
temp_files.append(captions_path)
|
||||
_captions_path_str = str(captions_path)
|
||||
|
||||
async def _run_whisper():
|
||||
_update_task(task_id, message="正在生成字幕 (Whisper)...", progress=82)
|
||||
|
||||
captions_path = temp_dir / f"{task_id}_captions.json"
|
||||
temp_files.append(captions_path)
|
||||
|
||||
try:
|
||||
await whisper_service.align(
|
||||
audio_path=str(audio_path),
|
||||
text=req.text,
|
||||
output_path=str(captions_path),
|
||||
output_path=_captions_path_str,
|
||||
language=_locale_to_whisper_lang(req.language),
|
||||
original_text=req.text,
|
||||
)
|
||||
print(f"[Pipeline] Whisper alignment completed")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Whisper alignment failed, skipping subtitles: {e}")
|
||||
captions_path = None
|
||||
return False
|
||||
|
||||
_update_task(task_id, progress=85)
|
||||
_whisper_task = _run_whisper()
|
||||
|
||||
final_audio_path = audio_path
|
||||
if req.bgm_id:
|
||||
_update_task(task_id, message="正在合成背景音乐...", progress=86)
|
||||
|
||||
bgm_path = resolve_bgm_path(req.bgm_id)
|
||||
if bgm_path:
|
||||
mix_output_path = temp_dir / f"{task_id}_audio_mix.wav"
|
||||
temp_files.append(mix_output_path)
|
||||
volume = req.bgm_volume if req.bgm_volume is not None else 0.2
|
||||
volume = max(0.0, min(float(volume), 1.0))
|
||||
try:
|
||||
video.mix_audio(
|
||||
voice_path=str(audio_path),
|
||||
bgm_path=str(bgm_path),
|
||||
output_path=str(mix_output_path),
|
||||
bgm_volume=volume
|
||||
)
|
||||
final_audio_path = mix_output_path
|
||||
except Exception as e:
|
||||
logger.warning(f"BGM mix failed, fallback to voice only: {e}")
|
||||
_mix_output = str(mix_output_path)
|
||||
_bgm_path = str(bgm_path)
|
||||
_voice_path = str(audio_path)
|
||||
_volume = volume
|
||||
|
||||
async def _run_bgm():
|
||||
_update_task(task_id, message="正在合成背景音乐...", progress=86)
|
||||
loop = asyncio.get_event_loop()
|
||||
try:
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
video.mix_audio,
|
||||
_voice_path,
|
||||
_bgm_path,
|
||||
_mix_output,
|
||||
_volume,
|
||||
)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"BGM mix failed, fallback to voice only: {e}")
|
||||
return False
|
||||
|
||||
_bgm_task = _run_bgm()
|
||||
else:
|
||||
logger.warning(f"BGM not found: {req.bgm_id}")
|
||||
|
||||
# 并行等待 Whisper + BGM
|
||||
parallel_tasks = [t for t in (_whisper_task, _bgm_task) if t is not None]
|
||||
if parallel_tasks:
|
||||
results = await asyncio.gather(*parallel_tasks)
|
||||
result_idx = 0
|
||||
if _whisper_task is not None:
|
||||
if not results[result_idx]:
|
||||
captions_path = None
|
||||
result_idx += 1
|
||||
if _bgm_task is not None:
|
||||
if results[result_idx]:
|
||||
final_audio_path = mix_output_path
|
||||
|
||||
|
||||
use_remotion = (captions_path and captions_path.exists()) or req.title or req.secondary_title
|
||||
|
||||
subtitle_style = None
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"""
|
||||
唇形同步服务
|
||||
通过 subprocess 调用 LatentSync conda 环境进行推理
|
||||
配置为使用 GPU1 (CUDA:1)
|
||||
混合方案: 短视频用 LatentSync (高质量), 长视频用 MuseTalk (高速度)
|
||||
路由阈值: LIPSYNC_DURATION_THRESHOLD (默认 120s)
|
||||
"""
|
||||
import os
|
||||
import shutil
|
||||
@@ -17,15 +17,18 @@ from app.core.config import settings
|
||||
|
||||
|
||||
class LipSyncService:
|
||||
"""唇形同步服务 - LatentSync 1.6 集成 (Subprocess 方式)"""
|
||||
|
||||
"""唇形同步服务 - LatentSync 1.6 + MuseTalk 1.5 混合方案"""
|
||||
|
||||
def __init__(self):
|
||||
self.use_local = settings.LATENTSYNC_LOCAL
|
||||
self.api_url = settings.LATENTSYNC_API_URL
|
||||
self.latentsync_dir = settings.LATENTSYNC_DIR
|
||||
self.gpu_id = settings.LATENTSYNC_GPU_ID
|
||||
self.use_server = settings.LATENTSYNC_USE_SERVER
|
||||
|
||||
|
||||
# MuseTalk 配置
|
||||
self.musetalk_api_url = settings.MUSETALK_API_URL
|
||||
|
||||
# GPU 并发锁 (Serial Queue)
|
||||
self._lock = asyncio.Lock()
|
||||
|
||||
@@ -103,7 +106,7 @@ class LipSyncService:
|
||||
"-t", str(target_duration), # 截取到目标时长
|
||||
"-c:v", "libx264",
|
||||
"-preset", "fast",
|
||||
"-crf", "18",
|
||||
"-crf", "23",
|
||||
"-an", # 去掉原音频
|
||||
output_path
|
||||
]
|
||||
@@ -268,6 +271,18 @@ class LipSyncService:
|
||||
else:
|
||||
actual_video_path = video_path
|
||||
|
||||
# 混合路由: 长视频走 MuseTalk,短视频走 LatentSync
|
||||
if audio_duration and audio_duration >= settings.LIPSYNC_DURATION_THRESHOLD:
|
||||
logger.info(
|
||||
f"🔄 音频 {audio_duration:.1f}s >= {settings.LIPSYNC_DURATION_THRESHOLD}s,路由到 MuseTalk"
|
||||
)
|
||||
musetalk_result = await self._call_musetalk_server(
|
||||
actual_video_path, audio_path, output_path
|
||||
)
|
||||
if musetalk_result:
|
||||
return musetalk_result
|
||||
logger.warning("⚠️ MuseTalk 不可用,回退到 LatentSync(长视频,会较慢)")
|
||||
|
||||
if self.use_server:
|
||||
# 模式 A: 调用常驻服务 (加速模式)
|
||||
return await self._call_persistent_server(actual_video_path, audio_path, output_path)
|
||||
@@ -352,6 +367,55 @@ class LipSyncService:
|
||||
shutil.copy(video_path, output_path)
|
||||
return output_path
|
||||
|
||||
async def _call_musetalk_server(
|
||||
self, video_path: str, audio_path: str, output_path: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
调用 MuseTalk 常驻服务。
|
||||
成功返回 output_path,不可用返回 None(信号上层回退到 LatentSync)。
|
||||
"""
|
||||
server_url = self.musetalk_api_url
|
||||
logger.info(f"⚡ 调用 MuseTalk 服务: {server_url}")
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=3600.0) as client:
|
||||
# 健康检查
|
||||
try:
|
||||
resp = await client.get(f"{server_url}/health", timeout=5.0)
|
||||
if resp.status_code != 200:
|
||||
logger.warning("⚠️ MuseTalk 健康检查失败")
|
||||
return None
|
||||
health = resp.json()
|
||||
if not health.get("model_loaded"):
|
||||
logger.warning("⚠️ MuseTalk 模型未加载")
|
||||
return None
|
||||
except Exception:
|
||||
logger.warning("⚠️ 无法连接 MuseTalk 服务")
|
||||
return None
|
||||
|
||||
# 发送推理请求
|
||||
payload = {
|
||||
"video_path": str(Path(video_path).resolve()),
|
||||
"audio_path": str(Path(audio_path).resolve()),
|
||||
"video_out_path": str(Path(output_path).resolve()),
|
||||
"batch_size": settings.MUSETALK_BATCH_SIZE,
|
||||
}
|
||||
|
||||
response = await client.post(f"{server_url}/lipsync", json=payload)
|
||||
|
||||
if response.status_code == 200:
|
||||
result = response.json()
|
||||
if Path(result["output_path"]).exists():
|
||||
logger.info(f"✅ MuseTalk 推理完成: {output_path}")
|
||||
return output_path
|
||||
|
||||
logger.error(f"❌ MuseTalk 服务报错: {response.text}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"❌ MuseTalk 调用失败: {e}")
|
||||
return None
|
||||
|
||||
async def _call_persistent_server(self, video_path: str, audio_path: str, output_path: str) -> str:
|
||||
"""调用本地常驻服务 (server.py)"""
|
||||
server_url = "http://localhost:8007"
|
||||
@@ -369,7 +433,7 @@ class LipSyncService:
|
||||
}
|
||||
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=1200.0) as client:
|
||||
async with httpx.AsyncClient(timeout=3600.0) as client:
|
||||
# 先检查健康状态
|
||||
try:
|
||||
resp = await client.get(f"{server_url}/health", timeout=5.0)
|
||||
@@ -477,8 +541,18 @@ class LipSyncService:
|
||||
except:
|
||||
pass
|
||||
|
||||
# 检查 MuseTalk 服务
|
||||
musetalk_ready = False
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=5.0) as client:
|
||||
resp = await client.get(f"{self.musetalk_api_url}/health")
|
||||
if resp.status_code == 200:
|
||||
musetalk_ready = resp.json().get("model_loaded", False)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
return {
|
||||
"model": "LatentSync 1.6",
|
||||
"model": "LatentSync 1.6 + MuseTalk 1.5",
|
||||
"conda_env": conda_ok,
|
||||
"weights": weights_ok,
|
||||
"gpu": gpu_ok,
|
||||
@@ -486,5 +560,7 @@ class LipSyncService:
|
||||
"gpu_id": self.gpu_id,
|
||||
"inference_steps": settings.LATENTSYNC_INFERENCE_STEPS,
|
||||
"guidance_scale": settings.LATENTSYNC_GUIDANCE_SCALE,
|
||||
"ready": conda_ok and weights_ok and gpu_ok
|
||||
"ready": conda_ok and weights_ok and gpu_ok,
|
||||
"musetalk_ready": musetalk_ready,
|
||||
"lipsync_threshold": settings.LIPSYNC_DURATION_THRESHOLD,
|
||||
}
|
||||
|
||||
@@ -1,14 +1,14 @@
|
||||
"""
|
||||
视频合成服务
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import json
|
||||
import shlex
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
"""
|
||||
视频合成服务
|
||||
"""
|
||||
import os
|
||||
import subprocess
|
||||
import json
|
||||
import shlex
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
class VideoService:
|
||||
def __init__(self):
|
||||
pass
|
||||
@@ -96,7 +96,7 @@ class VideoService:
|
||||
"-map", "0:a?",
|
||||
"-c:v", "libx264",
|
||||
"-preset", "fast",
|
||||
"-crf", "18",
|
||||
"-crf", "23",
|
||||
"-c:a", "copy",
|
||||
"-movflags", "+faststart",
|
||||
output_path,
|
||||
@@ -113,146 +113,146 @@ class VideoService:
|
||||
|
||||
logger.warning("视频方向归一化失败,回退使用原视频")
|
||||
return video_path
|
||||
|
||||
def _run_ffmpeg(self, cmd: list) -> bool:
|
||||
cmd_str = ' '.join(shlex.quote(str(c)) for c in cmd)
|
||||
logger.debug(f"FFmpeg CMD: {cmd_str}")
|
||||
try:
|
||||
# Synchronous call for BackgroundTasks compatibility
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
shell=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
encoding='utf-8',
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.error(f"FFmpeg Error: {result.stderr}")
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"FFmpeg Exception: {e}")
|
||||
return False
|
||||
|
||||
def _get_duration(self, file_path: str) -> float:
|
||||
# Synchronous call for BackgroundTasks compatibility
|
||||
# 使用参数列表形式避免 shell=True 的命令注入风险
|
||||
cmd = [
|
||||
'ffprobe', '-v', 'error',
|
||||
'-show_entries', 'format=duration',
|
||||
'-of', 'default=noprint_wrappers=1:nokey=1',
|
||||
file_path
|
||||
]
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
return float(result.stdout.strip())
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
def mix_audio(
|
||||
self,
|
||||
voice_path: str,
|
||||
bgm_path: str,
|
||||
output_path: str,
|
||||
bgm_volume: float = 0.2
|
||||
) -> str:
|
||||
"""混合人声与背景音乐"""
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
volume = max(0.0, min(float(bgm_volume), 1.0))
|
||||
filter_complex = (
|
||||
f"[0:a]volume=1.0[a0];"
|
||||
f"[1:a]volume={volume}[a1];"
|
||||
f"[a0][a1]amix=inputs=2:duration=first:dropout_transition=2:normalize=0[aout]"
|
||||
)
|
||||
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-i", voice_path,
|
||||
"-stream_loop", "-1", "-i", bgm_path,
|
||||
"-filter_complex", filter_complex,
|
||||
"-map", "[aout]",
|
||||
"-c:a", "pcm_s16le",
|
||||
"-shortest",
|
||||
output_path,
|
||||
]
|
||||
|
||||
if self._run_ffmpeg(cmd):
|
||||
return output_path
|
||||
raise RuntimeError("FFmpeg audio mix failed")
|
||||
|
||||
async def compose(
|
||||
self,
|
||||
video_path: str,
|
||||
audio_path: str,
|
||||
output_path: str,
|
||||
subtitle_path: Optional[str] = None
|
||||
) -> str:
|
||||
"""合成视频"""
|
||||
# Ensure output dir
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
video_duration = self._get_duration(video_path)
|
||||
audio_duration = self._get_duration(audio_path)
|
||||
|
||||
# Audio loop if needed
|
||||
loop_count = 1
|
||||
if audio_duration > video_duration and video_duration > 0:
|
||||
loop_count = int(audio_duration / video_duration) + 1
|
||||
|
||||
cmd = ["ffmpeg", "-y"]
|
||||
|
||||
# Input video (stream_loop must be before -i)
|
||||
if loop_count > 1:
|
||||
cmd.extend(["-stream_loop", str(loop_count)])
|
||||
cmd.extend(["-i", video_path])
|
||||
|
||||
# Input audio
|
||||
cmd.extend(["-i", audio_path])
|
||||
|
||||
# Filter complex
|
||||
filter_complex = []
|
||||
|
||||
# Subtitles (skip for now to mimic previous state or implement basic)
|
||||
# Previous state: subtitles disabled due to font issues
|
||||
# if subtitle_path: ...
|
||||
|
||||
# Audio map with high quality encoding
|
||||
cmd.extend([
|
||||
"-c:v", "libx264",
|
||||
"-preset", "slow", # 慢速预设,更好的压缩效率
|
||||
"-crf", "18", # 高质量(与 LatentSync 一致)
|
||||
"-c:a", "aac",
|
||||
"-b:a", "192k", # 音频比特率
|
||||
"-shortest"
|
||||
])
|
||||
# Use audio from input 1
|
||||
cmd.extend(["-map", "0:v", "-map", "1:a"])
|
||||
|
||||
cmd.append(output_path)
|
||||
|
||||
if self._run_ffmpeg(cmd):
|
||||
return output_path
|
||||
else:
|
||||
raise RuntimeError("FFmpeg composition failed")
|
||||
|
||||
|
||||
def _run_ffmpeg(self, cmd: list) -> bool:
|
||||
cmd_str = ' '.join(shlex.quote(str(c)) for c in cmd)
|
||||
logger.debug(f"FFmpeg CMD: {cmd_str}")
|
||||
try:
|
||||
# Synchronous call for BackgroundTasks compatibility
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
shell=False,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
encoding='utf-8',
|
||||
)
|
||||
if result.returncode != 0:
|
||||
logger.error(f"FFmpeg Error: {result.stderr}")
|
||||
return False
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"FFmpeg Exception: {e}")
|
||||
return False
|
||||
|
||||
def _get_duration(self, file_path: str) -> float:
|
||||
# Synchronous call for BackgroundTasks compatibility
|
||||
# 使用参数列表形式避免 shell=True 的命令注入风险
|
||||
cmd = [
|
||||
'ffprobe', '-v', 'error',
|
||||
'-show_entries', 'format=duration',
|
||||
'-of', 'default=noprint_wrappers=1:nokey=1',
|
||||
file_path
|
||||
]
|
||||
try:
|
||||
result = subprocess.run(
|
||||
cmd,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
return float(result.stdout.strip())
|
||||
except Exception:
|
||||
return 0.0
|
||||
|
||||
def mix_audio(
|
||||
self,
|
||||
voice_path: str,
|
||||
bgm_path: str,
|
||||
output_path: str,
|
||||
bgm_volume: float = 0.2
|
||||
) -> str:
|
||||
"""混合人声与背景音乐"""
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
volume = max(0.0, min(float(bgm_volume), 1.0))
|
||||
filter_complex = (
|
||||
f"[0:a]volume=1.0[a0];"
|
||||
f"[1:a]volume={volume}[a1];"
|
||||
f"[a0][a1]amix=inputs=2:duration=first:dropout_transition=2:normalize=0[aout]"
|
||||
)
|
||||
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-i", voice_path,
|
||||
"-stream_loop", "-1", "-i", bgm_path,
|
||||
"-filter_complex", filter_complex,
|
||||
"-map", "[aout]",
|
||||
"-c:a", "pcm_s16le",
|
||||
"-shortest",
|
||||
output_path,
|
||||
]
|
||||
|
||||
if self._run_ffmpeg(cmd):
|
||||
return output_path
|
||||
raise RuntimeError("FFmpeg audio mix failed")
|
||||
|
||||
async def compose(
|
||||
self,
|
||||
video_path: str,
|
||||
audio_path: str,
|
||||
output_path: str,
|
||||
subtitle_path: Optional[str] = None
|
||||
) -> str:
|
||||
"""合成视频"""
|
||||
# Ensure output dir
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
video_duration = self._get_duration(video_path)
|
||||
audio_duration = self._get_duration(audio_path)
|
||||
|
||||
# Audio loop if needed
|
||||
loop_count = 1
|
||||
if audio_duration > video_duration and video_duration > 0:
|
||||
loop_count = int(audio_duration / video_duration) + 1
|
||||
|
||||
cmd = ["ffmpeg", "-y"]
|
||||
|
||||
# Input video (stream_loop must be before -i)
|
||||
if loop_count > 1:
|
||||
cmd.extend(["-stream_loop", str(loop_count)])
|
||||
cmd.extend(["-i", video_path])
|
||||
|
||||
# Input audio
|
||||
cmd.extend(["-i", audio_path])
|
||||
|
||||
# Filter complex
|
||||
filter_complex = []
|
||||
|
||||
# Subtitles (skip for now to mimic previous state or implement basic)
|
||||
# Previous state: subtitles disabled due to font issues
|
||||
# if subtitle_path: ...
|
||||
|
||||
# Audio map with high quality encoding
|
||||
cmd.extend([
|
||||
"-c:v", "libx264",
|
||||
"-preset", "medium", # 平衡速度与压缩效率
|
||||
"-crf", "20", # 最终输出:高质量(肉眼无损)
|
||||
"-c:a", "aac",
|
||||
"-b:a", "192k", # 音频比特率
|
||||
"-shortest"
|
||||
])
|
||||
# Use audio from input 1
|
||||
cmd.extend(["-map", "0:v", "-map", "1:a"])
|
||||
|
||||
cmd.append(output_path)
|
||||
|
||||
if self._run_ffmpeg(cmd):
|
||||
return output_path
|
||||
else:
|
||||
raise RuntimeError("FFmpeg composition failed")
|
||||
|
||||
def concat_videos(self, video_paths: list, output_path: str, target_fps: int = 25) -> str:
|
||||
"""使用 FFmpeg concat demuxer 拼接多个视频片段"""
|
||||
if not video_paths:
|
||||
raise ValueError("No video segments to concat")
|
||||
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 生成 concat list 文件
|
||||
list_path = Path(output_path).parent / f"{Path(output_path).stem}_concat.txt"
|
||||
with open(list_path, "w", encoding="utf-8") as f:
|
||||
for vp in video_paths:
|
||||
f.write(f"file '{vp}'\n")
|
||||
|
||||
if not video_paths:
|
||||
raise ValueError("No video segments to concat")
|
||||
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 生成 concat list 文件
|
||||
list_path = Path(output_path).parent / f"{Path(output_path).stem}_concat.txt"
|
||||
with open(list_path, "w", encoding="utf-8") as f:
|
||||
for vp in video_paths:
|
||||
f.write(f"file '{vp}'\n")
|
||||
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-f", "concat",
|
||||
@@ -264,44 +264,44 @@ class VideoService:
|
||||
"-r", str(target_fps),
|
||||
"-c:v", "libx264",
|
||||
"-preset", "fast",
|
||||
"-crf", "18",
|
||||
"-crf", "23",
|
||||
"-pix_fmt", "yuv420p",
|
||||
"-movflags", "+faststart",
|
||||
output_path,
|
||||
]
|
||||
|
||||
try:
|
||||
if self._run_ffmpeg(cmd):
|
||||
return output_path
|
||||
else:
|
||||
raise RuntimeError("FFmpeg concat failed")
|
||||
finally:
|
||||
try:
|
||||
list_path.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def split_audio(self, audio_path: str, start: float, end: float, output_path: str) -> str:
|
||||
"""用 FFmpeg 按时间范围切分音频"""
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
duration = end - start
|
||||
if duration <= 0:
|
||||
raise ValueError(f"Invalid audio split range: start={start}, end={end}, duration={duration}")
|
||||
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-ss", str(start),
|
||||
"-t", str(duration),
|
||||
"-i", audio_path,
|
||||
"-c", "copy",
|
||||
output_path,
|
||||
]
|
||||
|
||||
if self._run_ffmpeg(cmd):
|
||||
return output_path
|
||||
raise RuntimeError(f"FFmpeg audio split failed: {start}-{end}")
|
||||
|
||||
|
||||
try:
|
||||
if self._run_ffmpeg(cmd):
|
||||
return output_path
|
||||
else:
|
||||
raise RuntimeError("FFmpeg concat failed")
|
||||
finally:
|
||||
try:
|
||||
list_path.unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def split_audio(self, audio_path: str, start: float, end: float, output_path: str) -> str:
|
||||
"""用 FFmpeg 按时间范围切分音频"""
|
||||
Path(output_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
duration = end - start
|
||||
if duration <= 0:
|
||||
raise ValueError(f"Invalid audio split range: start={start}, end={end}, duration={duration}")
|
||||
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-ss", str(start),
|
||||
"-t", str(duration),
|
||||
"-i", audio_path,
|
||||
"-c", "copy",
|
||||
output_path,
|
||||
]
|
||||
|
||||
if self._run_ffmpeg(cmd):
|
||||
return output_path
|
||||
raise RuntimeError(f"FFmpeg audio split failed: {start}-{end}")
|
||||
|
||||
def get_resolution(self, file_path: str) -> tuple[int, int]:
|
||||
"""获取视频有效显示分辨率(考虑旋转元数据)。"""
|
||||
info = self.get_video_metadata(file_path)
|
||||
@@ -309,7 +309,7 @@ class VideoService:
|
||||
int(info.get("effective_width") or 0),
|
||||
int(info.get("effective_height") or 0),
|
||||
)
|
||||
|
||||
|
||||
def prepare_segment(self, video_path: str, target_duration: float, output_path: str,
|
||||
target_resolution: Optional[tuple] = None, source_start: float = 0.0,
|
||||
source_end: Optional[float] = None, target_fps: Optional[int] = None) -> str:
|
||||
@@ -353,21 +353,21 @@ class VideoService:
|
||||
"-i", video_path,
|
||||
"-t", str(available),
|
||||
"-an",
|
||||
"-c:v", "libx264", "-preset", "fast", "-crf", "18",
|
||||
"-c:v", "libx264", "-preset", "fast", "-crf", "23",
|
||||
trim_temp,
|
||||
]
|
||||
if not self._run_ffmpeg(trim_cmd):
|
||||
raise RuntimeError(f"FFmpeg trim for loop failed: {video_path}")
|
||||
actual_input = trim_temp
|
||||
source_start = 0.0 # 已裁剪,不需要再 seek
|
||||
# 重新计算循环次数(基于裁剪后文件)
|
||||
available = self._get_duration(trim_temp) or available
|
||||
|
||||
loop_count = int(target_duration / available) + 1 if needs_loop else 0
|
||||
|
||||
cmd = ["ffmpeg", "-y"]
|
||||
if needs_loop:
|
||||
cmd.extend(["-stream_loop", str(loop_count)])
|
||||
if not self._run_ffmpeg(trim_cmd):
|
||||
raise RuntimeError(f"FFmpeg trim for loop failed: {video_path}")
|
||||
actual_input = trim_temp
|
||||
source_start = 0.0 # 已裁剪,不需要再 seek
|
||||
# 重新计算循环次数(基于裁剪后文件)
|
||||
available = self._get_duration(trim_temp) or available
|
||||
|
||||
loop_count = int(target_duration / available) + 1 if needs_loop else 0
|
||||
|
||||
cmd = ["ffmpeg", "-y"]
|
||||
if needs_loop:
|
||||
cmd.extend(["-stream_loop", str(loop_count)])
|
||||
if source_start > 0:
|
||||
cmd.extend(["-ss", str(source_start)])
|
||||
cmd.extend(["-i", actual_input, "-t", str(target_duration), "-an"])
|
||||
@@ -386,20 +386,20 @@ class VideoService:
|
||||
|
||||
# 需要循环、缩放或指定起点时必须重编码,否则用 stream copy 保持原画质
|
||||
if needs_loop or needs_scale or source_start > 0 or has_source_end or needs_fps:
|
||||
cmd.extend(["-c:v", "libx264", "-preset", "fast", "-crf", "18"])
|
||||
cmd.extend(["-c:v", "libx264", "-preset", "fast", "-crf", "23"])
|
||||
else:
|
||||
cmd.extend(["-c:v", "copy"])
|
||||
|
||||
cmd.append(output_path)
|
||||
|
||||
try:
|
||||
if self._run_ffmpeg(cmd):
|
||||
return output_path
|
||||
raise RuntimeError(f"FFmpeg prepare_segment failed: {video_path}")
|
||||
finally:
|
||||
# 清理裁剪临时文件
|
||||
if trim_temp:
|
||||
try:
|
||||
Path(trim_temp).unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
cmd.append(output_path)
|
||||
|
||||
try:
|
||||
if self._run_ffmpeg(cmd):
|
||||
return output_path
|
||||
raise RuntimeError(f"FFmpeg prepare_segment failed: {video_path}")
|
||||
finally:
|
||||
# 清理裁剪临时文件
|
||||
if trim_temp:
|
||||
try:
|
||||
Path(trim_temp).unlink(missing_ok=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
@@ -247,19 +247,67 @@ class WhisperService:
|
||||
line_segments = split_segment_to_lines(all_words, max_chars)
|
||||
all_segments.extend(line_segments)
|
||||
|
||||
# 如果提供了 original_text,用原文替换 Whisper 转录文字
|
||||
# 如果提供了 original_text,用原文替换 Whisper 转录文字,保留语音节奏
|
||||
if original_text and original_text.strip() and whisper_first_start is not None:
|
||||
logger.info(f"Using original_text for subtitles (len={len(original_text)}), "
|
||||
f"Whisper time range: {whisper_first_start:.2f}-{whisper_last_end:.2f}s")
|
||||
# 用 split_word_to_chars 拆分原文
|
||||
# 收集 Whisper 逐字时间戳(保留真实语音节奏)
|
||||
whisper_chars = []
|
||||
for seg in all_segments:
|
||||
whisper_chars.extend(seg.get("words", []))
|
||||
|
||||
# 用原文字符 + Whisper 节奏生成新的时间戳
|
||||
orig_chars = split_word_to_chars(
|
||||
original_text.strip(),
|
||||
whisper_first_start,
|
||||
whisper_last_end
|
||||
)
|
||||
if orig_chars:
|
||||
|
||||
if orig_chars and len(whisper_chars) >= 2:
|
||||
# 将原文字符按比例映射到 Whisper 的时间节奏上
|
||||
n_w = len(whisper_chars)
|
||||
n_o = len(orig_chars)
|
||||
w_starts = [c["start"] for c in whisper_chars]
|
||||
w_final_end = whisper_chars[-1]["end"]
|
||||
|
||||
logger.info(
|
||||
f"Using original_text for subtitles (len={len(original_text)}), "
|
||||
f"rhythm-mapping {n_o} orig chars onto {n_w} Whisper chars, "
|
||||
f"time range: {whisper_first_start:.2f}-{whisper_last_end:.2f}s"
|
||||
)
|
||||
|
||||
remapped = []
|
||||
for i, oc in enumerate(orig_chars):
|
||||
# 原文第 i 个字符对应 Whisper 时间线的位置
|
||||
pos = (i / n_o) * n_w
|
||||
idx = min(int(pos), n_w - 1)
|
||||
frac = pos - idx
|
||||
t_start = (
|
||||
w_starts[idx] + frac * (w_starts[idx + 1] - w_starts[idx])
|
||||
if idx < n_w - 1
|
||||
else w_starts[idx] + frac * (w_final_end - w_starts[idx])
|
||||
)
|
||||
|
||||
# 结束时间 = 下一个字符的开始时间
|
||||
pos_next = ((i + 1) / n_o) * n_w
|
||||
idx_n = min(int(pos_next), n_w - 1)
|
||||
frac_n = pos_next - idx_n
|
||||
t_end = (
|
||||
w_starts[idx_n] + frac_n * (w_starts[idx_n + 1] - w_starts[idx_n])
|
||||
if idx_n < n_w - 1
|
||||
else w_starts[idx_n] + frac_n * (w_final_end - w_starts[idx_n])
|
||||
)
|
||||
|
||||
remapped.append({
|
||||
"word": oc["word"],
|
||||
"start": round(t_start, 3),
|
||||
"end": round(t_end, 3),
|
||||
})
|
||||
|
||||
all_segments = split_segment_to_lines(remapped, max_chars)
|
||||
logger.info(f"Rebuilt {len(all_segments)} subtitle segments (rhythm-mapped)")
|
||||
elif orig_chars:
|
||||
# Whisper 字符不足,退回线性插值
|
||||
all_segments = split_segment_to_lines(orig_chars, max_chars)
|
||||
logger.info(f"Rebuilt {len(all_segments)} subtitle segments from original text")
|
||||
logger.info(f"Rebuilt {len(all_segments)} subtitle segments (linear fallback)")
|
||||
|
||||
logger.info(f"Generated {len(all_segments)} subtitle segments")
|
||||
return {"segments": all_segments}
|
||||
|
||||
@@ -54,5 +54,61 @@
|
||||
"letter_spacing": 1,
|
||||
"bottom_margin": 72,
|
||||
"is_default": false
|
||||
},
|
||||
{
|
||||
"id": "subtitle_pink",
|
||||
"label": "少女粉",
|
||||
"font_file": "DingTalk JinBuTi.ttf",
|
||||
"font_family": "DingTalkJinBuTi",
|
||||
"font_size": 56,
|
||||
"highlight_color": "#FF69B4",
|
||||
"normal_color": "#FFFFFF",
|
||||
"stroke_color": "#1A0010",
|
||||
"stroke_size": 3,
|
||||
"letter_spacing": 2,
|
||||
"bottom_margin": 80,
|
||||
"is_default": false
|
||||
},
|
||||
{
|
||||
"id": "subtitle_lime",
|
||||
"label": "清新绿",
|
||||
"font_file": "DingTalk Sans.ttf",
|
||||
"font_family": "DingTalkSans",
|
||||
"font_size": 50,
|
||||
"highlight_color": "#76FF03",
|
||||
"normal_color": "#FFFFFF",
|
||||
"stroke_color": "#001A00",
|
||||
"stroke_size": 3,
|
||||
"letter_spacing": 1,
|
||||
"bottom_margin": 78,
|
||||
"is_default": false
|
||||
},
|
||||
{
|
||||
"id": "subtitle_gold",
|
||||
"label": "金色隶书",
|
||||
"font_file": "阿里妈妈刀隶体.ttf",
|
||||
"font_family": "AliMamaDaoLiTi",
|
||||
"font_size": 56,
|
||||
"highlight_color": "#FDE68A",
|
||||
"normal_color": "#E8D5B0",
|
||||
"stroke_color": "#2B1B00",
|
||||
"stroke_size": 3,
|
||||
"letter_spacing": 3,
|
||||
"bottom_margin": 80,
|
||||
"is_default": false
|
||||
},
|
||||
{
|
||||
"id": "subtitle_kai",
|
||||
"label": "楷体红字",
|
||||
"font_file": "simkai.ttf",
|
||||
"font_family": "SimKai",
|
||||
"font_size": 54,
|
||||
"highlight_color": "#FF4444",
|
||||
"normal_color": "#FFFFFF",
|
||||
"stroke_color": "#000000",
|
||||
"stroke_size": 3,
|
||||
"letter_spacing": 2,
|
||||
"bottom_margin": 80,
|
||||
"is_default": false
|
||||
}
|
||||
]
|
||||
|
||||
@@ -7,7 +7,7 @@
|
||||
"font_size": 90,
|
||||
"color": "#FFFFFF",
|
||||
"stroke_color": "#000000",
|
||||
"stroke_size": 8,
|
||||
"stroke_size": 5,
|
||||
"letter_spacing": 5,
|
||||
"top_margin": 62,
|
||||
"font_weight": 900,
|
||||
@@ -21,7 +21,7 @@
|
||||
"font_size": 72,
|
||||
"color": "#FFFFFF",
|
||||
"stroke_color": "#000000",
|
||||
"stroke_size": 8,
|
||||
"stroke_size": 5,
|
||||
"letter_spacing": 4,
|
||||
"top_margin": 60,
|
||||
"font_weight": 900,
|
||||
@@ -35,7 +35,7 @@
|
||||
"font_size": 70,
|
||||
"color": "#FDE68A",
|
||||
"stroke_color": "#2B1B00",
|
||||
"stroke_size": 8,
|
||||
"stroke_size": 5,
|
||||
"letter_spacing": 3,
|
||||
"top_margin": 58,
|
||||
"font_weight": 800,
|
||||
@@ -49,10 +49,122 @@
|
||||
"font_size": 72,
|
||||
"color": "#FFFFFF",
|
||||
"stroke_color": "#1F0A00",
|
||||
"stroke_size": 8,
|
||||
"stroke_size": 5,
|
||||
"letter_spacing": 4,
|
||||
"top_margin": 60,
|
||||
"font_weight": 900,
|
||||
"is_default": false
|
||||
},
|
||||
{
|
||||
"id": "title_pangmen",
|
||||
"label": "庞门正道",
|
||||
"font_file": "title/庞门正道标题体3.0.ttf",
|
||||
"font_family": "PangMenZhengDao",
|
||||
"font_size": 80,
|
||||
"color": "#FFFFFF",
|
||||
"stroke_color": "#000000",
|
||||
"stroke_size": 5,
|
||||
"letter_spacing": 5,
|
||||
"top_margin": 60,
|
||||
"font_weight": 900,
|
||||
"is_default": false
|
||||
},
|
||||
{
|
||||
"id": "title_round",
|
||||
"label": "优设标题圆",
|
||||
"font_file": "title/优设标题圆.otf",
|
||||
"font_family": "YouSheBiaoTiYuan",
|
||||
"font_size": 78,
|
||||
"color": "#FFFFFF",
|
||||
"stroke_color": "#4A1A6B",
|
||||
"stroke_size": 5,
|
||||
"letter_spacing": 4,
|
||||
"top_margin": 60,
|
||||
"font_weight": 900,
|
||||
"is_default": false
|
||||
},
|
||||
{
|
||||
"id": "title_alibaba",
|
||||
"label": "阿里数黑体",
|
||||
"font_file": "title/阿里巴巴数黑体.ttf",
|
||||
"font_family": "AlibabaShuHeiTi",
|
||||
"font_size": 72,
|
||||
"color": "#FFFFFF",
|
||||
"stroke_color": "#000000",
|
||||
"stroke_size": 4,
|
||||
"letter_spacing": 3,
|
||||
"top_margin": 60,
|
||||
"font_weight": 900,
|
||||
"is_default": false
|
||||
},
|
||||
{
|
||||
"id": "title_chaohei",
|
||||
"label": "文道潮黑",
|
||||
"font_file": "title/文道潮黑.ttf",
|
||||
"font_family": "WenDaoChaoHei",
|
||||
"font_size": 76,
|
||||
"color": "#00E5FF",
|
||||
"stroke_color": "#001A33",
|
||||
"stroke_size": 5,
|
||||
"letter_spacing": 4,
|
||||
"top_margin": 60,
|
||||
"font_weight": 900,
|
||||
"is_default": false
|
||||
},
|
||||
{
|
||||
"id": "title_wujie",
|
||||
"label": "无界黑",
|
||||
"font_file": "title/标小智无界黑.otf",
|
||||
"font_family": "BiaoXiaoZhiWuJieHei",
|
||||
"font_size": 74,
|
||||
"color": "#FFFFFF",
|
||||
"stroke_color": "#1A1A1A",
|
||||
"stroke_size": 4,
|
||||
"letter_spacing": 3,
|
||||
"top_margin": 60,
|
||||
"font_weight": 900,
|
||||
"is_default": false
|
||||
},
|
||||
{
|
||||
"id": "title_houdi",
|
||||
"label": "厚底黑",
|
||||
"font_file": "title/Aa厚底黑.ttf",
|
||||
"font_family": "AaHouDiHei",
|
||||
"font_size": 76,
|
||||
"color": "#FF6B6B",
|
||||
"stroke_color": "#1A0000",
|
||||
"stroke_size": 5,
|
||||
"letter_spacing": 4,
|
||||
"top_margin": 60,
|
||||
"font_weight": 900,
|
||||
"is_default": false
|
||||
},
|
||||
{
|
||||
"id": "title_banyuan",
|
||||
"label": "寒蝉半圆体",
|
||||
"font_file": "title/寒蝉半圆体.otf",
|
||||
"font_family": "HanChanBanYuan",
|
||||
"font_size": 78,
|
||||
"color": "#FFFFFF",
|
||||
"stroke_color": "#000000",
|
||||
"stroke_size": 5,
|
||||
"letter_spacing": 4,
|
||||
"top_margin": 60,
|
||||
"font_weight": 900,
|
||||
"is_default": false
|
||||
},
|
||||
{
|
||||
"id": "title_jixiang",
|
||||
"label": "欣意吉祥宋",
|
||||
"font_file": "title/字体圈欣意吉祥宋.ttf",
|
||||
"font_family": "XinYiJiXiangSong",
|
||||
"font_size": 70,
|
||||
"color": "#FDE68A",
|
||||
"stroke_color": "#2B1B00",
|
||||
"stroke_size": 5,
|
||||
"letter_spacing": 3,
|
||||
"top_margin": 58,
|
||||
"font_weight": 800,
|
||||
"is_default": false
|
||||
}
|
||||
]
|
||||
|
||||
@@ -3,9 +3,11 @@
|
||||
import { useState } from 'react';
|
||||
import { useRouter } from 'next/navigation';
|
||||
import { login } from "@/shared/lib/auth";
|
||||
import { useAuth } from "@/shared/contexts/AuthContext";
|
||||
|
||||
export default function LoginPage() {
|
||||
const router = useRouter();
|
||||
const { setUser } = useAuth();
|
||||
const [phone, setPhone] = useState('');
|
||||
const [password, setPassword] = useState('');
|
||||
const [error, setError] = useState('');
|
||||
@@ -29,6 +31,7 @@ export default function LoginPage() {
|
||||
sessionStorage.setItem('payment_token', result.paymentToken);
|
||||
router.push('/pay');
|
||||
} else if (result.success) {
|
||||
if (result.user) setUser(result.user);
|
||||
router.push('/');
|
||||
} else {
|
||||
setError(result.message || '登录失败');
|
||||
|
||||
@@ -106,6 +106,10 @@ export default function AccountSettingsDropdown() {
|
||||
{/* 下拉菜单 */}
|
||||
{isOpen && (
|
||||
<div className="absolute right-0 mt-2 bg-gray-800 border border-white/10 rounded-lg shadow-xl z-[160] overflow-hidden whitespace-nowrap">
|
||||
{/* 账户名称 */}
|
||||
<div className="px-3 py-2 border-b border-white/10 text-center">
|
||||
<div className="text-sm text-white font-medium">{user?.phone ? `${user.phone.slice(0, 3)}****${user.phone.slice(-4)}` : '未知账户'}</div>
|
||||
</div>
|
||||
{/* 有效期显示 */}
|
||||
<div className="px-3 py-2 border-b border-white/10 text-center">
|
||||
<div className="text-xs text-gray-400">账户有效期</div>
|
||||
@@ -188,6 +192,7 @@ export default function AccountSettingsDropdown() {
|
||||
onClick={() => {
|
||||
setShowPasswordModal(false);
|
||||
setError('');
|
||||
setSuccess('');
|
||||
setOldPassword('');
|
||||
setNewPassword('');
|
||||
setConfirmPassword('');
|
||||
|
||||
@@ -12,7 +12,7 @@ interface GeneratedVideo {
|
||||
}
|
||||
|
||||
interface UseGeneratedVideosOptions {
|
||||
|
||||
storageKey: string;
|
||||
selectedVideoId: string | null;
|
||||
setSelectedVideoId: React.Dispatch<React.SetStateAction<string | null>>;
|
||||
setGeneratedVideo: React.Dispatch<React.SetStateAction<string | null>>;
|
||||
@@ -20,7 +20,7 @@ interface UseGeneratedVideosOptions {
|
||||
}
|
||||
|
||||
export const useGeneratedVideos = ({
|
||||
|
||||
storageKey,
|
||||
selectedVideoId,
|
||||
setSelectedVideoId,
|
||||
setGeneratedVideo,
|
||||
@@ -45,6 +45,8 @@ export const useGeneratedVideos = ({
|
||||
if (preferVideoId === "__latest__") {
|
||||
setSelectedVideoId(videos[0].id);
|
||||
setGeneratedVideo(resolveMediaUrl(videos[0].path));
|
||||
// 写入跨页面共享标记,让另一个页面也能感知最新生成的视频
|
||||
localStorage.setItem(`vigent_${storageKey}_latestGeneratedVideoId`, videos[0].id);
|
||||
} else {
|
||||
const found = videos.find(v => v.id === preferVideoId);
|
||||
if (found) {
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useEffect, useRef, useState } from "react";
|
||||
import { useEffect, useMemo, useRef, useState } from "react";
|
||||
import api from "@/shared/api/axios";
|
||||
import {
|
||||
buildTextShadow,
|
||||
@@ -26,6 +26,7 @@ import { useRefAudios } from "@/features/home/model/useRefAudios";
|
||||
import { useTitleSubtitleStyles } from "@/features/home/model/useTitleSubtitleStyles";
|
||||
import { useTimelineEditor } from "@/features/home/model/useTimelineEditor";
|
||||
import { useSavedScripts } from "@/features/home/model/useSavedScripts";
|
||||
import { useVideoFrameCapture } from "@/features/home/model/useVideoFrameCapture";
|
||||
import { ApiResponse, unwrap } from "@/shared/api/types";
|
||||
|
||||
const VOICES: Record<string, { id: string; name: string }[]> = {
|
||||
@@ -280,6 +281,9 @@ export const useHomeController = () => {
|
||||
// 文案提取模态框
|
||||
const [extractModalOpen, setExtractModalOpen] = useState(false);
|
||||
|
||||
// AI 改写模态框
|
||||
const [rewriteModalOpen, setRewriteModalOpen] = useState(false);
|
||||
|
||||
// 获取存储 key 的前缀(登录用户使用 userId,未登录使用 guest)
|
||||
const storageKey = userId || "guest";
|
||||
|
||||
@@ -361,7 +365,7 @@ export const useHomeController = () => {
|
||||
fetchGeneratedVideos,
|
||||
deleteVideo,
|
||||
} = useGeneratedVideos({
|
||||
|
||||
storageKey,
|
||||
selectedVideoId,
|
||||
setSelectedVideoId,
|
||||
setGeneratedVideo,
|
||||
@@ -395,6 +399,18 @@ export const useHomeController = () => {
|
||||
storageKey,
|
||||
});
|
||||
|
||||
// 时间轴第一段素材的视频 URL(用于帧截取预览)
|
||||
// 有时间轴段时用第一段,没有(如未选配音)回退到 selectedMaterials[0]
|
||||
const firstTimelineMaterialUrl = useMemo(() => {
|
||||
const firstSeg = timelineSegments[0];
|
||||
const matId = firstSeg?.materialId ?? selectedMaterials[0];
|
||||
if (!matId) return null;
|
||||
const mat = materials.find((m) => m.id === matId);
|
||||
return mat?.path ? resolveMediaUrl(mat.path) : null;
|
||||
}, [materials, timelineSegments, selectedMaterials]);
|
||||
|
||||
const materialPosterUrl = useVideoFrameCapture(showStylePreview ? firstTimelineMaterialUrl : null);
|
||||
|
||||
useEffect(() => {
|
||||
if (isAuthLoading || !userId) return;
|
||||
let active = true;
|
||||
@@ -617,8 +633,19 @@ export const useHomeController = () => {
|
||||
// 移除重复的 BGM 持久化恢复逻辑 (已统一移动到 useHomePersistence 中)
|
||||
// useEffect(() => { ... })
|
||||
|
||||
// 时间门控:页面加载后 1 秒内禁止所有列表自动滚动效果
|
||||
// 防止持久化恢复 + 异步数据加载触发 scrollIntoView 导致移动端页面跳动
|
||||
const scrollEffectsEnabled = useRef(false);
|
||||
useEffect(() => {
|
||||
if (!selectedBgmId) return;
|
||||
const timer = setTimeout(() => {
|
||||
scrollEffectsEnabled.current = true;
|
||||
}, 1000);
|
||||
return () => clearTimeout(timer);
|
||||
}, []);
|
||||
|
||||
// BGM 列表滚动
|
||||
useEffect(() => {
|
||||
if (!selectedBgmId || !scrollEffectsEnabled.current) return;
|
||||
const container = bgmListContainerRef.current;
|
||||
const target = bgmItemRefs.current[selectedBgmId];
|
||||
if (container && target) {
|
||||
@@ -626,16 +653,10 @@ export const useHomeController = () => {
|
||||
}
|
||||
}, [selectedBgmId, bgmList]);
|
||||
|
||||
// 素材列表滚动:跳过首次恢复,仅用户主动操作时滚动
|
||||
const materialScrollReady = useRef(false);
|
||||
// 素材列表滚动
|
||||
useEffect(() => {
|
||||
const firstSelected = selectedMaterials[0];
|
||||
if (!firstSelected) return;
|
||||
if (!materialScrollReady.current) {
|
||||
// 首次有选中素材时标记就绪,但不滚动(避免刷新后整页跳动)
|
||||
materialScrollReady.current = true;
|
||||
return;
|
||||
}
|
||||
if (!firstSelected || !scrollEffectsEnabled.current) return;
|
||||
const target = materialItemRefs.current[firstSelected];
|
||||
if (target) {
|
||||
target.scrollIntoView({ block: "nearest", behavior: "smooth" });
|
||||
@@ -660,14 +681,9 @@ export const useHomeController = () => {
|
||||
}
|
||||
}, [isRestored, bgmList, selectedBgmId, enableBgm, setSelectedBgmId]);
|
||||
|
||||
const videoScrollReady = useRef(false);
|
||||
// 视频列表滚动
|
||||
useEffect(() => {
|
||||
if (!selectedVideoId) return;
|
||||
if (!videoScrollReady.current) {
|
||||
videoScrollReady.current = true;
|
||||
return;
|
||||
}
|
||||
|
||||
if (!selectedVideoId || !scrollEffectsEnabled.current) return;
|
||||
const target = videoItemRefs.current[selectedVideoId];
|
||||
if (target) {
|
||||
target.scrollIntoView({ block: "nearest", behavior: "smooth" });
|
||||
@@ -978,11 +994,14 @@ export const useHomeController = () => {
|
||||
payload.title_font_size = Math.round(titleFontSize);
|
||||
}
|
||||
|
||||
if (videoTitle.trim()) {
|
||||
if (videoTitle.trim() || videoSecondaryTitle.trim()) {
|
||||
payload.title_display_mode = titleDisplayMode;
|
||||
if (titleDisplayMode === "short") {
|
||||
payload.title_duration = DEFAULT_SHORT_TITLE_DURATION;
|
||||
}
|
||||
}
|
||||
|
||||
if (videoTitle.trim()) {
|
||||
payload.title_top_margin = Math.round(titleTopMargin);
|
||||
}
|
||||
|
||||
@@ -1077,6 +1096,8 @@ export const useHomeController = () => {
|
||||
setText,
|
||||
extractModalOpen,
|
||||
setExtractModalOpen,
|
||||
rewriteModalOpen,
|
||||
setRewriteModalOpen,
|
||||
handleGenerateMeta,
|
||||
isGeneratingMeta,
|
||||
handleTranslate,
|
||||
@@ -1123,6 +1144,7 @@ export const useHomeController = () => {
|
||||
getFontFormat,
|
||||
buildTextShadow,
|
||||
materialDimensions,
|
||||
materialPosterUrl,
|
||||
ttsMode,
|
||||
setTtsMode,
|
||||
voices: VOICES[textLang] || VOICES["zh-CN"],
|
||||
|
||||
@@ -142,7 +142,8 @@ export const useHomePersistence = ({
|
||||
const savedTitleFontSize = localStorage.getItem(`vigent_${storageKey}_titleFontSize`);
|
||||
const savedSecondaryTitleFontSize = localStorage.getItem(`vigent_${storageKey}_secondaryTitleFontSize`);
|
||||
const savedBgmId = localStorage.getItem(`vigent_${storageKey}_bgmId`);
|
||||
const savedSelectedVideoId = localStorage.getItem(`vigent_${storageKey}_selectedVideoId`);
|
||||
const savedSelectedVideoId = localStorage.getItem(`vigent_${storageKey}_latestGeneratedVideoId`)
|
||||
|| localStorage.getItem(`vigent_${storageKey}_selectedVideoId`);
|
||||
const savedSelectedAudioId = localStorage.getItem(`vigent_${storageKey}_selectedAudioId`);
|
||||
const savedBgmVolume = localStorage.getItem(`vigent_${storageKey}_bgmVolume`);
|
||||
const savedEnableBgm = localStorage.getItem(`vigent_${storageKey}_enableBgm`);
|
||||
@@ -205,6 +206,8 @@ export const useHomePersistence = ({
|
||||
if (savedBgmVolume) setBgmVolume(parseFloat(savedBgmVolume));
|
||||
if (savedEnableBgm !== null) setEnableBgm(savedEnableBgm === 'true');
|
||||
if (savedSelectedVideoId) setSelectedVideoId(savedSelectedVideoId);
|
||||
// 消费后清除跨页面共享标记,避免反复覆盖
|
||||
localStorage.removeItem(`vigent_${storageKey}_latestGeneratedVideoId`);
|
||||
if (savedSelectedAudioId) setSelectedAudioId(savedSelectedAudioId);
|
||||
|
||||
if (savedTitleTopMargin) {
|
||||
|
||||
94
frontend/src/features/home/model/useVideoFrameCapture.ts
Normal file
94
frontend/src/features/home/model/useVideoFrameCapture.ts
Normal file
@@ -0,0 +1,94 @@
|
||||
import { useEffect, useState } from "react";
|
||||
|
||||
/** 预览窗口最大 280px 宽,截取无需超过此尺寸 */
|
||||
const MAX_CAPTURE_WIDTH = 480;
|
||||
|
||||
/**
|
||||
* 从视频 URL 截取 0.1s 处的帧,返回 JPEG data URL。
|
||||
* 失败时返回 null(降级渐变背景)。
|
||||
*/
|
||||
export function useVideoFrameCapture(videoUrl: string | null): string | null {
|
||||
const [frameUrl, setFrameUrl] = useState<string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (!videoUrl) {
|
||||
setFrameUrl(null);
|
||||
return;
|
||||
}
|
||||
|
||||
let isActive = true;
|
||||
const video = document.createElement("video");
|
||||
video.crossOrigin = "anonymous";
|
||||
video.muted = true;
|
||||
video.preload = "auto";
|
||||
video.playsInline = true;
|
||||
|
||||
const cleanup = () => {
|
||||
video.removeEventListener("loadedmetadata", onLoaded);
|
||||
video.removeEventListener("canplay", onLoaded);
|
||||
video.removeEventListener("seeked", onSeeked);
|
||||
video.removeEventListener("error", onError);
|
||||
video.src = "";
|
||||
video.load();
|
||||
};
|
||||
|
||||
const onSeeked = () => {
|
||||
if (!isActive) return;
|
||||
try {
|
||||
const vw = video.videoWidth;
|
||||
const vh = video.videoHeight;
|
||||
if (!vw || !vh) {
|
||||
if (isActive) setFrameUrl(null);
|
||||
cleanup();
|
||||
return;
|
||||
}
|
||||
|
||||
const scale = Math.min(1, MAX_CAPTURE_WIDTH / vw);
|
||||
const cw = Math.round(vw * scale);
|
||||
const ch = Math.round(vh * scale);
|
||||
|
||||
const canvas = document.createElement("canvas");
|
||||
canvas.width = cw;
|
||||
canvas.height = ch;
|
||||
const ctx = canvas.getContext("2d");
|
||||
if (!ctx) {
|
||||
if (isActive) setFrameUrl(null);
|
||||
cleanup();
|
||||
return;
|
||||
}
|
||||
ctx.drawImage(video, 0, 0, cw, ch);
|
||||
const dataUrl = canvas.toDataURL("image/jpeg", 0.7);
|
||||
if (isActive) setFrameUrl(dataUrl);
|
||||
} catch {
|
||||
if (isActive) setFrameUrl(null);
|
||||
}
|
||||
cleanup();
|
||||
};
|
||||
|
||||
let seeked = false;
|
||||
const onLoaded = () => {
|
||||
if (!isActive || seeked) return;
|
||||
seeked = true;
|
||||
video.currentTime = 0.1;
|
||||
};
|
||||
|
||||
const onError = () => {
|
||||
if (isActive) setFrameUrl(null);
|
||||
cleanup();
|
||||
};
|
||||
|
||||
// 先绑定监听,再设 src
|
||||
video.addEventListener("loadedmetadata", onLoaded);
|
||||
video.addEventListener("canplay", onLoaded);
|
||||
video.addEventListener("seeked", onSeeked);
|
||||
video.addEventListener("error", onError);
|
||||
video.src = videoUrl;
|
||||
|
||||
return () => {
|
||||
isActive = false;
|
||||
cleanup();
|
||||
};
|
||||
}, [videoUrl]);
|
||||
|
||||
return frameUrl;
|
||||
}
|
||||
@@ -43,7 +43,7 @@ export function BgmPanel({
|
||||
return (
|
||||
<div className="bg-white/5 rounded-2xl p-6 border border-white/10 backdrop-blur-sm">
|
||||
<div className="flex items-center justify-between mb-4">
|
||||
<h2 className="text-lg font-semibold text-white flex items-center gap-2">🎵 背景音乐</h2>
|
||||
<h2 className="text-lg font-semibold text-white flex items-center gap-2">五、背景音乐</h2>
|
||||
<div className="flex items-center gap-2">
|
||||
<button
|
||||
onClick={onRefresh}
|
||||
|
||||
@@ -213,7 +213,7 @@ export function ClipTrimmer({
|
||||
{/* Custom range track */}
|
||||
<div
|
||||
ref={trackRef}
|
||||
className="relative h-8 cursor-pointer select-none touch-none"
|
||||
className="relative h-10 cursor-pointer select-none touch-none"
|
||||
onPointerMove={handleTrackPointerMove}
|
||||
onPointerUp={handleTrackPointerUp}
|
||||
onPointerLeave={handleTrackPointerUp}
|
||||
@@ -242,7 +242,7 @@ export function ClipTrimmer({
|
||||
{/* Start thumb */}
|
||||
<div
|
||||
onPointerDown={(e) => handleThumbPointerDown("start", e)}
|
||||
className="absolute top-1/2 -translate-y-1/2 -translate-x-1/2 w-4 h-4 rounded-full bg-purple-500 border-2 border-white shadow-lg cursor-grab active:cursor-grabbing hover:scale-110 transition-transform z-10"
|
||||
className="absolute top-1/2 -translate-y-1/2 -translate-x-1/2 w-5 h-5 rounded-full bg-purple-500 border-2 border-white shadow-lg cursor-grab active:cursor-grabbing hover:scale-110 transition-transform z-10"
|
||||
style={{ left: `${startPct}%` }}
|
||||
title={`起点: ${formatSec(sourceStart)}`}
|
||||
/>
|
||||
@@ -250,7 +250,7 @@ export function ClipTrimmer({
|
||||
{/* End thumb */}
|
||||
<div
|
||||
onPointerDown={(e) => handleThumbPointerDown("end", e)}
|
||||
className="absolute top-1/2 -translate-y-1/2 -translate-x-1/2 w-4 h-4 rounded-full bg-pink-500 border-2 border-white shadow-lg cursor-grab active:cursor-grabbing hover:scale-110 transition-transform z-10"
|
||||
className="absolute top-1/2 -translate-y-1/2 -translate-x-1/2 w-5 h-5 rounded-full bg-pink-500 border-2 border-white shadow-lg cursor-grab active:cursor-grabbing hover:scale-110 transition-transform z-10"
|
||||
style={{ left: `${endPct}%` }}
|
||||
title={`终点: ${formatSec(effectiveEnd)}`}
|
||||
/>
|
||||
|
||||
@@ -53,9 +53,11 @@ interface FloatingStylePreviewProps {
|
||||
buildTextShadow: (color: string, size: number) => string;
|
||||
previewBaseWidth: number;
|
||||
previewBaseHeight: number;
|
||||
previewBackgroundUrl?: string | null;
|
||||
}
|
||||
|
||||
const DESKTOP_WIDTH = 280;
|
||||
const MOBILE_WIDTH = 160;
|
||||
|
||||
export function FloatingStylePreview({
|
||||
onClose,
|
||||
@@ -78,11 +80,10 @@ export function FloatingStylePreview({
|
||||
buildTextShadow,
|
||||
previewBaseWidth,
|
||||
previewBaseHeight,
|
||||
previewBackgroundUrl,
|
||||
}: FloatingStylePreviewProps) {
|
||||
const isMobile = typeof window !== "undefined" && window.innerWidth < 640;
|
||||
const windowWidth = isMobile
|
||||
? Math.min(window.innerWidth - 32, 360)
|
||||
: DESKTOP_WIDTH;
|
||||
const windowWidth = isMobile ? MOBILE_WIDTH : DESKTOP_WIDTH;
|
||||
|
||||
useEffect(() => {
|
||||
const handleKeyDown = (e: KeyboardEvent) => {
|
||||
@@ -154,11 +155,12 @@ export function FloatingStylePreview({
|
||||
<div
|
||||
style={{
|
||||
position: "fixed",
|
||||
left: "16px",
|
||||
top: "16px",
|
||||
...(isMobile
|
||||
? { right: "12px", bottom: "12px" }
|
||||
: { left: "16px", top: "16px" }),
|
||||
width: `${windowWidth}px`,
|
||||
zIndex: 150,
|
||||
maxHeight: "calc(100dvh - 32px)",
|
||||
maxHeight: isMobile ? "calc(50dvh)" : "calc(100dvh - 32px)",
|
||||
overflow: "hidden",
|
||||
}}
|
||||
className="rounded-xl border border-white/20 bg-gray-900/95 backdrop-blur-md shadow-2xl"
|
||||
@@ -190,7 +192,11 @@ export function FloatingStylePreview({
|
||||
${subtitleFontUrl ? `@font-face { font-family: '${subtitleFontFamilyName}'; src: url('${subtitleFontUrl}') format('${getFontFormat(activeSubtitleStyle?.font_file)}'); font-weight: 400; font-style: normal; }` : ''}
|
||||
`}</style>
|
||||
)}
|
||||
<div className="absolute inset-0 opacity-20 bg-gradient-to-br from-purple-500/40 via-transparent to-pink-500/30" />
|
||||
{previewBackgroundUrl ? (
|
||||
<img src={previewBackgroundUrl} alt="" className="absolute inset-0 w-full h-full object-cover" />
|
||||
) : (
|
||||
<div className="absolute inset-0 opacity-20 bg-gradient-to-br from-purple-500/40 via-transparent to-pink-500/30" />
|
||||
)}
|
||||
<div
|
||||
className="absolute top-0 left-0"
|
||||
style={{
|
||||
|
||||
@@ -23,6 +23,7 @@ interface GeneratedAudiosPanelProps {
|
||||
speed: number;
|
||||
onSpeedChange: (speed: number) => void;
|
||||
ttsMode: string;
|
||||
embedded?: boolean;
|
||||
}
|
||||
|
||||
export function GeneratedAudiosPanel({
|
||||
@@ -40,6 +41,7 @@ export function GeneratedAudiosPanel({
|
||||
speed,
|
||||
onSpeedChange,
|
||||
ttsMode,
|
||||
embedded = false,
|
||||
}: GeneratedAudiosPanelProps) {
|
||||
const [editingId, setEditingId] = useState<string | null>(null);
|
||||
const [editName, setEditName] = useState("");
|
||||
@@ -123,64 +125,124 @@ export function GeneratedAudiosPanel({
|
||||
] as const;
|
||||
const currentSpeedLabel = speedOptions.find((o) => o.value === speed)?.label ?? "正常";
|
||||
|
||||
return (
|
||||
<div className="bg-white/5 rounded-2xl p-4 sm:p-6 border border-white/10 backdrop-blur-sm relative z-10">
|
||||
<div className="flex justify-between items-center gap-2 mb-4">
|
||||
<h2 className="text-base sm:text-lg font-semibold text-white flex items-center gap-2 whitespace-nowrap">
|
||||
<Mic className="h-4 w-4 text-purple-400" />
|
||||
配音列表
|
||||
</h2>
|
||||
<div className="flex gap-1.5">
|
||||
{/* 语速下拉 (仅声音克隆模式) */}
|
||||
{ttsMode === "voiceclone" && (
|
||||
<div ref={speedRef} className="relative">
|
||||
<button
|
||||
onClick={() => setSpeedOpen((v) => !v)}
|
||||
className="px-2 py-1 text-xs bg-white/10 hover:bg-white/20 rounded text-gray-300 whitespace-nowrap flex items-center gap-1 transition-all"
|
||||
>
|
||||
语速: {currentSpeedLabel}
|
||||
<ChevronDown className={`h-3 w-3 transition-transform ${speedOpen ? "rotate-180" : ""}`} />
|
||||
</button>
|
||||
{speedOpen && (
|
||||
<div className="absolute right-0 top-full mt-1 bg-gray-800 border border-white/20 rounded-lg shadow-xl py-1 z-50 min-w-[80px]">
|
||||
{speedOptions.map((opt) => (
|
||||
<button
|
||||
key={opt.value}
|
||||
onClick={() => { onSpeedChange(opt.value); setSpeedOpen(false); }}
|
||||
className={`w-full text-left px-3 py-1.5 text-xs transition-colors ${
|
||||
speed === opt.value
|
||||
? "bg-purple-600/40 text-purple-200"
|
||||
: "text-gray-300 hover:bg-white/10"
|
||||
}`}
|
||||
>
|
||||
{opt.label}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
<button
|
||||
onClick={onGenerateAudio}
|
||||
disabled={isGeneratingAudio || !canGenerate}
|
||||
title={missingRefAudio ? "请先选择参考音频" : !hasText ? "请先输入文案" : ""}
|
||||
className={`px-2 py-1 text-xs rounded transition-all whitespace-nowrap flex items-center gap-1 ${
|
||||
isGeneratingAudio || !canGenerate
|
||||
? "bg-gray-600 cursor-not-allowed text-gray-400"
|
||||
: "bg-gradient-to-r from-purple-600 to-pink-600 hover:from-purple-700 hover:to-pink-700 text-white"
|
||||
}`}
|
||||
>
|
||||
<Mic className="h-3.5 w-3.5" />
|
||||
生成配音
|
||||
</button>
|
||||
<button
|
||||
onClick={onRefresh}
|
||||
className="px-2 py-1 text-xs bg-white/10 hover:bg-white/20 rounded text-gray-300 whitespace-nowrap flex items-center gap-1"
|
||||
>
|
||||
<RefreshCw className="h-3.5 w-3.5" />
|
||||
</button>
|
||||
const content = (
|
||||
<>
|
||||
{embedded ? (
|
||||
<>
|
||||
{/* Row 1: 语速 + 生成配音 (right-aligned) */}
|
||||
<div className="flex justify-end items-center gap-1.5 mb-3">
|
||||
{ttsMode === "voiceclone" && (
|
||||
<div ref={speedRef} className="relative">
|
||||
<button
|
||||
onClick={() => setSpeedOpen((v) => !v)}
|
||||
className="px-2 py-1 text-xs bg-white/10 hover:bg-white/20 rounded text-gray-300 whitespace-nowrap flex items-center gap-1 transition-all"
|
||||
>
|
||||
语速: {currentSpeedLabel}
|
||||
<ChevronDown className={`h-3 w-3 transition-transform ${speedOpen ? "rotate-180" : ""}`} />
|
||||
</button>
|
||||
{speedOpen && (
|
||||
<div className="absolute right-0 top-full mt-1 bg-gray-800 border border-white/20 rounded-lg shadow-xl py-1 z-50 min-w-[80px]">
|
||||
{speedOptions.map((opt) => (
|
||||
<button
|
||||
key={opt.value}
|
||||
onClick={() => { onSpeedChange(opt.value); setSpeedOpen(false); }}
|
||||
className={`w-full text-left px-3 py-1.5 text-xs transition-colors ${
|
||||
speed === opt.value
|
||||
? "bg-purple-600/40 text-purple-200"
|
||||
: "text-gray-300 hover:bg-white/10"
|
||||
}`}
|
||||
>
|
||||
{opt.label}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
<button
|
||||
onClick={onGenerateAudio}
|
||||
disabled={isGeneratingAudio || !canGenerate}
|
||||
title={missingRefAudio ? "请先选择参考音频" : !hasText ? "请先输入文案" : ""}
|
||||
className={`px-4 py-2 text-sm font-medium rounded-lg transition-all whitespace-nowrap flex items-center gap-1.5 shadow-sm ${
|
||||
isGeneratingAudio || !canGenerate
|
||||
? "bg-gray-600 cursor-not-allowed text-gray-400"
|
||||
: "bg-gradient-to-r from-purple-600 to-pink-600 hover:from-purple-700 hover:to-pink-700 text-white hover:shadow-md"
|
||||
}`}
|
||||
>
|
||||
<Mic className="h-4 w-4" />
|
||||
生成配音
|
||||
</button>
|
||||
</div>
|
||||
{/* Row 2: 配音列表 + 刷新 */}
|
||||
<div className="flex justify-between items-center mb-3">
|
||||
<h3 className="text-sm font-medium text-gray-400">配音列表</h3>
|
||||
<button
|
||||
onClick={onRefresh}
|
||||
className="px-2 py-1 text-xs bg-white/10 hover:bg-white/20 rounded text-gray-300 whitespace-nowrap flex items-center gap-1"
|
||||
>
|
||||
<RefreshCw className="h-3.5 w-3.5" />
|
||||
刷新
|
||||
</button>
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<div className="flex justify-between items-center gap-2 mb-4">
|
||||
<h2 className="text-base sm:text-lg font-semibold text-white flex items-center gap-2 whitespace-nowrap">
|
||||
<Mic className="h-4 w-4 text-purple-400" />
|
||||
配音列表
|
||||
</h2>
|
||||
<div className="flex gap-1.5">
|
||||
{ttsMode === "voiceclone" && (
|
||||
<div ref={speedRef} className="relative">
|
||||
<button
|
||||
onClick={() => setSpeedOpen((v) => !v)}
|
||||
className="px-2 py-1 text-xs bg-white/10 hover:bg-white/20 rounded text-gray-300 whitespace-nowrap flex items-center gap-1 transition-all"
|
||||
>
|
||||
语速: {currentSpeedLabel}
|
||||
<ChevronDown className={`h-3 w-3 transition-transform ${speedOpen ? "rotate-180" : ""}`} />
|
||||
</button>
|
||||
{speedOpen && (
|
||||
<div className="absolute right-0 top-full mt-1 bg-gray-800 border border-white/20 rounded-lg shadow-xl py-1 z-50 min-w-[80px]">
|
||||
{speedOptions.map((opt) => (
|
||||
<button
|
||||
key={opt.value}
|
||||
onClick={() => { onSpeedChange(opt.value); setSpeedOpen(false); }}
|
||||
className={`w-full text-left px-3 py-1.5 text-xs transition-colors ${
|
||||
speed === opt.value
|
||||
? "bg-purple-600/40 text-purple-200"
|
||||
: "text-gray-300 hover:bg-white/10"
|
||||
}`}
|
||||
>
|
||||
{opt.label}
|
||||
</button>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
<button
|
||||
onClick={onGenerateAudio}
|
||||
disabled={isGeneratingAudio || !canGenerate}
|
||||
title={missingRefAudio ? "请先选择参考音频" : !hasText ? "请先输入文案" : ""}
|
||||
className={`px-4 py-2 text-sm font-medium rounded-lg transition-all whitespace-nowrap flex items-center gap-1.5 shadow-sm ${
|
||||
isGeneratingAudio || !canGenerate
|
||||
? "bg-gray-600 cursor-not-allowed text-gray-400"
|
||||
: "bg-gradient-to-r from-purple-600 to-pink-600 hover:from-purple-700 hover:to-pink-700 text-white hover:shadow-md"
|
||||
}`}
|
||||
>
|
||||
<Mic className="h-4 w-4" />
|
||||
生成配音
|
||||
</button>
|
||||
<button
|
||||
onClick={onRefresh}
|
||||
className="px-2 py-1 text-xs bg-white/10 hover:bg-white/20 rounded text-gray-300 whitespace-nowrap flex items-center gap-1"
|
||||
>
|
||||
<RefreshCw className="h-3.5 w-3.5" />
|
||||
刷新
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* 缺少参考音频提示 */}
|
||||
{missingRefAudio && (
|
||||
@@ -250,7 +312,7 @@ export function GeneratedAudiosPanel({
|
||||
<div className="text-white text-sm truncate">{audio.name}</div>
|
||||
<div className="text-gray-400 text-xs">{audio.duration_sec.toFixed(1)}s</div>
|
||||
</div>
|
||||
<div className="flex items-center gap-1 pl-2 opacity-0 group-hover:opacity-100 transition-opacity">
|
||||
<div className="flex items-center gap-1 pl-2 opacity-40 group-hover:opacity-100 transition-opacity">
|
||||
<button
|
||||
onClick={(e) => togglePlay(audio, e)}
|
||||
className="p-1 text-gray-500 hover:text-purple-400 transition-colors"
|
||||
@@ -287,7 +349,14 @@ export function GeneratedAudiosPanel({
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
|
||||
if (embedded) return content;
|
||||
|
||||
return (
|
||||
<div className="bg-white/5 rounded-2xl p-4 sm:p-6 border border-white/10 backdrop-blur-sm relative z-10">
|
||||
{content}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -16,6 +16,7 @@ interface HistoryListProps {
|
||||
onRefresh: () => void;
|
||||
registerVideoRef: (id: string, element: HTMLDivElement | null) => void;
|
||||
formatDate: (timestamp: number) => string;
|
||||
embedded?: boolean;
|
||||
}
|
||||
|
||||
export function HistoryList({
|
||||
@@ -26,19 +27,22 @@ export function HistoryList({
|
||||
onRefresh,
|
||||
registerVideoRef,
|
||||
formatDate,
|
||||
embedded = false,
|
||||
}: HistoryListProps) {
|
||||
return (
|
||||
<div className="bg-white/5 rounded-2xl p-6 border border-white/10 backdrop-blur-sm">
|
||||
<div className="flex justify-between items-center mb-4">
|
||||
<h2 className="text-lg font-semibold text-white flex items-center gap-2">📂 历史作品</h2>
|
||||
<button
|
||||
onClick={onRefresh}
|
||||
className="px-3 py-1 text-xs bg-white/10 hover:bg-white/20 rounded text-gray-300 flex items-center gap-1"
|
||||
>
|
||||
<RefreshCw className="h-3.5 w-3.5" />
|
||||
刷新
|
||||
</button>
|
||||
</div>
|
||||
const content = (
|
||||
<>
|
||||
{!embedded && (
|
||||
<div className="flex justify-between items-center mb-4">
|
||||
<h2 className="text-lg font-semibold text-white flex items-center gap-2">历史作品</h2>
|
||||
<button
|
||||
onClick={onRefresh}
|
||||
className="px-3 py-1 text-xs bg-white/10 hover:bg-white/20 rounded text-gray-300 flex items-center gap-1"
|
||||
>
|
||||
<RefreshCw className="h-3.5 w-3.5" />
|
||||
刷新
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
{generatedVideos.length === 0 ? (
|
||||
<div className="text-center py-4 text-gray-500">
|
||||
<p>暂无生成的作品</p>
|
||||
@@ -66,7 +70,7 @@ export function HistoryList({
|
||||
e.stopPropagation();
|
||||
onDeleteVideo(v.id);
|
||||
}}
|
||||
className="p-1 text-gray-500 hover:text-red-400 opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
className="p-1 text-gray-500 hover:text-red-400 opacity-40 group-hover:opacity-100 transition-opacity"
|
||||
title="删除视频"
|
||||
>
|
||||
<Trash2 className="h-4 w-4" />
|
||||
@@ -75,6 +79,14 @@ export function HistoryList({
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
|
||||
if (embedded) return content;
|
||||
|
||||
return (
|
||||
<div className="bg-white/5 rounded-2xl p-6 border border-white/10 backdrop-blur-sm">
|
||||
{content}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
import { useEffect, useMemo } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { RefreshCw } from "lucide-react";
|
||||
import VideoPreviewModal from "@/components/VideoPreviewModal";
|
||||
import ScriptExtractionModal from "./ScriptExtractionModal";
|
||||
import RewriteModal from "./RewriteModal";
|
||||
import { useHomeController } from "@/features/home/model/useHomeController";
|
||||
import { resolveMediaUrl } from "@/shared/lib/media";
|
||||
import { BgmPanel } from "@/features/home/ui/BgmPanel";
|
||||
@@ -51,6 +53,8 @@ export function HomePage() {
|
||||
setText,
|
||||
extractModalOpen,
|
||||
setExtractModalOpen,
|
||||
rewriteModalOpen,
|
||||
setRewriteModalOpen,
|
||||
handleGenerateMeta,
|
||||
isGeneratingMeta,
|
||||
handleTranslate,
|
||||
@@ -171,6 +175,7 @@ export function HomePage() {
|
||||
setClipTrimmerOpen,
|
||||
clipTrimmerSegmentId,
|
||||
setClipTrimmerSegmentId,
|
||||
materialPosterUrl,
|
||||
} = useHomeController();
|
||||
|
||||
useEffect(() => {
|
||||
@@ -179,7 +184,15 @@ export function HomePage() {
|
||||
|
||||
useEffect(() => {
|
||||
if (typeof window === "undefined") return;
|
||||
if ("scrollRestoration" in history) {
|
||||
history.scrollRestoration = "manual";
|
||||
}
|
||||
window.scrollTo({ top: 0, left: 0, behavior: "auto" });
|
||||
// 兜底:等所有恢复 effect + 异步数据加载 settle 后再次强制回顶部
|
||||
const timer = setTimeout(() => {
|
||||
window.scrollTo({ top: 0, left: 0, behavior: "auto" });
|
||||
}, 200);
|
||||
return () => clearTimeout(timer);
|
||||
}, []);
|
||||
|
||||
const clipTrimmerSegment = useMemo(
|
||||
@@ -201,11 +214,12 @@ export function HomePage() {
|
||||
<div className="grid grid-cols-1 lg:grid-cols-2 gap-8">
|
||||
{/* 左侧: 输入区域 */}
|
||||
<div className="space-y-6">
|
||||
{/* 1. 文案输入 */}
|
||||
{/* 一、文案提取与编辑 */}
|
||||
<ScriptEditor
|
||||
text={text}
|
||||
onChangeText={setText}
|
||||
onOpenExtractModal={() => setExtractModalOpen(true)}
|
||||
onOpenRewriteModal={() => setRewriteModalOpen(true)}
|
||||
onGenerateMeta={handleGenerateMeta}
|
||||
isGeneratingMeta={isGeneratingMeta}
|
||||
onTranslate={handleTranslate}
|
||||
@@ -218,7 +232,127 @@ export function HomePage() {
|
||||
onDeleteScript={deleteSavedScript}
|
||||
/>
|
||||
|
||||
{/* 2. 标题和字幕设置 */}
|
||||
{/* 二、配音 */}
|
||||
<div className="bg-white/5 rounded-2xl p-4 sm:p-6 border border-white/10 backdrop-blur-sm">
|
||||
<h2 className="text-base sm:text-lg font-semibold text-white mb-4">
|
||||
二、配音
|
||||
</h2>
|
||||
<h3 className="text-sm font-medium text-gray-400 mb-3">配音方式</h3>
|
||||
<VoiceSelector
|
||||
embedded
|
||||
ttsMode={ttsMode}
|
||||
onSelectTtsMode={setTtsMode}
|
||||
voices={voices}
|
||||
voice={voice}
|
||||
onSelectVoice={setVoice}
|
||||
voiceCloneSlot={(
|
||||
<RefAudioPanel
|
||||
refAudios={refAudios}
|
||||
selectedRefAudio={selectedRefAudio}
|
||||
onSelectRefAudio={handleSelectRefAudio}
|
||||
isUploadingRef={isUploadingRef}
|
||||
uploadRefError={uploadRefError}
|
||||
onClearUploadRefError={() => setUploadRefError(null)}
|
||||
onUploadRefAudio={uploadRefAudio}
|
||||
onFetchRefAudios={fetchRefAudios}
|
||||
playingAudioId={playingAudioId}
|
||||
onTogglePlayPreview={togglePlayPreview}
|
||||
editingAudioId={editingAudioId}
|
||||
editName={editName}
|
||||
onEditNameChange={setEditName}
|
||||
onStartEditing={startEditing}
|
||||
onSaveEditing={saveEditing}
|
||||
onCancelEditing={cancelEditing}
|
||||
onDeleteRefAudio={deleteRefAudio}
|
||||
onRetranscribe={retranscribeRefAudio}
|
||||
retranscribingId={retranscribingId}
|
||||
recordedBlob={recordedBlob}
|
||||
isRecording={isRecording}
|
||||
recordingTime={recordingTime}
|
||||
onStartRecording={startRecording}
|
||||
onStopRecording={stopRecording}
|
||||
onUseRecording={useRecording}
|
||||
formatRecordingTime={formatRecordingTime}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
<div className="border-t border-white/10 my-4" />
|
||||
<GeneratedAudiosPanel
|
||||
embedded
|
||||
generatedAudios={generatedAudios}
|
||||
selectedAudioId={selectedAudioId}
|
||||
isGeneratingAudio={isGeneratingAudio}
|
||||
audioTask={audioTask}
|
||||
onGenerateAudio={handleGenerateAudio}
|
||||
onRefresh={() => fetchGeneratedAudios()}
|
||||
onSelectAudio={selectAudio}
|
||||
onDeleteAudio={deleteAudio}
|
||||
onRenameAudio={renameAudio}
|
||||
hasText={!!text.trim()}
|
||||
missingRefAudio={ttsMode === "voiceclone" && !selectedRefAudio}
|
||||
speed={speed}
|
||||
onSpeedChange={setSpeed}
|
||||
ttsMode={ttsMode}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* 三、素材编辑 */}
|
||||
<div className="bg-white/5 rounded-2xl p-4 sm:p-6 border border-white/10 backdrop-blur-sm">
|
||||
<h2 className="text-base sm:text-lg font-semibold text-white mb-4">
|
||||
三、素材编辑
|
||||
</h2>
|
||||
<MaterialSelector
|
||||
embedded
|
||||
materials={materials}
|
||||
selectedMaterials={selectedMaterials}
|
||||
isFetching={isFetching}
|
||||
lastMaterialCount={lastMaterialCount}
|
||||
editingMaterialId={editingMaterialId}
|
||||
editMaterialName={editMaterialName}
|
||||
isUploading={isUploading}
|
||||
uploadProgress={uploadProgress}
|
||||
uploadError={uploadError}
|
||||
fetchError={fetchError}
|
||||
apiBase={apiBase}
|
||||
onUploadChange={handleUpload}
|
||||
onRefresh={fetchMaterials}
|
||||
onToggleMaterial={toggleMaterial}
|
||||
onPreviewMaterial={handlePreviewMaterial}
|
||||
onStartEditing={startMaterialEditing}
|
||||
onEditNameChange={setEditMaterialName}
|
||||
onSaveEditing={saveMaterialEditing}
|
||||
onCancelEditing={cancelMaterialEditing}
|
||||
onDeleteMaterial={deleteMaterial}
|
||||
onClearUploadError={() => setUploadError(null)}
|
||||
registerMaterialRef={registerMaterialRef}
|
||||
/>
|
||||
<div className="border-t border-white/10 my-4" />
|
||||
<div className="relative">
|
||||
{(!selectedAudio || selectedMaterials.length === 0) && (
|
||||
<div className="absolute inset-0 bg-black/50 backdrop-blur-sm rounded-xl flex items-center justify-center z-10">
|
||||
<p className="text-gray-400">
|
||||
{!selectedAudio ? "请先生成并选中配音" : "请先选择素材"}
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
<TimelineEditor
|
||||
embedded
|
||||
audioDuration={selectedAudio?.duration_sec ?? 0}
|
||||
audioUrl={selectedAudio ? (resolveMediaUrl(selectedAudio.path) || "") : ""}
|
||||
segments={timelineSegments}
|
||||
materials={materials}
|
||||
outputAspectRatio={outputAspectRatio}
|
||||
onOutputAspectRatioChange={setOutputAspectRatio}
|
||||
onReorderSegment={reorderSegments}
|
||||
onClickSegment={(seg) => {
|
||||
setClipTrimmerSegmentId(seg.id);
|
||||
setClipTrimmerOpen(true);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* 四、标题与字幕 */}
|
||||
<TitleSubtitlePanel
|
||||
showStylePreview={showStylePreview}
|
||||
onTogglePreview={() => setShowStylePreview((prev) => !prev)}
|
||||
@@ -266,116 +400,10 @@ export function HomePage() {
|
||||
buildTextShadow={buildTextShadow}
|
||||
previewBaseWidth={outputAspectRatio === "16:9" ? 1920 : 1080}
|
||||
previewBaseHeight={outputAspectRatio === "16:9" ? 1080 : 1920}
|
||||
previewBackgroundUrl={materialPosterUrl}
|
||||
/>
|
||||
|
||||
{/* 3. 配音方式选择 */}
|
||||
<VoiceSelector
|
||||
ttsMode={ttsMode}
|
||||
onSelectTtsMode={setTtsMode}
|
||||
voices={voices}
|
||||
voice={voice}
|
||||
onSelectVoice={setVoice}
|
||||
voiceCloneSlot={(
|
||||
<RefAudioPanel
|
||||
refAudios={refAudios}
|
||||
selectedRefAudio={selectedRefAudio}
|
||||
onSelectRefAudio={handleSelectRefAudio}
|
||||
isUploadingRef={isUploadingRef}
|
||||
uploadRefError={uploadRefError}
|
||||
onClearUploadRefError={() => setUploadRefError(null)}
|
||||
onUploadRefAudio={uploadRefAudio}
|
||||
onFetchRefAudios={fetchRefAudios}
|
||||
playingAudioId={playingAudioId}
|
||||
onTogglePlayPreview={togglePlayPreview}
|
||||
editingAudioId={editingAudioId}
|
||||
editName={editName}
|
||||
onEditNameChange={setEditName}
|
||||
onStartEditing={startEditing}
|
||||
onSaveEditing={saveEditing}
|
||||
onCancelEditing={cancelEditing}
|
||||
onDeleteRefAudio={deleteRefAudio}
|
||||
onRetranscribe={retranscribeRefAudio}
|
||||
retranscribingId={retranscribingId}
|
||||
recordedBlob={recordedBlob}
|
||||
isRecording={isRecording}
|
||||
recordingTime={recordingTime}
|
||||
onStartRecording={startRecording}
|
||||
onStopRecording={stopRecording}
|
||||
onUseRecording={useRecording}
|
||||
formatRecordingTime={formatRecordingTime}
|
||||
/>
|
||||
)}
|
||||
/>
|
||||
|
||||
{/* 4. 配音列表 */}
|
||||
<GeneratedAudiosPanel
|
||||
generatedAudios={generatedAudios}
|
||||
selectedAudioId={selectedAudioId}
|
||||
isGeneratingAudio={isGeneratingAudio}
|
||||
audioTask={audioTask}
|
||||
onGenerateAudio={handleGenerateAudio}
|
||||
onRefresh={() => fetchGeneratedAudios()}
|
||||
onSelectAudio={selectAudio}
|
||||
onDeleteAudio={deleteAudio}
|
||||
onRenameAudio={renameAudio}
|
||||
hasText={!!text.trim()}
|
||||
missingRefAudio={ttsMode === "voiceclone" && !selectedRefAudio}
|
||||
speed={speed}
|
||||
onSpeedChange={setSpeed}
|
||||
ttsMode={ttsMode}
|
||||
/>
|
||||
|
||||
{/* 5. 视频素材 */}
|
||||
<MaterialSelector
|
||||
materials={materials}
|
||||
selectedMaterials={selectedMaterials}
|
||||
isFetching={isFetching}
|
||||
lastMaterialCount={lastMaterialCount}
|
||||
editingMaterialId={editingMaterialId}
|
||||
editMaterialName={editMaterialName}
|
||||
isUploading={isUploading}
|
||||
uploadProgress={uploadProgress}
|
||||
uploadError={uploadError}
|
||||
fetchError={fetchError}
|
||||
apiBase={apiBase}
|
||||
onUploadChange={handleUpload}
|
||||
onRefresh={fetchMaterials}
|
||||
onToggleMaterial={toggleMaterial}
|
||||
onPreviewMaterial={handlePreviewMaterial}
|
||||
onStartEditing={startMaterialEditing}
|
||||
onEditNameChange={setEditMaterialName}
|
||||
onSaveEditing={saveMaterialEditing}
|
||||
onCancelEditing={cancelMaterialEditing}
|
||||
onDeleteMaterial={deleteMaterial}
|
||||
onClearUploadError={() => setUploadError(null)}
|
||||
registerMaterialRef={registerMaterialRef}
|
||||
/>
|
||||
|
||||
{/* 5.5 时间轴编辑器 — 未选配音/素材时模糊遮挡 */}
|
||||
<div className="relative">
|
||||
{(!selectedAudio || selectedMaterials.length === 0) && (
|
||||
<div className="absolute inset-0 bg-black/50 backdrop-blur-sm rounded-2xl flex items-center justify-center z-10">
|
||||
<p className="text-gray-400">
|
||||
{!selectedAudio ? "请先生成并选中配音" : "请先选择素材"}
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
<TimelineEditor
|
||||
audioDuration={selectedAudio?.duration_sec ?? 0}
|
||||
audioUrl={selectedAudio ? (resolveMediaUrl(selectedAudio.path) || "") : ""}
|
||||
segments={timelineSegments}
|
||||
materials={materials}
|
||||
outputAspectRatio={outputAspectRatio}
|
||||
onOutputAspectRatioChange={setOutputAspectRatio}
|
||||
onReorderSegment={reorderSegments}
|
||||
onClickSegment={(seg) => {
|
||||
setClipTrimmerSegmentId(seg.id);
|
||||
setClipTrimmerOpen(true);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* 6. 背景音乐 */}
|
||||
{/* 背景音乐 (不编号) */}
|
||||
<BgmPanel
|
||||
bgmList={bgmList}
|
||||
bgmLoading={bgmLoading}
|
||||
@@ -393,7 +421,7 @@ export function HomePage() {
|
||||
registerBgmItemRef={registerBgmItemRef}
|
||||
/>
|
||||
|
||||
{/* 7. 生成按钮 */}
|
||||
{/* 生成按钮 (不编号) */}
|
||||
<GenerateActionBar
|
||||
isGenerating={isGenerating}
|
||||
progress={currentTask?.progress || 0}
|
||||
@@ -403,23 +431,59 @@ export function HomePage() {
|
||||
/>
|
||||
</div>
|
||||
|
||||
{/* 右侧: 预览区域 */}
|
||||
{/* 右侧: 作品区域 */}
|
||||
<div className="space-y-6">
|
||||
<PreviewPanel
|
||||
currentTask={currentTask}
|
||||
isGenerating={isGenerating}
|
||||
generatedVideo={generatedVideo}
|
||||
/>
|
||||
|
||||
<HistoryList
|
||||
generatedVideos={generatedVideos}
|
||||
selectedVideoId={selectedVideoId}
|
||||
onSelectVideo={handleSelectVideo}
|
||||
onDeleteVideo={deleteVideo}
|
||||
onRefresh={() => fetchGeneratedVideos()}
|
||||
registerVideoRef={registerVideoRef}
|
||||
formatDate={formatDate}
|
||||
/>
|
||||
{/* 生成进度(在作品卡片上方) */}
|
||||
{currentTask && isGenerating && (
|
||||
<div className="bg-white/5 rounded-2xl p-4 sm:p-6 border border-purple-500/30 backdrop-blur-sm">
|
||||
<div className="space-y-3">
|
||||
<div className="flex justify-between text-sm text-purple-300 mb-1">
|
||||
<span>正在AI生成中...</span>
|
||||
<span>{currentTask.progress || 0}%</span>
|
||||
</div>
|
||||
<div className="h-3 bg-black/30 rounded-full overflow-hidden">
|
||||
<div
|
||||
className="h-full bg-gradient-to-r from-purple-500 to-pink-500 transition-all duration-300"
|
||||
style={{ width: `${currentTask.progress || 0}%` }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
{/* 六、作品 */}
|
||||
<div className="bg-white/5 rounded-2xl p-4 sm:p-6 border border-white/10 backdrop-blur-sm">
|
||||
<h2 className="text-base sm:text-lg font-semibold text-white mb-4">
|
||||
六、作品
|
||||
</h2>
|
||||
<div className="flex justify-between items-center mb-3">
|
||||
<h3 className="text-sm font-medium text-gray-400">作品列表</h3>
|
||||
<button
|
||||
onClick={() => fetchGeneratedVideos()}
|
||||
className="px-2 py-1 text-xs bg-white/10 hover:bg-white/20 rounded text-gray-300 flex items-center gap-1"
|
||||
>
|
||||
<RefreshCw className="h-3.5 w-3.5" />
|
||||
刷新
|
||||
</button>
|
||||
</div>
|
||||
<HistoryList
|
||||
embedded
|
||||
generatedVideos={generatedVideos}
|
||||
selectedVideoId={selectedVideoId}
|
||||
onSelectVideo={handleSelectVideo}
|
||||
onDeleteVideo={deleteVideo}
|
||||
onRefresh={() => fetchGeneratedVideos()}
|
||||
registerVideoRef={registerVideoRef}
|
||||
formatDate={formatDate}
|
||||
/>
|
||||
<div className="border-t border-white/10 my-4" />
|
||||
<h3 className="text-sm font-medium text-gray-400 mb-3">作品预览</h3>
|
||||
<PreviewPanel
|
||||
embedded
|
||||
currentTask={null}
|
||||
isGenerating={false}
|
||||
generatedVideo={generatedVideo}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</main>
|
||||
@@ -435,6 +499,13 @@ export function HomePage() {
|
||||
onApply={(nextText) => setText(nextText)}
|
||||
/>
|
||||
|
||||
<RewriteModal
|
||||
isOpen={rewriteModalOpen}
|
||||
onClose={() => setRewriteModalOpen(false)}
|
||||
originalText={text}
|
||||
onApply={(newText) => setText(newText)}
|
||||
/>
|
||||
|
||||
<ClipTrimmer
|
||||
isOpen={clipTrimmerOpen}
|
||||
segment={clipTrimmerSegment}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { type ChangeEvent, type MouseEvent } from "react";
|
||||
import { type ChangeEvent, type MouseEvent, useMemo } from "react";
|
||||
import { Upload, RefreshCw, Eye, Trash2, X, Pencil, Check } from "lucide-react";
|
||||
import type { Material } from "@/shared/types/material";
|
||||
|
||||
@@ -25,6 +25,7 @@ interface MaterialSelectorProps {
|
||||
onDeleteMaterial: (id: string) => void;
|
||||
onClearUploadError: () => void;
|
||||
registerMaterialRef: (id: string, element: HTMLDivElement | null) => void;
|
||||
embedded?: boolean;
|
||||
}
|
||||
|
||||
export function MaterialSelector({
|
||||
@@ -50,19 +51,27 @@ export function MaterialSelector({
|
||||
onDeleteMaterial,
|
||||
onClearUploadError,
|
||||
registerMaterialRef,
|
||||
embedded = false,
|
||||
}: MaterialSelectorProps) {
|
||||
const selectedSet = new Set(selectedMaterials);
|
||||
const selectedSet = useMemo(() => new Set(selectedMaterials), [selectedMaterials]);
|
||||
const isFull = selectedMaterials.length >= 4;
|
||||
|
||||
return (
|
||||
<div className="bg-white/5 rounded-2xl p-4 sm:p-6 border border-white/10 backdrop-blur-sm">
|
||||
const content = (
|
||||
<>
|
||||
<div className="flex justify-between items-center gap-2 mb-4">
|
||||
<h2 className="text-base sm:text-lg font-semibold text-white flex items-center gap-2 whitespace-nowrap">
|
||||
📹 视频素材
|
||||
<span className="ml-1 text-[11px] sm:text-xs text-gray-400/90 font-normal">
|
||||
(可多选,最多4个)
|
||||
</span>
|
||||
</h2>
|
||||
{!embedded ? (
|
||||
<h2 className="text-base sm:text-lg font-semibold text-white flex items-center gap-2 min-w-0">
|
||||
<span className="shrink-0">视频素材</span>
|
||||
<span className="text-[11px] sm:text-xs text-gray-400/90 font-normal truncate">
|
||||
(上传自拍视频,最多可选4个)
|
||||
</span>
|
||||
</h2>
|
||||
) : (
|
||||
<h3 className="text-sm font-medium text-gray-400 min-w-0">
|
||||
<span className="shrink-0">视频素材</span>
|
||||
<span className="ml-1 text-[11px] text-gray-400/90 font-normal hidden sm:inline">(上传自拍视频,最多可选4个)</span>
|
||||
</h3>
|
||||
)}
|
||||
<div className="flex gap-1.5">
|
||||
<input
|
||||
type="file"
|
||||
@@ -94,7 +103,7 @@ export function MaterialSelector({
|
||||
{isUploading && (
|
||||
<div className="mb-4 p-4 bg-purple-500/10 rounded-xl border border-purple-500/30">
|
||||
<div className="flex justify-between text-sm text-purple-300 mb-2">
|
||||
<span>📤 上传中...</span>
|
||||
<span>上传中...</span>
|
||||
<span>{uploadProgress}%</span>
|
||||
</div>
|
||||
<div className="h-2 bg-black/30 rounded-full overflow-hidden">
|
||||
@@ -108,7 +117,7 @@ export function MaterialSelector({
|
||||
|
||||
{uploadError && (
|
||||
<div className="mb-4 p-4 bg-red-500/20 text-red-200 rounded-xl text-sm flex justify-between items-center">
|
||||
<span>❌ {uploadError}</span>
|
||||
<span>{uploadError}</span>
|
||||
<button onClick={onClearUploadError} className="text-red-300 hover:text-white">
|
||||
<X className="h-3.5 w-3.5" />
|
||||
</button>
|
||||
@@ -138,7 +147,7 @@ export function MaterialSelector({
|
||||
<div className="text-5xl mb-4">📁</div>
|
||||
<p>暂无视频素材</p>
|
||||
<p className="text-sm mt-2">
|
||||
点击上方「📤 上传视频」按钮添加视频素材
|
||||
点击上方「上传」按钮添加视频素材
|
||||
</p>
|
||||
</div>
|
||||
) : (
|
||||
@@ -183,7 +192,7 @@ export function MaterialSelector({
|
||||
</button>
|
||||
</div>
|
||||
) : (
|
||||
<button onClick={() => onToggleMaterial(m.id)} className="flex-1 text-left flex items-center gap-2">
|
||||
<button onClick={() => onToggleMaterial(m.id)} disabled={isFull && !isSelected} className="flex-1 text-left flex items-center gap-2">
|
||||
{/* 复选框 */}
|
||||
<span
|
||||
className={`flex-shrink-0 w-4 h-4 rounded border flex items-center justify-center text-[10px] ${isSelected
|
||||
@@ -207,7 +216,7 @@ export function MaterialSelector({
|
||||
onPreviewMaterial(m.path);
|
||||
}
|
||||
}}
|
||||
className="p-1 text-gray-500 hover:text-white opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
className="p-1 text-gray-500 hover:text-white opacity-40 group-hover:opacity-100 transition-opacity"
|
||||
title="预览视频"
|
||||
>
|
||||
<Eye className="h-4 w-4" />
|
||||
@@ -215,7 +224,7 @@ export function MaterialSelector({
|
||||
{editingMaterialId !== m.id && (
|
||||
<button
|
||||
onClick={(e) => onStartEditing(m, e)}
|
||||
className="p-1 text-gray-500 hover:text-white opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
className="p-1 text-gray-500 hover:text-white opacity-40 group-hover:opacity-100 transition-opacity"
|
||||
title="重命名"
|
||||
>
|
||||
<Pencil className="h-4 w-4" />
|
||||
@@ -226,7 +235,7 @@ export function MaterialSelector({
|
||||
e.stopPropagation();
|
||||
onDeleteMaterial(m.id);
|
||||
}}
|
||||
className="p-1 text-gray-500 hover:text-red-400 opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
className="p-1 text-gray-500 hover:text-red-400 opacity-40 group-hover:opacity-100 transition-opacity"
|
||||
title="删除素材"
|
||||
>
|
||||
<Trash2 className="h-4 w-4" />
|
||||
@@ -237,6 +246,14 @@ export function MaterialSelector({
|
||||
})}
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
|
||||
if (embedded) return content;
|
||||
|
||||
return (
|
||||
<div className="bg-white/5 rounded-2xl p-4 sm:p-6 border border-white/10 backdrop-blur-sm">
|
||||
{content}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -12,18 +12,20 @@ interface PreviewPanelProps {
|
||||
currentTask: Task | null;
|
||||
isGenerating: boolean;
|
||||
generatedVideo: string | null;
|
||||
embedded?: boolean;
|
||||
}
|
||||
|
||||
export function PreviewPanel({
|
||||
currentTask,
|
||||
isGenerating,
|
||||
generatedVideo,
|
||||
embedded = false,
|
||||
}: PreviewPanelProps) {
|
||||
return (
|
||||
const content = (
|
||||
<>
|
||||
{currentTask && isGenerating && (
|
||||
<div className="bg-white/5 rounded-2xl p-6 border border-white/10 backdrop-blur-sm">
|
||||
<h2 className="text-lg font-semibold text-white mb-4">⏳ 生成进度</h2>
|
||||
<div className={embedded ? "mb-4" : "bg-white/5 rounded-2xl p-6 border border-white/10 backdrop-blur-sm"}>
|
||||
{!embedded && <h2 className="text-lg font-semibold text-white mb-4">生成进度</h2>}
|
||||
<div className="space-y-3">
|
||||
<div className="h-3 bg-black/30 rounded-full overflow-hidden">
|
||||
<div
|
||||
@@ -36,8 +38,8 @@ export function PreviewPanel({
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="bg-white/5 rounded-2xl p-6 border border-white/10 backdrop-blur-sm">
|
||||
<h2 className="text-lg font-semibold text-white mb-4 flex items-center gap-2">🎥 作品预览</h2>
|
||||
<div className={embedded ? "" : "bg-white/5 rounded-2xl p-6 border border-white/10 backdrop-blur-sm"}>
|
||||
{!embedded && <h2 className="text-lg font-semibold text-white mb-4 flex items-center gap-2">作品预览</h2>}
|
||||
<div className="aspect-video bg-black/50 rounded-xl overflow-hidden flex items-center justify-center">
|
||||
{generatedVideo ? (
|
||||
<video src={generatedVideo} controls preload="metadata" className="w-full h-full object-contain" />
|
||||
@@ -71,4 +73,6 @@ export function PreviewPanel({
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
|
||||
return content;
|
||||
}
|
||||
|
||||
@@ -92,7 +92,7 @@ export function RefAudioPanel({
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
<div className="flex justify-between items-center mb-2">
|
||||
<span className="text-sm text-gray-300">📁 我的参考音频</span>
|
||||
<span className="text-sm text-gray-300">📁 我的参考音频 <span className="text-xs text-gray-500 font-normal">(上传3-10秒语音样本)</span></span>
|
||||
<div className="flex gap-2">
|
||||
<input
|
||||
type="file"
|
||||
@@ -187,7 +187,7 @@ export function RefAudioPanel({
|
||||
<div className="text-white text-xs truncate pr-1 flex-1" title={audio.name}>
|
||||
{audio.name}
|
||||
</div>
|
||||
<div className="flex gap-1 opacity-0 group-hover:opacity-100 transition-opacity">
|
||||
<div className="flex gap-1 opacity-40 group-hover:opacity-100 transition-opacity">
|
||||
<button
|
||||
onClick={(e) => onTogglePlayPreview(audio, e)}
|
||||
className="text-gray-400 hover:text-purple-400 text-xs"
|
||||
@@ -287,9 +287,6 @@ export function RefAudioPanel({
|
||||
)}
|
||||
</div>
|
||||
|
||||
<p className="text-xs text-gray-500 mt-2 border-t border-white/10 pt-3">
|
||||
上传任意语音样本(3-10秒),系统将自动识别内容并克隆声音
|
||||
</p>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
213
frontend/src/features/home/ui/RewriteModal.tsx
Normal file
213
frontend/src/features/home/ui/RewriteModal.tsx
Normal file
@@ -0,0 +1,213 @@
|
||||
import { useState, useEffect, useRef, useCallback } from "react";
|
||||
import { Loader2, Sparkles } from "lucide-react";
|
||||
import api from "@/shared/api/axios";
|
||||
import { ApiResponse, unwrap } from "@/shared/api/types";
|
||||
|
||||
const CUSTOM_PROMPT_KEY = "vigent_rewriteCustomPrompt";
|
||||
|
||||
interface RewriteModalProps {
|
||||
isOpen: boolean;
|
||||
onClose: () => void;
|
||||
originalText: string;
|
||||
onApply: (text: string) => void;
|
||||
}
|
||||
|
||||
export default function RewriteModal({
|
||||
isOpen,
|
||||
onClose,
|
||||
originalText,
|
||||
onApply,
|
||||
}: RewriteModalProps) {
|
||||
const [customPrompt, setCustomPrompt] = useState(
|
||||
() => (typeof window !== "undefined" ? localStorage.getItem(CUSTOM_PROMPT_KEY) || "" : "")
|
||||
);
|
||||
const [rewrittenText, setRewrittenText] = useState("");
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
// Debounced save customPrompt to localStorage
|
||||
const debounceRef = useRef<ReturnType<typeof setTimeout>>(undefined);
|
||||
useEffect(() => {
|
||||
debounceRef.current = setTimeout(() => {
|
||||
localStorage.setItem(CUSTOM_PROMPT_KEY, customPrompt);
|
||||
}, 300);
|
||||
return () => clearTimeout(debounceRef.current);
|
||||
}, [customPrompt]);
|
||||
|
||||
// Reset state when modal opens
|
||||
useEffect(() => {
|
||||
if (isOpen) {
|
||||
setRewrittenText("");
|
||||
setError(null);
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, [isOpen]);
|
||||
|
||||
const handleRewrite = useCallback(async () => {
|
||||
if (!originalText.trim()) return;
|
||||
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
|
||||
try {
|
||||
const { data: res } = await api.post<
|
||||
ApiResponse<{ rewritten_text: string }>
|
||||
>("/api/ai/rewrite", {
|
||||
text: originalText,
|
||||
custom_prompt: customPrompt.trim() || null,
|
||||
});
|
||||
const payload = unwrap(res);
|
||||
setRewrittenText(payload.rewritten_text || "");
|
||||
} catch (err: unknown) {
|
||||
console.error("AI rewrite failed:", err);
|
||||
const axiosErr = err as {
|
||||
response?: { data?: { message?: string } };
|
||||
message?: string;
|
||||
};
|
||||
const msg =
|
||||
axiosErr.response?.data?.message || axiosErr.message || "改写失败,请重试";
|
||||
setError(msg);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, [originalText, customPrompt]);
|
||||
|
||||
const handleApply = () => {
|
||||
onApply(rewrittenText);
|
||||
onClose();
|
||||
};
|
||||
|
||||
const handleRetry = () => {
|
||||
setRewrittenText("");
|
||||
setError(null);
|
||||
};
|
||||
|
||||
// ESC to close
|
||||
useEffect(() => {
|
||||
if (!isOpen) return;
|
||||
const handleKeyDown = (e: KeyboardEvent) => {
|
||||
if (e.key === "Escape") onClose();
|
||||
};
|
||||
document.addEventListener("keydown", handleKeyDown);
|
||||
return () => document.removeEventListener("keydown", handleKeyDown);
|
||||
}, [isOpen, onClose]);
|
||||
|
||||
if (!isOpen) return null;
|
||||
|
||||
return (
|
||||
<div className="fixed inset-0 z-50 flex items-center justify-center bg-black/80 backdrop-blur-sm p-4 animate-in fade-in duration-200">
|
||||
<div className="bg-[#1a1a1a] border border-white/10 rounded-2xl w-full max-w-2xl max-h-[90vh] overflow-hidden flex flex-col shadow-2xl">
|
||||
{/* Header */}
|
||||
<div className="flex items-center justify-between p-4 border-b border-white/10 bg-white/5">
|
||||
<h3 className="text-lg font-semibold text-white flex items-center gap-2">
|
||||
<Sparkles className="h-5 w-5 text-purple-400" />
|
||||
AI 智能改写
|
||||
</h3>
|
||||
<button
|
||||
onClick={onClose}
|
||||
className="text-gray-400 hover:text-white transition-colors text-2xl leading-none"
|
||||
>
|
||||
×
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* Content */}
|
||||
<div className="flex-1 overflow-y-auto p-6 space-y-5">
|
||||
{/* Custom Prompt */}
|
||||
<div className="space-y-2">
|
||||
<label className="text-sm text-gray-300">
|
||||
自定义提示词 (可选)
|
||||
</label>
|
||||
<textarea
|
||||
value={customPrompt}
|
||||
onChange={(e) => setCustomPrompt(e.target.value)}
|
||||
placeholder="输入改写要求..."
|
||||
rows={3}
|
||||
className="w-full bg-black/20 border border-white/10 rounded-xl px-3 py-2 text-sm text-white placeholder-gray-500 focus:outline-none focus:border-purple-500 transition-colors resize-none"
|
||||
/>
|
||||
<p className="text-xs text-gray-500">留空则使用默认提示词</p>
|
||||
</div>
|
||||
|
||||
{/* Action button (before result) */}
|
||||
{!rewrittenText && (
|
||||
<button
|
||||
onClick={handleRewrite}
|
||||
disabled={isLoading || !originalText.trim()}
|
||||
className="w-full py-3 px-4 bg-gradient-to-r from-purple-600 to-pink-600 hover:from-purple-500 hover:to-pink-500 disabled:opacity-50 disabled:cursor-not-allowed text-white rounded-xl transition-all font-medium shadow-lg flex items-center justify-center gap-2"
|
||||
>
|
||||
{isLoading ? (
|
||||
<>
|
||||
<Loader2 className="w-5 h-5 animate-spin" />
|
||||
改写中...
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<Sparkles className="w-5 h-5" />
|
||||
开始改写
|
||||
</>
|
||||
)}
|
||||
</button>
|
||||
)}
|
||||
|
||||
{/* Error */}
|
||||
{error && (
|
||||
<div className="bg-red-500/10 border border-red-500/30 rounded-xl p-4">
|
||||
<p className="text-red-400 text-sm">{error}</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Rewritten result */}
|
||||
{rewrittenText && (
|
||||
<>
|
||||
<div className="space-y-2">
|
||||
<div className="flex justify-between items-center">
|
||||
<h4 className="font-semibold text-purple-300 flex items-center gap-2">
|
||||
<Sparkles className="h-4 w-4" />
|
||||
AI 改写结果
|
||||
</h4>
|
||||
<button
|
||||
onClick={handleApply}
|
||||
className="text-xs bg-gradient-to-r from-purple-600 to-pink-600 hover:from-purple-500 hover:to-pink-500 text-white px-3 py-1.5 rounded-lg transition-colors shadow-sm"
|
||||
>
|
||||
使用此结果
|
||||
</button>
|
||||
</div>
|
||||
<div className="bg-purple-900/10 border border-purple-500/20 rounded-xl p-4 max-h-60 overflow-y-auto hide-scrollbar">
|
||||
<p className="text-gray-200 text-sm leading-relaxed whitespace-pre-wrap">
|
||||
{rewrittenText}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="space-y-2">
|
||||
<div className="flex justify-between items-center">
|
||||
<h4 className="font-semibold text-gray-400 flex items-center gap-2">
|
||||
📝 原文对比
|
||||
</h4>
|
||||
<button
|
||||
onClick={onClose}
|
||||
className="text-xs bg-white/10 hover:bg-white/20 text-white px-3 py-1.5 rounded-lg transition-colors"
|
||||
>
|
||||
保留原文
|
||||
</button>
|
||||
</div>
|
||||
<div className="bg-white/5 border border-white/10 rounded-xl p-4 max-h-40 overflow-y-auto hide-scrollbar">
|
||||
<p className="text-gray-400 text-sm leading-relaxed whitespace-pre-wrap">
|
||||
{originalText}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<button
|
||||
onClick={handleRetry}
|
||||
className="w-full py-2.5 px-4 bg-white/10 hover:bg-white/20 text-white rounded-xl transition-colors"
|
||||
>
|
||||
重新改写
|
||||
</button>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -18,6 +18,7 @@ interface ScriptEditorProps {
|
||||
text: string;
|
||||
onChangeText: (value: string) => void;
|
||||
onOpenExtractModal: () => void;
|
||||
onOpenRewriteModal: () => void;
|
||||
onGenerateMeta: () => void;
|
||||
isGeneratingMeta: boolean;
|
||||
onTranslate: (targetLang: string) => void;
|
||||
@@ -34,6 +35,7 @@ export function ScriptEditor({
|
||||
text,
|
||||
onChangeText,
|
||||
onOpenExtractModal,
|
||||
onOpenRewriteModal,
|
||||
onGenerateMeta,
|
||||
isGeneratingMeta,
|
||||
onTranslate,
|
||||
@@ -86,7 +88,7 @@ export function ScriptEditor({
|
||||
<div className="relative z-10 bg-white/5 rounded-2xl p-4 sm:p-6 border border-white/10 backdrop-blur-sm">
|
||||
<div className="mb-4 space-y-3">
|
||||
<h2 className="text-base sm:text-lg font-semibold text-white flex items-center gap-2">
|
||||
✍️ 文案提取与编辑
|
||||
一、文案提取与编辑
|
||||
</h2>
|
||||
<div className="flex gap-2 flex-wrap justify-end items-center">
|
||||
{/* 历史文案 */}
|
||||
@@ -123,7 +125,7 @@ export function ScriptEditor({
|
||||
e.stopPropagation();
|
||||
onDeleteScript(script.id);
|
||||
}}
|
||||
className="opacity-0 group-hover:opacity-100 p-1 text-gray-500 hover:text-red-400 transition-all shrink-0"
|
||||
className="opacity-40 group-hover:opacity-100 p-1 text-gray-500 hover:text-red-400 transition-all shrink-0"
|
||||
>
|
||||
<Trash2 className="h-3 w-3" />
|
||||
</button>
|
||||
@@ -218,18 +220,32 @@ export function ScriptEditor({
|
||||
/>
|
||||
<div className="flex items-center justify-between mt-2 text-sm text-gray-400">
|
||||
<span>{text.length} 字</span>
|
||||
<button
|
||||
onClick={onSaveScript}
|
||||
disabled={!text.trim()}
|
||||
className={`px-2.5 py-1 text-xs rounded transition-all flex items-center gap-1 ${
|
||||
!text.trim()
|
||||
? "bg-gray-700 cursor-not-allowed text-gray-500"
|
||||
: "bg-amber-600/80 hover:bg-amber-600 text-white"
|
||||
}`}
|
||||
>
|
||||
<Save className="h-3 w-3" />
|
||||
保存文案
|
||||
</button>
|
||||
<div className="flex items-center gap-2">
|
||||
<button
|
||||
onClick={onOpenRewriteModal}
|
||||
disabled={!text.trim()}
|
||||
className={`px-2.5 py-1 text-xs rounded transition-all flex items-center gap-1 ${
|
||||
!text.trim()
|
||||
? "bg-gray-700 cursor-not-allowed text-gray-500"
|
||||
: "bg-purple-600/80 hover:bg-purple-600 text-white"
|
||||
}`}
|
||||
>
|
||||
<Sparkles className="h-3 w-3" />
|
||||
AI智能改写
|
||||
</button>
|
||||
<button
|
||||
onClick={onSaveScript}
|
||||
disabled={!text.trim()}
|
||||
className={`px-2.5 py-1 text-xs rounded transition-all flex items-center gap-1 ${
|
||||
!text.trim()
|
||||
? "bg-gray-700 cursor-not-allowed text-gray-500"
|
||||
: "bg-amber-600/80 hover:bg-amber-600 text-white"
|
||||
}`}
|
||||
>
|
||||
<Save className="h-3 w-3" />
|
||||
保存文案
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -18,21 +18,14 @@ export default function ScriptExtractionModal({
|
||||
const {
|
||||
isLoading,
|
||||
script,
|
||||
rewrittenScript,
|
||||
error,
|
||||
doRewrite,
|
||||
step,
|
||||
dragActive,
|
||||
selectedFile,
|
||||
activeTab,
|
||||
inputUrl,
|
||||
customPrompt,
|
||||
showCustomPrompt,
|
||||
setDoRewrite,
|
||||
setActiveTab,
|
||||
setInputUrl,
|
||||
setCustomPrompt,
|
||||
setShowCustomPrompt,
|
||||
handleDrag,
|
||||
handleDrop,
|
||||
handleFileChange,
|
||||
@@ -190,46 +183,6 @@ export default function ScriptExtractionModal({
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Options */}
|
||||
<div className="bg-white/5 rounded-xl border border-white/10 overflow-hidden">
|
||||
<div className="flex items-center justify-between p-4">
|
||||
<label className="flex items-center gap-2 cursor-pointer">
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={doRewrite}
|
||||
onChange={(e) => setDoRewrite(e.target.checked)}
|
||||
className="w-4 h-4 rounded bg-white/10 border-white/20 text-purple-500 focus:ring-purple-500"
|
||||
/>
|
||||
<span className="text-sm text-gray-300">
|
||||
AI 智能改写
|
||||
</span>
|
||||
</label>
|
||||
{doRewrite && (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setShowCustomPrompt(!showCustomPrompt)}
|
||||
className="text-xs text-purple-400 hover:text-purple-300 transition-colors flex items-center gap-1"
|
||||
>
|
||||
自定义提示词 {showCustomPrompt ? "▲" : "▼"}
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
{doRewrite && showCustomPrompt && (
|
||||
<div className="px-4 pb-4 space-y-2">
|
||||
<textarea
|
||||
value={customPrompt}
|
||||
onChange={(e) => setCustomPrompt(e.target.value)}
|
||||
placeholder="输入自定义改写提示词..."
|
||||
rows={3}
|
||||
className="w-full bg-black/20 border border-white/10 rounded-lg px-3 py-2 text-sm text-white placeholder-gray-500 focus:outline-none focus:border-purple-500 transition-colors resize-none"
|
||||
/>
|
||||
<p className="text-xs text-gray-500">
|
||||
留空则使用默认提示词
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Error */}
|
||||
{error && (
|
||||
<div className="bg-red-500/10 border border-red-500/30 rounded-xl p-4">
|
||||
@@ -273,9 +226,7 @@ export default function ScriptExtractionModal({
|
||||
<p className="text-sm text-gray-400 text-center max-w-sm px-4">
|
||||
{activeTab === "url" && "正在下载视频..."}
|
||||
<br />
|
||||
{doRewrite
|
||||
? "正在进行语音识别和 AI 智能改写..."
|
||||
: "正在进行语音识别..."}
|
||||
正在进行语音识别...
|
||||
<br />
|
||||
<span className="opacity-75">
|
||||
大文件可能需要几分钟,请不要关闭窗口
|
||||
@@ -286,60 +237,30 @@ export default function ScriptExtractionModal({
|
||||
|
||||
{step === "result" && (
|
||||
<div className="space-y-6">
|
||||
{rewrittenScript && (
|
||||
<div className="space-y-2">
|
||||
<div className="flex justify-between items-center">
|
||||
<h4 className="font-semibold text-purple-300 flex items-center gap-2">
|
||||
✨ AI 改写结果{" "}
|
||||
<span className="text-xs font-normal text-purple-400/70">
|
||||
(推荐)
|
||||
</span>
|
||||
</h4>
|
||||
<div className="space-y-2">
|
||||
<div className="flex justify-between items-center">
|
||||
<h4 className="font-semibold text-gray-300 flex items-center gap-2">
|
||||
🎙️ 识别结果
|
||||
</h4>
|
||||
<div className="flex items-center gap-2">
|
||||
{onApply && (
|
||||
<button
|
||||
onClick={() => handleApplyAndClose(rewrittenScript)}
|
||||
onClick={() => handleApplyAndClose(script)}
|
||||
className="text-xs bg-gradient-to-r from-purple-600 to-pink-600 hover:from-purple-500 hover:to-pink-500 text-white px-3 py-1.5 rounded-lg transition-colors flex items-center gap-1 shadow-sm"
|
||||
>
|
||||
📥 填入
|
||||
</button>
|
||||
)}
|
||||
<button
|
||||
onClick={() => copyToClipboard(rewrittenScript)}
|
||||
className="text-xs bg-purple-600 hover:bg-purple-500 text-white px-3 py-1.5 rounded-lg transition-colors flex items-center gap-1"
|
||||
onClick={() => copyToClipboard(script)}
|
||||
className="text-xs bg-white/10 hover:bg-white/20 text-white px-3 py-1.5 rounded-lg transition-colors"
|
||||
>
|
||||
📋 复制内容
|
||||
复制
|
||||
</button>
|
||||
</div>
|
||||
<div className="bg-purple-900/10 border border-purple-500/20 rounded-xl p-4 max-h-60 overflow-y-auto custom-scrollbar">
|
||||
<p className="text-gray-200 text-sm leading-relaxed whitespace-pre-wrap">
|
||||
{rewrittenScript}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="space-y-2">
|
||||
<div className="flex justify-between items-center">
|
||||
<h4 className="font-semibold text-gray-400 flex items-center gap-2">
|
||||
🎙️ 原始识别结果
|
||||
</h4>
|
||||
{onApply && (
|
||||
<button
|
||||
onClick={() => handleApplyAndClose(script)}
|
||||
className="text-xs bg-white/10 hover:bg-white/20 text-white px-3 py-1.5 rounded-lg transition-colors flex items-center gap-1"
|
||||
>
|
||||
📥 填入
|
||||
</button>
|
||||
)}
|
||||
<button
|
||||
onClick={() => copyToClipboard(script)}
|
||||
className="text-xs bg-white/10 hover:bg-white/20 text-white px-3 py-1.5 rounded-lg transition-colors"
|
||||
>
|
||||
复制
|
||||
</button>
|
||||
</div>
|
||||
<div className="bg-white/5 border border-white/10 rounded-xl p-4 max-h-40 overflow-y-auto custom-scrollbar">
|
||||
<p className="text-gray-400 text-sm leading-relaxed whitespace-pre-wrap">
|
||||
<div className="bg-white/5 border border-white/10 rounded-xl p-4 max-h-60 overflow-y-auto hide-scrollbar">
|
||||
<p className="text-gray-200 text-sm leading-relaxed whitespace-pre-wrap">
|
||||
{script}
|
||||
</p>
|
||||
</div>
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import { useEffect, useRef, useCallback, useState } from "react";
|
||||
import { useEffect, useRef, useCallback, useState, useMemo } from "react";
|
||||
import WaveSurfer from "wavesurfer.js";
|
||||
import { ChevronDown } from "lucide-react";
|
||||
import { ChevronDown, GripVertical } from "lucide-react";
|
||||
import type { TimelineSegment } from "@/features/home/model/useTimelineEditor";
|
||||
import type { Material } from "@/shared/types/material";
|
||||
|
||||
|
||||
interface TimelineEditorProps {
|
||||
audioDuration: number;
|
||||
audioUrl: string;
|
||||
@@ -13,14 +13,15 @@ interface TimelineEditorProps {
|
||||
onOutputAspectRatioChange: (ratio: "9:16" | "16:9") => void;
|
||||
onReorderSegment: (fromIdx: number, toIdx: number) => void;
|
||||
onClickSegment: (segment: TimelineSegment) => void;
|
||||
embedded?: boolean;
|
||||
}
|
||||
|
||||
function formatTime(sec: number): string {
|
||||
const m = Math.floor(sec / 60);
|
||||
const s = sec % 60;
|
||||
return `${String(m).padStart(2, "0")}:${s.toFixed(1).padStart(4, "0")}`;
|
||||
}
|
||||
|
||||
|
||||
function formatTime(sec: number): string {
|
||||
const m = Math.floor(sec / 60);
|
||||
const s = sec % 60;
|
||||
return `${String(m).padStart(2, "0")}:${s.toFixed(1).padStart(4, "0")}`;
|
||||
}
|
||||
|
||||
export function TimelineEditor({
|
||||
audioDuration,
|
||||
audioUrl,
|
||||
@@ -30,12 +31,13 @@ export function TimelineEditor({
|
||||
onOutputAspectRatioChange,
|
||||
onReorderSegment,
|
||||
onClickSegment,
|
||||
embedded = false,
|
||||
}: TimelineEditorProps) {
|
||||
const waveRef = useRef<HTMLDivElement>(null);
|
||||
const wsRef = useRef<WaveSurfer | null>(null);
|
||||
const [waveReady, setWaveReady] = useState(false);
|
||||
const [isPlaying, setIsPlaying] = useState(false);
|
||||
|
||||
const waveRef = useRef<HTMLDivElement>(null);
|
||||
const wsRef = useRef<WaveSurfer | null>(null);
|
||||
const [waveReady, setWaveReady] = useState(false);
|
||||
const [isPlaying, setIsPlaying] = useState(false);
|
||||
|
||||
// Refs for high-frequency DOM updates (avoid 60fps re-renders)
|
||||
const playheadRef = useRef<HTMLDivElement>(null);
|
||||
const timeRef = useRef<HTMLSpanElement>(null);
|
||||
@@ -44,7 +46,7 @@ export function TimelineEditor({
|
||||
useEffect(() => {
|
||||
audioDurationRef.current = audioDuration;
|
||||
}, [audioDuration]);
|
||||
|
||||
|
||||
// Drag-to-reorder state
|
||||
const [dragFromIdx, setDragFromIdx] = useState<number | null>(null);
|
||||
const [dragOverIdx, setDragOverIdx] = useState<number | null>(null);
|
||||
@@ -68,57 +70,57 @@ export function TimelineEditor({
|
||||
if (ratioOpen) document.addEventListener("mousedown", handler);
|
||||
return () => document.removeEventListener("mousedown", handler);
|
||||
}, [ratioOpen]);
|
||||
|
||||
// Create / recreate wavesurfer when audioUrl changes
|
||||
|
||||
// Create / recreate wavesurfer when audioUrl changes
|
||||
useEffect(() => {
|
||||
if (!waveRef.current || !audioUrl) return;
|
||||
|
||||
const playheadEl = playheadRef.current;
|
||||
const timeEl = timeRef.current;
|
||||
|
||||
// Destroy previous instance
|
||||
if (wsRef.current) {
|
||||
wsRef.current.destroy();
|
||||
wsRef.current = null;
|
||||
}
|
||||
|
||||
const ws = WaveSurfer.create({
|
||||
container: waveRef.current,
|
||||
height: 56,
|
||||
waveColor: "#6d28d9",
|
||||
progressColor: "#a855f7",
|
||||
barWidth: 2,
|
||||
barGap: 1,
|
||||
barRadius: 2,
|
||||
cursorWidth: 1,
|
||||
cursorColor: "#e879f9",
|
||||
interact: true,
|
||||
normalize: true,
|
||||
});
|
||||
|
||||
// Click waveform → seek + auto-play
|
||||
ws.on("interaction", () => ws.play());
|
||||
ws.on("play", () => setIsPlaying(true));
|
||||
ws.on("pause", () => setIsPlaying(false));
|
||||
ws.on("finish", () => {
|
||||
setIsPlaying(false);
|
||||
if (playheadRef.current) playheadRef.current.style.display = "none";
|
||||
});
|
||||
// High-frequency: update playhead + time via refs (no React re-render)
|
||||
ws.on("timeupdate", (time: number) => {
|
||||
const dur = audioDurationRef.current;
|
||||
if (playheadRef.current && dur > 0) {
|
||||
playheadRef.current.style.left = `${(time / dur) * 100}%`;
|
||||
playheadRef.current.style.display = "block";
|
||||
}
|
||||
if (timeRef.current) {
|
||||
timeRef.current.textContent = formatTime(time);
|
||||
}
|
||||
});
|
||||
|
||||
ws.load(audioUrl);
|
||||
wsRef.current = ws;
|
||||
|
||||
|
||||
// Destroy previous instance
|
||||
if (wsRef.current) {
|
||||
wsRef.current.destroy();
|
||||
wsRef.current = null;
|
||||
}
|
||||
|
||||
const ws = WaveSurfer.create({
|
||||
container: waveRef.current,
|
||||
height: 56,
|
||||
waveColor: "#6d28d9",
|
||||
progressColor: "#a855f7",
|
||||
barWidth: 2,
|
||||
barGap: 1,
|
||||
barRadius: 2,
|
||||
cursorWidth: 1,
|
||||
cursorColor: "#e879f9",
|
||||
interact: true,
|
||||
normalize: true,
|
||||
});
|
||||
|
||||
// Click waveform → seek + auto-play
|
||||
ws.on("interaction", () => ws.play());
|
||||
ws.on("play", () => setIsPlaying(true));
|
||||
ws.on("pause", () => setIsPlaying(false));
|
||||
ws.on("finish", () => {
|
||||
setIsPlaying(false);
|
||||
if (playheadRef.current) playheadRef.current.style.display = "none";
|
||||
});
|
||||
// High-frequency: update playhead + time via refs (no React re-render)
|
||||
ws.on("timeupdate", (time: number) => {
|
||||
const dur = audioDurationRef.current;
|
||||
if (playheadRef.current && dur > 0) {
|
||||
playheadRef.current.style.left = `${(time / dur) * 100}%`;
|
||||
playheadRef.current.style.display = "block";
|
||||
}
|
||||
if (timeRef.current) {
|
||||
timeRef.current.textContent = formatTime(time);
|
||||
}
|
||||
});
|
||||
|
||||
ws.load(audioUrl);
|
||||
wsRef.current = ws;
|
||||
|
||||
return () => {
|
||||
ws.destroy();
|
||||
wsRef.current = null;
|
||||
@@ -127,60 +129,64 @@ export function TimelineEditor({
|
||||
if (timeEl) timeEl.textContent = formatTime(0);
|
||||
};
|
||||
}, [audioUrl, waveReady]);
|
||||
|
||||
// Callback ref to detect when waveRef div mounts
|
||||
const waveCallbackRef = useCallback((node: HTMLDivElement | null) => {
|
||||
(waveRef as React.MutableRefObject<HTMLDivElement | null>).current = node;
|
||||
setWaveReady(!!node);
|
||||
}, []);
|
||||
|
||||
const handlePlayPause = useCallback(() => {
|
||||
wsRef.current?.playPause();
|
||||
}, []);
|
||||
|
||||
// Drag-to-reorder handlers
|
||||
const handleDragStart = useCallback((idx: number, e: React.DragEvent) => {
|
||||
setDragFromIdx(idx);
|
||||
e.dataTransfer.effectAllowed = "move";
|
||||
e.dataTransfer.setData("text/plain", String(idx));
|
||||
}, []);
|
||||
|
||||
const handleDragOver = useCallback((idx: number, e: React.DragEvent) => {
|
||||
e.preventDefault();
|
||||
e.dataTransfer.dropEffect = "move";
|
||||
setDragOverIdx(idx);
|
||||
}, []);
|
||||
|
||||
const handleDragLeave = useCallback(() => {
|
||||
setDragOverIdx(null);
|
||||
}, []);
|
||||
|
||||
const handleDrop = useCallback((toIdx: number, e: React.DragEvent) => {
|
||||
e.preventDefault();
|
||||
const fromIdx = parseInt(e.dataTransfer.getData("text/plain"), 10);
|
||||
if (!isNaN(fromIdx) && fromIdx !== toIdx) {
|
||||
onReorderSegment(fromIdx, toIdx);
|
||||
}
|
||||
setDragFromIdx(null);
|
||||
setDragOverIdx(null);
|
||||
}, [onReorderSegment]);
|
||||
|
||||
const handleDragEnd = useCallback(() => {
|
||||
setDragFromIdx(null);
|
||||
setDragOverIdx(null);
|
||||
}, []);
|
||||
|
||||
// Filter visible vs overflow segments
|
||||
const visibleSegments = segments.filter((s) => s.start < audioDuration);
|
||||
const overflowSegments = segments.filter((s) => s.start >= audioDuration);
|
||||
const hasSegments = visibleSegments.length > 0;
|
||||
|
||||
return (
|
||||
<div className="bg-white/5 rounded-2xl p-4 sm:p-6 border border-white/10 backdrop-blur-sm">
|
||||
|
||||
// Callback ref to detect when waveRef div mounts
|
||||
const waveCallbackRef = useCallback((node: HTMLDivElement | null) => {
|
||||
(waveRef as React.MutableRefObject<HTMLDivElement | null>).current = node;
|
||||
setWaveReady(!!node);
|
||||
}, []);
|
||||
|
||||
const handlePlayPause = useCallback(() => {
|
||||
wsRef.current?.playPause();
|
||||
}, []);
|
||||
|
||||
// Drag-to-reorder handlers
|
||||
const handleDragStart = useCallback((idx: number, e: React.DragEvent) => {
|
||||
setDragFromIdx(idx);
|
||||
e.dataTransfer.effectAllowed = "move";
|
||||
e.dataTransfer.setData("text/plain", String(idx));
|
||||
}, []);
|
||||
|
||||
const handleDragOver = useCallback((idx: number, e: React.DragEvent) => {
|
||||
e.preventDefault();
|
||||
e.dataTransfer.dropEffect = "move";
|
||||
setDragOverIdx(idx);
|
||||
}, []);
|
||||
|
||||
const handleDragLeave = useCallback(() => {
|
||||
setDragOverIdx(null);
|
||||
}, []);
|
||||
|
||||
const handleDrop = useCallback((toIdx: number, e: React.DragEvent) => {
|
||||
e.preventDefault();
|
||||
const fromIdx = parseInt(e.dataTransfer.getData("text/plain"), 10);
|
||||
if (!isNaN(fromIdx) && fromIdx !== toIdx) {
|
||||
onReorderSegment(fromIdx, toIdx);
|
||||
}
|
||||
setDragFromIdx(null);
|
||||
setDragOverIdx(null);
|
||||
}, [onReorderSegment]);
|
||||
|
||||
const handleDragEnd = useCallback(() => {
|
||||
setDragFromIdx(null);
|
||||
setDragOverIdx(null);
|
||||
}, []);
|
||||
|
||||
// Filter visible vs overflow segments
|
||||
const visibleSegments = useMemo(() => segments.filter((s) => s.start < audioDuration), [segments, audioDuration]);
|
||||
const overflowSegments = useMemo(() => segments.filter((s) => s.start >= audioDuration), [segments, audioDuration]);
|
||||
const hasSegments = visibleSegments.length > 0;
|
||||
|
||||
const content = (
|
||||
<>
|
||||
<div className="flex items-center justify-between mb-3">
|
||||
<h2 className="text-base sm:text-lg font-semibold text-white flex items-center gap-2">
|
||||
🎞️ 时间轴编辑
|
||||
</h2>
|
||||
{!embedded ? (
|
||||
<h2 className="text-base sm:text-lg font-semibold text-white flex items-center gap-2">
|
||||
时间轴编辑
|
||||
</h2>
|
||||
) : (
|
||||
<h3 className="text-sm font-medium text-gray-400">时间轴编辑</h3>
|
||||
)}
|
||||
<div className="flex items-center gap-2 text-xs text-gray-400">
|
||||
<div ref={ratioRef} className="relative">
|
||||
<button
|
||||
@@ -231,28 +237,28 @@ export function TimelineEditor({
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Waveform — always rendered so ref stays mounted */}
|
||||
<div className="relative mb-1">
|
||||
<div ref={waveCallbackRef} className="rounded-lg overflow-hidden bg-black/20 cursor-pointer" style={{ minHeight: 56 }} />
|
||||
</div>
|
||||
|
||||
{/* Segment blocks or empty placeholder */}
|
||||
{hasSegments ? (
|
||||
<>
|
||||
<div className="relative h-14 flex select-none">
|
||||
{/* Playhead — syncs with audio playback */}
|
||||
<div
|
||||
ref={playheadRef}
|
||||
className="absolute top-0 h-full w-0.5 bg-fuchsia-400 z-10 pointer-events-none"
|
||||
style={{ display: "none", left: "0%" }}
|
||||
/>
|
||||
{visibleSegments.map((seg, i) => {
|
||||
const left = (seg.start / audioDuration) * 100;
|
||||
const width = ((seg.end - seg.start) / audioDuration) * 100;
|
||||
const segDur = seg.end - seg.start;
|
||||
const isDragTarget = dragOverIdx === i && dragFromIdx !== i;
|
||||
|
||||
|
||||
{/* Waveform — always rendered so ref stays mounted */}
|
||||
<div className="relative mb-1">
|
||||
<div ref={waveCallbackRef} className="rounded-lg overflow-hidden bg-black/20 cursor-pointer" style={{ minHeight: 56 }} />
|
||||
</div>
|
||||
|
||||
{/* Segment blocks or empty placeholder */}
|
||||
{hasSegments ? (
|
||||
<>
|
||||
<div className="relative h-14 flex select-none">
|
||||
{/* Playhead — syncs with audio playback */}
|
||||
<div
|
||||
ref={playheadRef}
|
||||
className="absolute top-0 h-full w-0.5 bg-fuchsia-400 z-10 pointer-events-none"
|
||||
style={{ display: "none", left: "0%" }}
|
||||
/>
|
||||
{visibleSegments.map((seg, i) => {
|
||||
const left = (seg.start / audioDuration) * 100;
|
||||
const width = ((seg.end - seg.start) / audioDuration) * 100;
|
||||
const segDur = seg.end - seg.start;
|
||||
const isDragTarget = dragOverIdx === i && dragFromIdx !== i;
|
||||
|
||||
// Compute loop portion for the last visible segment
|
||||
const isLastVisible = i === visibleSegments.length - 1;
|
||||
let loopPercent = 0;
|
||||
@@ -266,84 +272,93 @@ export function TimelineEditor({
|
||||
loopPercent = ((segDur - effDur) / segDur) * 100;
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div key={seg.id} className="absolute top-0 h-full" style={{ left: `${left}%`, width: `${width}%` }}>
|
||||
<button
|
||||
draggable
|
||||
onDragStart={(e) => handleDragStart(i, e)}
|
||||
onDragOver={(e) => handleDragOver(i, e)}
|
||||
onDragLeave={handleDragLeave}
|
||||
onDrop={(e) => handleDrop(i, e)}
|
||||
onDragEnd={handleDragEnd}
|
||||
onClick={() => onClickSegment(seg)}
|
||||
className={`relative w-full h-full rounded-lg flex flex-col items-center justify-center overflow-hidden cursor-grab active:cursor-grabbing transition-all border ${
|
||||
isDragTarget
|
||||
? "ring-2 ring-purple-400 border-purple-400 scale-[1.02]"
|
||||
: dragFromIdx === i
|
||||
? "opacity-50 border-white/10"
|
||||
: "hover:opacity-90 border-white/10"
|
||||
}`}
|
||||
style={{ backgroundColor: seg.color + "33", borderColor: isDragTarget ? undefined : seg.color + "66" }}
|
||||
title={`拖拽可调换顺序 · 点击设置截取范围\n${seg.materialName}\n${segDur.toFixed(1)}s${loopPercent > 0 ? ` (含循环 ${(segDur * loopPercent / 100).toFixed(1)}s)` : ""}`}
|
||||
>
|
||||
<span className="text-[11px] text-white/90 truncate max-w-full px-1 leading-tight z-[1]">
|
||||
{seg.materialName}
|
||||
</span>
|
||||
<span className="text-[10px] text-white/60 leading-tight z-[1]">
|
||||
{segDur.toFixed(1)}s
|
||||
</span>
|
||||
{seg.sourceStart > 0 && (
|
||||
<span className="text-[9px] text-amber-400/80 leading-tight z-[1]">
|
||||
✂ {seg.sourceStart.toFixed(1)}s
|
||||
</span>
|
||||
)}
|
||||
{/* Loop fill stripe overlay */}
|
||||
{loopPercent > 0 && (
|
||||
<div
|
||||
className="absolute top-0 right-0 h-full pointer-events-none flex items-center justify-center"
|
||||
style={{
|
||||
width: `${loopPercent}%`,
|
||||
background: `repeating-linear-gradient(-45deg, transparent, transparent 3px, rgba(255,255,255,0.07) 3px, rgba(255,255,255,0.07) 6px)`,
|
||||
borderLeft: "1px dashed rgba(255,255,255,0.25)",
|
||||
}}
|
||||
>
|
||||
<span className="text-[9px] text-white/30">循环</span>
|
||||
</div>
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
|
||||
{/* Overflow segments — shown as gray chips */}
|
||||
{overflowSegments.length > 0 && (
|
||||
<div className="flex flex-wrap items-center gap-1.5 mt-1.5">
|
||||
<span className="text-[10px] text-gray-500">未使用:</span>
|
||||
{overflowSegments.map((seg) => (
|
||||
<span
|
||||
key={seg.id}
|
||||
className="text-[10px] text-gray-500 bg-white/5 border border-white/10 rounded px-1.5 py-0.5"
|
||||
>
|
||||
{seg.materialName}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<p className="text-[10px] text-gray-500 mt-1.5">
|
||||
点击波形定位播放 · 拖拽色块调换顺序 · 点击色块设置截取范围
|
||||
</p>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<div className="h-14 bg-white/5 rounded-lg" />
|
||||
<p className="text-[10px] text-gray-500 mt-1.5">
|
||||
选中配音和素材后可编辑时间轴
|
||||
</p>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div key={seg.id} className="absolute top-0 h-full" style={{ left: `${left}%`, width: `${width}%` }}>
|
||||
<button
|
||||
draggable
|
||||
onDragStart={(e) => handleDragStart(i, e)}
|
||||
onDragOver={(e) => handleDragOver(i, e)}
|
||||
onDragLeave={handleDragLeave}
|
||||
onDrop={(e) => handleDrop(i, e)}
|
||||
onDragEnd={handleDragEnd}
|
||||
onClick={() => onClickSegment(seg)}
|
||||
className={`relative w-full h-full rounded-lg flex flex-col items-center justify-center overflow-hidden cursor-grab active:cursor-grabbing transition-all border ${
|
||||
isDragTarget
|
||||
? "ring-2 ring-purple-400 border-purple-400 scale-[1.02]"
|
||||
: dragFromIdx === i
|
||||
? "opacity-50 border-white/10"
|
||||
: "hover:opacity-90 border-white/10"
|
||||
}`}
|
||||
style={{ backgroundColor: seg.color + "33", borderColor: isDragTarget ? undefined : seg.color + "66" }}
|
||||
title={`拖拽可调换顺序 · 点击设置截取范围\n${seg.materialName}\n${segDur.toFixed(1)}s${loopPercent > 0 ? ` (含循环 ${(segDur * loopPercent / 100).toFixed(1)}s)` : ""}`}
|
||||
>
|
||||
<GripVertical className="absolute top-0.5 left-0.5 h-3 w-3 text-white/30 z-[1]" />
|
||||
<span className="text-[11px] text-white/90 truncate max-w-full px-1 leading-tight z-[1]">
|
||||
{seg.materialName}
|
||||
</span>
|
||||
<span className="text-[10px] text-white/60 leading-tight z-[1]">
|
||||
{segDur.toFixed(1)}s
|
||||
</span>
|
||||
{seg.sourceStart > 0 && (
|
||||
<span className="text-[9px] text-amber-400/80 leading-tight z-[1]">
|
||||
✂ {seg.sourceStart.toFixed(1)}s
|
||||
</span>
|
||||
)}
|
||||
{/* Loop fill stripe overlay */}
|
||||
{loopPercent > 0 && (
|
||||
<div
|
||||
className="absolute top-0 right-0 h-full pointer-events-none flex items-center justify-center"
|
||||
style={{
|
||||
width: `${loopPercent}%`,
|
||||
background: `repeating-linear-gradient(-45deg, transparent, transparent 3px, rgba(255,255,255,0.07) 3px, rgba(255,255,255,0.07) 6px)`,
|
||||
borderLeft: "1px dashed rgba(255,255,255,0.25)",
|
||||
}}
|
||||
>
|
||||
<span className="text-[9px] text-white/30">循环</span>
|
||||
</div>
|
||||
)}
|
||||
</button>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
|
||||
{/* Overflow segments — shown as gray chips */}
|
||||
{overflowSegments.length > 0 && (
|
||||
<div className="flex flex-wrap items-center gap-1.5 mt-1.5">
|
||||
<span className="text-[10px] text-gray-500">未使用:</span>
|
||||
{overflowSegments.map((seg) => (
|
||||
<span
|
||||
key={seg.id}
|
||||
className="text-[10px] text-gray-500 bg-white/5 border border-white/10 rounded px-1.5 py-0.5"
|
||||
>
|
||||
{seg.materialName}
|
||||
</span>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<p className="text-[10px] text-gray-500 mt-1.5">
|
||||
点击波形定位播放 · 拖拽色块调换顺序 · 点击色块设置截取范围
|
||||
</p>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<div className="h-14 bg-white/5 rounded-lg" />
|
||||
<p className="text-[10px] text-gray-500 mt-1.5">
|
||||
选中配音和素材后可编辑时间轴
|
||||
</p>
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
|
||||
if (embedded) return content;
|
||||
|
||||
return (
|
||||
<div className="bg-white/5 rounded-2xl p-4 sm:p-6 border border-white/10 backdrop-blur-sm">
|
||||
{content}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -69,6 +69,7 @@ interface TitleSubtitlePanelProps {
|
||||
buildTextShadow: (color: string, size: number) => string;
|
||||
previewBaseWidth?: number;
|
||||
previewBaseHeight?: number;
|
||||
previewBackgroundUrl?: string | null;
|
||||
}
|
||||
|
||||
export function TitleSubtitlePanel({
|
||||
@@ -109,20 +110,35 @@ export function TitleSubtitlePanel({
|
||||
buildTextShadow,
|
||||
previewBaseWidth = 1080,
|
||||
previewBaseHeight = 1920,
|
||||
previewBackgroundUrl,
|
||||
}: TitleSubtitlePanelProps) {
|
||||
return (
|
||||
<div className="bg-white/5 rounded-2xl p-4 sm:p-6 border border-white/10 backdrop-blur-sm">
|
||||
<div className="flex items-center justify-between mb-4 gap-2">
|
||||
<h2 className="text-base sm:text-lg font-semibold text-white flex items-center gap-2">
|
||||
🎬 标题与字幕
|
||||
四、标题与字幕
|
||||
</h2>
|
||||
<button
|
||||
onClick={onTogglePreview}
|
||||
className="px-2 py-1 text-xs bg-white/10 hover:bg-white/20 rounded text-gray-300 flex items-center gap-1"
|
||||
>
|
||||
<Eye className="h-3.5 w-3.5" />
|
||||
{showStylePreview ? "收起预览" : "预览样式"}
|
||||
</button>
|
||||
<div className="flex items-center gap-1.5">
|
||||
<div className="relative shrink-0">
|
||||
<select
|
||||
value={titleDisplayMode}
|
||||
onChange={(e) => onTitleDisplayModeChange(e.target.value as "short" | "persistent")}
|
||||
className="appearance-none rounded-lg border border-white/15 bg-black/35 px-2.5 py-1.5 pr-7 text-xs text-gray-200 outline-none transition-colors hover:border-white/25 focus:border-purple-500"
|
||||
aria-label="标题显示方式"
|
||||
>
|
||||
<option value="short">标题短暂显示</option>
|
||||
<option value="persistent">标题常驻显示</option>
|
||||
</select>
|
||||
<ChevronDown className="pointer-events-none absolute right-2 top-1/2 h-3.5 w-3.5 -translate-y-1/2 text-gray-400" />
|
||||
</div>
|
||||
<button
|
||||
onClick={onTogglePreview}
|
||||
className="px-2 py-1 text-xs bg-white/10 hover:bg-white/20 rounded text-gray-300 flex items-center gap-1"
|
||||
>
|
||||
<Eye className="h-3.5 w-3.5" />
|
||||
{showStylePreview ? "收起预览" : "预览样式"}
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{showStylePreview && (
|
||||
@@ -147,24 +163,14 @@ export function TitleSubtitlePanel({
|
||||
buildTextShadow={buildTextShadow}
|
||||
previewBaseWidth={previewBaseWidth}
|
||||
previewBaseHeight={previewBaseHeight}
|
||||
previewBackgroundUrl={previewBackgroundUrl}
|
||||
/>
|
||||
)}
|
||||
|
||||
<div className="mb-4">
|
||||
<div className="mb-2 flex items-center justify-between gap-2">
|
||||
<label className="text-sm text-gray-300">片头标题(限制15个字)</label>
|
||||
<div className="relative shrink-0">
|
||||
<select
|
||||
value={titleDisplayMode}
|
||||
onChange={(e) => onTitleDisplayModeChange(e.target.value as "short" | "persistent")}
|
||||
className="appearance-none rounded-lg border border-white/15 bg-black/35 px-2.5 py-1.5 pr-7 text-xs text-gray-200 outline-none transition-colors hover:border-white/25 focus:border-purple-500"
|
||||
aria-label="标题显示方式"
|
||||
>
|
||||
<option value="short">短暂显示</option>
|
||||
<option value="persistent">常驻显示</option>
|
||||
</select>
|
||||
<ChevronDown className="pointer-events-none absolute right-2 top-1/2 h-3.5 w-3.5 -translate-y-1/2 text-gray-400" />
|
||||
</div>
|
||||
<div className="flex items-center justify-between mb-2">
|
||||
<label className="text-sm text-gray-300">片头标题</label>
|
||||
<span className={`text-xs ${videoTitle.length > 15 ? "text-red-400" : "text-gray-500"}`}>{videoTitle.length}/15</span>
|
||||
</div>
|
||||
<input
|
||||
type="text"
|
||||
@@ -178,7 +184,10 @@ export function TitleSubtitlePanel({
|
||||
</div>
|
||||
|
||||
<div className="mb-4">
|
||||
<label className="text-sm text-gray-300 mb-2 block">片头副标题(限制20个字)</label>
|
||||
<div className="flex items-center justify-between mb-2">
|
||||
<label className="text-sm text-gray-300">片头副标题</label>
|
||||
<span className={`text-xs ${videoSecondaryTitle.length > 20 ? "text-red-400" : "text-gray-500"}`}>{videoSecondaryTitle.length}/20</span>
|
||||
</div>
|
||||
<input
|
||||
type="text"
|
||||
value={videoSecondaryTitle}
|
||||
@@ -191,142 +200,85 @@ export function TitleSubtitlePanel({
|
||||
</div>
|
||||
|
||||
{titleStyles.length > 0 && (
|
||||
<div className="mb-4">
|
||||
<label className="text-sm text-gray-300 mb-2 block">标题样式</label>
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
{titleStyles.map((style) => (
|
||||
<button
|
||||
key={style.id}
|
||||
onClick={() => onSelectTitleStyle(style.id)}
|
||||
className={`p-2 rounded-lg border transition-all text-left ${selectedTitleStyleId === style.id
|
||||
? "border-purple-500 bg-purple-500/20"
|
||||
: "border-white/10 bg-white/5 hover:border-white/30"
|
||||
}`}
|
||||
<div className="mb-4 space-y-3">
|
||||
<div className="flex items-center gap-3">
|
||||
<label className="text-sm text-gray-300 shrink-0 w-20">标题样式</label>
|
||||
<div className="relative w-1/3 min-w-[100px]">
|
||||
<select
|
||||
value={selectedTitleStyleId}
|
||||
onChange={(e) => onSelectTitleStyle(e.target.value)}
|
||||
className="w-full appearance-none rounded-lg border border-white/15 bg-black/35 px-3 py-2 pr-8 text-sm text-gray-200 outline-none transition-colors hover:border-white/25 focus:border-purple-500"
|
||||
>
|
||||
<div className="text-white text-sm truncate">{style.label}</div>
|
||||
<div className="text-xs text-gray-400 truncate">
|
||||
{style.font_family || style.font_file || ""}
|
||||
</div>
|
||||
</button>
|
||||
))}
|
||||
{titleStyles.map((style) => (
|
||||
<option key={style.id} value={style.id}>{style.label}</option>
|
||||
))}
|
||||
</select>
|
||||
<ChevronDown className="pointer-events-none absolute right-2.5 top-1/2 h-3.5 w-3.5 -translate-y-1/2 text-gray-400" />
|
||||
</div>
|
||||
</div>
|
||||
<div className="mt-3">
|
||||
<label className="text-xs text-gray-400 mb-2 block">标题字号: {titleFontSize}px</label>
|
||||
<input
|
||||
type="range"
|
||||
min="60"
|
||||
max="150"
|
||||
step="1"
|
||||
value={titleFontSize}
|
||||
onChange={(e) => onTitleFontSizeChange(parseInt(e.target.value, 10))}
|
||||
className="w-full accent-purple-500"
|
||||
/>
|
||||
<div className="flex items-center gap-3">
|
||||
<label className="text-xs text-gray-400 shrink-0 w-20">字号 {titleFontSize}</label>
|
||||
<input type="range" min="60" max="150" step="1" value={titleFontSize} onChange={(e) => onTitleFontSizeChange(parseInt(e.target.value, 10))} className="flex-1 accent-purple-500" />
|
||||
</div>
|
||||
<div className="mt-3">
|
||||
<label className="text-xs text-gray-400 mb-2 block">标题位置: {titleTopMargin}px</label>
|
||||
<input
|
||||
type="range"
|
||||
min="0"
|
||||
max="300"
|
||||
step="1"
|
||||
value={titleTopMargin}
|
||||
onChange={(e) => onTitleTopMarginChange(parseInt(e.target.value, 10))}
|
||||
className="w-full accent-purple-500"
|
||||
/>
|
||||
<div className="flex items-center gap-3">
|
||||
<label className="text-xs text-gray-400 shrink-0 w-20">位置 {titleTopMargin}</label>
|
||||
<input type="range" min="0" max="300" step="1" value={titleTopMargin} onChange={(e) => onTitleTopMarginChange(parseInt(e.target.value, 10))} className="flex-1 accent-purple-500" />
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{titleStyles.length > 0 && (
|
||||
<div className="mb-4">
|
||||
<label className="text-sm text-gray-300 mb-2 block">副标题样式</label>
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
{titleStyles.map((style) => (
|
||||
<button
|
||||
key={style.id}
|
||||
onClick={() => onSelectSecondaryTitleStyle(style.id)}
|
||||
className={`p-2 rounded-lg border transition-all text-left ${selectedSecondaryTitleStyleId === style.id
|
||||
? "border-purple-500 bg-purple-500/20"
|
||||
: "border-white/10 bg-white/5 hover:border-white/30"
|
||||
}`}
|
||||
<div className="mb-4 space-y-3">
|
||||
<div className="flex items-center gap-3">
|
||||
<label className="text-sm text-gray-300 shrink-0 w-20">副标题样式</label>
|
||||
<div className="relative w-1/3 min-w-[100px]">
|
||||
<select
|
||||
value={selectedSecondaryTitleStyleId}
|
||||
onChange={(e) => onSelectSecondaryTitleStyle(e.target.value)}
|
||||
className="w-full appearance-none rounded-lg border border-white/15 bg-black/35 px-3 py-2 pr-8 text-sm text-gray-200 outline-none transition-colors hover:border-white/25 focus:border-purple-500"
|
||||
>
|
||||
<div className="text-white text-sm truncate">{style.label}</div>
|
||||
<div className="text-xs text-gray-400 truncate">
|
||||
{style.font_family || style.font_file || ""}
|
||||
</div>
|
||||
</button>
|
||||
))}
|
||||
{titleStyles.map((style) => (
|
||||
<option key={style.id} value={style.id}>{style.label}</option>
|
||||
))}
|
||||
</select>
|
||||
<ChevronDown className="pointer-events-none absolute right-2.5 top-1/2 h-3.5 w-3.5 -translate-y-1/2 text-gray-400" />
|
||||
</div>
|
||||
</div>
|
||||
<div className="mt-3">
|
||||
<label className="text-xs text-gray-400 mb-2 block">副标题字号: {secondaryTitleFontSize}px</label>
|
||||
<input
|
||||
type="range"
|
||||
min="30"
|
||||
max="100"
|
||||
step="1"
|
||||
value={secondaryTitleFontSize}
|
||||
onChange={(e) => onSecondaryTitleFontSizeChange(parseInt(e.target.value, 10))}
|
||||
className="w-full accent-purple-500"
|
||||
/>
|
||||
<div className="flex items-center gap-3">
|
||||
<label className="text-xs text-gray-400 shrink-0 w-20">字号 {secondaryTitleFontSize}</label>
|
||||
<input type="range" min="30" max="100" step="1" value={secondaryTitleFontSize} onChange={(e) => onSecondaryTitleFontSizeChange(parseInt(e.target.value, 10))} className="flex-1 accent-purple-500" />
|
||||
</div>
|
||||
<div className="mt-3">
|
||||
<label className="text-xs text-gray-400 mb-2 block">副标题间距: {secondaryTitleTopMargin}px</label>
|
||||
<input
|
||||
type="range"
|
||||
min="0"
|
||||
max="100"
|
||||
step="1"
|
||||
value={secondaryTitleTopMargin}
|
||||
onChange={(e) => onSecondaryTitleTopMarginChange(parseInt(e.target.value, 10))}
|
||||
className="w-full accent-purple-500"
|
||||
/>
|
||||
<div className="flex items-center gap-3">
|
||||
<label className="text-xs text-gray-400 shrink-0 w-20">间距 {secondaryTitleTopMargin}</label>
|
||||
<input type="range" min="0" max="100" step="1" value={secondaryTitleTopMargin} onChange={(e) => onSecondaryTitleTopMarginChange(parseInt(e.target.value, 10))} className="flex-1 accent-purple-500" />
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{subtitleStyles.length > 0 && (
|
||||
<div className="mt-4">
|
||||
<label className="text-sm text-gray-300 mb-2 block">字幕样式</label>
|
||||
<div className="grid grid-cols-2 gap-2">
|
||||
{subtitleStyles.map((style) => (
|
||||
<button
|
||||
key={style.id}
|
||||
onClick={() => onSelectSubtitleStyle(style.id)}
|
||||
className={`p-2 rounded-lg border transition-all text-left ${selectedSubtitleStyleId === style.id
|
||||
? "border-purple-500 bg-purple-500/20"
|
||||
: "border-white/10 bg-white/5 hover:border-white/30"
|
||||
}`}
|
||||
<div className="mt-4 space-y-3">
|
||||
<div className="flex items-center gap-3">
|
||||
<label className="text-sm text-gray-300 shrink-0 w-20">字幕样式</label>
|
||||
<div className="relative w-1/3 min-w-[100px]">
|
||||
<select
|
||||
value={selectedSubtitleStyleId}
|
||||
onChange={(e) => onSelectSubtitleStyle(e.target.value)}
|
||||
className="w-full appearance-none rounded-lg border border-white/15 bg-black/35 px-3 py-2 pr-8 text-sm text-gray-200 outline-none transition-colors hover:border-white/25 focus:border-purple-500"
|
||||
>
|
||||
<div className="text-white text-sm truncate">{style.label}</div>
|
||||
<div className="text-xs text-gray-400 truncate">
|
||||
{style.font_family || style.font_file || ""}
|
||||
</div>
|
||||
</button>
|
||||
))}
|
||||
{subtitleStyles.map((style) => (
|
||||
<option key={style.id} value={style.id}>{style.label}</option>
|
||||
))}
|
||||
</select>
|
||||
<ChevronDown className="pointer-events-none absolute right-2.5 top-1/2 h-3.5 w-3.5 -translate-y-1/2 text-gray-400" />
|
||||
</div>
|
||||
</div>
|
||||
<div className="mt-3">
|
||||
<label className="text-xs text-gray-400 mb-2 block">字幕字号: {subtitleFontSize}px</label>
|
||||
<input
|
||||
type="range"
|
||||
min="40"
|
||||
max="90"
|
||||
step="1"
|
||||
value={subtitleFontSize}
|
||||
onChange={(e) => onSubtitleFontSizeChange(parseInt(e.target.value, 10))}
|
||||
className="w-full accent-purple-500"
|
||||
/>
|
||||
<div className="flex items-center gap-3">
|
||||
<label className="text-xs text-gray-400 shrink-0 w-20">字号 {subtitleFontSize}</label>
|
||||
<input type="range" min="40" max="90" step="1" value={subtitleFontSize} onChange={(e) => onSubtitleFontSizeChange(parseInt(e.target.value, 10))} className="flex-1 accent-purple-500" />
|
||||
</div>
|
||||
<div className="mt-3">
|
||||
<label className="text-xs text-gray-400 mb-2 block">字幕位置: {subtitleBottomMargin}px</label>
|
||||
<input
|
||||
type="range"
|
||||
min="0"
|
||||
max="300"
|
||||
step="1"
|
||||
value={subtitleBottomMargin}
|
||||
onChange={(e) => onSubtitleBottomMarginChange(parseInt(e.target.value, 10))}
|
||||
className="w-full accent-purple-500"
|
||||
/>
|
||||
<div className="flex items-center gap-3">
|
||||
<label className="text-xs text-gray-400 shrink-0 w-20">位置 {subtitleBottomMargin}</label>
|
||||
<input type="range" min="0" max="300" step="1" value={subtitleBottomMargin} onChange={(e) => onSubtitleBottomMarginChange(parseInt(e.target.value, 10))} className="flex-1 accent-purple-500" />
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -13,6 +13,7 @@ interface VoiceSelectorProps {
|
||||
voice: string;
|
||||
onSelectVoice: (id: string) => void;
|
||||
voiceCloneSlot: ReactNode;
|
||||
embedded?: boolean;
|
||||
}
|
||||
|
||||
export function VoiceSelector({
|
||||
@@ -22,32 +23,29 @@ export function VoiceSelector({
|
||||
voice,
|
||||
onSelectVoice,
|
||||
voiceCloneSlot,
|
||||
embedded = false,
|
||||
}: VoiceSelectorProps) {
|
||||
return (
|
||||
<div className="bg-white/5 rounded-2xl p-6 border border-white/10 backdrop-blur-sm">
|
||||
<h2 className="text-lg font-semibold text-white mb-4 flex items-center gap-2">
|
||||
🎙️ 配音方式
|
||||
</h2>
|
||||
|
||||
const content = (
|
||||
<>
|
||||
<div className="flex gap-2 mb-4">
|
||||
<button
|
||||
onClick={() => onSelectTtsMode("edgetts")}
|
||||
className={`flex-1 py-2 px-4 rounded-lg font-medium transition-all flex items-center justify-center gap-2 ${ttsMode === "edgetts"
|
||||
className={`flex-1 py-2 px-2 sm:px-4 rounded-lg text-sm sm:text-base font-medium transition-all flex items-center justify-center gap-1.5 sm:gap-2 ${ttsMode === "edgetts"
|
||||
? "bg-purple-600 text-white"
|
||||
: "bg-white/10 text-gray-300 hover:bg-white/20"
|
||||
}`}
|
||||
>
|
||||
<Volume2 className="h-4 w-4" />
|
||||
<Volume2 className="h-4 w-4 shrink-0" />
|
||||
选择声音
|
||||
</button>
|
||||
<button
|
||||
onClick={() => onSelectTtsMode("voiceclone")}
|
||||
className={`flex-1 py-2 px-4 rounded-lg font-medium transition-all flex items-center justify-center gap-2 ${ttsMode === "voiceclone"
|
||||
className={`flex-1 py-2 px-2 sm:px-4 rounded-lg text-sm sm:text-base font-medium transition-all flex items-center justify-center gap-1.5 sm:gap-2 ${ttsMode === "voiceclone"
|
||||
? "bg-purple-600 text-white"
|
||||
: "bg-white/10 text-gray-300 hover:bg-white/20"
|
||||
}`}
|
||||
>
|
||||
<Mic className="h-4 w-4" />
|
||||
<Mic className="h-4 w-4 shrink-0" />
|
||||
克隆声音
|
||||
</button>
|
||||
</div>
|
||||
@@ -70,6 +68,17 @@ export function VoiceSelector({
|
||||
)}
|
||||
|
||||
{ttsMode === "voiceclone" && voiceCloneSlot}
|
||||
</>
|
||||
);
|
||||
|
||||
if (embedded) return content;
|
||||
|
||||
return (
|
||||
<div className="bg-white/5 rounded-2xl p-6 border border-white/10 backdrop-blur-sm">
|
||||
<h2 className="text-lg font-semibold text-white mb-4 flex items-center gap-2">
|
||||
🎙️ 配音方式
|
||||
</h2>
|
||||
{content}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { useState, useEffect, useCallback, useRef } from "react";
|
||||
import { useState, useEffect, useCallback } from "react";
|
||||
import api from "@/shared/api/axios";
|
||||
import { ApiResponse, unwrap } from "@/shared/api/types";
|
||||
import { toast } from "sonner";
|
||||
@@ -7,7 +7,6 @@ export type ExtractionStep = "config" | "processing" | "result";
|
||||
export type InputTab = "file" | "url";
|
||||
|
||||
const VALID_FILE_TYPES = [".mp4", ".mov", ".avi", ".mp3", ".wav", ".m4a"];
|
||||
const CUSTOM_PROMPT_KEY = "vigent_rewriteCustomPrompt";
|
||||
|
||||
interface UseScriptExtractionOptions {
|
||||
isOpen: boolean;
|
||||
@@ -16,32 +15,18 @@ interface UseScriptExtractionOptions {
|
||||
export const useScriptExtraction = ({ isOpen }: UseScriptExtractionOptions) => {
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [script, setScript] = useState("");
|
||||
const [rewrittenScript, setRewrittenScript] = useState("");
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
const [doRewrite, setDoRewrite] = useState(true);
|
||||
const [step, setStep] = useState<ExtractionStep>("config");
|
||||
const [dragActive, setDragActive] = useState(false);
|
||||
const [selectedFile, setSelectedFile] = useState<File | null>(null);
|
||||
const [activeTab, setActiveTab] = useState<InputTab>("url");
|
||||
const [inputUrl, setInputUrl] = useState("");
|
||||
const [customPrompt, setCustomPrompt] = useState(() => typeof window !== "undefined" ? localStorage.getItem(CUSTOM_PROMPT_KEY) || "" : "");
|
||||
const [showCustomPrompt, setShowCustomPrompt] = useState(false);
|
||||
|
||||
// Debounced save customPrompt to localStorage
|
||||
const debounceRef = useRef<ReturnType<typeof setTimeout>>(undefined);
|
||||
useEffect(() => {
|
||||
debounceRef.current = setTimeout(() => {
|
||||
localStorage.setItem(CUSTOM_PROMPT_KEY, customPrompt);
|
||||
}, 300);
|
||||
return () => clearTimeout(debounceRef.current);
|
||||
}, [customPrompt]);
|
||||
|
||||
// Reset state when modal opens (customPrompt is persistent, not reset)
|
||||
// Reset state when modal opens
|
||||
useEffect(() => {
|
||||
if (isOpen) {
|
||||
setStep("config");
|
||||
setScript("");
|
||||
setRewrittenScript("");
|
||||
setError(null);
|
||||
setIsLoading(false);
|
||||
setSelectedFile(null);
|
||||
@@ -112,13 +97,10 @@ export const useScriptExtraction = ({ isOpen }: UseScriptExtractionOptions) => {
|
||||
} else if (activeTab === "url") {
|
||||
formData.append("url", inputUrl.trim());
|
||||
}
|
||||
formData.append("rewrite", doRewrite ? "true" : "false");
|
||||
if (doRewrite && customPrompt.trim()) {
|
||||
formData.append("custom_prompt", customPrompt.trim());
|
||||
}
|
||||
formData.append("rewrite", "false");
|
||||
|
||||
const { data: res } = await api.post<
|
||||
ApiResponse<{ original_script: string; rewritten_script?: string }>
|
||||
ApiResponse<{ original_script: string }>
|
||||
>("/api/tools/extract-script", formData, {
|
||||
headers: { "Content-Type": "multipart/form-data" },
|
||||
timeout: 180000, // 3 minutes timeout
|
||||
@@ -126,7 +108,6 @@ export const useScriptExtraction = ({ isOpen }: UseScriptExtractionOptions) => {
|
||||
|
||||
const payload = unwrap(res);
|
||||
setScript(payload.original_script);
|
||||
setRewrittenScript(payload.rewritten_script || "");
|
||||
setStep("result");
|
||||
} catch (err: unknown) {
|
||||
console.error(err);
|
||||
@@ -141,7 +122,7 @@ export const useScriptExtraction = ({ isOpen }: UseScriptExtractionOptions) => {
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, [activeTab, selectedFile, inputUrl, doRewrite, customPrompt]);
|
||||
}, [activeTab, selectedFile, inputUrl]);
|
||||
|
||||
const copyToClipboard = useCallback((text: string) => {
|
||||
if (navigator.clipboard && window.isSecureContext) {
|
||||
@@ -200,22 +181,15 @@ export const useScriptExtraction = ({ isOpen }: UseScriptExtractionOptions) => {
|
||||
// State
|
||||
isLoading,
|
||||
script,
|
||||
rewrittenScript,
|
||||
error,
|
||||
doRewrite,
|
||||
step,
|
||||
dragActive,
|
||||
selectedFile,
|
||||
activeTab,
|
||||
inputUrl,
|
||||
customPrompt,
|
||||
showCustomPrompt,
|
||||
// Setters
|
||||
setDoRewrite,
|
||||
setActiveTab,
|
||||
setInputUrl,
|
||||
setCustomPrompt,
|
||||
setShowCustomPrompt,
|
||||
// Handlers
|
||||
handleDrag,
|
||||
handleDrop,
|
||||
|
||||
@@ -83,6 +83,8 @@ export const usePublishController = () => {
|
||||
setVideos(nextVideos);
|
||||
if (nextVideos.length > 0 && autoSelectLatest) {
|
||||
setSelectedVideo(nextVideos[0].id);
|
||||
// 写入跨页面共享标记,让首页也能感知最新生成的视频
|
||||
localStorage.setItem(`vigent_${getStorageKey()}_latestGeneratedVideoId`, nextVideos[0].id);
|
||||
}
|
||||
updatePrefetch({ videos: nextVideos });
|
||||
} catch (error) {
|
||||
@@ -109,16 +111,23 @@ export const usePublishController = () => {
|
||||
|
||||
// ---- 视频选择恢复(唯一一个 effect,条件极简) ----
|
||||
// 等 auth 完成 + videos 有数据 → 恢复一次,之后再也不跑
|
||||
// 优先检查跨页面共享标记(最新生成的视频),其次恢复上次选择
|
||||
useEffect(() => {
|
||||
if (isAuthLoading || videos.length === 0 || videoRestoredRef.current) return;
|
||||
videoRestoredRef.current = true;
|
||||
|
||||
const key = getStorageKey();
|
||||
const saved = localStorage.getItem(`vigent_${key}_publish_selected_video`);
|
||||
if (saved && videos.some(v => v.id === saved)) {
|
||||
setSelectedVideo(saved);
|
||||
const latestId = localStorage.getItem(`vigent_${key}_latestGeneratedVideoId`);
|
||||
if (latestId && videos.some(v => v.id === latestId)) {
|
||||
setSelectedVideo(latestId);
|
||||
localStorage.removeItem(`vigent_${key}_latestGeneratedVideoId`);
|
||||
} else {
|
||||
setSelectedVideo(videos[0].id);
|
||||
const saved = localStorage.getItem(`vigent_${key}_publish_selected_video`);
|
||||
if (saved && videos.some(v => v.id === saved)) {
|
||||
setSelectedVideo(saved);
|
||||
} else {
|
||||
setSelectedVideo(videos[0].id);
|
||||
}
|
||||
}
|
||||
}, [isAuthLoading, videos, getStorageKey]);
|
||||
|
||||
|
||||
@@ -135,7 +135,7 @@ export function PublishPage() {
|
||||
<div className="space-y-6">
|
||||
<div className="bg-white/5 rounded-2xl p-6 border border-white/10 backdrop-blur-sm">
|
||||
<h2 className="text-lg font-semibold text-white mb-4 flex items-center gap-2">
|
||||
👤 平台账号
|
||||
七、平台账号
|
||||
</h2>
|
||||
|
||||
{isAccountsLoading ? (
|
||||
@@ -157,62 +157,60 @@ export function PublishPage() {
|
||||
))}
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-3">
|
||||
<div className="space-y-2 sm:space-y-3">
|
||||
{accounts.map((account) => (
|
||||
<div
|
||||
key={account.platform}
|
||||
className="flex items-center justify-between p-4 bg-black/30 rounded-xl"
|
||||
className="flex items-center gap-3 px-3 py-2.5 sm:px-4 sm:py-3.5 bg-black/30 rounded-xl"
|
||||
>
|
||||
<div className="flex items-center gap-3">
|
||||
{platformIcons[account.platform] ? (
|
||||
<Image
|
||||
src={platformIcons[account.platform].src}
|
||||
alt={platformIcons[account.platform].alt}
|
||||
width={28}
|
||||
height={28}
|
||||
className="h-7 w-7"
|
||||
/>
|
||||
) : (
|
||||
<span className="text-2xl">🌐</span>
|
||||
)}
|
||||
<div>
|
||||
<div className="text-white font-medium">
|
||||
{account.name}
|
||||
</div>
|
||||
<div
|
||||
className={`text-sm ${account.logged_in
|
||||
? "text-green-400"
|
||||
: "text-gray-500"
|
||||
}`}
|
||||
>
|
||||
{account.logged_in ? "✓ 已登录" : "未登录"}
|
||||
</div>
|
||||
{platformIcons[account.platform] ? (
|
||||
<Image
|
||||
src={platformIcons[account.platform].src}
|
||||
alt={platformIcons[account.platform].alt}
|
||||
width={28}
|
||||
height={28}
|
||||
className="h-6 w-6 sm:h-7 sm:w-7 shrink-0"
|
||||
/>
|
||||
) : (
|
||||
<span className="text-xl sm:text-2xl">🌐</span>
|
||||
)}
|
||||
<div className="min-w-0 flex-1">
|
||||
<div className="text-sm sm:text-base text-white font-medium leading-tight">
|
||||
{account.name}
|
||||
</div>
|
||||
<div
|
||||
className={`text-xs sm:text-sm leading-tight ${account.logged_in
|
||||
? "text-green-400"
|
||||
: "text-gray-500"
|
||||
}`}
|
||||
>
|
||||
{account.logged_in ? "✓ 已登录" : "未登录"}
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex gap-2">
|
||||
<div className="flex items-center gap-1.5 sm:gap-2 shrink-0">
|
||||
{account.logged_in ? (
|
||||
<>
|
||||
<button
|
||||
onClick={() => handleLogin(account.platform)}
|
||||
className="px-3 py-1 bg-white/10 hover:bg-white/20 text-white text-sm rounded-lg transition-colors flex items-center gap-1"
|
||||
className="px-2 py-1 sm:px-3 sm:py-1.5 bg-white/10 hover:bg-white/20 text-white text-xs sm:text-sm rounded-md sm:rounded-lg transition-colors flex items-center gap-1"
|
||||
>
|
||||
<RotateCcw className="h-3.5 w-3.5" />
|
||||
<RotateCcw className="h-3 w-3 sm:h-3.5 sm:w-3.5" />
|
||||
重新登录
|
||||
</button>
|
||||
<button
|
||||
onClick={() => handleLogout(account.platform)}
|
||||
className="px-3 py-1 bg-red-500/80 hover:bg-red-600 text-white text-sm rounded-lg transition-colors flex items-center gap-1"
|
||||
className="px-2 py-1 sm:px-3 sm:py-1.5 bg-red-500/80 hover:bg-red-600 text-white text-xs sm:text-sm rounded-md sm:rounded-lg transition-colors flex items-center gap-1"
|
||||
>
|
||||
<LogOut className="h-3.5 w-3.5" />
|
||||
<LogOut className="h-3 w-3 sm:h-3.5 sm:w-3.5" />
|
||||
注销
|
||||
</button>
|
||||
</>
|
||||
) : (
|
||||
<button
|
||||
onClick={() => handleLogin(account.platform)}
|
||||
className="px-3 py-1 bg-purple-500/80 hover:bg-purple-600 text-white text-sm rounded-lg transition-colors flex items-center gap-1"
|
||||
className="px-2 py-1 sm:px-3 sm:py-1.5 bg-purple-500/80 hover:bg-purple-600 text-white text-xs sm:text-sm rounded-md sm:rounded-lg transition-colors flex items-center gap-1"
|
||||
>
|
||||
<QrCode className="h-3.5 w-3.5" />
|
||||
<QrCode className="h-3 w-3 sm:h-3.5 sm:w-3.5" />
|
||||
登录
|
||||
</button>
|
||||
)}
|
||||
@@ -228,7 +226,7 @@ export function PublishPage() {
|
||||
<div className="space-y-6">
|
||||
{/* 选择视频 */}
|
||||
<div className="bg-white/5 rounded-2xl p-6 border border-white/10 backdrop-blur-sm">
|
||||
<h2 className="text-lg font-semibold text-white mb-4">📹 选择发布作品</h2>
|
||||
<h2 className="text-lg font-semibold text-white mb-4">八、选择发布作品</h2>
|
||||
|
||||
<div className="flex items-center gap-3 mb-4">
|
||||
<Search className="text-gray-400 w-4 h-4" />
|
||||
@@ -303,7 +301,7 @@ export function PublishPage() {
|
||||
|
||||
{/* 填写信息 */}
|
||||
<div className="bg-white/5 rounded-2xl p-6 border border-white/10 backdrop-blur-sm">
|
||||
<h2 className="text-lg font-semibold text-white mb-4">✍️ 发布信息</h2>
|
||||
<h2 className="text-lg font-semibold text-white mb-4">九、发布信息</h2>
|
||||
|
||||
<div className="space-y-4">
|
||||
<div>
|
||||
@@ -337,7 +335,7 @@ export function PublishPage() {
|
||||
|
||||
{/* 选择平台 */}
|
||||
<div className="bg-white/5 rounded-2xl p-6 border border-white/10 backdrop-blur-sm">
|
||||
<h2 className="text-lg font-semibold text-white mb-4">📱 选择发布平台</h2>
|
||||
<h2 className="text-lg font-semibold text-white mb-4">十、选择发布平台</h2>
|
||||
|
||||
<div className="grid grid-cols-3 gap-3">
|
||||
{accounts
|
||||
|
||||
@@ -11,6 +11,7 @@ interface AuthContextType {
|
||||
user: User | null;
|
||||
isLoading: boolean;
|
||||
isAuthenticated: boolean;
|
||||
setUser: (user: User | null) => void;
|
||||
}
|
||||
|
||||
const AuthContext = createContext<AuthContextType>({
|
||||
@@ -18,6 +19,7 @@ const AuthContext = createContext<AuthContextType>({
|
||||
user: null,
|
||||
isLoading: true,
|
||||
isAuthenticated: false,
|
||||
setUser: () => {},
|
||||
});
|
||||
|
||||
export function AuthProvider({ children }: { children: ReactNode }) {
|
||||
@@ -63,7 +65,8 @@ export function AuthProvider({ children }: { children: ReactNode }) {
|
||||
userId: user?.id || null,
|
||||
user,
|
||||
isLoading,
|
||||
isAuthenticated: !!user
|
||||
isAuthenticated: !!user,
|
||||
setUser,
|
||||
}}>
|
||||
{children}
|
||||
</AuthContext.Provider>
|
||||
|
||||
@@ -65,7 +65,7 @@ def load_model():
|
||||
start = time.time()
|
||||
|
||||
from cosyvoice.cli.cosyvoice import AutoModel
|
||||
_model = AutoModel(model_dir=str(MODEL_DIR))
|
||||
_model = AutoModel(model_dir=str(MODEL_DIR), fp16=True)
|
||||
|
||||
_model_loaded = True
|
||||
print(f"✅ CosyVoice 3.0 model loaded in {time.time() - start:.1f}s")
|
||||
|
||||
159
models/MuseTalk/LICENSE
Normal file
159
models/MuseTalk/LICENSE
Normal file
@@ -0,0 +1,159 @@
|
||||
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 Tencent Music Entertainment Group
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
|
||||
|
||||
Other dependencies and licenses:
|
||||
|
||||
|
||||
Open Source Software Licensed under the MIT License:
|
||||
--------------------------------------------------------------------
|
||||
1. sd-vae-ft-mse
|
||||
Files:https://huggingface.co/stabilityai/sd-vae-ft-mse/tree/main
|
||||
License:MIT license
|
||||
For details:https://choosealicense.com/licenses/mit/
|
||||
|
||||
2. whisper
|
||||
Files:https://github.com/openai/whisper
|
||||
License:MIT license
|
||||
Copyright (c) 2022 OpenAI
|
||||
For details:https://github.com/openai/whisper/blob/main/LICENSE
|
||||
|
||||
3. face-parsing.PyTorch
|
||||
Files:https://github.com/zllrunning/face-parsing.PyTorch
|
||||
License:MIT License
|
||||
Copyright (c) 2019 zll
|
||||
For details:https://github.com/zllrunning/face-parsing.PyTorch/blob/master/LICENSE
|
||||
|
||||
|
||||
|
||||
Open Source Software Licensed under the Apache License Version 2.0:
|
||||
--------------------------------------------------------------------
|
||||
1. DWpose
|
||||
Files:https://huggingface.co/yzd-v/DWPose/tree/main
|
||||
License:Apache-2.0
|
||||
For details:https://choosealicense.com/licenses/apache-2.0/
|
||||
|
||||
|
||||
Terms of the Apache License Version 2.0:
|
||||
--------------------------------------------------------------------
|
||||
Apache License
|
||||
|
||||
Version 2.0, January 2004
|
||||
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
||||
|
||||
You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
||||
|
||||
You must cause any modified files to carry prominent notices stating that You changed the files; and
|
||||
|
||||
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
||||
|
||||
If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
|
||||
|
||||
Open Source Software Licensed under the BSD 3-Clause License:
|
||||
--------------------------------------------------------------------
|
||||
1. face-alignment
|
||||
Files:https://github.com/1adrianb/face-alignment/tree/master
|
||||
License:BSD 3-Clause License
|
||||
Copyright (c) 2017, Adrian Bulat
|
||||
All rights reserved.
|
||||
For details:https://github.com/1adrianb/face-alignment/blob/master/LICENSE
|
||||
|
||||
|
||||
Terms of the BSD 3-Clause License:
|
||||
--------------------------------------------------------------------
|
||||
Redistribution and use in source and binary forms, with or without
|
||||
modification, are permitted provided that the following conditions are met:
|
||||
|
||||
* Redistributions of source code must retain the above copyright notice, this
|
||||
list of conditions and the following disclaimer.
|
||||
|
||||
* Redistributions in binary form must reproduce the above copyright notice,
|
||||
this list of conditions and the following disclaimer in the documentation
|
||||
and/or other materials provided with the distribution.
|
||||
|
||||
* Neither the name of the copyright holder nor the names of its
|
||||
contributors may be used to endorse or promote products derived from
|
||||
this software without specific prior written permission.
|
||||
|
||||
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
|
||||
AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
|
||||
IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
||||
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
||||
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
||||
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
||||
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
||||
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
||||
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
||||
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
||||
|
||||
|
||||
Open Source Software:
|
||||
--------------------------------------------------------------------
|
||||
1.s3FD
|
||||
Files:https://github.com/yxlijun/S3FD.pytorch
|
||||
556
models/MuseTalk/README.md
Normal file
556
models/MuseTalk/README.md
Normal file
@@ -0,0 +1,556 @@
|
||||
# MuseTalk
|
||||
|
||||
> **ViGent2 集成说明**
|
||||
>
|
||||
> 本目录为 MuseTalk v1.5 的部署副本,作为混合唇形同步方案的长视频引擎。
|
||||
>
|
||||
> - **服务**: `scripts/server.py` — FastAPI 常驻推理服务 (端口 8011, GPU0)
|
||||
> - **PM2**: `vigent2-musetalk` (启动脚本 `run_musetalk.sh`)
|
||||
> - **路由**: 音频 >=120s 自动路由到 MuseTalk, <120s 走 LatentSync
|
||||
> - **部署文档**: [`Docs/MUSETALK_DEPLOY.md`](../../Docs/MUSETALK_DEPLOY.md)
|
||||
> - **修改记录**: `scripts/inference.py` 增强 FFmpeg 调用 + CLI 参数; `musetalk/utils/audio_processor.py` 音视频长度不匹配时零填充
|
||||
|
||||
---
|
||||
|
||||
<strong>MuseTalk: Real-Time High-Fidelity Video Dubbing via Spatio-Temporal Sampling</strong>
|
||||
|
||||
Yue Zhang<sup>\*</sup>,
|
||||
Zhizhou Zhong<sup>\*</sup>,
|
||||
Minhao Liu<sup>\*</sup>,
|
||||
Zhaokang Chen,
|
||||
Bin Wu<sup>†</sup>,
|
||||
Yubin Zeng,
|
||||
Chao Zhan,
|
||||
Junxin Huang,
|
||||
Yingjie He,
|
||||
Wenjiang Zhou
|
||||
(<sup>*</sup>Equal Contribution, <sup>†</sup>Corresponding Author, benbinwu@tencent.com)
|
||||
|
||||
Lyra Lab, Tencent Music Entertainment
|
||||
|
||||
**[github](https://github.com/TMElyralab/MuseTalk)** **[huggingface](https://huggingface.co/TMElyralab/MuseTalk)** **[space](https://huggingface.co/spaces/TMElyralab/MuseTalk)** **[Technical report](https://arxiv.org/abs/2410.10122)**
|
||||
|
||||
We introduce `MuseTalk`, a **real-time high quality** lip-syncing model (30fps+ on an NVIDIA Tesla V100). MuseTalk can be applied with input videos, e.g., generated by [MuseV](https://github.com/TMElyralab/MuseV), as a complete virtual human solution.
|
||||
|
||||
## 🔥 Updates
|
||||
We're excited to unveil MuseTalk 1.5.
|
||||
This version **(1)** integrates training with perceptual loss, GAN loss, and sync loss, significantly boosting its overall performance. **(2)** We've implemented a two-stage training strategy and a spatio-temporal data sampling approach to strike a balance between visual quality and lip-sync accuracy.
|
||||
Learn more details [here](https://arxiv.org/abs/2410.10122).
|
||||
**The inference codes, training codes and model weights of MuseTalk 1.5 are all available now!** 🚀
|
||||
|
||||
# Overview
|
||||
`MuseTalk` is a real-time high quality audio-driven lip-syncing model trained in the latent space of `ft-mse-vae`, which
|
||||
|
||||
1. modifies an unseen face according to the input audio, with a size of face region of `256 x 256`.
|
||||
1. supports audio in various languages, such as Chinese, English, and Japanese.
|
||||
1. supports real-time inference with 30fps+ on an NVIDIA Tesla V100.
|
||||
1. supports modification of the center point of the face region proposes, which **SIGNIFICANTLY** affects generation results.
|
||||
1. checkpoint available trained on the HDTF and private dataset.
|
||||
|
||||
# News
|
||||
- [04/05/2025] :mega: We are excited to announce that the training code is now open-sourced! You can now train your own MuseTalk model using our provided training scripts and configurations.
|
||||
- [03/28/2025] We are thrilled to announce the release of our 1.5 version. This version is a significant improvement over the 1.0 version, with enhanced clarity, identity consistency, and precise lip-speech synchronization. We update the [technical report](https://arxiv.org/abs/2410.10122) with more details.
|
||||
- [10/18/2024] We release the [technical report](https://arxiv.org/abs/2410.10122v2). Our report details a superior model to the open-source L1 loss version. It includes GAN and perceptual losses for improved clarity, and sync loss for enhanced performance.
|
||||
- [04/17/2024] We release a pipeline that utilizes MuseTalk for real-time inference.
|
||||
- [04/16/2024] Release Gradio [demo](https://huggingface.co/spaces/TMElyralab/MuseTalk) on HuggingFace Spaces (thanks to HF team for their community grant)
|
||||
- [04/02/2024] Release MuseTalk project and pretrained models.
|
||||
|
||||
|
||||
## Model
|
||||

|
||||
MuseTalk was trained in latent spaces, where the images were encoded by a freezed VAE. The audio was encoded by a freezed `whisper-tiny` model. The architecture of the generation network was borrowed from the UNet of the `stable-diffusion-v1-4`, where the audio embeddings were fused to the image embeddings by cross-attention.
|
||||
|
||||
Note that although we use a very similar architecture as Stable Diffusion, MuseTalk is distinct in that it is **NOT** a diffusion model. Instead, MuseTalk operates by inpainting in the latent space with a single step.
|
||||
|
||||
## Cases
|
||||
|
||||
<table>
|
||||
<tr>
|
||||
<td width="33%">
|
||||
|
||||
### Input Video
|
||||
---
|
||||
https://github.com/TMElyralab/MuseTalk/assets/163980830/37a3a666-7b90-4244-8d3a-058cb0e44107
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/1ce3e850-90ac-4a31-a45f-8dfa4f2960ac
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/fa3b13a1-ae26-4d1d-899e-87435f8d22b3
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/15800692-39d1-4f4c-99f2-aef044dc3251
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/a843f9c9-136d-4ed4-9303-4a7269787a60
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/6eb4e70e-9e19-48e9-85a9-bbfa589c5fcb
|
||||
|
||||
</td>
|
||||
<td width="33%">
|
||||
|
||||
### MuseTalk 1.0
|
||||
---
|
||||
https://github.com/user-attachments/assets/c04f3cd5-9f77-40e9-aafd-61978380d0ef
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/2051a388-1cef-4c1d-b2a2-3c1ceee5dc99
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/b5f56f71-5cdc-4e2e-a519-454242000d32
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/a5843835-04ab-4c31-989f-0995cfc22f34
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/3dc7f1d7-8747-4733-bbdd-97874af0c028
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/3c78064e-faad-4637-83ae-28452a22b09a
|
||||
|
||||
</td>
|
||||
<td width="33%">
|
||||
|
||||
### MuseTalk 1.5
|
||||
---
|
||||
https://github.com/user-attachments/assets/999a6f5b-61dd-48e1-b902-bb3f9cbc7247
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/d26a5c9a-003c-489d-a043-c9a331456e75
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/471290d7-b157-4cf6-8a6d-7e899afa302c
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/1ee77c4c-8c70-4add-b6db-583a12faa7dc
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/370510ea-624c-43b7-bbb0-ab5333e0fcc4
|
||||
|
||||
---
|
||||
https://github.com/user-attachments/assets/b011ece9-a332-4bc1-b8b7-ef6e383d7bde
|
||||
|
||||
</td>
|
||||
</tr>
|
||||
</table>
|
||||
|
||||
|
||||
# TODO:
|
||||
- [x] trained models and inference codes.
|
||||
- [x] Huggingface Gradio [demo](https://huggingface.co/spaces/TMElyralab/MuseTalk).
|
||||
- [x] codes for real-time inference.
|
||||
- [x] [technical report](https://arxiv.org/abs/2410.10122v2).
|
||||
- [x] a better model with updated [technical report](https://arxiv.org/abs/2410.10122).
|
||||
- [x] realtime inference code for 1.5 version.
|
||||
- [x] training and data preprocessing codes.
|
||||
- [ ] **always** welcome to submit issues and PRs to improve this repository! 😊
|
||||
|
||||
|
||||
# Getting Started
|
||||
We provide a detailed tutorial about the installation and the basic usage of MuseTalk for new users:
|
||||
|
||||
## Third party integration
|
||||
Thanks for the third-party integration, which makes installation and use more convenient for everyone.
|
||||
We also hope you note that we have not verified, maintained, or updated third-party. Please refer to this project for specific results.
|
||||
|
||||
### [ComfyUI](https://github.com/chaojie/ComfyUI-MuseTalk)
|
||||
|
||||
## Installation
|
||||
To prepare the Python environment and install additional packages such as opencv, diffusers, mmcv, etc., please follow the steps below:
|
||||
|
||||
### Build environment
|
||||
We recommend Python 3.10 and CUDA 11.7. Set up your environment as follows:
|
||||
|
||||
```shell
|
||||
conda create -n MuseTalk python==3.10
|
||||
conda activate MuseTalk
|
||||
```
|
||||
|
||||
### Install PyTorch 2.0.1
|
||||
Choose one of the following installation methods:
|
||||
|
||||
```shell
|
||||
# Option 1: Using pip
|
||||
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
|
||||
|
||||
# Option 2: Using conda
|
||||
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia
|
||||
```
|
||||
|
||||
### Install Dependencies
|
||||
Install the remaining required packages:
|
||||
|
||||
```shell
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
### Install MMLab Packages
|
||||
Install the MMLab ecosystem packages:
|
||||
|
||||
```bash
|
||||
pip install --no-cache-dir -U openmim
|
||||
mim install mmengine
|
||||
mim install "mmcv==2.0.1"
|
||||
mim install "mmdet==3.1.0"
|
||||
mim install "mmpose==1.1.0"
|
||||
```
|
||||
|
||||
### Setup FFmpeg
|
||||
1. [Download](https://github.com/BtbN/FFmpeg-Builds/releases) the ffmpeg-static package
|
||||
|
||||
2. Configure FFmpeg based on your operating system:
|
||||
|
||||
For Linux:
|
||||
```bash
|
||||
export FFMPEG_PATH=/path/to/ffmpeg
|
||||
# Example:
|
||||
export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static
|
||||
```
|
||||
|
||||
For Windows:
|
||||
Add the `ffmpeg-xxx\bin` directory to your system's PATH environment variable. Verify the installation by running `ffmpeg -version` in the command prompt - it should display the ffmpeg version information.
|
||||
|
||||
### Download weights
|
||||
You can download weights in two ways:
|
||||
|
||||
#### Option 1: Using Download Scripts
|
||||
We provide two scripts for automatic downloading:
|
||||
|
||||
For Linux:
|
||||
```bash
|
||||
sh ./download_weights.sh
|
||||
```
|
||||
|
||||
For Windows:
|
||||
```batch
|
||||
# Run the script
|
||||
download_weights.bat
|
||||
```
|
||||
|
||||
#### Option 2: Manual Download
|
||||
You can also download the weights manually from the following links:
|
||||
|
||||
1. Download our trained [weights](https://huggingface.co/TMElyralab/MuseTalk/tree/main)
|
||||
2. Download the weights of other components:
|
||||
- [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse/tree/main)
|
||||
- [whisper](https://huggingface.co/openai/whisper-tiny/tree/main)
|
||||
- [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
|
||||
- [syncnet](https://huggingface.co/ByteDance/LatentSync/tree/main)
|
||||
- [face-parse-bisent](https://drive.google.com/file/d/154JgKpzCPW82qINcVieuPH3fZ2e0P812/view?pli=1)
|
||||
- [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
|
||||
|
||||
Finally, these weights should be organized in `models` as follows:
|
||||
```
|
||||
./models/
|
||||
├── musetalk
|
||||
│ └── musetalk.json
|
||||
│ └── pytorch_model.bin
|
||||
├── musetalkV15
|
||||
│ └── musetalk.json
|
||||
│ └── unet.pth
|
||||
├── syncnet
|
||||
│ └── latentsync_syncnet.pt
|
||||
├── dwpose
|
||||
│ └── dw-ll_ucoco_384.pth
|
||||
├── face-parse-bisent
|
||||
│ ├── 79999_iter.pth
|
||||
│ └── resnet18-5c106cde.pth
|
||||
├── sd-vae
|
||||
│ ├── config.json
|
||||
│ └── diffusion_pytorch_model.bin
|
||||
└── whisper
|
||||
├── config.json
|
||||
├── pytorch_model.bin
|
||||
└── preprocessor_config.json
|
||||
|
||||
```
|
||||
## Quickstart
|
||||
|
||||
### Inference
|
||||
We provide inference scripts for both versions of MuseTalk:
|
||||
|
||||
#### Prerequisites
|
||||
Before running inference, please ensure ffmpeg is installed and accessible:
|
||||
```bash
|
||||
# Check ffmpeg installation
|
||||
ffmpeg -version
|
||||
```
|
||||
If ffmpeg is not found, please install it first:
|
||||
- Windows: Download from [ffmpeg-static](https://github.com/BtbN/FFmpeg-Builds/releases) and add to PATH
|
||||
- Linux: `sudo apt-get install ffmpeg`
|
||||
|
||||
#### Normal Inference
|
||||
##### Linux Environment
|
||||
```bash
|
||||
# MuseTalk 1.5 (Recommended)
|
||||
sh inference.sh v1.5 normal
|
||||
|
||||
# MuseTalk 1.0
|
||||
sh inference.sh v1.0 normal
|
||||
```
|
||||
|
||||
##### Windows Environment
|
||||
|
||||
Please ensure that you set the `ffmpeg_path` to match the actual location of your FFmpeg installation.
|
||||
|
||||
```bash
|
||||
# MuseTalk 1.5 (Recommended)
|
||||
python -m scripts.inference --inference_config configs\inference\test.yaml --result_dir results\test --unet_model_path models\musetalkV15\unet.pth --unet_config models\musetalkV15\musetalk.json --version v15 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
|
||||
|
||||
# For MuseTalk 1.0, change:
|
||||
# - models\musetalkV15 -> models\musetalk
|
||||
# - unet.pth -> pytorch_model.bin
|
||||
# - --version v15 -> --version v1
|
||||
```
|
||||
|
||||
#### Real-time Inference
|
||||
##### Linux Environment
|
||||
```bash
|
||||
# MuseTalk 1.5 (Recommended)
|
||||
sh inference.sh v1.5 realtime
|
||||
|
||||
# MuseTalk 1.0
|
||||
sh inference.sh v1.0 realtime
|
||||
```
|
||||
|
||||
##### Windows Environment
|
||||
```bash
|
||||
# MuseTalk 1.5 (Recommended)
|
||||
python -m scripts.realtime_inference --inference_config configs\inference\realtime.yaml --result_dir results\realtime --unet_model_path models\musetalkV15\unet.pth --unet_config models\musetalkV15\musetalk.json --version v15 --fps 25 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
|
||||
|
||||
# For MuseTalk 1.0, change:
|
||||
# - models\musetalkV15 -> models\musetalk
|
||||
# - unet.pth -> pytorch_model.bin
|
||||
# - --version v15 -> --version v1
|
||||
```
|
||||
|
||||
The configuration file `configs/inference/test.yaml` contains the inference settings, including:
|
||||
- `video_path`: Path to the input video, image file, or directory of images
|
||||
- `audio_path`: Path to the input audio file
|
||||
|
||||
Note: For optimal results, we recommend using input videos with 25fps, which is the same fps used during model training. If your video has a lower frame rate, you can use frame interpolation or convert it to 25fps using ffmpeg.
|
||||
|
||||
Important notes for real-time inference:
|
||||
1. Set `preparation` to `True` when processing a new avatar
|
||||
2. After preparation, the avatar will generate videos using audio clips from `audio_clips`
|
||||
3. The generation process can achieve 30fps+ on an NVIDIA Tesla V100
|
||||
4. Set `preparation` to `False` for generating more videos with the same avatar
|
||||
|
||||
For faster generation without saving images, you can use:
|
||||
```bash
|
||||
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
|
||||
```
|
||||
|
||||
## Gradio Demo
|
||||
We provide an intuitive web interface through Gradio for users to easily adjust input parameters. To optimize inference time, users can generate only the **first frame** to fine-tune the best lip-sync parameters, which helps reduce facial artifacts in the final output.
|
||||

|
||||
For minimum hardware requirements, we tested the system on a Windows environment using an NVIDIA GeForce RTX 3050 Ti Laptop GPU with 4GB VRAM. In fp16 mode, generating an 8-second video takes approximately 5 minutes. 
|
||||
|
||||
Both Linux and Windows users can launch the demo using the following command. Please ensure that the `ffmpeg_path` parameter matches your actual FFmpeg installation path:
|
||||
|
||||
```bash
|
||||
# You can remove --use_float16 for better quality, but it will increase VRAM usage and inference time
|
||||
python app.py --use_float16 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
|
||||
```
|
||||
|
||||
## Training
|
||||
|
||||
### Data Preparation
|
||||
To train MuseTalk, you need to prepare your dataset following these steps:
|
||||
|
||||
1. **Place your source videos**
|
||||
|
||||
For example, if you're using the HDTF dataset, place all your video files in `./dataset/HDTF/source`.
|
||||
|
||||
2. **Run the preprocessing script**
|
||||
```bash
|
||||
python -m scripts.preprocess --config ./configs/training/preprocess.yaml
|
||||
```
|
||||
This script will:
|
||||
- Extract frames from videos
|
||||
- Detect and align faces
|
||||
- Generate audio features
|
||||
- Create the necessary data structure for training
|
||||
|
||||
### Training Process
|
||||
After data preprocessing, you can start the training process:
|
||||
|
||||
1. **First Stage**
|
||||
```bash
|
||||
sh train.sh stage1
|
||||
```
|
||||
|
||||
2. **Second Stage**
|
||||
```bash
|
||||
sh train.sh stage2
|
||||
```
|
||||
|
||||
### Configuration Adjustment
|
||||
Before starting the training, you should adjust the configuration files according to your hardware and requirements:
|
||||
|
||||
1. **GPU Configuration** (`configs/training/gpu.yaml`):
|
||||
- `gpu_ids`: Specify the GPU IDs you want to use (e.g., "0,1,2,3")
|
||||
- `num_processes`: Set this to match the number of GPUs you're using
|
||||
|
||||
2. **Stage 1 Configuration** (`configs/training/stage1.yaml`):
|
||||
- `data.train_bs`: Adjust batch size based on your GPU memory (default: 32)
|
||||
- `data.n_sample_frames`: Number of sampled frames per video (default: 1)
|
||||
|
||||
3. **Stage 2 Configuration** (`configs/training/stage2.yaml`):
|
||||
- `random_init_unet`: Must be set to `False` to use the model from stage 1
|
||||
- `data.train_bs`: Smaller batch size due to high GPU memory cost (default: 2)
|
||||
- `data.n_sample_frames`: Higher value for temporal consistency (default: 16)
|
||||
- `solver.gradient_accumulation_steps`: Increase to simulate larger batch sizes (default: 8)
|
||||
|
||||
|
||||
### GPU Memory Requirements
|
||||
Based on our testing on a machine with 8 NVIDIA H20 GPUs:
|
||||
|
||||
#### Stage 1 Memory Usage
|
||||
| Batch Size | Gradient Accumulation | Memory per GPU | Recommendation |
|
||||
|:----------:|:----------------------:|:--------------:|:--------------:|
|
||||
| 8 | 1 | ~32GB | |
|
||||
| 16 | 1 | ~45GB | |
|
||||
| 32 | 1 | ~74GB | ✓ |
|
||||
|
||||
#### Stage 2 Memory Usage
|
||||
| Batch Size | Gradient Accumulation | Memory per GPU | Recommendation |
|
||||
|:----------:|:----------------------:|:--------------:|:--------------:|
|
||||
| 1 | 8 | ~54GB | |
|
||||
| 2 | 2 | ~80GB | |
|
||||
| 2 | 8 | ~85GB | ✓ |
|
||||
|
||||
<details close>
|
||||
## TestCases For 1.0
|
||||
<table class="center">
|
||||
<tr style="font-weight: bolder;text-align:center;">
|
||||
<td width="33%">Image</td>
|
||||
<td width="33%">MuseV</td>
|
||||
<td width="33%">+MuseTalk</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<img src=assets/demo/musk/musk.png width="95%">
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/4a4bb2d1-9d14-4ca9-85c8-7f19c39f712e controls preload></video>
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/b2a879c2-e23a-4d39-911d-51f0343218e4 controls preload></video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<img src=assets/demo/yongen/yongen.jpeg width="95%">
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/57ef9dee-a9fd-4dc8-839b-3fbbbf0ff3f4 controls preload></video>
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/94d8dcba-1bcd-4b54-9d1d-8b6fc53228f0 controls preload></video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<img src=assets/demo/sit/sit.jpeg width="95%">
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/5fbab81b-d3f2-4c75-abb5-14c76e51769e controls preload></video>
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/f8100f4a-3df8-4151-8de2-291b09269f66 controls preload></video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<img src=assets/demo/man/man.png width="95%">
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/a6e7d431-5643-4745-9868-8b423a454153 controls preload></video>
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/6ccf7bc7-cb48-42de-85bd-076d5ee8a623 controls preload></video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<img src=assets/demo/monalisa/monalisa.png width="95%">
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/1568f604-a34f-4526-a13a-7d282aa2e773 controls preload></video>
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/a40784fc-a885-4c1f-9b7e-8f87b7caf4e0 controls preload></video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<img src=assets/demo/sun1/sun.png width="95%">
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/37a3a666-7b90-4244-8d3a-058cb0e44107 controls preload></video>
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/172f4ff1-d432-45bd-a5a7-a07dec33a26b controls preload></video>
|
||||
</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td>
|
||||
<img src=assets/demo/sun2/sun.png width="95%">
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/37a3a666-7b90-4244-8d3a-058cb0e44107 controls preload></video>
|
||||
</td>
|
||||
<td >
|
||||
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/85a6873d-a028-4cce-af2b-6c59a1f2971d controls preload></video>
|
||||
</td>
|
||||
</tr>
|
||||
</table >
|
||||
|
||||
#### Use of bbox_shift to have adjustable results(For 1.0)
|
||||
:mag_right: We have found that upper-bound of the mask has an important impact on mouth openness. Thus, to control the mask region, we suggest using the `bbox_shift` parameter. Positive values (moving towards the lower half) increase mouth openness, while negative values (moving towards the upper half) decrease mouth openness.
|
||||
|
||||
You can start by running with the default configuration to obtain the adjustable value range, and then re-run the script within this range.
|
||||
|
||||
For example, in the case of `Xinying Sun`, after running the default configuration, it shows that the adjustable value rage is [-9, 9]. Then, to decrease the mouth openness, we set the value to be `-7`.
|
||||
```
|
||||
python -m scripts.inference --inference_config configs/inference/test.yaml --bbox_shift -7
|
||||
```
|
||||
:pushpin: More technical details can be found in [bbox_shift](assets/BBOX_SHIFT.md).
|
||||
|
||||
|
||||
#### Combining MuseV and MuseTalk
|
||||
|
||||
As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference).
|
||||
|
||||
# Acknowledgement
|
||||
1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch) and [LatentSync](https://huggingface.co/ByteDance/LatentSync/tree/main).
|
||||
1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings).
|
||||
1. MuseTalk has been built on [HDTF](https://github.com/MRzzm/HDTF) datasets.
|
||||
|
||||
Thanks for open-sourcing!
|
||||
|
||||
# Limitations
|
||||
- Resolution: Though MuseTalk uses a face region size of 256 x 256, which make it better than other open-source methods, it has not yet reached the theoretical resolution bound. We will continue to deal with this problem.
|
||||
If you need higher resolution, you could apply super resolution models such as [GFPGAN](https://github.com/TencentARC/GFPGAN) in combination with MuseTalk.
|
||||
|
||||
- Identity preservation: Some details of the original face are not well preserved, such as mustache, lip shape and color.
|
||||
|
||||
- Jitter: There exists some jitter as the current pipeline adopts single-frame generation.
|
||||
|
||||
# Citation
|
||||
```bib
|
||||
@article{musetalk,
|
||||
title={MuseTalk: Real-Time High-Fidelity Video Dubbing via Spatio-Temporal Sampling},
|
||||
author={Zhang, Yue and Zhong, Zhizhou and Liu, Minhao and Chen, Zhaokang and Wu, Bin and Zeng, Yubin and Zhan, Chao and He, Yingjie and Huang, Junxin and Zhou, Wenjiang},
|
||||
journal={arxiv},
|
||||
year={2025}
|
||||
}
|
||||
```
|
||||
# Disclaimer/License
|
||||
1. `code`: The code of MuseTalk is released under the MIT License. There is no limitation for both academic and commercial usage.
|
||||
1. `model`: The trained model are available for any purpose, even commercially.
|
||||
1. `other opensource model`: Other open-source models used must comply with their license, such as `whisper`, `ft-mse-vae`, `dwpose`, `S3FD`, etc..
|
||||
1. The testdata are collected from internet, which are available for non-commercial research purposes only.
|
||||
1. `AIGC`: This project strives to impact the domain of AI-driven video generation positively. Users are granted the freedom to create videos using this tool, but they are expected to comply with local laws and utilize it responsibly. The developers do not assume any responsibility for potential misuse by users.
|
||||
570
models/MuseTalk/app.py
Normal file
570
models/MuseTalk/app.py
Normal file
@@ -0,0 +1,570 @@
|
||||
import os
|
||||
import time
|
||||
import pdb
|
||||
import re
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import sys
|
||||
import subprocess
|
||||
|
||||
from huggingface_hub import snapshot_download
|
||||
import requests
|
||||
|
||||
import argparse
|
||||
import os
|
||||
from omegaconf import OmegaConf
|
||||
import numpy as np
|
||||
import cv2
|
||||
import torch
|
||||
import glob
|
||||
import pickle
|
||||
from tqdm import tqdm
|
||||
import copy
|
||||
from argparse import Namespace
|
||||
import shutil
|
||||
import gdown
|
||||
import imageio
|
||||
import ffmpeg
|
||||
from moviepy.editor import *
|
||||
from transformers import WhisperModel
|
||||
|
||||
ProjectDir = os.path.abspath(os.path.dirname(__file__))
|
||||
CheckpointsDir = os.path.join(ProjectDir, "models")
|
||||
|
||||
@torch.no_grad()
|
||||
def debug_inpainting(video_path, bbox_shift, extra_margin=10, parsing_mode="jaw",
|
||||
left_cheek_width=90, right_cheek_width=90):
|
||||
"""Debug inpainting parameters, only process the first frame"""
|
||||
# Set default parameters
|
||||
args_dict = {
|
||||
"result_dir": './results/debug',
|
||||
"fps": 25,
|
||||
"batch_size": 1,
|
||||
"output_vid_name": '',
|
||||
"use_saved_coord": False,
|
||||
"audio_padding_length_left": 2,
|
||||
"audio_padding_length_right": 2,
|
||||
"version": "v15",
|
||||
"extra_margin": extra_margin,
|
||||
"parsing_mode": parsing_mode,
|
||||
"left_cheek_width": left_cheek_width,
|
||||
"right_cheek_width": right_cheek_width
|
||||
}
|
||||
args = Namespace(**args_dict)
|
||||
|
||||
# Create debug directory
|
||||
os.makedirs(args.result_dir, exist_ok=True)
|
||||
|
||||
# Read first frame
|
||||
if get_file_type(video_path) == "video":
|
||||
reader = imageio.get_reader(video_path)
|
||||
first_frame = reader.get_data(0)
|
||||
reader.close()
|
||||
else:
|
||||
first_frame = cv2.imread(video_path)
|
||||
first_frame = cv2.cvtColor(first_frame, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Save first frame
|
||||
debug_frame_path = os.path.join(args.result_dir, "debug_frame.png")
|
||||
cv2.imwrite(debug_frame_path, cv2.cvtColor(first_frame, cv2.COLOR_RGB2BGR))
|
||||
|
||||
# Get face coordinates
|
||||
coord_list, frame_list = get_landmark_and_bbox([debug_frame_path], bbox_shift)
|
||||
bbox = coord_list[0]
|
||||
frame = frame_list[0]
|
||||
|
||||
if bbox == coord_placeholder:
|
||||
return None, "No face detected, please adjust bbox_shift parameter"
|
||||
|
||||
# Initialize face parser
|
||||
fp = FaceParsing(
|
||||
left_cheek_width=args.left_cheek_width,
|
||||
right_cheek_width=args.right_cheek_width
|
||||
)
|
||||
|
||||
# Process first frame
|
||||
x1, y1, x2, y2 = bbox
|
||||
y2 = y2 + args.extra_margin
|
||||
y2 = min(y2, frame.shape[0])
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||
|
||||
# Generate random audio features
|
||||
random_audio = torch.randn(1, 50, 384, device=device, dtype=weight_dtype)
|
||||
audio_feature = pe(random_audio)
|
||||
|
||||
# Get latents
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
latents = latents.to(dtype=weight_dtype)
|
||||
|
||||
# Generate prediction results
|
||||
pred_latents = unet.model(latents, timesteps, encoder_hidden_states=audio_feature).sample
|
||||
recon = vae.decode_latents(pred_latents)
|
||||
|
||||
# Inpaint back to original image
|
||||
res_frame = recon[0]
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
||||
combine_frame = get_image(frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
|
||||
|
||||
# Save results (no need to convert color space again since get_image already returns RGB format)
|
||||
debug_result_path = os.path.join(args.result_dir, "debug_result.png")
|
||||
cv2.imwrite(debug_result_path, combine_frame)
|
||||
|
||||
# Create information text
|
||||
info_text = f"Parameter information:\n" + \
|
||||
f"bbox_shift: {bbox_shift}\n" + \
|
||||
f"extra_margin: {extra_margin}\n" + \
|
||||
f"parsing_mode: {parsing_mode}\n" + \
|
||||
f"left_cheek_width: {left_cheek_width}\n" + \
|
||||
f"right_cheek_width: {right_cheek_width}\n" + \
|
||||
f"Detected face coordinates: [{x1}, {y1}, {x2}, {y2}]"
|
||||
|
||||
return cv2.cvtColor(combine_frame, cv2.COLOR_RGB2BGR), info_text
|
||||
|
||||
def print_directory_contents(path):
|
||||
for child in os.listdir(path):
|
||||
child_path = os.path.join(path, child)
|
||||
if os.path.isdir(child_path):
|
||||
print(child_path)
|
||||
|
||||
def download_model():
|
||||
# 检查必需的模型文件是否存在
|
||||
required_models = {
|
||||
"MuseTalk": f"{CheckpointsDir}/musetalkV15/unet.pth",
|
||||
"MuseTalk": f"{CheckpointsDir}/musetalkV15/musetalk.json",
|
||||
"SD VAE": f"{CheckpointsDir}/sd-vae/config.json",
|
||||
"Whisper": f"{CheckpointsDir}/whisper/config.json",
|
||||
"DWPose": f"{CheckpointsDir}/dwpose/dw-ll_ucoco_384.pth",
|
||||
"SyncNet": f"{CheckpointsDir}/syncnet/latentsync_syncnet.pt",
|
||||
"Face Parse": f"{CheckpointsDir}/face-parse-bisent/79999_iter.pth",
|
||||
"ResNet": f"{CheckpointsDir}/face-parse-bisent/resnet18-5c106cde.pth"
|
||||
}
|
||||
|
||||
missing_models = []
|
||||
for model_name, model_path in required_models.items():
|
||||
if not os.path.exists(model_path):
|
||||
missing_models.append(model_name)
|
||||
|
||||
if missing_models:
|
||||
# 全用英文
|
||||
print("The following required model files are missing:")
|
||||
for model in missing_models:
|
||||
print(f"- {model}")
|
||||
print("\nPlease run the download script to download the missing models:")
|
||||
if sys.platform == "win32":
|
||||
print("Windows: Run download_weights.bat")
|
||||
else:
|
||||
print("Linux/Mac: Run ./download_weights.sh")
|
||||
sys.exit(1)
|
||||
else:
|
||||
print("All required model files exist.")
|
||||
|
||||
|
||||
|
||||
|
||||
download_model() # for huggingface deployment.
|
||||
|
||||
from musetalk.utils.blending import get_image
|
||||
from musetalk.utils.face_parsing import FaceParsing
|
||||
from musetalk.utils.audio_processor import AudioProcessor
|
||||
from musetalk.utils.utils import get_file_type, get_video_fps, datagen, load_all_model
|
||||
from musetalk.utils.preprocessing import get_landmark_and_bbox, read_imgs, coord_placeholder, get_bbox_range
|
||||
|
||||
|
||||
def fast_check_ffmpeg():
|
||||
try:
|
||||
subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True)
|
||||
return True
|
||||
except:
|
||||
return False
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def inference(audio_path, video_path, bbox_shift, extra_margin=10, parsing_mode="jaw",
|
||||
left_cheek_width=90, right_cheek_width=90, progress=gr.Progress(track_tqdm=True)):
|
||||
# Set default parameters, aligned with inference.py
|
||||
args_dict = {
|
||||
"result_dir": './results/output',
|
||||
"fps": 25,
|
||||
"batch_size": 8,
|
||||
"output_vid_name": '',
|
||||
"use_saved_coord": False,
|
||||
"audio_padding_length_left": 2,
|
||||
"audio_padding_length_right": 2,
|
||||
"version": "v15", # Fixed use v15 version
|
||||
"extra_margin": extra_margin,
|
||||
"parsing_mode": parsing_mode,
|
||||
"left_cheek_width": left_cheek_width,
|
||||
"right_cheek_width": right_cheek_width
|
||||
}
|
||||
args = Namespace(**args_dict)
|
||||
|
||||
# Check ffmpeg
|
||||
if not fast_check_ffmpeg():
|
||||
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
||||
|
||||
input_basename = os.path.basename(video_path).split('.')[0]
|
||||
audio_basename = os.path.basename(audio_path).split('.')[0]
|
||||
output_basename = f"{input_basename}_{audio_basename}"
|
||||
|
||||
# Create temporary directory
|
||||
temp_dir = os.path.join(args.result_dir, f"{args.version}")
|
||||
os.makedirs(temp_dir, exist_ok=True)
|
||||
|
||||
# Set result save path
|
||||
result_img_save_path = os.path.join(temp_dir, output_basename)
|
||||
crop_coord_save_path = os.path.join(args.result_dir, "../", input_basename+".pkl")
|
||||
os.makedirs(result_img_save_path, exist_ok=True)
|
||||
|
||||
if args.output_vid_name == "":
|
||||
output_vid_name = os.path.join(temp_dir, output_basename+".mp4")
|
||||
else:
|
||||
output_vid_name = os.path.join(temp_dir, args.output_vid_name)
|
||||
|
||||
############################################## extract frames from source video ##############################################
|
||||
if get_file_type(video_path) == "video":
|
||||
save_dir_full = os.path.join(temp_dir, input_basename)
|
||||
os.makedirs(save_dir_full, exist_ok=True)
|
||||
# Read video
|
||||
reader = imageio.get_reader(video_path)
|
||||
|
||||
# Save images
|
||||
for i, im in enumerate(reader):
|
||||
imageio.imwrite(f"{save_dir_full}/{i:08d}.png", im)
|
||||
input_img_list = sorted(glob.glob(os.path.join(save_dir_full, '*.[jpJP][pnPN]*[gG]')))
|
||||
fps = get_video_fps(video_path)
|
||||
else: # input img folder
|
||||
input_img_list = glob.glob(os.path.join(video_path, '*.[jpJP][pnPN]*[gG]'))
|
||||
input_img_list = sorted(input_img_list, key=lambda x: int(os.path.splitext(os.path.basename(x))[0]))
|
||||
fps = args.fps
|
||||
|
||||
############################################## extract audio feature ##############################################
|
||||
# Extract audio features
|
||||
whisper_input_features, librosa_length = audio_processor.get_audio_feature(audio_path)
|
||||
whisper_chunks = audio_processor.get_whisper_chunk(
|
||||
whisper_input_features,
|
||||
device,
|
||||
weight_dtype,
|
||||
whisper,
|
||||
librosa_length,
|
||||
fps=fps,
|
||||
audio_padding_length_left=args.audio_padding_length_left,
|
||||
audio_padding_length_right=args.audio_padding_length_right,
|
||||
)
|
||||
|
||||
############################################## preprocess input image ##############################################
|
||||
if os.path.exists(crop_coord_save_path) and args.use_saved_coord:
|
||||
print("using extracted coordinates")
|
||||
with open(crop_coord_save_path,'rb') as f:
|
||||
coord_list = pickle.load(f)
|
||||
frame_list = read_imgs(input_img_list)
|
||||
else:
|
||||
print("extracting landmarks...time consuming")
|
||||
coord_list, frame_list = get_landmark_and_bbox(input_img_list, bbox_shift)
|
||||
with open(crop_coord_save_path, 'wb') as f:
|
||||
pickle.dump(coord_list, f)
|
||||
bbox_shift_text = get_bbox_range(input_img_list, bbox_shift)
|
||||
|
||||
# Initialize face parser
|
||||
fp = FaceParsing(
|
||||
left_cheek_width=args.left_cheek_width,
|
||||
right_cheek_width=args.right_cheek_width
|
||||
)
|
||||
|
||||
i = 0
|
||||
input_latent_list = []
|
||||
for bbox, frame in zip(coord_list, frame_list):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
y2 = y2 + args.extra_margin
|
||||
y2 = min(y2, frame.shape[0])
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
crop_frame = cv2.resize(crop_frame,(256,256),interpolation = cv2.INTER_LANCZOS4)
|
||||
latents = vae.get_latents_for_unet(crop_frame)
|
||||
input_latent_list.append(latents)
|
||||
|
||||
# to smooth the first and the last frame
|
||||
frame_list_cycle = frame_list + frame_list[::-1]
|
||||
coord_list_cycle = coord_list + coord_list[::-1]
|
||||
input_latent_list_cycle = input_latent_list + input_latent_list[::-1]
|
||||
|
||||
############################################## inference batch by batch ##############################################
|
||||
print("start inference")
|
||||
video_num = len(whisper_chunks)
|
||||
batch_size = args.batch_size
|
||||
gen = datagen(
|
||||
whisper_chunks=whisper_chunks,
|
||||
vae_encode_latents=input_latent_list_cycle,
|
||||
batch_size=batch_size,
|
||||
delay_frame=0,
|
||||
device=device,
|
||||
)
|
||||
res_frame_list = []
|
||||
for i, (whisper_batch,latent_batch) in enumerate(tqdm(gen,total=int(np.ceil(float(video_num)/batch_size)))):
|
||||
audio_feature_batch = pe(whisper_batch)
|
||||
# Ensure latent_batch is consistent with model weight type
|
||||
latent_batch = latent_batch.to(dtype=weight_dtype)
|
||||
|
||||
pred_latents = unet.model(latent_batch, timesteps, encoder_hidden_states=audio_feature_batch).sample
|
||||
recon = vae.decode_latents(pred_latents)
|
||||
for res_frame in recon:
|
||||
res_frame_list.append(res_frame)
|
||||
|
||||
############################################## pad to full image ##############################################
|
||||
print("pad talking image to original video")
|
||||
for i, res_frame in enumerate(tqdm(res_frame_list)):
|
||||
bbox = coord_list_cycle[i%(len(coord_list_cycle))]
|
||||
ori_frame = copy.deepcopy(frame_list_cycle[i%(len(frame_list_cycle))])
|
||||
x1, y1, x2, y2 = bbox
|
||||
y2 = y2 + args.extra_margin
|
||||
y2 = min(y2, frame.shape[0])
|
||||
try:
|
||||
res_frame = cv2.resize(res_frame.astype(np.uint8),(x2-x1,y2-y1))
|
||||
except:
|
||||
continue
|
||||
|
||||
# Use v15 version blending
|
||||
combine_frame = get_image(ori_frame, res_frame, [x1, y1, x2, y2], mode=args.parsing_mode, fp=fp)
|
||||
|
||||
cv2.imwrite(f"{result_img_save_path}/{str(i).zfill(8)}.png",combine_frame)
|
||||
|
||||
# Frame rate
|
||||
fps = 25
|
||||
# Output video path
|
||||
output_video = 'temp.mp4'
|
||||
|
||||
# Read images
|
||||
def is_valid_image(file):
|
||||
pattern = re.compile(r'\d{8}\.png')
|
||||
return pattern.match(file)
|
||||
|
||||
images = []
|
||||
files = [file for file in os.listdir(result_img_save_path) if is_valid_image(file)]
|
||||
files.sort(key=lambda x: int(x.split('.')[0]))
|
||||
|
||||
for file in files:
|
||||
filename = os.path.join(result_img_save_path, file)
|
||||
images.append(imageio.imread(filename))
|
||||
|
||||
|
||||
# Save video
|
||||
imageio.mimwrite(output_video, images, 'FFMPEG', fps=fps, codec='libx264', pixelformat='yuv420p')
|
||||
|
||||
input_video = './temp.mp4'
|
||||
# Check if the input_video and audio_path exist
|
||||
if not os.path.exists(input_video):
|
||||
raise FileNotFoundError(f"Input video file not found: {input_video}")
|
||||
if not os.path.exists(audio_path):
|
||||
raise FileNotFoundError(f"Audio file not found: {audio_path}")
|
||||
|
||||
# Read video
|
||||
reader = imageio.get_reader(input_video)
|
||||
fps = reader.get_meta_data()['fps'] # Get original video frame rate
|
||||
reader.close() # Otherwise, error on win11: PermissionError: [WinError 32] Another program is using this file, process cannot access. : 'temp.mp4'
|
||||
# Store frames in list
|
||||
frames = images
|
||||
|
||||
print(len(frames))
|
||||
|
||||
# Load the video
|
||||
video_clip = VideoFileClip(input_video)
|
||||
|
||||
# Load the audio
|
||||
audio_clip = AudioFileClip(audio_path)
|
||||
|
||||
# Set the audio to the video
|
||||
video_clip = video_clip.set_audio(audio_clip)
|
||||
|
||||
# Write the output video
|
||||
video_clip.write_videofile(output_vid_name, codec='libx264', audio_codec='aac',fps=25)
|
||||
|
||||
os.remove("temp.mp4")
|
||||
#shutil.rmtree(result_img_save_path)
|
||||
print(f"result is save to {output_vid_name}")
|
||||
return output_vid_name,bbox_shift_text
|
||||
|
||||
|
||||
|
||||
# load model weights
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
vae, unet, pe = load_all_model(
|
||||
unet_model_path="./models/musetalkV15/unet.pth",
|
||||
vae_type="sd-vae",
|
||||
unet_config="./models/musetalkV15/musetalk.json",
|
||||
device=device
|
||||
)
|
||||
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--ffmpeg_path", type=str, default=r"ffmpeg-master-latest-win64-gpl-shared\bin", help="Path to ffmpeg executable")
|
||||
parser.add_argument("--ip", type=str, default="127.0.0.1", help="IP address to bind to")
|
||||
parser.add_argument("--port", type=int, default=7860, help="Port to bind to")
|
||||
parser.add_argument("--share", action="store_true", help="Create a public link")
|
||||
parser.add_argument("--use_float16", action="store_true", help="Use float16 for faster inference")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Set data type
|
||||
if args.use_float16:
|
||||
# Convert models to half precision for better performance
|
||||
pe = pe.half()
|
||||
vae.vae = vae.vae.half()
|
||||
unet.model = unet.model.half()
|
||||
weight_dtype = torch.float16
|
||||
else:
|
||||
weight_dtype = torch.float32
|
||||
|
||||
# Move models to specified device
|
||||
pe = pe.to(device)
|
||||
vae.vae = vae.vae.to(device)
|
||||
unet.model = unet.model.to(device)
|
||||
|
||||
timesteps = torch.tensor([0], device=device)
|
||||
|
||||
# Initialize audio processor and Whisper model
|
||||
audio_processor = AudioProcessor(feature_extractor_path="./models/whisper")
|
||||
whisper = WhisperModel.from_pretrained("./models/whisper")
|
||||
whisper = whisper.to(device=device, dtype=weight_dtype).eval()
|
||||
whisper.requires_grad_(False)
|
||||
|
||||
|
||||
def check_video(video):
|
||||
if not isinstance(video, str):
|
||||
return video # in case of none type
|
||||
# Define the output video file name
|
||||
dir_path, file_name = os.path.split(video)
|
||||
if file_name.startswith("outputxxx_"):
|
||||
return video
|
||||
# Add the output prefix to the file name
|
||||
output_file_name = "outputxxx_" + file_name
|
||||
|
||||
os.makedirs('./results',exist_ok=True)
|
||||
os.makedirs('./results/output',exist_ok=True)
|
||||
os.makedirs('./results/input',exist_ok=True)
|
||||
|
||||
# Combine the directory path and the new file name
|
||||
output_video = os.path.join('./results/input', output_file_name)
|
||||
|
||||
|
||||
# read video
|
||||
reader = imageio.get_reader(video)
|
||||
fps = reader.get_meta_data()['fps'] # get fps from original video
|
||||
|
||||
# conver fps to 25
|
||||
frames = [im for im in reader]
|
||||
target_fps = 25
|
||||
|
||||
L = len(frames)
|
||||
L_target = int(L / fps * target_fps)
|
||||
original_t = [x / fps for x in range(1, L+1)]
|
||||
t_idx = 0
|
||||
target_frames = []
|
||||
for target_t in range(1, L_target+1):
|
||||
while target_t / target_fps > original_t[t_idx]:
|
||||
t_idx += 1 # find the first t_idx so that target_t / target_fps <= original_t[t_idx]
|
||||
if t_idx >= L:
|
||||
break
|
||||
target_frames.append(frames[t_idx])
|
||||
|
||||
# save video
|
||||
imageio.mimwrite(output_video, target_frames, 'FFMPEG', fps=25, codec='libx264', quality=9, pixelformat='yuv420p')
|
||||
return output_video
|
||||
|
||||
|
||||
|
||||
|
||||
css = """#input_img {max-width: 1024px !important} #output_vid {max-width: 1024px; max-height: 576px}"""
|
||||
|
||||
with gr.Blocks(css=css) as demo:
|
||||
gr.Markdown(
|
||||
"""<div align='center'> <h1>MuseTalk: Real-Time High-Fidelity Video Dubbing via Spatio-Temporal Sampling</h1> \
|
||||
<h2 style='font-weight: 450; font-size: 1rem; margin: 0rem'>\
|
||||
</br>\
|
||||
Yue Zhang <sup>*</sup>,\
|
||||
Zhizhou Zhong <sup>*</sup>,\
|
||||
Minhao Liu<sup>*</sup>,\
|
||||
Zhaokang Chen,\
|
||||
Bin Wu<sup>†</sup>,\
|
||||
Yubin Zeng,\
|
||||
Chao Zhang,\
|
||||
Yingjie He,\
|
||||
Junxin Huang,\
|
||||
Wenjiang Zhou <br>\
|
||||
(<sup>*</sup>Equal Contribution, <sup>†</sup>Corresponding Author, benbinwu@tencent.com)\
|
||||
Lyra Lab, Tencent Music Entertainment\
|
||||
</h2> \
|
||||
<a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Github Repo]</a>\
|
||||
<a style='font-size:18px;color: #000000' href='https://github.com/TMElyralab/MuseTalk'>[Huggingface]</a>\
|
||||
<a style='font-size:18px;color: #000000' href='https://arxiv.org/abs/2410.10122'> [Technical report] </a>"""
|
||||
)
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
audio = gr.Audio(label="Drving Audio",type="filepath")
|
||||
video = gr.Video(label="Reference Video",sources=['upload'])
|
||||
bbox_shift = gr.Number(label="BBox_shift value, px", value=0)
|
||||
extra_margin = gr.Slider(label="Extra Margin", minimum=0, maximum=40, value=10, step=1)
|
||||
parsing_mode = gr.Radio(label="Parsing Mode", choices=["jaw", "raw"], value="jaw")
|
||||
left_cheek_width = gr.Slider(label="Left Cheek Width", minimum=20, maximum=160, value=90, step=5)
|
||||
right_cheek_width = gr.Slider(label="Right Cheek Width", minimum=20, maximum=160, value=90, step=5)
|
||||
bbox_shift_scale = gr.Textbox(label="'left_cheek_width' and 'right_cheek_width' parameters determine the range of left and right cheeks editing when parsing model is 'jaw'. The 'extra_margin' parameter determines the movement range of the jaw. Users can freely adjust these three parameters to obtain better inpainting results.")
|
||||
|
||||
with gr.Row():
|
||||
debug_btn = gr.Button("1. Test Inpainting ")
|
||||
btn = gr.Button("2. Generate")
|
||||
with gr.Column():
|
||||
debug_image = gr.Image(label="Test Inpainting Result (First Frame)")
|
||||
debug_info = gr.Textbox(label="Parameter Information", lines=5)
|
||||
out1 = gr.Video()
|
||||
|
||||
video.change(
|
||||
fn=check_video, inputs=[video], outputs=[video]
|
||||
)
|
||||
btn.click(
|
||||
fn=inference,
|
||||
inputs=[
|
||||
audio,
|
||||
video,
|
||||
bbox_shift,
|
||||
extra_margin,
|
||||
parsing_mode,
|
||||
left_cheek_width,
|
||||
right_cheek_width
|
||||
],
|
||||
outputs=[out1,bbox_shift_scale]
|
||||
)
|
||||
debug_btn.click(
|
||||
fn=debug_inpainting,
|
||||
inputs=[
|
||||
video,
|
||||
bbox_shift,
|
||||
extra_margin,
|
||||
parsing_mode,
|
||||
left_cheek_width,
|
||||
right_cheek_width
|
||||
],
|
||||
outputs=[debug_image, debug_info]
|
||||
)
|
||||
|
||||
# Check ffmpeg and add to PATH
|
||||
if not fast_check_ffmpeg():
|
||||
print(f"Adding ffmpeg to PATH: {args.ffmpeg_path}")
|
||||
# According to operating system, choose path separator
|
||||
path_separator = ';' if sys.platform == 'win32' else ':'
|
||||
os.environ["PATH"] = f"{args.ffmpeg_path}{path_separator}{os.environ['PATH']}"
|
||||
if not fast_check_ffmpeg():
|
||||
print("Warning: Unable to find ffmpeg, please ensure ffmpeg is properly installed")
|
||||
|
||||
# Solve asynchronous IO issues on Windows
|
||||
if sys.platform == 'win32':
|
||||
import asyncio
|
||||
asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
|
||||
|
||||
# Start Gradio application
|
||||
demo.queue().launch(
|
||||
share=args.share,
|
||||
debug=True,
|
||||
server_name=args.ip,
|
||||
server_port=args.port
|
||||
)
|
||||
10
models/MuseTalk/configs/inference/realtime.yaml
Normal file
10
models/MuseTalk/configs/inference/realtime.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
avator_1:
|
||||
preparation: True # your can set it to False if you want to use the existing avator, it will save time
|
||||
bbox_shift: 5
|
||||
video_path: "data/video/yongen.mp4"
|
||||
audio_clips:
|
||||
audio_0: "data/audio/yongen.wav"
|
||||
audio_1: "data/audio/eng.wav"
|
||||
|
||||
|
||||
|
||||
10
models/MuseTalk/configs/inference/test.yaml
Normal file
10
models/MuseTalk/configs/inference/test.yaml
Normal file
@@ -0,0 +1,10 @@
|
||||
task_0:
|
||||
video_path: "data/video/yongen.mp4"
|
||||
audio_path: "data/audio/yongen.wav"
|
||||
|
||||
task_1:
|
||||
video_path: "data/video/yongen.mp4"
|
||||
audio_path: "data/audio/eng.wav"
|
||||
bbox_shift: -7
|
||||
|
||||
|
||||
21
models/MuseTalk/configs/training/gpu.yaml
Normal file
21
models/MuseTalk/configs/training/gpu.yaml
Normal file
@@ -0,0 +1,21 @@
|
||||
compute_environment: LOCAL_MACHINE
|
||||
debug: True
|
||||
deepspeed_config:
|
||||
offload_optimizer_device: none
|
||||
offload_param_device: none
|
||||
zero3_init_flag: False
|
||||
zero_stage: 2
|
||||
|
||||
distributed_type: DEEPSPEED
|
||||
downcast_bf16: 'no'
|
||||
gpu_ids: "5, 7" # modify this according to your GPU number
|
||||
machine_rank: 0
|
||||
main_training_function: main
|
||||
num_machines: 1
|
||||
num_processes: 2 # it should be the same as the number of GPUs
|
||||
rdzv_backend: static
|
||||
same_network: true
|
||||
tpu_env: []
|
||||
tpu_use_cluster: false
|
||||
tpu_use_sudo: false
|
||||
use_cpu: false
|
||||
31
models/MuseTalk/configs/training/preprocess.yaml
Normal file
31
models/MuseTalk/configs/training/preprocess.yaml
Normal file
@@ -0,0 +1,31 @@
|
||||
clip_len_second: 30 # the length of the video clip
|
||||
video_root_raw: "./dataset/HDTF/source/" # the path of the original video
|
||||
val_list_hdtf:
|
||||
- RD_Radio7_000
|
||||
- RD_Radio8_000
|
||||
- RD_Radio9_000
|
||||
- WDA_TinaSmith_000
|
||||
- WDA_TomCarper_000
|
||||
- WDA_TomPerez_000
|
||||
- WDA_TomUdall_000
|
||||
- WDA_VeronicaEscobar0_000
|
||||
- WDA_VeronicaEscobar1_000
|
||||
- WDA_WhipJimClyburn_000
|
||||
- WDA_XavierBecerra_000
|
||||
- WDA_XavierBecerra_001
|
||||
- WDA_XavierBecerra_002
|
||||
- WDA_ZoeLofgren_000
|
||||
- WRA_SteveScalise1_000
|
||||
- WRA_TimScott_000
|
||||
- WRA_ToddYoung_000
|
||||
- WRA_TomCotton_000
|
||||
- WRA_TomPrice_000
|
||||
- WRA_VickyHartzler_000
|
||||
|
||||
# following dir will be automatically generated
|
||||
video_root_25fps: "./dataset/HDTF/video_root_25fps/"
|
||||
video_file_list: "./dataset/HDTF/video_file_list.txt"
|
||||
video_audio_clip_root: "./dataset/HDTF/video_audio_clip_root/"
|
||||
meta_root: "./dataset/HDTF/meta/"
|
||||
video_clip_file_list_train: "./dataset/HDTF/train.txt"
|
||||
video_clip_file_list_val: "./dataset/HDTF/val.txt"
|
||||
89
models/MuseTalk/configs/training/stage1.yaml
Normal file
89
models/MuseTalk/configs/training/stage1.yaml
Normal file
@@ -0,0 +1,89 @@
|
||||
exp_name: 'test' # Name of the experiment
|
||||
output_dir: './exp_out/stage1/' # Directory to save experiment outputs
|
||||
unet_sub_folder: musetalk # Subfolder name for UNet model
|
||||
random_init_unet: True # Whether to randomly initialize UNet (stage1) or use pretrained weights (stage2)
|
||||
whisper_path: "./models/whisper" # Path to the Whisper model
|
||||
pretrained_model_name_or_path: "./models" # Path to pretrained models
|
||||
resume_from_checkpoint: True # Whether to resume training from a checkpoint
|
||||
padding_pixel_mouth: 10 # Number of pixels to pad around the mouth region
|
||||
vae_type: "sd-vae" # Type of VAE model to use
|
||||
# Validation parameters
|
||||
num_images_to_keep: 8 # Number of validation images to keep
|
||||
ref_dropout_rate: 0 # Dropout rate for reference images
|
||||
syncnet_config_path: "./configs/training/syncnet.yaml" # Path to SyncNet configuration
|
||||
use_adapted_weight: False # Whether to use adapted weights for loss calculation
|
||||
cropping_jaw2edge_margin_mean: 10 # Mean margin for jaw-to-edge cropping
|
||||
cropping_jaw2edge_margin_std: 10 # Standard deviation for jaw-to-edge cropping
|
||||
crop_type: "crop_resize" # Type of cropping method
|
||||
random_margin_method: "normal" # Method for random margin generation
|
||||
num_backward_frames: 16 # Number of frames to use for backward pass in SyncNet
|
||||
|
||||
data:
|
||||
dataset_key: "HDTF" # Dataset to use for training
|
||||
train_bs: 32 # Training batch size (actual batch size is train_bs*n_sample_frames)
|
||||
image_size: 256 # Size of input images
|
||||
n_sample_frames: 1 # Number of frames to sample per batch
|
||||
num_workers: 8 # Number of data loading workers
|
||||
audio_padding_length_left: 2 # Left padding length for audio features
|
||||
audio_padding_length_right: 2 # Right padding length for audio features
|
||||
sample_method: pose_similarity_and_mouth_dissimilarity # Method for sampling frames
|
||||
top_k_ratio: 0.51 # Ratio for top-k sampling
|
||||
contorl_face_min_size: True # Whether to control minimum face size
|
||||
min_face_size: 150 # Minimum face size in pixels
|
||||
|
||||
loss_params:
|
||||
l1_loss: 1.0 # Weight for L1 loss
|
||||
vgg_loss: 0.01 # Weight for VGG perceptual loss
|
||||
vgg_layer_weight: [1, 1, 1, 1, 1] # Weights for different VGG layers
|
||||
pyramid_scale: [1, 0.5, 0.25, 0.125] # Scales for image pyramid
|
||||
gan_loss: 0 # Weight for GAN loss
|
||||
fm_loss: [1.0, 1.0, 1.0, 1.0] # Weights for feature matching loss
|
||||
sync_loss: 0 # Weight for sync loss
|
||||
mouth_gan_loss: 0 # Weight for mouth-specific GAN loss
|
||||
|
||||
model_params:
|
||||
discriminator_params:
|
||||
scales: [1] # Scales for discriminator
|
||||
block_expansion: 32 # Expansion factor for discriminator blocks
|
||||
max_features: 512 # Maximum number of features in discriminator
|
||||
num_blocks: 4 # Number of blocks in discriminator
|
||||
sn: True # Whether to use spectral normalization
|
||||
image_channel: 3 # Number of image channels
|
||||
estimate_jacobian: False # Whether to estimate Jacobian
|
||||
|
||||
discriminator_train_params:
|
||||
lr: 0.000005 # Learning rate for discriminator
|
||||
eps: 0.00000001 # Epsilon for optimizer
|
||||
weight_decay: 0.01 # Weight decay for optimizer
|
||||
patch_size: 1 # Size of patches for discriminator
|
||||
betas: [0.5, 0.999] # Beta parameters for Adam optimizer
|
||||
epochs: 10000 # Number of training epochs
|
||||
start_gan: 1000 # Step to start GAN training
|
||||
|
||||
solver:
|
||||
gradient_accumulation_steps: 1 # Number of steps for gradient accumulation
|
||||
uncond_steps: 10 # Number of unconditional steps
|
||||
mixed_precision: 'fp32' # Precision mode for training
|
||||
enable_xformers_memory_efficient_attention: True # Whether to use memory efficient attention
|
||||
gradient_checkpointing: True # Whether to use gradient checkpointing
|
||||
max_train_steps: 250000 # Maximum number of training steps
|
||||
max_grad_norm: 1.0 # Maximum gradient norm for clipping
|
||||
# Learning rate parameters
|
||||
learning_rate: 2.0e-5 # Base learning rate
|
||||
scale_lr: False # Whether to scale learning rate
|
||||
lr_warmup_steps: 1000 # Number of warmup steps for learning rate
|
||||
lr_scheduler: "linear" # Type of learning rate scheduler
|
||||
# Optimizer parameters
|
||||
use_8bit_adam: False # Whether to use 8-bit Adam optimizer
|
||||
adam_beta1: 0.5 # Beta1 parameter for Adam optimizer
|
||||
adam_beta2: 0.999 # Beta2 parameter for Adam optimizer
|
||||
adam_weight_decay: 1.0e-2 # Weight decay for Adam optimizer
|
||||
adam_epsilon: 1.0e-8 # Epsilon for Adam optimizer
|
||||
|
||||
total_limit: 10 # Maximum number of checkpoints to keep
|
||||
save_model_epoch_interval: 250000 # Interval between model saves
|
||||
checkpointing_steps: 10000 # Number of steps between checkpoints
|
||||
val_freq: 2000 # Frequency of validation
|
||||
|
||||
seed: 41 # Random seed for reproducibility
|
||||
|
||||
89
models/MuseTalk/configs/training/stage2.yaml
Normal file
89
models/MuseTalk/configs/training/stage2.yaml
Normal file
@@ -0,0 +1,89 @@
|
||||
exp_name: 'test' # Name of the experiment
|
||||
output_dir: './exp_out/stage2/' # Directory to save experiment outputs
|
||||
unet_sub_folder: musetalk # Subfolder name for UNet model
|
||||
random_init_unet: False # Whether to randomly initialize UNet (stage1) or use pretrained weights (stage2)
|
||||
whisper_path: "./models/whisper" # Path to the Whisper model
|
||||
pretrained_model_name_or_path: "./models" # Path to pretrained models
|
||||
resume_from_checkpoint: True # Whether to resume training from a checkpoint
|
||||
padding_pixel_mouth: 10 # Number of pixels to pad around the mouth region
|
||||
vae_type: "sd-vae" # Type of VAE model to use
|
||||
# Validation parameters
|
||||
num_images_to_keep: 8 # Number of validation images to keep
|
||||
ref_dropout_rate: 0 # Dropout rate for reference images
|
||||
syncnet_config_path: "./configs/training/syncnet.yaml" # Path to SyncNet configuration
|
||||
use_adapted_weight: False # Whether to use adapted weights for loss calculation
|
||||
cropping_jaw2edge_margin_mean: 10 # Mean margin for jaw-to-edge cropping
|
||||
cropping_jaw2edge_margin_std: 10 # Standard deviation for jaw-to-edge cropping
|
||||
crop_type: "dynamic_margin_crop_resize" # Type of cropping method
|
||||
random_margin_method: "normal" # Method for random margin generation
|
||||
num_backward_frames: 16 # Number of frames to use for backward pass in SyncNet
|
||||
|
||||
data:
|
||||
dataset_key: "HDTF" # Dataset to use for training
|
||||
train_bs: 2 # Training batch size (actual batch size is train_bs*n_sample_frames)
|
||||
image_size: 256 # Size of input images
|
||||
n_sample_frames: 16 # Number of frames to sample per batch
|
||||
num_workers: 8 # Number of data loading workers
|
||||
audio_padding_length_left: 2 # Left padding length for audio features
|
||||
audio_padding_length_right: 2 # Right padding length for audio features
|
||||
sample_method: pose_similarity_and_mouth_dissimilarity # Method for sampling frames
|
||||
top_k_ratio: 0.51 # Ratio for top-k sampling
|
||||
contorl_face_min_size: True # Whether to control minimum face size
|
||||
min_face_size: 200 # Minimum face size in pixels
|
||||
|
||||
loss_params:
|
||||
l1_loss: 1.0 # Weight for L1 loss
|
||||
vgg_loss: 0.01 # Weight for VGG perceptual loss
|
||||
vgg_layer_weight: [1, 1, 1, 1, 1] # Weights for different VGG layers
|
||||
pyramid_scale: [1, 0.5, 0.25, 0.125] # Scales for image pyramid
|
||||
gan_loss: 0.01 # Weight for GAN loss
|
||||
fm_loss: [1.0, 1.0, 1.0, 1.0] # Weights for feature matching loss
|
||||
sync_loss: 0.05 # Weight for sync loss
|
||||
mouth_gan_loss: 0.01 # Weight for mouth-specific GAN loss
|
||||
|
||||
model_params:
|
||||
discriminator_params:
|
||||
scales: [1] # Scales for discriminator
|
||||
block_expansion: 32 # Expansion factor for discriminator blocks
|
||||
max_features: 512 # Maximum number of features in discriminator
|
||||
num_blocks: 4 # Number of blocks in discriminator
|
||||
sn: True # Whether to use spectral normalization
|
||||
image_channel: 3 # Number of image channels
|
||||
estimate_jacobian: False # Whether to estimate Jacobian
|
||||
|
||||
discriminator_train_params:
|
||||
lr: 0.000005 # Learning rate for discriminator
|
||||
eps: 0.00000001 # Epsilon for optimizer
|
||||
weight_decay: 0.01 # Weight decay for optimizer
|
||||
patch_size: 1 # Size of patches for discriminator
|
||||
betas: [0.5, 0.999] # Beta parameters for Adam optimizer
|
||||
epochs: 10000 # Number of training epochs
|
||||
start_gan: 1000 # Step to start GAN training
|
||||
|
||||
solver:
|
||||
gradient_accumulation_steps: 8 # Number of steps for gradient accumulation
|
||||
uncond_steps: 10 # Number of unconditional steps
|
||||
mixed_precision: 'fp32' # Precision mode for training
|
||||
enable_xformers_memory_efficient_attention: True # Whether to use memory efficient attention
|
||||
gradient_checkpointing: True # Whether to use gradient checkpointing
|
||||
max_train_steps: 250000 # Maximum number of training steps
|
||||
max_grad_norm: 1.0 # Maximum gradient norm for clipping
|
||||
# Learning rate parameters
|
||||
learning_rate: 5.0e-6 # Base learning rate
|
||||
scale_lr: False # Whether to scale learning rate
|
||||
lr_warmup_steps: 1000 # Number of warmup steps for learning rate
|
||||
lr_scheduler: "linear" # Type of learning rate scheduler
|
||||
# Optimizer parameters
|
||||
use_8bit_adam: False # Whether to use 8-bit Adam optimizer
|
||||
adam_beta1: 0.5 # Beta1 parameter for Adam optimizer
|
||||
adam_beta2: 0.999 # Beta2 parameter for Adam optimizer
|
||||
adam_weight_decay: 1.0e-2 # Weight decay for Adam optimizer
|
||||
adam_epsilon: 1.0e-8 # Epsilon for Adam optimizer
|
||||
|
||||
total_limit: 10 # Maximum number of checkpoints to keep
|
||||
save_model_epoch_interval: 250000 # Interval between model saves
|
||||
checkpointing_steps: 2000 # Number of steps between checkpoints
|
||||
val_freq: 2000 # Frequency of validation
|
||||
|
||||
seed: 41 # Random seed for reproducibility
|
||||
|
||||
19
models/MuseTalk/configs/training/syncnet.yaml
Normal file
19
models/MuseTalk/configs/training/syncnet.yaml
Normal file
@@ -0,0 +1,19 @@
|
||||
# This file is modified from LatentSync (https://github.com/bytedance/LatentSync/blob/main/latentsync/configs/training/syncnet_16_pixel.yaml).
|
||||
model:
|
||||
audio_encoder: # input (1, 80, 52)
|
||||
in_channels: 1
|
||||
block_out_channels: [32, 64, 128, 256, 512, 1024, 2048]
|
||||
downsample_factors: [[2, 1], 2, 2, 1, 2, 2, [2, 3]]
|
||||
attn_blocks: [0, 0, 0, 0, 0, 0, 0]
|
||||
dropout: 0.0
|
||||
visual_encoder: # input (48, 128, 256)
|
||||
in_channels: 48
|
||||
block_out_channels: [64, 128, 256, 256, 512, 1024, 2048, 2048]
|
||||
downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2]
|
||||
attn_blocks: [0, 0, 0, 0, 0, 0, 0, 0]
|
||||
dropout: 0.0
|
||||
|
||||
ckpt:
|
||||
resume_ckpt_path: ""
|
||||
inference_ckpt_path: ./models/syncnet/latentsync_syncnet.pt # this pretrained model is from LatentSync (https://huggingface.co/ByteDance/LatentSync/tree/main)
|
||||
save_ckpt_steps: 2500
|
||||
41
models/MuseTalk/download_weights.bat
Normal file
41
models/MuseTalk/download_weights.bat
Normal file
@@ -0,0 +1,41 @@
|
||||
@echo off
|
||||
setlocal
|
||||
|
||||
:: Set the checkpoints directory
|
||||
set CheckpointsDir=models
|
||||
|
||||
:: Create necessary directories
|
||||
mkdir %CheckpointsDir%\musetalk
|
||||
mkdir %CheckpointsDir%\musetalkV15
|
||||
mkdir %CheckpointsDir%\syncnet
|
||||
mkdir %CheckpointsDir%\dwpose
|
||||
mkdir %CheckpointsDir%\face-parse-bisent
|
||||
mkdir %CheckpointsDir%\sd-vae-ft-mse
|
||||
mkdir %CheckpointsDir%\whisper
|
||||
|
||||
:: Install required packages
|
||||
pip install -U "huggingface_hub[hf_xet]"
|
||||
|
||||
:: Set HuggingFace endpoint
|
||||
set HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
:: Download MuseTalk weights
|
||||
hf download TMElyralab/MuseTalk --local-dir %CheckpointsDir%
|
||||
|
||||
:: Download SD VAE weights
|
||||
hf download stabilityai/sd-vae-ft-mse --local-dir %CheckpointsDir%\sd-vae --include "config.json" "diffusion_pytorch_model.bin"
|
||||
|
||||
:: Download Whisper weights
|
||||
hf download openai/whisper-tiny --local-dir %CheckpointsDir%\whisper --include "config.json" "pytorch_model.bin" "preprocessor_config.json"
|
||||
|
||||
:: Download DWPose weights
|
||||
hf download yzd-v/DWPose --local-dir %CheckpointsDir%\dwpose --include "dw-ll_ucoco_384.pth"
|
||||
|
||||
:: Download SyncNet weights
|
||||
hf download ByteDance/LatentSync --local-dir %CheckpointsDir%\syncnet --include "latentsync_syncnet.pt"
|
||||
|
||||
:: Download face-parse-bisent weights
|
||||
hf download ManyOtherFunctions/face-parse-bisent --local-dir %CheckpointsDir%\face-parse-bisent --include "79999_iter.pth" "resnet18-5c106cde.pth"
|
||||
|
||||
echo All weights have been downloaded successfully!
|
||||
endlocal
|
||||
51
models/MuseTalk/download_weights.sh
Normal file
51
models/MuseTalk/download_weights.sh
Normal file
@@ -0,0 +1,51 @@
|
||||
#!/bin/bash
|
||||
|
||||
# Set the checkpoints directory
|
||||
CheckpointsDir="models"
|
||||
|
||||
# Create necessary directories
|
||||
mkdir -p models/musetalk models/musetalkV15 models/syncnet models/dwpose models/face-parse-bisent models/sd-vae models/whisper
|
||||
|
||||
# Install required packages
|
||||
pip install -U "huggingface_hub[cli]"
|
||||
pip install gdown
|
||||
|
||||
# Set HuggingFace mirror endpoint
|
||||
export HF_ENDPOINT=https://hf-mirror.com
|
||||
|
||||
# Download MuseTalk V1.0 weights
|
||||
huggingface-cli download TMElyralab/MuseTalk \
|
||||
--local-dir $CheckpointsDir \
|
||||
--include "musetalk/musetalk.json" "musetalk/pytorch_model.bin"
|
||||
|
||||
# Download MuseTalk V1.5 weights (unet.pth)
|
||||
huggingface-cli download TMElyralab/MuseTalk \
|
||||
--local-dir $CheckpointsDir \
|
||||
--include "musetalkV15/musetalk.json" "musetalkV15/unet.pth"
|
||||
|
||||
# Download SD VAE weights
|
||||
huggingface-cli download stabilityai/sd-vae-ft-mse \
|
||||
--local-dir $CheckpointsDir/sd-vae \
|
||||
--include "config.json" "diffusion_pytorch_model.bin"
|
||||
|
||||
# Download Whisper weights
|
||||
huggingface-cli download openai/whisper-tiny \
|
||||
--local-dir $CheckpointsDir/whisper \
|
||||
--include "config.json" "pytorch_model.bin" "preprocessor_config.json"
|
||||
|
||||
# Download DWPose weights
|
||||
huggingface-cli download yzd-v/DWPose \
|
||||
--local-dir $CheckpointsDir/dwpose \
|
||||
--include "dw-ll_ucoco_384.pth"
|
||||
|
||||
# Download SyncNet weights
|
||||
huggingface-cli download ByteDance/LatentSync \
|
||||
--local-dir $CheckpointsDir/syncnet \
|
||||
--include "latentsync_syncnet.pt"
|
||||
|
||||
# Download Face Parse Bisent weights
|
||||
gdown --id 154JgKpzCPW82qINcVieuPH3fZ2e0P812 -O $CheckpointsDir/face-parse-bisent/79999_iter.pth
|
||||
curl -L https://download.pytorch.org/models/resnet18-5c106cde.pth \
|
||||
-o $CheckpointsDir/face-parse-bisent/resnet18-5c106cde.pth
|
||||
|
||||
echo "✅ All weights have been downloaded successfully!"
|
||||
9
models/MuseTalk/entrypoint.sh
Normal file
9
models/MuseTalk/entrypoint.sh
Normal file
@@ -0,0 +1,9 @@
|
||||
#!/bin/bash
|
||||
|
||||
echo "entrypoint.sh"
|
||||
whoami
|
||||
which python
|
||||
source /opt/conda/etc/profile.d/conda.sh
|
||||
conda activate musev
|
||||
which python
|
||||
python app.py
|
||||
72
models/MuseTalk/inference.sh
Normal file
72
models/MuseTalk/inference.sh
Normal file
@@ -0,0 +1,72 @@
|
||||
#!/bin/bash
|
||||
|
||||
# This script runs inference based on the version and mode specified by the user.
|
||||
# Usage:
|
||||
# To run v1.0 inference: sh inference.sh v1.0 [normal|realtime]
|
||||
# To run v1.5 inference: sh inference.sh v1.5 [normal|realtime]
|
||||
|
||||
# Check if the correct number of arguments is provided
|
||||
if [ "$#" -ne 2 ]; then
|
||||
echo "Usage: $0 <version> <mode>"
|
||||
echo "Example: $0 v1.0 normal or $0 v1.5 realtime"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Get the version and mode from the user input
|
||||
version=$1
|
||||
mode=$2
|
||||
|
||||
# Validate mode
|
||||
if [ "$mode" != "normal" ] && [ "$mode" != "realtime" ]; then
|
||||
echo "Invalid mode specified. Please use 'normal' or 'realtime'."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Set config path based on mode
|
||||
if [ "$mode" = "normal" ]; then
|
||||
config_path="./configs/inference/test.yaml"
|
||||
result_dir="./results/test"
|
||||
else
|
||||
config_path="./configs/inference/realtime.yaml"
|
||||
result_dir="./results/realtime"
|
||||
fi
|
||||
|
||||
# Define the model paths based on the version
|
||||
if [ "$version" = "v1.0" ]; then
|
||||
model_dir="./models/musetalk"
|
||||
unet_model_path="$model_dir/pytorch_model.bin"
|
||||
unet_config="$model_dir/musetalk.json"
|
||||
version_arg="v1"
|
||||
elif [ "$version" = "v1.5" ]; then
|
||||
model_dir="./models/musetalkV15"
|
||||
unet_model_path="$model_dir/unet.pth"
|
||||
unet_config="$model_dir/musetalk.json"
|
||||
version_arg="v15"
|
||||
else
|
||||
echo "Invalid version specified. Please use v1.0 or v1.5."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Set script name based on mode
|
||||
if [ "$mode" = "normal" ]; then
|
||||
script_name="scripts.inference"
|
||||
else
|
||||
script_name="scripts.realtime_inference"
|
||||
fi
|
||||
|
||||
# Base command arguments
|
||||
cmd_args="--inference_config $config_path \
|
||||
--result_dir $result_dir \
|
||||
--unet_model_path $unet_model_path \
|
||||
--unet_config $unet_config \
|
||||
--version $version_arg"
|
||||
|
||||
# Add realtime-specific arguments if in realtime mode
|
||||
if [ "$mode" = "realtime" ]; then
|
||||
cmd_args="$cmd_args \
|
||||
--fps 25 \
|
||||
--version $version_arg"
|
||||
fi
|
||||
|
||||
# Run inference
|
||||
python3 -m $script_name $cmd_args
|
||||
168
models/MuseTalk/musetalk/data/audio.py
Normal file
168
models/MuseTalk/musetalk/data/audio.py
Normal file
@@ -0,0 +1,168 @@
|
||||
import librosa
|
||||
import librosa.filters
|
||||
import numpy as np
|
||||
from scipy import signal
|
||||
from scipy.io import wavfile
|
||||
|
||||
class HParams:
|
||||
# copy from wav2lip
|
||||
def __init__(self):
|
||||
self.n_fft = 800
|
||||
self.hop_size = 200
|
||||
self.win_size = 800
|
||||
self.sample_rate = 16000
|
||||
self.frame_shift_ms = None
|
||||
self.signal_normalization = True
|
||||
|
||||
self.allow_clipping_in_normalization = True
|
||||
self.symmetric_mels = True
|
||||
self.max_abs_value = 4.0
|
||||
self.preemphasize = True
|
||||
self.preemphasis = 0.97
|
||||
self.min_level_db = -100
|
||||
self.ref_level_db = 20
|
||||
self.fmin = 55
|
||||
self.fmax=7600
|
||||
|
||||
self.use_lws=False
|
||||
self.num_mels=80 # Number of mel-spectrogram channels and local conditioning dimensionality
|
||||
self.rescale=True # Whether to rescale audio prior to preprocessing
|
||||
self.rescaling_max=0.9 # Rescaling value
|
||||
self.use_lws=False
|
||||
|
||||
|
||||
hp = HParams()
|
||||
|
||||
def load_wav(path, sr):
|
||||
return librosa.core.load(path, sr=sr)[0]
|
||||
#def load_wav(path, sr):
|
||||
# audio, sr_native = sf.read(path)
|
||||
# if sr != sr_native:
|
||||
# audio = librosa.resample(audio.T, sr_native, sr).T
|
||||
# return audio
|
||||
|
||||
def save_wav(wav, path, sr):
|
||||
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
|
||||
#proposed by @dsmiller
|
||||
wavfile.write(path, sr, wav.astype(np.int16))
|
||||
|
||||
def save_wavenet_wav(wav, path, sr):
|
||||
librosa.output.write_wav(path, wav, sr=sr)
|
||||
|
||||
def preemphasis(wav, k, preemphasize=True):
|
||||
if preemphasize:
|
||||
return signal.lfilter([1, -k], [1], wav)
|
||||
return wav
|
||||
|
||||
def inv_preemphasis(wav, k, inv_preemphasize=True):
|
||||
if inv_preemphasize:
|
||||
return signal.lfilter([1], [1, -k], wav)
|
||||
return wav
|
||||
|
||||
def get_hop_size():
|
||||
hop_size = hp.hop_size
|
||||
if hop_size is None:
|
||||
assert hp.frame_shift_ms is not None
|
||||
hop_size = int(hp.frame_shift_ms / 1000 * hp.sample_rate)
|
||||
return hop_size
|
||||
|
||||
def linearspectrogram(wav):
|
||||
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
||||
S = _amp_to_db(np.abs(D)) - hp.ref_level_db
|
||||
|
||||
if hp.signal_normalization:
|
||||
return _normalize(S)
|
||||
return S
|
||||
|
||||
def melspectrogram(wav):
|
||||
D = _stft(preemphasis(wav, hp.preemphasis, hp.preemphasize))
|
||||
S = _amp_to_db(_linear_to_mel(np.abs(D))) - hp.ref_level_db
|
||||
|
||||
if hp.signal_normalization:
|
||||
return _normalize(S)
|
||||
return S
|
||||
|
||||
def _lws_processor():
|
||||
import lws
|
||||
return lws.lws(hp.n_fft, get_hop_size(), fftsize=hp.win_size, mode="speech")
|
||||
|
||||
def _stft(y):
|
||||
if hp.use_lws:
|
||||
return _lws_processor(hp).stft(y).T
|
||||
else:
|
||||
return librosa.stft(y=y, n_fft=hp.n_fft, hop_length=get_hop_size(), win_length=hp.win_size)
|
||||
|
||||
##########################################################
|
||||
#Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
|
||||
def num_frames(length, fsize, fshift):
|
||||
"""Compute number of time frames of spectrogram
|
||||
"""
|
||||
pad = (fsize - fshift)
|
||||
if length % fshift == 0:
|
||||
M = (length + pad * 2 - fsize) // fshift + 1
|
||||
else:
|
||||
M = (length + pad * 2 - fsize) // fshift + 2
|
||||
return M
|
||||
|
||||
|
||||
def pad_lr(x, fsize, fshift):
|
||||
"""Compute left and right padding
|
||||
"""
|
||||
M = num_frames(len(x), fsize, fshift)
|
||||
pad = (fsize - fshift)
|
||||
T = len(x) + 2 * pad
|
||||
r = (M - 1) * fshift + fsize - T
|
||||
return pad, pad + r
|
||||
##########################################################
|
||||
#Librosa correct padding
|
||||
def librosa_pad_lr(x, fsize, fshift):
|
||||
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
|
||||
|
||||
# Conversions
|
||||
_mel_basis = None
|
||||
|
||||
def _linear_to_mel(spectogram):
|
||||
global _mel_basis
|
||||
if _mel_basis is None:
|
||||
_mel_basis = _build_mel_basis()
|
||||
return np.dot(_mel_basis, spectogram)
|
||||
|
||||
def _build_mel_basis():
|
||||
assert hp.fmax <= hp.sample_rate // 2
|
||||
return librosa.filters.mel(sr=hp.sample_rate, n_fft=hp.n_fft, n_mels=hp.num_mels,
|
||||
fmin=hp.fmin, fmax=hp.fmax)
|
||||
|
||||
def _amp_to_db(x):
|
||||
min_level = np.exp(hp.min_level_db / 20 * np.log(10))
|
||||
return 20 * np.log10(np.maximum(min_level, x))
|
||||
|
||||
def _db_to_amp(x):
|
||||
return np.power(10.0, (x) * 0.05)
|
||||
|
||||
def _normalize(S):
|
||||
if hp.allow_clipping_in_normalization:
|
||||
if hp.symmetric_mels:
|
||||
return np.clip((2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value,
|
||||
-hp.max_abs_value, hp.max_abs_value)
|
||||
else:
|
||||
return np.clip(hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db)), 0, hp.max_abs_value)
|
||||
|
||||
assert S.max() <= 0 and S.min() - hp.min_level_db >= 0
|
||||
if hp.symmetric_mels:
|
||||
return (2 * hp.max_abs_value) * ((S - hp.min_level_db) / (-hp.min_level_db)) - hp.max_abs_value
|
||||
else:
|
||||
return hp.max_abs_value * ((S - hp.min_level_db) / (-hp.min_level_db))
|
||||
|
||||
def _denormalize(D):
|
||||
if hp.allow_clipping_in_normalization:
|
||||
if hp.symmetric_mels:
|
||||
return (((np.clip(D, -hp.max_abs_value,
|
||||
hp.max_abs_value) + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value))
|
||||
+ hp.min_level_db)
|
||||
else:
|
||||
return ((np.clip(D, 0, hp.max_abs_value) * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
|
||||
|
||||
if hp.symmetric_mels:
|
||||
return (((D + hp.max_abs_value) * -hp.min_level_db / (2 * hp.max_abs_value)) + hp.min_level_db)
|
||||
else:
|
||||
return ((D * -hp.min_level_db / hp.max_abs_value) + hp.min_level_db)
|
||||
610
models/MuseTalk/musetalk/data/dataset.py
Normal file
610
models/MuseTalk/musetalk/data/dataset.py
Normal file
@@ -0,0 +1,610 @@
|
||||
import os
|
||||
import numpy as np
|
||||
import random
|
||||
from PIL import Image
|
||||
import torch
|
||||
from torch.utils.data import Dataset, ConcatDataset
|
||||
import torchvision.transforms as transforms
|
||||
from transformers import AutoFeatureExtractor
|
||||
import librosa
|
||||
import time
|
||||
import json
|
||||
import math
|
||||
from decord import AudioReader, VideoReader
|
||||
from decord.ndarray import cpu
|
||||
|
||||
from musetalk.data.sample_method import get_src_idx, shift_landmarks_to_face_coordinates, resize_landmark
|
||||
from musetalk.data import audio
|
||||
from musetalk.utils.audio_utils import ensure_wav
|
||||
|
||||
syncnet_mel_step_size = math.ceil(16 / 5 * 16) # latentsync
|
||||
|
||||
|
||||
class FaceDataset(Dataset):
|
||||
"""Dataset class for loading and processing video data
|
||||
|
||||
Each video can be represented as:
|
||||
- Concatenated frame images
|
||||
- '.mp4' or '.gif' files
|
||||
- Folder containing all frames
|
||||
"""
|
||||
def __init__(self,
|
||||
cfg,
|
||||
list_paths,
|
||||
root_path='./dataset/',
|
||||
repeats=None):
|
||||
# Initialize dataset paths
|
||||
meta_paths = []
|
||||
if repeats is None:
|
||||
repeats = [1] * len(list_paths)
|
||||
assert len(repeats) == len(list_paths)
|
||||
|
||||
# Load data list
|
||||
for list_path, repeat_time in zip(list_paths, repeats):
|
||||
with open(list_path, 'r') as f:
|
||||
num = 0
|
||||
f.readline() # Skip header line
|
||||
for line in f.readlines():
|
||||
line_info = line.strip()
|
||||
meta = line_info.split()
|
||||
meta = meta[0]
|
||||
meta_paths.extend([os.path.join(root_path, meta)] * repeat_time)
|
||||
num += 1
|
||||
print(f'{list_path}: {num} x {repeat_time} = {num * repeat_time} samples')
|
||||
|
||||
# Set basic attributes
|
||||
self.meta_paths = meta_paths
|
||||
self.root_path = root_path
|
||||
self.image_size = cfg['image_size']
|
||||
self.min_face_size = cfg['min_face_size']
|
||||
self.T = cfg['T']
|
||||
self.sample_method = cfg['sample_method']
|
||||
self.top_k_ratio = cfg['top_k_ratio']
|
||||
self.max_attempts = 200
|
||||
self.padding_pixel_mouth = cfg['padding_pixel_mouth']
|
||||
|
||||
# Cropping related parameters
|
||||
self.crop_type = cfg['crop_type']
|
||||
self.jaw2edge_margin_mean = cfg['cropping_jaw2edge_margin_mean']
|
||||
self.jaw2edge_margin_std = cfg['cropping_jaw2edge_margin_std']
|
||||
self.random_margin_method = cfg['random_margin_method']
|
||||
|
||||
# Image transformations
|
||||
self.to_tensor = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
|
||||
])
|
||||
self.pose_to_tensor = transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
])
|
||||
|
||||
# Feature extractor
|
||||
self.feature_extractor = AutoFeatureExtractor.from_pretrained(cfg['whisper_path'])
|
||||
self.contorl_face_min_size = cfg["contorl_face_min_size"]
|
||||
|
||||
print("The sample method is: ", self.sample_method)
|
||||
print(f"only use face size > {self.min_face_size}", self.contorl_face_min_size)
|
||||
|
||||
def generate_random_value(self):
|
||||
"""Generate random value
|
||||
|
||||
Returns:
|
||||
float: Generated random value
|
||||
"""
|
||||
if self.random_margin_method == "uniform":
|
||||
random_value = np.random.uniform(
|
||||
self.jaw2edge_margin_mean - self.jaw2edge_margin_std,
|
||||
self.jaw2edge_margin_mean + self.jaw2edge_margin_std
|
||||
)
|
||||
elif self.random_margin_method == "normal":
|
||||
random_value = np.random.normal(
|
||||
loc=self.jaw2edge_margin_mean,
|
||||
scale=self.jaw2edge_margin_std
|
||||
)
|
||||
random_value = np.clip(
|
||||
random_value,
|
||||
self.jaw2edge_margin_mean - self.jaw2edge_margin_std,
|
||||
self.jaw2edge_margin_mean + self.jaw2edge_margin_std,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid random margin method: {self.random_margin_method}")
|
||||
return max(0, random_value)
|
||||
|
||||
def dynamic_margin_crop(self, img, original_bbox, extra_margin=None):
|
||||
"""Dynamically crop image with dynamic margin
|
||||
|
||||
Args:
|
||||
img: Input image
|
||||
original_bbox: Original bounding box
|
||||
extra_margin: Extra margin
|
||||
|
||||
Returns:
|
||||
tuple: (x1, y1, x2, y2, extra_margin)
|
||||
"""
|
||||
if extra_margin is None:
|
||||
extra_margin = self.generate_random_value()
|
||||
w, h = img.size
|
||||
x1, y1, x2, y2 = original_bbox
|
||||
y2 = min(y2 + int(extra_margin), h)
|
||||
return x1, y1, x2, y2, extra_margin
|
||||
|
||||
def crop_resize_img(self, img, bbox, crop_type='crop_resize', extra_margin=None):
|
||||
"""Crop and resize image
|
||||
|
||||
Args:
|
||||
img: Input image
|
||||
bbox: Bounding box
|
||||
crop_type: Type of cropping
|
||||
extra_margin: Extra margin
|
||||
|
||||
Returns:
|
||||
tuple: (Processed image, extra_margin, mask_scaled_factor)
|
||||
"""
|
||||
mask_scaled_factor = 1.
|
||||
if crop_type == 'crop_resize':
|
||||
x1, y1, x2, y2 = bbox
|
||||
img = img.crop((x1, y1, x2, y2))
|
||||
img = img.resize((self.image_size, self.image_size), Image.LANCZOS)
|
||||
elif crop_type == 'dynamic_margin_crop_resize':
|
||||
x1, y1, x2, y2, extra_margin = self.dynamic_margin_crop(img, bbox, extra_margin)
|
||||
w_original, _ = img.size
|
||||
img = img.crop((x1, y1, x2, y2))
|
||||
w_cropped, _ = img.size
|
||||
mask_scaled_factor = w_cropped / w_original
|
||||
img = img.resize((self.image_size, self.image_size), Image.LANCZOS)
|
||||
elif crop_type == 'resize':
|
||||
w, h = img.size
|
||||
scale = np.sqrt(self.image_size ** 2 / (h * w))
|
||||
new_w = int(w * scale) / 64 * 64
|
||||
new_h = int(h * scale) / 64 * 64
|
||||
img = img.resize((new_w, new_h), Image.LANCZOS)
|
||||
return img, extra_margin, mask_scaled_factor
|
||||
|
||||
def get_audio_file(self, wav_path, start_index):
|
||||
"""Get audio file features
|
||||
|
||||
Args:
|
||||
wav_path: Audio file path
|
||||
start_index: Starting index
|
||||
|
||||
Returns:
|
||||
tuple: (Audio features, start index)
|
||||
"""
|
||||
if not os.path.exists(wav_path):
|
||||
return None
|
||||
wav_path_converted = ensure_wav(wav_path)
|
||||
audio_input_librosa, sampling_rate = librosa.load(wav_path_converted, sr=16000)
|
||||
assert sampling_rate == 16000
|
||||
|
||||
while start_index >= 25 * 30:
|
||||
audio_input = audio_input_librosa[16000*30:]
|
||||
start_index -= 25 * 30
|
||||
if start_index + 2 * 25 >= 25 * 30:
|
||||
start_index -= 4 * 25
|
||||
audio_input = audio_input_librosa[16000*4:16000*34]
|
||||
else:
|
||||
audio_input = audio_input_librosa[:16000*30]
|
||||
|
||||
assert 2 * (start_index) >= 0
|
||||
assert 2 * (start_index + 2 * 25) <= 1500
|
||||
|
||||
audio_input = self.feature_extractor(
|
||||
audio_input,
|
||||
return_tensors="pt",
|
||||
sampling_rate=sampling_rate
|
||||
).input_features
|
||||
return audio_input, start_index
|
||||
|
||||
def get_audio_file_mel(self, wav_path, start_index):
|
||||
"""Get mel spectrogram of audio file
|
||||
|
||||
Args:
|
||||
wav_path: Audio file path
|
||||
start_index: Starting index
|
||||
|
||||
Returns:
|
||||
tuple: (Mel spectrogram, start index)
|
||||
"""
|
||||
if not os.path.exists(wav_path):
|
||||
return None
|
||||
|
||||
wav_path_converted = ensure_wav(wav_path)
|
||||
audio_input_librosa, sampling_rate = librosa.load(wav_path_converted, sr=16000)
|
||||
assert sampling_rate == 16000
|
||||
|
||||
audio_mel = self.mel_feature_extractor(audio_input_librosa)
|
||||
return audio_mel, start_index
|
||||
|
||||
def mel_feature_extractor(self, audio_input):
|
||||
"""Extract mel spectrogram features
|
||||
|
||||
Args:
|
||||
audio_input: Input audio
|
||||
|
||||
Returns:
|
||||
ndarray: Mel spectrogram features
|
||||
"""
|
||||
orig_mel = audio.melspectrogram(audio_input)
|
||||
return orig_mel.T
|
||||
|
||||
def crop_audio_window(self, spec, start_frame_num, fps=25):
|
||||
"""Crop audio window
|
||||
|
||||
Args:
|
||||
spec: Spectrogram
|
||||
start_frame_num: Starting frame number
|
||||
fps: Frames per second
|
||||
|
||||
Returns:
|
||||
ndarray: Cropped spectrogram
|
||||
"""
|
||||
start_idx = int(80. * (start_frame_num / float(fps)))
|
||||
end_idx = start_idx + syncnet_mel_step_size
|
||||
return spec[start_idx: end_idx, :]
|
||||
|
||||
def get_syncnet_input(self, video_path):
|
||||
"""Get SyncNet input features
|
||||
|
||||
Args:
|
||||
video_path: Video file path
|
||||
|
||||
Returns:
|
||||
ndarray: SyncNet input features
|
||||
"""
|
||||
ar = AudioReader(video_path, sample_rate=16000)
|
||||
original_mel = audio.melspectrogram(ar[:].asnumpy().squeeze(0))
|
||||
return original_mel.T
|
||||
|
||||
def get_resized_mouth_mask(
|
||||
self,
|
||||
img_resized,
|
||||
landmark_array,
|
||||
face_shape,
|
||||
padding_pixel_mouth=0,
|
||||
image_size=256,
|
||||
crop_margin=0
|
||||
):
|
||||
landmark_array = np.array(landmark_array)
|
||||
resized_landmark = resize_landmark(
|
||||
landmark_array, w=face_shape[0], h=face_shape[1], new_w=image_size, new_h=image_size)
|
||||
|
||||
landmark_array = np.array(resized_landmark[48 : 67]) # the lip landmarks in 68 landmarks format
|
||||
min_x, min_y = np.min(landmark_array, axis=0)
|
||||
max_x, max_y = np.max(landmark_array, axis=0)
|
||||
min_x = min_x - padding_pixel_mouth
|
||||
max_x = max_x + padding_pixel_mouth
|
||||
|
||||
# Calculate x-axis length and use it for y-axis
|
||||
width = max_x - min_x
|
||||
|
||||
# Calculate old center point
|
||||
center_y = (max_y + min_y) / 2
|
||||
|
||||
# Determine new min_y and max_y based on width
|
||||
min_y = center_y - width / 4
|
||||
max_y = center_y + width / 4
|
||||
|
||||
# Adjust mask position for dynamic crop, shift y-axis
|
||||
min_y = min_y - crop_margin
|
||||
max_y = max_y - crop_margin
|
||||
|
||||
# Prevent out of bounds
|
||||
min_x = max(min_x, 0)
|
||||
min_y = max(min_y, 0)
|
||||
max_x = min(max_x, face_shape[0])
|
||||
max_y = min(max_y, face_shape[1])
|
||||
|
||||
mask = np.zeros_like(np.array(img_resized))
|
||||
mask[round(min_y):round(max_y), round(min_x):round(max_x)] = 255
|
||||
return Image.fromarray(mask)
|
||||
|
||||
def __len__(self):
|
||||
return 100000
|
||||
|
||||
def __getitem__(self, idx):
|
||||
attempts = 0
|
||||
while attempts < self.max_attempts:
|
||||
try:
|
||||
meta_path = random.sample(self.meta_paths, k=1)[0]
|
||||
with open(meta_path, 'r') as f:
|
||||
meta_data = json.load(f)
|
||||
except Exception as e:
|
||||
print(f"meta file error:{meta_path}")
|
||||
print(e)
|
||||
attempts += 1
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
video_path = meta_data["mp4_path"]
|
||||
wav_path = meta_data["wav_path"]
|
||||
bbox_list = meta_data["face_list"]
|
||||
landmark_list = meta_data["landmark_list"]
|
||||
T = self.T
|
||||
|
||||
s = 0
|
||||
e = meta_data["frames"]
|
||||
len_valid_clip = e - s
|
||||
|
||||
if len_valid_clip < T * 10:
|
||||
attempts += 1
|
||||
print(f"video {video_path} has less than {T * 10} frames")
|
||||
continue
|
||||
|
||||
try:
|
||||
cap = VideoReader(video_path, fault_tol=1, ctx=cpu(0))
|
||||
total_frames = len(cap)
|
||||
assert total_frames == len(landmark_list)
|
||||
assert total_frames == len(bbox_list)
|
||||
landmark_shape = np.array(landmark_list).shape
|
||||
if landmark_shape != (total_frames, 68, 2):
|
||||
attempts += 1
|
||||
print(f"video {video_path} has invalid landmark shape: {landmark_shape}, expected: {(total_frames, 68, 2)}") # we use 68 landmarks
|
||||
continue
|
||||
except Exception as e:
|
||||
print(f"video file error:{video_path}")
|
||||
print(e)
|
||||
attempts += 1
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
shift_landmarks, bbox_list_union, face_shapes = shift_landmarks_to_face_coordinates(
|
||||
landmark_list,
|
||||
bbox_list
|
||||
)
|
||||
if self.contorl_face_min_size and face_shapes[0][0] < self.min_face_size:
|
||||
print(f"video {video_path} has face size {face_shapes[0][0]} less than minimum required {self.min_face_size}")
|
||||
attempts += 1
|
||||
continue
|
||||
|
||||
step = 1
|
||||
drive_idx_start = random.randint(s, e - T * step)
|
||||
drive_idx_list = list(
|
||||
range(drive_idx_start, drive_idx_start + T * step, step))
|
||||
assert len(drive_idx_list) == T
|
||||
|
||||
src_idx_list = []
|
||||
list_index_out_of_range = False
|
||||
for drive_idx in drive_idx_list:
|
||||
src_idx = get_src_idx(
|
||||
drive_idx, T, self.sample_method, shift_landmarks, face_shapes, self.top_k_ratio)
|
||||
if src_idx is None:
|
||||
list_index_out_of_range = True
|
||||
break
|
||||
src_idx = min(src_idx, e - 1)
|
||||
src_idx = max(src_idx, s)
|
||||
src_idx_list.append(src_idx)
|
||||
|
||||
if list_index_out_of_range:
|
||||
attempts += 1
|
||||
print(f"video {video_path} has invalid source index for drive frames")
|
||||
continue
|
||||
|
||||
ref_face_valid_flag = True
|
||||
extra_margin = self.generate_random_value()
|
||||
|
||||
# Get reference images
|
||||
ref_imgs = []
|
||||
for src_idx in src_idx_list:
|
||||
imSrc = Image.fromarray(cap[src_idx].asnumpy())
|
||||
bbox_s = bbox_list_union[src_idx]
|
||||
imSrc, _, _ = self.crop_resize_img(
|
||||
imSrc,
|
||||
bbox_s,
|
||||
self.crop_type,
|
||||
extra_margin=None
|
||||
)
|
||||
if self.contorl_face_min_size and min(imSrc.size[0], imSrc.size[1]) < self.min_face_size:
|
||||
ref_face_valid_flag = False
|
||||
break
|
||||
ref_imgs.append(imSrc)
|
||||
|
||||
if not ref_face_valid_flag:
|
||||
attempts += 1
|
||||
print(f"video {video_path} has reference face size smaller than minimum required {self.min_face_size}")
|
||||
continue
|
||||
|
||||
# Get target images and masks
|
||||
imSameIDs = []
|
||||
bboxes = []
|
||||
face_masks = []
|
||||
face_mask_valid = True
|
||||
target_face_valid_flag = True
|
||||
|
||||
for drive_idx in drive_idx_list:
|
||||
imSameID = Image.fromarray(cap[drive_idx].asnumpy())
|
||||
bbox_s = bbox_list_union[drive_idx]
|
||||
imSameID, _ , mask_scaled_factor = self.crop_resize_img(
|
||||
imSameID,
|
||||
bbox_s,
|
||||
self.crop_type,
|
||||
extra_margin=extra_margin
|
||||
)
|
||||
if self.contorl_face_min_size and min(imSameID.size[0], imSameID.size[1]) < self.min_face_size:
|
||||
target_face_valid_flag = False
|
||||
break
|
||||
crop_margin = extra_margin * mask_scaled_factor
|
||||
face_mask = self.get_resized_mouth_mask(
|
||||
imSameID,
|
||||
shift_landmarks[drive_idx],
|
||||
face_shapes[drive_idx],
|
||||
self.padding_pixel_mouth,
|
||||
self.image_size,
|
||||
crop_margin=crop_margin
|
||||
)
|
||||
if np.count_nonzero(face_mask) == 0:
|
||||
face_mask_valid = False
|
||||
break
|
||||
|
||||
if face_mask.size[1] == 0 or face_mask.size[0] == 0:
|
||||
print(f"video {video_path} has invalid face mask size at frame {drive_idx}")
|
||||
face_mask_valid = False
|
||||
break
|
||||
|
||||
imSameIDs.append(imSameID)
|
||||
bboxes.append(bbox_s)
|
||||
face_masks.append(face_mask)
|
||||
|
||||
if not face_mask_valid:
|
||||
attempts += 1
|
||||
print(f"video {video_path} has invalid face mask")
|
||||
continue
|
||||
|
||||
if not target_face_valid_flag:
|
||||
attempts += 1
|
||||
print(f"video {video_path} has target face size smaller than minimum required {self.min_face_size}")
|
||||
continue
|
||||
|
||||
# Process audio features
|
||||
audio_offset = drive_idx_list[0]
|
||||
audio_step = step
|
||||
fps = 25.0 / step
|
||||
|
||||
try:
|
||||
audio_feature, audio_offset = self.get_audio_file(wav_path, audio_offset)
|
||||
_, audio_offset = self.get_audio_file_mel(wav_path, audio_offset)
|
||||
audio_feature_mel = self.get_syncnet_input(video_path)
|
||||
except Exception as e:
|
||||
print(f"audio file error:{wav_path}")
|
||||
print(e)
|
||||
attempts += 1
|
||||
time.sleep(0.1)
|
||||
continue
|
||||
|
||||
mel = self.crop_audio_window(audio_feature_mel, audio_offset)
|
||||
if mel.shape[0] != syncnet_mel_step_size:
|
||||
attempts += 1
|
||||
print(f"video {video_path} has invalid mel spectrogram shape: {mel.shape}, expected: {syncnet_mel_step_size}")
|
||||
continue
|
||||
|
||||
mel = torch.FloatTensor(mel.T).unsqueeze(0)
|
||||
|
||||
# Build sample dictionary
|
||||
sample = dict(
|
||||
pixel_values_vid=torch.stack(
|
||||
[self.to_tensor(imSameID) for imSameID in imSameIDs], dim=0),
|
||||
pixel_values_ref_img=torch.stack(
|
||||
[self.to_tensor(ref_img) for ref_img in ref_imgs], dim=0),
|
||||
pixel_values_face_mask=torch.stack(
|
||||
[self.pose_to_tensor(face_mask) for face_mask in face_masks], dim=0),
|
||||
audio_feature=audio_feature[0],
|
||||
audio_offset=audio_offset,
|
||||
audio_step=audio_step,
|
||||
mel=mel,
|
||||
wav_path=wav_path,
|
||||
fps=fps,
|
||||
)
|
||||
|
||||
return sample
|
||||
|
||||
raise ValueError("Unable to find a valid sample after maximum attempts.")
|
||||
|
||||
class HDTFDataset(FaceDataset):
|
||||
"""HDTF dataset class"""
|
||||
def __init__(self, cfg):
|
||||
root_path = './dataset/HDTF/meta'
|
||||
list_paths = [
|
||||
'./dataset/HDTF/train.txt',
|
||||
]
|
||||
|
||||
|
||||
repeats = [10]
|
||||
super().__init__(cfg, list_paths, root_path, repeats)
|
||||
print('HDTFDataset: ', len(self))
|
||||
|
||||
class VFHQDataset(FaceDataset):
|
||||
"""VFHQ dataset class"""
|
||||
def __init__(self, cfg):
|
||||
root_path = './dataset/VFHQ/meta'
|
||||
list_paths = [
|
||||
'./dataset/VFHQ/train.txt',
|
||||
]
|
||||
repeats = [1]
|
||||
super().__init__(cfg, list_paths, root_path, repeats)
|
||||
print('VFHQDataset: ', len(self))
|
||||
|
||||
def PortraitDataset(cfg=None):
|
||||
"""Return dataset based on configuration
|
||||
|
||||
Args:
|
||||
cfg: Configuration dictionary
|
||||
|
||||
Returns:
|
||||
Dataset: Combined dataset
|
||||
"""
|
||||
if cfg["dataset_key"] == "HDTF":
|
||||
return ConcatDataset([HDTFDataset(cfg)])
|
||||
elif cfg["dataset_key"] == "VFHQ":
|
||||
return ConcatDataset([VFHQDataset(cfg)])
|
||||
else:
|
||||
print("############ use all dataset ############ ")
|
||||
return ConcatDataset([HDTFDataset(cfg), VFHQDataset(cfg)])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
# Set random seeds for reproducibility
|
||||
seed = 42
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
|
||||
# Create dataset with configuration parameters
|
||||
dataset = PortraitDataset(cfg={
|
||||
'T': 1, # Number of frames to process at once
|
||||
'random_margin_method': "normal", # Method for generating random margins: "normal" or "uniform"
|
||||
'dataset_key': "HDTF", # Dataset to use: "HDTF", "VFHQ", or None for both
|
||||
'image_size': 256, # Size of processed images (height and width)
|
||||
'sample_method': 'pose_similarity_and_mouth_dissimilarity', # Method for selecting reference frames
|
||||
'top_k_ratio': 0.51, # Ratio for top-k selection in reference frame sampling
|
||||
'contorl_face_min_size': True, # Whether to enforce minimum face size
|
||||
'padding_pixel_mouth': 10, # Padding pixels around mouth region in mask
|
||||
'min_face_size': 200, # Minimum face size requirement for dataset
|
||||
'whisper_path': "./models/whisper", # Path to Whisper model
|
||||
'cropping_jaw2edge_margin_mean': 10, # Mean margin for jaw-to-edge cropping
|
||||
'cropping_jaw2edge_margin_std': 10, # Standard deviation for jaw-to-edge cropping
|
||||
'crop_type': "dynamic_margin_crop_resize", # Type of cropping: "crop_resize", "dynamic_margin_crop_resize", or "resize"
|
||||
})
|
||||
print(len(dataset))
|
||||
|
||||
import torchvision
|
||||
os.makedirs('debug', exist_ok=True)
|
||||
for i in range(10): # Check 10 samples
|
||||
sample = dataset[0]
|
||||
print(f"processing {i}")
|
||||
|
||||
# Get images and mask
|
||||
ref_img = (sample['pixel_values_ref_img'] + 1.0) / 2 # (b, c, h, w)
|
||||
target_img = (sample['pixel_values_vid'] + 1.0) / 2
|
||||
face_mask = sample['pixel_values_face_mask']
|
||||
|
||||
# Print dimension information
|
||||
print(f"ref_img shape: {ref_img.shape}")
|
||||
print(f"target_img shape: {target_img.shape}")
|
||||
print(f"face_mask shape: {face_mask.shape}")
|
||||
|
||||
# Create visualization images
|
||||
b, c, h, w = ref_img.shape
|
||||
|
||||
# Apply mask only to target image
|
||||
target_mask = face_mask
|
||||
|
||||
# Keep reference image unchanged
|
||||
ref_with_mask = ref_img.clone()
|
||||
|
||||
# Create mask overlay for target image
|
||||
target_with_mask = target_img.clone()
|
||||
target_with_mask = target_with_mask * (1 - target_mask) + target_mask # Apply mask only to target
|
||||
|
||||
# Save original images, mask, and overlay results
|
||||
# First row: original images
|
||||
# Second row: mask
|
||||
# Third row: overlay effect
|
||||
concatenated_img = torch.cat((
|
||||
ref_img, target_img, # Original images
|
||||
torch.zeros_like(ref_img), target_mask, # Mask (black for ref)
|
||||
ref_with_mask, target_with_mask # Overlay effect
|
||||
), dim=3)
|
||||
|
||||
torchvision.utils.save_image(
|
||||
concatenated_img, f'debug/mask_check_{i}.jpg', nrow=2)
|
||||
233
models/MuseTalk/musetalk/data/sample_method.py
Normal file
233
models/MuseTalk/musetalk/data/sample_method.py
Normal file
@@ -0,0 +1,233 @@
|
||||
import numpy as np
|
||||
import random
|
||||
|
||||
def summarize_tensor(x):
|
||||
return f"\033[34m{str(tuple(x.shape)).ljust(24)}\033[0m (\033[31mmin {x.min().item():+.4f}\033[0m / \033[32mmean {x.mean().item():+.4f}\033[0m / \033[33mmax {x.max().item():+.4f}\033[0m)"
|
||||
|
||||
def calculate_mouth_open_similarity(landmarks_list, select_idx,top_k=50,ascending=True):
|
||||
num_landmarks = len(landmarks_list)
|
||||
mouth_open_ratios = np.zeros(num_landmarks) # Initialize as a numpy array
|
||||
print(np.shape(landmarks_list))
|
||||
## Calculate mouth opening ratios
|
||||
for i, landmarks in enumerate(landmarks_list):
|
||||
# Assuming landmarks are in the format [x, y] and accessible by index
|
||||
mouth_top = landmarks[165] # Adjust index according to your landmarks format
|
||||
mouth_bottom = landmarks[147] # Adjust index according to your landmarks format
|
||||
mouth_open_ratio = np.linalg.norm(mouth_top - mouth_bottom)
|
||||
mouth_open_ratios[i] = mouth_open_ratio
|
||||
|
||||
# Calculate differences matrix
|
||||
differences_matrix = np.abs(mouth_open_ratios[:, np.newaxis] - mouth_open_ratios[select_idx])
|
||||
differences_matrix_with_signs = mouth_open_ratios[:, np.newaxis] - mouth_open_ratios[select_idx]
|
||||
print(differences_matrix.shape)
|
||||
# Find top_k similar indices for each landmark set
|
||||
if ascending:
|
||||
top_indices = np.argsort(differences_matrix[i])[:top_k]
|
||||
else:
|
||||
top_indices = np.argsort(-differences_matrix[i])[:top_k]
|
||||
similar_landmarks_indices = top_indices.tolist()
|
||||
similar_landmarks_distances = differences_matrix_with_signs[i].tolist() #注意这里不要排序
|
||||
|
||||
return similar_landmarks_indices, similar_landmarks_distances
|
||||
#############################################################################################
|
||||
def get_closed_mouth(landmarks_list,ascending=True,top_k=50):
|
||||
num_landmarks = len(landmarks_list)
|
||||
|
||||
mouth_open_ratios = np.zeros(num_landmarks) # Initialize as a numpy array
|
||||
## Calculate mouth opening ratios
|
||||
#print("landmarks shape",np.shape(landmarks_list))
|
||||
for i, landmarks in enumerate(landmarks_list):
|
||||
# Assuming landmarks are in the format [x, y] and accessible by index
|
||||
#print(landmarks[165])
|
||||
mouth_top = np.array(landmarks[165])# Adjust index according to your landmarks format
|
||||
mouth_bottom = np.array(landmarks[147]) # Adjust index according to your landmarks format
|
||||
mouth_open_ratio = np.linalg.norm(mouth_top - mouth_bottom)
|
||||
mouth_open_ratios[i] = mouth_open_ratio
|
||||
|
||||
# Find top_k similar indices for each landmark set
|
||||
if ascending:
|
||||
top_indices = np.argsort(mouth_open_ratios)[:top_k]
|
||||
else:
|
||||
top_indices = np.argsort(-mouth_open_ratios)[:top_k]
|
||||
return top_indices
|
||||
|
||||
def calculate_landmarks_similarity(selected_idx, landmarks_list,image_shapes, start_index, end_index, top_k=50,ascending=True):
|
||||
"""
|
||||
Calculate the similarity between sets of facial landmarks and return the indices of the most similar faces.
|
||||
|
||||
Parameters:
|
||||
landmarks_list (list): A list containing sets of facial landmarks, each element is a set of landmarks.
|
||||
image_shapes (list): A list containing the shape of each image, each element is a (width, height) tuple.
|
||||
start_index (int): The starting index of the facial landmarks.
|
||||
end_index (int): The ending index of the facial landmarks.
|
||||
top_k (int): The number of most similar landmark sets to return. Default is 50.
|
||||
ascending (bool): Controls the sorting order. If True, sort in ascending order; If False, sort in descending order. Default is True.
|
||||
|
||||
Returns:
|
||||
similar_landmarks_indices (list): A list containing the indices of the most similar facial landmarks for each face.
|
||||
resized_landmarks (list): A list containing the resized facial landmarks.
|
||||
"""
|
||||
num_landmarks = len(landmarks_list)
|
||||
resized_landmarks = []
|
||||
|
||||
# Preprocess landmarks
|
||||
for i in range(num_landmarks):
|
||||
landmark_array = np.array(landmarks_list[i])
|
||||
selected_landmarks = landmark_array[start_index:end_index]
|
||||
resized_landmark = resize_landmark(selected_landmarks, w=image_shapes[i][0], h=image_shapes[i][1],new_w=256,new_h=256)
|
||||
resized_landmarks.append(resized_landmark)
|
||||
|
||||
resized_landmarks_array = np.array(resized_landmarks) # Convert list to array for easier manipulation
|
||||
|
||||
# Calculate similarity
|
||||
distances = np.linalg.norm(resized_landmarks_array - resized_landmarks_array[selected_idx][np.newaxis, :], axis=2)
|
||||
overall_distances = np.mean(distances, axis=1) # Calculate mean distance for each set of landmarks
|
||||
|
||||
if ascending:
|
||||
sorted_indices = np.argsort(overall_distances)
|
||||
similar_landmarks_indices = sorted_indices[1:top_k+1].tolist() # Exclude self and take top_k
|
||||
else:
|
||||
sorted_indices = np.argsort(-overall_distances)
|
||||
similar_landmarks_indices = sorted_indices[0:top_k].tolist()
|
||||
|
||||
return similar_landmarks_indices
|
||||
|
||||
def process_bbox_musetalk(face_array, landmark_array):
|
||||
x_min_face, y_min_face, x_max_face, y_max_face = map(int, face_array)
|
||||
x_min_lm = min([int(x) for x, y in landmark_array])
|
||||
y_min_lm = min([int(y) for x, y in landmark_array])
|
||||
x_max_lm = max([int(x) for x, y in landmark_array])
|
||||
y_max_lm = max([int(y) for x, y in landmark_array])
|
||||
x_min = min(x_min_face, x_min_lm)
|
||||
y_min = min(y_min_face, y_min_lm)
|
||||
x_max = max(x_max_face, x_max_lm)
|
||||
y_max = max(y_max_face, y_max_lm)
|
||||
|
||||
x_min = max(x_min, 0)
|
||||
y_min = max(y_min, 0)
|
||||
|
||||
return [x_min, y_min, x_max, y_max]
|
||||
|
||||
def shift_landmarks_to_face_coordinates(landmark_list, face_list):
|
||||
"""
|
||||
Translates the data in landmark_list to the coordinates of the cropped larger face.
|
||||
|
||||
Parameters:
|
||||
landmark_list (list): A list containing multiple sets of facial landmarks.
|
||||
face_list (list): A list containing multiple facial images.
|
||||
|
||||
Returns:
|
||||
landmark_list_shift (list): The list of translated landmarks.
|
||||
bbox_union (list): The list of union bounding boxes.
|
||||
face_shapes (list): The list of facial shapes.
|
||||
"""
|
||||
landmark_list_shift = []
|
||||
bbox_union = []
|
||||
face_shapes = []
|
||||
|
||||
for i in range(len(face_list)):
|
||||
landmark_array = np.array(landmark_list[i]) # 转换为numpy数组并创建副本
|
||||
face_array = face_list[i]
|
||||
f_landmark_bbox = process_bbox_musetalk(face_array, landmark_array)
|
||||
x_min, y_min, x_max, y_max = f_landmark_bbox
|
||||
landmark_array[:, 0] = landmark_array[:, 0] - f_landmark_bbox[0]
|
||||
landmark_array[:, 1] = landmark_array[:, 1] - f_landmark_bbox[1]
|
||||
landmark_list_shift.append(landmark_array)
|
||||
bbox_union.append(f_landmark_bbox)
|
||||
face_shapes.append((x_max - x_min, y_max - y_min))
|
||||
|
||||
return landmark_list_shift, bbox_union, face_shapes
|
||||
|
||||
def resize_landmark(landmark, w, h, new_w, new_h):
|
||||
landmark_norm = landmark / [w, h]
|
||||
landmark_resized = landmark_norm * [new_w, new_h]
|
||||
|
||||
return landmark_resized
|
||||
|
||||
def get_src_idx(drive_idx, T, sample_method,landmarks_list,image_shapes,top_k_ratio):
|
||||
"""
|
||||
Calculate the source index (src_idx) based on the given drive index, T, s, e, and sampling method.
|
||||
|
||||
Parameters:
|
||||
- drive_idx (int): The current drive index.
|
||||
- T (int): Total number of frames or a specific range limit.
|
||||
- sample_method (str): Sampling method, which can be "random" or other methods.
|
||||
- landmarks_list (list): List of facial landmarks.
|
||||
- image_shapes (list): List of image shapes.
|
||||
- top_k_ratio (float): Ratio for selecting top k similar frames.
|
||||
|
||||
Returns:
|
||||
- src_idx (int): The calculated source index.
|
||||
"""
|
||||
if sample_method == "random":
|
||||
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
|
||||
elif sample_method == "pose_similarity":
|
||||
top_k = int(top_k_ratio*len(landmarks_list))
|
||||
try:
|
||||
top_k = int(top_k_ratio*len(landmarks_list))
|
||||
# facial contour
|
||||
landmark_start_idx = 0
|
||||
landmark_end_idx = 16
|
||||
pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True)
|
||||
src_idx = random.choice(pose_similarity_list)
|
||||
while abs(src_idx-drive_idx)<5:
|
||||
src_idx = random.choice(pose_similarity_list)
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
elif sample_method=="pose_similarity_and_closed_mouth":
|
||||
# facial contour
|
||||
landmark_start_idx = 0
|
||||
landmark_end_idx = 16
|
||||
try:
|
||||
top_k = int(top_k_ratio*len(landmarks_list))
|
||||
closed_mouth_list = get_closed_mouth(landmarks_list, ascending=True,top_k=top_k)
|
||||
#print("closed_mouth_list",closed_mouth_list)
|
||||
pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True)
|
||||
#print("pose_similarity_list",pose_similarity_list)
|
||||
common_list = list(set(closed_mouth_list).intersection(set(pose_similarity_list)))
|
||||
if len(common_list) == 0:
|
||||
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
|
||||
else:
|
||||
src_idx = random.choice(common_list)
|
||||
|
||||
while abs(src_idx-drive_idx) <5:
|
||||
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
elif sample_method=="pose_similarity_and_mouth_dissimilarity":
|
||||
top_k = int(top_k_ratio*len(landmarks_list))
|
||||
try:
|
||||
top_k = int(top_k_ratio*len(landmarks_list))
|
||||
|
||||
# facial contour for 68 landmarks format
|
||||
landmark_start_idx = 0
|
||||
landmark_end_idx = 16
|
||||
|
||||
pose_similarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=True)
|
||||
|
||||
# Mouth inner coutour for 68 landmarks format
|
||||
landmark_start_idx = 60
|
||||
landmark_end_idx = 67
|
||||
|
||||
mouth_dissimilarity_list = calculate_landmarks_similarity(drive_idx, landmarks_list,image_shapes, landmark_start_idx, landmark_end_idx,top_k=top_k, ascending=False)
|
||||
|
||||
common_list = list(set(pose_similarity_list).intersection(set(mouth_dissimilarity_list)))
|
||||
if len(common_list) == 0:
|
||||
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
|
||||
else:
|
||||
src_idx = random.choice(common_list)
|
||||
|
||||
while abs(src_idx-drive_idx) <5:
|
||||
src_idx = random.randint(drive_idx - 5 * T, drive_idx + 5 * T)
|
||||
|
||||
except Exception as e:
|
||||
print(e)
|
||||
return None
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown sample_method: {sample_method}")
|
||||
return src_idx
|
||||
81
models/MuseTalk/musetalk/loss/basic_loss.py
Normal file
81
models/MuseTalk/musetalk/loss/basic_loss.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from omegaconf import OmegaConf
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from torch import nn, optim
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from musetalk.loss.discriminator import MultiScaleDiscriminator,DiscriminatorFullModel
|
||||
import musetalk.loss.vgg_face as vgg_face
|
||||
|
||||
class Interpolate(nn.Module):
|
||||
def __init__(self, size=None, scale_factor=None, mode='nearest', align_corners=None):
|
||||
super(Interpolate, self).__init__()
|
||||
self.size = size
|
||||
self.scale_factor = scale_factor
|
||||
self.mode = mode
|
||||
self.align_corners = align_corners
|
||||
|
||||
def forward(self, input):
|
||||
return F.interpolate(input, self.size, self.scale_factor, self.mode, self.align_corners)
|
||||
|
||||
def set_requires_grad(net, requires_grad=False):
|
||||
if net is not None:
|
||||
for param in net.parameters():
|
||||
param.requires_grad = requires_grad
|
||||
|
||||
if __name__ == "__main__":
|
||||
cfg = OmegaConf.load("config/audio_adapter/E7.yaml")
|
||||
|
||||
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
|
||||
pyramid_scale = [1, 0.5, 0.25, 0.125]
|
||||
vgg_IN = vgg_face.Vgg19().to(device)
|
||||
pyramid = vgg_face.ImagePyramide(cfg.loss_params.pyramid_scale, 3).to(device)
|
||||
vgg_IN.eval()
|
||||
downsampler = Interpolate(size=(224, 224), mode='bilinear', align_corners=False)
|
||||
|
||||
image = torch.rand(8, 3, 256, 256).to(device)
|
||||
image_pred = torch.rand(8, 3, 256, 256).to(device)
|
||||
pyramide_real = pyramid(downsampler(image))
|
||||
pyramide_generated = pyramid(downsampler(image_pred))
|
||||
|
||||
|
||||
loss_IN = 0
|
||||
for scale in cfg.loss_params.pyramid_scale:
|
||||
x_vgg = vgg_IN(pyramide_generated['prediction_' + str(scale)])
|
||||
y_vgg = vgg_IN(pyramide_real['prediction_' + str(scale)])
|
||||
for i, weight in enumerate(cfg.loss_params.vgg_layer_weight):
|
||||
value = torch.abs(x_vgg[i] - y_vgg[i].detach()).mean()
|
||||
loss_IN += weight * value
|
||||
loss_IN /= sum(cfg.loss_params.vgg_layer_weight) # 对vgg不同层取均值,金字塔loss是每层叠
|
||||
print(loss_IN)
|
||||
|
||||
#print(cfg.model_params.discriminator_params)
|
||||
|
||||
discriminator = MultiScaleDiscriminator(**cfg.model_params.discriminator_params).to(device)
|
||||
discriminator_full = DiscriminatorFullModel(discriminator)
|
||||
disc_scales = cfg.model_params.discriminator_params.scales
|
||||
# Prepare optimizer and loss function
|
||||
optimizer_D = optim.AdamW(discriminator.parameters(),
|
||||
lr=cfg.discriminator_train_params.lr,
|
||||
weight_decay=cfg.discriminator_train_params.weight_decay,
|
||||
betas=cfg.discriminator_train_params.betas,
|
||||
eps=cfg.discriminator_train_params.eps)
|
||||
scheduler_D = CosineAnnealingLR(optimizer_D,
|
||||
T_max=cfg.discriminator_train_params.epochs,
|
||||
eta_min=1e-6)
|
||||
|
||||
discriminator.train()
|
||||
|
||||
set_requires_grad(discriminator, False)
|
||||
|
||||
loss_G = 0.
|
||||
discriminator_maps_generated = discriminator(pyramide_generated)
|
||||
discriminator_maps_real = discriminator(pyramide_real)
|
||||
|
||||
for scale in disc_scales:
|
||||
key = 'prediction_map_%s' % scale
|
||||
value = ((1 - discriminator_maps_generated[key]) ** 2).mean()
|
||||
loss_G += value
|
||||
|
||||
print(loss_G)
|
||||
44
models/MuseTalk/musetalk/loss/conv.py
Normal file
44
models/MuseTalk/musetalk/loss/conv.py
Normal file
@@ -0,0 +1,44 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
class Conv2d(nn.Module):
|
||||
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conv_block = nn.Sequential(
|
||||
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
||||
nn.BatchNorm2d(cout)
|
||||
)
|
||||
self.act = nn.ReLU()
|
||||
self.residual = residual
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv_block(x)
|
||||
if self.residual:
|
||||
out += x
|
||||
return self.act(out)
|
||||
|
||||
class nonorm_Conv2d(nn.Module):
|
||||
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conv_block = nn.Sequential(
|
||||
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
||||
)
|
||||
self.act = nn.LeakyReLU(0.01, inplace=True)
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv_block(x)
|
||||
return self.act(out)
|
||||
|
||||
class Conv2dTranspose(nn.Module):
|
||||
def __init__(self, cin, cout, kernel_size, stride, padding, output_padding=0, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.conv_block = nn.Sequential(
|
||||
nn.ConvTranspose2d(cin, cout, kernel_size, stride, padding, output_padding),
|
||||
nn.BatchNorm2d(cout)
|
||||
)
|
||||
self.act = nn.ReLU()
|
||||
|
||||
def forward(self, x):
|
||||
out = self.conv_block(x)
|
||||
return self.act(out)
|
||||
145
models/MuseTalk/musetalk/loss/discriminator.py
Normal file
145
models/MuseTalk/musetalk/loss/discriminator.py
Normal file
@@ -0,0 +1,145 @@
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
import torch
|
||||
from musetalk.loss.vgg_face import ImagePyramide
|
||||
|
||||
class DownBlock2d(nn.Module):
|
||||
"""
|
||||
Simple block for processing video (encoder).
|
||||
"""
|
||||
|
||||
def __init__(self, in_features, out_features, norm=False, kernel_size=4, pool=False, sn=False):
|
||||
super(DownBlock2d, self).__init__()
|
||||
self.conv = nn.Conv2d(in_channels=in_features, out_channels=out_features, kernel_size=kernel_size)
|
||||
|
||||
if sn:
|
||||
self.conv = nn.utils.spectral_norm(self.conv)
|
||||
|
||||
if norm:
|
||||
self.norm = nn.InstanceNorm2d(out_features, affine=True)
|
||||
else:
|
||||
self.norm = None
|
||||
self.pool = pool
|
||||
|
||||
def forward(self, x):
|
||||
out = x
|
||||
out = self.conv(out)
|
||||
if self.norm:
|
||||
out = self.norm(out)
|
||||
out = F.leaky_relu(out, 0.2)
|
||||
if self.pool:
|
||||
out = F.avg_pool2d(out, (2, 2))
|
||||
return out
|
||||
|
||||
|
||||
class Discriminator(nn.Module):
|
||||
"""
|
||||
Discriminator similar to Pix2Pix
|
||||
"""
|
||||
|
||||
def __init__(self, num_channels=3, block_expansion=64, num_blocks=4, max_features=512,
|
||||
sn=False, **kwargs):
|
||||
super(Discriminator, self).__init__()
|
||||
|
||||
down_blocks = []
|
||||
for i in range(num_blocks):
|
||||
down_blocks.append(
|
||||
DownBlock2d(num_channels if i == 0 else min(max_features, block_expansion * (2 ** i)),
|
||||
min(max_features, block_expansion * (2 ** (i + 1))),
|
||||
norm=(i != 0), kernel_size=4, pool=(i != num_blocks - 1), sn=sn))
|
||||
|
||||
self.down_blocks = nn.ModuleList(down_blocks)
|
||||
self.conv = nn.Conv2d(self.down_blocks[-1].conv.out_channels, out_channels=1, kernel_size=1)
|
||||
if sn:
|
||||
self.conv = nn.utils.spectral_norm(self.conv)
|
||||
|
||||
def forward(self, x):
|
||||
feature_maps = []
|
||||
out = x
|
||||
|
||||
for down_block in self.down_blocks:
|
||||
feature_maps.append(down_block(out))
|
||||
out = feature_maps[-1]
|
||||
prediction_map = self.conv(out)
|
||||
|
||||
return feature_maps, prediction_map
|
||||
|
||||
|
||||
class MultiScaleDiscriminator(nn.Module):
|
||||
"""
|
||||
Multi-scale (scale) discriminator
|
||||
"""
|
||||
|
||||
def __init__(self, scales=(), **kwargs):
|
||||
super(MultiScaleDiscriminator, self).__init__()
|
||||
self.scales = scales
|
||||
discs = {}
|
||||
for scale in scales:
|
||||
discs[str(scale).replace('.', '-')] = Discriminator(**kwargs)
|
||||
self.discs = nn.ModuleDict(discs)
|
||||
|
||||
def forward(self, x):
|
||||
out_dict = {}
|
||||
for scale, disc in self.discs.items():
|
||||
scale = str(scale).replace('-', '.')
|
||||
key = 'prediction_' + scale
|
||||
#print(key)
|
||||
#print(x)
|
||||
feature_maps, prediction_map = disc(x[key])
|
||||
out_dict['feature_maps_' + scale] = feature_maps
|
||||
out_dict['prediction_map_' + scale] = prediction_map
|
||||
return out_dict
|
||||
|
||||
|
||||
|
||||
class DiscriminatorFullModel(torch.nn.Module):
|
||||
"""
|
||||
Merge all discriminator related updates into single model for better multi-gpu usage
|
||||
"""
|
||||
|
||||
def __init__(self, discriminator):
|
||||
super(DiscriminatorFullModel, self).__init__()
|
||||
self.discriminator = discriminator
|
||||
self.scales = self.discriminator.scales
|
||||
print("scales",self.scales)
|
||||
self.pyramid = ImagePyramide(self.scales, 3)
|
||||
if torch.cuda.is_available():
|
||||
self.pyramid = self.pyramid.cuda()
|
||||
|
||||
self.zero_tensor = None
|
||||
|
||||
def get_zero_tensor(self, input):
|
||||
if self.zero_tensor is None:
|
||||
self.zero_tensor = torch.FloatTensor(1).fill_(0).cuda()
|
||||
self.zero_tensor.requires_grad_(False)
|
||||
return self.zero_tensor.expand_as(input)
|
||||
|
||||
def forward(self, x, generated, gan_mode='ls'):
|
||||
pyramide_real = self.pyramid(x)
|
||||
pyramide_generated = self.pyramid(generated.detach())
|
||||
|
||||
discriminator_maps_generated = self.discriminator(pyramide_generated)
|
||||
discriminator_maps_real = self.discriminator(pyramide_real)
|
||||
|
||||
value_total = 0
|
||||
for scale in self.scales:
|
||||
key = 'prediction_map_%s' % scale
|
||||
if gan_mode == 'hinge':
|
||||
value = -torch.mean(torch.min(discriminator_maps_real[key]-1, self.get_zero_tensor(discriminator_maps_real[key]))) - torch.mean(torch.min(-discriminator_maps_generated[key]-1, self.get_zero_tensor(discriminator_maps_generated[key])))
|
||||
elif gan_mode == 'ls':
|
||||
value = ((1 - discriminator_maps_real[key]) ** 2 + discriminator_maps_generated[key] ** 2).mean()
|
||||
else:
|
||||
raise ValueError('Unexpected gan_mode {}'.format(self.train_params['gan_mode']))
|
||||
|
||||
value_total += value
|
||||
|
||||
return value_total
|
||||
|
||||
def main():
|
||||
discriminator = MultiScaleDiscriminator(scales=[1],
|
||||
block_expansion=32,
|
||||
max_features=512,
|
||||
num_blocks=4,
|
||||
sn=True,
|
||||
image_channel=3,
|
||||
estimate_jacobian=False)
|
||||
152
models/MuseTalk/musetalk/loss/resnet.py
Normal file
152
models/MuseTalk/musetalk/loss/resnet.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import torch.nn as nn
|
||||
import math
|
||||
|
||||
__all__ = ['ResNet', 'resnet50']
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
expansion = 1
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(inplanes, planes, stride)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.conv2 = conv3x3(planes, planes)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, stride=stride, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ResNet(nn.Module):
|
||||
|
||||
def __init__(self, block, layers, num_classes=1000, include_top=True):
|
||||
self.inplanes = 64
|
||||
super(ResNet, self).__init__()
|
||||
self.include_top = include_top
|
||||
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=0, ceil_mode=True)
|
||||
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
self.avgpool = nn.AvgPool2d(7, stride=1)
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = x * 255.
|
||||
x = x.flip(1)
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
|
||||
if not self.include_top:
|
||||
return x
|
||||
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
return x
|
||||
|
||||
def resnet50(**kwargs):
|
||||
"""Constructs a ResNet-50 model.
|
||||
"""
|
||||
model = ResNet(Bottleneck, [3, 4, 6, 3], **kwargs)
|
||||
return model
|
||||
95
models/MuseTalk/musetalk/loss/syncnet.py
Normal file
95
models/MuseTalk/musetalk/loss/syncnet.py
Normal file
@@ -0,0 +1,95 @@
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from .conv import Conv2d
|
||||
|
||||
logloss = nn.BCELoss(reduction="none")
|
||||
def cosine_loss(a, v, y):
|
||||
d = nn.functional.cosine_similarity(a, v)
|
||||
d = d.clamp(0,1) # cosine_similarity的取值范围是【-1,1】,BCE如果输入负数会报错RuntimeError: CUDA error: device-side assert triggered
|
||||
loss = logloss(d.unsqueeze(1), y).squeeze()
|
||||
loss = loss.mean()
|
||||
return loss, d
|
||||
|
||||
def get_sync_loss(
|
||||
audio_embed,
|
||||
gt_frames,
|
||||
pred_frames,
|
||||
syncnet,
|
||||
adapted_weight,
|
||||
frames_left_index=0,
|
||||
frames_right_index=16,
|
||||
):
|
||||
# 跟gt_frames做随机的插入交换,节省显存开销
|
||||
assert pred_frames.shape[1] == (frames_right_index - frames_left_index) * 3
|
||||
# 3通道图像
|
||||
frames_sync_loss = torch.cat(
|
||||
[gt_frames[:, :3 * frames_left_index, ...], pred_frames, gt_frames[:, 3 * frames_right_index:, ...]],
|
||||
axis=1
|
||||
)
|
||||
vision_embed = syncnet.get_image_embed(frames_sync_loss)
|
||||
y = torch.ones(frames_sync_loss.size(0), 1).float().to(audio_embed.device)
|
||||
loss, score = cosine_loss(audio_embed, vision_embed, y)
|
||||
return loss, score
|
||||
|
||||
class SyncNet_color(nn.Module):
|
||||
def __init__(self):
|
||||
super(SyncNet_color, self).__init__()
|
||||
|
||||
self.face_encoder = nn.Sequential(
|
||||
Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3),
|
||||
|
||||
Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1),
|
||||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(64, 128, kernel_size=3, stride=2, padding=1),
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(128, 256, kernel_size=3, stride=2, padding=1),
|
||||
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(256, 512, kernel_size=3, stride=2, padding=1),
|
||||
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(512, 512, kernel_size=3, stride=2, padding=1),
|
||||
Conv2d(512, 512, kernel_size=3, stride=1, padding=0),
|
||||
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
||||
|
||||
self.audio_encoder = nn.Sequential(
|
||||
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
||||
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
|
||||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
|
||||
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
||||
|
||||
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
||||
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
||||
|
||||
def forward(self, audio_sequences, face_sequences): # audio_sequences := (B, dim, T)
|
||||
face_embedding = self.face_encoder(face_sequences)
|
||||
audio_embedding = self.audio_encoder(audio_sequences)
|
||||
|
||||
audio_embedding = audio_embedding.view(audio_embedding.size(0), -1)
|
||||
face_embedding = face_embedding.view(face_embedding.size(0), -1)
|
||||
|
||||
audio_embedding = F.normalize(audio_embedding, p=2, dim=1)
|
||||
face_embedding = F.normalize(face_embedding, p=2, dim=1)
|
||||
|
||||
|
||||
return audio_embedding, face_embedding
|
||||
237
models/MuseTalk/musetalk/loss/vgg_face.py
Normal file
237
models/MuseTalk/musetalk/loss/vgg_face.py
Normal file
@@ -0,0 +1,237 @@
|
||||
'''
|
||||
This part of code contains a pretrained vgg_face model.
|
||||
ref link: https://github.com/prlz77/vgg-face.pytorch
|
||||
'''
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.model_zoo
|
||||
import pickle
|
||||
from musetalk.loss import resnet as ResNet
|
||||
|
||||
|
||||
MODEL_URL = "https://github.com/claudio-unipv/vggface-pytorch/releases/download/v0.1/vggface-9d491dd7c30312.pth"
|
||||
VGG_FACE_PATH = '/apdcephfs_cq8/share_1367250/zhentaoyu/Driving/00_VASA/00_data/models/pretrain_models/resnet50_ft_weight.pkl'
|
||||
|
||||
# It was 93.5940, 104.7624, 129.1863 before dividing by 255
|
||||
MEAN_RGB = [
|
||||
0.367035294117647,
|
||||
0.41083294117647057,
|
||||
0.5066129411764705
|
||||
]
|
||||
def load_state_dict(model, fname):
|
||||
"""
|
||||
Set parameters converted from Caffe models authors of VGGFace2 provide.
|
||||
See https://www.robots.ox.ac.uk/~vgg/data/vgg_face2/.
|
||||
|
||||
Arguments:
|
||||
model: model
|
||||
fname: file name of parameters converted from a Caffe model, assuming the file format is Pickle.
|
||||
"""
|
||||
with open(fname, 'rb') as f:
|
||||
weights = pickle.load(f, encoding='latin1')
|
||||
|
||||
own_state = model.state_dict()
|
||||
for name, param in weights.items():
|
||||
if name in own_state:
|
||||
try:
|
||||
own_state[name].copy_(torch.from_numpy(param))
|
||||
except Exception:
|
||||
raise RuntimeError('While copying the parameter named {}, whose dimensions in the model are {} and whose '\
|
||||
'dimensions in the checkpoint are {}.'.format(name, own_state[name].size(), param.size()))
|
||||
else:
|
||||
raise KeyError('unexpected key "{}" in state_dict'.format(name))
|
||||
|
||||
|
||||
def vggface2(pretrained=True):
|
||||
vggface = ResNet.resnet50(num_classes=8631, include_top=True)
|
||||
load_state_dict(vggface, VGG_FACE_PATH)
|
||||
return vggface
|
||||
|
||||
def vggface(pretrained=False, **kwargs):
|
||||
"""VGGFace model.
|
||||
|
||||
Args:
|
||||
pretrained (bool): If True, returns pre-trained model
|
||||
"""
|
||||
model = VggFace(**kwargs)
|
||||
if pretrained:
|
||||
state = torch.utils.model_zoo.load_url(MODEL_URL)
|
||||
model.load_state_dict(state)
|
||||
return model
|
||||
|
||||
|
||||
class VggFace(torch.nn.Module):
|
||||
def __init__(self, classes=2622):
|
||||
"""VGGFace model.
|
||||
|
||||
Face recognition network. It takes as input a Bx3x224x224
|
||||
batch of face images and gives as output a BxC score vector
|
||||
(C is the number of identities).
|
||||
Input images need to be scaled in the 0-1 range and then
|
||||
normalized with respect to the mean RGB used during training.
|
||||
|
||||
Args:
|
||||
classes (int): number of identities recognized by the
|
||||
network
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.conv1 = _ConvBlock(3, 64, 64)
|
||||
self.conv2 = _ConvBlock(64, 128, 128)
|
||||
self.conv3 = _ConvBlock(128, 256, 256, 256)
|
||||
self.conv4 = _ConvBlock(256, 512, 512, 512)
|
||||
self.conv5 = _ConvBlock(512, 512, 512, 512)
|
||||
self.dropout = torch.nn.Dropout(0.5)
|
||||
self.fc1 = torch.nn.Linear(7 * 7 * 512, 4096)
|
||||
self.fc2 = torch.nn.Linear(4096, 4096)
|
||||
self.fc3 = torch.nn.Linear(4096, classes)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.conv2(x)
|
||||
x = self.conv3(x)
|
||||
x = self.conv4(x)
|
||||
x = self.conv5(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.dropout(F.relu(self.fc1(x)))
|
||||
x = self.dropout(F.relu(self.fc2(x)))
|
||||
x = self.fc3(x)
|
||||
return x
|
||||
|
||||
|
||||
class _ConvBlock(torch.nn.Module):
|
||||
"""A Convolutional block."""
|
||||
|
||||
def __init__(self, *units):
|
||||
"""Create a block with len(units) - 1 convolutions.
|
||||
|
||||
convolution number i transforms the number of channels from
|
||||
units[i - 1] to units[i] channels.
|
||||
|
||||
"""
|
||||
super().__init__()
|
||||
self.convs = torch.nn.ModuleList([
|
||||
torch.nn.Conv2d(in_, out, 3, 1, 1)
|
||||
for in_, out in zip(units[:-1], units[1:])
|
||||
])
|
||||
|
||||
def forward(self, x):
|
||||
# Each convolution is followed by a ReLU, then the block is
|
||||
# concluded by a max pooling.
|
||||
for c in self.convs:
|
||||
x = F.relu(c(x))
|
||||
return F.max_pool2d(x, 2, 2, 0, ceil_mode=True)
|
||||
|
||||
|
||||
|
||||
import numpy as np
|
||||
from torchvision import models
|
||||
class Vgg19(torch.nn.Module):
|
||||
"""
|
||||
Vgg19 network for perceptual loss.
|
||||
"""
|
||||
def __init__(self, requires_grad=False):
|
||||
super(Vgg19, self).__init__()
|
||||
vgg_pretrained_features = models.vgg19(pretrained=True).features
|
||||
self.slice1 = torch.nn.Sequential()
|
||||
self.slice2 = torch.nn.Sequential()
|
||||
self.slice3 = torch.nn.Sequential()
|
||||
self.slice4 = torch.nn.Sequential()
|
||||
self.slice5 = torch.nn.Sequential()
|
||||
for x in range(2):
|
||||
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(2, 7):
|
||||
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(7, 12):
|
||||
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(12, 21):
|
||||
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
||||
for x in range(21, 30):
|
||||
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
||||
|
||||
self.mean = torch.nn.Parameter(data=torch.Tensor(np.array([0.485, 0.456, 0.406]).reshape((1, 3, 1, 1))),
|
||||
requires_grad=False)
|
||||
self.std = torch.nn.Parameter(data=torch.Tensor(np.array([0.229, 0.224, 0.225]).reshape((1, 3, 1, 1))),
|
||||
requires_grad=False)
|
||||
|
||||
if not requires_grad:
|
||||
for param in self.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward(self, X):
|
||||
X = (X - self.mean) / self.std
|
||||
h_relu1 = self.slice1(X)
|
||||
h_relu2 = self.slice2(h_relu1)
|
||||
h_relu3 = self.slice3(h_relu2)
|
||||
h_relu4 = self.slice4(h_relu3)
|
||||
h_relu5 = self.slice5(h_relu4)
|
||||
out = [h_relu1, h_relu2, h_relu3, h_relu4, h_relu5]
|
||||
return out
|
||||
|
||||
|
||||
from torch import nn
|
||||
class AntiAliasInterpolation2d(nn.Module):
|
||||
"""
|
||||
Band-limited downsampling, for better preservation of the input signal.
|
||||
"""
|
||||
def __init__(self, channels, scale):
|
||||
super(AntiAliasInterpolation2d, self).__init__()
|
||||
sigma = (1 / scale - 1) / 2
|
||||
kernel_size = 2 * round(sigma * 4) + 1
|
||||
self.ka = kernel_size // 2
|
||||
self.kb = self.ka - 1 if kernel_size % 2 == 0 else self.ka
|
||||
|
||||
kernel_size = [kernel_size, kernel_size]
|
||||
sigma = [sigma, sigma]
|
||||
# The gaussian kernel is the product of the
|
||||
# gaussian function of each dimension.
|
||||
kernel = 1
|
||||
meshgrids = torch.meshgrid(
|
||||
[
|
||||
torch.arange(size, dtype=torch.float32)
|
||||
for size in kernel_size
|
||||
]
|
||||
)
|
||||
for size, std, mgrid in zip(kernel_size, sigma, meshgrids):
|
||||
mean = (size - 1) / 2
|
||||
kernel *= torch.exp(-(mgrid - mean) ** 2 / (2 * std ** 2))
|
||||
|
||||
# Make sure sum of values in gaussian kernel equals 1.
|
||||
kernel = kernel / torch.sum(kernel)
|
||||
# Reshape to depthwise convolutional weight
|
||||
kernel = kernel.view(1, 1, *kernel.size())
|
||||
kernel = kernel.repeat(channels, *[1] * (kernel.dim() - 1))
|
||||
|
||||
self.register_buffer('weight', kernel)
|
||||
self.groups = channels
|
||||
self.scale = scale
|
||||
inv_scale = 1 / scale
|
||||
self.int_inv_scale = int(inv_scale)
|
||||
|
||||
def forward(self, input):
|
||||
if self.scale == 1.0:
|
||||
return input
|
||||
|
||||
out = F.pad(input, (self.ka, self.kb, self.ka, self.kb))
|
||||
out = F.conv2d(out, weight=self.weight, groups=self.groups)
|
||||
out = out[:, :, ::self.int_inv_scale, ::self.int_inv_scale]
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class ImagePyramide(torch.nn.Module):
|
||||
"""
|
||||
Create image pyramide for computing pyramide perceptual loss.
|
||||
"""
|
||||
def __init__(self, scales, num_channels):
|
||||
super(ImagePyramide, self).__init__()
|
||||
downs = {}
|
||||
for scale in scales:
|
||||
downs[str(scale).replace('.', '-')] = AntiAliasInterpolation2d(num_channels, scale)
|
||||
self.downs = nn.ModuleDict(downs)
|
||||
|
||||
def forward(self, x):
|
||||
out_dict = {}
|
||||
for scale, down_module in self.downs.items():
|
||||
out_dict['prediction_' + str(scale).replace('-', '.')] = down_module(x)
|
||||
return out_dict
|
||||
240
models/MuseTalk/musetalk/models/syncnet.py
Normal file
240
models/MuseTalk/musetalk/models/syncnet.py
Normal file
@@ -0,0 +1,240 @@
|
||||
"""
|
||||
This file is modified from LatentSync (https://github.com/bytedance/LatentSync/blob/main/latentsync/models/stable_syncnet.py).
|
||||
"""
|
||||
|
||||
import torch
|
||||
from torch import nn
|
||||
from einops import rearrange
|
||||
from torch.nn import functional as F
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from diffusers.models.attention import Attention as CrossAttention, FeedForward
|
||||
from diffusers.utils.import_utils import is_xformers_available
|
||||
from einops import rearrange
|
||||
|
||||
|
||||
class SyncNet(nn.Module):
|
||||
def __init__(self, config):
|
||||
super().__init__()
|
||||
self.audio_encoder = DownEncoder2D(
|
||||
in_channels=config["audio_encoder"]["in_channels"],
|
||||
block_out_channels=config["audio_encoder"]["block_out_channels"],
|
||||
downsample_factors=config["audio_encoder"]["downsample_factors"],
|
||||
dropout=config["audio_encoder"]["dropout"],
|
||||
attn_blocks=config["audio_encoder"]["attn_blocks"],
|
||||
)
|
||||
|
||||
self.visual_encoder = DownEncoder2D(
|
||||
in_channels=config["visual_encoder"]["in_channels"],
|
||||
block_out_channels=config["visual_encoder"]["block_out_channels"],
|
||||
downsample_factors=config["visual_encoder"]["downsample_factors"],
|
||||
dropout=config["visual_encoder"]["dropout"],
|
||||
attn_blocks=config["visual_encoder"]["attn_blocks"],
|
||||
)
|
||||
|
||||
self.eval()
|
||||
|
||||
def forward(self, image_sequences, audio_sequences):
|
||||
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
|
||||
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
|
||||
|
||||
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
|
||||
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
|
||||
|
||||
# Make them unit vectors
|
||||
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
|
||||
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
|
||||
|
||||
return vision_embeds, audio_embeds
|
||||
|
||||
def get_image_embed(self, image_sequences):
|
||||
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
|
||||
|
||||
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
|
||||
|
||||
# Make them unit vectors
|
||||
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
|
||||
|
||||
return vision_embeds
|
||||
|
||||
def get_audio_embed(self, audio_sequences):
|
||||
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
|
||||
|
||||
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
|
||||
|
||||
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
|
||||
|
||||
return audio_embeds
|
||||
|
||||
class ResnetBlock2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels: int,
|
||||
out_channels: int,
|
||||
dropout: float = 0.0,
|
||||
norm_num_groups: int = 32,
|
||||
eps: float = 1e-6,
|
||||
act_fn: str = "silu",
|
||||
downsample_factor=2,
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
|
||||
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=eps, affine=True)
|
||||
self.dropout = nn.Dropout(dropout)
|
||||
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
if act_fn == "relu":
|
||||
self.act_fn = nn.ReLU()
|
||||
elif act_fn == "silu":
|
||||
self.act_fn = nn.SiLU()
|
||||
|
||||
if in_channels != out_channels:
|
||||
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
|
||||
else:
|
||||
self.conv_shortcut = None
|
||||
|
||||
if isinstance(downsample_factor, list):
|
||||
downsample_factor = tuple(downsample_factor)
|
||||
|
||||
if downsample_factor == 1:
|
||||
self.downsample_conv = None
|
||||
else:
|
||||
self.downsample_conv = nn.Conv2d(
|
||||
out_channels, out_channels, kernel_size=3, stride=downsample_factor, padding=0
|
||||
)
|
||||
self.pad = (0, 1, 0, 1)
|
||||
if isinstance(downsample_factor, tuple):
|
||||
if downsample_factor[0] == 1:
|
||||
self.pad = (0, 1, 1, 1) # The padding order is from back to front
|
||||
elif downsample_factor[1] == 1:
|
||||
self.pad = (1, 1, 0, 1)
|
||||
|
||||
def forward(self, input_tensor):
|
||||
hidden_states = input_tensor
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.act_fn(hidden_states)
|
||||
|
||||
hidden_states = self.conv1(hidden_states)
|
||||
hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.act_fn(hidden_states)
|
||||
|
||||
hidden_states = self.dropout(hidden_states)
|
||||
hidden_states = self.conv2(hidden_states)
|
||||
|
||||
if self.conv_shortcut is not None:
|
||||
input_tensor = self.conv_shortcut(input_tensor)
|
||||
|
||||
hidden_states += input_tensor
|
||||
|
||||
if self.downsample_conv is not None:
|
||||
hidden_states = F.pad(hidden_states, self.pad, mode="constant", value=0)
|
||||
hidden_states = self.downsample_conv(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
|
||||
|
||||
class AttentionBlock2D(nn.Module):
|
||||
def __init__(self, query_dim, norm_num_groups=32, dropout=0.0):
|
||||
super().__init__()
|
||||
if not is_xformers_available():
|
||||
raise ModuleNotFoundError(
|
||||
"You have to install xformers to enable memory efficient attetion", name="xformers"
|
||||
)
|
||||
# inner_dim = dim_head * heads
|
||||
self.norm1 = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=query_dim, eps=1e-6, affine=True)
|
||||
self.norm2 = nn.LayerNorm(query_dim)
|
||||
self.norm3 = nn.LayerNorm(query_dim)
|
||||
|
||||
self.ff = FeedForward(query_dim, dropout=dropout, activation_fn="geglu")
|
||||
|
||||
self.conv_in = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
|
||||
self.conv_out = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.attn = CrossAttention(query_dim=query_dim, heads=8, dim_head=query_dim // 8, dropout=dropout, bias=True)
|
||||
self.attn._use_memory_efficient_attention_xformers = True
|
||||
|
||||
def forward(self, hidden_states):
|
||||
assert hidden_states.dim() == 4, f"Expected hidden_states to have ndim=4, but got ndim={hidden_states.dim()}."
|
||||
|
||||
batch, channel, height, width = hidden_states.shape
|
||||
residual = hidden_states
|
||||
|
||||
hidden_states = self.norm1(hidden_states)
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
|
||||
|
||||
norm_hidden_states = self.norm2(hidden_states)
|
||||
hidden_states = self.attn(norm_hidden_states, attention_mask=None) + hidden_states
|
||||
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
|
||||
|
||||
hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=height, w=width)
|
||||
hidden_states = self.conv_out(hidden_states)
|
||||
|
||||
hidden_states = hidden_states + residual
|
||||
return hidden_states
|
||||
|
||||
|
||||
class DownEncoder2D(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
in_channels=4 * 16,
|
||||
block_out_channels=[64, 128, 256, 256],
|
||||
downsample_factors=[2, 2, 2, 2],
|
||||
layers_per_block=2,
|
||||
norm_num_groups=32,
|
||||
attn_blocks=[1, 1, 1, 1],
|
||||
dropout: float = 0.0,
|
||||
act_fn="silu",
|
||||
):
|
||||
super().__init__()
|
||||
self.layers_per_block = layers_per_block
|
||||
|
||||
# in
|
||||
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
|
||||
|
||||
# down
|
||||
self.down_blocks = nn.ModuleList([])
|
||||
|
||||
output_channels = block_out_channels[0]
|
||||
for i, block_out_channel in enumerate(block_out_channels):
|
||||
input_channels = output_channels
|
||||
output_channels = block_out_channel
|
||||
# is_final_block = i == len(block_out_channels) - 1
|
||||
|
||||
down_block = ResnetBlock2D(
|
||||
in_channels=input_channels,
|
||||
out_channels=output_channels,
|
||||
downsample_factor=downsample_factors[i],
|
||||
norm_num_groups=norm_num_groups,
|
||||
dropout=dropout,
|
||||
act_fn=act_fn,
|
||||
)
|
||||
|
||||
self.down_blocks.append(down_block)
|
||||
|
||||
if attn_blocks[i] == 1:
|
||||
attention_block = AttentionBlock2D(query_dim=output_channels, dropout=dropout)
|
||||
self.down_blocks.append(attention_block)
|
||||
|
||||
# out
|
||||
self.norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
|
||||
self.act_fn_out = nn.ReLU()
|
||||
|
||||
def forward(self, hidden_states):
|
||||
hidden_states = self.conv_in(hidden_states)
|
||||
|
||||
# down
|
||||
for down_block in self.down_blocks:
|
||||
hidden_states = down_block(hidden_states)
|
||||
|
||||
# post-process
|
||||
hidden_states = self.norm_out(hidden_states)
|
||||
hidden_states = self.act_fn_out(hidden_states)
|
||||
|
||||
return hidden_states
|
||||
51
models/MuseTalk/musetalk/models/unet.py
Normal file
51
models/MuseTalk/musetalk/models/unet.py
Normal file
@@ -0,0 +1,51 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import math
|
||||
import json
|
||||
|
||||
from diffusers import UNet2DConditionModel
|
||||
import sys
|
||||
import time
|
||||
import numpy as np
|
||||
import os
|
||||
|
||||
class PositionalEncoding(nn.Module):
|
||||
def __init__(self, d_model=384, max_len=5000):
|
||||
super(PositionalEncoding, self).__init__()
|
||||
pe = torch.zeros(max_len, d_model)
|
||||
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
||||
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
||||
pe[:, 0::2] = torch.sin(position * div_term)
|
||||
pe[:, 1::2] = torch.cos(position * div_term)
|
||||
pe = pe.unsqueeze(0)
|
||||
self.register_buffer('pe', pe)
|
||||
|
||||
def forward(self, x):
|
||||
b, seq_len, d_model = x.size()
|
||||
pe = self.pe[:, :seq_len, :]
|
||||
x = x + pe.to(x.device)
|
||||
return x
|
||||
|
||||
class UNet():
|
||||
def __init__(self,
|
||||
unet_config,
|
||||
model_path,
|
||||
use_float16=False,
|
||||
device=None
|
||||
):
|
||||
with open(unet_config, 'r') as f:
|
||||
unet_config = json.load(f)
|
||||
self.model = UNet2DConditionModel(**unet_config)
|
||||
self.pe = PositionalEncoding(d_model=384)
|
||||
if device != None:
|
||||
self.device = device
|
||||
else:
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
weights = torch.load(model_path) if torch.cuda.is_available() else torch.load(model_path, map_location=self.device)
|
||||
self.model.load_state_dict(weights)
|
||||
if use_float16:
|
||||
self.model = self.model.half()
|
||||
self.model.to(self.device)
|
||||
|
||||
if __name__ == "__main__":
|
||||
unet = UNet()
|
||||
148
models/MuseTalk/musetalk/models/vae.py
Normal file
148
models/MuseTalk/musetalk/models/vae.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from diffusers import AutoencoderKL
|
||||
import torch
|
||||
import torchvision.transforms as transforms
|
||||
import torch.nn.functional as F
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
import os
|
||||
|
||||
class VAE():
|
||||
"""
|
||||
VAE (Variational Autoencoder) class for image processing.
|
||||
"""
|
||||
|
||||
def __init__(self, model_path="./models/sd-vae-ft-mse/", resized_img=256, use_float16=False):
|
||||
"""
|
||||
Initialize the VAE instance.
|
||||
|
||||
:param model_path: Path to the trained model.
|
||||
:param resized_img: The size to which images are resized.
|
||||
:param use_float16: Whether to use float16 precision.
|
||||
"""
|
||||
self.model_path = model_path
|
||||
self.vae = AutoencoderKL.from_pretrained(self.model_path)
|
||||
|
||||
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
self.vae.to(self.device)
|
||||
|
||||
if use_float16:
|
||||
self.vae = self.vae.half()
|
||||
self._use_float16 = True
|
||||
else:
|
||||
self._use_float16 = False
|
||||
|
||||
self.scaling_factor = self.vae.config.scaling_factor
|
||||
self.transform = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
|
||||
self._resized_img = resized_img
|
||||
self._mask_tensor = self.get_mask_tensor()
|
||||
|
||||
def get_mask_tensor(self):
|
||||
"""
|
||||
Creates a mask tensor for image processing.
|
||||
:return: A mask tensor.
|
||||
"""
|
||||
mask_tensor = torch.zeros((self._resized_img,self._resized_img))
|
||||
mask_tensor[:self._resized_img//2,:] = 1
|
||||
mask_tensor[mask_tensor< 0.5] = 0
|
||||
mask_tensor[mask_tensor>= 0.5] = 1
|
||||
return mask_tensor
|
||||
|
||||
def preprocess_img(self,img_name,half_mask=False):
|
||||
"""
|
||||
Preprocess an image for the VAE.
|
||||
|
||||
:param img_name: The image file path or a list of image file paths.
|
||||
:param half_mask: Whether to apply a half mask to the image.
|
||||
:return: A preprocessed image tensor.
|
||||
"""
|
||||
window = []
|
||||
if isinstance(img_name, str):
|
||||
window_fnames = [img_name]
|
||||
for fname in window_fnames:
|
||||
img = cv2.imread(fname)
|
||||
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
img = cv2.resize(img, (self._resized_img, self._resized_img),
|
||||
interpolation=cv2.INTER_LANCZOS4)
|
||||
window.append(img)
|
||||
else:
|
||||
img = cv2.cvtColor(img_name, cv2.COLOR_BGR2RGB)
|
||||
window.append(img)
|
||||
|
||||
x = np.asarray(window) / 255.
|
||||
x = np.transpose(x, (3, 0, 1, 2))
|
||||
x = torch.squeeze(torch.FloatTensor(x))
|
||||
if half_mask:
|
||||
x = x * (self._mask_tensor>0.5)
|
||||
x = self.transform(x)
|
||||
|
||||
x = x.unsqueeze(0) # [1, 3, 256, 256] torch tensor
|
||||
x = x.to(self.vae.device)
|
||||
|
||||
return x
|
||||
|
||||
def encode_latents(self,image):
|
||||
"""
|
||||
Encode an image into latent variables.
|
||||
|
||||
:param image: The image tensor to encode.
|
||||
:return: The encoded latent variables.
|
||||
"""
|
||||
with torch.no_grad():
|
||||
init_latent_dist = self.vae.encode(image.to(self.vae.dtype)).latent_dist
|
||||
init_latents = self.scaling_factor * init_latent_dist.sample()
|
||||
return init_latents
|
||||
|
||||
def decode_latents(self, latents):
|
||||
"""
|
||||
Decode latent variables back into an image.
|
||||
:param latents: The latent variables to decode.
|
||||
:return: A NumPy array representing the decoded image.
|
||||
"""
|
||||
latents = (1/ self.scaling_factor) * latents
|
||||
image = self.vae.decode(latents.to(self.vae.dtype)).sample
|
||||
image = (image / 2 + 0.5).clamp(0, 1)
|
||||
image = image.detach().cpu().permute(0, 2, 3, 1).float().numpy()
|
||||
image = (image * 255).round().astype("uint8")
|
||||
image = image[...,::-1] # RGB to BGR
|
||||
return image
|
||||
|
||||
def get_latents_for_unet(self,img):
|
||||
"""
|
||||
Prepare latent variables for a U-Net model.
|
||||
:param img: The image to process.
|
||||
:return: A concatenated tensor of latents for U-Net input.
|
||||
"""
|
||||
|
||||
ref_image = self.preprocess_img(img,half_mask=True) # [1, 3, 256, 256] RGB, torch tensor
|
||||
masked_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
|
||||
ref_image = self.preprocess_img(img,half_mask=False) # [1, 3, 256, 256] RGB, torch tensor
|
||||
ref_latents = self.encode_latents(ref_image) # [1, 4, 32, 32], torch tensor
|
||||
latent_model_input = torch.cat([masked_latents, ref_latents], dim=1)
|
||||
return latent_model_input
|
||||
|
||||
if __name__ == "__main__":
|
||||
vae_mode_path = "./models/sd-vae-ft-mse/"
|
||||
vae = VAE(model_path = vae_mode_path,use_float16=False)
|
||||
img_path = "./results/sun001_crop/00000.png"
|
||||
|
||||
crop_imgs_path = "./results/sun001_crop/"
|
||||
latents_out_path = "./results/latents/"
|
||||
if not os.path.exists(latents_out_path):
|
||||
os.mkdir(latents_out_path)
|
||||
|
||||
files = os.listdir(crop_imgs_path)
|
||||
files.sort()
|
||||
files = [file for file in files if file.split(".")[-1] == "png"]
|
||||
|
||||
for file in files:
|
||||
index = file.split(".")[0]
|
||||
img_path = crop_imgs_path + file
|
||||
latents = vae.get_latents_for_unet(img_path)
|
||||
print(img_path,"latents",latents.size())
|
||||
#torch.save(latents,os.path.join(latents_out_path,index+".pt"))
|
||||
#reload_tensor = torch.load('tensor.pt')
|
||||
#print(reload_tensor.size())
|
||||
|
||||
|
||||
|
||||
5
models/MuseTalk/musetalk/utils/__init__.py
Normal file
5
models/MuseTalk/musetalk/utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import sys
|
||||
from os.path import abspath, dirname
|
||||
current_dir = dirname(abspath(__file__))
|
||||
parent_dir = dirname(current_dir)
|
||||
sys.path.append(parent_dir+'/utils')
|
||||
113
models/MuseTalk/musetalk/utils/audio_processor.py
Normal file
113
models/MuseTalk/musetalk/utils/audio_processor.py
Normal file
@@ -0,0 +1,113 @@
|
||||
import math
|
||||
import os
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
from einops import rearrange
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
|
||||
class AudioProcessor:
|
||||
def __init__(self, feature_extractor_path="openai/whisper-tiny/"):
|
||||
self.feature_extractor = AutoFeatureExtractor.from_pretrained(feature_extractor_path)
|
||||
|
||||
def get_audio_feature(self, wav_path, start_index=0, weight_dtype=None):
|
||||
if not os.path.exists(wav_path):
|
||||
return None
|
||||
librosa_output, sampling_rate = librosa.load(wav_path, sr=16000)
|
||||
assert sampling_rate == 16000
|
||||
# Split audio into 30s segments
|
||||
segment_length = 30 * sampling_rate
|
||||
segments = [librosa_output[i:i + segment_length] for i in range(0, len(librosa_output), segment_length)]
|
||||
|
||||
features = []
|
||||
for segment in segments:
|
||||
audio_feature = self.feature_extractor(
|
||||
segment,
|
||||
return_tensors="pt",
|
||||
sampling_rate=sampling_rate
|
||||
).input_features
|
||||
if weight_dtype is not None:
|
||||
audio_feature = audio_feature.to(dtype=weight_dtype)
|
||||
features.append(audio_feature)
|
||||
|
||||
return features, len(librosa_output)
|
||||
|
||||
def get_whisper_chunk(
|
||||
self,
|
||||
whisper_input_features,
|
||||
device,
|
||||
weight_dtype,
|
||||
whisper,
|
||||
librosa_length,
|
||||
fps=25,
|
||||
audio_padding_length_left=2,
|
||||
audio_padding_length_right=2,
|
||||
):
|
||||
audio_feature_length_per_frame = 2 * (audio_padding_length_left + audio_padding_length_right + 1)
|
||||
whisper_feature = []
|
||||
# Process multiple 30s mel input features
|
||||
for input_feature in whisper_input_features:
|
||||
input_feature = input_feature.to(device).to(weight_dtype)
|
||||
audio_feats = whisper.encoder(input_feature, output_hidden_states=True).hidden_states
|
||||
audio_feats = torch.stack(audio_feats, dim=2)
|
||||
whisper_feature.append(audio_feats)
|
||||
|
||||
whisper_feature = torch.cat(whisper_feature, dim=1)
|
||||
# Trim the last segment to remove padding
|
||||
sr = 16000
|
||||
audio_fps = 50
|
||||
fps = int(fps)
|
||||
whisper_idx_multiplier = audio_fps / fps
|
||||
num_frames = math.floor((librosa_length / sr) * fps)
|
||||
actual_length = math.floor((librosa_length / sr) * audio_fps)
|
||||
whisper_feature = whisper_feature[:,:actual_length,...]
|
||||
|
||||
# Calculate padding amount
|
||||
padding_nums = math.ceil(whisper_idx_multiplier)
|
||||
# Add padding at start and end
|
||||
whisper_feature = torch.cat([
|
||||
torch.zeros_like(whisper_feature[:, :padding_nums * audio_padding_length_left]),
|
||||
whisper_feature,
|
||||
# Add extra padding to prevent out of bounds
|
||||
torch.zeros_like(whisper_feature[:, :padding_nums * 3 * audio_padding_length_right])
|
||||
], 1)
|
||||
|
||||
audio_prompts = []
|
||||
for frame_index in range(num_frames):
|
||||
audio_index = math.floor(frame_index * whisper_idx_multiplier)
|
||||
end_index = audio_index + audio_feature_length_per_frame
|
||||
|
||||
# Handle case where audio is shorter than video
|
||||
if end_index > whisper_feature.shape[1]:
|
||||
available = whisper_feature[:, audio_index:]
|
||||
padding_size = end_index - whisper_feature.shape[1]
|
||||
if padding_size > 0:
|
||||
padding = torch.zeros((whisper_feature.shape[0], padding_size, *whisper_feature.shape[2:]),
|
||||
device=whisper_feature.device, dtype=whisper_feature.dtype)
|
||||
audio_clip = torch.cat([available, padding], dim=1)
|
||||
else:
|
||||
audio_clip = available
|
||||
else:
|
||||
audio_clip = whisper_feature[:, audio_index: end_index]
|
||||
|
||||
# Final size check and padding
|
||||
if audio_clip.shape[1] < audio_feature_length_per_frame:
|
||||
padding_size = audio_feature_length_per_frame - audio_clip.shape[1]
|
||||
padding = torch.zeros((whisper_feature.shape[0], padding_size, *whisper_feature.shape[2:]),
|
||||
device=whisper_feature.device, dtype=whisper_feature.dtype)
|
||||
audio_clip = torch.cat([audio_clip, padding], dim=1)
|
||||
|
||||
audio_prompts.append(audio_clip)
|
||||
|
||||
audio_prompts = torch.cat(audio_prompts, dim=0) # T, 10, 5, 384
|
||||
audio_prompts = rearrange(audio_prompts, 'b c h w -> b (c h) w')
|
||||
return audio_prompts
|
||||
|
||||
if __name__ == "__main__":
|
||||
audio_processor = AudioProcessor()
|
||||
wav_path = "./2.wav"
|
||||
audio_feature, librosa_feature_length = audio_processor.get_audio_feature(wav_path)
|
||||
print("Audio Feature shape:", audio_feature.shape)
|
||||
print("librosa_feature_length:", librosa_feature_length)
|
||||
17
models/MuseTalk/musetalk/utils/audio_utils.py
Normal file
17
models/MuseTalk/musetalk/utils/audio_utils.py
Normal file
@@ -0,0 +1,17 @@
|
||||
import os, subprocess
|
||||
|
||||
def ensure_wav(input_path: str, target_path: str | None = None) -> str:
|
||||
"""
|
||||
Convert any audio (mp3/ogg/m4a/wav/…) to 16kHz mono PCM WAV via ffmpeg.
|
||||
Returns path to the converted .wav (original if already correct).
|
||||
"""
|
||||
if not isinstance(input_path, str) or not os.path.exists(input_path):
|
||||
return input_path
|
||||
base, ext = os.path.splitext(input_path)
|
||||
ext = ext.lower()
|
||||
|
||||
if target_path is None:
|
||||
target_path = base + "_16k.wav"
|
||||
cmd = ["ffmpeg", "-y", "-i", input_path, "-ar", "16000", "-ac", "1", "-c:a", "pcm_s16le", target_path]
|
||||
subprocess.run(cmd, check=True, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)
|
||||
return target_path
|
||||
136
models/MuseTalk/musetalk/utils/blending.py
Normal file
136
models/MuseTalk/musetalk/utils/blending.py
Normal file
@@ -0,0 +1,136 @@
|
||||
from PIL import Image
|
||||
import numpy as np
|
||||
import cv2
|
||||
import copy
|
||||
|
||||
|
||||
def get_crop_box(box, expand):
|
||||
x, y, x1, y1 = box
|
||||
x_c, y_c = (x+x1)//2, (y+y1)//2
|
||||
w, h = x1-x, y1-y
|
||||
s = int(max(w, h)//2*expand)
|
||||
crop_box = [x_c-s, y_c-s, x_c+s, y_c+s]
|
||||
return crop_box, s
|
||||
|
||||
|
||||
def face_seg(image, mode="raw", fp=None):
|
||||
"""
|
||||
对图像进行面部解析,生成面部区域的掩码。
|
||||
|
||||
Args:
|
||||
image (PIL.Image): 输入图像。
|
||||
|
||||
Returns:
|
||||
PIL.Image: 面部区域的掩码图像。
|
||||
"""
|
||||
seg_image = fp(image, mode=mode) # 使用 FaceParsing 模型解析面部
|
||||
if seg_image is None:
|
||||
print("error, no person_segment") # 如果没有检测到面部,返回错误
|
||||
return None
|
||||
|
||||
seg_image = seg_image.resize(image.size) # 将掩码图像调整为输入图像的大小
|
||||
return seg_image
|
||||
|
||||
|
||||
def get_image(image, face, face_box, upper_boundary_ratio=0.5, expand=1.5, mode="raw", fp=None):
|
||||
"""
|
||||
将裁剪的面部图像粘贴回原始图像,并进行一些处理。
|
||||
|
||||
Args:
|
||||
image (numpy.ndarray): 原始图像(身体部分)。
|
||||
face (numpy.ndarray): 裁剪的面部图像。
|
||||
face_box (tuple): 面部边界框的坐标 (x, y, x1, y1)。
|
||||
upper_boundary_ratio (float): 用于控制面部区域的保留比例。
|
||||
expand (float): 扩展因子,用于放大裁剪框。
|
||||
mode: 融合mask构建方式
|
||||
|
||||
Returns:
|
||||
numpy.ndarray: 处理后的图像。
|
||||
"""
|
||||
# 将 numpy 数组转换为 PIL 图像
|
||||
body = Image.fromarray(image[:, :, ::-1]) # 身体部分图像(整张图)
|
||||
face = Image.fromarray(face[:, :, ::-1]) # 面部图像
|
||||
|
||||
x, y, x1, y1 = face_box # 获取面部边界框的坐标
|
||||
crop_box, s = get_crop_box(face_box, expand) # 计算扩展后的裁剪框
|
||||
x_s, y_s, x_e, y_e = crop_box # 裁剪框的坐标
|
||||
face_position = (x, y) # 面部在原始图像中的位置
|
||||
|
||||
# 从身体图像中裁剪出扩展后的面部区域(下巴到边界有距离)
|
||||
face_large = body.crop(crop_box)
|
||||
|
||||
ori_shape = face_large.size # 裁剪后图像的原始尺寸
|
||||
|
||||
# 对裁剪后的面部区域进行面部解析,生成掩码
|
||||
mask_image = face_seg(face_large, mode=mode, fp=fp)
|
||||
|
||||
mask_small = mask_image.crop((x - x_s, y - y_s, x1 - x_s, y1 - y_s)) # 裁剪出面部区域的掩码
|
||||
|
||||
mask_image = Image.new('L', ori_shape, 0) # 创建一个全黑的掩码图像
|
||||
mask_image.paste(mask_small, (x - x_s, y - y_s, x1 - x_s, y1 - y_s)) # 将面部掩码粘贴到全黑图像上
|
||||
|
||||
|
||||
# 保留面部区域的上半部分(用于控制说话区域)
|
||||
width, height = mask_image.size
|
||||
top_boundary = int(height * upper_boundary_ratio) # 计算上半部分的边界
|
||||
modified_mask_image = Image.new('L', ori_shape, 0) # 创建一个新的全黑掩码图像
|
||||
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary)) # 粘贴上半部分掩码
|
||||
|
||||
|
||||
# 对掩码进行高斯模糊,使边缘更平滑
|
||||
blur_kernel_size = int(0.05 * ori_shape[0] // 2 * 2) + 1 # 计算模糊核大小
|
||||
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0) # 高斯模糊
|
||||
#mask_array = np.array(modified_mask_image)
|
||||
mask_image = Image.fromarray(mask_array) # 将模糊后的掩码转换回 PIL 图像
|
||||
|
||||
# 将裁剪的面部图像粘贴回扩展后的面部区域
|
||||
face_large.paste(face, (x - x_s, y - y_s, x1 - x_s, y1 - y_s))
|
||||
|
||||
body.paste(face_large, crop_box[:2], mask_image)
|
||||
|
||||
body = np.array(body) # 将 PIL 图像转换回 numpy 数组
|
||||
|
||||
return body[:, :, ::-1] # 返回处理后的图像(BGR 转 RGB)
|
||||
|
||||
|
||||
def get_image_blending(image, face, face_box, mask_array, crop_box):
|
||||
body = Image.fromarray(image[:,:,::-1])
|
||||
face = Image.fromarray(face[:,:,::-1])
|
||||
|
||||
x, y, x1, y1 = face_box
|
||||
x_s, y_s, x_e, y_e = crop_box
|
||||
face_large = body.crop(crop_box)
|
||||
|
||||
mask_image = Image.fromarray(mask_array)
|
||||
mask_image = mask_image.convert("L")
|
||||
face_large.paste(face, (x-x_s, y-y_s, x1-x_s, y1-y_s))
|
||||
body.paste(face_large, crop_box[:2], mask_image)
|
||||
body = np.array(body)
|
||||
return body[:,:,::-1]
|
||||
|
||||
|
||||
def get_image_prepare_material(image, face_box, upper_boundary_ratio=0.5, expand=1.5, fp=None, mode="raw"):
|
||||
body = Image.fromarray(image[:,:,::-1])
|
||||
|
||||
x, y, x1, y1 = face_box
|
||||
#print(x1-x,y1-y)
|
||||
crop_box, s = get_crop_box(face_box, expand)
|
||||
x_s, y_s, x_e, y_e = crop_box
|
||||
|
||||
face_large = body.crop(crop_box)
|
||||
ori_shape = face_large.size
|
||||
|
||||
mask_image = face_seg(face_large, mode=mode, fp=fp)
|
||||
mask_small = mask_image.crop((x-x_s, y-y_s, x1-x_s, y1-y_s))
|
||||
mask_image = Image.new('L', ori_shape, 0)
|
||||
mask_image.paste(mask_small, (x-x_s, y-y_s, x1-x_s, y1-y_s))
|
||||
|
||||
# keep upper_boundary_ratio of talking area
|
||||
width, height = mask_image.size
|
||||
top_boundary = int(height * upper_boundary_ratio)
|
||||
modified_mask_image = Image.new('L', ori_shape, 0)
|
||||
modified_mask_image.paste(mask_image.crop((0, top_boundary, width, height)), (0, top_boundary))
|
||||
|
||||
blur_kernel_size = int(0.1 * ori_shape[0] // 2 * 2) + 1
|
||||
mask_array = cv2.GaussianBlur(np.array(modified_mask_image), (blur_kernel_size, blur_kernel_size), 0)
|
||||
return mask_array, crop_box
|
||||
54
models/MuseTalk/musetalk/utils/dwpose/default_runtime.py
Normal file
54
models/MuseTalk/musetalk/utils/dwpose/default_runtime.py
Normal file
@@ -0,0 +1,54 @@
|
||||
default_scope = 'mmpose'
|
||||
|
||||
# hooks
|
||||
default_hooks = dict(
|
||||
timer=dict(type='IterTimerHook'),
|
||||
logger=dict(type='LoggerHook', interval=50),
|
||||
param_scheduler=dict(type='ParamSchedulerHook'),
|
||||
checkpoint=dict(type='CheckpointHook', interval=10),
|
||||
sampler_seed=dict(type='DistSamplerSeedHook'),
|
||||
visualization=dict(type='PoseVisualizationHook', enable=False),
|
||||
badcase=dict(
|
||||
type='BadCaseAnalysisHook',
|
||||
enable=False,
|
||||
out_dir='badcase',
|
||||
metric_type='loss',
|
||||
badcase_thr=5))
|
||||
|
||||
# custom hooks
|
||||
custom_hooks = [
|
||||
# Synchronize model buffers such as running_mean and running_var in BN
|
||||
# at the end of each epoch
|
||||
dict(type='SyncBuffersHook')
|
||||
]
|
||||
|
||||
# multi-processing backend
|
||||
env_cfg = dict(
|
||||
cudnn_benchmark=False,
|
||||
mp_cfg=dict(mp_start_method='fork', opencv_num_threads=0),
|
||||
dist_cfg=dict(backend='nccl'),
|
||||
)
|
||||
|
||||
# visualizer
|
||||
vis_backends = [
|
||||
dict(type='LocalVisBackend'),
|
||||
# dict(type='TensorboardVisBackend'),
|
||||
# dict(type='WandbVisBackend'),
|
||||
]
|
||||
visualizer = dict(
|
||||
type='PoseLocalVisualizer', vis_backends=vis_backends, name='visualizer')
|
||||
|
||||
# logger
|
||||
log_processor = dict(
|
||||
type='LogProcessor', window_size=50, by_epoch=True, num_digits=6)
|
||||
log_level = 'INFO'
|
||||
load_from = None
|
||||
resume = False
|
||||
|
||||
# file I/O backend
|
||||
backend_args = dict(backend='local')
|
||||
|
||||
# training/validation/testing progress
|
||||
train_cfg = dict(by_epoch=True)
|
||||
val_cfg = dict()
|
||||
test_cfg = dict()
|
||||
@@ -0,0 +1,257 @@
|
||||
#_base_ = ['../../../_base_/default_runtime.py']
|
||||
_base_ = ['default_runtime.py']
|
||||
|
||||
# runtime
|
||||
max_epochs = 270
|
||||
stage2_num_epochs = 30
|
||||
base_lr = 4e-3
|
||||
train_batch_size = 32
|
||||
val_batch_size = 32
|
||||
|
||||
train_cfg = dict(max_epochs=max_epochs, val_interval=10)
|
||||
randomness = dict(seed=21)
|
||||
|
||||
# optimizer
|
||||
optim_wrapper = dict(
|
||||
type='OptimWrapper',
|
||||
optimizer=dict(type='AdamW', lr=base_lr, weight_decay=0.05),
|
||||
paramwise_cfg=dict(
|
||||
norm_decay_mult=0, bias_decay_mult=0, bypass_duplicate=True))
|
||||
|
||||
# learning rate
|
||||
param_scheduler = [
|
||||
dict(
|
||||
type='LinearLR',
|
||||
start_factor=1.0e-5,
|
||||
by_epoch=False,
|
||||
begin=0,
|
||||
end=1000),
|
||||
dict(
|
||||
# use cosine lr from 150 to 300 epoch
|
||||
type='CosineAnnealingLR',
|
||||
eta_min=base_lr * 0.05,
|
||||
begin=max_epochs // 2,
|
||||
end=max_epochs,
|
||||
T_max=max_epochs // 2,
|
||||
by_epoch=True,
|
||||
convert_to_iter_based=True),
|
||||
]
|
||||
|
||||
# automatically scaling LR based on the actual training batch size
|
||||
auto_scale_lr = dict(base_batch_size=512)
|
||||
|
||||
# codec settings
|
||||
codec = dict(
|
||||
type='SimCCLabel',
|
||||
input_size=(288, 384),
|
||||
sigma=(6., 6.93),
|
||||
simcc_split_ratio=2.0,
|
||||
normalize=False,
|
||||
use_dark=False)
|
||||
|
||||
# model settings
|
||||
model = dict(
|
||||
type='TopdownPoseEstimator',
|
||||
data_preprocessor=dict(
|
||||
type='PoseDataPreprocessor',
|
||||
mean=[123.675, 116.28, 103.53],
|
||||
std=[58.395, 57.12, 57.375],
|
||||
bgr_to_rgb=True),
|
||||
backbone=dict(
|
||||
_scope_='mmdet',
|
||||
type='CSPNeXt',
|
||||
arch='P5',
|
||||
expand_ratio=0.5,
|
||||
deepen_factor=1.,
|
||||
widen_factor=1.,
|
||||
out_indices=(4, ),
|
||||
channel_attention=True,
|
||||
norm_cfg=dict(type='SyncBN'),
|
||||
act_cfg=dict(type='SiLU'),
|
||||
init_cfg=dict(
|
||||
type='Pretrained',
|
||||
prefix='backbone.',
|
||||
checkpoint='https://download.openmmlab.com/mmpose/v1/projects/'
|
||||
'rtmpose/cspnext-l_udp-aic-coco_210e-256x192-273b7631_20230130.pth' # noqa: E501
|
||||
)),
|
||||
head=dict(
|
||||
type='RTMCCHead',
|
||||
in_channels=1024,
|
||||
out_channels=133,
|
||||
input_size=codec['input_size'],
|
||||
in_featuremap_size=(9, 12),
|
||||
simcc_split_ratio=codec['simcc_split_ratio'],
|
||||
final_layer_kernel_size=7,
|
||||
gau_cfg=dict(
|
||||
hidden_dims=256,
|
||||
s=128,
|
||||
expansion_factor=2,
|
||||
dropout_rate=0.,
|
||||
drop_path=0.,
|
||||
act_fn='SiLU',
|
||||
use_rel_bias=False,
|
||||
pos_enc=False),
|
||||
loss=dict(
|
||||
type='KLDiscretLoss',
|
||||
use_target_weight=True,
|
||||
beta=10.,
|
||||
label_softmax=True),
|
||||
decoder=codec),
|
||||
test_cfg=dict(flip_test=True, ))
|
||||
|
||||
# base dataset settings
|
||||
dataset_type = 'UBody2dDataset'
|
||||
data_mode = 'topdown'
|
||||
data_root = 'data/UBody/'
|
||||
|
||||
backend_args = dict(backend='local')
|
||||
|
||||
scenes = [
|
||||
'Magic_show', 'Entertainment', 'ConductMusic', 'Online_class', 'TalkShow',
|
||||
'Speech', 'Fitness', 'Interview', 'Olympic', 'TVShow', 'Singing',
|
||||
'SignLanguage', 'Movie', 'LiveVlog', 'VideoConference'
|
||||
]
|
||||
|
||||
train_datasets = [
|
||||
dict(
|
||||
type='CocoWholeBodyDataset',
|
||||
data_root='data/coco/',
|
||||
data_mode=data_mode,
|
||||
ann_file='annotations/coco_wholebody_train_v1.0.json',
|
||||
data_prefix=dict(img='train2017/'),
|
||||
pipeline=[])
|
||||
]
|
||||
|
||||
for scene in scenes:
|
||||
train_dataset = dict(
|
||||
type=dataset_type,
|
||||
data_root=data_root,
|
||||
data_mode=data_mode,
|
||||
ann_file=f'annotations/{scene}/train_annotations.json',
|
||||
data_prefix=dict(img='images/'),
|
||||
pipeline=[],
|
||||
sample_interval=10)
|
||||
train_datasets.append(train_dataset)
|
||||
|
||||
# pipelines
|
||||
train_pipeline = [
|
||||
dict(type='LoadImage', backend_args=backend_args),
|
||||
dict(type='GetBBoxCenterScale'),
|
||||
dict(type='RandomFlip', direction='horizontal'),
|
||||
dict(type='RandomHalfBody'),
|
||||
dict(
|
||||
type='RandomBBoxTransform', scale_factor=[0.5, 1.5], rotate_factor=90),
|
||||
dict(type='TopdownAffine', input_size=codec['input_size']),
|
||||
dict(type='mmdet.YOLOXHSVRandomAug'),
|
||||
dict(
|
||||
type='Albumentation',
|
||||
transforms=[
|
||||
dict(type='Blur', p=0.1),
|
||||
dict(type='MedianBlur', p=0.1),
|
||||
dict(
|
||||
type='CoarseDropout',
|
||||
max_holes=1,
|
||||
max_height=0.4,
|
||||
max_width=0.4,
|
||||
min_holes=1,
|
||||
min_height=0.2,
|
||||
min_width=0.2,
|
||||
p=1.0),
|
||||
]),
|
||||
dict(type='GenerateTarget', encoder=codec),
|
||||
dict(type='PackPoseInputs')
|
||||
]
|
||||
val_pipeline = [
|
||||
dict(type='LoadImage', backend_args=backend_args),
|
||||
dict(type='GetBBoxCenterScale'),
|
||||
dict(type='TopdownAffine', input_size=codec['input_size']),
|
||||
dict(type='PackPoseInputs')
|
||||
]
|
||||
|
||||
train_pipeline_stage2 = [
|
||||
dict(type='LoadImage', backend_args=backend_args),
|
||||
dict(type='GetBBoxCenterScale'),
|
||||
dict(type='RandomFlip', direction='horizontal'),
|
||||
dict(type='RandomHalfBody'),
|
||||
dict(
|
||||
type='RandomBBoxTransform',
|
||||
shift_factor=0.,
|
||||
scale_factor=[0.5, 1.5],
|
||||
rotate_factor=90),
|
||||
dict(type='TopdownAffine', input_size=codec['input_size']),
|
||||
dict(type='mmdet.YOLOXHSVRandomAug'),
|
||||
dict(
|
||||
type='Albumentation',
|
||||
transforms=[
|
||||
dict(type='Blur', p=0.1),
|
||||
dict(type='MedianBlur', p=0.1),
|
||||
dict(
|
||||
type='CoarseDropout',
|
||||
max_holes=1,
|
||||
max_height=0.4,
|
||||
max_width=0.4,
|
||||
min_holes=1,
|
||||
min_height=0.2,
|
||||
min_width=0.2,
|
||||
p=0.5),
|
||||
]),
|
||||
dict(type='GenerateTarget', encoder=codec),
|
||||
dict(type='PackPoseInputs')
|
||||
]
|
||||
|
||||
# data loaders
|
||||
train_dataloader = dict(
|
||||
batch_size=train_batch_size,
|
||||
num_workers=10,
|
||||
persistent_workers=True,
|
||||
sampler=dict(type='DefaultSampler', shuffle=True),
|
||||
dataset=dict(
|
||||
type='CombinedDataset',
|
||||
metainfo=dict(from_file='configs/_base_/datasets/coco_wholebody.py'),
|
||||
datasets=train_datasets,
|
||||
pipeline=train_pipeline,
|
||||
test_mode=False,
|
||||
))
|
||||
|
||||
val_dataloader = dict(
|
||||
batch_size=val_batch_size,
|
||||
num_workers=10,
|
||||
persistent_workers=True,
|
||||
drop_last=False,
|
||||
sampler=dict(type='DefaultSampler', shuffle=False, round_up=False),
|
||||
dataset=dict(
|
||||
type='CocoWholeBodyDataset',
|
||||
data_root=data_root,
|
||||
data_mode=data_mode,
|
||||
ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json',
|
||||
bbox_file='data/coco/person_detection_results/'
|
||||
'COCO_val2017_detections_AP_H_56_person.json',
|
||||
data_prefix=dict(img='coco/val2017/'),
|
||||
test_mode=True,
|
||||
pipeline=val_pipeline,
|
||||
))
|
||||
test_dataloader = val_dataloader
|
||||
|
||||
# hooks
|
||||
default_hooks = dict(
|
||||
checkpoint=dict(
|
||||
save_best='coco-wholebody/AP', rule='greater', max_keep_ckpts=1))
|
||||
|
||||
custom_hooks = [
|
||||
dict(
|
||||
type='EMAHook',
|
||||
ema_type='ExpMomentumEMA',
|
||||
momentum=0.0002,
|
||||
update_buffers=True,
|
||||
priority=49),
|
||||
dict(
|
||||
type='mmdet.PipelineSwitchHook',
|
||||
switch_epoch=max_epochs - stage2_num_epochs,
|
||||
switch_pipeline=train_pipeline_stage2)
|
||||
]
|
||||
|
||||
# evaluators
|
||||
val_evaluator = dict(
|
||||
type='CocoWholeBodyMetric',
|
||||
ann_file='data/coco/annotations/coco_wholebody_val_v1.0.json')
|
||||
test_evaluator = val_evaluator
|
||||
1
models/MuseTalk/musetalk/utils/face_detection/README.md
Normal file
1
models/MuseTalk/musetalk/utils/face_detection/README.md
Normal file
@@ -0,0 +1 @@
|
||||
The code for Face Detection in this folder has been taken from the wonderful [face_alignment](https://github.com/1adrianb/face-alignment) repository. This has been modified to take batches of faces at a time.
|
||||
@@ -0,0 +1,7 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
__author__ = """Adrian Bulat"""
|
||||
__email__ = 'adrian.bulat@nottingham.ac.uk'
|
||||
__version__ = '1.0.1'
|
||||
|
||||
from .api import FaceAlignment, LandmarksType, NetworkSize, YOLOv8_face
|
||||
240
models/MuseTalk/musetalk/utils/face_detection/api.py
Normal file
240
models/MuseTalk/musetalk/utils/face_detection/api.py
Normal file
@@ -0,0 +1,240 @@
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import torch
|
||||
from torch.utils.model_zoo import load_url
|
||||
from enum import Enum
|
||||
import numpy as np
|
||||
import cv2
|
||||
try:
|
||||
import urllib.request as request_file
|
||||
except BaseException:
|
||||
import urllib as request_file
|
||||
|
||||
from .models import FAN, ResNetDepth
|
||||
from .utils import *
|
||||
|
||||
|
||||
class LandmarksType(Enum):
|
||||
"""Enum class defining the type of landmarks to detect.
|
||||
|
||||
``_2D`` - the detected points ``(x,y)`` are detected in a 2D space and follow the visible contour of the face
|
||||
``_2halfD`` - this points represent the projection of the 3D points into 3D
|
||||
``_3D`` - detect the points ``(x,y,z)``` in a 3D space
|
||||
|
||||
"""
|
||||
_2D = 1
|
||||
_2halfD = 2
|
||||
_3D = 3
|
||||
|
||||
|
||||
class NetworkSize(Enum):
|
||||
# TINY = 1
|
||||
# SMALL = 2
|
||||
# MEDIUM = 3
|
||||
LARGE = 4
|
||||
|
||||
def __new__(cls, value):
|
||||
member = object.__new__(cls)
|
||||
member._value_ = value
|
||||
return member
|
||||
|
||||
def __int__(self):
|
||||
return self.value
|
||||
|
||||
|
||||
|
||||
class FaceAlignment:
|
||||
def __init__(self, landmarks_type, network_size=NetworkSize.LARGE,
|
||||
device='cuda', flip_input=False, face_detector='sfd', verbose=False):
|
||||
self.device = device
|
||||
self.flip_input = flip_input
|
||||
self.landmarks_type = landmarks_type
|
||||
self.verbose = verbose
|
||||
|
||||
network_size = int(network_size)
|
||||
|
||||
if 'cuda' in device:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
# torch.backends.cuda.matmul.allow_tf32 = False
|
||||
# torch.backends.cudnn.benchmark = True
|
||||
# torch.backends.cudnn.deterministic = False
|
||||
# torch.backends.cudnn.allow_tf32 = True
|
||||
print('cuda start')
|
||||
|
||||
|
||||
# Get the face detector
|
||||
face_detector_module = __import__('face_detection.detection.' + face_detector,
|
||||
globals(), locals(), [face_detector], 0)
|
||||
|
||||
self.face_detector = face_detector_module.FaceDetector(device=device, verbose=verbose)
|
||||
|
||||
def get_detections_for_batch(self, images):
|
||||
images = images[..., ::-1]
|
||||
detected_faces = self.face_detector.detect_from_batch(images.copy())
|
||||
results = []
|
||||
|
||||
for i, d in enumerate(detected_faces):
|
||||
if len(d) == 0:
|
||||
results.append(None)
|
||||
continue
|
||||
d = d[0]
|
||||
d = np.clip(d, 0, None)
|
||||
|
||||
x1, y1, x2, y2 = map(int, d[:-1])
|
||||
results.append((x1, y1, x2, y2))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class YOLOv8_face:
|
||||
def __init__(self, path = 'face_detection/weights/yolov8n-face.onnx', conf_thres=0.2, iou_thres=0.5):
|
||||
self.conf_threshold = conf_thres
|
||||
self.iou_threshold = iou_thres
|
||||
self.class_names = ['face']
|
||||
self.num_classes = len(self.class_names)
|
||||
# Initialize model
|
||||
self.net = cv2.dnn.readNet(path)
|
||||
self.input_height = 640
|
||||
self.input_width = 640
|
||||
self.reg_max = 16
|
||||
|
||||
self.project = np.arange(self.reg_max)
|
||||
self.strides = (8, 16, 32)
|
||||
self.feats_hw = [(math.ceil(self.input_height / self.strides[i]), math.ceil(self.input_width / self.strides[i])) for i in range(len(self.strides))]
|
||||
self.anchors = self.make_anchors(self.feats_hw)
|
||||
|
||||
def make_anchors(self, feats_hw, grid_cell_offset=0.5):
|
||||
"""Generate anchors from features."""
|
||||
anchor_points = {}
|
||||
for i, stride in enumerate(self.strides):
|
||||
h,w = feats_hw[i]
|
||||
x = np.arange(0, w) + grid_cell_offset # shift x
|
||||
y = np.arange(0, h) + grid_cell_offset # shift y
|
||||
sx, sy = np.meshgrid(x, y)
|
||||
# sy, sx = np.meshgrid(y, x)
|
||||
anchor_points[stride] = np.stack((sx, sy), axis=-1).reshape(-1, 2)
|
||||
return anchor_points
|
||||
|
||||
def softmax(self, x, axis=1):
|
||||
x_exp = np.exp(x)
|
||||
# 如果是列向量,则axis=0
|
||||
x_sum = np.sum(x_exp, axis=axis, keepdims=True)
|
||||
s = x_exp / x_sum
|
||||
return s
|
||||
|
||||
def resize_image(self, srcimg, keep_ratio=True):
|
||||
top, left, newh, neww = 0, 0, self.input_width, self.input_height
|
||||
if keep_ratio and srcimg.shape[0] != srcimg.shape[1]:
|
||||
hw_scale = srcimg.shape[0] / srcimg.shape[1]
|
||||
if hw_scale > 1:
|
||||
newh, neww = self.input_height, int(self.input_width / hw_scale)
|
||||
img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)
|
||||
left = int((self.input_width - neww) * 0.5)
|
||||
img = cv2.copyMakeBorder(img, 0, 0, left, self.input_width - neww - left, cv2.BORDER_CONSTANT,
|
||||
value=(0, 0, 0)) # add border
|
||||
else:
|
||||
newh, neww = int(self.input_height * hw_scale), self.input_width
|
||||
img = cv2.resize(srcimg, (neww, newh), interpolation=cv2.INTER_AREA)
|
||||
top = int((self.input_height - newh) * 0.5)
|
||||
img = cv2.copyMakeBorder(img, top, self.input_height - newh - top, 0, 0, cv2.BORDER_CONSTANT,
|
||||
value=(0, 0, 0))
|
||||
else:
|
||||
img = cv2.resize(srcimg, (self.input_width, self.input_height), interpolation=cv2.INTER_AREA)
|
||||
return img, newh, neww, top, left
|
||||
|
||||
def detect(self, srcimg):
|
||||
input_img, newh, neww, padh, padw = self.resize_image(cv2.cvtColor(srcimg, cv2.COLOR_BGR2RGB))
|
||||
scale_h, scale_w = srcimg.shape[0]/newh, srcimg.shape[1]/neww
|
||||
input_img = input_img.astype(np.float32) / 255.0
|
||||
|
||||
blob = cv2.dnn.blobFromImage(input_img)
|
||||
self.net.setInput(blob)
|
||||
outputs = self.net.forward(self.net.getUnconnectedOutLayersNames())
|
||||
# if isinstance(outputs, tuple):
|
||||
# outputs = list(outputs)
|
||||
# if float(cv2.__version__[:3])>=4.7:
|
||||
# outputs = [outputs[2], outputs[0], outputs[1]] ###opencv4.7需要这一步,opencv4.5不需要
|
||||
# Perform inference on the image
|
||||
det_bboxes, det_conf, det_classid, landmarks = self.post_process(outputs, scale_h, scale_w, padh, padw)
|
||||
return det_bboxes, det_conf, det_classid, landmarks
|
||||
|
||||
def post_process(self, preds, scale_h, scale_w, padh, padw):
|
||||
bboxes, scores, landmarks = [], [], []
|
||||
for i, pred in enumerate(preds):
|
||||
stride = int(self.input_height/pred.shape[2])
|
||||
pred = pred.transpose((0, 2, 3, 1))
|
||||
|
||||
box = pred[..., :self.reg_max * 4]
|
||||
cls = 1 / (1 + np.exp(-pred[..., self.reg_max * 4:-15])).reshape((-1,1))
|
||||
kpts = pred[..., -15:].reshape((-1,15)) ### x1,y1,score1, ..., x5,y5,score5
|
||||
|
||||
# tmp = box.reshape(self.feats_hw[i][0], self.feats_hw[i][1], 4, self.reg_max)
|
||||
tmp = box.reshape(-1, 4, self.reg_max)
|
||||
bbox_pred = self.softmax(tmp, axis=-1)
|
||||
bbox_pred = np.dot(bbox_pred, self.project).reshape((-1,4))
|
||||
|
||||
bbox = self.distance2bbox(self.anchors[stride], bbox_pred, max_shape=(self.input_height, self.input_width)) * stride
|
||||
kpts[:, 0::3] = (kpts[:, 0::3] * 2.0 + (self.anchors[stride][:, 0].reshape((-1,1)) - 0.5)) * stride
|
||||
kpts[:, 1::3] = (kpts[:, 1::3] * 2.0 + (self.anchors[stride][:, 1].reshape((-1,1)) - 0.5)) * stride
|
||||
kpts[:, 2::3] = 1 / (1+np.exp(-kpts[:, 2::3]))
|
||||
|
||||
bbox -= np.array([[padw, padh, padw, padh]]) ###合理使用广播法则
|
||||
bbox *= np.array([[scale_w, scale_h, scale_w, scale_h]])
|
||||
kpts -= np.tile(np.array([padw, padh, 0]), 5).reshape((1,15))
|
||||
kpts *= np.tile(np.array([scale_w, scale_h, 1]), 5).reshape((1,15))
|
||||
|
||||
bboxes.append(bbox)
|
||||
scores.append(cls)
|
||||
landmarks.append(kpts)
|
||||
|
||||
bboxes = np.concatenate(bboxes, axis=0)
|
||||
scores = np.concatenate(scores, axis=0)
|
||||
landmarks = np.concatenate(landmarks, axis=0)
|
||||
|
||||
bboxes_wh = bboxes.copy()
|
||||
bboxes_wh[:, 2:4] = bboxes[:, 2:4] - bboxes[:, 0:2] ####xywh
|
||||
classIds = np.argmax(scores, axis=1)
|
||||
confidences = np.max(scores, axis=1) ####max_class_confidence
|
||||
|
||||
mask = confidences>self.conf_threshold
|
||||
bboxes_wh = bboxes_wh[mask] ###合理使用广播法则
|
||||
confidences = confidences[mask]
|
||||
classIds = classIds[mask]
|
||||
landmarks = landmarks[mask]
|
||||
|
||||
indices = cv2.dnn.NMSBoxes(bboxes_wh.tolist(), confidences.tolist(), self.conf_threshold,
|
||||
self.iou_threshold).flatten()
|
||||
if len(indices) > 0:
|
||||
mlvl_bboxes = bboxes_wh[indices]
|
||||
confidences = confidences[indices]
|
||||
classIds = classIds[indices]
|
||||
landmarks = landmarks[indices]
|
||||
return mlvl_bboxes, confidences, classIds, landmarks
|
||||
else:
|
||||
print('nothing detect')
|
||||
return np.array([]), np.array([]), np.array([]), np.array([])
|
||||
|
||||
def distance2bbox(self, points, distance, max_shape=None):
|
||||
x1 = points[:, 0] - distance[:, 0]
|
||||
y1 = points[:, 1] - distance[:, 1]
|
||||
x2 = points[:, 0] + distance[:, 2]
|
||||
y2 = points[:, 1] + distance[:, 3]
|
||||
if max_shape is not None:
|
||||
x1 = np.clip(x1, 0, max_shape[1])
|
||||
y1 = np.clip(y1, 0, max_shape[0])
|
||||
x2 = np.clip(x2, 0, max_shape[1])
|
||||
y2 = np.clip(y2, 0, max_shape[0])
|
||||
return np.stack([x1, y1, x2, y2], axis=-1)
|
||||
|
||||
def draw_detections(self, image, boxes, scores, kpts):
|
||||
for box, score, kp in zip(boxes, scores, kpts):
|
||||
x, y, w, h = box.astype(int)
|
||||
# Draw rectangle
|
||||
cv2.rectangle(image, (x, y), (x + w, y + h), (0, 0, 255), thickness=3)
|
||||
cv2.putText(image, "face:"+str(round(score,2)), (x, y - 5), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 0, 255), thickness=2)
|
||||
for i in range(5):
|
||||
cv2.circle(image, (int(kp[i * 3]), int(kp[i * 3 + 1])), 4, (0, 255, 0), thickness=-1)
|
||||
# cv2.putText(image, str(i), (int(kp[i * 3]), int(kp[i * 3 + 1]) - 10), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), thickness=1)
|
||||
return image
|
||||
|
||||
ROOT = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -0,0 +1 @@
|
||||
from .core import FaceDetector
|
||||
130
models/MuseTalk/musetalk/utils/face_detection/detection/core.py
Normal file
130
models/MuseTalk/musetalk/utils/face_detection/detection/core.py
Normal file
@@ -0,0 +1,130 @@
|
||||
import logging
|
||||
import glob
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import torch
|
||||
import cv2
|
||||
|
||||
|
||||
class FaceDetector(object):
|
||||
"""An abstract class representing a face detector.
|
||||
|
||||
Any other face detection implementation must subclass it. All subclasses
|
||||
must implement ``detect_from_image``, that return a list of detected
|
||||
bounding boxes. Optionally, for speed considerations detect from path is
|
||||
recommended.
|
||||
"""
|
||||
|
||||
def __init__(self, device, verbose):
|
||||
self.device = device
|
||||
self.verbose = verbose
|
||||
|
||||
if verbose:
|
||||
if 'cpu' in device:
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.warning("Detection running on CPU, this may be potentially slow.")
|
||||
|
||||
if 'cpu' not in device and 'cuda' not in device:
|
||||
if verbose:
|
||||
logger.error("Expected values for device are: {cpu, cuda} but got: %s", device)
|
||||
raise ValueError
|
||||
|
||||
def detect_from_image(self, tensor_or_path):
|
||||
"""Detects faces in a given image.
|
||||
|
||||
This function detects the faces present in a provided BGR(usually)
|
||||
image. The input can be either the image itself or the path to it.
|
||||
|
||||
Arguments:
|
||||
tensor_or_path {numpy.ndarray, torch.tensor or string} -- the path
|
||||
to an image or the image itself.
|
||||
|
||||
Example::
|
||||
|
||||
>>> path_to_image = 'data/image_01.jpg'
|
||||
... detected_faces = detect_from_image(path_to_image)
|
||||
[A list of bounding boxes (x1, y1, x2, y2)]
|
||||
>>> image = cv2.imread(path_to_image)
|
||||
... detected_faces = detect_from_image(image)
|
||||
[A list of bounding boxes (x1, y1, x2, y2)]
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def detect_from_directory(self, path, extensions=['.jpg', '.png'], recursive=False, show_progress_bar=True):
|
||||
"""Detects faces from all the images present in a given directory.
|
||||
|
||||
Arguments:
|
||||
path {string} -- a string containing a path that points to the folder containing the images
|
||||
|
||||
Keyword Arguments:
|
||||
extensions {list} -- list of string containing the extensions to be
|
||||
consider in the following format: ``.extension_name`` (default:
|
||||
{['.jpg', '.png']}) recursive {bool} -- option wherever to scan the
|
||||
folder recursively (default: {False}) show_progress_bar {bool} --
|
||||
display a progressbar (default: {True})
|
||||
|
||||
Example:
|
||||
>>> directory = 'data'
|
||||
... detected_faces = detect_from_directory(directory)
|
||||
{A dictionary of [lists containing bounding boxes(x1, y1, x2, y2)]}
|
||||
|
||||
"""
|
||||
if self.verbose:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if len(extensions) == 0:
|
||||
if self.verbose:
|
||||
logger.error("Expected at list one extension, but none was received.")
|
||||
raise ValueError
|
||||
|
||||
if self.verbose:
|
||||
logger.info("Constructing the list of images.")
|
||||
additional_pattern = '/**/*' if recursive else '/*'
|
||||
files = []
|
||||
for extension in extensions:
|
||||
files.extend(glob.glob(path + additional_pattern + extension, recursive=recursive))
|
||||
|
||||
if self.verbose:
|
||||
logger.info("Finished searching for images. %s images found", len(files))
|
||||
logger.info("Preparing to run the detection.")
|
||||
|
||||
predictions = {}
|
||||
for image_path in tqdm(files, disable=not show_progress_bar):
|
||||
if self.verbose:
|
||||
logger.info("Running the face detector on image: %s", image_path)
|
||||
predictions[image_path] = self.detect_from_image(image_path)
|
||||
|
||||
if self.verbose:
|
||||
logger.info("The detector was successfully run on all %s images", len(files))
|
||||
|
||||
return predictions
|
||||
|
||||
@property
|
||||
def reference_scale(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def reference_x_shift(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def reference_y_shift(self):
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def tensor_or_path_to_ndarray(tensor_or_path, rgb=True):
|
||||
"""Convert path (represented as a string) or torch.tensor to a numpy.ndarray
|
||||
|
||||
Arguments:
|
||||
tensor_or_path {numpy.ndarray, torch.tensor or string} -- path to the image, or the image itself
|
||||
"""
|
||||
if isinstance(tensor_or_path, str):
|
||||
return cv2.imread(tensor_or_path) if not rgb else cv2.imread(tensor_or_path)[..., ::-1]
|
||||
elif torch.is_tensor(tensor_or_path):
|
||||
# Call cpu in case its coming from cuda
|
||||
return tensor_or_path.cpu().numpy()[..., ::-1].copy() if not rgb else tensor_or_path.cpu().numpy()
|
||||
elif isinstance(tensor_or_path, np.ndarray):
|
||||
return tensor_or_path[..., ::-1].copy() if not rgb else tensor_or_path
|
||||
else:
|
||||
raise TypeError
|
||||
@@ -0,0 +1 @@
|
||||
from .sfd_detector import SFDDetector as FaceDetector
|
||||
@@ -0,0 +1,129 @@
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
import cv2
|
||||
import random
|
||||
import datetime
|
||||
import time
|
||||
import math
|
||||
import argparse
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
from iou import IOU
|
||||
except BaseException:
|
||||
# IOU cython speedup 10x
|
||||
def IOU(ax1, ay1, ax2, ay2, bx1, by1, bx2, by2):
|
||||
sa = abs((ax2 - ax1) * (ay2 - ay1))
|
||||
sb = abs((bx2 - bx1) * (by2 - by1))
|
||||
x1, y1 = max(ax1, bx1), max(ay1, by1)
|
||||
x2, y2 = min(ax2, bx2), min(ay2, by2)
|
||||
w = x2 - x1
|
||||
h = y2 - y1
|
||||
if w < 0 or h < 0:
|
||||
return 0.0
|
||||
else:
|
||||
return 1.0 * w * h / (sa + sb - w * h)
|
||||
|
||||
|
||||
def bboxlog(x1, y1, x2, y2, axc, ayc, aww, ahh):
|
||||
xc, yc, ww, hh = (x2 + x1) / 2, (y2 + y1) / 2, x2 - x1, y2 - y1
|
||||
dx, dy = (xc - axc) / aww, (yc - ayc) / ahh
|
||||
dw, dh = math.log(ww / aww), math.log(hh / ahh)
|
||||
return dx, dy, dw, dh
|
||||
|
||||
|
||||
def bboxloginv(dx, dy, dw, dh, axc, ayc, aww, ahh):
|
||||
xc, yc = dx * aww + axc, dy * ahh + ayc
|
||||
ww, hh = math.exp(dw) * aww, math.exp(dh) * ahh
|
||||
x1, x2, y1, y2 = xc - ww / 2, xc + ww / 2, yc - hh / 2, yc + hh / 2
|
||||
return x1, y1, x2, y2
|
||||
|
||||
|
||||
def nms(dets, thresh):
|
||||
if 0 == len(dets):
|
||||
return []
|
||||
x1, y1, x2, y2, scores = dets[:, 0], dets[:, 1], dets[:, 2], dets[:, 3], dets[:, 4]
|
||||
areas = (x2 - x1 + 1) * (y2 - y1 + 1)
|
||||
order = scores.argsort()[::-1]
|
||||
|
||||
keep = []
|
||||
while order.size > 0:
|
||||
i = order[0]
|
||||
keep.append(i)
|
||||
xx1, yy1 = np.maximum(x1[i], x1[order[1:]]), np.maximum(y1[i], y1[order[1:]])
|
||||
xx2, yy2 = np.minimum(x2[i], x2[order[1:]]), np.minimum(y2[i], y2[order[1:]])
|
||||
|
||||
w, h = np.maximum(0.0, xx2 - xx1 + 1), np.maximum(0.0, yy2 - yy1 + 1)
|
||||
ovr = w * h / (areas[i] + areas[order[1:]] - w * h)
|
||||
|
||||
inds = np.where(ovr <= thresh)[0]
|
||||
order = order[inds + 1]
|
||||
|
||||
return keep
|
||||
|
||||
|
||||
def encode(matched, priors, variances):
|
||||
"""Encode the variances from the priorbox layers into the ground truth boxes
|
||||
we have matched (based on jaccard overlap) with the prior boxes.
|
||||
Args:
|
||||
matched: (tensor) Coords of ground truth for each prior in point-form
|
||||
Shape: [num_priors, 4].
|
||||
priors: (tensor) Prior boxes in center-offset form
|
||||
Shape: [num_priors,4].
|
||||
variances: (list[float]) Variances of priorboxes
|
||||
Return:
|
||||
encoded boxes (tensor), Shape: [num_priors, 4]
|
||||
"""
|
||||
|
||||
# dist b/t match center and prior's center
|
||||
g_cxcy = (matched[:, :2] + matched[:, 2:]) / 2 - priors[:, :2]
|
||||
# encode variance
|
||||
g_cxcy /= (variances[0] * priors[:, 2:])
|
||||
# match wh / prior wh
|
||||
g_wh = (matched[:, 2:] - matched[:, :2]) / priors[:, 2:]
|
||||
g_wh = torch.log(g_wh) / variances[1]
|
||||
# return target for smooth_l1_loss
|
||||
return torch.cat([g_cxcy, g_wh], 1) # [num_priors,4]
|
||||
|
||||
|
||||
def decode(loc, priors, variances):
|
||||
"""Decode locations from predictions using priors to undo
|
||||
the encoding we did for offset regression at train time.
|
||||
Args:
|
||||
loc (tensor): location predictions for loc layers,
|
||||
Shape: [num_priors,4]
|
||||
priors (tensor): Prior boxes in center-offset form.
|
||||
Shape: [num_priors,4].
|
||||
variances: (list[float]) Variances of priorboxes
|
||||
Return:
|
||||
decoded bounding box predictions
|
||||
"""
|
||||
|
||||
boxes = torch.cat((
|
||||
priors[:, :2] + loc[:, :2] * variances[0] * priors[:, 2:],
|
||||
priors[:, 2:] * torch.exp(loc[:, 2:] * variances[1])), 1)
|
||||
boxes[:, :2] -= boxes[:, 2:] / 2
|
||||
boxes[:, 2:] += boxes[:, :2]
|
||||
return boxes
|
||||
|
||||
def batch_decode(loc, priors, variances):
|
||||
"""Decode locations from predictions using priors to undo
|
||||
the encoding we did for offset regression at train time.
|
||||
Args:
|
||||
loc (tensor): location predictions for loc layers,
|
||||
Shape: [num_priors,4]
|
||||
priors (tensor): Prior boxes in center-offset form.
|
||||
Shape: [num_priors,4].
|
||||
variances: (list[float]) Variances of priorboxes
|
||||
Return:
|
||||
decoded bounding box predictions
|
||||
"""
|
||||
|
||||
boxes = torch.cat((
|
||||
priors[:, :, :2] + loc[:, :, :2] * variances[0] * priors[:, :, 2:],
|
||||
priors[:, :, 2:] * torch.exp(loc[:, :, 2:] * variances[1])), 2)
|
||||
boxes[:, :, :2] -= boxes[:, :, 2:] / 2
|
||||
boxes[:, :, 2:] += boxes[:, :, :2]
|
||||
return boxes
|
||||
@@ -0,0 +1,114 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import os
|
||||
import sys
|
||||
import cv2
|
||||
import random
|
||||
import datetime
|
||||
import math
|
||||
import argparse
|
||||
import numpy as np
|
||||
|
||||
import scipy.io as sio
|
||||
import zipfile
|
||||
from .net_s3fd import s3fd
|
||||
from .bbox import *
|
||||
|
||||
|
||||
def detect(net, img, device):
|
||||
img = img - np.array([104, 117, 123])
|
||||
img = img.transpose(2, 0, 1)
|
||||
img = img.reshape((1,) + img.shape)
|
||||
|
||||
if 'cuda' in device:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
img = torch.from_numpy(img).float().to(device)
|
||||
BB, CC, HH, WW = img.size()
|
||||
with torch.no_grad():
|
||||
olist = net(img)
|
||||
|
||||
bboxlist = []
|
||||
for i in range(len(olist) // 2):
|
||||
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
|
||||
olist = [oelem.data.cpu() for oelem in olist]
|
||||
for i in range(len(olist) // 2):
|
||||
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
|
||||
FB, FC, FH, FW = ocls.size() # feature map size
|
||||
stride = 2**(i + 2) # 4,8,16,32,64,128
|
||||
anchor = stride * 4
|
||||
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
|
||||
for Iindex, hindex, windex in poss:
|
||||
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
|
||||
score = ocls[0, 1, hindex, windex]
|
||||
loc = oreg[0, :, hindex, windex].contiguous().view(1, 4)
|
||||
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]])
|
||||
variances = [0.1, 0.2]
|
||||
box = decode(loc, priors, variances)
|
||||
x1, y1, x2, y2 = box[0] * 1.0
|
||||
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
|
||||
bboxlist.append([x1, y1, x2, y2, score])
|
||||
bboxlist = np.array(bboxlist)
|
||||
if 0 == len(bboxlist):
|
||||
bboxlist = np.zeros((1, 5))
|
||||
|
||||
return bboxlist
|
||||
|
||||
def batch_detect(net, imgs, device):
|
||||
imgs = imgs - np.array([104, 117, 123])
|
||||
imgs = imgs.transpose(0, 3, 1, 2)
|
||||
|
||||
if 'cuda' in device:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
imgs = torch.from_numpy(imgs).float().to(device)
|
||||
BB, CC, HH, WW = imgs.size()
|
||||
with torch.no_grad():
|
||||
olist = net(imgs)
|
||||
# print(olist)
|
||||
|
||||
bboxlist = []
|
||||
for i in range(len(olist) // 2):
|
||||
olist[i * 2] = F.softmax(olist[i * 2], dim=1)
|
||||
|
||||
olist = [oelem.cpu() for oelem in olist]
|
||||
for i in range(len(olist) // 2):
|
||||
ocls, oreg = olist[i * 2], olist[i * 2 + 1]
|
||||
FB, FC, FH, FW = ocls.size() # feature map size
|
||||
stride = 2**(i + 2) # 4,8,16,32,64,128
|
||||
anchor = stride * 4
|
||||
poss = zip(*np.where(ocls[:, 1, :, :] > 0.05))
|
||||
for Iindex, hindex, windex in poss:
|
||||
axc, ayc = stride / 2 + windex * stride, stride / 2 + hindex * stride
|
||||
score = ocls[:, 1, hindex, windex]
|
||||
loc = oreg[:, :, hindex, windex].contiguous().view(BB, 1, 4)
|
||||
priors = torch.Tensor([[axc / 1.0, ayc / 1.0, stride * 4 / 1.0, stride * 4 / 1.0]]).view(1, 1, 4)
|
||||
variances = [0.1, 0.2]
|
||||
box = batch_decode(loc, priors, variances)
|
||||
box = box[:, 0] * 1.0
|
||||
# cv2.rectangle(imgshow,(int(x1),int(y1)),(int(x2),int(y2)),(0,0,255),1)
|
||||
bboxlist.append(torch.cat([box, score.unsqueeze(1)], 1).cpu().numpy())
|
||||
bboxlist = np.array(bboxlist)
|
||||
if 0 == len(bboxlist):
|
||||
bboxlist = np.zeros((1, BB, 5))
|
||||
|
||||
return bboxlist
|
||||
|
||||
def flip_detect(net, img, device):
|
||||
img = cv2.flip(img, 1)
|
||||
b = detect(net, img, device)
|
||||
|
||||
bboxlist = np.zeros(b.shape)
|
||||
bboxlist[:, 0] = img.shape[1] - b[:, 2]
|
||||
bboxlist[:, 1] = b[:, 1]
|
||||
bboxlist[:, 2] = img.shape[1] - b[:, 0]
|
||||
bboxlist[:, 3] = b[:, 3]
|
||||
bboxlist[:, 4] = b[:, 4]
|
||||
return bboxlist
|
||||
|
||||
|
||||
def pts_to_bb(pts):
|
||||
min_x, min_y = np.min(pts, axis=0)
|
||||
max_x, max_y = np.max(pts, axis=0)
|
||||
return np.array([min_x, min_y, max_x, max_y])
|
||||
@@ -0,0 +1,129 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class L2Norm(nn.Module):
|
||||
def __init__(self, n_channels, scale=1.0):
|
||||
super(L2Norm, self).__init__()
|
||||
self.n_channels = n_channels
|
||||
self.scale = scale
|
||||
self.eps = 1e-10
|
||||
self.weight = nn.Parameter(torch.Tensor(self.n_channels))
|
||||
self.weight.data *= 0.0
|
||||
self.weight.data += self.scale
|
||||
|
||||
def forward(self, x):
|
||||
norm = x.pow(2).sum(dim=1, keepdim=True).sqrt() + self.eps
|
||||
x = x / norm * self.weight.view(1, -1, 1, 1)
|
||||
return x
|
||||
|
||||
|
||||
class s3fd(nn.Module):
|
||||
def __init__(self):
|
||||
super(s3fd, self).__init__()
|
||||
self.conv1_1 = nn.Conv2d(3, 64, kernel_size=3, stride=1, padding=1)
|
||||
self.conv1_2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.conv2_1 = nn.Conv2d(64, 128, kernel_size=3, stride=1, padding=1)
|
||||
self.conv2_2 = nn.Conv2d(128, 128, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.conv3_1 = nn.Conv2d(128, 256, kernel_size=3, stride=1, padding=1)
|
||||
self.conv3_2 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
||||
self.conv3_3 = nn.Conv2d(256, 256, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.conv4_1 = nn.Conv2d(256, 512, kernel_size=3, stride=1, padding=1)
|
||||
self.conv4_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||
self.conv4_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.conv5_1 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||
self.conv5_2 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||
self.conv5_3 = nn.Conv2d(512, 512, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.fc6 = nn.Conv2d(512, 1024, kernel_size=3, stride=1, padding=3)
|
||||
self.fc7 = nn.Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0)
|
||||
|
||||
self.conv6_1 = nn.Conv2d(1024, 256, kernel_size=1, stride=1, padding=0)
|
||||
self.conv6_2 = nn.Conv2d(256, 512, kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.conv7_1 = nn.Conv2d(512, 128, kernel_size=1, stride=1, padding=0)
|
||||
self.conv7_2 = nn.Conv2d(128, 256, kernel_size=3, stride=2, padding=1)
|
||||
|
||||
self.conv3_3_norm = L2Norm(256, scale=10)
|
||||
self.conv4_3_norm = L2Norm(512, scale=8)
|
||||
self.conv5_3_norm = L2Norm(512, scale=5)
|
||||
|
||||
self.conv3_3_norm_mbox_conf = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
||||
self.conv3_3_norm_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
||||
self.conv4_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
||||
self.conv4_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
||||
self.conv5_3_norm_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
||||
self.conv5_3_norm_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
self.fc7_mbox_conf = nn.Conv2d(1024, 2, kernel_size=3, stride=1, padding=1)
|
||||
self.fc7_mbox_loc = nn.Conv2d(1024, 4, kernel_size=3, stride=1, padding=1)
|
||||
self.conv6_2_mbox_conf = nn.Conv2d(512, 2, kernel_size=3, stride=1, padding=1)
|
||||
self.conv6_2_mbox_loc = nn.Conv2d(512, 4, kernel_size=3, stride=1, padding=1)
|
||||
self.conv7_2_mbox_conf = nn.Conv2d(256, 2, kernel_size=3, stride=1, padding=1)
|
||||
self.conv7_2_mbox_loc = nn.Conv2d(256, 4, kernel_size=3, stride=1, padding=1)
|
||||
|
||||
def forward(self, x):
|
||||
h = F.relu(self.conv1_1(x))
|
||||
h = F.relu(self.conv1_2(h))
|
||||
h = F.max_pool2d(h, 2, 2)
|
||||
|
||||
h = F.relu(self.conv2_1(h))
|
||||
h = F.relu(self.conv2_2(h))
|
||||
h = F.max_pool2d(h, 2, 2)
|
||||
|
||||
h = F.relu(self.conv3_1(h))
|
||||
h = F.relu(self.conv3_2(h))
|
||||
h = F.relu(self.conv3_3(h))
|
||||
f3_3 = h
|
||||
h = F.max_pool2d(h, 2, 2)
|
||||
|
||||
h = F.relu(self.conv4_1(h))
|
||||
h = F.relu(self.conv4_2(h))
|
||||
h = F.relu(self.conv4_3(h))
|
||||
f4_3 = h
|
||||
h = F.max_pool2d(h, 2, 2)
|
||||
|
||||
h = F.relu(self.conv5_1(h))
|
||||
h = F.relu(self.conv5_2(h))
|
||||
h = F.relu(self.conv5_3(h))
|
||||
f5_3 = h
|
||||
h = F.max_pool2d(h, 2, 2)
|
||||
|
||||
h = F.relu(self.fc6(h))
|
||||
h = F.relu(self.fc7(h))
|
||||
ffc7 = h
|
||||
h = F.relu(self.conv6_1(h))
|
||||
h = F.relu(self.conv6_2(h))
|
||||
f6_2 = h
|
||||
h = F.relu(self.conv7_1(h))
|
||||
h = F.relu(self.conv7_2(h))
|
||||
f7_2 = h
|
||||
|
||||
f3_3 = self.conv3_3_norm(f3_3)
|
||||
f4_3 = self.conv4_3_norm(f4_3)
|
||||
f5_3 = self.conv5_3_norm(f5_3)
|
||||
|
||||
cls1 = self.conv3_3_norm_mbox_conf(f3_3)
|
||||
reg1 = self.conv3_3_norm_mbox_loc(f3_3)
|
||||
cls2 = self.conv4_3_norm_mbox_conf(f4_3)
|
||||
reg2 = self.conv4_3_norm_mbox_loc(f4_3)
|
||||
cls3 = self.conv5_3_norm_mbox_conf(f5_3)
|
||||
reg3 = self.conv5_3_norm_mbox_loc(f5_3)
|
||||
cls4 = self.fc7_mbox_conf(ffc7)
|
||||
reg4 = self.fc7_mbox_loc(ffc7)
|
||||
cls5 = self.conv6_2_mbox_conf(f6_2)
|
||||
reg5 = self.conv6_2_mbox_loc(f6_2)
|
||||
cls6 = self.conv7_2_mbox_conf(f7_2)
|
||||
reg6 = self.conv7_2_mbox_loc(f7_2)
|
||||
|
||||
# max-out background label
|
||||
chunk = torch.chunk(cls1, 4, 1)
|
||||
bmax = torch.max(torch.max(chunk[0], chunk[1]), chunk[2])
|
||||
cls1 = torch.cat([bmax, chunk[3]], dim=1)
|
||||
|
||||
return [cls1, reg1, cls2, reg2, cls3, reg3, cls4, reg4, cls5, reg5, cls6, reg6]
|
||||
@@ -0,0 +1,59 @@
|
||||
import os
|
||||
import cv2
|
||||
from torch.utils.model_zoo import load_url
|
||||
|
||||
from ..core import FaceDetector
|
||||
|
||||
from .net_s3fd import s3fd
|
||||
from .bbox import *
|
||||
from .detect import *
|
||||
|
||||
models_urls = {
|
||||
's3fd': 'https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth',
|
||||
}
|
||||
|
||||
|
||||
class SFDDetector(FaceDetector):
|
||||
def __init__(self, device, path_to_detector=os.path.join(os.path.dirname(os.path.abspath(__file__)), 's3fd.pth'), verbose=False):
|
||||
super(SFDDetector, self).__init__(device, verbose)
|
||||
|
||||
# Initialise the face detector
|
||||
if not os.path.isfile(path_to_detector):
|
||||
model_weights = load_url(models_urls['s3fd'])
|
||||
else:
|
||||
model_weights = torch.load(path_to_detector)
|
||||
|
||||
self.face_detector = s3fd()
|
||||
self.face_detector.load_state_dict(model_weights)
|
||||
self.face_detector.to(device)
|
||||
self.face_detector.eval()
|
||||
|
||||
def detect_from_image(self, tensor_or_path):
|
||||
image = self.tensor_or_path_to_ndarray(tensor_or_path)
|
||||
|
||||
bboxlist = detect(self.face_detector, image, device=self.device)
|
||||
keep = nms(bboxlist, 0.3)
|
||||
bboxlist = bboxlist[keep, :]
|
||||
bboxlist = [x for x in bboxlist if x[-1] > 0.5]
|
||||
|
||||
return bboxlist
|
||||
|
||||
def detect_from_batch(self, images):
|
||||
bboxlists = batch_detect(self.face_detector, images, device=self.device)
|
||||
keeps = [nms(bboxlists[:, i, :], 0.3) for i in range(bboxlists.shape[1])]
|
||||
bboxlists = [bboxlists[keep, i, :] for i, keep in enumerate(keeps)]
|
||||
bboxlists = [[x for x in bboxlist if x[-1] > 0.5] for bboxlist in bboxlists]
|
||||
|
||||
return bboxlists
|
||||
|
||||
@property
|
||||
def reference_scale(self):
|
||||
return 195
|
||||
|
||||
@property
|
||||
def reference_x_shift(self):
|
||||
return 0
|
||||
|
||||
@property
|
||||
def reference_y_shift(self):
|
||||
return 0
|
||||
261
models/MuseTalk/musetalk/utils/face_detection/models.py
Normal file
261
models/MuseTalk/musetalk/utils/face_detection/models.py
Normal file
@@ -0,0 +1,261 @@
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import math
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, strd=1, padding=1, bias=False):
|
||||
"3x3 convolution with padding"
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3,
|
||||
stride=strd, padding=padding, bias=bias)
|
||||
|
||||
|
||||
class ConvBlock(nn.Module):
|
||||
def __init__(self, in_planes, out_planes):
|
||||
super(ConvBlock, self).__init__()
|
||||
self.bn1 = nn.BatchNorm2d(in_planes)
|
||||
self.conv1 = conv3x3(in_planes, int(out_planes / 2))
|
||||
self.bn2 = nn.BatchNorm2d(int(out_planes / 2))
|
||||
self.conv2 = conv3x3(int(out_planes / 2), int(out_planes / 4))
|
||||
self.bn3 = nn.BatchNorm2d(int(out_planes / 4))
|
||||
self.conv3 = conv3x3(int(out_planes / 4), int(out_planes / 4))
|
||||
|
||||
if in_planes != out_planes:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.BatchNorm2d(in_planes),
|
||||
nn.ReLU(True),
|
||||
nn.Conv2d(in_planes, out_planes,
|
||||
kernel_size=1, stride=1, bias=False),
|
||||
)
|
||||
else:
|
||||
self.downsample = None
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out1 = self.bn1(x)
|
||||
out1 = F.relu(out1, True)
|
||||
out1 = self.conv1(out1)
|
||||
|
||||
out2 = self.bn2(out1)
|
||||
out2 = F.relu(out2, True)
|
||||
out2 = self.conv2(out2)
|
||||
|
||||
out3 = self.bn3(out2)
|
||||
out3 = F.relu(out3, True)
|
||||
out3 = self.conv3(out3)
|
||||
|
||||
out3 = torch.cat((out1, out2, out3), 1)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(residual)
|
||||
|
||||
out3 += residual
|
||||
|
||||
return out3
|
||||
|
||||
|
||||
class Bottleneck(nn.Module):
|
||||
|
||||
expansion = 4
|
||||
|
||||
def __init__(self, inplanes, planes, stride=1, downsample=None):
|
||||
super(Bottleneck, self).__init__()
|
||||
self.conv1 = nn.Conv2d(inplanes, planes, kernel_size=1, bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(planes)
|
||||
self.conv2 = nn.Conv2d(planes, planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
self.bn2 = nn.BatchNorm2d(planes)
|
||||
self.conv3 = nn.Conv2d(planes, planes * 4, kernel_size=1, bias=False)
|
||||
self.bn3 = nn.BatchNorm2d(planes * 4)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = downsample
|
||||
self.stride = stride
|
||||
|
||||
def forward(self, x):
|
||||
residual = x
|
||||
|
||||
out = self.conv1(x)
|
||||
out = self.bn1(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv2(out)
|
||||
out = self.bn2(out)
|
||||
out = self.relu(out)
|
||||
|
||||
out = self.conv3(out)
|
||||
out = self.bn3(out)
|
||||
|
||||
if self.downsample is not None:
|
||||
residual = self.downsample(x)
|
||||
|
||||
out += residual
|
||||
out = self.relu(out)
|
||||
|
||||
return out
|
||||
|
||||
|
||||
class HourGlass(nn.Module):
|
||||
def __init__(self, num_modules, depth, num_features):
|
||||
super(HourGlass, self).__init__()
|
||||
self.num_modules = num_modules
|
||||
self.depth = depth
|
||||
self.features = num_features
|
||||
|
||||
self._generate_network(self.depth)
|
||||
|
||||
def _generate_network(self, level):
|
||||
self.add_module('b1_' + str(level), ConvBlock(self.features, self.features))
|
||||
|
||||
self.add_module('b2_' + str(level), ConvBlock(self.features, self.features))
|
||||
|
||||
if level > 1:
|
||||
self._generate_network(level - 1)
|
||||
else:
|
||||
self.add_module('b2_plus_' + str(level), ConvBlock(self.features, self.features))
|
||||
|
||||
self.add_module('b3_' + str(level), ConvBlock(self.features, self.features))
|
||||
|
||||
def _forward(self, level, inp):
|
||||
# Upper branch
|
||||
up1 = inp
|
||||
up1 = self._modules['b1_' + str(level)](up1)
|
||||
|
||||
# Lower branch
|
||||
low1 = F.avg_pool2d(inp, 2, stride=2)
|
||||
low1 = self._modules['b2_' + str(level)](low1)
|
||||
|
||||
if level > 1:
|
||||
low2 = self._forward(level - 1, low1)
|
||||
else:
|
||||
low2 = low1
|
||||
low2 = self._modules['b2_plus_' + str(level)](low2)
|
||||
|
||||
low3 = low2
|
||||
low3 = self._modules['b3_' + str(level)](low3)
|
||||
|
||||
up2 = F.interpolate(low3, scale_factor=2, mode='nearest')
|
||||
|
||||
return up1 + up2
|
||||
|
||||
def forward(self, x):
|
||||
return self._forward(self.depth, x)
|
||||
|
||||
|
||||
class FAN(nn.Module):
|
||||
|
||||
def __init__(self, num_modules=1):
|
||||
super(FAN, self).__init__()
|
||||
self.num_modules = num_modules
|
||||
|
||||
# Base part
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.conv2 = ConvBlock(64, 128)
|
||||
self.conv3 = ConvBlock(128, 128)
|
||||
self.conv4 = ConvBlock(128, 256)
|
||||
|
||||
# Stacking part
|
||||
for hg_module in range(self.num_modules):
|
||||
self.add_module('m' + str(hg_module), HourGlass(1, 4, 256))
|
||||
self.add_module('top_m_' + str(hg_module), ConvBlock(256, 256))
|
||||
self.add_module('conv_last' + str(hg_module),
|
||||
nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
||||
self.add_module('bn_end' + str(hg_module), nn.BatchNorm2d(256))
|
||||
self.add_module('l' + str(hg_module), nn.Conv2d(256,
|
||||
68, kernel_size=1, stride=1, padding=0))
|
||||
|
||||
if hg_module < self.num_modules - 1:
|
||||
self.add_module(
|
||||
'bl' + str(hg_module), nn.Conv2d(256, 256, kernel_size=1, stride=1, padding=0))
|
||||
self.add_module('al' + str(hg_module), nn.Conv2d(68,
|
||||
256, kernel_size=1, stride=1, padding=0))
|
||||
|
||||
def forward(self, x):
|
||||
x = F.relu(self.bn1(self.conv1(x)), True)
|
||||
x = F.avg_pool2d(self.conv2(x), 2, stride=2)
|
||||
x = self.conv3(x)
|
||||
x = self.conv4(x)
|
||||
|
||||
previous = x
|
||||
|
||||
outputs = []
|
||||
for i in range(self.num_modules):
|
||||
hg = self._modules['m' + str(i)](previous)
|
||||
|
||||
ll = hg
|
||||
ll = self._modules['top_m_' + str(i)](ll)
|
||||
|
||||
ll = F.relu(self._modules['bn_end' + str(i)]
|
||||
(self._modules['conv_last' + str(i)](ll)), True)
|
||||
|
||||
# Predict heatmaps
|
||||
tmp_out = self._modules['l' + str(i)](ll)
|
||||
outputs.append(tmp_out)
|
||||
|
||||
if i < self.num_modules - 1:
|
||||
ll = self._modules['bl' + str(i)](ll)
|
||||
tmp_out_ = self._modules['al' + str(i)](tmp_out)
|
||||
previous = previous + ll + tmp_out_
|
||||
|
||||
return outputs
|
||||
|
||||
|
||||
class ResNetDepth(nn.Module):
|
||||
|
||||
def __init__(self, block=Bottleneck, layers=[3, 8, 36, 3], num_classes=68):
|
||||
self.inplanes = 64
|
||||
super(ResNetDepth, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3 + 68, 64, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = self._make_layer(block, 64, layers[0])
|
||||
self.layer2 = self._make_layer(block, 128, layers[1], stride=2)
|
||||
self.layer3 = self._make_layer(block, 256, layers[2], stride=2)
|
||||
self.layer4 = self._make_layer(block, 512, layers[3], stride=2)
|
||||
self.avgpool = nn.AvgPool2d(7)
|
||||
self.fc = nn.Linear(512 * block.expansion, num_classes)
|
||||
|
||||
for m in self.modules():
|
||||
if isinstance(m, nn.Conv2d):
|
||||
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
||||
m.weight.data.normal_(0, math.sqrt(2. / n))
|
||||
elif isinstance(m, nn.BatchNorm2d):
|
||||
m.weight.data.fill_(1)
|
||||
m.bias.data.zero_()
|
||||
|
||||
def _make_layer(self, block, planes, blocks, stride=1):
|
||||
downsample = None
|
||||
if stride != 1 or self.inplanes != planes * block.expansion:
|
||||
downsample = nn.Sequential(
|
||||
nn.Conv2d(self.inplanes, planes * block.expansion,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(planes * block.expansion),
|
||||
)
|
||||
|
||||
layers = []
|
||||
layers.append(block(self.inplanes, planes, stride, downsample))
|
||||
self.inplanes = planes * block.expansion
|
||||
for i in range(1, blocks):
|
||||
layers.append(block(self.inplanes, planes))
|
||||
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = self.bn1(x)
|
||||
x = self.relu(x)
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
x = self.layer2(x)
|
||||
x = self.layer3(x)
|
||||
x = self.layer4(x)
|
||||
|
||||
x = self.avgpool(x)
|
||||
x = x.view(x.size(0), -1)
|
||||
x = self.fc(x)
|
||||
|
||||
return x
|
||||
313
models/MuseTalk/musetalk/utils/face_detection/utils.py
Normal file
313
models/MuseTalk/musetalk/utils/face_detection/utils.py
Normal file
@@ -0,0 +1,313 @@
|
||||
from __future__ import print_function
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
import torch
|
||||
import math
|
||||
import numpy as np
|
||||
import cv2
|
||||
|
||||
|
||||
def _gaussian(
|
||||
size=3, sigma=0.25, amplitude=1, normalize=False, width=None,
|
||||
height=None, sigma_horz=None, sigma_vert=None, mean_horz=0.5,
|
||||
mean_vert=0.5):
|
||||
# handle some defaults
|
||||
if width is None:
|
||||
width = size
|
||||
if height is None:
|
||||
height = size
|
||||
if sigma_horz is None:
|
||||
sigma_horz = sigma
|
||||
if sigma_vert is None:
|
||||
sigma_vert = sigma
|
||||
center_x = mean_horz * width + 0.5
|
||||
center_y = mean_vert * height + 0.5
|
||||
gauss = np.empty((height, width), dtype=np.float32)
|
||||
# generate kernel
|
||||
for i in range(height):
|
||||
for j in range(width):
|
||||
gauss[i][j] = amplitude * math.exp(-(math.pow((j + 1 - center_x) / (
|
||||
sigma_horz * width), 2) / 2.0 + math.pow((i + 1 - center_y) / (sigma_vert * height), 2) / 2.0))
|
||||
if normalize:
|
||||
gauss = gauss / np.sum(gauss)
|
||||
return gauss
|
||||
|
||||
|
||||
def draw_gaussian(image, point, sigma):
|
||||
# Check if the gaussian is inside
|
||||
ul = [math.floor(point[0] - 3 * sigma), math.floor(point[1] - 3 * sigma)]
|
||||
br = [math.floor(point[0] + 3 * sigma), math.floor(point[1] + 3 * sigma)]
|
||||
if (ul[0] > image.shape[1] or ul[1] > image.shape[0] or br[0] < 1 or br[1] < 1):
|
||||
return image
|
||||
size = 6 * sigma + 1
|
||||
g = _gaussian(size)
|
||||
g_x = [int(max(1, -ul[0])), int(min(br[0], image.shape[1])) - int(max(1, ul[0])) + int(max(1, -ul[0]))]
|
||||
g_y = [int(max(1, -ul[1])), int(min(br[1], image.shape[0])) - int(max(1, ul[1])) + int(max(1, -ul[1]))]
|
||||
img_x = [int(max(1, ul[0])), int(min(br[0], image.shape[1]))]
|
||||
img_y = [int(max(1, ul[1])), int(min(br[1], image.shape[0]))]
|
||||
assert (g_x[0] > 0 and g_y[1] > 0)
|
||||
image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]
|
||||
] = image[img_y[0] - 1:img_y[1], img_x[0] - 1:img_x[1]] + g[g_y[0] - 1:g_y[1], g_x[0] - 1:g_x[1]]
|
||||
image[image > 1] = 1
|
||||
return image
|
||||
|
||||
|
||||
def transform(point, center, scale, resolution, invert=False):
|
||||
"""Generate and affine transformation matrix.
|
||||
|
||||
Given a set of points, a center, a scale and a targer resolution, the
|
||||
function generates and affine transformation matrix. If invert is ``True``
|
||||
it will produce the inverse transformation.
|
||||
|
||||
Arguments:
|
||||
point {torch.tensor} -- the input 2D point
|
||||
center {torch.tensor or numpy.array} -- the center around which to perform the transformations
|
||||
scale {float} -- the scale of the face/object
|
||||
resolution {float} -- the output resolution
|
||||
|
||||
Keyword Arguments:
|
||||
invert {bool} -- define wherever the function should produce the direct or the
|
||||
inverse transformation matrix (default: {False})
|
||||
"""
|
||||
_pt = torch.ones(3)
|
||||
_pt[0] = point[0]
|
||||
_pt[1] = point[1]
|
||||
|
||||
h = 200.0 * scale
|
||||
t = torch.eye(3)
|
||||
t[0, 0] = resolution / h
|
||||
t[1, 1] = resolution / h
|
||||
t[0, 2] = resolution * (-center[0] / h + 0.5)
|
||||
t[1, 2] = resolution * (-center[1] / h + 0.5)
|
||||
|
||||
if invert:
|
||||
t = torch.inverse(t)
|
||||
|
||||
new_point = (torch.matmul(t, _pt))[0:2]
|
||||
|
||||
return new_point.int()
|
||||
|
||||
|
||||
def crop(image, center, scale, resolution=256.0):
|
||||
"""Center crops an image or set of heatmaps
|
||||
|
||||
Arguments:
|
||||
image {numpy.array} -- an rgb image
|
||||
center {numpy.array} -- the center of the object, usually the same as of the bounding box
|
||||
scale {float} -- scale of the face
|
||||
|
||||
Keyword Arguments:
|
||||
resolution {float} -- the size of the output cropped image (default: {256.0})
|
||||
|
||||
Returns:
|
||||
[type] -- [description]
|
||||
""" # Crop around the center point
|
||||
""" Crops the image around the center. Input is expected to be an np.ndarray """
|
||||
ul = transform([1, 1], center, scale, resolution, True)
|
||||
br = transform([resolution, resolution], center, scale, resolution, True)
|
||||
# pad = math.ceil(torch.norm((ul - br).float()) / 2.0 - (br[0] - ul[0]) / 2.0)
|
||||
if image.ndim > 2:
|
||||
newDim = np.array([br[1] - ul[1], br[0] - ul[0],
|
||||
image.shape[2]], dtype=np.int32)
|
||||
newImg = np.zeros(newDim, dtype=np.uint8)
|
||||
else:
|
||||
newDim = np.array([br[1] - ul[1], br[0] - ul[0]], dtype=np.int)
|
||||
newImg = np.zeros(newDim, dtype=np.uint8)
|
||||
ht = image.shape[0]
|
||||
wd = image.shape[1]
|
||||
newX = np.array(
|
||||
[max(1, -ul[0] + 1), min(br[0], wd) - ul[0]], dtype=np.int32)
|
||||
newY = np.array(
|
||||
[max(1, -ul[1] + 1), min(br[1], ht) - ul[1]], dtype=np.int32)
|
||||
oldX = np.array([max(1, ul[0] + 1), min(br[0], wd)], dtype=np.int32)
|
||||
oldY = np.array([max(1, ul[1] + 1), min(br[1], ht)], dtype=np.int32)
|
||||
newImg[newY[0] - 1:newY[1], newX[0] - 1:newX[1]
|
||||
] = image[oldY[0] - 1:oldY[1], oldX[0] - 1:oldX[1], :]
|
||||
newImg = cv2.resize(newImg, dsize=(int(resolution), int(resolution)),
|
||||
interpolation=cv2.INTER_LINEAR)
|
||||
return newImg
|
||||
|
||||
|
||||
def get_preds_fromhm(hm, center=None, scale=None):
|
||||
"""Obtain (x,y) coordinates given a set of N heatmaps. If the center
|
||||
and the scale is provided the function will return the points also in
|
||||
the original coordinate frame.
|
||||
|
||||
Arguments:
|
||||
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
|
||||
|
||||
Keyword Arguments:
|
||||
center {torch.tensor} -- the center of the bounding box (default: {None})
|
||||
scale {float} -- face scale (default: {None})
|
||||
"""
|
||||
max, idx = torch.max(
|
||||
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
||||
idx += 1
|
||||
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
||||
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
|
||||
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
|
||||
|
||||
for i in range(preds.size(0)):
|
||||
for j in range(preds.size(1)):
|
||||
hm_ = hm[i, j, :]
|
||||
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
|
||||
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
||||
diff = torch.FloatTensor(
|
||||
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
|
||||
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
||||
preds[i, j].add_(diff.sign_().mul_(.25))
|
||||
|
||||
preds.add_(-.5)
|
||||
|
||||
preds_orig = torch.zeros(preds.size())
|
||||
if center is not None and scale is not None:
|
||||
for i in range(hm.size(0)):
|
||||
for j in range(hm.size(1)):
|
||||
preds_orig[i, j] = transform(
|
||||
preds[i, j], center, scale, hm.size(2), True)
|
||||
|
||||
return preds, preds_orig
|
||||
|
||||
def get_preds_fromhm_batch(hm, centers=None, scales=None):
|
||||
"""Obtain (x,y) coordinates given a set of N heatmaps. If the centers
|
||||
and the scales is provided the function will return the points also in
|
||||
the original coordinate frame.
|
||||
|
||||
Arguments:
|
||||
hm {torch.tensor} -- the predicted heatmaps, of shape [B, N, W, H]
|
||||
|
||||
Keyword Arguments:
|
||||
centers {torch.tensor} -- the centers of the bounding box (default: {None})
|
||||
scales {float} -- face scales (default: {None})
|
||||
"""
|
||||
max, idx = torch.max(
|
||||
hm.view(hm.size(0), hm.size(1), hm.size(2) * hm.size(3)), 2)
|
||||
idx += 1
|
||||
preds = idx.view(idx.size(0), idx.size(1), 1).repeat(1, 1, 2).float()
|
||||
preds[..., 0].apply_(lambda x: (x - 1) % hm.size(3) + 1)
|
||||
preds[..., 1].add_(-1).div_(hm.size(2)).floor_().add_(1)
|
||||
|
||||
for i in range(preds.size(0)):
|
||||
for j in range(preds.size(1)):
|
||||
hm_ = hm[i, j, :]
|
||||
pX, pY = int(preds[i, j, 0]) - 1, int(preds[i, j, 1]) - 1
|
||||
if pX > 0 and pX < 63 and pY > 0 and pY < 63:
|
||||
diff = torch.FloatTensor(
|
||||
[hm_[pY, pX + 1] - hm_[pY, pX - 1],
|
||||
hm_[pY + 1, pX] - hm_[pY - 1, pX]])
|
||||
preds[i, j].add_(diff.sign_().mul_(.25))
|
||||
|
||||
preds.add_(-.5)
|
||||
|
||||
preds_orig = torch.zeros(preds.size())
|
||||
if centers is not None and scales is not None:
|
||||
for i in range(hm.size(0)):
|
||||
for j in range(hm.size(1)):
|
||||
preds_orig[i, j] = transform(
|
||||
preds[i, j], centers[i], scales[i], hm.size(2), True)
|
||||
|
||||
return preds, preds_orig
|
||||
|
||||
def shuffle_lr(parts, pairs=None):
|
||||
"""Shuffle the points left-right according to the axis of symmetry
|
||||
of the object.
|
||||
|
||||
Arguments:
|
||||
parts {torch.tensor} -- a 3D or 4D object containing the
|
||||
heatmaps.
|
||||
|
||||
Keyword Arguments:
|
||||
pairs {list of integers} -- [order of the flipped points] (default: {None})
|
||||
"""
|
||||
if pairs is None:
|
||||
pairs = [16, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0,
|
||||
26, 25, 24, 23, 22, 21, 20, 19, 18, 17, 27, 28, 29, 30, 35,
|
||||
34, 33, 32, 31, 45, 44, 43, 42, 47, 46, 39, 38, 37, 36, 41,
|
||||
40, 54, 53, 52, 51, 50, 49, 48, 59, 58, 57, 56, 55, 64, 63,
|
||||
62, 61, 60, 67, 66, 65]
|
||||
if parts.ndimension() == 3:
|
||||
parts = parts[pairs, ...]
|
||||
else:
|
||||
parts = parts[:, pairs, ...]
|
||||
|
||||
return parts
|
||||
|
||||
|
||||
def flip(tensor, is_label=False):
|
||||
"""Flip an image or a set of heatmaps left-right
|
||||
|
||||
Arguments:
|
||||
tensor {numpy.array or torch.tensor} -- [the input image or heatmaps]
|
||||
|
||||
Keyword Arguments:
|
||||
is_label {bool} -- [denote wherever the input is an image or a set of heatmaps ] (default: {False})
|
||||
"""
|
||||
if not torch.is_tensor(tensor):
|
||||
tensor = torch.from_numpy(tensor)
|
||||
|
||||
if is_label:
|
||||
tensor = shuffle_lr(tensor).flip(tensor.ndimension() - 1)
|
||||
else:
|
||||
tensor = tensor.flip(tensor.ndimension() - 1)
|
||||
|
||||
return tensor
|
||||
|
||||
# From pyzolib/paths.py (https://bitbucket.org/pyzo/pyzolib/src/tip/paths.py)
|
||||
|
||||
|
||||
def appdata_dir(appname=None, roaming=False):
|
||||
""" appdata_dir(appname=None, roaming=False)
|
||||
|
||||
Get the path to the application directory, where applications are allowed
|
||||
to write user specific files (e.g. configurations). For non-user specific
|
||||
data, consider using common_appdata_dir().
|
||||
If appname is given, a subdir is appended (and created if necessary).
|
||||
If roaming is True, will prefer a roaming directory (Windows Vista/7).
|
||||
"""
|
||||
|
||||
# Define default user directory
|
||||
userDir = os.getenv('FACEALIGNMENT_USERDIR', None)
|
||||
if userDir is None:
|
||||
userDir = os.path.expanduser('~')
|
||||
if not os.path.isdir(userDir): # pragma: no cover
|
||||
userDir = '/var/tmp' # issue #54
|
||||
|
||||
# Get system app data dir
|
||||
path = None
|
||||
if sys.platform.startswith('win'):
|
||||
path1, path2 = os.getenv('LOCALAPPDATA'), os.getenv('APPDATA')
|
||||
path = (path2 or path1) if roaming else (path1 or path2)
|
||||
elif sys.platform.startswith('darwin'):
|
||||
path = os.path.join(userDir, 'Library', 'Application Support')
|
||||
# On Linux and as fallback
|
||||
if not (path and os.path.isdir(path)):
|
||||
path = userDir
|
||||
|
||||
# Maybe we should store things local to the executable (in case of a
|
||||
# portable distro or a frozen application that wants to be portable)
|
||||
prefix = sys.prefix
|
||||
if getattr(sys, 'frozen', None):
|
||||
prefix = os.path.abspath(os.path.dirname(sys.executable))
|
||||
for reldir in ('settings', '../settings'):
|
||||
localpath = os.path.abspath(os.path.join(prefix, reldir))
|
||||
if os.path.isdir(localpath): # pragma: no cover
|
||||
try:
|
||||
open(os.path.join(localpath, 'test.write'), 'wb').close()
|
||||
os.remove(os.path.join(localpath, 'test.write'))
|
||||
except IOError:
|
||||
pass # We cannot write in this directory
|
||||
else:
|
||||
path = localpath
|
||||
break
|
||||
|
||||
# Get path specific for this app
|
||||
if appname:
|
||||
if path == userDir:
|
||||
appname = '.' + appname.lstrip('.') # Make it a hidden directory
|
||||
path = os.path.join(path, appname)
|
||||
if not os.path.isdir(path): # pragma: no cover
|
||||
os.mkdir(path)
|
||||
|
||||
# Done
|
||||
return path
|
||||
117
models/MuseTalk/musetalk/utils/face_parsing/__init__.py
Normal file
117
models/MuseTalk/musetalk/utils/face_parsing/__init__.py
Normal file
@@ -0,0 +1,117 @@
|
||||
import torch
|
||||
import time
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from .model import BiSeNet
|
||||
import torchvision.transforms as transforms
|
||||
|
||||
class FaceParsing():
|
||||
def __init__(self, left_cheek_width=80, right_cheek_width=80):
|
||||
self.net = self.model_init()
|
||||
self.preprocess = self.image_preprocess()
|
||||
# Ensure all size parameters are integers
|
||||
cone_height = 21
|
||||
tail_height = 12
|
||||
total_size = cone_height + tail_height
|
||||
|
||||
# Create kernel with explicit integer dimensions
|
||||
kernel = np.zeros((total_size, total_size), dtype=np.uint8)
|
||||
center_x = total_size // 2 # Ensure center coordinates are integers
|
||||
|
||||
# Cone part
|
||||
for row in range(cone_height):
|
||||
if row < cone_height//2:
|
||||
continue
|
||||
width = int(2 * (row - cone_height//2) + 1)
|
||||
start = int(center_x - (width // 2))
|
||||
end = int(center_x + (width // 2) + 1)
|
||||
kernel[row, start:end] = 1
|
||||
|
||||
# Vertical extension part
|
||||
if cone_height > 0:
|
||||
base_width = int(kernel[cone_height-1].sum())
|
||||
else:
|
||||
base_width = 1
|
||||
|
||||
for row in range(cone_height, total_size):
|
||||
start = max(0, int(center_x - (base_width//2)))
|
||||
end = min(total_size, int(center_x + (base_width//2) + 1))
|
||||
kernel[row, start:end] = 1
|
||||
self.kernel = kernel
|
||||
|
||||
# Modify cheek erosion kernel to be flatter ellipse
|
||||
self.cheek_kernel = cv2.getStructuringElement(
|
||||
cv2.MORPH_ELLIPSE, (35, 3))
|
||||
|
||||
# Add cheek area mask (protect chin area)
|
||||
self.cheek_mask = self._create_cheek_mask(left_cheek_width=left_cheek_width, right_cheek_width=right_cheek_width)
|
||||
|
||||
def _create_cheek_mask(self, left_cheek_width=80, right_cheek_width=80):
|
||||
"""Create cheek area mask (1/4 area on both sides)"""
|
||||
mask = np.zeros((512, 512), dtype=np.uint8)
|
||||
center = 512 // 2
|
||||
cv2.rectangle(mask, (0, 0), (center - left_cheek_width, 512), 255, -1) # Left cheek
|
||||
cv2.rectangle(mask, (center + right_cheek_width, 0), (512, 512), 255, -1) # Right cheek
|
||||
return mask
|
||||
|
||||
def model_init(self,
|
||||
resnet_path='./models/face-parse-bisent/resnet18-5c106cde.pth',
|
||||
model_pth='./models/face-parse-bisent/79999_iter.pth'):
|
||||
net = BiSeNet(resnet_path)
|
||||
if torch.cuda.is_available():
|
||||
net.cuda()
|
||||
net.load_state_dict(torch.load(model_pth))
|
||||
else:
|
||||
net.load_state_dict(torch.load(model_pth, map_location=torch.device('cpu')))
|
||||
net.eval()
|
||||
return net
|
||||
|
||||
def image_preprocess(self):
|
||||
return transforms.Compose([
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
|
||||
])
|
||||
|
||||
def __call__(self, image, size=(512, 512), mode="raw"):
|
||||
if isinstance(image, str):
|
||||
image = Image.open(image)
|
||||
|
||||
width, height = image.size
|
||||
with torch.no_grad():
|
||||
image = image.resize(size, Image.BILINEAR)
|
||||
img = self.preprocess(image)
|
||||
if torch.cuda.is_available():
|
||||
img = torch.unsqueeze(img, 0).cuda()
|
||||
else:
|
||||
img = torch.unsqueeze(img, 0)
|
||||
out = self.net(img)[0]
|
||||
parsing = out.squeeze(0).cpu().numpy().argmax(0)
|
||||
|
||||
# Add 14:neck, remove 10:nose and 7:8:9
|
||||
if mode == "neck":
|
||||
parsing[np.isin(parsing, [1, 11, 12, 13, 14])] = 255
|
||||
parsing[np.where(parsing!=255)] = 0
|
||||
elif mode == "jaw":
|
||||
face_region = np.isin(parsing, [1])*255
|
||||
face_region = face_region.astype(np.uint8)
|
||||
original_dilated = cv2.dilate(face_region, self.kernel, iterations=1)
|
||||
eroded = cv2.erode(original_dilated, self.cheek_kernel, iterations=2)
|
||||
face_region = cv2.bitwise_and(eroded, self.cheek_mask)
|
||||
face_region = cv2.bitwise_or(face_region, cv2.bitwise_and(original_dilated, ~self.cheek_mask))
|
||||
parsing[(face_region==255) & (~np.isin(parsing, [10]))] = 255
|
||||
parsing[np.isin(parsing, [11, 12, 13])] = 255
|
||||
parsing[np.where(parsing!=255)] = 0
|
||||
else:
|
||||
parsing[np.isin(parsing, [1, 11, 12, 13])] = 255
|
||||
parsing[np.where(parsing!=255)] = 0
|
||||
|
||||
parsing = Image.fromarray(parsing.astype(np.uint8))
|
||||
return parsing
|
||||
|
||||
if __name__ == "__main__":
|
||||
fp = FaceParsing()
|
||||
segmap = fp('154_small.png')
|
||||
segmap.save('res.png')
|
||||
|
||||
283
models/MuseTalk/musetalk/utils/face_parsing/model.py
Normal file
283
models/MuseTalk/musetalk/utils/face_parsing/model.py
Normal file
@@ -0,0 +1,283 @@
|
||||
#!/usr/bin/python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchvision
|
||||
|
||||
from .resnet import Resnet18
|
||||
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
||||
|
||||
|
||||
class ConvBNReLU(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, ks=3, stride=1, padding=1, *args, **kwargs):
|
||||
super(ConvBNReLU, self).__init__()
|
||||
self.conv = nn.Conv2d(in_chan,
|
||||
out_chan,
|
||||
kernel_size = ks,
|
||||
stride = stride,
|
||||
padding = padding,
|
||||
bias = False)
|
||||
self.bn = nn.BatchNorm2d(out_chan)
|
||||
self.init_weight()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = F.relu(self.bn(x))
|
||||
return x
|
||||
|
||||
def init_weight(self):
|
||||
for ly in self.children():
|
||||
if isinstance(ly, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||
|
||||
class BiSeNetOutput(nn.Module):
|
||||
def __init__(self, in_chan, mid_chan, n_classes, *args, **kwargs):
|
||||
super(BiSeNetOutput, self).__init__()
|
||||
self.conv = ConvBNReLU(in_chan, mid_chan, ks=3, stride=1, padding=1)
|
||||
self.conv_out = nn.Conv2d(mid_chan, n_classes, kernel_size=1, bias=False)
|
||||
self.init_weight()
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv(x)
|
||||
x = self.conv_out(x)
|
||||
return x
|
||||
|
||||
def init_weight(self):
|
||||
for ly in self.children():
|
||||
if isinstance(ly, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||
|
||||
def get_params(self):
|
||||
wd_params, nowd_params = [], []
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
||||
wd_params.append(module.weight)
|
||||
if not module.bias is None:
|
||||
nowd_params.append(module.bias)
|
||||
elif isinstance(module, nn.BatchNorm2d):
|
||||
nowd_params += list(module.parameters())
|
||||
return wd_params, nowd_params
|
||||
|
||||
|
||||
class AttentionRefinementModule(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
||||
super(AttentionRefinementModule, self).__init__()
|
||||
self.conv = ConvBNReLU(in_chan, out_chan, ks=3, stride=1, padding=1)
|
||||
self.conv_atten = nn.Conv2d(out_chan, out_chan, kernel_size= 1, bias=False)
|
||||
self.bn_atten = nn.BatchNorm2d(out_chan)
|
||||
self.sigmoid_atten = nn.Sigmoid()
|
||||
self.init_weight()
|
||||
|
||||
def forward(self, x):
|
||||
feat = self.conv(x)
|
||||
atten = F.avg_pool2d(feat, feat.size()[2:])
|
||||
atten = self.conv_atten(atten)
|
||||
atten = self.bn_atten(atten)
|
||||
atten = self.sigmoid_atten(atten)
|
||||
out = torch.mul(feat, atten)
|
||||
return out
|
||||
|
||||
def init_weight(self):
|
||||
for ly in self.children():
|
||||
if isinstance(ly, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||
|
||||
|
||||
class ContextPath(nn.Module):
|
||||
def __init__(self, resnet_path, *args, **kwargs):
|
||||
super(ContextPath, self).__init__()
|
||||
self.resnet = Resnet18(resnet_path)
|
||||
self.arm16 = AttentionRefinementModule(256, 128)
|
||||
self.arm32 = AttentionRefinementModule(512, 128)
|
||||
self.conv_head32 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
||||
self.conv_head16 = ConvBNReLU(128, 128, ks=3, stride=1, padding=1)
|
||||
self.conv_avg = ConvBNReLU(512, 128, ks=1, stride=1, padding=0)
|
||||
|
||||
self.init_weight()
|
||||
|
||||
def forward(self, x):
|
||||
H0, W0 = x.size()[2:]
|
||||
feat8, feat16, feat32 = self.resnet(x)
|
||||
H8, W8 = feat8.size()[2:]
|
||||
H16, W16 = feat16.size()[2:]
|
||||
H32, W32 = feat32.size()[2:]
|
||||
|
||||
avg = F.avg_pool2d(feat32, feat32.size()[2:])
|
||||
avg = self.conv_avg(avg)
|
||||
avg_up = F.interpolate(avg, (H32, W32), mode='nearest')
|
||||
|
||||
feat32_arm = self.arm32(feat32)
|
||||
feat32_sum = feat32_arm + avg_up
|
||||
feat32_up = F.interpolate(feat32_sum, (H16, W16), mode='nearest')
|
||||
feat32_up = self.conv_head32(feat32_up)
|
||||
|
||||
feat16_arm = self.arm16(feat16)
|
||||
feat16_sum = feat16_arm + feat32_up
|
||||
feat16_up = F.interpolate(feat16_sum, (H8, W8), mode='nearest')
|
||||
feat16_up = self.conv_head16(feat16_up)
|
||||
|
||||
return feat8, feat16_up, feat32_up # x8, x8, x16
|
||||
|
||||
def init_weight(self):
|
||||
for ly in self.children():
|
||||
if isinstance(ly, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||
|
||||
def get_params(self):
|
||||
wd_params, nowd_params = [], []
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
wd_params.append(module.weight)
|
||||
if not module.bias is None:
|
||||
nowd_params.append(module.bias)
|
||||
elif isinstance(module, nn.BatchNorm2d):
|
||||
nowd_params += list(module.parameters())
|
||||
return wd_params, nowd_params
|
||||
|
||||
|
||||
### This is not used, since I replace this with the resnet feature with the same size
|
||||
class SpatialPath(nn.Module):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(SpatialPath, self).__init__()
|
||||
self.conv1 = ConvBNReLU(3, 64, ks=7, stride=2, padding=3)
|
||||
self.conv2 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
||||
self.conv3 = ConvBNReLU(64, 64, ks=3, stride=2, padding=1)
|
||||
self.conv_out = ConvBNReLU(64, 128, ks=1, stride=1, padding=0)
|
||||
self.init_weight()
|
||||
|
||||
def forward(self, x):
|
||||
feat = self.conv1(x)
|
||||
feat = self.conv2(feat)
|
||||
feat = self.conv3(feat)
|
||||
feat = self.conv_out(feat)
|
||||
return feat
|
||||
|
||||
def init_weight(self):
|
||||
for ly in self.children():
|
||||
if isinstance(ly, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||
|
||||
def get_params(self):
|
||||
wd_params, nowd_params = [], []
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
||||
wd_params.append(module.weight)
|
||||
if not module.bias is None:
|
||||
nowd_params.append(module.bias)
|
||||
elif isinstance(module, nn.BatchNorm2d):
|
||||
nowd_params += list(module.parameters())
|
||||
return wd_params, nowd_params
|
||||
|
||||
|
||||
class FeatureFusionModule(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, *args, **kwargs):
|
||||
super(FeatureFusionModule, self).__init__()
|
||||
self.convblk = ConvBNReLU(in_chan, out_chan, ks=1, stride=1, padding=0)
|
||||
self.conv1 = nn.Conv2d(out_chan,
|
||||
out_chan//4,
|
||||
kernel_size = 1,
|
||||
stride = 1,
|
||||
padding = 0,
|
||||
bias = False)
|
||||
self.conv2 = nn.Conv2d(out_chan//4,
|
||||
out_chan,
|
||||
kernel_size = 1,
|
||||
stride = 1,
|
||||
padding = 0,
|
||||
bias = False)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.sigmoid = nn.Sigmoid()
|
||||
self.init_weight()
|
||||
|
||||
def forward(self, fsp, fcp):
|
||||
fcat = torch.cat([fsp, fcp], dim=1)
|
||||
feat = self.convblk(fcat)
|
||||
atten = F.avg_pool2d(feat, feat.size()[2:])
|
||||
atten = self.conv1(atten)
|
||||
atten = self.relu(atten)
|
||||
atten = self.conv2(atten)
|
||||
atten = self.sigmoid(atten)
|
||||
feat_atten = torch.mul(feat, atten)
|
||||
feat_out = feat_atten + feat
|
||||
return feat_out
|
||||
|
||||
def init_weight(self):
|
||||
for ly in self.children():
|
||||
if isinstance(ly, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||
|
||||
def get_params(self):
|
||||
wd_params, nowd_params = [], []
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
||||
wd_params.append(module.weight)
|
||||
if not module.bias is None:
|
||||
nowd_params.append(module.bias)
|
||||
elif isinstance(module, nn.BatchNorm2d):
|
||||
nowd_params += list(module.parameters())
|
||||
return wd_params, nowd_params
|
||||
|
||||
|
||||
class BiSeNet(nn.Module):
|
||||
def __init__(self, resnet_path='models/resnet18-5c106cde.pth', n_classes=19, *args, **kwargs):
|
||||
super(BiSeNet, self).__init__()
|
||||
self.cp = ContextPath(resnet_path)
|
||||
## here self.sp is deleted
|
||||
self.ffm = FeatureFusionModule(256, 256)
|
||||
self.conv_out = BiSeNetOutput(256, 256, n_classes)
|
||||
self.conv_out16 = BiSeNetOutput(128, 64, n_classes)
|
||||
self.conv_out32 = BiSeNetOutput(128, 64, n_classes)
|
||||
self.init_weight()
|
||||
|
||||
def forward(self, x):
|
||||
H, W = x.size()[2:]
|
||||
feat_res8, feat_cp8, feat_cp16 = self.cp(x) # here return res3b1 feature
|
||||
feat_sp = feat_res8 # use res3b1 feature to replace spatial path feature
|
||||
feat_fuse = self.ffm(feat_sp, feat_cp8)
|
||||
|
||||
feat_out = self.conv_out(feat_fuse)
|
||||
feat_out16 = self.conv_out16(feat_cp8)
|
||||
feat_out32 = self.conv_out32(feat_cp16)
|
||||
|
||||
feat_out = F.interpolate(feat_out, (H, W), mode='bilinear', align_corners=True)
|
||||
feat_out16 = F.interpolate(feat_out16, (H, W), mode='bilinear', align_corners=True)
|
||||
feat_out32 = F.interpolate(feat_out32, (H, W), mode='bilinear', align_corners=True)
|
||||
return feat_out, feat_out16, feat_out32
|
||||
|
||||
def init_weight(self):
|
||||
for ly in self.children():
|
||||
if isinstance(ly, nn.Conv2d):
|
||||
nn.init.kaiming_normal_(ly.weight, a=1)
|
||||
if not ly.bias is None: nn.init.constant_(ly.bias, 0)
|
||||
|
||||
def get_params(self):
|
||||
wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params = [], [], [], []
|
||||
for name, child in self.named_children():
|
||||
child_wd_params, child_nowd_params = child.get_params()
|
||||
if isinstance(child, FeatureFusionModule) or isinstance(child, BiSeNetOutput):
|
||||
lr_mul_wd_params += child_wd_params
|
||||
lr_mul_nowd_params += child_nowd_params
|
||||
else:
|
||||
wd_params += child_wd_params
|
||||
nowd_params += child_nowd_params
|
||||
return wd_params, nowd_params, lr_mul_wd_params, lr_mul_nowd_params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
net = BiSeNet(19)
|
||||
net.cuda()
|
||||
net.eval()
|
||||
in_ten = torch.randn(16, 3, 640, 480).cuda()
|
||||
out, out16, out32 = net(in_ten)
|
||||
print(out.shape)
|
||||
|
||||
net.get_params()
|
||||
109
models/MuseTalk/musetalk/utils/face_parsing/resnet.py
Normal file
109
models/MuseTalk/musetalk/utils/face_parsing/resnet.py
Normal file
@@ -0,0 +1,109 @@
|
||||
#!/usr/bin/python
|
||||
# -*- encoding: utf-8 -*-
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torch.utils.model_zoo as modelzoo
|
||||
|
||||
# from modules.bn import InPlaceABNSync as BatchNorm2d
|
||||
|
||||
resnet18_url = 'https://download.pytorch.org/models/resnet18-5c106cde.pth'
|
||||
|
||||
|
||||
def conv3x3(in_planes, out_planes, stride=1):
|
||||
"""3x3 convolution with padding"""
|
||||
return nn.Conv2d(in_planes, out_planes, kernel_size=3, stride=stride,
|
||||
padding=1, bias=False)
|
||||
|
||||
|
||||
class BasicBlock(nn.Module):
|
||||
def __init__(self, in_chan, out_chan, stride=1):
|
||||
super(BasicBlock, self).__init__()
|
||||
self.conv1 = conv3x3(in_chan, out_chan, stride)
|
||||
self.bn1 = nn.BatchNorm2d(out_chan)
|
||||
self.conv2 = conv3x3(out_chan, out_chan)
|
||||
self.bn2 = nn.BatchNorm2d(out_chan)
|
||||
self.relu = nn.ReLU(inplace=True)
|
||||
self.downsample = None
|
||||
if in_chan != out_chan or stride != 1:
|
||||
self.downsample = nn.Sequential(
|
||||
nn.Conv2d(in_chan, out_chan,
|
||||
kernel_size=1, stride=stride, bias=False),
|
||||
nn.BatchNorm2d(out_chan),
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
residual = self.conv1(x)
|
||||
residual = F.relu(self.bn1(residual))
|
||||
residual = self.conv2(residual)
|
||||
residual = self.bn2(residual)
|
||||
|
||||
shortcut = x
|
||||
if self.downsample is not None:
|
||||
shortcut = self.downsample(x)
|
||||
|
||||
out = shortcut + residual
|
||||
out = self.relu(out)
|
||||
return out
|
||||
|
||||
|
||||
def create_layer_basic(in_chan, out_chan, bnum, stride=1):
|
||||
layers = [BasicBlock(in_chan, out_chan, stride=stride)]
|
||||
for i in range(bnum-1):
|
||||
layers.append(BasicBlock(out_chan, out_chan, stride=1))
|
||||
return nn.Sequential(*layers)
|
||||
|
||||
|
||||
class Resnet18(nn.Module):
|
||||
def __init__(self, model_path):
|
||||
super(Resnet18, self).__init__()
|
||||
self.conv1 = nn.Conv2d(3, 64, kernel_size=7, stride=2, padding=3,
|
||||
bias=False)
|
||||
self.bn1 = nn.BatchNorm2d(64)
|
||||
self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
|
||||
self.layer1 = create_layer_basic(64, 64, bnum=2, stride=1)
|
||||
self.layer2 = create_layer_basic(64, 128, bnum=2, stride=2)
|
||||
self.layer3 = create_layer_basic(128, 256, bnum=2, stride=2)
|
||||
self.layer4 = create_layer_basic(256, 512, bnum=2, stride=2)
|
||||
self.init_weight(model_path)
|
||||
|
||||
def forward(self, x):
|
||||
x = self.conv1(x)
|
||||
x = F.relu(self.bn1(x))
|
||||
x = self.maxpool(x)
|
||||
|
||||
x = self.layer1(x)
|
||||
feat8 = self.layer2(x) # 1/8
|
||||
feat16 = self.layer3(feat8) # 1/16
|
||||
feat32 = self.layer4(feat16) # 1/32
|
||||
return feat8, feat16, feat32
|
||||
|
||||
def init_weight(self, model_path):
|
||||
state_dict = torch.load(model_path) #modelzoo.load_url(resnet18_url)
|
||||
self_state_dict = self.state_dict()
|
||||
for k, v in state_dict.items():
|
||||
if 'fc' in k: continue
|
||||
self_state_dict.update({k: v})
|
||||
self.load_state_dict(self_state_dict)
|
||||
|
||||
def get_params(self):
|
||||
wd_params, nowd_params = [], []
|
||||
for name, module in self.named_modules():
|
||||
if isinstance(module, (nn.Linear, nn.Conv2d)):
|
||||
wd_params.append(module.weight)
|
||||
if not module.bias is None:
|
||||
nowd_params.append(module.bias)
|
||||
elif isinstance(module, nn.BatchNorm2d):
|
||||
nowd_params += list(module.parameters())
|
||||
return wd_params, nowd_params
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
net = Resnet18()
|
||||
x = torch.randn(16, 3, 224, 224)
|
||||
out = net(x)
|
||||
print(out[0].size())
|
||||
print(out[1].size())
|
||||
print(out[2].size())
|
||||
net.get_params()
|
||||
155
models/MuseTalk/musetalk/utils/preprocessing.py
Normal file
155
models/MuseTalk/musetalk/utils/preprocessing.py
Normal file
@@ -0,0 +1,155 @@
|
||||
import sys
|
||||
from face_detection import FaceAlignment,LandmarksType
|
||||
from os import listdir, path
|
||||
import subprocess
|
||||
import numpy as np
|
||||
import cv2
|
||||
import pickle
|
||||
import os
|
||||
import json
|
||||
from mmpose.apis import inference_topdown, init_model
|
||||
from mmpose.structures import merge_data_samples
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
# initialize the mmpose model
|
||||
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
config_file = './musetalk/utils/dwpose/rtmpose-l_8xb32-270e_coco-ubody-wholebody-384x288.py'
|
||||
checkpoint_file = './models/dwpose/dw-ll_ucoco_384.pth'
|
||||
model = init_model(config_file, checkpoint_file, device=device)
|
||||
|
||||
# initialize the face detection model
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
fa = FaceAlignment(LandmarksType._2D, flip_input=False,device=device)
|
||||
|
||||
# maker if the bbox is not sufficient
|
||||
coord_placeholder = (0.0,0.0,0.0,0.0)
|
||||
|
||||
def resize_landmark(landmark, w, h, new_w, new_h):
|
||||
w_ratio = new_w / w
|
||||
h_ratio = new_h / h
|
||||
landmark_norm = landmark / [w, h]
|
||||
landmark_resized = landmark_norm * [new_w, new_h]
|
||||
return landmark_resized
|
||||
|
||||
def read_imgs(img_list):
|
||||
frames = []
|
||||
print('reading images...')
|
||||
for img_path in tqdm(img_list):
|
||||
frame = cv2.imread(img_path)
|
||||
frames.append(frame)
|
||||
return frames
|
||||
|
||||
def get_bbox_range(img_list,upperbondrange =0):
|
||||
frames = read_imgs(img_list)
|
||||
batch_size_fa = 1
|
||||
batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)]
|
||||
coords_list = []
|
||||
landmarks = []
|
||||
if upperbondrange != 0:
|
||||
print('get key_landmark and face bounding boxes with the bbox_shift:',upperbondrange)
|
||||
else:
|
||||
print('get key_landmark and face bounding boxes with the default value')
|
||||
average_range_minus = []
|
||||
average_range_plus = []
|
||||
for fb in tqdm(batches):
|
||||
results = inference_topdown(model, np.asarray(fb)[0])
|
||||
results = merge_data_samples(results)
|
||||
keypoints = results.pred_instances.keypoints
|
||||
face_land_mark= keypoints[0][23:91]
|
||||
face_land_mark = face_land_mark.astype(np.int32)
|
||||
|
||||
# get bounding boxes by face detetion
|
||||
bbox = fa.get_detections_for_batch(np.asarray(fb))
|
||||
|
||||
# adjust the bounding box refer to landmark
|
||||
# Add the bounding box to a tuple and append it to the coordinates list
|
||||
for j, f in enumerate(bbox):
|
||||
if f is None: # no face in the image
|
||||
coords_list += [coord_placeholder]
|
||||
continue
|
||||
|
||||
half_face_coord = face_land_mark[29]#np.mean([face_land_mark[28], face_land_mark[29]], axis=0)
|
||||
range_minus = (face_land_mark[30]- face_land_mark[29])[1]
|
||||
range_plus = (face_land_mark[29]- face_land_mark[28])[1]
|
||||
average_range_minus.append(range_minus)
|
||||
average_range_plus.append(range_plus)
|
||||
if upperbondrange != 0:
|
||||
half_face_coord[1] = upperbondrange+half_face_coord[1] #手动调整 + 向下(偏29) - 向上(偏28)
|
||||
|
||||
text_range=f"Total frame:「{len(frames)}」 Manually adjust range : [ -{int(sum(average_range_minus) / len(average_range_minus))}~{int(sum(average_range_plus) / len(average_range_plus))} ] , the current value: {upperbondrange}"
|
||||
return text_range
|
||||
|
||||
|
||||
def get_landmark_and_bbox(img_list,upperbondrange =0):
|
||||
frames = read_imgs(img_list)
|
||||
batch_size_fa = 1
|
||||
batches = [frames[i:i + batch_size_fa] for i in range(0, len(frames), batch_size_fa)]
|
||||
coords_list = []
|
||||
landmarks = []
|
||||
if upperbondrange != 0:
|
||||
print('get key_landmark and face bounding boxes with the bbox_shift:',upperbondrange)
|
||||
else:
|
||||
print('get key_landmark and face bounding boxes with the default value')
|
||||
average_range_minus = []
|
||||
average_range_plus = []
|
||||
for fb in tqdm(batches):
|
||||
results = inference_topdown(model, np.asarray(fb)[0])
|
||||
results = merge_data_samples(results)
|
||||
keypoints = results.pred_instances.keypoints
|
||||
face_land_mark= keypoints[0][23:91]
|
||||
face_land_mark = face_land_mark.astype(np.int32)
|
||||
|
||||
# get bounding boxes by face detetion
|
||||
bbox = fa.get_detections_for_batch(np.asarray(fb))
|
||||
|
||||
# adjust the bounding box refer to landmark
|
||||
# Add the bounding box to a tuple and append it to the coordinates list
|
||||
for j, f in enumerate(bbox):
|
||||
if f is None: # no face in the image
|
||||
coords_list += [coord_placeholder]
|
||||
continue
|
||||
|
||||
half_face_coord = face_land_mark[29]#np.mean([face_land_mark[28], face_land_mark[29]], axis=0)
|
||||
range_minus = (face_land_mark[30]- face_land_mark[29])[1]
|
||||
range_plus = (face_land_mark[29]- face_land_mark[28])[1]
|
||||
average_range_minus.append(range_minus)
|
||||
average_range_plus.append(range_plus)
|
||||
if upperbondrange != 0:
|
||||
half_face_coord[1] = upperbondrange+half_face_coord[1] #手动调整 + 向下(偏29) - 向上(偏28)
|
||||
half_face_dist = np.max(face_land_mark[:,1]) - half_face_coord[1]
|
||||
min_upper_bond = 0
|
||||
upper_bond = max(min_upper_bond, half_face_coord[1] - half_face_dist)
|
||||
|
||||
f_landmark = (np.min(face_land_mark[:, 0]),int(upper_bond),np.max(face_land_mark[:, 0]),np.max(face_land_mark[:,1]))
|
||||
x1, y1, x2, y2 = f_landmark
|
||||
|
||||
if y2-y1<=0 or x2-x1<=0 or x1<0: # if the landmark bbox is not suitable, reuse the bbox
|
||||
coords_list += [f]
|
||||
w,h = f[2]-f[0], f[3]-f[1]
|
||||
print("error bbox:",f)
|
||||
else:
|
||||
coords_list += [f_landmark]
|
||||
|
||||
print("********************************************bbox_shift parameter adjustment**********************************************************")
|
||||
print(f"Total frame:「{len(frames)}」 Manually adjust range : [ -{int(sum(average_range_minus) / len(average_range_minus))}~{int(sum(average_range_plus) / len(average_range_plus))} ] , the current value: {upperbondrange}")
|
||||
print("*************************************************************************************************************************************")
|
||||
return coords_list,frames
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
img_list = ["./results/lyria/00000.png","./results/lyria/00001.png","./results/lyria/00002.png","./results/lyria/00003.png"]
|
||||
crop_coord_path = "./coord_face.pkl"
|
||||
coords_list,full_frames = get_landmark_and_bbox(img_list)
|
||||
with open(crop_coord_path, 'wb') as f:
|
||||
pickle.dump(coords_list, f)
|
||||
|
||||
for bbox, frame in zip(coords_list,full_frames):
|
||||
if bbox == coord_placeholder:
|
||||
continue
|
||||
x1, y1, x2, y2 = bbox
|
||||
crop_frame = frame[y1:y2, x1:x2]
|
||||
print('Cropped shape', crop_frame.shape)
|
||||
|
||||
#cv2.imwrite(path.join(save_dir, '{}.png'.format(i)),full_frames[i][0][y1:y2, x1:x2])
|
||||
print(coords_list)
|
||||
337
models/MuseTalk/musetalk/utils/training_utils.py
Normal file
337
models/MuseTalk/musetalk/utils/training_utils.py
Normal file
@@ -0,0 +1,337 @@
|
||||
import os
|
||||
import json
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from torch.optim.lr_scheduler import CosineAnnealingLR
|
||||
from diffusers import AutoencoderKL, UNet2DConditionModel
|
||||
from transformers import WhisperModel
|
||||
from diffusers.optimization import get_scheduler
|
||||
from omegaconf import OmegaConf
|
||||
from einops import rearrange
|
||||
|
||||
from musetalk.models.syncnet import SyncNet
|
||||
from musetalk.loss.discriminator import MultiScaleDiscriminator, DiscriminatorFullModel
|
||||
from musetalk.loss.basic_loss import Interpolate
|
||||
import musetalk.loss.vgg_face as vgg_face
|
||||
from musetalk.data.dataset import PortraitDataset
|
||||
from musetalk.utils.utils import (
|
||||
get_image_pred,
|
||||
process_audio_features,
|
||||
process_and_save_images
|
||||
)
|
||||
|
||||
class Net(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
unet: UNet2DConditionModel,
|
||||
):
|
||||
super().__init__()
|
||||
self.unet = unet
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_latents,
|
||||
timesteps,
|
||||
audio_prompts,
|
||||
):
|
||||
model_pred = self.unet(
|
||||
input_latents,
|
||||
timesteps,
|
||||
encoder_hidden_states=audio_prompts
|
||||
).sample
|
||||
return model_pred
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
def initialize_models_and_optimizers(cfg, accelerator, weight_dtype):
|
||||
"""Initialize models and optimizers"""
|
||||
model_dict = {
|
||||
'vae': None,
|
||||
'unet': None,
|
||||
'net': None,
|
||||
'wav2vec': None,
|
||||
'optimizer': None,
|
||||
'lr_scheduler': None,
|
||||
'scheduler_max_steps': None,
|
||||
'trainable_params': None
|
||||
}
|
||||
|
||||
model_dict['vae'] = AutoencoderKL.from_pretrained(
|
||||
cfg.pretrained_model_name_or_path,
|
||||
subfolder=cfg.vae_type,
|
||||
)
|
||||
|
||||
unet_config_file = os.path.join(
|
||||
cfg.pretrained_model_name_or_path,
|
||||
cfg.unet_sub_folder + "/musetalk.json"
|
||||
)
|
||||
|
||||
with open(unet_config_file, 'r') as f:
|
||||
unet_config = json.load(f)
|
||||
model_dict['unet'] = UNet2DConditionModel(**unet_config)
|
||||
|
||||
if not cfg.random_init_unet:
|
||||
pretrained_unet_path = os.path.join(cfg.pretrained_model_name_or_path, cfg.unet_sub_folder, "pytorch_model.bin")
|
||||
print(f"### Loading existing unet weights from {pretrained_unet_path}. ###")
|
||||
checkpoint = torch.load(pretrained_unet_path, map_location=accelerator.device)
|
||||
model_dict['unet'].load_state_dict(checkpoint)
|
||||
|
||||
unet_params = [p.numel() for n, p in model_dict['unet'].named_parameters()]
|
||||
logger.info(f"unet {sum(unet_params) / 1e6}M-parameter")
|
||||
|
||||
model_dict['vae'].requires_grad_(False)
|
||||
model_dict['unet'].requires_grad_(True)
|
||||
|
||||
model_dict['vae'].to(accelerator.device, dtype=weight_dtype)
|
||||
|
||||
model_dict['net'] = Net(model_dict['unet'])
|
||||
|
||||
model_dict['wav2vec'] = WhisperModel.from_pretrained(cfg.whisper_path).to(
|
||||
device="cuda", dtype=weight_dtype).eval()
|
||||
model_dict['wav2vec'].requires_grad_(False)
|
||||
|
||||
if cfg.solver.gradient_checkpointing:
|
||||
model_dict['unet'].enable_gradient_checkpointing()
|
||||
|
||||
if cfg.solver.scale_lr:
|
||||
learning_rate = (
|
||||
cfg.solver.learning_rate
|
||||
* cfg.solver.gradient_accumulation_steps
|
||||
* cfg.data.train_bs
|
||||
* accelerator.num_processes
|
||||
)
|
||||
else:
|
||||
learning_rate = cfg.solver.learning_rate
|
||||
|
||||
if cfg.solver.use_8bit_adam:
|
||||
try:
|
||||
import bitsandbytes as bnb
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`"
|
||||
)
|
||||
optimizer_cls = bnb.optim.AdamW8bit
|
||||
else:
|
||||
optimizer_cls = torch.optim.AdamW
|
||||
|
||||
model_dict['trainable_params'] = list(filter(lambda p: p.requires_grad, model_dict['net'].parameters()))
|
||||
if accelerator.is_main_process:
|
||||
print('trainable params')
|
||||
for n, p in model_dict['net'].named_parameters():
|
||||
if p.requires_grad:
|
||||
print(n)
|
||||
|
||||
model_dict['optimizer'] = optimizer_cls(
|
||||
model_dict['trainable_params'],
|
||||
lr=learning_rate,
|
||||
betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2),
|
||||
weight_decay=cfg.solver.adam_weight_decay,
|
||||
eps=cfg.solver.adam_epsilon,
|
||||
)
|
||||
|
||||
model_dict['scheduler_max_steps'] = cfg.solver.max_train_steps * cfg.solver.gradient_accumulation_steps
|
||||
model_dict['lr_scheduler'] = get_scheduler(
|
||||
cfg.solver.lr_scheduler,
|
||||
optimizer=model_dict['optimizer'],
|
||||
num_warmup_steps=cfg.solver.lr_warmup_steps * cfg.solver.gradient_accumulation_steps,
|
||||
num_training_steps=model_dict['scheduler_max_steps'],
|
||||
)
|
||||
|
||||
return model_dict
|
||||
|
||||
def initialize_dataloaders(cfg):
|
||||
"""Initialize training and validation dataloaders"""
|
||||
dataloader_dict = {
|
||||
'train_dataset': None,
|
||||
'val_dataset': None,
|
||||
'train_dataloader': None,
|
||||
'val_dataloader': None
|
||||
}
|
||||
|
||||
dataloader_dict['train_dataset'] = PortraitDataset(cfg={
|
||||
'image_size': cfg.data.image_size,
|
||||
'T': cfg.data.n_sample_frames,
|
||||
"sample_method": cfg.data.sample_method,
|
||||
'top_k_ratio': cfg.data.top_k_ratio,
|
||||
"contorl_face_min_size": cfg.data.contorl_face_min_size,
|
||||
"dataset_key": cfg.data.dataset_key,
|
||||
"padding_pixel_mouth": cfg.padding_pixel_mouth,
|
||||
"whisper_path": cfg.whisper_path,
|
||||
"min_face_size": cfg.data.min_face_size,
|
||||
"cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean,
|
||||
"cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std,
|
||||
"crop_type": cfg.crop_type,
|
||||
"random_margin_method": cfg.random_margin_method,
|
||||
})
|
||||
|
||||
dataloader_dict['train_dataloader'] = torch.utils.data.DataLoader(
|
||||
dataloader_dict['train_dataset'],
|
||||
batch_size=cfg.data.train_bs,
|
||||
shuffle=True,
|
||||
num_workers=cfg.data.num_workers,
|
||||
)
|
||||
|
||||
dataloader_dict['val_dataset'] = PortraitDataset(cfg={
|
||||
'image_size': cfg.data.image_size,
|
||||
'T': cfg.data.n_sample_frames,
|
||||
"sample_method": cfg.data.sample_method,
|
||||
'top_k_ratio': cfg.data.top_k_ratio,
|
||||
"contorl_face_min_size": cfg.data.contorl_face_min_size,
|
||||
"dataset_key": cfg.data.dataset_key,
|
||||
"padding_pixel_mouth": cfg.padding_pixel_mouth,
|
||||
"whisper_path": cfg.whisper_path,
|
||||
"min_face_size": cfg.data.min_face_size,
|
||||
"cropping_jaw2edge_margin_mean": cfg.cropping_jaw2edge_margin_mean,
|
||||
"cropping_jaw2edge_margin_std": cfg.cropping_jaw2edge_margin_std,
|
||||
"crop_type": cfg.crop_type,
|
||||
"random_margin_method": cfg.random_margin_method,
|
||||
})
|
||||
|
||||
dataloader_dict['val_dataloader'] = torch.utils.data.DataLoader(
|
||||
dataloader_dict['val_dataset'],
|
||||
batch_size=cfg.data.train_bs,
|
||||
shuffle=True,
|
||||
num_workers=1,
|
||||
)
|
||||
|
||||
return dataloader_dict
|
||||
|
||||
def initialize_loss_functions(cfg, accelerator, scheduler_max_steps):
|
||||
"""Initialize loss functions and discriminators"""
|
||||
loss_dict = {
|
||||
'L1_loss': nn.L1Loss(reduction='mean'),
|
||||
'discriminator': None,
|
||||
'mouth_discriminator': None,
|
||||
'optimizer_D': None,
|
||||
'mouth_optimizer_D': None,
|
||||
'scheduler_D': None,
|
||||
'mouth_scheduler_D': None,
|
||||
'disc_scales': None,
|
||||
'discriminator_full': None,
|
||||
'mouth_discriminator_full': None
|
||||
}
|
||||
|
||||
if cfg.loss_params.gan_loss > 0:
|
||||
loss_dict['discriminator'] = MultiScaleDiscriminator(
|
||||
**cfg.model_params.discriminator_params).to(accelerator.device)
|
||||
loss_dict['discriminator_full'] = DiscriminatorFullModel(loss_dict['discriminator'])
|
||||
loss_dict['disc_scales'] = cfg.model_params.discriminator_params.scales
|
||||
loss_dict['optimizer_D'] = optim.AdamW(
|
||||
loss_dict['discriminator'].parameters(),
|
||||
lr=cfg.discriminator_train_params.lr,
|
||||
weight_decay=cfg.discriminator_train_params.weight_decay,
|
||||
betas=cfg.discriminator_train_params.betas,
|
||||
eps=cfg.discriminator_train_params.eps)
|
||||
loss_dict['scheduler_D'] = CosineAnnealingLR(
|
||||
loss_dict['optimizer_D'],
|
||||
T_max=scheduler_max_steps,
|
||||
eta_min=1e-6
|
||||
)
|
||||
|
||||
if cfg.loss_params.mouth_gan_loss > 0:
|
||||
loss_dict['mouth_discriminator'] = MultiScaleDiscriminator(
|
||||
**cfg.model_params.discriminator_params).to(accelerator.device)
|
||||
loss_dict['mouth_discriminator_full'] = DiscriminatorFullModel(loss_dict['mouth_discriminator'])
|
||||
loss_dict['mouth_optimizer_D'] = optim.AdamW(
|
||||
loss_dict['mouth_discriminator'].parameters(),
|
||||
lr=cfg.discriminator_train_params.lr,
|
||||
weight_decay=cfg.discriminator_train_params.weight_decay,
|
||||
betas=cfg.discriminator_train_params.betas,
|
||||
eps=cfg.discriminator_train_params.eps)
|
||||
loss_dict['mouth_scheduler_D'] = CosineAnnealingLR(
|
||||
loss_dict['mouth_optimizer_D'],
|
||||
T_max=scheduler_max_steps,
|
||||
eta_min=1e-6
|
||||
)
|
||||
|
||||
return loss_dict
|
||||
|
||||
def initialize_syncnet(cfg, accelerator, weight_dtype):
|
||||
"""Initialize SyncNet model"""
|
||||
if cfg.loss_params.sync_loss > 0 or cfg.use_adapted_weight:
|
||||
if cfg.data.n_sample_frames != 16:
|
||||
raise ValueError(
|
||||
f"Invalid n_sample_frames {cfg.data.n_sample_frames} for sync_loss, it should be 16."
|
||||
)
|
||||
syncnet_config = OmegaConf.load(cfg.syncnet_config_path)
|
||||
syncnet = SyncNet(OmegaConf.to_container(
|
||||
syncnet_config.model)).to(accelerator.device)
|
||||
print(
|
||||
f"Load SyncNet checkpoint from: {syncnet_config.ckpt.inference_ckpt_path}")
|
||||
checkpoint = torch.load(
|
||||
syncnet_config.ckpt.inference_ckpt_path, map_location=accelerator.device)
|
||||
syncnet.load_state_dict(checkpoint["state_dict"])
|
||||
syncnet.to(dtype=weight_dtype)
|
||||
syncnet.requires_grad_(False)
|
||||
syncnet.eval()
|
||||
return syncnet
|
||||
return None
|
||||
|
||||
def initialize_vgg(cfg, accelerator):
|
||||
"""Initialize VGG model"""
|
||||
if cfg.loss_params.vgg_loss > 0:
|
||||
vgg_IN = vgg_face.Vgg19().to(accelerator.device,)
|
||||
pyramid = vgg_face.ImagePyramide(
|
||||
cfg.loss_params.pyramid_scale, 3).to(accelerator.device)
|
||||
vgg_IN.eval()
|
||||
downsampler = Interpolate(
|
||||
size=(224, 224), mode='bilinear', align_corners=False).to(accelerator.device)
|
||||
return vgg_IN, pyramid, downsampler
|
||||
return None, None, None
|
||||
|
||||
def validation(
|
||||
cfg,
|
||||
val_dataloader,
|
||||
net,
|
||||
vae,
|
||||
wav2vec,
|
||||
accelerator,
|
||||
save_dir,
|
||||
global_step,
|
||||
weight_dtype,
|
||||
syncnet_score=1,
|
||||
):
|
||||
"""Validation function for model evaluation"""
|
||||
net.eval() # Set the model to evaluation mode
|
||||
for batch in val_dataloader:
|
||||
# The same ref_latents
|
||||
ref_pixel_values = batch["pixel_values_ref_img"].to(weight_dtype).to(
|
||||
accelerator.device, non_blocking=True
|
||||
)
|
||||
pixel_values = batch["pixel_values_vid"].to(weight_dtype).to(
|
||||
accelerator.device, non_blocking=True
|
||||
)
|
||||
bsz, num_frames, c, h, w = ref_pixel_values.shape
|
||||
|
||||
audio_prompts = process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype)
|
||||
# audio feature for unet
|
||||
audio_prompts = rearrange(
|
||||
audio_prompts,
|
||||
'b f c h w-> (b f) c h w'
|
||||
)
|
||||
audio_prompts = rearrange(
|
||||
audio_prompts,
|
||||
'(b f) c h w -> (b f) (c h) w',
|
||||
b=bsz
|
||||
)
|
||||
# different masked_latents
|
||||
image_pred_train = get_image_pred(
|
||||
pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype)
|
||||
image_pred_infer = get_image_pred(
|
||||
ref_pixel_values, ref_pixel_values, audio_prompts, vae, net, weight_dtype)
|
||||
|
||||
process_and_save_images(
|
||||
batch,
|
||||
image_pred_train,
|
||||
image_pred_infer,
|
||||
save_dir,
|
||||
global_step,
|
||||
accelerator,
|
||||
cfg.num_images_to_keep,
|
||||
syncnet_score
|
||||
)
|
||||
# only infer 1 image in validation
|
||||
break
|
||||
net.train() # Set the model back to training mode
|
||||
319
models/MuseTalk/musetalk/utils/utils.py
Normal file
319
models/MuseTalk/musetalk/utils/utils.py
Normal file
@@ -0,0 +1,319 @@
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Union, List
|
||||
import torch.nn.functional as F
|
||||
from einops import rearrange
|
||||
import shutil
|
||||
import os.path as osp
|
||||
|
||||
from musetalk.models.vae import VAE
|
||||
from musetalk.models.unet import UNet,PositionalEncoding
|
||||
|
||||
|
||||
def load_all_model(
|
||||
unet_model_path=os.path.join("models", "musetalkV15", "unet.pth"),
|
||||
vae_type="sd-vae",
|
||||
unet_config=os.path.join("models", "musetalkV15", "musetalk.json"),
|
||||
device=None,
|
||||
):
|
||||
vae = VAE(
|
||||
model_path = os.path.join("models", vae_type),
|
||||
)
|
||||
print(f"load unet model from {unet_model_path}")
|
||||
unet = UNet(
|
||||
unet_config=unet_config,
|
||||
model_path=unet_model_path,
|
||||
device=device
|
||||
)
|
||||
pe = PositionalEncoding(d_model=384)
|
||||
return vae, unet, pe
|
||||
|
||||
def get_file_type(video_path):
|
||||
_, ext = os.path.splitext(video_path)
|
||||
|
||||
if ext.lower() in ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']:
|
||||
return 'image'
|
||||
elif ext.lower() in ['.avi', '.mp4', '.mov', '.flv', '.mkv']:
|
||||
return 'video'
|
||||
else:
|
||||
return 'unsupported'
|
||||
|
||||
def get_video_fps(video_path):
|
||||
video = cv2.VideoCapture(video_path)
|
||||
fps = video.get(cv2.CAP_PROP_FPS)
|
||||
video.release()
|
||||
return fps
|
||||
|
||||
def datagen(
|
||||
whisper_chunks,
|
||||
vae_encode_latents,
|
||||
batch_size=8,
|
||||
delay_frame=0,
|
||||
device="cuda:0",
|
||||
):
|
||||
whisper_batch, latent_batch = [], []
|
||||
for i, w in enumerate(whisper_chunks):
|
||||
idx = (i+delay_frame)%len(vae_encode_latents)
|
||||
latent = vae_encode_latents[idx]
|
||||
whisper_batch.append(w)
|
||||
latent_batch.append(latent)
|
||||
|
||||
if len(latent_batch) >= batch_size:
|
||||
whisper_batch = torch.stack(whisper_batch)
|
||||
latent_batch = torch.cat(latent_batch, dim=0)
|
||||
yield whisper_batch, latent_batch
|
||||
whisper_batch, latent_batch = [], []
|
||||
|
||||
# the last batch may smaller than batch size
|
||||
if len(latent_batch) > 0:
|
||||
whisper_batch = torch.stack(whisper_batch)
|
||||
latent_batch = torch.cat(latent_batch, dim=0)
|
||||
|
||||
yield whisper_batch.to(device), latent_batch.to(device)
|
||||
|
||||
def cast_training_params(
|
||||
model: Union[torch.nn.Module, List[torch.nn.Module]],
|
||||
dtype=torch.float32,
|
||||
):
|
||||
if not isinstance(model, list):
|
||||
model = [model]
|
||||
for m in model:
|
||||
for param in m.parameters():
|
||||
# only upcast trainable parameters into fp32
|
||||
if param.requires_grad:
|
||||
param.data = param.to(dtype)
|
||||
|
||||
def rand_log_normal(
|
||||
shape,
|
||||
loc=0.,
|
||||
scale=1.,
|
||||
device='cpu',
|
||||
dtype=torch.float32,
|
||||
generator=None
|
||||
):
|
||||
"""Draws samples from an lognormal distribution."""
|
||||
rnd_normal = torch.randn(
|
||||
shape, device=device, dtype=dtype, generator=generator) # N(0, I)
|
||||
sigma = (rnd_normal * scale + loc).exp()
|
||||
return sigma
|
||||
|
||||
def get_mouth_region(frames, image_pred, pixel_values_face_mask):
|
||||
# Initialize lists to store the results for each image in the batch
|
||||
mouth_real_list = []
|
||||
mouth_generated_list = []
|
||||
|
||||
# Process each image in the batch
|
||||
for b in range(frames.shape[0]):
|
||||
# Find the non-zero area in the face mask
|
||||
non_zero_indices = torch.nonzero(pixel_values_face_mask[b])
|
||||
# If there are no non-zero indices, skip this image
|
||||
if non_zero_indices.numel() == 0:
|
||||
continue
|
||||
|
||||
min_y, max_y = torch.min(non_zero_indices[:, 1]), torch.max(
|
||||
non_zero_indices[:, 1])
|
||||
min_x, max_x = torch.min(non_zero_indices[:, 2]), torch.max(
|
||||
non_zero_indices[:, 2])
|
||||
|
||||
# Crop the frames and image_pred according to the non-zero area
|
||||
frames_cropped = frames[b, :, min_y:max_y, min_x:max_x]
|
||||
image_pred_cropped = image_pred[b, :, min_y:max_y, min_x:max_x]
|
||||
# Resize the cropped images to 256*256
|
||||
frames_resized = F.interpolate(frames_cropped.unsqueeze(
|
||||
0), size=(256, 256), mode='bilinear', align_corners=False)
|
||||
image_pred_resized = F.interpolate(image_pred_cropped.unsqueeze(
|
||||
0), size=(256, 256), mode='bilinear', align_corners=False)
|
||||
|
||||
# Append the resized images to the result lists
|
||||
mouth_real_list.append(frames_resized)
|
||||
mouth_generated_list.append(image_pred_resized)
|
||||
|
||||
# Convert the lists to tensors if they are not empty
|
||||
mouth_real = torch.cat(mouth_real_list, dim=0) if mouth_real_list else None
|
||||
mouth_generated = torch.cat(
|
||||
mouth_generated_list, dim=0) if mouth_generated_list else None
|
||||
|
||||
return mouth_real, mouth_generated
|
||||
|
||||
def get_image_pred(pixel_values,
|
||||
ref_pixel_values,
|
||||
audio_prompts,
|
||||
vae,
|
||||
net,
|
||||
weight_dtype):
|
||||
with torch.no_grad():
|
||||
bsz, num_frames, c, h, w = pixel_values.shape
|
||||
|
||||
masked_pixel_values = pixel_values.clone()
|
||||
masked_pixel_values[:, :, :, h//2:, :] = -1
|
||||
|
||||
masked_frames = rearrange(
|
||||
masked_pixel_values, 'b f c h w -> (b f) c h w')
|
||||
masked_latents = vae.encode(masked_frames).latent_dist.mode()
|
||||
masked_latents = masked_latents * vae.config.scaling_factor
|
||||
masked_latents = masked_latents.float()
|
||||
|
||||
ref_frames = rearrange(ref_pixel_values, 'b f c h w-> (b f) c h w')
|
||||
ref_latents = vae.encode(ref_frames).latent_dist.mode()
|
||||
ref_latents = ref_latents * vae.config.scaling_factor
|
||||
ref_latents = ref_latents.float()
|
||||
|
||||
input_latents = torch.cat([masked_latents, ref_latents], dim=1)
|
||||
input_latents = input_latents.to(weight_dtype)
|
||||
timesteps = torch.tensor([0], device=input_latents.device)
|
||||
latents_pred = net(
|
||||
input_latents,
|
||||
timesteps,
|
||||
audio_prompts,
|
||||
)
|
||||
latents_pred = (1 / vae.config.scaling_factor) * latents_pred
|
||||
image_pred = vae.decode(latents_pred).sample
|
||||
image_pred = image_pred.float()
|
||||
|
||||
return image_pred
|
||||
|
||||
def process_audio_features(cfg, batch, wav2vec, bsz, num_frames, weight_dtype):
|
||||
with torch.no_grad():
|
||||
audio_feature_length_per_frame = 2 * \
|
||||
(cfg.data.audio_padding_length_left +
|
||||
cfg.data.audio_padding_length_right + 1)
|
||||
audio_feats = batch['audio_feature'].to(weight_dtype)
|
||||
audio_feats = wav2vec.encoder(
|
||||
audio_feats, output_hidden_states=True).hidden_states
|
||||
audio_feats = torch.stack(audio_feats, dim=2).to(weight_dtype) # [B, T, 10, 5, 384]
|
||||
|
||||
start_ts = batch['audio_offset']
|
||||
step_ts = batch['audio_step']
|
||||
audio_feats = torch.cat([torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_left]),
|
||||
audio_feats,
|
||||
torch.zeros_like(audio_feats[:, :2*cfg.data.audio_padding_length_right])], 1)
|
||||
audio_prompts = []
|
||||
for bb in range(bsz):
|
||||
audio_feats_list = []
|
||||
for f in range(num_frames):
|
||||
cur_t = (start_ts[bb] + f * step_ts[bb]) * 2
|
||||
audio_clip = audio_feats[bb:bb+1,
|
||||
cur_t: cur_t+audio_feature_length_per_frame]
|
||||
|
||||
audio_feats_list.append(audio_clip)
|
||||
audio_feats_list = torch.stack(audio_feats_list, 1)
|
||||
audio_prompts.append(audio_feats_list)
|
||||
audio_prompts = torch.cat(audio_prompts) # B, T, 10, 5, 384
|
||||
return audio_prompts
|
||||
|
||||
def save_checkpoint(model, save_dir, ckpt_num, name="appearance_net", total_limit=None, logger=None):
|
||||
save_path = os.path.join(save_dir, f"{name}-{ckpt_num}.pth")
|
||||
|
||||
if total_limit is not None:
|
||||
checkpoints = os.listdir(save_dir)
|
||||
checkpoints = [d for d in checkpoints if d.endswith(".pth")]
|
||||
checkpoints = [d for d in checkpoints if name in d]
|
||||
checkpoints = sorted(
|
||||
checkpoints, key=lambda x: int(x.split("-")[1].split(".")[0])
|
||||
)
|
||||
|
||||
if len(checkpoints) >= total_limit:
|
||||
num_to_remove = len(checkpoints) - total_limit + 1
|
||||
removing_checkpoints = checkpoints[0:num_to_remove]
|
||||
logger.info(
|
||||
f"{len(checkpoints)} checkpoints already exist, removing {len(removing_checkpoints)} checkpoints"
|
||||
)
|
||||
logger.info(
|
||||
f"removing checkpoints: {', '.join(removing_checkpoints)}")
|
||||
|
||||
for removing_checkpoint in removing_checkpoints:
|
||||
removing_checkpoint = os.path.join(
|
||||
save_dir, removing_checkpoint)
|
||||
os.remove(removing_checkpoint)
|
||||
|
||||
state_dict = model.state_dict()
|
||||
torch.save(state_dict, save_path)
|
||||
|
||||
def save_models(accelerator, net, save_dir, global_step, cfg, logger=None):
|
||||
unwarp_net = accelerator.unwrap_model(net)
|
||||
save_checkpoint(
|
||||
unwarp_net.unet,
|
||||
save_dir,
|
||||
global_step,
|
||||
name="unet",
|
||||
total_limit=cfg.total_limit,
|
||||
logger=logger
|
||||
)
|
||||
|
||||
def delete_additional_ckpt(base_path, num_keep):
|
||||
dirs = []
|
||||
for d in os.listdir(base_path):
|
||||
if d.startswith("checkpoint-"):
|
||||
dirs.append(d)
|
||||
num_tot = len(dirs)
|
||||
if num_tot <= num_keep:
|
||||
return
|
||||
# ensure ckpt is sorted and delete the ealier!
|
||||
del_dirs = sorted(dirs, key=lambda x: int(x.split("-")[-1]))[: num_tot - num_keep]
|
||||
for d in del_dirs:
|
||||
path_to_dir = osp.join(base_path, d)
|
||||
if osp.exists(path_to_dir):
|
||||
shutil.rmtree(path_to_dir)
|
||||
|
||||
def seed_everything(seed):
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
|
||||
torch.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
np.random.seed(seed % (2**32))
|
||||
random.seed(seed)
|
||||
|
||||
def process_and_save_images(
|
||||
batch,
|
||||
image_pred,
|
||||
image_pred_infer,
|
||||
save_dir,
|
||||
global_step,
|
||||
accelerator,
|
||||
num_images_to_keep=10,
|
||||
syncnet_score=1
|
||||
):
|
||||
# Rearrange the tensors
|
||||
print("image_pred.shape: ", image_pred.shape)
|
||||
pixel_values_ref_img = rearrange(batch['pixel_values_ref_img'], "b f c h w -> (b f) c h w")
|
||||
pixel_values = rearrange(batch["pixel_values_vid"], 'b f c h w -> (b f) c h w')
|
||||
|
||||
# Create masked pixel values
|
||||
masked_pixel_values = batch["pixel_values_vid"].clone()
|
||||
_, _, _, h, _ = batch["pixel_values_vid"].shape
|
||||
masked_pixel_values[:, :, :, h//2:, :] = -1
|
||||
masked_pixel_values = rearrange(masked_pixel_values, 'b f c h w -> (b f) c h w')
|
||||
|
||||
# Keep only the specified number of images
|
||||
pixel_values = pixel_values[:num_images_to_keep, :, :, :]
|
||||
masked_pixel_values = masked_pixel_values[:num_images_to_keep, :, :, :]
|
||||
pixel_values_ref_img = pixel_values_ref_img[:num_images_to_keep, :, :, :]
|
||||
image_pred = image_pred.detach()[:num_images_to_keep, :, :, :]
|
||||
image_pred_infer = image_pred_infer.detach()[:num_images_to_keep, :, :, :]
|
||||
|
||||
# Concatenate images
|
||||
concat = torch.cat([
|
||||
masked_pixel_values * 0.5 + 0.5,
|
||||
pixel_values_ref_img * 0.5 + 0.5,
|
||||
image_pred * 0.5 + 0.5,
|
||||
pixel_values * 0.5 + 0.5,
|
||||
image_pred_infer * 0.5 + 0.5,
|
||||
], dim=2)
|
||||
print("concat.shape: ", concat.shape)
|
||||
|
||||
# Create the save directory if it doesn't exist
|
||||
os.makedirs(f'{save_dir}/samples/', exist_ok=True)
|
||||
|
||||
# Try to save the concatenated image
|
||||
try:
|
||||
# Concatenate images horizontally and convert to numpy array
|
||||
final_image = torch.cat([concat[i] for i in range(concat.shape[0])], dim=-1).permute(1, 2, 0).cpu().numpy()[:, :, [2, 1, 0]] * 255
|
||||
# Save the image
|
||||
cv2.imwrite(f'{save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg', final_image)
|
||||
print(f"Image saved successfully: {save_dir}/samples/sample_{global_step}_{accelerator.device}_SyncNetScore_{syncnet_score}.jpg")
|
||||
except Exception as e:
|
||||
print(f"Failed to save image: {e}")
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user