186 lines
6.0 KiB
Python
186 lines
6.0 KiB
Python
# numba_utils.py - Day 20 Numba 多核加速工具
|
||
"""
|
||
使用 Numba JIT 编译加速 numpy 密集操作,绕过 Python GIL 实现真正多核并行
|
||
"""
|
||
|
||
import numpy as np
|
||
|
||
try:
|
||
from numba import jit, prange
|
||
NUMBA_AVAILABLE = True
|
||
except ImportError:
|
||
NUMBA_AVAILABLE = False
|
||
print("[NUMBA] Numba 未安装,使用 numpy 回退实现")
|
||
|
||
|
||
if NUMBA_AVAILABLE:
|
||
@jit(nopython=True, parallel=True, cache=True)
|
||
def count_mask_pixels_numba(mask: np.ndarray) -> int:
|
||
"""快速计算 mask 中非零像素数量(多核并行)"""
|
||
count = 0
|
||
h, w = mask.shape
|
||
for i in prange(h):
|
||
for j in range(w):
|
||
if mask[i, j] > 0:
|
||
count += 1
|
||
return count
|
||
|
||
@jit(nopython=True, parallel=True, cache=True)
|
||
def compute_mask_stats_numba(mask: np.ndarray) -> tuple:
|
||
"""
|
||
快速计算 mask 的统计信息(多核并行)
|
||
返回: (area, center_x, center_y, min_x, max_x, min_y, max_y)
|
||
"""
|
||
h, w = mask.shape
|
||
count = 0
|
||
sum_x = 0.0
|
||
sum_y = 0.0
|
||
min_x = w
|
||
max_x = 0
|
||
min_y = h
|
||
max_y = 0
|
||
|
||
for i in prange(h):
|
||
for j in range(w):
|
||
if mask[i, j] > 0:
|
||
count += 1
|
||
sum_x += j
|
||
sum_y += i
|
||
if j < min_x:
|
||
min_x = j
|
||
if j > max_x:
|
||
max_x = j
|
||
if i < min_y:
|
||
min_y = i
|
||
if i > max_y:
|
||
max_y = i
|
||
|
||
if count > 0:
|
||
center_x = sum_x / count
|
||
center_y = sum_y / count
|
||
else:
|
||
center_x = 0.0
|
||
center_y = 0.0
|
||
|
||
return (count, center_x, center_y, min_x, max_x, min_y, max_y)
|
||
|
||
@jit(nopython=True, parallel=True, cache=True)
|
||
def bitwise_and_count_numba(mask1: np.ndarray, mask2: np.ndarray) -> int:
|
||
"""快速计算两个 mask 的交集像素数量(多核并行)"""
|
||
h, w = mask1.shape
|
||
count = 0
|
||
for i in prange(h):
|
||
for j in range(w):
|
||
if mask1[i, j] > 0 and mask2[i, j] > 0:
|
||
count += 1
|
||
return count
|
||
|
||
@jit(nopython=True, parallel=True, cache=True)
|
||
def resize_mask_nearest_numba(mask: np.ndarray, new_h: int, new_w: int) -> np.ndarray:
|
||
"""
|
||
快速最近邻插值缩放 mask(多核并行)
|
||
注意:这是简化实现,对于大多数情况足够用
|
||
"""
|
||
old_h, old_w = mask.shape
|
||
result = np.zeros((new_h, new_w), dtype=np.uint8)
|
||
|
||
scale_y = old_h / new_h
|
||
scale_x = old_w / new_w
|
||
|
||
for i in prange(new_h):
|
||
for j in range(new_w):
|
||
src_y = int(i * scale_y)
|
||
src_x = int(j * scale_x)
|
||
if src_y >= old_h:
|
||
src_y = old_h - 1
|
||
if src_x >= old_w:
|
||
src_x = old_w - 1
|
||
result[i, j] = mask[src_y, src_x]
|
||
|
||
return result
|
||
|
||
|
||
# 对外接口:根据 Numba 是否可用选择实现
|
||
def count_mask_pixels(mask: np.ndarray) -> int:
|
||
"""计算 mask 中非零像素数量"""
|
||
if NUMBA_AVAILABLE:
|
||
return count_mask_pixels_numba(mask)
|
||
else:
|
||
return int(np.sum(mask > 0))
|
||
|
||
|
||
def compute_mask_stats(mask: np.ndarray) -> dict:
|
||
"""
|
||
计算 mask 的统计信息
|
||
返回: {'area': int, 'center_x': float, 'center_y': float, 'bbox': (x1, y1, x2, y2)}
|
||
"""
|
||
if NUMBA_AVAILABLE:
|
||
area, cx, cy, min_x, max_x, min_y, max_y = compute_mask_stats_numba(mask)
|
||
return {
|
||
'area': int(area),
|
||
'center_x': float(cx),
|
||
'center_y': float(cy),
|
||
'bbox': (int(min_x), int(min_y), int(max_x), int(max_y))
|
||
}
|
||
else:
|
||
# numpy 回退
|
||
y_coords, x_coords = np.where(mask > 0)
|
||
if len(y_coords) == 0:
|
||
return {'area': 0, 'center_x': 0, 'center_y': 0, 'bbox': (0, 0, 0, 0)}
|
||
return {
|
||
'area': len(y_coords),
|
||
'center_x': float(np.mean(x_coords)),
|
||
'center_y': float(np.mean(y_coords)),
|
||
'bbox': (int(np.min(x_coords)), int(np.min(y_coords)),
|
||
int(np.max(x_coords)), int(np.max(y_coords)))
|
||
}
|
||
|
||
|
||
def bitwise_and_count(mask1: np.ndarray, mask2: np.ndarray) -> int:
|
||
"""计算两个 mask 的交集像素数量"""
|
||
if NUMBA_AVAILABLE:
|
||
return bitwise_and_count_numba(mask1.astype(np.uint8), mask2.astype(np.uint8))
|
||
else:
|
||
return int(np.sum(np.bitwise_and(mask1, mask2) > 0))
|
||
|
||
|
||
# 预热 JIT 编译(首次调用时编译,之后使用缓存)
|
||
def warmup():
|
||
"""预热 Numba JIT 编译,避免首次调用时的延迟"""
|
||
if NUMBA_AVAILABLE:
|
||
dummy = np.zeros((10, 10), dtype=np.uint8)
|
||
dummy[5, 5] = 255
|
||
count_mask_pixels_numba(dummy)
|
||
compute_mask_stats_numba(dummy)
|
||
bitwise_and_count_numba(dummy, dummy)
|
||
print("[NUMBA] JIT 编译预热完成,已启用多核加速")
|
||
|
||
|
||
if __name__ == "__main__":
|
||
# 测试和性能对比
|
||
import time
|
||
|
||
# 创建测试数据
|
||
test_mask = np.zeros((480, 640), dtype=np.uint8)
|
||
test_mask[100:300, 200:400] = 255
|
||
|
||
# 测试 numpy 版本
|
||
start = time.perf_counter()
|
||
for _ in range(100):
|
||
np.sum(test_mask > 0)
|
||
numpy_time = (time.perf_counter() - start) * 1000
|
||
|
||
# 测试 numba 版本
|
||
if NUMBA_AVAILABLE:
|
||
# 预热
|
||
count_mask_pixels_numba(test_mask)
|
||
|
||
start = time.perf_counter()
|
||
for _ in range(100):
|
||
count_mask_pixels_numba(test_mask)
|
||
numba_time = (time.perf_counter() - start) * 1000
|
||
|
||
print(f"numpy: {numpy_time:.2f}ms / 100 次")
|
||
print(f"numba: {numba_time:.2f}ms / 100 次")
|
||
print(f"加速比: {numpy_time / numba_time:.1f}x")
|