Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 56 additions & 32 deletions guides/custom_train_step_in_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,17 +60,10 @@
- We implement a fully-stateless `compute_loss_and_updates()` method
to compute the loss as well as the updated values for the non-trainable
variables of the model. Internally, it calls `stateless_call()` and
the built-in `compute_loss()`.
the built-in `stateless_compute_loss()`.
- We implement a fully-stateless `train_step()` method to compute current
metric values (including the loss) as well as updated values for the
trainable variables, the optimizer variables, and the metric variables.

Note that you can also take into account the `sample_weight` argument by:

- Unpacking the data as `x, y, sample_weight = data`
- Passing `sample_weight` to `compute_loss()`
- Passing `sample_weight` alongside `y` and `y_pred`
to metrics in `stateless_update_state()`
"""


Expand All @@ -79,8 +72,10 @@ def compute_loss_and_updates(
self,
trainable_variables,
non_trainable_variables,
metrics_variables,
x,
y,
sample_weight,
training=False,
):
y_pred, non_trainable_variables = self.stateless_call(
Expand All @@ -89,8 +84,21 @@ def compute_loss_and_updates(
x,
training=training,
)
loss = self.compute_loss(x, y, y_pred)
return loss, (y_pred, non_trainable_variables)
loss, (
trainable_variables,
non_trainable_variables,
metrics_variables,
) = self.stateless_compute_loss(
trainable_variables,
non_trainable_variables,
metrics_variables,
x=x,
y=y,
y_pred=y_pred,
sample_weight=sample_weight,
training=training,
)
return loss, (y_pred, non_trainable_variables, metrics_variables)

def train_step(self, state, data):
(
Expand All @@ -99,25 +107,24 @@ def train_step(self, state, data):
optimizer_variables,
metrics_variables,
) = state
x, y = data
x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)

# Get the gradient function.
grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)

# Compute the gradients.
(loss, (y_pred, non_trainable_variables)), grads = grad_fn(
(loss, (y_pred, non_trainable_variables, metrics_variables)), grads = grad_fn(
trainable_variables,
non_trainable_variables,
metrics_variables,
x,
y,
sample_weight,
training=True,
)

# Update trainable variables and optimizer variables.
(
trainable_variables,
optimizer_variables,
) = self.optimizer.stateless_apply(
trainable_variables, optimizer_variables = self.optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)

Expand All @@ -129,10 +136,12 @@ def train_step(self, state, data):
len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
]
if metric.name == "loss":
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
this_metric_vars = metric.stateless_update_state(
this_metric_vars, loss, sample_weight=sample_weight
)
else:
this_metric_vars = metric.stateless_update_state(
this_metric_vars, y, y_pred
this_metric_vars, y, y_pred, sample_weight=sample_weight
)
Comment on lines 138 to 145
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The metrics are being updated twice. stateless_compute_loss, called within grad_fn, already updates the metrics. These subsequent calls to metric.stateless_update_state are redundant and will result in incorrect metric values. You should remove this block. The metrics_variables returned from grad_fn already contain the updated state, and the logs should be computed from that state.

logs[metric.name] = metric.stateless_result(this_metric_vars)
new_metrics_vars += this_metric_vars
Expand Down Expand Up @@ -186,6 +195,7 @@ def compute_loss_and_updates(
non_trainable_variables,
x,
y,
sample_weight,
training=False,
):
y_pred, non_trainable_variables = self.stateless_call(
Expand All @@ -194,7 +204,7 @@ def compute_loss_and_updates(
x,
training=training,
)
loss = self.loss_fn(y, y_pred)
loss = self.loss_fn(y, y_pred, sample_weight=sample_weight)
return loss, (y_pred, non_trainable_variables)

def train_step(self, state, data):
Expand All @@ -204,7 +214,7 @@ def train_step(self, state, data):
optimizer_variables,
metrics_variables,
) = state
x, y = data
x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)

# Get the gradient function.
grad_fn = jax.value_and_grad(self.compute_loss_and_updates, has_aux=True)
Expand All @@ -215,14 +225,12 @@ def train_step(self, state, data):
non_trainable_variables,
x,
y,
sample_weight,
training=True,
)

# Update trainable variables and optimizer variables.
(
trainable_variables,
optimizer_variables,
) = self.optimizer.stateless_apply(
trainable_variables, optimizer_variables = self.optimizer.stateless_apply(
optimizer_variables, grads, trainable_variables
)

Expand All @@ -231,10 +239,10 @@ def train_step(self, state, data):
mae_metric_vars = metrics_variables[len(self.loss_tracker.variables) :]

loss_tracker_vars = self.loss_tracker.stateless_update_state(
loss_tracker_vars, loss
loss_tracker_vars, loss, sample_weight=sample_weight
)
mae_metric_vars = self.mae_metric.stateless_update_state(
mae_metric_vars, y, y_pred
mae_metric_vars, y, y_pred, sample_weight=sample_weight
)

logs = {}
Expand Down Expand Up @@ -287,7 +295,7 @@ def metrics(self):
class CustomModel(keras.Model):
def test_step(self, state, data):
# Unpack the data.
x, y = data
x, y, sample_weight = keras.utils.unpack_x_y_sample_weight(data)
(
trainable_variables,
non_trainable_variables,
Expand All @@ -301,21 +309,37 @@ def test_step(self, state, data):
x,
training=False,
)
loss = self.compute_loss(x, y, y_pred)
loss, (
trainable_variables,
non_trainable_variables,
metrics_variables,
) = self.stateless_compute_loss(
trainable_variables,
non_trainable_variables,
metrics_variables,
x=x,
y=y,
y_pred=y_pred,
sample_weight=sample_weight,
training=False,
)

# Update metrics.
new_metrics_vars = []
logs = {}
for metric in self.metrics:
this_metric_vars = metrics_variables[
len(new_metrics_vars) : len(new_metrics_vars) + len(metric.variables)
]
if metric.name == "loss":
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
this_metric_vars = metric.stateless_update_state(
this_metric_vars, loss, sample_weight=sample_weight
)
else:
this_metric_vars = metric.stateless_update_state(
this_metric_vars, y, y_pred
this_metric_vars, y, y_pred, sample_weight=sample_weight
)
Comment on lines 334 to 341
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the train_step, the metrics in test_step are being updated twice. stateless_compute_loss already handles metric updates. The explicit calls to metric.stateless_update_state here are redundant and will lead to incorrect evaluation results. Please remove this block.

logs = metric.stateless_result(this_metric_vars)
logs[metric.name] = metric.stateless_result(this_metric_vars)
new_metrics_vars += this_metric_vars

# Return metric logs and updated state variables.
Expand All @@ -336,7 +360,7 @@ def test_step(self, state, data):
# Evaluate with our custom test_step
x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.evaluate(x, y)
model.evaluate(x, y, return_dict=True)


"""
Expand Down
Loading