@@ -175,7 +175,6 @@ struct LocalSparseCoreTensorGroupingContext {
175175 absl::Span<const CooFormat> coo_tensors;
176176 const StackedTableMetadata& stacked_table_metadata;
177177 const PreprocessSparseDenseMatmulInputOptions& options;
178- const bool create_buckets;
179178 const int32_t local_sc_id;
180179 const int32_t num_sc_bits;
181180
@@ -184,11 +183,12 @@ struct LocalSparseCoreTensorGroupingContext {
184183 MatrixXi& ids_per_sc_partition_per_bucket;
185184 MatrixXi& unique_ids_per_partition_per_bucket;
186185 StatsPerDevice& stats;
186+ // These are only used for id dropping decisions and can be ignored otherwise.
187187 MatrixXi& kept_ids_per_sc_partition_per_bucket;
188188 MatrixXi& kept_unique_ids_per_partition_per_bucket;
189189};
190190
191- template <bool kHasVariableWeights >
191+ template <bool kHasVariableWeights , bool kCreateBuckets >
192192inline void GroupAndDeduplicateCooTensorsForLocalSparseCore (
193193 LocalSparseCoreTensorGroupingContext context) {
194194 // Unpack context for readability.
@@ -219,7 +219,8 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
219219 const int num_sc_bits = context.num_sc_bits ;
220220 for (const uint64_t key : context.keys ) {
221221 // Step 1: Unpack key to get tensor coordinates.
222- const uint32_t bucket_id = CooFormat::GetBucketIdFromKey (key);
222+ const uint32_t bucket_id =
223+ kCreateBuckets ? CooFormat::GetBucketIdFromKey (key) : 0 ;
223224 const uint32_t col_id =
224225 absl::rotl (CooFormat::GetRotatedColIdFromKey (key), num_sc_bits);
225226 const uint32_t global_sc_id = col_id & (global_sc_count - 1 );
@@ -244,26 +245,29 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
244245 }
245246 // If the ID is a duplicate of the last seen ID, it must have been dropped
246247 // (otherwise it would have been merged above), so drop this one too.
247- if (bucket_id == prev_bucket_id && col_id == prev_col_id &&
248- row_id == prev_row_id) {
248+ bool fully_duplicate = col_id == prev_col_id && row_id == prev_row_id;
249+ if constexpr (kCreateBuckets ) {
250+ fully_duplicate = fully_duplicate && bucket_id == prev_bucket_id;
251+ }
252+ if (fully_duplicate) {
249253 ++stats.dropped_id_count ;
250254 continue ;
251255 }
252256
253- // We do NOT drop IDs when minibatching is enabled and we are in the
254- // first pass (`create_buckets=false`), as we need to detect limit
255- // overflows to decide if minibatching is required.
256- const bool can_drop_id =
257- !options.enable_minibatching || context.create_buckets ;
258-
259257 // Step 3: Update observed statistics for the new ID.
260258 // We have a new column if the bucket_id changes (we can't dedupe across
261259 // bucket boundaries) or if the col_id changes within the same bucket. Note
262260 // that multiple col_ids can map to the same bucket.
263- const bool is_new_col =
264- (bucket_id != prev_bucket_id || col_id != prev_col_id);
261+ bool is_new_col = col_id != prev_col_id;
262+ if constexpr (kCreateBuckets ) {
263+ is_new_col = is_new_col || bucket_id != prev_bucket_id;
264+ }
265265 // Update observed stats. These are never decremented and are used for
266266 // reporting.
267+ // We do NOT drop IDs when minibatching is enabled and we are in the
268+ // first pass (`kCreateBuckets=false`), as we need to detect limit
269+ // overflows to decide if minibatching is required.
270+ const bool can_drop_id = !options.enable_minibatching || kCreateBuckets ;
267271 observed_ids (global_sc_id, bucket_id) += 1 ;
268272 if (is_new_col) {
269273 observed_unique_ids (global_sc_id, bucket_id) += 1 ;
@@ -312,8 +316,9 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore(
312316// NOTE: We use output buffers `max_ids_per_sc`, `max_unique_ids_per_sc`, and
313317// `required_buffer_size_per_sc` because we fill values in a loop to a bigger
314318// array.
315- template <bool kHasVariableWeights = true , typename SplitType>
316- PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice (
319+ template <bool kHasVariableWeights = true , bool kCreateBuckets ,
320+ typename SplitType>
321+ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDeviceImpl (
317322 const ExtractedCooTensors& extracted_coo_tensors,
318323 const StackedTableMetadata& stacked_table_metadata,
319324 const PreprocessSparseDenseMatmulInputOptions& options,
@@ -334,20 +339,18 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
334339 // This function can be called in two passes for minibatching. The logic for
335340 // stats collection and ID dropping depends on the pass.
336341 //
337- // Pass 1: Check if minibatching is required (`create_buckets ` is false).
342+ // Pass 1: Check if minibatching is required (`kCreateBuckets ` is false).
338343 // - No IDs are dropped.
339344 // - Stats are collected on all observed IDs to compute splits.
340345 //
341- // Pass 2: Create buckets (`create_buckets ` is true).
346+ // Pass 2: Create buckets (`kCreateBuckets ` is true).
342347 // - A dummy stats object is used (stats are not re-computed).
343348 // - IDs may be dropped if they exceed capacity.
344- const bool create_buckets = options.enable_minibatching &&
345- (std::is_same_v<SplitType, MinibatchingSplit>);
346349
347350 // Partition COO tensors among SparseCores for the local device (based on row
348351 // id).
349352 const int bucket_count =
350- create_buckets ? CooFormat::kMaxMinibatchingBuckets : 1 ;
353+ kCreateBuckets ? CooFormat::kMaxMinibatchingBuckets : 1 ;
351354 PartitionedCooTensors grouped_coo_tensors (
352355 coo_tensors.size (), num_sc_per_device, global_sc_count, bucket_count);
353356
@@ -383,7 +386,7 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
383386 // local_embedding_id(32-num_scs bits), index(26 bits)].
384387 // Note that this assumes `num_scs` is a power of 2.
385388 keys.push_back (coo_tensor.GetGroupingKey (
386- num_sc_bits, coo_tensor_index, create_buckets ,
389+ num_sc_bits, coo_tensor_index, kCreateBuckets ,
387390 options.minibatching_bucketing_hash_fn , kHasVariableWeights ));
388391 DCHECK (kHasVariableWeights || coo_tensors[coo_tensor_index].gain == 1 .0f )
389392 << " kHasVariableWeights: " << kHasVariableWeights
@@ -399,7 +402,6 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
399402 .coo_tensors = coo_tensors,
400403 .stacked_table_metadata = stacked_table_metadata,
401404 .options = options,
402- .create_buckets = create_buckets,
403405 .local_sc_id = local_sc_id,
404406 .num_sc_bits = num_sc_bits,
405407 .grouped_coo_tensors = grouped_coo_tensors,
@@ -414,7 +416,7 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
414416 };
415417
416418 internal::GroupAndDeduplicateCooTensorsForLocalSparseCore<
417- kHasVariableWeights >(context);
419+ kHasVariableWeights , kCreateBuckets >(context);
418420
419421 grouped_coo_tensors.FillRemainingScBuckets ();
420422
@@ -450,7 +452,7 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
450452
451453 // Only validate if creating minibatching buckets or when minibatching is
452454 // disabled, not when checking if minibatching is required.
453- if (!options.enable_minibatching || create_buckets )
455+ if (!options.enable_minibatching || kCreateBuckets )
454456 internal::ValidateMaxIdsOrDie (
455457 observed_max_ids_per_bucket, observed_max_unique_ids_per_bucket,
456458 max_ids_per_partition, max_unique_ids_per_partition,
@@ -460,6 +462,26 @@ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice(
460462 return grouped_coo_tensors;
461463}
462464
465+ template <bool kHasVariableWeights = true , typename SplitType>
466+ PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice (
467+ const ExtractedCooTensors& extracted_coo_tensors,
468+ const StackedTableMetadata& stacked_table_metadata,
469+ const PreprocessSparseDenseMatmulInputOptions& options,
470+ internal::StatsPerDevice& stats, SplitType& minibatching_split) {
471+ const bool create_buckets =
472+ options.enable_minibatching &&
473+ std::is_same_v<SplitType, MinibatchingSplit>;
474+ if (create_buckets) {
475+ return SortAndGroupCooTensorsPerLocalDeviceImpl<kHasVariableWeights , true >(
476+ extracted_coo_tensors, stacked_table_metadata, options, stats,
477+ minibatching_split);
478+ } else {
479+ return SortAndGroupCooTensorsPerLocalDeviceImpl<kHasVariableWeights , false >(
480+ extracted_coo_tensors, stacked_table_metadata, options, stats,
481+ minibatching_split);
482+ }
483+ }
484+
463485} // namespace jax_sc_embedding
464486
465487#endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_SORT_AND_GROUP_COO_TENSORS_IMPL_H_
0 commit comments