-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Create RetinaNet object detection guide #2069
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR @sachinprasadhs - since this is a guide - we should have more detailed explanations that would help users understand our framework better. Lets add more of that.
Thanks!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! All minor comments. Also heads up I added keras-team/keras-hub#2219 for coco id - name mappings. Whenever it is released we can use it.
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a guide for object detection with RetinaNet using Keras Hub. I have identified a critical issue in the data loading logic, along with suggestions to improve code quality and clarity.
cache_dir=data_dir, | ||
extract=True, | ||
) | ||
data_dir = os.path.join(get_data, extracted_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The keras.utils.get_file
function returns the path to the downloaded file, not the directory where files are extracted. Use os.path.dirname
to get the directory of the downloaded file.
data_dir = os.path.join(get_data, extracted_dir) | |
data_dir = os.path.join(os.path.dirname(get_data), extracted_dir) |
year="2007", | ||
split="trainval", | ||
data_dir="./", | ||
voc_url=VOC_2007_URL, | ||
): | ||
extracted_dir = os.path.join("VOCdevkit", f"VOC{year}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The image_file_paths
variable is defined but not used. Remove it to avoid confusion.
year="2007", | |
split="trainval", | |
data_dir="./", | |
voc_url=VOC_2007_URL, | |
): | |
extracted_dir = os.path.join("VOCdevkit", f"VOC{year}") | |
# Parallel process all the images. | |
annotation_file_paths = tf.io.gfile.glob( | |
os.path.join(data_dir, "Annotations", "*.xml") | |
) |
def decode_tfds(record): | ||
"""Decodes a standard TFDS object detection record. | ||
|
||
Args: | ||
record: A dictionary representing a single TFDS record. | ||
|
||
Returns: | ||
A dictionary with "images" and "bounding_boxes". | ||
""" | ||
image = record["image"] | ||
image_shape = tf.shape(image) | ||
height, width = image_shape[0], image_shape[1] | ||
boxes = keras.utils.bounding_boxes.convert_format( | ||
record["objects"]["bbox"], | ||
source="rel_yxyx", | ||
target=bbox_format, | ||
height=height, | ||
width=width, | ||
) | ||
labels = record["objects"]["label"] | ||
|
||
bounding_boxes = {"boxes": boxes, "labels": labels} | ||
|
||
return {"images": image, "bounding_boxes": bounding_boxes} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
def preprocess_tfds(ds): | ||
"""Preprocesses a TFDS dataset for object detection. | ||
|
||
Args: | ||
ds: The TFDS dataset. | ||
resizing: A resizing function. | ||
max_box_layer: A max box processing function. | ||
batch_size: The batch size. | ||
|
||
Returns: | ||
A preprocessed TFDS dataset. | ||
""" | ||
ds = ds.map(resizing, num_parallel_calls=tf.data.AUTOTUNE) | ||
ds = ds.map(max_box_layer, num_parallel_calls=tf.data.AUTOTUNE) | ||
ds = ds.batch(batch_size, drop_remainder=True) | ||
return ds |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The resizing
and max_box_layer
functions, and batch_size
variable are used from the global scope, but are documented as arguments. Explicitly pass them as arguments to increase reusability.
def preprocess_tfds(ds): | |
"""Preprocesses a TFDS dataset for object detection. | |
Args: | |
ds: The TFDS dataset. | |
resizing: A resizing function. | |
max_box_layer: A max box processing function. | |
batch_size: The batch size. | |
Returns: | |
A preprocessed TFDS dataset. | |
""" | |
ds = ds.map(resizing, num_parallel_calls=tf.data.AUTOTUNE) | |
ds = ds.map(max_box_layer, num_parallel_calls=tf.data.AUTOTUNE) | |
ds = ds.batch(batch_size, drop_remainder=True) | |
return ds | |
def preprocess_tfds(ds, resizing, max_box_layer, batch_size): | |
"""Preprocesses a TFDS dataset for object detection. | |
Args: | |
ds: The TFDS dataset. | |
resizing: A resizing function. | |
max_box_layer: A max box processing function. | |
batch_size: The batch size. | |
Returns: | |
A preprocessed TFDS dataset. | |
""" | |
ds = ds.map(resizing, num_parallel_calls=tf.data.AUTOTUNE) | |
ds = ds.map(max_box_layer, num_parallel_calls=tf.data.AUTOTUNE) | |
ds = ds.batch(batch_size, drop_remainder=True) | |
return ds |
update_freq=1, | ||
), | ||
keras.callbacks.ModelCheckpoint( | ||
ckpt_path + "/{epoch:04d}-{val_loss:.2f}.weights.h5", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
keras-nlp 0.19.0 requires keras-hub==0.19.0, but you have keras-hub 0.20.0 which is incompatible. | ||
|
||
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR | ||
E0000 00:00:1746815719.896182 8973 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered | ||
E0000 00:00:1746815719.902635 8973 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary:
This Pull Request introduces comprehensive documentation detailing how to use an object detection model from Keras Hub, using RetinaNet as an example.
This tutorial includes: