Compare commits

...

16 Commits

Author SHA1 Message Date
Kevin Wong
4a3dd2b225 更新 2026-01-28 17:22:31 +08:00
Kevin Wong
ee8cb9cfd2 更新 2026-01-27 16:52:40 +08:00
Kevin Wong
c6c4b2313f 更新 2026-01-26 16:38:30 +08:00
Kevin Wong
f99bd336c9 更新 2026-01-26 12:18:54 +08:00
Kevin Wong
c918dc6faf 更新 2026-01-23 18:09:12 +08:00
Kevin Wong
3a3df41904 优化界面 2026-01-23 10:38:03 +08:00
Kevin Wong
561d74e16d 更新 2026-01-23 10:07:35 +08:00
Kevin Wong
cfe21d8337 更新 2026-01-23 09:42:10 +08:00
Kevin Wong
3a76f9d0cf 更新 2026-01-22 17:15:42 +08:00
Kevin Wong
ad7ff7a385 界面优化 2026-01-22 11:14:42 +08:00
Kevin Wong
c7e2b4d363 文档更新 2026-01-22 09:54:32 +08:00
Kevin Wong
d5baa79448 文档更新 2026-01-22 09:52:29 +08:00
Kevin Wong
3db15cee4e 更新 2026-01-22 09:22:23 +08:00
Kevin Wong
2543a270c1 更新文档 2026-01-21 10:40:07 +08:00
Kevin Wong
cbf840f472 优化代码 2026-01-21 10:30:32 +08:00
Kevin Wong
1890cea3ee 分辨率修复 2026-01-20 17:33:57 +08:00
157 changed files with 132699 additions and 1179 deletions

View File

@@ -27,12 +27,18 @@ node --version
# 检查 FFmpeg
ffmpeg -version
# 检查 pm2 (用于服务管理)
pm2 --version
```
如果缺少 FFmpeg:
如果缺少依赖:
```bash
sudo apt update
sudo apt install ffmpeg
# 安装 pm2
npm install -g pm2
```
---
@@ -48,28 +54,7 @@ cd /home/rongye/ProgramFiles/ViGent2
---
## 步骤 3: 安装后端依赖
```bash
cd /home/rongye/ProgramFiles/ViGent2/backend
# 创建虚拟环境
python3 -m venv venv
source venv/bin/activate
# 安装 PyTorch (CUDA 12.1)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# 安装其他依赖
pip install -r requirements.txt
# 安装 Playwright 浏览器 (社交发布用)
playwright install chromium
```
---
## 步骤 4: 部署 AI 模型 (LatentSync 1.6)
## 步骤 3: 部署 AI 模型 (LatentSync 1.6)
> ⚠️ **重要**LatentSync 需要独立的 Conda 环境和 **~18GB VRAM**。请**不要**直接安装在后端环境中。
@@ -83,25 +68,86 @@ playwright install chromium
4. 复制核心推理代码
5. 验证推理脚本
确保 LatentSync 部署成功后,再继续后续步骤。
**验证 LatentSync 部署**:
```bash
cd /home/rongye/ProgramFiles/ViGent2/models/LatentSync
conda activate latentsync
python -m scripts.server # 测试能否启动Ctrl+C 退出
```
---
## 步骤 7: 配置环境变量
## 步骤 4: 安装后端依赖
```bash
cd /home/rongye/ProgramFiles/ViGent2/backend
# 复制配置模板 (默认配置已经就绪)
# 创建虚拟环境
python3 -m venv venv
source venv/bin/activate
# 安装 PyTorch (CUDA 12.1)
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121
# 安装 Python 依赖
pip install -r requirements.txt
# 安装 Playwright 浏览器(社交发布需要)
playwright install chromium
```
---
## 步骤 5: 部署用户认证系统 (Supabase + Auth)
> 🔐 **包含**: 登录/注册、Supabase 数据库配置、JWT 认证、管理员后台
请参考独立的认证系统部署指南:
**[用户认证系统部署指南](AUTH_DEPLOY.md)**
---
## 步骤 6: 配置 Supabase RLS 策略 (重要)
> ⚠️ **注意**:为了支持前端直传文件,必须配置存储桶的行级安全策略 (RLS)。
1. 确保 Supabase 容器正在运行 (`docker ps`).
2. 将项目根目录下的 `supabase_rls.sql` (如果有) 或以下 SQL 内容在数据库中执行。
3. **执行命令**:
```bash
# 进入后端目录
cd /home/rongye/ProgramFiles/ViGent2/backend
# 执行 SQL (允许 anon 角色上传/读取 materials 桶)
docker exec -i supabase-db psql -U postgres <<EOF
INSERT INTO storage.buckets (id, name, public) VALUES ('materials', 'materials', true) ON CONFLICT (id) DO NOTHING;
INSERT INTO storage.buckets (id, name, public) VALUES ('outputs', 'outputs', true) ON CONFLICT (id) DO NOTHING;
CREATE POLICY "Allow public uploads" ON storage.objects FOR INSERT TO anon WITH CHECK (bucket_id = 'materials');
CREATE POLICY "Allow public read" ON storage.objects FOR SELECT TO anon USING (bucket_id = 'materials' OR bucket_id = 'outputs');
EOF
```
---
## 步骤 7: 配置环境变量
```bash
cd /home/rongye/ProgramFiles/ViGent2/backend
# 复制配置模板
cp .env.example .env
```
> 💡 **说明**`.env.example` 已包含正确的 LatentSync 默认配置,直接复制即可使用。
> 💡 **说明**`.env.example` 已包含正确的默认配置,直接复制即可使用。
> 如需自定义,可编辑 `.env` 修改以下参数:
| 配置项 | 默认值 | 说明 |
|--------|--------|------|
| `SUPABASE_URL` | `http://localhost:8008` | Supabase API 内部地址 |
| `SUPABASE_PUBLIC_URL` | `https://api.hbyrkj.top` | Supabase API 公网地址 (前端访问) |
| `LATENTSYNC_GPU_ID` | 1 | GPU 选择 (0 或 1) |
| `LATENTSYNC_USE_SERVER` | false | 设为 true 以启用常驻服务加速 |
| `LATENTSYNC_INFERENCE_STEPS` | 20 | 推理步数 (20-50) |
| `LATENTSYNC_GUIDANCE_SCALE` | 1.5 | 引导系数 (1.0-3.0) |
| `DEBUG` | true | 生产环境改为 false |
@@ -115,13 +161,18 @@ cd /home/rongye/ProgramFiles/ViGent2/frontend
# 安装依赖
npm install
# 生产环境构建 (可选)
npm run build
```
---
## 步骤 9: 测试运行
### 启动后端
> 💡 先手动启动测试,确认一切正常后再配置 pm2 常驻服务。
### 启动后端 (终端 1)
```bash
cd /home/rongye/ProgramFiles/ViGent2/backend
@@ -129,16 +180,22 @@ source venv/bin/activate
uvicorn app.main:app --host 0.0.0.0 --port 8006
```
### 启动前端 (新开终端)
### 启动前端 (终端 2)
```bash
cd /home/rongye/ProgramFiles/ViGent2/frontend
npm run dev -- -H 0.0.0.0 --port 3002
```
---
### 启动 LatentSync (终端 3, 可选加速)
## 步骤 10: 验证
```bash
cd /home/rongye/ProgramFiles/ViGent2/models/LatentSync
conda activate latentsync
python -m scripts.server
```
### 验证
1. 访问 http://服务器IP:3002 查看前端
2. 访问 http://服务器IP:8006/docs 查看 API 文档
@@ -146,59 +203,158 @@ npm run dev -- -H 0.0.0.0 --port 3002
---
## 使用 systemd 管理服务 (可选)
## 步骤 10: 使用 pm2 管理常驻服务
### 后端服务
> 推荐使用 pm2 管理所有服务,支持自动重启和日志管理。
创建 `/etc/systemd/system/vigent2-backend.service`:
```ini
[Unit]
Description=ViGent2 Backend API
After=network.target
### 1. 启动后端服务 (FastAPI)
[Service]
Type=simple
User=rongye
WorkingDirectory=/home/rongye/ProgramFiles/ViGent2/backend
Environment="PATH=/home/rongye/ProgramFiles/ViGent2/backend/venv/bin"
ExecStart=/home/rongye/ProgramFiles/ViGent2/backend/venv/bin/uvicorn app.main:app --host 0.0.0.0 --port 8006
Restart=always
建议使用 Shell 脚本启动以避免环境问题。
[Install]
WantedBy=multi-user.target
1. 创建启动脚本 `run_backend.sh`:
```bash
cat > run_backend.sh << 'EOF'
#!/bin/bash
cd /home/rongye/ProgramFiles/ViGent2/backend
./venv/bin/uvicorn app.main:app --host 0.0.0.0 --port 8006
EOF
chmod +x run_backend.sh
```
### 前端服务
创建 `/etc/systemd/system/vigent2-frontend.service`:
```ini
[Unit]
Description=ViGent2 Frontend
After=network.target
[Service]
Type=simple
User=rongye
WorkingDirectory=/home/rongye/ProgramFiles/ViGent2/frontend
ExecStart=/usr/bin/npm run start
Restart=always
[Install]
WantedBy=multi-user.target
2. 使用 pm2 启动:
```bash
pm2 start ./run_backend.sh --name vigent2-backend
```
### 启用服务
### 2. 启动前端服务 (Next.js)
⚠️ **注意**:生产模式启动前必须先进行构建。
```bash
sudo systemctl daemon-reload
sudo systemctl enable vigent2-backend vigent2-frontend
sudo systemctl start vigent2-backend vigent2-frontend
cd /home/rongye/ProgramFiles/ViGent2/frontend
# 1. 构建项目 (如果之前没跑过或代码有更新)
npm run build
# 2. 启动服务
pm2 start npm --name vigent2-frontend -- run start -- -p 3002
```
### 3. 启动 LatentSync 模型服务
1. 创建启动脚本 `run_latentsync.sh` (使用你的 conda python 路径):
```bash
cat > run_latentsync.sh << 'EOF'
#!/bin/bash
cd /home/rongye/ProgramFiles/ViGent2/models/LatentSync
# 替换为你的实际 Python 路径
/home/rongye/ProgramFiles/miniconda3/envs/latentsync/bin/python -m scripts.server
EOF
chmod +x run_latentsync.sh
```
2. 使用 pm2 启动:
```bash
pm2 start ./run_latentsync.sh --name vigent2-latentsync
```
### 4. 保存当前列表 (开机自启)
```bash
pm2 save
pm2 startup
```
### pm2 常用命令
```bash
pm2 status # 查看所有服务状态
pm2 logs # 查看所有日志
pm2 logs vigent2-backend # 查看后端日志
pm2 restart all # 重启所有服务
pm2 stop vigent2-latentsync # 停止 LatentSync 服务
pm2 delete all # 删除所有服务
```
---
## 步骤 11: 配置 Nginx HTTPS (可选 - 公网访问)
如果您需要通过公网域名 HTTPS 访问 (如 `https://vigent.hbyrkj.top`),请参考以下 Nginx 配置。
**前置条件**
1. 已申请 SSL 证书 (如 Let's Encrypt)。
2. 使用 FRP 或其他方式将本地 3002 端口映射到服务器。
**配置示例** (`/etc/nginx/conf.d/vigent.conf`):
```nginx
server {
listen 80;
server_name your.domain.com;
return 301 https://$host$request_uri;
}
server {
listen 443 ssl http2;
server_name your.domain.com;
ssl_certificate /path/to/fullchain.pem;
ssl_certificate_key /path/to/privkey.pem;
location / {
proxy_pass http://127.0.0.1:3002; # 转发给 Next.js 前端
# 必须配置 WebSocket 支持,否则热更和即时通信失效
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
proxy_set_header Host $host;
proxy_set_header X-Real-IP $remote_addr;
}
}
```
---
---
## 步骤 12: 配置阿里云 Nginx 网关 (关键)
> ⚠️ **CRITICAL**: 如果使用 `api.hbyrkj.top` 等域名作为入口,必须在阿里云 (或公网入口) 的 Nginx 配置中解除上传限制。
> **这是导致 500/413 错误的核心原因。**
**关键配置项**
```nginx
server {
listen 443 ssl;
server_name api.hbyrkj.top;
# ... 其他 SSL 配置 ...
# 允许大文件上传 (0 表示不限制,或设置为 100M, 500M)
client_max_body_size 0;
location / {
proxy_pass http://127.0.0.1:YOUR_FRP_PORT;
# 延长超时时间
proxy_read_timeout 600s;
proxy_send_timeout 600s;
}
}
```
**后果**:如果没有这个配置,上传会在 ~1MB 或 ~10MB 时直接断开,报 413 Payload Too Large 或 500/502 错误。
---
## 故障排除
### GPU 不可用
```bash
@@ -213,14 +369,58 @@ python3 -c "import torch; print(torch.cuda.is_available())"
# 查看端口占用
sudo lsof -i :8006
sudo lsof -i :3002
sudo lsof -i :8007
```
### 查看日志
```bash
# 后端日志
journalctl -u vigent2-backend -f
# 前端日志
journalctl -u vigent2-frontend -f
# pm2 日志
pm2 logs vigent2-backend
pm2 logs vigent2-frontend
pm2 logs vigent2-latentsync
```
### SSH 连接卡顿 / 系统响应慢
**原因**LatentSync 模型服务启动时会占用大量 I/O 和 CPU 资源,或者模型加载到 GPU 时导致瞬时负载过高。
**解决**
1. 检查系统负载:`top` 或 `htop`
2. 如果不需要实时生成视频,可以暂时停止 LatentSync 服务:
```bash
pm2 stop vigent2-latentsync
```
3. 确保服务器有足够的 RAM 和 Swap 空间。
4. **代码级优化**:已在 `scripts/server.py` 和 `scripts/inference.py` 中强制限制 `OMP_NUM_THREADS=8`,防止 PyTorch 占用所有 CPU 核心导致系统假死。
---
## 依赖清单
### 后端关键依赖
| 依赖 | 用途 |
|------|------|
| `fastapi` | Web API 框架 |
| `uvicorn` | ASGI 服务器 |
| `edge-tts` | 微软 TTS 配音 |
| `playwright` | 社交媒体自动发布 |
| `biliup` | B站视频上传 |
| `loguru` | 日志管理 |
### 前端关键依赖
| 依赖 | 用途 |
|------|------|
| `next` | React 框架 |
| `swr` | 数据请求与缓存 |
| `tailwindcss` | CSS 样式 |
### LatentSync 关键依赖
| 依赖 | 用途 |
|------|------|
| `torch` 2.5.1 | PyTorch GPU 推理 |
| `diffusers` | Latent Diffusion 模型 |
| `accelerate` | 模型加速 |

122
Docs/DevLogs/Day10.md Normal file
View File

@@ -0,0 +1,122 @@
---
## 🔧 隧道访问与视频播放修复 (11:00)
### 问题描述
在通过 FRP 隧道 (如 `http://8.148.x.x:3002`) 访问时发现:
1. **视频无法播放**:后端返回 404 (Not Found)。
2. **发布页账号列表为空**:后端返回 500 (Internal Server Error)。
### 解决方案
#### 1. 视频播放修复
- **后端 (`main.py`)**:这是根源问题。后端缺少 `uploads` 目录的挂载,导致静态资源无法访问。
```python
app.mount("/uploads", StaticFiles(directory=str(settings.UPLOAD_DIR)), name="uploads")
```
- **前端 (`next.config.ts`)**:添加反向代理规则,将 `/outputs` 和 `/uploads` 转发到后端端口 8006。
```typescript
{
source: '/uploads/:path*',
destination: 'http://localhost:8006/uploads/:path*',
},
{
source: '/outputs/:path*',
destination: 'http://localhost:8006/outputs/:path*',
}
```
#### 2. 账号列表 500 错误修复
- **根源**`backend/app/core/paths.py` 中的白名单缺少 `weixin` 和 `kuaishou`。
- **现象**:当 `PublishService` 遍历所有平台时,遇到未在白名单的平台直接抛出 `ValueError`,导致整个接口崩溃。
- **修复**:更新白名单。
```python
VALID_PLATFORMS: Set[str] = {"bilibili", "douyin", "xiaohongshu", "weixin", "kuaishou"}
```
### 结果
- ✅ 视频预览和历史视频均可正常播放。
- ✅ 发布页账号列表恢复显示。
---
## 🚀 Nginx HTTPS 部署 (11:30)
### 需求
用户在阿里云服务器上配置了 SSL 证书,需要通过 HTTPS 访问应用。
### 解决方案
提供了 Nginx 配置文件 `nginx_vigent.conf`,配置了:
1. **HTTP -> HTTPS 重定向**。
2. **SSL 证书路径** (`/etc/letsencrypt/live/vigent.hbyrkj.top/...`)。
3. **反向代理** 到本地 FRP 端口 (3002)。
4. **WebSocket 支持** (用于 Next.js 热更和通信)。
### 结果
- ✅ 用户可通过 `https://vigent.hbyrkj.top` 安全访问。
- ✅ 代码自适应:前端 `API_BASE` 为空字符串,自动适配 HTTPS 协议,无需修改代码。
---
## 🎨 UI 细节优化 (11:45)
### 修改
- 修改 `frontend/src/app/layout.tsx` 中的 Metadata。
- 标题从 `Create Next App` 改为 `ViGent`。
### 结果
- ✅ 浏览器标签页名称已更新。
---
## 🚪 用户登录退出功能 (12:00)
### 需求
用户反馈没有退出的入口。
### 解决方案
- **UI 修改**:在首页和发布管理页面的顶部导航栏添加红色的“退出”按钮 (位于最右侧)。
- **逻辑实现**
```javascript
onClick={async () => {
if (confirm('确定要退出登录吗?')) {
await fetch(`${API_BASE}/api/auth/logout`, { method: 'POST' });
window.location.href = '/login';
}
}}
```
- **部署**:已同步代码并重建前端。
---
## 🚢 Supabase 服务部署 (16:10)
### 需求
由于需要多用户隔离和更完善的权限管理,决定从纯本地文件存储迁移到 Supabase BaaS 架构。
### 实施步骤
1. **Docker 部署 (Ubuntu)**
- 使用官方 `docker-compose.yml`。
- **端口冲突解决**
- `Moodist` 占用 4000 -> 迁移 Analytics 到 **4004**。
- `code-server` 占用 8443 -> 迁移 Kong HTTPS 到 **8444**。
- 自定义端口Studio (**3003**), API (**8008**)。
2. **安全加固 (Aliyun Nginx)**
- **双域名策略**
- `supabase.hbyrkj.top` -> Studio (3003)
- `api.hbyrkj.top` -> API (8008)
- **SSL**:配置 Let's Encrypt 证书。
- **访问控制**:为 Studio 域名添加 `auth_basic` (htpasswd),防止未授权访问管理后台。
- **WebSocket**Nginx 配置 `Upgrade` 头支持 Realtime 功能。
3. **数据库初始化**
- 使用 `backend/database/schema.sql` 初始化了 `users`, `social_accounts` 等表结构。
### 下一步计划 (Storage Migration)
目前文件仍存储在本地磁盘,无法通过 RLS 进行隔离。
**计划改造 LatentSync 流程**
1. 后端集成 Supabase Storage SDK。
2. 实现 `Download (Storage) -> Local Process (LatentSync) -> Upload (Storage)` 闭环。
3. 前端改为请求 Signed URL 进行播放。

278
Docs/DevLogs/Day11.md Normal file
View File

@@ -0,0 +1,278 @@
## 🔧 上传架构重构 (Direct Upload)
### 🚨 问题描述 (10:30)
**现象**:上传大于 7MB 的文件时,后端返回 500 Internal Server Error实际为 `ClientDisconnect`
**ROOT CAUSE (关键原因)**
- **Aliyun Nginx 网关限制**`api.hbyrkj.top` 域名的 Nginx 配置缺少 `client_max_body_size 0;`
- **默认限制**Nginx 默认限制请求体为 1MB (或少量),导致大文件上传时连接被网关强制截断。
- **误判**:初期待查方向集中在 FRP 和 Backend Proxy 超时,实际是网关层的硬限制。
### ✅ 解决方案:前端直传 Supabase + 网关配置 (14:00)
**核心变更**
1. **网关配置**:在 Aliyun Nginx 的 `api.hbyrkj.top` 配置块中添加 `client_max_body_size 0;` (解除大小限制)。
2. **架构优化**:移除后端文件转发逻辑,改由前端直接上传到 Supabase Storage (减少链路节点)。
#### 1. 前端改造 (`frontend/src/app/page.tsx`)
- 引入 `@supabase/supabase-js` 客户端。
- 使用 `supabase.storage.from('materials').upload()` 直接上传。
- 移除旧的 `XMLHttpRequest` 代理上传逻辑。
- 添加文件重命名策略:`{timestamp}_{sanitized_filename}`
```typescript
// V2: Direct Upload (Bypass Backend)
const { data, error } = await supabase.storage
.from('materials')
.upload(path, file, {
cacheControl: '3600',
upsert: false
});
```
#### 2. 后端适配 (`backend/app/api/materials.py`)
- **上传接口**(已废弃/保留用于极小文件) 主要流量走直传。
- **列表接口**:更新为返回 **签名 URL (Signed URL)**,而非本地路径。
- **兼容性**:前端直接接收 `path` 字段为完整 URL无需再次拼接。
#### 3. 权限控制 (RLS)
- Supabase 默认禁止匿名写入。
- 执行 SQL 策略允许 `anon` 角色对 `materials` 桶的 `INSERT``SELECT` 权限。
```sql
-- Allow anonymous uploads
CREATE POLICY "Allow public uploads"
ON storage.objects FOR INSERT
TO anon WITH CHECK (bucket_id = 'materials');
```
### 结果
-**彻底解决超时**:上传不再经过 Nginx/FRP直接走 Supabase CDN。
-**解除大小限制**:不再受限于后端服务的 `client_max_body_size`
-**用户体验提升**:上传速度更快,进度条更准确。
## 🔧 Supabase 部署与 RLS 配置
### 相关文件
- `supabase_rls.sql`: 定义存储桶权限的 SQL 脚本。
- `docker-compose.yml`: 确认 Storage 服务配置正常。
### 操作步骤
1.`supabase_rls.sql` 上传至服务器。
2. 通过 Docker 执行 SQL
```bash
cat supabase_rls.sql | docker exec -i supabase-db psql -U postgres
```
3. 验证前端上传成功。
---
## 🔐 用户隔离实现 (15:00)
### 问题描述
不同账户登录后能看到其他用户上传的素材和生成的视频,缺乏数据隔离。
### 解决方案:存储路径前缀隔离
#### 1. 素材模块 (`backend/app/api/materials.py`)
```python
# 上传时添加用户ID前缀
storage_path = f"{user_id}/{timestamp}_{safe_name}"
# 列表时只查询当前用户目录
files_obj = await storage_service.list_files(
bucket=storage_service.BUCKET_MATERIALS,
path=user_id # 只列出用户目录下的文件
)
# 删除时验证权限
if not material_id.startswith(f"{user_id}/"):
raise HTTPException(403, "无权删除此素材")
```
#### 2. 视频模块 (`backend/app/api/videos.py`)
```python
# 生成视频时使用用户ID目录
storage_path = f"{user_id}/{task_id}_output.mp4"
# 列表/删除同样基于用户目录隔离
```
#### 3. 发布模块 (`backend/app/services/publish_service.py`)
- Cookie 存储支持用户隔离:`cookies/{user_id}/{platform}.json`
### 存储结构
```
Supabase Storage/
├── materials/
│ ├── {user_id_1}/
│ │ ├── 1737000001_video1.mp4
│ │ └── 1737000002_video2.mp4
│ └── {user_id_2}/
│ └── 1737000003_video3.mp4
└── outputs/
├── {user_id_1}/
│ └── {task_id}_output.mp4
└── {user_id_2}/
└── ...
```
### 结果
- ✅ 不同用户数据完全隔离
- ✅ Cookie 和登录状态按用户存储
- ✅ 删除操作验证所有权
---
## 🌐 Storage URL 修复 (16:00)
### 问题描述
生成的视频 URL 为 `http://localhost:8008/...`,前端无法访问。
### 解决方案
#### 1. 后端配置 (`backend/.env`)
```ini
SUPABASE_URL=http://localhost:8008 # 内部访问
SUPABASE_PUBLIC_URL=https://api.hbyrkj.top # 公网访问
```
#### 2. URL 转换 (`backend/app/services/storage.py`)
```python
def _convert_to_public_url(self, url: str) -> str:
"""将内部 URL 转换为公网可访问的 URL"""
if settings.SUPABASE_PUBLIC_URL and settings.SUPABASE_URL:
internal_url = settings.SUPABASE_URL.rstrip('/')
public_url = settings.SUPABASE_PUBLIC_URL.rstrip('/')
return url.replace(internal_url, public_url)
return url
```
### 结果
- ✅ 前端获取的 URL 可正常访问
- ✅ 视频预览和下载功能正常
---
## ⚡ 发布服务优化 - 本地文件直读 (16:30)
### 问题描述
发布视频时需要先通过 HTTP 下载 Supabase Storage 文件到临时目录,效率低且浪费资源。
### 发现
Supabase Storage 文件实际存储在本地磁盘:
```
/home/rongye/ProgramFiles/Supabase/volumes/storage/stub/stub/{bucket}/{path}/{internal_uuid}
```
### 解决方案
#### 1. 添加本地路径获取方法 (`storage.py`)
```python
SUPABASE_STORAGE_LOCAL_PATH = Path("/home/rongye/ProgramFiles/Supabase/volumes/storage/stub/stub")
def get_local_file_path(self, bucket: str, path: str) -> Optional[str]:
"""获取 Storage 文件的本地磁盘路径"""
dir_path = SUPABASE_STORAGE_LOCAL_PATH / bucket / path
if not dir_path.exists():
return None
files = list(dir_path.iterdir())
return str(files[0]) if files else None
```
#### 2. 发布服务优先使用本地文件 (`publish_service.py`)
```python
# 解析 URL 获取 bucket 和 path
match = re.search(r'/storage/v1/object/sign/([^/]+)/(.+?)\?', video_path)
if match:
bucket, storage_path = match.group(1), match.group(2)
local_video_path = storage_service.get_local_file_path(bucket, storage_path)
if local_video_path and os.path.exists(local_video_path):
logger.info(f"[发布] 直接使用本地文件: {local_video_path}")
else:
# Fallback: HTTP 下载
```
### 结果
- ✅ 发布速度显著提升(跳过下载步骤)
- ✅ 减少临时文件占用
- ✅ 保留 HTTP 下载作为 Fallback
---
## 🔧 Supabase Studio 配置 (17:00)
### 修改内容
更新 `/home/rongye/ProgramFiles/Supabase/.env`
```ini
# 修改前
SUPABASE_PUBLIC_URL=http://localhost:8000
# 修改后
SUPABASE_PUBLIC_URL=https://api.hbyrkj.top
```
### 原因
通过 `supabase.hbyrkj.top` 公网访问 Studio 时,需要正确的 API 公网地址。
### 操作
```bash
docker compose restart studio
```
### 待解决
- 🔄 Studio Settings 页面加载问题401 Unauthorized- 可能与 Nginx Basic Auth 配置冲突
---
## 📁 今日修改文件清单
| 文件 | 变更类型 | 说明 |
|------|----------|------|
| `backend/app/api/materials.py` | 修改 | 添加用户隔离 |
| `backend/app/api/videos.py` | 修改 | 添加用户隔离 |
| `backend/app/services/storage.py` | 修改 | URL转换 + 本地路径获取 |
| `backend/app/services/publish_service.py` | 修改 | 本地文件直读优化 |
| `backend/.env` | 修改 | 添加 SUPABASE_PUBLIC_URL |
| `Supabase/.env` | 修改 | SUPABASE_PUBLIC_URL |
| `frontend/src/app/page.tsx` | 修改 | 改用后端API上传 |
---
## 📅 明日任务规划 (Day 12)
### 🎯 目标:部署 Qwen3-TTS 0.6B 声音克隆系统
**任务背景**
- 当前使用 EdgeTTS微软云端 TTS音色固定无法自定义
- Qwen3-TTS 支持**零样本声音克隆**,可用少量音频克隆任意人声
**核心任务**
1. **模型部署**
- 创建独立 Conda 环境 (`qwen-tts`)
- 下载 Qwen3-TTS 0.6B 模型权重
- 配置 GPU 推理环境
2. **后端集成**
- 新增 `qwen_tts_service.py` 服务
- 支持声音克隆:上传参考音频 → 生成克隆语音
- 兼容现有 `tts_service.py` 接口
3. **前端适配**
- 添加"声音克隆"选项
- 支持上传参考音频3-10秒
- 音色预览功能
**预期成果**
- ✅ 用户可上传自己的声音样本
- ✅ 生成的口播视频使用克隆后的声音
- ✅ 保留 EdgeTTS 作为备选方案
**参考资源**
- 模型:[Qwen/Qwen3-TTS-0.6B](https://huggingface.co/Qwen/Qwen3-TTS-0.6B)
- 显存需求:~4GB (0.6B 参数量)

347
Docs/DevLogs/Day12.md Normal file
View File

@@ -0,0 +1,347 @@
# Day 12 - iOS 兼容与移动端 UI 优化
**日期**2026-01-28
---
## 🔐 Axios 全局拦截器优化
### 背景
统一处理 API 请求的认证失败场景,避免各页面重复处理 401/403 错误。
### 实现 (`frontend/src/lib/axios.ts`)
```typescript
import axios from 'axios';
// 动态获取 API 地址:服务端使用 localhost客户端使用当前域名
const API_BASE = typeof window === 'undefined'
? 'http://localhost:8006'
: '';
// 防止重复跳转
let isRedirecting = false;
const api = axios.create({
baseURL: API_BASE,
withCredentials: true, // 自动携带 HttpOnly cookie
headers: { 'Content-Type': 'application/json' },
});
// 响应拦截器 - 全局处理 401/403
api.interceptors.response.use(
(response) => response,
async (error) => {
const status = error.response?.status;
if ((status === 401 || status === 403) && !isRedirecting) {
isRedirecting = true;
// 调用 logout API 清除 HttpOnly cookie
try {
await fetch('/api/auth/logout', { method: 'POST' });
} catch (e) { /* 忽略 */ }
// 跳转登录页
if (typeof window !== 'undefined') {
window.location.replace('/login');
}
}
return Promise.reject(error);
}
);
export default api;
```
### 关键特性
-**自动携带 Cookie**: `withCredentials: true` 确保 HttpOnly JWT cookie 被发送
-**401/403 自动跳转**: 认证失败时自动清理并跳转登录页
-**防重复跳转**: `isRedirecting` 标志避免多个请求同时触发跳转
-**SSR 兼容**: 服务端渲染时使用 `localhost`,客户端使用相对路径
---
## 🔧 iOS Safari 安全区域白边修复
### 问题描述
iPhone Safari 浏览器底部和顶部显示白色区域,安卓正常。原因是 iOS Safari 有安全区域 (Safe Area),页面背景没有延伸到该区域。
### 根本原因
1. 缺少 `viewport-fit=cover` 配置
2. `min-h-screen` (100vh) 在 iOS Safari 中不包含安全区域
3. 背景渐变在页面 div 上,而非 body 上,导致安全区域显示纯色
### 解决方案
#### 1. 添加 viewport 配置 (`layout.tsx`)
```typescript
export const viewport: Viewport = {
width: 'device-width',
initialScale: 1,
viewportFit: 'cover', // 允许内容延伸到安全区域
themeColor: '#0f172a', // 顶部状态栏颜色
};
```
#### 2. 统一渐变背景到 body (`layout.tsx`)
```tsx
<html lang="en" style={{ backgroundColor: '#0f172a' }}>
<body
style={{
margin: 0,
minHeight: '100dvh',
background: 'linear-gradient(to bottom, #0f172a 0%, #0f172a 5%, #581c87 50%, #0f172a 95%, #0f172a 100%)',
}}
>
{children}
</body>
</html>
```
#### 3. CSS 安全区域支持 (`globals.css`)
```css
html {
background-color: #0f172a !important;
min-height: 100%;
}
body {
margin: 0 !important;
min-height: 100dvh;
padding-top: env(safe-area-inset-top);
padding-bottom: env(safe-area-inset-bottom);
}
```
#### 4. 移除页面独立渐变背景
各页面的根 div 移除 `bg-gradient-to-br` 类,统一使用 body 渐变:
- `page.tsx`
- `login/page.tsx`
- `publish/page.tsx`
- `admin/page.tsx`
- `register/page.tsx`
### 结果
- ✅ 顶部状态栏颜色与页面一致 (themeColor)
- ✅ 底部安全区域颜色与渐变边缘一致
- ✅ 消除分层感,背景统一
---
## 📱 移动端 Header 响应式优化
### 问题描述
移动端顶部导航按钮(视频生成、发布管理、退出)过于拥挤,文字换行显示。
### 解决方案
#### 首页 Header (`page.tsx`)
```tsx
<header className="border-b border-white/10 bg-black/20 backdrop-blur-sm">
<div className="max-w-6xl mx-auto px-4 sm:px-6 py-3 sm:py-4 flex items-center justify-between">
<Link href="/" className="text-xl sm:text-2xl font-bold ...">
<span className="text-3xl sm:text-4xl">🎬</span>
ViGent
</Link>
<div className="flex items-center gap-1 sm:gap-4">
<span className="px-2 sm:px-4 py-1 sm:py-2 text-sm sm:text-base ...">
</span>
<!-- 其他按钮同样处理 -->
</div>
</div>
</header>
```
#### 发布管理页 Header (`publish/page.tsx`)
同步应用相同的响应式类名。
### 关键改动
| 属性 | 移动端 | 桌面端 |
|------|--------|--------|
| 容器内边距 | `px-4 py-3` | `px-6 py-4` |
| 按钮间距 | `gap-1` | `gap-4` |
| 按钮内边距 | `px-2 py-1` | `px-4 py-2` |
| 字体大小 | `text-sm` | `text-base` |
| Logo 大小 | `text-xl` + `text-3xl` | `text-2xl` + `text-4xl` |
### 结果
- ✅ 移动端按钮紧凑排列,不再换行
- ✅ 桌面端保持原有宽松布局
---
## 🚀 发布页面 UI 重构
### 问题描述
原有设计将"发布时间"选项放在表单中,用户可能误选"定时发布"但忘记设置时间。
### 解决方案
将"一键发布"按钮改为两个独立按钮:
- **立即发布** (绿色,占 3/4 宽度) - 主要操作
- **定时** (占 1/4 宽度) - 点击展开时间选择器
#### 新布局 (`publish/page.tsx`)
```tsx
{/* 发布按钮区域 */}
<div className="space-y-3">
<div className="flex gap-3">
{/* 立即发布 - 占 3/4 */}
<button
onClick={() => { setScheduleMode("now"); handlePublish(); }}
className="flex-[3] py-4 rounded-xl font-bold text-lg bg-gradient-to-r from-green-600 to-teal-600 ..."
>
🚀
</button>
{/* 定时发布 - 占 1/4 */}
<button
onClick={() => setScheduleMode(scheduleMode === "scheduled" ? "now" : "scheduled")}
className="flex-1 py-4 rounded-xl font-bold text-base ..."
>
</button>
</div>
{/* 定时发布时间选择器 (展开时显示) */}
{scheduleMode === "scheduled" && (
<div className="flex gap-3 items-center">
<input type="datetime-local" ... />
<button></button>
</div>
)}
</div>
```
### 结果
- ✅ 主操作(立即发布)更醒目
- ✅ 定时发布需要二次确认,防止误触
- ✅ 从表单区域移除发布时间选项,界面更简洁
---
## 🛤️ 后续优化项
### 后端定时发布 (待实现)
**当前状态**:定时发布使用平台端定时(在各平台设置发布时间)
**优化方向**:改为后端定时任务
- 使用 APScheduler 实现任务调度
- 存储定时任务到数据库
- 到时间后端自动触发发布 API
- 支持查看/取消定时任务
**优势**
- 统一逻辑,不依赖平台定时 UI
- 更灵活,可管理定时任务
- 平台页面更新不影响功能
---
## 🤖 Qwen3-TTS 0.6B 声音克隆部署
### 背景
为实现用户自定义声音克隆功能,部署 Qwen3-TTS 0.6B-Base 模型,支持 3 秒参考音频快速克隆。
### GPU 分配
| GPU | 服务 | 模型 |
|-----|------|------|
| GPU0 | Qwen3-TTS | 0.6B-Base (声音克隆) |
| GPU1 | LatentSync | 1.6 (唇形同步) |
### 部署步骤
#### 1. 克隆仓库
```bash
cd /home/rongye/ProgramFiles/ViGent2/models
git clone https://github.com/QwenLM/Qwen3-TTS.git
```
#### 2. 创建 conda 环境
```bash
conda create -n qwen-tts python=3.10 -y
conda activate qwen-tts
```
#### 3. 安装依赖
```bash
cd Qwen3-TTS
pip install -e .
conda install -y -c conda-forge sox # 音频处理依赖
```
#### 4. 下载模型权重 (使用 ModelScope国内更快)
```bash
pip install modelscope
# Tokenizer (651MB)
modelscope download --model Qwen/Qwen3-TTS-Tokenizer-12Hz --local_dir ./checkpoints/Tokenizer
# 0.6B-Base 模型 (2.4GB)
modelscope download --model Qwen/Qwen3-TTS-12Hz-0.6B-Base --local_dir ./checkpoints/0.6B-Base
```
#### 5. 测试推理
```python
# test_inference.py
import torch
import soundfile as sf
from qwen_tts import Qwen3TTSModel
model = Qwen3TTSModel.from_pretrained(
"./checkpoints/0.6B-Base",
device_map="cuda:0",
dtype=torch.bfloat16,
)
wavs, sr = model.generate_voice_clone(
text="测试文本",
language="Chinese",
ref_audio="./examples/myvoice.wav",
ref_text="参考音频的文字内容",
)
sf.write("output.wav", wavs[0], sr)
```
### 测试结果
- ✅ 模型加载成功 (GPU0)
- ✅ 声音克隆推理成功
- ✅ 输出音频 24000Hz质量良好
### 目录结构
```
models/Qwen3-TTS/
├── checkpoints/
│ ├── Tokenizer/ # 651MB
│ └── 0.6B-Base/ # 2.4GB
├── qwen_tts/ # 源码
├── examples/
│ └── myvoice.wav # 参考音频
└── test_inference.py # 测试脚本
```
---
## 📁 今日修改文件清单
| 文件 | 变更类型 | 说明 |
|------|----------|------|
| `frontend/src/lib/axios.ts` | 修改 | Axios 全局拦截器 (401/403 自动跳转) |
| `frontend/src/app/layout.tsx` | 修改 | viewport 配置 + body 渐变背景 |
| `frontend/src/app/globals.css` | 修改 | 安全区域 CSS 支持 |
| `frontend/src/app/page.tsx` | 修改 | 移除独立渐变 + Header 响应式 |
| `frontend/src/app/login/page.tsx` | 修改 | 移除独立渐变 |
| `frontend/src/app/publish/page.tsx` | 修改 | Header 响应式 + 发布按钮重构 |
| `frontend/src/app/admin/page.tsx` | 修改 | 移除独立渐变 |
| `frontend/src/app/register/page.tsx` | 修改 | 移除独立渐变 |
| `README.md` | 修改 | 添加 "iOS/Android 移动端适配" 功能说明 |
| `Docs/FRONTEND_DEV.md` | 修改 | iOS Safari 安全区域兼容规范 + 移动端响应式规则 |
| `models/Qwen3-TTS/` | 新增 | Qwen3-TTS 声音克隆模型部署 |
| `Docs/QWEN3_TTS_DEPLOY.md` | 新增 | Qwen3-TTS 部署指南 |
---
## 🔗 相关文档
- [task_complete.md](../task_complete.md) - 任务总览
- [Day11.md](./Day11.md) - 上传架构重构
- [QWEN3_TTS_DEPLOY.md](../QWEN3_TTS_DEPLOY.md) - Qwen3-TTS 部署指南

View File

@@ -208,3 +208,36 @@ CUDA_VISIBLE_DEVICES=1 python -m scripts.inference \
- [LatentSync GitHub](https://github.com/bytedance/LatentSync)
- [HuggingFace 模型](https://huggingface.co/ByteDance/LatentSync-1.6)
- [论文](https://arxiv.org/abs/2412.09262)
---
## 🐛 修复:视频分辨率降低问题 (17:30)
**问题**generated video is not resolution of original video (原视频预压缩导致输出为 720p)
**原因**:之前的性能优化中强制将视频压缩至 720p 以提高推理速度,导致 1080p 视频输出被降采样。
**修复**:在 `lipsync_service.py` 中禁用了 `_preprocess_video` 调用,直接使用原始视频进行推理。此时 `LatentSync` 将输出与输入视频一致的分辨率。
**结果**
- ✅ 输出视频将保持原始分辨率 (1080p)。
- ⚠️ 推理时间将相应增加 (约需多花费 20-30% 时间)。
---
## ⚡ 性能优化补全 (18:00)
### 1. 常驻模型服务 (Persistent Server)
**目标**: 消除每次生成视频时 30-40秒 的模型加载时间。
**实现**:
- 新增 `models/LatentSync/scripts/server.py` (FastAPI 服务)
- 自动加载后端 `.env` 配置
- 服务常驻显存,支持热调用
**效果**:
- 首次请求:正常加载 (~40s)
- 后续请求:**0s 加载**,直接推理
### 2. GPU 并发控制 (队列)
**目标**: 防止多用户同时请求导致 OOM (显存溢出)。
**实现**:
-`lipsync_service.py` 引入 `asyncio.Lock`
- 建立全局串行队列,无论远程还是本地调用,强制排队
**效果**:
- 即使前端触发多次生成,后端也会逐个处理,保证系统稳定性。

535
Docs/DevLogs/Day7.md Normal file
View File

@@ -0,0 +1,535 @@
# Day 7: 社交媒体发布功能完善
**日期**: 2026-01-21
**目标**: 完成社交媒体发布模块 (80% → 100%)
---
## 📋 任务概览
| 任务 | 状态 |
|------|------|
| SuperIPAgent 架构分析 | ✅ 完成 |
| 优化技术方案制定 | ✅ 完成 |
| B站上传功能实现 | ⏳ 计划中 |
| 定时发布功能 | ⏳ 计划中 |
| 端到端测试 | ⏳ 待进行 |
---
## 🔍 架构优化分析
### SuperIPAgent social-auto-upload 优势
通过分析 `Temp\SuperIPAgent\social-auto-upload`,发现以下**更优设计**:
| 对比项 | 原方案 | 优化方案 ✅ |
|--------|--------|------------|
| **调度方式** | APScheduler (需额外依赖) | **平台 API 原生定时** |
| **B站上传** | Playwright 自动化 (不稳定) | **biliup 库 (官方)** |
| **架构** | 单文件服务 | **模块化 uploader/** |
| **Cookie** | 手动维护 | **自动扫码 + 持久化** |
### 核心优势
1. **更简单**: 无需 APScheduler,直接传时间给平台
2. **更稳定**: biliup 库比 Playwright 选择器可靠
3. **更易维护**: 每个平台独立 uploader 类
---
## 📝 技术方案变更
### 新增依赖
```bash
pip install biliup>=0.4.0
pip install playwright-stealth # 可选,反检测
```
### 移除依赖
```diff
- apscheduler==3.10.4 # 不再需要
```
### 文件结构
```
backend/app/services/
├── publish_service.py # 简化,统一接口
+ ├── uploader/ # 新增: 平台上传器
+ │ ├── base_uploader.py # 基类
+ │ ├── bilibili_uploader.py # B站 (biliup)
+ │ └── douyin_uploader.py # 抖音 (Playwright)
```
---
## 🎯 关键代码模式
### 统一接口
```python
# publish_service.py
async def publish(video_path, platform, title, tags, publish_time=None):
if platform == "bilibili":
uploader = BilibiliUploader(...)
result = await uploader.main()
return result
```
### B站上传 (biliup 库)
```python
from biliup.plugins.bili_webup import BiliBili
with BiliBili(data) as bili:
bili.login_by_cookies(cookie_data)
video_part = bili.upload_file(video_path)
ret = bili.submit() # 平台处理定时
```
---
## 📅 开发计划
### 下午 (11:56 - 14:30)
- ✅ 添加 `biliup>=0.4.0``requirements.txt`
- ✅ 创建 `uploader/` 模块结构
- ✅ 实现 `base_uploader.py` 基类
- ✅ 实现 `bilibili_uploader.py` (biliup 库)
- ✅ 实现 `douyin_uploader.py` (Playwright)
- ✅ 实现 `xiaohongshu_uploader.py` (Playwright)
- ✅ 实现 `cookie_utils.py` (自动 Cookie 生成)
- ✅ 简化 `publish_service.py` (集成所有 uploader)
- ✅ 前端添加定时发布时间选择器
---
## 🎉 实施成果
### 后端改动
1. **新增文件**:
- `backend/app/services/uploader/__init__.py`
- `backend/app/services/uploader/base_uploader.py` (87行)
- `backend/app/services/uploader/bilibili_uploader.py` (135行) - biliup 库
- `backend/app/services/uploader/douyin_uploader.py` (173行) - Playwright
- `backend/app/services/uploader/xiaohongshu_uploader.py` (166行) - Playwright
- `backend/app/services/uploader/cookie_utils.py` (113行) - Cookie 自动生成
- `backend/app/services/uploader/stealth.min.js` - 反检测脚本
2. **修改文件**:
- `backend/requirements.txt`: 添加 `biliup>=0.4.0`
- `backend/app/services/publish_service.py`: 集成所有 uploader (170行)
3. **核心特性**:
-**自动 Cookie 生成** (Playwright QR 扫码登录)
-**B站**: 使用 `biliup` 库 (官方稳定)
-**抖音**: Playwright 自动化
-**小红书**: Playwright 自动化
- ✅ 支持定时发布 (所有平台)
- ✅ stealth.js 反检测 (防止被识别为机器人)
- ✅ 模块化架构 (易于扩展)
### 前端改动
1. **修改文件**:
- `frontend/src/app/publish/page.tsx`: 添加定时发布 UI
2. **新增功能**:
- ✅ 立即发布/定时发布切换按钮
-`datetime-local` 时间选择器
- ✅ 自动传递 ISO 格式时间到后端
- ✅ 一键登录按钮 (自动弹出浏览器扫码)
---
## 🚀 部署步骤
### 1. 安装依赖
```bash
cd backend
pip install biliup>=0.4.0
# 或重新安装所有依赖
pip install -r requirements.txt
# 安装 Playwright 浏览器
playwright install chromium
```
### 2. 客户登录平台 (**极简3步**)
**操作流程**:
1. **拖拽书签**(仅首次)
- 点击前端"🔐 扫码登录"
- 将页面上的"保存登录"按钮拖到浏览器书签栏
2. **扫码登录**
- 点击"打开登录页"
- 扫码登录B站/抖音/小红书
3. **点击书签**
- 登录成功后,点击书签栏的"保存登录"书签
- 自动完成!
**客户实际操作**: 拖拽1次首次+ 扫码1次 + 点击书签1次 = **仅3步**
**下次登录**: 只需扫码 + 点击书签 = **2步**
### 3. 重启后端服务
```bash
cd backend
uvicorn app.main:app --host 0.0.0.0 --port 8006 --reload
```
---
## ✅ Day 7 完成总结
### 核心成果
1. **QR码自动登录** ⭐⭐⭐⭐⭐
- Playwright headless模式提取二维码
- 前端弹窗显示二维码
- 后端自动监控登录状态
- Cookie自动保存
2. **多平台上传器架构**
- B站: biliup官方库
- 抖音: Playwright自动化
- 小红书: Playwright自动化
- stealth.js反检测
3. **定时发布功能**
- 前端datetime-local时间选择
- 平台API原生调度
- 无需APScheduler
4. **用户体验优化**
- 首页添加发布入口
- 视频生成后直接发布按钮
- 一键扫码登录(仅扫码)
**后端** (13个):
- `backend/requirements.txt`
- `backend/app/main.py`
- `backend/app/services/publish_service.py`
- `backend/app/services/qr_login_service.py` (新建)
- `backend/app/services/uploader/__init__.py` (新建)
- `backend/app/services/uploader/base_uploader.py` (新建)
- `backend/app/services/uploader/bilibili_uploader.py` (新建)
- `backend/app/services/uploader/douyin_uploader.py` (新建)
- `backend/app/services/uploader/xiaohongshu_uploader.py` (新建)
- `backend/app/services/uploader/cookie_utils.py` (新建)
- `backend/app/services/uploader/stealth.min.js` (新建)
- `backend/app/api/publish.py`
- `backend/app/api/login_helper.py` (新建)
**前端** (2个):
- `frontend/src/app/page.tsx`
- `frontend/src/app/publish/page.tsx`
---
## 📝 TODO (Day 8优化项)
### 用户体验优化
- [ ] **文件名保留**: 上传视频后保留原始文件名
- [ ] **视频持久化**: 刷新页面后保留生成的视频
### 功能增强
- [ ] 抖音/小红书实际测试
- [ ] 批量发布功能
- [ ] 发布历史记录
---
## 📊 测试清单
- [ ] Playwright 浏览器安装成功
- [ ] B站 Cookie 自动生成测试
- [ ] 抖音 Cookie 自动生成测试
- [ ] 小红书 Cookie 自动生成测试
- [ ] 测试 B站立即发布功能
- [ ] 测试抖音立即发布功能
- [ ] 测试小红书立即发布功能
- [ ] 测试定时发布功能
---
## ⚠️ 注意事项
1. **B站 Cookie 获取**
- 参考 `social-auto-upload/examples/get_bilibili_cookie.py`
- 或手动登录后导出 JSON
2. **定时发布原理**
- 前端收集时间
- 后端传给平台 API
- **平台自行处理调度** (无需 APScheduler)
3. **biliup 优势**
- 官方 API 支持
- 社区活跃维护
- 比 Playwright 更稳定
---
## 🔗 相关文档
- [SuperIPAgent social-auto-upload](file:///d:/CodingProjects/Antigravity/Temp/SuperIPAgent/social-auto-upload)
- [优化实施计划](implementation_plan.md)
- [Task Checklist](task.md)
---
## 🎨 UI 一致性优化 (16:00 - 16:35)
**问题**:导航栏不一致、页面偏移
- 首页 Logo 无法点击,发布页可点击
- 发布页多余标题"📤 社交媒体发布"
- 首页因滚动条向左偏移 15px
**修复**
- `frontend/src/app/page.tsx` - Logo 改为 `<Link>` 组件
- `frontend/src/app/publish/page.tsx` - 删除页面标题和顶端 padding
- `frontend/src/app/globals.css` - 隐藏滚动条(保留滚动功能)
**状态**:✅ 两页面完全对齐
---
## 🔍 QR 登录问题诊断 (16:05)
**问题**:所有平台 QR 登录超时 `Page.wait_for_selector: Timeout 10000ms exceeded`
**原因**
1. Playwright headless 模式被检测
2. 缺少 stealth.js 反检测
3. CSS 选择器可能过时
**状态**:✅ 已修复
---
## 🔧 QR 登录功能修复 (16:35 - 16:45)
### 实施方案
#### 1. 启用 Stealth 模式
```python
# 避免headless检测
browser = await playwright.chromium.launch(
headless=True,
args=[
'--disable-blink-features=AutomationControlled',
'--no-sandbox',
'--disable-dev-shm-usage'
]
)
```
#### 2. 配置真实浏览器特征
```python
context = await browser.new_context(
viewport={'width': 1920, 'height': 1080},
user_agent='Mozilla/5.0 (Windows NT 10.0; Win64; x64) ...',
locale='zh-CN',
timezone_id='Asia/Shanghai'
)
```
#### 3. 注入 stealth.js 脚本
```python
stealth_path = Path(__file__).parent / 'uploader' / 'stealth.min.js'
if stealth_path.exists():
await page.add_init_script(path=str(stealth_path))
```
#### 4. 多选择器 Fallback 策略
```python
"bilibili": {
"qr_selectors": [
".qrcode-img img",
"canvas.qrcode-img",
"img[alt*='二维码']",
".login-scan-box img",
"#qrcode-img"
]
}
# Douyin: 4个选择器, Xiaohongshu: 4个选择器
```
#### 5. 增加等待时间
- 页面加载3s → 5s + `wait_until='networkidle'`
- 选择器超时10s → 30s
#### 6. 调试功能
```python
# 保存调试截图到 backend/debug_screenshots/
if not qr_element:
screenshot_path = debug_dir / f"{platform}_debug.png"
await page.screenshot(path=str(screenshot_path))
```
### 修改文件
**后端** (1个):
- `backend/app/services/qr_login_service.py` - 全面重构QR登录逻辑
### 结果
- ✅ 添加反检测措施stealth模式、真实UA
- ✅ 多选择器fallback每平台4-5个
- ✅ 等待时间优化5s + 30s
- ✅ 自动保存调试截图
- 🔄 待服务器测试验证
---
## 📋 文档规则优化 (16:42 - 17:10)
**问题**Doc_Rules需要优化避免误删历史内容、规范工具使用、防止任务清单遗漏
**优化内容(最终版)**
1. **智能修改判断标准**
- 场景1错误修正 → 直接替换/删除
- 场景2方案改进 → 保留+追加V1/V2
- 场景3同一天多次修改 → 合并为最终版本
2. **工具使用规范**
- ✅ 必须使用 `replace_file_content`
- ❌ 禁止命令行工具(避免编码错误)
3. **task_complete 完整性保障** (新增)
- ✅ 引入 "完整性检查清单" (4大板块逐项检查)
- ✅ 引入记忆口诀:"头尾时间要对齐,任务规划两手抓,里程碑上别落下"
4. **结构优化**
- 合并冗余章节
- 移除无关项目组件
**修改文件**
- `Docs/Doc_Rules.md` - 包含检查清单的最终完善版
---
## ⚡ QR 登录性能与显示优化 (17:30)
**问题**
1. **速度慢**: 顺序等待每个选择器 (30s timeout × N),导致加载极慢
2. **B站显示错乱**: Fallback 触发全页截图,而不是二维码区域
**优化方案**
1. **并行等待 (Performance)**:
- 使用 `wait_for_selector("s1, s2, s3")` 联合选择器
- Playwright 自动等待任意一个出现 (即时响应,不再单纯 sleep)
- 超时时间从 30s 单次改为 15s 总计
2. **选择器增强 (Accuracy)**:
- 由于 B站登录页改版旧选择器失效
- 新增 `div[class*='qrcode'] canvas``div[class*='qrcode'] img`
**修改文件**:
- `backend/app/services/qr_login_service.py`
---
## ⚡ QR 登录最终坚固化 (17:45)
**问题**
- 并行等待虽然消除了顺序延迟,但 **CSS 选择器仍然无法匹配** (Timeout 15000ms)
- 截图显示二维码可见,但 Playwright 认为不可见或未找到(可能涉及动态类名或 DOM 结构变化)
**解决方案 (三重保障)**
1. **策略 1**: CSS 联合选择器 (超时缩短为 5s快速试错)
2. **策略 2 (新)**: **文本锚点定位**
- 不不再依赖脆弱的 CSS 类名
- 直接搜索屏幕上的 "扫码登录" 文字
- 智能查找文字附近的 `<canvas>``<img>`
3. **策略 3 (调试)**: **HTML 源码导出**
- 如果都失败,除了截图外,自动保存 `bilibili_debug.html`
- 彻底分析页面结构的"核武器"
**修改文件**:
- `backend/app/services/qr_login_service.py` (v3 最终版)
---
## ⚡ QR 登录终极修复 (17:55)
**致命问题**
1. **监控闪退**: 后端使用 `async with async_playwright()`,导致函数返回时浏览器自动关闭,后台监控任务 (`_monitor_login_status`) 操作已关闭的页面报错 `TargetClosedError`
2. **仍有延迟**: 之前的策略虽然改进,但串行等待 CSS 超时 (5s) 仍不可避免。
**解决方案**
1. **生命周期重构 (Backend)**:
- 移除上下文管理器,改为 `self.playwright.start()` 手动启动
- 浏览器实例持久化到类属性 (`self.browser`)
- 仅在监控任务完成或超时后,在 `finally` 块中手动清理资源 (`_cleanup`)
2. **真·并行策略**:
- 使用 `asyncio.wait(tasks, return_when=FIRST_COMPLETED)`
- CSS选择器策略 和 文本定位策略 **同时运行**
- 谁先找到二维码,直接返回,取消另一个任务
- **延迟降至 0秒** (理论极限)
**修改文件**:
- `backend/app/services/qr_login_service.py` (v4 重构版)
---
## 🐛 并行逻辑 Bug 修复 (18:00)
**问题现象**:
- B站登录正常**抖音秒挂** ("所有策略失败")。
- 原因:代码逻辑是 `asyncio.wait(FIRST_COMPLETED)`,如果其中一个策略(如文本策略)不适用该平台,它会立即返回 `None`
- **BUG**: 代码收到 `None`错误地以为任务结束取消了还在运行的另一个策略CSS策略
**修复方案**:
1. **修正并行逻辑**:
- 如果一个任务完成了但没找到结果 (Result is None)**不取消** 其他任务。
- 继续等待剩下的 `pending` 任务,直到找到结果或所有任务都跑完。
2. **扩展文本策略**:
-**抖音 (Douyin)** 也加入到文本锚点定位的支持列表中。
- 增加关键词 `["扫码登录", "打开抖音", "抖音APP"]`
**修改文件**:
- `backend/app/services/qr_login_service.py` (v5 修正版)
---
## ⚡ 抖音文本策略优化 (18:10)
**问题**:
- 抖音页面也是动态渲染的,"扫码登录" 文字出现有延迟。
- 之前的 `get_by_text(...).count()` 是瞬间检查,如果页面还没加载完文字,直接返回 0 (失败)。
- 结果CSS 还在等,文本策略瞬间报空,导致最终还是没找到。
**优化方案**:
1. **智能等待**: 对每个关键词 (如 "使用手机抖音扫码") 增加 `wait_for(timeout=2000)`,给页面一点加载时间。
2. **扩大搜索圈**: 找到文字后,向父级查找 **5层** (之前是3层),以适应抖音复杂的 DOM 结构。
3. **尺寸过滤**: 增加 `width > 100` 判断,防止误匹配到头像或小图标。
**修改文件**:
- `backend/app/services/qr_login_service.py` (v6 抖音增强版)
**状态**: ✅ 抖音策略已强化
---
## ✅ 验证结果 (18:15)
**用户反馈**:
- B站成功获取 Cookie 并显示"已登录"状态。
- 抖音:成功获取 Cookie 并显示"已登录"状态。
- **结论**:
1. 并行策略 (`asyncio.wait`) 有效解决了等待延迟。
2. 文本锚点定位 (`get_by_text`) 有效解决了动态页面元素查找问题。
3. 生命周期重构 (`manual start/close`) 解决了后台任务闪退问题。
**下一步**:
- 进行实际视频发布测试。

113
Docs/DevLogs/Day8.md Normal file
View File

@@ -0,0 +1,113 @@
# Day 8: 用户体验优化
**日期**: 2026-01-22
**目标**: 文件名保留 + 视频持久化 + 界面优化
---
## 📋 任务概览
| 任务 | 状态 |
|------|------|
| 文件名保留 | ✅ 完成 |
| 视频持久化 | ✅ 完成 |
| 历史视频列表 | ✅ 完成 |
| 删除功能 | ✅ 完成 |
---
## 🎉 实施成果
### 后端改动
**修改文件**:
- `backend/app/api/materials.py`
-`sanitize_filename()` 文件名安全化
- ✅ 时间戳前缀避免冲突 (`{timestamp}_{原始文件名}`)
-`list_materials` 显示原始文件名
-`DELETE /api/materials/{id}` 删除素材
- `backend/app/api/videos.py`
-`GET /api/videos/generated` 历史视频列表
-`DELETE /api/videos/generated/{id}` 删除视频
### 前端改动
**修改文件**:
- `frontend/src/app/page.tsx`
-`GeneratedVideo` 类型定义
-`generatedVideos` 状态管理
-`fetchGeneratedVideos()` 获取历史
-`deleteMaterial()` / `deleteVideo()` 删除功能
- ✅ 素材卡片添加删除按钮 (hover 显示)
- ✅ 历史视频列表组件 (右侧预览区下方)
- ✅ 生成完成后自动刷新历史列表
---
## 🔧 API 变更
### 新增端点
| 方法 | 路径 | 说明 |
|------|------|------|
| GET | `/api/videos/generated` | 获取生成视频列表 |
| DELETE | `/api/videos/generated/{id}` | 删除生成视频 |
| DELETE | `/api/materials/{id}` | 删除素材 |
### 文件命名规则
```
原始: 测试视频.mp4
保存: 1737518400_测试视频.mp4
显示: 测试视频.mp4 (前端自动去除时间戳前缀)
```
---
## ✅ 完成总结
1. **文件名保留** - 上传保留原始名称,时间戳前缀避免冲突
2. **视频持久化** - 从文件系统读取,刷新不丢失
3. **历史列表** - 右侧显示历史视频,点击切换播放
4. **删除功能** - 素材和视频均支持删除
---
## 📊 测试清单
- [x] 上传视频后检查素材列表显示原始文件名
- [x] 刷新页面后检查历史视频列表持久化
- [x] 测试删除素材功能
- [x] 测试删除生成视频功能
- [x] 测试历史视频列表点击切换播放
---
## 🔧 发布功能修复 (Day 8 下半场)
> 以下修复在用户体验优化后进行
### 问题
1. **抖音 QR 登录假成功** - 前端检测到旧 Cookie 文件就显示"登录成功",实际可能已过期
2. **抖音上传循环卡死** - 发布后检测逻辑不完善,`while True` 无超时
3. **前端轮询不规范** - 使用 `setInterval` 手动轮询,不符合 React 最佳实践
### 修复
**后端**:
- `publish_service.py` - 添加 `logout()` 方法、修复 `get_login_session_status()` 优先检查活跃会话
- `api/publish.py` - 新增 `POST /api/publish/logout/{platform}` 端点
- `douyin_uploader.py` - 添加 `import time`,修复发布按钮点击竞态条件
**前端**:
- `publish/page.tsx` - 使用 `useSWR` 替代 `setInterval` 轮询登录状态
- `package.json` - 添加 `swr` 依赖
### 新增 API
| 方法 | 路径 | 说明 |
|------|------|------|
| POST | `/api/publish/logout/{platform}` | 注销平台登录 |

320
Docs/DevLogs/Day9.md Normal file
View File

@@ -0,0 +1,320 @@
# Day 9: 发布模块代码优化
**日期**: 2026-01-23
**目标**: 代码质量优化 + 发布功能验证
---
## 📋 任务概览
| 任务 | 状态 |
|------|------|
| B站/抖音发布验证 | ✅ 完成 |
| 资源清理保障 (try-finally) | ✅ 完成 |
| 超时保护 (消除无限循环) | ✅ 完成 |
| 小红书 headless 模式修复 | ✅ 完成 |
| API 输入验证 | ✅ 完成 |
| 类型提示完善 | ✅ 完成 |
| 服务层代码优化 | ✅ 完成 |
| 扫码登录等待界面 | ✅ 完成 |
| 抖音登录策略优化 | ✅ 完成 |
| 发布成功审核提示 | ✅ 完成 |
| 用户认证系统规划 | ✅ 计划完成 |
---
## 🎉 发布验证结果
### 登录功能
-**B站登录成功** - 策略3(Text)匹配Cookie已保存
-**抖音登录成功** - 策略3(Text)匹配Cookie已保存
### 发布功能
-**抖音发布成功** - 自动关闭弹窗、跳转管理页面
-**B站发布成功** - API返回 `bvid: BV14izPBQEbd`
---
## 🔧 代码优化
### 1. 资源清理保障
**问题**Playwright 浏览器在异常路径可能未关闭
**修复**`try-finally` 模式确保资源释放
```python
browser = None
context = None
try:
browser = await playwright.chromium.launch(headless=True)
context = await browser.new_context(...)
# ... 业务逻辑 ...
finally:
if context:
try: await context.close()
except Exception: pass
if browser:
try: await browser.close()
except Exception: pass
```
### 2. 超时保护
**问题**`while True` 循环可能导致任务卡死
**修复**:添加类级别超时常量
```python
class DouyinUploader(BaseUploader):
UPLOAD_TIMEOUT = 300 # 视频上传超时
PUBLISH_TIMEOUT = 180 # 发布检测超时
PAGE_REDIRECT_TIMEOUT = 60 # 页面跳转超时
```
### 3. B站 bvid 提取修复
**问题**API 返回的 bvid 在 `data` 字段内
**修复**:同时检查多个位置
```python
bvid = ret.get('data', {}).get('bvid') or ret.get('bvid', '')
aid = ret.get('data', {}).get('aid') or ret.get('aid', '')
```
### 4. API 输入验证
**修复**:所有端点添加平台验证
```python
SUPPORTED_PLATFORMS = {"bilibili", "douyin", "xiaohongshu"}
if platform not in SUPPORTED_PLATFORMS:
raise HTTPException(status_code=400, detail=f"不支持的平台: {platform}")
```
---
## 🎨 用户体验优化
### 1. 扫码登录等待界面
**问题**:点击登录后,二维码获取需要几秒,用户无反馈
**优化**
- 点击登录后立即显示加载弹窗
- 加载动画 (旋转圈 + "正在获取二维码...")
- 二维码获取成功后自动切换显示
### 2. 抖音登录策略优化
**问题**:抖音登录需要约 23 秒获取二维码 (策略1/2超时)
**原因分析**
| 策略 | 抖音耗时 | B站耗时 | 结果 |
|------|----------|---------|------|
| Role | 10s 超时 | N/A | ❌ |
| CSS | 8s 超时 | 8s 超时 | ❌ |
| Text | ~1s | ~1s | ✅ |
**优化**
```python
# 抖音/B站Text 策略优先
if self.platform in ("douyin", "bilibili"):
qr_element = await self._try_text_strategy(page) # 优先
if not qr_element:
await page.wait_for_selector(..., timeout=3000) # CSS 备用
else:
# 其他平台保持 CSS 优先
```
**效果**
- 抖音登录二维码获取:~23s → ~5s
- B站登录二维码获取~13s → ~5s
### 3. 发布成功审核提示
**问题**:发布成功后,用户不知道需要审核
**优化**
- 后端消息改为 "发布成功,待审核"
- 前端增加提示 "⏳ 审核一般需要几分钟,请耐心等待"
- 发布结果 10 秒后自动消失
---
## 📁 修改文件列表
### 后端
| 文件 | 修改内容 |
|------|----------|
| `app/api/publish.py` | 输入验证、平台常量、文档改进 |
| `app/services/publish_service.py` | 类型提示、平台 enabled 标记 |
| `app/services/qr_login_service.py` | **策略顺序优化**、超时缩短 |
| `app/services/uploader/base_uploader.py` | 类型提示 |
| `app/services/uploader/bilibili_uploader.py` | **发布消息改为"待审核"** |
| `app/services/uploader/douyin_uploader.py` | **发布消息改为"待审核"** |
| `app/services/uploader/xiaohongshu_uploader.py` | **发布消息改为"待审核"** |
### 前端
| 文件 | 修改内容 |
|------|----------|
| `src/app/publish/page.tsx` | **加载动画、审核提示、结果自动消失** |
---
## ✅ 完成总结
1. **发布功能验证通过** - B站/抖音登录和发布均正常
2. **代码健壮性提升** - 资源清理、超时保护、异常处理
3. **代码可维护性** - 完整类型提示、常量化配置
4. **服务器兼容性** - 小红书 headless 模式修复
5. **用户体验优化** - 加载状态、策略顺序、审核提示
---
## 🔐 用户认证系统规划
> 规划完成,待下一阶段实施
### 技术方案
| 项目 | 方案 |
|------|------|
| 认证框架 | FastAPI + JWT (HttpOnly Cookie) |
| 数据库 | Supabase (PostgreSQL + RLS) |
| 管理员 | .env 预设 + startup 自动初始化 |
| 授权期限 | expires_at 字段,可设定有效期 |
| 单设备登录 | 后踢前模式 + Session Token 强校验 |
| 账号隔离 | 规范化 Cookie 路径 `user_data/{user_id}/` |
### 安全增强
1. **HttpOnly Cookie** - 防 XSS 窃取 Token
2. **Session Token 校验** - JWT 包含 session_token每次请求验证
3. **Startup 初始化管理员** - 服务启动自动创建
4. **RLS 最后防线** - Supabase 行级安全策略
5. **Cookie 路径规范化** - UUID 格式验证 + 白名单平台校验
### 数据库表
```sql
-- users (用户)
-- user_sessions (单设备登录)
-- social_accounts (社交账号绑定)
```
> 详细设计见 [implementation_plan.md](file:///C:/Users/danny/.gemini/antigravity/brain/06e7632c-12c6-4e80-b321-e1e642144560/implementation_plan.md)
### 后端实现进度
**状态**:✅ 核心模块完成
| 文件 | 说明 | 状态 |
|------|------|------|
| `requirements.txt` | 添加 supabase, python-jose, passlib | ✅ |
| `app/core/config.py` | 添加 Supabase/JWT/管理员配置 | ✅ |
| `app/core/supabase.py` | Supabase 客户端单例 | ✅ |
| `app/core/security.py` | JWT + 密码 + HttpOnly Cookie | ✅ |
| `app/core/paths.py` | Cookie 路径规范化 | ✅ |
| `app/core/deps.py` | 依赖注入 (当前用户/管理员) | ✅ |
| `app/api/auth.py` | 注册/登录/登出 API | ✅ |
| `app/api/admin.py` | 用户管理 API | ✅ |
| `app/main.py` | startup 初始化管理员 | ✅ |
| `database/schema.sql` | Supabase 数据库表 + RLS | ✅ |
### 前端实现进度
**状态**:✅ 核心页面完成
| 文件 | 说明 | 状态 |
|------|------|------|
| `src/lib/auth.ts` | 认证工具函数 | ✅ |
| `src/app/login/page.tsx` | 登录页 | ✅ |
| `src/app/register/page.tsx` | 注册页 | ✅ |
| `src/app/admin/page.tsx` | 管理后台 | ✅ |
| `src/proxy.ts` | 路由保护 | ✅ |
### 账号隔离集成
**状态**:✅ 完成
| 文件 | 修改内容 | 状态 |
|------|----------|------|
| `app/services/publish_service.py` | 重写支持 user_id 隔离 Cookie | ✅ |
| `app/api/publish.py` | 添加认证依赖,传递 user_id | ✅ |
**Cookie 存储路径**:
- 已登录用户: `user_data/{user_id}/cookies/{platform}_cookies.json`
- 未登录用户: `app/cookies/{platform}_cookies.json` (兼容旧版)
---
## 🔐 用户认证系统实现 (2026-01-23)
### 问题描述
为了支持多用户管理和资源隔离,需要实现一套完整的用户认证系统,取代以前的单用户模式。要求:
- 使用 Supabase 作为数据库
- 支持注册、登录、登出
- 管理员审核机制 (is_active)
- 单设备登录限制
- HttpOnly Cookie 存储 Token
### 解决方案
#### 1. 数据库设计 (Supabase)
创建了三张核心表:
- `users`: 存储邮箱、密码哈希、角色、激活状态
- `user_sessions`: 存储 Session Token实现单设备登录 (后踢前)
- `social_accounts`: 社交账号绑定信息 (B站/抖音Cookie)
#### 2. 后端实现 (FastAPI)
- **依赖注入** (`deps.py`): `get_current_user` 自动验证 Token 和 Session
- **安全模块** (`security.py`): JWT 生成与验证,密码 bcrypt 哈希
- **路由模块** (`auth.py`):
- `/register`: 注册后默认为 `pending` 状态
- `/login`: 验证通过后生成 JWT 并写入 HttpOnly Cookie
- `/me`: 获取当前用户信息
#### 3. 部署方案
- 采用 Supabase 云端免费版
- 为了防止 7 天不活跃暂停,配置了 GitHub Actions / Crontab 自动保活
- 创建了独立的部署文档 `Docs/AUTH_DEPLOY.md`
### 结果
- ✅ 成功实现了完整的 JWT 认证流程
- ✅ 管理员可以控制用户激活状态
- ✅ 实现了安全的无感 Token 刷新 (Session Token)
- ✅ 敏感配置 (Supabase Key) 通过环境变量管理
---
## 🔗 相关文档
- [用户认证系统实现计划](file:///C:/Users/danny/.gemini/antigravity/brain/06e7632c-12c6-4e80-b321-e1e642144560/implementation_plan.md)
- [代码审核报告](file:///C:/Users/danny/.gemini/antigravity/brain/a28bb1a6-2929-4c55-b837-c989943844e1/walkthrough.md)
- [部署手册](file:///d:/CodingProjects/Antigravity/ViGent2/Docs/DEPLOY_MANUAL.md)
---
## 🛠️ 部署调试记录 (2026-01-23)
### 1. 服务启动方式修正
- **问题**: pm2 直接启动 python/uvicorn 会导致 `SyntaxError` (Node.js 尝试解释 Python)
- **解决**: 改用 `.sh` 脚本封装启动命令
### 2. 依赖缺失与兼容性
- **问题 1**: `ImportError: email-validator is not installed` (Pydantic 依赖)
- **修复**: 添加 `email-validator>=2.1.0`
- **问题 2**: `AttributeError: module 'bcrypt' has no attribute '__about__'` (Passlib 兼容性)
- **修复**: 锁定 `bcrypt==4.0.1`
### 3. 前端生产环境构建
- **问题**: `Error: Could not find a production build`
- **解决**: 启动前必须执行 `npm run build`
### 4. 性能调优
- **现象**: SSH 远程连接出现显著卡顿
- **排查**: `vigent2-latentsync` 启动时模型加载占用大量系统资源
- **优化**: 生产环境建议按需开启 LatentSync 服务,或确保服务器 IO/带宽充足。停止该服务后 SSH 恢复流畅。

View File

@@ -10,8 +10,194 @@
|------|------|
| **默认更新** | 只更新 `DayN.md` |
| **按需更新** | `task_complete.md` 仅在用户**明确要求**时更新 |
| **增量追加** | 禁止覆盖/新建。请使用 replace/edit 工具插入新内容。 |
| **智能修改** | 错误→替换,改进→追加(见下方详细规则) |
| **先读后写** | 更新前先查看文件当前内容 |
| **日内合并** | 同一天的多次小修改合并为最终版本 |
---
## 🧾 全局文档更新清单 (Checklist)
> **每次提交重要变更时,请核对以下文件是否需要同步:**
| 优先级 | 文件路径 | 检查重点 |
| :---: | :--- | :--- |
| 🔥 **High** | `Docs/DevLogs/DayN.md` | **(最新日志)** 详细记录变更、修复、代码片段 |
| 🔥 **High** | `Docs/task_complete.md` | **(任务总览)** 更新 `[x]`、进度条、时间线 |
| ⚡ **Med** | `README.md` | **(项目主页)** 功能特性、技术栈、最新截图 |
| ⚡ **Med** | `Docs/DEPLOY_MANUAL.md` | **(部署手册)** 环境变量、依赖包、启动命令变更 |
| ⚡ **Med** | `Docs/FRONTEND_DEV.md` | **(前端规范)** API封装、日期格式化、新页面规范 |
| 🧊 **Low** | `Docs/implementation_plan.md` | **(实施计划)** 核对计划与实际实现的差异 |
| 🧊 **Low** | `frontend/README.md` | **(前端文档)** 新页面路由、组件用法、UI变更 |
---
## 🔍 修改原内容的判断标准
### 场景 1错误修正 → **替换/删除**
**条件**:之前的方法/方案**无法工作**或**逻辑错误**
**操作**
- ✅ 直接替换为正确内容
- ✅ 添加一行修正说明:`> **修正 (HH:MM)**[错误原因],已更新`
- ❌ 不保留错误方法(避免误导)
**示例**
```markdown
## 🔧 XXX功能修复
~~旧方法:增加超时时间(无效)~~
> **修正 (16:20)**单纯超时无法解决已更新为Stealth模式
### 解决方案
- 启用Stealth模式...
```
### 场景 2方案改进 → **保留+追加**
**条件**:之前的方法**可以工作**,后来发现**更好的方法**
**操作**
- ✅ 保留原方法(标注版本 V1/V2
- ✅ 追加新方法
- ✅ 说明改进原因
**示例**
```markdown
## ⚡ 性能优化
### V1: 基础实现 (Day 5)
- 单线程处理 ✅
### V2: 性能优化 (Day 7)
- 多线程并发
- 速度提升 3x ⚡
```
### 场景 3同一天多次修改 → **合并**
**条件**:同一天内对同一功能的多次小改动
**操作**
- ✅ 直接更新为最终版本
- ❌ 不记录中间的每次迭代
- ✅ 可注明"多次优化后"
---
## 🔍 更新前检查清单
> **核心原则**:追加前先查找,避免重复和遗漏
### 必须执行的检查步骤
**1. 快速浏览全文**(使用 `view_file``grep_search`
```markdown
# 检查是否存在:
- 同主题的旧章节?
- 待更新的状态标记(🔄 待验证)?
- 未完成的TODO项
```
**2. 判断操作类型**
| 情况 | 操作 |
|------|------|
| **有相关旧内容且错误** | 替换场景1 |
| **有相关旧内容可改进** | 追加V2场景2 |
| **有待验证状态** | 更新状态标记 |
| **全新独立内容** | 追加到末尾 |
**3. 必须更新的内容**
-**状态标记**`🔄 待验证``✅ 已修复` / `❌ 失败`
-**进度百分比**:更新为最新值
-**文件修改列表**:补充新修改的文件
-**禁止**:创建重复的章节标题
### 示例场景
**错误示例**(未检查旧内容):
```markdown
## 🔧 QR登录修复 (15:00)
**状态**:🔄 待验证
## 🔧 QR登录修复 (16:00) ❌ 重复!
**状态**:✅ 已修复
```
**正确做法**
```markdown
## 🔧 QR登录修复 (15:00)
**状态**:✅ 已修复 ← 直接更新原状态
```
---
## 工具使用规范
> **核心原则**:使用正确的工具,避免字符编码问题
### ✅ 推荐工具replace_file_content
**使用场景**
- 追加新章节到文件末尾
- 修改/替换现有章节内容
- 更新状态标记(🔄 → ✅)
- 修正错误内容
**优势**
- ✅ 自动处理字符编码Windows CRLF
- ✅ 精确替换,不会误删其他内容
- ✅ 有错误提示,方便调试
**注意事项**
```markdown
1. **必须精确匹配**TargetContent 必须与文件完全一致
2. **处理换行符**:文件使用 \r\n不要漏掉 \r
3. **合理范围**StartLine/EndLine 应覆盖目标内容
4. **先读后写**:编辑前先 view_file 确认内容
```
### ❌ 禁止使用:命令行工具
**禁止场景**
- ❌ 使用 `echo >>` 追加内容(编码问题)
- ❌ 使用 PowerShell 直接修改文档(破坏格式)
- ❌ 使用 sed/awk 等命令行工具
**原因**
- 容易破坏 UTF-8 编码
- Windows CRLF vs Unix LF 混乱
- 难以追踪修改,容易出错
**唯一例外**:简单的全局文本替换(如批量更新日期),且必须使用 `-NoNewline` 参数
### 📝 最佳实践示例
**追加新章节**
```python
replace_file_content(
TargetFile="path/to/DayN.md",
TargetContent="## 🔗 相关文档\n\n...\n\n", # 文件末尾的内容
ReplacementContent="## 🔗 相关文档\n\n...\n\n---\n\n## 🆕 新章节\n内容...",
StartLine=280,
EndLine=284
)
```
**修改现有内容**
```python
replace_file_content(
TargetContent="**状态**:🔄 待修复",
ReplacementContent="**状态**:✅ 已修复",
StartLine=310,
EndLine=310
)
```
---
@@ -21,6 +207,9 @@
ViGent/Docs/
├── task_complete.md # 任务总览(仅按需更新)
├── Doc_Rules.md # 本文件
├── FRONTEND_DEV.md # 前端开发规范
├── DEPLOY_MANUAL.md # 部署手册
├── SUPABASE_DEPLOY.md # Supabase 部署文档
└── DevLogs/
├── Day1.md # 开发日志
└── ...
@@ -30,10 +219,11 @@ ViGent/Docs/
## 📅 DayN.md 更新规则(日常更新)
### 新建判断
- 检查最新 `DayN.md` 的日期
- **今天** → 追加到现有文件
- **之前** → 创建 `Day{N+1}.md`
### 新建判断 (对话开始前)
1. **回顾进度**:查看 `task_complete.md` 了解当前状态
2. **检查日期**:查看最新 `DayN.md`
- **今天 (与当前日期相同)** → 🚨 **绝对禁止创建新文件**,必须**追加**到现有 `DayN.md` 末尾!即使是完全不同的功能模块。
- **之前 (昨天或更早)** → 创建 `Day{N+1}.md`
### 追加格式
```markdown
@@ -62,6 +252,24 @@ ViGent/Docs/
**状态**:✅ 已修复 / 🔄 待验证
```
---
## 📏 内容简洁性规则
### 代码示例长度控制
- **原则**只展示关键代码片段10-20行以内
- **超长代码**:使用 `// ... 省略 ...` 或仅列出文件名+行号
- **完整代码**:引用文件链接,而非粘贴全文
### 调试信息处理
- **临时调试**:验证后删除(如调试日志、测试截图)
- **有价值信息**:保留(如错误日志、性能数据)
### 状态标记更新
- **🔄 待验证** → 验证后更新为 **✅ 已修复** 或 **❌ 失败**
- 直接修改原状态,无需追加新行
---
## 📝 task_complete.md 更新规则(仅按需)
@@ -72,25 +280,29 @@ ViGent/Docs/
- **格式一致性**:直接参考 `task_complete.md` 现有格式追加内容。
- **进度更新**:仅在阶段性里程碑时更新进度百分比。
---
### 🔍 完整性检查清单 (必做)
## 🚀 新对话检查清单
每次更新 `task_complete.md` 时,必须**逐一检查**以下所有板块:
1. 查看 `task_complete.md` → 了解整体进度
2. 查看最新 `DayN.md` → 确认今天是第几天
3. 根据日期决定追加或新建 Day 文件
1. **文件头部 & 导航**
- [ ] `更新时间`:必须是当天日期
- [ ] `整体进度`:简述当前状态
- [ ] `快速导航`Day 范围与文档一致
2. **核心任务区**
- [ ] `已完成任务`:添加新的 [x] 项目
- [ ] `后续规划`:管理三色板块 (优先/债务/未来)
3. **统计与回顾**
- [ ] `进度统计`:更新对应模块状态和百分比
- [ ] `里程碑`:若有重大进展,追加 `## Milestone N`
4. **底部链接**
- [ ] `时间线`:追加今日概括
- [ ] `相关文档`:更新 DayLog 链接范围
> **口诀**:头尾时间要对齐,任务规划两手抓,里程碑上别落下。
---
## 🎯 项目组件
| 组件 | 位置 |
|------|------|
| 后端 (FastAPI) | `ViGent/backend/` |
| 前端 (Next.js) | `ViGent/frontend/` |
| AI 模型 (MuseTalk) | `ViGent/models/` |
| 文档 | `ViGent/Docs/` |
---
**最后更新**2026-01-13
**最后更新**2026-01-23

182
Docs/FRONTEND_DEV.md Normal file
View File

@@ -0,0 +1,182 @@
# 前端开发规范
## 目录结构
```
frontend/src/
├── app/ # Next.js App Router 页面
│ ├── page.tsx # 首页(视频生成)
│ ├── publish/ # 发布页面
│ ├── admin/ # 管理员页面
│ ├── login/ # 登录页面
│ └── register/ # 注册页面
├── lib/ # 公共工具函数
│ ├── axios.ts # Axios 实例(含 401/403 拦截器)
│ └── auth.ts # 认证相关函数
└── proxy.ts # 路由代理(原 middleware
```
---
## iOS Safari 安全区域兼容
### 问题
iPhone Safari 浏览器顶部(刘海/灵动岛和底部Home 指示条)有安全区域,默认情况下页面背景不会延伸到这些区域,导致白边。
### 解决方案(三层配合)
#### 1. Viewport 配置 (`layout.tsx`)
```typescript
import type { Viewport } from "next";
export const viewport: Viewport = {
width: 'device-width',
initialScale: 1,
viewportFit: 'cover', // 允许内容延伸到安全区域
themeColor: '#0f172a', // 顶部状态栏颜色(与背景一致)
};
```
#### 2. 全局背景统一到 body (`layout.tsx`)
```tsx
<html lang="en" style={{ backgroundColor: '#0f172a' }}>
<body
style={{
margin: 0,
minHeight: '100dvh', // 使用 dvh 而非 vh
background: 'linear-gradient(to bottom, #0f172a 0%, #0f172a 5%, #581c87 50%, #0f172a 95%, #0f172a 100%)',
}}
>
{children}
</body>
</html>
```
#### 3. CSS 安全区域支持 (`globals.css`)
```css
html {
background-color: #0f172a !important;
min-height: 100%;
}
body {
margin: 0 !important;
min-height: 100dvh;
padding-top: env(safe-area-inset-top);
padding-bottom: env(safe-area-inset-bottom);
}
```
### 关键要点
- **渐变背景放 body不放页面 div** - 安全区域在 div 之外
- **使用 `100dvh` 而非 `100vh`** - dvh 是动态视口高度,适配移动端
- **themeColor 与背景边缘色一致** - 避免状态栏色差
- **页面 div 移除独立背景** - 使用透明,继承 body 渐变
---
## 移动端响应式规范
### Header 按钮布局
```tsx
// 移动端紧凑,桌面端宽松
<div className="flex items-center gap-1 sm:gap-4">
<button className="px-2 sm:px-4 py-1 sm:py-2 text-sm sm:text-base ...">
</button>
</div>
```
### 常用响应式断点
| 断点 | 宽度 | 用途 |
|------|------|------|
| 默认 | < 640px | 移动端 |
| `sm:` | ≥ 640px | 平板/桌面 |
| `lg:` | ≥ 1024px | 大屏桌面 |
---
## API 请求规范
### 必须使用 `api` (axios 实例)
所有需要认证的 API 请求**必须**使用 `@/lib/axios` 导出的 axios 实例。该实例已配置:
- 自动携带 `credentials: include`
- 遇到 401/403 时自动清除 cookie 并跳转登录页
**使用方式:**
```typescript
import api from '@/lib/axios';
// GET 请求
const { data } = await api.get('/api/materials');
// POST 请求
const { data } = await api.post('/api/videos/generate', {
text: '...',
voice: '...',
});
// DELETE 请求
await api.delete(`/api/materials/${id}`);
// 带上传进度的文件上传
await api.post('/api/materials', formData, {
headers: { 'Content-Type': 'multipart/form-data' },
onUploadProgress: (e) => {
if (e.total) {
const progress = Math.round((e.loaded / e.total) * 100);
setProgress(progress);
}
},
});
```
### SWR 配合使用
```typescript
import api from '@/lib/axios';
// SWR fetcher 使用 axios
const fetcher = (url: string) => api.get(url).then(res => res.data);
const { data } = useSWR('/api/xxx', fetcher, { refreshInterval: 2000 });
```
---
## 日期格式化规范
### 禁止使用 `toLocaleString()`
`toLocaleString()` 在服务端和客户端可能返回不同格式,导致 Hydration 错误。
**错误示例:**
```typescript
// ❌ 会导致 Hydration 错误
new Date(timestamp * 1000).toLocaleString('zh-CN')
```
**正确做法:**
```typescript
// ✅ 使用固定格式
const formatDate = (timestamp: number) => {
const d = new Date(timestamp * 1000);
const year = d.getFullYear();
const month = String(d.getMonth() + 1).padStart(2, '0');
const day = String(d.getDate()).padStart(2, '0');
const hour = String(d.getHours()).padStart(2, '0');
const minute = String(d.getMinutes()).padStart(2, '0');
return `${year}/${month}/${day} ${hour}:${minute}`;
};
```
---
## 新增页面 Checklist
1. [ ] 导入 `import api from '@/lib/axios'`
2. [ ] 所有 API 请求使用 `api.get/post/delete()` 而非原生 `fetch`
3. [ ] 日期格式化使用固定格式函数,不用 `toLocaleString()`
4. [ ] 添加 `'use client'` 指令(如需客户端交互)

View File

@@ -1,46 +1,29 @@
(venv) rongye@r730-ubuntu:~/ProgramFiles/ViGent2/backend$ uvicorn app.main:app --host 0.0.0.0 --port 8006
INFO: Started server process [2398255]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://0.0.0.0:8006 (Press CTRL+C to quit)
INFO: 192.168.110.188:5826 - "GET /api/materials/?t=1768899244071 HTTP/1.1" 200 OK
INFO: 192.168.110.188:5826 - "GET /api/materials/?t=1768899248452 HTTP/1.1" 200 OK
INFO: 192.168.110.188:5826 - "GET /api/materials/?t=1768899250145 HTTP/1.1" 200 OK
INFO: 192.168.110.188:5826 - "GET /api/materials/?t=1768899250420 HTTP/1.1" 200 OK
INFO: 192.168.110.188:5826 - "GET /api/materials/?t=1768899250774 HTTP/1.1" 200 OK
INFO: 192.168.110.188:5826 - "GET /api/materials/?t=1768899251257 HTTP/1.1" 200 OK
INFO: 192.168.110.188:5826 - "OPTIONS /api/videos/generate HTTP/1.1" 200 OK
INFO: 192.168.110.188:5826 - "POST /api/videos/generate HTTP/1.1" 200 OK
2026-01-20 16:54:13.143 | INFO | app.services.tts_service:generate_audio:20 - TTS Generating: 大家好,欢迎来到我的频道,今天给大家分享... (zh-CN-YunxiNeural)
INFO: 192.168.110.188:5826 - "GET /api/videos/tasks/33c43a79-6e25-471f-873d-54d651d13474 HTTP/1.1" 200 OK
INFO: 192.168.110.188:5826 - "GET /api/videos/tasks/33c43a79-6e25-471f-873d-54d651d13474 HTTP/1.1" 200 OK
[Pipeline] TTS completed in 1.4s
2026-01-20 16:54:14.547 | INFO | app.services.lipsync_service:_check_weights:56 - ✅ LatentSync 权重文件已就绪
[LipSync] Health check: ready=True
[LipSync] Starting LatentSync inference...
2026-01-20 16:54:16.799 | INFO | app.services.lipsync_service:generate:172 - 🎬 唇形同步任务: 0bc1aa95-c567-4022-8d8b-cd3e439c78c0.mov + 33c43a79-6e25-471f-873d-54d651d13474_audio.mp3
2026-01-20 16:54:16.799 | INFO | app.services.lipsync_service:_local_generate:200 - 🔄 调用 LatentSync 推理 (subprocess)...
2026-01-20 16:54:17.004 | INFO | app.services.lipsync_service:_preprocess_video:111 - 📹 原始视频分辨率: 1920×1080
2026-01-20 16:54:17.005 | INFO | app.services.lipsync_service:_preprocess_video:128 - 📹 预处理视频: 1080p → 720p
2026-01-20 16:54:18.285 | INFO | app.services.lipsync_service:_preprocess_video:152 - ✅ 视频压缩完成: 14.9MB → 1.1MB
2026-01-20 16:54:18.285 | INFO | app.services.lipsync_service:_local_generate:237 - 🖥️ 执行命令: /home/rongye/ProgramFiles/miniconda3/envs/latentsync/bin/python -m scripts.inference --unet_config_path configs/unet/stage2_512.yaml --inference_ckpt_path checkpoints/latentsync_unet.pt --inference_steps...
2026-01-20 16:54:18.285 | INFO | app.services.lipsync_service:_local_generate:238 - 🖥️ GPU: CUDA_VISIBLE_DEVICES=1
2026-01-20 16:57:52.285 | INFO | app.services.lipsync_service:_local_generate:257 - LatentSync 输出:
: '0', 'arena_extend_strategy': 'kNextPowerOfTwo', 'use_ep_level_unified_stream': '0', 'device_id': '0', 'gpu_external_alloc': '0', 'sdpa_kernel': '0', 'cudnn_conv_algo_search': 'EXHAUSTIVE', 'gpu_external_free': '0', 'use_tf32': '1', 'cudnn_conv1d_pad_to_nc1d': '0', 'do_copy_in_default_stream': '1'}}
model ignore: checkpoints/auxiliary/models/buffalo_l/w600k_r50.onnx recognition
set det-size: (512, 512)
video in 25 FPS, audio idx in 50FPS
Affine transforming 135 faces...
Restoring 135 faces...
2026-01-20 16:57:52.287 | INFO | app.services.lipsync_service:_local_generate:262 - ✅ 唇形同步完成: /home/rongye/ProgramFiles/ViGent2/backend/outputs/33c43a79-6e25-471f-873d-54d651d13474_lipsync.mp4
[Pipeline] LipSync completed in 217.7s
2026-01-20 16:57:52.616 | DEBUG | app.services.video_service:_run_ffmpeg:17 - FFmpeg CMD: ffmpeg -y -i /home/rongye/ProgramFiles/ViGent2/backend/outputs/33c43a79-6e25-471f-873d-54d651d13474_lipsync.mp4 -i /home/rongye/ProgramFiles/ViGent2/backend/outputs/33c43a79-6e25-471f-873d-54d651d13474_audio.mp3 -c:v libx264 -c:a aac -shortest -map 0:v -map 1:a /home/rongye/ProgramFiles/ViGent2/backend/outputs/33c43a79-6e25-471f-873d-54d651d13474_output.mp4
[Pipeline] Total generation time: 220.4s
INFO: 192.168.110.188:5826 - "GET /api/videos/tasks/33c43a79-6e25-471f-873d-54d651d13474 HTTP/1.1" 200 OK
INFO: 192.168.110.188:10104 - "GET /outputs/33c43a79-6e25-471f-873d-54d651d13474_output.mp4 HTTP/1.1" 206 Partial Content
INFO: 192.168.110.188:6759 - "GET /outputs/33c43a79-6e25-471f-873d-54d651d13474_output.mp4 HTTP/1.1" 206 Partial Content
INFO: 192.168.110.188:6759 - "GET /outputs/33c43a79-6e25-471f-873d-54d651d13474_output.mp4 HTTP/1.1" 304 Not Modified
INFO: 192.168.110.188:6759 - "GET /outputs/33c43a79-6e25-471f-873d-54d651d13474_output.mp4 HTTP/1.1" 206 Partial Content
INFO: 192.168.110.188:6759 - "GET /outputs/33c43a79-6e25-471f-873d-54d651d13474_output.mp4 HTTP/1.1" 206 Partial Content
INFO: 192.168.110.188:10233 - "GET /outputs/33c43a79-6e25-471f-873d-54d651d13474_output.mp4 HTTP/1.1" 304 Not Modified
rongye@r730-ubuntu:~/ProgramFiles/Supabase$ docker compose up -d
[+] up 136/136
✔ Image timberio/vector:0.28.1-alpine Pulled 63.3ss
✔ Image supabase/storage-api:v1.33.0 Pulled 78.6ss
✔ Image darthsim/imgproxy:v3.30.1 Pulled 151.9s
✔ Image supabase/postgres-meta:v0.95.1 Pulled 87.5ss
✔ Image supabase/logflare:1.27.0 Pulled 229.2s
✔ Image supabase/postgres:15.8.1.085 Pulled 268.3s
✔ Image supabase/supavisor:2.7.4 Pulled 101.6s
✔ Image supabase/realtime:v2.68.0 Pulled 56.5ss
✔ Image postgrest/postgrest:v14.1 Pulled 201.8s
✔ Image supabase/edge-runtime:v1.69.28 Pulled 254.0s
✔ Network supabase_default Created 0.1s
✔ Volume supabase_db-config Created 0.1s
✔ Container supabase-vector Healthy 16.9s
✔ Container supabase-imgproxy Created 7.4s
✔ Container supabase-db Healthy 20.6s
✔ Container supabase-analytics Created 0.4s
✔ Container supabase-edge-functions Created 1.8s
✔ Container supabase-auth Created 1.7s
✔ Container supabase-studio Created 2.0s
✔ Container realtime-dev.supabase-realtime Created 1.7s
✔ Container supabase-pooler Created 1.8s
✔ Container supabase-kong Created 1.7s
✔ Container supabase-meta Created 2.0s
✔ Container supabase-rest Created 0.9s
✔ Container supabase-storage Created 1.4s
Error response from daemon: failed to set up container networking: driver failed programming external connectivity on endpoint supabase-analytics (2fd60a510a1f16bf29f8f5140f14ef457a284c5b65a2567b7be250a4f9708f34): failed to bind host port 0.0.0.0:4000/tcp: address already in use
[ble: exit 1]

View File

@@ -1,544 +0,0 @@
# MuseTalk
<strong>MuseTalk: Real-Time High-Fidelity Video Dubbing via Spatio-Temporal Sampling</strong>
Yue Zhang<sup>\*</sup>,
Zhizhou Zhong<sup>\*</sup>,
Minhao Liu<sup>\*</sup>,
Zhaokang Chen,
Bin Wu<sup>†</sup>,
Yubin Zeng,
Chao Zhan,
Junxin Huang,
Yingjie He,
Wenjiang Zhou
(<sup>*</sup>Equal Contribution, <sup>†</sup>Corresponding Author, benbinwu@tencent.com)
Lyra Lab, Tencent Music Entertainment
**[github](https://github.com/TMElyralab/MuseTalk)** **[huggingface](https://huggingface.co/TMElyralab/MuseTalk)** **[space](https://huggingface.co/spaces/TMElyralab/MuseTalk)** **[Technical report](https://arxiv.org/abs/2410.10122)**
We introduce `MuseTalk`, a **real-time high quality** lip-syncing model (30fps+ on an NVIDIA Tesla V100). MuseTalk can be applied with input videos, e.g., generated by [MuseV](https://github.com/TMElyralab/MuseV), as a complete virtual human solution.
## 🔥 Updates
We're excited to unveil MuseTalk 1.5.
This version **(1)** integrates training with perceptual loss, GAN loss, and sync loss, significantly boosting its overall performance. **(2)** We've implemented a two-stage training strategy and a spatio-temporal data sampling approach to strike a balance between visual quality and lip-sync accuracy.
Learn more details [here](https://arxiv.org/abs/2410.10122).
**The inference codes, training codes and model weights of MuseTalk 1.5 are all available now!** 🚀
# Overview
`MuseTalk` is a real-time high quality audio-driven lip-syncing model trained in the latent space of `ft-mse-vae`, which
1. modifies an unseen face according to the input audio, with a size of face region of `256 x 256`.
1. supports audio in various languages, such as Chinese, English, and Japanese.
1. supports real-time inference with 30fps+ on an NVIDIA Tesla V100.
1. supports modification of the center point of the face region proposes, which **SIGNIFICANTLY** affects generation results.
1. checkpoint available trained on the HDTF and private dataset.
# News
- [04/05/2025] :mega: We are excited to announce that the training code is now open-sourced! You can now train your own MuseTalk model using our provided training scripts and configurations.
- [03/28/2025] We are thrilled to announce the release of our 1.5 version. This version is a significant improvement over the 1.0 version, with enhanced clarity, identity consistency, and precise lip-speech synchronization. We update the [technical report](https://arxiv.org/abs/2410.10122) with more details.
- [10/18/2024] We release the [technical report](https://arxiv.org/abs/2410.10122v2). Our report details a superior model to the open-source L1 loss version. It includes GAN and perceptual losses for improved clarity, and sync loss for enhanced performance.
- [04/17/2024] We release a pipeline that utilizes MuseTalk for real-time inference.
- [04/16/2024] Release Gradio [demo](https://huggingface.co/spaces/TMElyralab/MuseTalk) on HuggingFace Spaces (thanks to HF team for their community grant)
- [04/02/2024] Release MuseTalk project and pretrained models.
## Model
![Model Structure](https://github.com/user-attachments/assets/02f4a214-1bdd-4326-983c-e70b478accba)
MuseTalk was trained in latent spaces, where the images were encoded by a freezed VAE. The audio was encoded by a freezed `whisper-tiny` model. The architecture of the generation network was borrowed from the UNet of the `stable-diffusion-v1-4`, where the audio embeddings were fused to the image embeddings by cross-attention.
Note that although we use a very similar architecture as Stable Diffusion, MuseTalk is distinct in that it is **NOT** a diffusion model. Instead, MuseTalk operates by inpainting in the latent space with a single step.
## Cases
<table>
<tr>
<td width="33%">
### Input Video
---
https://github.com/TMElyralab/MuseTalk/assets/163980830/37a3a666-7b90-4244-8d3a-058cb0e44107
---
https://github.com/user-attachments/assets/1ce3e850-90ac-4a31-a45f-8dfa4f2960ac
---
https://github.com/user-attachments/assets/fa3b13a1-ae26-4d1d-899e-87435f8d22b3
---
https://github.com/user-attachments/assets/15800692-39d1-4f4c-99f2-aef044dc3251
---
https://github.com/user-attachments/assets/a843f9c9-136d-4ed4-9303-4a7269787a60
---
https://github.com/user-attachments/assets/6eb4e70e-9e19-48e9-85a9-bbfa589c5fcb
</td>
<td width="33%">
### MuseTalk 1.0
---
https://github.com/user-attachments/assets/c04f3cd5-9f77-40e9-aafd-61978380d0ef
---
https://github.com/user-attachments/assets/2051a388-1cef-4c1d-b2a2-3c1ceee5dc99
---
https://github.com/user-attachments/assets/b5f56f71-5cdc-4e2e-a519-454242000d32
---
https://github.com/user-attachments/assets/a5843835-04ab-4c31-989f-0995cfc22f34
---
https://github.com/user-attachments/assets/3dc7f1d7-8747-4733-bbdd-97874af0c028
---
https://github.com/user-attachments/assets/3c78064e-faad-4637-83ae-28452a22b09a
</td>
<td width="33%">
### MuseTalk 1.5
---
https://github.com/user-attachments/assets/999a6f5b-61dd-48e1-b902-bb3f9cbc7247
---
https://github.com/user-attachments/assets/d26a5c9a-003c-489d-a043-c9a331456e75
---
https://github.com/user-attachments/assets/471290d7-b157-4cf6-8a6d-7e899afa302c
---
https://github.com/user-attachments/assets/1ee77c4c-8c70-4add-b6db-583a12faa7dc
---
https://github.com/user-attachments/assets/370510ea-624c-43b7-bbb0-ab5333e0fcc4
---
https://github.com/user-attachments/assets/b011ece9-a332-4bc1-b8b7-ef6e383d7bde
</td>
</tr>
</table>
# TODO:
- [x] trained models and inference codes.
- [x] Huggingface Gradio [demo](https://huggingface.co/spaces/TMElyralab/MuseTalk).
- [x] codes for real-time inference.
- [x] [technical report](https://arxiv.org/abs/2410.10122v2).
- [x] a better model with updated [technical report](https://arxiv.org/abs/2410.10122).
- [x] realtime inference code for 1.5 version.
- [x] training and data preprocessing codes.
- [ ] **always** welcome to submit issues and PRs to improve this repository! 😊
# Getting Started
We provide a detailed tutorial about the installation and the basic usage of MuseTalk for new users:
## Third party integration
Thanks for the third-party integration, which makes installation and use more convenient for everyone.
We also hope you note that we have not verified, maintained, or updated third-party. Please refer to this project for specific results.
### [ComfyUI](https://github.com/chaojie/ComfyUI-MuseTalk)
## Installation
To prepare the Python environment and install additional packages such as opencv, diffusers, mmcv, etc., please follow the steps below:
### Build environment
We recommend Python 3.10 and CUDA 11.7. Set up your environment as follows:
```shell
conda create -n MuseTalk python==3.10
conda activate MuseTalk
```
### Install PyTorch 2.0.1
Choose one of the following installation methods:
```shell
# Option 1: Using pip
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 --index-url https://download.pytorch.org/whl/cu118
# Option 2: Using conda
conda install pytorch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2 pytorch-cuda=11.8 -c pytorch -c nvidia
```
### Install Dependencies
Install the remaining required packages:
```shell
pip install -r requirements.txt
```
### Install MMLab Packages
Install the MMLab ecosystem packages:
```bash
pip install --no-cache-dir -U openmim
mim install mmengine
mim install "mmcv==2.0.1"
mim install "mmdet==3.1.0"
mim install "mmpose==1.1.0"
```
### Setup FFmpeg
1. [Download](https://github.com/BtbN/FFmpeg-Builds/releases) the ffmpeg-static package
2. Configure FFmpeg based on your operating system:
For Linux:
```bash
export FFMPEG_PATH=/path/to/ffmpeg
# Example:
export FFMPEG_PATH=/musetalk/ffmpeg-4.4-amd64-static
```
For Windows:
Add the `ffmpeg-xxx\bin` directory to your system's PATH environment variable. Verify the installation by running `ffmpeg -version` in the command prompt - it should display the ffmpeg version information.
### Download weights
You can download weights in two ways:
#### Option 1: Using Download Scripts
We provide two scripts for automatic downloading:
For Linux:
```bash
sh ./download_weights.sh
```
For Windows:
```batch
# Run the script
download_weights.bat
```
#### Option 2: Manual Download
You can also download the weights manually from the following links:
1. Download our trained [weights](https://huggingface.co/TMElyralab/MuseTalk/tree/main)
2. Download the weights of other components:
- [sd-vae-ft-mse](https://huggingface.co/stabilityai/sd-vae-ft-mse/tree/main)
- [whisper](https://huggingface.co/openai/whisper-tiny/tree/main)
- [dwpose](https://huggingface.co/yzd-v/DWPose/tree/main)
- [syncnet](https://huggingface.co/ByteDance/LatentSync/tree/main)
- [face-parse-bisent](https://drive.google.com/file/d/154JgKpzCPW82qINcVieuPH3fZ2e0P812/view?pli=1)
- [resnet18](https://download.pytorch.org/models/resnet18-5c106cde.pth)
Finally, these weights should be organized in `models` as follows:
```
./models/
├── musetalk
│ └── musetalk.json
│ └── pytorch_model.bin
├── musetalkV15
│ └── musetalk.json
│ └── unet.pth
├── syncnet
│ └── latentsync_syncnet.pt
├── dwpose
│ └── dw-ll_ucoco_384.pth
├── face-parse-bisent
│ ├── 79999_iter.pth
│ └── resnet18-5c106cde.pth
├── sd-vae
│ ├── config.json
│ └── diffusion_pytorch_model.bin
└── whisper
├── config.json
├── pytorch_model.bin
└── preprocessor_config.json
```
## Quickstart
### Inference
We provide inference scripts for both versions of MuseTalk:
#### Prerequisites
Before running inference, please ensure ffmpeg is installed and accessible:
```bash
# Check ffmpeg installation
ffmpeg -version
```
If ffmpeg is not found, please install it first:
- Windows: Download from [ffmpeg-static](https://github.com/BtbN/FFmpeg-Builds/releases) and add to PATH
- Linux: `sudo apt-get install ffmpeg`
#### Normal Inference
##### Linux Environment
```bash
# MuseTalk 1.5 (Recommended)
sh inference.sh v1.5 normal
# MuseTalk 1.0
sh inference.sh v1.0 normal
```
##### Windows Environment
Please ensure that you set the `ffmpeg_path` to match the actual location of your FFmpeg installation.
```bash
# MuseTalk 1.5 (Recommended)
python -m scripts.inference --inference_config configs\inference\test.yaml --result_dir results\test --unet_model_path models\musetalkV15\unet.pth --unet_config models\musetalkV15\musetalk.json --version v15 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
# For MuseTalk 1.0, change:
# - models\musetalkV15 -> models\musetalk
# - unet.pth -> pytorch_model.bin
# - --version v15 -> --version v1
```
#### Real-time Inference
##### Linux Environment
```bash
# MuseTalk 1.5 (Recommended)
sh inference.sh v1.5 realtime
# MuseTalk 1.0
sh inference.sh v1.0 realtime
```
##### Windows Environment
```bash
# MuseTalk 1.5 (Recommended)
python -m scripts.realtime_inference --inference_config configs\inference\realtime.yaml --result_dir results\realtime --unet_model_path models\musetalkV15\unet.pth --unet_config models\musetalkV15\musetalk.json --version v15 --fps 25 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
# For MuseTalk 1.0, change:
# - models\musetalkV15 -> models\musetalk
# - unet.pth -> pytorch_model.bin
# - --version v15 -> --version v1
```
The configuration file `configs/inference/test.yaml` contains the inference settings, including:
- `video_path`: Path to the input video, image file, or directory of images
- `audio_path`: Path to the input audio file
Note: For optimal results, we recommend using input videos with 25fps, which is the same fps used during model training. If your video has a lower frame rate, you can use frame interpolation or convert it to 25fps using ffmpeg.
Important notes for real-time inference:
1. Set `preparation` to `True` when processing a new avatar
2. After preparation, the avatar will generate videos using audio clips from `audio_clips`
3. The generation process can achieve 30fps+ on an NVIDIA Tesla V100
4. Set `preparation` to `False` for generating more videos with the same avatar
For faster generation without saving images, you can use:
```bash
python -m scripts.realtime_inference --inference_config configs/inference/realtime.yaml --skip_save_images
```
## Gradio Demo
We provide an intuitive web interface through Gradio for users to easily adjust input parameters. To optimize inference time, users can generate only the **first frame** to fine-tune the best lip-sync parameters, which helps reduce facial artifacts in the final output.
![para](assets/figs/gradio_2.png)
For minimum hardware requirements, we tested the system on a Windows environment using an NVIDIA GeForce RTX 3050 Ti Laptop GPU with 4GB VRAM. In fp16 mode, generating an 8-second video takes approximately 5 minutes. ![speed](assets/figs/gradio.png)
Both Linux and Windows users can launch the demo using the following command. Please ensure that the `ffmpeg_path` parameter matches your actual FFmpeg installation path:
```bash
# You can remove --use_float16 for better quality, but it will increase VRAM usage and inference time
python app.py --use_float16 --ffmpeg_path ffmpeg-master-latest-win64-gpl-shared\bin
```
## Training
### Data Preparation
To train MuseTalk, you need to prepare your dataset following these steps:
1. **Place your source videos**
For example, if you're using the HDTF dataset, place all your video files in `./dataset/HDTF/source`.
2. **Run the preprocessing script**
```bash
python -m scripts.preprocess --config ./configs/training/preprocess.yaml
```
This script will:
- Extract frames from videos
- Detect and align faces
- Generate audio features
- Create the necessary data structure for training
### Training Process
After data preprocessing, you can start the training process:
1. **First Stage**
```bash
sh train.sh stage1
```
2. **Second Stage**
```bash
sh train.sh stage2
```
### Configuration Adjustment
Before starting the training, you should adjust the configuration files according to your hardware and requirements:
1. **GPU Configuration** (`configs/training/gpu.yaml`):
- `gpu_ids`: Specify the GPU IDs you want to use (e.g., "0,1,2,3")
- `num_processes`: Set this to match the number of GPUs you're using
2. **Stage 1 Configuration** (`configs/training/stage1.yaml`):
- `data.train_bs`: Adjust batch size based on your GPU memory (default: 32)
- `data.n_sample_frames`: Number of sampled frames per video (default: 1)
3. **Stage 2 Configuration** (`configs/training/stage2.yaml`):
- `random_init_unet`: Must be set to `False` to use the model from stage 1
- `data.train_bs`: Smaller batch size due to high GPU memory cost (default: 2)
- `data.n_sample_frames`: Higher value for temporal consistency (default: 16)
- `solver.gradient_accumulation_steps`: Increase to simulate larger batch sizes (default: 8)
### GPU Memory Requirements
Based on our testing on a machine with 8 NVIDIA H20 GPUs:
#### Stage 1 Memory Usage
| Batch Size | Gradient Accumulation | Memory per GPU | Recommendation |
|:----------:|:----------------------:|:--------------:|:--------------:|
| 8 | 1 | ~32GB | |
| 16 | 1 | ~45GB | |
| 32 | 1 | ~74GB | ✓ |
#### Stage 2 Memory Usage
| Batch Size | Gradient Accumulation | Memory per GPU | Recommendation |
|:----------:|:----------------------:|:--------------:|:--------------:|
| 1 | 8 | ~54GB | |
| 2 | 2 | ~80GB | |
| 2 | 8 | ~85GB | ✓ |
<details close>
## TestCases For 1.0
<table class="center">
<tr style="font-weight: bolder;text-align:center;">
<td width="33%">Image</td>
<td width="33%">MuseV</td>
<td width="33%">+MuseTalk</td>
</tr>
<tr>
<td>
<img src=assets/demo/musk/musk.png width="95%">
</td>
<td >
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/4a4bb2d1-9d14-4ca9-85c8-7f19c39f712e controls preload></video>
</td>
<td >
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/b2a879c2-e23a-4d39-911d-51f0343218e4 controls preload></video>
</td>
</tr>
<tr>
<td>
<img src=assets/demo/yongen/yongen.jpeg width="95%">
</td>
<td >
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/57ef9dee-a9fd-4dc8-839b-3fbbbf0ff3f4 controls preload></video>
</td>
<td >
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/94d8dcba-1bcd-4b54-9d1d-8b6fc53228f0 controls preload></video>
</td>
</tr>
<tr>
<td>
<img src=assets/demo/sit/sit.jpeg width="95%">
</td>
<td >
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/5fbab81b-d3f2-4c75-abb5-14c76e51769e controls preload></video>
</td>
<td >
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/f8100f4a-3df8-4151-8de2-291b09269f66 controls preload></video>
</td>
</tr>
<tr>
<td>
<img src=assets/demo/man/man.png width="95%">
</td>
<td >
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/a6e7d431-5643-4745-9868-8b423a454153 controls preload></video>
</td>
<td >
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/6ccf7bc7-cb48-42de-85bd-076d5ee8a623 controls preload></video>
</td>
</tr>
<tr>
<td>
<img src=assets/demo/monalisa/monalisa.png width="95%">
</td>
<td >
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/1568f604-a34f-4526-a13a-7d282aa2e773 controls preload></video>
</td>
<td >
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/a40784fc-a885-4c1f-9b7e-8f87b7caf4e0 controls preload></video>
</td>
</tr>
<tr>
<td>
<img src=assets/demo/sun1/sun.png width="95%">
</td>
<td >
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/37a3a666-7b90-4244-8d3a-058cb0e44107 controls preload></video>
</td>
<td >
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/172f4ff1-d432-45bd-a5a7-a07dec33a26b controls preload></video>
</td>
</tr>
<tr>
<td>
<img src=assets/demo/sun2/sun.png width="95%">
</td>
<td >
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/37a3a666-7b90-4244-8d3a-058cb0e44107 controls preload></video>
</td>
<td >
<video src=https://github.com/TMElyralab/MuseTalk/assets/163980830/85a6873d-a028-4cce-af2b-6c59a1f2971d controls preload></video>
</td>
</tr>
</table >
#### Use of bbox_shift to have adjustable results(For 1.0)
:mag_right: We have found that upper-bound of the mask has an important impact on mouth openness. Thus, to control the mask region, we suggest using the `bbox_shift` parameter. Positive values (moving towards the lower half) increase mouth openness, while negative values (moving towards the upper half) decrease mouth openness.
You can start by running with the default configuration to obtain the adjustable value range, and then re-run the script within this range.
For example, in the case of `Xinying Sun`, after running the default configuration, it shows that the adjustable value rage is [-9, 9]. Then, to decrease the mouth openness, we set the value to be `-7`.
```
python -m scripts.inference --inference_config configs/inference/test.yaml --bbox_shift -7
```
:pushpin: More technical details can be found in [bbox_shift](assets/BBOX_SHIFT.md).
#### Combining MuseV and MuseTalk
As a complete solution to virtual human generation, you are suggested to first apply [MuseV](https://github.com/TMElyralab/MuseV) to generate a video (text-to-video, image-to-video or pose-to-video) by referring [this](https://github.com/TMElyralab/MuseV?tab=readme-ov-file#text2video). Frame interpolation is suggested to increase frame rate. Then, you can use `MuseTalk` to generate a lip-sync video by referring [this](https://github.com/TMElyralab/MuseTalk?tab=readme-ov-file#inference).
# Acknowledgement
1. We thank open-source components like [whisper](https://github.com/openai/whisper), [dwpose](https://github.com/IDEA-Research/DWPose), [face-alignment](https://github.com/1adrianb/face-alignment), [face-parsing](https://github.com/zllrunning/face-parsing.PyTorch), [S3FD](https://github.com/yxlijun/S3FD.pytorch) and [LatentSync](https://huggingface.co/ByteDance/LatentSync/tree/main).
1. MuseTalk has referred much to [diffusers](https://github.com/huggingface/diffusers) and [isaacOnline/whisper](https://github.com/isaacOnline/whisper/tree/extract-embeddings).
1. MuseTalk has been built on [HDTF](https://github.com/MRzzm/HDTF) datasets.
Thanks for open-sourcing!
# Limitations
- Resolution: Though MuseTalk uses a face region size of 256 x 256, which make it better than other open-source methods, it has not yet reached the theoretical resolution bound. We will continue to deal with this problem.
If you need higher resolution, you could apply super resolution models such as [GFPGAN](https://github.com/TencentARC/GFPGAN) in combination with MuseTalk.
- Identity preservation: Some details of the original face are not well preserved, such as mustache, lip shape and color.
- Jitter: There exists some jitter as the current pipeline adopts single-frame generation.
# Citation
```bib
@article{musetalk,
title={MuseTalk: Real-Time High-Fidelity Video Dubbing via Spatio-Temporal Sampling},
author={Zhang, Yue and Zhong, Zhizhou and Liu, Minhao and Chen, Zhaokang and Wu, Bin and Zeng, Yubin and Zhan, Chao and He, Yingjie and Huang, Junxin and Zhou, Wenjiang},
journal={arxiv},
year={2025}
}
```
# Disclaimer/License
1. `code`: The code of MuseTalk is released under the MIT License. There is no limitation for both academic and commercial usage.
1. `model`: The trained model are available for any purpose, even commercially.
1. `other opensource model`: Other open-source models used must comply with their license, such as `whisper`, `ft-mse-vae`, `dwpose`, `S3FD`, etc..
1. The testdata are collected from internet, which are available for non-commercial research purposes only.
1. `AIGC`: This project strives to impact the domain of AI-driven video generation positively. Users are granted the freedom to create videos using this tool, but they are expected to comply with local laws and utilize it responsibly. The developers do not assume any responsibility for potential misuse by users.

253
Docs/QWEN3_TTS_DEPLOY.md Normal file
View File

@@ -0,0 +1,253 @@
# Qwen3-TTS 0.6B 部署指南
> 本文档描述如何在 Ubuntu 服务器上部署 Qwen3-TTS 0.6B-Base 声音克隆模型。
## 系统要求
| 要求 | 规格 |
|------|------|
| GPU | NVIDIA RTX 3090 24GB (或更高) |
| VRAM | ≥ 4GB (推理), ≥ 8GB (带 flash-attn) |
| CUDA | 12.1+ |
| Python | 3.10.x |
| 系统 | Ubuntu 20.04+ |
---
## GPU 分配
| GPU | 服务 | 模型 |
|-----|------|------|
| GPU0 | **Qwen3-TTS** | 0.6B-Base (声音克隆) |
| GPU1 | LatentSync | 1.6 (唇形同步) |
---
## 步骤 1: 克隆仓库
```bash
cd /home/rongye/ProgramFiles/ViGent2/models
git clone https://github.com/QwenLM/Qwen3-TTS.git
cd Qwen3-TTS
```
---
## 步骤 2: 创建 Conda 环境
```bash
# 创建新的 conda 环境
conda create -n qwen-tts python=3.10 -y
conda activate qwen-tts
```
---
## 步骤 3: 安装 Python 依赖
```bash
cd /home/rongye/ProgramFiles/ViGent2/models/Qwen3-TTS
# 安装 qwen-tts 包 (editable mode)
pip install -e .
# 安装 sox 音频处理库 (必须)
conda install -y -c conda-forge sox
```
### 可选: 安装 FlashAttention (推荐)
FlashAttention 可以显著提升推理速度并减少显存占用:
```bash
pip install -U flash-attn --no-build-isolation
```
如果内存不足,可以限制编译并发数:
```bash
MAX_JOBS=4 pip install -U flash-attn --no-build-isolation
```
---
## 步骤 4: 下载模型权重
### 方式 A: ModelScope (推荐,国内更快)
```bash
pip install modelscope
# 下载 Tokenizer (651MB)
modelscope download --model Qwen/Qwen3-TTS-Tokenizer-12Hz --local_dir ./checkpoints/Tokenizer
# 下载 0.6B-Base 模型 (2.4GB)
modelscope download --model Qwen/Qwen3-TTS-12Hz-0.6B-Base --local_dir ./checkpoints/0.6B-Base
```
### 方式 B: HuggingFace
```bash
pip install -U "huggingface_hub[cli]"
huggingface-cli download Qwen/Qwen3-TTS-Tokenizer-12Hz --local-dir ./checkpoints/Tokenizer
huggingface-cli download Qwen/Qwen3-TTS-12Hz-0.6B-Base --local-dir ./checkpoints/0.6B-Base
```
下载完成后,目录结构应如下:
```
checkpoints/
├── Tokenizer/ # ~651MB
│ ├── config.json
│ ├── model.safetensors
│ └── ...
└── 0.6B-Base/ # ~2.4GB
├── config.json
├── model.safetensors
└── ...
```
---
## 步骤 5: 验证安装
### 5.1 检查环境
```bash
conda activate qwen-tts
# 检查 PyTorch 和 CUDA
python -c "import torch; print(f'PyTorch: {torch.__version__}'); print(f'CUDA: {torch.cuda.is_available()}')"
# 检查 sox
sox --version
```
### 5.2 运行推理测试
创建测试脚本 `test_inference.py`:
```python
"""Qwen3-TTS 声音克隆测试"""
import torch
import soundfile as sf
from qwen_tts import Qwen3TTSModel
print("Loading Qwen3-TTS model on GPU:0...")
model = Qwen3TTSModel.from_pretrained(
"./checkpoints/0.6B-Base",
device_map="cuda:0",
dtype=torch.bfloat16,
)
print("Model loaded!")
# 测试声音克隆 (需要准备参考音频)
ref_audio = "./examples/myvoice.wav" # 3-20秒的参考音频
ref_text = "参考音频的文字内容"
test_text = "这是一段测试文本,用于验证声音克隆功能是否正常工作。"
print("Generating cloned voice...")
wavs, sr = model.generate_voice_clone(
text=test_text,
language="Chinese",
ref_audio=ref_audio,
ref_text=ref_text,
)
sf.write("test_output.wav", wavs[0], sr)
print(f"✅ Saved: test_output.wav | {sr}Hz | {len(wavs[0])/sr:.2f}s")
```
运行测试:
```bash
cd /home/rongye/ProgramFiles/ViGent2/models/Qwen3-TTS
python test_inference.py
```
---
## 目录结构
部署完成后,目录结构应如下:
```
/home/rongye/ProgramFiles/ViGent2/models/Qwen3-TTS/
├── checkpoints/
│ ├── Tokenizer/ # 语音编解码器
│ └── 0.6B-Base/ # 声音克隆模型
├── qwen_tts/ # 源码
│ ├── inference/
│ ├── models/
│ └── ...
├── examples/
│ └── myvoice.wav # 参考音频
├── pyproject.toml
├── requirements.txt
└── test_inference.py # 测试脚本
```
---
## 模型说明
### 可用模型
| 模型 | 功能 | 大小 |
|------|------|------|
| 0.6B-Base | 3秒快速声音克隆 | 2.4GB |
| 0.6B-CustomVoice | 9种预设音色 | 2.4GB |
| 1.7B-Base | 声音克隆 (更高质量) | 6.8GB |
| 1.7B-VoiceDesign | 自然语言描述生成声音 | 6.8GB |
### 支持语言
中文、英语、日语、韩语、德语、法语、俄语、葡萄牙语、西班牙语、意大利语
---
## 故障排除
### sox 未找到
```
SoX could not be found!
```
**解决**: 通过 conda 安装 sox
```bash
conda install -y -c conda-forge sox
```
### CUDA 内存不足
Qwen3-TTS 0.6B 通常只需要 4-6GB VRAM。如果遇到 OOM
1. 确保 GPU0 没有运行其他程序
2. 不使用 flash-attn (会增加显存占用)
3. 使用更小的参考音频 (3-5秒)
### 模型加载失败
确保以下文件存在:
- `checkpoints/0.6B-Base/config.json`
- `checkpoints/0.6B-Base/model.safetensors`
### 音频输出质量问题
1. 参考音频质量:使用清晰、无噪音的 3-10 秒音频
2. ref_text 准确性:参考音频的转写文字必须准确
3. 语言设置:确保 `language` 参数与文本语言一致
---
## 参考链接
- [Qwen3-TTS GitHub](https://github.com/QwenLM/Qwen3-TTS)
- [ModelScope 模型](https://modelscope.cn/collections/Qwen/Qwen3-TTS)
- [HuggingFace 模型](https://huggingface.co/collections/Qwen/qwen3-tts)
- [技术报告](https://arxiv.org/abs/2601.15621)
- [官方博客](https://qwen.ai/blog?id=qwen3tts-0115)

291
Docs/SUPABASE_DEPLOY.md Normal file
View File

@@ -0,0 +1,291 @@
# Supabase 全栈部署指南 (Infrastructure + Auth)
本文档涵盖了 Supabase 基础设施的 Docker 部署、密钥配置、Nginx 安全加固以及用户认证系统的数据库初始化。
---
## 第一部分:基础设施部署 (Infrastructure)
### 1. 准备 Docker 环境 (Ubuntu)
Supabase 严重依赖官方目录结构(挂载配置文件),**必须包含完整的 `docker` 目录**。
```bash
# 1. 创建目录
mkdir -p /home/rongye/ProgramFiles/Supabase
cd /home/rongye/ProgramFiles/Supabase
# 2. 获取官方配置
# 克隆仓库并提取 docker 目录
git clone --depth 1 https://github.com/supabase/supabase.git temp_repo
mv temp_repo/docker/* .
rm -rf temp_repo
# 3. 复制环境变量模板
cp .env.example .env
```
### 2. 生成安全密钥
**警告**:官方模板使用的是公开的弱密钥。生产环境必须重新生成。
使用项目提供的脚本自动生成全套强密钥:
```bash
# 在 ViGent2 项目目录下
cd /home/rongye/ProgramFiles/ViGent2/backend
python generate_keys.py
```
将脚本生成的输出(包括 `JWT_SECRET`, `ANON_KEY`, `SERVICE_ROLE_KEY` 等)复制并**覆盖** `/home/rongye/ProgramFiles/Supabase/.env` 中的对应内容。
### 3. 配置端口与冲突解决
编辑 Supabase 的 `.env` 文件修改以下端口以避免与现有服务Code-Server, Moodist冲突
```ini
# --- Port Configuration ---
# 避免与 Code-Server (8443) 冲突
KONG_HTTPS_PORT=8444
# 自定义 API 端口 (默认 8000)
KONG_HTTP_PORT=8008
# 自定义管理后台端口 (默认 3000)
STUDIO_PORT=3003
# 外部访问 URL (重要:填入你的公网 API 域名/IP)
# 如果配置了 Nginx 反代: https://api.hbyrkj.top
# 如果直连: http://8.148.25.142:8008
API_EXTERNAL_URL=https://api.hbyrkj.top
# Studio 公网 API 地址 (通过公网访问 Studio 时必须配置)
# 用于 Studio 前端调用 API
SUPABASE_PUBLIC_URL=https://api.hbyrkj.top
```
### 4. 启动服务
```bash
docker compose up -d
```
---
## 第二部分Storage 本地文件结构
### 1. 存储路径
Supabase Storage 使用本地文件系统存储,路径结构如下:
```
/home/rongye/ProgramFiles/Supabase/volumes/storage/stub/stub/
├── materials/ # 素材桶
│ └── {user_id}/ # 用户目录 (隔离)
│ └── {timestamp}_{filename}/
│ └── {internal_uuid} # 实际文件 (Supabase 内部 UUID)
└── outputs/ # 输出桶
└── {user_id}/
└── {task_id}_output.mp4/
└── {internal_uuid}
```
### 2. 用户隔离策略
所有用户数据通过路径前缀实现隔离:
| 资源类型 | 路径格式 | 示例 |
|----------|----------|------|
| 素材 | `{bucket}/{user_id}/{timestamp}_{filename}` | `materials/abc123/1737000001_video.mp4` |
| 输出 | `{bucket}/{user_id}/{task_id}_output.mp4` | `outputs/abc123/uuid-xxx_output.mp4` |
| Cookie | `cookies/{user_id}/{platform}.json` | `cookies/abc123/bilibili.json` |
### 3. 直接访问本地文件
后端可以直接读取本地文件(跳过 HTTP提升发布等操作的效率
```python
# storage.py
SUPABASE_STORAGE_LOCAL_PATH = Path("/home/rongye/ProgramFiles/Supabase/volumes/storage/stub/stub")
def get_local_file_path(self, bucket: str, path: str) -> Optional[str]:
dir_path = SUPABASE_STORAGE_LOCAL_PATH / bucket / path
files = list(dir_path.iterdir())
return str(files[0]) if files else None
```
---
## 第三部分:安全访问配置 (Nginx)
建议在阿里云公网网关上配置 Nginx 反向代理,通过 Frp 隧道连接内网服务。
### 1. 域名规划
- **管理后台**: `https://supabase.hbyrkj.top` -> 内网 3003
- **API 接口**: `https://api.hbyrkj.top` -> 内网 8008
### 2. Nginx 配置示例
```nginx
# Studio (需要密码保护但静态资源和内部API需排除)
server {
server_name supabase.hbyrkj.top;
# SSL 配置略...
# 静态资源不需要认证
location ~ ^/(favicon|_next|static)/ {
auth_basic off;
proxy_pass http://127.0.0.1:3003;
proxy_set_header Host $host;
proxy_http_version 1.1;
}
# Studio 内部 API 调用不需要认证
location /api/ {
auth_basic off;
proxy_pass http://127.0.0.1:3003;
proxy_set_header Host $host;
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
}
# 其他路径需要 Basic Auth 认证
location / {
auth_basic "Restricted Studio";
auth_basic_user_file /etc/nginx/.htpasswd;
proxy_pass http://127.0.0.1:3003;
# WebSocket 支持 (Realtime 必须)
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
}
}
# API (公开访问)
server {
server_name api.hbyrkj.top;
# SSL 配置略...
# ⚠️ 重要:解除上传大小限制
client_max_body_size 0;
location / {
proxy_pass http://127.0.0.1:8008;
# 允许 WebSocket
proxy_http_version 1.1;
proxy_set_header Upgrade $http_upgrade;
proxy_set_header Connection "upgrade";
# 大文件上传超时设置
proxy_read_timeout 600s;
proxy_send_timeout 600s;
}
}
```
### 3. 关键配置说明
| 配置项 | 作用 | 必要性 |
|--------|------|--------|
| `client_max_body_size 0` | 解除上传大小限制(默认 1MB | **必须** |
| `proxy_read_timeout 600s` | 大文件上传/下载超时 | 推荐 |
| `proxy_http_version 1.1` | WebSocket 支持 | Realtime 必须 |
| `auth_basic` | Studio 访问保护 | 推荐 |
---
## 第四部分:数据库与认证配置 (Database & Auth)
### 1. 初始化表结构 (Schema)
访问管理后台 (Studio) 的 **SQL Editor**,执行以下 SQL 来初始化 ViGent2 所需的表结构:
```sql
-- 1. 用户表 (扩展 auth.users 或独立存储)
-- 注意:这里使用独立表设计,与 FastAPI 逻辑解耦
CREATE TABLE IF NOT EXISTS users (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
email TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
username TEXT,
role TEXT DEFAULT 'pending' CHECK (role IN ('pending', 'user', 'admin')),
is_active BOOLEAN DEFAULT FALSE,
expires_at TIMESTAMP WITH TIME ZONE,
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
);
-- 2. 会话表 (单设备登录控制)
CREATE TABLE IF NOT EXISTS user_sessions (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID REFERENCES users(id) ON DELETE CASCADE UNIQUE,
session_token TEXT UNIQUE NOT NULL,
device_info TEXT,
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
);
-- 3. 社交媒体账号绑定表
CREATE TABLE IF NOT EXISTS social_accounts (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID REFERENCES users(id) ON DELETE CASCADE,
platform TEXT NOT NULL CHECK (platform IN ('bilibili', 'douyin', 'xiaohongshu')),
logged_in BOOLEAN DEFAULT FALSE,
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
UNIQUE(user_id, platform)
);
-- 4. 性能索引
CREATE INDEX IF NOT EXISTS idx_users_email ON users(email);
CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON user_sessions(user_id);
CREATE INDEX IF NOT EXISTS idx_social_user_platform ON social_accounts(user_id, platform);
```
### 2. 后端集成配置 (FastAPI)
修改 `ViGent2/backend/.env` 以连接到自托管的 Supabase
```ini
# =============== Supabase 配置 ===============
# 指向 Docker 部署的 API 端口 (内网直连推荐用 Localhost)
SUPABASE_URL=http://localhost:8008
# 使用生成的 SERVICE_ROLE_KEY (后端需要管理员权限)
SUPABASE_KEY=eyJhbGciOiJIUzI1Ni...
# =============== JWT 配置 ===============
# 必须与 Supabase .env 中的 JWT_SECRET 保持一致!
JWT_SECRET_KEY=填入_generate_keys.py_生成的_JWT_SECRET
JWT_ALGORITHM=HS256
JWT_EXPIRE_HOURS=168
```
---
## 第五部分:常用维护命令
**查看服务状态**:
```bash
cd /home/rongye/ProgramFiles/Supabase
docker compose ps
```
**查看密钥**:
```bash
grep -E "ANON|SERVICE|SECRET" .env
```
**重启服务**:
```bash
docker compose restart
```
**完全重置数据库 (慎用)**:
```bash
docker compose down -v
rm -rf volumes/db/data
docker compose up -d
```

View File

@@ -22,7 +22,7 @@
┌─────────────────────────────────────────────────────────┐
│ 后端 (FastAPI) │
├─────────────────────────────────────────────────────────┤
Celery 任务队列 (Redis) │
异步任务队列 (asyncio) │
│ ├── 视频生成任务 │
│ ├── TTS 配音任务 │
│ └── 自动发布任务 │
@@ -30,7 +30,7 @@
│ │ │
▼ ▼ ▼
┌──────────┐ ┌──────────┐ ┌──────────┐
MuseTalk │ │ FFmpeg │ │Playwright│
LatentSync│ │ FFmpeg │ │Playwright│
│ 唇形同步 │ │ 视频合成 │ │ 自动发布 │
└──────────┘ └──────────┘ └──────────┘
```
@@ -45,7 +45,7 @@
| **UI 组件库** | Tailwind + shadcn/ui | Ant Design |
| **后端框架** | FastAPI | Flask |
| **任务队列** | Celery + Redis | RQ / Dramatiq |
| **唇形同步** | MuseTalk | Wav2Lip / SadTalker |
| **唇形同步** | **LatentSync 1.6** | MuseTalk / Wav2Lip |
| **TTS 配音** | EdgeTTS | CosyVoice |
| **声音克隆** | GPT-SoVITS (可选) | - |
| **视频处理** | FFmpeg | MoviePy |
@@ -141,12 +141,12 @@ backend/
| 端点 | 方法 | 功能 |
|------|------|------|
| `/api/materials` | POST | 上传素材视频 |
| `/api/materials` | GET | 获取素材列表 |
| `/api/videos/generate` | POST | 创建视频生成任务 |
| `/api/tasks/{id}` | GET | 查询任务状态 |
| `/api/videos/{id}/download` | GET | 下载生成的视频 |
| `/api/publish` | POST | 发布到社交平台 |
| `/api/materials` | POST | 上传素材视频 | ✅ |
| `/api/materials` | GET | 获取素材列表 | ✅ |
| `/api/videos/generate` | POST | 创建视频生成任务 | ✅ |
| `/api/tasks/{id}` | GET | 查询任务状态 | ✅ |
| `/api/videos/{id}/download` | GET | 下载生成的视频 | ✅ |
| `/api/publish` | POST | 发布到社交平台 | ✅ |
#### 2.3 Celery 任务定义
@@ -221,7 +221,107 @@ cp -r SuperIPAgent/social-auto-upload backend/social_upload
| **声音克隆** | 集成 GPT-SoVITS用自己的声音 |
| **批量生成** | 上传 Excel/CSV批量生成视频 |
| **字幕编辑器** | 可视化调整字幕样式、位置 |
| **Docker 部署** | 一键部署到云服务器 |
| **Docker 部署** | 一键部署到云服务器 | ✅ |
---
### 阶段六MuseTalk 服务器部署 (Day 2-3) ✅
> **目标**:在双显卡服务器上部署 MuseTalk 环境
- [x] Conda 环境配置 (musetalk)
- [x] 模型权重下载 (~7GB)
- [x] Subprocess 调用方式实现
- [x] 健康检查功能
### 阶段七MuseTalk 完整修复 (Day 4) ✅
> **目标**:解决推理脚本的各种兼容性问题
- [x] 权重检测路径修复 (软链接)
- [x] 音视频长度不匹配修复
- [x] 推理脚本错误日志增强
- [x] 视频合成 MP4 生成验证
### 阶段八:前端功能增强 (Day 5) ✅
> **目标**:提升用户体验
- [x] Web 视频上传功能
- [x] 上传进度显示
- [x] 自动刷新素材列表
### 阶段九:唇形同步模型升级 (Day 6) ✅
> **目标**:从 MuseTalk 迁移到 LatentSync 1.6
- [x] MuseTalk → LatentSync 1.6 迁移
- [x] 后端代码适配 (config.py, lipsync_service.py)
- [x] Latent Diffusion 架构 (512x512 高清)
- [x] 服务器端到端验证
### 阶段十:性能优化 (Day 6) ✅
> **目标**:提升系统响应速度和稳定性
- [x] 视频预压缩优化 (1080p → 720p 自动适配)
- [x] 进度更新细化 (实时反馈)
- [x] **常驻模型服务** (Persistent Server, 0s 加载)
- [x] **GPU 并发控制** (串行队列防崩溃)
### 阶段十一:社交媒体发布完善 (Day 7) ✅
> **目标**:实现全自动扫码登录和多平台发布
- [x] QR码自动登录 (Playwright headless + Stealth)
- [x] 多平台上传器架构 (B站/抖音/小红书)
- [x] Cookie 自动管理
- [x] 定时发布功能
### 阶段十二:用户体验优化 (Day 8) ✅
> **目标**:提升文件管理和历史记录功能
- [x] 文件名保留 (时间戳前缀 + 原始名称)
- [x] 视频持久化 (历史视频列表 API)
- [x] 素材/视频删除功能
### 阶段十三:发布模块优化 (Day 9) ✅
> **目标**:代码质量优化 + 发布功能验证
- [x] B站/抖音登录+发布验证通过
- [x] 资源清理保障 (try-finally)
- [x] 超时保护 (消除无限循环)
- [x] 完整类型提示
### 阶段十四:用户认证系统 (Day 9) ✅
> **目标**:实现安全、隔离的多用户认证体系
- [x] Supabase 云数据库集成 (本地自托管)
- [x] JWT + HttpOnly Cookie 认证架构
- [x] 用户表与权限表设计 (RLS 准备)
- [x] 认证部署文档 (Docs/SUPABASE_DEPLOY.md)
### 阶段十五:部署稳定性优化 (Day 9) ✅
> **目标**:确保生产环境服务长期稳定
- [x] 依赖冲突修复 (bcrypt)
- [x] 前端构建修复 (Production Build)
- [x] PM2 进程守护配置
- [x] 部署手册更新 (Docs/DEPLOY_MANUAL.md)
### 阶段十六HTTPS 全栈部署 (Day 10) ✅
> **目标**:实现安全的公网 HTTPS 访问
- [x] 阿里云 Nginx 反向代理配置
- [x] Let's Encrypt SSL 证书集成
- [x] Supabase 自托管部署 (Docker)
- [x] 端口冲突解决 (3003/8008/8444)
- [x] Basic Auth 管理后台保护
---

View File

@@ -1,23 +1,24 @@
# ViGent 数字人口播系统 - 开发任务清单
**项目**ViGent2 数字人口播视频生成系统
**服务器**Dell R730 (2× RTX 3090 24GB)
**更新时间**2026-01-20
**整体进度**100%Day 6 LatentSync 1.6 升级完成
**项目**ViGent2 数字人口播视频生成系统
**服务器**Dell R730 (2× RTX 3090 24GB)
**更新时间**2026-01-28
**整体进度**100%Day 12 iOS 兼容、移动端优化、Qwen3-TTS 部署
## 📖 快速导航
| 章节 | 说明 |
|------|------|
| [已完成任务](#-已完成任务) | Day 1-4 完成的功能 |
| [已完成任务](#-已完成任务) | Day 1-12 完成的功能 |
| [后续规划](#-后续规划) | 待办项目 |
| [进度统计](#-进度统计) | 各模块完成度 |
| [里程碑](#-里程碑) | 关键节点 |
| [时间线](#-时间线) | 开发历程 |
**相关文档**
- [Day 日志](file:///d:/CodingProjects/Antigravity/ViGent2/Docs/DevLogs/) (Day1-6)
- [Day 日志](file:///d:/CodingProjects/Antigravity/ViGent2/Docs/DevLogs/) (Day1-Day12)
- [部署指南](file:///d:/CodingProjects/Antigravity/ViGent2/Docs/DEPLOY_MANUAL.md)
- [Qwen3-TTS 部署](file:///d:/CodingProjects/Antigravity/ViGent2/Docs/QWEN3_TTS_DEPLOY.md)
---
@@ -45,7 +46,8 @@
- [x] Playwright 自动化框架
- [x] Cookie 管理功能
- [x] 多平台发布 UI
- [ ] 定时发布功能
- [x] 定时发布功能 (Day 7)
- [x] QR码自动登录 (Day 7)
### 阶段五:部署与文档
- [x] 手动部署指南 (DEPLOY_MANUAL.md)
@@ -86,25 +88,102 @@
- [x] LipSync 服务单例缓存
- [x] 健康检查缓存 (5分钟)
- [x] 异步子进程修复 (subprocess.run → asyncio)
- [ ] 预加载模型服务 (可选)
- [ ] 批量队列处理 (可选)
- [x] 预加载模型服务 (常驻 Server + FastAPI)
- [x] 批量队列处理 (GPU 并发控制)
### 阶段十一:社交媒体发布完善 (Day 7)
- [x] QR码自动登录 (Playwright headless)
- [x] 多平台上传器架构 (B站/抖音/小红书)
- [x] B站发布 (biliup官方库)
- [x] 抖音/小红书发布 (Playwright)
- [x] 定时发布功能
- [x] 前端发布UI优化
- [x] Cookie自动管理
- [x] UI一致性修复 (导航栏对齐、滚动条隐藏)
- [x] QR登录超时修复 (Stealth模式、多选择器fallback)
- [x] 文档规则优化 (智能修改标准、工具使用规范)
### 阶段十二:用户体验优化 (Day 8)
- [x] 文件名保留 (时间戳前缀 + 原始名称)
- [x] 视频持久化 (从文件系统读取历史)
- [x] 历史视频列表组件
- [x] 素材/视频删除功能
- [x] 登出功能 (Logout API + 前端按钮)
- [x] 前端 SWR 轮询优化
- [x] QR 登录状态检测修复
### 阶段十三:发布模块优化 (Day 9)
- [x] B站/抖音发布验证通过
- [x] 资源清理保障 (try-finally)
- [x] 超时保护 (消除无限循环)
- [x] 小红书 headless 模式修复
- [x] API 输入验证
- [x] 完整类型提示
- [x] 扫码登录等待界面 (加载动画)
- [x] 抖音/B站登录策略优化 (Text优先)
- [x] 发布成功审核提示
### 阶段十四:用户认证系统 (Day 9)
- [x] Supabase 数据库表设计与部署
- [x] JWT 认证 (HttpOnly Cookie)
- [x] 用户注册/登录/登出 API
- [x] 管理员权限控制 (is_active)
- [x] 单设备登录限制 (Session Token)
- [x] 防止 Supabase 暂停 (GitHub Actions/Crontab)
- [x] 认证部署文档 (AUTH_DEPLOY.md)
### 阶段十五:部署稳定性优化 (Day 9)
- [x] 后端依赖修复 (bcrypt/email-validator)
- [x] 前端生产环境构建修复 (npm run build)
- [x] LatentSync 性能卡顿修复 (OMP_NUM_THREADS限制)
- [x] 部署服务自愈 (PM2 配置优化)
- [x] 部署手册全量更新 (DEPLOY_MANUAL.md)
### 阶段十六HTTPS 部署与细节完善 (Day 10)
- [x] 隧道访问修复 (StaticFiles 挂载 + Rewrite)
- [x] 平台账号列表 500 错误修复 (paths.py)
- [x] Nginx HTTPS 配置 (反向代理 + SSL)
- [x] 浏览器标题修改 (ViGent)
- [x] 代码自适应 HTTPS 验证
- [x] **Supabase 自托管部署** (Docker, 3003/8008端口)
- [x] **安全加固** (Basic Auth 保护后台)
- [x] **端口冲突解决** (迁移 Analytics/Kong)
### 阶段十七:上传架构重构 (Day 11)
- [x] **直传改造** (前端直接上传 Supabase绕过后端代理)
- [x] **后端适配** (Signed URL 签名生成)
- [x] **RLS 策略部署** (SQL 脚本自动化权限配置)
- [x] **超时问题根治** (彻底解决 Nginx/FRP 30s 限制)
- [x] **前端依赖更新** (@supabase/supabase-js 集成)
### 阶段十八:用户隔离与存储优化 (Day 11)
- [x] **用户数据隔离** (素材/视频/Cookie 按用户ID目录隔离)
- [x] **Storage URL 修复** (SUPABASE_PUBLIC_URL 配置,修复 localhost 问题)
- [x] **发布服务优化** (直接读取本地 Supabase Storage 文件,跳过 HTTP 下载)
- [x] **Supabase Studio 配置** (公网访问配置)
### 阶段十九iOS 兼容与移动端 UI 优化 (Day 12)
- [x] **Axios 全局拦截器** (401/403 自动跳转登录,防重复跳转)
- [x] **iOS Safari 安全区域修复** (viewport-fit: cover, themeColor, 渐变背景统一)
- [x] **移动端 Header 优化** (按钮紧凑布局,响应式间距)
- [x] **发布页面 UI 重构** (立即发布/定时发布按钮分离,防误触设计)
- [x] **Qwen3-TTS 0.6B 部署** (声音克隆模型GPU03秒参考音频快速克隆)
---
## 🛤️ 后续规划
### 🔴 优先待办
- [x] 视频合成最终验证 (MP4生成) ✅ Day 4 完
- [x] 端到端流程完整测试 ✅ Day 4 完成
- [ ] 社交媒体发布测试
- [ ] **Qwen3-TTS 集成到 ViGent2** - 前端 UI + 后端服务集
- [ ] 批量视频生成架构设计
### 🟠 功能完善
- [ ] 定时发布功能
- [x] 定时发布功能 ✅ Day 7 完成
- [ ] **后端定时发布** - 替代平台端定时,使用 APScheduler 实现任务调度
- [ ] 批量视频生成
- [ ] 字幕样式编辑器
### 🔵 长期探索
- [ ] 声音克隆 (GPT-SoVITS)
- [ ] Docker 容器化
- [ ] Celery 分布式任务队列
@@ -126,8 +205,9 @@
| TTS 配音 | 100% | ✅ 完成 |
| 视频合成 | 100% | ✅ 完成 |
| 唇形同步 | 100% | ✅ LatentSync 1.6 升级完成 |
| 社交发布 | 80% | 🔄 框架完成,待测试 |
| 服务器部署 | 100% | ✅ 完成 |
| 社交发布 | 100% | ✅ Day 9 验证通过 |
| 用户认证 | 100% | ✅ Day 9 Supabase+JWT |
| 服务器部署 | 100% | ✅ Day 9 稳定性优化完成 |
---
@@ -162,11 +242,26 @@
- Latent Diffusion 架构升级
- 性能优化 (视频预压缩、进度更新)
### Milestone 5: 用户认证系统 ✅
**完成时间**: Day 9
**成果**:
- Supabase 云数据库集成
- 安全的 JWT + HttpOnly Cookie 认证
- 管理员后台与用户隔离
- 完善的部署与保活方案
### Milestone 6: 生产环境部署稳定化 ✅
**完成时间**: Day 9
**成果**:
- 修复了后端 (bcrypt) 和前端 (build) 的启动崩溃问题
- 解决了 LatentSync 占用全量 CPU 导致服务器卡顿的严重问题
- 完善了部署手册,记录了关键的 Troubleshooting 步骤
- 实现了服务 Long-term 稳定运行 (Reset PM2 counter)
---
## 📅 时间线
```
Day 1: 项目初始化 + 核心功能 ✅ 完成
- 后端 API 框架
- 前端 UI
@@ -204,5 +299,62 @@ Day 6: LatentSync 1.6 升级 ✅ 完成
- 模型部署指南
- 服务器部署验证
- 性能优化 (视频预压缩、进度更新)
```
Day 7: 社交媒体发布完善 ✅ 完成
- QR码自动登录 (B站/抖音验证通过)
- 智能定位策略 (CSS/Text并行)
- 多平台发布 (B站/抖音/小红书)
- UI 一致性优化
- 文档规则体系优化
Day 8: 用户体验优化 ✅ 完成
- 文件名保留 (时间戳前缀)
- 视频持久化 (历史视频API)
- 历史视频列表组件
- 素材/视频删除功能
Day 9: 发布模块优化 ✅ 完成
- B站/抖音登录+发布验证通过
- 资源清理保障 (try-finally)
- 超时保护 (消除无限循环)
- 小红书 headless 模式修复
- 扫码登录等待界面 (加载动画)
- 抖音/B站登录策略优化 (Text优先)
- 发布成功审核提示
- 用户认证系统规划 (FastAPI+Supabase)
- Supabase 表结构设计 (users/sessions)
- 后端 JWT 认证实现 (auth.py/deps.py)
- 数据库配置与 SQL 部署
- 独立认证部署文档 (AUTH_DEPLOY.md)
- 自动保活机制 (Crontab/Actions)
- 部署稳定性优化 (Backend依赖修复)
- 前端生产构建流程修复
- LatentSync 严重卡顿修复 (线程数限制)
- 部署手册全量更新
Day 10: HTTPS 部署与细节完善 ✅ 完成
- 隧道访问视频修正 (挂载 uploads)
- 账号列表 Bug 修复 (paths.py 白名单)
- 阿里云 Nginx HTTPS 部署
- UI 细节优化 (Title 更新)
Day 11: 上传架构重构 ✅ 完成
- **核心修复**: Aliyun Nginx `client_max_body_size 0` 配置
- 500 错误根治 (Direct Upload + Gateway Config)
- Supabase RLS 权限策略部署
- 前端集成 supabase-js
- 彻底解决大文件上传超时 (30s 限制)
- **用户数据隔离** (素材/视频/Cookie 按用户目录存储)
- **Storage URL 修复** (SUPABASE_PUBLIC_URL 公网地址配置)
- **发布服务优化** (本地文件直读,跳过 HTTP 下载)
Day 12: iOS 兼容与移动端优化 ✅ 完成
- Axios 全局拦截器 (401/403 自动跳转登录)
- iOS Safari 安全区域白边修复 (viewport-fit: cover)
- themeColor 配置 (状态栏颜色适配)
- 渐变背景统一 (body 全局渐变,消除分层)
- 移动端 Header 响应式优化 (按钮紧凑布局)
- 发布页面 UI 重构 (立即发布 3/4 + 定时 1/4)
- **Qwen3-TTS 0.6B 部署** (声音克隆模型GPU0)
- **部署文档** (QWEN3_TTS_DEPLOY.md)

View File

@@ -10,9 +10,11 @@
- 🎬 **唇形同步** - LatentSync 1.6 驱动512×512 高分辨率 Diffusion 模型
- 🎙️ **TTS 配音** - EdgeTTS 多音色支持(云溪、晓晓等)
- 📱 **一键发布** - Playwright 自动发布到抖音小红书、B站等
- 🖥️ **Web UI** - Next.js 现代化界面
- 🚀 **性能优化** - 视频预压缩、健康检查缓存
- 📱 **全自动发布** - 扫码登录 + Cookie持久化支持多平台(B站/抖音/小红书)定时发布
- 🖥️ **Web UI** - Next.js 现代化界面iOS/Android 移动端适配
- 🔐 **用户系统** - Supabase + JWT 认证,支持管理员后台、注册/登录
- 👥 **多用户隔离** - 素材/视频/Cookie 按用户独立存储,数据完全隔离
- 🚀 **性能优化** - 视频预压缩、常驻模型服务 (0s加载)、本地文件直读
## 🛠️ 技术栈
@@ -20,6 +22,9 @@
|------|------|
| 前端 | Next.js 14 + TypeScript + TailwindCSS |
| 后端 | FastAPI + Python 3.10 |
| 数据库 | **Supabase** (PostgreSQL) 自托管 Docker |
| 存储 | **Supabase Storage** (本地文件系统) |
| 认证 | **JWT** + HttpOnly Cookie |
| 唇形同步 | **LatentSync 1.6** (Latent Diffusion, 512×512) |
| TTS | EdgeTTS |
| 视频处理 | FFmpeg |
@@ -45,6 +50,7 @@ ViGent2/
│ └── DEPLOY.md # LatentSync 部署指南
└── Docs/ # 文档
├── DEPLOY_MANUAL.md # 部署手册
├── AUTH_DEPLOY.md # 认证部署指南
├── task_complete.md
└── DevLogs/
```
@@ -102,6 +108,10 @@ uvicorn app.main:app --host 0.0.0.0 --port 8006
# 终端 2: 前端 (端口 3002)
cd frontend
npm run dev -- -p 3002
# 终端 3: LatentSync 服务 (端口 8007, 推荐启动)
cd models/LatentSync
nohup python -m scripts.server > server.log 2>&1 &
```
---
@@ -125,18 +135,21 @@ npm run dev -- -p 3002
## 🌐 访问地址
| 服务 | 地址 |
|------|------|
| 视频生成 | http://服务器IP:3002 |
| 发布管理 | http://服务器IP:3002/publish |
| API 文档 | http://服务器IP:8006/docs |
| 服务 | 地址 | 说明 |
|------|------|------|
| **视频生成 (UI)** | `https://vigent.hbyrkj.top` | 用户访问入口 |
| **API 服务** | `http://<服务器IP>:8006` | 后端 Swagger |
| **认证管理 (Studio)** | `https://supabase.hbyrkj.top` | 需要 Basic Auth |
| **认证 API (Kong)** | `https://api.hbyrkj.top` | Supabase 接口 |
| **模型服务** | `http://<服务器IP>:8007` | LatentSync |
---
## 📖 文档
- [LatentSync 部署指南](models/LatentSync/DEPLOY.md)
- [手动部署指南](Docs/DEPLOY_MANUAL.md)
- [Supabase 部署指南](Docs/SUPABASE_DEPLOY.md)
- [LatentSync 部署指南](models/LatentSync/DEPLOY.md)
- [开发日志](Docs/DevLogs/)
- [任务进度](Docs/task_complete.md)

View File

@@ -15,11 +15,15 @@ DEFAULT_TTS_VOICE=zh-CN-YunxiNeural
# GPU 选择 (0=第一块GPU, 1=第二块GPU)
LATENTSYNC_GPU_ID=1
# 使用本地模式 (true) 或远程 API (false)
# 使用本地模式 (true) 或远程 API (false)
LATENTSYNC_LOCAL=true
# 远程 API 地址 (仅 LATENTSYNC_LOCAL=false 时使用)
# LATENTSYNC_API_URL=http://localhost:8001
# 使用常驻服务 (Persistent Server) 加速
LATENTSYNC_USE_SERVER=false
# 远程 API 地址 (常驻服务默认端口 8007)
# LATENTSYNC_API_URL=http://localhost:8007
# 推理步数 (20-50, 越高质量越好,速度越慢)
LATENTSYNC_INFERENCE_STEPS=20
@@ -41,3 +45,19 @@ MAX_UPLOAD_SIZE_MB=500
# FFmpeg 路径 (如果不在系统 PATH 中)
# FFMPEG_PATH=/usr/bin/ffmpeg
# =============== Supabase 配置 ===============
# 从 Supabase 项目设置 > API 获取
SUPABASE_URL=http://localhost:8008/
SUPABASE_PUBLIC_URL=https://api.hbyrkj.top
SUPABASE_KEY=eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9.eyJyb2xlIjogInNlcnZpY2Vfcm9sZSIsICJpc3MiOiAic3VwYWJhc2UiLCAiaWF0IjogMTc2OTQwNzU2NSwgImV4cCI6IDIwODQ3Njc1NjV9.LBPaimygpnM9o3mZ2Pi-iL8taJ90JjGbQ0HW6yFlmhg
# =============== JWT 配置 ===============
# 用于签名 JWT Token 的密钥 (请更换为随机字符串)
JWT_SECRET_KEY=F4MagRkf7nJsN-ag9AB7Q-30MbZRe7Iu4E9p9xRzyic
JWT_ALGORITHM=HS256
JWT_EXPIRE_HOURS=168
# =============== 管理员配置 ===============
# 服务启动时自动创建的管理员账号
ADMIN_EMAIL=lamnickdavid@gmail.com
ADMIN_PASSWORD=lam1988324

185
backend/app/api/admin.py Normal file
View File

@@ -0,0 +1,185 @@
"""
管理员 API用户管理
"""
from fastapi import APIRouter, HTTPException, Depends, status
from pydantic import BaseModel
from typing import Optional, List
from datetime import datetime, timezone, timedelta
from app.core.supabase import get_supabase
from app.core.deps import get_current_admin
from loguru import logger
router = APIRouter(prefix="/api/admin", tags=["管理"])
class UserListItem(BaseModel):
id: str
email: str
username: Optional[str]
role: str
is_active: bool
expires_at: Optional[str]
created_at: str
class ActivateRequest(BaseModel):
expires_days: Optional[int] = None # 授权天数None 表示永久
@router.get("/users", response_model=List[UserListItem])
async def list_users(admin: dict = Depends(get_current_admin)):
"""获取所有用户列表"""
try:
supabase = get_supabase()
result = supabase.table("users").select("*").order("created_at", desc=True).execute()
return [
UserListItem(
id=u["id"],
email=u["email"],
username=u.get("username"),
role=u["role"],
is_active=u["is_active"],
expires_at=u.get("expires_at"),
created_at=u["created_at"]
)
for u in result.data
]
except Exception as e:
logger.error(f"获取用户列表失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="获取用户列表失败"
)
@router.post("/users/{user_id}/activate")
async def activate_user(
user_id: str,
request: ActivateRequest,
admin: dict = Depends(get_current_admin)
):
"""
激活用户
Args:
user_id: 用户 ID
request.expires_days: 授权天数 (None 表示永久)
"""
try:
supabase = get_supabase()
# 计算过期时间
expires_at = None
if request.expires_days:
expires_at = (datetime.now(timezone.utc) + timedelta(days=request.expires_days)).isoformat()
# 更新用户
result = supabase.table("users").update({
"is_active": True,
"role": "user",
"expires_at": expires_at
}).eq("id", user_id).execute()
if not result.data:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail="用户不存在"
)
logger.info(f"管理员 {admin['email']} 激活用户 {user_id}, 有效期: {request.expires_days or '永久'}")
return {
"success": True,
"message": f"用户已激活,有效期: {request.expires_days or '永久'}"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"激活用户失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="激活用户失败"
)
@router.post("/users/{user_id}/deactivate")
async def deactivate_user(
user_id: str,
admin: dict = Depends(get_current_admin)
):
"""停用用户"""
try:
supabase = get_supabase()
# 不能停用管理员
user_result = supabase.table("users").select("role").eq("id", user_id).single().execute()
if user_result.data and user_result.data["role"] == "admin":
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="不能停用管理员账号"
)
# 更新用户
result = supabase.table("users").update({
"is_active": False
}).eq("id", user_id).execute()
# 清除用户 session
supabase.table("user_sessions").delete().eq("user_id", user_id).execute()
logger.info(f"管理员 {admin['email']} 停用用户 {user_id}")
return {"success": True, "message": "用户已停用"}
except HTTPException:
raise
except Exception as e:
logger.error(f"停用用户失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="停用用户失败"
)
@router.post("/users/{user_id}/extend")
async def extend_user(
user_id: str,
request: ActivateRequest,
admin: dict = Depends(get_current_admin)
):
"""延长用户授权期限"""
try:
supabase = get_supabase()
if not request.expires_days:
# 设为永久
expires_at = None
else:
# 获取当前过期时间
user_result = supabase.table("users").select("expires_at").eq("id", user_id).single().execute()
user = user_result.data
if user and user.get("expires_at"):
current_expires = datetime.fromisoformat(user["expires_at"].replace("Z", "+00:00"))
base_time = max(current_expires, datetime.now(timezone.utc))
else:
base_time = datetime.now(timezone.utc)
expires_at = (base_time + timedelta(days=request.expires_days)).isoformat()
result = supabase.table("users").update({
"expires_at": expires_at
}).eq("id", user_id).execute()
logger.info(f"管理员 {admin['email']} 延长用户 {user_id} 授权 {request.expires_days or '永久'}")
return {
"success": True,
"message": f"授权已延长 {request.expires_days or '永久'}"
}
except Exception as e:
logger.error(f"延长授权失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="延长授权失败"
)

223
backend/app/api/auth.py Normal file
View File

@@ -0,0 +1,223 @@
"""
认证 API注册、登录、登出
"""
from fastapi import APIRouter, HTTPException, Response, status, Request
from pydantic import BaseModel, EmailStr
from app.core.supabase import get_supabase
from app.core.security import (
get_password_hash,
verify_password,
create_access_token,
generate_session_token,
set_auth_cookie,
clear_auth_cookie,
decode_access_token
)
from loguru import logger
from typing import Optional
router = APIRouter(prefix="/api/auth", tags=["认证"])
class RegisterRequest(BaseModel):
email: EmailStr
password: str
username: Optional[str] = None
class LoginRequest(BaseModel):
email: EmailStr
password: str
class UserResponse(BaseModel):
id: str
email: str
username: Optional[str]
role: str
is_active: bool
@router.post("/register")
async def register(request: RegisterRequest):
"""
用户注册
注册后状态为 pending需要管理员激活
"""
try:
supabase = get_supabase()
# 检查邮箱是否已存在
existing = supabase.table("users").select("id").eq(
"email", request.email
).execute()
if existing.data:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="该邮箱已注册"
)
# 创建用户
password_hash = get_password_hash(request.password)
result = supabase.table("users").insert({
"email": request.email,
"password_hash": password_hash,
"username": request.username or request.email.split("@")[0],
"role": "pending",
"is_active": False
}).execute()
logger.info(f"新用户注册: {request.email}")
return {
"success": True,
"message": "注册成功,请等待管理员审核激活"
}
except HTTPException:
raise
except Exception as e:
logger.error(f"注册失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="注册失败,请稍后重试"
)
@router.post("/login")
async def login(request: LoginRequest, response: Response):
"""
用户登录
- 验证密码
- 检查是否激活
- 实现"后踢前"单设备登录
"""
try:
supabase = get_supabase()
# 查找用户
user_result = supabase.table("users").select("*").eq(
"email", request.email
).single().execute()
user = user_result.data
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="邮箱或密码错误"
)
# 验证密码
if not verify_password(request.password, user["password_hash"]):
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="邮箱或密码错误"
)
# 检查是否激活
if not user["is_active"]:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="账号未激活,请等待管理员审核"
)
# 检查授权是否过期
if user.get("expires_at"):
from datetime import datetime, timezone
expires_at = datetime.fromisoformat(user["expires_at"].replace("Z", "+00:00"))
if datetime.now(timezone.utc) > expires_at:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="授权已过期,请联系管理员续期"
)
# 生成新的 session_token (后踢前)
session_token = generate_session_token()
# 删除旧 session插入新 session
supabase.table("user_sessions").delete().eq(
"user_id", user["id"]
).execute()
supabase.table("user_sessions").insert({
"user_id": user["id"],
"session_token": session_token,
"device_info": None # 可以从 request headers 获取
}).execute()
# 生成 JWT Token
token = create_access_token(user["id"], session_token)
# 设置 HttpOnly Cookie
set_auth_cookie(response, token)
logger.info(f"用户登录: {request.email}")
return {
"success": True,
"message": "登录成功",
"user": UserResponse(
id=user["id"],
email=user["email"],
username=user.get("username"),
role=user["role"],
is_active=user["is_active"]
)
}
except HTTPException:
raise
except Exception as e:
logger.error(f"登录失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="登录失败,请稍后重试"
)
@router.post("/logout")
async def logout(response: Response):
"""用户登出"""
clear_auth_cookie(response)
return {"success": True, "message": "已登出"}
@router.get("/me")
async def get_me(request: Request):
"""获取当前用户信息"""
# 从 Cookie 获取用户
token = request.cookies.get("access_token")
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未登录"
)
token_data = decode_access_token(token)
if not token_data:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token 无效"
)
supabase = get_supabase()
user_result = supabase.table("users").select("*").eq(
"id", token_data.user_id
).single().execute()
user = user_result.data
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户不存在"
)
return UserResponse(
id=user["id"],
email=user["email"],
username=user.get("username"),
role=user["role"],
is_active=user["is_active"]
)

View File

@@ -0,0 +1,221 @@
"""
前端一键扫码登录辅助页面
客户在自己的浏览器中扫码JavaScript自动提取Cookie并上传到服务器
"""
from fastapi import APIRouter, Request
from fastapi.responses import HTMLResponse
from app.core.config import settings
router = APIRouter()
@router.get("/login-helper/{platform}", response_class=HTMLResponse)
async def login_helper_page(platform: str, request: Request):
"""
提供一个HTML页面让用户在自己的浏览器中登录平台
登录后JavaScript自动提取Cookie并POST回服务器
"""
platform_urls = {
"bilibili": "https://www.bilibili.com/",
"douyin": "https://creator.douyin.com/",
"xiaohongshu": "https://creator.xiaohongshu.com/"
}
platform_names = {
"bilibili": "B站",
"douyin": "抖音",
"xiaohongshu": "小红书"
}
if platform not in platform_urls:
return "<h1>不支持的平台</h1>"
# 获取服务器地址用于回传Cookie
server_url = str(request.base_url).rstrip('/')
html_content = f"""
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<title>{platform_names[platform]} 一键登录</title>
<style>
body {{
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
margin: 0;
padding: 20px;
min-height: 100vh;
display: flex;
align-items: center;
justify-content: center;
}}
.container {{
background: white;
border-radius: 20px;
padding: 50px;
box-shadow: 0 20px 60px rgba(0,0,0,0.3);
max-width: 700px;
width: 100%;
}}
h1 {{
color: #333;
margin: 0 0 30px 0;
text-align: center;
font-size: 32px;
}}
.step {{
display: flex;
align-items: flex-start;
margin: 25px 0;
padding: 20px;
background: linear-gradient(135deg, #f5f7fa 0%, #c3cfe2 100%);
border-radius: 12px;
border-left: 5px solid #667eea;
}}
.step-number {{
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
width: 40px;
height: 40px;
border-radius: 50%;
display: flex;
align-items: center;
justify-content: center;
font-weight: bold;
font-size: 20px;
margin-right: 20px;
flex-shrink: 0;
}}
.step-content {{
flex: 1;
}}
.step-title {{
font-weight: 600;
font-size: 18px;
margin-bottom: 8px;
color: #333;
}}
.step-desc {{
color: #666;
line-height: 1.6;
}}
.bookmarklet {{
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
padding: 15px 30px;
border-radius: 10px;
text-decoration: none;
display: inline-block;
font-weight: 600;
font-size: 18px;
margin: 20px 0;
cursor: move;
border: 3px dashed white;
transition: transform 0.2s;
}}
.bookmarklet:hover {{
transform: scale(1.05);
}}
.bookmarklet-container {{
text-align: center;
margin: 30px 0;
padding: 30px;
background: #f8f9fa;
border-radius: 12px;
}}
.instruction {{
font-size: 14px;
color: #666;
margin-top: 10px;
}}
.highlight {{
background: #fff3cd;
padding: 2px 6px;
border-radius: 4px;
font-weight: 600;
}}
.btn {{
background: linear-gradient(135deg, #667eea 0%, #764ba2 100%);
color: white;
border: none;
padding: 15px 40px;
border-radius: 10px;
font-size: 18px;
cursor: pointer;
font-weight: 600;
width: 100%;
margin-top: 20px;
transition: transform 0.2s;
}}
.btn:hover {{
transform: translateY(-2px);
}}
</style>
</head>
<body>
<div class="container">
<h1>🔐 {platform_names[platform]} 一键登录</h1>
<div class="step">
<div class="step-number">1</div>
<div class="step-content">
<div class="step-title">拖拽书签到书签栏</div>
<div class="step-desc">
将下方的"<span class="highlight">保存{platform_names[platform]}登录</span>"按钮拖拽到浏览器书签栏
<br><small>(如果书签栏未显示,按 Ctrl+Shift+B 显示)</small>
</div>
</div>
</div>
<div class="bookmarklet-container">
<a href="javascript:(function(){{var c=document.cookie;if(!c){{alert('请先登录{platform_names[platform]}');return;}}fetch('{server_url}/api/publish/cookies/save/{platform}',{{method:'POST',headers:{{'Content-Type':'application/json'}},body:JSON.stringify({{cookie_string:c}})}}).then(r=>r.json()).then(d=>{{if(d.success){{alert('✅ 登录成功!');window.opener&&window.opener.location.reload();}}else{{alert(''+d.message);}}}}
).catch(e=>alert('提交失败:'+e));}})();"
class="bookmarklet"
onclick="alert('请拖拽此按钮到书签栏,不要点击!'); return false;">
🔖 保存{platform_names[platform]}登录
</a>
<div class="instruction">
⬆️ <strong>拖拽此按钮到浏览器顶部书签栏</strong>
</div>
</div>
<div class="step">
<div class="step-number">2</div>
<div class="step-content">
<div class="step-title">登录 {platform_names[platform]}</div>
<div class="step-desc">
点击下方按钮打开{platform_names[platform]}登录页,扫码登录
</div>
</div>
</div>
<button class="btn" onclick="window.open('{platform_urls[platform]}', 'login_tab')">
🚀 打开{platform_names[platform]}登录页
</button>
<div class="step">
<div class="step-number">3</div>
<div class="step-content">
<div class="step-title">一键保存登录</div>
<div class="step-desc">
登录成功后,点击书签栏的"<span class="highlight">保存{platform_names[platform]}登录</span>"书签
<br>系统会自动提取并保存Cookie完成
</div>
</div>
</div>
<hr style="margin: 40px 0; border: none; border-top: 2px solid #eee;">
<div style="text-align: center; color: #999; font-size: 14px;">
<p>💡 <strong>提示</strong>:书签只需拖拽一次,下次登录直接点击书签即可</p>
<p>🔒 所有数据仅在您的浏览器和服务器之间传输,安全可靠</p>
</div>
</div>
</body>
</html>
"""
return HTMLResponse(content=html_content)

View File

@@ -1,53 +1,331 @@
from fastapi import APIRouter, UploadFile, File, HTTPException
from fastapi import APIRouter, UploadFile, File, HTTPException, Request, BackgroundTasks, Depends
from app.core.config import settings
import shutil
import uuid
from app.core.deps import get_current_user
from app.services.storage import storage_service
import re
import time
import traceback
import os
import aiofiles
from pathlib import Path
from loguru import logger
router = APIRouter()
@router.post("/")
async def upload_material(file: UploadFile = File(...)):
if not file.filename.lower().endswith(('.mp4', '.mov', '.avi')):
raise HTTPException(400, "Invalid format")
file_id = str(uuid.uuid4())
ext = Path(file.filename).suffix
save_path = settings.UPLOAD_DIR / "materials" / f"{file_id}{ext}"
# Save file
with open(save_path, "wb") as buffer:
shutil.copyfileobj(file.file, buffer)
# Calculate size
size_mb = save_path.stat().st_size / (1024 * 1024)
return {
"id": file_id,
"name": file.filename,
"path": f"uploads/materials/{file_id}{ext}",
"size_mb": size_mb,
"type": "video"
}
def sanitize_filename(filename: str) -> str:
safe_name = re.sub(r'[<>:"/\\|?*]', '_', filename)
if len(safe_name) > 100:
ext = Path(safe_name).suffix
safe_name = safe_name[:100 - len(ext)] + ext
return safe_name
@router.get("/")
async def list_materials():
materials_dir = settings.UPLOAD_DIR / "materials"
files = []
if materials_dir.exists():
for f in materials_dir.glob("*"):
async def process_and_upload(temp_file_path: str, original_filename: str, content_type: str, user_id: str):
"""Background task to strip multipart headers and upload to Supabase"""
try:
logger.info(f"Processing raw upload: {temp_file_path} for user {user_id}")
# 1. Analyze file to find actual video content (strip multipart boundaries)
# This is a simplified manual parser for a SINGLE file upload.
# Structure:
# --boundary
# Content-Disposition: form-data; name="file"; filename="..."
# Content-Type: video/mp4
# \r\n\r\n
# [DATA]
# \r\n--boundary--
# We need to read the first few KB to find the header end
start_offset = 0
end_offset = 0
boundary = b""
file_size = os.path.getsize(temp_file_path)
with open(temp_file_path, 'rb') as f:
# Read first 4KB to find header
head = f.read(4096)
# Find boundary
first_line_end = head.find(b'\r\n')
if first_line_end == -1:
raise Exception("Could not find boundary in multipart body")
boundary = head[:first_line_end] # e.g. --boundary123
logger.info(f"Detected boundary: {boundary}")
# Find end of headers (\r\n\r\n)
header_end = head.find(b'\r\n\r\n')
if header_end == -1:
raise Exception("Could not find end of multipart headers")
start_offset = header_end + 4
logger.info(f"Video data starts at offset: {start_offset}")
# Find end boundary (read from end of file)
# It should be \r\n + boundary + -- + \r\n
# We seek to end-200 bytes
f.seek(max(0, file_size - 200))
tail = f.read()
# The closing boundary is usually --boundary--
# We look for the last occurrence of the boundary
last_boundary_pos = tail.rfind(boundary)
if last_boundary_pos != -1:
# The data ends before \r\n + boundary
# The tail buffer relative position needs to be converted to absolute
end_pos_in_tail = last_boundary_pos
# We also need to check for the preceding \r\n
if end_pos_in_tail >= 2 and tail[end_pos_in_tail-2:end_pos_in_tail] == b'\r\n':
end_pos_in_tail -= 2
# Absolute end offset
end_offset = (file_size - 200) + last_boundary_pos
# Correction for CRLF before boundary
# Actually, simply: read until (file_size - len(tail) + last_boundary_pos) - 2
end_offset = (max(0, file_size - 200) + last_boundary_pos) - 2
else:
logger.warning("Could not find closing boundary, assuming EOF")
end_offset = file_size
logger.info(f"Video data ends at offset: {end_offset}. Total video size: {end_offset - start_offset}")
# 2. Extract and Upload to Supabase
# Since we have the file on disk, we can just pass the file object (seeked) to upload_file?
# Or if upload_file expects bytes/path, checking storage.py...
# It takes `file_data` (bytes) or file-like?
# supabase-py's `upload` method handles parsing if we pass a file object.
# But we need to pass ONLY the video slice.
# So we create a generator or a sliced file object?
# Simpler: Read the slice into memory if < 1GB? Or copy to new temp file?
# Copying to new temp file is safer for memory.
video_path = temp_file_path + "_video.mp4"
with open(temp_file_path, 'rb') as src, open(video_path, 'wb') as dst:
src.seek(start_offset)
# Copy in chunks
bytes_to_copy = end_offset - start_offset
copied = 0
while copied < bytes_to_copy:
chunk_size = min(1024*1024*10, bytes_to_copy - copied) # 10MB chunks
chunk = src.read(chunk_size)
if not chunk:
break
dst.write(chunk)
copied += len(chunk)
logger.info(f"Extracted video content to {video_path}")
# 3. Upload to Supabase with user isolation
timestamp = int(time.time())
safe_name = re.sub(r'[^a-zA-Z0-9._-]', '', original_filename)
# 使用 user_id 作为目录前缀实现隔离
storage_path = f"{user_id}/{timestamp}_{safe_name}"
# Use storage service (this calls Supabase which might do its own http request)
# We read the cleaned video file
with open(video_path, 'rb') as f:
file_content = f.read() # Still reading into memory for simple upload call, but server has 32GB RAM so ok for 500MB
await storage_service.upload_file(
bucket=storage_service.BUCKET_MATERIALS,
path=storage_path,
file_data=file_content,
content_type=content_type
)
logger.info(f"Upload to Supabase complete: {storage_path}")
# Cleanup
os.remove(temp_file_path)
os.remove(video_path)
return storage_path
except Exception as e:
logger.error(f"Background upload processing failed: {e}\n{traceback.format_exc()}")
raise
@router.post("")
async def upload_material(
request: Request,
background_tasks: BackgroundTasks,
current_user: dict = Depends(get_current_user)
):
user_id = current_user["id"]
logger.info(f"ENTERED upload_material (Streaming Mode) for user {user_id}. Headers: {request.headers}")
filename = "unknown_video.mp4" # Fallback
content_type = "video/mp4"
# Try to parse filename from header if possible (unreliable in raw stream)
# We will rely on post-processing or client hint
# Frontend sends standard multipart.
# Create temp file
timestamp = int(time.time())
temp_filename = f"upload_{timestamp}.raw"
temp_path = os.path.join("/tmp", temp_filename) # Use /tmp on Linux
# Ensure /tmp exists (it does) but verify paths
if os.name == 'nt': # Local dev
temp_path = f"d:/tmp/{temp_filename}"
os.makedirs("d:/tmp", exist_ok=True)
try:
total_size = 0
last_log = 0
async with aiofiles.open(temp_path, 'wb') as f:
async for chunk in request.stream():
await f.write(chunk)
total_size += len(chunk)
# Log progress every 20MB
if total_size - last_log > 20 * 1024 * 1024:
logger.info(f"Receiving stream... Processed {total_size / (1024*1024):.2f} MB")
last_log = total_size
logger.info(f"Stream reception complete. Total size: {total_size} bytes. Saved to {temp_path}")
if total_size == 0:
raise HTTPException(400, "Received empty body")
# Attempt to extract filename from the saved file's first bytes?
# Or just accept it as "uploaded_video.mp4" for now to prove it works.
# We can try to regex the header in the file content we just wrote.
# Implemented in background task to return success immediately.
# Wait, if we return immediately, the user's UI might not show the file yet?
# The prompt says "Wait for upload".
# But to avoid User Waiting Timeout, maybe returning early is better?
# NO, user expects the file to be in the list.
# So we Must await the processing.
# But "Processing" (Strip + Upload to Supabase) takes time.
# Receiving took time.
# If we await Supabase upload, does it timeout?
# Supabase upload is outgoing. Usually faster/stable.
# Let's await the processing to ensure "List Materials" shows it.
# We need to extract the filename for the list.
# Quick extract filename from first 4kb
with open(temp_path, 'rb') as f:
head = f.read(4096).decode('utf-8', errors='ignore')
match = re.search(r'filename="([^"]+)"', head)
if match:
filename = match.group(1)
logger.info(f"Extracted filename from body: {filename}")
# Run processing sync (in await)
storage_path = await process_and_upload(temp_path, filename, content_type, user_id)
# Get signed URL (it exists now)
signed_url = await storage_service.get_signed_url(
bucket=storage_service.BUCKET_MATERIALS,
path=storage_path
)
size_mb = total_size / (1024 * 1024) # Approximate (includes headers)
# 从 storage_path 提取显示名
display_name = storage_path.split('/')[-1] # 去掉 user_id 前缀
if '_' in display_name:
parts = display_name.split('_', 1)
if parts[0].isdigit():
display_name = parts[1]
return {
"id": storage_path,
"name": display_name,
"path": signed_url,
"size_mb": size_mb,
"type": "video"
}
except Exception as e:
error_msg = f"Streaming upload failed: {str(e)}"
detail_msg = f"Exception: {repr(e)}\nArgs: {e.args}\n{traceback.format_exc()}"
logger.error(error_msg + "\n" + detail_msg)
# Write to debug file
try:
with open("debug_upload.log", "a") as logf:
logf.write(f"\n--- Error at {time.ctime()} ---\n")
logf.write(detail_msg)
logf.write("\n-----------------------------\n")
except:
pass
if os.path.exists(temp_path):
try:
stat = f.stat()
files.append({
"id": f.stem,
"name": f.name,
"path": f"uploads/materials/{f.name}",
"size_mb": stat.st_size / (1024 * 1024),
"type": "video",
"created_at": stat.st_ctime
})
except Exception:
os.remove(temp_path)
except:
pass
raise HTTPException(500, f"Upload failed. Check server logs. Error: {str(e)}")
@router.get("")
async def list_materials(current_user: dict = Depends(get_current_user)):
user_id = current_user["id"]
try:
# 只列出当前用户目录下的文件
files_obj = await storage_service.list_files(
bucket=storage_service.BUCKET_MATERIALS,
path=user_id
)
materials = []
for f in files_obj:
name = f.get('name')
if not name or name == '.emptyFolderPlaceholder':
continue
# Sort by creation time desc
files.sort(key=lambda x: x.get("created_at", 0), reverse=True)
return {"materials": files}
display_name = name
if '_' in name:
parts = name.split('_', 1)
if parts[0].isdigit():
display_name = parts[1]
# 完整路径包含 user_id
full_path = f"{user_id}/{name}"
signed_url = await storage_service.get_signed_url(
bucket=storage_service.BUCKET_MATERIALS,
path=full_path
)
metadata = f.get('metadata', {})
size = metadata.get('size', 0)
# created_at 在顶层,是 ISO 字符串
created_at_str = f.get('created_at', '')
created_at = 0
if created_at_str:
from datetime import datetime
try:
dt = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
created_at = int(dt.timestamp())
except:
pass
materials.append({
"id": full_path, # ID 使用完整路径
"name": display_name,
"path": signed_url,
"size_mb": size / (1024 * 1024),
"type": "video",
"created_at": created_at
})
materials.sort(key=lambda x: x['id'], reverse=True)
return {"materials": materials}
except Exception as e:
logger.error(f"List materials failed: {e}")
return {"materials": []}
@router.delete("/{material_id:path}")
async def delete_material(material_id: str, current_user: dict = Depends(get_current_user)):
user_id = current_user["id"]
# 验证 material_id 属于当前用户
if not material_id.startswith(f"{user_id}/"):
raise HTTPException(403, "无权删除此素材")
try:
await storage_service.delete_file(
bucket=storage_service.BUCKET_MATERIALS,
path=material_id
)
return {"success": True, "message": "素材已删除"}
except Exception as e:
raise HTTPException(500, f"删除失败: {str(e)}")

View File

@@ -1,17 +1,19 @@
"""
发布管理 API
发布管理 API (支持用户认证)
"""
from fastapi import APIRouter, HTTPException, BackgroundTasks
from fastapi import APIRouter, HTTPException, BackgroundTasks, Depends, Request
from pydantic import BaseModel
from typing import List, Optional
from datetime import datetime
from loguru import logger
from app.services.publish_service import PublishService
from app.core.deps import get_current_user_optional
router = APIRouter()
publish_service = PublishService()
class PublishRequest(BaseModel):
"""Video publish request model"""
video_path: str
platform: str
title: str
@@ -20,13 +22,43 @@ class PublishRequest(BaseModel):
publish_time: Optional[datetime] = None
class PublishResponse(BaseModel):
"""Video publish response model"""
success: bool
message: str
platform: str
url: Optional[str] = None
@router.post("/", response_model=PublishResponse)
async def publish_video(request: PublishRequest, background_tasks: BackgroundTasks):
# Supported platforms for validation
SUPPORTED_PLATFORMS = {"bilibili", "douyin", "xiaohongshu"}
def _get_user_id(request: Request) -> Optional[str]:
"""从请求中获取用户 ID (兼容未登录场景)"""
try:
from app.core.security import decode_access_token
token = request.cookies.get("access_token")
if token:
token_data = decode_access_token(token)
if token_data:
return token_data.user_id
except Exception:
pass
return None
@router.post("", response_model=PublishResponse)
async def publish_video(request: PublishRequest, req: Request, background_tasks: BackgroundTasks):
"""发布视频到指定平台"""
# Validate platform
if request.platform not in SUPPORTED_PLATFORMS:
raise HTTPException(
status_code=400,
detail=f"不支持的平台: {request.platform}。支持的平台: {', '.join(SUPPORTED_PLATFORMS)}"
)
# 获取用户 ID (可选)
user_id = _get_user_id(req)
try:
result = await publish_service.publish(
video_path=request.video_path,
@@ -34,7 +66,8 @@ async def publish_video(request: PublishRequest, background_tasks: BackgroundTas
title=request.title,
tags=request.tags,
description=request.description,
publish_time=request.publish_time
publish_time=request.publish_time,
user_id=user_id
)
return PublishResponse(
success=result.get("success", False),
@@ -48,12 +81,66 @@ async def publish_video(request: PublishRequest, background_tasks: BackgroundTas
@router.get("/platforms")
async def list_platforms():
return {"platforms": [{"id": pid, **pinfo} for pid, pinfo in publish_service.PLATFORMS.items()]}
return {"platforms": [{**pinfo, "id": pid} for pid, pinfo in publish_service.PLATFORMS.items()]}
@router.get("/accounts")
async def list_accounts():
return {"accounts": publish_service.get_accounts()}
async def list_accounts(req: Request):
user_id = _get_user_id(req)
return {"accounts": publish_service.get_accounts(user_id)}
@router.post("/login/{platform}")
async def login_platform(platform: str):
return await publish_service.login(platform)
async def login_platform(platform: str, req: Request):
"""触发平台QR码登录"""
if platform not in SUPPORTED_PLATFORMS:
raise HTTPException(status_code=400, detail=f"不支持的平台: {platform}")
user_id = _get_user_id(req)
result = await publish_service.login(platform, user_id)
if result.get("success"):
return result
else:
raise HTTPException(status_code=400, detail=result.get("message"))
@router.post("/logout/{platform}")
async def logout_platform(platform: str, req: Request):
"""注销平台登录"""
if platform not in SUPPORTED_PLATFORMS:
raise HTTPException(status_code=400, detail=f"不支持的平台: {platform}")
user_id = _get_user_id(req)
result = publish_service.logout(platform, user_id)
return result
@router.get("/login/status/{platform}")
async def get_login_status(platform: str, req: Request):
"""检查登录状态 (优先检查活跃的扫码会话)"""
if platform not in SUPPORTED_PLATFORMS:
raise HTTPException(status_code=400, detail=f"不支持的平台: {platform}")
user_id = _get_user_id(req)
return publish_service.get_login_session_status(platform, user_id)
@router.post("/cookies/save/{platform}")
async def save_platform_cookie(platform: str, cookie_data: dict, req: Request):
"""
保存从客户端浏览器提取的Cookie
Args:
platform: 平台ID
cookie_data: {"cookie_string": "document.cookie的内容"}
"""
if platform not in SUPPORTED_PLATFORMS:
raise HTTPException(status_code=400, detail=f"不支持的平台: {platform}")
cookie_string = cookie_data.get("cookie_string", "")
if not cookie_string:
raise HTTPException(status_code=400, detail="cookie_string 不能为空")
user_id = _get_user_id(req)
result = await publish_service.save_cookie_string(platform, cookie_string, user_id)
if result.get("success"):
return result
else:
raise HTTPException(status_code=400, detail=result.get("message"))

View File

@@ -1,14 +1,19 @@
from fastapi import APIRouter, HTTPException, BackgroundTasks
from fastapi import APIRouter, HTTPException, BackgroundTasks, Depends, Request
from pydantic import BaseModel
from typing import Optional
from pathlib import Path
from loguru import logger
import uuid
import traceback
import time
import httpx
import os
from app.services.tts_service import TTSService
from app.services.video_service import VideoService
from app.services.lipsync_service import LipSyncService
from app.services.storage import storage_service
from app.core.config import settings
from app.core.deps import get_current_user
router = APIRouter()
@@ -47,42 +52,73 @@ async def _check_lipsync_ready(force: bool = False) -> bool:
print(f"[LipSync] Health check: ready={_lipsync_ready}")
return _lipsync_ready
async def _process_video_generation(task_id: str, req: GenerateRequest):
async def _download_material(path_or_url: str, temp_path: Path):
"""下载素材到临时文件 (流式下载,节省内存)"""
if path_or_url.startswith("http"):
# Download from URL
timeout = httpx.Timeout(None) # Disable timeout for large files
async with httpx.AsyncClient(timeout=timeout) as client:
async with client.stream("GET", path_or_url) as resp:
resp.raise_for_status()
with open(temp_path, "wb") as f:
async for chunk in resp.aiter_bytes():
f.write(chunk)
else:
# Local file (legacy or absolute path)
src = Path(path_or_url)
if not src.is_absolute():
src = settings.BASE_DIR.parent / path_or_url
if src.exists():
import shutil
shutil.copy(src, temp_path)
else:
raise FileNotFoundError(f"Material not found: {path_or_url}")
async def _process_video_generation(task_id: str, req: GenerateRequest, user_id: str):
temp_files = [] # Track files to clean up
try:
start_time = time.time()
# Resolve path if it's relative
input_material_path = Path(req.material_path)
if not input_material_path.is_absolute():
input_material_path = settings.BASE_DIR.parent / req.material_path
tasks[task_id]["status"] = "processing"
tasks[task_id]["progress"] = 5
tasks[task_id]["message"] = "正在初始化..."
tasks[task_id]["message"] = "正在下载素材..."
# Prepare temp dir
temp_dir = settings.UPLOAD_DIR / "temp"
temp_dir.mkdir(parents=True, exist_ok=True)
# 0. Download Material
input_material_path = temp_dir / f"{task_id}_input.mp4"
temp_files.append(input_material_path)
await _download_material(req.material_path, input_material_path)
# 1. TTS - 进度 5% -> 25%
tasks[task_id]["message"] = "正在生成语音 (TTS)..."
tasks[task_id]["progress"] = 10
tts = TTSService()
audio_path = settings.OUTPUT_DIR / f"{task_id}_audio.mp3"
audio_path = temp_dir / f"{task_id}_audio.mp3"
temp_files.append(audio_path)
await tts.generate_audio(req.text, req.voice, str(audio_path))
tts_time = time.time() - start_time
print(f"[Pipeline] TTS completed in {tts_time:.1f}s")
tasks[task_id]["progress"] = 25
# 2. LipSync - 进度 25% -> 85%
tasks[task_id]["message"] = "正在合成唇形 (LatentSync)..."
tasks[task_id]["progress"] = 30
lipsync = _get_lipsync_service()
lipsync_video_path = settings.OUTPUT_DIR / f"{task_id}_lipsync.mp4"
lipsync_video_path = temp_dir / f"{task_id}_lipsync.mp4"
temp_files.append(lipsync_video_path)
# 使用缓存的健康检查结果
lipsync_start = time.time()
is_ready = await _check_lipsync_ready()
if is_ready:
print(f"[LipSync] Starting LatentSync inference...")
tasks[task_id]["progress"] = 35
@@ -98,34 +134,72 @@ async def _process_video_generation(task_id: str, req: GenerateRequest):
lipsync_time = time.time() - lipsync_start
print(f"[Pipeline] LipSync completed in {lipsync_time:.1f}s")
tasks[task_id]["progress"] = 85
# 3. Composition - 进度 85% -> 100%
tasks[task_id]["message"] = "正在合成最终视频..."
tasks[task_id]["progress"] = 90
video = VideoService()
final_output = settings.OUTPUT_DIR / f"{task_id}_output.mp4"
await video.compose(str(lipsync_video_path), str(audio_path), str(final_output))
final_output_local_path = temp_dir / f"{task_id}_output.mp4"
temp_files.append(final_output_local_path)
await video.compose(str(lipsync_video_path), str(audio_path), str(final_output_local_path))
total_time = time.time() - start_time
# 4. Upload to Supabase with user isolation
tasks[task_id]["message"] = "正在上传结果..."
tasks[task_id]["progress"] = 95
# 使用 user_id 作为目录前缀实现隔离
storage_path = f"{user_id}/{task_id}_output.mp4"
with open(final_output_local_path, "rb") as f:
file_data = f.read()
await storage_service.upload_file(
bucket=storage_service.BUCKET_OUTPUTS,
path=storage_path,
file_data=file_data,
content_type="video/mp4"
)
# Get Signed URL
signed_url = await storage_service.get_signed_url(
bucket=storage_service.BUCKET_OUTPUTS,
path=storage_path
)
print(f"[Pipeline] Total generation time: {total_time:.1f}s")
tasks[task_id]["status"] = "completed"
tasks[task_id]["progress"] = 100
tasks[task_id]["message"] = f"生成完成!耗时 {total_time:.0f}"
tasks[task_id]["output"] = str(final_output)
tasks[task_id]["download_url"] = f"/outputs/{final_output.name}"
tasks[task_id]["output"] = storage_path
tasks[task_id]["download_url"] = signed_url
except Exception as e:
tasks[task_id]["status"] = "failed"
tasks[task_id]["message"] = f"错误: {str(e)}"
tasks[task_id]["error"] = traceback.format_exc()
logger.error(f"Generate video failed: {e}")
finally:
# Cleanup temp files
for f in temp_files:
try:
if f.exists():
f.unlink()
except Exception as e:
print(f"Error cleaning up {f}: {e}")
@router.post("/generate")
async def generate_video(req: GenerateRequest, background_tasks: BackgroundTasks):
async def generate_video(
req: GenerateRequest,
background_tasks: BackgroundTasks,
current_user: dict = Depends(get_current_user)
):
user_id = current_user["id"]
task_id = str(uuid.uuid4())
tasks[task_id] = {"status": "pending", "task_id": task_id, "progress": 0}
background_tasks.add_task(_process_video_generation, task_id, req)
tasks[task_id] = {"status": "pending", "task_id": task_id, "progress": 0, "user_id": user_id}
background_tasks.add_task(_process_video_generation, task_id, req, user_id)
return {"task_id": task_id}
@router.get("/tasks/{task_id}")
@@ -141,3 +215,85 @@ async def lipsync_health():
"""获取 LipSync 服务健康状态"""
lipsync = _get_lipsync_service()
return await lipsync.check_health()
@router.get("/generated")
async def list_generated_videos(current_user: dict = Depends(get_current_user)):
"""从 Storage 读取当前用户生成的视频列表"""
user_id = current_user["id"]
try:
# 只列出当前用户目录下的文件
files_obj = await storage_service.list_files(
bucket=storage_service.BUCKET_OUTPUTS,
path=user_id
)
videos = []
for f in files_obj:
name = f.get('name')
if not name or name == '.emptyFolderPlaceholder':
continue
# 过滤非 output.mp4 文件
if not name.endswith("_output.mp4"):
continue
# 获取 ID (即文件名去除后缀)
video_id = Path(name).stem
# 完整路径包含 user_id
full_path = f"{user_id}/{name}"
# 获取签名链接
signed_url = await storage_service.get_signed_url(
bucket=storage_service.BUCKET_OUTPUTS,
path=full_path
)
metadata = f.get('metadata', {})
size = metadata.get('size', 0)
# created_at 在顶层,是 ISO 字符串,转换为 Unix 时间戳
created_at_str = f.get('created_at', '')
created_at = 0
if created_at_str:
from datetime import datetime
try:
dt = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
created_at = int(dt.timestamp())
except:
pass
videos.append({
"id": video_id,
"name": name,
"path": signed_url, # Direct playable URL
"size_mb": size / (1024 * 1024),
"created_at": created_at
})
# Sort by created_at desc (newest first)
# Supabase API usually returns ISO string, simpler string sort works for ISO
videos.sort(key=lambda x: x.get("created_at", ""), reverse=True)
return {"videos": videos}
except Exception as e:
logger.error(f"List generated videos failed: {e}")
return {"videos": []}
@router.delete("/generated/{video_id}")
async def delete_generated_video(video_id: str, current_user: dict = Depends(get_current_user)):
"""删除生成的视频"""
user_id = current_user["id"]
try:
# video_id 通常是 uuid_output完整路径需要加上 user_id
storage_path = f"{user_id}/{video_id}.mp4"
await storage_service.delete_file(
bucket=storage_service.BUCKET_OUTPUTS,
path=storage_path
)
return {"success": True, "message": "视频已删除"}
except Exception as e:
raise HTTPException(500, f"删除失败: {str(e)}")

View File

@@ -18,11 +18,27 @@ class Settings(BaseSettings):
# LatentSync 配置
LATENTSYNC_GPU_ID: int = 1 # GPU ID (默认使用 GPU1)
LATENTSYNC_LOCAL: bool = True # 使用本地推理 (False 则使用远程 API)
LATENTSYNC_API_URL: str = "http://localhost:8001" # 远程 API 地址
LATENTSYNC_API_URL: str = "http://localhost:8007" # 远程 API 地址
LATENTSYNC_INFERENCE_STEPS: int = 20 # 推理步数 [20-50]
LATENTSYNC_GUIDANCE_SCALE: float = 1.5 # 引导系数 [1.0-3.0]
LATENTSYNC_ENABLE_DEEPCACHE: bool = True # 启用 DeepCache 加速
LATENTSYNC_ENABLE_DEEPCACHE: bool = True # 启用 DeepCache 加速
LATENTSYNC_SEED: int = 1247 # 随机种子 (-1 则随机)
LATENTSYNC_USE_SERVER: bool = False # 使用常驻服务 (Persistent Server) 加速
# Supabase 配置
SUPABASE_URL: str = ""
SUPABASE_PUBLIC_URL: str = "" # 公网访问地址,用于生成前端可访问的 URL
SUPABASE_KEY: str = ""
# JWT 配置
JWT_SECRET_KEY: str = "your-secret-key-change-in-production"
JWT_ALGORITHM: str = "HS256"
JWT_EXPIRE_HOURS: int = 24
# 管理员配置
ADMIN_EMAIL: str = ""
ADMIN_PASSWORD: str = ""
@property
def LATENTSYNC_DIR(self) -> Path:

141
backend/app/core/deps.py Normal file
View File

@@ -0,0 +1,141 @@
"""
依赖注入模块:认证和用户获取
"""
from typing import Optional
from fastapi import Request, HTTPException, Depends, status
from app.core.security import decode_access_token, TokenData
from app.core.supabase import get_supabase
from loguru import logger
async def get_token_from_cookie(request: Request) -> Optional[str]:
"""从 Cookie 中获取 Token"""
return request.cookies.get("access_token")
async def get_current_user_optional(
request: Request
) -> Optional[dict]:
"""
获取当前用户 (可选,未登录返回 None)
"""
token = await get_token_from_cookie(request)
if not token:
return None
token_data = decode_access_token(token)
if not token_data:
return None
# 验证 session_token 是否有效 (单设备登录检查)
try:
supabase = get_supabase()
result = supabase.table("user_sessions").select("*").eq(
"user_id", token_data.user_id
).eq(
"session_token", token_data.session_token
).execute()
if not result.data:
logger.warning(f"Session token 无效: user_id={token_data.user_id}")
return None
# 获取用户信息
user_result = supabase.table("users").select("*").eq(
"id", token_data.user_id
).single().execute()
return user_result.data
except Exception as e:
logger.error(f"获取用户信息失败: {e}")
return None
async def get_current_user(
request: Request
) -> dict:
"""
获取当前用户 (必须登录)
Raises:
HTTPException 401: 未登录
HTTPException 403: 会话失效或授权过期
"""
token = await get_token_from_cookie(request)
if not token:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="未登录,请先登录"
)
token_data = decode_access_token(token)
if not token_data:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Token 无效或已过期"
)
try:
supabase = get_supabase()
# 验证 session_token (单设备登录)
session_result = supabase.table("user_sessions").select("*").eq(
"user_id", token_data.user_id
).eq(
"session_token", token_data.session_token
).execute()
if not session_result.data:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="会话已失效,请重新登录(可能已在其他设备登录)"
)
# 获取用户信息
user_result = supabase.table("users").select("*").eq(
"id", token_data.user_id
).single().execute()
user = user_result.data
if not user:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="用户不存在"
)
# 检查授权是否过期
if user.get("expires_at"):
from datetime import datetime, timezone
expires_at = datetime.fromisoformat(user["expires_at"].replace("Z", "+00:00"))
if datetime.now(timezone.utc) > expires_at:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="授权已过期,请联系管理员续期"
)
return user
except HTTPException:
raise
except Exception as e:
logger.error(f"获取用户信息失败: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="服务器错误"
)
async def get_current_admin(
current_user: dict = Depends(get_current_user)
) -> dict:
"""
获取当前管理员用户
Raises:
HTTPException 403: 非管理员
"""
if current_user.get("role") != "admin":
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="需要管理员权限"
)
return current_user

98
backend/app/core/paths.py Normal file
View File

@@ -0,0 +1,98 @@
"""
路径规范化模块:按用户隔离 Cookie 存储
"""
from pathlib import Path
import re
from typing import Set
# 基础目录
BASE_DIR = Path(__file__).parent.parent.parent
USER_DATA_DIR = BASE_DIR / "user_data"
# 有效的平台列表
VALID_PLATFORMS: Set[str] = {"bilibili", "douyin", "xiaohongshu", "weixin", "kuaishou"}
# UUID 格式正则
UUID_PATTERN = re.compile(r'^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$', re.IGNORECASE)
def validate_user_id(user_id: str) -> bool:
"""验证 user_id 格式 (防止路径遍历攻击)"""
return bool(UUID_PATTERN.match(user_id))
def validate_platform(platform: str) -> bool:
"""验证平台名称"""
return platform in VALID_PLATFORMS
def get_user_data_dir(user_id: str) -> Path:
"""
获取用户数据根目录
Args:
user_id: 用户 UUID
Returns:
用户数据目录路径
Raises:
ValueError: user_id 格式无效
"""
if not validate_user_id(user_id):
raise ValueError(f"Invalid user_id format: {user_id}")
user_dir = USER_DATA_DIR / user_id
user_dir.mkdir(parents=True, exist_ok=True)
return user_dir
def get_user_cookie_dir(user_id: str) -> Path:
"""
获取用户 Cookie 目录
Args:
user_id: 用户 UUID
Returns:
Cookie 目录路径
"""
cookie_dir = get_user_data_dir(user_id) / "cookies"
cookie_dir.mkdir(parents=True, exist_ok=True)
return cookie_dir
def get_platform_cookie_path(user_id: str, platform: str) -> Path:
"""
获取平台 Cookie 文件路径
Args:
user_id: 用户 UUID
platform: 平台名称 (bilibili/douyin/xiaohongshu)
Returns:
Cookie 文件路径
Raises:
ValueError: 平台名称无效
"""
if not validate_platform(platform):
raise ValueError(f"Invalid platform: {platform}. Valid: {VALID_PLATFORMS}")
return get_user_cookie_dir(user_id) / f"{platform}_cookies.json"
# === 兼容旧代码的路径 (无用户隔离) ===
def get_legacy_cookie_dir() -> Path:
"""获取旧版 Cookie 目录 (无用户隔离)"""
cookie_dir = BASE_DIR / "app" / "cookies"
cookie_dir.mkdir(parents=True, exist_ok=True)
return cookie_dir
def get_legacy_cookie_path(platform: str) -> Path:
"""获取旧版 Cookie 路径 (无用户隔离)"""
if not validate_platform(platform):
raise ValueError(f"Invalid platform: {platform}")
return get_legacy_cookie_dir() / f"{platform}_cookies.json"

View File

@@ -0,0 +1,112 @@
"""
安全工具模块JWT Token 和密码处理
"""
from datetime import datetime, timedelta, timezone
from typing import Optional, Any
from jose import jwt, JWTError
from passlib.context import CryptContext
from pydantic import BaseModel
from fastapi import Response
from app.core.config import settings
import uuid
# 密码加密上下文
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
class TokenData(BaseModel):
"""JWT Token 数据结构"""
user_id: str
session_token: str
exp: datetime
def verify_password(plain_password: str, hashed_password: str) -> bool:
"""验证密码"""
return pwd_context.verify(plain_password, hashed_password)
def get_password_hash(password: str) -> str:
"""生成密码哈希"""
return pwd_context.hash(password)
def create_access_token(user_id: str, session_token: str) -> str:
"""
创建 JWT Access Token
Args:
user_id: 用户 ID
session_token: 会话 Token (用于单设备登录验证)
"""
expire = datetime.now(timezone.utc) + timedelta(hours=settings.JWT_EXPIRE_HOURS)
to_encode = {
"sub": user_id,
"session_token": session_token,
"exp": expire
}
return jwt.encode(
to_encode,
settings.JWT_SECRET_KEY,
algorithm=settings.JWT_ALGORITHM
)
def decode_access_token(token: str) -> Optional[TokenData]:
"""
解码并验证 JWT Token
Returns:
TokenData 或 None (如果验证失败)
"""
try:
payload = jwt.decode(
token,
settings.JWT_SECRET_KEY,
algorithms=[settings.JWT_ALGORITHM]
)
user_id = payload.get("sub")
session_token = payload.get("session_token")
exp = payload.get("exp")
if not user_id or not session_token:
return None
return TokenData(
user_id=user_id,
session_token=session_token,
exp=datetime.fromtimestamp(exp, tz=timezone.utc)
)
except JWTError:
return None
def generate_session_token() -> str:
"""生成新的会话 Token"""
return str(uuid.uuid4())
def set_auth_cookie(response: Response, token: str) -> None:
"""
设置 HttpOnly Cookie
Args:
response: FastAPI Response 对象
token: JWT Token
"""
response.set_cookie(
key="access_token",
value=token,
httponly=True,
secure=not settings.DEBUG, # 开发/测试环境(DEBUG=True)允许非HTTPS
samesite="lax",
max_age=settings.JWT_EXPIRE_HOURS * 3600
)
def clear_auth_cookie(response: Response) -> None:
"""清除认证 Cookie"""
response.delete_cookie(key="access_token")

View File

@@ -0,0 +1,26 @@
"""
Supabase 客户端初始化
"""
from supabase import create_client, Client
from app.core.config import settings
from loguru import logger
from typing import Optional
_supabase_client: Optional[Client] = None
def get_supabase() -> Client:
"""获取 Supabase 客户端单例"""
global _supabase_client
if _supabase_client is None:
if not settings.SUPABASE_URL or not settings.SUPABASE_KEY:
raise ValueError("SUPABASE_URL 和 SUPABASE_KEY 必须在 .env 中配置")
_supabase_client = create_client(
settings.SUPABASE_URL,
settings.SUPABASE_KEY
)
logger.info("Supabase 客户端已初始化")
return _supabase_client

View File

@@ -2,12 +2,36 @@ from fastapi import FastAPI
from fastapi.staticfiles import StaticFiles
from fastapi.middleware.cors import CORSMiddleware
from app.core import config
from app.api import materials, videos, publish
from app.api import materials, videos, publish, login_helper, auth, admin
from loguru import logger
import os
settings = config.settings
app = FastAPI(title="ViGent TalkingHead Agent")
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
import time
import traceback
class LoggingMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
start_time = time.time()
logger.info(f"START Request: {request.method} {request.url}")
logger.info(f"HEADERS: {dict(request.headers)}")
try:
response = await call_next(request)
process_time = time.time() - start_time
logger.info(f"END Request: {request.method} {request.url} - Status: {response.status_code} - Duration: {process_time:.2f}s")
return response
except Exception as e:
process_time = time.time() - start_time
logger.error(f"EXCEPTION during request {request.method} {request.url}: {str(e)}\n{traceback.format_exc()}")
raise e
app.add_middleware(LoggingMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
@@ -22,10 +46,56 @@ settings.OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
(settings.UPLOAD_DIR / "materials").mkdir(exist_ok=True)
app.mount("/outputs", StaticFiles(directory=str(settings.OUTPUT_DIR)), name="outputs")
app.mount("/uploads", StaticFiles(directory=str(settings.UPLOAD_DIR)), name="uploads")
# 注册路由
app.include_router(materials.router, prefix="/api/materials", tags=["Materials"])
app.include_router(videos.router, prefix="/api/videos", tags=["Videos"])
app.include_router(publish.router, prefix="/api/publish", tags=["Publish"])
app.include_router(login_helper.router, prefix="/api", tags=["LoginHelper"])
app.include_router(auth.router) # /api/auth
app.include_router(admin.router) # /api/admin
@app.on_event("startup")
async def init_admin():
"""
服务启动时初始化管理员账号
"""
admin_email = settings.ADMIN_EMAIL
admin_password = settings.ADMIN_PASSWORD
if not admin_email or not admin_password:
logger.warning("未配置 ADMIN_EMAIL 和 ADMIN_PASSWORD跳过管理员初始化")
return
try:
from app.core.supabase import get_supabase
from app.core.security import get_password_hash
supabase = get_supabase()
# 检查是否已存在
existing = supabase.table("users").select("id").eq("email", admin_email).execute()
if existing.data:
logger.info(f"管理员账号已存在: {admin_email}")
return
# 创建管理员
supabase.table("users").insert({
"email": admin_email,
"password_hash": get_password_hash(admin_password),
"username": "Admin",
"role": "admin",
"is_active": True,
"expires_at": None # 永不过期
}).execute()
logger.success(f"管理员账号已创建: {admin_email}")
except Exception as e:
logger.error(f"初始化管理员失败: {e}")
@app.get("/health")
def health():

View File

@@ -7,6 +7,7 @@ import os
import shutil
import subprocess
import tempfile
import asyncio
import httpx
from pathlib import Path
from loguru import logger
@@ -23,6 +24,10 @@ class LipSyncService:
self.api_url = settings.LATENTSYNC_API_URL
self.latentsync_dir = settings.LATENTSYNC_DIR
self.gpu_id = settings.LATENTSYNC_GPU_ID
self.use_server = settings.LATENTSYNC_USE_SERVER
# GPU 并发锁 (Serial Queue)
self._lock = asyncio.Lock()
# Conda 环境 Python 路径
# 根据服务器实际情况调整
@@ -197,98 +202,163 @@ class LipSyncService:
shutil.copy(video_path, output_path)
return output_path
logger.info("🔄 调用 LatentSync 推理 (subprocess)...")
# 使用临时目录存放输出
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
temp_output = tmpdir / "output.mp4"
# 视频预处理:压缩高分辨率视频以加速处理
preprocessed_video = tmpdir / "preprocessed_input.mp4"
actual_video_path = self._preprocess_video(
video_path,
str(preprocessed_video),
target_height=720
)
# 构建命令
cmd = [
str(self.conda_python),
"-m", "scripts.inference",
"--unet_config_path", "configs/unet/stage2_512.yaml",
"--inference_ckpt_path", "checkpoints/latentsync_unet.pt",
"--inference_steps", str(settings.LATENTSYNC_INFERENCE_STEPS),
"--guidance_scale", str(settings.LATENTSYNC_GUIDANCE_SCALE),
"--video_path", str(actual_video_path), # 使用预处理后的视频
"--audio_path", str(audio_path),
"--video_out_path", str(temp_output),
"--seed", str(settings.LATENTSYNC_SEED),
"--temp_dir", str(tmpdir / "cache"),
]
if settings.LATENTSYNC_ENABLE_DEEPCACHE:
cmd.append("--enable_deepcache")
# 设置环境变量
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id)
logger.info(f"🖥️ 执行命令: {' '.join(cmd[:8])}...")
logger.info(f"🖥️ GPU: CUDA_VISIBLE_DEVICES={self.gpu_id}")
try:
import asyncio
logger.info("⏳ 等待 GPU 资源 (排队中)...")
async with self._lock:
if self.use_server:
# 模式 A: 调用常驻服务 (加速模式)
return await self._call_persistent_server(video_path, audio_path, output_path)
# 使用 asyncio subprocess 实现真正的异步执行
# 这样事件循环可以继续处理其他请求(如进度查询)
process = await asyncio.create_subprocess_exec(
*cmd,
cwd=str(self.latentsync_dir),
env=env,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
logger.info("🔄 调用 LatentSync 推理 (subprocess)...")
# 使用临时目录存放输出
with tempfile.TemporaryDirectory() as tmpdir:
tmpdir = Path(tmpdir)
temp_output = tmpdir / "output.mp4"
# 视频预处理:压缩高分辨率视频以加速处理
# preprocessed_video = tmpdir / "preprocessed_input.mp4"
# actual_video_path = self._preprocess_video(
# video_path,
# str(preprocessed_video),
# target_height=720
# )
# 暂时禁用预处理以保持原始分辨率
actual_video_path = video_path
# 构建命令
cmd = [
str(self.conda_python),
"-m", "scripts.inference",
"--unet_config_path", "configs/unet/stage2_512.yaml",
"--inference_ckpt_path", "checkpoints/latentsync_unet.pt",
"--inference_steps", str(settings.LATENTSYNC_INFERENCE_STEPS),
"--guidance_scale", str(settings.LATENTSYNC_GUIDANCE_SCALE),
"--video_path", str(actual_video_path), # 使用预处理后的视频
"--audio_path", str(audio_path),
"--video_out_path", str(temp_output),
"--seed", str(settings.LATENTSYNC_SEED),
"--temp_dir", str(tmpdir / "cache"),
]
if settings.LATENTSYNC_ENABLE_DEEPCACHE:
cmd.append("--enable_deepcache")
# 设置环境变量
env = os.environ.copy()
env["CUDA_VISIBLE_DEVICES"] = str(self.gpu_id)
logger.info(f"🖥️ 执行命令: {' '.join(cmd[:8])}...")
logger.info(f"🖥️ GPU: CUDA_VISIBLE_DEVICES={self.gpu_id}")
# 等待进程完成,带超时
try:
stdout, stderr = await asyncio.wait_for(
process.communicate(),
timeout=900 # 15分钟超时
# 使用 asyncio subprocess 实现真正的异步执行
# 这样事件循环可以继续处理其他请求(如进度查询)
process = await asyncio.create_subprocess_exec(
*cmd,
cwd=str(self.latentsync_dir),
env=env,
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
)
except asyncio.TimeoutError:
process.kill()
await process.wait()
logger.error("⏰ LatentSync 推理超时 (15分钟)")
shutil.copy(video_path, output_path)
return output_path
stdout_text = stdout.decode() if stdout else ""
stderr_text = stderr.decode() if stderr else ""
if process.returncode != 0:
logger.error(f"LatentSync 推理失败:\n{stderr_text}")
logger.error(f"stdout:\n{stdout_text[-1000:] if stdout_text else 'N/A'}")
# Fallback
shutil.copy(video_path, output_path)
return output_path
logger.info(f"LatentSync 输出:\n{stdout_text[-500:] if stdout_text else 'N/A'}")
# 检查输出文件
if temp_output.exists():
shutil.copy(temp_output, output_path)
logger.info(f"✅ 唇形同步完成: {output_path}")
return output_path
else:
logger.warning("⚠️ 未找到输出文件,使用 Fallback")
shutil.copy(video_path, output_path)
return output_path
except Exception as e:
logger.error(f"❌ 推理异常: {e}")
shutil.copy(video_path, output_path)
return output_path
# 等待进程完成,带超时
try:
stdout, stderr = await asyncio.wait_for(
process.communicate(),
timeout=900 # 15分钟超时
)
except asyncio.TimeoutError:
process.kill()
await process.wait()
logger.error("⏰ LatentSync 推理超时 (15分钟)")
shutil.copy(video_path, output_path)
return output_path
stdout_text = stdout.decode() if stdout else ""
stderr_text = stderr.decode() if stderr else ""
if process.returncode != 0:
logger.error(f"LatentSync 推理失败:\n{stderr_text}")
logger.error(f"stdout:\n{stdout_text[-1000:] if stdout_text else 'N/A'}")
# Fallback
shutil.copy(video_path, output_path)
return output_path
logger.info(f"LatentSync 输出:\n{stdout_text[-500:] if stdout_text else 'N/A'}")
# 检查输出文件
if temp_output.exists():
shutil.copy(temp_output, output_path)
logger.info(f"✅ 唇形同步完成: {output_path}")
return output_path
else:
logger.warning("⚠️ 未找到输出文件,使用 Fallback")
shutil.copy(video_path, output_path)
return output_path
except Exception as e:
logger.error(f"❌ 推理异常: {e}")
shutil.copy(video_path, output_path)
return output_path
async def _call_persistent_server(self, video_path: str, audio_path: str, output_path: str) -> str:
"""调用本地常驻服务 (server.py)"""
server_url = "http://localhost:8007"
logger.info(f"⚡ 调用常驻服务: {server_url}")
# 准备请求数据 (传递绝对路径)
payload = {
"video_path": str(Path(video_path).resolve()),
"audio_path": str(Path(audio_path).resolve()),
"video_out_path": str(Path(output_path).resolve()),
"inference_steps": settings.LATENTSYNC_INFERENCE_STEPS,
"guidance_scale": settings.LATENTSYNC_GUIDANCE_SCALE,
"seed": settings.LATENTSYNC_SEED,
"temp_dir": os.path.join(tempfile.gettempdir(), "latentsync_temp")
}
try:
async with httpx.AsyncClient(timeout=1200.0) as client:
# 先检查健康状态
try:
resp = await client.get(f"{server_url}/health", timeout=5.0)
if resp.status_code != 200:
logger.warning("⚠️ 常驻服务健康检查失败,回退到 subprocess")
return await self._local_generate_subprocess(video_path, audio_path, output_path)
except Exception:
logger.warning("⚠️ 无法连接常驻服务,回退到 subprocess")
return await self._local_generate_subprocess(video_path, audio_path, output_path)
# 发送生成请求
response = await client.post(f"{server_url}/lipsync", json=payload)
if response.status_code == 200:
result = response.json()
if Path(result["output_path"]).exists():
logger.info(f"✅ 常驻服务推理完成: {output_path}")
return output_path
logger.error(f"❌ 常驻服务报错: {response.text}")
raise RuntimeError(f"Server Error: {response.text}")
except Exception as e:
logger.error(f"❌ 常驻服务调用失败: {e}")
# 这里可以选择回退,或者直接报错
raise e
async def _local_generate_subprocess(self, video_path: str, audio_path: str, output_path: str) -> str:
"""原有的 subprocess 逻辑提取为独立方法"""
logger.info("🔄 调用 LatentSync 推理 (subprocess)...")
# ... (此处仅为占位符提示,实际代码需要调整结构以避免重复,
# 但鉴于原有 _local_generate 的结构,最简单的方法是在 _local_generate 内部做判断,
# 如果 use_server 失败,可以 retry 或者 _local_generate 不做拆分,直接在里面写逻辑)
# 为了最小化改动且保持安全,上面的 _call_persistent_server 如果失败,
# 最好不要自动回退(可能导致双重资源消耗),而是直接报错让用户检查服务。
# 但为了用户体验,我们可以允许回退。
# *修正策略*:
# 我将不拆分 _local_generate_subprocess而是将 subprocess 逻辑保留在 _local_generate 的后半部分。
# 如果 self.use_server 为 True先尝试调用 server成功则 return失败则继续往下走。
pass
async def _remote_generate(
self,

View File

@@ -1,71 +1,368 @@
"""
发布服务 (Playwright)
发布服务 (支持用户隔离)
"""
from playwright.async_api import async_playwright
from pathlib import Path
import json
import asyncio
import os
import re
import tempfile
import httpx
from datetime import datetime
from pathlib import Path
from typing import Optional, List, Dict, Any
from loguru import logger
from app.core.config import settings
from app.core.paths import get_user_cookie_dir, get_platform_cookie_path, get_legacy_cookie_dir, get_legacy_cookie_path
from app.services.storage import storage_service
# Import platform uploaders
from .uploader.bilibili_uploader import BilibiliUploader
from .uploader.douyin_uploader import DouyinUploader
from .uploader.xiaohongshu_uploader import XiaohongshuUploader
class PublishService:
PLATFORMS = {
"douyin": {"name": "抖音", "url": "https://creator.douyin.com/"},
"xiaohongshu": {"name": "小红书", "url": "https://creator.xiaohongshu.com/"},
"weixin": {"name": "微信视频号", "url": "https://channels.weixin.qq.com/"},
"kuaishou": {"name": "快手", "url": "https://cp.kuaishou.com/"},
"bilibili": {"name": "B站", "url": "https://member.bilibili.com/platform/upload/video/frame"},
"""Social media publishing service (with user isolation)"""
# 支持的平台配置
PLATFORMS: Dict[str, Dict[str, Any]] = {
"bilibili": {"name": "B站", "url": "https://member.bilibili.com/platform/upload/video/frame", "enabled": True},
"douyin": {"name": "抖音", "url": "https://creator.douyin.com/", "enabled": True},
"xiaohongshu": {"name": "小红书", "url": "https://creator.xiaohongshu.com/", "enabled": True},
"weixin": {"name": "微信视频号", "url": "https://channels.weixin.qq.com/", "enabled": False},
"kuaishou": {"name": "快手", "url": "https://cp.kuaishou.com/", "enabled": False},
}
def __init__(self):
self.cookies_dir = settings.BASE_DIR / "cookies"
self.cookies_dir.mkdir(exist_ok=True)
def get_accounts(self):
def __init__(self) -> None:
# 存储活跃的登录会话,用于跟踪登录状态
# key 格式: "{user_id}_{platform}" 或 "{platform}" (兼容旧版)
self.active_login_sessions: Dict[str, Any] = {}
def _get_cookies_dir(self, user_id: Optional[str] = None) -> Path:
"""获取 Cookie 目录 (支持用户隔离)"""
if user_id:
return get_user_cookie_dir(user_id)
return get_legacy_cookie_dir()
def _get_cookie_path(self, platform: str, user_id: Optional[str] = None) -> Path:
"""获取 Cookie 文件路径 (支持用户隔离)"""
if user_id:
return get_platform_cookie_path(user_id, platform)
return get_legacy_cookie_path(platform)
def _get_session_key(self, platform: str, user_id: Optional[str] = None) -> str:
"""获取会话 key"""
if user_id:
return f"{user_id}_{platform}"
return platform
def get_accounts(self, user_id: Optional[str] = None) -> List[Dict[str, Any]]:
"""Get list of platform accounts with login status"""
accounts = []
for pid, pinfo in self.PLATFORMS.items():
cookie_file = self.cookies_dir / f"{pid}_cookies.json"
cookie_file = self._get_cookie_path(pid, user_id)
accounts.append({
"platform": pid,
"name": pinfo["name"],
"logged_in": cookie_file.exists(),
"enabled": True
"enabled": pinfo.get("enabled", True)
})
return accounts
async def login(self, platform: str):
if platform not in self.PLATFORMS:
raise ValueError("Unsupported platform")
pinfo = self.PLATFORMS[platform]
logger.info(f"Logging in to {platform}...")
async def publish(
self,
video_path: str,
platform: str,
title: str,
tags: List[str],
description: str = "",
publish_time: Optional[datetime] = None,
user_id: Optional[str] = None,
**kwargs: Any
) -> Dict[str, Any]:
"""
Publish video to specified platform
async with async_playwright() as p:
browser = await p.chromium.launch(headless=False)
context = await browser.new_context()
page = await context.new_page()
Args:
video_path: Path to video file
platform: Platform ID (bilibili, douyin, etc.)
title: Video title
tags: List of tags
description: Video description
publish_time: Scheduled publish time (None = immediate)
user_id: User ID for cookie isolation
**kwargs: Additional platform-specific parameters
await page.goto(pinfo["url"])
logger.info("Please login manually in the browser window...")
# Wait for user input (naive check via title or url change, or explicit timeout)
# For simplicity in restore, wait for 60s or until manually closed?
# In a real API, this blocks.
# We implemented a simplistic wait in the previous iteration.
try:
await page.wait_for_timeout(45000) # Give user 45s to login
cookies = await context.cookies()
cookie_path = self.cookies_dir / f"{platform}_cookies.json"
with open(cookie_path, "w") as f:
json.dump(cookies, f)
return {"success": True, "message": f"Login {platform} successful"}
except Exception as e:
return {"success": False, "message": str(e)}
finally:
await browser.close()
Returns:
dict: Publish result
"""
# Validate platform
if platform not in self.PLATFORMS:
logger.error(f"[发布] 不支持的平台: {platform}")
return {
"success": False,
"message": f"不支持的平台: {platform}",
"platform": platform
}
# Get account file path (with user isolation)
account_file = self._get_cookie_path(platform, user_id)
if not account_file.exists():
return {
"success": False,
"message": f"请先登录 {self.PLATFORMS[platform]['name']}",
"platform": platform
}
logger.info(f"[发布] 平台: {self.PLATFORMS[platform]['name']}")
logger.info(f"[发布] 视频: {video_path}")
logger.info(f"[发布] 标题: {title}")
logger.info(f"[发布] 用户: {user_id or 'legacy'}")
async def publish(self, video_path: str, platform: str, title: str, **kwargs):
# Placeholder for actual automation logic
# Real implementation requires complex selectors per platform
await asyncio.sleep(2)
return {"success": True, "message": f"Published to {platform} (Mock)", "url": ""}
temp_file = None
try:
# 处理视频路径
if video_path.startswith('http://') or video_path.startswith('https://'):
# 尝试从 URL 解析 bucket 和 path直接使用本地文件
local_video_path = None
# URL 格式: .../storage/v1/object/sign/{bucket}/{path}?token=...
match = re.search(r'/storage/v1/object/sign/([^/]+)/(.+?)\?', video_path)
if match:
bucket = match.group(1)
storage_path = match.group(2)
logger.info(f"[发布] 解析 URL: bucket={bucket}, path={storage_path}")
# 尝试获取本地文件路径
local_video_path = storage_service.get_local_file_path(bucket, storage_path)
if local_video_path and os.path.exists(local_video_path):
logger.info(f"[发布] 直接使用本地文件: {local_video_path}")
else:
# 本地文件不存在,通过 HTTP 下载
logger.info(f"[发布] 本地文件不存在,通过 HTTP 下载...")
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
temp_file.close()
# 将公网 URL 替换为内网 URL
download_url = video_path
if settings.SUPABASE_PUBLIC_URL and settings.SUPABASE_URL:
public_url = settings.SUPABASE_PUBLIC_URL.rstrip('/')
internal_url = settings.SUPABASE_URL.rstrip('/')
download_url = video_path.replace(public_url, internal_url)
async with httpx.AsyncClient(timeout=httpx.Timeout(None)) as client:
async with client.stream("GET", download_url) as resp:
resp.raise_for_status()
with open(temp_file.name, 'wb') as f:
async for chunk in resp.aiter_bytes():
f.write(chunk)
local_video_path = temp_file.name
logger.info(f"[发布] 视频已下载到: {local_video_path}")
else:
# 本地相对路径
local_video_path = str(settings.BASE_DIR.parent / video_path)
# Select appropriate uploader
if platform == "bilibili":
uploader = BilibiliUploader(
title=title,
file_path=local_video_path,
tags=tags,
publish_date=publish_time,
account_file=str(account_file),
description=description,
tid=kwargs.get('tid', 122),
copyright=kwargs.get('copyright', 1)
)
elif platform == "douyin":
uploader = DouyinUploader(
title=title,
file_path=local_video_path,
tags=tags,
publish_date=publish_time,
account_file=str(account_file),
description=description
)
elif platform == "xiaohongshu":
uploader = XiaohongshuUploader(
title=title,
file_path=local_video_path,
tags=tags,
publish_date=publish_time,
account_file=str(account_file),
description=description
)
else:
logger.warning(f"[发布] {platform} 上传功能尚未实现")
return {
"success": False,
"message": f"{self.PLATFORMS[platform]['name']} 上传功能开发中",
"platform": platform
}
# Execute upload
result = await uploader.main()
result['platform'] = platform
return result
except Exception as e:
logger.exception(f"[发布] 上传异常: {e}")
return {
"success": False,
"message": f"上传异常: {str(e)}",
"platform": platform
}
finally:
# 清理临时文件
if temp_file and os.path.exists(temp_file.name):
try:
os.remove(temp_file.name)
logger.info(f"[发布] 已清理临时文件: {temp_file.name}")
except Exception as e:
logger.warning(f"[发布] 清理临时文件失败: {e}")
async def login(self, platform: str, user_id: Optional[str] = None) -> Dict[str, Any]:
"""
启动QR码登录流程
Args:
platform: 平台 ID
user_id: 用户 ID (用于 Cookie 隔离)
Returns:
dict: 包含二维码base64图片
"""
if platform not in self.PLATFORMS:
return {"success": False, "message": "不支持的平台"}
try:
from .qr_login_service import QRLoginService
# 获取用户专属的 Cookie 目录
cookies_dir = self._get_cookies_dir(user_id)
# 创建QR登录服务
qr_service = QRLoginService(platform, cookies_dir)
# 存储活跃会话 (带用户隔离)
session_key = self._get_session_key(platform, user_id)
self.active_login_sessions[session_key] = qr_service
# 启动登录并获取二维码
result = await qr_service.start_login()
return result
except Exception as e:
logger.exception(f"[登录] QR码登录失败: {e}")
return {
"success": False,
"message": f"登录失败: {str(e)}"
}
def get_login_session_status(self, platform: str, user_id: Optional[str] = None) -> Dict[str, Any]:
"""获取活跃登录会话的状态"""
session_key = self._get_session_key(platform, user_id)
# 1. 如果有活跃的扫码会话,优先检查它
if session_key in self.active_login_sessions:
qr_service = self.active_login_sessions[session_key]
status = qr_service.get_login_status()
# 如果登录成功且Cookie已保存清理会话
if status["success"] and status["cookies_saved"]:
del self.active_login_sessions[session_key]
return {"success": True, "message": "登录成功"}
return {"success": False, "message": "等待扫码..."}
# 2. 检查本地Cookie文件是否存在
cookie_file = self._get_cookie_path(platform, user_id)
if cookie_file.exists():
return {"success": True, "message": "已登录 (历史状态)"}
return {"success": False, "message": "未登录"}
def logout(self, platform: str, user_id: Optional[str] = None) -> Dict[str, Any]:
"""
Logout from platform (delete cookie file)
"""
if platform not in self.PLATFORMS:
return {"success": False, "message": "不支持的平台"}
try:
session_key = self._get_session_key(platform, user_id)
# 1. 移除活跃会话
if session_key in self.active_login_sessions:
del self.active_login_sessions[session_key]
# 2. 删除Cookie文件
cookie_file = self._get_cookie_path(platform, user_id)
if cookie_file.exists():
cookie_file.unlink()
logger.info(f"[登出] {platform} Cookie已删除 (user: {user_id or 'legacy'})")
return {"success": True, "message": "已注销"}
except Exception as e:
logger.exception(f"[登出] 失败: {e}")
return {"success": False, "message": f"注销失败: {str(e)}"}
async def save_cookie_string(self, platform: str, cookie_string: str, user_id: Optional[str] = None) -> Dict[str, Any]:
"""
保存从客户端浏览器提取的Cookie字符串
Args:
platform: 平台ID
cookie_string: document.cookie 格式的Cookie字符串
user_id: 用户 ID (用于 Cookie 隔离)
"""
try:
account_file = self._get_cookie_path(platform, user_id)
# 解析Cookie字符串
cookie_dict = {}
for item in cookie_string.split('; '):
if '=' in item:
name, value = item.split('=', 1)
cookie_dict[name] = value
# 对B站进行特殊处理
if platform == "bilibili":
bilibili_cookies = {}
required_fields = ['SESSDATA', 'bili_jct', 'DedeUserID', 'DedeUserID__ckMd5']
for field in required_fields:
if field in cookie_dict:
bilibili_cookies[field] = cookie_dict[field]
if len(bilibili_cookies) < 3:
return {
"success": False,
"message": "Cookie不完整请确保已登录"
}
cookie_dict = bilibili_cookies
# 确保目录存在
account_file.parent.mkdir(parents=True, exist_ok=True)
# 保存Cookie
with open(account_file, 'w', encoding='utf-8') as f:
json.dump(cookie_dict, f, indent=2)
logger.success(f"[登录] {platform} Cookie已保存 (user: {user_id or 'legacy'})")
return {
"success": True,
"message": f"{self.PLATFORMS[platform]['name']} 登录成功"
}
except Exception as e:
logger.exception(f"[登录] Cookie保存失败: {e}")
return {
"success": False,
"message": f"Cookie保存失败: {str(e)}"
}

View File

@@ -0,0 +1,344 @@
"""
QR码自动登录服务
后端Playwright无头模式获取二维码前端扫码后自动保存Cookie
"""
import asyncio
import base64
import json
from pathlib import Path
from typing import Optional, Dict, Any, List
from playwright.async_api import async_playwright, Page, BrowserContext, Browser, Playwright as PW
from loguru import logger
class QRLoginService:
"""QR码登录服务"""
# 登录监控超时 (秒)
LOGIN_TIMEOUT = 120
def __init__(self, platform: str, cookies_dir: Path) -> None:
self.platform = platform
self.cookies_dir = cookies_dir
self.qr_code_image: Optional[str] = None
self.login_success: bool = False
self.cookies_data: Optional[Dict[str, Any]] = None
# Playwright 资源 (手动管理生命周期)
self.playwright: Optional[PW] = None
self.browser: Optional[Browser] = None
self.context: Optional[BrowserContext] = None
# 每个平台使用多个选择器 (使用逗号分隔Playwright会同时等待它们)
self.platform_configs = {
"bilibili": {
"url": "https://passport.bilibili.com/login",
"qr_selectors": [
"div[class*='qrcode'] canvas", # 常见canvas二维码
"div[class*='qrcode'] img", # 常见图片二维码
".qrcode-img img", # 旧版
".login-scan-box img", # 扫码框
"div[class*='scan'] img"
],
"success_indicator": "https://www.bilibili.com/"
},
"douyin": {
"url": "https://creator.douyin.com/",
"qr_selectors": [
".qrcode img", # 优先尝试
"img[alt='qrcode']",
"canvas[class*='qr']",
"img[src*='qr']"
],
"success_indicator": "https://creator.douyin.com/creator-micro"
},
"xiaohongshu": {
"url": "https://creator.xiaohongshu.com/",
"qr_selectors": [
".qrcode img",
"img[alt*='二维码']",
"canvas.qr-code",
"img[class*='qr']"
],
"success_indicator": "https://creator.xiaohongshu.com/publish"
}
}
async def start_login(self) -> Dict[str, Any]:
"""
启动登录流程
Returns:
dict: 包含二维码base64和状态
"""
if self.platform not in self.platform_configs:
return {"success": False, "message": "不支持的平台"}
config = self.platform_configs[self.platform]
try:
# 1. 启动 Playwright (不使用 async with手动管理生命周期)
self.playwright = await async_playwright().start()
# Stealth模式启动浏览器
self.browser = await self.playwright.chromium.launch(
headless=True,
args=[
'--disable-blink-features=AutomationControlled',
'--no-sandbox',
'--disable-dev-shm-usage'
]
)
# 配置真实浏览器特征
self.context = await self.browser.new_context(
viewport={'width': 1920, 'height': 1080},
user_agent='Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/120.0.0.0 Safari/537.36',
locale='zh-CN',
timezone_id='Asia/Shanghai'
)
page = await self.context.new_page()
# 注入stealth.js
stealth_path = Path(__file__).parent / 'uploader' / 'stealth.min.js'
if stealth_path.exists():
await page.add_init_script(path=str(stealth_path))
logger.debug(f"[{self.platform}] Stealth模式已启用")
logger.info(f"[{self.platform}] 打开登录页...")
await page.goto(config["url"], wait_until='networkidle')
# 等待页面加载 (缩短等待)
await asyncio.sleep(2)
# 提取二维码 (并行策略)
qr_image = await self._extract_qr_code(page, config["qr_selectors"])
if not qr_image:
await self._cleanup()
return {"success": False, "message": "未找到二维码"}
logger.info(f"[{self.platform}] 二维码已获取,等待扫码...")
# 启动后台监控任务 (浏览器保持开启)
asyncio.create_task(
self._monitor_login_status(page, config["success_indicator"])
)
return {
"success": True,
"qr_code": qr_image,
"message": "请扫码登录"
}
except Exception as e:
logger.exception(f"[{self.platform}] 启动登录失败: {e}")
await self._cleanup()
return {"success": False, "message": f"启动失败: {str(e)}"}
async def _extract_qr_code(self, page: Page, selectors: List[str]) -> Optional[str]:
"""
提取二维码图片 (优化策略顺序)
根据日志分析抖音和B站使用 Text 策略成功率最高
"""
qr_element = None
# 针对抖音和B站优先使用 Text 策略 (成功率最高,速度最快)
if self.platform in ("douyin", "bilibili"):
# 尝试最多2次 (首次 + 1次重试)
for attempt in range(2):
if attempt > 0:
logger.info(f"[{self.platform}] 等待页面加载后重试...")
await asyncio.sleep(2)
# 策略1: Text (优先,成功率最高)
qr_element = await self._try_text_strategy(page)
if qr_element:
try:
screenshot = await qr_element.screenshot()
return base64.b64encode(screenshot).decode()
except Exception as e:
logger.warning(f"[{self.platform}] Text策略截图失败: {e}")
qr_element = None
# 策略2: CSS (备用)
if not qr_element:
try:
combined_selector = ", ".join(selectors)
logger.debug(f"[{self.platform}] 策略2(CSS): 开始等待...")
# 增加超时到5秒抖音页面加载较慢
el = await page.wait_for_selector(combined_selector, state="visible", timeout=5000)
if el:
logger.info(f"[{self.platform}] 策略2(CSS): 匹配成功")
screenshot = await el.screenshot()
return base64.b64encode(screenshot).decode()
except Exception as e:
logger.warning(f"[{self.platform}] 策略2(CSS) 失败: {e}")
# 如果已成功,退出循环
if qr_element:
break
else:
# 其他平台 (小红书等):保持原顺序 CSS -> Text
# 策略1: CSS 选择器
try:
combined_selector = ", ".join(selectors)
logger.debug(f"[{self.platform}] 策略1(CSS): 开始等待...")
el = await page.wait_for_selector(combined_selector, state="visible", timeout=5000)
if el:
logger.info(f"[{self.platform}] 策略1(CSS): 匹配成功")
qr_element = el
except Exception as e:
logger.warning(f"[{self.platform}] 策略1(CSS) 失败: {e}")
# 策略2: Text
if not qr_element:
qr_element = await self._try_text_strategy(page)
# 如果找到元素,截图返回
if qr_element:
try:
screenshot = await qr_element.screenshot()
return base64.b64encode(screenshot).decode()
except Exception as e:
logger.error(f"[{self.platform}] 截图失败: {e}")
# 所有策略失败
logger.error(f"[{self.platform}] 所有QR码提取策略失败")
# 保存调试截图
debug_dir = Path(__file__).parent.parent.parent / 'debug_screenshots'
debug_dir.mkdir(exist_ok=True)
await page.screenshot(path=str(debug_dir / f"{self.platform}_debug.png"))
return None
async def _try_text_strategy(self, page: Page) -> Optional[Any]:
"""基于文本查找二维码图片"""
try:
logger.debug(f"[{self.platform}] 策略Text: 开始搜索...")
keywords = ["扫码登录", "二维码", "打开抖音", "抖音APP", "使用APP扫码"]
for kw in keywords:
try:
text_el = page.get_by_text(kw, exact=False).first
await text_el.wait_for(state="visible", timeout=2000)
# 向上查找图片
parent = text_el
for _ in range(5):
parent = parent.locator("..")
imgs = parent.locator("img")
for i in range(await imgs.count()):
img = imgs.nth(i)
if await img.is_visible():
bbox = await img.bounding_box()
if bbox and bbox['width'] > 100:
logger.info(f"[{self.platform}] 策略Text: 成功")
return img
except Exception:
continue
except Exception as e:
logger.warning(f"[{self.platform}] 策略Text 失败: {e}")
return None
async def _monitor_login_status(self, page: Page, success_url: str):
"""监控登录状态"""
try:
logger.info(f"[{self.platform}] 开始监控登录状态...")
key_cookies = {"bilibili": "SESSDATA", "douyin": "sessionid", "xiaohongshu": "web_session"}
target_cookie = key_cookies.get(self.platform, "")
for i in range(self.LOGIN_TIMEOUT):
await asyncio.sleep(1)
try:
if not self.context: break # 避免意外关闭
cookies = await self.context.cookies()
current_url = page.url
has_cookie = any(c['name'] == target_cookie for c in cookies)
if i % 5 == 0:
logger.debug(f"[{self.platform}] 等待登录... HasCookie: {has_cookie}")
if success_url in current_url or has_cookie:
logger.success(f"[{self.platform}] 登录成功!")
self.login_success = True
await asyncio.sleep(2) # 缓冲
# 保存Cookie
final_cookies = await self.context.cookies()
await self._save_cookies(final_cookies)
break
except Exception as e:
logger.warning(f"[{self.platform}] 监控循环警告: {e}")
break
if not self.login_success:
logger.warning(f"[{self.platform}] 登录超时")
except Exception as e:
logger.error(f"[{self.platform}] 监控异常: {e}")
finally:
await self._cleanup()
async def _cleanup(self) -> None:
"""清理资源"""
if self.context:
try:
await self.context.close()
except Exception:
pass
self.context = None
if self.browser:
try:
await self.browser.close()
except Exception:
pass
self.browser = None
if self.playwright:
try:
await self.playwright.stop()
except Exception:
pass
self.playwright = None
async def _save_cookies(self, cookies: List[Dict[str, Any]]) -> None:
"""保存Cookie到文件"""
try:
cookie_file = self.cookies_dir / f"{self.platform}_cookies.json"
if self.platform == "bilibili":
# Bilibili 使用简单格式 (biliup库需要)
cookie_dict = {c['name']: c['value'] for c in cookies}
required = ['SESSDATA', 'bili_jct', 'DedeUserID', 'DedeUserID__ckMd5']
cookie_dict = {k: v for k, v in cookie_dict.items() if k in required}
with open(cookie_file, 'w', encoding='utf-8') as f:
json.dump(cookie_dict, f, indent=2)
self.cookies_data = cookie_dict
else:
# Douyin/Xiaohongshu 使用 Playwright storage_state 完整格式
# 这样可以直接用 browser.new_context(storage_state=file)
storage_state = {
"cookies": cookies,
"origins": []
}
with open(cookie_file, 'w', encoding='utf-8') as f:
json.dump(storage_state, f, indent=2)
self.cookies_data = storage_state
logger.success(f"[{self.platform}] Cookie已保存")
except Exception as e:
logger.error(f"[{self.platform}] 保存Cookie失败: {e}")
def get_login_status(self) -> Dict[str, Any]:
"""获取登录状态"""
return {
"success": self.login_success,
"cookies_saved": self.cookies_data is not None
}

View File

@@ -0,0 +1,148 @@
from supabase import Client
from app.core.supabase import get_supabase
from app.core.config import settings
from loguru import logger
from typing import Optional, Union, Dict, List, Any
from pathlib import Path
import asyncio
import functools
import os
# Supabase Storage 本地存储根目录
SUPABASE_STORAGE_LOCAL_PATH = Path("/home/rongye/ProgramFiles/Supabase/volumes/storage/stub/stub")
class StorageService:
def __init__(self):
self.supabase: Client = get_supabase()
self.BUCKET_MATERIALS = "materials"
self.BUCKET_OUTPUTS = "outputs"
def _convert_to_public_url(self, url: str) -> str:
"""将内部 URL 转换为公网可访问的 URL"""
if settings.SUPABASE_PUBLIC_URL and settings.SUPABASE_URL:
# 去掉末尾斜杠进行替换
internal_url = settings.SUPABASE_URL.rstrip('/')
public_url = settings.SUPABASE_PUBLIC_URL.rstrip('/')
return url.replace(internal_url, public_url)
return url
def get_local_file_path(self, bucket: str, path: str) -> Optional[str]:
"""
获取 Storage 文件的本地磁盘路径
Supabase Storage 文件存储结构:
{STORAGE_ROOT}/{bucket}/{path}/{internal_uuid}
Returns:
本地文件路径,如果不存在返回 None
"""
try:
# 构建目录路径
dir_path = SUPABASE_STORAGE_LOCAL_PATH / bucket / path
if not dir_path.exists():
logger.warning(f"Storage 目录不存在: {dir_path}")
return None
# 目录下只有一个文件internal_uuid
files = list(dir_path.iterdir())
if not files:
logger.warning(f"Storage 目录为空: {dir_path}")
return None
local_path = str(files[0])
logger.info(f"获取本地文件路径: {local_path}")
return local_path
except Exception as e:
logger.error(f"获取本地文件路径失败: {e}")
return None
async def upload_file(self, bucket: str, path: str, file_data: bytes, content_type: str) -> str:
"""
异步上传文件到 Supabase Storage
"""
try:
# 运行在线程池中,避免阻塞事件循环
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None,
functools.partial(
self.supabase.storage.from_(bucket).upload,
path=path,
file=file_data,
file_options={"content-type": content_type, "upsert": "true"}
)
)
logger.info(f"Storage upload success: {path}")
return path
except Exception as e:
logger.error(f"Storage upload failed: {e}")
raise e
async def get_signed_url(self, bucket: str, path: str, expires_in: int = 3600) -> str:
"""异步获取签名访问链接"""
try:
loop = asyncio.get_running_loop()
res = await loop.run_in_executor(
None,
lambda: self.supabase.storage.from_(bucket).create_signed_url(path, expires_in)
)
# 兼容处理
url = ""
if isinstance(res, dict) and "signedURL" in res:
url = res["signedURL"]
elif isinstance(res, str):
url = res
else:
logger.warning(f"Unexpected signed_url response: {res}")
url = res.get("signedURL", "") if isinstance(res, dict) else str(res)
# 转换为公网可访问的 URL
return self._convert_to_public_url(url)
except Exception as e:
logger.error(f"Get signed URL failed: {e}")
return ""
async def get_public_url(self, bucket: str, path: str) -> str:
"""获取公开访问链接"""
try:
loop = asyncio.get_running_loop()
res = await loop.run_in_executor(
None,
lambda: self.supabase.storage.from_(bucket).get_public_url(path)
)
# 转换为公网可访问的 URL
return self._convert_to_public_url(res)
except Exception as e:
logger.error(f"Get public URL failed: {e}")
return ""
async def delete_file(self, bucket: str, path: str):
"""异步删除文件"""
try:
loop = asyncio.get_running_loop()
await loop.run_in_executor(
None,
lambda: self.supabase.storage.from_(bucket).remove([path])
)
logger.info(f"Deleted file: {bucket}/{path}")
except Exception as e:
logger.error(f"Delete file failed: {e}")
pass
async def list_files(self, bucket: str, path: str) -> List[Any]:
"""异步列出文件"""
try:
loop = asyncio.get_running_loop()
res = await loop.run_in_executor(
None,
lambda: self.supabase.storage.from_(bucket).list(path)
)
return res or []
except Exception as e:
logger.error(f"List files failed: {e}")
return []
storage_service = StorageService()

View File

@@ -0,0 +1,9 @@
"""
Platform uploader base classes and utilities
"""
from .base_uploader import BaseUploader
from .bilibili_uploader import BilibiliUploader
from .douyin_uploader import DouyinUploader
from .xiaohongshu_uploader import XiaohongshuUploader
__all__ = ['BaseUploader', 'BilibiliUploader', 'DouyinUploader', 'XiaohongshuUploader']

View File

@@ -0,0 +1,65 @@
"""
Base uploader class for all social media platforms
"""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import List, Optional, Dict, Any, Union
from datetime import datetime
class BaseUploader(ABC):
"""Base class for all platform uploaders"""
def __init__(
self,
title: str,
file_path: str,
tags: List[str],
publish_date: Optional[datetime] = None,
account_file: Optional[str] = None,
description: str = ""
):
"""
Initialize base uploader
Args:
title: Video title
file_path: Path to video file
tags: List of tags/hashtags
publish_date: Scheduled publish time (None = publish immediately)
account_file: Path to account cookie/credentials file
description: Video description
"""
self.title = title
self.file_path = Path(file_path)
self.tags = tags
self.publish_date = publish_date if publish_date else 0 # 0 = immediate
self.account_file = account_file
self.description = description
@abstractmethod
async def main(self) -> Dict[str, Any]:
"""
Main upload method - must be implemented by subclasses
Returns:
dict: Upload result with keys:
- success (bool): Whether upload succeeded
- message (str): Result message
- url (str, optional): URL of published video
"""
pass
def _get_timestamp(self, dt: Union[datetime, int]) -> int:
"""
Convert datetime to Unix timestamp
Args:
dt: datetime object or 0 for immediate publish
Returns:
int: Unix timestamp or 0
"""
if dt == 0:
return 0
return int(dt.timestamp())

View File

@@ -0,0 +1,172 @@
"""
Bilibili uploader using biliup library
"""
import json
import asyncio
from pathlib import Path
from typing import Optional, List, Dict, Any
from datetime import datetime
from concurrent.futures import ThreadPoolExecutor
try:
from biliup.plugins.bili_webup import BiliBili, Data
BILIUP_AVAILABLE = True
except ImportError:
BILIUP_AVAILABLE = False
from loguru import logger
from .base_uploader import BaseUploader
# Thread pool for running sync biliup code
_executor = ThreadPoolExecutor(max_workers=2)
class BilibiliUploader(BaseUploader):
"""Bilibili video uploader using biliup library"""
def __init__(
self,
title: str,
file_path: str,
tags: List[str],
publish_date: Optional[datetime] = None,
account_file: Optional[str] = None,
description: str = "",
tid: int = 122, # 分区ID: 122=国内原创
copyright: int = 1 # 1=原创, 2=转载
):
"""
Initialize Bilibili uploader
Args:
tid: Bilibili category ID (default: 122 for 国内原创)
copyright: 1 for original, 2 for repost
"""
super().__init__(title, file_path, tags, publish_date, account_file, description)
self.tid = tid
self.copyright = copyright
if not BILIUP_AVAILABLE:
raise ImportError(
"biliup library not installed. Please run: pip install biliup"
)
async def main(self) -> Dict[str, Any]:
"""
Upload video to Bilibili
Returns:
dict: Upload result
"""
# Run sync upload in thread pool to avoid asyncio.run() conflict
loop = asyncio.get_event_loop()
return await loop.run_in_executor(_executor, self._upload_sync)
def _upload_sync(self) -> Dict[str, Any]:
"""Synchronous upload logic (runs in thread pool)"""
try:
# 1. Load cookie data
if not self.account_file or not Path(self.account_file).exists():
logger.error(f"[B站] Cookie 文件不存在: {self.account_file}")
return {
"success": False,
"message": "Cookie 文件不存在,请先登录",
"url": None
}
with open(self.account_file, 'r', encoding='utf-8') as f:
cookie_data = json.load(f)
# Convert simple cookie format to biliup format if needed
if 'cookie_info' not in cookie_data and 'SESSDATA' in cookie_data:
# Transform to biliup expected format
cookie_data = {
'cookie_info': {
'cookies': [
{'name': k, 'value': v} for k, v in cookie_data.items()
]
},
'token_info': {
'access_token': cookie_data.get('access_token', ''),
'refresh_token': cookie_data.get('refresh_token', '')
}
}
logger.info("[B站] Cookie格式已转换")
# 2. Prepare video data
data = Data()
data.copyright = self.copyright
data.title = self.title
data.desc = self.description or f"标签: {', '.join(self.tags)}"
data.tid = self.tid
data.set_tag(self.tags)
data.dtime = self._get_timestamp(self.publish_date)
logger.info(f"[B站] 开始上传: {self.file_path.name}")
logger.info(f"[B站] 标题: {self.title}")
logger.info(f"[B站] 定时发布: {'' if data.dtime > 0 else ''}")
# 3. Upload video
with BiliBili(data) as bili:
# Login with cookies
bili.login_by_cookies(cookie_data)
bili.access_token = cookie_data.get('access_token', '')
# Upload file (3 threads, auto line selection)
video_part = bili.upload_file(
str(self.file_path),
lines='AUTO',
tasks=3
)
video_part['title'] = self.title
data.append(video_part)
# Submit
ret = bili.submit()
# Debug: log full response
logger.debug(f"[B站] API响应: {ret}")
if ret.get('code') == 0:
# Try multiple keys for bvid (API may vary)
bvid = ret.get('data', {}).get('bvid') or ret.get('bvid', '')
aid = ret.get('data', {}).get('aid') or ret.get('aid', '')
if bvid:
logger.success(f"[B站] 上传成功: {bvid}")
return {
"success": True,
"message": "发布成功,待审核" if data.dtime == 0 else "已设置定时发布",
"url": f"https://www.bilibili.com/video/{bvid}"
}
elif aid:
logger.success(f"[B站] 上传成功: av{aid}")
return {
"success": True,
"message": "发布成功,待审核" if data.dtime == 0 else "已设置定时发布",
"url": f"https://www.bilibili.com/video/av{aid}"
}
else:
# No bvid/aid but code=0, still consider success
logger.warning(f"[B站] 上传返回code=0但无bvid/aid: {ret}")
return {
"success": True,
"message": "发布成功,待审核",
"url": None
}
else:
error_msg = ret.get('message', '未知错误')
logger.error(f"[B站] 上传失败: {error_msg} (完整响应: {ret})")
return {
"success": False,
"message": f"上传失败: {error_msg}",
"url": None
}
except Exception as e:
logger.exception(f"[B站] 上传异常: {e}")
return {
"success": False,
"message": f"上传异常: {str(e)}",
"url": None
}

View File

@@ -0,0 +1,107 @@
"""
Utility functions for cookie management and Playwright setup
"""
from pathlib import Path
from playwright.async_api import async_playwright
import json
from loguru import logger
from app.core.config import settings
async def set_init_script(context):
"""
Add stealth script to prevent bot detection
Args:
context: Playwright browser context
Returns:
Modified context
"""
# Add stealth.js if available
stealth_js_path = settings.BASE_DIR / "app" / "services" / "uploader" / "stealth.min.js"
if stealth_js_path.exists():
await context.add_init_script(path=stealth_js_path)
# Grant geolocation permission
await context.grant_permissions(['geolocation'])
return context
async def generate_cookie_with_qr(platform: str, platform_url: str, account_file: str):
"""
Generate cookie by scanning QR code with Playwright
Args:
platform: Platform name (for logging)
platform_url: Platform login URL
account_file: Path to save cookies
Returns:
bool: Success status
"""
try:
logger.info(f"[{platform}] 开始自动生成 Cookie...")
async with async_playwright() as playwright:
browser = await playwright.chromium.launch(headless=False)
context = await browser.new_context()
# Add stealth script
context = await set_init_script(context)
page = await context.new_page()
await page.goto(platform_url)
logger.info(f"[{platform}] 请在浏览器中扫码登录...")
logger.info(f"[{platform}] 登录后点击 Playwright Inspector 的 '继续' 按钮")
# Pause for user to login
await page.pause()
# Save cookies
await context.storage_state(path=account_file)
await browser.close()
logger.success(f"[{platform}] Cookie 已保存到: {account_file}")
return True
except Exception as e:
logger.exception(f"[{platform}] Cookie 生成失败: {e}")
return False
async def extract_bilibili_cookies(account_file: str):
"""
Extract specific Bilibili cookies needed by biliup
Args:
account_file: Path to cookies file
Returns:
dict: Extracted cookies
"""
try:
# Read Playwright storage_state format
with open(account_file, 'r', encoding='utf-8') as f:
storage = json.load(f)
# Extract cookies
cookie_dict = {}
for cookie in storage.get('cookies', []):
if cookie['name'] in ['SESSDATA', 'bili_jct', 'DedeUserID', 'DedeUserID__ckMd5']:
cookie_dict[cookie['name']] = cookie['value']
# Save in biliup format
with open(account_file, 'w', encoding='utf-8') as f:
json.dump(cookie_dict, f, indent=2)
logger.info(f"[B站] Cookie 已转换为 biliup 格式")
return cookie_dict
except Exception as e:
logger.exception(f"[B站] Cookie 提取失败: {e}")
return {}

View File

@@ -0,0 +1,585 @@
"""
Douyin (抖音) uploader using Playwright
Based on social-auto-upload implementation
"""
from datetime import datetime
from pathlib import Path
from typing import Optional, List, Dict, Any
import asyncio
import time
from playwright.async_api import Playwright, async_playwright
from loguru import logger
from .base_uploader import BaseUploader
from .cookie_utils import set_init_script
class DouyinUploader(BaseUploader):
"""Douyin video uploader using Playwright"""
# 超时配置 (秒)
UPLOAD_TIMEOUT = 300 # 视频上传超时
PUBLISH_TIMEOUT = 180 # 发布检测超时
PAGE_REDIRECT_TIMEOUT = 60 # 页面跳转超时
POLL_INTERVAL = 2 # 轮询间隔
MAX_CLICK_RETRIES = 3 # 按钮点击重试次数
def __init__(
self,
title: str,
file_path: str,
tags: List[str],
publish_date: Optional[datetime] = None,
account_file: Optional[str] = None,
description: str = ""
):
super().__init__(title, file_path, tags, publish_date, account_file, description)
self.upload_url = "https://creator.douyin.com/creator-micro/content/upload"
async def _is_text_visible(self, page, text: str, exact: bool = False) -> bool:
try:
return await page.get_by_text(text, exact=exact).first.is_visible()
except Exception:
return False
async def _first_visible_locator(self, locator, timeout: int = 1000):
try:
if await locator.count() == 0:
return None
candidate = locator.first
if await candidate.is_visible(timeout=timeout):
return candidate
except Exception:
return None
return None
async def _wait_for_publish_result(self, page, max_wait_time: int = 180):
success_texts = ["发布成功", "作品已发布", "再发一条", "查看作品", "审核中", "待审核"]
weak_texts = ["发布完成"]
failure_texts = ["发布失败", "发布异常", "发布出错", "请完善", "请补充", "请先上传"]
start_time = time.time()
poll_interval = 2
weak_reason = None
while time.time() - start_time < max_wait_time:
if page.is_closed():
return False, "页面已关闭", False
current_url = page.url
if "content/manage" in current_url:
return True, f"已跳转到管理页面 (URL: {current_url})", False
for text in success_texts:
if await self._is_text_visible(page, text, exact=False):
return True, f"检测到成功提示: {text}", False
for text in failure_texts:
if await self._is_text_visible(page, text, exact=False):
return False, f"检测到失败提示: {text}", False
for text in weak_texts:
if await self._is_text_visible(page, text, exact=False):
weak_reason = text
logger.info("[抖音] 视频正在发布中...")
await asyncio.sleep(poll_interval)
if weak_reason:
return False, f"检测到提示: {weak_reason}", True
return False, "发布检测超时", True
async def _fill_title(self, page, title: str) -> bool:
title_text = title[:30]
locator_candidates = []
try:
label_locator = page.get_by_text("作品描述").locator("..").locator("..").locator(
"xpath=following-sibling::div[1]"
).locator("textarea, input, div[contenteditable='true']")
locator_candidates.append(label_locator)
except Exception:
pass
locator_candidates.extend([
page.locator("textarea[placeholder*='作品描述']"),
page.locator("textarea[placeholder*='描述']"),
page.locator("input[placeholder*='作品描述']"),
page.locator("input[placeholder*='描述']"),
page.locator("div[contenteditable='true']"),
])
for locator in locator_candidates:
try:
if await locator.count() > 0:
target = locator.first
await target.fill(title_text)
return True
except Exception:
continue
return False
async def _select_cover_if_needed(self, page) -> bool:
try:
cover_button = page.get_by_text("选择封面", exact=False).first
if await cover_button.is_visible():
await cover_button.click()
logger.info("[抖音] 尝试选择封面")
await asyncio.sleep(0.5)
dialog = page.locator(
"div.dy-creator-content-modal-wrap, div[role='dialog'], "
"div[class*='modal'], div[class*='dialog']"
).last
scopes = [dialog] if await dialog.count() > 0 else [page]
switched = False
for scope in scopes:
for selector in [
"button:has-text('设置横封面')",
"div:has-text('设置横封面')",
"span:has-text('设置横封面')",
]:
try:
button = await self._first_visible_locator(scope.locator(selector))
if button:
await button.click()
logger.info("[抖音] 已切换到横封面设置")
await asyncio.sleep(0.5)
switched = True
break
except Exception:
continue
if switched:
break
selected = False
for scope in scopes:
for selector in [
"div[class*='cover'] img",
"div[class*='cover']",
"div[class*='frame'] img",
"div[class*='frame']",
"div[class*='preset']",
"img",
]:
try:
candidate = await self._first_visible_locator(scope.locator(selector))
if candidate:
await candidate.click()
logger.info("[抖音] 已选择封面帧")
selected = True
break
except Exception:
continue
if selected:
break
confirm_selectors = [
"button:has-text('完成')",
"button:has-text('确定')",
"button:has-text('保存')",
"button:has-text('确认')",
]
for selector in confirm_selectors:
try:
button = await self._first_visible_locator(page.locator(selector))
if button:
if not await button.is_enabled():
for _ in range(8):
if await button.is_enabled():
break
await asyncio.sleep(0.5)
await button.click()
logger.info(f"[抖音] 封面已确认: {selector}")
await asyncio.sleep(0.5)
if await dialog.count() > 0:
try:
await dialog.wait_for(state="hidden", timeout=5000)
except Exception:
pass
return True
except Exception:
continue
return selected
except Exception as e:
logger.warning(f"[抖音] 选择封面失败: {e}")
return False
async def _click_publish_confirm_modal(self, page):
confirm_selectors = [
"button:has-text('确认发布')",
"button:has-text('继续发布')",
"button:has-text('确定发布')",
"button:has-text('发布确认')",
]
for selector in confirm_selectors:
try:
button = page.locator(selector).first
if await button.is_visible():
await button.click()
logger.info(f"[抖音] 点击了发布确认按钮: {selector}")
await asyncio.sleep(1)
return True
except Exception:
continue
return False
async def _dismiss_blocking_modal(self, page) -> bool:
modal_locator = page.locator(
"div.dy-creator-content-modal-wrap, div[role='dialog'], "
"div[class*='modal'], div[class*='dialog']"
)
try:
count = await modal_locator.count()
except Exception:
return False
if count == 0:
return False
button_texts = [
"我知道了",
"知道了",
"确定",
"继续",
"继续发布",
"确认",
"同意并继续",
"完成",
"好的",
"明白了",
]
close_selectors = [
"button[class*='close']",
"span[class*='close']",
"i[class*='close']",
]
for index in range(count):
modal = modal_locator.nth(index)
try:
if not await modal.is_visible():
continue
for text in button_texts:
try:
button = modal.get_by_role("button", name=text).first
if await button.is_visible():
await button.click()
logger.info(f"[抖音] 关闭弹窗: {text}")
await asyncio.sleep(0.5)
return True
except Exception:
continue
for selector in close_selectors:
try:
close_button = modal.locator(selector).first
if await close_button.is_visible():
await close_button.click()
logger.info("[抖音] 关闭弹窗: close")
await asyncio.sleep(0.5)
return True
except Exception:
continue
except Exception:
continue
return False
async def _verify_publish_in_manage(self, page):
manage_url = "https://creator.douyin.com/creator-micro/content/manage"
try:
await page.goto(manage_url)
await page.wait_for_load_state("domcontentloaded")
await asyncio.sleep(2)
title_text = self.title[:30]
title_locator = page.get_by_text(title_text, exact=False).first
if await title_locator.is_visible():
return True, "内容管理中检测到新作品"
if await self._is_text_visible(page, "审核中", exact=False):
return True, "内容管理显示审核中"
except Exception as e:
return False, f"无法验证内容管理: {e}"
return False, "内容管理中未找到视频"
async def set_schedule_time(self, page, publish_date):
"""Set scheduled publish time"""
try:
# Click "定时发布" radio button
label_element = page.locator("[class^='radio']:has-text('定时发布')")
await label_element.click()
await asyncio.sleep(1)
# Format time
publish_date_hour = publish_date.strftime("%Y-%m-%d %H:%M")
# Fill datetime input
await page.locator('.semi-input[placeholder="日期和时间"]').click()
await page.keyboard.press("Control+KeyA")
await page.keyboard.type(str(publish_date_hour))
await page.keyboard.press("Enter")
await asyncio.sleep(1)
logger.info(f"[抖音] 已设置定时发布: {publish_date_hour}")
except Exception as e:
logger.error(f"[抖音] 设置定时发布失败: {e}")
async def upload(self, playwright: Playwright) -> dict:
"""Main upload logic with guaranteed resource cleanup"""
browser = None
context = None
try:
# Launch browser in headless mode for server deployment
browser = await playwright.chromium.launch(headless=True)
context = await browser.new_context(storage_state=self.account_file)
context = await set_init_script(context)
page = await context.new_page()
# Go to upload page
await page.goto(self.upload_url)
await page.wait_for_load_state('domcontentloaded')
await asyncio.sleep(2)
logger.info(f"[抖音] 正在上传: {self.file_path.name}")
# Check if redirected to login page (more reliable than text detection)
current_url = page.url
if "login" in current_url or "passport" in current_url:
logger.error("[抖音] Cookie 已失效,被重定向到登录页")
return {
"success": False,
"message": "Cookie 已失效,请重新登录",
"url": None
}
# Ensure we're on the upload page
if "content/upload" not in page.url:
logger.info("[抖音] 当前不在上传页面,强制跳转...")
await page.goto(self.upload_url)
await asyncio.sleep(2)
# Try multiple selectors for the file input (page structure varies)
file_uploaded = False
selectors = [
"div[class^='container'] input", # Primary selector from SuperIPAgent
"input[type='file']", # Fallback selector
"div[class^='upload'] input[type='file']", # Alternative
]
for selector in selectors:
try:
logger.info(f"[抖音] 尝试选择器: {selector}")
locator = page.locator(selector).first
if await locator.count() > 0:
await locator.set_input_files(str(self.file_path))
file_uploaded = True
logger.info(f"[抖音] 文件上传成功使用选择器: {selector}")
break
except Exception as e:
logger.warning(f"[抖音] 选择器 {selector} 失败: {e}")
continue
if not file_uploaded:
logger.error("[抖音] 所有选择器都失败,无法上传文件")
return {
"success": False,
"message": "无法找到上传按钮,页面可能已更新",
"url": None
}
# Wait for redirect to publish page (with timeout)
redirect_start = time.time()
while time.time() - redirect_start < self.PAGE_REDIRECT_TIMEOUT:
current_url = page.url
if "content/publish" in current_url or "content/post/video" in current_url:
logger.info("[抖音] 成功进入发布页面")
break
await asyncio.sleep(0.5)
else:
logger.error("[抖音] 等待发布页面超时")
return {
"success": False,
"message": "等待发布页面超时",
"url": None
}
# Fill title
await asyncio.sleep(1)
logger.info("[抖音] 正在填充标题和话题...")
if not await self._fill_title(page, self.title):
logger.warning("[抖音] 未找到作品描述输入框")
# Add tags
css_selector = ".zone-container"
for tag in self.tags:
await page.type(css_selector, "#" + tag)
await page.press(css_selector, "Space")
logger.info(f"[抖音] 总共添加 {len(self.tags)} 个话题")
cover_selected = await self._select_cover_if_needed(page)
if not cover_selected:
logger.warning("[抖音] 未确认封面选择,可能影响发布")
# Wait for upload to complete (with timeout)
upload_start = time.time()
while time.time() - upload_start < self.UPLOAD_TIMEOUT:
try:
number = await page.locator('[class^="long-card"] div:has-text("重新上传")').count()
if number > 0:
logger.success("[抖音] 视频上传完毕")
break
else:
logger.info("[抖音] 正在上传视频中...")
await asyncio.sleep(self.POLL_INTERVAL)
except Exception:
await asyncio.sleep(self.POLL_INTERVAL)
else:
logger.error("[抖音] 视频上传超时")
return {
"success": False,
"message": "视频上传超时",
"url": None
}
# Set scheduled publish time if needed
if self.publish_date != 0:
await self.set_schedule_time(page, self.publish_date)
# Click publish button
# 使用更稳健的点击逻辑
try:
publish_label = "定时发布" if self.publish_date != 0 else "发布"
publish_button = page.get_by_role('button', name=publish_label, exact=True)
# 等待按钮出现
await publish_button.wait_for(state="visible", timeout=10000)
if not await publish_button.is_enabled():
logger.error("[抖音] 发布按钮不可点击,可能需要补充封面或确认信息")
return {
"success": False,
"message": "发布按钮不可点击,请检查封面/声明等必填项",
"url": None
}
await asyncio.sleep(1) # 额外等待以确保可交互
clicked = False
for attempt in range(self.MAX_CLICK_RETRIES):
await self._dismiss_blocking_modal(page)
try:
await publish_button.click(timeout=5000)
logger.info(f"[抖音] 点击了{publish_label}按钮")
clicked = True
break
except Exception as click_error:
logger.warning(f"[抖音] 点击发布按钮失败,重试 {attempt + 1}/{self.MAX_CLICK_RETRIES}: {click_error}")
try:
await page.keyboard.press("Escape")
except Exception:
pass
await asyncio.sleep(1)
if not clicked:
raise RuntimeError("点击发布按钮失败")
except Exception as e:
logger.error(f"[抖音] 点击发布按钮失败: {e}")
# 尝试备用选择器
try:
fallback_selectors = ["button:has-text('发布')", "button:has-text('定时发布')"]
clicked = False
for selector in fallback_selectors:
try:
await page.click(selector, timeout=5000)
logger.info(f"[抖音] 使用备用选择器点击了按钮: {selector}")
clicked = True
break
except Exception:
continue
if not clicked:
return {
"success": False,
"message": "无法点击发布按钮,请检查页面状态",
"url": None
}
except Exception:
return {
"success": False,
"message": "无法点击发布按钮,请检查页面状态",
"url": None
}
await self._click_publish_confirm_modal(page)
# 4. 检测发布完成
publish_success, publish_reason, is_timeout = await self._wait_for_publish_result(page)
if not publish_success and is_timeout:
verify_success, verify_reason = await self._verify_publish_in_manage(page)
if verify_success:
publish_success = True
publish_reason = verify_reason
else:
publish_reason = f"{publish_reason}; {verify_reason}"
if publish_success:
logger.success(f"[抖音] 发布成功: {publish_reason}")
else:
if is_timeout:
logger.warning("[抖音] 发布检测超时,但这不一定代表失败")
else:
logger.warning(f"[抖音] 发布未成功: {publish_reason}")
# Save updated cookies
await context.storage_state(path=self.account_file)
logger.success("[抖音] Cookie 更新完毕")
await asyncio.sleep(2)
if publish_success:
return {
"success": True,
"message": "发布成功,待审核",
"url": None
}
if is_timeout:
return {
"success": True,
"message": "发布检测超时,请到抖音后台确认",
"url": None
}
return {
"success": False,
"message": f"发布失败: {publish_reason}",
"url": None
}
except Exception as e:
logger.exception(f"[抖音] 上传失败: {e}")
return {
"success": False,
"message": f"上传失败: {str(e)}",
"url": None
}
finally:
# 确保资源释放
if context:
try:
await context.close()
except Exception:
pass
if browser:
try:
await browser.close()
except Exception:
pass
async def main(self) -> Dict[str, Any]:
"""Execute upload"""
async with async_playwright() as playwright:
return await self.upload(playwright)

View File

@@ -0,0 +1,30 @@
// Stealth script to prevent bot detection
(() => {
// Overwrite the `plugins` property to use a custom getter.
Object.defineProperty(navigator, 'webdriver', {
get: () => false,
});
// Overwrite the `languages` property to use a custom getter.
Object.defineProperty(navigator, 'languages', {
get: () => ['zh-CN', 'zh', 'en'],
});
// Overwrite the `plugins` property to use a custom getter.
Object.defineProperty(navigator, 'plugins', {
get: () => [1, 2, 3, 4, 5],
});
// Pass the Chrome Test.
window.chrome = {
runtime: {},
};
// Pass the Permissions Test.
const originalQuery = window.navigator.permissions.query;
window.navigator.permissions.query = (parameters) => (
parameters.name === 'notifications' ?
Promise.resolve({ state: Notification.permission }) :
originalQuery(parameters)
);
})();

View File

@@ -0,0 +1,201 @@
"""
Xiaohongshu (小红书) uploader using Playwright
Based on social-auto-upload implementation
"""
from datetime import datetime
from pathlib import Path
from typing import Optional, List, Dict, Any
import asyncio
from playwright.async_api import Playwright, async_playwright
from loguru import logger
from .base_uploader import BaseUploader
from .cookie_utils import set_init_script
class XiaohongshuUploader(BaseUploader):
"""Xiaohongshu video uploader using Playwright"""
# 超时配置 (秒)
UPLOAD_TIMEOUT = 300 # 视频上传超时
PUBLISH_TIMEOUT = 120 # 发布检测超时
POLL_INTERVAL = 1 # 轮询间隔
def __init__(
self,
title: str,
file_path: str,
tags: List[str],
publish_date: Optional[datetime] = None,
account_file: Optional[str] = None,
description: str = ""
):
super().__init__(title, file_path, tags, publish_date, account_file, description)
self.upload_url = "https://creator.xiaohongshu.com/publish/publish?from=homepage&target=video"
async def set_schedule_time(self, page, publish_date):
"""Set scheduled publish time"""
try:
logger.info("[小红书] 正在设置定时发布时间...")
# Click "定时发布" label
label_element = page.locator("label:has-text('定时发布')")
await label_element.click()
await asyncio.sleep(1)
# Format time
publish_date_hour = publish_date.strftime("%Y-%m-%d %H:%M")
# Fill datetime input
await page.locator('.el-input__inner[placeholder="选择日期和时间"]').click()
await page.keyboard.press("Control+KeyA")
await page.keyboard.type(str(publish_date_hour))
await page.keyboard.press("Enter")
await asyncio.sleep(1)
logger.info(f"[小红书] 已设置定时发布: {publish_date_hour}")
except Exception as e:
logger.error(f"[小红书] 设置定时发布失败: {e}")
async def upload(self, playwright: Playwright) -> dict:
"""Main upload logic with guaranteed resource cleanup"""
browser = None
context = None
try:
# Launch browser (headless for server deployment)
browser = await playwright.chromium.launch(headless=True)
context = await browser.new_context(
viewport={"width": 1600, "height": 900},
storage_state=self.account_file
)
context = await set_init_script(context)
page = await context.new_page()
# Go to upload page
await page.goto(self.upload_url)
logger.info(f"[小红书] 正在上传: {self.file_path.name}")
# Upload video file
await page.locator("div[class^='upload-content'] input[class='upload-input']").set_input_files(str(self.file_path))
# Wait for upload to complete (with timeout)
import time
upload_start = time.time()
while time.time() - upload_start < self.UPLOAD_TIMEOUT:
try:
upload_input = await page.wait_for_selector('input.upload-input', timeout=3000)
preview_new = await upload_input.query_selector(
'xpath=following-sibling::div[contains(@class, "preview-new")]'
)
if preview_new:
stage_elements = await preview_new.query_selector_all('div.stage')
upload_success = False
for stage in stage_elements:
text_content = await page.evaluate('(element) => element.textContent', stage)
if '上传成功' in text_content:
upload_success = True
break
if upload_success:
logger.info("[小红书] 检测到上传成功标识")
break
else:
logger.info("[小红书] 未找到上传成功标识,继续等待...")
else:
logger.info("[小红书] 未找到预览元素,继续等待...")
await asyncio.sleep(self.POLL_INTERVAL)
except Exception as e:
logger.info(f"[小红书] 检测过程: {str(e)},重新尝试...")
await asyncio.sleep(0.5)
else:
logger.error("[小红书] 视频上传超时")
return {
"success": False,
"message": "视频上传超时",
"url": None
}
# Fill title and tags
await asyncio.sleep(1)
logger.info("[小红书] 正在填充标题和话题...")
title_container = page.locator('div.plugin.title-container').locator('input.d-text')
if await title_container.count():
await title_container.fill(self.title[:30])
# Add tags
css_selector = ".tiptap"
for tag in self.tags:
await page.type(css_selector, "#" + tag)
await page.press(css_selector, "Space")
logger.info(f"[小红书] 总共添加 {len(self.tags)} 个话题")
# Set scheduled publish time if needed
if self.publish_date != 0:
await self.set_schedule_time(page, self.publish_date)
# Click publish button (with timeout)
publish_start = time.time()
while time.time() - publish_start < self.PUBLISH_TIMEOUT:
try:
if self.publish_date != 0:
await page.locator('button:has-text("定时发布")').click()
else:
await page.locator('button:has-text("发布")').click()
await page.wait_for_url(
"https://creator.xiaohongshu.com/publish/success?**",
timeout=3000
)
logger.success("[小红书] 视频发布成功")
break
except Exception:
logger.info("[小红书] 视频正在发布中...")
await asyncio.sleep(0.5)
else:
logger.warning("[小红书] 发布检测超时,请手动确认")
# Save updated cookies
await context.storage_state(path=self.account_file)
logger.success("[小红书] Cookie 更新完毕")
await asyncio.sleep(2)
return {
"success": True,
"message": "发布成功,待审核" if self.publish_date == 0 else "已设置定时发布",
"url": None
}
except Exception as e:
logger.exception(f"[小红书] 上传失败: {e}")
return {
"success": False,
"message": f"上传失败: {str(e)}",
"url": None
}
finally:
# 确保资源释放
if context:
try:
await context.close()
except Exception:
pass
if browser:
try:
await browser.close()
except Exception:
pass
async def main(self) -> Dict[str, Any]:
"""Execute upload"""
async with async_playwright() as playwright:
return await self.upload(playwright)

View File

@@ -0,0 +1,73 @@
-- ViGent 用户认证系统数据库表
-- 在 Supabase SQL Editor 中执行
-- 1. 创建 users 表
CREATE TABLE IF NOT EXISTS users (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
email TEXT UNIQUE NOT NULL,
password_hash TEXT NOT NULL,
username TEXT,
role TEXT DEFAULT 'pending' CHECK (role IN ('pending', 'user', 'admin')),
is_active BOOLEAN DEFAULT FALSE,
expires_at TIMESTAMP WITH TIME ZONE,
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
);
-- 2. 创建 user_sessions 表 (单设备登录)
CREATE TABLE IF NOT EXISTS user_sessions (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID REFERENCES users(id) ON DELETE CASCADE UNIQUE,
session_token TEXT UNIQUE NOT NULL,
device_info TEXT,
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
);
-- 3. 创建 social_accounts 表 (社交账号绑定)
CREATE TABLE IF NOT EXISTS social_accounts (
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
user_id UUID REFERENCES users(id) ON DELETE CASCADE,
platform TEXT NOT NULL CHECK (platform IN ('bilibili', 'douyin', 'xiaohongshu')),
logged_in BOOLEAN DEFAULT FALSE,
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
UNIQUE(user_id, platform)
);
-- 4. 创建索引
CREATE INDEX IF NOT EXISTS idx_users_email ON users(email);
CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON user_sessions(user_id);
CREATE INDEX IF NOT EXISTS idx_social_user_platform ON social_accounts(user_id, platform);
-- 5. 启用 RLS (行级安全)
ALTER TABLE users ENABLE ROW LEVEL SECURITY;
ALTER TABLE user_sessions ENABLE ROW LEVEL SECURITY;
ALTER TABLE social_accounts ENABLE ROW LEVEL SECURITY;
-- 6. RLS 策略 (Service Role 可以绑过 RLS所以后端使用 service_role key 时不受限)
-- 以下策略仅对 anon key 生效
-- users: 仅管理员可查看所有用户,普通用户只能查看自己
CREATE POLICY "Users can view own profile" ON users
FOR SELECT USING (auth.uid()::text = id::text);
-- user_sessions: 用户只能访问自己的 session
CREATE POLICY "Users can access own sessions" ON user_sessions
FOR ALL USING (user_id::text = auth.uid()::text);
-- social_accounts: 用户只能访问自己的社交账号
CREATE POLICY "Users can access own social accounts" ON social_accounts
FOR ALL USING (user_id::text = auth.uid()::text);
-- 7. 更新时间自动更新触发器
CREATE OR REPLACE FUNCTION update_updated_at()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = NOW();
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER users_updated_at
BEFORE UPDATE ON users
FOR EACH ROW
EXECUTE FUNCTION update_updated_at();

93
backend/generate_keys.py Normal file
View File

@@ -0,0 +1,93 @@
import hmac
import hashlib
import base64
import json
import time
import secrets
import string
def generate_secure_secret(length=64):
"""生成安全的随机十六进制字符串"""
return secrets.token_hex(length // 2)
def generate_random_string(length=32):
"""生成包含字母数字的随机字符串 (用于密码等)"""
chars = string.ascii_letters + string.digits
return ''.join(secrets.choice(chars) for _ in range(length))
def base64url_encode(input_bytes):
return base64.urlsafe_b64encode(input_bytes).decode('utf-8').rstrip('=')
def generate_jwt(role, secret):
# 1. Header
header = {
"alg": "HS256",
"typ": "JWT"
}
# 2. Payload
now = int(time.time())
payload = {
"role": role,
"iss": "supabase",
"iat": now,
"exp": now + 315360000 # 10年有效期
}
# Encode parts
header_b64 = base64url_encode(json.dumps(header).encode('utf-8'))
payload_b64 = base64url_encode(json.dumps(payload).encode('utf-8'))
# 3. Signature
signing_input = f"{header_b64}.{payload_b64}".encode('utf-8')
signature = hmac.new(
secret.encode('utf-8'),
signing_input,
hashlib.sha256
).digest()
signature_b64 = base64url_encode(signature)
return f"{header_b64}.{payload_b64}.{signature_b64}"
if __name__ == "__main__":
print("=" * 60)
print("🔐 Supabase 全自动配置生成器 (Zero Dependency)")
print("=" * 60)
print("正在生成所有密钥...\n")
# 1. 自动生成主密钥
jwt_secret = generate_secure_secret(64)
# 2. 基于主密钥生成 JWT
anon_key = generate_jwt("anon", jwt_secret)
service_key = generate_jwt("service_role", jwt_secret)
# 3. 生成其他加密 Key和密码
vault_key = generate_secure_secret(32)
meta_key = generate_secure_secret(32)
secret_key_base = generate_secure_secret(64)
db_password = generate_random_string(20)
dashboard_password = generate_random_string(16)
# 4. 输出结果
print(f"✅ 生成完成!请直接复制以下内容覆盖您的 .env 文件中的对应部分:\n")
print("-" * 20 + " [ 复制开始 ] " + "-" * 20)
print(f"# === 数据库安全配置 ===")
print(f"POSTGRES_PASSWORD={db_password}")
print(f"JWT_SECRET={jwt_secret}")
print(f"ANON_KEY={anon_key}")
print(f"SERVICE_ROLE_KEY={service_key}")
print(f"SECRET_KEY_BASE={secret_key_base}")
print(f"VAULT_ENC_KEY={vault_key}")
print(f"PG_META_CRYPTO_KEY={meta_key}")
print(f"\n# === 管理后台配置 ===")
print(f"DASHBOARD_USERNAME=admin")
print(f"DASHBOARD_PASSWORD={dashboard_password}")
print("-" * 20 + " [ 复制结束 ] " + "-" * 20)
print("\n💡 提示:")
print(f"1. 数据库密码: {db_password}")
print(f"2. 后台登录密码: {dashboard_password}")
print("请妥善保管这些密码!")

View File

@@ -18,3 +18,13 @@ python-dotenv>=1.0.0
loguru>=0.7.2
playwright>=1.40.0
requests>=2.31.0
# 社交媒体发布
biliup>=0.4.0
# 用户认证
email-validator>=2.1.0
supabase>=2.0.0
python-jose[cryptography]>=3.3.0
passlib[bcrypt]>=1.7.4
bcrypt==4.0.1

View File

@@ -1,36 +1,72 @@
This is a [Next.js](https://nextjs.org) project bootstrapped with [`create-next-app`](https://nextjs.org/docs/app/api-reference/cli/create-next-app).
# ViGent2 Frontend
## Getting Started
ViGent2 的前端界面,采用 Next.js 14 + TailwindCSS 构建。
First, run the development server:
## ✨ 核心功能
### 1. 视频生成 (`/`)
- **素材管理**: 拖拽上传人物视频,实时预览。
- **文案配音**: 集成 EdgeTTS支持多音色选择 (云溪 / 晓晓)。
- **进度追踪**: 实时显示视频生成进度 (10% -> 100%)。
- **结果预览**: 生成完成后直接播放下载。
### 2. 全自动发布 (`/publish`) [Day 7 新增]
- **多平台管理**: 统一管理 B站、抖音、小红书账号状态。
- **扫码登录**:
- 集成后端 Playwright 生成的 QR Code。
- 实时检测扫码状态 (Wait/Success)。
- Cookie 自动保存与状态同步。
- **发布配置**: 设置视频标题、标签、简介。
- **定时任务**: 支持 "立即发布" 或 "定时发布"。
## 🛠️ 技术栈
- **框架**: Next.js 14 (App Router)
- **样式**: TailwindCSS
- **图标**: Lucide React
- **组件**: 自定义现代化组件 (Glassmorphism 风格)
- **API**: Fetch API (对接后端 FastAPI :8006)
## 🚀 开发指南
### 安装依赖
```bash
npm install
```
### 启动开发服务器
默认运行在 **3002** 端口 (通过 `package.json` 配置):
```bash
npm run dev
# or
yarn dev
# or
pnpm dev
# or
bun dev
# 访问: http://localhost:3002
```
Open [http://localhost:3000](http://localhost:3000) with your browser to see the result.
### 目录结构
You can start editing the page by modifying `app/page.tsx`. The page auto-updates as you edit the file.
```
src/
├── app/
│ ├── page.tsx # 视频生成主页
│ ├── publish/ # 发布管理页
│ │ └── page.tsx
│ └── layout.tsx # 全局布局 (导航栏)
├── components/ # UI 组件
│ ├── VideoUploader.tsx # 视频上传
│ ├── StatusBadge.tsx # 状态徽章
│ └── ...
└── lib/ # 工具函数
```
This project uses [`next/font`](https://nextjs.org/docs/app/building-your-application/optimizing/fonts) to automatically optimize and load [Geist](https://vercel.com/font), a new font family for Vercel.
## 🔌 后端对接
## Learn More
- **Base URL**: `http://localhost:8006`
- **代理配置**: Next.js Rewrites (如需) 或直接 CORS。
To learn more about Next.js, take a look at the following resources:
## 🎨 设计规范
- [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API.
- [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial.
You can check out [the Next.js GitHub repository](https://github.com/vercel/next.js) - your feedback and contributions are welcome!
## Deploy on Vercel
The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js.
Check out our [Next.js deployment documentation](https://nextjs.org/docs/app/building-your-application/deploying) for more details.
- **主色调**: 深紫/黑色系 (Dark Mode)
- **交互**: 悬停微动画 (Hover Effects)
- **响应式**: 适配桌面端大屏操作

View File

@@ -8,6 +8,14 @@ const nextConfig: NextConfig = {
source: '/api/:path*',
destination: 'http://localhost:8006/api/:path*', // 服务器本地代理
},
{
source: '/uploads/:path*',
destination: 'http://localhost:8006/uploads/:path*', // 转发上传的素材
},
{
source: '/outputs/:path*',
destination: 'http://localhost:8006/outputs/:path*', // 转发生成的视频
},
];
},
};

View File

@@ -8,9 +8,12 @@
"name": "frontend",
"version": "0.1.0",
"dependencies": {
"@supabase/supabase-js": "^2.93.1",
"axios": "^1.13.4",
"next": "16.1.1",
"react": "19.2.3",
"react-dom": "19.2.3"
"react-dom": "19.2.3",
"swr": "^2.3.8"
},
"devDependencies": {
"@tailwindcss/postcss": "^4",
@@ -67,7 +70,6 @@
"integrity": "sha512-H3mcG6ZDLTlYfaSNi0iOKkigqMFvkTKlGUYlD8GW7nNOYRrevuA46iTypPyv+06V3fEmvvazfntkBU34L0azAw==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@babel/code-frame": "^7.28.6",
"@babel/generator": "^7.28.6",
@@ -1234,6 +1236,80 @@
"dev": true,
"license": "MIT"
},
"node_modules/@supabase/auth-js": {
"version": "2.93.1",
"resolved": "https://registry.npmjs.org/@supabase/auth-js/-/auth-js-2.93.1.tgz",
"integrity": "sha512-pC0Ek4xk4z6q7A/3+UuZ/eYgfFUUQTg3DhapzrAgJnFGDJDFDyGCj6v9nIz8+3jfLqSZ3QKGe6AoEodYjShghg==",
"dependencies": {
"tslib": "2.8.1"
},
"engines": {
"node": ">=20.0.0"
}
},
"node_modules/@supabase/functions-js": {
"version": "2.93.1",
"resolved": "https://registry.npmjs.org/@supabase/functions-js/-/functions-js-2.93.1.tgz",
"integrity": "sha512-Ott2IcIXHGupaC0nX9WNEiJAX4OdlGRu9upkkURaQHbaLdz9JuCcHxlwTERgtgjMpikbIWHfMM1M9QTQFYABiA==",
"dependencies": {
"tslib": "2.8.1"
},
"engines": {
"node": ">=20.0.0"
}
},
"node_modules/@supabase/postgrest-js": {
"version": "2.93.1",
"resolved": "https://registry.npmjs.org/@supabase/postgrest-js/-/postgrest-js-2.93.1.tgz",
"integrity": "sha512-uRKKQJBDnfi6XFNFPNMh9+u3HT2PCgp065PcMPmG7e0xGuqvLtN89QxO2/SZcGbw2y1+mNBz0yUs5KmyNqF2fA==",
"dependencies": {
"tslib": "2.8.1"
},
"engines": {
"node": ">=20.0.0"
}
},
"node_modules/@supabase/realtime-js": {
"version": "2.93.1",
"resolved": "https://registry.npmjs.org/@supabase/realtime-js/-/realtime-js-2.93.1.tgz",
"integrity": "sha512-2WaP/KVHPlQDjWM6qe4wOZz6zSRGaXw1lfXf4thbfvk3C3zPPKqXRyspyYnk3IhphyxSsJ2hQ/cXNOz48008tg==",
"dependencies": {
"@types/phoenix": "^1.6.6",
"@types/ws": "^8.18.1",
"tslib": "2.8.1",
"ws": "^8.18.2"
},
"engines": {
"node": ">=20.0.0"
}
},
"node_modules/@supabase/storage-js": {
"version": "2.93.1",
"resolved": "https://registry.npmjs.org/@supabase/storage-js/-/storage-js-2.93.1.tgz",
"integrity": "sha512-3KVwd4S1i1BVPL6KIywe5rnruNQXSkLyvrdiJmwnqwbCcDujQumARdGWBPesqCjOPKEU2M9ORWKAsn+2iLzquA==",
"dependencies": {
"iceberg-js": "^0.8.1",
"tslib": "2.8.1"
},
"engines": {
"node": ">=20.0.0"
}
},
"node_modules/@supabase/supabase-js": {
"version": "2.93.1",
"resolved": "https://registry.npmjs.org/@supabase/supabase-js/-/supabase-js-2.93.1.tgz",
"integrity": "sha512-FJTgS5s0xEgRQ3u7gMuzGObwf3jA4O5Ki/DgCDXx94w1pihLM4/WG3XFa4BaCJYfuzLxLcv6zPPA5tDvBUjAUg==",
"dependencies": {
"@supabase/auth-js": "2.93.1",
"@supabase/functions-js": "2.93.1",
"@supabase/postgrest-js": "2.93.1",
"@supabase/realtime-js": "2.93.1",
"@supabase/storage-js": "2.93.1"
},
"engines": {
"node": ">=20.0.0"
}
},
"node_modules/@swc/helpers": {
"version": "0.5.15",
"resolved": "https://registry.npmjs.org/@swc/helpers/-/helpers-0.5.15.tgz",
@@ -1550,19 +1626,22 @@
"version": "20.19.28",
"resolved": "https://registry.npmjs.org/@types/node/-/node-20.19.28.tgz",
"integrity": "sha512-VyKBr25BuFDzBFCK5sUM6ZXiWfqgCTwTAOK8qzGV/m9FCirXYDlmczJ+d5dXBAQALGCdRRdbteKYfJ84NGEusw==",
"dev": true,
"license": "MIT",
"dependencies": {
"undici-types": "~6.21.0"
}
},
"node_modules/@types/phoenix": {
"version": "1.6.7",
"resolved": "https://registry.npmjs.org/@types/phoenix/-/phoenix-1.6.7.tgz",
"integrity": "sha512-oN9ive//QSBkf19rfDv45M7eZPi0eEXylht2OLEXicu5b4KoQ1OzXIw+xDSGWxSxe1JmepRR/ZH283vsu518/Q=="
},
"node_modules/@types/react": {
"version": "19.2.8",
"resolved": "https://registry.npmjs.org/@types/react/-/react-19.2.8.tgz",
"integrity": "sha512-3MbSL37jEchWZz2p2mjntRZtPt837ij10ApxKfgmXCTuHWagYg7iA5bqPw6C8BMPfwidlvfPI/fxOc42HLhcyg==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"csstype": "^3.2.2"
}
@@ -1577,6 +1656,14 @@
"@types/react": "^19.2.0"
}
},
"node_modules/@types/ws": {
"version": "8.18.1",
"resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.18.1.tgz",
"integrity": "sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==",
"dependencies": {
"@types/node": "*"
}
},
"node_modules/@typescript-eslint/eslint-plugin": {
"version": "8.53.0",
"resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-8.53.0.tgz",
@@ -1622,7 +1709,6 @@
"integrity": "sha512-npiaib8XzbjtzS2N4HlqPvlpxpmZ14FjSJrteZpPxGUaYPlvhzlzUZ4mZyABo0EFrOWnvyd0Xxroq//hKhtAWg==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@typescript-eslint/scope-manager": "8.53.0",
"@typescript-eslint/types": "8.53.0",
@@ -2122,7 +2208,6 @@
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
"dev": true,
"license": "MIT",
"peer": true,
"bin": {
"acorn": "bin/acorn"
},
@@ -2367,6 +2452,12 @@
"node": ">= 0.4"
}
},
"node_modules/asynckit": {
"version": "0.4.0",
"resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz",
"integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==",
"license": "MIT"
},
"node_modules/available-typed-arrays": {
"version": "1.0.7",
"resolved": "https://registry.npmjs.org/available-typed-arrays/-/available-typed-arrays-1.0.7.tgz",
@@ -2393,6 +2484,17 @@
"node": ">=4"
}
},
"node_modules/axios": {
"version": "1.13.4",
"resolved": "https://registry.npmjs.org/axios/-/axios-1.13.4.tgz",
"integrity": "sha512-1wVkUaAO6WyaYtCkcYCOx12ZgpGf9Zif+qXa4n+oYzK558YryKqiL6UWwd5DqiH3VRW0GYhTZQ/vlgJrCoNQlg==",
"license": "MIT",
"dependencies": {
"follow-redirects": "^1.15.6",
"form-data": "^4.0.4",
"proxy-from-env": "^1.1.0"
}
},
"node_modules/axobject-query": {
"version": "4.1.0",
"resolved": "https://registry.npmjs.org/axobject-query/-/axobject-query-4.1.0.tgz",
@@ -2463,7 +2565,6 @@
}
],
"license": "MIT",
"peer": true,
"dependencies": {
"baseline-browser-mapping": "^2.9.0",
"caniuse-lite": "^1.0.30001759",
@@ -2501,7 +2602,6 @@
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz",
"integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==",
"dev": true,
"license": "MIT",
"dependencies": {
"es-errors": "^1.3.0",
@@ -2601,6 +2701,18 @@
"dev": true,
"license": "MIT"
},
"node_modules/combined-stream": {
"version": "1.0.8",
"resolved": "https://registry.npmjs.org/combined-stream/-/combined-stream-1.0.8.tgz",
"integrity": "sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==",
"license": "MIT",
"dependencies": {
"delayed-stream": "~1.0.0"
},
"engines": {
"node": ">= 0.8"
}
},
"node_modules/concat-map": {
"version": "0.0.1",
"resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz",
@@ -2759,6 +2871,24 @@
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/delayed-stream": {
"version": "1.0.0",
"resolved": "https://registry.npmjs.org/delayed-stream/-/delayed-stream-1.0.0.tgz",
"integrity": "sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==",
"license": "MIT",
"engines": {
"node": ">=0.4.0"
}
},
"node_modules/dequal": {
"version": "2.0.3",
"resolved": "https://registry.npmjs.org/dequal/-/dequal-2.0.3.tgz",
"integrity": "sha512-0je+qPKHEMohvfRTCEo3CrPG6cAzAYgmzKyxRiYSSDkS6eGJdyVJm7WaYA5ECaAD9wLB2T4EEeymA5aFVcYXCA==",
"license": "MIT",
"engines": {
"node": ">=6"
}
},
"node_modules/detect-libc": {
"version": "2.1.2",
"resolved": "https://registry.npmjs.org/detect-libc/-/detect-libc-2.1.2.tgz",
@@ -2786,7 +2916,6 @@
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz",
"integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==",
"dev": true,
"license": "MIT",
"dependencies": {
"call-bind-apply-helpers": "^1.0.1",
@@ -2898,7 +3027,6 @@
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.1.tgz",
"integrity": "sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">= 0.4"
@@ -2908,7 +3036,6 @@
"version": "1.3.0",
"resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz",
"integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">= 0.4"
@@ -2946,7 +3073,6 @@
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.1.tgz",
"integrity": "sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==",
"dev": true,
"license": "MIT",
"dependencies": {
"es-errors": "^1.3.0"
@@ -2959,7 +3085,6 @@
"version": "2.1.0",
"resolved": "https://registry.npmjs.org/es-set-tostringtag/-/es-set-tostringtag-2.1.0.tgz",
"integrity": "sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==",
"dev": true,
"license": "MIT",
"dependencies": {
"es-errors": "^1.3.0",
@@ -3031,7 +3156,6 @@
"integrity": "sha512-LEyamqS7W5HB3ujJyvi0HQK/dtVINZvd5mAAp9eT5S/ujByGjiZLCzPcHVzuXbpJDJF/cxwHlfceVUDZ2lnSTw==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@eslint-community/eslint-utils": "^4.8.0",
"@eslint-community/regexpp": "^4.12.1",
@@ -3217,7 +3341,6 @@
"integrity": "sha512-whOE1HFo/qJDyX4SnXzP4N6zOWn79WhnCUY/iDR0mPfQZO8wcYE4JClzI2oZrhBnnMUCBCHZhO6VQyoBU95mZA==",
"dev": true,
"license": "MIT",
"peer": true,
"dependencies": {
"@rtsao/scc": "^1.1.0",
"array-includes": "^3.1.9",
@@ -3576,6 +3699,26 @@
"dev": true,
"license": "ISC"
},
"node_modules/follow-redirects": {
"version": "1.15.11",
"resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.11.tgz",
"integrity": "sha512-deG2P0JfjrTxl50XGCDyfI97ZGVCxIpfKYmfyrQ54n5FO/0gfIES8C/Psl6kWVDolizcaaxZJnTS0QSMxvnsBQ==",
"funding": [
{
"type": "individual",
"url": "https://github.com/sponsors/RubenVerborgh"
}
],
"license": "MIT",
"engines": {
"node": ">=4.0"
},
"peerDependenciesMeta": {
"debug": {
"optional": true
}
}
},
"node_modules/for-each": {
"version": "0.3.5",
"resolved": "https://registry.npmjs.org/for-each/-/for-each-0.3.5.tgz",
@@ -3592,11 +3735,26 @@
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/form-data": {
"version": "4.0.5",
"resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.5.tgz",
"integrity": "sha512-8RipRLol37bNs2bhoV67fiTEvdTrbMUYcFTiy3+wuuOnUog2QBHCZWXDRijWQfAkhBj2Uf5UnVaiWwA5vdd82w==",
"license": "MIT",
"dependencies": {
"asynckit": "^0.4.0",
"combined-stream": "^1.0.8",
"es-set-tostringtag": "^2.1.0",
"hasown": "^2.0.2",
"mime-types": "^2.1.12"
},
"engines": {
"node": ">= 6"
}
},
"node_modules/function-bind": {
"version": "1.1.2",
"resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz",
"integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==",
"dev": true,
"license": "MIT",
"funding": {
"url": "https://github.com/sponsors/ljharb"
@@ -3657,7 +3815,6 @@
"version": "1.3.0",
"resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz",
"integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==",
"dev": true,
"license": "MIT",
"dependencies": {
"call-bind-apply-helpers": "^1.0.2",
@@ -3682,7 +3839,6 @@
"version": "1.0.1",
"resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz",
"integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==",
"dev": true,
"license": "MIT",
"dependencies": {
"dunder-proto": "^1.0.1",
@@ -3770,7 +3926,6 @@
"version": "1.2.0",
"resolved": "https://registry.npmjs.org/gopd/-/gopd-1.2.0.tgz",
"integrity": "sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">= 0.4"
@@ -3842,7 +3997,6 @@
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz",
"integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">= 0.4"
@@ -3855,7 +4009,6 @@
"version": "1.0.2",
"resolved": "https://registry.npmjs.org/has-tostringtag/-/has-tostringtag-1.0.2.tgz",
"integrity": "sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw==",
"dev": true,
"license": "MIT",
"dependencies": {
"has-symbols": "^1.0.3"
@@ -3871,7 +4024,6 @@
"version": "2.0.2",
"resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz",
"integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==",
"dev": true,
"license": "MIT",
"dependencies": {
"function-bind": "^1.1.2"
@@ -3897,6 +4049,14 @@
"hermes-estree": "0.25.1"
}
},
"node_modules/iceberg-js": {
"version": "0.8.1",
"resolved": "https://registry.npmjs.org/iceberg-js/-/iceberg-js-0.8.1.tgz",
"integrity": "sha512-1dhVQZXhcHje7798IVM+xoo/1ZdVfzOMIc8/rgVSijRK38EDqOJoGula9N/8ZI5RD8QTxNQtK/Gozpr+qUqRRA==",
"engines": {
"node": ">=20.0.0"
}
},
"node_modules/ignore": {
"version": "5.3.2",
"resolved": "https://registry.npmjs.org/ignore/-/ignore-5.3.2.tgz",
@@ -4854,7 +5014,6 @@
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz",
"integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==",
"dev": true,
"license": "MIT",
"engines": {
"node": ">= 0.4"
@@ -4884,6 +5043,27 @@
"node": ">=8.6"
}
},
"node_modules/mime-db": {
"version": "1.52.0",
"resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.52.0.tgz",
"integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==",
"license": "MIT",
"engines": {
"node": ">= 0.6"
}
},
"node_modules/mime-types": {
"version": "2.1.35",
"resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.35.tgz",
"integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==",
"license": "MIT",
"dependencies": {
"mime-db": "1.52.0"
},
"engines": {
"node": ">= 0.6"
}
},
"node_modules/minimatch": {
"version": "3.1.2",
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz",
@@ -5354,6 +5534,12 @@
"react-is": "^16.13.1"
}
},
"node_modules/proxy-from-env": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/proxy-from-env/-/proxy-from-env-1.1.0.tgz",
"integrity": "sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==",
"license": "MIT"
},
"node_modules/punycode": {
"version": "2.3.1",
"resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz",
@@ -5390,7 +5576,6 @@
"resolved": "https://registry.npmjs.org/react/-/react-19.2.3.tgz",
"integrity": "sha512-Ku/hhYbVjOQnXDZFv2+RibmLFGwFdeeKHFcOTlrt7xplBnya5OGn/hIRDsqDiSUcfORsDC7MPxwork8jBwsIWA==",
"license": "MIT",
"peer": true,
"engines": {
"node": ">=0.10.0"
}
@@ -5400,7 +5585,6 @@
"resolved": "https://registry.npmjs.org/react-dom/-/react-dom-19.2.3.tgz",
"integrity": "sha512-yELu4WmLPw5Mr/lmeEpox5rw3RETacE++JgHqQzd2dg+YbJuat3jH4ingc+WPZhxaoFzdv9y33G+F7Nl5O0GBg==",
"license": "MIT",
"peer": true,
"dependencies": {
"scheduler": "^0.27.0"
},
@@ -6027,6 +6211,19 @@
"url": "https://github.com/sponsors/ljharb"
}
},
"node_modules/swr": {
"version": "2.3.8",
"resolved": "https://registry.npmjs.org/swr/-/swr-2.3.8.tgz",
"integrity": "sha512-gaCPRVoMq8WGDcWj9p4YWzCMPHzE0WNl6W8ADIx9c3JBEIdMkJGMzW+uzXvxHMltwcYACr9jP+32H8/hgwMR7w==",
"license": "MIT",
"dependencies": {
"dequal": "^2.0.3",
"use-sync-external-store": "^1.6.0"
},
"peerDependencies": {
"react": "^16.11.0 || ^17.0.0 || ^18.0.0 || ^19.0.0"
}
},
"node_modules/tailwindcss": {
"version": "4.1.18",
"resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-4.1.18.tgz",
@@ -6089,7 +6286,6 @@
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
"dev": true,
"license": "MIT",
"peer": true,
"engines": {
"node": ">=12"
},
@@ -6252,7 +6448,6 @@
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
"dev": true,
"license": "Apache-2.0",
"peer": true,
"bin": {
"tsc": "bin/tsc",
"tsserver": "bin/tsserver"
@@ -6308,7 +6503,6 @@
"version": "6.21.0",
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.21.0.tgz",
"integrity": "sha512-iwDZqg0QAGrg9Rav5H4n0M64c3mkR59cJ6wQp+7C4nI0gsmExaedaYLNO44eT4AtBBwjbTiGPMlt2Md0T9H9JQ==",
"dev": true,
"license": "MIT"
},
"node_modules/unrs-resolver": {
@@ -6387,6 +6581,15 @@
"punycode": "^2.1.0"
}
},
"node_modules/use-sync-external-store": {
"version": "1.6.0",
"resolved": "https://registry.npmjs.org/use-sync-external-store/-/use-sync-external-store-1.6.0.tgz",
"integrity": "sha512-Pp6GSwGP/NrPIrxVFAIkOQeyw8lFenOHijQWkUTrDvrF4ALqylP2C/KCkeS9dpUM3KvYRQhna5vt7IL95+ZQ9w==",
"license": "MIT",
"peerDependencies": {
"react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0"
}
},
"node_modules/which": {
"version": "2.0.2",
"resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz",
@@ -6502,6 +6705,26 @@
"node": ">=0.10.0"
}
},
"node_modules/ws": {
"version": "8.19.0",
"resolved": "https://registry.npmjs.org/ws/-/ws-8.19.0.tgz",
"integrity": "sha512-blAT2mjOEIi0ZzruJfIhb3nps74PRWTCz1IjglWEEpQl5XS/UNama6u2/rjFkDDouqr4L67ry+1aGIALViWjDg==",
"engines": {
"node": ">=10.0.0"
},
"peerDependencies": {
"bufferutil": "^4.0.1",
"utf-8-validate": ">=5.0.2"
},
"peerDependenciesMeta": {
"bufferutil": {
"optional": true
},
"utf-8-validate": {
"optional": true
}
}
},
"node_modules/yallist": {
"version": "3.1.1",
"resolved": "https://registry.npmjs.org/yallist/-/yallist-3.1.1.tgz",
@@ -6528,7 +6751,6 @@
"integrity": "sha512-k7Nwx6vuWx1IJ9Bjuf4Zt1PEllcwe7cls3VNzm4CQ1/hgtFUK2bRNG3rvnpPUhFjmqJKAKtjV576KnUkHocg/g==",
"dev": true,
"license": "MIT",
"peer": true,
"funding": {
"url": "https://github.com/sponsors/colinhacks"
}

View File

@@ -9,9 +9,12 @@
"lint": "eslint"
},
"dependencies": {
"@supabase/supabase-js": "^2.93.1",
"axios": "^1.13.4",
"next": "16.1.1",
"react": "19.2.3",
"react-dom": "19.2.3"
"react-dom": "19.2.3",
"swr": "^2.3.8"
},
"devDependencies": {
"@tailwindcss/postcss": "^4",
@@ -23,4 +26,4 @@
"tailwindcss": "^4",
"typescript": "^5"
}
}
}

View File

@@ -0,0 +1,190 @@
'use client';
import { useState, useEffect } from 'react';
import { useRouter } from 'next/navigation';
import { getCurrentUser, User } from '@/lib/auth';
import api from '@/lib/axios';
interface UserListItem {
id: string;
email: string;
username: string | null;
role: string;
is_active: boolean;
expires_at: string | null;
created_at: string;
}
export default function AdminPage() {
const router = useRouter();
const [currentUser, setCurrentUser] = useState<User | null>(null);
const [users, setUsers] = useState<UserListItem[]>([]);
const [loading, setLoading] = useState(true);
const [error, setError] = useState('');
const [activatingId, setActivatingId] = useState<string | null>(null);
const [expireDays, setExpireDays] = useState<number>(30);
useEffect(() => {
checkAdmin();
fetchUsers();
}, []);
const checkAdmin = async () => {
const user = await getCurrentUser();
if (!user || user.role !== 'admin') {
router.push('/login');
return;
}
setCurrentUser(user);
};
const fetchUsers = async () => {
try {
const { data } = await api.get('/api/admin/users');
setUsers(data);
} catch (err) {
setError('获取用户列表失败');
} finally {
setLoading(false);
}
};
const activateUser = async (userId: string) => {
setActivatingId(userId);
try {
await api.post(`/api/admin/users/${userId}/activate`, {
expires_days: expireDays || null
});
fetchUsers();
} catch (err) {
// axios interceptor handles 401/403
} finally {
setActivatingId(null);
}
};
const deactivateUser = async (userId: string) => {
if (!confirm('确定要停用该用户吗?')) return;
try {
await api.post(`/api/admin/users/${userId}/deactivate`);
fetchUsers();
} catch (err) {
alert('操作失败');
}
};
const formatDate = (dateStr: string | null) => {
if (!dateStr) return '永久';
return new Date(dateStr).toLocaleDateString('zh-CN');
};
const getRoleBadge = (role: string, isActive: boolean) => {
if (role === 'admin') {
return <span className="px-2 py-1 text-xs rounded-full bg-purple-500/20 text-purple-300"></span>;
}
if (role === 'pending') {
return <span className="px-2 py-1 text-xs rounded-full bg-yellow-500/20 text-yellow-300"></span>;
}
if (!isActive) {
return <span className="px-2 py-1 text-xs rounded-full bg-red-500/20 text-red-300"></span>;
}
return <span className="px-2 py-1 text-xs rounded-full bg-green-500/20 text-green-300"></span>;
};
if (loading) {
return (
<div className="min-h-dvh flex items-center justify-center">
<div className="animate-spin rounded-full h-12 w-12 border-t-2 border-b-2 border-purple-500"></div>
</div>
);
}
return (
<div className="min-h-dvh p-8">
<div className="max-w-6xl mx-auto">
<div className="flex justify-between items-center mb-8">
<h1 className="text-3xl font-bold text-white"></h1>
<a href="/" className="text-purple-300 hover:text-purple-200">
</a>
</div>
{error && (
<div className="mb-4 p-3 bg-red-500/20 border border-red-500/50 rounded-lg text-red-200">
{error}
</div>
)}
<div className="mb-4 flex items-center gap-4">
<label className="text-gray-300"></label>
<input
type="number"
value={expireDays}
onChange={(e) => setExpireDays(parseInt(e.target.value) || 0)}
className="w-24 px-3 py-2 bg-white/5 border border-white/10 rounded text-white"
placeholder="0=永久"
/>
<span className="text-gray-400 text-sm">(0 )</span>
</div>
<div className="bg-white/5 backdrop-blur-lg rounded-xl border border-white/10 overflow-hidden">
<table className="w-full">
<thead className="bg-white/5">
<tr>
<th className="px-6 py-4 text-left text-sm font-medium text-gray-300"></th>
<th className="px-6 py-4 text-left text-sm font-medium text-gray-300"></th>
<th className="px-6 py-4 text-left text-sm font-medium text-gray-300"></th>
<th className="px-6 py-4 text-left text-sm font-medium text-gray-300"></th>
<th className="px-6 py-4 text-left text-sm font-medium text-gray-300"></th>
</tr>
</thead>
<tbody className="divide-y divide-white/5">
{users.map((user) => (
<tr key={user.id} className="hover:bg-white/5">
<td className="px-6 py-4">
<div>
<div className="text-white font-medium">{user.username || user.email.split('@')[0]}</div>
<div className="text-gray-400 text-sm">{user.email}</div>
</div>
</td>
<td className="px-6 py-4">
{getRoleBadge(user.role, user.is_active)}
</td>
<td className="px-6 py-4 text-gray-300">
{formatDate(user.expires_at)}
</td>
<td className="px-6 py-4 text-gray-400 text-sm">
{formatDate(user.created_at)}
</td>
<td className="px-6 py-4">
{user.role !== 'admin' && (
<div className="flex gap-2">
{!user.is_active || user.role === 'pending' ? (
<button
onClick={() => activateUser(user.id)}
disabled={activatingId === user.id}
className="px-3 py-1 bg-green-600 hover:bg-green-700 text-white text-sm rounded disabled:opacity-50"
>
{activatingId === user.id ? '...' : '激活'}
</button>
) : (
<button
onClick={() => deactivateUser(user.id)}
className="px-3 py-1 bg-red-600 hover:bg-red-700 text-white text-sm rounded"
>
</button>
)}
</div>
)}
</td>
</tr>
))}
</tbody>
</table>
</div>
</div>
</div>
);
}

View File

@@ -19,8 +19,73 @@
}
}
/* iOS Safari 安全区域支持 + 滚动条隐藏 */
html {
background-color: #0f172a !important;
min-height: 100%;
scrollbar-width: none;
-ms-overflow-style: none;
}
html::-webkit-scrollbar {
display: none;
}
body {
background: var(--background);
margin: 0 !important;
min-height: 100dvh;
color: var(--foreground);
font-family: Arial, Helvetica, sans-serif;
padding-top: env(safe-area-inset-top);
padding-bottom: env(safe-area-inset-bottom);
}
/* 自定义滚动条样式 - 深色主题 */
.custom-scrollbar {
scrollbar-width: thin;
scrollbar-color: rgba(147, 51, 234, 0.5) transparent;
}
.custom-scrollbar::-webkit-scrollbar {
width: 6px;
}
.custom-scrollbar::-webkit-scrollbar-track {
background: transparent;
}
.custom-scrollbar::-webkit-scrollbar-thumb {
background: rgba(147, 51, 234, 0.5);
border-radius: 3px;
}
.custom-scrollbar::-webkit-scrollbar-thumb:hover {
background: rgba(147, 51, 234, 0.8);
}
/* 完全隐藏滚动条 */
.hide-scrollbar {
scrollbar-width: none;
-ms-overflow-style: none;
}
.hide-scrollbar::-webkit-scrollbar {
display: none;
}
/* 自定义 select 下拉菜单 */
.custom-select {
appearance: none;
-webkit-appearance: none;
-moz-appearance: none;
background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='12' height='12' fill='%239ca3af' viewBox='0 0 16 16'%3E%3Cpath d='M8 11L3 6h10l-5 5z'/%3E%3C/svg%3E");
background-repeat: no-repeat;
background-position: right 12px center;
padding-right: 36px;
}
.custom-select option {
background: #1a1a2e;
color: white;
padding: 12px;
}

View File

@@ -1,4 +1,4 @@
import type { Metadata } from "next";
import type { Metadata, Viewport } from "next";
import { Geist, Geist_Mono } from "next/font/google";
import "./globals.css";
@@ -13,8 +13,15 @@ const geistMono = Geist_Mono({
});
export const metadata: Metadata = {
title: "Create Next App",
description: "Generated by create next app",
title: "ViGent",
description: "ViGent Talking Head Agent",
};
export const viewport: Viewport = {
width: 'device-width',
initialScale: 1,
viewportFit: 'cover',
themeColor: '#0f172a',
};
export default function RootLayout({
@@ -23,9 +30,14 @@ export default function RootLayout({
children: React.ReactNode;
}>) {
return (
<html lang="en">
<html lang="en" style={{ backgroundColor: '#0f172a' }}>
<body
className={`${geistSans.variable} ${geistMono.variable} antialiased`}
style={{
margin: 0,
minHeight: '100dvh',
background: 'linear-gradient(to bottom, #0f172a 0%, #0f172a 5%, #581c87 50%, #0f172a 95%, #0f172a 100%)',
}}
>
{children}
</body>

View File

@@ -0,0 +1,101 @@
'use client';
import { useState } from 'react';
import { useRouter } from 'next/navigation';
import { login } from '@/lib/auth';
export default function LoginPage() {
const router = useRouter();
const [email, setEmail] = useState('');
const [password, setPassword] = useState('');
const [error, setError] = useState('');
const [loading, setLoading] = useState(false);
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
setError('');
setLoading(true);
try {
const result = await login(email, password);
if (result.success) {
router.push('/');
} else {
setError(result.message || '登录失败');
}
} catch (err) {
setError('网络错误,请稍后重试');
} finally {
setLoading(false);
}
};
return (
<div className="min-h-dvh flex items-center justify-center">
<div className="w-full max-w-md p-8 bg-white/10 backdrop-blur-lg rounded-2xl shadow-2xl border border-white/20">
<div className="text-center mb-8">
<h1 className="text-3xl font-bold text-white mb-2">ViGent</h1>
<p className="text-gray-300">AI </p>
</div>
<form onSubmit={handleSubmit} className="space-y-6">
<div>
<label className="block text-sm font-medium text-gray-200 mb-2">
</label>
<input
type="email"
value={email}
onChange={(e) => setEmail(e.target.value)}
required
className="w-full px-4 py-3 bg-white/5 border border-white/10 rounded-lg text-white placeholder-gray-400 focus:outline-none focus:ring-2 focus:ring-purple-500 focus:border-transparent"
placeholder="your@email.com"
/>
</div>
<div>
<label className="block text-sm font-medium text-gray-200 mb-2">
</label>
<input
type="password"
value={password}
onChange={(e) => setPassword(e.target.value)}
required
className="w-full px-4 py-3 bg-white/5 border border-white/10 rounded-lg text-white placeholder-gray-400 focus:outline-none focus:ring-2 focus:ring-purple-500 focus:border-transparent"
placeholder="••••••••"
/>
</div>
{error && (
<div className="p-3 bg-red-500/20 border border-red-500/50 rounded-lg text-red-200 text-sm">
{error}
</div>
)}
<button
type="submit"
disabled={loading}
className="w-full py-3 px-4 bg-gradient-to-r from-purple-600 to-pink-600 hover:from-purple-700 hover:to-pink-700 text-white font-semibold rounded-lg shadow-lg transition-all duration-200 disabled:opacity-50 disabled:cursor-not-allowed"
>
{loading ? (
<span className="flex items-center justify-center">
<svg className="animate-spin -ml-1 mr-3 h-5 w-5 text-white" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
<circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4"></circle>
<path className="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
</svg>
...
</span>
) : '登录'}
</button>
</form>
<div className="mt-6 text-center">
<a href="/register" className="text-purple-300 hover:text-purple-200 text-sm">
</a>
</div>
</div>
</div>
);
}

View File

@@ -2,11 +2,12 @@
"use client";
import { useState, useEffect } from "react";
import Link from "next/link";
import api from "@/lib/axios";
// 动态获取 API 地址:服务端使用 localhost客户端使用当前域名
const API_BASE = typeof window !== 'undefined'
? `http://${window.location.hostname}:8006`
: 'http://localhost:8006';
const API_BASE = typeof window === 'undefined'
? 'http://localhost:8006'
: '';
// 类型定义
interface Material {
@@ -25,6 +26,25 @@ interface Task {
download_url?: string;
}
interface GeneratedVideo {
id: string;
name: string;
path: string;
size_mb: number;
created_at: number;
}
// 格式化日期(避免 Hydration 错误)
const formatDate = (timestamp: number) => {
const d = new Date(timestamp * 1000);
const year = d.getFullYear();
const month = String(d.getMonth() + 1).padStart(2, '0');
const day = String(d.getDate()).padStart(2, '0');
const hour = String(d.getHours()).padStart(2, '0');
const minute = String(d.getMinutes()).padStart(2, '0');
return `${year}/${month}/${day} ${hour}:${minute}`;
};
export default function Home() {
const [materials, setMaterials] = useState<Material[]>([]);
const [selectedMaterial, setSelectedMaterial] = useState<string>("");
@@ -40,6 +60,10 @@ export default function Home() {
const [isUploading, setIsUploading] = useState(false);
const [uploadProgress, setUploadProgress] = useState(0);
const [uploadError, setUploadError] = useState<string | null>(null);
const [uploadData, setUploadData] = useState<string>("");
const [generatedVideos, setGeneratedVideos] = useState<GeneratedVideo[]>([]);
const [selectedVideoId, setSelectedVideoId] = useState<string | null>(null);
// 可选音色
const voices = [
@@ -50,9 +74,10 @@ export default function Home() {
{ id: "zh-CN-XiaoyiNeural", name: "晓伊 (女声-温柔)" },
];
// 加载素材列表
// 加载素材列表和历史视频
useEffect(() => {
fetchMaterials();
fetchGeneratedVideos();
}, []);
const fetchMaterials = async () => {
@@ -60,18 +85,8 @@ export default function Home() {
setFetchError(null);
setDebugData("Loading...");
// Add timestamp to prevent caching
const url = `${API_BASE}/api/materials/?t=${new Date().getTime()}`;
const res = await fetch(url);
if (!res.ok) {
throw new Error(`HTTP ${res.status} ${res.statusText}`);
}
const text = await res.text(); // Get raw text first
setDebugData(text.substring(0, 200) + (text.length > 200 ? "..." : "")); // Show preview
const data = JSON.parse(text);
const { data } = await api.get(`/api/materials?t=${new Date().getTime()}`);
setDebugData(JSON.stringify(data).substring(0, 200));
setMaterials(data.materials || []);
if (data.materials?.length > 0) {
@@ -86,7 +101,46 @@ export default function Home() {
}
};
// 上传视频
// 获取已生成的视频列表(持久化)
const fetchGeneratedVideos = async () => {
try {
const { data } = await api.get('/api/videos/generated');
setGeneratedVideos(data.videos || []);
} catch (error) {
console.error("获取历史视频失败:", error);
}
};
// 删除素材
const deleteMaterial = async (materialId: string) => {
if (!confirm("确定要删除这个素材吗?")) return;
try {
await api.delete(`/api/materials/${materialId}`);
fetchMaterials();
if (selectedMaterial === materialId) {
setSelectedMaterial("");
}
} catch (error) {
alert("删除失败: " + error);
}
};
// 删除生成的视频
const deleteVideo = async (videoId: string) => {
if (!confirm("确定要删除这个视频吗?")) return;
try {
await api.delete(`/api/videos/generated/${videoId}`);
fetchGeneratedVideos();
if (selectedVideoId === videoId) {
setSelectedVideoId(null);
setGeneratedVideo(null);
}
} catch (error) {
alert("删除失败: " + error);
}
};
// 上传视频 - 使用 axios 支持进度显示
const handleUpload = async (e: React.ChangeEvent<HTMLInputElement>) => {
const file = e.target.files?.[0];
if (!file) return;
@@ -103,41 +157,37 @@ export default function Home() {
setUploadProgress(0);
setUploadError(null);
const formData = new FormData();
formData.append('file', file);
try {
const formData = new FormData();
formData.append('file', file);
// 使用 XMLHttpRequest 以获取上传进度
const xhr = new XMLHttpRequest();
await api.post('/api/materials', formData, {
headers: { 'Content-Type': 'multipart/form-data' },
onUploadProgress: (progressEvent) => {
if (progressEvent.total) {
const progress = Math.round((progressEvent.loaded / progressEvent.total) * 100);
setUploadProgress(progress);
}
},
});
xhr.upload.onprogress = (event) => {
if (event.lengthComputable) {
const progress = Math.round((event.loaded / event.total) * 100);
setUploadProgress(progress);
}
};
xhr.onload = () => {
setUploadProgress(100);
setIsUploading(false);
if (xhr.status >= 200 && xhr.status < 300) {
fetchMaterials(); // 刷新素材列表
setUploadProgress(100);
} else {
setUploadError(`上传失败: ${xhr.statusText}`);
}
};
xhr.onerror = () => {
fetchMaterials();
setUploadData("");
} catch (err: any) {
console.error("Upload failed:", err);
setIsUploading(false);
setUploadError('网络错误,上传失败');
};
xhr.open('POST', `${API_BASE}/api/materials/`);
xhr.send(formData);
const errorMsg = err.response?.data?.detail || err.message || String(err);
setUploadError(`上传失败: ${errorMsg}`);
}
// 清空 input 以便可以再次选择同一文件
e.target.value = '';
};
// 生成视频
const handleGenerate = async () => {
if (!selectedMaterial || !text.trim()) {
@@ -157,34 +207,34 @@ export default function Home() {
}
// 创建生成任务
const res = await fetch(`${API_BASE}/api/videos/generate`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
material_path: materialObj.path,
text: text,
voice: voice,
add_subtitle: true,
}),
const { data } = await api.post('/api/videos/generate', {
material_path: materialObj.path,
text: text,
voice: voice,
add_subtitle: true,
});
const data = await res.json();
const taskId = data.task_id;
// 轮询任务状态
const pollTask = async () => {
const taskRes = await fetch(`${API_BASE}/api/videos/tasks/${taskId}`);
const taskData: Task = await taskRes.json();
setCurrentTask(taskData);
try {
const { data: taskData } = await api.get(`/api/videos/tasks/${taskId}`);
setCurrentTask(taskData);
if (taskData.status === "completed") {
setGeneratedVideo(`${API_BASE}${taskData.download_url}`);
if (taskData.status === "completed") {
setGeneratedVideo(`${API_BASE}${taskData.download_url}`);
setIsGenerating(false);
fetchGeneratedVideos(); // 刷新历史视频列表
} else if (taskData.status === "failed") {
alert("视频生成失败: " + taskData.message);
setIsGenerating(false);
} else {
setTimeout(pollTask, 1000);
}
} catch (error) {
console.error("轮询任务失败:", error);
setIsGenerating(false);
} else if (taskData.status === "failed") {
alert("视频生成失败: " + taskData.message);
setIsGenerating(false);
} else {
setTimeout(pollTask, 1000);
}
};
@@ -196,14 +246,56 @@ export default function Home() {
};
return (
<div className="min-h-screen bg-gradient-to-br from-slate-900 via-purple-900 to-slate-900">
{/* Header */}
<div className="min-h-dvh">
{/* Header <header className="border-b border-white/10 bg-black/20 backdrop-blur-sm">
<div className="max-w-6xl mx-auto px-6 py-4 flex items-center justify-between">
<h1 className="text-2xl font-bold text-white flex items-center gap-3">
<span className="text-4xl">🎬</span>
ViGent
</h1>
<div className="flex items-center gap-4">
<span className="px-4 py-2 bg-gradient-to-r from-purple-600 to-pink-600 text-white rounded-lg font-semibold">
视频生成
</span>
<Link
href="/publish"
className="px-4 py-2 bg-white/10 hover:bg-white/20 text-white rounded-lg transition-colors"
>
发布管理
</Link>
</div>
</div>
</header> */}
<header className="border-b border-white/10 bg-black/20 backdrop-blur-sm">
<div className="max-w-6xl mx-auto px-6 py-4 flex items-center justify-between">
<h1 className="text-2xl font-bold text-white flex items-center gap-3">
<span className="text-3xl">🎬</span>
<div className="max-w-6xl mx-auto px-4 sm:px-6 py-3 sm:py-4 flex items-center justify-between">
<Link href="/" className="text-xl sm:text-2xl font-bold text-white flex items-center gap-2 sm:gap-3 hover:opacity-80 transition-opacity">
<span className="text-3xl sm:text-4xl">🎬</span>
ViGent
</h1>
</Link>
<div className="flex items-center gap-1 sm:gap-4">
<span className="px-2 sm:px-4 py-1 sm:py-2 text-sm sm:text-base bg-gradient-to-r from-purple-600 to-pink-600 text-white rounded-lg font-semibold">
</span>
<Link
href="/publish"
className="px-2 sm:px-4 py-1 sm:py-2 text-sm sm:text-base bg-white/10 hover:bg-white/20 text-white rounded-lg transition-colors"
>
</Link>
<button
onClick={async () => {
if (confirm('确定要退出登录吗?')) {
try {
await api.post('/api/auth/logout');
} catch (e) { }
window.location.href = '/login';
}
}}
className="px-2 sm:px-4 py-1 sm:py-2 text-sm sm:text-base bg-red-500/10 hover:bg-red-500/20 text-red-200 rounded-lg transition-colors"
>
退
</button>
</div>
</div>
</header>
@@ -212,12 +304,12 @@ export default function Home() {
{/* 左侧: 输入区域 */}
<div className="space-y-6">
{/* 素材选择 */}
<div className="bg-white/5 rounded-2xl p-6 border border-white/10 backdrop-blur-sm">
<div className="flex justify-between items-center mb-4">
<h2 className="text-lg font-semibold text-white flex items-center gap-2">
<div className="bg-white/5 rounded-2xl p-4 sm:p-6 border border-white/10 backdrop-blur-sm">
<div className="flex justify-between items-center gap-2 mb-4">
<h2 className="text-base sm:text-lg font-semibold text-white flex items-center gap-2 whitespace-nowrap">
📹
</h2>
<div className="flex gap-2">
<div className="flex gap-1.5">
{/* 隐藏的文件输入 */}
<input
type="file"
@@ -228,16 +320,16 @@ export default function Home() {
/>
<label
htmlFor="video-upload"
className={`px-3 py-1 text-xs rounded cursor-pointer transition-all ${isUploading
className={`px-2 py-1 text-xs rounded cursor-pointer transition-all whitespace-nowrap ${isUploading
? "bg-gray-600 cursor-not-allowed text-gray-400"
: "bg-gradient-to-r from-purple-600 to-pink-600 hover:from-purple-700 hover:to-pink-700 text-white"
}`}
>
📤
📤
</label>
<button
onClick={fetchMaterials}
className="px-3 py-1 text-xs bg-white/10 hover:bg-white/20 rounded text-gray-300"
className="px-2 py-1 text-xs bg-white/10 hover:bg-white/20 rounded text-gray-300 whitespace-nowrap"
>
🔄
</button>
@@ -290,21 +382,35 @@ export default function Home() {
) : (
<div className="grid grid-cols-2 gap-3">
{materials.map((m) => (
<button
<div
key={m.id}
onClick={() => setSelectedMaterial(m.id)}
className={`p-4 rounded-xl border-2 transition-all text-left ${selectedMaterial === m.id
className={`p-4 rounded-xl border-2 transition-all text-left relative group ${selectedMaterial === m.id
? "border-purple-500 bg-purple-500/20"
: "border-white/10 bg-white/5 hover:border-white/30"
}`}
>
<div className="text-white font-medium truncate">
{m.scene || m.name}
</div>
<div className="text-gray-400 text-sm mt-1">
{m.size_mb.toFixed(1)} MB
</div>
</button>
<button
onClick={() => setSelectedMaterial(m.id)}
className="w-full text-left"
>
<div className="text-white font-medium truncate pr-6">
{m.scene || m.name}
</div>
<div className="text-gray-400 text-sm mt-1">
{m.size_mb.toFixed(1)} MB
</div>
</button>
<button
onClick={(e) => {
e.stopPropagation();
deleteMaterial(m.id);
}}
className="absolute top-2 right-2 p-1 text-gray-500 hover:text-red-400 opacity-0 group-hover:opacity-100 transition-opacity"
title="删除素材"
>
🗑
</button>
</div>
))}
</div>
)}
@@ -424,25 +530,83 @@ export default function Home() {
</div>
{generatedVideo && (
<a
href={generatedVideo}
download
className="mt-4 w-full py-3 rounded-xl bg-green-600 hover:bg-green-700 text-white font-medium flex items-center justify-center gap-2 transition-colors"
<>
<a
href={generatedVideo}
download
className="mt-4 w-full py-3 rounded-xl bg-green-600 hover:bg-green-700 text-white font-medium flex items-center justify-center gap-2 transition-colors"
>
</a>
<Link
href="/publish"
className="mt-3 w-full py-3 rounded-xl bg-gradient-to-r from-purple-600 to-pink-600 hover:from-purple-700 hover:to-pink-700 text-white font-medium flex items-center justify-center gap-2 transition-colors"
>
📤
</Link>
</>
)}
</div>
{/* 历史视频列表 */}
<div className="bg-white/5 rounded-2xl p-6 border border-white/10 backdrop-blur-sm">
<div className="flex justify-between items-center mb-4">
<h2 className="text-lg font-semibold text-white flex items-center gap-2">
📂
</h2>
<button
onClick={fetchGeneratedVideos}
className="px-3 py-1 text-xs bg-white/10 hover:bg-white/20 rounded text-gray-300"
>
</a>
🔄
</button>
</div>
{generatedVideos.length === 0 ? (
<div className="text-center py-4 text-gray-500">
<p></p>
</div>
) : (
<div className="space-y-2 max-h-64 overflow-y-auto hide-scrollbar">
{generatedVideos.map((v) => (
<div
key={v.id}
className={`p-3 rounded-lg border transition-all flex items-center justify-between group ${selectedVideoId === v.id
? "border-purple-500 bg-purple-500/20"
: "border-white/10 bg-white/5 hover:border-white/30"
}`}
>
<button
onClick={() => {
setSelectedVideoId(v.id);
setGeneratedVideo(`${API_BASE}${v.path}`);
}}
className="flex-1 text-left"
>
<div className="text-white text-sm truncate">
{formatDate(v.created_at)}
</div>
<div className="text-gray-400 text-xs">
{v.size_mb.toFixed(1)} MB
</div>
</button>
<button
onClick={(e) => {
e.stopPropagation();
deleteVideo(v.id);
}}
className="p-1 text-gray-500 hover:text-red-400 opacity-0 group-hover:opacity-100 transition-opacity"
title="删除视频"
>
🗑
</button>
</div>
))}
</div>
)}
</div>
</div>
</div>
</main>
{/* Footer */}
<footer className="border-t border-white/10 mt-12">
<div className="max-w-6xl mx-auto px-6 py-4 text-center text-gray-500 text-sm">
ViGent - MuseTalk + EdgeTTS
</div>
</footer>
</div>
);
}

View File

@@ -1,12 +1,28 @@
"use client";
import { useState, useEffect } from "react";
import useSWR from 'swr';
import Link from "next/link";
import api from "@/lib/axios";
// SWR fetcher 使用 axios自动处理 401/403
const fetcher = (url: string) => api.get(url).then((res) => res.data);
// 动态获取 API 地址:服务端使用 localhost客户端使用当前域名
const API_BASE = typeof window !== 'undefined'
? `http://${window.location.hostname}:8006`
: 'http://localhost:8006';
const API_BASE = typeof window === 'undefined'
? 'http://localhost:8006'
: '';
// 格式化日期(避免 Hydration 错误)
const formatDate = (timestamp: number) => {
const d = new Date(timestamp * 1000);
const year = d.getFullYear();
const month = String(d.getMonth() + 1).padStart(2, '0');
const day = String(d.getDate()).padStart(2, '0');
const hour = String(d.getHours()).padStart(2, '0');
const minute = String(d.getMinutes()).padStart(2, '0');
return `${year}/${month}/${day} ${hour}:${minute}`;
};
interface Account {
platform: string;
@@ -29,6 +45,11 @@ export default function PublishPage() {
const [tags, setTags] = useState<string>("");
const [isPublishing, setIsPublishing] = useState(false);
const [publishResults, setPublishResults] = useState<any[]>([]);
const [scheduleMode, setScheduleMode] = useState<"now" | "scheduled">("now");
const [publishTime, setPublishTime] = useState<string>("");
const [qrCodeImage, setQrCodeImage] = useState<string | null>(null);
const [qrPlatform, setQrPlatform] = useState<string | null>(null);
const [isLoadingQR, setIsLoadingQR] = useState(false);
// 加载账号和视频列表
useEffect(() => {
@@ -38,8 +59,7 @@ export default function PublishPage() {
const fetchAccounts = async () => {
try {
const res = await fetch(`${API_BASE}/api/publish/accounts`);
const data = await res.json();
const { data } = await api.get('/api/publish/accounts');
setAccounts(data.accounts || []);
} catch (error) {
console.error("获取账号失败:", error);
@@ -48,20 +68,16 @@ export default function PublishPage() {
const fetchVideos = async () => {
try {
// 获取已生成的视频列表 (从 outputs 目录)
const res = await fetch(`${API_BASE}/api/videos/tasks`);
const data = await res.json();
const { data } = await api.get('/api/videos/generated');
const completedVideos = data.tasks
?.filter((t: any) => t.status === "completed")
.map((t: any) => ({
name: `${t.task_id}_output.mp4`,
path: `outputs/${t.task_id}_output.mp4`,
})) || [];
const videos = (data.videos || []).map((v: any) => ({
name: formatDate(v.created_at) + ` (${v.size_mb.toFixed(1)}MB)`,
path: v.path.startsWith('/') ? v.path.slice(1) : v.path,
}));
setVideos(completedVideos);
if (completedVideos.length > 0) {
setSelectedVideo(completedVideos[0].path);
setVideos(videos);
if (videos.length > 0) {
setSelectedVideo(videos[0].path);
}
} catch (error) {
console.error("获取视频失败:", error);
@@ -89,24 +105,29 @@ export default function PublishPage() {
for (const platform of selectedPlatforms) {
try {
const res = await fetch(`${API_BASE}/api/publish/`, {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
video_path: selectedVideo,
platform,
title,
tags: tagList,
description: "",
}),
const { data: result } = await api.post('/api/publish', {
video_path: selectedVideo,
platform,
title,
tags: tagList,
description: "",
publish_time: scheduleMode === "scheduled" && publishTime
? new Date(publishTime).toISOString()
: null
});
const result = await res.json();
setPublishResults((prev) => [...prev, result]);
} catch (error) {
// 发布成功后10秒自动清除结果
if (result.success) {
setTimeout(() => {
setPublishResults((prev) => prev.filter((r) => r !== result));
}, 10000);
}
} catch (error: any) {
const message = error.response?.data?.detail || String(error);
setPublishResults((prev) => [
...prev,
{ platform, success: false, message: String(error) },
{ platform, success: false, message },
]);
}
}
@@ -114,10 +135,72 @@ export default function PublishPage() {
setIsPublishing(false);
};
// SWR Polling for Login Status
const { data: loginStatus } = useSWR(
qrPlatform ? `${API_BASE}/api/publish/login/status/${qrPlatform}` : null,
fetcher,
{
refreshInterval: 2000,
onSuccess: (data) => {
if (data.success) {
setQrCodeImage(null);
setQrPlatform(null);
alert('✅ 登录成功!');
fetchAccounts();
}
}
}
);
// Timeout logic for QR code (business logic: stop after 2 mins)
useEffect(() => {
let timer: NodeJS.Timeout;
if (qrPlatform) {
timer = setTimeout(() => {
if (qrPlatform) { // Double check active
setQrPlatform(null);
setQrCodeImage(null);
alert('登录超时,请重试');
}
}, 120000);
}
return () => clearTimeout(timer);
}, [qrPlatform]);
const handleLogin = async (platform: string) => {
alert(
`登录功能需要在服务端执行。\n\n请在终端运行:\ncurl -X POST http://localhost:8006/api/publish/login/${platform}`
);
setIsLoadingQR(true);
setQrPlatform(platform); // 立即显示加载弹窗
setQrCodeImage(null); // 清空旧二维码
try {
const { data: result } = await api.post(`/api/publish/login/${platform}`);
if (result.success && result.qr_code) {
setQrCodeImage(result.qr_code);
} else {
setQrPlatform(null);
alert(result.message || '登录失败');
}
} catch (error: any) {
setQrPlatform(null);
alert(`登录失败: ${error.response?.data?.detail || error.message}`);
} finally {
setIsLoadingQR(false);
}
};
const handleLogout = async (platform: string) => {
if (!confirm('确定要注销登录吗?')) return;
try {
const { data: result } = await api.post(`/api/publish/logout/${platform}`);
if (result.success) {
alert('已注销');
fetchAccounts();
} else {
alert(result.message || '注销失败');
}
} catch (error: any) {
alert(`注销失败: ${error.response?.data?.detail || error.message}`);
}
};
const platformIcons: Record<string, string> = {
@@ -129,34 +212,74 @@ export default function PublishPage() {
};
return (
<div className="min-h-screen bg-gradient-to-br from-slate-900 via-purple-900 to-slate-900">
{/* Header */}
<div className="min-h-dvh">
{/* QR码弹窗 */}
{qrPlatform && (
<div className="fixed inset-0 bg-black/80 flex items-center justify-center z-50">
<div className="bg-white rounded-2xl p-8 max-w-md min-w-[320px]">
<h2 className="text-2xl font-bold mb-4 text-center">🔐 {qrPlatform}</h2>
{isLoadingQR ? (
<div className="flex flex-col items-center py-8">
<div className="animate-spin w-16 h-16 border-4 border-purple-500 border-t-transparent rounded-full" />
<p className="text-gray-600 mt-4">...</p>
</div>
) : qrCodeImage ? (
<>
<img
src={`data:image/png;base64,${qrCodeImage}`}
alt="QR Code"
className="w-full h-auto"
/>
<p className="text-center text-gray-600 mt-4">
使
</p>
</>
) : null}
<button
onClick={() => { setQrCodeImage(null); setQrPlatform(null); }}
className="w-full mt-4 px-4 py-2 bg-gray-200 rounded-lg hover:bg-gray-300"
>
</button>
</div>
</div>
)}
{/* Header - 统一样式 */}
<header className="border-b border-white/10 bg-black/20 backdrop-blur-sm">
<div className="max-w-6xl mx-auto px-6 py-4 flex items-center justify-between">
<Link href="/" className="text-2xl font-bold text-white flex items-center gap-3 hover:opacity-80">
<span className="text-3xl">🎬</span>
TalkingHead Agent
<div className="max-w-6xl mx-auto px-4 sm:px-6 py-3 sm:py-4 flex items-center justify-between">
<Link href="/" className="text-xl sm:text-2xl font-bold text-white flex items-center gap-2 sm:gap-3 hover:opacity-80 transition-opacity">
<span className="text-3xl sm:text-4xl">🎬</span>
ViGent
</Link>
<nav className="flex gap-4">
<div className="flex items-center gap-1 sm:gap-4">
<Link
href="/"
className="px-4 py-2 text-gray-400 hover:text-white transition-colors"
className="px-2 sm:px-4 py-1 sm:py-2 text-sm sm:text-base bg-white/10 hover:bg-white/20 text-white rounded-lg transition-colors"
>
</Link>
<Link
href="/publish"
className="px-4 py-2 text-white bg-purple-600 rounded-lg"
>
<span className="px-2 sm:px-4 py-1 sm:py-2 text-sm sm:text-base bg-gradient-to-r from-purple-600 to-pink-600 text-white rounded-lg font-semibold">
</Link>
</nav>
</span>
<button
onClick={async () => {
if (confirm('确定要退出登录吗?')) {
try {
await api.post('/api/auth/logout');
} catch (e) { }
window.location.href = '/login';
}
}}
className="px-2 sm:px-4 py-1 sm:py-2 text-sm sm:text-base bg-red-500/10 hover:bg-red-500/20 text-red-200 rounded-lg transition-colors"
>
退
</button>
</div>
</div>
</header>
<main className="max-w-6xl mx-auto px-6 py-8">
<h1 className="text-3xl font-bold text-white mb-8">📤 </h1>
<div className="grid grid-cols-1 lg:grid-cols-2 gap-8">
{/* 左侧: 账号管理 */}
<div className="space-y-6">
@@ -189,15 +312,31 @@ export default function PublishPage() {
</div>
</div>
</div>
<button
onClick={() => handleLogin(account.platform)}
className={`px-4 py-2 rounded-lg text-sm font-medium transition-colors ${account.logged_in
? "bg-gray-600 text-gray-300"
: "bg-purple-600 hover:bg-purple-700 text-white"
}`}
>
{account.logged_in ? "重新登录" : "登录"}
</button>
<div className="flex gap-2">
{account.logged_in ? (
<>
<button
onClick={() => handleLogin(account.platform)}
className="px-3 py-1 bg-white/10 hover:bg-white/20 text-white text-sm rounded-lg transition-colors"
>
</button>
<button
onClick={() => handleLogout(account.platform)}
className="px-3 py-1 bg-red-500/80 hover:bg-red-600 text-white text-sm rounded-lg transition-colors"
>
</button>
</>
) : (
<button
onClick={() => handleLogin(account.platform)}
className="px-3 py-1 bg-purple-600 hover:bg-purple-700 text-white text-sm rounded-lg transition-colors"
>
🔐
</button>
)}
</div>
</div>
))}
</div>
@@ -223,7 +362,7 @@ export default function PublishPage() {
<select
value={selectedVideo}
onChange={(e) => setSelectedVideo(e.target.value)}
className="w-full p-3 bg-black/30 border border-white/10 rounded-xl text-white"
className="w-full p-3 bg-black/30 border border-white/10 rounded-xl text-white custom-select cursor-pointer hover:border-purple-500/50 transition-colors"
>
{videos.map((v) => (
<option key={v.path} value={v.path}>
@@ -297,17 +436,61 @@ export default function PublishPage() {
)}
</div>
{/* 发布按钮 */}
<button
onClick={handlePublish}
disabled={isPublishing || selectedPlatforms.length === 0}
className={`w-full py-4 rounded-xl font-bold text-lg transition-all ${isPublishing || selectedPlatforms.length === 0
? "bg-gray-600 cursor-not-allowed text-gray-400"
: "bg-gradient-to-r from-green-600 to-teal-600 hover:from-green-700 hover:to-teal-700 text-white"
}`}
>
{isPublishing ? "发布中..." : "🚀 一键发布"}
</button>
{/* 发布按钮区域 */}
<div className="space-y-3">
<div className="flex gap-3">
{/* 立即发布 - 占 3/4 */}
<button
onClick={() => {
setScheduleMode("now");
handlePublish();
}}
disabled={isPublishing || selectedPlatforms.length === 0}
className={`flex-[3] py-4 rounded-xl font-bold text-lg transition-all ${isPublishing || selectedPlatforms.length === 0
? "bg-gray-600 cursor-not-allowed text-gray-400"
: "bg-gradient-to-r from-green-600 to-teal-600 hover:from-green-700 hover:to-teal-700 text-white"
}`}
>
{isPublishing && scheduleMode === "now" ? "发布中..." : "🚀 立即发布"}
</button>
{/* 定时发布 - 占 1/4 */}
<button
onClick={() => setScheduleMode(scheduleMode === "scheduled" ? "now" : "scheduled")}
disabled={isPublishing || selectedPlatforms.length === 0}
className={`flex-1 py-4 rounded-xl font-bold text-base transition-all ${isPublishing || selectedPlatforms.length === 0
? "bg-gray-600 cursor-not-allowed text-gray-400"
: scheduleMode === "scheduled"
? "bg-purple-600 text-white"
: "bg-white/10 hover:bg-white/20 text-white"
}`}
>
</button>
</div>
{/* 定时发布时间选择器 */}
{scheduleMode === "scheduled" && (
<div className="flex gap-3 items-center">
<input
type="datetime-local"
value={publishTime}
onChange={(e) => setPublishTime(e.target.value)}
min={new Date().toISOString().slice(0, 16)}
className="flex-1 p-3 bg-black/30 border border-white/10 rounded-xl text-white"
/>
<button
onClick={handlePublish}
disabled={isPublishing || selectedPlatforms.length === 0 || !publishTime}
className={`px-6 py-3 rounded-xl font-bold transition-all ${isPublishing || selectedPlatforms.length === 0 || !publishTime
? "bg-gray-600 cursor-not-allowed text-gray-400"
: "bg-gradient-to-r from-purple-600 to-pink-600 hover:from-purple-700 hover:to-pink-700 text-white"
}`}
>
{isPublishing && scheduleMode === "scheduled" ? "设置中..." : "确认定时"}
</button>
</div>
)}
</div>
{/* 发布结果 */}
{publishResults.length > 0 && (
@@ -325,6 +508,11 @@ export default function PublishPage() {
<span className="text-white">
{platformIcons[result.platform]} {result.message}
</span>
{result.success && (
<p className="text-green-400/80 text-sm mt-1">
</p>
)}
</div>
))}
</div>

View File

@@ -0,0 +1,158 @@
'use client';
import { useState } from 'react';
import { useRouter } from 'next/navigation';
import { register } from '@/lib/auth';
export default function RegisterPage() {
const router = useRouter();
const [email, setEmail] = useState('');
const [password, setPassword] = useState('');
const [confirmPassword, setConfirmPassword] = useState('');
const [username, setUsername] = useState('');
const [error, setError] = useState('');
const [success, setSuccess] = useState(false);
const [loading, setLoading] = useState(false);
const handleSubmit = async (e: React.FormEvent) => {
e.preventDefault();
setError('');
if (password !== confirmPassword) {
setError('两次输入的密码不一致');
return;
}
if (password.length < 6) {
setError('密码长度至少 6 位');
return;
}
setLoading(true);
try {
const result = await register(email, password, username || undefined);
if (result.success) {
setSuccess(true);
} else {
setError(result.message || '注册失败');
}
} catch (err) {
setError('网络错误,请稍后重试');
} finally {
setLoading(false);
}
};
if (success) {
return (
<div className="min-h-dvh flex items-center justify-center">
<div className="w-full max-w-md p-8 bg-white/10 backdrop-blur-lg rounded-2xl shadow-2xl border border-white/20 text-center">
<div className="mb-6">
<svg className="w-16 h-16 mx-auto text-green-400" fill="none" stroke="currentColor" viewBox="0 0 24 24">
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z" />
</svg>
</div>
<h2 className="text-2xl font-bold text-white mb-4"></h2>
<p className="text-gray-300 mb-6">
</p>
<a
href="/login"
className="inline-block py-3 px-6 bg-gradient-to-r from-purple-600 to-pink-600 text-white font-semibold rounded-lg"
>
</a>
</div>
</div>
);
}
return (
<div className="min-h-dvh flex items-center justify-center">
<div className="w-full max-w-md p-8 bg-white/10 backdrop-blur-lg rounded-2xl shadow-2xl border border-white/20">
<div className="text-center mb-8">
<h1 className="text-3xl font-bold text-white mb-2"></h1>
<p className="text-gray-300"> ViGent </p>
</div>
<form onSubmit={handleSubmit} className="space-y-5">
<div>
<label className="block text-sm font-medium text-gray-200 mb-2">
<span className="text-red-400">*</span>
</label>
<input
type="email"
value={email}
onChange={(e) => setEmail(e.target.value)}
required
className="w-full px-4 py-3 bg-white/5 border border-white/10 rounded-lg text-white placeholder-gray-400 focus:outline-none focus:ring-2 focus:ring-purple-500"
placeholder="your@email.com"
/>
</div>
<div>
<label className="block text-sm font-medium text-gray-200 mb-2">
<span className="text-gray-500">()</span>
</label>
<input
type="text"
value={username}
onChange={(e) => setUsername(e.target.value)}
className="w-full px-4 py-3 bg-white/5 border border-white/10 rounded-lg text-white placeholder-gray-400 focus:outline-none focus:ring-2 focus:ring-purple-500"
placeholder="您的昵称"
/>
</div>
<div>
<label className="block text-sm font-medium text-gray-200 mb-2">
<span className="text-red-400">*</span>
</label>
<input
type="password"
value={password}
onChange={(e) => setPassword(e.target.value)}
required
className="w-full px-4 py-3 bg-white/5 border border-white/10 rounded-lg text-white placeholder-gray-400 focus:outline-none focus:ring-2 focus:ring-purple-500"
placeholder="至少 6 位"
/>
</div>
<div>
<label className="block text-sm font-medium text-gray-200 mb-2">
<span className="text-red-400">*</span>
</label>
<input
type="password"
value={confirmPassword}
onChange={(e) => setConfirmPassword(e.target.value)}
required
className="w-full px-4 py-3 bg-white/5 border border-white/10 rounded-lg text-white placeholder-gray-400 focus:outline-none focus:ring-2 focus:ring-purple-500"
placeholder="再次输入密码"
/>
</div>
{error && (
<div className="p-3 bg-red-500/20 border border-red-500/50 rounded-lg text-red-200 text-sm">
{error}
</div>
)}
<button
type="submit"
disabled={loading}
className="w-full py-3 px-4 bg-gradient-to-r from-purple-600 to-pink-600 hover:from-purple-700 hover:to-pink-700 text-white font-semibold rounded-lg shadow-lg transition-all duration-200 disabled:opacity-50"
>
{loading ? '注册中...' : '注册'}
</button>
</form>
<div className="mt-6 text-center">
<a href="/login" className="text-purple-300 hover:text-purple-200 text-sm">
</a>
</div>
</div>
</div>
);
}

89
frontend/src/lib/auth.ts Normal file
View File

@@ -0,0 +1,89 @@
/**
* 认证工具函数
*/
const API_BASE = typeof window === 'undefined'
? (process.env.NEXT_PUBLIC_API_URL || 'http://localhost:8006')
: '';
export interface User {
id: string;
email: string;
username: string | null;
role: string;
is_active: boolean;
}
export interface AuthResponse {
success: boolean;
message: string;
user?: User;
}
/**
* 用户注册
*/
export async function register(email: string, password: string, username?: string): Promise<AuthResponse> {
const res = await fetch(`${API_BASE}/api/auth/register`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
credentials: 'include',
body: JSON.stringify({ email, password, username })
});
return res.json();
}
/**
* 用户登录
*/
export async function login(email: string, password: string): Promise<AuthResponse> {
const res = await fetch(`${API_BASE}/api/auth/login`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
credentials: 'include',
body: JSON.stringify({ email, password })
});
return res.json();
}
/**
* 用户登出
*/
export async function logout(): Promise<AuthResponse> {
const res = await fetch(`${API_BASE}/api/auth/logout`, {
method: 'POST',
credentials: 'include'
});
return res.json();
}
/**
* 获取当前用户
*/
export async function getCurrentUser(): Promise<User | null> {
try {
const res = await fetch(`${API_BASE}/api/auth/me`, {
credentials: 'include'
});
if (!res.ok) return null;
return res.json();
} catch {
return null;
}
}
/**
* 检查是否已登录
*/
export async function isAuthenticated(): Promise<boolean> {
const user = await getCurrentUser();
return user !== null;
}
/**
* 检查是否是管理员
*/
export async function isAdmin(): Promise<boolean> {
const user = await getCurrentUser();
return user?.role === 'admin';
}

50
frontend/src/lib/axios.ts Normal file
View File

@@ -0,0 +1,50 @@
/**
* Axios 实例配置
* 全局拦截 401/403 响应,自动跳转登录页
*/
import axios from 'axios';
// 动态获取 API 地址:服务端使用 localhost客户端使用当前域名
const API_BASE = typeof window === 'undefined'
? 'http://localhost:8006'
: '';
// 防止重复跳转
let isRedirecting = false;
// 创建 axios 实例
const api = axios.create({
baseURL: API_BASE,
withCredentials: true, // 自动携带 cookie
headers: {
'Content-Type': 'application/json',
},
});
// 响应拦截器 - 全局处理 401/403
api.interceptors.response.use(
(response) => response,
async (error) => {
const status = error.response?.status;
if ((status === 401 || status === 403) && !isRedirecting) {
isRedirecting = true;
// 调用 logout API 清除 HttpOnly cookie
try {
await fetch('/api/auth/logout', { method: 'POST' });
} catch (e) {
// 忽略错误
}
// 跳转登录页
if (typeof window !== 'undefined') {
window.location.replace('/login');
}
}
return Promise.reject(error);
}
);
export default api;

View File

@@ -0,0 +1,33 @@
import { NextResponse } from 'next/server';
import type { NextRequest } from 'next/server';
// 需要登录才能访问的路径
const protectedPaths = ['/', '/publish', '/admin'];
// 公开路径 (无需登录)
const publicPaths = ['/login', '/register'];
export function middleware(request: NextRequest) {
const { pathname } = request.nextUrl;
// 检查是否有 access_token cookie
const token = request.cookies.get('access_token');
// 访问受保护页面但未登录 → 重定向到登录页
if (protectedPaths.some(path => pathname === path || pathname.startsWith(path + '/')) && !token) {
const loginUrl = new URL('/login', request.url);
loginUrl.searchParams.set('from', pathname);
return NextResponse.redirect(loginUrl);
}
// 已登录用户访问登录/注册页 → 重定向到首页
if (publicPaths.includes(pathname) && token) {
return NextResponse.redirect(new URL('/', request.url));
}
return NextResponse.next();
}
export const config = {
matcher: ['/', '/publish/:path*', '/admin/:path*', '/login', '/register']
};

33
frontend/src/proxy.ts Normal file
View File

@@ -0,0 +1,33 @@
import { NextResponse } from 'next/server';
import type { NextRequest } from 'next/server';
// 需要登录才能访问的路径
const protectedPaths = ['/', '/publish', '/admin'];
// 公开路径 (无需登录)
const publicPaths = ['/login', '/register'];
export function proxy(request: NextRequest) {
const { pathname } = request.nextUrl;
// 检查是否有 access_token cookie
const token = request.cookies.get('access_token');
// 访问受保护页面但未登录 → 重定向到登录页
if (protectedPaths.some(path => pathname === path || pathname.startsWith(path + '/')) && !token) {
const loginUrl = new URL('/login', request.url);
loginUrl.searchParams.set('from', pathname);
return NextResponse.redirect(loginUrl);
}
// 已登录用户访问登录/注册页 → 重定向到首页
if (publicPaths.includes(pathname) && token) {
return NextResponse.redirect(new URL('/', request.url));
}
return NextResponse.next();
}
export const config = {
matcher: ['/', '/publish/:path*', '/admin/:path*', '/login', '/register']
};

View File

@@ -139,6 +139,45 @@ CUDA_VISIBLE_DEVICES=1 python -m scripts.inference \
---
---
## 步骤 7: 性能优化 (预加载模型服务)
为了消除每次生成视频时 30-40秒 的模型加载时间,建议运行常驻服务。
### 1. 安装服务依赖
```bash
conda activate latentsync
pip install fastapi uvicorn
```
### 2. 启动服务
**前台运行 (测试)**:
```bash
cd /home/rongye/ProgramFiles/ViGent2/models/LatentSync
# 启动服务 (端口 8007) - 会自动读取 backend/.env 中的 GPU 配置
python -m scripts.server
```
**后台运行 (推荐)**:
```bash
nohup python -m scripts.server > server.log 2>&1 &
```
### 3. 更新配置
修改 `ViGent2/backend/.env`:
```bash
LATENTSYNC_USE_SERVER=True
```
现在,后端通过 API 调用本地常驻服务,生成速度将显著提升。
---
## 故障排除
### CUDA 内存不足

View File

@@ -0,0 +1,23 @@
audio:
num_mels: 80 # Number of mel-spectrogram channels and local conditioning dimensionality
rescale: true # Whether to rescale audio prior to preprocessing
rescaling_max: 0.9 # Rescaling value
use_lws:
false # Use LWS (https://github.com/Jonathan-LeRoux/lws) for STFT and phase reconstruction
# It"s preferred to set True to use with https://github.com/r9y9/wavenet_vocoder
# Does not work if n_ffit is not multiple of hop_size!!
n_fft: 800 # Extra window size is filled with 0 paddings to match this parameter
hop_size: 200 # For 16000Hz, 200 = 12.5 ms (0.0125 * sample_rate)
win_size: 800 # For 16000Hz, 800 = 50 ms (If None, win_size = n_fft) (0.05 * sample_rate)
sample_rate: 16000 # 16000Hz (corresponding to librispeech) (sox --i <filename>)
frame_shift_ms: null
signal_normalization: true
allow_clipping_in_normalization: true
symmetric_mels: true
max_abs_value: 4.0
preemphasize: true # whether to apply filter
preemphasis: 0.97 # filter coefficient.
min_level_db: -100
ref_level_db: 20
fmin: 55
fmax: 7600

View File

@@ -0,0 +1,12 @@
{
"_class_name": "DDIMScheduler",
"beta_end": 0.012,
"beta_schedule": "scaled_linear",
"beta_start": 0.00085,
"clip_sample": false,
"num_train_timesteps": 1000,
"set_alpha_to_one": false,
"steps_offset": 1,
"trained_betas": null,
"skip_prk_steps": true
}

View File

@@ -0,0 +1,46 @@
model:
audio_encoder: # input (1, 80, 52)
in_channels: 1
block_out_channels: [32, 64, 128, 256, 512, 1024]
downsample_factors: [[2, 1], 2, 2, 2, 2, [2, 3]]
attn_blocks: [0, 0, 0, 0, 0, 0]
dropout: 0.0
visual_encoder: # input (64, 32, 32)
in_channels: 64
block_out_channels: [64, 128, 256, 256, 512, 1024]
downsample_factors: [2, 2, 2, 1, 2, 2]
attn_blocks: [0, 0, 0, 0, 0, 0]
dropout: 0.0
ckpt:
resume_ckpt_path: ""
inference_ckpt_path: ""
save_ckpt_steps: 2500
data:
train_output_dir: debug/syncnet
num_val_samples: 1200
batch_size: 120 # 40
gradient_accumulation_steps: 1
num_workers: 12 # 12
latent_space: true
num_frames: 16
resolution: 256
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
train_data_dir: ""
val_fileslist: ""
val_data_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/val
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
lower_half: false
audio_sample_rate: 16000
video_fps: 25
optimizer:
lr: 1e-5
max_grad_norm: 1.0
run:
max_train_steps: 10000000
validation_steps: 2500
mixed_precision_training: true
seed: 42

View File

@@ -0,0 +1,46 @@
model:
audio_encoder: # input (1, 80, 52)
in_channels: 1
block_out_channels: [32, 64, 128, 256, 512, 1024, 2048]
downsample_factors: [[2, 1], 2, 2, 1, 2, 2, [2, 3]]
attn_blocks: [0, 0, 0, 0, 0, 0, 0]
dropout: 0.0
visual_encoder: # input (48, 128, 256)
in_channels: 48
block_out_channels: [64, 128, 256, 256, 512, 1024, 2048, 2048]
downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2]
attn_blocks: [0, 0, 0, 0, 0, 0, 0, 0]
dropout: 0.0
ckpt:
resume_ckpt_path: ""
inference_ckpt_path: ""
save_ckpt_steps: 2500
data:
train_output_dir: debug/syncnet
num_val_samples: 2048
batch_size: 256 # 256
gradient_accumulation_steps: 1
num_workers: 12 # 12
latent_space: false
num_frames: 16
resolution: 256
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
train_data_dir: ""
val_fileslist: ""
val_data_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/val
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
lower_half: true
audio_sample_rate: 16000
video_fps: 25
optimizer:
lr: 1e-5
max_grad_norm: 1.0
run:
max_train_steps: 10000000
validation_steps: 2500
mixed_precision_training: true
seed: 42

View File

@@ -0,0 +1,46 @@
model:
audio_encoder: # input (1, 80, 52)
in_channels: 1
block_out_channels: [32, 64, 128, 256, 512, 1024, 2048]
downsample_factors: [[2, 1], 2, 2, 1, 2, 2, [2, 3]]
attn_blocks: [0, 0, 0, 1, 1, 0, 0]
dropout: 0.0
visual_encoder: # input (48, 128, 256)
in_channels: 48
block_out_channels: [64, 128, 256, 256, 512, 1024, 2048, 2048]
downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2]
attn_blocks: [0, 0, 0, 0, 1, 1, 0, 0]
dropout: 0.0
ckpt:
resume_ckpt_path: ""
inference_ckpt_path: checkpoints/stable_syncnet.pt
save_ckpt_steps: 2500
data:
train_output_dir: debug/syncnet
num_val_samples: 2048
batch_size: 256 # 256
gradient_accumulation_steps: 1
num_workers: 12 # 12
latent_space: false
num_frames: 16
resolution: 256
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
train_data_dir: ""
val_fileslist: ""
val_data_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/val
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
lower_half: true
audio_sample_rate: 16000
video_fps: 25
optimizer:
lr: 1e-5
max_grad_norm: 1.0
run:
max_train_steps: 10000000
validation_steps: 2500
mixed_precision_training: true
seed: 42

View File

@@ -0,0 +1,44 @@
model:
audio_encoder: # input (1, 80, 80)
in_channels: 1
block_out_channels: [64, 128, 256, 256, 512, 1024]
downsample_factors: [2, 2, 2, 2, 2, 2]
dropout: 0.0
visual_encoder: # input (75, 128, 256)
in_channels: 75
block_out_channels: [128, 128, 256, 256, 512, 512, 1024, 1024]
downsample_factors: [[1, 2], 2, 2, 2, 2, 2, 2, 2]
dropout: 0.0
ckpt:
resume_ckpt_path: ""
inference_ckpt_path: ""
save_ckpt_steps: 2500
data:
train_output_dir: debug/syncnet
num_val_samples: 2048
batch_size: 64 # 64
gradient_accumulation_steps: 1
num_workers: 12 # 12
latent_space: false
num_frames: 25
resolution: 256
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
train_data_dir: ""
val_fileslist: ""
val_data_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality/val
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
lower_half: true
audio_sample_rate: 16000
video_fps: 25
optimizer:
lr: 1e-5
max_grad_norm: 1.0
run:
max_train_steps: 10000000
validation_steps: 2500
mixed_precision_training: true
seed: 42

View File

@@ -0,0 +1,96 @@
data:
syncnet_config_path: configs/syncnet/syncnet_16_pixel_attn.yaml
train_output_dir: debug/unet
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
train_data_dir: ""
audio_embeds_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/embeds
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
val_video_path: assets/demo1_video.mp4
val_audio_path: assets/demo1_audio.wav
batch_size: 1 # 24
num_workers: 12 # 12
num_frames: 16
resolution: 256
mask_image_path: latentsync/utils/mask.png
audio_sample_rate: 16000
video_fps: 25
audio_feat_length: [2, 2]
ckpt:
resume_ckpt_path: checkpoints/latentsync_unet.pt
save_ckpt_steps: 10000
run:
pixel_space_supervise: false
use_syncnet: false
sync_loss_weight: 0.05
perceptual_loss_weight: 0.1 # 0.1
recon_loss_weight: 1 # 1
guidance_scale: 1.5 # [1.0 - 3.0]
trepa_loss_weight: 10
inference_steps: 20
seed: 1247
use_mixed_noise: true
mixed_noise_alpha: 1 # 1
mixed_precision_training: true
enable_gradient_checkpointing: true
max_train_steps: 10000000
max_train_epochs: -1
optimizer:
lr: 1e-5
scale_lr: false
max_grad_norm: 1.0
lr_scheduler: constant
lr_warmup_steps: 0
model:
act_fn: silu
add_audio_layer: true
attention_head_dim: 8
block_out_channels: [320, 640, 1280, 1280]
center_input_sample: false
cross_attention_dim: 384
down_block_types:
[
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D",
]
mid_block_type: UNetMidBlock3DCrossAttn
up_block_types:
[
"UpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
]
downsample_padding: 1
flip_sin_to_cos: true
freq_shift: 0
in_channels: 13 # 49
layers_per_block: 2
mid_block_scale_factor: 1
norm_eps: 1e-5
norm_num_groups: 32
out_channels: 4 # 16
sample_size: 64
resnet_time_scale_shift: default # Choose between [default, scale_shift]
use_motion_module: false
motion_module_resolutions: [1, 2, 4, 8]
motion_module_mid_block: false
motion_module_decoder_only: false
motion_module_type: Vanilla
motion_module_kwargs:
num_attention_heads: 8
num_transformer_block: 1
attention_block_types:
- Temporal_Self
- Temporal_Self
temporal_position_encoding: true
temporal_position_encoding_max_len: 24
temporal_attention_dim_div: 1
zero_initialize: true

View File

@@ -0,0 +1,96 @@
data:
syncnet_config_path: configs/syncnet/syncnet_16_pixel_attn.yaml
train_output_dir: debug/unet
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
train_data_dir: ""
audio_embeds_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/embeds
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
val_video_path: assets/demo1_video.mp4
val_audio_path: assets/demo1_audio.wav
batch_size: 1 # 8
num_workers: 12 # 12
num_frames: 16
resolution: 512
mask_image_path: latentsync/utils/mask.png
audio_sample_rate: 16000
video_fps: 25
audio_feat_length: [2, 2]
ckpt:
resume_ckpt_path: checkpoints/latentsync_unet.pt
save_ckpt_steps: 10000
run:
pixel_space_supervise: false
use_syncnet: false
sync_loss_weight: 0.05
perceptual_loss_weight: 0.1 # 0.1
recon_loss_weight: 1 # 1
guidance_scale: 1.5 # [1.0 - 3.0]
trepa_loss_weight: 10
inference_steps: 20
seed: 1247
use_mixed_noise: true
mixed_noise_alpha: 1 # 1
mixed_precision_training: true
enable_gradient_checkpointing: true
max_train_steps: 10000000
max_train_epochs: -1
optimizer:
lr: 1e-5
scale_lr: false
max_grad_norm: 1.0
lr_scheduler: constant
lr_warmup_steps: 0
model:
act_fn: silu
add_audio_layer: true
attention_head_dim: 8
block_out_channels: [320, 640, 1280, 1280]
center_input_sample: false
cross_attention_dim: 384
down_block_types:
[
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D",
]
mid_block_type: UNetMidBlock3DCrossAttn
up_block_types:
[
"UpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
]
downsample_padding: 1
flip_sin_to_cos: true
freq_shift: 0
in_channels: 13 # 49
layers_per_block: 2
mid_block_scale_factor: 1
norm_eps: 1e-5
norm_num_groups: 32
out_channels: 4 # 16
sample_size: 64
resnet_time_scale_shift: default # Choose between [default, scale_shift]
use_motion_module: false
motion_module_resolutions: [1, 2, 4, 8]
motion_module_mid_block: false
motion_module_decoder_only: false
motion_module_type: Vanilla
motion_module_kwargs:
num_attention_heads: 8
num_transformer_block: 1
attention_block_types:
- Temporal_Self
- Temporal_Self
temporal_position_encoding: true
temporal_position_encoding_max_len: 24
temporal_attention_dim_div: 1
zero_initialize: true

View File

@@ -0,0 +1,99 @@
data:
syncnet_config_path: configs/syncnet/syncnet_16_pixel_attn.yaml
train_output_dir: debug/unet
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
train_data_dir: ""
audio_embeds_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/embeds
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
val_video_path: assets/demo1_video.mp4
val_audio_path: assets/demo1_audio.wav
batch_size: 1 # 4
num_workers: 12 # 12
num_frames: 16
resolution: 256
mask_image_path: latentsync/utils/mask.png
audio_sample_rate: 16000
video_fps: 25
audio_feat_length: [2, 2]
ckpt:
resume_ckpt_path: checkpoints/latentsync_unet.pt
save_ckpt_steps: 10000
run:
pixel_space_supervise: true
use_syncnet: true
sync_loss_weight: 0.05
perceptual_loss_weight: 0.1 # 0.1
recon_loss_weight: 1 # 1
guidance_scale: 1.5 # [1.0 - 3.0]
trepa_loss_weight: 10
inference_steps: 20
trainable_modules:
- motion_modules.
- attentions.
seed: 1247
use_mixed_noise: true
mixed_noise_alpha: 1 # 1
mixed_precision_training: true
enable_gradient_checkpointing: true
max_train_steps: 10000000
max_train_epochs: -1
optimizer:
lr: 1e-5
scale_lr: false
max_grad_norm: 1.0
lr_scheduler: constant
lr_warmup_steps: 0
model:
act_fn: silu
add_audio_layer: true
attention_head_dim: 8
block_out_channels: [320, 640, 1280, 1280]
center_input_sample: false
cross_attention_dim: 384
down_block_types:
[
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D",
]
mid_block_type: UNetMidBlock3DCrossAttn
up_block_types:
[
"UpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
]
downsample_padding: 1
flip_sin_to_cos: true
freq_shift: 0
in_channels: 13 # 49
layers_per_block: 2
mid_block_scale_factor: 1
norm_eps: 1e-5
norm_num_groups: 32
out_channels: 4 # 16
sample_size: 64
resnet_time_scale_shift: default # Choose between [default, scale_shift]
use_motion_module: true
motion_module_resolutions: [1, 2, 4, 8]
motion_module_mid_block: false
motion_module_decoder_only: false
motion_module_type: Vanilla
motion_module_kwargs:
num_attention_heads: 8
num_transformer_block: 1
attention_block_types:
- Temporal_Self
- Temporal_Self
temporal_position_encoding: true
temporal_position_encoding_max_len: 24
temporal_attention_dim_div: 1
zero_initialize: true

View File

@@ -0,0 +1,99 @@
data:
syncnet_config_path: configs/syncnet/syncnet_16_pixel_attn.yaml
train_output_dir: debug/unet
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
train_data_dir: ""
audio_embeds_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/embeds
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
val_video_path: assets/demo1_video.mp4
val_audio_path: assets/demo1_audio.wav
batch_size: 1 # 4
num_workers: 12 # 12
num_frames: 16
resolution: 512
mask_image_path: latentsync/utils/mask.png
audio_sample_rate: 16000
video_fps: 25
audio_feat_length: [2, 2]
ckpt:
resume_ckpt_path: checkpoints/latentsync_unet.pt
save_ckpt_steps: 10000
run:
pixel_space_supervise: true
use_syncnet: true
sync_loss_weight: 0.05
perceptual_loss_weight: 0.1 # 0.1
recon_loss_weight: 1 # 1
guidance_scale: 1.5 # [1.0 - 3.0]
trepa_loss_weight: 10
inference_steps: 20
trainable_modules:
- motion_modules.
- attentions.
seed: 1247
use_mixed_noise: true
mixed_noise_alpha: 1 # 1
mixed_precision_training: true
enable_gradient_checkpointing: true
max_train_steps: 10000000
max_train_epochs: -1
optimizer:
lr: 1e-5
scale_lr: false
max_grad_norm: 1.0
lr_scheduler: constant
lr_warmup_steps: 0
model:
act_fn: silu
add_audio_layer: true
attention_head_dim: 8
block_out_channels: [320, 640, 1280, 1280]
center_input_sample: false
cross_attention_dim: 384
down_block_types:
[
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D",
]
mid_block_type: UNetMidBlock3DCrossAttn
up_block_types:
[
"UpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
]
downsample_padding: 1
flip_sin_to_cos: true
freq_shift: 0
in_channels: 13 # 49
layers_per_block: 2
mid_block_scale_factor: 1
norm_eps: 1e-5
norm_num_groups: 32
out_channels: 4 # 16
sample_size: 64
resnet_time_scale_shift: default # Choose between [default, scale_shift]
use_motion_module: true
motion_module_resolutions: [1, 2, 4, 8]
motion_module_mid_block: false
motion_module_decoder_only: false
motion_module_type: Vanilla
motion_module_kwargs:
num_attention_heads: 8
num_transformer_block: 1
attention_block_types:
- Temporal_Self
- Temporal_Self
temporal_position_encoding: true
temporal_position_encoding_max_len: 24
temporal_attention_dim_div: 1
zero_initialize: true

View File

@@ -0,0 +1,99 @@
data:
syncnet_config_path: configs/syncnet/syncnet_16_pixel_attn.yaml
train_output_dir: debug/unet
train_fileslist: /mnt/bn/maliva-gen-ai-v2/chunyu.li/fileslist/data_v10_core.txt
train_data_dir: ""
audio_embeds_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/embeds
audio_mel_cache_dir: /mnt/bn/maliva-gen-ai-v2/chunyu.li/audio_cache/mel
val_video_path: assets/demo1_video.mp4
val_audio_path: assets/demo1_audio.wav
batch_size: 1 # 4
num_workers: 12 # 12
num_frames: 16
resolution: 256
mask_image_path: latentsync/utils/mask.png
audio_sample_rate: 16000
video_fps: 25
audio_feat_length: [2, 2]
ckpt:
resume_ckpt_path: checkpoints/latentsync_unet.pt
save_ckpt_steps: 10000
run:
pixel_space_supervise: true
use_syncnet: true
sync_loss_weight: 0.05
perceptual_loss_weight: 0.1 # 0.1
recon_loss_weight: 1 # 1
guidance_scale: 1.5 # [1.0 - 3.0]
trepa_loss_weight: 0
inference_steps: 20
trainable_modules:
- motion_modules.
- attn2.
seed: 1247
use_mixed_noise: true
mixed_noise_alpha: 1 # 1
mixed_precision_training: true
enable_gradient_checkpointing: true
max_train_steps: 10000000
max_train_epochs: -1
optimizer:
lr: 1e-5
scale_lr: false
max_grad_norm: 1.0
lr_scheduler: constant
lr_warmup_steps: 0
model:
act_fn: silu
add_audio_layer: true
attention_head_dim: 8
block_out_channels: [320, 640, 1280, 1280]
center_input_sample: false
cross_attention_dim: 384
down_block_types:
[
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D",
]
mid_block_type: UNetMidBlock3DCrossAttn
up_block_types:
[
"UpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
"CrossAttnUpBlock3D",
]
downsample_padding: 1
flip_sin_to_cos: true
freq_shift: 0
in_channels: 13 # 49
layers_per_block: 2
mid_block_scale_factor: 1
norm_eps: 1e-5
norm_num_groups: 32
out_channels: 4 # 16
sample_size: 64
resnet_time_scale_shift: default # Choose between [default, scale_shift]
use_motion_module: true
motion_module_resolutions: [1, 2, 4, 8]
motion_module_mid_block: false
motion_module_decoder_only: true
motion_module_type: Vanilla
motion_module_kwargs:
num_attention_heads: 8
num_transformer_block: 1
attention_block_types:
- Temporal_Self
- Temporal_Self
temporal_position_encoding: true
temporal_position_encoding_max_len: 24
temporal_attention_dim_div: 1
zero_initialize: true

View File

@@ -0,0 +1,139 @@
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import numpy as np
from torch.utils.data import Dataset
import torch
import random
from ..utils.util import gather_video_paths_recursively
from ..utils.image_processor import ImageProcessor
from ..utils.audio import melspectrogram
import math
from pathlib import Path
from decord import AudioReader, VideoReader, cpu
class SyncNetDataset(Dataset):
def __init__(self, data_dir: str, fileslist: str, config):
if fileslist != "":
with open(fileslist) as file:
self.video_paths = [line.rstrip() for line in file]
elif data_dir != "":
self.video_paths = gather_video_paths_recursively(data_dir)
else:
raise ValueError("data_dir and fileslist cannot be both empty")
self.resolution = config.data.resolution
self.num_frames = config.data.num_frames
self.mel_window_length = math.ceil(self.num_frames / 5 * 16)
self.audio_sample_rate = config.data.audio_sample_rate
self.video_fps = config.data.video_fps
self.image_processor = ImageProcessor(resolution=config.data.resolution)
self.audio_mel_cache_dir = config.data.audio_mel_cache_dir
Path(self.audio_mel_cache_dir).mkdir(parents=True, exist_ok=True)
def __len__(self):
return len(self.video_paths)
def read_audio(self, video_path: str):
ar = AudioReader(video_path, ctx=cpu(self.worker_id), sample_rate=self.audio_sample_rate)
original_mel = melspectrogram(ar[:].asnumpy().squeeze(0))
return torch.from_numpy(original_mel)
def crop_audio_window(self, original_mel, start_index):
start_idx = int(80.0 * (start_index / float(self.video_fps)))
end_idx = start_idx + self.mel_window_length
return original_mel[:, start_idx:end_idx].unsqueeze(0)
def get_frames(self, video_reader: VideoReader):
total_num_frames = len(video_reader)
start_idx = random.randint(0, total_num_frames - self.num_frames)
frames_index = np.arange(start_idx, start_idx + self.num_frames, dtype=int)
while True:
wrong_start_idx = random.randint(0, total_num_frames - self.num_frames)
if wrong_start_idx == start_idx:
continue
wrong_frames_index = np.arange(wrong_start_idx, wrong_start_idx + self.num_frames, dtype=int)
break
frames = video_reader.get_batch(frames_index).asnumpy()
wrong_frames = video_reader.get_batch(wrong_frames_index).asnumpy()
return frames, wrong_frames, start_idx
def worker_init_fn(self, worker_id):
self.worker_id = worker_id
def __getitem__(self, idx):
while True:
try:
idx = random.randint(0, len(self) - 1)
# Get video file path
video_path = self.video_paths[idx]
vr = VideoReader(video_path, ctx=cpu(self.worker_id))
if len(vr) < 2 * self.num_frames:
continue
frames, wrong_frames, start_idx = self.get_frames(vr)
mel_cache_path = os.path.join(
self.audio_mel_cache_dir, os.path.basename(video_path).replace(".mp4", "_mel.pt")
)
if os.path.isfile(mel_cache_path):
try:
original_mel = torch.load(mel_cache_path, weights_only=True)
except Exception as e:
print(f"{type(e).__name__} - {e} - {mel_cache_path}")
os.remove(mel_cache_path)
original_mel = self.read_audio(video_path)
torch.save(original_mel, mel_cache_path)
else:
original_mel = self.read_audio(video_path)
torch.save(original_mel, mel_cache_path)
mel = self.crop_audio_window(original_mel, start_idx)
if mel.shape[-1] != self.mel_window_length:
continue
if random.choice([True, False]):
y = torch.ones(1).float()
chosen_frames = frames
else:
y = torch.zeros(1).float()
chosen_frames = wrong_frames
chosen_frames = self.image_processor.process_images(chosen_frames)
vr.seek(0) # avoid memory leak
break
except Exception as e: # Handle the exception of face not detcted
print(f"{type(e).__name__} - {e} - {video_path}")
if "vr" in locals():
vr.seek(0) # avoid memory leak
sample = dict(frames=chosen_frames, audio_samples=mel, y=y)
return sample

View File

@@ -0,0 +1,152 @@
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import math
import numpy as np
from torch.utils.data import Dataset
import torch
import random
import cv2
from ..utils.image_processor import ImageProcessor, load_fixed_mask
from ..utils.audio import melspectrogram
from decord import AudioReader, VideoReader, cpu
import torch.nn.functional as F
from pathlib import Path
class UNetDataset(Dataset):
def __init__(self, train_data_dir: str, config):
if config.data.train_fileslist != "":
with open(config.data.train_fileslist) as file:
self.video_paths = [line.rstrip() for line in file]
elif train_data_dir != "":
self.video_paths = []
for file in os.listdir(train_data_dir):
if file.endswith(".mp4"):
self.video_paths.append(os.path.join(train_data_dir, file))
else:
raise ValueError("data_dir and fileslist cannot be both empty")
self.resolution = config.data.resolution
self.num_frames = config.data.num_frames
self.mel_window_length = math.ceil(self.num_frames / 5 * 16)
self.audio_sample_rate = config.data.audio_sample_rate
self.video_fps = config.data.video_fps
self.image_processor = ImageProcessor(
self.resolution, mask_image=load_fixed_mask(self.resolution, config.data.mask_image_path)
)
self.load_audio_data = config.model.add_audio_layer and config.run.use_syncnet
self.audio_mel_cache_dir = config.data.audio_mel_cache_dir
Path(self.audio_mel_cache_dir).mkdir(parents=True, exist_ok=True)
def __len__(self):
return len(self.video_paths)
def read_audio(self, video_path: str):
ar = AudioReader(video_path, ctx=cpu(self.worker_id), sample_rate=self.audio_sample_rate)
original_mel = melspectrogram(ar[:].asnumpy().squeeze(0))
return torch.from_numpy(original_mel)
def crop_audio_window(self, original_mel, start_index):
start_idx = int(80.0 * (start_index / float(self.video_fps)))
end_idx = start_idx + self.mel_window_length
return original_mel[:, start_idx:end_idx].unsqueeze(0)
def get_frames(self, video_reader: VideoReader):
total_num_frames = len(video_reader)
start_idx = random.randint(0, total_num_frames - self.num_frames)
gt_frames_index = np.arange(start_idx, start_idx + self.num_frames, dtype=int)
while True:
ref_start_idx = random.randint(0, total_num_frames - self.num_frames)
if ref_start_idx > start_idx - self.num_frames and ref_start_idx < start_idx + self.num_frames:
continue
ref_frames_index = np.arange(ref_start_idx, ref_start_idx + self.num_frames, dtype=int)
break
gt_frames = video_reader.get_batch(gt_frames_index).asnumpy()
ref_frames = video_reader.get_batch(ref_frames_index).asnumpy()
return gt_frames, ref_frames, start_idx
def worker_init_fn(self, worker_id):
self.worker_id = worker_id
def __getitem__(self, idx):
while True:
try:
idx = random.randint(0, len(self) - 1)
# Get video file path
video_path = self.video_paths[idx]
vr = VideoReader(video_path, ctx=cpu(self.worker_id))
if len(vr) < 3 * self.num_frames:
continue
gt_frames, ref_frames, start_idx = self.get_frames(vr)
if self.load_audio_data:
mel_cache_path = os.path.join(
self.audio_mel_cache_dir, os.path.basename(video_path).replace(".mp4", "_mel.pt")
)
if os.path.isfile(mel_cache_path):
try:
original_mel = torch.load(mel_cache_path, weights_only=True)
except Exception as e:
print(f"{type(e).__name__} - {e} - {mel_cache_path}")
os.remove(mel_cache_path)
original_mel = self.read_audio(video_path)
torch.save(original_mel, mel_cache_path)
else:
original_mel = self.read_audio(video_path)
torch.save(original_mel, mel_cache_path)
mel = self.crop_audio_window(original_mel, start_idx)
if mel.shape[-1] != self.mel_window_length:
continue
else:
mel = []
gt_pixel_values, masked_pixel_values, masks = self.image_processor.prepare_masks_and_masked_images(
gt_frames, affine_transform=False
) # (f, c, h, w)
ref_pixel_values = self.image_processor.process_images(ref_frames)
vr.seek(0) # avoid memory leak
break
except Exception as e: # Handle the exception of face not detcted
print(f"{type(e).__name__} - {e} - {video_path}")
if "vr" in locals():
vr.seek(0) # avoid memory leak
sample = dict(
gt_pixel_values=gt_pixel_values,
masked_pixel_values=masked_pixel_values,
ref_pixel_values=ref_pixel_values,
mel=mel,
masks=masks,
video_path=video_path,
start_idx=start_idx,
)
return sample

View File

@@ -0,0 +1,280 @@
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
from dataclasses import dataclass
from typing import Optional
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import ModelMixin
from diffusers.utils import BaseOutput
from diffusers.models.attention import FeedForward, AdaLayerNorm
from einops import rearrange, repeat
@dataclass
class Transformer3DModelOutput(BaseOutput):
sample: torch.FloatTensor
class Transformer3DModel(ModelMixin, ConfigMixin):
@register_to_config
def __init__(
self,
num_attention_heads: int = 16,
attention_head_dim: int = 88,
in_channels: Optional[int] = None,
num_layers: int = 1,
dropout: float = 0.0,
norm_num_groups: int = 32,
cross_attention_dim: Optional[int] = None,
attention_bias: bool = False,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
use_linear_projection: bool = False,
only_cross_attention: bool = False,
upcast_attention: bool = False,
add_audio_layer=False,
):
super().__init__()
self.use_linear_projection = use_linear_projection
self.num_attention_heads = num_attention_heads
self.attention_head_dim = attention_head_dim
inner_dim = num_attention_heads * attention_head_dim
# Define input layers
self.in_channels = in_channels
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
if use_linear_projection:
self.proj_in = nn.Linear(in_channels, inner_dim)
else:
self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
# Define transformers blocks
self.transformer_blocks = nn.ModuleList(
[
BasicTransformerBlock(
inner_dim,
num_attention_heads,
attention_head_dim,
dropout=dropout,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
num_embeds_ada_norm=num_embeds_ada_norm,
attention_bias=attention_bias,
upcast_attention=upcast_attention,
add_audio_layer=add_audio_layer,
)
for d in range(num_layers)
]
)
# Define output layers
if use_linear_projection:
self.proj_out = nn.Linear(in_channels, inner_dim)
else:
self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
# Input
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
video_length = hidden_states.shape[2]
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
if not self.use_linear_projection:
hidden_states = self.proj_in(hidden_states)
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
else:
inner_dim = hidden_states.shape[1]
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
hidden_states = self.proj_in(hidden_states)
# Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
timestep=timestep,
video_length=video_length,
)
# Output
if not self.use_linear_projection:
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
hidden_states = self.proj_out(hidden_states)
else:
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
if not return_dict:
return (output,)
return Transformer3DModelOutput(sample=output)
class BasicTransformerBlock(nn.Module):
def __init__(
self,
dim: int,
num_attention_heads: int,
attention_head_dim: int,
dropout=0.0,
cross_attention_dim: Optional[int] = None,
activation_fn: str = "geglu",
num_embeds_ada_norm: Optional[int] = None,
attention_bias: bool = False,
upcast_attention: bool = False,
add_audio_layer=False,
):
super().__init__()
self.use_ada_layer_norm = num_embeds_ada_norm is not None
self.add_audio_layer = add_audio_layer
self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
self.attn1 = Attention(
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
# Cross-attn
if add_audio_layer:
self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
self.attn2 = Attention(
query_dim=dim,
cross_attention_dim=cross_attention_dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
)
else:
self.attn2 = None
# Feed-forward
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.norm3 = nn.LayerNorm(dim)
def forward(
self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None
):
norm_hidden_states = (
self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
)
hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
if self.attn2 is not None and encoder_hidden_states is not None:
if encoder_hidden_states.dim() == 4:
encoder_hidden_states = rearrange(encoder_hidden_states, "b f s d -> (b f) s d")
norm_hidden_states = (
self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
)
hidden_states = (
self.attn2(
norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
)
+ hidden_states
)
# Feed-forward
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
return hidden_states
class Attention(nn.Module):
def __init__(
self,
query_dim: int,
cross_attention_dim: Optional[int] = None,
heads: int = 8,
dim_head: int = 64,
dropout: float = 0.0,
bias=False,
upcast_attention: bool = False,
upcast_softmax: bool = False,
norm_num_groups: Optional[int] = None,
):
super().__init__()
inner_dim = dim_head * heads
cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
self.upcast_attention = upcast_attention
self.upcast_softmax = upcast_softmax
self.scale = dim_head**-0.5
self.heads = heads
if norm_num_groups is not None:
self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
else:
self.group_norm = None
self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
self.to_out = nn.ModuleList([])
self.to_out.append(nn.Linear(inner_dim, query_dim))
self.to_out.append(nn.Dropout(dropout))
def split_heads(self, tensor):
batch_size, seq_len, dim = tensor.shape
tensor = tensor.reshape(batch_size, seq_len, self.heads, dim // self.heads)
tensor = tensor.permute(0, 2, 1, 3)
return tensor
def concat_heads(self, tensor):
batch_size, heads, seq_len, head_dim = tensor.shape
tensor = tensor.permute(0, 2, 1, 3)
tensor = tensor.reshape(batch_size, seq_len, heads * head_dim)
return tensor
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states)
query = self.split_heads(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
key = self.split_heads(key)
value = self.split_heads(value)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
# Use PyTorch native implementation of FlashAttention-2
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
hidden_states = self.concat_heads(hidden_states)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
return hidden_states

View File

@@ -0,0 +1,313 @@
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
# Actually we don't use the motion module in the final version of LatentSync
# When we started the project, we used the codebase of AnimateDiff and tried motion module
# But the results are poor, and we decied to leave the code here for possible future usage
from dataclasses import dataclass
import torch
import torch.nn.functional as F
from torch import nn
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import ModelMixin
from diffusers.utils import BaseOutput
from diffusers.models.attention import FeedForward
from .attention import Attention
from einops import rearrange, repeat
import math
from .utils import zero_module
@dataclass
class TemporalTransformer3DModelOutput(BaseOutput):
sample: torch.FloatTensor
def get_motion_module(in_channels, motion_module_type: str, motion_module_kwargs: dict):
if motion_module_type == "Vanilla":
return VanillaTemporalModule(
in_channels=in_channels,
**motion_module_kwargs,
)
else:
raise ValueError
class VanillaTemporalModule(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads=8,
num_transformer_block=2,
attention_block_types=("Temporal_Self", "Temporal_Self"),
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=24,
temporal_attention_dim_div=1,
zero_initialize=True,
):
super().__init__()
self.temporal_transformer = TemporalTransformer3DModel(
in_channels=in_channels,
num_attention_heads=num_attention_heads,
attention_head_dim=in_channels // num_attention_heads // temporal_attention_dim_div,
num_layers=num_transformer_block,
attention_block_types=attention_block_types,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
if zero_initialize:
self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
def forward(self, input_tensor, temb, encoder_hidden_states, attention_mask=None, anchor_frame_idx=None):
hidden_states = input_tensor
hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
output = hidden_states
return output
class TemporalTransformer3DModel(nn.Module):
def __init__(
self,
in_channels,
num_attention_heads,
attention_head_dim,
num_layers,
attention_block_types=(
"Temporal_Self",
"Temporal_Self",
),
dropout=0.0,
norm_num_groups=32,
cross_attention_dim=768,
activation_fn="geglu",
attention_bias=False,
upcast_attention=False,
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=24,
):
super().__init__()
inner_dim = num_attention_heads * attention_head_dim
self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
self.proj_in = nn.Linear(in_channels, inner_dim)
self.transformer_blocks = nn.ModuleList(
[
TemporalTransformerBlock(
dim=inner_dim,
num_attention_heads=num_attention_heads,
attention_head_dim=attention_head_dim,
attention_block_types=attention_block_types,
dropout=dropout,
norm_num_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
activation_fn=activation_fn,
attention_bias=attention_bias,
upcast_attention=upcast_attention,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
for d in range(num_layers)
]
)
self.proj_out = nn.Linear(inner_dim, in_channels)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
video_length = hidden_states.shape[2]
hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
batch, channel, height, weight = hidden_states.shape
residual = hidden_states
hidden_states = self.norm(hidden_states)
hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, channel)
hidden_states = self.proj_in(hidden_states)
# Transformer Blocks
for block in self.transformer_blocks:
hidden_states = block(
hidden_states, encoder_hidden_states=encoder_hidden_states, video_length=video_length
)
# output
hidden_states = self.proj_out(hidden_states)
hidden_states = hidden_states.reshape(batch, height, weight, channel).permute(0, 3, 1, 2).contiguous()
output = hidden_states + residual
output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
return output
class TemporalTransformerBlock(nn.Module):
def __init__(
self,
dim,
num_attention_heads,
attention_head_dim,
attention_block_types=(
"Temporal_Self",
"Temporal_Self",
),
dropout=0.0,
norm_num_groups=32,
cross_attention_dim=768,
activation_fn="geglu",
attention_bias=False,
upcast_attention=False,
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=24,
):
super().__init__()
attention_blocks = []
norms = []
for block_name in attention_block_types:
attention_blocks.append(
VersatileAttention(
attention_mode=block_name.split("_")[0],
cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
query_dim=dim,
heads=num_attention_heads,
dim_head=attention_head_dim,
dropout=dropout,
bias=attention_bias,
upcast_attention=upcast_attention,
cross_frame_attention_mode=cross_frame_attention_mode,
temporal_position_encoding=temporal_position_encoding,
temporal_position_encoding_max_len=temporal_position_encoding_max_len,
)
)
norms.append(nn.LayerNorm(dim))
self.attention_blocks = nn.ModuleList(attention_blocks)
self.norms = nn.ModuleList(norms)
self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
self.ff_norm = nn.LayerNorm(dim)
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
for attention_block, norm in zip(self.attention_blocks, self.norms):
norm_hidden_states = norm(hidden_states)
hidden_states = (
attention_block(
norm_hidden_states,
encoder_hidden_states=encoder_hidden_states if attention_block.is_cross_attention else None,
video_length=video_length,
)
+ hidden_states
)
hidden_states = self.ff(self.ff_norm(hidden_states)) + hidden_states
output = hidden_states
return output
class PositionalEncoding(nn.Module):
def __init__(self, d_model, dropout=0.0, max_len=24):
super().__init__()
self.dropout = nn.Dropout(p=dropout)
position = torch.arange(max_len).unsqueeze(1)
div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))
pe = torch.zeros(1, max_len, d_model)
pe[0, :, 0::2] = torch.sin(position * div_term)
pe[0, :, 1::2] = torch.cos(position * div_term)
self.register_buffer("pe", pe)
def forward(self, x):
x = x + self.pe[:, : x.size(1)]
return self.dropout(x)
class VersatileAttention(Attention):
def __init__(
self,
attention_mode=None,
cross_frame_attention_mode=None,
temporal_position_encoding=False,
temporal_position_encoding_max_len=24,
*args,
**kwargs,
):
super().__init__(*args, **kwargs)
assert attention_mode == "Temporal"
self.attention_mode = attention_mode
self.is_cross_attention = kwargs["cross_attention_dim"] is not None
self.pos_encoder = (
PositionalEncoding(kwargs["query_dim"], dropout=0.0, max_len=temporal_position_encoding_max_len)
if (temporal_position_encoding and attention_mode == "Temporal")
else None
)
def extra_repr(self):
return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
if self.attention_mode == "Temporal":
s = hidden_states.shape[1]
hidden_states = rearrange(hidden_states, "(b f) s c -> (b s) f c", f=video_length)
if self.pos_encoder is not None:
hidden_states = self.pos_encoder(hidden_states)
##### This section will not be executed #####
encoder_hidden_states = (
repeat(encoder_hidden_states, "b n c -> (b s) n c", s=s)
if encoder_hidden_states is not None
else encoder_hidden_states
)
#############################################
else:
raise NotImplementedError
if self.group_norm is not None:
hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
query = self.to_q(hidden_states)
query = self.split_heads(query)
encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
key = self.to_k(encoder_hidden_states)
value = self.to_v(encoder_hidden_states)
key = self.split_heads(key)
value = self.split_heads(value)
if attention_mask is not None:
if attention_mask.shape[-1] != query.shape[1]:
target_length = query.shape[1]
attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
# Use PyTorch native implementation of FlashAttention-2
hidden_states = F.scaled_dot_product_attention(query, key, value, attn_mask=attention_mask)
hidden_states = self.concat_heads(hidden_states)
# linear proj
hidden_states = self.to_out[0](hidden_states)
# dropout
hidden_states = self.to_out[1](hidden_states)
if self.attention_mode == "Temporal":
hidden_states = rearrange(hidden_states, "(b s) f c -> (b f) s c", s=s)
return hidden_states

View File

@@ -0,0 +1,228 @@
# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
class InflatedConv3d(nn.Conv2d):
def forward(self, x):
video_length = x.shape[2]
x = rearrange(x, "b c f h w -> (b f) c h w")
x = super().forward(x)
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
return x
class InflatedGroupNorm(nn.GroupNorm):
def forward(self, x):
video_length = x.shape[2]
x = rearrange(x, "b c f h w -> (b f) c h w")
x = super().forward(x)
x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
return x
class Upsample3D(nn.Module):
def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.use_conv_transpose = use_conv_transpose
self.name = name
conv = None
if use_conv_transpose:
raise NotImplementedError
elif use_conv:
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
def forward(self, hidden_states, output_size=None):
assert hidden_states.shape[1] == self.channels
if self.use_conv_transpose:
raise NotImplementedError
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
dtype = hidden_states.dtype
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(torch.float32)
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
if hidden_states.shape[0] >= 64:
hidden_states = hidden_states.contiguous()
# if `output_size` is passed we force the interpolation output
# size and do not make use of `scale_factor=2`
if output_size is None:
hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
else:
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
# If the input is bfloat16, we cast back to bfloat16
if dtype == torch.bfloat16:
hidden_states = hidden_states.to(dtype)
hidden_states = self.conv(hidden_states)
return hidden_states
class Downsample3D(nn.Module):
def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
super().__init__()
self.channels = channels
self.out_channels = out_channels or channels
self.use_conv = use_conv
self.padding = padding
stride = 2
self.name = name
if use_conv:
self.conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
else:
raise NotImplementedError
def forward(self, hidden_states):
assert hidden_states.shape[1] == self.channels
if self.use_conv and self.padding == 0:
raise NotImplementedError
assert hidden_states.shape[1] == self.channels
hidden_states = self.conv(hidden_states)
return hidden_states
class ResnetBlock3D(nn.Module):
def __init__(
self,
*,
in_channels,
out_channels=None,
conv_shortcut=False,
dropout=0.0,
temb_channels=512,
groups=32,
groups_out=None,
pre_norm=True,
eps=1e-6,
non_linearity="swish",
time_embedding_norm="default",
output_scale_factor=1.0,
use_in_shortcut=None,
use_inflated_groupnorm=False,
):
super().__init__()
self.pre_norm = pre_norm
self.pre_norm = True
self.in_channels = in_channels
out_channels = in_channels if out_channels is None else out_channels
self.out_channels = out_channels
self.use_conv_shortcut = conv_shortcut
self.time_embedding_norm = time_embedding_norm
self.output_scale_factor = output_scale_factor
if groups_out is None:
groups_out = groups
assert use_inflated_groupnorm != None
if use_inflated_groupnorm:
self.norm1 = InflatedGroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
else:
self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
if temb_channels is not None:
if self.time_embedding_norm == "default":
time_emb_proj_out_channels = out_channels
elif self.time_embedding_norm == "scale_shift":
time_emb_proj_out_channels = out_channels * 2
else:
raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
else:
self.time_emb_proj = None
if self.time_embedding_norm == "scale_shift":
self.double_len_linear = torch.nn.Linear(time_emb_proj_out_channels, 2 * time_emb_proj_out_channels)
else:
self.double_len_linear = None
if use_inflated_groupnorm:
self.norm2 = InflatedGroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
else:
self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
self.dropout = torch.nn.Dropout(dropout)
self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if non_linearity == "swish":
self.nonlinearity = lambda x: F.silu(x)
elif non_linearity == "mish":
self.nonlinearity = Mish()
elif non_linearity == "silu":
self.nonlinearity = nn.SiLU()
self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
self.conv_shortcut = None
if self.use_in_shortcut:
self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
def forward(self, input_tensor, temb):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.conv1(hidden_states)
if temb is not None:
if temb.dim() == 2:
# input (1, 1280)
temb = self.time_emb_proj(self.nonlinearity(temb))
temb = temb[:, :, None, None, None] # unsqueeze
else:
# input (1, 1280, 16)
temb = temb.permute(0, 2, 1)
temb = self.time_emb_proj(self.nonlinearity(temb))
if self.double_len_linear is not None:
temb = self.double_len_linear(self.nonlinearity(temb))
temb = temb.permute(0, 2, 1)
temb = temb[:, :, :, None, None]
if temb is not None and self.time_embedding_norm == "default":
hidden_states = hidden_states + temb
hidden_states = self.norm2(hidden_states)
if temb is not None and self.time_embedding_norm == "scale_shift":
scale, shift = torch.chunk(temb, 2, dim=1)
hidden_states = hidden_states * (1 + scale) + shift
hidden_states = self.nonlinearity(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
return output_tensor
class Mish(torch.nn.Module):
def forward(self, hidden_states):
return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))

View File

@@ -0,0 +1,233 @@
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import nn
from einops import rearrange
from torch.nn import functional as F
from .attention import Attention
import torch.nn as nn
import torch.nn.functional as F
from diffusers.models.attention import FeedForward
from einops import rearrange
class StableSyncNet(nn.Module):
def __init__(self, config, gradient_checkpointing=False):
super().__init__()
self.audio_encoder = DownEncoder2D(
in_channels=config["audio_encoder"]["in_channels"],
block_out_channels=config["audio_encoder"]["block_out_channels"],
downsample_factors=config["audio_encoder"]["downsample_factors"],
dropout=config["audio_encoder"]["dropout"],
attn_blocks=config["audio_encoder"]["attn_blocks"],
gradient_checkpointing=gradient_checkpointing,
)
self.visual_encoder = DownEncoder2D(
in_channels=config["visual_encoder"]["in_channels"],
block_out_channels=config["visual_encoder"]["block_out_channels"],
downsample_factors=config["visual_encoder"]["downsample_factors"],
dropout=config["visual_encoder"]["dropout"],
attn_blocks=config["visual_encoder"]["attn_blocks"],
gradient_checkpointing=gradient_checkpointing,
)
self.eval()
def forward(self, image_sequences, audio_sequences):
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
# Make them unit vectors
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
return vision_embeds, audio_embeds
class ResnetBlock2D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
dropout: float = 0.0,
norm_num_groups: int = 32,
eps: float = 1e-6,
act_fn: str = "silu",
downsample_factor=2,
):
super().__init__()
self.norm1 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=eps, affine=True)
self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
self.norm2 = nn.GroupNorm(num_groups=norm_num_groups, num_channels=out_channels, eps=eps, affine=True)
self.dropout = nn.Dropout(dropout)
self.conv2 = nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
if act_fn == "relu":
self.act_fn = nn.ReLU()
elif act_fn == "silu":
self.act_fn = nn.SiLU()
if in_channels != out_channels:
self.conv_shortcut = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
else:
self.conv_shortcut = None
if isinstance(downsample_factor, list):
downsample_factor = tuple(downsample_factor)
if downsample_factor == 1:
self.downsample_conv = None
else:
self.downsample_conv = nn.Conv2d(
out_channels, out_channels, kernel_size=3, stride=downsample_factor, padding=0
)
self.pad = (0, 1, 0, 1)
if isinstance(downsample_factor, tuple):
if downsample_factor[0] == 1:
self.pad = (0, 1, 1, 1) # The padding order is from back to front
elif downsample_factor[1] == 1:
self.pad = (1, 1, 0, 1)
def forward(self, input_tensor):
hidden_states = input_tensor
hidden_states = self.norm1(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.conv1(hidden_states)
hidden_states = self.norm2(hidden_states)
hidden_states = self.act_fn(hidden_states)
hidden_states = self.dropout(hidden_states)
hidden_states = self.conv2(hidden_states)
if self.conv_shortcut is not None:
input_tensor = self.conv_shortcut(input_tensor)
hidden_states += input_tensor
if self.downsample_conv is not None:
hidden_states = F.pad(hidden_states, self.pad, mode="constant", value=0)
hidden_states = self.downsample_conv(hidden_states)
return hidden_states
class AttentionBlock2D(nn.Module):
def __init__(self, query_dim, norm_num_groups=32, dropout=0.0):
super().__init__()
self.norm1 = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=query_dim, eps=1e-6, affine=True)
self.norm2 = nn.LayerNorm(query_dim)
self.norm3 = nn.LayerNorm(query_dim)
self.ff = FeedForward(query_dim, dropout=dropout, activation_fn="geglu")
self.conv_in = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
self.conv_out = nn.Conv2d(query_dim, query_dim, kernel_size=1, stride=1, padding=0)
self.attn = Attention(query_dim=query_dim, heads=8, dim_head=query_dim // 8, dropout=dropout, bias=True)
def forward(self, hidden_states):
assert hidden_states.dim() == 4, f"Expected hidden_states to have ndim=4, but got ndim={hidden_states.dim()}."
batch, channel, height, width = hidden_states.shape
residual = hidden_states
hidden_states = self.norm1(hidden_states)
hidden_states = self.conv_in(hidden_states)
hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
norm_hidden_states = self.norm2(hidden_states)
hidden_states = self.attn(norm_hidden_states, attention_mask=None) + hidden_states
hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=height, w=width).contiguous()
hidden_states = self.conv_out(hidden_states)
hidden_states = hidden_states + residual
return hidden_states
class DownEncoder2D(nn.Module):
def __init__(
self,
in_channels=4 * 16,
block_out_channels=[64, 128, 256, 256],
downsample_factors=[2, 2, 2, 2],
layers_per_block=2,
norm_num_groups=32,
attn_blocks=[1, 1, 1, 1],
dropout: float = 0.0,
act_fn="silu",
gradient_checkpointing=False,
):
super().__init__()
self.layers_per_block = layers_per_block
self.gradient_checkpointing = gradient_checkpointing
# in
self.conv_in = nn.Conv2d(in_channels, block_out_channels[0], kernel_size=3, stride=1, padding=1)
# down
self.down_blocks = nn.ModuleList([])
output_channels = block_out_channels[0]
for i, block_out_channel in enumerate(block_out_channels):
input_channels = output_channels
output_channels = block_out_channel
down_block = ResnetBlock2D(
in_channels=input_channels,
out_channels=output_channels,
downsample_factor=downsample_factors[i],
norm_num_groups=norm_num_groups,
dropout=dropout,
act_fn=act_fn,
)
self.down_blocks.append(down_block)
if attn_blocks[i] == 1:
attention_block = AttentionBlock2D(query_dim=output_channels, dropout=dropout)
self.down_blocks.append(attention_block)
# out
self.norm_out = nn.GroupNorm(num_channels=block_out_channels[-1], num_groups=norm_num_groups, eps=1e-6)
self.act_fn_out = nn.ReLU()
def forward(self, hidden_states):
hidden_states = self.conv_in(hidden_states)
# down
for down_block in self.down_blocks:
if self.gradient_checkpointing:
hidden_states = torch.utils.checkpoint.checkpoint(down_block, hidden_states, use_reentrant=False)
else:
hidden_states = down_block(hidden_states)
# post-process
hidden_states = self.norm_out(hidden_states)
hidden_states = self.act_fn_out(hidden_states)
return hidden_states

View File

@@ -0,0 +1,512 @@
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet.py
from dataclasses import dataclass
from typing import List, Optional, Tuple, Union
import copy
import torch
import torch.nn as nn
import torch.utils.checkpoint
from diffusers.configuration_utils import ConfigMixin, register_to_config
from diffusers.models import ModelMixin
from diffusers.utils import BaseOutput, logging
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
from .unet_blocks import (
CrossAttnDownBlock3D,
CrossAttnUpBlock3D,
DownBlock3D,
UNetMidBlock3DCrossAttn,
UpBlock3D,
get_down_block,
get_up_block,
)
from .resnet import InflatedConv3d, InflatedGroupNorm
from ..utils.util import zero_rank_log
from .utils import zero_module
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@dataclass
class UNet3DConditionOutput(BaseOutput):
sample: torch.FloatTensor
class UNet3DConditionModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True
@register_to_config
def __init__(
self,
sample_size: Optional[int] = None,
in_channels: int = 4,
out_channels: int = 4,
center_input_sample: bool = False,
flip_sin_to_cos: bool = True,
freq_shift: int = 0,
down_block_types: Tuple[str] = (
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"CrossAttnDownBlock3D",
"DownBlock3D",
),
mid_block_type: str = "UNetMidBlock3DCrossAttn",
up_block_types: Tuple[str] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D"),
only_cross_attention: Union[bool, Tuple[bool]] = False,
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
layers_per_block: int = 2,
downsample_padding: int = 1,
mid_block_scale_factor: float = 1,
act_fn: str = "silu",
norm_num_groups: int = 32,
norm_eps: float = 1e-5,
cross_attention_dim: int = 1280,
attention_head_dim: Union[int, Tuple[int]] = 8,
dual_cross_attention: bool = False,
use_linear_projection: bool = False,
class_embed_type: Optional[str] = None,
num_class_embeds: Optional[int] = None,
upcast_attention: bool = False,
resnet_time_scale_shift: str = "default",
use_inflated_groupnorm=False,
# Additional
use_motion_module=False,
motion_module_resolutions=(1, 2, 4, 8),
motion_module_mid_block=False,
motion_module_decoder_only=False,
motion_module_type=None,
motion_module_kwargs={},
add_audio_layer=False,
):
super().__init__()
self.sample_size = sample_size
time_embed_dim = block_out_channels[0] * 4
self.use_motion_module = use_motion_module
self.add_audio_layer = add_audio_layer
self.conv_in = zero_module(InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)))
# time
self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
timestep_input_dim = block_out_channels[0]
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
# class embedding
if class_embed_type is None and num_class_embeds is not None:
self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
elif class_embed_type == "timestep":
self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
elif class_embed_type == "identity":
self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
else:
self.class_embedding = None
self.down_blocks = nn.ModuleList([])
self.mid_block = None
self.up_blocks = nn.ModuleList([])
if isinstance(only_cross_attention, bool):
only_cross_attention = [only_cross_attention] * len(down_block_types)
if isinstance(attention_head_dim, int):
attention_head_dim = (attention_head_dim,) * len(down_block_types)
# down
output_channel = block_out_channels[0]
for i, down_block_type in enumerate(down_block_types):
res = 2**i
input_channel = output_channel
output_channel = block_out_channels[i]
is_final_block = i == len(block_out_channels) - 1
down_block = get_down_block(
down_block_type,
num_layers=layers_per_block,
in_channels=input_channel,
out_channels=output_channel,
temb_channels=time_embed_dim,
add_downsample=not is_final_block,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[i],
downsample_padding=downsample_padding,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module
and (res in motion_module_resolutions)
and (not motion_module_decoder_only),
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
add_audio_layer=add_audio_layer,
)
self.down_blocks.append(down_block)
# mid
if mid_block_type == "UNetMidBlock3DCrossAttn":
self.mid_block = UNetMidBlock3DCrossAttn(
in_channels=block_out_channels[-1],
temb_channels=time_embed_dim,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
output_scale_factor=mid_block_scale_factor,
resnet_time_scale_shift=resnet_time_scale_shift,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attention_head_dim[-1],
resnet_groups=norm_num_groups,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module and motion_module_mid_block,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
add_audio_layer=add_audio_layer,
)
else:
raise ValueError(f"unknown mid_block_type : {mid_block_type}")
# count how many layers upsample the videos
self.num_upsamplers = 0
# up
reversed_block_out_channels = list(reversed(block_out_channels))
reversed_attention_head_dim = list(reversed(attention_head_dim))
only_cross_attention = list(reversed(only_cross_attention))
output_channel = reversed_block_out_channels[0]
for i, up_block_type in enumerate(up_block_types):
res = 2 ** (3 - i)
is_final_block = i == len(block_out_channels) - 1
prev_output_channel = output_channel
output_channel = reversed_block_out_channels[i]
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
# add upsample block for all BUT final layer
if not is_final_block:
add_upsample = True
self.num_upsamplers += 1
else:
add_upsample = False
up_block = get_up_block(
up_block_type,
num_layers=layers_per_block + 1,
in_channels=input_channel,
out_channels=output_channel,
prev_output_channel=prev_output_channel,
temb_channels=time_embed_dim,
add_upsample=add_upsample,
resnet_eps=norm_eps,
resnet_act_fn=act_fn,
resnet_groups=norm_num_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=reversed_attention_head_dim[i],
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention[i],
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module and (res in motion_module_resolutions),
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
add_audio_layer=add_audio_layer,
)
self.up_blocks.append(up_block)
prev_output_channel = output_channel
# out
if use_inflated_groupnorm:
self.conv_norm_out = InflatedGroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)
else:
self.conv_norm_out = nn.GroupNorm(
num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps
)
self.conv_act = nn.SiLU()
self.conv_out = zero_module(InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1))
def set_attention_slice(self, slice_size):
r"""
Enable sliced attention computation.
When this option is enabled, the attention module will split the input tensor in slices, to compute attention
in several steps. This is useful to save some memory in exchange for a small speed decrease.
Args:
slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
`"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
must be a multiple of `slice_size`.
"""
sliceable_head_dims = []
def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
if hasattr(module, "set_attention_slice"):
sliceable_head_dims.append(module.sliceable_head_dim)
for child in module.children():
fn_recursive_retrieve_slicable_dims(child)
# retrieve number of attention layers
for module in self.children():
fn_recursive_retrieve_slicable_dims(module)
num_slicable_layers = len(sliceable_head_dims)
if slice_size == "auto":
# half the attention head size is usually a good trade-off between
# speed and memory
slice_size = [dim // 2 for dim in sliceable_head_dims]
elif slice_size == "max":
# make smallest slice possible
slice_size = num_slicable_layers * [1]
slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
if len(slice_size) != len(sliceable_head_dims):
raise ValueError(
f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
)
for i in range(len(slice_size)):
size = slice_size[i]
dim = sliceable_head_dims[i]
if size is not None and size > dim:
raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
# Recursively walk through all the children.
# Any children which exposes the set_attention_slice method
# gets the message
def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
if hasattr(module, "set_attention_slice"):
module.set_attention_slice(slice_size.pop())
for child in module.children():
fn_recursive_set_attention_slice(child, slice_size)
reversed_slice_size = list(reversed(slice_size))
for module in self.children():
fn_recursive_set_attention_slice(module, reversed_slice_size)
def _set_gradient_checkpointing(self, module, value=False):
if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
module.gradient_checkpointing = value
def forward(
self,
sample: torch.FloatTensor,
timestep: Union[torch.Tensor, float, int],
encoder_hidden_states: torch.Tensor = None,
class_labels: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
# support controlnet
down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
mid_block_additional_residual: Optional[torch.Tensor] = None,
return_dict: bool = True,
) -> Union[UNet3DConditionOutput, Tuple]:
r"""
Args:
sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
return_dict (`bool`, *optional*, defaults to `True`):
Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
Returns:
[`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
[`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
returning a tuple, the first element is the sample tensor.
"""
# By default samples have to be AT least a multiple of the overall upsampling factor.
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
# However, the upsampling interpolation output size can be forced to fit any upsampling size
# on the fly if necessary.
default_overall_up_factor = 2**self.num_upsamplers
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
forward_upsample_size = False
upsample_size = None
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
logger.info("Forward upsample size to force interpolation output size.")
forward_upsample_size = True
# prepare attention_mask
if attention_mask is not None:
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
attention_mask = attention_mask.unsqueeze(1)
# center input if necessary
if self.config.center_input_sample:
sample = 2 * sample - 1.0
# time
timesteps = timestep
if not torch.is_tensor(timesteps):
# This would be a good case for the `match` statement (Python 3.10+)
is_mps = sample.device.type == "mps"
if isinstance(timestep, float):
dtype = torch.float32 if is_mps else torch.float64
else:
dtype = torch.int32 if is_mps else torch.int64
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
elif len(timesteps.shape) == 0:
timesteps = timesteps[None].to(sample.device)
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
timesteps = timesteps.expand(sample.shape[0])
t_emb = self.time_proj(timesteps)
# timesteps does not contain any weights and will always return f32 tensors
# but time_embedding might actually be running in fp16. so we need to cast here.
# there might be better ways to encapsulate this.
t_emb = t_emb.to(dtype=self.dtype)
emb = self.time_embedding(t_emb)
if self.class_embedding is not None:
if class_labels is None:
raise ValueError("class_labels should be provided when num_class_embeds > 0")
if self.config.class_embed_type == "timestep":
class_labels = self.time_proj(class_labels)
class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
emb = emb + class_emb
# pre-process
sample = self.conv_in(sample)
# down
down_block_res_samples = (sample,)
for downsample_block in self.down_blocks:
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
sample, res_samples = downsample_block(
hidden_states=sample,
temb=emb,
encoder_hidden_states=encoder_hidden_states,
attention_mask=attention_mask,
)
else:
sample, res_samples = downsample_block(
hidden_states=sample, temb=emb, encoder_hidden_states=encoder_hidden_states
)
down_block_res_samples += res_samples
# support controlnet
down_block_res_samples = list(down_block_res_samples)
if down_block_additional_residuals is not None:
for i, down_block_additional_residual in enumerate(down_block_additional_residuals):
if down_block_additional_residual.dim() == 4: # boardcast
down_block_additional_residual = down_block_additional_residual.unsqueeze(2)
down_block_res_samples[i] = down_block_res_samples[i] + down_block_additional_residual
# mid
sample = self.mid_block(
sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
)
# support controlnet
if mid_block_additional_residual is not None:
if mid_block_additional_residual.dim() == 4: # boardcast
mid_block_additional_residual = mid_block_additional_residual.unsqueeze(2)
sample = sample + mid_block_additional_residual
# up
for i, upsample_block in enumerate(self.up_blocks):
is_final_block = i == len(self.up_blocks) - 1
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
# if we have not reached the final block and need to forward the
# upsample size, we do it here
if not is_final_block and forward_upsample_size:
upsample_size = down_block_res_samples[-1].shape[2:]
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
encoder_hidden_states=encoder_hidden_states,
upsample_size=upsample_size,
attention_mask=attention_mask,
)
else:
sample = upsample_block(
hidden_states=sample,
temb=emb,
res_hidden_states_tuple=res_samples,
upsample_size=upsample_size,
encoder_hidden_states=encoder_hidden_states,
)
# post-process
sample = self.conv_norm_out(sample)
sample = self.conv_act(sample)
sample = self.conv_out(sample)
if not return_dict:
return (sample,)
return UNet3DConditionOutput(sample=sample)
def load_state_dict(self, state_dict, strict=True):
# If the loaded checkpoint's in_channels or out_channels are different from config
if state_dict["conv_in.weight"].shape[1] != self.config.in_channels:
del state_dict["conv_in.weight"]
del state_dict["conv_in.bias"]
if state_dict["conv_out.weight"].shape[0] != self.config.out_channels:
del state_dict["conv_out.weight"]
del state_dict["conv_out.bias"]
# If the loaded checkpoint's cross_attention_dim is different from config
keys_to_remove = []
for key in state_dict:
if "attn2.to_k." in key or "attn2.to_v." in key:
if state_dict[key].shape[1] != self.config.cross_attention_dim:
keys_to_remove.append(key)
for key in keys_to_remove:
del state_dict[key]
return super().load_state_dict(state_dict=state_dict, strict=strict)
@classmethod
def from_pretrained(cls, model_config: dict, ckpt_path: str, device="cpu"):
unet = cls.from_config(model_config).to(device)
if ckpt_path != "":
zero_rank_log(logger, f"Load from checkpoint: {ckpt_path}")
ckpt = torch.load(ckpt_path, map_location=device, weights_only=True)
if "global_step" in ckpt:
zero_rank_log(logger, f"resume from global_step: {ckpt['global_step']}")
resume_global_step = ckpt["global_step"]
else:
resume_global_step = 0
unet.load_state_dict(ckpt["state_dict"], strict=False)
del ckpt
torch.cuda.empty_cache()
else:
resume_global_step = 0
return unet, resume_global_step

View File

@@ -0,0 +1,777 @@
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/unet_blocks.py
import torch
from torch import nn
from .attention import Transformer3DModel
from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
from .motion_module import get_motion_module
def get_down_block(
down_block_type,
num_layers,
in_channels,
out_channels,
temb_channels,
add_downsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
downsample_padding=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
add_audio_layer=False,
):
down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
if down_block_type == "DownBlock3D":
return DownBlock3D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
resnet_time_scale_shift=resnet_time_scale_shift,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
elif down_block_type == "CrossAttnDownBlock3D":
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
return CrossAttnDownBlock3D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
add_downsample=add_downsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
downsample_padding=downsample_padding,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
add_audio_layer=add_audio_layer,
)
raise ValueError(f"{down_block_type} does not exist.")
def get_up_block(
up_block_type,
num_layers,
in_channels,
out_channels,
prev_output_channel,
temb_channels,
add_upsample,
resnet_eps,
resnet_act_fn,
attn_num_head_channels,
resnet_groups=None,
cross_attention_dim=None,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
resnet_time_scale_shift="default",
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
add_audio_layer=False,
):
up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
if up_block_type == "UpBlock3D":
return UpBlock3D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
resnet_time_scale_shift=resnet_time_scale_shift,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
elif up_block_type == "CrossAttnUpBlock3D":
if cross_attention_dim is None:
raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
return CrossAttnUpBlock3D(
num_layers=num_layers,
in_channels=in_channels,
out_channels=out_channels,
prev_output_channel=prev_output_channel,
temb_channels=temb_channels,
add_upsample=add_upsample,
resnet_eps=resnet_eps,
resnet_act_fn=resnet_act_fn,
resnet_groups=resnet_groups,
cross_attention_dim=cross_attention_dim,
attn_num_head_channels=attn_num_head_channels,
dual_cross_attention=dual_cross_attention,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
resnet_time_scale_shift=resnet_time_scale_shift,
use_inflated_groupnorm=use_inflated_groupnorm,
use_motion_module=use_motion_module,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
add_audio_layer=add_audio_layer,
)
raise ValueError(f"{up_block_type} does not exist.")
class UNetMidBlock3DCrossAttn(nn.Module):
def __init__(
self,
in_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
output_scale_factor=1.0,
cross_attention_dim=1280,
dual_cross_attention=False,
use_linear_projection=False,
upcast_attention=False,
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
add_audio_layer=False,
):
super().__init__()
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
# there is always at least one resnet
resnets = [
ResnetBlock3D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
]
attentions = []
motion_modules = []
for _ in range(num_layers):
if dual_cross_attention:
raise NotImplementedError
attentions.append(
Transformer3DModel(
attn_num_head_channels,
in_channels // attn_num_head_channels,
in_channels=in_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
upcast_attention=upcast_attention,
add_audio_layer=add_audio_layer,
)
)
motion_modules.append(
get_motion_module(
in_channels=in_channels,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
if use_motion_module
else None
)
resnets.append(
ResnetBlock3D(
in_channels=in_channels,
out_channels=in_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
hidden_states = self.resnets[0](hidden_states, temb)
for attn, resnet, motion_module in zip(self.attentions, self.resnets[1:], self.motion_modules):
hidden_states = attn(
hidden_states,
encoder_hidden_states=encoder_hidden_states,
return_dict=False,
)[0]
if motion_module is not None:
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
hidden_states = resnet(hidden_states, temb)
return hidden_states
class CrossAttnDownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
downsample_padding=1,
add_downsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
add_audio_layer=False,
):
super().__init__()
resnets = []
attentions = []
motion_modules = []
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock3D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
)
if dual_cross_attention:
raise NotImplementedError
attentions.append(
Transformer3DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
add_audio_layer=add_audio_layer,
)
)
motion_modules.append(
get_motion_module(
in_channels=out_channels,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
if use_motion_module
else None
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample3D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
output_states = ()
for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
use_reentrant=False,
)[0]
if motion_module is not None:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states,
temb,
encoder_hidden_states,
use_reentrant=False,
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
if motion_module is not None:
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
output_states += (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
class DownBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_downsample=True,
downsample_padding=1,
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
):
super().__init__()
resnets = []
motion_modules = []
for i in range(num_layers):
in_channels = in_channels if i == 0 else out_channels
resnets.append(
ResnetBlock3D(
in_channels=in_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
)
motion_modules.append(
get_motion_module(
in_channels=out_channels,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
if use_motion_module
else None
)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_downsample:
self.downsamplers = nn.ModuleList(
[
Downsample3D(
out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
)
]
)
else:
self.downsamplers = None
self.gradient_checkpointing = False
def forward(self, hidden_states, temb=None, encoder_hidden_states=None):
output_states = ()
for resnet, motion_module in zip(self.resnets, self.motion_modules):
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
if motion_module is not None:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states,
temb,
encoder_hidden_states,
use_reentrant=False,
)
else:
hidden_states = resnet(hidden_states, temb)
if motion_module is not None:
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
output_states += (hidden_states,)
if self.downsamplers is not None:
for downsampler in self.downsamplers:
hidden_states = downsampler(hidden_states)
output_states += (hidden_states,)
return hidden_states, output_states
class CrossAttnUpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
out_channels: int,
prev_output_channel: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
attn_num_head_channels=1,
cross_attention_dim=1280,
output_scale_factor=1.0,
add_upsample=True,
dual_cross_attention=False,
use_linear_projection=False,
only_cross_attention=False,
upcast_attention=False,
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
add_audio_layer=False,
):
super().__init__()
resnets = []
attentions = []
motion_modules = []
self.has_cross_attention = True
self.attn_num_head_channels = attn_num_head_channels
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock3D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
)
if dual_cross_attention:
raise NotImplementedError
attentions.append(
Transformer3DModel(
attn_num_head_channels,
out_channels // attn_num_head_channels,
in_channels=out_channels,
num_layers=1,
cross_attention_dim=cross_attention_dim,
norm_num_groups=resnet_groups,
use_linear_projection=use_linear_projection,
only_cross_attention=only_cross_attention,
upcast_attention=upcast_attention,
add_audio_layer=add_audio_layer,
)
)
motion_modules.append(
get_motion_module(
in_channels=out_channels,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
if use_motion_module
else None
)
self.attentions = nn.ModuleList(attentions)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
res_hidden_states_tuple,
temb=None,
encoder_hidden_states=None,
upsample_size=None,
attention_mask=None,
):
for resnet, attn, motion_module in zip(self.resnets, self.attentions, self.motion_modules):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module, return_dict=None):
def custom_forward(*inputs):
if return_dict is not None:
return module(*inputs, return_dict=return_dict)
else:
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(attn, return_dict=False),
hidden_states,
encoder_hidden_states,
use_reentrant=False,
)[0]
if motion_module is not None:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states,
temb,
encoder_hidden_states,
use_reentrant=False,
)
else:
hidden_states = resnet(hidden_states, temb)
hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
if motion_module is not None:
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states
class UpBlock3D(nn.Module):
def __init__(
self,
in_channels: int,
prev_output_channel: int,
out_channels: int,
temb_channels: int,
dropout: float = 0.0,
num_layers: int = 1,
resnet_eps: float = 1e-6,
resnet_time_scale_shift: str = "default",
resnet_act_fn: str = "swish",
resnet_groups: int = 32,
resnet_pre_norm: bool = True,
output_scale_factor=1.0,
add_upsample=True,
use_inflated_groupnorm=False,
use_motion_module=None,
motion_module_type=None,
motion_module_kwargs=None,
):
super().__init__()
resnets = []
motion_modules = []
for i in range(num_layers):
res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
resnet_in_channels = prev_output_channel if i == 0 else out_channels
resnets.append(
ResnetBlock3D(
in_channels=resnet_in_channels + res_skip_channels,
out_channels=out_channels,
temb_channels=temb_channels,
eps=resnet_eps,
groups=resnet_groups,
dropout=dropout,
time_embedding_norm=resnet_time_scale_shift,
non_linearity=resnet_act_fn,
output_scale_factor=output_scale_factor,
pre_norm=resnet_pre_norm,
use_inflated_groupnorm=use_inflated_groupnorm,
)
)
motion_modules.append(
get_motion_module(
in_channels=out_channels,
motion_module_type=motion_module_type,
motion_module_kwargs=motion_module_kwargs,
)
if use_motion_module
else None
)
self.resnets = nn.ModuleList(resnets)
self.motion_modules = nn.ModuleList(motion_modules)
if add_upsample:
self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
else:
self.upsamplers = None
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
res_hidden_states_tuple,
temb=None,
upsample_size=None,
encoder_hidden_states=None,
):
for resnet, motion_module in zip(self.resnets, self.motion_modules):
# pop res hidden states
res_hidden_states = res_hidden_states_tuple[-1]
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
if torch.is_grad_enabled() and self.gradient_checkpointing:
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs)
return custom_forward
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(resnet), hidden_states, temb, use_reentrant=False
)
if motion_module is not None:
hidden_states = torch.utils.checkpoint.checkpoint(
create_custom_forward(motion_module),
hidden_states,
temb,
encoder_hidden_states,
use_reentrant=False,
)
else:
hidden_states = resnet(hidden_states, temb)
if motion_module is not None:
hidden_states = motion_module(hidden_states, temb, encoder_hidden_states=encoder_hidden_states)
if self.upsamplers is not None:
for upsampler in self.upsamplers:
hidden_states = upsampler(hidden_states, upsample_size)
return hidden_states

View File

@@ -0,0 +1,19 @@
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
def zero_module(module):
# Zero out the parameters of a module and return it.
for p in module.parameters():
p.detach().zero_()
return module

View File

@@ -0,0 +1,90 @@
# Adapted from https://github.com/primepake/wav2lip_288x288/blob/master/models/syncnetv2.py
# The code here is for ablation study.
from torch import nn
from torch.nn import functional as F
class Wav2LipSyncNet(nn.Module):
def __init__(self, act_fn="leaky"):
super().__init__()
# input image sequences: (15, 128, 256)
self.visual_encoder = nn.Sequential(
Conv2d(15, 32, kernel_size=(7, 7), stride=1, padding=3, act_fn=act_fn), # (128, 256)
Conv2d(32, 64, kernel_size=5, stride=(1, 2), padding=1, act_fn=act_fn), # (126, 127)
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(64, 128, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (63, 64)
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(128, 256, kernel_size=3, stride=3, padding=1, act_fn=act_fn), # (21, 22)
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(256, 512, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (11, 11)
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(512, 1024, kernel_size=3, stride=2, padding=1, act_fn=act_fn), # (6, 6)
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(1024, 1024, kernel_size=3, stride=2, padding=1, act_fn="relu"), # (3, 3)
Conv2d(1024, 1024, kernel_size=3, stride=1, padding=0, act_fn="relu"), # (1, 1)
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act_fn="relu"),
)
# input audio sequences: (1, 80, 16)
self.audio_encoder = nn.Sequential(
Conv2d(1, 32, kernel_size=3, stride=1, padding=1, act_fn=act_fn),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1, act_fn=act_fn), # (27, 16)
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(64, 128, kernel_size=3, stride=3, padding=1, act_fn=act_fn), # (9, 6)
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1, act_fn=act_fn), # (3, 3)
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(256, 512, kernel_size=3, stride=1, padding=1, act_fn=act_fn),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(512, 512, kernel_size=3, stride=1, padding=1, residual=True, act_fn=act_fn),
Conv2d(512, 1024, kernel_size=3, stride=1, padding=0, act_fn="relu"), # (1, 1)
Conv2d(1024, 1024, kernel_size=1, stride=1, padding=0, act_fn="relu"),
)
def forward(self, image_sequences, audio_sequences):
vision_embeds = self.visual_encoder(image_sequences) # (b, c, 1, 1)
audio_embeds = self.audio_encoder(audio_sequences) # (b, c, 1, 1)
vision_embeds = vision_embeds.reshape(vision_embeds.shape[0], -1) # (b, c)
audio_embeds = audio_embeds.reshape(audio_embeds.shape[0], -1) # (b, c)
# Make them unit vectors
vision_embeds = F.normalize(vision_embeds, p=2, dim=1)
audio_embeds = F.normalize(audio_embeds, p=2, dim=1)
return vision_embeds, audio_embeds
class Conv2d(nn.Module):
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, act_fn="relu", *args, **kwargs):
super().__init__(*args, **kwargs)
self.conv_block = nn.Sequential(nn.Conv2d(cin, cout, kernel_size, stride, padding), nn.BatchNorm2d(cout))
if act_fn == "relu":
self.act_fn = nn.ReLU()
elif act_fn == "tanh":
self.act_fn = nn.Tanh()
elif act_fn == "silu":
self.act_fn = nn.SiLU()
elif act_fn == "leaky":
self.act_fn = nn.LeakyReLU(0.2, inplace=True)
self.residual = residual
def forward(self, x):
out = self.conv_block(x)
if self.residual:
out += x
return self.act_fn(out)

View File

@@ -0,0 +1,477 @@
# Adapted from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/pipelines/pipeline_animation.py
import inspect
import math
import os
import shutil
from typing import Callable, List, Optional, Union
import subprocess
import numpy as np
import torch
import torchvision
from torchvision import transforms
from packaging import version
from diffusers.configuration_utils import FrozenDict
from diffusers.models import AutoencoderKL
from diffusers.pipelines import DiffusionPipeline
from diffusers.schedulers import (
DDIMScheduler,
DPMSolverMultistepScheduler,
EulerAncestralDiscreteScheduler,
EulerDiscreteScheduler,
LMSDiscreteScheduler,
PNDMScheduler,
)
from diffusers.utils import deprecate, logging
from einops import rearrange
import cv2
from ..models.unet import UNet3DConditionModel
from ..utils.util import read_video, read_audio, write_video, check_ffmpeg_installed
from ..utils.image_processor import ImageProcessor, load_fixed_mask
from ..whisper.audio2feature import Audio2Feature
import tqdm
import soundfile as sf
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
class LipsyncPipeline(DiffusionPipeline):
_optional_components = []
def __init__(
self,
vae: AutoencoderKL,
audio_encoder: Audio2Feature,
unet: UNet3DConditionModel,
scheduler: Union[
DDIMScheduler,
PNDMScheduler,
LMSDiscreteScheduler,
EulerDiscreteScheduler,
EulerAncestralDiscreteScheduler,
DPMSolverMultistepScheduler,
],
):
super().__init__()
if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
"to update the config accordingly as leaving `steps_offset` might led to incorrect results"
" in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
" it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
" file"
)
deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["steps_offset"] = 1
scheduler._internal_dict = FrozenDict(new_config)
if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
deprecation_message = (
f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
" `clip_sample` should be set to False in the configuration file. Please make sure to update the"
" config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
" future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
" nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
)
deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(scheduler.config)
new_config["clip_sample"] = False
scheduler._internal_dict = FrozenDict(new_config)
is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
version.parse(unet.config._diffusers_version).base_version
) < version.parse("0.9.0.dev0")
is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
deprecation_message = (
"The configuration file of the unet has set the default `sample_size` to smaller than"
" 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
" following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
" CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
" \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
" configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
" in the config might lead to incorrect results in future versions. If you have downloaded this"
" checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
" the `unet/config.json` file"
)
deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
new_config = dict(unet.config)
new_config["sample_size"] = 64
unet._internal_dict = FrozenDict(new_config)
self.register_modules(
vae=vae,
audio_encoder=audio_encoder,
unet=unet,
scheduler=scheduler,
)
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
self.set_progress_bar_config(desc="Steps")
def enable_vae_slicing(self):
self.vae.enable_slicing()
def disable_vae_slicing(self):
self.vae.disable_slicing()
@property
def _execution_device(self):
if self.device != torch.device("meta") or not hasattr(self.unet, "_hf_hook"):
return self.device
for module in self.unet.modules():
if (
hasattr(module, "_hf_hook")
and hasattr(module._hf_hook, "execution_device")
and module._hf_hook.execution_device is not None
):
return torch.device(module._hf_hook.execution_device)
return self.device
def decode_latents(self, latents):
latents = latents / self.vae.config.scaling_factor + self.vae.config.shift_factor
latents = rearrange(latents, "b c f h w -> (b f) c h w")
decoded_latents = self.vae.decode(latents).sample
return decoded_latents
def prepare_extra_step_kwargs(self, generator, eta):
# prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
# eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
# eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
# and should be between [0, 1]
accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
extra_step_kwargs = {}
if accepts_eta:
extra_step_kwargs["eta"] = eta
# check if the scheduler accepts generator
accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
if accepts_generator:
extra_step_kwargs["generator"] = generator
return extra_step_kwargs
def check_inputs(self, height, width, callback_steps):
assert height == width, "Height and width must be equal"
if height % 8 != 0 or width % 8 != 0:
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
if (callback_steps is None) or (
callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
):
raise ValueError(
f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
f" {type(callback_steps)}."
)
def prepare_latents(self, num_frames, num_channels_latents, height, width, dtype, device, generator):
shape = (
1,
num_channels_latents,
1,
height // self.vae_scale_factor,
width // self.vae_scale_factor,
) # (b, c, f, h, w)
rand_device = "cpu" if device.type == "mps" else device
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
latents = latents.repeat(1, 1, num_frames, 1, 1)
# scale the initial noise by the standard deviation required by the scheduler
latents = latents * self.scheduler.init_noise_sigma
return latents
def prepare_mask_latents(
self, mask, masked_image, height, width, dtype, device, generator, do_classifier_free_guidance
):
# resize the mask to latents shape as we concatenate the mask to the latents
# we do that before converting to dtype to avoid breaking in case we're using cpu_offload
# and half precision
mask = torch.nn.functional.interpolate(
mask, size=(height // self.vae_scale_factor, width // self.vae_scale_factor)
)
masked_image = masked_image.to(device=device, dtype=dtype)
# encode the mask image into latents space so we can concatenate it to the latents
masked_image_latents = self.vae.encode(masked_image).latent_dist.sample(generator=generator)
masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
# aligning device to prevent device errors when concating it with the latent model input
masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
mask = mask.to(device=device, dtype=dtype)
# assume batch size = 1
mask = rearrange(mask, "f c h w -> 1 c f h w")
masked_image_latents = rearrange(masked_image_latents, "f c h w -> 1 c f h w")
mask = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
masked_image_latents = (
torch.cat([masked_image_latents] * 2) if do_classifier_free_guidance else masked_image_latents
)
return mask, masked_image_latents
def prepare_image_latents(self, images, device, dtype, generator, do_classifier_free_guidance):
images = images.to(device=device, dtype=dtype)
image_latents = self.vae.encode(images).latent_dist.sample(generator=generator)
image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
image_latents = rearrange(image_latents, "f c h w -> 1 c f h w")
image_latents = torch.cat([image_latents] * 2) if do_classifier_free_guidance else image_latents
return image_latents
def set_progress_bar_config(self, **kwargs):
if not hasattr(self, "_progress_bar_config"):
self._progress_bar_config = {}
self._progress_bar_config.update(kwargs)
@staticmethod
def paste_surrounding_pixels_back(decoded_latents, pixel_values, masks, device, weight_dtype):
# Paste the surrounding pixels back, because we only want to change the mouth region
pixel_values = pixel_values.to(device=device, dtype=weight_dtype)
masks = masks.to(device=device, dtype=weight_dtype)
combined_pixel_values = decoded_latents * masks + pixel_values * (1 - masks)
return combined_pixel_values
@staticmethod
def pixel_values_to_images(pixel_values: torch.Tensor):
pixel_values = rearrange(pixel_values, "f c h w -> f h w c")
pixel_values = (pixel_values / 2 + 0.5).clamp(0, 1)
images = (pixel_values * 255).to(torch.uint8)
images = images.cpu().numpy()
return images
def affine_transform_video(self, video_frames: np.ndarray):
faces = []
boxes = []
affine_matrices = []
print(f"Affine transforming {len(video_frames)} faces...")
for frame in tqdm.tqdm(video_frames):
face, box, affine_matrix = self.image_processor.affine_transform(frame)
faces.append(face)
boxes.append(box)
affine_matrices.append(affine_matrix)
faces = torch.stack(faces)
return faces, boxes, affine_matrices
def restore_video(self, faces: torch.Tensor, video_frames: np.ndarray, boxes: list, affine_matrices: list):
video_frames = video_frames[: len(faces)]
out_frames = []
print(f"Restoring {len(faces)} faces...")
for index, face in enumerate(tqdm.tqdm(faces)):
x1, y1, x2, y2 = boxes[index]
height = int(y2 - y1)
width = int(x2 - x1)
face = torchvision.transforms.functional.resize(
face, size=(height, width), interpolation=transforms.InterpolationMode.BICUBIC, antialias=True
)
out_frame = self.image_processor.restorer.restore_img(video_frames[index], face, affine_matrices[index])
out_frames.append(out_frame)
return np.stack(out_frames, axis=0)
def loop_video(self, whisper_chunks: list, video_frames: np.ndarray):
# If the audio is longer than the video, we need to loop the video
if len(whisper_chunks) > len(video_frames):
faces, boxes, affine_matrices = self.affine_transform_video(video_frames)
num_loops = math.ceil(len(whisper_chunks) / len(video_frames))
loop_video_frames = []
loop_faces = []
loop_boxes = []
loop_affine_matrices = []
for i in range(num_loops):
if i % 2 == 0:
loop_video_frames.append(video_frames)
loop_faces.append(faces)
loop_boxes += boxes
loop_affine_matrices += affine_matrices
else:
loop_video_frames.append(video_frames[::-1])
loop_faces.append(faces.flip(0))
loop_boxes += boxes[::-1]
loop_affine_matrices += affine_matrices[::-1]
video_frames = np.concatenate(loop_video_frames, axis=0)[: len(whisper_chunks)]
faces = torch.cat(loop_faces, dim=0)[: len(whisper_chunks)]
boxes = loop_boxes[: len(whisper_chunks)]
affine_matrices = loop_affine_matrices[: len(whisper_chunks)]
else:
video_frames = video_frames[: len(whisper_chunks)]
faces, boxes, affine_matrices = self.affine_transform_video(video_frames)
return video_frames, faces, boxes, affine_matrices
@torch.no_grad()
def __call__(
self,
video_path: str,
audio_path: str,
video_out_path: str,
num_frames: int = 16,
video_fps: int = 25,
audio_sample_rate: int = 16000,
height: Optional[int] = None,
width: Optional[int] = None,
num_inference_steps: int = 20,
guidance_scale: float = 1.5,
weight_dtype: Optional[torch.dtype] = torch.float16,
eta: float = 0.0,
mask_image_path: str = "latentsync/utils/mask.png",
temp_dir: str = "temp",
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
callback_steps: Optional[int] = 1,
**kwargs,
):
is_train = self.unet.training
self.unet.eval()
check_ffmpeg_installed()
# 0. Define call parameters
device = self._execution_device
mask_image = load_fixed_mask(height, mask_image_path)
self.image_processor = ImageProcessor(height, device="cuda", mask_image=mask_image)
self.set_progress_bar_config(desc=f"Sample frames: {num_frames}")
# 1. Default height and width to unet
height = height or self.unet.config.sample_size * self.vae_scale_factor
width = width or self.unet.config.sample_size * self.vae_scale_factor
# 2. Check inputs
self.check_inputs(height, width, callback_steps)
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
# corresponds to doing no classifier free guidance.
do_classifier_free_guidance = guidance_scale > 1.0
# 3. set timesteps
self.scheduler.set_timesteps(num_inference_steps, device=device)
timesteps = self.scheduler.timesteps
# 4. Prepare extra step kwargs.
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
whisper_feature = self.audio_encoder.audio2feat(audio_path)
whisper_chunks = self.audio_encoder.feature2chunks(feature_array=whisper_feature, fps=video_fps)
audio_samples = read_audio(audio_path)
video_frames = read_video(video_path, use_decord=False)
video_frames, faces, boxes, affine_matrices = self.loop_video(whisper_chunks, video_frames)
synced_video_frames = []
num_channels_latents = self.vae.config.latent_channels
# Prepare latent variables
all_latents = self.prepare_latents(
len(whisper_chunks),
num_channels_latents,
height,
width,
weight_dtype,
device,
generator,
)
num_inferences = math.ceil(len(whisper_chunks) / num_frames)
for i in tqdm.tqdm(range(num_inferences), desc="Doing inference..."):
if self.unet.add_audio_layer:
audio_embeds = torch.stack(whisper_chunks[i * num_frames : (i + 1) * num_frames])
audio_embeds = audio_embeds.to(device, dtype=weight_dtype)
if do_classifier_free_guidance:
null_audio_embeds = torch.zeros_like(audio_embeds)
audio_embeds = torch.cat([null_audio_embeds, audio_embeds])
else:
audio_embeds = None
inference_faces = faces[i * num_frames : (i + 1) * num_frames]
latents = all_latents[:, :, i * num_frames : (i + 1) * num_frames]
ref_pixel_values, masked_pixel_values, masks = self.image_processor.prepare_masks_and_masked_images(
inference_faces, affine_transform=False
)
# 7. Prepare mask latent variables
mask_latents, masked_image_latents = self.prepare_mask_latents(
masks,
masked_pixel_values,
height,
width,
weight_dtype,
device,
generator,
do_classifier_free_guidance,
)
# 8. Prepare image latents
ref_latents = self.prepare_image_latents(
ref_pixel_values,
device,
weight_dtype,
generator,
do_classifier_free_guidance,
)
# 9. Denoising loop
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
with self.progress_bar(total=num_inference_steps) as progress_bar:
for j, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
unet_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
unet_input = self.scheduler.scale_model_input(unet_input, t)
# concat latents, mask, masked_image_latents in the channel dimension
unet_input = torch.cat([unet_input, mask_latents, masked_image_latents, ref_latents], dim=1)
# predict the noise residual
noise_pred = self.unet(unet_input, t, encoder_hidden_states=audio_embeds).sample
# perform guidance
if do_classifier_free_guidance:
noise_pred_uncond, noise_pred_audio = noise_pred.chunk(2)
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_audio - noise_pred_uncond)
# compute the previous noisy sample x_t -> x_t-1
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
# call the callback, if provided
if j == len(timesteps) - 1 or ((j + 1) > num_warmup_steps and (j + 1) % self.scheduler.order == 0):
progress_bar.update()
if callback is not None and j % callback_steps == 0:
callback(j, t, latents)
# Recover the pixel values
decoded_latents = self.decode_latents(latents)
decoded_latents = self.paste_surrounding_pixels_back(
decoded_latents, ref_pixel_values, 1 - masks, device, weight_dtype
)
synced_video_frames.append(decoded_latents)
synced_video_frames = self.restore_video(torch.cat(synced_video_frames), video_frames, boxes, affine_matrices)
audio_samples_remain_length = int(synced_video_frames.shape[0] / video_fps * audio_sample_rate)
audio_samples = audio_samples[:audio_samples_remain_length].cpu().numpy()
if is_train:
self.unet.train()
if os.path.exists(temp_dir):
shutil.rmtree(temp_dir)
os.makedirs(temp_dir, exist_ok=True)
write_video(os.path.join(temp_dir, "video.mp4"), synced_video_frames, fps=video_fps)
sf.write(os.path.join(temp_dir, "audio.wav"), audio_samples, audio_sample_rate)
command = f"ffmpeg -y -loglevel error -nostdin -i {os.path.join(temp_dir, 'video.mp4')} -i {os.path.join(temp_dir, 'audio.wav')} -c:v libx264 -crf 18 -c:a aac -q:v 0 -q:a 0 {video_out_path}"
subprocess.run(command, shell=True)

View File

@@ -0,0 +1,67 @@
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
import torch.nn.functional as F
from einops import rearrange
from .third_party.VideoMAEv2.utils import load_videomae_model
from ..utils.util import check_model_and_download
class TREPALoss:
def __init__(
self,
device="cuda",
ckpt_path="checkpoints/auxiliary/vit_g_hybrid_pt_1200e_ssv2_ft.pth",
with_cp=False,
):
check_model_and_download(ckpt_path)
self.model = load_videomae_model(device, ckpt_path, with_cp).eval().to(dtype=torch.float16)
self.model.requires_grad_(False)
def __call__(self, videos_fake, videos_real):
batch_size = videos_fake.shape[0]
num_frames = videos_fake.shape[2]
videos_fake = rearrange(videos_fake.clone(), "b c f h w -> (b f) c h w")
videos_real = rearrange(videos_real.clone(), "b c f h w -> (b f) c h w")
videos_fake = F.interpolate(videos_fake, size=(224, 224), mode="bicubic")
videos_real = F.interpolate(videos_real, size=(224, 224), mode="bicubic")
videos_fake = rearrange(videos_fake, "(b f) c h w -> b c f h w", f=num_frames)
videos_real = rearrange(videos_real, "(b f) c h w -> b c f h w", f=num_frames)
# Because input pixel range is [-1, 1], and model expects pixel range to be [0, 1]
videos_fake = (videos_fake / 2 + 0.5).clamp(0, 1)
videos_real = (videos_real / 2 + 0.5).clamp(0, 1)
feats_fake = self.model.forward_features(videos_fake)
feats_real = self.model.forward_features(videos_real)
feats_fake = F.normalize(feats_fake, p=2, dim=1)
feats_real = F.normalize(feats_real, p=2, dim=1)
return F.mse_loss(feats_fake, feats_real)
if __name__ == "__main__":
torch.manual_seed(42)
# input shape: (b, c, f, h, w)
videos_fake = torch.randn(2, 3, 16, 256, 256, requires_grad=True).to(device="cuda", dtype=torch.float16)
videos_real = torch.randn(2, 3, 16, 256, 256, requires_grad=True).to(device="cuda", dtype=torch.float16)
trepa_loss = TREPALoss(device="cuda", with_cp=True)
loss = trepa_loss(videos_fake, videos_real)
print(loss)

View File

@@ -0,0 +1,82 @@
import os
import torch
import requests
from tqdm import tqdm
from torchvision import transforms
from .videomaev2_finetune import vit_giant_patch14_224
def to_normalized_float_tensor(vid):
return vid.permute(3, 0, 1, 2).to(torch.float32) / 255
# NOTE: for those functions, which generally expect mini-batches, we keep them
# as non-minibatch so that they are applied as if they were 4d (thus image).
# this way, we only apply the transformation in the spatial domain
def resize(vid, size, interpolation="bilinear"):
# NOTE: using bilinear interpolation because we don't work on minibatches
# at this level
scale = None
if isinstance(size, int):
scale = float(size) / min(vid.shape[-2:])
size = None
return torch.nn.functional.interpolate(vid, size=size, scale_factor=scale, mode=interpolation, align_corners=False)
class ToFloatTensorInZeroOne(object):
def __call__(self, vid):
return to_normalized_float_tensor(vid)
class Resize(object):
def __init__(self, size):
self.size = size
def __call__(self, vid):
return resize(vid, self.size)
def preprocess_videomae(videos):
transform = transforms.Compose([ToFloatTensorInZeroOne(), Resize((224, 224))])
return torch.stack([transform(f) for f in torch.from_numpy(videos)])
def load_videomae_model(device, ckpt_path=None, with_cp=False):
if ckpt_path is None:
current_dir = os.path.dirname(os.path.abspath(__file__))
ckpt_path = os.path.join(current_dir, "vit_g_hybrid_pt_1200e_ssv2_ft.pth")
if not os.path.exists(ckpt_path):
# download the ckpt to the path
ckpt_url = "https://pjlab-gvm-data.oss-cn-shanghai.aliyuncs.com/internvideo/videomaev2/vit_g_hybrid_pt_1200e_ssv2_ft.pth"
response = requests.get(ckpt_url, stream=True, allow_redirects=True)
total_size = int(response.headers.get("content-length", 0))
block_size = 1024
with tqdm(total=total_size, unit="B", unit_scale=True) as progress_bar:
with open(ckpt_path, "wb") as fw:
for data in response.iter_content(block_size):
progress_bar.update(len(data))
fw.write(data)
model = vit_giant_patch14_224(
img_size=224,
pretrained=False,
num_classes=174,
all_frames=16,
tubelet_size=2,
drop_path_rate=0.3,
use_mean_pooling=True,
with_cp=with_cp,
)
ckpt = torch.load(ckpt_path, map_location="cpu", weights_only=True)
for model_key in ["model", "module"]:
if model_key in ckpt:
ckpt = ckpt[model_key]
break
model.load_state_dict(ckpt)
del ckpt
torch.cuda.empty_cache()
return model.to(device)

View File

@@ -0,0 +1,543 @@
# --------------------------------------------------------
# Based on BEiT, timm, DINO and DeiT code bases
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
from functools import partial
import math
import warnings
import numpy as np
import collections.abc
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint as cp
from itertools import repeat
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Cut & paste from PyTorch official master until it's in a few official releases - RW
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. "
"The distribution of values may be incorrect.",
stacklevel=2,
)
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
# then using the inverse CDF for the normal distribution.
# Get upper and lower cdf values
l = norm_cdf((a - mean) / std)
u = norm_cdf((b - mean) / std)
# Uniformly fill tensor with values from [l, u], then translate to
# [2l-1, 2u-1].
tensor.uniform_(2 * l - 1, 2 * u - 1)
# Use inverse cdf transform for normal distribution to get truncated
# standard normal
tensor.erfinv_()
# Transform to proper mean, std
tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
tensor.clamp_(min=a, max=b)
return tensor
def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
r"""Fills the input Tensor with values drawn from a truncated
normal distribution. The values are effectively drawn from the
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)`
with values outside :math:`[a, b]` redrawn until they are within
the bounds. The method used for generating the random values works
best when :math:`a \leq \text{mean} \leq b`.
Args:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.trunc_normal_(w)
"""
return _no_grad_trunc_normal_(tensor, mean, std, a, b)
def _ntuple(n):
def parse(x):
if isinstance(x, collections.abc.Iterable):
return x
return tuple(repeat(x, n))
return parse
to_2tuple = _ntuple(2)
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
"""
Adapted from timm codebase
"""
if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
return output
def _cfg(url="", **kwargs):
return {
"url": url,
"num_classes": 400,
"input_size": (3, 224, 224),
"pool_size": None,
"crop_pct": 0.9,
"interpolation": "bicubic",
"mean": (0.5, 0.5, 0.5),
"std": (0.5, 0.5, 0.5),
**kwargs,
}
class DropPath(nn.Module):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
def __init__(self, drop_prob=None):
super(DropPath, self).__init__()
self.drop_prob = drop_prob
def forward(self, x):
return drop_path(x, self.drop_prob, self.training)
def extra_repr(self) -> str:
return "p={}".format(self.drop_prob)
class Mlp(nn.Module):
def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
self.fc1 = nn.Linear(in_features, hidden_features)
self.act = act_layer()
self.fc2 = nn.Linear(hidden_features, out_features)
self.drop = nn.Dropout(drop)
def forward(self, x):
x = self.fc1(x)
x = self.act(x)
# x = self.drop(x)
# commit this for the original BERT implement
x = self.fc2(x)
x = self.drop(x)
return x
class CosAttention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, attn_head_dim=None
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
# self.scale = qk_scale or head_dim**-0.5
# DO NOT RENAME [self.scale] (for no weight decay)
if qk_scale is None:
self.scale = nn.Parameter(torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True)
else:
self.scale = qk_scale
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
# torch.log(torch.tensor(1. / 0.01)) = 4.6052
logit_scale = torch.clamp(self.scale, max=4.6052).exp()
attn = attn * logit_scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)
x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Attention(nn.Module):
def __init__(
self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0.0, proj_drop=0.0, attn_head_dim=None
):
super().__init__()
self.num_heads = num_heads
head_dim = dim // num_heads
if attn_head_dim is not None:
head_dim = attn_head_dim
all_head_dim = head_dim * self.num_heads
self.scale = qk_scale or head_dim**-0.5
self.qkv = nn.Linear(dim, all_head_dim * 3, bias=False)
if qkv_bias:
self.q_bias = nn.Parameter(torch.zeros(all_head_dim))
self.v_bias = nn.Parameter(torch.zeros(all_head_dim))
else:
self.q_bias = None
self.v_bias = None
self.attn_drop = nn.Dropout(attn_drop)
self.proj = nn.Linear(all_head_dim, dim)
self.proj_drop = nn.Dropout(proj_drop)
def forward(self, x):
B, N, C = x.shape
qkv_bias = None
if self.q_bias is not None:
qkv_bias = torch.cat((self.q_bias, torch.zeros_like(self.v_bias, requires_grad=False), self.v_bias))
qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
qkv = qkv.reshape(B, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
# Use PyTorch native implementation of FlashAttention-2
attn = F.scaled_dot_product_attention(q, k, v)
x = attn.transpose(1, 2).reshape(B, N, -1)
# Deprecated attn implementation, which consumes much more VRAM
# q = q * self.scale
# attn = q @ k.transpose(-2, -1)
# attn = attn.softmax(dim=-1)
# attn = self.attn_drop(attn)
# x = (attn @ v).transpose(1, 2).reshape(B, N, -1)
x = self.proj(x)
x = self.proj_drop(x)
return x
class Block(nn.Module):
def __init__(
self,
dim,
num_heads,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop=0.0,
attn_drop=0.0,
drop_path=0.0,
init_values=None,
act_layer=nn.GELU,
norm_layer=nn.LayerNorm,
attn_head_dim=None,
cos_attn=False,
):
super().__init__()
self.norm1 = norm_layer(dim)
if cos_attn:
self.attn = CosAttention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
attn_head_dim=attn_head_dim,
)
else:
self.attn = Attention(
dim,
num_heads=num_heads,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
proj_drop=drop,
attn_head_dim=attn_head_dim,
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
if init_values > 0:
self.gamma_1 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
self.gamma_2 = nn.Parameter(init_values * torch.ones((dim)), requires_grad=True)
else:
self.gamma_1, self.gamma_2 = None, None
def forward(self, x):
if self.gamma_1 is None:
x = x + self.drop_path(self.attn(self.norm1(x)))
x = x + self.drop_path(self.mlp(self.norm2(x)))
else:
x = x + self.drop_path(self.gamma_1 * self.attn(self.norm1(x)))
x = x + self.drop_path(self.gamma_2 * self.mlp(self.norm2(x)))
return x
class PatchEmbed(nn.Module):
"""Image to Patch Embedding"""
def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, num_frames=16, tubelet_size=2):
super().__init__()
img_size = to_2tuple(img_size)
patch_size = to_2tuple(patch_size)
num_spatial_patches = (img_size[0] // patch_size[0]) * (img_size[1] // patch_size[1])
num_patches = num_spatial_patches * (num_frames // tubelet_size)
self.img_size = img_size
self.tubelet_size = tubelet_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv3d(
in_channels=in_chans,
out_channels=embed_dim,
kernel_size=(self.tubelet_size, patch_size[0], patch_size[1]),
stride=(self.tubelet_size, patch_size[0], patch_size[1]),
)
def forward(self, x, **kwargs):
B, C, T, H, W = x.shape
assert (
H == self.img_size[0] and W == self.img_size[1]
), f"Input image size ({H}*{W}) doesn't match model ({self.img_size[0]}*{self.img_size[1]})."
# b, c, l -> b, l, c
# [1, 1408, 8, 16, 16] -> [1, 1408, 2048] -> [1, 2048, 1408]
x = self.proj(x).flatten(2).transpose(1, 2)
return x
# sin-cos position encoding
# https://github.com/jadore801120/attention-is-all-you-need-pytorch/blob/master/transformer/Models.py#L31
def get_sinusoid_encoding_table(n_position, d_hid):
"""Sinusoid position encoding table"""
# TODO: make it with torch instead of numpy
def get_position_angle_vec(position):
return [position / np.power(10000, 2 * (hid_j // 2) / d_hid) for hid_j in range(d_hid)]
sinusoid_table = np.array([get_position_angle_vec(pos_i) for pos_i in range(n_position)])
sinusoid_table[:, 0::2] = np.sin(sinusoid_table[:, 0::2]) # dim 2i
sinusoid_table[:, 1::2] = np.cos(sinusoid_table[:, 1::2]) # dim 2i+1
return torch.tensor(sinusoid_table, dtype=torch.float, requires_grad=False).unsqueeze(0)
class VisionTransformer(nn.Module):
"""Vision Transformer with support for patch or hybrid CNN input stage"""
def __init__(
self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=1000,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.0,
qkv_bias=False,
qk_scale=None,
drop_rate=0.0,
attn_drop_rate=0.0,
drop_path_rate=0.0,
head_drop_rate=0.0,
norm_layer=nn.LayerNorm,
init_values=0.0,
use_learnable_pos_emb=False,
init_scale=0.0,
all_frames=16,
tubelet_size=2,
use_mean_pooling=True,
with_cp=False,
cos_attn=False,
):
super().__init__()
self.num_classes = num_classes
# num_features for consistency with other models
self.num_features = self.embed_dim = embed_dim
self.tubelet_size = tubelet_size
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
num_frames=all_frames,
tubelet_size=tubelet_size,
)
num_patches = self.patch_embed.num_patches
self.with_cp = with_cp
if use_learnable_pos_emb:
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
else:
# sine-cosine positional embeddings is on the way
self.pos_embed = get_sinusoid_encoding_table(num_patches, embed_dim)
self.pos_drop = nn.Dropout(p=drop_rate)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
self.blocks = nn.ModuleList(
[
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
init_values=init_values,
cos_attn=cos_attn,
)
for i in range(depth)
]
)
self.norm = nn.Identity() if use_mean_pooling else norm_layer(embed_dim)
self.fc_norm = norm_layer(embed_dim) if use_mean_pooling else None
self.head_dropout = nn.Dropout(head_drop_rate)
self.head = nn.Linear(embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if use_learnable_pos_emb:
trunc_normal_(self.pos_embed, std=0.02)
self.apply(self._init_weights)
self.head.weight.data.mul_(init_scale)
self.head.bias.data.mul_(init_scale)
self.num_frames = all_frames
def _init_weights(self, m):
if isinstance(m, nn.Linear):
trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {"pos_embed", "cls_token"}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=""):
self.num_classes = num_classes
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def interpolate_pos_encoding(self, t):
T = 8
t0 = t // self.tubelet_size
if T == t0:
return self.pos_embed
dim = self.pos_embed.shape[-1]
patch_pos_embed = self.pos_embed.permute(0, 2, 1).reshape(1, dim, 8, 16, 16)
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
t0 = t0 + 0.1
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed,
scale_factor=(t0 / T, 1, 1),
mode="trilinear",
)
assert int(t0) == patch_pos_embed.shape[-3]
patch_pos_embed = patch_pos_embed.reshape(1, dim, -1).permute(0, 2, 1)
return patch_pos_embed
def forward_features(self, x):
# [1, 3, 16, 224, 224]
B = x.size(0)
T = x.size(2)
# [1, 2048, 1408]
x = self.patch_embed(x)
if self.pos_embed is not None:
x = x + self.interpolate_pos_encoding(T).expand(B, -1, -1).type_as(x).to(x.device).clone().detach()
x = self.pos_drop(x)
for blk in self.blocks:
if self.with_cp:
x = cp.checkpoint(blk, x, use_reentrant=False)
else:
x = blk(x)
# return self.fc_norm(x)
if self.fc_norm is not None:
return self.fc_norm(x.mean(1))
else:
return self.norm(x[:, 0])
def forward(self, x):
x = self.forward_features(x)
x = self.head_dropout(x)
x = self.head(x)
return x
def vit_giant_patch14_224(pretrained=False, **kwargs):
model = VisionTransformer(
patch_size=14,
embed_dim=1408,
depth=40,
num_heads=16,
mlp_ratio=48 / 11,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs,
)
model.default_cfg = _cfg()
return model

View File

@@ -0,0 +1,469 @@
# --------------------------------------------------------
# Based on BEiT, timm, DINO and DeiT code bases
# https://github.com/microsoft/unilm/tree/master/beit
# https://github.com/rwightman/pytorch-image-models/tree/master/timm
# https://github.com/facebookresearch/deit
# https://github.com/facebookresearch/dino
# --------------------------------------------------------'
from functools import partial
import torch
import torch.nn as nn
import torch.utils.checkpoint as cp
from .videomaev2_finetune import (
Block,
PatchEmbed,
_cfg,
get_sinusoid_encoding_table,
)
from .videomaev2_finetune import trunc_normal_ as __call_trunc_normal_
def trunc_normal_(tensor, mean=0., std=1.):
__call_trunc_normal_(tensor, mean=mean, std=std, a=-std, b=std)
class PretrainVisionTransformerEncoder(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self,
img_size=224,
patch_size=16,
in_chans=3,
num_classes=0,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
init_values=None,
tubelet_size=2,
use_learnable_pos_emb=False,
with_cp=False,
all_frames=16,
cos_attn=False):
super().__init__()
self.num_classes = num_classes
# num_features for consistency with other models
self.num_features = self.embed_dim = embed_dim
self.patch_embed = PatchEmbed(
img_size=img_size,
patch_size=patch_size,
in_chans=in_chans,
embed_dim=embed_dim,
num_frames=all_frames,
tubelet_size=tubelet_size)
num_patches = self.patch_embed.num_patches
self.with_cp = with_cp
if use_learnable_pos_emb:
self.pos_embed = nn.Parameter(
torch.zeros(1, num_patches + 1, embed_dim))
else:
# sine-cosine positional embeddings
self.pos_embed = get_sinusoid_encoding_table(
num_patches, embed_dim)
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
init_values=init_values,
cos_attn=cos_attn) for i in range(depth)
])
self.norm = norm_layer(embed_dim)
self.head = nn.Linear(
embed_dim, num_classes) if num_classes > 0 else nn.Identity()
if use_learnable_pos_emb:
trunc_normal_(self.pos_embed, std=.02)
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward_features(self, x, mask):
x = self.patch_embed(x)
x = x + self.pos_embed.type_as(x).to(x.device).clone().detach()
B, _, C = x.shape
x_vis = x[~mask].reshape(B, -1, C) # ~mask means visible
for blk in self.blocks:
if self.with_cp:
x_vis = cp.checkpoint(blk, x_vis)
else:
x_vis = blk(x_vis)
x_vis = self.norm(x_vis)
return x_vis
def forward(self, x, mask):
x = self.forward_features(x, mask)
x = self.head(x)
return x
class PretrainVisionTransformerDecoder(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(self,
patch_size=16,
num_classes=768,
embed_dim=768,
depth=12,
num_heads=12,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
init_values=None,
num_patches=196,
tubelet_size=2,
with_cp=False,
cos_attn=False):
super().__init__()
self.num_classes = num_classes
assert num_classes == 3 * tubelet_size * patch_size**2
# num_features for consistency with other models
self.num_features = self.embed_dim = embed_dim
self.patch_size = patch_size
self.with_cp = with_cp
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)
] # stochastic depth decay rule
self.blocks = nn.ModuleList([
Block(
dim=embed_dim,
num_heads=num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
drop_path=dpr[i],
norm_layer=norm_layer,
init_values=init_values,
cos_attn=cos_attn) for i in range(depth)
])
self.norm = norm_layer(embed_dim)
self.head = nn.Linear(
embed_dim, num_classes) if num_classes > 0 else nn.Identity()
self.apply(self._init_weights)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token'}
def get_classifier(self):
return self.head
def reset_classifier(self, num_classes, global_pool=''):
self.num_classes = num_classes
self.head = nn.Linear(
self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
def forward(self, x, return_token_num):
for blk in self.blocks:
if self.with_cp:
x = cp.checkpoint(blk, x)
else:
x = blk(x)
if return_token_num > 0:
# only return the mask tokens predict pixels
x = self.head(self.norm(x[:, -return_token_num:]))
else:
# [B, N, 3*16^2]
x = self.head(self.norm(x))
return x
class PretrainVisionTransformer(nn.Module):
""" Vision Transformer with support for patch or hybrid CNN input stage
"""
def __init__(
self,
img_size=224,
patch_size=16,
encoder_in_chans=3,
encoder_num_classes=0,
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
decoder_num_classes=1536, # decoder_num_classes=768
decoder_embed_dim=512,
decoder_depth=8,
decoder_num_heads=8,
mlp_ratio=4.,
qkv_bias=False,
qk_scale=None,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.,
norm_layer=nn.LayerNorm,
init_values=0.,
use_learnable_pos_emb=False,
tubelet_size=2,
num_classes=0, # avoid the error from create_fn in timm
in_chans=0, # avoid the error from create_fn in timm
with_cp=False,
all_frames=16,
cos_attn=False,
):
super().__init__()
self.encoder = PretrainVisionTransformerEncoder(
img_size=img_size,
patch_size=patch_size,
in_chans=encoder_in_chans,
num_classes=encoder_num_classes,
embed_dim=encoder_embed_dim,
depth=encoder_depth,
num_heads=encoder_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_layer=norm_layer,
init_values=init_values,
tubelet_size=tubelet_size,
use_learnable_pos_emb=use_learnable_pos_emb,
with_cp=with_cp,
all_frames=all_frames,
cos_attn=cos_attn)
self.decoder = PretrainVisionTransformerDecoder(
patch_size=patch_size,
num_patches=self.encoder.patch_embed.num_patches,
num_classes=decoder_num_classes,
embed_dim=decoder_embed_dim,
depth=decoder_depth,
num_heads=decoder_num_heads,
mlp_ratio=mlp_ratio,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
drop_rate=drop_rate,
attn_drop_rate=attn_drop_rate,
drop_path_rate=drop_path_rate,
norm_layer=norm_layer,
init_values=init_values,
tubelet_size=tubelet_size,
with_cp=with_cp,
cos_attn=cos_attn)
self.encoder_to_decoder = nn.Linear(
encoder_embed_dim, decoder_embed_dim, bias=False)
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
self.pos_embed = get_sinusoid_encoding_table(
self.encoder.patch_embed.num_patches, decoder_embed_dim)
trunc_normal_(self.mask_token, std=.02)
def _init_weights(self, m):
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def get_num_layers(self):
return len(self.blocks)
@torch.jit.ignore
def no_weight_decay(self):
return {'pos_embed', 'cls_token', 'mask_token'}
def forward(self, x, mask, decode_mask=None):
decode_vis = mask if decode_mask is None else ~decode_mask
x_vis = self.encoder(x, mask) # [B, N_vis, C_e]
x_vis = self.encoder_to_decoder(x_vis) # [B, N_vis, C_d]
B, N_vis, C = x_vis.shape
# we don't unshuffle the correct visible token order,
# but shuffle the pos embedding accorddingly.
expand_pos_embed = self.pos_embed.expand(B, -1, -1).type_as(x).to(
x.device).clone().detach()
pos_emd_vis = expand_pos_embed[~mask].reshape(B, -1, C)
pos_emd_mask = expand_pos_embed[decode_vis].reshape(B, -1, C)
# [B, N, C_d]
x_full = torch.cat(
[x_vis + pos_emd_vis, self.mask_token + pos_emd_mask], dim=1)
# NOTE: if N_mask==0, the shape of x is [B, N_mask, 3 * 16 * 16]
x = self.decoder(x_full, pos_emd_mask.shape[1])
return x
def pretrain_videomae_small_patch16_224(pretrained=False, **kwargs):
model = PretrainVisionTransformer(
img_size=224,
patch_size=16,
encoder_embed_dim=384,
encoder_depth=12,
encoder_num_heads=6,
encoder_num_classes=0,
decoder_num_classes=1536, # 16 * 16 * 3 * 2
decoder_embed_dim=192,
decoder_num_heads=3,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
model.load_state_dict(checkpoint["model"])
return model
def pretrain_videomae_base_patch16_224(pretrained=False, **kwargs):
model = PretrainVisionTransformer(
img_size=224,
patch_size=16,
encoder_embed_dim=768,
encoder_depth=12,
encoder_num_heads=12,
encoder_num_classes=0,
decoder_num_classes=1536, # 16 * 16 * 3 * 2
decoder_embed_dim=384,
decoder_num_heads=6,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
model.load_state_dict(checkpoint["model"])
return model
def pretrain_videomae_large_patch16_224(pretrained=False, **kwargs):
model = PretrainVisionTransformer(
img_size=224,
patch_size=16,
encoder_embed_dim=1024,
encoder_depth=24,
encoder_num_heads=16,
encoder_num_classes=0,
decoder_num_classes=1536, # 16 * 16 * 3 * 2
decoder_embed_dim=512,
decoder_num_heads=8,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
model.load_state_dict(checkpoint["model"])
return model
def pretrain_videomae_huge_patch16_224(pretrained=False, **kwargs):
model = PretrainVisionTransformer(
img_size=224,
patch_size=16,
encoder_embed_dim=1280,
encoder_depth=32,
encoder_num_heads=16,
encoder_num_classes=0,
decoder_num_classes=1536, # 16 * 16 * 3 * 2
decoder_embed_dim=512,
decoder_num_heads=8,
mlp_ratio=4,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
model.load_state_dict(checkpoint["model"])
return model
def pretrain_videomae_giant_patch14_224(pretrained=False, **kwargs):
model = PretrainVisionTransformer(
img_size=224,
patch_size=14,
encoder_embed_dim=1408,
encoder_depth=40,
encoder_num_heads=16,
encoder_num_classes=0,
decoder_num_classes=1176, # 14 * 14 * 3 * 2,
decoder_embed_dim=512,
decoder_num_heads=8,
mlp_ratio=48 / 11,
qkv_bias=True,
norm_layer=partial(nn.LayerNorm, eps=1e-6),
**kwargs)
model.default_cfg = _cfg()
if pretrained:
checkpoint = torch.load(kwargs["init_ckpt"], map_location="cpu")
model.load_state_dict(checkpoint["model"])
return model

View File

@@ -0,0 +1,321 @@
import os
import math
import os.path as osp
import random
import pickle
import warnings
import glob
import numpy as np
from PIL import Image
import torch
import torch.utils.data as data
import torch.nn.functional as F
import torch.distributed as dist
from torchvision.datasets.video_utils import VideoClips
IMG_EXTENSIONS = ['.jpg', '.JPG', '.jpeg', '.JPEG', '.png', '.PNG']
VID_EXTENSIONS = ['.avi', '.mp4', '.webm', '.mov', '.mkv', '.m4v']
def get_dataloader(data_path, image_folder, resolution=128, sequence_length=16, sample_every_n_frames=1,
batch_size=16, num_workers=8):
data = VideoData(data_path, image_folder, resolution, sequence_length, sample_every_n_frames, batch_size, num_workers)
loader = data._dataloader()
return loader
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def get_parent_dir(path):
return osp.basename(osp.dirname(path))
def preprocess(video, resolution, sequence_length=None, in_channels=3, sample_every_n_frames=1):
# video: THWC, {0, ..., 255}
assert in_channels == 3
video = video.permute(0, 3, 1, 2).float() / 255. # TCHW
t, c, h, w = video.shape
# temporal crop
if sequence_length is not None:
assert sequence_length <= t
video = video[:sequence_length]
# skip frames
if sample_every_n_frames > 1:
video = video[::sample_every_n_frames]
# scale shorter side to resolution
scale = resolution / min(h, w)
if h < w:
target_size = (resolution, math.ceil(w * scale))
else:
target_size = (math.ceil(h * scale), resolution)
video = F.interpolate(video, size=target_size, mode='bilinear',
align_corners=False, antialias=True)
# center crop
t, c, h, w = video.shape
w_start = (w - resolution) // 2
h_start = (h - resolution) // 2
video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
video = video.permute(1, 0, 2, 3).contiguous() # CTHW
return {'video': video}
def preprocess_image(image):
# [0, 1] => [-1, 1]
img = torch.from_numpy(image)
return img
class VideoData(data.Dataset):
""" Class to create dataloaders for video datasets
Args:
data_path: Path to the folder with video frames or videos.
image_folder: If True, the data is stored as images in folders.
resolution: Resolution of the returned videos.
sequence_length: Length of extracted video sequences.
sample_every_n_frames: Sample every n frames from the video.
batch_size: Batch size.
num_workers: Number of workers for the dataloader.
shuffle: If True, shuffle the data.
"""
def __init__(self, data_path: str, image_folder: bool, resolution: int, sequence_length: int,
sample_every_n_frames: int, batch_size: int, num_workers: int, shuffle: bool = True):
super().__init__()
self.data_path = data_path
self.image_folder = image_folder
self.resolution = resolution
self.sequence_length = sequence_length
self.sample_every_n_frames = sample_every_n_frames
self.batch_size = batch_size
self.num_workers = num_workers
self.shuffle = shuffle
def _dataset(self):
'''
Initializes and return the dataset.
'''
if self.image_folder:
Dataset = FrameDataset
dataset = Dataset(self.data_path, self.sequence_length,
resolution=self.resolution, sample_every_n_frames=self.sample_every_n_frames)
else:
Dataset = VideoDataset
dataset = Dataset(self.data_path, self.sequence_length,
resolution=self.resolution, sample_every_n_frames=self.sample_every_n_frames)
return dataset
def _dataloader(self):
'''
Initializes and returns the dataloader.
'''
dataset = self._dataset()
if dist.is_initialized():
sampler = data.distributed.DistributedSampler(
dataset, num_replicas=dist.get_world_size(), rank=dist.get_rank()
)
else:
sampler = None
dataloader = data.DataLoader(
dataset,
batch_size=self.batch_size,
num_workers=self.num_workers,
pin_memory=True,
sampler=sampler,
shuffle=sampler is None and self.shuffle is True
)
return dataloader
class VideoDataset(data.Dataset):
"""
Generic dataset for videos files stored in folders.
Videos of the same class are expected to be stored in a single folder. Multiple folders can exist in the provided directory.
The class depends on `torchvision.datasets.video_utils.VideoClips` to load the videos.
Returns BCTHW videos in the range [0, 1].
Args:
data_folder: Path to the folder with corresponding videos stored.
sequence_length: Length of extracted video sequences.
resolution: Resolution of the returned videos.
sample_every_n_frames: Sample every n frames from the video.
"""
def __init__(self, data_folder: str, sequence_length: int = 16, resolution: int = 128, sample_every_n_frames: int = 1):
super().__init__()
self.sequence_length = sequence_length
self.resolution = resolution
self.sample_every_n_frames = sample_every_n_frames
folder = data_folder
files = sum([glob.glob(osp.join(folder, '**', f'*{ext}'), recursive=True)
for ext in VID_EXTENSIONS], [])
warnings.filterwarnings('ignore')
cache_file = osp.join(folder, f"metadata_{sequence_length}.pkl")
if not osp.exists(cache_file):
clips = VideoClips(files, sequence_length, num_workers=4)
try:
pickle.dump(clips.metadata, open(cache_file, 'wb'))
except:
print(f"Failed to save metadata to {cache_file}")
else:
metadata = pickle.load(open(cache_file, 'rb'))
clips = VideoClips(files, sequence_length,
_precomputed_metadata=metadata)
self._clips = clips
# instead of uniformly sampling from all possible clips, we sample uniformly from all possible videos
self._clips.get_clip_location = self.get_random_clip_from_video
def get_random_clip_from_video(self, idx: int) -> tuple:
'''
Sample a random clip starting index from the video.
Args:
idx: Index of the video.
'''
# Note that some videos may not contain enough frames, we skip those videos here.
while self._clips.clips[idx].shape[0] <= 0:
idx += 1
n_clip = self._clips.clips[idx].shape[0]
clip_id = random.randint(0, n_clip - 1)
return idx, clip_id
def __len__(self):
return self._clips.num_videos()
def __getitem__(self, idx):
resolution = self.resolution
while True:
try:
video, _, _, idx = self._clips.get_clip(idx)
except Exception as e:
print(idx, e)
idx = (idx + 1) % self._clips.num_clips()
continue
break
return dict(**preprocess(video, resolution, sample_every_n_frames=self.sample_every_n_frames))
class FrameDataset(data.Dataset):
"""
Generic dataset for videos stored as images. The loading will iterates over all the folders and subfolders
in the provided directory. Each leaf folder is assumed to contain frames from a single video.
Args:
data_folder: path to the folder with video frames. The folder
should contain folders with frames from each video.
sequence_length: length of extracted video sequences
resolution: resolution of the returned videos
sample_every_n_frames: sample every n frames from the video
"""
def __init__(self, data_folder, sequence_length, resolution=64, sample_every_n_frames=1):
self.resolution = resolution
self.sequence_length = sequence_length
self.sample_every_n_frames = sample_every_n_frames
self.data_all = self.load_video_frames(data_folder)
self.video_num = len(self.data_all)
def __getitem__(self, index):
batch_data = self.getTensor(index)
return_list = {'video': batch_data}
return return_list
def load_video_frames(self, dataroot: str) -> list:
'''
Loads all the video frames under the dataroot and returns a list of all the video frames.
Args:
dataroot: The root directory containing the video frames.
Returns:
A list of all the video frames.
'''
data_all = []
frame_list = os.walk(dataroot)
for _, meta in enumerate(frame_list):
root = meta[0]
try:
frames = sorted(meta[2], key=lambda item: int(item.split('.')[0].split('_')[-1]))
except:
print(meta[0], meta[2])
if len(frames) < max(0, self.sequence_length * self.sample_every_n_frames):
continue
frames = [
os.path.join(root, item) for item in frames
if is_image_file(item)
]
if len(frames) > max(0, self.sequence_length * self.sample_every_n_frames):
data_all.append(frames)
return data_all
def getTensor(self, index: int) -> torch.Tensor:
'''
Returns a tensor of the video frames at the given index.
Args:
index: The index of the video frames to return.
Returns:
A BCTHW tensor in the range `[0, 1]` of the video frames at the given index.
'''
video = self.data_all[index]
video_len = len(video)
# load the entire video when sequence_length = -1, whiel the sample_every_n_frames has to be 1
if self.sequence_length == -1:
assert self.sample_every_n_frames == 1
start_idx = 0
end_idx = video_len
else:
n_frames_interval = self.sequence_length * self.sample_every_n_frames
start_idx = random.randint(0, video_len - n_frames_interval)
end_idx = start_idx + n_frames_interval
img = Image.open(video[0])
h, w = img.height, img.width
if h > w:
half = (h - w) // 2
cropsize = (0, half, w, half + w) # left, upper, right, lower
elif w > h:
half = (w - h) // 2
cropsize = (half, 0, half + h, h)
images = []
for i in range(start_idx, end_idx,
self.sample_every_n_frames):
path = video[i]
img = Image.open(path)
if h != w:
img = img.crop(cropsize)
img = img.resize(
(self.resolution, self.resolution),
Image.ANTIALIAS)
img = np.asarray(img, dtype=np.float32)
img /= 255.
img_tensor = preprocess_image(img).unsqueeze(0)
images.append(img_tensor)
video_clip = torch.cat(images).permute(3, 0, 1, 2)
return video_clip
def __len__(self):
return self.video_num

View File

@@ -0,0 +1,161 @@
# Adapted from https://github.com/universome/stylegan-v/blob/master/src/metrics/metric_utils.py
import os
import random
import torch
import pickle
import numpy as np
from typing import List, Tuple
def seed_everything(seed):
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
class FeatureStats:
'''
Class to store statistics of features, including all features and mean/covariance.
Args:
capture_all: Whether to store all the features.
capture_mean_cov: Whether to store mean and covariance.
max_items: Maximum number of items to store.
'''
def __init__(self, capture_all: bool = False, capture_mean_cov: bool = False, max_items: int = None):
'''
'''
self.capture_all = capture_all
self.capture_mean_cov = capture_mean_cov
self.max_items = max_items
self.num_items = 0
self.num_features = None
self.all_features = None
self.raw_mean = None
self.raw_cov = None
def set_num_features(self, num_features: int):
'''
Set the number of features diminsions.
Args:
num_features: Number of features diminsions.
'''
if self.num_features is not None:
assert num_features == self.num_features
else:
self.num_features = num_features
self.all_features = []
self.raw_mean = np.zeros([num_features], dtype=np.float64)
self.raw_cov = np.zeros([num_features, num_features], dtype=np.float64)
def is_full(self) -> bool:
'''
Check if the maximum number of samples is reached.
Returns:
True if the storage is full, False otherwise.
'''
return (self.max_items is not None) and (self.num_items >= self.max_items)
def append(self, x: np.ndarray):
'''
Add the newly computed features to the list. Update the mean and covariance.
Args:
x: New features to record.
'''
x = np.asarray(x, dtype=np.float32)
assert x.ndim == 2
if (self.max_items is not None) and (self.num_items + x.shape[0] > self.max_items):
if self.num_items >= self.max_items:
return
x = x[:self.max_items - self.num_items]
self.set_num_features(x.shape[1])
self.num_items += x.shape[0]
if self.capture_all:
self.all_features.append(x)
if self.capture_mean_cov:
x64 = x.astype(np.float64)
self.raw_mean += x64.sum(axis=0)
self.raw_cov += x64.T @ x64
def append_torch(self, x: torch.Tensor, rank: int, num_gpus: int):
'''
Add the newly computed PyTorch features to the list. Update the mean and covariance.
Args:
x: New features to record.
rank: Rank of the current GPU.
num_gpus: Total number of GPUs.
'''
assert isinstance(x, torch.Tensor) and x.ndim == 2
assert 0 <= rank < num_gpus
if num_gpus > 1:
ys = []
for src in range(num_gpus):
y = x.clone()
torch.distributed.broadcast(y, src=src)
ys.append(y)
x = torch.stack(ys, dim=1).flatten(0, 1) # interleave samples
self.append(x.cpu().numpy())
def get_all(self) -> np.ndarray:
'''
Get all the stored features as NumPy Array.
Returns:
Concatenation of the stored features.
'''
assert self.capture_all
return np.concatenate(self.all_features, axis=0)
def get_all_torch(self) -> torch.Tensor:
'''
Get all the stored features as PyTorch Tensor.
Returns:
Concatenation of the stored features.
'''
return torch.from_numpy(self.get_all())
def get_mean_cov(self) -> Tuple[np.ndarray, np.ndarray]:
'''
Get the mean and covariance of the stored features.
Returns:
Mean and covariance of the stored features.
'''
assert self.capture_mean_cov
mean = self.raw_mean / self.num_items
cov = self.raw_cov / self.num_items
cov = cov - np.outer(mean, mean)
return mean, cov
def save(self, pkl_file: str):
'''
Save the features and statistics to a pickle file.
Args:
pkl_file: Path to the pickle file.
'''
with open(pkl_file, 'wb') as f:
pickle.dump(self.__dict__, f)
@staticmethod
def load(pkl_file: str) -> 'FeatureStats':
'''
Load the features and statistics from a pickle file.
Args:
pkl_file: Path to the pickle file.
'''
with open(pkl_file, 'rb') as f:
s = pickle.load(f)
obj = FeatureStats(capture_all=s['capture_all'], max_items=s['max_items'])
obj.__dict__.update(s)
print('Loaded %d features from %s' % (obj.num_items, pkl_file))
return obj

View File

@@ -0,0 +1,145 @@
# Adapted from https://github.com/guanjz20/StyleSync/blob/main/utils.py
import numpy as np
import cv2
import torch
from einops import rearrange
import kornia
class AlignRestore(object):
def __init__(self, align_points=3, resolution=256, device="cpu", dtype=torch.float16):
if align_points == 3:
self.upscale_factor = 1
ratio = resolution / 256 * 2.8
self.crop_ratio = (ratio, ratio)
self.face_template = np.array([[19 - 2, 30 - 10], [56 + 2, 30 - 10], [37.5, 45 - 5]])
self.face_template = self.face_template * ratio
self.face_size = (int(75 * self.crop_ratio[0]), int(100 * self.crop_ratio[1]))
self.p_bias = None
self.device = device
self.dtype = dtype
self.fill_value = torch.tensor([127, 127, 127], device=device, dtype=dtype)
self.mask = torch.ones((1, 1, self.face_size[1], self.face_size[0]), device=device, dtype=dtype)
def align_warp_face(self, img, landmarks3, smooth=True):
affine_matrix, self.p_bias = self.transformation_from_points(
landmarks3, self.face_template, smooth, self.p_bias
)
img = rearrange(torch.from_numpy(img).to(device=self.device, dtype=self.dtype), "h w c -> c h w").unsqueeze(0)
affine_matrix = torch.from_numpy(affine_matrix).to(device=self.device, dtype=self.dtype).unsqueeze(0)
cropped_face = kornia.geometry.transform.warp_affine(
img,
affine_matrix,
(self.face_size[1], self.face_size[0]),
mode="bilinear",
padding_mode="fill",
fill_value=self.fill_value,
)
cropped_face = rearrange(cropped_face.squeeze(0), "c h w -> h w c").cpu().numpy().astype(np.uint8)
return cropped_face, affine_matrix
def restore_img(self, input_img, face, affine_matrix):
h, w, _ = input_img.shape
if isinstance(affine_matrix, np.ndarray):
affine_matrix = torch.from_numpy(affine_matrix).to(device=self.device, dtype=self.dtype).unsqueeze(0)
inv_affine_matrix = kornia.geometry.transform.invert_affine_transform(affine_matrix)
face = face.to(dtype=self.dtype).unsqueeze(0)
inv_face = kornia.geometry.transform.warp_affine(
face, inv_affine_matrix, (h, w), mode="bilinear", padding_mode="fill", fill_value=self.fill_value
).squeeze(0)
inv_face = (inv_face / 2 + 0.5).clamp(0, 1) * 255
input_img = rearrange(torch.from_numpy(input_img).to(device=self.device, dtype=self.dtype), "h w c -> c h w")
inv_mask = kornia.geometry.transform.warp_affine(
self.mask, inv_affine_matrix, (h, w), padding_mode="zeros"
) # (1, 1, h_up, w_up)
inv_mask_erosion = kornia.morphology.erosion(
inv_mask,
torch.ones(
(int(2 * self.upscale_factor), int(2 * self.upscale_factor)), device=self.device, dtype=self.dtype
),
)
inv_mask_erosion_t = inv_mask_erosion.squeeze(0).expand_as(inv_face)
pasted_face = inv_mask_erosion_t * inv_face
total_face_area = torch.sum(inv_mask_erosion.float())
w_edge = int(total_face_area**0.5) // 20
erosion_radius = w_edge * 2
# This step will consume a large amount of GPU memory.
# inv_mask_center = kornia.morphology.erosion(
# inv_mask_erosion, torch.ones((erosion_radius, erosion_radius), device=self.device, dtype=self.dtype)
# )
# Run on CPU to avoid consuming a large amount of GPU memory.
inv_mask_erosion = inv_mask_erosion.squeeze().cpu().numpy().astype(np.float32)
inv_mask_center = cv2.erode(inv_mask_erosion, np.ones((erosion_radius, erosion_radius), np.uint8))
inv_mask_center = torch.from_numpy(inv_mask_center).to(device=self.device, dtype=self.dtype)[None, None, ...]
blur_size = w_edge * 2 + 1
sigma = 0.3 * ((blur_size - 1) * 0.5 - 1) + 0.8
inv_soft_mask = kornia.filters.gaussian_blur2d(
inv_mask_center, (blur_size, blur_size), (sigma, sigma)
).squeeze(0)
inv_soft_mask_3d = inv_soft_mask.expand_as(inv_face)
img_back = inv_soft_mask_3d * pasted_face + (1 - inv_soft_mask_3d) * input_img
img_back = rearrange(img_back, "c h w -> h w c").contiguous().to(dtype=torch.uint8)
img_back = img_back.cpu().numpy()
return img_back
def transformation_from_points(self, points1: torch.Tensor, points0: torch.Tensor, smooth=True, p_bias=None):
if isinstance(points0, np.ndarray):
points2 = torch.tensor(points0, device=self.device, dtype=torch.float32)
else:
points2 = points0.clone()
if isinstance(points1, np.ndarray):
points1_tensor = torch.tensor(points1, device=self.device, dtype=torch.float32)
else:
points1_tensor = points1.clone()
c1 = torch.mean(points1_tensor, dim=0)
c2 = torch.mean(points2, dim=0)
points1_centered = points1_tensor - c1
points2_centered = points2 - c2
s1 = torch.std(points1_centered)
s2 = torch.std(points2_centered)
points1_normalized = points1_centered / s1
points2_normalized = points2_centered / s2
covariance = torch.matmul(points1_normalized.T, points2_normalized)
U, S, V = torch.svd(covariance.float())
R = torch.matmul(V, U.T)
det = torch.det(R.float())
if det < 0:
V[:, -1] = -V[:, -1]
R = torch.matmul(V, U.T)
sR = (s2 / s1) * R
T = c2.reshape(2, 1) - (s2 / s1) * torch.matmul(R, c1.reshape(2, 1))
M = torch.cat((sR, T), dim=1)
if smooth:
bias = points2_normalized[2] - points1_normalized[2]
if p_bias is None:
p_bias = bias
else:
bias = p_bias * 0.2 + bias * 0.8
p_bias = bias
M[:, 2] = M[:, 2] + bias
return M.cpu().numpy(), p_bias

View File

@@ -0,0 +1,194 @@
# Adapted from https://github.com/Rudrabha/Wav2Lip/blob/master/audio.py
import librosa
import librosa.filters
import numpy as np
from scipy import signal
from scipy.io import wavfile
from omegaconf import OmegaConf
import torch
audio_config_path = "configs/audio.yaml"
config = OmegaConf.load(audio_config_path)
def load_wav(path, sr):
return librosa.core.load(path, sr=sr)[0]
def save_wav(wav, path, sr):
wav *= 32767 / max(0.01, np.max(np.abs(wav)))
# proposed by @dsmiller
wavfile.write(path, sr, wav.astype(np.int16))
def save_wavenet_wav(wav, path, sr):
librosa.output.write_wav(path, wav, sr=sr)
def preemphasis(wav, k, preemphasize=True):
if preemphasize:
return signal.lfilter([1, -k], [1], wav)
return wav
def inv_preemphasis(wav, k, inv_preemphasize=True):
if inv_preemphasize:
return signal.lfilter([1], [1, -k], wav)
return wav
def get_hop_size():
hop_size = config.audio.hop_size
if hop_size is None:
assert config.audio.frame_shift_ms is not None
hop_size = int(config.audio.frame_shift_ms / 1000 * config.audio.sample_rate)
return hop_size
def linearspectrogram(wav):
D = _stft(preemphasis(wav, config.audio.preemphasis, config.audio.preemphasize))
S = _amp_to_db(np.abs(D)) - config.audio.ref_level_db
if config.audio.signal_normalization:
return _normalize(S)
return S
def melspectrogram(wav):
D = _stft(preemphasis(wav, config.audio.preemphasis, config.audio.preemphasize))
S = _amp_to_db(_linear_to_mel(np.abs(D))) - config.audio.ref_level_db
if config.audio.signal_normalization:
return _normalize(S)
return S
def _lws_processor():
import lws
return lws.lws(config.audio.n_fft, get_hop_size(), fftsize=config.audio.win_size, mode="speech")
def _stft(y):
if config.audio.use_lws:
return _lws_processor(config.audio).stft(y).T
else:
return librosa.stft(y=y, n_fft=config.audio.n_fft, hop_length=get_hop_size(), win_length=config.audio.win_size)
##########################################################
# Those are only correct when using lws!!! (This was messing with Wavenet quality for a long time!)
def num_frames(length, fsize, fshift):
"""Compute number of time frames of spectrogram"""
pad = fsize - fshift
if length % fshift == 0:
M = (length + pad * 2 - fsize) // fshift + 1
else:
M = (length + pad * 2 - fsize) // fshift + 2
return M
def pad_lr(x, fsize, fshift):
"""Compute left and right padding"""
M = num_frames(len(x), fsize, fshift)
pad = fsize - fshift
T = len(x) + 2 * pad
r = (M - 1) * fshift + fsize - T
return pad, pad + r
##########################################################
# Librosa correct padding
def librosa_pad_lr(x, fsize, fshift):
return 0, (x.shape[0] // fshift + 1) * fshift - x.shape[0]
# Conversions
_mel_basis = None
def _linear_to_mel(spectogram):
global _mel_basis
if _mel_basis is None:
_mel_basis = _build_mel_basis()
return np.dot(_mel_basis, spectogram)
def _build_mel_basis():
assert config.audio.fmax <= config.audio.sample_rate // 2
return librosa.filters.mel(
sr=config.audio.sample_rate,
n_fft=config.audio.n_fft,
n_mels=config.audio.num_mels,
fmin=config.audio.fmin,
fmax=config.audio.fmax,
)
def _amp_to_db(x):
min_level = np.exp(config.audio.min_level_db / 20 * np.log(10))
return 20 * np.log10(np.maximum(min_level, x))
def _db_to_amp(x):
return np.power(10.0, (x) * 0.05)
def _normalize(S):
if config.audio.allow_clipping_in_normalization:
if config.audio.symmetric_mels:
return np.clip(
(2 * config.audio.max_abs_value) * ((S - config.audio.min_level_db) / (-config.audio.min_level_db))
- config.audio.max_abs_value,
-config.audio.max_abs_value,
config.audio.max_abs_value,
)
else:
return np.clip(
config.audio.max_abs_value * ((S - config.audio.min_level_db) / (-config.audio.min_level_db)),
0,
config.audio.max_abs_value,
)
assert S.max() <= 0 and S.min() - config.audio.min_level_db >= 0
if config.audio.symmetric_mels:
return (2 * config.audio.max_abs_value) * (
(S - config.audio.min_level_db) / (-config.audio.min_level_db)
) - config.audio.max_abs_value
else:
return config.audio.max_abs_value * ((S - config.audio.min_level_db) / (-config.audio.min_level_db))
def _denormalize(D):
if config.audio.allow_clipping_in_normalization:
if config.audio.symmetric_mels:
return (
(np.clip(D, -config.audio.max_abs_value, config.audio.max_abs_value) + config.audio.max_abs_value)
* -config.audio.min_level_db
/ (2 * config.audio.max_abs_value)
) + config.audio.min_level_db
else:
return (
np.clip(D, 0, config.audio.max_abs_value) * -config.audio.min_level_db / config.audio.max_abs_value
) + config.audio.min_level_db
if config.audio.symmetric_mels:
return (
(D + config.audio.max_abs_value) * -config.audio.min_level_db / (2 * config.audio.max_abs_value)
) + config.audio.min_level_db
else:
return (D * -config.audio.min_level_db / config.audio.max_abs_value) + config.audio.min_level_db
def get_melspec_overlap(audio_samples, melspec_length=52):
mel_spec_overlap = melspectrogram(audio_samples.numpy())
mel_spec_overlap = torch.from_numpy(mel_spec_overlap)
i = 0
mel_spec_overlap_list = []
while i + melspec_length < mel_spec_overlap.shape[1] - 3:
mel_spec_overlap_list.append(mel_spec_overlap[:, i : i + melspec_length].unsqueeze(0))
i += 3
mel_spec_overlap = torch.stack(mel_spec_overlap_list)
return mel_spec_overlap

View File

@@ -0,0 +1,157 @@
# We modified the original AVReader class of decord to solve the problem of memory leak.
# For more details, refer to: https://github.com/dmlc/decord/issues/208
import numpy as np
from decord.video_reader import VideoReader
from decord.audio_reader import AudioReader
from decord.ndarray import cpu
from decord import ndarray as _nd
from decord.bridge import bridge_out
class AVReader(object):
"""Individual audio video reader with convenient indexing function.
Parameters
----------
uri: str
Path of file.
ctx: decord.Context
The context to decode the file, can be decord.cpu() or decord.gpu().
sample_rate: int, default is -1
Desired output sample rate of the audio, unchanged if `-1` is specified.
mono: bool, default is True
Desired output channel layout of the audio. `True` is mono layout. `False` is unchanged.
width : int, default is -1
Desired output width of the video, unchanged if `-1` is specified.
height : int, default is -1
Desired output height of the video, unchanged if `-1` is specified.
num_threads : int, default is 0
Number of decoding thread, auto if `0` is specified.
fault_tol : int, default is -1
The threshold of corupted and recovered frames. This is to prevent silent fault
tolerance when for example 50% frames of a video cannot be decoded and duplicate
frames are returned. You may find the fault tolerant feature sweet in many cases,
but not for training models. Say `N = # recovered frames`
If `fault_tol` < 0, nothing will happen.
If 0 < `fault_tol` < 1.0, if N > `fault_tol * len(video)`, raise `DECORDLimitReachedError`.
If 1 < `fault_tol`, if N > `fault_tol`, raise `DECORDLimitReachedError`.
"""
def __init__(
self, uri, ctx=cpu(0), sample_rate=44100, mono=True, width=-1, height=-1, num_threads=0, fault_tol=-1
):
self.__audio_reader = AudioReader(uri, ctx, sample_rate, mono)
self.__audio_reader.add_padding()
if hasattr(uri, "read"):
uri.seek(0)
self.__video_reader = VideoReader(uri, ctx, width, height, num_threads, fault_tol)
self.__video_reader.seek(0)
def __len__(self):
"""Get length of the video. Note that sometimes FFMPEG reports inaccurate number of frames,
we always follow what FFMPEG reports.
Returns
-------
int
The number of frames in the video file.
"""
return len(self.__video_reader)
def __getitem__(self, idx):
"""Get audio samples and video frame at `idx`.
Parameters
----------
idx : int or slice
The frame index, can be negative which means it will index backwards,
or slice of frame indices.
Returns
-------
(ndarray/list of ndarray, ndarray)
First element is samples of shape CxS or a list of length N containing samples of shape CxS,
where N is the number of frames, C is the number of channels,
S is the number of samples of the corresponding frame.
Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3,
where N is the length of the slice.
"""
assert self.__video_reader is not None and self.__audio_reader is not None
if isinstance(idx, slice):
return self.get_batch(range(*idx.indices(len(self.__video_reader))))
if idx < 0:
idx += len(self.__video_reader)
if idx >= len(self.__video_reader) or idx < 0:
raise IndexError("Index: {} out of bound: {}".format(idx, len(self.__video_reader)))
audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx)
audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx)
audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx)
results = (self.__audio_reader[audio_start_idx:audio_end_idx], self.__video_reader[idx])
self.__video_reader.seek(0)
return results
def get_batch(self, indices):
"""Get entire batch of audio samples and video frames.
Parameters
----------
indices : list of integers
A list of frame indices. If negative indices detected, the indices will be indexed from backward
Returns
-------
(list of ndarray, ndarray)
First element is a list of length N containing samples of shape CxS,
where N is the number of frames, C is the number of channels,
S is the number of samples of the corresponding frame.
Second element is Frame of shape HxWx3 or batch of image frames with shape NxHxWx3,
where N is the length of the slice.
"""
assert self.__video_reader is not None and self.__audio_reader is not None
indices = self._validate_indices(indices)
audio_arr = []
prev_video_idx = None
prev_audio_end_idx = None
for idx in list(indices):
frame_start_time, frame_end_time = self.__video_reader.get_frame_timestamp(idx)
# timestamp and sample conversion could have some error that could cause non-continuous audio
# we detect if retrieving continuous frame and make the audio continuous
if prev_video_idx and idx == prev_video_idx + 1:
audio_start_idx = prev_audio_end_idx
else:
audio_start_idx = self.__audio_reader._time_to_sample(frame_start_time)
audio_end_idx = self.__audio_reader._time_to_sample(frame_end_time)
audio_arr.append(self.__audio_reader[audio_start_idx:audio_end_idx])
prev_video_idx = idx
prev_audio_end_idx = audio_end_idx
results = (audio_arr, self.__video_reader.get_batch(indices))
self.__video_reader.seek(0)
return results
def _get_slice(self, sl):
audio_arr = np.empty(shape=(self.__audio_reader.shape()[0], 0), dtype="float32")
for idx in list(sl):
audio_start_idx, audio_end_idx = self.__video_reader.get_frame_timestamp(idx)
audio_start_idx = self.__audio_reader._time_to_sample(audio_start_idx)
audio_end_idx = self.__audio_reader._time_to_sample(audio_end_idx)
audio_arr = np.concatenate(
(audio_arr, self.__audio_reader[audio_start_idx:audio_end_idx].asnumpy()), axis=1
)
results = (bridge_out(_nd.array(audio_arr)), self.__video_reader.get_batch(sl))
self.__video_reader.seek(0)
return results
def _validate_indices(self, indices):
"""Validate int64 integers and convert negative integers to positive by backward search"""
assert self.__video_reader is not None and self.__audio_reader is not None
indices = np.array(indices, dtype=np.int64)
# process negative indices
indices[indices < 0] += len(self.__video_reader)
if not (indices >= 0).all():
raise IndexError("Invalid negative indices: {}".format(indices[indices < 0] + len(self.__video_reader)))
if not (indices < len(self.__video_reader)).all():
raise IndexError("Out of bound indices: {}".format(indices[indices >= len(self.__video_reader)]))
return indices

View File

@@ -0,0 +1,115 @@
from insightface.app import FaceAnalysis
import numpy as np
import torch
INSIGHTFACE_DETECT_SIZE = 512
class FaceDetector:
def __init__(self, device="cuda"):
self.app = FaceAnalysis(
allowed_modules=["detection", "landmark_2d_106"],
root="checkpoints/auxiliary",
providers=["CUDAExecutionProvider"],
)
self.app.prepare(ctx_id=cuda_to_int(device), det_size=(INSIGHTFACE_DETECT_SIZE, INSIGHTFACE_DETECT_SIZE))
def __call__(self, frame, threshold=0.5):
f_h, f_w, _ = frame.shape
faces = self.app.get(frame)
get_face_store = None
max_size = 0
if len(faces) == 0:
return None, None
else:
for face in faces:
bbox = face.bbox.astype(np.int_).tolist()
w, h = bbox[2] - bbox[0], bbox[3] - bbox[1]
if w < 50 or h < 80:
continue
if w / h > 1.5 or w / h < 0.2:
continue
if face.det_score < threshold:
continue
size_now = w * h
if size_now > max_size:
max_size = size_now
get_face_store = face
if get_face_store is None:
return None, None
else:
face = get_face_store
lmk = np.round(face.landmark_2d_106).astype(np.int_)
halk_face_coord = np.mean([lmk[74], lmk[73]], axis=0) # lmk[73]
sub_lmk = lmk[LMK_ADAPT_ORIGIN_ORDER]
halk_face_dist = np.max(sub_lmk[:, 1]) - halk_face_coord[1]
upper_bond = halk_face_coord[1] - halk_face_dist # *0.94
x1, y1, x2, y2 = (np.min(sub_lmk[:, 0]), int(upper_bond), np.max(sub_lmk[:, 0]), np.max(sub_lmk[:, 1]))
if y2 - y1 <= 0 or x2 - x1 <= 0 or x1 < 0:
x1, y1, x2, y2 = face.bbox.astype(np.int_).tolist()
y2 += int((x2 - x1) * 0.1)
x1 -= int((x2 - x1) * 0.05)
x2 += int((x2 - x1) * 0.05)
x1 = max(0, x1)
y1 = max(0, y1)
x2 = min(f_w, x2)
y2 = min(f_h, y2)
return (x1, y1, x2, y2), lmk
def cuda_to_int(cuda_str: str) -> int:
"""
Convert the string with format "cuda:X" to integer X.
"""
if cuda_str == "cuda":
return 0
device = torch.device(cuda_str)
if device.type != "cuda":
raise ValueError(f"Device type must be 'cuda', got: {device.type}")
return device.index
LMK_ADAPT_ORIGIN_ORDER = [
1,
10,
12,
14,
16,
3,
5,
7,
0,
23,
21,
19,
32,
30,
28,
26,
17,
43,
48,
49,
51,
50,
102,
103,
104,
105,
101,
73,
74,
86,
]

View File

@@ -0,0 +1,122 @@
# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from latentsync.utils.util import read_video, write_video
from torchvision import transforms
import cv2
from einops import rearrange
import torch
import numpy as np
from typing import Union
from .affine_transform import AlignRestore
from .face_detector import FaceDetector
def load_fixed_mask(resolution: int, mask_image_path="latentsync/utils/mask.png") -> torch.Tensor:
mask_image = cv2.imread(mask_image_path)
mask_image = cv2.cvtColor(mask_image, cv2.COLOR_BGR2RGB)
mask_image = cv2.resize(mask_image, (resolution, resolution), interpolation=cv2.INTER_LANCZOS4) / 255.0
mask_image = rearrange(torch.from_numpy(mask_image), "h w c -> c h w")
return mask_image
class ImageProcessor:
def __init__(self, resolution: int = 512, device: str = "cpu", mask_image=None):
self.resolution = resolution
self.resize = transforms.Resize(
(resolution, resolution), interpolation=transforms.InterpolationMode.BICUBIC, antialias=True
)
self.normalize = transforms.Normalize([0.5], [0.5], inplace=True)
self.restorer = AlignRestore(resolution=resolution, device=device)
if mask_image is None:
self.mask_image = load_fixed_mask(resolution)
else:
self.mask_image = mask_image
if device == "cpu":
self.face_detector = None
else:
self.face_detector = FaceDetector(device=device)
def affine_transform(self, image: torch.Tensor) -> np.ndarray:
if self.face_detector is None:
raise NotImplementedError("Using the CPU for face detection is not supported")
bbox, landmark_2d_106 = self.face_detector(image)
if bbox is None:
raise RuntimeError("Face not detected")
pt_left_eye = np.mean(landmark_2d_106[[43, 48, 49, 51, 50]], axis=0) # left eyebrow center
pt_right_eye = np.mean(landmark_2d_106[101:106], axis=0) # right eyebrow center
pt_nose = np.mean(landmark_2d_106[[74, 77, 83, 86]], axis=0) # nose center
landmarks3 = np.round([pt_left_eye, pt_right_eye, pt_nose])
face, affine_matrix = self.restorer.align_warp_face(image.copy(), landmarks3=landmarks3, smooth=True)
box = [0, 0, face.shape[1], face.shape[0]] # x1, y1, x2, y2
face = cv2.resize(face, (self.resolution, self.resolution), interpolation=cv2.INTER_LANCZOS4)
face = rearrange(torch.from_numpy(face), "h w c -> c h w")
return face, box, affine_matrix
def preprocess_fixed_mask_image(self, image: torch.Tensor, affine_transform=False):
if affine_transform:
image, _, _ = self.affine_transform(image)
else:
image = self.resize(image)
pixel_values = self.normalize(image / 255.0)
masked_pixel_values = pixel_values * self.mask_image
return pixel_values, masked_pixel_values, self.mask_image[0:1]
def prepare_masks_and_masked_images(self, images: Union[torch.Tensor, np.ndarray], affine_transform=False):
if isinstance(images, np.ndarray):
images = torch.from_numpy(images)
if images.shape[3] == 3:
images = rearrange(images, "f h w c -> f c h w")
results = [self.preprocess_fixed_mask_image(image, affine_transform=affine_transform) for image in images]
pixel_values_list, masked_pixel_values_list, masks_list = list(zip(*results))
return torch.stack(pixel_values_list), torch.stack(masked_pixel_values_list), torch.stack(masks_list)
def process_images(self, images: Union[torch.Tensor, np.ndarray]):
if isinstance(images, np.ndarray):
images = torch.from_numpy(images)
if images.shape[3] == 3:
images = rearrange(images, "f h w c -> f c h w")
images = self.resize(images)
pixel_values = self.normalize(images / 255.0)
return pixel_values
class VideoProcessor:
def __init__(self, resolution: int = 512, device: str = "cpu"):
self.image_processor = ImageProcessor(resolution, device)
def affine_transform_video(self, video_path):
video_frames = read_video(video_path, change_fps=False)
results = []
for frame in video_frames:
frame, _, _ = self.image_processor.affine_transform(frame)
results.append(frame)
results = torch.stack(results)
results = rearrange(results, "f c h w -> f h w c").numpy()
return results
if __name__ == "__main__":
video_processor = VideoProcessor(256, "cuda")
video_frames = video_processor.affine_transform_video("assets/demo2_video.mp4")
write_video("output.mp4", video_frames, fps=25)

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.1 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 1.2 KiB

Some files were not shown because too many files have changed in this diff Show More