Skip to content

Commit 73ad6e9

Browse files
[JAX SC] perf: Hoist can_drop_id && allow_id_dropping out of keys loop.
* Also move `exceeds_id_limits` calculation inside `perform_id_dropping = can_drop_id && allow_id_dropping`. * Restructure the conditionals a bit (functionally same). PiperOrigin-RevId: 838903206
1 parent 0fc72d1 commit 73ad6e9

File tree

1 file changed

+28
-22
lines changed

1 file changed

+28
-22
lines changed

jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,13 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
209209
stacked_table_metadata.max_ids_per_partition;
210210
const int max_unique_ids_per_partition =
211211
stacked_table_metadata.max_unique_ids_per_partition;
212+
213+
// We do NOT drop IDs when minibatching is enabled and we are in the
214+
// first pass (`kCreateBuckets=false`), as we need to detect limit
215+
// overflows to decide if minibatching is required.
216+
const bool can_drop_id = !options.enable_minibatching || kCreateBuckets;
217+
const bool perform_id_dropping = allow_id_dropping && can_drop_id;
218+
212219
uint32_t prev_col_id = std::numeric_limits<uint32_t>::max();
213220
uint32_t prev_row_id = std::numeric_limits<uint32_t>::max();
214221
uint32_t prev_bucket_id = 0;
@@ -260,41 +267,40 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
260267
}
261268
// Update observed stats. These are never decremented and are used for
262269
// reporting.
263-
// We do NOT drop IDs when minibatching is enabled and we are in the
264-
// first pass (`kCreateBuckets=false`), as we need to detect limit
265-
// overflows to decide if minibatching is required.
266-
const bool can_drop_id = !options.enable_minibatching || kCreateBuckets;
267270
observed_ids(global_sc_id, bucket_id) += 1;
268271
if (is_new_col) {
269272
observed_unique_ids(global_sc_id, bucket_id) += 1;
270-
if (allow_id_dropping && can_drop_id) {
273+
}
274+
275+
// Step 4: Add ID to result or drop it.
276+
if (!perform_id_dropping) {
277+
grouped_coo_tensors.Add(context.local_sc_id, bucket_id, coo_tensor);
278+
} else {
279+
// Check limits.
280+
const bool exceeds_ids_limit =
281+
(kept_ids(global_sc_id, bucket_id) + 1) > max_ids_per_partition;
282+
if (is_new_col) {
271283
dropping_current_unique_col_id =
272284
(kept_unique_ids(global_sc_id, bucket_id) + 1) >
273285
max_unique_ids_per_partition;
274286
}
275-
}
276287

277-
// Step 4: Determine if the ID should be dropped based on capacity limits.
278-
const bool exceeds_ids_limit =
279-
(kept_ids(global_sc_id, bucket_id) + 1) > max_ids_per_partition;
288+
// Drop/Keep ID.
289+
if (exceeds_ids_limit || dropping_current_unique_col_id) {
290+
// Dropped id.
291+
++stats.dropped_id_count;
292+
} else {
293+
grouped_coo_tensors.Add(context.local_sc_id, bucket_id, coo_tensor);
294+
}
280295

281-
// Step 5: Add ID to result or drop it.
282-
if (can_drop_id && allow_id_dropping &&
283-
(exceeds_ids_limit || dropping_current_unique_col_id)) {
284-
// Dropped id.
285-
++stats.dropped_id_count;
286-
} else {
287-
grouped_coo_tensors.Add(context.local_sc_id, bucket_id, coo_tensor);
288296
// Update kept counts.
289-
if (allow_id_dropping && can_drop_id) {
290-
kept_ids(global_sc_id, bucket_id) += 1;
291-
if (is_new_col) {
292-
kept_unique_ids(global_sc_id, bucket_id) += 1;
293-
}
297+
kept_ids(global_sc_id, bucket_id) += 1;
298+
if (is_new_col) {
299+
kept_unique_ids(global_sc_id, bucket_id) += 1;
294300
}
295301
}
296302

297-
// Step 6: Update state for next iteration.
303+
// Step 5: Update state for next iteration.
298304
// This must be done regardless of whether the ID was dropped to ensure
299305
// correct stats collection for subsequent IDs.
300306
prev_col_id = col_id;

0 commit comments

Comments
 (0)