@@ -209,6 +209,13 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
209209 stacked_table_metadata.max_ids_per_partition ;
210210 const int max_unique_ids_per_partition =
211211 stacked_table_metadata.max_unique_ids_per_partition ;
212+
213+ // We do NOT drop IDs when minibatching is enabled and we are in the
214+ // first pass (`kCreateBuckets=false`), as we need to detect limit
215+ // overflows to decide if minibatching is required.
216+ const bool can_drop_id = !options.enable_minibatching || kCreateBuckets ;
217+ const bool perform_id_dropping = allow_id_dropping && can_drop_id;
218+
212219 uint32_t prev_col_id = std::numeric_limits<uint32_t >::max ();
213220 uint32_t prev_row_id = std::numeric_limits<uint32_t >::max ();
214221 uint32_t prev_bucket_id = 0 ;
@@ -260,41 +267,40 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
260267 }
261268 // Update observed stats. These are never decremented and are used for
262269 // reporting.
263- // We do NOT drop IDs when minibatching is enabled and we are in the
264- // first pass (`kCreateBuckets=false`), as we need to detect limit
265- // overflows to decide if minibatching is required.
266- const bool can_drop_id = !options.enable_minibatching || kCreateBuckets ;
267270 observed_ids (global_sc_id, bucket_id) += 1 ;
268271 if (is_new_col) {
269272 observed_unique_ids (global_sc_id, bucket_id) += 1 ;
270- if (allow_id_dropping && can_drop_id) {
273+ }
274+
275+ // Step 4: Add ID to result or drop it.
276+ if (!perform_id_dropping) {
277+ grouped_coo_tensors.Add (context.local_sc_id , bucket_id, coo_tensor);
278+ } else {
279+ // Check limits.
280+ const bool exceeds_ids_limit =
281+ (kept_ids (global_sc_id, bucket_id) + 1 ) > max_ids_per_partition;
282+ if (is_new_col) {
271283 dropping_current_unique_col_id =
272284 (kept_unique_ids (global_sc_id, bucket_id) + 1 ) >
273285 max_unique_ids_per_partition;
274286 }
275- }
276287
277- // Step 4: Determine if the ID should be dropped based on capacity limits.
278- const bool exceeds_ids_limit =
279- (kept_ids (global_sc_id, bucket_id) + 1 ) > max_ids_per_partition;
288+ // Drop/Keep ID.
289+ if (exceeds_ids_limit || dropping_current_unique_col_id) {
290+ // Dropped id.
291+ ++stats.dropped_id_count ;
292+ } else {
293+ grouped_coo_tensors.Add (context.local_sc_id , bucket_id, coo_tensor);
294+ }
280295
281- // Step 5: Add ID to result or drop it.
282- if (can_drop_id && allow_id_dropping &&
283- (exceeds_ids_limit || dropping_current_unique_col_id)) {
284- // Dropped id.
285- ++stats.dropped_id_count ;
286- } else {
287- grouped_coo_tensors.Add (context.local_sc_id , bucket_id, coo_tensor);
288296 // Update kept counts.
289- if (allow_id_dropping && can_drop_id) {
290- kept_ids (global_sc_id, bucket_id) += 1 ;
291- if (is_new_col) {
292- kept_unique_ids (global_sc_id, bucket_id) += 1 ;
293- }
297+ kept_ids (global_sc_id, bucket_id) += 1 ;
298+ if (is_new_col) {
299+ kept_unique_ids (global_sc_id, bucket_id) += 1 ;
294300 }
295301 }
296302
297- // Step 6 : Update state for next iteration.
303+ // Step 5 : Update state for next iteration.
298304 // This must be done regardless of whether the ID was dropped to ensure
299305 // correct stats collection for subsequent IDs.
300306 prev_col_id = col_id;
0 commit comments