From 57b7fb36ce645c1abd794fe3de0a4c8b36b60a7a Mon Sep 17 00:00:00 2001 From: bzgoogle Date: Fri, 7 Nov 2025 18:27:40 +0000 Subject: [PATCH] [GPT-OSS] inital commit to use MoE kernel --- tpu_inference/kernels/fused_moe/v1/kernel.py | 49 ++++++++------ tpu_inference/layers/jax/moe/gpt_oss_moe.py | 69 ++++++++++---------- tpu_inference/models/jax/gpt_oss.py | 19 ++++-- 3 files changed, 77 insertions(+), 60 deletions(-) diff --git a/tpu_inference/kernels/fused_moe/v1/kernel.py b/tpu_inference/kernels/fused_moe/v1/kernel.py index 8df0e1ab3..e9e01d77f 100644 --- a/tpu_inference/kernels/fused_moe/v1/kernel.py +++ b/tpu_inference/kernels/fused_moe/v1/kernel.py @@ -7,6 +7,7 @@ from jax import lax from jax._src import dtypes from jax.experimental import pallas as pl +from jax.experimental import shard_map from jax.experimental.pallas import tpu as pltpu P = jax.sharding.PartitionSpec @@ -144,7 +145,7 @@ def _fused_ep_moe_kernel( a2a_acc_sem, *, top_k: int, - ep_name: str, + ep_axis_name: str, # Kernel tuning params. bt: int, # Block size of local_num_tokens. bf: int, # Block size of intermediate_size. @@ -155,8 +156,8 @@ def _fused_ep_moe_kernel( bd1c: int, # Compute size of block hidden_size. bd2c: int, # Compute size of block hidden_size. ): - my_id = lax.axis_index(ep_name) - num_devices = lax.axis_size(ep_name) + my_id = lax.axis_index(ep_axis_name) + num_devices = lax.axis_size(ep_axis_name) local_num_tokens = tokens_hbm.shape[0] local_num_experts, intermediate_size, hidden_size = w2_hbm.shape # num_experts = local_num_experts * num_devices @@ -186,13 +187,13 @@ def sync_barrier(): barrier_sem = pltpu.get_barrier_semaphore() pltpu.semaphore_signal( barrier_sem, - device_id=right_id, - device_id_type=pltpu.DeviceIdType.LOGICAL, + device_id=(0, right_id), + device_id_type=pltpu.DeviceIdType.MESH, ) pltpu.semaphore_wait(barrier_sem, 1) def start_fetch_b_gating(bt_id, priority=0): - is_valid = jnp.logical_and(0 <= bt_id, bt_id < num_bt) + is_valid = jnp.logical_and(bt_id >= 0, bt_id < num_bt) sz = pl.multiple_of(lax.select(is_valid, bt, 0), bt) bt_sem_id = (bt_id + 2) % 2 b_gating_sem = local_sems.at[bt_sem_id, 0] @@ -276,7 +277,7 @@ def _all_reduce_metadata( dst_ref=d2e_count_vmem.at[row_id], send_sem=send_sem, recv_sem=recv_sem, - device_id=(right_id, ), + device_id=(0, right_id), device_id_type=pltpu.DeviceIdType.MESH, ).wait() row_id = (row_id + num_devices - 1) % num_devices @@ -358,7 +359,10 @@ def start_a2a_scatter(bt_id, e_sem_id, local_e_id): pl.ds(start, remote_sz)], send_sem=send_sems.at[e_sem_id], recv_sem=recv_sems.at[e_sem_id], - device_id=(recv_id, ), + device_id=( + 0, + recv_id, + ), ).start() a2a_s_sends_x2_smem[e_sem_id] = send_sz @@ -402,7 +406,7 @@ def start_a2a_gather(bt_id, e_sem_id, local_e_id): dst_ref=a2a_g_hbm.at[my_e_id, pl.ds(0, remote_sz)], send_sem=send_sems.at[e_sem_id], recv_sem=a2a_gather_sem, - device_id=(recv_id, ), + device_id=(0, recv_id), ).start() start += sz @@ -412,7 +416,7 @@ def wait_a2a_gather_send(bt_id, e_sem_id, local_e_id): sz = expert_sizes_x2_smem[bt_sem_id, 0, my_e_id] local_sz = d2e_count_x2_smem[bt_sem_id, my_id, 0, my_e_id] remote_sz = sz - local_sz - is_valid = jnp.logical_and(0 <= local_e_id, local_e_id + is_valid = jnp.logical_and(local_e_id >= 0, local_e_id < local_num_experts) remote_sz = lax.select(is_valid, remote_sz, 0) pltpu.make_async_copy( @@ -731,7 +735,7 @@ def start_send_bo(bt_id, priority=0): ).start(priority=priority) def wait_send_bo(bt_id): - is_valid = jnp.logical_and(0 <= bt_id, bt_id < num_bt) + is_valid = jnp.logical_and(bt_id >= 0, bt_id < num_bt) sz = pl.multiple_of(lax.select(is_valid, bt, 0), bt) bt_sem_id = (bt_id + 2) % 2 b_output_sem = local_sems.at[bt_sem_id, 4] @@ -831,6 +835,7 @@ def _(): "bfc", "bd1c", "bd2c", + "ep_axis_name", ], ) def fused_ep_moe( @@ -850,12 +855,14 @@ def fused_ep_moe( bfc: int, bd1c: int, bd2c: int, + ep_axis_name: str = 'model', ): - if len(mesh.axis_names) != 1: - raise ValueError("Mesh must have only one axis") + # Assert all other axes have length of 1 + assert len(mesh.shape) == 2, "Expect 2D mesh in tpu-inference" + assert 'data' in mesh.shape and mesh.shape['data'] == 1, \ + "Expect data axis size of 1 in tpu-inference" - ep_name = mesh.axis_names[0] - ep_size = mesh.axis_sizes[0] + ep_size = mesh.shape[ep_axis_name] num_devices = ep_size num_tokens, actual_hidden_size = tokens.shape @@ -907,7 +914,7 @@ def fused_ep_moe( functools.partial( _fused_ep_moe_kernel, top_k=top_k, - ep_name=ep_name, + ep_axis_name=ep_axis_name, bt=bt, bf=bf, bd1=bd1, @@ -999,11 +1006,13 @@ def fused_ep_moe( )) @jax.jit - @jax.shard_map( + @functools.partial( + shard_map.shard_map, mesh=mesh, - in_specs=(P(ep_name), P(ep_name), P(ep_name), P(ep_name), P()), - out_specs=P(ep_name), - check_vma=False, + in_specs=(P(ep_axis_name), P(ep_axis_name), P(ep_axis_name), + P(ep_axis_name), P()), + out_specs=P(ep_axis_name), + check_rep=False, ) def kernel(tokens, w1, w2, gating_output, a2a_g_hbm_scratch): return fused_moe( diff --git a/tpu_inference/layers/jax/moe/gpt_oss_moe.py b/tpu_inference/layers/jax/moe/gpt_oss_moe.py index 4fe26dda3..8f7c7488d 100644 --- a/tpu_inference/layers/jax/moe/gpt_oss_moe.py +++ b/tpu_inference/layers/jax/moe/gpt_oss_moe.py @@ -4,8 +4,10 @@ import jax.numpy as jnp from flax import nnx from flax.typing import Sharding +from jax.sharding import Mesh from jaxtyping import Float +from tpu_inference.kernels.fused_moe.v1.kernel import fused_ep_moe from tpu_inference.layers.jax.base import create_param from tpu_inference.layers.jax.layers import FlaxUtils from tpu_inference.layers.jax.moe.moe import Router @@ -46,13 +48,7 @@ def __call__(self, x_TD: Float): router_logits_TE += self.bias_E.value - weights_TX, selected_experts_TX = jax.lax.top_k( - router_logits_TE, self.num_experts_per_tok) - - normalized_weights_TX = jax.nn.softmax(weights_TX.astype(self.dtype), - axis=-1) - - return normalized_weights_TX, selected_experts_TX + return router_logits_TE def _swiglu(x: Float, alpha: Float, limit: Float) -> Float: @@ -90,37 +86,44 @@ class GptOssMoE(nnx.Module): random_init: bool = False + mesh: Mesh + def __call__(self, x_TD: Float) -> Float: """Performs the forward pass for the GPT-OSS MoE layer.""" x_TD = jnp.asarray(x_TD, self.dtype) x_TD = nnx.with_sharding_constraint(x_TD, self.activation_ffw_td) - weights_TX, indices_TX = self.router(x_TD) - - # First MLP layer (up-projection) - with jax.named_scope("MLP #1"): - up_proj_TEF2 = jnp.einsum('TD,EDF -> TEF', x_TD, - self.mlp1_weight_EDF2.value) - up_proj_TEF2 += self.mlp1_bias_EF2.value - - fuse_TEF = _swiglu(up_proj_TEF2, - alpha=self.swiglu_alpha, - limit=self.swiglu_limit) - - # Second MLP layer (down-projection) - with jax.named_scope("MLP #2"): - down_proj_TED = jnp.einsum('TEF,EFD -> TED', fuse_TEF, - self.mlp2_weight_EFD.value) - down_proj_TED += self.mlp2_bias_ED.value - - # Weighted sum of expert outputs - with jax.named_scope("sum"): - indices_for_gather = indices_TX[..., None] - gathered_down_proj_TED = jnp.take_along_axis(down_proj_TED, - indices_for_gather, - axis=1) - output_TD = jnp.einsum('TXD,TX -> TD', gathered_down_proj_TED, - weights_TX) + router_logits_TE = self.router(x_TD) + + block_size = { + "bt": 32, + "bf": 512, + "bd1": 512, + "bd2": 512, + "btc": 32, + "bfc": 256, + "bd1c": 256, + "bd2c": 256, + } + ep_axis_name = self.efd_sharding[0] + # TODO: Currently, we must reshape the tensors to fit the MoE kernel's + # required shape. We will eliminate this step and load the tensors in + # their desired final shape once the weight loading process(with fp4 + # support) is finalized. + mlp1_weight_E2DF = jnp.swapaxes( + jnp.reshape(self.mlp1_weight_EDF2.value, + (self.num_local_experts, self.hidden_size, 2, + self.intermediate_size_moe)), 1, 2) + output_TD = fused_ep_moe( + mesh=self.mesh, + tokens=x_TD, + w1=mlp1_weight_E2DF, + w2=self.mlp2_weight_EFD.value, + gating_output=router_logits_TE, + top_k=self.router.num_experts_per_tok, + ep_axis_name=ep_axis_name, + **block_size, + ) return output_TD.astype(self.dtype) diff --git a/tpu_inference/models/jax/gpt_oss.py b/tpu_inference/models/jax/gpt_oss.py index a6adbc71f..86e71632a 100644 --- a/tpu_inference/models/jax/gpt_oss.py +++ b/tpu_inference/models/jax/gpt_oss.py @@ -1,6 +1,5 @@ import re from dataclasses import dataclass -from typing import List, Optional, Tuple import jax import jax.numpy as jnp @@ -9,17 +8,22 @@ from flax.typing import PRNGKey from jax.sharding import Mesh, NamedSharding from jax.sharding import PartitionSpec as P -from vllm.config import VllmConfig from tpu_inference.layers.jax.attention.gpt_oss_attention import ( - AttentionMetadata, GptOssAttention) + AttentionMetadata, + GptOssAttention, +) from tpu_inference.layers.jax.constants import KVCacheType from tpu_inference.layers.jax.layers import Embedder, LMhead, RMSNorm from tpu_inference.layers.jax.moe.gpt_oss_moe import GptOssMoE, GptOssRouter from tpu_inference.layers.jax.transformer_block import TransformerBlock from tpu_inference.logger import init_logger from tpu_inference.models.jax.utils.weight_utils import ( - get_param, model_weights_generator, print_param_info) + get_param, + model_weights_generator, + print_param_info, +) +from vllm.config import VllmConfig logger = init_logger(__name__) @@ -136,6 +140,7 @@ def __init__(self, edf_sharding=('model', None, None), efd_sharding=('model', None, None), ed_sharding=('model', None), + mesh=self.mesh, ) block = TransformerBlock( @@ -180,7 +185,7 @@ def __init__(self, def apply(self, variables, *args, **kwargs): return self.__call__(*args, **kwargs) - def load_weights(self, rng: PRNGKey, cache_dir: Optional[str] = None): + def load_weights(self, rng: PRNGKey, cache_dir: str | None = None): """Loads and transforms all weights from a checkpoint""" self.rng = nnx.Rngs(rng) @@ -328,11 +333,11 @@ def get_slice(index): def __call__( self, - kv_caches: List[jax.Array], + kv_caches: list[jax.Array], input_ids: jax.Array, attention_metadata: AttentionMetadata, *args, - ) -> Tuple[List[KVCacheType], jax.Array, List[jax.Array]]: + ) -> tuple[list[KVCacheType], jax.Array, list[jax.Array]]: is_prefill = False x = self.embedder.encode(input_ids)