Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
38 commits
Select commit Hold shift + click to select a range
c480438
remove zero_like + scatter
3outeille Nov 27, 2025
c615c47
Merge branch 'main' into fix-moe-v5
3outeille Nov 27, 2025
073326f
fix mixtral moe
3outeille Nov 27, 2025
8ff6c18
Merge branch 'fix-moe-v5' of https://github.com/huggingface/transform…
3outeille Nov 27, 2025
f3457e2
fix other moe models as well
3outeille Nov 27, 2025
16737a4
fix ci
3outeille Nov 27, 2025
01da12d
Merge branch 'main' into fix-moe-v5
3outeille Nov 27, 2025
57541cd
fix modular mixtral
3outeille Nov 27, 2025
b7eb918
Merge branch 'fix-moe-v5' of https://github.com/huggingface/transform…
3outeille Nov 27, 2025
3992748
fix qwen2_moe + qwen3_next
3outeille Nov 28, 2025
15f41b9
fix device mismatch for qwen3_vl_moe to pass tests
3outeille Nov 28, 2025
35e8bf8
fix modular mixtral
3outeille Nov 28, 2025
e6f026f
fix other models
3outeille Nov 28, 2025
14b7ac0
rm slow tokenizers (#40936)
itazap Nov 27, 2025
ec3f555
[loading/saving] Reverse all loading operations when saving (#42396)
Cyrilvallez Nov 27, 2025
326eb75
Fix T5 tests: use generation_config for generation parameters (#42419)
Abdennacer-Badaoui Nov 28, 2025
50cc1e9
Merge branch 'main' into fix-moe-v5
3outeille Nov 28, 2025
8bccd8c
linting
3outeille Nov 28, 2025
74e84d5
Merge branch 'fix-moe-v5' of https://github.com/huggingface/transform…
3outeille Nov 28, 2025
718cc64
more fix to pass the CI tests
3outeille Nov 28, 2025
19db8c9
Merge branch 'main' into fix-moe-v5
3outeille Nov 28, 2025
1100864
fix lfm2 moe
3outeille Nov 28, 2025
7d024b9
Merge branch 'fix-moe-v5' of https://github.com/huggingface/transform…
3outeille Nov 28, 2025
e6f82dc
Merge branch 'main' into fix-moe-v5
3outeille Nov 28, 2025
e982a15
fix docstring
3outeille Nov 28, 2025
98703cc
Merge branch 'fix-moe-v5' of https://github.com/huggingface/transform…
3outeille Nov 28, 2025
84bb660
Merge branch 'main' into fix-moe-v5
3outeille Nov 28, 2025
3b14e7b
fix docstring
3outeille Nov 28, 2025
0ac90c8
Merge branch 'fix-moe-v5' of https://github.com/huggingface/transform…
3outeille Nov 28, 2025
5e4e7de
Merge branch 'main' into fix-moe-v5
3outeille Nov 28, 2025
0399e13
fix qwen like model
3outeille Nov 28, 2025
af29eee
fix flex olmo
3outeille Nov 28, 2025
bf66927
revert lfm2 moe config
3outeille Nov 28, 2025
4d6e993
Merge branch 'main' into fix-moe-v5
3outeille Nov 28, 2025
144ec86
Merge branch 'main' into fix-moe-v5
3outeille Nov 28, 2025
ede2116
make fixup
3outeille Nov 28, 2025
3132b5f
fix docstring
3outeille Nov 28, 2025
2e04f12
fix conversion mapping
3outeille Nov 28, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/transformers/conversion_mapping.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ def _build_checkpoint_conversion_mapping():
mapping["qwen3_vl_moe"] = mapping["qwen2_moe"].copy()
mapping["hunyuan_v1_moe"] = mapping["qwen2_moe"].copy()
mapping["minimax"] = mapping["mixtral"].copy()
mapping["flex_olmo"] = mapping["qwen2_moe"].copy()
mapping["olmoe"] = mapping["qwen2_moe"].copy()

return mapping

Expand Down
9 changes: 4 additions & 5 deletions src/transformers/models/deepseek_v2/modeling_deepseek_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,22 +61,21 @@ def forward(
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
if expert_idx == self.num_experts:
continue
_, token_idx = torch.where(expert_mask[expert_idx])
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))

return final_hidden_states
Expand Down
9 changes: 4 additions & 5 deletions src/transformers/models/deepseek_v3/modeling_deepseek_v3.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,22 +169,21 @@ def forward(
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
if expert_idx == self.num_experts:
continue
_, token_idx = torch.where(expert_mask[expert_idx])
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))

return final_hidden_states
Expand Down
9 changes: 4 additions & 5 deletions src/transformers/models/dots1/modeling_dots1.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,22 +327,21 @@ def forward(
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
if expert_idx == self.num_experts:
continue
_, token_idx = torch.where(expert_mask[expert_idx])
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))

return final_hidden_states
Expand Down
15 changes: 7 additions & 8 deletions src/transformers/models/flex_olmo/modeling_flex_olmo.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,22 +313,21 @@ def forward(
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
if expert_idx == self.num_experts:
continue
_, token_idx = torch.where(expert_mask[expert_idx])
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))

