# numba_utils.py - Day 20 Numba 多核加速工具 """ 使用 Numba JIT 编译加速 numpy 密集操作,绕过 Python GIL 实现真正多核并行 """ import numpy as np try: from numba import jit, prange NUMBA_AVAILABLE = True except ImportError: NUMBA_AVAILABLE = False print("[NUMBA] Numba 未安装,使用 numpy 回退实现") if NUMBA_AVAILABLE: @jit(nopython=True, parallel=True, cache=True) def count_mask_pixels_numba(mask: np.ndarray) -> int: """快速计算 mask 中非零像素数量(多核并行)""" count = 0 h, w = mask.shape for i in prange(h): for j in range(w): if mask[i, j] > 0: count += 1 return count @jit(nopython=True, parallel=True, cache=True) def compute_mask_stats_numba(mask: np.ndarray) -> tuple: """ 快速计算 mask 的统计信息(多核并行) 返回: (area, center_x, center_y, min_x, max_x, min_y, max_y) """ h, w = mask.shape count = 0 sum_x = 0.0 sum_y = 0.0 min_x = w max_x = 0 min_y = h max_y = 0 for i in prange(h): for j in range(w): if mask[i, j] > 0: count += 1 sum_x += j sum_y += i if j < min_x: min_x = j if j > max_x: max_x = j if i < min_y: min_y = i if i > max_y: max_y = i if count > 0: center_x = sum_x / count center_y = sum_y / count else: center_x = 0.0 center_y = 0.0 return (count, center_x, center_y, min_x, max_x, min_y, max_y) @jit(nopython=True, parallel=True, cache=True) def bitwise_and_count_numba(mask1: np.ndarray, mask2: np.ndarray) -> int: """快速计算两个 mask 的交集像素数量(多核并行)""" h, w = mask1.shape count = 0 for i in prange(h): for j in range(w): if mask1[i, j] > 0 and mask2[i, j] > 0: count += 1 return count @jit(nopython=True, parallel=True, cache=True) def resize_mask_nearest_numba(mask: np.ndarray, new_h: int, new_w: int) -> np.ndarray: """ 快速最近邻插值缩放 mask(多核并行) 注意:这是简化实现,对于大多数情况足够用 """ old_h, old_w = mask.shape result = np.zeros((new_h, new_w), dtype=np.uint8) scale_y = old_h / new_h scale_x = old_w / new_w for i in prange(new_h): for j in range(new_w): src_y = int(i * scale_y) src_x = int(j * scale_x) if src_y >= old_h: src_y = old_h - 1 if src_x >= old_w: src_x = old_w - 1 result[i, j] = mask[src_y, src_x] return result # 对外接口:根据 Numba 是否可用选择实现 def count_mask_pixels(mask: np.ndarray) -> int: """计算 mask 中非零像素数量""" if NUMBA_AVAILABLE: return count_mask_pixels_numba(mask) else: return int(np.sum(mask > 0)) def compute_mask_stats(mask: np.ndarray) -> dict: """ 计算 mask 的统计信息 返回: {'area': int, 'center_x': float, 'center_y': float, 'bbox': (x1, y1, x2, y2)} """ if NUMBA_AVAILABLE: area, cx, cy, min_x, max_x, min_y, max_y = compute_mask_stats_numba(mask) return { 'area': int(area), 'center_x': float(cx), 'center_y': float(cy), 'bbox': (int(min_x), int(min_y), int(max_x), int(max_y)) } else: # numpy 回退 y_coords, x_coords = np.where(mask > 0) if len(y_coords) == 0: return {'area': 0, 'center_x': 0, 'center_y': 0, 'bbox': (0, 0, 0, 0)} return { 'area': len(y_coords), 'center_x': float(np.mean(x_coords)), 'center_y': float(np.mean(y_coords)), 'bbox': (int(np.min(x_coords)), int(np.min(y_coords)), int(np.max(x_coords)), int(np.max(y_coords))) } def bitwise_and_count(mask1: np.ndarray, mask2: np.ndarray) -> int: """计算两个 mask 的交集像素数量""" if NUMBA_AVAILABLE: return bitwise_and_count_numba(mask1.astype(np.uint8), mask2.astype(np.uint8)) else: return int(np.sum(np.bitwise_and(mask1, mask2) > 0)) # 预热 JIT 编译(首次调用时编译,之后使用缓存) def warmup(): """预热 Numba JIT 编译,避免首次调用时的延迟""" if NUMBA_AVAILABLE: dummy = np.zeros((10, 10), dtype=np.uint8) dummy[5, 5] = 255 count_mask_pixels_numba(dummy) compute_mask_stats_numba(dummy) bitwise_and_count_numba(dummy, dummy) print("[NUMBA] JIT 编译预热完成,已启用多核加速") if __name__ == "__main__": # 测试和性能对比 import time # 创建测试数据 test_mask = np.zeros((480, 640), dtype=np.uint8) test_mask[100:300, 200:400] = 255 # 测试 numpy 版本 start = time.perf_counter() for _ in range(100): np.sum(test_mask > 0) numpy_time = (time.perf_counter() - start) * 1000 # 测试 numba 版本 if NUMBA_AVAILABLE: # 预热 count_mask_pixels_numba(test_mask) start = time.perf_counter() for _ in range(100): count_mask_pixels_numba(test_mask) numba_time = (time.perf_counter() - start) * 1000 print(f"numpy: {numpy_time:.2f}ms / 100 次") print(f"numba: {numba_time:.2f}ms / 100 次") print(f"加速比: {numpy_time / numba_time:.1f}x")