Skip to content
170 changes: 138 additions & 32 deletions tests/pytorch/test_fused_rope.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,6 @@ def test_fused_rope(
# are with the maximum length of the rope embeddings.
pytest.skip("Skipping test with margin=0 and start_positions=True")

if start_positions == True and cp_size > 1:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest.skip("Skipping test with cp_size>1 and start_positions=True")

device = torch.device("cuda:0")
batch_size, head_num = 2, 64
t = torch.rand(
Expand Down Expand Up @@ -102,11 +98,8 @@ def test_fused_rope(
cp_rank=cp_rank,
).to(dtype)
loss_unfused = loss_func(output_unfused)

if not isinstance(start_positions, torch.Tensor):
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()

loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None

# fused
Expand All @@ -121,17 +114,12 @@ def test_fused_rope(
cp_rank=cp_rank,
)
loss_fused = loss_func(output_fused)

if not isinstance(start_positions, torch.Tensor):
loss_fused.backward()
grad_fused = t.grad.detach().clone()
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None

torch.testing.assert_close(output_fused, output_unfused)

if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused)

torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous()


Expand All @@ -156,10 +144,6 @@ def test_fused_rope_thd(
margin: int,
) -> None:

if start_positions == True and cp_size > 1:
# `start_positions` is only supported for `cp_size=1` and inference.
pytest.skip("Skipping test with cp_size>1 and start_positions=True")

device = torch.device("cuda:0")
batch_size, head_num = 2, 64
cu_seqlens = [0, 400, 542, 711, 727, 752, 1270, 1426, 1450, 1954, 2044, 2048]
Expand Down Expand Up @@ -214,10 +198,8 @@ def test_fused_rope_thd(
cp_rank=cp_rank,
).to(dtype)
loss_unfused = loss_func(output_unfused)

if not isinstance(start_positions, torch.Tensor):
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
loss_unfused.backward()
grad_unfused = t.grad.detach().clone()
t.grad = None

# fused
Expand All @@ -233,18 +215,142 @@ def test_fused_rope_thd(
cp_rank=cp_rank,
)
loss_fused = loss_func(output_fused)

if not isinstance(start_positions, torch.Tensor):
loss_fused.backward()
grad_fused = t.grad.detach().clone()
loss_fused.backward()
grad_fused = t.grad.detach().clone()
t.grad = None

torch.testing.assert_close(output_fused, output_unfused)
torch.testing.assert_close(grad_fused, grad_unfused)
assert output_fused.is_contiguous()

if not isinstance(start_positions, torch.Tensor):
torch.testing.assert_close(grad_fused, grad_unfused)

assert output_fused.is_contiguous()
@pytest.mark.parametrize("start_positions", [False, True])
@pytest.mark.parametrize("dtype", [torch.bfloat16])
@pytest.mark.parametrize("hidden_size", [128, 256])
@pytest.mark.parametrize("rotary_percent", [1.0])
@pytest.mark.parametrize("loss_func", [_overlapping_grad])
@pytest.mark.parametrize("cp_size", [2])
@pytest.mark.parametrize("interleaved", [False, True])
def test_unfused_rope_thd_vs_bshd(
dtype: torch.dtype,
hidden_size: int,
rotary_percent: float,
loss_func: Callable,
cp_size: int,
interleaved: bool,
start_positions: bool,
) -> None:
"""
This is just a sanity check to ensure that the unfused RoPE in THD/SBHD/BSHD
formats are the same.
"""
device = torch.device("cuda:0")
seqlen, max_seqlen = 16, 2048
batch_size, head_num = 4, 256

# NOTE: dtype=torch.int32 is important, otherwise the cumsum will be in int64 and
# that causes unexpected issues.
seq_lens = torch.tensor([seqlen for _ in range(batch_size)], dtype=torch.int32)

cu_seqlens = torch.cumsum(torch.cat([torch.zeros(1, dtype=torch.int32), seq_lens]), dim=0).to(
device=device, dtype=torch.int32
)

# Create a tensor in THD format
thd = torch.rand(
(cu_seqlens[-1] // cp_size, head_num, hidden_size),
dtype=dtype,
device=device,
)
thd.requires_grad = True

# Clone the tensor to create a tensor in BSHD format
bshd = thd.view(batch_size, -1, head_num, hidden_size).clone().detach()
bshd = bshd.to(dtype=dtype, device=device)
bshd.requires_grad = True

# Clone the tensor to create a tensor in SBHD format
sbhd = bshd.transpose(1, 0).clone().detach()
sbhd = sbhd.to(dtype=dtype, device=device)
sbhd.requires_grad = True

rotary_pos_emb = RotaryPositionEmbedding(hidden_size, rotary_percent, interleaved=interleaved)
emb = rotary_pos_emb(max_seqlen)
assert emb.is_contiguous()

start_positions = cu_seqlens[:-1] if start_positions else None

for cp_rank in range(cp_size):
# unfused bshd
output_unfused_bshd = apply_rotary_pos_emb(
bshd.float(),
emb,
start_positions=start_positions,
interleaved=interleaved,
fused=False,
tensor_format="bshd",
cu_seqlens=cu_seqlens,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)
loss_unfused_bshd = loss_func(output_unfused_bshd)
loss_unfused_bshd.backward()
grad_unfused_bshd = bshd.grad.detach().clone()
bshd.grad = None

# unfused sbhd
output_unfused_sbhd = apply_rotary_pos_emb(
sbhd.float(),
emb,
start_positions=start_positions,
interleaved=interleaved,
fused=False,
tensor_format="sbhd",
cu_seqlens=cu_seqlens,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)

loss_unfused_sbhd = loss_func(output_unfused_sbhd)
loss_unfused_sbhd.backward()
grad_unfused_sbhd = sbhd.grad.detach().clone()
sbhd.grad = None

# unfused thd
output_unfused_thd = apply_rotary_pos_emb(
thd.float(),
emb,
start_positions=start_positions,
tensor_format="thd",
interleaved=interleaved,
fused=False,
cu_seqlens=cu_seqlens,
cp_size=cp_size,
cp_rank=cp_rank,
).to(dtype)

loss_unfused_thd = loss_func(output_unfused_thd)
loss_unfused_thd.backward()
grad_unfused_thd = thd.grad.detach().clone()
thd.grad = None

torch.testing.assert_close(
output_unfused_bshd.reshape(*output_unfused_thd.shape), output_unfused_thd
)
torch.testing.assert_close(
output_unfused_sbhd.transpose(1, 0).reshape(*output_unfused_thd.shape),
output_unfused_thd,
)
torch.testing.assert_close(
grad_unfused_bshd.reshape(*grad_unfused_thd.shape), grad_unfused_thd
)
torch.testing.assert_close(
grad_unfused_sbhd.transpose(1, 0).reshape(*grad_unfused_thd.shape), grad_unfused_thd
)

assert output_unfused_thd.is_contiguous()
assert output_unfused_bshd.is_contiguous()
assert output_unfused_sbhd.is_contiguous()


@pytest.mark.parametrize("start_positions", [True, False])
Expand Down
85 changes: 44 additions & 41 deletions transformer_engine/common/fused_rope/fused_rope.cu
Original file line number Diff line number Diff line change
Expand Up @@ -155,18 +155,18 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq
cur_seqlens = s;
}

int s_id_for_freqs;
// Offset the RoPE embedding by start_positions if provided.
int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id];
int s_id_for_freqs = s_id + begin_offset;

