From 6a9e43deae6f966546166d14eb8d11fd5639d2ac Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Thu, 12 Jun 2025 13:14:52 +0000 Subject: [PATCH 1/9] Add example for MOE in AutoParallel For now assumes it's balanced --- examples/example_moe.py | 146 ++++++++++++++++++++++++++++++++++++++++ 1 file changed, 146 insertions(+) create mode 100644 examples/example_moe.py diff --git a/examples/example_moe.py b/examples/example_moe.py new file mode 100644 index 00000000..debe3c42 --- /dev/null +++ b/examples/example_moe.py @@ -0,0 +1,146 @@ +import torch +from torch import nn +from torch.nn import functional as F + + +class FFN(nn.Module): + def __init__(self, in_channels, inter_channels): + super().__init__() + self.w1 = nn.Linear(in_channels, inter_channels) + self.w2 = nn.Linear(inter_channels, in_channels) + self.w3 = nn.Linear(in_channels, inter_channels) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class MOE(nn.Module): + def __init__(self, in_channels, inter_channels, num_experts): + super().__init__() + self.num_experts = num_experts + self.experts = nn.ModuleList(FFN(in_channels, inter_channels) for _ in range(num_experts)) + + def forward(self, x): + assert x.ndim == 3 + shape = x.shape + x = x.flatten(0, 1) + indices = torch.randint(0, self.num_experts, (x.shape[0],), dtype=torch.int64, device=x.device) + output = torch.zeros_like(x) + for i, expert in enumerate(self.experts): + idx = torch.where(indices == i) + output[idx] += expert(x[idx]) + return output.reshape(shape) + + +class BatchLinear(nn.Module): + def __init__(self, in_channels, out_channels, num_experts): + super().__init__() + self.weight = nn.Parameter(torch.randn(num_experts, out_channels, in_channels)) + self.bias = nn.Parameter(torch.randn(num_experts, out_channels)) + + def forward(self, x): + assert x.ndim == 3 + return x @ self.weight.transpose(-2, -1) + self.bias[:, None, :] + + +class BatchFFN(nn.Module): + def __init__(self, in_channels, inter_channels, num_experts): + super().__init__() + self.w1 = BatchLinear(in_channels, inter_channels, num_experts) + self.w2 = BatchLinear(inter_channels, in_channels, num_experts) + self.w3 = BatchLinear(in_channels, inter_channels, num_experts) + + def forward(self, x): + return self.w2(F.silu(self.w1(x)) * self.w3(x)) + + +class MOEBatched(nn.Module): + def __init__(self, in_channels, inter_channels, num_experts): + super().__init__() + self.num_experts = num_experts + self.experts = BatchFFN(in_channels, inter_channels, num_experts) + + def forward(self, x): + assert x.ndim == 3 + shape = x.shape + x = x.flatten(0, 1) + assert x.shape[0] % self.num_experts == 0 + # force balanced indices + indices = torch.randperm(x.shape[0], dtype=torch.int64, device=x.device) % self.num_experts + # put all tokens corresponding to the same expert together + idxs = indices.argsort() + xs = x[idxs].unflatten(0, (self.num_experts, -1)) + # now experts can be computed as bmm + out = self.experts(xs) + # put tokens back into its original order + out = out.flatten(0, 1) + new_idxs = idxs.argsort() + out = out[new_idxs] + return out.reshape(shape) + +""" + +in_channels = 64 +inter_channels = 128 +num_experts = 8 + +bs = 8 +seqlen = 64 + +x = torch.rand(bs, seqlen, in_channels).cuda() + +m = MOE(in_channels, inter_channels, num_experts).cuda() +m2 = MOEBatched(in_channels, inter_channels, num_experts).cuda() + +o = m(x) +o = m2(x) +""" + +from autoparallel.api import AutoParallel + +from torch.testing._internal.distributed.fake_pg import FakeStore + +world_size = 256 + +fake_store = FakeStore() +torch.distributed.init_process_group( + "fake", store=fake_store, rank=0, world_size=world_size +) +# mesh = torch.distributed.device_mesh.init_device_mesh("cuda", (world_size,), mesh_dim_names=("dp",)) +mesh = torch.distributed.device_mesh.init_device_mesh( + "cuda", + (world_size // 8, 8), + mesh_dim_names=( + "dp", + "tp", + ), +) + + +in_channels = 4096 +inter_channels = 4096 * 2 +num_experts = 8 + +bs = 8 * mesh.shape[0] +seqlen = 2048 + +def input_fn(): + return torch.rand(bs, seqlen, in_channels).cuda() + +def model_fn(): + return MOEBatched(in_channels, inter_channels, num_experts).cuda() + + +autop = AutoParallel(model_fn, input_fn, mesh) +autop.add_parameter_memory_constraint(low=None, high=None) + +from torch.distributed.tensor.placement_types import Replicate, Shard +x_sharding = (Shard(0), Replicate()) + +autop.add_input_constraints([x_sharding]) +autop.add_output_constraints([x_sharding]) + + +sharding_placement = autop.optimize_placement() + +from IPython import embed; embed(); sys.sdf From d1ceb77c5ae74972a3690f9bc3776fad7ee01489 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 25 Jul 2025 09:35:02 +0000 Subject: [PATCH 2/9] Update to latest main --- examples/example_moe.py | 33 ++++++++++++++++++++++++--------- 1 file changed, 24 insertions(+), 9 deletions(-) diff --git a/examples/example_moe.py b/examples/example_moe.py index debe3c42..92356004 100644 --- a/examples/example_moe.py +++ b/examples/example_moe.py @@ -18,13 +18,17 @@ class MOE(nn.Module): def __init__(self, in_channels, inter_channels, num_experts): super().__init__() self.num_experts = num_experts - self.experts = nn.ModuleList(FFN(in_channels, inter_channels) for _ in range(num_experts)) + self.experts = nn.ModuleList( + FFN(in_channels, inter_channels) for _ in range(num_experts) + ) def forward(self, x): assert x.ndim == 3 shape = x.shape x = x.flatten(0, 1) - indices = torch.randint(0, self.num_experts, (x.shape[0],), dtype=torch.int64, device=x.device) + indices = torch.randint( + 0, self.num_experts, (x.shape[0],), dtype=torch.int64, device=x.device + ) output = torch.zeros_like(x) for i, expert in enumerate(self.experts): idx = torch.where(indices == i) @@ -60,13 +64,19 @@ def __init__(self, in_channels, inter_channels, num_experts): self.num_experts = num_experts self.experts = BatchFFN(in_channels, inter_channels, num_experts) + def init_weights(self): + pass + def forward(self, x): assert x.ndim == 3 shape = x.shape x = x.flatten(0, 1) assert x.shape[0] % self.num_experts == 0 # force balanced indices - indices = torch.randperm(x.shape[0], dtype=torch.int64, device=x.device) % self.num_experts + indices = ( + torch.randperm(x.shape[0], dtype=torch.int64, device=x.device) + % self.num_experts + ) # put all tokens corresponding to the same expert together idxs = indices.argsort() xs = x[idxs].unflatten(0, (self.num_experts, -1)) @@ -78,6 +88,7 @@ def forward(self, x): out = out[new_idxs] return out.reshape(shape) + """ in_channels = 64 @@ -96,10 +107,10 @@ def forward(self, x): o = m2(x) """ -from autoparallel.api import AutoParallel - from torch.testing._internal.distributed.fake_pg import FakeStore +from autoparallel.api import AutoParallel + world_size = 256 fake_store = FakeStore() @@ -124,17 +135,23 @@ def forward(self, x): bs = 8 * mesh.shape[0] seqlen = 2048 + def input_fn(): return torch.rand(bs, seqlen, in_channels).cuda() + def model_fn(): - return MOEBatched(in_channels, inter_channels, num_experts).cuda() + return MOEBatched(in_channels, inter_channels, num_experts) + +with torch.device("meta"): + model = model_fn() -autop = AutoParallel(model_fn, input_fn, mesh) +autop = AutoParallel(model, input_fn, mesh) autop.add_parameter_memory_constraint(low=None, high=None) from torch.distributed.tensor.placement_types import Replicate, Shard + x_sharding = (Shard(0), Replicate()) autop.add_input_constraints([x_sharding]) @@ -142,5 +159,3 @@ def model_fn(): sharding_placement = autop.optimize_placement() - -from IPython import embed; embed(); sys.sdf From fa295674abd038f1266653ac352cd0ff943f7153 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 25 Jul 2025 12:24:13 +0000 Subject: [PATCH 3/9] Increase number of experts so that they can be added to the dp dim --- examples/example_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/example_moe.py b/examples/example_moe.py index 92356004..5c6676da 100644 --- a/examples/example_moe.py +++ b/examples/example_moe.py @@ -130,7 +130,7 @@ def forward(self, x): in_channels = 4096 inter_channels = 4096 * 2 -num_experts = 8 +num_experts = 64 bs = 8 * mesh.shape[0] seqlen = 2048 From 4be83cc1cb72f6f882144acaecd02c1160ecf63c Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 25 Jul 2025 14:40:38 +0000 Subject: [PATCH 4/9] Add debug example and use values from DeepSeekV3 --- examples/example_moe.py | 36 +++++++++++++++++++++++++++--------- 1 file changed, 27 insertions(+), 9 deletions(-) diff --git a/examples/example_moe.py b/examples/example_moe.py index 5c6676da..cd14ead2 100644 --- a/examples/example_moe.py +++ b/examples/example_moe.py @@ -89,6 +89,23 @@ def forward(self, x): return out.reshape(shape) +class MOEBatchedDebug(nn.Module): + def __init__(self, in_channels, inter_channels, num_experts): + super().__init__() + self.num_experts = num_experts + self.experts = BatchFFN(in_channels, inter_channels, num_experts) + + def init_weights(self): + pass + + def forward(self, x): + assert x.ndim == 3 + shape = x.shape + xs = x.unflatten(1, (self.num_experts, -1)).permute(1, 0, 2, 3).flatten(1, 2) + out = self.experts(xs) + return out.reshape(shape) + + """ in_channels = 64 @@ -111,7 +128,7 @@ def forward(self, x): from autoparallel.api import AutoParallel -world_size = 256 +world_size = 2048 fake_store = FakeStore() torch.distributed.init_process_group( @@ -120,7 +137,7 @@ def forward(self, x): # mesh = torch.distributed.device_mesh.init_device_mesh("cuda", (world_size,), mesh_dim_names=("dp",)) mesh = torch.distributed.device_mesh.init_device_mesh( "cuda", - (world_size // 8, 8), + (world_size // 64, 64), mesh_dim_names=( "dp", "tp", @@ -128,12 +145,12 @@ def forward(self, x): ) -in_channels = 4096 -inter_channels = 4096 * 2 -num_experts = 64 +in_channels = 7168 +inter_channels = 2048 +num_experts = 128 -bs = 8 * mesh.shape[0] -seqlen = 2048 +bs = 8 * mesh.shape[0] * mesh.shape[1] +seqlen = 2048 * 2 def input_fn(): @@ -141,7 +158,8 @@ def input_fn(): def model_fn(): - return MOEBatched(in_channels, inter_channels, num_experts) + # return MOEBatched(in_channels, inter_channels, num_experts) + return MOEBatchedDebug(in_channels, inter_channels, num_experts) with torch.device("meta"): @@ -152,7 +170,7 @@ def model_fn(): from torch.distributed.tensor.placement_types import Replicate, Shard -x_sharding = (Shard(0), Replicate()) +x_sharding = (Shard(0), Shard(0)) autop.add_input_constraints([x_sharding]) autop.add_output_constraints([x_sharding]) From 452106723c8ac52a9a467080e75f7aca3d230ff5 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 25 Jul 2025 14:55:13 +0000 Subject: [PATCH 5/9] Add top_k experts Everything seems to be working as expected! --- examples/example_moe.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/examples/example_moe.py b/examples/example_moe.py index cd14ead2..da950018 100644 --- a/examples/example_moe.py +++ b/examples/example_moe.py @@ -94,6 +94,7 @@ def __init__(self, in_channels, inter_channels, num_experts): super().__init__() self.num_experts = num_experts self.experts = BatchFFN(in_channels, inter_channels, num_experts) + self.top_k = 4 def init_weights(self): pass @@ -101,8 +102,14 @@ def init_weights(self): def forward(self, x): assert x.ndim == 3 shape = x.shape - xs = x.unflatten(1, (self.num_experts, -1)).permute(1, 0, 2, 3).flatten(1, 2) + xs = ( + x.unflatten(1, (self.num_experts, -1)) + .permute(1, 0, 2, 3) + .repeat(1, 1, self.top_k, 1) + .flatten(1, 2) + ) out = self.experts(xs) + out = out.unflatten(1, (-1, self.top_k)).sum(2) return out.reshape(shape) From 09dd86f3592e4ef7108d575301653232699cbae3 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 25 Jul 2025 16:09:09 +0000 Subject: [PATCH 6/9] Fix meshdim name --- examples/example_moe.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/example_moe.py b/examples/example_moe.py index da950018..3ebf6d61 100644 --- a/examples/example_moe.py +++ b/examples/example_moe.py @@ -147,7 +147,7 @@ def forward(self, x): (world_size // 64, 64), mesh_dim_names=( "dp", - "tp", + "ep", ), ) From 94856a8dd3153a8c87a1364ef210f03b70824c18 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Sat, 27 Sep 2025 15:49:57 +0000 Subject: [PATCH 7/9] Update to latest main --- examples/example_moe.py | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/examples/example_moe.py b/examples/example_moe.py index 3ebf6d61..09a279de 100644 --- a/examples/example_moe.py +++ b/examples/example_moe.py @@ -172,15 +172,18 @@ def model_fn(): with torch.device("meta"): model = model_fn() -autop = AutoParallel(model, input_fn, mesh) -autop.add_parameter_memory_constraint(low=None, high=None) +with AutoParallel(model, input_fn, mesh) as autop: + autop.add_parameter_memory_constraint(low=None, high=None) -from torch.distributed.tensor.placement_types import Replicate, Shard + from torch.distributed.tensor.placement_types import Replicate, Shard -x_sharding = (Shard(0), Shard(0)) + x_sharding = (Shard(0), Shard(0)) -autop.add_input_constraints([x_sharding]) -autop.add_output_constraints([x_sharding]) + autop.add_input_constraints([x_sharding]) + autop.add_output_constraints([x_sharding]) + import time -sharding_placement = autop.optimize_placement() + t = time.time() + sharding_placement = autop.optimize_placement() + print(f"Took {time.time() - t:.2f} s") From b3beb6ba126b590655c81f78f362d6a05cc45739 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Mon, 29 Sep 2025 12:06:57 +0000 Subject: [PATCH 8/9] Add more complete example --- autoparallel/api.py | 2 ++ examples/example_moe.py | 69 +++++++++++++++++++++++++++++------------ 2 files changed, 52 insertions(+), 19 deletions(-) diff --git a/autoparallel/api.py b/autoparallel/api.py index e5da5d67..70e65dbe 100644 --- a/autoparallel/api.py +++ b/autoparallel/api.py @@ -60,6 +60,8 @@ def _get_decomp_table(): decomp_table.pop(torch.ops.aten.native_layer_norm.default) decomp_table.pop(torch.ops.aten.embedding_dense_backward.default) decomp_table.pop(torch.ops.aten.native_layer_norm_backward.default) + decomp_table.pop(torch.ops.aten._softmax_backward_data.default) + decomp_table.pop(torch.ops.aten._softmax.default) # decompose addmm to allow for TP on mm decomp_table.pop(torch.ops.aten.addmm.default) diff --git a/examples/example_moe.py b/examples/example_moe.py index 09a279de..c1b7f0ba 100644 --- a/examples/example_moe.py +++ b/examples/example_moe.py @@ -62,31 +62,46 @@ class MOEBatched(nn.Module): def __init__(self, in_channels, inter_channels, num_experts): super().__init__() self.num_experts = num_experts + # TODO: need to fix case with bias, as the parameter memory constraint is not satisfied + # because too many GPUs for the number of experts + self.router = nn.Linear(in_channels, num_experts, bias=False) self.experts = BatchFFN(in_channels, inter_channels, num_experts) + self.top_k = 4 def init_weights(self): pass def forward(self, x): assert x.ndim == 3 - shape = x.shape - x = x.flatten(0, 1) - assert x.shape[0] % self.num_experts == 0 - # force balanced indices - indices = ( - torch.randperm(x.shape[0], dtype=torch.int64, device=x.device) - % self.num_experts - ) - # put all tokens corresponding to the same expert together - idxs = indices.argsort() - xs = x[idxs].unflatten(0, (self.num_experts, -1)) - # now experts can be computed as bmm + + # route tokens to experts + scores = self.router(x) + + # select topk experts following some criteria + dim = -1 + scores = F.softmax(scores, dim=dim) + # TODO: this is wrong, we need to do a sinkhorn here to ensure that the tokens are evenly distributed + top_scores, selected_experts_indices = torch.topk(scores, k=self.top_k, dim=dim) + idxs = selected_experts_indices.flatten(-2, -1).argsort(dim=-1, stable=True) + top_scores = top_scores.flatten(-2, -1).gather(-1, idxs) + idxs = idxs // self.top_k + + # route tokens for each expert + xs = x.gather(-2, idxs[:, :, None].expand(-1, -1, x.shape[-1])) + xs = xs.unflatten(1, (self.num_experts, -1)) + tokens_per_expert = xs.shape[2] + xs = xs.permute(1, 0, 2, 3).flatten(1, 2) out = self.experts(xs) - # put tokens back into its original order - out = out.flatten(0, 1) - new_idxs = idxs.argsort() - out = out[new_idxs] - return out.reshape(shape) + + out = out.unflatten(1, (-1, tokens_per_expert)) + out = out.permute(1, 0, 2, 3).flatten(1, 2) + + out = out * top_scores[:, :, None] + + # TODO: add shared expert + res = torch.zeros_like(x) + res = res.scatter_add(-2, idxs[:, :, None].expand(-1, -1, x.shape[-1]), out) + return res class MOEBatchedDebug(nn.Module): @@ -165,8 +180,8 @@ def input_fn(): def model_fn(): - # return MOEBatched(in_channels, inter_channels, num_experts) - return MOEBatchedDebug(in_channels, inter_channels, num_experts) + return MOEBatched(in_channels, inter_channels, num_experts) + # return MOEBatchedDebug(in_channels, inter_channels, num_experts) with torch.device("meta"): @@ -187,3 +202,19 @@ def model_fn(): t = time.time() sharding_placement = autop.optimize_placement() print(f"Took {time.time() - t:.2f} s") + parallel_mod = autop.apply_placement(sharding_placement) + +# run weight init on our sharded DTensor params +parallel_mod.to_empty(device="cuda") +parallel_mod.init_weights() + +# now let's run it +x = ( + torch.randn( + (bs // mesh.shape[0] // mesh.shape[1], seqlen, in_channels), + device=torch.device("cuda"), + ), +) +out = parallel_mod(*x) +out.backward(torch.randn_like(out)) +print("All good!") From 6c7ddbbae8fc8447d3bbb78e98fa64d661657c6d Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Wed, 1 Oct 2025 13:43:16 +0000 Subject: [PATCH 9/9] Add balanced token->expert selection Very slow, need to try Sinkhorn-Knopp --- examples/example_moe.py | 120 +++++++++++++++++++++++++++++++++++++--- 1 file changed, 113 insertions(+), 7 deletions(-) diff --git a/examples/example_moe.py b/examples/example_moe.py index c1b7f0ba..1f6df71c 100644 --- a/examples/example_moe.py +++ b/examples/example_moe.py @@ -1,6 +1,11 @@ import torch from torch import nn +from torch.distributed.tensor.placement_types import Replicate, Shard from torch.nn import functional as F +from torch.testing._internal.distributed.fake_pg import FakeStore + +from autoparallel.api import AutoParallel +from autoparallel.propagation_rules import register_opschema_rule class FFN(nn.Module): @@ -58,6 +63,107 @@ def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) +def _init_approximate_solution(scores, top_k, variables): + top_idxs = scores.topk(top_k, dim=-1).indices.tolist() + for token in variables.keys(): + for expert in variables[token].keys(): + token_id = int(token.split("_")[1]) + expert_id = int(expert.split("_")[1]) + initial_value = 1 if expert_id in top_idxs[token_id] else 0 + variables[token][expert].setInitialValue(initial_value) + + +def _assign_tokens_per_expert_2d(scores, top_k, init_sol=True, time_limit=1.0): + import pulp + + num_per_expert = scores.shape[0] // scores.shape[1] * top_k + prob = pulp.LpProblem("TokenExpertAssignment", pulp.LpMaximize) + experts = ["expert_{}".format(i) for i in range(scores.shape[1])] + tokens = ["token_{}".format(i) for i in range(scores.shape[0])] + scores_dict = pulp.makeDict([tokens, experts], scores.tolist(), 0) + variables = pulp.LpVariable.dicts("var", (tokens, experts), cat=pulp.LpBinary) + for token in tokens: + prob += pulp.lpSum([variables[token][expert] for expert in experts]) == top_k + for expert in experts: + prob += ( + pulp.lpSum([variables[token][expert] for token in tokens]) == num_per_expert + ) + + if init_sol: + _init_approximate_solution(scores, top_k, variables) + prob += pulp.lpSum( + [ + variables[token][expert] * scores_dict[token][expert] + for token in tokens + for expert in experts + ] + ) + verbose = False + solver = pulp.PULP_CBC_CMD(msg=verbose, warmStart=init_sol, timeLimit=time_limit) + prob.solve(solver) + res = [[variables[token][expert].value() for expert in experts] for token in tokens] + res = [[i for i, v in enumerate(r) if v == 1] for r in res] + return torch.tensor(res, dtype=torch.int32, device=scores.device) + + +@torch.library.custom_op("autoparallel::assign_tokens_to_experts", mutates_args=()) +def assign_tokens_to_experts(scores: torch.Tensor, top_k: int) -> torch.Tensor: + """ + MILP formulation of the token assignment problem, guarantees that each token + is assigned to exactly top_k experts, and every expert is assigned to exactly + the same number of tokens. + + NOTE: This performs a GPU->CPU transfer! Need to implement a working version of + Sinkhorn-Knopp on GPU to avoid this. + + NOTE: The MILP solver is *slow* and can take a long time to converge. + """ + shape = scores.shape[:-1] + scores_flat = scores.flatten(0, -3) + res = [] + for score in scores_flat: + assert score.ndim == 2, f"score must be 2D, got {score.shape}" + res.append(_assign_tokens_per_expert_2d(score, top_k)) + return torch.stack(res, dim=0).reshape(shape + (top_k,)) + + +@assign_tokens_to_experts.register_fake +def _(scores, top_k): + return torch.empty( + tuple(scores.shape[:-1]) + (top_k,), device=scores.device, dtype=torch.int32 + ) + + +@register_opschema_rule(torch.ops.autoparallel.assign_tokens_to_experts.default) +def _(mesh, op_schema): + from torch.distributed.tensor._ops.utils import expand_to_full_mesh_op_strategy + + mat1_strategy = op_schema.args_schema[0] + + assert len(mat1_strategy.shape) == 3 + + single_mesh_dim_strategies = [] + + # placement list stores placements of [output, mat1] + # first we always have replicate all for inputs and output + single_mesh_dim_strategies.append([Replicate(), Replicate()]) + single_mesh_dim_strategies.append([Shard(0), Shard(0)]) + + return expand_to_full_mesh_op_strategy( + mesh, op_schema, single_mesh_dim_strategies, input_index=1 + ) + + +# scores = torch.rand(1, 8192, 128, device="cuda") +# for i in range(0, 32): +# k = 8192 // 128 * 4 +# scores[:, i * k: (i + i) * k, i * 4 : (i + 1) * 4] += 1 +# r = assign_tokens_to_experts(scores.cpu(), 4) +# r = _assign_tokens_per_expert_2d(scores[0], 4) +# r = torch.ops.autoparallel.assign_tokens_to_experts(scores, 4) +# from IPython import embed; embed(); exit() + + class MOEBatched(nn.Module): def __init__(self, in_channels, inter_channels, num_experts): super().__init__() @@ -81,13 +187,18 @@ def forward(self, x): dim = -1 scores = F.softmax(scores, dim=dim) # TODO: this is wrong, we need to do a sinkhorn here to ensure that the tokens are evenly distributed - top_scores, selected_experts_indices = torch.topk(scores, k=self.top_k, dim=dim) + # top_scores, selected_experts_indices = torch.topk(scores, k=self.top_k, dim=dim) + selected_experts_indices = torch.ops.autoparallel.assign_tokens_to_experts( + scores, self.top_k + ) + top_scores = scores.gather(dim, selected_experts_indices) idxs = selected_experts_indices.flatten(-2, -1).argsort(dim=-1, stable=True) top_scores = top_scores.flatten(-2, -1).gather(-1, idxs) idxs = idxs // self.top_k # route tokens for each expert xs = x.gather(-2, idxs[:, :, None].expand(-1, -1, x.shape[-1])) + # this assumes the experts are balanced xs = xs.unflatten(1, (self.num_experts, -1)) tokens_per_expert = xs.shape[2] xs = xs.permute(1, 0, 2, 3).flatten(1, 2) @@ -146,9 +257,6 @@ def forward(self, x): o = m2(x) """ -from torch.testing._internal.distributed.fake_pg import FakeStore - -from autoparallel.api import AutoParallel world_size = 2048 @@ -172,7 +280,7 @@ def forward(self, x): num_experts = 128 bs = 8 * mesh.shape[0] * mesh.shape[1] -seqlen = 2048 * 2 +seqlen = 2048 * 2 * 2 def input_fn(): @@ -190,8 +298,6 @@ def model_fn(): with AutoParallel(model, input_fn, mesh) as autop: autop.add_parameter_memory_constraint(low=None, high=None) - from torch.distributed.tensor.placement_types import Replicate, Shard - x_sharding = (Shard(0), Shard(0)) autop.add_input_constraints([x_sharding])