return final_hidden_states
Expand All @@ -351,8 +350,8 @@ def forward(self, hidden_states):
if self.norm_topk_prob:
router_top_value /= router_top_value.sum(dim=-1, keepdim=True)
router_top_value = router_top_value.to(router_logits.dtype)
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
return router_scores, router_indices
router_scores = router_top_value
return router_logits, router_scores, router_indices


class FlexOlmoSparseMoeBlock(nn.Module):
Expand All @@ -364,7 +363,7 @@ def __init__(self, config):
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
batch_size, sequence_length, hidden_dim = hidden_states.shape
hidden_states = hidden_states.view(-1, hidden_dim)
top_k_weights, top_k_index = self.gate(hidden_states)
_, top_k_weights, top_k_index = self.gate(hidden_states)
final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape(
batch_size, sequence_length, hidden_dim
)
Expand Down
9 changes: 4 additions & 5 deletions src/transformers/models/glm4_moe/modeling_glm4_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,22 +350,21 @@ def forward(
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
if expert_idx == self.num_experts:
continue
_, token_idx = torch.where(expert_mask[expert_idx])
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))

return final_hidden_states
Expand Down
9 changes: 4 additions & 5 deletions src/transformers/models/glm4v_moe/modeling_glm4v_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -414,22 +414,21 @@ def forward(
top_k_weights: torch.Tensor,
) -> torch.Tensor:
final_hidden_states = torch.zeros_like(hidden_states)
num_experts = top_k_weights.shape[1]
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1)
expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts)
expert_mask = expert_mask.permute(2, 1, 0)
expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero()

for expert_idx in expert_hit:
expert_idx = expert_idx[0]
if expert_idx == num_experts:
if expert_idx == self.num_experts:
continue
_, token_idx = torch.where(expert_mask[expert_idx])
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1)
current_hidden_states = self.act_fn(gate) * up
current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx])
current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None]
current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None]
final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype))

return final_hidden_states
Expand Down
25 changes: 13 additions & 12 deletions src/transformers/models/gpt_oss/modeling_gpt_oss.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,12 +95,11 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
"""
batch_size = hidden_states.shape[0]
hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size)
num_experts = routing_weights.shape[1]
if hidden_states.device.type == "cpu" or self.training:
next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device)
with torch.no_grad():
expert_mask = torch.nn.functional.one_hot(
router_indices, num_classes=num_experts + 1
router_indices, num_classes=self.num_experts
) # masking is also a class
expert_mask = expert_mask.permute(2, 1, 0)
# we sum on the top_k and on the sequence length to get which experts
Expand All @@ -110,10 +109,10 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
# expert_idx only have 1 element, so we can use scale for fast indexing
expert_idx = expert_idx[0]
# skip masking index
if expert_idx == num_experts:
if expert_idx == self.num_experts:
continue
with torch.no_grad():
_, token_idx = torch.where(expert_mask[expert_idx])
top_k_pos, token_idx = torch.where(expert_mask[expert_idx])
current_state = hidden_states[token_idx]
gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx]
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
Expand All @@ -122,21 +121,23 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig
glu = gate * torch.sigmoid(gate * self.alpha)
gated_output = (up + 1) * glu
out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx]
weighted_output = out * routing_weights[token_idx, expert_idx, None]
weighted_output = out * routing_weights[token_idx, top_k_pos, None]
next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype))
next_states = next_states.view(batch_size, -1, self.hidden_size)
else:
hidden_states = hidden_states.repeat(num_experts, 1)
hidden_states = hidden_states.view(num_experts, -1, self.hidden_size)
hidden_states = hidden_states.repeat(self.num_experts, 1)
hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size)
gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :]
gate, up = gate_up[..., ::2], gate_up[..., 1::2]
gate = gate.clamp(min=None, max=self.limit)
up = up.clamp(min=-self.limit, max=self.limit)
glu = gate * torch.sigmoid(gate * self.alpha)
next_states = torch.bmm(((up + 1) * glu), self.down_proj)
next_states = next_states + self.down_proj_bias[..., None, :]
next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size)
next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None]
next_states = next_states.view(self.num_experts, batch_size, -1, self.hidden_size)
next_states = (
next_states * routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1)[..., None]
)
next_states = next_states.sum(dim=0)
return next_states

Expand All @@ -155,8 +156,8 @@ def forward(self, hidden_states):
router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts)
router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k)
router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype)
router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value)
return router_scores, router_indices
router_scores = router_top_value
return router_logits, router_scores, router_indices


@use_kernel_forward_from_hub("MegaBlocksMoeMLP")
Expand All @@ -167,7 +168,7 @@ def __init__(self, config):
self.experts = GptOssExperts(config)

def forward(self, hidden_states):
router_scores, router_indices = self.router(hidden_states)
_, router_scores, router_indices = self.router(hidden_states)
routed_out = self.experts(hidden_states, router_indices, router_scores)
return routed_out, router_scores

Expand Down
Loading