Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 13 additions & 2 deletions src/paddlefleet/models/backends.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
)
from paddlefleet.transformer.mlp import MLPSublayersSpec

from ..spec_utils import LayerSpec


# from paddlefleet.transformer.moe.experts import GroupedMLP, SequentialMLP
# HACK(Guoxia Wang): need remove later
Expand Down Expand Up @@ -118,14 +120,23 @@ def column_parallel_layer_norm_linear(self) -> type | None:
"""Which layer for sequential layernorm and linear"""
return None

def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> type:
def layer_norm(
self,
rms_norm: bool = False,
for_qk: bool = False,
fused: bool = True,
eps: float = 1e-5,
) -> type:
"""Which module to use for layer norm"""
if rms_norm:
# Matching get_gpt_layer_local_spec.
# Why does the global need to be updated?
global LNImpl
LNImpl = WrappedPaddleNorm
return LNImpl
return LayerSpec(
layer=LNImpl,
extra_kwargs={"eps": eps},
)

def core_attention(self) -> type:
"""Which layer to use for attention"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,9 @@ def get_freqs_non_repeated(

return freqs

def get_cos_sin(
self, max_seq_len: int, offset: int = 0
) -> (Tensor, Tensor):
def get_cos_sin(self, max_seq_len: int, offset: int = 0) -> tuple(
Tensor, Tensor
):
"""Cosine and sine values for RoPE are precomputed for all positions up to the maximum
sequence length"""
freqs = self.get_freqs_non_repeated(max_seq_len, offset)
Expand Down
66 changes: 0 additions & 66 deletions src/paddlefleet/models/gpt/gpt_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,6 @@

from paddlefleet.pipeline_parallel import ScheduleNode
from paddlefleet.spec_utils import LayerSpec, build_layer
from paddlefleet.tensor_parallel.mappings import (
scatter_to_sequence_parallel_region,
)
from paddlefleet.transformer.layer import FleetLayer

if TYPE_CHECKING:
Expand Down Expand Up @@ -115,12 +112,6 @@ def forward(
attn_mask_startend_row_indices = dict_args.get(
"attn_mask_startend_row_indices", None
)
deepstack_image_embeds = dict_args.get("deepstack_image_embeds", None)
deepstack_video_embeds = dict_args.get("deepstack_video_embeds", None)
visual_pos_masks = None
# Deepstack
deepstack_visual_embeds = None
visual_pos_mask = None
mtp_emb_res = None
if input_ids is None:
assert dict_args["decoder_input"] is not None, (
Expand Down Expand Up @@ -225,8 +216,6 @@ def forward(
image_embeds.astype(decoder_input.dtype).reshape([-1]),
) # scatter bwd is a simple gather — no sparse atomics
decoder_input = image_src_flat.reshape(decoder_input.shape)
visual_pos_masks = image_mask[..., 0]
deepstack_visual_embeds = deepstack_image_embeds

if video_embeds is not None:
_, video_mask = self.get_placeholder_mask(
Expand All @@ -247,60 +236,7 @@ def forward(
video_embeds.astype(decoder_input.dtype).reshape([-1]),
)
decoder_input = video_src_flat.reshape(decoder_input.shape)
visual_pos_masks = video_mask[..., 0]
deepstack_visual_embeds = deepstack_video_embeds

if image_embeds is not None and video_embeds is not None:
image_mask = image_mask[..., 0] # [B, S] bool
video_mask = video_mask[..., 0] # [B, S] bool
visual_pos_masks = image_mask | video_mask
deepstack_visual_embeds = []
for img_embed, vid_embed in zip(
deepstack_image_embeds, deepstack_video_embeds
):
# Build embed_joint [N_visual, H] without boolean-index
# scatter. Use dense mask arithmetic instead.
# img_embed : [N_img, H]
# vid_embed : [N_vid, H]
# visual_pos_masks: [B, S] bool, N_visual True entries
# img_mask_in_visual[i] = True iff visual position i is image
# Computed as: image_mask flattened, keep only visual positions,
# expressed as a dense [N_visual] float mask — no indexing.
h = img_embed.shape[-1]
n_visual = int(visual_pos_masks.sum())
# visual_pos_flat: [B*S] bool
visual_pos_flat = visual_pos_masks.reshape([-1])
image_mask_flat = image_mask.reshape([-1]) # [B*S] bool
video_mask_flat = video_mask.reshape([-1]) # [B*S] bool
# Dense [B*S] float masks, then compress to [N_visual] via
# paddle.masked_select (forward: gather, backward: scatter_add
# — but scalar backward is efficient, no sparse atomics)
img_mask_in_vis_f = paddle.masked_select(
image_mask_flat.astype(img_embed.dtype),
visual_pos_flat,
).unsqueeze(-1) # [N_visual, 1]
vid_mask_in_vis_f = paddle.masked_select(
video_mask_flat.astype(vid_embed.dtype),
visual_pos_flat,
).unsqueeze(-1) # [N_visual, 1]
embed_joint = (
img_embed.reshape([n_visual, h]) * img_mask_in_vis_f
+ vid_embed.reshape([n_visual, h])
* vid_mask_in_vis_f
)
deepstack_visual_embeds.append(embed_joint)
# Scatter decoder_input to SP format [S/tp, B, H] after multimodal
# token replacement, since LanguageModelEmbedding's internal scatter
# was disabled to allow image/video embedding insertion first.
if self.sequence_parallel:
decoder_input = decoder_input.transpose(
[1, 0, 2]
).contiguous()
decoder_input = scatter_to_sequence_parallel_region(
decoder_input, group=self.embedding.tp_group
)
if self.config.clone_scatter_output_in_embedding:
decoder_input = decoder_input.clone()
# Rotary positional embeddings (embedding is None for PP intermediate devices)
rotary_pos_emb = None
rotary_pos_cos = None
Expand Down Expand Up @@ -352,8 +288,6 @@ def forward(
"rotary_pos_cos": rotary_pos_cos,
"rotary_pos_sin": rotary_pos_sin,
"position_ids": position_ids,
"deepstack_visual_emb": deepstack_visual_embeds,
"visual_pos_masks": visual_pos_masks,
}
if mtp_emb_res is not None:
assert (
Expand Down
19 changes: 15 additions & 4 deletions src/paddlefleet/models/gpt/gpt_layer_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,10 +204,17 @@ def get_gpt_layer_local_spec(

backend = LocalSpecProvider()
# Adjust for RMS norm.
norm_eps = config.rms_norm_eps if config is not None else 1e-5
if normalization == "RMSNorm":
layer_norm = backend.layer_norm(rms_norm=True, for_qk=False)
layer_norm = backend.layer_norm(
rms_norm=True, for_qk=False, eps=norm_eps
)
qk_norm = backend.layer_norm(rms_norm=True, for_qk=True, eps=norm_eps)
else:
layer_norm = backend.layer_norm(rms_norm=False, for_qk=False)
layer_norm = backend.layer_norm(
rms_norm=False, for_qk=False, eps=norm_eps
)
qk_norm = backend.layer_norm(rms_norm=False, for_qk=True, eps=norm_eps)

mlp = get_mlp_layer_spec_for_backend(
backend=backend,
Expand All @@ -223,7 +230,9 @@ def get_gpt_layer_local_spec(
norm=layer_norm,
),
)
transformer_cls = getattr(config, "specific_layer", TransformerLayer)
transformer_cls = getattr(
config, "specific_transformer_layer", TransformerLayer
)
if paddle.distributed.is_initialized():
use_overlap = fleet.fleet._user_defined_strategy.hybrid_configs[
"pp_configs"
Expand Down Expand Up @@ -490,6 +499,7 @@ def get_gpt_spec(
language_embedding=language_embedding_spec,
rope_embedding=rope_embedding_spec,
)
embedding_cls = getattr(config, "specific_embedding", GPTEmbedding)

# Build block_attn_res spec for GPTLMHead
lm_head_block_attn_res = IdentityOp
Expand All @@ -514,10 +524,11 @@ def get_gpt_spec(
extra_kwargs={
"config": config,
"tie_word_embeddings": tie_word_embeddings,
"modal": "language_model" if config.multimodal_embedding else None,
},
sublayers_spec=GPTSublayersSpec(
embedding=LayerSpec(
layer=GPTEmbedding,
layer=embedding_cls,
sublayers_spec=embedding_spec,
extra_kwargs=embedding_extra_kwargs,
),
Expand Down
Loading
Loading