Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 114 additions & 12 deletions src/drawing.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,23 +15,34 @@ def generate_unique_bgr_colors(n: int) -> list:
return colors

def draw_iqs(detectors: list[Detector], iqs: list[ImageQuery], frame: np.ndarray) -> None:
counting_iqs = []
for detector, counting_iq in zip(detectors, iqs):
# Separate IQs by type
bbox_iqs = [] # COUNT and BOUNDING_BOX modes
banner_iqs = [] # BINARY and MULTI_CLASS modes

for detector, iq in zip(detectors, iqs):
detector_mode = detector.mode
if detector_mode == ModeEnum.COUNT:
counting_iqs.append(counting_iq)
if detector_mode in (ModeEnum.COUNT, ModeEnum.BOUNDING_BOX):
bbox_iqs.append((detector, iq))
elif detector_mode in (ModeEnum.BINARY, ModeEnum.MULTI_CLASS):
banner_iqs.append((detector, iq))
else:
raise NotImplementedError(
f'Detector mode {detector_mode} is not yet supported.'
)

unique_bbox_colors = generate_unique_bgr_colors(len(counting_iqs))
for n, counting_iq in enumerate(counting_iqs):
color = unique_bbox_colors[n]
draw_bounding_boxes(counting_iq.rois, frame, color)

class_names = [d.mode_configuration["class_name"] for d in detectors]
draw_class_labels(frame, class_names, unique_bbox_colors)
# Draw bounding boxes for COUNT and BOUNDING_BOX modes
if bbox_iqs:
unique_bbox_colors = generate_unique_bgr_colors(len(bbox_iqs))
for n, (detector, iq) in enumerate(bbox_iqs):
color = unique_bbox_colors[n]
draw_bounding_boxes(iq.rois, frame, color)

class_names = [d.mode_configuration.get("class_name", f"Detector {i+1}") for i, (d, _) in enumerate(bbox_iqs)]
draw_class_labels(frame, class_names, unique_bbox_colors)

# Draw banners for BINARY and MULTI_CLASS modes
if banner_iqs:
draw_banners(banner_iqs, frame)

