diff --git a/clrs/_src/baselines.py b/clrs/_src/baselines.py index e295de67..226889c1 100644 --- a/clrs/_src/baselines.py +++ b/clrs/_src/baselines.py @@ -125,15 +125,27 @@ def _maybe_put_replicated(tree): if jax.local_device_count() == 1: return jax.device_put(tree) else: - return jax.device_put_replicated(tree, jax.local_devices()) + devices = jax.local_devices() + mesh = jax.sharding.Mesh(np.array(devices), ('_device_put_sharded',)) + sharding = jax.NamedSharding(mesh, jax.P('_device_put_sharded')) + + def _replicate(x): + if isinstance(x, jax.Array): + return jax.device_put(jnp.stack([x] * len(devices)), sharding) + return jax.device_put(np.stack([x] * len(devices)), sharding) + + return jax.tree_util.tree_map(_replicate, tree) def _maybe_pmap_rng_key(rng_key: _Array): n_devices = jax.local_device_count() if n_devices == 1: return rng_key + devices = jax.local_devices() pmap_rng_keys = jax.random.split(rng_key, n_devices) - return jax.device_put_sharded(list(pmap_rng_keys), jax.local_devices()) + mesh = jax.sharding.Mesh(np.array(devices), ('_device_put_sharded',)) + sharding = jax.NamedSharding(mesh, jax.P('_device_put_sharded')) + return jax.device_put(jnp.stack(list(pmap_rng_keys)), sharding) class BaselineModel(model.Model):