From f65be607786eb68003e089adc494a047ece8d5a3 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 6 Apr 2026 02:40:37 +0800 Subject: [PATCH 1/5] support GDN CP --- .../model/modules/gated_delta_net.py | 71 +++++++++++++++++-- 1 file changed, 65 insertions(+), 6 deletions(-) diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index 5503966..9be39ac 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -60,7 +60,8 @@ def forward( inference_context = deprecate_inference_params(inference_context, inference_params) seq_len, batch, _ = hidden_states.shape - seq_len = seq_len * self.sp_size + cp_size = self.config.context_parallel_size + seq_len = seq_len * self.sp_size * cp_size if inference_context is not None: assert ( @@ -75,12 +76,31 @@ def forward( qkvzba, _ = self.in_proj(hidden_states) nvtx_range_pop(suffix='in_proj') + if cp_size > 1: + from megatron.core.ssm.gated_delta_net import get_parameter_local_cp, tensor_a2a_cp2hp + + # CP All to All: CP to HP + qkvzba = tensor_a2a_cp2hp( + qkvzba, + seq_dim=0, + head_dim=-1, + cp_group=self.pg_collection.cp, + split_sections=[ + self.qk_dim_local_tp, + self.qk_dim_local_tp, + self.v_dim_local_tp, + self.v_dim_local_tp, + self.num_value_heads // self.tp_size, + self.num_value_heads // self.tp_size, + ], + ) + # Transpose: s b x --> b s x # From sbhd to bshd format qkvzba = qkvzba.transpose(0, 1) # Split, reorder, and reshape the tensor into q, k, v, gate, beta, alpha - num_key_heads_per_device = self.num_key_heads // self.tp_size + num_key_heads_per_device = self.num_key_heads // self.tp_size // cp_size qkvzba = qkvzba.view(qkvzba.shape[:-1] + (num_key_heads_per_device, qkvzba.shape[-1] // num_key_heads_per_device)) qkv, gate, beta, alpha = torch.split( @@ -100,17 +120,49 @@ def forward( # Convolution on qkv nvtx_range_push(suffix='conv1d') + if cp_size > 1: + qkv_channels_split_sections = [ + self.qk_dim_local_tp, + self.qk_dim_local_tp, + self.v_dim_local_tp, + ] + conv1d_weight = get_parameter_local_cp( + self.conv1d.weight, + dim=0, + cp_group=self.pg_collection.cp, + split_sections=qkv_channels_split_sections, + ) + conv1d_bias = ( + get_parameter_local_cp( + self.conv1d.bias, + dim=0, + cp_group=self.pg_collection.cp, + split_sections=qkv_channels_split_sections, + ) if self.conv_bias else None) + else: + conv1d_weight = self.conv1d.weight + conv1d_bias = self.conv1d.bias + if (causal_conv1d is None) or self.config.deterministic_mode: assert cu_seqlens is None, 'Packed sequences are not supported when fla is not available.' qkv = qkv.transpose(1, 2).contiguous() # b, s, d -> b, d, s - qkv = self.act_fn(self.conv1d(qkv)[..., :seq_len]) + conv_out = F.conv1d( + input=qkv, + weight=conv1d_weight, + bias=conv1d_bias, + stride=self.conv1d.stride, + padding=self.conv1d.padding, + dilation=self.conv1d.dilation, + groups=self.conv_dim_local_tp // cp_size if cp_size > 1 else None, + ) + qkv = self.act_fn(conv_out[..., :seq_len]) qkv = qkv.transpose(1, 2) # b, d, s -> b, s, d else: assert self.activation in ['silu', 'swish'] qkv = causal_conv1d( x=qkv, - weight=self.conv1d.weight.squeeze(1), # d, 1, w -> d, w - bias=self.conv1d.bias, + weight=conv1d_weight.squeeze(1), # d, 1, w -> d, w + bias=conv1d_bias, activation=self.activation, cu_seqlens=cu_seqlens, )[0] @@ -143,7 +195,12 @@ def forward( # Calculate g and beta nvtx_range_push(suffix='g_and_beta') - g = -self.A_log.exp() * F.softplus(alpha.float() + self.dt_bias) # In fp32 + if cp_size > 1: + A_log_local_cp = get_parameter_local_cp(self.A_log, dim=0, cp_group=self.pg_collection.cp) + dt_bias_local_cp = get_parameter_local_cp(self.dt_bias, dim=0, cp_group=self.pg_collection.cp) + else: + A_log_local_cp, dt_bias_local_cp = A_log, self.dt_bias + g = -A_log_local_cp.exp() * F.softplus(alpha.float() + dt_bias_local_cp) # In fp32 beta = beta.sigmoid() nvtx_range_pop(suffix='g_and_beta') @@ -183,6 +240,8 @@ def forward( # From bshd back to sbhd format norm_out = norm_out.reshape(batch, seq_len, -1) norm_out = norm_out.transpose(0, 1).contiguous() + if cp_size > 1: + norm_out = tensor_a2a_hp2cp(norm_out, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp) # Output projection nvtx_range_push(suffix='out_proj') From 3c55565b2e2896a771e5afd8e87dea922376b2b6 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 6 Apr 2026 15:08:48 +0800 Subject: [PATCH 2/5] update --- .../model/modules/gated_delta_net.py | 55 ++----------------- 1 file changed, 5 insertions(+), 50 deletions(-) diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index 9be39ac..d47f254 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -77,7 +77,7 @@ def forward( nvtx_range_pop(suffix='in_proj') if cp_size > 1: - from megatron.core.ssm.gated_delta_net import get_parameter_local_cp, tensor_a2a_cp2hp + from megatron.core.ssm.gated_delta_net import tensor_a2a_cp2hp, tensor_a2a_hp2cp # CP All to All: CP to HP qkvzba = tensor_a2a_cp2hp( @@ -85,14 +85,6 @@ def forward( seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp, - split_sections=[ - self.qk_dim_local_tp, - self.qk_dim_local_tp, - self.v_dim_local_tp, - self.v_dim_local_tp, - self.num_value_heads // self.tp_size, - self.num_value_heads // self.tp_size, - ], ) # Transpose: s b x --> b s x @@ -120,49 +112,17 @@ def forward( # Convolution on qkv nvtx_range_push(suffix='conv1d') - if cp_size > 1: - qkv_channels_split_sections = [ - self.qk_dim_local_tp, - self.qk_dim_local_tp, - self.v_dim_local_tp, - ] - conv1d_weight = get_parameter_local_cp( - self.conv1d.weight, - dim=0, - cp_group=self.pg_collection.cp, - split_sections=qkv_channels_split_sections, - ) - conv1d_bias = ( - get_parameter_local_cp( - self.conv1d.bias, - dim=0, - cp_group=self.pg_collection.cp, - split_sections=qkv_channels_split_sections, - ) if self.conv_bias else None) - else: - conv1d_weight = self.conv1d.weight - conv1d_bias = self.conv1d.bias - if (causal_conv1d is None) or self.config.deterministic_mode: assert cu_seqlens is None, 'Packed sequences are not supported when fla is not available.' qkv = qkv.transpose(1, 2).contiguous() # b, s, d -> b, d, s - conv_out = F.conv1d( - input=qkv, - weight=conv1d_weight, - bias=conv1d_bias, - stride=self.conv1d.stride, - padding=self.conv1d.padding, - dilation=self.conv1d.dilation, - groups=self.conv_dim_local_tp // cp_size if cp_size > 1 else None, - ) - qkv = self.act_fn(conv_out[..., :seq_len]) + qkv = self.act_fn(self.conv1d(qkv)[..., :seq_len]) qkv = qkv.transpose(1, 2) # b, d, s -> b, s, d else: assert self.activation in ['silu', 'swish'] qkv = causal_conv1d( x=qkv, - weight=conv1d_weight.squeeze(1), # d, 1, w -> d, w - bias=conv1d_bias, + weight=self.conv1d.weight.squeeze(1), # d, 1, w -> d, w + bias=self.conv1d.bias, activation=self.activation, cu_seqlens=cu_seqlens, )[0] @@ -195,12 +155,7 @@ def forward( # Calculate g and beta nvtx_range_push(suffix='g_and_beta') - if cp_size > 1: - A_log_local_cp = get_parameter_local_cp(self.A_log, dim=0, cp_group=self.pg_collection.cp) - dt_bias_local_cp = get_parameter_local_cp(self.dt_bias, dim=0, cp_group=self.pg_collection.cp) - else: - A_log_local_cp, dt_bias_local_cp = A_log, self.dt_bias - g = -A_log_local_cp.exp() * F.softplus(alpha.float() + dt_bias_local_cp) # In fp32 + g = -self.A_log.exp() * F.softplus(alpha.float() + self.dt_bias) # In fp32 beta = beta.sigmoid() nvtx_range_pop(suffix='g_and_beta') From 17497a45a49b1f5f6b3cb2b08222b5ccacef3745 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 6 Apr 2026 15:16:12 +0800 Subject: [PATCH 3/5] update --- .../model/modules/gated_delta_net.py | 40 ++++++++++++++++--- 1 file changed, 35 insertions(+), 5 deletions(-) diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index d47f254..70ad276 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -77,7 +77,7 @@ def forward( nvtx_range_pop(suffix='in_proj') if cp_size > 1: - from megatron.core.ssm.gated_delta_net import tensor_a2a_cp2hp, tensor_a2a_hp2cp + from megatron.core.ssm.gated_delta_net import get_parameter_local_cp, tensor_a2a_cp2hp, tensor_a2a_hp2cp # CP All to All: CP to HP qkvzba = tensor_a2a_cp2hp( @@ -112,17 +112,42 @@ def forward( # Convolution on qkv nvtx_range_push(suffix='conv1d') + if cp_size > 1: + conv1d_weight = get_parameter_local_cp( + self.conv1d.weight, + dim=0, + cp_group=self.pg_collection.cp, + ) + conv1d_bias = ( + get_parameter_local_cp( + self.conv1d.bias, + dim=0, + cp_group=self.pg_collection.cp, + ) if self.conv_bias else None) + else: + conv1d_weight = self.conv1d.weight + conv1d_bias = self.conv1d.bias + if (causal_conv1d is None) or self.config.deterministic_mode: assert cu_seqlens is None, 'Packed sequences are not supported when fla is not available.' qkv = qkv.transpose(1, 2).contiguous() # b, s, d -> b, d, s - qkv = self.act_fn(self.conv1d(qkv)[..., :seq_len]) + conv_out = F.conv1d( + input=qkv, + weight=conv1d_weight, + bias=conv1d_bias, + stride=self.conv1d.stride, + padding=self.conv1d.padding, + dilation=self.conv1d.dilation, + groups=self.conv_dim_local_tp // cp_size if cp_size > 1 else None, + ) + qkv = self.act_fn(conv_out[..., :seq_len]) qkv = qkv.transpose(1, 2) # b, d, s -> b, s, d else: assert self.activation in ['silu', 'swish'] qkv = causal_conv1d( x=qkv, - weight=self.conv1d.weight.squeeze(1), # d, 1, w -> d, w - bias=self.conv1d.bias, + weight=conv1d_weight.squeeze(1), # d, 1, w -> d, w + bias=conv1d_bias, activation=self.activation, cu_seqlens=cu_seqlens, )[0] @@ -155,7 +180,12 @@ def forward( # Calculate g and beta nvtx_range_push(suffix='g_and_beta') - g = -self.A_log.exp() * F.softplus(alpha.float() + self.dt_bias) # In fp32 + if cp_size > 1: + A_log_local_cp = get_parameter_local_cp(self.A_log, dim=0, cp_group=self.pg_collection.cp) + dt_bias_local_cp = get_parameter_local_cp(self.dt_bias, dim=0, cp_group=self.pg_collection.cp) + else: + A_log_local_cp, dt_bias_local_cp = self.A_log, self.dt_bias + g = -A_log_local_cp.exp() * F.softplus(alpha.float() + dt_bias_local_cp) # In fp32 beta = beta.sigmoid() nvtx_range_pop(suffix='g_and_beta') From 89b6d59ece08d71868cd51ca1bf7625bbd6db378 Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 6 Apr 2026 15:51:38 +0800 Subject: [PATCH 4/5] update --- .../model/modules/gated_delta_net.py | 49 +++++++++++++++---- 1 file changed, 40 insertions(+), 9 deletions(-) diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index 70ad276..a2fda79 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -21,6 +21,17 @@ _GatedDeltaNet = object +def _unpack_sequence(x, cu_seqlens, dim=1): + unpacked_x = [] + num_seqs = cu_seqlens.shape[0] - 1 + for i in range(num_seqs): + idx_start = cu_seqlens[i].item() + idx_end = cu_seqlens[i + 1].item() + chunked_index = [slice(None)] * dim + [slice(idx_start, idx_end)] + unpacked_x.append(x[tuple(chunked_index)]) + return unpacked_x + + class GatedDeltaNet(_GatedDeltaNet): def forward( @@ -78,14 +89,26 @@ def forward( if cp_size > 1: from megatron.core.ssm.gated_delta_net import get_parameter_local_cp, tensor_a2a_cp2hp, tensor_a2a_hp2cp - - # CP All to All: CP to HP - qkvzba = tensor_a2a_cp2hp( - qkvzba, - seq_dim=0, - head_dim=-1, - cp_group=self.pg_collection.cp, - ) + if cu_seqlens is not None: + unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens // self.cp_size, dim=0) + outputs = [] + for qkvzba_i in unpacked_qkvzba: + qkvzba_i = tensor_a2a_cp2hp( + qkvzba_i, + seq_dim=0, + head_dim=-1, + cp_group=self.pg_collection.cp, + ) + outputs.append(qkvzba_i) + qkvzba = torch.cat(outputs, dim=0) + else: + # CP All to All: CP to HP + qkvzba = tensor_a2a_cp2hp( + qkvzba, + seq_dim=0, + head_dim=-1, + cp_group=self.pg_collection.cp, + ) # Transpose: s b x --> b s x # From sbhd to bshd format @@ -226,7 +249,15 @@ def forward( norm_out = norm_out.reshape(batch, seq_len, -1) norm_out = norm_out.transpose(0, 1).contiguous() if cp_size > 1: - norm_out = tensor_a2a_hp2cp(norm_out, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp) + if cu_seqlens is not None: + unpacked_norm_out = _unpack_sequence(norm_out, cu_seqlens, dim=0) + outputs = [] + for norm_out_i in unpacked_norm_out: + norm_out_i = tensor_a2a_hp2cp(norm_out_i, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp) + outputs.append(norm_out_i) + norm_out = torch.cat(outputs, dim=0) + else: + norm_out = tensor_a2a_hp2cp(norm_out, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp) # Output projection nvtx_range_push(suffix='out_proj') From 86144c08db6ba16eaf9f5b3e2fdaa37f007b063e Mon Sep 17 00:00:00 2001 From: Jintao Huang Date: Mon, 6 Apr 2026 16:02:49 +0800 Subject: [PATCH 5/5] fix --- src/mcore_bridge/model/modules/gated_delta_net.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/mcore_bridge/model/modules/gated_delta_net.py b/src/mcore_bridge/model/modules/gated_delta_net.py index a2fda79..8d0b4ea 100644 --- a/src/mcore_bridge/model/modules/gated_delta_net.py +++ b/src/mcore_bridge/model/modules/gated_delta_net.py @@ -161,7 +161,6 @@ def forward( stride=self.conv1d.stride, padding=self.conv1d.padding, dilation=self.conv1d.dilation, - groups=self.conv_dim_local_tp // cp_size if cp_size > 1 else None, ) qkv = self.act_fn(conv_out[..., :seq_len]) qkv = qkv.transpose(1, 2) # b, d, s -> b, s, d