Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4a3dd2b225 | ||
|
|
ee8cb9cfd2 | ||
|
|
c6c4b2313f | ||
|
|
f99bd336c9 | ||
|
|
c918dc6faf | ||
|
|
3a3df41904 | ||
|
|
561d74e16d | ||
|
|
cfe21d8337 | ||
|
|
3a76f9d0cf | ||
|
|
ad7ff7a385 | ||
|
|
c7e2b4d363 | ||
|
|
d5baa79448 |
@@ -27,12 +27,18 @@ node --version
|
||||
|
||||
# 检查 FFmpeg
|
||||
ffmpeg -version
|
||||
|
||||
# 检查 pm2 (用于服务管理)
|
||||
pm2 --version
|
||||
```
|
||||
|
||||
如果缺少 FFmpeg:
|
||||
如果缺少依赖:
|
||||
```bash
|
||||
sudo apt update
|
||||
sudo apt install ffmpeg
|
||||
|
||||
# 安装 pm2
|
||||
npm install -g pm2
|
||||
```
|
||||
|
||||
---
|
||||
@@ -48,7 +54,30 @@ cd /home/rongye/ProgramFiles/ViGent2
|
||||
|
||||
---
|
||||
|
||||
## 步骤 3: 安装后端依赖
|
||||
## 步骤 3: 部署 AI 模型 (LatentSync 1.6)
|
||||
|
||||
> ⚠️ **重要**:LatentSync 需要独立的 Conda 环境和 **~18GB VRAM**。请**不要**直接安装在后端环境中。
|
||||
|
||||
请参考详细的独立部署指南:
|
||||
**[LatentSync 部署指南](../models/LatentSync/DEPLOY.md)**
|
||||
|
||||
该指南包含以下关键步骤,请务必严格按照文档操作:
|
||||
1. 创建独立的 `latentsync` Conda 环境
|
||||
2. 安装 PyTorch 2.5.1 和相关依赖
|
||||
3. 下载模型权重 (HuggingFace CLI)
|
||||
4. 复制核心推理代码
|
||||
5. 验证推理脚本
|
||||
|
||||
**验证 LatentSync 部署**:
|
||||
```bash
|
||||
cd /home/rongye/ProgramFiles/ViGent2/models/LatentSync
|
||||
conda activate latentsync
|
||||
python -m scripts.server # 测试能否启动,Ctrl+C 退出
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 步骤 4: 安装后端依赖
|
||||
|
||||
```bash
|
||||
cd /home/rongye/ProgramFiles/ViGent2/backend
|
||||
@@ -69,51 +98,54 @@ playwright install chromium
|
||||
|
||||
---
|
||||
|
||||
## 步骤 4: 部署 AI 模型 (LatentSync 1.6)
|
||||
## 步骤 5: 部署用户认证系统 (Supabase + Auth)
|
||||
|
||||
> ⚠️ **重要**:LatentSync 需要独立的 Conda 环境和 **~18GB VRAM**。请**不要**直接安装在后端环境中。
|
||||
> 🔐 **包含**: 登录/注册、Supabase 数据库配置、JWT 认证、管理员后台
|
||||
|
||||
请参考详细的独立部署指南:
|
||||
**[LatentSync 部署指南](../models/LatentSync/DEPLOY.md)**
|
||||
|
||||
该指南包含以下关键步骤,请务必严格按照文档操作:
|
||||
1. 创建独立的 `latentsync` Conda 环境
|
||||
2. 安装 PyTorch 2.5.1 和相关依赖
|
||||
3. 下载模型权重 (HuggingFace CLI)
|
||||
4. 复制核心推理代码
|
||||
5. 验证推理脚本
|
||||
|
||||
确保 LatentSync 部署成功后,再继续后续步骤。
|
||||
请参考独立的认证系统部署指南:
|
||||
**[用户认证系统部署指南](AUTH_DEPLOY.md)**
|
||||
|
||||
---
|
||||
|
||||
## 步骤 5: 启动 LatentSync 常驻加速服务 (可选)
|
||||
## 步骤 6: 配置 Supabase RLS 策略 (重要)
|
||||
|
||||
> ⚠️ **注意**:为了支持前端直传文件,必须配置存储桶的行级安全策略 (RLS)。
|
||||
|
||||
1. 确保 Supabase 容器正在运行 (`docker ps`).
|
||||
2. 将项目根目录下的 `supabase_rls.sql` (如果有) 或以下 SQL 内容在数据库中执行。
|
||||
3. **执行命令**:
|
||||
```bash
|
||||
# 进入后端目录
|
||||
cd /home/rongye/ProgramFiles/ViGent2/backend
|
||||
|
||||
# 执行 SQL (允许 anon 角色上传/读取 materials 桶)
|
||||
docker exec -i supabase-db psql -U postgres <<EOF
|
||||
INSERT INTO storage.buckets (id, name, public) VALUES ('materials', 'materials', true) ON CONFLICT (id) DO NOTHING;
|
||||
INSERT INTO storage.buckets (id, name, public) VALUES ('outputs', 'outputs', true) ON CONFLICT (id) DO NOTHING;
|
||||
CREATE POLICY "Allow public uploads" ON storage.objects FOR INSERT TO anon WITH CHECK (bucket_id = 'materials');
|
||||
CREATE POLICY "Allow public read" ON storage.objects FOR SELECT TO anon USING (bucket_id = 'materials' OR bucket_id = 'outputs');
|
||||
EOF
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 步骤 7: 配置环境变量
|
||||
|
||||
为了消除每次生成视频时的 30-40秒 模型加载时间,建议启动常驻服务:
|
||||
|
||||
```bash
|
||||
cd /home/rongye/ProgramFiles/ViGent2/models/LatentSync
|
||||
|
||||
# 后台启动服务 (自动读取 backend/.env 中的 GPU 配置)
|
||||
nohup python -m scripts.server > server.log 2>&1 &
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 步骤 7: 配置环境变量
|
||||
|
||||
```bash
|
||||
cd /home/rongye/ProgramFiles/ViGent2/backend
|
||||
|
||||
# 复制配置模板 (默认配置已经就绪)
|
||||
# 复制配置模板
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
> 💡 **说明**:`.env.example` 已包含正确的 LatentSync 默认配置,直接复制即可使用。
|
||||
> 💡 **说明**:`.env.example` 已包含正确的默认配置,直接复制即可使用。
|
||||
> 如需自定义,可编辑 `.env` 修改以下参数:
|
||||
|
||||
| 配置项 | 默认值 | 说明 |
|
||||
|--------|--------|------|
|
||||
| `SUPABASE_URL` | `http://localhost:8008` | Supabase API 内部地址 |
|
||||
| `SUPABASE_PUBLIC_URL` | `https://api.hbyrkj.top` | Supabase API 公网地址 (前端访问) |
|
||||
| `LATENTSYNC_GPU_ID` | 1 | GPU 选择 (0 或 1) |
|
||||
| `LATENTSYNC_USE_SERVER` | false | 设为 true 以启用常驻服务加速 |
|
||||
| `LATENTSYNC_INFERENCE_STEPS` | 20 | 推理步数 (20-50) |
|
||||
@@ -129,13 +161,18 @@ cd /home/rongye/ProgramFiles/ViGent2/frontend
|
||||
|
||||
# 安装依赖
|
||||
npm install
|
||||
|
||||
# 生产环境构建 (可选)
|
||||
npm run build
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 步骤 9: 测试运行
|
||||
|
||||
### 启动后端
|
||||
> 💡 先手动启动测试,确认一切正常后再配置 pm2 常驻服务。
|
||||
|
||||
### 启动后端 (终端 1)
|
||||
|
||||
```bash
|
||||
cd /home/rongye/ProgramFiles/ViGent2/backend
|
||||
@@ -143,16 +180,22 @@ source venv/bin/activate
|
||||
uvicorn app.main:app --host 0.0.0.0 --port 8006
|
||||
```
|
||||
|
||||
### 启动前端 (新开终端)
|
||||
### 启动前端 (终端 2)
|
||||
|
||||
```bash
|
||||
cd /home/rongye/ProgramFiles/ViGent2/frontend
|
||||
npm run dev -- -H 0.0.0.0 --port 3002
|
||||
```
|
||||
|
||||
---
|
||||
### 启动 LatentSync (终端 3, 可选加速)
|
||||
|
||||
## 步骤 10: 验证
|
||||
```bash
|
||||
cd /home/rongye/ProgramFiles/ViGent2/models/LatentSync
|
||||
conda activate latentsync
|
||||
python -m scripts.server
|
||||
```
|
||||
|
||||
### 验证
|
||||
|
||||
1. 访问 http://服务器IP:3002 查看前端
|
||||
2. 访问 http://服务器IP:8006/docs 查看 API 文档
|
||||
@@ -160,59 +203,158 @@ npm run dev -- -H 0.0.0.0 --port 3002
|
||||
|
||||
---
|
||||
|
||||
## 使用 systemd 管理服务 (可选)
|
||||
## 步骤 10: 使用 pm2 管理常驻服务
|
||||
|
||||
### 后端服务
|
||||
> 推荐使用 pm2 管理所有服务,支持自动重启和日志管理。
|
||||
|
||||
创建 `/etc/systemd/system/vigent2-backend.service`:
|
||||
```ini
|
||||
[Unit]
|
||||
Description=ViGent2 Backend API
|
||||
After=network.target
|
||||
### 1. 启动后端服务 (FastAPI)
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=rongye
|
||||
WorkingDirectory=/home/rongye/ProgramFiles/ViGent2/backend
|
||||
Environment="PATH=/home/rongye/ProgramFiles/ViGent2/backend/venv/bin"
|
||||
ExecStart=/home/rongye/ProgramFiles/ViGent2/backend/venv/bin/uvicorn app.main:app --host 0.0.0.0 --port 8006
|
||||
Restart=always
|
||||
建议使用 Shell 脚本启动以避免环境问题。
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
1. 创建启动脚本 `run_backend.sh`:
|
||||
```bash
|
||||
cat > run_backend.sh << 'EOF'
|
||||
#!/bin/bash
|
||||
cd /home/rongye/ProgramFiles/ViGent2/backend
|
||||
./venv/bin/uvicorn app.main:app --host 0.0.0.0 --port 8006
|
||||
EOF
|
||||
chmod +x run_backend.sh
|
||||
```
|
||||
|
||||
### 前端服务
|
||||
|
||||
创建 `/etc/systemd/system/vigent2-frontend.service`:
|
||||
```ini
|
||||
[Unit]
|
||||
Description=ViGent2 Frontend
|
||||
After=network.target
|
||||
|
||||
[Service]
|
||||
Type=simple
|
||||
User=rongye
|
||||
WorkingDirectory=/home/rongye/ProgramFiles/ViGent2/frontend
|
||||
ExecStart=/usr/bin/npm run start
|
||||
Restart=always
|
||||
|
||||
[Install]
|
||||
WantedBy=multi-user.target
|
||||
2. 使用 pm2 启动:
|
||||
```bash
|
||||
pm2 start ./run_backend.sh --name vigent2-backend
|
||||
```
|
||||
|
||||
### 启用服务
|
||||
### 2. 启动前端服务 (Next.js)
|
||||
|
||||
⚠️ **注意**:生产模式启动前必须先进行构建。
|
||||
|
||||
```bash
|
||||
sudo systemctl daemon-reload
|
||||
sudo systemctl enable vigent2-backend vigent2-frontend
|
||||
sudo systemctl start vigent2-backend vigent2-frontend
|
||||
cd /home/rongye/ProgramFiles/ViGent2/frontend
|
||||
|
||||
# 1. 构建项目 (如果之前没跑过或代码有更新)
|
||||
npm run build
|
||||
|
||||
# 2. 启动服务
|
||||
pm2 start npm --name vigent2-frontend -- run start -- -p 3002
|
||||
```
|
||||
|
||||
### 3. 启动 LatentSync 模型服务
|
||||
|
||||
1. 创建启动脚本 `run_latentsync.sh` (使用你的 conda python 路径):
|
||||
```bash
|
||||
cat > run_latentsync.sh << 'EOF'
|
||||
#!/bin/bash
|
||||
cd /home/rongye/ProgramFiles/ViGent2/models/LatentSync
|
||||
# 替换为你的实际 Python 路径
|
||||
/home/rongye/ProgramFiles/miniconda3/envs/latentsync/bin/python -m scripts.server
|
||||
EOF
|
||||
chmod +x run_latentsync.sh
|
||||
```
|
||||
|
||||
2. 使用 pm2 启动:
|
||||
```bash
|
||||
pm2 start ./run_latentsync.sh --name vigent2-latentsync
|
||||
```
|
||||
|
||||
### 4. 保存当前列表 (开机自启)
|
||||
|
||||
```bash
|
||||
pm2 save
|
||||
pm2 startup
|
||||
```
|
||||
|
||||
### pm2 常用命令
|
||||
|
||||
```bash
|
||||
pm2 status # 查看所有服务状态
|
||||
pm2 logs # 查看所有日志
|
||||
pm2 logs vigent2-backend # 查看后端日志
|
||||
pm2 restart all # 重启所有服务
|
||||
pm2 stop vigent2-latentsync # 停止 LatentSync 服务
|
||||
pm2 delete all # 删除所有服务
|
||||
```
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 步骤 11: 配置 Nginx HTTPS (可选 - 公网访问)
|
||||
|
||||
如果您需要通过公网域名 HTTPS 访问 (如 `https://vigent.hbyrkj.top`),请参考以下 Nginx 配置。
|
||||
|
||||
**前置条件**:
|
||||
1. 已申请 SSL 证书 (如 Let's Encrypt)。
|
||||
2. 使用 FRP 或其他方式将本地 3002 端口映射到服务器。
|
||||
|
||||
**配置示例** (`/etc/nginx/conf.d/vigent.conf`):
|
||||
|
||||
```nginx
|
||||
server {
|
||||
listen 80;
|
||||
server_name your.domain.com;
|
||||
return 301 https://$host$request_uri;
|
||||
}
|
||||
|
||||
server {
|
||||
listen 443 ssl http2;
|
||||
server_name your.domain.com;
|
||||
|
||||
ssl_certificate /path/to/fullchain.pem;
|
||||
ssl_certificate_key /path/to/privkey.pem;
|
||||
|
||||
location / {
|
||||
proxy_pass http://127.0.0.1:3002; # 转发给 Next.js 前端
|
||||
|
||||
# 必须配置 WebSocket 支持,否则热更和即时通信失效
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
|
||||
proxy_set_header Host $host;
|
||||
proxy_set_header X-Real-IP $remote_addr;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
|
||||
---
|
||||
|
||||
## 步骤 12: 配置阿里云 Nginx 网关 (关键)
|
||||
|
||||
> ⚠️ **CRITICAL**: 如果使用 `api.hbyrkj.top` 等域名作为入口,必须在阿里云 (或公网入口) 的 Nginx 配置中解除上传限制。
|
||||
> **这是导致 500/413 错误的核心原因。**
|
||||
|
||||
**关键配置项**:
|
||||
```nginx
|
||||
server {
|
||||
listen 443 ssl;
|
||||
server_name api.hbyrkj.top;
|
||||
|
||||
# ... 其他 SSL 配置 ...
|
||||
|
||||
# 允许大文件上传 (0 表示不限制,或设置为 100M, 500M)
|
||||
client_max_body_size 0;
|
||||
|
||||
location / {
|
||||
proxy_pass http://127.0.0.1:YOUR_FRP_PORT;
|
||||
|
||||
# 延长超时时间
|
||||
proxy_read_timeout 600s;
|
||||
proxy_send_timeout 600s;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
**后果**:如果没有这个配置,上传会在 ~1MB 或 ~10MB 时直接断开,报 413 Payload Too Large 或 500/502 错误。
|
||||
|
||||
---
|
||||
|
||||
## 故障排除
|
||||
|
||||
|
||||
### GPU 不可用
|
||||
|
||||
```bash
|
||||
@@ -227,14 +369,58 @@ python3 -c "import torch; print(torch.cuda.is_available())"
|
||||
# 查看端口占用
|
||||
sudo lsof -i :8006
|
||||
sudo lsof -i :3002
|
||||
sudo lsof -i :8007
|
||||
```
|
||||
|
||||
### 查看日志
|
||||
|
||||
```bash
|
||||
# 后端日志
|
||||
journalctl -u vigent2-backend -f
|
||||
|
||||
# 前端日志
|
||||
journalctl -u vigent2-frontend -f
|
||||
# pm2 日志
|
||||
pm2 logs vigent2-backend
|
||||
pm2 logs vigent2-frontend
|
||||
pm2 logs vigent2-latentsync
|
||||
```
|
||||
|
||||
### SSH 连接卡顿 / 系统响应慢
|
||||
|
||||
**原因**:LatentSync 模型服务启动时会占用大量 I/O 和 CPU 资源,或者模型加载到 GPU 时导致瞬时负载过高。
|
||||
|
||||
**解决**:
|
||||
1. 检查系统负载:`top` 或 `htop`
|
||||
2. 如果不需要实时生成视频,可以暂时停止 LatentSync 服务:
|
||||
```bash
|
||||
pm2 stop vigent2-latentsync
|
||||
```
|
||||
3. 确保服务器有足够的 RAM 和 Swap 空间。
|
||||
4. **代码级优化**:已在 `scripts/server.py` 和 `scripts/inference.py` 中强制限制 `OMP_NUM_THREADS=8`,防止 PyTorch 占用所有 CPU 核心导致系统假死。
|
||||
|
||||
---
|
||||
|
||||
## 依赖清单
|
||||
|
||||
### 后端关键依赖
|
||||
|
||||
| 依赖 | 用途 |
|
||||
|------|------|
|
||||
| `fastapi` | Web API 框架 |
|
||||
| `uvicorn` | ASGI 服务器 |
|
||||
| `edge-tts` | 微软 TTS 配音 |
|
||||
| `playwright` | 社交媒体自动发布 |
|
||||
| `biliup` | B站视频上传 |
|
||||
| `loguru` | 日志管理 |
|
||||
|
||||
### 前端关键依赖
|
||||
|
||||
| 依赖 | 用途 |
|
||||
|------|------|
|
||||
| `next` | React 框架 |
|
||||
| `swr` | 数据请求与缓存 |
|
||||
| `tailwindcss` | CSS 样式 |
|
||||
|
||||
### LatentSync 关键依赖
|
||||
|
||||
| 依赖 | 用途 |
|
||||
|------|------|
|
||||
| `torch` 2.5.1 | PyTorch GPU 推理 |
|
||||
| `diffusers` | Latent Diffusion 模型 |
|
||||
| `accelerate` | 模型加速 |
|
||||
|
||||
122
Docs/DevLogs/Day10.md
Normal file
122
Docs/DevLogs/Day10.md
Normal file
@@ -0,0 +1,122 @@
|
||||
---
|
||||
|
||||
## 🔧 隧道访问与视频播放修复 (11:00)
|
||||
|
||||
### 问题描述
|
||||
在通过 FRP 隧道 (如 `http://8.148.x.x:3002`) 访问时发现:
|
||||
1. **视频无法播放**:后端返回 404 (Not Found)。
|
||||
2. **发布页账号列表为空**:后端返回 500 (Internal Server Error)。
|
||||
|
||||
### 解决方案
|
||||
|
||||
#### 1. 视频播放修复
|
||||
- **后端 (`main.py`)**:这是根源问题。后端缺少 `uploads` 目录的挂载,导致静态资源无法访问。
|
||||
```python
|
||||
app.mount("/uploads", StaticFiles(directory=str(settings.UPLOAD_DIR)), name="uploads")
|
||||
```
|
||||
- **前端 (`next.config.ts`)**:添加反向代理规则,将 `/outputs` 和 `/uploads` 转发到后端端口 8006。
|
||||
```typescript
|
||||
{
|
||||
source: '/uploads/:path*',
|
||||
destination: 'http://localhost:8006/uploads/:path*',
|
||||
},
|
||||
{
|
||||
source: '/outputs/:path*',
|
||||
destination: 'http://localhost:8006/outputs/:path*',
|
||||
}
|
||||
```
|
||||
|
||||
#### 2. 账号列表 500 错误修复
|
||||
- **根源**:`backend/app/core/paths.py` 中的白名单缺少 `weixin` 和 `kuaishou`。
|
||||
- **现象**:当 `PublishService` 遍历所有平台时,遇到未在白名单的平台直接抛出 `ValueError`,导致整个接口崩溃。
|
||||
- **修复**:更新白名单。
|
||||
```python
|
||||
VALID_PLATFORMS: Set[str] = {"bilibili", "douyin", "xiaohongshu", "weixin", "kuaishou"}
|
||||
```
|
||||
|
||||
### 结果
|
||||
- ✅ 视频预览和历史视频均可正常播放。
|
||||
- ✅ 发布页账号列表恢复显示。
|
||||
|
||||
---
|
||||
|
||||
## 🚀 Nginx HTTPS 部署 (11:30)
|
||||
|
||||
### 需求
|
||||
用户在阿里云服务器上配置了 SSL 证书,需要通过 HTTPS 访问应用。
|
||||
|
||||
### 解决方案
|
||||
提供了 Nginx 配置文件 `nginx_vigent.conf`,配置了:
|
||||
1. **HTTP -> HTTPS 重定向**。
|
||||
2. **SSL 证书路径** (`/etc/letsencrypt/live/vigent.hbyrkj.top/...`)。
|
||||
3. **反向代理** 到本地 FRP 端口 (3002)。
|
||||
4. **WebSocket 支持** (用于 Next.js 热更和通信)。
|
||||
|
||||
### 结果
|
||||
- ✅ 用户可通过 `https://vigent.hbyrkj.top` 安全访问。
|
||||
- ✅ 代码自适应:前端 `API_BASE` 为空字符串,自动适配 HTTPS 协议,无需修改代码。
|
||||
|
||||
---
|
||||
|
||||
## 🎨 UI 细节优化 (11:45)
|
||||
|
||||
### 修改
|
||||
- 修改 `frontend/src/app/layout.tsx` 中的 Metadata。
|
||||
- 标题从 `Create Next App` 改为 `ViGent`。
|
||||
|
||||
### 结果
|
||||
- ✅ 浏览器标签页名称已更新。
|
||||
|
||||
---
|
||||
|
||||
## 🚪 用户登录退出功能 (12:00)
|
||||
|
||||
### 需求
|
||||
用户反馈没有退出的入口。
|
||||
|
||||
### 解决方案
|
||||
- **UI 修改**:在首页和发布管理页面的顶部导航栏添加红色的“退出”按钮 (位于最右侧)。
|
||||
- **逻辑实现**:
|
||||
```javascript
|
||||
onClick={async () => {
|
||||
if (confirm('确定要退出登录吗?')) {
|
||||
await fetch(`${API_BASE}/api/auth/logout`, { method: 'POST' });
|
||||
window.location.href = '/login';
|
||||
}
|
||||
}}
|
||||
```
|
||||
- **部署**:已同步代码并重建前端。
|
||||
|
||||
---
|
||||
|
||||
## 🚢 Supabase 服务部署 (16:10)
|
||||
|
||||
### 需求
|
||||
由于需要多用户隔离和更完善的权限管理,决定从纯本地文件存储迁移到 Supabase BaaS 架构。
|
||||
|
||||
### 实施步骤
|
||||
|
||||
1. **Docker 部署 (Ubuntu)**
|
||||
- 使用官方 `docker-compose.yml`。
|
||||
- **端口冲突解决**:
|
||||
- `Moodist` 占用 4000 -> 迁移 Analytics 到 **4004**。
|
||||
- `code-server` 占用 8443 -> 迁移 Kong HTTPS 到 **8444**。
|
||||
- 自定义端口:Studio (**3003**), API (**8008**)。
|
||||
|
||||
2. **安全加固 (Aliyun Nginx)**
|
||||
- **双域名策略**:
|
||||
- `supabase.hbyrkj.top` -> Studio (3003)
|
||||
- `api.hbyrkj.top` -> API (8008)
|
||||
- **SSL**:配置 Let's Encrypt 证书。
|
||||
- **访问控制**:为 Studio 域名添加 `auth_basic` (htpasswd),防止未授权访问管理后台。
|
||||
- **WebSocket**:Nginx 配置 `Upgrade` 头支持 Realtime 功能。
|
||||
|
||||
3. **数据库初始化**
|
||||
- 使用 `backend/database/schema.sql` 初始化了 `users`, `social_accounts` 等表结构。
|
||||
|
||||
### 下一步计划 (Storage Migration)
|
||||
目前文件仍存储在本地磁盘,无法通过 RLS 进行隔离。
|
||||
**计划改造 LatentSync 流程**:
|
||||
1. 后端集成 Supabase Storage SDK。
|
||||
2. 实现 `Download (Storage) -> Local Process (LatentSync) -> Upload (Storage)` 闭环。
|
||||
3. 前端改为请求 Signed URL 进行播放。
|
||||
278
Docs/DevLogs/Day11.md
Normal file
278
Docs/DevLogs/Day11.md
Normal file
@@ -0,0 +1,278 @@
|
||||
|
||||
## 🔧 上传架构重构 (Direct Upload)
|
||||
|
||||
### 🚨 问题描述 (10:30)
|
||||
**现象**:上传大于 7MB 的文件时,后端返回 500 Internal Server Error,实际为 `ClientDisconnect`。
|
||||
**ROOT CAUSE (关键原因)**:
|
||||
- **Aliyun Nginx 网关限制**:`api.hbyrkj.top` 域名的 Nginx 配置缺少 `client_max_body_size 0;`。
|
||||
- **默认限制**:Nginx 默认限制请求体为 1MB (或少量),导致大文件上传时连接被网关强制截断。
|
||||
- **误判**:初期待查方向集中在 FRP 和 Backend Proxy 超时,实际是网关层的硬限制。
|
||||
|
||||
### ✅ 解决方案:前端直传 Supabase + 网关配置 (14:00)
|
||||
|
||||
**核心变更**:
|
||||
1. **网关配置**:在 Aliyun Nginx 的 `api.hbyrkj.top` 配置块中添加 `client_max_body_size 0;` (解除大小限制)。
|
||||
2. **架构优化**:移除后端文件转发逻辑,改由前端直接上传到 Supabase Storage (减少链路节点)。
|
||||
|
||||
#### 1. 前端改造 (`frontend/src/app/page.tsx`)
|
||||
- 引入 `@supabase/supabase-js` 客户端。
|
||||
- 使用 `supabase.storage.from('materials').upload()` 直接上传。
|
||||
- 移除旧的 `XMLHttpRequest` 代理上传逻辑。
|
||||
- 添加文件重命名策略:`{timestamp}_{sanitized_filename}`。
|
||||
|
||||
```typescript
|
||||
// V2: Direct Upload (Bypass Backend)
|
||||
const { data, error } = await supabase.storage
|
||||
.from('materials')
|
||||
.upload(path, file, {
|
||||
cacheControl: '3600',
|
||||
upsert: false
|
||||
});
|
||||
```
|
||||
|
||||
#### 2. 后端适配 (`backend/app/api/materials.py`)
|
||||
- **上传接口**:(已废弃/保留用于极小文件) 主要流量走直传。
|
||||
- **列表接口**:更新为返回 **签名 URL (Signed URL)**,而非本地路径。
|
||||
- **兼容性**:前端直接接收 `path` 字段为完整 URL,无需再次拼接。
|
||||
|
||||
#### 3. 权限控制 (RLS)
|
||||
- Supabase 默认禁止匿名写入。
|
||||
- 执行 SQL 策略允许 `anon` 角色对 `materials` 桶的 `INSERT` 和 `SELECT` 权限。
|
||||
|
||||
```sql
|
||||
-- Allow anonymous uploads
|
||||
CREATE POLICY "Allow public uploads"
|
||||
ON storage.objects FOR INSERT
|
||||
TO anon WITH CHECK (bucket_id = 'materials');
|
||||
```
|
||||
|
||||
### 结果
|
||||
- ✅ **彻底解决超时**:上传不再经过 Nginx/FRP,直接走 Supabase CDN。
|
||||
- ✅ **解除大小限制**:不再受限于后端服务的 `client_max_body_size`。
|
||||
- ✅ **用户体验提升**:上传速度更快,进度条更准确。
|
||||
|
||||
|
||||
|
||||
## 🔧 Supabase 部署与 RLS 配置
|
||||
|
||||
### 相关文件
|
||||
- `supabase_rls.sql`: 定义存储桶权限的 SQL 脚本。
|
||||
- `docker-compose.yml`: 确认 Storage 服务配置正常。
|
||||
|
||||
### 操作步骤
|
||||
1. 将 `supabase_rls.sql` 上传至服务器。
|
||||
2. 通过 Docker 执行 SQL:
|
||||
```bash
|
||||
cat supabase_rls.sql | docker exec -i supabase-db psql -U postgres
|
||||
```
|
||||
3. 验证前端上传成功。
|
||||
|
||||
---
|
||||
|
||||
## 🔐 用户隔离实现 (15:00)
|
||||
|
||||
### 问题描述
|
||||
不同账户登录后能看到其他用户上传的素材和生成的视频,缺乏数据隔离。
|
||||
|
||||
### 解决方案:存储路径前缀隔离
|
||||
|
||||
#### 1. 素材模块 (`backend/app/api/materials.py`)
|
||||
|
||||
```python
|
||||
# 上传时添加用户ID前缀
|
||||
storage_path = f"{user_id}/{timestamp}_{safe_name}"
|
||||
|
||||
# 列表时只查询当前用户目录
|
||||
files_obj = await storage_service.list_files(
|
||||
bucket=storage_service.BUCKET_MATERIALS,
|
||||
path=user_id # 只列出用户目录下的文件
|
||||
)
|
||||
|
||||
# 删除时验证权限
|
||||
if not material_id.startswith(f"{user_id}/"):
|
||||
raise HTTPException(403, "无权删除此素材")
|
||||
```
|
||||
|
||||
#### 2. 视频模块 (`backend/app/api/videos.py`)
|
||||
|
||||
```python
|
||||
# 生成视频时使用用户ID目录
|
||||
storage_path = f"{user_id}/{task_id}_output.mp4"
|
||||
|
||||
# 列表/删除同样基于用户目录隔离
|
||||
```
|
||||
|
||||
#### 3. 发布模块 (`backend/app/services/publish_service.py`)
|
||||
- Cookie 存储支持用户隔离:`cookies/{user_id}/{platform}.json`
|
||||
|
||||
### 存储结构
|
||||
```
|
||||
Supabase Storage/
|
||||
├── materials/
|
||||
│ ├── {user_id_1}/
|
||||
│ │ ├── 1737000001_video1.mp4
|
||||
│ │ └── 1737000002_video2.mp4
|
||||
│ └── {user_id_2}/
|
||||
│ └── 1737000003_video3.mp4
|
||||
└── outputs/
|
||||
├── {user_id_1}/
|
||||
│ └── {task_id}_output.mp4
|
||||
└── {user_id_2}/
|
||||
└── ...
|
||||
```
|
||||
|
||||
### 结果
|
||||
- ✅ 不同用户数据完全隔离
|
||||
- ✅ Cookie 和登录状态按用户存储
|
||||
- ✅ 删除操作验证所有权
|
||||
|
||||
---
|
||||
|
||||
## 🌐 Storage URL 修复 (16:00)
|
||||
|
||||
### 问题描述
|
||||
生成的视频 URL 为 `http://localhost:8008/...`,前端无法访问。
|
||||
|
||||
### 解决方案
|
||||
|
||||
#### 1. 后端配置 (`backend/.env`)
|
||||
```ini
|
||||
SUPABASE_URL=http://localhost:8008 # 内部访问
|
||||
SUPABASE_PUBLIC_URL=https://api.hbyrkj.top # 公网访问
|
||||
```
|
||||
|
||||
#### 2. URL 转换 (`backend/app/services/storage.py`)
|
||||
```python
|
||||
def _convert_to_public_url(self, url: str) -> str:
|
||||
"""将内部 URL 转换为公网可访问的 URL"""
|
||||
if settings.SUPABASE_PUBLIC_URL and settings.SUPABASE_URL:
|
||||
internal_url = settings.SUPABASE_URL.rstrip('/')
|
||||
public_url = settings.SUPABASE_PUBLIC_URL.rstrip('/')
|
||||
return url.replace(internal_url, public_url)
|
||||
return url
|
||||
```
|
||||
|
||||
### 结果
|
||||
- ✅ 前端获取的 URL 可正常访问
|
||||
- ✅ 视频预览和下载功能正常
|
||||
|
||||
---
|
||||
|
||||
## ⚡ 发布服务优化 - 本地文件直读 (16:30)
|
||||
|
||||
### 问题描述
|
||||
发布视频时需要先通过 HTTP 下载 Supabase Storage 文件到临时目录,效率低且浪费资源。
|
||||
|
||||
### 发现
|
||||
Supabase Storage 文件实际存储在本地磁盘:
|
||||
```
|
||||
/home/rongye/ProgramFiles/Supabase/volumes/storage/stub/stub/{bucket}/{path}/{internal_uuid}
|
||||
```
|
||||
|
||||
### 解决方案
|
||||
|
||||
#### 1. 添加本地路径获取方法 (`storage.py`)
|
||||
```python
|
||||
SUPABASE_STORAGE_LOCAL_PATH = Path("/home/rongye/ProgramFiles/Supabase/volumes/storage/stub/stub")
|
||||
|
||||
def get_local_file_path(self, bucket: str, path: str) -> Optional[str]:
|
||||
"""获取 Storage 文件的本地磁盘路径"""
|
||||
dir_path = SUPABASE_STORAGE_LOCAL_PATH / bucket / path
|
||||
if not dir_path.exists():
|
||||
return None
|
||||
files = list(dir_path.iterdir())
|
||||
return str(files[0]) if files else None
|
||||
```
|
||||
|
||||
#### 2. 发布服务优先使用本地文件 (`publish_service.py`)
|
||||
```python
|
||||
# 解析 URL 获取 bucket 和 path
|
||||
match = re.search(r'/storage/v1/object/sign/([^/]+)/(.+?)\?', video_path)
|
||||
if match:
|
||||
bucket, storage_path = match.group(1), match.group(2)
|
||||
local_video_path = storage_service.get_local_file_path(bucket, storage_path)
|
||||
|
||||
if local_video_path and os.path.exists(local_video_path):
|
||||
logger.info(f"[发布] 直接使用本地文件: {local_video_path}")
|
||||
else:
|
||||
# Fallback: HTTP 下载
|
||||
```
|
||||
|
||||
### 结果
|
||||
- ✅ 发布速度显著提升(跳过下载步骤)
|
||||
- ✅ 减少临时文件占用
|
||||
- ✅ 保留 HTTP 下载作为 Fallback
|
||||
|
||||
---
|
||||
|
||||
## 🔧 Supabase Studio 配置 (17:00)
|
||||
|
||||
### 修改内容
|
||||
更新 `/home/rongye/ProgramFiles/Supabase/.env`:
|
||||
```ini
|
||||
# 修改前
|
||||
SUPABASE_PUBLIC_URL=http://localhost:8000
|
||||
|
||||
# 修改后
|
||||
SUPABASE_PUBLIC_URL=https://api.hbyrkj.top
|
||||
```
|
||||
|
||||
### 原因
|
||||
通过 `supabase.hbyrkj.top` 公网访问 Studio 时,需要正确的 API 公网地址。
|
||||
|
||||
### 操作
|
||||
```bash
|
||||
docker compose restart studio
|
||||
```
|
||||
|
||||
### 待解决
|
||||
- 🔄 Studio Settings 页面加载问题(401 Unauthorized)- 可能与 Nginx Basic Auth 配置冲突
|
||||
|
||||
---
|
||||
|
||||
## 📁 今日修改文件清单
|
||||
|
||||
| 文件 | 变更类型 | 说明 |
|
||||
|------|----------|------|
|
||||
| `backend/app/api/materials.py` | 修改 | 添加用户隔离 |
|
||||
| `backend/app/api/videos.py` | 修改 | 添加用户隔离 |
|
||||
| `backend/app/services/storage.py` | 修改 | URL转换 + 本地路径获取 |
|
||||
| `backend/app/services/publish_service.py` | 修改 | 本地文件直读优化 |
|
||||
| `backend/.env` | 修改 | 添加 SUPABASE_PUBLIC_URL |
|
||||
| `Supabase/.env` | 修改 | SUPABASE_PUBLIC_URL |
|
||||
| `frontend/src/app/page.tsx` | 修改 | 改用后端API上传 |
|
||||
|
||||
---
|
||||
|
||||
## 📅 明日任务规划 (Day 12)
|
||||
|
||||
### 🎯 目标:部署 Qwen3-TTS 0.6B 声音克隆系统
|
||||
|
||||
**任务背景**:
|
||||
- 当前使用 EdgeTTS(微软云端 TTS),音色固定,无法自定义
|
||||
- Qwen3-TTS 支持**零样本声音克隆**,可用少量音频克隆任意人声
|
||||
|
||||
**核心任务**:
|
||||
1. **模型部署**
|
||||
- 创建独立 Conda 环境 (`qwen-tts`)
|
||||
- 下载 Qwen3-TTS 0.6B 模型权重
|
||||
- 配置 GPU 推理环境
|
||||
|
||||
2. **后端集成**
|
||||
- 新增 `qwen_tts_service.py` 服务
|
||||
- 支持声音克隆:上传参考音频 → 生成克隆语音
|
||||
- 兼容现有 `tts_service.py` 接口
|
||||
|
||||
3. **前端适配**
|
||||
- 添加"声音克隆"选项
|
||||
- 支持上传参考音频(3-10秒)
|
||||
- 音色预览功能
|
||||
|
||||
**预期成果**:
|
||||
- ✅ 用户可上传自己的声音样本
|
||||
- ✅ 生成的口播视频使用克隆后的声音
|
||||
- ✅ 保留 EdgeTTS 作为备选方案
|
||||
|
||||
**参考资源**:
|
||||
- 模型:[Qwen/Qwen3-TTS-0.6B](https://huggingface.co/Qwen/Qwen3-TTS-0.6B)
|
||||
- 显存需求:~4GB (0.6B 参数量)
|
||||
347
Docs/DevLogs/Day12.md
Normal file
347
Docs/DevLogs/Day12.md
Normal file
@@ -0,0 +1,347 @@
|
||||
# Day 12 - iOS 兼容与移动端 UI 优化
|
||||
|
||||
**日期**:2026-01-28
|
||||
|
||||
---
|
||||
|
||||
## 🔐 Axios 全局拦截器优化
|
||||
|
||||
### 背景
|
||||
统一处理 API 请求的认证失败场景,避免各页面重复处理 401/403 错误。
|
||||
|
||||
### 实现 (`frontend/src/lib/axios.ts`)
|
||||
|
||||
```typescript
|
||||
import axios from 'axios';
|
||||
|
||||
// 动态获取 API 地址:服务端使用 localhost,客户端使用当前域名
|
||||
const API_BASE = typeof window === 'undefined'
|
||||
? 'http://localhost:8006'
|
||||
: '';
|
||||
|
||||
// 防止重复跳转
|
||||
let isRedirecting = false;
|
||||
|
||||
const api = axios.create({
|
||||
baseURL: API_BASE,
|
||||
withCredentials: true, // 自动携带 HttpOnly cookie
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
});
|
||||
|
||||
// 响应拦截器 - 全局处理 401/403
|
||||
api.interceptors.response.use(
|
||||
(response) => response,
|
||||
async (error) => {
|
||||
const status = error.response?.status;
|
||||
|
||||
if ((status === 401 || status === 403) && !isRedirecting) {
|
||||
isRedirecting = true;
|
||||
|
||||
// 调用 logout API 清除 HttpOnly cookie
|
||||
try {
|
||||
await fetch('/api/auth/logout', { method: 'POST' });
|
||||
} catch (e) { /* 忽略 */ }
|
||||
|
||||
// 跳转登录页
|
||||
if (typeof window !== 'undefined') {
|
||||
window.location.replace('/login');
|
||||
}
|
||||
}
|
||||
return Promise.reject(error);
|
||||
}
|
||||
);
|
||||
|
||||
export default api;
|
||||
```
|
||||
|
||||
### 关键特性
|
||||
- ✅ **自动携带 Cookie**: `withCredentials: true` 确保 HttpOnly JWT cookie 被发送
|
||||
- ✅ **401/403 自动跳转**: 认证失败时自动清理并跳转登录页
|
||||
- ✅ **防重复跳转**: `isRedirecting` 标志避免多个请求同时触发跳转
|
||||
- ✅ **SSR 兼容**: 服务端渲染时使用 `localhost`,客户端使用相对路径
|
||||
|
||||
---
|
||||
|
||||
## 🔧 iOS Safari 安全区域白边修复
|
||||
|
||||
### 问题描述
|
||||
iPhone Safari 浏览器底部和顶部显示白色区域,安卓正常。原因是 iOS Safari 有安全区域 (Safe Area),页面背景没有延伸到该区域。
|
||||
|
||||
### 根本原因
|
||||
1. 缺少 `viewport-fit=cover` 配置
|
||||
2. `min-h-screen` (100vh) 在 iOS Safari 中不包含安全区域
|
||||
3. 背景渐变在页面 div 上,而非 body 上,导致安全区域显示纯色
|
||||
|
||||
### 解决方案
|
||||
|
||||
#### 1. 添加 viewport 配置 (`layout.tsx`)
|
||||
```typescript
|
||||
export const viewport: Viewport = {
|
||||
width: 'device-width',
|
||||
initialScale: 1,
|
||||
viewportFit: 'cover', // 允许内容延伸到安全区域
|
||||
themeColor: '#0f172a', // 顶部状态栏颜色
|
||||
};
|
||||
```
|
||||
|
||||
#### 2. 统一渐变背景到 body (`layout.tsx`)
|
||||
```tsx
|
||||
<html lang="en" style={{ backgroundColor: '#0f172a' }}>
|
||||
<body
|
||||
style={{
|
||||
margin: 0,
|
||||
minHeight: '100dvh',
|
||||
background: 'linear-gradient(to bottom, #0f172a 0%, #0f172a 5%, #581c87 50%, #0f172a 95%, #0f172a 100%)',
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</body>
|
||||
</html>
|
||||
```
|
||||
|
||||
#### 3. CSS 安全区域支持 (`globals.css`)
|
||||
```css
|
||||
html {
|
||||
background-color: #0f172a !important;
|
||||
min-height: 100%;
|
||||
}
|
||||
|
||||
body {
|
||||
margin: 0 !important;
|
||||
min-height: 100dvh;
|
||||
padding-top: env(safe-area-inset-top);
|
||||
padding-bottom: env(safe-area-inset-bottom);
|
||||
}
|
||||
```
|
||||
|
||||
#### 4. 移除页面独立渐变背景
|
||||
各页面的根 div 移除 `bg-gradient-to-br` 类,统一使用 body 渐变:
|
||||
- `page.tsx`
|
||||
- `login/page.tsx`
|
||||
- `publish/page.tsx`
|
||||
- `admin/page.tsx`
|
||||
- `register/page.tsx`
|
||||
|
||||
### 结果
|
||||
- ✅ 顶部状态栏颜色与页面一致 (themeColor)
|
||||
- ✅ 底部安全区域颜色与渐变边缘一致
|
||||
- ✅ 消除分层感,背景统一
|
||||
|
||||
---
|
||||
|
||||
## 📱 移动端 Header 响应式优化
|
||||
|
||||
### 问题描述
|
||||
移动端顶部导航按钮(视频生成、发布管理、退出)过于拥挤,文字换行显示。
|
||||
|
||||
### 解决方案
|
||||
|
||||
#### 首页 Header (`page.tsx`)
|
||||
```tsx
|
||||
<header className="border-b border-white/10 bg-black/20 backdrop-blur-sm">
|
||||
<div className="max-w-6xl mx-auto px-4 sm:px-6 py-3 sm:py-4 flex items-center justify-between">
|
||||
<Link href="/" className="text-xl sm:text-2xl font-bold ...">
|
||||
<span className="text-3xl sm:text-4xl">🎬</span>
|
||||
ViGent
|
||||
</Link>
|
||||
<div className="flex items-center gap-1 sm:gap-4">
|
||||
<span className="px-2 sm:px-4 py-1 sm:py-2 text-sm sm:text-base ...">
|
||||
视频生成
|
||||
</span>
|
||||
<!-- 其他按钮同样处理 -->
|
||||
</div>
|
||||
</div>
|
||||
</header>
|
||||
```
|
||||
|
||||
#### 发布管理页 Header (`publish/page.tsx`)
|
||||
同步应用相同的响应式类名。
|
||||
|
||||
### 关键改动
|
||||
| 属性 | 移动端 | 桌面端 |
|
||||
|------|--------|--------|
|
||||
| 容器内边距 | `px-4 py-3` | `px-6 py-4` |
|
||||
| 按钮间距 | `gap-1` | `gap-4` |
|
||||
| 按钮内边距 | `px-2 py-1` | `px-4 py-2` |
|
||||
| 字体大小 | `text-sm` | `text-base` |
|
||||
| Logo 大小 | `text-xl` + `text-3xl` | `text-2xl` + `text-4xl` |
|
||||
|
||||
### 结果
|
||||
- ✅ 移动端按钮紧凑排列,不再换行
|
||||
- ✅ 桌面端保持原有宽松布局
|
||||
|
||||
---
|
||||
|
||||
## 🚀 发布页面 UI 重构
|
||||
|
||||
### 问题描述
|
||||
原有设计将"发布时间"选项放在表单中,用户可能误选"定时发布"但忘记设置时间。
|
||||
|
||||
### 解决方案
|
||||
将"一键发布"按钮改为两个独立按钮:
|
||||
- **立即发布** (绿色,占 3/4 宽度) - 主要操作
|
||||
- **定时** (占 1/4 宽度) - 点击展开时间选择器
|
||||
|
||||
#### 新布局 (`publish/page.tsx`)
|
||||
```tsx
|
||||
{/* 发布按钮区域 */}
|
||||
<div className="space-y-3">
|
||||
<div className="flex gap-3">
|
||||
{/* 立即发布 - 占 3/4 */}
|
||||
<button
|
||||
onClick={() => { setScheduleMode("now"); handlePublish(); }}
|
||||
className="flex-[3] py-4 rounded-xl font-bold text-lg bg-gradient-to-r from-green-600 to-teal-600 ..."
|
||||
>
|
||||
🚀 立即发布
|
||||
</button>
|
||||
|
||||
{/* 定时发布 - 占 1/4 */}
|
||||
<button
|
||||
onClick={() => setScheduleMode(scheduleMode === "scheduled" ? "now" : "scheduled")}
|
||||
className="flex-1 py-4 rounded-xl font-bold text-base ..."
|
||||
>
|
||||
⏰ 定时
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* 定时发布时间选择器 (展开时显示) */}
|
||||
{scheduleMode === "scheduled" && (
|
||||
<div className="flex gap-3 items-center">
|
||||
<input type="datetime-local" ... />
|
||||
<button>确认定时</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
```
|
||||
|
||||
### 结果
|
||||
- ✅ 主操作(立即发布)更醒目
|
||||
- ✅ 定时发布需要二次确认,防止误触
|
||||
- ✅ 从表单区域移除发布时间选项,界面更简洁
|
||||
|
||||
---
|
||||
|
||||
## 🛤️ 后续优化项
|
||||
|
||||
### 后端定时发布 (待实现)
|
||||
**当前状态**:定时发布使用平台端定时(在各平台设置发布时间)
|
||||
|
||||
**优化方向**:改为后端定时任务
|
||||
- 使用 APScheduler 实现任务调度
|
||||
- 存储定时任务到数据库
|
||||
- 到时间后端自动触发发布 API
|
||||
- 支持查看/取消定时任务
|
||||
|
||||
**优势**:
|
||||
- 统一逻辑,不依赖平台定时 UI
|
||||
- 更灵活,可管理定时任务
|
||||
- 平台页面更新不影响功能
|
||||
|
||||
---
|
||||
|
||||
## 🤖 Qwen3-TTS 0.6B 声音克隆部署
|
||||
|
||||
### 背景
|
||||
为实现用户自定义声音克隆功能,部署 Qwen3-TTS 0.6B-Base 模型,支持 3 秒参考音频快速克隆。
|
||||
|
||||
### GPU 分配
|
||||
| GPU | 服务 | 模型 |
|
||||
|-----|------|------|
|
||||
| GPU0 | Qwen3-TTS | 0.6B-Base (声音克隆) |
|
||||
| GPU1 | LatentSync | 1.6 (唇形同步) |
|
||||
|
||||
### 部署步骤
|
||||
|
||||
#### 1. 克隆仓库
|
||||
```bash
|
||||
cd /home/rongye/ProgramFiles/ViGent2/models
|
||||
git clone https://github.com/QwenLM/Qwen3-TTS.git
|
||||
```
|
||||
|
||||
#### 2. 创建 conda 环境
|
||||
```bash
|
||||
conda create -n qwen-tts python=3.10 -y
|
||||
conda activate qwen-tts
|
||||
```
|
||||
|
||||
#### 3. 安装依赖
|
||||
```bash
|
||||
cd Qwen3-TTS
|
||||
pip install -e .
|
||||
conda install -y -c conda-forge sox # 音频处理依赖
|
||||
```
|
||||
|
||||
#### 4. 下载模型权重 (使用 ModelScope,国内更快)
|
||||
```bash
|
||||
pip install modelscope
|
||||
# Tokenizer (651MB)
|
||||
modelscope download --model Qwen/Qwen3-TTS-Tokenizer-12Hz --local_dir ./checkpoints/Tokenizer
|
||||
# 0.6B-Base 模型 (2.4GB)
|
||||
modelscope download --model Qwen/Qwen3-TTS-12Hz-0.6B-Base --local_dir ./checkpoints/0.6B-Base
|
||||
```
|
||||
|
||||
#### 5. 测试推理
|
||||
```python
|
||||
# test_inference.py
|
||||
import torch
|
||||
import soundfile as sf
|
||||
from qwen_tts import Qwen3TTSModel
|
||||
|
||||
model = Qwen3TTSModel.from_pretrained(
|
||||
"./checkpoints/0.6B-Base",
|
||||
device_map="cuda:0",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
|
||||
wavs, sr = model.generate_voice_clone(
|
||||
text="测试文本",
|
||||
language="Chinese",
|
||||
ref_audio="./examples/myvoice.wav",
|
||||
ref_text="参考音频的文字内容",
|
||||
)
|
||||
sf.write("output.wav", wavs[0], sr)
|
||||
```
|
||||
|
||||
### 测试结果
|
||||
- ✅ 模型加载成功 (GPU0)
|
||||
- ✅ 声音克隆推理成功
|
||||
- ✅ 输出音频 24000Hz,质量良好
|
||||
|
||||
### 目录结构
|
||||
```
|
||||
models/Qwen3-TTS/
|
||||
├── checkpoints/
|
||||
│ ├── Tokenizer/ # 651MB
|
||||
│ └── 0.6B-Base/ # 2.4GB
|
||||
├── qwen_tts/ # 源码
|
||||
├── examples/
|
||||
│ └── myvoice.wav # 参考音频
|
||||
└── test_inference.py # 测试脚本
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 📁 今日修改文件清单
|
||||
|
||||
| 文件 | 变更类型 | 说明 |
|
||||
|------|----------|------|
|
||||
| `frontend/src/lib/axios.ts` | 修改 | Axios 全局拦截器 (401/403 自动跳转) |
|
||||
| `frontend/src/app/layout.tsx` | 修改 | viewport 配置 + body 渐变背景 |
|
||||
| `frontend/src/app/globals.css` | 修改 | 安全区域 CSS 支持 |
|
||||
| `frontend/src/app/page.tsx` | 修改 | 移除独立渐变 + Header 响应式 |
|
||||
| `frontend/src/app/login/page.tsx` | 修改 | 移除独立渐变 |
|
||||
| `frontend/src/app/publish/page.tsx` | 修改 | Header 响应式 + 发布按钮重构 |
|
||||
| `frontend/src/app/admin/page.tsx` | 修改 | 移除独立渐变 |
|
||||
| `frontend/src/app/register/page.tsx` | 修改 | 移除独立渐变 |
|
||||
| `README.md` | 修改 | 添加 "iOS/Android 移动端适配" 功能说明 |
|
||||
| `Docs/FRONTEND_DEV.md` | 修改 | iOS Safari 安全区域兼容规范 + 移动端响应式规则 |
|
||||
| `models/Qwen3-TTS/` | 新增 | Qwen3-TTS 声音克隆模型部署 |
|
||||
| `Docs/QWEN3_TTS_DEPLOY.md` | 新增 | Qwen3-TTS 部署指南 |
|
||||
|
||||
---
|
||||
|
||||
## 🔗 相关文档
|
||||
|
||||
- [task_complete.md](../task_complete.md) - 任务总览
|
||||
- [Day11.md](./Day11.md) - 上传架构重构
|
||||
- [QWEN3_TTS_DEPLOY.md](../QWEN3_TTS_DEPLOY.md) - Qwen3-TTS 部署指南
|
||||
113
Docs/DevLogs/Day8.md
Normal file
113
Docs/DevLogs/Day8.md
Normal file
@@ -0,0 +1,113 @@
|
||||
# Day 8: 用户体验优化
|
||||
|
||||
**日期**: 2026-01-22
|
||||
**目标**: 文件名保留 + 视频持久化 + 界面优化
|
||||
|
||||
---
|
||||
|
||||
## 📋 任务概览
|
||||
|
||||
| 任务 | 状态 |
|
||||
|------|------|
|
||||
| 文件名保留 | ✅ 完成 |
|
||||
| 视频持久化 | ✅ 完成 |
|
||||
| 历史视频列表 | ✅ 完成 |
|
||||
| 删除功能 | ✅ 完成 |
|
||||
|
||||
---
|
||||
|
||||
## 🎉 实施成果
|
||||
|
||||
### 后端改动
|
||||
|
||||
**修改文件**:
|
||||
- `backend/app/api/materials.py`
|
||||
- ✅ `sanitize_filename()` 文件名安全化
|
||||
- ✅ 时间戳前缀避免冲突 (`{timestamp}_{原始文件名}`)
|
||||
- ✅ `list_materials` 显示原始文件名
|
||||
- ✅ `DELETE /api/materials/{id}` 删除素材
|
||||
|
||||
- `backend/app/api/videos.py`
|
||||
- ✅ `GET /api/videos/generated` 历史视频列表
|
||||
- ✅ `DELETE /api/videos/generated/{id}` 删除视频
|
||||
|
||||
### 前端改动
|
||||
|
||||
**修改文件**:
|
||||
- `frontend/src/app/page.tsx`
|
||||
- ✅ `GeneratedVideo` 类型定义
|
||||
- ✅ `generatedVideos` 状态管理
|
||||
- ✅ `fetchGeneratedVideos()` 获取历史
|
||||
- ✅ `deleteMaterial()` / `deleteVideo()` 删除功能
|
||||
- ✅ 素材卡片添加删除按钮 (hover 显示)
|
||||
- ✅ 历史视频列表组件 (右侧预览区下方)
|
||||
- ✅ 生成完成后自动刷新历史列表
|
||||
|
||||
---
|
||||
|
||||
## 🔧 API 变更
|
||||
|
||||
### 新增端点
|
||||
|
||||
| 方法 | 路径 | 说明 |
|
||||
|------|------|------|
|
||||
| GET | `/api/videos/generated` | 获取生成视频列表 |
|
||||
| DELETE | `/api/videos/generated/{id}` | 删除生成视频 |
|
||||
| DELETE | `/api/materials/{id}` | 删除素材 |
|
||||
|
||||
### 文件命名规则
|
||||
|
||||
```
|
||||
原始: 测试视频.mp4
|
||||
保存: 1737518400_测试视频.mp4
|
||||
显示: 测试视频.mp4 (前端自动去除时间戳前缀)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## ✅ 完成总结
|
||||
|
||||
1. **文件名保留** - 上传保留原始名称,时间戳前缀避免冲突
|
||||
2. **视频持久化** - 从文件系统读取,刷新不丢失
|
||||
3. **历史列表** - 右侧显示历史视频,点击切换播放
|
||||
4. **删除功能** - 素材和视频均支持删除
|
||||
|
||||
---
|
||||
|
||||
## 📊 测试清单
|
||||
|
||||
- [x] 上传视频后检查素材列表显示原始文件名
|
||||
- [x] 刷新页面后检查历史视频列表持久化
|
||||
- [x] 测试删除素材功能
|
||||
- [x] 测试删除生成视频功能
|
||||
- [x] 测试历史视频列表点击切换播放
|
||||
|
||||
---
|
||||
|
||||
## 🔧 发布功能修复 (Day 8 下半场)
|
||||
|
||||
> 以下修复在用户体验优化后进行
|
||||
|
||||
### 问题
|
||||
|
||||
1. **抖音 QR 登录假成功** - 前端检测到旧 Cookie 文件就显示"登录成功",实际可能已过期
|
||||
2. **抖音上传循环卡死** - 发布后检测逻辑不完善,`while True` 无超时
|
||||
3. **前端轮询不规范** - 使用 `setInterval` 手动轮询,不符合 React 最佳实践
|
||||
|
||||
### 修复
|
||||
|
||||
**后端**:
|
||||
- `publish_service.py` - 添加 `logout()` 方法、修复 `get_login_session_status()` 优先检查活跃会话
|
||||
- `api/publish.py` - 新增 `POST /api/publish/logout/{platform}` 端点
|
||||
- `douyin_uploader.py` - 添加 `import time`,修复发布按钮点击竞态条件
|
||||
|
||||
**前端**:
|
||||
- `publish/page.tsx` - 使用 `useSWR` 替代 `setInterval` 轮询登录状态
|
||||
- `package.json` - 添加 `swr` 依赖
|
||||
|
||||
### 新增 API
|
||||
|
||||
| 方法 | 路径 | 说明 |
|
||||
|------|------|------|
|
||||
| POST | `/api/publish/logout/{platform}` | 注销平台登录 |
|
||||
|
||||
320
Docs/DevLogs/Day9.md
Normal file
320
Docs/DevLogs/Day9.md
Normal file
@@ -0,0 +1,320 @@
|
||||
# Day 9: 发布模块代码优化
|
||||
|
||||
**日期**: 2026-01-23
|
||||
**目标**: 代码质量优化 + 发布功能验证
|
||||
|
||||
---
|
||||
|
||||
## 📋 任务概览
|
||||
|
||||
| 任务 | 状态 |
|
||||
|------|------|
|
||||
| B站/抖音发布验证 | ✅ 完成 |
|
||||
| 资源清理保障 (try-finally) | ✅ 完成 |
|
||||
| 超时保护 (消除无限循环) | ✅ 完成 |
|
||||
| 小红书 headless 模式修复 | ✅ 完成 |
|
||||
| API 输入验证 | ✅ 完成 |
|
||||
| 类型提示完善 | ✅ 完成 |
|
||||
| 服务层代码优化 | ✅ 完成 |
|
||||
| 扫码登录等待界面 | ✅ 完成 |
|
||||
| 抖音登录策略优化 | ✅ 完成 |
|
||||
| 发布成功审核提示 | ✅ 完成 |
|
||||
| 用户认证系统规划 | ✅ 计划完成 |
|
||||
|
||||
---
|
||||
|
||||
## 🎉 发布验证结果
|
||||
|
||||
### 登录功能
|
||||
- ✅ **B站登录成功** - 策略3(Text)匹配,Cookie已保存
|
||||
- ✅ **抖音登录成功** - 策略3(Text)匹配,Cookie已保存
|
||||
|
||||
### 发布功能
|
||||
- ✅ **抖音发布成功** - 自动关闭弹窗、跳转管理页面
|
||||
- ✅ **B站发布成功** - API返回 `bvid: BV14izPBQEbd`
|
||||
|
||||
---
|
||||
|
||||
## 🔧 代码优化
|
||||
|
||||
### 1. 资源清理保障
|
||||
|
||||
**问题**:Playwright 浏览器在异常路径可能未关闭
|
||||
|
||||
**修复**:`try-finally` 模式确保资源释放
|
||||
```python
|
||||
browser = None
|
||||
context = None
|
||||
try:
|
||||
browser = await playwright.chromium.launch(headless=True)
|
||||
context = await browser.new_context(...)
|
||||
# ... 业务逻辑 ...
|
||||
finally:
|
||||
if context:
|
||||
try: await context.close()
|
||||
except Exception: pass
|
||||
if browser:
|
||||
try: await browser.close()
|
||||
except Exception: pass
|
||||
```
|
||||
|
||||
### 2. 超时保护
|
||||
|
||||
**问题**:`while True` 循环可能导致任务卡死
|
||||
|
||||
**修复**:添加类级别超时常量
|
||||
```python
|
||||
class DouyinUploader(BaseUploader):
|
||||
UPLOAD_TIMEOUT = 300 # 视频上传超时
|
||||
PUBLISH_TIMEOUT = 180 # 发布检测超时
|
||||
PAGE_REDIRECT_TIMEOUT = 60 # 页面跳转超时
|
||||
```
|
||||
|
||||
### 3. B站 bvid 提取修复
|
||||
|
||||
**问题**:API 返回的 bvid 在 `data` 字段内
|
||||
|
||||
**修复**:同时检查多个位置
|
||||
```python
|
||||
bvid = ret.get('data', {}).get('bvid') or ret.get('bvid', '')
|
||||
aid = ret.get('data', {}).get('aid') or ret.get('aid', '')
|
||||
```
|
||||
|
||||
### 4. API 输入验证
|
||||
|
||||
**修复**:所有端点添加平台验证
|
||||
```python
|
||||
SUPPORTED_PLATFORMS = {"bilibili", "douyin", "xiaohongshu"}
|
||||
|
||||
if platform not in SUPPORTED_PLATFORMS:
|
||||
raise HTTPException(status_code=400, detail=f"不支持的平台: {platform}")
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 🎨 用户体验优化
|
||||
|
||||
### 1. 扫码登录等待界面
|
||||
|
||||
**问题**:点击登录后,二维码获取需要几秒,用户无反馈
|
||||
|
||||
**优化**:
|
||||
- 点击登录后立即显示加载弹窗
|
||||
- 加载动画 (旋转圈 + "正在获取二维码...")
|
||||
- 二维码获取成功后自动切换显示
|
||||
|
||||
### 2. 抖音登录策略优化
|
||||
|
||||
**问题**:抖音登录需要约 23 秒获取二维码 (策略1/2超时)
|
||||
|
||||
**原因分析**:
|
||||
| 策略 | 抖音耗时 | B站耗时 | 结果 |
|
||||
|------|----------|---------|------|
|
||||
| Role | 10s 超时 | N/A | ❌ |
|
||||
| CSS | 8s 超时 | 8s 超时 | ❌ |
|
||||
| Text | ~1s | ~1s | ✅ |
|
||||
|
||||
**优化**:
|
||||
```python
|
||||
# 抖音/B站:Text 策略优先
|
||||
if self.platform in ("douyin", "bilibili"):
|
||||
qr_element = await self._try_text_strategy(page) # 优先
|
||||
if not qr_element:
|
||||
await page.wait_for_selector(..., timeout=3000) # CSS 备用
|
||||
else:
|
||||
# 其他平台保持 CSS 优先
|
||||
```
|
||||
|
||||
**效果**:
|
||||
- 抖音登录二维码获取:~23s → ~5s
|
||||
- B站登录二维码获取:~13s → ~5s
|
||||
|
||||
### 3. 发布成功审核提示
|
||||
|
||||
**问题**:发布成功后,用户不知道需要审核
|
||||
|
||||
**优化**:
|
||||
- 后端消息改为 "发布成功,待审核"
|
||||
- 前端增加提示 "⏳ 审核一般需要几分钟,请耐心等待"
|
||||
- 发布结果 10 秒后自动消失
|
||||
|
||||
---
|
||||
|
||||
## 📁 修改文件列表
|
||||
|
||||
### 后端
|
||||
|
||||
| 文件 | 修改内容 |
|
||||
|------|----------|
|
||||
| `app/api/publish.py` | 输入验证、平台常量、文档改进 |
|
||||
| `app/services/publish_service.py` | 类型提示、平台 enabled 标记 |
|
||||
| `app/services/qr_login_service.py` | **策略顺序优化**、超时缩短 |
|
||||
| `app/services/uploader/base_uploader.py` | 类型提示 |
|
||||
| `app/services/uploader/bilibili_uploader.py` | **发布消息改为"待审核"** |
|
||||
| `app/services/uploader/douyin_uploader.py` | **发布消息改为"待审核"** |
|
||||
| `app/services/uploader/xiaohongshu_uploader.py` | **发布消息改为"待审核"** |
|
||||
|
||||
### 前端
|
||||
|
||||
| 文件 | 修改内容 |
|
||||
|------|----------|
|
||||
| `src/app/publish/page.tsx` | **加载动画、审核提示、结果自动消失** |
|
||||
|
||||
---
|
||||
|
||||
## ✅ 完成总结
|
||||
|
||||
1. **发布功能验证通过** - B站/抖音登录和发布均正常
|
||||
2. **代码健壮性提升** - 资源清理、超时保护、异常处理
|
||||
3. **代码可维护性** - 完整类型提示、常量化配置
|
||||
4. **服务器兼容性** - 小红书 headless 模式修复
|
||||
5. **用户体验优化** - 加载状态、策略顺序、审核提示
|
||||
|
||||
---
|
||||
|
||||
## 🔐 用户认证系统规划
|
||||
|
||||
> 规划完成,待下一阶段实施
|
||||
|
||||
### 技术方案
|
||||
|
||||
| 项目 | 方案 |
|
||||
|------|------|
|
||||
| 认证框架 | FastAPI + JWT (HttpOnly Cookie) |
|
||||
| 数据库 | Supabase (PostgreSQL + RLS) |
|
||||
| 管理员 | .env 预设 + startup 自动初始化 |
|
||||
| 授权期限 | expires_at 字段,可设定有效期 |
|
||||
| 单设备登录 | 后踢前模式 + Session Token 强校验 |
|
||||
| 账号隔离 | 规范化 Cookie 路径 `user_data/{user_id}/` |
|
||||
|
||||
### 安全增强
|
||||
|
||||
1. **HttpOnly Cookie** - 防 XSS 窃取 Token
|
||||
2. **Session Token 校验** - JWT 包含 session_token,每次请求验证
|
||||
3. **Startup 初始化管理员** - 服务启动自动创建
|
||||
4. **RLS 最后防线** - Supabase 行级安全策略
|
||||
5. **Cookie 路径规范化** - UUID 格式验证 + 白名单平台校验
|
||||
|
||||
### 数据库表
|
||||
|
||||
```sql
|
||||
-- users (用户)
|
||||
-- user_sessions (单设备登录)
|
||||
-- social_accounts (社交账号绑定)
|
||||
```
|
||||
|
||||
> 详细设计见 [implementation_plan.md](file:///C:/Users/danny/.gemini/antigravity/brain/06e7632c-12c6-4e80-b321-e1e642144560/implementation_plan.md)
|
||||
|
||||
### 后端实现进度
|
||||
|
||||
**状态**:✅ 核心模块完成
|
||||
|
||||
| 文件 | 说明 | 状态 |
|
||||
|------|------|------|
|
||||
| `requirements.txt` | 添加 supabase, python-jose, passlib | ✅ |
|
||||
| `app/core/config.py` | 添加 Supabase/JWT/管理员配置 | ✅ |
|
||||
| `app/core/supabase.py` | Supabase 客户端单例 | ✅ |
|
||||
| `app/core/security.py` | JWT + 密码 + HttpOnly Cookie | ✅ |
|
||||
| `app/core/paths.py` | Cookie 路径规范化 | ✅ |
|
||||
| `app/core/deps.py` | 依赖注入 (当前用户/管理员) | ✅ |
|
||||
| `app/api/auth.py` | 注册/登录/登出 API | ✅ |
|
||||
| `app/api/admin.py` | 用户管理 API | ✅ |
|
||||
| `app/main.py` | startup 初始化管理员 | ✅ |
|
||||
| `database/schema.sql` | Supabase 数据库表 + RLS | ✅ |
|
||||
|
||||
### 前端实现进度
|
||||
|
||||
**状态**:✅ 核心页面完成
|
||||
|
||||
| 文件 | 说明 | 状态 |
|
||||
|------|------|------|
|
||||
| `src/lib/auth.ts` | 认证工具函数 | ✅ |
|
||||
| `src/app/login/page.tsx` | 登录页 | ✅ |
|
||||
| `src/app/register/page.tsx` | 注册页 | ✅ |
|
||||
| `src/app/admin/page.tsx` | 管理后台 | ✅ |
|
||||
| `src/proxy.ts` | 路由保护 | ✅ |
|
||||
|
||||
### 账号隔离集成
|
||||
|
||||
**状态**:✅ 完成
|
||||
|
||||
| 文件 | 修改内容 | 状态 |
|
||||
|------|----------|------|
|
||||
| `app/services/publish_service.py` | 重写支持 user_id 隔离 Cookie | ✅ |
|
||||
| `app/api/publish.py` | 添加认证依赖,传递 user_id | ✅ |
|
||||
|
||||
**Cookie 存储路径**:
|
||||
- 已登录用户: `user_data/{user_id}/cookies/{platform}_cookies.json`
|
||||
- 未登录用户: `app/cookies/{platform}_cookies.json` (兼容旧版)
|
||||
|
||||
---
|
||||
|
||||
## 🔐 用户认证系统实现 (2026-01-23)
|
||||
|
||||
### 问题描述
|
||||
为了支持多用户管理和资源隔离,需要实现一套完整的用户认证系统,取代以前的单用户模式。要求:
|
||||
- 使用 Supabase 作为数据库
|
||||
- 支持注册、登录、登出
|
||||
- 管理员审核机制 (is_active)
|
||||
- 单设备登录限制
|
||||
- HttpOnly Cookie 存储 Token
|
||||
|
||||
### 解决方案
|
||||
|
||||
#### 1. 数据库设计 (Supabase)
|
||||
创建了三张核心表:
|
||||
- `users`: 存储邮箱、密码哈希、角色、激活状态
|
||||
- `user_sessions`: 存储 Session Token,实现单设备登录 (后踢前)
|
||||
- `social_accounts`: 社交账号绑定信息 (B站/抖音Cookie)
|
||||
|
||||
#### 2. 后端实现 (FastAPI)
|
||||
- **依赖注入** (`deps.py`): `get_current_user` 自动验证 Token 和 Session
|
||||
- **安全模块** (`security.py`): JWT 生成与验证,密码 bcrypt 哈希
|
||||
- **路由模块** (`auth.py`):
|
||||
- `/register`: 注册后默认为 `pending` 状态
|
||||
- `/login`: 验证通过后生成 JWT 并写入 HttpOnly Cookie
|
||||
- `/me`: 获取当前用户信息
|
||||
|
||||
#### 3. 部署方案
|
||||
- 采用 Supabase 云端免费版
|
||||
- 为了防止 7 天不活跃暂停,配置了 GitHub Actions / Crontab 自动保活
|
||||
- 创建了独立的部署文档 `Docs/AUTH_DEPLOY.md`
|
||||
|
||||
### 结果
|
||||
- ✅ 成功实现了完整的 JWT 认证流程
|
||||
- ✅ 管理员可以控制用户激活状态
|
||||
- ✅ 实现了安全的无感 Token 刷新 (Session Token)
|
||||
- ✅ 敏感配置 (Supabase Key) 通过环境变量管理
|
||||
|
||||
---
|
||||
|
||||
## 🔗 相关文档
|
||||
|
||||
- [用户认证系统实现计划](file:///C:/Users/danny/.gemini/antigravity/brain/06e7632c-12c6-4e80-b321-e1e642144560/implementation_plan.md)
|
||||
- [代码审核报告](file:///C:/Users/danny/.gemini/antigravity/brain/a28bb1a6-2929-4c55-b837-c989943844e1/walkthrough.md)
|
||||
- [部署手册](file:///d:/CodingProjects/Antigravity/ViGent2/Docs/DEPLOY_MANUAL.md)
|
||||
|
||||
---
|
||||
|
||||
## 🛠️ 部署调试记录 (2026-01-23)
|
||||
|
||||
### 1. 服务启动方式修正
|
||||
- **问题**: pm2 直接启动 python/uvicorn 会导致 `SyntaxError` (Node.js 尝试解释 Python)
|
||||
- **解决**: 改用 `.sh` 脚本封装启动命令
|
||||
|
||||
### 2. 依赖缺失与兼容性
|
||||
- **问题 1**: `ImportError: email-validator is not installed` (Pydantic 依赖)
|
||||
- **修复**: 添加 `email-validator>=2.1.0`
|
||||
- **问题 2**: `AttributeError: module 'bcrypt' has no attribute '__about__'` (Passlib 兼容性)
|
||||
- **修复**: 锁定 `bcrypt==4.0.1`
|
||||
|
||||
### 3. 前端生产环境构建
|
||||
- **问题**: `Error: Could not find a production build`
|
||||
- **解决**: 启动前必须执行 `npm run build`
|
||||
|
||||
### 4. 性能调优
|
||||
- **现象**: SSH 远程连接出现显著卡顿
|
||||
- **排查**: `vigent2-latentsync` 启动时模型加载占用大量系统资源
|
||||
- **优化**: 生产环境建议按需开启 LatentSync 服务,或确保服务器 IO/带宽充足。停止该服务后 SSH 恢复流畅。
|
||||
|
||||
|
||||
@@ -16,6 +16,22 @@
|
||||
|
||||
---
|
||||
|
||||
## 🧾 全局文档更新清单 (Checklist)
|
||||
|
||||
> **每次提交重要变更时,请核对以下文件是否需要同步:**
|
||||
|
||||
| 优先级 | 文件路径 | 检查重点 |
|
||||
| :---: | :--- | :--- |
|
||||
| 🔥 **High** | `Docs/DevLogs/DayN.md` | **(最新日志)** 详细记录变更、修复、代码片段 |
|
||||
| 🔥 **High** | `Docs/task_complete.md` | **(任务总览)** 更新 `[x]`、进度条、时间线 |
|
||||
| ⚡ **Med** | `README.md` | **(项目主页)** 功能特性、技术栈、最新截图 |
|
||||
| ⚡ **Med** | `Docs/DEPLOY_MANUAL.md` | **(部署手册)** 环境变量、依赖包、启动命令变更 |
|
||||
| ⚡ **Med** | `Docs/FRONTEND_DEV.md` | **(前端规范)** API封装、日期格式化、新页面规范 |
|
||||
| 🧊 **Low** | `Docs/implementation_plan.md` | **(实施计划)** 核对计划与实际实现的差异 |
|
||||
| 🧊 **Low** | `frontend/README.md` | **(前端文档)** 新页面路由、组件用法、UI变更 |
|
||||
|
||||
---
|
||||
|
||||
## 🔍 修改原内容的判断标准
|
||||
|
||||
### 场景 1:错误修正 → **替换/删除**
|
||||
@@ -120,7 +136,7 @@
|
||||
|
||||
---
|
||||
|
||||
## <EFBFBD>️ 工具使用规范
|
||||
## ️ 工具使用规范
|
||||
|
||||
> **核心原则**:使用正确的工具,避免字符编码问题
|
||||
|
||||
@@ -185,12 +201,15 @@ replace_file_content(
|
||||
|
||||
---
|
||||
|
||||
## <EFBFBD>📁 文件结构
|
||||
## 📁 文件结构
|
||||
|
||||
```
|
||||
ViGent/Docs/
|
||||
├── task_complete.md # 任务总览(仅按需更新)
|
||||
├── Doc_Rules.md # 本文件
|
||||
├── FRONTEND_DEV.md # 前端开发规范
|
||||
├── DEPLOY_MANUAL.md # 部署手册
|
||||
├── SUPABASE_DEPLOY.md # Supabase 部署文档
|
||||
└── DevLogs/
|
||||
├── Day1.md # 开发日志
|
||||
└── ...
|
||||
@@ -203,8 +222,8 @@ ViGent/Docs/
|
||||
### 新建判断 (对话开始前)
|
||||
1. **回顾进度**:查看 `task_complete.md` 了解当前状态
|
||||
2. **检查日期**:查看最新 `DayN.md`
|
||||
- **今天** → 追加到现有文件
|
||||
- **之前** → 创建 `Day{N+1}.md`
|
||||
- **今天 (与当前日期相同)** → 🚨 **绝对禁止创建新文件**,必须**追加**到现有 `DayN.md` 末尾!即使是完全不同的功能模块。
|
||||
- **之前 (昨天或更早)** → 创建 `Day{N+1}.md`
|
||||
|
||||
### 追加格式
|
||||
```markdown
|
||||
@@ -286,4 +305,4 @@ ViGent/Docs/
|
||||
|
||||
---
|
||||
|
||||
**最后更新**:2026-01-21
|
||||
**最后更新**:2026-01-23
|
||||
|
||||
182
Docs/FRONTEND_DEV.md
Normal file
182
Docs/FRONTEND_DEV.md
Normal file
@@ -0,0 +1,182 @@
|
||||
# 前端开发规范
|
||||
|
||||
## 目录结构
|
||||
|
||||
```
|
||||
frontend/src/
|
||||
├── app/ # Next.js App Router 页面
|
||||
│ ├── page.tsx # 首页(视频生成)
|
||||
│ ├── publish/ # 发布页面
|
||||
│ ├── admin/ # 管理员页面
|
||||
│ ├── login/ # 登录页面
|
||||
│ └── register/ # 注册页面
|
||||
├── lib/ # 公共工具函数
|
||||
│ ├── axios.ts # Axios 实例(含 401/403 拦截器)
|
||||
│ └── auth.ts # 认证相关函数
|
||||
└── proxy.ts # 路由代理(原 middleware)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## iOS Safari 安全区域兼容
|
||||
|
||||
### 问题
|
||||
iPhone Safari 浏览器顶部(刘海/灵动岛)和底部(Home 指示条)有安全区域,默认情况下页面背景不会延伸到这些区域,导致白边。
|
||||
|
||||
### 解决方案(三层配合)
|
||||
|
||||
#### 1. Viewport 配置 (`layout.tsx`)
|
||||
```typescript
|
||||
import type { Viewport } from "next";
|
||||
|
||||
export const viewport: Viewport = {
|
||||
width: 'device-width',
|
||||
initialScale: 1,
|
||||
viewportFit: 'cover', // 允许内容延伸到安全区域
|
||||
themeColor: '#0f172a', // 顶部状态栏颜色(与背景一致)
|
||||
};
|
||||
```
|
||||
|
||||
#### 2. 全局背景统一到 body (`layout.tsx`)
|
||||
```tsx
|
||||
<html lang="en" style={{ backgroundColor: '#0f172a' }}>
|
||||
<body
|
||||
style={{
|
||||
margin: 0,
|
||||
minHeight: '100dvh', // 使用 dvh 而非 vh
|
||||
background: 'linear-gradient(to bottom, #0f172a 0%, #0f172a 5%, #581c87 50%, #0f172a 95%, #0f172a 100%)',
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</body>
|
||||
</html>
|
||||
```
|
||||
|
||||
#### 3. CSS 安全区域支持 (`globals.css`)
|
||||
```css
|
||||
html {
|
||||
background-color: #0f172a !important;
|
||||
min-height: 100%;
|
||||
}
|
||||
|
||||
body {
|
||||
margin: 0 !important;
|
||||
min-height: 100dvh;
|
||||
padding-top: env(safe-area-inset-top);
|
||||
padding-bottom: env(safe-area-inset-bottom);
|
||||
}
|
||||
```
|
||||
|
||||
### 关键要点
|
||||
- **渐变背景放 body,不放页面 div** - 安全区域在 div 之外
|
||||
- **使用 `100dvh` 而非 `100vh`** - dvh 是动态视口高度,适配移动端
|
||||
- **themeColor 与背景边缘色一致** - 避免状态栏色差
|
||||
- **页面 div 移除独立背景** - 使用透明,继承 body 渐变
|
||||
|
||||
---
|
||||
|
||||
## 移动端响应式规范
|
||||
|
||||
### Header 按钮布局
|
||||
```tsx
|
||||
// 移动端紧凑,桌面端宽松
|
||||
<div className="flex items-center gap-1 sm:gap-4">
|
||||
<button className="px-2 sm:px-4 py-1 sm:py-2 text-sm sm:text-base ...">
|
||||
按钮
|
||||
</button>
|
||||
</div>
|
||||
```
|
||||
|
||||
### 常用响应式断点
|
||||
| 断点 | 宽度 | 用途 |
|
||||
|------|------|------|
|
||||
| 默认 | < 640px | 移动端 |
|
||||
| `sm:` | ≥ 640px | 平板/桌面 |
|
||||
| `lg:` | ≥ 1024px | 大屏桌面 |
|
||||
|
||||
---
|
||||
|
||||
## API 请求规范
|
||||
|
||||
### 必须使用 `api` (axios 实例)
|
||||
|
||||
所有需要认证的 API 请求**必须**使用 `@/lib/axios` 导出的 axios 实例。该实例已配置:
|
||||
- 自动携带 `credentials: include`
|
||||
- 遇到 401/403 时自动清除 cookie 并跳转登录页
|
||||
|
||||
**使用方式:**
|
||||
|
||||
```typescript
|
||||
import api from '@/lib/axios';
|
||||
|
||||
// GET 请求
|
||||
const { data } = await api.get('/api/materials');
|
||||
|
||||
// POST 请求
|
||||
const { data } = await api.post('/api/videos/generate', {
|
||||
text: '...',
|
||||
voice: '...',
|
||||
});
|
||||
|
||||
// DELETE 请求
|
||||
await api.delete(`/api/materials/${id}`);
|
||||
|
||||
// 带上传进度的文件上传
|
||||
await api.post('/api/materials', formData, {
|
||||
headers: { 'Content-Type': 'multipart/form-data' },
|
||||
onUploadProgress: (e) => {
|
||||
if (e.total) {
|
||||
const progress = Math.round((e.loaded / e.total) * 100);
|
||||
setProgress(progress);
|
||||
}
|
||||
},
|
||||
});
|
||||
```
|
||||
|
||||
### SWR 配合使用
|
||||
|
||||
```typescript
|
||||
import api from '@/lib/axios';
|
||||
|
||||
// SWR fetcher 使用 axios
|
||||
const fetcher = (url: string) => api.get(url).then(res => res.data);
|
||||
|
||||
const { data } = useSWR('/api/xxx', fetcher, { refreshInterval: 2000 });
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 日期格式化规范
|
||||
|
||||
### 禁止使用 `toLocaleString()`
|
||||
|
||||
`toLocaleString()` 在服务端和客户端可能返回不同格式,导致 Hydration 错误。
|
||||
|
||||
**错误示例:**
|
||||
```typescript
|
||||
// ❌ 会导致 Hydration 错误
|
||||
new Date(timestamp * 1000).toLocaleString('zh-CN')
|
||||
```
|
||||
|
||||
**正确做法:**
|
||||
```typescript
|
||||
// ✅ 使用固定格式
|
||||
const formatDate = (timestamp: number) => {
|
||||
const d = new Date(timestamp * 1000);
|
||||
const year = d.getFullYear();
|
||||
const month = String(d.getMonth() + 1).padStart(2, '0');
|
||||
const day = String(d.getDate()).padStart(2, '0');
|
||||
const hour = String(d.getHours()).padStart(2, '0');
|
||||
const minute = String(d.getMinutes()).padStart(2, '0');
|
||||
return `${year}/${month}/${day} ${hour}:${minute}`;
|
||||
};
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 新增页面 Checklist
|
||||
|
||||
1. [ ] 导入 `import api from '@/lib/axios'`
|
||||
2. [ ] 所有 API 请求使用 `api.get/post/delete()` 而非原生 `fetch`
|
||||
3. [ ] 日期格式化使用固定格式函数,不用 `toLocaleString()`
|
||||
4. [ ] 添加 `'use client'` 指令(如需客户端交互)
|
||||
29
Docs/Logs.md
Normal file
29
Docs/Logs.md
Normal file
@@ -0,0 +1,29 @@
|
||||
rongye@r730-ubuntu:~/ProgramFiles/Supabase$ docker compose up -d
|
||||
[+] up 136/136
|
||||
✔ Image timberio/vector:0.28.1-alpine Pulled 63.3ss
|
||||
✔ Image supabase/storage-api:v1.33.0 Pulled 78.6ss
|
||||
✔ Image darthsim/imgproxy:v3.30.1 Pulled 151.9s
|
||||
✔ Image supabase/postgres-meta:v0.95.1 Pulled 87.5ss
|
||||
✔ Image supabase/logflare:1.27.0 Pulled 229.2s
|
||||
✔ Image supabase/postgres:15.8.1.085 Pulled 268.3s
|
||||
✔ Image supabase/supavisor:2.7.4 Pulled 101.6s
|
||||
✔ Image supabase/realtime:v2.68.0 Pulled 56.5ss
|
||||
✔ Image postgrest/postgrest:v14.1 Pulled 201.8s
|
||||
✔ Image supabase/edge-runtime:v1.69.28 Pulled 254.0s
|
||||
✔ Network supabase_default Created 0.1s
|
||||
✔ Volume supabase_db-config Created 0.1s
|
||||
✔ Container supabase-vector Healthy 16.9s
|
||||
✔ Container supabase-imgproxy Created 7.4s
|
||||
✔ Container supabase-db Healthy 20.6s
|
||||
✔ Container supabase-analytics Created 0.4s
|
||||
✔ Container supabase-edge-functions Created 1.8s
|
||||
✔ Container supabase-auth Created 1.7s
|
||||
✔ Container supabase-studio Created 2.0s
|
||||
✔ Container realtime-dev.supabase-realtime Created 1.7s
|
||||
✔ Container supabase-pooler Created 1.8s
|
||||
✔ Container supabase-kong Created 1.7s
|
||||
✔ Container supabase-meta Created 2.0s
|
||||
✔ Container supabase-rest Created 0.9s
|
||||
✔ Container supabase-storage Created 1.4s
|
||||
Error response from daemon: failed to set up container networking: driver failed programming external connectivity on endpoint supabase-analytics (2fd60a510a1f16bf29f8f5140f14ef457a284c5b65a2567b7be250a4f9708f34): failed to bind host port 0.0.0.0:4000/tcp: address already in use
|
||||
[ble: exit 1]
|
||||
253
Docs/QWEN3_TTS_DEPLOY.md
Normal file
253
Docs/QWEN3_TTS_DEPLOY.md
Normal file
@@ -0,0 +1,253 @@
|
||||
# Qwen3-TTS 0.6B 部署指南
|
||||
|
||||
> 本文档描述如何在 Ubuntu 服务器上部署 Qwen3-TTS 0.6B-Base 声音克隆模型。
|
||||
|
||||
## 系统要求
|
||||
|
||||
| 要求 | 规格 |
|
||||
|------|------|
|
||||
| GPU | NVIDIA RTX 3090 24GB (或更高) |
|
||||
| VRAM | ≥ 4GB (推理), ≥ 8GB (带 flash-attn) |
|
||||
| CUDA | 12.1+ |
|
||||
| Python | 3.10.x |
|
||||
| 系统 | Ubuntu 20.04+ |
|
||||
|
||||
---
|
||||
|
||||
## GPU 分配
|
||||
|
||||
| GPU | 服务 | 模型 |
|
||||
|-----|------|------|
|
||||
| GPU0 | **Qwen3-TTS** | 0.6B-Base (声音克隆) |
|
||||
| GPU1 | LatentSync | 1.6 (唇形同步) |
|
||||
|
||||
---
|
||||
|
||||
## 步骤 1: 克隆仓库
|
||||
|
||||
```bash
|
||||
cd /home/rongye/ProgramFiles/ViGent2/models
|
||||
git clone https://github.com/QwenLM/Qwen3-TTS.git
|
||||
cd Qwen3-TTS
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 步骤 2: 创建 Conda 环境
|
||||
|
||||
```bash
|
||||
# 创建新的 conda 环境
|
||||
conda create -n qwen-tts python=3.10 -y
|
||||
conda activate qwen-tts
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 步骤 3: 安装 Python 依赖
|
||||
|
||||
```bash
|
||||
cd /home/rongye/ProgramFiles/ViGent2/models/Qwen3-TTS
|
||||
|
||||
# 安装 qwen-tts 包 (editable mode)
|
||||
pip install -e .
|
||||
|
||||
# 安装 sox 音频处理库 (必须)
|
||||
conda install -y -c conda-forge sox
|
||||
```
|
||||
|
||||
### 可选: 安装 FlashAttention (推荐)
|
||||
|
||||
FlashAttention 可以显著提升推理速度并减少显存占用:
|
||||
|
||||
```bash
|
||||
pip install -U flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
如果内存不足,可以限制编译并发数:
|
||||
|
||||
```bash
|
||||
MAX_JOBS=4 pip install -U flash-attn --no-build-isolation
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 步骤 4: 下载模型权重
|
||||
|
||||
### 方式 A: ModelScope (推荐,国内更快)
|
||||
|
||||
```bash
|
||||
pip install modelscope
|
||||
|
||||
# 下载 Tokenizer (651MB)
|
||||
modelscope download --model Qwen/Qwen3-TTS-Tokenizer-12Hz --local_dir ./checkpoints/Tokenizer
|
||||
|
||||
# 下载 0.6B-Base 模型 (2.4GB)
|
||||
modelscope download --model Qwen/Qwen3-TTS-12Hz-0.6B-Base --local_dir ./checkpoints/0.6B-Base
|
||||
```
|
||||
|
||||
### 方式 B: HuggingFace
|
||||
|
||||
```bash
|
||||
pip install -U "huggingface_hub[cli]"
|
||||
|
||||
huggingface-cli download Qwen/Qwen3-TTS-Tokenizer-12Hz --local-dir ./checkpoints/Tokenizer
|
||||
huggingface-cli download Qwen/Qwen3-TTS-12Hz-0.6B-Base --local-dir ./checkpoints/0.6B-Base
|
||||
```
|
||||
|
||||
下载完成后,目录结构应如下:
|
||||
|
||||
```
|
||||
checkpoints/
|
||||
├── Tokenizer/ # ~651MB
|
||||
│ ├── config.json
|
||||
│ ├── model.safetensors
|
||||
│ └── ...
|
||||
└── 0.6B-Base/ # ~2.4GB
|
||||
├── config.json
|
||||
├── model.safetensors
|
||||
└── ...
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 步骤 5: 验证安装
|
||||
|
||||
### 5.1 检查环境
|
||||
|
||||
```bash
|
||||
conda activate qwen-tts
|
||||
|
||||
# 检查 PyTorch 和 CUDA
|
||||
python -c "import torch; print(f'PyTorch: {torch.__version__}'); print(f'CUDA: {torch.cuda.is_available()}')"
|
||||
|
||||
# 检查 sox
|
||||
sox --version
|
||||
```
|
||||
|
||||
### 5.2 运行推理测试
|
||||
|
||||
创建测试脚本 `test_inference.py`:
|
||||
|
||||
```python
|
||||
"""Qwen3-TTS 声音克隆测试"""
|
||||
import torch
|
||||
import soundfile as sf
|
||||
from qwen_tts import Qwen3TTSModel
|
||||
|
||||
print("Loading Qwen3-TTS model on GPU:0...")
|
||||
model = Qwen3TTSModel.from_pretrained(
|
||||
"./checkpoints/0.6B-Base",
|
||||
device_map="cuda:0",
|
||||
dtype=torch.bfloat16,
|
||||
)
|
||||
print("Model loaded!")
|
||||
|
||||
# 测试声音克隆 (需要准备参考音频)
|
||||
ref_audio = "./examples/myvoice.wav" # 3-20秒的参考音频
|
||||
ref_text = "参考音频的文字内容"
|
||||
|
||||
test_text = "这是一段测试文本,用于验证声音克隆功能是否正常工作。"
|
||||
|
||||
print("Generating cloned voice...")
|
||||
wavs, sr = model.generate_voice_clone(
|
||||
text=test_text,
|
||||
language="Chinese",
|
||||
ref_audio=ref_audio,
|
||||
ref_text=ref_text,
|
||||
)
|
||||
|
||||
sf.write("test_output.wav", wavs[0], sr)
|
||||
print(f"✅ Saved: test_output.wav | {sr}Hz | {len(wavs[0])/sr:.2f}s")
|
||||
```
|
||||
|
||||
运行测试:
|
||||
|
||||
```bash
|
||||
cd /home/rongye/ProgramFiles/ViGent2/models/Qwen3-TTS
|
||||
python test_inference.py
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 目录结构
|
||||
|
||||
部署完成后,目录结构应如下:
|
||||
|
||||
```
|
||||
/home/rongye/ProgramFiles/ViGent2/models/Qwen3-TTS/
|
||||
├── checkpoints/
|
||||
│ ├── Tokenizer/ # 语音编解码器
|
||||
│ └── 0.6B-Base/ # 声音克隆模型
|
||||
├── qwen_tts/ # 源码
|
||||
│ ├── inference/
|
||||
│ ├── models/
|
||||
│ └── ...
|
||||
├── examples/
|
||||
│ └── myvoice.wav # 参考音频
|
||||
├── pyproject.toml
|
||||
├── requirements.txt
|
||||
└── test_inference.py # 测试脚本
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 模型说明
|
||||
|
||||
### 可用模型
|
||||
|
||||
| 模型 | 功能 | 大小 |
|
||||
|------|------|------|
|
||||
| 0.6B-Base | 3秒快速声音克隆 | 2.4GB |
|
||||
| 0.6B-CustomVoice | 9种预设音色 | 2.4GB |
|
||||
| 1.7B-Base | 声音克隆 (更高质量) | 6.8GB |
|
||||
| 1.7B-VoiceDesign | 自然语言描述生成声音 | 6.8GB |
|
||||
|
||||
### 支持语言
|
||||
|
||||
中文、英语、日语、韩语、德语、法语、俄语、葡萄牙语、西班牙语、意大利语
|
||||
|
||||
---
|
||||
|
||||
## 故障排除
|
||||
|
||||
### sox 未找到
|
||||
|
||||
```
|
||||
SoX could not be found!
|
||||
```
|
||||
|
||||
**解决**: 通过 conda 安装 sox:
|
||||
|
||||
```bash
|
||||
conda install -y -c conda-forge sox
|
||||
```
|
||||
|
||||
### CUDA 内存不足
|
||||
|
||||
Qwen3-TTS 0.6B 通常只需要 4-6GB VRAM。如果遇到 OOM:
|
||||
|
||||
1. 确保 GPU0 没有运行其他程序
|
||||
2. 不使用 flash-attn (会增加显存占用)
|
||||
3. 使用更小的参考音频 (3-5秒)
|
||||
|
||||
### 模型加载失败
|
||||
|
||||
确保以下文件存在:
|
||||
- `checkpoints/0.6B-Base/config.json`
|
||||
- `checkpoints/0.6B-Base/model.safetensors`
|
||||
|
||||
### 音频输出质量问题
|
||||
|
||||
1. 参考音频质量:使用清晰、无噪音的 3-10 秒音频
|
||||
2. ref_text 准确性:参考音频的转写文字必须准确
|
||||
3. 语言设置:确保 `language` 参数与文本语言一致
|
||||
|
||||
---
|
||||
|
||||
## 参考链接
|
||||
|
||||
- [Qwen3-TTS GitHub](https://github.com/QwenLM/Qwen3-TTS)
|
||||
- [ModelScope 模型](https://modelscope.cn/collections/Qwen/Qwen3-TTS)
|
||||
- [HuggingFace 模型](https://huggingface.co/collections/Qwen/qwen3-tts)
|
||||
- [技术报告](https://arxiv.org/abs/2601.15621)
|
||||
- [官方博客](https://qwen.ai/blog?id=qwen3tts-0115)
|
||||
291
Docs/SUPABASE_DEPLOY.md
Normal file
291
Docs/SUPABASE_DEPLOY.md
Normal file
@@ -0,0 +1,291 @@
|
||||
# Supabase 全栈部署指南 (Infrastructure + Auth)
|
||||
|
||||
本文档涵盖了 Supabase 基础设施的 Docker 部署、密钥配置、Nginx 安全加固以及用户认证系统的数据库初始化。
|
||||
|
||||
---
|
||||
|
||||
## 第一部分:基础设施部署 (Infrastructure)
|
||||
|
||||
### 1. 准备 Docker 环境 (Ubuntu)
|
||||
|
||||
Supabase 严重依赖官方目录结构(挂载配置文件),**必须包含完整的 `docker` 目录**。
|
||||
|
||||
```bash
|
||||
# 1. 创建目录
|
||||
mkdir -p /home/rongye/ProgramFiles/Supabase
|
||||
cd /home/rongye/ProgramFiles/Supabase
|
||||
|
||||
# 2. 获取官方配置
|
||||
# 克隆仓库并提取 docker 目录
|
||||
git clone --depth 1 https://github.com/supabase/supabase.git temp_repo
|
||||
mv temp_repo/docker/* .
|
||||
rm -rf temp_repo
|
||||
|
||||
# 3. 复制环境变量模板
|
||||
cp .env.example .env
|
||||
```
|
||||
|
||||
### 2. 生成安全密钥
|
||||
|
||||
**警告**:官方模板使用的是公开的弱密钥。生产环境必须重新生成。
|
||||
使用项目提供的脚本自动生成全套强密钥:
|
||||
|
||||
```bash
|
||||
# 在 ViGent2 项目目录下
|
||||
cd /home/rongye/ProgramFiles/ViGent2/backend
|
||||
python generate_keys.py
|
||||
```
|
||||
|
||||
将脚本生成的输出(包括 `JWT_SECRET`, `ANON_KEY`, `SERVICE_ROLE_KEY` 等)复制并**覆盖** `/home/rongye/ProgramFiles/Supabase/.env` 中的对应内容。
|
||||
|
||||
### 3. 配置端口与冲突解决
|
||||
|
||||
编辑 Supabase 的 `.env` 文件,修改以下端口以避免与现有服务(Code-Server, Moodist)冲突:
|
||||
|
||||
```ini
|
||||
# --- Port Configuration ---
|
||||
# 避免与 Code-Server (8443) 冲突
|
||||
KONG_HTTPS_PORT=8444
|
||||
|
||||
# 自定义 API 端口 (默认 8000)
|
||||
KONG_HTTP_PORT=8008
|
||||
|
||||
# 自定义管理后台端口 (默认 3000)
|
||||
STUDIO_PORT=3003
|
||||
|
||||
# 外部访问 URL (重要:填入你的公网 API 域名/IP)
|
||||
# 如果配置了 Nginx 反代: https://api.hbyrkj.top
|
||||
# 如果直连: http://8.148.25.142:8008
|
||||
API_EXTERNAL_URL=https://api.hbyrkj.top
|
||||
|
||||
# Studio 公网 API 地址 (通过公网访问 Studio 时必须配置)
|
||||
# 用于 Studio 前端调用 API
|
||||
SUPABASE_PUBLIC_URL=https://api.hbyrkj.top
|
||||
```
|
||||
|
||||
### 4. 启动服务
|
||||
|
||||
```bash
|
||||
docker compose up -d
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 第二部分:Storage 本地文件结构
|
||||
|
||||
### 1. 存储路径
|
||||
|
||||
Supabase Storage 使用本地文件系统存储,路径结构如下:
|
||||
|
||||
```
|
||||
/home/rongye/ProgramFiles/Supabase/volumes/storage/stub/stub/
|
||||
├── materials/ # 素材桶
|
||||
│ └── {user_id}/ # 用户目录 (隔离)
|
||||
│ └── {timestamp}_{filename}/
|
||||
│ └── {internal_uuid} # 实际文件 (Supabase 内部 UUID)
|
||||
└── outputs/ # 输出桶
|
||||
└── {user_id}/
|
||||
└── {task_id}_output.mp4/
|
||||
└── {internal_uuid}
|
||||
```
|
||||
|
||||
### 2. 用户隔离策略
|
||||
|
||||
所有用户数据通过路径前缀实现隔离:
|
||||
|
||||
| 资源类型 | 路径格式 | 示例 |
|
||||
|----------|----------|------|
|
||||
| 素材 | `{bucket}/{user_id}/{timestamp}_{filename}` | `materials/abc123/1737000001_video.mp4` |
|
||||
| 输出 | `{bucket}/{user_id}/{task_id}_output.mp4` | `outputs/abc123/uuid-xxx_output.mp4` |
|
||||
| Cookie | `cookies/{user_id}/{platform}.json` | `cookies/abc123/bilibili.json` |
|
||||
|
||||
### 3. 直接访问本地文件
|
||||
|
||||
后端可以直接读取本地文件(跳过 HTTP),提升发布等操作的效率:
|
||||
|
||||
```python
|
||||
# storage.py
|
||||
SUPABASE_STORAGE_LOCAL_PATH = Path("/home/rongye/ProgramFiles/Supabase/volumes/storage/stub/stub")
|
||||
|
||||
def get_local_file_path(self, bucket: str, path: str) -> Optional[str]:
|
||||
dir_path = SUPABASE_STORAGE_LOCAL_PATH / bucket / path
|
||||
files = list(dir_path.iterdir())
|
||||
return str(files[0]) if files else None
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 第三部分:安全访问配置 (Nginx)
|
||||
|
||||
建议在阿里云公网网关上配置 Nginx 反向代理,通过 Frp 隧道连接内网服务。
|
||||
|
||||
### 1. 域名规划
|
||||
- **管理后台**: `https://supabase.hbyrkj.top` -> 内网 3003
|
||||
- **API 接口**: `https://api.hbyrkj.top` -> 内网 8008
|
||||
|
||||
### 2. Nginx 配置示例
|
||||
|
||||
```nginx
|
||||
# Studio (需要密码保护,但静态资源和内部API需排除)
|
||||
server {
|
||||
server_name supabase.hbyrkj.top;
|
||||
|
||||
# SSL 配置略...
|
||||
|
||||
# 静态资源不需要认证
|
||||
location ~ ^/(favicon|_next|static)/ {
|
||||
auth_basic off;
|
||||
proxy_pass http://127.0.0.1:3003;
|
||||
proxy_set_header Host $host;
|
||||
proxy_http_version 1.1;
|
||||
}
|
||||
|
||||
# Studio 内部 API 调用不需要认证
|
||||
location /api/ {
|
||||
auth_basic off;
|
||||
proxy_pass http://127.0.0.1:3003;
|
||||
proxy_set_header Host $host;
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
}
|
||||
|
||||
# 其他路径需要 Basic Auth 认证
|
||||
location / {
|
||||
auth_basic "Restricted Studio";
|
||||
auth_basic_user_file /etc/nginx/.htpasswd;
|
||||
proxy_pass http://127.0.0.1:3003;
|
||||
|
||||
# WebSocket 支持 (Realtime 必须)
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
}
|
||||
}
|
||||
|
||||
# API (公开访问)
|
||||
server {
|
||||
server_name api.hbyrkj.top;
|
||||
|
||||
# SSL 配置略...
|
||||
|
||||
# ⚠️ 重要:解除上传大小限制
|
||||
client_max_body_size 0;
|
||||
|
||||
location / {
|
||||
proxy_pass http://127.0.0.1:8008;
|
||||
|
||||
# 允许 WebSocket
|
||||
proxy_http_version 1.1;
|
||||
proxy_set_header Upgrade $http_upgrade;
|
||||
proxy_set_header Connection "upgrade";
|
||||
|
||||
# 大文件上传超时设置
|
||||
proxy_read_timeout 600s;
|
||||
proxy_send_timeout 600s;
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 3. 关键配置说明
|
||||
|
||||
| 配置项 | 作用 | 必要性 |
|
||||
|--------|------|--------|
|
||||
| `client_max_body_size 0` | 解除上传大小限制(默认 1MB) | **必须** |
|
||||
| `proxy_read_timeout 600s` | 大文件上传/下载超时 | 推荐 |
|
||||
| `proxy_http_version 1.1` | WebSocket 支持 | Realtime 必须 |
|
||||
| `auth_basic` | Studio 访问保护 | 推荐 |
|
||||
|
||||
---
|
||||
|
||||
## 第四部分:数据库与认证配置 (Database & Auth)
|
||||
|
||||
### 1. 初始化表结构 (Schema)
|
||||
|
||||
访问管理后台 (Studio) 的 **SQL Editor**,执行以下 SQL 来初始化 ViGent2 所需的表结构:
|
||||
|
||||
```sql
|
||||
-- 1. 用户表 (扩展 auth.users 或独立存储)
|
||||
-- 注意:这里使用独立表设计,与 FastAPI 逻辑解耦
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
password_hash TEXT NOT NULL,
|
||||
username TEXT,
|
||||
role TEXT DEFAULT 'pending' CHECK (role IN ('pending', 'user', 'admin')),
|
||||
is_active BOOLEAN DEFAULT FALSE,
|
||||
expires_at TIMESTAMP WITH TIME ZONE,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- 2. 会话表 (单设备登录控制)
|
||||
CREATE TABLE IF NOT EXISTS user_sessions (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID REFERENCES users(id) ON DELETE CASCADE UNIQUE,
|
||||
session_token TEXT UNIQUE NOT NULL,
|
||||
device_info TEXT,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- 3. 社交媒体账号绑定表
|
||||
CREATE TABLE IF NOT EXISTS social_accounts (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID REFERENCES users(id) ON DELETE CASCADE,
|
||||
platform TEXT NOT NULL CHECK (platform IN ('bilibili', 'douyin', 'xiaohongshu')),
|
||||
logged_in BOOLEAN DEFAULT FALSE,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||||
UNIQUE(user_id, platform)
|
||||
);
|
||||
|
||||
-- 4. 性能索引
|
||||
CREATE INDEX IF NOT EXISTS idx_users_email ON users(email);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON user_sessions(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_social_user_platform ON social_accounts(user_id, platform);
|
||||
```
|
||||
|
||||
### 2. 后端集成配置 (FastAPI)
|
||||
|
||||
修改 `ViGent2/backend/.env` 以连接到自托管的 Supabase:
|
||||
|
||||
```ini
|
||||
# =============== Supabase 配置 ===============
|
||||
# 指向 Docker 部署的 API 端口 (内网直连推荐用 Localhost)
|
||||
SUPABASE_URL=http://localhost:8008
|
||||
|
||||
# 使用生成的 SERVICE_ROLE_KEY (后端需要管理员权限)
|
||||
SUPABASE_KEY=eyJhbGciOiJIUzI1Ni...
|
||||
|
||||
# =============== JWT 配置 ===============
|
||||
# 必须与 Supabase .env 中的 JWT_SECRET 保持一致!
|
||||
JWT_SECRET_KEY=填入_generate_keys.py_生成的_JWT_SECRET
|
||||
JWT_ALGORITHM=HS256
|
||||
JWT_EXPIRE_HOURS=168
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## 第五部分:常用维护命令
|
||||
|
||||
**查看服务状态**:
|
||||
```bash
|
||||
cd /home/rongye/ProgramFiles/Supabase
|
||||
docker compose ps
|
||||
```
|
||||
|
||||
**查看密钥**:
|
||||
```bash
|
||||
grep -E "ANON|SERVICE|SECRET" .env
|
||||
```
|
||||
|
||||
**重启服务**:
|
||||
```bash
|
||||
docker compose restart
|
||||
```
|
||||
|
||||
**完全重置数据库 (慎用)**:
|
||||
```bash
|
||||
docker compose down -v
|
||||
rm -rf volumes/db/data
|
||||
docker compose up -d
|
||||
```
|
||||
@@ -22,7 +22,7 @@
|
||||
┌─────────────────────────────────────────────────────────┐
|
||||
│ 后端 (FastAPI) │
|
||||
├─────────────────────────────────────────────────────────┤
|
||||
│ Celery 任务队列 (Redis) │
|
||||
│ 异步任务队列 (asyncio) │
|
||||
│ ├── 视频生成任务 │
|
||||
│ ├── TTS 配音任务 │
|
||||
│ └── 自动发布任务 │
|
||||
@@ -30,7 +30,7 @@
|
||||
│ │ │
|
||||
▼ ▼ ▼
|
||||
┌──────────┐ ┌──────────┐ ┌──────────┐
|
||||
│ MuseTalk │ │ FFmpeg │ │Playwright│
|
||||
│LatentSync│ │ FFmpeg │ │Playwright│
|
||||
│ 唇形同步 │ │ 视频合成 │ │ 自动发布 │
|
||||
└──────────┘ └──────────┘ └──────────┘
|
||||
```
|
||||
@@ -45,7 +45,7 @@
|
||||
| **UI 组件库** | Tailwind + shadcn/ui | Ant Design |
|
||||
| **后端框架** | FastAPI | Flask |
|
||||
| **任务队列** | Celery + Redis | RQ / Dramatiq |
|
||||
| **唇形同步** | MuseTalk | Wav2Lip / SadTalker |
|
||||
| **唇形同步** | **LatentSync 1.6** | MuseTalk / Wav2Lip |
|
||||
| **TTS 配音** | EdgeTTS | CosyVoice |
|
||||
| **声音克隆** | GPT-SoVITS (可选) | - |
|
||||
| **视频处理** | FFmpeg | MoviePy |
|
||||
@@ -141,12 +141,12 @@ backend/
|
||||
|
||||
| 端点 | 方法 | 功能 |
|
||||
|------|------|------|
|
||||
| `/api/materials` | POST | 上传素材视频 |
|
||||
| `/api/materials` | GET | 获取素材列表 |
|
||||
| `/api/videos/generate` | POST | 创建视频生成任务 |
|
||||
| `/api/tasks/{id}` | GET | 查询任务状态 |
|
||||
| `/api/videos/{id}/download` | GET | 下载生成的视频 |
|
||||
| `/api/publish` | POST | 发布到社交平台 |
|
||||
| `/api/materials` | POST | 上传素材视频 | ✅ |
|
||||
| `/api/materials` | GET | 获取素材列表 | ✅ |
|
||||
| `/api/videos/generate` | POST | 创建视频生成任务 | ✅ |
|
||||
| `/api/tasks/{id}` | GET | 查询任务状态 | ✅ |
|
||||
| `/api/videos/{id}/download` | GET | 下载生成的视频 | ✅ |
|
||||
| `/api/publish` | POST | 发布到社交平台 | ✅ |
|
||||
|
||||
#### 2.3 Celery 任务定义
|
||||
|
||||
@@ -221,7 +221,7 @@ cp -r SuperIPAgent/social-auto-upload backend/social_upload
|
||||
| **声音克隆** | 集成 GPT-SoVITS,用自己的声音 |
|
||||
| **批量生成** | 上传 Excel/CSV,批量生成视频 |
|
||||
| **字幕编辑器** | 可视化调整字幕样式、位置 |
|
||||
| **Docker 部署** | 一键部署到云服务器 |
|
||||
| **Docker 部署** | 一键部署到云服务器 | ✅ |
|
||||
|
||||
---
|
||||
|
||||
@@ -269,6 +269,60 @@ cp -r SuperIPAgent/social-auto-upload backend/social_upload
|
||||
- [x] **常驻模型服务** (Persistent Server, 0s 加载)
|
||||
- [x] **GPU 并发控制** (串行队列防崩溃)
|
||||
|
||||
### 阶段十一:社交媒体发布完善 (Day 7) ✅
|
||||
|
||||
> **目标**:实现全自动扫码登录和多平台发布
|
||||
|
||||
- [x] QR码自动登录 (Playwright headless + Stealth)
|
||||
- [x] 多平台上传器架构 (B站/抖音/小红书)
|
||||
- [x] Cookie 自动管理
|
||||
- [x] 定时发布功能
|
||||
|
||||
### 阶段十二:用户体验优化 (Day 8) ✅
|
||||
|
||||
> **目标**:提升文件管理和历史记录功能
|
||||
|
||||
- [x] 文件名保留 (时间戳前缀 + 原始名称)
|
||||
- [x] 视频持久化 (历史视频列表 API)
|
||||
- [x] 素材/视频删除功能
|
||||
|
||||
### 阶段十三:发布模块优化 (Day 9) ✅
|
||||
|
||||
> **目标**:代码质量优化 + 发布功能验证
|
||||
|
||||
- [x] B站/抖音登录+发布验证通过
|
||||
- [x] 资源清理保障 (try-finally)
|
||||
- [x] 超时保护 (消除无限循环)
|
||||
- [x] 完整类型提示
|
||||
|
||||
### 阶段十四:用户认证系统 (Day 9) ✅
|
||||
|
||||
> **目标**:实现安全、隔离的多用户认证体系
|
||||
|
||||
- [x] Supabase 云数据库集成 (本地自托管)
|
||||
- [x] JWT + HttpOnly Cookie 认证架构
|
||||
- [x] 用户表与权限表设计 (RLS 准备)
|
||||
- [x] 认证部署文档 (Docs/SUPABASE_DEPLOY.md)
|
||||
|
||||
### 阶段十五:部署稳定性优化 (Day 9) ✅
|
||||
|
||||
> **目标**:确保生产环境服务长期稳定
|
||||
|
||||
- [x] 依赖冲突修复 (bcrypt)
|
||||
- [x] 前端构建修复 (Production Build)
|
||||
- [x] PM2 进程守护配置
|
||||
- [x] 部署手册更新 (Docs/DEPLOY_MANUAL.md)
|
||||
|
||||
### 阶段十六:HTTPS 全栈部署 (Day 10) ✅
|
||||
|
||||
> **目标**:实现安全的公网 HTTPS 访问
|
||||
|
||||
- [x] 阿里云 Nginx 反向代理配置
|
||||
- [x] Let's Encrypt SSL 证书集成
|
||||
- [x] Supabase 自托管部署 (Docker)
|
||||
- [x] 端口冲突解决 (3003/8008/8444)
|
||||
- [x] Basic Auth 管理后台保护
|
||||
|
||||
---
|
||||
|
||||
## 项目目录结构 (最终)
|
||||
|
||||
@@ -1,23 +1,24 @@
|
||||
# ViGent 数字人口播系统 - 开发任务清单
|
||||
|
||||
**项目**:ViGent2 数字人口播视频生成系统
|
||||
**服务器**:Dell R730 (2× RTX 3090 24GB)
|
||||
**更新时间**:2026-01-21
|
||||
**整体进度**:100%(Day 7 社交发布完成)
|
||||
**项目**:ViGent2 数字人口播视频生成系统
|
||||
**服务器**:Dell R730 (2× RTX 3090 24GB)
|
||||
**更新时间**:2026-01-28
|
||||
**整体进度**:100%(Day 12 iOS 兼容、移动端优化、Qwen3-TTS 部署)
|
||||
|
||||
## 📖 快速导航
|
||||
|
||||
| 章节 | 说明 |
|
||||
|------|------|
|
||||
| [已完成任务](#-已完成任务) | Day 1-4 完成的功能 |
|
||||
| [已完成任务](#-已完成任务) | Day 1-12 完成的功能 |
|
||||
| [后续规划](#️-后续规划) | 待办项目 |
|
||||
| [进度统计](#-进度统计) | 各模块完成度 |
|
||||
| [里程碑](#-里程碑) | 关键节点 |
|
||||
| [时间线](#-时间线) | 开发历程 |
|
||||
|
||||
**相关文档**:
|
||||
- [Day 日志](file:///d:/CodingProjects/Antigravity/ViGent2/Docs/DevLogs/) (Day1-Day7)
|
||||
- [Day 日志](file:///d:/CodingProjects/Antigravity/ViGent2/Docs/DevLogs/) (Day1-Day12)
|
||||
- [部署指南](file:///d:/CodingProjects/Antigravity/ViGent2/Docs/DEPLOY_MANUAL.md)
|
||||
- [Qwen3-TTS 部署](file:///d:/CodingProjects/Antigravity/ViGent2/Docs/QWEN3_TTS_DEPLOY.md)
|
||||
|
||||
---
|
||||
|
||||
@@ -102,22 +103,87 @@
|
||||
- [x] QR登录超时修复 (Stealth模式、多选择器fallback)
|
||||
- [x] 文档规则优化 (智能修改标准、工具使用规范)
|
||||
|
||||
### 阶段十二:用户体验优化 (Day 8)
|
||||
- [x] 文件名保留 (时间戳前缀 + 原始名称)
|
||||
- [x] 视频持久化 (从文件系统读取历史)
|
||||
- [x] 历史视频列表组件
|
||||
- [x] 素材/视频删除功能
|
||||
- [x] 登出功能 (Logout API + 前端按钮)
|
||||
- [x] 前端 SWR 轮询优化
|
||||
- [x] QR 登录状态检测修复
|
||||
|
||||
### 阶段十三:发布模块优化 (Day 9)
|
||||
- [x] B站/抖音发布验证通过
|
||||
- [x] 资源清理保障 (try-finally)
|
||||
- [x] 超时保护 (消除无限循环)
|
||||
- [x] 小红书 headless 模式修复
|
||||
- [x] API 输入验证
|
||||
- [x] 完整类型提示
|
||||
- [x] 扫码登录等待界面 (加载动画)
|
||||
- [x] 抖音/B站登录策略优化 (Text优先)
|
||||
- [x] 发布成功审核提示
|
||||
|
||||
### 阶段十四:用户认证系统 (Day 9)
|
||||
- [x] Supabase 数据库表设计与部署
|
||||
- [x] JWT 认证 (HttpOnly Cookie)
|
||||
- [x] 用户注册/登录/登出 API
|
||||
- [x] 管理员权限控制 (is_active)
|
||||
- [x] 单设备登录限制 (Session Token)
|
||||
- [x] 防止 Supabase 暂停 (GitHub Actions/Crontab)
|
||||
- [x] 认证部署文档 (AUTH_DEPLOY.md)
|
||||
|
||||
### 阶段十五:部署稳定性优化 (Day 9)
|
||||
- [x] 后端依赖修复 (bcrypt/email-validator)
|
||||
- [x] 前端生产环境构建修复 (npm run build)
|
||||
- [x] LatentSync 性能卡顿修复 (OMP_NUM_THREADS限制)
|
||||
- [x] 部署服务自愈 (PM2 配置优化)
|
||||
- [x] 部署手册全量更新 (DEPLOY_MANUAL.md)
|
||||
|
||||
### 阶段十六:HTTPS 部署与细节完善 (Day 10)
|
||||
- [x] 隧道访问修复 (StaticFiles 挂载 + Rewrite)
|
||||
- [x] 平台账号列表 500 错误修复 (paths.py)
|
||||
- [x] Nginx HTTPS 配置 (反向代理 + SSL)
|
||||
- [x] 浏览器标题修改 (ViGent)
|
||||
- [x] 代码自适应 HTTPS 验证
|
||||
- [x] **Supabase 自托管部署** (Docker, 3003/8008端口)
|
||||
- [x] **安全加固** (Basic Auth 保护后台)
|
||||
- [x] **端口冲突解决** (迁移 Analytics/Kong)
|
||||
|
||||
### 阶段十七:上传架构重构 (Day 11)
|
||||
- [x] **直传改造** (前端直接上传 Supabase,绕过后端代理)
|
||||
- [x] **后端适配** (Signed URL 签名生成)
|
||||
- [x] **RLS 策略部署** (SQL 脚本自动化权限配置)
|
||||
- [x] **超时问题根治** (彻底解决 Nginx/FRP 30s 限制)
|
||||
- [x] **前端依赖更新** (@supabase/supabase-js 集成)
|
||||
|
||||
### 阶段十八:用户隔离与存储优化 (Day 11)
|
||||
- [x] **用户数据隔离** (素材/视频/Cookie 按用户ID目录隔离)
|
||||
- [x] **Storage URL 修复** (SUPABASE_PUBLIC_URL 配置,修复 localhost 问题)
|
||||
- [x] **发布服务优化** (直接读取本地 Supabase Storage 文件,跳过 HTTP 下载)
|
||||
- [x] **Supabase Studio 配置** (公网访问配置)
|
||||
|
||||
### 阶段十九:iOS 兼容与移动端 UI 优化 (Day 12)
|
||||
- [x] **Axios 全局拦截器** (401/403 自动跳转登录,防重复跳转)
|
||||
- [x] **iOS Safari 安全区域修复** (viewport-fit: cover, themeColor, 渐变背景统一)
|
||||
- [x] **移动端 Header 优化** (按钮紧凑布局,响应式间距)
|
||||
- [x] **发布页面 UI 重构** (立即发布/定时发布按钮分离,防误触设计)
|
||||
- [x] **Qwen3-TTS 0.6B 部署** (声音克隆模型,GPU0,3秒参考音频快速克隆)
|
||||
|
||||
---
|
||||
|
||||
## 🛤️ 后续规划
|
||||
|
||||
### 🔴 优先待办
|
||||
- [x] 视频合成最终验证 (MP4生成) ✅ Day 4 完成
|
||||
- [x] 端到端流程完整测试 ✅ Day 4 完成
|
||||
- [ ] 社交媒体发布测试 (B站/抖音已登录)
|
||||
- [ ] **Qwen3-TTS 集成到 ViGent2** - 前端 UI + 后端服务集成
|
||||
- [ ] 批量视频生成架构设计
|
||||
|
||||
### 🟠 功能完善
|
||||
- [ ] 定时发布功能
|
||||
- [x] 定时发布功能 ✅ Day 7 完成
|
||||
- [ ] **后端定时发布** - 替代平台端定时,使用 APScheduler 实现任务调度
|
||||
- [ ] 批量视频生成
|
||||
- [ ] 字幕样式编辑器
|
||||
|
||||
### 🔵 长期探索
|
||||
- [ ] 声音克隆 (GPT-SoVITS)
|
||||
- [ ] Docker 容器化
|
||||
- [ ] Celery 分布式任务队列
|
||||
|
||||
@@ -139,8 +205,9 @@
|
||||
| TTS 配音 | 100% | ✅ 完成 |
|
||||
| 视频合成 | 100% | ✅ 完成 |
|
||||
| 唇形同步 | 100% | ✅ LatentSync 1.6 升级完成 |
|
||||
| 社交发布 | 100% | ✅ 完成 (待验证) |
|
||||
| 服务器部署 | 100% | ✅ 完成 |
|
||||
| 社交发布 | 100% | ✅ Day 9 验证通过 |
|
||||
| 用户认证 | 100% | ✅ Day 9 Supabase+JWT |
|
||||
| 服务器部署 | 100% | ✅ Day 9 稳定性优化完成 |
|
||||
|
||||
---
|
||||
|
||||
@@ -175,11 +242,26 @@
|
||||
- Latent Diffusion 架构升级
|
||||
- 性能优化 (视频预压缩、进度更新)
|
||||
|
||||
### Milestone 5: 用户认证系统 ✅
|
||||
**完成时间**: Day 9
|
||||
**成果**:
|
||||
- Supabase 云数据库集成
|
||||
- 安全的 JWT + HttpOnly Cookie 认证
|
||||
- 管理员后台与用户隔离
|
||||
- 完善的部署与保活方案
|
||||
|
||||
### Milestone 6: 生产环境部署稳定化 ✅
|
||||
**完成时间**: Day 9
|
||||
**成果**:
|
||||
- 修复了后端 (bcrypt) 和前端 (build) 的启动崩溃问题
|
||||
- 解决了 LatentSync 占用全量 CPU 导致服务器卡顿的严重问题
|
||||
- 完善了部署手册,记录了关键的 Troubleshooting 步骤
|
||||
- 实现了服务 Long-term 稳定运行 (Reset PM2 counter)
|
||||
|
||||
---
|
||||
|
||||
## 📅 时间线
|
||||
|
||||
```
|
||||
Day 1: 项目初始化 + 核心功能 ✅ 完成
|
||||
- 后端 API 框架
|
||||
- 前端 UI
|
||||
@@ -224,5 +306,55 @@ Day 7: 社交媒体发布完善 ✅ 完成
|
||||
- 多平台发布 (B站/抖音/小红书)
|
||||
- UI 一致性优化
|
||||
- 文档规则体系优化
|
||||
```
|
||||
|
||||
Day 8: 用户体验优化 ✅ 完成
|
||||
- 文件名保留 (时间戳前缀)
|
||||
- 视频持久化 (历史视频API)
|
||||
- 历史视频列表组件
|
||||
- 素材/视频删除功能
|
||||
|
||||
Day 9: 发布模块优化 ✅ 完成
|
||||
- B站/抖音登录+发布验证通过
|
||||
- 资源清理保障 (try-finally)
|
||||
- 超时保护 (消除无限循环)
|
||||
- 小红书 headless 模式修复
|
||||
- 扫码登录等待界面 (加载动画)
|
||||
- 抖音/B站登录策略优化 (Text优先)
|
||||
- 发布成功审核提示
|
||||
- 用户认证系统规划 (FastAPI+Supabase)
|
||||
- Supabase 表结构设计 (users/sessions)
|
||||
- 后端 JWT 认证实现 (auth.py/deps.py)
|
||||
- 数据库配置与 SQL 部署
|
||||
- 独立认证部署文档 (AUTH_DEPLOY.md)
|
||||
- 自动保活机制 (Crontab/Actions)
|
||||
- 部署稳定性优化 (Backend依赖修复)
|
||||
- 前端生产构建流程修复
|
||||
- LatentSync 严重卡顿修复 (线程数限制)
|
||||
- 部署手册全量更新
|
||||
|
||||
Day 10: HTTPS 部署与细节完善 ✅ 完成
|
||||
- 隧道访问视频修正 (挂载 uploads)
|
||||
- 账号列表 Bug 修复 (paths.py 白名单)
|
||||
- 阿里云 Nginx HTTPS 部署
|
||||
- UI 细节优化 (Title 更新)
|
||||
|
||||
Day 11: 上传架构重构 ✅ 完成
|
||||
- **核心修复**: Aliyun Nginx `client_max_body_size 0` 配置
|
||||
- 500 错误根治 (Direct Upload + Gateway Config)
|
||||
- Supabase RLS 权限策略部署
|
||||
- 前端集成 supabase-js
|
||||
- 彻底解决大文件上传超时 (30s 限制)
|
||||
- **用户数据隔离** (素材/视频/Cookie 按用户目录存储)
|
||||
- **Storage URL 修复** (SUPABASE_PUBLIC_URL 公网地址配置)
|
||||
- **发布服务优化** (本地文件直读,跳过 HTTP 下载)
|
||||
|
||||
Day 12: iOS 兼容与移动端优化 ✅ 完成
|
||||
- Axios 全局拦截器 (401/403 自动跳转登录)
|
||||
- iOS Safari 安全区域白边修复 (viewport-fit: cover)
|
||||
- themeColor 配置 (状态栏颜色适配)
|
||||
- 渐变背景统一 (body 全局渐变,消除分层)
|
||||
- 移动端 Header 响应式优化 (按钮紧凑布局)
|
||||
- 发布页面 UI 重构 (立即发布 3/4 + 定时 1/4)
|
||||
- **Qwen3-TTS 0.6B 部署** (声音克隆模型,GPU0)
|
||||
- **部署文档** (QWEN3_TTS_DEPLOY.md)
|
||||
|
||||
|
||||
28
README.md
28
README.md
@@ -10,9 +10,11 @@
|
||||
|
||||
- 🎬 **唇形同步** - LatentSync 1.6 驱动,512×512 高分辨率 Diffusion 模型
|
||||
- 🎙️ **TTS 配音** - EdgeTTS 多音色支持(云溪、晓晓等)
|
||||
- 📱 **一键发布** - Playwright 自动发布到抖音、小红书、B站等
|
||||
- 🖥️ **Web UI** - Next.js 现代化界面
|
||||
- 🚀 **性能优化** - 视频预压缩、常驻模型服务 (0s加载)
|
||||
- 📱 **全自动发布** - 扫码登录 + Cookie持久化,支持多平台(B站/抖音/小红书)定时发布
|
||||
- 🖥️ **Web UI** - Next.js 现代化界面,iOS/Android 移动端适配
|
||||
- 🔐 **用户系统** - Supabase + JWT 认证,支持管理员后台、注册/登录
|
||||
- 👥 **多用户隔离** - 素材/视频/Cookie 按用户独立存储,数据完全隔离
|
||||
- 🚀 **性能优化** - 视频预压缩、常驻模型服务 (0s加载)、本地文件直读
|
||||
|
||||
## 🛠️ 技术栈
|
||||
|
||||
@@ -20,6 +22,9 @@
|
||||
|------|------|
|
||||
| 前端 | Next.js 14 + TypeScript + TailwindCSS |
|
||||
| 后端 | FastAPI + Python 3.10 |
|
||||
| 数据库 | **Supabase** (PostgreSQL) 自托管 Docker |
|
||||
| 存储 | **Supabase Storage** (本地文件系统) |
|
||||
| 认证 | **JWT** + HttpOnly Cookie |
|
||||
| 唇形同步 | **LatentSync 1.6** (Latent Diffusion, 512×512) |
|
||||
| TTS | EdgeTTS |
|
||||
| 视频处理 | FFmpeg |
|
||||
@@ -45,6 +50,7 @@ ViGent2/
|
||||
│ └── DEPLOY.md # LatentSync 部署指南
|
||||
└── Docs/ # 文档
|
||||
├── DEPLOY_MANUAL.md # 部署手册
|
||||
├── AUTH_DEPLOY.md # 认证部署指南
|
||||
├── task_complete.md
|
||||
└── DevLogs/
|
||||
```
|
||||
@@ -129,19 +135,21 @@ nohup python -m scripts.server > server.log 2>&1 &
|
||||
|
||||
## 🌐 访问地址
|
||||
|
||||
| 服务 | 地址 |
|
||||
|------|------|
|
||||
| 视频生成 | http://服务器IP:3002 |
|
||||
| 发布管理 | http://服务器IP:3002/publish |
|
||||
| API 文档 | http://服务器IP:8006/docs |
|
||||
| 模型API | http://服务器IP:8007/docs |
|
||||
| 服务 | 地址 | 说明 |
|
||||
|------|------|------|
|
||||
| **视频生成 (UI)** | `https://vigent.hbyrkj.top` | 用户访问入口 |
|
||||
| **API 服务** | `http://<服务器IP>:8006` | 后端 Swagger |
|
||||
| **认证管理 (Studio)** | `https://supabase.hbyrkj.top` | 需要 Basic Auth |
|
||||
| **认证 API (Kong)** | `https://api.hbyrkj.top` | Supabase 接口 |
|
||||
| **模型服务** | `http://<服务器IP>:8007` | LatentSync |
|
||||
|
||||
---
|
||||
|
||||
## 📖 文档
|
||||
|
||||
- [LatentSync 部署指南](models/LatentSync/DEPLOY.md)
|
||||
- [手动部署指南](Docs/DEPLOY_MANUAL.md)
|
||||
- [Supabase 部署指南](Docs/SUPABASE_DEPLOY.md)
|
||||
- [LatentSync 部署指南](models/LatentSync/DEPLOY.md)
|
||||
- [开发日志](Docs/DevLogs/)
|
||||
- [任务进度](Docs/task_complete.md)
|
||||
|
||||
|
||||
@@ -45,3 +45,19 @@ MAX_UPLOAD_SIZE_MB=500
|
||||
# FFmpeg 路径 (如果不在系统 PATH 中)
|
||||
# FFMPEG_PATH=/usr/bin/ffmpeg
|
||||
|
||||
# =============== Supabase 配置 ===============
|
||||
# 从 Supabase 项目设置 > API 获取
|
||||
SUPABASE_URL=http://localhost:8008/
|
||||
SUPABASE_PUBLIC_URL=https://api.hbyrkj.top
|
||||
SUPABASE_KEY=eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9.eyJyb2xlIjogInNlcnZpY2Vfcm9sZSIsICJpc3MiOiAic3VwYWJhc2UiLCAiaWF0IjogMTc2OTQwNzU2NSwgImV4cCI6IDIwODQ3Njc1NjV9.LBPaimygpnM9o3mZ2Pi-iL8taJ90JjGbQ0HW6yFlmhg
|
||||
|
||||
# =============== JWT 配置 ===============
|
||||
# 用于签名 JWT Token 的密钥 (请更换为随机字符串)
|
||||
JWT_SECRET_KEY=F4MagRkf7nJsN-ag9AB7Q-30MbZRe7Iu4E9p9xRzyic
|
||||
JWT_ALGORITHM=HS256
|
||||
JWT_EXPIRE_HOURS=168
|
||||
|
||||
# =============== 管理员配置 ===============
|
||||
# 服务启动时自动创建的管理员账号
|
||||
ADMIN_EMAIL=lamnickdavid@gmail.com
|
||||
ADMIN_PASSWORD=lam1988324
|
||||
|
||||
185
backend/app/api/admin.py
Normal file
185
backend/app/api/admin.py
Normal file
@@ -0,0 +1,185 @@
|
||||
"""
|
||||
管理员 API:用户管理
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Depends, status
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional, List
|
||||
from datetime import datetime, timezone, timedelta
|
||||
from app.core.supabase import get_supabase
|
||||
from app.core.deps import get_current_admin
|
||||
from loguru import logger
|
||||
|
||||
router = APIRouter(prefix="/api/admin", tags=["管理"])
|
||||
|
||||
|
||||
class UserListItem(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
username: Optional[str]
|
||||
role: str
|
||||
is_active: bool
|
||||
expires_at: Optional[str]
|
||||
created_at: str
|
||||
|
||||
|
||||
class ActivateRequest(BaseModel):
|
||||
expires_days: Optional[int] = None # 授权天数,None 表示永久
|
||||
|
||||
|
||||
@router.get("/users", response_model=List[UserListItem])
|
||||
async def list_users(admin: dict = Depends(get_current_admin)):
|
||||
"""获取所有用户列表"""
|
||||
try:
|
||||
supabase = get_supabase()
|
||||
result = supabase.table("users").select("*").order("created_at", desc=True).execute()
|
||||
|
||||
return [
|
||||
UserListItem(
|
||||
id=u["id"],
|
||||
email=u["email"],
|
||||
username=u.get("username"),
|
||||
role=u["role"],
|
||||
is_active=u["is_active"],
|
||||
expires_at=u.get("expires_at"),
|
||||
created_at=u["created_at"]
|
||||
)
|
||||
for u in result.data
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户列表失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="获取用户列表失败"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/activate")
|
||||
async def activate_user(
|
||||
user_id: str,
|
||||
request: ActivateRequest,
|
||||
admin: dict = Depends(get_current_admin)
|
||||
):
|
||||
"""
|
||||
激活用户
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
request.expires_days: 授权天数 (None 表示永久)
|
||||
"""
|
||||
try:
|
||||
supabase = get_supabase()
|
||||
|
||||
# 计算过期时间
|
||||
expires_at = None
|
||||
if request.expires_days:
|
||||
expires_at = (datetime.now(timezone.utc) + timedelta(days=request.expires_days)).isoformat()
|
||||
|
||||
# 更新用户
|
||||
result = supabase.table("users").update({
|
||||
"is_active": True,
|
||||
"role": "user",
|
||||
"expires_at": expires_at
|
||||
}).eq("id", user_id).execute()
|
||||
|
||||
if not result.data:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail="用户不存在"
|
||||
)
|
||||
|
||||
logger.info(f"管理员 {admin['email']} 激活用户 {user_id}, 有效期: {request.expires_days or '永久'} 天")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"用户已激活,有效期: {request.expires_days or '永久'} 天"
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"激活用户失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="激活用户失败"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/deactivate")
|
||||
async def deactivate_user(
|
||||
user_id: str,
|
||||
admin: dict = Depends(get_current_admin)
|
||||
):
|
||||
"""停用用户"""
|
||||
try:
|
||||
supabase = get_supabase()
|
||||
|
||||
# 不能停用管理员
|
||||
user_result = supabase.table("users").select("role").eq("id", user_id).single().execute()
|
||||
if user_result.data and user_result.data["role"] == "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="不能停用管理员账号"
|
||||
)
|
||||
|
||||
# 更新用户
|
||||
result = supabase.table("users").update({
|
||||
"is_active": False
|
||||
}).eq("id", user_id).execute()
|
||||
|
||||
# 清除用户 session
|
||||
supabase.table("user_sessions").delete().eq("user_id", user_id).execute()
|
||||
|
||||
logger.info(f"管理员 {admin['email']} 停用用户 {user_id}")
|
||||
|
||||
return {"success": True, "message": "用户已停用"}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"停用用户失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="停用用户失败"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/users/{user_id}/extend")
|
||||
async def extend_user(
|
||||
user_id: str,
|
||||
request: ActivateRequest,
|
||||
admin: dict = Depends(get_current_admin)
|
||||
):
|
||||
"""延长用户授权期限"""
|
||||
try:
|
||||
supabase = get_supabase()
|
||||
|
||||
if not request.expires_days:
|
||||
# 设为永久
|
||||
expires_at = None
|
||||
else:
|
||||
# 获取当前过期时间
|
||||
user_result = supabase.table("users").select("expires_at").eq("id", user_id).single().execute()
|
||||
user = user_result.data
|
||||
|
||||
if user and user.get("expires_at"):
|
||||
current_expires = datetime.fromisoformat(user["expires_at"].replace("Z", "+00:00"))
|
||||
base_time = max(current_expires, datetime.now(timezone.utc))
|
||||
else:
|
||||
base_time = datetime.now(timezone.utc)
|
||||
|
||||
expires_at = (base_time + timedelta(days=request.expires_days)).isoformat()
|
||||
|
||||
result = supabase.table("users").update({
|
||||
"expires_at": expires_at
|
||||
}).eq("id", user_id).execute()
|
||||
|
||||
logger.info(f"管理员 {admin['email']} 延长用户 {user_id} 授权 {request.expires_days or '永久'} 天")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": f"授权已延长 {request.expires_days or '永久'} 天"
|
||||
}
|
||||
except Exception as e:
|
||||
logger.error(f"延长授权失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="延长授权失败"
|
||||
)
|
||||
223
backend/app/api/auth.py
Normal file
223
backend/app/api/auth.py
Normal file
@@ -0,0 +1,223 @@
|
||||
"""
|
||||
认证 API:注册、登录、登出
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, Response, status, Request
|
||||
from pydantic import BaseModel, EmailStr
|
||||
from app.core.supabase import get_supabase
|
||||
from app.core.security import (
|
||||
get_password_hash,
|
||||
verify_password,
|
||||
create_access_token,
|
||||
generate_session_token,
|
||||
set_auth_cookie,
|
||||
clear_auth_cookie,
|
||||
decode_access_token
|
||||
)
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
router = APIRouter(prefix="/api/auth", tags=["认证"])
|
||||
|
||||
|
||||
class RegisterRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
username: Optional[str] = None
|
||||
|
||||
|
||||
class LoginRequest(BaseModel):
|
||||
email: EmailStr
|
||||
password: str
|
||||
|
||||
|
||||
class UserResponse(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
username: Optional[str]
|
||||
role: str
|
||||
is_active: bool
|
||||
|
||||
|
||||
@router.post("/register")
|
||||
async def register(request: RegisterRequest):
|
||||
"""
|
||||
用户注册
|
||||
|
||||
注册后状态为 pending,需要管理员激活
|
||||
"""
|
||||
try:
|
||||
supabase = get_supabase()
|
||||
|
||||
# 检查邮箱是否已存在
|
||||
existing = supabase.table("users").select("id").eq(
|
||||
"email", request.email
|
||||
).execute()
|
||||
|
||||
if existing.data:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="该邮箱已注册"
|
||||
)
|
||||
|
||||
# 创建用户
|
||||
password_hash = get_password_hash(request.password)
|
||||
|
||||
result = supabase.table("users").insert({
|
||||
"email": request.email,
|
||||
"password_hash": password_hash,
|
||||
"username": request.username or request.email.split("@")[0],
|
||||
"role": "pending",
|
||||
"is_active": False
|
||||
}).execute()
|
||||
|
||||
logger.info(f"新用户注册: {request.email}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "注册成功,请等待管理员审核激活"
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"注册失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="注册失败,请稍后重试"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/login")
|
||||
async def login(request: LoginRequest, response: Response):
|
||||
"""
|
||||
用户登录
|
||||
|
||||
- 验证密码
|
||||
- 检查是否激活
|
||||
- 实现"后踢前"单设备登录
|
||||
"""
|
||||
try:
|
||||
supabase = get_supabase()
|
||||
|
||||
# 查找用户
|
||||
user_result = supabase.table("users").select("*").eq(
|
||||
"email", request.email
|
||||
).single().execute()
|
||||
|
||||
user = user_result.data
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="邮箱或密码错误"
|
||||
)
|
||||
|
||||
# 验证密码
|
||||
if not verify_password(request.password, user["password_hash"]):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="邮箱或密码错误"
|
||||
)
|
||||
|
||||
# 检查是否激活
|
||||
if not user["is_active"]:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="账号未激活,请等待管理员审核"
|
||||
)
|
||||
|
||||
# 检查授权是否过期
|
||||
if user.get("expires_at"):
|
||||
from datetime import datetime, timezone
|
||||
expires_at = datetime.fromisoformat(user["expires_at"].replace("Z", "+00:00"))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="授权已过期,请联系管理员续期"
|
||||
)
|
||||
|
||||
# 生成新的 session_token (后踢前)
|
||||
session_token = generate_session_token()
|
||||
|
||||
# 删除旧 session,插入新 session
|
||||
supabase.table("user_sessions").delete().eq(
|
||||
"user_id", user["id"]
|
||||
).execute()
|
||||
|
||||
supabase.table("user_sessions").insert({
|
||||
"user_id": user["id"],
|
||||
"session_token": session_token,
|
||||
"device_info": None # 可以从 request headers 获取
|
||||
}).execute()
|
||||
|
||||
# 生成 JWT Token
|
||||
token = create_access_token(user["id"], session_token)
|
||||
|
||||
# 设置 HttpOnly Cookie
|
||||
set_auth_cookie(response, token)
|
||||
|
||||
logger.info(f"用户登录: {request.email}")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "登录成功",
|
||||
"user": UserResponse(
|
||||
id=user["id"],
|
||||
email=user["email"],
|
||||
username=user.get("username"),
|
||||
role=user["role"],
|
||||
is_active=user["is_active"]
|
||||
)
|
||||
}
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"登录失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="登录失败,请稍后重试"
|
||||
)
|
||||
|
||||
|
||||
@router.post("/logout")
|
||||
async def logout(response: Response):
|
||||
"""用户登出"""
|
||||
clear_auth_cookie(response)
|
||||
return {"success": True, "message": "已登出"}
|
||||
|
||||
|
||||
@router.get("/me")
|
||||
async def get_me(request: Request):
|
||||
"""获取当前用户信息"""
|
||||
# 从 Cookie 获取用户
|
||||
token = request.cookies.get("access_token")
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="未登录"
|
||||
)
|
||||
|
||||
token_data = decode_access_token(token)
|
||||
if not token_data:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token 无效"
|
||||
)
|
||||
|
||||
supabase = get_supabase()
|
||||
user_result = supabase.table("users").select("*").eq(
|
||||
"id", token_data.user_id
|
||||
).single().execute()
|
||||
|
||||
user = user_result.data
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户不存在"
|
||||
)
|
||||
|
||||
return UserResponse(
|
||||
id=user["id"],
|
||||
email=user["email"],
|
||||
username=user.get("username"),
|
||||
role=user["role"],
|
||||
is_active=user["is_active"]
|
||||
)
|
||||
@@ -1,53 +1,331 @@
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException
|
||||
from fastapi import APIRouter, UploadFile, File, HTTPException, Request, BackgroundTasks, Depends
|
||||
from app.core.config import settings
|
||||
import shutil
|
||||
import uuid
|
||||
from app.core.deps import get_current_user
|
||||
from app.services.storage import storage_service
|
||||
import re
|
||||
import time
|
||||
import traceback
|
||||
import os
|
||||
import aiofiles
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@router.post("/")
|
||||
async def upload_material(file: UploadFile = File(...)):
|
||||
if not file.filename.lower().endswith(('.mp4', '.mov', '.avi')):
|
||||
raise HTTPException(400, "Invalid format")
|
||||
|
||||
file_id = str(uuid.uuid4())
|
||||
ext = Path(file.filename).suffix
|
||||
save_path = settings.UPLOAD_DIR / "materials" / f"{file_id}{ext}"
|
||||
|
||||
# Save file
|
||||
with open(save_path, "wb") as buffer:
|
||||
shutil.copyfileobj(file.file, buffer)
|
||||
|
||||
# Calculate size
|
||||
size_mb = save_path.stat().st_size / (1024 * 1024)
|
||||
|
||||
return {
|
||||
"id": file_id,
|
||||
"name": file.filename,
|
||||
"path": f"uploads/materials/{file_id}{ext}",
|
||||
"size_mb": size_mb,
|
||||
"type": "video"
|
||||
}
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
safe_name = re.sub(r'[<>:"/\\|?*]', '_', filename)
|
||||
if len(safe_name) > 100:
|
||||
ext = Path(safe_name).suffix
|
||||
safe_name = safe_name[:100 - len(ext)] + ext
|
||||
return safe_name
|
||||
|
||||
@router.get("/")
|
||||
async def list_materials():
|
||||
materials_dir = settings.UPLOAD_DIR / "materials"
|
||||
files = []
|
||||
if materials_dir.exists():
|
||||
for f in materials_dir.glob("*"):
|
||||
async def process_and_upload(temp_file_path: str, original_filename: str, content_type: str, user_id: str):
|
||||
"""Background task to strip multipart headers and upload to Supabase"""
|
||||
try:
|
||||
logger.info(f"Processing raw upload: {temp_file_path} for user {user_id}")
|
||||
|
||||
# 1. Analyze file to find actual video content (strip multipart boundaries)
|
||||
# This is a simplified manual parser for a SINGLE file upload.
|
||||
# Structure:
|
||||
# --boundary
|
||||
# Content-Disposition: form-data; name="file"; filename="..."
|
||||
# Content-Type: video/mp4
|
||||
# \r\n\r\n
|
||||
# [DATA]
|
||||
# \r\n--boundary--
|
||||
|
||||
# We need to read the first few KB to find the header end
|
||||
start_offset = 0
|
||||
end_offset = 0
|
||||
boundary = b""
|
||||
|
||||
file_size = os.path.getsize(temp_file_path)
|
||||
|
||||
with open(temp_file_path, 'rb') as f:
|
||||
# Read first 4KB to find header
|
||||
head = f.read(4096)
|
||||
|
||||
# Find boundary
|
||||
first_line_end = head.find(b'\r\n')
|
||||
if first_line_end == -1:
|
||||
raise Exception("Could not find boundary in multipart body")
|
||||
|
||||
boundary = head[:first_line_end] # e.g. --boundary123
|
||||
logger.info(f"Detected boundary: {boundary}")
|
||||
|
||||
# Find end of headers (\r\n\r\n)
|
||||
header_end = head.find(b'\r\n\r\n')
|
||||
if header_end == -1:
|
||||
raise Exception("Could not find end of multipart headers")
|
||||
|
||||
start_offset = header_end + 4
|
||||
logger.info(f"Video data starts at offset: {start_offset}")
|
||||
|
||||
# Find end boundary (read from end of file)
|
||||
# It should be \r\n + boundary + -- + \r\n
|
||||
# We seek to end-200 bytes
|
||||
f.seek(max(0, file_size - 200))
|
||||
tail = f.read()
|
||||
|
||||
# The closing boundary is usually --boundary--
|
||||
# We look for the last occurrence of the boundary
|
||||
last_boundary_pos = tail.rfind(boundary)
|
||||
if last_boundary_pos != -1:
|
||||
# The data ends before \r\n + boundary
|
||||
# The tail buffer relative position needs to be converted to absolute
|
||||
end_pos_in_tail = last_boundary_pos
|
||||
# We also need to check for the preceding \r\n
|
||||
if end_pos_in_tail >= 2 and tail[end_pos_in_tail-2:end_pos_in_tail] == b'\r\n':
|
||||
end_pos_in_tail -= 2
|
||||
|
||||
# Absolute end offset
|
||||
end_offset = (file_size - 200) + last_boundary_pos
|
||||
# Correction for CRLF before boundary
|
||||
# Actually, simply: read until (file_size - len(tail) + last_boundary_pos) - 2
|
||||
end_offset = (max(0, file_size - 200) + last_boundary_pos) - 2
|
||||
else:
|
||||
logger.warning("Could not find closing boundary, assuming EOF")
|
||||
end_offset = file_size
|
||||
|
||||
logger.info(f"Video data ends at offset: {end_offset}. Total video size: {end_offset - start_offset}")
|
||||
|
||||
# 2. Extract and Upload to Supabase
|
||||
# Since we have the file on disk, we can just pass the file object (seeked) to upload_file?
|
||||
# Or if upload_file expects bytes/path, checking storage.py...
|
||||
# It takes `file_data` (bytes) or file-like?
|
||||
# supabase-py's `upload` method handles parsing if we pass a file object.
|
||||
# But we need to pass ONLY the video slice.
|
||||
# So we create a generator or a sliced file object?
|
||||
# Simpler: Read the slice into memory if < 1GB? Or copy to new temp file?
|
||||
# Copying to new temp file is safer for memory.
|
||||
|
||||
video_path = temp_file_path + "_video.mp4"
|
||||
with open(temp_file_path, 'rb') as src, open(video_path, 'wb') as dst:
|
||||
src.seek(start_offset)
|
||||
# Copy in chunks
|
||||
bytes_to_copy = end_offset - start_offset
|
||||
copied = 0
|
||||
while copied < bytes_to_copy:
|
||||
chunk_size = min(1024*1024*10, bytes_to_copy - copied) # 10MB chunks
|
||||
chunk = src.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
dst.write(chunk)
|
||||
copied += len(chunk)
|
||||
|
||||
logger.info(f"Extracted video content to {video_path}")
|
||||
|
||||
# 3. Upload to Supabase with user isolation
|
||||
timestamp = int(time.time())
|
||||
safe_name = re.sub(r'[^a-zA-Z0-9._-]', '', original_filename)
|
||||
# 使用 user_id 作为目录前缀实现隔离
|
||||
storage_path = f"{user_id}/{timestamp}_{safe_name}"
|
||||
|
||||
# Use storage service (this calls Supabase which might do its own http request)
|
||||
# We read the cleaned video file
|
||||
with open(video_path, 'rb') as f:
|
||||
file_content = f.read() # Still reading into memory for simple upload call, but server has 32GB RAM so ok for 500MB
|
||||
await storage_service.upload_file(
|
||||
bucket=storage_service.BUCKET_MATERIALS,
|
||||
path=storage_path,
|
||||
file_data=file_content,
|
||||
content_type=content_type
|
||||
)
|
||||
|
||||
logger.info(f"Upload to Supabase complete: {storage_path}")
|
||||
|
||||
# Cleanup
|
||||
os.remove(temp_file_path)
|
||||
os.remove(video_path)
|
||||
|
||||
return storage_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Background upload processing failed: {e}\n{traceback.format_exc()}")
|
||||
raise
|
||||
|
||||
|
||||
@router.post("")
|
||||
async def upload_material(
|
||||
request: Request,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
user_id = current_user["id"]
|
||||
logger.info(f"ENTERED upload_material (Streaming Mode) for user {user_id}. Headers: {request.headers}")
|
||||
|
||||
filename = "unknown_video.mp4" # Fallback
|
||||
content_type = "video/mp4"
|
||||
|
||||
# Try to parse filename from header if possible (unreliable in raw stream)
|
||||
# We will rely on post-processing or client hint
|
||||
# Frontend sends standard multipart.
|
||||
|
||||
# Create temp file
|
||||
timestamp = int(time.time())
|
||||
temp_filename = f"upload_{timestamp}.raw"
|
||||
temp_path = os.path.join("/tmp", temp_filename) # Use /tmp on Linux
|
||||
# Ensure /tmp exists (it does) but verify paths
|
||||
if os.name == 'nt': # Local dev
|
||||
temp_path = f"d:/tmp/{temp_filename}"
|
||||
os.makedirs("d:/tmp", exist_ok=True)
|
||||
|
||||
try:
|
||||
total_size = 0
|
||||
last_log = 0
|
||||
|
||||
async with aiofiles.open(temp_path, 'wb') as f:
|
||||
async for chunk in request.stream():
|
||||
await f.write(chunk)
|
||||
total_size += len(chunk)
|
||||
|
||||
# Log progress every 20MB
|
||||
if total_size - last_log > 20 * 1024 * 1024:
|
||||
logger.info(f"Receiving stream... Processed {total_size / (1024*1024):.2f} MB")
|
||||
last_log = total_size
|
||||
|
||||
logger.info(f"Stream reception complete. Total size: {total_size} bytes. Saved to {temp_path}")
|
||||
|
||||
if total_size == 0:
|
||||
raise HTTPException(400, "Received empty body")
|
||||
|
||||
# Attempt to extract filename from the saved file's first bytes?
|
||||
# Or just accept it as "uploaded_video.mp4" for now to prove it works.
|
||||
# We can try to regex the header in the file content we just wrote.
|
||||
# Implemented in background task to return success immediately.
|
||||
|
||||
# Wait, if we return immediately, the user's UI might not show the file yet?
|
||||
# The prompt says "Wait for upload".
|
||||
# But to avoid User Waiting Timeout, maybe returning early is better?
|
||||
# NO, user expects the file to be in the list.
|
||||
# So we Must await the processing.
|
||||
# But "Processing" (Strip + Upload to Supabase) takes time.
|
||||
# Receiving took time.
|
||||
# If we await Supabase upload, does it timeout?
|
||||
# Supabase upload is outgoing. Usually faster/stable.
|
||||
|
||||
# Let's await the processing to ensure "List Materials" shows it.
|
||||
# We need to extract the filename for the list.
|
||||
|
||||
# Quick extract filename from first 4kb
|
||||
with open(temp_path, 'rb') as f:
|
||||
head = f.read(4096).decode('utf-8', errors='ignore')
|
||||
match = re.search(r'filename="([^"]+)"', head)
|
||||
if match:
|
||||
filename = match.group(1)
|
||||
logger.info(f"Extracted filename from body: {filename}")
|
||||
|
||||
# Run processing sync (in await)
|
||||
storage_path = await process_and_upload(temp_path, filename, content_type, user_id)
|
||||
|
||||
# Get signed URL (it exists now)
|
||||
signed_url = await storage_service.get_signed_url(
|
||||
bucket=storage_service.BUCKET_MATERIALS,
|
||||
path=storage_path
|
||||
)
|
||||
|
||||
size_mb = total_size / (1024 * 1024) # Approximate (includes headers)
|
||||
|
||||
# 从 storage_path 提取显示名
|
||||
display_name = storage_path.split('/')[-1] # 去掉 user_id 前缀
|
||||
if '_' in display_name:
|
||||
parts = display_name.split('_', 1)
|
||||
if parts[0].isdigit():
|
||||
display_name = parts[1]
|
||||
|
||||
return {
|
||||
"id": storage_path,
|
||||
"name": display_name,
|
||||
"path": signed_url,
|
||||
"size_mb": size_mb,
|
||||
"type": "video"
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Streaming upload failed: {str(e)}"
|
||||
detail_msg = f"Exception: {repr(e)}\nArgs: {e.args}\n{traceback.format_exc()}"
|
||||
logger.error(error_msg + "\n" + detail_msg)
|
||||
|
||||
# Write to debug file
|
||||
try:
|
||||
with open("debug_upload.log", "a") as logf:
|
||||
logf.write(f"\n--- Error at {time.ctime()} ---\n")
|
||||
logf.write(detail_msg)
|
||||
logf.write("\n-----------------------------\n")
|
||||
except:
|
||||
pass
|
||||
|
||||
if os.path.exists(temp_path):
|
||||
try:
|
||||
stat = f.stat()
|
||||
files.append({
|
||||
"id": f.stem,
|
||||
"name": f.name,
|
||||
"path": f"uploads/materials/{f.name}",
|
||||
"size_mb": stat.st_size / (1024 * 1024),
|
||||
"type": "video",
|
||||
"created_at": stat.st_ctime
|
||||
})
|
||||
except Exception:
|
||||
os.remove(temp_path)
|
||||
except:
|
||||
pass
|
||||
raise HTTPException(500, f"Upload failed. Check server logs. Error: {str(e)}")
|
||||
|
||||
|
||||
@router.get("")
|
||||
async def list_materials(current_user: dict = Depends(get_current_user)):
|
||||
user_id = current_user["id"]
|
||||
try:
|
||||
# 只列出当前用户目录下的文件
|
||||
files_obj = await storage_service.list_files(
|
||||
bucket=storage_service.BUCKET_MATERIALS,
|
||||
path=user_id
|
||||
)
|
||||
materials = []
|
||||
for f in files_obj:
|
||||
name = f.get('name')
|
||||
if not name or name == '.emptyFolderPlaceholder':
|
||||
continue
|
||||
# Sort by creation time desc
|
||||
files.sort(key=lambda x: x.get("created_at", 0), reverse=True)
|
||||
return {"materials": files}
|
||||
display_name = name
|
||||
if '_' in name:
|
||||
parts = name.split('_', 1)
|
||||
if parts[0].isdigit():
|
||||
display_name = parts[1]
|
||||
# 完整路径包含 user_id
|
||||
full_path = f"{user_id}/{name}"
|
||||
signed_url = await storage_service.get_signed_url(
|
||||
bucket=storage_service.BUCKET_MATERIALS,
|
||||
path=full_path
|
||||
)
|
||||
metadata = f.get('metadata', {})
|
||||
size = metadata.get('size', 0)
|
||||
# created_at 在顶层,是 ISO 字符串
|
||||
created_at_str = f.get('created_at', '')
|
||||
created_at = 0
|
||||
if created_at_str:
|
||||
from datetime import datetime
|
||||
try:
|
||||
dt = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
|
||||
created_at = int(dt.timestamp())
|
||||
except:
|
||||
pass
|
||||
materials.append({
|
||||
"id": full_path, # ID 使用完整路径
|
||||
"name": display_name,
|
||||
"path": signed_url,
|
||||
"size_mb": size / (1024 * 1024),
|
||||
"type": "video",
|
||||
"created_at": created_at
|
||||
})
|
||||
materials.sort(key=lambda x: x['id'], reverse=True)
|
||||
return {"materials": materials}
|
||||
except Exception as e:
|
||||
logger.error(f"List materials failed: {e}")
|
||||
return {"materials": []}
|
||||
|
||||
|
||||
@router.delete("/{material_id:path}")
|
||||
async def delete_material(material_id: str, current_user: dict = Depends(get_current_user)):
|
||||
user_id = current_user["id"]
|
||||
# 验证 material_id 属于当前用户
|
||||
if not material_id.startswith(f"{user_id}/"):
|
||||
raise HTTPException(403, "无权删除此素材")
|
||||
try:
|
||||
await storage_service.delete_file(
|
||||
bucket=storage_service.BUCKET_MATERIALS,
|
||||
path=material_id
|
||||
)
|
||||
return {"success": True, "message": "素材已删除"}
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"删除失败: {str(e)}")
|
||||
|
||||
@@ -1,17 +1,19 @@
|
||||
"""
|
||||
发布管理 API
|
||||
发布管理 API (支持用户认证)
|
||||
"""
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
from loguru import logger
|
||||
from app.services.publish_service import PublishService
|
||||
from app.core.deps import get_current_user_optional
|
||||
|
||||
router = APIRouter()
|
||||
publish_service = PublishService()
|
||||
|
||||
class PublishRequest(BaseModel):
|
||||
"""Video publish request model"""
|
||||
video_path: str
|
||||
platform: str
|
||||
title: str
|
||||
@@ -20,13 +22,43 @@ class PublishRequest(BaseModel):
|
||||
publish_time: Optional[datetime] = None
|
||||
|
||||
class PublishResponse(BaseModel):
|
||||
"""Video publish response model"""
|
||||
success: bool
|
||||
message: str
|
||||
platform: str
|
||||
url: Optional[str] = None
|
||||
|
||||
@router.post("/", response_model=PublishResponse)
|
||||
async def publish_video(request: PublishRequest, background_tasks: BackgroundTasks):
|
||||
# Supported platforms for validation
|
||||
SUPPORTED_PLATFORMS = {"bilibili", "douyin", "xiaohongshu"}
|
||||
|
||||
|
||||
def _get_user_id(request: Request) -> Optional[str]:
|
||||
"""从请求中获取用户 ID (兼容未登录场景)"""
|
||||
try:
|
||||
from app.core.security import decode_access_token
|
||||
token = request.cookies.get("access_token")
|
||||
if token:
|
||||
token_data = decode_access_token(token)
|
||||
if token_data:
|
||||
return token_data.user_id
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
@router.post("", response_model=PublishResponse)
|
||||
async def publish_video(request: PublishRequest, req: Request, background_tasks: BackgroundTasks):
|
||||
"""发布视频到指定平台"""
|
||||
# Validate platform
|
||||
if request.platform not in SUPPORTED_PLATFORMS:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"不支持的平台: {request.platform}。支持的平台: {', '.join(SUPPORTED_PLATFORMS)}"
|
||||
)
|
||||
|
||||
# 获取用户 ID (可选)
|
||||
user_id = _get_user_id(req)
|
||||
|
||||
try:
|
||||
result = await publish_service.publish(
|
||||
video_path=request.video_path,
|
||||
@@ -34,7 +66,8 @@ async def publish_video(request: PublishRequest, background_tasks: BackgroundTas
|
||||
title=request.title,
|
||||
tags=request.tags,
|
||||
description=request.description,
|
||||
publish_time=request.publish_time
|
||||
publish_time=request.publish_time,
|
||||
user_id=user_id
|
||||
)
|
||||
return PublishResponse(
|
||||
success=result.get("success", False),
|
||||
@@ -48,33 +81,48 @@ async def publish_video(request: PublishRequest, background_tasks: BackgroundTas
|
||||
|
||||
@router.get("/platforms")
|
||||
async def list_platforms():
|
||||
return {"platforms": [{"id": pid, **pinfo} for pid, pinfo in publish_service.PLATFORMS.items()]}
|
||||
return {"platforms": [{**pinfo, "id": pid} for pid, pinfo in publish_service.PLATFORMS.items()]}
|
||||
|
||||
@router.get("/accounts")
|
||||
async def list_accounts():
|
||||
return {"accounts": publish_service.get_accounts()}
|
||||
async def list_accounts(req: Request):
|
||||
user_id = _get_user_id(req)
|
||||
return {"accounts": publish_service.get_accounts(user_id)}
|
||||
|
||||
@router.post("/login/{platform}")
|
||||
async def login_platform(platform: str):
|
||||
result = await publish_service.login(platform)
|
||||
async def login_platform(platform: str, req: Request):
|
||||
"""触发平台QR码登录"""
|
||||
if platform not in SUPPORTED_PLATFORMS:
|
||||
raise HTTPException(status_code=400, detail=f"不支持的平台: {platform}")
|
||||
|
||||
user_id = _get_user_id(req)
|
||||
result = await publish_service.login(platform, user_id)
|
||||
|
||||
if result.get("success"):
|
||||
return result
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail=result.get("message"))
|
||||
|
||||
@router.get("/login/status/{platform}")
|
||||
async def get_login_status(platform: str):
|
||||
"""检查登录状态"""
|
||||
# 这里简化处理,实际应该维护一个登录会话字典
|
||||
cookie_file = publish_service.cookies_dir / f"{platform}_cookies.json"
|
||||
@router.post("/logout/{platform}")
|
||||
async def logout_platform(platform: str, req: Request):
|
||||
"""注销平台登录"""
|
||||
if platform not in SUPPORTED_PLATFORMS:
|
||||
raise HTTPException(status_code=400, detail=f"不支持的平台: {platform}")
|
||||
|
||||
if cookie_file.exists():
|
||||
return {"success": True, "message": "已登录"}
|
||||
else:
|
||||
return {"success": False, "message": "未登录"}
|
||||
user_id = _get_user_id(req)
|
||||
result = publish_service.logout(platform, user_id)
|
||||
return result
|
||||
|
||||
@router.get("/login/status/{platform}")
|
||||
async def get_login_status(platform: str, req: Request):
|
||||
"""检查登录状态 (优先检查活跃的扫码会话)"""
|
||||
if platform not in SUPPORTED_PLATFORMS:
|
||||
raise HTTPException(status_code=400, detail=f"不支持的平台: {platform}")
|
||||
|
||||
user_id = _get_user_id(req)
|
||||
return publish_service.get_login_session_status(platform, user_id)
|
||||
|
||||
@router.post("/cookies/save/{platform}")
|
||||
async def save_platform_cookie(platform: str, cookie_data: dict):
|
||||
async def save_platform_cookie(platform: str, cookie_data: dict, req: Request):
|
||||
"""
|
||||
保存从客户端浏览器提取的Cookie
|
||||
|
||||
@@ -82,8 +130,15 @@ async def save_platform_cookie(platform: str, cookie_data: dict):
|
||||
platform: 平台ID
|
||||
cookie_data: {"cookie_string": "document.cookie的内容"}
|
||||
"""
|
||||
if platform not in SUPPORTED_PLATFORMS:
|
||||
raise HTTPException(status_code=400, detail=f"不支持的平台: {platform}")
|
||||
|
||||
cookie_string = cookie_data.get("cookie_string", "")
|
||||
result = await publish_service.save_cookie_string(platform, cookie_string)
|
||||
if not cookie_string:
|
||||
raise HTTPException(status_code=400, detail="cookie_string 不能为空")
|
||||
|
||||
user_id = _get_user_id(req)
|
||||
result = await publish_service.save_cookie_string(platform, cookie_string, user_id)
|
||||
|
||||
if result.get("success"):
|
||||
return result
|
||||
|
||||
@@ -1,14 +1,19 @@
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks
|
||||
from fastapi import APIRouter, HTTPException, BackgroundTasks, Depends, Request
|
||||
from pydantic import BaseModel
|
||||
from typing import Optional
|
||||
from pathlib import Path
|
||||
from loguru import logger
|
||||
import uuid
|
||||
import traceback
|
||||
import time
|
||||
import httpx
|
||||
import os
|
||||
from app.services.tts_service import TTSService
|
||||
from app.services.video_service import VideoService
|
||||
from app.services.lipsync_service import LipSyncService
|
||||
from app.services.storage import storage_service
|
||||
from app.core.config import settings
|
||||
from app.core.deps import get_current_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
@@ -47,42 +52,73 @@ async def _check_lipsync_ready(force: bool = False) -> bool:
|
||||
print(f"[LipSync] Health check: ready={_lipsync_ready}")
|
||||
return _lipsync_ready
|
||||
|
||||
async def _process_video_generation(task_id: str, req: GenerateRequest):
|
||||
async def _download_material(path_or_url: str, temp_path: Path):
|
||||
"""下载素材到临时文件 (流式下载,节省内存)"""
|
||||
if path_or_url.startswith("http"):
|
||||
# Download from URL
|
||||
timeout = httpx.Timeout(None) # Disable timeout for large files
|
||||
async with httpx.AsyncClient(timeout=timeout) as client:
|
||||
async with client.stream("GET", path_or_url) as resp:
|
||||
resp.raise_for_status()
|
||||
with open(temp_path, "wb") as f:
|
||||
async for chunk in resp.aiter_bytes():
|
||||
f.write(chunk)
|
||||
else:
|
||||
# Local file (legacy or absolute path)
|
||||
src = Path(path_or_url)
|
||||
if not src.is_absolute():
|
||||
src = settings.BASE_DIR.parent / path_or_url
|
||||
|
||||
if src.exists():
|
||||
import shutil
|
||||
shutil.copy(src, temp_path)
|
||||
else:
|
||||
raise FileNotFoundError(f"Material not found: {path_or_url}")
|
||||
|
||||
async def _process_video_generation(task_id: str, req: GenerateRequest, user_id: str):
|
||||
temp_files = [] # Track files to clean up
|
||||
try:
|
||||
start_time = time.time()
|
||||
|
||||
# Resolve path if it's relative
|
||||
input_material_path = Path(req.material_path)
|
||||
if not input_material_path.is_absolute():
|
||||
input_material_path = settings.BASE_DIR.parent / req.material_path
|
||||
|
||||
|
||||
tasks[task_id]["status"] = "processing"
|
||||
tasks[task_id]["progress"] = 5
|
||||
tasks[task_id]["message"] = "正在初始化..."
|
||||
|
||||
tasks[task_id]["message"] = "正在下载素材..."
|
||||
|
||||
# Prepare temp dir
|
||||
temp_dir = settings.UPLOAD_DIR / "temp"
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 0. Download Material
|
||||
input_material_path = temp_dir / f"{task_id}_input.mp4"
|
||||
temp_files.append(input_material_path)
|
||||
|
||||
await _download_material(req.material_path, input_material_path)
|
||||
|
||||
# 1. TTS - 进度 5% -> 25%
|
||||
tasks[task_id]["message"] = "正在生成语音 (TTS)..."
|
||||
tasks[task_id]["progress"] = 10
|
||||
|
||||
|
||||
tts = TTSService()
|
||||
audio_path = settings.OUTPUT_DIR / f"{task_id}_audio.mp3"
|
||||
audio_path = temp_dir / f"{task_id}_audio.mp3"
|
||||
temp_files.append(audio_path)
|
||||
await tts.generate_audio(req.text, req.voice, str(audio_path))
|
||||
|
||||
|
||||
tts_time = time.time() - start_time
|
||||
print(f"[Pipeline] TTS completed in {tts_time:.1f}s")
|
||||
tasks[task_id]["progress"] = 25
|
||||
|
||||
|
||||
# 2. LipSync - 进度 25% -> 85%
|
||||
tasks[task_id]["message"] = "正在合成唇形 (LatentSync)..."
|
||||
tasks[task_id]["progress"] = 30
|
||||
|
||||
|
||||
lipsync = _get_lipsync_service()
|
||||
lipsync_video_path = settings.OUTPUT_DIR / f"{task_id}_lipsync.mp4"
|
||||
|
||||
lipsync_video_path = temp_dir / f"{task_id}_lipsync.mp4"
|
||||
temp_files.append(lipsync_video_path)
|
||||
|
||||
# 使用缓存的健康检查结果
|
||||
lipsync_start = time.time()
|
||||
is_ready = await _check_lipsync_ready()
|
||||
|
||||
|
||||
if is_ready:
|
||||
print(f"[LipSync] Starting LatentSync inference...")
|
||||
tasks[task_id]["progress"] = 35
|
||||
@@ -98,34 +134,72 @@ async def _process_video_generation(task_id: str, req: GenerateRequest):
|
||||
lipsync_time = time.time() - lipsync_start
|
||||
print(f"[Pipeline] LipSync completed in {lipsync_time:.1f}s")
|
||||
tasks[task_id]["progress"] = 85
|
||||
|
||||
|
||||
# 3. Composition - 进度 85% -> 100%
|
||||
tasks[task_id]["message"] = "正在合成最终视频..."
|
||||
tasks[task_id]["progress"] = 90
|
||||
|
||||
|
||||
video = VideoService()
|
||||
final_output = settings.OUTPUT_DIR / f"{task_id}_output.mp4"
|
||||
await video.compose(str(lipsync_video_path), str(audio_path), str(final_output))
|
||||
|
||||
final_output_local_path = temp_dir / f"{task_id}_output.mp4"
|
||||
temp_files.append(final_output_local_path)
|
||||
|
||||
await video.compose(str(lipsync_video_path), str(audio_path), str(final_output_local_path))
|
||||
|
||||
total_time = time.time() - start_time
|
||||
|
||||
# 4. Upload to Supabase with user isolation
|
||||
tasks[task_id]["message"] = "正在上传结果..."
|
||||
tasks[task_id]["progress"] = 95
|
||||
|
||||
# 使用 user_id 作为目录前缀实现隔离
|
||||
storage_path = f"{user_id}/{task_id}_output.mp4"
|
||||
with open(final_output_local_path, "rb") as f:
|
||||
file_data = f.read()
|
||||
await storage_service.upload_file(
|
||||
bucket=storage_service.BUCKET_OUTPUTS,
|
||||
path=storage_path,
|
||||
file_data=file_data,
|
||||
content_type="video/mp4"
|
||||
)
|
||||
|
||||
# Get Signed URL
|
||||
signed_url = await storage_service.get_signed_url(
|
||||
bucket=storage_service.BUCKET_OUTPUTS,
|
||||
path=storage_path
|
||||
)
|
||||
|
||||
print(f"[Pipeline] Total generation time: {total_time:.1f}s")
|
||||
|
||||
|
||||
tasks[task_id]["status"] = "completed"
|
||||
tasks[task_id]["progress"] = 100
|
||||
tasks[task_id]["message"] = f"生成完成!耗时 {total_time:.0f} 秒"
|
||||
tasks[task_id]["output"] = str(final_output)
|
||||
tasks[task_id]["download_url"] = f"/outputs/{final_output.name}"
|
||||
tasks[task_id]["output"] = storage_path
|
||||
tasks[task_id]["download_url"] = signed_url
|
||||
|
||||
except Exception as e:
|
||||
tasks[task_id]["status"] = "failed"
|
||||
tasks[task_id]["message"] = f"错误: {str(e)}"
|
||||
tasks[task_id]["error"] = traceback.format_exc()
|
||||
logger.error(f"Generate video failed: {e}")
|
||||
finally:
|
||||
# Cleanup temp files
|
||||
for f in temp_files:
|
||||
try:
|
||||
if f.exists():
|
||||
f.unlink()
|
||||
except Exception as e:
|
||||
print(f"Error cleaning up {f}: {e}")
|
||||
|
||||
@router.post("/generate")
|
||||
async def generate_video(req: GenerateRequest, background_tasks: BackgroundTasks):
|
||||
async def generate_video(
|
||||
req: GenerateRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
current_user: dict = Depends(get_current_user)
|
||||
):
|
||||
user_id = current_user["id"]
|
||||
task_id = str(uuid.uuid4())
|
||||
tasks[task_id] = {"status": "pending", "task_id": task_id, "progress": 0}
|
||||
background_tasks.add_task(_process_video_generation, task_id, req)
|
||||
tasks[task_id] = {"status": "pending", "task_id": task_id, "progress": 0, "user_id": user_id}
|
||||
background_tasks.add_task(_process_video_generation, task_id, req, user_id)
|
||||
return {"task_id": task_id}
|
||||
|
||||
@router.get("/tasks/{task_id}")
|
||||
@@ -141,3 +215,85 @@ async def lipsync_health():
|
||||
"""获取 LipSync 服务健康状态"""
|
||||
lipsync = _get_lipsync_service()
|
||||
return await lipsync.check_health()
|
||||
|
||||
|
||||
@router.get("/generated")
|
||||
async def list_generated_videos(current_user: dict = Depends(get_current_user)):
|
||||
"""从 Storage 读取当前用户生成的视频列表"""
|
||||
user_id = current_user["id"]
|
||||
try:
|
||||
# 只列出当前用户目录下的文件
|
||||
files_obj = await storage_service.list_files(
|
||||
bucket=storage_service.BUCKET_OUTPUTS,
|
||||
path=user_id
|
||||
)
|
||||
|
||||
videos = []
|
||||
for f in files_obj:
|
||||
name = f.get('name')
|
||||
if not name or name == '.emptyFolderPlaceholder':
|
||||
continue
|
||||
|
||||
# 过滤非 output.mp4 文件
|
||||
if not name.endswith("_output.mp4"):
|
||||
continue
|
||||
|
||||
# 获取 ID (即文件名去除后缀)
|
||||
video_id = Path(name).stem
|
||||
|
||||
# 完整路径包含 user_id
|
||||
full_path = f"{user_id}/{name}"
|
||||
|
||||
# 获取签名链接
|
||||
signed_url = await storage_service.get_signed_url(
|
||||
bucket=storage_service.BUCKET_OUTPUTS,
|
||||
path=full_path
|
||||
)
|
||||
|
||||
metadata = f.get('metadata', {})
|
||||
size = metadata.get('size', 0)
|
||||
# created_at 在顶层,是 ISO 字符串,转换为 Unix 时间戳
|
||||
created_at_str = f.get('created_at', '')
|
||||
created_at = 0
|
||||
if created_at_str:
|
||||
from datetime import datetime
|
||||
try:
|
||||
dt = datetime.fromisoformat(created_at_str.replace('Z', '+00:00'))
|
||||
created_at = int(dt.timestamp())
|
||||
except:
|
||||
pass
|
||||
|
||||
videos.append({
|
||||
"id": video_id,
|
||||
"name": name,
|
||||
"path": signed_url, # Direct playable URL
|
||||
"size_mb": size / (1024 * 1024),
|
||||
"created_at": created_at
|
||||
})
|
||||
|
||||
# Sort by created_at desc (newest first)
|
||||
# Supabase API usually returns ISO string, simpler string sort works for ISO
|
||||
videos.sort(key=lambda x: x.get("created_at", ""), reverse=True)
|
||||
return {"videos": videos}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"List generated videos failed: {e}")
|
||||
return {"videos": []}
|
||||
|
||||
|
||||
@router.delete("/generated/{video_id}")
|
||||
async def delete_generated_video(video_id: str, current_user: dict = Depends(get_current_user)):
|
||||
"""删除生成的视频"""
|
||||
user_id = current_user["id"]
|
||||
try:
|
||||
# video_id 通常是 uuid_output,完整路径需要加上 user_id
|
||||
storage_path = f"{user_id}/{video_id}.mp4"
|
||||
|
||||
await storage_service.delete_file(
|
||||
bucket=storage_service.BUCKET_OUTPUTS,
|
||||
path=storage_path
|
||||
)
|
||||
return {"success": True, "message": "视频已删除"}
|
||||
except Exception as e:
|
||||
raise HTTPException(500, f"删除失败: {str(e)}")
|
||||
|
||||
|
||||
@@ -26,6 +26,20 @@ class Settings(BaseSettings):
|
||||
LATENTSYNC_SEED: int = 1247 # 随机种子 (-1 则随机)
|
||||
LATENTSYNC_USE_SERVER: bool = False # 使用常驻服务 (Persistent Server) 加速
|
||||
|
||||
# Supabase 配置
|
||||
SUPABASE_URL: str = ""
|
||||
SUPABASE_PUBLIC_URL: str = "" # 公网访问地址,用于生成前端可访问的 URL
|
||||
SUPABASE_KEY: str = ""
|
||||
|
||||
# JWT 配置
|
||||
JWT_SECRET_KEY: str = "your-secret-key-change-in-production"
|
||||
JWT_ALGORITHM: str = "HS256"
|
||||
JWT_EXPIRE_HOURS: int = 24
|
||||
|
||||
# 管理员配置
|
||||
ADMIN_EMAIL: str = ""
|
||||
ADMIN_PASSWORD: str = ""
|
||||
|
||||
@property
|
||||
def LATENTSYNC_DIR(self) -> Path:
|
||||
"""LatentSync 目录路径 (动态计算)"""
|
||||
|
||||
141
backend/app/core/deps.py
Normal file
141
backend/app/core/deps.py
Normal file
@@ -0,0 +1,141 @@
|
||||
"""
|
||||
依赖注入模块:认证和用户获取
|
||||
"""
|
||||
from typing import Optional
|
||||
from fastapi import Request, HTTPException, Depends, status
|
||||
from app.core.security import decode_access_token, TokenData
|
||||
from app.core.supabase import get_supabase
|
||||
from loguru import logger
|
||||
|
||||
|
||||
async def get_token_from_cookie(request: Request) -> Optional[str]:
|
||||
"""从 Cookie 中获取 Token"""
|
||||
return request.cookies.get("access_token")
|
||||
|
||||
|
||||
async def get_current_user_optional(
|
||||
request: Request
|
||||
) -> Optional[dict]:
|
||||
"""
|
||||
获取当前用户 (可选,未登录返回 None)
|
||||
"""
|
||||
token = await get_token_from_cookie(request)
|
||||
if not token:
|
||||
return None
|
||||
|
||||
token_data = decode_access_token(token)
|
||||
if not token_data:
|
||||
return None
|
||||
|
||||
# 验证 session_token 是否有效 (单设备登录检查)
|
||||
try:
|
||||
supabase = get_supabase()
|
||||
result = supabase.table("user_sessions").select("*").eq(
|
||||
"user_id", token_data.user_id
|
||||
).eq(
|
||||
"session_token", token_data.session_token
|
||||
).execute()
|
||||
|
||||
if not result.data:
|
||||
logger.warning(f"Session token 无效: user_id={token_data.user_id}")
|
||||
return None
|
||||
|
||||
# 获取用户信息
|
||||
user_result = supabase.table("users").select("*").eq(
|
||||
"id", token_data.user_id
|
||||
).single().execute()
|
||||
|
||||
return user_result.data
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户信息失败: {e}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_current_user(
|
||||
request: Request
|
||||
) -> dict:
|
||||
"""
|
||||
获取当前用户 (必须登录)
|
||||
|
||||
Raises:
|
||||
HTTPException 401: 未登录
|
||||
HTTPException 403: 会话失效或授权过期
|
||||
"""
|
||||
token = await get_token_from_cookie(request)
|
||||
if not token:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="未登录,请先登录"
|
||||
)
|
||||
|
||||
token_data = decode_access_token(token)
|
||||
if not token_data:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Token 无效或已过期"
|
||||
)
|
||||
|
||||
try:
|
||||
supabase = get_supabase()
|
||||
|
||||
# 验证 session_token (单设备登录)
|
||||
session_result = supabase.table("user_sessions").select("*").eq(
|
||||
"user_id", token_data.user_id
|
||||
).eq(
|
||||
"session_token", token_data.session_token
|
||||
).execute()
|
||||
|
||||
if not session_result.data:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="会话已失效,请重新登录(可能已在其他设备登录)"
|
||||
)
|
||||
|
||||
# 获取用户信息
|
||||
user_result = supabase.table("users").select("*").eq(
|
||||
"id", token_data.user_id
|
||||
).single().execute()
|
||||
|
||||
user = user_result.data
|
||||
if not user:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="用户不存在"
|
||||
)
|
||||
|
||||
# 检查授权是否过期
|
||||
if user.get("expires_at"):
|
||||
from datetime import datetime, timezone
|
||||
expires_at = datetime.fromisoformat(user["expires_at"].replace("Z", "+00:00"))
|
||||
if datetime.now(timezone.utc) > expires_at:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="授权已过期,请联系管理员续期"
|
||||
)
|
||||
|
||||
return user
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"获取用户信息失败: {e}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail="服务器错误"
|
||||
)
|
||||
|
||||
|
||||
async def get_current_admin(
|
||||
current_user: dict = Depends(get_current_user)
|
||||
) -> dict:
|
||||
"""
|
||||
获取当前管理员用户
|
||||
|
||||
Raises:
|
||||
HTTPException 403: 非管理员
|
||||
"""
|
||||
if current_user.get("role") != "admin":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="需要管理员权限"
|
||||
)
|
||||
return current_user
|
||||
98
backend/app/core/paths.py
Normal file
98
backend/app/core/paths.py
Normal file
@@ -0,0 +1,98 @@
|
||||
"""
|
||||
路径规范化模块:按用户隔离 Cookie 存储
|
||||
"""
|
||||
from pathlib import Path
|
||||
import re
|
||||
from typing import Set
|
||||
|
||||
# 基础目录
|
||||
BASE_DIR = Path(__file__).parent.parent.parent
|
||||
USER_DATA_DIR = BASE_DIR / "user_data"
|
||||
|
||||
# 有效的平台列表
|
||||
VALID_PLATFORMS: Set[str] = {"bilibili", "douyin", "xiaohongshu", "weixin", "kuaishou"}
|
||||
|
||||
# UUID 格式正则
|
||||
UUID_PATTERN = re.compile(r'^[a-f0-9]{8}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{4}-[a-f0-9]{12}$', re.IGNORECASE)
|
||||
|
||||
|
||||
def validate_user_id(user_id: str) -> bool:
|
||||
"""验证 user_id 格式 (防止路径遍历攻击)"""
|
||||
return bool(UUID_PATTERN.match(user_id))
|
||||
|
||||
|
||||
def validate_platform(platform: str) -> bool:
|
||||
"""验证平台名称"""
|
||||
return platform in VALID_PLATFORMS
|
||||
|
||||
|
||||
def get_user_data_dir(user_id: str) -> Path:
|
||||
"""
|
||||
获取用户数据根目录
|
||||
|
||||
Args:
|
||||
user_id: 用户 UUID
|
||||
|
||||
Returns:
|
||||
用户数据目录路径
|
||||
|
||||
Raises:
|
||||
ValueError: user_id 格式无效
|
||||
"""
|
||||
if not validate_user_id(user_id):
|
||||
raise ValueError(f"Invalid user_id format: {user_id}")
|
||||
|
||||
user_dir = USER_DATA_DIR / user_id
|
||||
user_dir.mkdir(parents=True, exist_ok=True)
|
||||
return user_dir
|
||||
|
||||
|
||||
def get_user_cookie_dir(user_id: str) -> Path:
|
||||
"""
|
||||
获取用户 Cookie 目录
|
||||
|
||||
Args:
|
||||
user_id: 用户 UUID
|
||||
|
||||
Returns:
|
||||
Cookie 目录路径
|
||||
"""
|
||||
cookie_dir = get_user_data_dir(user_id) / "cookies"
|
||||
cookie_dir.mkdir(parents=True, exist_ok=True)
|
||||
return cookie_dir
|
||||
|
||||
|
||||
def get_platform_cookie_path(user_id: str, platform: str) -> Path:
|
||||
"""
|
||||
获取平台 Cookie 文件路径
|
||||
|
||||
Args:
|
||||
user_id: 用户 UUID
|
||||
platform: 平台名称 (bilibili/douyin/xiaohongshu)
|
||||
|
||||
Returns:
|
||||
Cookie 文件路径
|
||||
|
||||
Raises:
|
||||
ValueError: 平台名称无效
|
||||
"""
|
||||
if not validate_platform(platform):
|
||||
raise ValueError(f"Invalid platform: {platform}. Valid: {VALID_PLATFORMS}")
|
||||
|
||||
return get_user_cookie_dir(user_id) / f"{platform}_cookies.json"
|
||||
|
||||
|
||||
# === 兼容旧代码的路径 (无用户隔离) ===
|
||||
|
||||
def get_legacy_cookie_dir() -> Path:
|
||||
"""获取旧版 Cookie 目录 (无用户隔离)"""
|
||||
cookie_dir = BASE_DIR / "app" / "cookies"
|
||||
cookie_dir.mkdir(parents=True, exist_ok=True)
|
||||
return cookie_dir
|
||||
|
||||
|
||||
def get_legacy_cookie_path(platform: str) -> Path:
|
||||
"""获取旧版 Cookie 路径 (无用户隔离)"""
|
||||
if not validate_platform(platform):
|
||||
raise ValueError(f"Invalid platform: {platform}")
|
||||
return get_legacy_cookie_dir() / f"{platform}_cookies.json"
|
||||
112
backend/app/core/security.py
Normal file
112
backend/app/core/security.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
安全工具模块:JWT Token 和密码处理
|
||||
"""
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Optional, Any
|
||||
from jose import jwt, JWTError
|
||||
from passlib.context import CryptContext
|
||||
from pydantic import BaseModel
|
||||
from fastapi import Response
|
||||
from app.core.config import settings
|
||||
import uuid
|
||||
|
||||
# 密码加密上下文
|
||||
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
|
||||
|
||||
|
||||
class TokenData(BaseModel):
|
||||
"""JWT Token 数据结构"""
|
||||
user_id: str
|
||||
session_token: str
|
||||
exp: datetime
|
||||
|
||||
|
||||
def verify_password(plain_password: str, hashed_password: str) -> bool:
|
||||
"""验证密码"""
|
||||
return pwd_context.verify(plain_password, hashed_password)
|
||||
|
||||
|
||||
def get_password_hash(password: str) -> str:
|
||||
"""生成密码哈希"""
|
||||
return pwd_context.hash(password)
|
||||
|
||||
|
||||
def create_access_token(user_id: str, session_token: str) -> str:
|
||||
"""
|
||||
创建 JWT Access Token
|
||||
|
||||
Args:
|
||||
user_id: 用户 ID
|
||||
session_token: 会话 Token (用于单设备登录验证)
|
||||
"""
|
||||
expire = datetime.now(timezone.utc) + timedelta(hours=settings.JWT_EXPIRE_HOURS)
|
||||
|
||||
to_encode = {
|
||||
"sub": user_id,
|
||||
"session_token": session_token,
|
||||
"exp": expire
|
||||
}
|
||||
|
||||
return jwt.encode(
|
||||
to_encode,
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithm=settings.JWT_ALGORITHM
|
||||
)
|
||||
|
||||
|
||||
def decode_access_token(token: str) -> Optional[TokenData]:
|
||||
"""
|
||||
解码并验证 JWT Token
|
||||
|
||||
Returns:
|
||||
TokenData 或 None (如果验证失败)
|
||||
"""
|
||||
try:
|
||||
payload = jwt.decode(
|
||||
token,
|
||||
settings.JWT_SECRET_KEY,
|
||||
algorithms=[settings.JWT_ALGORITHM]
|
||||
)
|
||||
|
||||
user_id = payload.get("sub")
|
||||
session_token = payload.get("session_token")
|
||||
exp = payload.get("exp")
|
||||
|
||||
if not user_id or not session_token:
|
||||
return None
|
||||
|
||||
return TokenData(
|
||||
user_id=user_id,
|
||||
session_token=session_token,
|
||||
exp=datetime.fromtimestamp(exp, tz=timezone.utc)
|
||||
)
|
||||
except JWTError:
|
||||
return None
|
||||
|
||||
|
||||
def generate_session_token() -> str:
|
||||
"""生成新的会话 Token"""
|
||||
return str(uuid.uuid4())
|
||||
|
||||
|
||||
def set_auth_cookie(response: Response, token: str) -> None:
|
||||
"""
|
||||
设置 HttpOnly Cookie
|
||||
|
||||
Args:
|
||||
response: FastAPI Response 对象
|
||||
token: JWT Token
|
||||
"""
|
||||
response.set_cookie(
|
||||
key="access_token",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=not settings.DEBUG, # 开发/测试环境(DEBUG=True)允许非HTTPS
|
||||
samesite="lax",
|
||||
max_age=settings.JWT_EXPIRE_HOURS * 3600
|
||||
)
|
||||
|
||||
|
||||
def clear_auth_cookie(response: Response) -> None:
|
||||
"""清除认证 Cookie"""
|
||||
response.delete_cookie(key="access_token")
|
||||
26
backend/app/core/supabase.py
Normal file
26
backend/app/core/supabase.py
Normal file
@@ -0,0 +1,26 @@
|
||||
"""
|
||||
Supabase 客户端初始化
|
||||
"""
|
||||
from supabase import create_client, Client
|
||||
from app.core.config import settings
|
||||
from loguru import logger
|
||||
from typing import Optional
|
||||
|
||||
_supabase_client: Optional[Client] = None
|
||||
|
||||
|
||||
def get_supabase() -> Client:
|
||||
"""获取 Supabase 客户端单例"""
|
||||
global _supabase_client
|
||||
|
||||
if _supabase_client is None:
|
||||
if not settings.SUPABASE_URL or not settings.SUPABASE_KEY:
|
||||
raise ValueError("SUPABASE_URL 和 SUPABASE_KEY 必须在 .env 中配置")
|
||||
|
||||
_supabase_client = create_client(
|
||||
settings.SUPABASE_URL,
|
||||
settings.SUPABASE_KEY
|
||||
)
|
||||
logger.info("Supabase 客户端已初始化")
|
||||
|
||||
return _supabase_client
|
||||
@@ -2,12 +2,36 @@ from fastapi import FastAPI
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from app.core import config
|
||||
from app.api import materials, videos, publish, login_helper
|
||||
from app.api import materials, videos, publish, login_helper, auth, admin
|
||||
from loguru import logger
|
||||
import os
|
||||
|
||||
settings = config.settings
|
||||
|
||||
app = FastAPI(title="ViGent TalkingHead Agent")
|
||||
|
||||
from fastapi import Request
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
import time
|
||||
import traceback
|
||||
|
||||
class LoggingMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
start_time = time.time()
|
||||
logger.info(f"START Request: {request.method} {request.url}")
|
||||
logger.info(f"HEADERS: {dict(request.headers)}")
|
||||
try:
|
||||
response = await call_next(request)
|
||||
process_time = time.time() - start_time
|
||||
logger.info(f"END Request: {request.method} {request.url} - Status: {response.status_code} - Duration: {process_time:.2f}s")
|
||||
return response
|
||||
except Exception as e:
|
||||
process_time = time.time() - start_time
|
||||
logger.error(f"EXCEPTION during request {request.method} {request.url}: {str(e)}\n{traceback.format_exc()}")
|
||||
raise e
|
||||
|
||||
app.add_middleware(LoggingMiddleware)
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
@@ -22,11 +46,56 @@ settings.OUTPUT_DIR.mkdir(parents=True, exist_ok=True)
|
||||
(settings.UPLOAD_DIR / "materials").mkdir(exist_ok=True)
|
||||
|
||||
app.mount("/outputs", StaticFiles(directory=str(settings.OUTPUT_DIR)), name="outputs")
|
||||
app.mount("/uploads", StaticFiles(directory=str(settings.UPLOAD_DIR)), name="uploads")
|
||||
|
||||
# 注册路由
|
||||
app.include_router(materials.router, prefix="/api/materials", tags=["Materials"])
|
||||
app.include_router(videos.router, prefix="/api/videos", tags=["Videos"])
|
||||
app.include_router(publish.router, prefix="/api/publish", tags=["Publish"])
|
||||
app.include_router(login_helper.router, prefix="/api", tags=["LoginHelper"])
|
||||
app.include_router(auth.router) # /api/auth
|
||||
app.include_router(admin.router) # /api/admin
|
||||
|
||||
|
||||
@app.on_event("startup")
|
||||
async def init_admin():
|
||||
"""
|
||||
服务启动时初始化管理员账号
|
||||
"""
|
||||
admin_email = settings.ADMIN_EMAIL
|
||||
admin_password = settings.ADMIN_PASSWORD
|
||||
|
||||
if not admin_email or not admin_password:
|
||||
logger.warning("未配置 ADMIN_EMAIL 和 ADMIN_PASSWORD,跳过管理员初始化")
|
||||
return
|
||||
|
||||
try:
|
||||
from app.core.supabase import get_supabase
|
||||
from app.core.security import get_password_hash
|
||||
|
||||
supabase = get_supabase()
|
||||
|
||||
# 检查是否已存在
|
||||
existing = supabase.table("users").select("id").eq("email", admin_email).execute()
|
||||
|
||||
if existing.data:
|
||||
logger.info(f"管理员账号已存在: {admin_email}")
|
||||
return
|
||||
|
||||
# 创建管理员
|
||||
supabase.table("users").insert({
|
||||
"email": admin_email,
|
||||
"password_hash": get_password_hash(admin_password),
|
||||
"username": "Admin",
|
||||
"role": "admin",
|
||||
"is_active": True,
|
||||
"expires_at": None # 永不过期
|
||||
}).execute()
|
||||
|
||||
logger.success(f"管理员账号已创建: {admin_email}")
|
||||
except Exception as e:
|
||||
logger.error(f"初始化管理员失败: {e}")
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health():
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
"""
|
||||
发布服务 (基于 social-auto-upload 架构)
|
||||
发布服务 (支持用户隔离)
|
||||
"""
|
||||
import json
|
||||
import os
|
||||
import re
|
||||
import tempfile
|
||||
import httpx
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Dict, Any
|
||||
from loguru import logger
|
||||
from app.core.config import settings
|
||||
from app.core.paths import get_user_cookie_dir, get_platform_cookie_path, get_legacy_cookie_dir, get_legacy_cookie_path
|
||||
from app.services.storage import storage_service
|
||||
|
||||
# Import platform uploaders
|
||||
from .uploader.bilibili_uploader import BilibiliUploader
|
||||
@@ -14,30 +21,50 @@ from .uploader.xiaohongshu_uploader import XiaohongshuUploader
|
||||
|
||||
|
||||
class PublishService:
|
||||
"""Social media publishing service"""
|
||||
|
||||
PLATFORMS = {
|
||||
"bilibili": {"name": "B站", "url": "https://member.bilibili.com/platform/upload/video/frame"},
|
||||
"douyin": {"name": "抖音", "url": "https://creator.douyin.com/"},
|
||||
"xiaohongshu": {"name": "小红书", "url": "https://creator.xiaohongshu.com/"},
|
||||
"weixin": {"name": "微信视频号", "url": "https://channels.weixin.qq.com/"},
|
||||
"kuaishou": {"name": "快手", "url": "https://cp.kuaishou.com/"},
|
||||
"""Social media publishing service (with user isolation)"""
|
||||
|
||||
# 支持的平台配置
|
||||
PLATFORMS: Dict[str, Dict[str, Any]] = {
|
||||
"bilibili": {"name": "B站", "url": "https://member.bilibili.com/platform/upload/video/frame", "enabled": True},
|
||||
"douyin": {"name": "抖音", "url": "https://creator.douyin.com/", "enabled": True},
|
||||
"xiaohongshu": {"name": "小红书", "url": "https://creator.xiaohongshu.com/", "enabled": True},
|
||||
"weixin": {"name": "微信视频号", "url": "https://channels.weixin.qq.com/", "enabled": False},
|
||||
"kuaishou": {"name": "快手", "url": "https://cp.kuaishou.com/", "enabled": False},
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.cookies_dir = settings.BASE_DIR / "cookies"
|
||||
self.cookies_dir.mkdir(exist_ok=True)
|
||||
def __init__(self) -> None:
|
||||
# 存储活跃的登录会话,用于跟踪登录状态
|
||||
# key 格式: "{user_id}_{platform}" 或 "{platform}" (兼容旧版)
|
||||
self.active_login_sessions: Dict[str, Any] = {}
|
||||
|
||||
def get_accounts(self):
|
||||
def _get_cookies_dir(self, user_id: Optional[str] = None) -> Path:
|
||||
"""获取 Cookie 目录 (支持用户隔离)"""
|
||||
if user_id:
|
||||
return get_user_cookie_dir(user_id)
|
||||
return get_legacy_cookie_dir()
|
||||
|
||||
def _get_cookie_path(self, platform: str, user_id: Optional[str] = None) -> Path:
|
||||
"""获取 Cookie 文件路径 (支持用户隔离)"""
|
||||
if user_id:
|
||||
return get_platform_cookie_path(user_id, platform)
|
||||
return get_legacy_cookie_path(platform)
|
||||
|
||||
def _get_session_key(self, platform: str, user_id: Optional[str] = None) -> str:
|
||||
"""获取会话 key"""
|
||||
if user_id:
|
||||
return f"{user_id}_{platform}"
|
||||
return platform
|
||||
|
||||
def get_accounts(self, user_id: Optional[str] = None) -> List[Dict[str, Any]]:
|
||||
"""Get list of platform accounts with login status"""
|
||||
accounts = []
|
||||
for pid, pinfo in self.PLATFORMS.items():
|
||||
cookie_file = self.cookies_dir / f"{pid}_cookies.json"
|
||||
cookie_file = self._get_cookie_path(pid, user_id)
|
||||
accounts.append({
|
||||
"platform": pid,
|
||||
"name": pinfo["name"],
|
||||
"logged_in": cookie_file.exists(),
|
||||
"enabled": True
|
||||
"enabled": pinfo.get("enabled", True)
|
||||
})
|
||||
return accounts
|
||||
|
||||
@@ -49,8 +76,9 @@ class PublishService:
|
||||
tags: List[str],
|
||||
description: str = "",
|
||||
publish_time: Optional[datetime] = None,
|
||||
**kwargs
|
||||
):
|
||||
user_id: Optional[str] = None,
|
||||
**kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Publish video to specified platform
|
||||
|
||||
@@ -61,6 +89,7 @@ class PublishService:
|
||||
tags: List of tags
|
||||
description: Video description
|
||||
publish_time: Scheduled publish time (None = immediate)
|
||||
user_id: User ID for cookie isolation
|
||||
**kwargs: Additional platform-specific parameters
|
||||
|
||||
Returns:
|
||||
@@ -75,30 +104,81 @@ class PublishService:
|
||||
"platform": platform
|
||||
}
|
||||
|
||||
# Get account file path
|
||||
account_file = self.cookies_dir / f"{platform}_cookies.json"
|
||||
# Get account file path (with user isolation)
|
||||
account_file = self._get_cookie_path(platform, user_id)
|
||||
|
||||
if not account_file.exists():
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"请先登录 {self.PLATFORMS[platform]['name']}",
|
||||
"platform": platform
|
||||
}
|
||||
|
||||
logger.info(f"[发布] 平台: {self.PLATFORMS[platform]['name']}")
|
||||
logger.info(f"[发布] 视频: {video_path}")
|
||||
logger.info(f"[发布] 标题: {title}")
|
||||
|
||||
logger.info(f"[发布] 用户: {user_id or 'legacy'}")
|
||||
|
||||
temp_file = None
|
||||
try:
|
||||
# 处理视频路径
|
||||
if video_path.startswith('http://') or video_path.startswith('https://'):
|
||||
# 尝试从 URL 解析 bucket 和 path,直接使用本地文件
|
||||
local_video_path = None
|
||||
|
||||
# URL 格式: .../storage/v1/object/sign/{bucket}/{path}?token=...
|
||||
match = re.search(r'/storage/v1/object/sign/([^/]+)/(.+?)\?', video_path)
|
||||
if match:
|
||||
bucket = match.group(1)
|
||||
storage_path = match.group(2)
|
||||
logger.info(f"[发布] 解析 URL: bucket={bucket}, path={storage_path}")
|
||||
|
||||
# 尝试获取本地文件路径
|
||||
local_video_path = storage_service.get_local_file_path(bucket, storage_path)
|
||||
|
||||
if local_video_path and os.path.exists(local_video_path):
|
||||
logger.info(f"[发布] 直接使用本地文件: {local_video_path}")
|
||||
else:
|
||||
# 本地文件不存在,通过 HTTP 下载
|
||||
logger.info(f"[发布] 本地文件不存在,通过 HTTP 下载...")
|
||||
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.mp4')
|
||||
temp_file.close()
|
||||
|
||||
# 将公网 URL 替换为内网 URL
|
||||
download_url = video_path
|
||||
if settings.SUPABASE_PUBLIC_URL and settings.SUPABASE_URL:
|
||||
public_url = settings.SUPABASE_PUBLIC_URL.rstrip('/')
|
||||
internal_url = settings.SUPABASE_URL.rstrip('/')
|
||||
download_url = video_path.replace(public_url, internal_url)
|
||||
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(None)) as client:
|
||||
async with client.stream("GET", download_url) as resp:
|
||||
resp.raise_for_status()
|
||||
with open(temp_file.name, 'wb') as f:
|
||||
async for chunk in resp.aiter_bytes():
|
||||
f.write(chunk)
|
||||
local_video_path = temp_file.name
|
||||
logger.info(f"[发布] 视频已下载到: {local_video_path}")
|
||||
else:
|
||||
# 本地相对路径
|
||||
local_video_path = str(settings.BASE_DIR.parent / video_path)
|
||||
|
||||
# Select appropriate uploader
|
||||
if platform == "bilibili":
|
||||
uploader = BilibiliUploader(
|
||||
title=title,
|
||||
file_path=str(settings.BASE_DIR / video_path), # Convert to absolute path
|
||||
file_path=local_video_path,
|
||||
tags=tags,
|
||||
publish_date=publish_time,
|
||||
account_file=str(account_file),
|
||||
description=description,
|
||||
tid=kwargs.get('tid', 122), # Category ID
|
||||
copyright=kwargs.get('copyright', 1) # 1=original
|
||||
tid=kwargs.get('tid', 122),
|
||||
copyright=kwargs.get('copyright', 1)
|
||||
)
|
||||
elif platform == "douyin":
|
||||
uploader = DouyinUploader(
|
||||
title=title,
|
||||
file_path=str(settings.BASE_DIR / video_path),
|
||||
file_path=local_video_path,
|
||||
tags=tags,
|
||||
publish_date=publish_time,
|
||||
account_file=str(account_file),
|
||||
@@ -107,7 +187,7 @@ class PublishService:
|
||||
elif platform == "xiaohongshu":
|
||||
uploader = XiaohongshuUploader(
|
||||
title=title,
|
||||
file_path=str(settings.BASE_DIR / video_path),
|
||||
file_path=local_video_path,
|
||||
tags=tags,
|
||||
publish_date=publish_time,
|
||||
account_file=str(account_file),
|
||||
@@ -125,7 +205,7 @@ class PublishService:
|
||||
result = await uploader.main()
|
||||
result['platform'] = platform
|
||||
return result
|
||||
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"[发布] 上传异常: {e}")
|
||||
return {
|
||||
@@ -133,11 +213,23 @@ class PublishService:
|
||||
"message": f"上传异常: {str(e)}",
|
||||
"platform": platform
|
||||
}
|
||||
finally:
|
||||
# 清理临时文件
|
||||
if temp_file and os.path.exists(temp_file.name):
|
||||
try:
|
||||
os.remove(temp_file.name)
|
||||
logger.info(f"[发布] 已清理临时文件: {temp_file.name}")
|
||||
except Exception as e:
|
||||
logger.warning(f"[发布] 清理临时文件失败: {e}")
|
||||
|
||||
async def login(self, platform: str):
|
||||
async def login(self, platform: str, user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
启动QR码登录流程
|
||||
|
||||
Args:
|
||||
platform: 平台 ID
|
||||
user_id: 用户 ID (用于 Cookie 隔离)
|
||||
|
||||
Returns:
|
||||
dict: 包含二维码base64图片
|
||||
"""
|
||||
@@ -147,8 +239,15 @@ class PublishService:
|
||||
try:
|
||||
from .qr_login_service import QRLoginService
|
||||
|
||||
# 获取用户专属的 Cookie 目录
|
||||
cookies_dir = self._get_cookies_dir(user_id)
|
||||
|
||||
# 创建QR登录服务
|
||||
qr_service = QRLoginService(platform, self.cookies_dir)
|
||||
qr_service = QRLoginService(platform, cookies_dir)
|
||||
|
||||
# 存储活跃会话 (带用户隔离)
|
||||
session_key = self._get_session_key(platform, user_id)
|
||||
self.active_login_sessions[session_key] = qr_service
|
||||
|
||||
# 启动登录并获取二维码
|
||||
result = await qr_service.start_login()
|
||||
@@ -161,17 +260,67 @@ class PublishService:
|
||||
"success": False,
|
||||
"message": f"登录失败: {str(e)}"
|
||||
}
|
||||
|
||||
def get_login_session_status(self, platform: str, user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""获取活跃登录会话的状态"""
|
||||
session_key = self._get_session_key(platform, user_id)
|
||||
|
||||
# 1. 如果有活跃的扫码会话,优先检查它
|
||||
if session_key in self.active_login_sessions:
|
||||
qr_service = self.active_login_sessions[session_key]
|
||||
status = qr_service.get_login_status()
|
||||
|
||||
# 如果登录成功且Cookie已保存,清理会话
|
||||
if status["success"] and status["cookies_saved"]:
|
||||
del self.active_login_sessions[session_key]
|
||||
return {"success": True, "message": "登录成功"}
|
||||
|
||||
return {"success": False, "message": "等待扫码..."}
|
||||
|
||||
# 2. 检查本地Cookie文件是否存在
|
||||
cookie_file = self._get_cookie_path(platform, user_id)
|
||||
if cookie_file.exists():
|
||||
return {"success": True, "message": "已登录 (历史状态)"}
|
||||
|
||||
return {"success": False, "message": "未登录"}
|
||||
|
||||
async def save_cookie_string(self, platform: str, cookie_string: str):
|
||||
def logout(self, platform: str, user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
Logout from platform (delete cookie file)
|
||||
"""
|
||||
if platform not in self.PLATFORMS:
|
||||
return {"success": False, "message": "不支持的平台"}
|
||||
|
||||
try:
|
||||
session_key = self._get_session_key(platform, user_id)
|
||||
|
||||
# 1. 移除活跃会话
|
||||
if session_key in self.active_login_sessions:
|
||||
del self.active_login_sessions[session_key]
|
||||
|
||||
# 2. 删除Cookie文件
|
||||
cookie_file = self._get_cookie_path(platform, user_id)
|
||||
if cookie_file.exists():
|
||||
cookie_file.unlink()
|
||||
logger.info(f"[登出] {platform} Cookie已删除 (user: {user_id or 'legacy'})")
|
||||
|
||||
return {"success": True, "message": "已注销"}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"[登出] 失败: {e}")
|
||||
return {"success": False, "message": f"注销失败: {str(e)}"}
|
||||
|
||||
async def save_cookie_string(self, platform: str, cookie_string: str, user_id: Optional[str] = None) -> Dict[str, Any]:
|
||||
"""
|
||||
保存从客户端浏览器提取的Cookie字符串
|
||||
|
||||
Args:
|
||||
platform: 平台ID
|
||||
cookie_string: document.cookie 格式的Cookie字符串
|
||||
user_id: 用户 ID (用于 Cookie 隔离)
|
||||
"""
|
||||
try:
|
||||
account_file = self.cookies_dir / f"{platform}_cookies.json"
|
||||
account_file = self._get_cookie_path(platform, user_id)
|
||||
|
||||
# 解析Cookie字符串
|
||||
cookie_dict = {}
|
||||
@@ -180,7 +329,7 @@ class PublishService:
|
||||
name, value = item.split('=', 1)
|
||||
cookie_dict[name] = value
|
||||
|
||||
# 对B站进行特殊处理,提取biliup需要的字段
|
||||
# 对B站进行特殊处理
|
||||
if platform == "bilibili":
|
||||
bilibili_cookies = {}
|
||||
required_fields = ['SESSDATA', 'bili_jct', 'DedeUserID', 'DedeUserID__ckMd5']
|
||||
@@ -189,7 +338,7 @@ class PublishService:
|
||||
if field in cookie_dict:
|
||||
bilibili_cookies[field] = cookie_dict[field]
|
||||
|
||||
if len(bilibili_cookies) < 3: # 至少需要3个关键字段
|
||||
if len(bilibili_cookies) < 3:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Cookie不完整,请确保已登录"
|
||||
@@ -197,12 +346,14 @@ class PublishService:
|
||||
|
||||
cookie_dict = bilibili_cookies
|
||||
|
||||
# 确保目录存在
|
||||
account_file.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# 保存Cookie
|
||||
import json
|
||||
with open(account_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(cookie_dict, f, indent=2)
|
||||
|
||||
logger.success(f"[登录] {platform} Cookie已保存")
|
||||
logger.success(f"[登录] {platform} Cookie已保存 (user: {user_id or 'legacy'})")
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
|
||||
@@ -4,22 +4,30 @@ QR码自动登录服务
|
||||
"""
|
||||
import asyncio
|
||||
import base64
|
||||
from pathlib import Path
|
||||
from playwright.async_api import async_playwright, Page
|
||||
from loguru import logger
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional, Dict, Any, List
|
||||
from playwright.async_api import async_playwright, Page, BrowserContext, Browser, Playwright as PW
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class QRLoginService:
|
||||
"""QR码登录服务"""
|
||||
|
||||
def __init__(self, platform: str, cookies_dir: Path):
|
||||
# 登录监控超时 (秒)
|
||||
LOGIN_TIMEOUT = 120
|
||||
|
||||
def __init__(self, platform: str, cookies_dir: Path) -> None:
|
||||
self.platform = platform
|
||||
self.cookies_dir = cookies_dir
|
||||
self.qr_code_image = None
|
||||
self.login_success = False
|
||||
self.cookies_data = None
|
||||
self.qr_code_image: Optional[str] = None
|
||||
self.login_success: bool = False
|
||||
self.cookies_data: Optional[Dict[str, Any]] = None
|
||||
|
||||
# Playwright 资源 (手动管理生命周期)
|
||||
self.playwright: Optional[PW] = None
|
||||
self.browser: Optional[Browser] = None
|
||||
self.context: Optional[BrowserContext] = None
|
||||
|
||||
# 每个平台使用多个选择器 (使用逗号分隔,Playwright会同时等待它们)
|
||||
self.platform_configs = {
|
||||
@@ -56,7 +64,7 @@ class QRLoginService:
|
||||
}
|
||||
}
|
||||
|
||||
async def start_login(self):
|
||||
async def start_login(self) -> Dict[str, Any]:
|
||||
"""
|
||||
启动登录流程
|
||||
|
||||
@@ -129,109 +137,112 @@ class QRLoginService:
|
||||
await self._cleanup()
|
||||
return {"success": False, "message": f"启动失败: {str(e)}"}
|
||||
|
||||
async def _extract_qr_code(self, page: Page, selectors: list) -> str:
|
||||
async def _extract_qr_code(self, page: Page, selectors: List[str]) -> Optional[str]:
|
||||
"""
|
||||
提取二维码图片(并行执行 CSS策略 和 文本策略)
|
||||
提取二维码图片 (优化策略顺序)
|
||||
根据日志分析:抖音和B站使用 Text 策略成功率最高
|
||||
"""
|
||||
async def strategy_css():
|
||||
qr_element = None
|
||||
|
||||
# 针对抖音和B站:优先使用 Text 策略 (成功率最高,速度最快)
|
||||
if self.platform in ("douyin", "bilibili"):
|
||||
# 尝试最多2次 (首次 + 1次重试)
|
||||
for attempt in range(2):
|
||||
if attempt > 0:
|
||||
logger.info(f"[{self.platform}] 等待页面加载后重试...")
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# 策略1: Text (优先,成功率最高)
|
||||
qr_element = await self._try_text_strategy(page)
|
||||
if qr_element:
|
||||
try:
|
||||
screenshot = await qr_element.screenshot()
|
||||
return base64.b64encode(screenshot).decode()
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.platform}] Text策略截图失败: {e}")
|
||||
qr_element = None
|
||||
|
||||
# 策略2: CSS (备用)
|
||||
if not qr_element:
|
||||
try:
|
||||
combined_selector = ", ".join(selectors)
|
||||
logger.debug(f"[{self.platform}] 策略2(CSS): 开始等待...")
|
||||
# 增加超时到5秒,抖音页面加载较慢
|
||||
el = await page.wait_for_selector(combined_selector, state="visible", timeout=5000)
|
||||
if el:
|
||||
logger.info(f"[{self.platform}] 策略2(CSS): 匹配成功")
|
||||
screenshot = await el.screenshot()
|
||||
return base64.b64encode(screenshot).decode()
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.platform}] 策略2(CSS) 失败: {e}")
|
||||
|
||||
# 如果已成功,退出循环
|
||||
if qr_element:
|
||||
break
|
||||
else:
|
||||
# 其他平台 (小红书等):保持原顺序 CSS -> Text
|
||||
# 策略1: CSS 选择器
|
||||
try:
|
||||
combined_selector = ", ".join(selectors)
|
||||
logger.debug(f"[{self.platform}] 策略1(CSS): 开始等待...")
|
||||
el = await page.wait_for_selector(combined_selector, state="visible", timeout=5000)
|
||||
if el:
|
||||
logger.info(f"[{self.platform}] 策略1(CSS): 匹配成功")
|
||||
return el
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
async def strategy_text():
|
||||
# 扩展支持 Bilibili 和 Douyin
|
||||
if self.platform not in ["bilibili", "douyin"]: return None
|
||||
try:
|
||||
logger.debug(f"[{self.platform}] 策略2(Text): 开始搜索...")
|
||||
# 关键词列表
|
||||
keywords = ["扫码登录", "打开抖音", "抖音APP", "使用手机抖音扫码"]
|
||||
scan_text = None
|
||||
|
||||
# 遍历尝试关键词 (带等待)
|
||||
for kw in keywords:
|
||||
try:
|
||||
t = page.get_by_text(kw, exact=False).first
|
||||
# 稍微等待一下文字渲染
|
||||
await t.wait_for(state="visible", timeout=2000)
|
||||
scan_text = t
|
||||
logger.debug(f"[{self.platform}] 找到关键词: {kw}")
|
||||
break
|
||||
except:
|
||||
continue
|
||||
|
||||
if scan_text:
|
||||
# 尝试定位周边的图片
|
||||
parent_locator = scan_text
|
||||
# 向上查找5层(扩大范围)
|
||||
for _ in range(5):
|
||||
parent_locator = parent_locator.locator("..")
|
||||
|
||||
# 找图片
|
||||
img = parent_locator.locator("img").first
|
||||
if await img.is_visible():
|
||||
# 过滤掉头像等小图标,确保尺寸足够大
|
||||
bbox = await img.bounding_box()
|
||||
if bbox and bbox['width'] > 100:
|
||||
logger.info(f"[{self.platform}] 策略2(Text): 定位成功(Img)")
|
||||
return img
|
||||
|
||||
# 找Canvas
|
||||
canvas = parent_locator.locator("canvas").first
|
||||
if await canvas.is_visible():
|
||||
logger.info(f"[{self.platform}] 策略2(Text): 定位成功(Canvas)")
|
||||
return canvas
|
||||
|
||||
qr_element = el
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.platform}] 策略2异常: {e}")
|
||||
return None
|
||||
|
||||
# 并行执行两个策略,谁先找到算谁的
|
||||
tasks = [
|
||||
asyncio.create_task(strategy_css()),
|
||||
asyncio.create_task(strategy_text())
|
||||
]
|
||||
|
||||
qr_element = None
|
||||
pending = set(tasks)
|
||||
|
||||
while pending:
|
||||
done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED)
|
||||
logger.warning(f"[{self.platform}] 策略1(CSS) 失败: {e}")
|
||||
|
||||
for task in done:
|
||||
result = await task
|
||||
if result:
|
||||
qr_element = result
|
||||
break
|
||||
# 策略2: Text
|
||||
if not qr_element:
|
||||
qr_element = await self._try_text_strategy(page)
|
||||
|
||||
# 如果找到元素,截图返回
|
||||
if qr_element:
|
||||
break
|
||||
try:
|
||||
screenshot = await qr_element.screenshot()
|
||||
return base64.b64encode(screenshot).decode()
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.platform}] 截图失败: {e}")
|
||||
|
||||
# 取消剩下的任务 (如果找到了)
|
||||
for task in pending:
|
||||
task.cancel()
|
||||
|
||||
if qr_element:
|
||||
try:
|
||||
screenshot = await qr_element.screenshot()
|
||||
return base64.b64encode(screenshot).decode()
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.platform}] 截图失败: {e}")
|
||||
|
||||
# 失败处理
|
||||
logger.warning(f"[{self.platform}] 所有策略失败,保存全页截图")
|
||||
# 所有策略失败
|
||||
logger.error(f"[{self.platform}] 所有QR码提取策略失败")
|
||||
|
||||
# 保存调试截图
|
||||
debug_dir = Path(__file__).parent.parent.parent / 'debug_screenshots'
|
||||
debug_dir.mkdir(exist_ok=True)
|
||||
await page.screenshot(path=str(debug_dir / f"{self.platform}_debug.png"))
|
||||
|
||||
screenshot = await page.screenshot()
|
||||
return base64.b64encode(screenshot).decode()
|
||||
return None
|
||||
|
||||
async def _try_text_strategy(self, page: Page) -> Optional[Any]:
|
||||
"""基于文本查找二维码图片"""
|
||||
try:
|
||||
logger.debug(f"[{self.platform}] 策略Text: 开始搜索...")
|
||||
keywords = ["扫码登录", "二维码", "打开抖音", "抖音APP", "使用APP扫码"]
|
||||
|
||||
for kw in keywords:
|
||||
try:
|
||||
text_el = page.get_by_text(kw, exact=False).first
|
||||
await text_el.wait_for(state="visible", timeout=2000)
|
||||
|
||||
# 向上查找图片
|
||||
parent = text_el
|
||||
for _ in range(5):
|
||||
parent = parent.locator("..")
|
||||
imgs = parent.locator("img")
|
||||
|
||||
for i in range(await imgs.count()):
|
||||
img = imgs.nth(i)
|
||||
if await img.is_visible():
|
||||
bbox = await img.bounding_box()
|
||||
if bbox and bbox['width'] > 100:
|
||||
logger.info(f"[{self.platform}] 策略Text: 成功")
|
||||
return img
|
||||
except Exception:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.warning(f"[{self.platform}] 策略Text 失败: {e}")
|
||||
return None
|
||||
|
||||
async def _monitor_login_status(self, page: Page, success_url: str):
|
||||
"""监控登录状态"""
|
||||
@@ -240,7 +251,7 @@ class QRLoginService:
|
||||
key_cookies = {"bilibili": "SESSDATA", "douyin": "sessionid", "xiaohongshu": "web_session"}
|
||||
target_cookie = key_cookies.get(self.platform, "")
|
||||
|
||||
for i in range(120):
|
||||
for i in range(self.LOGIN_TIMEOUT):
|
||||
await asyncio.sleep(1)
|
||||
|
||||
try:
|
||||
@@ -275,37 +286,57 @@ class QRLoginService:
|
||||
finally:
|
||||
await self._cleanup()
|
||||
|
||||
async def _cleanup(self):
|
||||
async def _cleanup(self) -> None:
|
||||
"""清理资源"""
|
||||
if hasattr(self, 'context') and self.context:
|
||||
try: await self.context.close()
|
||||
except: pass
|
||||
if hasattr(self, 'browser') and self.browser:
|
||||
try: await self.browser.close()
|
||||
except: pass
|
||||
if hasattr(self, 'playwright') and self.playwright:
|
||||
try: await self.playwright.stop()
|
||||
except: pass
|
||||
if self.context:
|
||||
try:
|
||||
await self.context.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.context = None
|
||||
if self.browser:
|
||||
try:
|
||||
await self.browser.close()
|
||||
except Exception:
|
||||
pass
|
||||
self.browser = None
|
||||
if self.playwright:
|
||||
try:
|
||||
await self.playwright.stop()
|
||||
except Exception:
|
||||
pass
|
||||
self.playwright = None
|
||||
|
||||
async def _save_cookies(self, cookies: list):
|
||||
async def _save_cookies(self, cookies: List[Dict[str, Any]]) -> None:
|
||||
"""保存Cookie到文件"""
|
||||
try:
|
||||
cookie_file = self.cookies_dir / f"{self.platform}_cookies.json"
|
||||
cookie_dict = {c['name']: c['value'] for c in cookies}
|
||||
|
||||
if self.platform == "bilibili":
|
||||
# Bilibili 使用简单格式 (biliup库需要)
|
||||
cookie_dict = {c['name']: c['value'] for c in cookies}
|
||||
required = ['SESSDATA', 'bili_jct', 'DedeUserID', 'DedeUserID__ckMd5']
|
||||
cookie_dict = {k: v for k, v in cookie_dict.items() if k in required}
|
||||
|
||||
with open(cookie_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(cookie_dict, f, indent=2)
|
||||
self.cookies_data = cookie_dict
|
||||
else:
|
||||
# Douyin/Xiaohongshu 使用 Playwright storage_state 完整格式
|
||||
# 这样可以直接用 browser.new_context(storage_state=file)
|
||||
storage_state = {
|
||||
"cookies": cookies,
|
||||
"origins": []
|
||||
}
|
||||
with open(cookie_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(storage_state, f, indent=2)
|
||||
self.cookies_data = storage_state
|
||||
|
||||
with open(cookie_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(cookie_dict, f, indent=2)
|
||||
|
||||
self.cookies_data = cookie_dict
|
||||
logger.success(f"[{self.platform}] Cookie已保存")
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.platform}] 保存Cookie失败: {e}")
|
||||
|
||||
def get_login_status(self):
|
||||
def get_login_status(self) -> Dict[str, Any]:
|
||||
"""获取登录状态"""
|
||||
return {
|
||||
"success": self.login_success,
|
||||
|
||||
148
backend/app/services/storage.py
Normal file
148
backend/app/services/storage.py
Normal file
@@ -0,0 +1,148 @@
|
||||
from supabase import Client
|
||||
from app.core.supabase import get_supabase
|
||||
from app.core.config import settings
|
||||
from loguru import logger
|
||||
from typing import Optional, Union, Dict, List, Any
|
||||
from pathlib import Path
|
||||
import asyncio
|
||||
import functools
|
||||
import os
|
||||
|
||||
# Supabase Storage 本地存储根目录
|
||||
SUPABASE_STORAGE_LOCAL_PATH = Path("/home/rongye/ProgramFiles/Supabase/volumes/storage/stub/stub")
|
||||
|
||||
class StorageService:
|
||||
def __init__(self):
|
||||
self.supabase: Client = get_supabase()
|
||||
self.BUCKET_MATERIALS = "materials"
|
||||
self.BUCKET_OUTPUTS = "outputs"
|
||||
|
||||
def _convert_to_public_url(self, url: str) -> str:
|
||||
"""将内部 URL 转换为公网可访问的 URL"""
|
||||
if settings.SUPABASE_PUBLIC_URL and settings.SUPABASE_URL:
|
||||
# 去掉末尾斜杠进行替换
|
||||
internal_url = settings.SUPABASE_URL.rstrip('/')
|
||||
public_url = settings.SUPABASE_PUBLIC_URL.rstrip('/')
|
||||
return url.replace(internal_url, public_url)
|
||||
return url
|
||||
|
||||
def get_local_file_path(self, bucket: str, path: str) -> Optional[str]:
|
||||
"""
|
||||
获取 Storage 文件的本地磁盘路径
|
||||
|
||||
Supabase Storage 文件存储结构:
|
||||
{STORAGE_ROOT}/{bucket}/{path}/{internal_uuid}
|
||||
|
||||
Returns:
|
||||
本地文件路径,如果不存在返回 None
|
||||
"""
|
||||
try:
|
||||
# 构建目录路径
|
||||
dir_path = SUPABASE_STORAGE_LOCAL_PATH / bucket / path
|
||||
|
||||
if not dir_path.exists():
|
||||
logger.warning(f"Storage 目录不存在: {dir_path}")
|
||||
return None
|
||||
|
||||
# 目录下只有一个文件(internal_uuid)
|
||||
files = list(dir_path.iterdir())
|
||||
if not files:
|
||||
logger.warning(f"Storage 目录为空: {dir_path}")
|
||||
return None
|
||||
|
||||
local_path = str(files[0])
|
||||
logger.info(f"获取本地文件路径: {local_path}")
|
||||
return local_path
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"获取本地文件路径失败: {e}")
|
||||
return None
|
||||
|
||||
async def upload_file(self, bucket: str, path: str, file_data: bytes, content_type: str) -> str:
|
||||
"""
|
||||
异步上传文件到 Supabase Storage
|
||||
"""
|
||||
try:
|
||||
# 运行在线程池中,避免阻塞事件循环
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
functools.partial(
|
||||
self.supabase.storage.from_(bucket).upload,
|
||||
path=path,
|
||||
file=file_data,
|
||||
file_options={"content-type": content_type, "upsert": "true"}
|
||||
)
|
||||
)
|
||||
logger.info(f"Storage upload success: {path}")
|
||||
return path
|
||||
except Exception as e:
|
||||
logger.error(f"Storage upload failed: {e}")
|
||||
raise e
|
||||
|
||||
async def get_signed_url(self, bucket: str, path: str, expires_in: int = 3600) -> str:
|
||||
"""异步获取签名访问链接"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
res = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.supabase.storage.from_(bucket).create_signed_url(path, expires_in)
|
||||
)
|
||||
|
||||
# 兼容处理
|
||||
url = ""
|
||||
if isinstance(res, dict) and "signedURL" in res:
|
||||
url = res["signedURL"]
|
||||
elif isinstance(res, str):
|
||||
url = res
|
||||
else:
|
||||
logger.warning(f"Unexpected signed_url response: {res}")
|
||||
url = res.get("signedURL", "") if isinstance(res, dict) else str(res)
|
||||
|
||||
# 转换为公网可访问的 URL
|
||||
return self._convert_to_public_url(url)
|
||||
except Exception as e:
|
||||
logger.error(f"Get signed URL failed: {e}")
|
||||
return ""
|
||||
|
||||
async def get_public_url(self, bucket: str, path: str) -> str:
|
||||
"""获取公开访问链接"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
res = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.supabase.storage.from_(bucket).get_public_url(path)
|
||||
)
|
||||
# 转换为公网可访问的 URL
|
||||
return self._convert_to_public_url(res)
|
||||
except Exception as e:
|
||||
logger.error(f"Get public URL failed: {e}")
|
||||
return ""
|
||||
|
||||
async def delete_file(self, bucket: str, path: str):
|
||||
"""异步删除文件"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.supabase.storage.from_(bucket).remove([path])
|
||||
)
|
||||
logger.info(f"Deleted file: {bucket}/{path}")
|
||||
except Exception as e:
|
||||
logger.error(f"Delete file failed: {e}")
|
||||
pass
|
||||
|
||||
async def list_files(self, bucket: str, path: str) -> List[Any]:
|
||||
"""异步列出文件"""
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
res = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: self.supabase.storage.from_(bucket).list(path)
|
||||
)
|
||||
return res or []
|
||||
except Exception as e:
|
||||
logger.error(f"List files failed: {e}")
|
||||
return []
|
||||
|
||||
storage_service = StorageService()
|
||||
@@ -3,7 +3,7 @@ Base uploader class for all social media platforms
|
||||
"""
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Dict, Any, Union
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ class BaseUploader(ABC):
|
||||
self.description = description
|
||||
|
||||
@abstractmethod
|
||||
async def main(self):
|
||||
async def main(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Main upload method - must be implemented by subclasses
|
||||
|
||||
@@ -50,7 +50,7 @@ class BaseUploader(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
def _get_timestamp(self, dt):
|
||||
def _get_timestamp(self, dt: Union[datetime, int]) -> int:
|
||||
"""
|
||||
Convert datetime to Unix timestamp
|
||||
|
||||
|
||||
@@ -2,9 +2,11 @@
|
||||
Bilibili uploader using biliup library
|
||||
"""
|
||||
import json
|
||||
import asyncio
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Dict, Any
|
||||
from datetime import datetime
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
try:
|
||||
from biliup.plugins.bili_webup import BiliBili, Data
|
||||
@@ -15,6 +17,9 @@ except ImportError:
|
||||
from loguru import logger
|
||||
from .base_uploader import BaseUploader
|
||||
|
||||
# Thread pool for running sync biliup code
|
||||
_executor = ThreadPoolExecutor(max_workers=2)
|
||||
|
||||
|
||||
class BilibiliUploader(BaseUploader):
|
||||
"""Bilibili video uploader using biliup library"""
|
||||
@@ -46,13 +51,19 @@ class BilibiliUploader(BaseUploader):
|
||||
"biliup library not installed. Please run: pip install biliup"
|
||||
)
|
||||
|
||||
async def main(self):
|
||||
async def main(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Upload video to Bilibili
|
||||
|
||||
Returns:
|
||||
dict: Upload result
|
||||
"""
|
||||
# Run sync upload in thread pool to avoid asyncio.run() conflict
|
||||
loop = asyncio.get_event_loop()
|
||||
return await loop.run_in_executor(_executor, self._upload_sync)
|
||||
|
||||
def _upload_sync(self) -> Dict[str, Any]:
|
||||
"""Synchronous upload logic (runs in thread pool)"""
|
||||
try:
|
||||
# 1. Load cookie data
|
||||
if not self.account_file or not Path(self.account_file).exists():
|
||||
@@ -66,6 +77,22 @@ class BilibiliUploader(BaseUploader):
|
||||
with open(self.account_file, 'r', encoding='utf-8') as f:
|
||||
cookie_data = json.load(f)
|
||||
|
||||
# Convert simple cookie format to biliup format if needed
|
||||
if 'cookie_info' not in cookie_data and 'SESSDATA' in cookie_data:
|
||||
# Transform to biliup expected format
|
||||
cookie_data = {
|
||||
'cookie_info': {
|
||||
'cookies': [
|
||||
{'name': k, 'value': v} for k, v in cookie_data.items()
|
||||
]
|
||||
},
|
||||
'token_info': {
|
||||
'access_token': cookie_data.get('access_token', ''),
|
||||
'refresh_token': cookie_data.get('refresh_token', '')
|
||||
}
|
||||
}
|
||||
logger.info("[B站] Cookie格式已转换")
|
||||
|
||||
# 2. Prepare video data
|
||||
data = Data()
|
||||
data.copyright = self.copyright
|
||||
@@ -97,17 +124,39 @@ class BilibiliUploader(BaseUploader):
|
||||
# Submit
|
||||
ret = bili.submit()
|
||||
|
||||
# Debug: log full response
|
||||
logger.debug(f"[B站] API响应: {ret}")
|
||||
|
||||
if ret.get('code') == 0:
|
||||
bvid = ret.get('bvid', '')
|
||||
logger.success(f"[B站] 上传成功: {bvid}")
|
||||
return {
|
||||
"success": True,
|
||||
"message": "上传成功" if data.dtime == 0 else "已设置定时发布",
|
||||
"url": f"https://www.bilibili.com/video/{bvid}" if bvid else None
|
||||
}
|
||||
# Try multiple keys for bvid (API may vary)
|
||||
bvid = ret.get('data', {}).get('bvid') or ret.get('bvid', '')
|
||||
aid = ret.get('data', {}).get('aid') or ret.get('aid', '')
|
||||
|
||||
if bvid:
|
||||
logger.success(f"[B站] 上传成功: {bvid}")
|
||||
return {
|
||||
"success": True,
|
||||
"message": "发布成功,待审核" if data.dtime == 0 else "已设置定时发布",
|
||||
"url": f"https://www.bilibili.com/video/{bvid}"
|
||||
}
|
||||
elif aid:
|
||||
logger.success(f"[B站] 上传成功: av{aid}")
|
||||
return {
|
||||
"success": True,
|
||||
"message": "发布成功,待审核" if data.dtime == 0 else "已设置定时发布",
|
||||
"url": f"https://www.bilibili.com/video/av{aid}"
|
||||
}
|
||||
else:
|
||||
# No bvid/aid but code=0, still consider success
|
||||
logger.warning(f"[B站] 上传返回code=0但无bvid/aid: {ret}")
|
||||
return {
|
||||
"success": True,
|
||||
"message": "发布成功,待审核",
|
||||
"url": None
|
||||
}
|
||||
else:
|
||||
error_msg = ret.get('message', '未知错误')
|
||||
logger.error(f"[B站] 上传失败: {error_msg}")
|
||||
logger.error(f"[B站] 上传失败: {error_msg} (完整响应: {ret})")
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"上传失败: {error_msg}",
|
||||
|
||||
@@ -1,169 +1,585 @@
|
||||
"""
|
||||
Douyin (抖音) uploader using Playwright
|
||||
Based on social-auto-upload implementation
|
||||
"""
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
import asyncio
|
||||
|
||||
from playwright.async_api import Playwright, async_playwright
|
||||
from loguru import logger
|
||||
|
||||
from .base_uploader import BaseUploader
|
||||
from .cookie_utils import set_init_script
|
||||
|
||||
|
||||
class DouyinUploader(BaseUploader):
|
||||
"""Douyin video uploader using Playwright"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
title: str,
|
||||
file_path: str,
|
||||
tags: List[str],
|
||||
publish_date: Optional[datetime] = None,
|
||||
account_file: Optional[str] = None,
|
||||
description: str = ""
|
||||
):
|
||||
super().__init__(title, file_path, tags, publish_date, account_file, description)
|
||||
self.upload_url = "https://creator.douyin.com/creator-micro/content/upload"
|
||||
|
||||
async def set_schedule_time(self, page, publish_date):
|
||||
"""Set scheduled publish time"""
|
||||
try:
|
||||
# Click "定时发布" radio button
|
||||
label_element = page.locator("[class^='radio']:has-text('定时发布')")
|
||||
await label_element.click()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Format time
|
||||
publish_date_hour = publish_date.strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
# Fill datetime input
|
||||
await page.locator('.semi-input[placeholder="日期和时间"]').click()
|
||||
await page.keyboard.press("Control+KeyA")
|
||||
await page.keyboard.type(str(publish_date_hour))
|
||||
await page.keyboard.press("Enter")
|
||||
|
||||
await asyncio.sleep(1)
|
||||
logger.info(f"[抖音] 已设置定时发布: {publish_date_hour}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[抖音] 设置定时发布失败: {e}")
|
||||
|
||||
async def upload(self, playwright: Playwright):
|
||||
"""Main upload logic"""
|
||||
try:
|
||||
# Launch browser
|
||||
browser = await playwright.chromium.launch(headless=False)
|
||||
context = await browser.new_context(storage_state=self.account_file)
|
||||
context = await set_init_script(context)
|
||||
|
||||
page = await context.new_page()
|
||||
|
||||
# Go to upload page
|
||||
await page.goto(self.upload_url)
|
||||
logger.info(f"[抖音] 正在上传: {self.file_path.name}")
|
||||
|
||||
# Upload video file
|
||||
await page.set_input_files("input[type='file']", str(self.file_path))
|
||||
|
||||
# Wait for redirect to publish page
|
||||
while True:
|
||||
try:
|
||||
await page.wait_for_url(
|
||||
"https://creator.douyin.com/creator-micro/content/publish?enter_from=publish_page",
|
||||
timeout=3000
|
||||
)
|
||||
logger.info("[抖音] 成功进入发布页面")
|
||||
break
|
||||
except:
|
||||
try:
|
||||
await page.wait_for_url(
|
||||
"https://creator.douyin.com/creator-micro/content/post/video?enter_from=publish_page",
|
||||
timeout=3000
|
||||
)
|
||||
logger.info("[抖音] 成功进入发布页面 (版本2)")
|
||||
break
|
||||
except:
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Fill title
|
||||
await asyncio.sleep(1)
|
||||
logger.info("[抖音] 正在填充标题和话题...")
|
||||
|
||||
title_container = page.get_by_text('作品描述').locator("..").locator("..").locator(
|
||||
"xpath=following-sibling::div[1]").locator("input")
|
||||
|
||||
if await title_container.count():
|
||||
await title_container.fill(self.title[:30])
|
||||
|
||||
# Add tags
|
||||
css_selector = ".zone-container"
|
||||
for tag in self.tags:
|
||||
await page.type(css_selector, "#" + tag)
|
||||
await page.press(css_selector, "Space")
|
||||
|
||||
logger.info(f"[抖音] 总共添加 {len(self.tags)} 个话题")
|
||||
|
||||
# Wait for upload to complete
|
||||
while True:
|
||||
try:
|
||||
number = await page.locator('[class^="long-card"] div:has-text("重新上传")').count()
|
||||
if number > 0:
|
||||
logger.success("[抖音] 视频上传完毕")
|
||||
break
|
||||
else:
|
||||
logger.info("[抖音] 正在上传视频中...")
|
||||
await asyncio.sleep(2)
|
||||
except:
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Set scheduled publish time if needed
|
||||
if self.publish_date != 0:
|
||||
await self.set_schedule_time(page, self.publish_date)
|
||||
|
||||
# Click publish button
|
||||
while True:
|
||||
try:
|
||||
publish_button = page.get_by_role('button', name="发布", exact=True)
|
||||
if await publish_button.count():
|
||||
await publish_button.click()
|
||||
|
||||
await page.wait_for_url(
|
||||
"https://creator.douyin.com/creator-micro/content/manage**",
|
||||
timeout=3000
|
||||
)
|
||||
logger.success("[抖音] 视频发布成功")
|
||||
break
|
||||
except:
|
||||
logger.info("[抖音] 视频正在发布中...")
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Save updated cookies
|
||||
await context.storage_state(path=self.account_file)
|
||||
logger.success("[抖音] Cookie 更新完毕")
|
||||
|
||||
await asyncio.sleep(2)
|
||||
await context.close()
|
||||
await browser.close()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "上传成功" if self.publish_date == 0 else "已设置定时发布",
|
||||
"url": None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"[抖音] 上传失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"上传失败: {str(e)}",
|
||||
"url": None
|
||||
}
|
||||
|
||||
async def main(self):
|
||||
"""Execute upload"""
|
||||
async with async_playwright() as playwright:
|
||||
return await self.upload(playwright)
|
||||
"""
|
||||
Douyin (抖音) uploader using Playwright
|
||||
Based on social-auto-upload implementation
|
||||
"""
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, List, Dict, Any
|
||||
import asyncio
|
||||
import time
|
||||
|
||||
from playwright.async_api import Playwright, async_playwright
|
||||
from loguru import logger
|
||||
|
||||
from .base_uploader import BaseUploader
|
||||
from .cookie_utils import set_init_script
|
||||
|
||||
|
||||
class DouyinUploader(BaseUploader):
|
||||
"""Douyin video uploader using Playwright"""
|
||||
|
||||
# 超时配置 (秒)
|
||||
UPLOAD_TIMEOUT = 300 # 视频上传超时
|
||||
PUBLISH_TIMEOUT = 180 # 发布检测超时
|
||||
PAGE_REDIRECT_TIMEOUT = 60 # 页面跳转超时
|
||||
POLL_INTERVAL = 2 # 轮询间隔
|
||||
MAX_CLICK_RETRIES = 3 # 按钮点击重试次数
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
title: str,
|
||||
file_path: str,
|
||||
tags: List[str],
|
||||
publish_date: Optional[datetime] = None,
|
||||
account_file: Optional[str] = None,
|
||||
description: str = ""
|
||||
):
|
||||
super().__init__(title, file_path, tags, publish_date, account_file, description)
|
||||
self.upload_url = "https://creator.douyin.com/creator-micro/content/upload"
|
||||
|
||||
async def _is_text_visible(self, page, text: str, exact: bool = False) -> bool:
|
||||
try:
|
||||
return await page.get_by_text(text, exact=exact).first.is_visible()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
async def _first_visible_locator(self, locator, timeout: int = 1000):
|
||||
try:
|
||||
if await locator.count() == 0:
|
||||
return None
|
||||
candidate = locator.first
|
||||
if await candidate.is_visible(timeout=timeout):
|
||||
return candidate
|
||||
except Exception:
|
||||
return None
|
||||
return None
|
||||
|
||||
async def _wait_for_publish_result(self, page, max_wait_time: int = 180):
|
||||
success_texts = ["发布成功", "作品已发布", "再发一条", "查看作品", "审核中", "待审核"]
|
||||
weak_texts = ["发布完成"]
|
||||
failure_texts = ["发布失败", "发布异常", "发布出错", "请完善", "请补充", "请先上传"]
|
||||
start_time = time.time()
|
||||
poll_interval = 2
|
||||
weak_reason = None
|
||||
|
||||
while time.time() - start_time < max_wait_time:
|
||||
if page.is_closed():
|
||||
return False, "页面已关闭", False
|
||||
|
||||
current_url = page.url
|
||||
if "content/manage" in current_url:
|
||||
return True, f"已跳转到管理页面 (URL: {current_url})", False
|
||||
|
||||
for text in success_texts:
|
||||
if await self._is_text_visible(page, text, exact=False):
|
||||
return True, f"检测到成功提示: {text}", False
|
||||
|
||||
for text in failure_texts:
|
||||
if await self._is_text_visible(page, text, exact=False):
|
||||
return False, f"检测到失败提示: {text}", False
|
||||
|
||||
for text in weak_texts:
|
||||
if await self._is_text_visible(page, text, exact=False):
|
||||
weak_reason = text
|
||||
|
||||
logger.info("[抖音] 视频正在发布中...")
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
if weak_reason:
|
||||
return False, f"检测到提示: {weak_reason}", True
|
||||
|
||||
return False, "发布检测超时", True
|
||||
|
||||
async def _fill_title(self, page, title: str) -> bool:
|
||||
title_text = title[:30]
|
||||
locator_candidates = []
|
||||
|
||||
try:
|
||||
label_locator = page.get_by_text("作品描述").locator("..").locator("..").locator(
|
||||
"xpath=following-sibling::div[1]"
|
||||
).locator("textarea, input, div[contenteditable='true']")
|
||||
locator_candidates.append(label_locator)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
locator_candidates.extend([
|
||||
page.locator("textarea[placeholder*='作品描述']"),
|
||||
page.locator("textarea[placeholder*='描述']"),
|
||||
page.locator("input[placeholder*='作品描述']"),
|
||||
page.locator("input[placeholder*='描述']"),
|
||||
page.locator("div[contenteditable='true']"),
|
||||
])
|
||||
|
||||
for locator in locator_candidates:
|
||||
try:
|
||||
if await locator.count() > 0:
|
||||
target = locator.first
|
||||
await target.fill(title_text)
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
async def _select_cover_if_needed(self, page) -> bool:
|
||||
try:
|
||||
cover_button = page.get_by_text("选择封面", exact=False).first
|
||||
if await cover_button.is_visible():
|
||||
await cover_button.click()
|
||||
logger.info("[抖音] 尝试选择封面")
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
dialog = page.locator(
|
||||
"div.dy-creator-content-modal-wrap, div[role='dialog'], "
|
||||
"div[class*='modal'], div[class*='dialog']"
|
||||
).last
|
||||
scopes = [dialog] if await dialog.count() > 0 else [page]
|
||||
|
||||
switched = False
|
||||
for scope in scopes:
|
||||
for selector in [
|
||||
"button:has-text('设置横封面')",
|
||||
"div:has-text('设置横封面')",
|
||||
"span:has-text('设置横封面')",
|
||||
]:
|
||||
try:
|
||||
button = await self._first_visible_locator(scope.locator(selector))
|
||||
if button:
|
||||
await button.click()
|
||||
logger.info("[抖音] 已切换到横封面设置")
|
||||
await asyncio.sleep(0.5)
|
||||
switched = True
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
if switched:
|
||||
break
|
||||
|
||||
selected = False
|
||||
for scope in scopes:
|
||||
for selector in [
|
||||
"div[class*='cover'] img",
|
||||
"div[class*='cover']",
|
||||
"div[class*='frame'] img",
|
||||
"div[class*='frame']",
|
||||
"div[class*='preset']",
|
||||
"img",
|
||||
]:
|
||||
try:
|
||||
candidate = await self._first_visible_locator(scope.locator(selector))
|
||||
if candidate:
|
||||
await candidate.click()
|
||||
logger.info("[抖音] 已选择封面帧")
|
||||
selected = True
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
if selected:
|
||||
break
|
||||
|
||||
confirm_selectors = [
|
||||
"button:has-text('完成')",
|
||||
"button:has-text('确定')",
|
||||
"button:has-text('保存')",
|
||||
"button:has-text('确认')",
|
||||
]
|
||||
for selector in confirm_selectors:
|
||||
try:
|
||||
button = await self._first_visible_locator(page.locator(selector))
|
||||
if button:
|
||||
if not await button.is_enabled():
|
||||
for _ in range(8):
|
||||
if await button.is_enabled():
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
await button.click()
|
||||
logger.info(f"[抖音] 封面已确认: {selector}")
|
||||
await asyncio.sleep(0.5)
|
||||
if await dialog.count() > 0:
|
||||
try:
|
||||
await dialog.wait_for(state="hidden", timeout=5000)
|
||||
except Exception:
|
||||
pass
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return selected
|
||||
except Exception as e:
|
||||
logger.warning(f"[抖音] 选择封面失败: {e}")
|
||||
return False
|
||||
|
||||
async def _click_publish_confirm_modal(self, page):
|
||||
confirm_selectors = [
|
||||
"button:has-text('确认发布')",
|
||||
"button:has-text('继续发布')",
|
||||
"button:has-text('确定发布')",
|
||||
"button:has-text('发布确认')",
|
||||
]
|
||||
for selector in confirm_selectors:
|
||||
try:
|
||||
button = page.locator(selector).first
|
||||
if await button.is_visible():
|
||||
await button.click()
|
||||
logger.info(f"[抖音] 点击了发布确认按钮: {selector}")
|
||||
await asyncio.sleep(1)
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
return False
|
||||
|
||||
async def _dismiss_blocking_modal(self, page) -> bool:
|
||||
modal_locator = page.locator(
|
||||
"div.dy-creator-content-modal-wrap, div[role='dialog'], "
|
||||
"div[class*='modal'], div[class*='dialog']"
|
||||
)
|
||||
try:
|
||||
count = await modal_locator.count()
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
if count == 0:
|
||||
return False
|
||||
|
||||
button_texts = [
|
||||
"我知道了",
|
||||
"知道了",
|
||||
"确定",
|
||||
"继续",
|
||||
"继续发布",
|
||||
"确认",
|
||||
"同意并继续",
|
||||
"完成",
|
||||
"好的",
|
||||
"明白了",
|
||||
]
|
||||
close_selectors = [
|
||||
"button[class*='close']",
|
||||
"span[class*='close']",
|
||||
"i[class*='close']",
|
||||
]
|
||||
|
||||
for index in range(count):
|
||||
modal = modal_locator.nth(index)
|
||||
try:
|
||||
if not await modal.is_visible():
|
||||
continue
|
||||
|
||||
for text in button_texts:
|
||||
try:
|
||||
button = modal.get_by_role("button", name=text).first
|
||||
if await button.is_visible():
|
||||
await button.click()
|
||||
logger.info(f"[抖音] 关闭弹窗: {text}")
|
||||
await asyncio.sleep(0.5)
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
for selector in close_selectors:
|
||||
try:
|
||||
close_button = modal.locator(selector).first
|
||||
if await close_button.is_visible():
|
||||
await close_button.click()
|
||||
logger.info("[抖音] 关闭弹窗: close")
|
||||
await asyncio.sleep(0.5)
|
||||
return True
|
||||
except Exception:
|
||||
continue
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
return False
|
||||
|
||||
async def _verify_publish_in_manage(self, page):
|
||||
manage_url = "https://creator.douyin.com/creator-micro/content/manage"
|
||||
try:
|
||||
await page.goto(manage_url)
|
||||
await page.wait_for_load_state("domcontentloaded")
|
||||
await asyncio.sleep(2)
|
||||
title_text = self.title[:30]
|
||||
title_locator = page.get_by_text(title_text, exact=False).first
|
||||
if await title_locator.is_visible():
|
||||
return True, "内容管理中检测到新作品"
|
||||
if await self._is_text_visible(page, "审核中", exact=False):
|
||||
return True, "内容管理显示审核中"
|
||||
except Exception as e:
|
||||
return False, f"无法验证内容管理: {e}"
|
||||
return False, "内容管理中未找到视频"
|
||||
|
||||
async def set_schedule_time(self, page, publish_date):
|
||||
"""Set scheduled publish time"""
|
||||
try:
|
||||
# Click "定时发布" radio button
|
||||
label_element = page.locator("[class^='radio']:has-text('定时发布')")
|
||||
await label_element.click()
|
||||
await asyncio.sleep(1)
|
||||
|
||||
# Format time
|
||||
publish_date_hour = publish_date.strftime("%Y-%m-%d %H:%M")
|
||||
|
||||
# Fill datetime input
|
||||
await page.locator('.semi-input[placeholder="日期和时间"]').click()
|
||||
await page.keyboard.press("Control+KeyA")
|
||||
await page.keyboard.type(str(publish_date_hour))
|
||||
await page.keyboard.press("Enter")
|
||||
|
||||
await asyncio.sleep(1)
|
||||
logger.info(f"[抖音] 已设置定时发布: {publish_date_hour}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"[抖音] 设置定时发布失败: {e}")
|
||||
|
||||
async def upload(self, playwright: Playwright) -> dict:
|
||||
"""Main upload logic with guaranteed resource cleanup"""
|
||||
browser = None
|
||||
context = None
|
||||
try:
|
||||
# Launch browser in headless mode for server deployment
|
||||
browser = await playwright.chromium.launch(headless=True)
|
||||
context = await browser.new_context(storage_state=self.account_file)
|
||||
context = await set_init_script(context)
|
||||
|
||||
page = await context.new_page()
|
||||
|
||||
# Go to upload page
|
||||
await page.goto(self.upload_url)
|
||||
await page.wait_for_load_state('domcontentloaded')
|
||||
await asyncio.sleep(2)
|
||||
|
||||
logger.info(f"[抖音] 正在上传: {self.file_path.name}")
|
||||
|
||||
# Check if redirected to login page (more reliable than text detection)
|
||||
current_url = page.url
|
||||
if "login" in current_url or "passport" in current_url:
|
||||
logger.error("[抖音] Cookie 已失效,被重定向到登录页")
|
||||
return {
|
||||
"success": False,
|
||||
"message": "Cookie 已失效,请重新登录",
|
||||
"url": None
|
||||
}
|
||||
|
||||
# Ensure we're on the upload page
|
||||
if "content/upload" not in page.url:
|
||||
logger.info("[抖音] 当前不在上传页面,强制跳转...")
|
||||
await page.goto(self.upload_url)
|
||||
await asyncio.sleep(2)
|
||||
|
||||
# Try multiple selectors for the file input (page structure varies)
|
||||
file_uploaded = False
|
||||
selectors = [
|
||||
"div[class^='container'] input", # Primary selector from SuperIPAgent
|
||||
"input[type='file']", # Fallback selector
|
||||
"div[class^='upload'] input[type='file']", # Alternative
|
||||
]
|
||||
|
||||
for selector in selectors:
|
||||
try:
|
||||
logger.info(f"[抖音] 尝试选择器: {selector}")
|
||||
locator = page.locator(selector).first
|
||||
if await locator.count() > 0:
|
||||
await locator.set_input_files(str(self.file_path))
|
||||
file_uploaded = True
|
||||
logger.info(f"[抖音] 文件上传成功使用选择器: {selector}")
|
||||
break
|
||||
except Exception as e:
|
||||
logger.warning(f"[抖音] 选择器 {selector} 失败: {e}")
|
||||
continue
|
||||
|
||||
if not file_uploaded:
|
||||
logger.error("[抖音] 所有选择器都失败,无法上传文件")
|
||||
return {
|
||||
"success": False,
|
||||
"message": "无法找到上传按钮,页面可能已更新",
|
||||
"url": None
|
||||
}
|
||||
|
||||
# Wait for redirect to publish page (with timeout)
|
||||
redirect_start = time.time()
|
||||
while time.time() - redirect_start < self.PAGE_REDIRECT_TIMEOUT:
|
||||
current_url = page.url
|
||||
if "content/publish" in current_url or "content/post/video" in current_url:
|
||||
logger.info("[抖音] 成功进入发布页面")
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
else:
|
||||
logger.error("[抖音] 等待发布页面超时")
|
||||
return {
|
||||
"success": False,
|
||||
"message": "等待发布页面超时",
|
||||
"url": None
|
||||
}
|
||||
|
||||
# Fill title
|
||||
await asyncio.sleep(1)
|
||||
logger.info("[抖音] 正在填充标题和话题...")
|
||||
|
||||
if not await self._fill_title(page, self.title):
|
||||
logger.warning("[抖音] 未找到作品描述输入框")
|
||||
|
||||
# Add tags
|
||||
css_selector = ".zone-container"
|
||||
for tag in self.tags:
|
||||
await page.type(css_selector, "#" + tag)
|
||||
await page.press(css_selector, "Space")
|
||||
|
||||
logger.info(f"[抖音] 总共添加 {len(self.tags)} 个话题")
|
||||
|
||||
cover_selected = await self._select_cover_if_needed(page)
|
||||
if not cover_selected:
|
||||
logger.warning("[抖音] 未确认封面选择,可能影响发布")
|
||||
|
||||
# Wait for upload to complete (with timeout)
|
||||
upload_start = time.time()
|
||||
while time.time() - upload_start < self.UPLOAD_TIMEOUT:
|
||||
try:
|
||||
number = await page.locator('[class^="long-card"] div:has-text("重新上传")').count()
|
||||
if number > 0:
|
||||
logger.success("[抖音] 视频上传完毕")
|
||||
break
|
||||
else:
|
||||
logger.info("[抖音] 正在上传视频中...")
|
||||
await asyncio.sleep(self.POLL_INTERVAL)
|
||||
except Exception:
|
||||
await asyncio.sleep(self.POLL_INTERVAL)
|
||||
else:
|
||||
logger.error("[抖音] 视频上传超时")
|
||||
return {
|
||||
"success": False,
|
||||
"message": "视频上传超时",
|
||||
"url": None
|
||||
}
|
||||
|
||||
# Set scheduled publish time if needed
|
||||
if self.publish_date != 0:
|
||||
await self.set_schedule_time(page, self.publish_date)
|
||||
|
||||
# Click publish button
|
||||
# 使用更稳健的点击逻辑
|
||||
try:
|
||||
publish_label = "定时发布" if self.publish_date != 0 else "发布"
|
||||
publish_button = page.get_by_role('button', name=publish_label, exact=True)
|
||||
# 等待按钮出现
|
||||
await publish_button.wait_for(state="visible", timeout=10000)
|
||||
if not await publish_button.is_enabled():
|
||||
logger.error("[抖音] 发布按钮不可点击,可能需要补充封面或确认信息")
|
||||
return {
|
||||
"success": False,
|
||||
"message": "发布按钮不可点击,请检查封面/声明等必填项",
|
||||
"url": None
|
||||
}
|
||||
await asyncio.sleep(1) # 额外等待以确保可交互
|
||||
|
||||
clicked = False
|
||||
for attempt in range(self.MAX_CLICK_RETRIES):
|
||||
await self._dismiss_blocking_modal(page)
|
||||
try:
|
||||
await publish_button.click(timeout=5000)
|
||||
logger.info(f"[抖音] 点击了{publish_label}按钮")
|
||||
clicked = True
|
||||
break
|
||||
except Exception as click_error:
|
||||
logger.warning(f"[抖音] 点击发布按钮失败,重试 {attempt + 1}/{self.MAX_CLICK_RETRIES}: {click_error}")
|
||||
try:
|
||||
await page.keyboard.press("Escape")
|
||||
except Exception:
|
||||
pass
|
||||
await asyncio.sleep(1)
|
||||
|
||||
if not clicked:
|
||||
raise RuntimeError("点击发布按钮失败")
|
||||
except Exception as e:
|
||||
logger.error(f"[抖音] 点击发布按钮失败: {e}")
|
||||
# 尝试备用选择器
|
||||
try:
|
||||
fallback_selectors = ["button:has-text('发布')", "button:has-text('定时发布')"]
|
||||
clicked = False
|
||||
for selector in fallback_selectors:
|
||||
try:
|
||||
await page.click(selector, timeout=5000)
|
||||
logger.info(f"[抖音] 使用备用选择器点击了按钮: {selector}")
|
||||
clicked = True
|
||||
break
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
if not clicked:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "无法点击发布按钮,请检查页面状态",
|
||||
"url": None
|
||||
}
|
||||
except Exception:
|
||||
return {
|
||||
"success": False,
|
||||
"message": "无法点击发布按钮,请检查页面状态",
|
||||
"url": None
|
||||
}
|
||||
|
||||
await self._click_publish_confirm_modal(page)
|
||||
|
||||
# 4. 检测发布完成
|
||||
publish_success, publish_reason, is_timeout = await self._wait_for_publish_result(page)
|
||||
if not publish_success and is_timeout:
|
||||
verify_success, verify_reason = await self._verify_publish_in_manage(page)
|
||||
if verify_success:
|
||||
publish_success = True
|
||||
publish_reason = verify_reason
|
||||
else:
|
||||
publish_reason = f"{publish_reason}; {verify_reason}"
|
||||
if publish_success:
|
||||
logger.success(f"[抖音] 发布成功: {publish_reason}")
|
||||
else:
|
||||
if is_timeout:
|
||||
logger.warning("[抖音] 发布检测超时,但这不一定代表失败")
|
||||
else:
|
||||
logger.warning(f"[抖音] 发布未成功: {publish_reason}")
|
||||
|
||||
# Save updated cookies
|
||||
await context.storage_state(path=self.account_file)
|
||||
logger.success("[抖音] Cookie 更新完毕")
|
||||
|
||||
await asyncio.sleep(2)
|
||||
|
||||
if publish_success:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "发布成功,待审核",
|
||||
"url": None
|
||||
}
|
||||
if is_timeout:
|
||||
return {
|
||||
"success": True,
|
||||
"message": "发布检测超时,请到抖音后台确认",
|
||||
"url": None
|
||||
}
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"发布失败: {publish_reason}",
|
||||
"url": None
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"[抖音] 上传失败: {e}")
|
||||
return {
|
||||
"success": False,
|
||||
"message": f"上传失败: {str(e)}",
|
||||
"url": None
|
||||
}
|
||||
finally:
|
||||
# 确保资源释放
|
||||
if context:
|
||||
try:
|
||||
await context.close()
|
||||
except Exception:
|
||||
pass
|
||||
if browser:
|
||||
try:
|
||||
await browser.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def main(self) -> Dict[str, Any]:
|
||||
"""Execute upload"""
|
||||
async with async_playwright() as playwright:
|
||||
return await self.upload(playwright)
|
||||
|
||||
@@ -4,7 +4,7 @@ Based on social-auto-upload implementation
|
||||
"""
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from typing import Optional, List
|
||||
from typing import Optional, List, Dict, Any
|
||||
import asyncio
|
||||
|
||||
from playwright.async_api import Playwright, async_playwright
|
||||
@@ -17,6 +17,11 @@ from .cookie_utils import set_init_script
|
||||
class XiaohongshuUploader(BaseUploader):
|
||||
"""Xiaohongshu video uploader using Playwright"""
|
||||
|
||||
# 超时配置 (秒)
|
||||
UPLOAD_TIMEOUT = 300 # 视频上传超时
|
||||
PUBLISH_TIMEOUT = 120 # 发布检测超时
|
||||
POLL_INTERVAL = 1 # 轮询间隔
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
title: str,
|
||||
@@ -54,11 +59,13 @@ class XiaohongshuUploader(BaseUploader):
|
||||
except Exception as e:
|
||||
logger.error(f"[小红书] 设置定时发布失败: {e}")
|
||||
|
||||
async def upload(self, playwright: Playwright):
|
||||
"""Main upload logic"""
|
||||
async def upload(self, playwright: Playwright) -> dict:
|
||||
"""Main upload logic with guaranteed resource cleanup"""
|
||||
browser = None
|
||||
context = None
|
||||
try:
|
||||
# Launch browser
|
||||
browser = await playwright.chromium.launch(headless=False)
|
||||
# Launch browser (headless for server deployment)
|
||||
browser = await playwright.chromium.launch(headless=True)
|
||||
context = await browser.new_context(
|
||||
viewport={"width": 1600, "height": 900},
|
||||
storage_state=self.account_file
|
||||
@@ -74,8 +81,10 @@ class XiaohongshuUploader(BaseUploader):
|
||||
# Upload video file
|
||||
await page.locator("div[class^='upload-content'] input[class='upload-input']").set_input_files(str(self.file_path))
|
||||
|
||||
# Wait for upload to complete
|
||||
while True:
|
||||
# Wait for upload to complete (with timeout)
|
||||
import time
|
||||
upload_start = time.time()
|
||||
while time.time() - upload_start < self.UPLOAD_TIMEOUT:
|
||||
try:
|
||||
upload_input = await page.wait_for_selector('input.upload-input', timeout=3000)
|
||||
preview_new = await upload_input.query_selector(
|
||||
@@ -100,11 +109,18 @@ class XiaohongshuUploader(BaseUploader):
|
||||
else:
|
||||
logger.info("[小红书] 未找到预览元素,继续等待...")
|
||||
|
||||
await asyncio.sleep(1)
|
||||
await asyncio.sleep(self.POLL_INTERVAL)
|
||||
|
||||
except Exception as e:
|
||||
logger.info(f"[小红书] 检测过程: {str(e)},重新尝试...")
|
||||
await asyncio.sleep(0.5)
|
||||
else:
|
||||
logger.error("[小红书] 视频上传超时")
|
||||
return {
|
||||
"success": False,
|
||||
"message": "视频上传超时",
|
||||
"url": None
|
||||
}
|
||||
|
||||
# Fill title and tags
|
||||
await asyncio.sleep(1)
|
||||
@@ -126,8 +142,9 @@ class XiaohongshuUploader(BaseUploader):
|
||||
if self.publish_date != 0:
|
||||
await self.set_schedule_time(page, self.publish_date)
|
||||
|
||||
# Click publish button
|
||||
while True:
|
||||
# Click publish button (with timeout)
|
||||
publish_start = time.time()
|
||||
while time.time() - publish_start < self.PUBLISH_TIMEOUT:
|
||||
try:
|
||||
if self.publish_date != 0:
|
||||
await page.locator('button:has-text("定时发布")').click()
|
||||
@@ -140,21 +157,21 @@ class XiaohongshuUploader(BaseUploader):
|
||||
)
|
||||
logger.success("[小红书] 视频发布成功")
|
||||
break
|
||||
except:
|
||||
except Exception:
|
||||
logger.info("[小红书] 视频正在发布中...")
|
||||
await asyncio.sleep(0.5)
|
||||
else:
|
||||
logger.warning("[小红书] 发布检测超时,请手动确认")
|
||||
|
||||
# Save updated cookies
|
||||
await context.storage_state(path=self.account_file)
|
||||
logger.success("[小红书] Cookie 更新完毕")
|
||||
|
||||
await asyncio.sleep(2)
|
||||
await context.close()
|
||||
await browser.close()
|
||||
|
||||
return {
|
||||
"success": True,
|
||||
"message": "上传成功" if self.publish_date == 0 else "已设置定时发布",
|
||||
"message": "发布成功,待审核" if self.publish_date == 0 else "已设置定时发布",
|
||||
"url": None
|
||||
}
|
||||
|
||||
@@ -165,8 +182,20 @@ class XiaohongshuUploader(BaseUploader):
|
||||
"message": f"上传失败: {str(e)}",
|
||||
"url": None
|
||||
}
|
||||
finally:
|
||||
# 确保资源释放
|
||||
if context:
|
||||
try:
|
||||
await context.close()
|
||||
except Exception:
|
||||
pass
|
||||
if browser:
|
||||
try:
|
||||
await browser.close()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
async def main(self):
|
||||
async def main(self) -> Dict[str, Any]:
|
||||
"""Execute upload"""
|
||||
async with async_playwright() as playwright:
|
||||
return await self.upload(playwright)
|
||||
|
||||
73
backend/database/schema.sql
Normal file
73
backend/database/schema.sql
Normal file
@@ -0,0 +1,73 @@
|
||||
-- ViGent 用户认证系统数据库表
|
||||
-- 在 Supabase SQL Editor 中执行
|
||||
|
||||
-- 1. 创建 users 表
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
email TEXT UNIQUE NOT NULL,
|
||||
password_hash TEXT NOT NULL,
|
||||
username TEXT,
|
||||
role TEXT DEFAULT 'pending' CHECK (role IN ('pending', 'user', 'admin')),
|
||||
is_active BOOLEAN DEFAULT FALSE,
|
||||
expires_at TIMESTAMP WITH TIME ZONE,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||||
updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- 2. 创建 user_sessions 表 (单设备登录)
|
||||
CREATE TABLE IF NOT EXISTS user_sessions (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID REFERENCES users(id) ON DELETE CASCADE UNIQUE,
|
||||
session_token TEXT UNIQUE NOT NULL,
|
||||
device_info TEXT,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW()
|
||||
);
|
||||
|
||||
-- 3. 创建 social_accounts 表 (社交账号绑定)
|
||||
CREATE TABLE IF NOT EXISTS social_accounts (
|
||||
id UUID PRIMARY KEY DEFAULT gen_random_uuid(),
|
||||
user_id UUID REFERENCES users(id) ON DELETE CASCADE,
|
||||
platform TEXT NOT NULL CHECK (platform IN ('bilibili', 'douyin', 'xiaohongshu')),
|
||||
logged_in BOOLEAN DEFAULT FALSE,
|
||||
created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(),
|
||||
UNIQUE(user_id, platform)
|
||||
);
|
||||
|
||||
-- 4. 创建索引
|
||||
CREATE INDEX IF NOT EXISTS idx_users_email ON users(email);
|
||||
CREATE INDEX IF NOT EXISTS idx_sessions_user_id ON user_sessions(user_id);
|
||||
CREATE INDEX IF NOT EXISTS idx_social_user_platform ON social_accounts(user_id, platform);
|
||||
|
||||
-- 5. 启用 RLS (行级安全)
|
||||
ALTER TABLE users ENABLE ROW LEVEL SECURITY;
|
||||
ALTER TABLE user_sessions ENABLE ROW LEVEL SECURITY;
|
||||
ALTER TABLE social_accounts ENABLE ROW LEVEL SECURITY;
|
||||
|
||||
-- 6. RLS 策略 (Service Role 可以绑过 RLS,所以后端使用 service_role key 时不受限)
|
||||
-- 以下策略仅对 anon key 生效
|
||||
|
||||
-- users: 仅管理员可查看所有用户,普通用户只能查看自己
|
||||
CREATE POLICY "Users can view own profile" ON users
|
||||
FOR SELECT USING (auth.uid()::text = id::text);
|
||||
|
||||
-- user_sessions: 用户只能访问自己的 session
|
||||
CREATE POLICY "Users can access own sessions" ON user_sessions
|
||||
FOR ALL USING (user_id::text = auth.uid()::text);
|
||||
|
||||
-- social_accounts: 用户只能访问自己的社交账号
|
||||
CREATE POLICY "Users can access own social accounts" ON social_accounts
|
||||
FOR ALL USING (user_id::text = auth.uid()::text);
|
||||
|
||||
-- 7. 更新时间自动更新触发器
|
||||
CREATE OR REPLACE FUNCTION update_updated_at()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
NEW.updated_at = NOW();
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
CREATE TRIGGER users_updated_at
|
||||
BEFORE UPDATE ON users
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_updated_at();
|
||||
93
backend/generate_keys.py
Normal file
93
backend/generate_keys.py
Normal file
@@ -0,0 +1,93 @@
|
||||
import hmac
|
||||
import hashlib
|
||||
import base64
|
||||
import json
|
||||
import time
|
||||
import secrets
|
||||
import string
|
||||
|
||||
def generate_secure_secret(length=64):
|
||||
"""生成安全的随机十六进制字符串"""
|
||||
return secrets.token_hex(length // 2)
|
||||
|
||||
def generate_random_string(length=32):
|
||||
"""生成包含字母数字的随机字符串 (用于密码等)"""
|
||||
chars = string.ascii_letters + string.digits
|
||||
return ''.join(secrets.choice(chars) for _ in range(length))
|
||||
|
||||
def base64url_encode(input_bytes):
|
||||
return base64.urlsafe_b64encode(input_bytes).decode('utf-8').rstrip('=')
|
||||
|
||||
def generate_jwt(role, secret):
|
||||
# 1. Header
|
||||
header = {
|
||||
"alg": "HS256",
|
||||
"typ": "JWT"
|
||||
}
|
||||
|
||||
# 2. Payload
|
||||
now = int(time.time())
|
||||
payload = {
|
||||
"role": role,
|
||||
"iss": "supabase",
|
||||
"iat": now,
|
||||
"exp": now + 315360000 # 10年有效期
|
||||
}
|
||||
|
||||
# Encode parts
|
||||
header_b64 = base64url_encode(json.dumps(header).encode('utf-8'))
|
||||
payload_b64 = base64url_encode(json.dumps(payload).encode('utf-8'))
|
||||
|
||||
# 3. Signature
|
||||
signing_input = f"{header_b64}.{payload_b64}".encode('utf-8')
|
||||
signature = hmac.new(
|
||||
secret.encode('utf-8'),
|
||||
signing_input,
|
||||
hashlib.sha256
|
||||
).digest()
|
||||
signature_b64 = base64url_encode(signature)
|
||||
|
||||
return f"{header_b64}.{payload_b64}.{signature_b64}"
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("=" * 60)
|
||||
print("🔐 Supabase 全自动配置生成器 (Zero Dependency)")
|
||||
print("=" * 60)
|
||||
print("正在生成所有密钥...\n")
|
||||
|
||||
# 1. 自动生成主密钥
|
||||
jwt_secret = generate_secure_secret(64)
|
||||
|
||||
# 2. 基于主密钥生成 JWT
|
||||
anon_key = generate_jwt("anon", jwt_secret)
|
||||
service_key = generate_jwt("service_role", jwt_secret)
|
||||
|
||||
# 3. 生成其他加密 Key和密码
|
||||
vault_key = generate_secure_secret(32)
|
||||
meta_key = generate_secure_secret(32)
|
||||
secret_key_base = generate_secure_secret(64)
|
||||
|
||||
db_password = generate_random_string(20)
|
||||
dashboard_password = generate_random_string(16)
|
||||
|
||||
# 4. 输出结果
|
||||
print(f"✅ 生成完成!请直接复制以下内容覆盖您的 .env 文件中的对应部分:\n")
|
||||
|
||||
print("-" * 20 + " [ 复制开始 ] " + "-" * 20)
|
||||
print(f"# === 数据库安全配置 ===")
|
||||
print(f"POSTGRES_PASSWORD={db_password}")
|
||||
print(f"JWT_SECRET={jwt_secret}")
|
||||
print(f"ANON_KEY={anon_key}")
|
||||
print(f"SERVICE_ROLE_KEY={service_key}")
|
||||
print(f"SECRET_KEY_BASE={secret_key_base}")
|
||||
print(f"VAULT_ENC_KEY={vault_key}")
|
||||
print(f"PG_META_CRYPTO_KEY={meta_key}")
|
||||
print(f"\n# === 管理后台配置 ===")
|
||||
print(f"DASHBOARD_USERNAME=admin")
|
||||
print(f"DASHBOARD_PASSWORD={dashboard_password}")
|
||||
print("-" * 20 + " [ 复制结束 ] " + "-" * 20)
|
||||
|
||||
print("\n💡 提示:")
|
||||
print(f"1. 数据库密码: {db_password}")
|
||||
print(f"2. 后台登录密码: {dashboard_password}")
|
||||
print("请妥善保管这些密码!")
|
||||
@@ -21,3 +21,10 @@ requests>=2.31.0
|
||||
|
||||
# 社交媒体发布
|
||||
biliup>=0.4.0
|
||||
|
||||
# 用户认证
|
||||
email-validator>=2.1.0
|
||||
supabase>=2.0.0
|
||||
python-jose[cryptography]>=3.3.0
|
||||
passlib[bcrypt]>=1.7.4
|
||||
bcrypt==4.0.1
|
||||
|
||||
@@ -1,36 +1,72 @@
|
||||
This is a [Next.js](https://nextjs.org) project bootstrapped with [`create-next-app`](https://nextjs.org/docs/app/api-reference/cli/create-next-app).
|
||||
# ViGent2 Frontend
|
||||
|
||||
## Getting Started
|
||||
ViGent2 的前端界面,采用 Next.js 14 + TailwindCSS 构建。
|
||||
|
||||
First, run the development server:
|
||||
## ✨ 核心功能
|
||||
|
||||
### 1. 视频生成 (`/`)
|
||||
- **素材管理**: 拖拽上传人物视频,实时预览。
|
||||
- **文案配音**: 集成 EdgeTTS,支持多音色选择 (云溪 / 晓晓)。
|
||||
- **进度追踪**: 实时显示视频生成进度 (10% -> 100%)。
|
||||
- **结果预览**: 生成完成后直接播放下载。
|
||||
|
||||
### 2. 全自动发布 (`/publish`) [Day 7 新增]
|
||||
- **多平台管理**: 统一管理 B站、抖音、小红书账号状态。
|
||||
- **扫码登录**:
|
||||
- 集成后端 Playwright 生成的 QR Code。
|
||||
- 实时检测扫码状态 (Wait/Success)。
|
||||
- Cookie 自动保存与状态同步。
|
||||
- **发布配置**: 设置视频标题、标签、简介。
|
||||
- **定时任务**: 支持 "立即发布" 或 "定时发布"。
|
||||
|
||||
## 🛠️ 技术栈
|
||||
|
||||
- **框架**: Next.js 14 (App Router)
|
||||
- **样式**: TailwindCSS
|
||||
- **图标**: Lucide React
|
||||
- **组件**: 自定义现代化组件 (Glassmorphism 风格)
|
||||
- **API**: Fetch API (对接后端 FastAPI :8006)
|
||||
|
||||
## 🚀 开发指南
|
||||
|
||||
### 安装依赖
|
||||
|
||||
```bash
|
||||
npm install
|
||||
```
|
||||
|
||||
### 启动开发服务器
|
||||
|
||||
默认运行在 **3002** 端口 (通过 `package.json` 配置):
|
||||
|
||||
```bash
|
||||
npm run dev
|
||||
# or
|
||||
yarn dev
|
||||
# or
|
||||
pnpm dev
|
||||
# or
|
||||
bun dev
|
||||
# 访问: http://localhost:3002
|
||||
```
|
||||
|
||||
Open [http://localhost:3000](http://localhost:3000) with your browser to see the result.
|
||||
### 目录结构
|
||||
|
||||
You can start editing the page by modifying `app/page.tsx`. The page auto-updates as you edit the file.
|
||||
```
|
||||
src/
|
||||
├── app/
|
||||
│ ├── page.tsx # 视频生成主页
|
||||
│ ├── publish/ # 发布管理页
|
||||
│ │ └── page.tsx
|
||||
│ └── layout.tsx # 全局布局 (导航栏)
|
||||
├── components/ # UI 组件
|
||||
│ ├── VideoUploader.tsx # 视频上传
|
||||
│ ├── StatusBadge.tsx # 状态徽章
|
||||
│ └── ...
|
||||
└── lib/ # 工具函数
|
||||
```
|
||||
|
||||
This project uses [`next/font`](https://nextjs.org/docs/app/building-your-application/optimizing/fonts) to automatically optimize and load [Geist](https://vercel.com/font), a new font family for Vercel.
|
||||
## 🔌 后端对接
|
||||
|
||||
## Learn More
|
||||
- **Base URL**: `http://localhost:8006`
|
||||
- **代理配置**: Next.js Rewrites (如需) 或直接 CORS。
|
||||
|
||||
To learn more about Next.js, take a look at the following resources:
|
||||
## 🎨 设计规范
|
||||
|
||||
- [Next.js Documentation](https://nextjs.org/docs) - learn about Next.js features and API.
|
||||
- [Learn Next.js](https://nextjs.org/learn) - an interactive Next.js tutorial.
|
||||
|
||||
You can check out [the Next.js GitHub repository](https://github.com/vercel/next.js) - your feedback and contributions are welcome!
|
||||
|
||||
## Deploy on Vercel
|
||||
|
||||
The easiest way to deploy your Next.js app is to use the [Vercel Platform](https://vercel.com/new?utm_medium=default-template&filter=next.js&utm_source=create-next-app&utm_campaign=create-next-app-readme) from the creators of Next.js.
|
||||
|
||||
Check out our [Next.js deployment documentation](https://nextjs.org/docs/app/building-your-application/deploying) for more details.
|
||||
- **主色调**: 深紫/黑色系 (Dark Mode)
|
||||
- **交互**: 悬停微动画 (Hover Effects)
|
||||
- **响应式**: 适配桌面端大屏操作
|
||||
|
||||
@@ -8,6 +8,14 @@ const nextConfig: NextConfig = {
|
||||
source: '/api/:path*',
|
||||
destination: 'http://localhost:8006/api/:path*', // 服务器本地代理
|
||||
},
|
||||
{
|
||||
source: '/uploads/:path*',
|
||||
destination: 'http://localhost:8006/uploads/:path*', // 转发上传的素材
|
||||
},
|
||||
{
|
||||
source: '/outputs/:path*',
|
||||
destination: 'http://localhost:8006/outputs/:path*', // 转发生成的视频
|
||||
},
|
||||
];
|
||||
},
|
||||
};
|
||||
|
||||
280
frontend/package-lock.json
generated
280
frontend/package-lock.json
generated
@@ -8,9 +8,12 @@
|
||||
"name": "frontend",
|
||||
"version": "0.1.0",
|
||||
"dependencies": {
|
||||
"@supabase/supabase-js": "^2.93.1",
|
||||
"axios": "^1.13.4",
|
||||
"next": "16.1.1",
|
||||
"react": "19.2.3",
|
||||
"react-dom": "19.2.3"
|
||||
"react-dom": "19.2.3",
|
||||
"swr": "^2.3.8"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@tailwindcss/postcss": "^4",
|
||||
@@ -67,7 +70,6 @@
|
||||
"integrity": "sha512-H3mcG6ZDLTlYfaSNi0iOKkigqMFvkTKlGUYlD8GW7nNOYRrevuA46iTypPyv+06V3fEmvvazfntkBU34L0azAw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@babel/code-frame": "^7.28.6",
|
||||
"@babel/generator": "^7.28.6",
|
||||
@@ -1234,6 +1236,80 @@
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@supabase/auth-js": {
|
||||
"version": "2.93.1",
|
||||
"resolved": "https://registry.npmjs.org/@supabase/auth-js/-/auth-js-2.93.1.tgz",
|
||||
"integrity": "sha512-pC0Ek4xk4z6q7A/3+UuZ/eYgfFUUQTg3DhapzrAgJnFGDJDFDyGCj6v9nIz8+3jfLqSZ3QKGe6AoEodYjShghg==",
|
||||
"dependencies": {
|
||||
"tslib": "2.8.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=20.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@supabase/functions-js": {
|
||||
"version": "2.93.1",
|
||||
"resolved": "https://registry.npmjs.org/@supabase/functions-js/-/functions-js-2.93.1.tgz",
|
||||
"integrity": "sha512-Ott2IcIXHGupaC0nX9WNEiJAX4OdlGRu9upkkURaQHbaLdz9JuCcHxlwTERgtgjMpikbIWHfMM1M9QTQFYABiA==",
|
||||
"dependencies": {
|
||||
"tslib": "2.8.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=20.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@supabase/postgrest-js": {
|
||||
"version": "2.93.1",
|
||||
"resolved": "https://registry.npmjs.org/@supabase/postgrest-js/-/postgrest-js-2.93.1.tgz",
|
||||
"integrity": "sha512-uRKKQJBDnfi6XFNFPNMh9+u3HT2PCgp065PcMPmG7e0xGuqvLtN89QxO2/SZcGbw2y1+mNBz0yUs5KmyNqF2fA==",
|
||||
"dependencies": {
|
||||
"tslib": "2.8.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=20.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@supabase/realtime-js": {
|
||||
"version": "2.93.1",
|
||||
"resolved": "https://registry.npmjs.org/@supabase/realtime-js/-/realtime-js-2.93.1.tgz",
|
||||
"integrity": "sha512-2WaP/KVHPlQDjWM6qe4wOZz6zSRGaXw1lfXf4thbfvk3C3zPPKqXRyspyYnk3IhphyxSsJ2hQ/cXNOz48008tg==",
|
||||
"dependencies": {
|
||||
"@types/phoenix": "^1.6.6",
|
||||
"@types/ws": "^8.18.1",
|
||||
"tslib": "2.8.1",
|
||||
"ws": "^8.18.2"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=20.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@supabase/storage-js": {
|
||||
"version": "2.93.1",
|
||||
"resolved": "https://registry.npmjs.org/@supabase/storage-js/-/storage-js-2.93.1.tgz",
|
||||
"integrity": "sha512-3KVwd4S1i1BVPL6KIywe5rnruNQXSkLyvrdiJmwnqwbCcDujQumARdGWBPesqCjOPKEU2M9ORWKAsn+2iLzquA==",
|
||||
"dependencies": {
|
||||
"iceberg-js": "^0.8.1",
|
||||
"tslib": "2.8.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=20.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@supabase/supabase-js": {
|
||||
"version": "2.93.1",
|
||||
"resolved": "https://registry.npmjs.org/@supabase/supabase-js/-/supabase-js-2.93.1.tgz",
|
||||
"integrity": "sha512-FJTgS5s0xEgRQ3u7gMuzGObwf3jA4O5Ki/DgCDXx94w1pihLM4/WG3XFa4BaCJYfuzLxLcv6zPPA5tDvBUjAUg==",
|
||||
"dependencies": {
|
||||
"@supabase/auth-js": "2.93.1",
|
||||
"@supabase/functions-js": "2.93.1",
|
||||
"@supabase/postgrest-js": "2.93.1",
|
||||
"@supabase/realtime-js": "2.93.1",
|
||||
"@supabase/storage-js": "2.93.1"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">=20.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@swc/helpers": {
|
||||
"version": "0.5.15",
|
||||
"resolved": "https://registry.npmjs.org/@swc/helpers/-/helpers-0.5.15.tgz",
|
||||
@@ -1550,19 +1626,22 @@
|
||||
"version": "20.19.28",
|
||||
"resolved": "https://registry.npmjs.org/@types/node/-/node-20.19.28.tgz",
|
||||
"integrity": "sha512-VyKBr25BuFDzBFCK5sUM6ZXiWfqgCTwTAOK8qzGV/m9FCirXYDlmczJ+d5dXBAQALGCdRRdbteKYfJ84NGEusw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"undici-types": "~6.21.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/phoenix": {
|
||||
"version": "1.6.7",
|
||||
"resolved": "https://registry.npmjs.org/@types/phoenix/-/phoenix-1.6.7.tgz",
|
||||
"integrity": "sha512-oN9ive//QSBkf19rfDv45M7eZPi0eEXylht2OLEXicu5b4KoQ1OzXIw+xDSGWxSxe1JmepRR/ZH283vsu518/Q=="
|
||||
},
|
||||
"node_modules/@types/react": {
|
||||
"version": "19.2.8",
|
||||
"resolved": "https://registry.npmjs.org/@types/react/-/react-19.2.8.tgz",
|
||||
"integrity": "sha512-3MbSL37jEchWZz2p2mjntRZtPt837ij10ApxKfgmXCTuHWagYg7iA5bqPw6C8BMPfwidlvfPI/fxOc42HLhcyg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"csstype": "^3.2.2"
|
||||
}
|
||||
@@ -1577,6 +1656,14 @@
|
||||
"@types/react": "^19.2.0"
|
||||
}
|
||||
},
|
||||
"node_modules/@types/ws": {
|
||||
"version": "8.18.1",
|
||||
"resolved": "https://registry.npmjs.org/@types/ws/-/ws-8.18.1.tgz",
|
||||
"integrity": "sha512-ThVF6DCVhA8kUGy+aazFQ4kXQ7E1Ty7A3ypFOe0IcJV8O/M511G99AW24irKrW56Wt44yG9+ij8FaqoBGkuBXg==",
|
||||
"dependencies": {
|
||||
"@types/node": "*"
|
||||
}
|
||||
},
|
||||
"node_modules/@typescript-eslint/eslint-plugin": {
|
||||
"version": "8.53.0",
|
||||
"resolved": "https://registry.npmjs.org/@typescript-eslint/eslint-plugin/-/eslint-plugin-8.53.0.tgz",
|
||||
@@ -1622,7 +1709,6 @@
|
||||
"integrity": "sha512-npiaib8XzbjtzS2N4HlqPvlpxpmZ14FjSJrteZpPxGUaYPlvhzlzUZ4mZyABo0EFrOWnvyd0Xxroq//hKhtAWg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@typescript-eslint/scope-manager": "8.53.0",
|
||||
"@typescript-eslint/types": "8.53.0",
|
||||
@@ -2122,7 +2208,6 @@
|
||||
"integrity": "sha512-NZyJarBfL7nWwIq+FDL6Zp/yHEhePMNnnJ0y3qfieCrmNvYct8uvtiV41UvlSe6apAfk0fY1FbWx+NwfmpvtTg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"acorn": "bin/acorn"
|
||||
},
|
||||
@@ -2367,6 +2452,12 @@
|
||||
"node": ">= 0.4"
|
||||
}
|
||||
},
|
||||
"node_modules/asynckit": {
|
||||
"version": "0.4.0",
|
||||
"resolved": "https://registry.npmjs.org/asynckit/-/asynckit-0.4.0.tgz",
|
||||
"integrity": "sha512-Oei9OH4tRh0YqU3GxhX79dM/mwVgvbZJaSNaRk+bshkj0S5cfHcgYakreBjrHwatXKbz+IoIdYLxrKim2MjW0Q==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/available-typed-arrays": {
|
||||
"version": "1.0.7",
|
||||
"resolved": "https://registry.npmjs.org/available-typed-arrays/-/available-typed-arrays-1.0.7.tgz",
|
||||
@@ -2393,6 +2484,17 @@
|
||||
"node": ">=4"
|
||||
}
|
||||
},
|
||||
"node_modules/axios": {
|
||||
"version": "1.13.4",
|
||||
"resolved": "https://registry.npmjs.org/axios/-/axios-1.13.4.tgz",
|
||||
"integrity": "sha512-1wVkUaAO6WyaYtCkcYCOx12ZgpGf9Zif+qXa4n+oYzK558YryKqiL6UWwd5DqiH3VRW0GYhTZQ/vlgJrCoNQlg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"follow-redirects": "^1.15.6",
|
||||
"form-data": "^4.0.4",
|
||||
"proxy-from-env": "^1.1.0"
|
||||
}
|
||||
},
|
||||
"node_modules/axobject-query": {
|
||||
"version": "4.1.0",
|
||||
"resolved": "https://registry.npmjs.org/axobject-query/-/axobject-query-4.1.0.tgz",
|
||||
@@ -2463,7 +2565,6 @@
|
||||
}
|
||||
],
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"baseline-browser-mapping": "^2.9.0",
|
||||
"caniuse-lite": "^1.0.30001759",
|
||||
@@ -2501,7 +2602,6 @@
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/call-bind-apply-helpers/-/call-bind-apply-helpers-1.0.2.tgz",
|
||||
"integrity": "sha512-Sp1ablJ0ivDkSzjcaJdxEunN5/XvksFJ2sMBFfq6x0ryhQV/2b/KwFe21cMpmHtPOSij8K99/wSfoEuTObmuMQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"es-errors": "^1.3.0",
|
||||
@@ -2601,6 +2701,18 @@
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/combined-stream": {
|
||||
"version": "1.0.8",
|
||||
"resolved": "https://registry.npmjs.org/combined-stream/-/combined-stream-1.0.8.tgz",
|
||||
"integrity": "sha512-FQN4MRfuJeHf7cBbBMJFXhKSDq+2kAArBlmRBvcvFE5BB1HZKXtSFASDhdlz9zOYwxh8lDdnvmMOe/+5cdoEdg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"delayed-stream": "~1.0.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.8"
|
||||
}
|
||||
},
|
||||
"node_modules/concat-map": {
|
||||
"version": "0.0.1",
|
||||
"resolved": "https://registry.npmjs.org/concat-map/-/concat-map-0.0.1.tgz",
|
||||
@@ -2759,6 +2871,24 @@
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/delayed-stream": {
|
||||
"version": "1.0.0",
|
||||
"resolved": "https://registry.npmjs.org/delayed-stream/-/delayed-stream-1.0.0.tgz",
|
||||
"integrity": "sha512-ZySD7Nf91aLB0RxL4KGrKHBXl7Eds1DAmEdcoVawXnLD7SDhpNgtuII2aAkg7a7QS41jxPSZ17p4VdGnMHk3MQ==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=0.4.0"
|
||||
}
|
||||
},
|
||||
"node_modules/dequal": {
|
||||
"version": "2.0.3",
|
||||
"resolved": "https://registry.npmjs.org/dequal/-/dequal-2.0.3.tgz",
|
||||
"integrity": "sha512-0je+qPKHEMohvfRTCEo3CrPG6cAzAYgmzKyxRiYSSDkS6eGJdyVJm7WaYA5ECaAD9wLB2T4EEeymA5aFVcYXCA==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=6"
|
||||
}
|
||||
},
|
||||
"node_modules/detect-libc": {
|
||||
"version": "2.1.2",
|
||||
"resolved": "https://registry.npmjs.org/detect-libc/-/detect-libc-2.1.2.tgz",
|
||||
@@ -2786,7 +2916,6 @@
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/dunder-proto/-/dunder-proto-1.0.1.tgz",
|
||||
"integrity": "sha512-KIN/nDJBQRcXw0MLVhZE9iQHmG68qAVIBg9CqmUYjmQIhgij9U5MFvrqkUL5FbtyyzZuOeOt0zdeRe4UY7ct+A==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"call-bind-apply-helpers": "^1.0.1",
|
||||
@@ -2898,7 +3027,6 @@
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/es-define-property/-/es-define-property-1.0.1.tgz",
|
||||
"integrity": "sha512-e3nRfgfUZ4rNGL232gUgX06QNyyez04KdjFrF+LTRoOXmrOgFKDg4BCdsjW8EnT69eqdYGmRpJwiPVYNrCaW3g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
@@ -2908,7 +3036,6 @@
|
||||
"version": "1.3.0",
|
||||
"resolved": "https://registry.npmjs.org/es-errors/-/es-errors-1.3.0.tgz",
|
||||
"integrity": "sha512-Zf5H2Kxt2xjTvbJvP2ZWLEICxA6j+hAmMzIlypy4xcBg1vKVnx89Wy0GbS+kf5cwCVFFzdCFh2XSCFNULS6csw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
@@ -2946,7 +3073,6 @@
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/es-object-atoms/-/es-object-atoms-1.1.1.tgz",
|
||||
"integrity": "sha512-FGgH2h8zKNim9ljj7dankFPcICIK9Cp5bm+c2gQSYePhpaG5+esrLODihIorn+Pe6FGJzWhXQotPv73jTaldXA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"es-errors": "^1.3.0"
|
||||
@@ -2959,7 +3085,6 @@
|
||||
"version": "2.1.0",
|
||||
"resolved": "https://registry.npmjs.org/es-set-tostringtag/-/es-set-tostringtag-2.1.0.tgz",
|
||||
"integrity": "sha512-j6vWzfrGVfyXxge+O0x5sh6cvxAog0a/4Rdd2K36zCMV5eJ+/+tOAngRO8cODMNWbVRdVlmGZQL2YS3yR8bIUA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"es-errors": "^1.3.0",
|
||||
@@ -3031,7 +3156,6 @@
|
||||
"integrity": "sha512-LEyamqS7W5HB3ujJyvi0HQK/dtVINZvd5mAAp9eT5S/ujByGjiZLCzPcHVzuXbpJDJF/cxwHlfceVUDZ2lnSTw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@eslint-community/eslint-utils": "^4.8.0",
|
||||
"@eslint-community/regexpp": "^4.12.1",
|
||||
@@ -3217,7 +3341,6 @@
|
||||
"integrity": "sha512-whOE1HFo/qJDyX4SnXzP4N6zOWn79WhnCUY/iDR0mPfQZO8wcYE4JClzI2oZrhBnnMUCBCHZhO6VQyoBU95mZA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"@rtsao/scc": "^1.1.0",
|
||||
"array-includes": "^3.1.9",
|
||||
@@ -3576,6 +3699,26 @@
|
||||
"dev": true,
|
||||
"license": "ISC"
|
||||
},
|
||||
"node_modules/follow-redirects": {
|
||||
"version": "1.15.11",
|
||||
"resolved": "https://registry.npmjs.org/follow-redirects/-/follow-redirects-1.15.11.tgz",
|
||||
"integrity": "sha512-deG2P0JfjrTxl50XGCDyfI97ZGVCxIpfKYmfyrQ54n5FO/0gfIES8C/Psl6kWVDolizcaaxZJnTS0QSMxvnsBQ==",
|
||||
"funding": [
|
||||
{
|
||||
"type": "individual",
|
||||
"url": "https://github.com/sponsors/RubenVerborgh"
|
||||
}
|
||||
],
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">=4.0"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"debug": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/for-each": {
|
||||
"version": "0.3.5",
|
||||
"resolved": "https://registry.npmjs.org/for-each/-/for-each-0.3.5.tgz",
|
||||
@@ -3592,11 +3735,26 @@
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/form-data": {
|
||||
"version": "4.0.5",
|
||||
"resolved": "https://registry.npmjs.org/form-data/-/form-data-4.0.5.tgz",
|
||||
"integrity": "sha512-8RipRLol37bNs2bhoV67fiTEvdTrbMUYcFTiy3+wuuOnUog2QBHCZWXDRijWQfAkhBj2Uf5UnVaiWwA5vdd82w==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"asynckit": "^0.4.0",
|
||||
"combined-stream": "^1.0.8",
|
||||
"es-set-tostringtag": "^2.1.0",
|
||||
"hasown": "^2.0.2",
|
||||
"mime-types": "^2.1.12"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 6"
|
||||
}
|
||||
},
|
||||
"node_modules/function-bind": {
|
||||
"version": "1.1.2",
|
||||
"resolved": "https://registry.npmjs.org/function-bind/-/function-bind-1.1.2.tgz",
|
||||
"integrity": "sha512-7XHNxH7qX9xG5mIwxkhumTox/MIRNcOgDrxWsMt2pAr23WHp6MrRlN7FBSFpCpr+oVO0F744iUgR82nJMfG2SA==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
@@ -3657,7 +3815,6 @@
|
||||
"version": "1.3.0",
|
||||
"resolved": "https://registry.npmjs.org/get-intrinsic/-/get-intrinsic-1.3.0.tgz",
|
||||
"integrity": "sha512-9fSjSaos/fRIVIp+xSJlE6lfwhES7LNtKaCBIamHsjr2na1BiABJPo0mOjjz8GJDURarmCPGqaiVg5mfjb98CQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"call-bind-apply-helpers": "^1.0.2",
|
||||
@@ -3682,7 +3839,6 @@
|
||||
"version": "1.0.1",
|
||||
"resolved": "https://registry.npmjs.org/get-proto/-/get-proto-1.0.1.tgz",
|
||||
"integrity": "sha512-sTSfBjoXBp89JvIKIefqw7U2CCebsc74kiY6awiGogKtoSGbgjYE/G/+l9sF3MWFPNc9IcoOC4ODfKHfxFmp0g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"dunder-proto": "^1.0.1",
|
||||
@@ -3770,7 +3926,6 @@
|
||||
"version": "1.2.0",
|
||||
"resolved": "https://registry.npmjs.org/gopd/-/gopd-1.2.0.tgz",
|
||||
"integrity": "sha512-ZUKRh6/kUFoAiTAtTYPZJ3hw9wNxx+BIBOijnlG9PnrJsCcSjs1wyyD6vJpaYtgnzDrKYRSqf3OO6Rfa93xsRg==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
@@ -3842,7 +3997,6 @@
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/has-symbols/-/has-symbols-1.1.0.tgz",
|
||||
"integrity": "sha512-1cDNdwJ2Jaohmb3sg4OmKaMBwuC48sYni5HUw2DvsC8LjGTLK9h+eb1X6RyuOHe4hT0ULCW68iomhjUoKUqlPQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
@@ -3855,7 +4009,6 @@
|
||||
"version": "1.0.2",
|
||||
"resolved": "https://registry.npmjs.org/has-tostringtag/-/has-tostringtag-1.0.2.tgz",
|
||||
"integrity": "sha512-NqADB8VjPFLM2V0VvHUewwwsw0ZWBaIdgo+ieHtK3hasLz4qeCRjYcqfB6AQrBggRKppKF8L52/VqdVsO47Dlw==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"has-symbols": "^1.0.3"
|
||||
@@ -3871,7 +4024,6 @@
|
||||
"version": "2.0.2",
|
||||
"resolved": "https://registry.npmjs.org/hasown/-/hasown-2.0.2.tgz",
|
||||
"integrity": "sha512-0hJU9SCPvmMzIBdZFqNPXWa6dqh7WdH0cII9y+CyS8rG3nL48Bclra9HmKhVVUHyPWNH5Y7xDwAB7bfgSjkUMQ==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"function-bind": "^1.1.2"
|
||||
@@ -3897,6 +4049,14 @@
|
||||
"hermes-estree": "0.25.1"
|
||||
}
|
||||
},
|
||||
"node_modules/iceberg-js": {
|
||||
"version": "0.8.1",
|
||||
"resolved": "https://registry.npmjs.org/iceberg-js/-/iceberg-js-0.8.1.tgz",
|
||||
"integrity": "sha512-1dhVQZXhcHje7798IVM+xoo/1ZdVfzOMIc8/rgVSijRK38EDqOJoGula9N/8ZI5RD8QTxNQtK/Gozpr+qUqRRA==",
|
||||
"engines": {
|
||||
"node": ">=20.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/ignore": {
|
||||
"version": "5.3.2",
|
||||
"resolved": "https://registry.npmjs.org/ignore/-/ignore-5.3.2.tgz",
|
||||
@@ -4854,7 +5014,6 @@
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/math-intrinsics/-/math-intrinsics-1.1.0.tgz",
|
||||
"integrity": "sha512-/IXtbwEk5HTPyEwyKX6hGkYXxM9nbj64B+ilVJnC/R6B0pH5G4V3b0pVbL7DBj4tkhBAppbQUlf6F6Xl9LHu1g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 0.4"
|
||||
@@ -4884,6 +5043,27 @@
|
||||
"node": ">=8.6"
|
||||
}
|
||||
},
|
||||
"node_modules/mime-db": {
|
||||
"version": "1.52.0",
|
||||
"resolved": "https://registry.npmjs.org/mime-db/-/mime-db-1.52.0.tgz",
|
||||
"integrity": "sha512-sPU4uV7dYlvtWJxwwxHD0PuihVNiE7TyAbQ5SWxDCB9mUYvOgroQOwYQQOKPJ8CIbE+1ETVlOoK1UC2nU3gYvg==",
|
||||
"license": "MIT",
|
||||
"engines": {
|
||||
"node": ">= 0.6"
|
||||
}
|
||||
},
|
||||
"node_modules/mime-types": {
|
||||
"version": "2.1.35",
|
||||
"resolved": "https://registry.npmjs.org/mime-types/-/mime-types-2.1.35.tgz",
|
||||
"integrity": "sha512-ZDY+bPm5zTTF+YpCrAU9nK0UgICYPT0QtT1NZWFv4s++TNkcgVaT0g6+4R2uI4MjQjzysHB1zxuWL50hzaeXiw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"mime-db": "1.52.0"
|
||||
},
|
||||
"engines": {
|
||||
"node": ">= 0.6"
|
||||
}
|
||||
},
|
||||
"node_modules/minimatch": {
|
||||
"version": "3.1.2",
|
||||
"resolved": "https://registry.npmjs.org/minimatch/-/minimatch-3.1.2.tgz",
|
||||
@@ -5354,6 +5534,12 @@
|
||||
"react-is": "^16.13.1"
|
||||
}
|
||||
},
|
||||
"node_modules/proxy-from-env": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/proxy-from-env/-/proxy-from-env-1.1.0.tgz",
|
||||
"integrity": "sha512-D+zkORCbA9f1tdWRK0RaCR3GPv50cMxcrz4X8k5LTSUD1Dkw47mKJEZQNunItRTkWwgtaUSo1RVFRIG9ZXiFYg==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/punycode": {
|
||||
"version": "2.3.1",
|
||||
"resolved": "https://registry.npmjs.org/punycode/-/punycode-2.3.1.tgz",
|
||||
@@ -5390,7 +5576,6 @@
|
||||
"resolved": "https://registry.npmjs.org/react/-/react-19.2.3.tgz",
|
||||
"integrity": "sha512-Ku/hhYbVjOQnXDZFv2+RibmLFGwFdeeKHFcOTlrt7xplBnya5OGn/hIRDsqDiSUcfORsDC7MPxwork8jBwsIWA==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
@@ -5400,7 +5585,6 @@
|
||||
"resolved": "https://registry.npmjs.org/react-dom/-/react-dom-19.2.3.tgz",
|
||||
"integrity": "sha512-yELu4WmLPw5Mr/lmeEpox5rw3RETacE++JgHqQzd2dg+YbJuat3jH4ingc+WPZhxaoFzdv9y33G+F7Nl5O0GBg==",
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"dependencies": {
|
||||
"scheduler": "^0.27.0"
|
||||
},
|
||||
@@ -6027,6 +6211,19 @@
|
||||
"url": "https://github.com/sponsors/ljharb"
|
||||
}
|
||||
},
|
||||
"node_modules/swr": {
|
||||
"version": "2.3.8",
|
||||
"resolved": "https://registry.npmjs.org/swr/-/swr-2.3.8.tgz",
|
||||
"integrity": "sha512-gaCPRVoMq8WGDcWj9p4YWzCMPHzE0WNl6W8ADIx9c3JBEIdMkJGMzW+uzXvxHMltwcYACr9jP+32H8/hgwMR7w==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"dequal": "^2.0.3",
|
||||
"use-sync-external-store": "^1.6.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"react": "^16.11.0 || ^17.0.0 || ^18.0.0 || ^19.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/tailwindcss": {
|
||||
"version": "4.1.18",
|
||||
"resolved": "https://registry.npmjs.org/tailwindcss/-/tailwindcss-4.1.18.tgz",
|
||||
@@ -6089,7 +6286,6 @@
|
||||
"integrity": "sha512-5gTmgEY/sqK6gFXLIsQNH19lWb4ebPDLA4SdLP7dsWkIXHWlG66oPuVvXSGFPppYZz8ZDZq0dYYrbHfBCVUb1Q==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"engines": {
|
||||
"node": ">=12"
|
||||
},
|
||||
@@ -6252,7 +6448,6 @@
|
||||
"integrity": "sha512-jl1vZzPDinLr9eUt3J/t7V6FgNEw9QjvBPdysz9KfQDD41fQrC2Y4vKQdiaUpFT4bXlb1RHhLpp8wtm6M5TgSw==",
|
||||
"dev": true,
|
||||
"license": "Apache-2.0",
|
||||
"peer": true,
|
||||
"bin": {
|
||||
"tsc": "bin/tsc",
|
||||
"tsserver": "bin/tsserver"
|
||||
@@ -6308,7 +6503,6 @@
|
||||
"version": "6.21.0",
|
||||
"resolved": "https://registry.npmjs.org/undici-types/-/undici-types-6.21.0.tgz",
|
||||
"integrity": "sha512-iwDZqg0QAGrg9Rav5H4n0M64c3mkR59cJ6wQp+7C4nI0gsmExaedaYLNO44eT4AtBBwjbTiGPMlt2Md0T9H9JQ==",
|
||||
"dev": true,
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/unrs-resolver": {
|
||||
@@ -6387,6 +6581,15 @@
|
||||
"punycode": "^2.1.0"
|
||||
}
|
||||
},
|
||||
"node_modules/use-sync-external-store": {
|
||||
"version": "1.6.0",
|
||||
"resolved": "https://registry.npmjs.org/use-sync-external-store/-/use-sync-external-store-1.6.0.tgz",
|
||||
"integrity": "sha512-Pp6GSwGP/NrPIrxVFAIkOQeyw8lFenOHijQWkUTrDvrF4ALqylP2C/KCkeS9dpUM3KvYRQhna5vt7IL95+ZQ9w==",
|
||||
"license": "MIT",
|
||||
"peerDependencies": {
|
||||
"react": "^16.8.0 || ^17.0.0 || ^18.0.0 || ^19.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/which": {
|
||||
"version": "2.0.2",
|
||||
"resolved": "https://registry.npmjs.org/which/-/which-2.0.2.tgz",
|
||||
@@ -6502,6 +6705,26 @@
|
||||
"node": ">=0.10.0"
|
||||
}
|
||||
},
|
||||
"node_modules/ws": {
|
||||
"version": "8.19.0",
|
||||
"resolved": "https://registry.npmjs.org/ws/-/ws-8.19.0.tgz",
|
||||
"integrity": "sha512-blAT2mjOEIi0ZzruJfIhb3nps74PRWTCz1IjglWEEpQl5XS/UNama6u2/rjFkDDouqr4L67ry+1aGIALViWjDg==",
|
||||
"engines": {
|
||||
"node": ">=10.0.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"bufferutil": "^4.0.1",
|
||||
"utf-8-validate": ">=5.0.2"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"bufferutil": {
|
||||
"optional": true
|
||||
},
|
||||
"utf-8-validate": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/yallist": {
|
||||
"version": "3.1.1",
|
||||
"resolved": "https://registry.npmjs.org/yallist/-/yallist-3.1.1.tgz",
|
||||
@@ -6528,7 +6751,6 @@
|
||||
"integrity": "sha512-k7Nwx6vuWx1IJ9Bjuf4Zt1PEllcwe7cls3VNzm4CQ1/hgtFUK2bRNG3rvnpPUhFjmqJKAKtjV576KnUkHocg/g==",
|
||||
"dev": true,
|
||||
"license": "MIT",
|
||||
"peer": true,
|
||||
"funding": {
|
||||
"url": "https://github.com/sponsors/colinhacks"
|
||||
}
|
||||
|
||||
@@ -9,9 +9,12 @@
|
||||
"lint": "eslint"
|
||||
},
|
||||
"dependencies": {
|
||||
"@supabase/supabase-js": "^2.93.1",
|
||||
"axios": "^1.13.4",
|
||||
"next": "16.1.1",
|
||||
"react": "19.2.3",
|
||||
"react-dom": "19.2.3"
|
||||
"react-dom": "19.2.3",
|
||||
"swr": "^2.3.8"
|
||||
},
|
||||
"devDependencies": {
|
||||
"@tailwindcss/postcss": "^4",
|
||||
@@ -23,4 +26,4 @@
|
||||
"tailwindcss": "^4",
|
||||
"typescript": "^5"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
190
frontend/src/app/admin/page.tsx
Normal file
190
frontend/src/app/admin/page.tsx
Normal file
@@ -0,0 +1,190 @@
|
||||
'use client';
|
||||
|
||||
import { useState, useEffect } from 'react';
|
||||
import { useRouter } from 'next/navigation';
|
||||
import { getCurrentUser, User } from '@/lib/auth';
|
||||
import api from '@/lib/axios';
|
||||
|
||||
interface UserListItem {
|
||||
id: string;
|
||||
email: string;
|
||||
username: string | null;
|
||||
role: string;
|
||||
is_active: boolean;
|
||||
expires_at: string | null;
|
||||
created_at: string;
|
||||
}
|
||||
|
||||
export default function AdminPage() {
|
||||
const router = useRouter();
|
||||
const [currentUser, setCurrentUser] = useState<User | null>(null);
|
||||
const [users, setUsers] = useState<UserListItem[]>([]);
|
||||
const [loading, setLoading] = useState(true);
|
||||
const [error, setError] = useState('');
|
||||
const [activatingId, setActivatingId] = useState<string | null>(null);
|
||||
const [expireDays, setExpireDays] = useState<number>(30);
|
||||
|
||||
useEffect(() => {
|
||||
checkAdmin();
|
||||
fetchUsers();
|
||||
}, []);
|
||||
|
||||
const checkAdmin = async () => {
|
||||
const user = await getCurrentUser();
|
||||
if (!user || user.role !== 'admin') {
|
||||
router.push('/login');
|
||||
return;
|
||||
}
|
||||
setCurrentUser(user);
|
||||
};
|
||||
|
||||
const fetchUsers = async () => {
|
||||
try {
|
||||
const { data } = await api.get('/api/admin/users');
|
||||
setUsers(data);
|
||||
} catch (err) {
|
||||
setError('获取用户列表失败');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const activateUser = async (userId: string) => {
|
||||
setActivatingId(userId);
|
||||
try {
|
||||
await api.post(`/api/admin/users/${userId}/activate`, {
|
||||
expires_days: expireDays || null
|
||||
});
|
||||
fetchUsers();
|
||||
} catch (err) {
|
||||
// axios interceptor handles 401/403
|
||||
} finally {
|
||||
setActivatingId(null);
|
||||
}
|
||||
};
|
||||
|
||||
const deactivateUser = async (userId: string) => {
|
||||
if (!confirm('确定要停用该用户吗?')) return;
|
||||
|
||||
try {
|
||||
await api.post(`/api/admin/users/${userId}/deactivate`);
|
||||
fetchUsers();
|
||||
} catch (err) {
|
||||
alert('操作失败');
|
||||
}
|
||||
};
|
||||
|
||||
const formatDate = (dateStr: string | null) => {
|
||||
if (!dateStr) return '永久';
|
||||
return new Date(dateStr).toLocaleDateString('zh-CN');
|
||||
};
|
||||
|
||||
const getRoleBadge = (role: string, isActive: boolean) => {
|
||||
if (role === 'admin') {
|
||||
return <span className="px-2 py-1 text-xs rounded-full bg-purple-500/20 text-purple-300">管理员</span>;
|
||||
}
|
||||
if (role === 'pending') {
|
||||
return <span className="px-2 py-1 text-xs rounded-full bg-yellow-500/20 text-yellow-300">待审核</span>;
|
||||
}
|
||||
if (!isActive) {
|
||||
return <span className="px-2 py-1 text-xs rounded-full bg-red-500/20 text-red-300">已停用</span>;
|
||||
}
|
||||
return <span className="px-2 py-1 text-xs rounded-full bg-green-500/20 text-green-300">正常</span>;
|
||||
};
|
||||
|
||||
if (loading) {
|
||||
return (
|
||||
<div className="min-h-dvh flex items-center justify-center">
|
||||
<div className="animate-spin rounded-full h-12 w-12 border-t-2 border-b-2 border-purple-500"></div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="min-h-dvh p-8">
|
||||
<div className="max-w-6xl mx-auto">
|
||||
<div className="flex justify-between items-center mb-8">
|
||||
<h1 className="text-3xl font-bold text-white">用户管理</h1>
|
||||
<a href="/" className="text-purple-300 hover:text-purple-200">
|
||||
← 返回首页
|
||||
</a>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="mb-4 p-3 bg-red-500/20 border border-red-500/50 rounded-lg text-red-200">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="mb-4 flex items-center gap-4">
|
||||
<label className="text-gray-300">默认授权天数:</label>
|
||||
<input
|
||||
type="number"
|
||||
value={expireDays}
|
||||
onChange={(e) => setExpireDays(parseInt(e.target.value) || 0)}
|
||||
className="w-24 px-3 py-2 bg-white/5 border border-white/10 rounded text-white"
|
||||
placeholder="0=永久"
|
||||
/>
|
||||
<span className="text-gray-400 text-sm">(0 表示永久)</span>
|
||||
</div>
|
||||
|
||||
<div className="bg-white/5 backdrop-blur-lg rounded-xl border border-white/10 overflow-hidden">
|
||||
<table className="w-full">
|
||||
<thead className="bg-white/5">
|
||||
<tr>
|
||||
<th className="px-6 py-4 text-left text-sm font-medium text-gray-300">用户</th>
|
||||
<th className="px-6 py-4 text-left text-sm font-medium text-gray-300">状态</th>
|
||||
<th className="px-6 py-4 text-left text-sm font-medium text-gray-300">过期时间</th>
|
||||
<th className="px-6 py-4 text-left text-sm font-medium text-gray-300">注册时间</th>
|
||||
<th className="px-6 py-4 text-left text-sm font-medium text-gray-300">操作</th>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody className="divide-y divide-white/5">
|
||||
{users.map((user) => (
|
||||
<tr key={user.id} className="hover:bg-white/5">
|
||||
<td className="px-6 py-4">
|
||||
<div>
|
||||
<div className="text-white font-medium">{user.username || user.email.split('@')[0]}</div>
|
||||
<div className="text-gray-400 text-sm">{user.email}</div>
|
||||
</div>
|
||||
</td>
|
||||
<td className="px-6 py-4">
|
||||
{getRoleBadge(user.role, user.is_active)}
|
||||
</td>
|
||||
<td className="px-6 py-4 text-gray-300">
|
||||
{formatDate(user.expires_at)}
|
||||
</td>
|
||||
<td className="px-6 py-4 text-gray-400 text-sm">
|
||||
{formatDate(user.created_at)}
|
||||
</td>
|
||||
<td className="px-6 py-4">
|
||||
{user.role !== 'admin' && (
|
||||
<div className="flex gap-2">
|
||||
{!user.is_active || user.role === 'pending' ? (
|
||||
<button
|
||||
onClick={() => activateUser(user.id)}
|
||||
disabled={activatingId === user.id}
|
||||
className="px-3 py-1 bg-green-600 hover:bg-green-700 text-white text-sm rounded disabled:opacity-50"
|
||||
>
|
||||
{activatingId === user.id ? '...' : '激活'}
|
||||
</button>
|
||||
) : (
|
||||
<button
|
||||
onClick={() => deactivateUser(user.id)}
|
||||
className="px-3 py-1 bg-red-600 hover:bg-red-700 text-white text-sm rounded"
|
||||
>
|
||||
停用
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</td>
|
||||
</tr>
|
||||
))}
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -19,18 +19,73 @@
|
||||
}
|
||||
}
|
||||
|
||||
body {
|
||||
background: var(--background);
|
||||
color: var(--foreground);
|
||||
font-family: Arial, Helvetica, sans-serif;
|
||||
}
|
||||
|
||||
/* 隐藏滚动条但保留滚动功能 */
|
||||
/* iOS Safari 安全区域支持 + 滚动条隐藏 */
|
||||
html {
|
||||
scrollbar-width: none; /* Firefox */
|
||||
-ms-overflow-style: none; /* IE 和 Edge */
|
||||
background-color: #0f172a !important;
|
||||
min-height: 100%;
|
||||
scrollbar-width: none;
|
||||
-ms-overflow-style: none;
|
||||
}
|
||||
|
||||
html::-webkit-scrollbar {
|
||||
display: none; /* Chrome, Safari, Opera */
|
||||
display: none;
|
||||
}
|
||||
|
||||
body {
|
||||
margin: 0 !important;
|
||||
min-height: 100dvh;
|
||||
color: var(--foreground);
|
||||
font-family: Arial, Helvetica, sans-serif;
|
||||
padding-top: env(safe-area-inset-top);
|
||||
padding-bottom: env(safe-area-inset-bottom);
|
||||
}
|
||||
|
||||
/* 自定义滚动条样式 - 深色主题 */
|
||||
.custom-scrollbar {
|
||||
scrollbar-width: thin;
|
||||
scrollbar-color: rgba(147, 51, 234, 0.5) transparent;
|
||||
}
|
||||
|
||||
.custom-scrollbar::-webkit-scrollbar {
|
||||
width: 6px;
|
||||
}
|
||||
|
||||
.custom-scrollbar::-webkit-scrollbar-track {
|
||||
background: transparent;
|
||||
}
|
||||
|
||||
.custom-scrollbar::-webkit-scrollbar-thumb {
|
||||
background: rgba(147, 51, 234, 0.5);
|
||||
border-radius: 3px;
|
||||
}
|
||||
|
||||
.custom-scrollbar::-webkit-scrollbar-thumb:hover {
|
||||
background: rgba(147, 51, 234, 0.8);
|
||||
}
|
||||
|
||||
/* 完全隐藏滚动条 */
|
||||
.hide-scrollbar {
|
||||
scrollbar-width: none;
|
||||
-ms-overflow-style: none;
|
||||
}
|
||||
|
||||
.hide-scrollbar::-webkit-scrollbar {
|
||||
display: none;
|
||||
}
|
||||
|
||||
/* 自定义 select 下拉菜单 */
|
||||
.custom-select {
|
||||
appearance: none;
|
||||
-webkit-appearance: none;
|
||||
-moz-appearance: none;
|
||||
background-image: url("data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' width='12' height='12' fill='%239ca3af' viewBox='0 0 16 16'%3E%3Cpath d='M8 11L3 6h10l-5 5z'/%3E%3C/svg%3E");
|
||||
background-repeat: no-repeat;
|
||||
background-position: right 12px center;
|
||||
padding-right: 36px;
|
||||
}
|
||||
|
||||
.custom-select option {
|
||||
background: #1a1a2e;
|
||||
color: white;
|
||||
padding: 12px;
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
import type { Metadata } from "next";
|
||||
import type { Metadata, Viewport } from "next";
|
||||
import { Geist, Geist_Mono } from "next/font/google";
|
||||
import "./globals.css";
|
||||
|
||||
@@ -13,8 +13,15 @@ const geistMono = Geist_Mono({
|
||||
});
|
||||
|
||||
export const metadata: Metadata = {
|
||||
title: "Create Next App",
|
||||
description: "Generated by create next app",
|
||||
title: "ViGent",
|
||||
description: "ViGent Talking Head Agent",
|
||||
};
|
||||
|
||||
export const viewport: Viewport = {
|
||||
width: 'device-width',
|
||||
initialScale: 1,
|
||||
viewportFit: 'cover',
|
||||
themeColor: '#0f172a',
|
||||
};
|
||||
|
||||
export default function RootLayout({
|
||||
@@ -23,9 +30,14 @@ export default function RootLayout({
|
||||
children: React.ReactNode;
|
||||
}>) {
|
||||
return (
|
||||
<html lang="en">
|
||||
<html lang="en" style={{ backgroundColor: '#0f172a' }}>
|
||||
<body
|
||||
className={`${geistSans.variable} ${geistMono.variable} antialiased`}
|
||||
style={{
|
||||
margin: 0,
|
||||
minHeight: '100dvh',
|
||||
background: 'linear-gradient(to bottom, #0f172a 0%, #0f172a 5%, #581c87 50%, #0f172a 95%, #0f172a 100%)',
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</body>
|
||||
|
||||
101
frontend/src/app/login/page.tsx
Normal file
101
frontend/src/app/login/page.tsx
Normal file
@@ -0,0 +1,101 @@
|
||||
'use client';
|
||||
|
||||
import { useState } from 'react';
|
||||
import { useRouter } from 'next/navigation';
|
||||
import { login } from '@/lib/auth';
|
||||
|
||||
export default function LoginPage() {
|
||||
const router = useRouter();
|
||||
const [email, setEmail] = useState('');
|
||||
const [password, setPassword] = useState('');
|
||||
const [error, setError] = useState('');
|
||||
const [loading, setLoading] = useState(false);
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
setError('');
|
||||
setLoading(true);
|
||||
|
||||
try {
|
||||
const result = await login(email, password);
|
||||
if (result.success) {
|
||||
router.push('/');
|
||||
} else {
|
||||
setError(result.message || '登录失败');
|
||||
}
|
||||
} catch (err) {
|
||||
setError('网络错误,请稍后重试');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="min-h-dvh flex items-center justify-center">
|
||||
<div className="w-full max-w-md p-8 bg-white/10 backdrop-blur-lg rounded-2xl shadow-2xl border border-white/20">
|
||||
<div className="text-center mb-8">
|
||||
<h1 className="text-3xl font-bold text-white mb-2">ViGent</h1>
|
||||
<p className="text-gray-300">AI 视频生成平台</p>
|
||||
</div>
|
||||
|
||||
<form onSubmit={handleSubmit} className="space-y-6">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-gray-200 mb-2">
|
||||
邮箱
|
||||
</label>
|
||||
<input
|
||||
type="email"
|
||||
value={email}
|
||||
onChange={(e) => setEmail(e.target.value)}
|
||||
required
|
||||
className="w-full px-4 py-3 bg-white/5 border border-white/10 rounded-lg text-white placeholder-gray-400 focus:outline-none focus:ring-2 focus:ring-purple-500 focus:border-transparent"
|
||||
placeholder="your@email.com"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-gray-200 mb-2">
|
||||
密码
|
||||
</label>
|
||||
<input
|
||||
type="password"
|
||||
value={password}
|
||||
onChange={(e) => setPassword(e.target.value)}
|
||||
required
|
||||
className="w-full px-4 py-3 bg-white/5 border border-white/10 rounded-lg text-white placeholder-gray-400 focus:outline-none focus:ring-2 focus:ring-purple-500 focus:border-transparent"
|
||||
placeholder="••••••••"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="p-3 bg-red-500/20 border border-red-500/50 rounded-lg text-red-200 text-sm">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
disabled={loading}
|
||||
className="w-full py-3 px-4 bg-gradient-to-r from-purple-600 to-pink-600 hover:from-purple-700 hover:to-pink-700 text-white font-semibold rounded-lg shadow-lg transition-all duration-200 disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
>
|
||||
{loading ? (
|
||||
<span className="flex items-center justify-center">
|
||||
<svg className="animate-spin -ml-1 mr-3 h-5 w-5 text-white" xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24">
|
||||
<circle className="opacity-25" cx="12" cy="12" r="10" stroke="currentColor" strokeWidth="4"></circle>
|
||||
<path className="opacity-75" fill="currentColor" d="M4 12a8 8 0 018-8V0C5.373 0 0 5.373 0 12h4zm2 5.291A7.962 7.962 0 014 12H0c0 3.042 1.135 5.824 3 7.938l3-2.647z"></path>
|
||||
</svg>
|
||||
登录中...
|
||||
</span>
|
||||
) : '登录'}
|
||||
</button>
|
||||
</form>
|
||||
|
||||
<div className="mt-6 text-center">
|
||||
<a href="/register" className="text-purple-300 hover:text-purple-200 text-sm">
|
||||
还没有账号?立即注册
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -3,11 +3,11 @@
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import Link from "next/link";
|
||||
import api from "@/lib/axios";
|
||||
|
||||
// 动态获取 API 地址:服务端使用 localhost,客户端使用当前域名
|
||||
const API_BASE = typeof window !== 'undefined'
|
||||
? `http://${window.location.hostname}:8006`
|
||||
: 'http://localhost:8006';
|
||||
const API_BASE = typeof window === 'undefined'
|
||||
? 'http://localhost:8006'
|
||||
: '';
|
||||
|
||||
// 类型定义
|
||||
interface Material {
|
||||
@@ -26,6 +26,25 @@ interface Task {
|
||||
download_url?: string;
|
||||
}
|
||||
|
||||
interface GeneratedVideo {
|
||||
id: string;
|
||||
name: string;
|
||||
path: string;
|
||||
size_mb: number;
|
||||
created_at: number;
|
||||
}
|
||||
|
||||
// 格式化日期(避免 Hydration 错误)
|
||||
const formatDate = (timestamp: number) => {
|
||||
const d = new Date(timestamp * 1000);
|
||||
const year = d.getFullYear();
|
||||
const month = String(d.getMonth() + 1).padStart(2, '0');
|
||||
const day = String(d.getDate()).padStart(2, '0');
|
||||
const hour = String(d.getHours()).padStart(2, '0');
|
||||
const minute = String(d.getMinutes()).padStart(2, '0');
|
||||
return `${year}/${month}/${day} ${hour}:${minute}`;
|
||||
};
|
||||
|
||||
export default function Home() {
|
||||
const [materials, setMaterials] = useState<Material[]>([]);
|
||||
const [selectedMaterial, setSelectedMaterial] = useState<string>("");
|
||||
@@ -41,6 +60,10 @@ export default function Home() {
|
||||
const [isUploading, setIsUploading] = useState(false);
|
||||
const [uploadProgress, setUploadProgress] = useState(0);
|
||||
const [uploadError, setUploadError] = useState<string | null>(null);
|
||||
const [uploadData, setUploadData] = useState<string>("");
|
||||
const [generatedVideos, setGeneratedVideos] = useState<GeneratedVideo[]>([]);
|
||||
|
||||
const [selectedVideoId, setSelectedVideoId] = useState<string | null>(null);
|
||||
|
||||
// 可选音色
|
||||
const voices = [
|
||||
@@ -51,9 +74,10 @@ export default function Home() {
|
||||
{ id: "zh-CN-XiaoyiNeural", name: "晓伊 (女声-温柔)" },
|
||||
];
|
||||
|
||||
// 加载素材列表
|
||||
// 加载素材列表和历史视频
|
||||
useEffect(() => {
|
||||
fetchMaterials();
|
||||
fetchGeneratedVideos();
|
||||
}, []);
|
||||
|
||||
const fetchMaterials = async () => {
|
||||
@@ -61,18 +85,8 @@ export default function Home() {
|
||||
setFetchError(null);
|
||||
setDebugData("Loading...");
|
||||
|
||||
// Add timestamp to prevent caching
|
||||
const url = `${API_BASE}/api/materials/?t=${new Date().getTime()}`;
|
||||
const res = await fetch(url);
|
||||
|
||||
if (!res.ok) {
|
||||
throw new Error(`HTTP ${res.status} ${res.statusText}`);
|
||||
}
|
||||
|
||||
const text = await res.text(); // Get raw text first
|
||||
setDebugData(text.substring(0, 200) + (text.length > 200 ? "..." : "")); // Show preview
|
||||
|
||||
const data = JSON.parse(text);
|
||||
const { data } = await api.get(`/api/materials?t=${new Date().getTime()}`);
|
||||
setDebugData(JSON.stringify(data).substring(0, 200));
|
||||
setMaterials(data.materials || []);
|
||||
|
||||
if (data.materials?.length > 0) {
|
||||
@@ -87,7 +101,46 @@ export default function Home() {
|
||||
}
|
||||
};
|
||||
|
||||
// 上传视频
|
||||
// 获取已生成的视频列表(持久化)
|
||||
const fetchGeneratedVideos = async () => {
|
||||
try {
|
||||
const { data } = await api.get('/api/videos/generated');
|
||||
setGeneratedVideos(data.videos || []);
|
||||
} catch (error) {
|
||||
console.error("获取历史视频失败:", error);
|
||||
}
|
||||
};
|
||||
|
||||
// 删除素材
|
||||
const deleteMaterial = async (materialId: string) => {
|
||||
if (!confirm("确定要删除这个素材吗?")) return;
|
||||
try {
|
||||
await api.delete(`/api/materials/${materialId}`);
|
||||
fetchMaterials();
|
||||
if (selectedMaterial === materialId) {
|
||||
setSelectedMaterial("");
|
||||
}
|
||||
} catch (error) {
|
||||
alert("删除失败: " + error);
|
||||
}
|
||||
};
|
||||
|
||||
// 删除生成的视频
|
||||
const deleteVideo = async (videoId: string) => {
|
||||
if (!confirm("确定要删除这个视频吗?")) return;
|
||||
try {
|
||||
await api.delete(`/api/videos/generated/${videoId}`);
|
||||
fetchGeneratedVideos();
|
||||
if (selectedVideoId === videoId) {
|
||||
setSelectedVideoId(null);
|
||||
setGeneratedVideo(null);
|
||||
}
|
||||
} catch (error) {
|
||||
alert("删除失败: " + error);
|
||||
}
|
||||
};
|
||||
|
||||
// 上传视频 - 使用 axios 支持进度显示
|
||||
const handleUpload = async (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
const file = e.target.files?.[0];
|
||||
if (!file) return;
|
||||
@@ -104,41 +157,37 @@ export default function Home() {
|
||||
setUploadProgress(0);
|
||||
setUploadError(null);
|
||||
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
try {
|
||||
const formData = new FormData();
|
||||
formData.append('file', file);
|
||||
|
||||
// 使用 XMLHttpRequest 以获取上传进度
|
||||
const xhr = new XMLHttpRequest();
|
||||
await api.post('/api/materials', formData, {
|
||||
headers: { 'Content-Type': 'multipart/form-data' },
|
||||
onUploadProgress: (progressEvent) => {
|
||||
if (progressEvent.total) {
|
||||
const progress = Math.round((progressEvent.loaded / progressEvent.total) * 100);
|
||||
setUploadProgress(progress);
|
||||
}
|
||||
},
|
||||
});
|
||||
|
||||
xhr.upload.onprogress = (event) => {
|
||||
if (event.lengthComputable) {
|
||||
const progress = Math.round((event.loaded / event.total) * 100);
|
||||
setUploadProgress(progress);
|
||||
}
|
||||
};
|
||||
|
||||
xhr.onload = () => {
|
||||
setUploadProgress(100);
|
||||
setIsUploading(false);
|
||||
if (xhr.status >= 200 && xhr.status < 300) {
|
||||
fetchMaterials(); // 刷新素材列表
|
||||
setUploadProgress(100);
|
||||
} else {
|
||||
setUploadError(`上传失败: ${xhr.statusText}`);
|
||||
}
|
||||
};
|
||||
|
||||
xhr.onerror = () => {
|
||||
fetchMaterials();
|
||||
setUploadData("");
|
||||
} catch (err: any) {
|
||||
console.error("Upload failed:", err);
|
||||
setIsUploading(false);
|
||||
setUploadError('网络错误,上传失败');
|
||||
};
|
||||
|
||||
xhr.open('POST', `${API_BASE}/api/materials/`);
|
||||
xhr.send(formData);
|
||||
const errorMsg = err.response?.data?.detail || err.message || String(err);
|
||||
setUploadError(`上传失败: ${errorMsg}`);
|
||||
}
|
||||
|
||||
// 清空 input 以便可以再次选择同一文件
|
||||
e.target.value = '';
|
||||
};
|
||||
|
||||
|
||||
|
||||
// 生成视频
|
||||
const handleGenerate = async () => {
|
||||
if (!selectedMaterial || !text.trim()) {
|
||||
@@ -158,34 +207,34 @@ export default function Home() {
|
||||
}
|
||||
|
||||
// 创建生成任务
|
||||
const res = await fetch(`${API_BASE}/api/videos/generate`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
material_path: materialObj.path,
|
||||
text: text,
|
||||
voice: voice,
|
||||
add_subtitle: true,
|
||||
}),
|
||||
const { data } = await api.post('/api/videos/generate', {
|
||||
material_path: materialObj.path,
|
||||
text: text,
|
||||
voice: voice,
|
||||
add_subtitle: true,
|
||||
});
|
||||
|
||||
const data = await res.json();
|
||||
const taskId = data.task_id;
|
||||
|
||||
// 轮询任务状态
|
||||
const pollTask = async () => {
|
||||
const taskRes = await fetch(`${API_BASE}/api/videos/tasks/${taskId}`);
|
||||
const taskData: Task = await taskRes.json();
|
||||
setCurrentTask(taskData);
|
||||
try {
|
||||
const { data: taskData } = await api.get(`/api/videos/tasks/${taskId}`);
|
||||
setCurrentTask(taskData);
|
||||
|
||||
if (taskData.status === "completed") {
|
||||
setGeneratedVideo(`${API_BASE}${taskData.download_url}`);
|
||||
if (taskData.status === "completed") {
|
||||
setGeneratedVideo(`${API_BASE}${taskData.download_url}`);
|
||||
setIsGenerating(false);
|
||||
fetchGeneratedVideos(); // 刷新历史视频列表
|
||||
} else if (taskData.status === "failed") {
|
||||
alert("视频生成失败: " + taskData.message);
|
||||
setIsGenerating(false);
|
||||
} else {
|
||||
setTimeout(pollTask, 1000);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("轮询任务失败:", error);
|
||||
setIsGenerating(false);
|
||||
} else if (taskData.status === "failed") {
|
||||
alert("视频生成失败: " + taskData.message);
|
||||
setIsGenerating(false);
|
||||
} else {
|
||||
setTimeout(pollTask, 1000);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -197,7 +246,7 @@ export default function Home() {
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="min-h-screen bg-gradient-to-br from-slate-900 via-purple-900 to-slate-900">
|
||||
<div className="min-h-dvh">
|
||||
{/* Header <header className="border-b border-white/10 bg-black/20 backdrop-blur-sm">
|
||||
<div className="max-w-6xl mx-auto px-6 py-4 flex items-center justify-between">
|
||||
<h1 className="text-2xl font-bold text-white flex items-center gap-3">
|
||||
@@ -218,21 +267,34 @@ export default function Home() {
|
||||
</div>
|
||||
</header> */}
|
||||
<header className="border-b border-white/10 bg-black/20 backdrop-blur-sm">
|
||||
<div className="max-w-6xl mx-auto px-6 py-4 flex items-center justify-between">
|
||||
<Link href="/" className="text-2xl font-bold text-white flex items-center gap-3 hover:opacity-80 transition-opacity">
|
||||
<span className="text-4xl">🎬</span>
|
||||
<div className="max-w-6xl mx-auto px-4 sm:px-6 py-3 sm:py-4 flex items-center justify-between">
|
||||
<Link href="/" className="text-xl sm:text-2xl font-bold text-white flex items-center gap-2 sm:gap-3 hover:opacity-80 transition-opacity">
|
||||
<span className="text-3xl sm:text-4xl">🎬</span>
|
||||
ViGent
|
||||
</Link>
|
||||
<div className="flex items-center gap-4">
|
||||
<span className="px-4 py-2 bg-gradient-to-r from-purple-600 to-pink-600 text-white rounded-lg font-semibold">
|
||||
<div className="flex items-center gap-1 sm:gap-4">
|
||||
<span className="px-2 sm:px-4 py-1 sm:py-2 text-sm sm:text-base bg-gradient-to-r from-purple-600 to-pink-600 text-white rounded-lg font-semibold">
|
||||
视频生成
|
||||
</span>
|
||||
<Link
|
||||
href="/publish"
|
||||
className="px-4 py-2 bg-white/10 hover:bg-white/20 text-white rounded-lg transition-colors"
|
||||
className="px-2 sm:px-4 py-1 sm:py-2 text-sm sm:text-base bg-white/10 hover:bg-white/20 text-white rounded-lg transition-colors"
|
||||
>
|
||||
发布管理
|
||||
</Link>
|
||||
<button
|
||||
onClick={async () => {
|
||||
if (confirm('确定要退出登录吗?')) {
|
||||
try {
|
||||
await api.post('/api/auth/logout');
|
||||
} catch (e) { }
|
||||
window.location.href = '/login';
|
||||
}
|
||||
}}
|
||||
className="px-2 sm:px-4 py-1 sm:py-2 text-sm sm:text-base bg-red-500/10 hover:bg-red-500/20 text-red-200 rounded-lg transition-colors"
|
||||
>
|
||||
退出
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</header>
|
||||
@@ -242,12 +304,12 @@ export default function Home() {
|
||||
{/* 左侧: 输入区域 */}
|
||||
<div className="space-y-6">
|
||||
{/* 素材选择 */}
|
||||
<div className="bg-white/5 rounded-2xl p-6 border border-white/10 backdrop-blur-sm">
|
||||
<div className="flex justify-between items-center mb-4">
|
||||
<h2 className="text-lg font-semibold text-white flex items-center gap-2">
|
||||
<div className="bg-white/5 rounded-2xl p-4 sm:p-6 border border-white/10 backdrop-blur-sm">
|
||||
<div className="flex justify-between items-center gap-2 mb-4">
|
||||
<h2 className="text-base sm:text-lg font-semibold text-white flex items-center gap-2 whitespace-nowrap">
|
||||
📹 选择素材视频
|
||||
</h2>
|
||||
<div className="flex gap-2">
|
||||
<div className="flex gap-1.5">
|
||||
{/* 隐藏的文件输入 */}
|
||||
<input
|
||||
type="file"
|
||||
@@ -258,16 +320,16 @@ export default function Home() {
|
||||
/>
|
||||
<label
|
||||
htmlFor="video-upload"
|
||||
className={`px-3 py-1 text-xs rounded cursor-pointer transition-all ${isUploading
|
||||
className={`px-2 py-1 text-xs rounded cursor-pointer transition-all whitespace-nowrap ${isUploading
|
||||
? "bg-gray-600 cursor-not-allowed text-gray-400"
|
||||
: "bg-gradient-to-r from-purple-600 to-pink-600 hover:from-purple-700 hover:to-pink-700 text-white"
|
||||
}`}
|
||||
>
|
||||
📤 上传视频
|
||||
📤 上传
|
||||
</label>
|
||||
<button
|
||||
onClick={fetchMaterials}
|
||||
className="px-3 py-1 text-xs bg-white/10 hover:bg-white/20 rounded text-gray-300"
|
||||
className="px-2 py-1 text-xs bg-white/10 hover:bg-white/20 rounded text-gray-300 whitespace-nowrap"
|
||||
>
|
||||
🔄 刷新
|
||||
</button>
|
||||
@@ -320,21 +382,35 @@ export default function Home() {
|
||||
) : (
|
||||
<div className="grid grid-cols-2 gap-3">
|
||||
{materials.map((m) => (
|
||||
<button
|
||||
<div
|
||||
key={m.id}
|
||||
onClick={() => setSelectedMaterial(m.id)}
|
||||
className={`p-4 rounded-xl border-2 transition-all text-left ${selectedMaterial === m.id
|
||||
className={`p-4 rounded-xl border-2 transition-all text-left relative group ${selectedMaterial === m.id
|
||||
? "border-purple-500 bg-purple-500/20"
|
||||
: "border-white/10 bg-white/5 hover:border-white/30"
|
||||
}`}
|
||||
>
|
||||
<div className="text-white font-medium truncate">
|
||||
{m.scene || m.name}
|
||||
</div>
|
||||
<div className="text-gray-400 text-sm mt-1">
|
||||
{m.size_mb.toFixed(1)} MB
|
||||
</div>
|
||||
</button>
|
||||
<button
|
||||
onClick={() => setSelectedMaterial(m.id)}
|
||||
className="w-full text-left"
|
||||
>
|
||||
<div className="text-white font-medium truncate pr-6">
|
||||
{m.scene || m.name}
|
||||
</div>
|
||||
<div className="text-gray-400 text-sm mt-1">
|
||||
{m.size_mb.toFixed(1)} MB
|
||||
</div>
|
||||
</button>
|
||||
<button
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
deleteMaterial(m.id);
|
||||
}}
|
||||
className="absolute top-2 right-2 p-1 text-gray-500 hover:text-red-400 opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
title="删除素材"
|
||||
>
|
||||
🗑️
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
@@ -471,16 +547,66 @@ export default function Home() {
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* 历史视频列表 */}
|
||||
<div className="bg-white/5 rounded-2xl p-6 border border-white/10 backdrop-blur-sm">
|
||||
<div className="flex justify-between items-center mb-4">
|
||||
<h2 className="text-lg font-semibold text-white flex items-center gap-2">
|
||||
📂 历史视频
|
||||
</h2>
|
||||
<button
|
||||
onClick={fetchGeneratedVideos}
|
||||
className="px-3 py-1 text-xs bg-white/10 hover:bg-white/20 rounded text-gray-300"
|
||||
>
|
||||
🔄 刷新
|
||||
</button>
|
||||
</div>
|
||||
{generatedVideos.length === 0 ? (
|
||||
<div className="text-center py-4 text-gray-500">
|
||||
<p>暂无生成的视频</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-2 max-h-64 overflow-y-auto hide-scrollbar">
|
||||
{generatedVideos.map((v) => (
|
||||
<div
|
||||
key={v.id}
|
||||
className={`p-3 rounded-lg border transition-all flex items-center justify-between group ${selectedVideoId === v.id
|
||||
? "border-purple-500 bg-purple-500/20"
|
||||
: "border-white/10 bg-white/5 hover:border-white/30"
|
||||
}`}
|
||||
>
|
||||
<button
|
||||
onClick={() => {
|
||||
setSelectedVideoId(v.id);
|
||||
setGeneratedVideo(`${API_BASE}${v.path}`);
|
||||
}}
|
||||
className="flex-1 text-left"
|
||||
>
|
||||
<div className="text-white text-sm truncate">
|
||||
{formatDate(v.created_at)}
|
||||
</div>
|
||||
<div className="text-gray-400 text-xs">
|
||||
{v.size_mb.toFixed(1)} MB
|
||||
</div>
|
||||
</button>
|
||||
<button
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
deleteVideo(v.id);
|
||||
}}
|
||||
className="p-1 text-gray-500 hover:text-red-400 opacity-0 group-hover:opacity-100 transition-opacity"
|
||||
title="删除视频"
|
||||
>
|
||||
🗑️
|
||||
</button>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</main>
|
||||
|
||||
{/* Footer */}
|
||||
<footer className="border-t border-white/10 mt-12">
|
||||
<div className="max-w-6xl mx-auto px-6 py-4 text-center text-gray-500 text-sm">
|
||||
ViGent - 基于 MuseTalk + EdgeTTS
|
||||
</div>
|
||||
</footer>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,12 +1,28 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
import useSWR from 'swr';
|
||||
import Link from "next/link";
|
||||
import api from "@/lib/axios";
|
||||
|
||||
// SWR fetcher 使用 axios(自动处理 401/403)
|
||||
const fetcher = (url: string) => api.get(url).then((res) => res.data);
|
||||
|
||||
// 动态获取 API 地址:服务端使用 localhost,客户端使用当前域名
|
||||
const API_BASE = typeof window !== 'undefined'
|
||||
? `http://${window.location.hostname}:8006`
|
||||
: 'http://localhost:8006';
|
||||
const API_BASE = typeof window === 'undefined'
|
||||
? 'http://localhost:8006'
|
||||
: '';
|
||||
|
||||
// 格式化日期(避免 Hydration 错误)
|
||||
const formatDate = (timestamp: number) => {
|
||||
const d = new Date(timestamp * 1000);
|
||||
const year = d.getFullYear();
|
||||
const month = String(d.getMonth() + 1).padStart(2, '0');
|
||||
const day = String(d.getDate()).padStart(2, '0');
|
||||
const hour = String(d.getHours()).padStart(2, '0');
|
||||
const minute = String(d.getMinutes()).padStart(2, '0');
|
||||
return `${year}/${month}/${day} ${hour}:${minute}`;
|
||||
};
|
||||
|
||||
interface Account {
|
||||
platform: string;
|
||||
@@ -33,6 +49,7 @@ export default function PublishPage() {
|
||||
const [publishTime, setPublishTime] = useState<string>("");
|
||||
const [qrCodeImage, setQrCodeImage] = useState<string | null>(null);
|
||||
const [qrPlatform, setQrPlatform] = useState<string | null>(null);
|
||||
const [isLoadingQR, setIsLoadingQR] = useState(false);
|
||||
|
||||
// 加载账号和视频列表
|
||||
useEffect(() => {
|
||||
@@ -42,8 +59,7 @@ export default function PublishPage() {
|
||||
|
||||
const fetchAccounts = async () => {
|
||||
try {
|
||||
const res = await fetch(`${API_BASE}/api/publish/accounts`);
|
||||
const data = await res.json();
|
||||
const { data } = await api.get('/api/publish/accounts');
|
||||
setAccounts(data.accounts || []);
|
||||
} catch (error) {
|
||||
console.error("获取账号失败:", error);
|
||||
@@ -52,20 +68,16 @@ export default function PublishPage() {
|
||||
|
||||
const fetchVideos = async () => {
|
||||
try {
|
||||
// 获取已生成的视频列表 (从 outputs 目录)
|
||||
const res = await fetch(`${API_BASE}/api/videos/tasks`);
|
||||
const data = await res.json();
|
||||
const { data } = await api.get('/api/videos/generated');
|
||||
|
||||
const completedVideos = data.tasks
|
||||
?.filter((t: any) => t.status === "completed")
|
||||
.map((t: any) => ({
|
||||
name: `${t.task_id}_output.mp4`,
|
||||
path: `outputs/${t.task_id}_output.mp4`,
|
||||
})) || [];
|
||||
const videos = (data.videos || []).map((v: any) => ({
|
||||
name: formatDate(v.created_at) + ` (${v.size_mb.toFixed(1)}MB)`,
|
||||
path: v.path.startsWith('/') ? v.path.slice(1) : v.path,
|
||||
}));
|
||||
|
||||
setVideos(completedVideos);
|
||||
if (completedVideos.length > 0) {
|
||||
setSelectedVideo(completedVideos[0].path);
|
||||
setVideos(videos);
|
||||
if (videos.length > 0) {
|
||||
setSelectedVideo(videos[0].path);
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("获取视频失败:", error);
|
||||
@@ -93,27 +105,29 @@ export default function PublishPage() {
|
||||
|
||||
for (const platform of selectedPlatforms) {
|
||||
try {
|
||||
const res = await fetch(`${API_BASE}/api/publish/`, {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
video_path: selectedVideo,
|
||||
platform,
|
||||
title,
|
||||
tags: tagList,
|
||||
description: "",
|
||||
publish_time: scheduleMode === "scheduled" && publishTime
|
||||
? new Date(publishTime).toISOString()
|
||||
: null
|
||||
}),
|
||||
const { data: result } = await api.post('/api/publish', {
|
||||
video_path: selectedVideo,
|
||||
platform,
|
||||
title,
|
||||
tags: tagList,
|
||||
description: "",
|
||||
publish_time: scheduleMode === "scheduled" && publishTime
|
||||
? new Date(publishTime).toISOString()
|
||||
: null
|
||||
});
|
||||
|
||||
const result = await res.json();
|
||||
setPublishResults((prev) => [...prev, result]);
|
||||
} catch (error) {
|
||||
// 发布成功后10秒自动清除结果
|
||||
if (result.success) {
|
||||
setTimeout(() => {
|
||||
setPublishResults((prev) => prev.filter((r) => r !== result));
|
||||
}, 10000);
|
||||
}
|
||||
} catch (error: any) {
|
||||
const message = error.response?.data?.detail || String(error);
|
||||
setPublishResults((prev) => [
|
||||
...prev,
|
||||
{ platform, success: false, message: String(error) },
|
||||
{ platform, success: false, message },
|
||||
]);
|
||||
}
|
||||
}
|
||||
@@ -121,45 +135,71 @@ export default function PublishPage() {
|
||||
setIsPublishing(false);
|
||||
};
|
||||
|
||||
// SWR Polling for Login Status
|
||||
const { data: loginStatus } = useSWR(
|
||||
qrPlatform ? `${API_BASE}/api/publish/login/status/${qrPlatform}` : null,
|
||||
fetcher,
|
||||
{
|
||||
refreshInterval: 2000,
|
||||
onSuccess: (data) => {
|
||||
if (data.success) {
|
||||
setQrCodeImage(null);
|
||||
setQrPlatform(null);
|
||||
alert('✅ 登录成功!');
|
||||
fetchAccounts();
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
|
||||
// Timeout logic for QR code (business logic: stop after 2 mins)
|
||||
useEffect(() => {
|
||||
let timer: NodeJS.Timeout;
|
||||
if (qrPlatform) {
|
||||
timer = setTimeout(() => {
|
||||
if (qrPlatform) { // Double check active
|
||||
setQrPlatform(null);
|
||||
setQrCodeImage(null);
|
||||
alert('登录超时,请重试');
|
||||
}
|
||||
}, 120000);
|
||||
}
|
||||
return () => clearTimeout(timer);
|
||||
}, [qrPlatform]);
|
||||
|
||||
const handleLogin = async (platform: string) => {
|
||||
setIsLoadingQR(true);
|
||||
setQrPlatform(platform); // 立即显示加载弹窗
|
||||
setQrCodeImage(null); // 清空旧二维码
|
||||
try {
|
||||
const res = await fetch(`${API_BASE}/api/publish/login/${platform}`, {
|
||||
method: 'POST'
|
||||
});
|
||||
const result = await res.json();
|
||||
const { data: result } = await api.post(`/api/publish/login/${platform}`);
|
||||
|
||||
if (result.success && result.qr_code) {
|
||||
// 显示二维码
|
||||
setQrCodeImage(result.qr_code);
|
||||
setQrPlatform(platform);
|
||||
|
||||
// 轮询登录状态
|
||||
const checkInterval = setInterval(async () => {
|
||||
const statusRes = await fetch(`${API_BASE}/api/publish/login/status/${platform}`);
|
||||
const statusData = await statusRes.json();
|
||||
|
||||
if (statusData.success) {
|
||||
clearInterval(checkInterval);
|
||||
setQrCodeImage(null);
|
||||
setQrPlatform(null);
|
||||
alert('✅ 登录成功!');
|
||||
fetchAccounts(); // 刷新账号状态
|
||||
}
|
||||
}, 2000); // 每2秒检查一次
|
||||
|
||||
// 2分钟后停止轮询
|
||||
setTimeout(() => {
|
||||
clearInterval(checkInterval);
|
||||
if (qrCodeImage) {
|
||||
setQrCodeImage(null);
|
||||
alert('登录超时,请重试');
|
||||
}
|
||||
}, 120000);
|
||||
} else {
|
||||
setQrPlatform(null);
|
||||
alert(result.message || '登录失败');
|
||||
}
|
||||
} catch (error) {
|
||||
alert(`登录失败: ${error}`);
|
||||
} catch (error: any) {
|
||||
setQrPlatform(null);
|
||||
alert(`登录失败: ${error.response?.data?.detail || error.message}`);
|
||||
} finally {
|
||||
setIsLoadingQR(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleLogout = async (platform: string) => {
|
||||
if (!confirm('确定要注销登录吗?')) return;
|
||||
try {
|
||||
const { data: result } = await api.post(`/api/publish/logout/${platform}`);
|
||||
if (result.success) {
|
||||
alert('已注销');
|
||||
fetchAccounts();
|
||||
} else {
|
||||
alert(result.message || '注销失败');
|
||||
}
|
||||
} catch (error: any) {
|
||||
alert(`注销失败: ${error.response?.data?.detail || error.message}`);
|
||||
}
|
||||
};
|
||||
|
||||
@@ -172,22 +212,31 @@ export default function PublishPage() {
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="min-h-screen bg-gradient-to-br from-gray-900 via-purple-900 to-gray-900">
|
||||
<div className="min-h-dvh">
|
||||
{/* QR码弹窗 */}
|
||||
{qrCodeImage && (
|
||||
{qrPlatform && (
|
||||
<div className="fixed inset-0 bg-black/80 flex items-center justify-center z-50">
|
||||
<div className="bg-white rounded-2xl p-8 max-w-md">
|
||||
<div className="bg-white rounded-2xl p-8 max-w-md min-w-[320px]">
|
||||
<h2 className="text-2xl font-bold mb-4 text-center">🔐 扫码登录 {qrPlatform}</h2>
|
||||
<img
|
||||
src={`data:image/png;base64,${qrCodeImage}`}
|
||||
alt="QR Code"
|
||||
className="w-full h-auto"
|
||||
/>
|
||||
<p className="text-center text-gray-600 mt-4">
|
||||
请使用手机扫码登录
|
||||
</p>
|
||||
{isLoadingQR ? (
|
||||
<div className="flex flex-col items-center py-8">
|
||||
<div className="animate-spin w-16 h-16 border-4 border-purple-500 border-t-transparent rounded-full" />
|
||||
<p className="text-gray-600 mt-4">正在获取二维码...</p>
|
||||
</div>
|
||||
) : qrCodeImage ? (
|
||||
<>
|
||||
<img
|
||||
src={`data:image/png;base64,${qrCodeImage}`}
|
||||
alt="QR Code"
|
||||
className="w-full h-auto"
|
||||
/>
|
||||
<p className="text-center text-gray-600 mt-4">
|
||||
请使用手机扫码登录
|
||||
</p>
|
||||
</>
|
||||
) : null}
|
||||
<button
|
||||
onClick={() => setQrCodeImage(null)}
|
||||
onClick={() => { setQrCodeImage(null); setQrPlatform(null); }}
|
||||
className="w-full mt-4 px-4 py-2 bg-gray-200 rounded-lg hover:bg-gray-300"
|
||||
>
|
||||
取消
|
||||
@@ -198,21 +247,34 @@ export default function PublishPage() {
|
||||
|
||||
{/* Header - 统一样式 */}
|
||||
<header className="border-b border-white/10 bg-black/20 backdrop-blur-sm">
|
||||
<div className="max-w-6xl mx-auto px-6 py-4 flex items-center justify-between">
|
||||
<Link href="/" className="text-2xl font-bold text-white flex items-center gap-3 hover:opacity-80 transition-opacity">
|
||||
<span className="text-4xl">🎬</span>
|
||||
<div className="max-w-6xl mx-auto px-4 sm:px-6 py-3 sm:py-4 flex items-center justify-between">
|
||||
<Link href="/" className="text-xl sm:text-2xl font-bold text-white flex items-center gap-2 sm:gap-3 hover:opacity-80 transition-opacity">
|
||||
<span className="text-3xl sm:text-4xl">🎬</span>
|
||||
ViGent
|
||||
</Link>
|
||||
<div className="flex items-center gap-4">
|
||||
<div className="flex items-center gap-1 sm:gap-4">
|
||||
<Link
|
||||
href="/"
|
||||
className="px-4 py-2 bg-white/10 hover:bg-white/20 text-white rounded-lg transition-colors"
|
||||
className="px-2 sm:px-4 py-1 sm:py-2 text-sm sm:text-base bg-white/10 hover:bg-white/20 text-white rounded-lg transition-colors"
|
||||
>
|
||||
视频生成
|
||||
返回创作
|
||||
</Link>
|
||||
<span className="px-4 py-2 bg-gradient-to-r from-purple-600 to-pink-600 text-white rounded-lg font-semibold">
|
||||
<span className="px-2 sm:px-4 py-1 sm:py-2 text-sm sm:text-base bg-gradient-to-r from-purple-600 to-pink-600 text-white rounded-lg font-semibold">
|
||||
发布管理
|
||||
</span>
|
||||
<button
|
||||
onClick={async () => {
|
||||
if (confirm('确定要退出登录吗?')) {
|
||||
try {
|
||||
await api.post('/api/auth/logout');
|
||||
} catch (e) { }
|
||||
window.location.href = '/login';
|
||||
}
|
||||
}}
|
||||
className="px-2 sm:px-4 py-1 sm:py-2 text-sm sm:text-base bg-red-500/10 hover:bg-red-500/20 text-red-200 rounded-lg transition-colors"
|
||||
>
|
||||
退出
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</header>
|
||||
@@ -250,12 +312,31 @@ export default function PublishPage() {
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<button
|
||||
onClick={() => handleLogin(account.platform)}
|
||||
className="px-3 py-1 bg-purple-600 hover:bg-purple-700 text-white text-sm rounded-lg transition-colors"
|
||||
>
|
||||
🔐 扫码登录
|
||||
</button>
|
||||
<div className="flex gap-2">
|
||||
{account.logged_in ? (
|
||||
<>
|
||||
<button
|
||||
onClick={() => handleLogin(account.platform)}
|
||||
className="px-3 py-1 bg-white/10 hover:bg-white/20 text-white text-sm rounded-lg transition-colors"
|
||||
>
|
||||
↻ 重新登录
|
||||
</button>
|
||||
<button
|
||||
onClick={() => handleLogout(account.platform)}
|
||||
className="px-3 py-1 bg-red-500/80 hover:bg-red-600 text-white text-sm rounded-lg transition-colors"
|
||||
>
|
||||
注销
|
||||
</button>
|
||||
</>
|
||||
) : (
|
||||
<button
|
||||
onClick={() => handleLogin(account.platform)}
|
||||
className="px-3 py-1 bg-purple-600 hover:bg-purple-700 text-white text-sm rounded-lg transition-colors"
|
||||
>
|
||||
🔐 扫码登录
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
@@ -281,7 +362,7 @@ export default function PublishPage() {
|
||||
<select
|
||||
value={selectedVideo}
|
||||
onChange={(e) => setSelectedVideo(e.target.value)}
|
||||
className="w-full p-3 bg-black/30 border border-white/10 rounded-xl text-white"
|
||||
className="w-full p-3 bg-black/30 border border-white/10 rounded-xl text-white custom-select cursor-pointer hover:border-purple-500/50 transition-colors"
|
||||
>
|
||||
{videos.map((v) => (
|
||||
<option key={v.path} value={v.path}>
|
||||
@@ -321,40 +402,6 @@ export default function PublishPage() {
|
||||
className="w-full p-3 bg-black/30 border border-white/10 rounded-xl text-white placeholder-gray-500"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<label className="block text-gray-400 text-sm mb-2">
|
||||
发布时间
|
||||
</label>
|
||||
<div className="flex gap-3 mb-3">
|
||||
<button
|
||||
onClick={() => setScheduleMode("now")}
|
||||
className={`flex-1 px-4 py-2 rounded-lg font-medium transition-colors ${scheduleMode === "now"
|
||||
? "bg-purple-600 text-white"
|
||||
: "bg-black/30 text-gray-400 hover:bg-black/50"
|
||||
}`}
|
||||
>
|
||||
⚡ 立即发布
|
||||
</button>
|
||||
<button
|
||||
onClick={() => setScheduleMode("scheduled")}
|
||||
className={`flex-1 px-4 py-2 rounded-lg font-medium transition-colors ${scheduleMode === "scheduled"
|
||||
? "bg-purple-600 text-white"
|
||||
: "bg-black/30 text-gray-400 hover:bg-black/50"
|
||||
}`}
|
||||
>
|
||||
⏰ 定时发布
|
||||
</button>
|
||||
</div>
|
||||
{scheduleMode === "scheduled" && (
|
||||
<input
|
||||
type="datetime-local"
|
||||
value={publishTime}
|
||||
onChange={(e) => setPublishTime(e.target.value)}
|
||||
min={new Date().toISOString().slice(0, 16)}
|
||||
className="w-full p-3 bg-black/30 border border-white/10 rounded-xl text-white"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -389,17 +436,61 @@ export default function PublishPage() {
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* 发布按钮 */}
|
||||
<button
|
||||
onClick={handlePublish}
|
||||
disabled={isPublishing || selectedPlatforms.length === 0}
|
||||
className={`w-full py-4 rounded-xl font-bold text-lg transition-all ${isPublishing || selectedPlatforms.length === 0
|
||||
? "bg-gray-600 cursor-not-allowed text-gray-400"
|
||||
: "bg-gradient-to-r from-green-600 to-teal-600 hover:from-green-700 hover:to-teal-700 text-white"
|
||||
}`}
|
||||
>
|
||||
{isPublishing ? "发布中..." : "🚀 一键发布"}
|
||||
</button>
|
||||
{/* 发布按钮区域 */}
|
||||
<div className="space-y-3">
|
||||
<div className="flex gap-3">
|
||||
{/* 立即发布 - 占 3/4 */}
|
||||
<button
|
||||
onClick={() => {
|
||||
setScheduleMode("now");
|
||||
handlePublish();
|
||||
}}
|
||||
disabled={isPublishing || selectedPlatforms.length === 0}
|
||||
className={`flex-[3] py-4 rounded-xl font-bold text-lg transition-all ${isPublishing || selectedPlatforms.length === 0
|
||||
? "bg-gray-600 cursor-not-allowed text-gray-400"
|
||||
: "bg-gradient-to-r from-green-600 to-teal-600 hover:from-green-700 hover:to-teal-700 text-white"
|
||||
}`}
|
||||
>
|
||||
{isPublishing && scheduleMode === "now" ? "发布中..." : "🚀 立即发布"}
|
||||
</button>
|
||||
{/* 定时发布 - 占 1/4 */}
|
||||
<button
|
||||
onClick={() => setScheduleMode(scheduleMode === "scheduled" ? "now" : "scheduled")}
|
||||
disabled={isPublishing || selectedPlatforms.length === 0}
|
||||
className={`flex-1 py-4 rounded-xl font-bold text-base transition-all ${isPublishing || selectedPlatforms.length === 0
|
||||
? "bg-gray-600 cursor-not-allowed text-gray-400"
|
||||
: scheduleMode === "scheduled"
|
||||
? "bg-purple-600 text-white"
|
||||
: "bg-white/10 hover:bg-white/20 text-white"
|
||||
}`}
|
||||
>
|
||||
⏰ 定时
|
||||
</button>
|
||||
</div>
|
||||
|
||||
{/* 定时发布时间选择器 */}
|
||||
{scheduleMode === "scheduled" && (
|
||||
<div className="flex gap-3 items-center">
|
||||
<input
|
||||
type="datetime-local"
|
||||
value={publishTime}
|
||||
onChange={(e) => setPublishTime(e.target.value)}
|
||||
min={new Date().toISOString().slice(0, 16)}
|
||||
className="flex-1 p-3 bg-black/30 border border-white/10 rounded-xl text-white"
|
||||
/>
|
||||
<button
|
||||
onClick={handlePublish}
|
||||
disabled={isPublishing || selectedPlatforms.length === 0 || !publishTime}
|
||||
className={`px-6 py-3 rounded-xl font-bold transition-all ${isPublishing || selectedPlatforms.length === 0 || !publishTime
|
||||
? "bg-gray-600 cursor-not-allowed text-gray-400"
|
||||
: "bg-gradient-to-r from-purple-600 to-pink-600 hover:from-purple-700 hover:to-pink-700 text-white"
|
||||
}`}
|
||||
>
|
||||
{isPublishing && scheduleMode === "scheduled" ? "设置中..." : "确认定时"}
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* 发布结果 */}
|
||||
{publishResults.length > 0 && (
|
||||
@@ -417,6 +508,11 @@ export default function PublishPage() {
|
||||
<span className="text-white">
|
||||
{platformIcons[result.platform]} {result.message}
|
||||
</span>
|
||||
{result.success && (
|
||||
<p className="text-green-400/80 text-sm mt-1">
|
||||
⏳ 审核一般需要几分钟,请耐心等待
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
158
frontend/src/app/register/page.tsx
Normal file
158
frontend/src/app/register/page.tsx
Normal file
@@ -0,0 +1,158 @@
|
||||
'use client';
|
||||
|
||||
import { useState } from 'react';
|
||||
import { useRouter } from 'next/navigation';
|
||||
import { register } from '@/lib/auth';
|
||||
|
||||
export default function RegisterPage() {
|
||||
const router = useRouter();
|
||||
const [email, setEmail] = useState('');
|
||||
const [password, setPassword] = useState('');
|
||||
const [confirmPassword, setConfirmPassword] = useState('');
|
||||
const [username, setUsername] = useState('');
|
||||
const [error, setError] = useState('');
|
||||
const [success, setSuccess] = useState(false);
|
||||
const [loading, setLoading] = useState(false);
|
||||
|
||||
const handleSubmit = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
setError('');
|
||||
|
||||
if (password !== confirmPassword) {
|
||||
setError('两次输入的密码不一致');
|
||||
return;
|
||||
}
|
||||
|
||||
if (password.length < 6) {
|
||||
setError('密码长度至少 6 位');
|
||||
return;
|
||||
}
|
||||
|
||||
setLoading(true);
|
||||
|
||||
try {
|
||||
const result = await register(email, password, username || undefined);
|
||||
if (result.success) {
|
||||
setSuccess(true);
|
||||
} else {
|
||||
setError(result.message || '注册失败');
|
||||
}
|
||||
} catch (err) {
|
||||
setError('网络错误,请稍后重试');
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
if (success) {
|
||||
return (
|
||||
<div className="min-h-dvh flex items-center justify-center">
|
||||
<div className="w-full max-w-md p-8 bg-white/10 backdrop-blur-lg rounded-2xl shadow-2xl border border-white/20 text-center">
|
||||
<div className="mb-6">
|
||||
<svg className="w-16 h-16 mx-auto text-green-400" fill="none" stroke="currentColor" viewBox="0 0 24 24">
|
||||
<path strokeLinecap="round" strokeLinejoin="round" strokeWidth={2} d="M9 12l2 2 4-4m6 2a9 9 0 11-18 0 9 9 0 0118 0z" />
|
||||
</svg>
|
||||
</div>
|
||||
<h2 className="text-2xl font-bold text-white mb-4">注册成功!</h2>
|
||||
<p className="text-gray-300 mb-6">
|
||||
您的账号已创建,请等待管理员审核激活后即可登录。
|
||||
</p>
|
||||
<a
|
||||
href="/login"
|
||||
className="inline-block py-3 px-6 bg-gradient-to-r from-purple-600 to-pink-600 text-white font-semibold rounded-lg"
|
||||
>
|
||||
返回登录
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="min-h-dvh flex items-center justify-center">
|
||||
<div className="w-full max-w-md p-8 bg-white/10 backdrop-blur-lg rounded-2xl shadow-2xl border border-white/20">
|
||||
<div className="text-center mb-8">
|
||||
<h1 className="text-3xl font-bold text-white mb-2">注册账号</h1>
|
||||
<p className="text-gray-300">创建您的 ViGent 账号</p>
|
||||
</div>
|
||||
|
||||
<form onSubmit={handleSubmit} className="space-y-5">
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-gray-200 mb-2">
|
||||
邮箱 <span className="text-red-400">*</span>
|
||||
</label>
|
||||
<input
|
||||
type="email"
|
||||
value={email}
|
||||
onChange={(e) => setEmail(e.target.value)}
|
||||
required
|
||||
className="w-full px-4 py-3 bg-white/5 border border-white/10 rounded-lg text-white placeholder-gray-400 focus:outline-none focus:ring-2 focus:ring-purple-500"
|
||||
placeholder="your@email.com"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-gray-200 mb-2">
|
||||
用户名 <span className="text-gray-500">(可选)</span>
|
||||
</label>
|
||||
<input
|
||||
type="text"
|
||||
value={username}
|
||||
onChange={(e) => setUsername(e.target.value)}
|
||||
className="w-full px-4 py-3 bg-white/5 border border-white/10 rounded-lg text-white placeholder-gray-400 focus:outline-none focus:ring-2 focus:ring-purple-500"
|
||||
placeholder="您的昵称"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-gray-200 mb-2">
|
||||
密码 <span className="text-red-400">*</span>
|
||||
</label>
|
||||
<input
|
||||
type="password"
|
||||
value={password}
|
||||
onChange={(e) => setPassword(e.target.value)}
|
||||
required
|
||||
className="w-full px-4 py-3 bg-white/5 border border-white/10 rounded-lg text-white placeholder-gray-400 focus:outline-none focus:ring-2 focus:ring-purple-500"
|
||||
placeholder="至少 6 位"
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div>
|
||||
<label className="block text-sm font-medium text-gray-200 mb-2">
|
||||
确认密码 <span className="text-red-400">*</span>
|
||||
</label>
|
||||
<input
|
||||
type="password"
|
||||
value={confirmPassword}
|
||||
onChange={(e) => setConfirmPassword(e.target.value)}
|
||||
required
|
||||
className="w-full px-4 py-3 bg-white/5 border border-white/10 rounded-lg text-white placeholder-gray-400 focus:outline-none focus:ring-2 focus:ring-purple-500"
|
||||
placeholder="再次输入密码"
|
||||
/>
|
||||
</div>
|
||||
|
||||
{error && (
|
||||
<div className="p-3 bg-red-500/20 border border-red-500/50 rounded-lg text-red-200 text-sm">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<button
|
||||
type="submit"
|
||||
disabled={loading}
|
||||
className="w-full py-3 px-4 bg-gradient-to-r from-purple-600 to-pink-600 hover:from-purple-700 hover:to-pink-700 text-white font-semibold rounded-lg shadow-lg transition-all duration-200 disabled:opacity-50"
|
||||
>
|
||||
{loading ? '注册中...' : '注册'}
|
||||
</button>
|
||||
</form>
|
||||
|
||||
<div className="mt-6 text-center">
|
||||
<a href="/login" className="text-purple-300 hover:text-purple-200 text-sm">
|
||||
已有账号?立即登录
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
89
frontend/src/lib/auth.ts
Normal file
89
frontend/src/lib/auth.ts
Normal file
@@ -0,0 +1,89 @@
|
||||
/**
|
||||
* 认证工具函数
|
||||
*/
|
||||
|
||||
const API_BASE = typeof window === 'undefined'
|
||||
? (process.env.NEXT_PUBLIC_API_URL || 'http://localhost:8006')
|
||||
: '';
|
||||
|
||||
export interface User {
|
||||
id: string;
|
||||
email: string;
|
||||
username: string | null;
|
||||
role: string;
|
||||
is_active: boolean;
|
||||
}
|
||||
|
||||
export interface AuthResponse {
|
||||
success: boolean;
|
||||
message: string;
|
||||
user?: User;
|
||||
}
|
||||
|
||||
/**
|
||||
* 用户注册
|
||||
*/
|
||||
export async function register(email: string, password: string, username?: string): Promise<AuthResponse> {
|
||||
const res = await fetch(`${API_BASE}/api/auth/register`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
credentials: 'include',
|
||||
body: JSON.stringify({ email, password, username })
|
||||
});
|
||||
return res.json();
|
||||
}
|
||||
|
||||
/**
|
||||
* 用户登录
|
||||
*/
|
||||
export async function login(email: string, password: string): Promise<AuthResponse> {
|
||||
const res = await fetch(`${API_BASE}/api/auth/login`, {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
credentials: 'include',
|
||||
body: JSON.stringify({ email, password })
|
||||
});
|
||||
return res.json();
|
||||
}
|
||||
|
||||
/**
|
||||
* 用户登出
|
||||
*/
|
||||
export async function logout(): Promise<AuthResponse> {
|
||||
const res = await fetch(`${API_BASE}/api/auth/logout`, {
|
||||
method: 'POST',
|
||||
credentials: 'include'
|
||||
});
|
||||
return res.json();
|
||||
}
|
||||
|
||||
/**
|
||||
* 获取当前用户
|
||||
*/
|
||||
export async function getCurrentUser(): Promise<User | null> {
|
||||
try {
|
||||
const res = await fetch(`${API_BASE}/api/auth/me`, {
|
||||
credentials: 'include'
|
||||
});
|
||||
if (!res.ok) return null;
|
||||
return res.json();
|
||||
} catch {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否已登录
|
||||
*/
|
||||
export async function isAuthenticated(): Promise<boolean> {
|
||||
const user = await getCurrentUser();
|
||||
return user !== null;
|
||||
}
|
||||
|
||||
/**
|
||||
* 检查是否是管理员
|
||||
*/
|
||||
export async function isAdmin(): Promise<boolean> {
|
||||
const user = await getCurrentUser();
|
||||
return user?.role === 'admin';
|
||||
}
|
||||
50
frontend/src/lib/axios.ts
Normal file
50
frontend/src/lib/axios.ts
Normal file
@@ -0,0 +1,50 @@
|
||||
/**
|
||||
* Axios 实例配置
|
||||
* 全局拦截 401/403 响应,自动跳转登录页
|
||||
*/
|
||||
import axios from 'axios';
|
||||
|
||||
// 动态获取 API 地址:服务端使用 localhost,客户端使用当前域名
|
||||
const API_BASE = typeof window === 'undefined'
|
||||
? 'http://localhost:8006'
|
||||
: '';
|
||||
|
||||
// 防止重复跳转
|
||||
let isRedirecting = false;
|
||||
|
||||
// 创建 axios 实例
|
||||
const api = axios.create({
|
||||
baseURL: API_BASE,
|
||||
withCredentials: true, // 自动携带 cookie
|
||||
headers: {
|
||||
'Content-Type': 'application/json',
|
||||
},
|
||||
});
|
||||
|
||||
// 响应拦截器 - 全局处理 401/403
|
||||
api.interceptors.response.use(
|
||||
(response) => response,
|
||||
async (error) => {
|
||||
const status = error.response?.status;
|
||||
|
||||
if ((status === 401 || status === 403) && !isRedirecting) {
|
||||
isRedirecting = true;
|
||||
|
||||
// 调用 logout API 清除 HttpOnly cookie
|
||||
try {
|
||||
await fetch('/api/auth/logout', { method: 'POST' });
|
||||
} catch (e) {
|
||||
// 忽略错误
|
||||
}
|
||||
|
||||
// 跳转登录页
|
||||
if (typeof window !== 'undefined') {
|
||||
window.location.replace('/login');
|
||||
}
|
||||
}
|
||||
|
||||
return Promise.reject(error);
|
||||
}
|
||||
);
|
||||
|
||||
export default api;
|
||||
33
frontend/src/middleware.ts
Normal file
33
frontend/src/middleware.ts
Normal file
@@ -0,0 +1,33 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import type { NextRequest } from 'next/server';
|
||||
|
||||
// 需要登录才能访问的路径
|
||||
const protectedPaths = ['/', '/publish', '/admin'];
|
||||
|
||||
// 公开路径 (无需登录)
|
||||
const publicPaths = ['/login', '/register'];
|
||||
|
||||
export function middleware(request: NextRequest) {
|
||||
const { pathname } = request.nextUrl;
|
||||
|
||||
// 检查是否有 access_token cookie
|
||||
const token = request.cookies.get('access_token');
|
||||
|
||||
// 访问受保护页面但未登录 → 重定向到登录页
|
||||
if (protectedPaths.some(path => pathname === path || pathname.startsWith(path + '/')) && !token) {
|
||||
const loginUrl = new URL('/login', request.url);
|
||||
loginUrl.searchParams.set('from', pathname);
|
||||
return NextResponse.redirect(loginUrl);
|
||||
}
|
||||
|
||||
// 已登录用户访问登录/注册页 → 重定向到首页
|
||||
if (publicPaths.includes(pathname) && token) {
|
||||
return NextResponse.redirect(new URL('/', request.url));
|
||||
}
|
||||
|
||||
return NextResponse.next();
|
||||
}
|
||||
|
||||
export const config = {
|
||||
matcher: ['/', '/publish/:path*', '/admin/:path*', '/login', '/register']
|
||||
};
|
||||
33
frontend/src/proxy.ts
Normal file
33
frontend/src/proxy.ts
Normal file
@@ -0,0 +1,33 @@
|
||||
import { NextResponse } from 'next/server';
|
||||
import type { NextRequest } from 'next/server';
|
||||
|
||||
// 需要登录才能访问的路径
|
||||
const protectedPaths = ['/', '/publish', '/admin'];
|
||||
|
||||
// 公开路径 (无需登录)
|
||||
const publicPaths = ['/login', '/register'];
|
||||
|
||||
export function proxy(request: NextRequest) {
|
||||
const { pathname } = request.nextUrl;
|
||||
|
||||
// 检查是否有 access_token cookie
|
||||
const token = request.cookies.get('access_token');
|
||||
|
||||
// 访问受保护页面但未登录 → 重定向到登录页
|
||||
if (protectedPaths.some(path => pathname === path || pathname.startsWith(path + '/')) && !token) {
|
||||
const loginUrl = new URL('/login', request.url);
|
||||
loginUrl.searchParams.set('from', pathname);
|
||||
return NextResponse.redirect(loginUrl);
|
||||
}
|
||||
|
||||
// 已登录用户访问登录/注册页 → 重定向到首页
|
||||
if (publicPaths.includes(pathname) && token) {
|
||||
return NextResponse.redirect(new URL('/', request.url));
|
||||
}
|
||||
|
||||
return NextResponse.next();
|
||||
}
|
||||
|
||||
export const config = {
|
||||
matcher: ['/', '/publish/:path*', '/admin/:path*', '/login', '/register']
|
||||
};
|
||||
@@ -14,6 +14,12 @@
|
||||
|
||||
import argparse
|
||||
import os
|
||||
|
||||
# --- 性能优化: 限制 CPU 线程数 ---
|
||||
os.environ["OMP_NUM_THREADS"] = "8"
|
||||
os.environ["MKL_NUM_THREADS"] = "8"
|
||||
os.environ["TORCH_NUM_THREADS"] = "8"
|
||||
|
||||
from omegaconf import OmegaConf
|
||||
import torch
|
||||
from diffusers import AutoencoderKL, DDIMScheduler
|
||||
|
||||
@@ -37,6 +37,14 @@ def load_gpu_config():
|
||||
|
||||
load_gpu_config()
|
||||
|
||||
# --- 性能优化: 限制 CPU 线程数 ---
|
||||
# 防止 PyTorch 默认占用所有 CPU 核心 (56线程) 导致系统卡顿
|
||||
# 预留资源给 Backend, Frontend 和 SSH
|
||||
os.environ["OMP_NUM_THREADS"] = "8"
|
||||
os.environ["MKL_NUM_THREADS"] = "8"
|
||||
os.environ["TORCH_NUM_THREADS"] = "8"
|
||||
print("⚙️ 已限制 PyTorch CPU 线程数为 8,防止系统卡顿")
|
||||
|
||||
import torch
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI, HTTPException
|
||||
|
||||
24
models/Qwen3-TTS/.gitignore
vendored
Normal file
24
models/Qwen3-TTS/.gitignore
vendored
Normal file
@@ -0,0 +1,24 @@
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
.idea/
|
||||
.vscode/
|
||||
venv/
|
||||
env/
|
||||
201
models/Qwen3-TTS/LICENSE
Normal file
201
models/Qwen3-TTS/LICENSE
Normal file
@@ -0,0 +1,201 @@
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2026 Alibaba Cloud
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
13
models/Qwen3-TTS/MANIFEST.in
Normal file
13
models/Qwen3-TTS/MANIFEST.in
Normal file
@@ -0,0 +1,13 @@
|
||||
global-exclude *
|
||||
|
||||
recursive-include qwen_tts *.py *.pyi py.typed
|
||||
recursive-include qwen_tts *.npz
|
||||
|
||||
include LICENSE
|
||||
include MANIFEST.in
|
||||
include pyproject.toml
|
||||
|
||||
prune assets
|
||||
prune examples
|
||||
prune finetuning
|
||||
prune qwen_tts.egg-info
|
||||
1361
models/Qwen3-TTS/README.md
Normal file
1361
models/Qwen3-TTS/README.md
Normal file
File diff suppressed because it is too large
Load Diff
121
models/Qwen3-TTS/finetuning/README.md
Normal file
121
models/Qwen3-TTS/finetuning/README.md
Normal file
@@ -0,0 +1,121 @@
|
||||
## Fine Tuning Qwen3-TTS-12Hz-1.7B/0.6B-Base
|
||||
|
||||
The Qwen3-TTS-12Hz-1.7B/0.6B-Base model series currently supports single-speaker fine-tuning. Please run `pip install qwen-tts` first, then run the command below:
|
||||
|
||||
```
|
||||
git clone https://github.com/QwenLM/Qwen3-TTS.git
|
||||
cd Qwen3-TTS/finetuning
|
||||
```
|
||||
|
||||
Then follow the steps below to complete the entire fine-tuning workflow. Multi-speaker fine-tuning and other advanced fine-tuning features will be supported in future releases.
|
||||
|
||||
### 1) Input JSONL format
|
||||
|
||||
Prepare your training file as a JSONL (one JSON object per line). Each line must contain:
|
||||
|
||||
- `audio`: path to the target training audio (wav)
|
||||
- `text`: transcript corresponding to `audio`
|
||||
- `ref_audio`: path to the reference speaker audio (wav)
|
||||
|
||||
Example:
|
||||
```jsonl
|
||||
{"audio":"./data/utt0001.wav","text":"其实我真的有发现,我是一个特别善于观察别人情绪的人。","ref_audio":"./data/ref.wav"}
|
||||
{"audio":"./data/utt0002.wav","text":"She said she would be here by noon.","ref_audio":"./data/ref.wav"}
|
||||
```
|
||||
|
||||
`ref_audio` recommendation:
|
||||
- Strongly recommended: use the same `ref_audio` for all samples.
|
||||
- Keeping `ref_audio` identical across the dataset usually improves speaker consistency and stability during generation.
|
||||
|
||||
|
||||
### 2) Prepare data (extract `audio_codes`)
|
||||
|
||||
Convert `train_raw.jsonl` into a training JSONL that includes `audio_codes`:
|
||||
|
||||
```bash
|
||||
python prepare_data.py \
|
||||
--device cuda:0 \
|
||||
--tokenizer_model_path Qwen/Qwen3-TTS-Tokenizer-12Hz \
|
||||
--input_jsonl train_raw.jsonl \
|
||||
--output_jsonl train_with_codes.jsonl
|
||||
```
|
||||
|
||||
|
||||
### 3) Fine-tune
|
||||
|
||||
Run SFT using the prepared JSONL:
|
||||
|
||||
```bash
|
||||
python sft_12hz.py \
|
||||
--init_model_path Qwen/Qwen3-TTS-12Hz-1.7B-Base \
|
||||
--output_model_path output \
|
||||
--train_jsonl train_with_codes.jsonl \
|
||||
--batch_size 2 \
|
||||
--lr 2e-5 \
|
||||
--num_epochs 3 \
|
||||
--speaker_name speaker_test
|
||||
```
|
||||
|
||||
Checkpoints will be written to:
|
||||
- `output/checkpoint-epoch-0`
|
||||
- `output/checkpoint-epoch-1`
|
||||
- `output/checkpoint-epoch-2`
|
||||
- ...
|
||||
|
||||
|
||||
### 4) Quick inference test
|
||||
|
||||
```python
|
||||
import torch
|
||||
import soundfile as sf
|
||||
from qwen_tts import Qwen3TTSModel
|
||||
|
||||
device = "cuda:0"
|
||||
tts = Qwen3TTSModel.from_pretrained(
|
||||
"output/checkpoint-epoch-2",
|
||||
device_map=device,
|
||||
dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
|
||||
wavs, sr = tts.generate_custom_voice(
|
||||
text="She said she would be here by noon.",
|
||||
speaker="speaker_test",
|
||||
)
|
||||
sf.write("output.wav", wavs[0], sr)
|
||||
```
|
||||
|
||||
### One-click shell script example
|
||||
|
||||
```bash
|
||||
#!/usr/bin/env bash
|
||||
set -e
|
||||
|
||||
DEVICE="cuda:0"
|
||||
TOKENIZER_MODEL_PATH="Qwen/Qwen3-TTS-Tokenizer-12Hz"
|
||||
INIT_MODEL_PATH="Qwen/Qwen3-TTS-12Hz-1.7B-Base"
|
||||
|
||||
RAW_JSONL="train_raw.jsonl"
|
||||
TRAIN_JSONL="train_with_codes.jsonl"
|
||||
OUTPUT_DIR="output"
|
||||
|
||||
BATCH_SIZE=2
|
||||
LR=2e-5
|
||||
EPOCHS=3
|
||||
SPEAKER_NAME="speaker_1"
|
||||
|
||||
python prepare_data.py \
|
||||
--device ${DEVICE} \
|
||||
--tokenizer_model_path ${TOKENIZER_MODEL_PATH} \
|
||||
--input_jsonl ${RAW_JSONL} \
|
||||
--output_jsonl ${TRAIN_JSONL}
|
||||
|
||||
python sft_12hz.py \
|
||||
--init_model_path ${INIT_MODEL_PATH} \
|
||||
--output_model_path ${OUTPUT_DIR} \
|
||||
--train_jsonl ${TRAIN_JSONL} \
|
||||
--batch_size ${BATCH_SIZE} \
|
||||
--lr ${LR} \
|
||||
--num_epochs ${EPOCHS} \
|
||||
--speaker_name ${SPEAKER_NAME}
|
||||
```
|
||||
218
models/Qwen3-TTS/finetuning/dataset.py
Normal file
218
models/Qwen3-TTS/finetuning/dataset.py
Normal file
@@ -0,0 +1,218 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from typing import Any, List, Tuple, Union
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import torch
|
||||
from qwen_tts.core.models.configuration_qwen3_tts import Qwen3TTSConfig
|
||||
from qwen_tts.core.models.modeling_qwen3_tts import mel_spectrogram
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
AudioLike = Union[
|
||||
str, # wav path, URL, base64
|
||||
np.ndarray, # waveform (requires sr)
|
||||
Tuple[np.ndarray, int], # (waveform, sr)
|
||||
]
|
||||
|
||||
MaybeList = Union[Any, List[Any]]
|
||||
|
||||
class TTSDataset(Dataset):
|
||||
def __init__(self, data_list, processor, config:Qwen3TTSConfig, lag_num = -1):
|
||||
self.data_list = data_list
|
||||
self.processor = processor
|
||||
self.lag_num = lag_num
|
||||
self.config = config
|
||||
|
||||
def __len__(self):
|
||||
return len(self.data_list)
|
||||
|
||||
def _load_audio_to_np(self, x: str) -> Tuple[np.ndarray, int]:
|
||||
|
||||
audio, sr = librosa.load(x, sr=None, mono=True)
|
||||
|
||||
if audio.ndim > 1:
|
||||
audio = np.mean(audio, axis=-1)
|
||||
|
||||
return audio.astype(np.float32), int(sr)
|
||||
|
||||
def _normalize_audio_inputs(self, audios: Union[AudioLike, List[AudioLike]]) -> List[Tuple[np.ndarray, int]]:
|
||||
"""
|
||||
Normalize audio inputs into a list of (waveform, sr).
|
||||
|
||||
Supported forms:
|
||||
- str: wav path / URL / base64 audio string
|
||||
- np.ndarray: waveform (NOT allowed alone here because sr is unknown)
|
||||
- (np.ndarray, sr): waveform + sampling rate
|
||||
- list of the above
|
||||
|
||||
Args:
|
||||
audios:
|
||||
Audio input(s).
|
||||
|
||||
Returns:
|
||||
List[Tuple[np.ndarray, int]]:
|
||||
List of (float32 waveform, original sr).
|
||||
|
||||
Raises:
|
||||
ValueError: If a numpy waveform is provided without sr.
|
||||
"""
|
||||
if isinstance(audios, list):
|
||||
items = audios
|
||||
else:
|
||||
items = [audios]
|
||||
|
||||
out: List[Tuple[np.ndarray, int]] = []
|
||||
for a in items:
|
||||
if isinstance(a, str):
|
||||
out.append(self._load_audio_to_np(a))
|
||||
elif isinstance(a, tuple) and len(a) == 2 and isinstance(a[0], np.ndarray):
|
||||
out.append((a[0].astype(np.float32), int(a[1])))
|
||||
elif isinstance(a, np.ndarray):
|
||||
raise ValueError("For numpy waveform input, pass a tuple (audio, sr).")
|
||||
else:
|
||||
raise TypeError(f"Unsupported audio input type: {type(a)}")
|
||||
return out
|
||||
|
||||
|
||||
def _build_assistant_text(self, text: str) -> str:
|
||||
return f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def _ensure_list(self, x: MaybeList) -> List[Any]:
|
||||
return x if isinstance(x, list) else [x]
|
||||
|
||||
def _tokenize_texts(self, text) -> List[torch.Tensor]:
|
||||
input = self.processor(text=text, return_tensors="pt", padding=True)
|
||||
input_id = input["input_ids"]
|
||||
input_id = input_id.unsqueeze(0) if input_id.dim() == 1 else input_id
|
||||
return input_id
|
||||
|
||||
@torch.inference_mode()
|
||||
def extract_mels(self, audio, sr):
|
||||
assert sr == 24000, "Only support 24kHz audio"
|
||||
mels = mel_spectrogram(
|
||||
torch.from_numpy(audio).unsqueeze(0),
|
||||
n_fft=1024,
|
||||
num_mels=128,
|
||||
sampling_rate=24000,
|
||||
hop_size=256,
|
||||
win_size=1024,
|
||||
fmin=0,
|
||||
fmax=12000
|
||||
).transpose(1, 2)
|
||||
return mels
|
||||
|
||||
|
||||
|
||||
def __getitem__(self, idx):
|
||||
item = self.data_list[idx]
|
||||
|
||||
audio_path = item["audio"]
|
||||
text = item["text"]
|
||||
audio_codes = item["audio_codes"]
|
||||
language = item.get('language','Auto')
|
||||
ref_audio_path = item['ref_audio']
|
||||
|
||||
text = self._build_assistant_text(text)
|
||||
text_ids = self._tokenize_texts(text)
|
||||
|
||||
audio_codes = torch.tensor(audio_codes, dtype=torch.long)
|
||||
|
||||
ref_audio_list = self._ensure_list(ref_audio_path)
|
||||
normalized = self._normalize_audio_inputs(ref_audio_list)
|
||||
wav,sr = normalized[0]
|
||||
|
||||
ref_mel = self.extract_mels(audio=wav, sr=sr)
|
||||
|
||||
return {
|
||||
"text_ids": text_ids[:,:-5], # 1 , t
|
||||
"audio_codes":audio_codes, # t, 16
|
||||
"ref_mel":ref_mel
|
||||
}
|
||||
|
||||
def collate_fn(self, batch):
|
||||
assert self.lag_num == -1
|
||||
|
||||
item_length = [b['text_ids'].shape[1] + b['audio_codes'].shape[0] for b in batch]
|
||||
max_length = max(item_length) + 8
|
||||
b,t = len(batch),max_length
|
||||
|
||||
input_ids = torch.zeros((b,t,2),dtype=torch.long)
|
||||
codec_ids = torch.zeros((b,t,16),dtype=torch.long)
|
||||
text_embedding_mask = torch.zeros((b,t),dtype=torch.bool)
|
||||
codec_embedding_mask = torch.zeros((b,t),dtype=torch.bool)
|
||||
codec_mask = torch.zeros((b,t),dtype=torch.bool)
|
||||
attention_mask = torch.zeros((b,t),dtype=torch.long)
|
||||
codec_0_labels = torch.full((b, t), -100, dtype=torch.long)
|
||||
|
||||
for i,data in enumerate(batch):
|
||||
text_ids = data['text_ids']
|
||||
audio_codec_0 = data['audio_codes'][:,0]
|
||||
audio_codecs = data['audio_codes']
|
||||
|
||||
text_ids_len = text_ids.shape[1]
|
||||
codec_ids_len = audio_codec_0.shape[0]
|
||||
|
||||
# text channel
|
||||
input_ids[i, :3, 0] = text_ids[0,:3]
|
||||
input_ids[i, 3:7, 0] = self.config.tts_pad_token_id
|
||||
input_ids[i, 7, 0] = self.config.tts_bos_token_id
|
||||
input_ids[i, 8:8+text_ids_len-3, 0] = text_ids[0,3:]
|
||||
input_ids[i, 8+text_ids_len-3, 0] = self.config.tts_eos_token_id
|
||||
input_ids[i, 8+text_ids_len-2:8+text_ids_len+codec_ids_len , 0] = self.config.tts_pad_token_id
|
||||
text_embedding_mask[i, :8+text_ids_len+codec_ids_len] = True
|
||||
|
||||
# codec channel
|
||||
# input_ids[i, :3, 1] = 0
|
||||
input_ids[i, 3:8 ,1] = torch.tensor(
|
||||
[
|
||||
self.config.talker_config.codec_nothink_id,
|
||||
self.config.talker_config.codec_think_bos_id,
|
||||
self.config.talker_config.codec_think_eos_id,
|
||||
0, # for speaker embedding
|
||||
self.config.talker_config.codec_pad_id
|
||||
]
|
||||
)
|
||||
input_ids[i, 8:8+text_ids_len-3 ,1] = self.config.talker_config.codec_pad_id
|
||||
input_ids[i, 8+text_ids_len-3 ,1] = self.config.talker_config.codec_pad_id
|
||||
input_ids[i, 8+text_ids_len-2 ,1] = self.config.talker_config.codec_bos_id
|
||||
input_ids[i, 8+text_ids_len-1:8+text_ids_len-1+codec_ids_len, 1] = audio_codec_0
|
||||
input_ids[i, 8+text_ids_len-1+codec_ids_len, 1] = self.config.talker_config.codec_eos_token_id
|
||||
|
||||
codec_0_labels[i, 8+text_ids_len-1:8+text_ids_len-1+codec_ids_len] = audio_codec_0
|
||||
codec_0_labels[i, 8+text_ids_len-1+codec_ids_len] = self.config.talker_config.codec_eos_token_id
|
||||
|
||||
codec_ids[i, 8+text_ids_len-1:8+text_ids_len-1+codec_ids_len,:] = audio_codecs
|
||||
|
||||
codec_embedding_mask[i, 3:8+text_ids_len+codec_ids_len] = True
|
||||
codec_embedding_mask[i, 6] = False # for speaker embedding
|
||||
|
||||
codec_mask[i, 8+text_ids_len-1:8+text_ids_len-1+codec_ids_len] = True
|
||||
attention_mask[i, :8+text_ids_len+codec_ids_len] = True
|
||||
|
||||
ref_mels = [data['ref_mel'] for data in batch]
|
||||
ref_mels = torch.cat(ref_mels,dim=0)
|
||||
|
||||
return {
|
||||
'input_ids':input_ids,
|
||||
'ref_mels':ref_mels,
|
||||
'attention_mask':attention_mask,
|
||||
'text_embedding_mask':text_embedding_mask.unsqueeze(-1),
|
||||
'codec_embedding_mask':codec_embedding_mask.unsqueeze(-1),
|
||||
'codec_0_labels':codec_0_labels,
|
||||
'codec_ids': codec_ids,
|
||||
'codec_mask':codec_mask
|
||||
}
|
||||
71
models/Qwen3-TTS/finetuning/prepare_data.py
Normal file
71
models/Qwen3-TTS/finetuning/prepare_data.py
Normal file
@@ -0,0 +1,71 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
|
||||
from qwen_tts import Qwen3TTSTokenizer
|
||||
|
||||
BATCH_INFER_NUM = 32
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--device", type=str, default="cuda:0")
|
||||
parser.add_argument("--tokenizer_model_path", type=str, default="Qwen/Qwen3-TTS-Tokenizer-12Hz")
|
||||
parser.add_argument("--input_jsonl", type=str, required=True)
|
||||
parser.add_argument("--output_jsonl", type=str, required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
tokenizer_12hz = Qwen3TTSTokenizer.from_pretrained(
|
||||
args.tokenizer_model_path,
|
||||
device_map=args.device,
|
||||
)
|
||||
|
||||
total_lines = open(args.input_jsonl).readlines()
|
||||
total_lines = [json.loads(line.strip()) for line in total_lines]
|
||||
|
||||
final_lines = []
|
||||
batch_lines = []
|
||||
batch_audios = []
|
||||
for line in total_lines:
|
||||
|
||||
batch_lines.append(line)
|
||||
batch_audios.append(line['audio'])
|
||||
|
||||
if len(batch_lines) >= BATCH_INFER_NUM:
|
||||
enc_res = tokenizer_12hz.encode(batch_audios)
|
||||
for code, line in zip(enc_res.audio_codes, batch_lines):
|
||||
line['audio_codes'] = code.cpu().tolist()
|
||||
final_lines.append(line)
|
||||
batch_lines.clear()
|
||||
batch_audios.clear()
|
||||
|
||||
if len(batch_audios) > 0:
|
||||
enc_res = tokenizer_12hz.encode(batch_audios)
|
||||
for code, line in zip(enc_res.audio_codes, batch_lines):
|
||||
line['audio_codes'] = code.cpu().tolist()
|
||||
final_lines.append(line)
|
||||
batch_lines.clear()
|
||||
batch_audios.clear()
|
||||
|
||||
final_lines = [json.dumps(line, ensure_ascii=False) for line in final_lines]
|
||||
|
||||
with open(args.output_jsonl, 'w') as f:
|
||||
for line in final_lines:
|
||||
f.writelines(line + '\n')
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
161
models/Qwen3-TTS/finetuning/sft_12hz.py
Normal file
161
models/Qwen3-TTS/finetuning/sft_12hz.py
Normal file
@@ -0,0 +1,161 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import argparse
|
||||
import json
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import torch
|
||||
from accelerate import Accelerator
|
||||
from dataset import TTSDataset
|
||||
from qwen_tts.inference.qwen3_tts_model import Qwen3TTSModel
|
||||
from safetensors.torch import save_file
|
||||
from torch.optim import AdamW
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers import AutoConfig
|
||||
|
||||
target_speaker_embedding = None
|
||||
def train():
|
||||
global target_speaker_embedding
|
||||
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument("--init_model_path", type=str, default="Qwen/Qwen3-TTS-12Hz-1.7B-Base")
|
||||
parser.add_argument("--output_model_path", type=str, default="output")
|
||||
parser.add_argument("--train_jsonl", type=str, required=True)
|
||||
parser.add_argument("--batch_size", type=int, default=2)
|
||||
parser.add_argument("--lr", type=float, default=2e-5)
|
||||
parser.add_argument("--num_epochs", type=int, default=3)
|
||||
parser.add_argument("--speaker_name", type=str, default="speaker_test")
|
||||
args = parser.parse_args()
|
||||
|
||||
accelerator = Accelerator(gradient_accumulation_steps=4, mixed_precision="bf16", log_with="tensorboard")
|
||||
|
||||
MODEL_PATH = args.init_model_path
|
||||
|
||||
qwen3tts = Qwen3TTSModel.from_pretrained(
|
||||
MODEL_PATH,
|
||||
torch_dtype=torch.bfloat16,
|
||||
attn_implementation="flash_attention_2",
|
||||
)
|
||||
config = AutoConfig.from_pretrained(MODEL_PATH)
|
||||
|
||||
train_data = open(args.train_jsonl).readlines()
|
||||
train_data = [json.loads(line) for line in train_data]
|
||||
dataset = TTSDataset(train_data, qwen3tts.processor, config)
|
||||
train_dataloader = DataLoader(dataset, batch_size=args.batch_size, shuffle=True, collate_fn=dataset.collate_fn)
|
||||
|
||||
optimizer = AdamW(qwen3tts.model.parameters(), lr=args.lr, weight_decay=0.01)
|
||||
|
||||
model, optimizer, train_dataloader = accelerator.prepare(
|
||||
qwen3tts.model, optimizer, train_dataloader
|
||||
)
|
||||
|
||||
num_epochs = args.num_epochs
|
||||
model.train()
|
||||
|
||||
for epoch in range(num_epochs):
|
||||
for step, batch in enumerate(train_dataloader):
|
||||
with accelerator.accumulate(model):
|
||||
|
||||
input_ids = batch['input_ids']
|
||||
codec_ids = batch['codec_ids']
|
||||
ref_mels = batch['ref_mels']
|
||||
text_embedding_mask = batch['text_embedding_mask']
|
||||
codec_embedding_mask = batch['codec_embedding_mask']
|
||||
attention_mask = batch['attention_mask']
|
||||
codec_0_labels = batch['codec_0_labels']
|
||||
codec_mask = batch['codec_mask']
|
||||
|
||||
speaker_embedding = model.speaker_encoder(ref_mels.to(model.device).to(model.dtype)).detach()
|
||||
if target_speaker_embedding is None:
|
||||
target_speaker_embedding = speaker_embedding
|
||||
|
||||
input_text_ids = input_ids[:, :, 0]
|
||||
input_codec_ids = input_ids[:, :, 1]
|
||||
|
||||
input_text_embedding = model.talker.model.text_embedding(input_text_ids) * text_embedding_mask
|
||||
input_codec_embedding = model.talker.model.codec_embedding(input_codec_ids) * codec_embedding_mask
|
||||
input_codec_embedding[:, 6, :] = speaker_embedding
|
||||
|
||||
input_embeddings = input_text_embedding + input_codec_embedding
|
||||
|
||||
for i in range(1, 16):
|
||||
codec_i_embedding = model.talker.code_predictor.get_input_embeddings()[i - 1](codec_ids[:, :, i])
|
||||
codec_i_embedding = codec_i_embedding * codec_mask.unsqueeze(-1)
|
||||
input_embeddings = input_embeddings + codec_i_embedding
|
||||
|
||||
outputs = model.talker(
|
||||
inputs_embeds=input_embeddings[:, :-1, :],
|
||||
attention_mask=attention_mask[:, :-1],
|
||||
labels=codec_0_labels[:, 1:],
|
||||
output_hidden_states=True
|
||||
)
|
||||
|
||||
hidden_states = outputs.hidden_states[0][-1]
|
||||
talker_hidden_states = hidden_states[codec_mask[:, 1:]]
|
||||
talker_codec_ids = codec_ids[codec_mask]
|
||||
|
||||
sub_talker_logits, sub_talker_loss = model.talker.forward_sub_talker_finetune(talker_codec_ids, talker_hidden_states)
|
||||
|
||||
loss = outputs.loss + sub_talker_loss
|
||||
|
||||
accelerator.backward(loss)
|
||||
|
||||
if accelerator.sync_gradients:
|
||||
accelerator.clip_grad_norm_(model.parameters(), 1.0)
|
||||
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % 10 == 0:
|
||||
accelerator.print(f"Epoch {epoch} | Step {step} | Loss: {loss.item():.4f}")
|
||||
|
||||
if accelerator.is_main_process:
|
||||
output_dir = os.path.join(args.output_model_path, f"checkpoint-epoch-{epoch}")
|
||||
shutil.copytree(MODEL_PATH, output_dir, dirs_exist_ok=True)
|
||||
|
||||
input_config_file = os.path.join(MODEL_PATH, "config.json")
|
||||
output_config_file = os.path.join(output_dir, "config.json")
|
||||
with open(input_config_file, 'r', encoding='utf-8') as f:
|
||||
config_dict = json.load(f)
|
||||
config_dict["tts_model_type"] = "custom_voice"
|
||||
talker_config = config_dict.get("talker_config", {})
|
||||
talker_config["spk_id"] = {
|
||||
args.speaker_name: 3000
|
||||
}
|
||||
talker_config["spk_is_dialect"] = {
|
||||
args.speaker_name: False
|
||||
}
|
||||
config_dict["talker_config"] = talker_config
|
||||
|
||||
with open(output_config_file, 'w', encoding='utf-8') as f:
|
||||
json.dump(config_dict, f, indent=2, ensure_ascii=False)
|
||||
|
||||
unwrapped_model = accelerator.unwrap_model(model)
|
||||
state_dict = {k: v.detach().to("cpu") for k, v in unwrapped_model.state_dict().items()}
|
||||
|
||||
drop_prefix = "speaker_encoder"
|
||||
keys_to_drop = [k for k in state_dict.keys() if k.startswith(drop_prefix)]
|
||||
for k in keys_to_drop:
|
||||
del state_dict[k]
|
||||
|
||||
weight = state_dict['talker.model.codec_embedding.weight']
|
||||
state_dict['talker.model.codec_embedding.weight'][3000] = target_speaker_embedding[0].detach().to(weight.device).to(weight.dtype)
|
||||
save_path = os.path.join(output_dir, "model.safetensors")
|
||||
save_file(state_dict, save_path)
|
||||
|
||||
if __name__ == "__main__":
|
||||
train()
|
||||
46
models/Qwen3-TTS/pyproject.toml
Normal file
46
models/Qwen3-TTS/pyproject.toml
Normal file
@@ -0,0 +1,46 @@
|
||||
[build-system]
|
||||
requires = ["setuptools>=68", "wheel"]
|
||||
build-backend = "setuptools.build_meta"
|
||||
|
||||
[project]
|
||||
name = "qwen-tts"
|
||||
version = "0.0.4"
|
||||
description = "Qwen-TTS python package"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.9"
|
||||
classifiers = [
|
||||
"Programming Language :: Python :: 3",
|
||||
"Programming Language :: Python :: 3.9",
|
||||
"Programming Language :: Python :: 3.10",
|
||||
"Programming Language :: Python :: 3.11",
|
||||
"Programming Language :: Python :: 3.12",
|
||||
"Programming Language :: Python :: 3.13",
|
||||
]
|
||||
license = { text = "Apache-2.0" }
|
||||
authors = [{ name = "Alibaba Qwen Team" }]
|
||||
|
||||
dependencies = [
|
||||
"transformers==4.57.3",
|
||||
"accelerate==1.12.0",
|
||||
"gradio",
|
||||
"librosa",
|
||||
"torchaudio",
|
||||
"soundfile",
|
||||
"sox",
|
||||
"onnxruntime",
|
||||
"einops",
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
Homepage = "https://github.com/Qwen/Qwen3-TTS"
|
||||
Repository = "https://github.com/Qwen/Qwen3-TTS"
|
||||
|
||||
[project.scripts]
|
||||
qwen-tts-demo = "qwen_tts.cli.demo:main"
|
||||
|
||||
[tool.setuptools]
|
||||
packages = { find = { where = ["."] , include = ["qwen_tts*"] } }
|
||||
include-package-data = true
|
||||
|
||||
[tool.setuptools.package-data]
|
||||
qwen_tts = ["py.typed", "**/*.npz"]
|
||||
24
models/Qwen3-TTS/qwen_tts/__init__.py
Normal file
24
models/Qwen3-TTS/qwen_tts/__init__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
qwen_tts: Qwen-TTS package.
|
||||
"""
|
||||
|
||||
from .inference.qwen3_tts_model import Qwen3TTSModel, VoiceClonePromptItem
|
||||
from .inference.qwen3_tts_tokenizer import Qwen3TTSTokenizer
|
||||
|
||||
__all__ = ["__version__"]
|
||||
24
models/Qwen3-TTS/qwen_tts/__main__.py
Normal file
24
models/Qwen3-TTS/qwen_tts/__main__.py
Normal file
@@ -0,0 +1,24 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
def main():
|
||||
print(
|
||||
"qwen_tts package.\n"
|
||||
"Use CLI entrypoints:\n"
|
||||
" - qwen-tts-demo\n"
|
||||
)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
634
models/Qwen3-TTS/qwen_tts/cli/demo.py
Normal file
634
models/Qwen3-TTS/qwen_tts/cli/demo.py
Normal file
@@ -0,0 +1,634 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
A gradio demo for Qwen3 TTS models.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import tempfile
|
||||
from dataclasses import asdict
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
import gradio as gr
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from .. import Qwen3TTSModel, VoiceClonePromptItem
|
||||
|
||||
|
||||
def _title_case_display(s: str) -> str:
|
||||
s = (s or "").strip()
|
||||
s = s.replace("_", " ")
|
||||
return " ".join([w[:1].upper() + w[1:] if w else "" for w in s.split()])
|
||||
|
||||
|
||||
def _build_choices_and_map(items: Optional[List[str]]) -> Tuple[List[str], Dict[str, str]]:
|
||||
if not items:
|
||||
return [], {}
|
||||
display = [_title_case_display(x) for x in items]
|
||||
mapping = {d: r for d, r in zip(display, items)}
|
||||
return display, mapping
|
||||
|
||||
|
||||
def _dtype_from_str(s: str) -> torch.dtype:
|
||||
s = (s or "").strip().lower()
|
||||
if s in ("bf16", "bfloat16"):
|
||||
return torch.bfloat16
|
||||
if s in ("fp16", "float16", "half"):
|
||||
return torch.float16
|
||||
if s in ("fp32", "float32"):
|
||||
return torch.float32
|
||||
raise ValueError(f"Unsupported torch dtype: {s}. Use bfloat16/float16/float32.")
|
||||
|
||||
|
||||
def _maybe(v):
|
||||
return v if v is not None else gr.update()
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="qwen-tts-demo",
|
||||
description=(
|
||||
"Launch a Gradio demo for Qwen3 TTS models (CustomVoice / VoiceDesign / Base).\n\n"
|
||||
"Examples:\n"
|
||||
" qwen-tts-demo Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice\n"
|
||||
" qwen-tts-demo Qwen/Qwen3-TTS-12Hz-1.7B-VoiceDesign --port 8000 --ip 127.0.0.01\n"
|
||||
" qwen-tts-demo Qwen/Qwen3-TTS-12Hz-1.7B-Base --device cuda:0\n"
|
||||
" qwen-tts-demo Qwen/Qwen3-TTS-12Hz-1.7B-CustomVoice --dtype bfloat16 --no-flash-attn\n"
|
||||
),
|
||||
formatter_class=argparse.RawTextHelpFormatter,
|
||||
add_help=True,
|
||||
)
|
||||
|
||||
# Positional checkpoint (also supports -c/--checkpoint)
|
||||
parser.add_argument(
|
||||
"checkpoint_pos",
|
||||
nargs="?",
|
||||
default=None,
|
||||
help="Model checkpoint path or HuggingFace repo id (positional).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"-c",
|
||||
"--checkpoint",
|
||||
default=None,
|
||||
help="Model checkpoint path or HuggingFace repo id (optional if positional is provided).",
|
||||
)
|
||||
|
||||
# Model loading / from_pretrained args
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
default="cuda:0",
|
||||
help="Device for device_map, e.g. cpu, cuda, cuda:0 (default: cuda:0).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dtype",
|
||||
default="bfloat16",
|
||||
choices=["bfloat16", "bf16", "float16", "fp16", "float32", "fp32"],
|
||||
help="Torch dtype for loading the model (default: bfloat16).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--flash-attn/--no-flash-attn",
|
||||
dest="flash_attn",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Enable FlashAttention-2 (default: enabled).",
|
||||
)
|
||||
|
||||
# Gradio server args
|
||||
parser.add_argument(
|
||||
"--ip",
|
||||
default="0.0.0.0",
|
||||
help="Server bind IP for Gradio (default: 0.0.0.0).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=8000,
|
||||
help="Server port for Gradio (default: 8000).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--share/--no-share",
|
||||
dest="share",
|
||||
default=False,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Whether to create a public Gradio link (default: disabled).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--concurrency",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Gradio queue concurrency (default: 16).",
|
||||
)
|
||||
|
||||
# HTTPS args
|
||||
parser.add_argument(
|
||||
"--ssl-certfile",
|
||||
default=None,
|
||||
help="Path to SSL certificate file for HTTPS (optional).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ssl-keyfile",
|
||||
default=None,
|
||||
help="Path to SSL key file for HTTPS (optional).",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--ssl-verify/--no-ssl-verify",
|
||||
dest="ssl_verify",
|
||||
default=True,
|
||||
action=argparse.BooleanOptionalAction,
|
||||
help="Whether to verify SSL certificate (default: enabled).",
|
||||
)
|
||||
|
||||
# Optional generation args
|
||||
parser.add_argument("--max-new-tokens", type=int, default=None, help="Max new tokens for generation (optional).")
|
||||
parser.add_argument("--temperature", type=float, default=None, help="Sampling temperature (optional).")
|
||||
parser.add_argument("--top-k", type=int, default=None, help="Top-k sampling (optional).")
|
||||
parser.add_argument("--top-p", type=float, default=None, help="Top-p sampling (optional).")
|
||||
parser.add_argument("--repetition-penalty", type=float, default=None, help="Repetition penalty (optional).")
|
||||
parser.add_argument("--subtalker-top-k", type=int, default=None, help="Subtalker top-k (optional, only for tokenizer v2).")
|
||||
parser.add_argument("--subtalker-top-p", type=float, default=None, help="Subtalker top-p (optional, only for tokenizer v2).")
|
||||
parser.add_argument(
|
||||
"--subtalker-temperature", type=float, default=None, help="Subtalker temperature (optional, only for tokenizer v2)."
|
||||
)
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _resolve_checkpoint(args: argparse.Namespace) -> str:
|
||||
ckpt = args.checkpoint or args.checkpoint_pos
|
||||
if not ckpt:
|
||||
raise SystemExit(0) # main() prints help
|
||||
return ckpt
|
||||
|
||||
|
||||
def _collect_gen_kwargs(args: argparse.Namespace) -> Dict[str, Any]:
|
||||
mapping = {
|
||||
"max_new_tokens": args.max_new_tokens,
|
||||
"temperature": args.temperature,
|
||||
"top_k": args.top_k,
|
||||
"top_p": args.top_p,
|
||||
"repetition_penalty": args.repetition_penalty,
|
||||
"subtalker_top_k": args.subtalker_top_k,
|
||||
"subtalker_top_p": args.subtalker_top_p,
|
||||
"subtalker_temperature": args.subtalker_temperature,
|
||||
}
|
||||
return {k: v for k, v in mapping.items() if v is not None}
|
||||
|
||||
|
||||
def _normalize_audio(wav, eps=1e-12, clip=True):
|
||||
x = np.asarray(wav)
|
||||
|
||||
if np.issubdtype(x.dtype, np.integer):
|
||||
info = np.iinfo(x.dtype)
|
||||
|
||||
if info.min < 0:
|
||||
y = x.astype(np.float32) / max(abs(info.min), info.max)
|
||||
else:
|
||||
mid = (info.max + 1) / 2.0
|
||||
y = (x.astype(np.float32) - mid) / mid
|
||||
|
||||
elif np.issubdtype(x.dtype, np.floating):
|
||||
y = x.astype(np.float32)
|
||||
m = np.max(np.abs(y)) if y.size else 0.0
|
||||
|
||||
if m <= 1.0 + 1e-6:
|
||||
pass
|
||||
else:
|
||||
y = y / (m + eps)
|
||||
else:
|
||||
raise TypeError(f"Unsupported dtype: {x.dtype}")
|
||||
|
||||
if clip:
|
||||
y = np.clip(y, -1.0, 1.0)
|
||||
|
||||
if y.ndim > 1:
|
||||
y = np.mean(y, axis=-1).astype(np.float32)
|
||||
|
||||
return y
|
||||
|
||||
|
||||
def _audio_to_tuple(audio: Any) -> Optional[Tuple[np.ndarray, int]]:
|
||||
if audio is None:
|
||||
return None
|
||||
|
||||
if isinstance(audio, tuple) and len(audio) == 2 and isinstance(audio[0], int):
|
||||
sr, wav = audio
|
||||
wav = _normalize_audio(wav)
|
||||
return wav, int(sr)
|
||||
|
||||
if isinstance(audio, dict) and "sampling_rate" in audio and "data" in audio:
|
||||
sr = int(audio["sampling_rate"])
|
||||
wav = _normalize_audio(audio["data"])
|
||||
return wav, sr
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def _wav_to_gradio_audio(wav: np.ndarray, sr: int) -> Tuple[int, np.ndarray]:
|
||||
wav = np.asarray(wav, dtype=np.float32)
|
||||
return sr, wav
|
||||
|
||||
|
||||
def _detect_model_kind(ckpt: str, tts: Qwen3TTSModel) -> str:
|
||||
mt = getattr(tts.model, "tts_model_type", None)
|
||||
if mt in ("custom_voice", "voice_design", "base"):
|
||||
return mt
|
||||
else:
|
||||
raise ValueError(f"Unknown Qwen-TTS model type: {mt}")
|
||||
|
||||
|
||||
def build_demo(tts: Qwen3TTSModel, ckpt: str, gen_kwargs_default: Dict[str, Any]) -> gr.Blocks:
|
||||
model_kind = _detect_model_kind(ckpt, tts)
|
||||
|
||||
supported_langs_raw = None
|
||||
if callable(getattr(tts.model, "get_supported_languages", None)):
|
||||
supported_langs_raw = tts.model.get_supported_languages()
|
||||
|
||||
supported_spks_raw = None
|
||||
if callable(getattr(tts.model, "get_supported_speakers", None)):
|
||||
supported_spks_raw = tts.model.get_supported_speakers()
|
||||
|
||||
lang_choices_disp, lang_map = _build_choices_and_map([x for x in (supported_langs_raw or [])])
|
||||
spk_choices_disp, spk_map = _build_choices_and_map([x for x in (supported_spks_raw or [])])
|
||||
|
||||
def _gen_common_kwargs() -> Dict[str, Any]:
|
||||
return dict(gen_kwargs_default)
|
||||
|
||||
theme = gr.themes.Soft(
|
||||
font=[gr.themes.GoogleFont("Source Sans Pro"), "Arial", "sans-serif"],
|
||||
)
|
||||
|
||||
css = ".gradio-container {max-width: none !important;}"
|
||||
|
||||
with gr.Blocks(theme=theme, css=css) as demo:
|
||||
gr.Markdown(
|
||||
f"""
|
||||
# Qwen3 TTS Demo
|
||||
**Checkpoint:** `{ckpt}`
|
||||
**Model Type:** `{model_kind}`
|
||||
"""
|
||||
)
|
||||
|
||||
if model_kind == "custom_voice":
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
text_in = gr.Textbox(
|
||||
label="Text (待合成文本)",
|
||||
lines=4,
|
||||
placeholder="Enter text to synthesize (输入要合成的文本).",
|
||||
)
|
||||
with gr.Row():
|
||||
lang_in = gr.Dropdown(
|
||||
label="Language (语种)",
|
||||
choices=lang_choices_disp,
|
||||
value="Auto",
|
||||
interactive=True,
|
||||
)
|
||||
spk_in = gr.Dropdown(
|
||||
label="Speaker (说话人)",
|
||||
choices=spk_choices_disp,
|
||||
value="Vivian",
|
||||
interactive=True,
|
||||
)
|
||||
instruct_in = gr.Textbox(
|
||||
label="Instruction (Optional) (控制指令,可不输入)",
|
||||
lines=2,
|
||||
placeholder="e.g. Say it in a very angry tone (例如:用特别伤心的语气说).",
|
||||
)
|
||||
btn = gr.Button("Generate (生成)", variant="primary")
|
||||
with gr.Column(scale=3):
|
||||
audio_out = gr.Audio(label="Output Audio (合成结果)", type="numpy")
|
||||
err = gr.Textbox(label="Status (状态)", lines=2)
|
||||
|
||||
def run_instruct(text: str, lang_disp: str, spk_disp: str, instruct: str):
|
||||
try:
|
||||
if not text or not text.strip():
|
||||
return None, "Text is required (必须填写文本)."
|
||||
if not spk_disp:
|
||||
return None, "Speaker is required (必须选择说话人)."
|
||||
language = lang_map.get(lang_disp, "Auto")
|
||||
speaker = spk_map.get(spk_disp, spk_disp)
|
||||
kwargs = _gen_common_kwargs()
|
||||
wavs, sr = tts.generate_custom_voice(
|
||||
text=text.strip(),
|
||||
language=language,
|
||||
speaker=speaker,
|
||||
instruct=(instruct or "").strip() or None,
|
||||
**kwargs,
|
||||
)
|
||||
return _wav_to_gradio_audio(wavs[0], sr), "Finished. (生成完成)"
|
||||
except Exception as e:
|
||||
return None, f"{type(e).__name__}: {e}"
|
||||
|
||||
btn.click(run_instruct, inputs=[text_in, lang_in, spk_in, instruct_in], outputs=[audio_out, err])
|
||||
|
||||
elif model_kind == "voice_design":
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
text_in = gr.Textbox(
|
||||
label="Text (待合成文本)",
|
||||
lines=4,
|
||||
value="It's in the top drawer... wait, it's empty? No way, that's impossible! I'm sure I put it there!"
|
||||
)
|
||||
with gr.Row():
|
||||
lang_in = gr.Dropdown(
|
||||
label="Language (语种)",
|
||||
choices=lang_choices_disp,
|
||||
value="Auto",
|
||||
interactive=True,
|
||||
)
|
||||
design_in = gr.Textbox(
|
||||
label="Voice Design Instruction (音色描述)",
|
||||
lines=3,
|
||||
value="Speak in an incredulous tone, but with a hint of panic beginning to creep into your voice."
|
||||
)
|
||||
btn = gr.Button("Generate (生成)", variant="primary")
|
||||
with gr.Column(scale=3):
|
||||
audio_out = gr.Audio(label="Output Audio (合成结果)", type="numpy")
|
||||
err = gr.Textbox(label="Status (状态)", lines=2)
|
||||
|
||||
def run_voice_design(text: str, lang_disp: str, design: str):
|
||||
try:
|
||||
if not text or not text.strip():
|
||||
return None, "Text is required (必须填写文本)."
|
||||
if not design or not design.strip():
|
||||
return None, "Voice design instruction is required (必须填写音色描述)."
|
||||
language = lang_map.get(lang_disp, "Auto")
|
||||
kwargs = _gen_common_kwargs()
|
||||
wavs, sr = tts.generate_voice_design(
|
||||
text=text.strip(),
|
||||
language=language,
|
||||
instruct=design.strip(),
|
||||
**kwargs,
|
||||
)
|
||||
return _wav_to_gradio_audio(wavs[0], sr), "Finished. (生成完成)"
|
||||
except Exception as e:
|
||||
return None, f"{type(e).__name__}: {e}"
|
||||
|
||||
btn.click(run_voice_design, inputs=[text_in, lang_in, design_in], outputs=[audio_out, err])
|
||||
|
||||
else: # voice_clone for base
|
||||
with gr.Tabs():
|
||||
with gr.Tab("Clone & Generate (克隆并合成)"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
ref_audio = gr.Audio(
|
||||
label="Reference Audio (参考音频)",
|
||||
)
|
||||
ref_text = gr.Textbox(
|
||||
label="Reference Text (参考音频文本)",
|
||||
lines=2,
|
||||
placeholder="Required if not set use x-vector only (不勾选use x-vector only时必填).",
|
||||
)
|
||||
xvec_only = gr.Checkbox(
|
||||
label="Use x-vector only (仅用说话人向量,效果有限,但不用传入参考音频文本)",
|
||||
value=False,
|
||||
)
|
||||
|
||||
with gr.Column(scale=2):
|
||||
text_in = gr.Textbox(
|
||||
label="Target Text (待合成文本)",
|
||||
lines=4,
|
||||
placeholder="Enter text to synthesize (输入要合成的文本).",
|
||||
)
|
||||
lang_in = gr.Dropdown(
|
||||
label="Language (语种)",
|
||||
choices=lang_choices_disp,
|
||||
value="Auto",
|
||||
interactive=True,
|
||||
)
|
||||
btn = gr.Button("Generate (生成)", variant="primary")
|
||||
|
||||
with gr.Column(scale=3):
|
||||
audio_out = gr.Audio(label="Output Audio (合成结果)", type="numpy")
|
||||
err = gr.Textbox(label="Status (状态)", lines=2)
|
||||
|
||||
def run_voice_clone(ref_aud, ref_txt: str, use_xvec: bool, text: str, lang_disp: str):
|
||||
try:
|
||||
if not text or not text.strip():
|
||||
return None, "Target text is required (必须填写待合成文本)."
|
||||
at = _audio_to_tuple(ref_aud)
|
||||
if at is None:
|
||||
return None, "Reference audio is required (必须上传参考音频)."
|
||||
if (not use_xvec) and (not ref_txt or not ref_txt.strip()):
|
||||
return None, (
|
||||
"Reference text is required when use x-vector only is NOT enabled.\n"
|
||||
"(未勾选 use x-vector only 时,必须提供参考音频文本;否则请勾选 use x-vector only,但效果会变差.)"
|
||||
)
|
||||
language = lang_map.get(lang_disp, "Auto")
|
||||
kwargs = _gen_common_kwargs()
|
||||
wavs, sr = tts.generate_voice_clone(
|
||||
text=text.strip(),
|
||||
language=language,
|
||||
ref_audio=at,
|
||||
ref_text=(ref_txt.strip() if ref_txt else None),
|
||||
x_vector_only_mode=bool(use_xvec),
|
||||
**kwargs,
|
||||
)
|
||||
return _wav_to_gradio_audio(wavs[0], sr), "Finished. (生成完成)"
|
||||
except Exception as e:
|
||||
return None, f"{type(e).__name__}: {e}"
|
||||
|
||||
btn.click(
|
||||
run_voice_clone,
|
||||
inputs=[ref_audio, ref_text, xvec_only, text_in, lang_in],
|
||||
outputs=[audio_out, err],
|
||||
)
|
||||
|
||||
with gr.Tab("Save / Load Voice (保存/加载克隆音色)"):
|
||||
with gr.Row():
|
||||
with gr.Column(scale=2):
|
||||
gr.Markdown(
|
||||
"""
|
||||
### Save Voice (保存音色)
|
||||
Upload reference audio and text, choose use x-vector only or not, then save a reusable voice prompt file.
|
||||
(上传参考音频和参考文本,选择是否使用 use x-vector only 模式后保存为可复用的音色文件)
|
||||
"""
|
||||
)
|
||||
ref_audio_s = gr.Audio(label="Reference Audio (参考音频)", type="numpy")
|
||||
ref_text_s = gr.Textbox(
|
||||
label="Reference Text (参考音频文本)",
|
||||
lines=2,
|
||||
placeholder="Required if not set use x-vector only (不勾选use x-vector only时必填).",
|
||||
)
|
||||
xvec_only_s = gr.Checkbox(
|
||||
label="Use x-vector only (仅用说话人向量,效果有限,但不用传入参考音频文本)",
|
||||
value=False,
|
||||
)
|
||||
save_btn = gr.Button("Save Voice File (保存音色文件)", variant="primary")
|
||||
prompt_file_out = gr.File(label="Voice File (音色文件)")
|
||||
|
||||
with gr.Column(scale=2):
|
||||
gr.Markdown(
|
||||
"""
|
||||
### Load Voice & Generate (加载音色并合成)
|
||||
Upload a previously saved voice file, then synthesize new text.
|
||||
(上传已保存提示文件后,输入新文本进行合成)
|
||||
"""
|
||||
)
|
||||
prompt_file_in = gr.File(label="Upload Prompt File (上传提示文件)")
|
||||
text_in2 = gr.Textbox(
|
||||
label="Target Text (待合成文本)",
|
||||
lines=4,
|
||||
placeholder="Enter text to synthesize (输入要合成的文本).",
|
||||
)
|
||||
lang_in2 = gr.Dropdown(
|
||||
label="Language (语种)",
|
||||
choices=lang_choices_disp,
|
||||
value="Auto",
|
||||
interactive=True,
|
||||
)
|
||||
gen_btn2 = gr.Button("Generate (生成)", variant="primary")
|
||||
|
||||
with gr.Column(scale=3):
|
||||
audio_out2 = gr.Audio(label="Output Audio (合成结果)", type="numpy")
|
||||
err2 = gr.Textbox(label="Status (状态)", lines=2)
|
||||
|
||||
def save_prompt(ref_aud, ref_txt: str, use_xvec: bool):
|
||||
try:
|
||||
at = _audio_to_tuple(ref_aud)
|
||||
if at is None:
|
||||
return None, "Reference audio is required (必须上传参考音频)."
|
||||
if (not use_xvec) and (not ref_txt or not ref_txt.strip()):
|
||||
return None, (
|
||||
"Reference text is required when use x-vector only is NOT enabled.\n"
|
||||
"(未勾选 use x-vector only 时,必须提供参考音频文本;否则请勾选 use x-vector only,但效果会变差.)"
|
||||
)
|
||||
items = tts.create_voice_clone_prompt(
|
||||
ref_audio=at,
|
||||
ref_text=(ref_txt.strip() if ref_txt else None),
|
||||
x_vector_only_mode=bool(use_xvec),
|
||||
)
|
||||
payload = {
|
||||
"items": [asdict(it) for it in items],
|
||||
}
|
||||
fd, out_path = tempfile.mkstemp(prefix="voice_clone_prompt_", suffix=".pt")
|
||||
os.close(fd)
|
||||
torch.save(payload, out_path)
|
||||
return out_path, "Finished. (生成完成)"
|
||||
except Exception as e:
|
||||
return None, f"{type(e).__name__}: {e}"
|
||||
|
||||
def load_prompt_and_gen(file_obj, text: str, lang_disp: str):
|
||||
try:
|
||||
if file_obj is None:
|
||||
return None, "Voice file is required (必须上传音色文件)."
|
||||
if not text or not text.strip():
|
||||
return None, "Target text is required (必须填写待合成文本)."
|
||||
|
||||
path = getattr(file_obj, "name", None) or getattr(file_obj, "path", None) or str(file_obj)
|
||||
payload = torch.load(path, map_location="cpu", weights_only=True)
|
||||
if not isinstance(payload, dict) or "items" not in payload:
|
||||
return None, "Invalid file format (文件格式不正确)."
|
||||
|
||||
items_raw = payload["items"]
|
||||
if not isinstance(items_raw, list) or len(items_raw) == 0:
|
||||
return None, "Empty voice items (音色为空)."
|
||||
|
||||
items: List[VoiceClonePromptItem] = []
|
||||
for d in items_raw:
|
||||
if not isinstance(d, dict):
|
||||
return None, "Invalid item format in file (文件内部格式错误)."
|
||||
ref_code = d.get("ref_code", None)
|
||||
if ref_code is not None and not torch.is_tensor(ref_code):
|
||||
ref_code = torch.tensor(ref_code)
|
||||
ref_spk = d.get("ref_spk_embedding", None)
|
||||
if ref_spk is None:
|
||||
return None, "Missing ref_spk_embedding (缺少说话人向量)."
|
||||
if not torch.is_tensor(ref_spk):
|
||||
ref_spk = torch.tensor(ref_spk)
|
||||
|
||||
items.append(
|
||||
VoiceClonePromptItem(
|
||||
ref_code=ref_code,
|
||||
ref_spk_embedding=ref_spk,
|
||||
x_vector_only_mode=bool(d.get("x_vector_only_mode", False)),
|
||||
icl_mode=bool(d.get("icl_mode", not bool(d.get("x_vector_only_mode", False)))),
|
||||
ref_text=d.get("ref_text", None),
|
||||
)
|
||||
)
|
||||
|
||||
language = lang_map.get(lang_disp, "Auto")
|
||||
kwargs = _gen_common_kwargs()
|
||||
wavs, sr = tts.generate_voice_clone(
|
||||
text=text.strip(),
|
||||
language=language,
|
||||
voice_clone_prompt=items,
|
||||
**kwargs,
|
||||
)
|
||||
return _wav_to_gradio_audio(wavs[0], sr), "Finished. (生成完成)"
|
||||
except Exception as e:
|
||||
return None, (
|
||||
f"Failed to read or use voice file. Check file format/content.\n"
|
||||
f"(读取或使用音色文件失败,请检查文件格式或内容)\n"
|
||||
f"{type(e).__name__}: {e}"
|
||||
)
|
||||
|
||||
save_btn.click(save_prompt, inputs=[ref_audio_s, ref_text_s, xvec_only_s], outputs=[prompt_file_out, err2])
|
||||
gen_btn2.click(load_prompt_and_gen, inputs=[prompt_file_in, text_in2, lang_in2], outputs=[audio_out2, err2])
|
||||
|
||||
gr.Markdown(
|
||||
"""
|
||||
**Disclaimer (免责声明)**
|
||||
- The audio is automatically generated/synthesized by an AI model solely to demonstrate the model’s capabilities; it may be inaccurate or inappropriate, does not represent the views of the developer/operator, and does not constitute professional advice. You are solely responsible for evaluating, using, distributing, or relying on this audio; to the maximum extent permitted by applicable law, the developer/operator disclaims liability for any direct, indirect, incidental, or consequential damages arising from the use of or inability to use the audio, except where liability cannot be excluded by law. Do not use this service to intentionally generate or replicate unlawful, harmful, defamatory, fraudulent, deepfake, or privacy/publicity/copyright/trademark‑infringing content; if a user prompts, supplies materials, or otherwise facilitates any illegal or infringing conduct, the user bears all legal consequences and the developer/operator is not responsible.
|
||||
- 音频由人工智能模型自动生成/合成,仅用于体验与展示模型效果,可能存在不准确或不当之处;其内容不代表开发者/运营方立场,亦不构成任何专业建议。用户应自行评估并承担使用、传播或依赖该音频所产生的一切风险与责任;在适用法律允许的最大范围内,开发者/运营方不对因使用或无法使用本音频造成的任何直接、间接、附带或后果性损失承担责任(法律另有强制规定的除外)。严禁利用本服务故意引导生成或复制违法、有害、诽谤、欺诈、深度伪造、侵犯隐私/肖像/著作权/商标等内容;如用户通过提示词、素材或其他方式实施或促成任何违法或侵权行为,相关法律后果由用户自行承担,与开发者/运营方无关。
|
||||
"""
|
||||
)
|
||||
|
||||
return demo
|
||||
|
||||
|
||||
def main(argv=None) -> int:
|
||||
parser = build_parser()
|
||||
args = parser.parse_args(argv)
|
||||
|
||||
if not args.checkpoint and not args.checkpoint_pos:
|
||||
parser.print_help()
|
||||
return 0
|
||||
|
||||
ckpt = _resolve_checkpoint(args)
|
||||
|
||||
dtype = _dtype_from_str(args.dtype)
|
||||
attn_impl = "flash_attention_2" if args.flash_attn else None
|
||||
|
||||
tts = Qwen3TTSModel.from_pretrained(
|
||||
ckpt,
|
||||
device_map=args.device,
|
||||
dtype=dtype,
|
||||
attn_implementation=attn_impl,
|
||||
)
|
||||
|
||||
gen_kwargs_default = _collect_gen_kwargs(args)
|
||||
demo = build_demo(tts, ckpt, gen_kwargs_default)
|
||||
|
||||
launch_kwargs: Dict[str, Any] = dict(
|
||||
server_name=args.ip,
|
||||
server_port=args.port,
|
||||
share=args.share,
|
||||
ssl_verify=True if args.ssl_verify else False,
|
||||
)
|
||||
if args.ssl_certfile is not None:
|
||||
launch_kwargs["ssl_certfile"] = args.ssl_certfile
|
||||
if args.ssl_keyfile is not None:
|
||||
launch_kwargs["ssl_keyfile"] = args.ssl_keyfile
|
||||
|
||||
demo.queue(default_concurrency_limit=int(args.concurrency)).launch(**launch_kwargs)
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
raise SystemExit(main())
|
||||
19
models/Qwen3-TTS/qwen_tts/core/__init__.py
Normal file
19
models/Qwen3-TTS/qwen_tts/core/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .tokenizer_25hz.configuration_qwen3_tts_tokenizer_v1 import Qwen3TTSTokenizerV1Config
|
||||
from .tokenizer_25hz.modeling_qwen3_tts_tokenizer_v1 import Qwen3TTSTokenizerV1Model
|
||||
from .tokenizer_12hz.configuration_qwen3_tts_tokenizer_v2 import Qwen3TTSTokenizerV2Config
|
||||
from .tokenizer_12hz.modeling_qwen3_tts_tokenizer_v2 import Qwen3TTSTokenizerV2Model
|
||||
18
models/Qwen3-TTS/qwen_tts/core/models/__init__.py
Normal file
18
models/Qwen3-TTS/qwen_tts/core/models/__init__.py
Normal file
@@ -0,0 +1,18 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from .configuration_qwen3_tts import Qwen3TTSConfig
|
||||
from .modeling_qwen3_tts import Qwen3TTSForConditionalGeneration
|
||||
from .processing_qwen3_tts import Qwen3TTSProcessor
|
||||
502
models/Qwen3-TTS/qwen_tts/core/models/configuration_qwen3_tts.py
Normal file
502
models/Qwen3-TTS/qwen_tts/core/models/configuration_qwen3_tts.py
Normal file
@@ -0,0 +1,502 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from transformers.configuration_utils import PretrainedConfig, layer_type_validation
|
||||
from transformers.modeling_rope_utils import rope_config_validation
|
||||
from transformers.utils import logging
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Qwen3TTSSpeakerEncoderConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen3TTSSpeakerEncoder`].
|
||||
It is used to instantiate a Qwen3TTS speaker encoder model according to the specified arguments, defining the model
|
||||
architecture. The architecture is based on the ECAPA-TDNN model.
|
||||
|
||||
Args:
|
||||
mel_dim (`int`, *optional*, defaults to 128):
|
||||
The dimension of the input mel-spectrogram.
|
||||
enc_dim (`int`, *optional*, defaults to 192):
|
||||
The dimension of the final speaker embedding.
|
||||
enc_channels (`list[int]`, *optional*, defaults to `[512, 512, 512, 512, 1536]`):
|
||||
A list of output channels for each TDNN/SERes2Net layer in the encoder. The first channel size is for the initial TDNN layer,
|
||||
the intermediate ones for the `SqueezeExcitationRes2NetBlock` layers, and the last one for the multi-layer feature aggregation.
|
||||
enc_kernel_sizes (`list[int]`, *optional*, defaults to `[5, 3, 3, 3, 1]`):
|
||||
A list of kernel sizes for each layer in the encoder, corresponding to `enc_channels`.
|
||||
enc_dilations (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 1]`):
|
||||
A list of dilations for each layer in the encoder, corresponding to `enc_channels`.
|
||||
enc_attention_channels (`int`, *optional*, defaults to 128):
|
||||
The number of attention channels in the `AttentiveStatisticsPooling` layer.
|
||||
enc_res2net_scale (`int`, *optional*,defaults to 8):
|
||||
The scale of the `Res2NetBlock` in the encoder.
|
||||
enc_se_channels (`int`, *optional*, defaults to 128):
|
||||
The number of channels in the squeeze part of the `SqueezeExcitationBlock`.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
mel_dim=128,
|
||||
enc_dim=1024,
|
||||
enc_channels=[512, 512, 512, 512, 1536],
|
||||
enc_kernel_sizes=[5, 3, 3, 3, 1],
|
||||
enc_dilations=[1, 2, 3, 4, 1],
|
||||
enc_attention_channels=128,
|
||||
enc_res2net_scale=8,
|
||||
enc_se_channels=128,
|
||||
sample_rate=24000,
|
||||
):
|
||||
self.mel_dim = mel_dim
|
||||
self.enc_dim = enc_dim
|
||||
self.enc_channels = enc_channels
|
||||
self.enc_kernel_sizes = enc_kernel_sizes
|
||||
self.enc_dilations = enc_dilations
|
||||
self.enc_attention_channels = enc_attention_channels
|
||||
self.enc_res2net_scale = enc_res2net_scale
|
||||
self.enc_se_channels = enc_se_channels
|
||||
self.sample_rate = sample_rate
|
||||
|
||||
|
||||
class Qwen3TTSTalkerCodePredictorConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen3TTSTalkerCodePredictorModel`]. It is used to instantiate a
|
||||
Qwen3TTSTalkerCodePredictor model according to the specified arguments, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 151936):
|
||||
Vocabulary size of the Qwen3TTSTalkerCodePredictor model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Qwen3TTSTalkerCodePredictorModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 4096):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 22016):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 32):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 32):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details, check out [this
|
||||
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
|
||||
head_dim (`int`, *optional*, defaults to 128):
|
||||
The attention head dimension.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`list[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`list[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use sliding window attention.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
||||
max_window_layers (`int`, *optional*, defaults to 28):
|
||||
The number of layers using full attention. The first `max_window_layers` layers will use full attention, while any
|
||||
additional layer afterwards will use SWA (Sliding Window Attention).
|
||||
layer_types (`list`, *optional*):
|
||||
Attention pattern for each layer.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
|
||||
"""
|
||||
|
||||
model_type = "qwen3_tts_talker_code_predictor"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
# Default tensor parallel plan for base model `Qwen3TTSTalkerCodePredictor`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=2048,
|
||||
hidden_size=1024,
|
||||
intermediate_size=3072,
|
||||
num_hidden_layers=5,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=8,
|
||||
head_dim=128,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=0.000001,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
max_window_layers=28,
|
||||
layer_types=None,
|
||||
attention_dropout=0,
|
||||
num_code_groups=32,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window if self.use_sliding_window else None
|
||||
self.max_window_layers = max_window_layers
|
||||
|
||||
# for backward compatibility
|
||||
if num_key_value_heads is None:
|
||||
num_key_value_heads = num_attention_heads
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.head_dim = head_dim
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
rope_config_validation(self)
|
||||
|
||||
self.layer_types = layer_types
|
||||
if self.layer_types is None:
|
||||
self.layer_types = [
|
||||
"sliding_attention"
|
||||
if self.sliding_window is not None and i >= self.max_window_layers
|
||||
else "full_attention"
|
||||
for i in range(self.num_hidden_layers)
|
||||
]
|
||||
layer_type_validation(self.layer_types)
|
||||
self.num_code_groups = num_code_groups
|
||||
|
||||
|
||||
class Qwen3TTSTalkerConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen3TTSTalkerModel`]. It is used to instantiate a
|
||||
Qwen3TTSTalker model according to the specified arguments, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 151936):
|
||||
Vocabulary size of the Qwen3TTSTalker model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Qwen3TTSTalkerModel`]
|
||||
hidden_size (`int`, *optional*, defaults to 2048):
|
||||
Dimension of the hidden representations.
|
||||
intermediate_size (`int`, *optional*, defaults to 6144):
|
||||
Dimension of the MLP representations.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 24):
|
||||
Number of hidden layers in the Transformer encoder.
|
||||
num_attention_heads (`int`, *optional*, defaults to 32):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 4):
|
||||
This is the number of key_value heads that should be used to implement Grouped Query Attention. If
|
||||
`num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
|
||||
`num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
|
||||
converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
|
||||
by meanpooling all the original heads within that group. For more details, check out [this
|
||||
paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to `32`.
|
||||
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function (function or string) in the decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 32768):
|
||||
The maximum sequence length that this model might ever be used with.
|
||||
initializer_range (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-06):
|
||||
The epsilon used by the rms normalization layers.
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models). Only
|
||||
relevant if `config.is_decoder=True`.
|
||||
tie_word_embeddings (`bool`, *optional*, defaults to `False`):
|
||||
Whether the model's input and output word embeddings should be tied.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period of the RoPE embeddings.
|
||||
rope_scaling (`Dict`, *optional*):
|
||||
Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
|
||||
and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
|
||||
accordingly.
|
||||
Expected contents:
|
||||
`rope_type` (`str`):
|
||||
The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
|
||||
'llama3'], with 'default' being the original RoPE implementation.
|
||||
`factor` (`float`, *optional*):
|
||||
Used with all rope types except 'default'. The scaling factor to apply to the RoPE embeddings. In
|
||||
most scaling types, a `factor` of x will enable the model to handle sequences of length x *
|
||||
original maximum pre-trained length.
|
||||
`original_max_position_embeddings` (`int`, *optional*):
|
||||
Used with 'dynamic', 'longrope' and 'llama3'. The original max position embeddings used during
|
||||
pretraining.
|
||||
`attention_factor` (`float`, *optional*):
|
||||
Used with 'yarn' and 'longrope'. The scaling factor to be applied on the attention
|
||||
computation. If unspecified, it defaults to value recommended by the implementation, using the
|
||||
`factor` field to infer the suggested value.
|
||||
`beta_fast` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for extrapolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 32.
|
||||
`beta_slow` (`float`, *optional*):
|
||||
Only used with 'yarn'. Parameter to set the boundary for interpolation (only) in the linear
|
||||
ramp function. If unspecified, it defaults to 1.
|
||||
`short_factor` (`list[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to short contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`long_factor` (`list[float]`, *optional*):
|
||||
Only used with 'longrope'. The scaling factor to be applied to long contexts (<
|
||||
`original_max_position_embeddings`). Must be a list of numbers with the same length as the hidden
|
||||
size divided by the number of attention heads divided by 2
|
||||
`low_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to low frequency components of the RoPE
|
||||
`high_freq_factor` (`float`, *optional*):
|
||||
Only used with 'llama3'. Scaling factor applied to high frequency components of the RoPE
|
||||
attention_bias (`bool`, defaults to `False`, *optional*, defaults to `False`):
|
||||
Whether to use a bias in the query, key, value and output projection layers during self-attention.
|
||||
use_sliding_window (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use sliding window attention.
|
||||
sliding_window (`int`, *optional*, defaults to 4096):
|
||||
Sliding window attention (SWA) window size. If not specified, will default to `4096`.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
"""
|
||||
|
||||
model_type = "qwen3_tts_talker"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
# Default tensor parallel plan for base model `Qwen3TTSTalker`
|
||||
base_model_tp_plan = {
|
||||
"layers.*.self_attn.q_proj": "colwise",
|
||||
"layers.*.self_attn.k_proj": "colwise",
|
||||
"layers.*.self_attn.v_proj": "colwise",
|
||||
"layers.*.self_attn.o_proj": "rowwise",
|
||||
"layers.*.mlp.gate_proj": "colwise",
|
||||
"layers.*.mlp.up_proj": "colwise",
|
||||
"layers.*.mlp.down_proj": "rowwise",
|
||||
}
|
||||
base_model_pp_plan = {
|
||||
"embed_tokens": (["input_ids"], ["inputs_embeds"]),
|
||||
"layers": (["hidden_states", "attention_mask"], ["hidden_states"]),
|
||||
"norm": (["hidden_states"], ["hidden_states"]),
|
||||
}
|
||||
sub_configs = {"code_predictor_config": Qwen3TTSTalkerCodePredictorConfig}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
code_predictor_config=None,
|
||||
vocab_size=3072,
|
||||
hidden_size=1024,
|
||||
intermediate_size=2048,
|
||||
num_hidden_layers=20,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=2,
|
||||
hidden_act="silu",
|
||||
max_position_embeddings=32768,
|
||||
initializer_range=0.02,
|
||||
rms_norm_eps=0.000001,
|
||||
use_cache=True,
|
||||
tie_word_embeddings=False,
|
||||
rope_theta=10000,
|
||||
rope_scaling=None,
|
||||
attention_bias=False,
|
||||
use_sliding_window=False,
|
||||
sliding_window=4096,
|
||||
attention_dropout=0,
|
||||
num_code_groups=32,
|
||||
text_hidden_size=2048,
|
||||
codec_eos_token_id=4198,
|
||||
codec_think_id=4202,
|
||||
codec_nothink_id=4203,
|
||||
codec_think_bos_id=4204,
|
||||
codec_think_eos_id=4205,
|
||||
codec_pad_id=4196,
|
||||
codec_bos_id=4197,
|
||||
spk_id=None,
|
||||
spk_is_dialect=None,
|
||||
codec_language_id=None,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(
|
||||
tie_word_embeddings=tie_word_embeddings,
|
||||
**kwargs,
|
||||
)
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.hidden_size = hidden_size
|
||||
self.intermediate_size = intermediate_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.use_sliding_window = use_sliding_window
|
||||
self.sliding_window = sliding_window if use_sliding_window else None
|
||||
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.hidden_act = hidden_act
|
||||
self.initializer_range = initializer_range
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.use_cache = use_cache
|
||||
self.rope_theta = rope_theta
|
||||
self.rope_scaling = rope_scaling
|
||||
self.attention_bias = attention_bias
|
||||
self.attention_dropout = attention_dropout
|
||||
# Validate the correctness of rotary position embeddings parameters
|
||||
# BC: if there is a 'type' field, move it to 'rope_type'.
|
||||
if self.rope_scaling is not None and "type" in self.rope_scaling:
|
||||
self.rope_scaling["rope_type"] = self.rope_scaling["type"]
|
||||
|
||||
if code_predictor_config is None:
|
||||
code_predictor_config = {}
|
||||
self.code_predictor_config = Qwen3TTSTalkerCodePredictorConfig()
|
||||
logger.info("code_predictor_config is None. Initializing code_predictor model with default values")
|
||||
elif isinstance(code_predictor_config, Qwen3TTSTalkerCodePredictorConfig):
|
||||
self.code_predictor_config = code_predictor_config
|
||||
else:
|
||||
self.code_predictor_config = Qwen3TTSTalkerCodePredictorConfig(**code_predictor_config)
|
||||
self.num_code_groups = num_code_groups
|
||||
self.text_hidden_size = text_hidden_size
|
||||
self.codec_eos_token_id = codec_eos_token_id
|
||||
self.codec_think_id = codec_think_id
|
||||
self.codec_language_id = codec_language_id
|
||||
self.codec_nothink_id = codec_nothink_id
|
||||
self.codec_think_bos_id = codec_think_bos_id
|
||||
self.codec_think_eos_id = codec_think_eos_id
|
||||
self.codec_pad_id = codec_pad_id
|
||||
self.codec_bos_id = codec_bos_id
|
||||
self.spk_id = spk_id
|
||||
self.spk_is_dialect = spk_is_dialect
|
||||
|
||||
|
||||
class Qwen3TTSConfig(PretrainedConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of a [`Qwen3TTSForConditionalGeneration`].
|
||||
"""
|
||||
|
||||
model_type = "qwen3_tts"
|
||||
sub_configs = {
|
||||
"talker_config": Qwen3TTSTalkerConfig,
|
||||
"speaker_encoder_config": Qwen3TTSSpeakerEncoderConfig,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
talker_config=None,
|
||||
speaker_encoder_config=None,
|
||||
tokenizer_type=None,
|
||||
tts_model_size=None,
|
||||
tts_model_type=None,
|
||||
im_start_token_id=151644,
|
||||
im_end_token_id=151645,
|
||||
tts_pad_token_id=151671,
|
||||
tts_bos_token_id=151672,
|
||||
tts_eos_token_id=151673,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
if talker_config is None:
|
||||
talker_config = {}
|
||||
logger.info("talker_config is None. Initializing talker model with default values")
|
||||
if speaker_encoder_config is None:
|
||||
speaker_encoder_config = {}
|
||||
logger.info("speaker_encoder_config is None. Initializing talker model with default values")
|
||||
|
||||
self.talker_config = Qwen3TTSTalkerConfig(**talker_config)
|
||||
self.speaker_encoder_config = Qwen3TTSSpeakerEncoderConfig(**speaker_encoder_config)
|
||||
|
||||
self.tokenizer_type = tokenizer_type
|
||||
self.tts_model_size = tts_model_size
|
||||
self.tts_model_type = tts_model_type
|
||||
|
||||
self.im_start_token_id = im_start_token_id
|
||||
self.im_end_token_id = im_end_token_id
|
||||
self.tts_pad_token_id = tts_pad_token_id
|
||||
self.tts_bos_token_id = tts_bos_token_id
|
||||
self.tts_eos_token_id = tts_eos_token_id
|
||||
|
||||
|
||||
__all__ = ["Qwen3TTSConfig", "Qwen3TTSTalkerConfig", "Qwen3TTSSpeakerEncoderConfig"]
|
||||
2299
models/Qwen3-TTS/qwen_tts/core/models/modeling_qwen3_tts.py
Normal file
2299
models/Qwen3-TTS/qwen_tts/core/models/modeling_qwen3_tts.py
Normal file
File diff suppressed because it is too large
Load Diff
106
models/Qwen3-TTS/qwen_tts/core/models/processing_qwen3_tts.py
Normal file
106
models/Qwen3-TTS/qwen_tts/core/models/processing_qwen3_tts.py
Normal file
@@ -0,0 +1,106 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
from transformers.processing_utils import ProcessingKwargs, ProcessorMixin
|
||||
|
||||
|
||||
class Qwen3TTSProcessorKwargs(ProcessingKwargs, total=False):
|
||||
_defaults = {
|
||||
"text_kwargs": {
|
||||
"padding": False,
|
||||
"padding_side": "left",
|
||||
}
|
||||
}
|
||||
|
||||
class Qwen3TTSProcessor(ProcessorMixin):
|
||||
r"""
|
||||
Constructs a Qwen3TTS processor.
|
||||
|
||||
Args:
|
||||
tokenizer ([`Qwen2TokenizerFast`], *optional*):
|
||||
The text tokenizer.
|
||||
chat_template (`Optional[str]`, *optional*):
|
||||
The Jinja template to use for formatting the conversation. If not provided, the default chat template is used.
|
||||
"""
|
||||
|
||||
attributes = ["tokenizer"]
|
||||
tokenizer_class = ("Qwen2Tokenizer", "Qwen2TokenizerFast")
|
||||
|
||||
def __init__(
|
||||
self, tokenizer=None, chat_template=None
|
||||
):
|
||||
super().__init__(tokenizer, chat_template=chat_template)
|
||||
|
||||
def __call__(self, text=None, **kwargs) -> BatchFeature:
|
||||
"""
|
||||
Main method to prepare for the model one or several sequences(s) and audio(s). This method forwards the `text`
|
||||
and `kwargs` arguments to Qwen2TokenizerFast's [`~Qwen2TokenizerFast.__call__`] if `text` is not `None` to encode
|
||||
the text.
|
||||
|
||||
Args:
|
||||
text (`str`, `List[str]`, `List[List[str]]`):
|
||||
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
|
||||
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
|
||||
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
|
||||
"""
|
||||
|
||||
if text is None:
|
||||
raise ValueError("You need to specify either a `text` input to process.")
|
||||
|
||||
output_kwargs = self._merge_kwargs(
|
||||
Qwen3TTSProcessorKwargs,
|
||||
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
if not isinstance(text, list):
|
||||
text = [text]
|
||||
|
||||
texts_inputs = self.tokenizer(text, **output_kwargs["text_kwargs"])
|
||||
|
||||
return BatchFeature(
|
||||
data={**texts_inputs},
|
||||
tensor_type=kwargs.get("return_tensors"),
|
||||
)
|
||||
|
||||
def batch_decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.batch_decode`]. Please
|
||||
refer to the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.batch_decode(*args, **kwargs)
|
||||
|
||||
def decode(self, *args, **kwargs):
|
||||
"""
|
||||
This method forwards all its arguments to Qwen2TokenizerFast's [`~PreTrainedTokenizer.decode`]. Please refer to
|
||||
the docstring of this method for more information.
|
||||
"""
|
||||
return self.tokenizer.decode(*args, **kwargs)
|
||||
|
||||
def apply_chat_template(self, conversations, chat_template=None, **kwargs):
|
||||
if isinstance(conversations[0], dict):
|
||||
conversations = [conversations]
|
||||
return super().apply_chat_template(conversations, chat_template, **kwargs)
|
||||
|
||||
@property
|
||||
def model_input_names(self):
|
||||
tokenizer_input_names = self.tokenizer.model_input_names
|
||||
return list(
|
||||
dict.fromkeys(
|
||||
tokenizer_input_names
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
__all__ = ["Qwen3TTSProcessor"]
|
||||
@@ -0,0 +1,172 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Qwen3TTSTokenizerV2 model configuration"""
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
from transformers import MimiConfig
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Qwen3TTSTokenizerV2DecoderConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV2DecoderConfig`].
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
codebook_size (`int`, *optional*, defaults to 2048):
|
||||
Number of entries in each residual codebook used for acoustic token quantization.
|
||||
hidden_size (`int`, *optional*, defaults to 1024):
|
||||
Dimensionality of the hidden states and embeddings in the autoregressive transformer decoder.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 8000):
|
||||
Maximum sequence length that the autoregressive decoder can handle. Determines positional embedding size.
|
||||
rope_theta (`float`, *optional*, defaults to 10000.0):
|
||||
The base period for rotary position embeddings (RoPE) applied to attention layers.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the decoder.
|
||||
num_key_value_heads (`int`, *optional*, defaults to 16):
|
||||
Number of key and value attention heads used in grouped-query attention (if applicable).
|
||||
attention_bias (`bool`, *optional*, defaults to `False`):
|
||||
Whether to use bias in the attention projection layers.
|
||||
sliding_window (`int`, *optional*, defaults to 72):
|
||||
Window size for local attention mechanism, limiting attention context to improve efficiency.
|
||||
intermediate_size (`int`, *optional*, defaults to 3072):
|
||||
Dimensionality of the feed-forward (intermediate) layer in each transformer block.
|
||||
hidden_act (`str` or `function`, *optional*, defaults to `"silu"`):
|
||||
The non-linear activation function used in the feed-forward layers. Supports `"silu"`, `"relu"`, `"gelu"`, etc.
|
||||
layer_scale_initial_scale (`float`, *optional*, defaults to 0.01):
|
||||
Initial value for LayerScale applied in transformer blocks, helping stabilize training.
|
||||
rms_norm_eps (`float`, *optional*, defaults to 1e-5):
|
||||
Epsilon value for RMS normalization layers to prevent division by zero.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 8):
|
||||
Number of transformer blocks in the autoregressive decoder.
|
||||
num_quantizers (`int`, *optional*, defaults to 16):
|
||||
Number of residual vector quantizers used in the vocoder for fine-grained audio reconstruction.
|
||||
upsample_rates (`Tuple[int]`, *optional*, defaults to `(8, 5, 4, 3)`):
|
||||
Rate at which features are upsampled in the final waveform synthesis stage.
|
||||
upsampling_ratios (`Tuple[int]`, *optional*, defaults to `(2, 2)`):
|
||||
Ratios used in transposed convolutional layers to progressively upsample feature maps to waveform.
|
||||
decoder_dim (`int`, *optional*, defaults to 1536):
|
||||
Final dimensionality of the decoder's output before waveform generation.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
Dropout probability applied to attention weights in the decoder.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
codebook_size=2048,
|
||||
hidden_size=1024,
|
||||
latent_dim=1024,
|
||||
max_position_embeddings=8000,
|
||||
rope_theta=10000,
|
||||
num_attention_heads=16,
|
||||
num_key_value_heads=16,
|
||||
attention_bias=False,
|
||||
sliding_window=72,
|
||||
intermediate_size=3072,
|
||||
hidden_act="silu",
|
||||
layer_scale_initial_scale=0.01,
|
||||
rms_norm_eps=1e-5,
|
||||
num_hidden_layers=8,
|
||||
num_quantizers=16,
|
||||
upsample_rates=(8, 5, 4, 3),
|
||||
upsampling_ratios=(2, 2),
|
||||
decoder_dim=1536,
|
||||
attention_dropout=0.0,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.codebook_size = codebook_size
|
||||
self.hidden_size = hidden_size
|
||||
self.latent_dim = latent_dim
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.rope_theta = rope_theta
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.num_key_value_heads = num_key_value_heads
|
||||
self.attention_bias = attention_bias
|
||||
self.sliding_window = sliding_window
|
||||
self.intermediate_size = intermediate_size
|
||||
self.hidden_act = hidden_act
|
||||
self.layer_scale_initial_scale = layer_scale_initial_scale
|
||||
self.rms_norm_eps = rms_norm_eps
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_quantizers = num_quantizers
|
||||
self.upsample_rates = upsample_rates
|
||||
self.upsampling_ratios = upsampling_ratios
|
||||
self.decoder_dim = decoder_dim
|
||||
self.attention_dropout = attention_dropout
|
||||
|
||||
@property
|
||||
def layer_types(self):
|
||||
"""
|
||||
All layer in code2wav should be sliding attention
|
||||
"""
|
||||
return ["sliding_attention"] * self.num_hidden_layers
|
||||
|
||||
|
||||
class Qwen3TTSTokenizerV2Config(PretrainedConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV2Config`]. It is used to instantiate a Qwen3TTSTokenizerV2Model
|
||||
model according to the specified sub-models configurations, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
encoder_config (`dict`, *optional*): Configuration of the underlying encoder sub-model.
|
||||
decoder_config (`dict`, *optional*): Configuration of the underlying decoder sub-model.
|
||||
"""
|
||||
|
||||
model_type = "qwen3_tts_tokenizer_12hz"
|
||||
sub_configs = {
|
||||
"encoder_config": MimiConfig,
|
||||
"decoder_config": Qwen3TTSTokenizerV2DecoderConfig,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_config=None,
|
||||
decoder_config=None,
|
||||
encoder_valid_num_quantizers=16,
|
||||
input_sample_rate=24000,
|
||||
output_sample_rate=24000,
|
||||
decode_upsample_rate=1920,
|
||||
encode_downsample_rate=1920,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if encoder_config is None:
|
||||
encoder_config = {}
|
||||
logger.info("encoder_config is None. Initializing encoder with default values")
|
||||
if decoder_config is None:
|
||||
decoder_config = {}
|
||||
logger.info("decoder_config is None. Initializing decoder with default values")
|
||||
|
||||
self.encoder_config = MimiConfig(**encoder_config)
|
||||
self.decoder_config = Qwen3TTSTokenizerV2DecoderConfig(**decoder_config)
|
||||
|
||||
self.encoder_valid_num_quantizers = encoder_valid_num_quantizers
|
||||
self.input_sample_rate = input_sample_rate
|
||||
self.output_sample_rate = output_sample_rate
|
||||
self.decode_upsample_rate = decode_upsample_rate
|
||||
self.encode_downsample_rate = encode_downsample_rate
|
||||
|
||||
|
||||
__all__ = ["Qwen3TTSTokenizerV2Config", "Qwen3TTSTokenizerV2DecoderConfig"]
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,332 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Qwen team, Alibaba Group and the HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Qwen3TTSTokenizerV1 model configuration"""
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Qwen3TTSTokenizerV1DecoderDiTConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of the Qwen3TTSTokenizerV1DecoderToken2WavDiT.
|
||||
It defines the architecture of the DiT model, which is used for generating mel-spectrograms from tokens.
|
||||
|
||||
Args:
|
||||
hidden_size (`int`, *optional*, defaults to 1024):
|
||||
The dimension of the model.
|
||||
num_hidden_layers (`int`, *optional*, defaults to 22):
|
||||
The number of transformer blocks in the DiT model.
|
||||
num_attention_heads (`int`, *optional*, defaults to 16):
|
||||
The number of attention heads in each transformer block.
|
||||
ff_mult (`int`, *optional*, defaults to 2):
|
||||
The multiplier for the feedforward layer in each transformer block.
|
||||
emb_dim (`int`, *optional*, defaults to 512):
|
||||
The dimension of the embedding layer.
|
||||
head_dim (`int`, *optional*, defaults to 64):
|
||||
The dimension of each attention head.
|
||||
repeats (`int`, *optional*, defaults to 2):
|
||||
The number of times the codec embeddings are repeated.
|
||||
num_embeds (`int`, *optional*, defaults to 8193):
|
||||
The number of unique embeddings in the codec.
|
||||
mel_dim (`int`, *optional*, defaults to 80):
|
||||
The dimension of the mel-spectrogram.
|
||||
dropout (`float`, *optional*, defaults to 0.1):
|
||||
The dropout rate for the transformer blocks.
|
||||
|
||||
enc_emb_dim (`int`, *optional*, defaults to 192):
|
||||
The dimension of the pre-trained speaker embedding.
|
||||
enc_dim (`int`, *optional*, defaults to 128):
|
||||
The dimension of the encoder output.
|
||||
enc_channels (`list[int]`, *optional*, defaults to `[256, 256, 256, 256, 768]`):
|
||||
A list of output channels for each TDNN/SERes2Net layer in the encoder.
|
||||
enc_kernel_sizes (`list[int]`, *optional*, defaults to `[5, 3, 3, 3, 1]`):
|
||||
A list of kernel sizes for each layer in the encoder.
|
||||
enc_dilations (`list[int]`, *optional*, defaults to `[1, 2, 3, 4, 1]`):
|
||||
A list of dilations for each layer in the encoder.
|
||||
enc_attention_channels (`int`, *optional*, defaults to 64):
|
||||
The number of attention channels in the SqueezeExcitationBlock.
|
||||
enc_res2net_scale (`int`, *optional*, defaults to 2):
|
||||
The scale of the Res2Net block in the encoder.
|
||||
enc_se_channels (`int`, *optional*, defaults to 64):
|
||||
The number of output channels after squeeze in the SqueezeExcitationBlock.
|
||||
"""
|
||||
|
||||
model_type = "qwen3_tts_tokenizer_v1_decoder_dit"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size=1024,
|
||||
num_hidden_layers=22,
|
||||
num_attention_heads=16,
|
||||
ff_mult=2,
|
||||
emb_dim=512,
|
||||
head_dim=64,
|
||||
rope_theta=10000.0,
|
||||
max_position_embeddings=32768,
|
||||
block_size=24,
|
||||
look_ahead_layers=[10],
|
||||
look_backward_layers=[0, 20],
|
||||
repeats=2,
|
||||
num_embeds=8193,
|
||||
mel_dim=80,
|
||||
dropout=0.1,
|
||||
enc_emb_dim=192,
|
||||
enc_dim=128,
|
||||
enc_channels=[256, 256, 256, 256, 768],
|
||||
enc_kernel_sizes=[5, 3, 3, 3, 1],
|
||||
enc_dilations=[1, 2, 3, 4, 1],
|
||||
enc_attention_channels=64,
|
||||
enc_res2net_scale=2,
|
||||
enc_se_channels=64,
|
||||
**kwargs,
|
||||
):
|
||||
self.hidden_size = hidden_size
|
||||
self.num_hidden_layers = num_hidden_layers
|
||||
self.num_attention_heads = num_attention_heads
|
||||
self.ff_mult = ff_mult
|
||||
self.emb_dim = emb_dim
|
||||
self.head_dim = head_dim
|
||||
self.rope_theta = rope_theta
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.block_size = block_size
|
||||
self.look_ahead_layers = look_ahead_layers
|
||||
self.look_backward_layers = look_backward_layers
|
||||
self.repeats = repeats
|
||||
self.num_embeds = num_embeds
|
||||
self.mel_dim = mel_dim
|
||||
self.dropout = dropout
|
||||
self.enc_emb_dim = enc_emb_dim
|
||||
self.enc_dim = enc_dim
|
||||
self.enc_channels = enc_channels
|
||||
self.enc_kernel_sizes = enc_kernel_sizes
|
||||
self.enc_dilations = enc_dilations
|
||||
self.enc_attention_channels = enc_attention_channels
|
||||
self.enc_res2net_scale = enc_res2net_scale
|
||||
self.enc_se_channels = enc_se_channels
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class Qwen3TTSTokenizerV1DecoderBigVGANConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of the Qwen3TTSTokenizerV1DecoderToken2WavBigVGAN module.
|
||||
It defines the architecture of the BigVGAN model, which is used for converting mel-spectrograms to waveforms.
|
||||
|
||||
Args:
|
||||
mel_dim (`int`, *optional*, defaults to 80):
|
||||
The dimension of the mel-spectrogram.
|
||||
upsample_initial_channel (`int`, *optional*, defaults to 1536):
|
||||
The number of channels in the initial upsampling layer.
|
||||
resblock_kernel_sizes (`list[int]`, *optional*, defaults to `[3, 7, 11]`):
|
||||
A list of kernel sizes for each residual block.
|
||||
resblock_dilation_sizes (`list[list[int]]`, *optional*, defaults to `[[1, 3, 5], [1, 3, 5], [1, 3, 5]]`):
|
||||
A list of dilation sizes for each residual block.
|
||||
upsample_rates (`list[int]`, *optional*, defaults to `[5, 3, 2, 2, 2, 2]`):
|
||||
A list of upsampling rates for each upsampling layer.
|
||||
upsample_kernel_sizes (`list[int]`, *optional*, defaults to `[11, 7, 4, 4, 4, 4]`):
|
||||
A list of kernel sizes for each upsampling layer.
|
||||
"""
|
||||
|
||||
model_type = "qwen3_tts_tokenizer_v1_decoder_bigvgan"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
mel_dim=80,
|
||||
upsample_initial_channel=1536,
|
||||
resblock_kernel_sizes=[3, 7, 11],
|
||||
resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
|
||||
upsample_rates=[5, 3, 2, 2, 2, 2],
|
||||
upsample_kernel_sizes=[11, 7, 4, 4, 4, 4],
|
||||
**kwargs,
|
||||
):
|
||||
self.mel_dim = mel_dim
|
||||
self.upsample_initial_channel = upsample_initial_channel
|
||||
self.resblock_kernel_sizes = resblock_kernel_sizes
|
||||
self.resblock_dilation_sizes = resblock_dilation_sizes
|
||||
self.upsample_rates = upsample_rates
|
||||
self.upsample_kernel_sizes = upsample_kernel_sizes
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class Qwen3TTSTokenizerV1DecoderConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV1DecoderConfig`].
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
dit_config ([`DiT_Args`], *optional*):
|
||||
Configuration class for the Diffusion Transformer (DiT) module responsible for generating mel-spectrograms.
|
||||
bigvgan_config ([`BigVGAN_Args`], *optional*):
|
||||
Configuration class for the BigVGAN module responsible for converting mel-spectrograms to waveforms.
|
||||
"""
|
||||
|
||||
model_type = "qwen3_tts_tokenizer_v1_decoder"
|
||||
sub_configs = {
|
||||
"dit_config": Qwen3TTSTokenizerV1DecoderDiTConfig,
|
||||
"bigvgan_config": Qwen3TTSTokenizerV1DecoderBigVGANConfig,
|
||||
}
|
||||
|
||||
def __init__(self, dit_config=None, bigvgan_config=None, **kwargs):
|
||||
if dit_config is None:
|
||||
dit_config = {}
|
||||
if bigvgan_config is None:
|
||||
bigvgan_config = {}
|
||||
self.dit_config = Qwen3TTSTokenizerV1DecoderDiTConfig(**dit_config)
|
||||
self.bigvgan_config = Qwen3TTSTokenizerV1DecoderBigVGANConfig(**bigvgan_config)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class Qwen3TTSTokenizerV1EncoderConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of the Qwen3TTSTokenizerV1 Encoder.
|
||||
|
||||
The encoder typically takes mel-spectrogram features and produces high-level audio representations, then (optionally)
|
||||
applies an Audio-VQ module (e.g., GRVQ) to discretize continuous representations into codes.
|
||||
|
||||
Args:
|
||||
n_mels (`int`, *optional*, defaults to 128):
|
||||
Number of mel bins in the input mel-spectrogram.
|
||||
n_ctx (`int`, *optional*, defaults to 1500):
|
||||
Maximum input sequence length (in frames/tokens) for the encoder.
|
||||
n_state (`int`, *optional*, defaults to 1280):
|
||||
Hidden size (model dimension) of the encoder transformer.
|
||||
n_head (`int`, *optional*, defaults to 20):
|
||||
Number of attention heads in each transformer layer.
|
||||
n_layer (`int`, *optional*, defaults to 32):
|
||||
Number of transformer layers.
|
||||
n_window (`int`, *optional*, defaults to 100):
|
||||
Window size used by the model for local attention / chunking (implementation-dependent).
|
||||
output_dim (`int`, *optional*, defaults to 3584):
|
||||
Output feature dimension produced by the encoder head (before/after projection, implementation-dependent).
|
||||
|
||||
grad_checkpointing (`bool`, *optional*, defaults to `False`):
|
||||
Whether to enable gradient checkpointing to reduce memory usage during training.
|
||||
enable_mp (`bool`, *optional*, defaults to `False`):
|
||||
Whether to enable model parallel features (implementation-dependent).
|
||||
audio_sequence_parallel (`bool`, *optional*, defaults to `False`):
|
||||
Whether to enable sequence parallelism for audio branch (implementation-dependent).
|
||||
|
||||
audio_vq_type (`str`, *optional*, defaults to `"GRVQ"`):
|
||||
Type of audio vector-quantization module. Common choices: `"GRVQ"`, `"RVQ"`, etc.
|
||||
audio_vq_layers (`int`, *optional*, defaults to 6):
|
||||
Number of VQ layers / quantizers (e.g., number of residual quantizers for RVQ/GRVQ-like designs).
|
||||
audio_vq_codebook_size (`int`, *optional*, defaults to 32768):
|
||||
Size of each codebook (number of entries).
|
||||
audio_vq_codebook_dim (`int`, *optional*, defaults to 1280):
|
||||
Dimension of codebook vectors (often equals encoder hidden size).
|
||||
audio_vq_pe (`bool`, *optional*, defaults to `True`):
|
||||
Whether to use positional encoding (or position embeddings) inside the VQ module.
|
||||
audio_vq_ds_rate (`int`, *optional*, defaults to 2):
|
||||
Downsampling rate applied before VQ (e.g., temporal downsample factor).
|
||||
"""
|
||||
|
||||
model_type = "qwen3_tts_tokenizer_v1_encoder"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
n_mels=128,
|
||||
n_ctx=1500,
|
||||
n_state=1280,
|
||||
n_head=20,
|
||||
n_layer=32,
|
||||
n_window=100,
|
||||
output_dim=3584,
|
||||
grad_checkpointing=False,
|
||||
enable_mp=False,
|
||||
audio_sequence_parallel=False,
|
||||
audio_vq_type="GRVQ",
|
||||
audio_vq_layers=6,
|
||||
audio_vq_codebook_size=32768,
|
||||
audio_vq_codebook_dim=1280,
|
||||
audio_vq_pe=True,
|
||||
audio_vq_ds_rate=2,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self.n_mels = n_mels
|
||||
self.n_ctx = n_ctx
|
||||
self.n_state = n_state
|
||||
self.n_head = n_head
|
||||
self.n_layer = n_layer
|
||||
self.n_window = n_window
|
||||
self.output_dim = output_dim
|
||||
self.grad_checkpointing = grad_checkpointing
|
||||
self.enable_mp = enable_mp
|
||||
self.audio_sequence_parallel = audio_sequence_parallel
|
||||
self.audio_vq_type = audio_vq_type
|
||||
self.audio_vq_layers = audio_vq_layers
|
||||
self.audio_vq_codebook_size = audio_vq_codebook_size
|
||||
self.audio_vq_codebook_dim = audio_vq_codebook_dim
|
||||
self.audio_vq_pe = audio_vq_pe
|
||||
self.audio_vq_ds_rate = audio_vq_ds_rate
|
||||
|
||||
|
||||
class Qwen3TTSTokenizerV1Config(PretrainedConfig):
|
||||
"""
|
||||
This is the configuration class to store the configuration of a [`Qwen3TTSTokenizerV1Config`]. It is used to instantiate a Qwen3TTSTokenizerV1Model
|
||||
model according to the specified sub-models configurations, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
encoder_config (`dict`, *optional*): Configuration of the underlying encoder sub-model.
|
||||
decoder_config (`dict`, *optional*): Configuration of the underlying decoder sub-model.
|
||||
"""
|
||||
|
||||
model_type = "qwen3_tts_tokenizer_25hz"
|
||||
sub_configs = {
|
||||
"encoder_config": Qwen3TTSTokenizerV1EncoderConfig,
|
||||
"decoder_config": Qwen3TTSTokenizerV1DecoderConfig,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
encoder_config=None,
|
||||
decoder_config=None,
|
||||
input_sample_rate=24000,
|
||||
output_sample_rate=24000,
|
||||
decode_upsample_rate=1920,
|
||||
encode_downsample_rate=1920,
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
if encoder_config is None:
|
||||
encoder_config = {}
|
||||
logger.info("encoder_config is None. Initializing encoder with default values")
|
||||
if decoder_config is None:
|
||||
decoder_config = {}
|
||||
logger.info("decoder_config is None. Initializing decoder with default values")
|
||||
|
||||
self.encoder_config = Qwen3TTSTokenizerV1EncoderConfig(**encoder_config)
|
||||
self.decoder_config = Qwen3TTSTokenizerV1DecoderConfig(**decoder_config)
|
||||
|
||||
self.input_sample_rate = input_sample_rate
|
||||
self.output_sample_rate = output_sample_rate
|
||||
self.decode_upsample_rate = decode_upsample_rate
|
||||
self.encode_downsample_rate = encode_downsample_rate
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Qwen3TTSTokenizerV1Config",
|
||||
"Qwen3TTSTokenizerV1EncoderConfig",
|
||||
"Qwen3TTSTokenizerV1DecoderConfig",
|
||||
"Qwen3TTSTokenizerV1DecoderBigVGANConfig",
|
||||
"Qwen3TTSTokenizerV1DecoderDiTConfig"
|
||||
]
|
||||
File diff suppressed because it is too large
Load Diff
Binary file not shown.
523
models/Qwen3-TTS/qwen_tts/core/tokenizer_25hz/vq/core_vq.py
Normal file
523
models/Qwen3-TTS/qwen_tts/core/tokenizer_25hz/vq/core_vq.py
Normal file
@@ -0,0 +1,523 @@
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the license found in the
|
||||
# LICENSE file in the root directory of this source tree.
|
||||
#
|
||||
# This implementation is inspired from
|
||||
# https://github.com/lucidrains/vector-quantize-pytorch
|
||||
# which is released under MIT License. Hereafter, the original license:
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2020 Phil Wang
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
"""Core vector quantization implementation."""
|
||||
import random
|
||||
import typing as tp
|
||||
from random import randrange
|
||||
|
||||
import numpy as np
|
||||
from einops import rearrange, repeat
|
||||
from math import ceil
|
||||
import torch
|
||||
from torch import nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def round_up_multiple(num, mult):
|
||||
return ceil(num / mult) * mult
|
||||
|
||||
def default(val: tp.Any, d: tp.Any) -> tp.Any:
|
||||
return val if val is not None else d
|
||||
|
||||
|
||||
def ema_inplace(moving_avg, new, decay: float):
|
||||
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
||||
|
||||
|
||||
def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5):
|
||||
return (x + epsilon) / (x.sum() + n_categories * epsilon)
|
||||
|
||||
|
||||
def uniform_init(*shape: int):
|
||||
t = torch.empty(shape)
|
||||
nn.init.kaiming_uniform_(t)
|
||||
return t
|
||||
|
||||
|
||||
def sample_vectors(samples, num: int):
|
||||
num_samples, device = samples.shape[0], samples.device
|
||||
|
||||
if num_samples >= num:
|
||||
indices = torch.randperm(num_samples, device=device)[:num]
|
||||
else:
|
||||
indices = torch.randint(0, num_samples, (num,), device=device)
|
||||
|
||||
return samples[indices]
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def kmeans(samples, num_clusters: int, num_iters: int = 10):
|
||||
dim, dtype = samples.shape[-1], samples.dtype
|
||||
|
||||
means = sample_vectors(samples, num_clusters)
|
||||
|
||||
for _ in range(num_iters):
|
||||
dists = -(
|
||||
samples.pow(2).sum(1, keepdim=True)
|
||||
- 2 * torch.matmul(samples, means.t())
|
||||
+ means.t().pow(2).sum(0, keepdim=True)
|
||||
)
|
||||
|
||||
buckets = dists.max(dim=-1).indices
|
||||
del dists
|
||||
bins = torch.bincount(buckets, minlength=num_clusters)
|
||||
zero_mask = bins == 0
|
||||
bins_min_clamped = bins.masked_fill(zero_mask, 1)
|
||||
|
||||
new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype)
|
||||
new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples)
|
||||
new_means = new_means / bins_min_clamped[..., None]
|
||||
|
||||
means = torch.where(zero_mask[..., None], means, new_means)
|
||||
return means, bins
|
||||
|
||||
|
||||
def preprocess(x):
|
||||
x = rearrange(x, "... d -> (...) d")
|
||||
return x
|
||||
|
||||
|
||||
def postprocess_emb(embed_ind, shape):
|
||||
return embed_ind.view(*shape[:-1])
|
||||
|
||||
|
||||
class EuclideanCodebook(nn.Module):
|
||||
"""Codebook with Euclidean distance.
|
||||
Args:
|
||||
dim (int): Dimension.
|
||||
codebook_size (int): Codebook size.
|
||||
kmeans_init (bool): Whether to use k-means to initialize the codebooks.
|
||||
If set to true, run the k-means algorithm on the first training batch and use
|
||||
the learned centroids as initialization.
|
||||
kmeans_iters (int): Number of iterations used for k-means algorithm at initialization.
|
||||
decay (float): Decay for exponential moving average over the codebooks.
|
||||
epsilon (float): Epsilon value for numerical stability.
|
||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
||||
that have an exponential moving average cluster size less than the specified threshold with
|
||||
randomly selected vector from the current batch.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
codebook_size: int,
|
||||
kmeans_init: int = False,
|
||||
kmeans_iters: int = 10,
|
||||
decay: float = 0.99,
|
||||
epsilon: float = 1e-5,
|
||||
threshold_ema_dead_code: float = 2.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.decay = decay
|
||||
self.codebook_size = codebook_size
|
||||
self.kmeans_iters = kmeans_iters
|
||||
self.epsilon = epsilon
|
||||
self.threshold_ema_dead_code = threshold_ema_dead_code
|
||||
|
||||
self.inited = None
|
||||
self.cluster_size = None
|
||||
self.embed = None
|
||||
self.embed_avg = None
|
||||
self.training = True
|
||||
|
||||
def init_embed_(self, data):
|
||||
if self.inited:
|
||||
return
|
||||
|
||||
embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters)
|
||||
self.embed.data.copy_(embed)
|
||||
self.embed_avg.data.copy_(embed.clone())
|
||||
self.cluster_size.data.copy_(cluster_size)
|
||||
self.inited.data.copy_(torch.Tensor([True]))
|
||||
# Make sure all buffers across workers are in sync after initialization
|
||||
# distrib.broadcast_tensors([self.embed, self.embed_avg, self.cluster_size, self.inited])
|
||||
|
||||
def replace_(self, samples, mask):
|
||||
modified_codebook = torch.where(
|
||||
mask[..., None], sample_vectors(samples, self.codebook_size), self.embed
|
||||
)
|
||||
self.embed.data.copy_(modified_codebook)
|
||||
|
||||
def expire_codes_(self, batch_samples):
|
||||
if self.threshold_ema_dead_code == 0:
|
||||
return
|
||||
|
||||
cluster_size = self.cluster_size / sum(self.cluster_size) * self.codebook_size
|
||||
expired_codes = cluster_size < self.threshold_ema_dead_code
|
||||
if not torch.any(expired_codes):
|
||||
return
|
||||
else:
|
||||
print(f"VQ expire infos: num_expire={sum(expired_codes)}, cluster_size[:5]={cluster_size[:5]}")
|
||||
|
||||
batch_samples = rearrange(batch_samples, "... d -> (...) d")
|
||||
self.replace_(batch_samples, mask=expired_codes)
|
||||
# sync buffers outside for efficiency
|
||||
# distrib.broadcast_tensors(self.buffers())
|
||||
|
||||
def quantize(self, x):
|
||||
embed = self.embed.t()
|
||||
dist = -(
|
||||
x.pow(2).sum(1, keepdim=True)
|
||||
- 2 * x @ embed
|
||||
+ embed.pow(2).sum(0, keepdim=True)
|
||||
)
|
||||
embed_ind = dist.max(dim=-1).indices
|
||||
return embed_ind
|
||||
|
||||
def dequantize(self, embed_ind):
|
||||
quantize = F.embedding(embed_ind, self.embed)
|
||||
return quantize
|
||||
|
||||
def encode(self, x, buffers):
|
||||
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
|
||||
|
||||
shape = x.shape
|
||||
# pre-process
|
||||
x = preprocess(x)
|
||||
# quantize
|
||||
embed_ind = self.quantize(x)
|
||||
# post-process
|
||||
embed_ind = postprocess_emb(embed_ind, shape)
|
||||
return embed_ind
|
||||
|
||||
def decode(self, embed_ind, buffers):
|
||||
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
|
||||
|
||||
quantize = self.dequantize(embed_ind)
|
||||
return quantize
|
||||
|
||||
def forward(self, x, buffers):
|
||||
self.inited, self.cluster_size, self.embed, self.embed_avg = buffers
|
||||
|
||||
shape, dtype = x.shape, x.dtype
|
||||
x = preprocess(x)
|
||||
|
||||
self.init_embed_(x)
|
||||
if self.training:
|
||||
# We do the expiry of code at that point as buffers are in sync
|
||||
# and all the workers will take the same decision.
|
||||
self.expire_codes_(x)
|
||||
|
||||
embed_ind = self.quantize(x)
|
||||
embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype)
|
||||
embed_ind = postprocess_emb(embed_ind, shape)
|
||||
quantize = self.dequantize(embed_ind)
|
||||
|
||||
if self.training:
|
||||
ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay)
|
||||
embed_sum = x.t() @ embed_onehot
|
||||
ema_inplace(self.embed_avg, embed_sum.t(), self.decay)
|
||||
cluster_size = (
|
||||
laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon)
|
||||
* self.cluster_size.sum()
|
||||
)
|
||||
embed_normalized = self.embed_avg / cluster_size.unsqueeze(1)
|
||||
self.embed.data.copy_(embed_normalized)
|
||||
# Note: after ema update, there is a very small difference between codebooks on GPUs.
|
||||
# The impact can be very small, ignore it.
|
||||
|
||||
return quantize, embed_ind
|
||||
|
||||
|
||||
class VectorQuantization(nn.Module):
|
||||
"""Vector quantization implementation.
|
||||
Currently, supports only euclidean distance.
|
||||
Args:
|
||||
dim (int): Dimension
|
||||
codebook_size (int): Codebook size
|
||||
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
||||
decay (float): Decay for exponential moving average over the codebooks.
|
||||
epsilon (float): Epsilon value for numerical stability.
|
||||
kmeans_init (bool): Whether to use kmeans to initialize the codebooks.
|
||||
kmeans_iters (int): Number of iterations used for kmeans initialization.
|
||||
threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes
|
||||
that have an exponential moving average cluster size less than the specified threshold with
|
||||
randomly selected vector from the current batch.
|
||||
commitment_weight (float): Weight for commitment loss.
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
codebook_size: int,
|
||||
codebook_dim: tp.Optional[int] = None,
|
||||
decay: float = 0.99,
|
||||
epsilon: float = 1e-5,
|
||||
kmeans_init: bool = True,
|
||||
kmeans_iters: int = 50,
|
||||
threshold_ema_dead_code: float = 2.0,
|
||||
commitment_weight: float = 1.,
|
||||
):
|
||||
super().__init__()
|
||||
_codebook_dim: int = default(codebook_dim, dim)
|
||||
|
||||
requires_projection = _codebook_dim != dim
|
||||
self.project_in = (nn.Linear(dim, _codebook_dim)) if requires_projection else (nn.Identity())
|
||||
self.project_out = (nn.Linear(_codebook_dim, dim)) if requires_projection else (nn.Identity())
|
||||
|
||||
self.epsilon = epsilon
|
||||
self.commitment_weight = commitment_weight
|
||||
|
||||
self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size,
|
||||
kmeans_init=kmeans_init, kmeans_iters=kmeans_iters,
|
||||
decay=decay, epsilon=epsilon,
|
||||
threshold_ema_dead_code=threshold_ema_dead_code)
|
||||
self.codebook_size = codebook_size
|
||||
self.training = True
|
||||
|
||||
@property
|
||||
def codebook(self):
|
||||
return self._codebook.embed
|
||||
|
||||
def encode(self, x, buffers):
|
||||
# x = rearrange(x, "b d n -> b n d")
|
||||
x = self.project_in(x)
|
||||
embed_in = self._codebook.encode(x, buffers)
|
||||
return embed_in
|
||||
|
||||
def decode(self, embed_ind, buffers):
|
||||
quantize = self._codebook.decode(embed_ind, buffers)
|
||||
quantize = self.project_out(quantize)
|
||||
# quantize = rearrange(quantize, "b n d -> b d n")
|
||||
return quantize
|
||||
|
||||
def forward(self, x, buffers):
|
||||
device = x.device
|
||||
# x = rearrange(x, "b d n -> b n d")
|
||||
x = self.project_in(x)
|
||||
|
||||
quantize, embed_ind = self._codebook(x, buffers)
|
||||
|
||||
if self.training:
|
||||
quantize = x + (quantize - x).detach()
|
||||
|
||||
loss = torch.tensor([0.0], device=device, requires_grad=self.training)
|
||||
|
||||
if self.training:
|
||||
if self.commitment_weight > 0:
|
||||
commit_loss = F.mse_loss(quantize.detach(), x)
|
||||
loss = loss + commit_loss * self.commitment_weight
|
||||
|
||||
quantize = self.project_out(quantize)
|
||||
# quantize = rearrange(quantize, "b n d -> b d n")
|
||||
return quantize, embed_ind, loss
|
||||
|
||||
|
||||
class DistributedResidualVectorQuantization(nn.Module):
|
||||
"""Efficient distributed residual vector quantization implementation.
|
||||
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
||||
"""
|
||||
def __init__(self, *,
|
||||
num_quantizers,
|
||||
quantize_dropout: bool = False,
|
||||
rand_num_quant: tp.Optional[tp.List] = None,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
"""
|
||||
dim: int,
|
||||
codebook_size: int,
|
||||
codebook_dim: tp.Optional[int] = None,
|
||||
"""
|
||||
codebook_size, codebook_dim = kwargs["codebook_size"], kwargs["codebook_dim"] if kwargs["codebook_dim"] else kwargs["dim"]
|
||||
kmeans_init = kwargs["kmeans_init"]
|
||||
if isinstance(kmeans_init, bool):
|
||||
if not kwargs["kmeans_init"]:
|
||||
# use uniform init
|
||||
embed = uniform_init(num_quantizers, codebook_size, codebook_dim)
|
||||
inited = True
|
||||
else:
|
||||
# to perform kmeans init on first batch
|
||||
embed = torch.zeros(num_quantizers, codebook_size, codebook_dim)
|
||||
inited = False
|
||||
elif isinstance(kmeans_init, str):
|
||||
# use prepared kmeans init
|
||||
embed = np.load(kmeans_init)
|
||||
embed = torch.from_numpy(embed)
|
||||
if embed.dim() == 2:
|
||||
embed = embed.unsqueeze(0)
|
||||
inited = True
|
||||
else:
|
||||
raise TypeError("kmeans_init should be either a bool or string path to init weights.")
|
||||
|
||||
self.register_buffer("inited", torch.Tensor([[inited] for _ in range(num_quantizers)]))
|
||||
self.register_buffer("cluster_size", torch.zeros(num_quantizers, codebook_size))
|
||||
self.register_buffer("embed", embed)
|
||||
self.register_buffer("embed_avg", embed.clone())
|
||||
|
||||
self.q0_ds_ratio = 1
|
||||
if "q0_ds_ratio" in kwargs:
|
||||
self.q0_ds_ratio = kwargs.pop("q0_ds_ratio")
|
||||
|
||||
self.layers = nn.ModuleList()
|
||||
for i in range(num_quantizers):
|
||||
vq_args = dict(**kwargs)
|
||||
vq = VectorQuantization(**vq_args)
|
||||
self.layers.append(vq)
|
||||
|
||||
self.quantize_dropout = quantize_dropout
|
||||
self.rand_num_quant = rand_num_quant
|
||||
|
||||
def forward(self, x, n_q: tp.Optional[int] = None):
|
||||
quantized_out = torch.zeros_like(x)
|
||||
residual = x
|
||||
bb, cc, tt = x.shape
|
||||
device = x.device
|
||||
|
||||
all_losses = []
|
||||
all_indices = []
|
||||
all_sub_quants = []
|
||||
n_q = n_q or len(self.layers)
|
||||
|
||||
should_quantize_dropout = self.training and self.quantize_dropout and self.rand_num_quant is not None
|
||||
if should_quantize_dropout:
|
||||
rand_quantize_dropout_index = random.choice(self.rand_num_quant)
|
||||
|
||||
null_indices_shape = (x.shape[0], x.shape[2])
|
||||
null_indices = torch.full(null_indices_shape, -1., device=device, dtype=torch.long)
|
||||
null_loss = torch.full((1,), 0., device=device, dtype=x.dtype)
|
||||
null_sub_quant = torch.full(x.shape, -1, device=device, dtype=x.dtype)
|
||||
|
||||
for quantizer_index, layer in enumerate(self.layers[:n_q]):
|
||||
# dropout except the first quantizer
|
||||
if should_quantize_dropout and quantizer_index >= rand_quantize_dropout_index:
|
||||
all_indices.append(null_indices)
|
||||
all_losses.append(null_loss)
|
||||
all_sub_quants.append(null_sub_quant)
|
||||
continue
|
||||
|
||||
quant_in = residual
|
||||
if self.q0_ds_ratio > 1 and quantizer_index == 0:
|
||||
quant_in = F.interpolate(quant_in, size=[tt//2])
|
||||
quantized, indices, loss = layer(quant_in, [
|
||||
self.inited[quantizer_index],
|
||||
self.cluster_size[quantizer_index],
|
||||
self.embed[quantizer_index],
|
||||
self.embed_avg[quantizer_index]
|
||||
])
|
||||
if self.q0_ds_ratio > 1 and quantizer_index == 0:
|
||||
quantized = F.interpolate(quantized, size=[tt])
|
||||
indices = F.interpolate(indices.unsqueeze(1).float(), size=[tt]).squeeze(1).long()
|
||||
residual = residual - quantized
|
||||
quantized_out = quantized_out + quantized
|
||||
|
||||
all_indices.append(indices)
|
||||
all_losses.append(loss)
|
||||
all_sub_quants.append(quantized)
|
||||
|
||||
# sync buffers after one forward step
|
||||
# distrib.broadcast_tensors(self.buffers())
|
||||
out_losses, out_indices, out_sub_quants = map(torch.stack, (all_losses, all_indices, all_sub_quants))
|
||||
|
||||
return quantized_out, out_indices, out_losses
|
||||
|
||||
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
||||
residual = x
|
||||
all_indices = []
|
||||
n_q = n_q or len(self.layers)
|
||||
for i, layer in enumerate(self.layers[:n_q]):
|
||||
indices = layer.encode(residual, [
|
||||
self.inited[i],
|
||||
self.cluster_size[i],
|
||||
self.embed[i],
|
||||
self.embed_avg[i]
|
||||
])
|
||||
quantized = layer.decode(indices, [
|
||||
self.inited[i],
|
||||
self.cluster_size[i],
|
||||
self.embed[i],
|
||||
self.embed_avg[i]
|
||||
])
|
||||
residual = residual - quantized
|
||||
all_indices.append(indices)
|
||||
out_indices = torch.stack(all_indices)
|
||||
return out_indices
|
||||
|
||||
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
||||
quantized_out = torch.tensor(0.0, device=q_indices.device)
|
||||
for i, indices in enumerate(q_indices):
|
||||
layer = self.layers[i]
|
||||
quantized = layer.decode(indices, [
|
||||
self.inited[i],
|
||||
self.cluster_size[i],
|
||||
self.embed[i],
|
||||
self.embed_avg[i]
|
||||
])
|
||||
quantized_out = quantized_out + quantized
|
||||
return quantized_out
|
||||
|
||||
|
||||
class DistributedGroupResidualVectorQuantization(nn.Module):
|
||||
"""Efficient distributed group residual vector quantization implementation.
|
||||
Follows Algorithm 1. in https://arxiv.org/abs/2305.02765
|
||||
Group Then rvq
|
||||
"""
|
||||
def __init__(self, *,
|
||||
num_groups,
|
||||
num_quantizers,
|
||||
quantize_dropout: bool = False,
|
||||
rand_num_quant: tp.Optional[tp.List] = None,
|
||||
**kwargs):
|
||||
super().__init__()
|
||||
self.rvqs = nn.ModuleList(
|
||||
[
|
||||
DistributedResidualVectorQuantization(
|
||||
num_quantizers=num_quantizers,
|
||||
quantize_dropout=quantize_dropout,
|
||||
rand_num_quant=rand_num_quant,
|
||||
**kwargs
|
||||
)
|
||||
for _ in range(num_groups)
|
||||
]
|
||||
)
|
||||
self.num_groups = num_groups
|
||||
|
||||
def forward(self, x, n_q: tp.Optional[int] = None):
|
||||
x_lst = torch.chunk(x, chunks=self.num_groups, dim=1)
|
||||
all_quantized_out = []
|
||||
all_indices = []
|
||||
all_losses = []
|
||||
for mod, item in zip(self.rvqs, x_lst):
|
||||
quantized_out, out_indices, out_losses = mod(item, n_q)
|
||||
all_quantized_out.append(quantized_out)
|
||||
all_indices.append(out_indices)
|
||||
all_losses.append(out_losses)
|
||||
|
||||
out_losses = torch.stack(all_losses, dim=1).mean(dim=1)
|
||||
|
||||
return torch.cat(all_quantized_out, dim=1), torch.stack(all_indices, dim=1), out_losses
|
||||
|
||||
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
||||
x_lst = torch.chunk(x, chunks=self.num_groups, dim=1)
|
||||
return torch.stack([mod.encode(item, n_q) for mod, item in zip(self.rvqs, x_lst)], dim=1)
|
||||
|
||||
def decode(self, q_indices: torch.Tensor) -> torch.Tensor:
|
||||
q_indices_lst = torch.chunk(q_indices, chunks=self.num_groups, dim=1)
|
||||
return torch.cat([mod.decode(item.squeeze(1)) for mod, item in zip(self.rvqs, q_indices_lst)], dim=1)
|
||||
357
models/Qwen3-TTS/qwen_tts/core/tokenizer_25hz/vq/speech_vq.py
Normal file
357
models/Qwen3-TTS/qwen_tts/core/tokenizer_25hz/vq/speech_vq.py
Normal file
@@ -0,0 +1,357 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import sox
|
||||
import copy
|
||||
import torch
|
||||
import operator
|
||||
import onnxruntime
|
||||
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
import torchaudio.compliance.kaldi as kaldi
|
||||
|
||||
from librosa.filters import mel as librosa_mel_fn
|
||||
from itertools import accumulate
|
||||
from typing import List
|
||||
from torch import Tensor
|
||||
|
||||
from .core_vq import DistributedGroupResidualVectorQuantization
|
||||
from .whisper_encoder import WhisperEncoder, Conv1d, ConvTranspose1d
|
||||
|
||||
|
||||
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
||||
return torch.log(torch.clamp(x, min=clip_val) * C)
|
||||
|
||||
def spectral_normalize_torch(magnitudes):
|
||||
output = dynamic_range_compression_torch(magnitudes)
|
||||
return output
|
||||
|
||||
class MelSpectrogramFeatures(nn.Module):
|
||||
"""
|
||||
Calculate the BigVGAN style mel spectrogram of an input signal.
|
||||
Args:
|
||||
filter_length (int): The number of samples in the filter window, used for the Fourier Transform. Default is 1024.
|
||||
hop_length (int): The number of samples between successive frames (stride of the STFT). Default is 160.
|
||||
win_length (int): The length of the window function applied to each frame, usually less than or equal to the filter length. Default is 640.
|
||||
n_mel_channels (int): The number of Mel-frequency channels to output from the Mel-scale spectrogram. Default is 80.
|
||||
mel_fmin (int): The minimum frequency (in Hz) of the Mel-scale spectrogram. Default is 0.
|
||||
mel_fmax (int): The maximum frequency (in Hz) of the Mel-scale spectrogram. Default is 8000.
|
||||
sampling_rate (int): The sampling rate of the audio data (in Hz). Default is 16000.
|
||||
sampling_rate_org (int, optional): The original sampling rate of the audio data before any resampling (in Hz), if applicable. Default is None.
|
||||
padding (str): The padding mode for the input signal. 'center' pads the signal symmetrically around its center. Default is 'center'.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Mel spectrogram.
|
||||
"""
|
||||
def __init__(self,
|
||||
filter_length=1024,
|
||||
hop_length=160,
|
||||
win_length=640,
|
||||
n_mel_channels=80,
|
||||
mel_fmin=0,
|
||||
mel_fmax=8000,
|
||||
sampling_rate=16000,
|
||||
sampling_rate_org=None,
|
||||
padding='center',
|
||||
use_db = False,
|
||||
):
|
||||
super().__init__()
|
||||
if padding not in ["center", "same"]:
|
||||
raise ValueError("Padding must be 'center' or 'same'.")
|
||||
self.padding = padding
|
||||
|
||||
self.filter_length = filter_length
|
||||
self.hop_length = hop_length
|
||||
self.win_length = win_length
|
||||
self.n_mel_channels = n_mel_channels
|
||||
self.mel_fmin = mel_fmin
|
||||
self.mel_fmax = mel_fmax
|
||||
self.sampling_rate = sampling_rate
|
||||
self.sampling_rate_org = sampling_rate_org if sampling_rate_org is not None else sampling_rate
|
||||
self.mel_basis = {}
|
||||
self.hann_window = {}
|
||||
|
||||
def forward(self, audio: torch.Tensor, **kwargs) -> torch.Tensor:
|
||||
with torch.no_grad():
|
||||
feats = self.extract(audio, **kwargs)
|
||||
return feats
|
||||
|
||||
def extract(self, audio, **kwargs):
|
||||
|
||||
if len(audio.shape) == 3:
|
||||
audio = audio.squeeze(1) if audio.shape[1] == 1 else audio.squeeze(2)
|
||||
assert len(audio.shape) == 2
|
||||
|
||||
y = audio
|
||||
if len(list(self.mel_basis.keys())) == 0:
|
||||
mel = librosa_mel_fn(sr=self.sampling_rate, n_fft=self.filter_length, n_mels=self.n_mel_channels, fmin=self.mel_fmin, fmax=self.mel_fmax)
|
||||
self.mel_basis[str(self.mel_fmax)+'_'+str(y.device)] = torch.from_numpy(mel).float().to(y.device)
|
||||
self.hann_window[str(y.device)] = torch.hann_window(self.win_length).to(y.device)
|
||||
|
||||
y = torch.nn.functional.pad(y.unsqueeze(1), (int((self.filter_length-self.hop_length)/2), int((self.filter_length-self.hop_length)/2)), mode='reflect')
|
||||
y = y.squeeze(1)
|
||||
|
||||
spec = torch.stft(y, self.filter_length, hop_length=self.hop_length, win_length=self.win_length, window=self.hann_window[str(y.device)],
|
||||
center=False, pad_mode='reflect', normalized=False, onesided=True, return_complex=True)
|
||||
spec = torch.view_as_real(spec)
|
||||
spec = torch.sqrt(spec.pow(2).sum(-1)+(1e-9))
|
||||
|
||||
spec = torch.matmul(self.mel_basis[str(self.mel_fmax)+'_'+str(y.device)], spec)
|
||||
spec = spectral_normalize_torch(spec)
|
||||
|
||||
return spec
|
||||
|
||||
|
||||
class XVectorExtractor(nn.Module):
|
||||
def __init__(self, audio_codec_with_xvector):
|
||||
super().__init__()
|
||||
option = onnxruntime.SessionOptions()
|
||||
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
||||
option.intra_op_num_threads = 1
|
||||
providers = ["CPUExecutionProvider"]
|
||||
self.ort_session = onnxruntime.InferenceSession(audio_codec_with_xvector, sess_options=option, providers=providers)
|
||||
|
||||
self.tfm = sox.Transformer()
|
||||
self.tfm.norm(db_level=-6)
|
||||
|
||||
self.mel_ext = MelSpectrogramFeatures(
|
||||
filter_length=1024,
|
||||
hop_length=160,
|
||||
win_length=640,
|
||||
n_mel_channels=80,
|
||||
mel_fmin=0,
|
||||
mel_fmax=8000,
|
||||
sampling_rate=16000
|
||||
)
|
||||
|
||||
def extract_code(self, audio):
|
||||
with torch.no_grad():
|
||||
norm_audio = self.sox_norm(audio)
|
||||
|
||||
norm_audio = torch.from_numpy(copy.deepcopy(norm_audio)).unsqueeze(0)
|
||||
feat = kaldi.fbank(norm_audio,
|
||||
num_mel_bins=80,
|
||||
dither=0,
|
||||
sample_frequency=16000)
|
||||
feat = feat - feat.mean(dim=0, keepdim=True)
|
||||
norm_embedding = self.ort_session.run(None, {self.ort_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten()
|
||||
norm_embedding = F.normalize(torch.from_numpy(norm_embedding), dim=0)
|
||||
|
||||
ref_mel = self.mel_ext.extract(audio=norm_audio)
|
||||
|
||||
return norm_embedding.numpy(), ref_mel.permute(0,2,1).squeeze(0).numpy()
|
||||
|
||||
def sox_norm(self, audio):
|
||||
wav_norm = self.tfm.build_array(input_array=audio, sample_rate_in=16000)
|
||||
return wav_norm
|
||||
|
||||
|
||||
class WhisperEncoderVQ(WhisperEncoder):
|
||||
def __init__(
|
||||
self,
|
||||
n_mels: int,
|
||||
n_ctx: int,
|
||||
n_state: int,
|
||||
n_head: int,
|
||||
n_layer: int,
|
||||
n_window: int = 1500,
|
||||
output_dim: int = 512,
|
||||
grad_checkpointing: bool = False,
|
||||
enable_mp: bool = False,
|
||||
audio_sequence_parallel: bool = False,
|
||||
audio_vq_layers: int = -1,
|
||||
audio_vq_type: str = "NULL",
|
||||
audio_vq_codebook_size: int = 4096,
|
||||
audio_vq_pe: bool = False,
|
||||
audio_vq_commit_loss: float = 0.0,
|
||||
audio_vq_out_commit_loss: float = 0.0,
|
||||
audio_vq_no_quantize: bool = False,
|
||||
audio_vq_ff_layer: int = 0,
|
||||
audio_vq_threshold_ema_dead_code: float = 0.1,
|
||||
audio_vq_codebook_dim: int = None,
|
||||
audio_vq_ds_rate: int = None,
|
||||
):
|
||||
super().__init__(n_mels, n_ctx, n_state, n_head, n_layer, n_window, output_dim, grad_checkpointing, enable_mp, audio_sequence_parallel)
|
||||
|
||||
self.audio_vq_layers = audio_vq_layers
|
||||
self.audio_vq_type = audio_vq_type
|
||||
self.audio_vq_codebook_size = audio_vq_codebook_size
|
||||
self.audio_vq_pe = audio_vq_pe
|
||||
self.audio_vq_commit_loss = audio_vq_commit_loss
|
||||
self.audio_vq_out_commit_loss = audio_vq_out_commit_loss
|
||||
self.audio_vq_no_quantize = audio_vq_no_quantize
|
||||
self.audio_vq_ff_layer = audio_vq_ff_layer
|
||||
|
||||
if audio_vq_layers > 0:
|
||||
self.vq_feature_dim = self.n_state
|
||||
self.audio_vq_ds_rate = 1
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported audio_vq_layers: {audio_vq_layers}")
|
||||
|
||||
if self.audio_vq_ds_rate == audio_vq_ds_rate:
|
||||
self.audio_vq_downsample = nn.Identity()
|
||||
self.audio_vq_upsample = nn.Identity()
|
||||
else:
|
||||
assert audio_vq_ds_rate % self.audio_vq_ds_rate == 0
|
||||
stride = audio_vq_ds_rate // self.audio_vq_ds_rate
|
||||
self.audio_vq_downsample = Conv1d(self.vq_feature_dim, self.vq_feature_dim, kernel_size=stride, stride=stride)
|
||||
self.audio_vq_upsample = ConvTranspose1d(self.vq_feature_dim, self.vq_feature_dim, kernel_size=stride, stride=stride)
|
||||
self.audio_vq_ds_rate = audio_vq_ds_rate
|
||||
|
||||
if audio_vq_type == "GRVQ":
|
||||
self.audio_quantizer = DistributedGroupResidualVectorQuantization(
|
||||
codebook_size = audio_vq_codebook_size,
|
||||
dim = self.vq_feature_dim,
|
||||
codebook_dim = self.vq_codebook_dim if audio_vq_codebook_dim is None else audio_vq_codebook_dim,
|
||||
num_groups=1,
|
||||
num_quantizers=1,
|
||||
kmeans_init=False,
|
||||
threshold_ema_dead_code = audio_vq_threshold_ema_dead_code
|
||||
)
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported audio_vq_type: {audio_vq_type}")
|
||||
|
||||
if self.audio_vq_pe:
|
||||
self.project_after_vq_pe = nn.Linear(self.n_state, self.n_state)
|
||||
|
||||
def _calc_quantize_activities(self, indices):
|
||||
indices_onehot = F.one_hot(indices.long().flatten(), self.audio_vq_codebook_size).sum(dim=0)
|
||||
vq_num_activities = sum(indices_onehot>0)
|
||||
vq_num_tokens = sum(indices_onehot)
|
||||
return {
|
||||
"vq_num_activities": vq_num_activities,
|
||||
"vq_num_tokens": vq_num_tokens,
|
||||
}
|
||||
|
||||
def _do_quantize(self, x, pe=None, y=None):
|
||||
"""
|
||||
x: torch.Tensor, shape = (T, D)
|
||||
q: torch.Tensor, shape = (T, D)
|
||||
i: torch.Tensor, shape = (T)
|
||||
"""
|
||||
if self.audio_vq_out_commit_loss > 0:
|
||||
x_teacher = x.clone()
|
||||
x = x.unsqueeze(0)
|
||||
|
||||
x = self.audio_vq_downsample(x.transpose(1, 2))
|
||||
x = x.transpose(1, 2)
|
||||
|
||||
vq_stats = {}
|
||||
|
||||
if self.audio_vq_type == "GRVQ":
|
||||
if self.training:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
indices = self.audio_quantizer.encode(x)
|
||||
x = self.audio_quantizer.decode(indices)
|
||||
indices = indices.squeeze(2).squeeze(1)
|
||||
|
||||
vq_stats.update(self._calc_quantize_activities(indices))
|
||||
|
||||
x, indices = x.squeeze(0), indices.squeeze(0)
|
||||
if self.audio_vq_pe:
|
||||
x = x + pe
|
||||
x = self.project_after_vq_pe(x)
|
||||
|
||||
x = self.audio_vq_upsample(x.unsqueeze(0).transpose(1, 2))
|
||||
x = x.transpose(1, 2).squeeze(0)
|
||||
|
||||
if self.audio_vq_out_commit_loss > 0:
|
||||
vq_out_commit_loss = F.mse_loss(x_teacher.detach(), x)
|
||||
vq_stats["vq_out_commit_loss"] = vq_out_commit_loss * self.audio_vq_out_commit_loss
|
||||
|
||||
return x, indices, vq_stats
|
||||
|
||||
def forward(self, x_list: List[Tensor], audio_mellens:List[int], audio_aftercnnlens:List[int], audio_seqlens:List[int], return_indices=False, audio_pitchs=None):
|
||||
"""
|
||||
x : torch.Tensor, shape = (n_mels, n_ctx)
|
||||
the mel spectrogram of the audio
|
||||
"""
|
||||
|
||||
aftercnn_x_list = []
|
||||
pe_for_vq_list = []
|
||||
for each_x in x_list:
|
||||
each_x_split_list = each_x.split(self.n_window * 2, dim=1)
|
||||
for each_x_split in each_x_split_list:
|
||||
each_x_split = F.gelu(self.conv1(each_x_split))
|
||||
each_x_split = F.gelu(self.conv2(each_x_split))
|
||||
each_x_split = each_x_split.permute(1, 0) # L,D
|
||||
|
||||
each_positional_embedding_split = self.positional_embedding[:each_x_split.shape[0]]
|
||||
aftercnn_x_list.append(each_x_split+each_positional_embedding_split.to(each_x_split.dtype))
|
||||
|
||||
pe_for_vq_split = self.positional_embedding[:each_x_split.shape[0] // self.audio_vq_ds_rate]
|
||||
pe_for_vq_list.append(pe_for_vq_split.to(each_x_split.dtype))
|
||||
|
||||
pe_for_vq = torch.cat(pe_for_vq_list, dim=0)
|
||||
x = torch.cat(aftercnn_x_list, dim=0)
|
||||
src_len = x.size(0)
|
||||
|
||||
output_list = []
|
||||
for item in audio_aftercnnlens:
|
||||
while item > self.n_window:
|
||||
output_list.append(self.n_window)
|
||||
item -= self.n_window
|
||||
output_list.append(item)
|
||||
|
||||
cu_seqlens = list(accumulate(output_list, func=operator.add,initial=0))
|
||||
cu_seqlens = torch.Tensor(cu_seqlens).to(device=x.device, dtype=torch.int32)
|
||||
|
||||
layer_id = 0
|
||||
|
||||
for block in self.blocks:
|
||||
layer_id+=1
|
||||
|
||||
x = block(x, cu_seqlens=cu_seqlens)
|
||||
|
||||
if self.audio_vq_layers == layer_id: # vq inside encoder
|
||||
x, indices, vq_stats = self._do_quantize(x, pe_for_vq)
|
||||
if return_indices:
|
||||
return x, indices
|
||||
|
||||
if self.avg_pooler:
|
||||
x_list = x.split(audio_aftercnnlens, dim=0)
|
||||
token_x_list = []
|
||||
for x in x_list:
|
||||
x = x.permute(1, 0)
|
||||
x = self.avg_pooler(x)
|
||||
x = x.permute(1, 0)
|
||||
token_x_list.append(x)
|
||||
x = torch.cat(token_x_list, dim=0)
|
||||
|
||||
x = self.ln_post(x)
|
||||
|
||||
x = self.proj(x)
|
||||
|
||||
output = torch.zeros(
|
||||
(x.size(0) + len(audio_seqlens) * 2, x.size(1)),
|
||||
device=x.device, dtype=x.dtype
|
||||
)
|
||||
|
||||
audio_seqlens_acc = list(accumulate(audio_seqlens, func=operator.add, initial=0))
|
||||
start_ids = torch.tensor(audio_seqlens_acc[:-1], device=x.device, dtype=torch.int32)
|
||||
end_ids = torch.tensor(audio_seqlens_acc[1:], device=x.device, dtype=torch.int32) - 1
|
||||
|
||||
audio_tokens_mask = torch.ones(output.size(0), device=x.device, dtype=torch.bool)
|
||||
audio_tokens_mask[start_ids] = False
|
||||
audio_tokens_mask[end_ids] = False
|
||||
output[start_ids] = self.audio_bos_eos_token.weight[0].to(x.dtype)
|
||||
output[end_ids] = self.audio_bos_eos_token.weight[1].to(x.dtype)
|
||||
output[audio_tokens_mask] = x
|
||||
|
||||
if self.audio_vq_type != "NULL":
|
||||
return output, vq_stats
|
||||
return output
|
||||
@@ -0,0 +1,406 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import os
|
||||
import math
|
||||
import torch
|
||||
import operator
|
||||
|
||||
import numpy as np
|
||||
import torch.nn.functional as F
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import Optional, Union, List
|
||||
from torch import nn, Tensor
|
||||
from itertools import accumulate
|
||||
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_varlen_func as flash_attn_varlen_func
|
||||
except ImportError:
|
||||
try:
|
||||
from flash_attn.flash_attn_interface import flash_attn_unpadded_func as flash_attn_varlen_func
|
||||
except ImportError:
|
||||
print("\n********\nWarning: flash-attn is not installed. Will only run the manual PyTorch version. Please install flash-attn for faster inference.\n********\n ")
|
||||
flash_attn_varlen_func = None
|
||||
|
||||
|
||||
N_FFT = 400
|
||||
HOP_LENGTH = 160
|
||||
|
||||
|
||||
@lru_cache(maxsize=None)
|
||||
def mel_filters(device, n_mels: int) -> torch.Tensor:
|
||||
"""
|
||||
load the mel filterbank matrix for projecting STFT into a Mel spectrogram.
|
||||
Allows decoupling librosa dependency; saved using:
|
||||
|
||||
np.savez_compressed(
|
||||
"mel_filters.npz",
|
||||
mel_80=librosa.filters.mel(sr=16000, n_fft=400, n_mels=80),
|
||||
mel_128=librosa.filters.mel(sr=16000, n_fft=400, n_mels=128),
|
||||
)
|
||||
"""
|
||||
assert n_mels in {80, 128}, f"Unsupported n_mels: {n_mels}"
|
||||
|
||||
filters_path = os.path.join(os.path.dirname(__file__), "assets", "mel_filters.npz")
|
||||
with np.load(filters_path, allow_pickle=False) as f:
|
||||
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
|
||||
|
||||
|
||||
def log_mel_spectrogram(
|
||||
audio: Union[str, np.ndarray, torch.Tensor],
|
||||
n_mels: int = 80,
|
||||
padding: int = 0,
|
||||
device: Optional[Union[str, torch.device]] = None,
|
||||
):
|
||||
"""
|
||||
Compute the log-Mel spectrogram of
|
||||
|
||||
Parameters
|
||||
----------
|
||||
audio: Union[str, np.ndarray, torch.Tensor], shape = (*)
|
||||
The path to audio or either a NumPy array or Tensor containing the audio waveform in 16 kHz
|
||||
|
||||
n_mels: int
|
||||
The number of Mel-frequency filters, only 80 is supported
|
||||
|
||||
padding: int
|
||||
Number of zero samples to pad to the right
|
||||
|
||||
device: Optional[Union[str, torch.device]]
|
||||
If given, the audio tensor is moved to this device before STFT
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor, shape = (80, n_frames)
|
||||
A Tensor that contains the Mel spectrogram
|
||||
"""
|
||||
if not torch.is_tensor(audio):
|
||||
audio = torch.from_numpy(audio)
|
||||
|
||||
if device is not None:
|
||||
audio = audio.to(device)
|
||||
if padding > 0:
|
||||
audio = F.pad(audio, (0, padding))
|
||||
window = torch.hann_window(N_FFT).to(audio.device)
|
||||
stft = torch.stft(audio, N_FFT, HOP_LENGTH, window=window, return_complex=True)
|
||||
magnitudes = stft[..., :-1].abs() ** 2
|
||||
|
||||
filters = mel_filters(audio.device, n_mels)
|
||||
mel_spec = filters @ magnitudes
|
||||
|
||||
log_spec = torch.clamp(mel_spec, min=1e-10).log10()
|
||||
log_spec = torch.maximum(log_spec, log_spec.max() - 8.0)
|
||||
log_spec = (log_spec + 4.0) / 4.0
|
||||
return log_spec
|
||||
|
||||
|
||||
def get_T_after_cnn(L_in, dilation=1):
|
||||
for (padding, kernel_size, stride) in eval("[(1,3,1)] + [(1,3,2)] "):
|
||||
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
|
||||
L_out = 1 + L_out // stride
|
||||
L_in = L_out
|
||||
return L_out
|
||||
|
||||
|
||||
def get_mel_audio(audio, padding=False, audio_vq_ds_rate = 1, n_mels = 128):
|
||||
audio_len = len(audio)
|
||||
if padding:
|
||||
reduction = 160 * 2 * audio_vq_ds_rate
|
||||
audio_pad = math.ceil(audio_len / reduction) * reduction - audio_len
|
||||
mel = log_mel_spectrogram(audio, n_mels=n_mels, padding=audio_pad)
|
||||
else:
|
||||
mel = log_mel_spectrogram(audio, n_mels=n_mels) # [F,T]
|
||||
return mel
|
||||
|
||||
|
||||
def sinusoids(length, channels, max_timescale=10000):
|
||||
"""Returns sinusoids for positional embedding"""
|
||||
assert channels % 2 == 0
|
||||
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
|
||||
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
|
||||
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
|
||||
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
|
||||
|
||||
|
||||
class Conv1d(nn.Conv1d):
|
||||
def _conv_forward(
|
||||
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
||||
) -> Tensor:
|
||||
return super()._conv_forward(
|
||||
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
||||
)
|
||||
|
||||
|
||||
class ConvTranspose1d(nn.ConvTranspose1d):
|
||||
def _conv_forward(
|
||||
self, x: Tensor, weight: Tensor, bias: Optional[Tensor]
|
||||
) -> Tensor:
|
||||
return super()._conv_forward(
|
||||
x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype)
|
||||
)
|
||||
|
||||
|
||||
class Linear(nn.Linear):
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype) )
|
||||
|
||||
|
||||
class MultiHeadAttention(nn.Module):
|
||||
def __init__(self, n_state: int, n_head: int):
|
||||
super().__init__()
|
||||
self.n_head = n_head
|
||||
self.query = Linear(n_state, n_state)
|
||||
self.key = Linear(n_state, n_state, bias=False)
|
||||
self.value = Linear(n_state, n_state)
|
||||
self.out = Linear(n_state, n_state)
|
||||
|
||||
self.use_flash_attention = True
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
cu_seqlens = None,
|
||||
):
|
||||
q = self.query(x)
|
||||
k = self.key(x)
|
||||
v = self.value(x)
|
||||
|
||||
if self.use_flash_attention:
|
||||
if flash_attn_varlen_func is None:
|
||||
x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens)
|
||||
else:
|
||||
if q.dtype not in [torch.float16, torch.bfloat16]:
|
||||
x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens)
|
||||
self.use_flash_attention = False
|
||||
else:
|
||||
x = self.qkv_flash_attention(q, k, v, cu_seqlens=cu_seqlens)
|
||||
else:
|
||||
x = self.qkv_attention_manual(q, k, v, cu_seqlens=cu_seqlens)
|
||||
|
||||
output = self.out(x)
|
||||
return output
|
||||
|
||||
def qkv_flash_attention(
|
||||
self, q: Tensor, k: Tensor, v: Tensor, cu_seqlens=None
|
||||
):
|
||||
n_ctx, n_state = q.shape
|
||||
# scale = (n_state // self.n_head) ** -0.25
|
||||
q = q.view(n_ctx, self.n_head, -1)# (batch_size, seqlen, nheads, headdim)
|
||||
k = k.view(n_ctx, self.n_head, -1)
|
||||
v = v.view(n_ctx, self.n_head, -1)
|
||||
|
||||
max_seqlen = (cu_seqlens[1:] - cu_seqlens[:-1]).max().item()
|
||||
|
||||
|
||||
x = flash_attn_varlen_func(
|
||||
q, k, v, cu_seqlens, cu_seqlens, max_seqlen, max_seqlen, dropout_p=0.0
|
||||
)
|
||||
x = x.reshape(n_ctx, n_state)
|
||||
return x
|
||||
|
||||
def qkv_attention_manual(
|
||||
self, q: Tensor, k: Tensor, v: Tensor, cu_seqlens: Tensor
|
||||
):
|
||||
n_ctx, n_state = q.shape
|
||||
head_dim = n_state // self.n_head
|
||||
scale = head_dim ** -0.5
|
||||
|
||||
q = q.view(n_ctx, self.n_head, head_dim)
|
||||
k = k.view(n_ctx, self.n_head, head_dim)
|
||||
v = v.view(n_ctx, self.n_head, head_dim)
|
||||
|
||||
seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist()
|
||||
batch_size = len(seqlens)
|
||||
max_seqlen = max(seqlens)
|
||||
|
||||
q_padded = torch.zeros(batch_size, max_seqlen, self.n_head, head_dim, dtype=q.dtype, device=q.device)
|
||||
k_padded = torch.zeros_like(q_padded)
|
||||
v_padded = torch.zeros_like(q_padded)
|
||||
|
||||
for i in range(batch_size):
|
||||
start_idx = cu_seqlens[i]
|
||||
end_idx = cu_seqlens[i+1]
|
||||
seq_len = seqlens[i]
|
||||
q_padded[i, :seq_len] = q[start_idx:end_idx]
|
||||
k_padded[i, :seq_len] = k[start_idx:end_idx]
|
||||
v_padded[i, :seq_len] = v[start_idx:end_idx]
|
||||
|
||||
q_padded = q_padded.transpose(1, 2)
|
||||
k_padded = k_padded.transpose(1, 2)
|
||||
v_padded = v_padded.transpose(1, 2)
|
||||
|
||||
attn_mask = torch.arange(max_seqlen, device=q.device)[None, :] < torch.tensor(seqlens, device=q.device)[:, None]
|
||||
attn_mask = attn_mask.unsqueeze(1).unsqueeze(2)
|
||||
|
||||
attn_mask = attn_mask.masked_fill(attn_mask == 0, -torch.finfo(q.dtype).max)
|
||||
|
||||
attn_scores = torch.matmul(q_padded, k_padded.transpose(-2, -1)) * scale
|
||||
attn_scores = attn_scores + attn_mask
|
||||
attn_weights = F.softmax(attn_scores, dim=-1)
|
||||
|
||||
context = torch.matmul(attn_weights, v_padded)
|
||||
|
||||
context = context.transpose(1, 2).contiguous().view(batch_size, max_seqlen, n_state)
|
||||
|
||||
output_packed = torch.cat([context[i, :seqlens[i]] for i in range(batch_size)], dim=0)
|
||||
|
||||
assert output_packed.shape == (n_ctx, n_state)
|
||||
|
||||
return output_packed
|
||||
|
||||
|
||||
class ResidualAttentionBlock(nn.Module):
|
||||
def __init__(self, n_state: int, n_head: int,
|
||||
enable_mp: bool = False, sequence_parallel: bool = False):
|
||||
super().__init__()
|
||||
n_mlp = n_state * 4
|
||||
self.attn_ln = nn.LayerNorm(n_state)
|
||||
self.mlp_ln = nn.LayerNorm(n_state)
|
||||
|
||||
self.attn = MultiHeadAttention(n_state, n_head)
|
||||
self.mlp = nn.Sequential(
|
||||
Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state)
|
||||
)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
x: Tensor,
|
||||
cu_seqlens = None
|
||||
):
|
||||
x = x + self.attn(self.attn_ln(x), cu_seqlens=cu_seqlens)
|
||||
x = x + self.mlp(self.mlp_ln(x))
|
||||
return x
|
||||
|
||||
|
||||
class WhisperEncoder(nn.Module):
|
||||
def __init__(
|
||||
self,
|
||||
n_mels: int,
|
||||
n_ctx: int,
|
||||
n_state: int,
|
||||
n_head: int,
|
||||
n_layer: int,
|
||||
n_window: int = 1500,
|
||||
output_dim: int = 512,
|
||||
grad_checkpointing: bool = False,
|
||||
enable_mp: bool = False,
|
||||
audio_sequence_parallel: bool = False,
|
||||
):
|
||||
super().__init__()
|
||||
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
|
||||
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
|
||||
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
|
||||
self.n_layer = n_layer
|
||||
self.n_mels = n_mels
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[ResidualAttentionBlock(n_state, n_head, enable_mp=enable_mp, sequence_parallel=audio_sequence_parallel)
|
||||
for _ in range(n_layer)]
|
||||
)
|
||||
self.ln_post = nn.LayerNorm(n_state)
|
||||
self.avg_pooler = nn.AvgPool1d(2, stride=2)
|
||||
|
||||
self.proj = torch.nn.Linear(n_state, output_dim)
|
||||
|
||||
self.audio_bos_eos_token = nn.Embedding(2, output_dim)
|
||||
|
||||
self.output_dim = output_dim
|
||||
self.grad_checkpointing = grad_checkpointing
|
||||
self.enable_mp = enable_mp
|
||||
self.n_head = n_head
|
||||
self.n_state = n_state
|
||||
self.n_window = n_window
|
||||
|
||||
self.audio_sequence_parallel = audio_sequence_parallel
|
||||
|
||||
self.tp_world_size = 1
|
||||
|
||||
self.set_audio_sync()
|
||||
|
||||
def set_audio_sync(self):
|
||||
for name, param in self.named_parameters():
|
||||
if not name.startswith("blocks"):
|
||||
setattr(param, "audio_sync", True)
|
||||
|
||||
def forward(self, x_list: List[Tensor], audio_mellens:List[int], audio_aftercnnlens:List[int], audio_seqlens:List[int]):
|
||||
"""
|
||||
x : torch.Tensor, shape = (n_mels, n_ctx)
|
||||
the mel spectrogram of the audio
|
||||
"""
|
||||
|
||||
aftercnn_x_list = []
|
||||
for each_x in x_list:
|
||||
each_x_split_list = each_x.split(self.n_window * 2, dim=1)
|
||||
for each_x_split in each_x_split_list:
|
||||
each_x_split = F.gelu(self.conv1(each_x_split))
|
||||
each_x_split = F.gelu(self.conv2(each_x_split))
|
||||
each_x_split = each_x_split.permute(1, 0) # L,D
|
||||
each_positional_embedding_split = self.positional_embedding[:each_x_split.shape[0]]
|
||||
aftercnn_x_list.append(each_x_split+each_positional_embedding_split.to(each_x_split.dtype))
|
||||
|
||||
x = torch.cat(aftercnn_x_list, dim=0)
|
||||
src_len = x.size(0)
|
||||
|
||||
output_list = []
|
||||
for item in audio_aftercnnlens:
|
||||
while item > self.n_window:
|
||||
output_list.append(self.n_window)
|
||||
item -= self.n_window
|
||||
output_list.append(item)
|
||||
|
||||
cu_seqlens = list(accumulate(output_list, func=operator.add,initial=0))
|
||||
cu_seqlens = torch.Tensor(cu_seqlens).to(device=x.device, dtype=torch.int32)
|
||||
|
||||
layer_id = 0
|
||||
for block in self.blocks:
|
||||
layer_id+=1
|
||||
x = block(x, cu_seqlens=cu_seqlens)
|
||||
|
||||
if self.avg_pooler:
|
||||
x_list = x.split(audio_aftercnnlens, dim=0)
|
||||
token_x_list = []
|
||||
for x in x_list:
|
||||
x = x.permute(1, 0)
|
||||
x = self.avg_pooler(x)
|
||||
x = x.permute(1, 0)
|
||||
token_x_list.append(x)
|
||||
x = torch.cat(token_x_list, dim=0)
|
||||
|
||||
x = self.ln_post(x)
|
||||
x = self.proj(x)
|
||||
|
||||
output = torch.zeros(
|
||||
(x.size(0) + len(audio_seqlens) * 2, x.size(1)),
|
||||
device=x.device, dtype=x.dtype
|
||||
)
|
||||
|
||||
audio_seqlens_acc = list(accumulate(audio_seqlens, func=operator.add, initial=0))
|
||||
start_ids = torch.tensor(audio_seqlens_acc[:-1], device=x.device, dtype=torch.int32)
|
||||
end_ids = torch.tensor(audio_seqlens_acc[1:], device=x.device, dtype=torch.int32) - 1
|
||||
|
||||
audio_tokens_mask = torch.ones(output.size(0), device=x.device, dtype=torch.bool)
|
||||
audio_tokens_mask[start_ids] = False
|
||||
audio_tokens_mask[end_ids] = False
|
||||
output[start_ids] = self.audio_bos_eos_token.weight[0].to(x.dtype)
|
||||
output[end_ids] = self.audio_bos_eos_token.weight[1].to(x.dtype)
|
||||
output[audio_tokens_mask] = x
|
||||
return output
|
||||
|
||||
def lock(self, layers: int):
|
||||
self.conv1.requires_grad_(False)
|
||||
self.conv2.requires_grad_(False)
|
||||
for i in range(min(layers, len(self.blocks))):
|
||||
self.blocks[i].requires_grad_(False)
|
||||
877
models/Qwen3-TTS/qwen_tts/inference/qwen3_tts_model.py
Normal file
877
models/Qwen3-TTS/qwen_tts/inference/qwen3_tts_model.py
Normal file
@@ -0,0 +1,877 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import base64
|
||||
import io
|
||||
import urllib.request
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from transformers import AutoConfig, AutoModel, AutoProcessor
|
||||
|
||||
from ..core.models import Qwen3TTSConfig, Qwen3TTSForConditionalGeneration, Qwen3TTSProcessor
|
||||
|
||||
AudioLike = Union[
|
||||
str, # wav path, URL, base64
|
||||
np.ndarray, # waveform (requires sr)
|
||||
Tuple[np.ndarray, int], # (waveform, sr)
|
||||
]
|
||||
|
||||
MaybeList = Union[Any, List[Any]]
|
||||
|
||||
|
||||
@dataclass
|
||||
class VoiceClonePromptItem:
|
||||
"""
|
||||
Container for one sample's voice-clone prompt information that can be fed to the model.
|
||||
|
||||
Fields are aligned with `Qwen3TTSForConditionalGeneration.generate(..., voice_clone_prompt=...)`.
|
||||
"""
|
||||
ref_code: Optional[torch.Tensor] # (T, Q) or (T,) depending on tokenizer 25Hz/12Hz
|
||||
ref_spk_embedding: torch.Tensor # (D,)
|
||||
x_vector_only_mode: bool
|
||||
icl_mode: bool
|
||||
ref_text: Optional[str] = None
|
||||
|
||||
|
||||
class Qwen3TTSModel:
|
||||
"""
|
||||
A HuggingFace-style wrapper for Qwen3 TTS models (CustomVoice/VoiceDesign/Base) that provides:
|
||||
- from_pretrained() initialization via AutoModel/AutoProcessor
|
||||
- generation APIs for:
|
||||
* CustomVoice: generate_custom_voice()
|
||||
* VoiceDesign: generate_voice_design()
|
||||
* Base: generate_voice_clone() + create_voice_clone_prompt()
|
||||
- consistent output: (wavs: List[np.ndarray], sample_rate: int)
|
||||
|
||||
Notes:
|
||||
- This wrapper expects the underlying model class to be `Qwen3TTSForConditionalGeneration`
|
||||
- Language / speaker validation is done via model methods:
|
||||
model.get_supported_languages(), model.get_supported_speakers()
|
||||
"""
|
||||
|
||||
def __init__(self, model: Qwen3TTSForConditionalGeneration, processor, generate_defaults: Optional[Dict[str, Any]] = None):
|
||||
self.model = model
|
||||
self.processor = processor
|
||||
self.generate_defaults = generate_defaults or {}
|
||||
|
||||
self.device = getattr(model, "device", None)
|
||||
if self.device is None:
|
||||
try:
|
||||
self.device = next(model.parameters()).device
|
||||
except StopIteration:
|
||||
self.device = torch.device("cpu")
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls,
|
||||
pretrained_model_name_or_path: str,
|
||||
**kwargs,
|
||||
) -> "Qwen3TTSModel":
|
||||
"""
|
||||
Load a Qwen3 TTS model and its processor in HuggingFace `from_pretrained` style.
|
||||
|
||||
This method:
|
||||
1) Loads config via AutoConfig (so your side can register model_type -> config/model).
|
||||
2) Loads the model via AutoModel.from_pretrained(...), forwarding `kwargs` unchanged.
|
||||
3) Loads the processor via AutoProcessor.from_pretrained(model_path).
|
||||
4) Loads optional `generate_config.json` from the model directory/repo snapshot if present.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (str):
|
||||
HuggingFace repo id or local directory of the model.
|
||||
**kwargs:
|
||||
Forwarded as-is into `AutoModel.from_pretrained(...)`.
|
||||
Typical examples: device_map="cuda:0", dtype=torch.bfloat16, attn_implementation="flash_attention_2".
|
||||
|
||||
Returns:
|
||||
Qwen3TTSModel:
|
||||
Wrapper instance containing `model`, `processor`, and generation defaults.
|
||||
"""
|
||||
AutoConfig.register("qwen3_tts", Qwen3TTSConfig)
|
||||
AutoModel.register(Qwen3TTSConfig, Qwen3TTSForConditionalGeneration)
|
||||
AutoProcessor.register(Qwen3TTSConfig, Qwen3TTSProcessor)
|
||||
|
||||
model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
if not isinstance(model, Qwen3TTSForConditionalGeneration):
|
||||
raise TypeError(
|
||||
f"AutoModel returned {type(model)}, expected Qwen3TTSForConditionalGeneration. "
|
||||
)
|
||||
|
||||
processor = AutoProcessor.from_pretrained(pretrained_model_name_or_path, fix_mistral_regex=True,)
|
||||
|
||||
generate_defaults = model.generate_config
|
||||
return cls(model=model, processor=processor, generate_defaults=generate_defaults)
|
||||
|
||||
def _supported_languages_set(self) -> Optional[set]:
|
||||
langs = getattr(self.model, "get_supported_languages", None)
|
||||
if callable(langs):
|
||||
v = langs()
|
||||
if v is None:
|
||||
return None
|
||||
return set([str(x).lower() for x in v])
|
||||
return None
|
||||
|
||||
def _supported_speakers_set(self) -> Optional[set]:
|
||||
spks = getattr(self.model, "get_supported_speakers", None)
|
||||
if callable(spks):
|
||||
v = spks()
|
||||
if v is None:
|
||||
return None
|
||||
return set([str(x).lower() for x in v])
|
||||
return None
|
||||
|
||||
def _validate_languages(self, languages: List[str]) -> None:
|
||||
"""
|
||||
Validate that requested languages are supported by the model.
|
||||
|
||||
Args:
|
||||
languages (List[str]): Language names for each sample.
|
||||
|
||||
Raises:
|
||||
ValueError: If any language is not supported.
|
||||
"""
|
||||
supported = self._supported_languages_set()
|
||||
if supported is None:
|
||||
return
|
||||
|
||||
bad = []
|
||||
for lang in languages:
|
||||
if lang is None:
|
||||
bad.append(lang)
|
||||
continue
|
||||
if str(lang).lower() not in supported:
|
||||
bad.append(lang)
|
||||
if bad:
|
||||
raise ValueError(f"Unsupported languages: {bad}. Supported: {sorted(supported)}")
|
||||
|
||||
def _validate_speakers(self, speakers: List[Optional[str]]) -> None:
|
||||
"""
|
||||
Validate that requested speakers are supported by the Instruct model.
|
||||
|
||||
Args:
|
||||
speakers (List[Optional[str]]): Speaker names for each sample.
|
||||
|
||||
Raises:
|
||||
ValueError: If any speaker is not supported.
|
||||
"""
|
||||
supported = self._supported_speakers_set()
|
||||
if supported is None:
|
||||
return
|
||||
|
||||
bad = []
|
||||
for spk in speakers:
|
||||
if spk is None or spk == "":
|
||||
continue
|
||||
if str(spk).lower() not in supported:
|
||||
bad.append(spk)
|
||||
if bad:
|
||||
raise ValueError(f"Unsupported speakers: {bad}. Supported: {sorted(supported)}")
|
||||
|
||||
def _is_probably_base64(self, s: str) -> bool:
|
||||
if s.startswith("data:audio"):
|
||||
return True
|
||||
if ("/" not in s and "\\" not in s) and len(s) > 256:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_url(self, s: str) -> bool:
|
||||
try:
|
||||
u = urlparse(s)
|
||||
return u.scheme in ("http", "https") and bool(u.netloc)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _decode_base64_to_wav_bytes(self, b64: str) -> bytes:
|
||||
if "," in b64 and b64.strip().startswith("data:"):
|
||||
b64 = b64.split(",", 1)[1]
|
||||
return base64.b64decode(b64)
|
||||
|
||||
def _load_audio_to_np(self, x: str) -> Tuple[np.ndarray, int]:
|
||||
if self._is_url(x):
|
||||
with urllib.request.urlopen(x) as resp:
|
||||
audio_bytes = resp.read()
|
||||
with io.BytesIO(audio_bytes) as f:
|
||||
audio, sr = sf.read(f, dtype="float32", always_2d=False)
|
||||
elif self._is_probably_base64(x):
|
||||
wav_bytes = self._decode_base64_to_wav_bytes(x)
|
||||
with io.BytesIO(wav_bytes) as f:
|
||||
audio, sr = sf.read(f, dtype="float32", always_2d=False)
|
||||
else:
|
||||
audio, sr = librosa.load(x, sr=None, mono=True)
|
||||
|
||||
if audio.ndim > 1:
|
||||
audio = np.mean(audio, axis=-1)
|
||||
|
||||
return audio.astype(np.float32), int(sr)
|
||||
|
||||
def _normalize_audio_inputs(self, audios: Union[AudioLike, List[AudioLike]]) -> List[Tuple[np.ndarray, int]]:
|
||||
"""
|
||||
Normalize audio inputs into a list of (waveform, sr).
|
||||
|
||||
Supported forms:
|
||||
- str: wav path / URL / base64 audio string
|
||||
- (np.ndarray, sr): waveform + sampling rate
|
||||
- list of the above
|
||||
|
||||
Args:
|
||||
audios:
|
||||
Audio input(s).
|
||||
|
||||
Returns:
|
||||
List[Tuple[np.ndarray, int]]:
|
||||
List of (float32 waveform, original sr).
|
||||
|
||||
Raises:
|
||||
ValueError: If a numpy waveform is provided without sr.
|
||||
"""
|
||||
if isinstance(audios, list):
|
||||
items = audios
|
||||
else:
|
||||
items = [audios]
|
||||
|
||||
out: List[Tuple[np.ndarray, int]] = []
|
||||
for a in items:
|
||||
if isinstance(a, str):
|
||||
out.append(self._load_audio_to_np(a))
|
||||
elif isinstance(a, tuple) and len(a) == 2 and isinstance(a[0], np.ndarray):
|
||||
out.append((a[0].astype(np.float32), int(a[1])))
|
||||
elif isinstance(a, np.ndarray):
|
||||
raise ValueError("For numpy waveform input, pass a tuple (audio, sr).")
|
||||
else:
|
||||
raise TypeError(f"Unsupported audio input type: {type(a)}")
|
||||
for i, a in enumerate(out):
|
||||
if a[0].ndim > 1:
|
||||
a[0] = np.mean(a[0], axis=-1).astype(np.float32)
|
||||
out[i] = (a[0], a[1])
|
||||
return out
|
||||
|
||||
def _ensure_list(self, x: MaybeList) -> List[Any]:
|
||||
return x if isinstance(x, list) else [x]
|
||||
|
||||
def _build_assistant_text(self, text: str) -> str:
|
||||
return f"<|im_start|>assistant\n{text}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def _build_ref_text(self, text: str) -> str:
|
||||
return f"<|im_start|>assistant\n{text}<|im_end|>\n"
|
||||
|
||||
def _build_instruct_text(self, instruct: str) -> str:
|
||||
return f"<|im_start|>user\n{instruct}<|im_end|>\n"
|
||||
|
||||
def _tokenize_texts(self, texts: List[str]) -> List[torch.Tensor]:
|
||||
input_ids = []
|
||||
for text in texts:
|
||||
input = self.processor(text=text, return_tensors="pt", padding=True)
|
||||
input_id = input["input_ids"].to(self.device)
|
||||
input_id = input_id.unsqueeze(0) if input_id.dim() == 1 else input_id
|
||||
input_ids.append(input_id)
|
||||
return input_ids
|
||||
|
||||
def _merge_generate_kwargs(
|
||||
self,
|
||||
do_sample: Optional[bool] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
repetition_penalty: Optional[float] = None,
|
||||
subtalker_dosample: Optional[bool] = None,
|
||||
subtalker_top_k: Optional[int] = None,
|
||||
subtalker_top_p: Optional[float] = None,
|
||||
subtalker_temperature: Optional[float] = None,
|
||||
max_new_tokens: Optional[int] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Merge user-provided generation arguments with defaults from `generate_config.json`.
|
||||
|
||||
Rule:
|
||||
- If the user explicitly passes a value (not None), use it.
|
||||
- Otherwise, use the value from generate_config.json if present.
|
||||
- Otherwise, fall back to the hard defaults.
|
||||
|
||||
Args:
|
||||
do_sample, top_k, top_p, temperature, repetition_penalty,
|
||||
subtalker_dosample, subtalker_top_k, subtalker_top_p, subtalker_temperature, max_new_tokens:
|
||||
Common generation parameters.
|
||||
**kwargs:
|
||||
Other arguments forwarded to model.generate().
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: Final kwargs to pass into model.generate().
|
||||
"""
|
||||
hard_defaults = dict(
|
||||
do_sample=True,
|
||||
top_k=50,
|
||||
top_p=1.0,
|
||||
temperature=0.9,
|
||||
repetition_penalty=1.05,
|
||||
subtalker_dosample=True,
|
||||
subtalker_top_k=50,
|
||||
subtalker_top_p=1.0,
|
||||
subtalker_temperature=0.9,
|
||||
max_new_tokens=2048,
|
||||
)
|
||||
|
||||
def pick(name: str, user_val: Any) -> Any:
|
||||
if user_val is not None:
|
||||
return user_val
|
||||
if name in self.generate_defaults:
|
||||
return self.generate_defaults[name]
|
||||
return hard_defaults[name]
|
||||
|
||||
merged = dict(kwargs)
|
||||
merged.update(
|
||||
do_sample=pick("do_sample", do_sample),
|
||||
top_k=pick("top_k", top_k),
|
||||
top_p=pick("top_p", top_p),
|
||||
temperature=pick("temperature", temperature),
|
||||
repetition_penalty=pick("repetition_penalty", repetition_penalty),
|
||||
subtalker_dosample=pick("subtalker_dosample", subtalker_dosample),
|
||||
subtalker_top_k=pick("subtalker_top_k", subtalker_top_k),
|
||||
subtalker_top_p=pick("subtalker_top_p", subtalker_top_p),
|
||||
subtalker_temperature=pick("subtalker_temperature", subtalker_temperature),
|
||||
max_new_tokens=pick("max_new_tokens", max_new_tokens),
|
||||
)
|
||||
return merged
|
||||
|
||||
# voice clone model
|
||||
@torch.inference_mode()
|
||||
def create_voice_clone_prompt(
|
||||
self,
|
||||
ref_audio: Union[AudioLike, List[AudioLike]],
|
||||
ref_text: Optional[Union[str, List[Optional[str]]]] = None,
|
||||
x_vector_only_mode: Union[bool, List[bool]] = False,
|
||||
) -> List[VoiceClonePromptItem]:
|
||||
"""
|
||||
Build voice-clone prompt items from reference audio (and optionally reference text) using Base model.
|
||||
|
||||
Modes:
|
||||
- x_vector_only_mode=True:
|
||||
Only speaker embedding is used to clone voice; ref_text/ref_code are ignored.
|
||||
This is mutually exclusive with ICL.
|
||||
- x_vector_only_mode=False:
|
||||
ICL mode is enabled automatically (icl_mode=True). In this case ref_text is required,
|
||||
because the model continues/conditions on the reference text + reference speech codes.
|
||||
|
||||
Batch behavior:
|
||||
- ref_audio can be a single item or a list.
|
||||
- ref_text and x_vector_only_mode can be scalars or lists.
|
||||
- If any of them are lists with length > 1, lengths must match.
|
||||
|
||||
Audio input:
|
||||
- str: local wav path / URL / base64
|
||||
- (np.ndarray, sr): waveform + sampling rate
|
||||
|
||||
Args:
|
||||
ref_audio:
|
||||
Reference audio(s) used to extract:
|
||||
- ref_code via `model.speech_tokenizer.encode(...)`
|
||||
- ref_spk_embedding via `model.extract_speaker_embedding(...)` (resampled to 24k)
|
||||
ref_text:
|
||||
Reference transcript(s). Required when x_vector_only_mode=False (ICL mode).
|
||||
x_vector_only_mode:
|
||||
Whether to use speaker embedding only. If False, ICL mode will be used.
|
||||
|
||||
Returns:
|
||||
List[VoiceClonePromptItem]:
|
||||
List of prompt items that can be converted into `voice_clone_prompt` dict.
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
- If x_vector_only_mode=False but ref_text is missing.
|
||||
- If batch lengths mismatch.
|
||||
"""
|
||||
if self.model.tts_model_type != "base":
|
||||
raise ValueError(
|
||||
f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
|
||||
f"tts_model_size: {self.model.tts_model_size}\n"
|
||||
f"tts_model_type: {self.model.tts_model_type}\n"
|
||||
"does not support create_voice_clone_prompt, Please check Model Card or Readme for more details."
|
||||
)
|
||||
|
||||
ref_audio_list = self._ensure_list(ref_audio)
|
||||
ref_text_list = self._ensure_list(ref_text) if isinstance(ref_text, list) else ([ref_text] * len(ref_audio_list))
|
||||
xvec_list = self._ensure_list(x_vector_only_mode) if isinstance(x_vector_only_mode, list) else ([x_vector_only_mode] * len(ref_audio_list))
|
||||
|
||||
if len(ref_text_list) != len(ref_audio_list) or len(xvec_list) != len(ref_audio_list):
|
||||
raise ValueError(
|
||||
f"Batch size mismatch: ref_audio={len(ref_audio_list)}, ref_text={len(ref_text_list)}, x_vector_only_mode={len(xvec_list)}"
|
||||
)
|
||||
|
||||
normalized = self._normalize_audio_inputs(ref_audio_list)
|
||||
|
||||
ref_wavs_for_code: List[np.ndarray] = []
|
||||
ref_sr_for_code: List[int] = []
|
||||
for wav, sr in normalized:
|
||||
ref_wavs_for_code.append(wav)
|
||||
ref_sr_for_code.append(sr)
|
||||
|
||||
if len(set(ref_sr_for_code)) == 1:
|
||||
enc = self.model.speech_tokenizer.encode(ref_wavs_for_code, sr=ref_sr_for_code[0])
|
||||
ref_codes = enc.audio_codes
|
||||
else:
|
||||
ref_codes = []
|
||||
for wav, sr in normalized:
|
||||
ref_codes.append(self.model.speech_tokenizer.encode(wav, sr=sr).audio_codes[0])
|
||||
|
||||
items: List[VoiceClonePromptItem] = []
|
||||
for i, ((wav, sr), code, rtext, xvec_only) in enumerate(zip(normalized, ref_codes, ref_text_list, xvec_list)):
|
||||
if not xvec_only:
|
||||
if rtext is None or rtext == "":
|
||||
raise ValueError(f"ref_text is required when x_vector_only_mode=False (ICL mode). Bad index={i}")
|
||||
|
||||
wav_resample = wav
|
||||
if sr != self.model.speaker_encoder_sample_rate:
|
||||
wav_resample = librosa.resample(y=wav_resample.astype(np.float32),
|
||||
orig_sr=int(sr),
|
||||
target_sr=self.model.speaker_encoder_sample_rate)
|
||||
|
||||
spk_emb = self.model.extract_speaker_embedding(audio=wav_resample,
|
||||
sr=self.model.speaker_encoder_sample_rate)
|
||||
|
||||
items.append(
|
||||
VoiceClonePromptItem(
|
||||
ref_code=None if xvec_only else code,
|
||||
ref_spk_embedding=spk_emb,
|
||||
x_vector_only_mode=bool(xvec_only),
|
||||
icl_mode=bool(not xvec_only),
|
||||
ref_text=rtext,
|
||||
)
|
||||
)
|
||||
return items
|
||||
|
||||
def _prompt_items_to_voice_clone_prompt(self, items: List[VoiceClonePromptItem]) -> Dict[str, Any]:
|
||||
return dict(
|
||||
ref_code=[it.ref_code for it in items],
|
||||
ref_spk_embedding=[it.ref_spk_embedding for it in items],
|
||||
x_vector_only_mode=[it.x_vector_only_mode for it in items],
|
||||
icl_mode=[it.icl_mode for it in items],
|
||||
)
|
||||
|
||||
# voice clone model
|
||||
@torch.no_grad()
|
||||
def generate_voice_clone(
|
||||
self,
|
||||
text: Union[str, List[str]],
|
||||
language: Union[str, List[str]] = None,
|
||||
ref_audio: Optional[Union[AudioLike, List[AudioLike]]] = None,
|
||||
ref_text: Optional[Union[str, List[Optional[str]]]] = None,
|
||||
x_vector_only_mode: Union[bool, List[bool]] = False,
|
||||
voice_clone_prompt: Optional[Union[Dict[str, Any], List[VoiceClonePromptItem]]] = None,
|
||||
non_streaming_mode: bool = False,
|
||||
**kwargs,
|
||||
) -> Tuple[List[np.ndarray], int]:
|
||||
"""
|
||||
Voice clone speech using the Base model.
|
||||
|
||||
You can provide either:
|
||||
- (ref_audio, ref_text, x_vector_only_mode) and let this method build the prompt, OR
|
||||
- `VoiceClonePromptItem` returned by `create_voice_clone_prompt`, OR
|
||||
- a list of `VoiceClonePromptItem` returned by `create_voice_clone_prompt`.
|
||||
|
||||
`ref_audio` Supported forms:
|
||||
- str: wav path / URL / base64 audio string
|
||||
- (np.ndarray, sr): waveform + sampling rate
|
||||
- list of the above
|
||||
|
||||
Input flexibility:
|
||||
- text/language can be scalar or list.
|
||||
- prompt can be single or batch.
|
||||
- If batch mode (len(text)>1), lengths must match.
|
||||
|
||||
Args:
|
||||
text:
|
||||
Text(s) to synthesize.
|
||||
language:
|
||||
Language(s) for each sample.
|
||||
ref_audio:
|
||||
Reference audio(s) for prompt building. Required if voice_clone_prompt is not provided.
|
||||
ref_text:
|
||||
Reference text(s) used for ICL mode (required when x_vector_only_mode=False).
|
||||
x_vector_only_mode:
|
||||
If True, only speaker embedding is used (ignores ref_text/ref_code).
|
||||
If False, ICL mode is used automatically.
|
||||
voice_clone_prompt:
|
||||
list[VoiceClonePromptItem] from `create_voice_clone_prompt`.
|
||||
non_streaming_mode:
|
||||
Using non-streaming text input, this option currently only simulates streaming text input when set to `false`,
|
||||
rather than enabling true streaming input or streaming generation.
|
||||
do_sample:
|
||||
Whether to use sampling, recommended to be set to `true` for most use cases.
|
||||
top_k:
|
||||
Top-k sampling parameter.
|
||||
top_p:
|
||||
Top-p sampling parameter.
|
||||
temperature:
|
||||
Sampling temperature; higher => more random.
|
||||
repetition_penalty:
|
||||
Penalty to reduce repeated tokens/codes.
|
||||
subtalker_dosample:
|
||||
Sampling switch for the sub-talker (only valid for qwen3-tts-tokenizer-v2) if applicable.
|
||||
subtalker_top_k:
|
||||
Top-k for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
||||
subtalker_top_p:
|
||||
Top-p for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
||||
subtalker_temperature:
|
||||
Temperature for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
||||
max_new_tokens:
|
||||
Maximum number of new codec tokens to generate.
|
||||
**kwargs:
|
||||
Any other keyword arguments supported by HuggingFace Transformers `generate()` can be passed.
|
||||
They will be forwarded to the underlying `Qwen3TTSForConditionalGeneration.generate(...)`.
|
||||
|
||||
Returns:
|
||||
Tuple[List[np.ndarray], int]:
|
||||
(wavs, sample_rate)
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
If batch sizes mismatch or required prompt inputs are missing.
|
||||
"""
|
||||
if self.model.tts_model_type != "base":
|
||||
raise ValueError(
|
||||
f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
|
||||
f"tts_model_size: {self.model.tts_model_size}\n"
|
||||
f"tts_model_type: {self.model.tts_model_type}\n"
|
||||
"does not support generate_voice_clone, Please check Model Card or Readme for more details."
|
||||
)
|
||||
|
||||
texts = self._ensure_list(text)
|
||||
languages = self._ensure_list(language) if isinstance(language, list) else ([language] * len(texts) if language is not None else ["Auto"] * len(texts))
|
||||
if len(languages) == 1 and len(texts) > 1:
|
||||
languages = languages * len(texts)
|
||||
if len(texts) != len(languages):
|
||||
raise ValueError(f"Batch size mismatch: text={len(texts)}, language={len(languages)}")
|
||||
|
||||
self._validate_languages(languages)
|
||||
|
||||
if voice_clone_prompt is None:
|
||||
if ref_audio is None:
|
||||
raise ValueError("Either `voice_clone_prompt` or `ref_audio` must be provided.")
|
||||
prompt_items = self.create_voice_clone_prompt(ref_audio=ref_audio, ref_text=ref_text, x_vector_only_mode=x_vector_only_mode)
|
||||
if len(prompt_items) == 1 and len(texts) > 1:
|
||||
prompt_items = prompt_items * len(texts)
|
||||
if len(prompt_items) != len(texts):
|
||||
raise ValueError(f"Batch size mismatch: prompt={len(prompt_items)}, text={len(texts)}")
|
||||
voice_clone_prompt_dict = self._prompt_items_to_voice_clone_prompt(prompt_items)
|
||||
ref_texts_for_ids = [it.ref_text for it in prompt_items]
|
||||
else:
|
||||
if isinstance(voice_clone_prompt, list):
|
||||
prompt_items = voice_clone_prompt
|
||||
if len(prompt_items) == 1 and len(texts) > 1:
|
||||
prompt_items = prompt_items * len(texts)
|
||||
if len(prompt_items) != len(texts):
|
||||
raise ValueError(f"Batch size mismatch: prompt={len(prompt_items)}, text={len(texts)}")
|
||||
voice_clone_prompt_dict = self._prompt_items_to_voice_clone_prompt(prompt_items)
|
||||
ref_texts_for_ids = [it.ref_text for it in prompt_items]
|
||||
else:
|
||||
voice_clone_prompt_dict = voice_clone_prompt
|
||||
ref_texts_for_ids = None
|
||||
|
||||
input_texts = [self._build_assistant_text(t) for t in texts]
|
||||
input_ids = self._tokenize_texts(input_texts)
|
||||
|
||||
ref_ids = None
|
||||
if ref_texts_for_ids is not None:
|
||||
ref_ids = []
|
||||
for i, rt in enumerate(ref_texts_for_ids):
|
||||
if rt is None or rt == "":
|
||||
ref_ids.append(None)
|
||||
else:
|
||||
ref_tok = self._tokenize_texts([self._build_ref_text(rt)])[0]
|
||||
ref_ids.append(ref_tok)
|
||||
|
||||
gen_kwargs = self._merge_generate_kwargs(**kwargs)
|
||||
|
||||
talker_codes_list, _ = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
ref_ids=ref_ids,
|
||||
voice_clone_prompt=voice_clone_prompt_dict,
|
||||
languages=languages,
|
||||
non_streaming_mode=non_streaming_mode,
|
||||
**gen_kwargs,
|
||||
)
|
||||
|
||||
codes_for_decode = []
|
||||
for i, codes in enumerate(talker_codes_list):
|
||||
ref_code_list = voice_clone_prompt_dict.get("ref_code", None)
|
||||
if ref_code_list is not None and ref_code_list[i] is not None:
|
||||
codes_for_decode.append(torch.cat([ref_code_list[i].to(codes.device), codes], dim=0))
|
||||
else:
|
||||
codes_for_decode.append(codes)
|
||||
|
||||
wavs_all, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in codes_for_decode])
|
||||
|
||||
wavs_out: List[np.ndarray] = []
|
||||
for i, wav in enumerate(wavs_all):
|
||||
ref_code_list = voice_clone_prompt_dict.get("ref_code", None)
|
||||
if ref_code_list is not None and ref_code_list[i] is not None:
|
||||
ref_len = int(ref_code_list[i].shape[0])
|
||||
total_len = int(codes_for_decode[i].shape[0])
|
||||
cut = int(ref_len / max(total_len, 1) * wav.shape[0])
|
||||
wavs_out.append(wav[cut:])
|
||||
else:
|
||||
wavs_out.append(wav)
|
||||
|
||||
return wavs_out, fs
|
||||
|
||||
# voice design model
|
||||
@torch.no_grad()
|
||||
def generate_voice_design(
|
||||
self,
|
||||
text: Union[str, List[str]],
|
||||
instruct: Union[str, List[str]],
|
||||
language: Union[str, List[str]] = None,
|
||||
non_streaming_mode: bool = True,
|
||||
**kwargs,
|
||||
) -> Tuple[List[np.ndarray], int]:
|
||||
"""
|
||||
Generate speech with the VoiceDesign model using natural-language style instructions.
|
||||
|
||||
Args:
|
||||
text:
|
||||
Text(s) to synthesize.
|
||||
language:
|
||||
Language(s) for each sample.
|
||||
instruct:
|
||||
Instruction(s) describing desired voice/style. Empty string is allowed (treated as no instruction).
|
||||
non_streaming_mode:
|
||||
Using non-streaming text input, this option currently only simulates streaming text input when set to `false`,
|
||||
rather than enabling true streaming input or streaming generation.
|
||||
do_sample:
|
||||
Whether to use sampling, recommended to be set to `true` for most use cases.
|
||||
top_k:
|
||||
Top-k sampling parameter.
|
||||
top_p:
|
||||
Top-p sampling parameter.
|
||||
temperature:
|
||||
Sampling temperature; higher => more random.
|
||||
repetition_penalty:
|
||||
Penalty to reduce repeated tokens/codes.
|
||||
subtalker_dosample:
|
||||
Sampling switch for the sub-talker (only valid for qwen3-tts-tokenizer-v2) if applicable.
|
||||
subtalker_top_k:
|
||||
Top-k for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
||||
subtalker_top_p:
|
||||
Top-p for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
||||
subtalker_temperature:
|
||||
Temperature for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
||||
max_new_tokens:
|
||||
Maximum number of new codec tokens to generate.
|
||||
**kwargs:
|
||||
Any other keyword arguments supported by HuggingFace Transformers `generate()` can be passed.
|
||||
They will be forwarded to the underlying `Qwen3TTSForConditionalGeneration.generate(...)`.
|
||||
|
||||
Returns:
|
||||
Tuple[List[np.ndarray], int]:
|
||||
(wavs, sample_rate)
|
||||
"""
|
||||
if self.model.tts_model_type != "voice_design":
|
||||
raise ValueError(
|
||||
f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
|
||||
f"tts_model_size: {self.model.tts_model_size}\n"
|
||||
f"tts_model_type: {self.model.tts_model_type}\n"
|
||||
"does not support generate_voice_design, Please check Model Card or Readme for more details."
|
||||
)
|
||||
|
||||
texts = self._ensure_list(text)
|
||||
languages = self._ensure_list(language) if isinstance(language, list) else ([language] * len(texts) if language is not None else ["Auto"] * len(texts))
|
||||
instructs = self._ensure_list(instruct)
|
||||
|
||||
if len(languages) == 1 and len(texts) > 1:
|
||||
languages = languages * len(texts)
|
||||
if len(instructs) == 1 and len(texts) > 1:
|
||||
instructs = instructs * len(texts)
|
||||
|
||||
if not (len(texts) == len(languages) == len(instructs)):
|
||||
raise ValueError(f"Batch size mismatch: text={len(texts)}, language={len(languages)}, instruct={len(instructs)}")
|
||||
|
||||
self._validate_languages(languages)
|
||||
|
||||
input_ids = self._tokenize_texts([self._build_assistant_text(t) for t in texts])
|
||||
|
||||
instruct_ids: List[Optional[torch.Tensor]] = []
|
||||
for ins in instructs:
|
||||
if ins is None or ins == "":
|
||||
instruct_ids.append(None)
|
||||
else:
|
||||
instruct_ids.append(self._tokenize_texts([self._build_instruct_text(ins)])[0])
|
||||
|
||||
gen_kwargs = self._merge_generate_kwargs(**kwargs)
|
||||
|
||||
talker_codes_list, _ = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
instruct_ids=instruct_ids,
|
||||
languages=languages,
|
||||
non_streaming_mode=non_streaming_mode,
|
||||
**gen_kwargs,
|
||||
)
|
||||
|
||||
wavs, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in talker_codes_list])
|
||||
return wavs, fs
|
||||
|
||||
# custom voice model
|
||||
@torch.no_grad()
|
||||
def generate_custom_voice(
|
||||
self,
|
||||
text: Union[str, List[str]],
|
||||
speaker: Union[str, List[str]],
|
||||
language: Union[str, List[str]] = None,
|
||||
instruct: Optional[Union[str, List[str]]] = None,
|
||||
non_streaming_mode: bool = True,
|
||||
**kwargs,
|
||||
) -> Tuple[List[np.ndarray], int]:
|
||||
"""
|
||||
Generate speech with the CustomVoice model using a predefined speaker id, optionally controlled by instruction text.
|
||||
|
||||
Args:
|
||||
text:
|
||||
Text(s) to synthesize.
|
||||
language:
|
||||
Language(s) for each sample.
|
||||
speaker:
|
||||
Speaker name(s). Will be validated against `model.get_supported_speakers()` (case-insensitive).
|
||||
instruct:
|
||||
Optional instruction(s). If None, treated as empty (no instruction).
|
||||
non_streaming_mode:
|
||||
Using non-streaming text input, this option currently only simulates streaming text input when set to `false`,
|
||||
rather than enabling true streaming input or streaming generation.
|
||||
do_sample:
|
||||
Whether to use sampling, recommended to be set to `true` for most use cases.
|
||||
top_k:
|
||||
Top-k sampling parameter.
|
||||
top_p:
|
||||
Top-p sampling parameter.
|
||||
temperature:
|
||||
Sampling temperature; higher => more random.
|
||||
repetition_penalty:
|
||||
Penalty to reduce repeated tokens/codes.
|
||||
subtalker_dosample:
|
||||
Sampling switch for the sub-talker (only valid for qwen3-tts-tokenizer-v2) if applicable.
|
||||
subtalker_top_k:
|
||||
Top-k for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
||||
subtalker_top_p:
|
||||
Top-p for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
||||
subtalker_temperature:
|
||||
Temperature for sub-talker sampling (only valid for qwen3-tts-tokenizer-v2).
|
||||
max_new_tokens:
|
||||
Maximum number of new codec tokens to generate.
|
||||
**kwargs:
|
||||
Any other keyword arguments supported by HuggingFace Transformers `generate()` can be passed.
|
||||
They will be forwarded to the underlying `Qwen3TTSForConditionalGeneration.generate(...)`.
|
||||
|
||||
Returns:
|
||||
Tuple[List[np.ndarray], int]:
|
||||
(wavs, sample_rate)
|
||||
|
||||
Raises:
|
||||
ValueError:
|
||||
If any speaker/language is unsupported or batch sizes mismatch.
|
||||
"""
|
||||
if self.model.tts_model_type != "custom_voice":
|
||||
raise ValueError(
|
||||
f"model with \ntokenizer_type: {self.model.tokenizer_type}\n"
|
||||
f"tts_model_size: {self.model.tts_model_size}\n"
|
||||
f"tts_model_type: {self.model.tts_model_type}\n"
|
||||
"does not support generate_custom_voice, Please check Model Card or Readme for more details."
|
||||
)
|
||||
|
||||
texts = self._ensure_list(text)
|
||||
languages = self._ensure_list(language) if isinstance(language, list) else ([language] * len(texts) if language is not None else ["Auto"] * len(texts))
|
||||
speakers = self._ensure_list(speaker)
|
||||
if self.model.tts_model_size in "0b6": # for 0b6 model, instruct is not supported
|
||||
instruct = None
|
||||
instructs = self._ensure_list(instruct) if isinstance(instruct, list) else ([instruct] * len(texts) if instruct is not None else [""] * len(texts))
|
||||
|
||||
if len(languages) == 1 and len(texts) > 1:
|
||||
languages = languages * len(texts)
|
||||
if len(speakers) == 1 and len(texts) > 1:
|
||||
speakers = speakers * len(texts)
|
||||
if len(instructs) == 1 and len(texts) > 1:
|
||||
instructs = instructs * len(texts)
|
||||
|
||||
if not (len(texts) == len(languages) == len(speakers) == len(instructs)):
|
||||
raise ValueError(
|
||||
f"Batch size mismatch: text={len(texts)}, language={len(languages)}, speaker={len(speakers)}, instruct={len(instructs)}"
|
||||
)
|
||||
|
||||
self._validate_languages(languages)
|
||||
self._validate_speakers(speakers)
|
||||
|
||||
input_ids = self._tokenize_texts([self._build_assistant_text(t) for t in texts])
|
||||
|
||||
instruct_ids: List[Optional[torch.Tensor]] = []
|
||||
for ins in instructs:
|
||||
if ins is None or ins == "":
|
||||
instruct_ids.append(None)
|
||||
else:
|
||||
instruct_ids.append(self._tokenize_texts([self._build_instruct_text(ins)])[0])
|
||||
|
||||
gen_kwargs = self._merge_generate_kwargs(**kwargs)
|
||||
|
||||
talker_codes_list, _ = self.model.generate(
|
||||
input_ids=input_ids,
|
||||
instruct_ids=instruct_ids,
|
||||
languages=languages,
|
||||
speakers=speakers,
|
||||
non_streaming_mode=non_streaming_mode,
|
||||
**gen_kwargs,
|
||||
)
|
||||
|
||||
wavs, fs = self.model.speech_tokenizer.decode([{"audio_codes": c} for c in talker_codes_list])
|
||||
return wavs, fs
|
||||
|
||||
|
||||
def get_supported_speakers(self) -> Optional[List[str]]:
|
||||
"""
|
||||
List supported speaker names for the current model.
|
||||
|
||||
This is a convenience wrapper around `model.get_supported_speakers()`.
|
||||
If the underlying model does not expose speaker constraints (returns None),
|
||||
this method also returns None.
|
||||
|
||||
Returns:
|
||||
Optional[List[str]]:
|
||||
- A sorted list of supported speaker names (lowercased), if available.
|
||||
- None if the model does not provide supported speakers.
|
||||
"""
|
||||
supported = self._supported_speakers_set()
|
||||
if supported is None:
|
||||
return None
|
||||
return sorted(supported)
|
||||
|
||||
|
||||
def get_supported_languages(self) -> Optional[List[str]]:
|
||||
"""
|
||||
List supported language names for the current model.
|
||||
|
||||
This is a convenience wrapper around `model.get_supported_languages()`.
|
||||
If the underlying model does not expose language constraints (returns None),
|
||||
this method also returns None.
|
||||
|
||||
Returns:
|
||||
Optional[List[str]]:
|
||||
- A sorted list of supported language names (lowercased), if available.
|
||||
- None if the model does not provide supported languages.
|
||||
"""
|
||||
supported = self._supported_languages_set()
|
||||
if supported is None:
|
||||
return None
|
||||
return sorted(supported)
|
||||
411
models/Qwen3-TTS/qwen_tts/inference/qwen3_tts_tokenizer.py
Normal file
411
models/Qwen3-TTS/qwen_tts/inference/qwen3_tts_tokenizer.py
Normal file
@@ -0,0 +1,411 @@
|
||||
# coding=utf-8
|
||||
# Copyright 2026 The Alibaba Qwen team.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import base64
|
||||
import io
|
||||
import urllib.request
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import librosa
|
||||
import numpy as np
|
||||
import soundfile as sf
|
||||
import torch
|
||||
from torch.nn.utils.rnn import pad_sequence
|
||||
from transformers import AutoConfig, AutoFeatureExtractor, AutoModel
|
||||
|
||||
from ..core import (
|
||||
Qwen3TTSTokenizerV1Config,
|
||||
Qwen3TTSTokenizerV1Model,
|
||||
Qwen3TTSTokenizerV2Config,
|
||||
Qwen3TTSTokenizerV2Model,
|
||||
)
|
||||
|
||||
AudioInput = Union[
|
||||
str, # wav path, or base64 string
|
||||
np.ndarray, # 1-D float array
|
||||
List[str],
|
||||
List[np.ndarray],
|
||||
]
|
||||
|
||||
|
||||
class Qwen3TTSTokenizer:
|
||||
"""
|
||||
A wrapper for Qwen3 TTS Tokenizer 25Hz/12Hz with HuggingFace-style loading.
|
||||
|
||||
- from_pretrained(): loads speech tokenizer model via AutoModel and feature_extractor via AutoFeatureExtractor.
|
||||
- encode(): supports wav path(s), base64 audio string(s), numpy array(s).
|
||||
- decode(): accepts either the raw model encode output, or a minimal dict/list-of-dicts.
|
||||
|
||||
Notes:
|
||||
- For numpy array input, you must pass `sr` so the audio can be resampled to model sample rate.
|
||||
- Returned audio is float32 numpy arrays and the output sample rate.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.model = None
|
||||
self.feature_extractor = None
|
||||
self.config = None
|
||||
self.device = None
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(cls, pretrained_model_name_or_path: str, **kwargs) -> "Qwen3TTSTokenizer":
|
||||
"""
|
||||
Initialize tokenizer with HuggingFace `from_pretrained` style.
|
||||
|
||||
Args:
|
||||
pretrained_model_name_or_path (str):
|
||||
HuggingFace repo id or local directory.
|
||||
**kwargs (Any):
|
||||
Forwarded to `AutoModel.from_pretrained(...)` directly.
|
||||
Typical examples: device_map="cuda:0", dtype=torch.bfloat16, attn_implementation="eager".
|
||||
|
||||
Returns:
|
||||
Qwen3TTSTokenizer:
|
||||
Initialized instance with `model`, `feature_extractor`, `config`.
|
||||
"""
|
||||
inst = cls()
|
||||
|
||||
AutoConfig.register("qwen3_tts_tokenizer_25hz", Qwen3TTSTokenizerV1Config)
|
||||
AutoModel.register(Qwen3TTSTokenizerV1Config, Qwen3TTSTokenizerV1Model)
|
||||
|
||||
AutoConfig.register("qwen3_tts_tokenizer_12hz", Qwen3TTSTokenizerV2Config)
|
||||
AutoModel.register(Qwen3TTSTokenizerV2Config, Qwen3TTSTokenizerV2Model)
|
||||
|
||||
inst.feature_extractor = AutoFeatureExtractor.from_pretrained(pretrained_model_name_or_path)
|
||||
inst.model = AutoModel.from_pretrained(pretrained_model_name_or_path, **kwargs)
|
||||
inst.config = inst.model.config
|
||||
|
||||
inst.device = getattr(inst.model, "device", None)
|
||||
if inst.device is None:
|
||||
# fallback: infer from first parameter device
|
||||
try:
|
||||
inst.device = next(inst.model.parameters()).device
|
||||
except StopIteration:
|
||||
inst.device = torch.device("cpu")
|
||||
|
||||
return inst
|
||||
|
||||
def _is_probably_base64(self, s: str) -> bool:
|
||||
if s.startswith("data:audio"):
|
||||
return True
|
||||
# Heuristic: no filesystem path separators and long enough.
|
||||
if ("/" not in s and "\\" not in s) and len(s) > 256:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_url(self, s: str) -> bool:
|
||||
try:
|
||||
u = urlparse(s)
|
||||
return u.scheme in ("http", "https") and bool(u.netloc)
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _decode_base64_to_wav_bytes(self, b64: str) -> bytes:
|
||||
# Accept both "data:audio/wav;base64,...." and raw base64
|
||||
if "," in b64 and b64.strip().startswith("data:"):
|
||||
b64 = b64.split(",", 1)[1]
|
||||
return base64.b64decode(b64)
|
||||
|
||||
def load_audio(
|
||||
self,
|
||||
x: str,
|
||||
target_sr: int,
|
||||
) -> np.ndarray:
|
||||
"""
|
||||
Load audio from wav path or base64 string, then resample to target_sr.
|
||||
|
||||
Args:
|
||||
x (str):
|
||||
A wav file path, or a base64 audio string (raw or data URL).
|
||||
target_sr (int):
|
||||
Target sampling rate.
|
||||
|
||||
Returns:
|
||||
np.ndarray:
|
||||
1-D float32 waveform at target_sr.
|
||||
"""
|
||||
if self._is_url(x):
|
||||
with urllib.request.urlopen(x) as resp:
|
||||
audio_bytes = resp.read()
|
||||
with io.BytesIO(audio_bytes) as f:
|
||||
audio, sr = sf.read(f, dtype="float32", always_2d=False)
|
||||
elif self._is_probably_base64(x):
|
||||
wav_bytes = self._decode_base64_to_wav_bytes(x)
|
||||
with io.BytesIO(wav_bytes) as f:
|
||||
audio, sr = sf.read(f, dtype="float32", always_2d=False)
|
||||
else:
|
||||
audio, sr = librosa.load(x, sr=None, mono=True)
|
||||
|
||||
if audio.ndim > 1:
|
||||
audio = np.mean(audio, axis=-1)
|
||||
|
||||
if sr != target_sr:
|
||||
audio = librosa.resample(y=audio, orig_sr=sr, target_sr=target_sr)
|
||||
|
||||
return audio.astype(np.float32)
|
||||
|
||||
def _normalize_audio_inputs(
|
||||
self,
|
||||
audios: AudioInput,
|
||||
sr: Optional[int],
|
||||
) -> List[np.ndarray]:
|
||||
"""
|
||||
Normalize all supported input types into a list of 1-D numpy float32 waveforms
|
||||
at `self.feature_extractor.sampling_rate`.
|
||||
|
||||
Args:
|
||||
audios (AudioInput):
|
||||
- str: wav path OR base64 audio string
|
||||
- np.ndarray: raw waveform (sr must be provided)
|
||||
- list[str] / list[np.ndarray]
|
||||
sr (Optional[int]):
|
||||
Sampling rate for raw numpy input. Required if input is np.ndarray or list[np.ndarray].
|
||||
|
||||
Returns:
|
||||
List[np.ndarray]:
|
||||
List of float32 waveforms resampled to model input SR.
|
||||
"""
|
||||
target_sr = int(self.feature_extractor.sampling_rate)
|
||||
|
||||
if isinstance(audios, (str, np.ndarray)):
|
||||
audios = [audios]
|
||||
|
||||
if len(audios) == 0:
|
||||
return []
|
||||
|
||||
if isinstance(audios[0], str):
|
||||
# wav path list or base64 list
|
||||
return [self.load_audio(x, target_sr=target_sr) for x in audios] # type: ignore[arg-type]
|
||||
|
||||
# numpy list
|
||||
if sr is None:
|
||||
raise ValueError("For numpy waveform input, you must provide `sr` (original sampling rate).")
|
||||
|
||||
out: List[np.ndarray] = []
|
||||
for a in audios: # type: ignore[assignment]
|
||||
if not isinstance(a, np.ndarray):
|
||||
raise TypeError("Mixed input types are not supported. Use all paths/base64 or all numpy arrays.")
|
||||
if a.ndim > 1:
|
||||
a = np.mean(a, axis=-1)
|
||||
if int(sr) != target_sr:
|
||||
a = librosa.resample(y=a.astype(np.float32), orig_sr=int(sr), target_sr=target_sr)
|
||||
out.append(a.astype(np.float32))
|
||||
return out
|
||||
|
||||
def encode(
|
||||
self,
|
||||
audios: AudioInput,
|
||||
sr: Optional[int] = None,
|
||||
return_dict: bool = True,
|
||||
):
|
||||
"""
|
||||
Batch-encode audio into discrete codes (and optional conditioning, depending on 25Hz/12Hz).
|
||||
|
||||
Args:
|
||||
audios (AudioInput):
|
||||
Supported forms:
|
||||
- np.ndarray: waveform (requires sr)
|
||||
- list[np.ndarray]: waveforms (requires sr)
|
||||
- str: wav path OR base64 audio string
|
||||
- list[str]: wav paths and/or base64 strings
|
||||
sr (Optional[int], default=None):
|
||||
Original sampling rate for numpy waveform input.
|
||||
return_dict (bool, default=True):
|
||||
Forwarded to model.encode(...). If True, returns ModelOutput.
|
||||
|
||||
Returns:
|
||||
25Hz:
|
||||
Qwen3TTSTokenizerV1EncoderOutput (if return_dict=True) with fields:
|
||||
- audio_codes: List[torch.LongTensor] each (codes_len,)
|
||||
- xvectors: List[torch.FloatTensor] each (xvector_dim,)
|
||||
- ref_mels: List[torch.FloatTensor] each (mel_len, mel_dim)
|
||||
12Hz:
|
||||
Qwen3TTSTokenizerV2EncoderOutput (if return_dict=True) with fields:
|
||||
- audio_codes: List[torch.LongTensor] each (codes_len, num_quantizers)
|
||||
|
||||
If return_dict=False, returns the raw tuple from model.encode.
|
||||
"""
|
||||
wavs = self._normalize_audio_inputs(audios, sr=sr)
|
||||
|
||||
inputs = self.feature_extractor(
|
||||
raw_audio=wavs,
|
||||
sampling_rate=int(self.feature_extractor.sampling_rate),
|
||||
return_tensors="pt",
|
||||
)
|
||||
inputs = inputs.to(self.device).to(self.model.dtype)
|
||||
|
||||
with torch.inference_mode():
|
||||
# model.encode expects (B, T) and (B, T)
|
||||
enc = self.model.encode(
|
||||
inputs["input_values"].squeeze(1),
|
||||
inputs["padding_mask"].squeeze(1),
|
||||
return_dict=return_dict,
|
||||
)
|
||||
return enc
|
||||
|
||||
def decode(
|
||||
self,
|
||||
encoded,
|
||||
) -> Tuple[List[np.ndarray], int]:
|
||||
"""
|
||||
Decode back to waveform.
|
||||
|
||||
Usage:
|
||||
1) Pass the raw output of `encode(...)` directly (recommended).
|
||||
- 25Hz: expects fields audio_codes, xvectors, ref_mels
|
||||
- 12Hz: expects field audio_codes
|
||||
2) Pass a dict or list[dict] (minimal form) for custom pipelines:
|
||||
- 25Hz dict keys: {"audio_codes", "xvectors", "ref_mels"}
|
||||
- 12Hz dict keys: {"audio_codes"}
|
||||
Values can be torch tensors or numpy arrays.
|
||||
|
||||
Args:
|
||||
encoded (Any):
|
||||
- ModelOutput returned by `encode()`, OR
|
||||
- dict, OR
|
||||
- list[dict]
|
||||
|
||||
Returns:
|
||||
Tuple[List[np.ndarray], int]:
|
||||
- wavs: list of 1-D float32 numpy arrays
|
||||
- sample_rate: int, model output sampling rate
|
||||
"""
|
||||
model_type = self.model.get_model_type()
|
||||
|
||||
def _to_tensor(x, dtype=None):
|
||||
if isinstance(x, torch.Tensor):
|
||||
return x
|
||||
x = np.asarray(x)
|
||||
t = torch.from_numpy(x)
|
||||
if dtype is not None:
|
||||
t = t.to(dtype)
|
||||
return t
|
||||
|
||||
# Normalize `encoded` into the same shapes as the official demo uses.
|
||||
if hasattr(encoded, "audio_codes"):
|
||||
# ModelOutput from encode()
|
||||
audio_codes_list = encoded.audio_codes
|
||||
xvectors_list = getattr(encoded, "xvectors", None)
|
||||
ref_mels_list = getattr(encoded, "ref_mels", None)
|
||||
elif isinstance(encoded, dict):
|
||||
audio_codes_list = encoded["audio_codes"]
|
||||
xvectors_list = encoded.get("xvectors", None)
|
||||
ref_mels_list = encoded.get("ref_mels", None)
|
||||
elif isinstance(encoded, list):
|
||||
# list of dicts
|
||||
audio_codes_list = [e["audio_codes"] for e in encoded]
|
||||
xvectors_list = [e["xvectors"] for e in encoded] if ("xvectors" in encoded[0]) else None
|
||||
ref_mels_list = [e["ref_mels"] for e in encoded] if ("ref_mels" in encoded[0]) else None
|
||||
else:
|
||||
raise TypeError("`encoded` must be an encode output, a dict, or a list of dicts.")
|
||||
|
||||
# Ensure list form for per-sample tensors
|
||||
if isinstance(audio_codes_list, torch.Tensor):
|
||||
# Could be a single sample tensor or an already padded batch tensor.
|
||||
t = audio_codes_list
|
||||
if t.dim() == 1:
|
||||
# 25Hz single sample: (C,) -> (1, C)
|
||||
t = t.unsqueeze(0)
|
||||
elif t.dim() == 2:
|
||||
# 12Hz single sample: (C, Q) -> (1, C, Q)
|
||||
t = t.unsqueeze(0)
|
||||
audio_codes_padded = t.to(self.device)
|
||||
else:
|
||||
# List[Tensor/np]
|
||||
audio_codes_list = [_to_tensor(c, dtype=torch.long) for c in audio_codes_list]
|
||||
audio_codes_padded = pad_sequence(audio_codes_list, batch_first=True, padding_value=0).to(self.device)
|
||||
|
||||
with torch.inference_mode():
|
||||
if model_type == "qwen3_tts_tokenizer_25hz":
|
||||
if xvectors_list is None or ref_mels_list is None:
|
||||
raise ValueError("25Hz decode requires `xvectors` and `ref_mels`.")
|
||||
|
||||
if isinstance(xvectors_list, torch.Tensor):
|
||||
xvectors_batch = xvectors_list
|
||||
if xvectors_batch.dim() == 1: # (D,) -> (1, D)
|
||||
xvectors_batch = xvectors_batch.unsqueeze(0)
|
||||
xvectors_batch = xvectors_batch.to(self.device).to(self.model.dtype)
|
||||
else:
|
||||
xvectors_list = [_to_tensor(x, dtype=torch.float32) for x in xvectors_list]
|
||||
xvectors_batch = torch.stack(xvectors_list, dim=0).to(self.device).to(self.model.dtype)
|
||||
|
||||
if isinstance(ref_mels_list, torch.Tensor):
|
||||
ref_mels_padded = ref_mels_list
|
||||
if ref_mels_padded.dim() == 2: # (T, M) -> (1, T, M)
|
||||
ref_mels_padded = ref_mels_padded.unsqueeze(0)
|
||||
ref_mels_padded = ref_mels_padded.to(self.device).to(self.model.dtype)
|
||||
else:
|
||||
ref_mels_list = [_to_tensor(m, dtype=torch.float32) for m in ref_mels_list]
|
||||
ref_mels_padded = pad_sequence(ref_mels_list, batch_first=True, padding_value=0).to(self.device).to(self.model.dtype)
|
||||
|
||||
dec = self.model.decode(audio_codes_padded, xvectors_batch, ref_mels_padded, return_dict=True)
|
||||
wav_tensors = dec.audio_values
|
||||
|
||||
elif model_type == "qwen3_tts_tokenizer_12hz":
|
||||
dec = self.model.decode(audio_codes_padded, return_dict=True)
|
||||
wav_tensors = dec.audio_values
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown model type: {model_type}")
|
||||
|
||||
wavs = [w.to(torch.float32).detach().cpu().numpy() for w in wav_tensors]
|
||||
return wavs, int(self.model.get_output_sample_rate())
|
||||
|
||||
def get_model_type(self) -> str:
|
||||
"""
|
||||
Get the underlying tokenizer model type.
|
||||
|
||||
Returns:
|
||||
str: Model type string from `self.model.config.model_type`
|
||||
(e.g. "qwen3_tts_tokenizer_25hz" / "qwen3_tts_tokenizer_12hz").
|
||||
"""
|
||||
return self.model.get_model_type()
|
||||
|
||||
def get_input_sample_rate(self) -> int:
|
||||
"""
|
||||
Get the expected input sample rate for encoding.
|
||||
|
||||
Returns:
|
||||
int: Input sample rate (Hz).
|
||||
"""
|
||||
return int(self.model.get_input_sample_rate())
|
||||
|
||||
def get_output_sample_rate(self) -> int:
|
||||
"""
|
||||
Get the output sample rate for decoded waveforms.
|
||||
|
||||
Returns:
|
||||
int: Output sample rate (Hz).
|
||||
"""
|
||||
return int(self.model.get_output_sample_rate())
|
||||
|
||||
def get_encode_downsample_rate(self) -> int:
|
||||
"""
|
||||
Get the encoder downsample rate (waveform samples per code step).
|
||||
|
||||
Returns:
|
||||
int: Encode downsample rate.
|
||||
"""
|
||||
return int(self.model.get_encode_downsample_rate())
|
||||
|
||||
def get_decode_upsample_rate(self) -> int:
|
||||
"""
|
||||
Get the decoder upsample rate (waveform samples per code step).
|
||||
|
||||
Returns:
|
||||
int: Decode upsample rate.
|
||||
"""
|
||||
return int(self.model.get_decode_upsample_rate())
|
||||
4
run_backend.sh
Normal file
4
run_backend.sh
Normal file
@@ -0,0 +1,4 @@
|
||||
#!/bin/bash
|
||||
# 启动 ViGent2 后端 (FastAPI)
|
||||
cd "$(dirname "$0")/backend"
|
||||
./venv/bin/uvicorn app.main:app --host 0.0.0.0 --port 8006
|
||||
17
run_latentsync.sh
Normal file
17
run_latentsync.sh
Normal file
@@ -0,0 +1,17 @@
|
||||
#!/bin/bash
|
||||
# 启动 LatentSync 模型服务
|
||||
# 注意: 需要根据服务器实际情况修改 Python 路径
|
||||
|
||||
cd "$(dirname "$0")/models/LatentSync"
|
||||
|
||||
# 请根据您的 Miniconda/Anaconda 安装路径修改此处
|
||||
PYTHON_PATH="/home/rongye/ProgramFiles/miniconda3/envs/latentsync/bin/python"
|
||||
|
||||
if [ -f "$PYTHON_PATH" ]; then
|
||||
"$PYTHON_PATH" -m scripts.server
|
||||
else
|
||||
echo "❌ 错误: 找不到 Python 解释器: $PYTHON_PATH"
|
||||
echo "请编辑此脚本 (run_latentsync.sh) 修改 PYTHON_PATH 为您的实际路径:"
|
||||
echo "conda activate latentsync && which python"
|
||||
exit 1
|
||||
fi
|
||||
Reference in New Issue
Block a user