From 7d85054a128ce2ecd0b2f90095562845ac111c8d Mon Sep 17 00:00:00 2001 From: avi Date: Wed, 22 Oct 2025 18:37:04 +0000 Subject: [PATCH 1/2] updates to support multiclass and bounding box for training and rendering --- src/drawing.py | 126 ++++++++++++++++++++++++++++++++++++++++++----- src/s2train.py | 29 +++++++++-- src/s3produce.py | 59 +++++++++++++++++----- 3 files changed, 185 insertions(+), 29 deletions(-) diff --git a/src/drawing.py b/src/drawing.py index 27d87dd..804198d 100644 --- a/src/drawing.py +++ b/src/drawing.py @@ -15,23 +15,34 @@ def generate_unique_bgr_colors(n: int) -> list: return colors def draw_iqs(detectors: list[Detector], iqs: list[ImageQuery], frame: np.ndarray) -> None: - counting_iqs = [] - for detector, counting_iq in zip(detectors, iqs): + # Separate IQs by type + bbox_iqs = [] # COUNT and BOUNDING_BOX modes + banner_iqs = [] # BINARY and MULTI_CLASS modes + + for detector, iq in zip(detectors, iqs): detector_mode = detector.mode - if detector_mode == ModeEnum.COUNT: - counting_iqs.append(counting_iq) + if detector_mode in (ModeEnum.COUNT, ModeEnum.BOUNDING_BOX): + bbox_iqs.append((detector, iq)) + elif detector_mode in (ModeEnum.BINARY, ModeEnum.MULTI_CLASS): + banner_iqs.append((detector, iq)) else: raise NotImplementedError( f'Detector mode {detector_mode} is not yet supported.' ) - - unique_bbox_colors = generate_unique_bgr_colors(len(counting_iqs)) - for n, counting_iq in enumerate(counting_iqs): - color = unique_bbox_colors[n] - draw_bounding_boxes(counting_iq.rois, frame, color) - class_names = [d.mode_configuration["class_name"] for d in detectors] - draw_class_labels(frame, class_names, unique_bbox_colors) + # Draw bounding boxes for COUNT and BOUNDING_BOX modes + if bbox_iqs: + unique_bbox_colors = generate_unique_bgr_colors(len(bbox_iqs)) + for n, (detector, iq) in enumerate(bbox_iqs): + color = unique_bbox_colors[n] + draw_bounding_boxes(iq.rois, frame, color) + + class_names = [d.mode_configuration.get("class_name", f"Detector {i+1}") for i, (d, _) in enumerate(bbox_iqs)] + draw_class_labels(frame, class_names, unique_bbox_colors) + + # Draw banners for BINARY and MULTI_CLASS modes + if banner_iqs: + draw_banners(banner_iqs, frame) def draw_bounding_boxes(rois, frame: np.ndarray, color: tuple) -> None: height, width = frame.shape[:2] @@ -72,4 +83,95 @@ def draw_class_labels(frame: np.ndarray, class_names: list[str], colors: list[tu cv2.rectangle(frame, top_left, bottom_right, (255, 255, 255), thickness=-1) # Text - cv2.putText(frame, label, (x, y), font, font_scale, color, thickness, lineType=cv2.LINE_AA) \ No newline at end of file + cv2.putText(frame, label, (x, y), font, font_scale, color, thickness, lineType=cv2.LINE_AA) + +def draw_banners(detector_iq_pairs: list[tuple[Detector, ImageQuery]], frame: np.ndarray) -> None: + """ + Draw banners at the top of the frame showing query and result for binary/multiclass detectors. + Each detector gets its own line. + """ + font = cv2.FONT_HERSHEY_SIMPLEX + font_scale = 0.7 + thickness = 2 + margin = 8 + frame_height, frame_width = frame.shape[:2] + + y_offset = margin + + for detector, iq in detector_iq_pairs: + # Get the result text + result_text = get_result_text(iq) + + # Format: "Query: Result" + query_text = detector.query + + # Calculate available space for query (leaving room for result) + result_width, _ = cv2.getTextSize(result_text, font, font_scale, thickness) + available_width = frame_width - result_width[0] - 4 * margin + + # Clip query text to fit + query_text = clip_text_to_width(query_text, font, font_scale, thickness, available_width) + full_text = f"{query_text}: {result_text}" + + # Get text dimensions + (text_width, text_height), baseline = cv2.getTextSize(full_text, font, font_scale, thickness) + banner_height = text_height + 2 * margin + + # Draw semi-transparent background + overlay = frame.copy() + cv2.rectangle(overlay, (0, y_offset), (frame_width, y_offset + banner_height), (0, 0, 0), -1) + cv2.addWeighted(overlay, 0.6, frame, 0.4, 0, frame) + + # Draw text + text_x = margin + text_y = y_offset + margin + text_height + cv2.putText(frame, full_text, (text_x, text_y), font, font_scale, (255, 255, 255), thickness, lineType=cv2.LINE_AA) + + # Move to next line + y_offset += banner_height + +def get_result_text(iq: ImageQuery) -> str: + """ + Extract the result text from an ImageQuery. + """ + if iq.result is None: + return "Pending" + + # Try to get the label (for binary and multiclass) + if hasattr(iq.result, 'label'): + if hasattr(iq.result.label, 'value'): + return str(iq.result.label.value) + return str(iq.result.label) + + # Fallback + return "Unknown" + +def clip_text_to_width(text: str, font, font_scale: float, thickness: int, max_width: int) -> str: + """ + Clip text to fit within a maximum width, adding ellipsis if needed. + """ + text_width, _ = cv2.getTextSize(text, font, font_scale, thickness) + + if text_width[0] <= max_width: + return text + + # Binary search for the right length + ellipsis = "..." + ellipsis_width, _ = cv2.getTextSize(ellipsis, font, font_scale, thickness) + available_width = max_width - ellipsis_width[0] + + if available_width <= 0: + return ellipsis + + # Estimate characters that fit + chars_per_pixel = len(text) / text_width[0] + estimated_chars = int(available_width * chars_per_pixel) + + # Find the exact length + for length in range(estimated_chars, 0, -1): + test_text = text[:length] + ellipsis + test_width, _ = cv2.getTextSize(test_text, font, font_scale, thickness) + if test_width[0] <= max_width: + return test_text + + return ellipsis \ No newline at end of file diff --git a/src/s2train.py b/src/s2train.py index dcd1a80..453c61b 100755 --- a/src/s2train.py +++ b/src/s2train.py @@ -23,10 +23,29 @@ def pprint_iq(iq: ImageQuery) -> None: label = '-' if iq.result is None else iq.result.label.value print(f'Label: {label}') else: - raise ValueError( - f'Unsupported result type: {type(iq.result)}' - ) - + # Handle multiclass and bounding box modes + if iq.result is not None: + # Multiclass: show the detected class + if hasattr(iq.result, 'label'): + if hasattr(iq.result.label, 'value'): + print(f'Class: {iq.result.label.value}') + else: + print(f'Class: {iq.result.label}') + # Bounding box: show list of boxes + elif hasattr(iq, 'rois'): + rois = iq.rois if iq.rois is not None else [] + print(f'Bounding Boxes: {len(rois)} box(es)') + for i, roi in enumerate(rois): + if hasattr(roi, 'geometry'): + bbox = roi.geometry + print(f' Box {i+1}: left={bbox.left:.3f}, top={bbox.top:.3f}, right={bbox.right:.3f}, bottom={bbox.bottom:.3f}') + else: + print(f' Box {i+1}: {roi}') + else: + print(f'Result type: {type(iq.result).__name__}') + else: + print('Result: None') + confidence = None if iq.result is None else iq.result.confidence confidence_str = '-' if confidence is None else f'{confidence * 100:.2f}%' print(f'Confidence: {confidence_str}') @@ -193,4 +212,4 @@ def submit_to_model_retry(detector, fmd: dict, ask_async: bool, wait: float, hum # Check if we have submitted the requested number of frames if num_submitted_frames == num_frames: print(f'Finshed submitting {num_frames} frames to {detector_id}.') - break \ No newline at end of file + break diff --git a/src/s3produce.py b/src/s3produce.py index f2004db..6c34cbd 100644 --- a/src/s3produce.py +++ b/src/s3produce.py @@ -6,6 +6,11 @@ import time import os import cv2 +import sys +import warnings +import logging +from io import StringIO +from contextlib import contextmanager from groundlight import Groundlight from tqdm.auto import tqdm @@ -18,6 +23,28 @@ from threaded_video_writer import ThreadedVideoWriter +@contextmanager +def suppress_output(): + """Context manager to suppress stdout, stderr, warnings, and logging temporarily.""" + old_stdout = sys.stdout + old_stderr = sys.stderr + sys.stdout = StringIO() + sys.stderr = StringIO() + + # Temporarily increase logging level to suppress Groundlight SDK logs + old_log_level = logging.root.level + logging.root.setLevel(logging.CRITICAL + 1) + + # Also suppress warnings + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + try: + yield + finally: + sys.stdout = old_stdout + sys.stderr = old_stderr + logging.root.setLevel(old_log_level) + def infer_and_produce_video(project: ProjectState, detector_ids: list[str], frame_stride: int, @@ -66,11 +93,12 @@ def infer_and_produce_video(project: ProjectState, web_server = FrameGrabWebServer(f"Producing {filename}...", port=web_preview_port, message=message) try: - for frame_num in tqdm(range(0, total_frames, frame_stride), "Producing video"): + pbar = tqdm(range(0, total_frames, frame_stride), desc="Producing video") + for frame_num in pbar: cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num) ret, frame = cap.read() if not ret: - print('Cannot read frame. Exiting...') + tqdm.write('Cannot read frame. Exiting...') break # Perform inference @@ -80,12 +108,14 @@ def infer_and_produce_video(project: ProjectState, retries = 0 while True: try: - iq = gl.submit_image_query( - detector=detector, - image=frame, - human_review=human_review, - wait=0.0, - ) + # Suppress Groundlight SDK warnings to avoid breaking progress bar + with suppress_output(): + iq = gl.submit_image_query( + detector=detector, + image=frame, + human_review=human_review, + wait=0.0, + ) break except Exception as e: retries += 1 @@ -93,13 +123,14 @@ def infer_and_produce_video(project: ProjectState, raise RuntimeError( f'Repeatedly encountered an exception while submitting image queries to {detector.id}.' ) - print(e) + tqdm.write(str(e)) time.sleep(1) iqs.append(iq) - # Annotate the frame - draw_iqs(detectors, iqs, frame) + # Annotate the frame (suppress warnings here too in case they come from result parsing) + with suppress_output(): + draw_iqs(detectors, iqs, frame) # Show in web browser web_server.show_image(frame) @@ -107,13 +138,17 @@ def infer_and_produce_video(project: ProjectState, # Write the frame writer.add_frame(frame) except KeyboardInterrupt: - print('User cancelled video production.') + tqdm.write('User cancelled video production.') finally: cap.release() writer.stop() print(f'Finished producing video at {filepath}') if __name__ == "__main__": + # Suppress Groundlight SDK warnings globally + logging.getLogger('groundlight').setLevel(logging.ERROR) + warnings.filterwarnings('ignore') + parser = argparse.ArgumentParser() parser.add_argument("project_dir", type=str, help="Path to the project directory") parser.add_argument("--detector-ids", type=str, nargs="+", required=True, help="One or more detector IDs to use") From c6cb0f4a9b224dc79c100dd4189850ac4ec52be4 Mon Sep 17 00:00:00 2001 From: none Date: Sun, 2 Nov 2025 02:29:13 +0000 Subject: [PATCH 2/2] updates --- src/s3produce.py | 0 1 file changed, 0 insertions(+), 0 deletions(-) mode change 100644 => 100755 src/s3produce.py diff --git a/src/s3produce.py b/src/s3produce.py old mode 100644 new mode 100755