// If CP_SIZE > 1, offset the RoPE embedding by cp_rank based on the dual-chunk order.
if (cp_size > 1) {
assert(cur_seqlens % 2 == 0);
if (s_id < cur_seqlens / 2) {
s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2;
s_id_for_freqs += cp_rank * cur_seqlens / 2;
} else {
s_id_for_freqs =
cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2;
s_id_for_freqs += cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 - cur_seqlens / 2;
}
} else {
int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id];
s_id_for_freqs = s_id + begin_offset;
}

fused_rope_block_forward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block,
Expand All @@ -175,11 +175,11 @@ __global__ void fused_rope_forward_kernel(const scalar_t *src, const int *cu_seq

template <typename scalar_t>
__global__ void fused_rope_backward_kernel(
const scalar_t *src, const int *cu_seqlens, const float *freqs, scalar_t *dst,
const bool interleaved, const int cp_size, const int cp_rank, const int s, const int h,
const int d, const int d2, const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, const int o_stride_s_or_t, const int o_stride_b, const int o_stride_h,
const int o_stride_d) {
const scalar_t *src, const int *cu_seqlens, const float *freqs, const int *start_positions,
scalar_t *dst, const bool interleaved, const int cp_size, const int cp_rank, const int s,
const int h, const int d, const int d2, const int stride_s_or_t, const int stride_b,
const int stride_h, const int stride_d, const int o_stride_s_or_t, const int o_stride_b,
const int o_stride_h, const int o_stride_d) {
int s_id = blockIdx.x, b_id = blockIdx.y;
int offset_block, offset_block_dst;
int cur_seqlens;
Expand All @@ -197,17 +197,18 @@ __global__ void fused_rope_backward_kernel(
cur_seqlens = s;
}

int s_id_for_freqs;
// Offset the RoPE embedding by start_positions if provided.
int begin_offset = (start_positions == nullptr) ? 0 : start_positions[b_id];
int s_id_for_freqs = s_id + begin_offset;

// If CP_SIZE > 1, offset the RoPE embedding by cp_rank based on the dual-chunk order.
if (cp_size > 1) {
assert(cur_seqlens % 2 == 0);
if (s_id < cur_seqlens / 2) {
s_id_for_freqs = s_id + cp_rank * cur_seqlens / 2;
s_id_for_freqs += cp_rank * cur_seqlens / 2;
} else {
s_id_for_freqs =
cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 + s_id - cur_seqlens / 2;
s_id_for_freqs += cur_seqlens * cp_size - (cp_rank + 1) * cur_seqlens / 2 - cur_seqlens / 2;
}
} else {
s_id_for_freqs = s_id;
}

fused_rope_block_backward(src, freqs, dst, interleaved, s_id_for_freqs, offset_block,
Expand Down Expand Up @@ -495,12 +496,12 @@ void fused_rope_forward_launcher(const scalar_t *input, const int *cu_seqlens, c

template <typename scalar_t>
void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_seqlens,
const float *freqs, scalar_t *input_grads,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream) {
const float *freqs, const int *start_positions,
scalar_t *input_grads, const NVTE_QKV_Format qkv_format,
const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2,
const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, cudaStream_t stream) {
int warps_per_block = h < 16 ? 4 : 8;
dim3 blocks(s, b);
dim3 threads(THREADS_PER_WARP, warps_per_block);
Expand All @@ -521,9 +522,9 @@ void fused_rope_backward_launcher(const scalar_t *output_grads, const int *cu_se
const int o_stride_d = 1;

fused_rope_backward_kernel<<<blocks, threads, shared_mem_size, stream>>>(
output_grads, cu_seqlens, freqs, input_grads, interleaved, cp_size, cp_rank, s, h, d, d2,
stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b, o_stride_h,
o_stride_d);
output_grads, cu_seqlens, freqs, start_positions, input_grads, interleaved, cp_size, cp_rank,
s, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, o_stride_s_or_t, o_stride_b,
o_stride_h, o_stride_d);
NVTE_CHECK_CUDA(cudaGetLastError());
}

Expand Down Expand Up @@ -590,16 +591,18 @@ void fused_rope_forward(const Tensor &input, const Tensor &cu_seqlens, const Ten
}

void fused_rope_backward(const Tensor &output_grads, const Tensor &cu_seqlens, const Tensor &freqs,
Tensor *input_grads, const NVTE_QKV_Format qkv_format,
const bool interleaved, const int cp_size, const int cp_rank, const int s,
const int b, const int h, const int d, const int d2,
const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, cudaStream_t stream) {
const Tensor &start_positions, Tensor *input_grads,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream) {
TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(
output_grads.data.dtype, scalar_t,
fused_rope_backward_launcher(reinterpret_cast<const scalar_t *>(output_grads.data.dptr),
reinterpret_cast<const int *>(cu_seqlens.data.dptr),
reinterpret_cast<const float *>(freqs.data.dptr),
reinterpret_cast<const int *>(start_positions.data.dptr),
reinterpret_cast<scalar_t *>(input_grads->data.dptr), qkv_format,
interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
stride_b, stride_h, stride_d, stream););
Expand Down Expand Up @@ -663,18 +666,18 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens
}

void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream) {
const NVTETensor freqs, const NVTETensor start_positions,
NVTETensor input_grads, const NVTE_QKV_Format qkv_format,
const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2,
const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, cudaStream_t stream) {
NVTE_API_CALL(nvte_fused_rope_backward);
using namespace transformer_engine;
fused_rope_backward(*convertNVTETensorCheck(output_grads), *convertNVTETensorCheck(cu_seqlens),
*convertNVTETensorCheck(freqs), convertNVTETensorCheck(input_grads),
qkv_format, interleaved, cp_size, cp_rank, s, b, h, d, d2, stride_s_or_t,
stride_b, stride_h, stride_d, stream);
*convertNVTETensorCheck(freqs), *convertNVTETensorCheck(start_positions),
convertNVTETensorCheck(input_grads), qkv_format, interleaved, cp_size,
cp_rank, s, b, h, d, d2, stride_s_or_t, stride_b, stride_h, stride_d, stream);
}

void nvte_fused_qkv_rope_forward(const NVTETensor qkv_input, const NVTETensor q_freqs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens
* \param[in] cu_seqlens The cumulative sum of sequence lengths tensor.
* (Required for the thd format, empty tensor for other formats)
* \param[in] freqs The freqs tensor.
* \param[in] start_positions The beginning offsets for applying RoPE embeddings.
* \param[out] input_grads Input gradient tensor to calculate.
* \param[in] qkv_format QKV format.
* \param[in] interleaved Whether to use interleaved rotary position embedding.
Expand All @@ -68,12 +69,12 @@ void nvte_fused_rope_forward(const NVTETensor input, const NVTETensor cu_seqlens
* \param[in] stream CUDA stream used for the operation.
*/
void nvte_fused_rope_backward(const NVTETensor output_grads, const NVTETensor cu_seqlens,
const NVTETensor freqs, NVTETensor input_grads,
const NVTE_QKV_Format qkv_format, const bool interleaved,
const int cp_size, const int cp_rank, const int s, const int b,
const int h, const int d, const int d2, const int stride_s_or_t,
const int stride_b, const int stride_h, const int stride_d,
cudaStream_t stream);
const NVTETensor freqs, const NVTETensor start_positions,
NVTETensor input_grads, const NVTE_QKV_Format qkv_format,
const bool interleaved, const int cp_size, const int cp_rank,
const int s, const int b, const int h, const int d, const int d2,
const int stride_s_or_t, const int stride_b, const int stride_h,
const int stride_d, cudaStream_t stream);

/*! \brief Apply rotary positional embedding to the combined QKV input tensor.
*
Expand Down
Loading