diff --git a/src/mcore_bridge/config/model_config.py b/src/mcore_bridge/config/model_config.py index ca21567..253d12b 100644 --- a/src/mcore_bridge/config/model_config.py +++ b/src/mcore_bridge/config/model_config.py @@ -205,9 +205,7 @@ class ModelConfig(TransformerConfig): # visual hf_config: Optional[PretrainedConfig] = None - vit_gradient_checkpointing: Optional[bool] = None vit_attn_impl: Optional[str] = None # e.g. 'flash_attention_2' - vit_gradient_checkpointing_kwargs: Optional[Union[dict, str]] = None # Override perform_initialization: bool = False diff --git a/src/mcore_bridge/model/gpt_model.py b/src/mcore_bridge/model/gpt_model.py index 6df275c..55443b7 100644 --- a/src/mcore_bridge/model/gpt_model.py +++ b/src/mcore_bridge/model/gpt_model.py @@ -322,6 +322,11 @@ def forward( assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}' decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]] + mtp_decoder_input = decoder_input + if self.config.is_multimodal and self.config.mtp_num_layers and decoder_input is None: + input_tensor = self.get_input_tensor() + input_tensor, mtp_decoder_input = input_tensor.chunk(2, dim=0) + self.set_input_tensor(input_tensor) # Run decoder. hidden_states = self.decoder( hidden_states=decoder_input, @@ -346,7 +351,7 @@ def forward( rotary_pos_cos=rotary_pos_cos, rotary_pos_sin=rotary_pos_sin, loss_mask=loss_mask, - decoder_input=decoder_input, + decoder_input=mtp_decoder_input, attention_mask=attention_mask, inference_params=inference_params, packed_seq_params=packed_seq_params, @@ -381,7 +386,10 @@ def _postprocess( the output layer, and computes language model loss when labels are provided. """ if not self.post_process: - return hidden_states + if self.config.is_multimodal and self.config.mtp_num_layers: + return torch.concat([hidden_states, decoder_input], dim=0) + else: + return hidden_states labels = labels if self.config.task_type == 'causal_lm' else None in_inference_mode = inference_context is not None and not self.training if in_inference_mode: @@ -395,6 +403,10 @@ def _postprocess( input_ids = split_cp_inputs(input_ids, getattr(packed_seq_params, 'cu_seqlens_q', None), 1) if self.mtp_process: + if self.config.is_multimodal: + embedding_ = (self.embedding, decoder_input) + else: + embedding_ = self.embedding hidden_states = self.mtp( input_ids=input_ids, position_ids=position_ids, @@ -406,7 +418,7 @@ def _postprocess( rotary_pos_sin=rotary_pos_sin, packed_seq_params=packed_seq_params, sequence_len_offset=sequence_len_offset, - embedding=self.embedding, + embedding=embedding_, **(extra_block_kwargs or {}), ) mtp_labels = labels.clone() diff --git a/src/mcore_bridge/patcher.py b/src/mcore_bridge/patcher.py index c43141d..b9d184c 100644 --- a/src/mcore_bridge/patcher.py +++ b/src/mcore_bridge/patcher.py @@ -21,7 +21,7 @@ from packaging import version from peft.tuners.tuners_utils import BaseTuner from torch import nn -from typing import List, Optional, Tuple +from typing import Callable, List, Optional, Tuple from mcore_bridge.utils import get_logger, is_flash_attn_3_available @@ -471,6 +471,59 @@ def forward( MultiTokenPredictionLayer.forward = forward + def _get_embeddings( + self, + input_ids: torch.Tensor, + position_ids: torch.Tensor, + embedding: Callable, + hidden_states: torch.Tensor, + packed_seq_params: Optional[PackedSeqParams] = None, + ): + from megatron.core.transformer.multi_token_prediction import roll_tensor + from megatron.core.utils import make_viewless_tensor + + # Calc logits for the current Multi-Token Prediction (MTP) layers. + input_ids, _ = roll_tensor( + input_ids, + shifts=-1, + dims=-1, + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) + position_ids, _ = roll_tensor( + position_ids, + shifts=-1, + dims=-1, + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) + # embedding + if isinstance(embedding, tuple): + embedding, decoder_input = embedding + else: + decoder_input = None + if decoder_input is None: + decoder_input = embedding(input_ids=input_ids, position_ids=position_ids) + else: + enable_sp = self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1 + if enable_sp: + decoder_input = gather_from_sequence_parallel_region(decoder_input) + decoder_input, _ = roll_tensor( + decoder_input.transpose(0, 2), + shifts=-1, + dims=-1, + cp_group=self.cp_group, + packed_seq_params=packed_seq_params, + ) + decoder_input = decoder_input.transpose(0, 2).contiguous() + if enable_sp: + decoder_input = scatter_to_sequence_parallel_region(decoder_input) + hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True) + + return input_ids, position_ids, decoder_input, hidden_states + + MultiTokenPredictionLayer._get_embeddings = _get_embeddings + def _patch_peft_ModulesToSaveWrapper(): if version.parse(peft.__version__) >= version.parse('0.16'):