Skip to content

Conversation

hertschuh
Copy link
Contributor

The starting point for these changes was to use stateless_compute_loss instead of compute_loss for consistency, since we state that everything in train_step should be stateless.

  • This in turn required plumbing metrics_variable and sample_weight.
  • This required adding keras.utils.unpack_x_y_sample_weight.

These changes make the guide slightly longer, but the code is more generic and closer to the Keras default implementation.

Also fixed issue where test_step was not returning the metrics correctly.

The starting point for these changes was to use `stateless_compute_loss` instead of `compute_loss` for consistency, since we state that everything in `train_step` should be stateless.
- This in turn required plumbing `metrics_variable` and `sample_weight`.
- This required adding `keras.utils.unpack_x_y_sample_weight`.

These changes make the guide slightly longer, but the code is more generic and closer to the Keras default implementation.

Also fixed issue where `test_step` was not returning the metrics correctly.
Copy link
Contributor

Summary of Changes

Hello @hertschuh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request significantly enhances the "Customizing fit() with JAX" guide by aligning its custom training and evaluation steps with Keras's stateless principles. The changes ensure that loss computation and metric updates consistently handle sample_weight and metrics_variables in a stateless manner, making the guide's examples more robust and reflective of Keras's internal implementation. Additionally, it rectifies a bug in the test_step's metric return mechanism.

Highlights

  • Stateless Loss Computation: Switched from compute_loss to stateless_compute_loss within the compute_loss_and_updates method to ensure full statelessness in the JAX train_step, aligning with Keras's design principles.
  • Sample Weight and Metrics Handling: Integrated sample_weight and metrics_variables into the compute_loss_and_updates and train_step signatures and logic, allowing for more comprehensive and consistent metric updates across the custom training loop.
  • Data Unpacking Utility: Introduced keras.utils.unpack_x_y_sample_weight for robust and standardized data unpacking in both train_step and test_step, simplifying data preparation.
  • Test Step Metric Correction: Fixed an issue in the test_step implementation where metrics were not being returned correctly, ensuring proper evaluation results and consistency with the training step.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request does a great job of updating the "Customizing fit() with JAX" guide to use stateless_compute_loss and correctly handle sample_weight. This makes the examples more robust and aligns them better with Keras best practices. The fix in test_step to correctly accumulate metrics in the logs dictionary is also a valuable improvement. However, I've identified a critical issue where metrics are being updated twice in the new examples that use stateless_compute_loss. This function already updates metrics, so the manual update calls that follow are redundant and lead to incorrect metric reporting. I've added comments with suggestions to remove the redundant code. With these fixes, the guide will be even more accurate and helpful.

Comment on lines 138 to 145
if metric.name == "loss":
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
this_metric_vars = metric.stateless_update_state(
this_metric_vars, loss, sample_weight=sample_weight
)
else:
this_metric_vars = metric.stateless_update_state(
this_metric_vars, y, y_pred
this_metric_vars, y, y_pred, sample_weight=sample_weight
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The metrics are being updated twice. stateless_compute_loss, called within grad_fn, already updates the metrics. These subsequent calls to metric.stateless_update_state are redundant and will result in incorrect metric values. You should remove this block. The metrics_variables returned from grad_fn already contain the updated state, and the logs should be computed from that state.

Comment on lines 334 to 341
if metric.name == "loss":
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
this_metric_vars = metric.stateless_update_state(
this_metric_vars, loss, sample_weight=sample_weight
)
else:
this_metric_vars = metric.stateless_update_state(
this_metric_vars, y, y_pred
this_metric_vars, y, y_pred, sample_weight=sample_weight
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the train_step, the metrics in test_step are being updated twice. stateless_compute_loss already handles metric updates. The explicit calls to metric.stateless_update_state here are redundant and will lead to incorrect evaluation results. Please remove this block.

Comment on lines 177 to 184
" if metric.name == \"loss\":\n",
" this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)\n",
" this_metric_vars = metric.stateless_update_state(\n",
" this_metric_vars, loss, sample_weight=sample_weight\n",
" )\n",
" else:\n",
" this_metric_vars = metric.stateless_update_state(\n",
" this_metric_vars, y, y_pred\n",
" this_metric_vars, y, y_pred, sample_weight=sample_weight\n",
" )\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The metrics are being updated twice. stateless_compute_loss, called within grad_fn, already updates the metrics. These subsequent calls to metric.stateless_update_state are redundant and will result in incorrect metric values. You should remove this block.

Comment on lines 415 to 422
" if metric.name == \"loss\":\n",
" this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)\n",
" this_metric_vars = metric.stateless_update_state(\n",
" this_metric_vars, loss, sample_weight=sample_weight\n",
" )\n",
" else:\n",
" this_metric_vars = metric.stateless_update_state(\n",
" this_metric_vars, y, y_pred\n",
" this_metric_vars, y, y_pred, sample_weight=sample_weight\n",
" )\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

Similar to the train_step, the metrics in test_step are being updated twice. stateless_compute_loss already handles metric updates. The explicit calls to metric.stateless_update_state here are redundant and will lead to incorrect evaluation results. Please remove this block.

Comment on lines 143 to 150
if metric.name == "loss":
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
this_metric_vars = metric.stateless_update_state(
this_metric_vars, loss, sample_weight=sample_weight
)
else:
this_metric_vars = metric.stateless_update_state(
this_metric_vars, y, y_pred
this_metric_vars, y, y_pred, sample_weight=sample_weight
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The metrics are being updated twice in this code example. stateless_compute_loss, called within grad_fn, already updates the metrics. These subsequent calls to metric.stateless_update_state are redundant and will result in incorrect metric values. You should remove this block to make the example correct.

Comment on lines 388 to 395
if metric.name == "loss":
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)
this_metric_vars = metric.stateless_update_state(
this_metric_vars, loss, sample_weight=sample_weight
)
else:
this_metric_vars = metric.stateless_update_state(
this_metric_vars, y, y_pred
this_metric_vars, y, y_pred, sample_weight=sample_weight
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the train_step example, the metrics in test_step are being updated twice. stateless_compute_loss already handles metric updates. The explicit calls to metric.stateless_update_state here are redundant and will lead to incorrect evaluation results. Please remove this block to correct the example.

@hertschuh hertschuh requested review from fchollet and removed request for fchollet and MarkDaoust October 7, 2025 18:52
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants