Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 9 additions & 10 deletions jax_tpu_embedding/sparsecore/lib/core/partitioned_coo_tensors.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,25 +44,27 @@ class PartitionedCooTensors {
curr_bucket_id_(0),
merged_(false),
dedup_col_id_(std::numeric_limits<uint32_t>::max()),
dedup_row_id_(std::numeric_limits<uint32_t>::max()),
dedup_bucket_id_(-1) {
dedup_row_id_(std::numeric_limits<uint32_t>::max()) {
coo_tensors_.reserve(reserve_count);
bucket_offsets_.reserve(1 + num_sc_per_device * bucket_count_per_sc_);
bucket_offsets_.push_back(0);
}

void MergeWithLastCoo(const CooFormat& coo_tensor) {
inline void MergeWithLastCoo(const CooFormat& coo_tensor) {
DCHECK_GT(coo_tensors_.size(), 0);
CooFormat& last = coo_tensors_.back();
DCHECK_EQ(last.row_id, coo_tensor.row_id);
DCHECK_EQ(last.col_id, coo_tensor.col_id);
last.gain += coo_tensor.gain;
}

bool MaybeMerge(int bucket_id, const CooFormat& coo_tensor) {
if (bucket_id == dedup_bucket_id_ && coo_tensor.col_id == dedup_col_id_ &&
coo_tensor.row_id == dedup_row_id_) {
CHECK(!coo_tensors_.empty());
inline bool MaybeMerge(const CooFormat& coo_tensor) {
// If col_id is the same, bucket_id must also be the same.
// For fastest short-circuiting, check row_id first, as it's the last
// component of the sort key and thus most likely to differ between
// consecutive non-identical elements.
if (coo_tensor.row_id == dedup_row_id_ &&
coo_tensor.col_id == dedup_col_id_) {
MergeWithLastCoo(coo_tensor);
return true;
}
Expand All @@ -72,7 +74,6 @@ class PartitionedCooTensors {
void ResetDedupState() {
dedup_col_id_ = std::numeric_limits<uint32_t>::max();
dedup_row_id_ = std::numeric_limits<uint32_t>::max();
dedup_bucket_id_ = -1;
}

// Add Coo tensor for given SC and bucket. Similar to std::vector::push_back.
Expand All @@ -81,7 +82,6 @@ class PartitionedCooTensors {
AdvanceBucketOffsets(target_sc_id, target_bucket_id);

coo_tensors_.push_back(coo_tensor);
dedup_bucket_id_ = target_bucket_id;
dedup_col_id_ = coo_tensor.col_id;
dedup_row_id_ = coo_tensor.row_id;
}
Expand Down Expand Up @@ -245,7 +245,6 @@ class PartitionedCooTensors {
// dropped.
uint32_t dedup_col_id_;
uint32_t dedup_row_id_;
int dedup_bucket_id_;
};

} // namespace jax_sc_embedding
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -240,16 +240,12 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
// An ID that is a duplicate of a previously non-dropped ID is merged.
// It does not count as a new ID for stats and does not go through dropping
// logic.
if (grouped_coo_tensors.MaybeMerge(bucket_id, coo_tensor)) {
if (grouped_coo_tensors.MaybeMerge(coo_tensor)) {
continue;
}
// If the ID is a duplicate of the last seen ID, it must have been dropped
// (otherwise it would have been merged above), so drop this one too.
bool fully_duplicate = col_id == prev_col_id && row_id == prev_row_id;
if constexpr (kCreateBuckets) {
fully_duplicate = fully_duplicate && bucket_id == prev_bucket_id;
}
if (fully_duplicate) {
if (row_id == prev_row_id && col_id == prev_col_id) {
++stats.dropped_id_count;
continue;
}
Expand Down
Loading