Init: 导入NaviGlassServer源码
This commit is contained in:
68
.env.performance
Normal file
68
.env.performance
Normal file
@@ -0,0 +1,68 @@
|
||||
# ============================================================
|
||||
# Day 22 性能优化配置文件
|
||||
# 复制此文件为 .env 并根据硬件调整参数
|
||||
# ============================================================
|
||||
|
||||
# === YOLO 盲道/斑马线检测 ===
|
||||
# 输入分辨率 (越小越快,但精度降低)
|
||||
# 建议: RTX 3090 用 640, GTX 1060 用 320-384
|
||||
AIGLASS_YOLO_IMGSZ=480
|
||||
|
||||
# 检测间隔 (每N帧检测一次)
|
||||
# 建议: 高端GPU用 6-8, 低端用 12-15
|
||||
AIGLASS_BLINDPATH_INTERVAL=10
|
||||
|
||||
# 启用FP16半精度 (1=启用, 0=禁用)
|
||||
# 注意: GTX 1060 的FP16性能不佳,建议设为0
|
||||
AIGLASS_YOLO_HALF=1
|
||||
|
||||
# === 障碍物检测 ===
|
||||
# 输入分辨率
|
||||
AIGLASS_OBS_IMGSZ=480
|
||||
|
||||
# 检测间隔
|
||||
AIGLASS_OBS_INTERVAL=18
|
||||
|
||||
# 缓存帧数
|
||||
AIGLASS_OBS_CACHE_FRAMES=12
|
||||
|
||||
# 置信度阈值
|
||||
AIGLASS_OBS_CONF=0.25
|
||||
|
||||
# 启用FP16
|
||||
AIGLASS_OBS_HALF=1
|
||||
|
||||
# === GPU 并发控制 ===
|
||||
# 同时推理的最大任务数
|
||||
AIGLASS_GPU_SLOTS=2
|
||||
|
||||
# 设备选择 (cuda:0, cuda:1, cpu)
|
||||
AIGLASS_DEVICE=cuda:0
|
||||
|
||||
# 混合精度模式 (fp16, bf16, off)
|
||||
# RTX 30系列支持 bf16, 其他用 fp16
|
||||
AIGLASS_AMP=fp16
|
||||
|
||||
# === 语音播报 ===
|
||||
# 直行提示间隔(秒)
|
||||
AIGLASS_STRAIGHT_INTERVAL=4.0
|
||||
|
||||
# 方向指令间隔(秒)
|
||||
AIGLASS_DIRECTION_INTERVAL=3.0
|
||||
|
||||
# 持续播报模式 (1=启用)
|
||||
AIGLASS_STRAIGHT_CONTINUOUS=1
|
||||
|
||||
# 限制模式下最大重复次数
|
||||
AIGLASS_STRAIGHT_LIMIT=2
|
||||
|
||||
# === 模型路径 (根据实际路径修改) ===
|
||||
# BLIND_PATH_MODEL=models/yolo-seg.pt
|
||||
# OBSTACLE_MODEL=models/yoloe-11l-seg.pt
|
||||
|
||||
# === 调试选项 ===
|
||||
# 启用ASR原始日志
|
||||
# ASR_DEBUG_RAW=1
|
||||
|
||||
# 启用红绿灯调试图像
|
||||
# AIGLASS_DEBUG_TRAFFIC_LIGHT=1
|
||||
90
CHANGELOG.md
Normal file
90
CHANGELOG.md
Normal file
@@ -0,0 +1,90 @@
|
||||
# 更新日志
|
||||
|
||||
本文档记录项目的所有重要变更。
|
||||
|
||||
格式基于 [Keep a Changelog](https://keepachangelog.com/zh-CN/1.0.0/),
|
||||
版本号遵循 [语义化版本](https://semver.org/lang/zh-CN/)。
|
||||
|
||||
## [未发布]
|
||||
|
||||
### 新增
|
||||
- 首次开源发布
|
||||
- 完整的 GitHub 文档(README, CONTRIBUTING, LICENSE 等)
|
||||
- Docker 支持
|
||||
- 环境变量配置模板
|
||||
|
||||
### 修改
|
||||
- 优化了 README 文档结构
|
||||
- 改进了代码注释
|
||||
|
||||
## [1.0.0] - 2025-01-XX
|
||||
|
||||
### 新增
|
||||
- 🚶 盲道导航系统
|
||||
- 实时盲道检测与分割
|
||||
- 智能语音引导
|
||||
- 障碍物检测与避障
|
||||
- 急转弯检测与提醒
|
||||
- 光流稳定算法
|
||||
|
||||
- 🚦 过马路辅助
|
||||
- 斑马线识别与方向检测
|
||||
- 红绿灯颜色识别
|
||||
- 对齐引导系统
|
||||
- 安全提醒
|
||||
|
||||
- 🔍 物品识别与查找
|
||||
- YOLO-E 开放词汇检测
|
||||
- MediaPipe 手部引导
|
||||
- 实时目标追踪
|
||||
- 抓取动作检测
|
||||
|
||||
- 🎙️ 实时语音交互
|
||||
- 阿里云 Paraformer ASR
|
||||
- Qwen-Omni-Turbo 多模态对话
|
||||
- 智能指令解析
|
||||
- 上下文感知
|
||||
|
||||
- 📹 视频与音频处理
|
||||
- WebSocket 实时推流
|
||||
- 音视频同步录制
|
||||
- IMU 数据融合
|
||||
- 多路音频混音
|
||||
|
||||
- 🎨 可视化与交互
|
||||
- Web 实时监控界面
|
||||
- IMU 3D 可视化
|
||||
- 状态面板
|
||||
- 中文友好界面
|
||||
|
||||
### 技术栈
|
||||
- FastAPI + WebSocket
|
||||
- YOLO11 / YOLO-E
|
||||
- MediaPipe
|
||||
- PyTorch + CUDA
|
||||
- OpenCV
|
||||
- DashScope API
|
||||
|
||||
### 已知问题
|
||||
- [ ] 在低端 GPU 上可能出现卡顿
|
||||
- [ ] macOS 上缺少 GPU 加速支持
|
||||
- [ ] 部分中文字体在 Linux 上显示不正确
|
||||
|
||||
---
|
||||
|
||||
## 版本说明
|
||||
|
||||
### 主版本(Major)
|
||||
- 不兼容的 API 更改
|
||||
|
||||
### 次版本(Minor)
|
||||
- 向后兼容的新功能
|
||||
|
||||
### 修订版本(Patch)
|
||||
- 向后兼容的问题修复
|
||||
|
||||
---
|
||||
|
||||
[未发布]: https://github.com/yourusername/aiglass/compare/v1.0.0...HEAD
|
||||
[1.0.0]: https://github.com/yourusername/aiglass/releases/tag/v1.0.0
|
||||
|
||||
59
Dockerfile
Normal file
59
Dockerfile
Normal file
@@ -0,0 +1,59 @@
|
||||
# AI Glass System - Dockerfile
|
||||
# 基于 NVIDIA CUDA 的 Python 镜像
|
||||
|
||||
FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
|
||||
|
||||
# 设置环境变量
|
||||
ENV DEBIAN_FRONTEND=noninteractive
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV CUDA_HOME=/usr/local/cuda
|
||||
ENV PATH=${CUDA_HOME}/bin:${PATH}
|
||||
ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
|
||||
|
||||
# 设置工作目录
|
||||
WORKDIR /app
|
||||
|
||||
# 安装系统依赖
|
||||
RUN apt-get update && apt-get install -y \
|
||||
python3.10 \
|
||||
python3-pip \
|
||||
python3-dev \
|
||||
portaudio19-dev \
|
||||
libgl1-mesa-glx \
|
||||
libglib2.0-0 \
|
||||
libsm6 \
|
||||
libxext6 \
|
||||
libxrender-dev \
|
||||
libgomp1 \
|
||||
git \
|
||||
wget \
|
||||
curl \
|
||||
&& apt-get clean \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# 升级 pip
|
||||
RUN python3 -m pip install --upgrade pip
|
||||
|
||||
# 复制 requirements.txt
|
||||
COPY requirements.txt .
|
||||
|
||||
# 安装 Python 依赖
|
||||
RUN pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
|
||||
RUN pip install --no-cache-dir -r requirements.txt
|
||||
|
||||
# 复制应用代码
|
||||
COPY . .
|
||||
|
||||
# 创建必要的目录
|
||||
RUN mkdir -p recordings model music voice static templates
|
||||
|
||||
# 暴露端口
|
||||
EXPOSE 8081 12345/udp
|
||||
|
||||
# 健康检查
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
||||
CMD curl -f http://localhost:8081/api/health || exit 1
|
||||
|
||||
# 启动命令
|
||||
CMD ["python3", "app_main.py"]
|
||||
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2025 AI-FanGe
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
402
PROJECT_STRUCTURE.md
Normal file
402
PROJECT_STRUCTURE.md
Normal file
@@ -0,0 +1,402 @@
|
||||
# 项目结构说明
|
||||
|
||||
本文档详细说明项目的目录结构和主要文件的作用。
|
||||
|
||||
## 📁 目录结构
|
||||
|
||||
```
|
||||
rebuild1002/
|
||||
├── 📄 主要应用文件
|
||||
│ ├── app_main.py # 主应用入口(FastAPI 服务)
|
||||
│ ├── navigation_master.py # 导航统领器(状态机)
|
||||
│ ├── workflow_blindpath.py # 盲道导航工作流
|
||||
│ ├── workflow_crossstreet.py # 过马路导航工作流
|
||||
│ └── yolomedia.py # 物品查找工作流
|
||||
│
|
||||
├── 🎙️ 语音处理模块
|
||||
│ ├── asr_core.py # 语音识别核心
|
||||
│ ├── omni_client.py # Qwen-Omni 客户端
|
||||
│ ├── qwen_extractor.py # 标签提取(中文->英文)
|
||||
│ ├── audio_player.py # 音频播放器
|
||||
│ └── audio_stream.py # 音频流管理
|
||||
│
|
||||
├── 🤖 模型相关
|
||||
│ ├── yoloe_backend.py # YOLO-E 后端(开放词汇)
|
||||
│ ├── trafficlight_detection.py # 红绿灯检测
|
||||
│ ├── obstacle_detector_client.py # 障碍物检测客户端
|
||||
│ └── models.py # 模型定义
|
||||
│
|
||||
├── 🎥 视频处理
|
||||
│ ├── bridge_io.py # 线程安全的帧缓冲
|
||||
│ ├── sync_recorder.py # 音视频同步录制
|
||||
│ └── video_recorder.py # 视频录制(旧版)
|
||||
│
|
||||
├── 🌐 Web 前端
|
||||
│ ├── templates/
|
||||
│ │ └── index.html # 主界面 HTML
|
||||
│ ├── static/
|
||||
│ │ ├── main.js # 主 JS 脚本
|
||||
│ │ ├── vision.js # 视觉流处理
|
||||
│ │ ├── visualizer.js # 数据可视化
|
||||
│ │ ├── vision_renderer.js # 渲染器
|
||||
│ │ ├── vision.css # 样式表
|
||||
│ │ └── models/ # 3D 模型(IMU 可视化)
|
||||
│
|
||||
├── 🎵 音频资源
|
||||
│ ├── music/ # 系统提示音
|
||||
│ │ ├── converted_向上.wav
|
||||
│ │ ├── converted_向下.wav
|
||||
│ │ └── ...
|
||||
│ └── voice/ # 预录语音
|
||||
│ ├── voice_mapping.json
|
||||
│ └── *.wav
|
||||
│
|
||||
├── 🧠 模型文件
|
||||
│ └── model/
|
||||
│ ├── yolo-seg.pt # 盲道分割模型
|
||||
│ ├── yoloe-11l-seg.pt # YOLO-E 开放词汇模型
|
||||
│ ├── shoppingbest5.pt # 物品识别模型
|
||||
│ ├── trafficlight.pt # 红绿灯检测模型
|
||||
│ └── hand_landmarker.task # MediaPipe 手部模型
|
||||
│
|
||||
├── 📹 录制文件
|
||||
│ └── recordings/ # 自动保存的视频和音频
|
||||
│ ├── video_*.avi
|
||||
│ └── audio_*.wav
|
||||
│
|
||||
├── 🛠️ ESP32 固件
|
||||
│ └── compile/
|
||||
│ ├── compile.ino # Arduino 主程序
|
||||
│ ├── camera_pins.h # 摄像头引脚定义
|
||||
│ ├── ICM42688.cpp/h # IMU 驱动
|
||||
│ └── ESP32_VIDEO_OPTIMIZATION.md
|
||||
│
|
||||
├── 🧪 测试文件
|
||||
│ ├── test_recorder.py # 录制功能测试
|
||||
│ ├── test_traffic_light.py # 红绿灯检测测试
|
||||
│ ├── test_cross_street_blindpath.py # 导航测试
|
||||
│ └── test_crosswalk_awareness.py # 斑马线检测测试
|
||||
│
|
||||
├── 📚 文档
|
||||
│ ├── README.md # 项目主文档
|
||||
│ ├── INSTALLATION.md # 安装指南
|
||||
│ ├── CONTRIBUTING.md # 贡献指南
|
||||
│ ├── FAQ.md # 常见问题
|
||||
│ ├── CHANGELOG.md # 更新日志
|
||||
│ ├── SECURITY.md # 安全政策
|
||||
│ └── PROJECT_STRUCTURE.md # 本文件
|
||||
│
|
||||
├── 🐳 Docker 相关
|
||||
│ ├── Dockerfile # Docker 镜像定义
|
||||
│ ├── docker-compose.yml # Docker Compose 配置
|
||||
│ └── .dockerignore # Docker 忽略文件
|
||||
│
|
||||
├── ⚙️ 配置文件
|
||||
│ ├── .env.example # 环境变量模板
|
||||
│ ├── .gitignore # Git 忽略文件
|
||||
│ ├── requirements.txt # Python 依赖
|
||||
│ ├── setup.sh # Linux/macOS 安装脚本
|
||||
│ └── setup.bat # Windows 安装脚本
|
||||
│
|
||||
├── 📄 许可证
|
||||
│ └── LICENSE # MIT 许可证
|
||||
│
|
||||
└── 🔧 GitHub 相关
|
||||
└── .github/
|
||||
├── ISSUE_TEMPLATE/
|
||||
│ ├── bug_report.md
|
||||
│ └── feature_request.md
|
||||
└── pull_request_template.md
|
||||
```
|
||||
|
||||
## 🔑 核心文件说明
|
||||
|
||||
### 主应用层
|
||||
|
||||
#### `app_main.py`
|
||||
- **作用**: FastAPI 主服务,处理所有 WebSocket 连接
|
||||
- **主要功能**:
|
||||
- WebSocket 路由管理(/ws/camera, /ws_audio, /ws/viewer 等)
|
||||
- 模型加载与初始化
|
||||
- 状态协调与管理
|
||||
- 音视频流分发
|
||||
- **依赖**: 所有其他模块
|
||||
- **入口点**: `python app_main.py`
|
||||
|
||||
#### `navigation_master.py`
|
||||
- **作用**: 导航统领器,管理整个系统的状态机
|
||||
- **主要状态**:
|
||||
- IDLE: 空闲
|
||||
- CHAT: 对话模式
|
||||
- BLINDPATH_NAV: 盲道导航
|
||||
- CROSSING: 过马路
|
||||
- TRAFFIC_LIGHT_DETECTION: 红绿灯检测
|
||||
- ITEM_SEARCH: 物品查找
|
||||
- **核心方法**:
|
||||
- `process_frame()`: 处理每一帧
|
||||
- `start_blind_path_navigation()`: 启动盲道导航
|
||||
- `start_crossing()`: 启动过马路模式
|
||||
- `on_voice_command()`: 处理语音命令
|
||||
|
||||
### 工作流模块
|
||||
|
||||
#### `workflow_blindpath.py`
|
||||
- **作用**: 盲道导航核心逻辑
|
||||
- **主要功能**:
|
||||
- 盲道分割与检测
|
||||
- 障碍物检测
|
||||
- 转弯检测
|
||||
- 光流稳定
|
||||
- 方向引导生成
|
||||
- **状态机**:
|
||||
- ONBOARDING: 上盲道
|
||||
- NAVIGATING: 导航中
|
||||
- MANEUVERING_TURN: 转弯
|
||||
- AVOIDING_OBSTACLE: 避障
|
||||
|
||||
#### `workflow_crossstreet.py`
|
||||
- **作用**: 过马路导航逻辑
|
||||
- **主要功能**:
|
||||
- 斑马线检测
|
||||
- 方向对齐
|
||||
- 引导生成
|
||||
- **核心方法**:
|
||||
- `_is_crosswalk_near()`: 判断是否接近斑马线
|
||||
- `_compute_angle_and_offset()`: 计算角度和偏移
|
||||
|
||||
#### `yolomedia.py`
|
||||
- **作用**: 物品查找工作流
|
||||
- **主要功能**:
|
||||
- YOLO-E 文本提示检测
|
||||
- MediaPipe 手部追踪
|
||||
- 光流目标追踪
|
||||
- 手部引导(方向提示)
|
||||
- 抓取动作检测
|
||||
- **模式**:
|
||||
- SEGMENT: 检测模式
|
||||
- FLASH: 闪烁确认
|
||||
- CENTER_GUIDE: 居中引导
|
||||
- TRACK: 手部追踪
|
||||
|
||||
### 语音模块
|
||||
|
||||
#### `asr_core.py`
|
||||
- **作用**: 阿里云 Paraformer ASR 实时语音识别
|
||||
- **主要功能**:
|
||||
- 实时语音识别
|
||||
- VAD(语音活动检测)
|
||||
- 识别结果回调
|
||||
- **关键类**: `ASRCallback`
|
||||
|
||||
#### `omni_client.py`
|
||||
- **作用**: Qwen-Omni-Turbo 多模态对话客户端
|
||||
- **主要功能**:
|
||||
- 流式对话生成
|
||||
- 图像+文本输入
|
||||
- 语音输出
|
||||
- **核心函数**: `stream_chat()`
|
||||
|
||||
#### `audio_player.py`
|
||||
- **作用**: 统一的音频播放管理
|
||||
- **主要功能**:
|
||||
- TTS 语音播放
|
||||
- 多路音频混音
|
||||
- 音量控制
|
||||
- 线程安全播放
|
||||
- **核心函数**: `play_voice_text()`, `play_audio_threadsafe()`
|
||||
|
||||
### 模型后端
|
||||
|
||||
#### `yoloe_backend.py`
|
||||
- **作用**: YOLO-E 开放词汇检测后端
|
||||
- **主要功能**:
|
||||
- 文本提示设置
|
||||
- 实时分割
|
||||
- 目标追踪
|
||||
- **核心类**: `YoloEBackend`
|
||||
|
||||
#### `trafficlight_detection.py`
|
||||
- **作用**: 红绿灯检测模块
|
||||
- **检测方法**:
|
||||
1. YOLO 模型检测
|
||||
2. HSV 颜色分类(备用)
|
||||
- **输出**: 红灯/绿灯/黄灯/未知
|
||||
|
||||
#### `obstacle_detector_client.py`
|
||||
- **作用**: 障碍物检测客户端
|
||||
- **主要功能**:
|
||||
- 白名单类别过滤
|
||||
- 路径掩码内检测
|
||||
- 物体属性计算(面积、位置、危险度)
|
||||
|
||||
### 视频处理
|
||||
|
||||
#### `bridge_io.py`
|
||||
- **作用**: 线程安全的帧缓冲与分发
|
||||
- **主要功能**:
|
||||
- 生产者-消费者模式
|
||||
- 原始帧缓存
|
||||
- 处理后帧分发
|
||||
- **核心函数**:
|
||||
- `push_raw_jpeg()`: 接收 ESP32 帧
|
||||
- `wait_raw_bgr()`: 取原始帧
|
||||
- `send_vis_bgr()`: 发送处理后的帧
|
||||
|
||||
#### `sync_recorder.py`
|
||||
- **作用**: 音视频同步录制
|
||||
- **主要功能**:
|
||||
- 同步录制视频和音频
|
||||
- 自动文件命名(时间戳)
|
||||
- 线程安全
|
||||
- **输出**: `recordings/video_*.avi`, `audio_*.wav`
|
||||
|
||||
### 前端
|
||||
|
||||
#### `templates/index.html`
|
||||
- **作用**: Web 监控界面
|
||||
- **主要区域**:
|
||||
- 视频流显示
|
||||
- 状态面板
|
||||
- IMU 3D 可视化
|
||||
- 语音识别结果
|
||||
|
||||
#### `static/main.js`
|
||||
- **作用**: 主 JavaScript 逻辑
|
||||
- **主要功能**:
|
||||
- WebSocket 连接管理
|
||||
- UI 更新
|
||||
- 事件处理
|
||||
|
||||
#### `static/vision.js`
|
||||
- **作用**: 视觉流处理
|
||||
- **主要功能**:
|
||||
- WebSocket 接收视频帧
|
||||
- Canvas 渲染
|
||||
- FPS 计算
|
||||
|
||||
#### `static/visualizer.js`
|
||||
- **作用**: IMU 3D 可视化(Three.js)
|
||||
- **主要功能**:
|
||||
- 接收 IMU 数据
|
||||
- 实时渲染设备姿态
|
||||
- 动态灯光效果
|
||||
|
||||
## 🔄 数据流
|
||||
|
||||
### 视频流
|
||||
```
|
||||
ESP32-CAM
|
||||
→ [JPEG] WebSocket /ws/camera
|
||||
→ bridge_io.push_raw_jpeg()
|
||||
→ yolomedia / navigation_master
|
||||
→ bridge_io.send_vis_bgr()
|
||||
→ [JPEG] WebSocket /ws/viewer
|
||||
→ Browser Canvas
|
||||
```
|
||||
|
||||
### 音频流(上行)
|
||||
```
|
||||
ESP32-MIC
|
||||
→ [PCM16] WebSocket /ws_audio
|
||||
→ asr_core
|
||||
→ DashScope ASR
|
||||
→ 识别结果
|
||||
→ start_ai_with_text_custom()
|
||||
```
|
||||
|
||||
### 音频流(下行)
|
||||
```
|
||||
Qwen-Omni / TTS
|
||||
→ audio_player
|
||||
→ [PCM16] audio_stream
|
||||
→ [WAV] HTTP /stream.wav
|
||||
→ ESP32-Speaker
|
||||
```
|
||||
|
||||
### IMU 数据流
|
||||
```
|
||||
ESP32-IMU
|
||||
→ [JSON] UDP 12345
|
||||
→ process_imu_and_maybe_store()
|
||||
→ [JSON] WebSocket /ws
|
||||
→ visualizer.js (Three.js)
|
||||
```
|
||||
|
||||
## 🎯 关键设计模式
|
||||
|
||||
### 1. 状态机模式
|
||||
- **位置**: `navigation_master.py`
|
||||
- **作用**: 管理系统状态转换
|
||||
- **状态**: IDLE → CHAT / BLINDPATH_NAV / CROSSING / ...
|
||||
|
||||
### 2. 生产者-消费者模式
|
||||
- **位置**: `bridge_io.py`
|
||||
- **作用**: 解耦视频接收与处理
|
||||
- **实现**: 线程 + 队列
|
||||
|
||||
### 3. 策略模式
|
||||
- **位置**: 各 `workflow_*.py`
|
||||
- **作用**: 不同导航策略的实现
|
||||
- **实现**: 统一的 `process_frame()` 接口
|
||||
|
||||
### 4. 单例模式
|
||||
- **位置**: 模型加载
|
||||
- **作用**: 全局共享模型实例
|
||||
- **实现**: 全局变量 + 初始化检查
|
||||
|
||||
### 5. 观察者模式
|
||||
- **位置**: WebSocket 通信
|
||||
- **作用**: 多客户端订阅视频流
|
||||
- **实现**: `camera_viewers: Set[WebSocket]`
|
||||
|
||||
## 📦 依赖关系
|
||||
|
||||
```
|
||||
app_main.py
|
||||
├── navigation_master.py
|
||||
│ ├── workflow_blindpath.py
|
||||
│ │ ├── yoloe_backend.py
|
||||
│ │ └── obstacle_detector_client.py
|
||||
│ ├── workflow_crossstreet.py
|
||||
│ └── trafficlight_detection.py
|
||||
├── yolomedia.py
|
||||
│ └── yoloe_backend.py
|
||||
├── asr_core.py
|
||||
├── omni_client.py
|
||||
├── audio_player.py
|
||||
├── audio_stream.py
|
||||
├── bridge_io.py
|
||||
└── sync_recorder.py
|
||||
```
|
||||
|
||||
## 🚀 启动流程
|
||||
|
||||
1. **初始化阶段** (`app_main.py`)
|
||||
- 加载环境变量
|
||||
- 加载导航模型(YOLO、MediaPipe)
|
||||
- 初始化音频系统
|
||||
- 启动录制系统
|
||||
- 预加载红绿灯模型
|
||||
|
||||
2. **服务启动** (FastAPI)
|
||||
- 注册 WebSocket 路由
|
||||
- 挂载静态文件
|
||||
- 启动 UDP 监听(IMU)
|
||||
- 启动 HTTP 服务(8081 端口)
|
||||
|
||||
3. **运行阶段**
|
||||
- 等待 ESP32 连接
|
||||
- 接收视频/音频/IMU 数据
|
||||
- 处理用户语音指令
|
||||
- 实时推送处理结果
|
||||
|
||||
4. **关闭阶段**
|
||||
- 停止录制(保存文件)
|
||||
- 关闭所有 WebSocket 连接
|
||||
- 释放模型资源
|
||||
- 清理临时文件
|
||||
|
||||
---
|
||||
|
||||
**提示**: 如需了解某个模块的详细实现,请查看相应源文件的注释和 docstring。
|
||||
|
||||
506
README.md
506
README.md
@@ -1,2 +1,506 @@
|
||||
# NaviGlassServer
|
||||
# AI 智能盲人眼镜系统 🤖👓
|
||||
|
||||
<div align="center">
|
||||
|
||||
一个面向视障人士的智能导航与辅助系统,集成了盲道导航、过马路辅助、物品识别、实时语音交互等功能。 本项目仅为交流学习使用,请勿直接给视障人群使用。本项目内仅包含代码,模型地址:https://www.modelscope.cn/models/archifancy/AIGlasses_for_navigation 。下载后存放在/model 文件夹
|
||||
|
||||
[功能特性](#功能特性) • [快速开始](#快速开始) • [系统架构](#系统架构) • [使用说明](#使用说明) • [开发文档](#开发文档)
|
||||
|
||||
</div>
|
||||
|
||||
---
|
||||
<img width="2481" height="3508" alt="1" src="https://github.com/user-attachments/assets/e8dec4a6-8fa6-4d94-bd66-4e9864b67daf" />
|
||||
<img width="2480" height="3508" alt="2" src="https://github.com/user-attachments/assets/bc7d1aac-a9e9-4ef8-9d67-224708d0c9fd" />
|
||||
<img width="2481" height="3508" alt="4" src="https://github.com/user-attachments/assets/6dd19750-57af-4560-a007-9a7059956b53" />
|
||||
|
||||
## 📋 目录
|
||||
|
||||
- [功能特性](#功能特性)
|
||||
- [系统要求](#系统要求)
|
||||
- [快速开始](#快速开始)
|
||||
- [系统架构](#系统架构)
|
||||
- [使用说明](#使用说明)
|
||||
- [配置说明](#配置说明)
|
||||
- [开发文档](#开发文档)
|
||||
|
||||
## ✨ 功能特性
|
||||
|
||||
### 🚶 盲道导航系统
|
||||
- **实时盲道检测**:基于 YOLO 分割模型实时识别盲道
|
||||
- **智能语音引导**:提供精准的方向指引(左转、右转、直行等)
|
||||
- **障碍物检测与避障**:自动识别前方障碍物并规划避障路线
|
||||
- **转弯检测**:自动识别急转弯并提前提醒
|
||||
- **光流稳定**:使用 Lucas-Kanade 光流算法稳定掩码,减少抖动
|
||||
|
||||
### 🚦 过马路辅助
|
||||
- **斑马线识别**:实时检测斑马线位置和方向
|
||||
- **红绿灯识别**:基于颜色和形状的红绿灯状态检测
|
||||
- **对齐引导**:引导用户对准斑马线中心
|
||||
- **安全提醒**:绿灯时语音提示可以通行
|
||||
|
||||
### 🔍 物品识别与查找
|
||||
- **智能物品搜索**:语音指令查找物品(如"帮我找一下红牛")
|
||||
- **实时目标追踪**:使用 YOLO-E 开放词汇检测 + ByteTrack 追踪
|
||||
- **手部引导**:结合 MediaPipe 手部检测,引导用户手部靠近物品
|
||||
- **抓取检测**:检测手部握持动作,确认物品已拿到
|
||||
- **多模态反馈**:视觉标注 + 语音引导 + 居中提示
|
||||
|
||||
### 🎙️ 实时语音交互
|
||||
- **语音识别(ASR)**:基于阿里云 DashScope Paraformer 实时语音识别
|
||||
- **多模态对话**:Qwen-Omni-Turbo 支持图像+文本输入,语音输出
|
||||
- **智能指令解析**:自动识别导航、查找、对话等不同类型指令
|
||||
- **上下文感知**:在不同模式下智能过滤无关指令
|
||||
|
||||
### 📹 视频与音频处理
|
||||
- **实时视频流**:WebSocket 推流,支持多客户端同时观看
|
||||
- **音视频同步录制**:自动保存带时间戳的录像和音频文件
|
||||
- **IMU 数据融合**:接收 ESP32 的 IMU 数据,支持姿态估计
|
||||
- **多路音频混音**:支持系统语音、AI 回复、环境音同时播放
|
||||
|
||||
### 🎨 可视化与交互
|
||||
- **Web 实时监控**:浏览器端实时查看处理后的视频流
|
||||
- **IMU 3D 可视化**:Three.js 实时渲染设备姿态
|
||||
- **状态面板**:显示导航状态、检测信息、FPS 等
|
||||
- **中文友好**:所有界面和语音使用中文,支持自定义字体
|
||||
|
||||
## 💻 系统要求
|
||||
|
||||
### 硬件要求
|
||||
- **开发/服务器端**:
|
||||
- CPU: Intel i5 或以上(推荐 i7/i9)
|
||||
- GPU: NVIDIA GPU(支持 CUDA 11.8+,推荐 RTX 3060 或以上)
|
||||
- 内存: 8GB RAM(推荐 16GB)
|
||||
- 存储: 10GB 可用空间
|
||||
|
||||
- **客户端设备**(可选):
|
||||
- ESP32-CAM 或其他支持 WebSocket 的摄像头
|
||||
- 麦克风(用于语音输入)
|
||||
- 扬声器/耳机(用于语音输出)
|
||||
|
||||
### 软件要求
|
||||
- **操作系统**: Windows 10/11, Linux (Ubuntu 20.04+), macOS 10.15+
|
||||
- **Python**: 3.9 - 3.11
|
||||
- **CUDA**: 11.8 或更高版本(GPU 加速必需)
|
||||
- **浏览器**: Chrome 90+, Firefox 88+, Edge 90+(用于 Web 监控)
|
||||
|
||||
### API 密钥
|
||||
- **阿里云 DashScope API Key**(必需):
|
||||
- 用于语音识别(ASR)和 Qwen-Omni 对话
|
||||
- 申请地址:https://dashscope.console.aliyun.com/
|
||||
|
||||
## 🚀 快速开始
|
||||
|
||||
### 1. 克隆项目
|
||||
|
||||
```bash
|
||||
git clone https://github.com/yourusername/aiglass.git
|
||||
cd aiglass/rebuild1002
|
||||
```
|
||||
|
||||
### 2. 安装依赖
|
||||
|
||||
#### 创建虚拟环境(推荐)
|
||||
```bash
|
||||
python -m venv venv
|
||||
# Windows
|
||||
venv\Scripts\activate
|
||||
# Linux/macOS
|
||||
source venv/bin/activate
|
||||
```
|
||||
|
||||
#### 安装 Python 包
|
||||
```bash
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
#### 安装 CUDA 和 cuDNN(GPU 加速)
|
||||
请参考 [NVIDIA CUDA Toolkit 安装指南](https://developer.nvidia.com/cuda-downloads)
|
||||
|
||||
### 3. 下载模型文件
|
||||
|
||||
将以下模型文件放入 `model/` 目录:
|
||||
|
||||
| 模型文件 | 用途 | 大小 | 下载链接 |
|
||||
|---------|------|------|---------|
|
||||
| `yolo-seg.pt` | 盲道分割 | ~50MB | [待补充] |
|
||||
| `yoloe-11l-seg.pt` | 开放词汇检测 | ~80MB | [待补充] |
|
||||
| `shoppingbest5.pt` | 物品识别 | ~30MB | [待补充] |
|
||||
| `trafficlight.pt` | 红绿灯检测 | ~20MB | [待补充] |
|
||||
| `hand_landmarker.task` | 手部检测 | ~15MB | [MediaPipe Models](https://developers.google.com/mediapipe/solutions/vision/hand_landmarker#models) |
|
||||
|
||||
### 4. 配置 API 密钥
|
||||
|
||||
创建 `.env` 文件:
|
||||
|
||||
```bash
|
||||
# .env
|
||||
DASHSCOPE_API_KEY=your_api_key_here
|
||||
```
|
||||
|
||||
或在代码中直接修改(不推荐):
|
||||
```python
|
||||
# app_main.py, line 50
|
||||
API_KEY = "your_api_key_here"
|
||||
```
|
||||
|
||||
### 5. 启动系统
|
||||
|
||||
```bash
|
||||
python app_main.py
|
||||
```
|
||||
|
||||
系统将在 `http://0.0.0.0:8081` 启动,打开浏览器访问即可看到实时监控界面。
|
||||
|
||||
### 6. 连接设备(可选)
|
||||
|
||||
如果使用 ESP32-CAM,请:
|
||||
1. 烧录 `compile/compile.ino` 到 ESP32
|
||||
2. 修改 WiFi 配置,连接到同一网络
|
||||
3. ESP32 自动连接到 WebSocket 端点
|
||||
|
||||
## 🏗️ 系统架构
|
||||
|
||||
### 整体架构
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ 客户端层 │
|
||||
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
|
||||
│ │ ESP32-CAM │ │ 浏览器 │ │ 移动端 │ │
|
||||
│ │ (视频/音频) │ │ (监控界面) │ │ (语音控制) │ │
|
||||
│ └──────┬───────┘ └──────┬───────┘ └──────┬───────┘ │
|
||||
└─────────┼──────────────────┼──────────────────┼─────────────┘
|
||||
│ WebSocket │ HTTP/WS │ WebSocket
|
||||
┌─────────┼──────────────────┼──────────────────┼─────────────┐
|
||||
│ │ │ │ │
|
||||
│ ┌────▼──────────────────▼──────────────────▼────────┐ │
|
||||
│ │ FastAPI 主服务 (app_main.py) │ │
|
||||
│ │ - WebSocket 路由管理 │ │
|
||||
│ │ - 音视频流分发 │ │
|
||||
│ │ - 状态管理与协调 │ │
|
||||
│ └────┬────────────────┬────────────────┬─────────────┘ │
|
||||
│ │ │ │ │
|
||||
│ ┌──────▼──────┐ ┌──────▼──────┐ ┌──────▼──────┐ │
|
||||
│ │ ASR 模块 │ │ Omni 对话 │ │ 音频播放 │ │
|
||||
│ │ (asr_core) │ │(omni_client)│ │(audio_player)│ │
|
||||
│ └──────────────┘ └──────────────┘ └──────────────┘ │
|
||||
│ │
|
||||
│ 应用层 │
|
||||
└───────────────────────────────────────────────────────────────┘
|
||||
│ │ │
|
||||
┌─────────▼──────────────────▼──────────────────▼──────────────┐
|
||||
│ 导航统领层 │
|
||||
│ ┌─────────────────────────────────────────────────┐ │
|
||||
│ │ NavigationMaster (navigation_master.py) │ │
|
||||
│ │ - 状态机:IDLE/CHAT/BLINDPATH_NAV/ │ │
|
||||
│ │ CROSSING/TRAFFIC_LIGHT/ITEM_SEARCH │ │
|
||||
│ │ - 模式切换与协调 │ │
|
||||
│ └───┬─────────────────────┬───────────────────┬───┘ │
|
||||
│ │ │ │ │
|
||||
│ ┌────▼────────┐ ┌────────▼────────┐ ┌─────▼──────┐ │
|
||||
│ │ 盲道导航 │ │ 过马路导航 │ │ 物品查找 │ │
|
||||
│ │(blindpath) │ │ (crossstreet) │ │(yolomedia) │ │
|
||||
│ └──────────────┘ └──────────────────┘ └─────────────┘ │
|
||||
└───────────────────────────────────────────────────────────────┘
|
||||
│ │ │
|
||||
┌─────────▼──────────────────▼──────────────────▼──────────────┐
|
||||
│ 模型推理层 │
|
||||
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
|
||||
│ │ YOLO 分割 │ │ YOLO-E 检测 │ │ MediaPipe │ │
|
||||
│ │ (盲道/斑马线) │ │ (开放词汇) │ │ (手部检测) │ │
|
||||
│ └──────────────┘ └──────────────┘ └──────────────┘ │
|
||||
│ ┌──────────────┐ ┌──────────────┐ │
|
||||
│ │ 红绿灯检测 │ │ 光流稳定 │ │
|
||||
│ │(HSV+YOLO) │ │(Lucas-Kanade)│ │
|
||||
│ └──────────────┘ └──────────────┘ │
|
||||
└───────────────────────────────────────────────────────────────┘
|
||||
│
|
||||
┌─────────▼─────────────────────────────────────────────────────┐
|
||||
│ 外部服务层 │
|
||||
│ ┌──────────────────────────────────────────────┐ │
|
||||
│ │ 阿里云 DashScope API │ │
|
||||
│ │ - Paraformer ASR (实时语音识别) │ │
|
||||
│ │ - Qwen-Omni-Turbo (多模态对话) │ │
|
||||
│ │ - Qwen-Turbo (标签提取) │ │
|
||||
│ └──────────────────────────────────────────────┘ │
|
||||
└───────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
### 核心模块说明
|
||||
|
||||
| 模块 | 文件 | 功能 |
|
||||
|------|------|------|
|
||||
| **主应用** | `app_main.py` | FastAPI 服务、WebSocket 管理、状态协调 |
|
||||
| **导航统领** | `navigation_master.py` | 状态机管理、模式切换、语音节流 |
|
||||
| **盲道导航** | `workflow_blindpath.py` | 盲道检测、避障、转弯引导 |
|
||||
| **过马路导航** | `workflow_crossstreet.py` | 斑马线检测、红绿灯识别、对齐引导 |
|
||||
| **物品查找** | `yolomedia.py` | 物品检测、手部引导、抓取确认 |
|
||||
| **语音识别** | `asr_core.py` | 实时 ASR、VAD、指令解析 |
|
||||
| **语音合成** | `omni_client.py` | Qwen-Omni 流式语音生成 |
|
||||
| **音频播放** | `audio_player.py` | 多路混音、TTS 播放、音量控制 |
|
||||
| **视频录制** | `sync_recorder.py` | 音视频同步录制 |
|
||||
| **桥接 IO** | `bridge_io.py` | 线程安全的帧缓冲与分发 |
|
||||
|
||||
## 📖 使用说明
|
||||
|
||||
### 语音指令
|
||||
|
||||
系统支持以下语音指令(说话时无需唤醒词):
|
||||
|
||||
#### 导航控制
|
||||
```
|
||||
"开始导航" / "盲道导航" → 启动盲道导航
|
||||
"停止导航" / "结束导航" → 停止盲道导航
|
||||
"开始过马路" / "帮我过马路" → 启动过马路模式
|
||||
"过马路结束" / "结束过马路" → 停止过马路模式
|
||||
```
|
||||
|
||||
#### 红绿灯检测
|
||||
```
|
||||
"检测红绿灯" / "看红绿灯" → 启动红绿灯检测
|
||||
"停止检测" / "停止红绿灯" → 停止检测
|
||||
```
|
||||
|
||||
#### 物品查找
|
||||
```
|
||||
"帮我找一下 [物品名]" → 启动物品搜索
|
||||
示例:
|
||||
- "帮我找一下红牛"
|
||||
- "找一下AD钙奶"
|
||||
- "帮我找矿泉水"
|
||||
"找到了" / "拿到了" → 确认找到物品
|
||||
```
|
||||
|
||||
#### 智能对话
|
||||
```
|
||||
"帮我看看这是什么" → 拍照识别
|
||||
"这个东西能吃吗" → 物品咨询
|
||||
任何其他问题 → AI 对话
|
||||
```
|
||||
|
||||
### 导航状态说明
|
||||
|
||||
系统包含以下主要状态(自动切换):
|
||||
|
||||
1. **IDLE** - 空闲状态
|
||||
- 等待用户指令
|
||||
- 显示原始视频流
|
||||
|
||||
2. **CHAT** - 对话模式
|
||||
- 与 AI 进行多模态对话
|
||||
- 暂停导航功能
|
||||
|
||||
3. **BLINDPATH_NAV** - 盲道导航
|
||||
- **ONBOARDING**: 上盲道引导
|
||||
- ROTATION: 旋转对准盲道
|
||||
- TRANSLATION: 平移至盲道中心
|
||||
- **NAVIGATING**: 沿盲道行走
|
||||
- 实时方向修正
|
||||
- 障碍物检测
|
||||
- **MANEUVERING_TURN**: 转弯处理
|
||||
- **AVOIDING_OBSTACLE**: 避障
|
||||
|
||||
4. **CROSSING** - 过马路模式
|
||||
- **SEEKING_CROSSWALK**: 寻找斑马线
|
||||
- **WAIT_TRAFFIC_LIGHT**: 等待绿灯
|
||||
- **CROSSING**: 过马路中
|
||||
- **SEEKING_NEXT_BLINDPATH**: 寻找对面盲道
|
||||
|
||||
5. **ITEM_SEARCH** - 物品查找
|
||||
- 实时检测目标物品
|
||||
- 引导手部靠近
|
||||
- 确认抓取
|
||||
|
||||
6. **TRAFFIC_LIGHT_DETECTION** - 红绿灯检测
|
||||
- 实时检测红绿灯状态
|
||||
- 语音播报颜色变化
|
||||
|
||||
### Web 监控界面
|
||||
|
||||
打开浏览器访问 `http://localhost:8081`,可以看到:
|
||||
|
||||
- **实时视频流**:显示处理后的视频,包括导航标注
|
||||
- **状态面板**:当前模式、检测信息、FPS
|
||||
- **IMU 可视化**:设备姿态 3D 实时渲染
|
||||
- **语音识别结果**:显示识别的文字和 AI 回复
|
||||
|
||||
### WebSocket 端点
|
||||
|
||||
| 端点 | 用途 | 数据格式 |
|
||||
|------|------|---------|
|
||||
| `/ws/camera` | ESP32 相机推流 | Binary (JPEG) |
|
||||
| `/ws/viewer` | 浏览器订阅视频 | Binary (JPEG) |
|
||||
| `/ws_audio` | ESP32 音频上传 | Binary (PCM16) |
|
||||
| `/ws_ui` | UI 状态推送 | JSON |
|
||||
| `/ws` | IMU 数据接收 | JSON |
|
||||
| `/stream.wav` | 音频下载流 | Binary (WAV) |
|
||||
|
||||
## ⚙️ 配置说明
|
||||
|
||||
### 环境变量
|
||||
|
||||
创建 `.env` 文件配置以下参数:
|
||||
|
||||
```bash
|
||||
# 阿里云 API
|
||||
DASHSCOPE_API_KEY=sk-xxxxx
|
||||
|
||||
# 模型路径(可选,使用默认路径可不配置)
|
||||
BLIND_PATH_MODEL=model/yolo-seg.pt
|
||||
OBSTACLE_MODEL=model/yoloe-11l-seg.pt
|
||||
YOLOE_MODEL_PATH=model/yoloe-11l-seg.pt
|
||||
|
||||
# 导航参数
|
||||
AIGLASS_MASK_MIN_AREA=1500 # 最小掩码面积
|
||||
AIGLASS_MASK_MORPH=3 # 形态学核大小
|
||||
AIGLASS_MASK_MISS_TTL=6 # 掩码丢失容忍帧数
|
||||
AIGLASS_PANEL_SCALE=0.65 # 数据面板缩放
|
||||
|
||||
# 音频配置
|
||||
TTS_INTERVAL_SEC=1.0 # 语音播报间隔
|
||||
ENABLE_TTS=true # 启用语音播报
|
||||
```
|
||||
|
||||
### 修改模型路径
|
||||
|
||||
如果模型文件不在默认位置,可以在相应文件中修改:
|
||||
|
||||
```python
|
||||
# workflow_blindpath.py
|
||||
seg_model_path = "your/custom/path/yolo-seg.pt"
|
||||
|
||||
# yolomedia.py
|
||||
YOLO_MODEL_PATH = "your/custom/path/shoppingbest5.pt"
|
||||
HAND_TASK_PATH = "your/custom/path/hand_landmarker.task"
|
||||
```
|
||||
|
||||
### 调整性能参数
|
||||
|
||||
根据硬件性能调整:
|
||||
|
||||
```python
|
||||
# yolomedia.py
|
||||
HAND_DOWNSCALE = 0.8 # 手部检测降采样(越小越快,精度降低)
|
||||
HAND_FPS_DIV = 1 # 手部检测抽帧(2=隔帧,3=每3帧)
|
||||
|
||||
# workflow_blindpath.py
|
||||
FEATURE_PARAMS = dict(
|
||||
maxCorners=600, # 光流特征点数(越少越快)
|
||||
qualityLevel=0.001, # 特征点质量
|
||||
minDistance=5 # 特征点最小间距
|
||||
)
|
||||
```
|
||||
|
||||
## 🛠️ 开发文档
|
||||
|
||||
### 添加新的语音指令
|
||||
|
||||
1. 在 `app_main.py` 的 `start_ai_with_text_custom()` 函数中添加:
|
||||
|
||||
```python
|
||||
# 检查新指令
|
||||
if "新指令关键词" in user_text:
|
||||
# 执行自定义逻辑
|
||||
print("[CUSTOM] 新指令被触发")
|
||||
await ui_broadcast_final("[系统] 新功能已启动")
|
||||
return
|
||||
```
|
||||
|
||||
2. 如需修改指令过滤规则:
|
||||
|
||||
```python
|
||||
# 修改 allowed_keywords 列表
|
||||
allowed_keywords = ["帮我看", "帮我找", "你的新关键词"]
|
||||
```
|
||||
|
||||
### 扩展导航功能
|
||||
|
||||
1. 在 `workflow_blindpath.py` 添加新状态:
|
||||
|
||||
```python
|
||||
# 在 BlindPathNavigator.__init__() 中初始化
|
||||
self.your_new_state_var = False
|
||||
|
||||
# 在 process_frame() 中处理
|
||||
def process_frame(self, image):
|
||||
if self.your_new_state_var:
|
||||
# 自定义处理逻辑
|
||||
guidance_text = "新状态引导"
|
||||
# ...
|
||||
```
|
||||
|
||||
2. 在 `navigation_master.py` 添加状态机状态:
|
||||
|
||||
```python
|
||||
class NavigationMaster:
|
||||
def start_your_new_mode(self):
|
||||
self.state = "YOUR_NEW_MODE"
|
||||
# 初始化逻辑
|
||||
```
|
||||
|
||||
### 集成新模型
|
||||
|
||||
1. 创建模型包装类:
|
||||
|
||||
```python
|
||||
# your_model_wrapper.py
|
||||
class YourModelWrapper:
|
||||
def __init__(self, model_path):
|
||||
self.model = load_your_model(model_path)
|
||||
|
||||
def detect(self, image):
|
||||
# 推理逻辑
|
||||
return results
|
||||
```
|
||||
|
||||
2. 在 `app_main.py` 中加载:
|
||||
|
||||
```python
|
||||
your_model = YourModelWrapper("model/your_model.pt")
|
||||
```
|
||||
|
||||
3. 在相应的工作流中调用:
|
||||
|
||||
```python
|
||||
results = your_model.detect(image)
|
||||
```
|
||||
|
||||
### 调试技巧
|
||||
|
||||
1. **启用详细日志**:
|
||||
|
||||
```python
|
||||
# app_main.py 顶部
|
||||
import logging
|
||||
logging.basicConfig(level=logging.DEBUG)
|
||||
```
|
||||
|
||||
2. **查看帧率瓶颈**:
|
||||
|
||||
```python
|
||||
# yolomedia.py
|
||||
PERF_DEBUG = True # 打印处理时间
|
||||
```
|
||||
|
||||
3. **测试单个模块**:
|
||||
|
||||
```bash
|
||||
# 测试盲道导航
|
||||
python test_cross_street_blindpath.py
|
||||
|
||||
# 测试红绿灯检测
|
||||
python test_traffic_light.py
|
||||
|
||||
# 测试录制功能
|
||||
python test_recorder.py
|
||||
```
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
## 📄 许可证
|
||||
|
||||
本项目采用 MIT 许可证 - 详见 [LICENSE](LICENSE) 文件
|
||||
|
||||
|
||||
|
||||
154
ai_voice_pipeline.py
Normal file
154
ai_voice_pipeline.py
Normal file
@@ -0,0 +1,154 @@
|
||||
# ai_voice_pipeline.py
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
AI 语音交互管道 - Day 21
|
||||
|
||||
整合 SenseVoice + GLM-4.5-Flash + EdgeTTS
|
||||
|
||||
流程:
|
||||
1. 客户端 VAD 检测语音结束
|
||||
2. 发送完整音频到服务器
|
||||
3. SenseVoice 识别 → GLM 生成回复 → EdgeTTS 合成语音
|
||||
4. 流式返回 PCM 音频
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Optional, Callable, AsyncGenerator
|
||||
|
||||
# 导入各模块
|
||||
from sensevoice_asr import recognize as asr_recognize, init_sensevoice
|
||||
from glm_client import chat as llm_chat, chat_stream as llm_chat_stream
|
||||
from edge_tts_client import (
|
||||
text_to_speech_pcm_stream,
|
||||
text_to_speech_pcm,
|
||||
DEFAULT_VOICE,
|
||||
)
|
||||
|
||||
|
||||
async def init_pipeline():
|
||||
"""初始化 AI 管道(服务器启动时调用)"""
|
||||
await init_sensevoice()
|
||||
print("[AI Pipeline] 初始化完成")
|
||||
|
||||
|
||||
async def process_voice(
|
||||
pcm_audio: bytes,
|
||||
image_base64: Optional[str] = None,
|
||||
on_text: Optional[Callable[[str], None]] = None,
|
||||
on_audio: Optional[Callable[[bytes], None]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
处理语音输入,返回 AI 回复
|
||||
|
||||
Args:
|
||||
pcm_audio: PCM16 音频数据 (16kHz, mono)
|
||||
image_base64: 可选的图片(用于多模态)
|
||||
on_text: 文本回调(用于 UI 显示)
|
||||
on_audio: 音频回调(用于流式播放)
|
||||
|
||||
Returns:
|
||||
AI 回复文本
|
||||
"""
|
||||
# 1. 语音识别
|
||||
user_text = await asr_recognize(pcm_audio)
|
||||
|
||||
if not user_text:
|
||||
print("[AI Pipeline] 未识别到有效语音")
|
||||
return ""
|
||||
|
||||
print(f"[AI Pipeline] 用户说: {user_text}")
|
||||
|
||||
# 通知 UI
|
||||
if on_text:
|
||||
on_text(f"用户: {user_text}")
|
||||
|
||||
# 2. LLM 生成回复
|
||||
ai_response = await llm_chat(user_text, image_base64)
|
||||
|
||||
if not ai_response:
|
||||
print("[AI Pipeline] AI 无回复")
|
||||
return ""
|
||||
|
||||
print(f"[AI Pipeline] AI 回复: {ai_response}")
|
||||
|
||||
# 通知 UI
|
||||
if on_text:
|
||||
on_text(f"AI: {ai_response}")
|
||||
|
||||
# 3. TTS 合成并播放
|
||||
if on_audio:
|
||||
async for audio_chunk in text_to_speech_pcm_stream(ai_response):
|
||||
on_audio(audio_chunk)
|
||||
|
||||
return ai_response
|
||||
|
||||
|
||||
async def process_voice_stream(
|
||||
pcm_audio: bytes,
|
||||
image_base64: Optional[str] = None,
|
||||
) -> AsyncGenerator[tuple, None]:
|
||||
"""
|
||||
流式处理语音输入
|
||||
|
||||
Args:
|
||||
pcm_audio: PCM16 音频数据
|
||||
image_base64: 可选的图片
|
||||
|
||||
Yields:
|
||||
("text", str) - 文本片段
|
||||
("audio", bytes) - 音频片段
|
||||
"""
|
||||
# 1. 语音识别
|
||||
user_text = await asr_recognize(pcm_audio)
|
||||
|
||||
if not user_text:
|
||||
return
|
||||
|
||||
yield ("user_text", user_text)
|
||||
|
||||
# 2. LLM 流式生成 + 3. TTS 流式合成
|
||||
# 收集一定长度的文本后送 TTS
|
||||
buffer = ""
|
||||
punctuation = "。,!?;:,.!?;:"
|
||||
|
||||
async for text_chunk in llm_chat_stream(user_text, image_base64):
|
||||
yield ("ai_text", text_chunk)
|
||||
buffer += text_chunk
|
||||
|
||||
# 遇到标点时合成音频
|
||||
if buffer and buffer[-1] in punctuation:
|
||||
async for audio_chunk in text_to_speech_pcm_stream(buffer):
|
||||
yield ("audio", audio_chunk)
|
||||
buffer = ""
|
||||
|
||||
# 处理剩余文本
|
||||
if buffer.strip():
|
||||
async for audio_chunk in text_to_speech_pcm_stream(buffer):
|
||||
yield ("audio", audio_chunk)
|
||||
|
||||
|
||||
async def text_to_voice(text: str) -> bytes:
|
||||
"""
|
||||
文本转语音(用于导航提示等)
|
||||
|
||||
Args:
|
||||
text: 要合成的文本
|
||||
|
||||
Returns:
|
||||
PCM16 音频数据
|
||||
"""
|
||||
return await text_to_speech_pcm(text)
|
||||
|
||||
|
||||
async def text_to_voice_stream(text: str) -> AsyncGenerator[bytes, None]:
|
||||
"""
|
||||
流式文本转语音
|
||||
|
||||
Args:
|
||||
text: 要合成的文本
|
||||
|
||||
Yields:
|
||||
PCM16 音频块
|
||||
"""
|
||||
async for chunk in text_to_speech_pcm_stream(text):
|
||||
yield chunk
|
||||
1868
app_main.py
Normal file
1868
app_main.py
Normal file
File diff suppressed because it is too large
Load Diff
221
asr_core.py
Normal file
221
asr_core.py
Normal file
@@ -0,0 +1,221 @@
|
||||
# asr_core.py
|
||||
# -*- coding: utf-8 -*-
|
||||
import os, json, asyncio
|
||||
from typing import Any, Dict, List, Optional, Callable, Tuple
|
||||
|
||||
ASR_DEBUG_RAW = os.getenv("ASR_DEBUG_RAW", "0") == "1"
|
||||
|
||||
def _shorten(s: str, limit: int = 200) -> str:
|
||||
if not s:
|
||||
return ""
|
||||
return s if len(s) <= limit else (s[:limit] + "…")
|
||||
|
||||
def _safe_to_dict(x: Any) -> Dict[str, Any]:
|
||||
if isinstance(x, dict): return x
|
||||
for attr in ("to_dict", "model_dump", "__dict__"):
|
||||
try:
|
||||
v = getattr(x, attr, None)
|
||||
except Exception:
|
||||
v = None
|
||||
if callable(v):
|
||||
try:
|
||||
d = v()
|
||||
if isinstance(d, dict): return d
|
||||
except Exception:
|
||||
pass
|
||||
elif isinstance(v, dict):
|
||||
return v
|
||||
try:
|
||||
s = str(x)
|
||||
if s and s.lstrip().startswith("{") and s.rstrip().endswith("}"):
|
||||
return json.loads(s)
|
||||
except Exception:
|
||||
pass
|
||||
return {"_raw": str(x)}
|
||||
|
||||
def _extract_sentence(event_obj: Any) -> Tuple[Optional[str], Optional[bool]]:
|
||||
d = _safe_to_dict(event_obj)
|
||||
cands: List[Dict[str, Any]] = [d]
|
||||
for k in ("output", "data", "result"):
|
||||
v = d.get(k)
|
||||
if isinstance(v, dict):
|
||||
cands.append(v)
|
||||
for obj in cands:
|
||||
sent = obj.get("sentence")
|
||||
if isinstance(sent, dict):
|
||||
text = sent.get("text")
|
||||
is_end = sent.get("sentence_end")
|
||||
if is_end is not None:
|
||||
is_end = bool(is_end)
|
||||
return text, is_end
|
||||
for obj in cands:
|
||||
if "text" in obj and isinstance(obj.get("text"), str):
|
||||
return obj.get("text"), None
|
||||
return None, None
|
||||
|
||||
# ====== 仅热词触发的“全清零复位”配置 ======
|
||||
INTERRUPT_KEYWORDS = set(
|
||||
os.getenv("INTERRUPT_KEYWORDS", "停下,别说了,停止").split(",")
|
||||
)
|
||||
|
||||
# Day 21: 导航控制白名单 - 这些命令包含热词但不应触发重置
|
||||
# 例如"停止导航"包含"停止",但不应该触发 full_system_reset
|
||||
NAV_CONTROL_WHITELIST = [
|
||||
"停止导航", "结束导航", "停止检测", "停止红绿灯",
|
||||
"开始导航", "盲道导航", "开始过马路", "过马路结束",
|
||||
"帮我导航", "帮我过马路"
|
||||
]
|
||||
|
||||
|
||||
def _normalize_cn(s: str) -> str:
|
||||
try:
|
||||
import unicodedata
|
||||
s = "".join(" " if unicodedata.category(ch) == "Zs" else ch for ch in s)
|
||||
s = s.strip().lower()
|
||||
except Exception:
|
||||
s = (s or "").strip().lower()
|
||||
return s
|
||||
|
||||
# ============ ASR 全局总闸 ============
|
||||
_current_recognition: Optional[object] = None
|
||||
_rec_lock = asyncio.Lock()
|
||||
|
||||
async def set_current_recognition(r):
|
||||
global _current_recognition
|
||||
async with _rec_lock:
|
||||
_current_recognition = r
|
||||
|
||||
async def stop_current_recognition():
|
||||
global _current_recognition
|
||||
async with _rec_lock:
|
||||
r = _current_recognition
|
||||
_current_recognition = None
|
||||
if r:
|
||||
try:
|
||||
r.stop() # DashScope SDK 的实时识别停止
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ============ ASR 回调 ============
|
||||
class ASRCallback:
|
||||
"""
|
||||
设计目标:
|
||||
1) “停下 / 别说了 …”等热词一出现 → 立刻全清零复位(恢复到刚启动后的状态)。
|
||||
2) 除此之外【不接受打断】;AI 正在播报时,用户说话只做展示,不触发新一轮。
|
||||
3) 不再用 partial 叠加字符串;partial 只用于 UI 临时展示;只有 final sentence 用于驱动 AI。
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
on_sdk_error: Callable[[str], None],
|
||||
post: Callable[[asyncio.Future], None],
|
||||
ui_broadcast_partial,
|
||||
ui_broadcast_final,
|
||||
is_playing_now_fn: Callable[[], bool],
|
||||
start_ai_with_text_fn, # async (text)
|
||||
full_system_reset_fn, # async (reason)
|
||||
interrupt_lock: asyncio.Lock,
|
||||
):
|
||||
self._on_sdk_error = on_sdk_error
|
||||
self._post = post
|
||||
self._last_partial_for_ui: str = "" # 只用于 UI 展示
|
||||
self._last_final_text: str = "" # 以句末 final 为准
|
||||
self._hot_interrupted: bool = False # 本句是否因热词触发过复位(防抖)
|
||||
|
||||
self._ui_partial = ui_broadcast_partial
|
||||
self._ui_final = ui_broadcast_final
|
||||
self._is_playing = is_playing_now_fn
|
||||
self._start_ai = start_ai_with_text_fn
|
||||
self._full_reset = full_system_reset_fn
|
||||
self._interrupt_lock = interrupt_lock
|
||||
|
||||
def on_open(self): pass
|
||||
def on_close(self): pass
|
||||
def on_complete(self): pass
|
||||
|
||||
def on_error(self, err):
|
||||
try:
|
||||
self._post(self._ui_partial(""))
|
||||
self._on_sdk_error(str(err))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def on_result(self, result): self._handle(result)
|
||||
def on_event(self, event): self._handle(event)
|
||||
|
||||
def _has_hotword(self, text: str) -> bool:
|
||||
"""Day 21 修复: 检查是否包含热词,但排除导航控制命令"""
|
||||
t = _normalize_cn(text)
|
||||
if not t: return False
|
||||
|
||||
# Day 21: 先检查是否是导航控制命令(白名单),如果是则不触发热词
|
||||
for nav_cmd in NAV_CONTROL_WHITELIST:
|
||||
if _normalize_cn(nav_cmd) in t:
|
||||
return False # 导航命令不触发热词重置
|
||||
|
||||
# 检查热词
|
||||
for w in INTERRUPT_KEYWORDS:
|
||||
if w and _normalize_cn(w) in t:
|
||||
return True
|
||||
return False
|
||||
|
||||
def _handle(self, event: Any):
|
||||
if ASR_DEBUG_RAW:
|
||||
try:
|
||||
rawd = _safe_to_dict(event)
|
||||
print("[ASR EVENT RAW]", json.dumps(rawd, ensure_ascii=False), flush=True)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
text, is_end = _extract_sentence(event)
|
||||
if text is None:
|
||||
return
|
||||
text = text.strip()
|
||||
if not text:
|
||||
return
|
||||
|
||||
# ---------- ① 热词优先:命中就全清零并短路,绝不送 LLM ----------
|
||||
if not self._hot_interrupted and self._has_hotword(text):
|
||||
self._hot_interrupted = True
|
||||
|
||||
async def _hot_reset():
|
||||
async with self._interrupt_lock:
|
||||
print(f"[ASR HOTWORD] '{text}' -> FULL RESET, skip LLM", flush=True)
|
||||
await self._full_reset("Hotword interrupt")
|
||||
try:
|
||||
self._post(_hot_reset())
|
||||
except Exception:
|
||||
pass
|
||||
return
|
||||
|
||||
# ---------- ② partial:仅用于 UI 展示 ----------
|
||||
self._last_partial_for_ui = text
|
||||
try:
|
||||
print(f"[ASR PARTIAL] len={len(text)} text='{_shorten(text)}'", flush=True)
|
||||
self._post(self._ui_partial(self._last_partial_for_ui))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ---------- ③ final:仅 final 驱动 LLM(若未在播报) ----------
|
||||
if is_end is True:
|
||||
final_text = text
|
||||
try:
|
||||
print(f"[ASR FINAL] len={len(final_text)} text='{final_text}'", flush=True)
|
||||
self._post(self._ui_final(final_text))
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
if (not self._is_playing()) and final_text:
|
||||
async def _run_final():
|
||||
async with self._interrupt_lock:
|
||||
print(f"[LLM INPUT TEXT] {final_text}", flush=True)
|
||||
await self._start_ai(final_text)
|
||||
try:
|
||||
self._post(_run_final())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# 复位进入下一句
|
||||
self._last_partial_for_ui = ""
|
||||
self._last_final_text = ""
|
||||
self._hot_interrupted = False
|
||||
439
audio_compressor.py
Normal file
439
audio_compressor.py
Normal file
@@ -0,0 +1,439 @@
|
||||
# audio_compressor.py
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
音频压缩工具 - 用于减少网络带宽占用
|
||||
支持将16kHz 16bit PCM压缩为更小的格式
|
||||
"""
|
||||
import os
|
||||
import wave
|
||||
import struct
|
||||
import numpy as np
|
||||
from typing import Optional, Tuple
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class AudioCompressor:
|
||||
"""音频压缩器 - 支持多种压缩算法"""
|
||||
|
||||
@staticmethod
|
||||
def pcm16_to_ulaw(pcm_data: bytes) -> bytes:
|
||||
"""
|
||||
将16位PCM转换为8位μ-law
|
||||
压缩率:50%(16bit -> 8bit)
|
||||
"""
|
||||
# 解析16位PCM
|
||||
samples = np.frombuffer(pcm_data, dtype=np.int16)
|
||||
|
||||
# μ-law压缩
|
||||
ulaw_data = bytearray()
|
||||
for sample in samples:
|
||||
ulaw_byte = AudioCompressor._linear_to_ulaw(sample)
|
||||
ulaw_data.append(ulaw_byte)
|
||||
|
||||
return bytes(ulaw_data)
|
||||
|
||||
@staticmethod
|
||||
def ulaw_to_pcm16(ulaw_data: bytes) -> bytes:
|
||||
"""
|
||||
将8位μ-law转换回16位PCM
|
||||
"""
|
||||
pcm_samples = []
|
||||
for ulaw_byte in ulaw_data:
|
||||
pcm_sample = AudioCompressor._ulaw_to_linear(ulaw_byte)
|
||||
pcm_samples.append(pcm_sample)
|
||||
|
||||
return np.array(pcm_samples, dtype=np.int16).tobytes()
|
||||
|
||||
@staticmethod
|
||||
def _linear_to_ulaw(sample: int) -> int:
|
||||
"""
|
||||
16位线性PCM转μ-law
|
||||
"""
|
||||
# μ-law编码表
|
||||
ULAW_MAX = 0x1FFF
|
||||
ULAW_BIAS = 0x84
|
||||
|
||||
# 限制范围
|
||||
sample = max(-32768, min(32767, sample))
|
||||
|
||||
# 获取符号位
|
||||
sign = 0
|
||||
if sample < 0:
|
||||
sign = 0x80
|
||||
sample = -sample
|
||||
|
||||
# 添加偏置
|
||||
sample = sample + ULAW_BIAS
|
||||
|
||||
# 限制最大值
|
||||
if sample > ULAW_MAX:
|
||||
sample = ULAW_MAX
|
||||
|
||||
# 查找指数和尾数
|
||||
exponent = 7
|
||||
for exp in range(7, -1, -1):
|
||||
if sample & (0x4000 >> exp):
|
||||
exponent = exp
|
||||
break
|
||||
|
||||
mantissa = (sample >> (exponent + 3)) & 0x0F
|
||||
ulawbyte = ~(sign | (exponent << 4) | mantissa) & 0xFF
|
||||
|
||||
return ulawbyte
|
||||
|
||||
@staticmethod
|
||||
def _ulaw_to_linear(ulawbyte: int) -> int:
|
||||
"""
|
||||
μ-law转16位线性PCM
|
||||
"""
|
||||
ULAW_BIAS = 0x84
|
||||
|
||||
ulawbyte = ~ulawbyte & 0xFF
|
||||
sign = ulawbyte & 0x80
|
||||
exponent = (ulawbyte >> 4) & 0x07
|
||||
mantissa = ulawbyte & 0x0F
|
||||
|
||||
sample = ((mantissa << 3) + ULAW_BIAS) << exponent
|
||||
|
||||
if sign:
|
||||
sample = -sample
|
||||
|
||||
return sample
|
||||
|
||||
@staticmethod
|
||||
def pcm16_to_adpcm(pcm_data: bytes) -> bytes:
|
||||
"""
|
||||
将16位PCM转换为4位ADPCM
|
||||
压缩率:75%(16bit -> 4bit)
|
||||
保持较好的语音质量
|
||||
"""
|
||||
samples = np.frombuffer(pcm_data, dtype=np.int16)
|
||||
|
||||
# IMA ADPCM 步长表
|
||||
step_table = [
|
||||
7, 8, 9, 10, 11, 12, 13, 14, 16, 17,
|
||||
19, 21, 23, 25, 28, 31, 34, 37, 41, 45,
|
||||
50, 55, 60, 66, 73, 80, 88, 97, 107, 118,
|
||||
130, 143, 157, 173, 190, 209, 230, 253, 279, 307,
|
||||
337, 371, 408, 449, 494, 544, 598, 658, 724, 796,
|
||||
876, 963, 1060, 1166, 1282, 1411, 1552, 1707, 1878, 2066,
|
||||
2272, 2499, 2749, 3024, 3327, 3660, 4026, 4428, 4871, 5358,
|
||||
5894, 6484, 7132, 7845, 8630, 9493, 10442, 11487, 12635, 13899,
|
||||
15289, 16818, 18500, 20350, 22385, 24623, 27086, 29794, 32767
|
||||
]
|
||||
|
||||
# 索引调整表
|
||||
index_table = [-1, -1, -1, -1, 2, 4, 6, 8]
|
||||
|
||||
# 初始化
|
||||
adpcm_data = bytearray()
|
||||
predicted = 0
|
||||
step_index = 0
|
||||
|
||||
# 每两个样本打包成一个字节
|
||||
for i in range(0, len(samples), 2):
|
||||
byte = 0
|
||||
|
||||
for j in range(2):
|
||||
if i + j < len(samples):
|
||||
sample = samples[i + j]
|
||||
|
||||
# 计算差值
|
||||
diff = sample - predicted
|
||||
|
||||
# 量化
|
||||
step = step_table[step_index]
|
||||
adpcm_sample = 0
|
||||
|
||||
if diff < 0:
|
||||
adpcm_sample = 8
|
||||
diff = -diff
|
||||
|
||||
if diff >= step:
|
||||
adpcm_sample |= 4
|
||||
diff -= step
|
||||
|
||||
step >>= 1
|
||||
if diff >= step:
|
||||
adpcm_sample |= 2
|
||||
diff -= step
|
||||
|
||||
step >>= 1
|
||||
if diff >= step:
|
||||
adpcm_sample |= 1
|
||||
|
||||
# 更新预测值
|
||||
step = step_table[step_index]
|
||||
diff = 0
|
||||
if adpcm_sample & 4:
|
||||
diff += step
|
||||
step >>= 1
|
||||
if adpcm_sample & 2:
|
||||
diff += step
|
||||
step >>= 1
|
||||
if adpcm_sample & 1:
|
||||
diff += step
|
||||
step >>= 1
|
||||
diff += step
|
||||
|
||||
if adpcm_sample & 8:
|
||||
predicted -= diff
|
||||
else:
|
||||
predicted += diff
|
||||
|
||||
# 限制预测值范围
|
||||
if predicted > 32767:
|
||||
predicted = 32767
|
||||
elif predicted < -32768:
|
||||
predicted = -32768
|
||||
|
||||
# 更新步长索引
|
||||
step_index += index_table[adpcm_sample & 7]
|
||||
if step_index < 0:
|
||||
step_index = 0
|
||||
elif step_index > 88:
|
||||
step_index = 88
|
||||
|
||||
# 打包到字节中
|
||||
if j == 0:
|
||||
byte = adpcm_sample
|
||||
else:
|
||||
byte |= (adpcm_sample << 4)
|
||||
|
||||
adpcm_data.append(byte)
|
||||
|
||||
# 添加头部信息:初始预测值和步长索引
|
||||
header = struct.pack('<hB', predicted, step_index)
|
||||
return header + bytes(adpcm_data)
|
||||
|
||||
@staticmethod
|
||||
def adpcm_to_pcm16(adpcm_data: bytes) -> bytes:
|
||||
"""
|
||||
将4位ADPCM转换回16位PCM
|
||||
"""
|
||||
if len(adpcm_data) < 3:
|
||||
return b''
|
||||
|
||||
# 读取头部
|
||||
predicted, step_index = struct.unpack('<hB', adpcm_data[:3])
|
||||
adpcm_bytes = adpcm_data[3:]
|
||||
|
||||
# IMA ADPCM 步长表
|
||||
step_table = [
|
||||
7, 8, 9, 10, 11, 12, 13, 14, 16, 17,
|
||||
19, 21, 23, 25, 28, 31, 34, 37, 41, 45,
|
||||
50, 55, 60, 66, 73, 80, 88, 97, 107, 118,
|
||||
130, 143, 157, 173, 190, 209, 230, 253, 279, 307,
|
||||
337, 371, 408, 449, 494, 544, 598, 658, 724, 796,
|
||||
876, 963, 1060, 1166, 1282, 1411, 1552, 1707, 1878, 2066,
|
||||
2272, 2499, 2749, 3024, 3327, 3660, 4026, 4428, 4871, 5358,
|
||||
5894, 6484, 7132, 7845, 8630, 9493, 10442, 11487, 12635, 13899,
|
||||
15289, 16818, 18500, 20350, 22385, 24623, 27086, 29794, 32767
|
||||
]
|
||||
|
||||
# 索引调整表
|
||||
index_table = [-1, -1, -1, -1, 2, 4, 6, 8]
|
||||
|
||||
pcm_samples = []
|
||||
|
||||
for byte in adpcm_bytes:
|
||||
# 解码两个4位样本
|
||||
for shift in [0, 4]:
|
||||
adpcm_sample = (byte >> shift) & 0x0F
|
||||
|
||||
# 计算差值
|
||||
step = step_table[step_index]
|
||||
diff = 0
|
||||
|
||||
if adpcm_sample & 4:
|
||||
diff += step
|
||||
step >>= 1
|
||||
if adpcm_sample & 2:
|
||||
diff += step
|
||||
step >>= 1
|
||||
if adpcm_sample & 1:
|
||||
diff += step
|
||||
step >>= 1
|
||||
diff += step
|
||||
|
||||
if adpcm_sample & 8:
|
||||
predicted -= diff
|
||||
else:
|
||||
predicted += diff
|
||||
|
||||
# 限制范围
|
||||
if predicted > 32767:
|
||||
predicted = 32767
|
||||
elif predicted < -32768:
|
||||
predicted = -32768
|
||||
|
||||
pcm_samples.append(predicted)
|
||||
|
||||
# 更新步长索引
|
||||
step_index += index_table[adpcm_sample & 7]
|
||||
if step_index < 0:
|
||||
step_index = 0
|
||||
elif step_index > 88:
|
||||
step_index = 88
|
||||
|
||||
return np.array(pcm_samples, dtype=np.int16).tobytes()
|
||||
|
||||
@staticmethod
|
||||
def downsample_pcm16(pcm_data: bytes, from_rate: int = 16000, to_rate: int = 8000) -> bytes:
|
||||
"""
|
||||
降采样(可选)
|
||||
16kHz -> 8kHz 可以再减少50%数据量
|
||||
"""
|
||||
if from_rate == to_rate:
|
||||
return pcm_data
|
||||
|
||||
# 解析PCM数据
|
||||
samples = np.frombuffer(pcm_data, dtype=np.int16)
|
||||
|
||||
# 简单的降采样(每隔一个样本取一个)
|
||||
if from_rate == 16000 and to_rate == 8000:
|
||||
downsampled = samples[::2]
|
||||
else:
|
||||
# 更复杂的重采样需要scipy
|
||||
ratio = to_rate / from_rate
|
||||
new_length = int(len(samples) * ratio)
|
||||
downsampled = np.interp(
|
||||
np.linspace(0, len(samples) - 1, new_length),
|
||||
np.arange(len(samples)),
|
||||
samples
|
||||
).astype(np.int16)
|
||||
|
||||
return downsampled.tobytes()
|
||||
|
||||
|
||||
class CompressedAudioCache:
|
||||
"""压缩音频缓存"""
|
||||
|
||||
def __init__(self, compression_type: str = "adpcm", use_downsample: bool = False):
|
||||
"""
|
||||
compression_type: "none", "ulaw", "adpcm"
|
||||
"""
|
||||
self.compression_type = compression_type
|
||||
self.use_downsample = use_downsample
|
||||
self._cache = {} # {filepath: compressed_data}
|
||||
self._original_sizes = {} # {filepath: original_size}
|
||||
|
||||
def load_and_compress(self, filepath: str) -> Optional[bytes]:
|
||||
"""加载并压缩音频文件(统一转换为8kHz)"""
|
||||
if filepath in self._cache:
|
||||
return self._cache[filepath]
|
||||
|
||||
try:
|
||||
with wave.open(filepath, 'rb') as wav:
|
||||
# 检查格式
|
||||
channels = wav.getnchannels()
|
||||
sampwidth = wav.getsampwidth()
|
||||
framerate = wav.getframerate()
|
||||
|
||||
if channels != 1:
|
||||
logger.warning(f"{filepath} 不是单声道")
|
||||
if sampwidth != 2:
|
||||
logger.warning(f"{filepath} 不是16位音频")
|
||||
|
||||
# 读取所有数据
|
||||
frames = wav.readframes(wav.getnframes())
|
||||
|
||||
# 如果是立体声,转换为单声道
|
||||
if channels == 2:
|
||||
import audioop
|
||||
frames = audioop.tomono(frames, sampwidth, 1, 0)
|
||||
|
||||
# 【修改】始终转换为16kHz(匹配客户端播放器)
|
||||
if framerate != 16000:
|
||||
import audioop
|
||||
frames, _ = audioop.ratecv(frames, sampwidth, 1, framerate, 16000, None)
|
||||
framerate = 16000
|
||||
|
||||
# 记录原始大小(转换后的大小)
|
||||
self._original_sizes[filepath] = len(frames)
|
||||
|
||||
# 压缩
|
||||
if self.compression_type == "ulaw":
|
||||
compressed = AudioCompressor.pcm16_to_ulaw(frames)
|
||||
# 添加简单的头部信息(1字节标识 + 4字节原始长度)
|
||||
header = struct.pack('!BI', 0x01, len(frames)) # 0x01表示μ-law
|
||||
compressed = header + compressed
|
||||
elif self.compression_type == "adpcm":
|
||||
compressed = AudioCompressor.pcm16_to_adpcm(frames)
|
||||
# 添加简单的头部信息(1字节标识 + 4字节原始长度)
|
||||
header = struct.pack('!BI', 0x02, len(frames)) # 0x02表示ADPCM
|
||||
compressed = header + compressed
|
||||
else:
|
||||
compressed = frames
|
||||
|
||||
self._cache[filepath] = compressed
|
||||
|
||||
# 打印压缩率
|
||||
compression_ratio = len(compressed) / self._original_sizes[filepath]
|
||||
logger.info(f"[压缩] {os.path.basename(filepath)}: "
|
||||
f"{self._original_sizes[filepath]} -> {len(compressed)} bytes "
|
||||
f"({compression_ratio:.1%})")
|
||||
|
||||
return compressed
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"压缩音频失败 {filepath}: {e}")
|
||||
return None
|
||||
|
||||
def decompress(self, compressed_data: bytes) -> Optional[bytes]:
|
||||
"""解压音频数据"""
|
||||
if not compressed_data or len(compressed_data) < 5:
|
||||
return compressed_data
|
||||
|
||||
try:
|
||||
# 检查头部
|
||||
compression_type = compressed_data[0]
|
||||
if compression_type == 0x01: # μ-law标识
|
||||
header_size = 5
|
||||
original_length = struct.unpack('!I', compressed_data[1:5])[0]
|
||||
ulaw_data = compressed_data[header_size:]
|
||||
|
||||
# μ-law解压
|
||||
pcm_data = AudioCompressor.ulaw_to_pcm16(ulaw_data)
|
||||
|
||||
return pcm_data
|
||||
elif compression_type == 0x02: # ADPCM标识
|
||||
header_size = 5
|
||||
original_length = struct.unpack('!I', compressed_data[1:5])[0]
|
||||
adpcm_data = compressed_data[header_size:]
|
||||
|
||||
# ADPCM解压
|
||||
pcm_data = AudioCompressor.adpcm_to_pcm16(adpcm_data)
|
||||
|
||||
return pcm_data
|
||||
else:
|
||||
# 未压缩的数据
|
||||
return compressed_data
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"解压音频失败: {e}")
|
||||
return compressed_data
|
||||
|
||||
def get_compression_stats(self) -> dict:
|
||||
"""获取压缩统计信息"""
|
||||
total_original = sum(self._original_sizes.values())
|
||||
total_compressed = sum(len(data) for data in self._cache.values())
|
||||
|
||||
return {
|
||||
"files_cached": len(self._cache),
|
||||
"total_original_size": total_original,
|
||||
"total_compressed_size": total_compressed,
|
||||
"compression_ratio": total_compressed / total_original if total_original > 0 else 0,
|
||||
"bytes_saved": total_original - total_compressed
|
||||
}
|
||||
|
||||
|
||||
# 全局压缩音频缓存实例
|
||||
# 默认使用ADPCM压缩,音质更好,压缩率也不错(75%)
|
||||
# 可通过环境变量 AIGLASS_COMPRESS_TYPE 设置: none, ulaw, adpcm
|
||||
import os
|
||||
compression_type = os.getenv("AIGLASS_COMPRESS_TYPE", "adpcm").lower()
|
||||
if compression_type not in ["none", "ulaw", "adpcm"]:
|
||||
compression_type = "adpcm"
|
||||
compressed_audio_cache = CompressedAudioCache(compression_type=compression_type, use_downsample=False)
|
||||
392
audio_player.py
Normal file
392
audio_player.py
Normal file
@@ -0,0 +1,392 @@
|
||||
# audio_player.py
|
||||
# 处理预录音频文件的播放,通过ESP32扬声器输出
|
||||
|
||||
import os
|
||||
import wave
|
||||
import json
|
||||
import asyncio
|
||||
import threading
|
||||
import queue
|
||||
import time
|
||||
from audio_stream import broadcast_pcm16_realtime
|
||||
from audio_compressor import compressed_audio_cache, AudioCompressor
|
||||
|
||||
# 导入录制器(避免循环导入,在需要时动态导入)
|
||||
_recorder_imported = False
|
||||
_sync_recorder = None
|
||||
|
||||
def _get_recorder():
|
||||
"""延迟导入录制器"""
|
||||
global _recorder_imported, _sync_recorder
|
||||
if not _recorder_imported:
|
||||
try:
|
||||
import sync_recorder as sr
|
||||
_sync_recorder = sr
|
||||
_recorder_imported = True
|
||||
except Exception as e:
|
||||
print(f"[AUDIO] 无法导入录制器: {e}")
|
||||
_recorder_imported = True # 标记已尝试,避免重复
|
||||
return _sync_recorder
|
||||
|
||||
# 兼容旧工程中的示例音频(保留)- 改为相对路径
|
||||
AUDIO_BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "music")
|
||||
|
||||
# 新增:voice 目录与映射表
|
||||
# 使用脚本所在目录的 voice 文件夹,避免工作目录问题
|
||||
VOICE_DIR = os.getenv("VOICE_DIR", os.path.join(os.path.dirname(os.path.abspath(__file__)), "voice"))
|
||||
VOICE_MAP_FILE = os.path.join(VOICE_DIR, "map.zh-CN.json")
|
||||
|
||||
# 音频文件映射(将合并 voice 映射)
|
||||
AUDIO_MAP = {
|
||||
"检测到物体": os.path.join(AUDIO_BASE_DIR, "音频1.wav"),
|
||||
"向上": os.path.join(AUDIO_BASE_DIR, "音频2.wav"),
|
||||
"向下": os.path.join(AUDIO_BASE_DIR, "音频3.wav"),
|
||||
"向左": os.path.join(AUDIO_BASE_DIR, "音频4.wav"),
|
||||
"向右": os.path.join(AUDIO_BASE_DIR, "音频5.wav"),
|
||||
"OK": os.path.join(AUDIO_BASE_DIR, "音频6.wav"),
|
||||
"向前": os.path.join(AUDIO_BASE_DIR, "音频7.wav"),
|
||||
"后退": os.path.join(AUDIO_BASE_DIR, "音频8.wav"),
|
||||
"拿到物体": os.path.join(AUDIO_BASE_DIR, "音频9.wav"),
|
||||
}
|
||||
|
||||
# 音频缓存,避免重复读取
|
||||
_audio_cache = {}
|
||||
|
||||
# 音频播放队列和工作线程 - 使用优先级队列
|
||||
_audio_queue = queue.PriorityQueue(maxsize=10)
|
||||
_audio_priority = 0 # 递增的优先级计数器
|
||||
_worker_thread = None
|
||||
_worker_loop = None
|
||||
_is_playing = False # 标记是否正在播放音频
|
||||
_playing_lock = threading.Lock() # 播放锁
|
||||
_initialized = False
|
||||
_last_play_ts = 0.0 # 记录上次播放结束时间,用于决定预热静音长度
|
||||
|
||||
def load_wav_file(filepath):
|
||||
"""加载WAV文件并返回PCM数据(自动转换为8kHz)"""
|
||||
if filepath in _audio_cache:
|
||||
return _audio_cache[filepath]
|
||||
|
||||
# 使用压缩缓存
|
||||
if os.getenv("AIGLASS_COMPRESS_AUDIO", "1") == "1":
|
||||
compressed_data = compressed_audio_cache.load_and_compress(filepath)
|
||||
if compressed_data:
|
||||
# 存储压缩后的数据
|
||||
_audio_cache[filepath] = compressed_data
|
||||
return compressed_data
|
||||
|
||||
# 原始加载方式(不压缩)
|
||||
try:
|
||||
with wave.open(filepath, 'rb') as wav:
|
||||
# 检查音频格式
|
||||
channels = wav.getnchannels()
|
||||
sampwidth = wav.getsampwidth()
|
||||
framerate = wav.getframerate()
|
||||
|
||||
if channels != 1:
|
||||
print(f"[AUDIO] 警告: {filepath} 不是单声道,将只使用第一个声道")
|
||||
if sampwidth != 2:
|
||||
print(f"[AUDIO] 警告: {filepath} 不是16位音频")
|
||||
|
||||
# 读取所有帧
|
||||
frames = wav.readframes(wav.getnframes())
|
||||
|
||||
# 如果是立体声,只取左声道
|
||||
if channels == 2:
|
||||
import audioop
|
||||
frames = audioop.tomono(frames, sampwidth, 1, 0)
|
||||
|
||||
# 统一转换为16kHz(使用ratecv保证音调和速度不变)
|
||||
if framerate != 16000:
|
||||
import audioop
|
||||
frames, _ = audioop.ratecv(frames, sampwidth, 1, framerate, 16000, None)
|
||||
print(f"[AUDIO] 重采样: {filepath} {framerate}Hz -> 16000Hz")
|
||||
|
||||
_audio_cache[filepath] = frames
|
||||
return frames
|
||||
|
||||
except Exception as e:
|
||||
print(f"[AUDIO] 加载音频文件失败 {filepath}: {e}")
|
||||
return None
|
||||
|
||||
def _merge_voice_map():
|
||||
"""读取 voice/map.zh-CN.json 并合并到 AUDIO_MAP"""
|
||||
try:
|
||||
if not os.path.exists(VOICE_MAP_FILE):
|
||||
print(f"[AUDIO] 未找到映射文件: {VOICE_MAP_FILE}")
|
||||
return
|
||||
with open(VOICE_MAP_FILE, "r", encoding="utf-8") as f:
|
||||
m = json.load(f)
|
||||
added = 0
|
||||
for text, info in (m or {}).items():
|
||||
files = (info or {}).get("files") or []
|
||||
if not files:
|
||||
continue
|
||||
fname = files[0]
|
||||
fpath = os.path.join(VOICE_DIR, fname)
|
||||
if os.path.exists(fpath):
|
||||
AUDIO_MAP[text] = fpath
|
||||
added += 1
|
||||
else:
|
||||
print(f"[AUDIO] 映射文件缺失: {fpath}")
|
||||
print(f"[AUDIO] 已合并 voice 映射 {added} 条")
|
||||
except Exception as e:
|
||||
print(f"[AUDIO] 读取 voice 映射失败: {e}")
|
||||
|
||||
def preload_all_audio():
|
||||
"""预加载所有音频文件到内存"""
|
||||
print("[AUDIO] 开始预加载音频文件...")
|
||||
loaded_count = 0
|
||||
|
||||
# 【暂时禁用变速】因为需要修改缓存机制
|
||||
# 需要加速的音频列表(斑马线相关)
|
||||
# speedup_keywords = ["斑马线", "画面"]
|
||||
# speedup_factor = 1.3 # 加速30%
|
||||
|
||||
for audio_key, filepath in AUDIO_MAP.items():
|
||||
if os.path.exists(filepath):
|
||||
# 【修复】暂时使用默认速度加载
|
||||
# need_speedup = any(keyword in audio_key for keyword in speedup_keywords)
|
||||
# speed = speedup_factor if need_speedup else 1.0
|
||||
|
||||
data = load_wav_file(filepath) # 使用默认参数
|
||||
if data:
|
||||
loaded_count += 1
|
||||
# if need_speedup:
|
||||
# print(f"[AUDIO] 加载(加速{speedup_factor}x): {audio_key}")
|
||||
else:
|
||||
# 降低噪声输出
|
||||
pass
|
||||
print(f"[AUDIO] 预加载完成,共加载 {loaded_count} 个音频文件")
|
||||
|
||||
def _audio_worker():
|
||||
"""音频播放工作线程"""
|
||||
global _worker_loop
|
||||
|
||||
# 尝试设置线程优先级(Windows特定)
|
||||
try:
|
||||
import ctypes
|
||||
import sys
|
||||
if sys.platform == "win32":
|
||||
# 设置线程为高优先级
|
||||
ctypes.windll.kernel32.SetThreadPriority(
|
||||
ctypes.windll.kernel32.GetCurrentThread(),
|
||||
1 # THREAD_PRIORITY_ABOVE_NORMAL
|
||||
)
|
||||
print("[AUDIO] 设置音频线程为高优先级")
|
||||
except Exception as e:
|
||||
print(f"[AUDIO] 设置线程优先级失败: {e}")
|
||||
|
||||
_worker_loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(_worker_loop)
|
||||
|
||||
async def process_queue():
|
||||
while True:
|
||||
try:
|
||||
# 从优先级队列获取数据
|
||||
priority_data = await asyncio.get_event_loop().run_in_executor(None, _audio_queue.get, True)
|
||||
if priority_data is None:
|
||||
break
|
||||
# 解包优先级和实际音频数据
|
||||
if isinstance(priority_data, tuple) and len(priority_data) == 2:
|
||||
_, audio_data = priority_data
|
||||
else:
|
||||
audio_data = priority_data
|
||||
await _broadcast_audio_optimized(audio_data)
|
||||
except Exception as e:
|
||||
print(f"[AUDIO] 工作线程错误: {e}")
|
||||
|
||||
_worker_loop.run_until_complete(process_queue())
|
||||
|
||||
async def _broadcast_audio_optimized(pcm_data: bytes):
|
||||
"""优化的音频广播:单次调用由底层按20ms节拍发送,移除重复节拍和Python层sleep"""
|
||||
global _last_play_ts, _is_playing
|
||||
try:
|
||||
# 设置播放标志
|
||||
with _playing_lock:
|
||||
_is_playing = True
|
||||
# 此时 pcm_data 应该已经是解压后的16位PCM数据了(8kHz)
|
||||
now = time.monotonic()
|
||||
idle_sec = now - (_last_play_ts or now)
|
||||
# 首次或长时间空闲后,预热更长静音;否则小静音
|
||||
lead_ms = 160 if idle_sec > 3.0 else 60
|
||||
tail_ms = 40
|
||||
|
||||
lead_silence = b'\x00' * (lead_ms * 16000 * 2 // 1000) # 16k * 2B
|
||||
tail_silence = b'\x00' * (tail_ms * 16000 * 2 // 1000)
|
||||
|
||||
# 完整音频数据(包含静音)
|
||||
full_audio = lead_silence + pcm_data + tail_silence
|
||||
|
||||
# 注意:录制在 broadcast_pcm16_realtime 中统一完成,避免重复
|
||||
|
||||
# 单次调用交给底层 pacing(20ms节拍在 broadcast_pcm16_realtime 内部实现)
|
||||
await broadcast_pcm16_realtime(full_audio)
|
||||
|
||||
_last_play_ts = time.monotonic()
|
||||
except Exception as e:
|
||||
print(f"[AUDIO] 广播音频失败: {e}")
|
||||
finally:
|
||||
# 清除播放标志
|
||||
with _playing_lock:
|
||||
_is_playing = False
|
||||
|
||||
def initialize_audio_system():
|
||||
"""初始化音频系统"""
|
||||
global _initialized, _worker_thread, _last_play_ts
|
||||
|
||||
if _initialized:
|
||||
return
|
||||
|
||||
# 先合并 voice 映射,再预加载
|
||||
_merge_voice_map()
|
||||
preload_all_audio()
|
||||
|
||||
_worker_thread = threading.Thread(target=_audio_worker, daemon=True)
|
||||
_worker_thread.start()
|
||||
_initialized = True
|
||||
_last_play_ts = 0.0
|
||||
|
||||
# 显示压缩统计
|
||||
if os.getenv("AIGLASS_COMPRESS_AUDIO", "1") == "1":
|
||||
stats = compressed_audio_cache.get_compression_stats()
|
||||
print(f"[AUDIO] 音频压缩统计:")
|
||||
print(f" - 文件数: {stats['files_cached']}")
|
||||
print(f" - 原始大小: {stats['total_original_size'] / 1024:.1f} KB")
|
||||
print(f" - 压缩后: {stats['total_compressed_size'] / 1024:.1f} KB")
|
||||
print(f" - 压缩率: {stats['compression_ratio']:.1%}")
|
||||
print(f" - 节省: {stats['bytes_saved'] / 1024:.1f} KB")
|
||||
|
||||
print("[AUDIO] 音频系统初始化完成(预加载+工作线程)")
|
||||
|
||||
def play_audio_threadsafe(audio_key):
|
||||
"""线程安全的音频播放函数"""
|
||||
global _audio_queue, _audio_priority
|
||||
|
||||
if not _initialized:
|
||||
initialize_audio_system()
|
||||
|
||||
if audio_key not in AUDIO_MAP:
|
||||
print(f"[AUDIO] 未知的音频键: {audio_key}")
|
||||
return
|
||||
|
||||
filepath = AUDIO_MAP[audio_key]
|
||||
pcm_data = _audio_cache.get(filepath)
|
||||
if pcm_data is None:
|
||||
print(f"[AUDIO] 音频未在缓存中: {audio_key}")
|
||||
return
|
||||
|
||||
# 如果是压缩的数据,先解压
|
||||
if pcm_data and len(pcm_data) > 5 and pcm_data[0] in [0x01, 0x02]:
|
||||
pcm_data = compressed_audio_cache.decompress(pcm_data)
|
||||
if not pcm_data:
|
||||
print(f"[AUDIO] 解压失败: {audio_key}")
|
||||
return
|
||||
|
||||
# 【优化】实时播报策略:保持队列最小化,避免积压延迟
|
||||
queue_size = _audio_queue.qsize()
|
||||
|
||||
# 检查是否正在播放
|
||||
with _playing_lock:
|
||||
currently_playing = _is_playing
|
||||
|
||||
# 实时策略:只允许1个积压,超过立即清空
|
||||
if queue_size > 0 and not currently_playing:
|
||||
# 未播放时立即清空,播放最新语音
|
||||
print(f"[AUDIO] 清空队列(当前{queue_size}个),播放最新语音")
|
||||
_audio_queue = queue.PriorityQueue(maxsize=10)
|
||||
elif queue_size > 1 and currently_playing:
|
||||
# 正在播放时,如果积压>1个则清空(保持实时性)
|
||||
print(f"[AUDIO] 队列积压({queue_size}个),清空以保持实时")
|
||||
_audio_queue = queue.PriorityQueue(maxsize=10)
|
||||
try:
|
||||
# 使用优先级队列,确保音频按顺序播放
|
||||
_audio_priority += 1
|
||||
_audio_queue.put_nowait((_audio_priority, pcm_data))
|
||||
if queue_size >= 1:
|
||||
print(f"[AUDIO] 播放队列当前大小: {queue_size + 1}")
|
||||
except queue.Full:
|
||||
# 播放队列满则丢弃,保持实时性
|
||||
print(f"[AUDIO] 队列满,丢弃: {audio_key}")
|
||||
pass
|
||||
|
||||
# 全局语音节流 - Day 22 优化: 降低冷却时间
|
||||
_last_voice_time = 0
|
||||
_last_voice_text = ""
|
||||
_voice_cooldown = 1.5 # 相同语音至少间隔1.5秒 (从3.0降低)
|
||||
|
||||
# 语音优先级定义
|
||||
VOICE_PRIORITY = {
|
||||
'obstacle': 100, # 障碍物 - 最高优先级
|
||||
'direction': 50, # 转向/平移 - 中等优先级
|
||||
'straight': 10, # 保持直行 - 最低优先级
|
||||
'other': 30 # 其他 - 默认优先级
|
||||
}
|
||||
|
||||
# 新增:根据中文提示文案直接播放(会做轻度规范化与降级)
|
||||
def play_voice_text(text: str):
|
||||
"""
|
||||
传入中文提示,自动匹配 voice 映射并播放。
|
||||
- 尝试原文
|
||||
- 尝试补全/去除句末标点(。.!!??)
|
||||
- 若包含“前方有…注意避让”但未命中,降级到“前方有障碍物,注意避让。”
|
||||
"""
|
||||
global _last_voice_time, _last_voice_text
|
||||
|
||||
if not text:
|
||||
return
|
||||
if not _initialized:
|
||||
initialize_audio_system()
|
||||
|
||||
# 全局节流:相同文本短时间内不重复播放
|
||||
current_time = time.time()
|
||||
if text == _last_voice_text and current_time - _last_voice_time < _voice_cooldown:
|
||||
return # 静默跳过
|
||||
|
||||
candidates = []
|
||||
t = text.strip()
|
||||
candidates.append(t)
|
||||
# 尝试补全句号
|
||||
if t[-1:] not in ("。", "!", "!", "?", "?", "."):
|
||||
candidates.append(t + "。")
|
||||
else:
|
||||
# 尝试去掉标点
|
||||
t2 = t.rstrip("。.!!??")
|
||||
if t2 and t2 != t:
|
||||
candidates.append(t2)
|
||||
|
||||
# 逐一尝试匹配
|
||||
for ck in candidates:
|
||||
if ck in AUDIO_MAP:
|
||||
play_audio_threadsafe(ck)
|
||||
_last_voice_text = text
|
||||
_last_voice_time = current_time
|
||||
return
|
||||
|
||||
# 针对“前方有…注意避让”降级
|
||||
if ("前方有" in t) and ("注意避让" in t):
|
||||
fallback = "前方有障碍物,注意避让。"
|
||||
if fallback in AUDIO_MAP:
|
||||
play_audio_threadsafe(fallback)
|
||||
_last_voice_text = text
|
||||
_last_voice_time = current_time
|
||||
return
|
||||
|
||||
# 针对“请向…平移/微调/转动”类词条,常见变体尝试
|
||||
base = t.rstrip("。.!!??")
|
||||
if base in AUDIO_MAP:
|
||||
play_audio_threadsafe(base)
|
||||
_last_voice_text = text
|
||||
_last_voice_time = current_time
|
||||
return
|
||||
if base + "。" in AUDIO_MAP:
|
||||
play_audio_threadsafe(base + "。")
|
||||
_last_voice_text = text
|
||||
_last_voice_time = current_time
|
||||
return
|
||||
|
||||
# 未匹配则输出日志(便于调试)
|
||||
print(f"[AUDIO] 未找到匹配语音: {text}")
|
||||
|
||||
# 兼容旧接口
|
||||
play_audio_on_esp32 = play_audio_threadsafe
|
||||
299
audio_stream.py
Normal file
299
audio_stream.py
Normal file
@@ -0,0 +1,299 @@
|
||||
# audio_stream.py
|
||||
# -*- coding: utf-8 -*-
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Set, List, Tuple, Any, Dict
|
||||
from fastapi import Request
|
||||
from fastapi.responses import StreamingResponse
|
||||
|
||||
# ===== 下行 WAV 流基础参数 =====
|
||||
STREAM_SR = 8000 # 改为8kHz,ESP32支持
|
||||
STREAM_CH = 1
|
||||
STREAM_SW = 2
|
||||
BYTES_PER_20MS_16K = STREAM_SR * STREAM_SW * 20 // 1000 # 320B (8kHz)
|
||||
|
||||
# ===== Day 13: TTS 缓存队列 =====
|
||||
# 当 WebSocket 断开时,缓存 TTS 音频,等待重连后发送
|
||||
TTS_BUFFER_MAX_SECONDS = 30 # 最多缓存 30 秒音频
|
||||
TTS_BUFFER_MAX_BYTES = 16000 * 2 * TTS_BUFFER_MAX_SECONDS # 16kHz * 2 bytes * 30s = ~960KB
|
||||
tts_audio_buffer: deque = deque() # 每个元素是 (timestamp, pcm16k_bytes)
|
||||
tts_buffer_total_bytes = 0
|
||||
|
||||
# Day 13: TTS 专用 WebSocket 引用
|
||||
# 在 AI 处理开始前保存,避免被 ws_audio 的 finally 块清空
|
||||
tts_websocket = None
|
||||
|
||||
def set_tts_websocket(ws):
|
||||
"""保存 TTS 发送专用的 WebSocket 引用"""
|
||||
global tts_websocket
|
||||
tts_websocket = ws
|
||||
|
||||
def get_tts_websocket():
|
||||
"""获取 TTS WebSocket(优先使用保存的引用,其次尝试全局变量)"""
|
||||
global tts_websocket
|
||||
if tts_websocket is not None:
|
||||
return tts_websocket
|
||||
# Day 15 修复:避免 import app_main,因为会触发模块顶层代码重新执行
|
||||
# 改为通过 sys.modules 获取已加载的模块引用
|
||||
try:
|
||||
import sys
|
||||
if 'app_main' in sys.modules:
|
||||
return sys.modules['app_main'].esp32_audio_ws
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
# ===== AI 播放任务总闸 =====
|
||||
current_ai_task: Optional[asyncio.Task] = None
|
||||
|
||||
async def cancel_current_ai():
|
||||
"""取消当前大模型语音任务,并等待其退出。"""
|
||||
global current_ai_task
|
||||
task = current_ai_task
|
||||
current_ai_task = None
|
||||
if task and not task.done():
|
||||
task.cancel()
|
||||
try:
|
||||
await task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def is_playing_now() -> bool:
|
||||
t = current_ai_task
|
||||
return (t is not None) and (not t.done())
|
||||
|
||||
# ===== /stream.wav 连接管理 =====
|
||||
@dataclass(frozen=True)
|
||||
class StreamClient:
|
||||
q: asyncio.Queue
|
||||
abort_event: asyncio.Event
|
||||
|
||||
stream_clients: "Set[StreamClient]" = set()
|
||||
STREAM_QUEUE_MAX = 96 # 小缓冲,避免积压
|
||||
|
||||
def _wav_header_unknown_size(sr=16000, ch=1, sw=2) -> bytes:
|
||||
import struct
|
||||
byte_rate = sr * ch * sw
|
||||
block_align = ch * sw
|
||||
data_size = 0x7FFFFFF0
|
||||
riff_size = 36 + data_size
|
||||
return struct.pack(
|
||||
"<4sI4s4sIHHIIHH4sI",
|
||||
b"RIFF", riff_size, b"WAVE",
|
||||
b"fmt ", 16,
|
||||
1, ch, sr, byte_rate, block_align, sw * 8,
|
||||
b"data", data_size
|
||||
)
|
||||
|
||||
async def hard_reset_audio(reason: str = ""):
|
||||
"""
|
||||
**一键清场**:取消当前AI任务。
|
||||
注意:不再断开 HTTP /stream.wav 连接,因为 Avaota F1 使用这个通道播放 TTS。
|
||||
"""
|
||||
# Day 14: 不再断开 HTTP 连接,只取消 AI 任务
|
||||
# 因为 Avaota F1 的 HTTP TTS 客户端需要保持长连接
|
||||
# 断开会导致客户端收不到后续 TTS 音频
|
||||
|
||||
|
||||
# 2) 取消当前AI任务
|
||||
await cancel_current_ai()
|
||||
|
||||
# 3) 日志
|
||||
if reason:
|
||||
print(f"[HARD-RESET] {reason}")
|
||||
|
||||
async def flush_tts_buffer(ws) -> int:
|
||||
"""
|
||||
Day 13: 刷新 TTS 缓存,发送所有缓存的音频到 WebSocket
|
||||
返回发送的字节数
|
||||
"""
|
||||
global tts_audio_buffer, tts_buffer_total_bytes
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
if not tts_audio_buffer:
|
||||
return 0
|
||||
|
||||
total_sent = 0
|
||||
items_to_send = list(tts_audio_buffer)
|
||||
tts_audio_buffer.clear()
|
||||
tts_buffer_total_bytes = 0
|
||||
|
||||
try:
|
||||
for _, audio_data in items_to_send:
|
||||
if hasattr(ws, 'client_state') and ws.client_state != WebSocketState.CONNECTED:
|
||||
print(f"[TTS->WS] ⚠️ WebSocket disconnected while flushing buffer")
|
||||
break
|
||||
await ws.send_bytes(audio_data)
|
||||
total_sent += len(audio_data)
|
||||
|
||||
if total_sent > 0:
|
||||
duration = total_sent / (16000 * 2)
|
||||
print(f"[TTS->WS] 📤 Flushed {total_sent} bytes ({duration:.1f}s) of cached TTS audio")
|
||||
except Exception as e:
|
||||
print(f"[TTS->WS] ❌ Error flushing buffer: {e}")
|
||||
|
||||
return total_sent
|
||||
|
||||
async def broadcast_pcm16_realtime(pcm16: bytes):
|
||||
"""
|
||||
Day 14 优化:WebSocket 立即发送,HTTP 节拍广播在后台执行
|
||||
避免 HTTP 20ms pacing 阻塞 WebSocket TTS 传输
|
||||
"""
|
||||
# 【新增】录制音频(在分发之前整体录制,避免分片)
|
||||
try:
|
||||
import sync_recorder
|
||||
sync_recorder.record_audio(pcm16, text="[Omni对话]")
|
||||
except Exception:
|
||||
pass # 静默失败,不影响播放
|
||||
|
||||
# Day 13: 同时发送给 WebSocket 客户端 (Avaota F1)
|
||||
# 注意:Avaota 期望 16kHz PCM16 数据,而这里的 pcm16 是 8kHz
|
||||
# 需要进行采样率转换
|
||||
global tts_audio_buffer, tts_buffer_total_bytes
|
||||
import time as _time
|
||||
|
||||
try:
|
||||
import audioop
|
||||
|
||||
# Day 13: 使用 get_tts_websocket() 获取 WebSocket 引用
|
||||
# 优先使用保存的引用,避免因 ws_audio 的 finally 清空全局变量
|
||||
ws = get_tts_websocket()
|
||||
|
||||
# Day 21 优化:输入现在已经是 16kHz,无需转换
|
||||
# app_main.py 已经直接从 24kHz 转换到 16kHz
|
||||
pcm16k = pcm16
|
||||
|
||||
sent_ok = False
|
||||
if ws is not None:
|
||||
try:
|
||||
# Day 13 修复:不检查 client_state,直接尝试发送
|
||||
# WebSocketState 检查可能不准确,导致音频被错误缓存
|
||||
|
||||
# 先发送缓存的音频
|
||||
while tts_audio_buffer:
|
||||
_, buffered_audio = tts_audio_buffer.popleft()
|
||||
tts_buffer_total_bytes -= len(buffered_audio)
|
||||
await ws.send_bytes(buffered_audio)
|
||||
if not getattr(broadcast_pcm16_realtime, '_flush_logged', False):
|
||||
print(f"[TTS->WS] 📤 Flushing buffered TTS audio...")
|
||||
broadcast_pcm16_realtime._flush_logged = True
|
||||
|
||||
# 发送当前音频
|
||||
await ws.send_bytes(pcm16k)
|
||||
sent_ok = True
|
||||
|
||||
if len(pcm16k) > 320:
|
||||
print(f"[TTS->WS] 📤 Sent {len(pcm16k)} bytes (16kHz) to Avaota")
|
||||
|
||||
# 重置警告标志
|
||||
broadcast_pcm16_realtime._ws_warned = False
|
||||
broadcast_pcm16_realtime._buffer_warned = False
|
||||
broadcast_pcm16_realtime._flush_logged = False
|
||||
except Exception as send_err:
|
||||
# 发送失败,将当前音频放回缓存
|
||||
if not getattr(broadcast_pcm16_realtime, '_send_err_warned', False):
|
||||
print(f"[TTS->WS] ❌ Send error: {send_err}, will buffer")
|
||||
broadcast_pcm16_realtime._send_err_warned = True
|
||||
|
||||
# 如果发送失败或 WebSocket 断开,缓存音频
|
||||
if not sent_ok:
|
||||
# 添加到缓存队列
|
||||
tts_audio_buffer.append((_time.time(), pcm16k))
|
||||
tts_buffer_total_bytes += len(pcm16k)
|
||||
|
||||
# 如果缓存过大,移除最旧的
|
||||
while tts_buffer_total_bytes > TTS_BUFFER_MAX_BYTES and tts_audio_buffer:
|
||||
_, old_audio = tts_audio_buffer.popleft()
|
||||
tts_buffer_total_bytes -= len(old_audio)
|
||||
|
||||
if not getattr(broadcast_pcm16_realtime, '_buffer_warned', False):
|
||||
buffer_secs = tts_buffer_total_bytes / (16000 * 2)
|
||||
print(f"[TTS->WS] 📦 Buffering TTS audio ({buffer_secs:.1f}s cached), will send when reconnected")
|
||||
broadcast_pcm16_realtime._buffer_warned = True
|
||||
|
||||
except Exception:
|
||||
pass # 静默忽略所有异常
|
||||
|
||||
# Day 14 优化:将 HTTP 节拍广播放到后台任务,不阻塞 WebSocket 发送
|
||||
# 这样下一个 Omni 音频块可以立即处理,不用等待 HTTP 节拍完成
|
||||
if stream_clients:
|
||||
asyncio.create_task(_http_pacing_broadcast(pcm16))
|
||||
|
||||
|
||||
async def _http_pacing_broadcast(pcm16: bytes):
|
||||
"""
|
||||
Day 14: HTTP 客户端的 20ms 节拍广播(独立后台任务)
|
||||
原来嵌入在 broadcast_pcm16_realtime 中,会阻塞 WebSocket 发送
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
next_tick = loop.time()
|
||||
off = 0
|
||||
while off < len(pcm16):
|
||||
take = min(BYTES_PER_20MS_16K, len(pcm16) - off)
|
||||
piece = pcm16[off:off + take]
|
||||
|
||||
dead: List[StreamClient] = []
|
||||
# Day 14 调试:确认 stream_clients 状态
|
||||
if len(stream_clients) > 0 and off == 0:
|
||||
print(f"[TTS->HTTP] 📤 Sending to {len(stream_clients)} HTTP stream client(s)")
|
||||
for sc in list(stream_clients):
|
||||
if sc.abort_event.is_set():
|
||||
dead.append(sc)
|
||||
continue
|
||||
try:
|
||||
if sc.q.full():
|
||||
try: sc.q.get_nowait()
|
||||
except Exception: pass
|
||||
sc.q.put_nowait(piece)
|
||||
except Exception:
|
||||
dead.append(sc)
|
||||
for sc in dead:
|
||||
try: stream_clients.discard(sc)
|
||||
except Exception: pass
|
||||
|
||||
next_tick += 0.020
|
||||
now = loop.time()
|
||||
if now < next_tick:
|
||||
await asyncio.sleep(next_tick - now)
|
||||
else:
|
||||
next_tick = now
|
||||
off += take
|
||||
|
||||
# ===== FastAPI 路由注册器 =====
|
||||
def register_stream_route(app):
|
||||
@app.get("/stream.wav")
|
||||
async def stream_wav(_: Request):
|
||||
# —— 强制单连接(或少数连接),先拉闸所有旧连接 ——
|
||||
for sc in list(stream_clients):
|
||||
try: sc.abort_event.set()
|
||||
except Exception: pass
|
||||
stream_clients.clear()
|
||||
|
||||
q: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=STREAM_QUEUE_MAX)
|
||||
abort_event = asyncio.Event()
|
||||
sc = StreamClient(q=q, abort_event=abort_event)
|
||||
stream_clients.add(sc)
|
||||
|
||||
async def gen():
|
||||
yield _wav_header_unknown_size(STREAM_SR, STREAM_CH, STREAM_SW)
|
||||
try:
|
||||
while True:
|
||||
if abort_event.is_set():
|
||||
break
|
||||
try:
|
||||
chunk = await asyncio.wait_for(q.get(), timeout=0.5)
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
if abort_event.is_set():
|
||||
break
|
||||
if chunk is None:
|
||||
break
|
||||
if chunk:
|
||||
yield chunk
|
||||
finally:
|
||||
stream_clients.discard(sc)
|
||||
return StreamingResponse(gen(), media_type="audio/wav")
|
||||
92
bridge_io.py
Normal file
92
bridge_io.py
Normal file
@@ -0,0 +1,92 @@
|
||||
# bridge_io.py
|
||||
# 极简桥:接原始JPEG → 提供BGR帧给外部算法;外部算法产出BGR → 广播给前端
|
||||
import threading
|
||||
from collections import deque
|
||||
import time
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
||||
# 原始JPEG帧缓冲(只保留最新 N 帧)
|
||||
_MAX_BUF = 4
|
||||
_frames = deque(maxlen=_MAX_BUF)
|
||||
_cond = threading.Condition()
|
||||
|
||||
# 向前端发送JPEG的回调,由 app_main.py 在启动时注册
|
||||
_sender_lock = threading.Lock()
|
||||
_sender_cb = None
|
||||
|
||||
# 向前端发送UI文本的回调(由 app_main.py 在启动时注册)
|
||||
_ui_sender_lock = threading.Lock()
|
||||
_ui_sender_cb = None
|
||||
|
||||
def set_sender(cb):
|
||||
"""由 app_main.py 调用,注册一个函数:cb(jpeg_bytes)->None"""
|
||||
global _sender_cb
|
||||
with _sender_lock:
|
||||
_sender_cb = cb
|
||||
|
||||
def set_ui_sender(cb):
|
||||
"""由 app_main.py 调用,注册一个函数:cb(text:str)->None"""
|
||||
global _ui_sender_cb
|
||||
with _ui_sender_lock:
|
||||
_ui_sender_cb = cb
|
||||
|
||||
def push_raw_jpeg(jpeg_bytes: bytes):
|
||||
"""由 app_main.py 在收到 /ws/camera 帧时调用"""
|
||||
if not jpeg_bytes:
|
||||
return
|
||||
with _cond:
|
||||
_frames.append((time.time(), jpeg_bytes))
|
||||
_cond.notify_all()
|
||||
|
||||
def wait_raw_bgr(timeout_sec: float = 0.5):
|
||||
"""被 YOLO/MediaPipe 脚本调用:等待并拿到最新一帧BGR;超时返回 None"""
|
||||
t_end = time.time() + timeout_sec
|
||||
last = None
|
||||
while time.time() < t_end:
|
||||
with _cond:
|
||||
if _frames:
|
||||
last = _frames[-1]
|
||||
if last is None:
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
# 解码JPEG为BGR
|
||||
ts, jpeg = last
|
||||
arr = np.frombuffer(jpeg, dtype=np.uint8)
|
||||
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||
if bgr is not None:
|
||||
# 在最源头进行镜像处理
|
||||
#bgr = cv2.flip(bgr, 1)
|
||||
return bgr
|
||||
# 解码失败,稍等重试
|
||||
time.sleep(0.01)
|
||||
return None
|
||||
|
||||
def send_vis_bgr(bgr, quality: int = 80):
|
||||
"""被 YOLO/MediaPipe 脚本调用:把处理后画面推给前端 viewer"""
|
||||
if bgr is None:
|
||||
return
|
||||
|
||||
# 直接编码,不做任何增强处理
|
||||
ok, enc = cv2.imencode(".jpg", bgr, [int(cv2.IMWRITE_JPEG_QUALITY), int(quality)])
|
||||
if not ok:
|
||||
return
|
||||
with _sender_lock:
|
||||
cb = _sender_cb
|
||||
if cb:
|
||||
try:
|
||||
cb(enc.tobytes())
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def send_ui_final(text: str):
|
||||
"""把一条UI文案作为 final answer 推给前端(线程安全回调)"""
|
||||
if not text:
|
||||
return
|
||||
with _ui_sender_lock:
|
||||
cb = _ui_sender_cb
|
||||
if cb:
|
||||
try:
|
||||
cb(str(text))
|
||||
except Exception:
|
||||
pass
|
||||
342
crosswalk_awareness.py
Normal file
342
crosswalk_awareness.py
Normal file
@@ -0,0 +1,342 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
斑马线感知监控器
|
||||
基于面积变化的斑马线检测和语音提示
|
||||
不涉及状态切换,只提供语音引导
|
||||
"""
|
||||
import time
|
||||
import numpy as np
|
||||
from collections import deque
|
||||
from typing import Optional, Dict, Any
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class CrosswalkAwarenessMonitor:
|
||||
"""斑马线感知监控器 - 纯语音提示模块"""
|
||||
|
||||
def __init__(self):
|
||||
# 面积阈值(固定锚点)
|
||||
self.THRESHOLDS = {
|
||||
'discover': 0.01, # 1% - 发现
|
||||
'approaching': 0.08, # 8% - 靠近
|
||||
'near': 0.18, # 18% - 很近
|
||||
'arrival': 0.25, # 25% - 到达(可以过马路)
|
||||
}
|
||||
|
||||
# 已播报的阈值(避免重复)
|
||||
self.broadcasted_thresholds = set()
|
||||
|
||||
# 面积历史记录
|
||||
self.area_history = deque(maxlen=30) # 保存最近30帧
|
||||
|
||||
# 时间记录
|
||||
self.last_broadcast_time = 0
|
||||
self.arrival_first_broadcast_time = 0
|
||||
|
||||
# 状态标志
|
||||
self.in_arrival_state = False # 是否在"可以过马路"状态
|
||||
self.last_position_zone = None # 上次播报的方位
|
||||
|
||||
# 播报间隔配置(可调整参数 - 数值越小播报越频繁)
|
||||
# 【参数调整】将所有间隔除以1.5,提高播报频率1.5倍
|
||||
self.REPEAT_INTERVALS = {
|
||||
'approaching': 6.7, # 靠近阶段:每6.7秒重复(原10秒÷1.5)
|
||||
'near': 3.3, # 很近阶段:每3.3秒重复(原5秒÷1.5)
|
||||
'arrival': 5.3, # 到达阶段:每5.3秒重复(原8秒÷1.5)
|
||||
}
|
||||
# 提示:如需调整频率,修改这些数值即可
|
||||
# - 数值越小 = 播报越频繁
|
||||
# - 数值越大 = 播报越稀疏
|
||||
|
||||
# 无遮挡判断阈值
|
||||
self.OCCLUSION_THRESHOLD = 0.30 # 重叠>30%认为有遮挡
|
||||
|
||||
def process_frame(self, crosswalk_mask, blind_path_mask=None) -> Optional[Dict[str, Any]]:
|
||||
"""
|
||||
处理每帧的斑马线检测
|
||||
|
||||
返回:
|
||||
{
|
||||
'voice_text': 语音文本,
|
||||
'priority': 优先级,
|
||||
'should_broadcast': 是否应该播报,
|
||||
'area': 当前面积,
|
||||
'position': 方位描述,
|
||||
'visualization': 可视化信息(用于外部绘制)
|
||||
}
|
||||
或 None(无需播报)
|
||||
"""
|
||||
# 如果没有斑马线,重置状态
|
||||
if crosswalk_mask is None:
|
||||
self._reset_if_needed()
|
||||
return None
|
||||
|
||||
# 1. 计算面积
|
||||
total_pixels = crosswalk_mask.size
|
||||
crosswalk_pixels = np.sum(crosswalk_mask > 0)
|
||||
area_ratio = crosswalk_pixels / total_pixels
|
||||
|
||||
# 2. 计算中心位置
|
||||
y_coords, x_coords = np.where(crosswalk_mask > 0)
|
||||
if len(y_coords) == 0:
|
||||
return None
|
||||
|
||||
center_x_ratio = np.mean(x_coords) / crosswalk_mask.shape[1]
|
||||
center_y_ratio = np.mean(y_coords) / crosswalk_mask.shape[0]
|
||||
|
||||
# 3. 记录历史
|
||||
current_time = time.time()
|
||||
self.area_history.append({
|
||||
'area': area_ratio,
|
||||
'center_x': center_x_ratio,
|
||||
'center_y': center_y_ratio,
|
||||
'time': current_time
|
||||
})
|
||||
|
||||
# 4. 检查遮挡
|
||||
has_occlusion = self._check_occlusion(crosswalk_mask, blind_path_mask)
|
||||
|
||||
# 5. 判断当前阶段和生成语音
|
||||
return self._generate_guidance(area_ratio, center_x_ratio, center_y_ratio,
|
||||
has_occlusion, current_time)
|
||||
|
||||
def _check_occlusion(self, crosswalk_mask, blind_path_mask) -> bool:
|
||||
"""检查斑马线是否被盲道遮挡"""
|
||||
if blind_path_mask is None:
|
||||
return False
|
||||
|
||||
crosswalk_area = crosswalk_mask > 0
|
||||
blind_path_area = blind_path_mask > 0
|
||||
|
||||
# 计算重叠
|
||||
overlap = np.logical_and(crosswalk_area, blind_path_area)
|
||||
overlap_ratio = np.sum(overlap) / max(np.sum(crosswalk_area), 1)
|
||||
|
||||
# 重叠超过阈值认为有遮挡
|
||||
return overlap_ratio > self.OCCLUSION_THRESHOLD
|
||||
|
||||
def _get_position_description(self, center_x_ratio) -> str:
|
||||
"""获取方位描述(3分法)"""
|
||||
if center_x_ratio < 0.40:
|
||||
return "在画面左侧"
|
||||
elif center_x_ratio < 0.60:
|
||||
return "在画面中间"
|
||||
else:
|
||||
return "在画面右侧"
|
||||
|
||||
def _generate_guidance(self, area_ratio, center_x_ratio, center_y_ratio,
|
||||
has_occlusion, current_time) -> Optional[Dict[str, Any]]:
|
||||
"""生成引导语音"""
|
||||
|
||||
# 检查面积是否稳定(避免抖动)
|
||||
if not self._is_area_stable(area_ratio):
|
||||
return None
|
||||
|
||||
position_desc = self._get_position_description(center_x_ratio)
|
||||
|
||||
# 阶段1:发现阶段(0.01-0.08)
|
||||
if area_ratio >= self.THRESHOLDS['discover'] and area_ratio < self.THRESHOLDS['approaching']:
|
||||
if self.THRESHOLDS['discover'] not in self.broadcasted_thresholds:
|
||||
self.broadcasted_thresholds.add(self.THRESHOLDS['discover'])
|
||||
return {
|
||||
'voice_text': f"远处发现斑马线,{position_desc}",
|
||||
'priority': 55, # 提高到55,超过盲道方向指令(50)
|
||||
'should_broadcast': True,
|
||||
'area': area_ratio,
|
||||
'position': position_desc
|
||||
}
|
||||
|
||||
# 阶段2:靠近阶段(0.08-0.18)
|
||||
elif area_ratio >= self.THRESHOLDS['approaching'] and area_ratio < self.THRESHOLDS['near']:
|
||||
# 首次播报
|
||||
if self.THRESHOLDS['approaching'] not in self.broadcasted_thresholds:
|
||||
self.broadcasted_thresholds.add(self.THRESHOLDS['approaching'])
|
||||
self.last_broadcast_time = current_time
|
||||
self.last_position_zone = position_desc
|
||||
return {
|
||||
'voice_text': f"正在靠近斑马线,{position_desc}",
|
||||
'priority': 55, # 提高到55
|
||||
'should_broadcast': True,
|
||||
'area': area_ratio,
|
||||
'position': position_desc
|
||||
}
|
||||
# 重复播报(每10秒或方位变化)
|
||||
elif (current_time - self.last_broadcast_time >= self.REPEAT_INTERVALS['approaching'] or
|
||||
position_desc != self.last_position_zone):
|
||||
self.last_broadcast_time = current_time
|
||||
self.last_position_zone = position_desc
|
||||
return {
|
||||
'voice_text': f"正在靠近斑马线,{position_desc}",
|
||||
'priority': 55, # 提高到55
|
||||
'should_broadcast': True,
|
||||
'area': area_ratio,
|
||||
'position': position_desc
|
||||
}
|
||||
|
||||
# 阶段3:很近阶段(0.18-0.25)
|
||||
elif area_ratio >= self.THRESHOLDS['near'] and area_ratio < self.THRESHOLDS['arrival']:
|
||||
# 首次播报
|
||||
if self.THRESHOLDS['near'] not in self.broadcasted_thresholds:
|
||||
self.broadcasted_thresholds.add(self.THRESHOLDS['near'])
|
||||
self.last_broadcast_time = current_time
|
||||
self.last_position_zone = position_desc
|
||||
return {
|
||||
'voice_text': f"接近斑马线,{position_desc}",
|
||||
'priority': 60,
|
||||
'should_broadcast': True,
|
||||
'area': area_ratio,
|
||||
'position': position_desc
|
||||
}
|
||||
# 重复播报(每5秒或方位变化)
|
||||
elif (current_time - self.last_broadcast_time >= self.REPEAT_INTERVALS['near'] or
|
||||
position_desc != self.last_position_zone):
|
||||
self.last_broadcast_time = current_time
|
||||
self.last_position_zone = position_desc
|
||||
return {
|
||||
'voice_text': f"接近斑马线,{position_desc}",
|
||||
'priority': 60,
|
||||
'should_broadcast': True,
|
||||
'area': area_ratio,
|
||||
'position': position_desc
|
||||
}
|
||||
|
||||
# 阶段4:到达阶段(area ≥ 0.25,无遮挡)
|
||||
elif area_ratio >= self.THRESHOLDS['arrival']:
|
||||
# 必须无遮挡才能提示过马路
|
||||
if has_occlusion:
|
||||
# 有遮挡,暂不提示过马路,停留在阶段3
|
||||
logger.info(f"[斑马线] 面积达到{area_ratio:.2f}但被遮挡,暂不提示过马路")
|
||||
return None
|
||||
|
||||
# 首次到达
|
||||
if not self.in_arrival_state:
|
||||
self.in_arrival_state = True
|
||||
self.arrival_first_broadcast_time = current_time
|
||||
self.last_broadcast_time = current_time
|
||||
logger.info(f"[斑马线] 到达状态:area={area_ratio:.2f}, 无遮挡")
|
||||
return {
|
||||
'voice_text': "斑马线到了可以过马路",
|
||||
'priority': 80,
|
||||
'should_broadcast': True,
|
||||
'area': area_ratio,
|
||||
'position': '到达'
|
||||
}
|
||||
# 重复播报(每8秒)
|
||||
elif current_time - self.last_broadcast_time >= self.REPEAT_INTERVALS['arrival']:
|
||||
self.last_broadcast_time = current_time
|
||||
return {
|
||||
'voice_text': "斑马线到了可以过马路",
|
||||
'priority': 80,
|
||||
'should_broadcast': True,
|
||||
'area': area_ratio,
|
||||
'position': '到达'
|
||||
}
|
||||
# 超时处理(30秒后自动退出到达状态)
|
||||
elif current_time - self.arrival_first_broadcast_time > 30.0:
|
||||
logger.info("[斑马线] 到达状态超时30秒,自动退出")
|
||||
self.in_arrival_state = False
|
||||
return None
|
||||
|
||||
# 降级处理:如果从到达状态面积减小
|
||||
if self.in_arrival_state and area_ratio < 0.20:
|
||||
logger.info(f"[斑马线] 面积降至{area_ratio:.2f},退出到达状态")
|
||||
self.in_arrival_state = False
|
||||
# 清除部分已播报标记,允许重新播报
|
||||
self.broadcasted_thresholds.discard(self.THRESHOLDS['arrival'])
|
||||
|
||||
return None
|
||||
|
||||
def _is_area_stable(self, area_ratio, stability_frames=5) -> bool:
|
||||
"""检查面积是否稳定(避免抖动触发)"""
|
||||
if len(self.area_history) < stability_frames:
|
||||
return True # 初始阶段,认为稳定
|
||||
|
||||
recent_areas = [h['area'] for h in list(self.area_history)[-stability_frames:]]
|
||||
|
||||
# 检查最近N帧是否都在当前面积附近(±20%)
|
||||
for recent_area in recent_areas:
|
||||
if abs(recent_area - area_ratio) / max(area_ratio, 0.001) > 0.20:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _reset_if_needed(self):
|
||||
"""重置状态(斑马线消失时)"""
|
||||
if len(self.area_history) > 0:
|
||||
logger.info("[斑马线] 斑马线消失,重置状态")
|
||||
|
||||
self.broadcasted_thresholds.clear()
|
||||
self.area_history.clear()
|
||||
self.in_arrival_state = False
|
||||
self.last_position_zone = None
|
||||
|
||||
def reset(self):
|
||||
"""完全重置"""
|
||||
self.broadcasted_thresholds.clear()
|
||||
self.area_history.clear()
|
||||
self.in_arrival_state = False
|
||||
self.last_broadcast_time = 0
|
||||
self.arrival_first_broadcast_time = 0
|
||||
self.last_position_zone = None
|
||||
logger.info("[斑马线] 感知监控器已重置")
|
||||
|
||||
def is_in_arrival_state(self) -> bool:
|
||||
"""是否在到达状态(用于外部判断是否暂停盲道语音)"""
|
||||
return self.in_arrival_state
|
||||
|
||||
def get_current_area(self) -> float:
|
||||
"""获取当前面积"""
|
||||
if len(self.area_history) > 0:
|
||||
return self.area_history[-1]['area']
|
||||
return 0.0
|
||||
|
||||
def get_visualization_data(self, crosswalk_mask, area_ratio, center_x_ratio, center_y_ratio, has_occlusion) -> Dict[str, Any]:
|
||||
"""
|
||||
获取可视化数据
|
||||
返回包含所有可视化元素的字典
|
||||
"""
|
||||
if crosswalk_mask is None:
|
||||
return {}
|
||||
|
||||
# 确定当前阶段(统一使用橙色)
|
||||
if area_ratio >= self.THRESHOLDS['arrival']:
|
||||
stage = "到达"
|
||||
stage_color = "rgba(255, 165, 0, 0.5)" # 橙色
|
||||
elif area_ratio >= self.THRESHOLDS['near']:
|
||||
stage = "接近"
|
||||
stage_color = "rgba(255, 165, 0, 0.45)" # 橙色
|
||||
elif area_ratio >= self.THRESHOLDS['approaching']:
|
||||
stage = "靠近"
|
||||
stage_color = "rgba(255, 165, 0, 0.40)" # 橙色
|
||||
else:
|
||||
stage = "发现"
|
||||
stage_color = "rgba(255, 165, 0, 0.35)" # 橙色
|
||||
|
||||
# 方位描述
|
||||
position = self._get_position_description(center_x_ratio)
|
||||
|
||||
return {
|
||||
'area_ratio': area_ratio,
|
||||
'stage': stage,
|
||||
'stage_color': stage_color,
|
||||
'position': position.replace("在画面", ""), # 去掉"在画面"前缀
|
||||
'center_x_ratio': center_x_ratio,
|
||||
'center_y_ratio': center_y_ratio,
|
||||
'has_occlusion': has_occlusion,
|
||||
'in_arrival': self.in_arrival_state
|
||||
}
|
||||
|
||||
|
||||
# 辅助函数
|
||||
def split_combined_voice(combined_text: str) -> list:
|
||||
"""
|
||||
将组合语音拆分为多个独立语音
|
||||
例如:"远处发现斑马线,在画面左侧" → ["远处发现斑马线", "在画面左侧"]
|
||||
"""
|
||||
if ',' in combined_text:
|
||||
parts = combined_text.split(',')
|
||||
return [p.strip() for p in parts if p.strip()]
|
||||
return [combined_text]
|
||||
|
||||
77
docker-compose.yml
Normal file
77
docker-compose.yml
Normal file
@@ -0,0 +1,77 @@
|
||||
version: '3.8'
|
||||
|
||||
services:
|
||||
aiglass:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
container_name: aiglass
|
||||
restart: unless-stopped
|
||||
|
||||
# 端口映射
|
||||
ports:
|
||||
- "8081:8081" # Web 服务
|
||||
- "12345:12345/udp" # IMU UDP
|
||||
|
||||
# 环境变量(从 .env 文件读取)
|
||||
environment:
|
||||
- DASHSCOPE_API_KEY=${DASHSCOPE_API_KEY}
|
||||
- BLIND_PATH_MODEL=${BLIND_PATH_MODEL:-model/yolo-seg.pt}
|
||||
- OBSTACLE_MODEL=${OBSTACLE_MODEL:-model/yoloe-11l-seg.pt}
|
||||
- YOLOE_MODEL_PATH=${YOLOE_MODEL_PATH:-model/yoloe-11l-seg.pt}
|
||||
- ENABLE_TTS=${ENABLE_TTS:-true}
|
||||
- TTS_INTERVAL_SEC=${TTS_INTERVAL_SEC:-1.0}
|
||||
- LOG_LEVEL=${LOG_LEVEL:-INFO}
|
||||
|
||||
# 卷挂载
|
||||
volumes:
|
||||
- ./model:/app/model:ro # 模型文件(只读)
|
||||
- ./recordings:/app/recordings # 录制文件
|
||||
- ./music:/app/music:ro # 音频文件(只读)
|
||||
- ./voice:/app/voice:ro # 语音文件(只读)
|
||||
- ./static:/app/static:ro # 静态文件(只读)
|
||||
- ./templates:/app/templates:ro # 模板文件(只读)
|
||||
|
||||
# GPU 支持(需要 nvidia-docker)
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
count: 1
|
||||
capabilities: [gpu]
|
||||
|
||||
# 依赖服务(可选,如需添加数据库等)
|
||||
# depends_on:
|
||||
# - redis
|
||||
|
||||
# 网络模式
|
||||
network_mode: bridge
|
||||
|
||||
# 健康检查
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:8081/api/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 40s
|
||||
|
||||
# 可选:添加 Redis 用于缓存
|
||||
# redis:
|
||||
# image: redis:7-alpine
|
||||
# container_name: aiglass-redis
|
||||
# restart: unless-stopped
|
||||
# ports:
|
||||
# - "6379:6379"
|
||||
# volumes:
|
||||
# - redis-data:/data
|
||||
|
||||
# 可选:数据卷
|
||||
# volumes:
|
||||
# redis-data:
|
||||
|
||||
# 可选:自定义网络
|
||||
# networks:
|
||||
# aiglass-network:
|
||||
# driver: bridge
|
||||
|
||||
202
edge_tts_client.py
Normal file
202
edge_tts_client.py
Normal file
@@ -0,0 +1,202 @@
|
||||
# edge_tts_client.py
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
EdgeTTS 流式语音合成客户端 - Day 21
|
||||
|
||||
特点:
|
||||
- 完全免费
|
||||
- 流式输出(边合成边播放)
|
||||
- 低延迟
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
import edge_tts
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
# 默认语音
|
||||
DEFAULT_VOICE = os.getenv("EDGE_TTS_VOICE", "zh-CN-XiaoxiaoNeural")
|
||||
|
||||
# 语速调整 ("+0%", "+10%", "-10%" 等)
|
||||
DEFAULT_RATE = os.getenv("EDGE_TTS_RATE", "+0%")
|
||||
|
||||
# 音量调整
|
||||
DEFAULT_VOLUME = os.getenv("EDGE_TTS_VOLUME", "+0%")
|
||||
|
||||
|
||||
async def text_to_speech_stream(
|
||||
text: str,
|
||||
voice: str = DEFAULT_VOICE,
|
||||
rate: str = DEFAULT_RATE,
|
||||
volume: str = DEFAULT_VOLUME,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""
|
||||
流式文本转语音
|
||||
|
||||
Args:
|
||||
text: 要合成的文本
|
||||
voice: 语音名称
|
||||
rate: 语速
|
||||
volume: 音量
|
||||
|
||||
Yields:
|
||||
MP3 音频数据块
|
||||
"""
|
||||
if not text or not text.strip():
|
||||
return
|
||||
|
||||
try:
|
||||
communicate = edge_tts.Communicate(
|
||||
text=text,
|
||||
voice=voice,
|
||||
rate=rate,
|
||||
volume=volume,
|
||||
)
|
||||
|
||||
async for chunk in communicate.stream():
|
||||
if chunk["type"] == "audio":
|
||||
yield chunk["data"]
|
||||
|
||||
except Exception as e:
|
||||
print(f"[EdgeTTS] 合成失败: {e}")
|
||||
|
||||
|
||||
async def text_to_speech(
|
||||
text: str,
|
||||
voice: str = DEFAULT_VOICE,
|
||||
rate: str = DEFAULT_RATE,
|
||||
volume: str = DEFAULT_VOLUME,
|
||||
) -> bytes:
|
||||
"""
|
||||
完整文本转语音(返回完整音频)
|
||||
|
||||
Args:
|
||||
text: 要合成的文本
|
||||
voice: 语音名称
|
||||
rate: 语速
|
||||
volume: 音量
|
||||
|
||||
Returns:
|
||||
MP3 音频数据
|
||||
"""
|
||||
audio_chunks = []
|
||||
async for chunk in text_to_speech_stream(text, voice, rate, volume):
|
||||
audio_chunks.append(chunk)
|
||||
return b"".join(audio_chunks)
|
||||
|
||||
|
||||
async def text_to_speech_pcm(
|
||||
text: str,
|
||||
voice: str = DEFAULT_VOICE,
|
||||
rate: str = DEFAULT_RATE,
|
||||
target_sample_rate: int = 16000,
|
||||
) -> bytes:
|
||||
"""
|
||||
文本转 PCM16 音频(用于直接播放)
|
||||
|
||||
Args:
|
||||
text: 要合成的文本
|
||||
voice: 语音名称
|
||||
rate: 语速
|
||||
target_sample_rate: 目标采样率
|
||||
|
||||
Returns:
|
||||
PCM16 音频数据
|
||||
"""
|
||||
import io
|
||||
from pydub import AudioSegment
|
||||
|
||||
# 获取 MP3 数据
|
||||
mp3_data = await text_to_speech(text, voice, rate)
|
||||
|
||||
if not mp3_data:
|
||||
return b""
|
||||
|
||||
try:
|
||||
# MP3 -> PCM 转换
|
||||
audio = AudioSegment.from_mp3(io.BytesIO(mp3_data))
|
||||
|
||||
# 设置采样率和通道
|
||||
audio = audio.set_frame_rate(target_sample_rate)
|
||||
audio = audio.set_channels(1) # 单声道
|
||||
audio = audio.set_sample_width(2) # 16-bit
|
||||
|
||||
return audio.raw_data
|
||||
|
||||
except Exception as e:
|
||||
print(f"[EdgeTTS] PCM 转换失败: {e}")
|
||||
return b""
|
||||
|
||||
|
||||
async def text_to_speech_pcm_stream(
|
||||
text: str,
|
||||
voice: str = DEFAULT_VOICE,
|
||||
rate: str = DEFAULT_RATE,
|
||||
target_sample_rate: int = 16000,
|
||||
) -> AsyncGenerator[bytes, None]:
|
||||
"""
|
||||
流式文本转 PCM16 音频
|
||||
|
||||
注意:由于需要解码 MP3,这里采用分段合成的方式
|
||||
每遇到标点符号就合成一段
|
||||
|
||||
Args:
|
||||
text: 要合成的文本
|
||||
voice: 语音名称
|
||||
rate: 语速
|
||||
target_sample_rate: 目标采样率
|
||||
|
||||
Yields:
|
||||
PCM16 音频数据块
|
||||
"""
|
||||
import io
|
||||
from pydub import AudioSegment
|
||||
|
||||
# 按标点分割文本
|
||||
punctuation = "。,!?;:,.!?;:"
|
||||
segments = []
|
||||
current = ""
|
||||
|
||||
for char in text:
|
||||
current += char
|
||||
if char in punctuation:
|
||||
segments.append(current.strip())
|
||||
current = ""
|
||||
|
||||
if current.strip():
|
||||
segments.append(current.strip())
|
||||
|
||||
# 逐段合成
|
||||
for segment in segments:
|
||||
if not segment:
|
||||
continue
|
||||
|
||||
try:
|
||||
mp3_data = await text_to_speech(segment, voice, rate)
|
||||
|
||||
if mp3_data:
|
||||
audio = AudioSegment.from_mp3(io.BytesIO(mp3_data))
|
||||
audio = audio.set_frame_rate(target_sample_rate)
|
||||
audio = audio.set_channels(1)
|
||||
audio = audio.set_sample_width(2)
|
||||
|
||||
yield audio.raw_data
|
||||
|
||||
except Exception as e:
|
||||
print(f"[EdgeTTS] 分段合成失败: {e}")
|
||||
|
||||
|
||||
# 语音列表(常用中文)
|
||||
CHINESE_VOICES = [
|
||||
"zh-CN-XiaoxiaoNeural", # 女声,自然
|
||||
"zh-CN-YunxiNeural", # 男声,自然
|
||||
"zh-CN-XiaoyiNeural", # 女声,活泼
|
||||
"zh-CN-YunjianNeural", # 男声,播报
|
||||
"zh-CN-XiaochenNeural", # 女声,温柔
|
||||
]
|
||||
|
||||
|
||||
async def list_voices() -> list:
|
||||
"""列出所有可用语音"""
|
||||
voices = await edge_tts.list_voices()
|
||||
return [v for v in voices if v["Locale"].startswith("zh")]
|
||||
208
glm_client.py
Normal file
208
glm_client.py
Normal file
@@ -0,0 +1,208 @@
|
||||
# glm_client.py
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
GLM-4.6v-Flash LLM 客户端 - Day 22
|
||||
|
||||
使用官方 zai-sdk + glm-4.6v-flash 模型
|
||||
"""
|
||||
|
||||
import os
|
||||
import asyncio
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, Optional
|
||||
from zai import ZhipuAiClient
|
||||
|
||||
# API 配置
|
||||
API_KEY = os.getenv(
|
||||
"GLM_API_KEY",
|
||||
"5915240ea48d4e93b454bc2412d1cc54.e054ej4pPqi9G6rc"
|
||||
)
|
||||
MODEL = "glm-4.6v-flash" # 升级到 glm-4.6v-flash (支持视觉)
|
||||
|
||||
# 星期映射
|
||||
WEEKDAY_MAP = ["星期一", "星期二", "星期三", "星期四", "星期五", "星期六", "星期日"]
|
||||
|
||||
|
||||
def get_system_prompt() -> str:
|
||||
"""动态生成 system prompt,包含当前时间信息"""
|
||||
now = datetime.now()
|
||||
current_time = now.strftime("%H:%M")
|
||||
current_date = now.strftime("%Y年%m月%d日")
|
||||
current_weekday = WEEKDAY_MAP[now.weekday()]
|
||||
|
||||
return f"""你是一个视障辅助AI助手,安装在智能导盲眼镜上。
|
||||
当前时间:{current_time}
|
||||
今天日期:{current_date} {current_weekday}
|
||||
|
||||
请用极简短的语言回答,每次回答不超过2-3句话。
|
||||
避免冗长解释,只提供最关键的信息。
|
||||
语气友好但简洁。"""
|
||||
|
||||
|
||||
# 客户端和对话历史
|
||||
_client = None
|
||||
_conversation_history = []
|
||||
MAX_HISTORY_TURNS = 5 # 保留最近5轮对话
|
||||
|
||||
|
||||
def _get_client() -> ZhipuAiClient:
|
||||
"""获取智谱 AI 客户端"""
|
||||
global _client
|
||||
if _client is None:
|
||||
_client = ZhipuAiClient(api_key=API_KEY)
|
||||
return _client
|
||||
|
||||
|
||||
def clear_conversation_history():
|
||||
"""清除对话历史"""
|
||||
global _conversation_history
|
||||
_conversation_history = []
|
||||
print("[GLM] 对话历史已清除")
|
||||
|
||||
|
||||
async def chat(user_message: str, image_base64: Optional[str] = None) -> str:
|
||||
"""
|
||||
与 GLM-4.6v-Flash 对话(带上下文记忆)
|
||||
|
||||
Args:
|
||||
user_message: 用户消息文本
|
||||
image_base64: 可选,Base64 编码的图片
|
||||
|
||||
Returns:
|
||||
AI 回复文本
|
||||
"""
|
||||
global _conversation_history
|
||||
client = _get_client()
|
||||
|
||||
# 构建用户消息
|
||||
if image_base64:
|
||||
# 多模态消息(带图片)
|
||||
user_content = [
|
||||
{"type": "text", "text": user_message},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}
|
||||
]
|
||||
else:
|
||||
user_content = user_message
|
||||
|
||||
# 添加用户消息到历史
|
||||
_conversation_history.append({"role": "user", "content": user_content})
|
||||
|
||||
# 限制历史长度(每轮 = 1用户 + 1助手 = 2条消息)
|
||||
max_messages = MAX_HISTORY_TURNS * 2
|
||||
if len(_conversation_history) > max_messages:
|
||||
_conversation_history = _conversation_history[-max_messages:]
|
||||
|
||||
# 构建完整消息列表(每次动态生成包含当前时间的 system prompt)
|
||||
messages = [{"role": "system", "content": get_system_prompt()}] + _conversation_history
|
||||
|
||||
# Day 22: 添加重试逻辑处理速率限制
|
||||
max_retries = 3
|
||||
retry_delay = 1 # 初始延迟1秒
|
||||
|
||||
for attempt in range(max_retries):
|
||||
try:
|
||||
# Day 22: 升级到 glm-4.6v-flash
|
||||
# 【修正】根据官方文档,thinking 参数也是必须的 (即使是 Vision 模型)
|
||||
response = await asyncio.to_thread(
|
||||
client.chat.completions.create,
|
||||
model=MODEL,
|
||||
messages=messages,
|
||||
thinking={"type": "disabled"}, # 显式禁用思考以降低延迟
|
||||
)
|
||||
|
||||
if response.choices and len(response.choices) > 0:
|
||||
ai_reply = response.choices[0].message.content.strip()
|
||||
# 添加助手回复到历史
|
||||
_conversation_history.append({"role": "assistant", "content": ai_reply})
|
||||
print(f"[GLM] 回复: {ai_reply[:50]}..." if len(ai_reply) > 50 else f"[GLM] 回复: {ai_reply}")
|
||||
return ai_reply
|
||||
return ""
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e)
|
||||
# 检查是否是速率限制错误(429 或 1305)
|
||||
if "429" in error_str or "1305" in error_str or "请求过多" in error_str:
|
||||
if attempt < max_retries - 1:
|
||||
print(f"[GLM] 速率限制,{retry_delay}秒后重试... (尝试 {attempt + 1}/{max_retries})")
|
||||
await asyncio.sleep(retry_delay)
|
||||
retry_delay *= 2 # 指数退避
|
||||
continue
|
||||
|
||||
print(f"[GLM] 调用失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
break
|
||||
|
||||
# 所有重试失败,移除用户消息
|
||||
if _conversation_history and _conversation_history[-1]["role"] == "user":
|
||||
_conversation_history.pop()
|
||||
return "抱歉,我暂时无法回答。"
|
||||
|
||||
|
||||
async def chat_stream(user_message: str, image_base64: Optional[str] = None) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
流式对话(逐字返回)- GLM-4.6v-Flash
|
||||
|
||||
Args:
|
||||
user_message: 用户消息文本
|
||||
image_base64: 可选,Base64 编码的图片
|
||||
|
||||
Yields:
|
||||
AI 回复的文本片段
|
||||
"""
|
||||
global _conversation_history
|
||||
client = _get_client()
|
||||
|
||||
# 构建用户消息
|
||||
if image_base64:
|
||||
user_content = [
|
||||
{"type": "text", "text": user_message},
|
||||
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}
|
||||
]
|
||||
else:
|
||||
user_content = user_message
|
||||
|
||||
# 添加用户消息到历史
|
||||
_conversation_history.append({"role": "user", "content": user_content})
|
||||
|
||||
# 限制历史长度
|
||||
max_messages = MAX_HISTORY_TURNS * 2
|
||||
if len(_conversation_history) > max_messages:
|
||||
_conversation_history = _conversation_history[-max_messages:]
|
||||
|
||||
# 构建完整消息列表
|
||||
messages = [{"role": "system", "content": get_system_prompt()}] + _conversation_history
|
||||
|
||||
full_response = ""
|
||||
|
||||
try:
|
||||
# 流式调用
|
||||
# Day 22: 升级到 glm-4.6v-flash
|
||||
# 【修正】根据官方文档,thinking 参数也是必须的
|
||||
response = await asyncio.to_thread(
|
||||
client.chat.completions.create,
|
||||
model=MODEL,
|
||||
messages=messages,
|
||||
thinking={"type": "disabled"},
|
||||
stream=True,
|
||||
)
|
||||
|
||||
for chunk in response:
|
||||
if chunk.choices[0].delta.content:
|
||||
text = chunk.choices[0].delta.content
|
||||
full_response += text
|
||||
yield text
|
||||
|
||||
# 添加完整回复到历史
|
||||
if full_response:
|
||||
_conversation_history.append({"role": "assistant", "content": full_response})
|
||||
print(f"[GLM] 流式完成: {full_response[:50]}..." if len(full_response) > 50 else f"[GLM] 流式完成: {full_response}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"[GLM] 流式调用失败: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
# 移除刚才添加的用户消息
|
||||
if _conversation_history and _conversation_history[-1]["role"] == "user":
|
||||
_conversation_history.pop()
|
||||
yield "抱歉,我暂时无法回答。"
|
||||
271
gpu_parallel.py
Normal file
271
gpu_parallel.py
Normal file
@@ -0,0 +1,271 @@
|
||||
# gpu_parallel.py
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
Day 20: GPU 并行推理优化模块
|
||||
使用 CUDA Stream 让盲道检测和障碍物检测并行执行
|
||||
"""
|
||||
import os
|
||||
import time
|
||||
import torch
|
||||
import numpy as np
|
||||
from typing import Tuple, Optional, List, Any
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 全局 CUDA Stream(延迟初始化)
|
||||
_cuda_streams = None
|
||||
_parallel_executor = None
|
||||
|
||||
def _init_cuda_streams():
|
||||
"""初始化 CUDA Streams"""
|
||||
global _cuda_streams
|
||||
if _cuda_streams is None and torch.cuda.is_available():
|
||||
try:
|
||||
_cuda_streams = [torch.cuda.Stream() for _ in range(2)]
|
||||
logger.info("[GPU_PARALLEL] 已创建 2 个 CUDA Stream")
|
||||
except Exception as e:
|
||||
logger.warning(f"[GPU_PARALLEL] 创建 CUDA Stream 失败: {e}")
|
||||
_cuda_streams = []
|
||||
return _cuda_streams
|
||||
|
||||
def _init_parallel_executor():
|
||||
"""初始化并行执行器(用于 CPU 后处理)"""
|
||||
global _parallel_executor
|
||||
if _parallel_executor is None:
|
||||
_parallel_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="gpu_post")
|
||||
logger.info("[GPU_PARALLEL] 已创建 GPU 后处理线程池")
|
||||
return _parallel_executor
|
||||
|
||||
class ParallelDetector:
|
||||
"""
|
||||
并行检测器 - 同时执行盲道检测和障碍物检测
|
||||
|
||||
使用方式:
|
||||
detector = ParallelDetector(yolo_model, obstacle_detector)
|
||||
blind_mask, cross_mask, obstacles = detector.detect_all(image, path_mask)
|
||||
"""
|
||||
|
||||
def __init__(self, yolo_model, obstacle_detector):
|
||||
self.yolo_model = yolo_model
|
||||
self.obstacle_detector = obstacle_detector
|
||||
|
||||
# 检测参数(从环境变量读取)
|
||||
self.imgsz = int(os.getenv("AIGLASS_YOLO_IMGSZ", "480"))
|
||||
self.use_half = os.getenv("AIGLASS_YOLO_HALF", "1") == "1"
|
||||
self.blind_conf_threshold = 0.20
|
||||
self.cross_conf_threshold = 0.30
|
||||
|
||||
# 初始化
|
||||
_init_cuda_streams()
|
||||
_init_parallel_executor()
|
||||
|
||||
logger.info(f"[GPU_PARALLEL] ParallelDetector 初始化完成: imgsz={self.imgsz}, half={self.use_half}")
|
||||
|
||||
def detect_all(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
path_mask: Optional[np.ndarray] = None
|
||||
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List[Any]]:
|
||||
"""
|
||||
并行执行所有检测
|
||||
|
||||
Args:
|
||||
image: BGR 图像
|
||||
path_mask: 盲道掩码(用于障碍物过滤)
|
||||
|
||||
Returns:
|
||||
(blind_path_mask, crosswalk_mask, obstacles)
|
||||
"""
|
||||
t0 = time.perf_counter()
|
||||
|
||||
streams = _init_cuda_streams()
|
||||
|
||||
if streams and len(streams) >= 2:
|
||||
return self._detect_with_streams(image, path_mask, streams)
|
||||
else:
|
||||
# 回退到串行执行
|
||||
return self._detect_serial(image, path_mask)
|
||||
|
||||
def _detect_with_streams(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
path_mask: Optional[np.ndarray],
|
||||
streams: List[torch.cuda.Stream]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List[Any]]:
|
||||
"""Day 21: 使用 ThreadPoolExecutor 真正并行检测(替代无效的 CUDA Stream)"""
|
||||
|
||||
blind_mask = None
|
||||
cross_mask = None
|
||||
obstacles = []
|
||||
|
||||
executor = _init_parallel_executor()
|
||||
|
||||
# 定义两个检测任务
|
||||
def task_blind_path():
|
||||
if self.yolo_model is None:
|
||||
return None, None
|
||||
try:
|
||||
results = self.yolo_model.predict(
|
||||
image,
|
||||
verbose=False,
|
||||
conf=min(self.blind_conf_threshold, self.cross_conf_threshold),
|
||||
classes=[0, 1], # 0=crosswalk, 1=blind_path
|
||||
imgsz=self.imgsz,
|
||||
half=self.use_half
|
||||
)
|
||||
if results and results[0] and results[0].masks is not None:
|
||||
return self._parse_seg_results(results[0], image.shape)
|
||||
except Exception as e:
|
||||
logger.error(f"[GPU_PARALLEL] 盲道检测失败: {e}")
|
||||
return None, None
|
||||
|
||||
def task_obstacles():
|
||||
if self.obstacle_detector is None:
|
||||
return []
|
||||
try:
|
||||
return self.obstacle_detector.detect(image, path_mask=path_mask)
|
||||
except Exception as e:
|
||||
logger.error(f"[GPU_PARALLEL] 障碍物检测失败: {e}")
|
||||
return []
|
||||
|
||||
# 并行提交两个任务
|
||||
from concurrent.futures import as_completed
|
||||
futures = {
|
||||
executor.submit(task_blind_path): 'blind',
|
||||
executor.submit(task_obstacles): 'obstacle'
|
||||
}
|
||||
|
||||
# 等待所有任务完成
|
||||
for future in as_completed(futures, timeout=2.0):
|
||||
task_type = futures[future]
|
||||
try:
|
||||
result = future.result()
|
||||
if task_type == 'blind':
|
||||
blind_mask, cross_mask = result
|
||||
else:
|
||||
obstacles = result
|
||||
except Exception as e:
|
||||
logger.error(f"[GPU_PARALLEL] {task_type}任务异常: {e}")
|
||||
|
||||
return blind_mask, cross_mask, obstacles
|
||||
|
||||
def _detect_serial(
|
||||
self,
|
||||
image: np.ndarray,
|
||||
path_mask: Optional[np.ndarray]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List[Any]]:
|
||||
"""串行检测(回退模式)"""
|
||||
|
||||
blind_mask = None
|
||||
cross_mask = None
|
||||
obstacles = []
|
||||
|
||||
# 盲道检测
|
||||
if self.yolo_model is not None:
|
||||
try:
|
||||
results = self.yolo_model.predict(
|
||||
image,
|
||||
verbose=False,
|
||||
conf=min(self.blind_conf_threshold, self.cross_conf_threshold),
|
||||
classes=[0, 1],
|
||||
imgsz=self.imgsz,
|
||||
half=self.use_half
|
||||
)
|
||||
|
||||
if results and results[0] and results[0].masks is not None:
|
||||
blind_mask, cross_mask = self._parse_seg_results(results[0], image.shape)
|
||||
except Exception as e:
|
||||
logger.error(f"[GPU_PARALLEL] 盲道检测失败: {e}")
|
||||
|
||||
# 障碍物检测
|
||||
if self.obstacle_detector is not None:
|
||||
try:
|
||||
obstacles = self.obstacle_detector.detect(image, path_mask=path_mask)
|
||||
except Exception as e:
|
||||
logger.error(f"[GPU_PARALLEL] 障碍物检测失败: {e}")
|
||||
|
||||
return blind_mask, cross_mask, obstacles
|
||||
|
||||
def _parse_seg_results(
|
||||
self,
|
||||
result,
|
||||
image_shape: Tuple[int, int, int]
|
||||
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
||||
"""解析 YOLO 分割结果"""
|
||||
|
||||
blind_mask = None
|
||||
cross_mask = None
|
||||
|
||||
h, w = image_shape[:2]
|
||||
|
||||
if result.masks is None or result.boxes is None:
|
||||
return None, None
|
||||
|
||||
for mask_tensor, conf_tensor, cls_tensor in zip(
|
||||
result.masks.data, result.boxes.conf, result.boxes.cls
|
||||
):
|
||||
class_id = int(cls_tensor.item())
|
||||
confidence = float(conf_tensor.item())
|
||||
|
||||
# 置信度过滤
|
||||
if class_id == 1 and confidence < self.blind_conf_threshold:
|
||||
continue
|
||||
if class_id == 0 and confidence < self.cross_conf_threshold:
|
||||
continue
|
||||
|
||||
# 转换掩码
|
||||
current_mask = self._tensor_to_mask(mask_tensor, w, h)
|
||||
|
||||
if class_id == 1: # 盲道
|
||||
if blind_mask is None:
|
||||
blind_mask = current_mask
|
||||
else:
|
||||
blind_mask = np.bitwise_or(blind_mask, current_mask)
|
||||
elif class_id == 0: # 斑马线
|
||||
if cross_mask is None:
|
||||
cross_mask = current_mask
|
||||
else:
|
||||
cross_mask = np.bitwise_or(cross_mask, current_mask)
|
||||
|
||||
return blind_mask, cross_mask
|
||||
|
||||
def _tensor_to_mask(
|
||||
self,
|
||||
mask_tensor: torch.Tensor,
|
||||
out_w: int,
|
||||
out_h: int
|
||||
) -> np.ndarray:
|
||||
"""将 PyTorch 张量掩码转换为 NumPy 数组"""
|
||||
import cv2
|
||||
|
||||
# 转换为 numpy
|
||||
if mask_tensor.is_cuda:
|
||||
mask_np = mask_tensor.cpu().numpy()
|
||||
else:
|
||||
mask_np = mask_tensor.numpy()
|
||||
|
||||
# 调整大小
|
||||
if mask_np.shape[0] != out_h or mask_np.shape[1] != out_w:
|
||||
mask_np = cv2.resize(mask_np, (out_w, out_h), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
# 二值化
|
||||
mask_np = (mask_np > 0.5).astype(np.uint8) * 255
|
||||
|
||||
return mask_np
|
||||
|
||||
|
||||
def detect_all_parallel(
|
||||
yolo_model,
|
||||
obstacle_detector,
|
||||
image: np.ndarray,
|
||||
path_mask: Optional[np.ndarray] = None
|
||||
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List[Any]]:
|
||||
"""
|
||||
便捷函数:并行执行所有检测
|
||||
|
||||
用于替换 workflow_blindpath.py 中的串行检测
|
||||
"""
|
||||
detector = ParallelDetector(yolo_model, obstacle_detector)
|
||||
return detector.detect_all(image, path_mask)
|
||||
BIN
hand_landmarker.task
Normal file
BIN
hand_landmarker.task
Normal file
Binary file not shown.
BIN
model/SenseVoiceSmall/chn_jpn_yue_eng_ko_spectok.bpe.model
Normal file
BIN
model/SenseVoiceSmall/chn_jpn_yue_eng_ko_spectok.bpe.model
Normal file
Binary file not shown.
97
model/SenseVoiceSmall/config.yaml
Normal file
97
model/SenseVoiceSmall/config.yaml
Normal file
@@ -0,0 +1,97 @@
|
||||
encoder: SenseVoiceEncoderSmall
|
||||
encoder_conf:
|
||||
output_size: 512
|
||||
attention_heads: 4
|
||||
linear_units: 2048
|
||||
num_blocks: 50
|
||||
tp_blocks: 20
|
||||
dropout_rate: 0.1
|
||||
positional_dropout_rate: 0.1
|
||||
attention_dropout_rate: 0.1
|
||||
input_layer: pe
|
||||
pos_enc_class: SinusoidalPositionEncoder
|
||||
normalize_before: true
|
||||
kernel_size: 11
|
||||
sanm_shfit: 0
|
||||
selfattention_layer_type: sanm
|
||||
|
||||
|
||||
model: SenseVoiceSmall
|
||||
model_conf:
|
||||
length_normalized_loss: true
|
||||
sos: 1
|
||||
eos: 2
|
||||
ignore_id: -1
|
||||
|
||||
tokenizer: SentencepiecesTokenizer
|
||||
tokenizer_conf:
|
||||
bpemodel: null
|
||||
unk_symbol: <unk>
|
||||
split_with_space: true
|
||||
|
||||
frontend: WavFrontend
|
||||
frontend_conf:
|
||||
fs: 16000
|
||||
window: hamming
|
||||
n_mels: 80
|
||||
frame_length: 25
|
||||
frame_shift: 10
|
||||
lfr_m: 7
|
||||
lfr_n: 6
|
||||
cmvn_file: null
|
||||
|
||||
|
||||
dataset: SenseVoiceCTCDataset
|
||||
dataset_conf:
|
||||
index_ds: IndexDSJsonl
|
||||
batch_sampler: EspnetStyleBatchSampler
|
||||
data_split_num: 32
|
||||
batch_type: token
|
||||
batch_size: 14000
|
||||
max_token_length: 2000
|
||||
min_token_length: 60
|
||||
max_source_length: 2000
|
||||
min_source_length: 60
|
||||
max_target_length: 200
|
||||
min_target_length: 0
|
||||
shuffle: true
|
||||
num_workers: 4
|
||||
sos: ${model_conf.sos}
|
||||
eos: ${model_conf.eos}
|
||||
IndexDSJsonl: IndexDSJsonl
|
||||
retry: 20
|
||||
|
||||
train_conf:
|
||||
accum_grad: 1
|
||||
grad_clip: 5
|
||||
max_epoch: 20
|
||||
keep_nbest_models: 10
|
||||
avg_nbest_model: 10
|
||||
log_interval: 100
|
||||
resume: true
|
||||
validate_interval: 10000
|
||||
save_checkpoint_interval: 10000
|
||||
|
||||
optim: adamw
|
||||
optim_conf:
|
||||
lr: 0.00002
|
||||
scheduler: warmuplr
|
||||
scheduler_conf:
|
||||
warmup_steps: 25000
|
||||
|
||||
specaug: SpecAugLFR
|
||||
specaug_conf:
|
||||
apply_time_warp: false
|
||||
time_warp_window: 5
|
||||
time_warp_mode: bicubic
|
||||
apply_freq_mask: true
|
||||
freq_mask_width_range:
|
||||
- 0
|
||||
- 30
|
||||
lfr_rate: 6
|
||||
num_freq_mask: 1
|
||||
apply_time_mask: true
|
||||
time_mask_width_range:
|
||||
- 0
|
||||
- 12
|
||||
num_time_mask: 1
|
||||
14
model/SenseVoiceSmall/configuration.json
Normal file
14
model/SenseVoiceSmall/configuration.json
Normal file
@@ -0,0 +1,14 @@
|
||||
{
|
||||
"framework": "pytorch",
|
||||
"task" : "auto-speech-recognition",
|
||||
"model": {"type" : "funasr"},
|
||||
"pipeline": {"type":"funasr-pipeline"},
|
||||
"model_name_in_hub": {
|
||||
"ms":"",
|
||||
"hf":""},
|
||||
"file_path_metas": {
|
||||
"init_param":"model.pt",
|
||||
"config":"config.yaml",
|
||||
"tokenizer_conf": {"bpemodel": "chn_jpn_yue_eng_ko_spectok.bpe.model"},
|
||||
"frontend_conf":{"cmvn_file": "am.mvn"}}
|
||||
}
|
||||
37
model_utils.py
Normal file
37
model_utils.py
Normal file
@@ -0,0 +1,37 @@
|
||||
# model_utils.py - Day 20 TensorRT 模型加载工具
|
||||
"""
|
||||
优先加载 TensorRT .engine 文件,不存在时回退到 .pt
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def get_best_model_path(pt_path: str) -> str:
|
||||
"""
|
||||
根据 .pt 路径自动选择最佳模型文件:
|
||||
1. 优先加载同目录下的 .engine 文件(TensorRT 加速)
|
||||
2. 如果 .engine 不存在,回退到原 .pt 文件
|
||||
|
||||
参数:
|
||||
pt_path: 原始 .pt 模型路径(如 'model/yolo-seg.pt')
|
||||
|
||||
返回:
|
||||
最佳模型路径(.engine 或 .pt)
|
||||
"""
|
||||
if not pt_path.endswith('.pt'):
|
||||
return pt_path
|
||||
|
||||
engine_path = pt_path.replace('.pt', '.engine')
|
||||
|
||||
if os.path.exists(engine_path):
|
||||
print(f"[MODEL] 🚀 使用 TensorRT 加速: {engine_path}")
|
||||
return engine_path
|
||||
else:
|
||||
print(f"[MODEL] 使用 PyTorch 模型: {pt_path}")
|
||||
return pt_path
|
||||
|
||||
|
||||
def is_tensorrt_engine(model_path: str) -> bool:
|
||||
"""检查模型路径是否是 TensorRT 引擎(.engine 文件)"""
|
||||
return model_path.endswith('.engine') or model_path.endswith('.trt')
|
||||
|
||||
159
models.py
Normal file
159
models.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# app/models.py
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
from threading import Semaphore
|
||||
from contextlib import contextmanager
|
||||
from typing import List
|
||||
from app.cloud.obstacle_detector_client import ObstacleDetectorClient
|
||||
# ==========================================================
|
||||
# 0. 导入所有需要的模型封装类 (Clients) 和 Ultralytics 基类
|
||||
# ==========================================================
|
||||
# 这是过马路工作流使用的封装类
|
||||
from app.cloud.crosswalk_detector_client import CrosswalkDetector
|
||||
from app.cloud.coco_perception_client import COCOClient
|
||||
from obstacle_detector_client import ObstacleDetectorClient
|
||||
|
||||
# Day 20: TensorRT 模型加载工具
|
||||
from model_utils import get_best_model_path, is_tensorrt_engine
|
||||
|
||||
# 这是盲道工作流直接使用的 Ultralytics 类
|
||||
from ultralytics import YOLO, YOLOE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ==========================================================
|
||||
# 1. 全局设备与并发控制 (统一管理)
|
||||
# ==========================================================
|
||||
DEVICE = os.getenv("AIGLASS_DEVICE", "cuda:0")
|
||||
if DEVICE.startswith("cuda") and not torch.cuda.is_available():
|
||||
logger.warning(f"AIGLASS_DEVICE={DEVICE} 但未检测到 CUDA,将回退到 CPU")
|
||||
DEVICE = "cpu"
|
||||
IS_CUDA = DEVICE.startswith("cuda")
|
||||
|
||||
# AMP (自动混合精度) 配置
|
||||
AMP_POLICY = os.getenv("AIGLASS_AMP", "bf16").lower()
|
||||
AMP_DTYPE = torch.bfloat16 if AMP_POLICY == "bf16" else (
|
||||
torch.float16 if AMP_POLICY == "fp16" else None) if IS_CUDA else None
|
||||
|
||||
# 🔥 核心:全局唯一的GPU并发信号量,所有工作流共享
|
||||
GPU_SLOTS = int(os.getenv("AIGLASS_GPU_SLOTS", "2"))
|
||||
gpu_semaphore = Semaphore(GPU_SLOTS)
|
||||
|
||||
|
||||
# 统一的推理上下文管理器,所有工作流都应使用它来调用模型
|
||||
@contextmanager
|
||||
def gpu_infer_slot():
|
||||
"""
|
||||
统一管理:GPU 并发限流 + torch.inference_mode() + AMP autocast
|
||||
"""
|
||||
with gpu_semaphore:
|
||||
if IS_CUDA and AMP_POLICY != "off" and AMP_DTYPE is not None:
|
||||
with torch.inference_mode(), torch.amp.autocast('cuda', dtype=AMP_DTYPE):
|
||||
yield
|
||||
else:
|
||||
with torch.inference_mode():
|
||||
yield
|
||||
|
||||
|
||||
# cuDNN 加速优化
|
||||
try:
|
||||
if IS_CUDA:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ==========================================================
|
||||
# 2. 全局模型实例定义 (全部初始化为 None)
|
||||
# ==========================================================
|
||||
|
||||
# --- 过马路工作流模型 (通过Client类封装) ---
|
||||
crosswalk_detector_client: CrosswalkDetector = None
|
||||
coco_client: COCOClient = None
|
||||
# ObstacleDetectorClient 将作为所有场景的通用障碍物检测器
|
||||
obstacle_detector_client: ObstacleDetectorClient = None
|
||||
|
||||
# --- 盲道工作流模型 (直接使用Ultralytics类) ---
|
||||
# 它们主要用于分割和路径规划,与过马路场景的检测逻辑不同
|
||||
blindpath_seg_model: YOLO = None
|
||||
# 障碍物检测将复用 obstacle_detector_client,但YOLOE的文本特征需要单独保存
|
||||
blindpath_whitelist_embeddings = None
|
||||
|
||||
# 全局加载状态标志
|
||||
models_are_loaded = False
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# 3. 统一的模型加载函数 (由 celery.py 在启动时调用)
|
||||
# ==========================================================
|
||||
def init_all_models():
|
||||
"""
|
||||
在Celery Worker进程启动时被调用一次。
|
||||
负责加载所有工作流所需的模型到全局变量中。
|
||||
"""
|
||||
global models_are_loaded
|
||||
if models_are_loaded:
|
||||
return
|
||||
|
||||
logger.info(f"========= 🚀 开始全局模型预加载 (目标设备: {DEVICE}) =========")
|
||||
|
||||
try:
|
||||
# --- [1] 加载通用的障碍物检测器 (ObstacleDetectorClient) ---
|
||||
global obstacle_detector_client
|
||||
logger.info("[1/4] 正在加载通用障碍物检测模型 (ObstacleDetectorClient)...")
|
||||
# Day 20: 优先使用 TensorRT 引擎
|
||||
obs_model_path = get_best_model_path('model/yoloe-11l-seg.pt')
|
||||
obstacle_detector_client = ObstacleDetectorClient(model_path=obs_model_path)
|
||||
|
||||
# Day 20: TensorRT 引擎不需要 .to()
|
||||
if not is_tensorrt_engine(obs_model_path):
|
||||
if hasattr(obstacle_detector_client, 'model') and obstacle_detector_client.model is not None:
|
||||
obstacle_detector_client.model.to(DEVICE)
|
||||
|
||||
logger.info("...通用障碍物检测模型加载成功。")
|
||||
|
||||
# --- [2] 加载过马路专用的模型 (Clients) ---
|
||||
global crosswalk_detector_client, coco_client
|
||||
logger.info("[2/4] 正在加载过马路分割模型 (CrosswalkDetector)...")
|
||||
# Day 20: 优先使用 TensorRT 引擎
|
||||
crosswalk_model_path = get_best_model_path('model/yolo-seg.pt')
|
||||
crosswalk_detector_client = CrosswalkDetector(model_path=crosswalk_model_path)
|
||||
# Day 20: TensorRT 引擎不需要 .to()
|
||||
if not is_tensorrt_engine(crosswalk_model_path):
|
||||
if hasattr(crosswalk_detector_client, 'model') and crosswalk_detector_client.model is not None:
|
||||
crosswalk_detector_client.model.to(DEVICE)
|
||||
logger.info("...过马路分割模型加载成功。")
|
||||
|
||||
logger.info("[3/4] 正在加载通用感知模型 (COCOClient)...")
|
||||
coco_client = COCOClient(model_path='model/yolov8l-world.pt')
|
||||
# 将其内部的YOLO模型移动到指定设备
|
||||
if hasattr(coco_client, 'model') and coco_client.model is not None:
|
||||
coco_client.model.to(DEVICE)
|
||||
logger.info("...通用感知模型加载成功。")
|
||||
|
||||
# --- [4] 加载盲道专用的模型 ---
|
||||
global blindpath_seg_model, blindpath_whitelist_embeddings
|
||||
logger.info("[4/4] 正在加载盲道专用分割模型 (YOLO)...")
|
||||
# Day 20: 优先使用 TensorRT 引擎
|
||||
blindpath_model_path = get_best_model_path('model/yolo-seg.pt')
|
||||
blindpath_seg_model = YOLO(blindpath_model_path)
|
||||
# Day 20: TensorRT 引擎不需要 .to() 和 .fuse()
|
||||
if not is_tensorrt_engine(blindpath_model_path):
|
||||
blindpath_seg_model.to(DEVICE)
|
||||
blindpath_seg_model.fuse()
|
||||
logger.info("...盲道专用分割模型加载成功。")
|
||||
|
||||
# 为盲道工作流保存其需要的YOLOE文本特征引用
|
||||
if obstacle_detector_client:
|
||||
blindpath_whitelist_embeddings = obstacle_detector_client.whitelist_embeddings
|
||||
logger.info("...已为盲道工作流链接障碍物模型特征。")
|
||||
|
||||
# 所有模型加载完毕
|
||||
models_are_loaded = True
|
||||
logger.info("========= ✅ 所有模型已成功预加载。Worker准备就绪! =========")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模型预加载过程中发生严重错误: {e}", exc_info=True)
|
||||
# 抛出异常,这将导致Celery Worker启动失败,这是合理的行为
|
||||
# 因为一个没有模型的Worker是无用的,提前暴露问题更好。
|
||||
raise
|
||||
BIN
music/converted_向上.wav
Normal file
BIN
music/converted_向上.wav
Normal file
Binary file not shown.
BIN
music/converted_向下.wav
Normal file
BIN
music/converted_向下.wav
Normal file
Binary file not shown.
BIN
music/converted_向前.wav
Normal file
BIN
music/converted_向前.wav
Normal file
Binary file not shown.
BIN
music/converted_向右.wav
Normal file
BIN
music/converted_向右.wav
Normal file
Binary file not shown.
BIN
music/converted_向后.wav
Normal file
BIN
music/converted_向后.wav
Normal file
Binary file not shown.
BIN
music/converted_向左.wav
Normal file
BIN
music/converted_向左.wav
Normal file
Binary file not shown.
BIN
music/converted_已对中.wav
Normal file
BIN
music/converted_已对中.wav
Normal file
Binary file not shown.
BIN
music/converted_找到啦.wav
Normal file
BIN
music/converted_找到啦.wav
Normal file
Binary file not shown.
BIN
music/converted_拿到啦.wav
Normal file
BIN
music/converted_拿到啦.wav
Normal file
Binary file not shown.
BIN
music/converted_音频1.WAV
Normal file
BIN
music/converted_音频1.WAV
Normal file
Binary file not shown.
BIN
music/converted_音频2.WAV
Normal file
BIN
music/converted_音频2.WAV
Normal file
Binary file not shown.
BIN
music/converted_音频3.WAV
Normal file
BIN
music/converted_音频3.WAV
Normal file
Binary file not shown.
BIN
music/converted_音频4.WAV
Normal file
BIN
music/converted_音频4.WAV
Normal file
Binary file not shown.
BIN
music/converted_音频5.WAV
Normal file
BIN
music/converted_音频5.WAV
Normal file
Binary file not shown.
BIN
music/converted_音频6.WAV
Normal file
BIN
music/converted_音频6.WAV
Normal file
Binary file not shown.
BIN
music/converted_音频7.WAV
Normal file
BIN
music/converted_音频7.WAV
Normal file
Binary file not shown.
BIN
music/converted_音频8.WAV
Normal file
BIN
music/converted_音频8.WAV
Normal file
Binary file not shown.
BIN
music/converted_音频9.WAV
Normal file
BIN
music/converted_音频9.WAV
Normal file
Binary file not shown.
1
music/向上.txt
Normal file
1
music/向上.txt
Normal file
@@ -0,0 +1 @@
|
||||
向上。如果还有其他想法,你可以随时告诉我哦。
|
||||
BIN
music/向上.wav
Normal file
BIN
music/向上.wav
Normal file
Binary file not shown.
1
music/向下.txt
Normal file
1
music/向下.txt
Normal file
@@ -0,0 +1 @@
|
||||
向下。如果还有啥想法,你可以再跟我说哦。
|
||||
BIN
music/向下.wav
Normal file
BIN
music/向下.wav
Normal file
Binary file not shown.
1
music/向前.txt
Normal file
1
music/向前.txt
Normal file
@@ -0,0 +1 @@
|
||||
向前。如果还有啥想法,你可以再跟我说哦。
|
||||
BIN
music/向前.wav
Normal file
BIN
music/向前.wav
Normal file
Binary file not shown.
1
music/向右.txt
Normal file
1
music/向右.txt
Normal file
@@ -0,0 +1 @@
|
||||
向右。如果还有啥想法,你可以再跟我说哦。
|
||||
BIN
music/向右.wav
Normal file
BIN
music/向右.wav
Normal file
Binary file not shown.
1
music/向后.txt
Normal file
1
music/向后.txt
Normal file
@@ -0,0 +1 @@
|
||||
向后。如果还有啥想法,你可以再跟我说哦。
|
||||
BIN
music/向后.wav
Normal file
BIN
music/向后.wav
Normal file
Binary file not shown.
1
music/向左.txt
Normal file
1
music/向左.txt
Normal file
@@ -0,0 +1 @@
|
||||
向左。如果还有其他想法,你可以随时告诉我哦。
|
||||
BIN
music/向左.wav
Normal file
BIN
music/向左.wav
Normal file
Binary file not shown.
BIN
music/在画面中间.WAV
Normal file
BIN
music/在画面中间.WAV
Normal file
Binary file not shown.
1
music/在画面中间.txt
Normal file
1
music/在画面中间.txt
Normal file
@@ -0,0 +1 @@
|
||||
“在画面中间”
|
||||
BIN
music/在画面中间_24k.wav
Normal file
BIN
music/在画面中间_24k.wav
Normal file
Binary file not shown.
BIN
music/在画面右侧.WAV
Normal file
BIN
music/在画面右侧.WAV
Normal file
Binary file not shown.
1
music/在画面右侧.txt
Normal file
1
music/在画面右侧.txt
Normal file
@@ -0,0 +1 @@
|
||||
“在画面右侧”
|
||||
BIN
music/在画面右侧_24k.wav
Normal file
BIN
music/在画面右侧_24k.wav
Normal file
Binary file not shown.
BIN
music/在画面左侧.WAV
Normal file
BIN
music/在画面左侧.WAV
Normal file
Binary file not shown.
1
music/在画面左侧.txt
Normal file
1
music/在画面左侧.txt
Normal file
@@ -0,0 +1 @@
|
||||
“在画面左侧”
|
||||
BIN
music/在画面左侧_24k.wav
Normal file
BIN
music/在画面左侧_24k.wav
Normal file
Binary file not shown.
1
music/已对中.txt
Normal file
1
music/已对中.txt
Normal file
@@ -0,0 +1 @@
|
||||
已对正!如果还有其他想法或者问题,你可以随时告诉我哦。
|
||||
BIN
music/已对中.wav
Normal file
BIN
music/已对中.wav
Normal file
Binary file not shown.
1
music/找到啦.txt
Normal file
1
music/找到啦.txt
Normal file
@@ -0,0 +1 @@
|
||||
找到了!
|
||||
BIN
music/找到啦.wav
Normal file
BIN
music/找到啦.wav
Normal file
Binary file not shown.
1
music/拿到啦.txt
Normal file
1
music/拿到啦.txt
Normal file
@@ -0,0 +1 @@
|
||||
拿到了!如果还有啥问题,你可以再跟我说哦。
|
||||
BIN
music/拿到啦.wav
Normal file
BIN
music/拿到啦.wav
Normal file
Binary file not shown.
BIN
music/接近斑马线.WAV
Normal file
BIN
music/接近斑马线.WAV
Normal file
Binary file not shown.
1
music/接近斑马线.txt
Normal file
1
music/接近斑马线.txt
Normal file
@@ -0,0 +1 @@
|
||||
“接近斑马线”
|
||||
BIN
music/接近斑马线_24k.wav
Normal file
BIN
music/接近斑马线_24k.wav
Normal file
Binary file not shown.
BIN
music/斑马线到了可以过马路.WAV
Normal file
BIN
music/斑马线到了可以过马路.WAV
Normal file
Binary file not shown.
1
music/斑马线到了可以过马路.txt
Normal file
1
music/斑马线到了可以过马路.txt
Normal file
@@ -0,0 +1 @@
|
||||
“斑马线到了可以过马路”。如果还有类似的问题或者其他想法,你可以随时告诉我哦。
|
||||
BIN
music/斑马线到了可以过马路_24k.wav
Normal file
BIN
music/斑马线到了可以过马路_24k.wav
Normal file
Binary file not shown.
BIN
music/正在靠近斑马线.WAV
Normal file
BIN
music/正在靠近斑马线.WAV
Normal file
Binary file not shown.
1
music/正在靠近斑马线.txt
Normal file
1
music/正在靠近斑马线.txt
Normal file
@@ -0,0 +1 @@
|
||||
“正在靠近斑马线”
|
||||
BIN
music/正在靠近斑马线_24k.wav
Normal file
BIN
music/正在靠近斑马线_24k.wav
Normal file
Binary file not shown.
BIN
music/红灯.WAV
Normal file
BIN
music/红灯.WAV
Normal file
Binary file not shown.
BIN
music/绿灯.WAV
Normal file
BIN
music/绿灯.WAV
Normal file
Binary file not shown.
BIN
music/远处发现斑马线.WAV
Normal file
BIN
music/远处发现斑马线.WAV
Normal file
Binary file not shown.
1
music/远处发现斑马线.txt
Normal file
1
music/远处发现斑马线.txt
Normal file
@@ -0,0 +1 @@
|
||||
“远处发现斑马线”
|
||||
BIN
music/远处发现斑马线_24k.wav
Normal file
BIN
music/远处发现斑马线_24k.wav
Normal file
Binary file not shown.
BIN
music/音频1.WAV
Normal file
BIN
music/音频1.WAV
Normal file
Binary file not shown.
BIN
music/音频2.WAV
Normal file
BIN
music/音频2.WAV
Normal file
Binary file not shown.
BIN
music/音频3.WAV
Normal file
BIN
music/音频3.WAV
Normal file
Binary file not shown.
BIN
music/音频4.WAV
Normal file
BIN
music/音频4.WAV
Normal file
Binary file not shown.
BIN
music/音频5.WAV
Normal file
BIN
music/音频5.WAV
Normal file
Binary file not shown.
BIN
music/音频6.WAV
Normal file
BIN
music/音频6.WAV
Normal file
Binary file not shown.
BIN
music/音频7.WAV
Normal file
BIN
music/音频7.WAV
Normal file
Binary file not shown.
BIN
music/音频8.WAV
Normal file
BIN
music/音频8.WAV
Normal file
Binary file not shown.
BIN
music/音频9.WAV
Normal file
BIN
music/音频9.WAV
Normal file
Binary file not shown.
BIN
music/黄灯.WAV
Normal file
BIN
music/黄灯.WAV
Normal file
Binary file not shown.
700
navigation_master.py
Normal file
700
navigation_master.py
Normal file
@@ -0,0 +1,700 @@
|
||||
# navigation_master.py
|
||||
# -*- coding: utf-8 -*-
|
||||
import time
|
||||
import math
|
||||
import cv2
|
||||
import numpy as np
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional, Dict, Any, Deque, List, Tuple
|
||||
from collections import deque
|
||||
|
||||
# 工作流导入(与现有文件解耦)
|
||||
from workflow_blindpath import BlindPathNavigator, ProcessingResult as BlindResult
|
||||
from workflow_crossstreet import CrossStreetNavigator, CrossStreetResult as CrossResult
|
||||
|
||||
# ========== 状态常量 ==========
|
||||
IDLE = "IDLE" # 空闲/未启用
|
||||
CHAT = "CHAT" # 对话模式(不进行导航,只返回原始画面)
|
||||
BLINDPATH_NAV = "BLINDPATH_NAV" # 正在走盲道(复用 BlindPathNavigator)
|
||||
SEEKING_CROSSWALK = "SEEKING_CROSSWALK"# 盲道阶段发现斑马线,正对准/靠近
|
||||
WAIT_TRAFFIC_LIGHT = "WAIT_TRAFFIC_LIGHT" # 到达斑马线后等待交通灯(可选/占位)
|
||||
CROSSING = "CROSSING" # 正在过马路(复用 CrossStreetNavigator)
|
||||
SEEKING_NEXT_BLINDPATH = "SEEKING_NEXT_BLINDPATH" # 过完马路后寻找下一段盲道入口(上盲道)
|
||||
RECOVERY = "RECOVERY" # 兜底/恢复(感知暂时丢失时)
|
||||
TRAFFIC_LIGHT_DETECTION = "TRAFFIC_LIGHT_DETECTION" # 红绿灯检测模式
|
||||
ITEM_SEARCH = "ITEM_SEARCH" # 找物品模式(暂停导航,由yolomedia处理画面)
|
||||
|
||||
# ========== 返回结构 ==========
|
||||
@dataclass
|
||||
class OrchestratorResult:
|
||||
annotated_image: Optional[np.ndarray]
|
||||
guidance_text: str
|
||||
state: str
|
||||
extras: Dict[str, Any]
|
||||
|
||||
# ========== 实用:信号平滑/多数表决 ==========
|
||||
class MajorityFilter:
|
||||
def __init__(self, size: int = 8):
|
||||
self.buf: Deque[str] = deque(maxlen=size)
|
||||
|
||||
def push(self, v: str):
|
||||
self.buf.append(v)
|
||||
|
||||
def majority(self) -> str:
|
||||
if not self.buf:
|
||||
return "unknown"
|
||||
cnt = {}
|
||||
for v in self.buf:
|
||||
cnt[v] = cnt.get(v, 0) + 1
|
||||
# 稳健排序:unknown 权重最低
|
||||
items = sorted(cnt.items(), key=lambda x: (0 if x[0]=="unknown" else 1, x[1]), reverse=True)
|
||||
return items[0][0]
|
||||
|
||||
def history(self) -> List[str]:
|
||||
return list(self.buf)
|
||||
|
||||
def clear(self):
|
||||
self.buf.clear()
|
||||
|
||||
# ========== 红绿灯识别 ==========
|
||||
class TrafficLightDetector:
|
||||
"""
|
||||
红绿灯识别器:
|
||||
1) 优先尝试 yoloe_backend 风格的检测(如可用);
|
||||
2) 回退:无模型时,使用 HSV 颜色启发式在上半屏寻找亮红/黄/绿的“灯团”。
|
||||
输出:('red'|'green'|'yellow'|'unknown', meta)
|
||||
"""
|
||||
def __init__(self):
|
||||
self.has_backend = False
|
||||
self.backend = None
|
||||
try:
|
||||
# 尝试动态导入(根据你本地 yoloe_backend 的接口调整)
|
||||
import yoloe_backend as _yeb # noqa
|
||||
self.backend = _yeb
|
||||
self.has_backend = True
|
||||
except Exception:
|
||||
self.has_backend = False
|
||||
self.backend = None
|
||||
|
||||
def _try_backend(self, bgr: np.ndarray) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
尝试调用 yoloe_backend 风格的接口。由于各项目实现不同,这里做“宽容地调用”:
|
||||
- 优先尝试 backend.detect(image, target_classes=['traffic light'])
|
||||
- 次选 backend.infer_image(image) 后在结果中过滤 'traffic light'
|
||||
- 以上都失败则返回 unknown
|
||||
预期结果条目应含 bbox 或 mask,可自行扩展“颜色判定”逻辑(ROI 取样 HSV)
|
||||
"""
|
||||
if not self.has_backend or self.backend is None:
|
||||
return "unknown", {"reason": "backend_not_available"}
|
||||
|
||||
res = None
|
||||
try:
|
||||
if hasattr(self.backend, "detect"):
|
||||
# 假定 detect 返回 [{'name': 'traffic light', 'box':[x1,y1,x2,y2], ...}, ...]
|
||||
res = self.backend.detect(bgr, target_classes=["traffic light"])
|
||||
elif hasattr(self.backend, "infer_image"):
|
||||
# 假定 infer_image 返回 [{'label': 'traffic light', 'bbox': [x1,y1,x2,y2], ...}, ...]
|
||||
res = self.backend.infer_image(bgr)
|
||||
else:
|
||||
return "unknown", {"reason": "backend_no_suitable_api"}
|
||||
except Exception as e:
|
||||
return "unknown", {"reason": f"backend_failed:{e}"}
|
||||
|
||||
if not res or len(res) == 0:
|
||||
return "unknown", {"reason": "no_detection"}
|
||||
|
||||
# 拿到最大框作为主灯,做 HSV 颜色判断
|
||||
H, W = bgr.shape[:2]
|
||||
best = None
|
||||
best_area = 0
|
||||
boxes = []
|
||||
for item in res:
|
||||
# 统一盒字段
|
||||
if "box" in item and isinstance(item["box"], (list, tuple)) and len(item["box"]) == 4:
|
||||
x1, y1, x2, y2 = item["box"]
|
||||
elif "bbox" in item and isinstance(item["bbox"], (list, tuple)) and len(item["bbox"]) == 4:
|
||||
x1, y1, x2, y2 = item["bbox"]
|
||||
else:
|
||||
continue
|
||||
x1 = int(max(0, min(W-1, x1))); x2 = int(max(0, min(W-1, x2)))
|
||||
y1 = int(max(0, min(H-1, y1))); y2 = int(max(0, min(H-1, y2)))
|
||||
if x2 <= x1 or y2 <= y1:
|
||||
continue
|
||||
area = (x2 - x1) * (y2 - y1)
|
||||
boxes.append((x1, y1, x2, y2, area))
|
||||
if area > best_area:
|
||||
best_area = area
|
||||
best = (x1, y1, x2, y2)
|
||||
|
||||
if best is None:
|
||||
return "unknown", {"reason": "no_valid_bbox", "raw": len(res)}
|
||||
|
||||
x1, y1, x2, y2 = best
|
||||
roi = bgr[y1:y2, x1:x2]
|
||||
color = self._classify_color_hsv(roi)
|
||||
return color, {"bbox": best, "count": len(res), "boxes": boxes}
|
||||
|
||||
def _classify_color_hsv(self, roi_bgr: np.ndarray) -> str:
|
||||
"""对 ROI 做 HSV 基于阈值的红/黄/绿简单判定;取面积最大的主色。"""
|
||||
if roi_bgr is None or roi_bgr.size == 0:
|
||||
return "unknown"
|
||||
hsv = cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV)
|
||||
|
||||
# 红色范围(两段)
|
||||
lower_red1 = np.array([0, 80, 120]); upper_red1 = np.array([10, 255, 255])
|
||||
lower_red2 = np.array([160, 80, 120]); upper_red2 = np.array([180, 255, 255])
|
||||
mask_r1 = cv2.inRange(hsv, lower_red1, upper_red1)
|
||||
mask_r2 = cv2.inRange(hsv, lower_red2, upper_red2)
|
||||
mask_red = cv2.bitwise_or(mask_r1, mask_r2)
|
||||
|
||||
# 绿色
|
||||
lower_green = np.array([40, 60, 120]); upper_green = np.array([90, 255, 255])
|
||||
mask_green = cv2.inRange(hsv, lower_green, upper_green)
|
||||
|
||||
# 黄色
|
||||
lower_yellow = np.array([18, 80, 150]); upper_yellow = np.array([35, 255, 255])
|
||||
mask_yellow = cv2.inRange(hsv, lower_yellow, upper_yellow)
|
||||
|
||||
# 面积阈值(相对 ROI)
|
||||
total = roi_bgr.shape[0] * roi_bgr.shape[1] + 1e-6
|
||||
r_ratio = float(np.count_nonzero(mask_red)) / total
|
||||
g_ratio = float(np.count_nonzero(mask_green)) / total
|
||||
y_ratio = float(np.count_nonzero(mask_yellow)) / total
|
||||
|
||||
# 简单抑制“脏背景导致的弱响应”
|
||||
thr = 0.03
|
||||
candidates = []
|
||||
if r_ratio > thr: candidates.append(("red", r_ratio))
|
||||
if g_ratio > thr: candidates.append(("green", g_ratio))
|
||||
if y_ratio > thr: candidates.append(("yellow", y_ratio))
|
||||
if not candidates:
|
||||
return "unknown"
|
||||
candidates.sort(key=lambda x: x[1], reverse=True)
|
||||
return candidates[0][0]
|
||||
|
||||
def detect(self, bgr: np.ndarray) -> Tuple[str, Dict[str, Any]]:
|
||||
"""
|
||||
总入口:先尝试后端;失败则在上半屏自行找“亮色灯团”(无需框)。
|
||||
"""
|
||||
# 1) 尝试后端
|
||||
if self.has_backend:
|
||||
color, meta = self._try_backend(bgr)
|
||||
if color != "unknown":
|
||||
return color, {"method": "backend", **meta}
|
||||
|
||||
# 2) 回退:上半屏 HSV 聚类 + 连通域,选最大“灯团”判色
|
||||
H, W = bgr.shape[:2]
|
||||
roi = bgr[:int(H * 0.5), :]
|
||||
hsv = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV)
|
||||
|
||||
# 高亮阈值(抑制暗部/车灯)
|
||||
v = hsv[:, :, 2]
|
||||
bright = (v > 140).astype(np.uint8) * 255
|
||||
|
||||
# 粗分颜色
|
||||
col = self._classify_color_hsv(roi)
|
||||
return col, {"method": "fallback", "note": "no_backend", "bright_ratio": float(np.mean(bright > 0))}
|
||||
|
||||
# ========== 视觉辅助工具 ==========
|
||||
def _color_bgr(name: str) -> Tuple[int, int, int]:
|
||||
if name == "red": return (0, 0, 255)
|
||||
if name == "green": return (0, 255, 0)
|
||||
if name == "yellow": return (0, 255, 255)
|
||||
if name == "blue": return (255, 0, 0)
|
||||
if name == "orange": return (0, 165, 255)
|
||||
if name == "cyan": return (255, 255, 0)
|
||||
if name == "magenta": return (255, 0, 255)
|
||||
if name == "gray": return (128, 128, 128)
|
||||
if name == "white": return (255, 255, 255)
|
||||
return (200, 200, 200)
|
||||
|
||||
def _put_text(img, text, org, color=(255,255,255), scale=0.7, thick=2, outline=True):
|
||||
if outline:
|
||||
for dx in (-1,0,1):
|
||||
for dy in (-1,0,1):
|
||||
if dx==0 and dy==0: continue
|
||||
cv2.putText(img, text, (org[0]+dx, org[1]+dy), cv2.FONT_HERSHEY_SIMPLEX, scale, (0,0,0), thick+1)
|
||||
cv2.putText(img, text, org, cv2.FONT_HERSHEY_SIMPLEX, scale, color, thick)
|
||||
|
||||
def _draw_badge(img, text, pos=(10, 28), fg="white", bg="blue"):
|
||||
color_fg = _color_bgr(fg); color_bg = _color_bgr(bg)
|
||||
(tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
|
||||
x, y = pos
|
||||
pad = 6
|
||||
cv2.rectangle(img, (x-4, y-th-pad), (x+tw+8, y+pad//2), color_bg, -1)
|
||||
_put_text(img, text, (x, y), color=color_fg, scale=0.6, thick=2, outline=False)
|
||||
|
||||
def _draw_state_panel(img, kv: Dict[str, Any], pos=(10, 60)):
|
||||
x, y = pos
|
||||
line_h = 22
|
||||
for i, (k, v) in enumerate(kv.items()):
|
||||
_put_text(img, f"{k}: {v}", (x, y + i*line_h), color=(255,255,255), scale=0.6, thick=2)
|
||||
|
||||
def _draw_frame_border(img, color=(0,255,0), thickness=3):
|
||||
h, w = img.shape[:2]
|
||||
cv2.rectangle(img, (0,0), (w-1, h-1), color, thickness)
|
||||
|
||||
def _draw_progress_bar(img, ratio: float, pos=(10, 90), size=(180, 10), color="cyan"):
|
||||
ratio = max(0.0, min(1.0, float(ratio)))
|
||||
x, y = pos
|
||||
w, h = size
|
||||
cv2.rectangle(img, (x, y), (x+w, y+h), (80,80,80), 1)
|
||||
cv2.rectangle(img, (x+1, y+1), (x+1+int((w-2)*ratio), y+h-1), _color_bgr(color), -1)
|
||||
|
||||
# ========== 统领器 ==========
|
||||
class NavigationMaster:
|
||||
def __init__(self,
|
||||
blind_nav: BlindPathNavigator,
|
||||
cross_nav: CrossStreetNavigator,
|
||||
*,
|
||||
min_tts_interval: float = 1.2):
|
||||
self.blind = blind_nav
|
||||
self.cross = cross_nav
|
||||
self.state = IDLE
|
||||
self.last_guidance_ts = 0.0
|
||||
self.min_tts_interval = min_tts_interval
|
||||
|
||||
# 防抖/稳定计数
|
||||
self.cnt_crosswalk_seen = 0 # 盲道侧看见斑马线(approaching/ready)
|
||||
self.cnt_align_ready = 0 # 斑马线 ready + 对准达标
|
||||
self.cnt_cross_end = 0 # 过马路结束条件累计
|
||||
self.cnt_lost = 0 # 感知丢失累计(进入 RECOVERY)
|
||||
|
||||
# 冷却期避免状态抖动
|
||||
self.cooldown_until = 0.0
|
||||
|
||||
# 紧急恢复目标
|
||||
self.prev_target_state = BLINDPATH_NAV
|
||||
|
||||
# 交通灯
|
||||
self.tld = TrafficLightDetector()
|
||||
self.tl_major = MajorityFilter(size=8)
|
||||
self.tl_last_color = "unknown"
|
||||
|
||||
# 参数(可按现场再调)
|
||||
self.FRAMES_CROSS_SEEN = 8
|
||||
self.FRAMES_ALIGN_READY = 12
|
||||
self.FRAMES_CROSS_END = 12
|
||||
self.FRAMES_NEXT_BLIND_OK = 8
|
||||
self.FRAMES_LOST_MAX = 45
|
||||
|
||||
self.ANGLE_ALIGN_THR_DEG = 12.0
|
||||
self.OFFSET_ALIGN_THR = 0.15
|
||||
|
||||
self.COOLDOWN_SEC = 0.6
|
||||
|
||||
# 找物品状态管理
|
||||
self.prev_nav_state_before_search = None # 找物品前的导航状态,用于恢复
|
||||
|
||||
# ----- 外部交互 -----
|
||||
def get_state(self) -> str:
|
||||
return self.state
|
||||
|
||||
def start_blind_path_navigation(self):
|
||||
"""启动盲道导航模式"""
|
||||
self.state = BLINDPATH_NAV
|
||||
self.cooldown_until = time.time() + self.COOLDOWN_SEC
|
||||
if self.blind:
|
||||
self.blind.reset()
|
||||
|
||||
def stop_navigation(self):
|
||||
"""停止导航,回到对话模式"""
|
||||
self.state = CHAT
|
||||
self.cooldown_until = time.time() + self.COOLDOWN_SEC
|
||||
if self.blind:
|
||||
self.blind.reset()
|
||||
|
||||
def start_crossing(self):
|
||||
"""启动过马路模式"""
|
||||
self.state = CROSSING
|
||||
self.cooldown_until = time.time() + self.COOLDOWN_SEC
|
||||
if self.cross:
|
||||
self.cross.reset()
|
||||
|
||||
def start_traffic_light_detection(self):
|
||||
"""启动红绿灯检测模式"""
|
||||
self.state = TRAFFIC_LIGHT_DETECTION
|
||||
self.cooldown_until = time.time() + self.COOLDOWN_SEC
|
||||
|
||||
def is_in_navigation_mode(self):
|
||||
"""检查是否在导航模式(非对话模式)"""
|
||||
return self.state not in ["CHAT", "IDLE", "TRAFFIC_LIGHT_DETECTION", "ITEM_SEARCH"]
|
||||
|
||||
def start_item_search(self):
|
||||
"""启动找物品模式,暂停当前导航"""
|
||||
# 保存当前导航状态(如果在导航中)
|
||||
if self.state in [BLINDPATH_NAV, SEEKING_CROSSWALK, WAIT_TRAFFIC_LIGHT, CROSSING, SEEKING_NEXT_BLINDPATH]:
|
||||
self.prev_nav_state_before_search = self.state
|
||||
print(f"[NAV MASTER] 暂停导航状态 {self.state},切换到找物品模式")
|
||||
else:
|
||||
self.prev_nav_state_before_search = None
|
||||
|
||||
self.state = ITEM_SEARCH
|
||||
self.cooldown_until = time.time() + self.COOLDOWN_SEC
|
||||
|
||||
def stop_item_search(self, restore_nav: bool = True):
|
||||
"""停止找物品模式"""
|
||||
# 如果需要恢复之前的导航状态
|
||||
if restore_nav and self.prev_nav_state_before_search:
|
||||
self.state = self.prev_nav_state_before_search
|
||||
print(f"[NAV MASTER] 找物品结束,恢复到导航状态 {self.state}")
|
||||
self.prev_nav_state_before_search = None
|
||||
else:
|
||||
# 否则回到对话模式
|
||||
self.state = CHAT
|
||||
print(f"[NAV MASTER] 找物品结束,回到对话模式")
|
||||
|
||||
self.cooldown_until = time.time() + self.COOLDOWN_SEC
|
||||
|
||||
def force_state(self, s: str):
|
||||
self.state = s
|
||||
self.cooldown_until = time.time() + self.COOLDOWN_SEC
|
||||
|
||||
def on_voice_command(self, text: str):
|
||||
t = (text or "").strip()
|
||||
if "开始过马路" in t:
|
||||
# 直接进入等待/或立即过马路(低速环境可直过)
|
||||
if self.state in (BLINDPATH_NAV, SEEKING_CROSSWALK, WAIT_TRAFFIC_LIGHT, IDLE, RECOVERY, SEEKING_NEXT_BLINDPATH):
|
||||
self.state = WAIT_TRAFFIC_LIGHT
|
||||
self.cooldown_until = time.time() + self.COOLDOWN_SEC
|
||||
elif "立即通过" in t or "现在通过" in t:
|
||||
self.state = CROSSING
|
||||
self.cooldown_until = time.time() + self.COOLDOWN_SEC
|
||||
elif "停止" in t or "结束" in t:
|
||||
self.state = IDLE
|
||||
elif "继续" in t:
|
||||
if self.state == IDLE:
|
||||
self.state = BLINDPATH_NAV
|
||||
|
||||
def reset(self):
|
||||
self.state = IDLE
|
||||
self.cnt_crosswalk_seen = 0
|
||||
self.cnt_align_ready = 0
|
||||
self.cnt_cross_end = 0
|
||||
self.cnt_lost = 0
|
||||
self.tl_major.clear()
|
||||
self.tl_last_color = "unknown"
|
||||
self.prev_target_state = BLINDPATH_NAV
|
||||
self._last_wait_light_announce = 0 # 重置等待绿灯播报时间
|
||||
try:
|
||||
self.blind.reset()
|
||||
except Exception:
|
||||
pass
|
||||
try:
|
||||
self.cross.reset()
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ----- 内部工具 -----
|
||||
def _say(self, now: float, text: str) -> str:
|
||||
if not text:
|
||||
return ""
|
||||
if now - self.last_guidance_ts >= self.min_tts_interval:
|
||||
self.last_guidance_ts = now
|
||||
return text
|
||||
return ""
|
||||
|
||||
def _draw_tl_status(self, img: np.ndarray, color: str, meta: Dict[str, Any]):
|
||||
if img is None:
|
||||
return
|
||||
color_bgr = _color_bgr(color)
|
||||
# 角标与文本
|
||||
cv2.circle(img, (24, 24), 10, color_bgr, -1)
|
||||
_put_text(img, f"信号灯: {color}", (40, 30), color=color_bgr, scale=0.6, thick=2, outline=False)
|
||||
# 画 bbox(若有)
|
||||
if meta and "bbox" in meta:
|
||||
x1, y1, x2, y2 = meta["bbox"]
|
||||
cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), color_bgr, 2)
|
||||
|
||||
# 多数表决历史(最近8帧)
|
||||
hist = self.tl_major.history()
|
||||
if hist:
|
||||
x0, y0 = 10, 50
|
||||
r = 6
|
||||
gap = 16
|
||||
for i, hcol in enumerate(hist[-12:]):
|
||||
cv2.circle(img, (x0 + i*gap, y0), r, _color_bgr(hcol), -1)
|
||||
_put_text(img, "信号历史", (x0, y0+20), color=(255,255,255), scale=0.5, thick=1)
|
||||
|
||||
# ----- 主循环 -----
|
||||
def process_frame(self, bgr: np.ndarray) -> OrchestratorResult:
|
||||
now = time.time()
|
||||
|
||||
# 【修改】IDLE状态默认进入CHAT模式,而不是自动开始导航
|
||||
if self.state == IDLE:
|
||||
self.state = CHAT
|
||||
self.cooldown_until = now + self.COOLDOWN_SEC
|
||||
|
||||
# 【新增】CHAT模式:只返回原始画面,不进行导航
|
||||
if self.state == CHAT:
|
||||
return OrchestratorResult(
|
||||
annotated_image=bgr,
|
||||
guidance_text="",
|
||||
state="CHAT",
|
||||
extras={"mode": "对话模式"}
|
||||
)
|
||||
|
||||
# 【新增】红绿灯检测模式:只返回原始画面,由红绿灯模块处理
|
||||
if self.state == TRAFFIC_LIGHT_DETECTION:
|
||||
return OrchestratorResult(
|
||||
annotated_image=bgr,
|
||||
guidance_text="",
|
||||
state="TRAFFIC_LIGHT_DETECTION",
|
||||
extras={"mode": "红绿灯检测模式"}
|
||||
)
|
||||
|
||||
# 【新增】找物品模式:只返回原始画面,由yolomedia处理
|
||||
if self.state == ITEM_SEARCH:
|
||||
return OrchestratorResult(
|
||||
annotated_image=bgr,
|
||||
guidance_text="",
|
||||
state="ITEM_SEARCH",
|
||||
extras={"mode": "找物品模式", "prev_nav_state": self.prev_nav_state_before_search}
|
||||
)
|
||||
|
||||
# 冷却期内允许继续输出画面,但避免"瞬时切换"
|
||||
in_cooldown = now < self.cooldown_until
|
||||
|
||||
# 各状态处理
|
||||
if self.state in (BLINDPATH_NAV, SEEKING_CROSSWALK, SEEKING_NEXT_BLINDPATH, RECOVERY):
|
||||
# —— 盲道侧 —— 统一调用盲道导航器
|
||||
try:
|
||||
bres: BlindResult = self.blind.process_frame(bgr)
|
||||
except Exception as e:
|
||||
# 异常 → 进入恢复态
|
||||
self.state = RECOVERY
|
||||
self.cnt_lost += 5
|
||||
ann_err = bgr.copy()
|
||||
# 【移除】所有可视化干扰
|
||||
# _draw_badge(ann_err, "NAV ERROR", (10, 28), fg="white", bg="red")
|
||||
# _put_text(ann_err, str(e), (10, 56), color=(255,255,255), scale=0.55)
|
||||
return OrchestratorResult(ann_err, self._say(now, ""), self.state, {"error": str(e)})
|
||||
|
||||
ann = bres.annotated_image if bres.annotated_image is not None else bgr.copy()
|
||||
say = bres.guidance_text or ""
|
||||
|
||||
state_info = bres.state_info or {}
|
||||
cross_stage = state_info.get("crosswalk_stage", "not_detected")
|
||||
blind_state = state_info.get("state", "UNKNOWN")
|
||||
# 可选字段(若工作流未来补充)
|
||||
angle = float(state_info.get("last_angle", 0.0))
|
||||
center_x_ratio = float(state_info.get("last_center_x_ratio", 0.5))
|
||||
|
||||
# —— 盲道 → 发现斑马线(approaching/ready)
|
||||
if self.state == BLINDPATH_NAV:
|
||||
if cross_stage in ("approaching", "ready"):
|
||||
self.cnt_crosswalk_seen += 1
|
||||
else:
|
||||
self.cnt_crosswalk_seen = max(0, self.cnt_crosswalk_seen - 1)
|
||||
|
||||
if self.cnt_crosswalk_seen >= self.FRAMES_CROSS_SEEN and not in_cooldown:
|
||||
self.state = SEEKING_CROSSWALK
|
||||
self.cooldown_until = now + self.COOLDOWN_SEC
|
||||
say = "正在接近斑马线,为您对准方向。"
|
||||
|
||||
# 【移除】所有可视化干扰
|
||||
# _draw_badge(ann, f"STATE: {self.state}", (10, 28), fg="white", bg="blue")
|
||||
# _draw_state_panel(ann, {
|
||||
# "盲道状态": blind_state,
|
||||
# "斑马线阶段": cross_stage,
|
||||
# "靠近计数": self.cnt_crosswalk_seen,
|
||||
# }, pos=(10, 60))
|
||||
# _draw_progress_bar(ann, max(0.0, min(1.0, self.cnt_crosswalk_seen / max(1, self.FRAMES_CROSS_SEEN))), pos=(10, 120), size=(180, 10), color="cyan")
|
||||
# _draw_frame_border(ann, color=_color_bgr("blue"), thickness=3)
|
||||
|
||||
# —— 对准阶段:同时利用 blind 内部 crosswalk_tracker 的角度与偏移(若提供)
|
||||
elif self.state == SEEKING_CROSSWALK:
|
||||
aligned = (abs(angle) <= self.ANGLE_ALIGN_THR_DEG and abs(center_x_ratio - 0.5) <= self.OFFSET_ALIGN_THR)
|
||||
if cross_stage == "ready" and aligned:
|
||||
self.cnt_align_ready += 1
|
||||
else:
|
||||
self.cnt_align_ready = max(0, self.cnt_align_ready - 1)
|
||||
|
||||
if self.cnt_align_ready >= self.FRAMES_ALIGN_READY and not in_cooldown:
|
||||
self.state = WAIT_TRAFFIC_LIGHT
|
||||
self.cooldown_until = now + self.COOLDOWN_SEC
|
||||
say = "已到达斑马线,请等待红绿灯。"
|
||||
|
||||
# 【移除】所有可视化干扰
|
||||
# _draw_badge(ann, f"STATE: {self.state}", (10, 28), fg="white", bg="orange")
|
||||
# panel = {
|
||||
# "阶段": cross_stage,
|
||||
# "对准计数": self.cnt_align_ready,
|
||||
# }
|
||||
# if "last_angle" in state_info:
|
||||
# panel["角度(°)"] = f"{angle:.1f}"
|
||||
# if "last_center_x_ratio" in state_info:
|
||||
# panel["偏移"] = f"{(center_x_ratio-0.5):+.2f}"
|
||||
# _draw_state_panel(ann, panel, pos=(10, 60))
|
||||
# _draw_progress_bar(ann, max(0.0, min(1.0, self.cnt_align_ready / max(1, self.FRAMES_ALIGN_READY))), pos=(10, 120), size=(220, 10), color="yellow")
|
||||
# _draw_frame_border(ann, color=_color_bgr("orange"), thickness=3)
|
||||
|
||||
# —— 过马路后寻找下一段盲道(上盲道流程)
|
||||
elif self.state == SEEKING_NEXT_BLINDPATH:
|
||||
if blind_state == "NAVIGATING":
|
||||
self.cnt_cross_end += 1
|
||||
else:
|
||||
self.cnt_cross_end = max(0, self.cnt_cross_end - 1)
|
||||
if self.cnt_cross_end >= self.FRAMES_NEXT_BLIND_OK and not in_cooldown:
|
||||
self.state = BLINDPATH_NAV
|
||||
self.cooldown_until = now + self.COOLDOWN_SEC
|
||||
say = "方向正确,请继续前进。"
|
||||
|
||||
# 【移除】所有可视化干扰
|
||||
# _draw_badge(ann, f"STATE: {self.state}", (10, 28), fg="white", bg="green")
|
||||
# _draw_state_panel(ann, {
|
||||
# "盲道状态": blind_state,
|
||||
# "回归计数": self.cnt_cross_end
|
||||
# }, pos=(10, 60))
|
||||
# _draw_progress_bar(ann, max(0.0, min(1.0, self.cnt_cross_end / max(1, self.FRAMES_NEXT_BLIND_OK))), pos=(10, 120), size=(200, 10), color="green")
|
||||
# _draw_frame_border(ann, color=_color_bgr("green"), thickness=3)
|
||||
|
||||
# —— 恢复态:一旦盲道恢复可用则回盲道
|
||||
elif self.state == RECOVERY:
|
||||
if blind_state in ("ONBOARDING", "NAVIGATING"):
|
||||
self.state = BLINDPATH_NAV
|
||||
self.cooldown_until = now + self.COOLDOWN_SEC
|
||||
say = ""
|
||||
else:
|
||||
say = ""
|
||||
# 【移除】所有可视化干扰
|
||||
# _draw_badge(ann, f"STATE: {self.state}", (10, 28), fg="white", bg="red")
|
||||
# _draw_state_panel(ann, {
|
||||
# "提示": "请缓慢环顾/抬头/降低手机角度",
|
||||
# "丢失计数": self.cnt_lost
|
||||
# }, pos=(10, 60))
|
||||
# _draw_frame_border(ann, color=_color_bgr("red"), thickness=3)
|
||||
|
||||
# 丢失计数(兜底)
|
||||
if blind_state == "UNKNOWN" and cross_stage == "not_detected":
|
||||
self.cnt_lost += 1
|
||||
else:
|
||||
self.cnt_lost = max(0, self.cnt_lost - 2)
|
||||
if self.cnt_lost >= self.FRAMES_LOST_MAX and self.state != RECOVERY:
|
||||
self.prev_target_state = self.state
|
||||
self.state = RECOVERY
|
||||
self.cooldown_until = now + self.COOLDOWN_SEC
|
||||
say = "环境复杂,进入恢复模式。"
|
||||
|
||||
# 【移除】冷却进度条
|
||||
# if in_cooldown:
|
||||
# remain = max(0.0, self.cooldown_until - now)
|
||||
# ratio = 1.0 - min(1.0, remain / self.COOLDOWN_SEC)
|
||||
# _draw_progress_bar(ann, ratio, pos=(10, 140), size=(160, 8), color="gray")
|
||||
|
||||
return OrchestratorResult(ann, self._say(now, say), self.state, {"source": "blind", "cross_stage": cross_stage, "blind_state": blind_state})
|
||||
|
||||
if self.state == WAIT_TRAFFIC_LIGHT:
|
||||
ann = bgr.copy()
|
||||
# 红绿灯识别(多数表决+冷却)
|
||||
color, meta = self.tld.detect(bgr)
|
||||
self.tl_major.push(color)
|
||||
major = self.tl_major.majority()
|
||||
self.tl_last_color = major
|
||||
|
||||
# 【移除】所有可视化干扰
|
||||
# _draw_badge(ann, f"STATE: {self.state}", (10, 28), fg="white", bg="magenta")
|
||||
# self._draw_tl_status(ann, major, meta)
|
||||
# _draw_state_panel(ann, {
|
||||
# "提示": "请等待绿灯或语音确认"立即通过"",
|
||||
# "冷却": f"{max(0.0, self.cooldown_until - now):.1f}s"
|
||||
# }, pos=(10, 80))
|
||||
# _draw_frame_border(ann, color=_color_bgr("magenta"), thickness=3)
|
||||
|
||||
say = ""
|
||||
if major == "green" and not in_cooldown:
|
||||
self.state = CROSSING
|
||||
self.cooldown_until = now + self.COOLDOWN_SEC
|
||||
say = "绿灯稳定,开始通行。"
|
||||
else:
|
||||
# 只在刚进入状态或每隔一段时间才播报
|
||||
if not hasattr(self, '_last_wait_light_announce'):
|
||||
self._last_wait_light_announce = 0
|
||||
if now - self._last_wait_light_announce > 5.0: # 5秒播报一次
|
||||
say = "正在等待绿灯…"
|
||||
self._last_wait_light_announce = now
|
||||
|
||||
|
||||
|
||||
# 【移除】冷却进度
|
||||
# if in_cooldown:
|
||||
# remain = max(0.0, self.cooldown_until - now)
|
||||
# ratio = 1.0 - min(1.0, remain / self.COOLDOWN_SEC)
|
||||
# _draw_progress_bar(ann, ratio, pos=(10, 140), size=(160, 8), color="gray")
|
||||
|
||||
return OrchestratorResult(ann, self._say(now, say), self.state, {"traffic_light": major})
|
||||
|
||||
if self.state == CROSSING:
|
||||
try:
|
||||
cres: CrossResult = self.cross.process_frame(bgr)
|
||||
except Exception as e:
|
||||
# 异常 → 恢复
|
||||
self.state = RECOVERY
|
||||
ann_err = bgr.copy()
|
||||
# 【移除】所有可视化干扰
|
||||
# _draw_badge(ann_err, "CROSS ERROR", (10, 28), fg="white", bg="red")
|
||||
# _put_text(ann_err, str(e), (10, 56), color=(255,255,255), scale=0.55)
|
||||
return OrchestratorResult(ann_err, self._say(now, ""), self.state, {"error": str(e)})
|
||||
|
||||
ann = cres.annotated_image if cres.annotated_image is not None else bgr.copy()
|
||||
say = cres.guidance_text or ""
|
||||
|
||||
# 新增:检查是否检测到盲道
|
||||
blind_path_detected = getattr(cres, 'blind_path_detected', False)
|
||||
blind_path_guidance = getattr(cres, 'blind_path_guidance', "")
|
||||
|
||||
# 如果检测到盲道且需要引导,优先处理盲道引导
|
||||
if blind_path_detected and blind_path_guidance:
|
||||
# 如果应该切换到盲道导航(盲道很近),直接切换状态
|
||||
if hasattr(cres, "should_switch_to_blindpath") and cres.should_switch_to_blindpath:
|
||||
if not in_cooldown:
|
||||
self.state = BLINDPATH_NAV
|
||||
self.cooldown_until = now + self.COOLDOWN_SEC
|
||||
say = "已到盲道跟前,切换到盲道导航。" # 使用现有语音文件
|
||||
self.cnt_cross_end = 0 # 重置计数器
|
||||
# 重置盲道导航器状态
|
||||
if hasattr(self.blind, 'reset'):
|
||||
self.blind.reset()
|
||||
else:
|
||||
# 盲道较远,继续过马路但给出盲道引导
|
||||
# say 已经在 cres.guidance_text 中包含了盲道引导信息
|
||||
pass
|
||||
|
||||
# 原有的结束条件:连续多帧"寻找斑马线"
|
||||
end_hint = False
|
||||
if "寻找斑马线" in (say or ""):
|
||||
end_hint = True
|
||||
# 注意:不再单纯因为 should_switch_to_blindpath 就结束过马路
|
||||
# if hasattr(cres, "should_switch_to_blindpath") and cres.should_switch_to_blindpath:
|
||||
# end_hint = True
|
||||
|
||||
self.cnt_cross_end = self.cnt_cross_end + 1 if end_hint else max(0, self.cnt_cross_end - 1)
|
||||
|
||||
if self.cnt_cross_end >= self.FRAMES_CROSS_END and not in_cooldown:
|
||||
self.state = SEEKING_NEXT_BLINDPATH
|
||||
self.cooldown_until = now + self.COOLDOWN_SEC
|
||||
say = "过马路结束,准备上人行道。"
|
||||
|
||||
# 【移除】所有可视化干扰
|
||||
# _draw_badge(ann, f"STATE: {self.state}", (10, 28), fg="white", bg="cyan")
|
||||
# _draw_state_panel(ann, {
|
||||
# "结束计数": self.cnt_cross_end,
|
||||
# "冷却": f"{max(0.0, self.cooldown_until - now):.1f}s"
|
||||
# }, pos=(10, 60))
|
||||
# _draw_progress_bar(ann, max(0.0, min(1.0, self.cnt_cross_end / max(1, self.FRAMES_CROSS_END))), pos=(10, 120), size=(220, 10), color="cyan")
|
||||
# _draw_frame_border(ann, color=_color_bgr("cyan"), thickness=3)
|
||||
# if in_cooldown:
|
||||
# remain = max(0.0, self.cooldown_until - now)
|
||||
# ratio = 1.0 - min(1.0, remain / self.COOLDOWN_SEC)
|
||||
# _draw_progress_bar(ann, ratio, pos=(10, 140), size=(160, 8), color="gray")
|
||||
|
||||
return OrchestratorResult(ann, self._say(now, say), self.state, {"source": "cross", "end_cnt": self.cnt_cross_end})
|
||||
|
||||
# 兜底
|
||||
ann = bgr.copy()
|
||||
# 【移除】所有可视化干扰
|
||||
# _draw_badge(ann, f"STATE: {self.state}", (10, 28), fg="white", bg="gray")
|
||||
# _draw_frame_border(ann, color=_color_bgr("gray"), thickness=2)
|
||||
return OrchestratorResult(ann, "", self.state, {})
|
||||
|
||||
|
||||
185
numba_utils.py
Normal file
185
numba_utils.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# numba_utils.py - Day 20 Numba 多核加速工具
|
||||
"""
|
||||
使用 Numba JIT 编译加速 numpy 密集操作,绕过 Python GIL 实现真正多核并行
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from numba import jit, prange
|
||||
NUMBA_AVAILABLE = True
|
||||
except ImportError:
|
||||
NUMBA_AVAILABLE = False
|
||||
print("[NUMBA] Numba 未安装,使用 numpy 回退实现")
|
||||
|
||||
|
||||
if NUMBA_AVAILABLE:
|
||||
@jit(nopython=True, parallel=True, cache=True)
|
||||
def count_mask_pixels_numba(mask: np.ndarray) -> int:
|
||||
"""快速计算 mask 中非零像素数量(多核并行)"""
|
||||
count = 0
|
||||
h, w = mask.shape
|
||||
for i in prange(h):
|
||||
for j in range(w):
|
||||
if mask[i, j] > 0:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@jit(nopython=True, parallel=True, cache=True)
|
||||
def compute_mask_stats_numba(mask: np.ndarray) -> tuple:
|
||||
"""
|
||||
快速计算 mask 的统计信息(多核并行)
|
||||
返回: (area, center_x, center_y, min_x, max_x, min_y, max_y)
|
||||
"""
|
||||
h, w = mask.shape
|
||||
count = 0
|
||||
sum_x = 0.0
|
||||
sum_y = 0.0
|
||||
min_x = w
|
||||
max_x = 0
|
||||
min_y = h
|
||||
max_y = 0
|
||||
|
||||
for i in prange(h):
|
||||
for j in range(w):
|
||||
if mask[i, j] > 0:
|
||||
count += 1
|
||||
sum_x += j
|
||||
sum_y += i
|
||||
if j < min_x:
|
||||
min_x = j
|
||||
if j > max_x:
|
||||
max_x = j
|
||||
if i < min_y:
|
||||
min_y = i
|
||||
if i > max_y:
|
||||
max_y = i
|
||||
|
||||
if count > 0:
|
||||
center_x = sum_x / count
|
||||
center_y = sum_y / count
|
||||
else:
|
||||
center_x = 0.0
|
||||
center_y = 0.0
|
||||
|
||||
return (count, center_x, center_y, min_x, max_x, min_y, max_y)
|
||||
|
||||
@jit(nopython=True, parallel=True, cache=True)
|
||||
def bitwise_and_count_numba(mask1: np.ndarray, mask2: np.ndarray) -> int:
|
||||
"""快速计算两个 mask 的交集像素数量(多核并行)"""
|
||||
h, w = mask1.shape
|
||||
count = 0
|
||||
for i in prange(h):
|
||||
for j in range(w):
|
||||
if mask1[i, j] > 0 and mask2[i, j] > 0:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@jit(nopython=True, parallel=True, cache=True)
|
||||
def resize_mask_nearest_numba(mask: np.ndarray, new_h: int, new_w: int) -> np.ndarray:
|
||||
"""
|
||||
快速最近邻插值缩放 mask(多核并行)
|
||||
注意:这是简化实现,对于大多数情况足够用
|
||||
"""
|
||||
old_h, old_w = mask.shape
|
||||
result = np.zeros((new_h, new_w), dtype=np.uint8)
|
||||
|
||||
scale_y = old_h / new_h
|
||||
scale_x = old_w / new_w
|
||||
|
||||
for i in prange(new_h):
|
||||
for j in range(new_w):
|
||||
src_y = int(i * scale_y)
|
||||
src_x = int(j * scale_x)
|
||||
if src_y >= old_h:
|
||||
src_y = old_h - 1
|
||||
if src_x >= old_w:
|
||||
src_x = old_w - 1
|
||||
result[i, j] = mask[src_y, src_x]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 对外接口:根据 Numba 是否可用选择实现
|
||||
def count_mask_pixels(mask: np.ndarray) -> int:
|
||||
"""计算 mask 中非零像素数量"""
|
||||
if NUMBA_AVAILABLE:
|
||||
return count_mask_pixels_numba(mask)
|
||||
else:
|
||||
return int(np.sum(mask > 0))
|
||||
|
||||
|
||||
def compute_mask_stats(mask: np.ndarray) -> dict:
|
||||
"""
|
||||
计算 mask 的统计信息
|
||||
返回: {'area': int, 'center_x': float, 'center_y': float, 'bbox': (x1, y1, x2, y2)}
|
||||
"""
|
||||
if NUMBA_AVAILABLE:
|
||||
area, cx, cy, min_x, max_x, min_y, max_y = compute_mask_stats_numba(mask)
|
||||
return {
|
||||
'area': int(area),
|
||||
'center_x': float(cx),
|
||||
'center_y': float(cy),
|
||||
'bbox': (int(min_x), int(min_y), int(max_x), int(max_y))
|
||||
}
|
||||
else:
|
||||
# numpy 回退
|
||||
y_coords, x_coords = np.where(mask > 0)
|
||||
if len(y_coords) == 0:
|
||||
return {'area': 0, 'center_x': 0, 'center_y': 0, 'bbox': (0, 0, 0, 0)}
|
||||
return {
|
||||
'area': len(y_coords),
|
||||
'center_x': float(np.mean(x_coords)),
|
||||
'center_y': float(np.mean(y_coords)),
|
||||
'bbox': (int(np.min(x_coords)), int(np.min(y_coords)),
|
||||
int(np.max(x_coords)), int(np.max(y_coords)))
|
||||
}
|
||||
|
||||
|
||||
def bitwise_and_count(mask1: np.ndarray, mask2: np.ndarray) -> int:
|
||||
"""计算两个 mask 的交集像素数量"""
|
||||
if NUMBA_AVAILABLE:
|
||||
return bitwise_and_count_numba(mask1.astype(np.uint8), mask2.astype(np.uint8))
|
||||
else:
|
||||
return int(np.sum(np.bitwise_and(mask1, mask2) > 0))
|
||||
|
||||
|
||||
# 预热 JIT 编译(首次调用时编译,之后使用缓存)
|
||||
def warmup():
|
||||
"""预热 Numba JIT 编译,避免首次调用时的延迟"""
|
||||
if NUMBA_AVAILABLE:
|
||||
dummy = np.zeros((10, 10), dtype=np.uint8)
|
||||
dummy[5, 5] = 255
|
||||
count_mask_pixels_numba(dummy)
|
||||
compute_mask_stats_numba(dummy)
|
||||
bitwise_and_count_numba(dummy, dummy)
|
||||
print("[NUMBA] JIT 编译预热完成,已启用多核加速")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试和性能对比
|
||||
import time
|
||||
|
||||
# 创建测试数据
|
||||
test_mask = np.zeros((480, 640), dtype=np.uint8)
|
||||
test_mask[100:300, 200:400] = 255
|
||||
|
||||
# 测试 numpy 版本
|
||||
start = time.perf_counter()
|
||||
for _ in range(100):
|
||||
np.sum(test_mask > 0)
|
||||
numpy_time = (time.perf_counter() - start) * 1000
|
||||
|
||||
# 测试 numba 版本
|
||||
if NUMBA_AVAILABLE:
|
||||
# 预热
|
||||
count_mask_pixels_numba(test_mask)
|
||||
|
||||
start = time.perf_counter()
|
||||
for _ in range(100):
|
||||
count_mask_pixels_numba(test_mask)
|
||||
numba_time = (time.perf_counter() - start) * 1000
|
||||
|
||||
print(f"numpy: {numpy_time:.2f}ms / 100 次")
|
||||
print(f"numba: {numba_time:.2f}ms / 100 次")
|
||||
print(f"加速比: {numpy_time / numba_time:.1f}x")
|
||||
244
obstacle_detector_client.py
Normal file
244
obstacle_detector_client.py
Normal file
@@ -0,0 +1,244 @@
|
||||
# app/cloud/obstacle_detector_client.py (新文件)
|
||||
import logging
|
||||
import os
|
||||
import cv2
|
||||
import numpy as np
|
||||
import torch
|
||||
from threading import Semaphore
|
||||
from contextlib import contextmanager
|
||||
from ultralytics import YOLOE
|
||||
from typing import List, Dict, Any
|
||||
|
||||
# Day 20: Numba 多核加速
|
||||
try:
|
||||
from numba_utils import count_mask_pixels, compute_mask_stats, bitwise_and_count, warmup as numba_warmup
|
||||
NUMBA_ENABLED = True
|
||||
except ImportError:
|
||||
NUMBA_ENABLED = False
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# --- GPU/CPU & AMP 配置 (从 blindpath 工作流迁移而来,保持一致) ---
|
||||
DEVICE = os.getenv("AIGLASS_DEVICE", "cuda:0")
|
||||
if DEVICE.startswith("cuda") and not torch.cuda.is_available():
|
||||
logger.warning(f"AIGLASS_DEVICE={DEVICE} 但未检测到 CUDA,将回退到 CPU")
|
||||
DEVICE = "cpu"
|
||||
IS_CUDA = DEVICE.startswith("cuda")
|
||||
|
||||
AMP_POLICY = os.getenv("AIGLASS_AMP", "fp16").lower()
|
||||
if AMP_POLICY not in ("bf16", "fp16", "off"):
|
||||
AMP_POLICY = "fp16"
|
||||
AMP_DTYPE = torch.bfloat16 if AMP_POLICY == "bf16" else (torch.float16 if AMP_POLICY == "fp16" else None)
|
||||
|
||||
# --- GPU 并发限流 (从 blindpath 工作流迁移而来,保持一致) ---
|
||||
# Day 20: 增加默认槽位从 2 到 4,RTX 3090 可以处理更多并发
|
||||
GPU_SLOTS = int(os.getenv("AIGLASS_GPU_SLOTS", "4"))
|
||||
_gpu_slots = Semaphore(GPU_SLOTS)
|
||||
|
||||
try:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
@contextmanager
|
||||
def gpu_infer_slot():
|
||||
"""统一管理 GPU 并发限流 + inference_mode + AMP autocast"""
|
||||
with _gpu_slots:
|
||||
if IS_CUDA and AMP_POLICY != "off":
|
||||
# 新式接口:torch.amp.autocast(device_type='cuda', dtype=...)
|
||||
with torch.inference_mode(), torch.amp.autocast(device_type='cuda', dtype=AMP_DTYPE):
|
||||
yield
|
||||
else:
|
||||
with torch.inference_mode():
|
||||
yield
|
||||
|
||||
|
||||
class ObstacleDetectorClient:
|
||||
def __init__(self, model_path: str = 'model/yoloe-11l-seg.pt'):
|
||||
self.model = None
|
||||
self.whitelist_embeddings = None
|
||||
self.WHITELIST_CLASSES = [
|
||||
'bicycle', 'car', 'motorcycle', 'bus', 'truck', 'animal', 'scooter', 'stroller', 'dog',
|
||||
'pole', 'post', 'column', 'pillar', 'stanchion', 'bollard', 'utility pole',
|
||||
'telegraph pole', 'light pole', 'street pole', 'signpost', 'support post',
|
||||
'vertical post', 'bench', 'chair', 'potted plant', 'hydrant', 'cone', 'stone', 'box'
|
||||
]
|
||||
# COCO 类别白名单 - TensorRT 模式下用于后处理过滤
|
||||
# 从 COCO 80 类中筛选出可能构成障碍物的类别
|
||||
self.COCO_WHITELIST = {
|
||||
'person', 'bicycle', 'car', 'motorcycle', 'bus', 'truck', # 交通
|
||||
'dog', 'cat', 'horse', 'cow', 'sheep', # 动物
|
||||
'bench', 'chair', 'potted plant', 'fire hydrant', 'stop sign', # 街道设施
|
||||
'parking meter', 'suitcase', 'backpack', 'umbrella', 'handbag', # 物品
|
||||
'sports ball', 'skateboard', 'surfboard', 'bottle', 'cup', # 可能障碍
|
||||
}
|
||||
try:
|
||||
# Day 20: 优先使用 TensorRT 引擎
|
||||
try:
|
||||
from model_utils import get_best_model_path, is_tensorrt_engine
|
||||
model_path = get_best_model_path(model_path)
|
||||
except ImportError:
|
||||
def is_tensorrt_engine(p): return p.endswith('.engine')
|
||||
|
||||
logger.info(f"正在加载 YOLOE 障碍物模型: {model_path}")
|
||||
self.model = YOLOE(model_path)
|
||||
|
||||
# Day 20: TensorRT 引擎不需要 .to() 和 .fuse()
|
||||
if is_tensorrt_engine(model_path):
|
||||
logger.info(f"TensorRT 引擎已加载,跳过 .to() 和 .fuse()")
|
||||
# TensorRT 引擎不支持 get_text_pe,跳过白名单特征计算
|
||||
self.whitelist_embeddings = None
|
||||
logger.info("TensorRT 模式:跳过白名单特征预计算")
|
||||
else:
|
||||
self.model.to(DEVICE)
|
||||
self.model.fuse()
|
||||
logger.info(f"YOLOE 障碍物模型加载成功,使用设备: {DEVICE}")
|
||||
|
||||
logger.info("正在为 YOLOE 预计算白名单文本特征...")
|
||||
if IS_CUDA and AMP_DTYPE is not None:
|
||||
with torch.inference_mode(), torch.amp.autocast(device_type='cuda', dtype=AMP_DTYPE):
|
||||
self.whitelist_embeddings = self.model.get_text_pe(self.WHITELIST_CLASSES)
|
||||
else:
|
||||
self.whitelist_embeddings = self.model.get_text_pe(self.WHITELIST_CLASSES)
|
||||
logger.info("YOLOE 特征预计算完成。")
|
||||
except Exception as e:
|
||||
logger.error(f"YOLOE 模型加载或特征计算失败: {e}", exc_info=True)
|
||||
raise
|
||||
def tensor_to_numpy_mask(mask_tensor):
|
||||
"""安全地将各种类型的张量转换为 numpy 掩码"""
|
||||
# 处理不同的数据类型
|
||||
if mask_tensor.dtype in (torch.bfloat16, torch.float16):
|
||||
mask_tensor = mask_tensor.float()
|
||||
|
||||
# 转换为 numpy
|
||||
mask = mask_tensor.cpu().numpy()
|
||||
|
||||
# 确保是二值掩码
|
||||
if mask.max() <= 1.0:
|
||||
mask = (mask > 0.5).astype(np.uint8) * 255
|
||||
else:
|
||||
mask = mask.astype(np.uint8)
|
||||
|
||||
return mask
|
||||
def detect(self, image: np.ndarray, path_mask: np.ndarray = None) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
利用白名单作为提示词寻找障碍物。
|
||||
如果提供了 path_mask,则执行与路径相关的空间过滤。
|
||||
如果 path_mask 为 None,则进行全局检测。
|
||||
"""
|
||||
if self.model is None:
|
||||
return []
|
||||
|
||||
H, W = image.shape[:2]
|
||||
|
||||
# TensorRT 模式下没有 embeddings,跳过 set_classes
|
||||
# 此时模型会使用默认的 COCO 类别进行检测
|
||||
if self.whitelist_embeddings is not None:
|
||||
try:
|
||||
self.model.set_classes(self.WHITELIST_CLASSES, self.whitelist_embeddings)
|
||||
except Exception as e:
|
||||
logger.error(f"设置 YOLOE 提示词失败: {e}")
|
||||
return []
|
||||
|
||||
conf_thr = float(os.getenv("AIGLASS_OBS_CONF", "0.25"))
|
||||
# Day 22 优化: 动态输入尺寸和FP16加速
|
||||
imgsz = int(os.getenv("AIGLASS_OBS_IMGSZ", "480")) # 从默认640降低
|
||||
use_half = os.getenv("AIGLASS_OBS_HALF", "1") == "1"
|
||||
|
||||
with gpu_infer_slot():
|
||||
results = self.model.predict(
|
||||
image,
|
||||
verbose=False,
|
||||
conf=conf_thr,
|
||||
imgsz=imgsz, # 使用较小的输入尺寸
|
||||
half=use_half # FP16 半精度加速
|
||||
)
|
||||
|
||||
if not (results and results[0].masks):
|
||||
return []
|
||||
|
||||
# --- 过滤与后处理 (逻辑与 blindpath 工作流保持一致) ---
|
||||
final_obstacles = []
|
||||
num_masks = len(results[0].masks.data)
|
||||
num_boxes = len(results[0].boxes.cls) if getattr(results[0].boxes, "cls", None) is not None else 0
|
||||
|
||||
for i, mask_tensor in enumerate(results[0].masks.data):
|
||||
if i >= num_boxes: continue
|
||||
|
||||
# 【修复】处理 BFloat16 类型的掩码
|
||||
# 先转换为 float32,避免 numpy 不支持 BFloat16 的问题
|
||||
if mask_tensor.dtype == torch.bfloat16:
|
||||
mask_tensor = mask_tensor.float()
|
||||
|
||||
# 转换为 numpy 数组
|
||||
mask = mask_tensor.cpu().numpy()
|
||||
|
||||
# 处理概率掩码(值在0-1之间)或二值掩码
|
||||
if mask.max() <= 1.0:
|
||||
# 概率掩码,需要二值化
|
||||
mask = (mask > 0.5).astype(np.uint8) * 255
|
||||
else:
|
||||
# 已经是二值掩码
|
||||
mask = mask.astype(np.uint8)
|
||||
|
||||
mask = cv2.resize(mask, (W, H), interpolation=cv2.INTER_NEAREST)
|
||||
|
||||
# Day 20: 使用 Numba 多核加速计算 mask 统计信息
|
||||
if NUMBA_ENABLED:
|
||||
stats = compute_mask_stats(mask)
|
||||
area = stats['area']
|
||||
center_x = stats['center_x']
|
||||
center_y = stats['center_y']
|
||||
min_y, max_y = stats['bbox'][1], stats['bbox'][3]
|
||||
else:
|
||||
area = int(np.sum(mask > 0))
|
||||
y_coords, x_coords = np.where(mask > 0)
|
||||
if len(y_coords) == 0:
|
||||
continue
|
||||
center_x = float(np.mean(x_coords))
|
||||
center_y = float(np.mean(y_coords))
|
||||
min_y, max_y = int(np.min(y_coords)), int(np.max(y_coords))
|
||||
|
||||
# 尺寸过滤:太大的物体(如整片地面)通常是误识别
|
||||
if (area / (H * W)) > 0.7: continue
|
||||
if area == 0: continue
|
||||
|
||||
# 空间过滤:如果提供了 path_mask,则只保留路径上的障碍物
|
||||
if path_mask is not None:
|
||||
# Day 20: 使用 Numba 加速交集计算
|
||||
if NUMBA_ENABLED:
|
||||
intersection_area = bitwise_and_count(mask, path_mask)
|
||||
else:
|
||||
intersection_area = int(np.sum(cv2.bitwise_and(mask, path_mask) > 0))
|
||||
# 必须与路径有足够的重叠
|
||||
if intersection_area < 100 or (intersection_area / area) < 0.01:
|
||||
continue
|
||||
|
||||
cls_id = int(results[0].boxes.cls[i])
|
||||
class_names_map = results[0].names
|
||||
class_name = "Unknown"
|
||||
if isinstance(class_names_map, dict):
|
||||
# 如果是字典,使用 .get() 方法
|
||||
class_name = class_names_map.get(cls_id, "Unknown")
|
||||
elif isinstance(class_names_map, list) and 0 <= cls_id < len(class_names_map):
|
||||
# 如果是列表,通过索引安全地获取
|
||||
class_name = class_names_map[cls_id]
|
||||
|
||||
# TensorRT 模式下使用 COCO 白名单过滤
|
||||
# 只保留可能构成障碍物的类别
|
||||
if self.whitelist_embeddings is None: # TensorRT 模式
|
||||
if class_name.lower().strip() not in self.COCO_WHITELIST:
|
||||
continue # 跳过非白名单类别
|
||||
|
||||
final_obstacles.append({
|
||||
'name': class_name.strip(),
|
||||
'mask': mask,
|
||||
'area': area,
|
||||
'area_ratio': area / (H * W),
|
||||
'center_x': center_x,
|
||||
'center_y': center_y,
|
||||
'bottom_y_ratio': max_y / H
|
||||
})
|
||||
|
||||
return final_obstacles
|
||||
117
omni_client.py
Normal file
117
omni_client.py
Normal file
@@ -0,0 +1,117 @@
|
||||
# omni_client.py
|
||||
# -*- coding: utf-8 -*-
|
||||
import os, base64, asyncio, threading
|
||||
from typing import AsyncGenerator, Dict, Any, List, Optional, Tuple
|
||||
|
||||
from openai import OpenAI
|
||||
|
||||
# ===== OpenAI 兼容(达摩院 DashScope 兼容模式)=====
|
||||
API_KEY = os.getenv("DASHSCOPE_API_KEY", "sk-a9440db694924559ae4ebdc2023d2b9a")
|
||||
if not API_KEY:
|
||||
raise RuntimeError("未设置 DASHSCOPE_API_KEY")
|
||||
|
||||
QWEN_MODEL = "qwen-omni-turbo"
|
||||
|
||||
# 兼容模式
|
||||
oai_client = OpenAI(
|
||||
api_key=API_KEY,
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
)
|
||||
|
||||
class OmniStreamPiece:
|
||||
"""对外的统一增量数据:text/audio 二选一或同时。"""
|
||||
def __init__(self, text_delta: Optional[str] = None, audio_b64: Optional[str] = None):
|
||||
self.text_delta = text_delta
|
||||
self.audio_b64 = audio_b64
|
||||
|
||||
async def stream_chat(
|
||||
content_list: List[Dict[str, Any]],
|
||||
voice: str = "Cherry",
|
||||
audio_format: str = "wav",
|
||||
) -> AsyncGenerator[OmniStreamPiece, None]:
|
||||
"""
|
||||
发起一轮 Omni-Turbo ChatCompletions 流式对话:
|
||||
- content_list: OpenAI chat 的 content,多模态(image_url/text)
|
||||
- 以 stream=True 返回
|
||||
- 增量产出:OmniStreamPiece(text_delta=?, audio_b64=?)
|
||||
|
||||
Day 13 修复:使用队列+线程解耦同步 API 调用,避免阻塞事件循环
|
||||
"""
|
||||
# 使用 asyncio.Queue 在线程和异步之间传递数据
|
||||
queue: asyncio.Queue = asyncio.Queue()
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
def _sync_stream():
|
||||
"""在独立线程中运行同步 API 调用"""
|
||||
try:
|
||||
# Day 21 优化:添加 system prompt 让 AI 回答简洁
|
||||
# 导盲眼镜场景需要快速、简短的回答
|
||||
system_prompt = """你是一个视障辅助AI助手,安装在智能导盲眼镜上。
|
||||
请用极简短的语言回答,每次回答不超过2-3句话。
|
||||
避免冗长解释,只提供最关键的信息。
|
||||
语气友好但简洁。"""
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": content_list}
|
||||
]
|
||||
|
||||
completion = oai_client.chat.completions.create(
|
||||
model=QWEN_MODEL,
|
||||
messages=messages,
|
||||
modalities=["text", "audio"],
|
||||
audio={"voice": voice, "format": audio_format},
|
||||
stream=True,
|
||||
stream_options={"include_usage": True},
|
||||
)
|
||||
|
||||
for chunk in completion:
|
||||
text_delta: Optional[str] = None
|
||||
audio_b64: Optional[str] = None
|
||||
|
||||
if getattr(chunk, "choices", None):
|
||||
c0 = chunk.choices[0]
|
||||
delta = getattr(c0, "delta", None)
|
||||
# 文本增量
|
||||
if delta and getattr(delta, "content", None):
|
||||
piece = delta.content
|
||||
if piece:
|
||||
text_delta = piece
|
||||
# 音频分片
|
||||
if delta and getattr(delta, "audio", None):
|
||||
aud = delta.audio
|
||||
audio_b64 = aud.get("data") if isinstance(aud, dict) else getattr(aud, "data", None)
|
||||
if audio_b64 is None:
|
||||
msg = getattr(c0, "message", None)
|
||||
if msg and getattr(msg, "audio", None):
|
||||
ma = msg.audio
|
||||
audio_b64 = ma.get("data") if isinstance(ma, dict) else getattr(ma, "data", None)
|
||||
|
||||
if (text_delta is not None) or (audio_b64 is not None):
|
||||
# 线程安全地放入队列
|
||||
loop.call_soon_threadsafe(
|
||||
queue.put_nowait,
|
||||
OmniStreamPiece(text_delta=text_delta, audio_b64=audio_b64)
|
||||
)
|
||||
except Exception as e:
|
||||
# 发生异常时也要通知
|
||||
loop.call_soon_threadsafe(queue.put_nowait, e)
|
||||
finally:
|
||||
# 发送结束标记
|
||||
loop.call_soon_threadsafe(queue.put_nowait, None)
|
||||
|
||||
# 在独立线程中启动同步 API 调用
|
||||
thread = threading.Thread(target=_sync_stream, daemon=True)
|
||||
thread.start()
|
||||
|
||||
# 异步消费队列
|
||||
while True:
|
||||
item = await queue.get()
|
||||
if item is None:
|
||||
# 流结束
|
||||
break
|
||||
if isinstance(item, Exception):
|
||||
# 发生异常
|
||||
raise item
|
||||
yield item
|
||||
|
||||
64
qwen_extractor.py
Normal file
64
qwen_extractor.py
Normal file
@@ -0,0 +1,64 @@
|
||||
# qwen_extractor.py
|
||||
# -*- coding: utf-8 -*-
|
||||
from typing import List, Tuple
|
||||
import os
|
||||
from openai import OpenAI
|
||||
|
||||
# —— 本地优先映射(可随时扩充/改名)——
|
||||
LOCAL_CN2EN = {
|
||||
"红牛": "Red_Bull",
|
||||
"ad钙奶": "AD_milk",
|
||||
"ad 钙奶": "AD_milk",
|
||||
"ad": "AD_milk",
|
||||
"钙奶": "AD_milk",
|
||||
"矿泉水": "bottle",
|
||||
"水瓶": "bottle",
|
||||
"可乐": "coke",
|
||||
"雪碧": "sprite",
|
||||
}
|
||||
|
||||
def _make_client() -> OpenAI:
|
||||
# 复用你百炼兼容端点;支持从环境变量读取
|
||||
base_url = os.getenv("DASHSCOPE_COMPAT_BASE", "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||
api_key = "sk-a9440db694924559ae4ebdc2023d2b9a"
|
||||
return OpenAI(api_key=api_key, base_url=base_url)
|
||||
|
||||
PROMPT_SYS = (
|
||||
"You are a label normalizer. Convert the given Chinese object "
|
||||
"description into a short, lowercase English YOLO/vision class name "
|
||||
"(1~3 words). If multiple are given, return the single most likely one. "
|
||||
"Output ONLY the label, no punctuation."
|
||||
)
|
||||
|
||||
def extract_english_label(query_cn: str) -> Tuple[str, str]:
|
||||
"""
|
||||
返回 (label_en, source);source ∈ {'local', 'qwen', 'fallback'}
|
||||
"""
|
||||
q = (query_cn or "").strip().lower()
|
||||
if q in LOCAL_CN2EN:
|
||||
return LOCAL_CN2EN[q], "local"
|
||||
|
||||
# 简单规则:去掉前缀修饰词
|
||||
for k, v in LOCAL_CN2EN.items():
|
||||
if k in q:
|
||||
return v, "local"
|
||||
|
||||
# 调用 Qwen Turbo(兼容 Chat Completions)
|
||||
try:
|
||||
client = _make_client()
|
||||
msgs = [
|
||||
{"role": "system", "content": PROMPT_SYS},
|
||||
{"role": "user", "content": query_cn.strip()},
|
||||
]
|
||||
rsp = client.chat.completions.create(
|
||||
model=os.getenv("QWEN_MODEL", "qwen-turbo"),
|
||||
messages=msgs,
|
||||
stream=False
|
||||
)
|
||||
label = (rsp.choices[0].message.content or "").strip()
|
||||
# 清洗一下
|
||||
label = label.replace(".", "").replace(",", "").replace(" ", " ").strip()
|
||||
# 兜底:空就回 'bottle'
|
||||
return (label or "bottle"), "qwen"
|
||||
except Exception:
|
||||
return "bottle", "fallback"
|
||||
28
qwenturbo_template.py
Normal file
28
qwenturbo_template.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from openai import OpenAI
|
||||
import os
|
||||
|
||||
client = OpenAI(
|
||||
# 如果没有配置环境变量,请用阿里云百炼API Key替换:api_key="sk-xxx"
|
||||
api_key="sk-a9440db694924559ae4ebdc2023d2b9a",
|
||||
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||
)
|
||||
|
||||
messages = [{"role": "user", "content": "你是谁"}]
|
||||
completion = client.chat.completions.create(
|
||||
model="qwen-turbo", # 您可以按需更换为其它深度思考模型
|
||||
messages=messages,
|
||||
extra_body={"enable_thinking": True},
|
||||
stream=True
|
||||
)
|
||||
is_answering = False # 是否进入回复阶段
|
||||
print("\n" + "=" * 20 + "思考过程" + "=" * 20)
|
||||
for chunk in completion:
|
||||
delta = chunk.choices[0].delta
|
||||
if hasattr(delta, "reasoning_content") and delta.reasoning_content is not None:
|
||||
if not is_answering:
|
||||
print(delta.reasoning_content, end="", flush=True)
|
||||
if hasattr(delta, "content") and delta.content:
|
||||
if not is_answering:
|
||||
print("\n" + "=" * 20 + "完整回复" + "=" * 20)
|
||||
is_answering = True
|
||||
print(delta.content, end="", flush=True)
|
||||
71
requirements.txt
Normal file
71
requirements.txt
Normal file
@@ -0,0 +1,71 @@
|
||||
# AI Glass System - Python Dependencies
|
||||
# Python 3.9 - 3.11 supported
|
||||
|
||||
# Core Web Framework
|
||||
fastapi==0.104.1
|
||||
uvicorn[standard]==0.24.0
|
||||
websockets==12.0
|
||||
python-multipart==0.0.6
|
||||
starlette==0.27.0
|
||||
|
||||
# Computer Vision & Deep Learning
|
||||
opencv-python==4.8.1.78
|
||||
numpy==1.24.3
|
||||
Pillow==10.1.0
|
||||
ultralytics==8.3.200
|
||||
torch==2.0.1
|
||||
torchvision==0.15.2
|
||||
|
||||
# MediaPipe (Hand Detection)
|
||||
mediapipe==0.10.8
|
||||
|
||||
# Audio Processing
|
||||
pyaudio==0.2.14
|
||||
pydub==0.25.1
|
||||
pygame==2.5.2
|
||||
|
||||
# Aliyun DashScope SDK (ASR & Qwen-Omni) - 旧管道
|
||||
dashscope==1.14.1
|
||||
openai==1.3.5 # For DashScope compatibility mode
|
||||
|
||||
# Day 21: 新 AI 管道 (SenseVoice + GLM + EdgeTTS)
|
||||
# torchaudio 需与 torch 版本匹配,使用以下命令安装:
|
||||
# pip install torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118
|
||||
funasr>=1.2.0 # SenseVoice 本地 ASR
|
||||
edge-tts>=6.1.0 # 免费 TTS
|
||||
|
||||
# Environment & Configuration
|
||||
python-dotenv==1.0.0
|
||||
|
||||
# Utilities
|
||||
opencv-contrib-python==4.8.1.78 # Extended OpenCV modules
|
||||
|
||||
# Optional Performance Optimizations
|
||||
# Uncomment if needed:
|
||||
# onnxruntime-gpu==1.16.3 # For ONNX model acceleration
|
||||
# tensorrt==8.6.1 # For TensorRT optimization
|
||||
PyTurboJPEG>=1.7.0 # Day 19: 2-3x faster JPEG encode/decode than cv2
|
||||
# NOTE: PyTurboJPEG requires system library: sudo apt-get install libturbojpeg (Ubuntu)
|
||||
numba>=0.58.0 # Day 20: JIT compilation for multi-core parallel numpy operations
|
||||
|
||||
# Development & Testing (Optional)
|
||||
# pytest==7.4.3
|
||||
# pytest-asyncio==0.21.1
|
||||
# black==23.11.0
|
||||
# flake8==6.1.0
|
||||
|
||||
# Platform-specific dependencies
|
||||
# Windows:
|
||||
# - PyAudio requires separate installation: https://www.lfd.uci.edu/~gohlke/pythonlibs/#pyaudio
|
||||
# Linux:
|
||||
# - Install PyAudio dependencies: sudo apt-get install portaudio19-dev python3-pyaudio
|
||||
# - Install OpenCV dependencies: sudo apt-get install libgl1-mesa-glx
|
||||
# macOS:
|
||||
# - Install PyAudio: brew install portaudio && pip install pyaudio
|
||||
|
||||
# CUDA Dependencies (GPU acceleration)
|
||||
# - CUDA Toolkit 11.8+: https://developer.nvidia.com/cuda-downloads
|
||||
# - cuDNN 8.6+: https://developer.nvidia.com/cudnn
|
||||
# PyTorch with CUDA support:
|
||||
# pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user