30
30
from vllm .platforms import current_platform
31
31
from vllm .utils import direct_register_custom_op
32
32
33
+ from .interfaces import MixtureOfExperts
33
34
from .transformers import (
34
35
TransformersBase ,
35
36
TransformersForCausalLM ,
@@ -116,17 +117,41 @@ def transformers_moe_forward_fake(
116
117
)
117
118
118
119
119
- class TransformersMoEBase (TransformersBase ):
120
+ class TransformersMoEBase (TransformersBase , MixtureOfExperts ):
120
121
def __init__ (self , * , vllm_config , prefix = "" ):
121
122
self .check_version ("4.57.0.dev0" , "MoE models support" )
123
+ self .ep_group = get_ep_group ()
122
124
super ().__init__ (vllm_config = vllm_config , prefix = prefix )
123
125
124
- if self .parallel_config .enable_eplb :
125
- raise NotImplementedError (
126
- "Transformers backend does not support expert parallel load "
127
- "balancing yet."
126
+ def set_eplb_state (
127
+ self ,
128
+ expert_load_view : torch .Tensor ,
129
+ logical_to_physical_map : torch .Tensor ,
130
+ logical_replica_count : torch .Tensor ,
131
+ ):
132
+ for moe_layer_idx , mlp_layer in enumerate (self .mlp_layers ):
133
+ mlp_layer .experts .set_eplb_state (
134
+ moe_layer_idx = moe_layer_idx ,
135
+ expert_load_view = expert_load_view ,
136
+ logical_to_physical_map = logical_to_physical_map ,
137
+ logical_replica_count = logical_replica_count ,
128
138
)
129
139
140
+ def update_physical_experts_metadata (
141
+ self ,
142
+ num_physical_experts : int ,
143
+ num_local_physical_experts : int ,
144
+ ):
145
+ assert self .num_local_physical_experts == num_local_physical_experts
146
+ self .num_physical_experts = num_physical_experts
147
+ self .num_local_physical_experts = num_local_physical_experts
148
+ self .num_redundant_experts = num_physical_experts - self .num_logical_experts
149
+ for mlp in self .mlp_layers :
150
+ mlp .n_local_physical_experts = num_local_physical_experts
151
+ mlp .n_physical_experts = num_physical_experts
152
+ mlp .n_redundant_experts = self .num_redundant_experts
153
+ mlp .experts .update_expert_map ()
154
+
130
155
def get_expert_mapping (self ) -> list [tuple [str , str , int , str ]]:
131
156
"""
132
157
Params for weights, fp8 weight scales, fp8 activation scales
@@ -138,15 +163,17 @@ def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
138
163
("w1" , "w2" , "w3" ), # Granite, Mixtral, Phi MoE style
139
164
("linear" , "linear_1" , "linear_v" ), # Grok1 style
140
165
]
166
+ num_experts = self .model_config .get_num_experts ()
167
+ num_redundant_experts = self .parallel_config .eplb_config .num_redundant_experts
141
168
expert_mapping = []
142
169
for gate_proj , down_proj , up_proj in ckpt_names :
143
170
expert_mapping .extend (
144
171
FusedMoE .make_expert_params_mapping (
145
172
ckpt_gate_proj_name = gate_proj ,
146
173
ckpt_down_proj_name = down_proj ,
147
174
ckpt_up_proj_name = up_proj ,
148
- num_experts = self . model_config . get_num_experts () ,
149
- num_redundant_experts = 0 , # TODO: enable EPLB
175
+ num_experts = num_experts ,
176
+ num_redundant_experts = num_redundant_experts ,
150
177
)
151
178
)
152
179
return expert_mapping
@@ -167,12 +194,15 @@ def recursive_replace(self):
167
194
168
195
# If there are shared experts, the results are
169
196
# reduced after mlp.forward() not inside FusedMoE
170
- num_experts_shared = getattr_iter (
197
+ num_shared_experts = getattr_iter (
171
198
text_config ,
172
- ["num_experts_shared" , "n_shared_experts" , "moe_num_shared_experts" ],
199
+ [
200
+ "n_shared_experts" , # DeepSeek, Docs, GLM
201
+ "moe_num_shared_experts" , # Aria, Ernie
202
+ ],
173
203
0 ,
174
204
)
175
- reduce_results = num_experts_shared == 0
205
+ reduce_results = num_shared_experts == 0
176
206
177
207
def add_all_reduce (mlp : nn .Module ):
178
208
"""Adds an all-reduce to the output of `mlp.forward()`."""
@@ -207,13 +237,23 @@ def forward(self, *args, **kwargs):
207
237
# Expert mapping for `AutoWeightsLoader`
208
238
expert_mapping = self .get_expert_mapping ()
209
239
210
- # Configs
211
- parallel_config = self .parallel_config
212
- eplb_config = parallel_config .eplb_config
213
-
214
240
# Expert parallel load balancing kwargs
215
- enable_eplb = parallel_config .enable_eplb
216
- num_redundant_experts = eplb_config .num_redundant_experts
241
+ enable_eplb = self .parallel_config .enable_eplb
242
+ num_redundant_experts = self .parallel_config .eplb_config .num_redundant_experts
243
+
244
+ # MixtureOfExperts mixin settings
245
+ ep_size = self .ep_group .world_size
246
+
247
+ self .mlp_layers = [] # Used for MixtureOfExperts methods
248
+ self .expert_weights = []
249
+ self .num_moe_layers = 0
250
+ self .num_expert_groups = 1 if num_expert_group is None else num_expert_group
251
+ self .num_logical_experts = num_experts
252
+ self .num_physical_experts = num_experts + num_redundant_experts
253
+ self .num_local_physical_experts = self .num_physical_experts // ep_size
254
+ self .num_routed_experts = num_experts
255
+ self .num_shared_experts = num_shared_experts
256
+ self .num_redundant_experts = num_redundant_experts
217
257
218
258
# Recursively fuse MoE layers
219
259
def _recursive_replace (module : nn .Module , prefix : str ):
@@ -235,6 +275,9 @@ def _recursive_replace(module: nn.Module, prefix: str):
235
275
for mlp_param_name , _ in mlp .named_parameters ():
236
276
if "shared_expert" in mlp_param_name :
237
277
reduce_results = False
278
+ # If the config does not specify num_shared_experts, but
279
+ # the model has shared experts, we assume there is one.
280
+ self .num_shared_experts = 1
238
281
break
239
282
# Replace experts module with FusedMoE
240
283
fused_experts = TransformersFusedMoE (
@@ -258,6 +301,10 @@ def _recursive_replace(module: nn.Module, prefix: str):
258
301
)
259
302
mlp .experts = fused_experts
260
303
log_replacement (qual_name , experts , fused_experts )
304
+ # Update MixtureOfExperts mixin state
305
+ self .mlp_layers .append (mlp )
306
+ self .expert_weights .append (fused_experts .get_expert_weights ())
307
+ self .num_moe_layers += 1
261
308
# If results are not all-reduced in FusedMoE, ensure they
262
309
# are all-reduced at the end of mlp.forward() if tensor
263
310
# parallel or expert parallel is enabled
0 commit comments