Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src/mcore_bridge/config/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,9 +205,7 @@ class ModelConfig(TransformerConfig):

# visual
hf_config: Optional[PretrainedConfig] = None
vit_gradient_checkpointing: Optional[bool] = None
vit_attn_impl: Optional[str] = None # e.g. 'flash_attention_2'
vit_gradient_checkpointing_kwargs: Optional[Union[dict, str]] = None

# Override
perform_initialization: bool = False
Expand Down
18 changes: 15 additions & 3 deletions src/mcore_bridge/model/gpt_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,6 +322,11 @@ def forward(
assert position_ids.shape[0] == 1, f'position_ids.shape: {position_ids.shape}'
decoder_rotary_pos_emb = rotary_pos_emb[position_ids[0]]

mtp_decoder_input = decoder_input
if self.config.is_multimodal and self.config.mtp_num_layers and decoder_input is None:
input_tensor = self.get_input_tensor()
input_tensor, mtp_decoder_input = input_tensor.chunk(2, dim=0)
self.set_input_tensor(input_tensor)
# Run decoder.
hidden_states = self.decoder(
hidden_states=decoder_input,
Expand All @@ -346,7 +351,7 @@ def forward(
rotary_pos_cos=rotary_pos_cos,
rotary_pos_sin=rotary_pos_sin,
loss_mask=loss_mask,
decoder_input=decoder_input,
decoder_input=mtp_decoder_input,
attention_mask=attention_mask,
inference_params=inference_params,
packed_seq_params=packed_seq_params,
Expand Down Expand Up @@ -381,7 +386,10 @@ def _postprocess(
the output layer, and computes language model loss when labels are provided.
"""
if not self.post_process:
return hidden_states
if self.config.is_multimodal and self.config.mtp_num_layers:
return torch.concat([hidden_states, decoder_input], dim=0)
else:
return hidden_states
labels = labels if self.config.task_type == 'causal_lm' else None
in_inference_mode = inference_context is not None and not self.training
if in_inference_mode:
Expand All @@ -395,6 +403,10 @@ def _postprocess(
input_ids = split_cp_inputs(input_ids, getattr(packed_seq_params, 'cu_seqlens_q', None), 1)

if self.mtp_process:
if self.config.is_multimodal:
embedding_ = (self.embedding, decoder_input)
else:
embedding_ = self.embedding
hidden_states = self.mtp(
input_ids=input_ids,
position_ids=position_ids,
Expand All @@ -406,7 +418,7 @@ def _postprocess(
rotary_pos_sin=rotary_pos_sin,
packed_seq_params=packed_seq_params,
sequence_len_offset=sequence_len_offset,
embedding=self.embedding,
embedding=embedding_,
**(extra_block_kwargs or {}),
)
mtp_labels = labels.clone()
Expand Down
55 changes: 54 additions & 1 deletion src/mcore_bridge/patcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from packaging import version
from peft.tuners.tuners_utils import BaseTuner
from torch import nn
from typing import List, Optional, Tuple
from typing import Callable, List, Optional, Tuple

from mcore_bridge.utils import get_logger, is_flash_attn_3_available

Expand Down Expand Up @@ -471,6 +471,59 @@ def forward(

MultiTokenPredictionLayer.forward = forward

def _get_embeddings(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor,
embedding: Callable,
hidden_states: torch.Tensor,
packed_seq_params: Optional[PackedSeqParams] = None,
):
from megatron.core.transformer.multi_token_prediction import roll_tensor
from megatron.core.utils import make_viewless_tensor

# Calc logits for the current Multi-Token Prediction (MTP) layers.
input_ids, _ = roll_tensor(
input_ids,
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)
position_ids, _ = roll_tensor(
position_ids,
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)
# embedding
if isinstance(embedding, tuple):
embedding, decoder_input = embedding
else:
decoder_input = None
if decoder_input is None:
decoder_input = embedding(input_ids=input_ids, position_ids=position_ids)
else:
enable_sp = self.config.sequence_parallel and self.config.tensor_model_parallel_size > 1
if enable_sp:
decoder_input = gather_from_sequence_parallel_region(decoder_input)
decoder_input, _ = roll_tensor(
decoder_input.transpose(0, 2),
shifts=-1,
dims=-1,
cp_group=self.cp_group,
packed_seq_params=packed_seq_params,
)
decoder_input = decoder_input.transpose(0, 2).contiguous()
if enable_sp:
decoder_input = scatter_to_sequence_parallel_region(decoder_input)
hidden_states = make_viewless_tensor(inp=hidden_states, requires_grad=True, keep_graph=True)

return input_ids, position_ids, decoder_input, hidden_states

MultiTokenPredictionLayer._get_embeddings = _get_embeddings


def _patch_peft_ModulesToSaveWrapper():
if version.parse(peft.__version__) >= version.parse('0.16'):
Expand Down
Loading