diff --git a/jax_tpu_embedding/sparsecore/lib/core/partitioned_coo_tensors.h b/jax_tpu_embedding/sparsecore/lib/core/partitioned_coo_tensors.h index f5d7f4fd..c6209789 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/partitioned_coo_tensors.h +++ b/jax_tpu_embedding/sparsecore/lib/core/partitioned_coo_tensors.h @@ -44,14 +44,13 @@ class PartitionedCooTensors { curr_bucket_id_(0), merged_(false), dedup_col_id_(std::numeric_limits::max()), - dedup_row_id_(std::numeric_limits::max()), - dedup_bucket_id_(-1) { + dedup_row_id_(std::numeric_limits::max()) { coo_tensors_.reserve(reserve_count); bucket_offsets_.reserve(1 + num_sc_per_device * bucket_count_per_sc_); bucket_offsets_.push_back(0); } - void MergeWithLastCoo(const CooFormat& coo_tensor) { + inline void MergeWithLastCoo(const CooFormat& coo_tensor) { DCHECK_GT(coo_tensors_.size(), 0); CooFormat& last = coo_tensors_.back(); DCHECK_EQ(last.row_id, coo_tensor.row_id); @@ -59,10 +58,13 @@ class PartitionedCooTensors { last.gain += coo_tensor.gain; } - bool MaybeMerge(int bucket_id, const CooFormat& coo_tensor) { - if (bucket_id == dedup_bucket_id_ && coo_tensor.col_id == dedup_col_id_ && - coo_tensor.row_id == dedup_row_id_) { - CHECK(!coo_tensors_.empty()); + inline bool MaybeMerge(const CooFormat& coo_tensor) { + // If col_id is the same, bucket_id must also be the same. + // For fastest short-circuiting, check row_id first, as it's the last + // component of the sort key and thus most likely to differ between + // consecutive non-identical elements. + if (coo_tensor.row_id == dedup_row_id_ && + coo_tensor.col_id == dedup_col_id_) { MergeWithLastCoo(coo_tensor); return true; } @@ -72,7 +74,6 @@ class PartitionedCooTensors { void ResetDedupState() { dedup_col_id_ = std::numeric_limits::max(); dedup_row_id_ = std::numeric_limits::max(); - dedup_bucket_id_ = -1; } // Add Coo tensor for given SC and bucket. Similar to std::vector::push_back. @@ -81,7 +82,6 @@ class PartitionedCooTensors { AdvanceBucketOffsets(target_sc_id, target_bucket_id); coo_tensors_.push_back(coo_tensor); - dedup_bucket_id_ = target_bucket_id; dedup_col_id_ = coo_tensor.col_id; dedup_row_id_ = coo_tensor.row_id; } @@ -245,7 +245,6 @@ class PartitionedCooTensors { // dropped. uint32_t dedup_col_id_; uint32_t dedup_row_id_; - int dedup_bucket_id_; }; } // namespace jax_sc_embedding diff --git a/jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h b/jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h index f8f89ea6..cb9a1457 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h +++ b/jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h @@ -240,16 +240,12 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore( // An ID that is a duplicate of a previously non-dropped ID is merged. // It does not count as a new ID for stats and does not go through dropping // logic. - if (grouped_coo_tensors.MaybeMerge(bucket_id, coo_tensor)) { + if (grouped_coo_tensors.MaybeMerge(coo_tensor)) { continue; } // If the ID is a duplicate of the last seen ID, it must have been dropped // (otherwise it would have been merged above), so drop this one too. - bool fully_duplicate = col_id == prev_col_id && row_id == prev_row_id; - if constexpr (kCreateBuckets) { - fully_duplicate = fully_duplicate && bucket_id == prev_bucket_id; - } - if (fully_duplicate) { + if (row_id == prev_row_id && col_id == prev_col_id) { ++stats.dropped_id_count; continue; }