代码优化
This commit is contained in:
@@ -272,21 +272,22 @@ class CrossStreetNavigator:
|
||||
logger.info(f"[CROSS_STREET] 斑马线检测间隔: 每{self.CROSSWALK_DETECTION_INTERVAL}帧")
|
||||
|
||||
# 确保模型在 GPU 上
|
||||
# Day 20: TensorRT 引擎不需要 .to()
|
||||
# Day 20/26: TensorRT 引擎不需要 .to(),改用 model_utils 检查
|
||||
if self.seg_model and torch.cuda.is_available():
|
||||
try:
|
||||
# 检查是否是 TensorRT 引擎
|
||||
from model_utils import is_tensorrt_engine
|
||||
model_path = getattr(self.seg_model, 'ckpt_path', '') or ''
|
||||
if not model_path.endswith('.engine'):
|
||||
if hasattr(self.seg_model, 'model') and hasattr(self.seg_model.model, 'to'):
|
||||
self.seg_model.model.to('cuda')
|
||||
elif hasattr(self.seg_model, 'to'):
|
||||
self.seg_model.to('cuda')
|
||||
if is_tensorrt_engine(model_path):
|
||||
pass # TensorRT 引擎无需 .to(),静默跳过
|
||||
elif hasattr(self.seg_model, 'model') and hasattr(self.seg_model.model, 'to'):
|
||||
self.seg_model.model.to('cuda')
|
||||
logger.info("[CROSS_STREET] 模型已移至 GPU")
|
||||
else:
|
||||
logger.info("[CROSS_STREET] TensorRT 引擎已加载,跳过 .to()")
|
||||
except Exception as e:
|
||||
logger.warning(f"[CROSS_STREET] 无法将模型移至 GPU: {e}")
|
||||
elif hasattr(self.seg_model, 'to'):
|
||||
self.seg_model.to('cuda')
|
||||
logger.info("[CROSS_STREET] 模型已移至 GPU")
|
||||
except Exception:
|
||||
pass # Day 26: 静默处理,避免启动日志刷屏
|
||||
|
||||
def reset(self):
|
||||
"""重置状态"""
|
||||
|
||||
Reference in New Issue
Block a user