diff --git a/keras/src/backend/jax/trainer.py b/keras/src/backend/jax/trainer.py index e5c77b7f526f..5f01505c2d47 100644 --- a/keras/src/backend/jax/trainer.py +++ b/keras/src/backend/jax/trainer.py @@ -105,7 +105,10 @@ def _update_metrics_variables( ] ) as scope: self._loss_tracker.update_state( - unscaled_loss, sample_weight=tree.flatten(x)[0].shape[0] + unscaled_loss, + sample_weight=next( + i for i in tree.flatten(x) if i is not None + ).shape[0], ) logs = self.compute_metrics(x, y, y_pred, sample_weight) diff --git a/keras/src/backend/tensorflow/trainer.py b/keras/src/backend/tensorflow/trainer.py index c223deff7e05..cd6410999dd2 100644 --- a/keras/src/backend/tensorflow/trainer.py +++ b/keras/src/backend/tensorflow/trainer.py @@ -68,7 +68,9 @@ def train_step(self, data): ) self._loss_tracker.update_state( loss_module.unscale_loss_for_distribution(loss), - sample_weight=tf.shape(tree.flatten(x)[0])[0], + sample_weight=tf.shape( + next(i for i in tree.flatten(x) if i is not None) + )[0], ) if self.optimizer is not None: loss = self.optimizer.scale_loss(loss) @@ -96,7 +98,9 @@ def test_step(self, data): ) self._loss_tracker.update_state( loss_module.unscale_loss_for_distribution(loss), - sample_weight=tf.shape(tree.flatten(x)[0])[0], + sample_weight=tf.shape( + next(i for i in tree.flatten(x) if i is not None) + )[0], ) return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight) diff --git a/keras/src/backend/torch/trainer.py b/keras/src/backend/torch/trainer.py index a021cca29b60..ad68c2f3a7ec 100644 --- a/keras/src/backend/torch/trainer.py +++ b/keras/src/backend/torch/trainer.py @@ -54,7 +54,10 @@ def train_step(self, data): x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=True ) self._loss_tracker.update_state( - loss, sample_weight=tree.flatten(x)[0].shape[0] + loss, + sample_weight=next( + i for i in tree.flatten(x) if i is not None + ).shape[0], ) if self.optimizer is not None: loss = self.optimizer.scale_loss(loss) @@ -90,7 +93,10 @@ def test_step(self, data): x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=False ) self._loss_tracker.update_state( - loss, sample_weight=tree.flatten(x)[0].shape[0] + loss, + sample_weight=next( + i for i in tree.flatten(x) if i is not None + ).shape[0], ) return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight) diff --git a/keras/src/models/model_test.py b/keras/src/models/model_test.py index 0fea1336db67..1c479fef9ba3 100644 --- a/keras/src/models/model_test.py +++ b/keras/src/models/model_test.py @@ -166,14 +166,14 @@ def __init__(self): super().__init__() self.dense = layers.Dense(2) - def call(self, a, b=None): - x = a if b is None else a + b - return self.dense(x) - - x1 = Input((2,), name="x1") - x2 = Input((2,), name="x2", optional=True) - y = OptionalInputLayer()(x1, x2) - model = Model({"x1": x1, "x2": x2}, y) + def call(self, x, o=None): + z = x if o is None else x + o + return self.dense(z) + + x = Input((2,), name="x") + o = Input((2,), name="o", optional=True) + y = OptionalInputLayer()(x, o) + model = Model({"x": x, "o": o}, y) return model @@ -1244,27 +1244,27 @@ def test_functional_deeply_nested_outputs_struct_losses(self): ) def test_functional_optional_inputs(self, is_optional_none): model = _get_model_optional_inputs() - x1 = np.ones((2, 2)) - x2 = None if is_optional_none else np.ones((2, 2)) + x = np.ones((2, 2)) + o = None if is_optional_none else np.ones((2, 2)) y_true = np.ones((2, 2)) model.compile(loss="mse", optimizer="adam") - model.fit(x={"x1": x1, "x2": x2}, y=y_true) - model.evaluate(x={"x1": x1, "x2": x2}, y=y_true) - model.predict(x={"x1": x1, "x2": x2}) + model.fit(x={"x": x, "o": o}, y=y_true) + model.evaluate(x={"x": x, "o": o}, y=y_true) + model.predict(x={"x": x, "o": o}) @parameterized.named_parameters( ("optional_none", True), ("optional_tensor", False) ) def test_functional_optional_inputs_generator(self, is_optional_none): model = _get_model_optional_inputs() - x1 = np.ones((2, 2)) - x2 = None if is_optional_none else np.ones((2, 2)) + x = np.ones((2, 2)) + o = None if is_optional_none else np.ones((2, 2)) y_true = np.ones((2, 2)) def data_generator(with_y=True): for _ in range(4): - yield ({"x1": x1, "x2": x2},) + ((y_true,) if with_y else ()) + yield ({"x": x, "o": o},) + ((y_true,) if with_y else ()) model.compile(loss="mse", optimizer="adam") model.fit(data_generator())