diff --git a/imutils/__init__.py b/imutils/__init__.py index e5643ca..665d183 100755 --- a/imutils/__init__.py +++ b/imutils/__init__.py @@ -14,10 +14,37 @@ from .convenience import url_to_image from .convenience import auto_canny from .convenience import grab_contours +from .convenience import find_contours from .convenience import is_cv2 from .convenience import is_cv3 from .convenience import is_cv4 -from .convenience import check_opencv_version +from .convenience import is_cv5 from .convenience import build_montages from .convenience import adjust_brightness_contrast from .meta import find_function + +# import bbox module functions +from .bbox import iou +from .bbox import xywh_to_xyxy +from .bbox import xyxy_to_xywh +from .bbox import center_xywh_to_xyxy +from .bbox import xyxy_to_center_xywh +from .bbox import draw_bbox +from .bbox import resize_bbox +from .bbox import clip_bbox +from .bbox import bbox_area +from .bbox import batch_iou + +# import augmentation module functions +from .augmentation import random_brightness_contrast +from .augmentation import add_gaussian_noise +from .augmentation import add_salt_pepper_noise +from .augmentation import cutout +from .augmentation import random_flip +from .augmentation import random_rotate +from .augmentation import random_crop +from .augmentation import random_blur +from .augmentation import color_jitter +from .augmentation import mixup +from .augmentation import random_perspective +from .augmentation import apply_augmentations diff --git a/imutils/augmentation.py b/imutils/augmentation.py new file mode 100644 index 0000000..8e3b38b --- /dev/null +++ b/imutils/augmentation.py @@ -0,0 +1,326 @@ +# author: PyImageSearch +# website: http://www.pyimagesearch.com + +# import the necessary packages +import numpy as np +import cv2 +import random + +def random_brightness_contrast(image, brightness_range=(-30, 30), contrast_range=(0.8, 1.2)): + """ + 随机调整图像的亮度和对比度 + + :param image: 输入图像 (OpenCV BGR 格式) + :param brightness_range: 亮度调整范围 (delta 值) + :param contrast_range: 对比度调整范围 (乘数) + :return: 增强后的图像 + """ + # 随机生成亮度和对比度参数 + brightness = random.uniform(brightness_range[0], brightness_range[1]) + contrast = random.uniform(contrast_range[0], contrast_range[1]) + + # 应用调整 + # 对比度调整: image * contrast + # 亮度调整: + brightness + adjusted = cv2.convertScaleAbs(image, alpha=contrast, beta=brightness) + + return adjusted + + +def add_gaussian_noise(image, mean=0, std=25): + """ + 向图像添加高斯噪声 + + :param image: 输入图像 + :param mean: 噪声均值 + :param std: 噪声标准差 + :return: 添加噪声后的图像 + """ + # 生成高斯噪声 + noise = np.random.normal(mean, std, image.shape).astype(np.float32) + + # 将图像转换为 float 并添加噪声 + noisy_image = image.astype(np.float32) + noise + + # 裁剪到有效范围并转换回 uint8 + noisy_image = np.clip(noisy_image, 0, 255).astype(np.uint8) + + return noisy_image + + +def add_salt_pepper_noise(image, salt_prob=0.01, pepper_prob=0.01): + """ + 向图像添加椒盐噪声 + + :param image: 输入图像 + :param salt_prob: 盐噪声 (白点) 概率 + :param pepper_prob: 椒噪声 (黑点) 概率 + :return: 添加噪声后的图像 + """ + noisy_image = image.copy() + total_pixels = image.shape[0] * image.shape[1] + + # 添加盐噪声 (白点) + num_salt = int(total_pixels * salt_prob) + salt_coords = [ + np.random.randint(0, i, num_salt) + for i in image.shape[:2] + ] + noisy_image[salt_coords[0], salt_coords[1]] = 255 + + # 添加椒噪声 (黑点) + num_pepper = int(total_pixels * pepper_prob) + pepper_coords = [ + np.random.randint(0, i, num_pepper) + for i in image.shape[:2] + ] + noisy_image[pepper_coords[0], pepper_coords[1]] = 0 + + return noisy_image + + +def cutout(image, num_holes=1, max_h_size=8, max_w_size=8, fill_value=0): + """ + Cutout 数据增强 - 在图像上随机遮挡矩形区域 + + 参考论文: "Improved Regularization of Convolutional Neural Networks with Cutout" + + :param image: 输入图像 + :param num_holes: 遮挡区域数量 + :param max_h_size: 遮挡区域最大高度 + :param max_w_size: 遮挡区域最大宽度 + :param fill_value: 遮挡填充值 (默认 0 为黑色) + :return: 应用 Cutout 后的图像 + """ + h, w = image.shape[:2] + result = image.copy() + + for _ in range(num_holes): + # 随机生成遮挡区域尺寸 + hole_h = random.randint(1, max_h_size) + hole_w = random.randint(1, max_w_size) + + # 随机生成遮挡区域位置 + y1 = random.randint(0, h - hole_h) if h > hole_h else 0 + x1 = random.randint(0, w - hole_w) if w > hole_w else 0 + + y2 = min(y1 + hole_h, h) + x2 = min(x1 + hole_w, w) + + # 应用遮挡 + if len(image.shape) == 3: + result[y1:y2, x1:x2, :] = fill_value + else: + result[y1:y2, x1:x2] = fill_value + + return result + + +def random_flip(image, flip_code=1): + """ + 随机水平或垂直翻转图像 + + :param image: 输入图像 + :param flip_code: 翻转代码 (0=垂直, 1=水平, -1=双向) + :return: 翻转后的图像 (50% 概率) + """ + if random.random() > 0.5: + return cv2.flip(image, flip_code) + return image + + +def random_rotate(image, angle_range=(-15, 15), scale=1.0): + """ + 随机旋转图像 + + :param image: 输入图像 + :param angle_range: 旋转角度范围 + :param scale: 缩放比例 + :return: 旋转后的图像 + """ + angle = random.uniform(angle_range[0], angle_range[1]) + + h, w = image.shape[:2] + center = (w // 2, h // 2) + + # 获取旋转矩阵 + M = cv2.getRotationMatrix2D(center, angle, scale) + + # 应用旋转 + rotated = cv2.warpAffine(image, M, (w, h), borderMode=cv2.BORDER_CONSTANT, + borderValue=(128, 128, 128)) + + return rotated + + +def random_crop(image, crop_ratio=(0.8, 1.0)): + """ + 随机裁剪图像 + + :param image: 输入图像 + :param crop_ratio: 裁剪比例范围 (相对于原图) + :return: 裁剪后的图像 + """ + h, w = image.shape[:2] + + # 随机选择裁剪比例 + ratio = random.uniform(crop_ratio[0], crop_ratio[1]) + + # 计算裁剪尺寸 + new_h = int(h * ratio) + new_w = int(w * ratio) + + # 随机选择裁剪位置 + y1 = random.randint(0, h - new_h) if h > new_h else 0 + x1 = random.randint(0, w - new_w) if w > new_w else 0 + + # 裁剪 + cropped = image[y1:y1+new_h, x1:x1+new_w] + + # 调整回原尺寸 + return cv2.resize(cropped, (w, h)) + + +def random_blur(image, kernel_size_range=(3, 7)): + """ + 随机模糊图像 + + :param image: 输入图像 + :param kernel_size_range: 模糊核大小范围 (奇数) + :return: 模糊后的图像 (50% 概率) + """ + if random.random() > 0.5: + # 确保核大小为奇数 + k = random.randint(kernel_size_range[0] // 2, kernel_size_range[1] // 2) + kernel_size = 2 * k + 1 + + # 随机选择模糊类型 + if random.random() > 0.5: + return cv2.GaussianBlur(image, (kernel_size, kernel_size), 0) + else: + return cv2.medianBlur(image, kernel_size) + return image + + +def color_jitter(image, hue_range=(-10, 10), saturation_range=(-30, 30), value_range=(-30, 30)): + """ + 颜色抖动 - 随机调整色调、饱和度和明度 + + :param image: 输入图像 (BGR 格式) + :param hue_range: 色调调整范围 + :param saturation_range: 饱和度调整范围 + :param value_range: 明度调整范围 + :return: 调整后的图像 + """ + # 转换到 HSV 色彩空间 + hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV).astype(np.float32) + + # 随机调整 + hue_delta = random.uniform(hue_range[0], hue_range[1]) + sat_delta = random.uniform(saturation_range[0], saturation_range[1]) + val_delta = random.uniform(value_range[0], value_range[1]) + + hsv[:, :, 0] = (hsv[:, :, 0] + hue_delta) % 180 + hsv[:, :, 1] = np.clip(hsv[:, :, 1] + sat_delta, 0, 255) + hsv[:, :, 2] = np.clip(hsv[:, :, 2] + val_delta, 0, 255) + + # 转换回 BGR + return cv2.cvtColor(hsv.astype(np.uint8), cv2.COLOR_HSV2BGR) + + +def mixup(image1, image2, alpha=0.4): + """ + Mixup 数据增强 - 将两张图像按一定比例混合 + + 参考论文: "mixup: Beyond Empirical Risk Minimization" + + :param image1: 第一张图像 + :param image2: 第二张图像 + :param alpha: Beta 分布参数 + :return: 混合后的图像 + """ + # 确保尺寸相同 + if image1.shape != image2.shape: + image2 = cv2.resize(image2, (image1.shape[1], image1.shape[0])) + + # 从 Beta 分布采样混合比例 + lam = np.random.beta(alpha, alpha) + + # 混合图像 + mixed = lam * image1.astype(np.float32) + (1 - lam) * image2.astype(np.float32) + + return mixed.astype(np.uint8) + + +def random_perspective(image, distortion_scale=0.2): + """ + 随机透视变换 + + :param image: 输入图像 + :param distortion_scale: 扭曲程度 + :return: 透视变换后的图像 + """ + h, w = image.shape[:2] + + # 定义四个角点 + margin = min(h, w) * distortion_scale + + pts1 = np.float32([ + [0, 0], + [w, 0], + [w, h], + [0, h] + ]) + + # 随机扰动角点 + pts2 = np.float32([ + [random.uniform(0, margin), random.uniform(0, margin)], + [w - random.uniform(0, margin), random.uniform(0, margin)], + [w - random.uniform(0, margin), h - random.uniform(0, margin)], + [random.uniform(0, margin), h - random.uniform(0, margin)] + ]) + + # 计算透视变换矩阵 + M = cv2.getPerspectiveTransform(pts1, pts2) + + # 应用变换 + return cv2.warpPerspective(image, M, (w, h), borderMode=cv2.BORDER_CONSTANT, + borderValue=(128, 128, 128)) + + +def apply_augmentations(image, aug_list=None): + """ + 应用一系列数据增强 + + :param image: 输入图像 + :param aug_list: 增强操作列表,如 ['flip', 'rotate', 'brightness', 'cutout'] + :return: 增强后的图像 + """ + if aug_list is None: + aug_list = ['flip', 'brightness', 'cutout'] + + result = image.copy() + + for aug in aug_list: + if aug == 'flip': + result = random_flip(result) + elif aug == 'rotate': + result = random_rotate(result) + elif aug == 'brightness': + result = random_brightness_contrast(result) + elif aug == 'cutout': + result = cutout(result) + elif aug == 'gaussian_noise': + result = add_gaussian_noise(result) + elif aug == 'salt_pepper': + result = add_salt_pepper_noise(result) + elif aug == 'blur': + result = random_blur(result) + elif aug == 'color_jitter': + result = color_jitter(result) + elif aug == 'crop': + result = random_crop(result) + elif aug == 'perspective': + result = random_perspective(result) + + return result diff --git a/imutils/bbox.py b/imutils/bbox.py new file mode 100644 index 0000000..d266a01 --- /dev/null +++ b/imutils/bbox.py @@ -0,0 +1,255 @@ +# author: PyImageSearch +# website: http://www.pyimagesearch.com + +# import the necessary packages +import numpy as np +import cv2 + +def iou(boxA, boxB): + """ + 计算两个边界框的交并比 (Intersection over Union) + + :param boxA: 第一个边界框 [xmin, ymin, xmax, ymax] 或 [x, y, w, h] + :param boxB: 第二个边界框 [xmin, ymin, xmax, ymax] 或 [x, y, w, h] + :return: IoU 值 (0.0 ~ 1.0) + """ + # 确保是 xyxy 格式 + if len(boxA) == 4 and len(boxB) == 4: + # 检查是否是 xywh 格式 (通过判断 xmax > xmin) + if boxA[2] < boxA[0] or boxA[2] > boxA[0] * 10: # 简单启发式判断 + boxA = xywh_to_xyxy(boxA) + if boxB[2] < boxB[0] or boxB[2] > boxB[0] * 10: + boxB = xywh_to_xyxy(boxB) + + # 计算交集区域 + xA = max(boxA[0], boxB[0]) + yA = max(boxA[1], boxB[1]) + xB = min(boxA[2], boxB[2]) + yB = min(boxA[3], boxB[3]) + + # 计算交集面积 + interArea = max(0, xB - xA) * max(0, yB - yA) + + # 计算两个框的面积 + boxAArea = (boxA[2] - boxA[0]) * (boxA[3] - boxA[1]) + boxBArea = (boxB[2] - boxB[0]) * (boxB[3] - boxB[1]) + + # 计算 IoU + iou = interArea / float(boxAArea + boxBArea - interArea + 1e-6) + + return iou + + +def xywh_to_xyxy(box): + """ + 将 [x, y, w, h] 格式转换为 [xmin, ymin, xmax, ymax] 格式 + + :param box: [x, y, w, h] 格式的边界框 + :return: [xmin, ymin, xmax, ymax] 格式的边界框 + """ + x, y, w, h = box + return [x, y, x + w, y + h] + + +def xyxy_to_xywh(box): + """ + 将 [xmin, ymin, xmax, ymax] 格式转换为 [x, y, w, h] 格式 + + :param box: [xmin, ymin, xmax, ymax] 格式的边界框 + :return: [x, y, w, h] 格式的边界框 + """ + xmin, ymin, xmax, ymax = box + return [xmin, ymin, xmax - xmin, ymax - ymin] + + +def center_xywh_to_xyxy(box): + """ + 将中心点格式 [cx, cy, w, h] 转换为 [xmin, ymin, xmax, ymax] + + :param box: [cx, cy, w, h] 格式的边界框 + :return: [xmin, ymin, xmax, ymax] 格式的边界框 + """ + cx, cy, w, h = box + return [cx - w/2, cy - h/2, cx + w/2, cy + h/2] + + +def xyxy_to_center_xywh(box): + """ + 将 [xmin, ymin, xmax, ymax] 格式转换为中心点格式 [cx, cy, w, h] + + :param box: [xmin, ymin, xmax, ymax] 格式的边界框 + :return: [cx, cy, w, h] 格式的边界框 + """ + xmin, ymin, xmax, ymax = box + w = xmax - xmin + h = ymax - ymin + cx = xmin + w / 2 + cy = ymin + h / 2 + return [cx, cy, w, h] + + +def draw_bbox(image, box, label=None, color=(0, 255, 0), thickness=2, + font_scale=0.6, text_color=None, text_bg_color=None): + """ + 在图像上绘制带标签的边界框 + + :param image: OpenCV 图像 (numpy array) + :param box: 边界框 [xmin, ymin, xmax, ymax] 或 [x, y, w, h] + :param label: 标签文本 (可选) + :param color: 框的颜色 (BGR) + :param thickness: 框线粗细 + :param font_scale: 字体大小 + :param text_color: 文字颜色 (默认与框颜色相同) + :param text_bg_color: 文字背景颜色 (默认与框颜色相同) + :return: 绘制后的图像 + """ + # 确保是 xyxy 格式 + if len(box) == 4: + # 启发式判断是否是 xywh 格式 + if box[2] < box[0] or box[2] > box[0] * 10: + box = xywh_to_xyxy(box) + + xmin, ymin, xmax, ymax = [int(v) for v in box] + + # 绘制边界框 + cv2.rectangle(image, (xmin, ymin), (xmax, ymax), color, thickness) + + # 绘制标签 + if label is not None: + if text_color is None: + text_color = color + if text_bg_color is None: + text_bg_color = color + + # 计算文字大小 + (text_width, text_height), _ = cv2.getTextSize( + label, cv2.FONT_HERSHEY_SIMPLEX, font_scale, 1) + + # 绘制文字背景 + cv2.rectangle(image, + (xmin, ymin - text_height - 10), + (xmin + text_width, ymin), + text_bg_color, -1) + + # 绘制文字 + cv2.putText(image, label, (xmin, ymin - 5), + cv2.FONT_HERSHEY_SIMPLEX, font_scale, text_color, 1) + + return image + + +def resize_bbox(box, original_size, new_size, format='xyxy'): + """ + 当图像缩放时,同步缩放边界框坐标 + + :param box: 原始边界框 + :param original_size: 原始图像尺寸 (width, height) + :param new_size: 新图像尺寸 (width, height) + :param format: 边界框格式 ('xyxy' 或 'xywh') + :return: 缩放后的边界框 + """ + orig_w, orig_h = original_size + new_w, new_h = new_size + + # 计算缩放比例 + scale_x = new_w / float(orig_w) + scale_y = new_h / float(orig_h) + + if format == 'xyxy': + xmin, ymin, xmax, ymax = box + new_box = [ + xmin * scale_x, + ymin * scale_y, + xmax * scale_x, + ymax * scale_y + ] + elif format == 'xywh': + x, y, w, h = box + new_box = [ + x * scale_x, + y * scale_y, + w * scale_x, + h * scale_y + ] + else: + raise ValueError("format must be 'xyxy' or 'xywh'") + + return new_box + + +def clip_bbox(box, image_width, image_height, format='xyxy'): + """ + 将边界框裁剪到图像边界内 + + :param box: 边界框 + :param image_width: 图像宽度 + :param image_height: 图像高度 + :param format: 边界框格式 ('xyxy' 或 'xywh') + :return: 裁剪后的边界框 + """ + if format == 'xyxy': + xmin, ymin, xmax, ymax = box + xmin = max(0, min(xmin, image_width - 1)) + ymin = max(0, min(ymin, image_height - 1)) + xmax = max(0, min(xmax, image_width - 1)) + ymax = max(0, min(ymax, image_height - 1)) + return [xmin, ymin, xmax, ymax] + elif format == 'xywh': + x, y, w, h = box + x = max(0, min(x, image_width - 1)) + y = max(0, min(y, image_height - 1)) + w = min(w, image_width - x) + h = min(h, image_height - y) + return [x, y, w, h] + else: + raise ValueError("format must be 'xyxy' or 'xywh'") + + +def bbox_area(box, format='xyxy'): + """ + 计算边界框面积 + + :param box: 边界框 + :param format: 边界框格式 ('xyxy' 或 'xywh') + :return: 面积 + """ + if format == 'xyxy': + xmin, ymin, xmax, ymax = box + return max(0, xmax - xmin) * max(0, ymax - ymin) + elif format == 'xywh': + x, y, w, h = box + return w * h + else: + raise ValueError("format must be 'xyxy' or 'xywh'") + + +def batch_iou(boxesA, boxesB): + """ + 批量计算两组边界框之间的 IoU + + :param boxesA: 第一组边界框,形状为 (N, 4) + :param boxesB: 第二组边界框,形状为 (M, 4) + :return: IoU 矩阵,形状为 (N, M) + """ + boxesA = np.array(boxesA) + boxesB = np.array(boxesB) + + # 计算面积 + areaA = (boxesA[:, 2] - boxesA[:, 0]) * (boxesA[:, 3] - boxesA[:, 1]) + areaB = (boxesB[:, 2] - boxesB[:, 0]) * (boxesB[:, 3] - boxesB[:, 1]) + + # 计算交集 + xmin = np.maximum(boxesA[:, None, 0], boxesB[None, :, 0]) + ymin = np.maximum(boxesA[:, None, 1], boxesB[None, :, 1]) + xmax = np.minimum(boxesA[:, None, 2], boxesB[None, :, 2]) + ymax = np.minimum(boxesA[:, None, 3], boxesB[None, :, 3]) + + inter_w = np.maximum(0, xmax - xmin) + inter_h = np.maximum(0, ymax - ymin) + inter_area = inter_w * inter_h + + # 计算 IoU + union_area = areaA[:, None] + areaB[None, :] - inter_area + iou_matrix = inter_area / (union_area + 1e-6) + + return iou_matrix diff --git a/imutils/convenience.py b/imutils/convenience.py index 9704eb2..e73c4e3 100644 --- a/imutils/convenience.py +++ b/imutils/convenience.py @@ -62,7 +62,18 @@ def rotate_bound(image, angle): # perform the actual rotation and return the image return cv2.warpAffine(image, M, (nW, nH)) -def resize(image, width=None, height=None, inter=cv2.INTER_AREA): +def resize(image, width=None, height=None, inter=cv2.INTER_AREA, bboxes=None): + """ + Resize the image, optionally scaling bounding boxes accordingly. + + :param image: Input image + :param width: Target width (optional) + :param height: Target height (optional) + :param inter: Interpolation method + :param bboxes: List of bounding boxes to resize along with image. + Each box can be [x, y, w, h] or [xmin, ymin, xmax, ymax] + :return: Resized image, or (resized_image, resized_bboxes) if bboxes provided + """ # initialize the dimensions of the image to be resized and # grab the image size dim = None @@ -71,6 +82,8 @@ def resize(image, width=None, height=None, inter=cv2.INTER_AREA): # if both the width and height are None, then return the # original image if width is None and height is None: + if bboxes is not None: + return image, bboxes return image # check to see if the width is None @@ -89,6 +102,27 @@ def resize(image, width=None, height=None, inter=cv2.INTER_AREA): # resize the image resized = cv2.resize(image, dim, interpolation=inter) + + # resize bounding boxes if provided + if bboxes is not None: + scale_x = dim[0] / float(w) + scale_y = dim[1] / float(h) + + resized_bboxes = [] + for box in bboxes: + # Detect format: if box[2] < box[0] * 10, assume xywh format + if box[2] < box[0] * 10: + # xywh format + x, y, bw, bh = box + resized_box = [x * scale_x, y * scale_y, bw * scale_x, bh * scale_y] + else: + # xyxy format + xmin, ymin, xmax, ymax = box + resized_box = [xmin * scale_x, ymin * scale_y, + xmax * scale_x, ymax * scale_y] + resized_bboxes.append(resized_box) + + return resized, resized_bboxes # return the resized image return resized @@ -154,7 +188,7 @@ def auto_canny(image, sigma=0.33): def grab_contours(cnts): # if the length the contours tuple returned by cv2.findContours # is '2' then we are using either OpenCV v2.4, v4-beta, or - # v4-official + # v4-official, or v5.x (OpenCV 5.x maintains the same API as 4.x) if len(cnts) == 2: cnts = cnts[0] @@ -174,6 +208,22 @@ def grab_contours(cnts): # return the actual contours array return cnts +def find_contours(image, mode=cv2.RETR_EXTERNAL, method=cv2.CHAIN_APPROX_SIMPLE): + """ + Wrapper around cv2.findContours that handles OpenCV version differences + including OpenCV 2.x, 3.x, 4.x, and 5.x + + :param image: Input binary image + :param mode: Contour retrieval mode + :param method: Contour approximation method + :return: contours (list of numpy arrays) + """ + # OpenCV 5.x and 4.x return (contours, hierarchy) - 2 elements + # OpenCV 3.x returns (image, contours, hierarchy) - 3 elements + # OpenCV 2.x returns (contours, hierarchy) - 2 elements + cnts = cv2.findContours(image.copy(), mode, method) + return grab_contours(cnts) + def is_cv2(or_better=False): # grab the OpenCV major version number major = get_opencv_major_version() @@ -207,6 +257,17 @@ def is_cv4(or_better=False): # otherwise we want to check for *strictly* OpenCV 4 return major == 4 +def is_cv5(or_better=False): + # grab the OpenCV major version number + major = get_opencv_major_version() + + # check to see if we are using *at least* OpenCV 5 + if or_better: + return major >= 5 + + # otherwise we want to check for *strictly* OpenCV 5 + return major == 5 + def get_opencv_major_version(lib=None): # if the supplied library is None, import OpenCV if lib is None: diff --git a/imutils/feature/factories.py b/imutils/feature/factories.py index 12d3719..d454b8a 100755 --- a/imutils/feature/factories.py +++ b/imutils/feature/factories.py @@ -1,4 +1,4 @@ -from ..convenience import is_cv2 +from ..convenience import is_cv2, is_cv3, is_cv4, is_cv5 import cv2 from .dense import DENSE from .gftt import GFTT @@ -26,6 +26,7 @@ def DescriptorMatcher_create(method): return cv2.DescriptorMatcher_create(method) else: + # OpenCV 3.x, 4.x, and 5.x compatible factories try: _DETECTOR_FACTORY = {"BRISK": cv2.BRISK_create, "DENSE": DENSE, @@ -34,26 +35,44 @@ def DescriptorMatcher_create(method): "HARRIS": HARRIS, "MSER": cv2.MSER_create, "ORB": cv2.ORB_create, - "SIFT": cv2.xfeatures2d.SIFT_create, - "SURF": cv2.xfeatures2d.SURF_create, - "STAR": cv2.xfeatures2d.StarDetector_create } - - _EXTRACTOR_FACTORY = {"SIFT": cv2.xfeatures2d.SIFT_create, - "ROOTSIFT": RootSIFT, - "SURF": cv2.xfeatures2d.SURF_create, - "BRIEF": cv2.xfeatures2d.BriefDescriptorExtractor_create, + + _EXTRACTOR_FACTORY = {"ROOTSIFT": RootSIFT, "ORB": cv2.ORB_create, "BRISK": cv2.BRISK_create, - "FREAK": cv2.xfeatures2d.FREAK_create } - + _MATCHER_FACTORY = {"BruteForce": cv2.DESCRIPTOR_MATCHER_BRUTEFORCE, "BruteForce-SL2": cv2.DESCRIPTOR_MATCHER_BRUTEFORCE_SL2, "BruteForce-L1": cv2.DESCRIPTOR_MATCHER_BRUTEFORCE_L1, "BruteForce-Hamming": cv2.DESCRIPTOR_MATCHER_BRUTEFORCE_HAMMING, + "BruteForce-HammingLUT": cv2.DESCRIPTOR_MATCHER_BRUTEFORCE_HAMMINGLUT, "FlannBased": cv2.DESCRIPTOR_MATCHER_FLANNBASED } + + # Try to add contrib modules if available (xfeatures2d) + # Note: In OpenCV 5.x, some of these may be in the main module + try: + import cv2.xfeatures2d as xfd + _DETECTOR_FACTORY.update({ + "SIFT": xfd.SIFT_create, + "SURF": xfd.SURF_create, + "STAR": xfd.StarDetector_create + }) + _EXTRACTOR_FACTORY.update({ + "SIFT": xfd.SIFT_create, + "SURF": xfd.SURF_create, + "BRIEF": xfd.BriefDescriptorExtractor_create, + "FREAK": xfd.FREAK_create + }) + except (AttributeError, ImportError): + # xfeatures2d not available, try main cv2 module (OpenCV 5.x may have SIFT here) + try: + if hasattr(cv2, 'SIFT_create'): + _DETECTOR_FACTORY["SIFT"] = cv2.SIFT_create + _EXTRACTOR_FACTORY["SIFT"] = cv2.SIFT_create + except: + pass except AttributeError: _DETECTOR_FACTORY = {"MSER": cv2.MSER_create, @@ -68,6 +87,12 @@ def DescriptorMatcher_create(method): _EXTRACTOR_FACTORY = {"ORB": cv2.ORB_create, "BRISK": cv2.BRISK_create } + + _MATCHER_FACTORY = {"BruteForce": cv2.DESCRIPTOR_MATCHER_BRUTEFORCE, + "BruteForce-L1": cv2.DESCRIPTOR_MATCHER_BRUTEFORCE_L1, + "BruteForce-Hamming": cv2.DESCRIPTOR_MATCHER_BRUTEFORCE_HAMMING, + "FlannBased": cv2.DESCRIPTOR_MATCHER_FLANNBASED + } _CONTRIB_FUNCS = {"SIFT", "ROOTSIFT", "SURF", "STAR", "BRIEF", "FREAK"} diff --git a/imutils/video/videostream.py b/imutils/video/videostream.py index 4f9de0b..7418fe2 100644 --- a/imutils/video/videostream.py +++ b/imutils/video/videostream.py @@ -3,7 +3,23 @@ class VideoStream: def __init__(self, src=0, usePiCamera=False, resolution=(320, 240), - framerate=32, **kwargs): + framerate=32, enable_reconnect=True, max_reconnect_attempts=5, + reconnect_delay=1.0, frame_timeout=5.0, + skip_frames_on_reconnect=True, **kwargs): + """ + 初始化视频流 + + :param src: 视频源 (摄像头索引或 RTSP/HTTP URL) + :param usePiCamera: 是否使用树莓派摄像头 + :param resolution: 分辨率 (仅用于 PiCamera) + :param framerate: 帧率 (仅用于 PiCamera) + :param enable_reconnect: 是否启用断线重连 (网络流) + :param max_reconnect_attempts: 最大重连尝试次数 + :param reconnect_delay: 重连间隔(秒) + :param frame_timeout: 帧读取超时时间(秒) + :param skip_frames_on_reconnect: 重连后是否跳过缓冲帧 + :param kwargs: 其他参数 + """ # check to see if the picamera module should be used if usePiCamera: # only import the picamera packages unless we are @@ -20,7 +36,15 @@ def __init__(self, src=0, usePiCamera=False, resolution=(320, 240), # otherwise, we are using OpenCV so initialize the webcam # stream else: - self.stream = WebcamVideoStream(src=src) + self.stream = WebcamVideoStream( + src=src, + enable_reconnect=enable_reconnect, + max_reconnect_attempts=max_reconnect_attempts, + reconnect_delay=reconnect_delay, + frame_timeout=frame_timeout, + skip_frames_on_reconnect=skip_frames_on_reconnect, + **kwargs + ) def start(self): # start the threaded video stream @@ -33,6 +57,18 @@ def update(self): def read(self): # return the current frame return self.stream.read() + + def is_connected(self): + """检查流是否连接正常 (仅适用于 WebcamVideoStream)""" + if hasattr(self.stream, 'is_connected'): + return self.stream.is_connected() + return True + + def get_stats(self): + """获取流统计信息 (仅适用于 WebcamVideoStream)""" + if hasattr(self.stream, 'get_stats'): + return self.stream.get_stats() + return {} def stop(self): # stop the thread and release any resources diff --git a/imutils/video/webcamvideostream.py b/imutils/video/webcamvideostream.py index dbe8751..b5116c3 100644 --- a/imutils/video/webcamvideostream.py +++ b/imutils/video/webcamvideostream.py @@ -1,42 +1,244 @@ # import the necessary packages -from threading import Thread +from threading import Thread, Lock import cv2 +import time class WebcamVideoStream: - def __init__(self, src=0, name="WebcamVideoStream"): - # initialize the video camera stream and read the first frame - # from the stream + def __init__(self, src=0, name="WebcamVideoStream", + enable_reconnect=True, max_reconnect_attempts=5, + reconnect_delay=1.0, frame_timeout=5.0, + skip_frames_on_reconnect=True, max_frame_lag=0.1): + """ + 初始化视频流 + + :param src: 视频源 (摄像头索引或 RTSP/HTTP URL) + :param name: 线程名称 + :param enable_reconnect: 是否启用断线重连 + :param max_reconnect_attempts: 最大重连尝试次数 + :param reconnect_delay: 重连间隔(秒) + :param frame_timeout: 帧读取超时时间(秒) + :param skip_frames_on_reconnect: 重连后是否跳过缓冲帧 + :param max_frame_lag: 最大允许帧延迟(秒),超过则跳过过期帧 + """ + # 保存配置参数 + self.src = src + self.enable_reconnect = enable_reconnect + self.max_reconnect_attempts = max_reconnect_attempts + self.reconnect_delay = reconnect_delay + self.frame_timeout = frame_timeout + self.skip_frames_on_reconnect = skip_frames_on_reconnect + self.max_frame_lag = max_frame_lag + + # 初始化视频捕获 self.stream = cv2.VideoCapture(src) + + # 尝试设置缓冲区大小以减小延迟 + self.stream.set(cv2.CAP_PROP_BUFFERSIZE, 1) + + # 读取第一帧 (self.grabbed, self.frame) = self.stream.read() - - # initialize the thread name + + # 初始化线程名称 self.name = name - - # initialize the variable used to indicate if the thread should - # be stopped + + # 初始化停止标志 self.stopped = False + + # 初始化线程锁 + self.lock = Lock() + + # 统计信息 + self.frame_count = 0 + self.error_count = 0 + self.last_frame_time = time.time() + self.is_connected = self.grabbed + + # 重连相关 + self.reconnect_attempts = 0 + self.last_successful_frame = None + + # 丢帧统计 + self.skipped_frames = 0 + + # 获取视频流的FPS用于计算帧延迟 + self.fps = self.stream.get(cv2.CAP_PROP_FPS) + if self.fps <= 0: + self.fps = 30 # 默认30fps def start(self): - # start the thread to read frames from the video stream + # 启动线程读取视频流 t = Thread(target=self.update, name=self.name, args=()) t.daemon = True t.start() return self + def _is_network_stream(self): + """检查是否是网络流 (RTSP/HTTP/HTTPS)""" + if isinstance(self.src, str): + return self.src.startswith(('rtsp://', 'http://', 'https://')) + return False + + def _reconnect(self): + """尝试重新连接视频流""" + if not self.enable_reconnect: + return False + + # 只在网络流上启用重连 + if not self._is_network_stream(): + return False + + with self.lock: + if self.reconnect_attempts >= self.max_reconnect_attempts: + return False + + self.reconnect_attempts += 1 + + # 释放旧连接 + try: + self.stream.release() + except: + pass + + # 等待后重连 + time.sleep(self.reconnect_delay) + + try: + # 创建新连接 + new_stream = cv2.VideoCapture(self.src) + new_stream.set(cv2.CAP_PROP_BUFFERSIZE, 1) + + # 测试读取 + (grabbed, frame) = new_stream.read() + + if grabbed: + with self.lock: + self.stream = new_stream + self.grabbed = grabbed + self.frame = frame + self.reconnect_attempts = 0 + self.is_connected = True + self.last_frame_time = time.time() + + # 如果需要,跳过缓冲帧 + if self.skip_frames_on_reconnect: + for _ in range(5): # 跳过前5帧 + new_stream.grab() + + return True + else: + new_stream.release() + return False + + except Exception as e: + return False + + def _skip_lagged_frames(self): + """ + 跳过延迟的帧以保持实时性。 + 如果缓冲区中的帧延迟超过 max_frame_lag,则连续调用 grab() 跳过过期帧。 + """ + if not self._is_network_stream(): + return + + # 计算当前缓冲区中帧的估计延迟 + # 获取当前缓冲区中的帧数 + buffer_size = self.stream.get(cv2.CAP_PROP_FRAME_COUNT) + + # 尝试获取当前位置 + try: + # 对于实时流,我们检查是否有可读取的帧在缓冲区中 + # 使用 grab 快速跳过帧而不解码 + skipped = 0 + max_skip = int(self.fps * self.max_frame_lag * 2) # 最多跳过的帧数 + + while skipped < max_skip: + # 检查是否有帧可读且会延迟 + if self.stream.grab(): + skipped += 1 + else: + break + + if skipped > 0: + with self.lock: + self.skipped_frames += skipped + + except Exception: + pass + def update(self): - # keep looping infinitely until the thread is stopped + # 持续读取帧直到线程停止 while True: - # if the thread indicator variable is set, stop the thread + # 如果线程指示停止,则退出 if self.stopped: return - - # otherwise, read the next frame from the stream - (self.grabbed, self.frame) = self.stream.read() + + try: + # 在读取新帧之前,检查并跳过延迟的帧 + self._skip_lagged_frames() + + # 读取下一帧 + (grabbed, frame) = self.stream.read() + + with self.lock: + if grabbed: + self.grabbed = True + self.frame = frame + self.frame_count += 1 + self.last_frame_time = time.time() + self.is_connected = True + self.reconnect_attempts = 0 + self.last_successful_frame = frame.copy() + else: + # 读取失败 + self.error_count += 1 + self.is_connected = False + + # 检查是否需要重连 + current_time = time.time() + time_since_last_frame = current_time - self.last_frame_time + + if time_since_last_frame > self.frame_timeout: + if self._reconnect(): + continue + + # 如果无法重连,使用最后一帧 + if self.last_successful_frame is not None: + self.frame = self.last_successful_frame.copy() + self.grabbed = True + else: + self.grabbed = False + + except Exception as e: + self.error_count += 1 + with self.lock: + self.is_connected = False + + # 短暂休眠避免 CPU 占用过高 + time.sleep(0.001) def read(self): - # return the frame most recently read - return self.frame + # 返回最近读取的帧 + with self.lock: + return self.frame.copy() if self.frame is not None else None + + def is_connected(self): + """检查流是否连接正常""" + with self.lock: + return self.is_connected + + def get_stats(self): + """获取流统计信息""" + with self.lock: + return { + 'frame_count': self.frame_count, + 'error_count': self.error_count, + 'skipped_frames': self.skipped_frames, + 'is_connected': self.is_connected, + 'reconnect_attempts': self.reconnect_attempts + } def stop(self): - # indicate that the thread should be stopped + # 指示线程停止 self.stopped = True + # 释放资源 + self.stream.release()