Skip to content

Commit 35e8bf8

Browse files
committed
fix modular mixtral
1 parent 15f41b9 commit 35e8bf8

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

src/transformers/models/mixtral/modular_mixtral.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,7 +204,7 @@ def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor, torch.Tens
204204
if self.training and self.jitter_noise > 0:
205205
hidden_states *= torch.empty_like(hidden_states).uniform_(1.0 - self.jitter_noise, 1.0 + self.jitter_noise)
206206
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
207-
_,top_k_weights, top_k_index = self.gate(hidden_states)
207+
_, top_k_weights, top_k_index = self.gate(hidden_states)
208208
hidden_states = self.experts(hidden_states, top_k_index, top_k_weights)
209209
hidden_states = hidden_states.reshape(batch_size, sequence_length, hidden_dim)
210210
return hidden_states

0 commit comments

Comments
 (0)