Skip to content

Create RetinaNet object detection guide #2069

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: master
Choose a base branch
from

Conversation

sachinprasadhs
Copy link
Collaborator

@sachinprasadhs sachinprasadhs commented Mar 24, 2025

Summary:
This Pull Request introduces comprehensive documentation detailing how to use an object detection model from Keras Hub, using RetinaNet as an example.

This tutorial includes:

  • Loading and preprocessing a training dataset
  • Demonstrating inference using pretrained weights
  • Fine-tuning a pretrained object detector
  • Custom training of an object detector
  • Summary

Copy link
Collaborator

@divyashreepathihalli divyashreepathihalli left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @sachinprasadhs - since this is a guide - we should have more detailed explanations that would help users understand our framework better. Lets add more of that.

Thanks!

Copy link
Member

@mattdangerw mattdangerw left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! All minor comments. Also heads up I added keras-team/keras-hub#2219 for coco id - name mappings. Whenever it is released we can use it.

@divyashreepathihalli
Copy link
Collaborator

/gemini review

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request introduces a guide for object detection with RetinaNet using Keras Hub. I have identified a critical issue in the data loading logic, along with suggestions to improve code quality and clarity.

cache_dir=data_dir,
extract=True,
)
data_dir = os.path.join(get_data, extracted_dir)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The keras.utils.get_file function returns the path to the downloaded file, not the directory where files are extracted. Use os.path.dirname to get the directory of the downloaded file.

Suggested change
data_dir = os.path.join(get_data, extracted_dir)
data_dir = os.path.join(os.path.dirname(get_data), extracted_dir)

Comment on lines +349 to +354
year="2007",
split="trainval",
data_dir="./",
voc_url=VOC_2007_URL,
):
extracted_dir = os.path.join("VOCdevkit", f"VOC{year}")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The image_file_paths variable is defined but not used. Remove it to avoid confusion.

Suggested change
year="2007",
split="trainval",
data_dir="./",
voc_url=VOC_2007_URL,
):
extracted_dir = os.path.join("VOCdevkit", f"VOC{year}")
# Parallel process all the images.
annotation_file_paths = tf.io.gfile.glob(
os.path.join(data_dir, "Annotations", "*.xml")
)

Comment on lines +524 to +547
def decode_tfds(record):
"""Decodes a standard TFDS object detection record.

Args:
record: A dictionary representing a single TFDS record.

Returns:
A dictionary with "images" and "bounding_boxes".
"""
image = record["image"]
image_shape = tf.shape(image)
height, width = image_shape[0], image_shape[1]
boxes = keras.utils.bounding_boxes.convert_format(
record["objects"]["bbox"],
source="rel_yxyx",
target=bbox_format,
height=height,
width=width,
)
labels = record["objects"]["label"]

bounding_boxes = {"boxes": boxes, "labels": labels}

return {"images": image, "bounding_boxes": bounding_boxes}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The function decode_tfds is defined but never used. Remove it to improve maintainability.

Comment on lines +550 to +565
def preprocess_tfds(ds):
"""Preprocesses a TFDS dataset for object detection.

Args:
ds: The TFDS dataset.
resizing: A resizing function.
max_box_layer: A max box processing function.
batch_size: The batch size.

Returns:
A preprocessed TFDS dataset.
"""
ds = ds.map(resizing, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.map(max_box_layer, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.batch(batch_size, drop_remainder=True)
return ds
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The resizing and max_box_layer functions, and batch_size variable are used from the global scope, but are documented as arguments. Explicitly pass them as arguments to increase reusability.

Suggested change
def preprocess_tfds(ds):
"""Preprocesses a TFDS dataset for object detection.
Args:
ds: The TFDS dataset.
resizing: A resizing function.
max_box_layer: A max box processing function.
batch_size: The batch size.
Returns:
A preprocessed TFDS dataset.
"""
ds = ds.map(resizing, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.map(max_box_layer, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.batch(batch_size, drop_remainder=True)
return ds
def preprocess_tfds(ds, resizing, max_box_layer, batch_size):
"""Preprocesses a TFDS dataset for object detection.
Args:
ds: The TFDS dataset.
resizing: A resizing function.
max_box_layer: A max box processing function.
batch_size: The batch size.
Returns:
A preprocessed TFDS dataset.
"""
ds = ds.map(resizing, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.map(max_box_layer, num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.batch(batch_size, drop_remainder=True)
return ds

update_freq=1,
),
keras.callbacks.ModelCheckpoint(
ckpt_path + "/{epoch:04d}-{val_loss:.2f}.weights.h5",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Use os.path.join for constructing file paths to ensure platform independence.

Suggested change
ckpt_path + "/{epoch:04d}-{val_loss:.2f}.weights.h5",
os.path.join(ckpt_path, "{epoch:04d}-{val_loss:.2f}.weights.h5"),

Comment on lines +77 to +81
keras-nlp 0.19.0 requires keras-hub==0.19.0, but you have keras-hub 0.20.0 which is incompatible.

WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
E0000 00:00:1746815719.896182 8973 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1746815719.902635 8973 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The output block contains a package incompatibility warning and some CUDA/cuDNN errors. Either resolve the dependency conflict, explain why these messages appear and if they can be safely ignored, or remove this output block from the guide to avoid confusion.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants