@@ -93,12 +93,11 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
9393 """
9494 batch_size = hidden_states .shape [0 ]
9595 hidden_states = hidden_states .reshape (- 1 , self .hidden_size ) # (num_tokens, hidden_size)
96- num_experts = routing_weights .shape [1 ]
9796 if hidden_states .device .type == "cpu" or self .training :
9897 next_states = torch .zeros_like (hidden_states , dtype = hidden_states .dtype , device = hidden_states .device )
9998 with torch .no_grad ():
10099 expert_mask = torch .nn .functional .one_hot (
101- router_indices , num_classes = num_experts + 1
100+ router_indices , num_classes = self . num_experts
102101 ) # masking is also a class
103102 expert_mask = expert_mask .permute (2 , 1 , 0 )
104103 # we sum on the top_k and on the sequence length to get which experts
@@ -108,10 +107,10 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
108107 # expert_idx only have 1 element, so we can use scale for fast indexing
109108 expert_idx = expert_idx [0 ]
110109 # skip masking index
111- if expert_idx == num_experts :
110+ if expert_idx == self . num_experts :
112111 continue
113112 with torch .no_grad ():
114- _ , token_idx = torch .where (expert_mask [expert_idx ])
113+ top_k_pos , token_idx = torch .where (expert_mask [expert_idx ])
115114 current_state = hidden_states [token_idx ]
116115 gate_up = current_state @ self .gate_up_proj [expert_idx ] + self .gate_up_proj_bias [expert_idx ]
117116 gate , up = gate_up [..., ::2 ], gate_up [..., 1 ::2 ]
@@ -120,21 +119,21 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
120119 glu = gate * torch .sigmoid (gate * self .alpha )
121120 gated_output = (up + 1 ) * glu
122121 out = gated_output @ self .down_proj [expert_idx ] + self .down_proj_bias [expert_idx ]
123- weighted_output = out * routing_weights [token_idx , expert_idx , None ]
122+ weighted_output = out * routing_weights [token_idx , top_k_pos , None ]
124123 next_states .index_add_ (0 , token_idx , weighted_output .to (hidden_states .dtype ))
125124 next_states = next_states .view (batch_size , - 1 , self .hidden_size )
126125 else :
127- hidden_states = hidden_states .repeat (num_experts , 1 )
128- hidden_states = hidden_states .view (num_experts , - 1 , self .hidden_size )
126+ hidden_states = hidden_states .repeat (self . num_experts , 1 )
127+ hidden_states = hidden_states .view (self . num_experts , - 1 , self .hidden_size )
129128 gate_up = torch .bmm (hidden_states , self .gate_up_proj ) + self .gate_up_proj_bias [..., None , :]
130129 gate , up = gate_up [..., ::2 ], gate_up [..., 1 ::2 ]
131130 gate = gate .clamp (min = None , max = self .limit )
132131 up = up .clamp (min = - self .limit , max = self .limit )
133132 glu = gate * torch .sigmoid (gate * self .alpha )
134133 next_states = torch .bmm (((up + 1 ) * glu ), self .down_proj )
135134 next_states = next_states + self .down_proj_bias [..., None , :]
136- next_states = next_states .view (num_experts , batch_size , - 1 , self .hidden_size )
137- next_states = next_states * routing_weights .transpose (0 , 1 ).view (num_experts , batch_size , - 1 )[..., None ]
135+ next_states = next_states .view (self . num_experts , batch_size , - 1 , self .hidden_size )
136+ next_states = next_states * routing_weights .transpose (0 , 1 ).view (self . num_experts , batch_size , - 1 )[..., None ]
138137 next_states = next_states .sum (dim = 0 )
139138 return next_states
140139
@@ -154,7 +153,7 @@ def forward(self, hidden_states):
154153 router_top_value , router_indices = torch .topk (router_logits , self .top_k , dim = - 1 ) # (seq_len, top_k)
155154 router_top_value = torch .nn .functional .softmax (router_top_value , dim = 1 , dtype = router_top_value .dtype )
156155 router_scores = router_top_value
157- return router_scores , router_indices
156+ return router_logits , router_scores , router_indices
158157
159158
160159@use_kernel_forward_from_hub ("MegaBlocksMoeMLP" )
@@ -165,7 +164,7 @@ def __init__(self, config):
165164 self .experts = GptOssExperts (config )
166165
167166 def forward (self , hidden_states ):
168- router_scores , router_indices = self .router (hidden_states )
167+ _ , router_scores , router_indices = self .router (hidden_states )
169168 routed_out = self .experts (hidden_states , router_indices , router_scores )
170169 return routed_out , router_scores
171170
0 commit comments