diff --git a/src/paddlefleet/models/backends.py b/src/paddlefleet/models/backends.py index a8ae893d9..a7b0cce3c 100644 --- a/src/paddlefleet/models/backends.py +++ b/src/paddlefleet/models/backends.py @@ -33,6 +33,8 @@ ) from paddlefleet.transformer.mlp import MLPSublayersSpec +from ..spec_utils import LayerSpec + # from paddlefleet.transformer.moe.experts import GroupedMLP, SequentialMLP # HACK(Guoxia Wang): need remove later @@ -118,14 +120,23 @@ def column_parallel_layer_norm_linear(self) -> type | None: """Which layer for sequential layernorm and linear""" return None - def layer_norm(self, rms_norm: bool = False, for_qk: bool = False) -> type: + def layer_norm( + self, + rms_norm: bool = False, + for_qk: bool = False, + fused: bool = True, + eps: float = 1e-5, + ) -> type: """Which module to use for layer norm""" if rms_norm: # Matching get_gpt_layer_local_spec. # Why does the global need to be updated? global LNImpl LNImpl = WrappedPaddleNorm - return LNImpl + return LayerSpec( + layer=LNImpl, + extra_kwargs={"eps": eps}, + ) def core_attention(self) -> type: """Which layer to use for attention""" diff --git a/src/paddlefleet/models/common/embeddings/rotary_pos_embedding.py b/src/paddlefleet/models/common/embeddings/rotary_pos_embedding.py index 76b70f364..4617eeb70 100644 --- a/src/paddlefleet/models/common/embeddings/rotary_pos_embedding.py +++ b/src/paddlefleet/models/common/embeddings/rotary_pos_embedding.py @@ -159,9 +159,9 @@ def get_freqs_non_repeated( return freqs - def get_cos_sin( - self, max_seq_len: int, offset: int = 0 - ) -> (Tensor, Tensor): + def get_cos_sin(self, max_seq_len: int, offset: int = 0) -> tuple( + Tensor, Tensor + ): """Cosine and sine values for RoPE are precomputed for all positions up to the maximum sequence length""" freqs = self.get_freqs_non_repeated(max_seq_len, offset) diff --git a/src/paddlefleet/models/gpt/gpt_embedding.py b/src/paddlefleet/models/gpt/gpt_embedding.py index 11f17427c..d28093a48 100644 --- a/src/paddlefleet/models/gpt/gpt_embedding.py +++ b/src/paddlefleet/models/gpt/gpt_embedding.py @@ -23,9 +23,6 @@ from paddlefleet.pipeline_parallel import ScheduleNode from paddlefleet.spec_utils import LayerSpec, build_layer -from paddlefleet.tensor_parallel.mappings import ( - scatter_to_sequence_parallel_region, -) from paddlefleet.transformer.layer import FleetLayer if TYPE_CHECKING: @@ -115,12 +112,6 @@ def forward( attn_mask_startend_row_indices = dict_args.get( "attn_mask_startend_row_indices", None ) - deepstack_image_embeds = dict_args.get("deepstack_image_embeds", None) - deepstack_video_embeds = dict_args.get("deepstack_video_embeds", None) - visual_pos_masks = None - # Deepstack - deepstack_visual_embeds = None - visual_pos_mask = None mtp_emb_res = None if input_ids is None: assert dict_args["decoder_input"] is not None, ( @@ -225,8 +216,6 @@ def forward( image_embeds.astype(decoder_input.dtype).reshape([-1]), ) # scatter bwd is a simple gather — no sparse atomics decoder_input = image_src_flat.reshape(decoder_input.shape) - visual_pos_masks = image_mask[..., 0] - deepstack_visual_embeds = deepstack_image_embeds if video_embeds is not None: _, video_mask = self.get_placeholder_mask( @@ -247,60 +236,7 @@ def forward( video_embeds.astype(decoder_input.dtype).reshape([-1]), ) decoder_input = video_src_flat.reshape(decoder_input.shape) - visual_pos_masks = video_mask[..., 0] - deepstack_visual_embeds = deepstack_video_embeds - if image_embeds is not None and video_embeds is not None: - image_mask = image_mask[..., 0] # [B, S] bool - video_mask = video_mask[..., 0] # [B, S] bool - visual_pos_masks = image_mask | video_mask - deepstack_visual_embeds = [] - for img_embed, vid_embed in zip( - deepstack_image_embeds, deepstack_video_embeds - ): - # Build embed_joint [N_visual, H] without boolean-index - # scatter. Use dense mask arithmetic instead. - # img_embed : [N_img, H] - # vid_embed : [N_vid, H] - # visual_pos_masks: [B, S] bool, N_visual True entries - # img_mask_in_visual[i] = True iff visual position i is image - # Computed as: image_mask flattened, keep only visual positions, - # expressed as a dense [N_visual] float mask — no indexing. - h = img_embed.shape[-1] - n_visual = int(visual_pos_masks.sum()) - # visual_pos_flat: [B*S] bool - visual_pos_flat = visual_pos_masks.reshape([-1]) - image_mask_flat = image_mask.reshape([-1]) # [B*S] bool - video_mask_flat = video_mask.reshape([-1]) # [B*S] bool - # Dense [B*S] float masks, then compress to [N_visual] via - # paddle.masked_select (forward: gather, backward: scatter_add - # — but scalar backward is efficient, no sparse atomics) - img_mask_in_vis_f = paddle.masked_select( - image_mask_flat.astype(img_embed.dtype), - visual_pos_flat, - ).unsqueeze(-1) # [N_visual, 1] - vid_mask_in_vis_f = paddle.masked_select( - video_mask_flat.astype(vid_embed.dtype), - visual_pos_flat, - ).unsqueeze(-1) # [N_visual, 1] - embed_joint = ( - img_embed.reshape([n_visual, h]) * img_mask_in_vis_f - + vid_embed.reshape([n_visual, h]) - * vid_mask_in_vis_f - ) - deepstack_visual_embeds.append(embed_joint) - # Scatter decoder_input to SP format [S/tp, B, H] after multimodal - # token replacement, since LanguageModelEmbedding's internal scatter - # was disabled to allow image/video embedding insertion first. - if self.sequence_parallel: - decoder_input = decoder_input.transpose( - [1, 0, 2] - ).contiguous() - decoder_input = scatter_to_sequence_parallel_region( - decoder_input, group=self.embedding.tp_group - ) - if self.config.clone_scatter_output_in_embedding: - decoder_input = decoder_input.clone() # Rotary positional embeddings (embedding is None for PP intermediate devices) rotary_pos_emb = None rotary_pos_cos = None @@ -352,8 +288,6 @@ def forward( "rotary_pos_cos": rotary_pos_cos, "rotary_pos_sin": rotary_pos_sin, "position_ids": position_ids, - "deepstack_visual_emb": deepstack_visual_embeds, - "visual_pos_masks": visual_pos_masks, } if mtp_emb_res is not None: assert ( diff --git a/src/paddlefleet/models/gpt/gpt_layer_specs.py b/src/paddlefleet/models/gpt/gpt_layer_specs.py index 00d717d79..ca74ac04c 100644 --- a/src/paddlefleet/models/gpt/gpt_layer_specs.py +++ b/src/paddlefleet/models/gpt/gpt_layer_specs.py @@ -204,10 +204,17 @@ def get_gpt_layer_local_spec( backend = LocalSpecProvider() # Adjust for RMS norm. + norm_eps = config.rms_norm_eps if config is not None else 1e-5 if normalization == "RMSNorm": - layer_norm = backend.layer_norm(rms_norm=True, for_qk=False) + layer_norm = backend.layer_norm( + rms_norm=True, for_qk=False, eps=norm_eps + ) + qk_norm = backend.layer_norm(rms_norm=True, for_qk=True, eps=norm_eps) else: - layer_norm = backend.layer_norm(rms_norm=False, for_qk=False) + layer_norm = backend.layer_norm( + rms_norm=False, for_qk=False, eps=norm_eps + ) + qk_norm = backend.layer_norm(rms_norm=False, for_qk=True, eps=norm_eps) mlp = get_mlp_layer_spec_for_backend( backend=backend, @@ -223,7 +230,9 @@ def get_gpt_layer_local_spec( norm=layer_norm, ), ) - transformer_cls = getattr(config, "specific_layer", TransformerLayer) + transformer_cls = getattr( + config, "specific_transformer_layer", TransformerLayer + ) if paddle.distributed.is_initialized(): use_overlap = fleet.fleet._user_defined_strategy.hybrid_configs[ "pp_configs" @@ -490,6 +499,7 @@ def get_gpt_spec( language_embedding=language_embedding_spec, rope_embedding=rope_embedding_spec, ) + embedding_cls = getattr(config, "specific_embedding", GPTEmbedding) # Build block_attn_res spec for GPTLMHead lm_head_block_attn_res = IdentityOp @@ -514,10 +524,11 @@ def get_gpt_spec( extra_kwargs={ "config": config, "tie_word_embeddings": tie_word_embeddings, + "modal": "language_model" if config.multimodal_embedding else None, }, sublayers_spec=GPTSublayersSpec( embedding=LayerSpec( - layer=GPTEmbedding, + layer=embedding_cls, sublayers_spec=embedding_spec, extra_kwargs=embedding_extra_kwargs, ), diff --git a/src/paddlefleet/models/gpt/gpt_model.py b/src/paddlefleet/models/gpt/gpt_model.py index 9f16e3acc..d13acb97c 100644 --- a/src/paddlefleet/models/gpt/gpt_model.py +++ b/src/paddlefleet/models/gpt/gpt_model.py @@ -17,15 +17,6 @@ from dataclasses import dataclass from typing import TYPE_CHECKING -from paddlefleet.pipeline_parallel import ( - LayerDesc, - PipelineLayer, - SharedLayerDesc, -) -from paddlefleet.pipeline_parallel.pp_utils.utils import ( - dict_to_tuple_helper, -) - if TYPE_CHECKING: from paddlefleet.spec_utils import LayerSpec @@ -33,96 +24,16 @@ from paddlefleet.models.gpt.gpt_embedding import GPTEmbedding from paddlefleet.models.gpt.lm_head import GPTLMHead -from paddlefleet.pipeline_parallel import ScheduleChunk -from paddlefleet.transformer.transformer_layer import ( - TransformerLayer, - TransformerLayerNode, - TransformerLayerOverlappedScheduleNode, +from paddlefleet.transformer.transformer_encoder import TransformerEncoder + +from ...pipeline_parallel import ( + LayerDesc, + SharedLayerDesc, ) logger = logging.getLogger(__name__) -def build_overlapped_nodes(forward_chunk, backward_chunk): - """Build overlapped nodes for TransformerLayer.""" - overlap_element_class = TransformerLayerNode - forward_decoder_layer_num = 0 - backward_decoder_layer_num = 0 - - assert isinstance(forward_chunk, ScheduleChunk) and isinstance( - backward_chunk, ScheduleChunk - ) - for n in forward_chunk.nodes: - if isinstance(n, overlap_element_class): - forward_decoder_layer_num += 1 - for n in reversed(backward_chunk.nodes): - if isinstance(n, overlap_element_class): - backward_decoder_layer_num += 1 - - overlap_layers_num = min( - forward_decoder_layer_num, backward_decoder_layer_num - ) - - # construct forward pre- and post-chunks - forward_pre_layers = [] - forward_post_layers = [] - forward_overlap_layers = [] - is_pre = True - for n in forward_chunk.nodes: - if not isinstance(n, overlap_element_class): - if is_pre: - forward_pre_layers.append(n) - else: - forward_post_layers.append(n) - else: - is_pre = False - if len(forward_overlap_layers) == overlap_layers_num: - forward_post_layers.append(n) - else: - forward_overlap_layers.append(n) - - forward_pre_node = ScheduleChunk(forward_pre_layers) - forward_post_node = ScheduleChunk(forward_post_layers) - - # construct backward pre- and post-chunks - backward_pre_layers = [] - backward_post_layers = [] - backward_overlap_layers = [] - is_pre = True - for n in reversed(backward_chunk.nodes): - if not isinstance(n, overlap_element_class): - if is_pre: - backward_pre_layers.append(n) - else: - backward_post_layers.append(n) - else: - is_pre = False - if len(backward_overlap_layers) == overlap_layers_num: - backward_post_layers.append(n) - else: - backward_overlap_layers.append(n) - - backward_pre_node = ScheduleChunk(list(reversed(backward_pre_layers))) - backward_post_node = ScheduleChunk(list(reversed(backward_post_layers))) - - # construct overlap chunk - overlap_node = ScheduleChunk( - [ - TransformerLayerOverlappedScheduleNode(forward_node, backward_node) - for forward_node, backward_node in zip( - forward_overlap_layers, backward_overlap_layers - ) - ] - ) - return ( - forward_pre_node, - backward_pre_node, - overlap_node, - forward_post_node, - backward_post_node, - ) - - @dataclass class GPTSublayersSpec: """p @@ -139,7 +50,7 @@ class GPTSublayersSpec: lm_head: LayerSpec | None = None -class GPTModel(PipelineLayer): +class GPTModel(TransformerEncoder): """GPT Transformer language model. Args: @@ -152,6 +63,7 @@ def __init__( **kwargs, ) -> None: self.config = kwargs["config"] + self.modal = kwargs.pop("modal", None) tie_word_embeddings = ( kwargs["tie_word_embeddings"] and self.config.pipeline_model_parallel_size > 1 @@ -176,7 +88,7 @@ def __init__( else fleet.get_hybrid_communicate_group().topology() ) - super().__init__( + super(TransformerEncoder, self).__init__( layers=self.layers, topology=topology, num_virtual_pipeline_stages=self.config.virtual_pipeline_model_parallel_size, @@ -191,33 +103,10 @@ def __init__( if isinstance(layer, GPTLMHead): layer.weight = shared_embed_weight - def _get_weight_only_params(self): - """Get all parameters marked with is_weight_only_mtp flag.""" - return [ - param - for param in self.state_dict().values() - if getattr(param, "is_weight_only_mtp", False) - ] - - def offload_weight_only_params(self): - """Offload all weight-only MTP parameters to CPU pinned memory.""" - for param in self._get_weight_only_params(): - if param.place.is_gpu_place(): - cpu_param = param.pin_memory() - cpu_param._share_buffer_to(param) - - def reload_weight_only_params(self): - """Reload weight-only MTP parameters from CPU pinned memory back to GPU.""" - for param in self._get_weight_only_params(): - if not param.place.is_gpu_place(): - gpu_param = param.cuda() - gpu_param._share_buffer_to(param) - def get_layer_desc_list(self, spec, tie_word_embeddings): layers = [] - model_type = getattr(self.config, "model_type", "") - if "qwen3_vl" in model_type or "qwen3_5" in model_type: - name_prefix = "model.language_model" + if self.modal: + name_prefix = f"model.{self.modal}" else: name_prefix = "model" if tie_word_embeddings: @@ -282,398 +171,3 @@ def get_layer_desc_list(self, spec, tie_word_embeddings): ) return layers - - def overlapped_forward_backward( - self, - forward_chunk, - forward_inputs, - forward_loss_fn_node, - backward_chunk, - backward_loss_fn_node, - backward_input_grads, - scaler, - p2p_async_handle, - ): - if backward_loss_fn_node is not None: - if scaler: - backward_input_grads = backward_loss_fn_node.backward( - scaler=scaler - ) - else: - backward_input_grads = backward_loss_fn_node.backward() - - ( - forward_pre_node, - backward_pre_node, - overlap_node, - forward_post_node, - backward_post_node, - ) = build_overlapped_nodes(forward_chunk, backward_chunk) - - if len(overlap_node.nodes) > 0: - assert not any( - isinstance(node, TransformerLayerNode) - for node in overlap_node.nodes - ) - # origin assert, why ? - # assert not any( - # isinstance(node, TransformerLayerNode) - # for node in forward_post_node.nodes - # ) - # assert not any( - # isinstance(node, TransformerLayerNode) - # for node in backward_post_node.nodes - # ) - - if p2p_async_handle is not None: - p2p_async_handle.forward_handle_wait() - p2p_async_handle.backward_handle_wait() - - forward_inputs = forward_pre_node.forward(forward_inputs) - backward_input_grads = backward_pre_node.backward(backward_input_grads) - - for i, node in enumerate(overlap_node.nodes): - forward_inputs, backward_input_grads = node.forward_backward( - forward_inputs, - backward_input_grads, - # split_bw=(i == len(overlap_node.nodes) - 1), - ) - - forward_inputs = forward_post_node.forward(forward_inputs) - backward_input_grads = backward_post_node.backward(backward_input_grads) - - # forward_inputs = forward_chunk.forward(forward_inputs) - - if p2p_async_handle is not None: - forward_inputs = dict_to_tuple_helper(forward_inputs) - p2p_async_handle.forward_async_comm(forward_inputs) - p2p_async_handle.backward_async_comm(backward_input_grads) - - # backward_input_grads = backward_chunk.backward(backward_input_grads) - - # used for bw split - # if len(overlap_node.nodes) > 0: - # WeightGradStore.pop() - # assert WeightGradStore.funcs_queue.empty() - - if forward_loss_fn_node is not None: - forward_loss = forward_loss_fn_node.forward(forward_inputs) - else: - forward_loss = None - - return forward_inputs, forward_loss, backward_input_grads - - def get_hardware_flops(self): - return 989e3 - - def add_sequential_layer(self, layers, layer_desc, name_prefix=""): - """ - Add a sequential layer to the network with specified description and name prefix. - - Args: - layers (list): List to store layer descriptions. Each element should be a dict - with keys "layer" (LayerDesc) and "name_prefix" (str). - layer_desc (LayerDesc|SharedLayerDesc): Layer description object containing - layer self.configuration. - name_prefix (str, optional): Prefix for layer names in the pipeline. - Defaults to empty string. - - Returns: - None: The layer description is appended to the input layers list. - """ - layers.append({"layer": layer_desc, "name_prefix": name_prefix}) - - def get_sequential_layers(self): - """ - Get all layers in the sequential network. - - Returns: - List[paddle.nn.Layer]: List containing all layers. - """ - return [x["layer"] for x in self._sequential_layers] - - def get_sequential_name_prefixes(self): - """ - Retrieve name prefixes for all parallel layers in the sequential network. - - Returns: - Dict[str, str]: A dictionary mapping layer indices (as strings) to their - corresponding name prefixes. The indices represent the position of - each layer in the sequential order. - """ - return { - str(index): x["name_prefix"] - for index, x in enumerate(self._sequential_layers) - } - - def get_shardlayer_prefix(self, name_splited): - """_summary_ - This function retrieves the prefix of a shared layer. The process involves: - 1. Identifying all key names of shared layers, like 'shared_weight01', 'shared_weight02', etc. - 2. For instance, given name_splited = ['shared_layers', 'shared_weight01', 'weight'], - the 'shared_layer_key' would be name_splited[1], which is 'shared_weight01'. - 3. By traversing through all layers, the function checks if the specified - shared_layer is present in the current stage. If found, it returns the corresponding prefix. - - Note: For retrieving all SharedLayer instances in Paddle, you can refer to the following Paddle code. - https://github.com/PaddlePaddle/Paddle/blob/2cf724d055679a1a0e48766dfb1708b920273078/python/paddle/distributed/fleet/meta_parallel/parallel_layers/pp_layers.py#L460-L513 - Args: - name_splited (_type_): _description_ - - Returns: - _type_: _description_ - """ - shared_layer_names = { - s.layer_name for s in self.layers if isinstance(s, SharedLayerDesc) - } - assert name_splited[1] in shared_layer_names, ( - f"The shared layer name {name_splited[1]} must be in prefixes!" - ) - shared_layer_key = name_splited[1] - for idx, layer in enumerate(self.layers): - if ( - isinstance(layer, SharedLayerDesc) - and layer.layer_name == shared_layer_key - ): - if self.get_stage_from_index(idx) == self._stage_id: - return self.get_sequential_name_prefixes()[str(idx)] - - # the prefix must be in the current stage, else raise error - raise ValueError( - f"The shared layer {shared_layer_key} must be in the current stage!" - ) - - def _set_pipeline_name_mapping(self, mappings=None): - """ - Set the name mapping for pipeline. - - Args: - mappings (dict, optional): Dictionary storing name mapping relationships. Default is None, meaning no mapping operation. - - Returns: - dict: Returns the updated or existing mapping relationship. - - """ - if mappings is not None: - self._pipeline_name_mapping = mappings - else: - single_to_pp_mapping = {} - pp_to_single_mapping = {} - - state_dict_keys = list(super().state_dict().keys()) - first_key = "" - for k in state_dict_keys: - if "shared_layers" not in k: - first_key = k - break - first_key = first_key.split(".") - # if use virtual pp_degree, the prefix is like 0.0.xxx - # else it will be like 0.xxx - use_virtual_pp_degree = ( - first_key[0].isdigit() and first_key[1].isdigit() - ) - - prefixes = self.get_sequential_name_prefixes() - for k in state_dict_keys: - name_splited = k.split(".") - if use_virtual_pp_degree: - if name_splited[0].isdigit(): - if name_splited[1].isdigit(): - idx = str( - int(name_splited[0]) + int(name_splited[1]) - ) - single_name = [prefixes[idx]] - single_name.extend(name_splited[2:]) - else: - single_name = [prefixes[str(len(prefixes) - 1)]] - single_name.extend(name_splited[2:]) - logger.warning( - f"Please check! we treat this key as last layer, get {k}, \ - set origin name as {'.'.join(single_name)}" - ) - elif name_splited[0] == "shared_layers": - single_name = [self.get_shardlayer_prefix(name_splited)] - single_name.extend(name_splited[2:]) - else: - single_to_pp_mapping[k] = k - pp_to_single_mapping[k] = k - continue - else: - idx = name_splited[0] - # for normal pp layer - if idx.isdigit(): - # allow empty prefix - single_name = ( - [] if prefixes[idx] == "" else [prefixes[idx]] - ) - single_name.extend(name_splited[1:]) - elif idx == "shared_layers": - single_name = [self.get_shardlayer_prefix(name_splited)] - single_name.extend(name_splited[2:]) - else: - single_to_pp_mapping[k] = k - pp_to_single_mapping[k] = k - continue - - single_to_pp_mapping[".".join(single_name)] = k - pp_to_single_mapping[k] = ".".join(single_name) - - self._pipeline_name_mapping = single_to_pp_mapping - self._pp_to_single_mapping = pp_to_single_mapping - - return self._pipeline_name_mapping - - def state_dict(self, *args, **kwargs): - """ - Return a dictionary with Pipeline Stage mapping. - Args: - *args (tuple): Variable argument list passed to parent method. - **kwargs (dict): Optional keyword arguments passed to parent method. - Returns: - dict: Dictionary containing Pipeline Stage mapping. - """ - state_dict = super().state_dict(*args, **kwargs) - - model_type = getattr(self.config, "model_type", "") - if "qwen3_vl" in model_type or "qwen3_5" in model_type: - name_prefix = "model.language_model." - else: - name_prefix = "" - if self._pipeline_name_mapping is None: - self._set_pipeline_name_mapping() - # assert len(self._pipeline_name_mapping) > 0, "The pipeline stage must have parameters!" - for k in list(state_dict.keys()): - v = state_dict.pop(k) - if name_prefix and k.startswith(name_prefix): - k = k[len(name_prefix) :] - if k not in self._pp_to_single_mapping: - state_dict[k] = v - continue - v.key = self._pp_to_single_mapping[k] - state_dict[self._pp_to_single_mapping[k]] = v - return state_dict - - def set_state_dict(self, state_dict, *args, **kwargs): - if self._pipeline_name_mapping is None: - self._set_pipeline_name_mapping() - assert len(self._pipeline_name_mapping) > 0, ( - "The pipeline stage must have parameters!" - ) - - for k in list(state_dict.keys()): - v = state_dict.pop(k) - if k not in self._pipeline_name_mapping: - continue - state_dict[self._pipeline_name_mapping[k]] = v - - ret = super().set_state_dict(state_dict, *args, **kwargs) - return ret - - def _check_shared_model_state(self): - if self._pipeline_name_mapping is None: - self._set_pipeline_name_mapping() - - super_state_dict = super().state_dict() - structure_name_to_tensor = {} - for k, v in super_state_dict.items(): - k = self._pp_to_single_mapping[k] - if k not in structure_name_to_tensor: - structure_name_to_tensor[k] = v - else: - old_v = structure_name_to_tensor[k] - assert old_v is v, ( - f"Shared tensor with different structure name: {k}" - ) - - missing_shared_keys = {} - for k, v in self._pp_to_single_mapping.items(): - mapped_k = self._pipeline_name_mapping[v] - if k != mapped_k: - missing_shared_keys[k] = mapped_k - return missing_shared_keys - - def sharded_state_dict(self, *args, **kwargs): - """ - sharded_state_dict method for PipelinePretrainedModel. - - Remaps parameter keys according to the pipeline stage mapping, and converts expert indices from local to global. - """ - sharded_state_dict = super().sharded_state_dict(*args, **kwargs) - if self._pipeline_name_mapping is None: - self._set_pipeline_name_mapping() - - model_type = getattr(self.config, "model_type", "") - if "qwen3_vl" in model_type or "qwen3_5" in model_type: - name_prefix = "model.language_model." - else: - name_prefix = "" - - for k in list(sharded_state_dict.keys()): - v = sharded_state_dict.pop(k) - # remove name_prefix - if name_prefix and k.startswith(name_prefix): - k = k[len(name_prefix) :] - if k not in self._pp_to_single_mapping: - sharded_state_dict[k] = v - continue - v.key = self._pp_to_single_mapping[k] - sharded_state_dict[self._pp_to_single_mapping[k]] = v - - def increment_expert_number(s, increment): - import re - - def replace(match): - original_number = int(match.group(0)) - new_number = original_number + increment - return str(new_number) - - return re.sub(r"(?<=experts\.)\d+", replace, s) - - renamed_sharded_state_dict = {} - for k, v in sharded_state_dict.items(): - global_expert_id_offset = getattr( - v, "global_expert_id_offset", None - ) - layer_cnt = getattr(v, "layer_cnt", None) - if global_expert_id_offset is not None: - new_key = increment_expert_number(k, global_expert_id_offset) - v.key = new_key - delattr(v, "global_expert_id_offset") - renamed_sharded_state_dict[new_key] = v - elif layer_cnt is not None: - new_key = k + "_layer_" + str(layer_cnt) - v.key = new_key - delattr(v, "layer_cnt") - renamed_sharded_state_dict[new_key] = v - else: - renamed_sharded_state_dict[k] = v - - return renamed_sharded_state_dict - - def fp8_quant_weight(self, batch_mode=False, quant_transpose=True): - if self._num_virtual_pipeline_stages > 1: - for idx, chunk in enumerate(self._model_chunks): - for idx, layer in enumerate(chunk): - if isinstance(layer, TransformerLayer): - layer.fp8_quant_weight( - batch_mode=batch_mode, - quant_transpose=quant_transpose, - ) - else: - for idx, layer in enumerate(self.run_function): - if isinstance(layer, TransformerLayer): - layer.fp8_quant_weight( - batch_mode=batch_mode, quant_transpose=quant_transpose - ) - - def use_fp8(self): - if self._num_virtual_pipeline_stages > 1: - for idx, chunk in enumerate(self._model_chunks): - for idx, layer in enumerate(chunk): - if isinstance(layer, TransformerLayer) and layer.use_fp8(): - return True - else: - for idx, layer in enumerate(self.run_function): - if isinstance(layer, TransformerLayer) and layer.use_fp8(): - return True - return False diff --git a/src/paddlefleet/models/qwen3_vl/__init__.py b/src/paddlefleet/models/qwen3_vl/__init__.py new file mode 100644 index 000000000..736407af7 --- /dev/null +++ b/src/paddlefleet/models/qwen3_vl/__init__.py @@ -0,0 +1,28 @@ +# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from .embedding import Qwen3VLTextEmbedding +from .qwen3_vl_model import ( + Qwen3VLModelDist, + Qwen3VLTextTransformerLayer, + Qwen3VLVisionModel, + Qwen3VLVisionTransformerLayer, +) + +__all__ = [ + "Qwen3VLTextEmbedding", + "Qwen3VLModelDist", + "Qwen3VLTextTransformerLayer", + "Qwen3VLVisionModel", + "Qwen3VLVisionTransformerLayer", +] diff --git a/src/paddlefleet/models/qwen3_vl/embedding.py b/src/paddlefleet/models/qwen3_vl/embedding.py index ebe85a567..ad8074cfc 100644 --- a/src/paddlefleet/models/qwen3_vl/embedding.py +++ b/src/paddlefleet/models/qwen3_vl/embedding.py @@ -15,12 +15,27 @@ import paddle from paddle import nn +from paddle.distributed.fleet.utils.sequence_parallel_utils import ( + ScatterOp, +) from paddle.nn import functional as F from ...packed_seq_params import PackedSeqParams from ...spec_utils import LayerSpec, build_layer +from ...tensor_parallel.mappings import ( + scatter_to_sequence_parallel_region, +) from ...transformer import TransformerConfig from ...transformer.layer import FleetLayer +from ..gpt.gpt_embedding import GPTEmbedding + + +def safe_repeat_interleave_values(values, repeats): + max_repeats = paddle.max(repeats) + mask = paddle.arange(max_repeats).unsqueeze(0) < repeats.unsqueeze(1) + expanded_values = values.unsqueeze(1).expand([values.shape[0], max_repeats]) + result = paddle.masked_select(expanded_values, mask) + return result @dataclass @@ -28,6 +43,22 @@ class VisionEmbeddingSpec: rope_embedding: LayerSpec = None +class VisionRotaryEmbedding(nn.Layer): + inv_freq: paddle.Tensor + + def __init__(self, dim: int, theta: float = 10000.0) -> None: + super().__init__() + inv_freq = 1.0 / ( + theta ** (paddle.arange(0, dim, 2, dtype=paddle.float32) / dim) + ) + self.register_buffer("inv_freq", inv_freq, persistable=False) + + def forward(self, seqlen: int) -> paddle.Tensor: + seq = paddle.arange(seqlen, dtype=self.inv_freq.dtype) + freqs = paddle.outer(seq, self.inv_freq) + return freqs + + class VisionEmbedding(FleetLayer): def __init__( self, @@ -68,193 +99,242 @@ def __init__( sublayers_spec.rope_embedding, ) - def rot_pos_emb(self, grid_thw): - pos_ids = [] - for t, h, w in grid_thw: - t, h, w = int(t), int(h), int(w) - hpos_ids = paddle.arange(h).unsqueeze(1).expand([-1, w]) - hpos_ids = hpos_ids.reshape( - [ - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ] - ) - hpos_ids = hpos_ids.transpose(perm=[0, 2, 1, 3]) - hpos_ids = hpos_ids.flatten() - - wpos_ids = paddle.arange(w).unsqueeze(0).expand([h, -1]) - wpos_ids = wpos_ids.reshape( - [ - h // self.spatial_merge_size, - self.spatial_merge_size, - w // self.spatial_merge_size, - self.spatial_merge_size, - ] - ) - wpos_ids = wpos_ids.transpose([0, 2, 1, 3]) - wpos_ids = wpos_ids.flatten() - pos_ids.append( - paddle.stack(x=[hpos_ids, wpos_ids], axis=-1).tile( - repeat_times=[t, 1] - ) - ) - pos_ids = paddle.concat(x=pos_ids, axis=0) - max_grid_size = int(grid_thw[:, 1:].max()) - # Get raw freqs [max_grid_size, head_dim//2] and index with 2D pos_ids - freqs = self.rotary_pos_emb.get_freqs_non_repeated(max_grid_size) - # pos_ids: [seq_len, 2], freqs: [max_grid_size, head_dim//2] - # Index freqs with each position dim: freqs[pos_ids] -> [seq_len, 2, head_dim//2] - rotary_pos_emb = freqs[pos_ids].flatten(start_axis=1) - # rotary_pos_emb: [seq_len, head_dim] (2 * head_dim//2) - # Duplicate to match expected format [seq_len, 1, 1, head_dim] - rotary_pos_emb = rotary_pos_emb[None, :, None, :] - return rotary_pos_emb - - def fast_pos_embed_interpolate(self, grid_thw): - grid_ts, grid_hs, grid_ws = ( - grid_thw[:, 0], - grid_thw[:, 1], - grid_thw[:, 2], + def _build_token_image_mapping(self, grid_thw): + """Build token-to-image mapping, shared by rot_pos_emb and fast_pos_embed_interpolate""" + heights = grid_thw[:, 1] + widths = grid_thw[:, 2] + frames = grid_thw[:, 0] + + num_tokens = frames * heights * widths # [N] + + total_tokens = num_tokens.sum().item() # 1 D2H + max_hw = paddle.max(paddle.maximum(heights, widths)).item() # 1 D2H + + # token-to-image mapping: image_id[j] = i, where cu_tokens[i] <= j < cu_tokens[i+1] + cu_tokens = paddle.concat( + [paddle.zeros([1], dtype="int64"), num_tokens.cumsum(0)] ) - device = paddle.get_device() + global_idx = paddle.arange(total_tokens, dtype="int64") + image_id = ( + global_idx.unsqueeze(-1) >= cu_tokens[:-1].unsqueeze(0) + ).astype("int64").sum(-1) - 1 - idx_list = [[] for _ in range(4)] - weight_list = [[] for _ in range(4)] + local_idx = global_idx - cu_tokens[image_id] - for t, h, w in zip(grid_ts, grid_hs, grid_ws): - t, h, w = int(t), int(h), int(w) - h_idxs = paddle.linspace(0, self.num_grid_per_side - 1, h) - w_idxs = paddle.linspace(0, self.num_grid_per_side - 1, w) + # frame-local index + token_hw = (heights * widths)[image_id] + frame_local_idx = local_idx % token_hw - h_idxs_floor = h_idxs.int() - w_idxs_floor = w_idxs.int() - h_idxs_ceil = (h_idxs.int() + 1).clip( - max=self.num_grid_per_side - 1 - ) - w_idxs_ceil = (w_idxs.int() + 1).clip( - max=self.num_grid_per_side - 1 + return image_id, frame_local_idx, total_tokens, max_hw + + def rot_pos_emb( + self, + grid_thw, + image_id=None, + frame_local_idx=None, + total_tokens=None, + max_hw=None, + ): + m = self.spatial_merge_size + widths = grid_thw[:, 2] + merged_w = widths // m + + if image_id is None: + image_id, frame_local_idx, total_tokens, max_hw = ( + self._build_token_image_mapping(grid_thw) ) - dh = h_idxs - h_idxs_floor.astype("float32") - dw = w_idxs - w_idxs_floor.astype("float32") + freq_table = self.rotary_pos_emb(max_hw) - base_h = h_idxs_floor * self.num_grid_per_side - base_h_ceil = h_idxs_ceil * self.num_grid_per_side + token_mw = merged_w[image_id] # [total_tokens] - indices = [ - (base_h[None].T + w_idxs_floor[None]).flatten(), - (base_h[None].T + w_idxs_ceil[None]).flatten(), - (base_h_ceil[None].T + w_idxs_floor[None]).flatten(), - (base_h_ceil[None].T + w_idxs_ceil[None]).flatten(), - ] + # Decompose linear index to coordinates: layout [merged_h, merged_w, m, m] + mm = m * m + mw_mm = token_mw * mm + block_row = frame_local_idx // mw_mm + r1 = frame_local_idx % mw_mm + block_col = r1 // mm + r2 = r1 % mm + intra_row = r2 // m + intra_col = r2 % m + + row_idx = block_row * m + intra_row + col_idx = block_col * m + intra_col - weights = [ - ((1 - dh)[None].T * (1 - dw)[None]).flatten(), - ((1 - dh)[None].T * dw[None]).flatten(), - (dh[None].T * (1 - dw)[None]).flatten(), - (dh[None].T * dw[None]).flatten(), + pos_ids = paddle.stack([row_idx, col_idx], axis=-1) # [total_tokens, 2] + + embeddings = freq_table[pos_ids] + embeddings = embeddings.flatten(start_axis=1) + return embeddings + + def fast_pos_embed_interpolate( + self, + grid_thw, + image_id=None, + frame_local_idx=None, + total_tokens=None, + max_hw=None, + ): + N = self.num_grid_per_side + m = self.spatial_merge_size + heights = grid_thw[:, 1] + widths = grid_thw[:, 2] + merged_w = widths // m + + if image_id is None: + image_id, frame_local_idx, total_tokens, max_hw = ( + self._build_token_image_mapping(grid_thw) + ) + + token_mw = merged_w[image_id] + + # Decompose linear index to coordinates (same layout as rot_pos_emb) + mm = m * m + mw_mm = token_mw * mm + block_row = frame_local_idx // mw_mm + r1 = frame_local_idx % mw_mm + block_col = r1 // mm + r2 = r1 % mm + intra_row = r2 // m + intra_col = r2 % m + + # Pixel coordinates + j_h = (block_row * m + intra_row).astype("float32") + j_w = (block_col * m + intra_col).astype("float32") + + # Bilinear interpolation: h_idx = j_h * (N-1) / (h-1) + token_h = heights[image_id].astype("float32") + token_w = widths[image_id].astype("float32") + h_denom = (token_h - 1).clip(min=1.0) + w_denom = (token_w - 1).clip(min=1.0) + h_idx = j_h * (N - 1) / h_denom + w_idx = j_w * (N - 1) / w_denom + + h_floor = h_idx.astype("int32") + w_floor = w_idx.astype("int32") + h_ceil = (h_floor + 1).clip(max=N - 1) + w_ceil = (w_floor + 1).clip(max=N - 1) + + dh = h_idx - h_floor.astype("float32") + dw = w_idx - w_floor.astype("float32") + + base_h = h_floor * N + base_h_ceil = h_ceil * N + + idx_tensor = paddle.stack( + [ + (base_h + w_floor).astype("int64"), + (base_h + w_ceil).astype("int64"), + (base_h_ceil + w_floor).astype("int64"), + (base_h_ceil + w_ceil).astype("int64"), ] + ) # [4, total_tokens] - for i in range(4): - idx_list[i].extend(indices[i].tolist()) - weight_list[i].extend(weights[i].tolist()) + weight_tensor = paddle.stack( + [(1 - dh) * (1 - dw), (1 - dh) * dw, dh * (1 - dw), dh * dw] + ).astype(self.pos_embed.weight.dtype) # [4, total_tokens] - idx_tensor = paddle.to_tensor(idx_list, dtype="int64") - weight_tensor = paddle.to_tensor( - weight_list, dtype=self.pos_embed.weight.dtype - ) pos_embeds = self.pos_embed(idx_tensor) * weight_tensor[:, :, None] patch_pos_embeds = ( pos_embeds[0] + pos_embeds[1] + pos_embeds[2] + pos_embeds[3] ) - - patch_pos_embeds = paddle.split( - patch_pos_embeds, - [int(h) * int(w) for h, w in zip(grid_hs, grid_ws)], - ) - - patch_pos_embeds_permute = [] - merge_size = self.spatial_merge_size - for pos_embed, t, h, w in zip( - patch_pos_embeds, grid_ts, grid_hs, grid_ws - ): - pos_embed = pos_embed.tile([int(t), 1]) - pos_embed = ( - pos_embed.reshape( - [ - int(t), - int(h) // merge_size, - merge_size, - int(w) // merge_size, - merge_size, - -1, - ] - ) - .transpose([0, 1, 3, 2, 4, 5]) - .flatten(0, 4) - ) - patch_pos_embeds_permute.append(pos_embed) - patch_pos_embeds = paddle.concat(patch_pos_embeds_permute) + # Already in (block_h, block_w, intra_h, intra_w) order, no merge_reshape needed return patch_pos_embeds def get_packed_seq_params( self, grid_thw: paddle.Tensor, ): - seqlens = paddle.repeat_interleave( + seqlens = safe_repeat_interleave_values( grid_thw[:, 1] * grid_thw[:, 2], grid_thw[:, 0] ) - cu_seqlens = seqlens.cumsum(axis=0).astype("int32") - cu_seqlens = F.pad(cu_seqlens.unsqueeze(0), [1, 0], value=0).squeeze(0) + cu_seqlens = seqlens.cumsum(axis=0, dtype=paddle.int32) + cu_seqlens = F.pad(cu_seqlens, (1, 0), value=0).contiguous() + cu_seqlens = cu_seqlens.squeeze().contiguous() max_seqlen = seqlens.max().item() - total_seqlen = cu_seqlens[-1].item() return PackedSeqParams( cu_seqlens_q=cu_seqlens, cu_seqlens_kv=cu_seqlens, max_seqlen_q=max_seqlen, max_seqlen_kv=max_seqlen, - total_seqlen_q=total_seqlen, - total_seqlen_kv=total_seqlen, qkv_format="thd", ) - def forward(self, dict_args: dict): + def forward(self, dict_args: dict) -> paddle.Tensor: pixel_values = dict_args["pixel_values"] grid_thw = dict_args["grid_thw"] - pixel_values = pixel_values.reshape( - [ - -1, - self.in_channels, - self.temporal_patch_size, - self.patch_size, - self.patch_size, - ] + # Pathed embedding + hidden_states = pixel_values.reshape( + -1, + self.in_channels, + self.temporal_patch_size, + self.patch_size, + self.patch_size, ) + hidden_states = self.patch_embed(hidden_states).view(-1, self.embed_dim) - hidden_states = ( - self.patch_embed(pixel_values) - .flatten(2) - .transpose([0, 2, 1]) - .reshape([-1, self.embed_dim]) + # Share token-to-image mapping to avoid redundant computation + image_id, frame_local_idx, total_tokens, max_hw = ( + self._build_token_image_mapping(grid_thw) + ) + + pos_embeds = self.fast_pos_embed_interpolate( + grid_thw, + image_id=image_id, + frame_local_idx=frame_local_idx, + total_tokens=total_tokens, + max_hw=max_hw, ) - pos_embeds = self.fast_pos_embed_interpolate(grid_thw) hidden_states = hidden_states + pos_embeds - seq_len, _ = hidden_states.shape + seq_len, _ = hidden_states.size() hidden_states = hidden_states.reshape([seq_len, -1]) hidden_states = hidden_states.unsqueeze(0) - rotary_pos_emb = self.rot_pos_emb(grid_thw) - rotary_pos_cos = paddle.cos(rotary_pos_emb) - rotary_pos_sin = paddle.sin(rotary_pos_emb) + rotary_pos_emb = self.rot_pos_emb( + grid_thw, + image_id=image_id, + frame_local_idx=frame_local_idx, + total_tokens=total_tokens, + max_hw=max_hw, + ) + rotary_pos_emb = rotary_pos_emb.reshape(seq_len, -1) + rotary_pos_emb = paddle.cat((rotary_pos_emb, rotary_pos_emb), axis=-1) + # Cast freqs to float32 and compute cos/sin inside auto_cast(False) to match the + # precision of _apply_rotary_pos_emb_bshd_fp32, which computes cos/sin on the same + # bf16 freqs but under auto_cast(False) using a float32 kernel. + with paddle.amp.auto_cast(False): + _freqs_f32 = rotary_pos_emb.astype("float32") + rotary_pos_cos = paddle.cos(_freqs_f32) + rotary_pos_sin = paddle.sin(_freqs_f32) + rotary_pos_emb = rotary_pos_emb[:, None, None, :] + rotary_pos_emb = rotary_pos_emb.transpose([1, 0, 2, 3]) packed_seq_params = self.get_packed_seq_params(grid_thw) + # Pre-compute attn_mask_startend_row_indices once for all ViT layers + cu_seqlens = packed_seq_params.cu_seqlens_kv + lengths = cu_seqlens[1:] - cu_seqlens[:-1] + indices_per_segment = paddle.stack( + [ + cu_seqlens[1:], # col 0: lower_start = end_i + paddle.full_like( + cu_seqlens[1:], seq_len + ), # col 1: lower_end = total_seq + paddle.zeros_like(cu_seqlens[:-1]), # col 2: upper_start = 0 + cu_seqlens[:-1], # col 3: upper_end = start_i + ], + axis=1, + ) # [num_segments, 4] + attn_mask_startend_row_indices = ( + paddle.repeat_interleave(indices_per_segment, lengths, axis=0) + .unsqueeze(0) + .unsqueeze(0) + ) # [1, 1, seq_len, 4] + preproc_output = { "hidden_states": hidden_states, "attention_mask": dict_args.get("attention_mask", None), @@ -262,6 +342,267 @@ def forward(self, dict_args: dict): "rotary_pos_cos": rotary_pos_cos, "rotary_pos_sin": rotary_pos_sin, "packed_seq_params": packed_seq_params, + "attn_mask_startend_row_indices": attn_mask_startend_row_indices, } return preproc_output + + +class Qwen3VLTextEmbedding(GPTEmbedding): + def forward( + self, + dict_args: dict, + decoder_input: paddle.Tensor = None, + packed_seq_params: PackedSeqParams = None, + ): + input_ids = dict_args["input_ids"] + position_ids = dict_args.get("position_ids", None) + position_ids = ( + position_ids.to("gpu") if position_ids is not None else None + ) + attention_mask = dict_args.get("attention_mask", None) + attn_mask_startend_row_indices = dict_args.get( + "attn_mask_startend_row_indices", None + ) + deepstack_image_embeds = dict_args.get("deepstack_image_embeds", None) + deepstack_video_embeds = dict_args.get("deepstack_video_embeds", None) + # Deepstack + deepstack_visual_embeds = None + visual_pos_masks = None + mtp_emb_res = None + if decoder_input is None: + decoder_input = self.embedding( + input_ids=input_ids, + position_ids=None + if self.multimodal_embedding + else position_ids, + ) + if ( + self.config.num_nextn_predict_layers is not None + and self.config.num_nextn_predict_layers > 0 + ): + assert not self.multimodal_embedding, ( + "MTP not support mm for now." + ) + inputs_embeds_extra = decoder_input[ + :, -self.config.num_nextn_predict_layers :, : + ] # [B, S, H] + inputs_embeds = decoder_input[ + :, : -self.config.num_nextn_predict_layers, : + ] + inputs_embeds_ori = inputs_embeds + batch_size, seq_length, hidden_size = inputs_embeds.shape + + if self.sequence_parallel: + inputs_embeds = inputs_embeds.reshape( + [-1, inputs_embeds.shape[-1]] + ) + inputs_embeds = ScatterOp.apply(inputs_embeds) + inputs_embeds = ( + inputs_embeds.reshape([batch_size, -1, hidden_size]) + .permute(1, 0, 2) + .contiguous() + ) # change to [S, B, H] + mtp_emb_res = [inputs_embeds] + for depth in range(self.config.num_nextn_predict_layers): + inputs_embeds_mtp = paddle.concat( + [ + inputs_embeds_ori[:, (depth + 1) :, :], + inputs_embeds_extra[:, : (depth + 1), :], + ], + axis=1, + ) + if self.sequence_parallel: + inputs_embeds_mtp = inputs_embeds_mtp.reshape( + [-1, inputs_embeds_mtp.shape[-1]] + ) + inputs_embeds_mtp = ScatterOp.apply(inputs_embeds_mtp) + inputs_embeds_mtp = ( + inputs_embeds_mtp.reshape( + [batch_size, -1, hidden_size] + ) + .permute(1, 0, 2) + .contiguous() + ) # change to [S, B, H] + mtp_emb_res.append(inputs_embeds_mtp) + + if self.multimodal_embedding: + image_embeds = dict_args.get("image_embeds", None) + video_embeds = dict_args.get("video_embeds", None) + if image_embeds is not None: + image_mask, _ = self.get_placeholder_mask( + input_ids, + inputs_embeds=decoder_input, + image_features=image_embeds, + ) + # Replace masked_scatter with arithmetic blend to avoid + # IndexingBackwardKernel (sparse scatter) in the backward pass. + # image_mask : [B, S, H] bool + # image_embeds: [N_img, H] (N_img = number of image tokens) + # Expand image_embeds into the full [B, S, H] space by: + # 1. flatten decoder_input and image_mask to 1-D + # 2. use paddle.scatter (dense backward = gather) to place + # image_embeds values at the True positions + # 3. blend with original decoder_input via mask arithmetic + # + # Optimization: reuse decoder_input's flattened buffer as the + # scatter base (scaled by (1-mask)) to avoid a separate + # paddle.zeros([n_total]) allocation (~192 MB bf16 tensor). + image_mask_f = image_mask.astype( + decoder_input.dtype + ) # [B,S,H] float + flat_indices = paddle.nonzero( + image_mask.reshape([-1]) + ).squeeze( + -1 + ) # [N_img*H] int64 — dense nonzero, no scatter bwd + # Scale the base tensor by (1 - mask) in-place before scatter + # so that visual positions are zero — no extra zeros allocation. + base_flat = (decoder_input * (1.0 - image_mask_f)).reshape( + [-1] + ) + image_src_flat = paddle.scatter( + base_flat, + flat_indices, + image_embeds.astype(decoder_input.dtype).reshape([-1]), + ) # scatter bwd is a simple gather — no sparse atomics + decoder_input = image_src_flat.reshape(decoder_input.shape) + visual_pos_masks = image_mask[..., 0] + deepstack_visual_embeds = deepstack_image_embeds + + if video_embeds is not None: + _, video_mask = self.get_placeholder_mask( + input_ids, + inputs_embeds=decoder_input, + video_features=video_embeds, + ) + video_mask_f = video_mask.astype(decoder_input.dtype) + flat_indices = paddle.nonzero( + video_mask.reshape([-1]) + ).squeeze(-1) + base_flat = (decoder_input * (1.0 - video_mask_f)).reshape( + [-1] + ) + video_src_flat = paddle.scatter( + base_flat, + flat_indices, + video_embeds.astype(decoder_input.dtype).reshape([-1]), + ) + decoder_input = video_src_flat.reshape(decoder_input.shape) + visual_pos_masks = video_mask[..., 0] + deepstack_visual_embeds = deepstack_video_embeds + + if image_embeds is not None and video_embeds is not None: + image_mask = image_mask[..., 0] # [B, S] bool + video_mask = video_mask[..., 0] # [B, S] bool + visual_pos_masks = image_mask | video_mask + deepstack_visual_embeds = [] + for img_embed, vid_embed in zip( + deepstack_image_embeds, deepstack_video_embeds + ): + # Build embed_joint [N_visual, H] without boolean-index + # scatter. Use dense mask arithmetic instead. + # img_embed : [N_img, H] + # vid_embed : [N_vid, H] + # visual_pos_masks: [B, S] bool, N_visual True entries + # img_mask_in_visual[i] = True iff visual position i is image + # Computed as: image_mask flattened, keep only visual positions, + # expressed as a dense [N_visual] float mask — no indexing. + h = img_embed.shape[-1] + n_visual = int(visual_pos_masks.sum()) + # visual_pos_flat: [B*S] bool + visual_pos_flat = visual_pos_masks.reshape([-1]) + image_mask_flat = image_mask.reshape([-1]) # [B*S] bool + video_mask_flat = video_mask.reshape([-1]) # [B*S] bool + # Dense [B*S] float masks, then compress to [N_visual] via + # paddle.masked_select (forward: gather, backward: scatter_add + # — but scalar backward is efficient, no sparse atomics) + img_mask_in_vis_f = paddle.masked_select( + image_mask_flat.astype(img_embed.dtype), + visual_pos_flat, + ).unsqueeze(-1) # [N_visual, 1] + vid_mask_in_vis_f = paddle.masked_select( + video_mask_flat.astype(vid_embed.dtype), + visual_pos_flat, + ).unsqueeze(-1) # [N_visual, 1] + embed_joint = ( + img_embed.reshape([n_visual, h]) * img_mask_in_vis_f + + vid_embed.reshape([n_visual, h]) + * vid_mask_in_vis_f + ) + deepstack_visual_embeds.append(embed_joint) + # Scatter decoder_input to SP format [S/tp, B, H] after multimodal + # token replacement, since LanguageModelEmbedding's internal scatter + # was disabled to allow image/video embedding insertion first. + if self.sequence_parallel: + decoder_input = decoder_input.transpose( + [1, 0, 2] + ).contiguous() + decoder_input = scatter_to_sequence_parallel_region( + decoder_input, group=self.embedding.tp_group + ) + if self.config.clone_scatter_output_in_embedding: + decoder_input = decoder_input.clone() + # Rotary positional embeddings (embedding is None for PP intermediate devices) + rotary_pos_emb = None + rotary_pos_cos = None + rotary_pos_sin = None + + if ( + self.position_embedding_type == "rope" + and self.rotary_pos_emb is not None + ): + rope_base = decoder_input if mtp_emb_res is None else mtp_emb_res[0] + rotary_seq_len = self.rotary_pos_emb.get_rotary_seq_len( + rope_base, self.config, packed_seq_params + ) + rotary_pos_emb = self.rotary_pos_emb( + rotary_seq_len, + packed_seq=packed_seq_params is not None + and packed_seq_params.qkv_format == "thd", + position_ids=position_ids, + ) + elif ( + self.position_embedding_type == "mrope" + and self.rotary_pos_emb is not None + ): + rotary_pos_emb = self.rotary_pos_emb( + position_ids, self.mrope_section + ) + + if rotary_pos_emb is not None: + if self.config.apply_rope_fusion: + rotary_pos_cos = paddle.cos(rotary_pos_emb) + rotary_pos_sin = paddle.sin(rotary_pos_emb) + if self.config.sequence_parallel: + rotary_pos_emb = rotary_pos_emb.transpose( + [1, 0, 2, 3] + ).contiguous() + + preproc_output = { + "hidden_states": decoder_input, + "attention_mask": attention_mask, + "attn_mask_startend_row_indices": attn_mask_startend_row_indices, + "rotary_pos_emb": rotary_pos_emb, + "rotary_pos_cos": rotary_pos_cos, + "rotary_pos_sin": rotary_pos_sin, + "position_ids": position_ids, + "deepstack_visual_emb": deepstack_visual_embeds, + "visual_pos_masks": visual_pos_masks, + } + if mtp_emb_res is not None: + assert ( + self.config.num_nextn_predict_layers is not None + and self.config.num_nextn_predict_layers > 0 + ) + assert len(mtp_emb_res) == self.config.num_nextn_predict_layers + 1 + hidden_states_concat = paddle.concat(mtp_emb_res) + preproc_output["hidden_states"] = hidden_states_concat + + for key in list(preproc_output.keys()): + if preproc_output[key] is None: + preproc_output.pop(key) + return preproc_output + + +__all__ = ["Qwen3VLTextEmbedding"] diff --git a/src/paddlefleet/models/qwen3_vl/layer_specs.py b/src/paddlefleet/models/qwen3_vl/layer_specs.py index 362b0d32f..dfd210da4 100644 --- a/src/paddlefleet/models/qwen3_vl/layer_specs.py +++ b/src/paddlefleet/models/qwen3_vl/layer_specs.py @@ -16,12 +16,16 @@ from ...fusions.fused_bias_dropout import get_bias_dropout_add from ...spec_utils import LayerSpec from ...transformer.attention import SelfAttention, SelfAttentionSublayersSpec +from ...transformer.enums import AttnMaskType from ...transformer.identity_op import IdentityOp from ...transformer.transformer_config import TransformerConfig from ..backends import LocalSpecProvider -from ..common.embeddings.rotary_pos_embedding import RotaryEmbedding from ..gpt.gpt_layer_specs import get_mlp_layer_spec_for_backend -from .embedding import VisionEmbedding, VisionEmbeddingSpec +from .embedding import ( + VisionEmbedding, + VisionEmbeddingSpec, + VisionRotaryEmbedding, +) from .patch_merger import Qwen3VLVisionPatchMergerSpec, Qwen3VLVisionPathMerger from .qwen3_vl_model import ( Qwen3VLVisionModel, @@ -38,8 +42,12 @@ def get_qwen3_vl_vision_layer_local_spec( append_deepstack: bool = False, ) -> LayerSpec: backend = LocalSpecProvider() - layer_norm = backend.layer_norm(rms_norm=False, for_qk=False) - qk_norm = backend.layer_norm(rms_norm=False, for_qk=True) + layer_norm = backend.layer_norm( + rms_norm=False, for_qk=False, fused=False, eps=config.rms_norm_eps + ) + qk_norm = backend.layer_norm( + rms_norm=False, for_qk=True, fused=False, eps=config.rms_norm_eps + ) mlp = get_mlp_layer_spec_for_backend( backend=backend, ) @@ -47,8 +55,8 @@ def get_qwen3_vl_vision_layer_local_spec( merger_spec = LayerSpec( layer=Qwen3VLVisionPathMerger, sublayers_spec=Qwen3VLVisionPatchMergerSpec( - backend.layer_norm( - rms_norm=(config.normalization == "RMSNorm"), for_qk=False + norm=backend.layer_norm( + rms_norm=False, for_qk=False, fused=False, eps=1e-6 ) ), extra_kwargs={"config": config, "use_postshuffle_norm": True}, @@ -66,6 +74,7 @@ def get_qwen3_vl_vision_layer_local_spec( q_norm=qk_norm if use_qk_norm else IdentityOp, k_norm=qk_norm if use_qk_norm else IdentityOp, ), + extra_kwargs={"attn_mask_type": AttnMaskType.no_mask}, ), self_attn_bda=get_bias_dropout_add, post_attention_layernorm=layer_norm, @@ -97,8 +106,8 @@ def get_qwen3vl_vision_encoder_layers_spec( use_qk_norm=config.use_qk_norm, ) layer_specs = [] - append_deepstack = False for layer_number in range(config.num_hidden_layers): + append_deepstack = False real_layer_number = layer_number + config.num_empty_layers_add_in_head if layer_number in config.deepstack_visual_indexes: append_deepstack = True @@ -117,31 +126,26 @@ def get_qwen3_vl_vision_spec( transformer_layers_spec: list[LayerSpec], head_empty_layers_spec: list[LayerSpec] | None = None, tail_empty_layer_spec: list[LayerSpec] | None = None, - rotary_percent: float = 1.0, - rotary_base: int = 10000, - rope_scaling: bool = False, + rotary_base: int = 10000.0, ): backend = LocalSpecProvider() embedding_extra_kwargs = {"config": config} rotary_emb_extra_kwargs = { - "head_dim": config.head_dim // 2, - "rotary_base": rotary_base, - "rope_scaling": rope_scaling, - "rotary_percent": rotary_percent, + "dim": config.head_dim // 2, + "theta": rotary_base, } embedding_spec = VisionEmbeddingSpec( rope_embedding=LayerSpec( - layer=RotaryEmbedding, + layer=VisionRotaryEmbedding, extra_kwargs=rotary_emb_extra_kwargs, ) ) - merger_norm = backend.layer_norm( - rms_norm=(config.normalization == "RMSNorm"), for_qk=False - ) merger_spec = LayerSpec( layer=Qwen3VLVisionPathMerger, sublayers_spec=Qwen3VLVisionPatchMergerSpec( - norm=merger_norm, + norm=backend.layer_norm( + rms_norm=False, for_qk=False, fused=False, eps=1e-6 + ) ), extra_kwargs={ "config": config, diff --git a/src/paddlefleet/models/qwen3_vl/patch_merger.py b/src/paddlefleet/models/qwen3_vl/patch_merger.py index d3dea3922..93b442210 100644 --- a/src/paddlefleet/models/qwen3_vl/patch_merger.py +++ b/src/paddlefleet/models/qwen3_vl/patch_merger.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +from collections import OrderedDict from dataclasses import dataclass from paddle import nn @@ -65,9 +66,8 @@ def __init__( ) ) - def forward(self, x): - if isinstance(x, dict): - x = x["hidden_states"].squeeze(0) + def forward(self, dict_args): + x = dict_args["hidden_states"].squeeze(0) if self.use_postshuffle_norm: x = self.norm(x.reshape([-1, self.hidden_size])) x = x.reshape([-1, self.hidden_size]) @@ -78,4 +78,7 @@ def forward(self, x): x, output_bias = self.mlp(x) if output_bias is not None: x += output_bias - return x, None + rst = OrderedDict() + rst = {"hidden_states": x} + rst = {**dict_args, **rst} + return rst diff --git a/src/paddlefleet/models/qwen3_vl/qwen3_vl_builders.py b/src/paddlefleet/models/qwen3_vl/qwen3_vl_builders.py index 562e88ce1..efb32eedf 100644 --- a/src/paddlefleet/models/qwen3_vl/qwen3_vl_builders.py +++ b/src/paddlefleet/models/qwen3_vl/qwen3_vl_builders.py @@ -42,8 +42,6 @@ def qwen3_vl_vision_builder(config, **kwargs): transformer_layers_spec=transformer_layer_specs, tail_empty_layer_spec=tail_empty_layers_spec, rotary_base=config.rope_theta, - rotary_percent=config.rotary_percent, - rope_scaling=config.rope_scaling, ) return build_layer(res_spec, **kwargs) diff --git a/src/paddlefleet/models/qwen3_vl/qwen3_vl_model.py b/src/paddlefleet/models/qwen3_vl/qwen3_vl_model.py index 6b869bd40..df2ca841e 100644 --- a/src/paddlefleet/models/qwen3_vl/qwen3_vl_model.py +++ b/src/paddlefleet/models/qwen3_vl/qwen3_vl_model.py @@ -21,6 +21,8 @@ from ...pipeline_parallel import LayerDesc from ...process_groups_config import ProcessGroupCollection from ...spec_utils import LayerSpec, build_layer +from ...transformer.enums import ModelType +from ...transformer.layer import FleetLayer from ...transformer.transformer_config import TransformerConfig from ...transformer.transformer_encoder import TransformerEncoder from ...transformer.transformer_layer import ( @@ -29,6 +31,15 @@ ) +def get_image_sequence_length( + img_h, img_w, patch_dim, add_class_token, class_token_len +): + num_patches_per_dim_h = img_h // patch_dim + num_patches_per_dim_w = img_w // patch_dim + num_patches = num_patches_per_dim_h * num_patches_per_dim_w + return num_patches + (class_token_len if add_class_token else 0) + + @dataclass class Qwen3VLVisionSublayersSpec: """ @@ -107,6 +118,7 @@ def forward( # runners in the cuda graph manager dict_args.pop("dynamic_inference_decode_only", None) dict_args.pop("position_ids", None) + deepstack_features_list = dict_args.pop("deepstack_features_list", None) if self.full_recompute: hidden_states = dict_args["hidden_states"] attention_mask = dict_args.get("attention_mask", None) @@ -167,21 +179,168 @@ def forward( else: outputs = self._forward_impl(**dict_args) - if len(outputs) == 3: + context, deepstack_feature = None, None + hidden_states = outputs[0] + if len(outputs) > 1: + deepstack_feature = outputs[-1] + if len(outputs) == 3: + context = outputs[1] + + rst = OrderedDict() + rst = {"hidden_states": hidden_states} + if context is not None: + rst["context"] = context + if deepstack_features_list is None: + deepstack_features_list = [] + if deepstack_feature is not None: + deepstack_features_list.append(deepstack_feature) + rst["deepstack_features_list"] = deepstack_features_list + rst = {**dict_args, **rst} + return rst + + def _forward_impl( + self, + hidden_states: paddle.Tensor, + attention_mask: paddle.Tensor = None, + attn_mask_startend_row_indices: paddle.Tensor = None, + context: paddle.Tensor = None, + context_mask: paddle.Tensor = None, + rotary_pos_emb: paddle.Tensor = None, + rotary_pos_cos: paddle.Tensor = None, + rotary_pos_sin: paddle.Tensor = None, + attention_bias: paddle.Tensor = None, + packed_seq_params: PackedSeqParams = None, + ): + hidden_states, context = self._forward_attention( + hidden_states=hidden_states, + attention_mask=attention_mask, + attn_mask_startend_row_indices=attn_mask_startend_row_indices, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb, + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + + hidden_states = self._forward_mlp(hidden_states) + + deepstack_feature = None + if self.deepstack_merger is not None: + deepstack_feature = self.deepstack_merger( + {"hidden_states": hidden_states} + )["hidden_states"] + + res = (hidden_states,) + if context is not None: + res += (context,) + if deepstack_feature is not None: + res += (deepstack_feature,) + return res + + +class Qwen3VLTextTransformerLayer(TransformerLayer): + """Qwen3VL text model for adapt deepstack process""" + + def forward( + self, + dict_args: dict, + ): + """ + Perform a forward pass through the transformer layer. + + This method calls the core computation of a transformer layer, including + self-attention, cross-attention (if applicable), and feed-forward operations. + """ + # Remove 'dynamic_inference_decode_only' from kwargs if present + # this is only used to uniquely identify decode and non-decode cuda graph + # runners in the cuda graph manager + dict_args.pop("dynamic_inference_decode_only", None) + dict_args.pop("position_ids", None) + deepstack_visual_emb = dict_args.get("deepstack_visual_emb", None) + visual_pos_masks = dict_args.get("visual_pos_masks", None) + + if self.full_recompute: + hidden_states = dict_args["hidden_states"] + attention_mask = dict_args.get("attention_mask", None) + attn_mask_startend_row_indices = dict_args.get( + "attn_mask_startend_row_indices", None + ) + context = dict_args.get("context", None) + context_mask = dict_args.get("context_mask", None) + rotary_pos_emb = dict_args.get("rotary_pos_emb", None) + rotary_pos_cos = dict_args.get("rotary_pos_cos", None) + rotary_pos_sin = dict_args.get("rotary_pos_sin", None) + attention_bias = dict_args.get("attention_bias", None) + packed_seq_params = dict_args.get("packed_seq_params", None) + + assert (rotary_pos_sin is None) == (rotary_pos_cos is None) + + if rotary_pos_cos is not None and rotary_pos_sin is not None: + rotary_pos_cos = rotary_pos_cos.clone() + rotary_pos_sin = rotary_pos_sin.clone() + if self.config.apply_rope_fusion: + rotary_pos_cos = rotary_pos_cos[0, ...] + rotary_pos_sin = rotary_pos_sin[0, ...] + if rotary_pos_cos.ndim == 2: + rotary_pos_cos = rotary_pos_cos.reshape( + [ + 1, + rotary_pos_cos.shape[0], + 1, + rotary_pos_cos.shape[1], + ] + ) + rotary_pos_sin = rotary_pos_sin.reshape( + [ + 1, + rotary_pos_sin.shape[0], + 1, + rotary_pos_sin.shape[1], + ] + ) + + outputs = recompute( + self._forward_impl, + hidden_states=hidden_states, + attention_mask=attention_mask, + attn_mask_startend_row_indices=attn_mask_startend_row_indices.clone() # Clone is necessary! + if attn_mask_startend_row_indices is not None + else None, + context=context, + context_mask=context_mask, + rotary_pos_emb=rotary_pos_emb.clone() + if rotary_pos_emb is not None + else None, # Clone is necessary! + rotary_pos_cos=rotary_pos_cos, + rotary_pos_sin=rotary_pos_sin, + attention_bias=attention_bias, + packed_seq_params=packed_seq_params, + ) + else: + outputs = self._forward_impl(**dict_args) + + if isinstance(outputs, tuple): output, context = outputs[0], outputs[1] else: output, context = outputs, None - deepstack_feature = outputs[-1] + # Apply deepstack visual embedding outside of recompute to avoid issues + # with recompute not properly handling list-of-tensors (deepstack_visual_emb) + if deepstack_visual_emb and self.layer_number in range( + len(deepstack_visual_emb) + ): + output = self._deepstack_process( + hidden_states=output, + visual_embeds=deepstack_visual_emb[self.layer_number], + visual_pos_masks=visual_pos_masks, + ) rst = OrderedDict() rst = {"hidden_states": output} if context is not None: rst["context"] = context - if "deepstack_feature_lists" not in rst: - rst["deepstack_feature_lists"] = [] - if deepstack_feature is not None: - rst["deepstack_feature_lists"].append(deepstack_feature) rst = {**dict_args, **rst} return rst @@ -197,6 +356,7 @@ def _forward_impl( rotary_pos_sin: paddle.Tensor = None, attention_bias: paddle.Tensor = None, packed_seq_params: PackedSeqParams = None, + **kwargs, ): hidden_states, context = self._forward_attention( hidden_states=hidden_states, @@ -211,11 +371,613 @@ def _forward_impl( packed_seq_params=packed_seq_params, ) hidden_states = self._forward_mlp(hidden_states) + if context is not None: + return hidden_states, context + return hidden_states - deepstack_feature = None - if self.deepstack_merger is not None: - deepstack_feature = self.deepstack_merger(hidden_states) + def _deepstack_process( + self, + hidden_states: paddle.Tensor, + visual_pos_masks: paddle.Tensor, + visual_embeds: paddle.Tensor, + ): + # SP layout is [S/tp, B, H] (seq-first); transpose to [B, S/tp, H] so that + # flatten(0,1) produces batch-first [B*S/tp, H], consistent with visual_pos_masks [B, S]. + _sp_transposed = False + if ( + getattr(self.config, "sequence_parallel", False) + and hidden_states.ndim == 3 + ): + hidden_states = hidden_states.transpose( + [1, 0, 2] + ) # [S/tp,B,H] -> [B,S/tp,H] + _sp_transposed = True + # Save original_shape AFTER the SP transpose so that reshape restores the + # batch-first [B, S/tp, H] form (needed for the final back-transpose). + original_shape = hidden_states.shape + if hidden_states.ndim > 2: + hidden_states = hidden_states.flatten(start_axis=0, stop_axis=1) - if context is not None: - return hidden_states, context, deepstack_feature - return hidden_states, deepstack_feature + visual_embeds = visual_embeds.to( + hidden_states.device, hidden_states.dtype + ) + + # Sequence Parallelism (SP) row slicing. + # visual_pos_masks is [B, S] (full sequence), hidden_states is [B*S/tp, H] + # (batch-major after transpose+flatten). We must slice along the S dimension + # (dim=1) to match the batch-major layout, NOT flatten-then-chunk which + # breaks when B > 1. + if visual_pos_masks.ndim > 1 and visual_pos_masks.shape[ + 1 + ] > hidden_states.shape[0] // max(visual_pos_masks.shape[0], 1): + # visual_pos_masks: [B, S], hidden_states: [B*S/tp, H] + try: + from paddle.distributed.fleet import ( + get_hybrid_communicate_group, + ) + + hcg = get_hybrid_communicate_group() + mp_rank = hcg.get_model_parallel_rank() + mp_size = hcg.get_model_parallel_world_size() + except (ImportError, AttributeError): + batch_size = visual_pos_masks.shape[0] + full_seq_len = visual_pos_masks.shape[1] + mp_size = (batch_size * full_seq_len) // hidden_states.shape[0] + mp_rank = paddle.distributed.get_rank() % mp_size + + full_seq_len = visual_pos_masks.shape[1] + chunk_s = full_seq_len // mp_size + start_s = mp_rank * chunk_s + + # Slice along S dimension: [B, S] -> [B, S/tp] + local_mask = visual_pos_masks[:, start_s : start_s + chunk_s] + batch_size = visual_pos_masks.shape[0] + + # Gather per-sample visual_embeds. + # visual_embeds is ordered as [sample0_all_vis, sample1_all_vis, ...]. + # Each rank only needs the visual tokens that fall within its local + # sequence chunk [start_s, start_s+chunk_s) for each sample. + per_sample_total = paddle.cast(visual_pos_masks, "int32").sum( + axis=1 + ) # [B] + per_sample_pre = ( + paddle.cast(visual_pos_masks[:, :start_s], "int32").sum(axis=1) + if start_s > 0 + else paddle.zeros([batch_size], dtype="int32") + ) # [B] + per_sample_local = paddle.cast(local_mask, "int32").sum( + axis=1 + ) # [B] + + gather_indices = [] + cumulative_total = 0 + for i in range(batch_size): + total_i = int(per_sample_total[i].item()) + pre_i = int(per_sample_pre[i].item()) + count_i = int(per_sample_local[i].item()) + if count_i > 0: + gather_indices.append( + paddle.arange( + cumulative_total + pre_i, + cumulative_total + pre_i + count_i, + ) + ) + cumulative_total += total_i + + if gather_indices: + gather_indices = paddle.concat(gather_indices) + visual_embeds = visual_embeds[gather_indices] + else: + visual_embeds = visual_embeds[:0] # empty + + # Flatten local mask to [B*S/tp] matching hidden_states batch-major layout + visual_pos_masks = local_mask.flatten() + elif visual_pos_masks.ndim > 1: + visual_pos_masks = visual_pos_masks.flatten() + + # If TP is enabled, hidden_states has shape [..., Hidden_Dim / TP_Size], + # but visual_embeds usually has full [Hidden_Dim]. We need to slice visual_embeds column-wise. + if hidden_states.shape[-1] != visual_embeds.shape[-1]: + try: + from paddle.distributed.fleet import ( + get_hybrid_communicate_group, + ) + + hcg = get_hybrid_communicate_group() + tp_rank = hcg.get_model_parallel_rank() + tp_size = hcg.get_model_parallel_world_size() + except (ImportError, AttributeError): + # Fallback simple estimation + tp_size = visual_embeds.shape[-1] // hidden_states.shape[-1] + tp_rank = paddle.distributed.get_rank() % tp_size + + if tp_size > 1: + embed_dim = visual_embeds.shape[-1] + slice_width = embed_dim // tp_size + start_col = tp_rank * slice_width + end_col = start_col + slice_width + visual_embeds = visual_embeds[:, start_col:end_col] + + hidden_states = hidden_states.clone() + update_indices = paddle.nonzero(visual_pos_masks) + # Under SP, visual tokens are unevenly distributed across ranks. After row-slicing + # visual_pos_masks and visual_embeds to the local sequence chunk, some ranks may + # have zero visual tokens (local_visual_count == 0), producing visual_embeds with + # shape [0, H]. Guard against passing an empty updates tensor to scatter_nd_add, + # whose behavior is undefined / backend-dependent in that case. + if visual_embeds.shape[0] > 0: + hidden_states = paddle.scatter_nd_add( + hidden_states, update_indices, visual_embeds + ) + + # [Supplement 3] Restore original shape [B*S, D] -> [B, S, D] if necessary + if len(original_shape) > 2: + hidden_states = hidden_states.reshape(original_shape) + if _sp_transposed: + hidden_states = hidden_states.transpose( + [1, 0, 2] + ) # [B,S/tp,H] -> [S/tp,B,H] + + return hidden_states + + +class Qwen3VLModelDist(FleetLayer): + """Qwen3VL Model Base Model Class.""" + + def __init__( + self, + config: TransformerConfig, + pre_process: bool = True, + post_process: bool = True, + add_encoder: bool = True, + add_decoder: bool = True, + drop_vision_class_token: bool = False, + vp_stage: int | None = None, + model_version: str | None = None, + criterion=False, + ) -> None: + super().__init__(config=config) + + language_transformer_config = config.text_config + vision_transformer_config = config.vision_config + self.model_version = ( + vision_transformer_config.model_version + if model_version is None + else model_version + ) + self._language_max_sequence_length = ( + language_transformer_config.max_sequence_length + ) + assert self.model_version is not None + + self.config = config + self.pre_process = pre_process + self.post_process = post_process + self.add_encoder = add_encoder + self.add_decoder = add_decoder + self.vp_stage = vp_stage + + self.encoder_hidden_state = None + self.vision_model = None + self.language_model = None + self.image_token_index = config.image_token_id + self.video_token_index = config.video_token_id + + self.sequence_parallel_lm = ( + language_transformer_config.sequence_parallel + ) + self.tp_comm_overlap_lm = language_transformer_config.tp_comm_overlap + self.context_parallel_lm = ( + language_transformer_config.context_parallel_size + ) + assert not (self.context_parallel_lm > 1), ( + f"qwenvl donnot support context parallel {self.context_parallel_lm}" + ) + self.share_embeddings_and_output_weights = False + self.rope_deltas = None + + if self.add_decoder: + self.language_model = language_transformer_config.provide( + pre_process=pre_process, + post_process=post_process, + vp_stage=vp_stage, + ) + self._language_is_pipeline_parallel = ( + language_transformer_config.pipeline_model_parallel_size > 1 + ) + + if self.add_encoder: + self.vision_model = vision_transformer_config.provide() + self._drop_vision_class_token = drop_vision_class_token + + self.model_type = ModelType.encoder_or_decoder + + self._img_seq_len = get_image_sequence_length( + img_h=vision_transformer_config.img_h, + img_w=vision_transformer_config.img_w, + patch_dim=vision_transformer_config.patch_size, + add_class_token=not drop_vision_class_token, + class_token_len=vision_transformer_config.class_token_len, + ) + self.criterion = criterion + + def get_rope_index( + self, + input_ids: paddle.LongTensor | None = None, + image_grid_thw: paddle.LongTensor | None = None, + video_grid_thw: paddle.LongTensor | None = None, + attention_mask: paddle.Tensor | None = None, + ) -> tuple[paddle.Tensor, paddle.Tensor]: + if video_grid_thw is not None: + video_grid_thw = paddle.repeat_interleave( + video_grid_thw, video_grid_thw[:, 0], dim=0 + ) + video_grid_thw[:, 0] = 1 + + spatial_merge_size = self.config.vision_config.spatial_merge_size + # TODO when implemented data file. + image_token_id = self.image_token_index + video_token_id = self.video_token_index + vision_start_token_id = 151652 + mrope_position_deltas = [] + if input_ids is not None and ( + image_grid_thw is not None or video_grid_thw is not None + ): + total_input_ids = input_ids + if attention_mask is None: + attention_mask = paddle.ones_like(total_input_ids) + position_ids = paddle.ones( + [3, input_ids.shape[0], input_ids.shape[1]], + dtype=input_ids.dtype, + ) + image_index, video_index = 0, 0 + for i, input_ids in enumerate(total_input_ids): + input_ids = input_ids[attention_mask[i] == 1] + image_nums, video_nums = 0, 0 + vision_start_indices = paddle.argwhere( + input_ids == vision_start_token_id + ).squeeze(1) + vision_tokens = input_ids[vision_start_indices + 1] + image_nums = (vision_tokens == image_token_id).sum() + video_nums = (vision_tokens == video_token_id).sum() + input_tokens = input_ids.tolist() + llm_pos_ids_list: list = [] + st = 0 + remain_images, remain_videos = image_nums, video_nums + for _ in range(image_nums + video_nums): + if image_token_id in input_tokens and remain_images > 0: + ed_image = input_tokens.index(image_token_id, st) + else: + ed_image = len(input_tokens) + 1 + if video_token_id in input_tokens and remain_videos > 0: + ed_video = input_tokens.index(video_token_id, st) + else: + ed_video = len(input_tokens) + 1 + if ed_image < ed_video: + t, h, w = ( + image_grid_thw[image_index][0], + image_grid_thw[image_index][1], + image_grid_thw[image_index][2], + ) + image_index += 1 + remain_images -= 1 + ed = ed_image + + else: + t, h, w = ( + video_grid_thw[video_index][0], + video_grid_thw[video_index][1], + video_grid_thw[video_index][2], + ) + video_index += 1 + remain_videos -= 1 + ed = ed_video + llm_grid_t, llm_grid_h, llm_grid_w = ( + t.item(), + h.item() // spatial_merge_size, + w.item() // spatial_merge_size, + ) + text_len = ed - st + + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if llm_pos_ids_list + else 0 + ) + llm_pos_ids_list.append( + paddle.arange(text_len).view(1, -1).expand(3, -1) + + st_idx + ) + + t_index = ( + paddle.arange(llm_grid_t) + .view(-1, 1) + .expand(-1, llm_grid_h * llm_grid_w) + .flatten() + ) + h_index = ( + paddle.arange(llm_grid_h) + .view(1, -1, 1) + .expand(llm_grid_t, -1, llm_grid_w) + .flatten() + ) + w_index = ( + paddle.arange(llm_grid_w) + .view(1, 1, -1) + .expand(llm_grid_t, llm_grid_h, -1) + .flatten() + ) + llm_pos_ids_list.append( + paddle.stack([t_index, h_index, w_index]) + + text_len + + st_idx + ) + st = ed + llm_grid_t * llm_grid_h * llm_grid_w + + if st < len(input_tokens): + st_idx = ( + llm_pos_ids_list[-1].max() + 1 + if len(llm_pos_ids_list) > 0 + else 0 + ) + text_len = len(input_tokens) - st + llm_pos_ids_list.append( + paddle.arange(text_len).view(1, -1).expand(3, -1) + + st_idx + ) + + llm_positions = paddle.cat(llm_pos_ids_list, dim=1).reshape( + 3, -1 + ) + position_ids[..., i, attention_mask[i] == 1] = llm_positions + mrope_position_deltas.append( + llm_positions.max() + 1 - len(total_input_ids[i]) + ) + mrope_position_deltas = paddle.to_tensor( + mrope_position_deltas + ).unsqueeze(1) + return position_ids, mrope_position_deltas + else: + if attention_mask is not None: + position_ids = attention_mask.long().cumsum(-1) - 1 + position_ids.masked_fill_(attention_mask == 0, 1) + position_ids = ( + position_ids.unsqueeze(0) + .expand(3, -1, -1) + .to(attention_mask.device) + ) + max_position_ids = position_ids.max(0, keepdim=False)[0].max( + -1, keepdim=True + )[0] + mrope_position_deltas = ( + max_position_ids + 1 - attention_mask.shape[-1] + ) + else: + position_ids = ( + paddle.arange(input_ids.shape[1]) + .view(1, 1, -1) + .expand(3, input_ids.shape[0], -1) + ) + mrope_position_deltas = paddle.zeros( + [input_ids.shape[0], 1], + dtype=input_ids.dtype, + ) + return position_ids, mrope_position_deltas + + def get_video_features( + self, + pixel_values_videos: paddle.FloatTensor, + video_grid_thw: paddle.LongTensor | None = None, + ): + return self.get_image_features(pixel_values_videos, video_grid_thw) + + def get_image_features( + self, + pixel_values: paddle.FloatTensor, + image_grid_thw: paddle.LongTensor | None = None, + ): + dict_args = { + "pixel_values": pixel_values, + "grid_thw": image_grid_thw, + } + vision_output = self.vision_model(dict_args) + image_embeds, deepstack_image_embeds = ( + vision_output["hidden_states"], + vision_output["deepstack_features_list"], + ) + split_sizes = ( + image_grid_thw.prod(-1) + // self.config.vision_config.spatial_merge_size**2 + ).tolist() + image_embeds = paddle.split(image_embeds, split_sizes) + return image_embeds, deepstack_image_embeds + + def forward( + self, + input_ids: paddle.LongTensor = None, + attention_mask: paddle.Tensor | None = None, + position_ids: paddle.LongTensor | None = None, + loss_mask: paddle.Tensor | None = None, + labels: paddle.Tensor | None = None, + inference_params=None, + pixel_values: paddle.Tensor | None = None, + pixel_values_videos=None, + image_grid_thw=None, + video_grid_thw=None, + runtime_gather_output: bool | None = None, + cache_position: paddle.Tensor | None = None, + attn_mask_startend_row_indices: paddle.Tensor | None = None, + **kwargs, + ) -> paddle.Tensor: + assert loss_mask is None, "loss_mask is not supported yet" + ( + image_embeds, + video_embeds, + deepstack_image_embeds, + deepstack_video_embeds, + ) = (None for _ in range(4)) + if self.add_encoder and pixel_values is not None: + # Handle list[paddle.Tensor] input (from RL training pipeline) + if isinstance(pixel_values, list): + # Filter out None and concatenate tensors + tensor_list = [ + elem for elem in pixel_values if elem is not None + ] + if tensor_list: + pixel_values = paddle.concat(tensor_list, axis=0) + else: + pixel_values = None + if pixel_values is not None: + pixel_values = pixel_values.to( + self.vision_model.parameters()[0].dtype + ) + # Handle list[paddle.Tensor] for image_grid_thw + if image_grid_thw is not None: + if isinstance(image_grid_thw, list): + tensor_list = [ + elem for elem in image_grid_thw if elem is not None + ] + if tensor_list: + image_grid_thw = paddle.concat(tensor_list, axis=0) + else: + image_grid_thw = None + if self.config.freeze_vision_model: + with paddle.no_grad(): + image_embeds, deepstack_image_embeds = ( + self.get_image_features( + pixel_values, image_grid_thw + ) + ) + else: + image_embeds, deepstack_image_embeds = ( + self.get_image_features(pixel_values, image_grid_thw) + ) + image_embeds = paddle.cat(image_embeds, dim=0) + + if self.add_encoder and pixel_values_videos is not None: + # Handle list[paddle.Tensor] input (from RL training pipeline) + if isinstance(pixel_values_videos, list): + # Filter out None and concatenate tensors + tensor_list = [ + elem for elem in pixel_values_videos if elem is not None + ] + if tensor_list: + pixel_values_videos = paddle.concat(tensor_list, axis=0) + else: + pixel_values_videos = None + if pixel_values_videos is not None: + pixel_values_videos = pixel_values_videos.to( + self.vision_model.parameters()[0].dtype + ) + # Handle list[paddle.Tensor] for video_grid_thw + if video_grid_thw is not None: + if isinstance(video_grid_thw, list): + tensor_list = [ + elem for elem in video_grid_thw if elem is not None + ] + if tensor_list: + video_grid_thw = paddle.concat(tensor_list, axis=0) + else: + video_grid_thw = None + if self.config.freeze_vision_model: + with paddle.no_grad(): + video_embeds, deepstack_video_embeds = ( + self.get_video_features( + pixel_values_videos, video_grid_thw + ) + ) + else: + video_embeds, deepstack_video_embeds = ( + self.get_video_features( + pixel_values_videos, video_grid_thw + ) + ) + video_embeds = paddle.cat(video_embeds, axis=0) + + if position_ids is None: + if ( + self.rope_deltas is None + or cache_position is None + or cache_position[0] == 0 + ): + position_ids, rope_deltas = self.get_rope_index( + input_ids, + image_grid_thw, + video_grid_thw, + attention_mask=attention_mask, + ) + self.rope_deltas = rope_deltas + else: + batch_size, seq_length = input_ids.shape + position_ids = paddle.arange(seq_length) + position_ids = position_ids.view(1, 1, -1).expand( + 3, batch_size, -1 + ) + if cache_position is not None: + delta = cache_position[0] + self.rope_deltas + else: + delta = paddle.zeros((batch_size, seq_length)) + delta = delta.repeat_interleave( + batch_size // delta.shape[0], axis=1 + ) + position_ids = position_ids + delta + else: + # Handle position_ids with mrope format [batch_size, seq_len, 3] -> [3, batch_size, seq_len] + if position_ids.ndim == 3 and position_ids.shape[-1] == 3: + position_ids = position_ids.transpose([2, 0, 1]) + elif position_ids.shape == input_ids.shape: + position_ids = position_ids.expand(3, position_ids.shape[0], -1) + + input_dict = { + "input_ids": input_ids, + "position_ids": position_ids, + "attention_mask": None, + "attn_mask_startend_row_indices": attn_mask_startend_row_indices, + "decoder_input": None, + "image_embeds": image_embeds, + "video_embeds": video_embeds, + "labels": labels, + "deepstack_image_embeds": deepstack_image_embeds, + "deepstack_video_embeds": deepstack_video_embeds, + "runtime_gather_output": runtime_gather_output, + } + output = self.language_model(input_dict) + + # print("qwenvl criterion ",self.criterion) + if labels is None: + return output + elif self.criterion is not None: + # print("qwenvl output loss ",self.criterion(output, labels)) + return self.criterion(output, labels) + else: + return output + + def set_input_tensor(self, input_tensor) -> None: + """Set model chunk input tensor.""" + # This is usually handled in schedules.py but some inference code still + # gives us non-lists or None + if not isinstance(input_tensor, list): + input_tensor = [input_tensor] + assert len(input_tensor) == 1, ( + "input_tensor should only be length 1 for llava" + ) + + if self.add_encoder and self.add_decoder: + self.vision_model.set_input_tensor(input_tensor[0]) + elif self.add_encoder: + self.vision_model.set_input_tensor(input_tensor[0]) + elif self.pre_process: + self.encoder_hidden_state = input_tensor[0] + else: + self.language_model.set_input_tensor(input_tensor[0]) + + # def get_input_embeddings(self): + # return self.language_model.get_input_embeddings() + + +__all__ = [ + "Qwen3VLTextTransformerLayer", + "Qwen3VLVisionModel", + "Qwen3VLVisionTransformerLayer", + "Qwen3VLModelDist", +] diff --git a/src/paddlefleet/models/qwen3_vl/qwen3_vl_provider.py b/src/paddlefleet/models/qwen3_vl/qwen3_vl_provider.py deleted file mode 100644 index 38113f65d..000000000 --- a/src/paddlefleet/models/qwen3_vl/qwen3_vl_provider.py +++ /dev/null @@ -1,93 +0,0 @@ -# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import contextlib -from collections.abc import Callable -from dataclasses import dataclass -from functools import partial - -import paddle -from paddle.nn import functional as F - -from ...spec_utils import LayerSpec -from ...transformer import TransformerConfig -from .qwen3_vl_builders import qwen3_vl_vision_builder -from .qwen3_vl_model import Qwen3VLVisionModel, Qwen3VLVisionTransformerLayer - - -@dataclass -class Qwen3VLVisionProvider(TransformerConfig): - patch_size: int = 16 - use_bias: bool = True - add_qkv_bias: bool = True - num_position_embeddings: int = 2304 - embed_dim: int = (1152,) - hidden_size: int = 1152 - out_hidden_size: int = 4096 - in_channels: int = 3 - spatial_merge_size: int = 2 - spatial_patch_size: int = 16 - temporal_patch_size: int = 2 - hidden_dropout_prob: float = 0.0 - attention_dropout: float = 0.0 - intermediate_size: int = 4304 - initializer_range: float = 0.02 - gated_linear_unit: bool = False - activation_func: Callable = F.gelu - layernorm_zero_centered_gamma: bool = False - apply_query_key_layer_scaling: bool = False - persist_layer_norm: bool = True - bias_activation_fusion: bool = False - bias_dropout_fusion: bool = False - attention_softmax_in_fp32: bool = True - normalization: str = "LayerNorm" - apply_rope_fusion: bool = True - rms_norm_eps: float = 1e-6 - transformer_layer_spec: LayerSpec = Qwen3VLVisionTransformerLayer - model_version: str = "qwen3_vl" - img_h: int = 336 - img_w: int = 336 - add_class_token: bool = False - class_token_len: int = 1 - high_precision_rope: bool = True - rotary_percent: float = 1.0 - transform_rules = { - "dtype": "params_dtype", - "num_heads": "num_attention_heads", - "depth": "num_hidden_layers", - "initializer_range": "init_method_std", - } - - def provide(self) -> "Qwen3VLVisionModel": - pp_size = self.pipeline_model_parallel_size - - is_pipeline_asymmetric = getattr( - self, "account_for_embedding_in_pipeline_split", False - ) or getattr(self, "account_for_loss_in_pipeline_split", False) - is_pipeline_asymmetric |= ( - getattr(self, "num_empty_layers_add_in_head", None) - or getattr(self, "num_empty_layers_add_in_tail", None) - ) is not None - - # Initialize model as meta data instead of allocating data on a device - model_init_device_context = contextlib.nullcontext - if self.init_model_with_meta_device: - model_init_device_context = partial(paddle.device, device="meta") - - with model_init_device_context(): - res_model = qwen3_vl_vision_builder( - self, - seg_method="layer:TransformerLayer|EmptyLayer", - num_stages=pp_size, - ) - return res_model diff --git a/src/paddlefleet/pipeline_parallel/pp_layers.py b/src/paddlefleet/pipeline_parallel/pp_layers.py index 70f4c62c9..e9b41eec9 100755 --- a/src/paddlefleet/pipeline_parallel/pp_layers.py +++ b/src/paddlefleet/pipeline_parallel/pp_layers.py @@ -339,6 +339,7 @@ def __init__( world_size = dist.get_world_size() self.global_rank = dist.get_rank() + self.global_rank = 0 if self._topo: if hasattr(self._topo, "_parent_hcg"): diff --git a/src/paddlefleet/transformer/transformer_encoder.py b/src/paddlefleet/transformer/transformer_encoder.py index 39b4c26b2..6b07a19c6 100644 --- a/src/paddlefleet/transformer/transformer_encoder.py +++ b/src/paddlefleet/transformer/transformer_encoder.py @@ -185,6 +185,28 @@ def get_encoder_layer_desc_list(self, layers, spec, name_prefix): ) i += 1 + def _get_weight_only_params(self): + """Get all parameters marked with is_weight_only_mtp flag.""" + return [ + param + for param in self.state_dict().values() + if getattr(param, "is_weight_only_mtp", False) + ] + + def offload_weight_only_params(self): + """Offload all weight-only MTP parameters to CPU pinned memory.""" + for param in self._get_weight_only_params(): + if param.place.is_gpu_place(): + cpu_param = param.pin_memory() + cpu_param._share_buffer_to(param) + + def reload_weight_only_params(self): + """Reload weight-only MTP parameters from CPU pinned memory back to GPU.""" + for param in self._get_weight_only_params(): + if not param.place.is_gpu_place(): + gpu_param = param.cuda() + gpu_param._share_buffer_to(param) + def overlapped_forward_backward( self, forward_chunk, diff --git a/third_party/DeepEP b/third_party/DeepEP index 3fed158b4..226cf5aae 160000 --- a/third_party/DeepEP +++ b/third_party/DeepEP @@ -1 +1 @@ -Subproject commit 3fed158b4751a4a89c92902aeb20bb8c8384f1e7 +Subproject commit 226cf5aae50bf88e28292483cbbb3ffd5f258da9 diff --git a/third_party/quack b/third_party/quack index c81d790dd..7ef82a904 160000 --- a/third_party/quack +++ b/third_party/quack @@ -1 +1 @@ -Subproject commit c81d790dd90f304fe9fc5cd250e06eddbef5b90d +Subproject commit 7ef82a90403f4a407f82d522d658d2b7e87ef733 diff --git a/third_party/sonic-moe b/third_party/sonic-moe index d241602f0..935608e44 160000 --- a/third_party/sonic-moe +++ b/third_party/sonic-moe @@ -1 +1 @@ -Subproject commit d241602f07fb67a48462606371b2de708f2143af +Subproject commit 935608e445dfefcc562ffb6bec2b27b23222ebc7