Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 29 additions & 9 deletions keras/src/optimizers/muon.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class Muon(optimizer.Optimizer):
will be used. This is not configurable.
- If the argument `exclude_embeddings` (defaults to `True`) is set
to `True`, the AdamW step will be used.
- For any variablewith a name that matches an expression
- For any variable with a name that matches an expression
listed in the argument `exclude_layers` (a list), the
AdamW step will be used.
- Any other variable uses the Muon step.
Expand All @@ -46,17 +46,17 @@ class Muon(optimizer.Optimizer):
that takes no arguments and returns the actual value to use.
The exponential decay rate for the 1st moment estimates. Defaults to
`0.9`.
adam_beta_2: A float value or a constant float tensor, ora callable
adam_beta_2: A float value or a constant float tensor, or a callable
that takes no arguments and returns the actual value to use.
The exponential decay rate for the 2nd moment estimates. Defaults to
`0.999`.
epsilon: A small constant for numerical stability. This is
"epsilon hat" in the Kingma and Ba paper
(in the formula just before Section 2.1),
not the epsilon in Algorithm 1 of the paper.
It be used at Adamw.Defaults to `1e-7`.
It be used at Adamw. Defaults to `1e-7`.
exclude_layers: List of strings, keywords of layer names to exclude.
All layers with keywords in their path will use adamw.
All layers with keywords in their name will use adamw.
exclude_embeddings: Boolean value
If True, embedding layers will use adamw.
muon_a: Float, parameter a of the muon algorithm.
Expand Down Expand Up @@ -134,10 +134,10 @@ def _should_use_adamw(self, variable):
# any {0,1}-D parameters should all be optimized by adam
if not 1 < len(variable.shape) < 4:
return True
if self.exclude_embeddings and "embedding" in variable.path.lower():
if self.exclude_embeddings and "embedding" in variable.name.lower():
return True
for keyword in self.exclude_layers:
if re.search(keyword, variable.path):
if re.search(keyword, variable.name):
return True
return False

Expand Down Expand Up @@ -183,7 +183,14 @@ def update_step(self, gradient, variable, learning_rate):
self._muon_update_step(gradient, variable, learning_rate)

def _muon_update_step(self, gradient, variable, lr):
m = self.adam_momentums[variable.path]
if variable.name not in self.adam_momentums:
self.adam_momentums[variable.name] = (
self.add_variable_from_reference(
reference_variable=variable, name="momentum"
)
)
Comment on lines +186 to +191
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve maintainability and reduce code duplication, you could extract this logic for lazily initializing the momentum variable into a helper method. This same logic is repeated in _adamw_update_step.

You could define a new private method like this:

def _maybe_init_momentum(self, variable):
    if variable.name not in self.adam_momentums:
        self.adam_momentums[variable.name] = (
            self.add_variable_from_reference(
                reference_variable=variable, name="momentum"
            )
        )

Then you can replace this block with a single call: self._maybe_init_momentum(variable).

        self._maybe_init_momentum(variable)


m = self.adam_momentums[variable.name]
self.assign_add(m, ops.add(gradient, m * (self.momentum - 1)))
shape = variable.shape
if self.nesterov:
Expand All @@ -200,6 +207,19 @@ def _muon_update_step(self, gradient, variable, lr):

def _adamw_update_step(self, gradient, variable, learning_rate):
"""Update step given gradient and the associated model variable."""
if variable.name not in self.adam_momentums:
self.adam_momentums[variable.name] = (
self.add_variable_from_reference(
reference_variable=variable, name="momentum"
)
)
Comment on lines +210 to +215
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

As mentioned in the comment for _muon_update_step, you can use the suggested _maybe_init_momentum helper method here as well to avoid duplicating the momentum initialization logic. This would make the code more concise and easier to maintain.

        self._maybe_init_momentum(variable)

if variable.name not in self.adam_velocities:
self.adam_velocities[variable.name] = (
self.add_variable_from_reference(
reference_variable=variable, name="velocity"
)
)

lr = ops.cast(learning_rate, variable.dtype)
gradient = ops.cast(gradient, variable.dtype)
local_step = ops.cast(self.iterations + 1, variable.dtype)
Expand All @@ -210,8 +230,8 @@ def _adamw_update_step(self, gradient, variable, learning_rate):
ops.cast(self.adam_beta_2, variable.dtype), local_step
)

m = self.adam_momentums[variable.path]
v = self.adam_velocities[variable.path]
m = self.adam_momentums[variable.name]
v = self.adam_velocities[variable.name]

alpha = lr * ops.sqrt(1 - adam_beta_2_power) / (1 - adam_beta_1_power)

Expand Down
24 changes: 24 additions & 0 deletions keras/src/optimizers/muon_test.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
import pytest

import keras
from keras.src import backend
from keras.src import ops
from keras.src import testing
Expand Down Expand Up @@ -81,3 +83,25 @@ def test_clip_value(self):
grad = [np.array([100.0, 100.0])]
clipped_grad = optimizer._clip_gradients(grad)
self.assertAllClose(clipped_grad[0], [1.0, 1.0])

@pytest.mark.skipif(
backend.backend() != "tensorflow", reason="Runs only on TF backend"
)
def test_exclude_layers_with_variable_name(self):
"""Ensure `exclude_layers` works with current TensorFlow versions.
Uses `variable.name` instead of the deprecated `variable.path`.
"""
optimizer = Muon(learning_rate=0.01, exclude_layers=["last"])

model = keras.Sequential(
[
keras.layers.Dense(5, input_shape=(10,)),
keras.layers.Dense(1, name="last"),
]
)

x_train = np.random.rand(10, 10).astype(np.float32)
y_train = np.random.rand(10, 1).astype(np.float32)

model.compile(optimizer=optimizer, loss="mse")
model.fit(x_train, y_train, epochs=1, batch_size=2, verbose=0)