diff --git a/tests/pytorch/test_fused_rope.py b/tests/pytorch/test_fused_rope.py index aaf2eca2d3..9e4ddbdad1 100644 --- a/tests/pytorch/test_fused_rope.py +++ b/tests/pytorch/test_fused_rope.py @@ -58,10 +58,6 @@ def test_fused_rope( # are with the maximum length of the rope embeddings. pytest.skip("Skipping test with margin=0 and start_positions=True") - if start_positions == True and cp_size > 1: - # `start_positions` is only supported for `cp_size=1` and inference. - pytest.skip("Skipping test with cp_size>1 and start_positions=True") - device = torch.device("cuda:0") batch_size, head_num = 2, 64 t = torch.rand( @@ -102,11 +98,8 @@ def test_fused_rope( cp_rank=cp_rank, ).to(dtype) loss_unfused = loss_func(output_unfused) - - if not isinstance(start_positions, torch.Tensor): - loss_unfused.backward() - grad_unfused = t.grad.detach().clone() - + loss_unfused.backward() + grad_unfused = t.grad.detach().clone() t.grad = None # fused @@ -121,17 +114,12 @@ def test_fused_rope( cp_rank=cp_rank, ) loss_fused = loss_func(output_fused) - - if not isinstance(start_positions, torch.Tensor): - loss_fused.backward() - grad_fused = t.grad.detach().clone() + loss_fused.backward() + grad_fused = t.grad.detach().clone() t.grad = None torch.testing.assert_close(output_fused, output_unfused) - - if not isinstance(start_positions, torch.Tensor): - torch.testing.assert_close(grad_fused, grad_unfused) - + torch.testing.assert_close(grad_fused, grad_unfused) assert output_fused.is_contiguous() @@ -156,10 +144,6 @@ def test_fused_rope_thd( margin: int, ) -> None: - if start_positions == True and cp_size > 1: - # `start_positions` is only supported for `cp_size=1` and inference. - pytest.skip("Skipping test with cp_size>1 and start_positions=True") - device = torch.device("cuda:0") batch_size, head_num = 2, 64 cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048] @@ -214,10 +198,8 @@ def test_fused_rope_thd( cp_rank=cp_rank, ).to(dtype) loss_unfused = loss_func(output_unfused) - - if not isinstance(start_positions, torch.Tensor): - loss_unfused.backward() - grad_unfused = t.grad.detach().clone() + loss_unfused.backward() + grad_unfused = t.grad.detach().clone() t.grad = None # fused @@ -233,18 +215,142 @@ def test_fused_rope_thd( cp_rank=cp_rank, ) loss_fused = loss_func(output_fused) - - if not isinstance(start_positions, torch.Tensor): - loss_fused.backward() - grad_fused = t.grad.detach().clone() + loss_fused.backward() + grad_fused = t.grad.detach().clone() t.grad = None torch.testing.assert_close(output_fused, output_unfused) + torch.testing.assert_close(grad_fused, grad_unfused) + assert output_fused.is_contiguous() - if not isinstance(start_positions, torch.Tensor): - torch.testing.assert_close(grad_fused, grad_unfused) - assert output_fused.is_contiguous() +@pytest.mark.parametrize("start_positions", [False, True]) +@pytest.mark.parametrize("dtype", [torch.bfloat16]) +@pytest.mark.parametrize("hidden_size", [128, 256]) +@pytest.mark.parametrize("rotary_percent", [1.0]) +@pytest.mark.parametrize("loss_func", [_overlapping_grad]) +@pytest.mark.parametrize("cp_size", [2]) +@pytest.mark.parametrize("interleaved", [False, True]) +def test_unfused_rope_thd_vs_bshd( + dtype: torch.dtype, + hidden_size: int, + rotary_percent: float, + loss_func: Callable, + cp_size: int, + interleaved: bool, + start_positions: bool, +) -> None: + """ + This is just a sanity check to ensure that the unfused RoPE in THD/SBHD/BSHD + formats are the same. + """ + device = torch.device("cuda:0") + seqlen, max_seqlen = 16, 2048 + batch_size, head_num = 4, 256 + + # NOTE: dtype=torch.int32 is important, otherwise the cumsum will be in int64 and + # that causes unexpected issues. + seq_lens = torch.tensor([seqlen for _ in range(batch_size)], dtype=torch.int32) + + cu_seqlens = torch.cumsum(torch.cat([torch.zeros(1, dtype=torch.int32), seq_lens]), dim=0).to( + device=device, dtype=torch.int32 + ) + + # Create a tensor in THD format + thd = torch.rand( + (cu_seqlens[-1] // cp_size, head_num, hidden_size), + dtype=dtype, + device=device, + ) + thd.requires_grad = True + + # Clone the tensor to create a tensor in BSHD format + bshd = thd.view(batch_size, -1, head_num, hidden_size).clone().detach() + bshd = bshd.to(dtype=dtype, device=device) + bshd.requires_grad = True + + # Clone the tensor to create a tensor in SBHD format + sbhd = bshd.transpose(1, 0).clone().detach() + sbhd = sbhd.to(dtype=dtype, device=device) + sbhd.requires_grad = True + + rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved) + emb = rotary_pos_emb(max_seqlen) + assert emb.is_contiguous() + + start_positions = cu_seqlens[:-1] if start_positions else None + + for cp_rank in range(cp_size): + # unfused bshd + output_unfused_bshd = apply_rotary_pos_emb( + bshd.float(), + emb, + start_positions=start_positions, + interleaved=interleaved, + fused=False, + tensor_format="bshd", + cu_seqlens=cu_seqlens, + cp_size=cp_size, + cp_rank=cp_rank, + ).to(dtype) + loss_unfused_bshd = loss_func(output_unfused_bshd) + loss_unfused_bshd.backward() + grad_unfused_bshd = bshd.grad.detach().clone() + bshd.grad = None + + # unfused sbhd + output_unfused_sbhd = apply_rotary_pos_emb( + sbhd.float(), + emb, + start_positions=start_positions, + interleaved=interleaved, + fused=False, + tensor_format="sbhd", + cu_seqlens=cu_seqlens, + cp_size=cp_size, + cp_rank=cp_rank, + ).to(dtype) + + loss_unfused_sbhd = loss_func(output_unfused_sbhd) + loss_unfused_sbhd.backward() + grad_unfused_sbhd = sbhd.grad.detach().clone() + sbhd.grad = None + + # unfused thd + output_unfused_thd = apply_rotary_pos_emb( + thd.float(), + emb, + start_positions=start_positions, + tensor_format="thd", + interleaved=interleaved, + fused=False, + cu_seqlens=cu_seqlens, + cp_size=cp_size, + cp_rank=cp_rank, + ).to(dtype) + + loss_unfused_thd = loss_func(output_unfused_thd) + loss_unfused_thd.backward() + grad_unfused_thd = thd.grad.detach().clone() + thd.grad = None + + torch.testing.assert_close( + output_unfused_bshd.reshape(*output_unfused_thd.shape), output_unfused_thd + ) + torch.testing.assert_close( + output_unfused_sbhd.transpose(1, 0).reshape(*output_unfused_thd.shape), + output_unfused_thd, + ) + torch.testing.assert_close( + grad_unfused_bshd.reshape(*grad_unfused_thd.shape), grad_unfused_thd + ) + torch.testing.assert_close( + grad_unfused_sbhd.transpose(1, 0).reshape(*grad_unfused_thd.shape), grad_unfused_thd + ) + + assert output_unfused_thd.is_contiguous() + assert output_unfused_bshd.is_contiguous() + assert output_unfused_sbhd.is_contiguous() @pytest.mark.parametrize("start_positions", [True, False]) diff --git a/transformer_engine/common/fused_rope/fused_rope.cu b/transformer_engine/common/fused_rope/fused_rope.cu index ccd0bc44c5..597a5d3c29 100644 --- a/transformer_engine/common/fused_rope/fused_rope.cu +++ b/transformer_engine/common/fused_rope/fused_rope.cu @@ -155,18 +155,18 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq cur_seqlens = s; } - int s_id_for_freqs; + // Offset the RoPE embedding by start_positions if provided. + int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id]; + int s_id_for_freqs = s_id + begin_offset; + + // If CP_SIZE > 1, offset the RoPE embedding by cp_rank based on the dual-chunk order. if (cp_size > 1) { assert(cur_seqlens % 2 == 0); if (s_id < cur_seqlens / 2) { - s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; + s_id_for_freqs += cp_rank * cur_seqlens / 2; } else { - s_id_for_freqs = - cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; + s_id_for_freqs += cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 - cur_seqlens / 2; } - } else { - int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id]; - s_id_for_freqs = s_id + begin_offset; } fused_rope_block_forward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block, @@ -175,11 +175,11 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq template __global__ void fused_rope_backward_kernel( - const scalar_t *src, const int *cu_seqlens, const float *freqs, scalar_t *dst, - const bool interleaved, const int cp_size, const int cp_rank, const int s, const int h, - const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h, - const int stride_d, const int o_stride_s_or_t, const int o_stride_b, const int o_stride_h, - const int o_stride_d) { + const scalar_t *src, const int *cu_seqlens, const float *freqs, const int *start_positions, + scalar_t *dst, const bool interleaved, const int cp_size, const int cp_rank, const int s, + const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b, + const int stride_h, const int stride_d, const int o_stride_s_or_t, const int o_stride_b, + const int o_stride_h, const int o_stride_d) { int s_id = blockIdx.x, b_id = blockIdx.y; int offset_block, offset_block_dst; int cur_seqlens; @@ -197,17 +197,18 @@ __global__ void fused_rope_backward_kernel( cur_seqlens = s; } - int s_id_for_freqs; + // Offset the RoPE embedding by start_positions if provided. + int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id]; + int s_id_for_freqs = s_id + begin_offset; + + // If CP_SIZE > 1, offset the RoPE embedding by cp_rank based on the dual-chunk order. if (cp_size > 1) { assert(cur_seqlens % 2 == 0); if (s_id < cur_seqlens / 2) { - s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2; + s_id_for_freqs += cp_rank * cur_seqlens / 2; } else { - s_id_for_freqs = - cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2; + s_id_for_freqs += cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 - cur_seqlens / 2; } - } else { - s_id_for_freqs = s_id; } fused_rope_block_backward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block, @@ -495,12 +496,12 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c template void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens, - const float *freqs, scalar_t *input_grads, - const NVTE_QKV_Format qkv_format, const bool interleaved, - const int cp_size, const int cp_rank, const int s, const int b, - const int h, const int d, const int d2, const int stride_s_or_t, - const int stride_b, const int stride_h, const int stride_d, - cudaStream_t stream) { + const float *freqs, const int *start_positions, + scalar_t *input_grads, const NVTE_QKV_Format qkv_format, + const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, + const int stride_s_or_t, const int stride_b, const int stride_h, + const int stride_d, cudaStream_t stream) { int warps_per_block = h < 16 ? 4 : 8; dim3 blocks(s, b); dim3 threads(THREADS_PER_WARP, warps_per_block); @@ -521,9 +522,9 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se const int o_stride_d = 1; fused_rope_backward_kernel<<>>( - output_grads, cu_seqlens, freqs, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2, - stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h, - o_stride_d); + output_grads, cu_seqlens, freqs, start_positions, input_grads, interleaved, cp_size, cp_rank, + s, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, + o_stride_h, o_stride_d); NVTE_CHECK_CUDA(cudaGetLastError()); } @@ -590,16 +591,18 @@ void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Ten } void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, const Tensor &freqs, - Tensor *input_grads, const NVTE_QKV_Format qkv_format, - const bool interleaved, const int cp_size, const int cp_rank, const int s, - const int b, const int h, const int d, const int d2, - const int stride_s_or_t, const int stride_b, const int stride_h, - const int stride_d, cudaStream_t stream) { + const Tensor &start_positions, Tensor *input_grads, + const NVTE_QKV_Format qkv_format, const bool interleaved, + const int cp_size, const int cp_rank, const int s, const int b, + const int h, const int d, const int d2, const int stride_s_or_t, + const int stride_b, const int stride_h, const int stride_d, + cudaStream_t stream) { TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT( output_grads.data.dtype, scalar_t, fused_rope_backward_launcher(reinterpret_cast(output_grads.data.dptr), reinterpret_cast(cu_seqlens.data.dptr), reinterpret_cast(freqs.data.dptr), + reinterpret_cast(start_positions.data.dptr), reinterpret_cast(input_grads->data.dptr), qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream);); @@ -663,18 +666,18 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens } void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, - const NVTE_QKV_Format qkv_format, const bool interleaved, - const int cp_size, const int cp_rank, const int s, const int b, - const int h, const int d, const int d2, const int stride_s_or_t, - const int stride_b, const int stride_h, const int stride_d, - cudaStream_t stream) { + const NVTETensor freqs, const NVTETensor start_positions, + NVTETensor input_grads, const NVTE_QKV_Format qkv_format, + const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, + const int stride_s_or_t, const int stride_b, const int stride_h, + const int stride_d, cudaStream_t stream) { NVTE_API_CALL(nvte_fused_rope_backward); using namespace transformer_engine; fused_rope_backward(*convertNVTETensorCheck(output_grads), *convertNVTETensorCheck(cu_seqlens), - *convertNVTETensorCheck(freqs), convertNVTETensorCheck(input_grads), - qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t, - stride_b, stride_h, stride_d, stream); + *convertNVTETensorCheck(freqs), *convertNVTETensorCheck(start_positions), + convertNVTETensorCheck(input_grads), qkv_format, interleaved, cp_size, + cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream); } void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs, diff --git a/transformer_engine/common/include/transformer_engine/fused_rope.h b/transformer_engine/common/include/transformer_engine/fused_rope.h index 610868f932..19047f463b 100644 --- a/transformer_engine/common/include/transformer_engine/fused_rope.h +++ b/transformer_engine/common/include/transformer_engine/fused_rope.h @@ -51,6 +51,7 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens * \param[in] cu_seqlens The cumulative sum of sequence lengths tensor. * (Required for the thd format, empty tensor for other formats) * \param[in] freqs The freqs tensor. + * \param[in] start_positions The beginning offsets for applying RoPE embeddings. * \param[out] input_grads Input gradient tensor to calculate. * \param[in] qkv_format QKV format. * \param[in] interleaved Whether to use interleaved rotary position embedding. @@ -68,12 +69,12 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens * \param[in] stream CUDA stream used for the operation. */ void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens, - const NVTETensor freqs, NVTETensor input_grads, - const NVTE_QKV_Format qkv_format, const bool interleaved, - const int cp_size, const int cp_rank, const int s, const int b, - const int h, const int d, const int d2, const int stride_s_or_t, - const int stride_b, const int stride_h, const int stride_d, - cudaStream_t stream); + const NVTETensor freqs, const NVTETensor start_positions, + NVTETensor input_grads, const NVTE_QKV_Format qkv_format, + const bool interleaved, const int cp_size, const int cp_rank, + const int s, const int b, const int h, const int d, const int d2, + const int stride_s_or_t, const int stride_b, const int stride_h, + const int stride_d, cudaStream_t stream); /*! \brief Apply rotary positional embedding to the combined QKV input tensor. * diff --git a/transformer_engine/pytorch/attention/rope.py b/transformer_engine/pytorch/attention/rope.py index cc23d65a3e..0e1222c22f 100644 --- a/transformer_engine/pytorch/attention/rope.py +++ b/transformer_engine/pytorch/attention/rope.py @@ -149,7 +149,7 @@ def forward( cp_size, cp_rank, ) - ctx.save_for_backward(freqs, cu_seqlens) + ctx.save_for_backward(freqs, cu_seqlens, start_positions) ctx.tensor_format = tensor_format ctx.cp_size = cp_size ctx.cp_rank = cp_rank @@ -160,10 +160,11 @@ def forward( @staticmethod def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: """Fused RoPE backward.""" - freqs, cu_seqlens = ctx.saved_tensors + freqs, cu_seqlens, start_positions = ctx.saved_tensors grad_input = tex.fused_rope_backward( grad_output, freqs, + start_positions, QKVFormat[ctx.tensor_format], ctx.interleaved, cu_seqlens, @@ -171,7 +172,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx.cp_rank, ) - return grad_input, None, None, None, None, None, None, None + return grad_input, None, None, None, None, None, None, None, None class FusedQKVRoPEFunc(torch.autograd.Function): @@ -278,7 +279,6 @@ def _rotate_half(x: torch.Tensor, interleaved: bool) -> torch.Tensor: def _apply_rotary_pos_emb_base( t: torch.Tensor, freqs: torch.Tensor, - start_positions: torch.Tensor = None, tensor_format: str = "sbhd", interleaved: bool = False, ) -> torch.Tensor: @@ -291,45 +291,19 @@ def _apply_rotary_pos_emb_base( Input tensor of shape `[s, b, h, d]` or `[b, s, h, d]`, on which rotary positional embedding will be applied. freqs: torch.Tensor - Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` and dtype 'float', - with `s2 >= s` and `d2 <= d`. - start_positions: torch.Tensor, default = None. - Tokens in a sequence `i` should be applied with position encoding offset by - `start_positions[i]`. If `start_positions=None`, there's no offset. + Rotary positional embedding tensor of shape `[s2, 1, 1, d2]` or `[s2, b, 1, d2]` + and dtype 'float', with `s2 >= s` and `d2 <= d`. tensor_format: {'sbhd', 'bshd'}, default = 'sbhd' Should be `bshd` if `t` is of shape `[bs, seq, ...]`, or `sbhd` if `t` is of shape `[seq, bs, ...]`. interleaved: bool, default = False Whether to use interleaved rotary position embedding. """ - max_seq_len = freqs.shape[0] - cur_seq_len = t.shape[1] if tensor_format == "bshd" else t.shape[0] - - # In case `start_positions` are provided, create a staggered `freqs` tensor - # offset by the values in `start_positions`. - # `start_positions` is only supported for `cp_size=1` and inference. - if start_positions is not None: - max_offset = torch.max(start_positions) - assert ( - max_offset + cur_seq_len <= max_seq_len - ), f"Rotary Embeddings only suppported up to {max_seq_len} sequence length!" - - # Stack staggered rope embeddings along the batch dimension - freqs = torch.concatenate([freqs[i : i + cur_seq_len] for i in start_positions], dim=1) - - # Note that from this point, `freqs` has a shape `(s,b,1,d)`. - - # Only apply the rotary embeddings up to the sequence length of the running - # input. - assert ( - cur_seq_len <= max_seq_len - ), f"Rotary Embeddings only supported up to {max_seq_len} sequence length!" - freqs = freqs[:cur_seq_len] - # [seq, 1, 1, dim] -> [1, seq, 1, dim] or # [seq, b, 1, dim] -> [b, seq, 1, dim] if tensor_format == "bshd": freqs = freqs.transpose(0, 1) + # cos/sin first then dtype conversion for better precision cos_ = torch.cos(freqs).to(t.dtype) sin_ = torch.sin(freqs).to(t.dtype) @@ -366,7 +340,7 @@ def _get_freqs_on_this_cp_rank( ) # cp_size == 1 - return freqs + return freqs[:seqlen] def apply_rotary_pos_emb( @@ -388,13 +362,13 @@ def apply_rotary_pos_emb( Training: qkv_formats: "thd", "bshd", "sbhd" context parallel: yes - start_positions: no + start_positions: yes interleaving: yes Inference: qkv_formats: "thd", "bshd", "sbhd" context parallelism: no start_positions: yes - interleaving: yes + interleaving: yes Parameters ---------- @@ -423,22 +397,17 @@ def apply_rotary_pos_emb( cp_rank: int, default = 0. Context parallel rank. Only valid when `tensor_format` is 'thd' and `fused` is True. """ - - # `start_positions` is only supported for `cp_size=1` and inference. - assert not ( - cp_size > 1 and start_positions is not None - ), """start_positions != None with CP SIZE > 1 is not supported!""" - assert ( tensor_format != "thd" or cu_seqlens is not None ), "cu_seqlens must not be None when tensor_format is 'thd'." + # Fused apply rope logic for THD/BSHD/SBHD formats if fused: return FusedRoPEFunc.apply( t, freqs, start_positions, tensor_format, interleaved, cu_seqlens, cp_size, cp_rank ) - # Unfused THD format + # Unfused apply rope logic for THD format if tensor_format == "thd": cu_seqlens = cu_seqlens // cp_size seqlens = (cu_seqlens[1:] - cu_seqlens[:-1]).tolist() @@ -447,15 +416,18 @@ def apply_rotary_pos_emb( # `s1hd` tensors (for each sequence) and applies rotary embedding to # those sequences individually. # Note that if `start_positions` is not `None`, then for each sequence, - # it's corresponding rope offset is also supplied from `start_positions` - # individually. + # the freqs supplied are offset by the corresponding `start_positions` value. return torch.cat( [ _apply_rotary_pos_emb_base( x.unsqueeze(1), - _get_freqs_on_this_cp_rank(freqs, x.size(0), cp_size, cp_rank), - start_positions=( - start_positions[idx : idx + 1] if start_positions is not None else None + _get_freqs_on_this_cp_rank( + ( + freqs[start_positions[idx] :] if start_positions is not None else freqs + ), # offset the freqs + x.size(0), + cp_size, + cp_rank, ), interleaved=interleaved, ) @@ -463,17 +435,28 @@ def apply_rotary_pos_emb( ] ).squeeze(1) - # Unfused SBHD/BSHD format + # Unfused apply rope logic for SBHD/BSHD format follows ... + if tensor_format == "sbhd": seqlen = t.size(0) elif tensor_format == "bshd": seqlen = t.size(1) else: raise ValueError(f"Unsupported tensor_format: {tensor_format}.") + + if start_positions is not None: + max_offset = torch.max(start_positions) + assert ( + max_offset + seqlen * cp_size <= freqs.shape[0] + ), f"Rotary Embeddings only suppported up to {freqs.shape[0]} sequence length!" + + # Stack staggered rope embeddings along the batch dimension + freqs = torch.concatenate([freqs[i : i + seqlen * cp_size] for i in start_positions], dim=1) + # Note that from this point, `freqs` has a shape `(s,b,1,d)`. + return _apply_rotary_pos_emb_base( t, _get_freqs_on_this_cp_rank(freqs, seqlen, cp_size, cp_rank), - start_positions, tensor_format, interleaved=interleaved, ) @@ -505,7 +488,7 @@ def apply_fused_qkv_rotary_pos_emb( qkv_formats: "bshd", "sbhd" context parallelism: no start_positions: yes - interleaving: yes + interleaving: yes Parameters ---------- diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index d86a96959c..4a13abb9f7 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -346,6 +346,7 @@ at::Tensor fused_rope_forward(const at::Tensor &input, const at::Tensor &freqs, const int cp_rank); at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, + const std::optional start_positions, const NVTE_QKV_Format qkv_format, const bool interleaved, const std::optional cu_seqlens, const int cp_size, const int cp_rank); diff --git a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp index 064da8a670..d1dcf68c3d 100644 --- a/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp +++ b/transformer_engine/pytorch/csrc/extensions/apply_rope.cpp @@ -163,6 +163,7 @@ std::tuple fused_qkv_rope_forward( } at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor &freqs, + const std::optional start_positions, const NVTE_QKV_Format qkv_format, const bool interleaved, const std::optional cu_seqlens, const int cp_size, const int cp_rank) { @@ -180,6 +181,12 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor auto freqs_cu = makeTransformerEngineTensor(freqs); auto input_grads_cu = makeTransformerEngineTensor(input_grads); + auto start_positions_cu = TensorWrapper(); // empty start_positions tensor + if (start_positions) { + start_positions_cu = makeTransformerEngineTensor(start_positions.value()); + TORCH_CHECK(start_positions_cu.ndim() == 1, "expected 1D tensor"); + } + if (qkv_format == NVTE_QKV_Format::NVTE_THD) { TORCH_CHECK(output_grads.dim() == 3, "expected 3D tensor"); TORCH_CHECK(cu_seqlens.has_value(), "expected cu_seqlens tensor"); @@ -208,8 +215,8 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor auto cu_seqlens_cu = makeTransformerEngineTensor(cu_seqlens.value()); nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, - max_s, b, h, d, d2, stride_t, + start_positions_cu.data(), input_grads_cu.data(), qkv_format, + interleaved, cp_size, cp_rank, max_s, b, h, d, d2, stride_t, /*stride_b=*/0, stride_h, stride_d, at::cuda::getCurrentCUDAStream()); return input_grads; @@ -246,9 +253,9 @@ at::Tensor fused_rope_backward(const at::Tensor &output_grads, const at::Tensor auto cu_seqlens_cu = TensorWrapper(); // empty cu_seqlens tensor nvte_fused_rope_backward(output_grads_cu.data(), cu_seqlens_cu.data(), freqs_cu.data(), - input_grads_cu.data(), qkv_format, interleaved, cp_size, cp_rank, s, b, - h, d, d2, stride_s, stride_b, stride_h, stride_d, - at::cuda::getCurrentCUDAStream()); + start_positions_cu.data(), input_grads_cu.data(), qkv_format, + interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s, stride_b, + stride_h, stride_d, at::cuda::getCurrentCUDAStream()); return input_grads; }