diff --git a/jax_tpu_embedding/sparsecore/lib/core/BUILD b/jax_tpu_embedding/sparsecore/lib/core/BUILD index fbaad01e..6c4b750b 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/BUILD +++ b/jax_tpu_embedding/sparsecore/lib/core/BUILD @@ -120,6 +120,7 @@ cc_test( "@com_google_fuzztest//fuzztest", "@com_google_googletest//:gtest_main", "@eigen_archive//:eigen3", + "@xla//xla:util", ], ) @@ -139,6 +140,7 @@ cc_library( "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", + "@com_google_absl//absl/status:statusor", "@com_google_absl//absl/strings", "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/strings:string_view", diff --git a/jax_tpu_embedding/sparsecore/lib/core/extract_sort_and_group_benchmark.cc b/jax_tpu_embedding/sparsecore/lib/core/extract_sort_and_group_benchmark.cc index fec01bff..81a12a7a 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/extract_sort_and_group_benchmark.cc +++ b/jax_tpu_embedding/sparsecore/lib/core/extract_sort_and_group_benchmark.cc @@ -85,15 +85,17 @@ std::vector GenerateEmbeddingIdsForRow(absl::BitGen& gen, int vocab_size) { return ids_out; } -ExtractedCooTensors GenerateSkewedCooTensors(int num_sc_per_device, - int batch_size_per_sc, - int vocab_size) { +std::vector GenerateSkewedCooTensors( + int num_sc_per_device, int batch_size_per_sc, int vocab_size) { const int batch_size_for_device = num_sc_per_device * batch_size_per_sc; absl::BitGen gen(std::seed_seq{kSeed}); // seed for reproducibility - ExtractedCooTensors extracted_coo_tensors(num_sc_per_device, - batch_size_for_device); + std::vector extracted_sc_tensors( + num_sc_per_device); + for (int i = 0; i < num_sc_per_device; ++i) { + extracted_sc_tensors[i].batch_size_for_device = batch_size_for_device; + } // For each sample in the batch: // 1. Draw a sample size from a Lognormal distribution. @@ -104,13 +106,12 @@ ExtractedCooTensors GenerateSkewedCooTensors(int num_sc_per_device, std::vector embedding_ids = GenerateEmbeddingIdsForRow(gen, vocab_size); int sc_id = row / batch_size_per_sc; - extracted_coo_tensors.coo_tensors_per_sc[sc_id] += embedding_ids.size(); for (int embedding_id : embedding_ids) { - extracted_coo_tensors.coo_tensors.push_back( + extracted_sc_tensors[sc_id].coo_tensors.push_back( CooFormat(row, embedding_id, 1.0)); } } - return extracted_coo_tensors; + return extracted_sc_tensors; } std::vector> @@ -173,10 +174,13 @@ void BM_ExtractCooTensors(benchmark::State& state) { .feature_stacking_strategy = FeatureStackingStrategy::kSplitThenStack, }; + std::vector extracted_tensors_per_sc( + kNumScPerDevice); + for (auto s : state) { - internal::ExtractCooTensorsForAllFeaturesPerLocalDevice( + internal::ExtractCooTensorsForLocalDevice( stacked_table_metadata, absl::MakeSpan(input_batches), - /*local_device_id=*/0, options); + /*local_device=*/0, options, absl::MakeSpan(extracted_tensors_per_sc)); } } BENCHMARK(BM_ExtractCooTensors) @@ -188,7 +192,7 @@ BENCHMARK(BM_ExtractCooTensors) ->UseRealTime(); void BM_SortAndGroup_Phase1(benchmark::State& state) { - ExtractedCooTensors extracted_coo_tensors = + std::vector per_sc_tensors = GenerateSkewedCooTensors(kNumScPerDevice, kBatchSizePerSc, kVocabSize); // Set to INT_MAX to avoid ID dropping and observe the actual statistics of @@ -206,18 +210,26 @@ void BM_SortAndGroup_Phase1(benchmark::State& state) { .num_sc_per_device = kNumScPerDevice, .allow_id_dropping = false, }; - bool minibatching_required = false; - StatsPerHost stats_per_host( - /*local_device_count=*/1, - /*global_sc_count=*/kNumScPerDevice * kGlobalDeviceCount, - /*num_sc_per_device=*/kNumScPerDevice); - internal::StatsPerDevice stats_per_device = - stats_per_host.GetStatsPerDevice(0); if (state.thread_index() == 0) { - SortAndGroupCooTensorsPerLocalDevice( - extracted_coo_tensors, stacked_table_metadata, options, - stats_per_device, minibatching_required); + bool minibatching_required = false; + StatsPerHost stats_per_host( + /*local_device_count=*/1, + /*global_sc_count=*/kNumScPerDevice * kGlobalDeviceCount, + /*num_sc_per_device=*/kNumScPerDevice); + internal::StatsPerDevice stats_per_device = + stats_per_host.GetStatsPerDevice(0); + int dropped_id_count = 0; + for (int sc_id = 0; sc_id < kNumScPerDevice; ++sc_id) { + PerSparseCoreGroupedData result = + internal::SortAndGroupCooTensorsForSingleSparseCore< + /*kHasVariableWeights=*/false>(per_sc_tensors[sc_id], + /*local_device_id=*/0, sc_id, + options, stacked_table_metadata, + minibatching_required); + internal::AggregatePerSparseCoreStats(result, sc_id, stats_per_device, + dropped_id_count); + } LogStats(stats_per_device.max_ids_per_partition, "Max ids per partition across all global SCs"); LogStats(stats_per_device.max_unique_ids_per_partition, @@ -225,9 +237,14 @@ void BM_SortAndGroup_Phase1(benchmark::State& state) { } for (auto s : state) { - SortAndGroupCooTensorsPerLocalDevice( - extracted_coo_tensors, stacked_table_metadata, options, - stats_per_device, minibatching_required); + bool minibatching_required = false; + for (int sc_id = 0; sc_id < kNumScPerDevice; ++sc_id) { + internal::SortAndGroupCooTensorsForSingleSparseCore< + /*kHasVariableWeights=*/false>(per_sc_tensors[sc_id], + /*local_device_id=*/0, sc_id, options, + stacked_table_metadata, + minibatching_required); + } } } BENCHMARK(BM_SortAndGroup_Phase1) diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc index 9352f902..dca075d4 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc @@ -48,63 +48,6 @@ namespace jax_sc_embedding { namespace { -// Extract the COO tensors for a single feature slice. -void ExtractCooTensorsForSingleFeatureSlice( - const StackedTableMetadata& metadata, - absl::Span> input_batches, - const int local_device_id, const int feature_slice_id, - const int feature_slices_per_device, - const PreprocessSparseDenseMatmulInputOptions& options, - ExtractedCooTensors& extracted_coo_tensors) { - const int feature_index = metadata.feature_index; - const std::unique_ptr& curr_batch = - input_batches[feature_index]; - const int num_samples = curr_batch->size(); - - const int batch_size_per_slice = xla::CeilOfRatio( - extracted_coo_tensors.batch_size_for_device, feature_slices_per_device); - - CHECK_GT(feature_slices_per_device, 0); - CHECK_GT(options.global_device_count, 0); - CHECK_GT(options.local_device_count, 0); - - const int row_offset_per_slice = - metadata.row_offset / - (options.global_device_count * feature_slices_per_device); - const int row_offset = - feature_slice_id * batch_size_per_slice + row_offset_per_slice; - const int col_offset = metadata.col_offset; - const int col_shift = metadata.col_shift; - - const int num_samples_per_split = - num_samples / (options.local_device_count * feature_slices_per_device); - const int start_index = - (local_device_id * feature_slices_per_device + feature_slice_id) * - num_samples_per_split; - int end_index = std::min(num_samples, start_index + num_samples_per_split); - - // In the case of feature stacking, we need to group all the COO tensors - // at this stage (i.e., before the sorting later on). - VLOG(2) << absl::StrFormat( - "Extracting COO Tensor from feature #%d from row %d to %d " - "(local_device_id = %d, feature_slice_id = %d, row_offset = %d, " - "batch_size_per_slice = %d)", - feature_index, start_index, end_index, local_device_id, feature_slice_id, - row_offset, batch_size_per_slice); - curr_batch->ExtractCooTensors( - { - .slice_start = start_index, - .slice_end = end_index, - .row_offset = row_offset, - .col_offset = col_offset, - .col_shift = col_shift, - .num_sc_per_device = options.num_sc_per_device, - .num_scs = options.GetNumScs(), - .combiner = metadata.row_combiner, - }, - extracted_coo_tensors); -} - void CheckDeviceBatchSize(int batch_size_for_device, int num_sc_per_device, absl::string_view stacked_table_name) { CHECK_EQ(batch_size_for_device % num_sc_per_device, 0) << absl::StrFormat( @@ -113,6 +56,119 @@ void CheckDeviceBatchSize(int batch_size_for_device, int num_sc_per_device, batch_size_for_device, stacked_table_name, num_sc_per_device); } +} // namespace + +namespace internal { + +ExtractedSparseCoreTensors ExtractCooTensorsForSparseCore( + absl::Span stacked_table_metadata, + absl::Span> input_batches, + int local_device_id, int local_sc_id, + const PreprocessSparseDenseMatmulInputOptions& options) { + int batch_size_for_device = 0; + for (const auto& feature_metadata : stacked_table_metadata) { + batch_size_for_device += + input_batches[feature_metadata.feature_index]->size() / + options.local_device_count; + } + CheckDeviceBatchSize(batch_size_for_device, options.num_sc_per_device, + stacked_table_metadata[0].name); + + int feature_slices_per_device; + int sc_slice_id; + switch (options.feature_stacking_strategy) { + case FeatureStackingStrategy::kStackThenSplit: + feature_slices_per_device = 1; + sc_slice_id = 0; + break; + case FeatureStackingStrategy::kSplitThenStack: + feature_slices_per_device = options.num_sc_per_device; + sc_slice_id = local_sc_id; + break; + default: + LOG(FATAL) << "Unsupported feature stacking strategy: " + << static_cast(options.feature_stacking_strategy); + break; + } + + CHECK_GE(batch_size_for_device, + feature_slices_per_device * stacked_table_metadata.size()) + << "Batch size must be greater or equal to the number of " + "features stacked together (per feature slice)."; + + ExtractedCooTensors extracted_coo_tensors(options.num_sc_per_device, + batch_size_for_device); + + for (const auto& metadata : stacked_table_metadata) { + const int feature_index = metadata.feature_index; + const std::unique_ptr& curr_batch = + input_batches[feature_index]; + const int num_samples = curr_batch->size(); + + const int batch_size_per_slice = xla::CeilOfRatio( + extracted_coo_tensors.batch_size_for_device, feature_slices_per_device); + + CHECK_GT(feature_slices_per_device, 0); + CHECK_GT(options.global_device_count, 0); + CHECK_GT(options.local_device_count, 0); + + const int row_offset_per_slice = + metadata.row_offset / + (options.global_device_count * feature_slices_per_device); + const int row_offset = + sc_slice_id * batch_size_per_slice + row_offset_per_slice; + const int col_offset = metadata.col_offset; + const int col_shift = metadata.col_shift; + + const int num_samples_per_split = + num_samples / (options.local_device_count * feature_slices_per_device); + const int start_index = + (local_device_id * feature_slices_per_device + sc_slice_id) * + num_samples_per_split; + int end_index = std::min(num_samples, start_index + num_samples_per_split); + + // In the case of feature stacking, we need to group all the COO tensors + // at this stage (i.e., before the sorting later on). + VLOG(2) << absl::StrFormat( + "Extracting COO Tensor from feature #%d from row %d to %d " + "(local_device_id = %d, feature_slice_id = %d, row_offset = %d, " + "batch_size_per_slice = %d)", + feature_index, start_index, end_index, local_device_id, sc_slice_id, + row_offset, batch_size_per_slice); + curr_batch->ExtractCooTensors( + { + .slice_start = start_index, + .slice_end = end_index, + .row_offset = row_offset, + .col_offset = col_offset, + .col_shift = col_shift, + .num_sc_per_device = options.num_sc_per_device, + .num_scs = options.GetNumScs(), + .combiner = metadata.row_combiner, + }, + extracted_coo_tensors); + } + + return {.coo_tensors = std::move( + extracted_coo_tensors.per_sc_tensors[local_sc_id].coo_tensors), + .batch_size_for_device = batch_size_for_device}; +} + +void ExtractCooTensorsForLocalDevice( + absl::Span stacked_table_metadata, + absl::Span> input_batches, + int local_device, const PreprocessSparseDenseMatmulInputOptions& options, + absl::Span extracted_sc_tensors) { + for (int sc_id = 0; sc_id < options.num_sc_per_device; ++sc_id) { + extracted_sc_tensors[sc_id] = ExtractCooTensorsForSparseCore( + stacked_table_metadata, input_batches, local_device, sc_id, options); + } +} + +} // namespace internal + +namespace { + // We consider a stack to have variable weights if any feature in the stack // has explicitly variable weights or if any feature uses a row combiner // other than 'sum' (e.g., 'mean' or 'sqrtn'). @@ -147,8 +203,8 @@ struct TableState { int batch_size_for_device; bool table_minibatching_required = false; MinibatchingSplit table_minibatching_split = 0; - std::vector extracted_coo_tensors_per_device; - std::vector partitioned_coo_tensors_per_device; + std::vector sc_results_per_device; + std::vector extracted_tensors_per_sc; std::vector dropped_id_count_per_device; TableState(absl::string_view name, @@ -172,27 +228,34 @@ struct TableState { stats_per_host(options.local_device_count, options.GetNumScs(), options.num_sc_per_device), batch_size_for_device(0) { - extracted_coo_tensors_per_device.resize(options.local_device_count); - partitioned_coo_tensors_per_device.resize(options.local_device_count); + sc_results_per_device.resize(options.local_device_count * + options.num_sc_per_device); + extracted_tensors_per_sc.resize(options.local_device_count * + options.num_sc_per_device); dropped_id_count_per_device.resize(options.local_device_count, 0); } }; template -void SortAndGroupCooTensorsForTableState( - TableState& state, int local_device, - const PreprocessSparseDenseMatmulInputOptions& options, - internal::StatsPerDevice& stats, SplitType& split) { - if (state.has_variable_weights) { - state.partitioned_coo_tensors_per_device[local_device] = - SortAndGroupCooTensorsPerLocalDevice( - state.extracted_coo_tensors_per_device[local_device], - state.stacked_table_metadata[0], options, stats, split); - } else { - state.partitioned_coo_tensors_per_device[local_device] = - SortAndGroupCooTensorsPerLocalDevice( - state.extracted_coo_tensors_per_device[local_device], - state.stacked_table_metadata[0], options, stats, split); +void SortAndGroupCooTensorsForLocalDevice( + absl::Span stacked_table_metadata, + int local_device, const PreprocessSparseDenseMatmulInputOptions& options, + bool has_variable_weights, + absl::Span extracted_tensors_for_device, + SplitType& split, + absl::Span sc_results_for_device) { + for (int sc_id = 0; sc_id < options.num_sc_per_device; ++sc_id) { + if (has_variable_weights) { + sc_results_for_device[sc_id] = + internal::SortAndGroupCooTensorsForSingleSparseCore( + extracted_tensors_for_device[sc_id], local_device, sc_id, options, + stacked_table_metadata[0], split); + } else { + sc_results_for_device[sc_id] = + internal::SortAndGroupCooTensorsForSingleSparseCore( + extracted_tensors_for_device[sc_id], local_device, sc_id, options, + stacked_table_metadata[0], split); + } } } @@ -213,22 +276,35 @@ void ExtractSortAndGroupCooTensorsForTable( }); for (int local_device = 0; local_device < options.local_device_count; ++local_device) { - PreprocessingThreadPool()->Schedule( - [&, local_device, &state = state, input_batches] { - state.extracted_coo_tensors_per_device[local_device] = - internal::ExtractCooTensorsForAllFeaturesPerLocalDevice( - state.stacked_table_metadata, input_batches, local_device, - options); - - internal::StatsPerDevice stats_per_device = - state.stats_per_host.GetStatsPerDevice(local_device); - SortAndGroupCooTensorsForTableState( - state, local_device, options, stats_per_device, + for (int sc_id = 0; sc_id < options.num_sc_per_device; ++sc_id) { + PreprocessingThreadPool()->Schedule([&state, input_batches, &options, + &counter, local_device, sc_id] { + auto& extracted_sc_tensor = + state.extracted_tensors_per_sc[local_device * + options.num_sc_per_device + + sc_id]; + auto& sc_result = + state.sc_results_per_device[local_device * + options.num_sc_per_device + + sc_id]; + extracted_sc_tensor = internal::ExtractCooTensorsForSparseCore( + state.stacked_table_metadata, input_batches, local_device, sc_id, + options); + if (state.has_variable_weights) { + sc_result = internal::SortAndGroupCooTensorsForSingleSparseCore( + extracted_sc_tensor, local_device, sc_id, options, + state.stacked_table_metadata[0], state.table_minibatching_required); - state.dropped_id_count_per_device[local_device] = - stats_per_device.dropped_id_count; - counter.DecrementCount(); - }); + } else { + sc_result = + internal::SortAndGroupCooTensorsForSingleSparseCore( + extracted_sc_tensor, local_device, sc_id, options, + state.stacked_table_metadata[0], + state.table_minibatching_required); + } + counter.DecrementCount(); + }); + } } } @@ -237,9 +313,9 @@ void PostProcessTableState(TableState& state) { absl::c_accumulate(state.dropped_id_count_per_device, 0LL); state.batch_size_for_device = - state.extracted_coo_tensors_per_device[0].batch_size_for_device; - for (const auto& extracted_coo : state.extracted_coo_tensors_per_device) { - DCHECK_EQ(state.batch_size_for_device, extracted_coo.batch_size_for_device); + state.sc_results_per_device[0].batch_size_for_device; + for (const auto& result : state.sc_results_per_device) { + DCHECK_EQ(state.batch_size_for_device, result.batch_size_for_device); } } @@ -249,32 +325,38 @@ void PostProcessTableState(TableState& state) { // `state`: The TableState holding the COO tensors and statistics. // `options`: Preprocessing options. void CreateMinibatchingBucketsForTable( - TableState& state, const PreprocessSparseDenseMatmulInputOptions& options, + TableState& state, + const PreprocessSparseDenseMatmulInputOptions& options, absl::BlockingCounter& counter) { tsl::profiler::TraceMe traceme([&] { return tsl::profiler::TraceMeEncode( - absl::StrCat("InputPreprocessingTable-CreateMinibatchingBuckets-", + absl::StrCat("InputPreprocessingTable-ExtractSortGroup-", state.stacked_table_name), + {{"batch_number", options.batch_number}}); }); state.stats_per_host.dropped_id_count = 0; for (int local_device = 0; local_device < options.local_device_count; ++local_device) { - PreprocessingThreadPool()->Schedule([&, local_device, &state = state] { - // Note: We create a dummy stats object here because we don't want to - // overwrite the stats from the first pass, which are authoritative. - // The only stat we care about from this second pass is the number of - // dropped IDs. - StatsPerHost dummy_stats_host( - /*local_device_count=*/1, options.GetNumScs(), - options.num_sc_per_device); - internal::StatsPerDevice dummy_stats = - dummy_stats_host.GetStatsPerDevice(0); - SortAndGroupCooTensorsForTableState(state, local_device, options, - dummy_stats, - state.table_minibatching_split); - state.dropped_id_count_per_device[local_device] = - dummy_stats.dropped_id_count; + auto extracted_sc_tensors_for_device = + absl::MakeSpan(state.extracted_tensors_per_sc) + .subspan(local_device * options.num_sc_per_device, + options.num_sc_per_device); + auto sc_results_for_device = + absl::MakeSpan(state.sc_results_per_device) + .subspan(local_device * options.num_sc_per_device, + options.num_sc_per_device); + PreprocessingThreadPool()->Schedule([&, local_device, + extracted_sc_tensors_for_device, + sc_results_for_device] { + SortAndGroupCooTensorsForLocalDevice( + state.stacked_table_metadata, local_device, options, + state.has_variable_weights, extracted_sc_tensors_for_device, + state.table_minibatching_split, sc_results_for_device); + for (int sc_id = 0; sc_id < options.num_sc_per_device; ++sc_id) { + state.dropped_id_count_per_device[local_device] += + sc_results_for_device[sc_id].dropped_id_count; + } counter.DecrementCount(); }); } @@ -292,60 +374,6 @@ inline MinibatchingSplit Deserialize(uint64_t value) { inline bool Serialize(bool value) { return value; } inline bool Deserialize(bool value) { return value; } -// Extract the COO tensors for all features. -ExtractedCooTensors ExtractCooTensorsForAllFeaturesPerLocalDevice( - const absl::Span stacked_table_metadata, - absl::Span> input_batches, - const int local_device_id, - const PreprocessSparseDenseMatmulInputOptions& options) { - int batch_size_for_device = 0; - for (const auto& feature_metadata : stacked_table_metadata) { - batch_size_for_device += - input_batches[feature_metadata.feature_index]->size() / - options.local_device_count; - } - CheckDeviceBatchSize(batch_size_for_device, options.num_sc_per_device, - stacked_table_metadata[0].name); - - int feature_slices_per_device; - switch (options.feature_stacking_strategy) { - case FeatureStackingStrategy::kStackThenSplit: - feature_slices_per_device = 1; - break; - case FeatureStackingStrategy::kSplitThenStack: - feature_slices_per_device = options.num_sc_per_device; - break; - default: - LOG(FATAL) << "Unsupported feature stacking strategy: " - << static_cast(options.feature_stacking_strategy); - break; - } - - CHECK_GE(batch_size_for_device, - feature_slices_per_device * stacked_table_metadata.size()) - << "Batch size must be greater or equal to the number of " - "features stacked together (per feature slice)."; - - ExtractedCooTensors extracted_coo_tensors(options.num_sc_per_device, - batch_size_for_device); - - // This slices each feature into `feature_slices` partitions and then - // interleaves them: (k=num_sc_per_device-1). For stacking strategy - // SC0: F1_1, F2_1, ... Fn_1, // <- batch_size_per_slice - // SC1: F1_2, F2_2, ... Fn_2, // <- batch_size_per_slice - // ... // <- batch_size_per_slice - // SCk: F1_k, F2_k, ..., Fn_k // <- batch_size_per_slice - for (int feature_slice_id = 0; feature_slice_id < feature_slices_per_device; - ++feature_slice_id) { - for (const auto& feature_metadata : stacked_table_metadata) { - ExtractCooTensorsForSingleFeatureSlice( - feature_metadata, input_batches, local_device_id, feature_slice_id, - feature_slices_per_device, options, extracted_coo_tensors); - } - } - return extracted_coo_tensors; -} - } // namespace internal namespace { @@ -513,10 +541,13 @@ void FillDeviceBuffersForTable( row_pointers_size_per_bucket, global_minibatching_required, global_minibatching_split] { - PartitionedCooTensors& grouped_coo_tensors = - state.partitioned_coo_tensors_per_device[local_device]; + auto sc_results = absl::MakeSpan(state.sc_results_per_device) + .subspan(local_device * options.num_sc_per_device, + options.num_sc_per_device); if (options.enable_minibatching && global_minibatching_required) { - grouped_coo_tensors.Merge(global_minibatching_split); + for (auto& result : sc_results) { + result.grouped_tensors.Merge(global_minibatching_split); + } } const int batch_size_per_sc = xla::CeilOfRatio( @@ -526,7 +557,7 @@ void FillDeviceBuffersForTable( internal::CsrArraysPerDevice csr_arrays_per_device = state.csr_arrays_per_host.GetCsrArraysPerDevice(local_device); int table_dropped_ids = 0; - FillLocalDeviceBuffer(grouped_coo_tensors, row_pointers_size_per_bucket, + FillLocalDeviceBuffer(sc_results, row_pointers_size_per_bucket, coo_buffer_size_per_sc, batch_size_per_sc, options, csr_arrays_per_device, table_dropped_ids); state.dropped_id_count_per_device[local_device] = table_dropped_ids; @@ -591,7 +622,8 @@ PreprocessSparseDenseMatmulInput( {{"batch_number", options.batch_number}}); }); absl::BlockingCounter counter(table_states.size() * - options.local_device_count); + options.local_device_count * + options.num_sc_per_device); for (auto& state : table_states) { ExtractSortAndGroupCooTensorsForTable(state, input_batches, options, counter); @@ -600,6 +632,19 @@ PreprocessSparseDenseMatmulInput( // Post-process results after all threads are done. for (auto& state : table_states) { + for (int local_device = 0; local_device < options.local_device_count; + ++local_device) { + internal::StatsPerDevice stats_per_device = + state.stats_per_host.GetStatsPerDevice(local_device); + for (int sc_id = 0; sc_id < options.num_sc_per_device; ++sc_id) { + internal::AggregatePerSparseCoreStats( + state.sc_results_per_device[local_device * + options.num_sc_per_device + + sc_id], + sc_id, stats_per_device, + state.dropped_id_count_per_device[local_device]); + } + } PostProcessTableState(state); } } diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.h b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.h index a20728d1..6e23e37f 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.h +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.h @@ -19,6 +19,8 @@ #include #include "absl/container/flat_hash_map.h" // from @com_google_absl +#include "absl/status/statusor.h" // from @com_google_absl +#include "absl/strings/string_view.h" // from @com_google_absl #include "absl/types/span.h" // from @com_google_absl #include "jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h" #include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h" @@ -44,11 +46,13 @@ struct SparseDenseMatmulInputStats { }; namespace internal { -ExtractedCooTensors ExtractCooTensorsForAllFeaturesPerLocalDevice( + +void ExtractCooTensorsForLocalDevice( absl::Span stacked_table_metadata, absl::Span> input_batches, - int local_device_id, - const PreprocessSparseDenseMatmulInputOptions& options); + int local_device, const PreprocessSparseDenseMatmulInputOptions& options, + absl::Span extracted_sc_tensors); + } // namespace internal struct PreprocessSparseDenseMatmulOutput { diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_test.cc b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_test.cc index 9451a75f..1c13ac86 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_test.cc +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_test.cc @@ -63,96 +63,6 @@ using ::testing::SizeIs; using InputBatch = ::jax_sc_embedding::RaggedTensorInputBatchWithOwnedData; -std::unique_ptr CreateInputBatchFromSamples( - absl::Span> samples) { - std::vector values; - std::vector row_splits; - row_splits.push_back(0); - for (const auto& sample_ids : samples) { - for (const auto& id : sample_ids) { - values.push_back(id); - } - row_splits.push_back(values.size()); - } - return std::make_unique(std::move(values), std::move(row_splits)); -} - -class TableStackingTest : public ::testing::Test { - protected: - InputBatch input_a_{{5, 18, // - 18, 0, // - 0, 20, // - 18, 0, // - 18, 0, // - 0, 20, // - 5, 18, // - 18, 0}, - {0, 2, 4, 6, 8, 10, 12, 14, 16}}; - InputBatch input_b_{{2, // - 10, // - 1, // - 9, // - 3, // - 7, // - 4, // - 8}, - {0, 1, 2, 3, 4, 5, 6, 7, 8}}; - InputBatch input_c_{{1, // - 2, 2, // - 3, 3, 3, // - 4, 4, 4, 4, // - 5, 5, 5, 5, 5, // - 6, 6, 6, 6, 6, 6, // - 7, 7, 7, 7, 7, 7, 7, // - 8, 8, 8, 8, 8, 8, 8, 8}, - {0, 1, 3, 6, 10, 15, 21, 28, 36}}; - InputBatch input_d_{ - {9, 9, 9, 9, 9, 9, 9, 9, 9, // - 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, // - 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, // - 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, // - 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, // - 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, // - 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, // - 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16}, - {0, 9, 19, 30, 42, 55, 69, 84, 100}}; - - std::vector stacked_table_metadata_multi_{ - StackedTableMetadata( - /*name=*/"table_0", - /*feature_index=*/0, /*max_ids_per_partition=*/16, - /*max_unique_ids_per_partition=*/16, /*row_offset=*/0, - /*col_offset=*/0, /*col_shift=*/0, /*batch_size=*/16), - StackedTableMetadata( - /*name=*/"table_1", - /*feature_index=*/1, /*max_ids_per_partition=*/16, - /*max_unique_ids_per_partition=*/16, /*row_offset=*/16, - /*col_offset=*/32, /*col_shift=*/0, /*batch_size=*/16)}; - - std::vector stacked_table_metadata_single_{ - StackedTableMetadata( - /*name=*/"table_0", - /*feature_index=*/0, /*max_ids_per_partition=*/16, - /*max_unique_ids_per_partition=*/16, /*row_offset=*/0, - /*col_offset=*/0, /*col_shift=*/0, /*batch_size=*/8), - StackedTableMetadata( - /*name=*/"table_1", - /*feature_index=*/1, /*max_ids_per_partition=*/16, - /*max_unique_ids_per_partition=*/16, /*row_offset=*/8, - /*col_offset=*/32, /*col_shift=*/0, /*batch_size=*/8)}; - - std::vector> input_batches_multi_, - input_batches_single_; - - void SetUp() override { - input_batches_multi_.push_back(std::make_unique(input_a_)); - input_batches_multi_.push_back(std::make_unique(input_b_)); - - input_batches_single_.push_back(std::make_unique(input_c_)); - input_batches_single_.push_back(std::make_unique(input_d_)); - } -}; - namespace testing_utils { // Wrapper around an AllReduceInterface that causes a test failure if @@ -217,244 +127,21 @@ class AllReduceSyncKeyCollector : public AllReduceInterface { std::vector sync_keys_ ABSL_GUARDED_BY(mutex_); }; -} // namespace testing_utils - -TEST_F(TableStackingTest, MultiProcessStackingStackThenSplit) { - PreprocessSparseDenseMatmulInputOptions options{ - .local_device_count = 1, - .global_device_count = 2, - .num_sc_per_device = 4, - .feature_stacking_strategy = FeatureStackingStrategy::kStackThenSplit}; - - ExtractedCooTensors extracted_coo_tensors = - internal::ExtractCooTensorsForAllFeaturesPerLocalDevice( - stacked_table_metadata_multi_, absl::MakeSpan(input_batches_multi_), - /*local_device_id=*/0, options); - - EXPECT_EQ(extracted_coo_tensors.batch_size_for_device, 16); - ASSERT_THAT(extracted_coo_tensors.coo_tensors, SizeIs(24)); - // This results in an uneven ID distribution - 8, 8, 4, 4 - - std::vector expected_coo_tensors; - // Feature 0, slice 0 - // SC 0 (4 rows, 8 ids) - expected_coo_tensors.push_back(CooFormat(0, 5, 1)); - expected_coo_tensors.push_back(CooFormat(0, 18, 1)); - expected_coo_tensors.push_back(CooFormat(1, 18, 1)); - expected_coo_tensors.push_back(CooFormat(1, 0, 1)); - expected_coo_tensors.push_back(CooFormat(2, 0, 1)); - expected_coo_tensors.push_back(CooFormat(2, 20, 1)); - expected_coo_tensors.push_back(CooFormat(3, 18, 1)); - expected_coo_tensors.push_back(CooFormat(3, 0, 1)); - // SC 1 (4 rows, 8 ids) - expected_coo_tensors.push_back(CooFormat(4, 18, 1)); - expected_coo_tensors.push_back(CooFormat(4, 0, 1)); - expected_coo_tensors.push_back(CooFormat(5, 0, 1)); - expected_coo_tensors.push_back(CooFormat(5, 20, 1)); - expected_coo_tensors.push_back(CooFormat(6, 5, 1)); - expected_coo_tensors.push_back(CooFormat(6, 18, 1)); - expected_coo_tensors.push_back(CooFormat(7, 18, 1)); - expected_coo_tensors.push_back(CooFormat(7, 0, 1)); - - // Feature 1, slice 0 - // SC 2 (4 rows, 4 ids) - expected_coo_tensors.push_back(CooFormat(8, 34, 1)); - expected_coo_tensors.push_back(CooFormat(9, 42, 1)); - expected_coo_tensors.push_back(CooFormat(10, 33, 1)); - expected_coo_tensors.push_back(CooFormat(11, 41, 1)); - // SC 3 (4 rows, 4 ids) - expected_coo_tensors.push_back(CooFormat(12, 35, 1)); - expected_coo_tensors.push_back(CooFormat(13, 39, 1)); - expected_coo_tensors.push_back(CooFormat(14, 36, 1)); - expected_coo_tensors.push_back(CooFormat(15, 40, 1)); - - EXPECT_THAT(extracted_coo_tensors.coo_tensors, - ElementsAreArray(expected_coo_tensors)); -} - -TEST_F(TableStackingTest, MultiProcessStackingSplitThenStack) { - PreprocessSparseDenseMatmulInputOptions options{ - .local_device_count = 1, - .global_device_count = 2, - .num_sc_per_device = 4, - .feature_stacking_strategy = FeatureStackingStrategy::kSplitThenStack}; - - ExtractedCooTensors extracted_coo_tensors = - internal::ExtractCooTensorsForAllFeaturesPerLocalDevice( - stacked_table_metadata_multi_, absl::MakeSpan(input_batches_multi_), - /*local_device_id=*/0, options); - - EXPECT_EQ(extracted_coo_tensors.batch_size_for_device, 16); - ASSERT_THAT(extracted_coo_tensors.coo_tensors, SizeIs(24)); - // This results in a more even distribution (actually ideal) - 6,6,6,6 - - std::vector expected_coo_tensors; - - // SC 0 (4 rows, 6 ids) - // Feature 0, slice 0 - expected_coo_tensors.push_back(CooFormat(0, 5, 1)); - expected_coo_tensors.push_back(CooFormat(0, 18, 1)); - expected_coo_tensors.push_back(CooFormat(1, 18, 1)); - expected_coo_tensors.push_back(CooFormat(1, 0, 1)); - - // Feature 1, slice 0 - expected_coo_tensors.push_back(CooFormat(2, 34, 1)); - expected_coo_tensors.push_back(CooFormat(3, 42, 1)); - - // SC 1 (4 rows, 6 ids) - // Feature 0, slice 1 - expected_coo_tensors.push_back(CooFormat(4, 0, 1)); - expected_coo_tensors.push_back(CooFormat(4, 20, 1)); - expected_coo_tensors.push_back(CooFormat(5, 18, 1)); - expected_coo_tensors.push_back(CooFormat(5, 0, 1)); - - // Feature 1, slice 1 - expected_coo_tensors.push_back(CooFormat(6, 33, 1)); - expected_coo_tensors.push_back(CooFormat(7, 41, 1)); - - // SC 2 (4 rows, 6 ids) - // Feature 0, slice 2 - expected_coo_tensors.push_back(CooFormat(8, 18, 1)); - expected_coo_tensors.push_back(CooFormat(8, 0, 1)); - expected_coo_tensors.push_back(CooFormat(9, 0, 1)); - expected_coo_tensors.push_back(CooFormat(9, 20, 1)); - - // Feature 1, slice 2 - expected_coo_tensors.push_back(CooFormat(10, 35, 1)); - expected_coo_tensors.push_back(CooFormat(11, 39, 1)); - // SC 3 (4 rows, 6 ids) - // Feature 0, slice 3 - expected_coo_tensors.push_back(CooFormat(12, 5, 1)); - expected_coo_tensors.push_back(CooFormat(12, 18, 1)); - expected_coo_tensors.push_back(CooFormat(13, 18, 1)); - expected_coo_tensors.push_back(CooFormat(13, 0, 1)); - - // Feature 1, slice 3 - expected_coo_tensors.push_back(CooFormat(14, 36, 1)); - expected_coo_tensors.push_back(CooFormat(15, 40, 1)); - - EXPECT_THAT(extracted_coo_tensors.coo_tensors, - ElementsAreArray(expected_coo_tensors)); -} - -TEST_F(TableStackingTest, SingleProcessSingleDeviceSplitThenStack) { - PreprocessSparseDenseMatmulInputOptions options{ - .local_device_count = 1, - .global_device_count = 1, - .num_sc_per_device = 4, - .feature_stacking_strategy = FeatureStackingStrategy::kSplitThenStack}; - - ExtractedCooTensors extracted_coo_tensors = - internal::ExtractCooTensorsForAllFeaturesPerLocalDevice( - stacked_table_metadata_single_, absl::MakeSpan(input_batches_single_), - /*local_device_id=*/0, options); - - EXPECT_EQ(extracted_coo_tensors.batch_size_for_device, 16); - ASSERT_THAT(extracted_coo_tensors.coo_tensors, SizeIs(16 * 17 / 2)); - - const int batch_size_per_sc = - extracted_coo_tensors.batch_size_for_device / options.num_sc_per_device; - std::vector ids_per_sc(options.num_sc_per_device, 0); - for (const auto& coo_tensor : extracted_coo_tensors.coo_tensors) { - ids_per_sc[coo_tensor.row_id / batch_size_per_sc]++; - } - - std::vector expected_ids_per_sc = {1 + 2 + 9 + 10, 3 + 4 + 11 + 12, - 5 + 6 + 13 + 14, 7 + 8 + 15 + 16}; - - EXPECT_EQ(ids_per_sc, expected_ids_per_sc); -} - -TEST_F(TableStackingTest, SingleProcessSingleDeviceStackThenSplit) { - PreprocessSparseDenseMatmulInputOptions options{ - .local_device_count = 1, - .global_device_count = 1, - .num_sc_per_device = 4, - .feature_stacking_strategy = FeatureStackingStrategy::kStackThenSplit}; - - ExtractedCooTensors extracted_coo_tensors = - internal::ExtractCooTensorsForAllFeaturesPerLocalDevice( - stacked_table_metadata_single_, absl::MakeSpan(input_batches_single_), - /*local_device_id=*/0, options); - - EXPECT_EQ(extracted_coo_tensors.batch_size_for_device, 16); - ASSERT_THAT(extracted_coo_tensors.coo_tensors, SizeIs(16 * 17 / 2)); - - const int batch_size_per_sc = - extracted_coo_tensors.batch_size_for_device / options.num_sc_per_device; - std::vector ids_per_sc(options.num_sc_per_device, 0); - for (const auto& coo_tensor : extracted_coo_tensors.coo_tensors) { - ids_per_sc[coo_tensor.row_id / batch_size_per_sc]++; - } - - std::vector expected_ids_per_sc = {1 + 2 + 3 + 4, // - 5 + 6 + 7 + 8, // - 9 + 10 + 11 + 12, // - 13 + 14 + 15 + 16}; - - EXPECT_EQ(ids_per_sc, expected_ids_per_sc); -} - -TEST_F(TableStackingTest, MultiChipSplitThenStack) { - PreprocessSparseDenseMatmulInputOptions options{ - .local_device_count = 2, - .global_device_count = 2, - .num_sc_per_device = 4, - .feature_stacking_strategy = FeatureStackingStrategy::kSplitThenStack}; - - std::vector expected_ids_per_sc[] = {{1 + 9, 2 + 10, 3 + 11, 4 + 12}, - {5 + 13, 6 + 14, 7 + 15, 8 + 16}}; - - for (int local_device_id = 0; local_device_id < options.local_device_count; - ++local_device_id) { - ExtractedCooTensors extracted_coo_tensors = - internal::ExtractCooTensorsForAllFeaturesPerLocalDevice( - stacked_table_metadata_single_, - absl::MakeSpan(input_batches_single_), local_device_id, options); - EXPECT_EQ(extracted_coo_tensors.batch_size_for_device, 8); - - const int batch_size_per_sc = - extracted_coo_tensors.batch_size_for_device / options.num_sc_per_device; - std::vector ids_per_sc(options.num_sc_per_device, 0); - for (const auto& coo_tensor : extracted_coo_tensors.coo_tensors) { - ids_per_sc[coo_tensor.row_id / batch_size_per_sc]++; +std::unique_ptr CreateInputBatchFromSamples( + absl::Span> samples) { + std::vector values; + std::vector row_splits; + row_splits.push_back(0); + for (const auto& sample_ids : samples) { + for (const auto& id : sample_ids) { + values.push_back(id); } - - EXPECT_EQ(ids_per_sc, expected_ids_per_sc[local_device_id]) - << "local_device_id: " << local_device_id; + row_splits.push_back(values.size()); } + return std::make_unique(std::move(values), std::move(row_splits)); } -TEST_F(TableStackingTest, MultiChipStackThenSplit) { - PreprocessSparseDenseMatmulInputOptions options{ - .local_device_count = 2, - .global_device_count = 2, - .num_sc_per_device = 4, - .feature_stacking_strategy = FeatureStackingStrategy::kStackThenSplit}; - - std::vector expected_ids_per_sc[] = {{1 + 2, 3 + 4, 9 + 10, 11 + 12}, - {5 + 6, 7 + 8, 13 + 14, 15 + 16}}; - - for (int local_device_id = 0; local_device_id < options.local_device_count; - ++local_device_id) { - ExtractedCooTensors extracted_coo_tensors = - internal::ExtractCooTensorsForAllFeaturesPerLocalDevice( - stacked_table_metadata_single_, - absl::MakeSpan(input_batches_single_), local_device_id, options); - - EXPECT_EQ(extracted_coo_tensors.batch_size_for_device, 8); - - const int batch_size_per_sc = - extracted_coo_tensors.batch_size_for_device / options.num_sc_per_device; - std::vector ids_per_sc(options.num_sc_per_device, 0); - for (const auto& coo_tensor : extracted_coo_tensors.coo_tensors) { - ids_per_sc[coo_tensor.row_id / batch_size_per_sc]++; - } - - EXPECT_EQ(ids_per_sc, expected_ids_per_sc[local_device_id]) - << "local_device_id: " << local_device_id; - } -} +} // namespace testing_utils TEST(InputPreprocessingUtilTest, MergeStats) { SparseDenseMatmulInputStats stats1; @@ -514,27 +201,6 @@ TEST(InputPreprocessingUtilTest, MergeStats) { ElementsAreArray({300, 301, 302, 303})); } -TEST_F(TableStackingTest, CooTensorsPerScCalculation) { - PreprocessSparseDenseMatmulInputOptions options{ - .local_device_count = 1, - .global_device_count = 1, - .num_sc_per_device = 4, - .feature_stacking_strategy = FeatureStackingStrategy::kSplitThenStack}; - - ExtractedCooTensors extracted_coo_tensors = - internal::ExtractCooTensorsForAllFeaturesPerLocalDevice( - stacked_table_metadata_single_, absl::MakeSpan(input_batches_single_), - /*local_device_id=*/0, options); - - EXPECT_EQ(extracted_coo_tensors.batch_size_for_device, 16); - ASSERT_THAT(extracted_coo_tensors.coo_tensors, SizeIs(16 * 17 / 2)); - - std::vector expected_coo_tensors_per_sc = { - 1 + 2 + 9 + 10, 3 + 4 + 11 + 12, 5 + 6 + 13 + 14, 7 + 8 + 15 + 16}; - EXPECT_EQ(extracted_coo_tensors.coo_tensors_per_sc, - expected_coo_tensors_per_sc); -} - class MinibatchingTest : public testing::TestWithParam { protected: bool IsMinibatchingEnabled() const { return GetParam(); } @@ -1018,7 +684,8 @@ void RunPreprocessingOutputIsValidTest( for (const auto& sample : table_samples) { total_input_ids += sample.size(); } - input_batches.push_back(CreateInputBatchFromSamples(table_samples)); + input_batches.push_back( + testing_utils::CreateInputBatchFromSamples(table_samples)); } const int kGlobalDeviceCount = global_device_count; @@ -1216,7 +883,7 @@ void StatsValidationTest(std::vector> samples, int num_sc_per_device, int global_device_count, FeatureStackingStrategy feature_stacking_strategy) { std::vector> input_batches; - input_batches.push_back(CreateInputBatchFromSamples(samples)); + input_batches.push_back(testing_utils::CreateInputBatchFromSamples(samples)); absl::flat_hash_map> stacked_tables; diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.cc b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.cc index 56ec238f..88047104 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.cc +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.cc @@ -204,6 +204,24 @@ int FillBufferSegment(const BufferFillingOptions& options, } // namespace +namespace internal { +void AggregatePerSparseCoreStats(const PerSparseCoreGroupedData& result, + int sc_id, + internal::StatsPerDevice& stats_per_device, + int& dropped_id_count) { + stats_per_device.max_ids_per_partition = + stats_per_device.max_ids_per_partition.cwiseMax( + result.ids_per_sc_partition_per_bucket.rowwise().sum().transpose()); + stats_per_device.max_unique_ids_per_partition = + stats_per_device.max_unique_ids_per_partition.cwiseMax( + result.unique_ids_per_partition_per_bucket.rowwise() + .sum() + .transpose()); + stats_per_device.required_buffer_size[sc_id] = result.required_buffer_size; + dropped_id_count += result.dropped_id_count; +} +} // namespace internal + RowCombiner GetRowCombiner(absl::string_view combiner) { if (combiner == "sum") { return RowCombiner::kSum; @@ -310,7 +328,7 @@ std::optional SuggestedCooBufferSizeForStackedTables( // We use output buffers `row_pointers`, `embedding_ids`, `sample_ids`, and // `gains` because we fill values in a loop to a bigger array. void FillLocalDeviceBuffer( - const PartitionedCooTensors& grouped_coo_tensors, + absl::Span sc_grouped_data, const int row_pointers_size_per_bucket, const int coo_buffer_size_per_sc, const int batch_size_per_sc, const PreprocessSparseDenseMatmulInputOptions& options, @@ -318,12 +336,15 @@ void FillLocalDeviceBuffer( int& dropped_id_count_static_bound) { tsl::profiler::TraceMe t("FillLocalDeviceBuffer"); const int num_sc_per_device = options.num_sc_per_device; + CHECK_EQ(num_sc_per_device, sc_grouped_data.size()); const int num_scs = options.GetNumScs(); const int coo_buffer_size = coo_buffer_size_per_sc * num_sc_per_device; DCHECK_GT(batch_size_per_sc, 0); int coo_begin = 0; int lhs_row_begin = 0; for (int local_sc_id = 0; local_sc_id < num_sc_per_device; ++local_sc_id) { + const auto& grouped_coo_tensors = + sc_grouped_data[local_sc_id].grouped_tensors; for (int minibatch_id = 0; minibatch_id < grouped_coo_tensors.GetNumMinibatches(); ++minibatch_id) { @@ -336,7 +357,7 @@ void FillLocalDeviceBuffer( coo_begin = FillBufferSegment( { .local_sc_id = local_sc_id, - .coo_tensors = grouped_coo_tensors(local_sc_id, minibatch_id), + .coo_tensors = grouped_coo_tensors(0, minibatch_id), .lhs_row_begin = lhs_row_begin, .lhs_row_end = lhs_row_end, .coo_begin = coo_begin, diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h index 4e813054..f107cb73 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h @@ -212,18 +212,61 @@ enum class RowCombiner { RowCombiner GetRowCombiner(absl::string_view combiner); -struct ExtractedCooTensors { +// Holds stats and grouped tensors for a single SparseCore. +struct PerSparseCoreGroupedData; + +namespace internal { +void AggregatePerSparseCoreStats(const PerSparseCoreGroupedData& result, + int sc_id, + internal::StatsPerDevice& stats_per_device, + int& dropped_id_count); +} + +// Holds stats and grouped tensors for a single SparseCore. +struct PerSparseCoreGroupedData { + // Tensors partitioned by bucket and global SC, for a single local SC. + // This will be configured as if num_sc_per_device=1. + // TODO(b/444292437): remove num_sc_per_device=1. + PartitionedCooTensors grouped_tensors; + + // Observed IDs per global SC partition per bucket for this local SC. + MatrixXi ids_per_sc_partition_per_bucket; + // Observed unique IDs per global SC partition per bucket for this local SC. + MatrixXi unique_ids_per_partition_per_bucket; + // Total required buffer size for this SC, summed across partitions. + int required_buffer_size = 0; + // Number of IDs dropped for this SC. + int dropped_id_count = 0; + // Batch size for the device this SparseCore belongs to. + int batch_size_for_device = 0; + + PerSparseCoreGroupedData() = default; + PerSparseCoreGroupedData(int coo_count, int global_sc_count, int bucket_count, + int batch_size_for_device) + : grouped_tensors(coo_count, /*num_sc_per_device=*/1, global_sc_count, + bucket_count), + ids_per_sc_partition_per_bucket( + MatrixXi::Zero(global_sc_count, bucket_count)), + unique_ids_per_partition_per_bucket( + MatrixXi::Zero(global_sc_count, bucket_count)), + batch_size_for_device(batch_size_for_device) {} +}; + +// Holds extracted tensors for one SparseCore before sorting and grouping. +struct ExtractedSparseCoreTensors { std::vector coo_tensors; + int batch_size_for_device = 0; +}; + +struct ExtractedCooTensors { + std::vector per_sc_tensors; // Number of samples these coo_tensors are extracted from. int batch_size_for_device; - // Count coo tensors per SC for efficient allocation of vector for sorting and - // grouping them. Might be lower after deduplication. - std::vector coo_tensors_per_sc; ExtractedCooTensors() : ExtractedCooTensors(0, 0) {} ExtractedCooTensors(int num_sc_per_device, int batch_size_for_device) - : batch_size_for_device(batch_size_for_device), - coo_tensors_per_sc(num_sc_per_device, 0) {} + : per_sc_tensors(num_sc_per_device), + batch_size_for_device(batch_size_for_device) {} }; struct StackedTableMetadata { @@ -287,7 +330,7 @@ std::optional SuggestedCooBufferSizeForStackedTables( absl::Span stacked_table_metadata); void FillLocalDeviceBuffer( - const PartitionedCooTensors& grouped_coo_tensors, + absl::Span sc_grouped_data, int row_pointers_size_per_bucket, int coo_buffer_size_per_sc, int batch_size_per_sc, const PreprocessSparseDenseMatmulInputOptions& options, diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util_test.cc b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util_test.cc index e0f719f7..7db6e0fc 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util_test.cc +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util_test.cc @@ -31,6 +31,8 @@ #include "jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h" namespace jax_sc_embedding { +namespace internal { + namespace { using ::testing::_; @@ -139,34 +141,44 @@ TEST(InputPreprocessingUtilTest, ComputeCooBufferSize) { TEST(SortAndGroupTest, Base) { std::vector coo_formats; - for (int row = 0; row < 8; ++row) { coo_formats.push_back(CooFormat(row, 0, 1.0)); coo_formats.push_back(CooFormat(row, 1, 1.0)); coo_formats.push_back(CooFormat(row, 2, 1.0)); coo_formats.push_back(CooFormat(row, 3, 1.0)); } - ExtractedCooTensors extracted_coo_tensors(4, 8); - extracted_coo_tensors.coo_tensors = coo_formats; + + const int kNumScPerDevice = 4; + const int kBatchSizePerSc = 2; + ExtractedSparseCoreTensors sc_tensors[kNumScPerDevice]; + for (int i = 0; i < kNumScPerDevice; ++i) { + sc_tensors[i].batch_size_for_device = kNumScPerDevice * kBatchSizePerSc; + for (int j = 0; j < kBatchSizePerSc; ++j) { + for (int k = 0; k < 4; ++k) { + sc_tensors[i].coo_tensors.push_back( + coo_formats[i * kBatchSizePerSc * 4 + j * 4 + k]); + } + } + } + StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/32, /*max_unique_ids_per_partition=*/32, /*row_offset=*/0, /*col_offset=*/0, /*col_shift=*/0, /*batch_size=*/0); PreprocessSparseDenseMatmulInputOptions options = { - .local_device_count = 4, + .local_device_count = 1, .global_device_count = 1, - .num_sc_per_device = 4, + .num_sc_per_device = kNumScPerDevice, .allow_id_dropping = false, - }; - MinibatchingSplit minibatching_split = 0; - StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/4, - /*num_sc_per_device=*/4); - internal::StatsPerDevice stats_per_device = - stats_per_host.GetStatsPerDevice(0); - PartitionedCooTensors coo_tensors_by_id = - SortAndGroupCooTensorsPerLocalDevice( - extracted_coo_tensors, stacked_table_metadata, options, - stats_per_device, minibatching_split); + }; + bool minibatching_required = false; + + std::vector results; + for (int i = 0; i < kNumScPerDevice; ++i) { + results.push_back(SortAndGroupCooTensorsForSingleSparseCore( + sc_tensors[i], 0, i, options, stacked_table_metadata, + minibatching_required)); + } std::vector expected_sc_0; expected_sc_0.push_back(CooFormat(0, 0, 1.0)); @@ -208,57 +220,71 @@ TEST(SortAndGroupTest, Base) { expected_sc_3.push_back(CooFormat(6, 3, 1.0)); expected_sc_3.push_back(CooFormat(7, 3, 1.0)); - EXPECT_THAT(coo_tensors_by_id(/*local_sc_id=*/0, /*bucket_id=*/0), + EXPECT_THAT(results[0].grouped_tensors(0, 0), ElementsAreArray(expected_sc_0)); - EXPECT_THAT(coo_tensors_by_id(/*local_sc_id=*/1, /*bucket_id=*/0), + EXPECT_THAT(results[1].grouped_tensors(0, 0), ElementsAreArray(expected_sc_1)); - EXPECT_THAT(coo_tensors_by_id(/*local_sc_id=*/2, /*bucket_id=*/0), + EXPECT_THAT(results[2].grouped_tensors(0, 0), ElementsAreArray(expected_sc_2)); - EXPECT_THAT(coo_tensors_by_id(/*local_sc_id=*/3, /*bucket_id=*/0), + EXPECT_THAT(results[3].grouped_tensors(0, 0), ElementsAreArray(expected_sc_3)); - EXPECT_EQ(stats_per_device.dropped_id_count, 0); - EXPECT_THAT(stats_per_device.max_ids_per_partition, + + EXPECT_EQ(results[0].dropped_id_count, 0); + EXPECT_EQ(results[1].dropped_id_count, 0); + EXPECT_EQ(results[2].dropped_id_count, 0); + EXPECT_EQ(results[3].dropped_id_count, 0); + + EXPECT_THAT(results[0].ids_per_sc_partition_per_bucket.rowwise().sum(), ElementsAreArray({2, 2, 2, 2})); - EXPECT_THAT(stats_per_device.max_unique_ids_per_partition, + EXPECT_THAT(results[0].unique_ids_per_partition_per_bucket.rowwise().sum(), ElementsAreArray({1, 1, 1, 1})); - EXPECT_THAT(stats_per_device.required_buffer_size, - ElementsAreArray({32, 32, 32, 32})); + EXPECT_EQ(results[0].required_buffer_size, 32); } TEST(SortAndGroupTest, TwoScs) { std::vector coo_formats; - for (int row = 0; row < 8; ++row) { coo_formats.push_back(CooFormat(row, 0, 1.0)); coo_formats.push_back(CooFormat(row, 1, 1.0)); coo_formats.push_back(CooFormat(row, 2, 1.0)); coo_formats.push_back(CooFormat(row, 3, 1.0)); } - ExtractedCooTensors extracted_coo_tensors(2, 8); - extracted_coo_tensors.coo_tensors = coo_formats; + + const int kNumScPerDevice = 2; + const int kBatchSizePerSc = 4; + ExtractedSparseCoreTensors sc_tensors[kNumScPerDevice]; + for (int i = 0; i < kNumScPerDevice; ++i) { + sc_tensors[i].batch_size_for_device = kNumScPerDevice * kBatchSizePerSc; + for (int j = 0; j < kBatchSizePerSc; ++j) { + for (int k = 0; k < 4; ++k) { + sc_tensors[i].coo_tensors.push_back( + coo_formats[i * kBatchSizePerSc * 4 + j * 4 + k]); + } + } + } + StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/32, /*max_unique_ids_per_partition=*/32, /*row_offset=*/0, /*col_offset=*/0, /*col_shift=*/0, /*batch_size=*/0); PreprocessSparseDenseMatmulInputOptions options = { - .local_device_count = 2, + .local_device_count = 1, .global_device_count = 1, - .num_sc_per_device = 2, + .num_sc_per_device = kNumScPerDevice, .allow_id_dropping = false, - }; - MinibatchingSplit minibatching_split = 0; - StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/2, - /*num_sc_per_device=*/2); - internal::StatsPerDevice stats_per_device = - stats_per_host.GetStatsPerDevice(0); - PartitionedCooTensors coo_tensors_by_id = - SortAndGroupCooTensorsPerLocalDevice( - extracted_coo_tensors, stacked_table_metadata, options, - stats_per_device, minibatching_split); - - EXPECT_EQ(minibatching_split, 0); - - EXPECT_THAT(coo_tensors_by_id(/*local_sc_id=*/0, /*bucket_id=*/0), + }; + bool minibatching_required = false; + + std::vector results; + for (int i = 0; i < kNumScPerDevice; ++i) { + results.push_back(SortAndGroupCooTensorsForSingleSparseCore( + sc_tensors[i], 0, i, options, stacked_table_metadata, + minibatching_required)); + } + + EXPECT_EQ(minibatching_required, 0); + + EXPECT_THAT(results[0].grouped_tensors(0, 0), ElementsAre(CooFormat(0, 0, 1.0), CooFormat(1, 0, 1.0), CooFormat(2, 0, 1.0), CooFormat(3, 0, 1.0), CooFormat(0, 2, 1.0), CooFormat(1, 2, 1.0), @@ -267,7 +293,7 @@ TEST(SortAndGroupTest, TwoScs) { CooFormat(2, 1, 1.0), CooFormat(3, 1, 1.0), CooFormat(0, 3, 1.0), CooFormat(1, 3, 1.0), CooFormat(2, 3, 1.0), CooFormat(3, 3, 1.0))); - EXPECT_THAT(coo_tensors_by_id(/*local_sc_id=*/1, /*bucket_id=*/0), + EXPECT_THAT(results[1].grouped_tensors(0, 0), ElementsAre(CooFormat(4, 0, 1.0), CooFormat(5, 0, 1.0), CooFormat(6, 0, 1.0), CooFormat(7, 0, 1.0), CooFormat(4, 2, 1.0), CooFormat(5, 2, 1.0), @@ -276,16 +302,37 @@ TEST(SortAndGroupTest, TwoScs) { CooFormat(6, 1, 1.0), CooFormat(7, 1, 1.0), CooFormat(4, 3, 1.0), CooFormat(5, 3, 1.0), CooFormat(6, 3, 1.0), CooFormat(7, 3, 1.0))); - EXPECT_EQ(stats_per_device.dropped_id_count, 0); - EXPECT_THAT(stats_per_device.max_ids_per_partition, ElementsAreArray({8, 8})); - EXPECT_THAT(stats_per_device.max_unique_ids_per_partition, + EXPECT_EQ(results[0].dropped_id_count, 0); + EXPECT_EQ(results[1].dropped_id_count, 0); + EXPECT_THAT(results[0].ids_per_sc_partition_per_bucket.rowwise().sum(), + ElementsAreArray({8, 8})); + EXPECT_THAT(results[0].unique_ids_per_partition_per_bucket.rowwise().sum(), ElementsAreArray({2, 2})); - EXPECT_THAT(stats_per_device.required_buffer_size, - ElementsAreArray({16, 16})); + EXPECT_EQ(results[0].required_buffer_size, 16); + EXPECT_EQ(results[1].required_buffer_size, 16); } TEST(SortAndGroupTest, VerifyIdLimitations1) { std::vector coo_formats; + for (int row = 0; row < 8; ++row) { + coo_formats.push_back(CooFormat(row, 0, 1.0)); + coo_formats.push_back(CooFormat(row, 1, 1.0)); + coo_formats.push_back(CooFormat(row, 2, 1.0)); + coo_formats.push_back(CooFormat(row, 3, 1.0)); + } + + const int kNumScPerDevice = 4; + const int kBatchSizePerSc = 2; + ExtractedSparseCoreTensors sc_tensors[kNumScPerDevice]; + for (int i = 0; i < kNumScPerDevice; ++i) { + sc_tensors[i].batch_size_for_device = kNumScPerDevice * kBatchSizePerSc; + for (int j = 0; j < kBatchSizePerSc; ++j) { + for (int k = 0; k < 4; ++k) { + sc_tensors[i].coo_tensors.push_back( + coo_formats[i * kBatchSizePerSc * 4 + j * 4 + k]); + } + } + } // With 8 samples, each sample has 4 ids [0, 1, 2, 3] // Each sparsecore serves 1 row of data. @@ -297,46 +344,53 @@ TEST(SortAndGroupTest, VerifyIdLimitations1) { // [max_unique_ids_per_partition == 1] // For each sparsecore, it receives the data of at most "1" row ID from each // sparsecore. - for (int row = 0; row < 8; ++row) { - coo_formats.push_back(CooFormat(row, 0, 1.0)); - coo_formats.push_back(CooFormat(row, 1, 1.0)); - coo_formats.push_back(CooFormat(row, 2, 1.0)); - coo_formats.push_back(CooFormat(row, 3, 1.0)); - } - ExtractedCooTensors extracted_coo_tensors(4, 8); - extracted_coo_tensors.coo_tensors = coo_formats; StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/2, /*max_unique_ids_per_partition=*/1, /*row_offset=*/0, /*col_offset=*/0, /*col_shift=*/0, /*batch_size=*/0); PreprocessSparseDenseMatmulInputOptions options = { - .local_device_count = 4, + .local_device_count = 1, .global_device_count = 1, - .num_sc_per_device = 4, + .num_sc_per_device = kNumScPerDevice, .allow_id_dropping = false, - }; - MinibatchingSplit minibatching_split = 0; - StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/4, - /*num_sc_per_device=*/4); - internal::StatsPerDevice stats_per_device = - stats_per_host.GetStatsPerDevice(0); - PartitionedCooTensors coo_tensors_by_id = - SortAndGroupCooTensorsPerLocalDevice( - extracted_coo_tensors, stacked_table_metadata, options, - stats_per_device, minibatching_split); - - EXPECT_EQ(minibatching_split, 0); - EXPECT_THAT(stats_per_device.dropped_id_count, 0); - EXPECT_THAT(stats_per_device.max_ids_per_partition, - ElementsAreArray({2, 2, 2, 2})); - EXPECT_THAT(stats_per_device.max_unique_ids_per_partition, - ElementsAreArray({1, 1, 1, 1})); - EXPECT_THAT(stats_per_device.required_buffer_size, - ElementsAreArray({32, 32, 32, 32})); + }; + bool minibatching_required = false; + + for (int i = 0; i < kNumScPerDevice; ++i) { + PerSparseCoreGroupedData result = + SortAndGroupCooTensorsForSingleSparseCore( + sc_tensors[i], 0, i, options, stacked_table_metadata, + minibatching_required); + EXPECT_EQ(result.dropped_id_count, 0); + EXPECT_THAT(result.ids_per_sc_partition_per_bucket.rowwise().sum(), + ElementsAreArray({2, 2, 2, 2})); + EXPECT_THAT(result.unique_ids_per_partition_per_bucket.rowwise().sum(), + ElementsAreArray({1, 1, 1, 1})); + EXPECT_EQ(result.required_buffer_size, 32); + } } TEST(SortAndGroupTest, VerifyIdLimitations2) { std::vector coo_formats; + for (int row = 0; row < 16; ++row) { + coo_formats.push_back(CooFormat(row, 0, 1.0)); + coo_formats.push_back(CooFormat(row, 1, 1.0)); + coo_formats.push_back(CooFormat(row, 2, 1.0)); + coo_formats.push_back(CooFormat(row, 3, 1.0)); + } + + const int kNumScPerDevice = 4; + const int kBatchSizePerSc = 4; + ExtractedSparseCoreTensors sc_tensors[kNumScPerDevice]; + for (int i = 0; i < kNumScPerDevice; ++i) { + sc_tensors[i].batch_size_for_device = kNumScPerDevice * kBatchSizePerSc; + for (int j = 0; j < kBatchSizePerSc; ++j) { + for (int k = 0; k < 4; ++k) { + sc_tensors[i].coo_tensors.push_back( + coo_formats[i * kBatchSizePerSc * 4 + j * 4 + k]); + } + } + } // With 16 samples, each sample has 4 ids [0, 1, 2, 3] // Each sparsecore serves 1 row of data. @@ -348,42 +402,30 @@ TEST(SortAndGroupTest, VerifyIdLimitations2) { // [max_unique_ids_per_partition == 1] // For each sparsecore, it receives the data of at most "1" row ID from each // sparsecore. - for (int row = 0; row < 16; ++row) { - coo_formats.push_back(CooFormat(row, 0, 1.0)); - coo_formats.push_back(CooFormat(row, 1, 1.0)); - coo_formats.push_back(CooFormat(row, 2, 1.0)); - coo_formats.push_back(CooFormat(row, 3, 1.0)); - } - ExtractedCooTensors extracted_coo_tensors(4, 16); - extracted_coo_tensors.coo_tensors = coo_formats; StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/4, /*max_unique_ids_per_partition=*/1, /*row_offset=*/0, /*col_offset=*/0, /*col_shift=*/0, /*batch_size=*/0); PreprocessSparseDenseMatmulInputOptions options = { - .local_device_count = 4, + .local_device_count = 1, .global_device_count = 1, - .num_sc_per_device = 4, + .num_sc_per_device = kNumScPerDevice, .allow_id_dropping = false, - }; - MinibatchingSplit minibatching_split = 0; - StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/4, - /*num_sc_per_device=*/4); - internal::StatsPerDevice stats_per_device = - stats_per_host.GetStatsPerDevice(0); - PartitionedCooTensors coo_tensors_by_id = - SortAndGroupCooTensorsPerLocalDevice( - extracted_coo_tensors, stacked_table_metadata, options, - stats_per_device, minibatching_split); - - EXPECT_EQ(minibatching_split, 0); - EXPECT_THAT(stats_per_device.dropped_id_count, 0); - EXPECT_THAT(stats_per_device.max_ids_per_partition, - ElementsAreArray({4, 4, 4, 4})); - EXPECT_THAT(stats_per_device.max_unique_ids_per_partition, - ElementsAreArray({1, 1, 1, 1})); - EXPECT_THAT(stats_per_device.required_buffer_size, - ElementsAreArray({32, 32, 32, 32})); + }; + bool minibatching_required = false; + + for (int i = 0; i < kNumScPerDevice; ++i) { + PerSparseCoreGroupedData result = + SortAndGroupCooTensorsForSingleSparseCore( + sc_tensors[i], 0, i, options, stacked_table_metadata, + minibatching_required); + EXPECT_EQ(result.dropped_id_count, 0); + EXPECT_THAT(result.ids_per_sc_partition_per_bucket.rowwise().sum(), + ElementsAreArray({4, 4, 4, 4})); + EXPECT_THAT(result.unique_ids_per_partition_per_bucket.rowwise().sum(), + ElementsAreArray({1, 1, 1, 1})); + EXPECT_EQ(result.required_buffer_size, 32); + } } TEST(SortAndGroupTest, VerifyIdLimitations3) { @@ -410,37 +452,46 @@ TEST(SortAndGroupTest, VerifyIdLimitations3) { coo_formats.push_back(CooFormat(row, 6, 1.0)); coo_formats.push_back(CooFormat(row, 7, 1.0)); } - ExtractedCooTensors extracted_coo_tensors(4, 16); - extracted_coo_tensors.coo_tensors = coo_formats; + const int kNumScPerDevice = 4; + const int kBatchSizePerSc = 4; + ExtractedSparseCoreTensors sc_tensors[kNumScPerDevice]; + for (int i = 0; i < kNumScPerDevice; ++i) { + sc_tensors[i].batch_size_for_device = kNumScPerDevice * kBatchSizePerSc; + for (int j = 0; j < kBatchSizePerSc; ++j) { + for (int k = 0; k < 8; ++k) { + sc_tensors[i].coo_tensors.push_back( + coo_formats[i * kBatchSizePerSc * 8 + j * 8 + k]); + } + } + } + StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/8, /*max_unique_ids_per_partition=*/2, /*row_offset=*/0, /*col_offset=*/0, /*col_shift=*/0, /*batch_size=*/0); PreprocessSparseDenseMatmulInputOptions options = { - .local_device_count = 4, + .local_device_count = 1, .global_device_count = 1, - .num_sc_per_device = 4, + .num_sc_per_device = kNumScPerDevice, .allow_id_dropping = false, - }; - MinibatchingSplit minibatching_split = 0; - StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/4, - /*num_sc_per_device=*/4); - internal::StatsPerDevice stats_per_device = - stats_per_host.GetStatsPerDevice(0); - PartitionedCooTensors coo_tensors_by_id = - SortAndGroupCooTensorsPerLocalDevice( - extracted_coo_tensors, stacked_table_metadata, options, - stats_per_device, minibatching_split); - - EXPECT_EQ(minibatching_split, 0); - EXPECT_THAT(stats_per_device.dropped_id_count, 0); - EXPECT_THAT(stats_per_device.max_ids_per_partition, - ElementsAreArray({8, 8, 8, 8})); - EXPECT_THAT(stats_per_device.max_unique_ids_per_partition, - ElementsAreArray({2, 2, 2, 2})); - // 4 partitions of size 8 with 2 elements each - EXPECT_THAT(stats_per_device.required_buffer_size, - ElementsAreArray({32, 32, 32, 32})); + }; + bool minibatching_required = false; + + for (int i = 0; i < kNumScPerDevice; ++i) { + PerSparseCoreGroupedData result = + SortAndGroupCooTensorsForSingleSparseCore( + sc_tensors[i], 0, i, options, stacked_table_metadata, + minibatching_required); + + EXPECT_EQ(minibatching_required, 0); + EXPECT_EQ(result.dropped_id_count, 0); + EXPECT_THAT(result.ids_per_sc_partition_per_bucket.rowwise().sum(), + ElementsAreArray({8, 8, 8, 8})); + EXPECT_THAT(result.unique_ids_per_partition_per_bucket.rowwise().sum(), + ElementsAreArray({2, 2, 2, 2})); + // 4 partitions of size 8 with 2 elements each + EXPECT_EQ(result.required_buffer_size, 32); + } } TEST(SortAndGroupTest, VerifyIdLimitations4) { @@ -467,37 +518,44 @@ TEST(SortAndGroupTest, VerifyIdLimitations4) { coo_formats.push_back(CooFormat(row, 6, 1.0)); coo_formats.push_back(CooFormat(row, 7, 1.0)); } - ExtractedCooTensors extracted_coo_tensors(4, 128); - extracted_coo_tensors.coo_tensors = coo_formats; + const int kNumScPerDevice = 4; + const int kBatchSizePerSc = 32; + ExtractedSparseCoreTensors sc_tensors[kNumScPerDevice]; + for (int i = 0; i < kNumScPerDevice; ++i) { + sc_tensors[i].batch_size_for_device = kNumScPerDevice * kBatchSizePerSc; + for (int j = 0; j < kBatchSizePerSc; ++j) { + for (int k = 0; k < 8; ++k) { + sc_tensors[i].coo_tensors.push_back( + coo_formats[i * kBatchSizePerSc * 8 + j * 8 + k]); + } + } + } StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/64, /*max_unique_ids_per_partition=*/2, /*row_offset=*/0, /*col_offset=*/0, /*col_shift=*/0, /*batch_size=*/0); PreprocessSparseDenseMatmulInputOptions options = { - .local_device_count = 4, + .local_device_count = 1, .global_device_count = 1, - .num_sc_per_device = 4, + .num_sc_per_device = kNumScPerDevice, .allow_id_dropping = false, - }; - MinibatchingSplit minibatching_split = 0; - StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/4, - /*num_sc_per_device=*/4); - internal::StatsPerDevice stats_per_device = - stats_per_host.GetStatsPerDevice(0); - PartitionedCooTensors coo_tensors_by_id = - SortAndGroupCooTensorsPerLocalDevice( - extracted_coo_tensors, stacked_table_metadata, options, - stats_per_device, minibatching_split); - - EXPECT_EQ(minibatching_split, 0); - EXPECT_THAT(stats_per_device.dropped_id_count, 0); - EXPECT_THAT(stats_per_device.max_ids_per_partition, - ElementsAreArray({64, 64, 64, 64})); - EXPECT_THAT(stats_per_device.max_unique_ids_per_partition, - ElementsAreArray({2, 2, 2, 2})); - // 8 partitions of size 256 with 32 elements each - EXPECT_THAT(stats_per_device.required_buffer_size, - ElementsAreArray({256, 256, 256, 256})); + }; + bool minibatching_required = false; + + for (int i = 0; i < kNumScPerDevice; ++i) { + PerSparseCoreGroupedData result = + SortAndGroupCooTensorsForSingleSparseCore( + sc_tensors[i], 0, i, options, stacked_table_metadata, + minibatching_required); + EXPECT_EQ(minibatching_required, 0); + EXPECT_EQ(result.dropped_id_count, 0); + EXPECT_THAT(result.ids_per_sc_partition_per_bucket.rowwise().sum(), + ElementsAreArray({64, 64, 64, 64})); + EXPECT_THAT(result.unique_ids_per_partition_per_bucket.rowwise().sum(), + ElementsAreArray({2, 2, 2, 2})); + // 8 partitions of size 256 with 32 elements each + EXPECT_EQ(result.required_buffer_size, 256); + } } TEST(SortAndGroupTest, VerifyIdLimitations5) { @@ -519,37 +577,44 @@ TEST(SortAndGroupTest, VerifyIdLimitations5) { coo_formats.push_back(CooFormat(row, 8, 1.0)); coo_formats.push_back(CooFormat(row, 16, 1.0)); } - ExtractedCooTensors extracted_coo_tensors(4, 128); - extracted_coo_tensors.coo_tensors = coo_formats; + const int kNumScPerDevice = 4; + const int kBatchSizePerSc = 32; + ExtractedSparseCoreTensors sc_tensors[kNumScPerDevice]; + for (int i = 0; i < kNumScPerDevice; ++i) { + sc_tensors[i].batch_size_for_device = kNumScPerDevice * kBatchSizePerSc; + for (int j = 0; j < kBatchSizePerSc; ++j) { + for (int k = 0; k < 4; ++k) { + sc_tensors[i].coo_tensors.push_back( + coo_formats[i * kBatchSizePerSc * 4 + j * 4 + k]); + } + } + } StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/128, /*max_unique_ids_per_partition=*/4, /*row_offset=*/0, /*col_offset=*/0, /*col_shift=*/0, /*batch_size=*/0); PreprocessSparseDenseMatmulInputOptions options = { - .local_device_count = 4, + .local_device_count = 1, .global_device_count = 1, - .num_sc_per_device = 4, + .num_sc_per_device = kNumScPerDevice, .allow_id_dropping = false, - }; - MinibatchingSplit minibatching_split = 0; - StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/4, - /*num_sc_per_device=*/4); - internal::StatsPerDevice stats_per_device = - stats_per_host.GetStatsPerDevice(0); - PartitionedCooTensors coo_tensors_by_id = - SortAndGroupCooTensorsPerLocalDevice( - extracted_coo_tensors, stacked_table_metadata, options, - stats_per_device, minibatching_split); - - EXPECT_EQ(minibatching_split, 0); - EXPECT_THAT(stats_per_device.dropped_id_count, 0); - EXPECT_THAT(stats_per_device.max_ids_per_partition, - ElementsAreArray({128, 0, 0, 0})); - EXPECT_THAT(stats_per_device.max_unique_ids_per_partition, - ElementsAreArray({4, 0, 0, 0})); - // 1 partition of size 128 with 128 elements - EXPECT_THAT(stats_per_device.required_buffer_size, - ElementsAreArray({128, 128, 128, 128})); + }; + bool minibatching_required = false; + + for (int i = 0; i < kNumScPerDevice; ++i) { + PerSparseCoreGroupedData result = + SortAndGroupCooTensorsForSingleSparseCore( + sc_tensors[i], 0, i, options, stacked_table_metadata, + minibatching_required); + EXPECT_EQ(minibatching_required, 0); + EXPECT_EQ(result.dropped_id_count, 0); + EXPECT_THAT(result.ids_per_sc_partition_per_bucket.rowwise().sum(), + ElementsAreArray({128, 0, 0, 0})); + EXPECT_THAT(result.unique_ids_per_partition_per_bucket.rowwise().sum(), + ElementsAreArray({4, 0, 0, 0})); + // 1 partition of size 128 with 128 elements + EXPECT_EQ(result.required_buffer_size, 128); + } } TEST(SortAndGroupTest, VerifyIdLimitations6) { @@ -571,36 +636,41 @@ TEST(SortAndGroupTest, VerifyIdLimitations6) { for (int row = 0; row < 128; ++row) { coo_formats.push_back(CooFormat(row, row * 4, 1.0)); } - ExtractedCooTensors extracted_coo_tensors(4, 128); - extracted_coo_tensors.coo_tensors = coo_formats; + const int kNumScPerDevice = 4; + const int kBatchSizePerSc = 32; + ExtractedSparseCoreTensors sc_tensors[kNumScPerDevice]; + for (int i = 0; i < kNumScPerDevice; ++i) { + sc_tensors[i].batch_size_for_device = kNumScPerDevice * kBatchSizePerSc; + for (int j = 0; j < kBatchSizePerSc; ++j) { + sc_tensors[i].coo_tensors.push_back(coo_formats[i * kBatchSizePerSc + j]); + } + } StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/32, /*max_unique_ids_per_partition=*/32, /*row_offset=*/0, /*col_offset=*/0, /*col_shift=*/0, /*batch_size=*/0); PreprocessSparseDenseMatmulInputOptions options = { - .local_device_count = 4, + .local_device_count = 1, .global_device_count = 1, - .num_sc_per_device = 4, + .num_sc_per_device = kNumScPerDevice, .allow_id_dropping = false, - }; - MinibatchingSplit minibatching_split = 0; - StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/4, - /*num_sc_per_device=*/4); - auto stats_per_device = stats_per_host.GetStatsPerDevice(0); - PartitionedCooTensors coo_tensors_by_id = - SortAndGroupCooTensorsPerLocalDevice( - extracted_coo_tensors, stacked_table_metadata, options, - stats_per_device, minibatching_split); - - EXPECT_EQ(minibatching_split, 0); - EXPECT_THAT(stats_per_device.dropped_id_count, 0); - EXPECT_THAT(stats_per_device.max_ids_per_partition, - ElementsAreArray({32, 0, 0, 0})); - EXPECT_THAT(stats_per_device.max_unique_ids_per_partition, - ElementsAreArray({32, 0, 0, 0})); - // 1 partition of size 32 with 32 elements - EXPECT_THAT(stats_per_device.required_buffer_size, - ElementsAreArray({32, 32, 32, 32})); + }; + bool minibatching_required = false; + + for (int i = 0; i < kNumScPerDevice; ++i) { + PerSparseCoreGroupedData result = + SortAndGroupCooTensorsForSingleSparseCore( + sc_tensors[i], 0, i, options, stacked_table_metadata, + minibatching_required); + EXPECT_EQ(minibatching_required, 0); + EXPECT_EQ(result.dropped_id_count, 0); + EXPECT_THAT(result.ids_per_sc_partition_per_bucket.rowwise().sum(), + ElementsAreArray({32, 0, 0, 0})); + EXPECT_THAT(result.unique_ids_per_partition_per_bucket.rowwise().sum(), + ElementsAreArray({32, 0, 0, 0})); + // 1 partition of size 32 with 32 elements + EXPECT_EQ(result.required_buffer_size, 32); + } } TEST(SortAndGroupTest, IdDropping) { @@ -624,88 +694,50 @@ TEST(SortAndGroupTest, IdDropping) { } // Force dropping of IDs here with max_ids_per_partition == 2 // The later 2 samples for each sparsecore will be dropped. - ExtractedCooTensors extracted_coo_tensors(4, 16); - extracted_coo_tensors.coo_tensors = coo_formats; + const int kNumScPerDevice = 4; + const int kBatchSizePerSc = 4; + ExtractedSparseCoreTensors sc_tensors[kNumScPerDevice]; + for (int i = 0; i < kNumScPerDevice; ++i) { + sc_tensors[i].batch_size_for_device = kNumScPerDevice * kBatchSizePerSc; + for (int j = 0; j < kBatchSizePerSc; ++j) { + for (int k = 0; k < 4; ++k) { + sc_tensors[i].coo_tensors.push_back( + coo_formats[i * kBatchSizePerSc * 4 + j * 4 + k]); + } + } + } StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/2, /*max_unique_ids_per_partition=*/1, /*row_offset=*/0, /*col_offset=*/0, /*col_shift=*/0, /*batch_size=*/0); PreprocessSparseDenseMatmulInputOptions options = { - .local_device_count = 4, + .local_device_count = 1, .global_device_count = 1, - .num_sc_per_device = 4, - .allow_id_dropping = true, - }; - bool minibatching_split = 0; - StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/4, - /*num_sc_per_device=*/4); - auto stats_per_device = stats_per_host.GetStatsPerDevice(0); - PartitionedCooTensors coo_tensors_by_id = - SortAndGroupCooTensorsPerLocalDevice( - extracted_coo_tensors, stacked_table_metadata, options, - stats_per_device, minibatching_split); - EXPECT_THAT(stats_per_device.dropped_id_count, 32); - EXPECT_THAT(stats_per_device.max_ids_per_partition, - ElementsAreArray({4, 4, 4, 4})); - EXPECT_THAT(stats_per_device.max_unique_ids_per_partition, - ElementsAreArray({1, 1, 1, 1})); - // 4 partition of size 8 with 4 element each - EXPECT_THAT(stats_per_device.required_buffer_size, - ElementsAreArray({32, 32, 32, 32})); - - // Note that sample 2, 3, 6, 7, 10, 11, 14, 15 are dropped. - // It's unclear how embedding activations will be constructed without these - // samples at this unit test level. - std::vector expected_sc_0; - expected_sc_0.push_back(CooFormat(0, 0, 1.0)); - expected_sc_0.push_back(CooFormat(1, 0, 1.0)); - expected_sc_0.push_back(CooFormat(0, 1, 1.0)); - expected_sc_0.push_back(CooFormat(1, 1, 1.0)); - expected_sc_0.push_back(CooFormat(0, 2, 1.0)); - expected_sc_0.push_back(CooFormat(1, 2, 1.0)); - expected_sc_0.push_back(CooFormat(0, 3, 1.0)); - expected_sc_0.push_back(CooFormat(1, 3, 1.0)); - - std::vector expected_sc_1; - expected_sc_1.push_back(CooFormat(4, 0, 1.0)); - expected_sc_1.push_back(CooFormat(5, 0, 1.0)); - expected_sc_1.push_back(CooFormat(4, 1, 1.0)); - expected_sc_1.push_back(CooFormat(5, 1, 1.0)); - expected_sc_1.push_back(CooFormat(4, 2, 1.0)); - expected_sc_1.push_back(CooFormat(5, 2, 1.0)); - expected_sc_1.push_back(CooFormat(4, 3, 1.0)); - expected_sc_1.push_back(CooFormat(5, 3, 1.0)); - - std::vector expected_sc_2; - expected_sc_2.push_back(CooFormat(8, 0, 1.0)); - expected_sc_2.push_back(CooFormat(9, 0, 1.0)); - expected_sc_2.push_back(CooFormat(8, 1, 1.0)); - expected_sc_2.push_back(CooFormat(9, 1, 1.0)); - expected_sc_2.push_back(CooFormat(8, 2, 1.0)); - expected_sc_2.push_back(CooFormat(9, 2, 1.0)); - expected_sc_2.push_back(CooFormat(8, 3, 1.0)); - expected_sc_2.push_back(CooFormat(9, 3, 1.0)); + .num_sc_per_device = kNumScPerDevice, + .allow_id_dropping = true}; + bool minibatching_required = false; - std::vector expected_sc_3; - expected_sc_3.push_back(CooFormat(12, 0, 1.0)); - expected_sc_3.push_back(CooFormat(13, 0, 1.0)); - expected_sc_3.push_back(CooFormat(12, 1, 1.0)); - expected_sc_3.push_back(CooFormat(13, 1, 1.0)); - expected_sc_3.push_back(CooFormat(12, 2, 1.0)); - expected_sc_3.push_back(CooFormat(13, 2, 1.0)); - expected_sc_3.push_back(CooFormat(12, 3, 1.0)); - expected_sc_3.push_back(CooFormat(13, 3, 1.0)); - - EXPECT_THAT(coo_tensors_by_id(/*local_sc_id=*/0, /*bucket_id=*/0), - ElementsAreArray(expected_sc_0)); - EXPECT_THAT(coo_tensors_by_id(/*local_sc_id=*/1, /*bucket_id=*/0), - ElementsAreArray(expected_sc_1)); - EXPECT_THAT(coo_tensors_by_id(/*local_sc_id=*/2, /*bucket_id=*/0), - ElementsAreArray(expected_sc_2)); - EXPECT_THAT(coo_tensors_by_id(/*local_sc_id=*/3, /*bucket_id=*/0), - ElementsAreArray(expected_sc_3)); + int total_dropped_count = 0; + for (int i = 0; i < kNumScPerDevice; ++i) { + PerSparseCoreGroupedData result = + SortAndGroupCooTensorsForSingleSparseCore( + sc_tensors[i], 0, i, options, stacked_table_metadata, + minibatching_required); + total_dropped_count += result.dropped_id_count; + EXPECT_THAT(result.ids_per_sc_partition_per_bucket.rowwise().sum(), + ElementsAreArray({4, 4, 4, 4})); + EXPECT_THAT(result.unique_ids_per_partition_per_bucket.rowwise().sum(), + ElementsAreArray({1, 1, 1, 1})); + // 4 partition of size 8 with 4 element each + EXPECT_EQ(result.required_buffer_size, 32); + } + EXPECT_EQ(total_dropped_count, 32); } +// Note that sample 2, 3, 6, 7, 10, 11, 14, 15 are dropped. +// It's unclear how embedding activations will be constructed without these +// samples at this unit test level. + TEST(InputPreprocessingUtilTest, FillBuffer) { std::vector coo_formats; @@ -715,35 +747,43 @@ TEST(InputPreprocessingUtilTest, FillBuffer) { coo_formats.push_back(CooFormat(row, 2, 1.0)); coo_formats.push_back(CooFormat(row, 3, 1.0)); } - ExtractedCooTensors extracted_coo_tensors(4, 8); - extracted_coo_tensors.coo_tensors = coo_formats; + const int kNumScPerDevice = 4; + const int kBatchSizePerSc = 2; + ExtractedSparseCoreTensors sc_tensors[kNumScPerDevice]; + for (int i = 0; i < kNumScPerDevice; ++i) { + sc_tensors[i].batch_size_for_device = kNumScPerDevice * kBatchSizePerSc; + for (int j = 0; j < kBatchSizePerSc; ++j) { + for (int k = 0; k < 4; ++k) { + sc_tensors[i].coo_tensors.push_back( + coo_formats[i * kBatchSizePerSc * 4 + j * 4 + k]); + } + } + } StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/32, /*max_unique_ids_per_partition=*/32, /*row_offset=*/0, /*col_offset=*/0, /*col_shift=*/0, /*batch_size=*/0); PreprocessSparseDenseMatmulInputOptions options = { - .local_device_count = 4, + .local_device_count = 1, .global_device_count = 1, - .num_sc_per_device = 4, - .allow_id_dropping = false, - }; - MinibatchingSplit minibatching_split = 0; - StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/4, - /*num_sc_per_device=*/4); - auto stats_per_device = stats_per_host.GetStatsPerDevice(0); - PartitionedCooTensors coo_tensors_by_id = - SortAndGroupCooTensorsPerLocalDevice( - extracted_coo_tensors, stacked_table_metadata, options, - stats_per_device, minibatching_split); + .num_sc_per_device = kNumScPerDevice, + .allow_id_dropping = false}; + bool minibatching_required = false; + + std::vector results; + for (int i = 0; i < kNumScPerDevice; ++i) { + results.push_back(SortAndGroupCooTensorsForSingleSparseCore( + sc_tensors[i], 0, i, options, stacked_table_metadata, + minibatching_required)); + } - EXPECT_EQ(minibatching_split, 0); + EXPECT_EQ(minibatching_required, 0); CsrArraysPerHost csr_arrays_per_host(1, 8 * 4, 40 * 4); - internal::CsrArraysPerDevice csr_array = - csr_arrays_per_host.GetCsrArraysPerDevice(0); + CsrArraysPerDevice csr_array = csr_arrays_per_host.GetCsrArraysPerDevice(0); int dropped_static_bound = 0; - FillLocalDeviceBuffer(coo_tensors_by_id, - /*row_pointers_size_per_sc=*/8, + FillLocalDeviceBuffer(results, + /*row_pointers_size_per_bucket=*/8, /*coo_buffer_size_per_sc=*/40, /*batch_size_per_sc=*/2, options, csr_array, dropped_static_bound); @@ -831,15 +871,26 @@ TEST(InputPreprocessingUtilTest, FillBuffer) { TEST(InputPreprocessingUtilTest, FillBufferMinibatchingSingleMinibatch) { std::vector coo_formats; - for (int row = 0; row < 8; ++row) { coo_formats.push_back(CooFormat(row, 0, 1.0)); coo_formats.push_back(CooFormat(row, 1, 1.0)); coo_formats.push_back(CooFormat(row, 2, 1.0)); coo_formats.push_back(CooFormat(row, 3, 1.0)); } - ExtractedCooTensors extracted_coo_tensors(4, 8); - extracted_coo_tensors.coo_tensors = coo_formats; + + const int kNumScPerDevice = 4; + const int kBatchSizePerSc = 2; + ExtractedSparseCoreTensors sc_tensors[kNumScPerDevice]; + for (int i = 0; i < kNumScPerDevice; ++i) { + sc_tensors[i].batch_size_for_device = kNumScPerDevice * kBatchSizePerSc; + for (int j = 0; j < kBatchSizePerSc; ++j) { + for (int k = 0; k < 4; ++k) { + sc_tensors[i].coo_tensors.push_back( + coo_formats[i * kBatchSizePerSc * 4 + j * 4 + k]); + } + } + } + StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/32, /*max_unique_ids_per_partition=*/32, /*row_offset=*/0, /*col_offset=*/0, @@ -848,30 +899,31 @@ TEST(InputPreprocessingUtilTest, FillBufferMinibatchingSingleMinibatch) { PreprocessSparseDenseMatmulInputOptions options = { .local_device_count = 1, .global_device_count = 1, - .num_sc_per_device = 4, + .num_sc_per_device = kNumScPerDevice, .allow_id_dropping = false, .enable_minibatching = true, .minibatching_bucketing_hash_fn = hash_fn}; MinibatchingSplit minibatching_split = 0; - StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/4, - /*num_sc_per_device=*/4); - internal::StatsPerDevice stats_per_device = - stats_per_host.GetStatsPerDevice(0); - PartitionedCooTensors coo_tensors_by_id = - SortAndGroupCooTensorsPerLocalDevice( - extracted_coo_tensors, stacked_table_metadata, options, - stats_per_device, minibatching_split); - coo_tensors_by_id.MergeAll(); + std::vector results; + for (int i = 0; i < kNumScPerDevice; ++i) { + results.push_back(SortAndGroupCooTensorsForSingleSparseCore( + sc_tensors[i], 0, i, options, stacked_table_metadata, + minibatching_split)); + } + + for (int i = 0; i < kNumScPerDevice; ++i) { + results[i].grouped_tensors.MergeAll(); + } CsrArraysPerHost csr_arrays_per_host(1, 8 * 4, 40 * 4); - internal::CsrArraysPerDevice csr_array = - csr_arrays_per_host.GetCsrArraysPerDevice(0); - FillLocalDeviceBuffer(coo_tensors_by_id, - /*row_pointers_size_per_sc=*/8, + CsrArraysPerDevice csr_array = csr_arrays_per_host.GetCsrArraysPerDevice(0); + int dropped_static_bound = 0; + FillLocalDeviceBuffer(results, + /*row_pointers_size_per_bucket=*/8, /*coo_buffer_size_per_sc=*/40, /*batch_size_per_sc=*/2, options, csr_array, - stats_per_device.dropped_id_count); + dropped_static_bound); std::array expected_row_pointers = { 2, 10, 18, 26, 32, 32, 32, 32, // MB0 @@ -950,7 +1002,7 @@ TEST(InputPreprocessingUtilTest, FillBufferMinibatchingSingleMinibatch) { _, _, _, _, _, _, _, _, // _, _, _, _, _, _, _, _, // _, _, _, _, _, _, _, _)); - EXPECT_EQ(stats_per_device.dropped_id_count, 0); + EXPECT_EQ(dropped_static_bound, 0); } TEST(InputPreprocessingUtilTest, FillBufferMinibatchingFourMinibatches) { @@ -961,8 +1013,18 @@ TEST(InputPreprocessingUtilTest, FillBufferMinibatchingFourMinibatches) { coo_formats.push_back(CooFormat(row, col, 1.0)); } } - ExtractedCooTensors extracted_coo_tensors(4, 8); - extracted_coo_tensors.coo_tensors = coo_formats; + const int kNumScPerDevice = 4; + const int kBatchSizePerSc = 2; + ExtractedSparseCoreTensors sc_tensors[kNumScPerDevice]; + for (int i = 0; i < kNumScPerDevice; ++i) { + sc_tensors[i].batch_size_for_device = kNumScPerDevice * kBatchSizePerSc; + for (int j = 0; j < kBatchSizePerSc; ++j) { + for (int k = 0; k < 64; ++k) { + sc_tensors[i].coo_tensors.push_back( + coo_formats[i * kBatchSizePerSc * 64 + j * 64 + k]); + } + } + } StackedTableMetadata stacked_table_metadata( "stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/32, /*max_unique_ids_per_partition=*/32, /*row_offset=*/0, /*col_offset=*/0, @@ -971,20 +1033,22 @@ TEST(InputPreprocessingUtilTest, FillBufferMinibatchingFourMinibatches) { PreprocessSparseDenseMatmulInputOptions options = { .local_device_count = 1, .global_device_count = 1, - .num_sc_per_device = 4, + .num_sc_per_device = kNumScPerDevice, .allow_id_dropping = false, .enable_minibatching = true, - .minibatching_bucketing_hash_fn = hash_fn}; + .minibatching_bucketing_hash_fn = hash_fn, + }; MinibatchingSplit minibatching_split = 0; StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/4, /*num_sc_per_device=*/4); - internal::StatsPerDevice stats_per_device = - stats_per_host.GetStatsPerDevice(0); - PartitionedCooTensors coo_tensors_by_id = - SortAndGroupCooTensorsPerLocalDevice( - extracted_coo_tensors, stacked_table_metadata, options, - stats_per_device, minibatching_split); - EXPECT_EQ(stats_per_device.dropped_id_count, 0); + StatsPerDevice stats_per_device = stats_per_host.GetStatsPerDevice(0); + std::vector results; + for (int i = 0; i < kNumScPerDevice; ++i) { + results.push_back(SortAndGroupCooTensorsForSingleSparseCore( + sc_tensors[i], 0, i, options, stacked_table_metadata, + minibatching_split)); + } + PartitionedCooTensors& coo_tensors_by_id = results[0].grouped_tensors; // 4 Minibatches of bucket sizes [16,8,24,16] // 62: 32 @@ -995,11 +1059,11 @@ TEST(InputPreprocessingUtilTest, FillBufferMinibatchingFourMinibatches) { minibatching_split.set(61); minibatching_split.set(60); minibatching_split.set(57); - coo_tensors_by_id.Merge(minibatching_split); - EXPECT_EQ(coo_tensors_by_id.GetNumMinibatches(), 4); - for (int i = 0; i < 4; i++) { - for (int j = 0; j < coo_tensors_by_id.GetNumMinibatches(); j++) { - auto coo_tensors = coo_tensors_by_id(i, j); + for (int i = 0; i < kNumScPerDevice; ++i) { + results[i].grouped_tensors.Merge(minibatching_split); + EXPECT_EQ(results[i].grouped_tensors.GetNumMinibatches(), 4); + for (int j = 0; j < results[i].grouped_tensors.GetNumMinibatches(); j++) { + auto coo_tensors = results[i].grouped_tensors(0, j); EXPECT_TRUE(absl::c_is_sorted( coo_tensors, [&](const CooFormat& coo1, const CooFormat& coo2) { // Sorted by global SC ID. @@ -1021,10 +1085,9 @@ TEST(InputPreprocessingUtilTest, FillBufferMinibatchingFourMinibatches) { const int num_devices = 1; CsrArraysPerHost csr_arrays_per_host = CsrArraysPerHost( num_devices, row_pointers_size, coo_buffer_size_per_sc * 4); - internal::CsrArraysPerDevice csr_array = - csr_arrays_per_host.GetCsrArraysPerDevice(0); + CsrArraysPerDevice csr_array = csr_arrays_per_host.GetCsrArraysPerDevice(0); - FillLocalDeviceBuffer(coo_tensors_by_id, + FillLocalDeviceBuffer(results, /*row_pointers_size_per_bucket=*/8, coo_buffer_size_per_sc, /*batch_size_per_sc=*/2, options, csr_array, @@ -1126,15 +1189,12 @@ TEST(InputPreprocessingUtilTest, FillBufferMinibatchingFourMinibatches) { TEST(InputPreprocessingUtilTest, FillBufferStaticBoundCountsOneDropNoMinibatching) { - std::vector coo_formats; - coo_formats.emplace_back(/*row=*/0, /*col=*/0, /*gain=*/1.0); - coo_formats.emplace_back(/*row=*/0, /*col=*/1, /*gain=*/1.0); - coo_formats.emplace_back(/*row=*/1, /*col=*/0, /*gain=*/1.0); - coo_formats.emplace_back(/*row=*/1, /*col=*/1, /*gain=*/1.0); - - ExtractedCooTensors extracted(/*num_sc_per_device=*/1, - /*batch_size_for_device=*/4); - extracted.coo_tensors = coo_formats; + ExtractedSparseCoreTensors extracted; + extracted.batch_size_for_device = 4; + extracted.coo_tensors.emplace_back(/*row=*/0, /*col=*/0, /*gain=*/1.0); + extracted.coo_tensors.emplace_back(/*row=*/0, /*col=*/1, /*gain=*/1.0); + extracted.coo_tensors.emplace_back(/*row=*/1, /*col=*/0, /*gain=*/1.0); + extracted.coo_tensors.emplace_back(/*row=*/1, /*col=*/1, /*gain=*/1.0); StackedTableMetadata meta("stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/32, @@ -1147,16 +1207,13 @@ TEST(InputPreprocessingUtilTest, .global_device_count = 1, .num_sc_per_device = 1, .allow_id_dropping = false, - }; + }; bool minibatching_required = false; - StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/1, - /*num_sc_per_device=*/1); - internal::StatsPerDevice stats_per_device = - stats_per_host.GetStatsPerDevice(0); - PartitionedCooTensors grouped = SortAndGroupCooTensorsPerLocalDevice( - extracted, meta, opts, stats_per_device, minibatching_required); - int dropped_sort = stats_per_device.dropped_id_count; + PerSparseCoreGroupedData result = + SortAndGroupCooTensorsForSingleSparseCore( + extracted, 0, 0, opts, meta, minibatching_required); + int dropped_sort = result.dropped_id_count; const int row_ptrs_size_per_bucket = 4; const int coo_buffer_size_per_sc = 3; @@ -1164,11 +1221,11 @@ TEST(InputPreprocessingUtilTest, CsrArraysPerHost csr_arrays_per_host(1, row_ptrs_size_per_bucket, coo_buffer_size_per_sc); - internal::CsrArraysPerDevice csr_arrays = - csr_arrays_per_host.GetCsrArraysPerDevice(0); + CsrArraysPerDevice csr_arrays = csr_arrays_per_host.GetCsrArraysPerDevice(0); + std::vector results = {result}; int dropped_static = 0; - FillLocalDeviceBuffer(grouped, row_ptrs_size_per_bucket, + FillLocalDeviceBuffer(results, row_ptrs_size_per_bucket, coo_buffer_size_per_sc, batch_size_per_sc, opts, csr_arrays, /*dropped_id_count_static_bound=*/dropped_static); @@ -1179,21 +1236,18 @@ TEST(InputPreprocessingUtilTest, TEST(InputPreprocessingUtilTest, FillBufferStaticBoundCountsDropsWithMinibatching) { - std::vector coo_formats; // 4 samples, 2 ids each. Total 8 ids. // col=0 will be in minibatch 0, col=1 in minibatch 1. - coo_formats.emplace_back(/*row=*/0, /*col=*/0, /*gain=*/1.0); - coo_formats.emplace_back(/*row=*/0, /*col=*/1, /*gain=*/1.0); - coo_formats.emplace_back(/*row=*/1, /*col=*/0, /*gain=*/1.0); - coo_formats.emplace_back(/*row=*/1, /*col=*/1, /*gain=*/1.0); - coo_formats.emplace_back(/*row=*/2, /*col=*/0, /*gain=*/1.0); - coo_formats.emplace_back(/*row=*/2, /*col=*/1, /*gain=*/1.0); - coo_formats.emplace_back(/*row=*/3, /*col=*/0, /*gain=*/1.0); - coo_formats.emplace_back(/*row=*/3, /*col=*/1, /*gain=*/1.0); - - ExtractedCooTensors extracted(/*num_sc_per_device=*/1, - /*batch_size_for_device=*/4); - extracted.coo_tensors = coo_formats; + ExtractedSparseCoreTensors extracted; + extracted.batch_size_for_device = 4; + extracted.coo_tensors.emplace_back(/*row=*/0, /*col=*/0, /*gain=*/1.0); + extracted.coo_tensors.emplace_back(/*row=*/0, /*col=*/1, /*gain=*/1.0); + extracted.coo_tensors.emplace_back(/*row=*/1, /*col=*/0, /*gain=*/1.0); + extracted.coo_tensors.emplace_back(/*row=*/1, /*col=*/1, /*gain=*/1.0); + extracted.coo_tensors.emplace_back(/*row=*/2, /*col=*/0, /*gain=*/1.0); + extracted.coo_tensors.emplace_back(/*row=*/2, /*col=*/1, /*gain=*/1.0); + extracted.coo_tensors.emplace_back(/*row=*/3, /*col=*/0, /*gain=*/1.0); + extracted.coo_tensors.emplace_back(/*row=*/3, /*col=*/1, /*gain=*/1.0); StackedTableMetadata meta("stacked_table", /*feature_index=*/0, /*max_ids_per_partition=*/32, @@ -1209,39 +1263,36 @@ TEST(InputPreprocessingUtilTest, .allow_id_dropping = false, .enable_minibatching = true, .minibatching_bucketing_hash_fn = hash_fn, - }; + }; MinibatchingSplit minibatching_split = 0; - StatsPerHost stats_per_host(/*local_device_count=*/1, /*num_partitions=*/1, - /*num_sc_per_device=*/1); - internal::StatsPerDevice stats_per_device = - stats_per_host.GetStatsPerDevice(0); - PartitionedCooTensors grouped = SortAndGroupCooTensorsPerLocalDevice( - extracted, meta, opts, stats_per_device, minibatching_split); + PerSparseCoreGroupedData result = + SortAndGroupCooTensorsForSingleSparseCore( + extracted, 0, 0, opts, meta, minibatching_split); // Create 2 minibatches by splitting based on bucket ID. minibatching_split.set(0); - grouped.Merge(minibatching_split); - ASSERT_EQ(grouped.GetNumMinibatches(), 2); + result.grouped_tensors.Merge(minibatching_split); + ASSERT_EQ(result.grouped_tensors.GetNumMinibatches(), 2); static constexpr int row_ptrs_size_per_bucket = 4; // Buffer size is 6, but there are 8 total IDs (4 in each minibatch). static constexpr int coo_buffer_size_per_sc = 6; static constexpr int batch_size_per_sc = 4; + std::vector results = {result}; CsrArraysPerHost csr_arrays_per_host( - 1, row_ptrs_size_per_bucket * grouped.GetNumMinibatches(), + 1, row_ptrs_size_per_bucket * result.grouped_tensors.GetNumMinibatches(), coo_buffer_size_per_sc); - internal::CsrArraysPerDevice csr_arrays = - csr_arrays_per_host.GetCsrArraysPerDevice(0); + CsrArraysPerDevice csr_arrays = csr_arrays_per_host.GetCsrArraysPerDevice(0); int dropped_static = 0; - FillLocalDeviceBuffer(grouped, row_ptrs_size_per_bucket, + FillLocalDeviceBuffer(results, row_ptrs_size_per_bucket, coo_buffer_size_per_sc, batch_size_per_sc, opts, csr_arrays, /*dropped_id_count_static_bound=*/dropped_static); - EXPECT_EQ(stats_per_device.dropped_id_count, 0); + EXPECT_EQ(result.dropped_id_count, 0); // Minibatch 0 has 4 IDs, Minibatch 1 has 4 IDs. Buffer size is 6. // Minibatch 0 is fully written (4 IDs). // Although there are 2 slots left, the entire Minibatch 1 (4 IDs) @@ -1270,4 +1321,5 @@ TEST(InputPreprocessingUtilTest, } } // namespace +} // namespace internal } // namespace jax_sc_embedding diff --git a/jax_tpu_embedding/sparsecore/lib/core/partitioned_coo_tensors.h b/jax_tpu_embedding/sparsecore/lib/core/partitioned_coo_tensors.h index f5d7f4fd..5afdfa4c 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/partitioned_coo_tensors.h +++ b/jax_tpu_embedding/sparsecore/lib/core/partitioned_coo_tensors.h @@ -33,6 +33,8 @@ namespace jax_sc_embedding { class PartitionedCooTensors { public: PartitionedCooTensors() : PartitionedCooTensors(0, 0, 0, 1) {} + // TODO(b/444292437): Remove `num_sc_per_device` and make this class a + // per-SparseCore class. PartitionedCooTensors(int reserve_count, int num_sc_per_device, uint32_t global_sc_count, int bucket_count_per_sc = 1) : coo_tensors_(), @@ -192,7 +194,6 @@ class PartitionedCooTensors { bucket_offsets_[dest_pos + minibatches - 1] = bucket_offsets_[start_pos + bucket_count_per_sc_ - 1]; } - bucket_offsets_.resize(1 + num_sc_per_device_ * minibatches); DCHECK(absl::c_is_sorted(bucket_offsets_)); @@ -213,7 +214,6 @@ class PartitionedCooTensors { while (curr_sc_id_ < target_sc_id || curr_bucket_id_ < target_bucket_id) { DCHECK_LT(curr_sc_id_, num_sc_per_device_); - DCHECK_LT(curr_bucket_id_, bucket_count_per_sc_); bucket_offsets_.push_back(coo_tensors_.size()); if (curr_bucket_id_ == bucket_count_per_sc_ - 1) { ++curr_sc_id_; diff --git a/jax_tpu_embedding/sparsecore/lib/core/process_coo_tensors_impl.h b/jax_tpu_embedding/sparsecore/lib/core/process_coo_tensors_impl.h index 1295129d..cda62c16 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/process_coo_tensors_impl.h +++ b/jax_tpu_embedding/sparsecore/lib/core/process_coo_tensors_impl.h @@ -95,7 +95,10 @@ void ProcessCooTensors( extracted_coo_tensors.batch_size_for_device / options.num_sc_per_device; CHECK_GT(batch_size_per_sc, 0); - extracted_coo_tensors.coo_tensors.reserve(values_stream.size()); + for (int sc_id = 0; sc_id < options.num_sc_per_device; ++sc_id) { + extracted_coo_tensors.per_sc_tensors[sc_id].coo_tensors.reserve( + batch_size_per_sc); + } DCHECK_EQ(values_stream.size(), weights_stream.size()); @@ -113,8 +116,7 @@ void ProcessCooTensors( ComputeWeightDivisor(options.combiner, weights_stream); const int num_cols = values_stream.cols(); - extracted_coo_tensors.coo_tensors_per_sc[sample_id / batch_size_per_sc] += - num_cols; + const int local_sc_id = sample_id / batch_size_per_sc; for (weights_stream.SeekCol(0); values_stream.col() < num_cols; values_stream.NextCol(), weights_stream.NextCol()) { @@ -123,9 +125,10 @@ void ProcessCooTensors( DCHECK_GE(embedding_id, 0); DCHECK_LT(sample_id, batch_size_per_sc * options.num_sc_per_device); - extracted_coo_tensors.coo_tensors.emplace_back( - sample_id, embedding_id, gain, options.col_shift, options.col_offset, - num_scs_mod); + extracted_coo_tensors.per_sc_tensors[local_sc_id] + .coo_tensors.emplace_back(sample_id, embedding_id, gain, + options.col_shift, options.col_offset, + num_scs_mod); } } } diff --git a/jax_tpu_embedding/sparsecore/lib/core/ragged_tensor_input_batch_test.cc b/jax_tpu_embedding/sparsecore/lib/core/ragged_tensor_input_batch_test.cc index 9b95b3d2..afeb7027 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/ragged_tensor_input_batch_test.cc +++ b/jax_tpu_embedding/sparsecore/lib/core/ragged_tensor_input_batch_test.cc @@ -27,7 +27,16 @@ namespace jax_sc_embedding { namespace { -using ::testing::ElementsAre; +using ::testing::UnorderedElementsAre; + +std::vector FlattenCooTensors(const ExtractedCooTensors& extracted) { + std::vector result; + for (const auto& sc_tensors : extracted.per_sc_tensors) { + result.insert(result.end(), sc_tensors.coo_tensors.begin(), + sc_tensors.coo_tensors.end()); + } + return result; +} TEST(RaggedTensorInputBatchTest, SliceTestWithSumCombiner) { // Input: (0, 0), (0, 1), (0, 2), (1, 0), (2, 0), (3, 2) @@ -76,16 +85,16 @@ TEST(RaggedTensorInputBatchTest, SliceTestWithSumCombiner) { }, extracted_3); - EXPECT_THAT(extracted_1.coo_tensors, - ElementsAre(CooFormat(0, 0, 1.0), CooFormat(0, 1, 1.0), - CooFormat(0, 2, 1.0), CooFormat(1, 0, 1.0), - CooFormat(2, 0, 1.0), CooFormat(3, 2, 1.0))); - EXPECT_THAT(extracted_2.coo_tensors, - ElementsAre(CooFormat(0, 0, 1.0), CooFormat(1, 0, 1.0))); + EXPECT_THAT(FlattenCooTensors(extracted_1), + UnorderedElementsAre(CooFormat(0, 0, 1.0), CooFormat(0, 1, 1.0), + CooFormat(0, 2, 1.0), CooFormat(1, 0, 1.0), + CooFormat(2, 0, 1.0), CooFormat(3, 2, 1.0))); + EXPECT_THAT(FlattenCooTensors(extracted_2), + UnorderedElementsAre(CooFormat(0, 0, 1.0), CooFormat(1, 0, 1.0))); - EXPECT_THAT( - extracted_3.coo_tensors, - ElementsAre(CooFormat(16, 0 + 8, 1.0), CooFormat(17, 2 + 8, 1.0))); + EXPECT_THAT(FlattenCooTensors(extracted_3), + UnorderedElementsAre(CooFormat(16, 0 + 8, 1.0), + CooFormat(17, 2 + 8, 1.0))); } TEST(RaggedTensorInputBatchTest, SliceTestWithMeanCombiner) { @@ -106,10 +115,11 @@ TEST(RaggedTensorInputBatchTest, SliceTestWithMeanCombiner) { .combiner = RowCombiner::kMean, }, extracted); - EXPECT_THAT(extracted.coo_tensors, - ElementsAre(CooFormat(0, 0, 1.0 / 3), CooFormat(0, 1, 1.0 / 3), - CooFormat(0, 2, 1.0 / 3), CooFormat(1, 0, 1.0), - CooFormat(2, 0, 1.0), CooFormat(3, 2, 1.0))); + EXPECT_THAT( + FlattenCooTensors(extracted), + UnorderedElementsAre(CooFormat(0, 0, 1.0 / 3), CooFormat(0, 1, 1.0 / 3), + CooFormat(0, 2, 1.0 / 3), CooFormat(1, 0, 1.0), + CooFormat(2, 0, 1.0), CooFormat(3, 2, 1.0))); } TEST(RaggedTensorInputBatchTest, SliceTestWithSqrtnCombiner) { @@ -128,16 +138,15 @@ TEST(RaggedTensorInputBatchTest, SliceTestWithSqrtnCombiner) { .combiner = RowCombiner::kSqrtn, }, extracted); - EXPECT_THAT( - extracted.coo_tensors, - ElementsAre(CooFormat(0, 0, 1.0 / std::sqrt(3)), - CooFormat(0, 1, 1.0 / std::sqrt(3)), - CooFormat(0, 2, 1.0 / std::sqrt(3)), CooFormat(1, 0, 1.0), - CooFormat(2, 0, 1.0), CooFormat(3, 2, 1.0))); + EXPECT_THAT(FlattenCooTensors(extracted), + UnorderedElementsAre(CooFormat(0, 0, 1.0 / std::sqrt(3)), + CooFormat(0, 1, 1.0 / std::sqrt(3)), + CooFormat(0, 2, 1.0 / std::sqrt(3)), + CooFormat(1, 0, 1.0), CooFormat(2, 0, 1.0), + CooFormat(3, 2, 1.0))); } -TEST(RaggedTensorInputBatchTest, - FixedValencyRowOffsetsCooExtractionIsCorrect) { +TEST(RaggedTensorInputBatchTest, FixedValencyRowOffsetsCooExtractionIsCorrect) { std::vector embedding_ids = {0, 1, 0, 2, 0, 3, 0, 4}; int batch_size = 4; int valency = 2; @@ -158,11 +167,12 @@ TEST(RaggedTensorInputBatchTest, .combiner = RowCombiner::kSum, }, extracted); - EXPECT_THAT(extracted.coo_tensors, - ElementsAre(CooFormat(0, 0, 1.0), CooFormat(0, 1, 1.0), // Row 0 - CooFormat(1, 0, 1.0), CooFormat(1, 2, 1.0), // Row 1 - CooFormat(2, 0, 1.0), CooFormat(2, 3, 1.0), // Row 2 - CooFormat(3, 0, 1.0), CooFormat(3, 4, 1.0)) // Row 3 + EXPECT_THAT( + FlattenCooTensors(extracted), + UnorderedElementsAre(CooFormat(0, 0, 1.0), CooFormat(0, 1, 1.0), // Row 0 + CooFormat(1, 0, 1.0), CooFormat(1, 2, 1.0), // Row 1 + CooFormat(2, 0, 1.0), CooFormat(2, 3, 1.0), // Row 2 + CooFormat(3, 0, 1.0), CooFormat(3, 4, 1.0)) // Row 3 ); } diff --git a/jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h b/jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h index f8f89ea6..5192e58e 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h +++ b/jax_tpu_embedding/sparsecore/lib/core/sort_and_group_coo_tensors_impl.h @@ -16,6 +16,7 @@ #include #include +#include #include #include "absl/log/check.h" // from @com_google_absl @@ -165,8 +166,7 @@ inline void LogSparseCoreStats( << local_sc_id << ": Total number of ids processed: " << keys_size << ", total after deduplication: " << ids_per_sc_partition_per_bucket.sum() - << ", total after drop id: " - << grouped_coo_tensors.Size(local_sc_id); + << ", total after drop id: " << grouped_coo_tensors.Size(0); } } @@ -309,179 +309,130 @@ inline void GroupAndDeduplicateCooTensorsForLocalSparseCore( } // namespace internal -// Sorts and groups the provided COO tensors in this hierarchy: Local SC -> -// Minibatching Bucket -> Global SC. -// -// -// NOTE: We use output buffers `max_ids_per_sc`, `max_unique_ids_per_sc`, and -// `required_buffer_size_per_sc` because we fill values in a loop to a bigger -// array. -template -PartitionedCooTensors SortAndGroupCooTensorsPerLocalDeviceImpl( - const ExtractedCooTensors& extracted_coo_tensors, - const StackedTableMetadata& stacked_table_metadata, +namespace internal { + +template +PerSparseCoreGroupedData SortAndGroupCooTensorsForSingleSparseCoreImpl( + const ExtractedSparseCoreTensors& sc_tensors_for_device, + int local_device_id, int local_sc_id, const PreprocessSparseDenseMatmulInputOptions& options, - internal::StatsPerDevice& stats, SplitType& minibatching_split) { - tsl::profiler::TraceMe t("SortAndGroupCooTensors"); - absl::Span coo_tensors = extracted_coo_tensors.coo_tensors; - const int num_sc_per_device = options.num_sc_per_device; - bool allow_id_dropping = options.allow_id_dropping; - const int batch_size_per_sc = xla::CeilOfRatio( - extracted_coo_tensors.batch_size_for_device, options.num_sc_per_device); - const uint32_t global_sc_count = options.GetNumScs(); - const int num_sc_bits = absl::bit_width(global_sc_count - 1); - const int max_ids_per_partition = - stacked_table_metadata.max_ids_per_partition; - const int max_unique_ids_per_partition = - stacked_table_metadata.max_unique_ids_per_partition; - const absl::string_view stacked_table_name = stacked_table_metadata.name; - // This function can be called in two passes for minibatching. The logic for - // stats collection and ID dropping depends on the pass. - // - // Pass 1: Check if minibatching is required (`kCreateBuckets` is false). - // - No IDs are dropped. - // - Stats are collected on all observed IDs to compute splits. - // - // Pass 2: Create buckets (`kCreateBuckets` is true). - // - A dummy stats object is used (stats are not re-computed). - // - IDs may be dropped if they exceed capacity. - - // Partition COO tensors among SparseCores for the local device (based on row - // id). + const StackedTableMetadata& first_metadata, SplitType& minibatching_split) { + std::vector sc_tensors_vec = + std::move(sc_tensors_for_device.coo_tensors); + const int bucket_count = kCreateBuckets ? CooFormat::kMaxMinibatchingBuckets : 1; - PartitionedCooTensors grouped_coo_tensors( - coo_tensors.size(), num_sc_per_device, global_sc_count, bucket_count); - - uint32_t coo_tensor_index = 0; - - // Loop over scs for this device. - for (int32_t local_sc_id = 0; local_sc_id < options.num_sc_per_device; - ++local_sc_id) { - grouped_coo_tensors.ResetDedupState(); - // These counters track the number of IDs that are actually kept (not - // dropped) for each partition and bucket for this device. - MatrixXi kept_ids_per_sc_partition_per_bucket = - MatrixXi::Zero(global_sc_count, bucket_count); - MatrixXi kept_unique_ids_per_partition_per_bucket = - MatrixXi::Zero(global_sc_count, bucket_count); - MatrixXi ids_per_sc_partition_per_bucket = - MatrixXi::Zero(global_sc_count, bucket_count); - MatrixXi unique_ids_per_partition_per_bucket = - MatrixXi::Zero(global_sc_count, bucket_count); - std::vector keys; - const int expected_keys_size = - extracted_coo_tensors.coo_tensors_per_sc[local_sc_id]; - keys.reserve(expected_keys_size); - internal::ValidateKeyCapacity(local_sc_id, expected_keys_size); - // We take the advantage of the fact that the row_ids are already sorted - // within each batch. - for (; coo_tensor_index < coo_tensors.size() && - coo_tensors[coo_tensor_index].row_id < - (local_sc_id + 1) * batch_size_per_sc; - coo_tensor_index++) { - const CooFormat& coo_tensor = coo_tensors[coo_tensor_index]; - // The key here is [bucket_id(6 bits), global_sc_id(num_scs bits), - // local_embedding_id(32-num_scs bits), index(26 bits)]. - // Note that this assumes `num_scs` is a power of 2. - keys.push_back(coo_tensor.GetGroupingKey( - num_sc_bits, coo_tensor_index, kCreateBuckets, - options.minibatching_bucketing_hash_fn, kHasVariableWeights)); - DCHECK(kHasVariableWeights || coo_tensors[coo_tensor_index].gain == 1.0f) - << "kHasVariableWeights: " << kHasVariableWeights - << ", coo: " << coo_tensor; - } + const uint32_t global_sc_count = options.GetNumScs(); - // The expected allocation size may be uninitialized. - DCHECK(expected_keys_size == 0 || keys.size() == expected_keys_size); - hwy::VQSort(keys.data(), keys.size(), hwy::SortAscending()); - - const internal::LocalSparseCoreTensorGroupingContext context = { - .keys = keys, - .coo_tensors = coo_tensors, - .stacked_table_metadata = stacked_table_metadata, - .options = options, - .local_sc_id = local_sc_id, - .num_sc_bits = num_sc_bits, - .grouped_coo_tensors = grouped_coo_tensors, - .ids_per_sc_partition_per_bucket = ids_per_sc_partition_per_bucket, - .unique_ids_per_partition_per_bucket = - unique_ids_per_partition_per_bucket, - .stats = stats, - .kept_ids_per_sc_partition_per_bucket = - kept_ids_per_sc_partition_per_bucket, - .kept_unique_ids_per_partition_per_bucket = - kept_unique_ids_per_partition_per_bucket, - }; - - internal::GroupAndDeduplicateCooTensorsForLocalSparseCore< - kHasVariableWeights, kCreateBuckets>(context); - - grouped_coo_tensors.FillRemainingScBuckets(); - - // Update global max using this device's values. - internal::UpdateMaxIdsPerPartition(stats.max_ids_per_partition, - ids_per_sc_partition_per_bucket); - internal::UpdateMaxIdsPerPartition(stats.max_unique_ids_per_partition, - unique_ids_per_partition_per_bucket); - auto partition_sizes = - ids_per_sc_partition_per_bucket.rowwise().sum().array(); - stats.required_buffer_size[local_sc_id] += - partition_sizes - .unaryExpr([](int val) { - return xla::RoundUpTo(val, TPU_VECTOR_REGISTER_ALIGNMENT_SIZE); - }) - .sum(); - - internal::LogSparseCoreStats( - local_sc_id, stacked_table_name, ids_per_sc_partition_per_bucket, - unique_ids_per_partition_per_bucket, keys.size(), grouped_coo_tensors); - - const int32_t observed_max_ids_per_bucket = - ids_per_sc_partition_per_bucket.maxCoeff(); - const int32_t observed_max_unique_ids_per_bucket = - unique_ids_per_partition_per_bucket.maxCoeff(); - - if (options.enable_minibatching) { - internal::UpdateMinibatchingSplit( - ids_per_sc_partition_per_bucket, unique_ids_per_partition_per_bucket, - global_sc_count, max_ids_per_partition, max_unique_ids_per_partition, - minibatching_split); - } + PerSparseCoreGroupedData result(sc_tensors_vec.size(), global_sc_count, + bucket_count, + sc_tensors_for_device.batch_size_for_device); + + const int num_sc_bits = absl::bit_width(global_sc_count - 1); + std::vector keys; + keys.reserve(sc_tensors_vec.size()); + for (int i = 0; i < sc_tensors_vec.size(); ++i) { + keys.push_back(sc_tensors_vec[i].GetGroupingKey( + num_sc_bits, i, kCreateBuckets, options.minibatching_bucketing_hash_fn, + kHasVariableWeights)); + } + + hwy::VQSort(keys.data(), keys.size(), hwy::SortAscending()); + + StatsPerHost stats_host(/*local_device_count=*/1, global_sc_count, + /*num_sc_per_device=*/1); + StatsPerDevice stats = stats_host.GetStatsPerDevice(0); + + MatrixXi kept_ids_per_sc_partition_per_bucket = + MatrixXi::Zero(global_sc_count, bucket_count); + MatrixXi kept_unique_ids_per_partition_per_bucket = + MatrixXi::Zero(global_sc_count, bucket_count); + + const internal::LocalSparseCoreTensorGroupingContext context = { + .keys = keys, + .coo_tensors = sc_tensors_vec, + .stacked_table_metadata = first_metadata, + .options = options, + .local_sc_id = 0, + .num_sc_bits = num_sc_bits, + .grouped_coo_tensors = result.grouped_tensors, + .ids_per_sc_partition_per_bucket = result.ids_per_sc_partition_per_bucket, + .unique_ids_per_partition_per_bucket = + result.unique_ids_per_partition_per_bucket, + .stats = stats, + .kept_ids_per_sc_partition_per_bucket = + kept_ids_per_sc_partition_per_bucket, + .kept_unique_ids_per_partition_per_bucket = + kept_unique_ids_per_partition_per_bucket, + }; + + internal::GroupAndDeduplicateCooTensorsForLocalSparseCore( + context); + + result.grouped_tensors.FillRemainingScBuckets(); + result.dropped_id_count = stats.dropped_id_count; + + internal::LogSparseCoreStats(local_sc_id, first_metadata.name, + result.ids_per_sc_partition_per_bucket, + result.unique_ids_per_partition_per_bucket, + keys.size(), result.grouped_tensors); + + auto partition_sizes = + result.ids_per_sc_partition_per_bucket.rowwise().sum().array(); + result.required_buffer_size = + partition_sizes + .unaryExpr([](int val) { + return xla::RoundUpTo(val, TPU_VECTOR_REGISTER_ALIGNMENT_SIZE); + }) + .sum(); + + const int32_t observed_max_ids_per_bucket = + result.ids_per_sc_partition_per_bucket.maxCoeff(); + const int32_t observed_max_unique_ids_per_bucket = + result.unique_ids_per_partition_per_bucket.maxCoeff(); + + if (options.enable_minibatching) { + internal::UpdateMinibatchingSplit( + result.ids_per_sc_partition_per_bucket, + result.unique_ids_per_partition_per_bucket, global_sc_count, + first_metadata.max_ids_per_partition, + first_metadata.max_unique_ids_per_partition, minibatching_split); + } - // Only validate if creating minibatching buckets or when minibatching is - // disabled, not when checking if minibatching is required. - if (!options.enable_minibatching || kCreateBuckets) - internal::ValidateMaxIdsOrDie( - observed_max_ids_per_bucket, observed_max_unique_ids_per_bucket, - max_ids_per_partition, max_unique_ids_per_partition, - stacked_table_name, allow_id_dropping); - } // end local_sc_id loop + if (!options.enable_minibatching || kCreateBuckets) + internal::ValidateMaxIdsOrDie( + observed_max_ids_per_bucket, observed_max_unique_ids_per_bucket, + first_metadata.max_ids_per_partition, + first_metadata.max_unique_ids_per_partition, first_metadata.name, + options.allow_id_dropping); - return grouped_coo_tensors; + return result; } -template -PartitionedCooTensors SortAndGroupCooTensorsPerLocalDevice( - const ExtractedCooTensors& extracted_coo_tensors, - const StackedTableMetadata& stacked_table_metadata, +template +PerSparseCoreGroupedData SortAndGroupCooTensorsForSingleSparseCore( + const ExtractedSparseCoreTensors& sc_tensors_for_device, + int local_device_id, int local_sc_id, const PreprocessSparseDenseMatmulInputOptions& options, - internal::StatsPerDevice& stats, SplitType& minibatching_split) { - const bool create_buckets = - options.enable_minibatching && - std::is_same_v; + const StackedTableMetadata& first_metadata, SplitType& minibatching_split) { + const bool create_buckets = options.enable_minibatching && + std::is_same_v; if (create_buckets) { - return SortAndGroupCooTensorsPerLocalDeviceImpl( - extracted_coo_tensors, stacked_table_metadata, options, stats, - minibatching_split); + return SortAndGroupCooTensorsForSingleSparseCoreImpl( + sc_tensors_for_device, local_device_id, local_sc_id, options, + first_metadata, minibatching_split); } else { - return SortAndGroupCooTensorsPerLocalDeviceImpl( - extracted_coo_tensors, stacked_table_metadata, options, stats, - minibatching_split); + return SortAndGroupCooTensorsForSingleSparseCoreImpl( + sc_tensors_for_device, local_device_id, local_sc_id, options, + first_metadata, minibatching_split); } } +} // namespace internal + } // namespace jax_sc_embedding -#endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_SORT_AND_GROUP_COO_TENSORS_IMPL_H_ +#endif // JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_SORT_AND_GROUP_COO_TENSORS_IMPL_H_ \ No newline at end of file diff --git a/jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_input_benchmarks.py b/jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_input_benchmarks.py index b7ac4e53..81f080e6 100644 --- a/jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_input_benchmarks.py +++ b/jax_tpu_embedding/sparsecore/lib/nn/tests/preprocess_input_benchmarks.py @@ -134,17 +134,16 @@ def generate_samples_for_feature_spec(feature_specs, num_samples, ragged=False): all_features.append(features) all_feature_weights.append(feature_weights) else: - features = [] - feature_weights = [] - for _ in range(num_samples): - num_ids = np.random.randint(1, 32) - ids = np.random.randint( - table_spec.vocabulary_size, - size=(num_ids,), - dtype=np.int32, - ) - features.append(ids) - feature_weights.append(np.ones((num_ids,), dtype=np.float32)) + counts = np.random.randint(1, 32, size=num_samples) + total_ids = np.sum(counts) + ids_flat = np.random.randint( + table_spec.vocabulary_size, + size=(total_ids,), + dtype=np.int32, + ) + split_indices = np.cumsum(counts)[:-1] + features = np.split(ids_flat, split_indices) + feature_weights = [np.ones((c,), dtype=np.float32) for c in counts] all_features.append(np.array(features, dtype=object)) all_feature_weights.append(np.array(feature_weights, dtype=object)) return all_features, all_feature_weights @@ -160,18 +159,19 @@ def generate_sparse_coo_inputs_for_feature_spec( for feature_spec in feature_specs: table_spec = feature_spec.table_spec - indices_tensors = [] - values_tensors = [] - for i in range(num_samples): - num_ids = np.random.randint(1, 32) - for j in range(num_ids): - indices_tensors.append([i, j]) - for _ in range(num_ids): - values_tensors.append(np.random.randint(table_spec.vocabulary_size)) - all_indices_tensors.append(np.array(indices_tensors, dtype=np.int64)) - all_values_tensors.append(np.array(values_tensors, dtype=np.int32)) + counts = np.random.randint(1, 32, size=num_samples) + total_ids = np.sum(counts) + values = np.random.randint( + table_spec.vocabulary_size, size=total_ids, dtype=np.int32 + ) + row_indices = np.repeat(np.arange(num_samples), counts) + col_indices = np.concatenate([np.arange(c) for c in counts]) + indices = np.stack([row_indices, col_indices], axis=1) + + all_indices_tensors.append(indices.astype(np.int64)) + all_values_tensors.append(values) all_dense_shape_tensors.append( - np.array([num_samples, vocab_size], dtype=np.int64) + np.array([num_samples, np.max(counts)], dtype=np.int64) ) return all_indices_tensors, all_values_tensors, all_dense_shape_tensors