Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 136 additions & 0 deletions custom_ops/gpu_ops/append_attn/decoder_write_cache_with_rope_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,142 @@ __global__ void append_decode_cache_T_rope_kernel(
}
}

template <typename T, int VecSize = 1>
__global__ void append_decode_cache_T_neox_partial_rope_kernel(
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
// head_size]
T* __restrict__ key_cache, // [num_blocks, kv_num_heads, block_size,
// head_size // 2]
T* __restrict__ value_cache, // [num_blocks, kv_num_heads, block_size,
// head_size // 2]
T* __restrict__ qkv_out,
const int* __restrict__ block_tables, // [bsz, max_blocks_per_seq]
const int* __restrict__ cu_seqlens_q,
const int* __restrict__ seq_lens, // [bsz]
const int* __restrict__ seq_lens_encoder, // [bsz]
const float* __restrict__ cos_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
const float* __restrict__ sin_emb, // [2, 1, max_model_len, 1, rotary_dim/2]
const int max_seq_len,
const int max_blocks_per_seq,
const int num_heads,
const int head_size,
const int rotary_dim,
const int block_size,
const uint32_t elem_cnt,
const int kv_num_heads,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>;
using LoadBiasT = AlignedVector<T, VecSize>;
using LoadKVT = AlignedVector<T, VecSize>;
constexpr int HalfVecSize = VecSize / 2;
using LoadEmbT = AlignedVector<float, VecSize>;

LoadT left_vec, right_vec;
LoadBiasT left_bias_vec, right_bias_vec;
LoadKVT left_cache_vec, right_cache_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;

int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const int half_head_size = head_size / 2;
const int half_rotary_dim = rotary_dim / 2;
const int64_t hidden_size = (num_heads + 2 * kv_num_heads) * head_size;
const int64_t half_hidden_size = hidden_size / 2;
// const int64_t offset = 2 * hidden_size;

for (int32_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int ori_bi = linear_index / half_hidden_size;
const int bias = linear_index % half_hidden_size;
const int hi = bias / half_head_size; // q + k + v
const int h_bias = bias % half_head_size;
if (hi < num_heads && h_bias >= half_rotary_dim){
continue;
}
if (seq_lens_encoder[ori_bi] > 0) continue;
const int write_seq_id = seq_lens[ori_bi];
if (write_seq_id == 0) continue;
const int start_token_idx = cu_seqlens_q[ori_bi];

const int* block_table_now = nullptr;

block_table_now = block_tables + ori_bi * max_blocks_per_seq;
const int block_idx = block_table_now[write_seq_id / block_size];
const int block_offset = write_seq_id % block_size;
uint32_t ori_idx_left =
start_token_idx * hidden_size + hi * head_size + h_bias;
uint32_t ori_idx_right = ori_idx_left + half_head_size;
if (hi < num_heads){
ori_idx_right = ori_idx_left + half_rotary_dim;
}else if (hi < num_heads + kv_num_heads){
if (h_bias < half_rotary_dim){
ori_idx_right = ori_idx_left + half_rotary_dim;
}else{
ori_idx_left = ori_idx_left + half_rotary_dim;
ori_idx_right = ori_idx_left + half_rotary_dim;
}
}

Load<T, VecSize>(&qkv[ori_idx_left], &left_vec);
Load<T, VecSize>(&qkv[ori_idx_right], &right_vec);

if (hi < num_heads + kv_num_heads) {
// q k rope
const uint32_t emb_idx = write_seq_id * half_rotary_dim + h_bias;
uint32_t new_emb_idx = rope_3d ? emb_idx + ori_bi * max_seq_len * head_size * 2 : emb_idx;
if (h_bias < half_rotary_dim){
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
}
}
#pragma unroll
for (int i = 0; i < VecSize; i++) {
// rope
float input_left = static_cast<float>(left_vec[i]);
float input_right = static_cast<float>(right_vec[i]);
if (hi < num_heads + kv_num_heads && h_bias < half_rotary_dim) {
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_bias_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
right_bias_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
} else {
left_bias_vec[i] = static_cast<T>(input_left);
right_bias_vec[i] = static_cast<T>(input_right);
}
}
if (hi < num_heads) {
// write q
Store<T, VecSize>(left_bias_vec, &qkv_out[ori_idx_left]);
Store<T, VecSize>(right_bias_vec, &qkv_out[ori_idx_right]);
} else {
// write k/v
const uint32_t kv_head_idx = (hi - num_heads) % kv_num_heads;
uint32_t tgt_idx_left =
block_idx * kv_num_heads * block_size * head_size +
kv_head_idx * block_size * head_size + block_offset * head_size +
h_bias;
uint32_t tgt_idx_right = tgt_idx_left + half_head_size;
if (hi < num_heads + kv_num_heads) {
if (h_bias < half_rotary_dim) {
tgt_idx_right = tgt_idx_left + half_rotary_dim;
}else{
tgt_idx_left = tgt_idx_left + half_rotary_dim;
tgt_idx_right = tgt_idx_left + half_rotary_dim;
}
Store<T, VecSize>(left_bias_vec, &key_cache[tgt_idx_left]);
Store<T, VecSize>(right_bias_vec, &key_cache[tgt_idx_right]);
} else {
Store<T, VecSize>(left_bias_vec, &value_cache[tgt_idx_left]);
Store<T, VecSize>(right_bias_vec, &value_cache[tgt_idx_right]);
}
}
}
}

