@@ -427,18 +427,42 @@ __device__ inline bool is_finite(const T val) {
427427#endif
428428}
429429
430+ // Scoring function enums
431+ enum ScoringFunc {
432+ SCORING_NONE = 0 , // no activation function
433+ SCORING_SIGMOID = 1 // apply sigmoid
434+ };
435+
436+ // Efficient sigmoid approximation from TensorRT-LLM
437+ __device__ inline float sigmoid_accurate (float x) {
438+ return 0 .5f * tanhf (0 .5f * x) + 0 .5f ;
439+ }
440+
430441template <typename T>
431- __device__ void topk_with_k2 (T* output, T const * input,
442+ __device__ inline T apply_sigmoid (T val) {
443+ float f = cuda_cast<float , T>(val);
444+ return cuda_cast<T, float >(sigmoid_accurate (f));
445+ }
446+
447+ template <typename T>
448+ __device__ void topk_with_k2 (T* output, T const * input, T const * bias,
432449 cg::thread_block_tile<32 > const & tile,
433450 int32_t const lane_id,
434- int const num_experts_per_group) {
451+ int const num_experts_per_group,
452+ int const scoring_func) {
435453 // Get the top2 per thread
436454 T largest = neg_inf<T>();
437455 T second_largest = neg_inf<T>();
438456
439457 if (num_experts_per_group > WARP_SIZE) {
440458 for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
441459 T value = input[i];
460+ // Apply scoring function if needed
461+ if (scoring_func == SCORING_SIGMOID) {
462+ value = apply_sigmoid (value);
463+ }
464+ value = value + bias[i];
465+
442466 if (value > largest) {
443467 second_largest = largest;
444468 largest = value;
@@ -448,7 +472,13 @@ __device__ void topk_with_k2(T* output, T const* input,
448472 }
449473 } else {
450474 for (int i = lane_id; i < num_experts_per_group; i += WARP_SIZE) {
451- largest = input[i];
475+ T value = input[i];
476+ // Apply scoring function if needed
477+ if (scoring_func == SCORING_SIGMOID) {
478+ value = apply_sigmoid (value);
479+ }
480+ value = value + bias[i];
481+ largest = value;
452482 }
453483 }
454484
@@ -472,17 +502,21 @@ __device__ void topk_with_k2(T* output, T const* input,
472502}
473503
474504template <typename T>
475- __global__ void topk_with_k2_kernel (T* output, T* input,
505+ __global__ void topk_with_k2_kernel (T* output, T* input, T const * bias,
476506 int64_t const num_tokens,
477507 int64_t const num_cases,
478508 int64_t const n_group,
479- int64_t const num_experts_per_group) {
509+ int64_t const num_experts_per_group,
510+ int const scoring_func) {
480511 int32_t warp_id = threadIdx .x / WARP_SIZE;
481512 int32_t lane_id = threadIdx .x % WARP_SIZE;
482513
483514 int32_t case_id = blockIdx .x * NUM_WARPS_PER_BLOCK + warp_id;
484515 if (case_id < num_cases) {
485516 input += case_id * num_experts_per_group;
517+ // bias is per expert group, offset to current group
518+ int32_t group_id = case_id % n_group;
519+ T const * group_bias = bias + group_id * num_experts_per_group;
486520 output += case_id;
487521
488522 cg::thread_block block = cg::this_thread_block ();
@@ -491,7 +525,8 @@ __global__ void topk_with_k2_kernel(T* output, T* input,
491525#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
492526 asm volatile (" griddepcontrol.wait;" );
493527#endif
494- topk_with_k2 (output, input, tile, lane_id, num_experts_per_group);
528+ topk_with_k2 (output, input, group_bias, tile, lane_id,
529+ num_experts_per_group, scoring_func);
495530 }
496531#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900))
497532 asm volatile (" griddepcontrol.launch_dependents;" );
@@ -500,16 +535,15 @@ __global__ void topk_with_k2_kernel(T* output, T* input,
500535
501536template <typename T, typename IdxT>
502537__global__ void group_idx_and_topk_idx_kernel (
503- T* scores, T const * group_scores, T * topk_values, IdxT* topk_indices,
504- T* scores_with_bias , int64_t const num_tokens, int64_t const n_group,
538+ T* scores, T const * group_scores, float * topk_values, IdxT* topk_indices,
539+ T const * bias , int64_t const num_tokens, int64_t const n_group,
505540 int64_t const topk_group, int64_t const topk, int64_t const num_experts,
506541 int64_t const num_experts_per_group, bool renormalize,
507- double routed_scaling_factor) {
542+ double routed_scaling_factor, int scoring_func ) {
508543 int32_t warp_id = threadIdx .x / WARP_SIZE;
509544 int32_t lane_id = threadIdx .x % WARP_SIZE;
510545 int32_t case_id =
511546 blockIdx .x * NUM_WARPS_PER_BLOCK + warp_id; // one per token
512- scores_with_bias += case_id * num_experts;
513547 scores += case_id * num_experts;
514548 group_scores += case_id * n_group;
515549 topk_values += case_id * topk;
@@ -577,10 +611,16 @@ __global__ void group_idx_and_topk_idx_kernel(
577611 int32_t offset = i_group * num_experts_per_group;
578612 for (int32_t i = lane_id; i < align_num_experts_per_group;
579613 i += WARP_SIZE) {
580- T candidates = (i < num_experts_per_group) &&
581- is_finite (scores_with_bias[offset + i])
582- ? scores_with_bias[offset + i]
583- : neg_inf<T>();
614+ T candidates = neg_inf<T>();
615+ if (i < num_experts_per_group) {
616+ // Apply scoring function (if any) and add bias
617+ T input = scores[offset + i];
618+ if (is_finite (input)) {
619+ T score = (scoring_func == SCORING_SIGMOID) ? apply_sigmoid (input)
620+ : input;
621+ candidates = score + bias[offset + i];
622+ }
623+ }
584624 queue.add (candidates, offset + i);
585625 }
586626 if (group_scores[i_group] == topk_group_value) {
@@ -602,11 +642,12 @@ __global__ void group_idx_and_topk_idx_kernel(
602642 for (int i = lane_id;
603643 i < warp_topk::round_up_to_multiple_of<WARP_SIZE>(topk);
604644 i += WARP_SIZE) {
605- T value =
606- i < topk
607- ? scores[s_topk_idx[i]]
608- : cuda_cast<T, float >(0 .0f ); // Load the valid value of expert
645+ T value = cuda_cast<T, float >(0 .0f );
609646 if (i < topk) {
647+ // Load the score value (without bias) for normalization
648+ T input = scores[s_topk_idx[i]];
649+ value =
650+ (scoring_func == SCORING_SIGMOID) ? apply_sigmoid (input) : input;
610651 s_topk_value[i] = value;
611652 }
612653 topk_sum +=
@@ -627,12 +668,12 @@ __global__ void group_idx_and_topk_idx_kernel(
627668 value = cuda_cast<float , T>(s_topk_value[i]) * routed_scaling_factor;
628669 }
629670 topk_indices[i] = s_topk_idx[i];
630- topk_values[i] = cuda_cast<T, float >( value) ;
671+ topk_values[i] = value;
631672 }
632673 } else {
633674 for (int i = lane_id; i < topk; i += WARP_SIZE) {
634675 topk_indices[i] = i;
635- topk_values[i] = cuda_cast<T, float >( 1 .0f / topk) ;
676+ topk_values[i] = 1 .0f / topk;
636677 }
637678 }
638679 // Note: when if_proceed_next_topk==false, choose the first 8 experts as the
@@ -644,12 +685,12 @@ __global__ void group_idx_and_topk_idx_kernel(
644685}
645686
646687template <typename T, typename IdxT>
647- void invokeNoAuxTc (T* scores, T* group_scores, T * topk_values,
648- IdxT* topk_indices, T* scores_with_bias ,
649- int64_t const num_tokens , int64_t const num_experts ,
650- int64_t const n_group , int64_t const topk_group ,
651- int64_t const topk, bool const renormalize ,
652- double const routed_scaling_factor , bool enable_pdl = false ,
688+ void invokeNoAuxTc (T* scores, T* group_scores, float * topk_values,
689+ IdxT* topk_indices, T const * bias, int64_t const num_tokens ,
690+ int64_t const num_experts , int64_t const n_group ,
691+ int64_t const topk_group , int64_t const topk ,
692+ bool const renormalize, double const routed_scaling_factor ,
693+ int const scoring_func , bool enable_pdl = false ,
653694 cudaStream_t const stream = 0 ) {
654695 int64_t num_cases = num_tokens * n_group;
655696 int64_t topk_with_k2_num_blocks = (num_cases - 1 ) / NUM_WARPS_PER_BLOCK + 1 ;
@@ -664,8 +705,9 @@ void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values,
664705 attrs[0 ].val .programmaticStreamSerializationAllowed = enable_pdl;
665706 config.numAttrs = 1 ;
666707 config.attrs = attrs;
667- cudaLaunchKernelEx (&config, kernel_instance1, group_scores, scores_with_bias,
668- num_tokens, num_cases, n_group, num_experts / n_group);
708+ cudaLaunchKernelEx (&config, kernel_instance1, group_scores, scores, bias,
709+ num_tokens, num_cases, n_group, num_experts / n_group,
710+ scoring_func);
669711
670712 int64_t topk_with_k_group_num_blocks =
671713 (num_tokens - 1 ) / NUM_WARPS_PER_BLOCK + 1 ;
@@ -682,19 +724,18 @@ void invokeNoAuxTc(T* scores, T* group_scores, T* topk_values,
682724 config.numAttrs = 1 ;
683725 config.attrs = attrs;
684726 cudaLaunchKernelEx (&config, kernel_instance2, scores, group_scores,
685- topk_values, topk_indices, scores_with_bias , num_tokens,
686- n_group, topk_group, topk, num_experts,
687- num_experts / n_group, renormalize, routed_scaling_factor);
727+ topk_values, topk_indices, bias , num_tokens, n_group ,
728+ topk_group, topk, num_experts, num_experts / n_group ,
729+ renormalize, routed_scaling_factor, scoring_func );
688730}
689731
690732#define INSTANTIATE_NOAUX_TC (T, IdxT ) \
691733 template void invokeNoAuxTc<T, IdxT>( \
692- T * scores, T * group_scores, T * topk_values, IdxT * topk_indices, \
693- T * scores_with_bias, int64_t const num_tokens, \
694- int64_t const num_experts, int64_t const n_group, \
695- int64_t const topk_group, int64_t const topk, bool const renormalize, \
696- double const routed_scaling_factor, bool enable_pdl, \
697- cudaStream_t const stream);
734+ T * scores, T * group_scores, float * topk_values, IdxT* topk_indices, \
735+ T const * bias, int64_t const num_tokens, int64_t const num_experts, \
736+ int64_t const n_group, int64_t const topk_group, int64_t const topk, \
737+ bool const renormalize, double const routed_scaling_factor, \
738+ int const scoring_func, bool enable_pdl, cudaStream_t const stream);
698739
699740INSTANTIATE_NOAUX_TC (float , int32_t );
700741INSTANTIATE_NOAUX_TC (half, int32_t );
@@ -703,40 +744,44 @@ INSTANTIATE_NOAUX_TC(__nv_bfloat16, int32_t);
703744} // namespace vllm
704745
705746std::tuple<torch::Tensor, torch::Tensor> grouped_topk (
706- torch::Tensor const & scores, torch::Tensor const & scores_with_bias ,
707- int64_t n_group, int64_t topk_group, int64_t topk, bool renormalize,
708- double routed_scaling_factor ) {
709- auto data_type = scores_with_bias .scalar_type ();
710- auto input_size = scores_with_bias .sizes ();
747+ torch::Tensor const & scores, int64_t n_group, int64_t topk_group ,
748+ int64_t topk, bool renormalize, double routed_scaling_factor ,
749+ torch::Tensor const & bias, int64_t scoring_func = 0 ) {
750+ auto data_type = scores .scalar_type ();
751+ auto input_size = scores .sizes ();
711752 int64_t num_tokens = input_size[0 ];
712753 int64_t num_experts = input_size[1 ];
713- TORCH_CHECK (input_size.size () == 2 , " scores_with_bias must be a 2D Tensor" );
754+ TORCH_CHECK (input_size.size () == 2 , " scores must be a 2D Tensor" );
714755 TORCH_CHECK (num_experts % n_group == 0 ,
715756 " num_experts should be divisible by n_group" );
716757 TORCH_CHECK (n_group <= 32 ,
717758 " n_group should be smaller than or equal to 32 for now" );
718759 TORCH_CHECK (topk <= 32 , " topk should be smaller than or equal to 32 for now" );
760+ TORCH_CHECK (scoring_func == vllm::moe::SCORING_NONE ||
761+ scoring_func == vllm::moe::SCORING_SIGMOID,
762+ " scoring_func must be SCORING_NONE (0) or SCORING_SIGMOID (1)" );
719763
720764 torch::Tensor group_scores = torch::empty (
721765 {num_tokens, n_group}, torch::dtype (data_type).device (torch::kCUDA ));
766+ // Always output float32 for topk_values (eliminates Python-side conversion)
722767 torch::Tensor topk_values = torch::empty (
723- {num_tokens, topk}, torch::dtype (data_type ).device (torch::kCUDA ));
768+ {num_tokens, topk}, torch::dtype (torch:: kFloat32 ).device (torch::kCUDA ));
724769 torch::Tensor topk_indices = torch::empty (
725770 {num_tokens, topk}, torch::dtype (torch::kInt32 ).device (torch::kCUDA ));
726771
727- auto stream = c10::cuda::getCurrentCUDAStream (scores_with_bias .get_device ());
772+ auto stream = c10::cuda::getCurrentCUDAStream (scores .get_device ());
728773
729774 switch (data_type) {
730775 case torch::kFloat16 :
731776 // Handle Float16
732777 vllm::moe::invokeNoAuxTc<half, int32_t >(
733778 reinterpret_cast <half*>(scores.mutable_data_ptr ()),
734779 reinterpret_cast <half*>(group_scores.mutable_data_ptr ()),
735- reinterpret_cast <half *>(topk_values.mutable_data_ptr ()),
780+ reinterpret_cast <float *>(topk_values.mutable_data_ptr ()),
736781 reinterpret_cast <int32_t *>(topk_indices.mutable_data_ptr ()),
737- reinterpret_cast <half*>(scores_with_bias .data_ptr ()), num_tokens,
782+ reinterpret_cast <half const *>(bias .data_ptr ()), num_tokens,
738783 num_experts, n_group, topk_group, topk, renormalize,
739- routed_scaling_factor, false , stream);
784+ routed_scaling_factor, static_cast < int >(scoring_func), false , stream);
740785 break ;
741786 case torch::kFloat32 :
742787 // Handle Float32
@@ -745,20 +790,20 @@ std::tuple<torch::Tensor, torch::Tensor> grouped_topk(
745790 reinterpret_cast <float *>(group_scores.mutable_data_ptr ()),
746791 reinterpret_cast <float *>(topk_values.mutable_data_ptr ()),
747792 reinterpret_cast <int32_t *>(topk_indices.mutable_data_ptr ()),
748- reinterpret_cast <float *>(scores_with_bias .data_ptr ()), num_tokens,
793+ reinterpret_cast <float const *>(bias .data_ptr ()), num_tokens,
749794 num_experts, n_group, topk_group, topk, renormalize,
750- routed_scaling_factor, false , stream);
795+ routed_scaling_factor, static_cast < int >(scoring_func), false , stream);
751796 break ;
752797 case torch::kBFloat16 :
753798 // Handle BFloat16
754799 vllm::moe::invokeNoAuxTc<__nv_bfloat16, int32_t >(
755800 reinterpret_cast <__nv_bfloat16*>(scores.mutable_data_ptr ()),
756801 reinterpret_cast <__nv_bfloat16*>(group_scores.mutable_data_ptr ()),
757- reinterpret_cast <__nv_bfloat16 *>(topk_values.mutable_data_ptr ()),
802+ reinterpret_cast <float *>(topk_values.mutable_data_ptr ()),
758803 reinterpret_cast <int32_t *>(topk_indices.mutable_data_ptr ()),
759- reinterpret_cast <__nv_bfloat16*>(scores_with_bias .data_ptr ()),
760- num_tokens, num_experts, n_group, topk_group, topk, renormalize,
761- routed_scaling_factor, false , stream);
804+ reinterpret_cast <__nv_bfloat16 const *>(bias .data_ptr ()), num_tokens ,
805+ num_experts, n_group, topk_group, topk, renormalize,
806+ routed_scaling_factor, static_cast < int >(scoring_func), false , stream);
762807 break ;
763808 default :
764809 // Handle other data types
0 commit comments