-
Notifications
You must be signed in to change notification settings - Fork 311
Fix: #1509 Merge duplicate DoG implementations and add layer-wise support #1518
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
|
Can you keep _dog.py and _dog_test.py files, and keep them in the contrib folder? |
Thanks for the Feedback, i will do the changes rn. |
|
@vroulet moved the files as you asked. This should fix the issue |
emilyfertig
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks!
|
@emilyfertig I made those changes that you prefered I added the new scale_by_l_dog function. 137d6a7 |
emilyfertig
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! scale_by_l_dog should be the same as scale_by_dog with layer_wise = True, right? Can we remove the layer_wise arg, and make scale_by_l_dog the same as scale_by_dog with layer_wise=True?
Thanks for your feedback @emilyfertig i did as you asked made scale_by_l_dog the same as scale_by_dog with layer_wise=True. |
06259a2 to
0894aeb
Compare
emilyfertig
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks! I think the pytest failure is unrelated and should clear up if you rebase.
optax/contrib/_dog.py
Outdated
| return _scale_by_dog( | ||
| init_step=("heuristic", reps_rel), | ||
| eps=eps, | ||
| layer_wise=True, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, what I meant is to please get rid of the layer_wise arg everywhere, and make separate implementations of scale_by_dog and scale_by_l_dog. Does that make sense?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@emilyfertig Refactored the DoG optimizer implementation in optax/contrib/_dog.py to separate the global and layer-wise variants.
Refactored optax/contrib/_dog.py: Removed the internal _scale_by_dog helper function.
Implemented scale_by_dog (global DoG) and scale_by_l_dog (layer-wise DoG) as distinct, standalone functions.
Removed the layer_wise argument from scale_by_dog to enforce clear separation of concerns.
Updated optax/contrib/_dog_test.py:
Renamed test_dog_layer_wise to test_l_dog_vs_dog to reflect the API changes.
Updated comments to remove outdated references to the layer_wise argument.
Verified that all tests pass with pytest optax/contrib/_dog_test.py.
optax/contrib/_dog.py
Outdated
| def scale_by_l_dog( | ||
| reps_rel: jax.typing.ArrayLike = 1e-6, | ||
| eps: jax.typing.ArrayLike = 1e-8, | ||
| param_dtype: Optional[jax.typing.DTypeLike] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove the unused param_dtype arg.
optax/contrib/_dog.py
Outdated
|
|
||
| def init_fn(params: base.Params) -> DoGState: | ||
| params_dtype = optax.tree.dtype(params, "lowest") | ||
| if param_dtype is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this done here but not in scale_by_dog?
| max_dist=jnp.asarray(r_epsilon, dtype=params_dtype), | ||
| sum_sq_norm_grads=jnp.asarray(0.0, dtype=params_dtype), | ||
| init_params=optax.tree.cast(params, params_dtype), | ||
| max_dist=max_dist, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Leave this inlined?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please revert this so it's inlined.
| with self.assertRaises(AssertionError): | ||
| test_utils.assert_trees_all_close(updates_global, updates_layer) | ||
|
|
||
| def test_legacy_compatibility(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This test doesn't have much point if the scale_by_distance_over_gradients implementation is changed to call scale_by_l_dog. Can you revert scale_by_distance_over_gradients to its former implementation and still deprecate it?
|
@emilyfertig Understood, this is what i am going to do. optax/contrib/_dog.py optax/contrib/_dog_test.py |
optax/_src/transform.py
Outdated
| Ivgi et al, `DoG is SGD's Best Friend: A Parameter-Free Dynamic Step Size | ||
| Schedule <https://arxiv.org/pdf/2302.12022.pdf>`_, 2023 | ||
| """ | ||
| reps_rel = 1e-6 if reps_rel is None else reps_rel |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not the original implementation. You appear to be using an LLM. Please check what it outputs before you request a review.
| max_dist=jnp.asarray(r_epsilon, dtype=params_dtype), | ||
| sum_sq_norm_grads=jnp.asarray(0.0, dtype=params_dtype), | ||
| init_params=optax.tree.cast(params, params_dtype), | ||
| max_dist=max_dist, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please revert this so it's inlined.
| def init_fn(params: base.Params) -> DoGState: | ||
| params_dtype = optax.tree.dtype(params, "lowest") | ||
|
|
||
| # r_epsilon is already a tree of scalars |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please remove or clarify comment
Fix Issue: #1509
This PR merges the duplicate Distance over Gradients (DoG) implementations found in optax/contrib/_dog.py and optax/_src/transform.py into a single, unified implementation in optax/_src/dog.py.
Created optax/_src/dog.py which consolidates DoG and DoWG.
The new scale_by_dognow supports a layer_wise argument.
Re-implemented scale_by_distance_over_gradients in optax/_src/transform.py to use the new scale_by_dog with layer_wise=True.
Deprecated scale_by_distance_over_gradients in favor of scale_by_dog.
Updated optax/contrib/_dog.py to be a compatibility shim importing from optax/_src/dog.py.
Added optax/_src/dog_test.py to verify both global and layer-wise behaviors, as well as legacy compatibility.