Skip to content

Replace jax.sharding.use_mesh with jax.set_mesh. jax.set_mesh can act as a global setter or a context manager. #71

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions recml/core/ops/hstu_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,9 @@ def _apply_mask(
masks = []
if mask_ref is not None:
if k_in_lanes:
mask = pl.load(mask_ref, (slice(None), k_slice))
mask = mask_ref[:, k_slice]
else:
mask = pl.load(mask_ref, (k_slice, slice(None)))
mask = mask_ref[k_slice, :]

snm = jnp.where(should_not_mask, 1, 0)
masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)) != 0)
Expand Down Expand Up @@ -156,7 +156,7 @@ def _apply_mask(
k_sequence = k_offset + jax.lax.broadcasted_iota(
jnp.int32, (k_slice.size, bq), 0
)
q_sequence = pl.load(q_sequence_ref, (pl.ds(1), slice(None))) # [1, bq]
q_sequence = q_sequence_ref[:1, :] # [1, bq]
q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq))

assert q_sequence.shape == k_sequence.shape
Expand All @@ -170,7 +170,7 @@ def _apply_mask(

if q_segment_ids_ref is not None:
if k_in_lanes:
kv_ids = pl.load(kv_segment_ids_ref, (pl.ds(1), k_slice)) # [1, k_slice]
kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice]
repeats, rem = divmod(kv_ids.shape[1], NUM_LANES)
if rem:
raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}")
Expand All @@ -181,9 +181,9 @@ def _apply_mask(
if rem:
raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}")
kv_ids = pltpu.repeat(
pl.load(kv_segment_ids_ref, (k_slice, slice(None))), repeats, axis=1
kv_segment_ids_ref[k_slice, :], repeats, axis=1
) # [k_slice, bq]
q_ids = pl.load(q_segment_ids_ref, (pl.ds(1), slice(None))) # [1, bq]
q_ids = q_segment_ids_ref[:1, :] # [1, bq]
masks.append(q_ids == kv_ids)

if masks:
Expand Down Expand Up @@ -228,7 +228,7 @@ def body(kv_compute_index, _):
slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute)

q = q_ref[...]
k = pl.load(k_ref, (slice_k, slice(None)))
k = k_ref[slice_k, :]
qk = jax.lax.dot_general(
q, k, NT_DIM_NUMBERS, preferred_element_type=jnp.float32
)
Expand Down Expand Up @@ -256,7 +256,7 @@ def body(kv_compute_index, _):
)

sv_dims = NN_DIM_NUMBERS
v = pl.load(v_ref, (slice_k, slice(None)))
v = v_ref[slice_k, :]

to_float32 = lambda x: x.astype(jnp.float32)
v = to_float32(v)
Expand Down
10 changes: 5 additions & 5 deletions recml/core/training/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def _shard(x: np.ndarray) -> jax.Array:
def partition_init(
self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None
) -> CreateStateFn:
with jax.sharding.use_mesh(self.mesh):
with jax.set_mesh(self.mesh):
if abstract_batch is not None:
abstract_state = jax.eval_shape(init_fn, abstract_batch)
specs = nn.get_partition_spec(abstract_state)
Expand All @@ -117,7 +117,7 @@ def partition_init(
init_fn = jax.jit(init_fn, out_shardings=self.state_sharding)

def _wrapped_init(batch: PyTree) -> State:
with jax.sharding.use_mesh(self.mesh):
with jax.set_mesh(self.mesh):
state = init_fn(batch)
state = _maybe_unbox_state(state)
return state
Expand All @@ -130,15 +130,15 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn:
jit_kws["out_shardings"] = (self.state_sharding, None)
jit_kws["donate_argnums"] = (1,)

with jax.sharding.use_mesh(self.mesh):
with jax.set_mesh(self.mesh):
step_fn = jax.jit(
fn,
in_shardings=(self.data_sharding, self.state_sharding),
**jit_kws,
)

def _wrapped_step(batch: PyTree, state: State) -> Any:
with jax.sharding.use_mesh(self.mesh):
with jax.set_mesh(self.mesh):
return step_fn(batch, state)

return _wrapped_step
Expand Down Expand Up @@ -217,7 +217,7 @@ def __init__(
def mesh_context_manager(
self,
) -> Callable[[jax.sharding.Mesh], ContextManager[None]]:
return jax.sharding.use_mesh
return jax.set_mesh

def shard_inputs(self, inputs: PyTree) -> PyTree:
def _shard(x: np.ndarray) -> jax.Array:
Expand Down
4 changes: 2 additions & 2 deletions recml/layers/linen/sparsecore.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,7 @@ class SparsecoreEmbed(nn.Module):
Attributes:
sparsecore_config: A sparsecore config specifying how to create the tables.
mesh: The mesh to use for the embedding layer. If not provided, the global
mesh set by `jax.sharding.use_mesh` will be used. If neither is set, an
mesh set by `jax.set_mesh` will be used. If neither is set, an
error will be raised.
"""

Expand All @@ -375,7 +375,7 @@ def get_mesh(self) -> jax.sharding.Mesh | jax.sharding.AbstractMesh:
abstract_mesh = jax.sharding.get_abstract_mesh()
if not abstract_mesh.shape_tuple:
raise ValueError(
'No abstract mesh shape was set with `jax.sharding.use_mesh`. Make'
'No abstract mesh shape was set with `jax.set_mesh`. Make'
' sure to set the mesh when calling the sparsecore module.'
)
return abstract_mesh
Expand Down
Loading