Files
NaviGlassServer/model_utils.py
2025-12-31 15:42:30 +08:00

38 lines
1.1 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# model_utils.py - Day 20 TensorRT 模型加载工具
"""
优先加载 TensorRT .engine 文件,不存在时回退到 .pt
"""
import os
def get_best_model_path(pt_path: str) -> str:
"""
根据 .pt 路径自动选择最佳模型文件:
1. 优先加载同目录下的 .engine 文件TensorRT 加速)
2. 如果 .engine 不存在,回退到原 .pt 文件
参数:
pt_path: 原始 .pt 模型路径(如 'model/yolo-seg.pt'
返回:
最佳模型路径(.engine 或 .pt
"""
if not pt_path.endswith('.pt'):
return pt_path
engine_path = pt_path.replace('.pt', '.engine')
if os.path.exists(engine_path):
print(f"[MODEL] 🚀 使用 TensorRT 加速: {engine_path}")
return engine_path
else:
print(f"[MODEL] 使用 PyTorch 模型: {pt_path}")
return pt_path
def is_tensorrt_engine(model_path: str) -> bool:
"""检查模型路径是否是 TensorRT 引擎(.engine 文件)"""
return model_path.endswith('.engine') or model_path.endswith('.trt')