diff --git a/src/transformers/conversion_mapping.py b/src/transformers/conversion_mapping.py index 159c503f0f56..5968bd08d406 100644 --- a/src/transformers/conversion_mapping.py +++ b/src/transformers/conversion_mapping.py @@ -175,6 +175,8 @@ def _build_checkpoint_conversion_mapping(): mapping["qwen3_vl_moe"] = mapping["qwen2_moe"].copy() mapping["hunyuan_v1_moe"] = mapping["qwen2_moe"].copy() mapping["minimax"] = mapping["mixtral"].copy() + mapping["flex_olmo"] = mapping["qwen2_moe"].copy() + mapping["olmoe"] = mapping["qwen2_moe"].copy() return mapping diff --git a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py index 89230d7a80b2..ad610e374899 100644 --- a/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py +++ b/src/transformers/models/deepseek_v2/modeling_deepseek_v2.py @@ -61,22 +61,21 @@ def forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) - num_experts = top_k_weights.shape[1] with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: expert_idx = expert_idx[0] - if expert_idx == num_experts: + if expert_idx == self.num_experts: continue - _, token_idx = torch.where(expert_mask[expert_idx]) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) current_hidden_states = self.act_fn(gate) * up current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) return final_hidden_states diff --git a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py index 4c56277c69dd..946f07cad901 100644 --- a/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py +++ b/src/transformers/models/deepseek_v3/modeling_deepseek_v3.py @@ -169,22 +169,21 @@ def forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) - num_experts = top_k_weights.shape[1] with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: expert_idx = expert_idx[0] - if expert_idx == num_experts: + if expert_idx == self.num_experts: continue - _, token_idx = torch.where(expert_mask[expert_idx]) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) current_hidden_states = self.act_fn(gate) * up current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) return final_hidden_states diff --git a/src/transformers/models/dots1/modeling_dots1.py b/src/transformers/models/dots1/modeling_dots1.py index 9092a3533e43..38dd93ba01e6 100644 --- a/src/transformers/models/dots1/modeling_dots1.py +++ b/src/transformers/models/dots1/modeling_dots1.py @@ -327,22 +327,21 @@ def forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) - num_experts = top_k_weights.shape[1] with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: expert_idx = expert_idx[0] - if expert_idx == num_experts: + if expert_idx == self.num_experts: continue - _, token_idx = torch.where(expert_mask[expert_idx]) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) current_hidden_states = self.act_fn(gate) * up current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) return final_hidden_states diff --git a/src/transformers/models/flex_olmo/modeling_flex_olmo.py b/src/transformers/models/flex_olmo/modeling_flex_olmo.py index 993a3dae1652..02e7af7e02e9 100644 --- a/src/transformers/models/flex_olmo/modeling_flex_olmo.py +++ b/src/transformers/models/flex_olmo/modeling_flex_olmo.py @@ -313,22 +313,21 @@ def forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) - num_experts = top_k_weights.shape[1] with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: expert_idx = expert_idx[0] - if expert_idx == num_experts: + if expert_idx == self.num_experts: continue - _, token_idx = torch.where(expert_mask[expert_idx]) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) current_hidden_states = self.act_fn(gate) * up current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) return final_hidden_states @@ -351,8 +350,8 @@ def forward(self, hidden_states): if self.norm_topk_prob: router_top_value /= router_top_value.sum(dim=-1, keepdim=True) router_top_value = router_top_value.to(router_logits.dtype) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices + router_scores = router_top_value + return router_logits, router_scores, router_indices class FlexOlmoSparseMoeBlock(nn.Module): @@ -364,7 +363,7 @@ def __init__(self, config): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - top_k_weights, top_k_index = self.gate(hidden_states) + _, top_k_weights, top_k_index = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape( batch_size, sequence_length, hidden_dim ) diff --git a/src/transformers/models/glm4_moe/modeling_glm4_moe.py b/src/transformers/models/glm4_moe/modeling_glm4_moe.py index e987e3e9e424..ae0dd080e451 100644 --- a/src/transformers/models/glm4_moe/modeling_glm4_moe.py +++ b/src/transformers/models/glm4_moe/modeling_glm4_moe.py @@ -350,22 +350,21 @@ def forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) - num_experts = top_k_weights.shape[1] with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: expert_idx = expert_idx[0] - if expert_idx == num_experts: + if expert_idx == self.num_experts: continue - _, token_idx = torch.where(expert_mask[expert_idx]) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) current_hidden_states = self.act_fn(gate) * up current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) return final_hidden_states diff --git a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py index 9c287a4b432c..18620637b8f4 100644 --- a/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py +++ b/src/transformers/models/glm4v_moe/modeling_glm4v_moe.py @@ -414,22 +414,21 @@ def forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) - num_experts = top_k_weights.shape[1] with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: expert_idx = expert_idx[0] - if expert_idx == num_experts: + if expert_idx == self.num_experts: continue - _, token_idx = torch.where(expert_mask[expert_idx]) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) current_hidden_states = self.act_fn(gate) * up current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) return final_hidden_states diff --git a/src/transformers/models/gpt_oss/modeling_gpt_oss.py b/src/transformers/models/gpt_oss/modeling_gpt_oss.py index 5e1173d823d0..85ed02e75983 100644 --- a/src/transformers/models/gpt_oss/modeling_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modeling_gpt_oss.py @@ -95,12 +95,11 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig """ batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) - num_experts = routing_weights.shape[1] if hidden_states.device.type == "cpu" or self.training: next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) with torch.no_grad(): expert_mask = torch.nn.functional.one_hot( - router_indices, num_classes=num_experts + 1 + router_indices, num_classes=self.num_experts ) # masking is also a class expert_mask = expert_mask.permute(2, 1, 0) # we sum on the top_k and on the sequence length to get which experts @@ -110,10 +109,10 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig # expert_idx only have 1 element, so we can use scale for fast indexing expert_idx = expert_idx[0] # skip masking index - if expert_idx == num_experts: + if expert_idx == self.num_experts: continue with torch.no_grad(): - _, token_idx = torch.where(expert_mask[expert_idx]) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] gate, up = gate_up[..., ::2], gate_up[..., 1::2] @@ -122,12 +121,12 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig glu = gate * torch.sigmoid(gate * self.alpha) gated_output = (up + 1) * glu out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] - weighted_output = out * routing_weights[token_idx, expert_idx, None] + weighted_output = out * routing_weights[token_idx, top_k_pos, None] next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) next_states = next_states.view(batch_size, -1, self.hidden_size) else: - hidden_states = hidden_states.repeat(num_experts, 1) - hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) + hidden_states = hidden_states.repeat(self.num_experts, 1) + hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] gate, up = gate_up[..., ::2], gate_up[..., 1::2] gate = gate.clamp(min=None, max=self.limit) @@ -135,8 +134,10 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig glu = gate * torch.sigmoid(gate * self.alpha) next_states = torch.bmm(((up + 1) * glu), self.down_proj) next_states = next_states + self.down_proj_bias[..., None, :] - next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) - next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] + next_states = next_states.view(self.num_experts, batch_size, -1, self.hidden_size) + next_states = ( + next_states * routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1)[..., None] + ) next_states = next_states.sum(dim=0) return next_states @@ -155,8 +156,8 @@ def forward(self, hidden_states): router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices + router_scores = router_top_value + return router_logits, router_scores, router_indices @use_kernel_forward_from_hub("MegaBlocksMoeMLP") @@ -167,7 +168,7 @@ def __init__(self, config): self.experts = GptOssExperts(config) def forward(self, hidden_states): - router_scores, router_indices = self.router(hidden_states) + _, router_scores, router_indices = self.router(hidden_states) routed_out = self.experts(hidden_states, router_indices, router_scores) return routed_out, router_scores diff --git a/src/transformers/models/gpt_oss/modular_gpt_oss.py b/src/transformers/models/gpt_oss/modular_gpt_oss.py index 57acfea8df64..ef474dca22b9 100644 --- a/src/transformers/models/gpt_oss/modular_gpt_oss.py +++ b/src/transformers/models/gpt_oss/modular_gpt_oss.py @@ -93,12 +93,11 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig """ batch_size = hidden_states.shape[0] hidden_states = hidden_states.reshape(-1, self.hidden_size) # (num_tokens, hidden_size) - num_experts = routing_weights.shape[1] if hidden_states.device.type == "cpu" or self.training: next_states = torch.zeros_like(hidden_states, dtype=hidden_states.dtype, device=hidden_states.device) with torch.no_grad(): expert_mask = torch.nn.functional.one_hot( - router_indices, num_classes=num_experts + 1 + router_indices, num_classes=self.num_experts ) # masking is also a class expert_mask = expert_mask.permute(2, 1, 0) # we sum on the top_k and on the sequence length to get which experts @@ -108,10 +107,10 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig # expert_idx only have 1 element, so we can use scale for fast indexing expert_idx = expert_idx[0] # skip masking index - if expert_idx == num_experts: + if expert_idx == self.num_experts: continue with torch.no_grad(): - _, token_idx = torch.where(expert_mask[expert_idx]) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate_up = current_state @ self.gate_up_proj[expert_idx] + self.gate_up_proj_bias[expert_idx] gate, up = gate_up[..., ::2], gate_up[..., 1::2] @@ -120,12 +119,12 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig glu = gate * torch.sigmoid(gate * self.alpha) gated_output = (up + 1) * glu out = gated_output @ self.down_proj[expert_idx] + self.down_proj_bias[expert_idx] - weighted_output = out * routing_weights[token_idx, expert_idx, None] + weighted_output = out * routing_weights[token_idx, top_k_pos, None] next_states.index_add_(0, token_idx, weighted_output.to(hidden_states.dtype)) next_states = next_states.view(batch_size, -1, self.hidden_size) else: - hidden_states = hidden_states.repeat(num_experts, 1) - hidden_states = hidden_states.view(num_experts, -1, self.hidden_size) + hidden_states = hidden_states.repeat(self.num_experts, 1) + hidden_states = hidden_states.view(self.num_experts, -1, self.hidden_size) gate_up = torch.bmm(hidden_states, self.gate_up_proj) + self.gate_up_proj_bias[..., None, :] gate, up = gate_up[..., ::2], gate_up[..., 1::2] gate = gate.clamp(min=None, max=self.limit) @@ -133,8 +132,10 @@ def forward(self, hidden_states: torch.Tensor, router_indices=None, routing_weig glu = gate * torch.sigmoid(gate * self.alpha) next_states = torch.bmm(((up + 1) * glu), self.down_proj) next_states = next_states + self.down_proj_bias[..., None, :] - next_states = next_states.view(num_experts, batch_size, -1, self.hidden_size) - next_states = next_states * routing_weights.transpose(0, 1).view(num_experts, batch_size, -1)[..., None] + next_states = next_states.view(self.num_experts, batch_size, -1, self.hidden_size) + next_states = ( + next_states * routing_weights.transpose(0, 1).view(self.num_experts, batch_size, -1)[..., None] + ) next_states = next_states.sum(dim=0) return next_states @@ -153,8 +154,8 @@ def forward(self, hidden_states): router_logits = F.linear(hidden_states, self.weight, self.bias) # (seq_len, num_experts) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) router_top_value = torch.nn.functional.softmax(router_top_value, dim=1, dtype=router_top_value.dtype) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices + router_scores = router_top_value + return router_logits, router_scores, router_indices @use_kernel_forward_from_hub("MegaBlocksMoeMLP") @@ -165,7 +166,7 @@ def __init__(self, config): self.experts = GptOssExperts(config) def forward(self, hidden_states): - router_scores, router_indices = self.router(hidden_states) + _, router_scores, router_indices = self.router(hidden_states) routed_out = self.experts(hidden_states, router_indices, router_scores) return routed_out, router_scores diff --git a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py index dda4366f0d4d..611d51cac094 100644 --- a/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py +++ b/src/transformers/models/hunyuan_v1_moe/modeling_hunyuan_v1_moe.py @@ -263,22 +263,21 @@ def forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) - num_experts = top_k_weights.shape[1] with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: expert_idx = expert_idx[0] - if expert_idx == num_experts: + if expert_idx == self.num_experts: continue - _, token_idx = torch.where(expert_mask[expert_idx]) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) current_hidden_states = self.act_fn(gate) * up current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) return final_hidden_states diff --git a/src/transformers/models/jamba/modeling_jamba.py b/src/transformers/models/jamba/modeling_jamba.py index 47673a3afab2..fed0a010b468 100755 --- a/src/transformers/models/jamba/modeling_jamba.py +++ b/src/transformers/models/jamba/modeling_jamba.py @@ -613,22 +613,21 @@ def forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) - num_experts = top_k_weights.shape[1] with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: expert_idx = expert_idx[0] - if expert_idx == num_experts: + if expert_idx == self.num_experts: continue - _, token_idx = torch.where(expert_mask[expert_idx]) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) current_hidden_states = self.act_fn(gate) * up current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) return final_hidden_states diff --git a/src/transformers/models/lfm2_moe/configuration_lfm2_moe.py b/src/transformers/models/lfm2_moe/configuration_lfm2_moe.py index 93b6ad208abe..2689e4f49571 100644 --- a/src/transformers/models/lfm2_moe/configuration_lfm2_moe.py +++ b/src/transformers/models/lfm2_moe/configuration_lfm2_moe.py @@ -29,66 +29,67 @@ class Lfm2MoeConfig(PreTrainedConfig): Args: - vocab_size (`int`, *optional*, defaults to 65536): - Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the - `inputs_ids` passed when calling [`Lfm2Model`] - hidden_size (`int`, *optional*, defaults to 2048): - Dimension of the hidden representations. - intermediate_size (`int`, *optional*, defaults to 7168): - Dimension of the MLP representations. - moe_intermediate_size (`int`, *optional*, defaults to 1792): - Intermediate size of the routed expert. - num_hidden_layers (`int`, *optional*, defaults to 32): - Number of hidden layers in the Transformer decoder. - pad_token_id (`int`, *optional*, defaults to 0): - Padding token id. - bos_token_id (`int`, *optional*, defaults to 1): - Beginning of stream token id. - eos_token_id (`int`, *optional*, defaults to 2): - End of stream token id. - tie_word_embeddings (`bool`, *optional*, defaults to `True`): - Whether to tie weight embeddings - rope_parameters (`RopeParameters`, *optional*): - Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain - a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE - with longer `max_position_embeddings`. - max_position_embeddings (`int`, *optional*, defaults to 128000): - The maximum sequence length that this model might ever be used with. - initializer_range (`float`, *optional*, defaults to 0.02): - The standard deviation of the truncated_normal_initializer for initializing all weight matrices. - use_cache (`bool`, *optional*, defaults to `True`): - Whether or not the model should return the last key/values attentions (not used by all models). Only - relevant if `config.is_decoder=True`. - norm_eps (`float`, *optional*, defaults to 1e-05): - The epsilon used by the rms normalization layers. - num_attention_heads (`int`, *optional*, defaults to 32): - Number of attention heads for each attention layer in the Transformer decoder. - num_key_value_heads (`int`, *optional*, defaults to 8): - This is the number of key_value heads that should be used to implement Grouped Query Attention. If - `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if - `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When - converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed - by meanpooling all the original heads within that group. For more details, check out [this - paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to - `num_attention_heads`. - conv_bias (`bool`, *optional*, defaults to `False`): - Whether to use bias in the conv layers. - conv_L_cache (`int`, *optional*, defaults to 3): - L_cache dim in the conv layers. - num_dense_layers (`int`, *optional*, defaults to 2): - Number of dense Lfm2MoeMLP layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). - num_experts_per_tok (`int`, *optional*, defaults to 4): - Number of selected experts. - num_experts (`int`, *optional*, defaults to 32): - Number of routed experts. - use_expert_bias (`bool`, *optional*, defaults to `True`): - Whether to use the expert bias on the routing weights. - routed_scaling_factor (`float`, *optional*, defaults to 1.0): - Scaling factor for routed experts in MoE models. - norm_topk_prob (`bool`, *optional*, defaults to `True`): - Whether to normalize the topk probabilities. - layer_types (`Optional`, *optional*): - Type of each layers. + vocab_size (`int`, *optional*, defaults to 65536): + Vocabulary size of the LLaMA model. Defines the number of different tokens that can be represented by the + `inputs_ids` passed when calling [`Lfm2Model`] + hidden_size (`int`, *optional*, defaults to 2048): + Dimension of the hidden representations. + intermediate_size (`int`, *optional*, defaults to 7168): + Dimension of the MLP representations. + moe_intermediate_size (`int`, *optional*, defaults to 1792): + Intermediate size of the routed expert. + num_hidden_layers (`int`, *optional*, defaults to 32): + Number of hidden layers in the Transformer decoder. + pad_token_id (`int`, *optional*, defaults to 0): + Padding token id. + bos_token_id (`int`, *optional*, defaults to 1): + Beginning of stream token id. + eos_token_id (`int`, *optional*, defaults to 2): + End of stream token id. + tie_word_embeddings (`bool`, *optional*, defaults to `True`): + Whether to tie weight embeddings + rope_parameters (`RopeParameters`, *optional*): + Dictionary containing the configuration parameters for the RoPE embeddings. The dictionary should contain + a value for `rope_theta` and optionally parameters used for scaling in case you want to use RoPE + with longer `max_position_embeddings`. + max_position_embeddings (`int`, *optional*, defaults to 128000): + The maximum sequence length that this model might ever be used with. + initializer_range (`float`, *optional*, defaults to 0.02): + The standard deviation of the truncated_normal_initializer for initializing all weight matrices. + use_cache (`bool`, *optional*, defaults to `True`): + Whether or not the model should return the last key/values attentions (not used by all models). Only + relevant if `config.is_decoder=True`. + norm_eps (`float`, *optional*, defaults to 1e-05): + The epsilon used by the rms normalization layers. + num_attention_heads (`int`, *optional*, defaults to 32): + Number of attention heads for each attention layer in the Transformer decoder. + num_key_value_heads (`int`, *optional*, defaults to 8): + This is the number of key_value heads that should be used to implement Grouped Query Attention. If + `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if + `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When + converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed + by meanpooling all the original heads within that group. For more details, check out [this + paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to + `num_attention_heads`. + conv_bias (`bool`, *optional*, defaults to `False`): + Whether to use bias in the conv layers. + conv_L_cache (`int`, *optional*, defaults to 3): + L_cache dim in the conv layers. + num_dense_layers (`int`, *optional*, defaults to 2): + Number of dense Lfm2MoeMLP layers in shallow layers(embed->dense->dense->...->dense->moe->moe...->lm_head). + num_experts_per_tok (`int`, *optional*, defaults to 4): + Number of selected experts. + num_experts (`int`, *optional*, defaults to 32): + Number of routed experts. + use_expert_bias (`bool`, *optional*, defaults to `True`): + Whether to use the expert bias on the routing weights. + routed_scaling_factor (`float`, *optional*, defaults to 1.0): + Scaling factor for routed experts in MoE models. + norm_topk_prob (`bool`, *optional*, defaults to `True`): + Whether to normalize the topk probabilities. + layer_types (`Optional`, *optional*): + Type of each layers. + hidden_act (`str`, *optional*, defaults to `"silu"`): The non-linear activation function in the MLP. ```python >>> from transformers import Lfm2MoeModel, Lfm2MoeConfig @@ -134,6 +135,7 @@ def __init__( routed_scaling_factor: float = 1.0, norm_topk_prob: bool = True, layer_types: Optional[list[str]] = None, + hidden_act: str = "silu", **kwargs, ): self.vocab_size = vocab_size @@ -162,6 +164,8 @@ def __init__( self.routed_scaling_factor = routed_scaling_factor self.norm_topk_prob = norm_topk_prob self.layer_types = layer_types + self.hidden_act = hidden_act + self.initializer_range = initializer_range self.rope_parameters = rope_parameters tie_word_embeddings = kwargs.get("tie_embedding", tie_word_embeddings) # to fit original config keys diff --git a/src/transformers/models/minimax/modeling_minimax.py b/src/transformers/models/minimax/modeling_minimax.py index 8b8f8c9adb3a..eb069a7a41ab 100644 --- a/src/transformers/models/minimax/modeling_minimax.py +++ b/src/transformers/models/minimax/modeling_minimax.py @@ -469,8 +469,8 @@ def forward(self, hidden_states): router_logits = torch.nn.functional.softmax(router_logits.float(), dim=-1) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) router_top_value /= router_top_value.sum(dim=-1, keepdim=True) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices + router_scores = router_top_value + return router_logits, router_scores, router_indices class MiniMaxExperts(nn.Module): @@ -492,22 +492,21 @@ def forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) - num_experts = top_k_weights.shape[1] with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: expert_idx = expert_idx[0] - if expert_idx == num_experts: + if expert_idx == self.num_experts: continue - _, token_idx = torch.where(expert_mask[expert_idx]) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) current_hidden_states = self.act_fn(gate) * up current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) return final_hidden_states @@ -526,7 +525,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens if self.training and self.jitter_noise > 0: hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - top_k_weights, top_k_index = self.gate(hidden_states) + _, top_k_weights, top_k_index = self.gate(hidden_states) hidden_states = self.experts(hidden_states, top_k_index, top_k_weights) hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim) return hidden_states diff --git a/src/transformers/models/mixtral/modeling_mixtral.py b/src/transformers/models/mixtral/modeling_mixtral.py index 95b236dadce6..a62c5e8aecf3 100644 --- a/src/transformers/models/mixtral/modeling_mixtral.py +++ b/src/transformers/models/mixtral/modeling_mixtral.py @@ -74,22 +74,21 @@ def forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) - num_experts = top_k_weights.shape[1] with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: expert_idx = expert_idx[0] - if expert_idx == num_experts: + if expert_idx == self.num_experts: continue - _, token_idx = torch.where(expert_mask[expert_idx]) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) current_hidden_states = self.act_fn(gate) * up current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) return final_hidden_states @@ -109,8 +108,8 @@ def forward(self, hidden_states): router_logits = torch.nn.functional.softmax(router_logits.float(), dim=-1) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) router_top_value /= router_top_value.sum(dim=-1, keepdim=True) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices + router_scores = router_top_value + return router_logits, router_scores, router_indices class MixtralSparseMoeBlock(nn.Module): @@ -126,7 +125,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens if self.training and self.jitter_noise > 0: hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - top_k_weights, top_k_index = self.gate(hidden_states) + _, top_k_weights, top_k_index = self.gate(hidden_states) hidden_states = self.experts(hidden_states, top_k_index, top_k_weights) hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim) return hidden_states diff --git a/src/transformers/models/mixtral/modular_mixtral.py b/src/transformers/models/mixtral/modular_mixtral.py index fabb84db688e..1796070fe6b6 100644 --- a/src/transformers/models/mixtral/modular_mixtral.py +++ b/src/transformers/models/mixtral/modular_mixtral.py @@ -153,22 +153,21 @@ def forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) - num_experts = top_k_weights.shape[1] with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: expert_idx = expert_idx[0] - if expert_idx == num_experts: + if expert_idx == self.num_experts: continue - _, token_idx = torch.where(expert_mask[expert_idx]) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) current_hidden_states = self.act_fn(gate) * up current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) return final_hidden_states @@ -188,8 +187,8 @@ def forward(self, hidden_states): router_logits = torch.nn.functional.softmax(router_logits.float(), dim=-1) router_top_value, router_indices = torch.topk(router_logits, self.top_k, dim=-1) # (seq_len, top_k) router_top_value /= router_top_value.sum(dim=-1, keepdim=True) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices + router_scores = router_top_value + return router_logits, router_scores, router_indices class MixtralSparseMoeBlock(nn.Module): @@ -205,7 +204,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens if self.training and self.jitter_noise > 0: hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise) hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) - top_k_weights, top_k_index = self.gate(hidden_states) + _, top_k_weights, top_k_index = self.gate(hidden_states) hidden_states = self.experts(hidden_states, top_k_index, top_k_weights) hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim) return hidden_states diff --git a/src/transformers/models/nanochat/modeling_nanochat.py b/src/transformers/models/nanochat/modeling_nanochat.py index 0488405d12e9..ab4ecc138912 100644 --- a/src/transformers/models/nanochat/modeling_nanochat.py +++ b/src/transformers/models/nanochat/modeling_nanochat.py @@ -30,6 +30,7 @@ from ...activations import ACT2FN from ...cache_utils import Cache, DynamicCache from ...generation import GenerationMixin +from ...integrations import use_kernel_func_from_hub from ...masking_utils import create_causal_mask from ...modeling_layers import GradientCheckpointingLayer from ...modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast @@ -121,6 +122,7 @@ def forward(self, x, position_ids): return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype) +@use_kernel_func_from_hub("rotary_pos_emb") def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1): """Applies Rotary Position Embedding to the query and key tensors. @@ -218,6 +220,7 @@ def __init__(self, config: NanoChatConfig, layer_idx: int): self.o_proj = nn.Linear( config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias ) + self.rotary_fn = apply_rotary_pos_emb self.q_norm = NanoChatRMSNorm(eps=config.rms_norm_eps) self.k_norm = NanoChatRMSNorm(eps=config.rms_norm_eps) diff --git a/src/transformers/models/olmoe/modeling_olmoe.py b/src/transformers/models/olmoe/modeling_olmoe.py index d04cd421d441..633f2365be92 100644 --- a/src/transformers/models/olmoe/modeling_olmoe.py +++ b/src/transformers/models/olmoe/modeling_olmoe.py @@ -317,22 +317,21 @@ def forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) - num_experts = top_k_weights.shape[1] with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: expert_idx = expert_idx[0] - if expert_idx == num_experts: + if expert_idx == self.num_experts: continue - _, token_idx = torch.where(expert_mask[expert_idx]) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) current_hidden_states = self.act_fn(gate) * up current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) return final_hidden_states @@ -355,8 +354,8 @@ def forward(self, hidden_states): if self.norm_topk_prob: router_top_value /= router_top_value.sum(dim=-1, keepdim=True) router_top_value = router_top_value.to(router_logits.dtype) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices + router_scores = router_top_value + return router_logits, router_scores, router_indices class OlmoeSparseMoeBlock(nn.Module): @@ -368,7 +367,7 @@ def __init__(self, config): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - top_k_weights, top_k_index = self.gate(hidden_states) + _, top_k_weights, top_k_index = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape( batch_size, sequence_length, hidden_dim ) diff --git a/src/transformers/models/olmoe/modular_olmoe.py b/src/transformers/models/olmoe/modular_olmoe.py index eef444e6f24a..e9399fac1a12 100644 --- a/src/transformers/models/olmoe/modular_olmoe.py +++ b/src/transformers/models/olmoe/modular_olmoe.py @@ -134,7 +134,7 @@ def __init__(self, config): def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) - top_k_weights, top_k_index = self.gate(hidden_states) + _, top_k_weights, top_k_index = self.gate(hidden_states) final_hidden_states = self.experts(hidden_states, top_k_index, top_k_weights).reshape( batch_size, sequence_length, hidden_dim ) diff --git a/src/transformers/models/phimoe/modeling_phimoe.py b/src/transformers/models/phimoe/modeling_phimoe.py index fc7c733f3271..22ffb7710cbc 100644 --- a/src/transformers/models/phimoe/modeling_phimoe.py +++ b/src/transformers/models/phimoe/modeling_phimoe.py @@ -346,22 +346,21 @@ def forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) - num_experts = top_k_weights.shape[1] with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: expert_idx = expert_idx[0] - if expert_idx == num_experts: + if expert_idx == self.num_experts: continue - _, token_idx = torch.where(expert_mask[expert_idx]) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) current_hidden_states = self.act_fn(gate) * up current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) return final_hidden_states diff --git a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py index 389bf016243a..1ed6819b3864 100644 --- a/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modeling_qwen2_moe.py @@ -311,22 +311,21 @@ def forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) - num_experts = top_k_weights.shape[1] with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: expert_idx = expert_idx[0] - if expert_idx == num_experts: + if expert_idx == self.num_experts: continue - _, token_idx = torch.where(expert_mask[expert_idx]) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) current_hidden_states = self.act_fn(gate) * up current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) return final_hidden_states @@ -349,8 +348,8 @@ def forward(self, hidden_states): if self.norm_topk_prob: router_top_value /= router_top_value.sum(dim=-1, keepdim=True) router_top_value = router_top_value.to(router_logits.dtype) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices + router_scores = router_top_value + return router_logits, router_scores, router_indices class Qwen2MoeSparseMoeBlock(nn.Module): @@ -365,7 +364,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_reshaped = hidden_states.view(-1, hidden_dim) shared_expert_output = self.shared_expert(hidden_states_reshaped) - routing_weights, selected_experts = self.gate(hidden_states_reshaped) + _, routing_weights, selected_experts = self.gate(hidden_states_reshaped) expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output diff --git a/src/transformers/models/qwen2_moe/modular_qwen2_moe.py b/src/transformers/models/qwen2_moe/modular_qwen2_moe.py index fa33b78c42f5..5e1b26e9a0e7 100644 --- a/src/transformers/models/qwen2_moe/modular_qwen2_moe.py +++ b/src/transformers/models/qwen2_moe/modular_qwen2_moe.py @@ -106,8 +106,8 @@ def forward(self, hidden_states): if self.norm_topk_prob: router_top_value /= router_top_value.sum(dim=-1, keepdim=True) router_top_value = router_top_value.to(router_logits.dtype) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices + router_scores = router_top_value + return router_logits, router_scores, router_indices class Qwen2MoeSparseMoeBlock(nn.Module): @@ -122,7 +122,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_reshaped = hidden_states.view(-1, hidden_dim) shared_expert_output = self.shared_expert(hidden_states_reshaped) - routing_weights, selected_experts = self.gate(hidden_states_reshaped) + _, routing_weights, selected_experts = self.gate(hidden_states_reshaped) expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output diff --git a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py index caf7e26a39fe..6c45c96a32d4 100644 --- a/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modeling_qwen3_moe.py @@ -231,22 +231,21 @@ def forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) - num_experts = top_k_weights.shape[1] with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: expert_idx = expert_idx[0] - if expert_idx == num_experts: + if expert_idx == self.num_experts: continue - _, token_idx = torch.where(expert_mask[expert_idx]) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) current_hidden_states = self.act_fn(gate) * up current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) return final_hidden_states @@ -269,20 +268,20 @@ def forward(self, hidden_states): if self.norm_topk_prob: router_top_value /= router_top_value.sum(dim=-1, keepdim=True) router_top_value = router_top_value.to(router_logits.dtype) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices + router_scores = router_top_value + return router_logits, router_scores, router_indices class Qwen3MoeSparseMoeBlock(nn.Module): def __init__(self, config: Qwen3MoeConfig): super().__init__() self.experts = Qwen3MoeExperts(config) - self.router = Qwen3MoeTopKRouter(config) + self.gate = Qwen3MoeTopKRouter(config) def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_reshaped = hidden_states.view(-1, hidden_dim) - routing_weights, selected_experts = self.router(hidden_states_reshaped) + _, routing_weights, selected_experts = self.gate(hidden_states_reshaped) final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights) return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) @@ -369,7 +368,7 @@ class Qwen3MoePreTrainedModel(PreTrainedModel): _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _supports_attention_backend = True _can_record_outputs = { - "router_logits": OutputRecorder(Qwen3MoeTopKRouter, layer_name="mlp.router", index=0), + "router_logits": OutputRecorder(Qwen3MoeTopKRouter, layer_name="mlp.gate", index=0), "hidden_states": Qwen3MoeDecoderLayer, "attentions": Qwen3MoeAttention, } diff --git a/src/transformers/models/qwen3_moe/modular_qwen3_moe.py b/src/transformers/models/qwen3_moe/modular_qwen3_moe.py index 6f4d5c53b820..17b3c42f6ccf 100644 --- a/src/transformers/models/qwen3_moe/modular_qwen3_moe.py +++ b/src/transformers/models/qwen3_moe/modular_qwen3_moe.py @@ -67,12 +67,12 @@ class Qwen3MoeSparseMoeBlock(nn.Module): def __init__(self, config: Qwen3MoeConfig): super().__init__() self.experts = Qwen3MoeExperts(config) - self.router = Qwen3MoeTopKRouter(config) + self.gate = Qwen3MoeTopKRouter(config) def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_reshaped = hidden_states.view(-1, hidden_dim) - routing_weights, selected_experts = self.router(hidden_states_reshaped) + _, routing_weights, selected_experts = self.gate(hidden_states_reshaped) final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights) return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) @@ -87,7 +87,7 @@ class Qwen3MoeDecoderLayer(Qwen2MoeDecoderLayer): class Qwen3MoePreTrainedModel(MixtralPreTrainedModel): _can_record_outputs = { - "router_logits": OutputRecorder(Qwen3MoeTopKRouter, layer_name="mlp.router", index=0), + "router_logits": OutputRecorder(Qwen3MoeTopKRouter, layer_name="mlp.gate", index=0), "hidden_states": Qwen3MoeDecoderLayer, "attentions": Qwen3MoeAttention, } diff --git a/src/transformers/models/qwen3_next/modeling_qwen3_next.py b/src/transformers/models/qwen3_next/modeling_qwen3_next.py index 96b5e19615ce..29d600ff74d4 100644 --- a/src/transformers/models/qwen3_next/modeling_qwen3_next.py +++ b/src/transformers/models/qwen3_next/modeling_qwen3_next.py @@ -840,22 +840,21 @@ def forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) - num_experts = top_k_weights.shape[1] with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: expert_idx = expert_idx[0] - if expert_idx == num_experts: + if expert_idx == self.num_experts: continue - _, token_idx = torch.where(expert_mask[expert_idx]) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) current_hidden_states = self.act_fn(gate) * up current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) return final_hidden_states @@ -878,8 +877,8 @@ def forward(self, hidden_states): if self.norm_topk_prob: router_top_value /= router_top_value.sum(dim=-1, keepdim=True) router_top_value = router_top_value.to(router_logits.dtype) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices + router_scores = router_top_value + return router_logits, router_scores, router_indices class Qwen3NextSparseMoeBlock(nn.Module): @@ -894,7 +893,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_reshaped = hidden_states.view(-1, hidden_dim) shared_expert_output = self.shared_expert(hidden_states_reshaped) - routing_weights, selected_experts = self.gate(hidden_states_reshaped) + _, routing_weights, selected_experts = self.gate(hidden_states_reshaped) expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output diff --git a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py index 636dab11c9d0..b0bfab4abfb1 100644 --- a/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py +++ b/src/transformers/models/qwen3_omni_moe/modeling_qwen3_omni_moe.py @@ -1113,7 +1113,7 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: def fast_pos_embed_interpolate(self, grid_thw): grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] - device = grid_thw.device + device = self.pos_embed.weight.device idx_list = [[] for _ in range(4)] weight_list = [[] for _ in range(4)] @@ -1338,22 +1338,21 @@ def forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) - num_experts = top_k_weights.shape[1] with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: expert_idx = expert_idx[0] - if expert_idx == num_experts: + if expert_idx == self.num_experts: continue - _, token_idx = torch.where(expert_mask[expert_idx]) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) current_hidden_states = self.act_fn(gate) * up current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) return final_hidden_states @@ -1376,20 +1375,20 @@ def forward(self, hidden_states): if self.norm_topk_prob: router_top_value /= router_top_value.sum(dim=-1, keepdim=True) router_top_value = router_top_value.to(router_logits.dtype) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices + router_scores = router_top_value + return router_logits, router_scores, router_indices class Qwen3OmniMoeThinkerTextSparseMoeBlock(nn.Module): def __init__(self, config: Qwen3OmniMoeThinkerConfig): super().__init__() self.experts = Qwen3OmniMoeThinkerTextExperts(config) - self.router = Qwen3OmniMoeThinkerTextTopKRouter(config) + self.gate = Qwen3OmniMoeThinkerTextTopKRouter(config) def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_reshaped = hidden_states.view(-1, hidden_dim) - routing_weights, selected_experts = self.router(hidden_states_reshaped) + _, routing_weights, selected_experts = self.gate(hidden_states_reshaped) final_hidden_states = self.experts(hidden_states_reshaped, selected_experts, routing_weights) return final_hidden_states.reshape(batch_size, sequence_length, hidden_dim) @@ -1599,7 +1598,7 @@ class Qwen3OmniMoeThinkerTextPreTrainedModel(PreTrainedModel): _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _supports_attention_backend = True _can_record_outputs = { - "router_logits": OutputRecorder(Qwen3OmniMoeThinkerTextTopKRouter, layer_name="mlp.router", index=0), + "router_logits": OutputRecorder(Qwen3OmniMoeThinkerTextTopKRouter, layer_name="mlp.gate", index=0), "hidden_states": Qwen3OmniMoeThinkerTextDecoderLayer, "attentions": Qwen3OmniMoeThinkerTextAttention, } @@ -2767,22 +2766,21 @@ def forward( top_k_weights: torch.Tensor, ) -> torch.Tensor: final_hidden_states = torch.zeros_like(hidden_states) - num_experts = top_k_weights.shape[1] with torch.no_grad(): - expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=num_experts + 1) + expert_mask = torch.nn.functional.one_hot(top_k_index, num_classes=self.num_experts) expert_mask = expert_mask.permute(2, 1, 0) expert_hit = torch.greater(expert_mask.sum(dim=(-1, -2)), 0).nonzero() for expert_idx in expert_hit: expert_idx = expert_idx[0] - if expert_idx == num_experts: + if expert_idx == self.num_experts: continue - _, token_idx = torch.where(expert_mask[expert_idx]) + top_k_pos, token_idx = torch.where(expert_mask[expert_idx]) current_state = hidden_states[token_idx] gate, up = nn.functional.linear(current_state, self.gate_up_proj[expert_idx]).chunk(2, dim=-1) current_hidden_states = self.act_fn(gate) * up current_hidden_states = nn.functional.linear(current_hidden_states, self.down_proj[expert_idx]) - current_hidden_states = current_hidden_states * top_k_weights[token_idx, expert_idx, None] + current_hidden_states = current_hidden_states * top_k_weights[token_idx, top_k_pos, None] final_hidden_states.index_add_(0, token_idx, current_hidden_states.to(final_hidden_states.dtype)) return final_hidden_states @@ -2805,8 +2803,8 @@ def forward(self, hidden_states): if self.norm_topk_prob: router_top_value /= router_top_value.sum(dim=-1, keepdim=True) router_top_value = router_top_value.to(router_logits.dtype) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices + router_scores = router_top_value + return router_logits, router_scores, router_indices class Qwen3OmniMoeTalkerTextSparseMoeBlock(nn.Module): @@ -2823,7 +2821,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens batch_size, sequence_length, hidden_dim = hidden_states.shape hidden_states_reshaped = hidden_states.view(-1, hidden_dim) shared_expert_output = self.shared_expert(hidden_states_reshaped) - routing_weights, selected_experts = self.gate(hidden_states_reshaped) + _, routing_weights, selected_experts = self.gate(hidden_states_reshaped) expert_output = self.experts(hidden_states_reshaped, selected_experts, routing_weights) shared_expert_output = F.sigmoid(self.shared_expert_gate(hidden_states_reshaped)) * shared_expert_output diff --git a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py index 5293f097bd46..5a7d1bbc54d5 100644 --- a/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modeling_qwen3_vl.py @@ -673,7 +673,7 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: def fast_pos_embed_interpolate(self, grid_thw): grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] - device = grid_thw.device + device = self.pos_embed.weight.device idx_list = [[] for _ in range(4)] weight_list = [[] for _ in range(4)] diff --git a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py index 82b385c53744..4bc9c0bc9b1d 100644 --- a/src/transformers/models/qwen3_vl/modular_qwen3_vl.py +++ b/src/transformers/models/qwen3_vl/modular_qwen3_vl.py @@ -569,7 +569,7 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: def fast_pos_embed_interpolate(self, grid_thw): grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] - device = grid_thw.device + device = self.pos_embed.weight.device idx_list = [[] for _ in range(4)] weight_list = [[] for _ in range(4)] diff --git a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py index 28e2c85f156c..dc48878f8b93 100644 --- a/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py +++ b/src/transformers/models/qwen3_vl_moe/modeling_qwen3_vl_moe.py @@ -385,8 +385,8 @@ def forward(self, hidden_states): if self.norm_topk_prob: router_top_value /= router_top_value.sum(dim=-1, keepdim=True) router_top_value = router_top_value.to(router_logits.dtype) - router_scores = torch.zeros_like(router_logits).scatter_(1, router_indices, router_top_value) - return router_scores, router_indices + router_scores = router_top_value + return router_logits, router_scores, router_indices @auto_docstring @@ -402,7 +402,7 @@ class Qwen3VLMoePreTrainedModel(PreTrainedModel): _can_compile_fullgraph = False # MoE models don't work with torch.compile (`torch.where(condition)` not supported) _supports_attention_backend = True _can_record_outputs = { - "router_logits": OutputRecorder(Qwen3VLMoeTextTopKRouter, layer_name="mlp.router", index=0), + "router_logits": OutputRecorder(Qwen3VLMoeTextTopKRouter, layer_name="mlp.gate", index=0), "hidden_states": Qwen3VLMoeTextDecoderLayer, "attentions": Qwen3VLMoeTextAttention, } @@ -687,7 +687,7 @@ def rot_pos_emb(self, grid_thw: torch.Tensor) -> torch.Tensor: def fast_pos_embed_interpolate(self, grid_thw): grid_ts, grid_hs, grid_ws = grid_thw[:, 0], grid_thw[:, 1], grid_thw[:, 2] - device = grid_thw.device + device = self.pos_embed.weight.device idx_list = [[] for _ in range(4)] weight_list = [[] for _ in range(4)] diff --git a/tests/models/olmoe/test_modeling_olmoe.py b/tests/models/olmoe/test_modeling_olmoe.py index 6e69a5fc8353..8712690e5d17 100644 --- a/tests/models/olmoe/test_modeling_olmoe.py +++ b/tests/models/olmoe/test_modeling_olmoe.py @@ -205,21 +205,21 @@ class OlmoeIntegrationTest(unittest.TestCase): def test_model_7b_logits(self): input_ids = [[1, 306, 4658, 278, 6593, 310, 2834, 338]] model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924", device_map="auto") - out = model(torch.tensor(input_ids)).logits.float() + out = model(torch.tensor(input_ids, device=model.device)).logits.float() # Expected mean on dim = -1 EXPECTED_MEAN = torch.tensor([[-1.3814, -3.4450, -2.2990, -1.9542, -2.4387, -2.7941, -2.9312, -2.8309]]) - torch.testing.assert_close(out.mean(-1), EXPECTED_MEAN, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(out.mean(-1).cpu(), EXPECTED_MEAN, rtol=1e-2, atol=1e-2) # slicing logits[0, 0, 0:30] EXPECTED_SLICE = torch.tensor([-2.3874, -2.4076, -2.4995, 4.2278, 1.4004, -0.0252, 0.4189, -2.7560, 0.3531, 1.6678, -0.7941, -1.1818, -0.2920, 0.7131, -1.4173, 1.6723, 0.5406, 0.1345, -0.1800, 0.2304, 1.2791, 0.7489, 0.6341, -0.0151, -1.3693, -1.2532, -2.3921, 0.7376, 1.6876, 0.5483]) # fmt: skip - torch.testing.assert_close(out[0, 0, :30], EXPECTED_SLICE, rtol=1e-2, atol=1e-2) + torch.testing.assert_close(out[0, 0, :30].cpu(), EXPECTED_SLICE, rtol=1e-2, atol=1e-2) @slow def test_model_7b_greedy_generation(self): EXPECTED_TEXT_COMPLETION = """Simply put, the theory of relativity states that \nthe speed of light is the same for all observers, no matter \nhow fast they are moving. This is a very counter-intuitive \nconcept, and it took Einstein a long time to come up with \nthe theory. The theory of relativity is based on two \npostulates""" prompt = "Simply put, the theory of relativity states that " tokenizer = AutoTokenizer.from_pretrained("allenai/OLMoE-1B-7B-0924", device_map="auto") - input_ids = tokenizer.encode(prompt, return_tensors="pt") model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924", device_map="auto") + input_ids = tokenizer.encode(prompt, return_tensors="pt").to(model.device) # greedy generation outputs generated_ids = model.generate(input_ids, max_new_tokens=64, top_p=None, temperature=1, do_sample=False) diff --git a/tests/test_modeling_common.py b/tests/test_modeling_common.py index d2b5e0949cac..3fb4e21909b2 100755 --- a/tests/test_modeling_common.py +++ b/tests/test_modeling_common.py @@ -2423,7 +2423,9 @@ def test_disk_offload_safetensors(self): max_memory = {0: max_size, "cpu": max_size} # This doesn't error out as it's in safetensors and doesn't need an offload folder - new_model = model_class.from_pretrained(tmp_dir, device_map="auto", max_memory=max_memory) + new_model = model_class.from_pretrained( + tmp_dir, device_map="auto", max_memory=max_memory, offload_folder=tmp_dir + ) self.check_device_map_is_respected(new_model, new_model.hf_device_map) torch.manual_seed(0)