diff --git a/python/mlx/nn/layers/convolution.py b/python/mlx/nn/layers/convolution.py index 88b97add0d..2109e84f71 100644 --- a/python/mlx/nn/layers/convolution.py +++ b/python/mlx/nn/layers/convolution.py @@ -66,7 +66,7 @@ def __init__( def _extra_repr(self): return ( - f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " + f"{self.weight.shape[-1] * self.groups}, {self.weight.shape[0]}, " f"kernel_size={self.weight.shape[1]}, stride={self.stride}, " f"padding={self.padding}, dilation={self.dilation}, " f"groups={self.groups}, " @@ -146,7 +146,7 @@ def __init__( def _extra_repr(self): return ( - f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " + f"{self.weight.shape[-1] * self.groups}, {self.weight.shape[0]}, " f"kernel_size={self.weight.shape[1:3]}, stride={self.stride}, " f"padding={self.padding}, dilation={self.dilation}, " f"groups={self.groups}, " @@ -219,7 +219,7 @@ def __init__( def _extra_repr(self): return ( - f"{self.weight.shape[-1]}, {self.weight.shape[0]}, " + f"{self.weight.shape[-1] * self.groups}, {self.weight.shape[0]}, " f"kernel_size={self.weight.shape[1:4]}, stride={self.stride}, " f"padding={self.padding}, dilation={self.dilation}, " f"bias={'bias' in self}"