def draw_bounding_boxes(rois, frame: np.ndarray, color: tuple) -> None:
height, width = frame.shape[:2]
Expand Down Expand Up @@ -72,4 +83,95 @@ def draw_class_labels(frame: np.ndarray, class_names: list[str], colors: list[tu
cv2.rectangle(frame, top_left, bottom_right, (255, 255, 255), thickness=-1)

# Text
cv2.putText(frame, label, (x, y), font, font_scale, color, thickness, lineType=cv2.LINE_AA)
cv2.putText(frame, label, (x, y), font, font_scale, color, thickness, lineType=cv2.LINE_AA)

def draw_banners(detector_iq_pairs: list[tuple[Detector, ImageQuery]], frame: np.ndarray) -> None:
"""
Draw banners at the top of the frame showing query and result for binary/multiclass detectors.
Each detector gets its own line.
"""
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.7
thickness = 2
margin = 8
frame_height, frame_width = frame.shape[:2]

y_offset = margin

for detector, iq in detector_iq_pairs:
# Get the result text
result_text = get_result_text(iq)

# Format: "Query: Result"
query_text = detector.query

# Calculate available space for query (leaving room for result)
result_width, _ = cv2.getTextSize(result_text, font, font_scale, thickness)
available_width = frame_width - result_width[0] - 4 * margin

# Clip query text to fit
query_text = clip_text_to_width(query_text, font, font_scale, thickness, available_width)
full_text = f"{query_text}: {result_text}"

# Get text dimensions
(text_width, text_height), baseline = cv2.getTextSize(full_text, font, font_scale, thickness)
banner_height = text_height + 2 * margin

# Draw semi-transparent background
overlay = frame.copy()
cv2.rectangle(overlay, (0, y_offset), (frame_width, y_offset + banner_height), (0, 0, 0), -1)
cv2.addWeighted(overlay, 0.6, frame, 0.4, 0, frame)

# Draw text
text_x = margin
text_y = y_offset + margin + text_height
cv2.putText(frame, full_text, (text_x, text_y), font, font_scale, (255, 255, 255), thickness, lineType=cv2.LINE_AA)

# Move to next line
y_offset += banner_height

def get_result_text(iq: ImageQuery) -> str:
"""
Extract the result text from an ImageQuery.
"""
if iq.result is None:
return "Pending"

# Try to get the label (for binary and multiclass)
if hasattr(iq.result, 'label'):
if hasattr(iq.result.label, 'value'):
return str(iq.result.label.value)
return str(iq.result.label)

# Fallback
return "Unknown"

def clip_text_to_width(text: str, font, font_scale: float, thickness: int, max_width: int) -> str:
"""
Clip text to fit within a maximum width, adding ellipsis if needed.
"""
text_width, _ = cv2.getTextSize(text, font, font_scale, thickness)

if text_width[0] <= max_width:
return text

# Binary search for the right length
ellipsis = "..."
ellipsis_width, _ = cv2.getTextSize(ellipsis, font, font_scale, thickness)
available_width = max_width - ellipsis_width[0]

if available_width <= 0:
return ellipsis

# Estimate characters that fit
chars_per_pixel = len(text) / text_width[0]
estimated_chars = int(available_width * chars_per_pixel)

# Find the exact length
for length in range(estimated_chars, 0, -1):
test_text = text[:length] + ellipsis
test_width, _ = cv2.getTextSize(test_text, font, font_scale, thickness)
if test_width[0] <= max_width:
return test_text

return ellipsis
29 changes: 24 additions & 5 deletions src/s2train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,29 @@ def pprint_iq(iq: ImageQuery) -> None:
label = '-' if iq.result is None else iq.result.label.value
print(f'Label: {label}')
else:
raise ValueError(
f'Unsupported result type: {type(iq.result)}'
)

# Handle multiclass and bounding box modes
if iq.result is not None:
# Multiclass: show the detected class
if hasattr(iq.result, 'label'):
if hasattr(iq.result.label, 'value'):
print(f'Class: {iq.result.label.value}')
else:
print(f'Class: {iq.result.label}')
# Bounding box: show list of boxes
elif hasattr(iq, 'rois'):
rois = iq.rois if iq.rois is not None else []
print(f'Bounding Boxes: {len(rois)} box(es)')
for i, roi in enumerate(rois):
if hasattr(roi, 'geometry'):
bbox = roi.geometry
print(f' Box {i+1}: left={bbox.left:.3f}, top={bbox.top:.3f}, right={bbox.right:.3f}, bottom={bbox.bottom:.3f}')
else:
print(f' Box {i+1}: {roi}')
else:
print(f'Result type: {type(iq.result).__name__}')
else:
print('Result: None')

confidence = None if iq.result is None else iq.result.confidence
confidence_str = '-' if confidence is None else f'{confidence * 100:.2f}%'
print(f'Confidence: {confidence_str}')
Expand Down Expand Up @@ -193,4 +212,4 @@ def submit_to_model_retry(detector, fmd: dict, ask_async: bool, wait: float, hum
# Check if we have submitted the requested number of frames
if num_submitted_frames == num_frames:
print(f'Finshed submitting {num_frames} frames to {detector_id}.')
break
break
59 changes: 47 additions & 12 deletions src/s3produce.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,11 @@
import time
import os
import cv2
import sys
import warnings
import logging
from io import StringIO
from contextlib import contextmanager

from groundlight import Groundlight
from tqdm.auto import tqdm
Expand All @@ -18,6 +23,28 @@

from threaded_video_writer import ThreadedVideoWriter

@contextmanager
def suppress_output():
"""Context manager to suppress stdout, stderr, warnings, and logging temporarily."""
old_stdout = sys.stdout
old_stderr = sys.stderr
sys.stdout = StringIO()
sys.stderr = StringIO()

# Temporarily increase logging level to suppress Groundlight SDK logs
old_log_level = logging.root.level
logging.root.setLevel(logging.CRITICAL + 1)

# Also suppress warnings
with warnings.catch_warnings():
warnings.simplefilter("ignore")
try:
yield
finally:
sys.stdout = old_stdout
sys.stderr = old_stderr
logging.root.setLevel(old_log_level)

def infer_and_produce_video(project: ProjectState,
detector_ids: list[str],
frame_stride: int,
Expand Down Expand Up @@ -66,11 +93,12 @@ def infer_and_produce_video(project: ProjectState,
web_server = FrameGrabWebServer(f"Producing {filename}...", port=web_preview_port, message=message)

try:
for frame_num in tqdm(range(0, total_frames, frame_stride), "Producing video"):
pbar = tqdm(range(0, total_frames, frame_stride), desc="Producing video")
for frame_num in pbar:
cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
ret, frame = cap.read()
if not ret:
print('Cannot read frame. Exiting...')
tqdm.write('Cannot read frame. Exiting...')
break

# Perform inference
Expand All @@ -80,40 +108,47 @@ def infer_and_produce_video(project: ProjectState,
retries = 0
while True:
try:
iq = gl.submit_image_query(
detector=detector,
image=frame,
human_review=human_review,
wait=0.0,
)
# Suppress Groundlight SDK warnings to avoid breaking progress bar
with suppress_output():
iq = gl.submit_image_query(
detector=detector,
image=frame,
human_review=human_review,
wait=0.0,
)
break
except Exception as e:
retries += 1
if retries == max_retries:
raise RuntimeError(
f'Repeatedly encountered an exception while submitting image queries to {detector.id}.'
)
print(e)
tqdm.write(str(e))
time.sleep(1)

iqs.append(iq)

# Annotate the frame
draw_iqs(detectors, iqs, frame)
# Annotate the frame (suppress warnings here too in case they come from result parsing)
with suppress_output():
draw_iqs(detectors, iqs, frame)

# Show in web browser
web_server.show_image(frame)

# Write the frame
writer.add_frame(frame)
except KeyboardInterrupt:
print('User cancelled video production.')
tqdm.write('User cancelled video production.')
finally:
cap.release()
writer.stop()
print(f'Finished producing video at {filepath}')

if __name__ == "__main__":
# Suppress Groundlight SDK warnings globally
logging.getLogger('groundlight').setLevel(logging.ERROR)
warnings.filterwarnings('ignore')

parser = argparse.ArgumentParser()
parser.add_argument("project_dir", type=str, help="Path to the project directory")
parser.add_argument("--detector-ids", type=str, nargs="+", required=True, help="One or more detector IDs to use")
Expand Down