Init: 导入NaviGlassServer源码
This commit is contained in:
159
models.py
Normal file
159
models.py
Normal file
@@ -0,0 +1,159 @@
|
||||
# app/models.py
|
||||
import os
|
||||
import logging
|
||||
import torch
|
||||
from threading import Semaphore
|
||||
from contextlib import contextmanager
|
||||
from typing import List
|
||||
from app.cloud.obstacle_detector_client import ObstacleDetectorClient
|
||||
# ==========================================================
|
||||
# 0. 导入所有需要的模型封装类 (Clients) 和 Ultralytics 基类
|
||||
# ==========================================================
|
||||
# 这是过马路工作流使用的封装类
|
||||
from app.cloud.crosswalk_detector_client import CrosswalkDetector
|
||||
from app.cloud.coco_perception_client import COCOClient
|
||||
from obstacle_detector_client import ObstacleDetectorClient
|
||||
|
||||
# Day 20: TensorRT 模型加载工具
|
||||
from model_utils import get_best_model_path, is_tensorrt_engine
|
||||
|
||||
# 这是盲道工作流直接使用的 Ultralytics 类
|
||||
from ultralytics import YOLO, YOLOE
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ==========================================================
|
||||
# 1. 全局设备与并发控制 (统一管理)
|
||||
# ==========================================================
|
||||
DEVICE = os.getenv("AIGLASS_DEVICE", "cuda:0")
|
||||
if DEVICE.startswith("cuda") and not torch.cuda.is_available():
|
||||
logger.warning(f"AIGLASS_DEVICE={DEVICE} 但未检测到 CUDA,将回退到 CPU")
|
||||
DEVICE = "cpu"
|
||||
IS_CUDA = DEVICE.startswith("cuda")
|
||||
|
||||
# AMP (自动混合精度) 配置
|
||||
AMP_POLICY = os.getenv("AIGLASS_AMP", "bf16").lower()
|
||||
AMP_DTYPE = torch.bfloat16 if AMP_POLICY == "bf16" else (
|
||||
torch.float16 if AMP_POLICY == "fp16" else None) if IS_CUDA else None
|
||||
|
||||
# 🔥 核心:全局唯一的GPU并发信号量,所有工作流共享
|
||||
GPU_SLOTS = int(os.getenv("AIGLASS_GPU_SLOTS", "2"))
|
||||
gpu_semaphore = Semaphore(GPU_SLOTS)
|
||||
|
||||
|
||||
# 统一的推理上下文管理器,所有工作流都应使用它来调用模型
|
||||
@contextmanager
|
||||
def gpu_infer_slot():
|
||||
"""
|
||||
统一管理:GPU 并发限流 + torch.inference_mode() + AMP autocast
|
||||
"""
|
||||
with gpu_semaphore:
|
||||
if IS_CUDA and AMP_POLICY != "off" and AMP_DTYPE is not None:
|
||||
with torch.inference_mode(), torch.amp.autocast('cuda', dtype=AMP_DTYPE):
|
||||
yield
|
||||
else:
|
||||
with torch.inference_mode():
|
||||
yield
|
||||
|
||||
|
||||
# cuDNN 加速优化
|
||||
try:
|
||||
if IS_CUDA:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# ==========================================================
|
||||
# 2. 全局模型实例定义 (全部初始化为 None)
|
||||
# ==========================================================
|
||||
|
||||
# --- 过马路工作流模型 (通过Client类封装) ---
|
||||
crosswalk_detector_client: CrosswalkDetector = None
|
||||
coco_client: COCOClient = None
|
||||
# ObstacleDetectorClient 将作为所有场景的通用障碍物检测器
|
||||
obstacle_detector_client: ObstacleDetectorClient = None
|
||||
|
||||
# --- 盲道工作流模型 (直接使用Ultralytics类) ---
|
||||
# 它们主要用于分割和路径规划,与过马路场景的检测逻辑不同
|
||||
blindpath_seg_model: YOLO = None
|
||||
# 障碍物检测将复用 obstacle_detector_client,但YOLOE的文本特征需要单独保存
|
||||
blindpath_whitelist_embeddings = None
|
||||
|
||||
# 全局加载状态标志
|
||||
models_are_loaded = False
|
||||
|
||||
|
||||
# ==========================================================
|
||||
# 3. 统一的模型加载函数 (由 celery.py 在启动时调用)
|
||||
# ==========================================================
|
||||
def init_all_models():
|
||||
"""
|
||||
在Celery Worker进程启动时被调用一次。
|
||||
负责加载所有工作流所需的模型到全局变量中。
|
||||
"""
|
||||
global models_are_loaded
|
||||
if models_are_loaded:
|
||||
return
|
||||
|
||||
logger.info(f"========= 🚀 开始全局模型预加载 (目标设备: {DEVICE}) =========")
|
||||
|
||||
try:
|
||||
# --- [1] 加载通用的障碍物检测器 (ObstacleDetectorClient) ---
|
||||
global obstacle_detector_client
|
||||
logger.info("[1/4] 正在加载通用障碍物检测模型 (ObstacleDetectorClient)...")
|
||||
# Day 20: 优先使用 TensorRT 引擎
|
||||
obs_model_path = get_best_model_path('model/yoloe-11l-seg.pt')
|
||||
obstacle_detector_client = ObstacleDetectorClient(model_path=obs_model_path)
|
||||
|
||||
# Day 20: TensorRT 引擎不需要 .to()
|
||||
if not is_tensorrt_engine(obs_model_path):
|
||||
if hasattr(obstacle_detector_client, 'model') and obstacle_detector_client.model is not None:
|
||||
obstacle_detector_client.model.to(DEVICE)
|
||||
|
||||
logger.info("...通用障碍物检测模型加载成功。")
|
||||
|
||||
# --- [2] 加载过马路专用的模型 (Clients) ---
|
||||
global crosswalk_detector_client, coco_client
|
||||
logger.info("[2/4] 正在加载过马路分割模型 (CrosswalkDetector)...")
|
||||
# Day 20: 优先使用 TensorRT 引擎
|
||||
crosswalk_model_path = get_best_model_path('model/yolo-seg.pt')
|
||||
crosswalk_detector_client = CrosswalkDetector(model_path=crosswalk_model_path)
|
||||
# Day 20: TensorRT 引擎不需要 .to()
|
||||
if not is_tensorrt_engine(crosswalk_model_path):
|
||||
if hasattr(crosswalk_detector_client, 'model') and crosswalk_detector_client.model is not None:
|
||||
crosswalk_detector_client.model.to(DEVICE)
|
||||
logger.info("...过马路分割模型加载成功。")
|
||||
|
||||
logger.info("[3/4] 正在加载通用感知模型 (COCOClient)...")
|
||||
coco_client = COCOClient(model_path='model/yolov8l-world.pt')
|
||||
# 将其内部的YOLO模型移动到指定设备
|
||||
if hasattr(coco_client, 'model') and coco_client.model is not None:
|
||||
coco_client.model.to(DEVICE)
|
||||
logger.info("...通用感知模型加载成功。")
|
||||
|
||||
# --- [4] 加载盲道专用的模型 ---
|
||||
global blindpath_seg_model, blindpath_whitelist_embeddings
|
||||
logger.info("[4/4] 正在加载盲道专用分割模型 (YOLO)...")
|
||||
# Day 20: 优先使用 TensorRT 引擎
|
||||
blindpath_model_path = get_best_model_path('model/yolo-seg.pt')
|
||||
blindpath_seg_model = YOLO(blindpath_model_path)
|
||||
# Day 20: TensorRT 引擎不需要 .to() 和 .fuse()
|
||||
if not is_tensorrt_engine(blindpath_model_path):
|
||||
blindpath_seg_model.to(DEVICE)
|
||||
blindpath_seg_model.fuse()
|
||||
logger.info("...盲道专用分割模型加载成功。")
|
||||
|
||||
# 为盲道工作流保存其需要的YOLOE文本特征引用
|
||||
if obstacle_detector_client:
|
||||
blindpath_whitelist_embeddings = obstacle_detector_client.whitelist_embeddings
|
||||
logger.info("...已为盲道工作流链接障碍物模型特征。")
|
||||
|
||||
# 所有模型加载完毕
|
||||
models_are_loaded = True
|
||||
logger.info("========= ✅ 所有模型已成功预加载。Worker准备就绪! =========")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"模型预加载过程中发生严重错误: {e}", exc_info=True)
|
||||
# 抛出异常,这将导致Celery Worker启动失败,这是合理的行为
|
||||
# 因为一个没有模型的Worker是无用的,提前暴露问题更好。
|
||||
raise
|
||||
Reference in New Issue
Block a user