-
Notifications
You must be signed in to change notification settings - Fork 2.1k
Updates to "Customizing fit() with JAX" guide. #2188
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Conversation
The starting point for these changes was to use `stateless_compute_loss` instead of `compute_loss` for consistency, since we state that everything in `train_step` should be stateless. - This in turn required plumbing `metrics_variable` and `sample_weight`. - This required adding `keras.utils.unpack_x_y_sample_weight`. These changes make the guide slightly longer, but the code is more generic and closer to the Keras default implementation. Also fixed issue where `test_step` was not returning the metrics correctly.
Summary of ChangesHello @hertschuh, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed! This pull request significantly enhances the "Customizing fit() with JAX" guide by aligning its custom training and evaluation steps with Keras's stateless principles. The changes ensure that loss computation and metric updates consistently handle Highlights
Using Gemini Code AssistThe full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips. Invoking Gemini You can request assistance from Gemini at any point by creating a comment using either
Customization To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a Limitations & Feedback Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here. You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension. Footnotes
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request does a great job of updating the "Customizing fit() with JAX" guide to use stateless_compute_loss
and correctly handle sample_weight
. This makes the examples more robust and aligns them better with Keras best practices. The fix in test_step
to correctly accumulate metrics in the logs
dictionary is also a valuable improvement. However, I've identified a critical issue where metrics are being updated twice in the new examples that use stateless_compute_loss
. This function already updates metrics, so the manual update calls that follow are redundant and lead to incorrect metric reporting. I've added comments with suggestions to remove the redundant code. With these fixes, the guide will be even more accurate and helpful.
if metric.name == "loss": | ||
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss) | ||
this_metric_vars = metric.stateless_update_state( | ||
this_metric_vars, loss, sample_weight=sample_weight | ||
) | ||
else: | ||
this_metric_vars = metric.stateless_update_state( | ||
this_metric_vars, y, y_pred | ||
this_metric_vars, y, y_pred, sample_weight=sample_weight | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The metrics are being updated twice. stateless_compute_loss
, called within grad_fn
, already updates the metrics. These subsequent calls to metric.stateless_update_state
are redundant and will result in incorrect metric values. You should remove this block. The metrics_variables
returned from grad_fn
already contain the updated state, and the logs
should be computed from that state.
if metric.name == "loss": | ||
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss) | ||
this_metric_vars = metric.stateless_update_state( | ||
this_metric_vars, loss, sample_weight=sample_weight | ||
) | ||
else: | ||
this_metric_vars = metric.stateless_update_state( | ||
this_metric_vars, y, y_pred | ||
this_metric_vars, y, y_pred, sample_weight=sample_weight | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
" if metric.name == \"loss\":\n", | ||
" this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)\n", | ||
" this_metric_vars = metric.stateless_update_state(\n", | ||
" this_metric_vars, loss, sample_weight=sample_weight\n", | ||
" )\n", | ||
" else:\n", | ||
" this_metric_vars = metric.stateless_update_state(\n", | ||
" this_metric_vars, y, y_pred\n", | ||
" this_metric_vars, y, y_pred, sample_weight=sample_weight\n", | ||
" )\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
" if metric.name == \"loss\":\n", | ||
" this_metric_vars = metric.stateless_update_state(this_metric_vars, loss)\n", | ||
" this_metric_vars = metric.stateless_update_state(\n", | ||
" this_metric_vars, loss, sample_weight=sample_weight\n", | ||
" )\n", | ||
" else:\n", | ||
" this_metric_vars = metric.stateless_update_state(\n", | ||
" this_metric_vars, y, y_pred\n", | ||
" this_metric_vars, y, y_pred, sample_weight=sample_weight\n", | ||
" )\n", |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if metric.name == "loss": | ||
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss) | ||
this_metric_vars = metric.stateless_update_state( | ||
this_metric_vars, loss, sample_weight=sample_weight | ||
) | ||
else: | ||
this_metric_vars = metric.stateless_update_state( | ||
this_metric_vars, y, y_pred | ||
this_metric_vars, y, y_pred, sample_weight=sample_weight | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The metrics are being updated twice in this code example. stateless_compute_loss
, called within grad_fn
, already updates the metrics. These subsequent calls to metric.stateless_update_state
are redundant and will result in incorrect metric values. You should remove this block to make the example correct.
if metric.name == "loss": | ||
this_metric_vars = metric.stateless_update_state(this_metric_vars, loss) | ||
this_metric_vars = metric.stateless_update_state( | ||
this_metric_vars, loss, sample_weight=sample_weight | ||
) | ||
else: | ||
this_metric_vars = metric.stateless_update_state( | ||
this_metric_vars, y, y_pred | ||
this_metric_vars, y, y_pred, sample_weight=sample_weight | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar to the train_step
example, the metrics in test_step
are being updated twice. stateless_compute_loss
already handles metric updates. The explicit calls to metric.stateless_update_state
here are redundant and will lead to incorrect evaluation results. Please remove this block to correct the example.
The starting point for these changes was to use
stateless_compute_loss
instead ofcompute_loss
for consistency, since we state that everything intrain_step
should be stateless.metrics_variable
andsample_weight
.keras.utils.unpack_x_y_sample_weight
.These changes make the guide slightly longer, but the code is more generic and closer to the Keras default implementation.
Also fixed issue where
test_step
was not returning the metrics correctly.