Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 80 additions & 6 deletions src/mcore_bridge/model/modules/gated_delta_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,17 @@
_GatedDeltaNet = object


def _unpack_sequence(x, cu_seqlens, dim=1):
unpacked_x = []
num_seqs = cu_seqlens.shape[0] - 1
for i in range(num_seqs):
idx_start = cu_seqlens[i].item()
idx_end = cu_seqlens[i + 1].item()
chunked_index = [slice(None)] * dim + [slice(idx_start, idx_end)]
unpacked_x.append(x[tuple(chunked_index)])
return unpacked_x


class GatedDeltaNet(_GatedDeltaNet):

def forward(
Expand Down Expand Up @@ -60,7 +71,8 @@ def forward(
inference_context = deprecate_inference_params(inference_context, inference_params)

seq_len, batch, _ = hidden_states.shape
seq_len = seq_len * self.sp_size
cp_size = self.config.context_parallel_size
seq_len = seq_len * self.sp_size * cp_size

if inference_context is not None:
assert (
Expand All @@ -75,12 +87,35 @@ def forward(
qkvzba, _ = self.in_proj(hidden_states)
nvtx_range_pop(suffix='in_proj')

if cp_size > 1:
from megatron.core.ssm.gated_delta_net import get_parameter_local_cp, tensor_a2a_cp2hp, tensor_a2a_hp2cp
if cu_seqlens is not None:
unpacked_qkvzba = _unpack_sequence(qkvzba, cu_seqlens // self.cp_size, dim=0)
outputs = []
for qkvzba_i in unpacked_qkvzba:
qkvzba_i = tensor_a2a_cp2hp(
qkvzba_i,
seq_dim=0,
head_dim=-1,
cp_group=self.pg_collection.cp,
)
outputs.append(qkvzba_i)
qkvzba = torch.cat(outputs, dim=0)
else:
# CP All to All: CP to HP
qkvzba = tensor_a2a_cp2hp(
qkvzba,
seq_dim=0,
head_dim=-1,
cp_group=self.pg_collection.cp,
)

# Transpose: s b x --> b s x
# From sbhd to bshd format
qkvzba = qkvzba.transpose(0, 1)

# Split, reorder, and reshape the tensor into q, k, v, gate, beta, alpha
num_key_heads_per_device = self.num_key_heads // self.tp_size
num_key_heads_per_device = self.num_key_heads // self.tp_size // cp_size
qkvzba = qkvzba.view(qkvzba.shape[:-1]
+ (num_key_heads_per_device, qkvzba.shape[-1] // num_key_heads_per_device))
qkv, gate, beta, alpha = torch.split(
Expand All @@ -100,17 +135,41 @@ def forward(

# Convolution on qkv
nvtx_range_push(suffix='conv1d')
if cp_size > 1:
conv1d_weight = get_parameter_local_cp(
self.conv1d.weight,
dim=0,
cp_group=self.pg_collection.cp,
)
conv1d_bias = (
get_parameter_local_cp(
self.conv1d.bias,
dim=0,
cp_group=self.pg_collection.cp,
) if self.conv_bias else None)
else:
conv1d_weight = self.conv1d.weight
conv1d_bias = self.conv1d.bias

if (causal_conv1d is None) or self.config.deterministic_mode:
assert cu_seqlens is None, 'Packed sequences are not supported when fla is not available.'
qkv = qkv.transpose(1, 2).contiguous() # b, s, d -> b, d, s
qkv = self.act_fn(self.conv1d(qkv)[..., :seq_len])
conv_out = F.conv1d(
input=qkv,
weight=conv1d_weight,
bias=conv1d_bias,
stride=self.conv1d.stride,
padding=self.conv1d.padding,
dilation=self.conv1d.dilation,
)
qkv = self.act_fn(conv_out[..., :seq_len])
qkv = qkv.transpose(1, 2) # b, d, s -> b, s, d
else:
assert self.activation in ['silu', 'swish']
qkv = causal_conv1d(
x=qkv,
weight=self.conv1d.weight.squeeze(1), # d, 1, w -> d, w
bias=self.conv1d.bias,
weight=conv1d_weight.squeeze(1), # d, 1, w -> d, w
bias=conv1d_bias,
activation=self.activation,
cu_seqlens=cu_seqlens,
)[0]
Expand Down Expand Up @@ -143,7 +202,12 @@ def forward(

# Calculate g and beta
nvtx_range_push(suffix='g_and_beta')
g = -self.A_log.exp() * F.softplus(alpha.float() + self.dt_bias) # In fp32
if cp_size > 1:
A_log_local_cp = get_parameter_local_cp(self.A_log, dim=0, cp_group=self.pg_collection.cp)
dt_bias_local_cp = get_parameter_local_cp(self.dt_bias, dim=0, cp_group=self.pg_collection.cp)
else:
A_log_local_cp, dt_bias_local_cp = self.A_log, self.dt_bias
g = -A_log_local_cp.exp() * F.softplus(alpha.float() + dt_bias_local_cp) # In fp32
beta = beta.sigmoid()
nvtx_range_pop(suffix='g_and_beta')

Expand Down Expand Up @@ -183,6 +247,16 @@ def forward(
# From bshd back to sbhd format
norm_out = norm_out.reshape(batch, seq_len, -1)
norm_out = norm_out.transpose(0, 1).contiguous()
if cp_size > 1:
if cu_seqlens is not None:
unpacked_norm_out = _unpack_sequence(norm_out, cu_seqlens, dim=0)
outputs = []
for norm_out_i in unpacked_norm_out:
norm_out_i = tensor_a2a_hp2cp(norm_out_i, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp)
outputs.append(norm_out_i)
norm_out = torch.cat(outputs, dim=0)
else:
norm_out = tensor_a2a_hp2cp(norm_out, seq_dim=0, head_dim=-1, cp_group=self.pg_collection.cp)

# Output projection
nvtx_range_push(suffix='out_proj')
Expand Down
Loading