Skip to content

Commit 3ddc44d

Browse files
[JAX SC] Continuation of cl/836767096. Templatize kCreateBuckets for compile-time optimization.
* `10.89%` geomean reduction in wall time with `9.02%` CPU time decrease. PiperOrigin-RevId: 836943854
1 parent 2580a46 commit 3ddc44d

File tree

1 file changed

+46
-24
lines changed

1 file changed

+46
-24
lines changed

jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h

Lines changed: 46 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,6 @@ struct LocalSparseCoreTensorGroupingContext {
175175
absl::Span<const CooFormat> coo_tensors;
176176
const StackedTableMetadata& stacked_table_metadata;
177177
const PreprocessSparseDenseMatmulInputOptions& options;
178-
const bool create_buckets;
179178
const int32_t local_sc_id;
180179
const int32_t num_sc_bits;
181180

@@ -184,11 +183,12 @@ struct LocalSparseCoreTensorGroupingContext {
184183
MatrixXi& ids_per_sc_partition_per_bucket;
185184
MatrixXi& unique_ids_per_partition_per_bucket;
186185
StatsPerDevice& stats;
186+
// These are only used for id dropping decisions and can be ignored otherwise.
187187
MatrixXi& kept_ids_per_sc_partition_per_bucket;
188188
MatrixXi& kept_unique_ids_per_partition_per_bucket;
189189
};
190190

191-
template <bool kHasVariableWeights>
191+
template <bool kHasVariableWeights, bool kCreateBuckets>
192192
inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
193193
LocalSparseCoreTensorGroupingContext context) {
194194
// Unpack context for readability.
@@ -219,7 +219,8 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
219219
const int num_sc_bits = context.num_sc_bits;
220220
for (const uint64_t key : context.keys) {
221221
// Step 1: Unpack key to get tensor coordinates.
222-
const uint32_t bucket_id = CooFormat::GetBucketIdFromKey(key);
222+
const uint32_t bucket_id =
223+
kCreateBuckets ? CooFormat::GetBucketIdFromKey(key) : 0;
223224
const uint32_t col_id =
224225
absl::rotl(CooFormat::GetRotatedColIdFromKey(key), num_sc_bits);
225226
const uint32_t global_sc_id = col_id & (global_sc_count - 1);
@@ -244,26 +245,29 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
244245
}
245246
// If the ID is a duplicate of the last seen ID, it must have been dropped
246247
// (otherwise it would have been merged above), so drop this one too.
247-
if (bucket_id == prev_bucket_id && col_id == prev_col_id &&
248-
row_id == prev_row_id) {
248+
bool fully_duplicate = col_id == prev_col_id && row_id == prev_row_id;
249+
if constexpr (kCreateBuckets) {
250+
fully_duplicate = fully_duplicate && bucket_id == prev_bucket_id;
251+
}
252+
if (fully_duplicate) {
249253
++stats.dropped_id_count;
250254
continue;
251255
}
252256

253-
// We do NOT drop IDs when minibatching is enabled and we are in the
254-
// first pass (`create_buckets=false`), as we need to detect limit
255-
// overflows to decide if minibatching is required.
256-
const bool can_drop_id =
257-
!options.enable_minibatching || context.create_buckets;
258-
259257
// Step 3: Update observed statistics for the new ID.
260258
// We have a new column if the bucket_id changes (we can't dedupe across
261259
// bucket boundaries) or if the col_id changes within the same bucket. Note
262260
// that multiple col_ids can map to the same bucket.
263-
const bool is_new_col =
264-
(bucket_id != prev_bucket_id || col_id != prev_col_id);
261+
bool is_new_col = col_id != prev_col_id;
262+
if constexpr (kCreateBuckets) {
263+
is_new_col = is_new_col || bucket_id != prev_bucket_id;
264+
}
265265
// Update observed stats. These are never decremented and are used for
266266
// reporting.
267+
// We do NOT drop IDs when minibatching is enabled and we are in the
268+
// first pass (`kCreateBuckets=false`), as we need to detect limit
269+
// overflows to decide if minibatching is required.
270+
const bool can_drop_id = !options.enable_minibatching || kCreateBuckets;
267271
observed_ids(global_sc_id, bucket_id) += 1;
268272
if (is_new_col) {
269273
observed_unique_ids(global_sc_id, bucket_id) += 1;
@@ -312,8 +316,9 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
312316
// NOTE: We use output buffers `max_ids_per_sc`, `max_unique_ids_per_sc`, and
313317
// `required_buffer_size_per_sc` because we fill values in a loop to a bigger
314318
// array.
315-
template <bool kHasVariableWeights = true, typename SplitType>
316-
PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
319+
template <bool kHasVariableWeights = true, bool kCreateBuckets,
320+
typename SplitType>
321+
PartitionedCooTensors SortAndGroupCooTensorsPerLocalDeviceImpl(
317322
const ExtractedCooTensors& extracted_coo_tensors,
318323
const StackedTableMetadata& stacked_table_metadata,
319324
const PreprocessSparseDenseMatmulInputOptions& options,
@@ -334,20 +339,18 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
334339
// This function can be called in two passes for minibatching. The logic for
335340
// stats collection and ID dropping depends on the pass.
336341
//
337-
// Pass 1: Check if minibatching is required (`create_buckets` is false).
342+
// Pass 1: Check if minibatching is required (`kCreateBuckets` is false).
338343
// - No IDs are dropped.
339344
// - Stats are collected on all observed IDs to compute splits.
340345
//
341-
// Pass 2: Create buckets (`create_buckets` is true).
346+
// Pass 2: Create buckets (`kCreateBuckets` is true).
342347
// - A dummy stats object is used (stats are not re-computed).
343348
// - IDs may be dropped if they exceed capacity.
344-
const bool create_buckets = options.enable_minibatching &&
345-
(std::is_same_v<SplitType, MinibatchingSplit>);
346349

347350
// Partition COO tensors among SparseCores for the local device (based on row
348351
// id).
349352
const int bucket_count =
350-
create_buckets ? CooFormat::kMaxMinibatchingBuckets : 1;
353+
kCreateBuckets ? CooFormat::kMaxMinibatchingBuckets : 1;
351354
PartitionedCooTensors grouped_coo_tensors(
352355
coo_tensors.size(), num_sc_per_device, global_sc_count, bucket_count);
353356

@@ -383,7 +386,7 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
383386
// local_embedding_id(32-num_scs bits), index(26 bits)].
384387
// Note that this assumes `num_scs` is a power of 2.
385388
keys.push_back(coo_tensor.GetGroupingKey(
386-
num_sc_bits, coo_tensor_index, create_buckets,
389+
num_sc_bits, coo_tensor_index, kCreateBuckets,
387390
options.minibatching_bucketing_hash_fn, kHasVariableWeights));
388391
DCHECK(kHasVariableWeights || coo_tensors[coo_tensor_index].gain == 1.0f)
389392
<< "kHasVariableWeights: " << kHasVariableWeights
@@ -399,7 +402,6 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
399402
.coo_tensors = coo_tensors,
400403
.stacked_table_metadata = stacked_table_metadata,
401404
.options = options,
402-
.create_buckets = create_buckets,
403405
.local_sc_id = local_sc_id,
404406
.num_sc_bits = num_sc_bits,
405407
.grouped_coo_tensors = grouped_coo_tensors,
@@ -414,7 +416,7 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
414416
};
415417

416418
internal::GroupAndDeduplicateCooTensorsForLocalSparseCore<
417-
kHasVariableWeights>(context);
419+
kHasVariableWeights, kCreateBuckets>(context);
418420

419421
grouped_coo_tensors.FillRemainingScBuckets();
420422

@@ -450,7 +452,7 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
450452

451453
// Only validate if creating minibatching buckets or when minibatching is
452454
// disabled, not when checking if minibatching is required.
453-
if (!options.enable_minibatching || create_buckets)
455+
if (!options.enable_minibatching || kCreateBuckets)
454456
internal::ValidateMaxIdsOrDie(
455457
observed_max_ids_per_bucket, observed_max_unique_ids_per_bucket,
456458
max_ids_per_partition, max_unique_ids_per_partition,
@@ -460,6 +462,26 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
460462
return grouped_coo_tensors;
461463
}
462464

465+
template <bool kHasVariableWeights = true, typename SplitType>
466+
PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
467+
const ExtractedCooTensors& extracted_coo_tensors,
468+
const StackedTableMetadata& stacked_table_metadata,
469+
const PreprocessSparseDenseMatmulInputOptions& options,
470+
internal::StatsPerDevice& stats, SplitType& minibatching_split) {
471+
const bool create_buckets =
472+
options.enable_minibatching &&
473+
std::is_same_v<SplitType, MinibatchingSplit>;
474+
if (create_buckets) {
475+
return SortAndGroupCooTensorsPerLocalDeviceImpl<kHasVariableWeights, true>(
476+
extracted_coo_tensors, stacked_table_metadata, options, stats,
477+
minibatching_split);
478+
} else {
479+
return SortAndGroupCooTensorsPerLocalDeviceImpl<kHasVariableWeights, false>(
480+
extracted_coo_tensors, stacked_table_metadata, options, stats,
481+
minibatching_split);
482+
}
483+
}
484+
463485
} // namespace jax_sc_embedding
464486

465487
#endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_SORT_AND_GROUP_COO_TENSORS_IMPL_H_

0 commit comments

Comments
 (0)