From bfd042f049b32194c2c4b287f88c73da69996e39 Mon Sep 17 00:00:00 2001 From: Daniel Suo Date: Thu, 26 Mar 2026 11:18:11 -0700 Subject: [PATCH] [pmap] In-line definitions of `jax.device_put_sharded` and `jax.device_put_replicated`. Both `jax.device_put_sharded` and `jax.device_put_replicated` were deprecated in JAX v0.8.1 in November 2025. We in-line their definitions using public JAX APIs taking the `jax_pmap_shmap_merge=True` branch, which was made the default in JAX v0.8.0 in October 2025. Please see the below for more information: - JAX CHANGELOG: https://docs.jax.dev/en/latest/changelog.html - Migrating from `jax.pmap`: https://docs.jax.dev/en/latest/migrate_pmap.html PiperOrigin-RevId: 889947371 --- clrs/_src/baselines.py | 16 ++++++++++++++-- 1 file changed, 14 insertions(+), 2 deletions(-) diff --git a/clrs/_src/baselines.py b/clrs/_src/baselines.py index e295de67..226889c1 100644 --- a/clrs/_src/baselines.py +++ b/clrs/_src/baselines.py @@ -125,15 +125,27 @@ def _maybe_put_replicated(tree): if jax.local_device_count() == 1: return jax.device_put(tree) else: - return jax.device_put_replicated(tree, jax.local_devices()) + devices = jax.local_devices() + mesh = jax.sharding.Mesh(np.array(devices), ('_device_put_sharded',)) + sharding = jax.NamedSharding(mesh, jax.P('_device_put_sharded')) + + def _replicate(x): + if isinstance(x, jax.Array): + return jax.device_put(jnp.stack([x] * len(devices)), sharding) + return jax.device_put(np.stack([x] * len(devices)), sharding) + + return jax.tree_util.tree_map(_replicate, tree) def _maybe_pmap_rng_key(rng_key: _Array): n_devices = jax.local_device_count() if n_devices == 1: return rng_key + devices = jax.local_devices() pmap_rng_keys = jax.random.split(rng_key, n_devices) - return jax.device_put_sharded(list(pmap_rng_keys), jax.local_devices()) + mesh = jax.sharding.Mesh(np.array(devices), ('_device_put_sharded',)) + sharding = jax.NamedSharding(mesh, jax.P('_device_put_sharded')) + return jax.device_put(jnp.stack(list(pmap_rng_keys)), sharding) class BaselineModel(model.Model):