diff --git a/mlx_lm/models/lfm2.py b/mlx_lm/models/lfm2.py index c7f742af9..07bebdea1 100644 --- a/mlx_lm/models/lfm2.py +++ b/mlx_lm/models/lfm2.py @@ -298,6 +298,8 @@ def __call__( def sanitize(self, weights): sanitized_weights = {} for name, param in weights.items(): + if name == "lm_head.weight": + continue if "conv.weight" in name: if param.shape[-1] > param.shape[1]: param = param.transpose(0, 2, 1)