template <typename T, int VecSize = 1>
__global__ void append_decode_cache_T_neox_rope_kernel(
const T* __restrict__ qkv, // [bsz, num_heads + 2 * kv_num_heads,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
const int num_heads,
const int kv_num_heads,
const int dim_head,
const int rotary_dim,
const int block_size,
const int bsz,
const cudaStream_t& stream,
Expand Down Expand Up @@ -137,7 +138,29 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
kv_num_heads,
rope_3d);
} else {
append_decode_cache_T_neox_rope_kernel<T, PackSize>
if (rotary_dim < dim_head){
append_decode_cache_T_neox_partial_rope_kernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
qkv_out,
block_tables,
cu_seqlens_q,
seq_lens,
seq_lens_encoder,
cos_emb,
sin_emb,
max_seq_len,
max_blocks_per_seq,
num_heads,
dim_head,
rotary_dim,
block_size,
elem_nums,
kv_num_heads,
rope_3d);
}else{
append_decode_cache_T_neox_rope_kernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(reinterpret_cast<const T*>(qkv),
key_cache,
value_cache,
Expand All @@ -157,6 +180,7 @@ void append_decode_cache_rope(const QKV_TYPE* qkv,
elem_nums,
kv_num_heads,
rope_3d);
}
}
} else {
if (qkv_out_scales) {
Expand Down Expand Up @@ -534,11 +558,20 @@ void DecoderWriteCacheWithRoPEKernel(
const float* cos_emb =
rotary_embs ? rotary_embs.get().data<float>() : nullptr;
const float* sin_emb;
int rotary_dim = dim_head;
if (rotary_embs) {
sin_emb =
use_neox_rotary_style
? rotary_embs.get().data<float>() + max_seq_len * dim_head
: rotary_embs.get().data<float>() + max_seq_len * dim_head / 2;
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
if(rotary_dim < dim_head){
if (!use_neox_rotary_style || qkv_out_scales || q_norm_weight || k_norm_weight|| cache_quant_type_str != "none"){
PADDLE_THROW(phi::errors::Fatal(
"partial_rotary_factor < 1.0 only supports neox_rotary_style=True, qkv_out_scales is None, q_norm_weight/k_norm_weight) is None, and cache_quant_type_str is 'none'."));
}
sin_emb = rotary_embs.get().data<float>() + max_seq_len * rotary_dim / 2;
}
}

if (q_norm_weight && k_norm_weight) {
Expand Down Expand Up @@ -599,6 +632,7 @@ void DecoderWriteCacheWithRoPEKernel(
num_heads,
kv_num_heads,
dim_head,
rotary_dim,
block_size,
bsz,
stream,
Expand Down
103 changes: 102 additions & 1 deletion custom_ops/gpu_ops/append_attn/encoder_write_cache_with_rope_impl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -900,6 +900,74 @@ __global__ void GQANeoxVariableLengthRotaryKernel(
}
}

template <typename T, int VecSize = 1>
__global__ void GQANeoxVariableLengthPartialRotaryKernel(
const T *qkv,
const float *cos_emb,
const float *sin_emb,
const int *batch_id_per_token,
const int *cu_seqlens_q,
const int *seq_lens,
const int *seq_lens_decoder,
const float *qkv_out_scales,
const T *qkv_biases,
T *qkv_out,
const int64_t elem_cnt,
const int q_num_head,
const int kv_num_head,
const int seq_len,
const int head_dim,
const int rotary_dim,
const bool rope_3d) {
using LoadT = AlignedVector<T, VecSize>;
using LoadEmbT = AlignedVector<float, VecSize>;
LoadT left_vec;
LoadT right_vec;
LoadEmbT cos_emb_vec;
LoadEmbT sin_emb_vec;
int64_t global_thread_idx = blockDim.x * blockIdx.x + threadIdx.x;
const int rotary_dim_half = rotary_dim / 2;
const int offset = (q_num_head + kv_num_head) * rotary_dim_half;
for (int64_t linear_index = global_thread_idx * VecSize,
step = gridDim.x * blockDim.x * VecSize;
linear_index < elem_cnt;
linear_index += step) {
const int token_idx = linear_index / offset;
const int ori_bi = batch_id_per_token[token_idx];
if (seq_lens && seq_lens[ori_bi] == 0) continue;
const int bias = linear_index % offset;
const int hi = bias / rotary_dim_half;
const int h_bias = bias % rotary_dim_half;

const int ori_seq_id = (token_idx - cu_seqlens_q[ori_bi]) + seq_lens_decoder[ori_bi];

const int emb_idx = ori_seq_id * rotary_dim_half + h_bias;
int64_t new_emb_idx = rope_3d ? emb_idx + ori_bi * head_dim * seq_len * 2 : emb_idx;
const int base_idx_left =
token_idx * (q_num_head + 2 * kv_num_head) * head_dim + hi * head_dim +
h_bias;
const int base_idx_right = base_idx_left + rotary_dim_half;

Load<T, VecSize>(&qkv[base_idx_left], &left_vec);
Load<T, VecSize>(&qkv[base_idx_right], &right_vec);
Load<float, VecSize>(&cos_emb[new_emb_idx], &cos_emb_vec);
Load<float, VecSize>(&sin_emb[new_emb_idx], &sin_emb_vec);
#pragma unroll
for (int i = 0; i < VecSize; i++) {
const float input_left = static_cast<float>(left_vec[i]);
const float input_right = static_cast<float>(right_vec[i]);
const float cos_tmp = cos_emb_vec[i];
const float sin_tmp = sin_emb_vec[i];
left_vec[i] =
static_cast<T>(input_left * cos_tmp - input_right * sin_tmp);
right_vec[i] =
static_cast<T>(input_right * cos_tmp + input_left * sin_tmp);
}
Store<T, VecSize>(left_vec, &qkv_out[base_idx_left]);
Store<T, VecSize>(right_vec, &qkv_out[base_idx_right]);
}
}

template <typename T, int VecSize = 1>
__global__ void cache_kernel(
const T *__restrict__ qkv, // [num_tokens, num_heads + 2 * kv_num_heads,
Expand Down Expand Up @@ -1755,6 +1823,7 @@ void gqa_rotary_qk_variable(
const int seq_len,
const int input_output_len,
const int dim_head,
const int rotary_dim,
const cudaStream_t &stream,
bool use_neox_style = false,
bool rope_3d = false) {
Expand Down Expand Up @@ -1835,7 +1904,38 @@ void gqa_rotary_qk_variable(
dim_head,
rope_3d);
} else {
GQANeoxVariableLengthRotaryKernel<T, PackSize>
if (rotary_dim < dim_head){
PD_CHECK((rotary_dim / 2) % PackSize == 0);
elem_nums =
qkv_out_scales
? token_num * (num_heads + 2 * kv_num_heads) * rotary_dim
: token_num * (num_heads + kv_num_heads) * rotary_dim; // for all q k v
if (use_neox_style) {
elem_nums /= 2;
}
const int pack_num_new = elem_nums / PackSize;
GetNumBlocks<128>(pack_num_new, &grid_size);
GQANeoxVariableLengthPartialRotaryKernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<const T *>(qkv_input),
cos_emb,
rotary_emb + input_output_len * rotary_dim / 2,
batch_id_per_token,
cu_seqlens_q,
seq_lens,
seq_lens_decoder,
qkv_out_scales,
qkv_bias,
qkv_out,
elem_nums,
num_heads,
kv_num_heads,
seq_len,
dim_head,
rotary_dim,
rope_3d);
}else{
GQANeoxVariableLengthRotaryKernel<T, PackSize>
<<<grid_size, blocksize, 0, stream>>>(
reinterpret_cast<const T *>(qkv_input),
cos_emb,
Expand All @@ -1853,6 +1953,7 @@ void gqa_rotary_qk_variable(
seq_len,
dim_head,
rope_3d);
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,19 @@ void EncoderWriteCacheWithRopeKernel(
auto kv_num_heads = meta_data.kv_num_heads;
auto head_dim = meta_data.head_dims;
bool is_scale_channel_wise = false;
int rotary_dim = head_dim;
if (cache_k_scale && cache_k_scale.get().dims()[0] == head_dim * kv_num_heads) {
is_scale_channel_wise = true;
}
if (rotary_embs){
rotary_dim = rotary_embs.get().dims()[rotary_embs.get().dims().size()-1] * 2;
if(rotary_dim < head_dim){
if (!use_neox_style || q_norm_weight || k_norm_weight || num_heads == kv_num_heads || is_scale_channel_wise){
PADDLE_THROW(phi::errors::Fatal(
"partial_rotary_factor < 1.0 only supports use_neox_rotary_style=True, q_norm_weight/k_norm_weight) is None, GQA and is_scale_channel_wise=false."));
}
}
}

if (q_norm_weight && k_norm_weight) {
if (num_heads != kv_num_heads && !is_scale_channel_wise && !use_neox_style) {
Expand Down Expand Up @@ -125,6 +135,7 @@ void EncoderWriteCacheWithRopeKernel(
max_seq_len,
rope_3d ? rotary_embs.get().dims()[3] : rotary_embs.get().dims()[2],
head_dim,
rotary_dim,
stream,
use_neox_style,
rope_3d);
Expand Down
3 changes: 2 additions & 1 deletion fastdeploy/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def __init__(
self.eos_tokens_lens: int = 2
self.lm_head_fp32: bool = False
self.model_format = "auto"
self.partial_rotary_factor: float = 1.0
for key, value in args.items():
if hasattr(self, key):
setattr(self, key, value)
Expand Down Expand Up @@ -396,7 +397,7 @@ def __init__(
# model for mtp/eagle/draft_model
self.model: Optional[str] = None
# quantization of model
self.quantization: Optional[str] = None
self.quantization: Optional[Dict[str, Any]] = None
# allocate more blocks to prevent mtp from finishing the block earlier than the main model
# Fixed now
self.num_gpu_block_expand_ratio: Optional[float] = 1
Expand Down
Loading
Loading