From 86288044bb9b8aa2c0ab6f735bd22348fe488207 Mon Sep 17 00:00:00 2001 From: The JAX SC Authors Date: Tue, 11 Nov 2025 20:08:17 -0800 Subject: [PATCH] Optimize SparseCore input preprocessing by eliminating buffer copies. * Introducing `OutputBufferViews`, allowing the caller to pass views of the destination buffers. * Modifying `CsrArraysPerHost` to optionally use `Eigen::Map` to wrap these buffers when provided. We can avoid populating large CSR arrays in the preprocessing return values and skip the data copy step. PiperOrigin-RevId: 831179831 --- jax_tpu_embedding/sparsecore/lib/core/BUILD | 1 + .../lib/core/input_preprocessing.cc | 54 ++++++++++----- .../lib/core/input_preprocessing_test.cc | 64 ++++++++++++++++++ .../lib/core/input_preprocessing_util.cc | 2 +- .../lib/core/input_preprocessing_util.h | 65 +++++++++++++++---- .../lib/core/input_preprocessing_util_test.cc | 33 ++++++++++ 6 files changed, 191 insertions(+), 28 deletions(-) diff --git a/jax_tpu_embedding/sparsecore/lib/core/BUILD b/jax_tpu_embedding/sparsecore/lib/core/BUILD index d10d9f7d..9eed1092 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/BUILD +++ b/jax_tpu_embedding/sparsecore/lib/core/BUILD @@ -57,6 +57,7 @@ cc_library( ":partitioned_coo_tensors", "@com_google_absl//absl/base:core_headers", "@com_google_absl//absl/base:nullability", + "@com_google_absl//absl/container:flat_hash_map", "@com_google_absl//absl/log", "@com_google_absl//absl/log:check", "@com_google_absl//absl/strings:string_view", diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc index c9ef8fc4..38bccdf1 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc @@ -113,6 +113,26 @@ void CheckDeviceBatchSize(int batch_size_for_device, int num_sc_per_device, batch_size_for_device, stacked_table_name, num_sc_per_device); } +CsrArraysPerHost CreateCsrArraysPerHost( + absl::string_view name, + const PreprocessSparseDenseMatmulInputOptions& options, + int coo_buffer_size_per_device, int row_pointers_size_per_bucket) { + const int row_pointers_dim = row_pointers_size_per_bucket * + (options.enable_minibatching + ? CooFormat::kMaxMinibatchingBuckets + : 1) * + options.num_sc_per_device; + if (options.output_buffers) { + auto it = options.output_buffers->find(name); + if (it != options.output_buffers->end()) { + return CsrArraysPerHost(options.local_device_count, row_pointers_dim, + coo_buffer_size_per_device, it->second); + } + } + return CsrArraysPerHost(options.local_device_count, row_pointers_dim, + coo_buffer_size_per_device); +} + // Holds the state for processing a single stacked table across all local // devices. This includes extracted COO tensors, partitioned COO tensors, // CSR arrays, and statistics. @@ -138,13 +158,9 @@ struct TableState { coo_buffer_size_per_device(ComputeCooBufferSizePerDevice( num_scs, options.num_sc_per_device, metadata, options.batch_number, options.enable_minibatching)), - csr_arrays_per_host(options.local_device_count, - row_pointers_size_per_bucket * - (options.enable_minibatching - ? CooFormat::kMaxMinibatchingBuckets - : 1) * - options.num_sc_per_device, - coo_buffer_size_per_device), + csr_arrays_per_host(CreateCsrArraysPerHost( + name, options, coo_buffer_size_per_device, + row_pointers_size_per_bucket)), stats_per_host(options.local_device_count, options.GetNumScs(), options.num_sc_per_device), batch_size_for_device(0) { @@ -428,15 +444,21 @@ void PopulateOutput(TableState& state, PreprocessSparseDenseMatmulOutput& out, absl::Mutex& output_mutex) { state.stats_per_host.Flatten(); - absl::MutexLock mutex(output_mutex); - out.lhs_row_pointers[state.stacked_table_name] = - std::move(state.csr_arrays_per_host.row_pointers); - out.lhs_embedding_ids[state.stacked_table_name] = - std::move(state.csr_arrays_per_host.embedding_ids); - out.lhs_sample_ids[state.stacked_table_name] = - std::move(state.csr_arrays_per_host.sample_ids); - out.lhs_gains[state.stacked_table_name] = - std::move(state.csr_arrays_per_host.gains); + absl::MutexLock lock(output_mutex); + // If `owns_data` is true, it indicates that the data is owned + // by `CsrArraysPerHost`, so we need to move it to the output. Otherwise, + // the data has already been written directly into the output buffers via + // `views`, and no move is necessary. + if (state.csr_arrays_per_host.owns_data) { + out.lhs_row_pointers[state.stacked_table_name] = + std::move(state.csr_arrays_per_host.row_pointers); + out.lhs_embedding_ids[state.stacked_table_name] = + std::move(state.csr_arrays_per_host.embedding_ids); + out.lhs_sample_ids[state.stacked_table_name] = + std::move(state.csr_arrays_per_host.sample_ids); + out.lhs_gains[state.stacked_table_name] = + std::move(state.csr_arrays_per_host.gains); + } out.stats.max_ids_per_partition[state.stacked_table_name] = std::move(state.stats_per_host.max_ids_per_partition); diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_test.cc b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_test.cc index 4275305b..a1bb14f8 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_test.cc +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_test.cc @@ -58,6 +58,7 @@ using ::testing::Each; using ::testing::ElementsAreArray; using ::testing::Eq; using ::testing::Gt; +using ::testing::NanSensitiveFloatEq; using ::testing::SizeIs; std::unique_ptr CreateInputBatchFromSamples( @@ -1388,5 +1389,68 @@ FUZZ_TEST(InputPreprocessingFuzzTest, StatsValidationTest) {FeatureStackingStrategy::kStackThenSplit, FeatureStackingStrategy::kSplitThenStack})); +TEST_F(TableStackingTest, + PreprocessingWithAndWithoutOutputBuffersIsEquivalent) { + PreprocessSparseDenseMatmulInputOptions options{ + .local_device_count = 1, + .global_device_count = 2, + .num_sc_per_device = 4, + .feature_stacking_strategy = FeatureStackingStrategy::kStackThenSplit}; + absl::flat_hash_map> + stacked_tables({{"table_0", stacked_table_metadata_multi_}}); + + TF_ASSERT_OK_AND_ASSIGN( + PreprocessSparseDenseMatmulOutput output_matrix, + PreprocessSparseDenseMatmulInput(absl::MakeSpan(input_batches_multi_), + stacked_tables, options)); + + const int num_scs = options.GetNumScs(); + const int coo_buffer_size_per_device = + ComputeCooBufferSizePerDevice(num_scs, options.num_sc_per_device, + stacked_table_metadata_multi_, 0, false); + const int row_pointers_size = + std::max(num_scs, TPU_VECTOR_REGISTER_ALIGNMENT_SIZE) * + options.num_sc_per_device; + std::vector row_pointers_data(row_pointers_size, INT_MAX); + std::vector embedding_ids_data(coo_buffer_size_per_device, INT_MAX); + std::vector sample_ids_data(coo_buffer_size_per_device, INT_MAX); + std::vector gains_data(coo_buffer_size_per_device, std::nanf("")); + absl::flat_hash_map output_buffers; + output_buffers["table_0"] = internal::CsrArraysPerDevice{ + .row_pointers = absl::MakeSpan(row_pointers_data), + .embedding_ids = absl::MakeSpan(embedding_ids_data), + .sample_ids = absl::MakeSpan(sample_ids_data), + .gains = absl::MakeSpan(gains_data), + }; + options.output_buffers = &output_buffers; + + TF_ASSERT_OK_AND_ASSIGN( + PreprocessSparseDenseMatmulOutput output_zero_copy, + PreprocessSparseDenseMatmulInput(absl::MakeSpan(input_batches_multi_), + stacked_tables, options)); + options.output_buffers = nullptr; // for next test + + ASSERT_EQ(output_matrix.lhs_row_pointers["table_0"].rows(), 1); + ASSERT_EQ(output_matrix.lhs_embedding_ids["table_0"].rows(), 1); + ASSERT_EQ(output_matrix.lhs_sample_ids["table_0"].rows(), 1); + ASSERT_EQ(output_matrix.lhs_gains["table_0"].rows(), 1); + + EXPECT_THAT(row_pointers_data, + ElementsAreArray(absl::MakeConstSpan( + output_matrix.lhs_row_pointers["table_0"].data(), + row_pointers_size))); + for (int i = 0; i < coo_buffer_size_per_device; ++i) { + if (embedding_ids_data[i] != INT_MAX) { + EXPECT_EQ(embedding_ids_data[i], + output_matrix.lhs_embedding_ids["table_0"].data()[i]); + EXPECT_EQ(sample_ids_data[i], + output_matrix.lhs_sample_ids["table_0"].data()[i]); + EXPECT_THAT( + gains_data[i], + NanSensitiveFloatEq(output_matrix.lhs_gains["table_0"].data()[i])); + } + } +} + } // namespace } // namespace jax_sc_embedding diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.cc b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.cc index 56ec238f..72b1a06e 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.cc +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.cc @@ -91,7 +91,7 @@ bool ValidIndices(int row_index, int coo_offset, int processed, // Pad the row pointers buffer to the end of the buffer. void PadRowPointersBuffer(int& lhs_row_offset, int padding, int row_end, - Eigen::Ref row_pointers) { + absl::Span row_pointers) { while (lhs_row_offset < row_end) { row_pointers[lhs_row_offset++] = padding; } diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h index 4e813054..4719b457 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h @@ -23,6 +23,8 @@ #include "absl/base/attributes.h" // from @com_google_absl #include "absl/base/nullability.h" // from @com_google_absl +#include "absl/container/flat_hash_map.h" // from @com_google_absl +#include "absl/log/check.h" // from @com_google_absl #include "absl/strings/string_view.h" // from @com_google_absl #include "absl/types/span.h" // from @com_google_absl #include "Eigen/Core" // from @eigen_archive @@ -60,10 +62,10 @@ using BlockRow = Eigen::Block, 1, Eigen::Dynamic, Eigen::RowMajor>; namespace internal { struct CsrArraysPerDevice { - BlockRow row_pointers; - BlockRow embedding_ids; - BlockRow sample_ids; - BlockRow gains; + absl::Span row_pointers; + absl::Span embedding_ids; + absl::Span sample_ids; + absl::Span gains; }; struct StatsPerDevice { @@ -81,21 +83,58 @@ struct CsrArraysPerHost { MatrixXi sample_ids; MatrixXf gains; + internal::CsrArraysPerDevice views; + bool owns_data = false; + + const int local_device_count_; + CsrArraysPerHost(int local_device_count, int row_pointers_size_per_device, int coo_buffer_size_per_device) : row_pointers(local_device_count, row_pointers_size_per_device), embedding_ids(local_device_count, coo_buffer_size_per_device), sample_ids(local_device_count, coo_buffer_size_per_device), - gains(local_device_count, coo_buffer_size_per_device) {} + gains(local_device_count, coo_buffer_size_per_device), + owns_data(true), + local_device_count_(local_device_count) { + views = {absl::MakeSpan(row_pointers.data(), row_pointers.size()), + absl::MakeSpan(embedding_ids.data(), embedding_ids.size()), + absl::MakeSpan(sample_ids.data(), sample_ids.size()), + absl::MakeSpan(gains.data(), gains.size())}; + } + + CsrArraysPerHost(int local_device_count, int row_pointers_size_per_device, + int coo_buffer_size_per_device, + internal::CsrArraysPerDevice output_buffers) + : views(output_buffers), + owns_data(false), + local_device_count_(local_device_count) { + CHECK_EQ(output_buffers.row_pointers.size(), + local_device_count * row_pointers_size_per_device); + CHECK_EQ(output_buffers.embedding_ids.size(), + local_device_count * coo_buffer_size_per_device); + CHECK_EQ(output_buffers.sample_ids.size(), + local_device_count * coo_buffer_size_per_device); + CHECK_EQ(output_buffers.gains.size(), + local_device_count * coo_buffer_size_per_device); + } internal::CsrArraysPerDevice GetCsrArraysPerDevice(int local_device_id) ABSL_ATTRIBUTE_LIFETIME_BOUND { - return internal::CsrArraysPerDevice{ - .row_pointers = row_pointers.row(local_device_id), - .embedding_ids = embedding_ids.row(local_device_id), - .sample_ids = sample_ids.row(local_device_id), - .gains = gains.row(local_device_id), - }; + int row_pointers_size_per_device = + views.row_pointers.size() / local_device_count_; + int coo_buffer_size_per_device = + views.embedding_ids.size() / local_device_count_; + return {views.row_pointers.subspan( + local_device_id * row_pointers_size_per_device, + row_pointers_size_per_device), + views.embedding_ids.subspan( + local_device_id * coo_buffer_size_per_device, + coo_buffer_size_per_device), + views.sample_ids.subspan( + local_device_id * coo_buffer_size_per_device, + coo_buffer_size_per_device), + views.gains.subspan(local_device_id * coo_buffer_size_per_device, + coo_buffer_size_per_device)}; } }; @@ -195,6 +234,10 @@ struct PreprocessSparseDenseMatmulInputOptions { // mini-batching to synchronize state across different hosts. AllReduceInterface* absl_nullable all_reduce_interface; + // If provided, CSR data will be written directly to these buffers. + const absl::flat_hash_map* + output_buffers = nullptr; + // Hash function used for creating minibatching buckets. CooFormat::HashFn minibatching_bucketing_hash_fn = HighwayHash; diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util_test.cc b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util_test.cc index e0f719f7..5f1f6f07 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util_test.cc +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util_test.cc @@ -1269,5 +1269,38 @@ TEST(InputPreprocessingUtilTest, ElementsAreArray(expected_sample_ids)); } +TEST(CsrArraysPerHostTest, + CreatingCsrArraysPerHostFromExternalArraysTriggersZeroCopies) { + const int kLocalDeviceCount = 1; + const int kRpSize = 10; + const int kCooSize = 20; + std::vector rp_data(kRpSize, 0); + std::vector eid_data(kCooSize, 0); + std::vector sid_data(kCooSize, 0); + std::vector gains_data(kCooSize, 0.0); + + internal::CsrArraysPerDevice buffers{ + .row_pointers = absl::MakeSpan(rp_data), + .embedding_ids = absl::MakeSpan(eid_data), + .sample_ids = absl::MakeSpan(sid_data), + .gains = absl::MakeSpan(gains_data), + }; + + CsrArraysPerHost csr_arrays_per_host(kLocalDeviceCount, kRpSize, kCooSize, + buffers); + EXPECT_FALSE(csr_arrays_per_host.owns_data); + + internal::CsrArraysPerDevice device_array = + csr_arrays_per_host.GetCsrArraysPerDevice(0); + device_array.row_pointers[0] = 1; + device_array.embedding_ids[0] = 2; + device_array.sample_ids[0] = 3; + device_array.gains[0] = 4.0; + + EXPECT_EQ(rp_data[0], 1); + EXPECT_EQ(eid_data[0], 2); + EXPECT_EQ(sid_data[0], 3); + EXPECT_EQ(gains_data[0], 4.0); +} } // namespace } // namespace jax_sc_embedding