diff --git a/keras/src/optimizers/muon.py b/keras/src/optimizers/muon.py index 88d0dde3ee92..aca27a238335 100644 --- a/keras/src/optimizers/muon.py +++ b/keras/src/optimizers/muon.py @@ -24,7 +24,7 @@ class Muon(optimizer.Optimizer): will be used. This is not configurable. - If the argument `exclude_embeddings` (defaults to `True`) is set to `True`, the AdamW step will be used. - - For any variablewith a name that matches an expression + - For any variable with a name that matches an expression listed in the argument `exclude_layers` (a list), the AdamW step will be used. - Any other variable uses the Muon step. @@ -46,7 +46,7 @@ class Muon(optimizer.Optimizer): that takes no arguments and returns the actual value to use. The exponential decay rate for the 1st moment estimates. Defaults to `0.9`. - adam_beta_2: A float value or a constant float tensor, ora callable + adam_beta_2: A float value or a constant float tensor, or a callable that takes no arguments and returns the actual value to use. The exponential decay rate for the 2nd moment estimates. Defaults to `0.999`. @@ -54,9 +54,9 @@ class Muon(optimizer.Optimizer): "epsilon hat" in the Kingma and Ba paper (in the formula just before Section 2.1), not the epsilon in Algorithm 1 of the paper. - It be used at Adamw.Defaults to `1e-7`. + It be used at Adamw. Defaults to `1e-7`. exclude_layers: List of strings, keywords of layer names to exclude. - All layers with keywords in their path will use adamw. + All layers with keywords in their name will use adamw. exclude_embeddings: Boolean value If True, embedding layers will use adamw. muon_a: Float, parameter a of the muon algorithm. @@ -134,10 +134,10 @@ def _should_use_adamw(self, variable): # any {0,1}-D parameters should all be optimized by adam if not 1 < len(variable.shape) < 4: return True - if self.exclude_embeddings and "embedding" in variable.path.lower(): + if self.exclude_embeddings and "embedding" in variable.name.lower(): return True for keyword in self.exclude_layers: - if re.search(keyword, variable.path): + if re.search(keyword, variable.name): return True return False @@ -183,7 +183,14 @@ def update_step(self, gradient, variable, learning_rate): self._muon_update_step(gradient, variable, learning_rate) def _muon_update_step(self, gradient, variable, lr): - m = self.adam_momentums[variable.path] + if variable.name not in self.adam_momentums: + self.adam_momentums[variable.name] = ( + self.add_variable_from_reference( + reference_variable=variable, name="momentum" + ) + ) + + m = self.adam_momentums[variable.name] self.assign_add(m, ops.add(gradient, m * (self.momentum - 1))) shape = variable.shape if self.nesterov: @@ -200,6 +207,19 @@ def _muon_update_step(self, gradient, variable, lr): def _adamw_update_step(self, gradient, variable, learning_rate): """Update step given gradient and the associated model variable.""" + if variable.name not in self.adam_momentums: + self.adam_momentums[variable.name] = ( + self.add_variable_from_reference( + reference_variable=variable, name="momentum" + ) + ) + if variable.name not in self.adam_velocities: + self.adam_velocities[variable.name] = ( + self.add_variable_from_reference( + reference_variable=variable, name="velocity" + ) + ) + lr = ops.cast(learning_rate, variable.dtype) gradient = ops.cast(gradient, variable.dtype) local_step = ops.cast(self.iterations + 1, variable.dtype) @@ -210,8 +230,8 @@ def _adamw_update_step(self, gradient, variable, learning_rate): ops.cast(self.adam_beta_2, variable.dtype), local_step ) - m = self.adam_momentums[variable.path] - v = self.adam_velocities[variable.path] + m = self.adam_momentums[variable.name] + v = self.adam_velocities[variable.name] alpha = lr * ops.sqrt(1 - adam_beta_2_power) / (1 - adam_beta_1_power) diff --git a/keras/src/optimizers/muon_test.py b/keras/src/optimizers/muon_test.py index 9ec85d8985ce..4241539a5acf 100644 --- a/keras/src/optimizers/muon_test.py +++ b/keras/src/optimizers/muon_test.py @@ -1,5 +1,7 @@ import numpy as np +import pytest +import keras from keras.src import backend from keras.src import ops from keras.src import testing @@ -81,3 +83,25 @@ def test_clip_value(self): grad = [np.array([100.0, 100.0])] clipped_grad = optimizer._clip_gradients(grad) self.assertAllClose(clipped_grad[0], [1.0, 1.0]) + + @pytest.mark.skipif( + backend.backend() != "tensorflow", reason="Runs only on TF backend" + ) + def test_exclude_layers_with_variable_name(self): + """Ensure `exclude_layers` works with current TensorFlow versions. + Uses `variable.name` instead of the deprecated `variable.path`. + """ + optimizer = Muon(learning_rate=0.01, exclude_layers=["last"]) + + model = keras.Sequential( + [ + keras.layers.Dense(5, input_shape=(10,)), + keras.layers.Dense(1, name="last"), + ] + ) + + x_train = np.random.rand(10, 10).astype(np.float32) + y_train = np.random.rand(10, 1).astype(np.float32) + + model.compile(optimizer=optimizer, loss="mse") + model.fit(x_train, y_train, epochs=1, batch_size=2, verbose=0)