Files
NaviGlassServer/yoloe_backend.py
2026-01-06 17:15:06 +08:00

100 lines
3.8 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# yoloe_backend.py
# -*- coding: utf-8 -*-
from typing import List, Dict, Any, Optional, Tuple, Union
import os
import cv2
import numpy as np
# Day 20: TensorRT 模型加载工具
from model_utils import get_best_model_path
# 兼容 YOLOE / YOLO
try:
from ultralytics import YOLOE as _MODEL
except Exception:
from ultralytics import YOLO as _MODEL
# Day 20: 优先使用 TensorRT 引擎
# Day 28: 使用基于当前文件的绝对路径
_DEFAULT_YOLOE_PATH = os.path.join(os.path.dirname(os.path.abspath(__file__)), "model", "yoloe-11l-seg.pt")
DEFAULT_MODEL_PATH = get_best_model_path(os.getenv("YOLOE_MODEL_PATH", _DEFAULT_YOLOE_PATH))
TRACKER_CFG = os.getenv("YOLO_TRACKER_YAML", "bytetrack.yaml")
class YoloEBackend:
def __init__(self, model_path: Optional[str] = None, device: Optional[Union[str, int]] = None):
actual_path = model_path or DEFAULT_MODEL_PATH
self.model = _MODEL(actual_path)
# Day 20: TensorRT 引擎不需要 .to()
from model_utils import is_tensorrt_engine
if not is_tensorrt_engine(actual_path):
self.model.to("cuda")
self.device = device
def set_text_classes(self, names: List[str]):
# YOLOE 文本提示:与你模板一致
# Day 20: TensorRT 引擎不支持 get_text_pe
if hasattr(self.model, 'get_text_pe'):
self.model.set_classes(names, self.model.get_text_pe(names))
else:
print(f"[YOLOE] TensorRT 模式:跳过 set_text_classes")
def segment(self,
frame_bgr: np.ndarray,
conf: float = 0.20,
iou: float = 0.45,
imgsz: int = None, # Day 20: 改为 None从环境变量读取
persist: bool = True
) -> Dict[str, Any]:
"""
返回:
dict{
'masks': List[np.uint8(H,W)], # 0/1 mask
'boxes': List[Tuple[x1,y1,x2,y2]],
'cls_ids': List[int],
'names': List[str],
'ids': List[Optional[int]]
}
"""
# Day 20: 使用环境变量,保持与 TensorRT 导出尺寸一致
if imgsz is None:
imgsz = int(os.getenv("AIGLASS_YOLO_IMGSZ", "480"))
r = self.model.track(
frame_bgr,
conf=conf, iou=iou, imgsz=imgsz,
persist=persist, tracker=TRACKER_CFG, verbose=False
)[0]
out = {"masks": [], "boxes": [], "cls_ids": [], "names": [], "ids": []}
masks_obj = getattr(r, "masks", None)
boxes_obj = getattr(r, "boxes", None)
if masks_obj is None or getattr(masks_obj, "data", None) is None:
return out
mask_arr = masks_obj.data.cpu().numpy() # [N, h, w], float(0..1)
H, W = frame_bgr.shape[:2]
id2name = r.names if hasattr(r, "names") else {}
N = mask_arr.shape[0]
if boxes_obj is not None:
xyxy = boxes_obj.xyxy.cpu().numpy()
cls = boxes_obj.cls.cpu().tolist()
tids = boxes_obj.id.int().cpu().tolist() if boxes_obj.id is not None else [None]*N
else:
xyxy = [None]*N
cls = [0]*N
tids = [None]*N
for i in range(N):
bin_mask = (mask_arr[i] > 0.5).astype(np.uint8)
if bin_mask.shape[:2] != (H, W):
bin_mask = cv2.resize(bin_mask, (W, H), interpolation=cv2.INTER_NEAREST)
out["masks"].append(bin_mask)
out["boxes"].append(tuple(xyxy[i]) if xyxy[i] is not None else None)
cid = int(cls[i]) if cls is not None else 0
out["cls_ids"].append(cid)
out["names"].append(id2name.get(cid, str(cid)))
out["ids"].append(int(tids[i]) if tids[i] is not None else None)
return out