Skip to content

Commit e6f026f

Browse files
committed
fix other models
1 parent 35e8bf8 commit e6f026f

File tree

18 files changed

+60
-72
lines changed

18 files changed

+60
-72
lines changed

src/transformers/models/deepseek_v2/modeling_deepseek_v2.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -61,22 +61,21 @@ def forward(
6161
top_k_weights: torch.Tensor,
6262
) -> torch.Tensor:
6363
final_hidden_states = torch.zeros_like(hidden_states)
64-
num_experts = top_k_weights.shape[1]
6564
with torch.no_grad():
66-
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
65+
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
6766
expert_mask = expert_mask.permute(2, 1, 0)
6867
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
6968

7069
for expert_idx in expert_hit:
7170
expert_idx = expert_idx[0]
72-
if expert_idx == num_experts:
71+
if expert_idx == self.num_experts:
7372
continue
74-
_, token_idx = torch.where(expert_mask[expert_idx])
73+
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
7574
current_state = hidden_states[token_idx]
7675
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
7776
current_hidden_states = self.act_fn(gate) * up
7877
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
79-
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
78+
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
8079
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
8180

8281
return final_hidden_states

src/transformers/models/deepseek_v3/modeling_deepseek_v3.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -169,22 +169,21 @@ def forward(
169169
top_k_weights: torch.Tensor,
170170
) -> torch.Tensor:
171171
final_hidden_states = torch.zeros_like(hidden_states)
172-
num_experts = top_k_weights.shape[1]
173172
with torch.no_grad():
174-
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
173+
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
175174
expert_mask = expert_mask.permute(2, 1, 0)
176175
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
177176

178177
for expert_idx in expert_hit:
179178
expert_idx = expert_idx[0]
180-
if expert_idx == num_experts:
179+
if expert_idx == self.num_experts:
181180
continue
182-
_, token_idx = torch.where(expert_mask[expert_idx])
181+
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
183182
current_state = hidden_states[token_idx]
184183
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
185184
current_hidden_states = self.act_fn(gate) * up
186185
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
187-
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
186+
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
188187
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
189188

190189
return final_hidden_states

src/transformers/models/dots1/modeling_dots1.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -325,22 +325,21 @@ def forward(
325325
top_k_weights: torch.Tensor,
326326
) -> torch.Tensor:
327327
final_hidden_states = torch.zeros_like(hidden_states)
328-
num_experts = top_k_weights.shape[1]
329328
with torch.no_grad():
330-
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
329+
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
331330
expert_mask = expert_mask.permute(2, 1, 0)
332331
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
333332

334333
for expert_idx in expert_hit:
335334
expert_idx = expert_idx[0]
336-
if expert_idx == num_experts:
335+
if expert_idx == self.num_experts:
337336
continue
338-
_, token_idx = torch.where(expert_mask[expert_idx])
337+
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
339338
current_state = hidden_states[token_idx]
340339
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
341340
current_hidden_states = self.act_fn(gate) * up
342341
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
343-
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
342+
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
344343
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
345344

346345
return final_hidden_states

src/transformers/models/glm4_moe/modeling_glm4_moe.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -350,22 +350,21 @@ def forward(
350350
top_k_weights: torch.Tensor,
351351
) -> torch.Tensor:
352352
final_hidden_states = torch.zeros_like(hidden_states)
353-
num_experts = top_k_weights.shape[1]
354353
with torch.no_grad():
355-
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
354+
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
356355
expert_mask = expert_mask.permute(2, 1, 0)
357356
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
358357

359358
for expert_idx in expert_hit:
360359
expert_idx = expert_idx[0]
361-
if expert_idx == num_experts:
360+
if expert_idx == self.num_experts:
362361
continue
363-
_, token_idx = torch.where(expert_mask[expert_idx])
362+
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
364363
current_state = hidden_states[token_idx]
365364
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
366365
current_hidden_states = self.act_fn(gate) * up
367366
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
368-
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
367+
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
369368
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
370369

371370
return final_hidden_states

src/transformers/models/glm4v_moe/modeling_glm4v_moe.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -371,22 +371,21 @@ def forward(
371371
top_k_weights: torch.Tensor,
372372
) -> torch.Tensor:
373373
final_hidden_states = torch.zeros_like(hidden_states)
374-
num_experts = top_k_weights.shape[1]
375374
with torch.no_grad():
376-
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
375+
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
377376
expert_mask = expert_mask.permute(2, 1, 0)
378377
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
379378

380379
for expert_idx in expert_hit:
381380
expert_idx = expert_idx[0]
382-
if expert_idx == num_experts:
381+
if expert_idx == self.num_experts:
383382
continue
384-
_, token_idx = torch.where(expert_mask[expert_idx])
383+
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
385384
current_state = hidden_states[token_idx]
386385
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
387386
current_hidden_states = self.act_fn(gate) * up
388387
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
389-
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
388+
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
390389
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
391390

392391
return final_hidden_states

src/transformers/models/gpt_oss/modular_gpt_oss.py

