Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/mcore_bridge/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,9 +126,11 @@ class ModelConfig(TransformerConfig):
ffn_hidden_size: Optional[int] = None
num_attention_heads: Optional[int] = None
num_query_groups: Optional[int] = None
num_global_query_groups: Optional[int] = None
softmax_type: Literal['vanilla', 'off-by-one', 'learnable'] = 'vanilla'
window_size: Optional[str] = None
window_attn_skip_freq: Optional[str] = None
layer_types: Optional[List[str]] = None
max_position_embeddings: Optional[int] = None

position_embedding_type: Literal['learned_absolute', 'rope', 'mrope', 'none'] = 'rope'
Expand All @@ -153,6 +155,7 @@ class ModelConfig(TransformerConfig):
attention_dropout: float = 0.
hidden_dropout: float = 0.
kv_channels: Optional[int] = None
global_kv_channels: Optional[int] = None
qk_layernorm: bool = False
qk_l2_norm: bool = False
no_rope_freq: Optional[int] = None
Expand Down Expand Up @@ -207,6 +210,14 @@ class ModelConfig(TransformerConfig):
hf_config: Optional[PretrainedConfig] = None
vit_attn_impl: Optional[str] = None # e.g. 'flash_attention_2'

# gemma4
hidden_size_per_layer_input: Optional[int] = None
vocab_size_per_layer_input: Optional[int] = None
num_kv_shared_layers: int = 0
enable_moe_block: bool = False
use_double_wide_mlp: bool = False
top_k_experts: Optional[int] = None

# Override
perform_initialization: bool = False
apply_query_key_layer_scaling: Optional[bool] = None
Expand Down
21 changes: 21 additions & 0 deletions src/mcore_bridge/config/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
'ffn_hidden_size': ['intermediate_size'],
'num_attention_heads': ['num_attention_heads'],
'num_query_groups': ['num_key_value_heads'],
'num_global_query_groups': ['num_global_key_value_heads'],
'max_position_embeddings': ['max_position_embeddings'],
'layernorm_epsilon': ['rms_norm_eps'],
'rotary_base': ['rope_theta'],
Expand All @@ -21,6 +22,7 @@
'add_qkv_bias': ['attention_bias', 'qkv_bias', 'use_bias'],
'add_bias_linear': ['mlp_bias'],
'kv_channels': ['head_dim'],
'global_kv_channels': ['global_head_dim'],
'hf_model_type': ['model_type'],
# moe
'moe_ffn_hidden_size': ['moe_intermediate_size'],
Expand Down Expand Up @@ -60,6 +62,14 @@
'window_size': ['sliding_window'],
'layer_types': ['layer_types'],
'interleave_moe_layer_step': ['interleave_moe_layer_step'],
# gemma4
'hidden_size_per_layer_input': ['hidden_size_per_layer_input'],
'vocab_size_per_layer_input': ['vocab_size_per_layer_input'],
'num_kv_shared_layers': ['num_kv_shared_layers'],
'enable_moe_block': ['enable_moe_block'],
'use_double_wide_mlp': ['use_double_wide_mlp'],
'num_experts': ['num_experts'],
'top_k_experts': ['top_k_experts'],
}


Expand Down Expand Up @@ -112,6 +122,9 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]:
interleave_moe_layer_step = res.pop('interleave_moe_layer_step', None)
window_size = res.pop('window_size', None)
rope_scaling = res.get('rope_scaling') or {}
if 'full_attention' in rope_scaling:
rope_scaling = rope_scaling['full_attention']
res['rope_scaling'] = rope_scaling
if llm_model_type in {'qwen3', 'qwen3_moe', 'qwen3_next'} or hf_model_type in {
'qwen3_omni_moe', 'qwen3_omni', 'qwen3_vl', 'qwen3_vl_moe', 'qwen3_5', 'qwen3_5_moe'
}:
Expand Down Expand Up @@ -192,7 +205,15 @@ def hf_to_mcore_config(hf_config: PretrainedConfig) -> Dict[str, Any]:
res['moe_layer_freq'] = f"[{','.join(moe_layer_freq)}]"
elif hf_model_type == 'glm4v':
res['rotary_interleaved'] = True
elif hf_model_type == 'gemma4':
res['qk_layernorm'] = True
if layer_types is not None:
res['window_attn_skip_freq'] = f"[{','.join(['1' if lt == 'sliding_attention' else '0' for lt in layer_types])}]"
res['layer_types'] = layer_types

if rope_scaling.get('rope_type') is None and rope_scaling.get('type') is not None:
rope_scaling = {**rope_scaling, 'rope_type': rope_scaling['type']}
res['rope_scaling'] = rope_scaling
if 'partial_rotary_factor' not in res and 'partial_rotary_factor' in rope_scaling:
res['partial_rotary_factor'] = rope_scaling['partial_rotary_factor']
if 'rotary_base' not in res and 'rope_theta' in rope_scaling:
Expand Down
1 change: 1 addition & 0 deletions src/mcore_bridge/model/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ class MLLMModelType:
glm4v_moe = 'glm4v_moe'
kimi_vl = 'kimi_vl'
llama4 = 'llama4'
gemma4 = 'gemma4'


class ModelType(LLMModelType, MLLMModelType):
Expand Down
2 changes: 1 addition & 1 deletion src/mcore_bridge/model/mm_gpts/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
from . import glm, internvl, kimi_vl, llama4, qwen, qwen3_5, qwen3_5_gdn, qwen3_omni, qwen3_vl
from . import gemma4, glm, internvl, kimi_vl, llama4, qwen, qwen3_5, qwen3_5_gdn, qwen3_omni, qwen3_vl
Loading
Loading