Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions mlx_lm/models/longcat_next.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
# Copyright © 2026 Apple Inc.

from dataclasses import dataclass
from typing import Optional, Union

import mlx.nn as nn

from .base import BaseModelArgs
from .longcat_flash_ngram import Model as LongcatFlashNgramLM
from .longcat_flash_ngram import ModelArgs as TextConfig


@dataclass
class ModelArgs(BaseModelArgs):
text_config: Union[TextConfig, dict] = None
text_vocab_plus_multimodal_special_token_size: int = 131125
model_type: str = "longcat_next"

def __post_init__(self):
if self.text_config is None:
raise ValueError("text_config is required")
if isinstance(self.text_config, dict):
self.text_config = TextConfig.from_dict(self.text_config)

@classmethod
def from_dict(cls, params):
text_config = dict(params)
# Ngram hashing uses text_vocab_size, not the full multimodal vocab_size
if "text_vocab_size" in params:
text_config["vocab_size"] = params["text_vocab_size"]
return cls(
text_config=text_config,
**{
k: v
for k, v in params.items()
if k in ("model_type", "text_vocab_plus_multimodal_special_token_size")
}
)


class Model(LongcatFlashNgramLM):
def __init__(self, config: ModelArgs):
super().__init__(config.text_config)
self.lm_head = nn.Linear(
config.text_config.hidden_size,
config.text_vocab_plus_multimodal_special_token_size,
bias=False,
)

def sanitize(self, weights):
weights = {
k: v
for k, v in weights.items()
if not k.startswith(
(
"visual_head.",
"audio_head.",
"visual_model.",
"image_decoder.",
"image_refiner.",
"model.visual_tokenizer.",
"model.audio_tokenizer.",
)
)
}
# Truncate embed_tokens to text vocab (drop multimodal token embeddings)
embed_key = "model.embed_tokens.weight"
if embed_key in weights:
weights[embed_key] = weights[embed_key][: self.args.vocab_size]
return super().sanitize(weights)
27 changes: 27 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1926,6 +1926,33 @@ def test_all_models(self):
"mla_scale_kv_lora": True,
"attention_bias": False,
},
{
"model_type": "longcat_next",
"zero_expert_type": "identity",
"hidden_size": 128,
"ffn_hidden_size": 128,
"moe_topk": 2,
"expert_ffn_hidden_size": 128,
"n_routed_experts": 2,
"zero_expert_num": 2,
"num_layers": 4,
"num_hidden_layers": 4,
"vocab_size": 1000,
"text_vocab_plus_multimodal_special_token_size": 1000,
"max_position_embeddings": 1000,
"num_attention_heads": 4,
"kv_lora_rank": 16,
"q_lora_rank": 16,
"qk_rope_head_dim": 8,
"qk_nope_head_dim": 8,
"v_head_dim": 8,
"routed_scaling_factor": 1.0,
"rms_norm_eps": 1e-5,
"rope_theta": 1000,
"mla_scale_q_lora": True,
"mla_scale_kv_lora": True,
"attention_bias": False,
},
{
"model_type": "longcat_flash",
"attention_method": "MLA",
Expand Down