Skip to content

Commit 7c46853

Browse files
[JAX SC] Add id_count() to AbstractInputBatch for accurate pre-allocation.
* Also remove redundant stream slice `size()` function. PiperOrigin-RevId: 838966852
1 parent 73ad6e9 commit 7c46853

File tree

8 files changed

+43
-18
lines changed

8 files changed

+43
-18
lines changed

jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,11 @@ class AbstractInputBatch {
4646
const RowCombiner combiner ABSL_REQUIRE_EXPLICIT_INIT;
4747
};
4848

49-
// Return the batch size or the number of samples in this input batch.
50-
virtual ssize_t size() const = 0;
49+
// Returns the number of samples (e.g., rows) in this input batch.
50+
virtual int64_t size() const = 0;
51+
52+
// Returns the total number of embedding IDs across all samples.
53+
virtual int64_t id_count() const = 0;
5154

5255
// Returns true if the input batch has variable weights.
5356
virtual bool HasVariableWeights() const { return true; }

jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,10 +299,14 @@ ExtractedCooTensors ExtractCooTensorsForAllFeaturesPerLocalDevice(
299299
const int local_device_id,
300300
const PreprocessSparseDenseMatmulInputOptions& options) {
301301
int batch_size_for_device = 0;
302+
int64_t total_ids_for_device = 0;
302303
for (const auto& feature_metadata : stacked_table_metadata) {
303304
batch_size_for_device +=
304305
input_batches[feature_metadata.feature_index]->size() /
305306
options.local_device_count;
307+
total_ids_for_device +=
308+
input_batches[feature_metadata.feature_index]->id_count() /
309+
options.local_device_count;
306310
}
307311
CheckDeviceBatchSize(batch_size_for_device, options.num_sc_per_device,
308312
stacked_table_metadata[0].name);
@@ -328,6 +332,7 @@ ExtractedCooTensors ExtractCooTensorsForAllFeaturesPerLocalDevice(
328332

329333
ExtractedCooTensors extracted_coo_tensors(options.num_sc_per_device,
330334
batch_size_for_device);
335+
extracted_coo_tensors.coo_tensors.reserve(total_ids_for_device);
331336

332337
// This slices each feature into `feature_slices` partitions and then
333338
// interleaves them: (k=num_sc_per_device-1). For stacking strategy

jax_tpu_embedding/sparsecore/lib/core/numpy_input_batch.cc

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
#include "jax_tpu_embedding/sparsecore/lib/core/numpy_input_batch.h"
1515

1616
#include <memory>
17-
#include <vector>
1817

1918
#include "absl/log/check.h" // from @com_google_absl
2019
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
@@ -49,8 +48,6 @@ class NumpyDenseInputBatchStream {
4948
template <typename U>
5049
NumpyDenseInputBatchStream(U matrix, int row_start, int row_end) = delete;
5150

52-
int size() const { return size_; }
53-
5451
int cols() const { return cols_; }
5552

5653
void NextRow() {
@@ -97,9 +94,6 @@ class NumpyRaggedInputBatchStream {
9794
size_(row_end - row_start),
9895
row_end_(row_end) {}
9996

100-
// estimate of total embedding ids (currently a lower bound).
101-
int size() const { return size_; }
102-
10397
int cols() const { return row_ref_->shape(0); }
10498

10599
void NextRow() {
@@ -144,8 +138,7 @@ void NumpySparseInputBatch::ExtractCooTensors(
144138

145139
if (feature_.ndim() == 2) {
146140
py::gil_scoped_acquire _;
147-
// I'm not sure but without casting, passing feature_ as `const py::array&`
148-
// and using feature_.unchecked_reference<T,2> seems to give garbage values.
141+
// The casted temporary values must outlive the ProcessCooTensors.
149142
auto feature_array = feature_.cast<py::array_t<int>>();
150143
auto weights_array = weights_.cast<py::array_t<float>>();
151144
py::gil_scoped_release __;

jax_tpu_embedding/sparsecore/lib/core/numpy_input_batch.h

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
#ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_NUMPY_INPUT_BATCH_H_
1515
#define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_NUMPY_INPUT_BATCH_H_
1616

17+
#include <cstdint>
18+
1719
#include "absl/log/check.h" // from @com_google_absl
1820
#include "jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h"
1921
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
@@ -38,16 +40,30 @@ class NumpySparseInputBatch : public AbstractInputBatch {
3840
<< "Dimension mismatch for features and weights";
3941
CHECK(feature_.ndim() == 1 || feature_.ndim() == 2)
4042
<< "Only 1D and 2D numpy arrays supported as inputs.";
43+
44+
if (feature_.ndim() == 1) {
45+
// Iterating over every row to sum up the number of IDs negates the
46+
// performance benefit of reserving memory for them, so we underestimate
47+
// the number of IDs as 1 per sample.
48+
id_count_ = feature_.shape(0);
49+
} else {
50+
id_count_ = feature_.shape(0) * feature_.shape(1);
51+
}
4152
}
4253

43-
py::ssize_t size() const override { return feature_.shape(0); }
54+
// Returns the number of samples in this input batch.
55+
int64_t size() const override { return feature_.shape(0); }
56+
57+
// Returns the total number of embedding IDs across all samples.
58+
int64_t id_count() const override { return id_count_; }
4459

4560
void ExtractCooTensors(const ExtractCooTensorsOptions& options,
4661
ExtractedCooTensors& coo_tensors) override;
4762

4863
private:
4964
const py::array feature_;
5065
const py::array weights_;
66+
int64_t id_count_;
5167
};
5268

5369
} // namespace jax_sc_embedding

jax_tpu_embedding/sparsecore/lib/core/process_coo_tensors_impl.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,6 @@ void ProcessCooTensors(
9595
extracted_coo_tensors.batch_size_for_device / options.num_sc_per_device;
9696
CHECK_GT(batch_size_per_sc, 0);
9797

98-
extracted_coo_tensors.coo_tensors.reserve(values_stream.size());
99-
100-
DCHECK_EQ(values_stream.size(), weights_stream.size());
101-
10298
for (; values_stream.row() < options.slice_end &&
10399
weights_stream.row() < options.slice_end;
104100
values_stream.NextRow(), weights_stream.NextRow()) {

jax_tpu_embedding/sparsecore/lib/core/ragged_tensor_input_batch.h

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,12 @@ class ABSL_ATTRIBUTE_VIEW RaggedTensorInputBatch : public AbstractInputBatch {
8888
table_name_(table_name),
8989
max_vocab_id_(max_vocab_id) {}
9090

91+
// Returns the number of samples in this input batch.
9192
int64_t size() const override { return row_offsets_.size() - 1; }
9293

94+
// Returns the total number of embedding IDs across all samples.
95+
int64_t id_count() const override { return row_offsets_[size()]; }
96+
9397
bool HasVariableWeights() const override { return false; }
9498

9599
void ExtractCooTensors(const ExtractCooTensorsOptions& options,
@@ -129,7 +133,12 @@ class RaggedTensorInputBatchWithOwnedData : public AbstractInputBatch {
129133
view_(absl::MakeConstSpan(embedding_ids_),
130134
absl::MakeConstSpan(row_splits_), table_name, max_vocab_id) {}
131135

136+
// Returns the number of samples in this input batch.
132137
int64_t size() const override { return view_.size(); }
138+
139+
// Returns the total number of embedding IDs across all samples.
140+
int64_t id_count() const override { return view_.id_count(); }
141+
133142
bool HasVariableWeights() const override {
134143
return view_.HasVariableWeights();
135144
}

jax_tpu_embedding/sparsecore/lib/core/sparse_coo_input_batch.h

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,12 @@ class PySparseCooInputBatch : public AbstractInputBatch {
5656
<< "Need GIL to create references to indices and values.";
5757
}
5858

59-
// Returns the number of rows in the current slice.
59+
// Returns the number of samples in this input batch.
6060
int64_t size() const override { return batch_size_; }
6161

62+
// Returns the total number of embedding IDs across all samples.
63+
int64_t id_count() const override { return values_.size(); }
64+
6265
bool HasVariableWeights() const override { return false; }
6366

6467
// Extracts COO tensors for each SparseCore.

jax_tpu_embedding/sparsecore/lib/core/sparse_csr_input_stream_impl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ class ABSL_ATTRIBUTE_VIEW SparseCsrInputBatchStream {
5555
T max_vocab_id = std::numeric_limits<T>::max())
5656
: values_ref_(values),
5757
row_pointers_(row_pointers),
58+
row_start_(row_start),
5859
curr_row_(row_start),
5960
row_end_(row_end),
6061
curr_idx_(row_pointers[row_start]),
@@ -66,8 +67,6 @@ class ABSL_ATTRIBUTE_VIEW SparseCsrInputBatchStream {
6667
: row_pointers_[curr_row_ + 1] - row_pointers_[curr_row_];
6768
}
6869

69-
int size() const { return row_pointers_[row_end_] - row_pointers_[0]; }
70-
7170
// Returns number of values in current row.
7271
int cols() const { return curr_row_cols_; }
7372

@@ -99,6 +98,7 @@ class ABSL_ATTRIBUTE_VIEW SparseCsrInputBatchStream {
9998
private:
10099
ValuesView values_ref_;
101100
RowPointersView row_pointers_;
101+
int row_start_;
102102
int curr_row_;
103103
int row_end_;
104104
int curr_idx_;

0 commit comments

Comments
 (0)