@@ -136,8 +136,8 @@ class SparseMoE(MoE):
136136 # TODO: determine if we get it from external or extrat it in MoE class
137137 is_batch_sharded_by_expert: True if batch is sharded over 'expert' dim.
138138 """
139- def_sharding : Sharding
140- fed_sharding : Sharding
139+ edf_sharding : Sharding
140+ efd_sharding : Sharding
141141 num_experts_per_tok : int
142142 #TODO: tile size is (tile_batch_seq, tile_activation_dim, tile_weight_dim,) from MaxText
143143 tile_size : tuple [int , int , int ] = (128 , 64 , 128 )
@@ -155,24 +155,24 @@ def __post_init__(self, rngs: nnx.Rngs):
155155 shape_up = (self .num_local_experts , D , F )
156156 shape_down = (self .num_local_experts , F , D )
157157
158- self .kernel_gating_DEF = create_param (rngs ,
158+ self .kernel_gating_EDF = create_param (rngs ,
159159 shape = shape_gating ,
160160 dtype = self .dtype ,
161- sharding = self .def_sharding ,
161+ sharding = self .edf_sharding ,
162162 random_init = self .random_init )
163- self .kernel_up_proj_DEF = create_param (rngs ,
163+ self .kernel_up_proj_EDF = create_param (rngs ,
164164 shape = shape_up ,
165165 dtype = self .dtype ,
166- sharding = self .def_sharding ,
166+ sharding = self .edf_sharding ,
167167 random_init = self .random_init )
168- self .kernel_down_proj_FED = create_param (rngs ,
168+ self .kernel_down_proj_EFD = create_param (rngs ,
169169 shape = shape_down ,
170170 dtype = self .dtype ,
171- sharding = self .fed_sharding ,
171+ sharding = self .efd_sharding ,
172172 random_init = self .random_init )
173173
174174 # Derive the expert sharding
175- self .expert_axis_name = self .def_sharding [0 ]
175+ self .expert_axis_name = self .edf_sharding [0 ]
176176 if self .expert_axis_name is None :
177177 self .num_expert_parallelism = 1
178178 else :
@@ -597,10 +597,10 @@ def __call__(self, x_TD: Float):
597597 PartitionSpec (* self .activation_ffw_td ), # Sharded x_TD
598598 PartitionSpec (), # Replicated router_weights_TX
599599 PartitionSpec (), # Replicated selected_experts_TX
600- PartitionSpec (* self .def_sharding ), # Sharded gating kernel
601- PartitionSpec (* self .def_sharding ), # Sharded up-projection kernel
600+ PartitionSpec (* self .edf_sharding ), # Sharded gating kernel
601+ PartitionSpec (* self .edf_sharding ), # Sharded up-projection kernel
602602 PartitionSpec (
603- * self .fed_sharding ), # Sharded down-projection kernel
603+ * self .efd_sharding ), # Sharded down-projection kernel
604604 )
605605 out_specs = PartitionSpec (* self .activation_ffw_td )
606606
@@ -616,7 +616,7 @@ def __call__(self, x_TD: Float):
616616 x_TD ,
617617 router_weights_TX ,
618618 selected_experts_TX ,
619- self .kernel_gating_DEF .value ,
620- self .kernel_up_proj_DEF .value ,
621- self .kernel_down_proj_FED .value ,
619+ self .kernel_gating_EDF .value ,
620+ self .kernel_up_proj_EDF .value ,
621+ self .kernel_down_proj_EFD .value ,
622622 )
0 commit comments