From 641dc17be2c880daf21cbd611b1ef8cae7af75eb Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 2 Apr 2026 19:13:19 +0800 Subject: [PATCH 1/5] compat mcore dev --- src/mcore_bridge/model/gpt_model.py | 1 + src/mcore_bridge/patcher.py | 5 ++++- src/mcore_bridge/version.py | 2 +- 3 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 29db5d0..12cd97d 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -158,6 +158,7 @@ def _apply_rotary_pos_emb_bshd( rotary_interleaved: bool = False, multi_latent_attention: bool = False, # not use mscale: float = 1.0, + **kwargs, ) -> torch.Tensor: """Apply rotary positional embedding to input tensor T. diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py index 2de8e51..392ff34 100644 --- a/src/mcore_bridge/patcher.py +++ b/src/mcore_bridge/patcher.py @@ -608,6 +608,7 @@ def _apply_rotary_pos_emb_thd( multi_latent_attention: bool = False, mscale: float = 1.0, cp_group: torch.distributed.ProcessGroup = None, + **kwargs, ) -> torch.Tensor: """A baseline implementation of applying RoPE for `thd` format. @@ -629,7 +630,8 @@ def _apply_rotary_pos_emb_thd( use_batched_rope = (freqs.dim() >= 1 and freqs.shape[0] == cu_seqlens_for_batched[-1]).item() if not use_batched_rope: logger.warning_once('Using non-batched RoPE, which may affect performance.') - kwargs = {'cp_group': cp_group} if mcore_013 else {} + if mcore_013: + kwargs['cp_group'] = cp_group return _origin_apply_rotary_pos_emb_thd( t, cu_seqlens, @@ -646,6 +648,7 @@ def _apply_rotary_pos_emb_thd( rotary_interleaved=rotary_interleaved, multi_latent_attention=multi_latent_attention, mscale=mscale, + **kwargs, ).squeeze(1) rope_utils._apply_rotary_pos_emb_thd = _apply_rotary_pos_emb_thd diff --git a/src/mcore_bridge/version.py b/src/mcore_bridge/version.py index 3594f30..f4f1f4b 100644 --- a/src/mcore_bridge/version.py +++ b/src/mcore_bridge/version.py @@ -1,5 +1,5 @@ # Make sure to modify __release_datetime__ to release time when making official release. -__version__ = '1.0.1.dev0' +__version__ = '1.1.0.dev0' # default release datetime for branches under active development is set # to be a time far-far-away-into-the-future __release_datetime__ = '2099-12-31 23:59:59' From e81832d447e054176383773cfd7f41dfb5caefd9 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Thu, 2 Apr 2026 19:15:40 +0800 Subject: [PATCH 2/5] support multimodel mtp --- src/mcore_bridge/model/gpt_model.py | 1 + src/mcore_bridge/patcher.py | 39 ++++++++++++++++++++++++++++- 2 files changed, 39 insertions(+), 1 deletion(-) diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 12cd97d..91ccf37 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -405,6 +405,7 @@ def _postprocess( packed_seq_params=packed_seq_params, sequence_len_offset=sequence_len_offset, embedding=self.embedding, + decoder_input=decoder_input, **(extra_block_kwargs or {}), ) hidden_states_list = torch.chunk(hidden_states, 1 + self.config.mtp_num_layers, dim=0) diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py index 392ff34..c57ee37 100644 --- a/src/mcore_bridge/patcher.py +++ b/src/mcore_bridge/patcher.py @@ -21,7 +21,7 @@ from packaging import version from peft.tuners.tuners_utils import BaseTuner from torch import nn -from typing import List, Optional, Tuple +from typing import List, Optional, Tuple, Callable from mcore_bridge.utils import get_logger, is_flash_attn_3_available @@ -424,6 +424,8 @@ def forward( position_ids=position_ids, embedding=embedding, hidden_states=hidden_states, + packed_seq_params=packed_seq_params, + decoder_input=decoder_input, ) assert not self.transformer_layer.self_attention.config.apply_rope_fusion packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' @@ -471,6 +473,41 @@ def forward( MultiTokenPredictionLayer.forward = forward + def _get_embeddings( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + embedding: Callable, + hidden_states: torch.Tensor, + packed_seq_params: Optional[PackedSeqParams] = None, + decoder_input: Optional[torch.Tensor] = None, + ): + # Calc logits for the current Multi-Token Prediction (MTP) layers. + input_ids, _ = roll_tensor( + input_ids, + shifts=-1, + dims=-1, + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) + position_ids, _ = roll_tensor( + position_ids, + shifts=-1, + dims=-1, + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) + # embedding + decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) + + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + return input_ids, position_ids, decoder_input, hidden_states + + + MultiTokenPredictionLayer._get_embeddings = _get_embeddings + + def _patch_peft_ModulesToSaveWrapper(): if version.parse(peft.__version__) >= version.parse('0.16'): from peft.utils import other as peft_module From 9ac57637f5356e46f049fca638f30b55eef4cad2 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sat, 4 Apr 2026 18:45:45 +0800 Subject: [PATCH 3/5] fix --- src/mcore_bridge/patcher.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py index e964d34..baccd41 100644 --- a/src/mcore_bridge/patcher.py +++ b/src/mcore_bridge/patcher.py @@ -425,7 +425,6 @@ def forward( embedding=embedding, packed_seq_params=packed_seq_params, hidden_states=hidden_states, - packed_seq_params=packed_seq_params, decoder_input=decoder_input, ) assert not self.transformer_layer.self_attention.config.apply_rope_fusion From d98999470687e1ccdd763910c07afb3c0079db40 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 5 Apr 2026 14:35:16 +0800 Subject: [PATCH 4/5] update --- src/mcore_bridge/config/model_config.py | 2 -- src/mcore_bridge/model/gpt_model.py | 3 +-- src/mcore_bridge/patcher.py | 31 +++++++++++++++++++------ 3 files changed, 25 insertions(+), 11 deletions(-) diff --git a/src/mcore_bridge/config/model_config.py b/src/mcore_bridge/config/model_config.py index ca21567..253d12b 100644 --- a/src/mcore_bridge/config/model_config.py +++ b/src/mcore_bridge/config/model_config.py @@ -205,9 +205,7 @@ class ModelConfig(TransformerConfig): # visual hf_config: Optional[PretrainedConfig] = None - vit_gradient_checkpointing: Optional[bool] = None vit_attn_impl: Optional[str] = None # e.g. 'flash_attention_2' - vit_gradient_checkpointing_kwargs: Optional[Union[dict, str]] = None # Override perform_initialization: bool = False diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 57c3038..66c37a6 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -406,8 +406,7 @@ def _postprocess( rotary_pos_sin=rotary_pos_sin, packed_seq_params=packed_seq_params, sequence_len_offset=sequence_len_offset, - embedding=self.embedding, - decoder_input=decoder_input, + embedding=(self.embedding, decoder_input), **(extra_block_kwargs or {}), ) mtp_labels = labels.clone() diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py index baccd41..b9d184c 100644 --- a/src/mcore_bridge/patcher.py +++ b/src/mcore_bridge/patcher.py @@ -21,7 +21,7 @@ from packaging import version from peft.tuners.tuners_utils import BaseTuner from torch import nn -from typing import List, Optional, Tuple, Callable +from typing import Callable, List, Optional, Tuple from mcore_bridge.utils import get_logger, is_flash_attn_3_available @@ -425,7 +425,6 @@ def forward( embedding=embedding, packed_seq_params=packed_seq_params, hidden_states=hidden_states, - decoder_input=decoder_input, ) assert not self.transformer_layer.self_attention.config.apply_rope_fusion packed_seq = packed_seq_params is not None and packed_seq_params.qkv_format == 'thd' @@ -472,7 +471,6 @@ def forward( MultiTokenPredictionLayer.forward = forward - def _get_embeddings( self, input_ids: torch.Tensor, @@ -480,8 +478,10 @@ def _get_embeddings( embedding: Callable, hidden_states: torch.Tensor, packed_seq_params: Optional[PackedSeqParams] = None, - decoder_input: Optional[torch.Tensor] = None, ): + from megatron.core.transformer.multi_token_prediction import roll_tensor + from megatron.core.utils import make_viewless_tensor + # Calc logits for the current Multi-Token Prediction (MTP) layers. input_ids, _ = roll_tensor( input_ids, @@ -498,13 +498,30 @@ def _get_embeddings( packed_seq_params=packed_seq_params, ) # embedding - decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) - + if isinstance(embedding, tuple): + embedding, decoder_input = embedding + else: + decoder_input = None + if decoder_input is None: + decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) + else: + enable_sp = self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1 + if enable_sp: + decoder_input = gather_from_sequence_parallel_region(decoder_input) + decoder_input, _ = roll_tensor( + decoder_input.transpose(0, 2), + shifts=-1, + dims=-1, + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) + decoder_input = decoder_input.transpose(0, 2).contiguous() + if enable_sp: + decoder_input = scatter_to_sequence_parallel_region(decoder_input) hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) return input_ids, position_ids, decoder_input, hidden_states - MultiTokenPredictionLayer._get_embeddings = _get_embeddings From 4654faaf6700504b39d769f93c1fb8523c2a5313 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Sun, 5 Apr 2026 15:19:44 +0800 Subject: [PATCH 5/5] update --- src/mcore_bridge/model/gpt_model.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 66c37a6..55443b7 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -322,6 +322,11 @@ def forward( assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]] + mtp_decoder_input = decoder_input + if self.config.is_multimodal and self.config.mtp_num_layers and decoder_input is None: + input_tensor = self.get_input_tensor() + input_tensor, mtp_decoder_input = input_tensor.chunk(2, dim=0) + self.set_input_tensor(input_tensor) # Run decoder. hidden_states = self.decoder( hidden_states=decoder_input, @@ -346,7 +351,7 @@ def forward( rotary_pos_cos=rotary_pos_cos, rotary_pos_sin=rotary_pos_sin, loss_mask=loss_mask, - decoder_input=decoder_input, + decoder_input=mtp_decoder_input, attention_mask=attention_mask, inference_params=inference_params, packed_seq_params=packed_seq_params, @@ -381,7 +386,10 @@ def _postprocess( the output layer, and computes language model loss when labels are provided. """ if not self.post_process: - return hidden_states + if self.config.is_multimodal and self.config.mtp_num_layers: + return torch.concat([hidden_states, decoder_input], dim=0) + else: + return hidden_states labels = labels if self.config.task_type == 'causal_lm' else None in_inference_mode = inference_context is not None and not self.training if in_inference_mode: @@ -395,6 +403,10 @@ def _postprocess( input_ids = split_cp_inputs(input_ids, getattr(packed_seq_params, 'cu_seqlens_q', None), 1) if self.mtp_process: + if self.config.is_multimodal: + embedding_ = (self.embedding, decoder_input) + else: + embedding_ = self.embedding hidden_states = self.mtp( input_ids=input_ids, position_ids=position_ids, @@ -406,7 +418,7 @@ def _postprocess( rotary_pos_sin=rotary_pos_sin, packed_seq_params=packed_seq_params, sequence_len_offset=sequence_len_offset, - embedding=(self.embedding, decoder_input), + embedding=embedding_, **(extra_block_kwargs or {}), ) mtp_labels = labels.clone()