Lines changed: 10 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,11 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
9393
"""
9494
batch_size = hidden_states.shape[0]
9595
hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
96-
num_experts = routing_weights.shape[1]
9796
if hidden_states.device.type == "cpu" or self.training:
9897
next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
9998
with torch.no_grad():
10099
expert_mask = torch.nn.functional.one_hot(
101-
router_indices, num_classes=num_experts + 1
100+
router_indices, num_classes=self.num_experts
102101
) # masking is also a class
103102
expert_mask = expert_mask.permute(2, 1, 0)
104103
# we sum on the top_k and on the sequence length to get which experts
@@ -108,10 +107,10 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
108107
# expert_idx only have 1 element, so we can use scale for fast indexing
109108
expert_idx = expert_idx[0]
110109
# skip masking index
111-
if expert_idx == num_experts:
110+
if expert_idx == self.num_experts:
112111
continue
113112
with torch.no_grad():
114-
_, token_idx = torch.where(expert_mask[expert_idx])
113+
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
115114
current_state = hidden_states[token_idx]
116115
gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
117116
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
@@ -120,21 +119,21 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
120119
glu = gate * torch.sigmoid(gate * self.alpha)
121120
gated_output = (up + 1) * glu
122121
out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
123-
weighted_output = out * routing_weights[token_idx, expert_idx, None]
122+
weighted_output = out * routing_weights[token_idx, top_k_pos, None]
124123
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
125124
next_states = next_states.view(batch_size, -1, self.hidden_size)
126125
else:
127-
hidden_states = hidden_states.repeat(num_experts, 1)
128-
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
126+
hidden_states = hidden_states.repeat(self.num_experts, 1)
127+
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
129128
gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
130129
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
131130
gate = gate.clamp(min=None, max=self.limit)
132131
up = up.clamp(min=-self.limit, max=self.limit)
133132
glu = gate * torch.sigmoid(gate * self.alpha)
134133
next_states = torch.bmm(((up + 1) * glu), self.down_proj)
135134
next_states = next_states + self.down_proj_bias[..., None, :]
136-
next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
137-
next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
135+
next_states = next_states.view(self.num_experts, batch_size, -1, self.hidden_size)
136+
next_states = next_states * routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1)[..., None]
138137
next_states = next_states.sum(dim=0)
139138
return next_states
140139

@@ -154,7 +153,7 @@ def forward(self, hidden_states):
154153
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
155154
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
156155
router_scores = router_top_value
157-
return router_scores, router_indices
156+
return router_logits, router_scores, router_indices
158157

159158

160159
@use_kernel_forward_from_hub("MegaBlocksMoeMLP")
@@ -165,7 +164,7 @@ def __init__(self, config):
165164
self.experts = GptOssExperts(config)
166165

167166
def forward(self, hidden_states):
168-
router_scores, router_indices = self.router(hidden_states)
167+
_, router_scores, router_indices = self.router(hidden_states)
169168
routed_out = self.experts(hidden_states, router_indices, router_scores)
170169
return routed_out, router_scores
171170

src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -261,22 +261,21 @@ def forward(
261261
top_k_weights: torch.Tensor,
262262
) -> torch.Tensor:
263263
final_hidden_states = torch.zeros_like(hidden_states)
264-
num_experts = top_k_weights.shape[1]
265264
with torch.no_grad():
266-
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
265+
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
267266
expert_mask = expert_mask.permute(2, 1, 0)
268267
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
269268

270269
for expert_idx in expert_hit:
271270
expert_idx = expert_idx[0]
272-
if expert_idx == num_experts:
271+
if expert_idx == self.num_experts:
273272
continue
274-
_, token_idx = torch.where(expert_mask[expert_idx])
273+
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
275274
current_state = hidden_states[token_idx]
276275
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
277276
current_hidden_states = self.act_fn(gate) * up
278277
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
279-
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
278+
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
280279
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
281280

282281
return final_hidden_states

src/transformers/models/jamba/modeling_jamba.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -577,22 +577,21 @@ def forward(
577577
top_k_weights: torch.Tensor,
578578
) -> torch.Tensor:
579579
final_hidden_states = torch.zeros_like(hidden_states)
580-
num_experts = top_k_weights.shape[1]
581580
with torch.no_grad():
582-
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
581+
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
583582
expert_mask = expert_mask.permute(2, 1, 0)
584583
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
585584

586585
for expert_idx in expert_hit:
587586
expert_idx = expert_idx[0]
588-
if expert_idx == num_experts:
587+
if expert_idx == self.num_experts:
589588
continue
590-
_, token_idx = torch.where(expert_mask[expert_idx])
589+
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
591590
current_state = hidden_states[token_idx]
592591
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
593592
current_hidden_states = self.act_fn(gate) * up
594593
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
595-
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
594+
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
596595
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
597596

598597
return final_hidden_states

src/transformers/models/lfm2_moe/modeling_lfm2_moe.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -165,22 +165,21 @@ def forward(
165165
top_k_weights: torch.Tensor,
166166
) -> torch.Tensor:
167167
final_hidden_states = torch.zeros_like(hidden_states)
168-
num_experts = top_k_weights.shape[1]
169168
with torch.no_grad():
170-
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
169+
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
171170
expert_mask = expert_mask.permute(2, 1, 0)
172171
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()
173172

174173
for expert_idx in expert_hit:
175174
expert_idx = expert_idx[0]
176-
if expert_idx == num_experts:
175+
if expert_idx == self.num_experts:
177176
continue
178-
_, token_idx = torch.where(expert_mask[expert_idx])
177+
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
179178
current_state = hidden_states[token_idx]
180179
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
181180
current_hidden_states = self.act_fn(gate) * up
182181
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
183-
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
182+
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
184183
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))
185184

186185
return final_hidden_states

src/transformers/models/mixtral/modeling_mixtral.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def forward(
8383
expert_idx = expert_idx[0]
8484
if expert_idx == self.num_experts:
8585
continue
86-
8786
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
8887
current_state = hidden_states[token_idx]
8988
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)

0 commit comments

Comments
 (0)