From 7c46853caaab08a3d770d53400d250040f99db53 Mon Sep 17 00:00:00 2001 From: Aditya Gupta Date: Mon, 1 Dec 2025 15:59:00 -0800 Subject: [PATCH] [JAX SC] Add id_count() to AbstractInputBatch for accurate pre-allocation. * Also remove redundant stream slice `size()` function. PiperOrigin-RevId: 838966852 --- .../sparsecore/lib/core/abstract_input_batch.h | 7 +++++-- .../sparsecore/lib/core/input_preprocessing.cc | 5 +++++ .../sparsecore/lib/core/numpy_input_batch.cc | 9 +-------- .../sparsecore/lib/core/numpy_input_batch.h | 18 +++++++++++++++++- .../lib/core/process_coo_tensors_impl.h | 4 ---- .../lib/core/ragged_tensor_input_batch.h | 9 +++++++++ .../lib/core/sparse_coo_input_batch.h | 5 ++++- .../lib/core/sparse_csr_input_stream_impl.h | 4 ++-- 8 files changed, 43 insertions(+), 18 deletions(-) diff --git a/jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h b/jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h index 9ad1470f..d72611fc 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h +++ b/jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h @@ -46,8 +46,11 @@ class AbstractInputBatch { const RowCombiner combiner ABSL_REQUIRE_EXPLICIT_INIT; }; - // Return the batch size or the number of samples in this input batch. - virtual ssize_t size() const = 0; + // Returns the number of samples (e.g., rows) in this input batch. + virtual int64_t size() const = 0; + + // Returns the total number of embedding IDs across all samples. + virtual int64_t id_count() const = 0; // Returns true if the input batch has variable weights. virtual bool HasVariableWeights() const { return true; } diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc index 9352f902..21a73e4c 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc @@ -299,10 +299,14 @@ ExtractedCooTensors ExtractCooTensorsForAllFeaturesPerLocalDevice( const int local_device_id, const PreprocessSparseDenseMatmulInputOptions& options) { int batch_size_for_device = 0; + int64_t total_ids_for_device = 0; for (const auto& feature_metadata : stacked_table_metadata) { batch_size_for_device += input_batches[feature_metadata.feature_index]->size() / options.local_device_count; + total_ids_for_device += + input_batches[feature_metadata.feature_index]->id_count() / + options.local_device_count; } CheckDeviceBatchSize(batch_size_for_device, options.num_sc_per_device, stacked_table_metadata[0].name); @@ -328,6 +332,7 @@ ExtractedCooTensors ExtractCooTensorsForAllFeaturesPerLocalDevice( ExtractedCooTensors extracted_coo_tensors(options.num_sc_per_device, batch_size_for_device); + extracted_coo_tensors.coo_tensors.reserve(total_ids_for_device); // This slices each feature into `feature_slices` partitions and then // interleaves them: (k=num_sc_per_device-1). For stacking strategy diff --git a/jax_tpu_embedding/sparsecore/lib/core/numpy_input_batch.cc b/jax_tpu_embedding/sparsecore/lib/core/numpy_input_batch.cc index 550fe4b7..78c8c555 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/numpy_input_batch.cc +++ b/jax_tpu_embedding/sparsecore/lib/core/numpy_input_batch.cc @@ -14,7 +14,6 @@ #include "jax_tpu_embedding/sparsecore/lib/core/numpy_input_batch.h" #include -#include #include "absl/log/check.h" // from @com_google_absl #include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h" @@ -49,8 +48,6 @@ class NumpyDenseInputBatchStream { template NumpyDenseInputBatchStream(U matrix, int row_start, int row_end) = delete; - int size() const { return size_; } - int cols() const { return cols_; } void NextRow() { @@ -97,9 +94,6 @@ class NumpyRaggedInputBatchStream { size_(row_end - row_start), row_end_(row_end) {} - // estimate of total embedding ids (currently a lower bound). - int size() const { return size_; } - int cols() const { return row_ref_->shape(0); } void NextRow() { @@ -144,8 +138,7 @@ void NumpySparseInputBatch::ExtractCooTensors( if (feature_.ndim() == 2) { py::gil_scoped_acquire _; - // I'm not sure but without casting, passing feature_ as `const py::array&` - // and using feature_.unchecked_reference seems to give garbage values. + // The casted temporary values must outlive the ProcessCooTensors. auto feature_array = feature_.cast>(); auto weights_array = weights_.cast>(); py::gil_scoped_release __; diff --git a/jax_tpu_embedding/sparsecore/lib/core/numpy_input_batch.h b/jax_tpu_embedding/sparsecore/lib/core/numpy_input_batch.h index 1aff4616..5200e47d 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/numpy_input_batch.h +++ b/jax_tpu_embedding/sparsecore/lib/core/numpy_input_batch.h @@ -14,6 +14,8 @@ #ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_NUMPY_INPUT_BATCH_H_ #define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_NUMPY_INPUT_BATCH_H_ +#include + #include "absl/log/check.h" // from @com_google_absl #include "jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h" #include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h" @@ -38,9 +40,22 @@ class NumpySparseInputBatch : public AbstractInputBatch { << "Dimension mismatch for features and weights"; CHECK(feature_.ndim() == 1 || feature_.ndim() == 2) << "Only 1D and 2D numpy arrays supported as inputs."; + + if (feature_.ndim() == 1) { + // Iterating over every row to sum up the number of IDs negates the + // performance benefit of reserving memory for them, so we underestimate + // the number of IDs as 1 per sample. + id_count_ = feature_.shape(0); + } else { + id_count_ = feature_.shape(0) * feature_.shape(1); + } } - py::ssize_t size() const override { return feature_.shape(0); } + // Returns the number of samples in this input batch. + int64_t size() const override { return feature_.shape(0); } + + // Returns the total number of embedding IDs across all samples. + int64_t id_count() const override { return id_count_; } void ExtractCooTensors(const ExtractCooTensorsOptions& options, ExtractedCooTensors& coo_tensors) override; @@ -48,6 +63,7 @@ class NumpySparseInputBatch : public AbstractInputBatch { private: const py::array feature_; const py::array weights_; + int64_t id_count_; }; } // namespace jax_sc_embedding diff --git a/jax_tpu_embedding/sparsecore/lib/core/process_coo_tensors_impl.h b/jax_tpu_embedding/sparsecore/lib/core/process_coo_tensors_impl.h index 1295129d..5127c16e 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/process_coo_tensors_impl.h +++ b/jax_tpu_embedding/sparsecore/lib/core/process_coo_tensors_impl.h @@ -95,10 +95,6 @@ void ProcessCooTensors( extracted_coo_tensors.batch_size_for_device / options.num_sc_per_device; CHECK_GT(batch_size_per_sc, 0); - extracted_coo_tensors.coo_tensors.reserve(values_stream.size()); - - DCHECK_EQ(values_stream.size(), weights_stream.size()); - for (; values_stream.row() < options.slice_end && weights_stream.row() < options.slice_end; values_stream.NextRow(), weights_stream.NextRow()) { diff --git a/jax_tpu_embedding/sparsecore/lib/core/ragged_tensor_input_batch.h b/jax_tpu_embedding/sparsecore/lib/core/ragged_tensor_input_batch.h index b337a27b..27136fac 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/ragged_tensor_input_batch.h +++ b/jax_tpu_embedding/sparsecore/lib/core/ragged_tensor_input_batch.h @@ -88,8 +88,12 @@ class ABSL_ATTRIBUTE_VIEW RaggedTensorInputBatch : public AbstractInputBatch { table_name_(table_name), max_vocab_id_(max_vocab_id) {} + // Returns the number of samples in this input batch. int64_t size() const override { return row_offsets_.size() - 1; } + // Returns the total number of embedding IDs across all samples. + int64_t id_count() const override { return row_offsets_[size()]; } + bool HasVariableWeights() const override { return false; } void ExtractCooTensors(const ExtractCooTensorsOptions& options, @@ -129,7 +133,12 @@ class RaggedTensorInputBatchWithOwnedData : public AbstractInputBatch { view_(absl::MakeConstSpan(embedding_ids_), absl::MakeConstSpan(row_splits_), table_name, max_vocab_id) {} + // Returns the number of samples in this input batch. int64_t size() const override { return view_.size(); } + + // Returns the total number of embedding IDs across all samples. + int64_t id_count() const override { return view_.id_count(); } + bool HasVariableWeights() const override { return view_.HasVariableWeights(); } diff --git a/jax_tpu_embedding/sparsecore/lib/core/sparse_coo_input_batch.h b/jax_tpu_embedding/sparsecore/lib/core/sparse_coo_input_batch.h index 1082ad6d..d5f5b6af 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/sparse_coo_input_batch.h +++ b/jax_tpu_embedding/sparsecore/lib/core/sparse_coo_input_batch.h @@ -56,9 +56,12 @@ class PySparseCooInputBatch : public AbstractInputBatch { << "Need GIL to create references to indices and values."; } - // Returns the number of rows in the current slice. + // Returns the number of samples in this input batch. int64_t size() const override { return batch_size_; } + // Returns the total number of embedding IDs across all samples. + int64_t id_count() const override { return values_.size(); } + bool HasVariableWeights() const override { return false; } // Extracts COO tensors for each SparseCore. diff --git a/jax_tpu_embedding/sparsecore/lib/core/sparse_csr_input_stream_impl.h b/jax_tpu_embedding/sparsecore/lib/core/sparse_csr_input_stream_impl.h index 9b9fb6d7..b7274b5b 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/sparse_csr_input_stream_impl.h +++ b/jax_tpu_embedding/sparsecore/lib/core/sparse_csr_input_stream_impl.h @@ -55,6 +55,7 @@ class ABSL_ATTRIBUTE_VIEW SparseCsrInputBatchStream { T max_vocab_id = std::numeric_limits::max()) : values_ref_(values), row_pointers_(row_pointers), + row_start_(row_start), curr_row_(row_start), row_end_(row_end), curr_idx_(row_pointers[row_start]), @@ -66,8 +67,6 @@ class ABSL_ATTRIBUTE_VIEW SparseCsrInputBatchStream { : row_pointers_[curr_row_ + 1] - row_pointers_[curr_row_]; } - int size() const { return row_pointers_[row_end_] - row_pointers_[0]; } - // Returns number of values in current row. int cols() const { return curr_row_cols_; } @@ -99,6 +98,7 @@ class ABSL_ATTRIBUTE_VIEW SparseCsrInputBatchStream { private: ValuesView values_ref_; RowPointersView row_pointers_; + int row_start_; int curr_row_; int row_end_; int curr_idx_;