Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/src/python/nn/layers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,4 @@ Layers
Tanh
Transformer
Upsample
WeightNorm
1 change: 1 addition & 0 deletions python/mlx/nn/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@
InstanceNorm,
LayerNorm,
RMSNorm,
WeightNorm,
)
from mlx.nn.layers.pooling import (
AvgPool1d,
Expand Down
67 changes: 67 additions & 0 deletions python/mlx/nn/layers/normalization.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,3 +361,70 @@ def __call__(self, x: mx.array) -> mx.array:

x = (x - mean) * mx.rsqrt(var + self.eps)
return (self.weight * x + self.bias) if "weight" in self else x


class WeightNorm(Module):
r"""Applies weight normalization [1] to a parameter of a given module.

Weight normalization reparameterizes a weight tensor :math:`\mathbf{w}` as

.. math::

\mathbf{w} = g \frac{\mathbf{v}}{\|\mathbf{v}\|}

where :math:`g` is a scalar magnitude and :math:`\mathbf{v}` is the
direction vector. The norm is computed over all dimensions except ``dim``.

On each call, the normalized weight is recomputed from the current
``weight_g`` and ``weight_v`` and injected into the wrapped module
before its forward pass.

[1]: https://arxiv.org/abs/1602.07868

Args:
module (mlx.nn.Module): The module containing the weight to normalize.
name (str): The name of the weight parameter to normalize.
Default: ``"weight"``.
dim (int): The dimension over which to keep independent magnitudes.
Default: ``0``.

Examples:
>>> import mlx.core as mx
>>> import mlx.nn as nn
>>> linear = nn.Linear(8, 16)
>>> wn = nn.WeightNorm(linear)
>>> x = mx.random.normal((2, 8))
>>> wn(x).shape
[2, 16]
"""

def __init__(self, module: Module, name: str = "weight", dim: int = 0):
super().__init__()
self.module = module
self.name = name

w = getattr(module, name)
self.dim = dim % w.ndim
norm_axes = [i for i in range(w.ndim) if i != self.dim]
g = mx.sqrt(mx.sum(mx.square(w), axis=norm_axes, keepdims=True))

self.weight_g = g
self.weight_v = w
module.freeze(keys=[name], recurse=False)

def unfreeze(self, *args, **kwargs):
super().unfreeze(*args, **kwargs)
self.module.freeze(keys=[self.name], recurse=False)

def _compute_weight(self):
v = self.weight_v
norm_axes = [i for i in range(v.ndim) if i != self.dim]
norm = mx.sqrt(mx.sum(mx.square(v), axis=norm_axes, keepdims=True))
return self.weight_g * (v / norm)

def __call__(self, *args, **kwargs):
setattr(self.module, self.name, self._compute_weight())
return self.module(*args, **kwargs)

def _extra_repr(self):
return f"name={self.name!r}, dim={self.dim}"
104 changes: 104 additions & 0 deletions python/tests/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -785,6 +785,110 @@ def test_batch_norm_stats(self):
self.assertEqual(batch_norm.running_mean.shape, running_mean.shape)
self.assertEqual(batch_norm.running_var.shape, running_var.shape)

def test_weight_norm(self):
mx.random.seed(42)

# Basic: wraps a linear layer
linear = nn.Linear(8, 16)
wn = nn.WeightNorm(linear)
x = mx.random.normal((2, 8))
y = wn(x)
self.assertEqual(y.shape, (2, 16))

# weight_g and weight_v should be parameters
params = wn.parameters()
self.assertIn("weight_g", params)
self.assertIn("weight_v", params)

# After forward pass, the module has the recomputed normalized weight
w = linear.weight
self.assertEqual(w.shape, (16, 8))

# Verify the normalization: each row should have magnitude weight_g
v = wn.weight_v
g = wn.weight_g
norm_axes = [i for i in range(v.ndim) if i != 0]
v_norm = mx.sqrt(mx.sum(v * v, axis=norm_axes, keepdims=True))
expected_w = g * (v / v_norm)
self.assertTrue(mx.allclose(w, expected_w, atol=1e-6))

# Wrapping Conv1d
conv = nn.Conv1d(4, 8, kernel_size=3)
wn_conv = nn.WeightNorm(conv)
x_conv = mx.random.normal((2, 10, 4))
y_conv = wn_conv(x_conv)
self.assertEqual(y_conv.shape, (2, 8, 8))

# Verify conv weight_g shape: one magnitude per output channel
self.assertEqual(wn_conv.weight_g.shape[0], 8)

# Wrapping Conv2d
conv2d = nn.Conv2d(3, 16, kernel_size=3)
wn_conv2d = nn.WeightNorm(conv2d)
x_2d = mx.random.normal((1, 8, 8, 3))
y_2d = wn_conv2d(x_2d)
self.assertEqual(y_2d.shape, (1, 6, 6, 16))

# Initial forward pass should match unwrapped module with same weights
# (since weight_norm initializes g = ||v|| per dim, w = g*v/||v|| = v)
linear2 = nn.Linear(4, 6)
w_orig = mx.array(linear2.weight)
wn2 = nn.WeightNorm(linear2)
x2 = mx.random.normal((1, 4))
y_wn = wn2(x2)
# Manually compute with original weight
y_orig = x2 @ w_orig.T + linear2.bias
mx.eval(y_wn, y_orig)
self.assertTrue(mx.allclose(y_wn, y_orig, atol=1e-5))

# module.weight should not be trainable (only weight_g and weight_v)
wn3 = nn.WeightNorm(nn.Linear(4, 8))

def collect_keys(d, prefix="", result=None):
if result is None:
result = set()
if isinstance(d, dict):
for k, v in d.items():
collect_keys(v, prefix + k + ".", result)
elif isinstance(d, mx.array):
result.add(prefix[:-1])
return result

tp_keys = collect_keys(wn3.trainable_parameters())
self.assertIn("weight_g", tp_keys)
self.assertIn("weight_v", tp_keys)
self.assertIn("module.bias", tp_keys)
self.assertNotIn("module.weight", tp_keys)

# unfreeze should keep module.weight frozen
wn3.unfreeze()
tp_keys_after = collect_keys(wn3.trainable_parameters())
self.assertNotIn("module.weight", tp_keys_after)

# module.weight should still appear in parameters() (frozen, not removed)
all_keys = collect_keys(wn3.parameters())
self.assertIn("module.weight", all_keys)

# Gradient flow
wn4 = nn.WeightNorm(nn.Linear(4, 8))
x4 = mx.random.normal((2, 4))

def loss_fn(model, x):
return model(x).sum()

loss, grads = nn.value_and_grad(wn4, loss_fn)(wn4, x4)
mx.eval(loss, grads)
self.assertIn("weight_g", grads)
self.assertIn("weight_v", grads)
self.assertNotIn("weight", grads.get("module", {}))

# Negative dim
linear_neg = nn.Linear(4, 8)
wn_neg = nn.WeightNorm(linear_neg, dim=-1)
y_neg = wn_neg(mx.random.normal((2, 4)))
self.assertEqual(y_neg.shape, (2, 8))
self.assertEqual(wn_neg.weight_g.shape, (1, 4))

def test_conv1d(self):
N = 5
L = 12
Expand Down