From 13f0c4bdf3c166dad18873c1390a887c950600ee Mon Sep 17 00:00:00 2001 From: Tarjei Mandt Date: Thu, 26 Mar 2026 16:51:32 +1100 Subject: [PATCH] Add LongCat Next --- mlx_lm/models/longcat_next.py | 70 +++++++++++++++++++++++++++++++++++ tests/test_models.py | 27 ++++++++++++++ 2 files changed, 97 insertions(+) create mode 100644 mlx_lm/models/longcat_next.py diff --git a/mlx_lm/models/longcat_next.py b/mlx_lm/models/longcat_next.py new file mode 100644 index 000000000..6e11e30d6 --- /dev/null +++ b/mlx_lm/models/longcat_next.py @@ -0,0 +1,70 @@ +# Copyright © 2026 Apple Inc. + +from dataclasses import dataclass +from typing import Optional, Union + +import mlx.nn as nn + +from .base import BaseModelArgs +from .longcat_flash_ngram import Model as LongcatFlashNgramLM +from .longcat_flash_ngram import ModelArgs as TextConfig + + +@dataclass +class ModelArgs(BaseModelArgs): + text_config: Union[TextConfig, dict] = None + text_vocab_plus_multimodal_special_token_size: int = 131125 + model_type: str = "longcat_next" + + def __post_init__(self): + if self.text_config is None: + raise ValueError("text_config is required") + if isinstance(self.text_config, dict): + self.text_config = TextConfig.from_dict(self.text_config) + + @classmethod + def from_dict(cls, params): + text_config = dict(params) + # Ngram hashing uses text_vocab_size, not the full multimodal vocab_size + if "text_vocab_size" in params: + text_config["vocab_size"] = params["text_vocab_size"] + return cls( + text_config=text_config, + **{ + k: v + for k, v in params.items() + if k in ("model_type", "text_vocab_plus_multimodal_special_token_size") + } + ) + + +class Model(LongcatFlashNgramLM): + def __init__(self, config: ModelArgs): + super().__init__(config.text_config) + self.lm_head = nn.Linear( + config.text_config.hidden_size, + config.text_vocab_plus_multimodal_special_token_size, + bias=False, + ) + + def sanitize(self, weights): + weights = { + k: v + for k, v in weights.items() + if not k.startswith( + ( + "visual_head.", + "audio_head.", + "visual_model.", + "image_decoder.", + "image_refiner.", + "model.visual_tokenizer.", + "model.audio_tokenizer.", + ) + ) + } + # Truncate embed_tokens to text vocab (drop multimodal token embeddings) + embed_key = "model.embed_tokens.weight" + if embed_key in weights: + weights[embed_key] = weights[embed_key][: self.args.vocab_size] + return super().sanitize(weights) diff --git a/tests/test_models.py b/tests/test_models.py index 0b2a963dc..684ab16b2 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -1926,6 +1926,33 @@ def test_all_models(self): "mla_scale_kv_lora": True, "attention_bias": False, }, + { + "model_type": "longcat_next", + "zero_expert_type": "identity", + "hidden_size": 128, + "ffn_hidden_size": 128, + "moe_topk": 2, + "expert_ffn_hidden_size": 128, + "n_routed_experts": 2, + "zero_expert_num": 2, + "num_layers": 4, + "num_hidden_layers": 4, + "vocab_size": 1000, + "text_vocab_plus_multimodal_special_token_size": 1000, + "max_position_embeddings": 1000, + "num_attention_heads": 4, + "kv_lora_rank": 16, + "q_lora_rank": 16, + "qk_rope_head_dim": 8, + "qk_nope_head_dim": 8, + "v_head_dim": 8, + "routed_scaling_factor": 1.0, + "rms_norm_eps": 1e-5, + "rope_theta": 1000, + "mla_scale_q_lora": True, + "mla_scale_kv_lora": True, + "attention_bias": False, + }, { "model_type": "longcat_flash", "attention_method": "MLA",