From 96045a7a75204ad3e70f443b518bb8833491bafb Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Thu, 31 Jul 2025 18:51:26 -0700 Subject: [PATCH] [pallas] Do not use `pl.{load,store}` which will soon be removed from Pallas Each backend now has its own `load` and `store` function with backend-specific signatures. Prefer using those to target backend-specific feature or load/store via indexing. PiperOrigin-RevId: 789549961 --- recml/core/ops/hstu_ops.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/recml/core/ops/hstu_ops.py b/recml/core/ops/hstu_ops.py index 3a8df11..59fd7bd 100644 --- a/recml/core/ops/hstu_ops.py +++ b/recml/core/ops/hstu_ops.py @@ -125,9 +125,9 @@ def _apply_mask( masks = [] if mask_ref is not None: if k_in_lanes: - mask = pl.load(mask_ref, (slice(None), k_slice)) + mask = mask_ref[:, k_slice] else: - mask = pl.load(mask_ref, (k_slice, slice(None))) + mask = mask_ref[k_slice, :] snm = jnp.where(should_not_mask, 1, 0) masks.append(jnp.bitwise_or(mask, jnp.broadcast_to(snm, mask.shape)) != 0) @@ -156,7 +156,7 @@ def _apply_mask( k_sequence = k_offset + jax.lax.broadcasted_iota( jnp.int32, (k_slice.size, bq), 0 ) - q_sequence = pl.load(q_sequence_ref, (pl.ds(1), slice(None))) # [1, bq] + q_sequence = q_sequence_ref[:1, :] # [1, bq] q_sequence = jnp.broadcast_to(q_sequence, (k_slice.size, bq)) assert q_sequence.shape == k_sequence.shape @@ -170,7 +170,7 @@ def _apply_mask( if q_segment_ids_ref is not None: if k_in_lanes: - kv_ids = pl.load(kv_segment_ids_ref, (pl.ds(1), k_slice)) # [1, k_slice] + kv_ids = kv_segment_ids_ref[:1, k_slice] # [1, k_slice] repeats, rem = divmod(kv_ids.shape[1], NUM_LANES) if rem: raise NotImplementedError(f"block_kv must be a multiple of {NUM_LANES}") @@ -181,9 +181,9 @@ def _apply_mask( if rem: raise NotImplementedError(f"block_q must be a multiple of {NUM_LANES}") kv_ids = pltpu.repeat( - pl.load(kv_segment_ids_ref, (k_slice, slice(None))), repeats, axis=1 + kv_segment_ids_ref[k_slice, :], repeats, axis=1 ) # [k_slice, bq] - q_ids = pl.load(q_segment_ids_ref, (pl.ds(1), slice(None))) # [1, bq] + q_ids = q_segment_ids_ref[:1, :] # [1, bq] masks.append(q_ids == kv_ids) if masks: @@ -228,7 +228,7 @@ def body(kv_compute_index, _): slice_k = pl.ds(kv_compute_index * bkv_compute, bkv_compute) q = q_ref[...] - k = pl.load(k_ref, (slice_k, slice(None))) + k = k_ref[slice_k, :] qk = jax.lax.dot_general( q, k, NT_DIM_NUMBERS, preferred_element_type=jnp.float32 ) @@ -256,7 +256,7 @@ def body(kv_compute_index, _): ) sv_dims = NN_DIM_NUMBERS - v = pl.load(v_ref, (slice_k, slice(None))) + v = v_ref[slice_k, :] to_float32 = lambda x: x.astype(jnp.float32) v = to_float32(v)