Init: 导入NaviGlassServer源码
This commit is contained in:
185
numba_utils.py
Normal file
185
numba_utils.py
Normal file
@@ -0,0 +1,185 @@
|
||||
# numba_utils.py - Day 20 Numba 多核加速工具
|
||||
"""
|
||||
使用 Numba JIT 编译加速 numpy 密集操作,绕过 Python GIL 实现真正多核并行
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
|
||||
try:
|
||||
from numba import jit, prange
|
||||
NUMBA_AVAILABLE = True
|
||||
except ImportError:
|
||||
NUMBA_AVAILABLE = False
|
||||
print("[NUMBA] Numba 未安装,使用 numpy 回退实现")
|
||||
|
||||
|
||||
if NUMBA_AVAILABLE:
|
||||
@jit(nopython=True, parallel=True, cache=True)
|
||||
def count_mask_pixels_numba(mask: np.ndarray) -> int:
|
||||
"""快速计算 mask 中非零像素数量(多核并行)"""
|
||||
count = 0
|
||||
h, w = mask.shape
|
||||
for i in prange(h):
|
||||
for j in range(w):
|
||||
if mask[i, j] > 0:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@jit(nopython=True, parallel=True, cache=True)
|
||||
def compute_mask_stats_numba(mask: np.ndarray) -> tuple:
|
||||
"""
|
||||
快速计算 mask 的统计信息(多核并行)
|
||||
返回: (area, center_x, center_y, min_x, max_x, min_y, max_y)
|
||||
"""
|
||||
h, w = mask.shape
|
||||
count = 0
|
||||
sum_x = 0.0
|
||||
sum_y = 0.0
|
||||
min_x = w
|
||||
max_x = 0
|
||||
min_y = h
|
||||
max_y = 0
|
||||
|
||||
for i in prange(h):
|
||||
for j in range(w):
|
||||
if mask[i, j] > 0:
|
||||
count += 1
|
||||
sum_x += j
|
||||
sum_y += i
|
||||
if j < min_x:
|
||||
min_x = j
|
||||
if j > max_x:
|
||||
max_x = j
|
||||
if i < min_y:
|
||||
min_y = i
|
||||
if i > max_y:
|
||||
max_y = i
|
||||
|
||||
if count > 0:
|
||||
center_x = sum_x / count
|
||||
center_y = sum_y / count
|
||||
else:
|
||||
center_x = 0.0
|
||||
center_y = 0.0
|
||||
|
||||
return (count, center_x, center_y, min_x, max_x, min_y, max_y)
|
||||
|
||||
@jit(nopython=True, parallel=True, cache=True)
|
||||
def bitwise_and_count_numba(mask1: np.ndarray, mask2: np.ndarray) -> int:
|
||||
"""快速计算两个 mask 的交集像素数量(多核并行)"""
|
||||
h, w = mask1.shape
|
||||
count = 0
|
||||
for i in prange(h):
|
||||
for j in range(w):
|
||||
if mask1[i, j] > 0 and mask2[i, j] > 0:
|
||||
count += 1
|
||||
return count
|
||||
|
||||
@jit(nopython=True, parallel=True, cache=True)
|
||||
def resize_mask_nearest_numba(mask: np.ndarray, new_h: int, new_w: int) -> np.ndarray:
|
||||
"""
|
||||
快速最近邻插值缩放 mask(多核并行)
|
||||
注意:这是简化实现,对于大多数情况足够用
|
||||
"""
|
||||
old_h, old_w = mask.shape
|
||||
result = np.zeros((new_h, new_w), dtype=np.uint8)
|
||||
|
||||
scale_y = old_h / new_h
|
||||
scale_x = old_w / new_w
|
||||
|
||||
for i in prange(new_h):
|
||||
for j in range(new_w):
|
||||
src_y = int(i * scale_y)
|
||||
src_x = int(j * scale_x)
|
||||
if src_y >= old_h:
|
||||
src_y = old_h - 1
|
||||
if src_x >= old_w:
|
||||
src_x = old_w - 1
|
||||
result[i, j] = mask[src_y, src_x]
|
||||
|
||||
return result
|
||||
|
||||
|
||||
# 对外接口:根据 Numba 是否可用选择实现
|
||||
def count_mask_pixels(mask: np.ndarray) -> int:
|
||||
"""计算 mask 中非零像素数量"""
|
||||
if NUMBA_AVAILABLE:
|
||||
return count_mask_pixels_numba(mask)
|
||||
else:
|
||||
return int(np.sum(mask > 0))
|
||||
|
||||
|
||||
def compute_mask_stats(mask: np.ndarray) -> dict:
|
||||
"""
|
||||
计算 mask 的统计信息
|
||||
返回: {'area': int, 'center_x': float, 'center_y': float, 'bbox': (x1, y1, x2, y2)}
|
||||
"""
|
||||
if NUMBA_AVAILABLE:
|
||||
area, cx, cy, min_x, max_x, min_y, max_y = compute_mask_stats_numba(mask)
|
||||
return {
|
||||
'area': int(area),
|
||||
'center_x': float(cx),
|
||||
'center_y': float(cy),
|
||||
'bbox': (int(min_x), int(min_y), int(max_x), int(max_y))
|
||||
}
|
||||
else:
|
||||
# numpy 回退
|
||||
y_coords, x_coords = np.where(mask > 0)
|
||||
if len(y_coords) == 0:
|
||||
return {'area': 0, 'center_x': 0, 'center_y': 0, 'bbox': (0, 0, 0, 0)}
|
||||
return {
|
||||
'area': len(y_coords),
|
||||
'center_x': float(np.mean(x_coords)),
|
||||
'center_y': float(np.mean(y_coords)),
|
||||
'bbox': (int(np.min(x_coords)), int(np.min(y_coords)),
|
||||
int(np.max(x_coords)), int(np.max(y_coords)))
|
||||
}
|
||||
|
||||
|
||||
def bitwise_and_count(mask1: np.ndarray, mask2: np.ndarray) -> int:
|
||||
"""计算两个 mask 的交集像素数量"""
|
||||
if NUMBA_AVAILABLE:
|
||||
return bitwise_and_count_numba(mask1.astype(np.uint8), mask2.astype(np.uint8))
|
||||
else:
|
||||
return int(np.sum(np.bitwise_and(mask1, mask2) > 0))
|
||||
|
||||
|
||||
# 预热 JIT 编译(首次调用时编译,之后使用缓存)
|
||||
def warmup():
|
||||
"""预热 Numba JIT 编译,避免首次调用时的延迟"""
|
||||
if NUMBA_AVAILABLE:
|
||||
dummy = np.zeros((10, 10), dtype=np.uint8)
|
||||
dummy[5, 5] = 255
|
||||
count_mask_pixels_numba(dummy)
|
||||
compute_mask_stats_numba(dummy)
|
||||
bitwise_and_count_numba(dummy, dummy)
|
||||
print("[NUMBA] JIT 编译预热完成,已启用多核加速")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# 测试和性能对比
|
||||
import time
|
||||
|
||||
# 创建测试数据
|
||||
test_mask = np.zeros((480, 640), dtype=np.uint8)
|
||||
test_mask[100:300, 200:400] = 255
|
||||
|
||||
# 测试 numpy 版本
|
||||
start = time.perf_counter()
|
||||
for _ in range(100):
|
||||
np.sum(test_mask > 0)
|
||||
numpy_time = (time.perf_counter() - start) * 1000
|
||||
|
||||
# 测试 numba 版本
|
||||
if NUMBA_AVAILABLE:
|
||||
# 预热
|
||||
count_mask_pixels_numba(test_mask)
|
||||
|
||||
start = time.perf_counter()
|
||||
for _ in range(100):
|
||||
count_mask_pixels_numba(test_mask)
|
||||
numba_time = (time.perf_counter() - start) * 1000
|
||||
|
||||
print(f"numpy: {numpy_time:.2f}ms / 100 次")
|
||||
print(f"numba: {numba_time:.2f}ms / 100 次")
|
||||
print(f"加速比: {numpy_time / numba_time:.1f}x")
|
||||
Reference in New Issue
Block a user