diff --git a/.env.performance b/.env.performance new file mode 100644 index 0000000..57bf4a3 --- /dev/null +++ b/.env.performance @@ -0,0 +1,68 @@ +# ============================================================ +# Day 22 性能优化配置文件 +# 复制此文件为 .env 并根据硬件调整参数 +# ============================================================ + +# === YOLO 盲道/斑马线检测 === +# 输入分辨率 (越小越快,但精度降低) +# 建议: RTX 3090 用 640, GTX 1060 用 320-384 +AIGLASS_YOLO_IMGSZ=480 + +# 检测间隔 (每N帧检测一次) +# 建议: 高端GPU用 6-8, 低端用 12-15 +AIGLASS_BLINDPATH_INTERVAL=10 + +# 启用FP16半精度 (1=启用, 0=禁用) +# 注意: GTX 1060 的FP16性能不佳,建议设为0 +AIGLASS_YOLO_HALF=1 + +# === 障碍物检测 === +# 输入分辨率 +AIGLASS_OBS_IMGSZ=480 + +# 检测间隔 +AIGLASS_OBS_INTERVAL=18 + +# 缓存帧数 +AIGLASS_OBS_CACHE_FRAMES=12 + +# 置信度阈值 +AIGLASS_OBS_CONF=0.25 + +# 启用FP16 +AIGLASS_OBS_HALF=1 + +# === GPU 并发控制 === +# 同时推理的最大任务数 +AIGLASS_GPU_SLOTS=2 + +# 设备选择 (cuda:0, cuda:1, cpu) +AIGLASS_DEVICE=cuda:0 + +# 混合精度模式 (fp16, bf16, off) +# RTX 30系列支持 bf16, 其他用 fp16 +AIGLASS_AMP=fp16 + +# === 语音播报 === +# 直行提示间隔(秒) +AIGLASS_STRAIGHT_INTERVAL=4.0 + +# 方向指令间隔(秒) +AIGLASS_DIRECTION_INTERVAL=3.0 + +# 持续播报模式 (1=启用) +AIGLASS_STRAIGHT_CONTINUOUS=1 + +# 限制模式下最大重复次数 +AIGLASS_STRAIGHT_LIMIT=2 + +# === 模型路径 (根据实际路径修改) === +# BLIND_PATH_MODEL=models/yolo-seg.pt +# OBSTACLE_MODEL=models/yoloe-11l-seg.pt + +# === 调试选项 === +# 启用ASR原始日志 +# ASR_DEBUG_RAW=1 + +# 启用红绿灯调试图像 +# AIGLASS_DEBUG_TRAFFIC_LIGHT=1 diff --git a/.gitignore b/.gitignore index 36b13f1..cd0bb7f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,176 +1,176 @@ -# ---> Python -# Byte-compiled / optimized / DLL files -__pycache__/ -*.py[cod] -*$py.class - -# C extensions -*.so - -# Distribution / packaging -.Python -build/ -develop-eggs/ -dist/ -downloads/ -eggs/ -.eggs/ -lib/ -lib64/ -parts/ -sdist/ -var/ -wheels/ -share/python-wheels/ -*.egg-info/ -.installed.cfg -*.egg -MANIFEST - -# PyInstaller -# Usually these files are written by a python script from a template -# before PyInstaller builds the exe, so as to inject date/other infos into it. -*.manifest -*.spec - -# Installer logs -pip-log.txt -pip-delete-this-directory.txt - -# Unit test / coverage reports -htmlcov/ -.tox/ -.nox/ -.coverage -.coverage.* -.cache -nosetests.xml -coverage.xml -*.cover -*.py,cover -.hypothesis/ -.pytest_cache/ -cover/ - -# Translations -*.mo -*.pot - -# Django stuff: -*.log -local_settings.py -db.sqlite3 -db.sqlite3-journal - -# Flask stuff: -instance/ -.webassets-cache - -# Scrapy stuff: -.scrapy - -# Sphinx documentation -docs/_build/ - -# PyBuilder -.pybuilder/ -target/ - -# Jupyter Notebook -.ipynb_checkpoints - -# IPython -profile_default/ -ipython_config.py - -# pyenv -# For a library or package, you might want to ignore these files since the code is -# intended to run in multiple environments; otherwise, check them in: -# .python-version - -# pipenv -# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. -# However, in case of collaboration, if having platform-specific dependencies or dependencies -# having no cross-platform support, pipenv may install dependencies that don't work, or not -# install all needed dependencies. -#Pipfile.lock - -# UV -# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -#uv.lock - -# poetry -# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. -# This is especially recommended for binary packages to ensure reproducibility, and is more -# commonly ignored for libraries. -# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control -#poetry.lock - -# pdm -# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. -#pdm.lock -# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it -# in version control. -# https://pdm.fming.dev/latest/usage/project/#working-with-version-control -.pdm.toml -.pdm-python -.pdm-build/ - -# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm -__pypackages__/ - -# Celery stuff -celerybeat-schedule -celerybeat.pid - -# SageMath parsed files -*.sage.py - -# Environments -.env -.venv -env/ -venv/ -ENV/ -env.bak/ -venv.bak/ - -# Spyder project settings -.spyderproject -.spyproject - -# Rope project settings -.ropeproject - -# mkdocs documentation -/site - -# mypy -.mypy_cache/ -.dmypy.json -dmypy.json - -# Pyre type checker -.pyre/ - -# pytype static type analyzer -.pytype/ - -# Cython debug symbols -cython_debug/ - -# PyCharm -# JetBrains specific template is maintained in a separate JetBrains.gitignore that can -# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore -# and can be added to the global gitignore or merged into this file. For a more nuclear -# option (not recommended) you can uncomment the following to ignore the entire idea folder. -#.idea/ - -# Ruff stuff: -.ruff_cache/ - -# PyPI configuration file -.pypirc - +# ---> Python +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# UV +# Similar to Pipfile.lock, it is generally recommended to include uv.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +#uv.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# Ruff stuff: +.ruff_cache/ + +# PyPI configuration file +.pypirc + diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 0000000..00f5255 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,90 @@ +# 更新日志 + +本文档记录项目的所有重要变更。 + +格式基于 [Keep a Changelog](https://keepachangelog.com/zh-CN/1.0.0/), +版本号遵循 [语义化版本](https://semver.org/lang/zh-CN/)。 + +## [未发布] + +### 新增 +- 首次开源发布 +- 完整的 GitHub 文档(README, CONTRIBUTING, LICENSE 等) +- Docker 支持 +- 环境变量配置模板 + +### 修改 +- 优化了 README 文档结构 +- 改进了代码注释 + +## [1.0.0] - 2025-01-XX + +### 新增 +- 🚶 盲道导航系统 + - 实时盲道检测与分割 + - 智能语音引导 + - 障碍物检测与避障 + - 急转弯检测与提醒 + - 光流稳定算法 + +- 🚦 过马路辅助 + - 斑马线识别与方向检测 + - 红绿灯颜色识别 + - 对齐引导系统 + - 安全提醒 + +- 🔍 物品识别与查找 + - YOLO-E 开放词汇检测 + - MediaPipe 手部引导 + - 实时目标追踪 + - 抓取动作检测 + +- 🎙️ 实时语音交互 + - 阿里云 Paraformer ASR + - Qwen-Omni-Turbo 多模态对话 + - 智能指令解析 + - 上下文感知 + +- 📹 视频与音频处理 + - WebSocket 实时推流 + - 音视频同步录制 + - IMU 数据融合 + - 多路音频混音 + +- 🎨 可视化与交互 + - Web 实时监控界面 + - IMU 3D 可视化 + - 状态面板 + - 中文友好界面 + +### 技术栈 +- FastAPI + WebSocket +- YOLO11 / YOLO-E +- MediaPipe +- PyTorch + CUDA +- OpenCV +- DashScope API + +### 已知问题 +- [ ] 在低端 GPU 上可能出现卡顿 +- [ ] macOS 上缺少 GPU 加速支持 +- [ ] 部分中文字体在 Linux 上显示不正确 + +--- + +## 版本说明 + +### 主版本(Major) +- 不兼容的 API 更改 + +### 次版本(Minor) +- 向后兼容的新功能 + +### 修订版本(Patch) +- 向后兼容的问题修复 + +--- + +[未发布]: https://github.com/yourusername/aiglass/compare/v1.0.0...HEAD +[1.0.0]: https://github.com/yourusername/aiglass/releases/tag/v1.0.0 + diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..1443dd9 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,59 @@ +# AI Glass System - Dockerfile +# 基于 NVIDIA CUDA 的 Python 镜像 + +FROM nvidia/cuda:11.8.0-cudnn8-runtime-ubuntu22.04 + +# 设置环境变量 +ENV DEBIAN_FRONTEND=noninteractive +ENV PYTHONUNBUFFERED=1 +ENV CUDA_HOME=/usr/local/cuda +ENV PATH=${CUDA_HOME}/bin:${PATH} +ENV LD_LIBRARY_PATH=${CUDA_HOME}/lib64:${LD_LIBRARY_PATH} + +# 设置工作目录 +WORKDIR /app + +# 安装系统依赖 +RUN apt-get update && apt-get install -y \ + python3.10 \ + python3-pip \ + python3-dev \ + portaudio19-dev \ + libgl1-mesa-glx \ + libglib2.0-0 \ + libsm6 \ + libxext6 \ + libxrender-dev \ + libgomp1 \ + git \ + wget \ + curl \ + && apt-get clean \ + && rm -rf /var/lib/apt/lists/* + +# 升级 pip +RUN python3 -m pip install --upgrade pip + +# 复制 requirements.txt +COPY requirements.txt . + +# 安装 Python 依赖 +RUN pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 +RUN pip install --no-cache-dir -r requirements.txt + +# 复制应用代码 +COPY . . + +# 创建必要的目录 +RUN mkdir -p recordings model music voice static templates + +# 暴露端口 +EXPOSE 8081 12345/udp + +# 健康检查 +HEALTHCHECK --interval=30s --timeout=10s --start-period=40s --retries=3 \ + CMD curl -f http://localhost:8081/api/health || exit 1 + +# 启动命令 +CMD ["python3", "app_main.py"] + diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..4b5f221 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2025 AI-FanGe + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/PROJECT_STRUCTURE.md b/PROJECT_STRUCTURE.md new file mode 100644 index 0000000..0f19329 --- /dev/null +++ b/PROJECT_STRUCTURE.md @@ -0,0 +1,402 @@ +# 项目结构说明 + +本文档详细说明项目的目录结构和主要文件的作用。 + +## 📁 目录结构 + +``` +rebuild1002/ +├── 📄 主要应用文件 +│ ├── app_main.py # 主应用入口(FastAPI 服务) +│ ├── navigation_master.py # 导航统领器(状态机) +│ ├── workflow_blindpath.py # 盲道导航工作流 +│ ├── workflow_crossstreet.py # 过马路导航工作流 +│ └── yolomedia.py # 物品查找工作流 +│ +├── 🎙️ 语音处理模块 +│ ├── asr_core.py # 语音识别核心 +│ ├── omni_client.py # Qwen-Omni 客户端 +│ ├── qwen_extractor.py # 标签提取(中文->英文) +│ ├── audio_player.py # 音频播放器 +│ └── audio_stream.py # 音频流管理 +│ +├── 🤖 模型相关 +│ ├── yoloe_backend.py # YOLO-E 后端(开放词汇) +│ ├── trafficlight_detection.py # 红绿灯检测 +│ ├── obstacle_detector_client.py # 障碍物检测客户端 +│ └── models.py # 模型定义 +│ +├── 🎥 视频处理 +│ ├── bridge_io.py # 线程安全的帧缓冲 +│ ├── sync_recorder.py # 音视频同步录制 +│ └── video_recorder.py # 视频录制(旧版) +│ +├── 🌐 Web 前端 +│ ├── templates/ +│ │ └── index.html # 主界面 HTML +│ ├── static/ +│ │ ├── main.js # 主 JS 脚本 +│ │ ├── vision.js # 视觉流处理 +│ │ ├── visualizer.js # 数据可视化 +│ │ ├── vision_renderer.js # 渲染器 +│ │ ├── vision.css # 样式表 +│ │ └── models/ # 3D 模型(IMU 可视化) +│ +├── 🎵 音频资源 +│ ├── music/ # 系统提示音 +│ │ ├── converted_向上.wav +│ │ ├── converted_向下.wav +│ │ └── ... +│ └── voice/ # 预录语音 +│ ├── voice_mapping.json +│ └── *.wav +│ +├── 🧠 模型文件 +│ └── model/ +│ ├── yolo-seg.pt # 盲道分割模型 +│ ├── yoloe-11l-seg.pt # YOLO-E 开放词汇模型 +│ ├── shoppingbest5.pt # 物品识别模型 +│ ├── trafficlight.pt # 红绿灯检测模型 +│ └── hand_landmarker.task # MediaPipe 手部模型 +│ +├── 📹 录制文件 +│ └── recordings/ # 自动保存的视频和音频 +│ ├── video_*.avi +│ └── audio_*.wav +│ +├── 🛠️ ESP32 固件 +│ └── compile/ +│ ├── compile.ino # Arduino 主程序 +│ ├── camera_pins.h # 摄像头引脚定义 +│ ├── ICM42688.cpp/h # IMU 驱动 +│ └── ESP32_VIDEO_OPTIMIZATION.md +│ +├── 🧪 测试文件 +│ ├── test_recorder.py # 录制功能测试 +│ ├── test_traffic_light.py # 红绿灯检测测试 +│ ├── test_cross_street_blindpath.py # 导航测试 +│ └── test_crosswalk_awareness.py # 斑马线检测测试 +│ +├── 📚 文档 +│ ├── README.md # 项目主文档 +│ ├── INSTALLATION.md # 安装指南 +│ ├── CONTRIBUTING.md # 贡献指南 +│ ├── FAQ.md # 常见问题 +│ ├── CHANGELOG.md # 更新日志 +│ ├── SECURITY.md # 安全政策 +│ └── PROJECT_STRUCTURE.md # 本文件 +│ +├── 🐳 Docker 相关 +│ ├── Dockerfile # Docker 镜像定义 +│ ├── docker-compose.yml # Docker Compose 配置 +│ └── .dockerignore # Docker 忽略文件 +│ +├── ⚙️ 配置文件 +│ ├── .env.example # 环境变量模板 +│ ├── .gitignore # Git 忽略文件 +│ ├── requirements.txt # Python 依赖 +│ ├── setup.sh # Linux/macOS 安装脚本 +│ └── setup.bat # Windows 安装脚本 +│ +├── 📄 许可证 +│ └── LICENSE # MIT 许可证 +│ +└── 🔧 GitHub 相关 + └── .github/ + ├── ISSUE_TEMPLATE/ + │ ├── bug_report.md + │ └── feature_request.md + └── pull_request_template.md +``` + +## 🔑 核心文件说明 + +### 主应用层 + +#### `app_main.py` +- **作用**: FastAPI 主服务,处理所有 WebSocket 连接 +- **主要功能**: + - WebSocket 路由管理(/ws/camera, /ws_audio, /ws/viewer 等) + - 模型加载与初始化 + - 状态协调与管理 + - 音视频流分发 +- **依赖**: 所有其他模块 +- **入口点**: `python app_main.py` + +#### `navigation_master.py` +- **作用**: 导航统领器,管理整个系统的状态机 +- **主要状态**: + - IDLE: 空闲 + - CHAT: 对话模式 + - BLINDPATH_NAV: 盲道导航 + - CROSSING: 过马路 + - TRAFFIC_LIGHT_DETECTION: 红绿灯检测 + - ITEM_SEARCH: 物品查找 +- **核心方法**: + - `process_frame()`: 处理每一帧 + - `start_blind_path_navigation()`: 启动盲道导航 + - `start_crossing()`: 启动过马路模式 + - `on_voice_command()`: 处理语音命令 + +### 工作流模块 + +#### `workflow_blindpath.py` +- **作用**: 盲道导航核心逻辑 +- **主要功能**: + - 盲道分割与检测 + - 障碍物检测 + - 转弯检测 + - 光流稳定 + - 方向引导生成 +- **状态机**: + - ONBOARDING: 上盲道 + - NAVIGATING: 导航中 + - MANEUVERING_TURN: 转弯 + - AVOIDING_OBSTACLE: 避障 + +#### `workflow_crossstreet.py` +- **作用**: 过马路导航逻辑 +- **主要功能**: + - 斑马线检测 + - 方向对齐 + - 引导生成 +- **核心方法**: + - `_is_crosswalk_near()`: 判断是否接近斑马线 + - `_compute_angle_and_offset()`: 计算角度和偏移 + +#### `yolomedia.py` +- **作用**: 物品查找工作流 +- **主要功能**: + - YOLO-E 文本提示检测 + - MediaPipe 手部追踪 + - 光流目标追踪 + - 手部引导(方向提示) + - 抓取动作检测 +- **模式**: + - SEGMENT: 检测模式 + - FLASH: 闪烁确认 + - CENTER_GUIDE: 居中引导 + - TRACK: 手部追踪 + +### 语音模块 + +#### `asr_core.py` +- **作用**: 阿里云 Paraformer ASR 实时语音识别 +- **主要功能**: + - 实时语音识别 + - VAD(语音活动检测) + - 识别结果回调 +- **关键类**: `ASRCallback` + +#### `omni_client.py` +- **作用**: Qwen-Omni-Turbo 多模态对话客户端 +- **主要功能**: + - 流式对话生成 + - 图像+文本输入 + - 语音输出 +- **核心函数**: `stream_chat()` + +#### `audio_player.py` +- **作用**: 统一的音频播放管理 +- **主要功能**: + - TTS 语音播放 + - 多路音频混音 + - 音量控制 + - 线程安全播放 +- **核心函数**: `play_voice_text()`, `play_audio_threadsafe()` + +### 模型后端 + +#### `yoloe_backend.py` +- **作用**: YOLO-E 开放词汇检测后端 +- **主要功能**: + - 文本提示设置 + - 实时分割 + - 目标追踪 +- **核心类**: `YoloEBackend` + +#### `trafficlight_detection.py` +- **作用**: 红绿灯检测模块 +- **检测方法**: + 1. YOLO 模型检测 + 2. HSV 颜色分类(备用) +- **输出**: 红灯/绿灯/黄灯/未知 + +#### `obstacle_detector_client.py` +- **作用**: 障碍物检测客户端 +- **主要功能**: + - 白名单类别过滤 + - 路径掩码内检测 + - 物体属性计算(面积、位置、危险度) + +### 视频处理 + +#### `bridge_io.py` +- **作用**: 线程安全的帧缓冲与分发 +- **主要功能**: + - 生产者-消费者模式 + - 原始帧缓存 + - 处理后帧分发 +- **核心函数**: + - `push_raw_jpeg()`: 接收 ESP32 帧 + - `wait_raw_bgr()`: 取原始帧 + - `send_vis_bgr()`: 发送处理后的帧 + +#### `sync_recorder.py` +- **作用**: 音视频同步录制 +- **主要功能**: + - 同步录制视频和音频 + - 自动文件命名(时间戳) + - 线程安全 +- **输出**: `recordings/video_*.avi`, `audio_*.wav` + +### 前端 + +#### `templates/index.html` +- **作用**: Web 监控界面 +- **主要区域**: + - 视频流显示 + - 状态面板 + - IMU 3D 可视化 + - 语音识别结果 + +#### `static/main.js` +- **作用**: 主 JavaScript 逻辑 +- **主要功能**: + - WebSocket 连接管理 + - UI 更新 + - 事件处理 + +#### `static/vision.js` +- **作用**: 视觉流处理 +- **主要功能**: + - WebSocket 接收视频帧 + - Canvas 渲染 + - FPS 计算 + +#### `static/visualizer.js` +- **作用**: IMU 3D 可视化(Three.js) +- **主要功能**: + - 接收 IMU 数据 + - 实时渲染设备姿态 + - 动态灯光效果 + +## 🔄 数据流 + +### 视频流 +``` +ESP32-CAM + → [JPEG] WebSocket /ws/camera + → bridge_io.push_raw_jpeg() + → yolomedia / navigation_master + → bridge_io.send_vis_bgr() + → [JPEG] WebSocket /ws/viewer + → Browser Canvas +``` + +### 音频流(上行) +``` +ESP32-MIC + → [PCM16] WebSocket /ws_audio + → asr_core + → DashScope ASR + → 识别结果 + → start_ai_with_text_custom() +``` + +### 音频流(下行) +``` +Qwen-Omni / TTS + → audio_player + → [PCM16] audio_stream + → [WAV] HTTP /stream.wav + → ESP32-Speaker +``` + +### IMU 数据流 +``` +ESP32-IMU + → [JSON] UDP 12345 + → process_imu_and_maybe_store() + → [JSON] WebSocket /ws + → visualizer.js (Three.js) +``` + +## 🎯 关键设计模式 + +### 1. 状态机模式 +- **位置**: `navigation_master.py` +- **作用**: 管理系统状态转换 +- **状态**: IDLE → CHAT / BLINDPATH_NAV / CROSSING / ... + +### 2. 生产者-消费者模式 +- **位置**: `bridge_io.py` +- **作用**: 解耦视频接收与处理 +- **实现**: 线程 + 队列 + +### 3. 策略模式 +- **位置**: 各 `workflow_*.py` +- **作用**: 不同导航策略的实现 +- **实现**: 统一的 `process_frame()` 接口 + +### 4. 单例模式 +- **位置**: 模型加载 +- **作用**: 全局共享模型实例 +- **实现**: 全局变量 + 初始化检查 + +### 5. 观察者模式 +- **位置**: WebSocket 通信 +- **作用**: 多客户端订阅视频流 +- **实现**: `camera_viewers: Set[WebSocket]` + +## 📦 依赖关系 + +``` +app_main.py +├── navigation_master.py +│ ├── workflow_blindpath.py +│ │ ├── yoloe_backend.py +│ │ └── obstacle_detector_client.py +│ ├── workflow_crossstreet.py +│ └── trafficlight_detection.py +├── yolomedia.py +│ └── yoloe_backend.py +├── asr_core.py +├── omni_client.py +├── audio_player.py +├── audio_stream.py +├── bridge_io.py +└── sync_recorder.py +``` + +## 🚀 启动流程 + +1. **初始化阶段** (`app_main.py`) + - 加载环境变量 + - 加载导航模型(YOLO、MediaPipe) + - 初始化音频系统 + - 启动录制系统 + - 预加载红绿灯模型 + +2. **服务启动** (FastAPI) + - 注册 WebSocket 路由 + - 挂载静态文件 + - 启动 UDP 监听(IMU) + - 启动 HTTP 服务(8081 端口) + +3. **运行阶段** + - 等待 ESP32 连接 + - 接收视频/音频/IMU 数据 + - 处理用户语音指令 + - 实时推送处理结果 + +4. **关闭阶段** + - 停止录制(保存文件) + - 关闭所有 WebSocket 连接 + - 释放模型资源 + - 清理临时文件 + +--- + +**提示**: 如需了解某个模块的详细实现,请查看相应源文件的注释和 docstring。 + diff --git a/README.md b/README.md index 19f64e6..397294e 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,506 @@ -# NaviGlassServer - +# AI 智能盲人眼镜系统 🤖👓 + +
+ +一个面向视障人士的智能导航与辅助系统,集成了盲道导航、过马路辅助、物品识别、实时语音交互等功能。 本项目仅为交流学习使用,请勿直接给视障人群使用。本项目内仅包含代码,模型地址:https://www.modelscope.cn/models/archifancy/AIGlasses_for_navigation 。下载后存放在/model 文件夹 + +[功能特性](#功能特性) • [快速开始](#快速开始) • [系统架构](#系统架构) • [使用说明](#使用说明) • [开发文档](#开发文档) + +
+ +--- +1 +2 +4 + +## 📋 目录 + +- [功能特性](#功能特性) +- [系统要求](#系统要求) +- [快速开始](#快速开始) +- [系统架构](#系统架构) +- [使用说明](#使用说明) +- [配置说明](#配置说明) +- [开发文档](#开发文档) + +## ✨ 功能特性 + +### 🚶 盲道导航系统 +- **实时盲道检测**:基于 YOLO 分割模型实时识别盲道 +- **智能语音引导**:提供精准的方向指引(左转、右转、直行等) +- **障碍物检测与避障**:自动识别前方障碍物并规划避障路线 +- **转弯检测**:自动识别急转弯并提前提醒 +- **光流稳定**:使用 Lucas-Kanade 光流算法稳定掩码,减少抖动 + +### 🚦 过马路辅助 +- **斑马线识别**:实时检测斑马线位置和方向 +- **红绿灯识别**:基于颜色和形状的红绿灯状态检测 +- **对齐引导**:引导用户对准斑马线中心 +- **安全提醒**:绿灯时语音提示可以通行 + +### 🔍 物品识别与查找 +- **智能物品搜索**:语音指令查找物品(如"帮我找一下红牛") +- **实时目标追踪**:使用 YOLO-E 开放词汇检测 + ByteTrack 追踪 +- **手部引导**:结合 MediaPipe 手部检测,引导用户手部靠近物品 +- **抓取检测**:检测手部握持动作,确认物品已拿到 +- **多模态反馈**:视觉标注 + 语音引导 + 居中提示 + +### 🎙️ 实时语音交互 +- **语音识别(ASR)**:基于阿里云 DashScope Paraformer 实时语音识别 +- **多模态对话**:Qwen-Omni-Turbo 支持图像+文本输入,语音输出 +- **智能指令解析**:自动识别导航、查找、对话等不同类型指令 +- **上下文感知**:在不同模式下智能过滤无关指令 + +### 📹 视频与音频处理 +- **实时视频流**:WebSocket 推流,支持多客户端同时观看 +- **音视频同步录制**:自动保存带时间戳的录像和音频文件 +- **IMU 数据融合**:接收 ESP32 的 IMU 数据,支持姿态估计 +- **多路音频混音**:支持系统语音、AI 回复、环境音同时播放 + +### 🎨 可视化与交互 +- **Web 实时监控**:浏览器端实时查看处理后的视频流 +- **IMU 3D 可视化**:Three.js 实时渲染设备姿态 +- **状态面板**:显示导航状态、检测信息、FPS 等 +- **中文友好**:所有界面和语音使用中文,支持自定义字体 + +## 💻 系统要求 + +### 硬件要求 +- **开发/服务器端**: + - CPU: Intel i5 或以上(推荐 i7/i9) + - GPU: NVIDIA GPU(支持 CUDA 11.8+,推荐 RTX 3060 或以上) + - 内存: 8GB RAM(推荐 16GB) + - 存储: 10GB 可用空间 + +- **客户端设备**(可选): + - ESP32-CAM 或其他支持 WebSocket 的摄像头 + - 麦克风(用于语音输入) + - 扬声器/耳机(用于语音输出) + +### 软件要求 +- **操作系统**: Windows 10/11, Linux (Ubuntu 20.04+), macOS 10.15+ +- **Python**: 3.9 - 3.11 +- **CUDA**: 11.8 或更高版本(GPU 加速必需) +- **浏览器**: Chrome 90+, Firefox 88+, Edge 90+(用于 Web 监控) + +### API 密钥 +- **阿里云 DashScope API Key**(必需): + - 用于语音识别(ASR)和 Qwen-Omni 对话 + - 申请地址:https://dashscope.console.aliyun.com/ + +## 🚀 快速开始 + +### 1. 克隆项目 + +```bash +git clone https://github.com/yourusername/aiglass.git +cd aiglass/rebuild1002 +``` + +### 2. 安装依赖 + +#### 创建虚拟环境(推荐) +```bash +python -m venv venv +# Windows +venv\Scripts\activate +# Linux/macOS +source venv/bin/activate +``` + +#### 安装 Python 包 +```bash +pip install -r requirements.txt +``` + +#### 安装 CUDA 和 cuDNN(GPU 加速) +请参考 [NVIDIA CUDA Toolkit 安装指南](https://developer.nvidia.com/cuda-downloads) + +### 3. 下载模型文件 + +将以下模型文件放入 `model/` 目录: + +| 模型文件 | 用途 | 大小 | 下载链接 | +|---------|------|------|---------| +| `yolo-seg.pt` | 盲道分割 | ~50MB | [待补充] | +| `yoloe-11l-seg.pt` | 开放词汇检测 | ~80MB | [待补充] | +| `shoppingbest5.pt` | 物品识别 | ~30MB | [待补充] | +| `trafficlight.pt` | 红绿灯检测 | ~20MB | [待补充] | +| `hand_landmarker.task` | 手部检测 | ~15MB | [MediaPipe Models](https://developers.google.com/mediapipe/solutions/vision/hand_landmarker#models) | + +### 4. 配置 API 密钥 + +创建 `.env` 文件: + +```bash +# .env +DASHSCOPE_API_KEY=your_api_key_here +``` + +或在代码中直接修改(不推荐): +```python +# app_main.py, line 50 +API_KEY = "your_api_key_here" +``` + +### 5. 启动系统 + +```bash +python app_main.py +``` + +系统将在 `http://0.0.0.0:8081` 启动,打开浏览器访问即可看到实时监控界面。 + +### 6. 连接设备(可选) + +如果使用 ESP32-CAM,请: +1. 烧录 `compile/compile.ino` 到 ESP32 +2. 修改 WiFi 配置,连接到同一网络 +3. ESP32 自动连接到 WebSocket 端点 + +## 🏗️ 系统架构 + +### 整体架构 + +``` +┌─────────────────────────────────────────────────────────────┐ +│ 客户端层 │ +│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ +│ │ ESP32-CAM │ │ 浏览器 │ │ 移动端 │ │ +│ │ (视频/音频) │ │ (监控界面) │ │ (语音控制) │ │ +│ └──────┬───────┘ └──────┬───────┘ └──────┬───────┘ │ +└─────────┼──────────────────┼──────────────────┼─────────────┘ + │ WebSocket │ HTTP/WS │ WebSocket +┌─────────┼──────────────────┼──────────────────┼─────────────┐ +│ │ │ │ │ +│ ┌────▼──────────────────▼──────────────────▼────────┐ │ +│ │ FastAPI 主服务 (app_main.py) │ │ +│ │ - WebSocket 路由管理 │ │ +│ │ - 音视频流分发 │ │ +│ │ - 状态管理与协调 │ │ +│ └────┬────────────────┬────────────────┬─────────────┘ │ +│ │ │ │ │ +│ ┌──────▼──────┐ ┌──────▼──────┐ ┌──────▼──────┐ │ +│ │ ASR 模块 │ │ Omni 对话 │ │ 音频播放 │ │ +│ │ (asr_core) │ │(omni_client)│ │(audio_player)│ │ +│ └──────────────┘ └──────────────┘ └──────────────┘ │ +│ │ +│ 应用层 │ +└───────────────────────────────────────────────────────────────┘ + │ │ │ +┌─────────▼──────────────────▼──────────────────▼──────────────┐ +│ 导航统领层 │ +│ ┌─────────────────────────────────────────────────┐ │ +│ │ NavigationMaster (navigation_master.py) │ │ +│ │ - 状态机:IDLE/CHAT/BLINDPATH_NAV/ │ │ +│ │ CROSSING/TRAFFIC_LIGHT/ITEM_SEARCH │ │ +│ │ - 模式切换与协调 │ │ +│ └───┬─────────────────────┬───────────────────┬───┘ │ +│ │ │ │ │ +│ ┌────▼────────┐ ┌────────▼────────┐ ┌─────▼──────┐ │ +│ │ 盲道导航 │ │ 过马路导航 │ │ 物品查找 │ │ +│ │(blindpath) │ │ (crossstreet) │ │(yolomedia) │ │ +│ └──────────────┘ └──────────────────┘ └─────────────┘ │ +└───────────────────────────────────────────────────────────────┘ + │ │ │ +┌─────────▼──────────────────▼──────────────────▼──────────────┐ +│ 模型推理层 │ +│ ┌──────────────┐ ┌──────────────┐ ┌──────────────┐ │ +│ │ YOLO 分割 │ │ YOLO-E 检测 │ │ MediaPipe │ │ +│ │ (盲道/斑马线) │ │ (开放词汇) │ │ (手部检测) │ │ +│ └──────────────┘ └──────────────┘ └──────────────┘ │ +│ ┌──────────────┐ ┌──────────────┐ │ +│ │ 红绿灯检测 │ │ 光流稳定 │ │ +│ │(HSV+YOLO) │ │(Lucas-Kanade)│ │ +│ └──────────────┘ └──────────────┘ │ +└───────────────────────────────────────────────────────────────┘ + │ +┌─────────▼─────────────────────────────────────────────────────┐ +│ 外部服务层 │ +│ ┌──────────────────────────────────────────────┐ │ +│ │ 阿里云 DashScope API │ │ +│ │ - Paraformer ASR (实时语音识别) │ │ +│ │ - Qwen-Omni-Turbo (多模态对话) │ │ +│ │ - Qwen-Turbo (标签提取) │ │ +│ └──────────────────────────────────────────────┘ │ +└───────────────────────────────────────────────────────────────┘ +``` + +### 核心模块说明 + +| 模块 | 文件 | 功能 | +|------|------|------| +| **主应用** | `app_main.py` | FastAPI 服务、WebSocket 管理、状态协调 | +| **导航统领** | `navigation_master.py` | 状态机管理、模式切换、语音节流 | +| **盲道导航** | `workflow_blindpath.py` | 盲道检测、避障、转弯引导 | +| **过马路导航** | `workflow_crossstreet.py` | 斑马线检测、红绿灯识别、对齐引导 | +| **物品查找** | `yolomedia.py` | 物品检测、手部引导、抓取确认 | +| **语音识别** | `asr_core.py` | 实时 ASR、VAD、指令解析 | +| **语音合成** | `omni_client.py` | Qwen-Omni 流式语音生成 | +| **音频播放** | `audio_player.py` | 多路混音、TTS 播放、音量控制 | +| **视频录制** | `sync_recorder.py` | 音视频同步录制 | +| **桥接 IO** | `bridge_io.py` | 线程安全的帧缓冲与分发 | + +## 📖 使用说明 + +### 语音指令 + +系统支持以下语音指令(说话时无需唤醒词): + +#### 导航控制 +``` +"开始导航" / "盲道导航" → 启动盲道导航 +"停止导航" / "结束导航" → 停止盲道导航 +"开始过马路" / "帮我过马路" → 启动过马路模式 +"过马路结束" / "结束过马路" → 停止过马路模式 +``` + +#### 红绿灯检测 +``` +"检测红绿灯" / "看红绿灯" → 启动红绿灯检测 +"停止检测" / "停止红绿灯" → 停止检测 +``` + +#### 物品查找 +``` +"帮我找一下 [物品名]" → 启动物品搜索 + 示例: + - "帮我找一下红牛" + - "找一下AD钙奶" + - "帮我找矿泉水" +"找到了" / "拿到了" → 确认找到物品 +``` + +#### 智能对话 +``` +"帮我看看这是什么" → 拍照识别 +"这个东西能吃吗" → 物品咨询 +任何其他问题 → AI 对话 +``` + +### 导航状态说明 + +系统包含以下主要状态(自动切换): + +1. **IDLE** - 空闲状态 + - 等待用户指令 + - 显示原始视频流 + +2. **CHAT** - 对话模式 + - 与 AI 进行多模态对话 + - 暂停导航功能 + +3. **BLINDPATH_NAV** - 盲道导航 + - **ONBOARDING**: 上盲道引导 + - ROTATION: 旋转对准盲道 + - TRANSLATION: 平移至盲道中心 + - **NAVIGATING**: 沿盲道行走 + - 实时方向修正 + - 障碍物检测 + - **MANEUVERING_TURN**: 转弯处理 + - **AVOIDING_OBSTACLE**: 避障 + +4. **CROSSING** - 过马路模式 + - **SEEKING_CROSSWALK**: 寻找斑马线 + - **WAIT_TRAFFIC_LIGHT**: 等待绿灯 + - **CROSSING**: 过马路中 + - **SEEKING_NEXT_BLINDPATH**: 寻找对面盲道 + +5. **ITEM_SEARCH** - 物品查找 + - 实时检测目标物品 + - 引导手部靠近 + - 确认抓取 + +6. **TRAFFIC_LIGHT_DETECTION** - 红绿灯检测 + - 实时检测红绿灯状态 + - 语音播报颜色变化 + +### Web 监控界面 + +打开浏览器访问 `http://localhost:8081`,可以看到: + +- **实时视频流**:显示处理后的视频,包括导航标注 +- **状态面板**:当前模式、检测信息、FPS +- **IMU 可视化**:设备姿态 3D 实时渲染 +- **语音识别结果**:显示识别的文字和 AI 回复 + +### WebSocket 端点 + +| 端点 | 用途 | 数据格式 | +|------|------|---------| +| `/ws/camera` | ESP32 相机推流 | Binary (JPEG) | +| `/ws/viewer` | 浏览器订阅视频 | Binary (JPEG) | +| `/ws_audio` | ESP32 音频上传 | Binary (PCM16) | +| `/ws_ui` | UI 状态推送 | JSON | +| `/ws` | IMU 数据接收 | JSON | +| `/stream.wav` | 音频下载流 | Binary (WAV) | + +## ⚙️ 配置说明 + +### 环境变量 + +创建 `.env` 文件配置以下参数: + +```bash +# 阿里云 API +DASHSCOPE_API_KEY=sk-xxxxx + +# 模型路径(可选,使用默认路径可不配置) +BLIND_PATH_MODEL=model/yolo-seg.pt +OBSTACLE_MODEL=model/yoloe-11l-seg.pt +YOLOE_MODEL_PATH=model/yoloe-11l-seg.pt + +# 导航参数 +AIGLASS_MASK_MIN_AREA=1500 # 最小掩码面积 +AIGLASS_MASK_MORPH=3 # 形态学核大小 +AIGLASS_MASK_MISS_TTL=6 # 掩码丢失容忍帧数 +AIGLASS_PANEL_SCALE=0.65 # 数据面板缩放 + +# 音频配置 +TTS_INTERVAL_SEC=1.0 # 语音播报间隔 +ENABLE_TTS=true # 启用语音播报 +``` + +### 修改模型路径 + +如果模型文件不在默认位置,可以在相应文件中修改: + +```python +# workflow_blindpath.py +seg_model_path = "your/custom/path/yolo-seg.pt" + +# yolomedia.py +YOLO_MODEL_PATH = "your/custom/path/shoppingbest5.pt" +HAND_TASK_PATH = "your/custom/path/hand_landmarker.task" +``` + +### 调整性能参数 + +根据硬件性能调整: + +```python +# yolomedia.py +HAND_DOWNSCALE = 0.8 # 手部检测降采样(越小越快,精度降低) +HAND_FPS_DIV = 1 # 手部检测抽帧(2=隔帧,3=每3帧) + +# workflow_blindpath.py +FEATURE_PARAMS = dict( + maxCorners=600, # 光流特征点数(越少越快) + qualityLevel=0.001, # 特征点质量 + minDistance=5 # 特征点最小间距 +) +``` + +## 🛠️ 开发文档 + +### 添加新的语音指令 + +1. 在 `app_main.py` 的 `start_ai_with_text_custom()` 函数中添加: + +```python +# 检查新指令 +if "新指令关键词" in user_text: + # 执行自定义逻辑 + print("[CUSTOM] 新指令被触发") + await ui_broadcast_final("[系统] 新功能已启动") + return +``` + +2. 如需修改指令过滤规则: + +```python +# 修改 allowed_keywords 列表 +allowed_keywords = ["帮我看", "帮我找", "你的新关键词"] +``` + +### 扩展导航功能 + +1. 在 `workflow_blindpath.py` 添加新状态: + +```python +# 在 BlindPathNavigator.__init__() 中初始化 +self.your_new_state_var = False + +# 在 process_frame() 中处理 +def process_frame(self, image): + if self.your_new_state_var: + # 自定义处理逻辑 + guidance_text = "新状态引导" + # ... +``` + +2. 在 `navigation_master.py` 添加状态机状态: + +```python +class NavigationMaster: + def start_your_new_mode(self): + self.state = "YOUR_NEW_MODE" + # 初始化逻辑 +``` + +### 集成新模型 + +1. 创建模型包装类: + +```python +# your_model_wrapper.py +class YourModelWrapper: + def __init__(self, model_path): + self.model = load_your_model(model_path) + + def detect(self, image): + # 推理逻辑 + return results +``` + +2. 在 `app_main.py` 中加载: + +```python +your_model = YourModelWrapper("model/your_model.pt") +``` + +3. 在相应的工作流中调用: + +```python +results = your_model.detect(image) +``` + +### 调试技巧 + +1. **启用详细日志**: + +```python +# app_main.py 顶部 +import logging +logging.basicConfig(level=logging.DEBUG) +``` + +2. **查看帧率瓶颈**: + +```python +# yolomedia.py +PERF_DEBUG = True # 打印处理时间 +``` + +3. **测试单个模块**: + +```bash +# 测试盲道导航 +python test_cross_street_blindpath.py + +# 测试红绿灯检测 +python test_traffic_light.py + +# 测试录制功能 +python test_recorder.py +``` + + + + + +## 📄 许可证 + +本项目采用 MIT 许可证 - 详见 [LICENSE](LICENSE) 文件 + + diff --git a/ai_voice_pipeline.py b/ai_voice_pipeline.py new file mode 100644 index 0000000..9d6798d --- /dev/null +++ b/ai_voice_pipeline.py @@ -0,0 +1,154 @@ +# ai_voice_pipeline.py +# -*- coding: utf-8 -*- +""" +AI 语音交互管道 - Day 21 + +整合 SenseVoice + GLM-4.5-Flash + EdgeTTS + +流程: +1. 客户端 VAD 检测语音结束 +2. 发送完整音频到服务器 +3. SenseVoice 识别 → GLM 生成回复 → EdgeTTS 合成语音 +4. 流式返回 PCM 音频 +""" + +import asyncio +from typing import Optional, Callable, AsyncGenerator + +# 导入各模块 +from sensevoice_asr import recognize as asr_recognize, init_sensevoice +from glm_client import chat as llm_chat, chat_stream as llm_chat_stream +from edge_tts_client import ( + text_to_speech_pcm_stream, + text_to_speech_pcm, + DEFAULT_VOICE, +) + + +async def init_pipeline(): + """初始化 AI 管道(服务器启动时调用)""" + await init_sensevoice() + print("[AI Pipeline] 初始化完成") + + +async def process_voice( + pcm_audio: bytes, + image_base64: Optional[str] = None, + on_text: Optional[Callable[[str], None]] = None, + on_audio: Optional[Callable[[bytes], None]] = None, +) -> str: + """ + 处理语音输入,返回 AI 回复 + + Args: + pcm_audio: PCM16 音频数据 (16kHz, mono) + image_base64: 可选的图片(用于多模态) + on_text: 文本回调(用于 UI 显示) + on_audio: 音频回调(用于流式播放) + + Returns: + AI 回复文本 + """ + # 1. 语音识别 + user_text = await asr_recognize(pcm_audio) + + if not user_text: + print("[AI Pipeline] 未识别到有效语音") + return "" + + print(f"[AI Pipeline] 用户说: {user_text}") + + # 通知 UI + if on_text: + on_text(f"用户: {user_text}") + + # 2. LLM 生成回复 + ai_response = await llm_chat(user_text, image_base64) + + if not ai_response: + print("[AI Pipeline] AI 无回复") + return "" + + print(f"[AI Pipeline] AI 回复: {ai_response}") + + # 通知 UI + if on_text: + on_text(f"AI: {ai_response}") + + # 3. TTS 合成并播放 + if on_audio: + async for audio_chunk in text_to_speech_pcm_stream(ai_response): + on_audio(audio_chunk) + + return ai_response + + +async def process_voice_stream( + pcm_audio: bytes, + image_base64: Optional[str] = None, +) -> AsyncGenerator[tuple, None]: + """ + 流式处理语音输入 + + Args: + pcm_audio: PCM16 音频数据 + image_base64: 可选的图片 + + Yields: + ("text", str) - 文本片段 + ("audio", bytes) - 音频片段 + """ + # 1. 语音识别 + user_text = await asr_recognize(pcm_audio) + + if not user_text: + return + + yield ("user_text", user_text) + + # 2. LLM 流式生成 + 3. TTS 流式合成 + # 收集一定长度的文本后送 TTS + buffer = "" + punctuation = "。,!?;:,.!?;:" + + async for text_chunk in llm_chat_stream(user_text, image_base64): + yield ("ai_text", text_chunk) + buffer += text_chunk + + # 遇到标点时合成音频 + if buffer and buffer[-1] in punctuation: + async for audio_chunk in text_to_speech_pcm_stream(buffer): + yield ("audio", audio_chunk) + buffer = "" + + # 处理剩余文本 + if buffer.strip(): + async for audio_chunk in text_to_speech_pcm_stream(buffer): + yield ("audio", audio_chunk) + + +async def text_to_voice(text: str) -> bytes: + """ + 文本转语音(用于导航提示等) + + Args: + text: 要合成的文本 + + Returns: + PCM16 音频数据 + """ + return await text_to_speech_pcm(text) + + +async def text_to_voice_stream(text: str) -> AsyncGenerator[bytes, None]: + """ + 流式文本转语音 + + Args: + text: 要合成的文本 + + Yields: + PCM16 音频块 + """ + async for chunk in text_to_speech_pcm_stream(text): + yield chunk diff --git a/app_main.py b/app_main.py new file mode 100644 index 0000000..8ebbbb3 --- /dev/null +++ b/app_main.py @@ -0,0 +1,1868 @@ +# app_main.py +# -*- coding: utf-8 -*- +import os, sys, time, json, asyncio, base64, audioop +from typing import Any, Dict, Optional, Tuple, List, Callable, Set, Deque +from collections import deque +from dataclasses import dataclass +from concurrent.futures import ThreadPoolExecutor +import re +from dotenv import load_dotenv + +# 加载环境变量 (Day 18: 修复 GPU 选择等配置不生效的问题) +load_dotenv() +# 在其它 import 之后加: + +from qwen_extractor import extract_english_label +from navigation_master import NavigationMaster, OrchestratorResult +# 新增:导入盲道导航器 +from workflow_blindpath import BlindPathNavigator +# 新增:导入过马路导航器 +from workflow_crossstreet import CrossStreetNavigator +import torch +from fastapi import FastAPI, WebSocket, WebSocketDisconnect, Request +from fastapi.responses import HTMLResponse, PlainTextResponse +from fastapi.staticfiles import StaticFiles +from starlette.websockets import WebSocketState +import uvicorn +import cv2 +import numpy as np + +# 【Day 19 优化】TurboJPEG - 比 cv2.imencode/imdecode 快 2-3 倍 +# Day 20: TensorRT 模型加载工具 +from model_utils import get_best_model_path +try: + from turbojpeg import TurboJPEG + _turbo_jpeg = TurboJPEG() + print("[INIT] TurboJPEG 加载成功,JPEG 编解码将使用加速版本") +except ImportError: + _turbo_jpeg = None + print("[INIT] TurboJPEG 未安装,使用 cv2 作为回退 (pip install PyTurboJPEG)") + +from ultralytics import YOLO +from obstacle_detector_client import ObstacleDetectorClient +from contextlib import asynccontextmanager + +# Day 18: 删除了重复的 import torch(已在 L17 导入) + + +import mediapipe as mp +import bridge_io +import threading +import yolomedia # 确保和 app_main.py 同目录,文件名就是 yolomedia.py +# ---- Windows 事件循环策略 ---- +if sys.platform.startswith("win"): + try: + asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy()) + except Exception: + pass + +# ---- .env ---- +try: + from dotenv import load_dotenv + load_dotenv() +except Exception: + pass + +# ---- Day 21: 新 AI 管道 (SenseVoice + GLM-4.6v-Flash + EdgeTTS) ---- +# 选择使用新的 AI 管道还是旧的 DashScope/Omni +USE_NEW_AI_PIPELINE = os.getenv("USE_NEW_AI_PIPELINE", "1") == "1" + +if USE_NEW_AI_PIPELINE: + # 新管道:本地 ASR + GLM + EdgeTTS + 服务器端 VAD + from sensevoice_asr import recognize as sensevoice_recognize, init_sensevoice + from glm_client import chat as glm_chat + from edge_tts_client import text_to_speech_pcm_stream + from server_vad import get_server_vad, reset_server_vad + print("[AI Pipeline] 使用新管道: SenseVoice + GLM-4.6v-Flash + EdgeTTS + Server VAD") +else: + # 旧管道:DashScope ASR + Omni + from dashscope import audio as dash_audio + API_KEY = os.getenv("DASHSCOPE_API_KEY", "sk-a9440db694924559ae4ebdc2023d2b9a") + MODEL = "paraformer-realtime-v2" + from omni_client import stream_chat, OmniStreamPiece + from asr_core import ASRCallback, set_current_recognition, stop_current_recognition + print("[AI Pipeline] 使用旧管道: DashScope + Qwen-Omni") + +# 通用常量 +AUDIO_FMT = "pcm" +SAMPLE_RATE = 16000 +SILENCE_CHUNK = b'\x00' * 640 # 20ms 静音 + +# 兼容层:当使用新管道时,提供 ASR 相关函数的 stub +if USE_NEW_AI_PIPELINE: + # 新管道不使用流式 ASR,但需要保持函数存在避免导入错误 + async def set_current_recognition(rec): pass + async def stop_current_recognition(): pass + class ASRCallback: + def __init__(self, **kwargs): pass + + +from audio_stream import ( + hard_reset_audio, + BYTES_PER_20MS_16K, + is_playing_now, + current_ai_task, + register_stream_route, + broadcast_pcm16_realtime, +) +from audio_player import initialize_audio_system, play_voice_text + +# ---- 同步录制器 ---- +import sync_recorder +import signal +import atexit + +# ---- IMU UDP ---- +UDP_IP = "0.0.0.0" +UDP_PORT = 12345 + + +# ---- 【新】lifespan 管理器(替代 on_event) ---- +@asynccontextmanager +async def lifespan(app: FastAPI): + """应用生命周期管理器 - 替代 on_event 装饰器""" + # === 启动逻辑(原 @app.on_event("startup") 的内容) === + print("[LIFESPAN] 应用启动中...") + + # 1. 注册 bridge_io 发送回调 + main_loop = asyncio.get_event_loop() + + def _sender(jpeg_bytes: bytes): + try: + if main_loop.is_closed(): + return + + global yolomedia_sending_frames + if not yolomedia_sending_frames: + yolomedia_sending_frames = True + print("[YOLOMEDIA] 开始发送处理后的帧,切换到YOLO画面", flush=True) + + async def _broadcast(): + if not camera_viewers: + return + dead = [] + for ws in list(camera_viewers): + try: + await ws.send_bytes(jpeg_bytes) + except Exception: + dead.append(ws) + for ws in dead: + try: + camera_viewers.remove(ws) + except Exception: + pass + + future = asyncio.run_coroutine_threadsafe(_broadcast(), main_loop) + except Exception as e: + if "Event loop is closed" not in str(e): + print(f"[DEBUG] _sender error: {e}", flush=True) + + bridge_io.set_sender(_sender) + + # 2. 初始化音频系统(后台线程) + def _init_audio(): + try: + initialize_audio_system() + except Exception as e: + print(f"[AUDIO] 初始化失败: {e}") + + threading.Thread(target=_init_audio, daemon=True).start() + + # 3. 启动 UDP 服务器 + loop = asyncio.get_running_loop() + await loop.create_datagram_endpoint(lambda: UDPProto(), local_addr=(UDP_IP, UDP_PORT)) + + # 4. Day 21: 预加载新 AI 管道模型(避免首次使用时延迟) + if USE_NEW_AI_PIPELINE: + async def _preload_models(): + try: + print("[PRELOAD] 预加载 Silero VAD...") + from server_vad import get_server_vad + get_server_vad() # 触发 VAD 模型加载 + + print("[PRELOAD] 预加载 SenseVoice ASR...") + from sensevoice_asr import init_sensevoice + await init_sensevoice() # 异步加载 ASR 模型 + + print("[PRELOAD] 新 AI 管道模型预加载完成") + except Exception as e: + print(f"[PRELOAD] 模型预加载失败: {e}") + + # 后台预加载,不阻塞启动 + asyncio.create_task(_preload_models()) + + print("[LIFESPAN] 应用启动完成") + + # === yield 表示应用开始运行 === + # Day 13: 使用 try-finally 确保关闭逻辑执行,并捕获 CancelledError + try: + yield + except asyncio.CancelledError: + # Ctrl+C 时 Starlette 会取消 lifespan,这是正常行为 + pass + finally: + # === 关闭逻辑(原 @app.on_event("shutdown") 的内容) === + print("[LIFESPAN] 应用关闭中...") + + # 停止YOLO媒体处理 + try: + stop_yolomedia() + except Exception: + pass + + # 停止音频和AI任务 + try: + await hard_reset_audio("shutdown") + except Exception: + pass + + # 【Day 15】关闭帧处理线程池 + try: + frame_processing_executor.shutdown(wait=False) + print("[LIFESPAN] 帧处理线程池已关闭") + except Exception: + pass + + print("[LIFESPAN] 应用关闭完成") + + + # Day 13: 强制退出进程,避免 uvicorn 挂起 + # 注意:不能在这里 import threading 或 os,否则会破坏 Python 作用域 + # 顶层已经导入了这些模块 + def _force_exit(): + import time as _time + import os as _os + _time.sleep(0.5) # 给其他清理一点时间 + _os._exit(0) + import threading as _threading + _threading.Thread(target=_force_exit, daemon=True).start() + + +app = FastAPI(lifespan=lifespan) + +# ====== 状态与容器 ====== +app.mount("/static", StaticFiles(directory="static"), name="static") + +ui_clients: Dict[int, WebSocket] = {} +current_partial: str = "" +recent_finals: List[str] = [] +RECENT_MAX = 50 +last_frames: Deque[Tuple[float, bytes]] = deque(maxlen=10) + +camera_viewers: Set[WebSocket] = set() +esp32_camera_ws: Optional[WebSocket] = None +imu_ws_clients: Set[WebSocket] = set() +esp32_audio_ws: Optional[WebSocket] = None + +# 【新增】盲道导航相关全局变量 +blind_path_navigator = None +navigation_active = False +yolo_seg_model = None +obstacle_detector = None + +# 【新增】过马路导航相关全局变量 +cross_street_navigator = None +cross_street_active = False +orchestrator = None # 新增 + +# 【新增】omni对话状态标志 +omni_conversation_active = False # 标记omni对话是否正在进行 +omni_previous_nav_state = None # 保存omni激活前的导航状态,用于恢复 + +# 【Day 15 性能优化】帧处理线程池 - Day 18 优化: 增加worker数量 +# 将 CPU 密集型的帧处理移至后台线程,避免阻塞事件循环 +frame_processing_executor = ThreadPoolExecutor(max_workers=3, thread_name_prefix="frame_proc") + +# 【Day 15 跳帧机制】异步帧处理状态 +# 避免 await 阻塞,使用后台任务 + 最新结果缓存 +_nav_processing_task = None # 当前的后台处理任务 +_nav_last_result_image = None # 最后一次成功处理的输出图像 +_nav_last_result_jpeg: bytes = None # 【Day 19 优化】缓存编码后的 JPEG,避免重复编码 +_nav_pending_frame = None # 等待处理的最新帧 +_nav_processing_lock = asyncio.Lock() # 确保单任务运行 +_nav_task_start_time = None # Day 20: 任务开始时间,用于计算处理耗时 + +# 【Day 18 性能优化】并行广播辅助函数 - 解决 WebSocket 顺序发送阻塞 + +# 【Day 19 优化】TurboJPEG 辅助函数 - 带回退逻辑 +def turbo_decode(jpeg_bytes: bytes): + """解码 JPEG 为 BGR numpy 数组,优先使用 TurboJPEG""" + if _turbo_jpeg: + return _turbo_jpeg.decode(jpeg_bytes) + else: + arr = np.frombuffer(jpeg_bytes, dtype=np.uint8) + return cv2.imdecode(arr, cv2.IMREAD_COLOR) + +def turbo_encode(bgr_image, quality: int = 80) -> bytes: + """编码 BGR numpy 数组为 JPEG bytes,优先使用 TurboJPEG""" + if _turbo_jpeg: + return _turbo_jpeg.encode(bgr_image, quality=quality) + else: + ok, enc = cv2.imencode(".jpg", bgr_image, [int(cv2.IMWRITE_JPEG_QUALITY), quality]) + return enc.tobytes() if ok else None + +async def _broadcast_to_viewers(jpeg_data: bytes) -> None: + """并行向所有 viewer 广播 JPEG 帧,避免顺序 await 阻塞事件循环""" + if not camera_viewers or not jpeg_data: + return + + viewers = list(camera_viewers) + if not viewers: + return + + # 使用 asyncio.gather 并行发送,return_exceptions=True 确保单个失败不影响其他 + async def _safe_send(ws): + try: + await ws.send_bytes(jpeg_data) + return None + except Exception: + return ws # 返回失败的 ws 以便移除 + + results = await asyncio.gather(*[_safe_send(ws) for ws in viewers], return_exceptions=True) + + # 清理失败的连接 + for r in results: + if r is not None and r in camera_viewers: + try: + camera_viewers.discard(r) + except Exception: + pass + + +def load_navigation_models(): + """加载盲道导航所需的模型""" + global yolo_seg_model, obstacle_detector + + try: + seg_model_path = os.getenv("BLIND_PATH_MODEL", "model/yolo-seg.pt") + # Day 20: 优先使用 TensorRT 引擎 + seg_model_path = get_best_model_path(seg_model_path) + #print(f"[NAVIGATION] 尝试加载模型: {seg_model_path}") + + if os.path.exists(seg_model_path): + print(f"[NAVIGATION] 模型文件存在,开始加载...") + yolo_seg_model = YOLO(seg_model_path) + + # Day 20: TensorRT 引擎不需要 .to() 和 .fuse() + from model_utils import is_tensorrt_engine + if is_tensorrt_engine(seg_model_path): + print(f"[NAVIGATION] TensorRT 引擎已加载,跳过 .to() 和 .fuse()") + elif torch.cuda.is_available(): + yolo_seg_model.to("cuda") + # Day 22 优化: 融合模型层以提升推理速度 + try: + yolo_seg_model.fuse() + print(f"[NAVIGATION] 模型层融合完成") + except Exception as e: + print(f"[NAVIGATION] 模型融合失败(非致命): {e}") + print(f"[NAVIGATION] 盲道分割模型加载成功并放到GPU: {yolo_seg_model.device}") + else: + print("[NAVIGATION] CUDA不可用,模型仍在CPU") + + # Day 22 优化: 使用配置的输入尺寸进行预热,并启用FP16 + try: + imgsz = int(os.getenv("AIGLASS_YOLO_IMGSZ", "480")) + use_half = os.getenv("AIGLASS_YOLO_HALF", "1") == "1" + test_img = np.zeros((imgsz, imgsz, 3), dtype=np.uint8) + + # 预热推理,让CUDA编译kernel + for _ in range(3): # 多次预热确保稳定 + results = yolo_seg_model.predict( + test_img, + device="cuda" if torch.cuda.is_available() else "cpu", + verbose=False, + imgsz=imgsz, + half=use_half + ) + print(f"[NAVIGATION] 模型预热成功 (imgsz={imgsz}, half={use_half})") + print(f"[NAVIGATION] 支持的类别数: {len(yolo_seg_model.names) if hasattr(yolo_seg_model, 'names') else '未知'}") + if hasattr(yolo_seg_model, 'names'): + print(f"[NAVIGATION] 模型类别: {yolo_seg_model.names}") + except Exception as e: + print(f"[NAVIGATION] 模型预热失败: {e}") + else: + print(f"[NAVIGATION] 错误:找不到模型文件: {seg_model_path}") + print(f"[NAVIGATION] 当前工作目录: {os.getcwd()}") + print(f"[NAVIGATION] 请检查文件路径是否正确") + + # 【修改开始】使用 ObstacleDetectorClient 替代直接的 YOLO + obstacle_model_path = os.getenv("OBSTACLE_MODEL", "model/yoloe-11l-seg.pt") + # Day 20: 优先使用 TensorRT 引擎 + obstacle_model_path = get_best_model_path(obstacle_model_path) + print(f"[NAVIGATION] 尝试加载障碍物检测模型: {obstacle_model_path}") + + if os.path.exists(obstacle_model_path): + print(f"[NAVIGATION] 障碍物检测模型文件存在,开始加载...") + try: + # 使用 ObstacleDetectorClient 封装的 YOLO-E + obstacle_detector = ObstacleDetectorClient(model_path=obstacle_model_path) + print(f"[NAVIGATION] ========== YOLO-E 障碍物检测器加载成功 ==========") + + # 检查模型是否成功加载 + if hasattr(obstacle_detector, 'model') and obstacle_detector.model is not None: + print(f"[NAVIGATION] YOLO-E 模型已初始化") + # Day 20: TensorRT 引擎没有 .parameters(),跳过设备检查 + if not is_tensorrt_engine(obstacle_model_path): + try: + print(f"[NAVIGATION] 模型设备: {next(obstacle_detector.model.parameters()).device}") + except StopIteration: + pass + else: + print(f"[NAVIGATION] 警告:YOLO-E 模型初始化异常") + + # 检查白名单是否成功加载 + if hasattr(obstacle_detector, 'WHITELIST_CLASSES'): + print(f"[NAVIGATION] 白名单类别数: {len(obstacle_detector.WHITELIST_CLASSES)}") + print(f"[NAVIGATION] 白名单前10个类别: {', '.join(obstacle_detector.WHITELIST_CLASSES[:10])}") + else: + print(f"[NAVIGATION] 警告:白名单类别未定义") + + # 检查文本特征是否成功预计算 + if hasattr(obstacle_detector, 'whitelist_embeddings') and obstacle_detector.whitelist_embeddings is not None: + print(f"[NAVIGATION] YOLO-E 文本特征已预计算") + print(f"[NAVIGATION] 文本特征张量形状: {obstacle_detector.whitelist_embeddings.shape if hasattr(obstacle_detector.whitelist_embeddings, 'shape') else '未知'}") + else: + print(f"[NAVIGATION] 警告:YOLO-E 文本特征未预计算") + + # 测试障碍物检测功能 + print(f"[NAVIGATION] 开始测试 YOLO-E 检测功能...") + try: + test_img = np.zeros((640, 640, 3), dtype=np.uint8) + # 在测试图像中画一个白色矩形,模拟一个物体 + cv2.rectangle(test_img, (200, 200), (400, 400), (255, 255, 255), -1) + + # 测试检测(不提供 path_mask) + test_results = obstacle_detector.detect(test_img) + print(f"[NAVIGATION] YOLO-E 检测测试成功!") + print(f"[NAVIGATION] 测试检测结果数: {len(test_results)}") + + if len(test_results) > 0: + print(f"[NAVIGATION] 测试检测到的物体:") + for i, obj in enumerate(test_results): + print(f" - 物体 {i+1}: {obj.get('name', 'unknown')}, " + f"面积比例: {obj.get('area_ratio', 0):.3f}, " + f"位置: ({obj.get('center_x', 0):.0f}, {obj.get('center_y', 0):.0f})") + except Exception as e: + print(f"[NAVIGATION] YOLO-E 检测测试失败: {e}") + import traceback + traceback.print_exc() + + print(f"[NAVIGATION] ========== YOLO-E 障碍物检测器加载完成 ==========") + + except Exception as e: + print(f"[NAVIGATION] 障碍物检测器加载失败: {e}") + import traceback + traceback.print_exc() + obstacle_detector = None + else: + print(f"[NAVIGATION] 警告:找不到障碍物检测模型文件: {obstacle_model_path}") + + except Exception as e: + print(f"[NAVIGATION] 模型加载失败: {e}") + import traceback + traceback.print_exc() + +# 在程序启动时加载模型 +print("[NAVIGATION] 开始加载导航模型...") +load_navigation_models() +print(f"[NAVIGATION] 模型加载完成 - yolo_seg_model: {yolo_seg_model is not None}") + +# Day 14 优化: 在服务器启动时就预先创建导航器实例,避免客户端连接时延迟 +if yolo_seg_model is not None and blind_path_navigator is None: + print("[NAVIGATION] 预初始化盲道导航器...") + blind_path_navigator = BlindPathNavigator(yolo_seg_model, obstacle_detector) + print("[NAVIGATION] 盲道导航器已预初始化") + +if yolo_seg_model is not None and cross_street_navigator is None: + print("[CROSS_STREET] 预初始化过马路导航器...") + cross_street_navigator = CrossStreetNavigator( + seg_model=yolo_seg_model, + coco_model=None, + obs_model=None + ) + print("[CROSS_STREET] 过马路导航器已预初始化") + +if orchestrator is None and blind_path_navigator is not None and cross_street_navigator is not None: + print("[NAV MASTER] 预初始化统领状态机...") + orchestrator = NavigationMaster(blind_path_navigator, cross_street_navigator) + print("[NAV MASTER] 统领状态机已预初始化") + +# 【新增】启动同步录制 +print("[RECORDER] 启动同步录制系统...") +sync_recorder.start_recording() +print("[RECORDER] 录制系统已启动,将自动保存视频和音频") + +# 【新增】注册退出处理器,确保Ctrl+C时保存录制文件 +def cleanup_on_exit(): + """程序退出时的清理工作""" + print("\n[SYSTEM] 正在关闭录制器...") + try: + sync_recorder.stop_recording() + print("[SYSTEM] 录制文件已保存") + except Exception as e: + print(f"[SYSTEM] 关闭录制器时出错: {e}") + +def signal_handler(sig, frame): + """处理Ctrl+C信号""" + print("\n[SYSTEM] 收到中断信号,正在安全退出...") + cleanup_on_exit() + # Day 13: 使用 os._exit() 强制退出,避免 asyncio 事件循环干扰 + import os + os._exit(0) + +# 注册信号处理器 +signal.signal(signal.SIGINT, signal_handler) # Ctrl+C +signal.signal(signal.SIGTERM, signal_handler) # 终止信号 +atexit.register(cleanup_on_exit) # 正常退出时也调用 + +print("[RECORDER] 已注册退出处理器 - Ctrl+C时会自动保存录制文件") + + + +# 【新增】预加载红绿灯检测模型(避免进入WAIT_TRAFFIC_LIGHT状态时卡顿) +try: + import trafficlight_detection + print("[TRAFFIC_LIGHT] 开始预加载红绿灯检测模型...") + if trafficlight_detection.init_model(): + print("[TRAFFIC_LIGHT] 红绿灯检测模型预加载成功") + # 执行一次测试推理,完全预热模型 + try: + test_img = np.zeros((640, 640, 3), dtype=np.uint8) + _ = trafficlight_detection.process_single_frame(test_img) + print("[TRAFFIC_LIGHT] 模型预热完成") + except Exception as e: + print(f"[TRAFFIC_LIGHT] 模型预热失败: {e}") + else: + print("[TRAFFIC_LIGHT] 红绿灯检测模型预加载失败") +except Exception as e: + print(f"[TRAFFIC_LIGHT] 红绿灯模型预加载出错: {e}") + +# ============== 关键:系统级"硬重置"总闸 ================= +interrupt_lock = asyncio.Lock() + +# ============== YOLO媒体线程管理 ================= +yolomedia_thread: Optional[threading.Thread] = None +yolomedia_stop_event = threading.Event() +yolomedia_running = False +yolomedia_sending_frames = False # 新增:标记YOLO是否已经开始发送处理后的帧 + +# ============== 红绿灯检测跳帧机制 ================= +_traffic_light_task = None +_traffic_light_result_jpeg = None +_traffic_light_pending_frame = None + +# 物品名称到YOLO类别的映射 +ITEM_TO_CLASS_MAP = { + "红牛": "Red_Bull", + "AD钙奶": "AD_milk", + "ad钙奶": "AD_milk", + "钙奶": "AD_milk", +} + +async def ui_broadcast_raw(msg: str): + dead = [] + for k, ws in list(ui_clients.items()): + try: + await ws.send_text(msg) + except Exception: + dead.append(k) + for k in dead: + ui_clients.pop(k, None) + + +async def ui_broadcast_partial(text: str): + global current_partial + current_partial = text + await ui_broadcast_raw("PARTIAL:" + text) + +async def ui_broadcast_final(text: str): + global current_partial, recent_finals + current_partial = "" + recent_finals.append(text) + if len(recent_finals) > RECENT_MAX: + recent_finals = recent_finals[-RECENT_MAX:] + await ui_broadcast_raw("FINAL:" + text) + print(f"[ASR/AI FINAL] {text}", flush=True) + +async def full_system_reset(reason: str = ""): + """ + 回到刚启动后的状态: + 1) 停播 + 取消AI任务 + 切断所有/stream.wav(hard_reset_audio) + 2) 停止 ASR 实时识别流(关键) + 3) 清 UI 状态 + 4) 清最近相机帧(避免把旧帧又拼进下一轮) + 5) 告知 ESP32:RESET(可选) + """ + # 1) 音频&AI + await hard_reset_audio(reason or "full_system_reset") + + # 2) ASR + await stop_current_recognition() + + # 3) UI + global current_partial, recent_finals + current_partial = "" + recent_finals = [] + + # 4) 相机帧 + try: + last_frames.clear() + except Exception: + pass + + # 5) 通知 ESP32 + try: + if esp32_audio_ws and (esp32_audio_ws.client_state == WebSocketState.CONNECTED): + await esp32_audio_ws.send_text("RESET") + except Exception: + pass + + print("[SYSTEM] full reset done.", flush=True) + +# ========= 启动/停止 YOLO 媒体处理 ========= +def start_yolomedia_with_target(target_name: str): + """启动yolomedia线程,搜索指定物品""" + global yolomedia_thread, yolomedia_stop_event, yolomedia_running, yolomedia_sending_frames + + # 如果已经在运行,先停止 + if yolomedia_running: + stop_yolomedia() + + # 查找对应的YOLO类别 + yolo_class = ITEM_TO_CLASS_MAP.get(target_name, target_name) + print(f"[YOLOMEDIA] Starting with target: {target_name} -> YOLO class: {yolo_class}", flush=True) + print(f"[YOLOMEDIA] Available mappings: {ITEM_TO_CLASS_MAP}", flush=True) # 添加这行调试 + + yolomedia_stop_event.clear() + yolomedia_running = True + yolomedia_sending_frames = False # 重置发送帧状态 + + def _run(): + try: + # 传递目标类别名和停止事件 + yolomedia.main(headless=True, prompt_name=yolo_class, stop_event=yolomedia_stop_event) + except Exception as e: + print(f"[YOLOMEDIA] worker stopped: {e}", flush=True) + finally: + global yolomedia_running, yolomedia_sending_frames + yolomedia_running = False + yolomedia_sending_frames = False + + yolomedia_thread = threading.Thread(target=_run, daemon=True) + yolomedia_thread.start() + print(f"[YOLOMEDIA] background worker started for: {yolo_class}(正在初始化,暂时显示原始画面)", flush=True) + +def stop_yolomedia(): + """停止yolomedia线程""" + global yolomedia_thread, yolomedia_stop_event, yolomedia_running, yolomedia_sending_frames + + if yolomedia_running: + print("[YOLOMEDIA] Stopping worker...", flush=True) + yolomedia_stop_event.set() + + # 等待线程结束(最多等5秒) + if yolomedia_thread and yolomedia_thread.is_alive(): + yolomedia_thread.join(timeout=5.0) + + yolomedia_running = False + yolomedia_sending_frames = False + + # 【新增】如果orchestrator在找物品模式,结束时不自动恢复(由命令控制) + # 只清理标志位即可 + print("[YOLOMEDIA] Worker stopped, 等待状态切换.", flush=True) + +# ========= 自定义的 start_ai_with_text,支持识别特殊命令 ========= +async def start_ai_with_text_custom(user_text: str): + """扩展版的AI启动函数,支持识别特殊命令""" + global navigation_active, blind_path_navigator, cross_street_active, cross_street_navigator, orchestrator + + # 【修改】在导航模式和红绿灯检测模式下,只有特定词才进入omni对话 + if orchestrator: + current_state = orchestrator.get_state() + # 如果在导航模式或红绿灯检测模式(非CHAT模式) + if current_state not in ["CHAT", "IDLE"]: + # 检查是否是允许的对话触发词 + allowed_keywords = ["帮我看", "帮我看下", "帮我找", "找一下", "看看", "识别一下"] + is_allowed_query = any(keyword in user_text for keyword in allowed_keywords) + + # 检查是否是导航控制命令 + nav_control_keywords = ["开始过马路", "过马路结束", "开始导航", "盲道导航", "停止导航", "结束导航", + "检测红绿灯", "看红绿灯", "停止检测", "停止红绿灯"] + is_nav_control = any(keyword in user_text for keyword in nav_control_keywords) + + # 如果既不是允许的查询,也不是导航控制命令,则丢弃 + if not is_allowed_query and not is_nav_control: + mode_name = "红绿灯检测" if current_state == "TRAFFIC_LIGHT_DETECTION" else "导航" + print(f"[{mode_name}模式] 丢弃非对话语音: {user_text}") + return # 直接丢弃,不进入omni + + # 【修改】检查是否是过马路相关命令 - 使用orchestrator控制 + if "开始过马路" in user_text or "帮我过马路" in user_text: + # 【新增】如果正在找物品,先停止 + if yolomedia_running: + stop_yolomedia() + print("[ITEM_SEARCH] 从找物品模式切换到过马路") + + if orchestrator: + orchestrator.start_crossing() + print(f"[CROSS_STREET] 过马路模式已启动,状态: {orchestrator.get_state()}") + # 播放启动语音并广播到UI + play_voice_text("过马路模式已启动。") + await ui_broadcast_final("[系统] 过马路模式已启动") + else: + print("[CROSS_STREET] 警告:导航统领器未初始化!") + play_voice_text("启动过马路模式失败,请稍后重试。") + await ui_broadcast_final("[系统] 导航系统未就绪") + return + + if "过马路结束" in user_text or "结束过马路" in user_text: + if orchestrator: + orchestrator.stop_navigation() + print(f"[CROSS_STREET] 导航已停止,状态: {orchestrator.get_state()}") + # 播放停止语音并广播到UI + play_voice_text("已停止导航。") + await ui_broadcast_final("[系统] 过马路模式已停止") + else: + await ui_broadcast_final("[系统] 导航系统未运行") + return + + # 【修改】检查是否是红绿灯检测命令 - 实现与盲道导航互斥 + if "检测红绿灯" in user_text or "看红绿灯" in user_text: + try: + import trafficlight_detection + + # 切换orchestrator到红绿灯检测模式(暂停盲道导航) + if orchestrator: + orchestrator.start_traffic_light_detection() + print(f"[TRAFFIC] 切换到红绿灯检测模式,状态: {orchestrator.get_state()}") + + # 【改进】使用主线程模式而不是独立线程,避免掉帧 + success = trafficlight_detection.init_model() # 只初始化模型,不启动线程 + trafficlight_detection.reset_detection_state() # 重置状态 + + if success: + await ui_broadcast_final("[系统] 红绿灯检测已启动") + else: + await ui_broadcast_final("[系统] 红绿灯模型加载失败") + except Exception as e: + print(f"[TRAFFIC] 启动红绿灯检测失败: {e}") + await ui_broadcast_final(f"[系统] 启动失败: {e}") + return + + if "停止检测" in user_text or "停止红绿灯" in user_text: + try: + # 恢复到对话模式 + if orchestrator: + orchestrator.stop_navigation() # 回到CHAT模式 + print(f"[TRAFFIC] 红绿灯检测停止,恢复到{orchestrator.get_state()}模式") + + # 清除红绿灯检测缓存 + global _traffic_light_result_jpeg + _traffic_light_result_jpeg = None + + await ui_broadcast_final("[系统] 红绿灯检测已停止") + except Exception as e: + print(f"[TRAFFIC] 停止红绿灯检测失败: {e}") + await ui_broadcast_final(f"[系统] 停止失败: {e}") + return + + # 【修改】检查是否是导航相关命令 - 使用orchestrator控制 + if "开始导航" in user_text or "盲道导航" in user_text or "帮我导航" in user_text: + # 【新增】如果正在找物品,先停止 + if yolomedia_running: + stop_yolomedia() + print("[ITEM_SEARCH] 从找物品模式切换到盲道导航") + + if orchestrator: + orchestrator.start_blind_path_navigation() + print(f"[NAVIGATION] 盲道导航已启动,状态: {orchestrator.get_state()}") + await ui_broadcast_final("[系统] 盲道导航已启动") + else: + print("[NAVIGATION] 警告:导航统领器未初始化!") + await ui_broadcast_final("[系统] 导航系统未就绪") + return + + if "停止导航" in user_text or "结束导航" in user_text: + if orchestrator: + orchestrator.stop_navigation() + print(f"[NAVIGATION] 导航已停止,状态: {orchestrator.get_state()}") + await ui_broadcast_final("[系统] 盲道导航已停止") + else: + await ui_broadcast_final("[系统] 导航系统未运行") + return + + nav_cmd_keywords = ["开始过马路", "过马路结束", "开始导航", "盲道导航", "停止导航", "结束导航", "立即通过", "现在通过", "继续"] + if any(k in user_text for k in nav_cmd_keywords): + if orchestrator: + orchestrator.on_voice_command(user_text) + await ui_broadcast_final("[系统] 导航模式已更新") + else: + await ui_broadcast_final("[系统] 导航统领器未初始化") + return + + # 检查是否是"帮我找/识别一下xxx"的命令 + # 扩展正则表达式,支持更多关键词 + find_pattern = r"(?:^\s*帮我)?\s*找一下\s*(.+?)(?:。|!|?|$)" + match = re.search(find_pattern, user_text) + + if match: + # 提取中文物品名称 + item_cn = match.group(1).strip() + if item_cn: + # 【新增】用本地映射 + Qwen 提取英文类名 + label_en, src = extract_english_label(item_cn) + print(f"[COMMAND] Finder request: '{item_cn}' -> '{label_en}' (src={src})", flush=True) + + # 【新增】切换到找物品模式(暂停导航) + if orchestrator: + orchestrator.start_item_search() + print(f"[ITEM_SEARCH] 已切换到找物品模式,状态: {orchestrator.get_state()}") + + # 【关键】把英文类名传给 yolomedia(它会在找不到类时自动切 YOLOE) + start_yolomedia_with_target(label_en) + + # 给前端/语音来个确认反馈 + try: + await ui_broadcast_final(f"[找物品] 正在寻找 {item_cn}...") + except Exception: + pass + + return + + # 检查是否是"找到了"的命令 + if "找到了" in user_text or "拿到了" in user_text: + print("[COMMAND] Found command detected", flush=True) + # 停止yolomedia + stop_yolomedia() + + # 【新增】停止找物品模式,恢复之前的导航状态 + if orchestrator: + orchestrator.stop_item_search(restore_nav=True) + current_state = orchestrator.get_state() + print(f"[ITEM_SEARCH] 找物品结束,当前状态: {current_state}") + + # 根据恢复的状态给出反馈 + if current_state in ["BLINDPATH_NAV", "SEEKING_CROSSWALK", "WAIT_TRAFFIC_LIGHT", "CROSSING", "SEEKING_NEXT_BLINDPATH"]: + await ui_broadcast_final("[找物品] 已找到物品,继续导航。") + else: + await ui_broadcast_final("[找物品] 已找到物品。") + else: + await ui_broadcast_final("[找物品] 已找到物品。") + + return + + # 【修改】omni对话开始时,切换到CHAT模式 + global omni_conversation_active, omni_previous_nav_state + omni_conversation_active = True + + # 保存当前导航状态并切换到CHAT模式 + if orchestrator: + current_state = orchestrator.get_state() + # 只有在导航模式下才需要保存和切换 + if current_state not in ["CHAT", "IDLE"]: + omni_previous_nav_state = current_state + orchestrator.force_state("CHAT") + print(f"[OMNI] 对话开始,从{current_state}切换到CHAT模式") + else: + omni_previous_nav_state = None + print(f"[OMNI] 对话开始(当前已在{current_state}模式)") + + # 如果不是特殊命令,执行原有的AI对话逻辑 + # 但如果yolomedia正在运行,暂时不处理普通对话 + if yolomedia_running: + print("[AI] YOLO media is running, skipping normal AI response", flush=True) + return + + # 原有的AI对话逻辑 + await start_ai_with_text(user_text) + +# ========= Omni 播放启动 ========= +async def start_ai_with_text(user_text: str): + """硬重置后,开启新的 AI 语音输出。""" + + # Day 13: 在 AI 处理开始前保存 WebSocket 引用 + from audio_stream import set_tts_websocket + set_tts_websocket(esp32_audio_ws) + + async def _runner_new_pipeline(): + """Day 21: 新管道 - GLM-4.5-Flash + EdgeTTS""" + txt_buf: List[str] = [] + + try: + # 获取图片(如果有) + img_b64 = None + if last_frames: + try: + _, jpeg_bytes = last_frames[-1] + img_b64 = base64.b64encode(jpeg_bytes).decode("ascii") + except Exception: + pass + + # 调用 GLM-4.5-Flash + print(f"[NEW AI] 调用 GLM: {user_text[:50]}...") + ai_response = await glm_chat(user_text, img_b64) + + if not ai_response: + print("[NEW AI] AI 无回复") + return + + txt_buf.append(ai_response) + print(f"[NEW AI] AI 回复: {ai_response}") + await ui_broadcast_partial("[AI] " + ai_response) + + # EdgeTTS 流式合成并发送 + # 设置 VAD TTS 播放状态,避免将 TTS 回声误识别为用户语音 + vad = get_server_vad() + vad.set_tts_playing(True) + + try: + async for audio_chunk in text_to_speech_pcm_stream(ai_response): + if audio_chunk: + await broadcast_pcm16_realtime(audio_chunk) + finally: + # TTS 播放结束,恢复 VAD 检测 + vad.set_tts_playing(False) + + print("[NEW AI] 音频播放完成") + + except asyncio.CancelledError: + raise + except Exception as e: + err_msg = f"AI Error: {str(e)}" + print(f"[NEW AI] 错误: {err_msg}") + import traceback + traceback.print_exc() + + # 1. 广播错误到 UI + try: + await ui_broadcast_final(f"[系统] {err_msg}") + except Exception: + pass + + # 2. 发送错误到客户端日志 + if esp32_audio_ws: + try: + await esp32_audio_ws.send_text(f"ERR:{str(e)[:50]}") + except Exception: + pass + + # 3. 语音播报错误 (可选,防止用户以为在思考) + try: + vad = get_server_vad() + vad.set_tts_playing(True) + async for audio_chunk in text_to_speech_pcm_stream("抱歉,我遇到了一些问题。"): + if audio_chunk: + await broadcast_pcm16_realtime(audio_chunk) + vad.set_tts_playing(False) + except Exception: + pass + finally: + global omni_conversation_active, omni_previous_nav_state + omni_conversation_active = False + + if orchestrator and omni_previous_nav_state: + orchestrator.force_state(omni_previous_nav_state) + print(f"[AI] 对话结束,恢复到{omni_previous_nav_state}模式") + omni_previous_nav_state = None + + from audio_stream import stream_clients + for sc in list(stream_clients): + if not sc.abort_event.is_set(): + try: sc.q.put_nowait(b"\x00"*BYTES_PER_20MS_16K) + except Exception: pass + try: sc.q.put_nowait(None) + except Exception: pass + + final_text = ("".join(txt_buf)).strip() or "(空响应)" + await ui_broadcast_final("[AI] " + final_text) + + async def _runner_old_pipeline(): + """旧管道 - Qwen-Omni (流式音频)""" + txt_buf: List[str] = [] + rate_state = None + + # 组装(图像+文本) + content_list = [] + if last_frames: + try: + _, jpeg_bytes = last_frames[-1] + img_b64 = base64.b64encode(jpeg_bytes).decode("ascii") + content_list.append({ + "type": "image_url", + "image_url": {"url": f"data:image/jpeg;base64,{img_b64}"} + }) + except Exception: + pass + content_list.append({"type": "text", "text": user_text}) + + try: + async for piece in stream_chat(content_list, voice="Cherry", audio_format="wav"): + if piece.text_delta: + txt_buf.append(piece.text_delta) + try: + await ui_broadcast_partial("[AI] " + "".join(txt_buf)) + except Exception: + pass + + if piece.audio_b64: + try: + pcm24 = base64.b64decode(piece.audio_b64) + except Exception: + pcm24 = b"" + if pcm24: + pcm16k, rate_state = audioop.ratecv(pcm24, 2, 1, 24000, 16000, rate_state) + pcm16k = audioop.mul(pcm16k, 2, 0.60) + if pcm16k: + await broadcast_pcm16_realtime(pcm16k) + + except asyncio.CancelledError: + raise + except Exception as e: + try: + await ui_broadcast_final(f"[AI] 发生错误:{e}") + except Exception: + pass + finally: + global omni_conversation_active, omni_previous_nav_state + omni_conversation_active = False + + if orchestrator and omni_previous_nav_state: + orchestrator.force_state(omni_previous_nav_state) + print(f"[OMNI] 对话结束,恢复到{omni_previous_nav_state}模式") + omni_previous_nav_state = None + else: + print(f"[OMNI] 对话结束(无需恢复导航状态)") + + from audio_stream import stream_clients + for sc in list(stream_clients): + if not sc.abort_event.is_set(): + try: sc.q.put_nowait(b"\x00"*BYTES_PER_20MS_16K) + except Exception: pass + try: sc.q.put_nowait(None) + except Exception: pass + + final_text = ("".join(txt_buf)).strip() or "(空响应)" + try: + await ui_broadcast_final("[AI] " + final_text) + except Exception: + pass + + # 真正启动前先硬重置 + await hard_reset_audio("start_ai_with_text") + loop = asyncio.get_running_loop() + from audio_stream import __dict__ as _as_dict + + # 根据配置选择管道 + if USE_NEW_AI_PIPELINE: + task = loop.create_task(_runner_new_pipeline()) + else: + task = loop.create_task(_runner_old_pipeline()) + + _as_dict["current_ai_task"] = task + +# ---------- 页面 / 健康 ---------- +@app.get("/", response_class=HTMLResponse) +def root(): + with open(os.path.join("templates", "index.html"), "r", encoding="utf-8") as f: + return HTMLResponse(f.read()) + +@app.get("/api/health", response_class=PlainTextResponse) +def health(): + return "OK" + +# 注册 /stream.wav +register_stream_route(app) + +# ---------- WebSocket:WebUI 文本(ASR/AI 状态推送) ---------- +@app.websocket("/ws_ui") +async def ws_ui(ws: WebSocket): + await ws.accept() + ui_clients[id(ws)] = ws + try: + init = {"partial": current_partial, "finals": recent_finals[-10:]} + await ws.send_text("INIT:" + json.dumps(init, ensure_ascii=False)) + while True: + await asyncio.sleep(60) + except (WebSocketDisconnect, asyncio.CancelledError): + pass + finally: + ui_clients.pop(id(ws), None) + + +# ---------- Day 21: 新版 AI 音频处理 (SenseVoice + GLM + EdgeTTS) ---------- +async def process_complete_audio_new_pipeline(audio_data: bytes, ws: WebSocket): + """ + 非流式音频处理: + 1. SenseVoice 识别完整音频 + 2. GLM 生成回复 + 3. EdgeTTS 流式合成并发送 + """ + try: + # 1. 语音识别 + print(f"[NEW AI] 开始识别音频: {len(audio_data)} bytes") + user_text = await sensevoice_recognize(audio_data) + + if not user_text or len(user_text.strip()) < 2: + print("[NEW AI] 未识别到有效语音") + return + + print(f"[NEW AI] 用户说: {user_text}") + await ui_broadcast_partial(f"[用户] {user_text}") + + # 检查是否是导航命令 + # 使用现有的 start_ai_with_text_custom 来处理特殊命令 + # 这样可以保持导航功能不变 + + # 2. 调用 GLM 生成回复 + ai_response = await glm_chat(user_text) + + if not ai_response: + print("[NEW AI] AI 无回复") + return + + print(f"[NEW AI] AI 回复: {ai_response}") + await ui_broadcast_final(f"[AI] {ai_response}") + + # 3. EdgeTTS 流式合成并发送 + # 设置 VAD TTS 播放状态 + vad = get_server_vad() + vad.set_tts_playing(True) + + try: + async for audio_chunk in text_to_speech_pcm_stream(ai_response): + if audio_chunk: + await broadcast_pcm16_realtime(audio_chunk) + finally: + vad.set_tts_playing(False) + + print("[NEW AI] 音频合成并发送完成") + + except Exception as e: + print(f"[NEW AI] 处理失败: {e}") + import traceback + traceback.print_exc() + + +# ---------- WebSocket:设备音频入口(ASR 上行) ---------- +@app.websocket("/ws_audio") +async def ws_audio(ws: WebSocket): + global esp32_audio_ws + esp32_audio_ws = ws + # Day 20: 连接建立时立即保存 TTS WebSocket 引用 + # 避免因引用丢失导致 TTS 音频无法发送 + from audio_stream import set_tts_websocket + set_tts_websocket(ws) + await ws.accept() + print("\n[AUDIO] client connected (TTS WebSocket reference saved)") + recognition = None + streaming = False + last_ts = time.monotonic() + keepalive_task: Optional[asyncio.Task] = None + audio_buffer = bytearray() # Day 21: 用于新管道收集音频 + + async def stop_rec(send_notice: Optional[str] = None): + nonlocal recognition, streaming, keepalive_task + if keepalive_task and not keepalive_task.done(): + keepalive_task.cancel() + try: await keepalive_task + except Exception: pass + keepalive_task = None + if recognition: + try: recognition.stop() + except Exception: pass + recognition = None + await set_current_recognition(None) + streaming = False + if send_notice: + try: await ws.send_text(send_notice) + except Exception: pass + + async def on_sdk_error(_msg: str): + await stop_rec(send_notice="RESTART") + + async def keepalive_loop(): + nonlocal last_ts, recognition, streaming + try: + while streaming and recognition is not None: + idle = time.monotonic() - last_ts + if idle > 0.35: + try: + for _ in range(30): # ~600ms 静音 + recognition.send_audio_frame(SILENCE_CHUNK) + last_ts = time.monotonic() + except Exception: + await on_sdk_error("keepalive send failed") + return + await asyncio.sleep(0.10) + except asyncio.CancelledError: + return + + try: + while True: + if WebSocketState and ws.client_state != WebSocketState.CONNECTED: + break + try: + msg = await ws.receive() + except WebSocketDisconnect: + break + except RuntimeError as e: + if "Cannot call \"receive\"" in str(e): + break + raise + + if "text" in msg and msg["text"] is not None: + raw = (msg["text"] or "").strip() + cmd = raw.upper() + + if cmd == "START": + print("[AUDIO] START received") + await stop_rec() + + # Day 13: 刷新 TTS 缓存 + try: + from audio_stream import flush_tts_buffer + flushed = await flush_tts_buffer(ws) + if flushed > 0: + print(f"[AUDIO] Flushed {flushed} bytes of cached TTS audio") + except Exception as e: + print(f"[AUDIO] Error flushing TTS buffer: {e}") + + if USE_NEW_AI_PIPELINE: + # Day 21: 新管道 - 服务器端 VAD + 非流式 SenseVoice + reset_server_vad() # 重置 VAD 状态 + + # 清除对话历史(新会话开始) + from glm_client import clear_conversation_history + clear_conversation_history() + + streaming = True + await ui_broadcast_partial("(已开始接收音频…)") + await ws.send_text("OK:STARTED") + print("[NEW ASR] 新管道已启动,服务器端 VAD 监听中") + else: + # 旧管道 - 流式 DashScope + loop = asyncio.get_running_loop() + def post(coro): + asyncio.run_coroutine_threadsafe(coro, loop) + + cb = ASRCallback( + on_sdk_error=lambda s: post(on_sdk_error(s)), + post=post, + ui_broadcast_partial=ui_broadcast_partial, + ui_broadcast_final=ui_broadcast_final, + is_playing_now_fn=is_playing_now, + start_ai_with_text_fn=start_ai_with_text_custom, + full_system_reset_fn=full_system_reset, + interrupt_lock=interrupt_lock, + ) + + recognition = dash_audio.asr.Recognition( + api_key=API_KEY, model=MODEL, format=AUDIO_FMT, + sample_rate=SAMPLE_RATE, callback=cb + ) + recognition.start() + await set_current_recognition(recognition) + streaming = True + last_ts = time.monotonic() + keepalive_task = asyncio.create_task(keepalive_loop()) + await ui_broadcast_partial("(已开始接收音频…)") + await ws.send_text("OK:STARTED") + + elif cmd == "STOP": + if recognition: + for _ in range(15): # ~300ms 静音 + try: recognition.send_audio_frame(SILENCE_CHUNK) + except Exception: break + await stop_rec(send_notice="OK:STOPPED") + + elif cmd == "RECOGNIZE" and USE_NEW_AI_PIPELINE: + # Day 21: 客户端 VAD 检测到语音结束,请求识别 + if audio_buffer and len(audio_buffer) > 3200: # 至少 100ms 音频 + print(f"[NEW ASR] 收到 RECOGNIZE 命令,音频长度: {len(audio_buffer)} bytes") + await ui_broadcast_partial("(正在识别…)") + + # 非流式识别 + user_text = await sensevoice_recognize(bytes(audio_buffer)) + audio_buffer.clear() + + if user_text and len(user_text.strip()) >= 2: + print(f"[NEW ASR] 识别结果: {user_text}") + await ui_broadcast_final(f"[用户] {user_text}") + + # 调用 AI 回复 + async with interrupt_lock: + await start_ai_with_text_custom(user_text) + await ws.send_text("OK:RECOGNIZED") + else: + print("[NEW ASR] 未识别到有效语音") + await ws.send_text("OK:EMPTY") + else: + print("[NEW ASR] 音频太短,忽略") + await ws.send_text("OK:TOO_SHORT") + + elif raw.startswith("PROMPT:"): + text = raw[len("PROMPT:"):].strip() + if text: + async with interrupt_lock: + await start_ai_with_text_custom(text) + await ws.send_text("OK:PROMPT_ACCEPTED") + else: + await ws.send_text("ERR:EMPTY_PROMPT") + + elif "bytes" in msg and msg["bytes"] is not None: + audio_bytes = msg["bytes"] + if not hasattr(ws_audio, '_audio_recv_count'): + ws_audio._audio_recv_count = 0 + ws_audio._audio_total_bytes = 0 + ws_audio._audio_recv_count += 1 + ws_audio._audio_total_bytes += len(audio_bytes) + + if ws_audio._audio_recv_count % 500 == 0: + print(f"[AUDIO] 📥 Received: {ws_audio._audio_recv_count} packets, {ws_audio._audio_total_bytes} bytes total") + + if USE_NEW_AI_PIPELINE: + # Day 21 改进: 使用服务器端 VAD 检测语音 + if streaming: + vad = get_server_vad() + vad_result = vad.process(audio_bytes) + + if vad_result['speech_started']: + await ui_broadcast_partial("(正在录音…)") + + if vad_result['speech_ended'] and vad_result['speech_audio']: + # VAD 检测到语音结束,自动触发识别 + speech_audio = vad_result['speech_audio'] + print(f"[VAD] 自动触发识别,音频长度: {len(speech_audio)} bytes") + await ui_broadcast_partial("(正在识别…)") + + # 非流式识别 + user_text = await sensevoice_recognize(speech_audio) + + if user_text and len(user_text.strip()) >= 2: + print(f"[NEW ASR] 识别结果: {user_text}") + await ui_broadcast_final(f"[用户] {user_text}") + + # 调用 AI 回复 + async with interrupt_lock: + await start_ai_with_text_custom(user_text) + await ws.send_text("OK:RECOGNIZED") + else: + print("[NEW ASR] 未识别到有效语音") + await ws.send_text("OK:EMPTY") + else: + # 旧管道:实时发送到 DashScope + if streaming and recognition: + try: + recognition.send_audio_frame(audio_bytes) + last_ts = time.monotonic() + except Exception: + await on_sdk_error("send_audio_frame failed") + + except Exception as e: + print(f"\n[WS ERROR] {e}") + finally: + await stop_rec() + try: + if WebSocketState is None or ws.client_state == WebSocketState.CONNECTED: + await ws.close(code=1000) + except Exception: + pass + if esp32_audio_ws is ws: + esp32_audio_ws = None + print("[WS] connection closed") + +# ---------- WebSocket:设备相机入口(JPEG 二进制) ---------- +@app.websocket("/ws/camera") +async def ws_camera_esp(ws: WebSocket): + global esp32_camera_ws, blind_path_navigator, cross_street_navigator, cross_street_active, navigation_active, orchestrator + if esp32_camera_ws is not None: + await ws.close(code=1013) + return + esp32_camera_ws = ws + await ws.accept() + print("[CAMERA] 设备已连接") + + # 【新增】初始化盲道导航器 + if blind_path_navigator is None and yolo_seg_model is not None: + blind_path_navigator = BlindPathNavigator(yolo_seg_model, obstacle_detector) + print("[NAVIGATION] 盲道导航器已初始化") + else: + if blind_path_navigator is not None: + print("[NAVIGATION] 导航器已存在,无需重新初始化") + elif yolo_seg_model is None: + print("[NAVIGATION] 警告:YOLO模型未加载,无法初始化导航器") + + # 【新增】初始化过马路导航器 + if cross_street_navigator is None: + if yolo_seg_model: + cross_street_navigator = CrossStreetNavigator( + seg_model=yolo_seg_model, + coco_model=None, # 不使用交通灯检测 + obs_model=None # 暂时也不用障碍物检测,让它更快 + ) + print("[CROSS_STREET] 过马路导航器已初始化(简化版 - 仅斑马线检测)") + else: + print("[CROSS_STREET] 错误:缺少分割模型,无法初始化过马路导航器") + + if not yolo_seg_model: + print("[CROSS_STREET] - 缺少分割模型 (yolo_seg_model)") + if not obstacle_detector: + print("[CROSS_STREET] - 缺少障碍物检测器 (obstacle_detector)") + + if orchestrator is None and blind_path_navigator is not None and cross_street_navigator is not None: + orchestrator = NavigationMaster(blind_path_navigator, cross_street_navigator) + print("[NAV MASTER] 统领状态机已初始化(托管模式)") + frame_counter = 0 # 添加帧计数器 + + # Day 20: 性能诊断变量 + _perf_last_frame_time = None + _perf_frame_intervals = [] + _perf_broadcast_times = [] + _perf_nav_times = [] + + try: + while True: + msg = await ws.receive() + if "bytes" in msg and msg["bytes"] is not None: + data = msg["bytes"] + frame_counter += 1 + + # Day 20: 记录帧接收时间 + _perf_frame_time = time.perf_counter() + if _perf_last_frame_time is not None: + _perf_frame_intervals.append(_perf_frame_time - _perf_last_frame_time) + _perf_last_frame_time = _perf_frame_time + + # 【新增】录制原始帧 + try: + sync_recorder.record_frame(data) + except Exception as e: + if frame_counter % 100 == 0: # 避免日志刷屏 + print(f"[RECORDER] 录制帧失败: {e}") + + try: + last_frames.append((time.time(), data)) + except Exception: + pass + + # 推送到bridge_io(供yolomedia使用) + bridge_io.push_raw_jpeg(data) + + # 【调试】检查导航条件 + if frame_counter % 60 == 0: # 每60帧输出一次(约5-6秒) + state_dbg = orchestrator.get_state() if orchestrator else "N/A" + + # Day 20: 性能诊断汇总 + if _perf_frame_intervals: + avg_interval = sum(_perf_frame_intervals) / len(_perf_frame_intervals) * 1000 + fps = 1000 / avg_interval if avg_interval > 0 else 0 + _perf_frame_intervals.clear() + else: + avg_interval = 0 + fps = 0 + + avg_broadcast = sum(_perf_broadcast_times) / len(_perf_broadcast_times) if _perf_broadcast_times else 0 + avg_nav = sum(_perf_nav_times) / len(_perf_nav_times) if _perf_nav_times else 0 + _perf_broadcast_times.clear() + _perf_nav_times.clear() + + print(f"[PERF] 帧:{frame_counter} | 客户端FPS:{fps:.1f} | 帧间隔:{avg_interval:.1f}ms | " + f"广播:{avg_broadcast:.1f}ms | 导航:{avg_nav:.1f}ms | state={state_dbg}") + + # 【Day 19 优化】延迟解码:只在需要处理时才解码,避免白白浪费 CPU + # 先检查是否需要导航处理 + needs_processing = (orchestrator and not yolomedia_running) + bgr = None # 延迟初始化 + + if needs_processing: + current_state = orchestrator.get_state() + + # 【Day 19】ITEM_SEARCH/CHAT/IDLE 模式无需处理,直接转发原始 JPEG + if current_state in ("ITEM_SEARCH", "CHAT", "IDLE"): + if not yolomedia_sending_frames and camera_viewers: + await _broadcast_to_viewers(data) # 零拷贝直传 + continue + + # 需要导航处理时才解码 + try: + bgr = turbo_decode(data) + if bgr is None or bgr.size == 0: + if frame_counter % 30 == 0: + print(f"[JPEG] 解码失败:数据长度={len(data)}") + bgr = None + except Exception as e: + if frame_counter % 30 == 0: + print(f"[JPEG] 解码异常: {e}") + bgr = None + + # 【托管】优先交给统领状态机(寻物未占用画面时) + if orchestrator and not yolomedia_running and bgr is not None: + out_img = bgr # 默认输出原图 + try: + # 【新增】检查是否在红绿灯检测模式 + if current_state == "TRAFFIC_LIGHT_DETECTION": + # 红绿灯检测模式:使用跳帧机制避免阻塞 + import trafficlight_detection + global _traffic_light_task, _traffic_light_result_jpeg, _traffic_light_pending_frame + + # 更新待处理帧 + _traffic_light_pending_frame = bgr + + # 如果没有正在运行的任务,启动一个 + if _traffic_light_task is None or _traffic_light_task.done(): + if _traffic_light_task is not None and _traffic_light_task.done(): + try: + result = _traffic_light_task.result() + if result and result.get('vis_image') is not None: + enc = turbo_encode(result['vis_image'], quality=80) + if enc: + _traffic_light_result_jpeg = enc + except Exception: + pass + + # 启动新任务 + if _traffic_light_pending_frame is not None: + frame = _traffic_light_pending_frame + _traffic_light_pending_frame = None + loop = asyncio.get_event_loop() + _traffic_light_task = loop.run_in_executor( + frame_processing_executor, + trafficlight_detection.process_single_frame, + frame, + None + ) + + # 广播红绿灯检测结果(独立于盲道导航缓存) + if camera_viewers: + if _traffic_light_result_jpeg is not None: + await _broadcast_to_viewers(_traffic_light_result_jpeg) + else: + await _broadcast_to_viewers(data) # 首帧回退 + continue # 跳过盲道导航的广播逻辑 + else: + # 【Day 15 跳帧机制】非阻塞式帧处理 + # 不等待处理完成,使用最后一次成功的结果 + global _nav_processing_task, _nav_last_result_image, _nav_last_result_jpeg, _nav_pending_frame + + # 更新待处理帧(始终是最新的) + _nav_pending_frame = bgr + + # 如果没有正在运行的任务,启动一个 + if _nav_processing_task is None or _nav_processing_task.done(): + # 检查上一个任务的结果 + if _nav_processing_task is not None and _nav_processing_task.done(): + # Day 20: 记录处理耗时 + global _nav_task_start_time + if _nav_task_start_time is not None: + nav_elapsed = (time.perf_counter() - _nav_task_start_time) * 1000 + _perf_nav_times.append(nav_elapsed) + _nav_task_start_time = None + + try: + res = _nav_processing_task.result() + if res is not None: + _nav_last_result_image = res.annotated_image + # 【Day 19 优化】立即编码并缓存 JPEG,避免每帧重复编码 + if _nav_last_result_image is not None: + # 使用 TurboJPEG 编码 + enc_result = turbo_encode(_nav_last_result_image, quality=80) + if enc_result: + _nav_last_result_jpeg = enc_result + # 语音引导 + if res.guidance_text: + try: + # Day 21 优化:视觉优先级中断 + # 当检测到近距离障碍物时,打断正在进行的 AI 对话 + obstacle_keywords = ['前方有', '停一下', '注意避让', '左侧有', '右侧有'] + is_obstacle_warning = any(kw in res.guidance_text for kw in obstacle_keywords) + + if is_obstacle_warning: + # 检查是否有正在进行的 AI 对话 + if is_playing_now(): + # 打断 AI 对话,优先播报障碍物警告 + print(f"[PRIORITY INTERRUPT] 检测到障碍物警告,打断AI对话: {res.guidance_text}") + asyncio.create_task(hard_reset_audio("Obstacle priority interrupt")) + + play_voice_text(res.guidance_text) + asyncio.create_task(ui_broadcast_final(f"[导航] {res.guidance_text}")) + except Exception: + pass + except Exception: + print(f"[NAV MASTER] 获取导航结果异常:") + traceback.print_exc() + + # 启动新的处理任务 + if _nav_pending_frame is not None: + frame_to_process = _nav_pending_frame + _nav_pending_frame = None + _nav_task_start_time = time.perf_counter() # Day 20: 记录开始时间 + loop = asyncio.get_event_loop() + _nav_processing_task = loop.run_in_executor( + frame_processing_executor, + orchestrator.process_frame, + frame_to_process + ) + + # 使用最后一次成功的结果(不阻塞等待) + out_img = _nav_last_result_image if _nav_last_result_image is not None else bgr + except Exception as e: + if frame_counter % 100 == 0: + print(f"[NAV MASTER] 处理帧时出错: {e}") + + # 【Day 19 优化】广播导航结果,优先使用缓存的 JPEG + if camera_viewers: + _t_broadcast = time.perf_counter() # Day 20: 计时 + # 如果有缓存的 JPEG(导航结果),直接使用 + if _nav_last_result_jpeg is not None: + await _broadcast_to_viewers(_nav_last_result_jpeg) + elif out_img is not None: + # 回退:使用 TurboJPEG 编码当前帧 + enc_result = turbo_encode(out_img, quality=80) + if enc_result: + await _broadcast_to_viewers(enc_result) + else: + # 【Day 23 修复】首帧回退:导航刚启动时无处理结果,直接广播原始帧 + await _broadcast_to_viewers(data) + _perf_broadcast_times.append((time.perf_counter() - _t_broadcast) * 1000) # Day 20 + # 已托管,进入下一帧 + continue + + # 【Day 19 优化】零拷贝直传:原始 JPEG 直接转发,无需解码再编码 + # 之前的问题:imdecode + imencode 浪费 CPU,原始 data 就是 JPEG + if not yolomedia_sending_frames and camera_viewers: + try: + # 直接转发原始 JPEG 数据,跳过解码-编码循环 + await _broadcast_to_viewers(data) + except Exception as e: + print(f"[CAMERA] Broadcast error: {e}") + + elif "type" in msg and msg["type"] in ("websocket.close", "websocket.disconnect"): + break + except WebSocketDisconnect: + pass + except Exception as e: + print(f"[CAMERA ERROR] {e}") + finally: + try: + if WebSocketState is None or ws.client_state == WebSocketState.CONNECTED: + await ws.close(code=1000) + except Exception: + pass + esp32_camera_ws = None + print("[CAMERA] 设备已断开") + + # 【新增】清理导航状态 + if blind_path_navigator: + blind_path_navigator.reset() + if cross_street_navigator: + cross_street_navigator.reset() + if orchestrator: + orchestrator.reset() + print("[NAV MASTER] 统领器已重置") + +# ---------- WebSocket:浏览器订阅相机帧 ---------- +@app.websocket("/ws/viewer") +async def ws_viewer(ws: WebSocket): + await ws.accept() + camera_viewers.add(ws) + print(f"[VIEWER] Browser connected. Total viewers: {len(camera_viewers)}", flush=True) + try: + while True: + # 保持连接活跃 + await asyncio.sleep(60) + except (WebSocketDisconnect, asyncio.CancelledError): + pass # 正常关闭,静默处理 + finally: + try: + camera_viewers.remove(ws) + except Exception: + pass + print(f"[VIEWER] Removed. Total viewers: {len(camera_viewers)}", flush=True) + +# ---------- WebSocket:浏览器订阅 IMU ---------- +@app.websocket("/ws") +async def ws_imu(ws: WebSocket): + await ws.accept() + imu_ws_clients.add(ws) + try: + while True: + await asyncio.sleep(60) + except (WebSocketDisconnect, asyncio.CancelledError): + pass # 正常关闭,静默处理 + finally: + imu_ws_clients.discard(ws) + +async def imu_broadcast(msg: str): + if not imu_ws_clients: return + dead = [] + for ws in list(imu_ws_clients): + try: + await ws.send_text(msg) + except Exception: + dead.append(ws) + for ws in dead: + imu_ws_clients.discard(ws) + +# ---------- 服务端 IMU 估计(原样保留) ---------- +from math import atan2, hypot, pi +GRAV_BETA = 0.98 +STILL_W = 0.4 +YAW_DB = 0.08 +YAW_LEAK = 0.2 +ANG_EMA = 0.15 +AUTO_REZERO = True +USE_PROJ = True +FREEZE_STILL= True +G = 9.807 +A_TOL = 0.08 * G +gLP = {"x":0.0, "y":0.0, "z":0.0} +gOff= {"x":0.0, "y":0.0, "z":0.0} +BIAS_ALPHA = 0.002 +yaw = 0.0 +Rf = Pf = Yf = 0.0 +ref = {"roll":0.0, "pitch":0.0, "yaw":0.0} +holdStart = 0.0 +isStill = False +last_ts_imu = 0.0 +last_wall = 0.0 +imu_store: List[Dict[str, Any]] = [] + +def _wrap180(a: float) -> float: + a = a % 360.0 + if a >= 180.0: a -= 360.0 + if a < -180.0: a += 360.0 + return a + +def process_imu_and_maybe_store(d: Dict[str, Any]): + global gLP, gOff, yaw, Rf, Pf, Yf, ref, holdStart, isStill, last_ts_imu, last_wall + + t_ms = float(d.get("ts", 0.0)) + now_wall = time.monotonic() + if t_ms <= 0.0: + t_ms = (now_wall * 1000.0) + if last_ts_imu <= 0.0 or t_ms <= last_ts_imu or (t_ms - last_ts_imu) > 3000.0: + dt = 0.02 + else: + dt = (t_ms - last_ts_imu) / 1000.0 + last_ts_imu = t_ms + + ax = float(((d.get("accel") or {}).get("x", 0.0))) + ay = float(((d.get("accel") or {}).get("y", 0.0))) + az = float(((d.get("accel") or {}).get("z", 0.0))) + wx = float(((d.get("gyro") or {}).get("x", 0.0))) + wy = float(((d.get("gyro") or {}).get("y", 0.0))) + wz = float(((d.get("gyro") or {}).get("z", 0.0))) + + gLP["x"] = GRAV_BETA * gLP["x"] + (1.0 - GRAV_BETA) * ax + gLP["y"] = GRAV_BETA * gLP["y"] + (1.0 - GRAV_BETA) * ay + gLP["z"] = GRAV_BETA * gLP["z"] + (1.0 - GRAV_BETA) * az + gmag = hypot(gLP["x"], gLP["y"], gLP["z"]) or 1.0 + gHat = {"x": gLP["x"]/gmag, "y": gLP["y"]/gmag, "z": gLP["z"]/gmag} + + roll = (atan2(az, ay) * 180.0 / pi) + pitch = (atan2(-ax, ay) * 180.0 / pi) + + aNorm = hypot(ax, ay, az); wNorm = hypot(wx, wy, wz) + nearFlat = (abs(roll) < 2.0 and abs(pitch) < 2.0) + stillCond = (abs(aNorm - G) < A_TOL) and (wNorm < STILL_W) + + if stillCond: + if holdStart <= 0.0: holdStart = t_ms + if not isStill and (t_ms - holdStart) > 350.0: isStill = True + gOff["x"] = (1.0 - BIAS_ALPHA)*gOff["x"] + BIAS_ALPHA*wx + gOff["y"] = (1.0 - BIAS_ALPHA)*gOff["y"] + BIAS_ALPHA*wy + gOff["z"] = (1.0 - BIAS_ALPHA)*gOff["z"] + BIAS_ALPHA*wz + else: + holdStart = 0.0; isStill = False + + if USE_PROJ: + yawdot = ((wx - gOff["x"])*gHat["x"] + (wy - gOff["y"])*gHat["y"] + (wz - gOff["z"])*gHat["z"]) + else: + yawdot = (wy - gOff["y"]) + + if abs(yawdot) < YAW_DB: yawdot = 0.0 + if FREEZE_STILL and stillCond: yawdot = 0.0 + + yaw = _wrap180(yaw + yawdot * dt) + + if (YAW_LEAK > 0.0) and nearFlat and stillCond and abs(yaw) > 0.0: + step = YAW_LEAK * dt * (-1.0 if yaw > 0 else (1.0 if yaw < 0 else 0.0)) + if abs(yaw) <= abs(step): yaw = 0.0 + else: yaw += step + + global Rf, Pf, Yf, ref, last_wall + Rf = ANG_EMA * roll + (1.0 - ANG_EMA) * Rf + Pf = ANG_EMA * pitch + (1.0 - ANG_EMA) * Pf + Yf = ANG_EMA * yaw + (1.0 - ANG_EMA) * Yf + + if AUTO_REZERO and nearFlat and (wNorm < STILL_W): + if holdStart <= 0.0: holdStart = t_ms + if not isStill and (t_ms - holdStart) > 350.0: + ref.update({"roll": Rf, "pitch": Pf, "yaw": Yf}) + isStill = True + + R = _wrap180(Rf - ref["roll"]) + P = _wrap180(Pf - ref["pitch"]) + Y = _wrap180(Yf - ref["yaw"]) + + now_wall = time.monotonic() + if last_wall <= 0.0 or (now_wall - last_wall) >= 0.100: + last_wall = now_wall + item = { + "ts": t_ms/1000.0, + "angles": {"roll": R, "pitch": P, "yaw": Y}, + "accel": {"x": ax, "y": ay, "z": az}, + "gyro": {"x": wx, "y": wy, "z": wz}, + } + imu_store.append(item) + +# ---------- UDP 接收 IMU 并转发 ---------- +class UDPProto(asyncio.DatagramProtocol): + def connection_made(self, transport): + print(f"[UDP] listening on {UDP_IP}:{UDP_PORT}") + def datagram_received(self, data, addr): + try: + s = data.decode('utf-8', errors='ignore').strip() + d = json.loads(s) + if 'ts' not in d and 'timestamp_ms' in d: + d['ts'] = d.pop('timestamp_ms') + process_imu_and_maybe_store(d) + asyncio.create_task(imu_broadcast(json.dumps(d))) + except Exception: + pass + + + + + + + +# --- 导出接口(可选) --- +def get_last_frames(): + return last_frames + +def get_camera_ws(): + return esp32_camera_ws + +if __name__ == "__main__": + import signal + import logging + + # Day 13: 抑制 Ctrl+C 时的 asyncio CancelledError 日志 + logging.getLogger("uvicorn.error").setLevel(logging.CRITICAL) + + # Day 13: 移除重复的信号处理器,模块级别已经处理了 + # 信号处理在模块顶部已注册 + + # Day 20: Numba JIT 预热,避免首次调用时的编译延迟 + try: + from numba_utils import warmup as numba_warmup + numba_warmup() + except ImportError: + print("[启动] Numba 未安装,跳过预热") + + uvicorn.run( + app, host="0.0.0.0", port=8081, + log_level="warning", access_log=False, + loop="asyncio", workers=1, reload=False + ) diff --git a/asr_core.py b/asr_core.py new file mode 100644 index 0000000..b3cdcdd --- /dev/null +++ b/asr_core.py @@ -0,0 +1,221 @@ +# asr_core.py +# -*- coding: utf-8 -*- +import os, json, asyncio +from typing import Any, Dict, List, Optional, Callable, Tuple + +ASR_DEBUG_RAW = os.getenv("ASR_DEBUG_RAW", "0") == "1" + +def _shorten(s: str, limit: int = 200) -> str: + if not s: + return "" + return s if len(s) <= limit else (s[:limit] + "…") + +def _safe_to_dict(x: Any) -> Dict[str, Any]: + if isinstance(x, dict): return x + for attr in ("to_dict", "model_dump", "__dict__"): + try: + v = getattr(x, attr, None) + except Exception: + v = None + if callable(v): + try: + d = v() + if isinstance(d, dict): return d + except Exception: + pass + elif isinstance(v, dict): + return v + try: + s = str(x) + if s and s.lstrip().startswith("{") and s.rstrip().endswith("}"): + return json.loads(s) + except Exception: + pass + return {"_raw": str(x)} + +def _extract_sentence(event_obj: Any) -> Tuple[Optional[str], Optional[bool]]: + d = _safe_to_dict(event_obj) + cands: List[Dict[str, Any]] = [d] + for k in ("output", "data", "result"): + v = d.get(k) + if isinstance(v, dict): + cands.append(v) + for obj in cands: + sent = obj.get("sentence") + if isinstance(sent, dict): + text = sent.get("text") + is_end = sent.get("sentence_end") + if is_end is not None: + is_end = bool(is_end) + return text, is_end + for obj in cands: + if "text" in obj and isinstance(obj.get("text"), str): + return obj.get("text"), None + return None, None + +# ====== 仅热词触发的“全清零复位”配置 ====== +INTERRUPT_KEYWORDS = set( + os.getenv("INTERRUPT_KEYWORDS", "停下,别说了,停止").split(",") +) + +# Day 21: 导航控制白名单 - 这些命令包含热词但不应触发重置 +# 例如"停止导航"包含"停止",但不应该触发 full_system_reset +NAV_CONTROL_WHITELIST = [ + "停止导航", "结束导航", "停止检测", "停止红绿灯", + "开始导航", "盲道导航", "开始过马路", "过马路结束", + "帮我导航", "帮我过马路" +] + + +def _normalize_cn(s: str) -> str: + try: + import unicodedata + s = "".join(" " if unicodedata.category(ch) == "Zs" else ch for ch in s) + s = s.strip().lower() + except Exception: + s = (s or "").strip().lower() + return s + +# ============ ASR 全局总闸 ============ +_current_recognition: Optional[object] = None +_rec_lock = asyncio.Lock() + +async def set_current_recognition(r): + global _current_recognition + async with _rec_lock: + _current_recognition = r + +async def stop_current_recognition(): + global _current_recognition + async with _rec_lock: + r = _current_recognition + _current_recognition = None + if r: + try: + r.stop() # DashScope SDK 的实时识别停止 + except Exception: + pass + +# ============ ASR 回调 ============ +class ASRCallback: + """ + 设计目标: + 1) “停下 / 别说了 …”等热词一出现 → 立刻全清零复位(恢复到刚启动后的状态)。 + 2) 除此之外【不接受打断】;AI 正在播报时,用户说话只做展示,不触发新一轮。 + 3) 不再用 partial 叠加字符串;partial 只用于 UI 临时展示;只有 final sentence 用于驱动 AI。 + """ + + def __init__( + self, + on_sdk_error: Callable[[str], None], + post: Callable[[asyncio.Future], None], + ui_broadcast_partial, + ui_broadcast_final, + is_playing_now_fn: Callable[[], bool], + start_ai_with_text_fn, # async (text) + full_system_reset_fn, # async (reason) + interrupt_lock: asyncio.Lock, + ): + self._on_sdk_error = on_sdk_error + self._post = post + self._last_partial_for_ui: str = "" # 只用于 UI 展示 + self._last_final_text: str = "" # 以句末 final 为准 + self._hot_interrupted: bool = False # 本句是否因热词触发过复位(防抖) + + self._ui_partial = ui_broadcast_partial + self._ui_final = ui_broadcast_final + self._is_playing = is_playing_now_fn + self._start_ai = start_ai_with_text_fn + self._full_reset = full_system_reset_fn + self._interrupt_lock = interrupt_lock + + def on_open(self): pass + def on_close(self): pass + def on_complete(self): pass + + def on_error(self, err): + try: + self._post(self._ui_partial("")) + self._on_sdk_error(str(err)) + except Exception: + pass + + def on_result(self, result): self._handle(result) + def on_event(self, event): self._handle(event) + + def _has_hotword(self, text: str) -> bool: + """Day 21 修复: 检查是否包含热词,但排除导航控制命令""" + t = _normalize_cn(text) + if not t: return False + + # Day 21: 先检查是否是导航控制命令(白名单),如果是则不触发热词 + for nav_cmd in NAV_CONTROL_WHITELIST: + if _normalize_cn(nav_cmd) in t: + return False # 导航命令不触发热词重置 + + # 检查热词 + for w in INTERRUPT_KEYWORDS: + if w and _normalize_cn(w) in t: + return True + return False + + def _handle(self, event: Any): + if ASR_DEBUG_RAW: + try: + rawd = _safe_to_dict(event) + print("[ASR EVENT RAW]", json.dumps(rawd, ensure_ascii=False), flush=True) + except Exception: + pass + + text, is_end = _extract_sentence(event) + if text is None: + return + text = text.strip() + if not text: + return + + # ---------- ① 热词优先:命中就全清零并短路,绝不送 LLM ---------- + if not self._hot_interrupted and self._has_hotword(text): + self._hot_interrupted = True + + async def _hot_reset(): + async with self._interrupt_lock: + print(f"[ASR HOTWORD] '{text}' -> FULL RESET, skip LLM", flush=True) + await self._full_reset("Hotword interrupt") + try: + self._post(_hot_reset()) + except Exception: + pass + return + + # ---------- ② partial:仅用于 UI 展示 ---------- + self._last_partial_for_ui = text + try: + print(f"[ASR PARTIAL] len={len(text)} text='{_shorten(text)}'", flush=True) + self._post(self._ui_partial(self._last_partial_for_ui)) + except Exception: + pass + + # ---------- ③ final:仅 final 驱动 LLM(若未在播报) ---------- + if is_end is True: + final_text = text + try: + print(f"[ASR FINAL] len={len(final_text)} text='{final_text}'", flush=True) + self._post(self._ui_final(final_text)) + except Exception: + pass + + if (not self._is_playing()) and final_text: + async def _run_final(): + async with self._interrupt_lock: + print(f"[LLM INPUT TEXT] {final_text}", flush=True) + await self._start_ai(final_text) + try: + self._post(_run_final()) + except Exception: + pass + + # 复位进入下一句 + self._last_partial_for_ui = "" + self._last_final_text = "" + self._hot_interrupted = False diff --git a/audio_compressor.py b/audio_compressor.py new file mode 100644 index 0000000..9b60a4e --- /dev/null +++ b/audio_compressor.py @@ -0,0 +1,439 @@ +# audio_compressor.py +# -*- coding: utf-8 -*- +""" +音频压缩工具 - 用于减少网络带宽占用 +支持将16kHz 16bit PCM压缩为更小的格式 +""" +import os +import wave +import struct +import numpy as np +from typing import Optional, Tuple +import logging + +logger = logging.getLogger(__name__) + +class AudioCompressor: + """音频压缩器 - 支持多种压缩算法""" + + @staticmethod + def pcm16_to_ulaw(pcm_data: bytes) -> bytes: + """ + 将16位PCM转换为8位μ-law + 压缩率:50%(16bit -> 8bit) + """ + # 解析16位PCM + samples = np.frombuffer(pcm_data, dtype=np.int16) + + # μ-law压缩 + ulaw_data = bytearray() + for sample in samples: + ulaw_byte = AudioCompressor._linear_to_ulaw(sample) + ulaw_data.append(ulaw_byte) + + return bytes(ulaw_data) + + @staticmethod + def ulaw_to_pcm16(ulaw_data: bytes) -> bytes: + """ + 将8位μ-law转换回16位PCM + """ + pcm_samples = [] + for ulaw_byte in ulaw_data: + pcm_sample = AudioCompressor._ulaw_to_linear(ulaw_byte) + pcm_samples.append(pcm_sample) + + return np.array(pcm_samples, dtype=np.int16).tobytes() + + @staticmethod + def _linear_to_ulaw(sample: int) -> int: + """ + 16位线性PCM转μ-law + """ + # μ-law编码表 + ULAW_MAX = 0x1FFF + ULAW_BIAS = 0x84 + + # 限制范围 + sample = max(-32768, min(32767, sample)) + + # 获取符号位 + sign = 0 + if sample < 0: + sign = 0x80 + sample = -sample + + # 添加偏置 + sample = sample + ULAW_BIAS + + # 限制最大值 + if sample > ULAW_MAX: + sample = ULAW_MAX + + # 查找指数和尾数 + exponent = 7 + for exp in range(7, -1, -1): + if sample & (0x4000 >> exp): + exponent = exp + break + + mantissa = (sample >> (exponent + 3)) & 0x0F + ulawbyte = ~(sign | (exponent << 4) | mantissa) & 0xFF + + return ulawbyte + + @staticmethod + def _ulaw_to_linear(ulawbyte: int) -> int: + """ + μ-law转16位线性PCM + """ + ULAW_BIAS = 0x84 + + ulawbyte = ~ulawbyte & 0xFF + sign = ulawbyte & 0x80 + exponent = (ulawbyte >> 4) & 0x07 + mantissa = ulawbyte & 0x0F + + sample = ((mantissa << 3) + ULAW_BIAS) << exponent + + if sign: + sample = -sample + + return sample + + @staticmethod + def pcm16_to_adpcm(pcm_data: bytes) -> bytes: + """ + 将16位PCM转换为4位ADPCM + 压缩率:75%(16bit -> 4bit) + 保持较好的语音质量 + """ + samples = np.frombuffer(pcm_data, dtype=np.int16) + + # IMA ADPCM 步长表 + step_table = [ + 7, 8, 9, 10, 11, 12, 13, 14, 16, 17, + 19, 21, 23, 25, 28, 31, 34, 37, 41, 45, + 50, 55, 60, 66, 73, 80, 88, 97, 107, 118, + 130, 143, 157, 173, 190, 209, 230, 253, 279, 307, + 337, 371, 408, 449, 494, 544, 598, 658, 724, 796, + 876, 963, 1060, 1166, 1282, 1411, 1552, 1707, 1878, 2066, + 2272, 2499, 2749, 3024, 3327, 3660, 4026, 4428, 4871, 5358, + 5894, 6484, 7132, 7845, 8630, 9493, 10442, 11487, 12635, 13899, + 15289, 16818, 18500, 20350, 22385, 24623, 27086, 29794, 32767 + ] + + # 索引调整表 + index_table = [-1, -1, -1, -1, 2, 4, 6, 8] + + # 初始化 + adpcm_data = bytearray() + predicted = 0 + step_index = 0 + + # 每两个样本打包成一个字节 + for i in range(0, len(samples), 2): + byte = 0 + + for j in range(2): + if i + j < len(samples): + sample = samples[i + j] + + # 计算差值 + diff = sample - predicted + + # 量化 + step = step_table[step_index] + adpcm_sample = 0 + + if diff < 0: + adpcm_sample = 8 + diff = -diff + + if diff >= step: + adpcm_sample |= 4 + diff -= step + + step >>= 1 + if diff >= step: + adpcm_sample |= 2 + diff -= step + + step >>= 1 + if diff >= step: + adpcm_sample |= 1 + + # 更新预测值 + step = step_table[step_index] + diff = 0 + if adpcm_sample & 4: + diff += step + step >>= 1 + if adpcm_sample & 2: + diff += step + step >>= 1 + if adpcm_sample & 1: + diff += step + step >>= 1 + diff += step + + if adpcm_sample & 8: + predicted -= diff + else: + predicted += diff + + # 限制预测值范围 + if predicted > 32767: + predicted = 32767 + elif predicted < -32768: + predicted = -32768 + + # 更新步长索引 + step_index += index_table[adpcm_sample & 7] + if step_index < 0: + step_index = 0 + elif step_index > 88: + step_index = 88 + + # 打包到字节中 + if j == 0: + byte = adpcm_sample + else: + byte |= (adpcm_sample << 4) + + adpcm_data.append(byte) + + # 添加头部信息:初始预测值和步长索引 + header = struct.pack(' bytes: + """ + 将4位ADPCM转换回16位PCM + """ + if len(adpcm_data) < 3: + return b'' + + # 读取头部 + predicted, step_index = struct.unpack('> shift) & 0x0F + + # 计算差值 + step = step_table[step_index] + diff = 0 + + if adpcm_sample & 4: + diff += step + step >>= 1 + if adpcm_sample & 2: + diff += step + step >>= 1 + if adpcm_sample & 1: + diff += step + step >>= 1 + diff += step + + if adpcm_sample & 8: + predicted -= diff + else: + predicted += diff + + # 限制范围 + if predicted > 32767: + predicted = 32767 + elif predicted < -32768: + predicted = -32768 + + pcm_samples.append(predicted) + + # 更新步长索引 + step_index += index_table[adpcm_sample & 7] + if step_index < 0: + step_index = 0 + elif step_index > 88: + step_index = 88 + + return np.array(pcm_samples, dtype=np.int16).tobytes() + + @staticmethod + def downsample_pcm16(pcm_data: bytes, from_rate: int = 16000, to_rate: int = 8000) -> bytes: + """ + 降采样(可选) + 16kHz -> 8kHz 可以再减少50%数据量 + """ + if from_rate == to_rate: + return pcm_data + + # 解析PCM数据 + samples = np.frombuffer(pcm_data, dtype=np.int16) + + # 简单的降采样(每隔一个样本取一个) + if from_rate == 16000 and to_rate == 8000: + downsampled = samples[::2] + else: + # 更复杂的重采样需要scipy + ratio = to_rate / from_rate + new_length = int(len(samples) * ratio) + downsampled = np.interp( + np.linspace(0, len(samples) - 1, new_length), + np.arange(len(samples)), + samples + ).astype(np.int16) + + return downsampled.tobytes() + + +class CompressedAudioCache: + """压缩音频缓存""" + + def __init__(self, compression_type: str = "adpcm", use_downsample: bool = False): + """ + compression_type: "none", "ulaw", "adpcm" + """ + self.compression_type = compression_type + self.use_downsample = use_downsample + self._cache = {} # {filepath: compressed_data} + self._original_sizes = {} # {filepath: original_size} + + def load_and_compress(self, filepath: str) -> Optional[bytes]: + """加载并压缩音频文件(统一转换为8kHz)""" + if filepath in self._cache: + return self._cache[filepath] + + try: + with wave.open(filepath, 'rb') as wav: + # 检查格式 + channels = wav.getnchannels() + sampwidth = wav.getsampwidth() + framerate = wav.getframerate() + + if channels != 1: + logger.warning(f"{filepath} 不是单声道") + if sampwidth != 2: + logger.warning(f"{filepath} 不是16位音频") + + # 读取所有数据 + frames = wav.readframes(wav.getnframes()) + + # 如果是立体声,转换为单声道 + if channels == 2: + import audioop + frames = audioop.tomono(frames, sampwidth, 1, 0) + + # 【修改】始终转换为16kHz(匹配客户端播放器) + if framerate != 16000: + import audioop + frames, _ = audioop.ratecv(frames, sampwidth, 1, framerate, 16000, None) + framerate = 16000 + + # 记录原始大小(转换后的大小) + self._original_sizes[filepath] = len(frames) + + # 压缩 + if self.compression_type == "ulaw": + compressed = AudioCompressor.pcm16_to_ulaw(frames) + # 添加简单的头部信息(1字节标识 + 4字节原始长度) + header = struct.pack('!BI', 0x01, len(frames)) # 0x01表示μ-law + compressed = header + compressed + elif self.compression_type == "adpcm": + compressed = AudioCompressor.pcm16_to_adpcm(frames) + # 添加简单的头部信息(1字节标识 + 4字节原始长度) + header = struct.pack('!BI', 0x02, len(frames)) # 0x02表示ADPCM + compressed = header + compressed + else: + compressed = frames + + self._cache[filepath] = compressed + + # 打印压缩率 + compression_ratio = len(compressed) / self._original_sizes[filepath] + logger.info(f"[压缩] {os.path.basename(filepath)}: " + f"{self._original_sizes[filepath]} -> {len(compressed)} bytes " + f"({compression_ratio:.1%})") + + return compressed + + except Exception as e: + logger.error(f"压缩音频失败 {filepath}: {e}") + return None + + def decompress(self, compressed_data: bytes) -> Optional[bytes]: + """解压音频数据""" + if not compressed_data or len(compressed_data) < 5: + return compressed_data + + try: + # 检查头部 + compression_type = compressed_data[0] + if compression_type == 0x01: # μ-law标识 + header_size = 5 + original_length = struct.unpack('!I', compressed_data[1:5])[0] + ulaw_data = compressed_data[header_size:] + + # μ-law解压 + pcm_data = AudioCompressor.ulaw_to_pcm16(ulaw_data) + + return pcm_data + elif compression_type == 0x02: # ADPCM标识 + header_size = 5 + original_length = struct.unpack('!I', compressed_data[1:5])[0] + adpcm_data = compressed_data[header_size:] + + # ADPCM解压 + pcm_data = AudioCompressor.adpcm_to_pcm16(adpcm_data) + + return pcm_data + else: + # 未压缩的数据 + return compressed_data + + except Exception as e: + logger.error(f"解压音频失败: {e}") + return compressed_data + + def get_compression_stats(self) -> dict: + """获取压缩统计信息""" + total_original = sum(self._original_sizes.values()) + total_compressed = sum(len(data) for data in self._cache.values()) + + return { + "files_cached": len(self._cache), + "total_original_size": total_original, + "total_compressed_size": total_compressed, + "compression_ratio": total_compressed / total_original if total_original > 0 else 0, + "bytes_saved": total_original - total_compressed + } + + +# 全局压缩音频缓存实例 +# 默认使用ADPCM压缩,音质更好,压缩率也不错(75%) +# 可通过环境变量 AIGLASS_COMPRESS_TYPE 设置: none, ulaw, adpcm +import os +compression_type = os.getenv("AIGLASS_COMPRESS_TYPE", "adpcm").lower() +if compression_type not in ["none", "ulaw", "adpcm"]: + compression_type = "adpcm" +compressed_audio_cache = CompressedAudioCache(compression_type=compression_type, use_downsample=False) \ No newline at end of file diff --git a/audio_player.py b/audio_player.py new file mode 100644 index 0000000..da45a1a --- /dev/null +++ b/audio_player.py @@ -0,0 +1,392 @@ +# audio_player.py +# 处理预录音频文件的播放,通过ESP32扬声器输出 + +import os +import wave +import json +import asyncio +import threading +import queue +import time +from audio_stream import broadcast_pcm16_realtime +from audio_compressor import compressed_audio_cache, AudioCompressor + +# 导入录制器(避免循环导入,在需要时动态导入) +_recorder_imported = False +_sync_recorder = None + +def _get_recorder(): + """延迟导入录制器""" + global _recorder_imported, _sync_recorder + if not _recorder_imported: + try: + import sync_recorder as sr + _sync_recorder = sr + _recorder_imported = True + except Exception as e: + print(f"[AUDIO] 无法导入录制器: {e}") + _recorder_imported = True # 标记已尝试,避免重复 + return _sync_recorder + +# 兼容旧工程中的示例音频(保留)- 改为相对路径 +AUDIO_BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "music") + +# 新增:voice 目录与映射表 +# 使用脚本所在目录的 voice 文件夹,避免工作目录问题 +VOICE_DIR = os.getenv("VOICE_DIR", os.path.join(os.path.dirname(os.path.abspath(__file__)), "voice")) +VOICE_MAP_FILE = os.path.join(VOICE_DIR, "map.zh-CN.json") + +# 音频文件映射(将合并 voice 映射) +AUDIO_MAP = { + "检测到物体": os.path.join(AUDIO_BASE_DIR, "音频1.wav"), + "向上": os.path.join(AUDIO_BASE_DIR, "音频2.wav"), + "向下": os.path.join(AUDIO_BASE_DIR, "音频3.wav"), + "向左": os.path.join(AUDIO_BASE_DIR, "音频4.wav"), + "向右": os.path.join(AUDIO_BASE_DIR, "音频5.wav"), + "OK": os.path.join(AUDIO_BASE_DIR, "音频6.wav"), + "向前": os.path.join(AUDIO_BASE_DIR, "音频7.wav"), + "后退": os.path.join(AUDIO_BASE_DIR, "音频8.wav"), + "拿到物体": os.path.join(AUDIO_BASE_DIR, "音频9.wav"), +} + +# 音频缓存,避免重复读取 +_audio_cache = {} + +# 音频播放队列和工作线程 - 使用优先级队列 +_audio_queue = queue.PriorityQueue(maxsize=10) +_audio_priority = 0 # 递增的优先级计数器 +_worker_thread = None +_worker_loop = None +_is_playing = False # 标记是否正在播放音频 +_playing_lock = threading.Lock() # 播放锁 +_initialized = False +_last_play_ts = 0.0 # 记录上次播放结束时间,用于决定预热静音长度 + +def load_wav_file(filepath): + """加载WAV文件并返回PCM数据(自动转换为8kHz)""" + if filepath in _audio_cache: + return _audio_cache[filepath] + + # 使用压缩缓存 + if os.getenv("AIGLASS_COMPRESS_AUDIO", "1") == "1": + compressed_data = compressed_audio_cache.load_and_compress(filepath) + if compressed_data: + # 存储压缩后的数据 + _audio_cache[filepath] = compressed_data + return compressed_data + + # 原始加载方式(不压缩) + try: + with wave.open(filepath, 'rb') as wav: + # 检查音频格式 + channels = wav.getnchannels() + sampwidth = wav.getsampwidth() + framerate = wav.getframerate() + + if channels != 1: + print(f"[AUDIO] 警告: {filepath} 不是单声道,将只使用第一个声道") + if sampwidth != 2: + print(f"[AUDIO] 警告: {filepath} 不是16位音频") + + # 读取所有帧 + frames = wav.readframes(wav.getnframes()) + + # 如果是立体声,只取左声道 + if channels == 2: + import audioop + frames = audioop.tomono(frames, sampwidth, 1, 0) + + # 统一转换为16kHz(使用ratecv保证音调和速度不变) + if framerate != 16000: + import audioop + frames, _ = audioop.ratecv(frames, sampwidth, 1, framerate, 16000, None) + print(f"[AUDIO] 重采样: {filepath} {framerate}Hz -> 16000Hz") + + _audio_cache[filepath] = frames + return frames + + except Exception as e: + print(f"[AUDIO] 加载音频文件失败 {filepath}: {e}") + return None + +def _merge_voice_map(): + """读取 voice/map.zh-CN.json 并合并到 AUDIO_MAP""" + try: + if not os.path.exists(VOICE_MAP_FILE): + print(f"[AUDIO] 未找到映射文件: {VOICE_MAP_FILE}") + return + with open(VOICE_MAP_FILE, "r", encoding="utf-8") as f: + m = json.load(f) + added = 0 + for text, info in (m or {}).items(): + files = (info or {}).get("files") or [] + if not files: + continue + fname = files[0] + fpath = os.path.join(VOICE_DIR, fname) + if os.path.exists(fpath): + AUDIO_MAP[text] = fpath + added += 1 + else: + print(f"[AUDIO] 映射文件缺失: {fpath}") + print(f"[AUDIO] 已合并 voice 映射 {added} 条") + except Exception as e: + print(f"[AUDIO] 读取 voice 映射失败: {e}") + +def preload_all_audio(): + """预加载所有音频文件到内存""" + print("[AUDIO] 开始预加载音频文件...") + loaded_count = 0 + + # 【暂时禁用变速】因为需要修改缓存机制 + # 需要加速的音频列表(斑马线相关) + # speedup_keywords = ["斑马线", "画面"] + # speedup_factor = 1.3 # 加速30% + + for audio_key, filepath in AUDIO_MAP.items(): + if os.path.exists(filepath): + # 【修复】暂时使用默认速度加载 + # need_speedup = any(keyword in audio_key for keyword in speedup_keywords) + # speed = speedup_factor if need_speedup else 1.0 + + data = load_wav_file(filepath) # 使用默认参数 + if data: + loaded_count += 1 + # if need_speedup: + # print(f"[AUDIO] 加载(加速{speedup_factor}x): {audio_key}") + else: + # 降低噪声输出 + pass + print(f"[AUDIO] 预加载完成,共加载 {loaded_count} 个音频文件") + +def _audio_worker(): + """音频播放工作线程""" + global _worker_loop + + # 尝试设置线程优先级(Windows特定) + try: + import ctypes + import sys + if sys.platform == "win32": + # 设置线程为高优先级 + ctypes.windll.kernel32.SetThreadPriority( + ctypes.windll.kernel32.GetCurrentThread(), + 1 # THREAD_PRIORITY_ABOVE_NORMAL + ) + print("[AUDIO] 设置音频线程为高优先级") + except Exception as e: + print(f"[AUDIO] 设置线程优先级失败: {e}") + + _worker_loop = asyncio.new_event_loop() + asyncio.set_event_loop(_worker_loop) + + async def process_queue(): + while True: + try: + # 从优先级队列获取数据 + priority_data = await asyncio.get_event_loop().run_in_executor(None, _audio_queue.get, True) + if priority_data is None: + break + # 解包优先级和实际音频数据 + if isinstance(priority_data, tuple) and len(priority_data) == 2: + _, audio_data = priority_data + else: + audio_data = priority_data + await _broadcast_audio_optimized(audio_data) + except Exception as e: + print(f"[AUDIO] 工作线程错误: {e}") + + _worker_loop.run_until_complete(process_queue()) + +async def _broadcast_audio_optimized(pcm_data: bytes): + """优化的音频广播:单次调用由底层按20ms节拍发送,移除重复节拍和Python层sleep""" + global _last_play_ts, _is_playing + try: + # 设置播放标志 + with _playing_lock: + _is_playing = True + # 此时 pcm_data 应该已经是解压后的16位PCM数据了(8kHz) + now = time.monotonic() + idle_sec = now - (_last_play_ts or now) + # 首次或长时间空闲后,预热更长静音;否则小静音 + lead_ms = 160 if idle_sec > 3.0 else 60 + tail_ms = 40 + + lead_silence = b'\x00' * (lead_ms * 16000 * 2 // 1000) # 16k * 2B + tail_silence = b'\x00' * (tail_ms * 16000 * 2 // 1000) + + # 完整音频数据(包含静音) + full_audio = lead_silence + pcm_data + tail_silence + + # 注意:录制在 broadcast_pcm16_realtime 中统一完成,避免重复 + + # 单次调用交给底层 pacing(20ms节拍在 broadcast_pcm16_realtime 内部实现) + await broadcast_pcm16_realtime(full_audio) + + _last_play_ts = time.monotonic() + except Exception as e: + print(f"[AUDIO] 广播音频失败: {e}") + finally: + # 清除播放标志 + with _playing_lock: + _is_playing = False + +def initialize_audio_system(): + """初始化音频系统""" + global _initialized, _worker_thread, _last_play_ts + + if _initialized: + return + + # 先合并 voice 映射,再预加载 + _merge_voice_map() + preload_all_audio() + + _worker_thread = threading.Thread(target=_audio_worker, daemon=True) + _worker_thread.start() + _initialized = True + _last_play_ts = 0.0 + + # 显示压缩统计 + if os.getenv("AIGLASS_COMPRESS_AUDIO", "1") == "1": + stats = compressed_audio_cache.get_compression_stats() + print(f"[AUDIO] 音频压缩统计:") + print(f" - 文件数: {stats['files_cached']}") + print(f" - 原始大小: {stats['total_original_size'] / 1024:.1f} KB") + print(f" - 压缩后: {stats['total_compressed_size'] / 1024:.1f} KB") + print(f" - 压缩率: {stats['compression_ratio']:.1%}") + print(f" - 节省: {stats['bytes_saved'] / 1024:.1f} KB") + + print("[AUDIO] 音频系统初始化完成(预加载+工作线程)") + +def play_audio_threadsafe(audio_key): + """线程安全的音频播放函数""" + global _audio_queue, _audio_priority + + if not _initialized: + initialize_audio_system() + + if audio_key not in AUDIO_MAP: + print(f"[AUDIO] 未知的音频键: {audio_key}") + return + + filepath = AUDIO_MAP[audio_key] + pcm_data = _audio_cache.get(filepath) + if pcm_data is None: + print(f"[AUDIO] 音频未在缓存中: {audio_key}") + return + + # 如果是压缩的数据,先解压 + if pcm_data and len(pcm_data) > 5 and pcm_data[0] in [0x01, 0x02]: + pcm_data = compressed_audio_cache.decompress(pcm_data) + if not pcm_data: + print(f"[AUDIO] 解压失败: {audio_key}") + return + + # 【优化】实时播报策略:保持队列最小化,避免积压延迟 + queue_size = _audio_queue.qsize() + + # 检查是否正在播放 + with _playing_lock: + currently_playing = _is_playing + + # 实时策略:只允许1个积压,超过立即清空 + if queue_size > 0 and not currently_playing: + # 未播放时立即清空,播放最新语音 + print(f"[AUDIO] 清空队列(当前{queue_size}个),播放最新语音") + _audio_queue = queue.PriorityQueue(maxsize=10) + elif queue_size > 1 and currently_playing: + # 正在播放时,如果积压>1个则清空(保持实时性) + print(f"[AUDIO] 队列积压({queue_size}个),清空以保持实时") + _audio_queue = queue.PriorityQueue(maxsize=10) + try: + # 使用优先级队列,确保音频按顺序播放 + _audio_priority += 1 + _audio_queue.put_nowait((_audio_priority, pcm_data)) + if queue_size >= 1: + print(f"[AUDIO] 播放队列当前大小: {queue_size + 1}") + except queue.Full: + # 播放队列满则丢弃,保持实时性 + print(f"[AUDIO] 队列满,丢弃: {audio_key}") + pass + +# 全局语音节流 - Day 22 优化: 降低冷却时间 +_last_voice_time = 0 +_last_voice_text = "" +_voice_cooldown = 1.5 # 相同语音至少间隔1.5秒 (从3.0降低) + +# 语音优先级定义 +VOICE_PRIORITY = { + 'obstacle': 100, # 障碍物 - 最高优先级 + 'direction': 50, # 转向/平移 - 中等优先级 + 'straight': 10, # 保持直行 - 最低优先级 + 'other': 30 # 其他 - 默认优先级 +} + +# 新增:根据中文提示文案直接播放(会做轻度规范化与降级) +def play_voice_text(text: str): + """ + 传入中文提示,自动匹配 voice 映射并播放。 + - 尝试原文 + - 尝试补全/去除句末标点(。.!!??) + - 若包含“前方有…注意避让”但未命中,降级到“前方有障碍物,注意避让。” + """ + global _last_voice_time, _last_voice_text + + if not text: + return + if not _initialized: + initialize_audio_system() + + # 全局节流:相同文本短时间内不重复播放 + current_time = time.time() + if text == _last_voice_text and current_time - _last_voice_time < _voice_cooldown: + return # 静默跳过 + + candidates = [] + t = text.strip() + candidates.append(t) + # 尝试补全句号 + if t[-1:] not in ("。", "!", "!", "?", "?", "."): + candidates.append(t + "。") + else: + # 尝试去掉标点 + t2 = t.rstrip("。.!!??") + if t2 and t2 != t: + candidates.append(t2) + + # 逐一尝试匹配 + for ck in candidates: + if ck in AUDIO_MAP: + play_audio_threadsafe(ck) + _last_voice_text = text + _last_voice_time = current_time + return + + # 针对“前方有…注意避让”降级 + if ("前方有" in t) and ("注意避让" in t): + fallback = "前方有障碍物,注意避让。" + if fallback in AUDIO_MAP: + play_audio_threadsafe(fallback) + _last_voice_text = text + _last_voice_time = current_time + return + + # 针对“请向…平移/微调/转动”类词条,常见变体尝试 + base = t.rstrip("。.!!??") + if base in AUDIO_MAP: + play_audio_threadsafe(base) + _last_voice_text = text + _last_voice_time = current_time + return + if base + "。" in AUDIO_MAP: + play_audio_threadsafe(base + "。") + _last_voice_text = text + _last_voice_time = current_time + return + + # 未匹配则输出日志(便于调试) + print(f"[AUDIO] 未找到匹配语音: {text}") + +# 兼容旧接口 +play_audio_on_esp32 = play_audio_threadsafe \ No newline at end of file diff --git a/audio_stream.py b/audio_stream.py new file mode 100644 index 0000000..d8ffd91 --- /dev/null +++ b/audio_stream.py @@ -0,0 +1,299 @@ +# audio_stream.py +# -*- coding: utf-8 -*- +import asyncio +from collections import deque +from dataclasses import dataclass +from typing import Optional, Set, List, Tuple, Any, Dict +from fastapi import Request +from fastapi.responses import StreamingResponse + +# ===== 下行 WAV 流基础参数 ===== +STREAM_SR = 8000 # 改为8kHz,ESP32支持 +STREAM_CH = 1 +STREAM_SW = 2 +BYTES_PER_20MS_16K = STREAM_SR * STREAM_SW * 20 // 1000 # 320B (8kHz) + +# ===== Day 13: TTS 缓存队列 ===== +# 当 WebSocket 断开时,缓存 TTS 音频,等待重连后发送 +TTS_BUFFER_MAX_SECONDS = 30 # 最多缓存 30 秒音频 +TTS_BUFFER_MAX_BYTES = 16000 * 2 * TTS_BUFFER_MAX_SECONDS # 16kHz * 2 bytes * 30s = ~960KB +tts_audio_buffer: deque = deque() # 每个元素是 (timestamp, pcm16k_bytes) +tts_buffer_total_bytes = 0 + +# Day 13: TTS 专用 WebSocket 引用 +# 在 AI 处理开始前保存,避免被 ws_audio 的 finally 块清空 +tts_websocket = None + +def set_tts_websocket(ws): + """保存 TTS 发送专用的 WebSocket 引用""" + global tts_websocket + tts_websocket = ws + +def get_tts_websocket(): + """获取 TTS WebSocket(优先使用保存的引用,其次尝试全局变量)""" + global tts_websocket + if tts_websocket is not None: + return tts_websocket + # Day 15 修复:避免 import app_main,因为会触发模块顶层代码重新执行 + # 改为通过 sys.modules 获取已加载的模块引用 + try: + import sys + if 'app_main' in sys.modules: + return sys.modules['app_main'].esp32_audio_ws + except: + pass + return None + + +# ===== AI 播放任务总闸 ===== +current_ai_task: Optional[asyncio.Task] = None + +async def cancel_current_ai(): + """取消当前大模型语音任务,并等待其退出。""" + global current_ai_task + task = current_ai_task + current_ai_task = None + if task and not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + except Exception: + pass + +def is_playing_now() -> bool: + t = current_ai_task + return (t is not None) and (not t.done()) + +# ===== /stream.wav 连接管理 ===== +@dataclass(frozen=True) +class StreamClient: + q: asyncio.Queue + abort_event: asyncio.Event + +stream_clients: "Set[StreamClient]" = set() +STREAM_QUEUE_MAX = 96 # 小缓冲,避免积压 + +def _wav_header_unknown_size(sr=16000, ch=1, sw=2) -> bytes: + import struct + byte_rate = sr * ch * sw + block_align = ch * sw + data_size = 0x7FFFFFF0 + riff_size = 36 + data_size + return struct.pack( + "<4sI4s4sIHHIIHH4sI", + b"RIFF", riff_size, b"WAVE", + b"fmt ", 16, + 1, ch, sr, byte_rate, block_align, sw * 8, + b"data", data_size + ) + +async def hard_reset_audio(reason: str = ""): + """ + **一键清场**:取消当前AI任务。 + 注意:不再断开 HTTP /stream.wav 连接,因为 Avaota F1 使用这个通道播放 TTS。 + """ + # Day 14: 不再断开 HTTP 连接,只取消 AI 任务 + # 因为 Avaota F1 的 HTTP TTS 客户端需要保持长连接 + # 断开会导致客户端收不到后续 TTS 音频 + + + # 2) 取消当前AI任务 + await cancel_current_ai() + + # 3) 日志 + if reason: + print(f"[HARD-RESET] {reason}") + +async def flush_tts_buffer(ws) -> int: + """ + Day 13: 刷新 TTS 缓存,发送所有缓存的音频到 WebSocket + 返回发送的字节数 + """ + global tts_audio_buffer, tts_buffer_total_bytes + from starlette.websockets import WebSocketState + + if not tts_audio_buffer: + return 0 + + total_sent = 0 + items_to_send = list(tts_audio_buffer) + tts_audio_buffer.clear() + tts_buffer_total_bytes = 0 + + try: + for _, audio_data in items_to_send: + if hasattr(ws, 'client_state') and ws.client_state != WebSocketState.CONNECTED: + print(f"[TTS->WS] ⚠️ WebSocket disconnected while flushing buffer") + break + await ws.send_bytes(audio_data) + total_sent += len(audio_data) + + if total_sent > 0: + duration = total_sent / (16000 * 2) + print(f"[TTS->WS] 📤 Flushed {total_sent} bytes ({duration:.1f}s) of cached TTS audio") + except Exception as e: + print(f"[TTS->WS] ❌ Error flushing buffer: {e}") + + return total_sent + +async def broadcast_pcm16_realtime(pcm16: bytes): + """ + Day 14 优化:WebSocket 立即发送,HTTP 节拍广播在后台执行 + 避免 HTTP 20ms pacing 阻塞 WebSocket TTS 传输 + """ + # 【新增】录制音频(在分发之前整体录制,避免分片) + try: + import sync_recorder + sync_recorder.record_audio(pcm16, text="[Omni对话]") + except Exception: + pass # 静默失败,不影响播放 + + # Day 13: 同时发送给 WebSocket 客户端 (Avaota F1) + # 注意:Avaota 期望 16kHz PCM16 数据,而这里的 pcm16 是 8kHz + # 需要进行采样率转换 + global tts_audio_buffer, tts_buffer_total_bytes + import time as _time + + try: + import audioop + + # Day 13: 使用 get_tts_websocket() 获取 WebSocket 引用 + # 优先使用保存的引用,避免因 ws_audio 的 finally 清空全局变量 + ws = get_tts_websocket() + + # Day 21 优化:输入现在已经是 16kHz,无需转换 + # app_main.py 已经直接从 24kHz 转换到 16kHz + pcm16k = pcm16 + + sent_ok = False + if ws is not None: + try: + # Day 13 修复:不检查 client_state,直接尝试发送 + # WebSocketState 检查可能不准确,导致音频被错误缓存 + + # 先发送缓存的音频 + while tts_audio_buffer: + _, buffered_audio = tts_audio_buffer.popleft() + tts_buffer_total_bytes -= len(buffered_audio) + await ws.send_bytes(buffered_audio) + if not getattr(broadcast_pcm16_realtime, '_flush_logged', False): + print(f"[TTS->WS] 📤 Flushing buffered TTS audio...") + broadcast_pcm16_realtime._flush_logged = True + + # 发送当前音频 + await ws.send_bytes(pcm16k) + sent_ok = True + + if len(pcm16k) > 320: + print(f"[TTS->WS] 📤 Sent {len(pcm16k)} bytes (16kHz) to Avaota") + + # 重置警告标志 + broadcast_pcm16_realtime._ws_warned = False + broadcast_pcm16_realtime._buffer_warned = False + broadcast_pcm16_realtime._flush_logged = False + except Exception as send_err: + # 发送失败,将当前音频放回缓存 + if not getattr(broadcast_pcm16_realtime, '_send_err_warned', False): + print(f"[TTS->WS] ❌ Send error: {send_err}, will buffer") + broadcast_pcm16_realtime._send_err_warned = True + + # 如果发送失败或 WebSocket 断开,缓存音频 + if not sent_ok: + # 添加到缓存队列 + tts_audio_buffer.append((_time.time(), pcm16k)) + tts_buffer_total_bytes += len(pcm16k) + + # 如果缓存过大,移除最旧的 + while tts_buffer_total_bytes > TTS_BUFFER_MAX_BYTES and tts_audio_buffer: + _, old_audio = tts_audio_buffer.popleft() + tts_buffer_total_bytes -= len(old_audio) + + if not getattr(broadcast_pcm16_realtime, '_buffer_warned', False): + buffer_secs = tts_buffer_total_bytes / (16000 * 2) + print(f"[TTS->WS] 📦 Buffering TTS audio ({buffer_secs:.1f}s cached), will send when reconnected") + broadcast_pcm16_realtime._buffer_warned = True + + except Exception: + pass # 静默忽略所有异常 + + # Day 14 优化:将 HTTP 节拍广播放到后台任务,不阻塞 WebSocket 发送 + # 这样下一个 Omni 音频块可以立即处理,不用等待 HTTP 节拍完成 + if stream_clients: + asyncio.create_task(_http_pacing_broadcast(pcm16)) + + +async def _http_pacing_broadcast(pcm16: bytes): + """ + Day 14: HTTP 客户端的 20ms 节拍广播(独立后台任务) + 原来嵌入在 broadcast_pcm16_realtime 中,会阻塞 WebSocket 发送 + """ + loop = asyncio.get_event_loop() + next_tick = loop.time() + off = 0 + while off < len(pcm16): + take = min(BYTES_PER_20MS_16K, len(pcm16) - off) + piece = pcm16[off:off + take] + + dead: List[StreamClient] = [] + # Day 14 调试:确认 stream_clients 状态 + if len(stream_clients) > 0 and off == 0: + print(f"[TTS->HTTP] 📤 Sending to {len(stream_clients)} HTTP stream client(s)") + for sc in list(stream_clients): + if sc.abort_event.is_set(): + dead.append(sc) + continue + try: + if sc.q.full(): + try: sc.q.get_nowait() + except Exception: pass + sc.q.put_nowait(piece) + except Exception: + dead.append(sc) + for sc in dead: + try: stream_clients.discard(sc) + except Exception: pass + + next_tick += 0.020 + now = loop.time() + if now < next_tick: + await asyncio.sleep(next_tick - now) + else: + next_tick = now + off += take + +# ===== FastAPI 路由注册器 ===== +def register_stream_route(app): + @app.get("/stream.wav") + async def stream_wav(_: Request): + # —— 强制单连接(或少数连接),先拉闸所有旧连接 —— + for sc in list(stream_clients): + try: sc.abort_event.set() + except Exception: pass + stream_clients.clear() + + q: asyncio.Queue[bytes | None] = asyncio.Queue(maxsize=STREAM_QUEUE_MAX) + abort_event = asyncio.Event() + sc = StreamClient(q=q, abort_event=abort_event) + stream_clients.add(sc) + + async def gen(): + yield _wav_header_unknown_size(STREAM_SR, STREAM_CH, STREAM_SW) + try: + while True: + if abort_event.is_set(): + break + try: + chunk = await asyncio.wait_for(q.get(), timeout=0.5) + except asyncio.TimeoutError: + continue + if abort_event.is_set(): + break + if chunk is None: + break + if chunk: + yield chunk + finally: + stream_clients.discard(sc) + return StreamingResponse(gen(), media_type="audio/wav") \ No newline at end of file diff --git a/bridge_io.py b/bridge_io.py new file mode 100644 index 0000000..93e6870 --- /dev/null +++ b/bridge_io.py @@ -0,0 +1,92 @@ +# bridge_io.py +# 极简桥:接原始JPEG → 提供BGR帧给外部算法;外部算法产出BGR → 广播给前端 +import threading +from collections import deque +import time +import cv2 +import numpy as np + +# 原始JPEG帧缓冲(只保留最新 N 帧) +_MAX_BUF = 4 +_frames = deque(maxlen=_MAX_BUF) +_cond = threading.Condition() + +# 向前端发送JPEG的回调,由 app_main.py 在启动时注册 +_sender_lock = threading.Lock() +_sender_cb = None + +# 向前端发送UI文本的回调(由 app_main.py 在启动时注册) +_ui_sender_lock = threading.Lock() +_ui_sender_cb = None + +def set_sender(cb): + """由 app_main.py 调用,注册一个函数:cb(jpeg_bytes)->None""" + global _sender_cb + with _sender_lock: + _sender_cb = cb + +def set_ui_sender(cb): + """由 app_main.py 调用,注册一个函数:cb(text:str)->None""" + global _ui_sender_cb + with _ui_sender_lock: + _ui_sender_cb = cb + +def push_raw_jpeg(jpeg_bytes: bytes): + """由 app_main.py 在收到 /ws/camera 帧时调用""" + if not jpeg_bytes: + return + with _cond: + _frames.append((time.time(), jpeg_bytes)) + _cond.notify_all() + +def wait_raw_bgr(timeout_sec: float = 0.5): + """被 YOLO/MediaPipe 脚本调用:等待并拿到最新一帧BGR;超时返回 None""" + t_end = time.time() + timeout_sec + last = None + while time.time() < t_end: + with _cond: + if _frames: + last = _frames[-1] + if last is None: + time.sleep(0.01) + continue + # 解码JPEG为BGR + ts, jpeg = last + arr = np.frombuffer(jpeg, dtype=np.uint8) + bgr = cv2.imdecode(arr, cv2.IMREAD_COLOR) + if bgr is not None: + # 在最源头进行镜像处理 + #bgr = cv2.flip(bgr, 1) + return bgr + # 解码失败,稍等重试 + time.sleep(0.01) + return None + +def send_vis_bgr(bgr, quality: int = 80): + """被 YOLO/MediaPipe 脚本调用:把处理后画面推给前端 viewer""" + if bgr is None: + return + + # 直接编码,不做任何增强处理 + ok, enc = cv2.imencode(".jpg", bgr, [int(cv2.IMWRITE_JPEG_QUALITY), int(quality)]) + if not ok: + return + with _sender_lock: + cb = _sender_cb + if cb: + try: + cb(enc.tobytes()) + except Exception: + pass + +def send_ui_final(text: str): + """把一条UI文案作为 final answer 推给前端(线程安全回调)""" + if not text: + return + with _ui_sender_lock: + cb = _ui_sender_cb + if cb: + try: + cb(str(text)) + except Exception: + pass diff --git a/crosswalk_awareness.py b/crosswalk_awareness.py new file mode 100644 index 0000000..abddec3 --- /dev/null +++ b/crosswalk_awareness.py @@ -0,0 +1,342 @@ +# -*- coding: utf-8 -*- +""" +斑马线感知监控器 +基于面积变化的斑马线检测和语音提示 +不涉及状态切换,只提供语音引导 +""" +import time +import numpy as np +from collections import deque +from typing import Optional, Dict, Any +import logging + +logger = logging.getLogger(__name__) + + +class CrosswalkAwarenessMonitor: + """斑马线感知监控器 - 纯语音提示模块""" + + def __init__(self): + # 面积阈值(固定锚点) + self.THRESHOLDS = { + 'discover': 0.01, # 1% - 发现 + 'approaching': 0.08, # 8% - 靠近 + 'near': 0.18, # 18% - 很近 + 'arrival': 0.25, # 25% - 到达(可以过马路) + } + + # 已播报的阈值(避免重复) + self.broadcasted_thresholds = set() + + # 面积历史记录 + self.area_history = deque(maxlen=30) # 保存最近30帧 + + # 时间记录 + self.last_broadcast_time = 0 + self.arrival_first_broadcast_time = 0 + + # 状态标志 + self.in_arrival_state = False # 是否在"可以过马路"状态 + self.last_position_zone = None # 上次播报的方位 + + # 播报间隔配置(可调整参数 - 数值越小播报越频繁) + # 【参数调整】将所有间隔除以1.5,提高播报频率1.5倍 + self.REPEAT_INTERVALS = { + 'approaching': 6.7, # 靠近阶段:每6.7秒重复(原10秒÷1.5) + 'near': 3.3, # 很近阶段:每3.3秒重复(原5秒÷1.5) + 'arrival': 5.3, # 到达阶段:每5.3秒重复(原8秒÷1.5) + } + # 提示:如需调整频率,修改这些数值即可 + # - 数值越小 = 播报越频繁 + # - 数值越大 = 播报越稀疏 + + # 无遮挡判断阈值 + self.OCCLUSION_THRESHOLD = 0.30 # 重叠>30%认为有遮挡 + + def process_frame(self, crosswalk_mask, blind_path_mask=None) -> Optional[Dict[str, Any]]: + """ + 处理每帧的斑马线检测 + + 返回: + { + 'voice_text': 语音文本, + 'priority': 优先级, + 'should_broadcast': 是否应该播报, + 'area': 当前面积, + 'position': 方位描述, + 'visualization': 可视化信息(用于外部绘制) + } + 或 None(无需播报) + """ + # 如果没有斑马线,重置状态 + if crosswalk_mask is None: + self._reset_if_needed() + return None + + # 1. 计算面积 + total_pixels = crosswalk_mask.size + crosswalk_pixels = np.sum(crosswalk_mask > 0) + area_ratio = crosswalk_pixels / total_pixels + + # 2. 计算中心位置 + y_coords, x_coords = np.where(crosswalk_mask > 0) + if len(y_coords) == 0: + return None + + center_x_ratio = np.mean(x_coords) / crosswalk_mask.shape[1] + center_y_ratio = np.mean(y_coords) / crosswalk_mask.shape[0] + + # 3. 记录历史 + current_time = time.time() + self.area_history.append({ + 'area': area_ratio, + 'center_x': center_x_ratio, + 'center_y': center_y_ratio, + 'time': current_time + }) + + # 4. 检查遮挡 + has_occlusion = self._check_occlusion(crosswalk_mask, blind_path_mask) + + # 5. 判断当前阶段和生成语音 + return self._generate_guidance(area_ratio, center_x_ratio, center_y_ratio, + has_occlusion, current_time) + + def _check_occlusion(self, crosswalk_mask, blind_path_mask) -> bool: + """检查斑马线是否被盲道遮挡""" + if blind_path_mask is None: + return False + + crosswalk_area = crosswalk_mask > 0 + blind_path_area = blind_path_mask > 0 + + # 计算重叠 + overlap = np.logical_and(crosswalk_area, blind_path_area) + overlap_ratio = np.sum(overlap) / max(np.sum(crosswalk_area), 1) + + # 重叠超过阈值认为有遮挡 + return overlap_ratio > self.OCCLUSION_THRESHOLD + + def _get_position_description(self, center_x_ratio) -> str: + """获取方位描述(3分法)""" + if center_x_ratio < 0.40: + return "在画面左侧" + elif center_x_ratio < 0.60: + return "在画面中间" + else: + return "在画面右侧" + + def _generate_guidance(self, area_ratio, center_x_ratio, center_y_ratio, + has_occlusion, current_time) -> Optional[Dict[str, Any]]: + """生成引导语音""" + + # 检查面积是否稳定(避免抖动) + if not self._is_area_stable(area_ratio): + return None + + position_desc = self._get_position_description(center_x_ratio) + + # 阶段1:发现阶段(0.01-0.08) + if area_ratio >= self.THRESHOLDS['discover'] and area_ratio < self.THRESHOLDS['approaching']: + if self.THRESHOLDS['discover'] not in self.broadcasted_thresholds: + self.broadcasted_thresholds.add(self.THRESHOLDS['discover']) + return { + 'voice_text': f"远处发现斑马线,{position_desc}", + 'priority': 55, # 提高到55,超过盲道方向指令(50) + 'should_broadcast': True, + 'area': area_ratio, + 'position': position_desc + } + + # 阶段2:靠近阶段(0.08-0.18) + elif area_ratio >= self.THRESHOLDS['approaching'] and area_ratio < self.THRESHOLDS['near']: + # 首次播报 + if self.THRESHOLDS['approaching'] not in self.broadcasted_thresholds: + self.broadcasted_thresholds.add(self.THRESHOLDS['approaching']) + self.last_broadcast_time = current_time + self.last_position_zone = position_desc + return { + 'voice_text': f"正在靠近斑马线,{position_desc}", + 'priority': 55, # 提高到55 + 'should_broadcast': True, + 'area': area_ratio, + 'position': position_desc + } + # 重复播报(每10秒或方位变化) + elif (current_time - self.last_broadcast_time >= self.REPEAT_INTERVALS['approaching'] or + position_desc != self.last_position_zone): + self.last_broadcast_time = current_time + self.last_position_zone = position_desc + return { + 'voice_text': f"正在靠近斑马线,{position_desc}", + 'priority': 55, # 提高到55 + 'should_broadcast': True, + 'area': area_ratio, + 'position': position_desc + } + + # 阶段3:很近阶段(0.18-0.25) + elif area_ratio >= self.THRESHOLDS['near'] and area_ratio < self.THRESHOLDS['arrival']: + # 首次播报 + if self.THRESHOLDS['near'] not in self.broadcasted_thresholds: + self.broadcasted_thresholds.add(self.THRESHOLDS['near']) + self.last_broadcast_time = current_time + self.last_position_zone = position_desc + return { + 'voice_text': f"接近斑马线,{position_desc}", + 'priority': 60, + 'should_broadcast': True, + 'area': area_ratio, + 'position': position_desc + } + # 重复播报(每5秒或方位变化) + elif (current_time - self.last_broadcast_time >= self.REPEAT_INTERVALS['near'] or + position_desc != self.last_position_zone): + self.last_broadcast_time = current_time + self.last_position_zone = position_desc + return { + 'voice_text': f"接近斑马线,{position_desc}", + 'priority': 60, + 'should_broadcast': True, + 'area': area_ratio, + 'position': position_desc + } + + # 阶段4:到达阶段(area ≥ 0.25,无遮挡) + elif area_ratio >= self.THRESHOLDS['arrival']: + # 必须无遮挡才能提示过马路 + if has_occlusion: + # 有遮挡,暂不提示过马路,停留在阶段3 + logger.info(f"[斑马线] 面积达到{area_ratio:.2f}但被遮挡,暂不提示过马路") + return None + + # 首次到达 + if not self.in_arrival_state: + self.in_arrival_state = True + self.arrival_first_broadcast_time = current_time + self.last_broadcast_time = current_time + logger.info(f"[斑马线] 到达状态:area={area_ratio:.2f}, 无遮挡") + return { + 'voice_text': "斑马线到了可以过马路", + 'priority': 80, + 'should_broadcast': True, + 'area': area_ratio, + 'position': '到达' + } + # 重复播报(每8秒) + elif current_time - self.last_broadcast_time >= self.REPEAT_INTERVALS['arrival']: + self.last_broadcast_time = current_time + return { + 'voice_text': "斑马线到了可以过马路", + 'priority': 80, + 'should_broadcast': True, + 'area': area_ratio, + 'position': '到达' + } + # 超时处理(30秒后自动退出到达状态) + elif current_time - self.arrival_first_broadcast_time > 30.0: + logger.info("[斑马线] 到达状态超时30秒,自动退出") + self.in_arrival_state = False + return None + + # 降级处理:如果从到达状态面积减小 + if self.in_arrival_state and area_ratio < 0.20: + logger.info(f"[斑马线] 面积降至{area_ratio:.2f},退出到达状态") + self.in_arrival_state = False + # 清除部分已播报标记,允许重新播报 + self.broadcasted_thresholds.discard(self.THRESHOLDS['arrival']) + + return None + + def _is_area_stable(self, area_ratio, stability_frames=5) -> bool: + """检查面积是否稳定(避免抖动触发)""" + if len(self.area_history) < stability_frames: + return True # 初始阶段,认为稳定 + + recent_areas = [h['area'] for h in list(self.area_history)[-stability_frames:]] + + # 检查最近N帧是否都在当前面积附近(±20%) + for recent_area in recent_areas: + if abs(recent_area - area_ratio) / max(area_ratio, 0.001) > 0.20: + return False + + return True + + def _reset_if_needed(self): + """重置状态(斑马线消失时)""" + if len(self.area_history) > 0: + logger.info("[斑马线] 斑马线消失,重置状态") + + self.broadcasted_thresholds.clear() + self.area_history.clear() + self.in_arrival_state = False + self.last_position_zone = None + + def reset(self): + """完全重置""" + self.broadcasted_thresholds.clear() + self.area_history.clear() + self.in_arrival_state = False + self.last_broadcast_time = 0 + self.arrival_first_broadcast_time = 0 + self.last_position_zone = None + logger.info("[斑马线] 感知监控器已重置") + + def is_in_arrival_state(self) -> bool: + """是否在到达状态(用于外部判断是否暂停盲道语音)""" + return self.in_arrival_state + + def get_current_area(self) -> float: + """获取当前面积""" + if len(self.area_history) > 0: + return self.area_history[-1]['area'] + return 0.0 + + def get_visualization_data(self, crosswalk_mask, area_ratio, center_x_ratio, center_y_ratio, has_occlusion) -> Dict[str, Any]: + """ + 获取可视化数据 + 返回包含所有可视化元素的字典 + """ + if crosswalk_mask is None: + return {} + + # 确定当前阶段(统一使用橙色) + if area_ratio >= self.THRESHOLDS['arrival']: + stage = "到达" + stage_color = "rgba(255, 165, 0, 0.5)" # 橙色 + elif area_ratio >= self.THRESHOLDS['near']: + stage = "接近" + stage_color = "rgba(255, 165, 0, 0.45)" # 橙色 + elif area_ratio >= self.THRESHOLDS['approaching']: + stage = "靠近" + stage_color = "rgba(255, 165, 0, 0.40)" # 橙色 + else: + stage = "发现" + stage_color = "rgba(255, 165, 0, 0.35)" # 橙色 + + # 方位描述 + position = self._get_position_description(center_x_ratio) + + return { + 'area_ratio': area_ratio, + 'stage': stage, + 'stage_color': stage_color, + 'position': position.replace("在画面", ""), # 去掉"在画面"前缀 + 'center_x_ratio': center_x_ratio, + 'center_y_ratio': center_y_ratio, + 'has_occlusion': has_occlusion, + 'in_arrival': self.in_arrival_state + } + + +# 辅助函数 +def split_combined_voice(combined_text: str) -> list: + """ + 将组合语音拆分为多个独立语音 + 例如:"远处发现斑马线,在画面左侧" → ["远处发现斑马线", "在画面左侧"] + """ + if ',' in combined_text: + parts = combined_text.split(',') + return [p.strip() for p in parts if p.strip()] + return [combined_text] + diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..c5d4d50 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,77 @@ +version: '3.8' + +services: + aiglass: + build: + context: . + dockerfile: Dockerfile + container_name: aiglass + restart: unless-stopped + + # 端口映射 + ports: + - "8081:8081" # Web 服务 + - "12345:12345/udp" # IMU UDP + + # 环境变量(从 .env 文件读取) + environment: + - DASHSCOPE_API_KEY=${DASHSCOPE_API_KEY} + - BLIND_PATH_MODEL=${BLIND_PATH_MODEL:-model/yolo-seg.pt} + - OBSTACLE_MODEL=${OBSTACLE_MODEL:-model/yoloe-11l-seg.pt} + - YOLOE_MODEL_PATH=${YOLOE_MODEL_PATH:-model/yoloe-11l-seg.pt} + - ENABLE_TTS=${ENABLE_TTS:-true} + - TTS_INTERVAL_SEC=${TTS_INTERVAL_SEC:-1.0} + - LOG_LEVEL=${LOG_LEVEL:-INFO} + + # 卷挂载 + volumes: + - ./model:/app/model:ro # 模型文件(只读) + - ./recordings:/app/recordings # 录制文件 + - ./music:/app/music:ro # 音频文件(只读) + - ./voice:/app/voice:ro # 语音文件(只读) + - ./static:/app/static:ro # 静态文件(只读) + - ./templates:/app/templates:ro # 模板文件(只读) + + # GPU 支持(需要 nvidia-docker) + deploy: + resources: + reservations: + devices: + - driver: nvidia + count: 1 + capabilities: [gpu] + + # 依赖服务(可选,如需添加数据库等) + # depends_on: + # - redis + + # 网络模式 + network_mode: bridge + + # 健康检查 + healthcheck: + test: ["CMD", "curl", "-f", "http://localhost:8081/api/health"] + interval: 30s + timeout: 10s + retries: 3 + start_period: 40s + + # 可选:添加 Redis 用于缓存 + # redis: + # image: redis:7-alpine + # container_name: aiglass-redis + # restart: unless-stopped + # ports: + # - "6379:6379" + # volumes: + # - redis-data:/data + +# 可选:数据卷 +# volumes: +# redis-data: + +# 可选:自定义网络 +# networks: +# aiglass-network: +# driver: bridge + diff --git a/edge_tts_client.py b/edge_tts_client.py new file mode 100644 index 0000000..d3b870f --- /dev/null +++ b/edge_tts_client.py @@ -0,0 +1,202 @@ +# edge_tts_client.py +# -*- coding: utf-8 -*- +""" +EdgeTTS 流式语音合成客户端 - Day 21 + +特点: +- 完全免费 +- 流式输出(边合成边播放) +- 低延迟 +""" + +import os +import asyncio +import edge_tts +from typing import AsyncGenerator, Optional + +# 默认语音 +DEFAULT_VOICE = os.getenv("EDGE_TTS_VOICE", "zh-CN-XiaoxiaoNeural") + +# 语速调整 ("+0%", "+10%", "-10%" 等) +DEFAULT_RATE = os.getenv("EDGE_TTS_RATE", "+0%") + +# 音量调整 +DEFAULT_VOLUME = os.getenv("EDGE_TTS_VOLUME", "+0%") + + +async def text_to_speech_stream( + text: str, + voice: str = DEFAULT_VOICE, + rate: str = DEFAULT_RATE, + volume: str = DEFAULT_VOLUME, +) -> AsyncGenerator[bytes, None]: + """ + 流式文本转语音 + + Args: + text: 要合成的文本 + voice: 语音名称 + rate: 语速 + volume: 音量 + + Yields: + MP3 音频数据块 + """ + if not text or not text.strip(): + return + + try: + communicate = edge_tts.Communicate( + text=text, + voice=voice, + rate=rate, + volume=volume, + ) + + async for chunk in communicate.stream(): + if chunk["type"] == "audio": + yield chunk["data"] + + except Exception as e: + print(f"[EdgeTTS] 合成失败: {e}") + + +async def text_to_speech( + text: str, + voice: str = DEFAULT_VOICE, + rate: str = DEFAULT_RATE, + volume: str = DEFAULT_VOLUME, +) -> bytes: + """ + 完整文本转语音(返回完整音频) + + Args: + text: 要合成的文本 + voice: 语音名称 + rate: 语速 + volume: 音量 + + Returns: + MP3 音频数据 + """ + audio_chunks = [] + async for chunk in text_to_speech_stream(text, voice, rate, volume): + audio_chunks.append(chunk) + return b"".join(audio_chunks) + + +async def text_to_speech_pcm( + text: str, + voice: str = DEFAULT_VOICE, + rate: str = DEFAULT_RATE, + target_sample_rate: int = 16000, +) -> bytes: + """ + 文本转 PCM16 音频(用于直接播放) + + Args: + text: 要合成的文本 + voice: 语音名称 + rate: 语速 + target_sample_rate: 目标采样率 + + Returns: + PCM16 音频数据 + """ + import io + from pydub import AudioSegment + + # 获取 MP3 数据 + mp3_data = await text_to_speech(text, voice, rate) + + if not mp3_data: + return b"" + + try: + # MP3 -> PCM 转换 + audio = AudioSegment.from_mp3(io.BytesIO(mp3_data)) + + # 设置采样率和通道 + audio = audio.set_frame_rate(target_sample_rate) + audio = audio.set_channels(1) # 单声道 + audio = audio.set_sample_width(2) # 16-bit + + return audio.raw_data + + except Exception as e: + print(f"[EdgeTTS] PCM 转换失败: {e}") + return b"" + + +async def text_to_speech_pcm_stream( + text: str, + voice: str = DEFAULT_VOICE, + rate: str = DEFAULT_RATE, + target_sample_rate: int = 16000, +) -> AsyncGenerator[bytes, None]: + """ + 流式文本转 PCM16 音频 + + 注意:由于需要解码 MP3,这里采用分段合成的方式 + 每遇到标点符号就合成一段 + + Args: + text: 要合成的文本 + voice: 语音名称 + rate: 语速 + target_sample_rate: 目标采样率 + + Yields: + PCM16 音频数据块 + """ + import io + from pydub import AudioSegment + + # 按标点分割文本 + punctuation = "。,!?;:,.!?;:" + segments = [] + current = "" + + for char in text: + current += char + if char in punctuation: + segments.append(current.strip()) + current = "" + + if current.strip(): + segments.append(current.strip()) + + # 逐段合成 + for segment in segments: + if not segment: + continue + + try: + mp3_data = await text_to_speech(segment, voice, rate) + + if mp3_data: + audio = AudioSegment.from_mp3(io.BytesIO(mp3_data)) + audio = audio.set_frame_rate(target_sample_rate) + audio = audio.set_channels(1) + audio = audio.set_sample_width(2) + + yield audio.raw_data + + except Exception as e: + print(f"[EdgeTTS] 分段合成失败: {e}") + + +# 语音列表(常用中文) +CHINESE_VOICES = [ + "zh-CN-XiaoxiaoNeural", # 女声,自然 + "zh-CN-YunxiNeural", # 男声,自然 + "zh-CN-XiaoyiNeural", # 女声,活泼 + "zh-CN-YunjianNeural", # 男声,播报 + "zh-CN-XiaochenNeural", # 女声,温柔 +] + + +async def list_voices() -> list: + """列出所有可用语音""" + voices = await edge_tts.list_voices() + return [v for v in voices if v["Locale"].startswith("zh")] diff --git a/glm_client.py b/glm_client.py new file mode 100644 index 0000000..abaad6c --- /dev/null +++ b/glm_client.py @@ -0,0 +1,208 @@ +# glm_client.py +# -*- coding: utf-8 -*- +""" +GLM-4.6v-Flash LLM 客户端 - Day 22 + +使用官方 zai-sdk + glm-4.6v-flash 模型 +""" + +import os +import asyncio +from datetime import datetime +from typing import AsyncGenerator, Optional +from zai import ZhipuAiClient + +# API 配置 +API_KEY = os.getenv( + "GLM_API_KEY", + "5915240ea48d4e93b454bc2412d1cc54.e054ej4pPqi9G6rc" +) +MODEL = "glm-4.6v-flash" # 升级到 glm-4.6v-flash (支持视觉) + +# 星期映射 +WEEKDAY_MAP = ["星期一", "星期二", "星期三", "星期四", "星期五", "星期六", "星期日"] + + +def get_system_prompt() -> str: + """动态生成 system prompt,包含当前时间信息""" + now = datetime.now() + current_time = now.strftime("%H:%M") + current_date = now.strftime("%Y年%m月%d日") + current_weekday = WEEKDAY_MAP[now.weekday()] + + return f"""你是一个视障辅助AI助手,安装在智能导盲眼镜上。 +当前时间:{current_time} +今天日期:{current_date} {current_weekday} + +请用极简短的语言回答,每次回答不超过2-3句话。 +避免冗长解释,只提供最关键的信息。 +语气友好但简洁。""" + + +# 客户端和对话历史 +_client = None +_conversation_history = [] +MAX_HISTORY_TURNS = 5 # 保留最近5轮对话 + + +def _get_client() -> ZhipuAiClient: + """获取智谱 AI 客户端""" + global _client + if _client is None: + _client = ZhipuAiClient(api_key=API_KEY) + return _client + + +def clear_conversation_history(): + """清除对话历史""" + global _conversation_history + _conversation_history = [] + print("[GLM] 对话历史已清除") + + +async def chat(user_message: str, image_base64: Optional[str] = None) -> str: + """ + 与 GLM-4.6v-Flash 对话(带上下文记忆) + + Args: + user_message: 用户消息文本 + image_base64: 可选,Base64 编码的图片 + + Returns: + AI 回复文本 + """ + global _conversation_history + client = _get_client() + + # 构建用户消息 + if image_base64: + # 多模态消息(带图片) + user_content = [ + {"type": "text", "text": user_message}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}} + ] + else: + user_content = user_message + + # 添加用户消息到历史 + _conversation_history.append({"role": "user", "content": user_content}) + + # 限制历史长度(每轮 = 1用户 + 1助手 = 2条消息) + max_messages = MAX_HISTORY_TURNS * 2 + if len(_conversation_history) > max_messages: + _conversation_history = _conversation_history[-max_messages:] + + # 构建完整消息列表(每次动态生成包含当前时间的 system prompt) + messages = [{"role": "system", "content": get_system_prompt()}] + _conversation_history + + # Day 22: 添加重试逻辑处理速率限制 + max_retries = 3 + retry_delay = 1 # 初始延迟1秒 + + for attempt in range(max_retries): + try: + # Day 22: 升级到 glm-4.6v-flash + # 【修正】根据官方文档,thinking 参数也是必须的 (即使是 Vision 模型) + response = await asyncio.to_thread( + client.chat.completions.create, + model=MODEL, + messages=messages, + thinking={"type": "disabled"}, # 显式禁用思考以降低延迟 + ) + + if response.choices and len(response.choices) > 0: + ai_reply = response.choices[0].message.content.strip() + # 添加助手回复到历史 + _conversation_history.append({"role": "assistant", "content": ai_reply}) + print(f"[GLM] 回复: {ai_reply[:50]}..." if len(ai_reply) > 50 else f"[GLM] 回复: {ai_reply}") + return ai_reply + return "" + + except Exception as e: + error_str = str(e) + # 检查是否是速率限制错误(429 或 1305) + if "429" in error_str or "1305" in error_str or "请求过多" in error_str: + if attempt < max_retries - 1: + print(f"[GLM] 速率限制,{retry_delay}秒后重试... (尝试 {attempt + 1}/{max_retries})") + await asyncio.sleep(retry_delay) + retry_delay *= 2 # 指数退避 + continue + + print(f"[GLM] 调用失败: {e}") + import traceback + traceback.print_exc() + break + + # 所有重试失败,移除用户消息 + if _conversation_history and _conversation_history[-1]["role"] == "user": + _conversation_history.pop() + return "抱歉,我暂时无法回答。" + + +async def chat_stream(user_message: str, image_base64: Optional[str] = None) -> AsyncGenerator[str, None]: + """ + 流式对话(逐字返回)- GLM-4.6v-Flash + + Args: + user_message: 用户消息文本 + image_base64: 可选,Base64 编码的图片 + + Yields: + AI 回复的文本片段 + """ + global _conversation_history + client = _get_client() + + # 构建用户消息 + if image_base64: + user_content = [ + {"type": "text", "text": user_message}, + {"type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_base64}"}} + ] + else: + user_content = user_message + + # 添加用户消息到历史 + _conversation_history.append({"role": "user", "content": user_content}) + + # 限制历史长度 + max_messages = MAX_HISTORY_TURNS * 2 + if len(_conversation_history) > max_messages: + _conversation_history = _conversation_history[-max_messages:] + + # 构建完整消息列表 + messages = [{"role": "system", "content": get_system_prompt()}] + _conversation_history + + full_response = "" + + try: + # 流式调用 + # Day 22: 升级到 glm-4.6v-flash + # 【修正】根据官方文档,thinking 参数也是必须的 + response = await asyncio.to_thread( + client.chat.completions.create, + model=MODEL, + messages=messages, + thinking={"type": "disabled"}, + stream=True, + ) + + for chunk in response: + if chunk.choices[0].delta.content: + text = chunk.choices[0].delta.content + full_response += text + yield text + + # 添加完整回复到历史 + if full_response: + _conversation_history.append({"role": "assistant", "content": full_response}) + print(f"[GLM] 流式完成: {full_response[:50]}..." if len(full_response) > 50 else f"[GLM] 流式完成: {full_response}") + + except Exception as e: + print(f"[GLM] 流式调用失败: {e}") + import traceback + traceback.print_exc() + # 移除刚才添加的用户消息 + if _conversation_history and _conversation_history[-1]["role"] == "user": + _conversation_history.pop() + yield "抱歉,我暂时无法回答。" diff --git a/gpu_parallel.py b/gpu_parallel.py new file mode 100644 index 0000000..3875475 --- /dev/null +++ b/gpu_parallel.py @@ -0,0 +1,271 @@ +# gpu_parallel.py +# -*- coding: utf-8 -*- +""" +Day 20: GPU 并行推理优化模块 +使用 CUDA Stream 让盲道检测和障碍物检测并行执行 +""" +import os +import time +import torch +import numpy as np +from typing import Tuple, Optional, List, Any +from concurrent.futures import ThreadPoolExecutor +import logging + +logger = logging.getLogger(__name__) + +# 全局 CUDA Stream(延迟初始化) +_cuda_streams = None +_parallel_executor = None + +def _init_cuda_streams(): + """初始化 CUDA Streams""" + global _cuda_streams + if _cuda_streams is None and torch.cuda.is_available(): + try: + _cuda_streams = [torch.cuda.Stream() for _ in range(2)] + logger.info("[GPU_PARALLEL] 已创建 2 个 CUDA Stream") + except Exception as e: + logger.warning(f"[GPU_PARALLEL] 创建 CUDA Stream 失败: {e}") + _cuda_streams = [] + return _cuda_streams + +def _init_parallel_executor(): + """初始化并行执行器(用于 CPU 后处理)""" + global _parallel_executor + if _parallel_executor is None: + _parallel_executor = ThreadPoolExecutor(max_workers=2, thread_name_prefix="gpu_post") + logger.info("[GPU_PARALLEL] 已创建 GPU 后处理线程池") + return _parallel_executor + +class ParallelDetector: + """ + 并行检测器 - 同时执行盲道检测和障碍物检测 + + 使用方式: + detector = ParallelDetector(yolo_model, obstacle_detector) + blind_mask, cross_mask, obstacles = detector.detect_all(image, path_mask) + """ + + def __init__(self, yolo_model, obstacle_detector): + self.yolo_model = yolo_model + self.obstacle_detector = obstacle_detector + + # 检测参数(从环境变量读取) + self.imgsz = int(os.getenv("AIGLASS_YOLO_IMGSZ", "480")) + self.use_half = os.getenv("AIGLASS_YOLO_HALF", "1") == "1" + self.blind_conf_threshold = 0.20 + self.cross_conf_threshold = 0.30 + + # 初始化 + _init_cuda_streams() + _init_parallel_executor() + + logger.info(f"[GPU_PARALLEL] ParallelDetector 初始化完成: imgsz={self.imgsz}, half={self.use_half}") + + def detect_all( + self, + image: np.ndarray, + path_mask: Optional[np.ndarray] = None + ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List[Any]]: + """ + 并行执行所有检测 + + Args: + image: BGR 图像 + path_mask: 盲道掩码(用于障碍物过滤) + + Returns: + (blind_path_mask, crosswalk_mask, obstacles) + """ + t0 = time.perf_counter() + + streams = _init_cuda_streams() + + if streams and len(streams) >= 2: + return self._detect_with_streams(image, path_mask, streams) + else: + # 回退到串行执行 + return self._detect_serial(image, path_mask) + + def _detect_with_streams( + self, + image: np.ndarray, + path_mask: Optional[np.ndarray], + streams: List[torch.cuda.Stream] + ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List[Any]]: + """Day 21: 使用 ThreadPoolExecutor 真正并行检测(替代无效的 CUDA Stream)""" + + blind_mask = None + cross_mask = None + obstacles = [] + + executor = _init_parallel_executor() + + # 定义两个检测任务 + def task_blind_path(): + if self.yolo_model is None: + return None, None + try: + results = self.yolo_model.predict( + image, + verbose=False, + conf=min(self.blind_conf_threshold, self.cross_conf_threshold), + classes=[0, 1], # 0=crosswalk, 1=blind_path + imgsz=self.imgsz, + half=self.use_half + ) + if results and results[0] and results[0].masks is not None: + return self._parse_seg_results(results[0], image.shape) + except Exception as e: + logger.error(f"[GPU_PARALLEL] 盲道检测失败: {e}") + return None, None + + def task_obstacles(): + if self.obstacle_detector is None: + return [] + try: + return self.obstacle_detector.detect(image, path_mask=path_mask) + except Exception as e: + logger.error(f"[GPU_PARALLEL] 障碍物检测失败: {e}") + return [] + + # 并行提交两个任务 + from concurrent.futures import as_completed + futures = { + executor.submit(task_blind_path): 'blind', + executor.submit(task_obstacles): 'obstacle' + } + + # 等待所有任务完成 + for future in as_completed(futures, timeout=2.0): + task_type = futures[future] + try: + result = future.result() + if task_type == 'blind': + blind_mask, cross_mask = result + else: + obstacles = result + except Exception as e: + logger.error(f"[GPU_PARALLEL] {task_type}任务异常: {e}") + + return blind_mask, cross_mask, obstacles + + def _detect_serial( + self, + image: np.ndarray, + path_mask: Optional[np.ndarray] + ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List[Any]]: + """串行检测(回退模式)""" + + blind_mask = None + cross_mask = None + obstacles = [] + + # 盲道检测 + if self.yolo_model is not None: + try: + results = self.yolo_model.predict( + image, + verbose=False, + conf=min(self.blind_conf_threshold, self.cross_conf_threshold), + classes=[0, 1], + imgsz=self.imgsz, + half=self.use_half + ) + + if results and results[0] and results[0].masks is not None: + blind_mask, cross_mask = self._parse_seg_results(results[0], image.shape) + except Exception as e: + logger.error(f"[GPU_PARALLEL] 盲道检测失败: {e}") + + # 障碍物检测 + if self.obstacle_detector is not None: + try: + obstacles = self.obstacle_detector.detect(image, path_mask=path_mask) + except Exception as e: + logger.error(f"[GPU_PARALLEL] 障碍物检测失败: {e}") + + return blind_mask, cross_mask, obstacles + + def _parse_seg_results( + self, + result, + image_shape: Tuple[int, int, int] + ) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: + """解析 YOLO 分割结果""" + + blind_mask = None + cross_mask = None + + h, w = image_shape[:2] + + if result.masks is None or result.boxes is None: + return None, None + + for mask_tensor, conf_tensor, cls_tensor in zip( + result.masks.data, result.boxes.conf, result.boxes.cls + ): + class_id = int(cls_tensor.item()) + confidence = float(conf_tensor.item()) + + # 置信度过滤 + if class_id == 1 and confidence < self.blind_conf_threshold: + continue + if class_id == 0 and confidence < self.cross_conf_threshold: + continue + + # 转换掩码 + current_mask = self._tensor_to_mask(mask_tensor, w, h) + + if class_id == 1: # 盲道 + if blind_mask is None: + blind_mask = current_mask + else: + blind_mask = np.bitwise_or(blind_mask, current_mask) + elif class_id == 0: # 斑马线 + if cross_mask is None: + cross_mask = current_mask + else: + cross_mask = np.bitwise_or(cross_mask, current_mask) + + return blind_mask, cross_mask + + def _tensor_to_mask( + self, + mask_tensor: torch.Tensor, + out_w: int, + out_h: int + ) -> np.ndarray: + """将 PyTorch 张量掩码转换为 NumPy 数组""" + import cv2 + + # 转换为 numpy + if mask_tensor.is_cuda: + mask_np = mask_tensor.cpu().numpy() + else: + mask_np = mask_tensor.numpy() + + # 调整大小 + if mask_np.shape[0] != out_h or mask_np.shape[1] != out_w: + mask_np = cv2.resize(mask_np, (out_w, out_h), interpolation=cv2.INTER_NEAREST) + + # 二值化 + mask_np = (mask_np > 0.5).astype(np.uint8) * 255 + + return mask_np + + +def detect_all_parallel( + yolo_model, + obstacle_detector, + image: np.ndarray, + path_mask: Optional[np.ndarray] = None +) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], List[Any]]: + """ + 便捷函数:并行执行所有检测 + + 用于替换 workflow_blindpath.py 中的串行检测 + """ + detector = ParallelDetector(yolo_model, obstacle_detector) + return detector.detect_all(image, path_mask) diff --git a/hand_landmarker.task b/hand_landmarker.task new file mode 100644 index 0000000..0d53faf Binary files /dev/null and b/hand_landmarker.task differ diff --git a/model/SenseVoiceSmall/chn_jpn_yue_eng_ko_spectok.bpe.model b/model/SenseVoiceSmall/chn_jpn_yue_eng_ko_spectok.bpe.model new file mode 100644 index 0000000..da7e375 Binary files /dev/null and b/model/SenseVoiceSmall/chn_jpn_yue_eng_ko_spectok.bpe.model differ diff --git a/model/SenseVoiceSmall/config.yaml b/model/SenseVoiceSmall/config.yaml new file mode 100644 index 0000000..b1dd72d --- /dev/null +++ b/model/SenseVoiceSmall/config.yaml @@ -0,0 +1,97 @@ +encoder: SenseVoiceEncoderSmall +encoder_conf: + output_size: 512 + attention_heads: 4 + linear_units: 2048 + num_blocks: 50 + tp_blocks: 20 + dropout_rate: 0.1 + positional_dropout_rate: 0.1 + attention_dropout_rate: 0.1 + input_layer: pe + pos_enc_class: SinusoidalPositionEncoder + normalize_before: true + kernel_size: 11 + sanm_shfit: 0 + selfattention_layer_type: sanm + + +model: SenseVoiceSmall +model_conf: + length_normalized_loss: true + sos: 1 + eos: 2 + ignore_id: -1 + +tokenizer: SentencepiecesTokenizer +tokenizer_conf: + bpemodel: null + unk_symbol: + split_with_space: true + +frontend: WavFrontend +frontend_conf: + fs: 16000 + window: hamming + n_mels: 80 + frame_length: 25 + frame_shift: 10 + lfr_m: 7 + lfr_n: 6 + cmvn_file: null + + +dataset: SenseVoiceCTCDataset +dataset_conf: + index_ds: IndexDSJsonl + batch_sampler: EspnetStyleBatchSampler + data_split_num: 32 + batch_type: token + batch_size: 14000 + max_token_length: 2000 + min_token_length: 60 + max_source_length: 2000 + min_source_length: 60 + max_target_length: 200 + min_target_length: 0 + shuffle: true + num_workers: 4 + sos: ${model_conf.sos} + eos: ${model_conf.eos} + IndexDSJsonl: IndexDSJsonl + retry: 20 + +train_conf: + accum_grad: 1 + grad_clip: 5 + max_epoch: 20 + keep_nbest_models: 10 + avg_nbest_model: 10 + log_interval: 100 + resume: true + validate_interval: 10000 + save_checkpoint_interval: 10000 + +optim: adamw +optim_conf: + lr: 0.00002 +scheduler: warmuplr +scheduler_conf: + warmup_steps: 25000 + +specaug: SpecAugLFR +specaug_conf: + apply_time_warp: false + time_warp_window: 5 + time_warp_mode: bicubic + apply_freq_mask: true + freq_mask_width_range: + - 0 + - 30 + lfr_rate: 6 + num_freq_mask: 1 + apply_time_mask: true + time_mask_width_range: + - 0 + - 12 + num_time_mask: 1 diff --git a/model/SenseVoiceSmall/configuration.json b/model/SenseVoiceSmall/configuration.json new file mode 100644 index 0000000..c01fe5b --- /dev/null +++ b/model/SenseVoiceSmall/configuration.json @@ -0,0 +1,14 @@ +{ + "framework": "pytorch", + "task" : "auto-speech-recognition", + "model": {"type" : "funasr"}, + "pipeline": {"type":"funasr-pipeline"}, + "model_name_in_hub": { + "ms":"", + "hf":""}, + "file_path_metas": { + "init_param":"model.pt", + "config":"config.yaml", + "tokenizer_conf": {"bpemodel": "chn_jpn_yue_eng_ko_spectok.bpe.model"}, + "frontend_conf":{"cmvn_file": "am.mvn"}} +} \ No newline at end of file diff --git a/model_utils.py b/model_utils.py new file mode 100644 index 0000000..d069139 --- /dev/null +++ b/model_utils.py @@ -0,0 +1,37 @@ +# model_utils.py - Day 20 TensorRT 模型加载工具 +""" +优先加载 TensorRT .engine 文件,不存在时回退到 .pt +""" + +import os + + +def get_best_model_path(pt_path: str) -> str: + """ + 根据 .pt 路径自动选择最佳模型文件: + 1. 优先加载同目录下的 .engine 文件(TensorRT 加速) + 2. 如果 .engine 不存在,回退到原 .pt 文件 + + 参数: + pt_path: 原始 .pt 模型路径(如 'model/yolo-seg.pt') + + 返回: + 最佳模型路径(.engine 或 .pt) + """ + if not pt_path.endswith('.pt'): + return pt_path + + engine_path = pt_path.replace('.pt', '.engine') + + if os.path.exists(engine_path): + print(f"[MODEL] 🚀 使用 TensorRT 加速: {engine_path}") + return engine_path + else: + print(f"[MODEL] 使用 PyTorch 模型: {pt_path}") + return pt_path + + +def is_tensorrt_engine(model_path: str) -> bool: + """检查模型路径是否是 TensorRT 引擎(.engine 文件)""" + return model_path.endswith('.engine') or model_path.endswith('.trt') + diff --git a/models.py b/models.py new file mode 100644 index 0000000..2acf1a9 --- /dev/null +++ b/models.py @@ -0,0 +1,159 @@ +# app/models.py +import os +import logging +import torch +from threading import Semaphore +from contextlib import contextmanager +from typing import List +from app.cloud.obstacle_detector_client import ObstacleDetectorClient +# ========================================================== +# 0. 导入所有需要的模型封装类 (Clients) 和 Ultralytics 基类 +# ========================================================== +# 这是过马路工作流使用的封装类 +from app.cloud.crosswalk_detector_client import CrosswalkDetector +from app.cloud.coco_perception_client import COCOClient +from obstacle_detector_client import ObstacleDetectorClient + +# Day 20: TensorRT 模型加载工具 +from model_utils import get_best_model_path, is_tensorrt_engine + +# 这是盲道工作流直接使用的 Ultralytics 类 +from ultralytics import YOLO, YOLOE + +logger = logging.getLogger(__name__) + +# ========================================================== +# 1. 全局设备与并发控制 (统一管理) +# ========================================================== +DEVICE = os.getenv("AIGLASS_DEVICE", "cuda:0") +if DEVICE.startswith("cuda") and not torch.cuda.is_available(): + logger.warning(f"AIGLASS_DEVICE={DEVICE} 但未检测到 CUDA,将回退到 CPU") + DEVICE = "cpu" +IS_CUDA = DEVICE.startswith("cuda") + +# AMP (自动混合精度) 配置 +AMP_POLICY = os.getenv("AIGLASS_AMP", "bf16").lower() +AMP_DTYPE = torch.bfloat16 if AMP_POLICY == "bf16" else ( + torch.float16 if AMP_POLICY == "fp16" else None) if IS_CUDA else None + +# 🔥 核心:全局唯一的GPU并发信号量,所有工作流共享 +GPU_SLOTS = int(os.getenv("AIGLASS_GPU_SLOTS", "2")) +gpu_semaphore = Semaphore(GPU_SLOTS) + + +# 统一的推理上下文管理器,所有工作流都应使用它来调用模型 +@contextmanager +def gpu_infer_slot(): + """ + 统一管理:GPU 并发限流 + torch.inference_mode() + AMP autocast + """ + with gpu_semaphore: + if IS_CUDA and AMP_POLICY != "off" and AMP_DTYPE is not None: + with torch.inference_mode(), torch.amp.autocast('cuda', dtype=AMP_DTYPE): + yield + else: + with torch.inference_mode(): + yield + + +# cuDNN 加速优化 +try: + if IS_CUDA: + torch.backends.cudnn.benchmark = True +except Exception: + pass + +# ========================================================== +# 2. 全局模型实例定义 (全部初始化为 None) +# ========================================================== + +# --- 过马路工作流模型 (通过Client类封装) --- +crosswalk_detector_client: CrosswalkDetector = None +coco_client: COCOClient = None +# ObstacleDetectorClient 将作为所有场景的通用障碍物检测器 +obstacle_detector_client: ObstacleDetectorClient = None + +# --- 盲道工作流模型 (直接使用Ultralytics类) --- +# 它们主要用于分割和路径规划,与过马路场景的检测逻辑不同 +blindpath_seg_model: YOLO = None +# 障碍物检测将复用 obstacle_detector_client,但YOLOE的文本特征需要单独保存 +blindpath_whitelist_embeddings = None + +# 全局加载状态标志 +models_are_loaded = False + + +# ========================================================== +# 3. 统一的模型加载函数 (由 celery.py 在启动时调用) +# ========================================================== +def init_all_models(): + """ + 在Celery Worker进程启动时被调用一次。 + 负责加载所有工作流所需的模型到全局变量中。 + """ + global models_are_loaded + if models_are_loaded: + return + + logger.info(f"========= 🚀 开始全局模型预加载 (目标设备: {DEVICE}) =========") + + try: + # --- [1] 加载通用的障碍物检测器 (ObstacleDetectorClient) --- + global obstacle_detector_client + logger.info("[1/4] 正在加载通用障碍物检测模型 (ObstacleDetectorClient)...") + # Day 20: 优先使用 TensorRT 引擎 + obs_model_path = get_best_model_path('model/yoloe-11l-seg.pt') + obstacle_detector_client = ObstacleDetectorClient(model_path=obs_model_path) + + # Day 20: TensorRT 引擎不需要 .to() + if not is_tensorrt_engine(obs_model_path): + if hasattr(obstacle_detector_client, 'model') and obstacle_detector_client.model is not None: + obstacle_detector_client.model.to(DEVICE) + + logger.info("...通用障碍物检测模型加载成功。") + + # --- [2] 加载过马路专用的模型 (Clients) --- + global crosswalk_detector_client, coco_client + logger.info("[2/4] 正在加载过马路分割模型 (CrosswalkDetector)...") + # Day 20: 优先使用 TensorRT 引擎 + crosswalk_model_path = get_best_model_path('model/yolo-seg.pt') + crosswalk_detector_client = CrosswalkDetector(model_path=crosswalk_model_path) + # Day 20: TensorRT 引擎不需要 .to() + if not is_tensorrt_engine(crosswalk_model_path): + if hasattr(crosswalk_detector_client, 'model') and crosswalk_detector_client.model is not None: + crosswalk_detector_client.model.to(DEVICE) + logger.info("...过马路分割模型加载成功。") + + logger.info("[3/4] 正在加载通用感知模型 (COCOClient)...") + coco_client = COCOClient(model_path='model/yolov8l-world.pt') + # 将其内部的YOLO模型移动到指定设备 + if hasattr(coco_client, 'model') and coco_client.model is not None: + coco_client.model.to(DEVICE) + logger.info("...通用感知模型加载成功。") + + # --- [4] 加载盲道专用的模型 --- + global blindpath_seg_model, blindpath_whitelist_embeddings + logger.info("[4/4] 正在加载盲道专用分割模型 (YOLO)...") + # Day 20: 优先使用 TensorRT 引擎 + blindpath_model_path = get_best_model_path('model/yolo-seg.pt') + blindpath_seg_model = YOLO(blindpath_model_path) + # Day 20: TensorRT 引擎不需要 .to() 和 .fuse() + if not is_tensorrt_engine(blindpath_model_path): + blindpath_seg_model.to(DEVICE) + blindpath_seg_model.fuse() + logger.info("...盲道专用分割模型加载成功。") + + # 为盲道工作流保存其需要的YOLOE文本特征引用 + if obstacle_detector_client: + blindpath_whitelist_embeddings = obstacle_detector_client.whitelist_embeddings + logger.info("...已为盲道工作流链接障碍物模型特征。") + + # 所有模型加载完毕 + models_are_loaded = True + logger.info("========= ✅ 所有模型已成功预加载。Worker准备就绪! =========") + + except Exception as e: + logger.error(f"模型预加载过程中发生严重错误: {e}", exc_info=True) + # 抛出异常,这将导致Celery Worker启动失败,这是合理的行为 + # 因为一个没有模型的Worker是无用的,提前暴露问题更好。 + raise \ No newline at end of file diff --git a/music/converted_向上.wav b/music/converted_向上.wav new file mode 100644 index 0000000..78f153b Binary files /dev/null and b/music/converted_向上.wav differ diff --git a/music/converted_向下.wav b/music/converted_向下.wav new file mode 100644 index 0000000..8263f70 Binary files /dev/null and b/music/converted_向下.wav differ diff --git a/music/converted_向前.wav b/music/converted_向前.wav new file mode 100644 index 0000000..bd9202a Binary files /dev/null and b/music/converted_向前.wav differ diff --git a/music/converted_向右.wav b/music/converted_向右.wav new file mode 100644 index 0000000..d72c46b Binary files /dev/null and b/music/converted_向右.wav differ diff --git a/music/converted_向后.wav b/music/converted_向后.wav new file mode 100644 index 0000000..247c718 Binary files /dev/null and b/music/converted_向后.wav differ diff --git a/music/converted_向左.wav b/music/converted_向左.wav new file mode 100644 index 0000000..b7cff2e Binary files /dev/null and b/music/converted_向左.wav differ diff --git a/music/converted_已对中.wav b/music/converted_已对中.wav new file mode 100644 index 0000000..4b1b66a Binary files /dev/null and b/music/converted_已对中.wav differ diff --git a/music/converted_找到啦.wav b/music/converted_找到啦.wav new file mode 100644 index 0000000..dc2c4e3 Binary files /dev/null and b/music/converted_找到啦.wav differ diff --git a/music/converted_拿到啦.wav b/music/converted_拿到啦.wav new file mode 100644 index 0000000..a17b653 Binary files /dev/null and b/music/converted_拿到啦.wav differ diff --git a/music/converted_音频1.WAV b/music/converted_音频1.WAV new file mode 100644 index 0000000..fbe4c0e Binary files /dev/null and b/music/converted_音频1.WAV differ diff --git a/music/converted_音频2.WAV b/music/converted_音频2.WAV new file mode 100644 index 0000000..1a5d453 Binary files /dev/null and b/music/converted_音频2.WAV differ diff --git a/music/converted_音频3.WAV b/music/converted_音频3.WAV new file mode 100644 index 0000000..4ee77ed Binary files /dev/null and b/music/converted_音频3.WAV differ diff --git a/music/converted_音频4.WAV b/music/converted_音频4.WAV new file mode 100644 index 0000000..4654407 Binary files /dev/null and b/music/converted_音频4.WAV differ diff --git a/music/converted_音频5.WAV b/music/converted_音频5.WAV new file mode 100644 index 0000000..9d8c70b Binary files /dev/null and b/music/converted_音频5.WAV differ diff --git a/music/converted_音频6.WAV b/music/converted_音频6.WAV new file mode 100644 index 0000000..a847d4e Binary files /dev/null and b/music/converted_音频6.WAV differ diff --git a/music/converted_音频7.WAV b/music/converted_音频7.WAV new file mode 100644 index 0000000..dcf9c83 Binary files /dev/null and b/music/converted_音频7.WAV differ diff --git a/music/converted_音频8.WAV b/music/converted_音频8.WAV new file mode 100644 index 0000000..9d436a0 Binary files /dev/null and b/music/converted_音频8.WAV differ diff --git a/music/converted_音频9.WAV b/music/converted_音频9.WAV new file mode 100644 index 0000000..2e4c8f8 Binary files /dev/null and b/music/converted_音频9.WAV differ diff --git a/music/向上.txt b/music/向上.txt new file mode 100644 index 0000000..9b463d6 --- /dev/null +++ b/music/向上.txt @@ -0,0 +1 @@ +向上。如果还有其他想法,你可以随时告诉我哦。 \ No newline at end of file diff --git a/music/向上.wav b/music/向上.wav new file mode 100644 index 0000000..8e3b097 Binary files /dev/null and b/music/向上.wav differ diff --git a/music/向下.txt b/music/向下.txt new file mode 100644 index 0000000..f1df867 --- /dev/null +++ b/music/向下.txt @@ -0,0 +1 @@ +向下。如果还有啥想法,你可以再跟我说哦。 \ No newline at end of file diff --git a/music/向下.wav b/music/向下.wav new file mode 100644 index 0000000..1caa62a Binary files /dev/null and b/music/向下.wav differ diff --git a/music/向前.txt b/music/向前.txt new file mode 100644 index 0000000..f4aac71 --- /dev/null +++ b/music/向前.txt @@ -0,0 +1 @@ +向前。如果还有啥想法,你可以再跟我说哦。 \ No newline at end of file diff --git a/music/向前.wav b/music/向前.wav new file mode 100644 index 0000000..a11a8a3 Binary files /dev/null and b/music/向前.wav differ diff --git a/music/向右.txt b/music/向右.txt new file mode 100644 index 0000000..ef5e13d --- /dev/null +++ b/music/向右.txt @@ -0,0 +1 @@ +向右。如果还有啥想法,你可以再跟我说哦。 \ No newline at end of file diff --git a/music/向右.wav b/music/向右.wav new file mode 100644 index 0000000..5289914 Binary files /dev/null and b/music/向右.wav differ diff --git a/music/向后.txt b/music/向后.txt new file mode 100644 index 0000000..bf66582 --- /dev/null +++ b/music/向后.txt @@ -0,0 +1 @@ +向后。如果还有啥想法,你可以再跟我说哦。 \ No newline at end of file diff --git a/music/向后.wav b/music/向后.wav new file mode 100644 index 0000000..5d12772 Binary files /dev/null and b/music/向后.wav differ diff --git a/music/向左.txt b/music/向左.txt new file mode 100644 index 0000000..b988b32 --- /dev/null +++ b/music/向左.txt @@ -0,0 +1 @@ +向左。如果还有其他想法,你可以随时告诉我哦。 \ No newline at end of file diff --git a/music/向左.wav b/music/向左.wav new file mode 100644 index 0000000..c97dea1 Binary files /dev/null and b/music/向左.wav differ diff --git a/music/在画面中间.WAV b/music/在画面中间.WAV new file mode 100644 index 0000000..68befb0 Binary files /dev/null and b/music/在画面中间.WAV differ diff --git a/music/在画面中间.txt b/music/在画面中间.txt new file mode 100644 index 0000000..061d5aa --- /dev/null +++ b/music/在画面中间.txt @@ -0,0 +1 @@ +“在画面中间” \ No newline at end of file diff --git a/music/在画面中间_24k.wav b/music/在画面中间_24k.wav new file mode 100644 index 0000000..853a937 Binary files /dev/null and b/music/在画面中间_24k.wav differ diff --git a/music/在画面右侧.WAV b/music/在画面右侧.WAV new file mode 100644 index 0000000..e0d175b Binary files /dev/null and b/music/在画面右侧.WAV differ diff --git a/music/在画面右侧.txt b/music/在画面右侧.txt new file mode 100644 index 0000000..43fcc5d --- /dev/null +++ b/music/在画面右侧.txt @@ -0,0 +1 @@ +“在画面右侧” \ No newline at end of file diff --git a/music/在画面右侧_24k.wav b/music/在画面右侧_24k.wav new file mode 100644 index 0000000..d83a129 Binary files /dev/null and b/music/在画面右侧_24k.wav differ diff --git a/music/在画面左侧.WAV b/music/在画面左侧.WAV new file mode 100644 index 0000000..2e10c3e Binary files /dev/null and b/music/在画面左侧.WAV differ diff --git a/music/在画面左侧.txt b/music/在画面左侧.txt new file mode 100644 index 0000000..4bd9477 --- /dev/null +++ b/music/在画面左侧.txt @@ -0,0 +1 @@ +“在画面左侧” \ No newline at end of file diff --git a/music/在画面左侧_24k.wav b/music/在画面左侧_24k.wav new file mode 100644 index 0000000..5dea225 Binary files /dev/null and b/music/在画面左侧_24k.wav differ diff --git a/music/已对中.txt b/music/已对中.txt new file mode 100644 index 0000000..d8d0bd6 --- /dev/null +++ b/music/已对中.txt @@ -0,0 +1 @@ +已对正!如果还有其他想法或者问题,你可以随时告诉我哦。 \ No newline at end of file diff --git a/music/已对中.wav b/music/已对中.wav new file mode 100644 index 0000000..e713fa2 Binary files /dev/null and b/music/已对中.wav differ diff --git a/music/找到啦.txt b/music/找到啦.txt new file mode 100644 index 0000000..1f2c468 --- /dev/null +++ b/music/找到啦.txt @@ -0,0 +1 @@ +找到了! \ No newline at end of file diff --git a/music/找到啦.wav b/music/找到啦.wav new file mode 100644 index 0000000..93ad225 Binary files /dev/null and b/music/找到啦.wav differ diff --git a/music/拿到啦.txt b/music/拿到啦.txt new file mode 100644 index 0000000..3c48626 --- /dev/null +++ b/music/拿到啦.txt @@ -0,0 +1 @@ +拿到了!如果还有啥问题,你可以再跟我说哦。 \ No newline at end of file diff --git a/music/拿到啦.wav b/music/拿到啦.wav new file mode 100644 index 0000000..9d6aaa0 Binary files /dev/null and b/music/拿到啦.wav differ diff --git a/music/接近斑马线.WAV b/music/接近斑马线.WAV new file mode 100644 index 0000000..314ee6a Binary files /dev/null and b/music/接近斑马线.WAV differ diff --git a/music/接近斑马线.txt b/music/接近斑马线.txt new file mode 100644 index 0000000..08bb25d --- /dev/null +++ b/music/接近斑马线.txt @@ -0,0 +1 @@ +“接近斑马线” \ No newline at end of file diff --git a/music/接近斑马线_24k.wav b/music/接近斑马线_24k.wav new file mode 100644 index 0000000..b128006 Binary files /dev/null and b/music/接近斑马线_24k.wav differ diff --git a/music/斑马线到了可以过马路.WAV b/music/斑马线到了可以过马路.WAV new file mode 100644 index 0000000..5d3d661 Binary files /dev/null and b/music/斑马线到了可以过马路.WAV differ diff --git a/music/斑马线到了可以过马路.txt b/music/斑马线到了可以过马路.txt new file mode 100644 index 0000000..2a77950 --- /dev/null +++ b/music/斑马线到了可以过马路.txt @@ -0,0 +1 @@ +“斑马线到了可以过马路”。如果还有类似的问题或者其他想法,你可以随时告诉我哦。 \ No newline at end of file diff --git a/music/斑马线到了可以过马路_24k.wav b/music/斑马线到了可以过马路_24k.wav new file mode 100644 index 0000000..0809f65 Binary files /dev/null and b/music/斑马线到了可以过马路_24k.wav differ diff --git a/music/正在靠近斑马线.WAV b/music/正在靠近斑马线.WAV new file mode 100644 index 0000000..cf0ec57 Binary files /dev/null and b/music/正在靠近斑马线.WAV differ diff --git a/music/正在靠近斑马线.txt b/music/正在靠近斑马线.txt new file mode 100644 index 0000000..ed31e7a --- /dev/null +++ b/music/正在靠近斑马线.txt @@ -0,0 +1 @@ +“正在靠近斑马线” \ No newline at end of file diff --git a/music/正在靠近斑马线_24k.wav b/music/正在靠近斑马线_24k.wav new file mode 100644 index 0000000..017ee91 Binary files /dev/null and b/music/正在靠近斑马线_24k.wav differ diff --git a/music/红灯.WAV b/music/红灯.WAV new file mode 100644 index 0000000..f53ff7f Binary files /dev/null and b/music/红灯.WAV differ diff --git a/music/绿灯.WAV b/music/绿灯.WAV new file mode 100644 index 0000000..1a5725c Binary files /dev/null and b/music/绿灯.WAV differ diff --git a/music/远处发现斑马线.WAV b/music/远处发现斑马线.WAV new file mode 100644 index 0000000..49488e3 Binary files /dev/null and b/music/远处发现斑马线.WAV differ diff --git a/music/远处发现斑马线.txt b/music/远处发现斑马线.txt new file mode 100644 index 0000000..975bed2 --- /dev/null +++ b/music/远处发现斑马线.txt @@ -0,0 +1 @@ +“远处发现斑马线” \ No newline at end of file diff --git a/music/远处发现斑马线_24k.wav b/music/远处发现斑马线_24k.wav new file mode 100644 index 0000000..3a39e4f Binary files /dev/null and b/music/远处发现斑马线_24k.wav differ diff --git a/music/音频1.WAV b/music/音频1.WAV new file mode 100644 index 0000000..fcb8e04 Binary files /dev/null and b/music/音频1.WAV differ diff --git a/music/音频2.WAV b/music/音频2.WAV new file mode 100644 index 0000000..c173480 Binary files /dev/null and b/music/音频2.WAV differ diff --git a/music/音频3.WAV b/music/音频3.WAV new file mode 100644 index 0000000..612f34b Binary files /dev/null and b/music/音频3.WAV differ diff --git a/music/音频4.WAV b/music/音频4.WAV new file mode 100644 index 0000000..4df2cf3 Binary files /dev/null and b/music/音频4.WAV differ diff --git a/music/音频5.WAV b/music/音频5.WAV new file mode 100644 index 0000000..5672b13 Binary files /dev/null and b/music/音频5.WAV differ diff --git a/music/音频6.WAV b/music/音频6.WAV new file mode 100644 index 0000000..8ff8ce0 Binary files /dev/null and b/music/音频6.WAV differ diff --git a/music/音频7.WAV b/music/音频7.WAV new file mode 100644 index 0000000..b4d350d Binary files /dev/null and b/music/音频7.WAV differ diff --git a/music/音频8.WAV b/music/音频8.WAV new file mode 100644 index 0000000..de6b1ff Binary files /dev/null and b/music/音频8.WAV differ diff --git a/music/音频9.WAV b/music/音频9.WAV new file mode 100644 index 0000000..1dd502e Binary files /dev/null and b/music/音频9.WAV differ diff --git a/music/黄灯.WAV b/music/黄灯.WAV new file mode 100644 index 0000000..e4051d3 Binary files /dev/null and b/music/黄灯.WAV differ diff --git a/navigation_master.py b/navigation_master.py new file mode 100644 index 0000000..76737fb --- /dev/null +++ b/navigation_master.py @@ -0,0 +1,700 @@ +# navigation_master.py +# -*- coding: utf-8 -*- +import time +import math +import cv2 +import numpy as np +from dataclasses import dataclass +from typing import Optional, Dict, Any, Deque, List, Tuple +from collections import deque + +# 工作流导入(与现有文件解耦) +from workflow_blindpath import BlindPathNavigator, ProcessingResult as BlindResult +from workflow_crossstreet import CrossStreetNavigator, CrossStreetResult as CrossResult + +# ========== 状态常量 ========== +IDLE = "IDLE" # 空闲/未启用 +CHAT = "CHAT" # 对话模式(不进行导航,只返回原始画面) +BLINDPATH_NAV = "BLINDPATH_NAV" # 正在走盲道(复用 BlindPathNavigator) +SEEKING_CROSSWALK = "SEEKING_CROSSWALK"# 盲道阶段发现斑马线,正对准/靠近 +WAIT_TRAFFIC_LIGHT = "WAIT_TRAFFIC_LIGHT" # 到达斑马线后等待交通灯(可选/占位) +CROSSING = "CROSSING" # 正在过马路(复用 CrossStreetNavigator) +SEEKING_NEXT_BLINDPATH = "SEEKING_NEXT_BLINDPATH" # 过完马路后寻找下一段盲道入口(上盲道) +RECOVERY = "RECOVERY" # 兜底/恢复(感知暂时丢失时) +TRAFFIC_LIGHT_DETECTION = "TRAFFIC_LIGHT_DETECTION" # 红绿灯检测模式 +ITEM_SEARCH = "ITEM_SEARCH" # 找物品模式(暂停导航,由yolomedia处理画面) + +# ========== 返回结构 ========== +@dataclass +class OrchestratorResult: + annotated_image: Optional[np.ndarray] + guidance_text: str + state: str + extras: Dict[str, Any] + +# ========== 实用:信号平滑/多数表决 ========== +class MajorityFilter: + def __init__(self, size: int = 8): + self.buf: Deque[str] = deque(maxlen=size) + + def push(self, v: str): + self.buf.append(v) + + def majority(self) -> str: + if not self.buf: + return "unknown" + cnt = {} + for v in self.buf: + cnt[v] = cnt.get(v, 0) + 1 + # 稳健排序:unknown 权重最低 + items = sorted(cnt.items(), key=lambda x: (0 if x[0]=="unknown" else 1, x[1]), reverse=True) + return items[0][0] + + def history(self) -> List[str]: + return list(self.buf) + + def clear(self): + self.buf.clear() + +# ========== 红绿灯识别 ========== +class TrafficLightDetector: + """ + 红绿灯识别器: + 1) 优先尝试 yoloe_backend 风格的检测(如可用); + 2) 回退:无模型时,使用 HSV 颜色启发式在上半屏寻找亮红/黄/绿的“灯团”。 + 输出:('red'|'green'|'yellow'|'unknown', meta) + """ + def __init__(self): + self.has_backend = False + self.backend = None + try: + # 尝试动态导入(根据你本地 yoloe_backend 的接口调整) + import yoloe_backend as _yeb # noqa + self.backend = _yeb + self.has_backend = True + except Exception: + self.has_backend = False + self.backend = None + + def _try_backend(self, bgr: np.ndarray) -> Tuple[str, Dict[str, Any]]: + """ + 尝试调用 yoloe_backend 风格的接口。由于各项目实现不同,这里做“宽容地调用”: + - 优先尝试 backend.detect(image, target_classes=['traffic light']) + - 次选 backend.infer_image(image) 后在结果中过滤 'traffic light' + - 以上都失败则返回 unknown + 预期结果条目应含 bbox 或 mask,可自行扩展“颜色判定”逻辑(ROI 取样 HSV) + """ + if not self.has_backend or self.backend is None: + return "unknown", {"reason": "backend_not_available"} + + res = None + try: + if hasattr(self.backend, "detect"): + # 假定 detect 返回 [{'name': 'traffic light', 'box':[x1,y1,x2,y2], ...}, ...] + res = self.backend.detect(bgr, target_classes=["traffic light"]) + elif hasattr(self.backend, "infer_image"): + # 假定 infer_image 返回 [{'label': 'traffic light', 'bbox': [x1,y1,x2,y2], ...}, ...] + res = self.backend.infer_image(bgr) + else: + return "unknown", {"reason": "backend_no_suitable_api"} + except Exception as e: + return "unknown", {"reason": f"backend_failed:{e}"} + + if not res or len(res) == 0: + return "unknown", {"reason": "no_detection"} + + # 拿到最大框作为主灯,做 HSV 颜色判断 + H, W = bgr.shape[:2] + best = None + best_area = 0 + boxes = [] + for item in res: + # 统一盒字段 + if "box" in item and isinstance(item["box"], (list, tuple)) and len(item["box"]) == 4: + x1, y1, x2, y2 = item["box"] + elif "bbox" in item and isinstance(item["bbox"], (list, tuple)) and len(item["bbox"]) == 4: + x1, y1, x2, y2 = item["bbox"] + else: + continue + x1 = int(max(0, min(W-1, x1))); x2 = int(max(0, min(W-1, x2))) + y1 = int(max(0, min(H-1, y1))); y2 = int(max(0, min(H-1, y2))) + if x2 <= x1 or y2 <= y1: + continue + area = (x2 - x1) * (y2 - y1) + boxes.append((x1, y1, x2, y2, area)) + if area > best_area: + best_area = area + best = (x1, y1, x2, y2) + + if best is None: + return "unknown", {"reason": "no_valid_bbox", "raw": len(res)} + + x1, y1, x2, y2 = best + roi = bgr[y1:y2, x1:x2] + color = self._classify_color_hsv(roi) + return color, {"bbox": best, "count": len(res), "boxes": boxes} + + def _classify_color_hsv(self, roi_bgr: np.ndarray) -> str: + """对 ROI 做 HSV 基于阈值的红/黄/绿简单判定;取面积最大的主色。""" + if roi_bgr is None or roi_bgr.size == 0: + return "unknown" + hsv = cv2.cvtColor(roi_bgr, cv2.COLOR_BGR2HSV) + + # 红色范围(两段) + lower_red1 = np.array([0, 80, 120]); upper_red1 = np.array([10, 255, 255]) + lower_red2 = np.array([160, 80, 120]); upper_red2 = np.array([180, 255, 255]) + mask_r1 = cv2.inRange(hsv, lower_red1, upper_red1) + mask_r2 = cv2.inRange(hsv, lower_red2, upper_red2) + mask_red = cv2.bitwise_or(mask_r1, mask_r2) + + # 绿色 + lower_green = np.array([40, 60, 120]); upper_green = np.array([90, 255, 255]) + mask_green = cv2.inRange(hsv, lower_green, upper_green) + + # 黄色 + lower_yellow = np.array([18, 80, 150]); upper_yellow = np.array([35, 255, 255]) + mask_yellow = cv2.inRange(hsv, lower_yellow, upper_yellow) + + # 面积阈值(相对 ROI) + total = roi_bgr.shape[0] * roi_bgr.shape[1] + 1e-6 + r_ratio = float(np.count_nonzero(mask_red)) / total + g_ratio = float(np.count_nonzero(mask_green)) / total + y_ratio = float(np.count_nonzero(mask_yellow)) / total + + # 简单抑制“脏背景导致的弱响应” + thr = 0.03 + candidates = [] + if r_ratio > thr: candidates.append(("red", r_ratio)) + if g_ratio > thr: candidates.append(("green", g_ratio)) + if y_ratio > thr: candidates.append(("yellow", y_ratio)) + if not candidates: + return "unknown" + candidates.sort(key=lambda x: x[1], reverse=True) + return candidates[0][0] + + def detect(self, bgr: np.ndarray) -> Tuple[str, Dict[str, Any]]: + """ + 总入口:先尝试后端;失败则在上半屏自行找“亮色灯团”(无需框)。 + """ + # 1) 尝试后端 + if self.has_backend: + color, meta = self._try_backend(bgr) + if color != "unknown": + return color, {"method": "backend", **meta} + + # 2) 回退:上半屏 HSV 聚类 + 连通域,选最大“灯团”判色 + H, W = bgr.shape[:2] + roi = bgr[:int(H * 0.5), :] + hsv = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV) + + # 高亮阈值(抑制暗部/车灯) + v = hsv[:, :, 2] + bright = (v > 140).astype(np.uint8) * 255 + + # 粗分颜色 + col = self._classify_color_hsv(roi) + return col, {"method": "fallback", "note": "no_backend", "bright_ratio": float(np.mean(bright > 0))} + +# ========== 视觉辅助工具 ========== +def _color_bgr(name: str) -> Tuple[int, int, int]: + if name == "red": return (0, 0, 255) + if name == "green": return (0, 255, 0) + if name == "yellow": return (0, 255, 255) + if name == "blue": return (255, 0, 0) + if name == "orange": return (0, 165, 255) + if name == "cyan": return (255, 255, 0) + if name == "magenta": return (255, 0, 255) + if name == "gray": return (128, 128, 128) + if name == "white": return (255, 255, 255) + return (200, 200, 200) + +def _put_text(img, text, org, color=(255,255,255), scale=0.7, thick=2, outline=True): + if outline: + for dx in (-1,0,1): + for dy in (-1,0,1): + if dx==0 and dy==0: continue + cv2.putText(img, text, (org[0]+dx, org[1]+dy), cv2.FONT_HERSHEY_SIMPLEX, scale, (0,0,0), thick+1) + cv2.putText(img, text, org, cv2.FONT_HERSHEY_SIMPLEX, scale, color, thick) + +def _draw_badge(img, text, pos=(10, 28), fg="white", bg="blue"): + color_fg = _color_bgr(fg); color_bg = _color_bgr(bg) + (tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, 0.6, 2) + x, y = pos + pad = 6 + cv2.rectangle(img, (x-4, y-th-pad), (x+tw+8, y+pad//2), color_bg, -1) + _put_text(img, text, (x, y), color=color_fg, scale=0.6, thick=2, outline=False) + +def _draw_state_panel(img, kv: Dict[str, Any], pos=(10, 60)): + x, y = pos + line_h = 22 + for i, (k, v) in enumerate(kv.items()): + _put_text(img, f"{k}: {v}", (x, y + i*line_h), color=(255,255,255), scale=0.6, thick=2) + +def _draw_frame_border(img, color=(0,255,0), thickness=3): + h, w = img.shape[:2] + cv2.rectangle(img, (0,0), (w-1, h-1), color, thickness) + +def _draw_progress_bar(img, ratio: float, pos=(10, 90), size=(180, 10), color="cyan"): + ratio = max(0.0, min(1.0, float(ratio))) + x, y = pos + w, h = size + cv2.rectangle(img, (x, y), (x+w, y+h), (80,80,80), 1) + cv2.rectangle(img, (x+1, y+1), (x+1+int((w-2)*ratio), y+h-1), _color_bgr(color), -1) + +# ========== 统领器 ========== +class NavigationMaster: + def __init__(self, + blind_nav: BlindPathNavigator, + cross_nav: CrossStreetNavigator, + *, + min_tts_interval: float = 1.2): + self.blind = blind_nav + self.cross = cross_nav + self.state = IDLE + self.last_guidance_ts = 0.0 + self.min_tts_interval = min_tts_interval + + # 防抖/稳定计数 + self.cnt_crosswalk_seen = 0 # 盲道侧看见斑马线(approaching/ready) + self.cnt_align_ready = 0 # 斑马线 ready + 对准达标 + self.cnt_cross_end = 0 # 过马路结束条件累计 + self.cnt_lost = 0 # 感知丢失累计(进入 RECOVERY) + + # 冷却期避免状态抖动 + self.cooldown_until = 0.0 + + # 紧急恢复目标 + self.prev_target_state = BLINDPATH_NAV + + # 交通灯 + self.tld = TrafficLightDetector() + self.tl_major = MajorityFilter(size=8) + self.tl_last_color = "unknown" + + # 参数(可按现场再调) + self.FRAMES_CROSS_SEEN = 8 + self.FRAMES_ALIGN_READY = 12 + self.FRAMES_CROSS_END = 12 + self.FRAMES_NEXT_BLIND_OK = 8 + self.FRAMES_LOST_MAX = 45 + + self.ANGLE_ALIGN_THR_DEG = 12.0 + self.OFFSET_ALIGN_THR = 0.15 + + self.COOLDOWN_SEC = 0.6 + + # 找物品状态管理 + self.prev_nav_state_before_search = None # 找物品前的导航状态,用于恢复 + + # ----- 外部交互 ----- + def get_state(self) -> str: + return self.state + + def start_blind_path_navigation(self): + """启动盲道导航模式""" + self.state = BLINDPATH_NAV + self.cooldown_until = time.time() + self.COOLDOWN_SEC + if self.blind: + self.blind.reset() + + def stop_navigation(self): + """停止导航,回到对话模式""" + self.state = CHAT + self.cooldown_until = time.time() + self.COOLDOWN_SEC + if self.blind: + self.blind.reset() + + def start_crossing(self): + """启动过马路模式""" + self.state = CROSSING + self.cooldown_until = time.time() + self.COOLDOWN_SEC + if self.cross: + self.cross.reset() + + def start_traffic_light_detection(self): + """启动红绿灯检测模式""" + self.state = TRAFFIC_LIGHT_DETECTION + self.cooldown_until = time.time() + self.COOLDOWN_SEC + + def is_in_navigation_mode(self): + """检查是否在导航模式(非对话模式)""" + return self.state not in ["CHAT", "IDLE", "TRAFFIC_LIGHT_DETECTION", "ITEM_SEARCH"] + + def start_item_search(self): + """启动找物品模式,暂停当前导航""" + # 保存当前导航状态(如果在导航中) + if self.state in [BLINDPATH_NAV, SEEKING_CROSSWALK, WAIT_TRAFFIC_LIGHT, CROSSING, SEEKING_NEXT_BLINDPATH]: + self.prev_nav_state_before_search = self.state + print(f"[NAV MASTER] 暂停导航状态 {self.state},切换到找物品模式") + else: + self.prev_nav_state_before_search = None + + self.state = ITEM_SEARCH + self.cooldown_until = time.time() + self.COOLDOWN_SEC + + def stop_item_search(self, restore_nav: bool = True): + """停止找物品模式""" + # 如果需要恢复之前的导航状态 + if restore_nav and self.prev_nav_state_before_search: + self.state = self.prev_nav_state_before_search + print(f"[NAV MASTER] 找物品结束,恢复到导航状态 {self.state}") + self.prev_nav_state_before_search = None + else: + # 否则回到对话模式 + self.state = CHAT + print(f"[NAV MASTER] 找物品结束,回到对话模式") + + self.cooldown_until = time.time() + self.COOLDOWN_SEC + + def force_state(self, s: str): + self.state = s + self.cooldown_until = time.time() + self.COOLDOWN_SEC + + def on_voice_command(self, text: str): + t = (text or "").strip() + if "开始过马路" in t: + # 直接进入等待/或立即过马路(低速环境可直过) + if self.state in (BLINDPATH_NAV, SEEKING_CROSSWALK, WAIT_TRAFFIC_LIGHT, IDLE, RECOVERY, SEEKING_NEXT_BLINDPATH): + self.state = WAIT_TRAFFIC_LIGHT + self.cooldown_until = time.time() + self.COOLDOWN_SEC + elif "立即通过" in t or "现在通过" in t: + self.state = CROSSING + self.cooldown_until = time.time() + self.COOLDOWN_SEC + elif "停止" in t or "结束" in t: + self.state = IDLE + elif "继续" in t: + if self.state == IDLE: + self.state = BLINDPATH_NAV + + def reset(self): + self.state = IDLE + self.cnt_crosswalk_seen = 0 + self.cnt_align_ready = 0 + self.cnt_cross_end = 0 + self.cnt_lost = 0 + self.tl_major.clear() + self.tl_last_color = "unknown" + self.prev_target_state = BLINDPATH_NAV + self._last_wait_light_announce = 0 # 重置等待绿灯播报时间 + try: + self.blind.reset() + except Exception: + pass + try: + self.cross.reset() + except Exception: + pass + + # ----- 内部工具 ----- + def _say(self, now: float, text: str) -> str: + if not text: + return "" + if now - self.last_guidance_ts >= self.min_tts_interval: + self.last_guidance_ts = now + return text + return "" + + def _draw_tl_status(self, img: np.ndarray, color: str, meta: Dict[str, Any]): + if img is None: + return + color_bgr = _color_bgr(color) + # 角标与文本 + cv2.circle(img, (24, 24), 10, color_bgr, -1) + _put_text(img, f"信号灯: {color}", (40, 30), color=color_bgr, scale=0.6, thick=2, outline=False) + # 画 bbox(若有) + if meta and "bbox" in meta: + x1, y1, x2, y2 = meta["bbox"] + cv2.rectangle(img, (int(x1), int(y1)), (int(x2), int(y2)), color_bgr, 2) + + # 多数表决历史(最近8帧) + hist = self.tl_major.history() + if hist: + x0, y0 = 10, 50 + r = 6 + gap = 16 + for i, hcol in enumerate(hist[-12:]): + cv2.circle(img, (x0 + i*gap, y0), r, _color_bgr(hcol), -1) + _put_text(img, "信号历史", (x0, y0+20), color=(255,255,255), scale=0.5, thick=1) + + # ----- 主循环 ----- + def process_frame(self, bgr: np.ndarray) -> OrchestratorResult: + now = time.time() + + # 【修改】IDLE状态默认进入CHAT模式,而不是自动开始导航 + if self.state == IDLE: + self.state = CHAT + self.cooldown_until = now + self.COOLDOWN_SEC + + # 【新增】CHAT模式:只返回原始画面,不进行导航 + if self.state == CHAT: + return OrchestratorResult( + annotated_image=bgr, + guidance_text="", + state="CHAT", + extras={"mode": "对话模式"} + ) + + # 【新增】红绿灯检测模式:只返回原始画面,由红绿灯模块处理 + if self.state == TRAFFIC_LIGHT_DETECTION: + return OrchestratorResult( + annotated_image=bgr, + guidance_text="", + state="TRAFFIC_LIGHT_DETECTION", + extras={"mode": "红绿灯检测模式"} + ) + + # 【新增】找物品模式:只返回原始画面,由yolomedia处理 + if self.state == ITEM_SEARCH: + return OrchestratorResult( + annotated_image=bgr, + guidance_text="", + state="ITEM_SEARCH", + extras={"mode": "找物品模式", "prev_nav_state": self.prev_nav_state_before_search} + ) + + # 冷却期内允许继续输出画面,但避免"瞬时切换" + in_cooldown = now < self.cooldown_until + + # 各状态处理 + if self.state in (BLINDPATH_NAV, SEEKING_CROSSWALK, SEEKING_NEXT_BLINDPATH, RECOVERY): + # —— 盲道侧 —— 统一调用盲道导航器 + try: + bres: BlindResult = self.blind.process_frame(bgr) + except Exception as e: + # 异常 → 进入恢复态 + self.state = RECOVERY + self.cnt_lost += 5 + ann_err = bgr.copy() + # 【移除】所有可视化干扰 + # _draw_badge(ann_err, "NAV ERROR", (10, 28), fg="white", bg="red") + # _put_text(ann_err, str(e), (10, 56), color=(255,255,255), scale=0.55) + return OrchestratorResult(ann_err, self._say(now, ""), self.state, {"error": str(e)}) + + ann = bres.annotated_image if bres.annotated_image is not None else bgr.copy() + say = bres.guidance_text or "" + + state_info = bres.state_info or {} + cross_stage = state_info.get("crosswalk_stage", "not_detected") + blind_state = state_info.get("state", "UNKNOWN") + # 可选字段(若工作流未来补充) + angle = float(state_info.get("last_angle", 0.0)) + center_x_ratio = float(state_info.get("last_center_x_ratio", 0.5)) + + # —— 盲道 → 发现斑马线(approaching/ready) + if self.state == BLINDPATH_NAV: + if cross_stage in ("approaching", "ready"): + self.cnt_crosswalk_seen += 1 + else: + self.cnt_crosswalk_seen = max(0, self.cnt_crosswalk_seen - 1) + + if self.cnt_crosswalk_seen >= self.FRAMES_CROSS_SEEN and not in_cooldown: + self.state = SEEKING_CROSSWALK + self.cooldown_until = now + self.COOLDOWN_SEC + say = "正在接近斑马线,为您对准方向。" + + # 【移除】所有可视化干扰 + # _draw_badge(ann, f"STATE: {self.state}", (10, 28), fg="white", bg="blue") + # _draw_state_panel(ann, { + # "盲道状态": blind_state, + # "斑马线阶段": cross_stage, + # "靠近计数": self.cnt_crosswalk_seen, + # }, pos=(10, 60)) + # _draw_progress_bar(ann, max(0.0, min(1.0, self.cnt_crosswalk_seen / max(1, self.FRAMES_CROSS_SEEN))), pos=(10, 120), size=(180, 10), color="cyan") + # _draw_frame_border(ann, color=_color_bgr("blue"), thickness=3) + + # —— 对准阶段:同时利用 blind 内部 crosswalk_tracker 的角度与偏移(若提供) + elif self.state == SEEKING_CROSSWALK: + aligned = (abs(angle) <= self.ANGLE_ALIGN_THR_DEG and abs(center_x_ratio - 0.5) <= self.OFFSET_ALIGN_THR) + if cross_stage == "ready" and aligned: + self.cnt_align_ready += 1 + else: + self.cnt_align_ready = max(0, self.cnt_align_ready - 1) + + if self.cnt_align_ready >= self.FRAMES_ALIGN_READY and not in_cooldown: + self.state = WAIT_TRAFFIC_LIGHT + self.cooldown_until = now + self.COOLDOWN_SEC + say = "已到达斑马线,请等待红绿灯。" + + # 【移除】所有可视化干扰 + # _draw_badge(ann, f"STATE: {self.state}", (10, 28), fg="white", bg="orange") + # panel = { + # "阶段": cross_stage, + # "对准计数": self.cnt_align_ready, + # } + # if "last_angle" in state_info: + # panel["角度(°)"] = f"{angle:.1f}" + # if "last_center_x_ratio" in state_info: + # panel["偏移"] = f"{(center_x_ratio-0.5):+.2f}" + # _draw_state_panel(ann, panel, pos=(10, 60)) + # _draw_progress_bar(ann, max(0.0, min(1.0, self.cnt_align_ready / max(1, self.FRAMES_ALIGN_READY))), pos=(10, 120), size=(220, 10), color="yellow") + # _draw_frame_border(ann, color=_color_bgr("orange"), thickness=3) + + # —— 过马路后寻找下一段盲道(上盲道流程) + elif self.state == SEEKING_NEXT_BLINDPATH: + if blind_state == "NAVIGATING": + self.cnt_cross_end += 1 + else: + self.cnt_cross_end = max(0, self.cnt_cross_end - 1) + if self.cnt_cross_end >= self.FRAMES_NEXT_BLIND_OK and not in_cooldown: + self.state = BLINDPATH_NAV + self.cooldown_until = now + self.COOLDOWN_SEC + say = "方向正确,请继续前进。" + + # 【移除】所有可视化干扰 + # _draw_badge(ann, f"STATE: {self.state}", (10, 28), fg="white", bg="green") + # _draw_state_panel(ann, { + # "盲道状态": blind_state, + # "回归计数": self.cnt_cross_end + # }, pos=(10, 60)) + # _draw_progress_bar(ann, max(0.0, min(1.0, self.cnt_cross_end / max(1, self.FRAMES_NEXT_BLIND_OK))), pos=(10, 120), size=(200, 10), color="green") + # _draw_frame_border(ann, color=_color_bgr("green"), thickness=3) + + # —— 恢复态:一旦盲道恢复可用则回盲道 + elif self.state == RECOVERY: + if blind_state in ("ONBOARDING", "NAVIGATING"): + self.state = BLINDPATH_NAV + self.cooldown_until = now + self.COOLDOWN_SEC + say = "" + else: + say = "" + # 【移除】所有可视化干扰 + # _draw_badge(ann, f"STATE: {self.state}", (10, 28), fg="white", bg="red") + # _draw_state_panel(ann, { + # "提示": "请缓慢环顾/抬头/降低手机角度", + # "丢失计数": self.cnt_lost + # }, pos=(10, 60)) + # _draw_frame_border(ann, color=_color_bgr("red"), thickness=3) + + # 丢失计数(兜底) + if blind_state == "UNKNOWN" and cross_stage == "not_detected": + self.cnt_lost += 1 + else: + self.cnt_lost = max(0, self.cnt_lost - 2) + if self.cnt_lost >= self.FRAMES_LOST_MAX and self.state != RECOVERY: + self.prev_target_state = self.state + self.state = RECOVERY + self.cooldown_until = now + self.COOLDOWN_SEC + say = "环境复杂,进入恢复模式。" + + # 【移除】冷却进度条 + # if in_cooldown: + # remain = max(0.0, self.cooldown_until - now) + # ratio = 1.0 - min(1.0, remain / self.COOLDOWN_SEC) + # _draw_progress_bar(ann, ratio, pos=(10, 140), size=(160, 8), color="gray") + + return OrchestratorResult(ann, self._say(now, say), self.state, {"source": "blind", "cross_stage": cross_stage, "blind_state": blind_state}) + + if self.state == WAIT_TRAFFIC_LIGHT: + ann = bgr.copy() + # 红绿灯识别(多数表决+冷却) + color, meta = self.tld.detect(bgr) + self.tl_major.push(color) + major = self.tl_major.majority() + self.tl_last_color = major + + # 【移除】所有可视化干扰 + # _draw_badge(ann, f"STATE: {self.state}", (10, 28), fg="white", bg="magenta") + # self._draw_tl_status(ann, major, meta) + # _draw_state_panel(ann, { + # "提示": "请等待绿灯或语音确认"立即通过"", + # "冷却": f"{max(0.0, self.cooldown_until - now):.1f}s" + # }, pos=(10, 80)) + # _draw_frame_border(ann, color=_color_bgr("magenta"), thickness=3) + + say = "" + if major == "green" and not in_cooldown: + self.state = CROSSING + self.cooldown_until = now + self.COOLDOWN_SEC + say = "绿灯稳定,开始通行。" + else: + # 只在刚进入状态或每隔一段时间才播报 + if not hasattr(self, '_last_wait_light_announce'): + self._last_wait_light_announce = 0 + if now - self._last_wait_light_announce > 5.0: # 5秒播报一次 + say = "正在等待绿灯…" + self._last_wait_light_announce = now + + + + # 【移除】冷却进度 + # if in_cooldown: + # remain = max(0.0, self.cooldown_until - now) + # ratio = 1.0 - min(1.0, remain / self.COOLDOWN_SEC) + # _draw_progress_bar(ann, ratio, pos=(10, 140), size=(160, 8), color="gray") + + return OrchestratorResult(ann, self._say(now, say), self.state, {"traffic_light": major}) + + if self.state == CROSSING: + try: + cres: CrossResult = self.cross.process_frame(bgr) + except Exception as e: + # 异常 → 恢复 + self.state = RECOVERY + ann_err = bgr.copy() + # 【移除】所有可视化干扰 + # _draw_badge(ann_err, "CROSS ERROR", (10, 28), fg="white", bg="red") + # _put_text(ann_err, str(e), (10, 56), color=(255,255,255), scale=0.55) + return OrchestratorResult(ann_err, self._say(now, ""), self.state, {"error": str(e)}) + + ann = cres.annotated_image if cres.annotated_image is not None else bgr.copy() + say = cres.guidance_text or "" + + # 新增:检查是否检测到盲道 + blind_path_detected = getattr(cres, 'blind_path_detected', False) + blind_path_guidance = getattr(cres, 'blind_path_guidance', "") + + # 如果检测到盲道且需要引导,优先处理盲道引导 + if blind_path_detected and blind_path_guidance: + # 如果应该切换到盲道导航(盲道很近),直接切换状态 + if hasattr(cres, "should_switch_to_blindpath") and cres.should_switch_to_blindpath: + if not in_cooldown: + self.state = BLINDPATH_NAV + self.cooldown_until = now + self.COOLDOWN_SEC + say = "已到盲道跟前,切换到盲道导航。" # 使用现有语音文件 + self.cnt_cross_end = 0 # 重置计数器 + # 重置盲道导航器状态 + if hasattr(self.blind, 'reset'): + self.blind.reset() + else: + # 盲道较远,继续过马路但给出盲道引导 + # say 已经在 cres.guidance_text 中包含了盲道引导信息 + pass + + # 原有的结束条件:连续多帧"寻找斑马线" + end_hint = False + if "寻找斑马线" in (say or ""): + end_hint = True + # 注意:不再单纯因为 should_switch_to_blindpath 就结束过马路 + # if hasattr(cres, "should_switch_to_blindpath") and cres.should_switch_to_blindpath: + # end_hint = True + + self.cnt_cross_end = self.cnt_cross_end + 1 if end_hint else max(0, self.cnt_cross_end - 1) + + if self.cnt_cross_end >= self.FRAMES_CROSS_END and not in_cooldown: + self.state = SEEKING_NEXT_BLINDPATH + self.cooldown_until = now + self.COOLDOWN_SEC + say = "过马路结束,准备上人行道。" + + # 【移除】所有可视化干扰 + # _draw_badge(ann, f"STATE: {self.state}", (10, 28), fg="white", bg="cyan") + # _draw_state_panel(ann, { + # "结束计数": self.cnt_cross_end, + # "冷却": f"{max(0.0, self.cooldown_until - now):.1f}s" + # }, pos=(10, 60)) + # _draw_progress_bar(ann, max(0.0, min(1.0, self.cnt_cross_end / max(1, self.FRAMES_CROSS_END))), pos=(10, 120), size=(220, 10), color="cyan") + # _draw_frame_border(ann, color=_color_bgr("cyan"), thickness=3) + # if in_cooldown: + # remain = max(0.0, self.cooldown_until - now) + # ratio = 1.0 - min(1.0, remain / self.COOLDOWN_SEC) + # _draw_progress_bar(ann, ratio, pos=(10, 140), size=(160, 8), color="gray") + + return OrchestratorResult(ann, self._say(now, say), self.state, {"source": "cross", "end_cnt": self.cnt_cross_end}) + + # 兜底 + ann = bgr.copy() + # 【移除】所有可视化干扰 + # _draw_badge(ann, f"STATE: {self.state}", (10, 28), fg="white", bg="gray") + # _draw_frame_border(ann, color=_color_bgr("gray"), thickness=2) + return OrchestratorResult(ann, "", self.state, {}) + + diff --git a/numba_utils.py b/numba_utils.py new file mode 100644 index 0000000..f01ef2b --- /dev/null +++ b/numba_utils.py @@ -0,0 +1,185 @@ +# numba_utils.py - Day 20 Numba 多核加速工具 +""" +使用 Numba JIT 编译加速 numpy 密集操作,绕过 Python GIL 实现真正多核并行 +""" + +import numpy as np + +try: + from numba import jit, prange + NUMBA_AVAILABLE = True +except ImportError: + NUMBA_AVAILABLE = False + print("[NUMBA] Numba 未安装,使用 numpy 回退实现") + + +if NUMBA_AVAILABLE: + @jit(nopython=True, parallel=True, cache=True) + def count_mask_pixels_numba(mask: np.ndarray) -> int: + """快速计算 mask 中非零像素数量(多核并行)""" + count = 0 + h, w = mask.shape + for i in prange(h): + for j in range(w): + if mask[i, j] > 0: + count += 1 + return count + + @jit(nopython=True, parallel=True, cache=True) + def compute_mask_stats_numba(mask: np.ndarray) -> tuple: + """ + 快速计算 mask 的统计信息(多核并行) + 返回: (area, center_x, center_y, min_x, max_x, min_y, max_y) + """ + h, w = mask.shape + count = 0 + sum_x = 0.0 + sum_y = 0.0 + min_x = w + max_x = 0 + min_y = h + max_y = 0 + + for i in prange(h): + for j in range(w): + if mask[i, j] > 0: + count += 1 + sum_x += j + sum_y += i + if j < min_x: + min_x = j + if j > max_x: + max_x = j + if i < min_y: + min_y = i + if i > max_y: + max_y = i + + if count > 0: + center_x = sum_x / count + center_y = sum_y / count + else: + center_x = 0.0 + center_y = 0.0 + + return (count, center_x, center_y, min_x, max_x, min_y, max_y) + + @jit(nopython=True, parallel=True, cache=True) + def bitwise_and_count_numba(mask1: np.ndarray, mask2: np.ndarray) -> int: + """快速计算两个 mask 的交集像素数量(多核并行)""" + h, w = mask1.shape + count = 0 + for i in prange(h): + for j in range(w): + if mask1[i, j] > 0 and mask2[i, j] > 0: + count += 1 + return count + + @jit(nopython=True, parallel=True, cache=True) + def resize_mask_nearest_numba(mask: np.ndarray, new_h: int, new_w: int) -> np.ndarray: + """ + 快速最近邻插值缩放 mask(多核并行) + 注意:这是简化实现,对于大多数情况足够用 + """ + old_h, old_w = mask.shape + result = np.zeros((new_h, new_w), dtype=np.uint8) + + scale_y = old_h / new_h + scale_x = old_w / new_w + + for i in prange(new_h): + for j in range(new_w): + src_y = int(i * scale_y) + src_x = int(j * scale_x) + if src_y >= old_h: + src_y = old_h - 1 + if src_x >= old_w: + src_x = old_w - 1 + result[i, j] = mask[src_y, src_x] + + return result + + +# 对外接口:根据 Numba 是否可用选择实现 +def count_mask_pixels(mask: np.ndarray) -> int: + """计算 mask 中非零像素数量""" + if NUMBA_AVAILABLE: + return count_mask_pixels_numba(mask) + else: + return int(np.sum(mask > 0)) + + +def compute_mask_stats(mask: np.ndarray) -> dict: + """ + 计算 mask 的统计信息 + 返回: {'area': int, 'center_x': float, 'center_y': float, 'bbox': (x1, y1, x2, y2)} + """ + if NUMBA_AVAILABLE: + area, cx, cy, min_x, max_x, min_y, max_y = compute_mask_stats_numba(mask) + return { + 'area': int(area), + 'center_x': float(cx), + 'center_y': float(cy), + 'bbox': (int(min_x), int(min_y), int(max_x), int(max_y)) + } + else: + # numpy 回退 + y_coords, x_coords = np.where(mask > 0) + if len(y_coords) == 0: + return {'area': 0, 'center_x': 0, 'center_y': 0, 'bbox': (0, 0, 0, 0)} + return { + 'area': len(y_coords), + 'center_x': float(np.mean(x_coords)), + 'center_y': float(np.mean(y_coords)), + 'bbox': (int(np.min(x_coords)), int(np.min(y_coords)), + int(np.max(x_coords)), int(np.max(y_coords))) + } + + +def bitwise_and_count(mask1: np.ndarray, mask2: np.ndarray) -> int: + """计算两个 mask 的交集像素数量""" + if NUMBA_AVAILABLE: + return bitwise_and_count_numba(mask1.astype(np.uint8), mask2.astype(np.uint8)) + else: + return int(np.sum(np.bitwise_and(mask1, mask2) > 0)) + + +# 预热 JIT 编译(首次调用时编译,之后使用缓存) +def warmup(): + """预热 Numba JIT 编译,避免首次调用时的延迟""" + if NUMBA_AVAILABLE: + dummy = np.zeros((10, 10), dtype=np.uint8) + dummy[5, 5] = 255 + count_mask_pixels_numba(dummy) + compute_mask_stats_numba(dummy) + bitwise_and_count_numba(dummy, dummy) + print("[NUMBA] JIT 编译预热完成,已启用多核加速") + + +if __name__ == "__main__": + # 测试和性能对比 + import time + + # 创建测试数据 + test_mask = np.zeros((480, 640), dtype=np.uint8) + test_mask[100:300, 200:400] = 255 + + # 测试 numpy 版本 + start = time.perf_counter() + for _ in range(100): + np.sum(test_mask > 0) + numpy_time = (time.perf_counter() - start) * 1000 + + # 测试 numba 版本 + if NUMBA_AVAILABLE: + # 预热 + count_mask_pixels_numba(test_mask) + + start = time.perf_counter() + for _ in range(100): + count_mask_pixels_numba(test_mask) + numba_time = (time.perf_counter() - start) * 1000 + + print(f"numpy: {numpy_time:.2f}ms / 100 次") + print(f"numba: {numba_time:.2f}ms / 100 次") + print(f"加速比: {numpy_time / numba_time:.1f}x") diff --git a/obstacle_detector_client.py b/obstacle_detector_client.py new file mode 100644 index 0000000..487fddd --- /dev/null +++ b/obstacle_detector_client.py @@ -0,0 +1,244 @@ +# app/cloud/obstacle_detector_client.py (新文件) +import logging +import os +import cv2 +import numpy as np +import torch +from threading import Semaphore +from contextlib import contextmanager +from ultralytics import YOLOE +from typing import List, Dict, Any + +# Day 20: Numba 多核加速 +try: + from numba_utils import count_mask_pixels, compute_mask_stats, bitwise_and_count, warmup as numba_warmup + NUMBA_ENABLED = True +except ImportError: + NUMBA_ENABLED = False + +logger = logging.getLogger(__name__) + +# --- GPU/CPU & AMP 配置 (从 blindpath 工作流迁移而来,保持一致) --- +DEVICE = os.getenv("AIGLASS_DEVICE", "cuda:0") +if DEVICE.startswith("cuda") and not torch.cuda.is_available(): + logger.warning(f"AIGLASS_DEVICE={DEVICE} 但未检测到 CUDA,将回退到 CPU") + DEVICE = "cpu" +IS_CUDA = DEVICE.startswith("cuda") + +AMP_POLICY = os.getenv("AIGLASS_AMP", "fp16").lower() +if AMP_POLICY not in ("bf16", "fp16", "off"): + AMP_POLICY = "fp16" +AMP_DTYPE = torch.bfloat16 if AMP_POLICY == "bf16" else (torch.float16 if AMP_POLICY == "fp16" else None) + +# --- GPU 并发限流 (从 blindpath 工作流迁移而来,保持一致) --- +# Day 20: 增加默认槽位从 2 到 4,RTX 3090 可以处理更多并发 +GPU_SLOTS = int(os.getenv("AIGLASS_GPU_SLOTS", "4")) +_gpu_slots = Semaphore(GPU_SLOTS) + +try: + torch.backends.cudnn.benchmark = True +except Exception: + pass + + +@contextmanager +def gpu_infer_slot(): + """统一管理 GPU 并发限流 + inference_mode + AMP autocast""" + with _gpu_slots: + if IS_CUDA and AMP_POLICY != "off": + # 新式接口:torch.amp.autocast(device_type='cuda', dtype=...) + with torch.inference_mode(), torch.amp.autocast(device_type='cuda', dtype=AMP_DTYPE): + yield + else: + with torch.inference_mode(): + yield + + +class ObstacleDetectorClient: + def __init__(self, model_path: str = 'model/yoloe-11l-seg.pt'): + self.model = None + self.whitelist_embeddings = None + self.WHITELIST_CLASSES = [ + 'bicycle', 'car', 'motorcycle', 'bus', 'truck', 'animal', 'scooter', 'stroller', 'dog', + 'pole', 'post', 'column', 'pillar', 'stanchion', 'bollard', 'utility pole', + 'telegraph pole', 'light pole', 'street pole', 'signpost', 'support post', + 'vertical post', 'bench', 'chair', 'potted plant', 'hydrant', 'cone', 'stone', 'box' + ] + # COCO 类别白名单 - TensorRT 模式下用于后处理过滤 + # 从 COCO 80 类中筛选出可能构成障碍物的类别 + self.COCO_WHITELIST = { + 'person', 'bicycle', 'car', 'motorcycle', 'bus', 'truck', # 交通 + 'dog', 'cat', 'horse', 'cow', 'sheep', # 动物 + 'bench', 'chair', 'potted plant', 'fire hydrant', 'stop sign', # 街道设施 + 'parking meter', 'suitcase', 'backpack', 'umbrella', 'handbag', # 物品 + 'sports ball', 'skateboard', 'surfboard', 'bottle', 'cup', # 可能障碍 + } + try: + # Day 20: 优先使用 TensorRT 引擎 + try: + from model_utils import get_best_model_path, is_tensorrt_engine + model_path = get_best_model_path(model_path) + except ImportError: + def is_tensorrt_engine(p): return p.endswith('.engine') + + logger.info(f"正在加载 YOLOE 障碍物模型: {model_path}") + self.model = YOLOE(model_path) + + # Day 20: TensorRT 引擎不需要 .to() 和 .fuse() + if is_tensorrt_engine(model_path): + logger.info(f"TensorRT 引擎已加载,跳过 .to() 和 .fuse()") + # TensorRT 引擎不支持 get_text_pe,跳过白名单特征计算 + self.whitelist_embeddings = None + logger.info("TensorRT 模式:跳过白名单特征预计算") + else: + self.model.to(DEVICE) + self.model.fuse() + logger.info(f"YOLOE 障碍物模型加载成功,使用设备: {DEVICE}") + + logger.info("正在为 YOLOE 预计算白名单文本特征...") + if IS_CUDA and AMP_DTYPE is not None: + with torch.inference_mode(), torch.amp.autocast(device_type='cuda', dtype=AMP_DTYPE): + self.whitelist_embeddings = self.model.get_text_pe(self.WHITELIST_CLASSES) + else: + self.whitelist_embeddings = self.model.get_text_pe(self.WHITELIST_CLASSES) + logger.info("YOLOE 特征预计算完成。") + except Exception as e: + logger.error(f"YOLOE 模型加载或特征计算失败: {e}", exc_info=True) + raise + def tensor_to_numpy_mask(mask_tensor): + """安全地将各种类型的张量转换为 numpy 掩码""" + # 处理不同的数据类型 + if mask_tensor.dtype in (torch.bfloat16, torch.float16): + mask_tensor = mask_tensor.float() + + # 转换为 numpy + mask = mask_tensor.cpu().numpy() + + # 确保是二值掩码 + if mask.max() <= 1.0: + mask = (mask > 0.5).astype(np.uint8) * 255 + else: + mask = mask.astype(np.uint8) + + return mask + def detect(self, image: np.ndarray, path_mask: np.ndarray = None) -> List[Dict[str, Any]]: + """ + 利用白名单作为提示词寻找障碍物。 + 如果提供了 path_mask,则执行与路径相关的空间过滤。 + 如果 path_mask 为 None,则进行全局检测。 + """ + if self.model is None: + return [] + + H, W = image.shape[:2] + + # TensorRT 模式下没有 embeddings,跳过 set_classes + # 此时模型会使用默认的 COCO 类别进行检测 + if self.whitelist_embeddings is not None: + try: + self.model.set_classes(self.WHITELIST_CLASSES, self.whitelist_embeddings) + except Exception as e: + logger.error(f"设置 YOLOE 提示词失败: {e}") + return [] + + conf_thr = float(os.getenv("AIGLASS_OBS_CONF", "0.25")) + # Day 22 优化: 动态输入尺寸和FP16加速 + imgsz = int(os.getenv("AIGLASS_OBS_IMGSZ", "480")) # 从默认640降低 + use_half = os.getenv("AIGLASS_OBS_HALF", "1") == "1" + + with gpu_infer_slot(): + results = self.model.predict( + image, + verbose=False, + conf=conf_thr, + imgsz=imgsz, # 使用较小的输入尺寸 + half=use_half # FP16 半精度加速 + ) + + if not (results and results[0].masks): + return [] + + # --- 过滤与后处理 (逻辑与 blindpath 工作流保持一致) --- + final_obstacles = [] + num_masks = len(results[0].masks.data) + num_boxes = len(results[0].boxes.cls) if getattr(results[0].boxes, "cls", None) is not None else 0 + + for i, mask_tensor in enumerate(results[0].masks.data): + if i >= num_boxes: continue + + # 【修复】处理 BFloat16 类型的掩码 + # 先转换为 float32,避免 numpy 不支持 BFloat16 的问题 + if mask_tensor.dtype == torch.bfloat16: + mask_tensor = mask_tensor.float() + + # 转换为 numpy 数组 + mask = mask_tensor.cpu().numpy() + + # 处理概率掩码(值在0-1之间)或二值掩码 + if mask.max() <= 1.0: + # 概率掩码,需要二值化 + mask = (mask > 0.5).astype(np.uint8) * 255 + else: + # 已经是二值掩码 + mask = mask.astype(np.uint8) + + mask = cv2.resize(mask, (W, H), interpolation=cv2.INTER_NEAREST) + + # Day 20: 使用 Numba 多核加速计算 mask 统计信息 + if NUMBA_ENABLED: + stats = compute_mask_stats(mask) + area = stats['area'] + center_x = stats['center_x'] + center_y = stats['center_y'] + min_y, max_y = stats['bbox'][1], stats['bbox'][3] + else: + area = int(np.sum(mask > 0)) + y_coords, x_coords = np.where(mask > 0) + if len(y_coords) == 0: + continue + center_x = float(np.mean(x_coords)) + center_y = float(np.mean(y_coords)) + min_y, max_y = int(np.min(y_coords)), int(np.max(y_coords)) + + # 尺寸过滤:太大的物体(如整片地面)通常是误识别 + if (area / (H * W)) > 0.7: continue + if area == 0: continue + + # 空间过滤:如果提供了 path_mask,则只保留路径上的障碍物 + if path_mask is not None: + # Day 20: 使用 Numba 加速交集计算 + if NUMBA_ENABLED: + intersection_area = bitwise_and_count(mask, path_mask) + else: + intersection_area = int(np.sum(cv2.bitwise_and(mask, path_mask) > 0)) + # 必须与路径有足够的重叠 + if intersection_area < 100 or (intersection_area / area) < 0.01: + continue + + cls_id = int(results[0].boxes.cls[i]) + class_names_map = results[0].names + class_name = "Unknown" + if isinstance(class_names_map, dict): + # 如果是字典,使用 .get() 方法 + class_name = class_names_map.get(cls_id, "Unknown") + elif isinstance(class_names_map, list) and 0 <= cls_id < len(class_names_map): + # 如果是列表,通过索引安全地获取 + class_name = class_names_map[cls_id] + + # TensorRT 模式下使用 COCO 白名单过滤 + # 只保留可能构成障碍物的类别 + if self.whitelist_embeddings is None: # TensorRT 模式 + if class_name.lower().strip() not in self.COCO_WHITELIST: + continue # 跳过非白名单类别 + + final_obstacles.append({ + 'name': class_name.strip(), + 'mask': mask, + 'area': area, + 'area_ratio': area / (H * W), + 'center_x': center_x, + 'center_y': center_y, + 'bottom_y_ratio': max_y / H + }) + + return final_obstacles \ No newline at end of file diff --git a/omni_client.py b/omni_client.py new file mode 100644 index 0000000..0dd1f0c --- /dev/null +++ b/omni_client.py @@ -0,0 +1,117 @@ +# omni_client.py +# -*- coding: utf-8 -*- +import os, base64, asyncio, threading +from typing import AsyncGenerator, Dict, Any, List, Optional, Tuple + +from openai import OpenAI + +# ===== OpenAI 兼容(达摩院 DashScope 兼容模式)===== +API_KEY = os.getenv("DASHSCOPE_API_KEY", "sk-a9440db694924559ae4ebdc2023d2b9a") +if not API_KEY: + raise RuntimeError("未设置 DASHSCOPE_API_KEY") + +QWEN_MODEL = "qwen-omni-turbo" + +# 兼容模式 +oai_client = OpenAI( + api_key=API_KEY, + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", +) + +class OmniStreamPiece: + """对外的统一增量数据:text/audio 二选一或同时。""" + def __init__(self, text_delta: Optional[str] = None, audio_b64: Optional[str] = None): + self.text_delta = text_delta + self.audio_b64 = audio_b64 + +async def stream_chat( + content_list: List[Dict[str, Any]], + voice: str = "Cherry", + audio_format: str = "wav", +) -> AsyncGenerator[OmniStreamPiece, None]: + """ + 发起一轮 Omni-Turbo ChatCompletions 流式对话: + - content_list: OpenAI chat 的 content,多模态(image_url/text) + - 以 stream=True 返回 + - 增量产出:OmniStreamPiece(text_delta=?, audio_b64=?) + + Day 13 修复:使用队列+线程解耦同步 API 调用,避免阻塞事件循环 + """ + # 使用 asyncio.Queue 在线程和异步之间传递数据 + queue: asyncio.Queue = asyncio.Queue() + loop = asyncio.get_running_loop() + + def _sync_stream(): + """在独立线程中运行同步 API 调用""" + try: + # Day 21 优化:添加 system prompt 让 AI 回答简洁 + # 导盲眼镜场景需要快速、简短的回答 + system_prompt = """你是一个视障辅助AI助手,安装在智能导盲眼镜上。 +请用极简短的语言回答,每次回答不超过2-3句话。 +避免冗长解释,只提供最关键的信息。 +语气友好但简洁。""" + + messages = [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": content_list} + ] + + completion = oai_client.chat.completions.create( + model=QWEN_MODEL, + messages=messages, + modalities=["text", "audio"], + audio={"voice": voice, "format": audio_format}, + stream=True, + stream_options={"include_usage": True}, + ) + + for chunk in completion: + text_delta: Optional[str] = None + audio_b64: Optional[str] = None + + if getattr(chunk, "choices", None): + c0 = chunk.choices[0] + delta = getattr(c0, "delta", None) + # 文本增量 + if delta and getattr(delta, "content", None): + piece = delta.content + if piece: + text_delta = piece + # 音频分片 + if delta and getattr(delta, "audio", None): + aud = delta.audio + audio_b64 = aud.get("data") if isinstance(aud, dict) else getattr(aud, "data", None) + if audio_b64 is None: + msg = getattr(c0, "message", None) + if msg and getattr(msg, "audio", None): + ma = msg.audio + audio_b64 = ma.get("data") if isinstance(ma, dict) else getattr(ma, "data", None) + + if (text_delta is not None) or (audio_b64 is not None): + # 线程安全地放入队列 + loop.call_soon_threadsafe( + queue.put_nowait, + OmniStreamPiece(text_delta=text_delta, audio_b64=audio_b64) + ) + except Exception as e: + # 发生异常时也要通知 + loop.call_soon_threadsafe(queue.put_nowait, e) + finally: + # 发送结束标记 + loop.call_soon_threadsafe(queue.put_nowait, None) + + # 在独立线程中启动同步 API 调用 + thread = threading.Thread(target=_sync_stream, daemon=True) + thread.start() + + # 异步消费队列 + while True: + item = await queue.get() + if item is None: + # 流结束 + break + if isinstance(item, Exception): + # 发生异常 + raise item + yield item + diff --git a/qwen_extractor.py b/qwen_extractor.py new file mode 100644 index 0000000..4389913 --- /dev/null +++ b/qwen_extractor.py @@ -0,0 +1,64 @@ +# qwen_extractor.py +# -*- coding: utf-8 -*- +from typing import List, Tuple +import os +from openai import OpenAI + +# —— 本地优先映射(可随时扩充/改名)—— +LOCAL_CN2EN = { + "红牛": "Red_Bull", + "ad钙奶": "AD_milk", + "ad 钙奶": "AD_milk", + "ad": "AD_milk", + "钙奶": "AD_milk", + "矿泉水": "bottle", + "水瓶": "bottle", + "可乐": "coke", + "雪碧": "sprite", +} + +def _make_client() -> OpenAI: + # 复用你百炼兼容端点;支持从环境变量读取 + base_url = os.getenv("DASHSCOPE_COMPAT_BASE", "https://dashscope.aliyuncs.com/compatible-mode/v1") + api_key = "sk-a9440db694924559ae4ebdc2023d2b9a" + return OpenAI(api_key=api_key, base_url=base_url) + +PROMPT_SYS = ( + "You are a label normalizer. Convert the given Chinese object " + "description into a short, lowercase English YOLO/vision class name " + "(1~3 words). If multiple are given, return the single most likely one. " + "Output ONLY the label, no punctuation." +) + +def extract_english_label(query_cn: str) -> Tuple[str, str]: + """ + 返回 (label_en, source);source ∈ {'local', 'qwen', 'fallback'} + """ + q = (query_cn or "").strip().lower() + if q in LOCAL_CN2EN: + return LOCAL_CN2EN[q], "local" + + # 简单规则:去掉前缀修饰词 + for k, v in LOCAL_CN2EN.items(): + if k in q: + return v, "local" + + # 调用 Qwen Turbo(兼容 Chat Completions) + try: + client = _make_client() + msgs = [ + {"role": "system", "content": PROMPT_SYS}, + {"role": "user", "content": query_cn.strip()}, + ] + rsp = client.chat.completions.create( + model=os.getenv("QWEN_MODEL", "qwen-turbo"), + messages=msgs, + stream=False + ) + label = (rsp.choices[0].message.content or "").strip() + # 清洗一下 + label = label.replace(".", "").replace(",", "").replace(" ", " ").strip() + # 兜底:空就回 'bottle' + return (label or "bottle"), "qwen" + except Exception: + return "bottle", "fallback" diff --git a/qwenturbo_template.py b/qwenturbo_template.py new file mode 100644 index 0000000..ca62b15 --- /dev/null +++ b/qwenturbo_template.py @@ -0,0 +1,28 @@ +from openai import OpenAI +import os + +client = OpenAI( + # 如果没有配置环境变量,请用阿里云百炼API Key替换:api_key="sk-xxx" + api_key="sk-a9440db694924559ae4ebdc2023d2b9a", + base_url="https://dashscope.aliyuncs.com/compatible-mode/v1", +) + +messages = [{"role": "user", "content": "你是谁"}] +completion = client.chat.completions.create( + model="qwen-turbo", # 您可以按需更换为其它深度思考模型 + messages=messages, + extra_body={"enable_thinking": True}, + stream=True +) +is_answering = False # 是否进入回复阶段 +print("\n" + "=" * 20 + "思考过程" + "=" * 20) +for chunk in completion: + delta = chunk.choices[0].delta + if hasattr(delta, "reasoning_content") and delta.reasoning_content is not None: + if not is_answering: + print(delta.reasoning_content, end="", flush=True) + if hasattr(delta, "content") and delta.content: + if not is_answering: + print("\n" + "=" * 20 + "完整回复" + "=" * 20) + is_answering = True + print(delta.content, end="", flush=True) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..263f1de --- /dev/null +++ b/requirements.txt @@ -0,0 +1,71 @@ +# AI Glass System - Python Dependencies +# Python 3.9 - 3.11 supported + +# Core Web Framework +fastapi==0.104.1 +uvicorn[standard]==0.24.0 +websockets==12.0 +python-multipart==0.0.6 +starlette==0.27.0 + +# Computer Vision & Deep Learning +opencv-python==4.8.1.78 +numpy==1.24.3 +Pillow==10.1.0 +ultralytics==8.3.200 +torch==2.0.1 +torchvision==0.15.2 + +# MediaPipe (Hand Detection) +mediapipe==0.10.8 + +# Audio Processing +pyaudio==0.2.14 +pydub==0.25.1 +pygame==2.5.2 + +# Aliyun DashScope SDK (ASR & Qwen-Omni) - 旧管道 +dashscope==1.14.1 +openai==1.3.5 # For DashScope compatibility mode + +# Day 21: 新 AI 管道 (SenseVoice + GLM + EdgeTTS) +# torchaudio 需与 torch 版本匹配,使用以下命令安装: +# pip install torchaudio==2.0.2+cu118 --index-url https://download.pytorch.org/whl/cu118 +funasr>=1.2.0 # SenseVoice 本地 ASR +edge-tts>=6.1.0 # 免费 TTS + +# Environment & Configuration +python-dotenv==1.0.0 + +# Utilities +opencv-contrib-python==4.8.1.78 # Extended OpenCV modules + +# Optional Performance Optimizations +# Uncomment if needed: +# onnxruntime-gpu==1.16.3 # For ONNX model acceleration +# tensorrt==8.6.1 # For TensorRT optimization +PyTurboJPEG>=1.7.0 # Day 19: 2-3x faster JPEG encode/decode than cv2 +# NOTE: PyTurboJPEG requires system library: sudo apt-get install libturbojpeg (Ubuntu) +numba>=0.58.0 # Day 20: JIT compilation for multi-core parallel numpy operations + +# Development & Testing (Optional) +# pytest==7.4.3 +# pytest-asyncio==0.21.1 +# black==23.11.0 +# flake8==6.1.0 + +# Platform-specific dependencies +# Windows: +# - PyAudio requires separate installation: https://www.lfd.uci.edu/~gohlke/pythonlibs/#pyaudio +# Linux: +# - Install PyAudio dependencies: sudo apt-get install portaudio19-dev python3-pyaudio +# - Install OpenCV dependencies: sudo apt-get install libgl1-mesa-glx +# macOS: +# - Install PyAudio: brew install portaudio && pip install pyaudio + +# CUDA Dependencies (GPU acceleration) +# - CUDA Toolkit 11.8+: https://developer.nvidia.com/cuda-downloads +# - cuDNN 8.6+: https://developer.nvidia.com/cudnn +# PyTorch with CUDA support: +# pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 + diff --git a/sensevoice_asr.py b/sensevoice_asr.py new file mode 100644 index 0000000..14971b7 --- /dev/null +++ b/sensevoice_asr.py @@ -0,0 +1,168 @@ +# sensevoice_asr.py +# -*- coding: utf-8 -*- +""" +SenseVoice 本地 ASR 模块 - Day 21 +参考 xiaozhi-esp32-server 的非流式实现 + +特点: +- 非流式识别(等语音说完再识别) +- 内置 VAD 自动切分 +- 整句输出,不会"蹦字" +""" + +import os +import time +import asyncio +import numpy as np +from typing import Optional, Tuple +from funasr import AutoModel +from funasr.utils.postprocess_utils import rich_transcription_postprocess + +# 模型路径 - 支持环境变量配置和相对路径 +# 【重要】FunASR 需要目录路径(包含 config.yaml 和 model.pt) +# 本地模型目录(相对于当前文件) +_LOCAL_MODEL_DIR = os.path.join(os.path.dirname(__file__), "model", "SenseVoiceSmall") + +# 自动选择可用路径:环境变量 > 本地目录 > 在线下载 +if os.path.exists(os.getenv("SENSEVOICE_MODEL_PATH", "")): + MODEL_PATH = os.getenv("SENSEVOICE_MODEL_PATH") +elif os.path.isdir(_LOCAL_MODEL_DIR) and os.path.exists(os.path.join(_LOCAL_MODEL_DIR, "model.pt")): + MODEL_PATH = _LOCAL_MODEL_DIR +else: + # 使用 FunASR 模型标识符(首次运行会自动下载到 ~/.cache) + MODEL_PATH = "iic/SenseVoiceSmall" + +# GPU 设备 - 使用与主程序相同的配置方式 +# 注意:服务器通过 CUDA_VISIBLE_DEVICES=1 选择 GPU,程序中统一使用 cuda:0 +import torch +if torch.cuda.is_available(): + DEVICE = os.getenv("SENSEVOICE_DEVICE", os.getenv("AIGLASS_DEVICE", "cuda:0")) +else: + DEVICE = "cpu" + +# 全局模型实例 +_model: Optional[AutoModel] = None +_model_lock = asyncio.Lock() + + +def _load_model(): + """加载 SenseVoice 模型""" + global _model + if _model is not None: + return _model + + print(f"[SenseVoice] 正在加载模型: {MODEL_PATH}") + print(f"[SenseVoice] 使用设备: {DEVICE}") + + start_time = time.time() + + _model = AutoModel( + model=MODEL_PATH, + vad_kwargs={"max_single_segment_time": 30000}, # VAD 最大 30 秒 + disable_update=True, + hub="hf", # 参考 xiaozhi-esp32-server + device=DEVICE, + ) + + print(f"[SenseVoice] 模型加载完成,耗时: {time.time() - start_time:.2f}s") + return _model + + +async def init_sensevoice(): + """异步初始化 SenseVoice(服务器启动时调用)""" + async with _model_lock: + await asyncio.to_thread(_load_model) + print("[SenseVoice] 初始化完成") + + +async def recognize(pcm_data: bytes, sample_rate: int = 16000) -> str: + """ + 识别 PCM 音频数据 + + Args: + pcm_data: PCM 16-bit 音频数据 (bytes) + sample_rate: 采样率 (默认 16000) + + Returns: + 识别结果文本 + """ + if _model is None: + await init_sensevoice() + + if not pcm_data or len(pcm_data) < 640: # 至少 20ms 音频 + return "" + + try: + start_time = time.time() + + # 在线程池中执行推理(避免阻塞事件循环) + # 【Day 22 修复】language 从 "auto" 改为 "zh" + # 避免误识别为韩语等其他语言 + result = await asyncio.to_thread( + _model.generate, + input=pcm_data, + cache={}, + language="zh", # 固定为中文,避免 auto 误判 + use_itn=True, + batch_size_s=60, + ) + + # 后处理 + if result and len(result) > 0 and "text" in result[0]: + text = await asyncio.to_thread( + rich_transcription_postprocess, + result[0]["text"] + ) + elapsed = time.time() - start_time + print(f"[SenseVoice] 识别耗时: {elapsed:.3f}s | 结果: {text}") + return text.strip() + else: + print("[SenseVoice] 识别结果为空") + return "" + + except Exception as e: + print(f"[SenseVoice] 识别失败: {e}") + import traceback + traceback.print_exc() + return "" + + +async def recognize_from_file(file_path: str) -> str: + """ + 从文件识别音频 + + Args: + file_path: 音频文件路径 + + Returns: + 识别结果文本 + """ + if _model is None: + await init_sensevoice() + + try: + start_time = time.time() + + result = await asyncio.to_thread( + _model.generate, + input=file_path, + cache={}, + language="zh", # 【Day 22 修复】固定为中文 + use_itn=True, + batch_size_s=60, + ) + + if result and len(result) > 0 and "text" in result[0]: + text = await asyncio.to_thread( + rich_transcription_postprocess, + result[0]["text"] + ) + elapsed = time.time() - start_time + print(f"[SenseVoice] 文件识别耗时: {elapsed:.3f}s | 结果: {text}") + return text.strip() + else: + return "" + + except Exception as e: + print(f"[SenseVoice] 文件识别失败: {e}") + return "" diff --git a/server_vad.py b/server_vad.py new file mode 100644 index 0000000..21ef7f3 --- /dev/null +++ b/server_vad.py @@ -0,0 +1,326 @@ +""" +Silero VAD 服务器端语音活动检测 +参考 xiaozhi-esp32-server 实现 +""" +import torch +import numpy as np +import os +import collections # Day 23: For VAD lookback buffer +import time + +# 尝试加载模型 +_vad_model = None +_model_loaded = False + +def get_vad_model(): + """获取或加载 Silero VAD 模型""" + global _vad_model, _model_loaded + + if _model_loaded: + return _vad_model + + try: + # 尝试从本地加载 + model_dir = os.path.join(os.path.dirname(__file__), "model", "snakers4_silero-vad") + if os.path.exists(model_dir): + print(f"[VAD] 从本地加载 Silero VAD: {model_dir}") + _vad_model, _ = torch.hub.load( + repo_or_dir=model_dir, + source="local", + model="silero_vad", + force_reload=False, + ) + else: + # 优先使用缓存,避免每次检查 GitHub 更新 + cache_dir = os.path.expanduser("~/.cache/torch/hub/snakers4_silero-vad_master") + if os.path.exists(cache_dir): + print(f"[VAD] 使用 torch hub 缓存: {cache_dir}") + _vad_model, _ = torch.hub.load( + repo_or_dir=cache_dir, + source="local", + model="silero_vad", + force_reload=False, + ) + else: + # 缓存不存在,从网络下载 + print("[VAD] 从 torch.hub 下载 Silero VAD...") + _vad_model, _ = torch.hub.load( + repo_or_dir='snakers4/silero-vad', + model='silero_vad', + force_reload=False, + ) + + _model_loaded = True + print("[VAD] Silero VAD 模型加载成功") + return _vad_model + except Exception as e: + print(f"[VAD] Silero VAD 加载失败: {e}") + _model_loaded = True # 避免重复尝试加载 + return None + + +class SileroVAD: + """ + 服务器端 Silero VAD + 用于检测语音开始和结束 + """ + + def __init__(self, + threshold: float = 0.5, # Day 23: 再次降低阈值 (原 0.7) + threshold_low: float = 0.3, # Day 23: 再次降低低阈值 (原 0.4) + min_silence_ms: int = 800, # Day 23: 延长静默 (原 600) + min_speech_ms: int = 300, # Day 23: 降低最小语音 (原 500) + sample_rate: int = 16000): + """ + 初始化 VAD + + Args: + threshold: 语音概率阈值(超过此值判断为语音) + threshold_low: 语音概率低阈值(低于此值判断为静默) + min_silence_ms: 最小静默时间(毫秒),超过此时间认为语音结束 + min_speech_ms: 最小语音时间(毫秒),至少说这么久才算有效语音 + sample_rate: 采样率 + """ + self.model = get_vad_model() + self.threshold = threshold + self.threshold_low = threshold_low + self.min_silence_ms = min_silence_ms + self.min_speech_ms = min_speech_ms + self.sample_rate = sample_rate + + # 状态 + self.audio_buffer = bytearray() + self.is_speaking = False + self.last_speech_time = 0 + self.speech_start_time = 0 + self.speech_audio = bytearray() # 存储语音音频 + + # TTS 播放状态 - 播放期间暂停 VAD + self.tts_playing = False + self.tts_end_time = 0 # TTS 结束时间 + self.tts_cooldown_ms = 500 # TTS 结束后等待 500ms 再开始检测 + + # 滑动窗口 + self.voice_window = [] + self.window_size = 5 # 滑动窗口大小 + self.frame_threshold = 3 # 至少多少帧语音才算开始说话 + + # Day 23: Pre-speech buffer (Lookback) to fix "cut-off" start of words + # 300ms lookback approx. (each chunk is 32ms) -> 10 chunks + self.pre_speech_buffer = collections.deque(maxlen=10) + + print(f"[VAD] 初始化: threshold={threshold}, threshold_low={threshold_low}, " + f"min_silence_ms={min_silence_ms}, min_speech_ms={min_speech_ms}") + + def reset(self): + """重置 VAD 状态""" + self.audio_buffer.clear() + self.speech_audio.clear() + self.is_speaking = False + self.last_speech_time = 0 + self.speech_start_time = 0 + self.voice_window.clear() + self.tts_playing = False + self.tts_end_time = 0 + if self.model: + self.model.reset_states() + + def set_tts_playing(self, playing: bool): + """设置 TTS 播放状态""" + self.tts_playing = playing + if not playing: + # TTS 结束,记录时间 + self.tts_end_time = time.time() * 1000 + print("[VAD] TTS 结束,等待冷却期...") + else: + print("[VAD] TTS 开始播放,暂停 VAD 检测") + # TTS 开始播放时,如果正在录音则中断 + if self.is_speaking: + self.is_speaking = False + self.speech_audio.clear() + self.voice_window.clear() + # Day 23: Clear lookback buffer + if hasattr(self, 'pre_speech_buffer'): + self.pre_speech_buffer.clear() + print("[VAD] TTS 播放打断语音录制") + + def process(self, audio_bytes: bytes) -> dict: + """ + 处理音频数据 + + Args: + audio_bytes: PCM 16-bit 音频数据 + + Returns: + dict: { + 'speech_started': bool, # 语音刚刚开始 + 'speech_ended': bool, # 语音刚刚结束 + 'is_speaking': bool, # 当前是否在说话 + 'speech_audio': bytes, # 如果语音结束,返回完整语音音频 + } + """ + result = { + 'speech_started': False, + 'speech_ended': False, + 'is_speaking': self.is_speaking, + 'speech_audio': None, + } + + if self.model is None: + # 没有模型,使用简单能量检测 + return self._fallback_energy_vad(audio_bytes, result) + + # TTS 播放期间,跳过 VAD 检测 + current_time = time.time() * 1000 + if self.tts_playing: + return result + + # TTS 刚结束,等待冷却期 + if self.tts_end_time > 0 and (current_time - self.tts_end_time) < self.tts_cooldown_ms: + return result + + # 将音频添加到缓冲区 + self.audio_buffer.extend(audio_bytes) + + # Silero VAD 需要 512 采样点 (32ms @ 16kHz) + chunk_size = 512 * 2 # 512 samples * 2 bytes + + while len(self.audio_buffer) >= chunk_size: + chunk = self.audio_buffer[:chunk_size] + self.audio_buffer = self.audio_buffer[chunk_size:] + + # 转换为模型需要的格式 + audio_int16 = np.frombuffer(chunk, dtype=np.int16) + audio_float32 = audio_int16.astype(np.float32) / 32768.0 + audio_tensor = torch.from_numpy(audio_float32) + + # 检测语音概率 + with torch.no_grad(): + speech_prob = self.model(audio_tensor, self.sample_rate).item() + # Day 23: Debug logging to diagnose low volume/mic issues + if speech_prob > 0.3: + print(f"[VAD DEBUG] Prob: {speech_prob:.3f}") + + # 双阈值判断 + if speech_prob >= self.threshold: + is_voice = True + elif speech_prob <= self.threshold_low: + is_voice = False + else: + is_voice = self.is_speaking # 保持当前状态 + + # 更新滑动窗口 + self.voice_window.append(is_voice) + if len(self.voice_window) > self.window_size: + self.voice_window.pop(0) + + # 判断是否有语音 + voice_count = self.voice_window.count(True) + has_voice = voice_count >= self.frame_threshold + + # Maintain lookback buffer (always add current chunk) + self.pre_speech_buffer.append(chunk) + + current_time = time.time() * 1000 # 毫秒 + + if has_voice: + if not self.is_speaking: + # 语音开始 + self.is_speaking = True + self.speech_start_time = current_time + self.speech_audio.clear() + result['speech_started'] = True + result['speech_started'] = True + print("[VAD] 🎤 Speech started") + + # Day 23: Prepend lookback buffer to recover the start of speech + if self.pre_speech_buffer: + for prev_chunk in self.pre_speech_buffer: + self.speech_audio.extend(prev_chunk) + print(f"[VAD] Recovered {len(self.pre_speech_buffer)} chunks ({len(self.pre_speech_buffer)*32}ms) from history") + + self.last_speech_time = current_time + self.speech_audio.extend(chunk) + + elif self.is_speaking: + # 仍在收集音频(可能是短暂停顿) + self.speech_audio.extend(chunk) + + # 检查是否静默时间过长 + silence_duration = current_time - self.last_speech_time + speech_duration = current_time - self.speech_start_time + + if silence_duration >= self.min_silence_ms: + # 语音结束 + self.is_speaking = False + + # 检查语音是否足够长 + if speech_duration >= self.min_speech_ms: + result['speech_ended'] = True + result['speech_audio'] = bytes(self.speech_audio) + print(f"[VAD] 🔇 Speech ended, duration={speech_duration:.0f}ms, " + f"audio_size={len(self.speech_audio)} bytes") + else: + print(f"[VAD] 语音太短 ({speech_duration:.0f}ms), 忽略") + + self.speech_audio.clear() + + result['is_speaking'] = self.is_speaking + return result + + def _fallback_energy_vad(self, audio_bytes: bytes, result: dict) -> dict: + """简单能量检测(作为备用)""" + # 计算 RMS 能量 + audio_int16 = np.frombuffer(audio_bytes, dtype=np.int16) + rms = np.sqrt(np.mean(audio_int16.astype(np.float32) ** 2)) + + # 简单阈值 + threshold = 500 + is_voice = rms > threshold + + current_time = time.time() * 1000 + + if is_voice: + if not self.is_speaking: + self.is_speaking = True + self.speech_start_time = current_time + self.speech_audio.clear() + result['speech_started'] = True + + self.last_speech_time = current_time + self.speech_audio.extend(audio_bytes) + + elif self.is_speaking: + self.speech_audio.extend(audio_bytes) + + silence_duration = current_time - self.last_speech_time + if silence_duration >= self.min_silence_ms: + self.is_speaking = False + speech_duration = current_time - self.speech_start_time + + if speech_duration >= self.min_speech_ms: + result['speech_ended'] = True + result['speech_audio'] = bytes(self.speech_audio) + + self.speech_audio.clear() + + result['is_speaking'] = self.is_speaking + return result + + +# 全局 VAD 实例 +_global_vad = None + +def get_server_vad() -> SileroVAD: + """获取全局 VAD 实例""" + global _global_vad + if _global_vad is None: + _global_vad = SileroVAD() + return _global_vad + + +def reset_server_vad(): + """重置全局 VAD 状态""" + global _global_vad + if _global_vad: + _global_vad.reset() diff --git a/setup.bat b/setup.bat new file mode 100644 index 0000000..1b72cf6 --- /dev/null +++ b/setup.bat @@ -0,0 +1,157 @@ +@echo off +REM AI Glass System - Windows 快速安装脚本 + +echo ========================================== +echo AI Glass System - 自动安装脚本 +echo ========================================== +echo. + +REM 检查 Python +echo 正在检查 Python... +python --version >nul 2>&1 +if errorlevel 1 ( + echo [错误] 未找到 Python + echo 请从 https://www.python.org/downloads/ 下载并安装 Python 3.9-3.11 + pause + exit /b 1 +) + +python --version +echo [成功] Python 已安装 + +REM 检查 CUDA +echo. +echo 正在检查 CUDA... +nvidia-smi >nul 2>&1 +if errorlevel 1 ( + echo [警告] 未检测到 NVIDIA GPU,将使用 CPU 模式(速度较慢) + set HAS_GPU=0 +) else ( + echo [成功] 检测到 NVIDIA GPU + nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader + set HAS_GPU=1 +) + +REM 创建虚拟环境 +echo. +echo 正在创建虚拟环境... +if exist venv ( + echo [警告] 虚拟环境已存在 + set /p RECREATE="是否删除并重新创建? (y/n): " + if /i "%RECREATE%"=="y" ( + rmdir /s /q venv + python -m venv venv + echo [成功] 虚拟环境已重新创建 + ) +) else ( + python -m venv venv + echo [成功] 虚拟环境已创建 +) + +REM 激活虚拟环境 +echo. +echo 正在激活虚拟环境... +call venv\Scripts\activate.bat + +REM 升级 pip +echo. +echo 正在升级 pip... +python -m pip install --upgrade pip -q +echo [成功] pip 已升级 + +REM 安装 PyTorch +echo. +echo 正在安装 PyTorch... +if %HAS_GPU%==1 ( + echo 安装 GPU 版本 PyTorch ^(CUDA 11.8^)... + pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 -q +) else ( + echo 安装 CPU 版本 PyTorch... + pip install torch torchvision -q +) +echo [成功] PyTorch 已安装 + +REM 验证 PyTorch +echo. +echo 验证 PyTorch 安装... +python -c "import torch; print(f'PyTorch 版本: {torch.__version__}'); print(f'CUDA 可用: {torch.cuda.is_available()}')" + +REM 安装 PyAudio +echo. +echo 正在安装 PyAudio... +echo [警告] PyAudio 在 Windows 上可能需要手动安装 +echo 如果自动安装失败,请从以下地址下载 wheel 文件: +echo https://www.lfd.uci.edu/~gohlke/pythonlibs/#pyaudio +echo. +pip install pyaudio -q +if errorlevel 1 ( + echo [警告] PyAudio 自动安装失败,请手动安装 +) else ( + echo [成功] PyAudio 已安装 +) + +REM 安装其他依赖 +echo. +echo 正在安装 Python 依赖... +pip install -r requirements.txt -q +echo [成功] Python 依赖已安装 + +REM 创建 .env 文件 +echo. +if not exist .env ( + echo 正在创建 .env 配置文件... + copy .env.example .env >nul + echo [成功] .env 文件已创建 + echo [提示] 请编辑 .env 文件,填入您的 DASHSCOPE_API_KEY +) else ( + echo [跳过] .env 文件已存在 +) + +REM 创建必要的目录 +echo. +echo 正在创建目录结构... +if not exist recordings mkdir recordings +if not exist model mkdir model +if not exist music mkdir music +if not exist voice mkdir voice +echo [成功] 目录结构已创建 + +REM 检查模型文件 +echo. +echo 正在检查模型文件... +set MISSING=0 +if exist model\yolo-seg.pt (echo [成功] yolo-seg.pt) else (echo [缺失] yolo-seg.pt & set MISSING=1) +if exist model\yoloe-11l-seg.pt (echo [成功] yoloe-11l-seg.pt) else (echo [缺失] yoloe-11l-seg.pt & set MISSING=1) +if exist model\shoppingbest5.pt (echo [成功] shoppingbest5.pt) else (echo [缺失] shoppingbest5.pt & set MISSING=1) +if exist model\trafficlight.pt (echo [成功] trafficlight.pt) else (echo [缺失] trafficlight.pt & set MISSING=1) +if exist model\hand_landmarker.task (echo [成功] hand_landmarker.task) else (echo [缺失] hand_landmarker.task & set MISSING=1) + +if %MISSING%==1 ( + echo. + echo [警告] 部分模型文件缺失,请将模型文件放入 model\ 目录 +) + +REM 完成 +echo. +echo ========================================== +echo [成功] 安装完成! +echo ========================================== +echo. +echo 下一步: +echo 1. 编辑 .env 文件,填入您的 API 密钥: +echo notepad .env +echo. +echo 2. 确保所有模型文件已放入 model\ 目录 +echo. +echo 3. 启动系统: +echo venv\Scripts\activate +echo python app_main.py +echo. +echo 4. 访问 http://localhost:8081 +echo. +echo [提示] 每次使用前请激活虚拟环境: +echo venv\Scripts\activate +echo. + +pause + diff --git a/setup.sh b/setup.sh new file mode 100644 index 0000000..37626aa --- /dev/null +++ b/setup.sh @@ -0,0 +1,192 @@ +#!/bin/bash +# AI Glass System - Linux/macOS 快速安装脚本 + +set -e # 遇到错误立即退出 + +echo "==========================================" +echo " AI Glass System - 自动安装脚本" +echo "==========================================" +echo "" + +# 颜色定义 +RED='\033[0;31m' +GREEN='\033[0;32m' +YELLOW='\033[1;33m' +NC='\033[0m' # No Color + +# 检查 Python 版本 +echo "正在检查 Python 版本..." +if ! command -v python3 &> /dev/null; then + echo -e "${RED}错误: 未找到 Python 3${NC}" + echo "请先安装 Python 3.9-3.11" + exit 1 +fi + +PYTHON_VERSION=$(python3 -c 'import sys; print(".".join(map(str, sys.version_info[:2])))') +echo -e "${GREEN}✓ 找到 Python $PYTHON_VERSION${NC}" + +# 检查 Python 版本是否在支持范围内 +PYTHON_MAJOR=$(echo $PYTHON_VERSION | cut -d. -f1) +PYTHON_MINOR=$(echo $PYTHON_VERSION | cut -d. -f2) + +if [ "$PYTHON_MAJOR" -ne 3 ] || [ "$PYTHON_MINOR" -lt 9 ] || [ "$PYTHON_MINOR" -gt 11 ]; then + echo -e "${YELLOW}警告: Python 版本 $PYTHON_VERSION 可能不受支持${NC}" + echo "推荐使用 Python 3.9-3.11" + read -p "是否继续? (y/n) " -n 1 -r + echo + if [[ ! $REPLY =~ ^[Yy]$ ]]; then + exit 1 + fi +fi + +# 检查 CUDA(可选) +echo "" +echo "正在检查 CUDA..." +if command -v nvidia-smi &> /dev/null; then + echo -e "${GREEN}✓ 检测到 NVIDIA GPU${NC}" + nvidia-smi --query-gpu=name,driver_version,memory.total --format=csv,noheader + HAS_GPU=true +else + echo -e "${YELLOW}! 未检测到 NVIDIA GPU,将使用 CPU 模式(速度较慢)${NC}" + HAS_GPU=false +fi + +# 创建虚拟环境 +echo "" +echo "正在创建虚拟环境..." +if [ -d "venv" ]; then + echo -e "${YELLOW}虚拟环境已存在${NC}" + read -p "是否删除并重新创建? (y/n) " -n 1 -r + echo + if [[ $REPLY =~ ^[Yy]$ ]]; then + rm -rf venv + python3 -m venv venv + echo -e "${GREEN}✓ 虚拟环境已重新创建${NC}" + fi +else + python3 -m venv venv + echo -e "${GREEN}✓ 虚拟环境已创建${NC}" +fi + +# 激活虚拟环境 +echo "正在激活虚拟环境..." +source venv/bin/activate + +# 升级 pip +echo "" +echo "正在升级 pip..." +pip install --upgrade pip -q +echo -e "${GREEN}✓ pip 已升级${NC}" + +# 安装 PyTorch +echo "" +echo "正在安装 PyTorch..." +if [ "$HAS_GPU" = true ]; then + echo "安装 GPU 版本 PyTorch (CUDA 11.8)..." + pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 --index-url https://download.pytorch.org/whl/cu118 -q +else + echo "安装 CPU 版本 PyTorch..." + pip install torch torchvision -q +fi +echo -e "${GREEN}✓ PyTorch 已安装${NC}" + +# 验证 PyTorch +echo "验证 PyTorch 安装..." +python3 -c "import torch; print(f'PyTorch 版本: {torch.__version__}'); print(f'CUDA 可用: {torch.cuda.is_available()}')" + +# 安装系统依赖(Linux) +if [[ "$OSTYPE" == "linux-gnu"* ]]; then + echo "" + echo "正在检查系统依赖..." + + # 检测发行版 + if [ -f /etc/os-release ]; then + . /etc/os-release + OS=$ID + else + OS="unknown" + fi + + if [ "$OS" = "ubuntu" ] || [ "$OS" = "debian" ]; then + echo "检测到 Ubuntu/Debian 系统" + echo "可能需要 sudo 权限来安装系统依赖..." + sudo apt-get update -qq + sudo apt-get install -y -qq portaudio19-dev libgl1-mesa-glx libglib2.0-0 + echo -e "${GREEN}✓ 系统依赖已安装${NC}" + else + echo -e "${YELLOW}! 未知的 Linux 发行版,请手动安装依赖${NC}" + echo " 需要: portaudio19-dev, libgl1-mesa-glx, libglib2.0-0" + fi +fi + +# 安装 Python 依赖 +echo "" +echo "正在安装 Python 依赖..." +pip install -r requirements.txt -q +echo -e "${GREEN}✓ Python 依赖已安装${NC}" + +# 创建 .env 文件 +echo "" +if [ ! -f ".env" ]; then + echo "正在创建 .env 配置文件..." + cp .env.example .env + echo -e "${GREEN}✓ .env 文件已创建${NC}" + echo -e "${YELLOW}请编辑 .env 文件,填入您的 DASHSCOPE_API_KEY${NC}" +else + echo -e "${YELLOW}.env 文件已存在,跳过${NC}" +fi + +# 创建必要的目录 +echo "" +echo "正在创建目录结构..." +mkdir -p recordings model music voice +echo -e "${GREEN}✓ 目录结构已创建${NC}" + +# 检查模型文件 +echo "" +echo "正在检查模型文件..." +MODELS=("yolo-seg.pt" "yoloe-11l-seg.pt" "shoppingbest5.pt" "trafficlight.pt" "hand_landmarker.task") +MISSING_MODELS=() + +for model in "${MODELS[@]}"; do + if [ -f "model/$model" ]; then + echo -e "${GREEN}✓ $model${NC}" + else + echo -e "${RED}✗ $model (缺失)${NC}" + MISSING_MODELS+=("$model") + fi +done + +if [ ${#MISSING_MODELS[@]} -gt 0 ]; then + echo "" + echo -e "${YELLOW}警告: 缺少以下模型文件:${NC}" + for model in "${MISSING_MODELS[@]}"; do + echo " - $model" + done + echo "请将模型文件放入 model/ 目录" +fi + +# 完成 +echo "" +echo "==========================================" +echo -e "${GREEN}安装完成!${NC}" +echo "==========================================" +echo "" +echo "下一步:" +echo "1. 编辑 .env 文件,填入您的 API 密钥:" +echo " nano .env" +echo "" +echo "2. 确保所有模型文件已放入 model/ 目录" +echo "" +echo "3. 启动系统:" +echo " source venv/bin/activate" +echo " python app_main.py" +echo "" +echo "4. 访问 http://localhost:8081" +echo "" + +# 提示激活虚拟环境 +echo -e "${YELLOW}注意: 每次使用前请激活虚拟环境:${NC}" +echo " source venv/bin/activate" +echo "" + diff --git a/static/favicon.png b/static/favicon.png new file mode 100644 index 0000000..14ad066 Binary files /dev/null and b/static/favicon.png differ diff --git a/static/main.js b/static/main.js new file mode 100644 index 0000000..df26478 --- /dev/null +++ b/static/main.js @@ -0,0 +1,883 @@ +// static/main.js + +// ================= 摄像头 + ASR ================= +(() => { + const $camStatus = document.getElementById('camStatus'); + const $asrStatus = document.getElementById('asrStatus'); + const $partial = document.getElementById('partial'); + const $finalList = document.getElementById('finalList'); + const $btnClear = document.getElementById('btnClear'); + const $btnRe = document.getElementById('btnReconnect'); + const $fps = document.getElementById('fps'); + const canvas = document.getElementById('canvas'); + const ctx = canvas.getContext('2d'); + + // === 获取/创建聊天容器(关键补丁) === + let chatContainer = document.getElementById('chatContainer'); + + function ensureChatContainer() { + // 已缓存且仍在文档中 + if (chatContainer && document.body.contains(chatContainer)) return chatContainer; + + // 重新获取,防热更新或 DOM 移动 + chatContainer = document.getElementById('chatContainer'); + if (!chatContainer) { + chatContainer = document.createElement('div'); + chatContainer.id = 'chatContainer'; + + // 优先挂到 finalList 的父容器;否则挂到 partial 的父容器;再否则挂到 body 兜底 + if ($finalList && $finalList.parentElement) { + // 隐藏原来的 finalList + $finalList.style.display = 'none'; + // 将聊天容器挂载到 finals div 内 + $finalList.parentElement.appendChild(chatContainer); + console.log('[chat] 创建并挂载 #chatContainer 到 finalList 区域'); + } else if ($partial && $partial.parentElement) { + $partial.parentElement.appendChild(chatContainer); + console.log('[chat] 创建并挂载 #chatContainer 到 partial 区域'); + } else { + document.body.appendChild(chatContainer); + console.warn('[chat] 未找到合适锚点,已挂到 '); + } + } + return chatContainer; + } + + // === 注入聊天样式(左右两侧气泡 + 时间戳,增加权重)=== + (function injectChatStyles() { + if (document.getElementById('chat-style-injected')) return; + const s = document.createElement('style'); + s.id = 'chat-style-injected'; + s.textContent = ` + #chatContainer{ + position: relative !important; + overflow-y: auto !important; + flex: 1 !important; /* 改为使用 flex: 1 占满剩余空间 */ + min-height: 0 !important; /* 确保 flex 子元素能正确收缩 */ + padding: 12px 12px 4px !important; + background: #0b1020 !important; + border: 1px solid #1d2438 !important; + border-radius: 10px !important; + margin-top: 12px !important; + } + + /* 自定义滚动条样式 */ + #chatContainer::-webkit-scrollbar { + width: 8px !important; + } + + #chatContainer::-webkit-scrollbar-track { + background: #0d1420 !important; + border-radius: 4px !important; + } + + #chatContainer::-webkit-scrollbar-thumb { + background: #2a3446 !important; + border-radius: 4px !important; + transition: background 0.2s !important; + } + + #chatContainer::-webkit-scrollbar-thumb:hover { + background: #3a4556 !important; + } + + /* Firefox 滚动条 */ + #chatContainer { + scrollbar-width: thin !important; + scrollbar-color: #2a3446 #0d1420 !important; + } + .timestamp{ + text-align:center !important; + font-size:12px !important; + color:#8a93a5 !important; + margin:10px 0 !important; + user-select:none !important; + } + .message{ + display:flex !important; + gap:8px !important; + margin:6px 0 !important; + align-items:flex-end !important; + } + .message.ai{ justify-content:flex-start !important; } + .message.user{ justify-content:flex-end !important; } + + .avatar{ + width:28px !important; height:28px !important; border-radius:50% !important; + background:#232a3d !important; flex:0 0 28px !important; + display:flex !important; align-items:center !important; justify-content:center !important; + color:#9fb0c3 !important; font-size:12px !important; user-select:none !important; + border:1px solid #29314a !important; + } + .message.user .avatar{ display:none !important; } + + .bubble{ + max-width: 72% !important; + padding:10px 12px !important; + line-height:1.45 !important; + border-radius:14px !important; + word-break:break-word !important; + white-space:pre-wrap !important; + border:1px solid transparent !important; + box-shadow:0 2px 8px rgba(0,0,0,0.15) !important; + font-size:14px !important; + } + .message.ai .bubble{ + background:#111a2e !important; + color:#e6edf3 !important; + border-color:#1e2740 !important; + border-top-left-radius:6px !important; + } + .message.user .bubble{ + background:#2a6df4 !important; + color:#fff !important; + border-color:#2a6df4 !important; + border-top-right-radius:6px !important; + } + `; + document.head.appendChild(s); + })(); + + // 聊天消息管理 + let lastTimestamp = 0; + const TIMESTAMP_INTERVAL = 5 * 60 * 1000; // 5分钟 + + function shouldShowTimestamp() { + const now = Date.now(); + if (now - lastTimestamp > TIMESTAMP_INTERVAL) { + lastTimestamp = now; + return true; + } + return false; + } + + function formatTime(timestamp = Date.now()) { + const date = new Date(timestamp); + const hours = date.getHours().toString().padStart(2, '0'); + const minutes = date.getMinutes().toString().padStart(2, '0'); + return `${hours}:${minutes}`; + } + + function addTimestamp() { + const container = ensureChatContainer(); + const timestampDiv = document.createElement('div'); + timestampDiv.className = 'timestamp'; + timestampDiv.textContent = formatTime(); + container.appendChild(timestampDiv); + } + + function addMessage(text, isUser = false) { + // 时间戳 + if (shouldShowTimestamp()) addTimestamp(); + + const container = ensureChatContainer(); + + // 行容器 + const messageDiv = document.createElement('div'); + messageDiv.className = `message ${isUser ? 'user' : 'ai'}`; + + // 左侧头像(AI) + const avatar = document.createElement('div'); + avatar.className = 'avatar'; + avatar.textContent = isUser ? '' : 'AI'; + + // 气泡 + const bubbleDiv = document.createElement('div'); + bubbleDiv.className = 'bubble'; + bubbleDiv.textContent = text; + + if (isUser) { + // 右侧:气泡在右 + messageDiv.appendChild(bubbleDiv); + } else { + // 左侧:头像 + 气泡 + messageDiv.appendChild(avatar); + messageDiv.appendChild(bubbleDiv); + } + + container.appendChild(messageDiv); + + // 滚动到底部 + container.scrollTop = container.scrollHeight; + } + + // Day 20: 更新 badge 样式,支持 connecting 状态动画 + function setBadge(el, status, text) { + el.textContent = text; + // status: 'ok', 'err', 'connecting' + if (status === true) status = 'ok'; + if (status === false) status = 'err'; + el.className = 'badge ' + (status || ''); + } + + function navLabelAndText(raw) { + // 去掉前缀 “[导航] ” + const t = raw.startsWith('[导航]') ? raw.substring(4).trim() : raw; + // 粗略判断:含“斑马线/绿灯/红灯/黄灯/过马路”归为斑马线导航,否则盲道导航 + const crossHints = ['斑马线', '绿灯', '红灯', '黄灯', '过马路']; + const isCross = crossHints.some(k => t.includes(k)); + const label = isCross ? '【斑马线导航】' : '【盲道导航】'; + return { label, text: `${label} ${t}` }; + } + + // 改进的 fitCanvas: 支持移动端尺寸计算 + function fitCanvas() { + const rect = canvas.getBoundingClientRect(); + // 使用容器实际宽高,添加最小值保护 + const w = Math.max(320, Math.floor(rect.width) || 320); + let h = Math.floor(rect.height) || 0; + // 如果容器高度太小或为0,使用4:3比例回退 + if (h < 100) { + h = Math.max(240, Math.floor(w * 3 / 4)); + } + if (canvas.width !== w || canvas.height !== h) { + canvas.width = w; + canvas.height = h; + console.log('[Canvas] 尺寸调整:', w, 'x', h); + } + } + window.addEventListener('resize', fitCanvas); + // 延迟初始化,确保布局完成 + setTimeout(fitCanvas, 100); + fitCanvas(); + + let wsCam, wsUI, frames = 0, fpsTimer = 0; + + function drawBlob(buf) { + const blob = new Blob([buf], { type: 'image/jpeg' }); + if ('createImageBitmap' in window) { + createImageBitmap(blob).then(bmp => { + fitCanvas(); + ctx.drawImage(bmp, 0, 0, canvas.width, canvas.height); + }).catch(() => { }); + } else { + const img = new Image(); + img.onload = () => { fitCanvas(); ctx.drawImage(img, 0, 0, canvas.width, canvas.height); URL.revokeObjectURL(img.src); }; + img.src = URL.createObjectURL(blob); + } + frames++; + const now = performance.now(); + if (!fpsTimer) fpsTimer = now; + if (now - fpsTimer >= 1000) { + $fps.textContent = 'FPS: ' + frames; + frames = 0; fpsTimer = now; + } + } + + function connectCamera() { + try { if (wsCam) wsCam.close(); } catch (e) { } + const proto = location.protocol === 'https:' ? 'wss' : 'ws'; + wsCam = new WebSocket(`${proto}://${location.host}/ws/viewer`); + setBadge($camStatus, 'connecting', '📷 连接中...'); + wsCam.binaryType = 'arraybuffer'; + wsCam.onopen = () => setBadge($camStatus, 'ok', '📷 已连接'); + wsCam.onclose = () => setBadge($camStatus, 'err', '📷 已断开'); + wsCam.onerror = () => setBadge($camStatus, 'err', '📷 连接错误'); + wsCam.onmessage = (ev) => drawBlob(ev.data); + } + + function connectASR() { + try { if (wsUI) wsUI.close(); } catch (e) { } + const proto = location.protocol === 'https:' ? 'wss' : 'ws'; + wsUI = new WebSocket(`${proto}://${location.host}/ws_ui`); + setBadge($asrStatus, 'connecting', '🎤 连接中...'); + wsUI.onopen = () => setBadge($asrStatus, 'ok', '🎤 已连接'); + wsUI.onclose = () => setBadge($asrStatus, 'err', '🎤 已断开'); + wsUI.onerror = () => setBadge($asrStatus, 'err', '🎤 连接错误'); + wsUI.onmessage = (ev) => { + const s = ev.data || ''; + if (s.startsWith('INIT:')) { + try { + const data = JSON.parse(s.slice(5)); + $partial.textContent = data.partial || '(等待音频…)'; + + // 初始化时加载历史消息(识别 [AI] 与 [导航]) + if (data.finals && data.finals.length > 0) { + data.finals.forEach(text => { + if (text.startsWith('[AI]')) { + addMessage(text.substring(4).trim(), false); + } else if (text.startsWith('[导航]')) { + const { text: show } = navLabelAndText(text); + addMessage(show, false); + } else { + addMessage(text, true); + } + }); + } + } catch (e) { } + return; + } + if (s.startsWith('PARTIAL:')) { + $partial.textContent = s.slice(8); + return; + } + if (s.startsWith('FINAL:')) { + const text = s.slice(6); + if (text.startsWith('[AI]')) { + addMessage(text.substring(4).trim(), false); + } else if (text.startsWith('[导航]')) { + const { text: show } = navLabelAndText(text); + addMessage(show, false); // 左侧 AI + } else { + addMessage(text, true); // 其它仍按右侧 + } + $partial.textContent = '(等待音频…)'; + return; + } + } + } + + $btnClear.onclick = () => { + const container = ensureChatContainer(); + // 清空聊天记录 + const messages = container.querySelectorAll('.message, .timestamp'); + messages.forEach(msg => msg.remove()); + lastTimestamp = 0; // 重置时间戳计数 + }; + $btnRe.onclick = () => { connectCamera(); connectASR(); }; + + connectCamera(); + connectASR(); +})(); + + +// ================= IMU 3D(无虚线框、无滚动条、上下对齐、自适应) ================= +import * as THREE from 'three'; +import { GLTFLoader } from 'https://unpkg.com/three@0.155.0/examples/jsm/loaders/GLTFLoader.js'; + +// Day 20: IMU 浮窗折叠功能 - 修复:兼容模块延迟加载 +// Day 23: 移动端优化 - 默认折叠 +function initImuToggle() { + const imuFloat = document.getElementById('imuFloat'); + const imuToggle = document.getElementById('imuToggle'); + console.log('[IMU] 初始化折叠功能, imuFloat:', !!imuFloat, 'imuToggle:', !!imuToggle); + + if (imuFloat && imuToggle) { + // 检测移动端 - 默认折叠 + const isMobile = window.innerWidth < 1100; + if (isMobile) { + imuFloat.classList.add('collapsed'); + imuFloat.classList.remove('expanded'); + imuToggle.textContent = '+'; + console.log('[IMU] 移动端检测,默认折叠'); + } else { + imuFloat.classList.add('expanded'); + } + + imuToggle.onclick = function (e) { + e.preventDefault(); + e.stopPropagation(); + const isCollapsed = imuFloat.classList.toggle('collapsed'); + imuFloat.classList.toggle('expanded', !isCollapsed); + this.textContent = isCollapsed ? '+' : '−'; + console.log('[IMU] 折叠状态:', isCollapsed); + }; + console.log('[IMU] 折叠按钮事件已绑定'); + } +} + +// 确保 DOM 加载后执行(兼容模块延迟加载) +if (document.readyState === 'loading') { + document.addEventListener('DOMContentLoaded', initImuToggle); +} else { + // DOM 已加载完成,直接执行 + initImuToggle(); +} + +(() => { + const container = document.getElementById('imu_view'); // 左侧3D容器 + const hud = document.getElementById('imu_hud'); // 右侧IMU容器 + + // 左右窗口统一半透明底色 + if (container) container.style.background = 'rgba(0,0,0,0.2)'; + if (hud) { + // 关键:右侧容器作为定位参考,同时禁止滚动、清理边框 + Object.assign(hud.style, { + position: 'relative', + overflow: 'hidden', + border: 'none', + outline: 'none', + background: 'rgba(0,0,0,0.2)', // 右侧也给统一底色(整块),干净无额外面板底色 + borderRadius: '10px' + }); + } + + // —— 彻底去掉“虚线框”和一切边框/阴影(含可能的外层壳)—— + (function killFraming() { + const s = document.createElement('style'); + s.textContent = ` + #imu_view, #imu_hud, #data-panel, #imu_dock, + .imu-card, .imu-wrap, .panel, .card, .window { + border: none !important; + outline: none !important; + box-shadow: none !important; + background-image: none !important; + } + /* 兜底:清除任何内联 dashed/ dotted */ + [style*="dashed"], [style*="dotted"] { + border-style: none !important; + outline: none !important; + } + `; + document.head.appendChild(s); + + // 同时清理父级(最多向上两层)里的边框与滚动条,避免外层虚线框和滚动条 + [container, hud].forEach(el => { + let p = el ? el.parentElement : null; + for (let i = 0; i < 2 && p; i++, p = p.parentElement) { + p.style.border = 'none'; + p.style.outline = 'none'; + p.style.boxShadow = 'none'; + p.style.overflow = 'hidden'; + p.style.backgroundImage = 'none'; + } + }); + })(); + + // 右侧:不再额外创建 dock 背板(直接用 hud 当整块背景) + // 数据面板只负责显示文字,不再自带背景与边框 + + // three.js 渲染器 + const renderer = new THREE.WebGLRenderer({ antialias: true, alpha: true }); + const scene = new THREE.Scene(); + const camera = new THREE.PerspectiveCamera(70, 1, 0.1, 1000); + + // 画质相关 + renderer.shadowMap.enabled = true; + renderer.shadowMap.type = THREE.PCFSoftShadowMap; + renderer.outputColorSpace = THREE.SRGBColorSpace; + renderer.toneMapping = THREE.ACESFilmicToneMapping; + renderer.toneMappingExposure = 1.0; + renderer.setClearColor(0x000000, 0); // 透明背景 + + // ——— 核心:左右窗口“上下齐+自适应等比” ——— + let syncRaf = 0; + function syncHeights() { + if (!container || !hud) return; + const w = container.clientWidth || 300; + + // 恢复合理高度设置 + const padding = 20; + const contentH = (document.getElementById('data-panel')?.offsetHeight || 0) + padding; + let targetH = Math.max(180, contentH); // 最小高度 180px + + hud.style.height = `${targetH}px`; + hud.style.maxHeight = 'none'; + hud.style.overflow = 'hidden'; + + container.style.height = `${targetH}px`; + renderer.setSize(w, targetH); + camera.aspect = w / targetH; + camera.updateProjectionMatrix(); + } + + function requestSync() { + cancelAnimationFrame(syncRaf); + syncRaf = requestAnimationFrame(syncHeights); + } + + // 初次与窗口变化时,同步左右高度 + requestSync(); + window.addEventListener('resize', requestSync); + + // 数据变化时也同步(放在 updateDataPanel 内) + function updateDataPanel(roll, pitch, yaw, gx, gy, gz, ax, ay, az) { + document.getElementById('panel-roll').textContent = roll.toFixed(1) + '°'; + document.getElementById('panel-pitch').textContent = pitch.toFixed(1) + '°'; + document.getElementById('panel-yaw').textContent = yaw.toFixed(1) + '°'; + document.getElementById('panel-gx').textContent = gx.toFixed(1); + document.getElementById('panel-gy').textContent = gy.toFixed(1); + document.getElementById('panel-gz').textContent = gz.toFixed(1); + document.getElementById('panel-ax').textContent = ax.toFixed(2); + document.getElementById('panel-ay').textContent = ay.toFixed(2); + document.getElementById('panel-az').textContent = az.toFixed(2); + + requestSync(); // 数据刷新后同步高度 + } + + + container.appendChild(renderer.domElement); + + // ========== 场景 ========== + const group = new THREE.Group(); + scene.add(group); + + const axesHelper = new THREE.AxesHelper(4); + scene.add(axesHelper); + + function createAxisLabel(text, position, color) { + const c = document.createElement('canvas'); + const ctx = c.getContext('2d'); + c.width = 128; c.height = 64; + ctx.fillStyle = color; + ctx.font = 'Bold 24px Arial'; + ctx.textAlign = 'center'; + ctx.textBaseline = 'middle'; + ctx.fillText(text, 64, 32); + const tex = new THREE.CanvasTexture(c); + const mat = new THREE.SpriteMaterial({ map: tex }); + const spr = new THREE.Sprite(mat); + spr.position.copy(position); + spr.scale.set(0.8, 0.4, 1); + return spr; + } + scene.add(createAxisLabel('X', new THREE.Vector3(4.5, 0, 0), '#ff0000')); + scene.add(createAxisLabel('Y', new THREE.Vector3(0, 4.5, 0), '#00ff00')); + scene.add(createAxisLabel('Z', new THREE.Vector3(0, 0, 4.5), '#0000ff')); + + function createScale() { + const g = new THREE.Group(); + for (let i = 1; i <= 4; i++) { + const geo = new THREE.SphereGeometry(0.05, 8, 6); + const mk = (c) => new THREE.Mesh(geo, new THREE.MeshBasicMaterial({ color: c })); + const mx = mk(0xff4444); mx.position.set(i, 0, 0); g.add(mx); + const my = mk(0x44ff44); my.position.set(0, i, 0); g.add(my); + const mz = mk(0x4444ff); mz.position.set(0, 0, i); g.add(mz); + } + return g; + } + scene.add(createScale()); + + function createDirectionLabels() { + [ + { t: '前', p: new THREE.Vector3(0, 0, 5), c: '#00ffff' }, + { t: '后', p: new THREE.Vector3(0, 0, -5), c: '#00ffff' }, + { t: '左', p: new THREE.Vector3(-5, 0, 0), c: '#ffff00' }, + { t: '右', p: new THREE.Vector3(5, 0, 0), c: '#ffff00' }, + { t: '上', p: new THREE.Vector3(0, 5, 0), c: '#ff00ff' }, + { t: '下', p: new THREE.Vector3(0, -5, 0), c: '#ff00ff' }, + ].forEach(d => scene.add(createAxisLabel(d.t, d.p, d.c))); + } + createDirectionLabels(); + + camera.position.set(4, 4, 6); + camera.lookAt(0, 0, 0); + + // ========== 右侧 IMU 数据展示 ========== + function createDataPanel() { + const panel = document.createElement('div'); + panel.id = 'data-panel'; + panel.style.cssText = ` + position: relative; + background: transparent; + border: none; + padding: 10px; + width: 100%; + color: #e6edf3; + font-family: 'Consolas','Monaco',monospace; + font-size: 11px; + box-shadow: none; + overflow: hidden; + `; + panel.innerHTML = ` +
IMU 实时数据
+
+
翻滚角 (Roll)
+
0.0°
+
俯仰角 (Pitch)
+
0.0°
+
+
+
偏航角 (Yaw)
+
0.0°
+
+
+
角速度 (°/s)
+
+
gX:0.0
+
gY:0.0
+
gZ:0.0
+
+
+
+
加速度 (m/s²)
+
+
aX:0.0
+
aY:0.0
+
aZ:0.0
+
+
+ `; + hud.appendChild(panel); + return panel; + } + const dataPanel = createDataPanel(); + + function updateDataPanel(roll, pitch, yaw, gx, gy, gz, ax, ay, az) { + document.getElementById('panel-roll').textContent = roll.toFixed(1) + '°'; + document.getElementById('panel-pitch').textContent = pitch.toFixed(1) + '°'; + document.getElementById('panel-yaw').textContent = yaw.toFixed(1) + '°'; + document.getElementById('panel-gx').textContent = gx.toFixed(1); + document.getElementById('panel-gy').textContent = gy.toFixed(1); + document.getElementById('panel-gz').textContent = gz.toFixed(1); + document.getElementById('panel-ax').textContent = ax.toFixed(2); + document.getElementById('panel-ay').textContent = ay.toFixed(2); + document.getElementById('panel-az').textContent = az.toFixed(2); + } + + // ========== 灯光 ========== + const ambientLight = new THREE.AmbientLight(0x404080, 0.3); + scene.add(ambientLight); + + const mainLight = new THREE.DirectionalLight(0x00aaff, 1.2); + mainLight.position.set(5, 8, 5); + mainLight.castShadow = true; + mainLight.shadow.mapSize.width = 2048; + mainLight.shadow.mapSize.height = 2048; + mainLight.shadow.camera.near = 0.5; + mainLight.shadow.camera.far = 50; + scene.add(mainLight); + + const fillLight = new THREE.DirectionalLight(0xff6633, 0.8); + fillLight.position.set(-5, 3, -3); + scene.add(fillLight); + + const rimLight = new THREE.DirectionalLight(0x66ffff, 0.6); + rimLight.position.set(0, -5, 8); + scene.add(rimLight); + + const pointLight1 = new THREE.PointLight(0x00ff88, 0.5, 20); + pointLight1.position.set(3, 2, 4); + scene.add(pointLight1); + + const pointLight2 = new THREE.PointLight(0xff3388, 0.4, 15); + pointLight2.position.set(-3, -2, 2); + scene.add(pointLight2); + + const spotLight = new THREE.SpotLight(0xffffff, 1.0, 30, Math.PI / 6, 0.3, 1); + spotLight.position.set(0, 10, 8); + spotLight.target.position.set(0, 0, 0); + spotLight.castShadow = true; + scene.add(spotLight); + scene.add(spotLight.target); + + let lightTime = 0; + function updateLighting() { + lightTime += 0.01; + mainLight.intensity = 1.2 + Math.sin(lightTime * 2) * 0.2; + pointLight1.intensity = 0.5 + Math.sin(lightTime * 3) * 0.2; + pointLight2.intensity = 0.4 + Math.cos(lightTime * 2.5) * 0.2; + const hue = (Math.sin(lightTime * 0.5) + 1) * 0.3; + rimLight.color.setHSL(0.5 + hue, 1.0, 0.7); + } + + // ========== 模型 ========== + let glassModel = null; + const loader = new GLTFLoader(); + loader.load( + '/static/models/aiglass.glb', + (gltf) => { + glassModel = gltf.scene; + glassModel.scale.set(2, 2, 2); + glassModel.position.set(0, 0, 0); + glassModel.traverse((child) => { + if (child.isMesh) { + child.castShadow = true; + child.receiveShadow = true; + if (child.material) { + if (child.material.transparent || child.material.opacity < 1) { + child.material.envMapIntensity = 1.5; + child.material.roughness = 0.1; + child.material.metalness = 0.8; + } + } + } + }); + group.add(glassModel); + }, + undefined, + (error) => { + console.error('GLB加载失败:', error); + const fallbackCube = new THREE.Mesh( + new THREE.BoxGeometry(2, 2, 2), + new THREE.MeshStandardMaterial({ color: 0x00aaff, metalness: 0.7, roughness: 0.3, envMapIntensity: 1.0 }) + ); + fallbackCube.castShadow = true; + fallbackCube.receiveShadow = true; + group.add(fallbackCube); + } + ); + + // 渲染循环 + (function animate() { + requestAnimationFrame(animate); + updateLighting(); + renderer.render(scene, camera); + })(); + + // ===== IMU 数学与数据通道(原逻辑保持) ===== + // 安装补偿 + const MOUNT_RX = 0, MOUNT_RY = -90, MOUNT_RZ = 0; + const qMount = new THREE.Quaternion() + .multiply(new THREE.Quaternion().setFromAxisAngle(new THREE.Vector3(0, 1, 0), THREE.MathUtils.degToRad(MOUNT_RY))) + .multiply(new THREE.Quaternion().setFromAxisAngle(new THREE.Vector3(0, 0, 1), THREE.MathUtils.degToRad(MOUNT_RZ))) + .multiply(new THREE.Quaternion().setFromAxisAngle(new THREE.Vector3(1, 0, 0), THREE.MathUtils.degToRad(MOUNT_RX))); + + const FOLLOW = 0.85; + const $ = id => document.getElementById(id); + const updateSlider = (idBase, v) => { const sl = $(`${idBase}_sl`), tv = $(`${idBase}_val`); if (sl) { const min = +sl.min, max = +sl.max; sl.value = Math.max(min, Math.min(max, v)); } if (tv) tv.textContent = (typeof v === 'number' ? v.toFixed(2) : '-'); }; + + let MED_N = Number($('medn').value); + $('medn').onchange = e => MED_N = Number(e.target.value); + + let STILL_W = Number($('still_w').value); + $('still_w').onchange = e => STILL_W = Number(e.target.value); + + let ANG_EMA = Number($('ang_ema').value); + $('ang_ema').onchange = e => ANG_EMA = Number(e.target.value); + + let GRAV_BETA = Number($('grav_beta').value); + $('grav_beta').onchange = e => GRAV_BETA = Number(e.target.value); + + let YAW_DB = Number($('yaw_db').value); + $('yaw_db').onchange = e => YAW_DB = Number(e.target.value); + + let YAW_LEAK = Number($('yaw_leak').value); + $('yaw_leak').onchange = e => YAW_LEAK = Number(e.target.value); + + let autoRezero = true; + $('auto_rezero').onchange = e => { autoRezero = e.target.checked; }; + + let autoBias = true; + $('auto_bias').onchange = e => { autoBias = e.target.checked; }; + + let useProj = true; + $('use_proj').onchange = e => { useProj = e.target.checked; }; + + let freezeStill = true; + $('freeze_still').onchange = e => { freezeStill = e.target.checked; }; + + const mkMed = () => ({ buf: [], push(v) { this.buf.push(v); if (this.buf.length > MED_N) this.buf.shift(); const arr = [...this.buf].sort((a, b) => a - b); const m = arr[Math.floor(arr.length / 2)]; return { median: m, valid: this.buf.length === MED_N }; } }); + const fx = mkMed(), fy = mkMed(), fz = mkMed(); + const gx = mkMed(), gy = mkMed(), gz = mkMed(); + + const rad2deg = r => r * 180 / Math.PI; + const wrap180 = a => { a %= 360; if (a >= 180) a -= 360; if (a < -180) a += 360; return a; }; + + let lastTS = 0; + let yaw = 0; + let ref = { roll: 0, pitch: 0, yaw: 0 }; + let holdStart = 0, isStill = false; + + let gLP = { x: 0, y: 0, z: 0 }; + const G = 9.807, A_TOL = 0.08 * G; + + let gOff = { x: 0, y: 0, z: 0 }; + const BIAS_ALPHA = 0.002; + + let Rf = 0, Pf = 0, Yf = 0; + + document.getElementById('btn_zero').onclick = () => { ref = { roll: Rf, pitch: Pf, yaw: Yf }; }; + document.getElementById('btn_reset').onclick = () => { ref = { roll: 0, pitch: 0, yaw: 0 }; yaw = 0; Rf = 0; Pf = 0; Yf = 0; }; + document.getElementById('btn_bias_now').onclick = () => { gOff = { ...lastGy }; }; + + let lastGy = { x: 0, y: 0, z: 0 }; + + const imu_ws_state = document.getElementById('imu_ws_state'); + function setImuBadge(ok, text) { + imu_ws_state.textContent = text; + imu_ws_state.className = 'badge ' + (ok ? 'ok' : 'err'); + } + + const ws = new WebSocket((location.protocol === 'https:' ? 'wss://' : 'ws://') + location.host + '/ws'); + setImuBadge(false, 'connecting…'); + ws.onopen = () => setImuBadge(true, 'connected'); + ws.onclose = () => setImuBadge(false, 'disconnected'); + ws.onerror = () => setImuBadge(false, 'error'); + ws.onmessage = (ev) => { + try { + const d = JSON.parse(ev.data); + const t = (typeof d.ts === 'number') ? d.ts : performance.now(); + let dt = (!lastTS || (t - lastTS) <= 0 || (t - lastTS) > 300) ? 0.02 : (t - lastTS) / 1000; + lastTS = t; + + let ax = Number(d?.accel?.x) || 0, ay = Number(d?.accel?.y) || 0, az = Number(d?.accel?.z) || 0; + let wx = Number(d?.gyro?.x) || 0, wy = Number(d?.gyro?.y) || 0, wz = Number(d?.gyro?.z) || 0; + + const fxr = fx.push(ax), fyr = fy.push(ay), fzr = fz.push(az); + const gxr = gx.push(wx), gyr = gy.push(wy), gzr = gz.push(wz); + if (fxr.valid) { ax = fxr.median; ay = fyr.median; az = fzr.median; } + if (gxr.valid) { wx = gxr.median; wy = gyr.median; wz = gzr.median; } + + lastGy = { x: wx, y: wy, z: wz }; + + gLP.x = GRAV_BETA * gLP.x + (1 - GRAV_BETA) * ax; + gLP.y = GRAV_BETA * gLP.y + (1 - GRAV_BETA) * ay; + gLP.z = GRAV_BETA * gLP.z + (1 - GRAV_BETA) * az; + const gmag = Math.hypot(gLP.x, gLP.y, gLP.z) || 1; + const gHat = { x: gLP.x / gmag, y: gLP.y / gmag, z: gLP.z / gmag }; + + const roll = rad2deg(Math.atan2(az, ay)); + const pitch = rad2deg(Math.atan2(-ax, ay)); + + const aNorm = Math.hypot(ax, ay, az); + const wNorm = Math.hypot(wx, wy, wz); + const nearFlat = Math.abs(roll) < 2.0 && Math.abs(pitch) < 2.0; + const stillCond = (Math.abs(aNorm - G) < A_TOL) && (wNorm < STILL_W); + + if (stillCond) { + if (!holdStart) holdStart = t; + if (!isStill && (t - holdStart) > 350) isStill = true; + if (autoBias) { + gOff.x = (1 - BIAS_ALPHA) * gOff.x + BIAS_ALPHA * wx; + gOff.y = (1 - BIAS_ALPHA) * gOff.y + BIAS_ALPHA * wy; + gOff.z = (1 - BIAS_ALPHA) * gOff.z + BIAS_ALPHA * wz; + } + } else { holdStart = 0; isStill = false; } + + let yawdot = useProj + ? ((wx - gOff.x) * gHat.x + (wy - gOff.y) * gHat.y + (wz - gOff.z) * gHat.z) + : (wy - gOff.y); + + if (Math.abs(yawdot) < YAW_DB) yawdot = 0; + if (freezeStill && stillCond) yawdot = 0; + + yaw = wrap180(yaw + yawdot * dt); + + if (YAW_LEAK > 0 && nearFlat && stillCond && Math.abs(yaw) > 0) { + const step = YAW_LEAK * dt * Math.sign(-yaw); + if (Math.abs(yaw) <= Math.abs(step)) yaw = 0; else yaw += step; + } + + const alpha = ANG_EMA; + Rf = alpha * roll + (1 - alpha) * Rf; + Pf = alpha * pitch + (1 - alpha) * Pf; + Yf = alpha * yaw + (1 - alpha) * Yf; + + if (autoRezero && nearFlat && wNorm < STILL_W) { + if (!holdStart) holdStart = t; + if (!isStill && (t - holdStart) > 350) { + ref = { roll: Rf, pitch: Pf, yaw: Yf }; + isStill = true; + } + } + + const R = wrap180(Rf - ref.roll); + const P = wrap180(Pf - ref.pitch); + const Y = wrap180(Yf - ref.yaw); + + const qBody = new THREE.Quaternion() + .multiply(new THREE.Quaternion().setFromAxisAngle(new THREE.Vector3(0, 1, 0), THREE.MathUtils.degToRad(Y))) + .multiply(new THREE.Quaternion().setFromAxisAngle(new THREE.Vector3(0, 0, 1), THREE.MathUtils.degToRad(P))) + .multiply(new THREE.Quaternion().setFromAxisAngle(new THREE.Vector3(1, 0, 0), THREE.MathUtils.degToRad(R))); + const q = qMount.clone().multiply(qBody); + + if (FOLLOW >= 0.999) group.setRotationFromQuaternion(q); + else group.quaternion.slerp(q, FOLLOW); + + updateSlider('roll', R); + updateSlider('pitch', P); + updateSlider('yaw', Y); + + updateSlider('gx', wx); updateSlider('gy', wy); updateSlider('gz', wz); + updateSlider('ax', ax); updateSlider('ay', ay); updateSlider('az', az); + + // 更新右侧数据 + updateDataPanel(R, P, Y, wx, wy, wz, ax, ay, az); + } catch (e) { } + }; + + // 初次与窗口改变时,保持左右上下对齐 + window.addEventListener('resize', resize); + resize(); +})(); diff --git a/static/models/aiglass.glb b/static/models/aiglass.glb new file mode 100644 index 0000000..c96e88d Binary files /dev/null and b/static/models/aiglass.glb differ diff --git a/static/vision.css b/static/vision.css new file mode 100644 index 0000000..87dfebc --- /dev/null +++ b/static/vision.css @@ -0,0 +1,195 @@ +/* 科技感配色方案 */ +:root { + --tech-bg: #0a0e1b; + --tech-card: #111827; + --tech-border: #1e293b; + --tech-primary: #3b82f6; + --tech-secondary: #8b5cf6; + --tech-accent: #06b6d4; + --tech-success: #10b981; + --tech-warning: #f59e0b; + --tech-text: #e0e7ff; + --tech-muted: #94a3b8; + --glow-primary: 0 0 30px rgba(59, 130, 246, 0.5); + --glow-secondary: 0 0 30px rgba(139, 92, 246, 0.5); +} + +/* 视觉识别画布容器 */ +.vision-container { + position: relative; + background: var(--tech-bg); + border: 1px solid var(--tech-border); + border-radius: 16px; + overflow: hidden; + box-shadow: 0 20px 40px rgba(0, 0, 0, 0.6); +} + +.vision-canvas { + width: 100%; + height: auto; + display: block; +} + +/* 覆盖层 */ +.vision-overlay { + position: absolute; + top: 0; + left: 0; + width: 100%; + height: 100%; + pointer-events: none; +} + +/* HUD样式 */ +.hud-element { + position: absolute; + color: var(--tech-text); + font-family: 'Inter', 'Noto Sans SC', sans-serif; + text-shadow: 0 0 10px rgba(59, 130, 246, 0.8); +} + +/* 状态指示器 */ +.status-indicator { + position: absolute; + top: 20px; + left: 20px; + padding: 12px 24px; + background: rgba(17, 24, 39, 0.9); + border: 1px solid var(--tech-primary); + border-radius: 8px; + backdrop-filter: blur(10px); + box-shadow: var(--glow-primary); +} + +.status-main { + font-size: 18px; + font-weight: 600; + color: var(--tech-primary); + margin-bottom: 4px; +} + +.status-sub { + font-size: 12px; + color: var(--tech-muted); + text-transform: uppercase; + letter-spacing: 1px; +} + +/* 进度条 */ +.progress-container { + position: absolute; + bottom: 40px; + left: 20px; + width: 300px; +} + +.progress-item { + margin-bottom: 20px; +} + +.progress-label { + display: flex; + justify-content: space-between; + margin-bottom: 8px; +} + +.progress-label-text { + font-size: 14px; + font-weight: 500; +} + +.progress-label-sub { + font-size: 11px; + color: var(--tech-muted); + margin-left: 8px; +} + +.progress-bar { + height: 8px; + background: rgba(30, 41, 59, 0.8); + border-radius: 4px; + overflow: hidden; + position: relative; +} + +.progress-fill { + height: 100%; + background: linear-gradient(90deg, var(--tech-primary), var(--tech-accent)); + border-radius: 4px; + transition: width 0.3s ease; + box-shadow: 0 0 20px rgba(59, 130, 246, 0.6); +} + +/* 手部追踪样式 */ +.hand-skeleton { + stroke: var(--tech-accent); + stroke-width: 2; + fill: none; + filter: drop-shadow(0 0 6px rgba(6, 182, 212, 0.8)); +} + +.hand-joint { + fill: var(--tech-accent); + filter: drop-shadow(0 0 8px rgba(6, 182, 212, 1)); +} + +/* 目标锁定样式 */ +.target-lock { + stroke: var(--tech-success); + stroke-width: 3; + fill: none; + stroke-dasharray: 10 5; + animation: rotate 20s linear infinite; + filter: drop-shadow(0 0 10px rgba(16, 185, 129, 0.8)); +} + +@keyframes rotate { + from { transform: rotate(0deg); } + to { transform: rotate(360deg); } +} + +/* 闪烁动画 */ +.flash-overlay { + position: absolute; + top: 0; + left: 0; + width: 100%; + height: 100%; + background: radial-gradient(circle, transparent 0%, rgba(139, 92, 246, 0.3) 100%); + animation: flash-pulse 1s ease-in-out; +} + +@keyframes flash-pulse { + 0%, 100% { opacity: 0; } + 50% { opacity: 1; } +} + +/* 数据显示面板 */ +.data-panel { + position: absolute; + top: 20px; + right: 20px; + background: rgba(17, 24, 39, 0.9); + border: 1px solid var(--tech-secondary); + border-radius: 8px; + padding: 16px; + backdrop-filter: blur(10px); + min-width: 200px; +} + +.data-item { + display: flex; + justify-content: space-between; + margin-bottom: 12px; + font-size: 14px; +} + +.data-label { + color: var(--tech-muted); +} + +.data-value { + color: var(--tech-text); + font-weight: 600; + font-family: 'JetBrains Mono', monospace; +} \ No newline at end of file diff --git a/static/vision.js b/static/vision.js new file mode 100644 index 0000000..b0209fe --- /dev/null +++ b/static/vision.js @@ -0,0 +1,291 @@ +// 科技感视觉识别系统 +class VisionSystem { + constructor(canvasId) { + this.canvas = document.getElementById(canvasId); + this.ctx = this.canvas.getContext('2d'); + this.overlay = document.createElement('div'); + this.overlay.className = 'vision-overlay'; + this.canvas.parentElement.appendChild(this.overlay); + + // 状态 + this.mode = 'SEGMENT'; + this.fps = 0; + this.detectedObjects = []; + this.handData = null; + this.trackingData = null; + + // 初始化UI元素 + this.initUI(); + + // 连接WebSocket + this.connectVisionWS(); + } + + initUI() { + // 状态指示器 + this.statusElement = this.createStatusIndicator(); + this.overlay.appendChild(this.statusElement); + + // 进度条 + this.progressElement = this.createProgressBars(); + this.overlay.appendChild(this.progressElement); + + // 数据面板 + this.dataPanel = this.createDataPanel(); + this.overlay.appendChild(this.dataPanel); + } + + createStatusIndicator() { + const status = document.createElement('div'); + status.className = 'status-indicator'; + status.innerHTML = ` +
系统就绪 System Ready
+
等待目标 Waiting for Target
+ `; + return status; + } + + createProgressBars() { + const container = document.createElement('div'); + container.className = 'progress-container'; + container.innerHTML = ` +
+
+ 对齐度 Alignment + 0% +
+
+
+
+
+
+
+ 距离匹配 Distance Match + 0% +
+
+
+
+
+ `; + return container; + } + + createDataPanel() { + const panel = document.createElement('div'); + panel.className = 'data-panel'; + panel.innerHTML = ` +
+ FPS + -- +
+
+ 模式 Mode + 检测 +
+
+ 目标数 Objects + 0 +
+
+ 握持分 Grasp + 0.00 +
+ `; + return panel; + } + + connectVisionWS() { + const proto = location.protocol === 'https:' ? 'wss' : 'ws'; + this.ws = new WebSocket(`${proto}://${location.host}/ws/viewer`); // 改为 /ws/viewer + + this.ws.onopen = () => { + console.log('[Vision] WebSocket connected'); + // ... rest of the code + }; + + this.ws.onmessage = (event) => { + // 处理二进制图像数据 + if (event.data instanceof Blob) { + // 创建图像URL并显示 + const url = URL.createObjectURL(event.data); + const img = new Image(); + img.onload = () => { + this.ctx.drawImage(img, 0, 0, this.canvas.width, this.canvas.height); + URL.revokeObjectURL(url); + }; + img.src = url; + } + }; + + this.ws.onerror = () => { + console.error('Vision WebSocket error'); + }; + } + + updateVisualization(data) { + // 更新状态 + this.mode = data.mode || 'SEGMENT'; + this.fps = data.fps || 0; + + // 更新UI + this.updateStatus(data); + this.updateProgress(data); + this.updateDataPanel(data); + + // 绘制可视化 + if (data.frame) { + this.drawFrame(data.frame); + } + + if (data.hand) { + this.drawHand(data.hand); + } + + if (data.objects) { + this.drawObjects(data.objects); + } + + if (data.tracking) { + this.drawTracking(data.tracking); + } + } + + updateStatus(data) { + const statusMain = this.statusElement.querySelector('.status-main'); + const statusSub = this.statusElement.querySelector('.status-sub:last-child'); + + switch(this.mode) { + case 'SEGMENT': + statusMain.innerHTML = '目标检测中 Detecting'; + statusSub.textContent = data.message || '扫描环境 Scanning Environment'; + break; + case 'FLASH': + statusMain.innerHTML = '锁定中 Locking'; + statusSub.textContent = '准备追踪 Preparing to Track'; + break; + case 'TRACK': + statusMain.innerHTML = '追踪中 Tracking'; + statusSub.textContent = '保持对准 Maintain Alignment'; + break; + } + } + + updateProgress(data) { + if (data.alignScore !== undefined) { + const alignPercent = Math.round(data.alignScore * 100); + document.getElementById('align-progress').style.width = `${alignPercent}%`; + this.progressElement.querySelector('.progress-value').textContent = `${alignPercent}%`; + } + + if (data.distanceScore !== undefined) { + const distPercent = Math.round(data.distanceScore * 100); + document.getElementById('distance-progress').style.width = `${distPercent}%`; + this.progressElement.querySelectorAll('.progress-value')[1].textContent = `${distPercent}%`; + } + } + + updateDataPanel(data) { + document.getElementById('fps-value').textContent = Math.round(this.fps); + document.getElementById('mode-value').textContent = this.getModeText(this.mode); + document.getElementById('objects-value').textContent = data.objectCount || 0; + document.getElementById('grasp-value').textContent = (data.graspScore || 0).toFixed(2); + } + + getModeText(mode) { + const modeMap = { + 'SEGMENT': '检测 Detect', + 'FLASH': '锁定 Lock', + 'TRACK': '追踪 Track' + }; + return modeMap[mode] || mode; + } + + drawFrame(frameData) { + // 绘制基础图像 + const img = new Image(); + img.onload = () => { + this.canvas.width = img.width; + this.canvas.height = img.height; + this.ctx.drawImage(img, 0, 0); + }; + img.src = 'data:image/jpeg;base64,' + frameData; + } + + drawHand(handData) { + // 使用SVG绘制手部骨骼 + const svg = document.createElementNS('http://www.w3.org/2000/svg', 'svg'); + svg.style.position = 'absolute'; + svg.style.top = '0'; + svg.style.left = '0'; + svg.style.width = '100%'; + svg.style.height = '100%'; + svg.style.pointerEvents = 'none'; + + // 绘制连接线 + handData.connections.forEach(conn => { + const line = document.createElementNS('http://www.w3.org/2000/svg', 'line'); + line.setAttribute('x1', conn.start.x); + line.setAttribute('y1', conn.start.y); + line.setAttribute('x2', conn.end.x); + line.setAttribute('y2', conn.end.y); + line.setAttribute('class', 'hand-skeleton'); + svg.appendChild(line); + }); + + // 绘制关节点 + handData.landmarks.forEach(point => { + const circle = document.createElementNS('http://www.w3.org/2000/svg', 'circle'); + circle.setAttribute('cx', point.x); + circle.setAttribute('cy', point.y); + circle.setAttribute('r', '3'); + circle.setAttribute('class', 'hand-joint'); + svg.appendChild(circle); + }); + + // 添加到覆盖层 + const oldSvg = this.overlay.querySelector('svg'); + if (oldSvg) oldSvg.remove(); + this.overlay.appendChild(svg); + } + + drawObjects(objects) { + // 绘制检测到的物体 + objects.forEach((obj, index) => { + if (obj.isTarget) { + // 目标物体用特殊样式 + this.drawTargetObject(obj); + } else { + // 其他物体用普通样式 + this.drawNormalObject(obj); + } + }); + } + + drawTargetObject(obj) { + // 创建目标锁定效果 + const target = document.createElement('div'); + target.className = 'target-lock'; + target.style.position = 'absolute'; + target.style.left = `${obj.x}px`; + target.style.top = `${obj.y}px`; + target.style.width = `${obj.width}px`; + target.style.height = `${obj.height}px`; + + // 添加锁定动画 + const svg = ` + + + + `; + target.innerHTML = svg; + + this.overlay.appendChild(target); + } +} + +// 初始化 +document.addEventListener('DOMContentLoaded', () => { + const visionSystem = new VisionSystem('vision-canvas'); +}); \ No newline at end of file diff --git a/static/vision_renderer.js b/static/vision_renderer.js new file mode 100644 index 0000000..f4041ba --- /dev/null +++ b/static/vision_renderer.js @@ -0,0 +1,443 @@ +// vision_renderer.js - 前端可视化渲染器 + +class VisionRenderer { + constructor(canvasId) { + this.canvas = document.getElementById(canvasId); + this.ctx = this.canvas.getContext('2d'); + this.ws = null; + this.currentData = null; + + // UI配色方案 + this.colors = { + primaryBlue: '#00C8FF', + secondaryPurple: '#9664FF', + accentCyan: '#00FFFF', + white: '#FFFFFF', + lightGray: '#C8C8C8', + darkBg: 'rgba(40, 40, 40, 0.8)', + success: '#7FFF00', + warning: '#FFA500', + error: '#FF7272', + glassBg: 'rgba(20, 20, 20, 0.3)', + }; + + // 动画状态 + this.animations = { + flashAlpha: 0, + messageAlpha: 1, + progressAnimations: {} + }; + + this.setupCanvas(); + this.connect(); + this.startRenderLoop(); + } + + setupCanvas() { + // 设置画布大小 + const resizeCanvas = () => { + const rect = this.canvas.getBoundingClientRect(); + this.canvas.width = rect.width; + this.canvas.height = rect.height; + }; + + resizeCanvas(); + window.addEventListener('resize', resizeCanvas); + } + + connect() { + const proto = location.protocol === 'https:' ? 'wss' : 'ws'; + this.ws = new WebSocket(`${proto}://${location.host}/ws/vision_data`); + + this.ws.onopen = () => { + console.log('[VisionRenderer] Connected'); + this.updateConnectionStatus(true); + }; + + this.ws.onclose = () => { + console.log('[VisionRenderer] Disconnected'); + this.updateConnectionStatus(false); + // 自动重连 + setTimeout(() => this.connect(), 2000); + }; + + this.ws.onmessage = (event) => { + try { + this.currentData = JSON.parse(event.data); + } catch (e) { + console.error('[VisionRenderer] Parse error:', e); + } + }; + } + + updateConnectionStatus(connected) { + const badge = document.getElementById('visionStatus'); + if (badge) { + badge.textContent = connected ? 'Vision: connected' : 'Vision: disconnected'; + badge.className = 'badge ' + (connected ? 'ok' : 'err'); + } + } + + startRenderLoop() { + const render = () => { + this.clearCanvas(); + + if (this.currentData) { + this.renderFrame(this.currentData); + } + + requestAnimationFrame(render); + }; + + render(); + } + + clearCanvas() { + this.ctx.clearRect(0, 0, this.canvas.width, this.canvas.height); + } + + renderFrame(data) { + const ctx = this.ctx; + const W = this.canvas.width; + const H = this.canvas.height; + + // 渲染手部骨骼 + if (data.hand_detected && data.hand_landmarks) { + this.drawHandSkeleton(data.hand_landmarks); + + // 手部边界框 + if (data.hand_box) { + this.drawBox(data.hand_box, this.colors.accentCyan, 1); + } + + // 握持评分 + this.drawTextWithBg( + `握持评分 Grasp Score: ${data.grasp_score.toFixed(2)}`, + 10, 60, 18, this.colors.accentCyan + ); + } + + // 渲染检测到的物体 + if (data.mode === 'SEGMENT' && data.objects) { + data.objects.forEach((obj, index) => { + const isSelected = index === data.selected_object_index; + const color = isSelected ? this.colors.success : this.colors.primaryBlue; + + // 绘制轮廓 + if (obj.contour) { + this.drawContour(obj.contour, color, isSelected ? 3 : 2); + } + + // 选中物体的标记 + if (isSelected && obj.center) { + this.drawTargetMarker(obj.center.x, obj.center.y); + } + }); + + // 倒计时 + if (data.countdown !== null) { + this.drawCountdown(data.countdown); + } + } + + // 闪烁动画 + if (data.mode === 'FLASH' && data.flash_progress !== null) { + this.renderFlashAnimation(data.flash_progress); + } + + // 追踪模式 + if (data.mode === 'TRACK') { + // 追踪多边形 + if (data.tracking_polygon) { + this.drawPolygon(data.tracking_polygon, this.colors.success, 2); + } + + // 中心点 + if (data.tracking_center) { + this.drawCircle(data.tracking_center.x, data.tracking_center.y, 6, this.colors.success); + } + + // 对齐箭头 + if (data.hand_center && data.tracking_center) { + this.drawMeasureArrow( + data.hand_center, + data.tracking_center + ); + } + + // 面积比和引导 + if (data.area_ratio !== null) { + this.drawAreaRatio(data.area_ratio, data.guidance); + } + } + + // 进度条 + this.drawTechProgressBars(data.align_score, data.range_score); + + // FPS + this.drawFPS(data.fps); + + // 状态消息 + if (data.status_message) { + this.drawStatusMessage(data.status_message); + } + } + + drawHandSkeleton(landmarks) { + const ctx = this.ctx; + const color = this.colors.secondaryPurple; + + // MediaPipe手部连接 + const connections = [ + [0, 1], [1, 2], [2, 3], [3, 4], // 拇指 + [0, 5], [5, 6], [6, 7], [7, 8], // 食指 + [0, 9], [9, 10], [10, 11], [11, 12], // 中指 + [0, 13], [13, 14], [14, 15], [15, 16], // 无名指 + [0, 17], [17, 18], [18, 19], [19, 20], // 小指 + [5, 9], [9, 13], [13, 17] // 掌心 + ]; + + // 绘制连接线 + ctx.strokeStyle = color; + ctx.lineWidth = 2; + connections.forEach(([i, j]) => { + if (landmarks[i] && landmarks[j]) { + ctx.beginPath(); + ctx.moveTo(landmarks[i].x, landmarks[i].y); + ctx.lineTo(landmarks[j].x, landmarks[j].y); + ctx.stroke(); + } + }); + + // 绘制关键点 + landmarks.forEach(point => { + this.drawCircle(point.x, point.y, 3, color, true); + }); + } + + drawTextWithBg(text, x, y, fontSize = 18, color = this.colors.white, bgColor = this.colors.glassBg) { + const ctx = this.ctx; + const padding = 10; + + ctx.font = `${fontSize}px Arial, "Microsoft YaHei"`; + const metrics = ctx.measureText(text); + const textWidth = metrics.width; + const textHeight = fontSize; + + // 绘制背景 + ctx.fillStyle = bgColor; + ctx.fillRect(x - padding, y - textHeight - padding, + textWidth + 2 * padding, textHeight + 2 * padding); + + // 绘制边框 + ctx.strokeStyle = this.colors.primaryBlue; + ctx.lineWidth = 1; + ctx.strokeRect(x - padding, y - textHeight - padding, + textWidth + 2 * padding, textHeight + 2 * padding); + + // 绘制文字 + ctx.fillStyle = color; + ctx.fillText(text, x, y); + } + + drawCountdown(seconds) { + const text = `检测到物体 Object detected, ${seconds.toFixed(1)}s`; + const x = 10; + const y = 100; + this.drawTextWithBg(text, x, y, 22, this.colors.warning); + } + + renderFlashAnimation(progress) { + const ctx = this.ctx; + const W = this.canvas.width; + const H = this.canvas.height; + + // 计算闪烁透明度 + const cycleProgress = progress * 2; + const alpha = 0.3 + 0.3 * Math.sin(cycleProgress * Math.PI); + + // 全屏闪烁效果 + ctx.fillStyle = this.colors.accentCyan + Math.floor(alpha * 255).toString(16).padStart(2, '0'); + ctx.fillRect(0, 0, W, H); + + // 锁定文字 + this.drawTextWithBg('正在锁定目标... Locking target...', + W/2 - 150, H/2, 24, this.colors.accentCyan); + } + + drawTechProgressBars(alignScore, rangeScore) { + const W = this.canvas.width; + const H = this.canvas.height; + const barW = W * 0.3; + const barH = 8; + const gap = 20; + const x0 = 20; + const y0 = H - 2 * barH - gap - 60; + + // 对齐进度条 + this.drawProgressBar(x0, y0, barW, barH, alignScore, + '对齐 Align', this.colors.primaryBlue); + + // 距离进度条 + this.drawProgressBar(x0, y0 + barH + gap, barW, barH, rangeScore, + '距离(≈1) Distance(≈1)', this.colors.accentCyan); + } + + drawProgressBar(x, y, width, height, value, label, color) { + const ctx = this.ctx; + + // 背景 + ctx.fillStyle = this.colors.darkBg; + ctx.fillRect(x, y, width, height); + + // 边框 + ctx.strokeStyle = color; + ctx.lineWidth = 1; + ctx.strokeRect(x, y, width, height); + + // 填充(渐变) + const fillWidth = width * Math.max(0, Math.min(1, value)); + if (fillWidth > 0) { + const gradient = ctx.createLinearGradient(x, y, x + fillWidth, y); + gradient.addColorStop(0, this.colors.secondaryPurple); + gradient.addColorStop(1, color); + ctx.fillStyle = gradient; + ctx.fillRect(x, y, fillWidth, height); + } + + // 标签 + this.drawTextWithBg(label, x, y - 10, 14, color); + } + + drawCircle(x, y, radius, color, fill = true) { + const ctx = this.ctx; + ctx.beginPath(); + ctx.arc(x, y, radius, 0, 2 * Math.PI); + if (fill) { + ctx.fillStyle = color; + ctx.fill(); + } else { + ctx.strokeStyle = color; + ctx.lineWidth = 2; + ctx.stroke(); + } + } + + drawBox(box, color, lineWidth = 2) { + const ctx = this.ctx; + ctx.strokeStyle = color; + ctx.lineWidth = lineWidth; + ctx.strokeRect(box.x, box.y, box.width, box.height); + } + + drawContour(points, color, lineWidth = 2) { + if (!points || points.length < 3) return; + + const ctx = this.ctx; + ctx.strokeStyle = color; + ctx.lineWidth = lineWidth; + ctx.beginPath(); + ctx.moveTo(points[0].x, points[0].y); + for (let i = 1; i < points.length; i++) { + ctx.lineTo(points[i].x, points[i].y); + } + ctx.closePath(); + ctx.stroke(); + } + + drawPolygon(points, color, lineWidth = 2) { + this.drawContour(points, color, lineWidth); + } + + drawTargetMarker(x, y) { + // 双圆圈标记 + this.drawCircle(x, y, 8, this.colors.success, false); + this.drawCircle(x, y, 12, this.colors.success, false); + this.drawTextWithBg('目标 Target', x + 15, y - 5, 16, this.colors.success); + } + + drawMeasureArrow(p1, p2) { + const ctx = this.ctx; + const dx = p2.x - p1.x; + const dy = p2.y - p1.y; + const distance = Math.sqrt(dx * dx + dy * dy); + + // 绘制线 + ctx.strokeStyle = this.colors.white; + ctx.lineWidth = 2; + ctx.setLineDash([5, 5]); + ctx.beginPath(); + ctx.moveTo(p1.x, p1.y); + ctx.lineTo(p2.x, p2.y); + ctx.stroke(); + ctx.setLineDash([]); + + // 绘制箭头 + const angle = Math.atan2(dy, dx); + const arrowLength = 15; + const arrowAngle = Math.PI / 6; + + ctx.beginPath(); + ctx.moveTo(p2.x, p2.y); + ctx.lineTo( + p2.x - arrowLength * Math.cos(angle - arrowAngle), + p2.y - arrowLength * Math.sin(angle - arrowAngle) + ); + ctx.moveTo(p2.x, p2.y); + ctx.lineTo( + p2.x - arrowLength * Math.cos(angle + arrowAngle), + p2.y - arrowLength * Math.sin(angle + arrowAngle) + ); + ctx.stroke(); + + // 显示距离 + const midX = (p1.x + p2.x) / 2; + const midY = (p1.y + p2.y) / 2; + ctx.fillStyle = this.colors.white; + ctx.font = '14px Arial'; + ctx.fillText(`${distance.toFixed(0)}px`, midX + 10, midY - 10); + } + + drawAreaRatio(ratio, guidance) { + const y = 120; + const text = `面积比 Area Ratio: ${ratio.toFixed(2)}`; + this.drawTextWithBg(text, 10, y, 18, this.colors.lightGray); + + if (guidance) { + const guidanceText = { + 'forward': '向前靠近 Move Forward', + 'backward': '后退 Move Back', + 'maintain': '保持 Maintain' + }; + const guidanceColor = guidance === 'maintain' ? this.colors.success : this.colors.warning; + this.drawTextWithBg(guidanceText[guidance] || guidance, + 10, y + 40, 20, guidanceColor); + } + } + + drawFPS(fps) { + const W = this.canvas.width; + const text = `FPS: ${fps.toFixed(1)}`; + this.drawTextWithBg(text, W - 120, 30, 16, this.colors.accentCyan); + } + + drawStatusMessage(message) { + const W = this.canvas.width; + const H = this.canvas.height; + + // 根据消息类型选择颜色 + let color = this.colors.white; + if (message.includes('追踪丢失') || message.includes('lost')) { + color = this.colors.error; + } else if (message.includes('刷新') || message.includes('refreshed')) { + color = this.colors.success; + } + + this.drawTextWithBg(message, W/2 - 200, H - 50, 20, color); + } +} + +// 初始化渲染器 +document.addEventListener('DOMContentLoaded', () => { + window.visionRenderer = new VisionRenderer('canvas'); +}); \ No newline at end of file diff --git a/static/visualizer.js b/static/visualizer.js new file mode 100644 index 0000000..ac96992 --- /dev/null +++ b/static/visualizer.js @@ -0,0 +1,546 @@ +// static/visualizer.js +class TechVisualizer { + constructor(canvasId) { + this.canvas = document.getElementById(canvasId); + this.ctx = this.canvas.getContext('2d'); + this.ws = null; + this.data = {}; + + // 科技感配色方案 + this.colors = { + primary: '#00D9FF', // 青蓝色 + secondary: '#FF00FF', // 品红/紫色 + accent: '#00FF88', // 青绿色 + warning: '#FFAA00', // 橙色 + background: '#000000', // 黑色 + surface: '#0A0A0A', // 深灰 + text: '#FFFFFF', // 白色 + textMuted: '#888888', // 灰色 + grid: '#1A1A1A', // 网格色 + glow: '#00D9FF55' // 发光效果 + }; + + // 字体设置 + this.fonts = { + title: 'bold 24px "Orbitron", "Microsoft YaHei", sans-serif', + subtitle: 'bold 18px "Rajdhani", "Microsoft YaHei", sans-serif', + body: '16px "Roboto", "Microsoft YaHei", sans-serif', + small: '12px "Roboto", "Microsoft YaHei", sans-serif', + tiny: '10px "Roboto", sans-serif' + }; + + this.setupCanvas(); + this.connectWebSocket(); + } + + setupCanvas() { + // 设置画布大小 + const resizeCanvas = () => { + const rect = this.canvas.getBoundingClientRect(); + this.canvas.width = rect.width * window.devicePixelRatio; + this.canvas.height = rect.height * window.devicePixelRatio; + this.ctx.scale(window.devicePixelRatio, window.devicePixelRatio); + }; + + window.addEventListener('resize', resizeCanvas); + resizeCanvas(); + } + + connectWebSocket() { + const wsUrl = `ws://${window.location.host}/ws/visualizer`; + this.ws = new WebSocket(wsUrl); + + this.ws.onmessage = (event) => { + try { + this.data = JSON.parse(event.data); + this.render(); + } catch (e) { + console.error('Failed to parse visualization data:', e); + } + }; + + this.ws.onclose = () => { + setTimeout(() => this.connectWebSocket(), 1000); + }; + } + + render() { + const ctx = this.ctx; + const width = this.canvas.width / window.devicePixelRatio; + const height = this.canvas.height / window.devicePixelRatio; + + // 清空画布 + ctx.fillStyle = this.colors.background; + ctx.fillRect(0, 0, width, height); + + // 绘制网格背景 + this.drawGrid(width, height); + + // 绘制HUD边框 + this.drawHUD(width, height); + + // 根据模式绘制内容 + if (this.data.mode === 'SEGMENT') { + this.drawSegmentMode(width, height); + } else if (this.data.mode === 'FLASH') { + this.drawFlashMode(width, height); + } else if (this.data.mode === 'TRACK') { + this.drawTrackMode(width, height); + } + + // 绘制手部骨骼 + if (this.data.hand) { + this.drawHand(this.data.hand, width, height); + } + + // 绘制FPS和状态信息 + this.drawStats(width, height); + } + + drawGrid(width, height) { + const ctx = this.ctx; + ctx.strokeStyle = this.colors.grid; + ctx.lineWidth = 0.5; + + const gridSize = 50; + for (let x = 0; x < width; x += gridSize) { + ctx.beginPath(); + ctx.moveTo(x, 0); + ctx.lineTo(x, height); + ctx.stroke(); + } + for (let y = 0; y < height; y += gridSize) { + ctx.beginPath(); + ctx.moveTo(0, y); + ctx.lineTo(width, y); + ctx.stroke(); + } + } + + drawHUD(width, height) { + const ctx = this.ctx; + const margin = 20; + + // 四角装饰 + ctx.strokeStyle = this.colors.primary; + ctx.lineWidth = 2; + + const cornerSize = 40; + // 左上角 + ctx.beginPath(); + ctx.moveTo(margin, margin + cornerSize); + ctx.lineTo(margin, margin); + ctx.lineTo(margin + cornerSize, margin); + ctx.stroke(); + + // 右上角 + ctx.beginPath(); + ctx.moveTo(width - margin - cornerSize, margin); + ctx.lineTo(width - margin, margin); + ctx.lineTo(width - margin, margin + cornerSize); + ctx.stroke(); + + // 左下角 + ctx.beginPath(); + ctx.moveTo(margin, height - margin - cornerSize); + ctx.lineTo(margin, height - margin); + ctx.lineTo(margin + cornerSize, height - margin); + ctx.stroke(); + + // 右下角 + ctx.beginPath(); + ctx.moveTo(width - margin - cornerSize, height - margin); + ctx.lineTo(width - margin, height - margin); + ctx.lineTo(width - margin, height - margin - cornerSize); + ctx.stroke(); + } + + drawSegmentMode(width, height) { + const ctx = this.ctx; + + // 绘制检测到的物体 + if (this.data.segments) { + this.data.segments.forEach((seg, index) => { + if (seg.contour && seg.contour.length > 0) { + // 绘制轮廓 + ctx.beginPath(); + ctx.strokeStyle = seg.is_target ? this.colors.primary : this.colors.secondary; + ctx.lineWidth = seg.is_target ? 3 : 2; + + // 添加发光效果 + if (seg.is_target) { + ctx.shadowColor = this.colors.primary; + ctx.shadowBlur = 10; + } + + const points = this.scalePoints(seg.contour, width, height); + ctx.moveTo(points[0][0], points[0][1]); + points.forEach(p => ctx.lineTo(p[0], p[1])); + ctx.closePath(); + ctx.stroke(); + + ctx.shadowBlur = 0; + + // 如果是目标,绘制中心标记 + if (seg.is_target) { + const center = this.getContourCenter(points); + this.drawTargetMarker(center[0], center[1]); + + // 绘制面积信息 + ctx.font = this.fonts.small; + ctx.fillStyle = this.colors.primary; + ctx.fillText(`Area: ${seg.area}`, center[0] + 20, center[1] - 20); + } + } + }); + } + + // 绘制状态文字 + if (this.data.auto_lock && this.data.auto_lock.active) { + this.drawStatusText( + `目标锁定中 Locking Target`, + `${this.data.auto_lock.remaining.toFixed(1)}s`, + width / 2, + 100, + this.colors.warning + ); + } else { + this.drawStatusText( + '扫描中 Scanning', + '等待检测目标 Waiting for target', + width / 2, + 100, + this.colors.primary + ); + } + } + + drawFlashMode(width, height) { + const ctx = this.ctx; + + if (this.data.flash && this.data.flash.mask_contour) { + const progress = this.data.flash.progress || 0; + const alpha = 0.3 + 0.4 * (0.5 * (1 + Math.sin(progress * 2 * Math.PI - Math.PI/2))); + + // 绘制闪烁轮廓 + ctx.beginPath(); + ctx.strokeStyle = this.colors.accent; + ctx.lineWidth = 4; + ctx.globalAlpha = alpha; + + const points = this.scalePoints(this.data.flash.mask_contour, width, height); + ctx.moveTo(points[0][0], points[0][1]); + points.forEach(p => ctx.lineTo(p[0], p[1])); + ctx.closePath(); + + // 填充 + ctx.fillStyle = this.colors.accent + '33'; + ctx.fill(); + ctx.stroke(); + + ctx.globalAlpha = 1; + + // 绘制锁定动画 + const center = this.getContourCenter(points); + this.drawLockAnimation(center[0], center[1], progress); + } + + this.drawStatusText( + '正在锁定目标 Locking Target', + '准备追踪 Preparing to track', + width / 2, + 100, + this.colors.accent + ); + } + + drawTrackMode(width, height) { + const ctx = this.ctx; + const tracking = this.data.tracking; + + if (!tracking) return; + + // 绘制追踪多边形 + if (tracking.polygon && tracking.polygon.length > 0) { + ctx.beginPath(); + ctx.strokeStyle = this.colors.accent; + ctx.lineWidth = 3; + ctx.shadowColor = this.colors.accent; + ctx.shadowBlur = 15; + + const points = this.scalePoints(tracking.polygon, width, height); + ctx.moveTo(points[0][0], points[0][1]); + points.forEach(p => ctx.lineTo(p[0], p[1])); + ctx.closePath(); + ctx.stroke(); + + ctx.shadowBlur = 0; + + // 绘制中心点 + if (tracking.center) { + const center = this.scalePoint(tracking.center, width, height); + ctx.fillStyle = this.colors.accent; + ctx.beginPath(); + ctx.arc(center[0], center[1], 6, 0, Math.PI * 2); + ctx.fill(); + } + } + + // 绘制进度条 + this.drawProgressBars(tracking, width, height); + + // 绘制引导文字 + if (tracking.guidance) { + const guidanceText = { + '向前靠近': 'Move Closer', + '后退': 'Move Back', + '保持': 'Hold Position' + }; + + this.drawStatusText( + tracking.guidance, + guidanceText[tracking.guidance] || '', + width / 2, + height - 100, + this.colors.warning + ); + } + + // 如果触发了重新锁定 + if (tracking.relock_triggered) { + this.drawStatusText( + '已根据周边检测刷新追踪', + 'Tracking refreshed by peripheral detection', + width / 2, + 170, + this.colors.accent + ); + } + } + + drawHand(handData, width, height) { + const ctx = this.ctx; + + if (!handData.landmarks) return; + + // 缩放坐标 + const landmarks = handData.landmarks.map(p => + this.scalePoint([p[0], p[1]], width, height) + ); + + // 绘制手部连接线 + ctx.strokeStyle = this.colors.secondary; + ctx.lineWidth = 2; + ctx.globalAlpha = 0.8; + + // MediaPipe手部连接定义 + const connections = [ + [0, 1], [1, 2], [2, 3], [3, 4], // 拇指 + [0, 5], [5, 6], [6, 7], [7, 8], // 食指 + [5, 9], [9, 10], [10, 11], [11, 12], // 中指 + [9, 13], [13, 14], [14, 15], [15, 16], // 无名指 + [13, 17], [17, 18], [18, 19], [19, 20], // 小指 + [0, 17] // 手腕连接 + ]; + + connections.forEach(([start, end]) => { + ctx.beginPath(); + ctx.moveTo(landmarks[start][0], landmarks[start][1]); + ctx.lineTo(landmarks[end][0], landmarks[end][1]); + ctx.stroke(); + }); + + // 绘制关键点 + landmarks.forEach((point, i) => { + ctx.fillStyle = this.colors.secondary; + ctx.beginPath(); + ctx.arc(point[0], point[1], 3, 0, Math.PI * 2); + ctx.fill(); + }); + + ctx.globalAlpha = 1; + + // 绘制握持评分 + if (handData.grasp_score !== undefined) { + ctx.font = this.fonts.body; + ctx.fillStyle = this.colors.text; + ctx.fillText( + `握持评分 Grasp Score: ${handData.grasp_score.toFixed(2)}`, + 20, + 80 + ); + } + } + + drawProgressBars(tracking, width, height) { + const ctx = this.ctx; + const barWidth = width * 0.25; + const barHeight = 12; + const x = 20; + const y = height - 80; + + // 对齐进度条 + this.drawProgressBar( + x, y - 30, + barWidth, barHeight, + tracking.align_score || 0, + '对齐 Alignment', + this.colors.primary + ); + + // 距离进度条 + this.drawProgressBar( + x, y, + barWidth, barHeight, + tracking.range_score || 0, + `距离 Distance (≈1)`, + this.colors.secondary + ); + + // 显示比率 + if (tracking.ratio !== null && tracking.ratio !== undefined) { + ctx.font = this.fonts.small; + ctx.fillStyle = this.colors.text; + ctx.fillText( + `面积比 Ratio: ${tracking.ratio.toFixed(2)}`, + x + barWidth + 20, + y + 8 + ); + } + } + + drawProgressBar(x, y, width, height, value, label, color) { + const ctx = this.ctx; + + // 背景 + ctx.fillStyle = this.colors.surface; + ctx.fillRect(x, y, width, height); + + // 边框 + ctx.strokeStyle = color + '44'; + ctx.lineWidth = 1; + ctx.strokeRect(x, y, width, height); + + // 填充 + const fillWidth = width * Math.max(0, Math.min(1, value)); + const gradient = ctx.createLinearGradient(x, y, x + fillWidth, y); + gradient.addColorStop(0, color + 'AA'); + gradient.addColorStop(1, color); + ctx.fillStyle = gradient; + ctx.fillRect(x, y, fillWidth, height); + + // 标签 + ctx.font = this.fonts.small; + ctx.fillStyle = this.colors.textMuted; + ctx.fillText(label, x, y - 5); + } + + drawStats(width, height) { + const ctx = this.ctx; + + // FPS显示 + ctx.font = this.fonts.body; + ctx.fillStyle = this.colors.accent; + ctx.fillText(`FPS: ${(this.data.fps || 0).toFixed(1)}`, 20, 40); + + // 模式显示 + const modeText = { + 'SEGMENT': '分割模式 Segmentation', + 'FLASH': '锁定模式 Locking', + 'TRACK': '追踪模式 Tracking' + }; + + ctx.fillStyle = this.colors.text; + ctx.fillText(modeText[this.data.mode] || this.data.mode, width - 200, 40); + } + + // 辅助函数 + scalePoint(point, width, height) { + if (!this.data.frame_size) return [0, 0]; + return [ + point[0] * width / this.data.frame_size.width, + point[1] * height / this.data.frame_size.height + ]; + } + + scalePoints(points, width, height) { + return points.map(p => this.scalePoint(p, width, height)); + } + + getContourCenter(points) { + const sum = points.reduce((acc, p) => [acc[0] + p[0], acc[1] + p[1]], [0, 0]); + return [sum[0] / points.length, sum[1] / points.length]; + } + + drawTargetMarker(x, y) { + const ctx = this.ctx; + ctx.strokeStyle = this.colors.primary; + ctx.lineWidth = 2; + + // 十字准星 + const size = 20; + ctx.beginPath(); + ctx.moveTo(x - size, y); + ctx.lineTo(x - size/2, y); + ctx.moveTo(x + size/2, y); + ctx.lineTo(x + size, y); + ctx.moveTo(x, y - size); + ctx.lineTo(x, y - size/2); + ctx.moveTo(x, y + size/2); + ctx.lineTo(x, y + size); + ctx.stroke(); + + // 圆圈 + ctx.beginPath(); + ctx.arc(x, y, 10, 0, Math.PI * 2); + ctx.stroke(); + } + + drawLockAnimation(x, y, progress) { + const ctx = this.ctx; + const radius = 30 + 10 * Math.sin(progress * Math.PI * 2); + + ctx.strokeStyle = this.colors.accent; + ctx.lineWidth = 3; + ctx.globalAlpha = 0.8; + + // 旋转的锁定环 + ctx.save(); + ctx.translate(x, y); + ctx.rotate(progress * Math.PI * 2); + + // 绘制4个弧形 + for (let i = 0; i < 4; i++) { + ctx.beginPath(); + ctx.arc(0, 0, radius, i * Math.PI/2 + 0.1, i * Math.PI/2 + Math.PI/2 - 0.1); + ctx.stroke(); + } + + ctx.restore(); + ctx.globalAlpha = 1; + } + + drawStatusText(mainText, subText, x, y, color) { + const ctx = this.ctx; + + // 主文字(中文) + ctx.font = this.fonts.subtitle; + ctx.fillStyle = color; + ctx.textAlign = 'center'; + ctx.fillText(mainText, x, y); + + // 副文字(英文) + if (subText) { + ctx.font = this.fonts.small; + ctx.fillStyle = this.colors.textMuted; + ctx.fillText(subText, x, y + 20); + } + + ctx.textAlign = 'left'; + } +} + +// 初始化 +window.addEventListener('DOMContentLoaded', () => { + window.visualizer = new TechVisualizer('tech-canvas'); +}); \ No newline at end of file diff --git a/sync_recorder.py b/sync_recorder.py new file mode 100644 index 0000000..e455188 --- /dev/null +++ b/sync_recorder.py @@ -0,0 +1,322 @@ +# sync_recorder.py +# 同步录制ESP32视频流和音频指令 +# 自动确保视频和音频时间轴对齐 + +import os +import cv2 +import wave +import numpy as np +import threading +import time +from datetime import datetime +from collections import deque +import struct + +class SyncRecorder: + """同步录制器 - 视频+音频时间对齐""" + + def __init__(self, output_dir="recordings", fps=15.0): + """ + 初始化录制器 + :param output_dir: 输出目录 + :param fps: 视频帧率(默认15fps) + """ + self.output_dir = output_dir + self.fps = fps + self.frame_duration = 1.0 / fps # 每帧时长(秒) + + # 创建输出目录 + os.makedirs(output_dir, exist_ok=True) + + # 录制状态 + self.is_recording = False + self.start_time = None + + # 视频写入器 + self.video_writer = None + self.video_path = None + self.last_frame = None + self.frame_count = 0 + + # 音频写入器 + self.audio_writer = None + self.audio_path = None + self.audio_buffer = bytearray() + self.last_audio_time = 0.0 + + # 音频参数(ESP32标准:16kHz, 16bit, Mono) + self.sample_rate = 16000 + self.sample_width = 2 # 16bit = 2 bytes + self.channels = 1 + + # 线程安全 + self.lock = threading.Lock() + + # 性能监控 + self.frames_written = 0 + self.audio_bytes_written = 0 + self.last_log_time = time.time() + + print(f"[RECORDER] 录制器初始化完成 - FPS={fps}, 输出目录={output_dir}") + + def start_recording(self): + """开始新的录制会话""" + if self.is_recording: + print("[RECORDER] 警告:已经在录制中") + return False + + # 生成文件名(时间戳) + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") + self.video_path = os.path.join(self.output_dir, f"video_{timestamp}.avi") + self.audio_path = os.path.join(self.output_dir, f"audio_{timestamp}.wav") + + # 重置状态 + self.start_time = time.time() + self.last_audio_time = 0.0 + self.frame_count = 0 + self.frames_written = 0 + self.audio_bytes_written = 0 + self.audio_buffer.clear() + self.last_frame = None + + # 初始化音频文件 + try: + self.audio_writer = wave.open(self.audio_path, 'wb') + self.audio_writer.setnchannels(self.channels) + self.audio_writer.setsampwidth(self.sample_width) + self.audio_writer.setframerate(self.sample_rate) + except Exception as e: + print(f"[RECORDER] 音频文件初始化失败: {e}") + return False + + self.is_recording = True + print(f"[RECORDER] 开始录制") + print(f" 视频: {self.video_path}") + print(f" 音频: {self.audio_path}") + return True + + def add_frame(self, jpeg_data: bytes): + """ + 添加一帧视频(原始JPEG数据) + :param jpeg_data: JPEG格式的图像数据 + """ + if not self.is_recording: + return + + try: + with self.lock: + # 解码JPEG + arr = np.frombuffer(jpeg_data, dtype=np.uint8) + frame = cv2.imdecode(arr, cv2.IMREAD_COLOR) + + if frame is None: + print(f"[RECORDER] 警告:帧解码失败") + return + + # 首帧:初始化视频写入器 + if self.video_writer is None: + height, width = frame.shape[:2] + # 使用XVID编码器(Windows兼容性好) + fourcc = cv2.VideoWriter_fourcc(*'XVID') + self.video_writer = cv2.VideoWriter( + self.video_path, + fourcc, + self.fps, + (width, height) + ) + + if not self.video_writer.isOpened(): + print(f"[RECORDER] 错误:视频写入器初始化失败") + self.is_recording = False + return + + print(f"[RECORDER] 视频写入器初始化:{width}x{height} @ {self.fps}fps") + + # 写入帧 + self.video_writer.write(frame) + self.frame_count += 1 + self.frames_written += 1 + self.last_frame = frame + + # 计算当前视频时长(秒) + current_video_time = self.frame_count * self.frame_duration + + # 音频同步:填充静音到视频时长 + self._sync_audio_to_video(current_video_time) + + # 性能日志(每10秒) + now = time.time() + if now - self.last_log_time > 10.0: + elapsed = now - self.start_time + avg_fps = self.frames_written / elapsed if elapsed > 0 else 0 + audio_duration = self.audio_bytes_written / (self.sample_rate * self.sample_width) + print(f"[RECORDER] 录制中 - 帧数={self.frames_written}, " + f"实际FPS={avg_fps:.1f}, " + f"视频时长={current_video_time:.1f}s, " + f"音频时长={audio_duration:.1f}s") + self.last_log_time = now + + except Exception as e: + print(f"[RECORDER] 添加帧失败: {e}") + import traceback + traceback.print_exc() + + def add_audio(self, pcm_data: bytes, text: str = ""): + """ + 添加音频数据(PCM 16bit) + :param pcm_data: PCM格式音频数据 + :param text: 语音文本(用于日志) + """ + if not self.is_recording: + return + + try: + with self.lock: + # 当前视频时长 + current_video_time = self.frame_count * self.frame_duration + + # 在添加音频前,先填充静音到视频时长 + self._sync_audio_to_video(current_video_time) + + # 写入实际音频 + self.audio_writer.writeframes(pcm_data) + audio_duration = len(pcm_data) / (self.sample_rate * self.sample_width) + self.last_audio_time = current_video_time + audio_duration + self.audio_bytes_written += len(pcm_data) + + if text: + print(f"[RECORDER] 录制语音: {text[:30]}... (时间={current_video_time:.2f}s, 时长={audio_duration:.2f}s)") + + except Exception as e: + print(f"[RECORDER] 添加音频失败: {e}") + + def _sync_audio_to_video(self, video_time: float): + """ + 同步音频到视频时长(填充静音) + :param video_time: 当前视频时长(秒) + """ + # 计算需要填充的静音时长 + silence_duration = video_time - self.last_audio_time + + if silence_duration > 0.01: # 大于10ms才填充 + # 生成静音数据 + silence_samples = int(silence_duration * self.sample_rate) + silence_bytes = silence_samples * self.sample_width + silence_data = b'\x00' * silence_bytes + + # 写入静音 + self.audio_writer.writeframes(silence_data) + self.audio_bytes_written += len(silence_data) + self.last_audio_time = video_time + + def stop_recording(self): + """停止录制并保存文件""" + if not self.is_recording: + return + + print("[RECORDER] 正在保存录制文件...") + self.is_recording = False + + with self.lock: + # 最后一次音频同步 + try: + if self.frame_count > 0: + final_video_time = self.frame_count * self.frame_duration + self._sync_audio_to_video(final_video_time) + except Exception as e: + print(f"[RECORDER] 最终音频同步失败: {e}") + + # 关闭视频写入器(关键步骤) + if self.video_writer is not None: + try: + print("[RECORDER] 正在关闭视频写入器...") + self.video_writer.release() + print("[RECORDER] 视频写入器已关闭") + except Exception as e: + print(f"[RECORDER] 关闭视频写入器失败: {e}") + finally: + self.video_writer = None + + # 关闭音频写入器 + if self.audio_writer is not None: + try: + print("[RECORDER] 正在关闭音频写入器...") + self.audio_writer.close() + print("[RECORDER] 音频写入器已关闭") + except Exception as e: + print(f"[RECORDER] 关闭音频写入器失败: {e}") + finally: + self.audio_writer = None + + # 统计信息 + try: + elapsed = time.time() - self.start_time if self.start_time else 0 + video_duration = self.frame_count * self.frame_duration + audio_duration = self.audio_bytes_written / (self.sample_rate * self.sample_width) + + print(f"\n{'='*60}") + print(f"[RECORDER] 录制完成") + print(f"{'='*60}") + print(f" 总耗时: {elapsed:.1f}秒") + print(f"\n 视频: {self.video_path}") + print(f" - 帧数: {self.frames_written}") + print(f" - 时长: {video_duration:.2f}秒") + if elapsed > 0: + print(f" - 平均FPS: {self.frames_written/elapsed:.1f}") + print(f"\n 音频: {self.audio_path}") + print(f" - 数据量: {self.audio_bytes_written/1024:.1f} KB") + print(f" - 时长: {audio_duration:.2f}秒") + print(f"\n 时间差: {abs(video_duration - audio_duration):.3f}秒") + + # 验证文件 + if os.path.exists(self.video_path): + video_size = os.path.getsize(self.video_path) / 1024 / 1024 + print(f" 视频文件大小: {video_size:.2f} MB ✓") + else: + print(f" ⚠ 警告:视频文件未生成") + + if os.path.exists(self.audio_path): + audio_size = os.path.getsize(self.audio_path) / 1024 + print(f" 音频文件大小: {audio_size:.2f} KB ✓") + else: + print(f" ⚠ 警告:音频文件未生成") + + print(f"{'='*60}\n") + except Exception as e: + print(f"[RECORDER] 显示统计信息失败: {e}") + + +# 全局录制器实例 +_global_recorder = None +_recorder_lock = threading.Lock() + +def get_recorder(): + """获取全局录制器实例""" + global _global_recorder + with _recorder_lock: + if _global_recorder is None: + _global_recorder = SyncRecorder() + return _global_recorder + +def start_recording(): + """启动录制""" + recorder = get_recorder() + return recorder.start_recording() + +def stop_recording(): + """停止录制""" + recorder = get_recorder() + recorder.stop_recording() + +def record_frame(jpeg_data: bytes): + """记录一帧(供外部调用)""" + recorder = get_recorder() + if recorder.is_recording: + recorder.add_frame(jpeg_data) + +def record_audio(pcm_data: bytes, text: str = ""): + """记录音频(供外部调用)""" + recorder = get_recorder() + if recorder.is_recording: + recorder.add_audio(pcm_data, text) + diff --git a/templates/index.html b/templates/index.html new file mode 100644 index 0000000..ea4c564 --- /dev/null +++ b/templates/index.html @@ -0,0 +1,715 @@ + + + + + + + NaviGlass 导盲系统可视化 + + + + + + + + + +
+ +
+
+ + +
+ +
📊 IMU 姿态可视化 +
+
+
+
+
+ +
+ + + +
+ + + +
+ + +
+ + + + + + + + + + + + + + + + + + +
+ + + + + + + + + \ No newline at end of file diff --git a/trafficlight_detection.py b/trafficlight_detection.py new file mode 100644 index 0000000..f2da50e --- /dev/null +++ b/trafficlight_detection.py @@ -0,0 +1,627 @@ +# -*- coding: utf-8 -*- +""" +红绿灯检测模块 - 独立工作流版本 +基于YOLO模型实时检测红绿灯状态,并通过语音反馈 +可以通过语音命令"检测红绿灯"、"停止检测"来控制 +""" + +import os +import time +import threading +import cv2 +import numpy as np +from ultralytics import YOLO +import bridge_io +from audio_player import play_voice_text # 使用统一的语音播放接口 +import logging + +# Day 20: TensorRT 模型加载工具 +from model_utils import get_best_model_path + +logger = logging.getLogger(__name__) + +# ========= 配置参数 ========= +# Day 20: 优先使用 TensorRT 引擎 +YOLO_MODEL_PATH = get_best_model_path(os.path.join(os.path.dirname(__file__), "model", "trafficlight.pt")) + +# ========= 显示参数 ========= +CONF_THRESHOLD = 0.25 # 置信度阈值 +FONT_SIZE = 20 +STROKE_WIDTH = 3 + +# ========= 语音播报参数 ========= +TTS_INTERVAL_SEC = 2.0 # 语音播报间隔(避免频繁播报) +ENABLE_TTS = False # 【禁用】红绿灯检测模块不播报,由 workflow_crossstreet.py 统一处理 + +# ========= 线程控制 ========= +_detection_thread = None +_stop_event = None +_detection_running = False + +# ========= 单帧处理模式(新增)========= +_model = None # 全局模型实例 +_last_tts_ts = 0.0 +_last_detected_light = None +_detection_history = [] + +# ========= 前端配色(BGR) ========= +FRONTEND_COLORS = { + "text": (230, 237, 243), # 白色文字 + "red": (0, 0, 255), # 红色 + "yellow": (0, 255, 255), # 黄色 + "green": (0, 255, 0), # 绿色 + "muted": (159, 176, 195), # 灰色 +} + +# 红绿灯状态到颜色的映射 +LIGHT_COLORS = { + "stop": FRONTEND_COLORS["red"], + "countdown_go": FRONTEND_COLORS["yellow"], + "go": FRONTEND_COLORS["green"], +} + +# 【修正】红绿灯状态到中文的映射 +# 只包含真正的红绿灯类别,排除斑马线(crossing)和空白 +LIGHT_NAMES = { + "stop": "红灯", # 机动车红灯 + "go": "绿灯", # 机动车绿灯 + "countdown_go": "黄灯", # 绿灯倒计时(用黄灯提示) + "countdown_stop": "红灯", # 红灯倒计时 +} + +# 红绿灯状态到语音文件的映射 +LIGHT_VOICE_MAP = { + "stop": "红灯", # → voice/红灯.WAV + "go": "绿灯", # → voice/绿灯.WAV + "countdown_go": "黄灯", # → voice/黄灯.WAV(绿灯倒计时用黄灯提示) + "countdown_stop": "红灯", # → voice/红灯.WAV +} + +# 需要过滤的类别(不检测、不显示) +FILTERED_CLASSES = { + "crossing", # 斑马线(不需要) + "blank", # 空白 + "countdown_blank" # 倒计时空白 +} + +# UI文本管理 +_UI_LINE = 0 +_UI_H = 0 +_UI_TR_LINE = 0 +_UI_TOP_MARGIN = 12 +_UI_RIGHT_MARGIN = 12 +UNIFIED_FONT_PX = 12 + +def ui_reset_overlay(img_h: int): + """每帧调用一次,重置叠加行计数""" + global _UI_LINE, _UI_H, _UI_TR_LINE + _UI_LINE = 0 + _UI_TR_LINE = 0 + _UI_H = int(img_h) + +def _ui_next_y_top(font_size: int) -> int: + """返回右上角下一行的y坐标""" + global _UI_TR_LINE + line_gap = max(4, int(font_size * 0.25)) + y_top = _UI_TOP_MARGIN + (_UI_TR_LINE * (font_size + line_gap)) + _UI_TR_LINE += 1 + return y_top + +# ======== 中文文本绘制 ======== +_PIL_OK = False +_FONT_PATH = None + +def _init_font(): + global _PIL_OK, _FONT_PATH + try: + from PIL import ImageFont + _PIL_OK = True + except Exception: + _PIL_OK = False + return + candidates = [ + # Linux 中文字体路径 (Ubuntu/Debian) + "/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc", + "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", + "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", + "/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc", + "/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf", + ] + for p in candidates: + if os.path.exists(p): + _FONT_PATH = p + return + _PIL_OK = False + +_init_font() + +def draw_text_cn(img_bgr, text, xy, font_size=20, color=(255,255,255), ui_hint=True): + """统一的中文文本绘制""" + color = (255, 255, 255) + font_size = int(UNIFIED_FONT_PX) + + H, W = img_bgr.shape[:2] + y_top = _ui_next_y_top(font_size) if ui_hint else xy[1] + tw = th = 0 + font_obj = None + + if _PIL_OK and _FONT_PATH: + try: + from PIL import Image, ImageDraw, ImageFont + font_obj = ImageFont.truetype(_FONT_PATH, font_size) + bbox = ImageDraw.Draw(Image.new('RGB', (1,1))).textbbox((0,0), text, font=font_obj) + tw = max(1, bbox[2] - bbox[0]) + th = max(1, bbox[3] - bbox[1]) + except Exception: + pass + + if _PIL_OK and _FONT_PATH and font_obj is not None: + try: + from PIL import Image, ImageDraw + img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) + pil_img = Image.fromarray(img_rgb) + draw = ImageDraw.Draw(pil_img) + if ui_hint: + x = max(8, W - _UI_RIGHT_MARGIN - tw) + y = y_top + else: + x = xy[0] + y = xy[1] + draw.text((x, y), text, fill=color, font=font_obj) + img_bgr[:] = cv2.cvtColor(np.asarray(pil_img), cv2.COLOR_RGB2BGR) + return + except Exception: + pass + + # OpenCV 回退 + if tw <= 0 or th <= 0: + scale = font_size/24.0 + (tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, scale, 2) + if ui_hint: + x = max(8, W - _UI_RIGHT_MARGIN - int(tw)) + y_baseline = int(y_top + th) + else: + x = xy[0] + y_baseline = xy[1] + int(th) + cv2.putText(img_bgr, text, (x, y_baseline), cv2.FONT_HERSHEY_SIMPLEX, font_size/24.0, color, 2, cv2.LINE_AA) + +def main(headless: bool = True, stop_event=None): + """ + 红绿灯检测主函数 + + 参数: + headless: 是否无头模式(不显示OpenCV窗口) + stop_event: threading.Event,用于停止检测 + """ + + print("[TRAFFIC] 加载 YOLO 红绿灯检测模型...") + try: + model = YOLO(YOLO_MODEL_PATH) + print(f"[TRAFFIC] 模型加载成功: {YOLO_MODEL_PATH}") + except Exception as e: + print(f"[TRAFFIC] 模型加载失败: {e}") + return + + # 获取类别名称 + class_names = model.names if hasattr(model, 'names') else {} + print(f"[TRAFFIC] 模型类别: {class_names}") + + # 状态跟踪 + last_tts_ts = 0.0 + last_detected_light = None + fps_hist = [] + + # 【优化】状态稳定性判断 - 使用多数表决而非连续帧 + detection_history = [] # 保存最近N帧的检测结果 + HISTORY_SIZE = 5 # 保存最近5帧 + MAJORITY_THRESHOLD = 3 # 5帧中至少3帧相同才认为稳定 + + # 【新增】帧统计 + frame_count = 0 + frame_received_count = 0 + frame_none_count = 0 + last_frame_log_time = time.time() + + print("[TRAFFIC] 等待 ESP32 画面...") + + try: + while True: + # 检查停止事件 + if stop_event and stop_event.is_set(): + print("[TRAFFIC] 停止事件触发,退出检测") + break + + # 【优化】从bridge_io获取原始BGR帧 - 增加超时时间 + frame = bridge_io.wait_raw_bgr(timeout_sec=2.0) # 从0.5秒增加到2秒 + + frame_count += 1 + + if frame is None: + frame_none_count += 1 + # 每3秒打印一次帧统计 + current_time = time.time() + if current_time - last_frame_log_time > 3.0: + print(f"[TRAFFIC] 帧统计: 总={frame_count}, 收到={frame_received_count}, " + f"丢失={frame_none_count}, 丢失率={frame_none_count/frame_count*100:.1f}%") + last_frame_log_time = current_time + + if headless: + cv2.waitKey(1) + continue + + frame_received_count += 1 + + # 重置UI叠加 + H, W = frame.shape[:2] + ui_reset_overlay(H) + + vis = frame.copy() + t_now = time.time() + + # 【优化】YOLO推理 - 添加计时 + inference_start = time.time() + results = model(frame, conf=CONF_THRESHOLD, verbose=False) + inference_time = (time.time() - inference_start) * 1000 + + # 监控推理时间 + if inference_time > 100: + print(f"[TRAFFIC] WARNING: 推理耗时 {inference_time:.0f}ms") + + # 处理检测结果 + detected_light = None + max_conf = 0.0 + + if results and len(results) > 0: + r = results[0] + if r.boxes is not None and len(r.boxes) > 0: + # 【过滤】遍历所有检测框,找到置信度最高的红绿灯(排除斑马线) + for box in r.boxes: + cls_id = int(box.cls[0]) + conf = float(box.conf[0]) + class_name = class_names.get(cls_id, f"class_{cls_id}") + class_name_lower = class_name.lower() + + # 跳过不需要的类别 + if class_name_lower in FILTERED_CLASSES: + continue + + if conf > max_conf: + max_conf = conf + detected_light = class_name_lower + + # 【过滤】绘制检测框(只绘制红绿灯) + for box in r.boxes: + cls_id = int(box.cls[0]) + conf = float(box.conf[0]) + class_name = class_names.get(cls_id, f"class_{cls_id}") + class_name_lower = class_name.lower() + + # 跳过不需要的类别 + if class_name_lower in FILTERED_CLASSES: + continue + + # 获取边界框坐标 + x1, y1, x2, y2 = map(int, box.xyxy[0]) + + # 确定颜色 + color = LIGHT_COLORS.get(class_name_lower, FRONTEND_COLORS["text"]) + + # 绘制边界框 + cv2.rectangle(vis, (x1, y1), (x2, y2), color, STROKE_WIDTH) + + # 绘制中文标签(使用PIL) + label = f"{LIGHT_NAMES.get(class_name.lower(), class_name)}: {conf:.2f}" + + if _PIL_OK and _FONT_PATH: + try: + from PIL import Image, ImageDraw, ImageFont + # 使用较大的字体绘制标签 + font_obj = ImageFont.truetype(_FONT_PATH, 20) + # 转换为PIL图像 + img_rgb = cv2.cvtColor(vis, cv2.COLOR_BGR2RGB) + pil_img = Image.fromarray(img_rgb) + draw = ImageDraw.Draw(pil_img) + + # 计算文本尺寸 + bbox = draw.textbbox((0, 0), label, font=font_obj) + text_w = bbox[2] - bbox[0] + text_h = bbox[3] - bbox[1] + + # 标签位置 + label_y = max(y1 - text_h - 8, text_h) + + # 绘制背景矩形 + bg_x1 = x1 + bg_y1 = label_y - text_h - 4 + bg_x2 = x1 + text_w + 8 + bg_y2 = label_y + 4 + cv2.rectangle(vis, (bg_x1, bg_y1), (bg_x2, bg_y2), color, -1) + + # 重新转换(因为矩形是用OpenCV画的) + img_rgb = cv2.cvtColor(vis, cv2.COLOR_BGR2RGB) + pil_img = Image.fromarray(img_rgb) + draw = ImageDraw.Draw(pil_img) + + # 【删除】绘制文字 + # draw.text((x1 + 4, label_y - text_h), label, fill=(0, 0, 0), font=font_obj) + + # 转换回OpenCV格式 + vis[:] = cv2.cvtColor(np.asarray(pil_img), cv2.COLOR_RGB2BGR) + except Exception as e: + # 【删除】PIL失败时的文本标签 + pass + else: + # 【删除】文本标签 + pass + + # 【优化】状态稳定性判断:使用多数表决而非连续帧 + detection_history.append(detected_light) + if len(detection_history) > HISTORY_SIZE: + detection_history.pop(0) + + # 判断状态是否稳定(多数表决) + stable_light = None + if len(detection_history) >= MAJORITY_THRESHOLD: + # 统计最近N帧中每个状态出现的次数 + valid_detections = [d for d in detection_history if d and d in LIGHT_NAMES] + if len(valid_detections) >= MAJORITY_THRESHOLD: + # 找出现次数最多的状态 + from collections import Counter + counter = Counter(valid_detections) + most_common = counter.most_common(1) + if most_common and most_common[0][1] >= MAJORITY_THRESHOLD: + stable_light = most_common[0][0] + # 打印调试信息 + if frame_received_count % 30 == 0: + print(f"[TRAFFIC] 检测历史: {detection_history[-5:]}, 稳定状态: {stable_light}") + + # 【禁用语音播报】只检测不播报,由调用者(workflow_crossstreet.py)统一处理语音 + # 只更新状态跟踪 + if stable_light: + # 状态改变时记录(但不播报) + if stable_light != last_detected_light: + last_detected_light = stable_light + print(f"[TRAFFIC] 检测到稳定状态改变: {LIGHT_NAMES[stable_light]}(不播报)") + last_tts_ts = t_now + # 超过间隔时间,更新时间戳(但不播报) + elif (t_now - last_tts_ts) > TTS_INTERVAL_SEC: + print(f"[TRAFFIC] 稳定状态持续: {LIGHT_NAMES[stable_light]}(不播报)") + last_tts_ts = t_now + + # 【删除】显示当前检测状态 + # if detected_light and detected_light in LIGHT_NAMES: + # status_text = f"检测: {LIGHT_NAMES[detected_light]} ({max_conf:.2f})" + # color = LIGHT_COLORS[detected_light] + # else: + # status_text = "检测: 无" + # color = FRONTEND_COLORS["muted"] + # draw_text_cn(vis, status_text, (10, 40), font_size=18, color=color) + + # 【删除】显示稳定状态 + # if stable_light: + # stable_text = f"稳定状态: {LIGHT_NAMES[stable_light]}" + # stable_color = LIGHT_COLORS[stable_light] + # else: + # stable_text = f"稳定状态: 等待中 ({len(detection_history)}/{HISTORY_SIZE})" + # stable_color = FRONTEND_COLORS["muted"] + # draw_text_cn(vis, stable_text, (10, 60), font_size=18, color=stable_color) + + # 【删除】FPS计算和显示 + # fps_hist.append(t_now) + # if len(fps_hist) > 30: + # fps_hist.pop(0) + # fps = 0.0 if len(fps_hist) < 2 else (len(fps_hist)-1)/(fps_hist[-1]-fps_hist[0]) + # draw_text_cn(vis, f"FPS: {fps:.1f}", (10, 20), font_size=16, color=FRONTEND_COLORS["text"]) + + # 发送可视化结果到前端 + bridge_io.send_vis_bgr(vis) + + # 非headless模式下显示窗口 + if not headless: + cv2.imshow("Traffic Light Detection", vis) + key = cv2.waitKey(1) & 0xFF + if key in (27, ord('q')): + break + else: + cv2.waitKey(1) + + except Exception as e: + print(f"[TRAFFIC] 检测过程出错: {e}") + finally: + if not headless: + cv2.destroyAllWindows() + print("[TRAFFIC] 红绿灯检测已停止") + + +def start_detection(): + """启动红绿灯检测(在后台线程中运行)""" + global _detection_thread, _stop_event, _detection_running + + if _detection_running: + print("[TRAFFIC] 红绿灯检测已在运行中") + return False + + _stop_event = threading.Event() + _detection_thread = threading.Thread( + target=main, + args=(True, _stop_event), # headless=True, stop_event + daemon=True, + name="TrafficLightDetection" + ) + _detection_thread.start() + _detection_running = True + print("[TRAFFIC] 红绿灯检测已启动(后台线程)") + return True + +def stop_detection(): + """停止红绿灯检测""" + global _detection_thread, _stop_event, _detection_running + + if not _detection_running: + print("[TRAFFIC] 红绿灯检测未运行") + return False + + print("[TRAFFIC] 正在停止红绿灯检测...") + if _stop_event: + _stop_event.set() + + if _detection_thread: + _detection_thread.join(timeout=2.0) + _detection_thread = None + + _stop_event = None + _detection_running = False + print("[TRAFFIC] 红绿灯检测已停止") + return True + +def is_detection_running(): + """检查红绿灯检测是否正在运行""" + return _detection_running + +def init_model(): + """初始化YOLO模型(单帧处理模式)""" + global _model + if _model is not None: + print("[TRAFFIC] 模型已加载") + return True + + try: + print("[TRAFFIC] 加载 YOLO 红绿灯检测模型...") + _model = YOLO(YOLO_MODEL_PATH) + print(f"[TRAFFIC] 模型加载成功: {YOLO_MODEL_PATH}") + class_names = _model.names if hasattr(_model, 'names') else {} + print(f"[TRAFFIC] 模型类别: {class_names}") + return True + except Exception as e: + print(f"[TRAFFIC] 模型加载失败: {e}") + _model = None + return False + +def process_single_frame(image: np.ndarray, ui_broadcast_callback=None) -> dict: + """ + 处理单帧图像(主线程模式,避免掉帧) + 参数: + image: 输入图像 + ui_broadcast_callback: 前端广播回调函数(用于显示红绿灯状态) + 返回:{'vis_image': 可视化图像, 'detected_light': 检测到的灯, 'stable_light': 稳定状态} + """ + global _model, _last_tts_ts, _last_detected_light, _detection_history + + if _model is None: + if not init_model(): + return {'vis_image': image, 'detected_light': None, 'stable_light': None} + + vis = image.copy() + t_now = time.time() + + # YOLO推理 + results = _model(image, conf=CONF_THRESHOLD, verbose=False) + + # 处理检测结果 + detected_light = None + max_conf = 0.0 + class_names = _model.names if hasattr(_model, 'names') else {} + + if results and len(results) > 0: + r = results[0] + if r.boxes is not None and len(r.boxes) > 0: + # 遍历所有检测框,找到置信度最高的红绿灯(过滤掉crossing等) + for box in r.boxes: + cls_id = int(box.cls[0]) + conf = float(box.conf[0]) + class_name = class_names.get(cls_id, f"class_{cls_id}") + class_name_lower = class_name.lower() + + # 【过滤】跳过不需要的类别(斑马线、空白等) + if class_name_lower in FILTERED_CLASSES: + continue + + if conf > max_conf: + max_conf = conf + detected_light = class_name_lower + + # 绘制检测框(只绘制红绿灯,不绘制斑马线) + for box in r.boxes: + cls_id = int(box.cls[0]) + conf = float(box.conf[0]) + class_name = class_names.get(cls_id, f"class_{cls_id}") + class_name_lower = class_name.lower() + + # 【过滤】跳过不需要的类别 + if class_name_lower in FILTERED_CLASSES: + continue + + # 获取边界框坐标 + x1, y1, x2, y2 = map(int, box.xyxy[0]) + + # 确定颜色 + color = LIGHT_COLORS.get(class_name_lower, FRONTEND_COLORS["text"]) + + # 绘制边界框 + cv2.rectangle(vis, (x1, y1), (x2, y2), color, STROKE_WIDTH) + + # 【放宽】状态稳定性判断(多数表决) - 降低要求 + _detection_history.append(detected_light) + if len(_detection_history) > 5: + _detection_history.pop(0) + + stable_light = None + if len(_detection_history) >= 2: # 从3帧降低到2帧 + from collections import Counter + valid_detections = [d for d in _detection_history if d and d in LIGHT_NAMES] + if len(valid_detections) >= 2: # 从3帧降低到2帧 + counter = Counter(valid_detections) + most_common = counter.most_common(1) + if most_common and most_common[0][1] >= 2: # 从3次降低到2次 + stable_light = most_common[0][0] + + # 【调试】打印检测结果(已禁用) + # print(f"[TRAFFIC-DEBUG] detected={detected_light}, stable={stable_light}, history={_detection_history}") + + # 【禁用语音播报】只检测不播报,由 workflow_crossstreet.py 统一处理语音 + # 只更新状态跟踪,不调用 play_voice_text + if stable_light: + # 更新状态跟踪(用于检测状态变化) + if stable_light != _last_detected_light: + _last_detected_light = stable_light + print(f"[TRAFFIC] 检测到稳定状态改变: {LIGHT_NAMES[stable_light]}(不播报)") + _last_tts_ts = t_now + elif (t_now - _last_tts_ts) > TTS_INTERVAL_SEC: + # 超过间隔时间,更新时间戳(但不播报) + print(f"[TRAFFIC] 稳定状态持续: {LIGHT_NAMES[stable_light]}(不播报)") + _last_tts_ts = t_now + + # 【删除】状态文本显示 + # if detected_light and detected_light in LIGHT_NAMES: + # status_text = f"{LIGHT_NAMES[detected_light]} ({max_conf:.2f})" + # else: + # status_text = "无检测" + # + # if stable_light: + # stable_text = f"稳定: {LIGHT_NAMES[stable_light]}" + # else: + # stable_text = f"等待稳定 ({len(_detection_history)}/5)" + # + # # 添加简单的文本显示 + # cv2.putText(vis, status_text, (10, 30), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2) + # cv2.putText(vis, stable_text, (10, 70), cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2) + + return { + 'vis_image': vis, + 'detected_light': detected_light, + 'stable_light': stable_light + } + +def reset_detection_state(): + """重置检测状态""" + global _last_tts_ts, _last_detected_light, _detection_history + _last_tts_ts = 0.0 + _last_detected_light = None + _detection_history = [] + print("[TRAFFIC] 检测状态已重置") + +if __name__ == "__main__": + main(headless=False) + + + diff --git a/utils.py b/utils.py new file mode 100644 index 0000000..9fef671 --- /dev/null +++ b/utils.py @@ -0,0 +1,307 @@ +# utils.py +# -*- coding: utf-8 -*- +import cv2 +import numpy as np +import logging + +logger = logging.getLogger(__name__) + +# 物品名称映射 +ITEM_TO_CLASS_MAP = { + "红牛": "Red_Bull", + "AD钙奶": "AD_milk", + "ad钙奶": "AD_milk", + "钙奶": "AD_milk", +} + +# 英文类别名到中文的映射 +_OBSTACLE_NAME_CN = { + 'person': '人', + 'bicycle': '自行车', + 'car': '车', + 'motorcycle': '摩托车', + 'bus': '公交车', + 'truck': '卡车', + 'animal': '动物', + 'scooter': '电瓶车', + 'stroller': '婴儿车', + 'dog': '狗', +} + +# 动态类别名称列表 +DYNAMIC_CLASS_NAMES = {'person', 'bicycle', 'car', 'motorcycle', 'bus', 'truck', 'animal', 'dog'} + +def extract_english_label(item_cn: str) -> tuple: + """ + 提取中文物品名称对应的英文标签 + :param item_cn: 中文物品名称 + :return: (英文标签, 来源) + """ + # 先查找本地映射 + if item_cn in ITEM_TO_CLASS_MAP: + return ITEM_TO_CLASS_MAP[item_cn], "local" + + # 如果没有找到,返回原始名称 + return item_cn, "direct" + +def _to_cn_obstacle(name: str) -> str: + """ + 将英文障碍物名称转换为中文 + :param name: 英文名称 + :return: 中文名称 + """ + try: + key = (name or '').strip().lower() + return _OBSTACLE_NAME_CN.get(key, '障碍物') + except Exception: + return '障碍物' + +def estimate_global_affine(prev_gray, curr_gray, mask=None): + """ + 估计两帧之间的全局仿射变换 + :param prev_gray: 前一帧灰度图 + :param curr_gray: 当前帧灰度图 + :param mask: 可选的掩码,只在掩码区域内计算 + :return: (仿射矩阵, 内点数) + """ + try: + # 提取特征点 + detector = cv2.ORB_create(nfeatures=500) + kp1, des1 = detector.detectAndCompute(prev_gray, mask) + kp2, des2 = detector.detectAndCompute(curr_gray, mask) + + if des1 is None or des2 is None or len(kp1) < 10 or len(kp2) < 10: + return np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32), 0 + + # 匹配特征点 + matcher = cv2.BFMatcher(cv2.NORM_HAMMING, crossCheck=True) + matches = matcher.match(des1, des2) + + if len(matches) < 4: + return np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32), 0 + + # 提取匹配的点对 + src_pts = np.float32([kp1[m.queryIdx].pt for m in matches]).reshape(-1, 1, 2) + dst_pts = np.float32([kp2[m.trainIdx].pt for m in matches]).reshape(-1, 1, 2) + + # 使用RANSAC估计仿射变换 + M, inliers = cv2.estimateAffinePartial2D(src_pts, dst_pts, method=cv2.RANSAC, + ransacReprojThreshold=3.0) + + if M is None: + return np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32), 0 + + inlier_count = np.sum(inliers) if inliers is not None else 0 + return M, inlier_count + + except Exception as e: + logger.warning(f"estimate_global_affine failed: {e}") + return np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32), 0 + +def warp_mask(mask, M, output_shape): + """ + 使用仿射变换对掩码进行变换 + :param mask: 输入掩码 + :param M: 2x3的仿射变换矩阵 + :param output_shape: 输出形状 (width, height) + :return: 变换后的掩码 + """ + try: + if mask is None or M is None: + return None + + W, H = output_shape + warped = cv2.warpAffine(mask, M, (W, H), + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=0) + return warped + + except Exception as e: + logger.warning(f"warp_mask failed: {e}") + return None + +def estimate_translation_flow(prev_gray, curr_gray, mask=None): + """ + 估计两帧之间的平移光流 + :param prev_gray: 前一帧灰度图 + :param curr_gray: 当前帧灰度图 + :param mask: 可选的掩码 + :return: (中位光流幅度, 平移矩阵) + """ + try: + # 计算稀疏光流 + corners = cv2.goodFeaturesToTrack(prev_gray, maxCorners=100, + qualityLevel=0.3, minDistance=7, + mask=mask) + + if corners is None or len(corners) < 10: + return 0.0, np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32) + + # 计算光流 + next_pts, status, _ = cv2.calcOpticalFlowPyrLK(prev_gray, curr_gray, + corners, None) + + # 筛选有效点 + valid_old = corners[status == 1] + valid_new = next_pts[status == 1] + + if len(valid_old) < 5: + return 0.0, np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32) + + # 计算位移 + flow_vectors = valid_new - valid_old + flow_magnitudes = np.linalg.norm(flow_vectors, axis=1) + median_flow = np.median(flow_magnitudes) + + # 估计平均平移 + mean_translation = np.mean(flow_vectors, axis=0) + M = np.array([[1, 0, mean_translation[0]], + [0, 1, mean_translation[1]]], dtype=np.float32) + + return median_flow, M + + except Exception as e: + logger.warning(f"estimate_translation_flow failed: {e}") + return 0.0, np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32) + +def is_stationary_frame(prev_gray, curr_gray, mask=None, threshold=0.35): + """ + 判断用户是否静止 + :param prev_gray: 前一帧灰度图 + :param curr_gray: 当前帧灰度图 + :param mask: 可选的掩码 + :param threshold: 静止判定阈值 + :return: True表示静止,False表示运动 + """ + try: + median_flow, _ = estimate_translation_flow(prev_gray, curr_gray, mask) + return median_flow < threshold + except: + return False + +def compute_approach_metrics(prev_obstacles, curr_obstacles, M, H, W): + """ + 计算障碍物的接近度量 + :param prev_obstacles: 前一帧障碍物列表 + :param curr_obstacles: 当前帧障碍物列表 + :param M: 仿射变换矩阵 + :param H: 图像高度 + :param W: 图像宽度 + :return: 接近度量列表 + """ + metrics = [] + + for curr_obs in curr_obstacles: + # 寻找最佳匹配的前一帧障碍物 + best_match = None + best_iou = 0.0 + + curr_mask = curr_obs.get('mask') + if curr_mask is None: + metrics.append(None) + continue + + for prev_obs in prev_obstacles: + prev_mask = prev_obs.get('mask') + if prev_mask is None: + continue + + # 将前一帧掩码变换到当前帧 + warped_prev = warp_mask(prev_mask, M, (W, H)) + if warped_prev is None: + continue + + # 计算IoU + intersection = np.logical_and(curr_mask > 0, warped_prev > 0).sum() + union = np.logical_or(curr_mask > 0, warped_prev > 0).sum() + iou = intersection / union if union > 0 else 0.0 + + if iou > best_iou: + best_iou = iou + best_match = prev_obs + + if best_match is None: + metrics.append(None) + continue + + # 计算度量 + curr_area = curr_obs.get('area', 0) + prev_area = best_match.get('area', 0) + area_growth = (curr_area - prev_area) / prev_area if prev_area > 0 else 0.0 + + curr_bottom_y = curr_obs.get('bottom_y_ratio', 0) + prev_bottom_y = best_match.get('bottom_y_ratio', 0) + v_forward = curr_bottom_y - prev_bottom_y + + metrics.append({ + 'area_growth': area_growth, + 'v_forward': v_forward, + 'iou': best_iou + }) + + return metrics + +def compute_risk_scores(obstacles, prev_obstacles, M, path_mask, image_shape, + stop_th=0.6, avoid_th=0.56): + """ + 计算障碍物的风险评分 + :param obstacles: 当前障碍物列表 + :param prev_obstacles: 前一帧障碍物列表 + :param M: 仿射变换矩阵 + :param path_mask: 路径掩码 + :param image_shape: 图像形状 + :param stop_th: 停止阈值 + :param avoid_th: 避让阈值 + :return: (评分后的障碍物列表, 是否需要停止, 是否需要避让, 可视化元素) + """ + H, W = image_shape[:2] + has_stop = False + has_avoid = False + risk_vis = [] + + # 计算接近度量 + metrics = compute_approach_metrics(prev_obstacles, obstacles, M, H, W) + + for obs, met in zip(obstacles, metrics): + risk_score = 0.0 + + if met is not None: + # 基于接近速度和面积增长计算风险 + if met['v_forward'] > 0.004: # 向下移动 + risk_score += 0.3 + if met['area_growth'] > 0.01: # 面积增长 + risk_score += 0.3 + + # 基于距离的风险 + bottom_y = obs.get('bottom_y_ratio', 0) + area_ratio = obs.get('area_ratio', 0) + + if bottom_y > 0.8 or area_ratio > 0.15: + risk_score += 0.3 + + # 动态物体额外风险 + name_lower = str(obs.get('name', '')).lower() + if name_lower in DYNAMIC_CLASS_NAMES: + risk_score *= 1.2 + + obs['risk_score'] = risk_score + + # 更新标志 + if risk_score >= stop_th: + has_stop = True + elif risk_score >= avoid_th: + has_avoid = True + + # 添加风险可视化 + if risk_score > 0.3: + risk_color = "rgba(255, 0, 0, 0.3)" if risk_score >= stop_th else "rgba(255, 165, 0, 0.3)" + risk_vis.append({ + "type": "risk_indicator", + "score": risk_score, + "color": risk_color, + "position": [int(obs.get('center_x', W/2)), int(obs.get('center_y', H/2))] + }) + + return obstacles, has_stop, has_avoid, risk_vis + diff --git a/voice/map.zh-CN.json b/voice/map.zh-CN.json new file mode 100644 index 0000000..8f2b4d1 --- /dev/null +++ b/voice/map.zh-CN.json @@ -0,0 +1,714 @@ +{ + "丢失路径,重新搜索。": { + "files": [ + "丢失路径,重新搜索。.wav" + ], + "duration_ms": 1653 + }, + "保持直行": { + "files": [ + "保持直行.wav" + ], + "duration_ms": 587 + }, + "保持直行,靠近盲道。": { + "files": [ + "保持直行,靠近盲道。.wav" + ], + "duration_ms": 1493 + }, + "切换到盲道导航。": { + "files": [ + "切换到盲道导航。.wav" + ], + "duration_ms": 1013 + }, + "到达转弯处,向右平移。": { + "files": [ + "到达转弯处,向右平移。.wav" + ], + "duration_ms": 1973 + }, + "到达转弯处,向左平移。": { + "files": [ + "到达转弯处,向左平移。.wav" + ], + "duration_ms": 1680 + }, + "前方有人,停一下。": { + "files": [ + "前方有人,停一下。.wav" + ], + "duration_ms": 1547 + }, + "前方有人,注意避让。": { + "files": [ + "前方有人,注意避让。.wav" + ], + "duration_ms": 1467 + }, + "前方有公交车,停一下。": { + "files": [ + "前方有公交车,停一下。.wav" + ], + "duration_ms": 1573 + }, + "前方有动物,停一下。": { + "files": [ + "前方有动物,停一下。.wav" + ], + "duration_ms": 1333 + }, + "前方有卡车,停一下。": { + "files": [ + "前方有卡车,停一下。.wav" + ], + "duration_ms": 1440 + }, + "前方有婴儿车,停一下。": { + "files": [ + "前方有婴儿车,停一下。.wav" + ], + "duration_ms": 1573 + }, + "前方有左转弯,继续直行。": { + "files": [ + "前方有左转弯,继续直行。.wav" + ], + "duration_ms": 1920 + }, + "前方有摩托车,停一下。": { + "files": [ + "前方有摩托车,停一下。.wav" + ], + "duration_ms": 1520 + }, + "前方有狗,停一下。": { + "files": [ + "前方有狗,停一下。.wav" + ], + "duration_ms": 1520 + }, + "前方有电瓶车,停一下。": { + "files": [ + "前方有电瓶车,停一下。.wav" + ], + "duration_ms": 1653 + }, + "前方有自行车,停一下。": { + "files": [ + "前方有自行车,停一下。.wav" + ], + "duration_ms": 1467 + }, + "前方有车,停一下。": { + "files": [ + "前方有车,停一下。.wav" + ], + "duration_ms": 804 + }, + "前方有车,注意避让。": { + "files": [ + "前方有车,注意避让。.wav" + ], + "duration_ms": 1520 + }, + "前方有障碍物,停一下。": { + "files": [ + "前方有障碍物,停一下。.wav" + ], + "duration_ms": 1573 + }, + "前方有障碍物,注意避让。": { + "files": [ + "前方有障碍物,注意避让。.wav" + ], + "duration_ms": 1893 + }, + "发现斑马线,对准方向。": { + "files": [ + "发现斑马线,对准方向。.wav" + ], + "duration_ms": 1413 + }, + "右侧有人,停一下。": { + "files": [ + "右侧有人,停一下。.wav" + ], + "duration_ms": 1600 + }, + "右侧有公交车,停一下。": { + "files": [ + "右侧有公交车,停一下。.wav" + ], + "duration_ms": 1467 + }, + "右侧有动物,停一下。": { + "files": [ + "右侧有动物,停一下。.wav" + ], + "duration_ms": 1467 + }, + "右侧有卡车,停一下。": { + "files": [ + "右侧有卡车,停一下。.wav" + ], + "duration_ms": 1413 + }, + "右侧有婴儿车,停一下。": { + "files": [ + "右侧有婴儿车,停一下。.wav" + ], + "duration_ms": 1440 + }, + "右侧有摩托车,停一下。": { + "files": [ + "右侧有摩托车,停一下。.wav" + ], + "duration_ms": 1467 + }, + "右侧有狗,停一下。": { + "files": [ + "右侧有狗,停一下。.wav" + ], + "duration_ms": 1467 + }, + "右侧有电瓶车,停一下。": { + "files": [ + "右侧有电瓶车,停一下。.wav" + ], + "duration_ms": 1657 + }, + "右侧有自行车,停一下。": { + "files": [ + "右侧有自行车,停一下。.wav" + ], + "duration_ms": 1680 + }, + "右侧有车,停一下。": { + "files": [ + "右侧有车,停一下。.wav" + ], + "duration_ms": 1440 + }, + "右侧有障碍物,停一下。": { + "files": [ + "右侧有障碍物,停一下。.wav" + ], + "duration_ms": 1573 + }, + "右移": { + "files": [ + "右移.wav" + ], + "duration_ms": 378 + }, + "右转": { + "files": [ + "右转.wav" + ], + "duration_ms": 507 + }, + "右转一点": { + "files": [ + "右转一点.wav" + ], + "duration_ms": 560 + }, + "向前直行几步越过障碍物。然后说‘好了’。": { + "files": [ + "向前直行几步越过障碍物。然后说‘好了’。.wav" + ], + "duration_ms": 2667 + }, + "向右平移": { + "files": [ + "向右平移.wav" + ], + "duration_ms": 720 + }, + "向右平移,对准盲道": { + "files": [ + "向右平移,对准盲道.wav" + ], + "duration_ms": 1733 + }, + "向左平移": { + "files": [ + "向左平移.wav" + ], + "duration_ms": 720 + }, + "向左平移,对准盲道": { + "files": [ + "向左平移,对准盲道.wav" + ], + "duration_ms": 1573 + }, + "启动过马路模式失败,请稍后重试。": { + "files": [ + "启动过马路模式失败,请稍后重试。.wav" + ], + "duration_ms": 2187 + }, + "好的,请停下侧移。": { + "files": [ + "好的,请停下侧移。.wav" + ], + "duration_ms": 1658 + }, + "寻物任务完成。": { + "files": [ + "寻物任务完成。.wav" + ], + "duration_ms": 827 + }, + "导航已被取消。": { + "files": [ + "导航已被取消。.wav" + ], + "duration_ms": 827 + }, + "左侧有人,停一下。": { + "files": [ + "左侧有人,停一下。.wav" + ], + "duration_ms": 1013 + }, + "左侧有公交车,停一下。": { + "files": [ + "左侧有公交车,停一下。.wav" + ], + "duration_ms": 1493 + }, + "左侧有动物,停一下。": { + "files": [ + "左侧有动物,停一下。.wav" + ], + "duration_ms": 1360 + }, + "左侧有卡车,停一下。": { + "files": [ + "左侧有卡车,停一下。.wav" + ], + "duration_ms": 1467 + }, + "左侧有婴儿车,停一下。": { + "files": [ + "左侧有婴儿车,停一下。.wav" + ], + "duration_ms": 1600 + }, + "左侧有摩托车,停一下。": { + "files": [ + "左侧有摩托车,停一下。.wav" + ], + "duration_ms": 1520 + }, + "左侧有狗,停一下。": { + "files": [ + "左侧有狗,停一下。.wav" + ], + "duration_ms": 1627 + }, + "左侧有电瓶车,停一下。": { + "files": [ + "左侧有电瓶车,停一下。.wav" + ], + "duration_ms": 1733 + }, + "左侧有自行车,停一下。": { + "files": [ + "左侧有自行车,停一下。.wav" + ], + "duration_ms": 1520 + }, + "左侧有车,停一下。": { + "files": [ + "左侧有车,停一下。.wav" + ], + "duration_ms": 1467 + }, + "左侧有障碍物,停一下。": { + "files": [ + "左侧有障碍物,停一下。.wav" + ], + "duration_ms": 1627 + }, + "左移": { + "files": [ + "左移.wav" + ], + "duration_ms": 378 + }, + "左转": { + "files": [ + "左转.wav" + ], + "duration_ms": 378 + }, + "左转一点": { + "files": [ + "左转一点.wav" + ], + "duration_ms": 613 + }, + "已停止导航。": { + "files": [ + "已停止导航。.wav" + ], + "duration_ms": 827 + }, + "已到盲道跟前,切换到盲道导航。": { + "files": [ + "已到盲道跟前,切换到盲道导航。.wav" + ], + "duration_ms": 2080 + }, + "已到达目标前方,请注意。": { + "files": [ + "已到达目标前方,请注意。.wav" + ], + "duration_ms": 1760 + }, + "已到达目标,引导结束。": { + "files": [ + "已到达目标,引导结束。.wav" + ], + "duration_ms": 1733 + }, + "已回到盲道。": { + "files": [ + "已回到盲道。.wav" + ], + "duration_ms": 773 + }, + "已对准, 准备切换过马路模式。": { + "files": [ + "已对准, 准备切换过马路模式。.wav" + ], + "duration_ms": 1893 + }, + "已对准新路径,请向前直行。": { + "files": [ + "已对准新路径,请向前直行。.wav" + ], + "duration_ms": 1973 + }, + "引导超时,自动结束。": { + "files": [ + "引导超时,自动结束。.wav" + ], + "duration_ms": 1657 + }, + "收到,准备回归盲道。": { + "files": [ + "收到,准备回归盲道。.wav" + ], + "duration_ms": 1547 + }, + "斑马线已对准,继续前行。": { + "files": [ + "斑马线已对准,继续前行。.wav" + ], + "duration_ms": 1733 + }, + "方向已对正!现在校准位置。": { + "files": [ + "方向已对正!现在校准位置。.wav" + ], + "duration_ms": 1760 + }, + "方向正确,请直行。": { + "files": [ + "方向正确,请直行。.wav" + ], + "duration_ms": 1573 + }, + "方向正确,请继续前进。": { + "files": [ + "方向正确,请继续前进。.wav" + ], + "duration_ms": 1867 + }, + "校准完成!您已在盲道上,开始前行。": { + "files": [ + "校准完成!您已在盲道上,开始前行。.wav" + ], + "duration_ms": 2613 + }, + "检测到已移动,开始对准新方向。": { + "files": [ + "检测到已移动,开始对准新方向。.wav" + ], + "duration_ms": 2000 + }, + "正在接近斑马线,为您对准方向。": { + "files": [ + "正在接近斑马线,为您对准方向。.wav" + ], + "duration_ms": 2000 + }, + "正在等待绿灯…": { + "files": [ + "正在等待绿灯….wav" + ], + "duration_ms": 640 + }, + "没看到盲道,请向右侧小幅移动。": { + "files": [ + "没看到盲道,请向右侧小幅移动。.wav" + ], + "duration_ms": 2240 + }, + "没看到盲道,请向左侧小幅移动。": { + "files": [ + "没看到盲道,请向左侧小幅移动。.wav" + ], + "duration_ms": 2267 + }, + "目标在您的左前方,请右转一点。": { + "files": [ + "目标在您的左前方,请右转一点。.wav" + ], + "duration_ms": 1947 + }, + "目标在您的左前方,请左转一点。": { + "files": [ + "目标在您的左前方,请左转一点。.wav" + ], + "duration_ms": 2000 + }, + "目标就在前方,请慢慢靠近。": { + "files": [ + "目标就在前方,请慢慢靠近。.wav" + ], + "duration_ms": 2160 + }, + "目标消失,请原地小幅转动。": { + "files": [ + "目标消失,请原地小幅转动。.wav" + ], + "duration_ms": 2080 + }, + "目标消失,请原地等待。": { + "files": [ + "目标消失,请原地等待。.wav" + ], + "duration_ms": 1733 + }, + "盲道已接近,开始对准盲道。": { + "files": [ + "盲道已接近,开始对准盲道。.wav" + ], + "duration_ms": 1973 + }, + "稍微向右调整,继续前进。": { + "files": [ + "稍微向右调整,继续前进。.wav" + ], + "duration_ms": 1840 + }, + "稍微向左调整,继续前进。": { + "files": [ + "稍微向左调整,继续前进。.wav" + ], + "duration_ms": 2053 + }, + "绿灯稳定,开始通行。": { + "files": [ + "绿灯稳定,开始通行。.wav" + ], + "duration_ms": 1547 + }, + "绿灯快没了": { + "files": [ + "绿灯快没了.wav" + ], + "duration_ms": 1200 + }, + "开始通行": { + "files": [ + "绿灯稳定,开始通行。.wav" + ], + "duration_ms": 1547 + }, + "斑马线已在跟前,进入红绿灯判定模式": { + "files": [ + "正在等待绿灯….wav" + ], + "duration_ms": 640 + }, + "请向右平移。": { + "files": [ + "请向右平移。.wav" + ], + "duration_ms": 747 + }, + "请向右微调,对准盲道。": { + "files": [ + "请向右微调,对准盲道。.wav" + ], + "duration_ms": 1573 + }, + "请向右转动。": { + "files": [ + "请向右转动。.wav" + ], + "duration_ms": 773 + }, + "请向左平移。": { + "files": [ + "请向左平移。.wav" + ], + "duration_ms": 933 + }, + "请向左微调,对准盲道。": { + "files": [ + "请向左微调,对准盲道。.wav" + ], + "duration_ms": 1573 + }, + "请向左转动。": { + "files": [ + "请向左转动。.wav" + ], + "duration_ms": 667 + }, + "请继续向右平移。": { + "files": [ + "请继续向右平移。.wav" + ], + "duration_ms": 1413 + }, + "请继续向左平移。": { + "files": [ + "请继续向左平移。.wav" + ], + "duration_ms": 1360 + }, + "请问完成了吗?": { + "files": [ + "请问完成了吗?.wav" + ], + "duration_ms": 773 + }, + "路径太远,请继续靠近": { + "files": [ + "路径太远,请继续靠近.wav" + ], + "duration_ms": 1520 + }, + "路径被挡住,请向右侧平移。": { + "files": [ + "路径被挡住,请向右侧平移。.wav" + ], + "duration_ms": 2000 + }, + "路径被挡住,请向左侧平移。": { + "files": [ + "路径被挡住,请向左侧平移。.wav" + ], + "duration_ms": 1760 + }, + "过马路模式已启动。": { + "files": [ + "过马路模式已启动。.wav" + ], + "duration_ms": 1360 + }, + "过马路结束,准备上人行道。": { + "files": [ + "过马路结束,准备上人行道。.wav" + ], + "duration_ms": 2053 + }, + "远处发现斑马线,继续直行。": { + "files": [ + "远处发现斑马线,继续直行。.wav" + ], + "duration_ms": 1920 + }, + "远处有盲道,继续前行。": { + "files": [ + "远处有盲道,继续前行。.wav" + ], + "duration_ms": 1733 + }, + + "避让完成,已回到盲道。": { + "files": [ + "避让完成,已回到盲道。.wav" + ], + "duration_ms": 1600 + }, + + "前方有右转弯,继续直行。": { + "files": [ + "前方有右转弯,继续直行。.wav" + ], + "duration_ms": 1657 + }, + "红灯": { + "files": [ + "红灯.WAV" + ], + "duration_ms": 400 + }, + "绿灯": { + "files": [ + "绿灯.WAV" + ], + "duration_ms": 400 + }, + "黄灯": { + "files": [ + "黄灯.WAV" + ], + "duration_ms": 400 + }, + "远处发现斑马线": { + "files": [ + "../music/远处发现斑马线.WAV" + ], + "duration_ms": 1600 + }, + "正在靠近斑马线": { + "files": [ + "../music/正在靠近斑马线.WAV" + ], + "duration_ms": 1600 + }, + "接近斑马线": { + "files": [ + "../music/接近斑马线.WAV" + ], + "duration_ms": 1200 + }, + "斑马线到了可以过马路": { + "files": [ + "../music/斑马线到了可以过马路.WAV" + ], + "duration_ms": 2000 + }, + "在画面左侧": { + "files": [ + "../music/在画面左侧.WAV" + ], + "duration_ms": 1200 + }, + "在画面中间": { + "files": [ + "../music/在画面中间.WAV" + ], + "duration_ms": 1200 + }, + "在画面右侧": { + "files": [ + "../music/在画面右侧.WAV" + ], + "duration_ms": 1200 + } +} + + diff --git a/voice/丢失路径,重新搜索。.wav b/voice/丢失路径,重新搜索。.wav new file mode 100644 index 0000000..a682114 Binary files /dev/null and b/voice/丢失路径,重新搜索。.wav differ diff --git a/voice/保持直行.wav b/voice/保持直行.wav new file mode 100644 index 0000000..5ad19fb Binary files /dev/null and b/voice/保持直行.wav differ diff --git a/voice/保持直行,靠近盲道。.wav b/voice/保持直行,靠近盲道。.wav new file mode 100644 index 0000000..0becc06 Binary files /dev/null and b/voice/保持直行,靠近盲道。.wav differ diff --git a/voice/切换到盲道导航。.wav b/voice/切换到盲道导航。.wav new file mode 100644 index 0000000..eb75924 Binary files /dev/null and b/voice/切换到盲道导航。.wav differ diff --git a/voice/到达转弯处,向右平移。.wav b/voice/到达转弯处,向右平移。.wav new file mode 100644 index 0000000..a45c286 Binary files /dev/null and b/voice/到达转弯处,向右平移。.wav differ diff --git a/voice/到达转弯处,向左平移。.wav b/voice/到达转弯处,向左平移。.wav new file mode 100644 index 0000000..f057229 Binary files /dev/null and b/voice/到达转弯处,向左平移。.wav differ diff --git a/voice/前方有人,停一下。.wav b/voice/前方有人,停一下。.wav new file mode 100644 index 0000000..02ef0f4 Binary files /dev/null and b/voice/前方有人,停一下。.wav differ diff --git a/voice/前方有人,注意避让。.wav b/voice/前方有人,注意避让。.wav new file mode 100644 index 0000000..76f2237 Binary files /dev/null and b/voice/前方有人,注意避让。.wav differ diff --git a/voice/前方有公交车,停一下。.wav b/voice/前方有公交车,停一下。.wav new file mode 100644 index 0000000..84448f6 Binary files /dev/null and b/voice/前方有公交车,停一下。.wav differ diff --git a/voice/前方有动物,停一下。.wav b/voice/前方有动物,停一下。.wav new file mode 100644 index 0000000..02e35b3 Binary files /dev/null and b/voice/前方有动物,停一下。.wav differ diff --git a/voice/前方有卡车,停一下。.wav b/voice/前方有卡车,停一下。.wav new file mode 100644 index 0000000..680f7e5 Binary files /dev/null and b/voice/前方有卡车,停一下。.wav differ diff --git a/voice/前方有右转弯,继续直行。.wav b/voice/前方有右转弯,继续直行。.wav new file mode 100644 index 0000000..f007c04 Binary files /dev/null and b/voice/前方有右转弯,继续直行。.wav differ diff --git a/voice/前方有婴儿车,停一下。.wav b/voice/前方有婴儿车,停一下。.wav new file mode 100644 index 0000000..0d7b019 Binary files /dev/null and b/voice/前方有婴儿车,停一下。.wav differ diff --git a/voice/前方有左转弯,继续直行。.wav b/voice/前方有左转弯,继续直行。.wav new file mode 100644 index 0000000..e4b4aea Binary files /dev/null and b/voice/前方有左转弯,继续直行。.wav differ diff --git a/voice/前方有摩托车,停一下。.wav b/voice/前方有摩托车,停一下。.wav new file mode 100644 index 0000000..954e8ff Binary files /dev/null and b/voice/前方有摩托车,停一下。.wav differ diff --git a/voice/前方有狗,停一下。.wav b/voice/前方有狗,停一下。.wav new file mode 100644 index 0000000..4d8d29d Binary files /dev/null and b/voice/前方有狗,停一下。.wav differ diff --git a/voice/前方有电瓶车,停一下。.wav b/voice/前方有电瓶车,停一下。.wav new file mode 100644 index 0000000..713ba64 Binary files /dev/null and b/voice/前方有电瓶车,停一下。.wav differ diff --git a/voice/前方有自行车,停一下。.wav b/voice/前方有自行车,停一下。.wav new file mode 100644 index 0000000..6dfabac Binary files /dev/null and b/voice/前方有自行车,停一下。.wav differ diff --git a/voice/前方有车,停一下。.wav b/voice/前方有车,停一下。.wav new file mode 100644 index 0000000..5270b48 Binary files /dev/null and b/voice/前方有车,停一下。.wav differ diff --git a/voice/前方有车,注意避让。.wav b/voice/前方有车,注意避让。.wav new file mode 100644 index 0000000..036e945 Binary files /dev/null and b/voice/前方有车,注意避让。.wav differ diff --git a/voice/前方有障碍物,停一下。.wav b/voice/前方有障碍物,停一下。.wav new file mode 100644 index 0000000..13f0adc Binary files /dev/null and b/voice/前方有障碍物,停一下。.wav differ diff --git a/voice/前方有障碍物,注意避让。.wav b/voice/前方有障碍物,注意避让。.wav new file mode 100644 index 0000000..09d0ea7 Binary files /dev/null and b/voice/前方有障碍物,注意避让。.wav differ diff --git a/voice/发现斑马线,对准方向。.wav b/voice/发现斑马线,对准方向。.wav new file mode 100644 index 0000000..22a2072 Binary files /dev/null and b/voice/发现斑马线,对准方向。.wav differ diff --git a/voice/右侧有人,停一下。.wav b/voice/右侧有人,停一下。.wav new file mode 100644 index 0000000..af80608 Binary files /dev/null and b/voice/右侧有人,停一下。.wav differ diff --git a/voice/右侧有公交车,停一下。.wav b/voice/右侧有公交车,停一下。.wav new file mode 100644 index 0000000..f3fae0a Binary files /dev/null and b/voice/右侧有公交车,停一下。.wav differ diff --git a/voice/右侧有动物,停一下。.wav b/voice/右侧有动物,停一下。.wav new file mode 100644 index 0000000..eaf2995 Binary files /dev/null and b/voice/右侧有动物,停一下。.wav differ diff --git a/voice/右侧有卡车,停一下。.wav b/voice/右侧有卡车,停一下。.wav new file mode 100644 index 0000000..110dc5b Binary files /dev/null and b/voice/右侧有卡车,停一下。.wav differ diff --git a/voice/右侧有婴儿车,停一下。.wav b/voice/右侧有婴儿车,停一下。.wav new file mode 100644 index 0000000..5670c33 Binary files /dev/null and b/voice/右侧有婴儿车,停一下。.wav differ diff --git a/voice/右侧有摩托车,停一下。.wav b/voice/右侧有摩托车,停一下。.wav new file mode 100644 index 0000000..4da98bc Binary files /dev/null and b/voice/右侧有摩托车,停一下。.wav differ diff --git a/voice/右侧有狗,停一下。.wav b/voice/右侧有狗,停一下。.wav new file mode 100644 index 0000000..0f32e13 Binary files /dev/null and b/voice/右侧有狗,停一下。.wav differ diff --git a/voice/右侧有电瓶车,停一下。.wav b/voice/右侧有电瓶车,停一下。.wav new file mode 100644 index 0000000..18edc78 Binary files /dev/null and b/voice/右侧有电瓶车,停一下。.wav differ diff --git a/voice/右侧有自行车,停一下。.wav b/voice/右侧有自行车,停一下。.wav new file mode 100644 index 0000000..c40102c Binary files /dev/null and b/voice/右侧有自行车,停一下。.wav differ diff --git a/voice/右侧有车,停一下。.wav b/voice/右侧有车,停一下。.wav new file mode 100644 index 0000000..c64bc94 Binary files /dev/null and b/voice/右侧有车,停一下。.wav differ diff --git a/voice/右侧有障碍物,停一下。.wav b/voice/右侧有障碍物,停一下。.wav new file mode 100644 index 0000000..f882471 Binary files /dev/null and b/voice/右侧有障碍物,停一下。.wav differ diff --git a/voice/右移.wav b/voice/右移.wav new file mode 100644 index 0000000..65fc54a Binary files /dev/null and b/voice/右移.wav differ diff --git a/voice/右转.wav b/voice/右转.wav new file mode 100644 index 0000000..b3b6be9 Binary files /dev/null and b/voice/右转.wav differ diff --git a/voice/右转一点.wav b/voice/右转一点.wav new file mode 100644 index 0000000..4e4386c Binary files /dev/null and b/voice/右转一点.wav differ diff --git a/voice/向前直行几步越过障碍物。然后说‘好了’。.wav b/voice/向前直行几步越过障碍物。然后说‘好了’。.wav new file mode 100644 index 0000000..ee9623a Binary files /dev/null and b/voice/向前直行几步越过障碍物。然后说‘好了’。.wav differ diff --git a/voice/向右平移.wav b/voice/向右平移.wav new file mode 100644 index 0000000..bb56764 Binary files /dev/null and b/voice/向右平移.wav differ diff --git a/voice/向右平移,对准盲道.wav b/voice/向右平移,对准盲道.wav new file mode 100644 index 0000000..04ff64b Binary files /dev/null and b/voice/向右平移,对准盲道.wav differ diff --git a/voice/向左平移.wav b/voice/向左平移.wav new file mode 100644 index 0000000..4b3156a Binary files /dev/null and b/voice/向左平移.wav differ diff --git a/voice/向左平移,对准盲道.wav b/voice/向左平移,对准盲道.wav new file mode 100644 index 0000000..346b297 Binary files /dev/null and b/voice/向左平移,对准盲道.wav differ diff --git a/voice/启动过马路模式失败,请稍后重试。.wav b/voice/启动过马路模式失败,请稍后重试。.wav new file mode 100644 index 0000000..464a08c Binary files /dev/null and b/voice/启动过马路模式失败,请稍后重试。.wav differ diff --git a/voice/好的,请停下侧移。.wav b/voice/好的,请停下侧移。.wav new file mode 100644 index 0000000..625573c Binary files /dev/null and b/voice/好的,请停下侧移。.wav differ diff --git a/voice/寻物任务完成。.wav b/voice/寻物任务完成。.wav new file mode 100644 index 0000000..184c2d9 Binary files /dev/null and b/voice/寻物任务完成。.wav differ diff --git a/voice/导航已被取消。.wav b/voice/导航已被取消。.wav new file mode 100644 index 0000000..ddf6773 Binary files /dev/null and b/voice/导航已被取消。.wav differ diff --git a/voice/左侧有人,停一下。.wav b/voice/左侧有人,停一下。.wav new file mode 100644 index 0000000..653e845 Binary files /dev/null and b/voice/左侧有人,停一下。.wav differ diff --git a/voice/左侧有公交车,停一下。.wav b/voice/左侧有公交车,停一下。.wav new file mode 100644 index 0000000..7a3d637 Binary files /dev/null and b/voice/左侧有公交车,停一下。.wav differ diff --git a/voice/左侧有动物,停一下。.wav b/voice/左侧有动物,停一下。.wav new file mode 100644 index 0000000..7099472 Binary files /dev/null and b/voice/左侧有动物,停一下。.wav differ diff --git a/voice/左侧有卡车,停一下。.wav b/voice/左侧有卡车,停一下。.wav new file mode 100644 index 0000000..cb5dfcc Binary files /dev/null and b/voice/左侧有卡车,停一下。.wav differ diff --git a/voice/左侧有婴儿车,停一下。.wav b/voice/左侧有婴儿车,停一下。.wav new file mode 100644 index 0000000..c26ccb8 Binary files /dev/null and b/voice/左侧有婴儿车,停一下。.wav differ diff --git a/voice/左侧有摩托车,停一下。.wav b/voice/左侧有摩托车,停一下。.wav new file mode 100644 index 0000000..49ec8cd Binary files /dev/null and b/voice/左侧有摩托车,停一下。.wav differ diff --git a/voice/左侧有狗,停一下。.wav b/voice/左侧有狗,停一下。.wav new file mode 100644 index 0000000..e53ec61 Binary files /dev/null and b/voice/左侧有狗,停一下。.wav differ diff --git a/voice/左侧有电瓶车,停一下。.wav b/voice/左侧有电瓶车,停一下。.wav new file mode 100644 index 0000000..13b6ede Binary files /dev/null and b/voice/左侧有电瓶车,停一下。.wav differ diff --git a/voice/左侧有自行车,停一下。.wav b/voice/左侧有自行车,停一下。.wav new file mode 100644 index 0000000..66b681e Binary files /dev/null and b/voice/左侧有自行车,停一下。.wav differ diff --git a/voice/左侧有车,停一下。.wav b/voice/左侧有车,停一下。.wav new file mode 100644 index 0000000..cd72629 Binary files /dev/null and b/voice/左侧有车,停一下。.wav differ diff --git a/voice/左侧有障碍物,停一下。.wav b/voice/左侧有障碍物,停一下。.wav new file mode 100644 index 0000000..f86af6b Binary files /dev/null and b/voice/左侧有障碍物,停一下。.wav differ diff --git a/voice/左移.wav b/voice/左移.wav new file mode 100644 index 0000000..b7bd8de Binary files /dev/null and b/voice/左移.wav differ diff --git a/voice/左转.wav b/voice/左转.wav new file mode 100644 index 0000000..10e6f4a Binary files /dev/null and b/voice/左转.wav differ diff --git a/voice/左转一点.wav b/voice/左转一点.wav new file mode 100644 index 0000000..6a8a27a Binary files /dev/null and b/voice/左转一点.wav differ diff --git a/voice/已停止导航。.wav b/voice/已停止导航。.wav new file mode 100644 index 0000000..fda991f Binary files /dev/null and b/voice/已停止导航。.wav differ diff --git a/voice/已到盲道跟前,切换到盲道导航。.wav b/voice/已到盲道跟前,切换到盲道导航。.wav new file mode 100644 index 0000000..402756e Binary files /dev/null and b/voice/已到盲道跟前,切换到盲道导航。.wav differ diff --git a/voice/已到达目标前方,请注意。.wav b/voice/已到达目标前方,请注意。.wav new file mode 100644 index 0000000..8ae1ef6 Binary files /dev/null and b/voice/已到达目标前方,请注意。.wav differ diff --git a/voice/已到达目标,引导结束。.wav b/voice/已到达目标,引导结束。.wav new file mode 100644 index 0000000..107ac0d Binary files /dev/null and b/voice/已到达目标,引导结束。.wav differ diff --git a/voice/已回到盲道。.wav b/voice/已回到盲道。.wav new file mode 100644 index 0000000..333ab65 Binary files /dev/null and b/voice/已回到盲道。.wav differ diff --git a/voice/已对准, 准备切换过马路模式。.wav b/voice/已对准, 准备切换过马路模式。.wav new file mode 100644 index 0000000..3f23ee2 Binary files /dev/null and b/voice/已对准, 准备切换过马路模式。.wav differ diff --git a/voice/已对准新路径,请向前直行。.wav b/voice/已对准新路径,请向前直行。.wav new file mode 100644 index 0000000..4f6e867 Binary files /dev/null and b/voice/已对准新路径,请向前直行。.wav differ diff --git a/voice/引导超时,自动结束。.wav b/voice/引导超时,自动结束。.wav new file mode 100644 index 0000000..f808ec8 Binary files /dev/null and b/voice/引导超时,自动结束。.wav differ diff --git a/voice/收到,准备回归盲道。.wav b/voice/收到,准备回归盲道。.wav new file mode 100644 index 0000000..f6801b6 Binary files /dev/null and b/voice/收到,准备回归盲道。.wav differ diff --git a/voice/斑马线已对准,继续前行。.wav b/voice/斑马线已对准,继续前行。.wav new file mode 100644 index 0000000..2f566a0 Binary files /dev/null and b/voice/斑马线已对准,继续前行。.wav differ diff --git a/voice/方向已对正!现在校准位置。.wav b/voice/方向已对正!现在校准位置。.wav new file mode 100644 index 0000000..19d638f Binary files /dev/null and b/voice/方向已对正!现在校准位置。.wav differ diff --git a/voice/方向正确,请直行。.wav b/voice/方向正确,请直行。.wav new file mode 100644 index 0000000..fddd765 Binary files /dev/null and b/voice/方向正确,请直行。.wav differ diff --git a/voice/方向正确,请继续前进。.wav b/voice/方向正确,请继续前进。.wav new file mode 100644 index 0000000..a2f9510 Binary files /dev/null and b/voice/方向正确,请继续前进。.wav differ diff --git a/voice/校准完成!您已在盲道上,开始前行。.wav b/voice/校准完成!您已在盲道上,开始前行。.wav new file mode 100644 index 0000000..9b16fd3 Binary files /dev/null and b/voice/校准完成!您已在盲道上,开始前行。.wav differ diff --git a/voice/检测到已移动,开始对准新方向。.wav b/voice/检测到已移动,开始对准新方向。.wav new file mode 100644 index 0000000..9d64665 Binary files /dev/null and b/voice/检测到已移动,开始对准新方向。.wav differ diff --git a/voice/正在接近斑马线,为您对准方向。.wav b/voice/正在接近斑马线,为您对准方向。.wav new file mode 100644 index 0000000..f984c60 Binary files /dev/null and b/voice/正在接近斑马线,为您对准方向。.wav differ diff --git a/voice/正在等待绿灯….wav b/voice/正在等待绿灯….wav new file mode 100644 index 0000000..6ebf2d0 Binary files /dev/null and b/voice/正在等待绿灯….wav differ diff --git a/voice/没看到盲道,请向右侧小幅移动。.wav b/voice/没看到盲道,请向右侧小幅移动。.wav new file mode 100644 index 0000000..2b3e2dd Binary files /dev/null and b/voice/没看到盲道,请向右侧小幅移动。.wav differ diff --git a/voice/没看到盲道,请向左侧小幅移动。.wav b/voice/没看到盲道,请向左侧小幅移动。.wav new file mode 100644 index 0000000..1f66f4e Binary files /dev/null and b/voice/没看到盲道,请向左侧小幅移动。.wav differ diff --git a/voice/目标在您的左前方,请右转一点。.wav b/voice/目标在您的左前方,请右转一点。.wav new file mode 100644 index 0000000..e7497fe Binary files /dev/null and b/voice/目标在您的左前方,请右转一点。.wav differ diff --git a/voice/目标在您的左前方,请左转一点。.wav b/voice/目标在您的左前方,请左转一点。.wav new file mode 100644 index 0000000..79509f4 Binary files /dev/null and b/voice/目标在您的左前方,请左转一点。.wav differ diff --git a/voice/目标就在前方,请慢慢靠近。.wav b/voice/目标就在前方,请慢慢靠近。.wav new file mode 100644 index 0000000..b31ca15 Binary files /dev/null and b/voice/目标就在前方,请慢慢靠近。.wav differ diff --git a/voice/目标消失,请原地小幅转动。.wav b/voice/目标消失,请原地小幅转动。.wav new file mode 100644 index 0000000..2dac313 Binary files /dev/null and b/voice/目标消失,请原地小幅转动。.wav differ diff --git a/voice/目标消失,请原地等待。.wav b/voice/目标消失,请原地等待。.wav new file mode 100644 index 0000000..8db3ce9 Binary files /dev/null and b/voice/目标消失,请原地等待。.wav differ diff --git a/voice/盲道已接近,开始对准盲道。.wav b/voice/盲道已接近,开始对准盲道。.wav new file mode 100644 index 0000000..d5d4779 Binary files /dev/null and b/voice/盲道已接近,开始对准盲道。.wav differ diff --git a/voice/稍微向右调整,继续前进。.wav b/voice/稍微向右调整,继续前进。.wav new file mode 100644 index 0000000..db0706b Binary files /dev/null and b/voice/稍微向右调整,继续前进。.wav differ diff --git a/voice/稍微向左调整,继续前进。.wav b/voice/稍微向左调整,继续前进。.wav new file mode 100644 index 0000000..92aecc0 Binary files /dev/null and b/voice/稍微向左调整,继续前进。.wav differ diff --git a/voice/红灯.WAV b/voice/红灯.WAV new file mode 100644 index 0000000..59c34f9 Binary files /dev/null and b/voice/红灯.WAV differ diff --git a/voice/红灯_原始.WAV b/voice/红灯_原始.WAV new file mode 100644 index 0000000..f53ff7f Binary files /dev/null and b/voice/红灯_原始.WAV differ diff --git a/voice/绿灯.WAV b/voice/绿灯.WAV new file mode 100644 index 0000000..75981ec Binary files /dev/null and b/voice/绿灯.WAV differ diff --git a/voice/绿灯_原始.WAV b/voice/绿灯_原始.WAV new file mode 100644 index 0000000..1a5725c Binary files /dev/null and b/voice/绿灯_原始.WAV differ diff --git a/voice/绿灯快没了.wav b/voice/绿灯快没了.wav new file mode 100644 index 0000000..4d63baa Binary files /dev/null and b/voice/绿灯快没了.wav differ diff --git a/voice/绿灯稳定,开始通行。.wav b/voice/绿灯稳定,开始通行。.wav new file mode 100644 index 0000000..6c175de Binary files /dev/null and b/voice/绿灯稳定,开始通行。.wav differ diff --git a/voice/请向右平移。.wav b/voice/请向右平移。.wav new file mode 100644 index 0000000..48cb226 Binary files /dev/null and b/voice/请向右平移。.wav differ diff --git a/voice/请向右微调,对准盲道。.wav b/voice/请向右微调,对准盲道。.wav new file mode 100644 index 0000000..eb197ec Binary files /dev/null and b/voice/请向右微调,对准盲道。.wav differ diff --git a/voice/请向右转动。.wav b/voice/请向右转动。.wav new file mode 100644 index 0000000..5a4b8fb Binary files /dev/null and b/voice/请向右转动。.wav differ diff --git a/voice/请向左平移。.wav b/voice/请向左平移。.wav new file mode 100644 index 0000000..60fe6e6 Binary files /dev/null and b/voice/请向左平移。.wav differ diff --git a/voice/请向左微调,对准盲道。.wav b/voice/请向左微调,对准盲道。.wav new file mode 100644 index 0000000..54c5cb3 Binary files /dev/null and b/voice/请向左微调,对准盲道。.wav differ diff --git a/voice/请向左转动。.wav b/voice/请向左转动。.wav new file mode 100644 index 0000000..381281e Binary files /dev/null and b/voice/请向左转动。.wav differ diff --git a/voice/请继续向右平移。.wav b/voice/请继续向右平移。.wav new file mode 100644 index 0000000..03f0bf8 Binary files /dev/null and b/voice/请继续向右平移。.wav differ diff --git a/voice/请继续向左平移。.wav b/voice/请继续向左平移。.wav new file mode 100644 index 0000000..946fc82 Binary files /dev/null and b/voice/请继续向左平移。.wav differ diff --git a/voice/请问完成了吗?.wav b/voice/请问完成了吗?.wav new file mode 100644 index 0000000..b6ce528 Binary files /dev/null and b/voice/请问完成了吗?.wav differ diff --git a/voice/路径太远,请继续靠近.wav b/voice/路径太远,请继续靠近.wav new file mode 100644 index 0000000..db0947c Binary files /dev/null and b/voice/路径太远,请继续靠近.wav differ diff --git a/voice/路径被挡住,请向右侧平移。.wav b/voice/路径被挡住,请向右侧平移。.wav new file mode 100644 index 0000000..0296ec4 Binary files /dev/null and b/voice/路径被挡住,请向右侧平移。.wav differ diff --git a/voice/路径被挡住,请向左侧平移。.wav b/voice/路径被挡住,请向左侧平移。.wav new file mode 100644 index 0000000..d6bd1f0 Binary files /dev/null and b/voice/路径被挡住,请向左侧平移。.wav differ diff --git a/voice/过马路模式已启动。.wav b/voice/过马路模式已启动。.wav new file mode 100644 index 0000000..3ef4eff Binary files /dev/null and b/voice/过马路模式已启动。.wav differ diff --git a/voice/过马路结束,准备上人行道。.wav b/voice/过马路结束,准备上人行道。.wav new file mode 100644 index 0000000..f665f3d Binary files /dev/null and b/voice/过马路结束,准备上人行道。.wav differ diff --git a/voice/远处发现斑马线,继续直行。.wav b/voice/远处发现斑马线,继续直行。.wav new file mode 100644 index 0000000..f5c9a06 Binary files /dev/null and b/voice/远处发现斑马线,继续直行。.wav differ diff --git a/voice/远处有盲道,继续前行。.wav b/voice/远处有盲道,继续前行。.wav new file mode 100644 index 0000000..d461e89 Binary files /dev/null and b/voice/远处有盲道,继续前行。.wav differ diff --git a/voice/避让完成,已回到盲道。.wav b/voice/避让完成,已回到盲道。.wav new file mode 100644 index 0000000..2048fed Binary files /dev/null and b/voice/避让完成,已回到盲道。.wav differ diff --git a/voice/黄灯.WAV b/voice/黄灯.WAV new file mode 100644 index 0000000..c44a249 Binary files /dev/null and b/voice/黄灯.WAV differ diff --git a/voice/黄灯_原始.WAV b/voice/黄灯_原始.WAV new file mode 100644 index 0000000..e4051d3 Binary files /dev/null and b/voice/黄灯_原始.WAV differ diff --git a/workflow_blindpath.py b/workflow_blindpath.py new file mode 100644 index 0000000..ef4a98b --- /dev/null +++ b/workflow_blindpath.py @@ -0,0 +1,3235 @@ +# workflow_blindpath.py +# -*- coding: utf-8 -*- +""" +盲道导航工作流 - 纯净版 +移除了所有 Redis、Celery 依赖,可以直接集成到任何 Python 应用中 +""" +import os +import time +import cv2 +import numpy as np +import logging +from typing import Dict, List, Optional, Tuple, Any +from dataclasses import dataclass +from collections import deque +import torch # 添加这行 +from obstacle_detector_client import ObstacleDetectorClient +# 【移除】从这里播放音频会导致线程池中asyncio无法工作 +# from audio_player import play_voice_text +# 语音由 app_main.py 统一处理 +from crosswalk_awareness import CrosswalkAwarenessMonitor, split_combined_voice # 斑马线感知 +# 尝试导入 Pillow,用于中文显示 +try: + from PIL import Image, ImageDraw, ImageFont + PIL_AVAILABLE = True +except ImportError: + PIL_AVAILABLE = False + Image, ImageDraw, ImageFont = None, None, None + +logger = logging.getLogger(__name__) + +# ========== 状态常量定义 ========== +STATE_ONBOARDING = "ONBOARDING" +STATE_NAVIGATING = "NAVIGATING" +STATE_MANEUVERING_TURN = "MANEUVERING_TURN" +STATE_AVOIDING_OBSTACLE = "AVOIDING_OBSTACLE" +STATE_LOCKING_ON = "LOCKING_ON" + +# ONBOARDING子步骤 +ONBOARDING_STEP_ROTATION = "ROTATION" +ONBOARDING_STEP_TRANSLATION = "TRANSLATION" + +# 转向子步骤 +MANEUVER_STEP_1_ISSUE_COMMAND = "ISSUE_COMMAND" +MANEUVER_STEP_2_WAIT_FOR_SHIFT = "WAIT_FOR_SHIFT" +MANEUVER_STEP_3_ALIGN_ON_NEW_PATH = "ALIGN_ON_NEW_PATH" + +# 颜色定义 (BGR格式) +VIS_COLORS = { + "blind_path": (0, 255, 0), # 绿色 + "obstacle": (0, 0, 255), # 红色 + "crosswalk": (0, 165, 255), # 橙色 + "centerline": (0, 255, 255), # 黄色 + "target_point": (255, 0, 0), # 蓝色 + "turn_point": (128, 0, 128), # 紫色 + "pulse_effect": (100, 100, 255) # 淡红色 +} + +# 障碍物名称映射 +_OBSTACLE_NAME_CN = { + 'person': '人', + 'bicycle': '自行车', + 'car': '车', + 'motorcycle': '摩托车', + 'bus': '公交车', + 'truck': '卡车', + 'animal': '动物', + 'scooter': '电瓶车', + 'stroller': '婴儿车', + 'dog': '狗', +} + +# 动态类别名称列表 +DYNAMIC_CLASS_NAMES = {'person', 'bicycle', 'car', 'motorcycle', 'bus', 'truck', 'animal', 'dog'} + +@dataclass +class ProcessingResult: + """处理结果数据类""" + guidance_text: str # 语音引导文本 + visualizations: List[Dict[str, Any]] # 可视化元素列表 + annotated_image: Optional[np.ndarray] = None # 标注后的图像 + state_info: Dict[str, Any] = None # 状态信息 + + def __post_init__(self): + if self.state_info is None: + self.state_info = {} + + +class BlindPathNavigator: + """盲道导航处理器 - 无外部依赖版本""" + + def __init__(self, yolo_model=None, obstacle_detector=None): + """ + 初始化导航器 + :param yolo_model: YOLO分割模型(可选) + :param obstacle_detector: 障碍物检测器(可选) + """ + self.yolo_model = yolo_model + self.obstacle_detector = obstacle_detector + + # 状态变量 + self.current_state = STATE_ONBOARDING + self.onboarding_step = ONBOARDING_STEP_ROTATION + self.maneuver_step = MANEUVER_STEP_1_ISSUE_COMMAND + self.maneuver_target_info = None + + + # 光流追踪参数 + self.lk_params = dict( + winSize=(15, 15), + maxLevel=2, + criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 10, 0.03) + ) + + # 特征检测参数 + self.feature_params = dict( + maxCorners=100, + qualityLevel=0.05, + minDistance=10, + blockSize=7, + useHarrisDetector=False, + k=0.04 + ) + + # 光流追踪点缓存 + self.flow_points = {} # {mask_type: points} + self.flow_grace = {} # {mask_type: grace_count} + self.FLOW_GRACE_MAX = 3 # 【修改】从8帧降低到3帧,快速清除光流遗留 + + # 中心线平滑缓存 + self.centerline_history = [] # 历史中心线数据 + self.centerline_history_max = 5 # 保留最近5帧用于平滑 + + # 多项式系数平滑缓存 + self.poly_coeffs_history = [] # 历史多项式系数 + self.poly_coeffs_history_max = 8 # 保留最近8帧系数用于平滑 + + # 转弯检测追踪器 + self.turn_detection_tracker = { + 'direction': None, + 'consecutive_hits': 0, + 'last_seen_frame': 0, + 'corner_info': None + } + + # 转弯冷却 + self.turn_cooldown_frames = 0 + self.TURN_COOLDOWN_DURATION = 50 + + # 避障相关 + self.avoidance_plan = None + self.avoidance_step_index = 0 + self.lock_on_data = None + + # 斑马线追踪 + self.crosswalk_tracker = { + 'stage': 'not_detected', + 'consecutive_frames': 0, + 'last_area_ratio': 0.0, + 'last_bottom_y_ratio': 0.0, + 'last_center_x_ratio': 0.5, + 'position_announced': False, + 'alignment_status': 'not_aligned', + 'last_seen_frame': 0, + 'last_angle': 0.0 + } + + # 帧计数器 + self.frame_counter = 0 + + # 直行提示配置 - 支持环境变量 + self.guide_interval = float(os.getenv("AIGLASS_STRAIGHT_INTERVAL", "4.0")) # 播报间隔(秒) + self.last_guide_time = 0.0 + self.straight_continuous_mode = os.getenv("AIGLASS_STRAIGHT_CONTINUOUS", "1") == "1" # 持续播报模式 + self.straight_repeat_limit = int(os.getenv("AIGLASS_STRAIGHT_LIMIT", "2")) # 限制模式下的最大次数 + self.straight_repeat_count = 0 + + # 【新增】方向指令持续播报配置 + self.direction_interval = float(os.getenv("AIGLASS_DIRECTION_INTERVAL", "3.0")) # 方向指令间隔(秒) + self.last_direction_time = 0.0 + self.last_direction_message = "" + + # 打印配置信息 + logger.info(f"[BlindPath] 直行播报配置: 间隔={self.guide_interval}秒, " + f"持续模式={self.straight_continuous_mode}, " + f"限制次数={self.straight_repeat_limit}") + logger.info(f"[BlindPath] 方向播报配置: 间隔={self.direction_interval}秒") + + # 缓存变量 + self.prev_gray = None + self.prev_blind_path_mask = None + self.prev_crosswalk_mask = None + self.prev_obstacle_cache = [] + self.last_guidance_message = "" + self.last_detected_obstacles = [] + self.last_obstacle_detection_frame = 0 + self.last_any_speech_time = 0 + + # 斑马线准备状态标志 + self.crosswalk_ready_announced = False + self.crosswalk_ready_time = 0 + + # 障碍物语音待播报 + self.pending_obstacle_voice = None + + # 红绿灯检测 + self.traffic_light_detector = None + self.init_traffic_light_detector() + self.traffic_light_history = deque(maxlen=8) # 用于多数表决 + self.last_traffic_light_state = "unknown" + self.green_light_announced = False + + # 阈值设置 + self.CLASS_CONF_THRESHOLDS = { + 1: 0.20, # blind_path + 0: 0.30 # crosswalk + } + + # 导航阈值 + # 导航阈值 + self.ONBOARDING_ALIGN_THRESHOLD_RATIO = 0.1 + self.VP_FIT_ERROR_THRESHOLD = 8.0 + + self.ONBOARDING_ORIENTATION_THRESHOLD_RAD = np.deg2rad(10) + self.ONBOARDING_CENTER_OFFSET_THRESHOLD_RATIO = 0.15 + self.NAV_ORIENTATION_THRESHOLD_RAD = np.deg2rad(10) + self.NAV_CENTER_OFFSET_THRESHOLD_RATIO = 0.15 + self.CURVATURE_PROXY_THRESHOLD = 5e-5 + + # 斑马线切换阈值 + self.CROSSWALK_SWITCH_AREA_RATIO = 0.22 + self.CROSSWALK_SWITCH_BOTTOM_RATIO = 0.9 + self.CROSSWALK_SWITCH_CONSECUTIVE_FRAMES = 10 + + # 障碍物检测间隔 + # 障碍物检测优化参数 - Day 22 优化: 增加间隔减少卡顿 + self.OBSTACLE_DETECTION_INTERVAL = int(os.getenv("AIGLASS_OBS_INTERVAL", "18")) # 从15帧增加到18帧 + self.OBSTACLE_CACHE_DURATION_FRAMES = int(os.getenv("AIGLASS_OBS_CACHE_FRAMES", "20")) # Day 21: 缓存20帧减少GPU负载 + + # 障碍物播报管理 + self.last_obstacle_speech = "" + self.last_obstacle_speech_time = 0 + self.obstacle_speech_cooldown = 5.0 # 相同障碍物3秒内不重复播报 + + # 掩码稳定化参数(已禁用光流外推,这些参数不再使用) + self.MASK_STAB_MIN_AREA = int(os.getenv("AIGLASS_MASK_MIN_AREA", "1500")) + self.MASK_STAB_KERNEL = int(os.getenv("AIGLASS_MASK_MORPH", "3")) + self.MASK_MISS_TTL = 0 # 【修改为0】禁用光流外推,完全实时 + self.blind_miss_ttl = 0 + self.cross_miss_ttl = 0 + + # 光流跟踪参数 + self.flow_iou_threshold = 0.3 # IoU低于此值时重新初始化光流点 + + # 【新增】盲道YOLO检测间隔 - Day 22 优化: 增加间隔减少卡顿 + self.BLINDPATH_DETECTION_INTERVAL = int(os.getenv("AIGLASS_BLINDPATH_INTERVAL", "10")) # 从8帧增加到10帧 + self.last_blindpath_detection_frame = 0 + self.last_blindpath_mask = None + self.last_crosswalk_mask = None + + # 【新增】斑马线感知监控器 + self.crosswalk_monitor = CrosswalkAwarenessMonitor() + logger.info("[BlindPath] 斑马线感知监控器已初始化") + logger.info(f"[BlindPath] 盲道检测间隔: 每{self.BLINDPATH_DETECTION_INTERVAL}帧") + + def init_traffic_light_detector(self): + """初始化红绿灯检测器""" + try: + # 首先尝试使用 YOLO 模型检测红绿灯 + self.traffic_light_yolo = None + # 如果你有专门的红绿灯模型,在这里加载 + # self.traffic_light_yolo = YOLO('path/to/traffic_light_model.pt') + except Exception as e: + logger.info(f"未加载红绿灯YOLO模型: {e}") + + def detect_traffic_light(self, image: np.ndarray) -> str: + """检测红绿灯状态 + 返回: 'red', 'green', 'yellow', 'unknown' + """ + # 模拟模式(用于测试) + if os.getenv("AIGLASS_SIMULATE_TRAFFIC_LIGHT", "0") == "1": + # 根据帧数模拟红绿灯变化 + cycle = (self.frame_counter // 100) % 3 + if cycle == 0: + return "red" + elif cycle == 1: + return "yellow" + else: + return "green" + + # 如果有 YOLO 模型,优先使用 + if self.traffic_light_yolo: + try: + results = self.traffic_light_yolo.predict(image, verbose=False, conf=0.3) + # TODO: 解析 YOLO 结果,判断红绿灯颜色 + pass + except: + pass + + # 使用 HSV 颜色检测作为后备方案 + return self._detect_traffic_light_by_color(image) + + def _detect_traffic_light_by_color(self, image: np.ndarray) -> str: + """基于 HSV 颜色空间检测红绿灯""" + h, w = image.shape[:2] + # 检测图像上半部分和中间部分(红绿灯可能在不同高度) + roi = image[:int(h * 0.7), :] # 扩大检测范围到70% + hsv = cv2.cvtColor(roi, cv2.COLOR_BGR2HSV) + + # 提高亮度的图像用于检测(有助于检测较暗的红绿灯) + hsv_bright = hsv.copy() + hsv_bright[:, :, 2] = cv2.add(hsv_bright[:, :, 2], 30) # 增加亮度 + + # 定义颜色范围(优化后的参数) + # 红色(两个范围,因为红色在 HSV 中跨越 0 度) + lower_red1 = np.array([0, 120, 100]) + upper_red1 = np.array([10, 255, 255]) + lower_red2 = np.array([170, 120, 100]) + upper_red2 = np.array([180, 255, 255]) + + # 绿色(调整为更宽的范围以适应不同灯光) + lower_green = np.array([40, 60, 60]) + upper_green = np.array([90, 255, 255]) + + # 黄色 + lower_yellow = np.array([15, 100, 100]) + upper_yellow = np.array([40, 255, 255]) + + # 创建掩码(同时在原图和增亮图上检测) + mask_red1 = cv2.inRange(hsv, lower_red1, upper_red1) + mask_red2 = cv2.inRange(hsv, lower_red2, upper_red2) + mask_red1_bright = cv2.inRange(hsv_bright, lower_red1, upper_red1) + mask_red2_bright = cv2.inRange(hsv_bright, lower_red2, upper_red2) + mask_red = cv2.bitwise_or(cv2.bitwise_or(mask_red1, mask_red2), + cv2.bitwise_or(mask_red1_bright, mask_red2_bright)) + + mask_green = cv2.bitwise_or(cv2.inRange(hsv, lower_green, upper_green), + cv2.inRange(hsv_bright, lower_green, upper_green)) + mask_yellow = cv2.bitwise_or(cv2.inRange(hsv, lower_yellow, upper_yellow), + cv2.inRange(hsv_bright, lower_yellow, upper_yellow)) + + # 形态学操作去噪 + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) + mask_red = cv2.morphologyEx(mask_red, cv2.MORPH_OPEN, kernel) + mask_green = cv2.morphologyEx(mask_green, cv2.MORPH_OPEN, kernel) + mask_yellow = cv2.morphologyEx(mask_yellow, cv2.MORPH_OPEN, kernel) + + # 计算每种颜色的面积 + area_red = cv2.countNonZero(mask_red) + area_green = cv2.countNonZero(mask_green) + area_yellow = cv2.countNonZero(mask_yellow) + + # 设置最小面积阈值(降低阈值使检测更敏感) + min_area = 30 # 进一步降低阈值 + + # 添加更详细的调试信息 + if hasattr(self, 'frame_counter') and self.frame_counter % 30 == 0: + logger.info(f"[HSV检测] 红:{area_red}, 绿:{area_green}, 黄:{area_yellow}") + # 保存调试图像 + if os.getenv("AIGLASS_DEBUG_TRAFFIC_LIGHT", "0") == "1": + debug_dir = "traffic_light_debug" + os.makedirs(debug_dir, exist_ok=True) + cv2.imwrite(f"{debug_dir}/frame_{self.frame_counter}_roi.jpg", roi) + cv2.imwrite(f"{debug_dir}/frame_{self.frame_counter}_red.jpg", mask_red) + cv2.imwrite(f"{debug_dir}/frame_{self.frame_counter}_green.jpg", mask_green) + cv2.imwrite(f"{debug_dir}/frame_{self.frame_counter}_yellow.jpg", mask_yellow) + + # 判断颜色(优先级:绿 > 红 > 黄) + if area_green > min_area and area_green > area_red * 0.8: # 绿灯优先 + return "green" + elif area_red > min_area and area_red > area_green: + return "red" + elif area_yellow > min_area: + return "yellow" + else: + return "unknown" + + def _get_voice_priority(self, guidance_text): + """获取语音指令的优先级 + 优先级:障碍物(100) > 转向/平移(50) > 保持直行(10) + """ + if not guidance_text: + return 0 + + # 障碍物播报 - 最高优先级 + obstacle_keywords = ['前方有', '左侧有', '右侧有', '停一下', '注意避让', '障碍物'] + for keyword in obstacle_keywords: + if keyword in guidance_text: + return 100 + + # 转向和平移 - 中等优先级 + direction_keywords = ['左转', '右转', '左移', '右移', '向左', '向右', '平移', '微调'] + for keyword in direction_keywords: + if keyword in guidance_text: + return 50 + + # 保持直行 - 最低优先级 + if '保持直行' in guidance_text or '继续前进' in guidance_text or '方向正确' in guidance_text: + return 10 + + # 其他指令 - 默认中等优先级 + return 30 + + def process_frame(self, image: np.ndarray) -> ProcessingResult: + """ + 处理单帧图像 + :param image: BGR格式的图像 + :return: 处理结果 + """ + # 【Day 15 性能诊断】帧处理计时 + import time as perf_time + frame_start_time = perf_time.perf_counter() + timing_log = {} + + self.frame_counter += 1 + + # 更新冷却期 + if self.turn_cooldown_frames > 0: + self.turn_cooldown_frames -= 1 + + image_height, image_width = image.shape[:2] + image_center_x = image_width / 2 + + # 转换为灰度图 + t0 = perf_time.perf_counter() + curr_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + timing_log['grayscale'] = (perf_time.perf_counter() - t0) * 1000 + + # 可视化元素列表 + frame_visualizations = [] + guidance_text = "" + + # Day 20 优化: 并行检测 - 盲道和障碍物同时检测 + # 统一检测间隔,减少总体延迟 + UNIFIED_DETECTION_INTERVAL = int(os.getenv("AIGLASS_UNIFIED_INTERVAL", "20")) # Day 21: 从10帧增加到20帧减少卡顿 + + t0 = perf_time.perf_counter() + if self.frame_counter % UNIFIED_DETECTION_INTERVAL == 0: + # 尝试使用 GPU 并行检测 + try: + from gpu_parallel import ParallelDetector + if not hasattr(self, '_parallel_detector'): + self._parallel_detector = ParallelDetector(self.yolo_model, self.obstacle_detector) + logger.info("[Day20] GPU 并行检测器已初始化") + + # 并行执行盲道检测和障碍物检测 + blind_path_mask, crosswalk_mask, detected_obstacles = self._parallel_detector.detect_all( + image, self.last_blindpath_mask + ) + + self.last_blindpath_mask = blind_path_mask + self.last_crosswalk_mask = crosswalk_mask + self.last_detected_obstacles = detected_obstacles + self.last_obstacle_detection_frame = self.frame_counter + + timing_log['yolo'] = (perf_time.perf_counter() - t0) * 1000 + timing_log['obstacle'] = 0 # 并行执行,计入 yolo 时间 + + except ImportError: + # 回退到串行检测 + logger.warning("[Day20] gpu_parallel 模块不可用,使用串行检测") + blind_path_mask, crosswalk_mask = self._detect_path_and_crosswalk(image) + self.last_blindpath_mask = blind_path_mask + self.last_crosswalk_mask = crosswalk_mask + timing_log['yolo'] = (perf_time.perf_counter() - t0) * 1000 + + t0 = perf_time.perf_counter() + detected_obstacles = self._detect_obstacles(image, blind_path_mask) + self.last_detected_obstacles = detected_obstacles + self.last_obstacle_detection_frame = self.frame_counter + timing_log['obstacle'] = (perf_time.perf_counter() - t0) * 1000 + else: + # 使用缓存 + blind_path_mask = self.last_blindpath_mask + crosswalk_mask = self.last_crosswalk_mask + timing_log['yolo'] = 0 + + if self.frame_counter - self.last_obstacle_detection_frame < self.OBSTACLE_CACHE_DURATION_FRAMES: + detected_obstacles = self.last_detected_obstacles + else: + detected_obstacles = [] + timing_log['obstacle'] = 0 # 未执行检测 + + # 添加所有障碍物的可视化(不只是近距离的) + t0 = perf_time.perf_counter() + for i, obs in enumerate(detected_obstacles): + self._add_obstacle_visualization(obs, frame_visualizations) + timing_log['obstacle_viz'] = (perf_time.perf_counter() - t0) * 1000 + + # 【新增】检查近距离障碍物并设置语音 + self._check_and_set_obstacle_voice(detected_obstacles) + + # 【新增】斑马线感知处理 + # 【Day 15 优化】减少每帧日志输出,只在每 30 帧输出一次 + if crosswalk_mask is not None and self.frame_counter % 30 == 0: + cross_pixels = np.sum(crosswalk_mask > 0) + if cross_pixels > 0: + logger.info(f"[斑马线] monitor: pixels={cross_pixels}, area={cross_pixels/crosswalk_mask.size*100:.2f}%") + elif crosswalk_mask is None and self.frame_counter % 30 == 0: + logger.info(f"[斑马线] crosswalk_mask为None") + + crosswalk_guidance = self.crosswalk_monitor.process_frame(crosswalk_mask, blind_path_mask) + if crosswalk_guidance: + logger.info(f"[斑马线感知] 检测结果: area={crosswalk_guidance.get('area', 0):.3f}, " + f"should_broadcast={crosswalk_guidance.get('should_broadcast', False)}, " + f"voice={crosswalk_guidance.get('voice_text', 'None')}") + if crosswalk_guidance and crosswalk_guidance['should_broadcast']: + # 将斑马线语音加入待播报列表(通过pending机制) + if not hasattr(self, 'pending_crosswalk_voice'): + self.pending_crosswalk_voice = None + self.pending_crosswalk_voice = crosswalk_guidance + logger.info(f"[斑马线语音] 已设置待播报语音: {crosswalk_guidance['voice_text']}, 优先级{crosswalk_guidance['priority']}") + + # 【新增】添加斑马线可视化 + if crosswalk_mask is not None: + # 计算可视化数据 + total_pixels = crosswalk_mask.size + crosswalk_pixels = np.sum(crosswalk_mask > 0) + area_ratio = crosswalk_pixels / total_pixels + + y_coords, x_coords = np.where(crosswalk_mask > 0) + if len(y_coords) > 0: + center_x_ratio = np.mean(x_coords) / crosswalk_mask.shape[1] + center_y_ratio = np.mean(y_coords) / crosswalk_mask.shape[0] + has_occlusion = self.crosswalk_monitor._check_occlusion(crosswalk_mask, blind_path_mask) + + # 获取可视化数据 + viz_data = self.crosswalk_monitor.get_visualization_data( + crosswalk_mask, area_ratio, center_x_ratio, center_y_ratio, has_occlusion + ) + + # 添加斑马线mask可视化 + self._add_mask_visualization(crosswalk_mask, frame_visualizations, + "crosswalk_mask", viz_data['stage_color']) + + # 添加斑马线检测信息可视化 + self._add_crosswalk_info_visualization(viz_data, image_height, image_width, + frame_visualizations) + + # 【已禁用】4. 更新斑马线追踪器 - 盲道导航不再跳转到斑马线 + # self._update_crosswalk_tracker(crosswalk_mask, image_height, image_width) + + # 5. 添加路径可视化 + # 【恢复】盲道mask可视化 + self._add_mask_visualization(blind_path_mask, frame_visualizations, "blind_path_mask", "rgba(0, 255, 0, 0.4)") + # 【斑马线可视化由crosswalk_monitor处理,不在这里添加】 + + + # 【已禁用】5. 根据状态执行不同的导航逻辑 - 盲道导航不再处理斑马线 + current_stage = 'not_detected' # 固定为不检测斑马线 + # current_stage = self.crosswalk_tracker['stage'] # 已禁用 + + # 直接进行盲道导航,不检查斑马线状态 + if False: # current_stage == 'ready': + # 检查是否已经播报过准备提示 + if not hasattr(self, 'crosswalk_ready_announced'): + self.crosswalk_ready_announced = False + self.crosswalk_ready_time = 0 + + current_time = time.time() + + # 检测红绿灯 + traffic_light_color = self.detect_traffic_light(image) + self.traffic_light_history.append(traffic_light_color) + + # 调试信息 + if self.frame_counter % 30 == 0: # 每30帧打印一次 + logger.info(f"[红绿灯检测] 当前颜色: {traffic_light_color}, 历史: {list(self.traffic_light_history)}") + + # 多数表决,获得稳定的红绿灯状态 + if len(self.traffic_light_history) >= 3: + color_counts = {} + for color in self.traffic_light_history: + color_counts[color] = color_counts.get(color, 0) + 1 + # 获取出现次数最多的颜色 + stable_color = max(color_counts.items(), key=lambda x: x[1])[0] + else: + stable_color = "unknown" + + # 添加红绿灯状态可视化 + self._add_traffic_light_visualization( + stable_color, frame_visualizations, image_height, image_width + ) + + # 决定语音播报 + if not self.crosswalk_ready_announced: + guidance_text = "已对准, 准备切换过马路模式。" + self.crosswalk_ready_announced = True + self.crosswalk_ready_time = current_time + elif stable_color == "green" and not self.green_light_announced: + guidance_text = "绿灯稳定,开始通行。" + self.green_light_announced = True + elif stable_color == "red": + # 红灯时定期提醒 + if current_time - self.crosswalk_ready_time > 5.0: + guidance_text = "正在等待绿灯…" + self.crosswalk_ready_time = current_time + else: + guidance_text = "" + else: + guidance_text = "" + + # 添加状态信息 + frame_visualizations.append({ + "type": "data_panel", + "data": { + "状态": "等待过马路", + "红绿灯": stable_color, + "检测历史": len(self.traffic_light_history) + }, + "position": (25, image_height - 120) + }) + + elif False: # current_stage == 'approaching': + guidance_text = self._handle_crosswalk_approaching( + frame_visualizations, image_height, image_width, image + ) + + # elif current_stage in ['far', 'not_detected']: + else: # 总是执行盲道导航 + # 【已禁用】斑马线提示 + # if current_stage == 'far' and not self.crosswalk_tracker['position_announced']: + # guidance_text = "远处发现斑马线,继续直行。" + # self.crosswalk_tracker['position_announced'] = True + + if blind_path_mask is None: + guidance_text = "" + # 【移除左上角文字,改为右上角数据面板】 + frame_visualizations.append({ + "type": "data_panel", + "data": { + "状态": "等待盲道识别" + }, + "position": (image_width - 180, 20) + }) + else: + guidance_text = self._execute_state_machine( + blind_path_mask, image, frame_visualizations, + image_height, image_width, curr_gray + ) + + # 6. 更新缓存 + self.prev_gray = curr_gray + if blind_path_mask is not None: + self.prev_blind_path_mask = blind_path_mask.copy() + if crosswalk_mask is not None: + self.prev_crosswalk_mask = crosswalk_mask.copy() + + # 【改进】语音优先级管理系统 + current_time = time.time() + + # 收集所有可能的语音指令 + voice_candidates = [] + + # 1. 添加主要导航语音 + if guidance_text: + voice_candidates.append({ + 'text': guidance_text, + 'priority': self._get_voice_priority(guidance_text), + 'source': 'navigation' + }) + + # 2. 检查是否有障碍物语音(独立检查,确保最高优先级) + if hasattr(self, 'pending_obstacle_voice'): + if self.pending_obstacle_voice: + voice_candidates.append({ + 'text': self.pending_obstacle_voice, + 'priority': 100, # 障碍物始终最高优先级 + 'source': 'obstacle' + }) + self.pending_obstacle_voice = None # 清除已处理的障碍物语音 + + # 【新增】检查是否有斑马线语音 + if hasattr(self, 'pending_crosswalk_voice'): + if self.pending_crosswalk_voice: + voice_candidates.append({ + 'text': self.pending_crosswalk_voice['voice_text'], + 'priority': self.pending_crosswalk_voice['priority'], + 'source': 'crosswalk' + }) + self.pending_crosswalk_voice = None # 清除已处理的斑马线语音 + + # 3. 选择优先级最高的语音 + if voice_candidates: + # 按优先级排序,取最高的 + voice_candidates.sort(key=lambda x: x['priority'], reverse=True) + selected_voice = voice_candidates[0] + final_guidance_text = selected_voice['text'] + + # 全局播报冷却(避免任何语音重叠)- Day 22 优化: 降低冷却 + MIN_SPEECH_INTERVAL = 0.8 # 任意两条语音间隔至少0.8秒 (从1.2降低) + if hasattr(self, 'last_any_speech_time'): + if current_time - self.last_any_speech_time < MIN_SPEECH_INTERVAL: + final_guidance_text = "" # 太快了,跳过这次播报 + + # 特殊处理保持直行的节流 + if final_guidance_text == "保持直行": + if self.straight_continuous_mode: + # 持续播报模式:只检查时间间隔 + if current_time - self.last_guide_time >= self.guide_interval: + self.last_guide_time = current_time + self.straight_repeat_count += 1 + self.last_any_speech_time = current_time + else: + final_guidance_text = "" + else: + # 原有的限制模式 + if (current_time - self.last_guide_time >= self.guide_interval) and \ + (self.straight_repeat_count < self.straight_repeat_limit): + self.last_guide_time = current_time + self.straight_repeat_count += 1 + self.last_any_speech_time = current_time + else: + final_guidance_text = "" + elif final_guidance_text and selected_voice['source'] != 'obstacle': + # 【修改】非直行、非障碍物指令 - 支持方向指令持续播报 + # 判断是否是方向指令 + direction_keywords = ["左转", "右转", "左移", "右移", "向左", "向右", "平移", "微调"] + is_direction = any(keyword in final_guidance_text for keyword in direction_keywords) + + if is_direction: + # 方向指令:支持持续播报 + if final_guidance_text == self.last_direction_message: + # 同一个方向指令,检查时间间隔 + if current_time - self.last_direction_time >= self.direction_interval: + self.last_direction_time = current_time + self.last_any_speech_time = current_time + self.straight_repeat_count = 0 + else: + final_guidance_text = "" # 时间间隔不够,跳过 + else: + # 新的方向指令,立即播报 + self.last_direction_message = final_guidance_text + self.last_direction_time = current_time + self.last_any_speech_time = current_time + self.straight_repeat_count = 0 + else: + # 其他指令:只播报一次 + if final_guidance_text != self.last_guidance_message: + self.last_guidance_message = final_guidance_text + self.straight_repeat_count = 0 + self.last_any_speech_time = current_time + else: + final_guidance_text = "" + elif final_guidance_text and selected_voice['source'] == 'obstacle': + # 障碍物语音总是播报 + self.last_any_speech_time = current_time + elif final_guidance_text and selected_voice['source'] == 'crosswalk': + # 斑马线语音总是播报(不受重复检查限制) + self.last_any_speech_time = current_time + + # 播报选中的语音 + if final_guidance_text: + try: + # 【优化】组合语音只播第一部分,避免队列积压 + if selected_voice.get('source') == 'crosswalk' and ',' in final_guidance_text: + voice_parts = split_combined_voice(final_guidance_text) + logger.info(f"[斑马线语音] 组合播报检测到{len(voice_parts)}部分,只播第一部分保持实时") + # 只播放第一部分,后续部分丢弃以保持实时性 + if voice_parts: + # 【移除】play_voice_text(voice_parts[0]) - 由app_main统一处理 + final_guidance_text = voice_parts[0] # 只保留第一部分 + logger.info(f"[语音待播] 优先级{selected_voice['priority']}: {voice_parts[0]}") + else: + # 【移除】play_voice_text(final_guidance_text) - 由app_main统一处理 + logger.info(f"[语音待播] 优先级{selected_voice['priority']}: {final_guidance_text}") + except Exception as e: + logger.error(f"[语音播报] 播放失败: {e}") + else: + final_guidance_text = "" + + # 7. 生成标注图像 + # Day 20 优化:移除 image.copy(),直接在原图上绘制(输入图像是临时的) + t0 = perf_time.perf_counter() + + if frame_visualizations: + annotated_image = self._draw_visualizations(image, frame_visualizations) + else: + annotated_image = image + + # 添加底部指令按钮(显示当前实际播报的语音) + current_instruction = final_guidance_text if final_guidance_text else "等待中..." + annotated_image = self._draw_command_button(annotated_image, current_instruction) + timing_log['visualization'] = (perf_time.perf_counter() - t0) * 1000 + + # 【Day 15 性能诊断】每 30 帧输出一次详细性能报告 + frame_total_time = (perf_time.perf_counter() - frame_start_time) * 1000 + if self.frame_counter % 30 == 0: + logger.info(f"[PERF] Frame={self.frame_counter} 总耗时={frame_total_time:.1f}ms | " + f"YOLO={timing_log.get('yolo', 0):.1f}ms, " + f"障碍物={timing_log.get('obstacle', 0):.1f}ms, " + f"可视化={timing_log.get('visualization', 0):.1f}ms, " + f"灰度={timing_log.get('grayscale', 0):.1f}ms") + + # 8. 返回结果 + # 【修改】返回 final_guidance_text(经过节流的),由 app_main 统一播放 + return ProcessingResult( + guidance_text=final_guidance_text, + visualizations=frame_visualizations, + annotated_image=annotated_image, + state_info={ + "state": self.current_state, + "crosswalk_stage": current_stage, + "frame_count": self.frame_counter + } + ) + + def _detect_path_and_crosswalk(self, image: np.ndarray) -> Tuple[Optional[np.ndarray], Optional[np.ndarray]]: + """检测盲道和斑马线 - Day 22 性能优化版本""" + if self.yolo_model is None: + # 【新增】没有模型时返回模拟数据用于测试 + logger.warning("YOLO模型未加载,返回模拟数据") + h, w = image.shape[:2] + # 创建一个模拟的盲道掩码(垂直居中的条带) + blind_path_mask = np.zeros((h, w), dtype=np.uint8) + # 在图像中央创建一个宽度为图像宽度20%的垂直条带 + strip_width = int(w * 0.2) + strip_left = (w - strip_width) // 2 + blind_path_mask[int(h*0.3):, strip_left:strip_left+strip_width] = 255 + return blind_path_mask, None + + blind_path_mask = None + crosswalk_mask = None + + try: + # Day 22 优化: 动态调整输入分辨率以提升性能 + # 可通过环境变量 AIGLASS_YOLO_IMGSZ 配置,默认480(从640降低) + imgsz = int(os.getenv("AIGLASS_YOLO_IMGSZ", "480")) + + min_conf = min(self.CLASS_CONF_THRESHOLDS.values()) + + # Day 22 优化: 使用 half 精度加速(如果GPU支持) + use_half = os.getenv("AIGLASS_YOLO_HALF", "1") == "1" + + results = self.yolo_model.predict( + image, + verbose=False, + conf=min_conf, + classes=[0, 1], + imgsz=imgsz, # 使用较小的输入尺寸 + half=use_half # FP16 半精度加速 + ) + + if (results and results[0] and results[0].masks is not None and + results[0].boxes is not None and len(results[0].masks.data) > 0): + + for mask_tensor, conf_tensor, cls_tensor in zip( + results[0].masks.data, results[0].boxes.conf, results[0].boxes.cls + ): + class_id = int(cls_tensor.item()) + confidence = float(conf_tensor.item()) + threshold = self.CLASS_CONF_THRESHOLDS.get(class_id, 1.0) + + if confidence >= threshold: + current_mask = self._tensor_to_mask(mask_tensor, image.shape[1], image.shape[0]) + + if class_id == 1: # 盲道 + if blind_path_mask is None: + blind_path_mask = current_mask + else: + blind_path_mask = cv2.bitwise_or(blind_path_mask, current_mask) + elif class_id == 0: # 斑马线 + if crosswalk_mask is None: + crosswalk_mask = current_mask + else: + crosswalk_mask = cv2.bitwise_or(crosswalk_mask, current_mask) + except Exception as e: + logger.error(f"YOLO检测失败: {e}") + # 【新增】检测失败时也返回模拟数据 + h, w = image.shape[:2] + blind_path_mask = np.zeros((h, w), dtype=np.uint8) + strip_width = int(w * 0.2) + strip_left = (w - strip_width) // 2 + blind_path_mask[int(h*0.3):, strip_left:strip_left+strip_width] = 255 + + return blind_path_mask, crosswalk_mask + + def _tensor_to_mask(self, mask_tensor, out_w: int, out_h: int, binarize: bool = True) -> np.ndarray: + """将张量掩码转换为numpy数组""" + try: + import torch + + if not isinstance(mask_tensor, torch.Tensor): + arr = np.asarray(mask_tensor) + if arr.dtype != np.uint8: + arr = (arr > 0.5).astype(np.uint8) * 255 if binarize else (arr * 255.0).astype(np.uint8) + mask_u8 = arr + else: + if mask_tensor.dtype in (torch.bfloat16, torch.float16): + mask_tensor = mask_tensor.to(torch.float32) + + if mask_tensor.ndim > 2: + mask_tensor = mask_tensor.squeeze() + + if binarize: + mask_tensor = (mask_tensor > 0.5).to(torch.uint8).mul_(255) + mask_u8 = mask_tensor.cpu().numpy() + else: + mask_u8 = (mask_tensor.mul(255).clamp_(0, 255).to(torch.uint8)).cpu().numpy() + + if mask_u8.ndim == 3: + mask_u8 = mask_u8.squeeze(-1) + + if mask_u8.shape[1] != out_w or mask_u8.shape[0] != out_h: + mask_u8 = cv2.resize(mask_u8, (out_w, out_h), interpolation=cv2.INTER_NEAREST) + + return mask_u8 + except ImportError: + # 如果没有torch,返回空掩码 + return np.zeros((out_h, out_w), dtype=np.uint8) + + def _stabilize_mask(self, prev_gray, curr_gray, raw_mask, prev_stable_mask, mask_type): + """稳定化掩码 - 使用 Lucas-Kanade 光流""" + if mask_type == 'blind_path': + ttl = self.blind_miss_ttl + min_area = self.MASK_STAB_MIN_AREA + else: # crosswalk + ttl = self.cross_miss_ttl + min_area = self.MASK_STAB_MIN_AREA + + # 调用新的光流稳定化方法 + stable_mask = self._stabilize_seg_mask( + prev_gray, curr_gray, raw_mask, prev_stable_mask, + (curr_gray.shape[1], curr_gray.shape[0]) if curr_gray is not None else (640, 480), + min_area_px=min_area, + morph_kernel=self.MASK_STAB_KERNEL, + mask_type=mask_type + ) + + if stable_mask is not None: + # 重置TTL + if mask_type == 'blind_path': + self.blind_miss_ttl = self.MASK_MISS_TTL + else: + self.cross_miss_ttl = self.MASK_MISS_TTL + return stable_mask + else: + # 减少TTL + if mask_type == 'blind_path': + self.blind_miss_ttl = max(0, self.blind_miss_ttl - 1) + else: + self.cross_miss_ttl = max(0, self.cross_miss_ttl - 1) + return None + + def _stabilize_seg_mask(self, prev_gray, curr_gray, curr_mask, prev_stable_mask, + image_wh, min_area_px=1500, morph_kernel=3, iou_high_thr=0.4, mask_type='', + fast_clear=True): + """使用 Lucas-Kanade 光流的掩码稳定化实现""" + W, H = image_wh + + def _binarize(mask): + if mask is None: + return None + if mask.dtype != np.uint8: + mask = mask.astype(np.uint8) + mask = (mask > 0).astype(np.uint8) * 255 + return mask + + def _morph_smooth(mask, kernel_size): + if mask is None: + return None + k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, + (max(1, kernel_size), max(1, kernel_size))) + sm = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, k, iterations=1) + sm = cv2.morphologyEx(sm, cv2.MORPH_OPEN, k, iterations=1) + return sm + + curr_mask_b = _binarize(curr_mask) + prev_mask_b = _binarize(prev_stable_mask) + + # 如果没有历史数据,直接返回当前掩码 + if prev_mask_b is None or prev_gray is None or curr_gray is None: + return _morph_smooth(curr_mask_b, morph_kernel) if curr_mask_b is not None else None + + # 当前帧有检测结果 + if curr_mask_b is not None and np.sum(curr_mask_b > 0) >= min_area_px: + # 计算与上一帧的IoU + if prev_mask_b is not None: + inter = np.logical_and(curr_mask_b > 0, prev_mask_b > 0).sum() + union = np.logical_or(curr_mask_b > 0, prev_mask_b > 0).sum() + iou = float(inter) / float(union) if union > 0 else 0.0 + + # IoU足够高,说明检测稳定,直接使用当前结果 + if iou >= iou_high_thr: + return _morph_smooth(curr_mask_b, morph_kernel) + + # IoU较低但仍有重叠,进行加权融合 + elif iou > 0.1: + # 使用光流预测的掩码 + flow_mask = self._predict_mask_with_flow(prev_mask_b, prev_gray, curr_gray) + if flow_mask is not None: + # 根据IoU动态调整权重 + # IoU越低,越依赖光流;IoU越高,越依赖当前检测 + w_curr = min(0.9, 0.4 + iou) # IoU=0.1时w_curr=0.5, IoU=0.5时w_curr=0.9 + w_flow = 1.0 - w_curr + + fused = (w_curr * curr_mask_b.astype(np.float32) + + w_flow * flow_mask.astype(np.float32)) + fused_bin = (fused >= 128).astype(np.uint8) * 255 + + # 重新初始化光流点(如果IoU过低) + if iou < self.flow_iou_threshold: + self.flow_points['blind_path'] = None + + return _morph_smooth(fused_bin, morph_kernel) + + # 没有历史或IoU太低,使用当前检测 + return _morph_smooth(curr_mask_b, morph_kernel) + + # 当前帧没有检测结果,尝试使用光流外推 + else: + # 获取对应的TTL + if mask_type == 'blind_path': + ttl = self.blind_miss_ttl + else: + ttl = self.cross_miss_ttl + + # 【修改】当前帧无检测结果,快速清除 + if fast_clear and ttl <= 1: + # TTL耗尽,立即返回None,不使用光流 + return None + + if prev_mask_b is not None and np.sum(prev_mask_b > 0) >= min_area_px and ttl > 0: + # 使用光流预测 + flow_mask = self._predict_mask_with_flow(prev_mask_b, prev_gray, curr_gray) + if flow_mask is not None and np.sum(flow_mask > 0) >= min_area_px * 0.5: + return _morph_smooth(flow_mask, morph_kernel) + + # 光流失败或超过TTL + return None + + def _predict_mask_with_flow(self, prev_mask, prev_gray, curr_gray): + """使用Lucas-Kanade光流预测掩码位置(改进版)""" + try: + # 方法1:尝试使用凸包方法(参考yolomedia) + if hasattr(self, 'flow_points') and 'blind_path' in self.flow_points: + p0 = self.flow_points['blind_path'] + if p0 is not None and len(p0) >= 5: + # 计算光流 + p1, st, err = cv2.calcOpticalFlowPyrLK(prev_gray, curr_gray, p0, None, **self.lk_params) + + if p1 is not None and st is not None: + good_new = p1[st == 1] + if len(good_new) >= 5: + # 更新光流点 + self.flow_points['blind_path'] = good_new.reshape(-1, 1, 2) + + # 生成凸包掩码 + hull = cv2.convexHull(good_new.reshape(-1, 1, 2)) + poly = hull.reshape(-1, 2) + + if len(poly) >= 3: + H, W = curr_gray.shape[:2] + flow_mask = np.zeros((H, W), dtype=np.uint8) + cv2.fillPoly(flow_mask, [poly.astype(np.int32)], 255) + return flow_mask + + # 方法2:边缘特征点方法(原有方法,作为备选) + edge_mask = self._get_edge_mask(prev_mask, offset=10) + + # 检测特征点 + p0 = cv2.goodFeaturesToTrack(prev_gray, mask=edge_mask, **self.feature_params) + if p0 is None or len(p0) < 8: + return None + + # 保存特征点供下次使用 + self.flow_points['blind_path'] = p0 + + # 计算光流 + p1, st, err = cv2.calcOpticalFlowPyrLK(prev_gray, curr_gray, p0, None, **self.lk_params) + + if p1 is None or st is None: + return None + + # 只保留成功追踪的点 + good_new = p1[st == 1] + good_old = p0[st == 1] + + if len(good_new) < 5: + return None + + # 估计变换矩阵(使用RANSAC提高鲁棒性) + M, inliers = cv2.estimateAffinePartial2D(good_old, good_new, method=cv2.RANSAC, ransacReprojThreshold=5.0) + + if M is None: + return None + + # 应用变换 + H, W = curr_gray.shape[:2] + flow_mask = cv2.warpAffine(prev_mask, M, (W, H), + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=0) + + return flow_mask + + except Exception as e: + logger.debug(f"光流预测失败: {e}") + return None + + + def _get_edge_mask(self, mask, offset=10): + """获取掩码的内边缘区域,用于特征点检测""" + if mask is None: + return None + + # 腐蚀得到内部掩码 + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (offset*2, offset*2)) + inner = cv2.erode(mask, kernel, iterations=1) + + # 边缘 = 原始 - 内部 + edge = cv2.subtract(mask, inner) + + # 稍微膨胀边缘区域 + kernel_small = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) + edge = cv2.dilate(edge, kernel_small, iterations=1) + + return edge + + def _smooth_centerline(self, centerline_data): + """平滑中心线数据,减少抖动""" + if centerline_data is None or len(centerline_data) < 5: + return centerline_data + + # 保存到历史记录 + self.centerline_history.append(centerline_data.copy()) + if len(self.centerline_history) > self.centerline_history_max: + self.centerline_history.pop(0) + + # 如果历史记录不足,返回轻度平滑的当前帧数据 + if len(self.centerline_history) < 3: + # 对当前帧进行空间平滑 + smoothed_data = centerline_data.copy() + # 使用滑动窗口平均 + window_size = 5 + for i in range(len(smoothed_data)): + start_idx = max(0, i - window_size // 2) + end_idx = min(len(smoothed_data), i + window_size // 2 + 1) + window = smoothed_data[start_idx:end_idx] + if len(window) > 0: + smoothed_data[i, 1] = np.mean(window[:, 1]) # 平滑x坐标 + smoothed_data[i, 2] = np.mean(window[:, 2]) # 平滑宽度 + return smoothed_data + + # 时间平滑:使用历史帧的加权平均 + smoothed_data = centerline_data.copy() + + # 为每个y坐标找到历史帧中对应的数据 + for i, (y, x, width) in enumerate(centerline_data): + x_values = [x] + width_values = [width] + weights = [1.0] # 当前帧权重最高 + + # 从历史帧中查找相近y坐标的数据 + for hist_idx, hist_data in enumerate(self.centerline_history[-3:-1]): # 使用最近的2帧历史 + # 找到最接近的y坐标 + y_diffs = np.abs(hist_data[:, 0] - y) + if len(y_diffs) > 0: + closest_idx = np.argmin(y_diffs) + if y_diffs[closest_idx] < 10: # y坐标差异小于10像素 + x_values.append(hist_data[closest_idx, 1]) + width_values.append(hist_data[closest_idx, 2]) + # 历史帧权重递减 + weights.append(0.5 ** (len(self.centerline_history) - hist_idx - 1)) + + # 加权平均 + if len(x_values) > 1: + weights = np.array(weights) + weights = weights / np.sum(weights) + smoothed_data[i, 1] = np.sum(np.array(x_values) * weights) + smoothed_data[i, 2] = np.sum(np.array(width_values) * weights) + + # 空间平滑:对结果再进行一次滑动窗口平均 + window_size = 3 + final_data = smoothed_data.copy() + for i in range(len(final_data)): + start_idx = max(0, i - window_size // 2) + end_idx = min(len(final_data), i + window_size // 2 + 1) + window = smoothed_data[start_idx:end_idx] + if len(window) > 0: + final_data[i, 1] = np.mean(window[:, 1]) + final_data[i, 2] = np.mean(window[:, 2]) + + return final_data + + def _estimate_affine(self, prev_gray, curr_gray, mask=None): + """使用光流估计仿射变换(备用方法)""" + try: + # 提取特征点 + if mask is not None: + p0 = cv2.goodFeaturesToTrack(prev_gray, mask=mask, **self.feature_params) + else: + p0 = cv2.goodFeaturesToTrack(prev_gray, **self.feature_params) + + if p0 is None or len(p0) < 4: + return np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32) + + # 计算光流 + p1, st, err = cv2.calcOpticalFlowPyrLK(prev_gray, curr_gray, p0, None, **self.lk_params) + + if p1 is None or st is None: + return np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32) + + # 只保留好的点 + good_new = p1[st == 1].reshape(-1, 2) + good_old = p0[st == 1].reshape(-1, 2) + + if len(good_new) < 4: + return np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32) + + # 估计仿射变换 + M, _ = cv2.estimateAffinePartial2D(good_old, good_new, method=cv2.RANSAC) + + if M is None: + return np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32) + + return M + + except Exception as e: + logger.debug(f"仿射估计失败: {e}") + return np.array([[1, 0, 0], [0, 1, 0]], dtype=np.float32) + + def _warp_mask(self, mask, M, output_shape): + """应用仿射变换""" + try: + W, H = output_shape + warped = cv2.warpAffine(mask, M, (W, H), + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=0) + return warped + except: + return None + + def _add_mask_visualization(self, mask, visualizations, viz_type, color, add_outline=True): + """添加掩码可视化(增加描边)""" + if mask is None: + return + + try: + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if contours: + main_contour = max(contours, key=cv2.contourArea) + points = main_contour.squeeze(1)[::5].tolist() + + # 添加填充 + visualizations.append({ + "type": viz_type, + "points": points, + "color": color + }) + + # 添加描边(盲道不添加描边) + if add_outline and viz_type != "blind_path_mask": + visualizations.append({ + "type": "outline", + "points": points, + "color": "rgba(255, 255, 255, 0.8)", # 白色描边 + "thickness": 3 + }) + except: + pass + + + def _update_crosswalk_tracker(self, crosswalk_mask, image_height, image_width): + """更新斑马线追踪器""" + if crosswalk_mask is not None: + self.crosswalk_tracker['consecutive_frames'] += 1 + self.crosswalk_tracker['last_seen_frame'] = self.frame_counter + + # 计算关键指标 + total_area = image_height * image_width + area_ratio = np.sum(crosswalk_mask > 0) / total_area + y_coords, x_coords = np.where(crosswalk_mask > 0) + + if len(y_coords) > 0: + bottom_y_ratio = np.max(y_coords) / image_height + center_x_ratio = np.mean(x_coords) / image_width + + self.crosswalk_tracker['last_area_ratio'] = area_ratio + self.crosswalk_tracker['last_bottom_y_ratio'] = bottom_y_ratio + self.crosswalk_tracker['last_center_x_ratio'] = center_x_ratio + + # 计算角度 + try: + contours, _ = cv2.findContours(crosswalk_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if contours: + main_contour = max(contours, key=cv2.contourArea) + rect = cv2.minAreaRect(main_contour) + angle = rect[-1] + w, h = rect[1] + if w < h: + angle += 90 + self.crosswalk_tracker['last_angle'] = angle + except: + self.crosswalk_tracker['last_angle'] = 0.0 + + # 状态切换 + is_ready_to_switch = ( + area_ratio >= self.CROSSWALK_SWITCH_AREA_RATIO and + bottom_y_ratio >= self.CROSSWALK_SWITCH_BOTTOM_RATIO or + (self.crosswalk_tracker['consecutive_frames'] >= self.CROSSWALK_SWITCH_CONSECUTIVE_FRAMES + and area_ratio > 0.18) + ) + + if is_ready_to_switch and self.crosswalk_tracker['alignment_status'] == 'aligned': + if self.crosswalk_tracker['stage'] != 'ready': + self.crosswalk_tracker['stage'] = 'ready' + elif area_ratio > 0.07 or bottom_y_ratio > 0.75: + if self.crosswalk_tracker['stage'] in ['far', 'not_detected']: + self.crosswalk_tracker['stage'] = 'approaching' + elif area_ratio > 0.01: + if self.crosswalk_tracker['stage'] == 'not_detected': + self.crosswalk_tracker['stage'] = 'far' + else: + # 丢失检测 + if self.frame_counter - self.crosswalk_tracker['last_seen_frame'] > 15: + self.crosswalk_tracker['stage'] = 'not_detected' + self.crosswalk_tracker['consecutive_frames'] = 0 + self.crosswalk_tracker['position_announced'] = False + self.crosswalk_tracker['alignment_status'] = 'not_aligned' + # 重置准备状态标志 + if hasattr(self, 'crosswalk_ready_announced'): + self.crosswalk_ready_announced = False + self.crosswalk_ready_time = 0 + if hasattr(self, 'traffic_light_history'): + self.traffic_light_history.clear() + self.green_light_announced = False + + def _handle_crosswalk_approaching(self, frame_visualizations, image_height, image_width, image): + """处理接近斑马线的情况""" + # 障碍物检测 + if self.obstacle_detector and self.frame_counter % self.OBSTACLE_DETECTION_INTERVAL == 0: + detected_obstacles = self._detect_obstacles(image) + self.last_detected_obstacles = detected_obstacles + self.last_obstacle_detection_frame = self.frame_counter + + # 添加障碍物可视化 + for obs in self.last_detected_obstacles: + self._add_obstacle_visualization(obs, frame_visualizations) + + # 优先检查近距离障碍物(提高阈值,只有非常近才报警) + NEAR_DISTANCE_Y_THRESHOLD = 0.75 # 提高到0.75 + NEAR_DISTANCE_AREA_THRESHOLD = 0.12 # 提高到0.12 + near_obstacles = [ + obs for obs in self.last_detected_obstacles + if (obs.get('bottom_y_ratio', 0) > NEAR_DISTANCE_Y_THRESHOLD or + obs.get('area_ratio', 0) > NEAR_DISTANCE_AREA_THRESHOLD) + ] + + # 如果有近距离障碍物,应用相同的播报逻辑 + if near_obstacles: + main_obstacle = near_obstacles[0] + obstacle_name = main_obstacle.get('name', '') + current_time = time.time() + + # 检查是否需要播报(避免重复) + should_announce = False + if obstacle_name != self.last_obstacle_speech: + should_announce = True + self.last_obstacle_speech = obstacle_name + self.last_obstacle_speech_time = current_time + elif current_time - self.last_obstacle_speech_time > self.obstacle_speech_cooldown: + should_announce = True + self.last_obstacle_speech_time = current_time + + if should_announce: + return self._speech_for_obstacle(obstacle_name) + else: + # 没有障碍物时清空记录 + self.last_obstacle_speech = "" + + # 对准逻辑 + if self.crosswalk_tracker['alignment_status'] == 'not_aligned': + guidance_text = "正在接近斑马线,为您对准方向。" + self.crosswalk_tracker['alignment_status'] = 'aligning' + else: + angle = self.crosswalk_tracker['last_angle'] + center_x_ratio = self.crosswalk_tracker['last_center_x_ratio'] + + ANGLE_ALIGN_THRESHOLD = 15 + POSITION_ALIGN_THRESHOLD = 0.25 + + if abs(angle) > ANGLE_ALIGN_THRESHOLD: + guidance_text = "右转" if angle < 0 else "左转" + elif abs(center_x_ratio - 0.5) > (POSITION_ALIGN_THRESHOLD / 2): + guidance_text = "右移" if center_x_ratio < 0.5 else "左移" + else: + self.crosswalk_tracker['alignment_status'] = 'aligned' + guidance_text = "斑马线已对准,继续前行。" + + # 添加数据面板 + data_for_panel = { + "状态": "对准斑马线", + "引导": guidance_text, + "角度": f"{self.crosswalk_tracker['last_angle']:.1f}°", + "偏移": f"{(self.crosswalk_tracker['last_center_x_ratio'] - 0.5):.2f}" + } + frame_visualizations.append({ + "type": "data_panel", + "data": data_for_panel, + "position": (25, image_height - 75) + }) + + return guidance_text + + def _execute_state_machine(self, mask, image, frame_visualizations, + image_height, image_width, curr_gray): + """执行状态机逻辑""" + if self.current_state == STATE_ONBOARDING: + return self._handle_onboarding(mask, image, frame_visualizations, + image_height, image_width) + elif self.current_state == STATE_NAVIGATING: + return self._handle_navigating(mask, image, frame_visualizations, + image_height, image_width, curr_gray) + elif self.current_state == STATE_MANEUVERING_TURN: + return self._handle_maneuvering_turn(mask, image, frame_visualizations, + image_height, image_width) + elif self.current_state == STATE_LOCKING_ON: + return self._handle_locking_on(frame_visualizations) + elif self.current_state == STATE_AVOIDING_OBSTACLE: + return self._handle_avoiding_obstacle(mask, image, frame_visualizations, + image_height, image_width) + + return "" + + def _handle_onboarding(self, mask, image, frame_visualizations, image_height, image_width): + """处理上盲道状态""" + image_center_x = image_width / 2 + vp_features = self._get_vanishing_point_features(mask) + + if vp_features and vp_features['fit_error'] < self.VP_FIT_ERROR_THRESHOLD: + # 使用灭点法 + VP, L_center = vp_features["VP"], vp_features["L_center"] + + if self.onboarding_step == ONBOARDING_STEP_ROTATION: + if abs(VP[0] - image_center_x) < (image_width * self.ONBOARDING_ALIGN_THRESHOLD_RATIO): + guidance_text = "方向已对正!现在校准位置。" + self.onboarding_step = ONBOARDING_STEP_TRANSLATION + else: + guidance_text = "请向左转动。" if VP[0] < image_center_x else "请向右转动。" + + angle_error_px = VP[0] - image_center_x + self._add_data_panel(frame_visualizations, { + "状态": "上盲道 (方向)", + "引导": guidance_text, + "角度": f"{angle_error_px:.1f}px", + "偏移": "待校准" + }, (25, image_height - 75)) + + elif self.onboarding_step == ONBOARDING_STEP_TRANSLATION: + L_center_bottom_x = self._calculate_line_x_at_y(L_center, image_height - 1) + + if L_center_bottom_x: + center_offset_pixels = L_center_bottom_x - image_center_x + center_offset_ratio = abs(center_offset_pixels) / image_width + + if center_offset_ratio < self.ONBOARDING_CENTER_OFFSET_THRESHOLD_RATIO: + guidance_text = "校准完成!您已在盲道上,开始前行。" + self.current_state = STATE_NAVIGATING + else: + guidance_text = "请向左平移。" if L_center_bottom_x < image_center_x else "请向右平移。" + + self._add_data_panel(frame_visualizations, { + "状态": "上盲道 (位置)", + "引导": guidance_text, + "角度": "已对准", + "偏移": f"{center_offset_ratio * 100:.1f}%" + }, (25, image_height - 75)) + else: + guidance_text = "请向前移动,让盲道更清晰。" + else: + # 使用像素域方法 + pixel_features = self._get_pixel_domain_features(mask, image.shape) + if not pixel_features: + return "" + self._add_navigation_info_visualization(pixel_features, image_height, image_width, frame_visualizations) + guidance_text = self._handle_pixel_domain_onboarding( + pixel_features, image_height, image_width, frame_visualizations + ) + + return guidance_text + + def _handle_navigating(self, mask, image, frame_visualizations, + image_height, image_width, curr_gray): + """处理常规导航状态""" + image_center_x = image_width / 2 + + # 提取路径特征 + features = self._get_pixel_domain_features(mask, image.shape) + if not features: + return "路径特征提取失败" + self._add_navigation_info_visualization(features, image_height, image_width, frame_visualizations) + + # 转弯检测 + if self.turn_cooldown_frames == 0: + corner_info = self._detect_sharp_corner(features['centerline_data']) + if corner_info: + self._update_turn_tracker(corner_info) + + if self.turn_detection_tracker['consecutive_hits'] >= 3: + stable_corner_info = self.turn_detection_tracker['corner_info'] + corner_y = stable_corner_info['corner_point_pixel'][1] + turn_trigger_y_threshold = image_height * 0.65 + + if corner_y > turn_trigger_y_threshold: + # 触发转弯 + direction_text = '右' if self.turn_detection_tracker['direction'] == 'right' else '左' + self.current_state = STATE_MANEUVERING_TURN + self.maneuver_target_info = stable_corner_info + self.maneuver_step = MANEUVER_STEP_1_ISSUE_COMMAND + self._reset_turn_tracker() + # 不再播报"到达转弯处",直接返回空字符串,让后续逻辑处理 + return "" + else: + # 不再预告转弯,继续常规导航 + pass + + # 优先级1:障碍物检测(最高优先级) + obstacles = self._check_obstacles(image, mask, frame_visualizations) + if obstacles: + # 获取主要障碍物 + main_obstacle = obstacles[0] + obstacle_name = main_obstacle.get('name', '') + current_time = time.time() + + # 检查是否需要播报(避免重复) + should_announce = False + if obstacle_name != self.last_obstacle_speech: + # 不同障碍物,立即播报 + should_announce = True + self.last_obstacle_speech = obstacle_name + self.last_obstacle_speech_time = current_time + elif current_time - self.last_obstacle_speech_time > self.obstacle_speech_cooldown: + # 同一障碍物但超过冷却时间,再次播报 + should_announce = True + self.last_obstacle_speech_time = current_time + + if should_announce: + # 不进入完整的避障流程,只是警告 + # 设置待播报的障碍物语音,而不是直接返回 + self.pending_obstacle_voice = self._speech_for_obstacle(obstacle_name) + # 如果不需要播报,继续常规导航 + else: + # 没有障碍物,清空记录 + self.last_obstacle_speech = "" + self.pending_obstacle_voice = None + + # 优先级2:常规导航(左移/右移/左转/右转 > 直行) + return self._generate_navigation_guidance( + features, image_height, image_width, frame_visualizations + ) + + def _handle_maneuvering_turn(self, mask, image, frame_visualizations, + image_height, image_width): + """处理转弯状态""" + features = self._get_pixel_domain_features(mask, image.shape) + if not features: + return "丢失路径,重新搜索。" + self._add_navigation_info_visualization(features, image_height, image_width, frame_visualizations) + if self.maneuver_step == MANEUVER_STEP_1_ISSUE_COMMAND: + direction_text = '右' if self.maneuver_target_info['direction'] == 'right' else '左' + guidance_text = f"请向{direction_text}平移。" + + poly_func = features['poly_func'] + y_check = image_height * 0.7 + self.maneuver_target_info['old_path_center_x'] = poly_func(y_check) + + self.maneuver_step = MANEUVER_STEP_2_WAIT_FOR_SHIFT + + self._add_data_panel(frame_visualizations, { + "状态": "处理转弯", + "引导": guidance_text, + "步骤": "发出指令", + "方向": direction_text + }, (25, image_height - 75)) + + return guidance_text + + elif self.maneuver_step == MANEUVER_STEP_2_WAIT_FOR_SHIFT: + old_path_x = self.maneuver_target_info.get('old_path_center_x') + if old_path_x is None: + self.maneuver_step = MANEUVER_STEP_1_ISSUE_COMMAND + return "" + + poly_func = features['poly_func'] + y_check = image_height * 0.7 + current_path_x = poly_func(y_check) + shift_distance = abs(current_path_x - old_path_x) + + centerline_data = features['centerline_data'] + width_at_check_y = self._get_width_at_y(centerline_data, y_check) + + if shift_distance > (width_at_check_y * 0.5): + guidance_text = "检测到已移动,开始对准新方向。" + self.maneuver_step = MANEUVER_STEP_3_ALIGN_ON_NEW_PATH + else: + direction_text = '右' if self.maneuver_target_info['direction'] == 'right' else '左' + guidance_text = f"请继续向{direction_text}平移。" + + self._add_data_panel(frame_visualizations, { + "状态": "处理转弯", + "引导": guidance_text, + "步骤": "等待平移", + "偏移量": f"{shift_distance:.1f}px" + }, (25, image_height - 75)) + + return guidance_text + + elif self.maneuver_step == MANEUVER_STEP_3_ALIGN_ON_NEW_PATH: + poly_func = features['poly_func'] + y_check = image_height * 0.5 + current_path_x_at_center = poly_func(y_check) + + pixel_error = current_path_x_at_center - image_width / 2 + center_offset_ratio = abs(pixel_error) / image_width + + if center_offset_ratio < self.NAV_CENTER_OFFSET_THRESHOLD_RATIO: + guidance_text = "已对准新路径,请向前直行。" + self.current_state = STATE_NAVIGATING + self.maneuver_target_info = None + self.turn_cooldown_frames = self.TURN_COOLDOWN_DURATION + else: + move_direction = "右" if pixel_error > 0 else "左" + guidance_text = f"请向{move_direction}微调,对准盲道。" + + self._add_data_panel(frame_visualizations, { + "状态": "处理转弯", + "引导": guidance_text, + "步骤": "对准新路径", + "误差": f"{center_offset_ratio * 100:.1f}%" + }, (25, image_height - 75)) + + return guidance_text + + def _handle_locking_on(self, frame_visualizations): + """处理锁定状态""" + if not self.lock_on_data: + self.current_state = STATE_NAVIGATING + return "" + + main_obstacle = self.lock_on_data['main_obstacle'] + + # 添加脉冲特效 + self._add_obstacle_visualization(main_obstacle, frame_visualizations, pulse_effect=True) + + # 检查时间 + if time.time() - self.lock_on_data['start_time'] > 0.7: + self.avoidance_plan = self.lock_on_data['avoidance_plan'] + self.avoidance_step_index = 0 + self.current_state = STATE_AVOIDING_OBSTACLE + self.lock_on_data = None + + return "" + + def _handle_avoiding_obstacle(self, mask, image, frame_visualizations, + image_height, image_width): + """处理避障状态""" + if not self.avoidance_plan or self.avoidance_step_index >= len(self.avoidance_plan): + self.current_state = STATE_NAVIGATING + self.avoidance_plan = None + return "避让完成,已回到盲道。" + + step = self.avoidance_plan[self.avoidance_step_index] + + if step['type'] == 'sidestep_clear': + direction = step['direction'] + + if self.obstacle_detector: + final_obstacles = self._detect_obstacles(image, mask) + else: + final_obstacles = [] + + if final_obstacles: + guidance_text = f"路径被挡住,请向{'右' if direction == 'right' else '左'}侧平移。" + else: + guidance_text = "好的,请停下侧移。" + self.avoidance_step_index += 1 + + self._add_data_panel(frame_visualizations, { + "状态": "避障中", + "引导": guidance_text, + "步骤": "侧向移出", + "方向": direction + }, (25, image_height - 75)) + + return guidance_text + + elif step['type'] == 'forward_pass': + # 简化处理,直接进入下一步 + self.avoidance_step_index += 1 + return "向前直行几步越过障碍物。然后说‘好了’。" + + elif step['type'] == 'sidestep_return': + direction = step['direction'] + features = self._get_pixel_domain_features(mask, image.shape) + + if not features: + return f"没看到盲道,请向{'右' if direction == 'right' else '左'}侧小幅移动。" + + poly_func = features['poly_func'] + y_target = image_height * 0.5 + x_target = poly_func(y_target) + + center_offset_pixels = x_target - image_width / 2 + center_offset_ratio = abs(center_offset_pixels) / image_width + + if center_offset_ratio < self.NAV_CENTER_OFFSET_THRESHOLD_RATIO: + guidance_text = "已回到盲道。" + self.avoidance_step_index += 1 + else: + guidance_text = "向右平移,对准盲道" if center_offset_pixels > 0 else "向左平移,对准盲道" + + self._add_data_panel(frame_visualizations, { + "状态": "避障中", + "引导": guidance_text, + "步骤": "回归盲道", + "偏移": f"{center_offset_ratio * 100:.1f}%" + }, (25, image_height - 75)) + + return guidance_text + + # ========== 辅助方法 ========== + + def _get_vanishing_point_features(self, mask): + """提取灭点特征""" + try: + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if not contours: + return None + main_contour = max(contours, key=cv2.contourArea) + if cv2.contourArea(main_contour) < 5000: + return None + + rect = cv2.minAreaRect(main_contour) + center, _, angle = rect + angle_rad = np.deg2rad(angle) + R = np.array([[np.cos(angle_rad), -np.sin(angle_rad)], + [np.sin(angle_rad), np.cos(angle_rad)]]) + points_transformed = np.dot(main_contour.squeeze(1) - center, R) + left_points = main_contour.squeeze(1)[points_transformed[:, 0] < 0] + right_points = main_contour.squeeze(1)[points_transformed[:, 0] >= 0] + + if len(left_points) < 20 or len(right_points) < 20: + return None + + [vx_l, vy_l, x_l, y_l] = cv2.fitLine(left_points, cv2.DIST_L2, 0, 0.01, 0.01) + [vx_r, vy_r, x_r, y_r] = cv2.fitLine(right_points, cv2.DIST_L2, 0, 0.01, 0.01) + + a1, b1, c1 = vy_l, -vx_l, vx_l * y_l - vy_l * x_l + a2, b2, c2 = vy_r, -vx_r, vx_r * y_r - vy_r * x_r + determinant = a1 * b2 - a2 * b1 + + if abs(determinant) < 1e-6: + return None + + vp_x = (b1 * c2 - b2 * c1) / determinant + vp_y = (a2 * c1 - a1 * c2) / determinant + L_center = ((vx_l + vx_r) / 2, (vy_l + vy_r) / 2, (x_l + x_r) / 2, (y_l + y_r) / 2) + + total_dist = 0 + for pt in left_points: + total_dist += abs((pt[0] - x_l) * vy_l - (pt[1] - y_l) * vx_l) + for pt in right_points: + total_dist += abs((pt[0] - x_r) * vx_r - (pt[1] - y_r) * vy_r) + fit_error = total_dist / (len(left_points) + len(right_points)) + + return {"VP": (vp_x, vp_y), "L_center": L_center, "fit_error": fit_error} + except: + return None + + def _get_pixel_domain_features(self, mask, image_shape): + """提取像素域特征""" + try: + height, width = image_shape[:2] + + centerline_data = [] + for y in range(height - 1, int(height * 0.3), -5): + row = mask[y, :] + x_pixels = np.where(row > 0)[0] + if x_pixels.size > 10: + x_min, x_max = x_pixels[0], x_pixels[-1] + path_width = x_max - x_min + center_x = (x_min + x_max) / 2 + centerline_data.append([y, center_x, path_width]) + + if len(centerline_data) < 20: + return None + + data = np.array(centerline_data) + + # 应用中心线平滑 + data = self._smooth_centerline(data) + + # 检测急转弯 + sharp_turn_index = self._find_sharp_turn(data) + if sharp_turn_index is not None: + cutoff_index = int(sharp_turn_index * 0.6) + if cutoff_index >= 10: + data = data[:cutoff_index] + + y_coords, x_coords, widths = data[:, 0], data[:, 1], data[:, 2] + weights = widths + + # 原始多项式拟合 + coeffs_raw = np.polyfit(y_coords, x_coords, 2, w=weights) + + # 【新增】对多项式系数进行时间平滑 + self.poly_coeffs_history.append(coeffs_raw.copy()) + if len(self.poly_coeffs_history) > self.poly_coeffs_history_max: + self.poly_coeffs_history.pop(0) + + # 使用指数加权移动平均平滑系数 + if len(self.poly_coeffs_history) >= 3: + # 权重:最近的帧权重更高 + weights_time = np.array([0.7 ** (len(self.poly_coeffs_history) - i - 1) + for i in range(len(self.poly_coeffs_history))]) + weights_time = weights_time / np.sum(weights_time) + + # 加权平均系数 + coeffs = np.zeros_like(coeffs_raw) + for i, hist_coeffs in enumerate(self.poly_coeffs_history): + coeffs += hist_coeffs * weights_time[i] + else: + coeffs = coeffs_raw + + poly_func = np.poly1d(coeffs) + + curvature_proxy = abs(coeffs[0]) + tangent_slope = 2 * coeffs[0] * height + coeffs[1] + tangent_angle_rad = np.arctan(tangent_slope) + + return { + "poly_func": poly_func, + "curvature_proxy": curvature_proxy, + "tangent_angle_rad": tangent_angle_rad, + "centerline_data": np.array(centerline_data) + } + except Exception as e: + logger.warning(f"Pixel domain feature calculation failed: {e}") + return None + + def _find_sharp_turn(self, data): + """查找急转弯点""" + window_size = 5 + angle_threshold = 30 + + for i in range(len(data) - 2 * window_size): + front_window = data[i:i + window_size] + back_window = data[i + window_size:i + 2 * window_size] + + front_dir = [front_window[-1, 1] - front_window[0, 1], + front_window[-1, 0] - front_window[0, 0]] + back_dir = [back_window[-1, 1] - back_window[0, 1], + back_window[-1, 0] - back_window[0, 0]] + + angle1 = np.arctan2(front_dir[1], front_dir[0]) + angle2 = np.arctan2(back_dir[1], back_dir[0]) + angle_diff = abs(np.degrees(angle2 - angle1)) + + if angle_diff > 180: + angle_diff = 360 - angle_diff + + if angle_diff > angle_threshold: + return i + window_size + + return None + + def _detect_sharp_corner(self, centerline_data, angle_threshold_deg=45): + """检测急转弯""" + try: + if len(centerline_data) < 15: + return None + points_in_range = np.array(centerline_data) + num_points = len(points_in_range) + + window_size = max(5, int(num_points * 0.15)) + best_turn_info = None + max_angle_diff = 0 + + for i in range(0, num_points - 2 * window_size, 2): + front_segment = points_in_range[i:i + window_size] + back_segment = points_in_range[i + window_size:i + 2 * window_size] + + if len(front_segment) < 3 or len(back_segment) < 3: + continue + + front_y = front_segment[:, 0] + front_x = front_segment[:, 1] + front_coeffs = np.polyfit(front_y, front_x, 1) + front_slope = front_coeffs[0] + + back_y = back_segment[:, 0] + back_x = back_segment[:, 1] + back_coeffs = np.polyfit(back_y, back_x, 1) + back_slope = back_coeffs[0] + + front_angle = np.arctan(front_slope) + back_angle = np.arctan(back_slope) + + angle_diff_rad = back_angle - front_angle + angle_diff_deg = abs(np.degrees(angle_diff_rad)) + + if angle_diff_deg > max_angle_diff and angle_diff_deg > angle_threshold_deg: + max_angle_diff = angle_diff_deg + corner_point_idx = i + window_size + corner_point = points_in_range[corner_point_idx] + + direction = "right" if angle_diff_rad > 0 else "left" + + post_turn_segment = points_in_range[ + corner_point_idx:min(corner_point_idx + window_size * 2, num_points)] + if len(post_turn_segment) > 0: + post_turn_center_x = np.mean(post_turn_segment[:, 1]) + else: + post_turn_center_x = corner_point[1] + + best_turn_info = { + "corner_point_pixel": (corner_point[1], corner_point[0]), + "turn_angle": max_angle_diff, + "direction": direction, + "post_turn_center_x": post_turn_center_x, + "corner_point_idx": corner_point_idx + } + + return best_turn_info + + except Exception as e: + logger.warning(f"Corner detection error: {e}") + return None + + def _update_turn_tracker(self, corner_info): + """更新转弯追踪器""" + detected_direction = corner_info['direction'] + + if detected_direction == self.turn_detection_tracker['direction']: + self.turn_detection_tracker['consecutive_hits'] += 1 + else: + self.turn_detection_tracker['direction'] = detected_direction + self.turn_detection_tracker['consecutive_hits'] = 1 + + self.turn_detection_tracker['last_seen_frame'] = self.frame_counter + self.turn_detection_tracker['corner_info'] = corner_info + + def _reset_turn_tracker(self): + """重置转弯追踪器""" + self.turn_detection_tracker = { + 'direction': None, + 'consecutive_hits': 0, + 'last_seen_frame': 0, + 'corner_info': None + } + + def _calculate_line_x_at_y(self, line_params, y_target): + """计算直线在特定y坐标的x值""" + vx, vy, x0, y0 = line_params + if abs(vy) < 1e-6: + return None + t = (y_target - y0) / vy + x = x0 + t * vx + return x + + def _get_width_at_y(self, centerline_data, y_target): + """获取特定y坐标的路径宽度""" + ys = centerline_data[:, 0] + ws = centerline_data[:, 2] + idx = np.abs(ys - y_target).argmin() + return ws[idx] + + def _detect_obstacles(self, image, path_mask=None): + """检测障碍物 - Day 20 性能优化版本""" + # Day 20: 移除过量日志,只在 DEBUG 模式下输出详细信息 + + if self.obstacle_detector is None: + return [] + + # 【新增】打印白名单类别(只在第一次调用时打印) + if not hasattr(self, '_classes_printed'): + self._classes_printed = True + if hasattr(self.obstacle_detector, 'WHITELIST_CLASSES'): + logger.info(f"[障碍物检测] 白名单类别数: {len(self.obstacle_detector.WHITELIST_CLASSES)}") + + try: + detected_obstacles = self.obstacle_detector.detect(image, path_mask=path_mask) + + # 补充一些可能缺失但后续代码需要的字段 + H, W = image.shape[:2] + for obj in detected_obstacles: + if 'mask' in obj and obj['mask'] is not None: + y_coords, x_coords = np.where(obj['mask'] > 0) + if len(y_coords) > 0 and len(x_coords) > 0: + x1, y1 = int(np.min(x_coords)), int(np.min(y_coords)) + x2, y2 = int(np.max(x_coords)), int(np.max(y_coords)) + obj['box_coords'] = (x1, y1, x2, y2) + + if 'y_position_ratio' not in obj: + obj['y_position_ratio'] = obj.get('center_y', 0) / H + if 'label' not in obj: + obj['label'] = obj.get('name', 'unknown') + if 'center' not in obj: + obj['center'] = (obj.get('center_x', 0), obj.get('center_y', 0)) + if 'confidence' not in obj: + obj['confidence'] = 0.5 + + # Day 20: 只输出一行摘要日志 + if detected_obstacles and self.frame_counter % 30 == 0: + names = [o.get('name', '?') for o in detected_obstacles[:3]] + logger.info(f"[障碍物] 检测到 {len(detected_obstacles)} 个: {names}") + + return detected_obstacles + + except Exception as e: + logger.error(f"[障碍物检测] 失败: {e}") + import traceback + traceback.print_exc() + return [] + + def _check_and_set_obstacle_voice(self, obstacles): + """检查障碍物并设置待播报的语音""" + if not obstacles: + self.last_obstacle_speech = "" + self.pending_obstacle_voice = None + return + + # 筛选近距离障碍物(提高阈值,只有非常近才报警) + NEAR_DISTANCE_Y_THRESHOLD = 0.75 # 提高到0.75,障碍物底部必须在画面下方75%以下 + NEAR_DISTANCE_AREA_THRESHOLD = 0.12 # 提高到0.12,障碍物必须占画面12%以上 + + near_obstacles = [] + for obs in obstacles: + if (obs.get('bottom_y_ratio', 0) > NEAR_DISTANCE_Y_THRESHOLD or + obs.get('area_ratio', 0) > NEAR_DISTANCE_AREA_THRESHOLD): + near_obstacles.append(obs) + + if near_obstacles: + # 获取最主要的障碍物(面积最大) + main_obstacle = max(near_obstacles, key=lambda x: x.get('area_ratio', 0)) + obstacle_name = main_obstacle.get('name', '') + current_time = time.time() + + # 检查是否需要播报 + should_announce = False + if obstacle_name != self.last_obstacle_speech: + # 不同障碍物,立即播报 + should_announce = True + self.last_obstacle_speech = obstacle_name + self.last_obstacle_speech_time = current_time + elif current_time - self.last_obstacle_speech_time > self.obstacle_speech_cooldown: + # 同一障碍物但超过冷却时间,再次播报 + should_announce = True + self.last_obstacle_speech_time = current_time + + if should_announce: + self.pending_obstacle_voice = self._speech_for_obstacle(obstacle_name) + else: + # 没有近距离障碍物 + self.last_obstacle_speech = "" + self.pending_obstacle_voice = None + + def _check_obstacles(self, image, mask, frame_visualizations): + """检查并处理障碍物""" + # 使用缓存策略 + if self.frame_counter % self.OBSTACLE_DETECTION_INTERVAL == 0: + final_obstacles = self._detect_obstacles(image, mask) + # 【新增】稳定化障碍物,避免重复叠加 + if hasattr(self, 'prev_gray') and self.prev_gray is not None: + curr_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + final_obstacles = self._stabilize_obstacle_list( + final_obstacles, + self.last_detected_obstacles, + self.prev_gray, + curr_gray, + image.shape[:2] + ) + self.last_detected_obstacles = final_obstacles + self.last_obstacle_detection_frame = self.frame_counter + else: + if self.frame_counter - self.last_obstacle_detection_frame < self.OBSTACLE_CACHE_DURATION_FRAMES: + final_obstacles = self.last_detected_obstacles + else: + final_obstacles = [] + + # 添加可视化 + for obs in final_obstacles: + self._add_obstacle_visualization(obs, frame_visualizations) + + # 筛选近距离障碍物(提高阈值,只有非常近才报警) + NEAR_DISTANCE_Y_THRESHOLD = 0.75 # 提高到0.75,障碍物底部必须在画面下方75%以下 + NEAR_DISTANCE_AREA_THRESHOLD = 0.12 # 提高到0.12,障碍物必须占画面12%以上 + + near_obstacles = [ + obs for obs in final_obstacles + if (obs.get('bottom_y_ratio', 0) > NEAR_DISTANCE_Y_THRESHOLD or + obs.get('area_ratio', 0) > NEAR_DISTANCE_AREA_THRESHOLD) + ] + + return near_obstacles + + def _plan_avoidance(self, obstacle_info, image_width): + """规划避障路径""" + obstacle_center_x = obstacle_info['center_x'] + image_center_x = image_width / 2 + + if obstacle_center_x < image_center_x: + turn_direction = 'right' + else: + turn_direction = 'left' + + plan = [ + {'type': 'sidestep_clear', 'direction': turn_direction}, + {'type': 'forward_pass'}, + {'type': 'sidestep_return', 'direction': 'left' if turn_direction == 'right' else 'right'} + ] + return plan + + def _generate_navigation_guidance(self, features, image_height, image_width, frame_visualizations): + """生成导航指引""" + poly_func = features['poly_func'] + is_curve = features['curvature_proxy'] > self.CURVATURE_PROXY_THRESHOLD + lookahead_ratio = 0.6 if is_curve else 0.4 + y_target = image_height * lookahead_ratio + x_target = poly_func(y_target) + + # 添加中心线可视化 + plot_y = np.arange(int(image_height * 0.3), image_height, 5).astype(int) + plot_x = poly_func(plot_y).astype(int) + centerline_points = np.vstack((plot_x, plot_y)).T.tolist() + frame_visualizations.append({ + "type": "polyline", + "points": centerline_points, + "color": "yellow", + "width": 2 + }) + + # 添加目标点 + frame_visualizations.append({ + "type": "circle", + "center": [int(x_target), int(y_target)], + "radius": 10, + "color": "red" + }) + + # 计算导航指令(优先级:转向/平移 > 直行) + center_offset_pixels = x_target - image_width / 2 + center_offset_ratio = abs(center_offset_pixels) / image_width + orientation_error_rad = features['tangent_angle_rad'] + + # 先检查是否需要转向(左转/右转) + if orientation_error_rad > self.NAV_ORIENTATION_THRESHOLD_RAD: + guidance_text = "左转" + elif orientation_error_rad < -self.NAV_ORIENTATION_THRESHOLD_RAD: + guidance_text = "右转" + # 再检查是否需要平移(左移/右移) + elif center_offset_ratio > self.NAV_CENTER_OFFSET_THRESHOLD_RATIO: + guidance_text = "右移" if center_offset_pixels > 0 else "左移" + # 最后才是直行 + else: + guidance_text = "保持直行" + + # 添加数据面板 + self._add_data_panel(frame_visualizations, { + "状态": "常规导航", + "引导": guidance_text, + "朝向": f"{np.degrees(orientation_error_rad):.1f}°", + "偏移": f"{center_offset_ratio * 100:.1f}%" + }, (25, image_height - 75)) + + return guidance_text + + def _handle_pixel_domain_onboarding(self, pixel_features, image_height, image_width, frame_visualizations): + """处理像素域的上盲道引导""" + image_center_x = image_width / 2 + orientation_error_rad = pixel_features['tangent_angle_rad'] + poly_func = pixel_features['poly_func'] + + y_bottom = image_height - 1 + x_target_bottom = poly_func(y_bottom) + center_offset_pixels = x_target_bottom - image_center_x + center_offset_ratio = abs(center_offset_pixels) / image_width + + if self.onboarding_step == ONBOARDING_STEP_ROTATION: + if abs(orientation_error_rad) < self.ONBOARDING_ORIENTATION_THRESHOLD_RAD: + guidance_text = "方向已对正!现在校准位置。" + self.onboarding_step = ONBOARDING_STEP_TRANSLATION + else: + guidance_text = "请向左转动。" if orientation_error_rad > 0.1 else "请向右转动。" + + self._add_data_panel(frame_visualizations, { + "状态": "上盲道 (方向)", + "引导": guidance_text, + "角度": f"{np.degrees(orientation_error_rad):.1f}°", + "偏移": "待校准" + }, (25, image_height - 75)) + self._add_navigation_info_visualization(pixel_features, image_height, image_width, frame_visualizations) + + return guidance_text + + elif self.onboarding_step == ONBOARDING_STEP_TRANSLATION: + if center_offset_ratio < self.ONBOARDING_CENTER_OFFSET_THRESHOLD_RATIO: + guidance_text = "校准完成!您已在盲道上,开始前行。" + self.current_state = STATE_NAVIGATING + else: + guidance_text = "请向右平移。" if center_offset_pixels > 0 else "请向左平移。" + + self._add_data_panel(frame_visualizations, { + "状态": "上盲道 (位置)", + "引导": guidance_text, + "角度": "已对准", + "偏移": f"{center_offset_ratio * 100:.1f}%" + }, (25, image_height - 75)) + + return guidance_text + + def _add_obstacle_visualization(self, obstacle, visualizations, pulse_effect=False): + """添加障碍物可视化(简化版:仅边框,近红远黄)""" + try: + # 计算障碍物危险等级 + bottom_y_ratio = obstacle.get('bottom_y_ratio', 0) + area_ratio = obstacle.get('area_ratio', 0) + + # 判断是否为近距离障碍物 + is_near = bottom_y_ratio > 0.7 or area_ratio > 0.1 # 近距离障碍物 + + # 添加 mask 边框可视化(如果有) + if 'mask' in obstacle and obstacle['mask'] is not None: + mask = obstacle['mask'] + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + + if contours: + # 找到最大的轮廓 + max_contour = max(contours, key=cv2.contourArea) + points = max_contour.squeeze(1)[::5].tolist() + + # 根据距离选择边框颜色:近距离红色,远距离黄色 + if is_near: + outline_color = "rgba(255, 0, 0, 1.0)" # 红色 + thickness = 3 + else: + outline_color = "rgba(255, 255, 0, 0.8)" # 黄色 + thickness = 2 + + # 只添加边框,不添加填充和文字 + visualizations.append({ + "type": "outline", + "points": points, + "color": outline_color, + "thickness": thickness + }) + except Exception as e: + logger.error(f"[_add_obstacle_visualization] 添加障碍物可视化失败: {e}") + + def _add_navigation_info_visualization(self, features, image_height, image_width, frame_visualizations): + """添加导航计算信息的可视化""" + if not features: + return + + try: + # 获取计算结果 + poly_func = features.get('poly_func') + curvature_proxy = features.get('curvature_proxy', 0) + tangent_angle_rad = features.get('tangent_angle_rad', 0) + tangent_angle_deg = np.degrees(tangent_angle_rad) + + # 绘制切线方向 + if poly_func: + # 在画面底部计算切线 + y_bottom = image_height - 50 + x_bottom = poly_func(y_bottom) + + # 计算切线的终点 + tangent_length = 100 + dx = tangent_length * np.cos(tangent_angle_rad) + dy = tangent_length * np.sin(tangent_angle_rad) + + # 【新增】绘制基准虚线(垂直向上) + baseline_length = 80 + frame_visualizations.append({ + "type": "dashed_line", + "start": [int(x_bottom), int(y_bottom)], + "end": [int(x_bottom), int(y_bottom - baseline_length)], + "color": "rgba(255, 255, 255, 0.6)", # 白色虚线 + "thickness": 2 + }) + + # 添加切线可视化 + frame_visualizations.append({ + "type": "arrow", + "start": [int(x_bottom), int(y_bottom)], + "end": [int(x_bottom + dx), int(y_bottom - dy)], # 注意Y轴方向 + "color": "rgba(0, 255, 255, 0.8)", # 青色 + "thickness": 3, + "tip_length": 0.3 + }) + + # 【新增】绘制夹角弧线标识 + arc_radius = 40 + # 基准线角度是-90度(向上),切线角度是tangent_angle_deg + # OpenCV中角度是从右侧水平线逆时针测量 + start_angle = -90 # 基准线(垂直向上) + end_angle = -90 + tangent_angle_deg # 切线角度 + frame_visualizations.append({ + "type": "angle_arc", + "center": [int(x_bottom), int(y_bottom)], + "radius": arc_radius, + "start_angle": start_angle, + "end_angle": end_angle, + "color": "rgba(255, 200, 0, 0.8)", # 橙黄色 + "thickness": 2 + }) + + # 添加角度文字(文字大小减半) + frame_visualizations.append({ + "type": "text_with_bg", + "text": f"角度: {tangent_angle_deg:.1f}°", + "position": [int(x_bottom + 10), int(y_bottom - 30)], + "font_scale": 0.3, # 从0.6减半到0.3 + "color": "rgba(255, 255, 255, 1.0)", + "bg_color": "rgba(0, 0, 0, 0.7)" + }) + + # 添加曲率信息(文字大小减半) + if curvature_proxy > 0.00001: + curve_text = "弯道" if curvature_proxy > 0.00005 else "缓弯" + frame_visualizations.append({ + "type": "text_with_bg", + "text": f"{curve_text}: {curvature_proxy:.2e}", + "position": [20, 100], + "font_scale": 0.25, # 从0.5减半到0.25 + "color": "rgba(255, 255, 0, 1.0)", + "bg_color": "rgba(0, 0, 0, 0.7)" + }) + + # 显示中心线数据点 + if 'centerline_data' in features: + centerline_data = features['centerline_data'] + # 在画面中部显示路径宽度 + mid_idx = len(centerline_data) // 2 + if mid_idx < len(centerline_data): + y, x, width = centerline_data[mid_idx] + # 绘制宽度指示线(改为双向箭头) + frame_visualizations.append({ + "type": "double_arrow", # 新增双向箭头类型 + "start": [int(x - width/2), int(y)], + "end": [int(x + width/2), int(y)], + "color": "rgba(0, 255, 0, 0.8)", + "thickness": 2, + "tip_length": 0.15 + }) + # 添加宽度文字(文字大小减半) + frame_visualizations.append({ + "type": "text_with_bg", + "text": f"宽度: {width:.0f}px", + "position": [int(x - 30), int(y - 10)], + "font_scale": 0.25, # 从0.5减半到0.25 + "color": "rgba(255, 255, 255, 1.0)", + "bg_color": "rgba(0, 0, 0, 0.7)" + }) + except Exception as e: + logger.error(f"添加导航信息可视化失败: {e}") + + def _add_data_panel(self, visualizations, data, position): + """添加数据面板""" + visualizations.append({ + "type": "data_panel", + "data": data, + "position": position + }) + + def _add_crosswalk_info_visualization(self, viz_data, image_height, image_width, visualizations): + """添加斑马线检测信息的精美可视化""" + try: + # 1. 绘制斑马线中心点标识(大十字) + center_x = int(viz_data['center_x_ratio'] * image_width) + center_y = int(viz_data['center_y_ratio'] * image_height) + + cross_size = 20 if viz_data['in_arrival'] else 15 # 减小尺寸 + cross_color = "rgba(255, 100, 0, 1.0)" if viz_data['in_arrival'] else "rgba(0, 200, 255, 0.8)" + + # 水平线 + visualizations.append({ + "type": "line", + "start": [center_x - cross_size, center_y], + "end": [center_x + cross_size, center_y], + "color": cross_color, + "thickness": 2 # 减细 + }) + # 垂直线 + visualizations.append({ + "type": "line", + "start": [center_x, center_y - cross_size], + "end": [center_x, center_y + cross_size], + "color": cross_color, + "thickness": 2 # 减细 + }) + + # 2. 绘制指向斑马线的箭头(从画面中心指向斑马线中心) + screen_center_x = image_width // 2 + screen_center_y = image_height // 2 + + # 只在斑马线不在画面中心时绘制箭头 + distance = np.sqrt((center_x - screen_center_x)**2 + (center_y - screen_center_y)**2) + if distance > 80: # 提高到80像素才画箭头(减少干扰) + visualizations.append({ + "type": "arrow", + "start": [screen_center_x, screen_center_y], + "end": [center_x, center_y], + "color": "rgba(255, 150, 0, 0.6)", # 降低透明度 + "thickness": 2, # 减细 + "tip_length": 0.15 # 减小箭头 + }) + + # 3. 添加信息面板(右上角) + panel_x = image_width - 180 + panel_y = 20 + + # 准备面板数据 + panel_data = { + "斑马线": viz_data['stage'], + "面积": f"{viz_data['area_ratio']*100:.1f}%", + "方位": viz_data['position'], + } + + if viz_data['has_occlusion']: + panel_data["状态"] = "被遮挡" + elif viz_data['in_arrival']: + panel_data["状态"] = "可过马路" + + visualizations.append({ + "type": "data_panel", + "data": panel_data, + "position": (panel_x, panel_y) + }) + + # 4. 添加面积进度条(视觉化面积大小) + bar_width = 150 + bar_height = 20 + bar_x = image_width - bar_width - 20 + bar_y = panel_y + 90 + + # 背景框 + visualizations.append({ + "type": "rectangle", + "top_left": (bar_x, bar_y), + "bottom_right": (bar_x + bar_width, bar_y + bar_height), + "color": "rgba(50, 50, 50, 0.7)", + "filled": True + }) + + # 进度填充(0-100%,但最多显示到arrival阈值0.25对应100%) + progress = min(viz_data['area_ratio'] / 0.25, 1.0) + fill_width = int(bar_width * progress) + + # 根据阶段选择颜色 + if viz_data['in_arrival']: + fill_color = "rgba(0, 255, 100, 0.8)" # 绿色(可过马路) + elif viz_data['area_ratio'] >= 0.18: + fill_color = "rgba(255, 200, 0, 0.8)" # 黄色(接近) + elif viz_data['area_ratio'] >= 0.08: + fill_color = "rgba(0, 200, 255, 0.8)" # 青色(靠近) + else: + fill_color = "rgba(100, 150, 255, 0.8)" # 蓝色(发现) + + visualizations.append({ + "type": "rectangle", + "top_left": (bar_x + 2, bar_y + 2), + "bottom_right": (bar_x + fill_width - 2, bar_y + bar_height - 2), + "color": fill_color, + "filled": True + }) + + # 进度条标签(使用中文文本,字体减小) + visualizations.append({ + "type": "text_with_bg", + "text": f"接近度: {int(progress * 100)}%", + "position": [bar_x, bar_y - 18], + "font_scale": 0.25, # 减小字体 + "color": "rgba(255, 255, 255, 1.0)", + "bg_color": "rgba(0, 0, 0, 0.7)" + }) + + except Exception as e: + logger.error(f"添加斑马线可视化失败: {e}") + + def _add_traffic_light_visualization(self, color, visualizations, image_height, image_width): + """添加红绿灯状态可视化""" + # 在右上角绘制红绿灯指示器 + x = image_width - 100 + y = 50 + + # 背景框 + visualizations.append({ + "type": "rectangle", + "top_left": (x - 40, y - 40), + "bottom_right": (x + 40, y + 100), + "color": "rgba(0, 0, 0, 0.5)", + "filled": True + }) + + # 三个圆形灯 + colors = { + "red": [(255, 0, 0), (50, 0, 0), (50, 0, 0)], + "yellow": [(50, 50, 0), (255, 255, 0), (50, 50, 0)], + "green": [(0, 50, 0), (0, 50, 0), (0, 255, 0)], + "unknown": [(50, 50, 50), (50, 50, 50), (50, 50, 50)] + } + + light_colors = colors.get(color, colors["unknown"]) + positions = [y - 20, y + 20, y + 60] + + for i, (pos_y, light_color) in enumerate(zip(positions, light_colors)): + # 外圈 + visualizations.append({ + "type": "circle", + "center": [x, pos_y], + "radius": 18, + "color": f"rgba(100, 100, 100, 1.0)", + "thickness": 2 + }) + # 内圈(灯的颜色) + visualizations.append({ + "type": "circle", + "center": [x, pos_y], + "radius": 15, + "color": f"rgba({light_color[0]}, {light_color[1]}, {light_color[2]}, 1.0)", + "filled": True + }) + + # 标签 + visualizations.append({ + "type": "text_with_bg", + "text": f"信号灯: {color}", + "position": [x - 35, y + 90], + "font_scale": 0.5, + "color": "rgba(255, 255, 255, 1.0)", + "bg_color": "rgba(0, 0, 0, 0.7)" + }) + + def _to_cn_obstacle(self, name: str) -> str: + """转换障碍物名称为中文""" + try: + key = (name or '').strip().lower() + return _OBSTACLE_NAME_CN.get(key, '障碍物') + except: + return '障碍物' + + def _speech_for_obstacle(self, name: str) -> str: + k = (name or '').strip().lower() + if k == 'person': return "前方有人,注意避让。" + if k == 'car': return "前方有车,注意避让。" + if k == 'bicycle': return "前方有自行车,停一下。" + if k == 'motorcycle': return "前方有摩托车,停一下。" + if k == 'bus': return "前方有公交车,停一下。" + if k == 'truck': return "前方有卡车,停一下。" + if k == 'scooter': return "前方有电瓶车,停一下。" + if k == 'stroller': return "前方有婴儿车,停一下。" + if k == 'dog': return "前方有狗,停一下。" + if k == 'animal': return "前方有动物,停一下。" + return "前方有障碍物,注意避让。" + + def _draw_command_button(self, image, text): + """绘制底部中央的指令按钮(与斑马线模式统一)""" + try: + H, W = image.shape[:2] + full_text = f"当前指令:{text if text else '—'}" + + # 按钮参数 + font_px = 14 + pad_x, pad_y = 14, 8 + bottom_margin = 28 + + # 计算文字尺寸 + if PIL_AVAILABLE: + try: + from PIL import Image as PILImage, ImageDraw, ImageFont + # 尝试加载中文字体 + font = None + for font_path in [ + "/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc", + "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", + ]: + if os.path.exists(font_path): + try: + font = ImageFont.truetype(font_path, font_px) + break + except: + continue + if font: + bbox = ImageDraw.Draw(PILImage.new('RGB', (1, 1))).textbbox((0, 0), full_text, font=font) + tw = max(1, bbox[2] - bbox[0]) + th = max(1, bbox[3] - bbox[1]) + else: + scale = font_px / 24.0 + (tw, th), _ = cv2.getTextSize(full_text, cv2.FONT_HERSHEY_SIMPLEX, scale, 1) + except: + scale = font_px / 24.0 + (tw, th), _ = cv2.getTextSize(full_text, cv2.FONT_HERSHEY_SIMPLEX, scale, 1) + else: + scale = font_px / 24.0 + (tw, th), _ = cv2.getTextSize(full_text, cv2.FONT_HERSHEY_SIMPLEX, scale, 1) + + # 计算按钮位置(底部居中) + bw = tw + pad_x * 2 + bh = th + pad_y * 2 + radius = max(10, bh // 2) + + cx = W // 2 + left = max(8, cx - bw // 2) + top = H - bottom_margin - bh + right = min(W - 8, left + bw) + bottom = top + bh + + # 绘制半透明圆角背景 + overlay = image.copy() + bg_color = (26, 32, 41) # 深色背景 + border_color = (60, 76, 102) # 边框 + + # 圆角矩形(中间+两个圆) + cv2.rectangle(overlay, (left + radius, top), (right - radius, bottom), bg_color, -1) + cv2.circle(overlay, (left + radius, (top + bottom) // 2), radius, bg_color, -1) + cv2.circle(overlay, (right - radius, (top + bottom) // 2), radius, bg_color, -1) + + # 混合半透明 + cv2.addWeighted(overlay, 0.75, image, 0.25, 0, image) + + # 绘制边框 + cv2.rectangle(image, (left + radius, top), (right - radius, bottom), border_color, 1) + cv2.circle(image, (left + radius, (top + bottom) // 2), radius, border_color, 1) + cv2.circle(image, (right - radius, (top + bottom) // 2), radius, border_color, 1) + + # 绘制文字 + text_x = left + pad_x + text_y = top + pad_y + th + + if PIL_AVAILABLE and 'font' in locals() and font: + # 使用PIL绘制中文 + pil_img = PILImage.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + draw = ImageDraw.Draw(pil_img) + draw.text((text_x, top + pad_y), full_text, font=font, fill=(255, 255, 255)) + image = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) + else: + # 使用OpenCV绘制 + cv2.putText(image, full_text, (text_x, text_y), + cv2.FONT_HERSHEY_SIMPLEX, scale, (255, 255, 255), 1) + + return image + except Exception as e: + logger.error(f"绘制指令按钮失败: {e}") + return image + + def _parse_color(self, color_str): + """解析颜色字符串,返回BGR格式""" + try: + if color_str.startswith('rgba('): + values = color_str[5:-1].split(',') + r, g, b = int(values[0]), int(values[1]), int(values[2]) + return (b, g, r) # OpenCV 使用 BGR 格式 + elif color_str == 'yellow': + return (0, 255, 255) + elif color_str == 'red': + return (0, 0, 255) + else: + return (0, 0, 255) # 默认红色 + except: + return (0, 0, 255) + + def _draw_data_panel_no_bg(self, image, data, position=(15, 15)): + """绘制数据面板(无黑底版本)""" + if not PIL_AVAILABLE: + return image + + try: + pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + draw = ImageDraw.Draw(pil_img, "RGBA") + + env_scale = float(os.getenv("AIGLASS_PANEL_SCALE", "0.7")) + base_font_size = max(10, int(round(14 * env_scale))) + + # 尝试多种字体,确保中文显示 + font = None + font_paths = [ + "/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc", + "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", + ] + + for font_path in font_paths: + try: + if os.path.exists(font_path): + font = ImageFont.truetype(font_path, base_font_size) + break + except: + continue + + if font is None: + font = ImageFont.load_default() + + # 绘制文本,使用描边效果 + y_offset = position[1] + for key, value in data.items(): + text = f"{key}: {value}" + + # 绘制黑色描边(8个方向) + for dx in [-1, 0, 1]: + for dy in [-1, 0, 1]: + if dx != 0 or dy != 0: + draw.text((position[0] + dx, y_offset + dy), text, + font=font, fill=(0, 0, 0, 255)) + + # 绘制白色文字 + draw.text((position[0], y_offset), text, font=font, fill=(255, 255, 255, 255)) + y_offset += base_font_size + 5 + + return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) + + except Exception as e: + logger.warning(f"绘制数据面板失败: {e}") + return image + + + def _draw_visualizations(self, image, viz_elements): + """增强的可视化绘制方法""" + if not viz_elements: + return image + + # 获取当前时间用于动画效果 + current_time = time.time() + + # 分离不同类型的元素 + panel_elements = [v for v in viz_elements if v.get("type") == "data_panel"] + standard_elements = [v for v in viz_elements if v.get("type") != "data_panel"] + + # 第一遍:绘制填充(Day 20 优化:用轮廓替代半透明填充,大幅提升性能) + for element in standard_elements: + elem_type = element.get("type") + + if elem_type in ['blind_path_mask', 'obstacle_mask', 'crosswalk_mask']: + points = np.array(element.get("points", []), dtype=np.int32) + if points.size > 0: + color = self._parse_color(element.get("color", "rgba(255, 255, 255, 0.5)")) + + # Day 20 性能优化:只绘制轮廓,不做半透明填充 + # 原因:逐像素混合 (~200-300ms) 改为轮廓绘制 (~5-10ms) + + # 根据类型选择轮廓样式 + if elem_type == 'blind_path_mask': + thickness = 3 # 盲道用粗线 + elif elem_type == 'obstacle_mask': + thickness = 2 # 障碍物用中等线 + else: + thickness = 2 # 斑马线用中等线 + + # 直接用 points 绘制轮廓(快速) + cv2.polylines(image, [points], isClosed=True, color=color, thickness=thickness) + + # 第二遍:绘制轮廓和其他元素 + for element in standard_elements: + elem_type = element.get("type") + + # 【新增】绘制直线 + if elem_type == 'line': + start = tuple(element.get("start", (0, 0))) + end = tuple(element.get("end", (100, 100))) + color = self._parse_color(element.get("color", "rgba(255, 255, 255, 1.0)")) + thickness = element.get("thickness", 2) + cv2.line(image, start, end, color, thickness) + + # 绘制轮廓/描边 + elif elem_type == 'outline': + points = np.array(element.get("points", []), dtype=np.int32) + if points.size > 0: + color = self._parse_color(element.get("color", "rgba(255, 255, 255, 1.0)")) + thickness = element.get("thickness", 3) + cv2.polylines(image, [points], isClosed=True, color=color, thickness=thickness) + + # 绘制折线 + elif elem_type == 'polyline': + points = np.array(element.get("points", []), dtype=np.int32) + if points.size > 0: + color = self._parse_color(element.get("color", "rgba(255, 255, 0, 1.0)")) + thickness = element.get("width", 2) + cv2.polylines(image, [points], isClosed=False, color=color, thickness=thickness) + + # 绘制圆形 + elif elem_type == 'circle': + center = tuple(element.get("center", (0, 0))) + radius = element.get("radius", 10) + color = self._parse_color(element.get("color", "rgba(255, 0, 0, 1.0)")) + thickness = element.get("thickness", -1 if element.get("filled", True) else 2) + cv2.circle(image, center, radius, color, thickness) + + # 绘制矩形 + elif elem_type == 'rectangle': + top_left = tuple(element.get("top_left", (0, 0))) + bottom_right = tuple(element.get("bottom_right", (100, 100))) + color = self._parse_color(element.get("color", "rgba(0, 0, 0, 0.5)")) + thickness = -1 if element.get("filled", True) else 2 + cv2.rectangle(image, top_left, bottom_right, color, thickness) + + # 绘制箭头 + elif elem_type == 'arrow': + start = tuple(element.get("start", (0, 0))) + end = tuple(element.get("end", (100, 100))) + color = self._parse_color(element.get("color", "rgba(0, 255, 255, 1.0)")) + thickness = element.get("thickness", 2) + tip_length = element.get("tip_length", 0.3) + cv2.arrowedLine(image, start, end, color, thickness, tipLength=tip_length) + + # 【新增】绘制双向箭头 + elif elem_type == 'double_arrow': + start = tuple(element.get("start", (0, 0))) + end = tuple(element.get("end", (100, 100))) + color = self._parse_color(element.get("color", "rgba(0, 255, 0, 0.8)")) + thickness = element.get("thickness", 2) + tip_length = element.get("tip_length", 0.15) + # 绘制中间的线 + cv2.line(image, start, end, color, thickness) + # 绘制两端的箭头 + # 计算箭头方向向量 + dx = end[0] - start[0] + dy = end[1] - start[1] + length = np.sqrt(dx*dx + dy*dy) + if length > 0: + # 单位方向向量 + ux, uy = dx/length, dy/length + # 箭头长度 + arrow_len = length * tip_length + # 左端箭头 + tip1_x = int(start[0] + arrow_len * ux) + tip1_y = int(start[1] + arrow_len * uy) + # 绘制左端箭头(指向左) + angle = np.arctan2(dy, dx) + arrow_angle = 30 * np.pi / 180 # 箭头角度 + p1 = (int(start[0] + arrow_len * np.cos(angle - arrow_angle)), + int(start[1] + arrow_len * np.sin(angle - arrow_angle))) + p2 = (int(start[0] + arrow_len * np.cos(angle + arrow_angle)), + int(start[1] + arrow_len * np.sin(angle + arrow_angle))) + cv2.line(image, start, p1, color, thickness) + cv2.line(image, start, p2, color, thickness) + # 右端箭头(指向右) + p3 = (int(end[0] - arrow_len * np.cos(angle - arrow_angle)), + int(end[1] - arrow_len * np.sin(angle - arrow_angle))) + p4 = (int(end[0] - arrow_len * np.cos(angle + arrow_angle)), + int(end[1] - arrow_len * np.sin(angle + arrow_angle))) + cv2.line(image, end, p3, color, thickness) + cv2.line(image, end, p4, color, thickness) + + # 【新增】绘制虚线 + elif elem_type == 'dashed_line': + start = np.array(element.get("start", (0, 0))) + end = np.array(element.get("end", (100, 100))) + color = self._parse_color(element.get("color", "rgba(255, 255, 255, 0.6)")) + thickness = element.get("thickness", 2) + dash_length = 10 + gap_length = 5 + # 计算总长度和方向 + total_vec = end - start + total_len = np.linalg.norm(total_vec) + if total_len > 0: + unit_vec = total_vec / total_len + # 绘制虚线段 + current_len = 0 + while current_len < total_len: + seg_start = start + unit_vec * current_len + seg_end = start + unit_vec * min(current_len + dash_length, total_len) + cv2.line(image, tuple(seg_start.astype(int)), tuple(seg_end.astype(int)), color, thickness) + current_len += dash_length + gap_length + + # 【新增】绘制角度弧线 + elif elem_type == 'angle_arc': + center = tuple(element.get("center", (100, 100))) + radius = element.get("radius", 40) + start_angle = element.get("start_angle", -90) + end_angle = element.get("end_angle", 0) + color = self._parse_color(element.get("color", "rgba(255, 200, 0, 0.8)")) + thickness = element.get("thickness", 2) + # OpenCV的ellipse函数:startAngle和endAngle是从右侧水平线开始顺时针测量 + # 需要转换:我们的角度是从右侧水平线逆时针(数学标准) + # OpenCV需要的是从右侧水平线顺时针 + cv2_start = -end_angle # 转换为OpenCV格式 + cv2_end = -start_angle + # 确保角度范围正确 + if cv2_start > cv2_end: + cv2_start, cv2_end = cv2_end, cv2_start + cv2.ellipse(image, center, (radius, radius), 0, cv2_start, cv2_end, color, thickness) + + # 【修改】绘制带背景的文本(使用中文支持) + elif elem_type == 'text_with_bg': + text = element.get("text", "") + pos = element.get("position", [10, 30]) + font_scale = element.get("font_scale", 0.6) + color = self._parse_color(element.get("color", "rgba(255, 255, 255, 1.0)")) + + # 使用新的中文文本绘制函数 + image = self._draw_chinese_text(image, text, tuple(pos), + font_scale=font_scale, + color=color, + stroke_color=(0, 0, 0), + stroke_width=1) + + # 绘制警告图标 + elif elem_type == 'warning_icon': + pos = element.get("position", (100, 100)) + level = element.get("level", "info") + text = element.get("text", "") + flash = element.get("flash", False) + + # 根据级别选择颜色 + if level == "danger": + icon_color = (0, 0, 255) # 红色 + text_color = (255, 255, 255) + elif level == "warning": + icon_color = (0, 165, 255) # 橙色 + text_color = (255, 255, 255) + else: + icon_color = (0, 255, 255) # 黄色 + text_color = (0, 0, 0) + + # 闪烁效果 + if flash: + alpha = 0.5 + 0.5 * np.sin(current_time * 4 * np.pi) + icon_color = tuple(int(c * alpha) for c in icon_color) + + # 绘制三角形警告图标 + triangle = np.array([ + [pos[0], pos[1] - 20], + [pos[0] - 15, pos[1]], + [pos[0] + 15, pos[1]] + ], np.int32) + cv2.fillPoly(image, [triangle], icon_color) + cv2.polylines(image, [triangle], True, (255, 255, 255), 2) + + # 绘制感叹号 + cv2.putText(image, "!", (pos[0] - 5, pos[1] - 5), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) + + # 绘制文本标签(使用中文支持) + if text: + font_scale = 0.5 + # 使用新的中文文本绘制函数 + text_pos = (pos[0] - 20, pos[1] + 20) # 简化位置计算 + image = self._draw_chinese_text(image, text, text_pos, + font_scale=font_scale, + color=text_color, + stroke_color=(0, 0, 0), + stroke_width=1) + + # 普通文本 + elif elem_type == 'text': + text = element.get("text", "") + pos = tuple(element.get("pos", (10, 30))) + # 使用中文文本绘制函数 + image = self._draw_chinese_text(image, text, pos, + font_scale=0.7, + color=(255, 255, 255), + stroke_color=(0, 0, 0), + stroke_width=1) + + # 【修改】绘制数据面板(使用无黑底版本) + if PIL_AVAILABLE: + for panel in panel_elements: + image = self._draw_data_panel_no_bg(image, panel["data"], panel["position"]) + else: + # 如果没有PIL,也使用描边效果 + for panel in panel_elements: + y_offset = panel["position"][1] + for key, value in panel["data"].items(): + text = f"{key}: {value}" + # 绘制文字描边 + for dx in [-1, 0, 1]: + for dy in [-1, 0, 1]: + if dx != 0 or dy != 0: + cv2.putText(image, text, (panel["position"][0] + dx, y_offset + dy), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 3) + # 绘制白色文字 + cv2.putText(image, text, (panel["position"][0], y_offset), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1) + y_offset += 25 + + return image + + + + def _draw_chinese_text(self, image, text, position, font_scale=0.6, color=(255, 255, 255), + stroke_color=(0, 0, 0), stroke_width=1): + """绘制中文文本,使用微软雅黑字体,白字黑边""" + if not PIL_AVAILABLE: + # 如果没有PIL,回退到cv2.putText(会显示问号) + cv2.putText(image, text, position, cv2.FONT_HERSHEY_SIMPLEX, + font_scale, color, 2) + return image + + try: + # 转换为PIL图像 + pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + draw = ImageDraw.Draw(pil_img) + + # 计算字体大小(基于font_scale) + base_size = 24 # 基准字体大小 + font_size = int(base_size * font_scale / 0.6) + + # 尝试加载微软雅黑字体 + font = None + font_paths = [ + "/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc", + "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", + ] + + for font_path in font_paths: + if os.path.exists(font_path): + try: + font = ImageFont.truetype(font_path, font_size) + break + except: + continue + + if font is None: + font = ImageFont.load_default() + + # 将OpenCV的BGR颜色转换为RGB + rgb_color = (color[2], color[1], color[0]) + rgb_stroke = (stroke_color[2], stroke_color[1], stroke_color[0]) + + # 绘制文本(带描边效果) + x, y = position + # 绘制描边 + draw.text((x, y), text, font=font, fill=rgb_stroke, + stroke_width=stroke_width, stroke_fill=rgb_stroke) + # 绘制主文本 + draw.text((x, y), text, font=font, fill=rgb_color) + + # 转换回OpenCV格式 + return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) + + except Exception as e: + logger.warning(f"绘制中文文本失败: {e}") + # 回退到cv2.putText + cv2.putText(image, text, position, cv2.FONT_HERSHEY_SIMPLEX, + font_scale, color, 2) + return image + + def _draw_data_panel(self, image, data, position=(15, 15)): + """绘制数据面板(需要Pillow)""" + if not PIL_AVAILABLE: + return image + + try: + pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + draw = ImageDraw.Draw(pil_img, "RGBA") + + env_scale = float(os.getenv("AIGLASS_PANEL_SCALE", "0.65")) + base_font_size = max(8, int(round(16 * env_scale))) + padding = max(4, int(round(8 * env_scale))) + + # 尝试加载微软雅黑字体 + font = None + font_paths = [ + "/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc", + "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", + ] + + for font_path in font_paths: + if os.path.exists(font_path): + try: + font = ImageFont.truetype(font_path, base_font_size) + break + except: + continue + + if font is None: + font = ImageFont.load_default() + + text_lines = [f"{key}: {value}" for key, value in data.items()] + text_to_draw = "\n".join(text_lines) + + bbox = draw.textbbox(position, text_to_draw, font=font) + text_w, text_h = bbox[2] - bbox[0], bbox[3] - bbox[1] + + bg_rect = [ + (position[0] - padding, position[1] - padding), + (position[0] + text_w + padding, position[1] + text_h + padding) + ] + draw.rectangle(bg_rect, fill=(0, 0, 0, 128)) + draw.text(position, text_to_draw, font=font, fill=(255, 255, 255, 255)) + + return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) + + except Exception: + return image + + def reset(self): + """重置导航器状态""" + self.current_state = STATE_ONBOARDING + self.onboarding_step = ONBOARDING_STEP_ROTATION + self.maneuver_step = MANEUVER_STEP_1_ISSUE_COMMAND + self.maneuver_target_info = None + self.turn_detection_tracker = { + 'direction': None, + 'consecutive_hits': 0, + 'last_seen_frame': 0, + 'corner_info': None + } + self.turn_cooldown_frames = 0 + self.avoidance_plan = None + self.avoidance_step_index = 0 + self.lock_on_data = None + + # 重置光流和平滑相关 + self.flow_points = {} + self.flow_grace = {} + self.centerline_history = [] + self.blind_miss_ttl = 0 + self.cross_miss_ttl = 0 + + # 重置语音相关 + self.pending_obstacle_voice = None + self.last_obstacle_speech = "" + self.last_obstacle_speech_time = 0 + + # 重置多项式系数历史 + self.poly_coeffs_history = [] + self.crosswalk_tracker = { + 'stage': 'not_detected', + 'consecutive_frames': 0, + 'last_area_ratio': 0.0, + 'last_bottom_y_ratio': 0.0, + 'last_center_x_ratio': 0.5, + 'position_announced': False, + 'alignment_status': 'not_aligned', + 'last_seen_frame': 0, + 'last_angle': 0.0 + } + self.frame_counter = 0 + self.prev_gray = None + self.prev_blind_path_mask = None + self.prev_crosswalk_mask = None + self.prev_obstacle_cache = [] + self.last_guidance_message = "" + self.last_detected_obstacles = [] + self.last_obstacle_detection_frame = 0 + self.last_obstacle_speech = "" + self.last_obstacle_speech_time = 0 + self.last_any_speech_time = 0 + self.crosswalk_ready_announced = False + self.crosswalk_ready_time = 0 + self.traffic_light_history.clear() + self.last_traffic_light_state = "unknown" + self.green_light_announced = False + + def _stabilize_obstacle_list(self, obstacles, prev_obstacles, prev_gray, curr_gray, + image_shape, threshold=0.5): + """稳定障碍物检测结果,避免重复叠加""" + if not obstacles or prev_gray is None or curr_gray is None: + return obstacles + + H, W = image_shape + stabilized = [] + used_prev = set() # 记录已使用的历史障碍物 + + # 对每个当前检测到的障碍物 + for curr_obs in obstacles: + if 'mask' not in curr_obs or curr_obs['mask'] is None: + stabilized.append(curr_obs) + continue + + curr_mask = curr_obs['mask'] + best_match = None + best_iou = 0 + best_idx = -1 + + # 寻找最佳匹配的历史障碍物 + if prev_obstacles: + for idx, prev_obs in enumerate(prev_obstacles): + if idx in used_prev or 'mask' not in prev_obs: + continue + + # 使用光流预测历史障碍物的新位置 + flow_mask = self._predict_mask_with_flow(prev_obs['mask'], prev_gray, curr_gray) + if flow_mask is None: + flow_mask = prev_obs['mask'] + + # 计算IoU + inter = np.logical_and(curr_mask > 0, flow_mask > 0).sum() + union = np.logical_or(curr_mask > 0, flow_mask > 0).sum() + iou = float(inter) / float(union) if union > 0 else 0.0 + + if iou > best_iou and iou > threshold: + best_iou = iou + best_match = flow_mask + best_idx = idx + + # 如果找到匹配,融合结果 + if best_match is not None and best_idx >= 0: + used_prev.add(best_idx) + # 融合当前检测和光流预测,提高稳定性 + fused_mask = ((0.8 * curr_mask + 0.2 * best_match) > 128).astype(np.uint8) * 255 + curr_obs['mask'] = fused_mask + # 更新派生属性 + self._update_obstacle_properties(curr_obs, H, W) + + stabilized.append(curr_obs) + + return stabilized + + def _speech_for_obstacle(self, name: str) -> str: + k = (name or '').strip().lower() + if k == 'person': return "前方有人,注意避让。" + if k == 'car': return "前方有车,注意避让。" + if k == 'bicycle': return "前方有自行车,停一下。" + if k == 'motorcycle': return "前方有摩托车,停一下。" + if k == 'bus': return "前方有公交车,停一下。" + if k == 'truck': return "前方有卡车,停一下。" + if k == 'scooter': return "前方有电瓶车,停一下。" + if k == 'stroller': return "前方有婴儿车,停一下。" + if k == 'dog': return "前方有狗,停一下。" + if k == 'animal': return "前方有动物,停一下。" + return "前方有障碍物,注意避让。" + + def _update_obstacle_properties(self, obs, H, W): + """更新障碍物的派生属性""" + if 'mask' not in obs or obs['mask'] is None: + return + + mask = obs['mask'] + y_coords, x_coords = np.where(mask > 0) + + if len(y_coords) > 0: + obs['area'] = len(y_coords) + obs['center_x'] = float(np.mean(x_coords)) + obs['center_y'] = float(np.mean(y_coords)) + obs['y_position_ratio'] = obs['center_y'] / H + obs['area_ratio'] = obs['area'] / (H * W) + obs['bottom_y_ratio'] = np.max(y_coords) / H + + # 更新边界框 + x1, y1 = int(np.min(x_coords)), int(np.min(y_coords)) + x2, y2 = int(np.max(x_coords)), int(np.max(y_coords)) + obs['box_coords'] = (x1, y1, x2, y2) \ No newline at end of file diff --git a/workflow_crossstreet.py b/workflow_crossstreet.py new file mode 100644 index 0000000..c9fb3c6 --- /dev/null +++ b/workflow_crossstreet.py @@ -0,0 +1,1832 @@ +# -*- coding: utf-8 -*- +""" +过马路工作流(简化版 - 仅斑马线检测,但保留导航功能) +- 直连版本,无 Celery/Redis +- 仅检测斑马线,无交通灯检测 +- 保留斑马线导航功能(角度、偏移计算) +- 保留可视化(引导线、目标点等) +- 每帧都进行分割;若该帧分割失败,则用上一帧从掩码打点的光流特征点追踪,重建掩码保持位置,直到下一次分割检出 +""" +import torch +import os +import time +import logging +import numpy as np +import cv2 +from dataclasses import dataclass +from typing import Optional, List, Dict, Any +# 【移除】from audio_player import play_voice_text - 不在工作流内部播放音频 + +# 可选:用于更精致的数据面板(与 blindpath 一致) +try: + from PIL import Image, ImageDraw, ImageFont + PIL_AVAILABLE = True +except ImportError: + PIL_AVAILABLE = False + Image, ImageDraw, ImageFont = None, None, None + +# 可选:自动启用障碍物检测(与 blindpath 一致) +try: + from obstacle_detector_client import ObstacleDetectorClient +except Exception: + ObstacleDetectorClient = None + +# 红绿灯检测模块 +try: + import trafficlight_detection + TRAFFIC_LIGHT_AVAILABLE = True +except Exception: + TRAFFIC_LIGHT_AVAILABLE = False + trafficlight_detection = None + +# Day 20: TensorRT 模型加载工具 +try: + from model_utils import get_best_model_path +except ImportError: + def get_best_model_path(path): return path + +logger = logging.getLogger(__name__) + +# ========== 状态常量 ========== +STATE_SEEKING = "SEEKING_CROSSWALK" # 寻找并对准远处的斑马线 +STATE_WAIT_LIGHT = "WAIT_TRAFFIC_LIGHT" # 等待红绿灯判定 +STATE_CROSSING = "CROSSING" # 正在过马路 + +# ========== 配置参数 ========== +CROSSWALK_MIN_CONF = float(os.getenv('CROSSWALK_MIN_CONF', '0.3')) +CROSSWALK_MIN_AREA = int(os.getenv('CROSSWALK_MIN_AREA', '5000')) +BLIND_MIN_CONF = float(os.getenv('BLIND_MIN_CONF', '0.34')) # 盲道最低置信度(更高,防误判) +ANGLE_THRESH_DEG = float(os.getenv('CROSSWALK_ANGLE_THRESH_DEG', '5.0')) # 默认阈值略放宽 +OFFSET_THRESH = float(os.getenv('CROSSWALK_OFFSET_THRESH', '0.08')) # 默认阈值略放宽 + +# 远距离对准阈值(更宽松,避免过于敏感) +SEEKING_ANGLE_THRESH_DEG = 15.0 # 远距离角度阈值(更宽松) +SEEKING_OFFSET_THRESH = 0.20 # 远距离偏移阈值(更宽松) + +# 远距离对准阈值(判定"很近"的条件,更严格) +CROSSWALK_NEAR_AREA_RATIO = 0.30 # 斑马线占画面30%认为"很近"(提高) +CROSSWALK_NEAR_BOTTOM_RATIO = 0.80 # 斑马线底部超过画面80%认为"很近"(提高) +CROSSWALK_NEAR_MIN_HEIGHT_RATIO = 0.35 # 斑马线高度占画面35%以上(新增条件) + +# 红绿灯判定参数 +GREEN_LIGHT_STABLE_FRAMES = 5 # 绿灯稳定帧数 + +# 类别ID绑定(与训练集对应) +CW_ID = int(os.getenv("AIGLASS_SEG_CW_ID", "0")) # 斑马线 +BP_ID = int(os.getenv("AIGLASS_SEG_BP_ID", "1")) # 盲道 + +# 斑马线与盲道的同义名集合 +_CW = {'zebra_crossing', 'zebra crossing', 'zebra', 'crosswalk', 'road_crossing', 'road crossing'} +_BP = {'blind_path', 'tactile_paving', 'tactile paving', 'blind path'} + +# 盲道"真伪判定"阈值 +BP_VALID_IOU_THR = 0.40 # 与斑马线 IoU 超过此值,判为"混淆",不当盲道 + +# 追踪/打点参数 +INNER_OFFSET_PX_LOCK = 5 +EDGE_DILATE_PX = 2 +LK_PARAMS = dict( + winSize=(21, 21), + maxLevel=3, + criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 12, 0.03) +) +FEATURE_PARAMS = dict( + maxCorners=600, + qualityLevel=0.001, + minDistance=5, + blockSize=7 +) + +# 时序平滑与保活 +MASK_EMA_ALPHA = 0.6 # EMA 平滑权重 +TRACK_MIN_POINTS = 30 # 追踪最少特征点阈值 +TRACK_RESEED_EVERY = 12 # 每隔 N 帧在成功分割时重播种一次特征点 + +# 可视化颜色(BGR) +VIS_COLORS = { + "crosswalk": (0, 165, 255), # 橙色 + "centerline": (255, 255, 0), # 青色 - 引导中心线 + "target_point": (255, 0, 255), # 粉色 - 引导目标点 + "hint": (0, 255, 255), # 黄色 + "stripes": (0, 128, 255), # 橙蓝 - 条纹线段 + "heading": (0, 0, 255), # 红色 - 方向箭头 +} + +@dataclass +class CrossStreetResult: + """过马路导航结果""" + annotated_image: Optional[np.ndarray] = None + guidance_text: str = "" + visualizations: List[Dict[str, Any]] = None + should_switch_to_blindpath: bool = False + + def __post_init__(self): + if self.visualizations is None: + self.visualizations = [] + +# ========== 辅助函数 ========== +def _score_of(d) -> float: + """兼容不同检测结构,取出置信度;取不到就给 0.0(保守)""" + for k in ("conf", "confidence", "score", "prob"): + v = getattr(d, k, None) + if v is not None: + try: + return float(v) + except Exception: + break + return 0.0 + +def _norm_name(s: str) -> str: + """标准化名称""" + return str(s).lower().replace('_', ' ').strip() + +def _in_set(name: str, pool: set) -> bool: + """检查名称是否在集合中""" + return _norm_name(name) in {_norm_name(x) for x in pool} + +def _mask_iou(a: np.ndarray, b: np.ndarray) -> float: + """计算两个mask的IoU""" + if a is None or b is None: + return 0.0 + ai = a > 0 + bi = b > 0 + inter = np.logical_and(ai, bi).sum() + union = np.logical_or(ai, bi).sum() + return float(inter) / float(union + 1e-6) + +def _looks_like_blind_path(bp_mask: np.ndarray, cw_mask: np.ndarray, H: int, W: int) -> bool: + """几何+互斥检查,过滤'横条纹/路牙'伪盲道""" + if bp_mask is None: + return False + ys, xs = np.where(bp_mask > 0) + if xs.size < 80: # 太小的片段直接丢 + return False + + # 计算主轴角度 + pts = np.stack([xs.astype(np.float32), ys.astype(np.float32)], axis=1) + mean = pts.mean(axis=0) + cov = np.cov((pts - mean).T) + eigvals, eigvecs = np.linalg.eig(cov) + v = eigvecs[:, np.argmax(eigvals)] + angle_deg = np.degrees(np.arctan2(v[1], v[0])) + if angle_deg > 90: angle_deg -= 180 + if angle_deg < -90: angle_deg += 180 + + h = (ys.max() - ys.min() + 1) + w = (xs.max() - xs.min() + 1) + aspect = h / float(w + 1e-6) # 期望盲道"更竖一些" + iou_cw = _mask_iou(bp_mask, cw_mask) + + # 1) 横向条纹过滤(放宽到 20°,给远端/轻微倾斜更多空间) + if abs(angle_deg) <= 20.0: + return False + # 2) 形状过滤(放宽到 0.52) + if aspect < 0.52: + return False + # 3) 与斑马线高度重叠 + if iou_cw >= BP_VALID_IOU_THR: + return False + # 4) 底边窄条(疑似路牙)过滤 + bottom = bp_mask[int(0.88 * H):, :] + if bottom.sum() > 0: + bottom_share = bottom.sum() / float((bp_mask > 0).sum() + 1e-6) + if bottom_share > 0.50 and (w / float(W)) < 0.35: + return False + return True + +def _cls_of(d): + """提取检测对象的类别ID""" + for k in ("cls", "class_id", "category_id"): + v = getattr(d, k, None) + if v is not None: + try: + return int(v) + except Exception: + pass + return None + +class CrossStreetNavigator: + """简化版过马路导航器 - 仅斑马线检测但保留导航(每帧分割 + 失败用光流保活)""" + + def __init__(self, seg_model=None, coco_model=None, obs_model=None, device_id: str = "esp32"): + self.seg_model = seg_model + self.device_id = device_id + self.frame_counter = 0 + self.last_guidance = "" + self.crosswalk_detected = False + self.last_guide_time = 0 + self.guide_interval = 3.0 # 语音引导间隔(秒) + + # —— 状态机 —— + self.state = STATE_SEEKING # 当前状态 + self.green_light_counter = 0 # 绿灯稳定帧计数 + self.last_traffic_light = None # 上一帧检测到的红绿灯 + self.last_seeking_guidance = "" # 上一次SEEKING状态的引导文本(用于节流) + self.last_waiting_light_time = 0 # 上次播报"正在等待绿灯"的时间 + self.crossing_end_announced = False # 是否已播报"过马路结束"(CROSSING状态用) + self.last_crosswalk_seen_time = 0 # 上次检测到斑马线的时间 + self.last_blindpath_announce_time = 0 # 上次播报盲道提示的时间(用于节流重复播报) + + # —— 时序/追踪状态 —— + self.prev_mask = None # 上一帧稳定后的二值掩码 + self.prev_mask_float = None # 掩码 EMA 浮点缓冲 + self.prev_mask_ts = 0.0 # 最近一次掩码更新时间 + self.old_gray = None # 上一帧灰度图(供 LK) + self.p0 = None # 上一帧特征点(N,1,2) + self.last_seed_frame = 0 # 上次播种特征点的帧号 + + # —— 避障(与 blindpath 一致) —— + self.obstacle_detector = obs_model + self.prev_gray = None + self.last_detected_obstacles = [] + self.last_obstacle_detection_frame = 0 + self.OBSTACLE_DETECTION_INTERVAL = int(os.getenv("AIGLASS_OBS_INTERVAL", "15")) + self.OBSTACLE_CACHE_DURATION_FRAMES = int(os.getenv("AIGLASS_OBS_CACHE_FRAMES", "0")) + + # 【新增】斑马线检测间隔配置 + self.CROSSWALK_DETECTION_INTERVAL = int(os.getenv("AIGLASS_CROSSWALK_INTERVAL", "4")) # 每4帧检测一次 + self.last_crosswalk_detection_frame = 0 + self.last_detected_crosswalk_mask = None + self.last_detected_blindpath_mask = None + + # 自动启用障碍物检测(若未传入 obs_model) + if self.obstacle_detector is None and os.getenv("AIGLASS_OBS_AUTO", "1") != "0": + try: + if ObstacleDetectorClient is not None: + model_path = os.getenv("AIGLASS_OBS_MODEL", "model/yoloe-11l-seg.pt") + # Day 20: 优先使用 TensorRT 引擎 + model_path = get_best_model_path(model_path) + self.obstacle_detector = ObstacleDetectorClient(model_path) + logger.info("[CROSS_STREET] 障碍物检测器已自动加载") + else: + logger.warning("[CROSS_STREET] 未找到 ObstacleDetectorClient,跳过自动加载") + except Exception as e: + logger.warning(f"[CROSS_STREET] 自动加载障碍物检测器失败: {e}") + + # 如果模型有 predict 方法但没有 detect 方法,进行包装 + if self.seg_model and hasattr(self.seg_model, 'predict') and not hasattr(self.seg_model, 'detect'): + logger.info("[CROSS_STREET] 包装 YOLO 模型") + self.seg_model = YOLOModelWrapper(self.seg_model) + + # 【新增】打印检测间隔配置 + logger.info(f"[CROSS_STREET] 斑马线检测间隔: 每{self.CROSSWALK_DETECTION_INTERVAL}帧") + + # 确保模型在 GPU 上 + # Day 20: TensorRT 引擎不需要 .to() + if self.seg_model and torch.cuda.is_available(): + try: + # 检查是否是 TensorRT 引擎 + model_path = getattr(self.seg_model, 'ckpt_path', '') or '' + if not model_path.endswith('.engine'): + if hasattr(self.seg_model, 'model') and hasattr(self.seg_model.model, 'to'): + self.seg_model.model.to('cuda') + elif hasattr(self.seg_model, 'to'): + self.seg_model.to('cuda') + logger.info("[CROSS_STREET] 模型已移至 GPU") + else: + logger.info("[CROSS_STREET] TensorRT 引擎已加载,跳过 .to()") + except Exception as e: + logger.warning(f"[CROSS_STREET] 无法将模型移至 GPU: {e}") + + def reset(self): + """重置状态""" + self.frame_counter = 0 + self.last_guidance = "" + self.crosswalk_detected = False + self.last_guide_time = 0 + # 状态机 + self.state = STATE_SEEKING + self.green_light_counter = 0 + self.last_traffic_light = None + self.last_seeking_guidance = "" + self.last_waiting_light_time = 0 + self.crossing_end_announced = False + self.last_crosswalk_seen_time = 0 + self.last_blindpath_announce_time = 0 + # 追踪 + self.prev_mask = None + self.prev_mask_float = None + self.prev_mask_ts = 0.0 + self.old_gray = None + self.p0 = None + self.last_seed_frame = 0 + # 避障缓存 + self.prev_gray = None + self.last_detected_obstacles = [] + self.last_obstacle_detection_frame = 0 + # 重置红绿灯检测状态 + if TRAFFIC_LIGHT_AVAILABLE and trafficlight_detection: + trafficlight_detection.reset_detection_state() + logger.info("[CROSS_STREET] 导航器已重置") + + # —— 打点/追踪辅助 —— + @staticmethod + def _inner_offset_edge(mask_bin: np.ndarray, offset_px=5, edge_dilate_px=2) -> np.ndarray: + """对二值掩码做内收后提边缘,便于在目标内部打光流特征点""" + if offset_px > 0: + k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*offset_px+1, 2*offset_px+1)) + eroded = cv2.erode(mask_bin.astype(np.uint8), k, iterations=1) + else: + eroded = mask_bin.astype(np.uint8) + edges = cv2.Canny(eroded*255, 50, 150) + if edge_dilate_px > 0: + k2 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*edge_dilate_px+1, 2*edge_dilate_px+1)) + edges = cv2.dilate(edges, k2, iterations=1) + return edges # uint8 0/255 + + @staticmethod + def _hull_mask_from_points(points: np.ndarray, shape_hw: tuple) -> Optional[np.ndarray]: + """从一组点的凸包生成二值掩码""" + if points is None or len(points) < 3: + return None + H, W = shape_hw + pts = points.reshape(-1, 2).astype(np.float32) + hull = cv2.convexHull(pts.reshape(-1,1,2)) + poly = hull.reshape(-1, 2).astype(np.int32) + mask = np.zeros((H, W), dtype=np.uint8) + cv2.fillPoly(mask, [poly], 1) + return mask + + def _seed_points_from_mask(self, gray: np.ndarray, mask_bin: np.ndarray) -> Optional[np.ndarray]: + """基于掩码的内收边界,播种 LK 光流特征点""" + edge_mask = self._inner_offset_edge(mask_bin, offset_px=INNER_OFFSET_PX_LOCK, edge_dilate_px=EDGE_DILATE_PX) + try: + pts = cv2.goodFeaturesToTrack(gray, mask=edge_mask, **FEATURE_PARAMS) + return pts + except Exception as e: + logger.warning(f"[CROSS_STREET] goodFeaturesToTrack 失败: {e}") + return None + + @staticmethod + def _ensure_binary_mask(mask: np.ndarray, shape_hw: tuple) -> np.ndarray: + """阈值化并调整尺寸到图像大小,返回二值 0/1 uint8""" + H, W = shape_hw + if mask.dtype != np.uint8: + mask = (mask > 0.5).astype(np.uint8) + if mask.shape[:2] != (H, W): + mask = cv2.resize(mask, (W, H), interpolation=cv2.INTER_NEAREST) + return (mask > 0).astype(np.uint8) + + def _postprocess_mask(self, mask_bin: np.ndarray) -> np.ndarray: + """形态学净化 + 移除小碎片,缓解毛边与噪点""" + try: + m = (mask_bin > 0).astype(np.uint8) + H, W = m.shape[:2] + # 轻度开闭操作,去毛刺并填补细小空洞 + k_open = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) + k_close = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (7, 7)) + m = cv2.morphologyEx(m, cv2.MORPH_OPEN, k_open, iterations=1) + m = cv2.morphologyEx(m, cv2.MORPH_CLOSE, k_close, iterations=1) + # 移除过小连通域 + num_labels, labels, stats, _ = cv2.connectedComponentsWithStats(m, connectivity=8) + if num_labels > 1: + areas = stats[1:, cv2.CC_STAT_AREA] + keep_area = max(int(0.003 * H * W), 1500) # 约 0.3% 画面或 1500 px + keep_labels = np.where(areas >= keep_area)[0] + 1 + m2 = np.zeros_like(m) + for lbl in keep_labels: + m2[labels == lbl] = 1 + if m2.sum() > 0: + m = m2 + return (m > 0).astype(np.uint8) + except Exception: + return (mask_bin > 0).astype(np.uint8) + + @staticmethod + def _largest_contour(mask_bin: np.ndarray): + cts, _ = cv2.findContours((mask_bin>0).astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if not cts: + return None + return max(cts, key=cv2.contourArea) + + + def _mask_center(self, mask: np.ndarray): + """用图像矩计算掩码质心;失败返回 None""" + M = cv2.moments((mask > 0).astype(np.uint8)) + if abs(M["m00"]) < 1e-6: + return None + cx = int(M["m10"] / M["m00"]) + cy = int(M["m01"] / M["m00"]) + return (cx, cy) + + def _is_crosswalk_near(self, mask: np.ndarray, h: int, w: int) -> bool: + """判断斑马线是否"很近"(到用户跟前)- 更严格的判定条件""" + if mask is None: + return False + area = int(mask.sum()) + area_ratio = float(area) / float(h * w) + + # 获取底部位置和高度 + ys = np.where(mask > 0)[0] + if ys.size == 0: + return False + top_y = int(ys.min()) + bottom_y = int(ys.max()) + mask_height = bottom_y - top_y + 1 + height_ratio = float(mask_height) / float(h) + bottom_ratio = float(bottom_y) / float(h) + + # 需要同时满足多个条件(AND逻辑,更严格): + # 1. 面积足够大 + # 2. 底部位置足够低 + # 3. 高度占比足够大(防止只是因为抬头导致的误判) + is_near = (area_ratio >= CROSSWALK_NEAR_AREA_RATIO and + bottom_ratio >= CROSSWALK_NEAR_BOTTOM_RATIO and + height_ratio >= CROSSWALK_NEAR_MIN_HEIGHT_RATIO) + return is_near + + def _is_crosswalk_almost_done(self, mask: np.ndarray, h: int, w: int) -> bool: + """判断斑马线是否"快消失"(斑马线在画面底部且面积很小)- 更严格的判定""" + if mask is None: + return False + area = int(mask.sum()) + area_ratio = float(area) / float(h * w) + + ys = np.where(mask > 0)[0] + if ys.size == 0: + return False + + # 计算斑马线的顶部和底部位置 + top_y = int(ys.min()) + bottom_y = int(ys.max()) + + top_ratio = float(top_y) / float(h) + bottom_ratio = float(bottom_y) / float(h) + + # 更严格的判断条件(避免过早触发): + # 1. 顶部已经过了画面70%(>0.7),说明斑马线主要在画面最下方 + # 2. 底部接近画面底部(>0.85) + # 3. 面积很小(<0.08),说明快消失了 + is_almost_done = (top_ratio > 0.7 and bottom_ratio > 0.85 and area_ratio < 0.08) + return is_almost_done + + def _compute_远_distance_alignment(self, mask: np.ndarray, h: int, w: int) -> tuple: + """计算远距离对准的角度和偏移(基于mask几何,不依赖条纹)""" + ys, xs = np.where(mask > 0) + if xs.size < 50: + return 0.0, 0.0 + + # 使用PCA计算主方向 + pts = np.stack([xs.astype(np.float32), ys.astype(np.float32)], axis=1) + mean = pts.mean(axis=0) + cov = np.cov((pts - mean).T) + eigvals, eigvecs = np.linalg.eig(cov) + v = eigvecs[:, np.argmax(eigvals)] + + # 计算角度(相对水平) + angle = np.degrees(np.arctan2(v[1], v[0])) + if angle > 90: angle -= 180 + if angle < -90: angle += 180 + + # 计算水平偏移(质心相对画面中心) + cx = float(mean[0]) + offset = (cx - (w / 2.0)) / max(1.0, w / 2.0) + + return float(angle), float(offset) + + def _draw_line_vertical_angle(self, image, center, angle_deg, length_ratio=0.7, color=(255, 255, 0), thickness=3): + """ + 以“竖直方向”为0°基准,angle_deg>0 表示左偏,<0 表示右偏。 + 在 center 处画一条通过点的直线。 + """ + H, W = image.shape[:2] + half_len = int(0.5 * length_ratio * min(H, W)) + rad = np.radians(angle_deg) + # 竖直基准: 向上的单位向量(0, -1) + # 旋转 angle 后的方向向量 = (sin, -cos) + vx = np.sin(rad); + vy = -np.cos(rad) + x0, y0 = center + p1 = (int(x0 - vx * half_len), int(y0 - vy * half_len)) + p2 = (int(x0 + vx * half_len), int(y0 + vy * half_len)) + cv2.line(image, p1, p2, color, thickness) + + def _draw_dashed_line_vertical_angle(self, image, center, angle_deg, length_ratio=0.7, + dash=12, gap=8, color=(255, 255, 255), thickness=2): + """同样以竖直为0°,画 through center 的虚线。""" + H, W = image.shape[:2] + half_len = int(0.5 * length_ratio * min(H, W)) + rad = np.radians(angle_deg) + vx = np.sin(rad); + vy = -np.cos(rad) + x0, y0 = center + x1, y1 = int(x0 - vx * half_len), int(y0 - vy * half_len) + x2, y2 = int(x0 + vx * half_len), int(y0 + vy * half_len) + + # 沿整条线分段画虚线 + total_len = int(np.hypot(x2 - x1, y2 - y1)) + if total_len <= 0: return + dx = (x2 - x1) / total_len + dy = (y2 - y1) / total_len + s = 0 + while s < total_len: + e = min(s + dash, total_len) + xa, ya = int(x1 + dx * s), int(y1 + dy * s) + xb, yb = int(x1 + dx * e), int(y1 + dy * e) + cv2.line(image, (xa, ya), (xb, yb), color, thickness) + s += (dash + gap) + + def _offset_from_centerline(self, center_pt, angle_vertical_deg, width, height, y_ratio=0.75) -> float: + """ + 基于“青色法线中央直线”计算左右偏移: + - angle_vertical_deg: 以“竖直方向为0°”的角(与 _draw_line_vertical_angle 相同坐标系) + - center_pt: 掩码质心 (cx, cy) + - y_ratio: 预瞄行高度(相对图像高度的比例),默认0.75(底部偏下更稳定) + 返回:归一化偏移(右为正,左为负),与原 offset 含义一致。 + """ + if center_pt is None: + return 0.0 + x0, y0 = center_pt + rad = np.radians(angle_vertical_deg) + # 与 _draw_line_vertical_angle 完全一致的方向向量定义 + vx = np.sin(rad) + vy = -np.cos(rad) + + # 取预瞄行的 y + y_target = float(int(height * y_ratio)) + + # 若法线几乎水平(极少出现),避免除0 + if abs(vy) < 1e-6: + x_at = float(x0) + else: + t = (y_target - float(y0)) / vy + x_at = float(x0) + t * vx + + x_at = float(np.clip(x_at, 0, width - 1)) + # 与旧 offset 定义一致:相对画面中心的归一化水平偏移(右正左负) + return float((x_at - (width / 2.0)) / max(1.0, width / 2.0)) + + def _compute_angle_and_offset(self, mask: np.ndarray) -> tuple: + """计算斑马线的角度和偏移(PCA 回退用)""" + H, W = mask.shape[:2] + ys, xs = np.where(mask > 0) + if xs.size < 50: + return 0.0, 0.0 + + # 使用PCA计算主方向 + pts = np.stack([xs.astype(np.float32), ys.astype(np.float32)], axis=1) + mean = pts.mean(axis=0) + cov = np.cov((pts - mean).T) + eigvals, eigvecs = np.linalg.eig(cov) + v = eigvecs[:, np.argmax(eigvals)] + + # 计算角度 + angle = np.degrees(np.arctan2(v[1], v[0])) + if angle > 90: angle -= 180 + if angle < -90: angle += 180 + + # 计算水平偏移 + cx = float(mean[0]) + offset = (cx - (W / 2.0)) / max(1.0, W / 2.0) + + return float(angle), float(offset) + + def _estimate_angle_by_stripes(self, mask: np.ndarray, gray: np.ndarray) -> Optional[Dict[str, Any]]: + """ + 基于掩码内条纹(霍夫线)估计角度和可视化(放宽参数 + 鲁棒聚类): + 返回 dict: { + 'angle_deg': float, # 相对竖直方向偏角([-45,45]),正=左偏,负=右偏 + 'lines': List[(x1,y1,x2,y2)], # 选中的条纹线段(图像坐标) + 'confidence': float, # [0,1] 加权圆均值合力强度 + 'count': int # 线段数量 + } + """ + try: + H, W = mask.shape[:2] + roi_top = int(0.45 * H) # 关注下半部分,稳定性更好 + m_roi = (mask[roi_top:H, :] > 0).astype(np.uint8) + g_roi = gray[roi_top:H, :] + + # 放宽边缘阈值 + g_blur = cv2.GaussianBlur(g_roi, (5, 5), 0) + edges = cv2.Canny(g_blur, 50, 150) + edges = cv2.bitwise_and(edges, edges, mask=m_roi * 255) + + # 放宽霍夫参数 + lines = cv2.HoughLinesP( + edges, + rho=1, + theta=np.pi / 180, + threshold=max(30, int(0.03 * W)), + minLineLength=int(0.15 * W), + maxLineGap=20 + ) + if lines is None: + return None + + angles, weights = [], [] + all_lines = [] + for x1, y1, x2, y2 in lines.reshape(-1, 4): + dx, dy = x2 - x1, y2 - y1 + length = float(np.hypot(dx, dy)) + if length < 8: + continue + ang = float(np.degrees(np.arctan2(dy, dx))) # 相对 x 轴 + if ang > 90: ang -= 180 + if ang < -90: ang += 180 + # 放宽角度接受范围 + if abs(ang) > 65: + continue + # 底部越近权重越大 + ymid = (y1 + y2) * 0.5 + roi_top + w = length * (0.5 + 0.5 * (ymid / max(1.0, H))) + angles.append(ang) + weights.append(w) + all_lines.append((int(x1), int(y1 + roi_top), int(x2), int(y2 + roi_top))) + + if len(angles) < 5: + return None + + # 角度鲁棒聚类:加权中位数 + MAD 剔除离群 + angs = np.array(angles, dtype=np.float32) + wts = np.array(weights, dtype=np.float32) + + # 加权中位数 + sort_idx = np.argsort(angs) + angs_sorted = angs[sort_idx] + wts_sorted = wts[sort_idx] + cum = np.cumsum(wts_sorted) + med_idx = np.searchsorted(cum, cum[-1] * 0.5) + med = float(angs_sorted[min(max(med_idx, 0), len(angs_sorted) - 1)]) + + # MAD(围绕中位数的绝对偏差中位数),阈值更宽 + dev = np.abs(angs - med) + mad = float(np.median(dev) + 1e-6) + deg_thr = max(12.0, 2.8 * mad) # 适度放宽 + keep = dev <= deg_thr + + if keep.sum() >= 3: + angs_keep = angs[keep] + wts_keep = wts[keep] + lines_keep = [all_lines[i] for i, k in enumerate(keep) if k] + else: + angs_keep = angs + wts_keep = wts + lines_keep = all_lines + + # 加权圆均值 + ang_rad = np.radians(angs_keep) + C = float(np.sum(wts_keep * np.cos(ang_rad))) + S = float(np.sum(wts_keep * np.sin(ang_rad))) + norm = float(np.sum(wts_keep) + 1e-6) + if abs(C) < 1e-6 and abs(S) < 1e-6: + return None + mean = float(np.degrees(np.arctan2(S, C))) + confidence = float(np.hypot(C, S) / norm) + + return { + "angle_deg": mean, + "lines": lines_keep, + "confidence": confidence, + "count": len(lines_keep), + } + except Exception: + return None + + def _get_crosswalk_guidance_features(self, mask: np.ndarray, image_shape: tuple) -> dict: + """计算斑马线引导特征(鲁棒中心线 + 目标点 + 角度/偏移)""" + try: + height, width = image_shape[:2] + min_run_px = max(12, int(width * 0.02)) + centerline_rows = [] + + # 自底向上扫描,按最大连续区段取左右边界的中点,忽略零散噪点 + for y in range(height - 1, int(height * 0.4), -5): + row = mask[y, :] + xs = np.where(row > 0)[0] + if xs.size <= min_run_px: + continue + splits = np.where(np.diff(xs) > 1)[0] + 1 + segments = np.split(xs, splits) if xs.size else [] + if not segments: + continue + seg = max(segments, key=lambda s: (s[-1] - s[0] + 1)) + if seg.size == 0 or (seg[-1] - seg[0] + 1) < min_run_px: + continue + center_x = 0.5 * (seg[0] + seg[-1]) + centerline_rows.append([y, center_x]) + + if len(centerline_rows) < 10: + return None + + data = np.array(centerline_rows, dtype=np.float32) + y_coords, x_coords = data[:, 0], data[:, 1] + + # 初始加权(底部更重要) + w_base = y_coords / float(height) + coeffs = np.polyfit(y_coords, x_coords, 2, w=w_base) + poly = np.poly1d(coeffs) + + # 一次鲁棒再加权(抑制弯折/异常点) + res = x_coords - poly(y_coords) + mad = np.median(np.abs(res - np.median(res))) + 1e-6 + c = 2.5 * mad + w_robust = 1.0 / (1.0 + (res / c) ** 2) + w_total = w_base * w_robust + coeffs = np.polyfit(y_coords, x_coords, 2, w=w_total) + poly = np.poly1d(coeffs) + + # 目标点与绘制点 + lookahead_y = int(height * 0.6) + target_x = float(poly(lookahead_y)) + plot_y = np.arange(int(height * 0.4), height, 5).astype(int) + plot_x = poly(plot_y).astype(int) + centerline_points = np.vstack((plot_x, plot_y)).T.tolist() + + # 角度(基于 x(y) 的导数)与水平偏移 + dpoly = np.polyder(poly) + dx_dy = float(dpoly(lookahead_y)) + angle_deg = float(np.degrees(np.arctan(dx_dy))) + offset = float((target_x - (width / 2.0)) / max(1.0, width / 2.0)) + + # 截断目标点范围 + tx = int(np.clip(target_x, 0, width - 1)) + return { + "target_point": (tx, lookahead_y), + "centerline_points": centerline_points, + "angle_deg": angle_deg, + "offset": offset, + } + except Exception: + return None + + # —— 障碍物:光流辅助方法(与 blindpath 一致) —— + def _get_edge_mask(self, mask, offset=10): + """获取掩码的内边缘区域,用于特征点检测""" + if mask is None: + return None + kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (offset*2, offset*2)) + inner = cv2.erode(mask, kernel, iterations=1) + edge = cv2.subtract(mask, inner) + kernel_small = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3)) + edge = cv2.dilate(edge, kernel_small, iterations=1) + return edge + + def _predict_mask_with_flow(self, prev_mask, prev_gray, curr_gray): + """使用 Lucas-Kanade 光流预测掩码位置(与 blindpath 一致)""" + try: + edge_mask = self._get_edge_mask(prev_mask, offset=10) + p0 = cv2.goodFeaturesToTrack(prev_gray, mask=edge_mask, **FEATURE_PARAMS) + if p0 is None or len(p0) < 8: + return None + p1, st, err = cv2.calcOpticalFlowPyrLK(prev_gray, curr_gray, p0, None, **LK_PARAMS) + if p1 is None or st is None: + return None + good_new = p1[st == 1] + good_old = p0[st == 1] + if len(good_new) < 5: + return None + M, inliers = cv2.estimateAffinePartial2D(good_old, good_new, method=cv2.RANSAC, ransacReprojThreshold=5.0) + if M is None: + return None + H, W = curr_gray.shape[:2] + flow_mask = cv2.warpAffine(prev_mask, M, (W, H), + flags=cv2.INTER_NEAREST, + borderMode=cv2.BORDER_CONSTANT, + borderValue=0) + return flow_mask + except Exception: + return None + + # —— 障碍物:检测与可视化(与 blindpath 一致) —— + def _detect_obstacles(self, image, path_mask=None): + """检测障碍物,调用 ObstacleDetectorClient.detect(与 blindpath 同步)""" + logger.info(f"[_detect_obstacles] 开始执行,Frame={self.frame_counter}, obstacle_detector={'已加载' if self.obstacle_detector else '未加载'}") + if self.obstacle_detector is None: + logger.warning("[_detect_obstacles] 障碍物检测器未加载!") + return [] + + try: + logger.info(f"[_detect_obstacles] 调用ObstacleDetectorClient.detect()... image.shape={image.shape}") + detected_obstacles = self.obstacle_detector.detect(image, path_mask=path_mask) + logger.info(f"[_detect_obstacles] 返回 {len(detected_obstacles)} 个物体") + + # 补充派生字段 + H, W = image.shape[:2] + for i, obj in enumerate(detected_obstacles): + if 'mask' in obj and obj['mask'] is not None: + y_coords, x_coords = np.where(obj['mask'] > 0) + if len(y_coords) > 0 and len(x_coords) > 0: + x1, y1 = int(np.min(x_coords)), int(np.min(y_coords)) + x2, y2 = int(np.max(x_coords)), int(np.max(y_coords)) + obj['box_coords'] = (x1, y1, x2, y2) + if 'y_position_ratio' not in obj: + obj['y_position_ratio'] = obj.get('center_y', 0) / H + if 'label' not in obj: + obj['label'] = obj.get('name', 'unknown') + if 'center' not in obj: + obj['center'] = (obj.get('center_x', 0), obj.get('center_y', 0)) + if 'confidence' not in obj: + obj['confidence'] = 0.5 + return detected_obstacles + except Exception as e: + logger.error(f"[_detect_obstacles] 障碍物检测失败: {e}", exc_info=True) + return [] + + def _stabilize_obstacle_list(self, obstacles, prev_obstacles, prev_gray, curr_gray, image_shape, threshold=0.5): + """稳定障碍物检测结果,避免重复叠加(与 blindpath 一致)""" + if not obstacles or prev_gray is None or curr_gray is None: + return obstacles + + H, W = image_shape + stabilized = [] + used_prev = set() + for curr_obs in obstacles: + if 'mask' not in curr_obs or curr_obs['mask'] is None: + stabilized.append(curr_obs) + continue + curr_mask = curr_obs['mask'] + best_match = None + best_iou = 0 + best_idx = -1 + + if prev_obstacles: + for idx, prev_obs in enumerate(prev_obstacles): + if idx in used_prev or 'mask' not in prev_obs: + continue + flow_mask = self._predict_mask_with_flow(prev_obs['mask'], prev_gray, curr_gray) + if flow_mask is None: + flow_mask = prev_obs['mask'] + inter = np.logical_and(curr_mask > 0, flow_mask > 0).sum() + union = np.logical_or(curr_mask > 0, flow_mask > 0).sum() + iou = float(inter) / float(union) if union > 0 else 0.0 + if iou > best_iou and iou > threshold: + best_iou = iou + best_match = flow_mask + best_idx = idx + + if best_match is not None and best_idx >= 0: + used_prev.add(best_idx) + fused_mask = ((0.8 * curr_mask + 0.2 * best_match) > 128).astype(np.uint8) * 255 + curr_obs['mask'] = fused_mask + self._update_obstacle_properties(curr_obs, H, W) + stabilized.append(curr_obs) + return stabilized + + def _update_obstacle_properties(self, obs, H, W): + """更新障碍物的派生属性""" + if 'mask' not in obs or obs['mask'] is None: + return + mask = obs['mask'] + y_coords, x_coords = np.where(mask > 0) + if len(y_coords) > 0: + obs['area'] = int(len(y_coords)) + obs['center_x'] = float(np.mean(x_coords)) + obs['center_y'] = float(np.mean(y_coords)) + obs['y_position_ratio'] = obs['center_y'] / H + obs['area_ratio'] = obs['area'] / float(H * W) + obs['bottom_y_ratio'] = np.max(y_coords) / float(H) + x1, y1 = int(np.min(x_coords)), int(np.min(y_coords)) + x2, y2 = int(np.max(x_coords)), int(np.max(y_coords)) + obs['box_coords'] = (x1, y1, x2, y2) + + # —— 可视化通用方法(与 blindpath 一致) —— + def _parse_color(self, color_str): + """解析颜色字符串,返回BGR格式""" + try: + if isinstance(color_str, tuple) and len(color_str) == 3: + return color_str + if color_str.startswith('rgba('): + values = color_str[5:-1].split(',') + r, g, b = int(values[0]), int(values[1]), int(values[2]) + return (b, g, r) # OpenCV: BGR + elif color_str == 'yellow': + return (0, 255, 255) + elif color_str == 'red': + return (0, 0, 255) + else: + return (0, 0, 255) + except: + return (0, 0, 255) + + def _add_obstacle_visualization(self, obstacle, visualizations, pulse_effect=False): + """添加障碍物可视化(简化版:仅边框,近红远黄)""" + try: + bottom_y_ratio = obstacle.get('bottom_y_ratio', 0) + area_ratio = obstacle.get('area_ratio', 0) + is_near = bottom_y_ratio > 0.7 or area_ratio > 0.1 + + if 'mask' in obstacle and obstacle['mask'] is not None: + mask = obstacle['mask'] + contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + if contours: + max_contour = max(contours, key=cv2.contourArea) + points = max_contour.squeeze(1)[::5].tolist() + + # 根据距离选择边框颜色:近距离红色,远距离黄色 + if is_near: + outline_color = "rgba(255, 0, 0, 1.0)" # 红色 + thickness = 3 + else: + outline_color = "rgba(255, 255, 0, 0.8)" # 黄色 + thickness = 2 + + # 只添加边框,不添加填充和文字 + visualizations.append({ + "type": "outline", + "points": points, + "color": outline_color, + "thickness": thickness + }) + except Exception as e: + logger.error(f"[_add_obstacle_visualization] 添加障碍物可视化失败: {e}") + + def _draw_command_button(self, image, text): + """绘制底部中央的指令按钮(类似yolomedia风格)""" + try: + H, W = image.shape[:2] + full_text = f"当前指令:{text if text else '—'}" + + # 按钮参数 + font_px = 14 + pad_x, pad_y = 14, 8 + bottom_margin = 28 + + # 计算文字尺寸 + if PIL_AVAILABLE: + try: + from PIL import Image as PILImage, ImageDraw, ImageFont + # 尝试加载中文字体 + font = None + for font_path in [ + "/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc", + "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", + ]: + if os.path.exists(font_path): + try: + font = ImageFont.truetype(font_path, font_px) + break + except: + continue + if font: + bbox = ImageDraw.Draw(PILImage.new('RGB', (1, 1))).textbbox((0, 0), full_text, font=font) + tw = max(1, bbox[2] - bbox[0]) + th = max(1, bbox[3] - bbox[1]) + else: + scale = font_px / 24.0 + (tw, th), _ = cv2.getTextSize(full_text, cv2.FONT_HERSHEY_SIMPLEX, scale, 1) + except: + scale = font_px / 24.0 + (tw, th), _ = cv2.getTextSize(full_text, cv2.FONT_HERSHEY_SIMPLEX, scale, 1) + else: + scale = font_px / 24.0 + (tw, th), _ = cv2.getTextSize(full_text, cv2.FONT_HERSHEY_SIMPLEX, scale, 1) + + # 计算按钮位置(底部居中) + bw = tw + pad_x * 2 + bh = th + pad_y * 2 + radius = max(10, bh // 2) + + cx = W // 2 + left = max(8, cx - bw // 2) + top = H - bottom_margin - bh + right = min(W - 8, left + bw) + bottom = top + bh + + # 绘制半透明圆角背景 + overlay = image.copy() + bg_color = (26, 32, 41) # 深色背景 + border_color = (60, 76, 102) # 边框 + + # 圆角矩形(中间+两个圆) + cv2.rectangle(overlay, (left + radius, top), (right - radius, bottom), bg_color, -1) + cv2.circle(overlay, (left + radius, (top + bottom) // 2), radius, bg_color, -1) + cv2.circle(overlay, (right - radius, (top + bottom) // 2), radius, bg_color, -1) + + # 混合半透明 + cv2.addWeighted(overlay, 0.75, image, 0.25, 0, image) + + # 绘制边框 + cv2.rectangle(image, (left + radius, top), (right - radius, bottom), border_color, 1) + cv2.circle(image, (left + radius, (top + bottom) // 2), radius, border_color, 1) + cv2.circle(image, (right - radius, (top + bottom) // 2), radius, border_color, 1) + + # 绘制文字 + text_x = left + pad_x + text_y = top + pad_y + th + + if PIL_AVAILABLE and font: + # 使用PIL绘制中文 + pil_img = PILImage.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + draw = ImageDraw.Draw(pil_img) + draw.text((text_x, top + pad_y), full_text, font=font, fill=(255, 255, 255)) + image = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) + else: + # 使用OpenCV绘制 + cv2.putText(image, full_text, (text_x, text_y), + cv2.FONT_HERSHEY_SIMPLEX, scale, (255, 255, 255), 1) + + return image + except Exception as e: + logger.error(f"绘制指令按钮失败: {e}") + return image + + def _draw_data_panel_no_bg(self, image, data, position=(15, 15)): + """绘制数据面板(无黑底,描边文字),与 blindpath 一致""" + if not PIL_AVAILABLE: + return image + try: + pil_img = Image.fromarray(cv2.cvtColor(image, cv2.COLOR_BGR2RGB)) + draw = ImageDraw.Draw(pil_img, "RGBA") + env_scale = float(os.getenv("AIGLASS_PANEL_SCALE", "0.7")) + base_font_size = max(10, int(round(14 * env_scale))) + font = None + font_paths = [ + "/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc", + "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", + ] + for font_path in font_paths: + try: + if os.path.exists(font_path): + font = ImageFont.truetype(font_path, base_font_size) + break + except: + continue + if font is None: + font = ImageFont.load_default() + + y_offset = position[1] + for key, value in data.items(): + text = f"{key}: {value}" + for dx in [-1, 0, 1]: + for dy in [-1, 0, 1]: + if dx != 0 or dy != 0: + draw.text((position[0] + dx, y_offset + dy), text, + font=font, fill=(0, 0, 0, 255)) + draw.text((position[0], y_offset), text, font=font, fill=(255, 255, 255, 255)) + y_offset += base_font_size + 5 + return cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR) + except Exception as e: + logger.warning(f"绘制数据面板失败: {e}") + return image + + def _draw_visualizations(self, image, viz_elements): + """增强的可视化绘制方法(与 blindpath 一致)""" + if not viz_elements: + return image + current_time = time.time() + panel_elements = [v for v in viz_elements if v.get("type") == "data_panel"] + standard_elements = [v for v in viz_elements if v.get("type") != "data_panel"] + + # 第一遍:半透明填充 + for element in standard_elements: + elem_type = element.get("type") + if elem_type in ['blind_path_mask', 'obstacle_mask', 'crosswalk_mask']: + points = np.array(element.get("points", []), dtype=np.int32) + if points.size > 0: + color = self._parse_color(element.get("color", "rgba(255, 255, 255, 0.5)")) + if element.get("effect") == "pulse": + pulse_speed = element.get("pulse_speed", 1.0) + alpha = 0.3 + 0.3 * np.sin(current_time * pulse_speed * 2 * np.pi) + else: + alpha = 0.4 + x, y, w, h = cv2.boundingRect(points) + x = max(0, x); y = max(0, y) + w = min(w, image.shape[1] - x) + h = min(h, image.shape[0] - y) + if w > 0 and h > 0: + binary_mask = np.zeros((h, w), dtype=np.uint8) + local_points = points - np.array([x, y]) + cv2.fillPoly(binary_mask, [local_points], 255) + local_region = image[y:y+h, x:x+w].copy() + color_overlay = np.zeros((h, w, 3), dtype=np.uint8) + color_overlay[:] = color + for c in range(3): + local_region[:, :, c] = np.where( + binary_mask > 0, + (1 - alpha) * local_region[:, :, c] + alpha * color_overlay[:, :, c], + local_region[:, :, c] + ) + image[y:y+h, x:x+w] = local_region + + # 第二遍:轮廓和元素 + for element in standard_elements: + elem_type = element.get("type") + if elem_type == 'outline': + points = np.array(element.get("points", []), dtype=np.int32) + if points.size > 0: + color = self._parse_color(element.get("color", "rgba(255, 255, 255, 1.0)")) + thickness = element.get("thickness", 3) + cv2.polylines(image, [points], isClosed=True, color=color, thickness=thickness) + elif elem_type == 'polyline': + points = np.array(element.get("points", []), dtype=np.int32) + if points.size > 0: + color = self._parse_color(element.get("color", "rgba(255, 255, 0, 1.0)")) + thickness = element.get("width", 2) + cv2.polylines(image, [points], isClosed=False, color=color, thickness=thickness) + elif elem_type == 'circle': + center = tuple(element.get("center", (0, 0))) + radius = element.get("radius", 10) + color = self._parse_color(element.get("color", "rgba(255, 0, 0, 1.0)")) + thickness = -1 if element.get("filled", True) else 2 + cv2.circle(image, center, radius, color, thickness) + elif elem_type == 'arrow': + start = tuple(element.get("start", (0, 0))) + end = tuple(element.get("end", (100, 100))) + color = self._parse_color(element.get("color", "rgba(0, 255, 255, 1.0)")) + thickness = element.get("thickness", 2) + tip_length = element.get("tip_length", 0.3) + cv2.arrowedLine(image, start, end, color, thickness, tipLength=tip_length) + elif elem_type == 'text_with_bg': + text = element.get("text", "") + pos = element.get("position", [10, 30]) + font_scale = element.get("font_scale", 0.6) + color = self._parse_color(element.get("color", "rgba(255, 255, 255, 1.0)")) + for dx in [-1, 0, 1]: + for dy in [-1, 0, 1]: + if dx != 0 or dy != 0: + cv2.putText(image, text, (pos[0] + dx, pos[1] + dy), + cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 0), 3) + cv2.putText(image, text, tuple(pos), cv2.FONT_HERSHEY_SIMPLEX, font_scale, color, 2) + elif elem_type == 'warning_icon': + pos = element.get("position", (100, 100)) + level = element.get("level", "info") + text = element.get("text", "") + flash = element.get("flash", False) + if level == "danger": + icon_color = (0, 0, 255) + text_color = (255, 255, 255) + elif level == "warning": + icon_color = (0, 165, 255) + text_color = (255, 255, 255) + else: + icon_color = (0, 255, 255) + text_color = (0, 0, 0) + if flash: + alpha = 0.5 + 0.5 * np.sin(current_time * 4 * np.pi) + icon_color = tuple(int(c * alpha) for c in icon_color) + triangle = np.array([ + [pos[0], pos[1] - 20], + [pos[0] - 15, pos[1]], + [pos[0] + 15, pos[1]] + ], np.int32) + cv2.fillPoly(image, [triangle], icon_color) + cv2.polylines(image, [triangle], True, (255, 255, 255), 2) + cv2.putText(image, "!", (pos[0] - 5, pos[1] - 5), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 2) + if text: + font_scale = 0.5 + (tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, font_scale, 1) + text_pos = (pos[0] - tw // 2, pos[1] + 20) + for dx in [-1, 0, 1]: + for dy in [-1, 0, 1]: + if dx != 0 or dy != 0: + cv2.putText(image, text, (text_pos[0] + dx, text_pos[1] + dy), + cv2.FONT_HERSHEY_SIMPLEX, font_scale, (0, 0, 0), 2) + cv2.putText(image, text, text_pos, cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_color, 1) + elif elem_type == 'text': + text = element.get("text", "") + pos = tuple(element.get("pos", (10, 30))) + cv2.putText(image, text, pos, cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) + + # 数据面板 + if PIL_AVAILABLE: + for panel in panel_elements: + image = self._draw_data_panel_no_bg(image, panel["data"], panel["position"]) + else: + for panel in panel_elements: + y_offset = panel["position"][1] + for key, value in panel["data"].items(): + text = f"{key}: {value}" + for dx in [-1, 0, 1]: + for dy in [-1, 0, 1]: + if dx != 0 or dy != 0: + cv2.putText(image, text, (panel["position"][0] + dx, y_offset + dy), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 3) + cv2.putText(image, text, (panel["position"][0], y_offset), + cv2.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1) + y_offset += 25 + return image + + def _speech_for_obstacle(self, name: str) -> str: + """生成障碍物语音提示""" + k = (name or '').strip().lower() + if k == 'person': return "前方有人,注意避让。" + if k == 'car': return "前方有车,注意避让。" + if k == 'bicycle': return "前方有自行车,停一下。" + if k == 'motorcycle': return "前方有摩托车,停一下。" + if k == 'bus': return "前方有公交车,停一下。" + if k == 'truck': return "前方有卡车,停一下。" + if k == 'scooter': return "前方有电瓶车,停一下。" + if k == 'stroller': return "前方有婴儿车,停一下。" + if k == 'dog': return "前方有狗,停一下。" + if k == 'animal': return "前方有动物,停一下。" + return "前方有障碍物,注意避让。" + + def process_frame(self, bgr_image: np.ndarray) -> CrossStreetResult: + """处理单帧图像(每帧分割;若失败,用光流追踪上一帧掩码保持可视化与导航)""" + self.frame_counter += 1 + current_time = time.time() + + try: + annotated = bgr_image.copy() + h, w = bgr_image.shape[:2] + frame_visualizations = [] + + # 当前灰度图供 LK 与避障稳定使用 + gray = cv2.cvtColor(bgr_image, cv2.COLOR_BGR2GRAY) + + # ========== 1) 间隔执行分割(每4帧检测一次) ========== + crosswalk_mask = None + blindpath_mask = None + det_area = 0 + + # 【新增】检测间隔逻辑 + if self.seg_model and self.frame_counter % self.CROSSWALK_DETECTION_INTERVAL == 0: + # 执行新的检测 + # 使用较低的基础阈值获取所有候选 + base_thr = min(CROSSWALK_MIN_CONF, BLIND_MIN_CONF) + detections = self.seg_model.detect(bgr_image, confidence_threshold=base_thr) or [] + + # 按类别ID和名称分拣 + raw_cw, raw_bp = [], [] + for det in detections: + if not hasattr(det, 'mask') or det.mask is None: + continue + + cid = _cls_of(det) + name = str(getattr(det, "name", "")).lower() + + # 斑马线:ID匹配或名称匹配 + if (cid == CW_ID) or _in_set(name, _CW): + raw_cw.append(det) + # 盲道:ID匹配或名称匹配 + elif (cid == BP_ID) or _in_set(name, _BP): + raw_bp.append(det) + + # 二次阈值过滤 + cw_list = [d for d in raw_cw if _score_of(d) >= CROSSWALK_MIN_CONF] + bp_list = [d for d in raw_bp if _score_of(d) >= BLIND_MIN_CONF] + + # 合并斑马线mask + if cw_list: + cw_masks = [] + for det in cw_list: + mask = det.mask + if mask.shape != (h, w): + mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST) + mask_bin = (mask > 0.5).astype(np.uint8) + cw_masks.append(mask_bin) + if cw_masks: + crosswalk_mask = np.maximum.reduce(cw_masks) + det_area = int(crosswalk_mask.sum()) + if det_area < CROSSWALK_MIN_AREA: + crosswalk_mask = None + det_area = 0 + + # 合并盲道mask + if bp_list: + bp_masks = [] + for det in bp_list: + mask = det.mask + if mask.shape != (h, w): + mask = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST) + mask_bin = (mask > 0.5).astype(np.uint8) + bp_masks.append(mask_bin) + if bp_masks: + blindpath_mask = np.maximum.reduce(bp_masks) + + # 去交叠:从斑马线mask中移除盲道区域 + if crosswalk_mask is not None and blindpath_mask is not None: + crosswalk_mask = crosswalk_mask.copy() + crosswalk_mask[blindpath_mask > 0] = 0 + + # 盲道真伪判定 + if blindpath_mask is not None: + if not _looks_like_blind_path(blindpath_mask, crosswalk_mask, h, w): + blindpath_mask = None + + # 【新增】保存检测结果到缓存 + self.last_detected_crosswalk_mask = crosswalk_mask + self.last_detected_blindpath_mask = blindpath_mask + self.last_crosswalk_detection_frame = self.frame_counter + + else: + # 【新增】使用缓存的检测结果 + crosswalk_mask = self.last_detected_crosswalk_mask + blindpath_mask = self.last_detected_blindpath_mask + + # ========== 2) 分割失败 → 用上一帧特征点光流追踪重建 ========== + used_tracking = False + if crosswalk_mask is None: + if self.old_gray is not None and self.p0 is not None and len(self.p0) >= TRACK_MIN_POINTS: + try: + p1, st, err = cv2.calcOpticalFlowPyrLK(self.old_gray, gray, self.p0, None, **LK_PARAMS) + if p1 is not None and st is not None: + good_new = p1[st == 1] + if good_new is not None and len(good_new) >= TRACK_MIN_POINTS: + tracked_mask = self._hull_mask_from_points(good_new, (h, w)) + if tracked_mask is not None and int(tracked_mask.sum()) >= (0.3 * (self.prev_mask.sum() if self.prev_mask is not None else 1)): + crosswalk_mask = tracked_mask + used_tracking = True + self.p0 = good_new.reshape(-1, 1, 2) + else: + self.p0 = None + self.old_gray = None + except Exception as e: + logger.warning(f"[CROSS_STREET] LK 光流失败: {e}") + self.p0 = None + self.old_gray = None + + # ========== 3) EMA 平滑(减少抖动) + 形态学净化 ========== + if crosswalk_mask is not None: + m = crosswalk_mask.astype(np.float32) + if self.prev_mask_float is not None and self.prev_mask_float.shape == m.shape: + self.prev_mask_float = MASK_EMA_ALPHA * m + (1.0 - MASK_EMA_ALPHA) * self.prev_mask_float + else: + self.prev_mask_float = m + crosswalk_mask = (self.prev_mask_float > 0.5).astype(np.uint8) + crosswalk_mask = self._postprocess_mask(crosswalk_mask) + self.prev_mask = crosswalk_mask + self.prev_mask_ts = current_time + + # ========== 4) 若分割成功(或追踪成功)→ 播种/更新特征点 ========== + if crosswalk_mask is not None: + need_seed = (self.p0 is None or len(self.p0) < TRACK_MIN_POINTS or + (self.frame_counter - self.last_seed_frame) >= TRACK_RESEED_EVERY) + if need_seed: + pts = self._seed_points_from_mask(gray, crosswalk_mask) + if pts is not None and len(pts) >= TRACK_MIN_POINTS: + self.p0 = pts + self.old_gray = gray.copy() + self.last_seed_frame = self.frame_counter + else: + self.old_gray = gray.copy() + else: + self.crosswalk_detected = False + self.p0 = None + self.old_gray = None + + # ========== 4.5) 障碍物检测与可视化(与 blindpath 一致) ========== + # 使用 crosswalk_mask 作为 path_mask,若无则全局检测 + detected_obstacles = [] + if self.obstacle_detector is not None: + if self.frame_counter % self.OBSTACLE_DETECTION_INTERVAL == 0: + detected_obstacles = self._detect_obstacles(bgr_image, path_mask=crosswalk_mask) + # 稳定化 + if self.prev_gray is not None: + detected_obstacles = self._stabilize_obstacle_list( + detected_obstacles, + self.last_detected_obstacles, + self.prev_gray, + gray, + bgr_image.shape[:2] + ) + self.last_detected_obstacles = detected_obstacles + self.last_obstacle_detection_frame = self.frame_counter + else: + if self.frame_counter - self.last_obstacle_detection_frame < self.OBSTACLE_CACHE_DURATION_FRAMES: + detected_obstacles = self.last_detected_obstacles + else: + detected_obstacles = [] + # 可视化所有障碍物 + for obs in detected_obstacles: + self._add_obstacle_visualization(obs, frame_visualizations) + + # ========== 5) 状态机 + 可视化与导航指令 ========== + guidance_text = "" + + # 先绘制盲道(绿色mask,无黑底) + if blindpath_mask is not None: + # 只在掩码区域混合绿色,避免黑底 + mask_area = (blindpath_mask > 0).astype(bool) + green_color = np.array([0, 255, 0], dtype=np.float32) # BGR + # 在掩码区域内混合颜色 + for c in range(3): + annotated[:, :, c] = np.where( + mask_area, + (annotated[:, :, c] * 0.7 + green_color[c] * 0.3).astype(np.uint8), + annotated[:, :, c] + ) + # 绘制盲道边框 + bp_ct = self._largest_contour(blindpath_mask) + if bp_ct is not None: + cv2.drawContours(annotated, [bp_ct], -1, (0, 255, 0), 2) + + # 绘制斑马线(橙色mask,无描边,与盲道模式颜色一致) + if crosswalk_mask is not None: + self.crosswalk_detected = True + # 使用与盲道模式相同的橙色:BGR(0, 165, 255),只在掩码区域混合 + mask_area = (crosswalk_mask > 0).astype(bool) + orange_color = np.array([0, 165, 255], dtype=np.float32) # BGR + # 在掩码区域内混合颜色 + for c in range(3): + annotated[:, :, c] = np.where( + mask_area, + (annotated[:, :, c] * 0.7 + orange_color[c] * 0.3).astype(np.uint8), + annotated[:, :, c] + ) + + # ===== 状态机逻辑 ===== + if self.state == STATE_SEEKING: + # 阶段1:寻找并对准远处的斑马线 + if crosswalk_mask is not None: + is_near = self._is_crosswalk_near(crosswalk_mask, h, w) + + if is_near: + # 斑马线已到跟前,切换到红绿灯判定 + self.state = STATE_WAIT_LIGHT + guidance_text = "斑马线已在跟前,进入红绿灯判定模式" + self.last_seeking_guidance = "" # 重置节流状态 + else: + # 远距离对准引导(使用更宽松的阈值) + angle, offset = self._compute_远_distance_alignment(crosswalk_mask, h, w) + + # 优先角度,其次方位(使用SEEKING专用的宽松阈值) + if abs(angle) >= SEEKING_ANGLE_THRESH_DEG: + direction = "左转一点" if angle > 0 else "右转一点" + elif abs(offset) >= SEEKING_OFFSET_THRESH: + direction = "向右平移" if offset > 0 else "向左平移" + else: + direction = "保持直行" + + # 【移除左上角文字,改为右上角数据面板】 + # 添加右上角数据面板 + frame_visualizations.append({ + "type": "data_panel", + "data": { + "状态": "对准斑马线", + "角度": f"{angle:.1f}°", + "偏移": f"{offset:.2f}" + }, + "position": (w - 180, 20) + }) + + # 节流:只有当引导文本改变或超过时间间隔时才播报 + if current_time - self.last_guide_time > self.guide_interval: + if direction != self.last_seeking_guidance: + guidance_text = direction + self.last_seeking_guidance = direction + elif current_time - self.last_guide_time > self.guide_interval * 2: + # 超过2倍间隔,重复播报 + guidance_text = direction + else: + # 【移除左上角文字,改为右上角数据面板】 + frame_visualizations.append({ + "type": "data_panel", + "data": { + "状态": "寻找斑马线" + }, + "position": (w - 180, 20) + }) + self.last_seeking_guidance = "" # 没有斑马线时重置 + + elif self.state == STATE_WAIT_LIGHT: + # 阶段2:红绿灯判定 + # 【移除左上角文字,稍后添加右上角数据面板】 + + if TRAFFIC_LIGHT_AVAILABLE and trafficlight_detection: + try: + # 传入annotated(已包含斑马线和盲道),红绿灯检测在此基础上添加检测框 + result = trafficlight_detection.process_single_frame(annotated) + + # 可视化红绿灯检测结果(绘制检测框) + if result and 'vis_image' in result: + vis_img = result['vis_image'] + if vis_img is not None: + # 将红绿灯检测的可视化结果(带斑马线、盲道和检测框)更新到annotated + annotated = vis_img + + if result and 'stable_light' in result: + stable_light = result['stable_light'] + + if stable_light == 'go': + self.green_light_counter += 1 + # 【移除左上角文字,改为右上角数据面板】 + frame_visualizations.append({ + "type": "data_panel", + "data": { + "状态": "红绿灯判定", + "检测": f"绿灯 {self.green_light_counter}/{GREEN_LIGHT_STABLE_FRAMES}" + }, + "position": (w - 180, 20) + }) + + if self.green_light_counter >= GREEN_LIGHT_STABLE_FRAMES: + self.state = STATE_CROSSING + guidance_text = "绿灯稳定,开始通行。" + self.green_light_counter = 0 + self.crossing_end_announced = False # 重置过马路结束标志 + self.last_crosswalk_seen_time = current_time # 初始化斑马线检测时间 + self.last_blindpath_announce_time = 0 # 重置盲道播报时间 + else: + # 检测到绿灯但还不稳定,节流播报 + if current_time - self.last_waiting_light_time > 3.0: + guidance_text = "正在等待绿灯…" + self.last_waiting_light_time = current_time + else: + self.green_light_counter = 0 + if stable_light in ['stop', 'countdown_stop']: + # 【移除左上角文字,改为右上角数据面板】 + frame_visualizations.append({ + "type": "data_panel", + "data": { + "状态": "红绿灯判定", + "检测": "红灯,请等待" + }, + "position": (w - 180, 20) + }) + # 红灯状态播报(节流) + if current_time - self.last_waiting_light_time > 3.0: + guidance_text = "正在等待绿灯…" + self.last_waiting_light_time = current_time + else: + # 其他状态(黄灯或未检测到),节流播报 + if current_time - self.last_waiting_light_time > 3.0: + guidance_text = "正在等待绿灯…" + self.last_waiting_light_time = current_time + else: + # 没有检测到稳定的红绿灯,节流播报 + if current_time - self.last_waiting_light_time > 3.0: + guidance_text = "正在等待绿灯…" + self.last_waiting_light_time = current_time + except Exception as e: + logger.warning(f"[CROSS_STREET] 红绿灯检测失败: {e}") + if current_time - self.last_waiting_light_time > 3.0: + guidance_text = "正在等待绿灯…" + self.last_waiting_light_time = current_time + else: + # 无红绿灯模块,直接切换 + # 【移除左上角文字,改为右上角数据面板】 + frame_visualizations.append({ + "type": "data_panel", + "data": { + "状态": "红绿灯判定", + "检测": "模块未加载" + }, + "position": (w - 180, 20) + }) + if current_time - self.last_guide_time > 2.0: + self.state = STATE_CROSSING + guidance_text = "开始通行" + self.crossing_end_announced = False # 重置过马路结束标志 + self.last_crosswalk_seen_time = current_time # 初始化斑马线检测时间 + self.last_blindpath_announce_time = 0 # 重置盲道播报时间 + + elif self.state == STATE_CROSSING: + # 阶段3:过马路引导(原有逻辑) + + # 【新增】实时红绿灯检测(在CROSSING状态中) + traffic_light_warning = None # 用于存储红绿灯警告信息 + if TRAFFIC_LIGHT_AVAILABLE and trafficlight_detection: + try: + # 传入annotated(已包含斑马线和盲道),红绿灯检测在此基础上添加检测框 + result = trafficlight_detection.process_single_frame(annotated) + + # 将红绿灯检测的可视化结果(带斑马线、盲道和检测框)更新到annotated + if result and 'vis_image' in result: + vis_img = result['vis_image'] + if vis_img is not None: + # 将红绿灯检测框叠加到annotated上(保留斑马线和盲道) + annotated = vis_img + + # 检查稳定状态,如果是绿灯倒计时,播报警告 + if result and 'stable_light' in result: + stable_light = result['stable_light'] + if stable_light == 'countdown_go': + # 绿灯倒计时,播报警告(节流) + if current_time - self.last_guide_time > 2.0: + traffic_light_warning = "绿灯快没了" + except Exception as e: + logger.warning(f"[CROSS_STREET] CROSSING状态红绿灯检测失败: {e}") + + if crosswalk_mask is not None: + # 更新斑马线检测时间 + self.last_crosswalk_seen_time = current_time + + # 检测到斑马线:如果之前误播报了结束,现在重置标志回到正常流程 + area = int(crosswalk_mask.sum()) + area_ratio = float(area) / float(h * w) + # 如果斑马线面积还比较大(>0.1),说明还在过马路中,重置结束标志 + if area_ratio > 0.1 and self.crossing_end_announced: + self.crossing_end_announced = False + self.blindpath_announced = False + logger.info("[CROSS_STREET] 检测到斑马线,重置结束标志,回到正常过马路流程") + + # 【移除左上角文字,改为右上角数据面板】 + panel_data = { + "状态": "正在过马路", + "面积": f"{area_ratio:.2f}" + } + if self.crossing_end_announced: + panel_data["提示"] = "已播报结束" + frame_visualizations.append({ + "type": "data_panel", + "data": panel_data, + "position": (w - 180, 20) + }) + + # 使用"斑马线横纹法线的中央直线"来推导偏移(offset 初值仍给 0,后面根据青色法线更新) + angle_deg, offset = 0.0, 0.0 + + # 角度:优先使用条纹霍夫线估计;失败回退 PCA + angle_source = "条纹" + stripes = self._estimate_angle_by_stripes(crosswalk_mask, gray) + if stripes and ("angle_deg" in stripes): + angle_deg = -float(stripes["angle_deg"]) + for (x1, y1, x2, y2) in stripes.get("lines", []): + cv2.line(annotated, (x1, y1), (x2, y2), VIS_COLORS["stripes"], 2) + # 可视化方向箭头(底部中心,表示偏角相对竖直) + cx, cy = int(w * 0.5), int(h * 0.85) + length = int(60) + rad = np.radians(angle_deg) + dx = int(length * np.sin(rad)) + dy = int(length * np.cos(rad)) + cv2.arrowedLine(annotated, (cx, cy), (cx + dx, cy - dy), VIS_COLORS["heading"], 3, tipLength=0.25) + else: + angle_source = "PCA" + angle_deg, _ = self._compute_angle_and_offset(crosswalk_mask) + + + # === 基于掩码质心 + 条纹法线,绘制"青色法线中央直线" & "白色虚线(与红箭头同向)" === + # === 过中心的两条参考线:青色=法线、白色虚线=与红箭头同向 === + center_pt = self._mask_center(crosswalk_mask) + if center_pt is not None and stripes and ("angle_deg" in stripes): + # 1) 青色法线:使用"条纹均值角"作为【法线相对竖直】角,保证与橙色条纹垂直 + angle_blue = float(stripes["angle_deg"]) # ← 关键:不要再取负,不要再加减 90° + self._draw_line_vertical_angle(annotated, center_pt, angle_blue, + length_ratio=0.7, + color=VIS_COLORS["centerline"], thickness=3) + + # 2) 白色虚线:过质心的"画面竖直(0°)"——代表用户假定行走朝向 + angle_white = 0.0 + self._draw_dashed_line_vertical_angle(annotated, center_pt, angle_white, + length_ratio=0.7, + dash=12, gap=8, color=(255, 255, 255), thickness=2) + + # 3) 角差显示(可选):青色 vs 白虚线 + diff = angle_blue - 0.0 # = angle_blue + diff = (diff + 180.0) % 360.0 - 180.0 # wrap 到 [-180,180] + cv2.putText(annotated, f"{abs(diff):.1f}°", + (min(center_pt[0] + 12, w - 110), max(center_pt[1] - 12, 30)), + cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2) + + # === 用青色法线中央直线 计算"左右偏移" === + try: + # 注意:_offset_from_centerline 的角度坐标系与 _draw_line_vertical_angle 一致(竖直为0°) + offset_new = self._offset_from_centerline(center_pt, angle_blue, w, h, y_ratio=0.75) + offset = float(offset_new) + except Exception: + # 兜底:若计算异常,保持原 offset(默认为0) + pass + + # 导航方向(基础) + if abs(angle_deg) >= ANGLE_THRESH_DEG: + direction = "左转一点" if angle_deg > 0 else "右转一点" + elif abs(offset) >= OFFSET_THRESH: + direction = "向右平移" if offset > 0 else "向左平移" + else: + direction = "保持直行" + + # 障碍物引导优先级(近距离优先覆盖方向提示) + obstacle_override = None + if detected_obstacles: + NEAR_Y = 0.7 + NEAR_AREA = 0.1 + near_list = [o for o in detected_obstacles if (o.get('bottom_y_ratio', 0) > NEAR_Y or o.get('area_ratio', 0) > NEAR_AREA)] + if near_list: + name = (near_list[0].get('name') or '障碍物') + obstacle_override = self._speech_for_obstacle(name) + + # 【移除左上角调试信息,改为右上角数据面板】 + # 更新右上角数据面板(合并到已有的面板数据中) + src_text = "分割" if not used_tracking else "追踪" + # 数据面板在前面已经添加了,这里只记录调试数据 + # 稍后会统一添加完整的数据面板 + + # 语音输出(节流) + if current_time - self.last_guide_time > self.guide_interval: + # 检查是否快走完斑马线 + is_almost_done = self._is_crosswalk_almost_done(crosswalk_mask, h, w) + + # 调试信息:显示判定条件 + if self.frame_counter % 30 == 0: + ys = np.where(crosswalk_mask > 0)[0] + if ys.size > 0: + top_y, bottom_y = int(ys.min()), int(ys.max()) + logger.info(f"[CROSS_STREET] area_ratio={area_ratio:.3f}, top_ratio={top_y/h:.3f}, bottom_ratio={bottom_y/h:.3f}, almost_done={is_almost_done}") + + # 优先级1:红绿灯警告(绿灯倒计时) + if traffic_light_warning: + guidance_text = traffic_light_warning + self.last_guide_time = current_time + # 优先级2:过马路结束提示(斑马线快消失) + elif is_almost_done and not self.crossing_end_announced: + guidance_text = "过马路结束,准备上人行道。" + self.crossing_end_announced = True + self.last_guide_time = current_time + # 优先级3:盲道提示(过马路结束后检测到盲道,可重复播报但节流4秒) + elif self.crossing_end_announced and blindpath_mask is not None: + if current_time - self.last_blindpath_announce_time > 4.0: + guidance_text = "远处有盲道,继续前行。" + self.last_blindpath_announce_time = current_time + self.last_guide_time = current_time + # 优先级4:障碍物 + elif obstacle_override: + guidance_text = obstacle_override + self.last_guide_time = current_time + # 优先级5:方向引导 + else: + guidance_text = direction + self.last_guide_time = current_time + else: + # CROSSING 阶段但没有检测到斑马线 + no_crosswalk_duration = current_time - self.last_crosswalk_seen_time + # 【移除左上角文字,改为右上角数据面板】 + frame_visualizations.append({ + "type": "data_panel", + "data": { + "状态": "正在过马路", + "斑马线": f"未检测到 ({no_crosswalk_duration:.1f}s)" + }, + "position": (w - 180, 20) + }) + + # 连续超过10秒没有斑马线,才播报"过马路结束" + if no_crosswalk_duration > 10.0: + if not self.crossing_end_announced: + if current_time - self.last_guide_time > self.guide_interval: + # 优先级1:红绿灯警告 + if traffic_light_warning: + guidance_text = traffic_light_warning + self.last_guide_time = current_time + # 优先级2:过马路结束 + else: + guidance_text = "过马路结束,准备上人行道。" + self.crossing_end_announced = True + self.last_guide_time = current_time + # 播报结束后,检测到盲道则重复播报(节流4秒) + elif blindpath_mask is not None: + if current_time - self.last_blindpath_announce_time > 4.0: + guidance_text = "远处有盲道,继续前行。" + self.last_blindpath_announce_time = current_time + self.last_guide_time = current_time + + # 【移除帧信息】 + # 添加底部指令按钮(显示当前状态或引导内容) + if guidance_text: + current_instruction = guidance_text + elif self.state == STATE_SEEKING: + current_instruction = self.last_seeking_guidance if self.last_seeking_guidance else "寻找斑马线..." + elif self.state == STATE_WAIT_LIGHT: + current_instruction = "等待绿灯..." + elif self.state == STATE_CROSSING: + current_instruction = "过马路中..." + else: + current_instruction = "等待中..." + annotated = self._draw_command_button(annotated, current_instruction) + + # 统一渲染障碍物等可视化图层(blindpath 风格) + if frame_visualizations: + annotated = self._draw_visualizations(annotated, frame_visualizations) + + # 【修改】不在工作流内部播放音频,由app_main统一处理 + # 直接返回guidance_text给上层调用者(app_main)来播放 + + # 更新 prev_gray(供障碍物稳定化使用) + self.prev_gray = gray + + return CrossStreetResult( + annotated_image=annotated, + guidance_text=guidance_text, + visualizations=frame_visualizations, + should_switch_to_blindpath=False + ) + + except Exception as e: + logger.error(f"[CROSS_STREET] 处理帧时出错: {e}", exc_info=True) + return CrossStreetResult( + annotated_image=bgr_image, + guidance_text="", + visualizations=[], + should_switch_to_blindpath=False + ) + +class YOLOModelWrapper: + """YOLO 模型包装器,将 predict 方法适配为 detect""" + + def __init__(self, yolo_model): + self.model = yolo_model + + def detect(self, image, confidence_threshold=0.25): + """使用 predict 方法并转换为 detect 格式""" + try: + results = self.model.predict(image, conf=confidence_threshold, verbose=False) + detections = [] + if results and len(results) > 0: + result = results[0] + if hasattr(result, 'masks') and result.masks is not None: + for i, mask in enumerate(result.masks.data): + if hasattr(result, 'boxes') and result.boxes is not None: + cls = int(result.boxes.cls[i].cpu().numpy()) + conf = float(result.boxes.conf[i].cpu().numpy()) + class Detection: + def __init__(self): + self.cls = cls + self.conf = conf + self.mask = mask.cpu().numpy() + detections.append(Detection()) + return detections + except Exception as e: + logger.error(f"[YOLO Wrapper] 检测错误: {e}") + return [] \ No newline at end of file diff --git a/yoloe_backend.py b/yoloe_backend.py new file mode 100644 index 0000000..223c10c --- /dev/null +++ b/yoloe_backend.py @@ -0,0 +1,97 @@ +# yoloe_backend.py +# -*- coding: utf-8 -*- +from typing import List, Dict, Any, Optional, Tuple, Union +import os +import cv2 +import numpy as np + +# Day 20: TensorRT 模型加载工具 +from model_utils import get_best_model_path + +# 兼容 YOLOE / YOLO +try: + from ultralytics import YOLOE as _MODEL +except Exception: + from ultralytics import YOLO as _MODEL + +# Day 20: 优先使用 TensorRT 引擎 +DEFAULT_MODEL_PATH = get_best_model_path(os.getenv("YOLOE_MODEL_PATH", "model/yoloe-11l-seg.pt")) +TRACKER_CFG = os.getenv("YOLO_TRACKER_YAML", "bytetrack.yaml") + +class YoloEBackend: + def __init__(self, model_path: Optional[str] = None, device: Optional[Union[str, int]] = None): + actual_path = model_path or DEFAULT_MODEL_PATH + self.model = _MODEL(actual_path) + # Day 20: TensorRT 引擎不需要 .to() + from model_utils import is_tensorrt_engine + if not is_tensorrt_engine(actual_path): + self.model.to("cuda") + self.device = device + + def set_text_classes(self, names: List[str]): + # YOLOE 文本提示:与你模板一致 + # Day 20: TensorRT 引擎不支持 get_text_pe + if hasattr(self.model, 'get_text_pe'): + self.model.set_classes(names, self.model.get_text_pe(names)) + else: + print(f"[YOLOE] TensorRT 模式:跳过 set_text_classes") + + def segment(self, + frame_bgr: np.ndarray, + conf: float = 0.20, + iou: float = 0.45, + imgsz: int = None, # Day 20: 改为 None,从环境变量读取 + persist: bool = True + ) -> Dict[str, Any]: + """ + 返回: + dict{ + 'masks': List[np.uint8(H,W)], # 0/1 mask + 'boxes': List[Tuple[x1,y1,x2,y2]], + 'cls_ids': List[int], + 'names': List[str], + 'ids': List[Optional[int]] + } + """ + # Day 20: 使用环境变量,保持与 TensorRT 导出尺寸一致 + if imgsz is None: + imgsz = int(os.getenv("AIGLASS_YOLO_IMGSZ", "480")) + + r = self.model.track( + frame_bgr, + conf=conf, iou=iou, imgsz=imgsz, + persist=persist, tracker=TRACKER_CFG, verbose=False + )[0] + + out = {"masks": [], "boxes": [], "cls_ids": [], "names": [], "ids": []} + masks_obj = getattr(r, "masks", None) + boxes_obj = getattr(r, "boxes", None) + + if masks_obj is None or getattr(masks_obj, "data", None) is None: + return out + + mask_arr = masks_obj.data.cpu().numpy() # [N, h, w], float(0..1) + H, W = frame_bgr.shape[:2] + id2name = r.names if hasattr(r, "names") else {} + N = mask_arr.shape[0] + + if boxes_obj is not None: + xyxy = boxes_obj.xyxy.cpu().numpy() + cls = boxes_obj.cls.cpu().tolist() + tids = boxes_obj.id.int().cpu().tolist() if boxes_obj.id is not None else [None]*N + else: + xyxy = [None]*N + cls = [0]*N + tids = [None]*N + + for i in range(N): + bin_mask = (mask_arr[i] > 0.5).astype(np.uint8) + if bin_mask.shape[:2] != (H, W): + bin_mask = cv2.resize(bin_mask, (W, H), interpolation=cv2.INTER_NEAREST) + out["masks"].append(bin_mask) + out["boxes"].append(tuple(xyxy[i]) if xyxy[i] is not None else None) + cid = int(cls[i]) if cls is not None else 0 + out["cls_ids"].append(cid) + out["names"].append(id2name.get(cid, str(cid))) + out["ids"].append(int(tids[i]) if tids[i] is not None else None) + return out diff --git a/yolomedia.py b/yolomedia.py new file mode 100644 index 0000000..a49854f --- /dev/null +++ b/yolomedia.py @@ -0,0 +1,1567 @@ +# -*- coding: utf-8 -*- +""" +YOLOv8 单类分割 + MediaPipe Hand Landmarker + 光流追踪(多边形) +更新点(本版重点): +- 左下角第二个进度条"距离(≈1)" 已完全替换为:ratio = 物体面积 / 手面积 的"接近 1 程度"可视化 + -> range_score = 1 - clamp(|ratio - 1| / RATIO_TOL, 0..1) + -> 画面同时显示 ratio 数值;ratio<1 提示"向前靠近",ratio>1 提示"后退",在 [1±RATIO_TOL] 内为"保持" +其他特性: +- Enter 锁定:在分割掩码"内收 5px"的内边界上取光流点 +- TRACK 期间:监控当前多边形外扩 40px 周边区域的分割,命中即重锁 +- 成功判定:放宽"握持(Grasp)"启发式(拿瓶子无需特别紧) +- 手骨架单色渲染;测距箭头(端点定位线 + 箭头 + 像素值) +- 中文绘制优先 Pillow + 系统中文字体(避免问号) +""" + +import os +import time +import threading +import math +import cv2 +import numpy as np +import mediapipe as mp +from mediapipe.framework.formats import landmark_pb2 +from ultralytics import YOLO +from ultralytics.utils.plotting import Colors +import bridge_io +import pygame # 用于播放本地音频文件 + +from audio_player import play_audio_threadsafe +PERF_DEBUG = False # 打印调试信息(False 关闭) +HAND_DOWNSCALE = 0.8 # HandLandmarker 的输入缩放 0.5=长宽各减半(≈1/4 像素量) +HAND_FPS_DIV = 1 # 人手每 2 帧跑一次(1=每帧;2=隔帧;3=每3帧) + + +# === 前端风格配色(BGR) + UI叠加管理(左下角按行堆叠) === +FRONTEND_COLORS = { + "text": (230, 237, 243), # --text: #e6edf3 + "muted": (159, 176, 195), # --muted: #9fb0c3 + "ok": (126, 231, 135), # --ok: #7ee787 + "err": (128, 128, 255), # --err: #ff8080 (BGR) + "accent": (251, 218, 97), # #61dafb 近似的强调色(BGR 取近似亮色) +} + +# 底部指令按钮文本 +CURRENT_COMMAND_TEXT = "—" + +_UI_LINE = 0 +_UI_H = 0 +_UI_TR_LINE = 0 # 右上角逐行叠放计数 +_UI_TOP_MARGIN = 12 +_UI_RIGHT_MARGIN = 12 +UNIFIED_FONT_PX = 12 # 统一字号 + + +def ui_reset_overlay(img_h: int): + """每帧调用一次,重置叠加行计数(改为右上角布局)。""" + global _UI_LINE, _UI_H, _UI_TR_LINE + _UI_LINE = 0 + _UI_TR_LINE = 0 + _UI_H = int(img_h) + + +def _ui_next_y_top(font_size: int) -> int: + """返回右上角下一行的y(顶部对齐),并推进行计数。""" + global _UI_TR_LINE + line_gap = max(4, int(font_size * 0.25)) + y_top = _UI_TOP_MARGIN + (_UI_TR_LINE * (font_size + line_gap)) + _UI_TR_LINE += 1 + return y_top + + +def set_current_command(text: str): + global CURRENT_COMMAND_TEXT + try: + CURRENT_COMMAND_TEXT = str(text) if text else "—" + except Exception: + CURRENT_COMMAND_TEXT = "—" + + +def draw_command_pill(img_bgr: np.ndarray, label: str): + """统一改为右上角白色文案。不再绘制底部圆角按钮。""" + text_prefix = "当前指令:" + full_text = f"{text_prefix}{label if label else '—'}" + # 直接用统一文本渲染 + draw_text_cn(img_bgr, full_text, (0, 0), font_size=UNIFIED_FONT_PX, color=(255,255,255), ui_hint=True) + +try: + from yoloe_backend import YoloEBackend + _YOLOE_READY = True +except Exception as e: + _YOLOE_READY = False + print(f"[DETECTOR] YOLOE backend not ready: {e}", flush=True) + +# ========= 路径参数(按需修改)========= +YOLO_MODEL_PATH = 'model/shoppingbest5.pt' +HAND_TASK_PATH = 'model/hand_landmarker.task' + +# ========= 摄像头 ========= +CAM_INDEX = 0 +INPUT_W, INPUT_H = 600, 480 + +# ========= 分割显示 ========= +STROKE_WIDTH = 5 # 增加描边宽度,让黄框和绿框更粗 +MASK_ALPHA = 0.45 +CONF_THRESHOLD = 0.20 + +# —— 单 prompt 识别(只显示一个类)—— +PROMPT_NAME = "AD_milk" +PROMPT_STRICT = True + +# ========= 对齐条参数 ========= +ALIGN_LOOSE_PCT = 0.12 # 归一化距离阈(相对画面对角线) + +# ========= 距离条参数(本版采用"ratio≈1"为目标)========= +RATIO_IDEAL = 1.0 # 理想值:物体面积/手面积 ≈ 1 +RATIO_TOL = 0.25 # 容许偏离:±25% 内认为距离合适 + +# ========= 语音播报 ========= +TTS_INTERVAL_SEC = 1.0 +ENABLE_TTS = True + +# ========= 光流(LK)与特征点 ========= +LK_PARAMS = dict(winSize=(21, 21), + maxLevel=3, + criteria=(cv2.TERM_CRITERIA_EPS | cv2.TERM_CRITERIA_COUNT, 12, 0.03)) +FEATURE_PARAMS = dict(maxCorners=600, + qualityLevel=0.001, + minDistance=5, + blockSize=7) + +# ========= 关键参数:内收与周边监控 ========= +INNER_OFFSET_PX_LOCK = 5 # Enter 锁定:掩码腐蚀像素,保证点在物体内部 +EDGE_DILATE_PX = 2 # 取内边界后小膨胀,利于提点 +PERI_MONITOR_PX = 40 # TRACK:监控多边形外扩 40px 的周边带 +PERI_CHECK_EVERY = 5 # 每隔 N 帧做一次周边分割检查,改为每帧 + +# ========= 轮廓精度参数 ========= +CONTOUR_EPSILON_FACTOR = 0.002 # Douglas-Peucker算法的精度因子,越小越精细 +TRACK_EPSILON_FACTOR = 0.003 # 追踪模式下的轮廓精度因子 + +# ========= YOLO实时矫正参数 ========= +YOLO_CORRECTION_IOU_THRESHOLD = 0.2 # IoU阈值,越低越积极矫正 +YOLO_CORRECTION_CONF_THRESHOLD = 0.15 # 置信度阈值,越低检测越敏感 + +# ========= 方向引导音频路径 ========= +AUDIO_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "music") # 相对路径 +AUDIO_FILES = { + "向上": os.path.join(AUDIO_DIR, "向上.wav"), + "向下": os.path.join(AUDIO_DIR, "向下.wav"), + "向左": os.path.join(AUDIO_DIR, "向左.wav"), + "向右": os.path.join(AUDIO_DIR, "向右.wav"), + "向前": os.path.join(AUDIO_DIR, "向前.wav"), + "后退": os.path.join(AUDIO_DIR, "向后.wav"), + "OK": os.path.join(AUDIO_DIR, "已对中.wav"), +} +GUIDANCE_INTERVAL_SEC = 1.5 # 引导播报间隔 + +# 初始化pygame音频 +pygame.mixer.init() + +# ========= 窗口 ========= +WINDOW = "YOLO Seg + Flow Polygon (Peri-Relock) (Grab Guidance)" + +# ======== MediaPipe 别名 ======== +BaseOptions = mp.tasks.BaseOptions +VisionRunningMode = mp.tasks.vision.RunningMode +HandLandmarker = mp.tasks.vision.HandLandmarker +HandLandmarkerOptions = mp.tasks.vision.HandLandmarkerOptions +HAND_CONNECTIONS = mp.solutions.hands.HAND_CONNECTIONS + +# ======== HandLandmarker 回调缓存 ======== +_last_result = None # (result, timestamp_ms) + +def on_result(result: mp.tasks.vision.HandLandmarkerResult, + output_image: mp.Image, timestamp_ms: int): + global _last_result + _last_result = (result, timestamp_ms) + +def _to_proto(hand_lms) -> landmark_pb2.NormalizedLandmarkList: + proto = landmark_pb2.NormalizedLandmarkList() + proto.landmark.extend([ + landmark_pb2.NormalizedLandmark(x=p.x, y=p.y, z=p.z) for p in hand_lms + ]) + return proto + +# —— 手骨架单色渲染 —— # +def draw_hands_mono(img_bgr, hand_lms, color=(0, 255, 255), r=2, t=2): + mp_drawing = mp.solutions.drawing_utils + landmark_spec = mp_drawing.DrawingSpec(color=color, thickness=-1, circle_radius=r) + connection_spec = mp_drawing.DrawingSpec(color=color, thickness=t, circle_radius=r) + if hasattr(hand_lms, "landmark"): + proto = hand_lms + else: + proto = _to_proto(hand_lms) + mp_drawing.draw_landmarks( + img_bgr, + landmark_list=proto, + connections=HAND_CONNECTIONS, + landmark_drawing_spec=landmark_spec, + connection_drawing_spec=connection_spec, + ) + +def norm_name(s: str) -> str: + return "".join(str(s).lower().split()) + +# ======== TTS(pyttsx3)======== +class Speaker: + def __init__(self, enable=True): + self.enable = enable + self._engine = None + self._lock = threading.Lock() + if enable: + try: + import pyttsx3 + self._engine = pyttsx3.init() + self._engine.setProperty('rate', 190) + self._engine.setProperty('volume', 1.0) + except Exception: + self._engine = None + self.enable = False + + def say_async(self, text: str): + if not self.enable or not text: + return + def _run(): + try: + with self._lock: + self._engine.stop() + self._engine.say(text) + self._engine.iterate() + t0 = time.time() + while self._engine.isBusy() and (time.time() - t0) < 1.2: + self._engine.iterate() + time.sleep(0.01) + except Exception: + pass + threading.Thread(target=_run, daemon=True).start() + +# ======== 中文文本绘制(优先 Pillow)======== +_PIL_OK = False +_FONT_PATH = None +def _init_font(): + global _PIL_OK, _FONT_PATH + try: + from PIL import ImageFont # noqa + _PIL_OK = True + except Exception: + _PIL_OK = False + return + candidates = [ + # Linux 中文字体路径 (Ubuntu/Debian) + "/usr/share/fonts/truetype/wqy/wqy-zenhei.ttc", + "/usr/share/fonts/truetype/wqy/wqy-microhei.ttc", + "/usr/share/fonts/opentype/noto/NotoSansCJK-Regular.ttc", + "/usr/share/fonts/truetype/noto/NotoSansCJK-Regular.ttc", + "/usr/share/fonts/truetype/droid/DroidSansFallbackFull.ttf", + ] + for p in candidates: + if os.path.exists(p): + _FONT_PATH = p + return + _PIL_OK = False +_init_font() + +def draw_text_cn(img_bgr, text, xy, font_size=20, color=(255,255,255), stroke=None, ui_hint=True): + """ + 统一的文本绘制: + - 默认采用前端风格:小字体、左下角按行堆叠(ui_hint=True)。 + - 若 ui_hint=False 则按传入 xy 精确定位(用于贴近目标的小标注)。 + """ + # 统一样式:微软雅黑 + 固定字号 + 纯白 + color = (255, 255, 255) + font_size = int(UNIFIED_FONT_PX) + + H, W = img_bgr.shape[:2] + # 右上角堆叠布局:计算y顶边,并按文本宽度右对齐 + y_top = _ui_next_y_top(font_size) if ui_hint else _ui_next_y_top(font_size) + # 先估算文本尺寸 + tw = th = 0 + font_obj = None + + if _PIL_OK and _FONT_PATH: + try: + from PIL import Image, ImageDraw, ImageFont + font_obj = ImageFont.truetype(_FONT_PATH, font_size) + # 计算文本尺寸 + bbox = ImageDraw.Draw(Image.new('RGB', (1,1))).textbbox((0,0), text, font=font_obj) + tw = max(1, bbox[2] - bbox[0]) + th = max(1, bbox[3] - bbox[1]) + except Exception: + pass + if _PIL_OK and _FONT_PATH and font_obj is not None: + try: + from PIL import Image, ImageDraw + img_rgb = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2RGB) + pil_img = Image.fromarray(img_rgb) + draw = ImageDraw.Draw(pil_img) + x = max(8, W - _UI_RIGHT_MARGIN - tw) + y = y_top + draw.text((x, y), text, fill=(255,255,255), font=font_obj) + img_bgr[:] = cv2.cvtColor(np.asarray(pil_img), cv2.COLOR_RGB2BGR) + return + except Exception: + pass + # OpenCV 回退:估算尺寸并右对齐 + if tw <= 0 or th <= 0: + scale = font_size/24.0 + (tw, th), _ = cv2.getTextSize(text, cv2.FONT_HERSHEY_SIMPLEX, scale, 2) + x = max(8, W - _UI_RIGHT_MARGIN - int(tw)) + y_baseline = int(y_top + th) + cv2.putText(img_bgr, text, (x, y_baseline), cv2.FONT_HERSHEY_SIMPLEX, font_size/24.0, color, 2, cv2.LINE_AA) + +# ======== 工具函数 ======== +def clamp01(x): return max(0.0, min(1.0, x)) + +def draw_progress_bars(vis, align_score, range_score): + """第一条=对齐,第二条=距离(≈1),对应 ratio 与 1 的接近程度""" + H, W = vis.shape[:2] + bar_w = int(W * 0.28) + bar_h = 12 + gap = 8 + x0 = 12 + y0 = H - 2*bar_h - gap - 12 + # 背景 + cv2.rectangle(vis, (x0, y0), (x0 + bar_w, y0 + bar_h), (50, 50, 50), -1) + cv2.rectangle(vis, (x0, y0 + bar_h + gap), (x0 + bar_w, y0 + 2*bar_h + gap), (50, 50, 50), -1) + # 填充 + cv2.rectangle(vis, (x0, y0), (x0 + int(bar_w * clamp01(align_score)), y0 + bar_h), (0, 220, 0), -1) + cv2.rectangle(vis, (x0, y0 + bar_h + gap), (x0 + int(bar_w * clamp01(range_score)), y0 + 2*bar_h + gap), (0, 180, 255), -1) + draw_text_cn(vis, "对齐", (x0, y0 - 18), font_size=18, color=(180,180,180)) + draw_text_cn(vis, "距离(≈1)", (x0, y0 + bar_h + gap - 18), font_size=18, color=(180,180,180)) + +def polygon_center_and_area(poly): + if poly is None or len(poly) < 3: + return None, 0.0 + poly = np.array(poly, dtype=np.float32) + M = cv2.moments(poly) + if abs(M["m00"]) < 1e-6: + c = np.mean(poly, axis=0) + return (float(c[0]), float(c[1])), 0.0 + cx = float(M["m10"] / M["m00"]) + cy = float(M["m01"] / M["m00"]) + area = float(cv2.contourArea(poly.astype(np.int32))) + return (cx, cy), area + +def hand_bbox_and_area(lms, W, H): + xs = [int(p.x * W) for p in lms] + ys = [int(p.y * H) for p in lms] + if not xs or not ys: + return None, 0.0 + x0, y0, x1, y1 = min(xs), min(ys), max(xs), max(ys) + w = max(1, x1 - x0) + h = max(1, y1 - y0) + area = float(w * h) + return (x0, y0, w, h), area + +# ======== 手势:握持(Grasp) 识别(放宽版启发式)======== +THUMB_INDEX_CLOSE = 0.34 # 放宽 +FINGERTIP_NEAR = 0.44 # 放宽 +MIN_CURLED_COUNT = 1 # 放宽 + +def detect_grasp(hand_lms, W, H): + box, _ = hand_bbox_and_area(hand_lms, W, H) + if not box: + return False, 0.0 + x0, y0, w0, h0 = box + hand_diag = float(np.hypot(w0, h0)) + 1e-6 + palm_idx = [0, 5, 9, 13, 17] + px = np.mean([hand_lms[i].x * W for i in palm_idx]) + py = np.mean([hand_lms[i].y * H for i in palm_idx]) + palm = np.array([px, py], dtype=np.float32) + t4 = np.array([hand_lms[4].x * W, hand_lms[4].y * H], dtype=np.float32) + t8 = np.array([hand_lms[8].x * W, hand_lms[8].y * H], dtype=np.float32) + thumb_index_dist = float(np.linalg.norm(t4 - t8)) / hand_diag + tips = [12, 16, 20] + dists = [] + for i in tips: + ti = np.array([hand_lms[i].x * W, hand_lms[i].y * H], dtype=np.float32) + dists.append(float(np.linalg.norm(ti - palm)) / hand_diag) + curled_cnt = sum(1 for d in dists if d < FINGERTIP_NEAR) + cond1 = (thumb_index_dist < THUMB_INDEX_CLOSE) + cond2 = (curled_cnt >= MIN_CURLED_COUNT) + score = 0.5 * (1.0 - min(thumb_index_dist / THUMB_INDEX_CLOSE, 1.0)) + \ + 0.5 * min(curled_cnt / 3.0, 1.0) + return (cond1 and cond2), score + +# ======== 内收后的边界提点 ======== +def inner_offset_edge(mask_bin, offset_px=5, edge_dilate_px=2): + if offset_px > 0: + k = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*offset_px+1, 2*offset_px+1)) + eroded = cv2.erode(mask_bin.astype(np.uint8), k, iterations=1) + else: + eroded = mask_bin.astype(np.uint8) + edges = cv2.Canny(eroded*255, 50, 150) + if edge_dilate_px > 0: + k2 = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (2*edge_dilate_px+1, 2*edge_dilate_px+1)) + edges = cv2.dilate(edges, k2, iterations=1) + return edges # uint8 0/255 + +# ======== YOLO 分割:全帧或 ROI 内选择最佳 mask ======== +def find_best_mask(frame_bgr, yolo, W, H, target_cls_id, conf_thr=0.10, roi_rect=None): + results = yolo(frame_bgr, verbose=False) + best_mask = None + best_score = 0.0 + if results and results[0].masks is not None: + r0 = results[0] + for mask_t, conf_t, cls_t in zip(r0.masks.data, r0.boxes.conf, r0.boxes.cls): + cls_id = int(cls_t.item()) + conf_value = float(conf_t.item()) + if target_cls_id is not None and cls_id != target_cls_id: + continue + if conf_value < conf_thr: + continue + mask_np = mask_t.detach().cpu().numpy() + mask_rz = cv2.resize(mask_np, (W, H), interpolation=cv2.INTER_LINEAR) + mask_bin = (mask_rz > 0.5).astype(np.uint8) + + if roi_rect is not None: + x0, y0, x1, y1 = roi_rect + x0, y0 = max(0, x0), max(0, y0) + x1, y1 = min(W-1, x1), min(H-1, y1) + roi = np.zeros_like(mask_bin, dtype=np.uint8) + roi[y0:y1+1, x0:x1+1] = 1 + overlap = (mask_bin & roi).sum() + score = float(overlap) + else: + score = float(mask_bin.sum()) + + if score > best_score: + best_score = score + best_mask = mask_bin + return best_mask + +# ======== 工程化:测距箭头(端点定位线 + 箭头 + 像素值)======== +def draw_measure_arrow(img, p1, p2, txt=None): + p1 = (int(p1[0]), int(p1[1])) + p2 = (int(p2[0]), int(p2[1])) + # 端点定位线 + def end_cap(pt, size=8, color=(255,255,255), t=1): + x, y = pt + cv2.line(img, (x - size, y), (x + size, y), color, t, cv2.LINE_AA) + cv2.line(img, (x, y - size), (x, y + size), color, t, cv2.LINE_AA) + end_cap(p1, size=7, color=(255,255,255), t=1) + end_cap(p2, size=7, color=(255,255,255), t=1) + # 箭头 + cv2.arrowedLine(img, p1, p2, (255,255,255), 2, cv2.LINE_AA, tipLength=0.18) + # 文本 + if txt is None: + d = int(np.hypot(p2[0]-p1[0], p2[1]-p1[1])) + txt = f"{d}px" + mid = ((p1[0]+p2[0])//2, (p1[1]+p2[1])//2) + font = cv2.FONT_HERSHEY_SIMPLEX + fs, th = 0.6, 2 + (tw, th_text), _ = cv2.getTextSize(txt, font, fs, th) + pad = 4 + x0 = mid[0] - tw//2 - pad + y0 = mid[1] - th_text - 6 + x1 = mid[0] + tw//2 + pad + y1 = mid[1] + 6 + cv2.rectangle(img, (x0, y0), (x1, y1), (32,32,32), -1) + cv2.putText(img, txt, (x0+pad, y1-6), font, fs, (255,255,255), th, cv2.LINE_AA) + +# 添加绘制虚线的函数 +def draw_dashed_line(img, pt1, pt2, color=(255, 255, 255), thickness=2, dash_length=10, gap_length=5): + """绘制虚线""" + pt1 = np.array(pt1, dtype=np.float32) + pt2 = np.array(pt2, dtype=np.float32) + line_vec = pt2 - pt1 + line_len = np.linalg.norm(line_vec) + if line_len < 1: + return + + line_vec = line_vec / line_len # 单位向量 + + # 绘制虚线段 + current_pos = 0 + while current_pos < line_len: + start_pos = current_pos + end_pos = min(current_pos + dash_length, line_len) + + start_pt = pt1 + line_vec * start_pos + end_pt = pt1 + line_vec * end_pos + + cv2.line(img, tuple(start_pt.astype(int)), tuple(end_pt.astype(int)), color, thickness) + + current_pos += dash_length + gap_length + +# 添加绘制手部轮廓的函数 +def draw_hand_contour(img, hand_lms, W, H, color=(255, 255, 255), thickness=1): + """绘制手部landmarks的凸包轮廓""" + # 获取所有手部关键点 + points = [] + for lm in hand_lms: + x = int(lm.x * W) + y = int(lm.y * H) + points.append([x, y]) + + if len(points) > 3: + points = np.array(points, dtype=np.int32) + # 计算凸包 + hull = cv2.convexHull(points) + # 绘制凸包轮廓 + cv2.polylines(img, [hull], True, color, thickness) + +# 检测手和物体是否接触 +def check_hand_object_contact(hand_box, poly, overlap_threshold=0.15): + """ + 检测手的边界框和物体多边形是否有重叠 + 返回: (是否接触, 重叠比例) + """ + if hand_box is None or poly is None or len(poly) < 3: + return False, 0.0 + + # 获取手的边界框 + hx, hy, hw, hh = hand_box + hand_rect = np.array([ + [hx, hy], + [hx + hw, hy], + [hx + hw, hy + hh], + [hx, hy + hh] + ], dtype=np.int32) + + # 创建掩码来计算重叠 + H = int(max(hy + hh, np.max(poly[:, 1])) + 10) + W = int(max(hx + hw, np.max(poly[:, 0])) + 10) + + hand_mask = np.zeros((H, W), dtype=np.uint8) + cv2.fillPoly(hand_mask, [hand_rect], 1) + + obj_mask = np.zeros((H, W), dtype=np.uint8) + cv2.fillPoly(obj_mask, [poly.astype(np.int32)], 1) + + # 计算重叠 + intersection = np.logical_and(hand_mask, obj_mask).sum() + hand_area = hand_mask.sum() + + # 重叠比例(相对于手的面积) + overlap_ratio = intersection / max(1.0, hand_area) + + return overlap_ratio > overlap_threshold, overlap_ratio + +# 添加方向判断函数 +def get_guidance_direction(hand_center, object_center, hand_area, object_area, hand_box=None, poly=None): + """ + 根据手心和物体中心位置,以及面积比,返回引导方向 + 返回: (方向文字, 是否需要前后调整) + """ + if hand_center is None or object_center is None: + return None, None + + # 首先检查手和物体是否接触 + is_touching = False + overlap_ratio = 0.0 + if hand_box is not None and poly is not None: + is_touching, overlap_ratio = check_hand_object_contact(hand_box, poly, overlap_threshold=0.1) + + hx, hy = hand_center + ox, oy = object_center + + # 计算水平和垂直偏差 + dx = ox - hx # 正数表示物体在右边 + dy = oy - hy # 正数表示物体在下边 + + # 如果手和物体已经接触,直接返回"向前" + if is_touching: + return "向前", f"接触度: {overlap_ratio:.1%}" + + # 如果没有接触,引导上下左右 + # 判断主要方向 + h_threshold = 30 # 水平偏差阈值(像素) + v_threshold = 30 # 垂直偏差阈值(像素) + + h_dir = None + v_dir = None + + # 水平方向 + if abs(dx) > h_threshold: + h_dir = "向右" if dx > 0 else "向左" + + # 垂直方向 + if abs(dy) > v_threshold: + v_dir = "向下" if dy > 0 else "向上" + + # 选择偏移最大的方向 + if abs(dx) > abs(dy) and h_dir: + # 水平偏移更大 + return h_dir, v_dir + elif v_dir: + # 垂直偏移更大或相等 + return v_dir, h_dir + else: + # 已经在中心附近但还没接触,提示靠近 + distance = np.sqrt(dx**2 + dy**2) + if distance < 50: # 很近但还没接触 + return "向前", "请缓慢靠近" + else: + return "保持", None + +# 播放音频的函数 +def play_guidance_audio(direction): + """播放方向引导音频""" + # 直接调用新的音频播放函数 + play_audio_threadsafe(direction) + # 同步更新底部按钮的指令文本 + try: + if isinstance(direction, str) and direction.strip(): + set_current_command(direction.strip()) + except Exception: + pass + +# 添加居中判断函数 +def get_center_guidance(object_center, frame_center, threshold=30): + """ + 判断物体是否在画面中心,返回引导方向 + 返回: (方向文字, 是否已居中) + """ + if object_center is None: + return None, False + + ox, oy = object_center + cx, cy = frame_center + + dx = cx - ox # 正数表示需要向右移动 + dy = cy - oy # 正数表示需要向下移动 + + # 判断是否已经居中 + distance = np.sqrt(dx**2 + dy**2) + if distance < threshold: + return "已居中", True + + # 判断主要方向(对调左右和上下) + if abs(dx) > abs(dy): + return "向左" if dx > 0 else "向右", False # 对调了 + else: + return "向上" if dy > 0 else "向下", False # 对调了 + +def main(headless: bool = False, prompt_name: str = None, stop_event=None): + + # OpenCV 优化 + try: + import cv2 + cv2.setUseOptimized(True) + cv2.setNumThreads(2) # 视 CPU 核心数而定;树莓派类设备可设 1 + except Exception: + pass + + + + + # 如果传入了 prompt_name,使用它替换全局的 PROMPT_NAME + global PROMPT_NAME + if prompt_name: + PROMPT_NAME = prompt_name + print(f"[YOLOMEDIA] Using dynamic prompt: {PROMPT_NAME}") + + speaker = Speaker(ENABLE_TTS) + last_tts_ts = 0.0 + MODE = "SEGMENT" # 模式:SEGMENT -> FLASH -> CENTER_GUIDE -> TRACK + colors = Colors() + + FRAME_IDX = 0 + last_mask = None # 上一帧"目标掩膜"(用于 IoU 降噪) + flow_mask = None # 光流外推得到的掩膜(你现有代码里会更新它) + flow_grace = 0 # YOLOE 丢检后,允许光流顶住的计数 + last_seen_ts = 0.0 # 最近一次 YOLOE 成功检测的时间戳 + locked_id = None # (可选)若你在 tracker 里记录了 id,可在下面选择相同 id + # 刷新/容错参数(可按需微调) + REDETECT_EVERY = 5 # 每 5 帧强制"信任 YOLOE 一次" + FLOW_GRACE_MAX = 8 # YOLOE 连续丢检时,光流最多顶 8 帧 + IOU_MIN_KEEP = 0.20 # 新/旧掩膜 IoU 太低时,用平滑合成,避免闪烁 + + + + print("[INIT] 加载 YOLO 模型...") + # NOTE: shoppingbest 不再用于找东西流程;如其他模式仍需,可保留 yolo = YOLO(...) 但不在本流程使用 + # yolo = YOLO(YOLO_MODEL_PATH) + + # —— 直接启用 YOLOE 文本提示后端(不再先查 shoppingbest)—— + use_yoloe = False + yoloe_backend = None + if _YOLOE_READY: + try: + yoloe_backend = YoloEBackend() # 可用 YOLOE_MODEL_PATH 环境变量指定模型 + yoloe_backend.set_text_classes([PROMPT_NAME]) # 文本类别 + use_yoloe = True + print(f"[DETECTOR] YOLOE text-prompt backend enabled for: {PROMPT_NAME}", flush=True) + except Exception as e: + print(f"[DETECTOR] YOLOE init failed: {e}", flush=True) + else: + print("[DETECTOR] YOLOE backend not ready (import failed)", flush=True) + + # 类名映射(YOLOE 模式下简化) + if use_yoloe: + # YOLOE 模式下,只有一个目标类 + id_to_name = {0: PROMPT_NAME} + name_to_id = {norm_name(PROMPT_NAME): 0} + target_cls_id = 0 + else: + # 如果将来需要支持传统 YOLO,可以在这里初始化 + id_to_name = {} + name_to_id = {} + target_cls_id = None + + # 目标类已在上面的 YOLOE 模式中设置 + + print(f"[CLASS] target id={target_cls_id}, name={id_to_name.get(target_cls_id, 'N/A')}") + print(f"[阈值] conf >= {CONF_THRESHOLD:.2f}") + + # Hand Landmarker + print("[INIT] 初始化 Hand Landmarker...") + base = BaseOptions(model_asset_path=HAND_TASK_PATH) + hand_options = HandLandmarkerOptions( + base_options=base, + running_mode=VisionRunningMode.LIVE_STREAM, + num_hands=1, + min_hand_detection_confidence=0.40, + min_hand_presence_confidence=0.50, + min_tracking_confidence=0.70, + result_callback=on_result + ) + landmarker = HandLandmarker.create_from_options(hand_options) + + W = None + H = None + print("[Bridge] 等待 ESP32 画面 ...") + + # [headless] 仅在非 headless 时创建窗口(原逻辑保留,外层加判断) + if not headless: + cv2.namedWindow(WINDOW, cv2.WINDOW_NORMAL) + + # 光流缓存 + old_gray = None + p0 = None + lock_edge_debug = None # 调试可视化:内边界 + track_frame_count = 0 # 控制周边监控频率 + last_poly_box = None # 当前多边形外接矩形 + + fps_hist = [] + + # 添加自动锁定相关变量 + auto_lock_start_time = None # 开始检测到物体的时间 + auto_lock_delay = 1.0 # 1秒后自动锁定 + last_detected_mask = None # 最后检测到的mask + + # 添加闪烁动画相关变量 + flash_start_time = None # 闪烁开始时间 + flash_duration = 1.0 # 闪烁持续时间(秒) + flash_frequency = 1 # 闪烁频率(Hz) - 只闪一次 + flash_mask = None # 用于闪烁的mask + flash_color = (0, 255, 255) # 闪烁颜色(黄色) + + # 添加引导相关变量 + last_guidance_time = 0 + last_guidance_direction = None + + # 添加居中引导相关变量 + center_guide_mask = None # 用于居中引导的mask + center_guide_start = None # 居中引导开始时间 + center_threshold = 30 # 居中判定阈值(像素) + last_center_guide_time = 0 # 上次居中引导语音时间 + center_reached = False # 是否已经到达中心 + + # 添加抓取跟踪相关变量 + grasp_tracking_frames = [] # 存储最近的手和物体位置 + grasp_tracking_duration = 1.0 # 需要持续1秒 + grasp_movement_threshold = 10 # 最小移动像素阈值(提高阈值) + grasp_detected = False # 是否已经检测到抓取 + grasp_start_time = None # 开始检测到协同移动的时间 + + # 背景参考点(用于检测相机移动) - 移到这里初始化 + background_points = None + old_background_gray = None + + try: + while True: + # 检查停止事件 + if stop_event and stop_event.is_set(): + print("[YOLOMEDIA] Stop event detected, exiting...") + break + + frame = bridge_io.wait_raw_bgr(timeout_sec=0.5) + if frame is None: + # 没取到帧就继续等(ESP32还没连上或暂时无新帧) + # [headless] 给出 1ms 让出调度,避免空转 + if headless: + cv2.waitKey(1) + continue + + # 每帧重置 UI 文字叠加到左下角 + H, W = frame.shape[:2] + ui_reset_overlay(H) + + vis = frame.copy() + t_now = time.time() + + # 抽帧 + 降采样(人手识别) + if FRAME_IDX % HAND_FPS_DIV == 0: + rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) + if HAND_DOWNSCALE and HAND_DOWNSCALE != 1.0: + small = cv2.resize(rgb, None, fx=HAND_DOWNSCALE, fy=HAND_DOWNSCALE, interpolation=cv2.INTER_AREA) + mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=small) + else: + mp_image = mp.Image(image_format=mp.ImageFormat.SRGB, data=rgb) + landmarker.detect_async(mp_image, int(t_now * 1000)) + # 否则跳过,复用上一次 _last_result;Landmarker 会自己做 tracking + + + # 取手心、手框、握持(放宽版) + hand_center = None + hand_area = None + hand_box = None + grasp_now = False + grasp_score = 0.0 + if _last_result is not None: + res, _ = _last_result + if res.hand_landmarks and len(res.hand_landmarks) > 0: + l0 = res.hand_landmarks[0] + + # 绘制手部骨骼 + draw_hands_mono(vis, l0, color=(0, 255, 255), r=2, t=2) + + # 绘制手部轮廓(替代矩形框) + draw_hand_contour(vis, l0, W, H, color=(255, 255, 255), thickness=1) + + xs = [p.x * W for p in l0] + ys = [p.y * H for p in l0] + hand_center = (float(sum(xs)/len(xs)), float(sum(ys)/len(ys))) + hand_box, hand_area = hand_bbox_and_area(l0, W, H) + # 注释掉矩形框绘制 + # if hand_box: + # x0, y0, w0, h0 = hand_box + # cv2.rectangle(vis, (x0, y0), (x0+w0, y0+h0), (0,255,255), 1) + grasp_now, grasp_score = detect_grasp(l0, W, H) + draw_text_cn(vis, f"握持评分: {grasp_score:.2f}", (10, 70), font_size=18, color=(0, 180, 255)) + + + if MODE == "SEGMENT": + # —— 仅 YOLOE:每帧文本提示分割 + 取最大目标(删掉 shoppingbest 与重复 YOLOE 段)—— + FRAME_IDX += 1 + candidate_masks = [] + detected_object = False + + if use_yoloe and yoloe_backend is not None: + # 每帧都跑;persist=True 便于维持目标 ID + det = yoloe_backend.segment(frame, conf=0.20, iou=0.45, persist=True) + H, W = frame.shape[:2] + + # 选一个掩膜:优先与 locked_id 相同;否则面积最大 + chosen_idx = None + if det["masks"]: + if locked_id is not None and det["ids"] and (locked_id in det["ids"]): + chosen_idx = det["ids"].index(locked_id) + else: + areas = [int(m.sum()) for m in det["masks"]] + chosen_idx = int(np.argmax(areas)) + + if chosen_idx is not None: + m = det["masks"][chosen_idx] + if m.shape[:2] != (H, W): + m = cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST) + + mask_bin = (m > 0).astype(np.uint8) + candidate_masks.append({ + "mask": mask_bin, + "area": int(mask_bin.sum()), + "name": PROMPT_NAME, + "cls_id": 0, + "conf": 0.99, + }) + detected_object = True + + # 简单可视化(半透明叠层 + 轮廓),不影响你后面的逻辑 + colored = np.zeros_like(frame, dtype=np.uint8) + colored[mask_bin == 1] = (0, 255, 255) + vis = cv2.addWeighted(vis, 1.0, colored, MASK_ALPHA, 0) + contours, _ = cv2.findContours(mask_bin, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + if contours: + # 选择最大轮廓并进行适度平滑 + largest_contour = max(contours, key=cv2.contourArea) + # 使用Douglas-Peucker算法适度简化,保持更多细节 + epsilon = CONTOUR_EPSILON_FACTOR * cv2.arcLength(largest_contour, True) # 更小的epsilon保留更多细节 + smoothed_contour = cv2.approxPolyDP(largest_contour, epsilon, True) + cv2.drawContours(vis, [smoothed_contour], -1, (0, 255, 255), STROKE_WIDTH) + + # 记录 id,减少目标跳变 + if det["ids"] and len(det["ids"]) > chosen_idx and det["ids"][chosen_idx] is not None: + locked_id = int(det["ids"][chosen_idx]) + + else: + # YOLOE 未就绪:提示并保持原画面(不阻塞前端) + draw_text_cn(vis, "YOLOE 未就绪,显示原始画面", (10, 100), font_size=22, color=(0, 215, 255)) + + # 选择面积最大的mask ←—— 这一行下面开始保留你的原代码 + + # 选择面积最大的mask + if candidate_masks: + # 按面积降序排序 + candidate_masks.sort(key=lambda x: x['area'], reverse=True) + largest_mask_info = candidate_masks[0] + last_detected_mask = largest_mask_info['mask'] + + # 可选:在最大的物体上添加特殊标记 + contours, _ = cv2.findContours(last_detected_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + if contours: + # 找到最大轮廓的中心 + M = cv2.moments(contours[0]) + if M["m00"] != 0: + cx = int(M["m10"] / M["m00"]) + cy = int(M["m01"] / M["m00"]) + # 在最大物体中心画一个圆圈标记 + cv2.circle(vis, (cx, cy), 8, (0, 255, 0), 2) + cv2.circle(vis, (cx, cy), 12, (0, 255, 0), 1) + # 目标标签:保持就地标注 + draw_text_cn(vis, "目标", (cx + 15, cy - 5), font_size=16, color=FRONTEND_COLORS["ok"], ui_hint=False) + + # 显示检测信息 + if len(candidate_masks) > 1: + draw_text_cn(vis, f"检测到{len(candidate_masks)}个物体,选择最大的(面积: {largest_mask_info['area']})", + (10, H - 30), font_size=16, color=(255, 255, 0)) + + # 自动锁定逻辑 + if detected_object and last_detected_mask is not None: + if auto_lock_start_time is None: + auto_lock_start_time = t_now + print(f"[AUTO] 检测到物体,选择最大的(面积: {np.sum(last_detected_mask)}),开始倒计时...") + #play_guidance_audio("检测到物体") # 添加这行 + + elapsed = t_now - auto_lock_start_time + remaining = auto_lock_delay - elapsed + + if remaining > 0: + # 显示倒计时(移动到左下角,前端风格) + draw_text_cn(vis, f"检测到物体,{remaining:.1f}秒后自动锁定", (10, 100), font_size=16, color=FRONTEND_COLORS["text"], stroke=(0,0,0)) + + # 绘制锁定框 - 使用虚线框表示正在准备锁定 + if last_detected_mask is not None: + contours, _ = cv2.findContours(last_detected_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + if contours: + # 找到最大轮廓 + largest_contour = max(contours, key=cv2.contourArea) + # 简化轮廓 + epsilon = CONTOUR_EPSILON_FACTOR * cv2.arcLength(largest_contour, True) + smoothed_contour = cv2.approxPolyDP(largest_contour, epsilon, True) + + # 根据倒计时进度改变颜色亮度 + progress = 1.0 - (remaining / auto_lock_delay) + color_intensity = int(100 + 155 * progress) # 从100到255 + lock_color = (0, color_intensity, color_intensity) # 黄色渐亮 + + # 绘制虚线轮廓 + pts = smoothed_contour.reshape(-1, 2) + for i in range(len(pts)): + pt1 = tuple(pts[i]) + pt2 = tuple(pts[(i + 1) % len(pts)]) + # 使用虚线效果(通过绘制短线段) + draw_dashed_line(vis, pt1, pt2, color=lock_color, thickness=3, + dash_length=15, gap_length=8) + else: + # 进入闪烁模式 + print("[AUTO] 进入闪烁动画模式") + MODE = "FLASH" + flash_start_time = t_now + flash_mask = last_detected_mask.copy() + auto_lock_start_time = None + play_guidance_audio("检测到物体") + else: + # 没有检测到物体,重置计时器 + if auto_lock_start_time is not None: + print("[AUTO] 物体丢失,重置倒计时") + auto_lock_start_time = None + last_detected_mask = None + draw_text_cn(vis, "分割中... 等待检测到物体", (10, 100), font_size=16, color=FRONTEND_COLORS["muted"]) + + elif MODE == "FLASH": + # 闪烁动画模式 + if flash_start_time is not None and flash_mask is not None: + elapsed = t_now - flash_start_time + + if elapsed < flash_duration: + # 计算渐入渐出效果 + # 前0.3秒渐入,中间0.4秒保持,后0.3秒渐出 + if elapsed < 0.3: + # 渐入阶段 + alpha = elapsed / 0.3 * 0.8 # 0到0.8 + elif elapsed < 0.7: + # 保持阶段 + alpha = 0.8 + else: + # 渐出阶段 + alpha = (1.0 - elapsed) / 0.3 * 0.8 # 0.8到0 + + # 绘制闪烁的mask + colored = np.zeros_like(frame, dtype=np.uint8) + colored[flash_mask == 1] = flash_color + vis = cv2.addWeighted(vis, 1.0 - alpha, colored, alpha, 0) + + # 绘制轮廓(固定粗细,颜色渐变) + contours, _ = cv2.findContours(flash_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + if contours: + # 轮廓颜色也跟随alpha变化 + contour_color = tuple(int(c * (0.5 + alpha * 0.5)) for c in flash_color) + cv2.drawContours(vis, contours, -1, contour_color, STROKE_WIDTH + 1) + + # 显示提示文字(左下角) + draw_text_cn(vis, "正在锁定目标...", (10, 100), font_size=18, color=FRONTEND_COLORS["accent"]) + else: + # 闪烁结束,初始化光流追踪并进入居中引导模式 + print("[AUTO] 闪烁结束,初始化光流追踪") + edge_mask = inner_offset_edge(flash_mask, offset_px=INNER_OFFSET_PX_LOCK, edge_dilate_px=EDGE_DILATE_PX) + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + pts = cv2.goodFeaturesToTrack(gray, mask=edge_mask, **FEATURE_PARAMS) + + if pts is not None and len(pts) >= 8: + p0 = pts + old_gray = gray + MODE = "CENTER_GUIDE" + lock_edge_debug = edge_mask.copy() + track_frame_count = 0 + center_guide_start = t_now + center_reached = False + flash_start_time = None + flash_mask = None + last_detected_mask = None + print(f"[LOCK] 内边界特征点数={len(p0)} → CENTER_GUIDE") + else: + print("[LOCK] 内边界特征点不足,返回检测模式") + MODE = "SEGMENT" + flash_start_time = None + flash_mask = None + last_detected_mask = None + + elif MODE == "CENTER_GUIDE": + # 居中引导模式(使用光流追踪) + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + poly_center = None + poly_area = 0.0 + + if old_gray is not None and p0 is not None and len(p0) >= 5: + # 光流追踪 + p1, st, err = cv2.calcOpticalFlowPyrLK(old_gray, gray, p0, None, **LK_PARAMS) + if p1 is not None and st is not None: + good_new = p1[st == 1] + if len(good_new) >= 5: + p0 = good_new.reshape(-1, 1, 2) + hull = cv2.convexHull(good_new.reshape(-1,1,2)) + poly = hull.reshape(-1, 2) + + if len(poly) >= 3: + H, W = frame.shape[:2] + + # 把当前光流多边形 rasterize 成掩膜(便于与 YOLOE 掩膜做 IoU) + poly_mask = np.zeros((H, W), dtype=np.uint8) + cv2.fillPoly(poly_mask, [poly.astype(np.int32)], 1) + + # 降频:每3帧用 YOLOE 重新检测,其余帧依赖光流维持 + need_reseed = False + new_det_mask = None + + if use_yoloe and yoloe_backend is not None and (FRAME_IDX % 3 == 0): + # 添加调试信息 + if FRAME_IDX % 30 == 0: # 每30帧打印一次 + print(f"[YOLOE] 实时检测第 {FRAME_IDX} 帧") + det = yoloe_backend.segment(frame, conf=0.20, iou=0.45, persist=True) + if det["masks"]: + # 取面积最大的那个 + areas = [int(m.sum()) for m in det["masks"]] + j = int(np.argmax(areas)) + m = det["masks"][j] + if m.shape[:2] != (H, W): + m = cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST) + new_det_mask = (m > 0).astype(np.uint8) + + # 和当前光流多边形的 IoU + inter = np.logical_and(new_det_mask, poly_mask).sum() + union = np.logical_or(new_det_mask, poly_mask).sum() + 1e-6 + iou = inter / union + + # IoU 太低,说明漂了:用 YOLOE 的掩膜重播种光流 + # 降低阈值,让 YOLOE 更容易更新光流 + if iou < 0.5: # 从 IOU_MIN_KEEP (0.20) 提高到 0.5 + need_reseed = True + # 用新掩膜的「内边界特征点」播种 + edge_mask = inner_offset_edge(new_det_mask, offset_px=INNER_OFFSET_PX_LOCK, edge_dilate_px=EDGE_DILATE_PX) + gray2 = gray # 本帧灰度图已在上面算过 + pts = cv2.goodFeaturesToTrack(gray2, mask=edge_mask, **FEATURE_PARAMS) + if pts is not None and len(pts) >= 8: + p0 = pts + old_gray = gray2 + # 更新 last_mask,便于下游逻辑一致 + last_mask = new_det_mask.copy() + last_seen_ts = time.time() + flow_grace = 0 + print("[RESEED] YOLOE 低 IoU 触发重播种(已更新光流特征点)") + + # 如果这帧没重播种,但 YOLOE 有结果且与 poly 很接近,可以做一次"平滑融合",抑制抖动 + if (not need_reseed) and (new_det_mask is not None): + inter = np.logical_and(new_det_mask, poly_mask).sum() + union = np.logical_or(new_det_mask, poly_mask).sum() + 1e-6 + iou = inter / union + # 降低融合阈值,让 YOLOE 结果更容易被采用 + if iou < 0.95: # 从 0.90 提高到 0.95 + # 增加 YOLOE 的权重,让实时检测更明显 + poly_mask = ((0.8 * new_det_mask + 0.2 * poly_mask) > 0.5).astype(np.uint8) + # 用更新后的 poly_mask 回写到可视化与引导的后续变量(如果你下游用的是 last_detected_mask/last_mask) + last_mask = poly_mask.copy() + # 更新多边形轮廓,让可视化实时更新 + contours, _ = cv2.findContours(poly_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + if contours: + # 找到最大轮廓 + largest_contour = max(contours, key=cv2.contourArea) + # 使用精细的轮廓处理,保留更多细节 + epsilon = TRACK_EPSILON_FACTOR * cv2.arcLength(largest_contour, True) + poly = cv2.approxPolyDP(largest_contour, epsilon, True).reshape(-1, 2) + # 注释掉凸包处理,保留原始轮廓细节 + # hull = cv2.convexHull(poly.reshape(-1,1,2)) + # poly = hull.reshape(-1, 2) + # 重新计算特征点 + edge_mask = inner_offset_edge(poly_mask, offset_px=INNER_OFFSET_PX_LOCK, edge_dilate_px=EDGE_DILATE_PX) + pts = cv2.goodFeaturesToTrack(gray, mask=edge_mask, **FEATURE_PARAMS) + if pts is not None and len(pts) >= 5: + p0 = pts + + # 绘制追踪的多边形 - 使用更粗的线条 + cv2.polylines(vis, [poly.astype(np.int32)], isClosed=True, color=(0,255,255), thickness=STROKE_WIDTH) + + # 计算多边形中心 + poly_center, poly_area = polygon_center_and_area(poly) + + if poly_center: + object_center = (int(poly_center[0]), int(poly_center[1])) + + # 画面中心 + frame_center = (W // 2, H // 2) + + # 绘制物品中心点 + cv2.circle(vis, object_center, 8, (0, 255, 0), -1) + cv2.circle(vis, object_center, 12, (0, 255, 0), 2) + + # 绘制画面中心十字 + cv2.line(vis, (frame_center[0] - 20, frame_center[1]), + (frame_center[0] + 20, frame_center[1]), (255, 255, 255), 2) + cv2.line(vis, (frame_center[0], frame_center[1] - 20), + (frame_center[0], frame_center[1] + 20), (255, 255, 255), 2) + + # 绘制引导虚线 + draw_dashed_line(vis, object_center, frame_center, + color=(255, 255, 0), thickness=2, + dash_length=10, gap_length=5) + + # 获取引导方向 + direction, is_centered = get_center_guidance(object_center, frame_center, center_threshold) + + if not center_reached: + if is_centered: + # 到达中心,播放OK音效 + center_reached = True + last_center_guide_time = t_now + play_guidance_audio("OK") + try: + bridge_io.send_ui_final("✓ 物品已居中!") + except Exception: + pass + draw_text_cn(vis, "✓ 物品已居中!", (10, 60), font_size=18, color=FRONTEND_COLORS["ok"]) + else: + # 显示引导文字 + msg = f"请将物品移到画面中心: {direction}" + try: + # 节流:每次语音播报也推一次final + if t_now - last_center_guide_time > GUIDANCE_INTERVAL_SEC: + bridge_io.send_ui_final(msg) + except Exception: + pass + draw_text_cn(vis, msg, + (10, 40), font_size=18, color=FRONTEND_COLORS["text"]) + + # 显示距离信息 + dx = frame_center[0] - object_center[0] + dy = frame_center[1] - object_center[1] + distance = int(np.sqrt(dx**2 + dy**2)) + draw_text_cn(vis, f"距离: {distance}px", + (10, 60), font_size=16, color=FRONTEND_COLORS["muted"]) + + # 播放语音引导 + if t_now - last_center_guide_time > GUIDANCE_INTERVAL_SEC: + play_guidance_audio(direction) + last_center_guide_time = t_now + else: + # 已经居中,显示成功信息 + try: + bridge_io.send_ui_final("✓ 物品已成功移到中心!") + except Exception: + pass + draw_text_cn(vis, "✓ 物品已成功移到中心!", + (10, 60), font_size=18, color=FRONTEND_COLORS["ok"]) + + # 等待1秒后进入手部追踪模式 + if t_now - last_center_guide_time > 1.0: + print("[CENTER] 进入手部追踪模式") + try: + bridge_io.send_ui_final("进入手部追踪模式") + except Exception: + pass + MODE = "TRACK" + # 保持当前的光流追踪状态 + else: + # 多边形中心计算失败,显示警告 + draw_text_cn(vis, "正在追踪物体...", (10, 100), font_size=20, color=(255, 255, 0)) + else: + # 光流点数不足,尝试重新检测 + MODE = "SEGMENT" + old_gray = None + p0 = None + print("[CENTER] 光流追踪失败,返回检测模式") + + old_gray = gray + + else: # MODE == "TRACK" + # 手部追踪模式(原有逻辑保持不变) + align_score = 0.0 + range_score = 0.0 + ratio = None + + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + track_frame_count += 1 + + relock_done = False + poly_center = None + poly_area = 0.0 + + # 初始化camera_movement为默认值 + camera_movement = np.array([0.0, 0.0]) + + # 初始化或更新背景参考点(在物体多边形外部取点) + if background_points is None or track_frame_count % 30 == 0: + # 在画面四角取一些背景特征点 + mask_for_bg = np.ones((H, W), dtype=np.uint8) * 255 + if last_poly_box: + x, y, w, h = last_poly_box + # 扩大区域,排除物体和手 + expand = 100 + x1 = max(0, x - expand) + y1 = max(0, y - expand) + x2 = min(W, x + w + expand) + y2 = min(H, y + h + expand) + mask_for_bg[y1:y2, x1:x2] = 0 + + # 在背景区域提取特征点 + try: + bg_pts = cv2.goodFeaturesToTrack(gray, maxCorners=20, + qualityLevel=0.1, + minDistance=30, + mask=mask_for_bg) + if bg_pts is not None and len(bg_pts) >= 5: + background_points = bg_pts + old_background_gray = gray.copy() + except Exception as e: + #print(f"[TRACK] 背景特征点提取失败: {e}") + background_points = None + + # 计算背景移动(相机移动) + if old_background_gray is not None and background_points is not None and len(background_points) > 0: + try: + bg_p1, bg_st, _ = cv2.calcOpticalFlowPyrLK( + old_background_gray, gray, background_points, None, **LK_PARAMS + ) + if bg_p1 is not None and bg_st is not None: + good_bg_old = background_points[bg_st == 1] + good_bg_new = bg_p1[bg_st == 1] + if len(good_bg_new) >= 3 and len(good_bg_old) >= 3: + # 计算背景的平均移动 + bg_movement = np.mean(good_bg_new - good_bg_old, axis=0) + camera_movement = bg_movement.reshape(2) + background_points = good_bg_new.reshape(-1, 1, 2) + old_background_gray = gray.copy() + except Exception as e: + print(f"[TRACK] 背景光流计算失败: {e}") + camera_movement = np.array([0.0, 0.0]) + + if old_gray is not None and p0 is not None and len(p0) >= 5: + p1, st, err = cv2.calcOpticalFlowPyrLK(old_gray, gray, p0, None, **LK_PARAMS) + if p1 is not None and st is not None: + good_new = p1[st == 1] + if len(good_new) >= 5: + p0 = good_new.reshape(-1, 1, 2) + hull = cv2.convexHull(good_new.reshape(-1,1,2)) + poly = hull.reshape(-1, 2) + + if len(poly) >= 3: + # 统一的 YOLOE 实时检测和校正(每帧) + latest_det_mask = None + if use_yoloe and yoloe_backend is not None: + # 添加调试信息 + if track_frame_count % 30 == 0: # 每30帧打印一次 + print(f"[YOLOE] TRACK模式实时检测第 {track_frame_count} 帧") + + # YOLOE 实时检测(统一调用,避免重复) + det = yoloe_backend.segment(frame, conf=YOLO_CORRECTION_CONF_THRESHOLD, iou=0.45, persist=True) + if det["masks"]: + # 取面积最大的那个 + areas = [int(m.sum()) for m in det["masks"]] + j = int(np.argmax(areas)) + m = det["masks"][j] + if m.shape[:2] != (H, W): + m = cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST) + latest_det_mask = (m > 0).astype(np.uint8) + + # 和当前光流多边形的 IoU + poly_mask = np.zeros((H, W), dtype=np.uint8) + cv2.fillPoly(poly_mask, [poly.astype(np.int32)], 1) + inter = np.logical_and(latest_det_mask, poly_mask).sum() + union = np.logical_or(latest_det_mask, poly_mask).sum() + 1e-6 + iou = inter / union + + # 降低IoU阈值,更积极地校正 + if iou > YOLO_CORRECTION_IOU_THRESHOLD: # 使用可配置阈值 + # 用 YOLOE 结果更新多边形 + contours, _ = cv2.findContours(latest_det_mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + if contours: + largest_contour = max(contours, key=cv2.contourArea) + # 使用更精细的轮廓处理,减少过度简化 + epsilon = TRACK_EPSILON_FACTOR * cv2.arcLength(largest_contour, True) + poly = cv2.approxPolyDP(largest_contour, epsilon, True).reshape(-1, 2) + + # 更新光流特征点 + edge_mask = inner_offset_edge(latest_det_mask, offset_px=INNER_OFFSET_PX_LOCK, edge_dilate_px=EDGE_DILATE_PX) + pts = cv2.goodFeaturesToTrack(gray, mask=edge_mask, **FEATURE_PARAMS) + if pts is not None and len(pts) >= 5: + p0 = pts + #print(f"[TRACK] YOLOE 实时校正,IoU: {iou:.3f}") + + # 检查是否接触,决定轮廓颜色 + is_touching = False + overlap_ratio = 0.0 + if hand_box is not None and poly is not None: + is_touching, overlap_ratio = check_hand_object_contact(hand_box, poly, overlap_threshold=0.1) + + # 绘制多边形(可能已被 YOLOE 更新)- 使用更粗的线条 + if is_touching: + # 接触时用亮绿色,并添加发光效果 + poly_color = (0, 255, 127) + # 绘制一个更粗的外层轮廓作为发光效果 + cv2.polylines(vis, [poly.astype(np.int32)], isClosed=True, + color=(127, 255, 127), thickness=STROKE_WIDTH + 4) + # 添加半透明的填充效果 + overlay = vis.copy() + cv2.fillPoly(overlay, [poly.astype(np.int32)], (0, 255, 0)) + cv2.addWeighted(overlay, 0.15, vis, 0.85, 0, vis) + else: + # 未接触时用普通绿色 + poly_color = (0, 255, 0) + cv2.polylines(vis, [poly.astype(np.int32)], isClosed=True, color=poly_color, thickness=STROKE_WIDTH) + # 多边形质心与面积 + poly_center, poly_area = polygon_center_and_area(poly) + if poly_center: + pc = (int(poly_center[0]), int(poly_center[1])) + cv2.circle(vis, pc, 6, (0,255,0), -1) + + # 多边形外接矩形(用于周边监控) + x, y, w, h = cv2.boundingRect(poly.astype(np.int32)) + last_poly_box = (x, y, w, h) + + # ====== 对齐分数(第一条)====== + if hand_center and poly_center: + hc = np.array(hand_center, dtype=np.float32) + oc = np.array(poly_center, dtype=np.float32) + dist = float(np.linalg.norm(oc - hc)) + diag = float(np.linalg.norm([W, H])) + align_score = 1.0 - min(dist/(ALIGN_LOOSE_PCT*diag + 1e-6), 1.0) + + # 绘制虚线引导(替代原来的实线箭头) + draw_dashed_line(vis, (hc[0], hc[1]), (oc[0], oc[1]), + color=(255, 255, 0), thickness=2, + dash_length=15, gap_length=10) + + # 方向引导 + direction, secondary = get_guidance_direction( + hand_center, poly_center, hand_area, poly_area, + hand_box, poly + ) + + if direction and direction != "保持": + # 根据是否接触显示不同颜色 + if direction == "向前": + # 手已经接触物体,用绿色显示 + guide_color = (0, 255, 0) # 绿色 + draw_text_cn(vis, f"引导: {direction} - 伸手抓取", (W//2 - 80, 40), + font_size=24, color=guide_color, stroke=(0, 0, 0)) + else: + # 还未接触,用黄色显示 + guide_color = (0, 255, 255) # 黄色 + draw_text_cn(vis, f"引导: {direction}", (W//2 - 60, 40), + font_size=24, color=guide_color, stroke=(0, 0, 0)) + + # 显示次要信息(接触度或其他方向) + if secondary: + if isinstance(secondary, str): + # 接触度信息 + draw_text_cn(vis, secondary, (W//2 - 60, 70), + font_size=18, color=(0, 255, 0)) + else: + # 其他方向信息 + draw_text_cn(vis, f"(或 {secondary})", (W//2 - 60, 70), + font_size=18, color=(200, 200, 200)) + + # 播放语音引导 - 确保每个方向都会播放 + if t_now - last_guidance_time > GUIDANCE_INTERVAL_SEC: + # 检查方向是否改变,或者时间间隔足够 + if direction != last_guidance_direction or t_now - last_guidance_time > GUIDANCE_INTERVAL_SEC * 2: + play_guidance_audio(direction) + last_guidance_direction = direction + last_guidance_time = t_now + print(f"[GUIDE] 播放引导音频: {direction}") + else: + align_score = 0.0 + + # 显示接触状态 + is_touching, overlap_ratio = check_hand_object_contact(hand_box, poly, overlap_threshold=0.1) + if is_touching: + draw_text_cn(vis, f"状态: 已接触 ({overlap_ratio:.1%})", (10, 95), + font_size=16, color=(0, 255, 0)) + else: + # 计算手和物体的距离 + if hand_center and poly_center: + distance = np.sqrt((hand_center[0] - poly_center[0])**2 + + (hand_center[1] - poly_center[1])**2) + draw_text_cn(vis, f"距离: {distance:.0f}px", (10, 95), + font_size=16, color=FRONTEND_COLORS["muted"]) + + # 成功条件:握持(放宽) + if (_last_result and _last_result[0].hand_landmarks and len(_last_result[0].hand_landmarks) > 0): + l0 = _last_result[0].hand_landmarks[0] + grasp_now, grasp_score = detect_grasp(l0, W, H) + else: + grasp_now, grasp_score = False, 0.0 + + # guidance_msg 相关代码已经集成到上面的引导逻辑中 + + # ===== 周边监控 & 重新锁定(复用YOLO结果)===== + if (track_frame_count % PERI_CHECK_EVERY == 0) and (last_poly_box is not None) and (latest_det_mask is not None): + # 直接使用刚才的YOLO检测结果,避免重复调用 + px, py, pw, ph = last_poly_box + x0 = max(0, px - PERI_MONITOR_PX) + y0 = max(0, py - PERI_MONITOR_PX) + x1 = min(W - 1, px + pw + PERI_MONITOR_PX) + y1 = min(H - 1, py + ph + PERI_MONITOR_PX) + + # 检查周边区域是否有更好的检测结果 + peri_area = latest_det_mask[y0:y1, x0:x1].sum() + total_area = latest_det_mask.sum() + + # 如果周边区域有显著检测结果,重新锁定 + if peri_area > total_area * 0.1: # 周边有10%以上的检测面积 + edge_mask = inner_offset_edge(latest_det_mask, offset_px=INNER_OFFSET_PX_LOCK, edge_dilate_px=EDGE_DILATE_PX) + pts = cv2.goodFeaturesToTrack(gray, mask=edge_mask, **FEATURE_PARAMS) + if pts is not None and len(pts) >= 8: + p0 = pts + old_gray = gray + lock_edge_debug = edge_mask.copy() + #print(f"[PERI] 周边重锁定,特征点数={len(p0)}") + else: + MODE = "SEGMENT"; old_gray = None; p0 = None; lock_edge_debug = None + else: + MODE = "SEGMENT"; old_gray = None; p0 = None; lock_edge_debug = None + else: + MODE = "SEGMENT"; old_gray = None; p0 = None; lock_edge_debug = None + else: + MODE = "SEGMENT"; old_gray = None; p0 = None; lock_edge_debug = None + + + + if MODE == "SEGMENT": + draw_text_cn(vis, "追踪丢失 → 正在重新识别。按 Enter 重新锁定", (10, 100), font_size=22, color=(0,0,255)) + + old_gray = gray + + # FPS(移动到左下角样式) + if 'fps_hist' not in locals(): + fps_hist = [] + fps_hist.append(t_now) + if len(fps_hist) > 30: + fps_hist.pop(0) + fps = 0.0 if len(fps_hist) < 2 else (len(fps_hist)-1)/(fps_hist[-1]-fps_hist[0]) + draw_text_cn(vis, f"FPS: {fps:.1f}", (10, 40), font_size=16, color=FRONTEND_COLORS["ok"]) + + # 右下角显示"内边界/最近一次锁定"的调试图 + if lock_edge_debug is not None: + # 极小缩放并放在右下角 + small = cv2.resize(lock_edge_debug, (0,0), fx=0.22, fy=0.22, interpolation=cv2.INTER_NEAREST) + sh, sw = small.shape[:2] + small_bgr = cv2.cvtColor(small, cv2.COLOR_GRAY2BGR) + # 右下角位置,留 10-12px 边距 + x1 = max(8, W - sw - 12) + y1 = max(8, H - sh - 12) + y2 = y1 + sh + x2 = x1 + sw + vis[y1:y2, x1:x2] = small_bgr + # 标签置于图上方紧贴,使用更小字号 + #draw_text_cn(vis, "内边界", (x1, y1 - 8), font_size=12, color=FRONTEND_COLORS["muted"], ui_hint=False) + + # 底部中间的"当前指令"按钮(始终绘制,文案随音频同步) + draw_command_pill(vis, CURRENT_COMMAND_TEXT) + + # 展示(无论 headless 与否,都会推给前端) + bridge_io.send_vis_bgr(vis) + + # [headless] 只有非 headless 时才弹窗与键盘交互;headless 下用 waitKey(1) 让出调度 + if not headless: + cv2.imshow(WINDOW, vis) + key = cv2.waitKey(1) & 0xFF + if key in (27, ord('q')): + break + elif key == ord('r'): + MODE = "SEGMENT"; old_gray = None; p0 = None; lock_edge_debug = None + elif key == 13: # Enter:从 SEGMENT 锁定并开始 TRACK(内收 5px) + if MODE == "SEGMENT": + # 使用 YOLOE 进行手动锁定 + if use_yoloe and yoloe_backend is not None: + det = yoloe_backend.segment(frame, conf=CONF_THRESHOLD, iou=0.45, persist=True) + if det["masks"]: + # 取面积最大的那个 + areas = [int(m.sum()) for m in det["masks"]] + j = int(np.argmax(areas)) + m = det["masks"][j] + if m.shape[:2] != (H, W): + m = cv2.resize(m, (W, H), interpolation=cv2.INTER_NEAREST) + best_mask = (m > 0.5).astype(np.uint8) + else: + best_mask = None + else: + best_mask = None + if best_mask is not None: + edge_mask = inner_offset_edge(best_mask, offset_px=INNER_OFFSET_PX_LOCK, edge_dilate_px=EDGE_DILATE_PX) + gray = cv2.cvtColor(frame, cv2.COLOR_BGR2GRAY) + pts = cv2.goodFeaturesToTrack(gray, mask=edge_mask, **FEATURE_PARAMS) + if pts is not None and len(pts) >= 8: + p0 = pts + old_gray = gray + MODE = "TRACK" + lock_edge_debug = edge_mask.copy() + track_frame_count = 0 + print(f"[LOCK] 内边界特征点数={len(p0)} → TRACK") + else: + print("[LOCK] 内边界特征点不足,请调整画面后重试。") + else: + print("[LOCK] 当前帧未找到有效分割,请重试。") + else: + # headless 下也调用一次 waitKey(1),让 OpenCV 的计时器/回调得到机会,且避免 CPU 忙等 + cv2.waitKey(1) + + # 在 headless 模式下检查停止事件 + if stop_event and stop_event.is_set(): + print("[YOLOMEDIA] Received stop signal in headless mode") + break + + finally: + try: + landmarker.close() + except Exception: + pass + #cap.release() + # [headless] 仅在非 headless 时销毁窗口 + if not headless: + cv2.destroyAllWindows() + + +if __name__ == "__main__": + main()