From c1454890d3d4fd919367664f951c8631fea71da5 Mon Sep 17 00:00:00 2001 From: mm65x Date: Sun, 22 Mar 2026 12:44:34 +0000 Subject: [PATCH] add nn.WeightNorm layer --- docs/src/python/nn/layers.rst | 1 + python/mlx/nn/layers/__init__.py | 1 + python/mlx/nn/layers/normalization.py | 67 +++++++++++++++++ python/tests/test_nn.py | 104 ++++++++++++++++++++++++++ 4 files changed, 173 insertions(+) diff --git a/docs/src/python/nn/layers.rst b/docs/src/python/nn/layers.rst index b9544bae51..92f7ff12ef 100644 --- a/docs/src/python/nn/layers.rst +++ b/docs/src/python/nn/layers.rst @@ -72,3 +72,4 @@ Layers Tanh Transformer Upsample + WeightNorm diff --git a/python/mlx/nn/layers/__init__.py b/python/mlx/nn/layers/__init__.py index c2fba58347..5c32b0f44a 100644 --- a/python/mlx/nn/layers/__init__.py +++ b/python/mlx/nn/layers/__init__.py @@ -77,6 +77,7 @@ InstanceNorm, LayerNorm, RMSNorm, + WeightNorm, ) from mlx.nn.layers.pooling import ( AvgPool1d, diff --git a/python/mlx/nn/layers/normalization.py b/python/mlx/nn/layers/normalization.py index bdddb6ccfb..c0653604eb 100644 --- a/python/mlx/nn/layers/normalization.py +++ b/python/mlx/nn/layers/normalization.py @@ -361,3 +361,70 @@ def __call__(self, x: mx.array) -> mx.array: x = (x - mean) * mx.rsqrt(var + self.eps) return (self.weight * x + self.bias) if "weight" in self else x + + +class WeightNorm(Module): + r"""Applies weight normalization [1] to a parameter of a given module. + + Weight normalization reparameterizes a weight tensor :math:`\mathbf{w}` as + + .. math:: + + \mathbf{w} = g \frac{\mathbf{v}}{\|\mathbf{v}\|} + + where :math:`g` is a scalar magnitude and :math:`\mathbf{v}` is the + direction vector. The norm is computed over all dimensions except ``dim``. + + On each call, the normalized weight is recomputed from the current + ``weight_g`` and ``weight_v`` and injected into the wrapped module + before its forward pass. + + [1]: https://arxiv.org/abs/1602.07868 + + Args: + module (mlx.nn.Module): The module containing the weight to normalize. + name (str): The name of the weight parameter to normalize. + Default: ``"weight"``. + dim (int): The dimension over which to keep independent magnitudes. + Default: ``0``. + + Examples: + >>> import mlx.core as mx + >>> import mlx.nn as nn + >>> linear = nn.Linear(8, 16) + >>> wn = nn.WeightNorm(linear) + >>> x = mx.random.normal((2, 8)) + >>> wn(x).shape + [2, 16] + """ + + def __init__(self, module: Module, name: str = "weight", dim: int = 0): + super().__init__() + self.module = module + self.name = name + + w = getattr(module, name) + self.dim = dim % w.ndim + norm_axes = [i for i in range(w.ndim) if i != self.dim] + g = mx.sqrt(mx.sum(mx.square(w), axis=norm_axes, keepdims=True)) + + self.weight_g = g + self.weight_v = w + module.freeze(keys=[name], recurse=False) + + def unfreeze(self, *args, **kwargs): + super().unfreeze(*args, **kwargs) + self.module.freeze(keys=[self.name], recurse=False) + + def _compute_weight(self): + v = self.weight_v + norm_axes = [i for i in range(v.ndim) if i != self.dim] + norm = mx.sqrt(mx.sum(mx.square(v), axis=norm_axes, keepdims=True)) + return self.weight_g * (v / norm) + + def __call__(self, *args, **kwargs): + setattr(self.module, self.name, self._compute_weight()) + return self.module(*args, **kwargs) + + def _extra_repr(self): + return f"name={self.name!r}, dim={self.dim}" diff --git a/python/tests/test_nn.py b/python/tests/test_nn.py index 174823f179..384d31a1e9 100644 --- a/python/tests/test_nn.py +++ b/python/tests/test_nn.py @@ -785,6 +785,110 @@ def test_batch_norm_stats(self): self.assertEqual(batch_norm.running_mean.shape, running_mean.shape) self.assertEqual(batch_norm.running_var.shape, running_var.shape) + def test_weight_norm(self): + mx.random.seed(42) + + # Basic: wraps a linear layer + linear = nn.Linear(8, 16) + wn = nn.WeightNorm(linear) + x = mx.random.normal((2, 8)) + y = wn(x) + self.assertEqual(y.shape, (2, 16)) + + # weight_g and weight_v should be parameters + params = wn.parameters() + self.assertIn("weight_g", params) + self.assertIn("weight_v", params) + + # After forward pass, the module has the recomputed normalized weight + w = linear.weight + self.assertEqual(w.shape, (16, 8)) + + # Verify the normalization: each row should have magnitude weight_g + v = wn.weight_v + g = wn.weight_g + norm_axes = [i for i in range(v.ndim) if i != 0] + v_norm = mx.sqrt(mx.sum(v * v, axis=norm_axes, keepdims=True)) + expected_w = g * (v / v_norm) + self.assertTrue(mx.allclose(w, expected_w, atol=1e-6)) + + # Wrapping Conv1d + conv = nn.Conv1d(4, 8, kernel_size=3) + wn_conv = nn.WeightNorm(conv) + x_conv = mx.random.normal((2, 10, 4)) + y_conv = wn_conv(x_conv) + self.assertEqual(y_conv.shape, (2, 8, 8)) + + # Verify conv weight_g shape: one magnitude per output channel + self.assertEqual(wn_conv.weight_g.shape[0], 8) + + # Wrapping Conv2d + conv2d = nn.Conv2d(3, 16, kernel_size=3) + wn_conv2d = nn.WeightNorm(conv2d) + x_2d = mx.random.normal((1, 8, 8, 3)) + y_2d = wn_conv2d(x_2d) + self.assertEqual(y_2d.shape, (1, 6, 6, 16)) + + # Initial forward pass should match unwrapped module with same weights + # (since weight_norm initializes g = ||v|| per dim, w = g*v/||v|| = v) + linear2 = nn.Linear(4, 6) + w_orig = mx.array(linear2.weight) + wn2 = nn.WeightNorm(linear2) + x2 = mx.random.normal((1, 4)) + y_wn = wn2(x2) + # Manually compute with original weight + y_orig = x2 @ w_orig.T + linear2.bias + mx.eval(y_wn, y_orig) + self.assertTrue(mx.allclose(y_wn, y_orig, atol=1e-5)) + + # module.weight should not be trainable (only weight_g and weight_v) + wn3 = nn.WeightNorm(nn.Linear(4, 8)) + + def collect_keys(d, prefix="", result=None): + if result is None: + result = set() + if isinstance(d, dict): + for k, v in d.items(): + collect_keys(v, prefix + k + ".", result) + elif isinstance(d, mx.array): + result.add(prefix[:-1]) + return result + + tp_keys = collect_keys(wn3.trainable_parameters()) + self.assertIn("weight_g", tp_keys) + self.assertIn("weight_v", tp_keys) + self.assertIn("module.bias", tp_keys) + self.assertNotIn("module.weight", tp_keys) + + # unfreeze should keep module.weight frozen + wn3.unfreeze() + tp_keys_after = collect_keys(wn3.trainable_parameters()) + self.assertNotIn("module.weight", tp_keys_after) + + # module.weight should still appear in parameters() (frozen, not removed) + all_keys = collect_keys(wn3.parameters()) + self.assertIn("module.weight", all_keys) + + # Gradient flow + wn4 = nn.WeightNorm(nn.Linear(4, 8)) + x4 = mx.random.normal((2, 4)) + + def loss_fn(model, x): + return model(x).sum() + + loss, grads = nn.value_and_grad(wn4, loss_fn)(wn4, x4) + mx.eval(loss, grads) + self.assertIn("weight_g", grads) + self.assertIn("weight_v", grads) + self.assertNotIn("weight", grads.get("module", {})) + + # Negative dim + linear_neg = nn.Linear(4, 8) + wn_neg = nn.WeightNorm(linear_neg, dim=-1) + y_neg = wn_neg(mx.random.normal((2, 4))) + self.assertEqual(y_neg.shape, (2, 8)) + self.assertEqual(wn_neg.weight_g.shape, (1, 4)) + def test_conv1d(self): N = 5 L = 12