From 6945c9e22613816bf16c069b16ae75fd4ce169d2 Mon Sep 17 00:00:00 2001 From: Yash Katariya Date: Wed, 6 Aug 2025 16:35:21 -0700 Subject: [PATCH] Replace `jax.sharding.use_mesh` with `jax.set_mesh`. `jax.set_mesh` can act as a global setter or a context manager. PiperOrigin-RevId: 791891966 --- recml/core/ops/hstu_ops.py | 16 ++++++++-------- recml/core/training/partitioning.py | 10 +++++----- recml/layers/linen/sparsecore.py | 4 ++-- 3 files changed, 15 insertions(+), 15 deletions(-) diff --git a/recml/core/ops/hstu_ops.py b/recml/core/ops/hstu_ops.py index 3a8df11..59fd7bd 100644 --- a/recml/core/ops/hstu_ops.py +++ b/recml/core/ops/hstu_ops.py @@ -125,9 +125,9 @@ def _apply_mask( masks = [] if mask_ref is not None: if k_in_lanes: - mask = pl.load(mask_ref, (slice(None), k_slice)) + mask = mask_ref[:, k_slice] else: - mask = pl.load(mask_ref, (k_slice, slice(None))) + mask = mask_ref[k_slice, :] snm = jnp.where(should_not_mask, 1, 0) masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)) != 0) @@ -156,7 +156,7 @@ def _apply_mask( k_sequence = k_offset + jax.lax.broadcasted_iota( jnp.int32, (k_slice.size, bq), 0 ) - q_sequence = pl.load(q_sequence_ref, (pl.ds(1), slice(None))) # [1, bq] + q_sequence = q_sequence_ref[:1, :] # [1, bq] q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq)) assert q_sequence.shape == k_sequence.shape @@ -170,7 +170,7 @@ def _apply_mask( if q_segment_ids_ref is not None: if k_in_lanes: - kv_ids = pl.load(kv_segment_ids_ref, (pl.ds(1), k_slice)) # [1, k_slice] + kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice] repeats, rem = divmod(kv_ids.shape[1], NUM_LANES) if rem: raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}") @@ -181,9 +181,9 @@ def _apply_mask( if rem: raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}") kv_ids = pltpu.repeat( - pl.load(kv_segment_ids_ref, (k_slice, slice(None))), repeats, axis=1 + kv_segment_ids_ref[k_slice, :], repeats, axis=1 ) # [k_slice, bq] - q_ids = pl.load(q_segment_ids_ref, (pl.ds(1), slice(None))) # [1, bq] + q_ids = q_segment_ids_ref[:1, :] # [1, bq] masks.append(q_ids == kv_ids) if masks: @@ -228,7 +228,7 @@ def body(kv_compute_index, _): slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute) q = q_ref[...] - k = pl.load(k_ref, (slice_k, slice(None))) + k = k_ref[slice_k, :] qk = jax.lax.dot_general( q, k, NT_DIM_NUMBERS, preferred_element_type=jnp.float32 ) @@ -256,7 +256,7 @@ def body(kv_compute_index, _): ) sv_dims = NN_DIM_NUMBERS - v = pl.load(v_ref, (slice_k, slice(None))) + v = v_ref[slice_k, :] to_float32 = lambda x: x.astype(jnp.float32) v = to_float32(v) diff --git a/recml/core/training/partitioning.py b/recml/core/training/partitioning.py index 4dc3b76..eabce4a 100644 --- a/recml/core/training/partitioning.py +++ b/recml/core/training/partitioning.py @@ -107,7 +107,7 @@ def _shard(x: np.ndarray) -> jax.Array: def partition_init( self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None ) -> CreateStateFn: - with jax.sharding.use_mesh(self.mesh): + with jax.set_mesh(self.mesh): if abstract_batch is not None: abstract_state = jax.eval_shape(init_fn, abstract_batch) specs = nn.get_partition_spec(abstract_state) @@ -117,7 +117,7 @@ def partition_init( init_fn = jax.jit(init_fn, out_shardings=self.state_sharding) def _wrapped_init(batch: PyTree) -> State: - with jax.sharding.use_mesh(self.mesh): + with jax.set_mesh(self.mesh): state = init_fn(batch) state = _maybe_unbox_state(state) return state @@ -130,7 +130,7 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn: jit_kws["out_shardings"] = (self.state_sharding, None) jit_kws["donate_argnums"] = (1,) - with jax.sharding.use_mesh(self.mesh): + with jax.set_mesh(self.mesh): step_fn = jax.jit( fn, in_shardings=(self.data_sharding, self.state_sharding), @@ -138,7 +138,7 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn: ) def _wrapped_step(batch: PyTree, state: State) -> Any: - with jax.sharding.use_mesh(self.mesh): + with jax.set_mesh(self.mesh): return step_fn(batch, state) return _wrapped_step @@ -217,7 +217,7 @@ def __init__( def mesh_context_manager( self, ) -> Callable[[jax.sharding.Mesh], ContextManager[None]]: - return jax.sharding.use_mesh + return jax.set_mesh def shard_inputs(self, inputs: PyTree) -> PyTree: def _shard(x: np.ndarray) -> jax.Array: diff --git a/recml/layers/linen/sparsecore.py b/recml/layers/linen/sparsecore.py index a908ab8..5fa38d4 100644 --- a/recml/layers/linen/sparsecore.py +++ b/recml/layers/linen/sparsecore.py @@ -362,7 +362,7 @@ class SparsecoreEmbed(nn.Module): Attributes: sparsecore_config: A sparsecore config specifying how to create the tables. mesh: The mesh to use for the embedding layer. If not provided, the global - mesh set by `jax.sharding.use_mesh` will be used. If neither is set, an + mesh set by `jax.set_mesh` will be used. If neither is set, an error will be raised. """ @@ -375,7 +375,7 @@ def get_mesh(self) -> jax.sharding.Mesh | jax.sharding.AbstractMesh: abstract_mesh = jax.sharding.get_abstract_mesh() if not abstract_mesh.shape_tuple: raise ValueError( - 'No abstract mesh shape was set with `jax.sharding.use_mesh`. Make' + 'No abstract mesh shape was set with `jax.set_mesh`. Make' ' sure to set the mesh when calling the sparsecore module.' ) return abstract_mesh