Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 4 additions & 1 deletion keras/src/backend/jax/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,10 @@ def _update_metrics_variables(
]
) as scope:
self._loss_tracker.update_state(
unscaled_loss, sample_weight=tree.flatten(x)[0].shape[0]
unscaled_loss,
sample_weight=next(
i for i in tree.flatten(x) if i is not None
).shape[0],
)
logs = self.compute_metrics(x, y, y_pred, sample_weight)

Expand Down
8 changes: 6 additions & 2 deletions keras/src/backend/tensorflow/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def train_step(self, data):
)
self._loss_tracker.update_state(
loss_module.unscale_loss_for_distribution(loss),
sample_weight=tf.shape(tree.flatten(x)[0])[0],
sample_weight=tf.shape(
next(i for i in tree.flatten(x) if i is not None)
)[0],
Comment on lines +71 to +73
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This change correctly handles cases where the first input is None. However, it introduces a risk of a StopIteration error if all inputs in x are None. This can be difficult to debug, especially inside a tf.function.

A more robust approach would be to handle this edge case explicitly, for example by raising a ValueError with a clear message.

Also, this logic is duplicated in test_step. Consider extracting it into a private helper method to improve maintainability and ensure consistency.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

)
if self.optimizer is not None:
loss = self.optimizer.scale_loss(loss)
Expand Down Expand Up @@ -96,7 +98,9 @@ def test_step(self, data):
)
self._loss_tracker.update_state(
loss_module.unscale_loss_for_distribution(loss),
sample_weight=tf.shape(tree.flatten(x)[0])[0],
sample_weight=tf.shape(
next(i for i in tree.flatten(x) if i is not None)
)[0],
Comment on lines +101 to +103
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the train_step, this change is vulnerable to a StopIteration error if all inputs are None. Explicitly handling this edge case would make the code more robust and prevent potential runtime crashes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

)
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)

Expand Down
10 changes: 8 additions & 2 deletions keras/src/backend/torch/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def train_step(self, data):
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=True
)
self._loss_tracker.update_state(
loss, sample_weight=tree.flatten(x)[0].shape[0]
loss,
sample_weight=next(
i for i in tree.flatten(x) if i is not None
).shape[0],
)
if self.optimizer is not None:
loss = self.optimizer.scale_loss(loss)
Expand Down Expand Up @@ -90,7 +93,10 @@ def test_step(self, data):
x=x, y=y, y_pred=y_pred, sample_weight=sample_weight, training=False
)
self._loss_tracker.update_state(
loss, sample_weight=tree.flatten(x)[0].shape[0]
loss,
sample_weight=next(
i for i in tree.flatten(x) if i is not None
).shape[0],
)
return self.compute_metrics(x, y, y_pred, sample_weight=sample_weight)

Expand Down
32 changes: 16 additions & 16 deletions keras/src/models/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,14 +166,14 @@ def __init__(self):
super().__init__()
self.dense = layers.Dense(2)

def call(self, a, b=None):
x = a if b is None else a + b
return self.dense(x)

x1 = Input((2,), name="x1")
x2 = Input((2,), name="x2", optional=True)
y = OptionalInputLayer()(x1, x2)
model = Model({"x1": x1, "x2": x2}, y)
def call(self, x, o=None):
z = x if o is None else x + o
return self.dense(z)

x = Input((2,), name="x")
o = Input((2,), name="o", optional=True)
y = OptionalInputLayer()(x, o)
model = Model({"x": x, "o": o}, y)
return model


Expand Down Expand Up @@ -1244,27 +1244,27 @@ def test_functional_deeply_nested_outputs_struct_losses(self):
)
def test_functional_optional_inputs(self, is_optional_none):
model = _get_model_optional_inputs()
x1 = np.ones((2, 2))
x2 = None if is_optional_none else np.ones((2, 2))
x = np.ones((2, 2))
o = None if is_optional_none else np.ones((2, 2))
y_true = np.ones((2, 2))

model.compile(loss="mse", optimizer="adam")
model.fit(x={"x1": x1, "x2": x2}, y=y_true)
model.evaluate(x={"x1": x1, "x2": x2}, y=y_true)
model.predict(x={"x1": x1, "x2": x2})
model.fit(x={"x": x, "o": o}, y=y_true)
model.evaluate(x={"x": x, "o": o}, y=y_true)
model.predict(x={"x": x, "o": o})

@parameterized.named_parameters(
("optional_none", True), ("optional_tensor", False)
)
def test_functional_optional_inputs_generator(self, is_optional_none):
model = _get_model_optional_inputs()
x1 = np.ones((2, 2))
x2 = None if is_optional_none else np.ones((2, 2))
x = np.ones((2, 2))
o = None if is_optional_none else np.ones((2, 2))
y_true = np.ones((2, 2))

def data_generator(with_y=True):
for _ in range(4):
yield ({"x1": x1, "x2": x2},) + ((y_true,) if with_y else ())
yield ({"x": x, "o": o},) + ((y_true,) if with_y else ())

model.compile(loss="mse", optimizer="adam")
model.fit(data_generator())
Expand Down