From f7933ecc9f9fda99d5878f87149401b4732de08b Mon Sep 17 00:00:00 2001 From: RecML authors Date: Thu, 10 Jul 2025 10:03:52 -0700 Subject: [PATCH] [Efficient LMs] Add compressor and decompressor for the research work of go/context_compression. This commit adds functions for compressing and decompressing input and output tensors. Reverts changelist 793734230 PiperOrigin-RevId: 781578682 --- recml/core/ops/hstu_ops.py | 16 ++++++++-------- recml/core/training/jax_trainer.py | 4 +++- recml/core/training/partitioning.py | 10 +++++----- recml/layers/linen/sparsecore.py | 7 ++++--- 4 files changed, 20 insertions(+), 17 deletions(-) diff --git a/recml/core/ops/hstu_ops.py b/recml/core/ops/hstu_ops.py index 3a8df11..59fd7bd 100644 --- a/recml/core/ops/hstu_ops.py +++ b/recml/core/ops/hstu_ops.py @@ -125,9 +125,9 @@ def _apply_mask( masks = [] if mask_ref is not None: if k_in_lanes: - mask = pl.load(mask_ref, (slice(None), k_slice)) + mask = mask_ref[:, k_slice] else: - mask = pl.load(mask_ref, (k_slice, slice(None))) + mask = mask_ref[k_slice, :] snm = jnp.where(should_not_mask, 1, 0) masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)) != 0) @@ -156,7 +156,7 @@ def _apply_mask( k_sequence = k_offset + jax.lax.broadcasted_iota( jnp.int32, (k_slice.size, bq), 0 ) - q_sequence = pl.load(q_sequence_ref, (pl.ds(1), slice(None))) # [1, bq] + q_sequence = q_sequence_ref[:1, :] # [1, bq] q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq)) assert q_sequence.shape == k_sequence.shape @@ -170,7 +170,7 @@ def _apply_mask( if q_segment_ids_ref is not None: if k_in_lanes: - kv_ids = pl.load(kv_segment_ids_ref, (pl.ds(1), k_slice)) # [1, k_slice] + kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice] repeats, rem = divmod(kv_ids.shape[1], NUM_LANES) if rem: raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}") @@ -181,9 +181,9 @@ def _apply_mask( if rem: raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}") kv_ids = pltpu.repeat( - pl.load(kv_segment_ids_ref, (k_slice, slice(None))), repeats, axis=1 + kv_segment_ids_ref[k_slice, :], repeats, axis=1 ) # [k_slice, bq] - q_ids = pl.load(q_segment_ids_ref, (pl.ds(1), slice(None))) # [1, bq] + q_ids = q_segment_ids_ref[:1, :] # [1, bq] masks.append(q_ids == kv_ids) if masks: @@ -228,7 +228,7 @@ def body(kv_compute_index, _): slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute) q = q_ref[...] - k = pl.load(k_ref, (slice_k, slice(None))) + k = k_ref[slice_k, :] qk = jax.lax.dot_general( q, k, NT_DIM_NUMBERS, preferred_element_type=jnp.float32 ) @@ -256,7 +256,7 @@ def body(kv_compute_index, _): ) sv_dims = NN_DIM_NUMBERS - v = pl.load(v_ref, (slice_k, slice(None))) + v = v_ref[slice_k, :] to_float32 = lambda x: x.astype(jnp.float32) v = to_float32(v) diff --git a/recml/core/training/jax_trainer.py b/recml/core/training/jax_trainer.py index 481bde3..6c5aa64 100644 --- a/recml/core/training/jax_trainer.py +++ b/recml/core/training/jax_trainer.py @@ -699,6 +699,9 @@ def train(self, task: JaxTask) -> core.Logs: ) metrics[core.TRAIN_LOG_DIRNAME] = train_metrics + if jax.process_index() == 0: + task.export_model(state, self._model_dir) + self._maybe_save_checkpoint(curr_step, state, metrics=metrics) step = curr_step + 1 @@ -706,7 +709,6 @@ def train(self, task: JaxTask) -> core.Logs: if jax.process_index() == 0: self._write_marker_file() - task.export_model(state, self._model_dir) self.checkpoint_manager.close() del self.checkpoint_manager diff --git a/recml/core/training/partitioning.py b/recml/core/training/partitioning.py index 4dc3b76..eabce4a 100644 --- a/recml/core/training/partitioning.py +++ b/recml/core/training/partitioning.py @@ -107,7 +107,7 @@ def _shard(x: np.ndarray) -> jax.Array: def partition_init( self, init_fn: CreateStateFn, *, abstract_batch: PyTree | None = None ) -> CreateStateFn: - with jax.sharding.use_mesh(self.mesh): + with jax.set_mesh(self.mesh): if abstract_batch is not None: abstract_state = jax.eval_shape(init_fn, abstract_batch) specs = nn.get_partition_spec(abstract_state) @@ -117,7 +117,7 @@ def partition_init( init_fn = jax.jit(init_fn, out_shardings=self.state_sharding) def _wrapped_init(batch: PyTree) -> State: - with jax.sharding.use_mesh(self.mesh): + with jax.set_mesh(self.mesh): state = init_fn(batch) state = _maybe_unbox_state(state) return state @@ -130,7 +130,7 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn: jit_kws["out_shardings"] = (self.state_sharding, None) jit_kws["donate_argnums"] = (1,) - with jax.sharding.use_mesh(self.mesh): + with jax.set_mesh(self.mesh): step_fn = jax.jit( fn, in_shardings=(self.data_sharding, self.state_sharding), @@ -138,7 +138,7 @@ def partition_step(self, fn: StepFn, *, training: bool = False) -> StepFn: ) def _wrapped_step(batch: PyTree, state: State) -> Any: - with jax.sharding.use_mesh(self.mesh): + with jax.set_mesh(self.mesh): return step_fn(batch, state) return _wrapped_step @@ -217,7 +217,7 @@ def __init__( def mesh_context_manager( self, ) -> Callable[[jax.sharding.Mesh], ContextManager[None]]: - return jax.sharding.use_mesh + return jax.set_mesh def shard_inputs(self, inputs: PyTree) -> PyTree: def _shard(x: np.ndarray) -> jax.Array: diff --git a/recml/layers/linen/sparsecore.py b/recml/layers/linen/sparsecore.py index a908ab8..3849425 100644 --- a/recml/layers/linen/sparsecore.py +++ b/recml/layers/linen/sparsecore.py @@ -334,7 +334,7 @@ def _to_np(x: Any) -> np.ndarray: weights[key] = np.reshape(weights[key], (-1, 1)) self._batch_number += 1 - csr_inputs, _ = embedding.preprocess_sparse_dense_matmul_input( + preprocessed_inputs, _ = embedding.preprocess_sparse_dense_matmul_input( features=features, features_weights=weights, feature_specs=self.sparsecore_config.feature_specs, @@ -345,6 +345,7 @@ def _to_np(x: Any) -> np.ndarray: allow_id_dropping=self.sparsecore_config.allow_id_dropping, batch_number=self._batch_number, ) + csr_inputs = preprocessed_inputs.sparse_dense_matmul_input processed_inputs = { k: v for k, v in inputs.items() if k not in sparse_features @@ -362,7 +363,7 @@ class SparsecoreEmbed(nn.Module): Attributes: sparsecore_config: A sparsecore config specifying how to create the tables. mesh: The mesh to use for the embedding layer. If not provided, the global - mesh set by `jax.sharding.use_mesh` will be used. If neither is set, an + mesh set by `jax.set_mesh` will be used. If neither is set, an error will be raised. """ @@ -375,7 +376,7 @@ def get_mesh(self) -> jax.sharding.Mesh | jax.sharding.AbstractMesh: abstract_mesh = jax.sharding.get_abstract_mesh() if not abstract_mesh.shape_tuple: raise ValueError( - 'No abstract mesh shape was set with `jax.sharding.use_mesh`. Make' + 'No abstract mesh shape was set with `jax.set_mesh`. Make' ' sure to set the mesh when calling the sparsecore module.' ) return abstract_mesh