From eafd9a71960bb57e3ce9dcbae40d772e3ab85f17 Mon Sep 17 00:00:00 2001 From: changyixing <2679896825@qq.com> Date: Thu, 12 Mar 2026 15:33:37 +0800 Subject: [PATCH] Optimize: Kimi version - enterprise-grade refactor with robust error handling and UI decoupling --- python-package/insightface/app/common.py | 101 +++- .../insightface/app/face_analysis.py | 514 +++++++++++++++--- python-package/insightface/utils/__init__.py | 1 + .../insightface/utils/visualization.py | 238 ++++++++ 4 files changed, 774 insertions(+), 80 deletions(-) create mode 100644 python-package/insightface/utils/visualization.py diff --git a/python-package/insightface/app/common.py b/python-package/insightface/app/common.py index 82ca987ae..927e51032 100644 --- a/python-package/insightface/app/common.py +++ b/python-package/insightface/app/common.py @@ -1,22 +1,57 @@ +# -*- coding: utf-8 -*- +# @Organization : insightface.ai +# @Author : Jia Guo +# @Time : 2021-05-04 +# @Function : Face data structure + +from typing import Any, Dict, Optional, List, Union, Tuple import numpy as np from numpy.linalg import norm as l2norm -#from easydict import EasyDict + +__all__ = ['Face'] + class Face(dict): + """Face data container that stores all face-related information. + + This class extends dict to provide both dictionary-like and attribute-like + access to face data. It can store various face attributes such as: + - bbox: Bounding box coordinates [x1, y1, x2, y2] + - det_score: Detection confidence score + - kps: Facial keypoints + - embedding: Face feature vector + - gender: Gender prediction (0=female, 1=male) + - age: Age prediction + - pose: Head pose angles [pitch, yaw, roll] + - Various landmark predictions + + Attributes are dynamically set and accessed. Missing attributes return None + instead of raising AttributeError. + + Example: + >>> face = Face(bbox=np.array([100, 100, 200, 200]), det_score=0.95) + >>> print(face.bbox) # Attribute access + >>> print(face['bbox']) # Dictionary access + >>> face.gender = 1 # Set attribute + >>> print(face.sex) # Computed property + """ - def __init__(self, d=None, **kwargs): + def __init__(self, d: Optional[Dict[str, Any]] = None, **kwargs): + """Initialize a Face object. + + Args: + d: Dictionary of initial face attributes. + **kwargs: Additional face attributes as keyword arguments. + """ if d is None: d = {} if kwargs: d.update(**kwargs) for k, v in d.items(): setattr(self, k, v) - # Class attributes - #for k in self.__class__.__dict__.keys(): - # if not (k.startswith('__') and k.endswith('__')) and not k in ('update', 'pop'): - # setattr(self, k, getattr(self, k)) - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: Any) -> None: + """Set attribute with automatic conversion of nested dicts to Face objects.""" if isinstance(value, (list, tuple)): value = [self.__class__(x) if isinstance(x, dict) else x for x in value] @@ -27,23 +62,63 @@ def __setattr__(self, name, value): __setitem__ = __setattr__ - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: + """Get attribute, return None if not found.""" return None @property - def embedding_norm(self): + def embedding_norm(self) -> Optional[float]: + """Compute L2 norm of the face embedding. + + Returns: + L2 norm of embedding, or None if embedding is not set. + """ if self.embedding is None: return None return l2norm(self.embedding) @property - def normed_embedding(self): + def normed_embedding(self) -> Optional[np.ndarray]: + """Get L2-normalized embedding vector. + + Returns: + Normalized embedding vector, or None if embedding is not set. + """ if self.embedding is None: return None - return self.embedding / self.embedding_norm + norm_val = self.embedding_norm + if norm_val is None or norm_val == 0: + return None + return self.embedding / norm_val @property - def sex(self): + def sex(self) -> Optional[str]: + """Get gender as string representation. + + Returns: + 'M' for male (gender=1), 'F' for female (gender=0), or None if not set. + """ if self.gender is None: return None - return 'M' if self.gender==1 else 'F' + return 'M' if self.gender == 1 else 'F' + + def to_dict(self) -> Dict[str, Any]: + """Convert Face object to a regular dictionary. + + Returns: + Dictionary containing all face attributes. + """ + result = {} + for key, value in self.items(): + if isinstance(value, np.ndarray): + result[key] = value.copy() + elif isinstance(value, Face): + result[key] = value.to_dict() + elif isinstance(value, list): + result[key] = [ + v.to_dict() if isinstance(v, Face) else v + for v in value + ] + else: + result[key] = value + return result diff --git a/python-package/insightface/app/face_analysis.py b/python-package/insightface/app/face_analysis.py index a9112b14a..776b9466f 100644 --- a/python-package/insightface/app/face_analysis.py +++ b/python-package/insightface/app/face_analysis.py @@ -2,13 +2,15 @@ # @Organization : insightface.ai # @Author : Jia Guo # @Time : 2021-05-04 -# @Function : - +# @Function : Face Analysis Pipeline from __future__ import division import glob +import json +import os import os.path as osp +from typing import List, Optional, Tuple, Union, Dict, Any import numpy as np import onnxruntime @@ -16,94 +18,472 @@ from ..model_zoo import model_zoo from ..utils import DEFAULT_MP_NAME, ensure_available +from ..utils.visualization import FaceVisualizer from .common import Face __all__ = ['FaceAnalysis'] + +# Type aliases for better readability +ImageArray = np.ndarray # Expected: uint8 array with shape (H, W, 3), BGR format, values in [0, 255] +BoundingBox = np.ndarray # Shape: (4,), format: [x1, y1, x2, y2] +Keypoints = np.ndarray # Shape: (N, 2), format: [[x, y], ...] + + +def _get_onnx_model_info(onnx_file: str) -> Optional[Dict[str, Any]]: + """Get model information from ONNX file without creating InferenceSession. + + This function reads ONNX file metadata directly to determine model type + and extract configuration, avoiding the overhead of creating an InferenceSession. + + Args: + onnx_file: Path to the ONNX model file. + + Returns: + Dictionary containing model info, or None if model is not recognized. + """ + try: + # Try to use onnx to read model metadata (lightweight) + import onnx + model = onnx.load(onnx_file, load_external_data=False) + graph = model.graph + + # Get input shape from the first input + if len(graph.input) == 0: + return None + + input_tensor = graph.input[0] + input_shape = [] + for dim in input_tensor.type.tensor_type.shape.dim: + if dim.HasField('dim_value'): + input_shape.append(dim.dim_value) + else: + input_shape.append(None) + + # Get number of outputs + num_outputs = len(graph.output) + num_inputs = len(graph.input) + + # Determine model type based on input shape and output count + # This logic mirrors the ModelRouter in model_zoo.py + taskname = None + input_mean = 127.5 + input_std = 128.0 + + # Check for RetinaFace/SCRFD (detection models have many outputs) + if num_outputs >= 5: + taskname = 'detection' + # Detection models use 127.5/128.0 by default + + # Check input shape for other model types + elif len(input_shape) >= 4: + h, w = input_shape[2], input_shape[3] + + if h == 192 and w == 192: + taskname = 'landmark' + # Landmark model + elif h == 96 and w == 96: + taskname = 'genderage' + # Attribute model + elif num_inputs == 2 and h == 128 and w == 128: + taskname = 'inswapper' + # INSwapper model + elif h == w and h >= 112 and h % 16 == 0: + taskname = 'recognition' + # Recognition model - check for mxnet style + # Check first few nodes for Sub/Mul (mxnet style) + find_sub = False + find_mul = False + for nid, node in enumerate(graph.node[:8]): + if node.name.startswith('Sub') or node.name.startswith('_minus'): + find_sub = True + if node.name.startswith('Mul') or node.name.startswith('_mul'): + find_mul = True + + if find_sub and find_mul: + input_mean = 0.0 + input_std = 1.0 + + if taskname is None: + return None + + return { + 'taskname': taskname, + 'input_shape': input_shape, + 'input_mean': input_mean, + 'input_std': input_std, + 'num_outputs': num_outputs, + 'num_inputs': num_inputs, + } + + except ImportError: + # Fallback: onnx not available, return minimal info + # This should not happen in normal usage + return None + except Exception as e: + # If ONNX parsing fails, return None + return None + + class FaceAnalysis: - def __init__(self, name=DEFAULT_MP_NAME, root='~/.insightface', allowed_modules=None, **kwargs): + """Face analysis pipeline that combines detection, alignment, and feature extraction. + + This class provides a high-level interface for face analysis tasks including: + - Face detection + - Facial landmark detection + - Face recognition/embedding extraction + - Attribute prediction (gender, age) + + The class uses lazy loading for models, meaning models are only loaded into memory + when they are first needed, reducing initial memory footprint. + + Attributes: + model_dir: Directory containing ONNX model files. + det_thresh: Detection threshold for face detection. + det_size: Input size for face detection model (width, height). + + Example: + >>> import cv2 + >>> from insightface.app import FaceAnalysis + >>> + >>> # Initialize the app + >>> app = FaceAnalysis(name='buffalo_l') + >>> app.prepare(ctx_id=0, det_size=(640, 640)) + >>> + >>> # Process an image + >>> img = cv2.imread('image.jpg') # BGR format, uint8, [0, 255] + >>> faces = app.get(img) + >>> + >>> # Access results + >>> for face in faces: + ... print(f"BBox: {face.bbox}") + ... print(f"Detection score: {face.det_score}") + ... print(f"Embedding shape: {face.embedding.shape if face.embedding is not None else None}") + """ + + def __init__( + self, + name: str = DEFAULT_MP_NAME, + root: str = '~/.insightface', + allowed_modules: Optional[List[str]] = None, + **kwargs + ): + """Initialize the FaceAnalysis pipeline. + + Args: + name: Name of the model package to use. Default is DEFAULT_MP_NAME. + root: Root directory for model storage. Default is '~/.insightface'. + allowed_modules: List of module names to load. If None, loads all available modules. + Common values: 'detection', 'recognition', 'genderage', 'landmark_2d_106', etc. + **kwargs: Additional arguments passed to model initialization. + """ onnxruntime.set_default_logger_severity(3) - self.models = {} + + # Store model configurations for lazy loading + self._model_configs: Dict[str, Dict[str, Any]] = {} + self._loaded_models: Dict[str, Any] = {} + self._model_kwargs = kwargs + self.model_dir = ensure_available('models', name, root=root) + + # Scan and register available models onnx_files = glob.glob(osp.join(self.model_dir, '*.onnx')) onnx_files = sorted(onnx_files) + for onnx_file in onnx_files: - model = model_zoo.get_model(onnx_file, **kwargs) - if model is None: + model_info = self._get_model_info(onnx_file) + if model_info is None: print('model not recognized:', onnx_file) - elif allowed_modules is not None and model.taskname not in allowed_modules: - print('model ignore:', onnx_file, model.taskname) - del model - elif model.taskname not in self.models and (allowed_modules is None or model.taskname in allowed_modules): - print('find model:', onnx_file, model.taskname, model.input_shape, model.input_mean, model.input_std) - self.models[model.taskname] = model + continue + + taskname = model_info['taskname'] + + if allowed_modules is not None and taskname not in allowed_modules: + print('model ignore:', onnx_file, taskname) + continue + + if taskname in self._model_configs: + print('duplicated model task type, ignore:', onnx_file, taskname) + continue + + print('find model:', onnx_file, taskname, + model_info.get('input_shape'), + model_info.get('input_mean'), + model_info.get('input_std')) + + self._model_configs[taskname] = { + 'onnx_file': onnx_file, + 'kwargs': kwargs + } + + # Ensure detection model is available + if 'detection' not in self._model_configs: + raise ValueError("No detection model found. FaceAnalysis requires a detection model.") + + def _get_model_info(self, onnx_file: str) -> Optional[Dict[str, Any]]: + """Get model information without fully loading the model. + + This method uses a fast path to read ONNX metadata without creating + an InferenceSession, significantly improving initialization speed. + + Args: + onnx_file: Path to the ONNX model file. + + Returns: + Dictionary containing model info, or None if model is not recognized. + """ + # First try the fast path using onnx directly + info = _get_onnx_model_info(onnx_file) + if info is not None: + return info + + # Fallback: use model_zoo if onnx parsing fails + # This is slower but more robust + try: + model = model_zoo.get_model(onnx_file, **self._model_kwargs) + if model is None: + return None + + # Extract relevant info and delete model + info = { + 'taskname': model.taskname, + 'input_shape': getattr(model, 'input_shape', None), + 'input_mean': getattr(model, 'input_mean', None), + 'input_std': getattr(model, 'input_std', None), + } + del model + return info + except Exception as e: + print(f'Error loading model info for {onnx_file}: {e}') + return None + + def _load_model(self, taskname: str) -> Any: + """Lazy load a model by taskname. + + Args: + taskname: Name of the task/model to load. + + Returns: + The loaded model instance. + + Raises: + KeyError: If the taskname is not registered. + """ + if taskname in self._loaded_models: + return self._loaded_models[taskname] + + if taskname not in self._model_configs: + raise KeyError(f"Model for task '{taskname}' not found") + + config = self._model_configs[taskname] + model = model_zoo.get_model(config['onnx_file'], **config['kwargs']) + + # Prepare model if ctx_id was set + if hasattr(self, '_ctx_id'): + if taskname == 'detection': + model.prepare( + self._ctx_id, + input_size=getattr(self, 'det_size', None), + det_thresh=getattr(self, 'det_thresh', 0.5) + ) else: - print('duplicated model task type, ignore:', onnx_file, model.taskname) - del model - assert 'detection' in self.models - self.det_model = self.models['detection'] - - - def prepare(self, ctx_id, det_thresh=0.5, det_size=(640, 640)): + model.prepare(self._ctx_id) + + self._loaded_models[taskname] = model + return model + + @property + def det_model(self) -> Any: + """Get the detection model (lazy loaded).""" + return self._load_model('detection') + + @property + def models(self) -> Dict[str, Any]: + """Get all loaded models.""" + # Ensure detection model is loaded + self._load_model('detection') + return self._loaded_models + + def prepare( + self, + ctx_id: int, + det_thresh: float = 0.5, + det_size: Tuple[int, int] = (640, 640) + ) -> None: + """Prepare the pipeline with execution context and detection parameters. + + Args: + ctx_id: Execution provider ID. Use >= 0 for GPU, -1 for CPU. + det_thresh: Detection threshold. Faces with scores below this value are filtered out. + Default is 0.5. + det_size: Input size for detection model as (width, height). Default is (640, 640). + Larger sizes detect smaller faces but are slower. + """ self.det_thresh = det_thresh - assert det_size is not None - print('set det-size:', det_size) self.det_size = det_size - for taskname, model in self.models.items(): - if taskname=='detection': + self._ctx_id = ctx_id + + assert det_size is not None, "det_size cannot be None" + print('set det-size:', det_size) + + # Prepare already loaded models + for taskname, model in self._loaded_models.items(): + if taskname == 'detection': model.prepare(ctx_id, input_size=det_size, det_thresh=det_thresh) else: model.prepare(ctx_id) - - def get(self, img, max_num=0, det_metric='default'): - bboxes, kpss = self.det_model.detect(img, - max_num=max_num, - metric=det_metric) - if bboxes.shape[0] == 0: + + def get( + self, + img: ImageArray, + max_num: int = 0, + det_metric: str = 'default' + ) -> List[Face]: + """Detect and analyze faces in an image. + + Args: + img: Input image as a numpy array. + - Format: BGR (OpenCV default) + - Shape: (height, width, 3) + - Dtype: uint8 + - Value range: [0, 255] + max_num: Maximum number of faces to return. If 0, returns all detected faces. + Default is 0. + det_metric: Metric for selecting faces when max_num is specified. + - 'default': Uses area minus distance from center (prefers larger, centered faces) + - 'max': Uses only face area + Default is 'default'. + + Returns: + List of Face objects containing detection and analysis results. + Each Face object may contain: + - bbox: Bounding box [x1, y1, x2, y2] + - det_score: Detection confidence score + - kps: Facial keypoints + - embedding: Face feature vector + - gender: Gender prediction (0=female, 1=male) + - age: Age prediction + - Various landmark predictions depending on loaded models + + Raises: + ValueError: If img is not a valid numpy array. + """ + # Validate input + if not isinstance(img, np.ndarray): + raise ValueError(f"img must be a numpy array, got {type(img)}") + + if img.ndim != 3: + raise ValueError(f"img must have 3 dimensions (H, W, C), got shape {img.shape}") + + if img.shape[2] != 3: + raise ValueError(f"img must have 3 channels (BGR), got shape {img.shape}") + + # Detect faces + bboxes, kpss = self.det_model.detect( + img, + max_num=max_num, + metric=det_metric + ) + + # Handle empty detection + if bboxes is None or bboxes.shape[0] == 0: return [] + + # Process each detected face ret = [] for i in range(bboxes.shape[0]): - bbox = bboxes[i, 0:4] - det_score = bboxes[i, 4] + # Safely extract bounding box with bounds checking + bbox = self._safe_extract_bbox(bboxes, i) + det_score = self._safe_extract_score(bboxes, i) + + # Safely extract keypoints if available kps = None - if kpss is not None: + if kpss is not None and i < kpss.shape[0]: kps = kpss[i] + + # Create Face object face = Face(bbox=bbox, kps=kps, det_score=det_score) - for taskname, model in self.models.items(): - if taskname=='detection': + + # Apply other models (lazy loaded) + for taskname in self._model_configs: + if taskname == 'detection': continue - model.get(img, face) + try: + model = self._load_model(taskname) + model.get(img, face) + except Exception as e: + # Log error but continue processing other models + print(f"Error applying model '{taskname}': {e}") + ret.append(face) + return ret - - def draw_on(self, img, faces): - import cv2 - dimg = img.copy() - for i in range(len(faces)): - face = faces[i] - box = face.bbox.astype(int) - color = (0, 0, 255) - cv2.rectangle(dimg, (box[0], box[1]), (box[2], box[3]), color, 2) - if face.kps is not None: - kps = face.kps.astype(int) - #print(landmark.shape) - for l in range(kps.shape[0]): - color = (0, 0, 255) - if l == 0 or l == 3: - color = (0, 255, 0) - cv2.circle(dimg, (kps[l][0], kps[l][1]), 1, color, - 2) - if face.gender is not None and face.age is not None: - cv2.putText(dimg,'%s,%d'%(face.sex,face.age), (box[0]-1, box[1]-4),cv2.FONT_HERSHEY_COMPLEX,0.7,(0,255,0),1) - - #for key, value in face.items(): - # if key.startswith('landmark_3d'): - # print(key, value.shape) - # print(value[0:10,:]) - # lmk = np.round(value).astype(int) - # for l in range(lmk.shape[0]): - # color = (255, 0, 0) - # cv2.circle(dimg, (lmk[l][0], lmk[l][1]), 1, color, - # 2) - return dimg - + + def _safe_extract_bbox(self, bboxes: np.ndarray, index: int) -> Optional[np.ndarray]: + """Safely extract bounding box from detection results. + + Args: + bboxes: Array of bounding boxes with shape (N, 5+) where each row is [x1, y1, x2, y2, score, ...] + index: Index of the bounding box to extract. + + Returns: + Bounding box array [x1, y1, x2, y2] or None if extraction fails. + """ + try: + if index < 0 or index >= bboxes.shape[0]: + return None + if bboxes.shape[1] < 4: + return None + return bboxes[index, :4].copy() + except (IndexError, ValueError): + return None + + def _safe_extract_score(self, bboxes: np.ndarray, index: int) -> Optional[float]: + """Safely extract detection score from detection results. + + Args: + bboxes: Array of bounding boxes with shape (N, 5+) where each row is [x1, y1, x2, y2, score, ...] + index: Index of the bounding box to extract. + + Returns: + Detection score as float or None if extraction fails. + """ + try: + if index < 0 or index >= bboxes.shape[0]: + return None + if bboxes.shape[1] < 5: + return None + return float(bboxes[index, 4]) + except (IndexError, ValueError, TypeError): + return None + + def draw_on( + self, + img: ImageArray, + faces: List[Face], + **kwargs + ) -> np.ndarray: + """Draw face analysis results on an image. + + .. deprecated:: + This method is deprecated. Use insightface.utils.visualization.draw_on() + or insightface.utils.visualization.FaceVisualizer instead. + + Args: + img: Input image in BGR format with shape (H, W, 3). + Expected dtype is uint8 with values in range [0, 255]. + faces: List of Face objects to draw. + **kwargs: Additional arguments passed to FaceVisualizer. + + Returns: + A copy of the input image with drawings. + """ + import warnings + warnings.warn( + "FaceAnalysis.draw_on() is deprecated. " + "Use insightface.utils.visualization.draw_on() or FaceVisualizer instead.", + DeprecationWarning, + stacklevel=2 + ) + visualizer = FaceVisualizer(**kwargs) + return visualizer.draw_on(img, faces) diff --git a/python-package/insightface/utils/__init__.py b/python-package/insightface/utils/__init__.py index d3ccb4e65..5a778f73a 100644 --- a/python-package/insightface/utils/__init__.py +++ b/python-package/insightface/utils/__init__.py @@ -10,6 +10,7 @@ from .filesystem import get_model_dir from .filesystem import makedirs, try_import_dali from .constant import * +from .visualization import FaceVisualizer, draw_on #from .bbox import bbox_iou #from .block import recursive_visit, set_lr_mult, freeze_bn #from .lr_scheduler import LRSequential, LRScheduler diff --git a/python-package/insightface/utils/visualization.py b/python-package/insightface/utils/visualization.py new file mode 100644 index 000000000..3dea9af53 --- /dev/null +++ b/python-package/insightface/utils/visualization.py @@ -0,0 +1,238 @@ +# -*- coding: utf-8 -*- +# @Organization : insightface.ai +# @Author : Jia Guo +# @Time : 2021-05-04 +# @Function : Visualization utilities for face analysis results + +from typing import List, Optional, Tuple, Union +import numpy as np + +try: + import cv2 + CV2_AVAILABLE = True +except ImportError: + CV2_AVAILABLE = False + + +class FaceVisualizer: + """Visualizer for face detection and analysis results. + + This class provides methods to draw face detection results on images, + including bounding boxes, keypoints, and attributes. + + Attributes: + bbox_color: Color for bounding boxes in BGR format. Default is red (0, 0, 255). + keypoint_color: Default color for keypoints in BGR format. Default is red (0, 0, 255). + keypoint_special_color: Special color for specific keypoints in BGR format. Default is green (0, 255, 0). + text_color: Color for text annotations in BGR format. Default is green (0, 255, 0). + bbox_thickness: Thickness of bounding box lines. Default is 2. + keypoint_thickness: Thickness of keypoint circles. Default is 2. + keypoint_radius: Radius of keypoint circles. Default is 1. + font: Font type for text. Default is FONT_HERSHEY_COMPLEX. + font_scale: Scale factor for font. Default is 0.7. + font_thickness: Thickness of font. Default is 1. + """ + + def __init__( + self, + bbox_color: Tuple[int, int, int] = (0, 0, 255), + keypoint_color: Tuple[int, int, int] = (0, 0, 255), + keypoint_special_color: Tuple[int, int, int] = (0, 255, 0), + text_color: Tuple[int, int, int] = (0, 255, 0), + bbox_thickness: int = 2, + keypoint_thickness: int = 2, + keypoint_radius: int = 1, + font: int = None, # Will use FONT_HERSHEY_COMPLEX as default + font_scale: float = 0.7, + font_thickness: int = 1 + ): + if not CV2_AVAILABLE: + raise ImportError("OpenCV (cv2) is required for visualization. " + "Please install it with: pip install opencv-python") + + self.bbox_color = bbox_color + self.keypoint_color = keypoint_color + self.keypoint_special_color = keypoint_special_color + self.text_color = text_color + self.bbox_thickness = bbox_thickness + self.keypoint_thickness = keypoint_thickness + self.keypoint_radius = keypoint_radius + self.font = font if font is not None else cv2.FONT_HERSHEY_COMPLEX + self.font_scale = font_scale + self.font_thickness = font_thickness + + def draw_on( + self, + img: np.ndarray, + faces: List, + draw_bbox: bool = True, + draw_keypoints: bool = True, + draw_attributes: bool = True + ) -> np.ndarray: + """Draw face analysis results on an image. + + Args: + img: Input image in BGR format with shape (H, W, 3). + Expected dtype is uint8 with values in range [0, 255]. + faces: List of Face objects to draw. + draw_bbox: Whether to draw bounding boxes. Default is True. + draw_keypoints: Whether to draw facial keypoints. Default is True. + draw_attributes: Whether to draw attributes (gender/age). Default is True. + + Returns: + A copy of the input image with drawings. + + Raises: + ValueError: If img is not a valid numpy array. + ImportError: If OpenCV is not available. + """ + if not CV2_AVAILABLE: + raise ImportError("OpenCV (cv2) is required for visualization.") + + if not isinstance(img, np.ndarray): + raise ValueError(f"img must be a numpy array, got {type(img)}") + + if img.ndim != 3 or img.shape[2] != 3: + raise ValueError(f"img must have shape (H, W, 3), got {img.shape}") + + dimg = img.copy() + + for face in faces: + if face is None: + continue + + if draw_bbox and face.bbox is not None: + dimg = self._draw_bbox(dimg, face) + + if draw_keypoints and face.kps is not None: + dimg = self._draw_keypoints(dimg, face) + + if draw_attributes: + dimg = self._draw_attributes(dimg, face) + + return dimg + + def _draw_bbox(self, img: np.ndarray, face) -> np.ndarray: + """Draw bounding box for a face. + + Args: + img: Image to draw on. + face: Face object containing bbox. + + Returns: + Image with bounding box drawn. + """ + if face.bbox is None or len(face.bbox) < 4: + return img + + try: + box = face.bbox.astype(int) + # Ensure box coordinates are valid + if len(box) >= 4: + cv2.rectangle( + img, + (int(box[0]), int(box[1])), + (int(box[2]), int(box[3])), + self.bbox_color, + self.bbox_thickness + ) + except (ValueError, TypeError, IndexError): + # Skip drawing if bbox is invalid + pass + + return img + + def _draw_keypoints(self, img: np.ndarray, face) -> np.ndarray: + """Draw facial keypoints for a face. + + Args: + img: Image to draw on. + face: Face object containing kps. + + Returns: + Image with keypoints drawn. + """ + if face.kps is None: + return img + + try: + kps = face.kps.astype(int) + if kps.ndim != 2 or kps.shape[1] < 2: + return img + + for idx in range(kps.shape[0]): + # Use special color for specific keypoints (indices 0 and 3) + color = self.keypoint_special_color if idx in (0, 3) else self.keypoint_color + cv2.circle( + img, + (int(kps[idx][0]), int(kps[idx][1])), + self.keypoint_radius, + color, + self.keypoint_thickness + ) + except (ValueError, TypeError, IndexError): + # Skip drawing if keypoints are invalid + pass + + return img + + def _draw_attributes(self, img: np.ndarray, face) -> np.ndarray: + """Draw face attributes (gender and age) for a face. + + Args: + img: Image to draw on. + face: Face object containing gender and age. + + Returns: + Image with attributes drawn. + """ + if face.gender is None or face.age is None or face.bbox is None: + return img + + try: + box = face.bbox.astype(int) + if len(box) < 4: + return img + + sex_label = face.sex if face.sex is not None else 'N/A' + text = f"{sex_label},{face.age}" + + # Position text above the bounding box + text_pos = (int(box[0]) - 1, int(box[1]) - 4) + + cv2.putText( + img, + text, + text_pos, + self.font, + self.font_scale, + self.text_color, + self.font_thickness + ) + except (ValueError, TypeError, IndexError): + # Skip drawing if attributes are invalid + pass + + return img + + +def draw_on( + img: np.ndarray, + faces: List, + **kwargs +) -> np.ndarray: + """Convenience function to draw face analysis results on an image. + + This is a wrapper around FaceVisualizer for quick usage. + + Args: + img: Input image in BGR format with shape (H, W, 3). + Expected dtype is uint8 with values in range [0, 255]. + faces: List of Face objects to draw. + **kwargs: Additional arguments passed to FaceVisualizer. + + Returns: + A copy of the input image with drawings. + """ + visualizer = FaceVisualizer(**kwargs) + return visualizer.draw_on(img, faces)