# model_utils.py - Day 20 TensorRT 模型加载工具 """ 优先加载 TensorRT .engine 文件,不存在时回退到 .pt """ import os def get_best_model_path(pt_path: str) -> str: """ 根据 .pt 路径自动选择最佳模型文件: 1. 优先加载同目录下的 .engine 文件(TensorRT 加速) 2. 如果 .engine 不存在,回退到原 .pt 文件 参数: pt_path: 原始 .pt 模型路径(如 'model/yolo-seg.pt') 返回: 最佳模型路径(.engine 或 .pt) """ if not pt_path.endswith('.pt'): return pt_path engine_path = pt_path.replace('.pt', '.engine') if os.path.exists(engine_path): print(f"[MODEL] 🚀 使用 TensorRT 加速: {engine_path}") return engine_path else: print(f"[MODEL] 使用 PyTorch 模型: {pt_path}") return pt_path def is_tensorrt_engine(model_path: str) -> bool: """检查模型路径是否是 TensorRT 引擎(.engine 文件)""" return model_path.endswith('.engine') or model_path.endswith('.trt')