代码优化

This commit is contained in:
Kevin Wong
2026-01-05 09:08:40 +08:00
parent 8d725c2723
commit baf9e235a1
13 changed files with 1183 additions and 348 deletions

File diff suppressed because it is too large Load Diff

View File

@@ -63,7 +63,8 @@ INTERRUPT_KEYWORDS = set(
NAV_CONTROL_WHITELIST = [ NAV_CONTROL_WHITELIST = [
"停止导航", "结束导航", "停止检测", "停止红绿灯", "停止导航", "结束导航", "停止检测", "停止红绿灯",
"开始导航", "盲道导航", "开始过马路", "过马路结束", "开始导航", "盲道导航", "开始过马路", "过马路结束",
"帮我导航", "帮我过马路" "帮我导航", "帮我过马路",
"室内导航", "室内导盲", # Day 25: 新增室内导航命令
] ]

View File

@@ -371,9 +371,9 @@ class CompressedAudioCache:
# 打印压缩率 # 打印压缩率
compression_ratio = len(compressed) / self._original_sizes[filepath] compression_ratio = len(compressed) / self._original_sizes[filepath]
logger.info(f"[压缩] {os.path.basename(filepath)}: " # logger.info(f"[压缩] {os.path.basename(filepath)}: "
f"{self._original_sizes[filepath]} -> {len(compressed)} bytes " # f"{self._original_sizes[filepath]} -> {len(compressed)} bytes "
f"({compression_ratio:.1%})") # f"({compression_ratio:.1%})")
return compressed return compressed

View File

@@ -8,6 +8,7 @@ import asyncio
import threading import threading
import queue import queue
import time import time
import hashlib
from audio_stream import broadcast_pcm16_realtime from audio_stream import broadcast_pcm16_realtime
from audio_compressor import compressed_audio_cache, AudioCompressor from audio_compressor import compressed_audio_cache, AudioCompressor
@@ -36,6 +37,9 @@ AUDIO_BASE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "music
VOICE_DIR = os.getenv("VOICE_DIR", os.path.join(os.path.dirname(os.path.abspath(__file__)), "voice")) VOICE_DIR = os.getenv("VOICE_DIR", os.path.join(os.path.dirname(os.path.abspath(__file__)), "voice"))
VOICE_MAP_FILE = os.path.join(VOICE_DIR, "map.zh-CN.json") VOICE_MAP_FILE = os.path.join(VOICE_DIR, "map.zh-CN.json")
# Day 26 优化: EdgeTTS 合成语音磁盘缓存目录
TTS_CACHE_DIR = os.path.join(os.path.dirname(os.path.abspath(__file__)), "voice", "tts_cache")
# 音频文件映射(将合并 voice 映射) # 音频文件映射(将合并 voice 映射)
AUDIO_MAP = { AUDIO_MAP = {
"检测到物体": os.path.join(AUDIO_BASE_DIR, "音频1.wav"), "检测到物体": os.path.join(AUDIO_BASE_DIR, "音频1.wav"),
@@ -100,7 +104,7 @@ def load_wav_file(filepath):
if framerate != 16000: if framerate != 16000:
import audioop import audioop
frames, _ = audioop.ratecv(frames, sampwidth, 1, framerate, 16000, None) frames, _ = audioop.ratecv(frames, sampwidth, 1, framerate, 16000, None)
print(f"[AUDIO] 重采样: {filepath} {framerate}Hz -> 16000Hz") # print(f"[AUDIO] 重采样: {filepath} {framerate}Hz -> 16000Hz")
_audio_cache[filepath] = frames _audio_cache[filepath] = frames
return frames return frames
@@ -129,7 +133,8 @@ def _merge_voice_map():
added += 1 added += 1
else: else:
print(f"[AUDIO] 映射文件缺失: {fpath}") print(f"[AUDIO] 映射文件缺失: {fpath}")
print(f"[AUDIO] 已合并 voice 映射 {added}") if added > 0:
print(f"[AUDIO] 已合并 voice 映射 {added}")
except Exception as e: except Exception as e:
print(f"[AUDIO] 读取 voice 映射失败: {e}") print(f"[AUDIO] 读取 voice 映射失败: {e}")
@@ -250,12 +255,13 @@ def initialize_audio_system():
# 显示压缩统计 # 显示压缩统计
if os.getenv("AIGLASS_COMPRESS_AUDIO", "1") == "1": if os.getenv("AIGLASS_COMPRESS_AUDIO", "1") == "1":
stats = compressed_audio_cache.get_compression_stats() stats = compressed_audio_cache.get_compression_stats()
print(f"[AUDIO] 音频压缩统计:") # print(f"[AUDIO] 音频压缩统计:")
print(f" - 文件数: {stats['files_cached']}") # print(f" - 文件数: {stats['files_cached']}")
print(f" - 原始大小: {stats['total_original_size'] / 1024:.1f} KB") # print(f" - 原始大小: {stats['total_original_size'] / 1024:.1f} KB")
print(f" - 压缩后: {stats['total_compressed_size'] / 1024:.1f} KB") # print(f" - 压缩后: {stats['total_compressed_size'] / 1024:.1f} KB")
print(f" - 压缩率: {stats['compression_ratio']:.1%}") # print(f" - 压缩率: {stats['compression_ratio']:.1%}")
print(f" - 节省: {stats['bytes_saved'] / 1024:.1f} KB") # print(f" - 节省: {stats['bytes_saved'] / 1024:.1f} KB")
pass
print("[AUDIO] 音频系统初始化完成(预加载+工作线程)") print("[AUDIO] 音频系统初始化完成(预加载+工作线程)")
@@ -385,8 +391,73 @@ def play_voice_text(text: str):
_last_voice_time = current_time _last_voice_time = current_time
return return
# 未匹配则输出日志(便于调试) # 未匹配则尝试使用 EdgeTTS 进行流式合成 (Day 26)
print(f"[AUDIO] 未找到匹配语音: {text}") print(f"[AUDIO] 未找到本地语音,尝试 EdgeTTS 合成: {text}")
# 启动后台任务进行合成和播放
# 注意:为了不阻塞主线程,这里使用 create_task
try:
loop = asyncio.get_event_loop()
loop.create_task(_synthesize_and_play_fallback(text))
except RuntimeError:
# 如果当前线程没有 loop (例如在非 async 上下文中),则使用线程
# 但通常 app_main 是 async 的,这里应该没问题
pass
async def _synthesize_and_play_fallback(text: str):
"""(内部) 使用 EdgeTTS 合成并播放,支持磁盘缓存"""
try:
# 动态导入以避免循环依赖
from edge_tts_client import text_to_speech_pcm
global _audio_cache
cache_key = f"tts_fallback:{text}"
# 1. 先检查内存缓存
if cache_key in _audio_cache:
play_audio_threadsafe(cache_key)
return
# 2. Day 26: 检查磁盘缓存
text_hash = hashlib.md5(text.encode('utf-8')).hexdigest()
disk_cache_path = os.path.join(TTS_CACHE_DIR, f"{text_hash}.pcm")
if os.path.exists(disk_cache_path):
# 从磁盘加载
with open(disk_cache_path, 'rb') as f:
pcm_data = f.read()
if pcm_data:
_audio_cache[cache_key] = pcm_data
AUDIO_MAP[cache_key] = cache_key
play_audio_threadsafe(cache_key)
print(f"[AUDIO] EdgeTTS 从磁盘缓存加载: {text[:20]}...")
return
# 3. 合成 (目标 16kHz PCM)
pcm_data = await text_to_speech_pcm(text, target_sample_rate=16000)
if pcm_data:
# 存入内存缓存
_audio_cache[cache_key] = pcm_data
AUDIO_MAP[cache_key] = cache_key
# Day 26: 存入磁盘缓存(异步写入,不阻塞播放)
try:
os.makedirs(TTS_CACHE_DIR, exist_ok=True)
with open(disk_cache_path, 'wb') as f:
f.write(pcm_data)
print(f"[AUDIO] EdgeTTS 已缓存到磁盘: {text[:20]}...")
except Exception as disk_err:
print(f"[AUDIO] 磁盘缓存写入失败: {disk_err}")
# 播放
play_audio_threadsafe(cache_key)
print(f"[AUDIO] EdgeTTS 合成成功: {text}")
else:
print(f"[AUDIO] EdgeTTS 合成返回空: {text}")
except Exception as e:
print(f"[AUDIO] EdgeTTS 回退失败: {e}")
# 兼容旧接口 # 兼容旧接口
play_audio_on_esp32 = play_audio_threadsafe play_audio_on_esp32 = play_audio_threadsafe

View File

@@ -59,6 +59,7 @@ async def text_to_speech_stream(
except Exception as e: except Exception as e:
print(f"[EdgeTTS] 合成失败: {e}") print(f"[EdgeTTS] 合成失败: {e}")
raise e # Day 23: 抛出异常以便上层重试
async def text_to_speech( async def text_to_speech(
@@ -80,9 +81,28 @@ async def text_to_speech(
MP3 音频数据 MP3 音频数据
""" """
audio_chunks = [] audio_chunks = []
async for chunk in text_to_speech_stream(text, voice, rate, volume):
audio_chunks.append(chunk) # Day 23: 添加重试逻辑
return b"".join(audio_chunks) max_retries = 3
for attempt in range(max_retries):
try:
audio_chunks = [] # 清空缓存,重新开始
async for chunk in text_to_speech_stream(text, voice, rate, volume):
audio_chunks.append(chunk)
# 成功,返回完整音频
return b"".join(audio_chunks)
except Exception:
if attempt < max_retries - 1:
wait_time = 0.5 * (2 ** attempt)
print(f"[EdgeTTS] 合成异常,{wait_time}s 后重试 ({attempt+1}/{max_retries})")
await asyncio.sleep(wait_time)
else:
print(f"[EdgeTTS] 重试 {max_retries} 次后仍失败")
return b"" # 最终失败返回空
return b""
async def text_to_speech_pcm( async def text_to_speech_pcm(

View File

@@ -13,10 +13,9 @@ from typing import AsyncGenerator, Optional
from zai import ZhipuAiClient from zai import ZhipuAiClient
# API 配置 # API 配置
API_KEY = os.getenv( API_KEY = os.getenv("GLM_API_KEY")
"GLM_API_KEY", if not API_KEY:
"5915240ea48d4e93b454bc2412d1cc54.e054ej4pPqi9G6rc" raise RuntimeError("未设置 GLM_API_KEY 环境变量,请在 .env 中配置")
)
MODEL = "glm-4.6v-flash" # 升级到 glm-4.6v-flash (支持视觉) MODEL = "glm-4.6v-flash" # 升级到 glm-4.6v-flash (支持视觉)
# 星期映射 # 星期映射
@@ -178,14 +177,35 @@ async def chat_stream(user_message: str, image_base64: Optional[str] = None) ->
try: try:
# 流式调用 # 流式调用
# Day 22: 升级到 glm-4.6v-flash # Day 22: 升级到 glm-4.6v-flash
# 【修正】根据官方文档thinking 参数也是必须的 max_retries = 3
response = await asyncio.to_thread( retry_delay = 1
client.chat.completions.create,
model=MODEL, response = None
messages=messages, for attempt in range(max_retries):
thinking={"type": "disabled"}, try:
stream=True, # 【修正】根据官方文档thinking 参数也是必须的
) response = await asyncio.to_thread(
client.chat.completions.create,
model=MODEL,
messages=messages,
thinking={"type": "disabled"},
stream=True,
)
break # 成功则跳出循环
except Exception as e:
error_str = str(e)
if attempt < max_retries - 1:
if "429" in error_str or "1305" in error_str or "请求过多" in error_str:
print(f"[GLM] (流式) 速率限制,{retry_delay}秒后重试... ({attempt + 1}/{max_retries})")
await asyncio.sleep(retry_delay)
retry_delay *= 2
continue
# 其他网络错误也可以重试
print(f"[GLM] (流式) 连接错误: {e},重试... ({attempt + 1}/{max_retries})")
await asyncio.sleep(retry_delay)
continue
else:
raise e # 最后一次尝试失败,抛出异常
for chunk in response: for chunk in response:
if chunk.choices[0].delta.content: if chunk.choices[0].delta.content:

View File

@@ -23,6 +23,7 @@ SEEKING_NEXT_BLINDPATH = "SEEKING_NEXT_BLINDPATH" # 过完马路后寻找下一
RECOVERY = "RECOVERY" # 兜底/恢复(感知暂时丢失时) RECOVERY = "RECOVERY" # 兜底/恢复(感知暂时丢失时)
TRAFFIC_LIGHT_DETECTION = "TRAFFIC_LIGHT_DETECTION" # 红绿灯检测模式 TRAFFIC_LIGHT_DETECTION = "TRAFFIC_LIGHT_DETECTION" # 红绿灯检测模式
ITEM_SEARCH = "ITEM_SEARCH" # 找物品模式暂停导航由yolomedia处理画面 ITEM_SEARCH = "ITEM_SEARCH" # 找物品模式暂停导航由yolomedia处理画面
INDOOR_NAV = "INDOOR_NAV" # 室内导航模式(使用室内导盲模型)
# ========== 返回结构 ========== # ========== 返回结构 ==========
@dataclass @dataclass
@@ -247,9 +248,11 @@ class NavigationMaster:
blind_nav: BlindPathNavigator, blind_nav: BlindPathNavigator,
cross_nav: CrossStreetNavigator, cross_nav: CrossStreetNavigator,
*, *,
indoor_nav: BlindPathNavigator = None, # 新增:室内导航器
min_tts_interval: float = 1.2): min_tts_interval: float = 1.2):
self.blind = blind_nav self.blind = blind_nav
self.cross = cross_nav self.cross = cross_nav
self.indoor = indoor_nav # 室内导航器(使用室内导盲模型)
self.state = IDLE self.state = IDLE
self.last_guidance_ts = 0.0 self.last_guidance_ts = 0.0
self.min_tts_interval = min_tts_interval self.min_tts_interval = min_tts_interval
@@ -302,7 +305,14 @@ class NavigationMaster:
self.state = CHAT self.state = CHAT
self.cooldown_until = time.time() + self.COOLDOWN_SEC self.cooldown_until = time.time() + self.COOLDOWN_SEC
if self.blind: if self.blind:
self.blind.reset() try: self.blind.reset()
except: pass
if self.cross:
try: self.cross.reset()
except: pass
if self.indoor:
try: self.indoor.reset()
except: pass
def start_crossing(self): def start_crossing(self):
"""启动过马路模式""" """启动过马路模式"""
@@ -316,6 +326,13 @@ class NavigationMaster:
self.state = TRAFFIC_LIGHT_DETECTION self.state = TRAFFIC_LIGHT_DETECTION
self.cooldown_until = time.time() + self.COOLDOWN_SEC self.cooldown_until = time.time() + self.COOLDOWN_SEC
def start_indoor_navigation(self):
"""启动室内导航模式(使用室内导盲模型)"""
self.state = INDOOR_NAV
self.cooldown_until = time.time() + self.COOLDOWN_SEC
if self.blind:
self.blind.reset()
def is_in_navigation_mode(self): def is_in_navigation_mode(self):
"""检查是否在导航模式(非对话模式)""" """检查是否在导航模式(非对话模式)"""
return self.state not in ["CHAT", "IDLE", "TRAFFIC_LIGHT_DETECTION", "ITEM_SEARCH"] return self.state not in ["CHAT", "IDLE", "TRAFFIC_LIGHT_DETECTION", "ITEM_SEARCH"]
@@ -384,6 +401,10 @@ class NavigationMaster:
self.cross.reset() self.cross.reset()
except Exception: except Exception:
pass pass
try:
if self.indoor: self.indoor.reset()
except Exception:
pass
# ----- 内部工具 ----- # ----- 内部工具 -----
def _say(self, now: float, text: str) -> str: def _say(self, now: float, text: str) -> str:
@@ -455,6 +476,25 @@ class NavigationMaster:
# 冷却期内允许继续输出画面,但避免"瞬时切换" # 冷却期内允许继续输出画面,但避免"瞬时切换"
in_cooldown = now < self.cooldown_until in_cooldown = now < self.cooldown_until
# 【新增】室内导航模式:使用室内导盲模型处理帧
# Day 26: 支持 IndoorNavigator 返回的 IndoorResult
if self.state == INDOOR_NAV:
# 优先使用室内导航器,如果没有则 fallback 到盲道导航器
nav = self.indoor if self.indoor else self.blind
try:
result = nav.process_frame(bgr)
except Exception as e:
self.state = RECOVERY
ann_err = bgr.copy()
return OrchestratorResult(ann_err, self._say(now, ""), self.state, {"error": str(e)})
ann = result.annotated_image if result.annotated_image is not None else bgr.copy()
say = result.guidance_text or ""
state_info = result.state_info if hasattr(result, 'state_info') else {}
return OrchestratorResult(ann, self._say(now, say), self.state,
{"source": "indoor", "state_info": state_info})
# 各状态处理 # 各状态处理
if self.state in (BLINDPATH_NAV, SEEKING_CROSSWALK, SEEKING_NEXT_BLINDPATH, RECOVERY): if self.state in (BLINDPATH_NAV, SEEKING_CROSSWALK, SEEKING_NEXT_BLINDPATH, RECOVERY):
# —— 盲道侧 —— 统一调用盲道导航器 # —— 盲道侧 —— 统一调用盲道导航器

98
server_context.py Normal file
View File

@@ -0,0 +1,98 @@
# server_context.py
# -*- coding: utf-8 -*-
import asyncio
from typing import Dict, List, Set, Deque, Optional, Tuple, Any
from collections import deque
from concurrent.futures import ThreadPoolExecutor
from fastapi import WebSocket
class ServerContext:
"""
单例模式的服务器全局上下文
用于统一管理状态、资源引用和客户端连接,解决 app_main.py 中 global 变量混乱的问题。
"""
_instance = None
_lock = asyncio.Lock() # 异步锁,主要用于保护关键状态切换
def __new__(cls):
if cls._instance is None:
cls._instance = super(ServerContext, cls).__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self):
if self._initialized:
return
self._initialized = True
# ====== 1. WebSocket 客户端管理 ======
self.ui_clients: Dict[int, WebSocket] = {}
self.camera_viewers: Set[WebSocket] = set()
self.imu_ws_clients: Set[WebSocket] = set()
self.esp32_audio_ws: Optional[WebSocket] = None
self.esp32_camera_ws: Optional[WebSocket] = None
# ====== 2. 媒体数据缓冲 ======
self.current_partial: str = ""
self.recent_finals: List[str] = []
self.last_frames: Deque[Tuple[float, bytes]] = deque(maxlen=10)
# ====== 3. 业务状态标志 (State Flags) ======
# 盲道导航状态
self.navigation_active: bool = False
# 过马路导航状态
self.cross_street_active: bool = False
# Omni 对话状态
self.omni_conversation_active: bool = False
self.omni_previous_nav_state: Optional[str] = None
# YOLO 媒体流状态
self.yolomedia_running: bool = False
self.yolomedia_sending_frames: bool = False
# ====== 4. 核心组件引用 (Resources) ======
# 导航器实例
self.blind_path_navigator = None
self.cross_street_navigator = None
self.indoor_navigator = None
# 协调器
self.orchestrator = None
# 模型实例
self.yolo_seg_model = None
self.obstacle_detector = None
self.indoor_seg_model = None
# ====== 5. 异步处理资源 ======
# 帧处理线程池
self.frame_processing_executor = ThreadPoolExecutor(max_workers=3, thread_name_prefix="frame_proc")
# 异步帧处理状态
self.nav_processing_task: Optional[asyncio.Task] = None
self.nav_last_result_image: Any = None
self.nav_last_result_jpeg: Optional[bytes] = None
self.nav_pending_frame: Any = None
self.nav_processing_lock = asyncio.Lock()
self.nav_task_start_time: float = 0.0
def reset_navigation_state(self):
"""重置所有导航相关的状态标志"""
self.navigation_active = False
self.cross_street_active = False
self.omni_conversation_active = False
# 注意:这里不停止 orchestrator只是重置标志位
def add_ui_client(self, ws: WebSocket):
self.ui_clients[id(ws)] = ws
def remove_ui_client(self, ws: WebSocket):
self.ui_clients.pop(id(ws), None)
def get_ui_client_count(self) -> int:
return len(self.ui_clients)
# 全局访问点
ctx = ServerContext()

View File

@@ -479,18 +479,25 @@ def is_detection_running():
return _detection_running return _detection_running
def init_model(): def init_model():
"""初始化YOLO模型单帧处理模式""" """初始化YOLO模型单帧处理模式
Day 26 优化: 包含预热推理,避免 TensorRT 重复加载
"""
global _model global _model
if _model is not None: if _model is not None:
print("[TRAFFIC] 模型已加载")
return True return True
try: try:
print("[TRAFFIC] 加载 YOLO 红绿灯检测模型...") print("[TRAFFIC] 加载 YOLO 红绿灯检测模型...")
_model = YOLO(YOLO_MODEL_PATH) _model = YOLO(YOLO_MODEL_PATH, task='detect')
print(f"[TRAFFIC] 模型加载成功: {YOLO_MODEL_PATH}") print(f"[TRAFFIC] 模型加载成功: {YOLO_MODEL_PATH}")
class_names = _model.names if hasattr(_model, 'names') else {} class_names = _model.names if hasattr(_model, 'names') else {}
print(f"[TRAFFIC] 模型类别: {class_names}") print(f"[TRAFFIC] 模型类别: {class_names}")
# Day 26 优化: 预热推理,创建 TensorRT 执行上下文(只创建一次)
test_img = np.zeros((640, 640, 3), dtype=np.uint8)
_ = _model(test_img, conf=CONF_THRESHOLD, verbose=False)
print("[TRAFFIC] 模型预热完成")
return True return True
except Exception as e: except Exception as e:
print(f"[TRAFFIC] 模型加载失败: {e}") print(f"[TRAFFIC] 模型加载失败: {e}")

View File

@@ -88,14 +88,16 @@ class ProcessingResult:
class BlindPathNavigator: class BlindPathNavigator:
"""盲道导航处理器 - 无外部依赖版本""" """盲道导航处理器 - 无外部依赖版本"""
def __init__(self, yolo_model=None, obstacle_detector=None): def __init__(self, yolo_model=None, obstacle_detector=None, enable_crosswalk_detection=True):
""" """
初始化导航器 初始化导航器
:param yolo_model: YOLO分割模型可选 :param yolo_model: YOLO分割模型可选
:param obstacle_detector: 障碍物检测器(可选) :param obstacle_detector: 障碍物检测器(可选)
:param enable_crosswalk_detection: 是否启用斑马线检测(室内模式可关闭)
""" """
self.yolo_model = yolo_model self.yolo_model = yolo_model
self.obstacle_detector = obstacle_detector self.obstacle_detector = obstacle_detector
self.enable_crosswalk_detection = enable_crosswalk_detection
# 状态变量 # 状态变量
self.current_state = STATE_ONBOARDING self.current_state = STATE_ONBOARDING
@@ -185,6 +187,10 @@ class BlindPathNavigator:
f"限制次数={self.straight_repeat_limit}") f"限制次数={self.straight_repeat_limit}")
logger.info(f"[BlindPath] 方向播报配置: 间隔={self.direction_interval}") logger.info(f"[BlindPath] 方向播报配置: 间隔={self.direction_interval}")
# Day 26 优化: 可配置日志采样间隔
self.log_interval = int(os.getenv("AIGLASS_LOG_INTERVAL", "30")) # 每 N 帧输出一次日志
logger.info(f"[BlindPath] 日志采样间隔: 每{self.log_interval}")
# 缓存变量 # 缓存变量
self.prev_gray = None self.prev_gray = None
self.prev_blind_path_mask = None self.prev_blind_path_mask = None
@@ -258,8 +264,14 @@ class BlindPathNavigator:
self.last_crosswalk_mask = None self.last_crosswalk_mask = None
# 【新增】斑马线感知监控器 # 【新增】斑马线感知监控器
self.crosswalk_monitor = CrosswalkAwarenessMonitor() # 【新增】斑马线感知监控器
logger.info("[BlindPath] 斑马线感知监控器已初始化") if self.enable_crosswalk_detection:
self.crosswalk_monitor = CrosswalkAwarenessMonitor()
logger.info("[BlindPath] 斑马线感知监控器已初始化")
else:
self.crosswalk_monitor = None
logger.info("[BlindPath] 斑马线感知监控器已禁用 (室内模式)")
logger.info(f"[BlindPath] 盲道检测间隔: 每{self.BLINDPATH_DETECTION_INTERVAL}") logger.info(f"[BlindPath] 盲道检测间隔: 每{self.BLINDPATH_DETECTION_INTERVAL}")
def init_traffic_light_detector(self): def init_traffic_light_detector(self):
@@ -489,16 +501,24 @@ class BlindPathNavigator:
# 【新增】检查近距离障碍物并设置语音 # 【新增】检查近距离障碍物并设置语音
self._check_and_set_obstacle_voice(detected_obstacles) self._check_and_set_obstacle_voice(detected_obstacles)
# 【配置】如果禁用了斑马线检测强制置为None
if not self.enable_crosswalk_detection:
crosswalk_mask = None
# 【新增】斑马线感知处理 # 【新增】斑马线感知处理
# 【Day 15 优化】减少每帧日志输出,只在每 30 帧输出一次 # 【Day 26 优化】使用可配置的日志间隔
if crosswalk_mask is not None and self.frame_counter % 30 == 0: if crosswalk_mask is not None and self.frame_counter % self.log_interval == 0:
cross_pixels = np.sum(crosswalk_mask > 0) cross_pixels = np.sum(crosswalk_mask > 0)
if cross_pixels > 0: if cross_pixels > 0:
logger.info(f"[斑马线] monitor: pixels={cross_pixels}, area={cross_pixels/crosswalk_mask.size*100:.2f}%") logger.info(f"[斑马线] monitor: pixels={cross_pixels}, area={cross_pixels/crosswalk_mask.size*100:.2f}%")
elif crosswalk_mask is None and self.frame_counter % 30 == 0: elif crosswalk_mask is None and self.frame_counter % self.log_interval == 0:
if self.enable_crosswalk_detection:
logger.info(f"[斑马线] crosswalk_mask为None") logger.info(f"[斑马线] crosswalk_mask为None")
crosswalk_guidance = self.crosswalk_monitor.process_frame(crosswalk_mask, blind_path_mask) crosswalk_guidance = None
if self.crosswalk_monitor:
crosswalk_guidance = self.crosswalk_monitor.process_frame(crosswalk_mask, blind_path_mask)
if crosswalk_guidance: if crosswalk_guidance:
logger.info(f"[斑马线感知] 检测结果: area={crosswalk_guidance.get('area', 0):.3f}, " logger.info(f"[斑马线感知] 检测结果: area={crosswalk_guidance.get('area', 0):.3f}, "
f"should_broadcast={crosswalk_guidance.get('should_broadcast', False)}, " f"should_broadcast={crosswalk_guidance.get('should_broadcast', False)}, "
@@ -511,7 +531,7 @@ class BlindPathNavigator:
logger.info(f"[斑马线语音] 已设置待播报语音: {crosswalk_guidance['voice_text']}, 优先级{crosswalk_guidance['priority']}") logger.info(f"[斑马线语音] 已设置待播报语音: {crosswalk_guidance['voice_text']}, 优先级{crosswalk_guidance['priority']}")
# 【新增】添加斑马线可视化 # 【新增】添加斑马线可视化
if crosswalk_mask is not None: if crosswalk_mask is not None and self.crosswalk_monitor:
# 计算可视化数据 # 计算可视化数据
total_pixels = crosswalk_mask.size total_pixels = crosswalk_mask.size
crosswalk_pixels = np.sum(crosswalk_mask > 0) crosswalk_pixels = np.sum(crosswalk_mask > 0)

View File

@@ -272,21 +272,22 @@ class CrossStreetNavigator:
logger.info(f"[CROSS_STREET] 斑马线检测间隔: 每{self.CROSSWALK_DETECTION_INTERVAL}") logger.info(f"[CROSS_STREET] 斑马线检测间隔: 每{self.CROSSWALK_DETECTION_INTERVAL}")
# 确保模型在 GPU 上 # 确保模型在 GPU 上
# Day 20: TensorRT 引擎不需要 .to() # Day 20/26: TensorRT 引擎不需要 .to(),改用 model_utils 检查
if self.seg_model and torch.cuda.is_available(): if self.seg_model and torch.cuda.is_available():
try: try:
# 检查是否是 TensorRT 引擎 # 检查是否是 TensorRT 引擎
from model_utils import is_tensorrt_engine
model_path = getattr(self.seg_model, 'ckpt_path', '') or '' model_path = getattr(self.seg_model, 'ckpt_path', '') or ''
if not model_path.endswith('.engine'): if is_tensorrt_engine(model_path):
if hasattr(self.seg_model, 'model') and hasattr(self.seg_model.model, 'to'): pass # TensorRT 引擎无需 .to(),静默跳过
self.seg_model.model.to('cuda') elif hasattr(self.seg_model, 'model') and hasattr(self.seg_model.model, 'to'):
elif hasattr(self.seg_model, 'to'): self.seg_model.model.to('cuda')
self.seg_model.to('cuda')
logger.info("[CROSS_STREET] 模型已移至 GPU") logger.info("[CROSS_STREET] 模型已移至 GPU")
else: elif hasattr(self.seg_model, 'to'):
logger.info("[CROSS_STREET] TensorRT 引擎已加载,跳过 .to()") self.seg_model.to('cuda')
except Exception as e: logger.info("[CROSS_STREET] 模型已移至 GPU")
logger.warning(f"[CROSS_STREET] 无法将模型移至 GPU: {e}") except Exception:
pass # Day 26: 静默处理,避免启动日志刷屏
def reset(self): def reset(self):
"""重置状态""" """重置状态"""

454
workflow_indoor.py Normal file
View File

@@ -0,0 +1,454 @@
# -*- coding: utf-8 -*-
"""
室内导航工作流 (Indoor Navigation Workflow)
Day 26: 专为室内导盲模型 (yolo11l-seg-indoor14) 设计
类别映射 (14 classes from MIT Indoor):
- 可行走区域: floor(0), corridor(1), sidewalk(2)
- 静态障碍物: chair(3), table(4), sofa_bed(5), cabinet(11), trash_can(12)
- 兴趣点: door(6), elevator(7), stairs(8)
- 边界: wall(9), window(13)
- 动态障碍: person(10)
"""
import os
import time
import logging
import numpy as np
import cv2
from dataclasses import dataclass
from typing import Optional, List, Dict, Any
from collections import deque
logger = logging.getLogger(__name__)
# ========== 类别常量 ==========
# 可行走区域
WALKABLE_CLASSES = {0, 1, 2} # floor, corridor, sidewalk
CLASS_FLOOR = 0
CLASS_CORRIDOR = 1
CLASS_SIDEWALK = 2
# 静态障碍物
OBSTACLE_CLASSES = {3, 4, 5, 11, 12} # chair, table, sofa_bed, cabinet, trash_can
CLASS_CHAIR = 3
CLASS_TABLE = 4
CLASS_SOFA_BED = 5
CLASS_CABINET = 11
CLASS_TRASH_CAN = 12
# 兴趣点
POI_CLASSES = {6, 7, 8} # door, elevator, stairs
CLASS_DOOR = 6
CLASS_ELEVATOR = 7
CLASS_STAIRS = 8
# 边界
BOUNDARY_CLASSES = {9, 13} # wall, window
CLASS_WALL = 9
CLASS_WINDOW = 13
# 动态障碍
CLASS_PERSON = 10
# 类别名称映射
CLASS_NAMES = {
0: 'floor', 1: 'corridor', 2: 'sidewalk',
3: 'chair', 4: 'table', 5: 'sofa_bed',
6: 'door', 7: 'elevator', 8: 'stairs',
9: 'wall', 10: 'person', 11: 'cabinet',
12: 'trash_can', 13: 'window'
}
# 中文名称(用于语音)
CLASS_NAMES_CN = {
0: '地面', 1: '走廊', 2: '人行道',
3: '椅子', 4: '桌子', 5: '沙发',
6: '', 7: '电梯', 8: '楼梯',
9: '墙壁', 10: '行人', 11: '柜子',
12: '垃圾桶', 13: '窗户'
}
# ========== 配置参数 ==========
CONF_THRESHOLD = float(os.getenv('INDOOR_CONF_THRESHOLD', '0.25'))
WALKABLE_MIN_AREA = int(os.getenv('INDOOR_WALKABLE_MIN_AREA', '3000'))
OBSTACLE_MIN_AREA = int(os.getenv('INDOOR_OBSTACLE_MIN_AREA', '500'))
# 语音间隔
GUIDE_INTERVAL = float(os.getenv('INDOOR_GUIDE_INTERVAL', '3.0'))
DIRECTION_INTERVAL = float(os.getenv('INDOOR_DIRECTION_INTERVAL', '2.5'))
POI_INTERVAL = float(os.getenv('INDOOR_POI_INTERVAL', '5.0'))
OBSTACLE_INTERVAL = float(os.getenv('INDOOR_OBSTACLE_INTERVAL', '2.0'))
# ========== 可视化颜色 (BGR) ==========
VIS_COLORS = {
'walkable': (0, 255, 0), # 绿色 - 可行走
'obstacle': (0, 0, 255), # 红色 - 障碍物
'poi': (255, 255, 0), # 青色 - 兴趣点
'boundary': (128, 128, 128), # 灰色 - 边界
'person': (255, 0, 255), # 粉色 - 行人
'centerline': (255, 255, 0), # 黄色 - 引导线
}
@dataclass
class IndoorResult:
"""室内导航结果"""
annotated_image: Optional[np.ndarray] = None
guidance_text: str = ""
state_info: Dict[str, Any] = None
visualizations: List[Dict[str, Any]] = None
def __post_init__(self):
if self.state_info is None:
self.state_info = {}
if self.visualizations is None:
self.visualizations = []
class IndoorNavigator:
"""室内导航器 - 专为室内导盲模型设计"""
def __init__(self, seg_model=None, device_id: str = "indoor"):
self.seg_model = seg_model
self.device_id = device_id
self.frame_counter = 0
# 语音节流
self.last_guide_time = 0
self.last_direction_time = 0
self.last_poi_time = 0
self.last_obstacle_time = 0
self.last_guidance_text = ""
self.last_direction_text = ""
# 检测间隔
self.detection_interval = int(os.getenv('INDOOR_DETECTION_INTERVAL', '6'))
self.last_detection_frame = 0
# 缓存
self.last_walkable_mask = None
self.last_obstacles = []
self.last_pois = []
# 灰度图(用于光流等)
self.prev_gray = None
# 日志间隔
self.log_interval = int(os.getenv('AIGLASS_LOG_INTERVAL', '30'))
logger.info(f"[INDOOR] 室内导航器初始化完成")
logger.info(f"[INDOOR] 检测间隔: 每{self.detection_interval}")
logger.info(f"[INDOOR] 可行走类别: {[CLASS_NAMES[c] for c in WALKABLE_CLASSES]}")
def reset(self):
"""重置状态"""
self.frame_counter = 0
self.last_guide_time = 0
self.last_direction_time = 0
self.last_poi_time = 0
self.last_obstacle_time = 0
self.last_guidance_text = ""
self.last_direction_text = ""
self.last_walkable_mask = None
self.last_obstacles = []
self.last_pois = []
self.prev_gray = None
logger.info("[INDOOR] 导航器已重置")
def process_frame(self, image: np.ndarray) -> IndoorResult:
"""处理单帧图像"""
self.frame_counter += 1
h, w = image.shape[:2]
now = time.time()
frame_visualizations = []
guidance_text = ""
state_info = {}
# 是否执行检测
should_detect = (self.frame_counter - self.last_detection_frame) >= self.detection_interval
if should_detect and self.seg_model is not None:
self.last_detection_frame = self.frame_counter
# 执行分割推理
walkable_mask, obstacles, pois = self._detect_all(image)
# 更新缓存
self.last_walkable_mask = walkable_mask
self.last_obstacles = obstacles
self.last_pois = pois
else:
# 使用缓存
walkable_mask = self.last_walkable_mask
obstacles = self.last_obstacles
pois = self.last_pois
# 生成导航引导
if walkable_mask is not None:
guidance_text = self._generate_guidance(walkable_mask, obstacles, pois, h, w, now)
# 添加可视化
self._add_mask_visualization(walkable_mask, frame_visualizations,
"walkable_mask", "rgba(0, 255, 0, 0.3)")
# 障碍物可视化
for obs in obstacles:
self._add_detection_visualization(obs, frame_visualizations, "obstacle")
# 兴趣点可视化
for poi in pois:
self._add_detection_visualization(poi, frame_visualizations, "poi")
# 日志
if self.frame_counter % self.log_interval == 0:
walkable_area = int(walkable_mask.sum()) if walkable_mask is not None else 0
logger.info(f"[INDOOR] Frame={self.frame_counter} | 可行走面积={walkable_area} | "
f"障碍物={len(obstacles)} | 兴趣点={len(pois)}")
# 更新状态信息
state_info = {
'frame': self.frame_counter,
'walkable_detected': walkable_mask is not None and walkable_mask.sum() > 0,
'obstacles_count': len(obstacles),
'pois_count': len(pois),
}
# 更新灰度图
self.prev_gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
return IndoorResult(
annotated_image=image.copy(),
guidance_text=guidance_text,
state_info=state_info,
visualizations=frame_visualizations
)
def _detect_all(self, image: np.ndarray):
"""执行分割检测,返回可行走区域、障碍物、兴趣点"""
h, w = image.shape[:2]
walkable_mask = np.zeros((h, w), dtype=np.uint8)
obstacles = []
pois = []
try:
imgsz = int(os.getenv("AIGLASS_YOLO_IMGSZ", "480"))
use_half = os.getenv("AIGLASS_YOLO_HALF", "1") == "1"
results = self.seg_model.predict(
image,
imgsz=imgsz,
conf=CONF_THRESHOLD,
verbose=False,
half=use_half
)
if results and len(results) > 0 and results[0].masks is not None:
r0 = results[0]
masks = r0.masks.data.cpu().numpy()
boxes = r0.boxes
for i, (mask, cls_id, conf) in enumerate(zip(masks, boxes.cls, boxes.conf)):
cls_id = int(cls_id.item())
conf_val = float(conf.item())
# 调整 mask 尺寸
mask_resized = cv2.resize(mask, (w, h), interpolation=cv2.INTER_NEAREST)
mask_bin = (mask_resized > 0.5).astype(np.uint8)
area = int(mask_bin.sum())
if area < 100: # 过滤小碎片
continue
# 可行走区域
if cls_id in WALKABLE_CLASSES and area > WALKABLE_MIN_AREA:
walkable_mask = cv2.bitwise_or(walkable_mask, mask_bin * 255)
# 障碍物
elif cls_id in OBSTACLE_CLASSES or cls_id == CLASS_PERSON:
if area > OBSTACLE_MIN_AREA:
obstacles.append({
'class_id': cls_id,
'class_name': CLASS_NAMES.get(cls_id, 'unknown'),
'class_name_cn': CLASS_NAMES_CN.get(cls_id, '未知'),
'conf': conf_val,
'mask': mask_bin,
'area': area,
'center': self._mask_center(mask_bin),
})
# 兴趣点
elif cls_id in POI_CLASSES:
pois.append({
'class_id': cls_id,
'class_name': CLASS_NAMES.get(cls_id, 'unknown'),
'class_name_cn': CLASS_NAMES_CN.get(cls_id, '未知'),
'conf': conf_val,
'mask': mask_bin,
'area': area,
'center': self._mask_center(mask_bin),
})
except Exception as e:
logger.warning(f"[INDOOR] 检测失败: {e}")
return walkable_mask, obstacles, pois
def _mask_center(self, mask: np.ndarray):
"""计算 mask 质心"""
M = cv2.moments(mask)
if abs(M["m00"]) < 1e-6:
return None
cx = int(M["m10"] / M["m00"])
cy = int(M["m01"] / M["m00"])
return (cx, cy)
def _generate_guidance(self, walkable_mask, obstacles, pois, h, w, now):
"""生成导航引导文本"""
guidance_text = ""
# 1. 计算可行走区域的偏移和方向
direction_guidance = self._compute_direction_guidance(walkable_mask, h, w)
# 2. 检查障碍物警告
obstacle_warning = self._check_obstacle_warning(obstacles, walkable_mask, h, w)
# 3. 检查兴趣点提示
poi_hint = self._check_poi_hint(pois, h, w)
# 优先级:障碍物 > 方向 > 兴趣点
if obstacle_warning and (now - self.last_obstacle_time) > OBSTACLE_INTERVAL:
guidance_text = obstacle_warning
self.last_obstacle_time = now
self.last_guidance_text = guidance_text
elif direction_guidance:
# 方向引导节流
if direction_guidance != self.last_direction_text:
if (now - self.last_direction_time) > DIRECTION_INTERVAL:
guidance_text = direction_guidance
self.last_direction_time = now
self.last_direction_text = direction_guidance
elif (now - self.last_guide_time) > GUIDE_INTERVAL:
# 同样的方向,降低频率
guidance_text = direction_guidance
self.last_guide_time = now
elif poi_hint and (now - self.last_poi_time) > POI_INTERVAL:
guidance_text = poi_hint
self.last_poi_time = now
return guidance_text
def _compute_direction_guidance(self, walkable_mask, h, w):
"""计算方向引导"""
if walkable_mask is None or walkable_mask.sum() < WALKABLE_MIN_AREA:
return "未检测到可行走区域"
# 分析下半部分(更近的区域)
lower_half = walkable_mask[int(h * 0.5):, :]
if lower_half.sum() < 1000:
return "前方可行走区域较小,请小心"
# 计算左中右分布
third = w // 3
left_area = lower_half[:, :third].sum()
center_area = lower_half[:, third:2*third].sum()
right_area = lower_half[:, 2*third:].sum()
total = left_area + center_area + right_area + 1e-6
left_ratio = left_area / total
center_ratio = center_area / total
right_ratio = right_area / total
# 方向判断
if center_ratio > 0.4:
return "保持直行"
elif left_ratio > right_ratio * 1.5:
return "向左调整"
elif right_ratio > left_ratio * 1.5:
return "向右调整"
else:
return "保持直行"
def _check_obstacle_warning(self, obstacles, walkable_mask, h, w):
"""检查是否有障碍物在前方"""
if not obstacles:
return None
# 定义前方区域(画面中下部)
front_zone_top = int(h * 0.4)
front_zone_left = int(w * 0.2)
front_zone_right = int(w * 0.8)
for obs in obstacles:
center = obs.get('center')
if center is None:
continue
cx, cy = center
# 检查是否在前方区域
if front_zone_top < cy < h and front_zone_left < cx < front_zone_right:
name_cn = obs.get('class_name_cn', '障碍物')
# 判断位置
if cx < w * 0.4:
return f"左前方有{name_cn}"
elif cx > w * 0.6:
return f"右前方有{name_cn}"
else:
return f"正前方有{name_cn}"
return None
def _check_poi_hint(self, pois, h, w):
"""检查兴趣点提示"""
if not pois:
return None
for poi in pois:
cls_id = poi.get('class_id')
name_cn = poi.get('class_name_cn', '兴趣点')
center = poi.get('center')
if center is None:
continue
cx, cy = center
# 楼梯需要特别警告
if cls_id == CLASS_STAIRS:
if cy > h * 0.5: # 比较近
return f"注意前方有{name_cn}"
# 门/电梯提示
elif cls_id in (CLASS_DOOR, CLASS_ELEVATOR):
if cy > h * 0.3: # 在视野内
position = "左侧" if cx < w * 0.4 else ("右侧" if cx > w * 0.6 else "前方")
return f"{position}{name_cn}"
return None
def _add_mask_visualization(self, mask, visualizations, viz_type, color):
"""添加 mask 可视化"""
if mask is None or mask.sum() == 0:
return
visualizations.append({
'type': viz_type,
'mask': mask,
'color': color
})
def _add_detection_visualization(self, detection, visualizations, det_type):
"""添加检测框可视化"""
center = detection.get('center')
if center is None:
return
visualizations.append({
'type': det_type,
'center': center,
'class_name': detection.get('class_name', 'unknown'),
'class_name_cn': detection.get('class_name_cn', '未知'),
'conf': detection.get('conf', 0),
})

View File

@@ -24,6 +24,10 @@ from mediapipe.framework.formats import landmark_pb2
from ultralytics import YOLO from ultralytics import YOLO
from ultralytics.utils.plotting import Colors from ultralytics.utils.plotting import Colors
import bridge_io import bridge_io
# Day 26: 抑制 pygame 社区欢迎信息
import os
os.environ['PYGAME_HIDE_SUPPORT_PROMPT'] = "1"
import pygame # 用于播放本地音频文件 import pygame # 用于播放本地音频文件
from audio_player import play_audio_threadsafe from audio_player import play_audio_threadsafe