diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index 5503966..8d0b4ea 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -21,6 +21,17 @@ _GatedDeltaNet = object +def _unpack_sequence(x, cu_seqlens, dim=1): + unpacked_x = [] + num_seqs = cu_seqlens.shape[0] - 1 + for i in range(num_seqs): + idx_start = cu_seqlens[i].item() + idx_end = cu_seqlens[i + 1].item() + chunked_index = [slice(None)] * dim + [slice(idx_start, idx_end)] + unpacked_x.append(x[tuple(chunked_index)]) + return unpacked_x + + class GatedDeltaNet(_GatedDeltaNet): def forward( @@ -60,7 +71,8 @@ def forward( inference_context = deprecate_inference_params(inference_context, inference_params) seq_len, batch, _ = hidden_states.shape - seq_len = seq_len * self.sp_size + cp_size = self.config.context_parallel_size + seq_len = seq_len * self.sp_size * cp_size if inference_context is not None: assert ( @@ -75,12 +87,35 @@ def forward( qkvzba, _ = self.in_proj(hidden_states) nvtx_range_pop(suffix='in_proj') + if cp_size > 1: + from megatron.core.ssm.gated_delta_net import get_parameter_local_cp, tensor_a2a_cp2hp, tensor_a2a_hp2cp + if cu_seqlens is not None: + unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens // self.cp_size, dim=0) + outputs = [] + for qkvzba_i in unpacked_qkvzba: + qkvzba_i = tensor_a2a_cp2hp( + qkvzba_i, + seq_dim=0, + head_dim=-1, + cp_group=self.pg_collection.cp, + ) + outputs.append(qkvzba_i) + qkvzba = torch.cat(outputs, dim=0) + else: + # CP All to All: CP to HP + qkvzba = tensor_a2a_cp2hp( + qkvzba, + seq_dim=0, + head_dim=-1, + cp_group=self.pg_collection.cp, + ) + # Transpose: s b x --> b s x # From sbhd to bshd format qkvzba = qkvzba.transpose(0, 1) # Split, reorder, and reshape the tensor into q, k, v, gate, beta, alpha - num_key_heads_per_device = self.num_key_heads // self.tp_size + num_key_heads_per_device = self.num_key_heads // self.tp_size // cp_size qkvzba = qkvzba.view(qkvzba.shape[:-1] + (num_key_heads_per_device, qkvzba.shape[-1] // num_key_heads_per_device)) qkv, gate, beta, alpha = torch.split( @@ -100,17 +135,41 @@ def forward( # Convolution on qkv nvtx_range_push(suffix='conv1d') + if cp_size > 1: + conv1d_weight = get_parameter_local_cp( + self.conv1d.weight, + dim=0, + cp_group=self.pg_collection.cp, + ) + conv1d_bias = ( + get_parameter_local_cp( + self.conv1d.bias, + dim=0, + cp_group=self.pg_collection.cp, + ) if self.conv_bias else None) + else: + conv1d_weight = self.conv1d.weight + conv1d_bias = self.conv1d.bias + if (causal_conv1d is None) or self.config.deterministic_mode: assert cu_seqlens is None, 'Packed sequences are not supported when fla is not available.' qkv = qkv.transpose(1, 2).contiguous() # b, s, d -> b, d, s - qkv = self.act_fn(self.conv1d(qkv)[..., :seq_len]) + conv_out = F.conv1d( + input=qkv, + weight=conv1d_weight, + bias=conv1d_bias, + stride=self.conv1d.stride, + padding=self.conv1d.padding, + dilation=self.conv1d.dilation, + ) + qkv = self.act_fn(conv_out[..., :seq_len]) qkv = qkv.transpose(1, 2) # b, d, s -> b, s, d else: assert self.activation in ['silu', 'swish'] qkv = causal_conv1d( x=qkv, - weight=self.conv1d.weight.squeeze(1), # d, 1, w -> d, w - bias=self.conv1d.bias, + weight=conv1d_weight.squeeze(1), # d, 1, w -> d, w + bias=conv1d_bias, activation=self.activation, cu_seqlens=cu_seqlens, )[0] @@ -143,7 +202,12 @@ def forward( # Calculate g and beta nvtx_range_push(suffix='g_and_beta') - g = -self.A_log.exp() * F.softplus(alpha.float() + self.dt_bias) # In fp32 + if cp_size > 1: + A_log_local_cp = get_parameter_local_cp(self.A_log, dim=0, cp_group=self.pg_collection.cp) + dt_bias_local_cp = get_parameter_local_cp(self.dt_bias, dim=0, cp_group=self.pg_collection.cp) + else: + A_log_local_cp, dt_bias_local_cp = self.A_log, self.dt_bias + g = -A_log_local_cp.exp() * F.softplus(alpha.float() + dt_bias_local_cp) # In fp32 beta = beta.sigmoid() nvtx_range_pop(suffix='g_and_beta') @@ -183,6 +247,16 @@ def forward( # From bshd back to sbhd format norm_out = norm_out.reshape(batch, seq_len, -1) norm_out = norm_out.transpose(0, 1).contiguous() + if cp_size > 1: + if cu_seqlens is not None: + unpacked_norm_out = _unpack_sequence(norm_out, cu_seqlens, dim=0) + outputs = [] + for norm_out_i in unpacked_norm_out: + norm_out_i = tensor_a2a_hp2cp(norm_out_i, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp) + outputs.append(norm_out_i) + norm_out = torch.cat(outputs, dim=0) + else: + norm_out = tensor_a2a_hp2cp(norm_out, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp) # Output projection nvtx_range_push(suffix='out_proj')