Skip to content

feat: Drop unsupported weights by default during model load#1033

Open
spicyneuron wants to merge 2 commits intoml-explore:mainfrom
spicyneuron:drop-unknown-weights
Open

feat: Drop unsupported weights by default during model load#1033
spicyneuron wants to merge 2 commits intoml-explore:mainfrom
spicyneuron:drop-unknown-weights

Conversation

@spicyneuron
Copy link
Copy Markdown
Contributor

@spicyneuron spicyneuron commented Mar 21, 2026

This makes mlx-lm drop unsupported weights by default when loading models that contain extra parameters (for example vision_tower weights from multimodal checkpoints).

Problem

mlx-lm doesn't support vision, but models converted with mlx-vlm still contain a perfectly usable text model.

Currently, attempting to load such models results in a hard crash:

  File "~/.local/share/uv/tools/mlx-lm/lib/python3.13/site-packages/mlx/nn/layers/base.py", line 185, in load_weights
    raise ValueError(
        f"Received {num_extra} parameters not in model: \n{extras}."
    )
ValueError: Received 333 parameters not in model:
language_model.vision_tower.blocks.0.attn.proj.bias,
language_model.vision_tower.blocks.0.attn.proj.weight,
...

That's especially painful right now because MLX quantizations on Hugging Face can still be a bit thin, and sometimes the only available checkpoint is a vision model. Even when that's not the case, needing to keep two versions of the same model in order to use both mlx-lm and mlx-vlm takes up a lot of storage.

Solution

This PR makes load_model filter out weights that don't exist on the instantiated MLX model before loading. Unsupported extra weights are skipped and logged, while real incompatibilities like missing supported weights still fail as before.

The main change is in mlx_lm/utils.py.

Alternatives

  1. I initially implemented this as an opt-in --drop-unknown-weights flag, but after testing locally I couldn't really find a case where I would want this disabled. The change felt narrow in scope and general enough in benefit that it made more sense as the default behavior. Easy to revert if folks feel differently.

  2. Add a --disable-strict-load style flag that toggles strict=False. I think that goes too far, since it would also hide incompatibilities. The current approach still preserves hard failures for missing supported weights.

@spicyneuron spicyneuron changed the title feat: Add CLI flag to drop unsupported weights during model load feat: Add support for dropping unsupported weights during model load Mar 21, 2026
@spicyneuron spicyneuron changed the title feat: Add support for dropping unsupported weights during model load feat: Add option to drop unsupported weights during model load Mar 21, 2026
@spicyneuron spicyneuron changed the title feat: Add option to drop unsupported weights during model load feat: Drop unsupported weights by default during model load Mar 23, 2026
@Thump604
Copy link
Copy Markdown

Heads up on a use case this would break: models with extra weight keys that are loaded by separate builders.

MTP (multi-token prediction) weights in PR #990 are stored as mtp.* keys alongside the base model weights. They're not in the base model's parameter tree because they're loaded separately by a dedicated builder (text_model_from_vlm.py in our case, mtp_generate.py in #990). With this PR, tree_flatten(model.parameters()) wouldn't include mtp.*, so they'd all get dropped before load_weights() runs.

Same pattern applies to any pipeline where extra weights are co-located in the safetensors but consumed by a secondary model (adapter weights, reward heads, draft model weights, etc.).

A couple of options that would preserve the cleanup benefit without breaking these cases:

  • A keep_prefixes parameter: load_model(..., keep_prefixes=["mtp."])
  • Only drop + warn (current behavior) but behind a flag like --strict-weights that defaults to off
  • Log the unknown keys as a warning but don't drop them

The warning-without-dropping approach is probably the safest default. Users who want strict behavior can opt in.

@angeloskath
Copy link
Copy Markdown
Member

Interesting. However we do have strict=True. Passing strict=False would allow loading and ignoring the unmatched weights.

Why not simply "expose" this parameter to the CLIs ?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants