Files
NaviGlassServer/numba_utils.py
2025-12-31 15:42:30 +08:00

186 lines
6.0 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# numba_utils.py - Day 20 Numba 多核加速工具
"""
使用 Numba JIT 编译加速 numpy 密集操作,绕过 Python GIL 实现真正多核并行
"""
import numpy as np
try:
from numba import jit, prange
NUMBA_AVAILABLE = True
except ImportError:
NUMBA_AVAILABLE = False
print("[NUMBA] Numba 未安装,使用 numpy 回退实现")
if NUMBA_AVAILABLE:
@jit(nopython=True, parallel=True, cache=True)
def count_mask_pixels_numba(mask: np.ndarray) -> int:
"""快速计算 mask 中非零像素数量(多核并行)"""
count = 0
h, w = mask.shape
for i in prange(h):
for j in range(w):
if mask[i, j] > 0:
count += 1
return count
@jit(nopython=True, parallel=True, cache=True)
def compute_mask_stats_numba(mask: np.ndarray) -> tuple:
"""
快速计算 mask 的统计信息(多核并行)
返回: (area, center_x, center_y, min_x, max_x, min_y, max_y)
"""
h, w = mask.shape
count = 0
sum_x = 0.0
sum_y = 0.0
min_x = w
max_x = 0
min_y = h
max_y = 0
for i in prange(h):
for j in range(w):
if mask[i, j] > 0:
count += 1
sum_x += j
sum_y += i
if j < min_x:
min_x = j
if j > max_x:
max_x = j
if i < min_y:
min_y = i
if i > max_y:
max_y = i
if count > 0:
center_x = sum_x / count
center_y = sum_y / count
else:
center_x = 0.0
center_y = 0.0
return (count, center_x, center_y, min_x, max_x, min_y, max_y)
@jit(nopython=True, parallel=True, cache=True)
def bitwise_and_count_numba(mask1: np.ndarray, mask2: np.ndarray) -> int:
"""快速计算两个 mask 的交集像素数量(多核并行)"""
h, w = mask1.shape
count = 0
for i in prange(h):
for j in range(w):
if mask1[i, j] > 0 and mask2[i, j] > 0:
count += 1
return count
@jit(nopython=True, parallel=True, cache=True)
def resize_mask_nearest_numba(mask: np.ndarray, new_h: int, new_w: int) -> np.ndarray:
"""
快速最近邻插值缩放 mask多核并行
注意:这是简化实现,对于大多数情况足够用
"""
old_h, old_w = mask.shape
result = np.zeros((new_h, new_w), dtype=np.uint8)
scale_y = old_h / new_h
scale_x = old_w / new_w
for i in prange(new_h):
for j in range(new_w):
src_y = int(i * scale_y)
src_x = int(j * scale_x)
if src_y >= old_h:
src_y = old_h - 1
if src_x >= old_w:
src_x = old_w - 1
result[i, j] = mask[src_y, src_x]
return result
# 对外接口:根据 Numba 是否可用选择实现
def count_mask_pixels(mask: np.ndarray) -> int:
"""计算 mask 中非零像素数量"""
if NUMBA_AVAILABLE:
return count_mask_pixels_numba(mask)
else:
return int(np.sum(mask > 0))
def compute_mask_stats(mask: np.ndarray) -> dict:
"""
计算 mask 的统计信息
返回: {'area': int, 'center_x': float, 'center_y': float, 'bbox': (x1, y1, x2, y2)}
"""
if NUMBA_AVAILABLE:
area, cx, cy, min_x, max_x, min_y, max_y = compute_mask_stats_numba(mask)
return {
'area': int(area),
'center_x': float(cx),
'center_y': float(cy),
'bbox': (int(min_x), int(min_y), int(max_x), int(max_y))
}
else:
# numpy 回退
y_coords, x_coords = np.where(mask > 0)
if len(y_coords) == 0:
return {'area': 0, 'center_x': 0, 'center_y': 0, 'bbox': (0, 0, 0, 0)}
return {
'area': len(y_coords),
'center_x': float(np.mean(x_coords)),
'center_y': float(np.mean(y_coords)),
'bbox': (int(np.min(x_coords)), int(np.min(y_coords)),
int(np.max(x_coords)), int(np.max(y_coords)))
}
def bitwise_and_count(mask1: np.ndarray, mask2: np.ndarray) -> int:
"""计算两个 mask 的交集像素数量"""
if NUMBA_AVAILABLE:
return bitwise_and_count_numba(mask1.astype(np.uint8), mask2.astype(np.uint8))
else:
return int(np.sum(np.bitwise_and(mask1, mask2) > 0))
# 预热 JIT 编译(首次调用时编译,之后使用缓存)
def warmup():
"""预热 Numba JIT 编译,避免首次调用时的延迟"""
if NUMBA_AVAILABLE:
dummy = np.zeros((10, 10), dtype=np.uint8)
dummy[5, 5] = 255
count_mask_pixels_numba(dummy)
compute_mask_stats_numba(dummy)
bitwise_and_count_numba(dummy, dummy)
print("[NUMBA] JIT 编译预热完成,已启用多核加速")
if __name__ == "__main__":
# 测试和性能对比
import time
# 创建测试数据
test_mask = np.zeros((480, 640), dtype=np.uint8)
test_mask[100:300, 200:400] = 255
# 测试 numpy 版本
start = time.perf_counter()
for _ in range(100):
np.sum(test_mask > 0)
numpy_time = (time.perf_counter() - start) * 1000
# 测试 numba 版本
if NUMBA_AVAILABLE:
# 预热
count_mask_pixels_numba(test_mask)
start = time.perf_counter()
for _ in range(100):
count_mask_pixels_numba(test_mask)
numba_time = (time.perf_counter() - start) * 1000
print(f"numpy: {numpy_time:.2f}ms / 100 次")
print(f"numba: {numba_time:.2f}ms / 100 次")
print(f"加速比: {numpy_time / numba_time:.1f}x")