Skip to content

Conversation

@abheeyeee
Copy link

Fix Issue: #1509
This PR merges the duplicate Distance over Gradients (DoG) implementations found in optax/contrib/_dog.py and optax/_src/transform.py into a single, unified implementation in optax/_src/dog.py.
Created optax/_src/dog.py which consolidates DoG and DoWG.
The new scale_by_dognow supports a layer_wise argument.
Re-implemented scale_by_distance_over_gradients in optax/_src/transform.py to use the new scale_by_dog with layer_wise=True.
Deprecated scale_by_distance_over_gradients in favor of scale_by_dog.
Updated optax/contrib/_dog.py to be a compatibility shim importing from optax/_src/dog.py.
Added optax/_src/dog_test.py to verify both global and layer-wise behaviors, as well as legacy compatibility.

@vroulet
Copy link
Collaborator

vroulet commented Dec 2, 2025

Can you keep _dog.py and _dog_test.py files, and keep them in the contrib folder?
This will ease the review process.
Try to keep the pr as concise as possible. (It looks rather good otherwise :) )

@abheeyeee
Copy link
Author

Can you keep _dog.py and _dog_test.py files, and keep them in the contrib folder? This will ease the review process. Try to keep the pr as concise as possible. (It looks rather good otherwise :) )

Thanks for the Feedback, i will do the changes rn.

@abheeyeee abheeyeee marked this pull request as draft December 2, 2025 19:29
@abheeyeee abheeyeee marked this pull request as ready for review December 2, 2025 19:29
@abheeyeee
Copy link
Author

@vroulet moved the files as you asked. This should fix the issue

Copy link
Collaborator

@emilyfertig emilyfertig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

@abheeyeee
Copy link
Author

abheeyeee commented Dec 4, 2025

@emilyfertig I made those changes that you prefered I added the new scale_by_l_dog function. 137d6a7
commit is not visible, i mistakenly did rebased or reset my branch after making it while trying to fix the merge conflict.
Would really love your feedback. thanks

Copy link
Collaborator

@emilyfertig emilyfertig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! scale_by_l_dog should be the same as scale_by_dog with layer_wise = True, right? Can we remove the layer_wise arg, and make scale_by_l_dog the same as scale_by_dog with layer_wise=True?

@abheeyeee
Copy link
Author

Thanks! scale_by_l_dog should be the same as scale_by_dog with layer_wise = True, right? Can we remove the layer_wise arg, and make scale_by_l_dog the same as scale_by_dog with layer_wise=True?

Thanks for your feedback @emilyfertig i did as you asked made scale_by_l_dog the same as scale_by_dog with layer_wise=True.
But i am facing some Checks faiure of Pytest versions on ubuntu jax, and i tried solving it but one passes and another fails. can you help me out here. Otherwise all changes have been made and this should fix the issue

@abheeyeee abheeyeee requested a review from emilyfertig December 4, 2025 19:03
@abheeyeee abheeyeee force-pushed the Fix1509 branch 2 times, most recently from 06259a2 to 0894aeb Compare December 4, 2025 20:03
Copy link
Collaborator

@emilyfertig emilyfertig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks! I think the pytest failure is unrelated and should clear up if you rebase.

return _scale_by_dog(
init_step=("heuristic", reps_rel),
eps=eps,
layer_wise=True,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, what I meant is to please get rid of the layer_wise arg everywhere, and make separate implementations of scale_by_dog and scale_by_l_dog. Does that make sense?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@emilyfertig Refactored the DoG optimizer implementation in optax/contrib/_dog.py to separate the global and layer-wise variants.
Refactored optax/contrib/_dog.py: Removed the internal _scale_by_dog helper function.
Implemented scale_by_dog (global DoG) and scale_by_l_dog (layer-wise DoG) as distinct, standalone functions.
Removed the layer_wise argument from scale_by_dog to enforce clear separation of concerns.

Updated optax/contrib/_dog_test.py:
Renamed test_dog_layer_wise to test_l_dog_vs_dog to reflect the API changes.
Updated comments to remove outdated references to the layer_wise argument.

Verified that all tests pass with pytest optax/contrib/_dog_test.py.

@abheeyeee abheeyeee requested a review from emilyfertig December 5, 2025 04:23
def scale_by_l_dog(
reps_rel: jax.typing.ArrayLike = 1e-6,
eps: jax.typing.ArrayLike = 1e-8,
param_dtype: Optional[jax.typing.DTypeLike] = None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove the unused param_dtype arg.


def init_fn(params: base.Params) -> DoGState:
params_dtype = optax.tree.dtype(params, "lowest")
if param_dtype is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this done here but not in scale_by_dog?

max_dist=jnp.asarray(r_epsilon, dtype=params_dtype),
sum_sq_norm_grads=jnp.asarray(0.0, dtype=params_dtype),
init_params=optax.tree.cast(params, params_dtype),
max_dist=max_dist,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leave this inlined?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert this so it's inlined.

with self.assertRaises(AssertionError):
test_utils.assert_trees_all_close(updates_global, updates_layer)

def test_legacy_compatibility(self):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This test doesn't have much point if the scale_by_distance_over_gradients implementation is changed to call scale_by_l_dog. Can you revert scale_by_distance_over_gradients to its former implementation and still deprecate it?

@abheeyeee
Copy link
Author

abheeyeee commented Dec 5, 2025

@emilyfertig Understood, this is what i am going to do.
Refactor DoG and Revert Legacy Compatibility:
Proposed Changes

optax/contrib/_dog.py
Remove param_dtype argument from scale_by_l_dog.
Remove param_dtype usage in init_fn inside scale_by_l_dog if it's unused or redundant.
Address consistency between scale_by_l_dog and scale_by_dog regarding init_fn and dtype casting.
Revert scale_by_distance_over_gradients to its legacy implementation (likely a standalone implementation instead of calling scale_by_l_dog) but ensure it relies on scale_by_dog.

optax/contrib/_dog_test.py
Ensure test_legacy_compatibility is meaningful after reverting scale_by_distance_over_gradients.

@abheeyeee abheeyeee requested a review from emilyfertig December 5, 2025 19:50
Ivgi et al, `DoG is SGD's Best Friend: A Parameter-Free Dynamic Step Size
Schedule <https://arxiv.org/pdf/2302.12022.pdf>`_, 2023
"""
reps_rel = 1e-6 if reps_rel is None else reps_rel
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not the original implementation. You appear to be using an LLM. Please check what it outputs before you request a review.

@abheeyeee abheeyeee requested a review from emilyfertig December 5, 2025 20:24
max_dist=jnp.asarray(r_epsilon, dtype=params_dtype),
sum_sq_norm_grads=jnp.asarray(0.0, dtype=params_dtype),
init_params=optax.tree.cast(params, params_dtype),
max_dist=max_dist,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please revert this so it's inlined.

def init_fn(params: base.Params) -> DoGState:
params_dtype = optax.tree.dtype(params, "lowest")

# r_epsilon is already a tree of scalars
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove or clarify comment

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants