38 lines
1.1 KiB
Python
38 lines
1.1 KiB
Python
# model_utils.py - Day 20 TensorRT 模型加载工具
|
||
"""
|
||
优先加载 TensorRT .engine 文件,不存在时回退到 .pt
|
||
"""
|
||
|
||
import os
|
||
|
||
|
||
def get_best_model_path(pt_path: str) -> str:
|
||
"""
|
||
根据 .pt 路径自动选择最佳模型文件:
|
||
1. 优先加载同目录下的 .engine 文件(TensorRT 加速)
|
||
2. 如果 .engine 不存在,回退到原 .pt 文件
|
||
|
||
参数:
|
||
pt_path: 原始 .pt 模型路径(如 'model/yolo-seg.pt')
|
||
|
||
返回:
|
||
最佳模型路径(.engine 或 .pt)
|
||
"""
|
||
if not pt_path.endswith('.pt'):
|
||
return pt_path
|
||
|
||
engine_path = pt_path.replace('.pt', '.engine')
|
||
|
||
if os.path.exists(engine_path):
|
||
print(f"[MODEL] 🚀 使用 TensorRT 加速: {engine_path}")
|
||
return engine_path
|
||
else:
|
||
print(f"[MODEL] 使用 PyTorch 模型: {pt_path}")
|
||
return pt_path
|
||
|
||
|
||
def is_tensorrt_engine(model_path: str) -> bool:
|
||
"""检查模型路径是否是 TensorRT 引擎(.engine 文件)"""
|
||
return model_path.endswith('.engine') or model_path.endswith('.trt')
|
||
|