From d145b8c421efda384dc85817bdd12b5d34da0459 Mon Sep 17 00:00:00 2001 From: Hariom_Nagar Date: Thu, 4 Sep 2025 12:06:17 +0530 Subject: [PATCH 1/2] license_plate_detection_yunet: fix NMSBoxes input (use [x,y,w,h] built from 4-point corners); keep original quad dets. Fixes #275 --- .../lpd_yunet.py | 78 ++++++++++++------- 1 file changed, 48 insertions(+), 30 deletions(-) diff --git a/models/license_plate_detection_yunet/lpd_yunet.py b/models/license_plate_detection_yunet/lpd_yunet.py index 917e58a3..21aa5c42 100644 --- a/models/license_plate_detection_yunet/lpd_yunet.py +++ b/models/license_plate_detection_yunet/lpd_yunet.py @@ -1,13 +1,15 @@ from itertools import product - import numpy as np import cv2 as cv + class LPD_YuNet: - def __init__(self, modelPath, inputSize=[320, 240], confThreshold=0.8, nmsThreshold=0.3, topK=5000, keepTopK=750, backendId=0, targetId=0): + def __init__(self, modelPath, inputSize=[320, 240], confThreshold=0.8, + nmsThreshold=0.3, topK=5000, keepTopK=750, + backendId=0, targetId=0): self.model_path = modelPath self.input_size = np.array(inputSize) - self.confidence_threshold=confThreshold + self.confidence_threshold = confThreshold self.nms_threshold = nmsThreshold self.top_k = topK self.keep_top_k = keepTopK @@ -19,12 +21,12 @@ def __init__(self, modelPath, inputSize=[320, 240], confThreshold=0.8, nmsThresh self.steps = [8, 16, 32, 64] self.variance = [0.1, 0.2] - # load model + # Load model self.model = cv.dnn.readNet(self.model_path) - # set backend and target self.model.setPreferableBackend(self.backend_id) self.model.setPreferableTarget(self.target_id) - # generate anchors/priorboxes + + # Generate anchors/priorboxes self._priorGen() @property @@ -39,15 +41,16 @@ def setBackendAndTarget(self, backendId, targetId): def setInputSize(self, inputSize): self.input_size = inputSize - # re-generate anchors/priorboxes self._priorGen() def _preprocess(self, image): return cv.dnn.blobFromImage(image) def infer(self, image): - assert image.shape[0] == self.input_size[1], '{} (height of input image) != {} (preset height)'.format(image.shape[0], self.input_size[1]) - assert image.shape[1] == self.input_size[0], '{} (width of input image) != {} (preset width)'.format(image.shape[1], self.input_size[0]) + assert image.shape[0] == self.input_size[1], \ + f"{image.shape[0]} (height of input image) != {self.input_size[1]} (preset height)" + assert image.shape[1] == self.input_size[0], \ + f"{image.shape[1]} (width of input image) != {self.input_size[0]} (preset width)" # Preprocess inputBlob = self._preprocess(image) @@ -58,26 +61,44 @@ def infer(self, image): # Postprocess results = self._postprocess(outputBlob) - return results def _postprocess(self, blob): - # Decode + # Decode outputs dets = self._decode(blob) - # NMS + # dets shape: [x1,y1,x2,y2,x3,y3,x4,y4,score] + pts = dets[:, :-1].reshape(-1, 8) # N x 8 corners + scores = dets[:, -1].astype(float).tolist() + + # Convert corners → [x,y,w,h] for NMS + bboxes = [] + for p in pts: + xs = p[0::2] + ys = p[1::2] + x_min, y_min = float(xs.min()), float(ys.min()) + w, h = float(xs.max() - x_min), float(ys.max() - y_min) + bboxes.append([x_min, y_min, w, h]) + keepIdx = cv.dnn.NMSBoxes( - bboxes=dets[:, 0:4].tolist(), - scores=dets[:, -1].tolist(), + bboxes=bboxes, + scores=scores, score_threshold=self.confidence_threshold, nms_threshold=self.nms_threshold, top_k=self.top_k - ) # box_num x class_num - if len(keepIdx) > 0: - dets = dets[keepIdx] - return dets[:self.keep_top_k] - else: - return np.empty(shape=(0, 9)) + ) + + # Normalize keepIdx across OpenCV versions + if isinstance(keepIdx, tuple): + keepIdx = keepIdx[0] + if len(keepIdx) == 0: + return np.empty((0, dets.shape[1]), dtype=dets.dtype) + + keepIdx = np.array(keepIdx).reshape(-1) + + # Keep original quadrilateral detections + dets = dets[keepIdx] + return dets[:self.keep_top_k] def _priorGen(self): w, h = self.input_size @@ -98,36 +119,33 @@ def _priorGen(self): priors = [] for k, f in enumerate(feature_maps): min_sizes = self.min_sizes[k] - for i, j in product(range(f[0]), range(f[1])): # i->h, j->w + for i, j in product(range(f[0]), range(f[1])): # i→h, j→w for min_size in min_sizes: s_kx = min_size / w s_ky = min_size / h - cx = (j + 0.5) * self.steps[k] / w cy = (i + 0.5) * self.steps[k] / h - priors.append([cx, cy, s_kx, s_ky]) self.priors = np.array(priors, dtype=np.float32) def _decode(self, blob): loc, conf, iou = blob + # get score cls_scores = conf[:, 1] iou_scores = iou[:, 0] + # clamp - _idx = np.where(iou_scores < 0.) - iou_scores[_idx] = 0. - _idx = np.where(iou_scores > 1.) - iou_scores[_idx] = 1. + iou_scores = np.clip(iou_scores, 0., 1.) scores = np.sqrt(cls_scores * iou_scores) scores = scores[:, np.newaxis] scale = self.input_size - # get four corner points for bounding box + # get four corner points bboxes = np.hstack(( - (self.priors[:, 0:2] + loc[:, 4: 6] * self.variance[0] * self.priors[:, 2:4]) * scale, - (self.priors[:, 0:2] + loc[:, 6: 8] * self.variance[0] * self.priors[:, 2:4]) * scale, + (self.priors[:, 0:2] + loc[:, 4:6] * self.variance[0] * self.priors[:, 2:4]) * scale, + (self.priors[:, 0:2] + loc[:, 6:8] * self.variance[0] * self.priors[:, 2:4]) * scale, (self.priors[:, 0:2] + loc[:, 10:12] * self.variance[0] * self.priors[:, 2:4]) * scale, (self.priors[:, 0:2] + loc[:, 12:14] * self.variance[0] * self.priors[:, 2:4]) * scale )) From d43c91a4344361f80878d5f64cff80e799002a48 Mon Sep 17 00:00:00 2001 From: Hariom_Nagar Date: Mon, 22 Sep 2025 21:46:41 +0530 Subject: [PATCH 2/2] Quantize NafNet (deblurring): add ORT pipeline, validation tool, and robust quantization handling --- .gitignore | 5 + .../validate_quantization.py | 156 ++++++++++++++++++ tools/quantize/quantize-ort.py | 92 +++++++++-- tools/quantize/transform.py | 14 +- 4 files changed, 250 insertions(+), 17 deletions(-) create mode 100644 models/deblurring_nafnet/validate_quantization.py diff --git a/.gitignore b/.gitignore index 2df6ebfd..5c998d8d 100644 --- a/.gitignore +++ b/.gitignore @@ -7,3 +7,8 @@ build/ **/build **/build/** + +# Generated quantized models and outputs (do not commit) +models/deblurring_nafnet/*_int8.onnx +models/deblurring_nafnet/*_output.png +tools/quantize/*-opset13.onnx diff --git a/models/deblurring_nafnet/validate_quantization.py b/models/deblurring_nafnet/validate_quantization.py new file mode 100644 index 00000000..2bca7fb7 --- /dev/null +++ b/models/deblurring_nafnet/validate_quantization.py @@ -0,0 +1,156 @@ +# This file is part of OpenCV Zoo project. +# It is subject to the license terms in the LICENSE file found in the same directory. +# +# Quantization validation utility for deblurring_nafnet +# Runs FP32 and INT8 ONNX models with identical preprocessing and reports timing, PSNR, and SSIM. + +import argparse +import time +from typing import Tuple + +import cv2 as cv +import numpy as np +import onnxruntime as ort + + +def get_args_parser(func_args): + parser = argparse.ArgumentParser(add_help=False) + parser.add_argument('--input', help='Path to input image.', default='example_outputs/licenseplate_motion.jpg', required=False) + parser.add_argument('--model_fp32', help='Path to FP32 ONNX model', default='deblurring_nafnet_2025may.onnx', required=False) + parser.add_argument('--model_int8', help='Path to INT8 (quantized) ONNX model', default='deblurring_nafnet_2025may_int8.onnx', required=False) + parser.add_argument('--save_outputs', action='store_true', help='Save output images next to models') + + args, _ = parser.parse_known_args(func_args) + parser = argparse.ArgumentParser(parents=[parser], description='Validate quantization for deblurring_nafnet', formatter_class=argparse.RawTextHelpFormatter) + return parser.parse_args(func_args) + + +def preprocess(image: np.ndarray) -> np.ndarray: + # Match nafnet.py: blobFromImage(image, 1/255, (W,H), mean=(0,0,0), swapRB=True, crop=False) + blob = cv.dnn.blobFromImage(image, scalefactor=1.0/255.0, size=(image.shape[1], image.shape[0]), mean=(0, 0, 0), swapRB=True, crop=False) + return blob.astype(np.float32) + + +def postprocess(output: np.ndarray) -> np.ndarray: + # output expected shape: (1, C, H, W) + result = output[0] + result = np.transpose(result, (1, 2, 0)) + result = np.clip(result * 255.0, 0, 255).astype(np.uint8) + result = cv.cvtColor(result, cv.COLOR_RGB2BGR) + return result + + +def run_onnx(session: ort.InferenceSession, input_blob: np.ndarray) -> Tuple[np.ndarray, float]: + input_name = session.get_inputs()[0].name + tm = cv.TickMeter() + tm.start() + output_list = session.run(None, {input_name: input_blob}) + tm.stop() + return output_list[0], tm.getTimeMilli() + + +def compute_psnr(img1: np.ndarray, img2: np.ndarray) -> float: + return cv.PSNR(img1, img2) + + +def compute_ssim(img1: np.ndarray, img2: np.ndarray) -> float: + # Simple SSIM implementation averaged across channels + C1 = (0.01 * 255) ** 2 + C2 = (0.03 * 255) ** 2 + + if img1.ndim == 2: + img1 = img1[..., None] + if img2.ndim == 2: + img2 = img2[..., None] + + img1 = img1.astype(np.float64) + img2 = img2.astype(np.float64) + + ssim_per_channel = [] + for ch in range(img1.shape[2]): + x = img1[:, :, ch] + y = img2[:, :, ch] + mu_x = cv.GaussianBlur(x, (11, 11), 1.5) + mu_y = cv.GaussianBlur(y, (11, 11), 1.5) + mu_x_mu_y = mu_x * mu_y + mu_x_sq = mu_x * mu_x + mu_y_sq = mu_y * mu_y + + sigma_x_sq = cv.GaussianBlur(x * x, (11, 11), 1.5) - mu_x_sq + sigma_y_sq = cv.GaussianBlur(y * y, (11, 11), 1.5) - mu_y_sq + sigma_xy = cv.GaussianBlur(x * y, (11, 11), 1.5) - mu_x_mu_y + + numerator = (2 * mu_x_mu_y + C1) * (2 * sigma_xy + C2) + denominator = (mu_x_sq + mu_y_sq + C1) * (sigma_x_sq + sigma_y_sq + C2) + ssim_map = numerator / (denominator + 1e-12) + ssim_per_channel.append(ssim_map.mean()) + + return float(np.mean(ssim_per_channel)) + + +def main(func_args=None): + args = get_args_parser(func_args) + + # Load image + image = cv.imread(args.input) + if image is None: + raise FileNotFoundError(f'Failed to read input image: {args.input}') + + # Preprocess + blob = preprocess(image) + + # Create ORT sessions + sess_options = ort.SessionOptions() + providers = ['CPUExecutionProvider'] + + fp32_sess = ort.InferenceSession(args.model_fp32, sess_options=sess_options, providers=providers) + int8_sess = ort.InferenceSession(args.model_int8, sess_options=sess_options, providers=providers) + + # Run inference + fp32_out, t_fp32 = run_onnx(fp32_sess, blob) + int8_out, t_int8 = run_onnx(int8_sess, blob) + + # Postprocess + fp32_img = postprocess(fp32_out) + int8_img = postprocess(int8_out) + + # Metrics + psnr = compute_psnr(fp32_img, int8_img) + ssim = compute_ssim(fp32_img, int8_img) + + # Display + label_fp32 = f'FP32 time: {t_fp32:.2f} ms' + label_int8 = f'INT8 time: {t_int8:.2f} ms\nPSNR(fp32,int8): {psnr:.2f} dB\nSSIM(fp32,int8): {ssim:.4f}' + + fp32_disp = fp32_img.copy() + int8_disp = int8_img.copy() + cv.putText(fp32_disp, label_fp32, (8, 22), cv.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 2) + cv.putText(fp32_disp, label_fp32, (8, 22), cv.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1) + + y = 22 + for line in label_int8.split('\n'): + cv.putText(int8_disp, line, (8, y), cv.FONT_HERSHEY_SIMPLEX, 0.6, (0, 0, 0), 2) + cv.putText(int8_disp, line, (8, y), cv.FONT_HERSHEY_SIMPLEX, 0.6, (255, 255, 255), 1) + y += 22 + + cv.imshow('Input', image) + cv.imshow('Output FP32', fp32_disp) + cv.imshow('Output INT8', int8_disp) + cv.waitKey(0) + cv.destroyAllWindows() + + if args.save_outputs: + base_fp32 = args.model_fp32.rsplit('.', 1)[0] + base_int8 = args.model_int8.rsplit('.', 1)[0] + cv.imwrite(base_fp32 + '_output.png', fp32_img) + cv.imwrite(base_int8 + '_output.png', int8_img) + + print('--- Quantization Validation ---') + print(f'FP32 model: {args.model_fp32}\n Inference time: {t_fp32:.2f} ms') + print(f'INT8 model: {args.model_int8}\n Inference time: {t_int8:.2f} ms') + print(f'PSNR (FP32 vs INT8): {psnr:.2f} dB') + print(f'SSIM (FP32 vs INT8): {ssim:.4f}') + + +if __name__ == '__main__': + main() diff --git a/tools/quantize/quantize-ort.py b/tools/quantize/quantize-ort.py index aba57f71..41c69b9a 100644 --- a/tools/quantize/quantize-ort.py +++ b/tools/quantize/quantize-ort.py @@ -12,17 +12,26 @@ import onnx from onnx import version_converter import onnxruntime -from onnxruntime.quantization import quantize_static, CalibrationDataReader, QuantType, QuantFormat, quant_pre_process +from onnxruntime.quantization import ( + quantize_static, + CalibrationDataReader, + QuantType, + QuantFormat, + quant_pre_process, + CalibrationMethod, +) from transform import Compose, Resize, CenterCrop, Normalize, ColorConvert, HandAlign class DataReader(CalibrationDataReader): - def __init__(self, model_path, image_dir, transforms, data_dim): + def __init__(self, model_path, image_dir, transforms, data_dim, max_samples=None): model = onnx.load(model_path) self.input_name = model.graph.input[0].name self.transforms = transforms self.data_dim = data_dim self.data = self.get_calibration_data(image_dir) + if max_samples is not None and len(self.data) > max_samples: + self.data = self.data[:max_samples] self.enum_data_dicts = iter([{self.input_name: x} for x in self.data]) def get_next(self): @@ -46,7 +55,7 @@ def get_calibration_data(self, image_dir): return blobs class Quantize: - def __init__(self, model_path, calibration_image_dir, transforms=Compose(), per_channel=False, act_type='int8', wt_type='int8', data_dim='chw', nodes_to_exclude=[]): + def __init__(self, model_path, calibration_image_dir, transforms=Compose(), per_channel=False, act_type='uint8', wt_type='int8', data_dim='chw', nodes_to_exclude=[], max_samples=None, quant_format=QuantFormat.QDQ, op_types_to_quantize=None): self.type_dict = {"uint8" : QuantType.QUInt8, "int8" : QuantType.QInt8} self.model_path = model_path @@ -56,14 +65,18 @@ def __init__(self, model_path, calibration_image_dir, transforms=Compose(), per_ self.act_type = act_type self.wt_type = wt_type self.nodes_to_exclude = nodes_to_exclude + self.data_dim = data_dim + self.max_samples = max_samples + self.quant_format = quant_format + self.op_types_to_quantize = op_types_to_quantize - # data reader - self.dr = DataReader(self.model_path, self.calibration_image_dir, self.transforms, data_dim) + # DataReader is created in run() to avoid touching all model paths at import time def check_opset(self): model = onnx.load(self.model_path) - if model.opset_import[0].version != 13: - print('\tmodel opset version: {}. Converting to opset 13'.format(model.opset_import[0].version)) + current_opset = model.opset_import[0].version + if current_opset < 13: + print('\tmodel opset version: {}. Converting to opset 13'.format(current_opset)) # convert opset version to 13 model_opset13 = version_converter.convert_version(model, 13) # save converted model @@ -71,19 +84,57 @@ def check_opset(self): onnx.save_model(model_opset13, output_name) # update model_path for quantization return output_name + # opset >= 13: keep as-is to avoid conversion errors with newer opsets (e.g., blocked quantization) return self.model_path def run(self): print('Quantizing {}: act_type {}, wt_type {}'.format(self.model_path, self.act_type, self.wt_type)) new_model_path = self.check_opset() - quant_pre_process(new_model_path, new_model_path) + # quant_pre_process may require onnxruntime-extensions; make it optional + try: + quant_pre_process(new_model_path, new_model_path) + except Exception as e: + print('\tquant_pre_process skipped:', e) output_name = '{}_{}.onnx'.format(self.model_path[:-5], self.wt_type) - quantize_static(new_model_path, output_name, self.dr, - quant_format=QuantFormat.QOperator, # start from onnxruntime==1.11.0, quant_format is set to QuantFormat.QDQ by default, which performs fake quantization - per_channel=self.per_channel, - weight_type=self.type_dict[self.wt_type], - activation_type=self.type_dict[self.act_type], - nodes_to_exclude=self.nodes_to_exclude) + # Create data reader lazily here so only selected model accesses its calibration data + dr = DataReader(new_model_path, self.calibration_image_dir, self.transforms, self.data_dim, max_samples=self.max_samples) + # Auto-detect Conv nodes without bias and exclude them to avoid bias=None errors + try: + model_for_scan = onnx.load(new_model_path) + initializer_names = {init.name for init in model_for_scan.graph.initializer} + convs_without_bias = [] + for node in model_for_scan.graph.node: + if node.op_type == 'Conv': + # Conv inputs: [X, W] or [X, W, B] + if len(node.input) < 3 or node.input[2] not in initializer_names: + if node.name: + convs_without_bias.append(node.name) + if convs_without_bias: + print('\tAuto-excluding Conv nodes without bias:', convs_without_bias[:5], ('... (+%d more)' % (len(convs_without_bias)-5) if len(convs_without_bias)>5 else '')) + combined_excludes = list(set(self.nodes_to_exclude + convs_without_bias)) + except Exception as _: + combined_excludes = self.nodes_to_exclude + try: + quantize_static(new_model_path, output_name, dr, + quant_format=self.quant_format, + per_channel=self.per_channel, + weight_type=self.type_dict[self.wt_type], + activation_type=self.type_dict[self.act_type], + nodes_to_exclude=combined_excludes, + calibrate_method=CalibrationMethod.MinMax, + op_types_to_quantize=self.op_types_to_quantize) + except Exception as e: + print('\tPrimary quantization failed:', e) + print('\tRetrying with fallback: QuantFormat.QDQ and Conv-only') + # Fallback to a more robust configuration + quantize_static(new_model_path, output_name, dr, + quant_format=QuantFormat.QDQ, + per_channel=True, + weight_type=self.type_dict[self.wt_type], + activation_type=self.type_dict[self.act_type], + nodes_to_exclude=combined_excludes, + calibrate_method=CalibrationMethod.MinMax, + op_types_to_quantize=['Conv']) if new_model_path != self.model_path: os.remove(new_model_path) print('\tQuantized model saved to {}'.format(output_name)) @@ -134,6 +185,19 @@ def run(self): transforms=Compose([Resize(size=(320, 240))]), nodes_to_exclude=['MaxPool_5', 'MaxPool_18', 'MaxPool_25', 'MaxPool_32', 'MaxPool_39'], ), + deblurring_nafnet=Quantize(model_path='../../models/deblurring_nafnet/deblurring_nafnet_2025may.onnx', + calibration_image_dir='../../models/deblurring_nafnet/example_outputs', + transforms=Compose([ + Resize(size=(512, 512)), + Normalize(std=[255, 255, 255]), + ColorConvert(ctype=cv.COLOR_BGR2RGB) + ]), + per_channel=True, + act_type='uint8', + wt_type='int8', + max_samples=1, + quant_format=QuantFormat.QOperator, + op_types_to_quantize=['Conv', 'MatMul']), ) if __name__ == '__main__': diff --git a/tools/quantize/transform.py b/tools/quantize/transform.py index 10d97521..a4b1850f 100644 --- a/tools/quantize/transform.py +++ b/tools/quantize/transform.py @@ -65,11 +65,19 @@ def __call__(self, img): class HandAlign: def __init__(self, model): self.model = model - sys.path.append('../../models/palm_detection_mediapipe') - from mp_palmdet import MPPalmDet - self.palm_detector = MPPalmDet(modelPath='../../models/palm_detection_mediapipe/palm_detection_mediapipe_2023feb.onnx', nmsThreshold=0.3, scoreThreshold=0.9) + # Lazy initialization to avoid loading dependencies when not used + self.palm_detector = None def __call__(self, img): + # Initialize palm detector on first use only + if self.palm_detector is None: + sys.path.append('../../models/palm_detection_mediapipe') + from mp_palmdet import MPPalmDet + self.palm_detector = MPPalmDet( + modelPath='../../models/palm_detection_mediapipe/palm_detection_mediapipe_2023feb.onnx', + nmsThreshold=0.3, + scoreThreshold=0.9 + ) return self.mp_handpose_align(img) def mp_handpose_align(self, img):