Skip to content

Conversation

@TanmayThakur2209
Copy link

Motivation

scale_by_backtracking_linesearch accepts extra keyword arguments that are
forwarded to value_fn. Currently, any unused keyword arguments are silently
discarded, which can lead to subtle bugs (e.g. in Flax training loops) and makes
misconfiguration hard to diagnose.

What this PR does

  • Raises a clear error when unused keyword arguments are passed to the
    backtracking linesearch update.
  • Adds a test documenting the expected behavior.

Why this is safe

  • Does not change the public API.
  • Only affects cases that previously resulted in silent argument dropping.
  • Preserves JIT compatibility and existing behavior for valid calls.

Tests added

  • test_linesearch_raises_on_unused_extra_args

@selamw1
Copy link
Collaborator

selamw1 commented Jan 12, 2026

Thanks! I have a few suggestions to wrap this up:

  • Could you update the docstring for update_fn to explicitly warn users about the new validation? Here is a suggestion for the extra_args description:
**extra_args: additional keyword arguments, if the function needs
        additional arguments such as input data, they should be put there.
        Arguments not accepted by ``value_fn`` will raise a ``TypeError``, see
        the example in the docstring of the transform.
  • I noticed that scale_by_zoom_linesearch (line 1511) has a similar structure where it accepts **extra_args. For consistency and safety, could you apply this same validation logic there as well? This ensures users get the same "unexpected argument" error regardless of which linesearch they use.
  • Once you have made the final changes, could you please squash this into a single commit? (See JAX Contributing - Single change commits).

@TanmayThakur2209
Copy link
Author

Thanks for the suggestions!

This all makes sense — I’ll update the docstring to document the new validation, apply the same unused-argument check to scale_by_zoom_linesearch for consistency, and squash the changes into a single commit.

I’ll follow up shortly with the updates.

@TanmayThakur2209 TanmayThakur2209 force-pushed the fix-linesearch-extra-args branch 3 times, most recently from 3a12fd7 to 37abe97 Compare January 13, 2026 09:43
- Raise TypeError on unused keyword arguments
- Apply validation consistently to backtracking and zoom linesearch
- Update docstrings to document new behavior
- Add tests for unused extra_args errors
@TanmayThakur2209 TanmayThakur2209 force-pushed the fix-linesearch-extra-args branch from 37abe97 to 901a341 Compare January 13, 2026 09:54
@TanmayThakur2209
Copy link
Author

I’ve made the requested updates

  • Updated the update_fn docstring to explicitly document the new validation behavior for extra_args, including the fact that unused keyword arguments now raise a TypeError.

  • Applied the same unused-extra_args validation logic to scale_by_zoom_linesearch for consistency and safety, so both linesearch implementations now behave identically.

  • Added targeted tests covering the unused extra_args error for both backtracking and zoom linesearches.

I’ve squashed the changes into a single commit in line with the JAX contributing guidelines.

Please let me know if you’d like any further tweaks or clarifications!

@selamw1
Copy link
Collaborator

selamw1 commented Jan 13, 2026

LGTM, thank you!

@TanmayThakur2209
Copy link
Author

Thank you so much for the review and guidance — I really appreciate it!

Copy link
Collaborator

@vroulet vroulet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the catch

@vroulet
Copy link
Collaborator

vroulet commented Jan 22, 2026

And thanks @selamw1 for the careful review!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants