From 269cc0dc9f0cd368cd1ff683f2a6e12f3d54dfb4 Mon Sep 17 00:00:00 2001 From: Shipra <138140065+Shi-pra-19@users.noreply.github.com> Date: Tue, 7 Oct 2025 17:42:33 +0000 Subject: [PATCH 1/2] migrate retinanet example to keras 3 --- examples/vision/retinanet.py | 245 +++++++++++++++++++---------------- 1 file changed, 132 insertions(+), 113 deletions(-) diff --git a/examples/vision/retinanet.py b/examples/vision/retinanet.py index c27685616f..9061d9e3f3 100644 --- a/examples/vision/retinanet.py +++ b/examples/vision/retinanet.py @@ -29,12 +29,14 @@ import os +import glob import re import zipfile import numpy as np import tensorflow as tf -from tensorflow import keras + +import keras import matplotlib.pyplot as plt import tensorflow_datasets as tfds @@ -50,10 +52,14 @@ url = "https://github.com/srihari-humbarwadi/datasets/releases/download/v0.1.0/data.zip" filename = os.path.join(os.getcwd(), "data.zip") -keras.utils.get_file(filename, url) +file_path = keras.utils.get_file( + fname=os.path.basename(filename), + origin=url, + cache_dir="." +) -with zipfile.ZipFile("data.zip", "r") as z_fp: +with zipfile.ZipFile(file_path, "r") as z_fp: z_fp.extractall("./") @@ -71,6 +77,7 @@ """ + def swap_xy(boxes): """Swaps order the of x and y coordinates of the boxes. @@ -80,7 +87,7 @@ def swap_xy(boxes): Returns: swapped boxes with shape same as that of boxes. """ - return tf.stack([boxes[:, 1], boxes[:, 0], boxes[:, 3], boxes[:, 2]], axis=-1) + return keras.ops.stack([boxes[:, 1], boxes[:, 0], boxes[:, 3], boxes[:, 2]], axis=-1) def convert_to_xywh(boxes): @@ -94,7 +101,7 @@ def convert_to_xywh(boxes): Returns: converted boxes with shape same as that of boxes. """ - return tf.concat( + return keras.ops.concatenate( [(boxes[..., :2] + boxes[..., 2:]) / 2.0, boxes[..., 2:] - boxes[..., :2]], axis=-1, ) @@ -111,12 +118,11 @@ def convert_to_corners(boxes): Returns: converted boxes with shape same as that of boxes. """ - return tf.concat( + return keras.ops.concatenate( [boxes[..., :2] - boxes[..., 2:] / 2.0, boxes[..., :2] + boxes[..., 2:] / 2.0], axis=-1, ) - """ ## Computing pairwise Intersection Over Union (IOU) @@ -143,16 +149,16 @@ def compute_iou(boxes1, boxes2): """ boxes1_corners = convert_to_corners(boxes1) boxes2_corners = convert_to_corners(boxes2) - lu = tf.maximum(boxes1_corners[:, None, :2], boxes2_corners[:, :2]) - rd = tf.minimum(boxes1_corners[:, None, 2:], boxes2_corners[:, 2:]) - intersection = tf.maximum(0.0, rd - lu) + lu = keras.ops.maximum(boxes1_corners[:, None, :2], boxes2_corners[:, :2]) + rd = keras.ops.minimum(boxes1_corners[:, None, 2:], boxes2_corners[:, 2:]) + intersection = keras.ops.maximum(0.0, rd - lu) intersection_area = intersection[:, :, 0] * intersection[:, :, 1] boxes1_area = boxes1[:, 2] * boxes1[:, 3] boxes2_area = boxes2[:, 2] * boxes2[:, 3] - union_area = tf.maximum( + union_area = keras.ops.maximum( boxes1_area[:, None] + boxes2_area - intersection_area, 1e-8 ) - return tf.clip_by_value(intersection_area / union_area, 0.0, 1.0) + return keras.ops.clip(intersection_area / union_area, 0.0, 1.0) def visualize_detections( @@ -195,7 +201,6 @@ def visualize_detections( (at three scales and three ratios). """ - class AnchorBox: """Generates anchor boxes. @@ -232,14 +237,14 @@ def _compute_dims(self): for area in self._areas: anchor_dims = [] for ratio in self.aspect_ratios: - anchor_height = tf.math.sqrt(area / ratio) + anchor_height = keras.ops.sqrt(area / ratio) anchor_width = area / anchor_height - dims = tf.reshape( - tf.stack([anchor_width, anchor_height], axis=-1), [1, 1, 2] + dims = keras.ops.reshape( + keras.ops.stack([anchor_width, anchor_height], axis=-1), [1, 1, 2] ) for scale in self.scales: anchor_dims.append(scale * dims) - anchor_dims_all.append(tf.stack(anchor_dims, axis=-2)) + anchor_dims_all.append(keras.ops.stack(anchor_dims, axis=-2)) return anchor_dims_all def _get_anchors(self, feature_height, feature_width, level): @@ -255,16 +260,16 @@ def _get_anchors(self, feature_height, feature_width, level): anchor boxes with the shape `(feature_height * feature_width * num_anchors, 4)` """ - rx = tf.range(feature_width, dtype=tf.float32) + 0.5 - ry = tf.range(feature_height, dtype=tf.float32) + 0.5 - centers = tf.stack(tf.meshgrid(rx, ry), axis=-1) * self._strides[level - 3] - centers = tf.expand_dims(centers, axis=-2) - centers = tf.tile(centers, [1, 1, self._num_anchors, 1]) - dims = tf.tile( + rx = keras.ops.arange(feature_width, dtype="float32") + 0.5 + ry = keras.ops.arange(feature_height, dtype="float32") + 0.5 + centers = keras.ops.stack(keras.ops.meshgrid(rx, ry), axis=-1) * self._strides[level - 3] + centers = keras.ops.expand_dims(centers, axis=-2) + centers = keras.ops.tile(centers, [1, 1, self._num_anchors, 1]) + dims = keras.ops.tile( self._anchor_dims[level - 3], [feature_height, feature_width, 1, 1] ) - anchors = tf.concat([centers, dims], axis=-1) - return tf.reshape( + anchors = keras.ops.concatenate([centers, dims], axis=-1) + return keras.ops.reshape( anchors, [feature_height * feature_width * self._num_anchors, 4] ) @@ -281,14 +286,13 @@ def get_anchors(self, image_height, image_width): """ anchors = [ self._get_anchors( - tf.math.ceil(image_height / 2**i), - tf.math.ceil(image_width / 2**i), + keras.ops.ceil(image_height / 2**i), + keras.ops.ceil(image_width / 2**i), i, ) for i in range(3, 8) ] - return tf.concat(anchors, axis=0) - + return keras.ops.concatenate(anchors, axis=0) """ ## Preprocessing data @@ -317,9 +321,9 @@ def random_flip_horizontal(image, boxes): Returns: Randomly flipped image and boxes """ - if tf.random.uniform(()) > 0.5: + if keras.random.uniform(()) > 0.5: image = tf.image.flip_left_right(image) - boxes = tf.stack( + boxes = keras.ops.stack( [1 - boxes[:, 2], boxes[:, 1], 1 - boxes[:, 0], boxes[:, 3]], axis=-1 ) return image, boxes @@ -355,16 +359,16 @@ def resize_and_pad_image( image_shape: Shape of the image before padding. ratio: The scaling factor used to resize the image """ - image_shape = tf.cast(tf.shape(image)[:2], dtype=tf.float32) + image_shape = keras.ops.cast(keras.ops.shape(image)[:2], dtype="float32") if jitter is not None: - min_side = tf.random.uniform((), jitter[0], jitter[1], dtype=tf.float32) - ratio = min_side / tf.reduce_min(image_shape) - if ratio * tf.reduce_max(image_shape) > max_side: - ratio = max_side / tf.reduce_max(image_shape) + min_side = keras.random.uniform((), jitter[0], jitter[1], dtype="float32") + ratio = min_side / keras.ops.min(image_shape) + if ratio * keras.ops.max(image_shape) > max_side: + ratio = max_side / keras.ops.max(image_shape) image_shape = ratio * image_shape - image = tf.image.resize(image, tf.cast(image_shape, dtype=tf.int32)) - padded_image_shape = tf.cast( - tf.math.ceil(image_shape / stride) * stride, dtype=tf.int32 + image = tf.image.resize(image, keras.ops.cast(image_shape, dtype=tf.int32)) + padded_image_shape = keras.ops.cast( + keras.ops.ceil(image_shape / stride) * stride, dtype="int32" ) image = tf.image.pad_to_bounding_box( image, 0, 0, padded_image_shape[0], padded_image_shape[1] @@ -387,12 +391,12 @@ def preprocess_data(sample): """ image = sample["image"] bbox = swap_xy(sample["objects"]["bbox"]) - class_id = tf.cast(sample["objects"]["label"], dtype=tf.int32) + class_id = keras.ops.cast(sample["objects"]["label"], dtype="int32") image, bbox = random_flip_horizontal(image, bbox) image, image_shape, _ = resize_and_pad_image(image) - bbox = tf.stack( + bbox = keras.ops.stack( [ bbox[:, 0] * image_shape[1], bbox[:, 1] * image_shape[0], @@ -434,8 +438,8 @@ class LabelEncoder: def __init__(self): self._anchor_box = AnchorBox() - self._box_variance = tf.convert_to_tensor( - [0.1, 0.1, 0.2, 0.2], dtype=tf.float32 + self._box_variance = keras.ops.convert_to_tensor( + [0.1, 0.1, 0.2, 0.2], dtype="float32" ) def _match_anchor_boxes( @@ -472,23 +476,23 @@ def _match_anchor_boxes( training """ iou_matrix = compute_iou(anchor_boxes, gt_boxes) - max_iou = tf.reduce_max(iou_matrix, axis=1) - matched_gt_idx = tf.argmax(iou_matrix, axis=1) - positive_mask = tf.greater_equal(max_iou, match_iou) - negative_mask = tf.less(max_iou, ignore_iou) - ignore_mask = tf.logical_not(tf.logical_or(positive_mask, negative_mask)) + max_iou = keras.ops.amax(iou_matrix, axis=1) + matched_gt_idx = keras.ops.argmax(iou_matrix, axis=1) + positive_mask = keras.ops.greater_equal(max_iou, match_iou) + negative_mask = keras.ops.less(max_iou, ignore_iou) + ignore_mask = keras.ops.logical_not(keras.ops.logical_or(positive_mask, negative_mask)) return ( matched_gt_idx, - tf.cast(positive_mask, dtype=tf.float32), - tf.cast(ignore_mask, dtype=tf.float32), + keras.ops.cast(positive_mask, dtype="float32"), + keras.ops.cast(ignore_mask, dtype="float32"), ) def _compute_box_target(self, anchor_boxes, matched_gt_boxes): """Transforms the ground truth boxes into targets for training""" - box_target = tf.concat( + box_target = keras.ops.concatenate( [ (matched_gt_boxes[:, :2] - anchor_boxes[:, :2]) / anchor_boxes[:, 2:], - tf.math.log(matched_gt_boxes[:, 2:] / anchor_boxes[:, 2:]), + keras.ops.log(matched_gt_boxes[:, 2:] / anchor_boxes[:, 2:]), ], axis=-1, ) @@ -498,31 +502,31 @@ def _compute_box_target(self, anchor_boxes, matched_gt_boxes): def _encode_sample(self, image_shape, gt_boxes, cls_ids): """Creates box and classification targets for a single sample""" anchor_boxes = self._anchor_box.get_anchors(image_shape[1], image_shape[2]) - cls_ids = tf.cast(cls_ids, dtype=tf.float32) + cls_ids = keras.ops.cast(cls_ids, dtype="float32") matched_gt_idx, positive_mask, ignore_mask = self._match_anchor_boxes( anchor_boxes, gt_boxes ) - matched_gt_boxes = tf.gather(gt_boxes, matched_gt_idx) + matched_gt_boxes = keras.ops.take(gt_boxes, matched_gt_idx,axis=0) box_target = self._compute_box_target(anchor_boxes, matched_gt_boxes) - matched_gt_cls_ids = tf.gather(cls_ids, matched_gt_idx) - cls_target = tf.where( - tf.not_equal(positive_mask, 1.0), -1.0, matched_gt_cls_ids + matched_gt_cls_ids = keras.ops.take(cls_ids, matched_gt_idx, axis=0) + cls_target = keras.ops.where( + keras.ops.not_equal(positive_mask, 1.0), -1.0, matched_gt_cls_ids ) - cls_target = tf.where(tf.equal(ignore_mask, 1.0), -2.0, cls_target) - cls_target = tf.expand_dims(cls_target, axis=-1) - label = tf.concat([box_target, cls_target], axis=-1) + cls_target = keras.ops.where(keras.ops.equal(ignore_mask, 1.0), -2.0, cls_target) + cls_target = keras.ops.expand_dims(cls_target, axis=-1) + label = keras.ops.concatenate([box_target, cls_target], axis=-1) return label def encode_batch(self, batch_images, gt_boxes, cls_ids): """Creates box and classification targets for a batch""" - images_shape = tf.shape(batch_images) + images_shape = keras.ops.shape(batch_images) batch_size = images_shape[0] labels = tf.TensorArray(dtype=tf.float32, size=batch_size, dynamic_size=True) for i in range(batch_size): label = self._encode_sample(images_shape, gt_boxes[i], cls_ids[i]) labels = labels.write(i, label) - batch_images = tf.keras.applications.resnet.preprocess_input(batch_images) + batch_images = keras.applications.resnet.preprocess_input(batch_images) return batch_images, labels.stack() @@ -545,7 +549,7 @@ def get_backbone(): for layer_name in ["conv3_block4_out", "conv4_block6_out", "conv5_block3_out"] ] return keras.Model( - inputs=[backbone.inputs], outputs=[c3_output, c4_output, c5_output] + inputs=backbone.inputs, outputs=[c3_output, c4_output, c5_output] ) @@ -587,7 +591,7 @@ def call(self, images, training=False): p4_output = self.conv_c4_3x3(p4_output) p5_output = self.conv_c5_3x3(p5_output) p6_output = self.conv_c6_3x3(c5_output) - p7_output = self.conv_c7_3x3(tf.nn.relu(p6_output)) + p7_output = self.conv_c7_3x3(keras.ops.relu(p6_output)) return p3_output, p4_output, p5_output, p6_output, p7_output @@ -611,7 +615,7 @@ def build_head(output_filters, bias_init): or the box regression head depending on `output_filters`. """ head = keras.Sequential([keras.Input(shape=[None, None, 256])]) - kernel_init = tf.initializers.RandomNormal(0.0, 0.01) + kernel_init = keras.initializers.RandomNormal(0.0, 0.01) for _ in range(4): head.add( keras.layers.Conv2D(256, 3, padding="same", kernel_initializer=kernel_init) @@ -649,23 +653,23 @@ def __init__(self, num_classes, backbone=None, **kwargs): self.fpn = FeaturePyramid(backbone) self.num_classes = num_classes - prior_probability = tf.constant_initializer(-np.log((1 - 0.01) / 0.01)) + prior_probability = prior_probability = keras.initializers.Constant(-np.log((1 - 0.01) / 0.01)) self.cls_head = build_head(9 * num_classes, prior_probability) self.box_head = build_head(9 * 4, "zeros") def call(self, image, training=False): features = self.fpn(image, training=training) - N = tf.shape(image)[0] + N = keras.ops.shape(image)[0] cls_outputs = [] box_outputs = [] for feature in features: - box_outputs.append(tf.reshape(self.box_head(feature), [N, -1, 4])) + box_outputs.append(keras.ops.reshape(self.box_head(feature), [N, -1, 4])) cls_outputs.append( - tf.reshape(self.cls_head(feature), [N, -1, self.num_classes]) + keras.ops.reshape(self.cls_head(feature), [N, -1, self.num_classes]) ) - cls_outputs = tf.concat(cls_outputs, axis=1) - box_outputs = tf.concat(box_outputs, axis=1) - return tf.concat([box_outputs, cls_outputs], axis=-1) + cls_outputs = keras.ops.concatenate(cls_outputs, axis=1) + box_outputs = keras.ops.concatenate(box_outputs, axis=1) + return keras.ops.concatenate([box_outputs, cls_outputs], axis=-1) """ @@ -673,7 +677,7 @@ def call(self, image, training=False): """ -class DecodePredictions(tf.keras.layers.Layer): +class DecodePredictions(keras.layers.Layer): """A Keras layer that decodes predictions of the RetinaNet model. Attributes: @@ -707,16 +711,16 @@ def __init__( self.max_detections = max_detections self._anchor_box = AnchorBox() - self._box_variance = tf.convert_to_tensor( - [0.1, 0.1, 0.2, 0.2], dtype=tf.float32 + self._box_variance = keras.ops.convert_to_tensor( + [0.1, 0.1, 0.2, 0.2], dtype="float32" ) def _decode_box_predictions(self, anchor_boxes, box_predictions): boxes = box_predictions * self._box_variance - boxes = tf.concat( + boxes = keras.ops.concatenate( [ boxes[:, :, :2] * anchor_boxes[:, :, 2:] + anchor_boxes[:, :, :2], - tf.math.exp(boxes[:, :, 2:]) * anchor_boxes[:, :, 2:], + keras.ops.exp(boxes[:, :, 2:]) * anchor_boxes[:, :, 2:], ], axis=-1, ) @@ -724,14 +728,14 @@ def _decode_box_predictions(self, anchor_boxes, box_predictions): return boxes_transformed def call(self, images, predictions): - image_shape = tf.cast(tf.shape(images), dtype=tf.float32) + image_shape = keras.ops.cast(keras.ops.shape(images), dtype="float32") anchor_boxes = self._anchor_box.get_anchors(image_shape[1], image_shape[2]) box_predictions = predictions[:, :, :4] - cls_predictions = tf.nn.sigmoid(predictions[:, :, 4:]) + cls_predictions = keras.ops.sigmoid(predictions[:, :, 4:]) boxes = self._decode_box_predictions(anchor_boxes[None, ...], box_predictions) return tf.image.combined_non_max_suppression( - tf.expand_dims(boxes, axis=2), + keras.ops.expand_dims(boxes, axis=2), cls_predictions, self.max_detections_per_class, self.max_detections, @@ -746,49 +750,53 @@ def call(self, images, predictions): """ -class RetinaNetBoxLoss(tf.losses.Loss): +class RetinaNetBoxLoss(keras.losses.Loss): """Implements Smooth L1 loss""" def __init__(self, delta): - super().__init__(reduction="none", name="RetinaNetBoxLoss") + super().__init__( + reduction="none", name="RetinaNetBoxLoss" + ) self._delta = delta def call(self, y_true, y_pred): difference = y_true - y_pred - absolute_difference = tf.abs(difference) - squared_difference = difference**2 - loss = tf.where( - tf.less(absolute_difference, self._delta), + absolute_difference = keras.ops.abs(difference) + squared_difference = difference ** 2 + loss = keras.ops.where( + keras.ops.less(absolute_difference, self._delta), 0.5 * squared_difference, absolute_difference - 0.5, ) - return tf.reduce_sum(loss, axis=-1) + return keras.ops.sum(loss, axis=-1) -class RetinaNetClassificationLoss(tf.losses.Loss): +class RetinaNetClassificationLoss(keras.losses.Loss): """Implements Focal loss""" def __init__(self, alpha, gamma): - super().__init__(reduction="none", name="RetinaNetClassificationLoss") + super().__init__( + reduction="none", name="RetinaNetClassificationLoss" + ) self._alpha = alpha self._gamma = gamma def call(self, y_true, y_pred): - cross_entropy = tf.nn.sigmoid_cross_entropy_with_logits( - labels=y_true, logits=y_pred + cross_entropy = keras.ops.binary_crossentropy( + y_true, y_pred ) - probs = tf.nn.sigmoid(y_pred) - alpha = tf.where(tf.equal(y_true, 1.0), self._alpha, (1.0 - self._alpha)) - pt = tf.where(tf.equal(y_true, 1.0), probs, 1 - probs) - loss = alpha * tf.pow(1.0 - pt, self._gamma) * cross_entropy - return tf.reduce_sum(loss, axis=-1) + probs = keras.ops.sigmoid(y_pred) + alpha = keras.ops.where(keras.ops.equal(y_true, 1.0), self._alpha, (1.0 - self._alpha)) + pt = keras.ops.where(keras.ops.equal(y_true, 1.0), probs, 1 - probs) + loss = alpha * keras.ops.power(1.0 - pt, self._gamma) * cross_entropy + return keras.ops.sum(loss, axis=-1) -class RetinaNetLoss(tf.losses.Loss): +class RetinaNetLoss(keras.losses.Loss): """Wrapper to combine both the losses""" def __init__(self, num_classes=80, alpha=0.25, gamma=2.0, delta=1.0): - super().__init__(reduction="auto", name="RetinaNetLoss") + super().__init__(reduction="sum_over_batch_size", name="RetinaNetLoss") self._clf_loss = RetinaNetClassificationLoss(alpha, gamma) self._box_loss = RetinaNetBoxLoss(delta) self._num_classes = num_classes @@ -798,13 +806,13 @@ def call(self, y_true, y_pred): box_labels = y_true[:, :, :4] box_predictions = y_pred[:, :, :4] cls_labels = tf.one_hot( - tf.cast(y_true[:, :, 4], dtype=tf.int32), + keras.ops.cast(y_true[:, :, 4], dtype="int32"), depth=self._num_classes, dtype=tf.float32, ) cls_predictions = y_pred[:, :, 4:] - positive_mask = tf.cast(tf.greater(y_true[:, :, 4], -1.0), dtype=tf.float32) - ignore_mask = tf.cast(tf.equal(y_true[:, :, 4], -2.0), dtype=tf.float32) + positive_mask = keras.ops.cast(tf.greater(y_true[:, :, 4], -1.0), dtype="float32") + ignore_mask = keras.ops.cast(keras.ops.equal(y_true[:, :, 4], -2.0), dtype="float32") clf_loss = self._clf_loss(cls_labels, cls_predictions) box_loss = self._box_loss(box_labels, box_predictions) clf_loss = tf.where(tf.equal(ignore_mask, 1.0), 0.0, clf_loss) @@ -828,7 +836,7 @@ def call(self, y_true, y_pred): learning_rates = [2.5e-06, 0.000625, 0.00125, 0.0025, 0.00025, 2.5e-05] learning_rate_boundaries = [125, 250, 500, 240000, 360000] -learning_rate_fn = tf.optimizers.schedules.PiecewiseConstantDecay( +learning_rate_fn = keras.optimizers.schedules.PiecewiseConstantDecay( boundaries=learning_rate_boundaries, values=learning_rates ) @@ -840,7 +848,7 @@ def call(self, y_true, y_pred): loss_fn = RetinaNetLoss(num_classes) model = RetinaNet(num_classes, resnet50_backbone) -optimizer = tf.keras.optimizers.legacy.SGD(learning_rate=learning_rate_fn, momentum=0.9) +optimizer = keras.optimizers.SGD(learning_rate=learning_rate_fn, momentum=0.9) model.compile(loss=loss_fn, optimizer=optimizer) """ @@ -848,8 +856,8 @@ def call(self, y_true, y_pred): """ callbacks_list = [ - tf.keras.callbacks.ModelCheckpoint( - filepath=os.path.join(model_dir, "weights" + "_epoch_{epoch}"), + keras.callbacks.ModelCheckpoint( + filepath=os.path.join(model_dir, "weights" + "_epoch_{epoch}.weights.h5"), monitor="loss", save_best_only=False, save_weights_only=True, @@ -932,19 +940,30 @@ def call(self, y_true, y_pred): """ # Change this to `model_dir` when not using the downloaded weights -weights_dir = "data" -latest_checkpoint = tf.train.latest_checkpoint(weights_dir) -model.load_weights(latest_checkpoint) +def get_latest_weights(model_dir): + weight_files = glob.glob(os.path.join(model_dir, "*.weights.h5")) + + if not weight_files: + raise FileNotFoundError(f"No weight files found in {model_dir}") + + latest_weight_file = max(weight_files, key=os.path.getmtime) + + return latest_weight_file + + +model.load_weights(get_latest_weights(model_dir)) + """ ## Building inference model """ -image = tf.keras.Input(shape=[None, None, 3], name="image") +image = keras.Input(shape=[None, None, 3], name="image") predictions = model(image, training=False) detections = DecodePredictions(confidence_threshold=0.5)(image, predictions) -inference_model = tf.keras.Model(inputs=image, outputs=detections) +inference_model = keras.Model(inputs=image, outputs=detections) + """ ## Generating detections @@ -953,15 +972,15 @@ def call(self, y_true, y_pred): def prepare_image(image): image, _, ratio = resize_and_pad_image(image, jitter=None) - image = tf.keras.applications.resnet.preprocess_input(image) - return tf.expand_dims(image, axis=0), ratio + image = keras.applications.resnet.preprocess_input(image) + return keras.ops.expand_dims(image, axis=0), ratio val_dataset = tfds.load("coco/2017", split="validation", data_dir="data") int2str = dataset_info.features["objects"]["label"].int2str for sample in val_dataset.take(2): - image = tf.cast(sample["image"], dtype=tf.float32) + image = keras.ops.cast(sample["image"], dtype="float32") input_image, ratio = prepare_image(image) detections = inference_model.predict(input_image) num_detections = detections.valid_detections[0] From 153e9a885ad9ca35ac276538d88bb1e3652119f5 Mon Sep 17 00:00:00 2001 From: Shipra <138140065+Shi-pra-19@users.noreply.github.com> Date: Wed, 8 Oct 2025 16:17:25 +0000 Subject: [PATCH 2/2] minor fixes --- examples/vision/retinanet.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/examples/vision/retinanet.py b/examples/vision/retinanet.py index 9061d9e3f3..600204ad71 100644 --- a/examples/vision/retinanet.py +++ b/examples/vision/retinanet.py @@ -653,7 +653,7 @@ def __init__(self, num_classes, backbone=None, **kwargs): self.fpn = FeaturePyramid(backbone) self.num_classes = num_classes - prior_probability = prior_probability = keras.initializers.Constant(-np.log((1 - 0.01) / 0.01)) + prior_probability = keras.initializers.Constant(-np.log((1 - 0.01) / 0.01)) self.cls_head = build_head(9 * num_classes, prior_probability) self.box_head = build_head(9 * 4, "zeros") @@ -783,7 +783,7 @@ def __init__(self, alpha, gamma): def call(self, y_true, y_pred): cross_entropy = keras.ops.binary_crossentropy( - y_true, y_pred + y_true, y_pred, from_logits = True ) probs = keras.ops.sigmoid(y_pred) alpha = keras.ops.where(keras.ops.equal(y_true, 1.0), self._alpha, (1.0 - self._alpha)) @@ -802,24 +802,24 @@ def __init__(self, num_classes=80, alpha=0.25, gamma=2.0, delta=1.0): self._num_classes = num_classes def call(self, y_true, y_pred): - y_pred = tf.cast(y_pred, dtype=tf.float32) + y_pred = keras.ops.cast(y_pred, dtype="float32") box_labels = y_true[:, :, :4] box_predictions = y_pred[:, :, :4] - cls_labels = tf.one_hot( + cls_labels = keras.ops.one_hot( keras.ops.cast(y_true[:, :, 4], dtype="int32"), - depth=self._num_classes, - dtype=tf.float32, + num_classes=self._num_classes, + dtype="float32", ) cls_predictions = y_pred[:, :, 4:] - positive_mask = keras.ops.cast(tf.greater(y_true[:, :, 4], -1.0), dtype="float32") + positive_mask = keras.ops.cast(keras.ops.greater(y_true[:, :, 4], -1.0), dtype="float32") ignore_mask = keras.ops.cast(keras.ops.equal(y_true[:, :, 4], -2.0), dtype="float32") clf_loss = self._clf_loss(cls_labels, cls_predictions) box_loss = self._box_loss(box_labels, box_predictions) - clf_loss = tf.where(tf.equal(ignore_mask, 1.0), 0.0, clf_loss) - box_loss = tf.where(tf.equal(positive_mask, 1.0), box_loss, 0.0) - normalizer = tf.reduce_sum(positive_mask, axis=-1) - clf_loss = tf.math.divide_no_nan(tf.reduce_sum(clf_loss, axis=-1), normalizer) - box_loss = tf.math.divide_no_nan(tf.reduce_sum(box_loss, axis=-1), normalizer) + clf_loss = keras.ops.where(keras.ops.equal(ignore_mask, 1.0), 0.0, clf_loss) + box_loss = keras.ops.where(keras.ops.equal(positive_mask, 1.0), box_loss, 0.0) + normalizer = keras.ops.sum(positive_mask, axis=-1) + clf_loss = keras.ops.divide_no_nan(keras.ops.sum(clf_loss, axis=-1), normalizer) + box_loss = keras.ops.divide_no_nan(keras.ops.sum(box_loss, axis=-1), normalizer) loss = clf_loss + box_loss return loss @@ -939,7 +939,6 @@ def call(self, y_true, y_pred): ## Loading weights """ -# Change this to `model_dir` when not using the downloaded weights def get_latest_weights(model_dir): weight_files = glob.glob(os.path.join(model_dir, "*.weights.h5"))