Init: 导入NaviGlassServer源码
This commit is contained in:
68
.env.performance
Normal file
68
.env.performance
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
# ============================================================
|
||||||
|
# Day 22 性能优化配置文件
|
||||||
|
# 复制此文件为 .env 并根据硬件调整参数
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
# === YOLO 盲道/斑马线检测 ===
|
||||||
|
# 输入分辨率 (越小越快,但精度降低)
|
||||||
|
# 建议: RTX 3090 用 640, GTX 1060 用 320-384
|
||||||
|
AIGLASS_YOLO_IMGSZ=480
|
||||||
|
|
||||||
|
# 检测间隔 (每N帧检测一次)
|
||||||
|
# 建议: 高端GPU用 6-8, 低端用 12-15
|
||||||
|
AIGLASS_BLINDPATH_INTERVAL=10
|
||||||
|
|
||||||
|
# 启用FP16半精度 (1=启用, 0=禁用)
|
||||||
|
# 注意: GTX 1060 的FP16性能不佳,建议设为0
|
||||||
|
AIGLASS_YOLO_HALF=1
|
||||||
|
|
||||||
|
# === 障碍物检测 ===
|
||||||
|
# 输入分辨率
|
||||||
|
AIGLASS_OBS_IMGSZ=480
|
||||||
|
|
||||||
|
# 检测间隔
|
||||||
|
AIGLASS_OBS_INTERVAL=18
|
||||||
|
|
||||||
|
# 缓存帧数
|
||||||
|
AIGLASS_OBS_CACHE_FRAMES=12
|
||||||
|
|
||||||
|
# 置信度阈值
|
||||||
|
AIGLASS_OBS_CONF=0.25
|
||||||
|
|
||||||
|
# 启用FP16
|
||||||
|
AIGLASS_OBS_HALF=1
|
||||||
|
|
||||||
|
# === GPU 并发控制 ===
|
||||||
|
# 同时推理的最大任务数
|
||||||
|
AIGLASS_GPU_SLOTS=2
|
||||||
|
|
||||||
|
# 设备选择 (cuda:0, cuda:1, cpu)
|
||||||
|
AIGLASS_DEVICE=cuda:0
|
||||||
|
|
||||||
|
# 混合精度模式 (fp16, bf16, off)
|
||||||
|
# RTX 30系列支持 bf16, 其他用 fp16
|
||||||
|
AIGLASS_AMP=fp16
|
||||||
|
|
||||||
|
# === 语音播报 ===
|
||||||
|
# 直行提示间隔(秒)
|
||||||
|
AIGLASS_STRAIGHT_INTERVAL=4.0
|
||||||
|
|
||||||
|
# 方向指令间隔(秒)
|
||||||
|
AIGLASS_DIRECTION_INTERVAL=3.0
|
||||||
|
|
||||||
|
# 持续播报模式 (1=启用)
|
||||||
|
AIGLASS_STRAIGHT_CONTINUOUS=1
|
||||||
|
|
||||||
|
# 限制模式下最大重复次数
|
||||||
|
AIGLASS_STRAIGHT_LIMIT=2
|
||||||
|
|
||||||
|
# === 模型路径 (根据实际路径修改) ===
|
||||||
|
# BLIND_PATH_MODEL=models/yolo-seg.pt
|
||||||
|
# OBSTACLE_MODEL=models/yoloe-11l-seg.pt
|
||||||
|
|
||||||
|
# === 调试选项 ===
|
||||||
|
# 启用ASR原始日志
|
||||||
|
# ASR_DEBUG_RAW=1
|
||||||
|
|
||||||
|
# 启用红绿灯调试图像
|
||||||
|
# AIGLASS_DEBUG_TRAFFIC_LIGHT=1
|
||||||
352
.gitignore
vendored
352
.gitignore
vendored
@@ -1,176 +1,176 @@
|
|||||||
# ---> Python
|
# ---> Python
|
||||||
# Byte-compiled / optimized / DLL files
|
# Byte-compiled / optimized / DLL files
|
||||||
__pycache__/
|
__pycache__/
|
||||||
*.py[cod]
|
*.py[cod]
|
||||||
*$py.class
|
*$py.class
|
||||||
|
|
||||||
# C extensions
|
# C extensions
|
||||||
*.so
|
*.so
|
||||||
|
|
||||||
# Distribution / packaging
|
# Distribution / packaging
|
||||||
.Python
|
.Python
|
||||||
build/
|
build/
|
||||||
develop-eggs/
|
develop-eggs/
|
||||||
dist/
|
dist/
|
||||||
downloads/
|
downloads/
|
||||||
eggs/
|
eggs/
|
||||||
.eggs/
|
.eggs/
|
||||||
lib/
|
lib/
|
||||||
lib64/
|
lib64/
|
||||||
parts/
|
parts/
|
||||||
sdist/
|
sdist/
|
||||||
var/
|
var/
|
||||||
wheels/
|
wheels/
|
||||||
share/python-wheels/
|
share/python-wheels/
|
||||||
*.egg-info/
|
*.egg-info/
|
||||||
.installed.cfg
|
.installed.cfg
|
||||||
*.egg
|
*.egg
|
||||||
MANIFEST
|
MANIFEST
|
||||||
|
|
||||||
# PyInstaller
|
# PyInstaller
|
||||||
# Usually these files are written by a python script from a template
|
# Usually these files are written by a python script from a template
|
||||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||||
*.manifest
|
*.manifest
|
||||||
*.spec
|
*.spec
|
||||||
|
|
||||||
# Installer logs
|
# Installer logs
|
||||||
pip-log.txt
|
pip-log.txt
|
||||||
pip-delete-this-directory.txt
|
pip-delete-this-directory.txt
|
||||||
|
|
||||||
# Unit test / coverage reports
|
# Unit test / coverage reports
|
||||||
htmlcov/
|
htmlcov/
|
||||||
.tox/
|
.tox/
|
||||||
.nox/
|
.nox/
|
||||||
.coverage
|
.coverage
|
||||||
.coverage.*
|
.coverage.*
|
||||||
.cache
|
.cache
|
||||||
nosetests.xml
|
nosetests.xml
|
||||||
coverage.xml
|
coverage.xml
|
||||||
*.cover
|
*.cover
|
||||||
*.py,cover
|
*.py,cover
|
||||||
.hypothesis/
|
.hypothesis/
|
||||||
.pytest_cache/
|
.pytest_cache/
|
||||||
cover/
|
cover/
|
||||||
|
|
||||||
# Translations
|
# Translations
|
||||||
*.mo
|
*.mo
|
||||||
*.pot
|
*.pot
|
||||||
|
|
||||||
# Django stuff:
|
# Django stuff:
|
||||||
*.log
|
*.log
|
||||||
local_settings.py
|
local_settings.py
|
||||||
db.sqlite3
|
db.sqlite3
|
||||||
db.sqlite3-journal
|
db.sqlite3-journal
|
||||||
|
|
||||||
# Flask stuff:
|
# Flask stuff:
|
||||||
instance/
|
instance/
|
||||||
.webassets-cache
|
.webassets-cache
|
||||||
|
|
||||||
# Scrapy stuff:
|
# Scrapy stuff:
|
||||||
.scrapy
|
.scrapy
|
||||||
|
|
||||||
# Sphinx documentation
|
# Sphinx documentation
|
||||||
docs/_build/
|
docs/_build/
|
||||||
|
|
||||||
# PyBuilder
|
# PyBuilder
|
||||||
.pybuilder/
|
.pybuilder/
|
||||||
target/
|
target/
|
||||||
|
|
||||||
# Jupyter Notebook
|
# Jupyter Notebook
|
||||||
.ipynb_checkpoints
|
.ipynb_checkpoints
|
||||||
|
|
||||||
# IPython
|
# IPython
|
||||||
profile_default/
|
profile_default/
|
||||||
ipython_config.py
|
ipython_config.py
|
||||||
|
|
||||||
# pyenv
|
# pyenv
|
||||||
# For a library or package, you might want to ignore these files since the code is
|
# For a library or package, you might want to ignore these files since the code is
|
||||||
# intended to run in multiple environments; otherwise, check them in:
|
# intended to run in multiple environments; otherwise, check them in:
|
||||||
# .python-version
|
# .python-version
|
||||||
|
|
||||||
# pipenv
|
# pipenv
|
||||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||||
# install all needed dependencies.
|
# install all needed dependencies.
|
||||||
#Pipfile.lock
|
#Pipfile.lock
|
||||||
|
|
||||||
# UV
|
# UV
|
||||||
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control.
|
||||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
# commonly ignored for libraries.
|
# commonly ignored for libraries.
|
||||||
#uv.lock
|
#uv.lock
|
||||||
|
|
||||||
# poetry
|
# poetry
|
||||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||||
# commonly ignored for libraries.
|
# commonly ignored for libraries.
|
||||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||||
#poetry.lock
|
#poetry.lock
|
||||||
|
|
||||||
# pdm
|
# pdm
|
||||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||||
#pdm.lock
|
#pdm.lock
|
||||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||||
# in version control.
|
# in version control.
|
||||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||||
.pdm.toml
|
.pdm.toml
|
||||||
.pdm-python
|
.pdm-python
|
||||||
.pdm-build/
|
.pdm-build/
|
||||||
|
|
||||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||||
__pypackages__/
|
__pypackages__/
|
||||||
|
|
||||||
# Celery stuff
|
# Celery stuff
|
||||||
celerybeat-schedule
|
celerybeat-schedule
|
||||||
celerybeat.pid
|
celerybeat.pid
|
||||||
|
|
||||||
# SageMath parsed files
|
# SageMath parsed files
|
||||||
*.sage.py
|
*.sage.py
|
||||||
|
|
||||||
# Environments
|
# Environments
|
||||||
.env
|
.env
|
||||||
.venv
|
.venv
|
||||||
env/
|
env/
|
||||||
venv/
|
venv/
|
||||||
ENV/
|
ENV/
|
||||||
env.bak/
|
env.bak/
|
||||||
venv.bak/
|
venv.bak/
|
||||||
|
|
||||||
# Spyder project settings
|
# Spyder project settings
|
||||||
.spyderproject
|
.spyderproject
|
||||||
.spyproject
|
.spyproject
|
||||||
|
|
||||||
# Rope project settings
|
# Rope project settings
|
||||||
.ropeproject
|
.ropeproject
|
||||||
|
|
||||||
# mkdocs documentation
|
# mkdocs documentation
|
||||||
/site
|
/site
|
||||||
|
|
||||||
# mypy
|
# mypy
|
||||||
.mypy_cache/
|
.mypy_cache/
|
||||||
.dmypy.json
|
.dmypy.json
|
||||||
dmypy.json
|
dmypy.json
|
||||||
|
|
||||||
# Pyre type checker
|
# Pyre type checker
|
||||||
.pyre/
|
.pyre/
|
||||||
|
|
||||||
# pytype static type analyzer
|
# pytype static type analyzer
|
||||||
.pytype/
|
.pytype/
|
||||||
|
|
||||||
# Cython debug symbols
|
# Cython debug symbols
|
||||||
cython_debug/
|
cython_debug/
|
||||||
|
|
||||||
# PyCharm
|
# PyCharm
|
||||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||||
#.idea/
|
#.idea/
|
||||||
|
|
||||||
# Ruff stuff:
|
# Ruff stuff:
|
||||||
.ruff_cache/
|
.ruff_cache/
|
||||||
|
|
||||||
# PyPI configuration file
|
# PyPI configuration file
|
||||||
.pypirc
|
.pypirc
|
||||||
|
|
||||||
|
|||||||
90
CHANGELOG.md
Normal file
90
CHANGELOG.md
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
# 更新日志
|
||||||
|
|
||||||
|
本文档记录项目的所有重要变更。
|
||||||
|
|
||||||
|
格式基于 [Keep a Changelog](https://keepachangelog.com/zh-CN/1.0.0/),
|
||||||
|
版本号遵循 [语义化版本](https://semver.org/lang/zh-CN/)。
|
||||||
|
|
||||||
|
## [未发布]
|
||||||
|
|
||||||
|
### 新增
|
||||||
|
- 首次开源发布
|
||||||
|
- 完整的 GitHub 文档(README, CONTRIBUTING, LICENSE 等)
|
||||||
|
- Docker 支持
|
||||||
|
- 环境变量配置模板
|
||||||
|
|
||||||
|
### 修改
|
||||||
|
- 优化了 README 文档结构
|
||||||
|
- 改进了代码注释
|
||||||
|
|
||||||
|
## [1.0.0] - 2025-01-XX
|
||||||
|
|
||||||
|
### 新增
|
||||||
|
- 🚶 盲道导航系统
|
||||||
|
- 实时盲道检测与分割
|
||||||
|
- 智能语音引导
|
||||||
|
- 障碍物检测与避障
|
||||||
|
- 急转弯检测与提醒
|
||||||
|
- 光流稳定算法
|
||||||
|
|
||||||
|
- 🚦 过马路辅助
|
||||||
|
- 斑马线识别与方向检测
|
||||||
|
- 红绿灯颜色识别
|
||||||
|
- 对齐引导系统
|
||||||
|
- 安全提醒
|
||||||
|
|
||||||
|
- 🔍 物品识别与查找
|
||||||
|
- YOLO-E 开放词汇检测
|
||||||
|
- MediaPipe 手部引导
|
||||||
|
- 实时目标追踪
|
||||||
|
- 抓取动作检测
|
||||||
|
|
||||||
|
- 🎙️ 实时语音交互
|
||||||
|
- 阿里云 Paraformer ASR
|
||||||
|
- Qwen-Omni-Turbo 多模态对话
|
||||||
|
- 智能指令解析
|
||||||
|
- 上下文感知
|
||||||
|
|
||||||
|
- 📹 视频与音频处理
|
||||||
|
- WebSocket 实时推流
|
||||||
|
- 音视频同步录制
|
||||||
|
- IMU 数据融合
|
||||||
|
- 多路音频混音
|
||||||
|
|
||||||
|
- 🎨 可视化与交互
|
||||||
|
- Web 实时监控界面
|
||||||
|
- IMU 3D 可视化
|
||||||
|
- 状态面板
|
||||||
|
- 中文友好界面
|
||||||
|
|
||||||
|
### 技术栈
|
||||||
|
- FastAPI + WebSocket
|
||||||
|
- YOLO11 / YOLO-E
|
||||||
|
- MediaPipe
|
||||||
|
- PyTorch + CUDA
|
||||||
|
- OpenCV
|
||||||
|
- DashScope API
|
||||||
|
|
||||||
|
### 已知问题
|
||||||
|
- [ ] 在低端 GPU 上可能出现卡顿
|
||||||
|
- [ ] macOS 上缺少 GPU 加速支持
|
||||||
|
- [ ] 部分中文字体在 Linux 上显示不正确
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
## 版本说明
|
||||||
|
|
||||||
|
### 主版本(Major)
|
||||||
|
- 不兼容的 API 更改
|
||||||
|
|
||||||
|
### 次版本(Minor)
|
||||||
|
- 向后兼容的新功能
|
||||||
|
|
||||||
|
### 修订版本(Patch)
|
||||||
|
- 向后兼容的问题修复
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
[未发布]: https://github.com/yourusername/aiglass/compare/v1.0.0...HEAD
|
||||||
|
[1.0.0]: https://github.com/yourusername/aiglass/releases/tag/v1.0.0
|
||||||
|
|
||||||
59
Dockerfile
Normal file
59
Dockerfile
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
# AI Glass System - Dockerfile
|
||||||
|
# 基于 NVIDIA CUDA 的 Python 镜像
|
||||||
|
|
||||||
|
FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04
|
||||||
|
|
||||||
|
# 设置环境变量
|
||||||
|
ENV DEBIAN_FRONTEND=noninteractive
|
||||||
|
ENV PYTHONUNBUFFERED=1
|
||||||
|
ENV CUDA_HOME=/usr/local/cuda
|
||||||
|
ENV PATH=${CUDA_HOME}/bin:${PATH}
|
||||||
|
ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH}
|
||||||
|
|
||||||
|
# 设置工作目录
|
||||||
|
WORKDIR /app
|
||||||
|
|
||||||
|
# 安装系统依赖
|
||||||
|
RUN apt-get update && apt-get install -y \
|
||||||
|
python3.10 \
|
||||||
|
python3-pip \
|
||||||
|
python3-dev \
|
||||||
|
portaudio19-dev \
|
||||||
|
libgl1-mesa-glx \
|
||||||
|
libglib2.0-0 \
|
||||||
|
libsm6 \
|
||||||
|
libxext6 \
|
||||||
|
libxrender-dev \
|
||||||
|
libgomp1 \
|
||||||
|
git \
|
||||||
|
wget \
|
||||||
|
curl \
|
||||||
|
&& apt-get clean \
|
||||||
|
&& rm -rf /var/lib/apt/lists/*
|
||||||
|
|
||||||
|
# 升级 pip
|
||||||
|
RUN python3 -m pip install --upgrade pip
|
||||||
|
|
||||||
|
# 复制 requirements.txt
|
||||||
|
COPY requirements.txt .
|
||||||
|
|
||||||
|
# 安装 Python 依赖
|
||||||
|
RUN pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118
|
||||||
|
RUN pip install --no-cache-dir -r requirements.txt
|
||||||
|
|
||||||
|
# 复制应用代码
|
||||||
|
COPY . .
|
||||||
|
|
||||||
|
# 创建必要的目录
|
||||||
|
RUN mkdir -p recordings model music voice static templates
|
||||||
|
|
||||||
|
# 暴露端口
|
||||||
|
EXPOSE 8081 12345/udp
|
||||||
|
|
||||||
|
# 健康检查
|
||||||
|
HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \
|
||||||
|
CMD curl -f http://localhost:8081/api/health || exit 1
|
||||||
|
|
||||||
|
# 启动命令
|
||||||
|
CMD ["python3", "app_main.py"]
|
||||||
|
|
||||||
21
LICENSE
Normal file
21
LICENSE
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
MIT License
|
||||||
|
|
||||||
|
Copyright (c) 2025 AI-FanGe
|
||||||
|
|
||||||
|
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||||
|
of this software and associated documentation files (the "Software"), to deal
|
||||||
|
in the Software without restriction, including without limitation the rights
|
||||||
|
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||||
|
copies of the Software, and to permit persons to whom the Software is
|
||||||
|
furnished to do so, subject to the following conditions:
|
||||||
|
|
||||||
|
The above copyright notice and this permission notice shall be included in all
|
||||||
|
copies or substantial portions of the Software.
|
||||||
|
|
||||||
|
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||||
|
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||||
|
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||||
|
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||||
|
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||||
|
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||||
|
SOFTWARE.
|
||||||
402
PROJECT_STRUCTURE.md
Normal file
402
PROJECT_STRUCTURE.md
Normal file
@@ -0,0 +1,402 @@
|
|||||||
|
# 项目结构说明
|
||||||
|
|
||||||
|
本文档详细说明项目的目录结构和主要文件的作用。
|
||||||
|
|
||||||
|
## 📁 目录结构
|
||||||
|
|
||||||
|
```
|
||||||
|
rebuild1002/
|
||||||
|
├── 📄 主要应用文件
|
||||||
|
│ ├── app_main.py # 主应用入口(FastAPI 服务)
|
||||||
|
│ ├── navigation_master.py # 导航统领器(状态机)
|
||||||
|
│ ├── workflow_blindpath.py # 盲道导航工作流
|
||||||
|
│ ├── workflow_crossstreet.py # 过马路导航工作流
|
||||||
|
│ └── yolomedia.py # 物品查找工作流
|
||||||
|
│
|
||||||
|
├── 🎙️ 语音处理模块
|
||||||
|
│ ├── asr_core.py # 语音识别核心
|
||||||
|
│ ├── omni_client.py # Qwen-Omni 客户端
|
||||||
|
│ ├── qwen_extractor.py # 标签提取(中文->英文)
|
||||||
|
│ ├── audio_player.py # 音频播放器
|
||||||
|
│ └── audio_stream.py # 音频流管理
|
||||||
|
│
|
||||||
|
├── 🤖 模型相关
|
||||||
|
│ ├── yoloe_backend.py # YOLO-E 后端(开放词汇)
|
||||||
|
│ ├── trafficlight_detection.py # 红绿灯检测
|
||||||
|
│ ├── obstacle_detector_client.py # 障碍物检测客户端
|
||||||
|
│ └── models.py # 模型定义
|
||||||
|
│
|
||||||
|
├── 🎥 视频处理
|
||||||
|
│ ├── bridge_io.py # 线程安全的帧缓冲
|
||||||
|
│ ├── sync_recorder.py # 音视频同步录制
|
||||||
|
│ └── video_recorder.py # 视频录制(旧版)
|
||||||
|
│
|
||||||
|
├── 🌐 Web 前端
|
||||||
|
│ ├── templates/
|
||||||
|
│ │ └── index.html # 主界面 HTML
|
||||||
|
│ ├── static/
|
||||||
|
│ │ ├── main.js # 主 JS 脚本
|
||||||
|
│ │ ├── vision.js # 视觉流处理
|
||||||
|
│ │ ├── visualizer.js # 数据可视化
|
||||||
|
│ │ ├── vision_renderer.js # 渲染器
|
||||||
|
│ │ ├── vision.css # 样式表
|
||||||
|
│ │ └── models/ # 3D 模型(IMU 可视化)
|
||||||
|
│
|
||||||
|
├── 🎵 音频资源
|
||||||
|
│ ├── music/ # 系统提示音
|
||||||
|
│ │ ├── converted_向上.wav
|
||||||
|
│ │ ├── converted_向下.wav
|
||||||
|
│ │ └── ...
|
||||||
|
│ └── voice/ # 预录语音
|
||||||
|
│ ├── voice_mapping.json
|
||||||
|
│ └── *.wav
|
||||||
|
│
|
||||||
|
├── 🧠 模型文件
|
||||||
|
│ └── model/
|
||||||
|
│ ├── yolo-seg.pt # 盲道分割模型
|
||||||
|
│ ├── yoloe-11l-seg.pt # YOLO-E 开放词汇模型
|
||||||
|
│ ├── shoppingbest5.pt # 物品识别模型
|
||||||
|
│ ├── trafficlight.pt # 红绿灯检测模型
|
||||||
|
│ └── hand_landmarker.task # MediaPipe 手部模型
|
||||||
|
│
|
||||||
|
├── 📹 录制文件
|
||||||
|
│ └── recordings/ # 自动保存的视频和音频
|
||||||
|
│ ├── video_*.avi
|
||||||
|
│ └── audio_*.wav
|
||||||
|
│
|
||||||
|
├── 🛠️ ESP32 固件
|
||||||
|
│ └── compile/
|
||||||
|
│ ├── compile.ino # Arduino 主程序
|
||||||
|
│ ├── camera_pins.h # 摄像头引脚定义
|
||||||
|
│ ├── ICM42688.cpp/h # IMU 驱动
|
||||||
|
│ └── ESP32_VIDEO_OPTIMIZATION.md
|
||||||
|
│
|
||||||
|
├── 🧪 测试文件
|
||||||
|
│ ├── test_recorder.py # 录制功能测试
|
||||||
|
│ ├── test_traffic_light.py # 红绿灯检测测试
|
||||||
|
│ ├── test_cross_street_blindpath.py # 导航测试
|
||||||
|
│ └── test_crosswalk_awareness.py # 斑马线检测测试
|
||||||
|
│
|
||||||
|
├── 📚 文档
|
||||||
|
│ ├── README.md # 项目主文档
|
||||||
|
│ ├── INSTALLATION.md # 安装指南
|
||||||
|
│ ├── CONTRIBUTING.md # 贡献指南
|
||||||
|
│ ├── FAQ.md # 常见问题
|
||||||
|
│ ├── CHANGELOG.md # 更新日志
|
||||||
|
│ ├── SECURITY.md # 安全政策
|
||||||
|
│ └── PROJECT_STRUCTURE.md # 本文件
|
||||||
|
│
|
||||||
|
├── 🐳 Docker 相关
|
||||||
|
│ ├── Dockerfile # Docker 镜像定义
|
||||||
|
│ ├── docker-compose.yml # Docker Compose 配置
|
||||||
|
│ └── .dockerignore # Docker 忽略文件
|
||||||
|
│
|
||||||
|
├── ⚙️ 配置文件
|
||||||
|
│ ├── .env.example # 环境变量模板
|
||||||
|
│ ├── .gitignore # Git 忽略文件
|
||||||
|
│ ├── requirements.txt # Python 依赖
|
||||||
|
│ ├── setup.sh # Linux/macOS 安装脚本
|
||||||
|
│ └── setup.bat # Windows 安装脚本
|
||||||
|
│
|
||||||
|
├── 📄 许可证
|
||||||
|
│ └── LICENSE # MIT 许可证
|
||||||
|
│
|
||||||
|
└── 🔧 GitHub 相关
|
||||||
|
└── .github/
|
||||||
|
├── ISSUE_TEMPLATE/
|
||||||
|
│ ├── bug_report.md
|
||||||
|
│ └── feature_request.md
|
||||||
|
└── pull_request_template.md
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🔑 核心文件说明
|
||||||
|
|
||||||
|
### 主应用层
|
||||||
|
|
||||||
|
#### `app_main.py`
|
||||||
|
- **作用**: FastAPI 主服务,处理所有 WebSocket 连接
|
||||||
|
- **主要功能**:
|
||||||
|
- WebSocket 路由管理(/ws/camera, /ws_audio, /ws/viewer 等)
|
||||||
|
- 模型加载与初始化
|
||||||
|
- 状态协调与管理
|
||||||
|
- 音视频流分发
|
||||||
|
- **依赖**: 所有其他模块
|
||||||
|
- **入口点**: `python app_main.py`
|
||||||
|
|
||||||
|
#### `navigation_master.py`
|
||||||
|
- **作用**: 导航统领器,管理整个系统的状态机
|
||||||
|
- **主要状态**:
|
||||||
|
- IDLE: 空闲
|
||||||
|
- CHAT: 对话模式
|
||||||
|
- BLINDPATH_NAV: 盲道导航
|
||||||
|
- CROSSING: 过马路
|
||||||
|
- TRAFFIC_LIGHT_DETECTION: 红绿灯检测
|
||||||
|
- ITEM_SEARCH: 物品查找
|
||||||
|
- **核心方法**:
|
||||||
|
- `process_frame()`: 处理每一帧
|
||||||
|
- `start_blind_path_navigation()`: 启动盲道导航
|
||||||
|
- `start_crossing()`: 启动过马路模式
|
||||||
|
- `on_voice_command()`: 处理语音命令
|
||||||
|
|
||||||
|
### 工作流模块
|
||||||
|
|
||||||
|
#### `workflow_blindpath.py`
|
||||||
|
- **作用**: 盲道导航核心逻辑
|
||||||
|
- **主要功能**:
|
||||||
|
- 盲道分割与检测
|
||||||
|
- 障碍物检测
|
||||||
|
- 转弯检测
|
||||||
|
- 光流稳定
|
||||||
|
- 方向引导生成
|
||||||
|
- **状态机**:
|
||||||
|
- ONBOARDING: 上盲道
|
||||||
|
- NAVIGATING: 导航中
|
||||||
|
- MANEUVERING_TURN: 转弯
|
||||||
|
- AVOIDING_OBSTACLE: 避障
|
||||||
|
|
||||||
|
#### `workflow_crossstreet.py`
|
||||||
|
- **作用**: 过马路导航逻辑
|
||||||
|
- **主要功能**:
|
||||||
|
- 斑马线检测
|
||||||
|
- 方向对齐
|
||||||
|
- 引导生成
|
||||||
|
- **核心方法**:
|
||||||
|
- `_is_crosswalk_near()`: 判断是否接近斑马线
|
||||||
|
- `_compute_angle_and_offset()`: 计算角度和偏移
|
||||||
|
|
||||||
|
#### `yolomedia.py`
|
||||||
|
- **作用**: 物品查找工作流
|
||||||
|
- **主要功能**:
|
||||||
|
- YOLO-E 文本提示检测
|
||||||
|
- MediaPipe 手部追踪
|
||||||
|
- 光流目标追踪
|
||||||
|
- 手部引导(方向提示)
|
||||||
|
- 抓取动作检测
|
||||||
|
- **模式**:
|
||||||
|
- SEGMENT: 检测模式
|
||||||
|
- FLASH: 闪烁确认
|
||||||
|
- CENTER_GUIDE: 居中引导
|
||||||
|
- TRACK: 手部追踪
|
||||||
|
|
||||||
|
### 语音模块
|
||||||
|
|
||||||
|
#### `asr_core.py`
|
||||||
|
- **作用**: 阿里云 Paraformer ASR 实时语音识别
|
||||||
|
- **主要功能**:
|
||||||
|
- 实时语音识别
|
||||||
|
- VAD(语音活动检测)
|
||||||
|
- 识别结果回调
|
||||||
|
- **关键类**: `ASRCallback`
|
||||||
|
|
||||||
|
#### `omni_client.py`
|
||||||
|
- **作用**: Qwen-Omni-Turbo 多模态对话客户端
|
||||||
|
- **主要功能**:
|
||||||
|
- 流式对话生成
|
||||||
|
- 图像+文本输入
|
||||||
|
- 语音输出
|
||||||
|
- **核心函数**: `stream_chat()`
|
||||||
|
|
||||||
|
#### `audio_player.py`
|
||||||
|
- **作用**: 统一的音频播放管理
|
||||||
|
- **主要功能**:
|
||||||
|
- TTS 语音播放
|
||||||
|
- 多路音频混音
|
||||||
|
- 音量控制
|
||||||
|
- 线程安全播放
|
||||||
|
- **核心函数**: `play_voice_text()`, `play_audio_threadsafe()`
|
||||||
|
|
||||||
|
### 模型后端
|
||||||
|
|
||||||
|
#### `yoloe_backend.py`
|
||||||
|
- **作用**: YOLO-E 开放词汇检测后端
|
||||||
|
- **主要功能**:
|
||||||
|
- 文本提示设置
|
||||||
|
- 实时分割
|
||||||
|
- 目标追踪
|
||||||
|
- **核心类**: `YoloEBackend`
|
||||||
|
|
||||||
|
#### `trafficlight_detection.py`
|
||||||
|
- **作用**: 红绿灯检测模块
|
||||||
|
- **检测方法**:
|
||||||
|
1. YOLO 模型检测
|
||||||
|
2. HSV 颜色分类(备用)
|
||||||
|
- **输出**: 红灯/绿灯/黄灯/未知
|
||||||
|
|
||||||
|
#### `obstacle_detector_client.py`
|
||||||
|
- **作用**: 障碍物检测客户端
|
||||||
|
- **主要功能**:
|
||||||
|
- 白名单类别过滤
|
||||||
|
- 路径掩码内检测
|
||||||
|
- 物体属性计算(面积、位置、危险度)
|
||||||
|
|
||||||
|
### 视频处理
|
||||||
|
|
||||||
|
#### `bridge_io.py`
|
||||||
|
- **作用**: 线程安全的帧缓冲与分发
|
||||||
|
- **主要功能**:
|
||||||
|
- 生产者-消费者模式
|
||||||
|
- 原始帧缓存
|
||||||
|
- 处理后帧分发
|
||||||
|
- **核心函数**:
|
||||||
|
- `push_raw_jpeg()`: 接收 ESP32 帧
|
||||||
|
- `wait_raw_bgr()`: 取原始帧
|
||||||
|
- `send_vis_bgr()`: 发送处理后的帧
|
||||||
|
|
||||||
|
#### `sync_recorder.py`
|
||||||
|
- **作用**: 音视频同步录制
|
||||||
|
- **主要功能**:
|
||||||
|
- 同步录制视频和音频
|
||||||
|
- 自动文件命名(时间戳)
|
||||||
|
- 线程安全
|
||||||
|
- **输出**: `recordings/video_*.avi`, `audio_*.wav`
|
||||||
|
|
||||||
|
### 前端
|
||||||
|
|
||||||
|
#### `templates/index.html`
|
||||||
|
- **作用**: Web 监控界面
|
||||||
|
- **主要区域**:
|
||||||
|
- 视频流显示
|
||||||
|
- 状态面板
|
||||||
|
- IMU 3D 可视化
|
||||||
|
- 语音识别结果
|
||||||
|
|
||||||
|
#### `static/main.js`
|
||||||
|
- **作用**: 主 JavaScript 逻辑
|
||||||
|
- **主要功能**:
|
||||||
|
- WebSocket 连接管理
|
||||||
|
- UI 更新
|
||||||
|
- 事件处理
|
||||||
|
|
||||||
|
#### `static/vision.js`
|
||||||
|
- **作用**: 视觉流处理
|
||||||
|
- **主要功能**:
|
||||||
|
- WebSocket 接收视频帧
|
||||||
|
- Canvas 渲染
|
||||||
|
- FPS 计算
|
||||||
|
|
||||||
|
#### `static/visualizer.js`
|
||||||
|
- **作用**: IMU 3D 可视化(Three.js)
|
||||||
|
- **主要功能**:
|
||||||
|
- 接收 IMU 数据
|
||||||
|
- 实时渲染设备姿态
|
||||||
|
- 动态灯光效果
|
||||||
|
|
||||||
|
## 🔄 数据流
|
||||||
|
|
||||||
|
### 视频流
|
||||||
|
```
|
||||||
|
ESP32-CAM
|
||||||
|
→ [JPEG] WebSocket /ws/camera
|
||||||
|
→ bridge_io.push_raw_jpeg()
|
||||||
|
→ yolomedia / navigation_master
|
||||||
|
→ bridge_io.send_vis_bgr()
|
||||||
|
→ [JPEG] WebSocket /ws/viewer
|
||||||
|
→ Browser Canvas
|
||||||
|
```
|
||||||
|
|
||||||
|
### 音频流(上行)
|
||||||
|
```
|
||||||
|
ESP32-MIC
|
||||||
|
→ [PCM16] WebSocket /ws_audio
|
||||||
|
→ asr_core
|
||||||
|
→ DashScope ASR
|
||||||
|
→ 识别结果
|
||||||
|
→ start_ai_with_text_custom()
|
||||||
|
```
|
||||||
|
|
||||||
|
### 音频流(下行)
|
||||||
|
```
|
||||||
|
Qwen-Omni / TTS
|
||||||
|
→ audio_player
|
||||||
|
→ [PCM16] audio_stream
|
||||||
|
→ [WAV] HTTP /stream.wav
|
||||||
|
→ ESP32-Speaker
|
||||||
|
```
|
||||||
|
|
||||||
|
### IMU 数据流
|
||||||
|
```
|
||||||
|
ESP32-IMU
|
||||||
|
→ [JSON] UDP 12345
|
||||||
|
→ process_imu_and_maybe_store()
|
||||||
|
→ [JSON] WebSocket /ws
|
||||||
|
→ visualizer.js (Three.js)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🎯 关键设计模式
|
||||||
|
|
||||||
|
### 1. 状态机模式
|
||||||
|
- **位置**: `navigation_master.py`
|
||||||
|
- **作用**: 管理系统状态转换
|
||||||
|
- **状态**: IDLE → CHAT / BLINDPATH_NAV / CROSSING / ...
|
||||||
|
|
||||||
|
### 2. 生产者-消费者模式
|
||||||
|
- **位置**: `bridge_io.py`
|
||||||
|
- **作用**: 解耦视频接收与处理
|
||||||
|
- **实现**: 线程 + 队列
|
||||||
|
|
||||||
|
### 3. 策略模式
|
||||||
|
- **位置**: 各 `workflow_*.py`
|
||||||
|
- **作用**: 不同导航策略的实现
|
||||||
|
- **实现**: 统一的 `process_frame()` 接口
|
||||||
|
|
||||||
|
### 4. 单例模式
|
||||||
|
- **位置**: 模型加载
|
||||||
|
- **作用**: 全局共享模型实例
|
||||||
|
- **实现**: 全局变量 + 初始化检查
|
||||||
|
|
||||||
|
### 5. 观察者模式
|
||||||
|
- **位置**: WebSocket 通信
|
||||||
|
- **作用**: 多客户端订阅视频流
|
||||||
|
- **实现**: `camera_viewers: Set[WebSocket]`
|
||||||
|
|
||||||
|
## 📦 依赖关系
|
||||||
|
|
||||||
|
```
|
||||||
|
app_main.py
|
||||||
|
├── navigation_master.py
|
||||||
|
│ ├── workflow_blindpath.py
|
||||||
|
│ │ ├── yoloe_backend.py
|
||||||
|
│ │ └── obstacle_detector_client.py
|
||||||
|
│ ├── workflow_crossstreet.py
|
||||||
|
│ └── trafficlight_detection.py
|
||||||
|
├── yolomedia.py
|
||||||
|
│ └── yoloe_backend.py
|
||||||
|
├── asr_core.py
|
||||||
|
├── omni_client.py
|
||||||
|
├── audio_player.py
|
||||||
|
├── audio_stream.py
|
||||||
|
├── bridge_io.py
|
||||||
|
└── sync_recorder.py
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🚀 启动流程
|
||||||
|
|
||||||
|
1. **初始化阶段** (`app_main.py`)
|
||||||
|
- 加载环境变量
|
||||||
|
- 加载导航模型(YOLO、MediaPipe)
|
||||||
|
- 初始化音频系统
|
||||||
|
- 启动录制系统
|
||||||
|
- 预加载红绿灯模型
|
||||||
|
|
||||||
|
2. **服务启动** (FastAPI)
|
||||||
|
- 注册 WebSocket 路由
|
||||||
|
- 挂载静态文件
|
||||||
|
- 启动 UDP 监听(IMU)
|
||||||
|
- 启动 HTTP 服务(8081 端口)
|
||||||
|
|
||||||
|
3. **运行阶段**
|
||||||
|
- 等待 ESP32 连接
|
||||||
|
- 接收视频/音频/IMU 数据
|
||||||
|
- 处理用户语音指令
|
||||||
|
- 实时推送处理结果
|
||||||
|
|
||||||
|
4. **关闭阶段**
|
||||||
|
- 停止录制(保存文件)
|
||||||
|
- 关闭所有 WebSocket 连接
|
||||||
|
- 释放模型资源
|
||||||
|
- 清理临时文件
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
**提示**: 如需了解某个模块的详细实现,请查看相应源文件的注释和 docstring。
|
||||||
|
|
||||||
508
README.md
508
README.md
@@ -1,2 +1,506 @@
|
|||||||
# NaviGlassServer
|
# AI 智能盲人眼镜系统 🤖👓
|
||||||
|
|
||||||
|
<div align="center">
|
||||||
|
|
||||||
|
一个面向视障人士的智能导航与辅助系统,集成了盲道导航、过马路辅助、物品识别、实时语音交互等功能。 本项目仅为交流学习使用,请勿直接给视障人群使用。本项目内仅包含代码,模型地址:https://www.modelscope.cn/models/archifancy/AIGlasses_for_navigation 。下载后存放在/model 文件夹
|
||||||
|
|
||||||
|
[功能特性](#功能特性) • [快速开始](#快速开始) • [系统架构](#系统架构) • [使用说明](#使用说明) • [开发文档](#开发文档)
|
||||||
|
|
||||||
|
</div>
|
||||||
|
|
||||||
|
---
|
||||||
|
<img width="2481" height="3508" alt="1" src="https://github.com/user-attachments/assets/e8dec4a6-8fa6-4d94-bd66-4e9864b67daf" />
|
||||||
|
<img width="2480" height="3508" alt="2" src="https://github.com/user-attachments/assets/bc7d1aac-a9e9-4ef8-9d67-224708d0c9fd" />
|
||||||
|
<img width="2481" height="3508" alt="4" src="https://github.com/user-attachments/assets/6dd19750-57af-4560-a007-9a7059956b53" />
|
||||||
|
|
||||||
|
## 📋 目录
|
||||||
|
|
||||||
|
- [功能特性](#功能特性)
|
||||||
|
- [系统要求](#系统要求)
|
||||||
|
- [快速开始](#快速开始)
|
||||||
|
- [系统架构](#系统架构)
|
||||||
|
- [使用说明](#使用说明)
|
||||||
|
- [配置说明](#配置说明)
|
||||||
|
- [开发文档](#开发文档)
|
||||||
|
|
||||||
|
## ✨ 功能特性
|
||||||
|
|
||||||
|
### 🚶 盲道导航系统
|
||||||
|
- **实时盲道检测**:基于 YOLO 分割模型实时识别盲道
|
||||||
|
- **智能语音引导**:提供精准的方向指引(左转、右转、直行等)
|
||||||
|
- **障碍物检测与避障**:自动识别前方障碍物并规划避障路线
|
||||||
|
- **转弯检测**:自动识别急转弯并提前提醒
|
||||||
|
- **光流稳定**:使用 Lucas-Kanade 光流算法稳定掩码,减少抖动
|
||||||
|
|
||||||
|
### 🚦 过马路辅助
|
||||||
|
- **斑马线识别**:实时检测斑马线位置和方向
|
||||||
|
- **红绿灯识别**:基于颜色和形状的红绿灯状态检测
|
||||||
|
- **对齐引导**:引导用户对准斑马线中心
|
||||||
|
- **安全提醒**:绿灯时语音提示可以通行
|
||||||
|
|
||||||
|
### 🔍 物品识别与查找
|
||||||
|
- **智能物品搜索**:语音指令查找物品(如"帮我找一下红牛")
|
||||||
|
- **实时目标追踪**:使用 YOLO-E 开放词汇检测 + ByteTrack 追踪
|
||||||
|
- **手部引导**:结合 MediaPipe 手部检测,引导用户手部靠近物品
|
||||||
|
- **抓取检测**:检测手部握持动作,确认物品已拿到
|
||||||
|
- **多模态反馈**:视觉标注 + 语音引导 + 居中提示
|
||||||
|
|
||||||
|
### 🎙️ 实时语音交互
|
||||||
|
- **语音识别(ASR)**:基于阿里云 DashScope Paraformer 实时语音识别
|
||||||
|
- **多模态对话**:Qwen-Omni-Turbo 支持图像+文本输入,语音输出
|
||||||
|
- **智能指令解析**:自动识别导航、查找、对话等不同类型指令
|
||||||
|
- **上下文感知**:在不同模式下智能过滤无关指令
|
||||||
|
|
||||||
|
### 📹 视频与音频处理
|
||||||
|
- **实时视频流**:WebSocket 推流,支持多客户端同时观看
|
||||||
|
- **音视频同步录制**:自动保存带时间戳的录像和音频文件
|
||||||
|
- **IMU 数据融合**:接收 ESP32 的 IMU 数据,支持姿态估计
|
||||||
|
- **多路音频混音**:支持系统语音、AI 回复、环境音同时播放
|
||||||
|
|
||||||
|
### 🎨 可视化与交互
|
||||||
|
- **Web 实时监控**:浏览器端实时查看处理后的视频流
|
||||||
|
- **IMU 3D 可视化**:Three.js 实时渲染设备姿态
|
||||||
|
- **状态面板**:显示导航状态、检测信息、FPS 等
|
||||||
|
- **中文友好**:所有界面和语音使用中文,支持自定义字体
|
||||||
|
|
||||||
|
## 💻 系统要求
|
||||||
|
|
||||||
|
### 硬件要求
|
||||||
|
- **开发/服务器端**:
|
||||||
|
- CPU: Intel i5 或以上(推荐 i7/i9)
|
||||||
|
- GPU: NVIDIA GPU(支持 CUDA 11.8+,推荐 RTX 3060 或以上)
|
||||||
|
- 内存: 8GB RAM(推荐 16GB)
|
||||||
|
- 存储: 10GB 可用空间
|
||||||
|
|
||||||
|
- **客户端设备**(可选):
|
||||||
|
- ESP32-CAM 或其他支持 WebSocket 的摄像头
|
||||||
|
- 麦克风(用于语音输入)
|
||||||
|
- 扬声器/耳机(用于语音输出)
|
||||||
|
|
||||||
|
### 软件要求
|
||||||
|
- **操作系统**: Windows 10/11, Linux (Ubuntu 20.04+), macOS 10.15+
|
||||||
|
- **Python**: 3.9 - 3.11
|
||||||
|
- **CUDA**: 11.8 或更高版本(GPU 加速必需)
|
||||||
|
- **浏览器**: Chrome 90+, Firefox 88+, Edge 90+(用于 Web 监控)
|
||||||
|
|
||||||
|
### API 密钥
|
||||||
|
- **阿里云 DashScope API Key**(必需):
|
||||||
|
- 用于语音识别(ASR)和 Qwen-Omni 对话
|
||||||
|
- 申请地址:https://dashscope.console.aliyun.com/
|
||||||
|
|
||||||
|
## 🚀 快速开始
|
||||||
|
|
||||||
|
### 1. 克隆项目
|
||||||
|
|
||||||
|
```bash
|
||||||
|
git clone https://github.com/yourusername/aiglass.git
|
||||||
|
cd aiglass/rebuild1002
|
||||||
|
```
|
||||||
|
|
||||||
|
### 2. 安装依赖
|
||||||
|
|
||||||
|
#### 创建虚拟环境(推荐)
|
||||||
|
```bash
|
||||||
|
python -m venv venv
|
||||||
|
# Windows
|
||||||
|
venv\Scripts\activate
|
||||||
|
# Linux/macOS
|
||||||
|
source venv/bin/activate
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 安装 Python 包
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 安装 CUDA 和 cuDNN(GPU 加速)
|
||||||
|
请参考 [NVIDIA CUDA Toolkit 安装指南](https://developer.nvidia.com/cuda-downloads)
|
||||||
|
|
||||||
|
### 3. 下载模型文件
|
||||||
|
|
||||||
|
将以下模型文件放入 `model/` 目录:
|
||||||
|
|
||||||
|
| 模型文件 | 用途 | 大小 | 下载链接 |
|
||||||
|
|---------|------|------|---------|
|
||||||
|
| `yolo-seg.pt` | 盲道分割 | ~50MB | [待补充] |
|
||||||
|
| `yoloe-11l-seg.pt` | 开放词汇检测 | ~80MB | [待补充] |
|
||||||
|
| `shoppingbest5.pt` | 物品识别 | ~30MB | [待补充] |
|
||||||
|
| `trafficlight.pt` | 红绿灯检测 | ~20MB | [待补充] |
|
||||||
|
| `hand_landmarker.task` | 手部检测 | ~15MB | [MediaPipe Models](https://developers.google.com/mediapipe/solutions/vision/hand_landmarker#models) |
|
||||||
|
|
||||||
|
### 4. 配置 API 密钥
|
||||||
|
|
||||||
|
创建 `.env` 文件:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# .env
|
||||||
|
DASHSCOPE_API_KEY=your_api_key_here
|
||||||
|
```
|
||||||
|
|
||||||
|
或在代码中直接修改(不推荐):
|
||||||
|
```python
|
||||||
|
# app_main.py, line 50
|
||||||
|
API_KEY = "your_api_key_here"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 5. 启动系统
|
||||||
|
|
||||||
|
```bash
|
||||||
|
python app_main.py
|
||||||
|
```
|
||||||
|
|
||||||
|
系统将在 `http://0.0.0.0:8081` 启动,打开浏览器访问即可看到实时监控界面。
|
||||||
|
|
||||||
|
### 6. 连接设备(可选)
|
||||||
|
|
||||||
|
如果使用 ESP32-CAM,请:
|
||||||
|
1. 烧录 `compile/compile.ino` 到 ESP32
|
||||||
|
2. 修改 WiFi 配置,连接到同一网络
|
||||||
|
3. ESP32 自动连接到 WebSocket 端点
|
||||||
|
|
||||||
|
## 🏗️ 系统架构
|
||||||
|
|
||||||
|
### 整体架构
|
||||||
|
|
||||||
|
```
|
||||||
|
┌─────────────────────────────────────────────────────────────┐
|
||||||
|
│ 客户端层 │
|
||||||
|
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
|
||||||
|
│ │ ESP32-CAM │ │ 浏览器 │ │ 移动端 │ │
|
||||||
|
│ │ (视频/音频) │ │ (监控界面) │ │ (语音控制) │ │
|
||||||
|
│ └──────┬───────┘ └──────┬───────┘ └──────┬───────┘ │
|
||||||
|
└─────────┼──────────────────┼──────────────────┼─────────────┘
|
||||||
|
│ WebSocket │ HTTP/WS │ WebSocket
|
||||||
|
┌─────────┼──────────────────┼──────────────────┼─────────────┐
|
||||||
|
│ │ │ │ │
|
||||||
|
│ ┌────▼──────────────────▼──────────────────▼────────┐ │
|
||||||
|
│ │ FastAPI 主服务 (app_main.py) │ │
|
||||||
|
│ │ - WebSocket 路由管理 │ │
|
||||||
|
│ │ - 音视频流分发 │ │
|
||||||
|
│ │ - 状态管理与协调 │ │
|
||||||
|
│ └────┬────────────────┬────────────────┬─────────────┘ │
|
||||||
|
│ │ │ │ │
|
||||||
|
│ ┌──────▼──────┐ ┌──────▼──────┐ ┌──────▼──────┐ │
|
||||||
|
│ │ ASR 模块 │ │ Omni 对话 │ │ 音频播放 │ │
|
||||||
|
│ │ (asr_core) │ │(omni_client)│ │(audio_player)│ │
|
||||||
|
│ └──────────────┘ └──────────────┘ └──────────────┘ │
|
||||||
|
│ │
|
||||||
|
│ 应用层 │
|
||||||
|
└───────────────────────────────────────────────────────────────┘
|
||||||
|
│ │ │
|
||||||
|
┌─────────▼──────────────────▼──────────────────▼──────────────┐
|
||||||
|
│ 导航统领层 │
|
||||||
|
│ ┌─────────────────────────────────────────────────┐ │
|
||||||
|
│ │ NavigationMaster (navigation_master.py) │ │
|
||||||
|
│ │ - 状态机:IDLE/CHAT/BLINDPATH_NAV/ │ │
|
||||||
|
│ │ CROSSING/TRAFFIC_LIGHT/ITEM_SEARCH │ │
|
||||||
|
│ │ - 模式切换与协调 │ │
|
||||||
|
│ └───┬─────────────────────┬───────────────────┬───┘ │
|
||||||
|
│ │ │ │ │
|
||||||
|
│ ┌────▼────────┐ ┌────────▼────────┐ ┌─────▼──────┐ │
|
||||||
|
│ │ 盲道导航 │ │ 过马路导航 │ │ 物品查找 │ │
|
||||||
|
│ │(blindpath) │ │ (crossstreet) │ │(yolomedia) │ │
|
||||||
|
│ └──────────────┘ └──────────────────┘ └─────────────┘ │
|
||||||
|
└───────────────────────────────────────────────────────────────┘
|
||||||
|
│ │ │
|
||||||
|
┌─────────▼──────────────────▼──────────────────▼──────────────┐
|
||||||
|
│ 模型推理层 │
|
||||||
|
│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │
|
||||||
|
│ │ YOLO 分割 │ │ YOLO-E 检测 │ │ MediaPipe │ │
|
||||||
|
│ │ (盲道/斑马线) │ │ (开放词汇) │ │ (手部检测) │ │
|
||||||
|
│ └──────────────┘ └──────────────┘ └──────────────┘ │
|
||||||
|
│ ┌──────────────┐ ┌──────────────┐ │
|
||||||
|
│ │ 红绿灯检测 │ │ 光流稳定 │ │
|
||||||
|
│ │(HSV+YOLO) │ │(Lucas-Kanade)│ │
|
||||||
|
│ └──────────────┘ └──────────────┘ │
|
||||||
|
└───────────────────────────────────────────────────────────────┘
|
||||||
|
│
|
||||||
|
┌─────────▼─────────────────────────────────────────────────────┐
|
||||||
|
│ 外部服务层 │
|
||||||
|
│ ┌──────────────────────────────────────────────┐ │
|
||||||
|
│ │ 阿里云 DashScope API │ │
|
||||||
|
│ │ - Paraformer ASR (实时语音识别) │ │
|
||||||
|
│ │ - Qwen-Omni-Turbo (多模态对话) │ │
|
||||||
|
│ │ - Qwen-Turbo (标签提取) │ │
|
||||||
|
│ └──────────────────────────────────────────────┘ │
|
||||||
|
└───────────────────────────────────────────────────────────────┘
|
||||||
|
```
|
||||||
|
|
||||||
|
### 核心模块说明
|
||||||
|
|
||||||
|
| 模块 | 文件 | 功能 |
|
||||||
|
|------|------|------|
|
||||||
|
| **主应用** | `app_main.py` | FastAPI 服务、WebSocket 管理、状态协调 |
|
||||||
|
| **导航统领** | `navigation_master.py` | 状态机管理、模式切换、语音节流 |
|
||||||
|
| **盲道导航** | `workflow_blindpath.py` | 盲道检测、避障、转弯引导 |
|
||||||
|
| **过马路导航** | `workflow_crossstreet.py` | 斑马线检测、红绿灯识别、对齐引导 |
|
||||||
|
| **物品查找** | `yolomedia.py` | 物品检测、手部引导、抓取确认 |
|
||||||
|
| **语音识别** | `asr_core.py` | 实时 ASR、VAD、指令解析 |
|
||||||
|
| **语音合成** | `omni_client.py` | Qwen-Omni 流式语音生成 |
|
||||||
|
| **音频播放** | `audio_player.py` | 多路混音、TTS 播放、音量控制 |
|
||||||
|
| **视频录制** | `sync_recorder.py` | 音视频同步录制 |
|
||||||
|
| **桥接 IO** | `bridge_io.py` | 线程安全的帧缓冲与分发 |
|
||||||
|
|
||||||
|
## 📖 使用说明
|
||||||
|
|
||||||
|
### 语音指令
|
||||||
|
|
||||||
|
系统支持以下语音指令(说话时无需唤醒词):
|
||||||
|
|
||||||
|
#### 导航控制
|
||||||
|
```
|
||||||
|
"开始导航" / "盲道导航" → 启动盲道导航
|
||||||
|
"停止导航" / "结束导航" → 停止盲道导航
|
||||||
|
"开始过马路" / "帮我过马路" → 启动过马路模式
|
||||||
|
"过马路结束" / "结束过马路" → 停止过马路模式
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 红绿灯检测
|
||||||
|
```
|
||||||
|
"检测红绿灯" / "看红绿灯" → 启动红绿灯检测
|
||||||
|
"停止检测" / "停止红绿灯" → 停止检测
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 物品查找
|
||||||
|
```
|
||||||
|
"帮我找一下 [物品名]" → 启动物品搜索
|
||||||
|
示例:
|
||||||
|
- "帮我找一下红牛"
|
||||||
|
- "找一下AD钙奶"
|
||||||
|
- "帮我找矿泉水"
|
||||||
|
"找到了" / "拿到了" → 确认找到物品
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 智能对话
|
||||||
|
```
|
||||||
|
"帮我看看这是什么" → 拍照识别
|
||||||
|
"这个东西能吃吗" → 物品咨询
|
||||||
|
任何其他问题 → AI 对话
|
||||||
|
```
|
||||||
|
|
||||||
|
### 导航状态说明
|
||||||
|
|
||||||
|
系统包含以下主要状态(自动切换):
|
||||||
|
|
||||||
|
1. **IDLE** - 空闲状态
|
||||||
|
- 等待用户指令
|
||||||
|
- 显示原始视频流
|
||||||
|
|
||||||
|
2. **CHAT** - 对话模式
|
||||||
|
- 与 AI 进行多模态对话
|
||||||
|
- 暂停导航功能
|
||||||
|
|
||||||
|
3. **BLINDPATH_NAV** - 盲道导航
|
||||||
|
- **ONBOARDING**: 上盲道引导
|
||||||
|
- ROTATION: 旋转对准盲道
|
||||||
|
- TRANSLATION: 平移至盲道中心
|
||||||
|
- **NAVIGATING**: 沿盲道行走
|
||||||
|
- 实时方向修正
|
||||||
|
- 障碍物检测
|
||||||
|
- **MANEUVERING_TURN**: 转弯处理
|
||||||
|
- **AVOIDING_OBSTACLE**: 避障
|
||||||
|
|
||||||
|
4. **CROSSING** - 过马路模式
|
||||||
|
- **SEEKING_CROSSWALK**: 寻找斑马线
|
||||||
|
- **WAIT_TRAFFIC_LIGHT**: 等待绿灯
|
||||||
|
- **CROSSING**: 过马路中
|
||||||
|
- **SEEKING_NEXT_BLINDPATH**: 寻找对面盲道
|
||||||
|
|
||||||
|
5. **ITEM_SEARCH** - 物品查找
|
||||||
|
- 实时检测目标物品
|
||||||
|
- 引导手部靠近
|
||||||
|
- 确认抓取
|
||||||
|
|
||||||
|
6. **TRAFFIC_LIGHT_DETECTION** - 红绿灯检测
|
||||||
|
- 实时检测红绿灯状态
|
||||||
|
- 语音播报颜色变化
|
||||||
|
|
||||||
|
### Web 监控界面
|
||||||
|
|
||||||
|
打开浏览器访问 `http://localhost:8081`,可以看到:
|
||||||
|
|
||||||
|
- **实时视频流**:显示处理后的视频,包括导航标注
|
||||||
|
- **状态面板**:当前模式、检测信息、FPS
|
||||||
|
- **IMU 可视化**:设备姿态 3D 实时渲染
|
||||||
|
- **语音识别结果**:显示识别的文字和 AI 回复
|
||||||
|
|
||||||
|
### WebSocket 端点
|
||||||
|
|
||||||
|
| 端点 | 用途 | 数据格式 |
|
||||||
|
|------|------|---------|
|
||||||
|
| `/ws/camera` | ESP32 相机推流 | Binary (JPEG) |
|
||||||
|
| `/ws/viewer` | 浏览器订阅视频 | Binary (JPEG) |
|
||||||
|
| `/ws_audio` | ESP32 音频上传 | Binary (PCM16) |
|
||||||
|
| `/ws_ui` | UI 状态推送 | JSON |
|
||||||
|
| `/ws` | IMU 数据接收 | JSON |
|
||||||
|
| `/stream.wav` | 音频下载流 | Binary (WAV) |
|
||||||
|
|
||||||
|
## ⚙️ 配置说明
|
||||||
|
|
||||||
|
### 环境变量
|
||||||
|
|
||||||
|
创建 `.env` 文件配置以下参数:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 阿里云 API
|
||||||
|
DASHSCOPE_API_KEY=sk-xxxxx
|
||||||
|
|
||||||
|
# 模型路径(可选,使用默认路径可不配置)
|
||||||
|
BLIND_PATH_MODEL=model/yolo-seg.pt
|
||||||
|
OBSTACLE_MODEL=model/yoloe-11l-seg.pt
|
||||||
|
YOLOE_MODEL_PATH=model/yoloe-11l-seg.pt
|
||||||
|
|
||||||
|
# 导航参数
|
||||||
|
AIGLASS_MASK_MIN_AREA=1500 # 最小掩码面积
|
||||||
|
AIGLASS_MASK_MORPH=3 # 形态学核大小
|
||||||
|
AIGLASS_MASK_MISS_TTL=6 # 掩码丢失容忍帧数
|
||||||
|
AIGLASS_PANEL_SCALE=0.65 # 数据面板缩放
|
||||||
|
|
||||||
|
# 音频配置
|
||||||
|
TTS_INTERVAL_SEC=1.0 # 语音播报间隔
|
||||||
|
ENABLE_TTS=true # 启用语音播报
|
||||||
|
```
|
||||||
|
|
||||||
|
### 修改模型路径
|
||||||
|
|
||||||
|
如果模型文件不在默认位置,可以在相应文件中修改:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# workflow_blindpath.py
|
||||||
|
seg_model_path = "your/custom/path/yolo-seg.pt"
|
||||||
|
|
||||||
|
# yolomedia.py
|
||||||
|
YOLO_MODEL_PATH = "your/custom/path/shoppingbest5.pt"
|
||||||
|
HAND_TASK_PATH = "your/custom/path/hand_landmarker.task"
|
||||||
|
```
|
||||||
|
|
||||||
|
### 调整性能参数
|
||||||
|
|
||||||
|
根据硬件性能调整:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# yolomedia.py
|
||||||
|
HAND_DOWNSCALE = 0.8 # 手部检测降采样(越小越快,精度降低)
|
||||||
|
HAND_FPS_DIV = 1 # 手部检测抽帧(2=隔帧,3=每3帧)
|
||||||
|
|
||||||
|
# workflow_blindpath.py
|
||||||
|
FEATURE_PARAMS = dict(
|
||||||
|
maxCorners=600, # 光流特征点数(越少越快)
|
||||||
|
qualityLevel=0.001, # 特征点质量
|
||||||
|
minDistance=5 # 特征点最小间距
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 🛠️ 开发文档
|
||||||
|
|
||||||
|
### 添加新的语音指令
|
||||||
|
|
||||||
|
1. 在 `app_main.py` 的 `start_ai_with_text_custom()` 函数中添加:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 检查新指令
|
||||||
|
if "新指令关键词" in user_text:
|
||||||
|
# 执行自定义逻辑
|
||||||
|
print("[CUSTOM] 新指令被触发")
|
||||||
|
await ui_broadcast_final("[系统] 新功能已启动")
|
||||||
|
return
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 如需修改指令过滤规则:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 修改 allowed_keywords 列表
|
||||||
|
allowed_keywords = ["帮我看", "帮我找", "你的新关键词"]
|
||||||
|
```
|
||||||
|
|
||||||
|
### 扩展导航功能
|
||||||
|
|
||||||
|
1. 在 `workflow_blindpath.py` 添加新状态:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# 在 BlindPathNavigator.__init__() 中初始化
|
||||||
|
self.your_new_state_var = False
|
||||||
|
|
||||||
|
# 在 process_frame() 中处理
|
||||||
|
def process_frame(self, image):
|
||||||
|
if self.your_new_state_var:
|
||||||
|
# 自定义处理逻辑
|
||||||
|
guidance_text = "新状态引导"
|
||||||
|
# ...
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 在 `navigation_master.py` 添加状态机状态:
|
||||||
|
|
||||||
|
```python
|
||||||
|
class NavigationMaster:
|
||||||
|
def start_your_new_mode(self):
|
||||||
|
self.state = "YOUR_NEW_MODE"
|
||||||
|
# 初始化逻辑
|
||||||
|
```
|
||||||
|
|
||||||
|
### 集成新模型
|
||||||
|
|
||||||
|
1. 创建模型包装类:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# your_model_wrapper.py
|
||||||
|
class YourModelWrapper:
|
||||||
|
def __init__(self, model_path):
|
||||||
|
self.model = load_your_model(model_path)
|
||||||
|
|
||||||
|
def detect(self, image):
|
||||||
|
# 推理逻辑
|
||||||
|
return results
|
||||||
|
```
|
||||||
|
|
||||||
|
2. 在 `app_main.py` 中加载:
|
||||||
|
|
||||||
|
```python
|
||||||
|
your_model = YourModelWrapper("model/your_model.pt")
|
||||||
|
```
|
||||||
|
|
||||||
|
3. 在相应的工作流中调用:
|
||||||
|
|
||||||
|
```python
|
||||||
|
results = your_model.detect(image)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 调试技巧
|
||||||
|
|
||||||
|
1. **启用详细日志**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# app_main.py 顶部
|
||||||
|
import logging
|
||||||
|
logging.basicConfig(level=logging.DEBUG)
|
||||||
|
```
|
||||||
|
|
||||||
|
2. **查看帧率瓶颈**:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# yolomedia.py
|
||||||
|
PERF_DEBUG = True # 打印处理时间
|
||||||
|
```
|
||||||
|
|
||||||
|
3. **测试单个模块**:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
# 测试盲道导航
|
||||||
|
python test_cross_street_blindpath.py
|
||||||
|
|
||||||
|
# 测试红绿灯检测
|
||||||
|
python test_traffic_light.py
|
||||||
|
|
||||||
|
# 测试录制功能
|
||||||
|
python test_recorder.py
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
## 📄 许可证
|
||||||
|
|
||||||
|
本项目采用 MIT 许可证 - 详见 [LICENSE](LICENSE) 文件
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
154
ai_voice_pipeline.py
Normal file
154
ai_voice_pipeline.py
Normal file
@@ -0,0 +1,154 @@
|
|||||||
|
# ai_voice_pipeline.py
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
AI 语音交互管道 - Day 21
|
||||||
|
|
||||||
|
整合 SenseVoice + GLM-4.5-Flash + EdgeTTS
|
||||||
|
|
||||||
|
流程:
|
||||||
|
1. 客户端 VAD 检测语音结束
|
||||||
|
2. 发送完整音频到服务器
|
||||||
|
3. SenseVoice 识别 → GLM 生成回复 → EdgeTTS 合成语音
|
||||||
|
4. 流式返回 PCM 音频
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Optional, Callable, AsyncGenerator
|
||||||
|
|
||||||
|
# 导入各模块
|
||||||
|
from sensevoice_asr import recognize as asr_recognize, init_sensevoice
|
||||||
|
from glm_client import chat as llm_chat, chat_stream as llm_chat_stream
|
||||||
|
from edge_tts_client import (
|
||||||
|
text_to_speech_pcm_stream,
|
||||||
|
text_to_speech_pcm,
|
||||||
|
DEFAULT_VOICE,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def init_pipeline():
|
||||||
|
"""初始化 AI 管道(服务器启动时调用)"""
|
||||||
|
await init_sensevoice()
|
||||||
|
print("[AI Pipeline] 初始化完成")
|
||||||
|
|
||||||
|
|
||||||
|
async def process_voice(
|
||||||
|
pcm_audio: bytes,
|
||||||
|
image_base64: Optional[str] = None,
|
||||||
|
on_text: Optional[Callable[[str], None]] = None,
|
||||||
|
on_audio: Optional[Callable[[bytes], None]] = None,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
处理语音输入,返回 AI 回复
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pcm_audio: PCM16 音频数据 (16kHz, mono)
|
||||||
|
image_base64: 可选的图片(用于多模态)
|
||||||
|
on_text: 文本回调(用于 UI 显示)
|
||||||
|
on_audio: 音频回调(用于流式播放)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AI 回复文本
|
||||||
|
"""
|
||||||
|
# 1. 语音识别
|
||||||
|
user_text = await asr_recognize(pcm_audio)
|
||||||
|
|
||||||
|
if not user_text:
|
||||||
|
print("[AI Pipeline] 未识别到有效语音")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
print(f"[AI Pipeline] 用户说: {user_text}")
|
||||||
|
|
||||||
|
# 通知 UI
|
||||||
|
if on_text:
|
||||||
|
on_text(f"用户: {user_text}")
|
||||||
|
|
||||||
|
# 2. LLM 生成回复
|
||||||
|
ai_response = await llm_chat(user_text, image_base64)
|
||||||
|
|
||||||
|
if not ai_response:
|
||||||
|
print("[AI Pipeline] AI 无回复")
|
||||||
|
return ""
|
||||||
|
|
||||||
|
print(f"[AI Pipeline] AI 回复: {ai_response}")
|
||||||
|
|
||||||
|
# 通知 UI
|
||||||
|
if on_text:
|
||||||
|
on_text(f"AI: {ai_response}")
|
||||||
|
|
||||||
|
# 3. TTS 合成并播放
|
||||||
|
if on_audio:
|
||||||
|
async for audio_chunk in text_to_speech_pcm_stream(ai_response):
|
||||||
|
on_audio(audio_chunk)
|
||||||
|
|
||||||
|
return ai_response
|
||||||
|
|
||||||
|
|
||||||
|
async def process_voice_stream(
|
||||||
|
pcm_audio: bytes,
|
||||||
|
image_base64: Optional[str] = None,
|
||||||
|
) -> AsyncGenerator[tuple, None]:
|
||||||
|
"""
|
||||||
|
流式处理语音输入
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pcm_audio: PCM16 音频数据
|
||||||
|
image_base64: 可选的图片
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
("text", str) - 文本片段
|
||||||
|
("audio", bytes) - 音频片段
|
||||||
|
"""
|
||||||
|
# 1. 语音识别
|
||||||
|
user_text = await asr_recognize(pcm_audio)
|
||||||
|
|
||||||
|
if not user_text:
|
||||||
|
return
|
||||||
|
|
||||||
|
yield ("user_text", user_text)
|
||||||
|
|
||||||
|
# 2. LLM 流式生成 + 3. TTS 流式合成
|
||||||
|
# 收集一定长度的文本后送 TTS
|
||||||
|
buffer = ""
|
||||||
|
punctuation = "。,!?;:,.!?;:"
|
||||||
|
|
||||||
|
async for text_chunk in llm_chat_stream(user_text, image_base64):
|
||||||
|
yield ("ai_text", text_chunk)
|
||||||
|
buffer += text_chunk
|
||||||
|
|
||||||
|
# 遇到标点时合成音频
|
||||||
|
if buffer and buffer[-1] in punctuation:
|
||||||
|
async for audio_chunk in text_to_speech_pcm_stream(buffer):
|
||||||
|
yield ("audio", audio_chunk)
|
||||||
|
buffer = ""
|
||||||
|
|
||||||
|
# 处理剩余文本
|
||||||
|
if buffer.strip():
|
||||||
|
async for audio_chunk in text_to_speech_pcm_stream(buffer):
|
||||||
|
yield ("audio", audio_chunk)
|
||||||
|
|
||||||
|
|
||||||
|
async def text_to_voice(text: str) -> bytes:
|
||||||
|
"""
|
||||||
|
文本转语音(用于导航提示等)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 要合成的文本
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PCM16 音频数据
|
||||||
|
"""
|
||||||
|
return await text_to_speech_pcm(text)
|
||||||
|
|
||||||
|
|
||||||
|
async def text_to_voice_stream(text: str) -> AsyncGenerator[bytes, None]:
|
||||||
|
"""
|
||||||
|
流式文本转语音
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 要合成的文本
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
PCM16 音频块
|
||||||
|
"""
|
||||||
|
async for chunk in text_to_speech_pcm_stream(text):
|
||||||
|
yield chunk
|
||||||
1868
app_main.py
Normal file
1868
app_main.py
Normal file
File diff suppressed because it is too large
Load Diff
221
asr_core.py
Normal file
221
asr_core.py
Normal file
@@ -0,0 +1,221 @@
|
|||||||
|
# asr_core.py
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import os, json, asyncio
|
||||||
|
from typing import Any, Dict, List, Optional, Callable, Tuple
|
||||||
|
|
||||||
|
ASR_DEBUG_RAW = os.getenv("ASR_DEBUG_RAW", "0") == "1"
|
||||||
|
|
||||||
|
def _shorten(s: str, limit: int = 200) -> str:
|
||||||
|
if not s:
|
||||||
|
return ""
|
||||||
|
return s if len(s) <= limit else (s[:limit] + "…")
|
||||||
|
|
||||||
|
def _safe_to_dict(x: Any) -> Dict[str, Any]:
|
||||||
|
if isinstance(x, dict): return x
|
||||||
|
for attr in ("to_dict", "model_dump", "__dict__"):
|
||||||
|
try:
|
||||||
|
v = getattr(x, attr, None)
|
||||||
|
except Exception:
|
||||||
|
v = None
|
||||||
|
if callable(v):
|
||||||
|
try:
|
||||||
|
d = v()
|
||||||
|
if isinstance(d, dict): return d
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
elif isinstance(v, dict):
|
||||||
|
return v
|
||||||
|
try:
|
||||||
|
s = str(x)
|
||||||
|
if s and s.lstrip().startswith("{") and s.rstrip().endswith("}"):
|
||||||
|
return json.loads(s)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return {"_raw": str(x)}
|
||||||
|
|
||||||
|
def _extract_sentence(event_obj: Any) -> Tuple[Optional[str], Optional[bool]]:
|
||||||
|
d = _safe_to_dict(event_obj)
|
||||||
|
cands: List[Dict[str, Any]] = [d]
|
||||||
|
for k in ("output", "data", "result"):
|
||||||
|
v = d.get(k)
|
||||||
|
if isinstance(v, dict):
|
||||||
|
cands.append(v)
|
||||||
|
for obj in cands:
|
||||||
|
sent = obj.get("sentence")
|
||||||
|
if isinstance(sent, dict):
|
||||||
|
text = sent.get("text")
|
||||||
|
is_end = sent.get("sentence_end")
|
||||||
|
if is_end is not None:
|
||||||
|
is_end = bool(is_end)
|
||||||
|
return text, is_end
|
||||||
|
for obj in cands:
|
||||||
|
if "text" in obj and isinstance(obj.get("text"), str):
|
||||||
|
return obj.get("text"), None
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# ====== 仅热词触发的“全清零复位”配置 ======
|
||||||
|
INTERRUPT_KEYWORDS = set(
|
||||||
|
os.getenv("INTERRUPT_KEYWORDS", "停下,别说了,停止").split(",")
|
||||||
|
)
|
||||||
|
|
||||||
|
# Day 21: 导航控制白名单 - 这些命令包含热词但不应触发重置
|
||||||
|
# 例如"停止导航"包含"停止",但不应该触发 full_system_reset
|
||||||
|
NAV_CONTROL_WHITELIST = [
|
||||||
|
"停止导航", "结束导航", "停止检测", "停止红绿灯",
|
||||||
|
"开始导航", "盲道导航", "开始过马路", "过马路结束",
|
||||||
|
"帮我导航", "帮我过马路"
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def _normalize_cn(s: str) -> str:
|
||||||
|
try:
|
||||||
|
import unicodedata
|
||||||
|
s = "".join(" " if unicodedata.category(ch) == "Zs" else ch for ch in s)
|
||||||
|
s = s.strip().lower()
|
||||||
|
except Exception:
|
||||||
|
s = (s or "").strip().lower()
|
||||||
|
return s
|
||||||
|
|
||||||
|
# ============ ASR 全局总闸 ============
|
||||||
|
_current_recognition: Optional[object] = None
|
||||||
|
_rec_lock = asyncio.Lock()
|
||||||
|
|
||||||
|
async def set_current_recognition(r):
|
||||||
|
global _current_recognition
|
||||||
|
async with _rec_lock:
|
||||||
|
_current_recognition = r
|
||||||
|
|
||||||
|
async def stop_current_recognition():
|
||||||
|
global _current_recognition
|
||||||
|
async with _rec_lock:
|
||||||
|
r = _current_recognition
|
||||||
|
_current_recognition = None
|
||||||
|
if r:
|
||||||
|
try:
|
||||||
|
r.stop() # DashScope SDK 的实时识别停止
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ============ ASR 回调 ============
|
||||||
|
class ASRCallback:
|
||||||
|
"""
|
||||||
|
设计目标:
|
||||||
|
1) “停下 / 别说了 …”等热词一出现 → 立刻全清零复位(恢复到刚启动后的状态)。
|
||||||
|
2) 除此之外【不接受打断】;AI 正在播报时,用户说话只做展示,不触发新一轮。
|
||||||
|
3) 不再用 partial 叠加字符串;partial 只用于 UI 临时展示;只有 final sentence 用于驱动 AI。
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
on_sdk_error: Callable[[str], None],
|
||||||
|
post: Callable[[asyncio.Future], None],
|
||||||
|
ui_broadcast_partial,
|
||||||
|
ui_broadcast_final,
|
||||||
|
is_playing_now_fn: Callable[[], bool],
|
||||||
|
start_ai_with_text_fn, # async (text)
|
||||||
|
full_system_reset_fn, # async (reason)
|
||||||
|
interrupt_lock: asyncio.Lock,
|
||||||
|
):
|
||||||
|
self._on_sdk_error = on_sdk_error
|
||||||
|
self._post = post
|
||||||
|
self._last_partial_for_ui: str = "" # 只用于 UI 展示
|
||||||
|
self._last_final_text: str = "" # 以句末 final 为准
|
||||||
|
self._hot_interrupted: bool = False # 本句是否因热词触发过复位(防抖)
|
||||||
|
|
||||||
|
self._ui_partial = ui_broadcast_partial
|
||||||
|
self._ui_final = ui_broadcast_final
|
||||||
|
self._is_playing = is_playing_now_fn
|
||||||
|
self._start_ai = start_ai_with_text_fn
|
||||||
|
self._full_reset = full_system_reset_fn
|
||||||
|
self._interrupt_lock = interrupt_lock
|
||||||
|
|
||||||
|
def on_open(self): pass
|
||||||
|
def on_close(self): pass
|
||||||
|
def on_complete(self): pass
|
||||||
|
|
||||||
|
def on_error(self, err):
|
||||||
|
try:
|
||||||
|
self._post(self._ui_partial(""))
|
||||||
|
self._on_sdk_error(str(err))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_result(self, result): self._handle(result)
|
||||||
|
def on_event(self, event): self._handle(event)
|
||||||
|
|
||||||
|
def _has_hotword(self, text: str) -> bool:
|
||||||
|
"""Day 21 修复: 检查是否包含热词,但排除导航控制命令"""
|
||||||
|
t = _normalize_cn(text)
|
||||||
|
if not t: return False
|
||||||
|
|
||||||
|
# Day 21: 先检查是否是导航控制命令(白名单),如果是则不触发热词
|
||||||
|
for nav_cmd in NAV_CONTROL_WHITELIST:
|
||||||
|
if _normalize_cn(nav_cmd) in t:
|
||||||
|
return False # 导航命令不触发热词重置
|
||||||
|
|
||||||
|
# 检查热词
|
||||||
|
for w in INTERRUPT_KEYWORDS:
|
||||||
|
if w and _normalize_cn(w) in t:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def _handle(self, event: Any):
|
||||||
|
if ASR_DEBUG_RAW:
|
||||||
|
try:
|
||||||
|
rawd = _safe_to_dict(event)
|
||||||
|
print("[ASR EVENT RAW]", json.dumps(rawd, ensure_ascii=False), flush=True)
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
text, is_end = _extract_sentence(event)
|
||||||
|
if text is None:
|
||||||
|
return
|
||||||
|
text = text.strip()
|
||||||
|
if not text:
|
||||||
|
return
|
||||||
|
|
||||||
|
# ---------- ① 热词优先:命中就全清零并短路,绝不送 LLM ----------
|
||||||
|
if not self._hot_interrupted and self._has_hotword(text):
|
||||||
|
self._hot_interrupted = True
|
||||||
|
|
||||||
|
async def _hot_reset():
|
||||||
|
async with self._interrupt_lock:
|
||||||
|
print(f"[ASR HOTWORD] '{text}' -> FULL RESET, skip LLM", flush=True)
|
||||||
|
await self._full_reset("Hotword interrupt")
|
||||||
|
try:
|
||||||
|
self._post(_hot_reset())
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return
|
||||||
|
|
||||||
|
# ---------- ② partial:仅用于 UI 展示 ----------
|
||||||
|
self._last_partial_for_ui = text
|
||||||
|
try:
|
||||||
|
print(f"[ASR PARTIAL] len={len(text)} text='{_shorten(text)}'", flush=True)
|
||||||
|
self._post(self._ui_partial(self._last_partial_for_ui))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ---------- ③ final:仅 final 驱动 LLM(若未在播报) ----------
|
||||||
|
if is_end is True:
|
||||||
|
final_text = text
|
||||||
|
try:
|
||||||
|
print(f"[ASR FINAL] len={len(final_text)} text='{final_text}'", flush=True)
|
||||||
|
self._post(self._ui_final(final_text))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
if (not self._is_playing()) and final_text:
|
||||||
|
async def _run_final():
|
||||||
|
async with self._interrupt_lock:
|
||||||
|
print(f"[LLM INPUT TEXT] {final_text}", flush=True)
|
||||||
|
await self._start_ai(final_text)
|
||||||
|
try:
|
||||||
|
self._post(_run_final())
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 复位进入下一句
|
||||||
|
self._last_partial_for_ui = ""
|
||||||
|
self._last_final_text = ""
|
||||||
|
self._hot_interrupted = False
|
||||||
439
audio_compressor.py
Normal file
439
audio_compressor.py
Normal file
@@ -0,0 +1,439 @@
|
|||||||
|
# audio_compressor.py
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
音频压缩工具 - 用于减少网络带宽占用
|
||||||
|
支持将16kHz 16bit PCM压缩为更小的格式
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import wave
|
||||||
|
import struct
|
||||||
|
import numpy as np
|
||||||
|
from typing import Optional, Tuple
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class AudioCompressor:
|
||||||
|
"""音频压缩器 - 支持多种压缩算法"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pcm16_to_ulaw(pcm_data: bytes) -> bytes:
|
||||||
|
"""
|
||||||
|
将16位PCM转换为8位μ-law
|
||||||
|
压缩率:50%(16bit -> 8bit)
|
||||||
|
"""
|
||||||
|
# 解析16位PCM
|
||||||
|
samples = np.frombuffer(pcm_data, dtype=np.int16)
|
||||||
|
|
||||||
|
# μ-law压缩
|
||||||
|
ulaw_data = bytearray()
|
||||||
|
for sample in samples:
|
||||||
|
ulaw_byte = AudioCompressor._linear_to_ulaw(sample)
|
||||||
|
ulaw_data.append(ulaw_byte)
|
||||||
|
|
||||||
|
return bytes(ulaw_data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def ulaw_to_pcm16(ulaw_data: bytes) -> bytes:
|
||||||
|
"""
|
||||||
|
将8位μ-law转换回16位PCM
|
||||||
|
"""
|
||||||
|
pcm_samples = []
|
||||||
|
for ulaw_byte in ulaw_data:
|
||||||
|
pcm_sample = AudioCompressor._ulaw_to_linear(ulaw_byte)
|
||||||
|
pcm_samples.append(pcm_sample)
|
||||||
|
|
||||||
|
return np.array(pcm_samples, dtype=np.int16).tobytes()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _linear_to_ulaw(sample: int) -> int:
|
||||||
|
"""
|
||||||
|
16位线性PCM转μ-law
|
||||||
|
"""
|
||||||
|
# μ-law编码表
|
||||||
|
ULAW_MAX = 0x1FFF
|
||||||
|
ULAW_BIAS = 0x84
|
||||||
|
|
||||||
|
# 限制范围
|
||||||
|
sample = max(-32768, min(32767, sample))
|
||||||
|
|
||||||
|
# 获取符号位
|
||||||
|
sign = 0
|
||||||
|
if sample < 0:
|
||||||
|
sign = 0x80
|
||||||
|
sample = -sample
|
||||||
|
|
||||||
|
# 添加偏置
|
||||||
|
sample = sample + ULAW_BIAS
|
||||||
|
|
||||||
|
# 限制最大值
|
||||||
|
if sample > ULAW_MAX:
|
||||||
|
sample = ULAW_MAX
|
||||||
|
|
||||||
|
# 查找指数和尾数
|
||||||
|
exponent = 7
|
||||||
|
for exp in range(7, -1, -1):
|
||||||
|
if sample & (0x4000 >> exp):
|
||||||
|
exponent = exp
|
||||||
|
break
|
||||||
|
|
||||||
|
mantissa = (sample >> (exponent + 3)) & 0x0F
|
||||||
|
ulawbyte = ~(sign | (exponent << 4) | mantissa) & 0xFF
|
||||||
|
|
||||||
|
return ulawbyte
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _ulaw_to_linear(ulawbyte: int) -> int:
|
||||||
|
"""
|
||||||
|
μ-law转16位线性PCM
|
||||||
|
"""
|
||||||
|
ULAW_BIAS = 0x84
|
||||||
|
|
||||||
|
ulawbyte = ~ulawbyte & 0xFF
|
||||||
|
sign = ulawbyte & 0x80
|
||||||
|
exponent = (ulawbyte >> 4) & 0x07
|
||||||
|
mantissa = ulawbyte & 0x0F
|
||||||
|
|
||||||
|
sample = ((mantissa << 3) + ULAW_BIAS) << exponent
|
||||||
|
|
||||||
|
if sign:
|
||||||
|
sample = -sample
|
||||||
|
|
||||||
|
return sample
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def pcm16_to_adpcm(pcm_data: bytes) -> bytes:
|
||||||
|
"""
|
||||||
|
将16位PCM转换为4位ADPCM
|
||||||
|
压缩率:75%(16bit -> 4bit)
|
||||||
|
保持较好的语音质量
|
||||||
|
"""
|
||||||
|
samples = np.frombuffer(pcm_data, dtype=np.int16)
|
||||||
|
|
||||||
|
# IMA ADPCM 步长表
|
||||||
|
step_table = [
|
||||||
|
7, 8, 9, 10, 11, 12, 13, 14, 16, 17,
|
||||||
|
19, 21, 23, 25, 28, 31, 34, 37, 41, 45,
|
||||||
|
50, 55, 60, 66, 73, 80, 88, 97, 107, 118,
|
||||||
|
130, 143, 157, 173, 190, 209, 230, 253, 279, 307,
|
||||||
|
337, 371, 408, 449, 494, 544, 598, 658, 724, 796,
|
||||||
|
876, 963, 1060, 1166, 1282, 1411, 1552, 1707, 1878, 2066,
|
||||||
|
2272, 2499, 2749, 3024, 3327, 3660, 4026, 4428, 4871, 5358,
|
||||||
|
5894, 6484, 7132, 7845, 8630, 9493, 10442, 11487, 12635, 13899,
|
||||||
|
15289, 16818, 18500, 20350, 22385, 24623, 27086, 29794, 32767
|
||||||
|
]
|
||||||
|
|
||||||
|
# 索引调整表
|
||||||
|
index_table = [-1, -1, -1, -1, 2, 4, 6, 8]
|
||||||
|
|
||||||
|
# 初始化
|
||||||
|
adpcm_data = bytearray()
|
||||||
|
predicted = 0
|
||||||
|
step_index = 0
|
||||||
|
|
||||||
|
# 每两个样本打包成一个字节
|
||||||
|
for i in range(0, len(samples), 2):
|
||||||
|
byte = 0
|
||||||
|
|
||||||
|
for j in range(2):
|
||||||
|
if i + j < len(samples):
|
||||||
|
sample = samples[i + j]
|
||||||
|
|
||||||
|
# 计算差值
|
||||||
|
diff = sample - predicted
|
||||||
|
|
||||||
|
# 量化
|
||||||
|
step = step_table[step_index]
|
||||||
|
adpcm_sample = 0
|
||||||
|
|
||||||
|
if diff < 0:
|
||||||
|
adpcm_sample = 8
|
||||||
|
diff = -diff
|
||||||
|
|
||||||
|
if diff >= step:
|
||||||
|
adpcm_sample |= 4
|
||||||
|
diff -= step
|
||||||
|
|
||||||
|
step >>= 1
|
||||||
|
if diff >= step:
|
||||||
|
adpcm_sample |= 2
|
||||||
|
diff -= step
|
||||||
|
|
||||||
|
step >>= 1
|
||||||
|
if diff >= step:
|
||||||
|
adpcm_sample |= 1
|
||||||
|
|
||||||
|
# 更新预测值
|
||||||
|
step = step_table[step_index]
|
||||||
|
diff = 0
|
||||||
|
if adpcm_sample & 4:
|
||||||
|
diff += step
|
||||||
|
step >>= 1
|
||||||
|
if adpcm_sample & 2:
|
||||||
|
diff += step
|
||||||
|
step >>= 1
|
||||||
|
if adpcm_sample & 1:
|
||||||
|
diff += step
|
||||||
|
step >>= 1
|
||||||
|
diff += step
|
||||||
|
|
||||||
|
if adpcm_sample & 8:
|
||||||
|
predicted -= diff
|
||||||
|
else:
|
||||||
|
predicted += diff
|
||||||
|
|
||||||
|
# 限制预测值范围
|
||||||
|
if predicted > 32767:
|
||||||
|
predicted = 32767
|
||||||
|
elif predicted < -32768:
|
||||||
|
predicted = -32768
|
||||||
|
|
||||||
|
# 更新步长索引
|
||||||
|
step_index += index_table[adpcm_sample & 7]
|
||||||
|
if step_index < 0:
|
||||||
|
step_index = 0
|
||||||
|
elif step_index > 88:
|
||||||
|
step_index = 88
|
||||||
|
|
||||||
|
# 打包到字节中
|
||||||
|
if j == 0:
|
||||||
|
byte = adpcm_sample
|
||||||
|
else:
|
||||||
|
byte |= (adpcm_sample << 4)
|
||||||
|
|
||||||
|
adpcm_data.append(byte)
|
||||||
|
|
||||||
|
# 添加头部信息:初始预测值和步长索引
|
||||||
|
header = struct.pack('<hB', predicted, step_index)
|
||||||
|
return header + bytes(adpcm_data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def adpcm_to_pcm16(adpcm_data: bytes) -> bytes:
|
||||||
|
"""
|
||||||
|
将4位ADPCM转换回16位PCM
|
||||||
|
"""
|
||||||
|
if len(adpcm_data) < 3:
|
||||||
|
return b''
|
||||||
|
|
||||||
|
# 读取头部
|
||||||
|
predicted, step_index = struct.unpack('<hB', adpcm_data[:3])
|
||||||
|
adpcm_bytes = adpcm_data[3:]
|
||||||
|
|
||||||
|
# IMA ADPCM 步长表
|
||||||
|
step_table = [
|
||||||
|
7, 8, 9, 10, 11, 12, 13, 14, 16, 17,
|
||||||
|
19, 21, 23, 25, 28, 31, 34, 37, 41, 45,
|
||||||
|
50, 55, 60, 66, 73, 80, 88, 97, 107, 118,
|
||||||
|
130, 143, 157, 173, 190, 209, 230, 253, 279, 307,
|
||||||
|
337, 371, 408, 449, 494, 544, 598, 658, 724, 796,
|
||||||
|
876, 963, 1060, 1166, 1282, 1411, 1552, 1707, 1878, 2066,
|
||||||
|
2272, 2499, 2749, 3024, 3327, 3660, 4026, 4428, 4871, 5358,
|
||||||
|
5894, 6484, 7132, 7845, 8630, 9493, 10442, 11487, 12635, 13899,
|
||||||
|
15289, 16818, 18500, 20350, 22385, 24623, 27086, 29794, 32767
|
||||||
|
]
|
||||||
|
|
||||||
|
# 索引调整表
|
||||||
|
index_table = [-1, -1, -1, -1, 2, 4, 6, 8]
|
||||||
|
|
||||||
|
pcm_samples = []
|
||||||
|
|
||||||
|
for byte in adpcm_bytes:
|
||||||
|
# 解码两个4位样本
|
||||||
|
for shift in [0, 4]:
|
||||||
|
adpcm_sample = (byte >> shift) & 0x0F
|
||||||
|
|
||||||
|
# 计算差值
|
||||||
|
step = step_table[step_index]
|
||||||
|
diff = 0
|
||||||
|
|
||||||
|
if adpcm_sample & 4:
|
||||||
|
diff += step
|
||||||
|
step >>= 1
|
||||||
|
if adpcm_sample & 2:
|
||||||
|
diff += step
|
||||||
|
step >>= 1
|
||||||
|
if adpcm_sample & 1:
|
||||||
|
diff += step
|
||||||
|
step >>= 1
|
||||||
|
diff += step
|
||||||
|
|
||||||
|
if adpcm_sample & 8:
|
||||||
|
predicted -= diff
|
||||||
|
else:
|
||||||
|
predicted += diff
|
||||||
|
|
||||||
|
# 限制范围
|
||||||
|
if predicted > 32767:
|
||||||
|
predicted = 32767
|
||||||
|
elif predicted < -32768:
|
||||||
|
predicted = -32768
|
||||||
|
|
||||||
|
pcm_samples.append(predicted)
|
||||||
|
|
||||||
|
# 更新步长索引
|
||||||
|
step_index += index_table[adpcm_sample & 7]
|
||||||
|
if step_index < 0:
|
||||||
|
step_index = 0
|
||||||
|
elif step_index > 88:
|
||||||
|
step_index = 88
|
||||||
|
|
||||||
|
return np.array(pcm_samples, dtype=np.int16).tobytes()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def downsample_pcm16(pcm_data: bytes, from_rate: int = 16000, to_rate: int = 8000) -> bytes:
|
||||||
|
"""
|
||||||
|
降采样(可选)
|
||||||
|
16kHz -> 8kHz 可以再减少50%数据量
|
||||||
|
"""
|
||||||
|
if from_rate == to_rate:
|
||||||
|
return pcm_data
|
||||||
|
|
||||||
|
# 解析PCM数据
|
||||||
|
samples = np.frombuffer(pcm_data, dtype=np.int16)
|
||||||
|
|
||||||
|
# 简单的降采样(每隔一个样本取一个)
|
||||||
|
if from_rate == 16000 and to_rate == 8000:
|
||||||
|
downsampled = samples[::2]
|
||||||
|
else:
|
||||||
|
# 更复杂的重采样需要scipy
|
||||||
|
ratio = to_rate / from_rate
|
||||||
|
new_length = int(len(samples) * ratio)
|
||||||
|
downsampled = np.interp(
|
||||||
|
np.linspace(0, len(samples) - 1, new_length),
|
||||||
|
np.arange(len(samples)),
|
||||||
|
samples
|
||||||
|
).astype(np.int16)
|
||||||
|
|
||||||
|
return downsampled.tobytes()
|
||||||
|
|
||||||
|
|
||||||
|
class CompressedAudioCache:
|
||||||
|
"""压缩音频缓存"""
|
||||||
|
|
||||||
|
def __init__(self, compression_type: str = "adpcm", use_downsample: bool = False):
|
||||||
|
"""
|
||||||
|
compression_type: "none", "ulaw", "adpcm"
|
||||||
|
"""
|
||||||
|
self.compression_type = compression_type
|
||||||
|
self.use_downsample = use_downsample
|
||||||
|
self._cache = {} # {filepath: compressed_data}
|
||||||
|
self._original_sizes = {} # {filepath: original_size}
|
||||||
|
|
||||||
|
def load_and_compress(self, filepath: str) -> Optional[bytes]:
|
||||||
|
"""加载并压缩音频文件(统一转换为8kHz)"""
|
||||||
|
if filepath in self._cache:
|
||||||
|
return self._cache[filepath]
|
||||||
|
|
||||||
|
try:
|
||||||
|
with wave.open(filepath, 'rb') as wav:
|
||||||
|
# 检查格式
|
||||||
|
channels = wav.getnchannels()
|
||||||
|
sampwidth = wav.getsampwidth()
|
||||||
|
framerate = wav.getframerate()
|
||||||
|
|
||||||
|
if channels != 1:
|
||||||
|
logger.warning(f"{filepath} 不是单声道")
|
||||||
|
if sampwidth != 2:
|
||||||
|
logger.warning(f"{filepath} 不是16位音频")
|
||||||
|
|
||||||
|
# 读取所有数据
|
||||||
|
frames = wav.readframes(wav.getnframes())
|
||||||
|
|
||||||
|
# 如果是立体声,转换为单声道
|
||||||
|
if channels == 2:
|
||||||
|
import audioop
|
||||||
|
frames = audioop.tomono(frames, sampwidth, 1, 0)
|
||||||
|
|
||||||
|
# 【修改】始终转换为16kHz(匹配客户端播放器)
|
||||||
|
if framerate != 16000:
|
||||||
|
import audioop
|
||||||
|
frames, _ = audioop.ratecv(frames, sampwidth, 1, framerate, 16000, None)
|
||||||
|
framerate = 16000
|
||||||
|
|
||||||
|
# 记录原始大小(转换后的大小)
|
||||||
|
self._original_sizes[filepath] = len(frames)
|
||||||
|
|
||||||
|
# 压缩
|
||||||
|
if self.compression_type == "ulaw":
|
||||||
|
compressed = AudioCompressor.pcm16_to_ulaw(frames)
|
||||||
|
# 添加简单的头部信息(1字节标识 + 4字节原始长度)
|
||||||
|
header = struct.pack('!BI', 0x01, len(frames)) # 0x01表示μ-law
|
||||||
|
compressed = header + compressed
|
||||||
|
elif self.compression_type == "adpcm":
|
||||||
|
compressed = AudioCompressor.pcm16_to_adpcm(frames)
|
||||||
|
# 添加简单的头部信息(1字节标识 + 4字节原始长度)
|
||||||
|
header = struct.pack('!BI', 0x02, len(frames)) # 0x02表示ADPCM
|
||||||
|
compressed = header + compressed
|
||||||
|
else:
|
||||||
|
compressed = frames
|
||||||
|
|
||||||
|
self._cache[filepath] = compressed
|
||||||
|
|
||||||
|
# 打印压缩率
|
||||||
|
compression_ratio = len(compressed) / self._original_sizes[filepath]
|
||||||
|
logger.info(f"[压缩] {os.path.basename(filepath)}: "
|
||||||
|
f"{self._original_sizes[filepath]} -> {len(compressed)} bytes "
|
||||||
|
f"({compression_ratio:.1%})")
|
||||||
|
|
||||||
|
return compressed
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"压缩音频失败 {filepath}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def decompress(self, compressed_data: bytes) -> Optional[bytes]:
|
||||||
|
"""解压音频数据"""
|
||||||
|
if not compressed_data or len(compressed_data) < 5:
|
||||||
|
return compressed_data
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 检查头部
|
||||||
|
compression_type = compressed_data[0]
|
||||||
|
if compression_type == 0x01: # μ-law标识
|
||||||
|
header_size = 5
|
||||||
|
original_length = struct.unpack('!I', compressed_data[1:5])[0]
|
||||||
|
ulaw_data = compressed_data[header_size:]
|
||||||
|
|
||||||
|
# μ-law解压
|
||||||
|
pcm_data = AudioCompressor.ulaw_to_pcm16(ulaw_data)
|
||||||
|
|
||||||
|
return pcm_data
|
||||||
|
elif compression_type == 0x02: # ADPCM标识
|
||||||
|
header_size = 5
|
||||||
|
original_length = struct.unpack('!I', compressed_data[1:5])[0]
|
||||||
|
adpcm_data = compressed_data[header_size:]
|
||||||
|
|
||||||
|
# ADPCM解压
|
||||||
|
pcm_data = AudioCompressor.adpcm_to_pcm16(adpcm_data)
|
||||||
|
|
||||||
|
return pcm_data
|
||||||
|
else:
|
||||||
|
# 未压缩的数据
|
||||||
|
return compressed_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"解压音频失败: {e}")
|
||||||
|
return compressed_data
|
||||||
|
|
||||||
|
def get_compression_stats(self) -> dict:
|
||||||
|
"""获取压缩统计信息"""
|
||||||
|
total_original = sum(self._original_sizes.values())
|
||||||
|
total_compressed = sum(len(data) for data in self._cache.values())
|
||||||
|
|
||||||
|
return {
|
||||||
|
"files_cached": len(self._cache),
|
||||||
|
"total_original_size": total_original,
|
||||||
|
"total_compressed_size": total_compressed,
|
||||||
|
"compression_ratio": total_compressed / total_original if total_original > 0 else 0,
|
||||||
|
"bytes_saved": total_original - total_compressed
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 全局压缩音频缓存实例
|
||||||
|
# 默认使用ADPCM压缩,音质更好,压缩率也不错(75%)
|
||||||
|
# 可通过环境变量 AIGLASS_COMPRESS_TYPE 设置: none, ulaw, adpcm
|
||||||
|
import os
|
||||||
|
compression_type = os.getenv("AIGLASS_COMPRESS_TYPE", "adpcm").lower()
|
||||||
|
if compression_type not in ["none", "ulaw", "adpcm"]:
|
||||||
|
compression_type = "adpcm"
|
||||||
|
compressed_audio_cache = CompressedAudioCache(compression_type=compression_type, use_downsample=False)
|
||||||
392
audio_player.py
Normal file
392
audio_player.py
Normal file
@@ -0,0 +1,392 @@
|
|||||||
|
# audio_player.py
|
||||||
|
# 处理预录音频文件的播放,通过ESP32扬声器输出
|
||||||
|
|
||||||
|
import os
|
||||||
|
import wave
|
||||||
|
import json
|
||||||
|
import asyncio
|
||||||
|
import threading
|
||||||
|
import queue
|
||||||
|
import time
|
||||||
|
from audio_stream import broadcast_pcm16_realtime
|
||||||
|
from audio_compressor import compressed_audio_cache, AudioCompressor
|
||||||
|
|
||||||
|
# 导入录制器(避免循环导入,在需要时动态导入)
|
||||||
|
_recorder_imported = False
|
||||||
|
_sync_recorder = None
|
||||||
|
|
||||||
|
def _get_recorder():
|
||||||
|
"""延迟导入录制器"""
|
||||||
|
global _recorder_imported, _sync_recorder
|
||||||
|
if not _recorder_imported:
|
||||||
|
try:
|
||||||
|
import sync_recorder as sr
|
||||||
|
_sync_recorder = sr
|
||||||
|
_recorder_imported = True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[AUDIO] 无法导入录制器: {e}")
|
||||||
|
_recorder_imported = True # 标记已尝试,避免重复
|
||||||
|
return _sync_recorder
|
||||||
|
|
||||||
|
# 兼容旧工程中的示例音频(保留)- 改为相对路径
|
||||||
|
AUDIO_BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "music")
|
||||||
|
|
||||||
|
# 新增:voice 目录与映射表
|
||||||
|
# 使用脚本所在目录的 voice 文件夹,避免工作目录问题
|
||||||
|
VOICE_DIR = os.getenv("VOICE_DIR", os.path.join(os.path.dirname(os.path.abspath(__file__)), "voice"))
|
||||||
|
VOICE_MAP_FILE = os.path.join(VOICE_DIR, "map.zh-CN.json")
|
||||||
|
|
||||||
|
# 音频文件映射(将合并 voice 映射)
|
||||||
|
AUDIO_MAP = {
|
||||||
|
"检测到物体": os.path.join(AUDIO_BASE_DIR, "音频1.wav"),
|
||||||
|
"向上": os.path.join(AUDIO_BASE_DIR, "音频2.wav"),
|
||||||
|
"向下": os.path.join(AUDIO_BASE_DIR, "音频3.wav"),
|
||||||
|
"向左": os.path.join(AUDIO_BASE_DIR, "音频4.wav"),
|
||||||
|
"向右": os.path.join(AUDIO_BASE_DIR, "音频5.wav"),
|
||||||
|
"OK": os.path.join(AUDIO_BASE_DIR, "音频6.wav"),
|
||||||
|
"向前": os.path.join(AUDIO_BASE_DIR, "音频7.wav"),
|
||||||
|
"后退": os.path.join(AUDIO_BASE_DIR, "音频8.wav"),
|
||||||
|
"拿到物体": os.path.join(AUDIO_BASE_DIR, "音频9.wav"),
|
||||||
|
}
|
||||||
|
|
||||||
|
# 音频缓存,避免重复读取
|
||||||
|
_audio_cache = {}
|
||||||
|
|
||||||
|
# 音频播放队列和工作线程 - 使用优先级队列
|
||||||
|
_audio_queue = queue.PriorityQueue(maxsize=10)
|
||||||
|
_audio_priority = 0 # 递增的优先级计数器
|
||||||
|
_worker_thread = None
|
||||||
|
_worker_loop = None
|
||||||
|
_is_playing = False # 标记是否正在播放音频
|
||||||
|
_playing_lock = threading.Lock() # 播放锁
|
||||||
|
_initialized = False
|
||||||
|
_last_play_ts = 0.0 # 记录上次播放结束时间,用于决定预热静音长度
|
||||||
|
|
||||||
|
def load_wav_file(filepath):
|
||||||
|
"""加载WAV文件并返回PCM数据(自动转换为8kHz)"""
|
||||||
|
if filepath in _audio_cache:
|
||||||
|
return _audio_cache[filepath]
|
||||||
|
|
||||||
|
# 使用压缩缓存
|
||||||
|
if os.getenv("AIGLASS_COMPRESS_AUDIO", "1") == "1":
|
||||||
|
compressed_data = compressed_audio_cache.load_and_compress(filepath)
|
||||||
|
if compressed_data:
|
||||||
|
# 存储压缩后的数据
|
||||||
|
_audio_cache[filepath] = compressed_data
|
||||||
|
return compressed_data
|
||||||
|
|
||||||
|
# 原始加载方式(不压缩)
|
||||||
|
try:
|
||||||
|
with wave.open(filepath, 'rb') as wav:
|
||||||
|
# 检查音频格式
|
||||||
|
channels = wav.getnchannels()
|
||||||
|
sampwidth = wav.getsampwidth()
|
||||||
|
framerate = wav.getframerate()
|
||||||
|
|
||||||
|
if channels != 1:
|
||||||
|
print(f"[AUDIO] 警告: {filepath} 不是单声道,将只使用第一个声道")
|
||||||
|
if sampwidth != 2:
|
||||||
|
print(f"[AUDIO] 警告: {filepath} 不是16位音频")
|
||||||
|
|
||||||
|
# 读取所有帧
|
||||||
|
frames = wav.readframes(wav.getnframes())
|
||||||
|
|
||||||
|
# 如果是立体声,只取左声道
|
||||||
|
if channels == 2:
|
||||||
|
import audioop
|
||||||
|
frames = audioop.tomono(frames, sampwidth, 1, 0)
|
||||||
|
|
||||||
|
# 统一转换为16kHz(使用ratecv保证音调和速度不变)
|
||||||
|
if framerate != 16000:
|
||||||
|
import audioop
|
||||||
|
frames, _ = audioop.ratecv(frames, sampwidth, 1, framerate, 16000, None)
|
||||||
|
print(f"[AUDIO] 重采样: {filepath} {framerate}Hz -> 16000Hz")
|
||||||
|
|
||||||
|
_audio_cache[filepath] = frames
|
||||||
|
return frames
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[AUDIO] 加载音频文件失败 {filepath}: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _merge_voice_map():
|
||||||
|
"""读取 voice/map.zh-CN.json 并合并到 AUDIO_MAP"""
|
||||||
|
try:
|
||||||
|
if not os.path.exists(VOICE_MAP_FILE):
|
||||||
|
print(f"[AUDIO] 未找到映射文件: {VOICE_MAP_FILE}")
|
||||||
|
return
|
||||||
|
with open(VOICE_MAP_FILE, "r", encoding="utf-8") as f:
|
||||||
|
m = json.load(f)
|
||||||
|
added = 0
|
||||||
|
for text, info in (m or {}).items():
|
||||||
|
files = (info or {}).get("files") or []
|
||||||
|
if not files:
|
||||||
|
continue
|
||||||
|
fname = files[0]
|
||||||
|
fpath = os.path.join(VOICE_DIR, fname)
|
||||||
|
if os.path.exists(fpath):
|
||||||
|
AUDIO_MAP[text] = fpath
|
||||||
|
added += 1
|
||||||
|
else:
|
||||||
|
print(f"[AUDIO] 映射文件缺失: {fpath}")
|
||||||
|
print(f"[AUDIO] 已合并 voice 映射 {added} 条")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[AUDIO] 读取 voice 映射失败: {e}")
|
||||||
|
|
||||||
|
def preload_all_audio():
|
||||||
|
"""预加载所有音频文件到内存"""
|
||||||
|
print("[AUDIO] 开始预加载音频文件...")
|
||||||
|
loaded_count = 0
|
||||||
|
|
||||||
|
# 【暂时禁用变速】因为需要修改缓存机制
|
||||||
|
# 需要加速的音频列表(斑马线相关)
|
||||||
|
# speedup_keywords = ["斑马线", "画面"]
|
||||||
|
# speedup_factor = 1.3 # 加速30%
|
||||||
|
|
||||||
|
for audio_key, filepath in AUDIO_MAP.items():
|
||||||
|
if os.path.exists(filepath):
|
||||||
|
# 【修复】暂时使用默认速度加载
|
||||||
|
# need_speedup = any(keyword in audio_key for keyword in speedup_keywords)
|
||||||
|
# speed = speedup_factor if need_speedup else 1.0
|
||||||
|
|
||||||
|
data = load_wav_file(filepath) # 使用默认参数
|
||||||
|
if data:
|
||||||
|
loaded_count += 1
|
||||||
|
# if need_speedup:
|
||||||
|
# print(f"[AUDIO] 加载(加速{speedup_factor}x): {audio_key}")
|
||||||
|
else:
|
||||||
|
# 降低噪声输出
|
||||||
|
pass
|
||||||
|
print(f"[AUDIO] 预加载完成,共加载 {loaded_count} 个音频文件")
|
||||||
|
|
||||||
|
def _audio_worker():
|
||||||
|
"""音频播放工作线程"""
|
||||||
|
global _worker_loop
|
||||||
|
|
||||||
|
# 尝试设置线程优先级(Windows特定)
|
||||||
|
try:
|
||||||
|
import ctypes
|
||||||
|
import sys
|
||||||
|
if sys.platform == "win32":
|
||||||
|
# 设置线程为高优先级
|
||||||
|
ctypes.windll.kernel32.SetThreadPriority(
|
||||||
|
ctypes.windll.kernel32.GetCurrentThread(),
|
||||||
|
1 # THREAD_PRIORITY_ABOVE_NORMAL
|
||||||
|
)
|
||||||
|
print("[AUDIO] 设置音频线程为高优先级")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[AUDIO] 设置线程优先级失败: {e}")
|
||||||
|
|
||||||
|
_worker_loop = asyncio.new_event_loop()
|
||||||
|
asyncio.set_event_loop(_worker_loop)
|
||||||
|
|
||||||
|
async def process_queue():
|
||||||
|
while True:
|
||||||
|
try:
|
||||||
|
# 从优先级队列获取数据
|
||||||
|
priority_data = await asyncio.get_event_loop().run_in_executor(None, _audio_queue.get, True)
|
||||||
|
if priority_data is None:
|
||||||
|
break
|
||||||
|
# 解包优先级和实际音频数据
|
||||||
|
if isinstance(priority_data, tuple) and len(priority_data) == 2:
|
||||||
|
_, audio_data = priority_data
|
||||||
|
else:
|
||||||
|
audio_data = priority_data
|
||||||
|
await _broadcast_audio_optimized(audio_data)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[AUDIO] 工作线程错误: {e}")
|
||||||
|
|
||||||
|
_worker_loop.run_until_complete(process_queue())
|
||||||
|
|
||||||
|
async def _broadcast_audio_optimized(pcm_data: bytes):
|
||||||
|
"""优化的音频广播:单次调用由底层按20ms节拍发送,移除重复节拍和Python层sleep"""
|
||||||
|
global _last_play_ts, _is_playing
|
||||||
|
try:
|
||||||
|
# 设置播放标志
|
||||||
|
with _playing_lock:
|
||||||
|
_is_playing = True
|
||||||
|
# 此时 pcm_data 应该已经是解压后的16位PCM数据了(8kHz)
|
||||||
|
now = time.monotonic()
|
||||||
|
idle_sec = now - (_last_play_ts or now)
|
||||||
|
# 首次或长时间空闲后,预热更长静音;否则小静音
|
||||||
|
lead_ms = 160 if idle_sec > 3.0 else 60
|
||||||
|
tail_ms = 40
|
||||||
|
|
||||||
|
lead_silence = b'\x00' * (lead_ms * 16000 * 2 // 1000) # 16k * 2B
|
||||||
|
tail_silence = b'\x00' * (tail_ms * 16000 * 2 // 1000)
|
||||||
|
|
||||||
|
# 完整音频数据(包含静音)
|
||||||
|
full_audio = lead_silence + pcm_data + tail_silence
|
||||||
|
|
||||||
|
# 注意:录制在 broadcast_pcm16_realtime 中统一完成,避免重复
|
||||||
|
|
||||||
|
# 单次调用交给底层 pacing(20ms节拍在 broadcast_pcm16_realtime 内部实现)
|
||||||
|
await broadcast_pcm16_realtime(full_audio)
|
||||||
|
|
||||||
|
_last_play_ts = time.monotonic()
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[AUDIO] 广播音频失败: {e}")
|
||||||
|
finally:
|
||||||
|
# 清除播放标志
|
||||||
|
with _playing_lock:
|
||||||
|
_is_playing = False
|
||||||
|
|
||||||
|
def initialize_audio_system():
|
||||||
|
"""初始化音频系统"""
|
||||||
|
global _initialized, _worker_thread, _last_play_ts
|
||||||
|
|
||||||
|
if _initialized:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 先合并 voice 映射,再预加载
|
||||||
|
_merge_voice_map()
|
||||||
|
preload_all_audio()
|
||||||
|
|
||||||
|
_worker_thread = threading.Thread(target=_audio_worker, daemon=True)
|
||||||
|
_worker_thread.start()
|
||||||
|
_initialized = True
|
||||||
|
_last_play_ts = 0.0
|
||||||
|
|
||||||
|
# 显示压缩统计
|
||||||
|
if os.getenv("AIGLASS_COMPRESS_AUDIO", "1") == "1":
|
||||||
|
stats = compressed_audio_cache.get_compression_stats()
|
||||||
|
print(f"[AUDIO] 音频压缩统计:")
|
||||||
|
print(f" - 文件数: {stats['files_cached']}")
|
||||||
|
print(f" - 原始大小: {stats['total_original_size'] / 1024:.1f} KB")
|
||||||
|
print(f" - 压缩后: {stats['total_compressed_size'] / 1024:.1f} KB")
|
||||||
|
print(f" - 压缩率: {stats['compression_ratio']:.1%}")
|
||||||
|
print(f" - 节省: {stats['bytes_saved'] / 1024:.1f} KB")
|
||||||
|
|
||||||
|
print("[AUDIO] 音频系统初始化完成(预加载+工作线程)")
|
||||||
|
|
||||||
|
def play_audio_threadsafe(audio_key):
|
||||||
|
"""线程安全的音频播放函数"""
|
||||||
|
global _audio_queue, _audio_priority
|
||||||
|
|
||||||
|
if not _initialized:
|
||||||
|
initialize_audio_system()
|
||||||
|
|
||||||
|
if audio_key not in AUDIO_MAP:
|
||||||
|
print(f"[AUDIO] 未知的音频键: {audio_key}")
|
||||||
|
return
|
||||||
|
|
||||||
|
filepath = AUDIO_MAP[audio_key]
|
||||||
|
pcm_data = _audio_cache.get(filepath)
|
||||||
|
if pcm_data is None:
|
||||||
|
print(f"[AUDIO] 音频未在缓存中: {audio_key}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 如果是压缩的数据,先解压
|
||||||
|
if pcm_data and len(pcm_data) > 5 and pcm_data[0] in [0x01, 0x02]:
|
||||||
|
pcm_data = compressed_audio_cache.decompress(pcm_data)
|
||||||
|
if not pcm_data:
|
||||||
|
print(f"[AUDIO] 解压失败: {audio_key}")
|
||||||
|
return
|
||||||
|
|
||||||
|
# 【优化】实时播报策略:保持队列最小化,避免积压延迟
|
||||||
|
queue_size = _audio_queue.qsize()
|
||||||
|
|
||||||
|
# 检查是否正在播放
|
||||||
|
with _playing_lock:
|
||||||
|
currently_playing = _is_playing
|
||||||
|
|
||||||
|
# 实时策略:只允许1个积压,超过立即清空
|
||||||
|
if queue_size > 0 and not currently_playing:
|
||||||
|
# 未播放时立即清空,播放最新语音
|
||||||
|
print(f"[AUDIO] 清空队列(当前{queue_size}个),播放最新语音")
|
||||||
|
_audio_queue = queue.PriorityQueue(maxsize=10)
|
||||||
|
elif queue_size > 1 and currently_playing:
|
||||||
|
# 正在播放时,如果积压>1个则清空(保持实时性)
|
||||||
|
print(f"[AUDIO] 队列积压({queue_size}个),清空以保持实时")
|
||||||
|
_audio_queue = queue.PriorityQueue(maxsize=10)
|
||||||
|
try:
|
||||||
|
# 使用优先级队列,确保音频按顺序播放
|
||||||
|
_audio_priority += 1
|
||||||
|
_audio_queue.put_nowait((_audio_priority, pcm_data))
|
||||||
|
if queue_size >= 1:
|
||||||
|
print(f"[AUDIO] 播放队列当前大小: {queue_size + 1}")
|
||||||
|
except queue.Full:
|
||||||
|
# 播放队列满则丢弃,保持实时性
|
||||||
|
print(f"[AUDIO] 队列满,丢弃: {audio_key}")
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 全局语音节流 - Day 22 优化: 降低冷却时间
|
||||||
|
_last_voice_time = 0
|
||||||
|
_last_voice_text = ""
|
||||||
|
_voice_cooldown = 1.5 # 相同语音至少间隔1.5秒 (从3.0降低)
|
||||||
|
|
||||||
|
# 语音优先级定义
|
||||||
|
VOICE_PRIORITY = {
|
||||||
|
'obstacle': 100, # 障碍物 - 最高优先级
|
||||||
|
'direction': 50, # 转向/平移 - 中等优先级
|
||||||
|
'straight': 10, # 保持直行 - 最低优先级
|
||||||
|
'other': 30 # 其他 - 默认优先级
|
||||||
|
}
|
||||||
|
|
||||||
|
# 新增:根据中文提示文案直接播放(会做轻度规范化与降级)
|
||||||
|
def play_voice_text(text: str):
|
||||||
|
"""
|
||||||
|
传入中文提示,自动匹配 voice 映射并播放。
|
||||||
|
- 尝试原文
|
||||||
|
- 尝试补全/去除句末标点(。.!!??)
|
||||||
|
- 若包含“前方有…注意避让”但未命中,降级到“前方有障碍物,注意避让。”
|
||||||
|
"""
|
||||||
|
global _last_voice_time, _last_voice_text
|
||||||
|
|
||||||
|
if not text:
|
||||||
|
return
|
||||||
|
if not _initialized:
|
||||||
|
initialize_audio_system()
|
||||||
|
|
||||||
|
# 全局节流:相同文本短时间内不重复播放
|
||||||
|
current_time = time.time()
|
||||||
|
if text == _last_voice_text and current_time - _last_voice_time < _voice_cooldown:
|
||||||
|
return # 静默跳过
|
||||||
|
|
||||||
|
candidates = []
|
||||||
|
t = text.strip()
|
||||||
|
candidates.append(t)
|
||||||
|
# 尝试补全句号
|
||||||
|
if t[-1:] not in ("。", "!", "!", "?", "?", "."):
|
||||||
|
candidates.append(t + "。")
|
||||||
|
else:
|
||||||
|
# 尝试去掉标点
|
||||||
|
t2 = t.rstrip("。.!!??")
|
||||||
|
if t2 and t2 != t:
|
||||||
|
candidates.append(t2)
|
||||||
|
|
||||||
|
# 逐一尝试匹配
|
||||||
|
for ck in candidates:
|
||||||
|
if ck in AUDIO_MAP:
|
||||||
|
play_audio_threadsafe(ck)
|
||||||
|
_last_voice_text = text
|
||||||
|
_last_voice_time = current_time
|
||||||
|
return
|
||||||
|
|
||||||
|
# 针对“前方有…注意避让”降级
|
||||||
|
if ("前方有" in t) and ("注意避让" in t):
|
||||||
|
fallback = "前方有障碍物,注意避让。"
|
||||||
|
if fallback in AUDIO_MAP:
|
||||||
|
play_audio_threadsafe(fallback)
|
||||||
|
_last_voice_text = text
|
||||||
|
_last_voice_time = current_time
|
||||||
|
return
|
||||||
|
|
||||||
|
# 针对“请向…平移/微调/转动”类词条,常见变体尝试
|
||||||
|
base = t.rstrip("。.!!??")
|
||||||
|
if base in AUDIO_MAP:
|
||||||
|
play_audio_threadsafe(base)
|
||||||
|
_last_voice_text = text
|
||||||
|
_last_voice_time = current_time
|
||||||
|
return
|
||||||
|
if base + "。" in AUDIO_MAP:
|
||||||
|
play_audio_threadsafe(base + "。")
|
||||||
|
_last_voice_text = text
|
||||||
|
_last_voice_time = current_time
|
||||||
|
return
|
||||||
|
|
||||||
|
# 未匹配则输出日志(便于调试)
|
||||||
|
print(f"[AUDIO] 未找到匹配语音: {text}")
|
||||||
|
|
||||||
|
# 兼容旧接口
|
||||||
|
play_audio_on_esp32 = play_audio_threadsafe
|
||||||
299
audio_stream.py
Normal file
299
audio_stream.py
Normal file
@@ -0,0 +1,299 @@
|
|||||||
|
# audio_stream.py
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import asyncio
|
||||||
|
from collections import deque
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Set, List, Tuple, Any, Dict
|
||||||
|
from fastapi import Request
|
||||||
|
from fastapi.responses import StreamingResponse
|
||||||
|
|
||||||
|
# ===== 下行 WAV 流基础参数 =====
|
||||||
|
STREAM_SR = 8000 # 改为8kHz,ESP32支持
|
||||||
|
STREAM_CH = 1
|
||||||
|
STREAM_SW = 2
|
||||||
|
BYTES_PER_20MS_16K = STREAM_SR * STREAM_SW * 20 // 1000 # 320B (8kHz)
|
||||||
|
|
||||||
|
# ===== Day 13: TTS 缓存队列 =====
|
||||||
|
# 当 WebSocket 断开时,缓存 TTS 音频,等待重连后发送
|
||||||
|
TTS_BUFFER_MAX_SECONDS = 30 # 最多缓存 30 秒音频
|
||||||
|
TTS_BUFFER_MAX_BYTES = 16000 * 2 * TTS_BUFFER_MAX_SECONDS # 16kHz * 2 bytes * 30s = ~960KB
|
||||||
|
tts_audio_buffer: deque = deque() # 每个元素是 (timestamp, pcm16k_bytes)
|
||||||
|
tts_buffer_total_bytes = 0
|
||||||
|
|
||||||
|
# Day 13: TTS 专用 WebSocket 引用
|
||||||
|
# 在 AI 处理开始前保存,避免被 ws_audio 的 finally 块清空
|
||||||
|
tts_websocket = None
|
||||||
|
|
||||||
|
def set_tts_websocket(ws):
|
||||||
|
"""保存 TTS 发送专用的 WebSocket 引用"""
|
||||||
|
global tts_websocket
|
||||||
|
tts_websocket = ws
|
||||||
|
|
||||||
|
def get_tts_websocket():
|
||||||
|
"""获取 TTS WebSocket(优先使用保存的引用,其次尝试全局变量)"""
|
||||||
|
global tts_websocket
|
||||||
|
if tts_websocket is not None:
|
||||||
|
return tts_websocket
|
||||||
|
# Day 15 修复:避免 import app_main,因为会触发模块顶层代码重新执行
|
||||||
|
# 改为通过 sys.modules 获取已加载的模块引用
|
||||||
|
try:
|
||||||
|
import sys
|
||||||
|
if 'app_main' in sys.modules:
|
||||||
|
return sys.modules['app_main'].esp32_audio_ws
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
# ===== AI 播放任务总闸 =====
|
||||||
|
current_ai_task: Optional[asyncio.Task] = None
|
||||||
|
|
||||||
|
async def cancel_current_ai():
|
||||||
|
"""取消当前大模型语音任务,并等待其退出。"""
|
||||||
|
global current_ai_task
|
||||||
|
task = current_ai_task
|
||||||
|
current_ai_task = None
|
||||||
|
if task and not task.done():
|
||||||
|
task.cancel()
|
||||||
|
try:
|
||||||
|
await task
|
||||||
|
except asyncio.CancelledError:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def is_playing_now() -> bool:
|
||||||
|
t = current_ai_task
|
||||||
|
return (t is not None) and (not t.done())
|
||||||
|
|
||||||
|
# ===== /stream.wav 连接管理 =====
|
||||||
|
@dataclass(frozen=True)
|
||||||
|
class StreamClient:
|
||||||
|
q: asyncio.Queue
|
||||||
|
abort_event: asyncio.Event
|
||||||
|
|
||||||
|
stream_clients: "Set[StreamClient]" = set()
|
||||||
|
STREAM_QUEUE_MAX = 96 # 小缓冲,避免积压
|
||||||
|
|
||||||
|
def _wav_header_unknown_size(sr=16000, ch=1, sw=2) -> bytes:
|
||||||
|
import struct
|
||||||
|
byte_rate = sr * ch * sw
|
||||||
|
block_align = ch * sw
|
||||||
|
data_size = 0x7FFFFFF0
|
||||||
|
riff_size = 36 + data_size
|
||||||
|
return struct.pack(
|
||||||
|
"<4sI4s4sIHHIIHH4sI",
|
||||||
|
b"RIFF", riff_size, b"WAVE",
|
||||||
|
b"fmt ", 16,
|
||||||
|
1, ch, sr, byte_rate, block_align, sw * 8,
|
||||||
|
b"data", data_size
|
||||||
|
)
|
||||||
|
|
||||||
|
async def hard_reset_audio(reason: str = ""):
|
||||||
|
"""
|
||||||
|
**一键清场**:取消当前AI任务。
|
||||||
|
注意:不再断开 HTTP /stream.wav 连接,因为 Avaota F1 使用这个通道播放 TTS。
|
||||||
|
"""
|
||||||
|
# Day 14: 不再断开 HTTP 连接,只取消 AI 任务
|
||||||
|
# 因为 Avaota F1 的 HTTP TTS 客户端需要保持长连接
|
||||||
|
# 断开会导致客户端收不到后续 TTS 音频
|
||||||
|
|
||||||
|
|
||||||
|
# 2) 取消当前AI任务
|
||||||
|
await cancel_current_ai()
|
||||||
|
|
||||||
|
# 3) 日志
|
||||||
|
if reason:
|
||||||
|
print(f"[HARD-RESET] {reason}")
|
||||||
|
|
||||||
|
async def flush_tts_buffer(ws) -> int:
|
||||||
|
"""
|
||||||
|
Day 13: 刷新 TTS 缓存,发送所有缓存的音频到 WebSocket
|
||||||
|
返回发送的字节数
|
||||||
|
"""
|
||||||
|
global tts_audio_buffer, tts_buffer_total_bytes
|
||||||
|
from starlette.websockets import WebSocketState
|
||||||
|
|
||||||
|
if not tts_audio_buffer:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
total_sent = 0
|
||||||
|
items_to_send = list(tts_audio_buffer)
|
||||||
|
tts_audio_buffer.clear()
|
||||||
|
tts_buffer_total_bytes = 0
|
||||||
|
|
||||||
|
try:
|
||||||
|
for _, audio_data in items_to_send:
|
||||||
|
if hasattr(ws, 'client_state') and ws.client_state != WebSocketState.CONNECTED:
|
||||||
|
print(f"[TTS->WS] ⚠️ WebSocket disconnected while flushing buffer")
|
||||||
|
break
|
||||||
|
await ws.send_bytes(audio_data)
|
||||||
|
total_sent += len(audio_data)
|
||||||
|
|
||||||
|
if total_sent > 0:
|
||||||
|
duration = total_sent / (16000 * 2)
|
||||||
|
print(f"[TTS->WS] 📤 Flushed {total_sent} bytes ({duration:.1f}s) of cached TTS audio")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[TTS->WS] ❌ Error flushing buffer: {e}")
|
||||||
|
|
||||||
|
return total_sent
|
||||||
|
|
||||||
|
async def broadcast_pcm16_realtime(pcm16: bytes):
|
||||||
|
"""
|
||||||
|
Day 14 优化:WebSocket 立即发送,HTTP 节拍广播在后台执行
|
||||||
|
避免 HTTP 20ms pacing 阻塞 WebSocket TTS 传输
|
||||||
|
"""
|
||||||
|
# 【新增】录制音频(在分发之前整体录制,避免分片)
|
||||||
|
try:
|
||||||
|
import sync_recorder
|
||||||
|
sync_recorder.record_audio(pcm16, text="[Omni对话]")
|
||||||
|
except Exception:
|
||||||
|
pass # 静默失败,不影响播放
|
||||||
|
|
||||||
|
# Day 13: 同时发送给 WebSocket 客户端 (Avaota F1)
|
||||||
|
# 注意:Avaota 期望 16kHz PCM16 数据,而这里的 pcm16 是 8kHz
|
||||||
|
# 需要进行采样率转换
|
||||||
|
global tts_audio_buffer, tts_buffer_total_bytes
|
||||||
|
import time as _time
|
||||||
|
|
||||||
|
try:
|
||||||
|
import audioop
|
||||||
|
|
||||||
|
# Day 13: 使用 get_tts_websocket() 获取 WebSocket 引用
|
||||||
|
# 优先使用保存的引用,避免因 ws_audio 的 finally 清空全局变量
|
||||||
|
ws = get_tts_websocket()
|
||||||
|
|
||||||
|
# Day 21 优化:输入现在已经是 16kHz,无需转换
|
||||||
|
# app_main.py 已经直接从 24kHz 转换到 16kHz
|
||||||
|
pcm16k = pcm16
|
||||||
|
|
||||||
|
sent_ok = False
|
||||||
|
if ws is not None:
|
||||||
|
try:
|
||||||
|
# Day 13 修复:不检查 client_state,直接尝试发送
|
||||||
|
# WebSocketState 检查可能不准确,导致音频被错误缓存
|
||||||
|
|
||||||
|
# 先发送缓存的音频
|
||||||
|
while tts_audio_buffer:
|
||||||
|
_, buffered_audio = tts_audio_buffer.popleft()
|
||||||
|
tts_buffer_total_bytes -= len(buffered_audio)
|
||||||
|
await ws.send_bytes(buffered_audio)
|
||||||
|
if not getattr(broadcast_pcm16_realtime, '_flush_logged', False):
|
||||||
|
print(f"[TTS->WS] 📤 Flushing buffered TTS audio...")
|
||||||
|
broadcast_pcm16_realtime._flush_logged = True
|
||||||
|
|
||||||
|
# 发送当前音频
|
||||||
|
await ws.send_bytes(pcm16k)
|
||||||
|
sent_ok = True
|
||||||
|
|
||||||
|
if len(pcm16k) > 320:
|
||||||
|
print(f"[TTS->WS] 📤 Sent {len(pcm16k)} bytes (16kHz) to Avaota")
|
||||||
|
|
||||||
|
# 重置警告标志
|
||||||
|
broadcast_pcm16_realtime._ws_warned = False
|
||||||
|
broadcast_pcm16_realtime._buffer_warned = False
|
||||||
|
broadcast_pcm16_realtime._flush_logged = False
|
||||||
|
except Exception as send_err:
|
||||||
|
# 发送失败,将当前音频放回缓存
|
||||||
|
if not getattr(broadcast_pcm16_realtime, '_send_err_warned', False):
|
||||||
|
print(f"[TTS->WS] ❌ Send error: {send_err}, will buffer")
|
||||||
|
broadcast_pcm16_realtime._send_err_warned = True
|
||||||
|
|
||||||
|
# 如果发送失败或 WebSocket 断开,缓存音频
|
||||||
|
if not sent_ok:
|
||||||
|
# 添加到缓存队列
|
||||||
|
tts_audio_buffer.append((_time.time(), pcm16k))
|
||||||
|
tts_buffer_total_bytes += len(pcm16k)
|
||||||
|
|
||||||
|
# 如果缓存过大,移除最旧的
|
||||||
|
while tts_buffer_total_bytes > TTS_BUFFER_MAX_BYTES and tts_audio_buffer:
|
||||||
|
_, old_audio = tts_audio_buffer.popleft()
|
||||||
|
tts_buffer_total_bytes -= len(old_audio)
|
||||||
|
|
||||||
|
if not getattr(broadcast_pcm16_realtime, '_buffer_warned', False):
|
||||||
|
buffer_secs = tts_buffer_total_bytes / (16000 * 2)
|
||||||
|
print(f"[TTS->WS] 📦 Buffering TTS audio ({buffer_secs:.1f}s cached), will send when reconnected")
|
||||||
|
broadcast_pcm16_realtime._buffer_warned = True
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
pass # 静默忽略所有异常
|
||||||
|
|
||||||
|
# Day 14 优化:将 HTTP 节拍广播放到后台任务,不阻塞 WebSocket 发送
|
||||||
|
# 这样下一个 Omni 音频块可以立即处理,不用等待 HTTP 节拍完成
|
||||||
|
if stream_clients:
|
||||||
|
asyncio.create_task(_http_pacing_broadcast(pcm16))
|
||||||
|
|
||||||
|
|
||||||
|
async def _http_pacing_broadcast(pcm16: bytes):
|
||||||
|
"""
|
||||||
|
Day 14: HTTP 客户端的 20ms 节拍广播(独立后台任务)
|
||||||
|
原来嵌入在 broadcast_pcm16_realtime 中,会阻塞 WebSocket 发送
|
||||||
|
"""
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
next_tick = loop.time()
|
||||||
|
off = 0
|
||||||
|
while off < len(pcm16):
|
||||||
|
take = min(BYTES_PER_20MS_16K, len(pcm16) - off)
|
||||||
|
piece = pcm16[off:off + take]
|
||||||
|
|
||||||
|
dead: List[StreamClient] = []
|
||||||
|
# Day 14 调试:确认 stream_clients 状态
|
||||||
|
if len(stream_clients) > 0 and off == 0:
|
||||||
|
print(f"[TTS->HTTP] 📤 Sending to {len(stream_clients)} HTTP stream client(s)")
|
||||||
|
for sc in list(stream_clients):
|
||||||
|
if sc.abort_event.is_set():
|
||||||
|
dead.append(sc)
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
if sc.q.full():
|
||||||
|
try: sc.q.get_nowait()
|
||||||
|
except Exception: pass
|
||||||
|
sc.q.put_nowait(piece)
|
||||||
|
except Exception:
|
||||||
|
dead.append(sc)
|
||||||
|
for sc in dead:
|
||||||
|
try: stream_clients.discard(sc)
|
||||||
|
except Exception: pass
|
||||||
|
|
||||||
|
next_tick += 0.020
|
||||||
|
now = loop.time()
|
||||||
|
if now < next_tick:
|
||||||
|
await asyncio.sleep(next_tick - now)
|
||||||
|
else:
|
||||||
|
next_tick = now
|
||||||
|
off += take
|
||||||
|
|
||||||
|
# ===== FastAPI 路由注册器 =====
|
||||||
|
def register_stream_route(app):
|
||||||
|
@app.get("/stream.wav")
|
||||||
|
async def stream_wav(_: Request):
|
||||||
|
# —— 强制单连接(或少数连接),先拉闸所有旧连接 ——
|
||||||
|
for sc in list(stream_clients):
|
||||||
|
try: sc.abort_event.set()
|
||||||
|
except Exception: pass
|
||||||
|
stream_clients.clear()
|
||||||
|
|
||||||
|
q: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=STREAM_QUEUE_MAX)
|
||||||
|
abort_event = asyncio.Event()
|
||||||
|
sc = StreamClient(q=q, abort_event=abort_event)
|
||||||
|
stream_clients.add(sc)
|
||||||
|
|
||||||
|
async def gen():
|
||||||
|
yield _wav_header_unknown_size(STREAM_SR, STREAM_CH, STREAM_SW)
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
if abort_event.is_set():
|
||||||
|
break
|
||||||
|
try:
|
||||||
|
chunk = await asyncio.wait_for(q.get(), timeout=0.5)
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
continue
|
||||||
|
if abort_event.is_set():
|
||||||
|
break
|
||||||
|
if chunk is None:
|
||||||
|
break
|
||||||
|
if chunk:
|
||||||
|
yield chunk
|
||||||
|
finally:
|
||||||
|
stream_clients.discard(sc)
|
||||||
|
return StreamingResponse(gen(), media_type="audio/wav")
|
||||||
92
bridge_io.py
Normal file
92
bridge_io.py
Normal file
@@ -0,0 +1,92 @@
|
|||||||
|
# bridge_io.py
|
||||||
|
# 极简桥:接原始JPEG → 提供BGR帧给外部算法;外部算法产出BGR → 广播给前端
|
||||||
|
import threading
|
||||||
|
from collections import deque
|
||||||
|
import time
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
# 原始JPEG帧缓冲(只保留最新 N 帧)
|
||||||
|
_MAX_BUF = 4
|
||||||
|
_frames = deque(maxlen=_MAX_BUF)
|
||||||
|
_cond = threading.Condition()
|
||||||
|
|
||||||
|
# 向前端发送JPEG的回调,由 app_main.py 在启动时注册
|
||||||
|
_sender_lock = threading.Lock()
|
||||||
|
_sender_cb = None
|
||||||
|
|
||||||
|
# 向前端发送UI文本的回调(由 app_main.py 在启动时注册)
|
||||||
|
_ui_sender_lock = threading.Lock()
|
||||||
|
_ui_sender_cb = None
|
||||||
|
|
||||||
|
def set_sender(cb):
|
||||||
|
"""由 app_main.py 调用,注册一个函数:cb(jpeg_bytes)->None"""
|
||||||
|
global _sender_cb
|
||||||
|
with _sender_lock:
|
||||||
|
_sender_cb = cb
|
||||||
|
|
||||||
|
def set_ui_sender(cb):
|
||||||
|
"""由 app_main.py 调用,注册一个函数:cb(text:str)->None"""
|
||||||
|
global _ui_sender_cb
|
||||||
|
with _ui_sender_lock:
|
||||||
|
_ui_sender_cb = cb
|
||||||
|
|
||||||
|
def push_raw_jpeg(jpeg_bytes: bytes):
|
||||||
|
"""由 app_main.py 在收到 /ws/camera 帧时调用"""
|
||||||
|
if not jpeg_bytes:
|
||||||
|
return
|
||||||
|
with _cond:
|
||||||
|
_frames.append((time.time(), jpeg_bytes))
|
||||||
|
_cond.notify_all()
|
||||||
|
|
||||||
|
def wait_raw_bgr(timeout_sec: float = 0.5):
|
||||||
|
"""被 YOLO/MediaPipe 脚本调用:等待并拿到最新一帧BGR;超时返回 None"""
|
||||||
|
t_end = time.time() + timeout_sec
|
||||||
|
last = None
|
||||||
|
while time.time() < t_end:
|
||||||
|
with _cond:
|
||||||
|
if _frames:
|
||||||
|
last = _frames[-1]
|
||||||
|
if last is None:
|
||||||
|
time.sleep(0.01)
|
||||||
|
continue
|
||||||
|
# 解码JPEG为BGR
|
||||||
|
ts, jpeg = last
|
||||||
|
arr = np.frombuffer(jpeg, dtype=np.uint8)
|
||||||
|
bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR)
|
||||||
|
if bgr is not None:
|
||||||
|
# 在最源头进行镜像处理
|
||||||
|
#bgr = cv2.flip(bgr, 1)
|
||||||
|
return bgr
|
||||||
|
# 解码失败,稍等重试
|
||||||
|
time.sleep(0.01)
|
||||||
|
return None
|
||||||
|
|
||||||
|
def send_vis_bgr(bgr, quality: int = 80):
|
||||||
|
"""被 YOLO/MediaPipe 脚本调用:把处理后画面推给前端 viewer"""
|
||||||
|
if bgr is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
# 直接编码,不做任何增强处理
|
||||||
|
ok, enc = cv2.imencode(".jpg", bgr, [int(cv2.IMWRITE_JPEG_QUALITY), int(quality)])
|
||||||
|
if not ok:
|
||||||
|
return
|
||||||
|
with _sender_lock:
|
||||||
|
cb = _sender_cb
|
||||||
|
if cb:
|
||||||
|
try:
|
||||||
|
cb(enc.tobytes())
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def send_ui_final(text: str):
|
||||||
|
"""把一条UI文案作为 final answer 推给前端(线程安全回调)"""
|
||||||
|
if not text:
|
||||||
|
return
|
||||||
|
with _ui_sender_lock:
|
||||||
|
cb = _ui_sender_cb
|
||||||
|
if cb:
|
||||||
|
try:
|
||||||
|
cb(str(text))
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
342
crosswalk_awareness.py
Normal file
342
crosswalk_awareness.py
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
斑马线感知监控器
|
||||||
|
基于面积变化的斑马线检测和语音提示
|
||||||
|
不涉及状态切换,只提供语音引导
|
||||||
|
"""
|
||||||
|
import time
|
||||||
|
import numpy as np
|
||||||
|
from collections import deque
|
||||||
|
from typing import Optional, Dict, Any
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class CrosswalkAwarenessMonitor:
|
||||||
|
"""斑马线感知监控器 - 纯语音提示模块"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# 面积阈值(固定锚点)
|
||||||
|
self.THRESHOLDS = {
|
||||||
|
'discover': 0.01, # 1% - 发现
|
||||||
|
'approaching': 0.08, # 8% - 靠近
|
||||||
|
'near': 0.18, # 18% - 很近
|
||||||
|
'arrival': 0.25, # 25% - 到达(可以过马路)
|
||||||
|
}
|
||||||
|
|
||||||
|
# 已播报的阈值(避免重复)
|
||||||
|
self.broadcasted_thresholds = set()
|
||||||
|
|
||||||
|
# 面积历史记录
|
||||||
|
self.area_history = deque(maxlen=30) # 保存最近30帧
|
||||||
|
|
||||||
|
# 时间记录
|
||||||
|
self.last_broadcast_time = 0
|
||||||
|
self.arrival_first_broadcast_time = 0
|
||||||
|
|
||||||
|
# 状态标志
|
||||||
|
self.in_arrival_state = False # 是否在"可以过马路"状态
|
||||||
|
self.last_position_zone = None # 上次播报的方位
|
||||||
|
|
||||||
|
# 播报间隔配置(可调整参数 - 数值越小播报越频繁)
|
||||||
|
# 【参数调整】将所有间隔除以1.5,提高播报频率1.5倍
|
||||||
|
self.REPEAT_INTERVALS = {
|
||||||
|
'approaching': 6.7, # 靠近阶段:每6.7秒重复(原10秒÷1.5)
|
||||||
|
'near': 3.3, # 很近阶段:每3.3秒重复(原5秒÷1.5)
|
||||||
|
'arrival': 5.3, # 到达阶段:每5.3秒重复(原8秒÷1.5)
|
||||||
|
}
|
||||||
|
# 提示:如需调整频率,修改这些数值即可
|
||||||
|
# - 数值越小 = 播报越频繁
|
||||||
|
# - 数值越大 = 播报越稀疏
|
||||||
|
|
||||||
|
# 无遮挡判断阈值
|
||||||
|
self.OCCLUSION_THRESHOLD = 0.30 # 重叠>30%认为有遮挡
|
||||||
|
|
||||||
|
def process_frame(self, crosswalk_mask, blind_path_mask=None) -> Optional[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
处理每帧的斑马线检测
|
||||||
|
|
||||||
|
返回:
|
||||||
|
{
|
||||||
|
'voice_text': 语音文本,
|
||||||
|
'priority': 优先级,
|
||||||
|
'should_broadcast': 是否应该播报,
|
||||||
|
'area': 当前面积,
|
||||||
|
'position': 方位描述,
|
||||||
|
'visualization': 可视化信息(用于外部绘制)
|
||||||
|
}
|
||||||
|
或 None(无需播报)
|
||||||
|
"""
|
||||||
|
# 如果没有斑马线,重置状态
|
||||||
|
if crosswalk_mask is None:
|
||||||
|
self._reset_if_needed()
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 1. 计算面积
|
||||||
|
total_pixels = crosswalk_mask.size
|
||||||
|
crosswalk_pixels = np.sum(crosswalk_mask > 0)
|
||||||
|
area_ratio = crosswalk_pixels / total_pixels
|
||||||
|
|
||||||
|
# 2. 计算中心位置
|
||||||
|
y_coords, x_coords = np.where(crosswalk_mask > 0)
|
||||||
|
if len(y_coords) == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
|
center_x_ratio = np.mean(x_coords) / crosswalk_mask.shape[1]
|
||||||
|
center_y_ratio = np.mean(y_coords) / crosswalk_mask.shape[0]
|
||||||
|
|
||||||
|
# 3. 记录历史
|
||||||
|
current_time = time.time()
|
||||||
|
self.area_history.append({
|
||||||
|
'area': area_ratio,
|
||||||
|
'center_x': center_x_ratio,
|
||||||
|
'center_y': center_y_ratio,
|
||||||
|
'time': current_time
|
||||||
|
})
|
||||||
|
|
||||||
|
# 4. 检查遮挡
|
||||||
|
has_occlusion = self._check_occlusion(crosswalk_mask, blind_path_mask)
|
||||||
|
|
||||||
|
# 5. 判断当前阶段和生成语音
|
||||||
|
return self._generate_guidance(area_ratio, center_x_ratio, center_y_ratio,
|
||||||
|
has_occlusion, current_time)
|
||||||
|
|
||||||
|
def _check_occlusion(self, crosswalk_mask, blind_path_mask) -> bool:
|
||||||
|
"""检查斑马线是否被盲道遮挡"""
|
||||||
|
if blind_path_mask is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
crosswalk_area = crosswalk_mask > 0
|
||||||
|
blind_path_area = blind_path_mask > 0
|
||||||
|
|
||||||
|
# 计算重叠
|
||||||
|
overlap = np.logical_and(crosswalk_area, blind_path_area)
|
||||||
|
overlap_ratio = np.sum(overlap) / max(np.sum(crosswalk_area), 1)
|
||||||
|
|
||||||
|
# 重叠超过阈值认为有遮挡
|
||||||
|
return overlap_ratio > self.OCCLUSION_THRESHOLD
|
||||||
|
|
||||||
|
def _get_position_description(self, center_x_ratio) -> str:
|
||||||
|
"""获取方位描述(3分法)"""
|
||||||
|
if center_x_ratio < 0.40:
|
||||||
|
return "在画面左侧"
|
||||||
|
elif center_x_ratio < 0.60:
|
||||||
|
return "在画面中间"
|
||||||
|
else:
|
||||||
|
return "在画面右侧"
|
||||||
|
|
||||||
|
def _generate_guidance(self, area_ratio, center_x_ratio, center_y_ratio,
|
||||||
|
has_occlusion, current_time) -> Optional[Dict[str, Any]]:
|
||||||
|
"""生成引导语音"""
|
||||||
|
|
||||||
|
# 检查面积是否稳定(避免抖动)
|
||||||
|
if not self._is_area_stable(area_ratio):
|
||||||
|
return None
|
||||||
|
|
||||||
|
position_desc = self._get_position_description(center_x_ratio)
|
||||||
|
|
||||||
|
# 阶段1:发现阶段(0.01-0.08)
|
||||||
|
if area_ratio >= self.THRESHOLDS['discover'] and area_ratio < self.THRESHOLDS['approaching']:
|
||||||
|
if self.THRESHOLDS['discover'] not in self.broadcasted_thresholds:
|
||||||
|
self.broadcasted_thresholds.add(self.THRESHOLDS['discover'])
|
||||||
|
return {
|
||||||
|
'voice_text': f"远处发现斑马线,{position_desc}",
|
||||||
|
'priority': 55, # 提高到55,超过盲道方向指令(50)
|
||||||
|
'should_broadcast': True,
|
||||||
|
'area': area_ratio,
|
||||||
|
'position': position_desc
|
||||||
|
}
|
||||||
|
|
||||||
|
# 阶段2:靠近阶段(0.08-0.18)
|
||||||
|
elif area_ratio >= self.THRESHOLDS['approaching'] and area_ratio < self.THRESHOLDS['near']:
|
||||||
|
# 首次播报
|
||||||
|
if self.THRESHOLDS['approaching'] not in self.broadcasted_thresholds:
|
||||||
|
self.broadcasted_thresholds.add(self.THRESHOLDS['approaching'])
|
||||||
|
self.last_broadcast_time = current_time
|
||||||
|
self.last_position_zone = position_desc
|
||||||
|
return {
|
||||||
|
'voice_text': f"正在靠近斑马线,{position_desc}",
|
||||||
|
'priority': 55, # 提高到55
|
||||||
|
'should_broadcast': True,
|
||||||
|
'area': area_ratio,
|
||||||
|
'position': position_desc
|
||||||
|
}
|
||||||
|
# 重复播报(每10秒或方位变化)
|
||||||
|
elif (current_time - self.last_broadcast_time >= self.REPEAT_INTERVALS['approaching'] or
|
||||||
|
position_desc != self.last_position_zone):
|
||||||
|
self.last_broadcast_time = current_time
|
||||||
|
self.last_position_zone = position_desc
|
||||||
|
return {
|
||||||
|
'voice_text': f"正在靠近斑马线,{position_desc}",
|
||||||
|
'priority': 55, # 提高到55
|
||||||
|
'should_broadcast': True,
|
||||||
|
'area': area_ratio,
|
||||||
|
'position': position_desc
|
||||||
|
}
|
||||||
|
|
||||||
|
# 阶段3:很近阶段(0.18-0.25)
|
||||||
|
elif area_ratio >= self.THRESHOLDS['near'] and area_ratio < self.THRESHOLDS['arrival']:
|
||||||
|
# 首次播报
|
||||||
|
if self.THRESHOLDS['near'] not in self.broadcasted_thresholds:
|
||||||
|
self.broadcasted_thresholds.add(self.THRESHOLDS['near'])
|
||||||
|
self.last_broadcast_time = current_time
|
||||||
|
self.last_position_zone = position_desc
|
||||||
|
return {
|
||||||
|
'voice_text': f"接近斑马线,{position_desc}",
|
||||||
|
'priority': 60,
|
||||||
|
'should_broadcast': True,
|
||||||
|
'area': area_ratio,
|
||||||
|
'position': position_desc
|
||||||
|
}
|
||||||
|
# 重复播报(每5秒或方位变化)
|
||||||
|
elif (current_time - self.last_broadcast_time >= self.REPEAT_INTERVALS['near'] or
|
||||||
|
position_desc != self.last_position_zone):
|
||||||
|
self.last_broadcast_time = current_time
|
||||||
|
self.last_position_zone = position_desc
|
||||||
|
return {
|
||||||
|
'voice_text': f"接近斑马线,{position_desc}",
|
||||||
|
'priority': 60,
|
||||||
|
'should_broadcast': True,
|
||||||
|
'area': area_ratio,
|
||||||
|
'position': position_desc
|
||||||
|
}
|
||||||
|
|
||||||
|
# 阶段4:到达阶段(area ≥ 0.25,无遮挡)
|
||||||
|
elif area_ratio >= self.THRESHOLDS['arrival']:
|
||||||
|
# 必须无遮挡才能提示过马路
|
||||||
|
if has_occlusion:
|
||||||
|
# 有遮挡,暂不提示过马路,停留在阶段3
|
||||||
|
logger.info(f"[斑马线] 面积达到{area_ratio:.2f}但被遮挡,暂不提示过马路")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 首次到达
|
||||||
|
if not self.in_arrival_state:
|
||||||
|
self.in_arrival_state = True
|
||||||
|
self.arrival_first_broadcast_time = current_time
|
||||||
|
self.last_broadcast_time = current_time
|
||||||
|
logger.info(f"[斑马线] 到达状态:area={area_ratio:.2f}, 无遮挡")
|
||||||
|
return {
|
||||||
|
'voice_text': "斑马线到了可以过马路",
|
||||||
|
'priority': 80,
|
||||||
|
'should_broadcast': True,
|
||||||
|
'area': area_ratio,
|
||||||
|
'position': '到达'
|
||||||
|
}
|
||||||
|
# 重复播报(每8秒)
|
||||||
|
elif current_time - self.last_broadcast_time >= self.REPEAT_INTERVALS['arrival']:
|
||||||
|
self.last_broadcast_time = current_time
|
||||||
|
return {
|
||||||
|
'voice_text': "斑马线到了可以过马路",
|
||||||
|
'priority': 80,
|
||||||
|
'should_broadcast': True,
|
||||||
|
'area': area_ratio,
|
||||||
|
'position': '到达'
|
||||||
|
}
|
||||||
|
# 超时处理(30秒后自动退出到达状态)
|
||||||
|
elif current_time - self.arrival_first_broadcast_time > 30.0:
|
||||||
|
logger.info("[斑马线] 到达状态超时30秒,自动退出")
|
||||||
|
self.in_arrival_state = False
|
||||||
|
return None
|
||||||
|
|
||||||
|
# 降级处理:如果从到达状态面积减小
|
||||||
|
if self.in_arrival_state and area_ratio < 0.20:
|
||||||
|
logger.info(f"[斑马线] 面积降至{area_ratio:.2f},退出到达状态")
|
||||||
|
self.in_arrival_state = False
|
||||||
|
# 清除部分已播报标记,允许重新播报
|
||||||
|
self.broadcasted_thresholds.discard(self.THRESHOLDS['arrival'])
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _is_area_stable(self, area_ratio, stability_frames=5) -> bool:
|
||||||
|
"""检查面积是否稳定(避免抖动触发)"""
|
||||||
|
if len(self.area_history) < stability_frames:
|
||||||
|
return True # 初始阶段,认为稳定
|
||||||
|
|
||||||
|
recent_areas = [h['area'] for h in list(self.area_history)[-stability_frames:]]
|
||||||
|
|
||||||
|
# 检查最近N帧是否都在当前面积附近(±20%)
|
||||||
|
for recent_area in recent_areas:
|
||||||
|
if abs(recent_area - area_ratio) / max(area_ratio, 0.001) > 0.20:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
|
def _reset_if_needed(self):
|
||||||
|
"""重置状态(斑马线消失时)"""
|
||||||
|
if len(self.area_history) > 0:
|
||||||
|
logger.info("[斑马线] 斑马线消失,重置状态")
|
||||||
|
|
||||||
|
self.broadcasted_thresholds.clear()
|
||||||
|
self.area_history.clear()
|
||||||
|
self.in_arrival_state = False
|
||||||
|
self.last_position_zone = None
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
"""完全重置"""
|
||||||
|
self.broadcasted_thresholds.clear()
|
||||||
|
self.area_history.clear()
|
||||||
|
self.in_arrival_state = False
|
||||||
|
self.last_broadcast_time = 0
|
||||||
|
self.arrival_first_broadcast_time = 0
|
||||||
|
self.last_position_zone = None
|
||||||
|
logger.info("[斑马线] 感知监控器已重置")
|
||||||
|
|
||||||
|
def is_in_arrival_state(self) -> bool:
|
||||||
|
"""是否在到达状态(用于外部判断是否暂停盲道语音)"""
|
||||||
|
return self.in_arrival_state
|
||||||
|
|
||||||
|
def get_current_area(self) -> float:
|
||||||
|
"""获取当前面积"""
|
||||||
|
if len(self.area_history) > 0:
|
||||||
|
return self.area_history[-1]['area']
|
||||||
|
return 0.0
|
||||||
|
|
||||||
|
def get_visualization_data(self, crosswalk_mask, area_ratio, center_x_ratio, center_y_ratio, has_occlusion) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
获取可视化数据
|
||||||
|
返回包含所有可视化元素的字典
|
||||||
|
"""
|
||||||
|
if crosswalk_mask is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
# 确定当前阶段(统一使用橙色)
|
||||||
|
if area_ratio >= self.THRESHOLDS['arrival']:
|
||||||
|
stage = "到达"
|
||||||
|
stage_color = "rgba(255, 165, 0, 0.5)" # 橙色
|
||||||
|
elif area_ratio >= self.THRESHOLDS['near']:
|
||||||
|
stage = "接近"
|
||||||
|
stage_color = "rgba(255, 165, 0, 0.45)" # 橙色
|
||||||
|
elif area_ratio >= self.THRESHOLDS['approaching']:
|
||||||
|
stage = "靠近"
|
||||||
|
stage_color = "rgba(255, 165, 0, 0.40)" # 橙色
|
||||||
|
else:
|
||||||
|
stage = "发现"
|
||||||
|
stage_color = "rgba(255, 165, 0, 0.35)" # 橙色
|
||||||
|
|
||||||
|
# 方位描述
|
||||||
|
position = self._get_position_description(center_x_ratio)
|
||||||
|
|
||||||
|
return {
|
||||||
|
'area_ratio': area_ratio,
|
||||||
|
'stage': stage,
|
||||||
|
'stage_color': stage_color,
|
||||||
|
'position': position.replace("在画面", ""), # 去掉"在画面"前缀
|
||||||
|
'center_x_ratio': center_x_ratio,
|
||||||
|
'center_y_ratio': center_y_ratio,
|
||||||
|
'has_occlusion': has_occlusion,
|
||||||
|
'in_arrival': self.in_arrival_state
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# 辅助函数
|
||||||
|
def split_combined_voice(combined_text: str) -> list:
|
||||||
|
"""
|
||||||
|
将组合语音拆分为多个独立语音
|
||||||
|
例如:"远处发现斑马线,在画面左侧" → ["远处发现斑马线", "在画面左侧"]
|
||||||
|
"""
|
||||||
|
if ',' in combined_text:
|
||||||
|
parts = combined_text.split(',')
|
||||||
|
return [p.strip() for p in parts if p.strip()]
|
||||||
|
return [combined_text]
|
||||||
|
|
||||||
77
docker-compose.yml
Normal file
77
docker-compose.yml
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
version: '3.8'
|
||||||
|
|
||||||
|
services:
|
||||||
|
aiglass:
|
||||||
|
build:
|
||||||
|
context: .
|
||||||
|
dockerfile: Dockerfile
|
||||||
|
container_name: aiglass
|
||||||
|
restart: unless-stopped
|
||||||
|
|
||||||
|
# 端口映射
|
||||||
|
ports:
|
||||||
|
- "8081:8081" # Web 服务
|
||||||
|
- "12345:12345/udp" # IMU UDP
|
||||||
|
|
||||||
|
# 环境变量(从 .env 文件读取)
|
||||||
|
environment:
|
||||||
|
- DASHSCOPE_API_KEY=${DASHSCOPE_API_KEY}
|
||||||
|
- BLIND_PATH_MODEL=${BLIND_PATH_MODEL:-model/yolo-seg.pt}
|
||||||
|
- OBSTACLE_MODEL=${OBSTACLE_MODEL:-model/yoloe-11l-seg.pt}
|
||||||
|
- YOLOE_MODEL_PATH=${YOLOE_MODEL_PATH:-model/yoloe-11l-seg.pt}
|
||||||
|
- ENABLE_TTS=${ENABLE_TTS:-true}
|
||||||
|
- TTS_INTERVAL_SEC=${TTS_INTERVAL_SEC:-1.0}
|
||||||
|
- LOG_LEVEL=${LOG_LEVEL:-INFO}
|
||||||
|
|
||||||
|
# 卷挂载
|
||||||
|
volumes:
|
||||||
|
- ./model:/app/model:ro # 模型文件(只读)
|
||||||
|
- ./recordings:/app/recordings # 录制文件
|
||||||
|
- ./music:/app/music:ro # 音频文件(只读)
|
||||||
|
- ./voice:/app/voice:ro # 语音文件(只读)
|
||||||
|
- ./static:/app/static:ro # 静态文件(只读)
|
||||||
|
- ./templates:/app/templates:ro # 模板文件(只读)
|
||||||
|
|
||||||
|
# GPU 支持(需要 nvidia-docker)
|
||||||
|
deploy:
|
||||||
|
resources:
|
||||||
|
reservations:
|
||||||
|
devices:
|
||||||
|
- driver: nvidia
|
||||||
|
count: 1
|
||||||
|
capabilities: [gpu]
|
||||||
|
|
||||||
|
# 依赖服务(可选,如需添加数据库等)
|
||||||
|
# depends_on:
|
||||||
|
# - redis
|
||||||
|
|
||||||
|
# 网络模式
|
||||||
|
network_mode: bridge
|
||||||
|
|
||||||
|
# 健康检查
|
||||||
|
healthcheck:
|
||||||
|
test: ["CMD", "curl", "-f", "http://localhost:8081/api/health"]
|
||||||
|
interval: 30s
|
||||||
|
timeout: 10s
|
||||||
|
retries: 3
|
||||||
|
start_period: 40s
|
||||||
|
|
||||||
|
# 可选:添加 Redis 用于缓存
|
||||||
|
# redis:
|
||||||
|
# image: redis:7-alpine
|
||||||
|
# container_name: aiglass-redis
|
||||||
|
# restart: unless-stopped
|
||||||
|
# ports:
|
||||||
|
# - "6379:6379"
|
||||||
|
# volumes:
|
||||||
|
# - redis-data:/data
|
||||||
|
|
||||||
|
# 可选:数据卷
|
||||||
|
# volumes:
|
||||||
|
# redis-data:
|
||||||
|
|
||||||
|
# 可选:自定义网络
|
||||||
|
# networks:
|
||||||
|
# aiglass-network:
|
||||||
|
# driver: bridge
|
||||||
|
|
||||||
202
edge_tts_client.py
Normal file
202
edge_tts_client.py
Normal file
@@ -0,0 +1,202 @@
|
|||||||
|
# edge_tts_client.py
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
EdgeTTS 流式语音合成客户端 - Day 21
|
||||||
|
|
||||||
|
特点:
|
||||||
|
- 完全免费
|
||||||
|
- 流式输出(边合成边播放)
|
||||||
|
- 低延迟
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
import edge_tts
|
||||||
|
from typing import AsyncGenerator, Optional
|
||||||
|
|
||||||
|
# 默认语音
|
||||||
|
DEFAULT_VOICE = os.getenv("EDGE_TTS_VOICE", "zh-CN-XiaoxiaoNeural")
|
||||||
|
|
||||||
|
# 语速调整 ("+0%", "+10%", "-10%" 等)
|
||||||
|
DEFAULT_RATE = os.getenv("EDGE_TTS_RATE", "+0%")
|
||||||
|
|
||||||
|
# 音量调整
|
||||||
|
DEFAULT_VOLUME = os.getenv("EDGE_TTS_VOLUME", "+0%")
|
||||||
|
|
||||||
|
|
||||||
|
async def text_to_speech_stream(
|
||||||
|
text: str,
|
||||||
|
voice: str = DEFAULT_VOICE,
|
||||||
|
rate: str = DEFAULT_RATE,
|
||||||
|
volume: str = DEFAULT_VOLUME,
|
||||||
|
) -> AsyncGenerator[bytes, None]:
|
||||||
|
"""
|
||||||
|
流式文本转语音
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 要合成的文本
|
||||||
|
voice: 语音名称
|
||||||
|
rate: 语速
|
||||||
|
volume: 音量
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
MP3 音频数据块
|
||||||
|
"""
|
||||||
|
if not text or not text.strip():
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
communicate = edge_tts.Communicate(
|
||||||
|
text=text,
|
||||||
|
voice=voice,
|
||||||
|
rate=rate,
|
||||||
|
volume=volume,
|
||||||
|
)
|
||||||
|
|
||||||
|
async for chunk in communicate.stream():
|
||||||
|
if chunk["type"] == "audio":
|
||||||
|
yield chunk["data"]
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[EdgeTTS] 合成失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
async def text_to_speech(
|
||||||
|
text: str,
|
||||||
|
voice: str = DEFAULT_VOICE,
|
||||||
|
rate: str = DEFAULT_RATE,
|
||||||
|
volume: str = DEFAULT_VOLUME,
|
||||||
|
) -> bytes:
|
||||||
|
"""
|
||||||
|
完整文本转语音(返回完整音频)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 要合成的文本
|
||||||
|
voice: 语音名称
|
||||||
|
rate: 语速
|
||||||
|
volume: 音量
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
MP3 音频数据
|
||||||
|
"""
|
||||||
|
audio_chunks = []
|
||||||
|
async for chunk in text_to_speech_stream(text, voice, rate, volume):
|
||||||
|
audio_chunks.append(chunk)
|
||||||
|
return b"".join(audio_chunks)
|
||||||
|
|
||||||
|
|
||||||
|
async def text_to_speech_pcm(
|
||||||
|
text: str,
|
||||||
|
voice: str = DEFAULT_VOICE,
|
||||||
|
rate: str = DEFAULT_RATE,
|
||||||
|
target_sample_rate: int = 16000,
|
||||||
|
) -> bytes:
|
||||||
|
"""
|
||||||
|
文本转 PCM16 音频(用于直接播放)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 要合成的文本
|
||||||
|
voice: 语音名称
|
||||||
|
rate: 语速
|
||||||
|
target_sample_rate: 目标采样率
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
PCM16 音频数据
|
||||||
|
"""
|
||||||
|
import io
|
||||||
|
from pydub import AudioSegment
|
||||||
|
|
||||||
|
# 获取 MP3 数据
|
||||||
|
mp3_data = await text_to_speech(text, voice, rate)
|
||||||
|
|
||||||
|
if not mp3_data:
|
||||||
|
return b""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# MP3 -> PCM 转换
|
||||||
|
audio = AudioSegment.from_mp3(io.BytesIO(mp3_data))
|
||||||
|
|
||||||
|
# 设置采样率和通道
|
||||||
|
audio = audio.set_frame_rate(target_sample_rate)
|
||||||
|
audio = audio.set_channels(1) # 单声道
|
||||||
|
audio = audio.set_sample_width(2) # 16-bit
|
||||||
|
|
||||||
|
return audio.raw_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[EdgeTTS] PCM 转换失败: {e}")
|
||||||
|
return b""
|
||||||
|
|
||||||
|
|
||||||
|
async def text_to_speech_pcm_stream(
|
||||||
|
text: str,
|
||||||
|
voice: str = DEFAULT_VOICE,
|
||||||
|
rate: str = DEFAULT_RATE,
|
||||||
|
target_sample_rate: int = 16000,
|
||||||
|
) -> AsyncGenerator[bytes, None]:
|
||||||
|
"""
|
||||||
|
流式文本转 PCM16 音频
|
||||||
|
|
||||||
|
注意:由于需要解码 MP3,这里采用分段合成的方式
|
||||||
|
每遇到标点符号就合成一段
|
||||||
|
|
||||||
|
Args:
|
||||||
|
text: 要合成的文本
|
||||||
|
voice: 语音名称
|
||||||
|
rate: 语速
|
||||||
|
target_sample_rate: 目标采样率
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
PCM16 音频数据块
|
||||||
|
"""
|
||||||
|
import io
|
||||||
|
from pydub import AudioSegment
|
||||||
|
|
||||||
|
# 按标点分割文本
|
||||||
|
punctuation = "。,!?;:,.!?;:"
|
||||||
|
segments = []
|
||||||
|
current = ""
|
||||||
|
|
||||||
|
for char in text:
|
||||||
|
current += char
|
||||||
|
if char in punctuation:
|
||||||
|
segments.append(current.strip())
|
||||||
|
current = ""
|
||||||
|
|
||||||
|
if current.strip():
|
||||||
|
segments.append(current.strip())
|
||||||
|
|
||||||
|
# 逐段合成
|
||||||
|
for segment in segments:
|
||||||
|
if not segment:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
mp3_data = await text_to_speech(segment, voice, rate)
|
||||||
|
|
||||||
|
if mp3_data:
|
||||||
|
audio = AudioSegment.from_mp3(io.BytesIO(mp3_data))
|
||||||
|
audio = audio.set_frame_rate(target_sample_rate)
|
||||||
|
audio = audio.set_channels(1)
|
||||||
|
audio = audio.set_sample_width(2)
|
||||||
|
|
||||||
|
yield audio.raw_data
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[EdgeTTS] 分段合成失败: {e}")
|
||||||
|
|
||||||
|
|
||||||
|
# 语音列表(常用中文)
|
||||||
|
CHINESE_VOICES = [
|
||||||
|
"zh-CN-XiaoxiaoNeural", # 女声,自然
|
||||||
|
"zh-CN-YunxiNeural", # 男声,自然
|
||||||
|
"zh-CN-XiaoyiNeural", # 女声,活泼
|
||||||
|
"zh-CN-YunjianNeural", # 男声,播报
|
||||||
|
"zh-CN-XiaochenNeural", # 女声,温柔
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def list_voices() -> list:
|
||||||
|
"""列出所有可用语音"""
|
||||||
|
voices = await edge_tts.list_voices()
|
||||||
|
return [v for v in voices if v["Locale"].startswith("zh")]
|
||||||
208
glm_client.py
Normal file
208
glm_client.py
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
# glm_client.py
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
GLM-4.6v-Flash LLM 客户端 - Day 22
|
||||||
|
|
||||||
|
使用官方 zai-sdk + glm-4.6v-flash 模型
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
import asyncio
|
||||||
|
from datetime import datetime
|
||||||
|
from typing import AsyncGenerator, Optional
|
||||||
|
from zai import ZhipuAiClient
|
||||||
|
|
||||||
|
# API 配置
|
||||||
|
API_KEY = os.getenv(
|
||||||
|
"GLM_API_KEY",
|
||||||
|
"5915240ea48d4e93b454bc2412d1cc54.e054ej4pPqi9G6rc"
|
||||||
|
)
|
||||||
|
MODEL = "glm-4.6v-flash" # 升级到 glm-4.6v-flash (支持视觉)
|
||||||
|
|
||||||
|
# 星期映射
|
||||||
|
WEEKDAY_MAP = ["星期一", "星期二", "星期三", "星期四", "星期五", "星期六", "星期日"]
|
||||||
|
|
||||||
|
|
||||||
|
def get_system_prompt() -> str:
|
||||||
|
"""动态生成 system prompt,包含当前时间信息"""
|
||||||
|
now = datetime.now()
|
||||||
|
current_time = now.strftime("%H:%M")
|
||||||
|
current_date = now.strftime("%Y年%m月%d日")
|
||||||
|
current_weekday = WEEKDAY_MAP[now.weekday()]
|
||||||
|
|
||||||
|
return f"""你是一个视障辅助AI助手,安装在智能导盲眼镜上。
|
||||||
|
当前时间:{current_time}
|
||||||
|
今天日期:{current_date} {current_weekday}
|
||||||
|
|
||||||
|
请用极简短的语言回答,每次回答不超过2-3句话。
|
||||||
|
避免冗长解释,只提供最关键的信息。
|
||||||
|
语气友好但简洁。"""
|
||||||
|
|
||||||
|
|
||||||
|
# 客户端和对话历史
|
||||||
|
_client = None
|
||||||
|
_conversation_history = []
|
||||||
|
MAX_HISTORY_TURNS = 5 # 保留最近5轮对话
|
||||||
|
|
||||||
|
|
||||||
|
def _get_client() -> ZhipuAiClient:
|
||||||
|
"""获取智谱 AI 客户端"""
|
||||||
|
global _client
|
||||||
|
if _client is None:
|
||||||
|
_client = ZhipuAiClient(api_key=API_KEY)
|
||||||
|
return _client
|
||||||
|
|
||||||
|
|
||||||
|
def clear_conversation_history():
|
||||||
|
"""清除对话历史"""
|
||||||
|
global _conversation_history
|
||||||
|
_conversation_history = []
|
||||||
|
print("[GLM] 对话历史已清除")
|
||||||
|
|
||||||
|
|
||||||
|
async def chat(user_message: str, image_base64: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
与 GLM-4.6v-Flash 对话(带上下文记忆)
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_message: 用户消息文本
|
||||||
|
image_base64: 可选,Base64 编码的图片
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
AI 回复文本
|
||||||
|
"""
|
||||||
|
global _conversation_history
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
# 构建用户消息
|
||||||
|
if image_base64:
|
||||||
|
# 多模态消息(带图片)
|
||||||
|
user_content = [
|
||||||
|
{"type": "text", "text": user_message},
|
||||||
|
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
user_content = user_message
|
||||||
|
|
||||||
|
# 添加用户消息到历史
|
||||||
|
_conversation_history.append({"role": "user", "content": user_content})
|
||||||
|
|
||||||
|
# 限制历史长度(每轮 = 1用户 + 1助手 = 2条消息)
|
||||||
|
max_messages = MAX_HISTORY_TURNS * 2
|
||||||
|
if len(_conversation_history) > max_messages:
|
||||||
|
_conversation_history = _conversation_history[-max_messages:]
|
||||||
|
|
||||||
|
# 构建完整消息列表(每次动态生成包含当前时间的 system prompt)
|
||||||
|
messages = [{"role": "system", "content": get_system_prompt()}] + _conversation_history
|
||||||
|
|
||||||
|
# Day 22: 添加重试逻辑处理速率限制
|
||||||
|
max_retries = 3
|
||||||
|
retry_delay = 1 # 初始延迟1秒
|
||||||
|
|
||||||
|
for attempt in range(max_retries):
|
||||||
|
try:
|
||||||
|
# Day 22: 升级到 glm-4.6v-flash
|
||||||
|
# 【修正】根据官方文档,thinking 参数也是必须的 (即使是 Vision 模型)
|
||||||
|
response = await asyncio.to_thread(
|
||||||
|
client.chat.completions.create,
|
||||||
|
model=MODEL,
|
||||||
|
messages=messages,
|
||||||
|
thinking={"type": "disabled"}, # 显式禁用思考以降低延迟
|
||||||
|
)
|
||||||
|
|
||||||
|
if response.choices and len(response.choices) > 0:
|
||||||
|
ai_reply = response.choices[0].message.content.strip()
|
||||||
|
# 添加助手回复到历史
|
||||||
|
_conversation_history.append({"role": "assistant", "content": ai_reply})
|
||||||
|
print(f"[GLM] 回复: {ai_reply[:50]}..." if len(ai_reply) > 50 else f"[GLM] 回复: {ai_reply}")
|
||||||
|
return ai_reply
|
||||||
|
return ""
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_str = str(e)
|
||||||
|
# 检查是否是速率限制错误(429 或 1305)
|
||||||
|
if "429" in error_str or "1305" in error_str or "请求过多" in error_str:
|
||||||
|
if attempt < max_retries - 1:
|
||||||
|
print(f"[GLM] 速率限制,{retry_delay}秒后重试... (尝试 {attempt + 1}/{max_retries})")
|
||||||
|
await asyncio.sleep(retry_delay)
|
||||||
|
retry_delay *= 2 # 指数退避
|
||||||
|
continue
|
||||||
|
|
||||||
|
print(f"[GLM] 调用失败: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
break
|
||||||
|
|
||||||
|
# 所有重试失败,移除用户消息
|
||||||
|
if _conversation_history and _conversation_history[-1]["role"] == "user":
|
||||||
|
_conversation_history.pop()
|
||||||
|
return "抱歉,我暂时无法回答。"
|
||||||
|
|
||||||
|
|
||||||
|
async def chat_stream(user_message: str, image_base64: Optional[str] = None) -> AsyncGenerator[str, None]:
|
||||||
|
"""
|
||||||
|
流式对话(逐字返回)- GLM-4.6v-Flash
|
||||||
|
|
||||||
|
Args:
|
||||||
|
user_message: 用户消息文本
|
||||||
|
image_base64: 可选,Base64 编码的图片
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
AI 回复的文本片段
|
||||||
|
"""
|
||||||
|
global _conversation_history
|
||||||
|
client = _get_client()
|
||||||
|
|
||||||
|
# 构建用户消息
|
||||||
|
if image_base64:
|
||||||
|
user_content = [
|
||||||
|
{"type": "text", "text": user_message},
|
||||||
|
{"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}}
|
||||||
|
]
|
||||||
|
else:
|
||||||
|
user_content = user_message
|
||||||
|
|
||||||
|
# 添加用户消息到历史
|
||||||
|
_conversation_history.append({"role": "user", "content": user_content})
|
||||||
|
|
||||||
|
# 限制历史长度
|
||||||
|
max_messages = MAX_HISTORY_TURNS * 2
|
||||||
|
if len(_conversation_history) > max_messages:
|
||||||
|
_conversation_history = _conversation_history[-max_messages:]
|
||||||
|
|
||||||
|
# 构建完整消息列表
|
||||||
|
messages = [{"role": "system", "content": get_system_prompt()}] + _conversation_history
|
||||||
|
|
||||||
|
full_response = ""
|
||||||
|
|
||||||
|
try:
|
||||||
|
# 流式调用
|
||||||
|
# Day 22: 升级到 glm-4.6v-flash
|
||||||
|
# 【修正】根据官方文档,thinking 参数也是必须的
|
||||||
|
response = await asyncio.to_thread(
|
||||||
|
client.chat.completions.create,
|
||||||
|
model=MODEL,
|
||||||
|
messages=messages,
|
||||||
|
thinking={"type": "disabled"},
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
if chunk.choices[0].delta.content:
|
||||||
|
text = chunk.choices[0].delta.content
|
||||||
|
full_response += text
|
||||||
|
yield text
|
||||||
|
|
||||||
|
# 添加完整回复到历史
|
||||||
|
if full_response:
|
||||||
|
_conversation_history.append({"role": "assistant", "content": full_response})
|
||||||
|
print(f"[GLM] 流式完成: {full_response[:50]}..." if len(full_response) > 50 else f"[GLM] 流式完成: {full_response}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[GLM] 流式调用失败: {e}")
|
||||||
|
import traceback
|
||||||
|
traceback.print_exc()
|
||||||
|
# 移除刚才添加的用户消息
|
||||||
|
if _conversation_history and _conversation_history[-1]["role"] == "user":
|
||||||
|
_conversation_history.pop()
|
||||||
|
yield "抱歉,我暂时无法回答。"
|
||||||
271
gpu_parallel.py
Normal file
271
gpu_parallel.py
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
# gpu_parallel.py
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
"""
|
||||||
|
Day 20: GPU 并行推理优化模块
|
||||||
|
使用 CUDA Stream 让盲道检测和障碍物检测并行执行
|
||||||
|
"""
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from typing import Tuple, Optional, List, Any
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
import logging
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# 全局 CUDA Stream(延迟初始化)
|
||||||
|
_cuda_streams = None
|
||||||
|
_parallel_executor = None
|
||||||
|
|
||||||
|
def _init_cuda_streams():
|
||||||
|
"""初始化 CUDA Streams"""
|
||||||
|
global _cuda_streams
|
||||||
|
if _cuda_streams is None and torch.cuda.is_available():
|
||||||
|
try:
|
||||||
|
_cuda_streams = [torch.cuda.Stream() for _ in range(2)]
|
||||||
|
logger.info("[GPU_PARALLEL] 已创建 2 个 CUDA Stream")
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"[GPU_PARALLEL] 创建 CUDA Stream 失败: {e}")
|
||||||
|
_cuda_streams = []
|
||||||
|
return _cuda_streams
|
||||||
|
|
||||||
|
def _init_parallel_executor():
|
||||||
|
"""初始化并行执行器(用于 CPU 后处理)"""
|
||||||
|
global _parallel_executor
|
||||||
|
if _parallel_executor is None:
|
||||||
|
_parallel_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="gpu_post")
|
||||||
|
logger.info("[GPU_PARALLEL] 已创建 GPU 后处理线程池")
|
||||||
|
return _parallel_executor
|
||||||
|
|
||||||
|
class ParallelDetector:
|
||||||
|
"""
|
||||||
|
并行检测器 - 同时执行盲道检测和障碍物检测
|
||||||
|
|
||||||
|
使用方式:
|
||||||
|
detector = ParallelDetector(yolo_model, obstacle_detector)
|
||||||
|
blind_mask, cross_mask, obstacles = detector.detect_all(image, path_mask)
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, yolo_model, obstacle_detector):
|
||||||
|
self.yolo_model = yolo_model
|
||||||
|
self.obstacle_detector = obstacle_detector
|
||||||
|
|
||||||
|
# 检测参数(从环境变量读取)
|
||||||
|
self.imgsz = int(os.getenv("AIGLASS_YOLO_IMGSZ", "480"))
|
||||||
|
self.use_half = os.getenv("AIGLASS_YOLO_HALF", "1") == "1"
|
||||||
|
self.blind_conf_threshold = 0.20
|
||||||
|
self.cross_conf_threshold = 0.30
|
||||||
|
|
||||||
|
# 初始化
|
||||||
|
_init_cuda_streams()
|
||||||
|
_init_parallel_executor()
|
||||||
|
|
||||||
|
logger.info(f"[GPU_PARALLEL] ParallelDetector 初始化完成: imgsz={self.imgsz}, half={self.use_half}")
|
||||||
|
|
||||||
|
def detect_all(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
path_mask: Optional[np.ndarray] = None
|
||||||
|
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List[Any]]:
|
||||||
|
"""
|
||||||
|
并行执行所有检测
|
||||||
|
|
||||||
|
Args:
|
||||||
|
image: BGR 图像
|
||||||
|
path_mask: 盲道掩码(用于障碍物过滤)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
(blind_path_mask, crosswalk_mask, obstacles)
|
||||||
|
"""
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
|
||||||
|
streams = _init_cuda_streams()
|
||||||
|
|
||||||
|
if streams and len(streams) >= 2:
|
||||||
|
return self._detect_with_streams(image, path_mask, streams)
|
||||||
|
else:
|
||||||
|
# 回退到串行执行
|
||||||
|
return self._detect_serial(image, path_mask)
|
||||||
|
|
||||||
|
def _detect_with_streams(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
path_mask: Optional[np.ndarray],
|
||||||
|
streams: List[torch.cuda.Stream]
|
||||||
|
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List[Any]]:
|
||||||
|
"""Day 21: 使用 ThreadPoolExecutor 真正并行检测(替代无效的 CUDA Stream)"""
|
||||||
|
|
||||||
|
blind_mask = None
|
||||||
|
cross_mask = None
|
||||||
|
obstacles = []
|
||||||
|
|
||||||
|
executor = _init_parallel_executor()
|
||||||
|
|
||||||
|
# 定义两个检测任务
|
||||||
|
def task_blind_path():
|
||||||
|
if self.yolo_model is None:
|
||||||
|
return None, None
|
||||||
|
try:
|
||||||
|
results = self.yolo_model.predict(
|
||||||
|
image,
|
||||||
|
verbose=False,
|
||||||
|
conf=min(self.blind_conf_threshold, self.cross_conf_threshold),
|
||||||
|
classes=[0, 1], # 0=crosswalk, 1=blind_path
|
||||||
|
imgsz=self.imgsz,
|
||||||
|
half=self.use_half
|
||||||
|
)
|
||||||
|
if results and results[0] and results[0].masks is not None:
|
||||||
|
return self._parse_seg_results(results[0], image.shape)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[GPU_PARALLEL] 盲道检测失败: {e}")
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
def task_obstacles():
|
||||||
|
if self.obstacle_detector is None:
|
||||||
|
return []
|
||||||
|
try:
|
||||||
|
return self.obstacle_detector.detect(image, path_mask=path_mask)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[GPU_PARALLEL] 障碍物检测失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
# 并行提交两个任务
|
||||||
|
from concurrent.futures import as_completed
|
||||||
|
futures = {
|
||||||
|
executor.submit(task_blind_path): 'blind',
|
||||||
|
executor.submit(task_obstacles): 'obstacle'
|
||||||
|
}
|
||||||
|
|
||||||
|
# 等待所有任务完成
|
||||||
|
for future in as_completed(futures, timeout=2.0):
|
||||||
|
task_type = futures[future]
|
||||||
|
try:
|
||||||
|
result = future.result()
|
||||||
|
if task_type == 'blind':
|
||||||
|
blind_mask, cross_mask = result
|
||||||
|
else:
|
||||||
|
obstacles = result
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[GPU_PARALLEL] {task_type}任务异常: {e}")
|
||||||
|
|
||||||
|
return blind_mask, cross_mask, obstacles
|
||||||
|
|
||||||
|
def _detect_serial(
|
||||||
|
self,
|
||||||
|
image: np.ndarray,
|
||||||
|
path_mask: Optional[np.ndarray]
|
||||||
|
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List[Any]]:
|
||||||
|
"""串行检测(回退模式)"""
|
||||||
|
|
||||||
|
blind_mask = None
|
||||||
|
cross_mask = None
|
||||||
|
obstacles = []
|
||||||
|
|
||||||
|
# 盲道检测
|
||||||
|
if self.yolo_model is not None:
|
||||||
|
try:
|
||||||
|
results = self.yolo_model.predict(
|
||||||
|
image,
|
||||||
|
verbose=False,
|
||||||
|
conf=min(self.blind_conf_threshold, self.cross_conf_threshold),
|
||||||
|
classes=[0, 1],
|
||||||
|
imgsz=self.imgsz,
|
||||||
|
half=self.use_half
|
||||||
|
)
|
||||||
|
|
||||||
|
if results and results[0] and results[0].masks is not None:
|
||||||
|
blind_mask, cross_mask = self._parse_seg_results(results[0], image.shape)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[GPU_PARALLEL] 盲道检测失败: {e}")
|
||||||
|
|
||||||
|
# 障碍物检测
|
||||||
|
if self.obstacle_detector is not None:
|
||||||
|
try:
|
||||||
|
obstacles = self.obstacle_detector.detect(image, path_mask=path_mask)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"[GPU_PARALLEL] 障碍物检测失败: {e}")
|
||||||
|
|
||||||
|
return blind_mask, cross_mask, obstacles
|
||||||
|
|
||||||
|
def _parse_seg_results(
|
||||||
|
self,
|
||||||
|
result,
|
||||||
|
image_shape: Tuple[int, int, int]
|
||||||
|
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]:
|
||||||
|
"""解析 YOLO 分割结果"""
|
||||||
|
|
||||||
|
blind_mask = None
|
||||||
|
cross_mask = None
|
||||||
|
|
||||||
|
h, w = image_shape[:2]
|
||||||
|
|
||||||
|
if result.masks is None or result.boxes is None:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
for mask_tensor, conf_tensor, cls_tensor in zip(
|
||||||
|
result.masks.data, result.boxes.conf, result.boxes.cls
|
||||||
|
):
|
||||||
|
class_id = int(cls_tensor.item())
|
||||||
|
confidence = float(conf_tensor.item())
|
||||||
|
|
||||||
|
# 置信度过滤
|
||||||
|
if class_id == 1 and confidence < self.blind_conf_threshold:
|
||||||
|
continue
|
||||||
|
if class_id == 0 and confidence < self.cross_conf_threshold:
|
||||||
|
continue
|
||||||
|
|
||||||
|
# 转换掩码
|
||||||
|
current_mask = self._tensor_to_mask(mask_tensor, w, h)
|
||||||
|
|
||||||
|
if class_id == 1: # 盲道
|
||||||
|
if blind_mask is None:
|
||||||
|
blind_mask = current_mask
|
||||||
|
else:
|
||||||
|
blind_mask = np.bitwise_or(blind_mask, current_mask)
|
||||||
|
elif class_id == 0: # 斑马线
|
||||||
|
if cross_mask is None:
|
||||||
|
cross_mask = current_mask
|
||||||
|
else:
|
||||||
|
cross_mask = np.bitwise_or(cross_mask, current_mask)
|
||||||
|
|
||||||
|
return blind_mask, cross_mask
|
||||||
|
|
||||||
|
def _tensor_to_mask(
|
||||||
|
self,
|
||||||
|
mask_tensor: torch.Tensor,
|
||||||
|
out_w: int,
|
||||||
|
out_h: int
|
||||||
|
) -> np.ndarray:
|
||||||
|
"""将 PyTorch 张量掩码转换为 NumPy 数组"""
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
# 转换为 numpy
|
||||||
|
if mask_tensor.is_cuda:
|
||||||
|
mask_np = mask_tensor.cpu().numpy()
|
||||||
|
else:
|
||||||
|
mask_np = mask_tensor.numpy()
|
||||||
|
|
||||||
|
# 调整大小
|
||||||
|
if mask_np.shape[0] != out_h or mask_np.shape[1] != out_w:
|
||||||
|
mask_np = cv2.resize(mask_np, (out_w, out_h), interpolation=cv2.INTER_NEAREST)
|
||||||
|
|
||||||
|
# 二值化
|
||||||
|
mask_np = (mask_np > 0.5).astype(np.uint8) * 255
|
||||||
|
|
||||||
|
return mask_np
|
||||||
|
|
||||||
|
|
||||||
|
def detect_all_parallel(
|
||||||
|
yolo_model,
|
||||||
|
obstacle_detector,
|
||||||
|
image: np.ndarray,
|
||||||
|
path_mask: Optional[np.ndarray] = None
|
||||||
|
) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List[Any]]:
|
||||||
|
"""
|
||||||
|
便捷函数:并行执行所有检测
|
||||||
|
|
||||||
|
用于替换 workflow_blindpath.py 中的串行检测
|
||||||
|
"""
|
||||||
|
detector = ParallelDetector(yolo_model, obstacle_detector)
|
||||||
|
return detector.detect_all(image, path_mask)
|
||||||
BIN
hand_landmarker.task
Normal file
BIN
hand_landmarker.task
Normal file
Binary file not shown.
BIN
model/SenseVoiceSmall/chn_jpn_yue_eng_ko_spectok.bpe.model
Normal file
BIN
model/SenseVoiceSmall/chn_jpn_yue_eng_ko_spectok.bpe.model
Normal file
Binary file not shown.
97
model/SenseVoiceSmall/config.yaml
Normal file
97
model/SenseVoiceSmall/config.yaml
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
encoder: SenseVoiceEncoderSmall
|
||||||
|
encoder_conf:
|
||||||
|
output_size: 512
|
||||||
|
attention_heads: 4
|
||||||
|
linear_units: 2048
|
||||||
|
num_blocks: 50
|
||||||
|
tp_blocks: 20
|
||||||
|
dropout_rate: 0.1
|
||||||
|
positional_dropout_rate: 0.1
|
||||||
|
attention_dropout_rate: 0.1
|
||||||
|
input_layer: pe
|
||||||
|
pos_enc_class: SinusoidalPositionEncoder
|
||||||
|
normalize_before: true
|
||||||
|
kernel_size: 11
|
||||||
|
sanm_shfit: 0
|
||||||
|
selfattention_layer_type: sanm
|
||||||
|
|
||||||
|
|
||||||
|
model: SenseVoiceSmall
|
||||||
|
model_conf:
|
||||||
|
length_normalized_loss: true
|
||||||
|
sos: 1
|
||||||
|
eos: 2
|
||||||
|
ignore_id: -1
|
||||||
|
|
||||||
|
tokenizer: SentencepiecesTokenizer
|
||||||
|
tokenizer_conf:
|
||||||
|
bpemodel: null
|
||||||
|
unk_symbol: <unk>
|
||||||
|
split_with_space: true
|
||||||
|
|
||||||
|
frontend: WavFrontend
|
||||||
|
frontend_conf:
|
||||||
|
fs: 16000
|
||||||
|
window: hamming
|
||||||
|
n_mels: 80
|
||||||
|
frame_length: 25
|
||||||
|
frame_shift: 10
|
||||||
|
lfr_m: 7
|
||||||
|
lfr_n: 6
|
||||||
|
cmvn_file: null
|
||||||
|
|
||||||
|
|
||||||
|
dataset: SenseVoiceCTCDataset
|
||||||
|
dataset_conf:
|
||||||
|
index_ds: IndexDSJsonl
|
||||||
|
batch_sampler: EspnetStyleBatchSampler
|
||||||
|
data_split_num: 32
|
||||||
|
batch_type: token
|
||||||
|
batch_size: 14000
|
||||||
|
max_token_length: 2000
|
||||||
|
min_token_length: 60
|
||||||
|
max_source_length: 2000
|
||||||
|
min_source_length: 60
|
||||||
|
max_target_length: 200
|
||||||
|
min_target_length: 0
|
||||||
|
shuffle: true
|
||||||
|
num_workers: 4
|
||||||
|
sos: ${model_conf.sos}
|
||||||
|
eos: ${model_conf.eos}
|
||||||
|
IndexDSJsonl: IndexDSJsonl
|
||||||
|
retry: 20
|
||||||
|
|
||||||
|
train_conf:
|
||||||
|
accum_grad: 1
|
||||||
|
grad_clip: 5
|
||||||
|
max_epoch: 20
|
||||||
|
keep_nbest_models: 10
|
||||||
|
avg_nbest_model: 10
|
||||||
|
log_interval: 100
|
||||||
|
resume: true
|
||||||
|
validate_interval: 10000
|
||||||
|
save_checkpoint_interval: 10000
|
||||||
|
|
||||||
|
optim: adamw
|
||||||
|
optim_conf:
|
||||||
|
lr: 0.00002
|
||||||
|
scheduler: warmuplr
|
||||||
|
scheduler_conf:
|
||||||
|
warmup_steps: 25000
|
||||||
|
|
||||||
|
specaug: SpecAugLFR
|
||||||
|
specaug_conf:
|
||||||
|
apply_time_warp: false
|
||||||
|
time_warp_window: 5
|
||||||
|
time_warp_mode: bicubic
|
||||||
|
apply_freq_mask: true
|
||||||
|
freq_mask_width_range:
|
||||||
|
- 0
|
||||||
|
- 30
|
||||||
|
lfr_rate: 6
|
||||||
|
num_freq_mask: 1
|
||||||
|
apply_time_mask: true
|
||||||
|
time_mask_width_range:
|
||||||
|
- 0
|
||||||
|
- 12
|
||||||
|
num_time_mask: 1
|
||||||
14
model/SenseVoiceSmall/configuration.json
Normal file
14
model/SenseVoiceSmall/configuration.json
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
{
|
||||||
|
"framework": "pytorch",
|
||||||
|
"task" : "auto-speech-recognition",
|
||||||
|
"model": {"type" : "funasr"},
|
||||||
|
"pipeline": {"type":"funasr-pipeline"},
|
||||||
|
"model_name_in_hub": {
|
||||||
|
"ms":"",
|
||||||
|
"hf":""},
|
||||||
|
"file_path_metas": {
|
||||||
|
"init_param":"model.pt",
|
||||||
|
"config":"config.yaml",
|
||||||
|
"tokenizer_conf": {"bpemodel": "chn_jpn_yue_eng_ko_spectok.bpe.model"},
|
||||||
|
"frontend_conf":{"cmvn_file": "am.mvn"}}
|
||||||
|
}
|
||||||
37
model_utils.py
Normal file
37
model_utils.py
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
# model_utils.py - Day 20 TensorRT 模型加载工具
|
||||||
|
"""
|
||||||
|
优先加载 TensorRT .engine 文件,不存在时回退到 .pt
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
|
||||||
|
def get_best_model_path(pt_path: str) -> str:
|
||||||
|
"""
|
||||||
|
根据 .pt 路径自动选择最佳模型文件:
|
||||||
|
1. 优先加载同目录下的 .engine 文件(TensorRT 加速)
|
||||||
|
2. 如果 .engine 不存在,回退到原 .pt 文件
|
||||||
|
|
||||||
|
参数:
|
||||||
|
pt_path: 原始 .pt 模型路径(如 'model/yolo-seg.pt')
|
||||||
|
|
||||||
|
返回:
|
||||||
|
最佳模型路径(.engine 或 .pt)
|
||||||
|
"""
|
||||||
|
if not pt_path.endswith('.pt'):
|
||||||
|
return pt_path
|
||||||
|
|
||||||
|
engine_path = pt_path.replace('.pt', '.engine')
|
||||||
|
|
||||||
|
if os.path.exists(engine_path):
|
||||||
|
print(f"[MODEL] 🚀 使用 TensorRT 加速: {engine_path}")
|
||||||
|
return engine_path
|
||||||
|
else:
|
||||||
|
print(f"[MODEL] 使用 PyTorch 模型: {pt_path}")
|
||||||
|
return pt_path
|
||||||
|
|
||||||
|
|
||||||
|
def is_tensorrt_engine(model_path: str) -> bool:
|
||||||
|
"""检查模型路径是否是 TensorRT 引擎(.engine 文件)"""
|
||||||
|
return model_path.endswith('.engine') or model_path.endswith('.trt')
|
||||||
|
|
||||||
159
models.py
Normal file
159
models.py
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
# app/models.py
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import torch
|
||||||
|
from threading import Semaphore
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from typing import List
|
||||||
|
from app.cloud.obstacle_detector_client import ObstacleDetectorClient
|
||||||
|
# ==========================================================
|
||||||
|
# 0. 导入所有需要的模型封装类 (Clients) 和 Ultralytics 基类
|
||||||
|
# ==========================================================
|
||||||
|
# 这是过马路工作流使用的封装类
|
||||||
|
from app.cloud.crosswalk_detector_client import CrosswalkDetector
|
||||||
|
from app.cloud.coco_perception_client import COCOClient
|
||||||
|
from obstacle_detector_client import ObstacleDetectorClient
|
||||||
|
|
||||||
|
# Day 20: TensorRT 模型加载工具
|
||||||
|
from model_utils import get_best_model_path, is_tensorrt_engine
|
||||||
|
|
||||||
|
# 这是盲道工作流直接使用的 Ultralytics 类
|
||||||
|
from ultralytics import YOLO, YOLOE
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# ==========================================================
|
||||||
|
# 1. 全局设备与并发控制 (统一管理)
|
||||||
|
# ==========================================================
|
||||||
|
DEVICE = os.getenv("AIGLASS_DEVICE", "cuda:0")
|
||||||
|
if DEVICE.startswith("cuda") and not torch.cuda.is_available():
|
||||||
|
logger.warning(f"AIGLASS_DEVICE={DEVICE} 但未检测到 CUDA,将回退到 CPU")
|
||||||
|
DEVICE = "cpu"
|
||||||
|
IS_CUDA = DEVICE.startswith("cuda")
|
||||||
|
|
||||||
|
# AMP (自动混合精度) 配置
|
||||||
|
AMP_POLICY = os.getenv("AIGLASS_AMP", "bf16").lower()
|
||||||
|
AMP_DTYPE = torch.bfloat16 if AMP_POLICY == "bf16" else (
|
||||||
|
torch.float16 if AMP_POLICY == "fp16" else None) if IS_CUDA else None
|
||||||
|
|
||||||
|
# 🔥 核心:全局唯一的GPU并发信号量,所有工作流共享
|
||||||
|
GPU_SLOTS = int(os.getenv("AIGLASS_GPU_SLOTS", "2"))
|
||||||
|
gpu_semaphore = Semaphore(GPU_SLOTS)
|
||||||
|
|
||||||
|
|
||||||
|
# 统一的推理上下文管理器,所有工作流都应使用它来调用模型
|
||||||
|
@contextmanager
|
||||||
|
def gpu_infer_slot():
|
||||||
|
"""
|
||||||
|
统一管理:GPU 并发限流 + torch.inference_mode() + AMP autocast
|
||||||
|
"""
|
||||||
|
with gpu_semaphore:
|
||||||
|
if IS_CUDA and AMP_POLICY != "off" and AMP_DTYPE is not None:
|
||||||
|
with torch.inference_mode(), torch.amp.autocast('cuda', dtype=AMP_DTYPE):
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
with torch.inference_mode():
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
# cuDNN 加速优化
|
||||||
|
try:
|
||||||
|
if IS_CUDA:
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ==========================================================
|
||||||
|
# 2. 全局模型实例定义 (全部初始化为 None)
|
||||||
|
# ==========================================================
|
||||||
|
|
||||||
|
# --- 过马路工作流模型 (通过Client类封装) ---
|
||||||
|
crosswalk_detector_client: CrosswalkDetector = None
|
||||||
|
coco_client: COCOClient = None
|
||||||
|
# ObstacleDetectorClient 将作为所有场景的通用障碍物检测器
|
||||||
|
obstacle_detector_client: ObstacleDetectorClient = None
|
||||||
|
|
||||||
|
# --- 盲道工作流模型 (直接使用Ultralytics类) ---
|
||||||
|
# 它们主要用于分割和路径规划,与过马路场景的检测逻辑不同
|
||||||
|
blindpath_seg_model: YOLO = None
|
||||||
|
# 障碍物检测将复用 obstacle_detector_client,但YOLOE的文本特征需要单独保存
|
||||||
|
blindpath_whitelist_embeddings = None
|
||||||
|
|
||||||
|
# 全局加载状态标志
|
||||||
|
models_are_loaded = False
|
||||||
|
|
||||||
|
|
||||||
|
# ==========================================================
|
||||||
|
# 3. 统一的模型加载函数 (由 celery.py 在启动时调用)
|
||||||
|
# ==========================================================
|
||||||
|
def init_all_models():
|
||||||
|
"""
|
||||||
|
在Celery Worker进程启动时被调用一次。
|
||||||
|
负责加载所有工作流所需的模型到全局变量中。
|
||||||
|
"""
|
||||||
|
global models_are_loaded
|
||||||
|
if models_are_loaded:
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"========= 🚀 开始全局模型预加载 (目标设备: {DEVICE}) =========")
|
||||||
|
|
||||||
|
try:
|
||||||
|
# --- [1] 加载通用的障碍物检测器 (ObstacleDetectorClient) ---
|
||||||
|
global obstacle_detector_client
|
||||||
|
logger.info("[1/4] 正在加载通用障碍物检测模型 (ObstacleDetectorClient)...")
|
||||||
|
# Day 20: 优先使用 TensorRT 引擎
|
||||||
|
obs_model_path = get_best_model_path('model/yoloe-11l-seg.pt')
|
||||||
|
obstacle_detector_client = ObstacleDetectorClient(model_path=obs_model_path)
|
||||||
|
|
||||||
|
# Day 20: TensorRT 引擎不需要 .to()
|
||||||
|
if not is_tensorrt_engine(obs_model_path):
|
||||||
|
if hasattr(obstacle_detector_client, 'model') and obstacle_detector_client.model is not None:
|
||||||
|
obstacle_detector_client.model.to(DEVICE)
|
||||||
|
|
||||||
|
logger.info("...通用障碍物检测模型加载成功。")
|
||||||
|
|
||||||
|
# --- [2] 加载过马路专用的模型 (Clients) ---
|
||||||
|
global crosswalk_detector_client, coco_client
|
||||||
|
logger.info("[2/4] 正在加载过马路分割模型 (CrosswalkDetector)...")
|
||||||
|
# Day 20: 优先使用 TensorRT 引擎
|
||||||
|
crosswalk_model_path = get_best_model_path('model/yolo-seg.pt')
|
||||||
|
crosswalk_detector_client = CrosswalkDetector(model_path=crosswalk_model_path)
|
||||||
|
# Day 20: TensorRT 引擎不需要 .to()
|
||||||
|
if not is_tensorrt_engine(crosswalk_model_path):
|
||||||
|
if hasattr(crosswalk_detector_client, 'model') and crosswalk_detector_client.model is not None:
|
||||||
|
crosswalk_detector_client.model.to(DEVICE)
|
||||||
|
logger.info("...过马路分割模型加载成功。")
|
||||||
|
|
||||||
|
logger.info("[3/4] 正在加载通用感知模型 (COCOClient)...")
|
||||||
|
coco_client = COCOClient(model_path='model/yolov8l-world.pt')
|
||||||
|
# 将其内部的YOLO模型移动到指定设备
|
||||||
|
if hasattr(coco_client, 'model') and coco_client.model is not None:
|
||||||
|
coco_client.model.to(DEVICE)
|
||||||
|
logger.info("...通用感知模型加载成功。")
|
||||||
|
|
||||||
|
# --- [4] 加载盲道专用的模型 ---
|
||||||
|
global blindpath_seg_model, blindpath_whitelist_embeddings
|
||||||
|
logger.info("[4/4] 正在加载盲道专用分割模型 (YOLO)...")
|
||||||
|
# Day 20: 优先使用 TensorRT 引擎
|
||||||
|
blindpath_model_path = get_best_model_path('model/yolo-seg.pt')
|
||||||
|
blindpath_seg_model = YOLO(blindpath_model_path)
|
||||||
|
# Day 20: TensorRT 引擎不需要 .to() 和 .fuse()
|
||||||
|
if not is_tensorrt_engine(blindpath_model_path):
|
||||||
|
blindpath_seg_model.to(DEVICE)
|
||||||
|
blindpath_seg_model.fuse()
|
||||||
|
logger.info("...盲道专用分割模型加载成功。")
|
||||||
|
|
||||||
|
# 为盲道工作流保存其需要的YOLOE文本特征引用
|
||||||
|
if obstacle_detector_client:
|
||||||
|
blindpath_whitelist_embeddings = obstacle_detector_client.whitelist_embeddings
|
||||||
|
logger.info("...已为盲道工作流链接障碍物模型特征。")
|
||||||
|
|
||||||
|
# 所有模型加载完毕
|
||||||
|
models_are_loaded = True
|
||||||
|
logger.info("========= ✅ 所有模型已成功预加载。Worker准备就绪! =========")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"模型预加载过程中发生严重错误: {e}", exc_info=True)
|
||||||
|
# 抛出异常,这将导致Celery Worker启动失败,这是合理的行为
|
||||||
|
# 因为一个没有模型的Worker是无用的,提前暴露问题更好。
|
||||||
|
raise
|
||||||
BIN
music/converted_向上.wav
Normal file
BIN
music/converted_向上.wav
Normal file
Binary file not shown.
BIN
music/converted_向下.wav
Normal file
BIN
music/converted_向下.wav
Normal file
Binary file not shown.
BIN
music/converted_向前.wav
Normal file
BIN
music/converted_向前.wav
Normal file
Binary file not shown.
BIN
music/converted_向右.wav
Normal file
BIN
music/converted_向右.wav
Normal file
Binary file not shown.
BIN
music/converted_向后.wav
Normal file
BIN
music/converted_向后.wav
Normal file
Binary file not shown.
BIN
music/converted_向左.wav
Normal file
BIN
music/converted_向左.wav
Normal file
Binary file not shown.
BIN
music/converted_已对中.wav
Normal file
BIN
music/converted_已对中.wav
Normal file
Binary file not shown.
BIN
music/converted_找到啦.wav
Normal file
BIN
music/converted_找到啦.wav
Normal file
Binary file not shown.
BIN
music/converted_拿到啦.wav
Normal file
BIN
music/converted_拿到啦.wav
Normal file
Binary file not shown.
BIN
music/converted_音频1.WAV
Normal file
BIN
music/converted_音频1.WAV
Normal file
Binary file not shown.
BIN
music/converted_音频2.WAV
Normal file
BIN
music/converted_音频2.WAV
Normal file
Binary file not shown.
BIN
music/converted_音频3.WAV
Normal file
BIN
music/converted_音频3.WAV
Normal file
Binary file not shown.
BIN
music/converted_音频4.WAV
Normal file
BIN
music/converted_音频4.WAV
Normal file
Binary file not shown.
BIN
music/converted_音频5.WAV
Normal file
BIN
music/converted_音频5.WAV
Normal file
Binary file not shown.
BIN
music/converted_音频6.WAV
Normal file
BIN
music/converted_音频6.WAV
Normal file
Binary file not shown.
BIN
music/converted_音频7.WAV
Normal file
BIN
music/converted_音频7.WAV
Normal file
Binary file not shown.
BIN
music/converted_音频8.WAV
Normal file
BIN
music/converted_音频8.WAV
Normal file
Binary file not shown.
BIN
music/converted_音频9.WAV
Normal file
BIN
music/converted_音频9.WAV
Normal file
Binary file not shown.
1
music/向上.txt
Normal file
1
music/向上.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
向上。如果还有其他想法,你可以随时告诉我哦。
|
||||||
BIN
music/向上.wav
Normal file
BIN
music/向上.wav
Normal file
Binary file not shown.
1
music/向下.txt
Normal file
1
music/向下.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
向下。如果还有啥想法,你可以再跟我说哦。
|
||||||
BIN
music/向下.wav
Normal file
BIN
music/向下.wav
Normal file
Binary file not shown.
1
music/向前.txt
Normal file
1
music/向前.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
向前。如果还有啥想法,你可以再跟我说哦。
|
||||||
BIN
music/向前.wav
Normal file
BIN
music/向前.wav
Normal file
Binary file not shown.
1
music/向右.txt
Normal file
1
music/向右.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
向右。如果还有啥想法,你可以再跟我说哦。
|
||||||
BIN
music/向右.wav
Normal file
BIN
music/向右.wav
Normal file
Binary file not shown.
1
music/向后.txt
Normal file
1
music/向后.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
向后。如果还有啥想法,你可以再跟我说哦。
|
||||||
BIN
music/向后.wav
Normal file
BIN
music/向后.wav
Normal file
Binary file not shown.
1
music/向左.txt
Normal file
1
music/向左.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
向左。如果还有其他想法,你可以随时告诉我哦。
|
||||||
BIN
music/向左.wav
Normal file
BIN
music/向左.wav
Normal file
Binary file not shown.
BIN
music/在画面中间.WAV
Normal file
BIN
music/在画面中间.WAV
Normal file
Binary file not shown.
1
music/在画面中间.txt
Normal file
1
music/在画面中间.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
“在画面中间”
|
||||||
BIN
music/在画面中间_24k.wav
Normal file
BIN
music/在画面中间_24k.wav
Normal file
Binary file not shown.
BIN
music/在画面右侧.WAV
Normal file
BIN
music/在画面右侧.WAV
Normal file
Binary file not shown.
1
music/在画面右侧.txt
Normal file
1
music/在画面右侧.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
“在画面右侧”
|
||||||
BIN
music/在画面右侧_24k.wav
Normal file
BIN
music/在画面右侧_24k.wav
Normal file
Binary file not shown.
BIN
music/在画面左侧.WAV
Normal file
BIN
music/在画面左侧.WAV
Normal file
Binary file not shown.
1
music/在画面左侧.txt
Normal file
1
music/在画面左侧.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
“在画面左侧”
|
||||||
BIN
music/在画面左侧_24k.wav
Normal file
BIN
music/在画面左侧_24k.wav
Normal file
Binary file not shown.
1
music/已对中.txt
Normal file
1
music/已对中.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
已对正!如果还有其他想法或者问题,你可以随时告诉我哦。
|
||||||
BIN
music/已对中.wav
Normal file
BIN
music/已对中.wav
Normal file
Binary file not shown.
1
music/找到啦.txt
Normal file
1
music/找到啦.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
找到了!
|
||||||
BIN
music/找到啦.wav
Normal file
BIN
music/找到啦.wav
Normal file
Binary file not shown.
1
music/拿到啦.txt
Normal file
1
music/拿到啦.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
拿到了!如果还有啥问题,你可以再跟我说哦。
|
||||||
BIN
music/拿到啦.wav
Normal file
BIN
music/拿到啦.wav
Normal file
Binary file not shown.
BIN
music/接近斑马线.WAV
Normal file
BIN
music/接近斑马线.WAV
Normal file
Binary file not shown.
1
music/接近斑马线.txt
Normal file
1
music/接近斑马线.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
“接近斑马线”
|
||||||
BIN
music/接近斑马线_24k.wav
Normal file
BIN
music/接近斑马线_24k.wav
Normal file
Binary file not shown.
BIN
music/斑马线到了可以过马路.WAV
Normal file
BIN
music/斑马线到了可以过马路.WAV
Normal file
Binary file not shown.
1
music/斑马线到了可以过马路.txt
Normal file
1
music/斑马线到了可以过马路.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
“斑马线到了可以过马路”。如果还有类似的问题或者其他想法,你可以随时告诉我哦。
|
||||||
BIN
music/斑马线到了可以过马路_24k.wav
Normal file
BIN
music/斑马线到了可以过马路_24k.wav
Normal file
Binary file not shown.
BIN
music/正在靠近斑马线.WAV
Normal file
BIN
music/正在靠近斑马线.WAV
Normal file
Binary file not shown.
1
music/正在靠近斑马线.txt
Normal file
1
music/正在靠近斑马线.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
“正在靠近斑马线”
|
||||||
BIN
music/正在靠近斑马线_24k.wav
Normal file
BIN
music/正在靠近斑马线_24k.wav
Normal file
Binary file not shown.
BIN
music/红灯.WAV
Normal file
BIN
music/红灯.WAV
Normal file
Binary file not shown.
BIN
music/绿灯.WAV
Normal file
BIN
music/绿灯.WAV
Normal file
Binary file not shown.
BIN
music/远处发现斑马线.WAV
Normal file
BIN
music/远处发现斑马线.WAV
Normal file
Binary file not shown.
1
music/远处发现斑马线.txt
Normal file
1
music/远处发现斑马线.txt
Normal file
@@ -0,0 +1 @@
|
|||||||
|
“远处发现斑马线”
|
||||||
BIN
music/远处发现斑马线_24k.wav
Normal file
BIN
music/远处发现斑马线_24k.wav
Normal file
Binary file not shown.
BIN
music/音频1.WAV
Normal file
BIN
music/音频1.WAV
Normal file
Binary file not shown.
BIN
music/音频2.WAV
Normal file
BIN
music/音频2.WAV
Normal file
Binary file not shown.
BIN
music/音频3.WAV
Normal file
BIN
music/音频3.WAV
Normal file
Binary file not shown.
BIN
music/音频4.WAV
Normal file
BIN
music/音频4.WAV
Normal file
Binary file not shown.
BIN
music/音频5.WAV
Normal file
BIN
music/音频5.WAV
Normal file
Binary file not shown.
BIN
music/音频6.WAV
Normal file
BIN
music/音频6.WAV
Normal file
Binary file not shown.
BIN
music/音频7.WAV
Normal file
BIN
music/音频7.WAV
Normal file
Binary file not shown.
BIN
music/音频8.WAV
Normal file
BIN
music/音频8.WAV
Normal file
Binary file not shown.
BIN
music/音频9.WAV
Normal file
BIN
music/音频9.WAV
Normal file
Binary file not shown.
BIN
music/黄灯.WAV
Normal file
BIN
music/黄灯.WAV
Normal file
Binary file not shown.
700
navigation_master.py
Normal file
700
navigation_master.py
Normal file
@@ -0,0 +1,700 @@
|
|||||||
|
# navigation_master.py
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import time
|
||||||
|
import math
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional, Dict, Any, Deque, List, Tuple
|
||||||
|
from collections import deque
|
||||||
|
|
||||||
|
# 工作流导入(与现有文件解耦)
|
||||||
|
from workflow_blindpath import BlindPathNavigator, ProcessingResult as BlindResult
|
||||||
|
from workflow_crossstreet import CrossStreetNavigator, CrossStreetResult as CrossResult
|
||||||
|
|
||||||
|
# ========== 状态常量 ==========
|
||||||
|
IDLE = "IDLE" # 空闲/未启用
|
||||||
|
CHAT = "CHAT" # 对话模式(不进行导航,只返回原始画面)
|
||||||
|
BLINDPATH_NAV = "BLINDPATH_NAV" # 正在走盲道(复用 BlindPathNavigator)
|
||||||
|
SEEKING_CROSSWALK = "SEEKING_CROSSWALK"# 盲道阶段发现斑马线,正对准/靠近
|
||||||
|
WAIT_TRAFFIC_LIGHT = "WAIT_TRAFFIC_LIGHT" # 到达斑马线后等待交通灯(可选/占位)
|
||||||
|
CROSSING = "CROSSING" # 正在过马路(复用 CrossStreetNavigator)
|
||||||
|
SEEKING_NEXT_BLINDPATH = "SEEKING_NEXT_BLINDPATH" # 过完马路后寻找下一段盲道入口(上盲道)
|
||||||
|
RECOVERY = "RECOVERY" # 兜底/恢复(感知暂时丢失时)
|
||||||
|
TRAFFIC_LIGHT_DETECTION = "TRAFFIC_LIGHT_DETECTION" # 红绿灯检测模式
|
||||||
|
ITEM_SEARCH = "ITEM_SEARCH" # 找物品模式(暂停导航,由yolomedia处理画面)
|
||||||
|
|
||||||
|
# ========== 返回结构 ==========
|
||||||
|
@dataclass
|
||||||
|
class OrchestratorResult:
|
||||||
|
annotated_image: Optional[np.ndarray]
|
||||||
|
guidance_text: str
|
||||||
|
state: str
|
||||||
|
extras: Dict[str, Any]
|
||||||
|
|
||||||
|
# ========== 实用:信号平滑/多数表决 ==========
|
||||||
|
class MajorityFilter:
|
||||||
|
def __init__(self, size: int = 8):
|
||||||
|
self.buf: Deque[str] = deque(maxlen=size)
|
||||||
|
|
||||||
|
def push(self, v: str):
|
||||||
|
self.buf.append(v)
|
||||||
|
|
||||||
|
def majority(self) -> str:
|
||||||
|
if not self.buf:
|
||||||
|
return "unknown"
|
||||||
|
cnt = {}
|
||||||
|
for v in self.buf:
|
||||||
|
cnt[v] = cnt.get(v, 0) + 1
|
||||||
|
# 稳健排序:unknown 权重最低
|
||||||
|
items = sorted(cnt.items(), key=lambda x: (0 if x[0]=="unknown" else 1, x[1]), reverse=True)
|
||||||
|
return items[0][0]
|
||||||
|
|
||||||
|
def history(self) -> List[str]:
|
||||||
|
return list(self.buf)
|
||||||
|
|
||||||
|
def clear(self):
|
||||||
|
self.buf.clear()
|
||||||
|
|
||||||
|
# ========== 红绿灯识别 ==========
|
||||||
|
class TrafficLightDetector:
|
||||||
|
"""
|
||||||
|
红绿灯识别器:
|
||||||
|
1) 优先尝试 yoloe_backend 风格的检测(如可用);
|
||||||
|
2) 回退:无模型时,使用 HSV 颜色启发式在上半屏寻找亮红/黄/绿的“灯团”。
|
||||||
|
输出:('red'|'green'|'yellow'|'unknown', meta)
|
||||||
|
"""
|
||||||
|
def __init__(self):
|
||||||
|
self.has_backend = False
|
||||||
|
self.backend = None
|
||||||
|
try:
|
||||||
|
# 尝试动态导入(根据你本地 yoloe_backend 的接口调整)
|
||||||
|
import yoloe_backend as _yeb # noqa
|
||||||
|
self.backend = _yeb
|
||||||
|
self.has_backend = True
|
||||||
|
except Exception:
|
||||||
|
self.has_backend = False
|
||||||
|
self.backend = None
|
||||||
|
|
||||||
|
def _try_backend(self, bgr: np.ndarray) -> Tuple[str, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
尝试调用 yoloe_backend 风格的接口。由于各项目实现不同,这里做“宽容地调用”:
|
||||||
|
- 优先尝试 backend.detect(image, target_classes=['traffic light'])
|
||||||
|
- 次选 backend.infer_image(image) 后在结果中过滤 'traffic light'
|
||||||
|
- 以上都失败则返回 unknown
|
||||||
|
预期结果条目应含 bbox 或 mask,可自行扩展“颜色判定”逻辑(ROI 取样 HSV)
|
||||||
|
"""
|
||||||
|
if not self.has_backend or self.backend is None:
|
||||||
|
return "unknown", {"reason": "backend_not_available"}
|
||||||
|
|
||||||
|
res = None
|
||||||
|
try:
|
||||||
|
if hasattr(self.backend, "detect"):
|
||||||
|
# 假定 detect 返回 [{'name': 'traffic light', 'box':[x1,y1,x2,y2], ...}, ...]
|
||||||
|
res = self.backend.detect(bgr, target_classes=["traffic light"])
|
||||||
|
elif hasattr(self.backend, "infer_image"):
|
||||||
|
# 假定 infer_image 返回 [{'label': 'traffic light', 'bbox': [x1,y1,x2,y2], ...}, ...]
|
||||||
|
res = self.backend.infer_image(bgr)
|
||||||
|
else:
|
||||||
|
return "unknown", {"reason": "backend_no_suitable_api"}
|
||||||
|
except Exception as e:
|
||||||
|
return "unknown", {"reason": f"backend_failed:{e}"}
|
||||||
|
|
||||||
|
if not res or len(res) == 0:
|
||||||
|
return "unknown", {"reason": "no_detection"}
|
||||||
|
|
||||||
|
# 拿到最大框作为主灯,做 HSV 颜色判断
|
||||||
|
H, W = bgr.shape[:2]
|
||||||
|
best = None
|
||||||
|
best_area = 0
|
||||||
|
boxes = []
|
||||||
|
for item in res:
|
||||||
|
# 统一盒字段
|
||||||
|
if "box" in item and isinstance(item["box"], (list, tuple)) and len(item["box"]) == 4:
|
||||||
|
x1, y1, x2, y2 = item["box"]
|
||||||
|
elif "bbox" in item and isinstance(item["bbox"], (list, tuple)) and len(item["bbox"]) == 4:
|
||||||
|
x1, y1, x2, y2 = item["bbox"]
|
||||||
|
else:
|
||||||
|
continue
|
||||||
|
x1 = int(max(0, min(W-1, x1))); x2 = int(max(0, min(W-1, x2)))
|
||||||
|
y1 = int(max(0, min(H-1, y1))); y2 = int(max(0, min(H-1, y2)))
|
||||||
|
if x2 <= x1 or y2 <= y1:
|
||||||
|
continue
|
||||||
|
area = (x2 - x1) * (y2 - y1)
|
||||||
|
boxes.append((x1, y1, x2, y2, area))
|
||||||
|
if area > best_area:
|
||||||
|
best_area = area
|
||||||
|
best = (x1, y1, x2, y2)
|
||||||
|
|
||||||
|
if best is None:
|
||||||
|
return "unknown", {"reason": "no_valid_bbox", "raw": len(res)}
|
||||||
|
|
||||||
|
x1, y1, x2, y2 = best
|
||||||
|
roi = bgr[y1:y2, x1:x2]
|
||||||
|
color = self._classify_color_hsv(roi)
|
||||||
|
return color, {"bbox": best, "count": len(res), "boxes": boxes}
|
||||||
|
|
||||||
|
def _classify_color_hsv(self, roi_bgr: np.ndarray) -> str:
|
||||||
|
"""对 ROI 做 HSV 基于阈值的红/黄/绿简单判定;取面积最大的主色。"""
|
||||||
|
if roi_bgr is None or roi_bgr.size == 0:
|
||||||
|
return "unknown"
|
||||||
|
hsv = cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV)
|
||||||
|
|
||||||
|
# 红色范围(两段)
|
||||||
|
lower_red1 = np.array([0, 80, 120]); upper_red1 = np.array([10, 255, 255])
|
||||||
|
lower_red2 = np.array([160, 80, 120]); upper_red2 = np.array([180, 255, 255])
|
||||||
|
mask_r1 = cv2.inRange(hsv, lower_red1, upper_red1)
|
||||||
|
mask_r2 = cv2.inRange(hsv, lower_red2, upper_red2)
|
||||||
|
mask_red = cv2.bitwise_or(mask_r1, mask_r2)
|
||||||
|
|
||||||
|
# 绿色
|
||||||
|
lower_green = np.array([40, 60, 120]); upper_green = np.array([90, 255, 255])
|
||||||
|
mask_green = cv2.inRange(hsv, lower_green, upper_green)
|
||||||
|
|
||||||
|
# 黄色
|
||||||
|
lower_yellow = np.array([18, 80, 150]); upper_yellow = np.array([35, 255, 255])
|
||||||
|
mask_yellow = cv2.inRange(hsv, lower_yellow, upper_yellow)
|
||||||
|
|
||||||
|
# 面积阈值(相对 ROI)
|
||||||
|
total = roi_bgr.shape[0] * roi_bgr.shape[1] + 1e-6
|
||||||
|
r_ratio = float(np.count_nonzero(mask_red)) / total
|
||||||
|
g_ratio = float(np.count_nonzero(mask_green)) / total
|
||||||
|
y_ratio = float(np.count_nonzero(mask_yellow)) / total
|
||||||
|
|
||||||
|
# 简单抑制“脏背景导致的弱响应”
|
||||||
|
thr = 0.03
|
||||||
|
candidates = []
|
||||||
|
if r_ratio > thr: candidates.append(("red", r_ratio))
|
||||||
|
if g_ratio > thr: candidates.append(("green", g_ratio))
|
||||||
|
if y_ratio > thr: candidates.append(("yellow", y_ratio))
|
||||||
|
if not candidates:
|
||||||
|
return "unknown"
|
||||||
|
candidates.sort(key=lambda x: x[1], reverse=True)
|
||||||
|
return candidates[0][0]
|
||||||
|
|
||||||
|
def detect(self, bgr: np.ndarray) -> Tuple[str, Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
总入口:先尝试后端;失败则在上半屏自行找“亮色灯团”(无需框)。
|
||||||
|
"""
|
||||||
|
# 1) 尝试后端
|
||||||
|
if self.has_backend:
|
||||||
|
color, meta = self._try_backend(bgr)
|
||||||
|
if color != "unknown":
|
||||||
|
return color, {"method": "backend", **meta}
|
||||||
|
|
||||||
|
# 2) 回退:上半屏 HSV 聚类 + 连通域,选最大“灯团”判色
|
||||||
|
H, W = bgr.shape[:2]
|
||||||
|
roi = bgr[:int(H * 0.5), :]
|
||||||
|
hsv = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV)
|
||||||
|
|
||||||
|
# 高亮阈值(抑制暗部/车灯)
|
||||||
|
v = hsv[:, :, 2]
|
||||||
|
bright = (v > 140).astype(np.uint8) * 255
|
||||||
|
|
||||||
|
# 粗分颜色
|
||||||
|
col = self._classify_color_hsv(roi)
|
||||||
|
return col, {"method": "fallback", "note": "no_backend", "bright_ratio": float(np.mean(bright > 0))}
|
||||||
|
|
||||||
|
# ========== 视觉辅助工具 ==========
|
||||||
|
def _color_bgr(name: str) -> Tuple[int, int, int]:
|
||||||
|
if name == "red": return (0, 0, 255)
|
||||||
|
if name == "green": return (0, 255, 0)
|
||||||
|
if name == "yellow": return (0, 255, 255)
|
||||||
|
if name == "blue": return (255, 0, 0)
|
||||||
|
if name == "orange": return (0, 165, 255)
|
||||||
|
if name == "cyan": return (255, 255, 0)
|
||||||
|
if name == "magenta": return (255, 0, 255)
|
||||||
|
if name == "gray": return (128, 128, 128)
|
||||||
|
if name == "white": return (255, 255, 255)
|
||||||
|
return (200, 200, 200)
|
||||||
|
|
||||||
|
def _put_text(img, text, org, color=(255,255,255), scale=0.7, thick=2, outline=True):
|
||||||
|
if outline:
|
||||||
|
for dx in (-1,0,1):
|
||||||
|
for dy in (-1,0,1):
|
||||||
|
if dx==0 and dy==0: continue
|
||||||
|
cv2.putText(img, text, (org[0]+dx, org[1]+dy), cv2.FONT_HERSHEY_SIMPLEX, scale, (0,0,0), thick+1)
|
||||||
|
cv2.putText(img, text, org, cv2.FONT_HERSHEY_SIMPLEX, scale, color, thick)
|
||||||
|
|
||||||
|
def _draw_badge(img, text, pos=(10, 28), fg="white", bg="blue"):
|
||||||
|
color_fg = _color_bgr(fg); color_bg = _color_bgr(bg)
|
||||||
|
(tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2)
|
||||||
|
x, y = pos
|
||||||
|
pad = 6
|
||||||
|
cv2.rectangle(img, (x-4, y-th-pad), (x+tw+8, y+pad//2), color_bg, -1)
|
||||||
|
_put_text(img, text, (x, y), color=color_fg, scale=0.6, thick=2, outline=False)
|
||||||
|
|
||||||
|
def _draw_state_panel(img, kv: Dict[str, Any], pos=(10, 60)):
|
||||||
|
x, y = pos
|
||||||
|
line_h = 22
|
||||||
|
for i, (k, v) in enumerate(kv.items()):
|
||||||
|
_put_text(img, f"{k}: {v}", (x, y + i*line_h), color=(255,255,255), scale=0.6, thick=2)
|
||||||
|
|
||||||
|
def _draw_frame_border(img, color=(0,255,0), thickness=3):
|
||||||
|
h, w = img.shape[:2]
|
||||||
|
cv2.rectangle(img, (0,0), (w-1, h-1), color, thickness)
|
||||||
|
|
||||||
|
def _draw_progress_bar(img, ratio: float, pos=(10, 90), size=(180, 10), color="cyan"):
|
||||||
|
ratio = max(0.0, min(1.0, float(ratio)))
|
||||||
|
x, y = pos
|
||||||
|
w, h = size
|
||||||
|
cv2.rectangle(img, (x, y), (x+w, y+h), (80,80,80), 1)
|
||||||
|
cv2.rectangle(img, (x+1, y+1), (x+1+int((w-2)*ratio), y+h-1), _color_bgr(color), -1)
|
||||||
|
|
||||||
|
# ========== 统领器 ==========
|
||||||
|
class NavigationMaster:
|
||||||
|
def __init__(self,
|
||||||
|
blind_nav: BlindPathNavigator,
|
||||||
|
cross_nav: CrossStreetNavigator,
|
||||||
|
*,
|
||||||
|
min_tts_interval: float = 1.2):
|
||||||
|
self.blind = blind_nav
|
||||||
|
self.cross = cross_nav
|
||||||
|
self.state = IDLE
|
||||||
|
self.last_guidance_ts = 0.0
|
||||||
|
self.min_tts_interval = min_tts_interval
|
||||||
|
|
||||||
|
# 防抖/稳定计数
|
||||||
|
self.cnt_crosswalk_seen = 0 # 盲道侧看见斑马线(approaching/ready)
|
||||||
|
self.cnt_align_ready = 0 # 斑马线 ready + 对准达标
|
||||||
|
self.cnt_cross_end = 0 # 过马路结束条件累计
|
||||||
|
self.cnt_lost = 0 # 感知丢失累计(进入 RECOVERY)
|
||||||
|
|
||||||
|
# 冷却期避免状态抖动
|
||||||
|
self.cooldown_until = 0.0
|
||||||
|
|
||||||
|
# 紧急恢复目标
|
||||||
|
self.prev_target_state = BLINDPATH_NAV
|
||||||
|
|
||||||
|
# 交通灯
|
||||||
|
self.tld = TrafficLightDetector()
|
||||||
|
self.tl_major = MajorityFilter(size=8)
|
||||||
|
self.tl_last_color = "unknown"
|
||||||
|
|
||||||
|
# 参数(可按现场再调)
|
||||||
|
self.FRAMES_CROSS_SEEN = 8
|
||||||
|
self.FRAMES_ALIGN_READY = 12
|
||||||
|
self.FRAMES_CROSS_END = 12
|
||||||
|
self.FRAMES_NEXT_BLIND_OK = 8
|
||||||
|
self.FRAMES_LOST_MAX = 45
|
||||||
|
|
||||||
|
self.ANGLE_ALIGN_THR_DEG = 12.0
|
||||||
|
self.OFFSET_ALIGN_THR = 0.15
|
||||||
|
|
||||||
|
self.COOLDOWN_SEC = 0.6
|
||||||
|
|
||||||
|
# 找物品状态管理
|
||||||
|
self.prev_nav_state_before_search = None # 找物品前的导航状态,用于恢复
|
||||||
|
|
||||||
|
# ----- 外部交互 -----
|
||||||
|
def get_state(self) -> str:
|
||||||
|
return self.state
|
||||||
|
|
||||||
|
def start_blind_path_navigation(self):
|
||||||
|
"""启动盲道导航模式"""
|
||||||
|
self.state = BLINDPATH_NAV
|
||||||
|
self.cooldown_until = time.time() + self.COOLDOWN_SEC
|
||||||
|
if self.blind:
|
||||||
|
self.blind.reset()
|
||||||
|
|
||||||
|
def stop_navigation(self):
|
||||||
|
"""停止导航,回到对话模式"""
|
||||||
|
self.state = CHAT
|
||||||
|
self.cooldown_until = time.time() + self.COOLDOWN_SEC
|
||||||
|
if self.blind:
|
||||||
|
self.blind.reset()
|
||||||
|
|
||||||
|
def start_crossing(self):
|
||||||
|
"""启动过马路模式"""
|
||||||
|
self.state = CROSSING
|
||||||
|
self.cooldown_until = time.time() + self.COOLDOWN_SEC
|
||||||
|
if self.cross:
|
||||||
|
self.cross.reset()
|
||||||
|
|
||||||
|
def start_traffic_light_detection(self):
|
||||||
|
"""启动红绿灯检测模式"""
|
||||||
|
self.state = TRAFFIC_LIGHT_DETECTION
|
||||||
|
self.cooldown_until = time.time() + self.COOLDOWN_SEC
|
||||||
|
|
||||||
|
def is_in_navigation_mode(self):
|
||||||
|
"""检查是否在导航模式(非对话模式)"""
|
||||||
|
return self.state not in ["CHAT", "IDLE", "TRAFFIC_LIGHT_DETECTION", "ITEM_SEARCH"]
|
||||||
|
|
||||||
|
def start_item_search(self):
|
||||||
|
"""启动找物品模式,暂停当前导航"""
|
||||||
|
# 保存当前导航状态(如果在导航中)
|
||||||
|
if self.state in [BLINDPATH_NAV, SEEKING_CROSSWALK, WAIT_TRAFFIC_LIGHT, CROSSING, SEEKING_NEXT_BLINDPATH]:
|
||||||
|
self.prev_nav_state_before_search = self.state
|
||||||
|
print(f"[NAV MASTER] 暂停导航状态 {self.state},切换到找物品模式")
|
||||||
|
else:
|
||||||
|
self.prev_nav_state_before_search = None
|
||||||
|
|
||||||
|
self.state = ITEM_SEARCH
|
||||||
|
self.cooldown_until = time.time() + self.COOLDOWN_SEC
|
||||||
|
|
||||||
|
def stop_item_search(self, restore_nav: bool = True):
|
||||||
|
"""停止找物品模式"""
|
||||||
|
# 如果需要恢复之前的导航状态
|
||||||
|
if restore_nav and self.prev_nav_state_before_search:
|
||||||
|
self.state = self.prev_nav_state_before_search
|
||||||
|
print(f"[NAV MASTER] 找物品结束,恢复到导航状态 {self.state}")
|
||||||
|
self.prev_nav_state_before_search = None
|
||||||
|
else:
|
||||||
|
# 否则回到对话模式
|
||||||
|
self.state = CHAT
|
||||||
|
print(f"[NAV MASTER] 找物品结束,回到对话模式")
|
||||||
|
|
||||||
|
self.cooldown_until = time.time() + self.COOLDOWN_SEC
|
||||||
|
|
||||||
|
def force_state(self, s: str):
|
||||||
|
self.state = s
|
||||||
|
self.cooldown_until = time.time() + self.COOLDOWN_SEC
|
||||||
|
|
||||||
|
def on_voice_command(self, text: str):
|
||||||
|
t = (text or "").strip()
|
||||||
|
if "开始过马路" in t:
|
||||||
|
# 直接进入等待/或立即过马路(低速环境可直过)
|
||||||
|
if self.state in (BLINDPATH_NAV, SEEKING_CROSSWALK, WAIT_TRAFFIC_LIGHT, IDLE, RECOVERY, SEEKING_NEXT_BLINDPATH):
|
||||||
|
self.state = WAIT_TRAFFIC_LIGHT
|
||||||
|
self.cooldown_until = time.time() + self.COOLDOWN_SEC
|
||||||
|
elif "立即通过" in t or "现在通过" in t:
|
||||||
|
self.state = CROSSING
|
||||||
|
self.cooldown_until = time.time() + self.COOLDOWN_SEC
|
||||||
|
elif "停止" in t or "结束" in t:
|
||||||
|
self.state = IDLE
|
||||||
|
elif "继续" in t:
|
||||||
|
if self.state == IDLE:
|
||||||
|
self.state = BLINDPATH_NAV
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
self.state = IDLE
|
||||||
|
self.cnt_crosswalk_seen = 0
|
||||||
|
self.cnt_align_ready = 0
|
||||||
|
self.cnt_cross_end = 0
|
||||||
|
self.cnt_lost = 0
|
||||||
|
self.tl_major.clear()
|
||||||
|
self.tl_last_color = "unknown"
|
||||||
|
self.prev_target_state = BLINDPATH_NAV
|
||||||
|
self._last_wait_light_announce = 0 # 重置等待绿灯播报时间
|
||||||
|
try:
|
||||||
|
self.blind.reset()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
try:
|
||||||
|
self.cross.reset()
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# ----- 内部工具 -----
|
||||||
|
def _say(self, now: float, text: str) -> str:
|
||||||
|
if not text:
|
||||||
|
return ""
|
||||||
|
if now - self.last_guidance_ts >= self.min_tts_interval:
|
||||||
|
self.last_guidance_ts = now
|
||||||
|
return text
|
||||||
|
return ""
|
||||||
|
|
||||||
|
def _draw_tl_status(self, img: np.ndarray, color: str, meta: Dict[str, Any]):
|
||||||
|
if img is None:
|
||||||
|
return
|
||||||
|
color_bgr = _color_bgr(color)
|
||||||
|
# 角标与文本
|
||||||
|
cv2.circle(img, (24, 24), 10, color_bgr, -1)
|
||||||
|
_put_text(img, f"信号灯: {color}", (40, 30), color=color_bgr, scale=0.6, thick=2, outline=False)
|
||||||
|
# 画 bbox(若有)
|
||||||
|
if meta and "bbox" in meta:
|
||||||
|
x1, y1, x2, y2 = meta["bbox"]
|
||||||
|
cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), color_bgr, 2)
|
||||||
|
|
||||||
|
# 多数表决历史(最近8帧)
|
||||||
|
hist = self.tl_major.history()
|
||||||
|
if hist:
|
||||||
|
x0, y0 = 10, 50
|
||||||
|
r = 6
|
||||||
|
gap = 16
|
||||||
|
for i, hcol in enumerate(hist[-12:]):
|
||||||
|
cv2.circle(img, (x0 + i*gap, y0), r, _color_bgr(hcol), -1)
|
||||||
|
_put_text(img, "信号历史", (x0, y0+20), color=(255,255,255), scale=0.5, thick=1)
|
||||||
|
|
||||||
|
# ----- 主循环 -----
|
||||||
|
def process_frame(self, bgr: np.ndarray) -> OrchestratorResult:
|
||||||
|
now = time.time()
|
||||||
|
|
||||||
|
# 【修改】IDLE状态默认进入CHAT模式,而不是自动开始导航
|
||||||
|
if self.state == IDLE:
|
||||||
|
self.state = CHAT
|
||||||
|
self.cooldown_until = now + self.COOLDOWN_SEC
|
||||||
|
|
||||||
|
# 【新增】CHAT模式:只返回原始画面,不进行导航
|
||||||
|
if self.state == CHAT:
|
||||||
|
return OrchestratorResult(
|
||||||
|
annotated_image=bgr,
|
||||||
|
guidance_text="",
|
||||||
|
state="CHAT",
|
||||||
|
extras={"mode": "对话模式"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 【新增】红绿灯检测模式:只返回原始画面,由红绿灯模块处理
|
||||||
|
if self.state == TRAFFIC_LIGHT_DETECTION:
|
||||||
|
return OrchestratorResult(
|
||||||
|
annotated_image=bgr,
|
||||||
|
guidance_text="",
|
||||||
|
state="TRAFFIC_LIGHT_DETECTION",
|
||||||
|
extras={"mode": "红绿灯检测模式"}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 【新增】找物品模式:只返回原始画面,由yolomedia处理
|
||||||
|
if self.state == ITEM_SEARCH:
|
||||||
|
return OrchestratorResult(
|
||||||
|
annotated_image=bgr,
|
||||||
|
guidance_text="",
|
||||||
|
state="ITEM_SEARCH",
|
||||||
|
extras={"mode": "找物品模式", "prev_nav_state": self.prev_nav_state_before_search}
|
||||||
|
)
|
||||||
|
|
||||||
|
# 冷却期内允许继续输出画面,但避免"瞬时切换"
|
||||||
|
in_cooldown = now < self.cooldown_until
|
||||||
|
|
||||||
|
# 各状态处理
|
||||||
|
if self.state in (BLINDPATH_NAV, SEEKING_CROSSWALK, SEEKING_NEXT_BLINDPATH, RECOVERY):
|
||||||
|
# —— 盲道侧 —— 统一调用盲道导航器
|
||||||
|
try:
|
||||||
|
bres: BlindResult = self.blind.process_frame(bgr)
|
||||||
|
except Exception as e:
|
||||||
|
# 异常 → 进入恢复态
|
||||||
|
self.state = RECOVERY
|
||||||
|
self.cnt_lost += 5
|
||||||
|
ann_err = bgr.copy()
|
||||||
|
# 【移除】所有可视化干扰
|
||||||
|
# _draw_badge(ann_err, "NAV ERROR", (10, 28), fg="white", bg="red")
|
||||||
|
# _put_text(ann_err, str(e), (10, 56), color=(255,255,255), scale=0.55)
|
||||||
|
return OrchestratorResult(ann_err, self._say(now, ""), self.state, {"error": str(e)})
|
||||||
|
|
||||||
|
ann = bres.annotated_image if bres.annotated_image is not None else bgr.copy()
|
||||||
|
say = bres.guidance_text or ""
|
||||||
|
|
||||||
|
state_info = bres.state_info or {}
|
||||||
|
cross_stage = state_info.get("crosswalk_stage", "not_detected")
|
||||||
|
blind_state = state_info.get("state", "UNKNOWN")
|
||||||
|
# 可选字段(若工作流未来补充)
|
||||||
|
angle = float(state_info.get("last_angle", 0.0))
|
||||||
|
center_x_ratio = float(state_info.get("last_center_x_ratio", 0.5))
|
||||||
|
|
||||||
|
# —— 盲道 → 发现斑马线(approaching/ready)
|
||||||
|
if self.state == BLINDPATH_NAV:
|
||||||
|
if cross_stage in ("approaching", "ready"):
|
||||||
|
self.cnt_crosswalk_seen += 1
|
||||||
|
else:
|
||||||
|
self.cnt_crosswalk_seen = max(0, self.cnt_crosswalk_seen - 1)
|
||||||
|
|
||||||
|
if self.cnt_crosswalk_seen >= self.FRAMES_CROSS_SEEN and not in_cooldown:
|
||||||
|
self.state = SEEKING_CROSSWALK
|
||||||
|
self.cooldown_until = now + self.COOLDOWN_SEC
|
||||||
|
say = "正在接近斑马线,为您对准方向。"
|
||||||
|
|
||||||
|
# 【移除】所有可视化干扰
|
||||||
|
# _draw_badge(ann, f"STATE: {self.state}", (10, 28), fg="white", bg="blue")
|
||||||
|
# _draw_state_panel(ann, {
|
||||||
|
# "盲道状态": blind_state,
|
||||||
|
# "斑马线阶段": cross_stage,
|
||||||
|
# "靠近计数": self.cnt_crosswalk_seen,
|
||||||
|
# }, pos=(10, 60))
|
||||||
|
# _draw_progress_bar(ann, max(0.0, min(1.0, self.cnt_crosswalk_seen / max(1, self.FRAMES_CROSS_SEEN))), pos=(10, 120), size=(180, 10), color="cyan")
|
||||||
|
# _draw_frame_border(ann, color=_color_bgr("blue"), thickness=3)
|
||||||
|
|
||||||
|
# —— 对准阶段:同时利用 blind 内部 crosswalk_tracker 的角度与偏移(若提供)
|
||||||
|
elif self.state == SEEKING_CROSSWALK:
|
||||||
|
aligned = (abs(angle) <= self.ANGLE_ALIGN_THR_DEG and abs(center_x_ratio - 0.5) <= self.OFFSET_ALIGN_THR)
|
||||||
|
if cross_stage == "ready" and aligned:
|
||||||
|
self.cnt_align_ready += 1
|
||||||
|
else:
|
||||||
|
self.cnt_align_ready = max(0, self.cnt_align_ready - 1)
|
||||||
|
|
||||||
|
if self.cnt_align_ready >= self.FRAMES_ALIGN_READY and not in_cooldown:
|
||||||
|
self.state = WAIT_TRAFFIC_LIGHT
|
||||||
|
self.cooldown_until = now + self.COOLDOWN_SEC
|
||||||
|
say = "已到达斑马线,请等待红绿灯。"
|
||||||
|
|
||||||
|
# 【移除】所有可视化干扰
|
||||||
|
# _draw_badge(ann, f"STATE: {self.state}", (10, 28), fg="white", bg="orange")
|
||||||
|
# panel = {
|
||||||
|
# "阶段": cross_stage,
|
||||||
|
# "对准计数": self.cnt_align_ready,
|
||||||
|
# }
|
||||||
|
# if "last_angle" in state_info:
|
||||||
|
# panel["角度(°)"] = f"{angle:.1f}"
|
||||||
|
# if "last_center_x_ratio" in state_info:
|
||||||
|
# panel["偏移"] = f"{(center_x_ratio-0.5):+.2f}"
|
||||||
|
# _draw_state_panel(ann, panel, pos=(10, 60))
|
||||||
|
# _draw_progress_bar(ann, max(0.0, min(1.0, self.cnt_align_ready / max(1, self.FRAMES_ALIGN_READY))), pos=(10, 120), size=(220, 10), color="yellow")
|
||||||
|
# _draw_frame_border(ann, color=_color_bgr("orange"), thickness=3)
|
||||||
|
|
||||||
|
# —— 过马路后寻找下一段盲道(上盲道流程)
|
||||||
|
elif self.state == SEEKING_NEXT_BLINDPATH:
|
||||||
|
if blind_state == "NAVIGATING":
|
||||||
|
self.cnt_cross_end += 1
|
||||||
|
else:
|
||||||
|
self.cnt_cross_end = max(0, self.cnt_cross_end - 1)
|
||||||
|
if self.cnt_cross_end >= self.FRAMES_NEXT_BLIND_OK and not in_cooldown:
|
||||||
|
self.state = BLINDPATH_NAV
|
||||||
|
self.cooldown_until = now + self.COOLDOWN_SEC
|
||||||
|
say = "方向正确,请继续前进。"
|
||||||
|
|
||||||
|
# 【移除】所有可视化干扰
|
||||||
|
# _draw_badge(ann, f"STATE: {self.state}", (10, 28), fg="white", bg="green")
|
||||||
|
# _draw_state_panel(ann, {
|
||||||
|
# "盲道状态": blind_state,
|
||||||
|
# "回归计数": self.cnt_cross_end
|
||||||
|
# }, pos=(10, 60))
|
||||||
|
# _draw_progress_bar(ann, max(0.0, min(1.0, self.cnt_cross_end / max(1, self.FRAMES_NEXT_BLIND_OK))), pos=(10, 120), size=(200, 10), color="green")
|
||||||
|
# _draw_frame_border(ann, color=_color_bgr("green"), thickness=3)
|
||||||
|
|
||||||
|
# —— 恢复态:一旦盲道恢复可用则回盲道
|
||||||
|
elif self.state == RECOVERY:
|
||||||
|
if blind_state in ("ONBOARDING", "NAVIGATING"):
|
||||||
|
self.state = BLINDPATH_NAV
|
||||||
|
self.cooldown_until = now + self.COOLDOWN_SEC
|
||||||
|
say = ""
|
||||||
|
else:
|
||||||
|
say = ""
|
||||||
|
# 【移除】所有可视化干扰
|
||||||
|
# _draw_badge(ann, f"STATE: {self.state}", (10, 28), fg="white", bg="red")
|
||||||
|
# _draw_state_panel(ann, {
|
||||||
|
# "提示": "请缓慢环顾/抬头/降低手机角度",
|
||||||
|
# "丢失计数": self.cnt_lost
|
||||||
|
# }, pos=(10, 60))
|
||||||
|
# _draw_frame_border(ann, color=_color_bgr("red"), thickness=3)
|
||||||
|
|
||||||
|
# 丢失计数(兜底)
|
||||||
|
if blind_state == "UNKNOWN" and cross_stage == "not_detected":
|
||||||
|
self.cnt_lost += 1
|
||||||
|
else:
|
||||||
|
self.cnt_lost = max(0, self.cnt_lost - 2)
|
||||||
|
if self.cnt_lost >= self.FRAMES_LOST_MAX and self.state != RECOVERY:
|
||||||
|
self.prev_target_state = self.state
|
||||||
|
self.state = RECOVERY
|
||||||
|
self.cooldown_until = now + self.COOLDOWN_SEC
|
||||||
|
say = "环境复杂,进入恢复模式。"
|
||||||
|
|
||||||
|
# 【移除】冷却进度条
|
||||||
|
# if in_cooldown:
|
||||||
|
# remain = max(0.0, self.cooldown_until - now)
|
||||||
|
# ratio = 1.0 - min(1.0, remain / self.COOLDOWN_SEC)
|
||||||
|
# _draw_progress_bar(ann, ratio, pos=(10, 140), size=(160, 8), color="gray")
|
||||||
|
|
||||||
|
return OrchestratorResult(ann, self._say(now, say), self.state, {"source": "blind", "cross_stage": cross_stage, "blind_state": blind_state})
|
||||||
|
|
||||||
|
if self.state == WAIT_TRAFFIC_LIGHT:
|
||||||
|
ann = bgr.copy()
|
||||||
|
# 红绿灯识别(多数表决+冷却)
|
||||||
|
color, meta = self.tld.detect(bgr)
|
||||||
|
self.tl_major.push(color)
|
||||||
|
major = self.tl_major.majority()
|
||||||
|
self.tl_last_color = major
|
||||||
|
|
||||||
|
# 【移除】所有可视化干扰
|
||||||
|
# _draw_badge(ann, f"STATE: {self.state}", (10, 28), fg="white", bg="magenta")
|
||||||
|
# self._draw_tl_status(ann, major, meta)
|
||||||
|
# _draw_state_panel(ann, {
|
||||||
|
# "提示": "请等待绿灯或语音确认"立即通过"",
|
||||||
|
# "冷却": f"{max(0.0, self.cooldown_until - now):.1f}s"
|
||||||
|
# }, pos=(10, 80))
|
||||||
|
# _draw_frame_border(ann, color=_color_bgr("magenta"), thickness=3)
|
||||||
|
|
||||||
|
say = ""
|
||||||
|
if major == "green" and not in_cooldown:
|
||||||
|
self.state = CROSSING
|
||||||
|
self.cooldown_until = now + self.COOLDOWN_SEC
|
||||||
|
say = "绿灯稳定,开始通行。"
|
||||||
|
else:
|
||||||
|
# 只在刚进入状态或每隔一段时间才播报
|
||||||
|
if not hasattr(self, '_last_wait_light_announce'):
|
||||||
|
self._last_wait_light_announce = 0
|
||||||
|
if now - self._last_wait_light_announce > 5.0: # 5秒播报一次
|
||||||
|
say = "正在等待绿灯…"
|
||||||
|
self._last_wait_light_announce = now
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# 【移除】冷却进度
|
||||||
|
# if in_cooldown:
|
||||||
|
# remain = max(0.0, self.cooldown_until - now)
|
||||||
|
# ratio = 1.0 - min(1.0, remain / self.COOLDOWN_SEC)
|
||||||
|
# _draw_progress_bar(ann, ratio, pos=(10, 140), size=(160, 8), color="gray")
|
||||||
|
|
||||||
|
return OrchestratorResult(ann, self._say(now, say), self.state, {"traffic_light": major})
|
||||||
|
|
||||||
|
if self.state == CROSSING:
|
||||||
|
try:
|
||||||
|
cres: CrossResult = self.cross.process_frame(bgr)
|
||||||
|
except Exception as e:
|
||||||
|
# 异常 → 恢复
|
||||||
|
self.state = RECOVERY
|
||||||
|
ann_err = bgr.copy()
|
||||||
|
# 【移除】所有可视化干扰
|
||||||
|
# _draw_badge(ann_err, "CROSS ERROR", (10, 28), fg="white", bg="red")
|
||||||
|
# _put_text(ann_err, str(e), (10, 56), color=(255,255,255), scale=0.55)
|
||||||
|
return OrchestratorResult(ann_err, self._say(now, ""), self.state, {"error": str(e)})
|
||||||
|
|
||||||
|
ann = cres.annotated_image if cres.annotated_image is not None else bgr.copy()
|
||||||
|
say = cres.guidance_text or ""
|
||||||
|
|
||||||
|
# 新增:检查是否检测到盲道
|
||||||
|
blind_path_detected = getattr(cres, 'blind_path_detected', False)
|
||||||
|
blind_path_guidance = getattr(cres, 'blind_path_guidance', "")
|
||||||
|
|
||||||
|
# 如果检测到盲道且需要引导,优先处理盲道引导
|
||||||
|
if blind_path_detected and blind_path_guidance:
|
||||||
|
# 如果应该切换到盲道导航(盲道很近),直接切换状态
|
||||||
|
if hasattr(cres, "should_switch_to_blindpath") and cres.should_switch_to_blindpath:
|
||||||
|
if not in_cooldown:
|
||||||
|
self.state = BLINDPATH_NAV
|
||||||
|
self.cooldown_until = now + self.COOLDOWN_SEC
|
||||||
|
say = "已到盲道跟前,切换到盲道导航。" # 使用现有语音文件
|
||||||
|
self.cnt_cross_end = 0 # 重置计数器
|
||||||
|
# 重置盲道导航器状态
|
||||||
|
if hasattr(self.blind, 'reset'):
|
||||||
|
self.blind.reset()
|
||||||
|
else:
|
||||||
|
# 盲道较远,继续过马路但给出盲道引导
|
||||||
|
# say 已经在 cres.guidance_text 中包含了盲道引导信息
|
||||||
|
pass
|
||||||
|
|
||||||
|
# 原有的结束条件:连续多帧"寻找斑马线"
|
||||||
|
end_hint = False
|
||||||
|
if "寻找斑马线" in (say or ""):
|
||||||
|
end_hint = True
|
||||||
|
# 注意:不再单纯因为 should_switch_to_blindpath 就结束过马路
|
||||||
|
# if hasattr(cres, "should_switch_to_blindpath") and cres.should_switch_to_blindpath:
|
||||||
|
# end_hint = True
|
||||||
|
|
||||||
|
self.cnt_cross_end = self.cnt_cross_end + 1 if end_hint else max(0, self.cnt_cross_end - 1)
|
||||||
|
|
||||||
|
if self.cnt_cross_end >= self.FRAMES_CROSS_END and not in_cooldown:
|
||||||
|
self.state = SEEKING_NEXT_BLINDPATH
|
||||||
|
self.cooldown_until = now + self.COOLDOWN_SEC
|
||||||
|
say = "过马路结束,准备上人行道。"
|
||||||
|
|
||||||
|
# 【移除】所有可视化干扰
|
||||||
|
# _draw_badge(ann, f"STATE: {self.state}", (10, 28), fg="white", bg="cyan")
|
||||||
|
# _draw_state_panel(ann, {
|
||||||
|
# "结束计数": self.cnt_cross_end,
|
||||||
|
# "冷却": f"{max(0.0, self.cooldown_until - now):.1f}s"
|
||||||
|
# }, pos=(10, 60))
|
||||||
|
# _draw_progress_bar(ann, max(0.0, min(1.0, self.cnt_cross_end / max(1, self.FRAMES_CROSS_END))), pos=(10, 120), size=(220, 10), color="cyan")
|
||||||
|
# _draw_frame_border(ann, color=_color_bgr("cyan"), thickness=3)
|
||||||
|
# if in_cooldown:
|
||||||
|
# remain = max(0.0, self.cooldown_until - now)
|
||||||
|
# ratio = 1.0 - min(1.0, remain / self.COOLDOWN_SEC)
|
||||||
|
# _draw_progress_bar(ann, ratio, pos=(10, 140), size=(160, 8), color="gray")
|
||||||
|
|
||||||
|
return OrchestratorResult(ann, self._say(now, say), self.state, {"source": "cross", "end_cnt": self.cnt_cross_end})
|
||||||
|
|
||||||
|
# 兜底
|
||||||
|
ann = bgr.copy()
|
||||||
|
# 【移除】所有可视化干扰
|
||||||
|
# _draw_badge(ann, f"STATE: {self.state}", (10, 28), fg="white", bg="gray")
|
||||||
|
# _draw_frame_border(ann, color=_color_bgr("gray"), thickness=2)
|
||||||
|
return OrchestratorResult(ann, "", self.state, {})
|
||||||
|
|
||||||
|
|
||||||
185
numba_utils.py
Normal file
185
numba_utils.py
Normal file
@@ -0,0 +1,185 @@
|
|||||||
|
# numba_utils.py - Day 20 Numba 多核加速工具
|
||||||
|
"""
|
||||||
|
使用 Numba JIT 编译加速 numpy 密集操作,绕过 Python GIL 实现真正多核并行
|
||||||
|
"""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
try:
|
||||||
|
from numba import jit, prange
|
||||||
|
NUMBA_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
NUMBA_AVAILABLE = False
|
||||||
|
print("[NUMBA] Numba 未安装,使用 numpy 回退实现")
|
||||||
|
|
||||||
|
|
||||||
|
if NUMBA_AVAILABLE:
|
||||||
|
@jit(nopython=True, parallel=True, cache=True)
|
||||||
|
def count_mask_pixels_numba(mask: np.ndarray) -> int:
|
||||||
|
"""快速计算 mask 中非零像素数量(多核并行)"""
|
||||||
|
count = 0
|
||||||
|
h, w = mask.shape
|
||||||
|
for i in prange(h):
|
||||||
|
for j in range(w):
|
||||||
|
if mask[i, j] > 0:
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
@jit(nopython=True, parallel=True, cache=True)
|
||||||
|
def compute_mask_stats_numba(mask: np.ndarray) -> tuple:
|
||||||
|
"""
|
||||||
|
快速计算 mask 的统计信息(多核并行)
|
||||||
|
返回: (area, center_x, center_y, min_x, max_x, min_y, max_y)
|
||||||
|
"""
|
||||||
|
h, w = mask.shape
|
||||||
|
count = 0
|
||||||
|
sum_x = 0.0
|
||||||
|
sum_y = 0.0
|
||||||
|
min_x = w
|
||||||
|
max_x = 0
|
||||||
|
min_y = h
|
||||||
|
max_y = 0
|
||||||
|
|
||||||
|
for i in prange(h):
|
||||||
|
for j in range(w):
|
||||||
|
if mask[i, j] > 0:
|
||||||
|
count += 1
|
||||||
|
sum_x += j
|
||||||
|
sum_y += i
|
||||||
|
if j < min_x:
|
||||||
|
min_x = j
|
||||||
|
if j > max_x:
|
||||||
|
max_x = j
|
||||||
|
if i < min_y:
|
||||||
|
min_y = i
|
||||||
|
if i > max_y:
|
||||||
|
max_y = i
|
||||||
|
|
||||||
|
if count > 0:
|
||||||
|
center_x = sum_x / count
|
||||||
|
center_y = sum_y / count
|
||||||
|
else:
|
||||||
|
center_x = 0.0
|
||||||
|
center_y = 0.0
|
||||||
|
|
||||||
|
return (count, center_x, center_y, min_x, max_x, min_y, max_y)
|
||||||
|
|
||||||
|
@jit(nopython=True, parallel=True, cache=True)
|
||||||
|
def bitwise_and_count_numba(mask1: np.ndarray, mask2: np.ndarray) -> int:
|
||||||
|
"""快速计算两个 mask 的交集像素数量(多核并行)"""
|
||||||
|
h, w = mask1.shape
|
||||||
|
count = 0
|
||||||
|
for i in prange(h):
|
||||||
|
for j in range(w):
|
||||||
|
if mask1[i, j] > 0 and mask2[i, j] > 0:
|
||||||
|
count += 1
|
||||||
|
return count
|
||||||
|
|
||||||
|
@jit(nopython=True, parallel=True, cache=True)
|
||||||
|
def resize_mask_nearest_numba(mask: np.ndarray, new_h: int, new_w: int) -> np.ndarray:
|
||||||
|
"""
|
||||||
|
快速最近邻插值缩放 mask(多核并行)
|
||||||
|
注意:这是简化实现,对于大多数情况足够用
|
||||||
|
"""
|
||||||
|
old_h, old_w = mask.shape
|
||||||
|
result = np.zeros((new_h, new_w), dtype=np.uint8)
|
||||||
|
|
||||||
|
scale_y = old_h / new_h
|
||||||
|
scale_x = old_w / new_w
|
||||||
|
|
||||||
|
for i in prange(new_h):
|
||||||
|
for j in range(new_w):
|
||||||
|
src_y = int(i * scale_y)
|
||||||
|
src_x = int(j * scale_x)
|
||||||
|
if src_y >= old_h:
|
||||||
|
src_y = old_h - 1
|
||||||
|
if src_x >= old_w:
|
||||||
|
src_x = old_w - 1
|
||||||
|
result[i, j] = mask[src_y, src_x]
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
# 对外接口:根据 Numba 是否可用选择实现
|
||||||
|
def count_mask_pixels(mask: np.ndarray) -> int:
|
||||||
|
"""计算 mask 中非零像素数量"""
|
||||||
|
if NUMBA_AVAILABLE:
|
||||||
|
return count_mask_pixels_numba(mask)
|
||||||
|
else:
|
||||||
|
return int(np.sum(mask > 0))
|
||||||
|
|
||||||
|
|
||||||
|
def compute_mask_stats(mask: np.ndarray) -> dict:
|
||||||
|
"""
|
||||||
|
计算 mask 的统计信息
|
||||||
|
返回: {'area': int, 'center_x': float, 'center_y': float, 'bbox': (x1, y1, x2, y2)}
|
||||||
|
"""
|
||||||
|
if NUMBA_AVAILABLE:
|
||||||
|
area, cx, cy, min_x, max_x, min_y, max_y = compute_mask_stats_numba(mask)
|
||||||
|
return {
|
||||||
|
'area': int(area),
|
||||||
|
'center_x': float(cx),
|
||||||
|
'center_y': float(cy),
|
||||||
|
'bbox': (int(min_x), int(min_y), int(max_x), int(max_y))
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# numpy 回退
|
||||||
|
y_coords, x_coords = np.where(mask > 0)
|
||||||
|
if len(y_coords) == 0:
|
||||||
|
return {'area': 0, 'center_x': 0, 'center_y': 0, 'bbox': (0, 0, 0, 0)}
|
||||||
|
return {
|
||||||
|
'area': len(y_coords),
|
||||||
|
'center_x': float(np.mean(x_coords)),
|
||||||
|
'center_y': float(np.mean(y_coords)),
|
||||||
|
'bbox': (int(np.min(x_coords)), int(np.min(y_coords)),
|
||||||
|
int(np.max(x_coords)), int(np.max(y_coords)))
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def bitwise_and_count(mask1: np.ndarray, mask2: np.ndarray) -> int:
|
||||||
|
"""计算两个 mask 的交集像素数量"""
|
||||||
|
if NUMBA_AVAILABLE:
|
||||||
|
return bitwise_and_count_numba(mask1.astype(np.uint8), mask2.astype(np.uint8))
|
||||||
|
else:
|
||||||
|
return int(np.sum(np.bitwise_and(mask1, mask2) > 0))
|
||||||
|
|
||||||
|
|
||||||
|
# 预热 JIT 编译(首次调用时编译,之后使用缓存)
|
||||||
|
def warmup():
|
||||||
|
"""预热 Numba JIT 编译,避免首次调用时的延迟"""
|
||||||
|
if NUMBA_AVAILABLE:
|
||||||
|
dummy = np.zeros((10, 10), dtype=np.uint8)
|
||||||
|
dummy[5, 5] = 255
|
||||||
|
count_mask_pixels_numba(dummy)
|
||||||
|
compute_mask_stats_numba(dummy)
|
||||||
|
bitwise_and_count_numba(dummy, dummy)
|
||||||
|
print("[NUMBA] JIT 编译预热完成,已启用多核加速")
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
# 测试和性能对比
|
||||||
|
import time
|
||||||
|
|
||||||
|
# 创建测试数据
|
||||||
|
test_mask = np.zeros((480, 640), dtype=np.uint8)
|
||||||
|
test_mask[100:300, 200:400] = 255
|
||||||
|
|
||||||
|
# 测试 numpy 版本
|
||||||
|
start = time.perf_counter()
|
||||||
|
for _ in range(100):
|
||||||
|
np.sum(test_mask > 0)
|
||||||
|
numpy_time = (time.perf_counter() - start) * 1000
|
||||||
|
|
||||||
|
# 测试 numba 版本
|
||||||
|
if NUMBA_AVAILABLE:
|
||||||
|
# 预热
|
||||||
|
count_mask_pixels_numba(test_mask)
|
||||||
|
|
||||||
|
start = time.perf_counter()
|
||||||
|
for _ in range(100):
|
||||||
|
count_mask_pixels_numba(test_mask)
|
||||||
|
numba_time = (time.perf_counter() - start) * 1000
|
||||||
|
|
||||||
|
print(f"numpy: {numpy_time:.2f}ms / 100 次")
|
||||||
|
print(f"numba: {numba_time:.2f}ms / 100 次")
|
||||||
|
print(f"加速比: {numpy_time / numba_time:.1f}x")
|
||||||
244
obstacle_detector_client.py
Normal file
244
obstacle_detector_client.py
Normal file
@@ -0,0 +1,244 @@
|
|||||||
|
# app/cloud/obstacle_detector_client.py (新文件)
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from threading import Semaphore
|
||||||
|
from contextlib import contextmanager
|
||||||
|
from ultralytics import YOLOE
|
||||||
|
from typing import List, Dict, Any
|
||||||
|
|
||||||
|
# Day 20: Numba 多核加速
|
||||||
|
try:
|
||||||
|
from numba_utils import count_mask_pixels, compute_mask_stats, bitwise_and_count, warmup as numba_warmup
|
||||||
|
NUMBA_ENABLED = True
|
||||||
|
except ImportError:
|
||||||
|
NUMBA_ENABLED = False
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# --- GPU/CPU & AMP 配置 (从 blindpath 工作流迁移而来,保持一致) ---
|
||||||
|
DEVICE = os.getenv("AIGLASS_DEVICE", "cuda:0")
|
||||||
|
if DEVICE.startswith("cuda") and not torch.cuda.is_available():
|
||||||
|
logger.warning(f"AIGLASS_DEVICE={DEVICE} 但未检测到 CUDA,将回退到 CPU")
|
||||||
|
DEVICE = "cpu"
|
||||||
|
IS_CUDA = DEVICE.startswith("cuda")
|
||||||
|
|
||||||
|
AMP_POLICY = os.getenv("AIGLASS_AMP", "fp16").lower()
|
||||||
|
if AMP_POLICY not in ("bf16", "fp16", "off"):
|
||||||
|
AMP_POLICY = "fp16"
|
||||||
|
AMP_DTYPE = torch.bfloat16 if AMP_POLICY == "bf16" else (torch.float16 if AMP_POLICY == "fp16" else None)
|
||||||
|
|
||||||
|
# --- GPU 并发限流 (从 blindpath 工作流迁移而来,保持一致) ---
|
||||||
|
# Day 20: 增加默认槽位从 2 到 4,RTX 3090 可以处理更多并发
|
||||||
|
GPU_SLOTS = int(os.getenv("AIGLASS_GPU_SLOTS", "4"))
|
||||||
|
_gpu_slots = Semaphore(GPU_SLOTS)
|
||||||
|
|
||||||
|
try:
|
||||||
|
torch.backends.cudnn.benchmark = True
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def gpu_infer_slot():
|
||||||
|
"""统一管理 GPU 并发限流 + inference_mode + AMP autocast"""
|
||||||
|
with _gpu_slots:
|
||||||
|
if IS_CUDA and AMP_POLICY != "off":
|
||||||
|
# 新式接口:torch.amp.autocast(device_type='cuda', dtype=...)
|
||||||
|
with torch.inference_mode(), torch.amp.autocast(device_type='cuda', dtype=AMP_DTYPE):
|
||||||
|
yield
|
||||||
|
else:
|
||||||
|
with torch.inference_mode():
|
||||||
|
yield
|
||||||
|
|
||||||
|
|
||||||
|
class ObstacleDetectorClient:
|
||||||
|
def __init__(self, model_path: str = 'model/yoloe-11l-seg.pt'):
|
||||||
|
self.model = None
|
||||||
|
self.whitelist_embeddings = None
|
||||||
|
self.WHITELIST_CLASSES = [
|
||||||
|
'bicycle', 'car', 'motorcycle', 'bus', 'truck', 'animal', 'scooter', 'stroller', 'dog',
|
||||||
|
'pole', 'post', 'column', 'pillar', 'stanchion', 'bollard', 'utility pole',
|
||||||
|
'telegraph pole', 'light pole', 'street pole', 'signpost', 'support post',
|
||||||
|
'vertical post', 'bench', 'chair', 'potted plant', 'hydrant', 'cone', 'stone', 'box'
|
||||||
|
]
|
||||||
|
# COCO 类别白名单 - TensorRT 模式下用于后处理过滤
|
||||||
|
# 从 COCO 80 类中筛选出可能构成障碍物的类别
|
||||||
|
self.COCO_WHITELIST = {
|
||||||
|
'person', 'bicycle', 'car', 'motorcycle', 'bus', 'truck', # 交通
|
||||||
|
'dog', 'cat', 'horse', 'cow', 'sheep', # 动物
|
||||||
|
'bench', 'chair', 'potted plant', 'fire hydrant', 'stop sign', # 街道设施
|
||||||
|
'parking meter', 'suitcase', 'backpack', 'umbrella', 'handbag', # 物品
|
||||||
|
'sports ball', 'skateboard', 'surfboard', 'bottle', 'cup', # 可能障碍
|
||||||
|
}
|
||||||
|
try:
|
||||||
|
# Day 20: 优先使用 TensorRT 引擎
|
||||||
|
try:
|
||||||
|
from model_utils import get_best_model_path, is_tensorrt_engine
|
||||||
|
model_path = get_best_model_path(model_path)
|
||||||
|
except ImportError:
|
||||||
|
def is_tensorrt_engine(p): return p.endswith('.engine')
|
||||||
|
|
||||||
|
logger.info(f"正在加载 YOLOE 障碍物模型: {model_path}")
|
||||||
|
self.model = YOLOE(model_path)
|
||||||
|
|
||||||
|
# Day 20: TensorRT 引擎不需要 .to() 和 .fuse()
|
||||||
|
if is_tensorrt_engine(model_path):
|
||||||
|
logger.info(f"TensorRT 引擎已加载,跳过 .to() 和 .fuse()")
|
||||||
|
# TensorRT 引擎不支持 get_text_pe,跳过白名单特征计算
|
||||||
|
self.whitelist_embeddings = None
|
||||||
|
logger.info("TensorRT 模式:跳过白名单特征预计算")
|
||||||
|
else:
|
||||||
|
self.model.to(DEVICE)
|
||||||
|
self.model.fuse()
|
||||||
|
logger.info(f"YOLOE 障碍物模型加载成功,使用设备: {DEVICE}")
|
||||||
|
|
||||||
|
logger.info("正在为 YOLOE 预计算白名单文本特征...")
|
||||||
|
if IS_CUDA and AMP_DTYPE is not None:
|
||||||
|
with torch.inference_mode(), torch.amp.autocast(device_type='cuda', dtype=AMP_DTYPE):
|
||||||
|
self.whitelist_embeddings = self.model.get_text_pe(self.WHITELIST_CLASSES)
|
||||||
|
else:
|
||||||
|
self.whitelist_embeddings = self.model.get_text_pe(self.WHITELIST_CLASSES)
|
||||||
|
logger.info("YOLOE 特征预计算完成。")
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"YOLOE 模型加载或特征计算失败: {e}", exc_info=True)
|
||||||
|
raise
|
||||||
|
def tensor_to_numpy_mask(mask_tensor):
|
||||||
|
"""安全地将各种类型的张量转换为 numpy 掩码"""
|
||||||
|
# 处理不同的数据类型
|
||||||
|
if mask_tensor.dtype in (torch.bfloat16, torch.float16):
|
||||||
|
mask_tensor = mask_tensor.float()
|
||||||
|
|
||||||
|
# 转换为 numpy
|
||||||
|
mask = mask_tensor.cpu().numpy()
|
||||||
|
|
||||||
|
# 确保是二值掩码
|
||||||
|
if mask.max() <= 1.0:
|
||||||
|
mask = (mask > 0.5).astype(np.uint8) * 255
|
||||||
|
else:
|
||||||
|
mask = mask.astype(np.uint8)
|
||||||
|
|
||||||
|
return mask
|
||||||
|
def detect(self, image: np.ndarray, path_mask: np.ndarray = None) -> List[Dict[str, Any]]:
|
||||||
|
"""
|
||||||
|
利用白名单作为提示词寻找障碍物。
|
||||||
|
如果提供了 path_mask,则执行与路径相关的空间过滤。
|
||||||
|
如果 path_mask 为 None,则进行全局检测。
|
||||||
|
"""
|
||||||
|
if self.model is None:
|
||||||
|
return []
|
||||||
|
|
||||||
|
H, W = image.shape[:2]
|
||||||
|
|
||||||
|
# TensorRT 模式下没有 embeddings,跳过 set_classes
|
||||||
|
# 此时模型会使用默认的 COCO 类别进行检测
|
||||||
|
if self.whitelist_embeddings is not None:
|
||||||
|
try:
|
||||||
|
self.model.set_classes(self.WHITELIST_CLASSES, self.whitelist_embeddings)
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"设置 YOLOE 提示词失败: {e}")
|
||||||
|
return []
|
||||||
|
|
||||||
|
conf_thr = float(os.getenv("AIGLASS_OBS_CONF", "0.25"))
|
||||||
|
# Day 22 优化: 动态输入尺寸和FP16加速
|
||||||
|
imgsz = int(os.getenv("AIGLASS_OBS_IMGSZ", "480")) # 从默认640降低
|
||||||
|
use_half = os.getenv("AIGLASS_OBS_HALF", "1") == "1"
|
||||||
|
|
||||||
|
with gpu_infer_slot():
|
||||||
|
results = self.model.predict(
|
||||||
|
image,
|
||||||
|
verbose=False,
|
||||||
|
conf=conf_thr,
|
||||||
|
imgsz=imgsz, # 使用较小的输入尺寸
|
||||||
|
half=use_half # FP16 半精度加速
|
||||||
|
)
|
||||||
|
|
||||||
|
if not (results and results[0].masks):
|
||||||
|
return []
|
||||||
|
|
||||||
|
# --- 过滤与后处理 (逻辑与 blindpath 工作流保持一致) ---
|
||||||
|
final_obstacles = []
|
||||||
|
num_masks = len(results[0].masks.data)
|
||||||
|
num_boxes = len(results[0].boxes.cls) if getattr(results[0].boxes, "cls", None) is not None else 0
|
||||||
|
|
||||||
|
for i, mask_tensor in enumerate(results[0].masks.data):
|
||||||
|
if i >= num_boxes: continue
|
||||||
|
|
||||||
|
# 【修复】处理 BFloat16 类型的掩码
|
||||||
|
# 先转换为 float32,避免 numpy 不支持 BFloat16 的问题
|
||||||
|
if mask_tensor.dtype == torch.bfloat16:
|
||||||
|
mask_tensor = mask_tensor.float()
|
||||||
|
|
||||||
|
# 转换为 numpy 数组
|
||||||
|
mask = mask_tensor.cpu().numpy()
|
||||||
|
|
||||||
|
# 处理概率掩码(值在0-1之间)或二值掩码
|
||||||
|
if mask.max() <= 1.0:
|
||||||
|
# 概率掩码,需要二值化
|
||||||
|
mask = (mask > 0.5).astype(np.uint8) * 255
|
||||||
|
else:
|
||||||
|
# 已经是二值掩码
|
||||||
|
mask = mask.astype(np.uint8)
|
||||||
|
|
||||||
|
mask = cv2.resize(mask, (W, H), interpolation=cv2.INTER_NEAREST)
|
||||||
|
|
||||||
|
# Day 20: 使用 Numba 多核加速计算 mask 统计信息
|
||||||
|
if NUMBA_ENABLED:
|
||||||
|
stats = compute_mask_stats(mask)
|
||||||
|
area = stats['area']
|
||||||
|
center_x = stats['center_x']
|
||||||
|
center_y = stats['center_y']
|
||||||
|
min_y, max_y = stats['bbox'][1], stats['bbox'][3]
|
||||||
|
else:
|
||||||
|
area = int(np.sum(mask > 0))
|
||||||
|
y_coords, x_coords = np.where(mask > 0)
|
||||||
|
if len(y_coords) == 0:
|
||||||
|
continue
|
||||||
|
center_x = float(np.mean(x_coords))
|
||||||
|
center_y = float(np.mean(y_coords))
|
||||||
|
min_y, max_y = int(np.min(y_coords)), int(np.max(y_coords))
|
||||||
|
|
||||||
|
# 尺寸过滤:太大的物体(如整片地面)通常是误识别
|
||||||
|
if (area / (H * W)) > 0.7: continue
|
||||||
|
if area == 0: continue
|
||||||
|
|
||||||
|
# 空间过滤:如果提供了 path_mask,则只保留路径上的障碍物
|
||||||
|
if path_mask is not None:
|
||||||
|
# Day 20: 使用 Numba 加速交集计算
|
||||||
|
if NUMBA_ENABLED:
|
||||||
|
intersection_area = bitwise_and_count(mask, path_mask)
|
||||||
|
else:
|
||||||
|
intersection_area = int(np.sum(cv2.bitwise_and(mask, path_mask) > 0))
|
||||||
|
# 必须与路径有足够的重叠
|
||||||
|
if intersection_area < 100 or (intersection_area / area) < 0.01:
|
||||||
|
continue
|
||||||
|
|
||||||
|
cls_id = int(results[0].boxes.cls[i])
|
||||||
|
class_names_map = results[0].names
|
||||||
|
class_name = "Unknown"
|
||||||
|
if isinstance(class_names_map, dict):
|
||||||
|
# 如果是字典,使用 .get() 方法
|
||||||
|
class_name = class_names_map.get(cls_id, "Unknown")
|
||||||
|
elif isinstance(class_names_map, list) and 0 <= cls_id < len(class_names_map):
|
||||||
|
# 如果是列表,通过索引安全地获取
|
||||||
|
class_name = class_names_map[cls_id]
|
||||||
|
|
||||||
|
# TensorRT 模式下使用 COCO 白名单过滤
|
||||||
|
# 只保留可能构成障碍物的类别
|
||||||
|
if self.whitelist_embeddings is None: # TensorRT 模式
|
||||||
|
if class_name.lower().strip() not in self.COCO_WHITELIST:
|
||||||
|
continue # 跳过非白名单类别
|
||||||
|
|
||||||
|
final_obstacles.append({
|
||||||
|
'name': class_name.strip(),
|
||||||
|
'mask': mask,
|
||||||
|
'area': area,
|
||||||
|
'area_ratio': area / (H * W),
|
||||||
|
'center_x': center_x,
|
||||||
|
'center_y': center_y,
|
||||||
|
'bottom_y_ratio': max_y / H
|
||||||
|
})
|
||||||
|
|
||||||
|
return final_obstacles
|
||||||
117
omni_client.py
Normal file
117
omni_client.py
Normal file
@@ -0,0 +1,117 @@
|
|||||||
|
# omni_client.py
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
import os, base64, asyncio, threading
|
||||||
|
from typing import AsyncGenerator, Dict, Any, List, Optional, Tuple
|
||||||
|
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# ===== OpenAI 兼容(达摩院 DashScope 兼容模式)=====
|
||||||
|
API_KEY = os.getenv("DASHSCOPE_API_KEY", "sk-a9440db694924559ae4ebdc2023d2b9a")
|
||||||
|
if not API_KEY:
|
||||||
|
raise RuntimeError("未设置 DASHSCOPE_API_KEY")
|
||||||
|
|
||||||
|
QWEN_MODEL = "qwen-omni-turbo"
|
||||||
|
|
||||||
|
# 兼容模式
|
||||||
|
oai_client = OpenAI(
|
||||||
|
api_key=API_KEY,
|
||||||
|
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
class OmniStreamPiece:
|
||||||
|
"""对外的统一增量数据:text/audio 二选一或同时。"""
|
||||||
|
def __init__(self, text_delta: Optional[str] = None, audio_b64: Optional[str] = None):
|
||||||
|
self.text_delta = text_delta
|
||||||
|
self.audio_b64 = audio_b64
|
||||||
|
|
||||||
|
async def stream_chat(
|
||||||
|
content_list: List[Dict[str, Any]],
|
||||||
|
voice: str = "Cherry",
|
||||||
|
audio_format: str = "wav",
|
||||||
|
) -> AsyncGenerator[OmniStreamPiece, None]:
|
||||||
|
"""
|
||||||
|
发起一轮 Omni-Turbo ChatCompletions 流式对话:
|
||||||
|
- content_list: OpenAI chat 的 content,多模态(image_url/text)
|
||||||
|
- 以 stream=True 返回
|
||||||
|
- 增量产出:OmniStreamPiece(text_delta=?, audio_b64=?)
|
||||||
|
|
||||||
|
Day 13 修复:使用队列+线程解耦同步 API 调用,避免阻塞事件循环
|
||||||
|
"""
|
||||||
|
# 使用 asyncio.Queue 在线程和异步之间传递数据
|
||||||
|
queue: asyncio.Queue = asyncio.Queue()
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
|
||||||
|
def _sync_stream():
|
||||||
|
"""在独立线程中运行同步 API 调用"""
|
||||||
|
try:
|
||||||
|
# Day 21 优化:添加 system prompt 让 AI 回答简洁
|
||||||
|
# 导盲眼镜场景需要快速、简短的回答
|
||||||
|
system_prompt = """你是一个视障辅助AI助手,安装在智能导盲眼镜上。
|
||||||
|
请用极简短的语言回答,每次回答不超过2-3句话。
|
||||||
|
避免冗长解释,只提供最关键的信息。
|
||||||
|
语气友好但简洁。"""
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": content_list}
|
||||||
|
]
|
||||||
|
|
||||||
|
completion = oai_client.chat.completions.create(
|
||||||
|
model=QWEN_MODEL,
|
||||||
|
messages=messages,
|
||||||
|
modalities=["text", "audio"],
|
||||||
|
audio={"voice": voice, "format": audio_format},
|
||||||
|
stream=True,
|
||||||
|
stream_options={"include_usage": True},
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in completion:
|
||||||
|
text_delta: Optional[str] = None
|
||||||
|
audio_b64: Optional[str] = None
|
||||||
|
|
||||||
|
if getattr(chunk, "choices", None):
|
||||||
|
c0 = chunk.choices[0]
|
||||||
|
delta = getattr(c0, "delta", None)
|
||||||
|
# 文本增量
|
||||||
|
if delta and getattr(delta, "content", None):
|
||||||
|
piece = delta.content
|
||||||
|
if piece:
|
||||||
|
text_delta = piece
|
||||||
|
# 音频分片
|
||||||
|
if delta and getattr(delta, "audio", None):
|
||||||
|
aud = delta.audio
|
||||||
|
audio_b64 = aud.get("data") if isinstance(aud, dict) else getattr(aud, "data", None)
|
||||||
|
if audio_b64 is None:
|
||||||
|
msg = getattr(c0, "message", None)
|
||||||
|
if msg and getattr(msg, "audio", None):
|
||||||
|
ma = msg.audio
|
||||||
|
audio_b64 = ma.get("data") if isinstance(ma, dict) else getattr(ma, "data", None)
|
||||||
|
|
||||||
|
if (text_delta is not None) or (audio_b64 is not None):
|
||||||
|
# 线程安全地放入队列
|
||||||
|
loop.call_soon_threadsafe(
|
||||||
|
queue.put_nowait,
|
||||||
|
OmniStreamPiece(text_delta=text_delta, audio_b64=audio_b64)
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
# 发生异常时也要通知
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, e)
|
||||||
|
finally:
|
||||||
|
# 发送结束标记
|
||||||
|
loop.call_soon_threadsafe(queue.put_nowait, None)
|
||||||
|
|
||||||
|
# 在独立线程中启动同步 API 调用
|
||||||
|
thread = threading.Thread(target=_sync_stream, daemon=True)
|
||||||
|
thread.start()
|
||||||
|
|
||||||
|
# 异步消费队列
|
||||||
|
while True:
|
||||||
|
item = await queue.get()
|
||||||
|
if item is None:
|
||||||
|
# 流结束
|
||||||
|
break
|
||||||
|
if isinstance(item, Exception):
|
||||||
|
# 发生异常
|
||||||
|
raise item
|
||||||
|
yield item
|
||||||
|
|
||||||
64
qwen_extractor.py
Normal file
64
qwen_extractor.py
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
# qwen_extractor.py
|
||||||
|
# -*- coding: utf-8 -*-
|
||||||
|
from typing import List, Tuple
|
||||||
|
import os
|
||||||
|
from openai import OpenAI
|
||||||
|
|
||||||
|
# —— 本地优先映射(可随时扩充/改名)——
|
||||||
|
LOCAL_CN2EN = {
|
||||||
|
"红牛": "Red_Bull",
|
||||||
|
"ad钙奶": "AD_milk",
|
||||||
|
"ad 钙奶": "AD_milk",
|
||||||
|
"ad": "AD_milk",
|
||||||
|
"钙奶": "AD_milk",
|
||||||
|
"矿泉水": "bottle",
|
||||||
|
"水瓶": "bottle",
|
||||||
|
"可乐": "coke",
|
||||||
|
"雪碧": "sprite",
|
||||||
|
}
|
||||||
|
|
||||||
|
def _make_client() -> OpenAI:
|
||||||
|
# 复用你百炼兼容端点;支持从环境变量读取
|
||||||
|
base_url = os.getenv("DASHSCOPE_COMPAT_BASE", "https://dashscope.aliyuncs.com/compatible-mode/v1")
|
||||||
|
api_key = "sk-a9440db694924559ae4ebdc2023d2b9a"
|
||||||
|
return OpenAI(api_key=api_key, base_url=base_url)
|
||||||
|
|
||||||
|
PROMPT_SYS = (
|
||||||
|
"You are a label normalizer. Convert the given Chinese object "
|
||||||
|
"description into a short, lowercase English YOLO/vision class name "
|
||||||
|
"(1~3 words). If multiple are given, return the single most likely one. "
|
||||||
|
"Output ONLY the label, no punctuation."
|
||||||
|
)
|
||||||
|
|
||||||
|
def extract_english_label(query_cn: str) -> Tuple[str, str]:
|
||||||
|
"""
|
||||||
|
返回 (label_en, source);source ∈ {'local', 'qwen', 'fallback'}
|
||||||
|
"""
|
||||||
|
q = (query_cn or "").strip().lower()
|
||||||
|
if q in LOCAL_CN2EN:
|
||||||
|
return LOCAL_CN2EN[q], "local"
|
||||||
|
|
||||||
|
# 简单规则:去掉前缀修饰词
|
||||||
|
for k, v in LOCAL_CN2EN.items():
|
||||||
|
if k in q:
|
||||||
|
return v, "local"
|
||||||
|
|
||||||
|
# 调用 Qwen Turbo(兼容 Chat Completions)
|
||||||
|
try:
|
||||||
|
client = _make_client()
|
||||||
|
msgs = [
|
||||||
|
{"role": "system", "content": PROMPT_SYS},
|
||||||
|
{"role": "user", "content": query_cn.strip()},
|
||||||
|
]
|
||||||
|
rsp = client.chat.completions.create(
|
||||||
|
model=os.getenv("QWEN_MODEL", "qwen-turbo"),
|
||||||
|
messages=msgs,
|
||||||
|
stream=False
|
||||||
|
)
|
||||||
|
label = (rsp.choices[0].message.content or "").strip()
|
||||||
|
# 清洗一下
|
||||||
|
label = label.replace(".", "").replace(",", "").replace(" ", " ").strip()
|
||||||
|
# 兜底:空就回 'bottle'
|
||||||
|
return (label or "bottle"), "qwen"
|
||||||
|
except Exception:
|
||||||
|
return "bottle", "fallback"
|
||||||
28
qwenturbo_template.py
Normal file
28
qwenturbo_template.py
Normal file
@@ -0,0 +1,28 @@
|
|||||||
|
from openai import OpenAI
|
||||||
|
import os
|
||||||
|
|
||||||
|
client = OpenAI(
|
||||||
|
# 如果没有配置环境变量,请用阿里云百炼API Key替换:api_key="sk-xxx"
|
||||||
|
api_key="sk-a9440db694924559ae4ebdc2023d2b9a",
|
||||||
|
base_url="https://dashscope.aliyuncs.com/compatible-mode/v1",
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [{"role": "user", "content": "你是谁"}]
|
||||||
|
completion = client.chat.completions.create(
|
||||||
|
model="qwen-turbo", # 您可以按需更换为其它深度思考模型
|
||||||
|
messages=messages,
|
||||||
|
extra_body={"enable_thinking": True},
|
||||||
|
stream=True
|
||||||
|
)
|
||||||
|
is_answering = False # 是否进入回复阶段
|
||||||
|
print("\n" + "=" * 20 + "思考过程" + "=" * 20)
|
||||||
|
for chunk in completion:
|
||||||
|
delta = chunk.choices[0].delta
|
||||||
|
if hasattr(delta, "reasoning_content") and delta.reasoning_content is not None:
|
||||||
|
if not is_answering:
|
||||||
|
print(delta.reasoning_content, end="", flush=True)
|
||||||
|
if hasattr(delta, "content") and delta.content:
|
||||||
|
if not is_answering:
|
||||||
|
print("\n" + "=" * 20 + "完整回复" + "=" * 20)
|
||||||
|
is_answering = True
|
||||||
|
print(delta.content, end="", flush=True)
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user