Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,11 @@ class AbstractInputBatch {
const RowCombiner combiner ABSL_REQUIRE_EXPLICIT_INIT;
};

// Return the batch size or the number of samples in this input batch.
virtual ssize_t size() const = 0;
// Returns the number of samples (e.g., rows) in this input batch.
virtual int64_t size() const = 0;

// Returns the total number of embedding IDs across all samples.
virtual int64_t id_count() const = 0;

// Returns true if the input batch has variable weights.
virtual bool HasVariableWeights() const { return true; }
Expand Down
5 changes: 5 additions & 0 deletions jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,14 @@ ExtractedCooTensors ExtractCooTensorsForAllFeaturesPerLocalDevice(
const int local_device_id,
const PreprocessSparseDenseMatmulInputOptions& options) {
int batch_size_for_device = 0;
int64_t total_ids_for_device = 0;
for (const auto& feature_metadata : stacked_table_metadata) {
batch_size_for_device +=
input_batches[feature_metadata.feature_index]->size() /
options.local_device_count;
total_ids_for_device +=
input_batches[feature_metadata.feature_index]->id_count() /
options.local_device_count;
}
CheckDeviceBatchSize(batch_size_for_device, options.num_sc_per_device,
stacked_table_metadata[0].name);
Expand All @@ -328,6 +332,7 @@ ExtractedCooTensors ExtractCooTensorsForAllFeaturesPerLocalDevice(

ExtractedCooTensors extracted_coo_tensors(options.num_sc_per_device,
batch_size_for_device);
extracted_coo_tensors.coo_tensors.reserve(total_ids_for_device);

// This slices each feature into `feature_slices` partitions and then
// interleaves them: (k=num_sc_per_device-1). For stacking strategy
Expand Down
9 changes: 1 addition & 8 deletions jax_tpu_embedding/sparsecore/lib/core/numpy_input_batch.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#include "jax_tpu_embedding/sparsecore/lib/core/numpy_input_batch.h"

#include <memory>
#include <vector>

#include "absl/log/check.h" // from @com_google_absl
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
Expand Down Expand Up @@ -49,8 +48,6 @@ class NumpyDenseInputBatchStream {
template <typename U>
NumpyDenseInputBatchStream(U matrix, int row_start, int row_end) = delete;

int size() const { return size_; }

int cols() const { return cols_; }

void NextRow() {
Expand Down Expand Up @@ -97,9 +94,6 @@ class NumpyRaggedInputBatchStream {
size_(row_end - row_start),
row_end_(row_end) {}

// estimate of total embedding ids (currently a lower bound).
int size() const { return size_; }

int cols() const { return row_ref_->shape(0); }

void NextRow() {
Expand Down Expand Up @@ -144,8 +138,7 @@ void NumpySparseInputBatch::ExtractCooTensors(

if (feature_.ndim() == 2) {
py::gil_scoped_acquire _;
// I'm not sure but without casting, passing feature_ as `const py::array&`
// and using feature_.unchecked_reference<T,2> seems to give garbage values.
// The casted temporary values must outlive the ProcessCooTensors.
auto feature_array = feature_.cast<py::array_t<int>>();
auto weights_array = weights_.cast<py::array_t<float>>();
py::gil_scoped_release __;
Expand Down
18 changes: 17 additions & 1 deletion jax_tpu_embedding/sparsecore/lib/core/numpy_input_batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#ifndef JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_NUMPY_INPUT_BATCH_H_
#define JAX_TPU_EMBEDDING_SPARSECORE_LIB_CORE_NUMPY_INPUT_BATCH_H_

#include <cstdint>

#include "absl/log/check.h" // from @com_google_absl
#include "jax_tpu_embedding/sparsecore/lib/core/abstract_input_batch.h"
#include "jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h"
Expand All @@ -38,16 +40,30 @@ class NumpySparseInputBatch : public AbstractInputBatch {
<< "Dimension mismatch for features and weights";
CHECK(feature_.ndim() == 1 || feature_.ndim() == 2)
<< "Only 1D and 2D numpy arrays supported as inputs.";

if (feature_.ndim() == 1) {
// Iterating over every row to sum up the number of IDs negates the
// performance benefit of reserving memory for them, so we underestimate
// the number of IDs as 1 per sample.
id_count_ = feature_.shape(0);
} else {
id_count_ = feature_.shape(0) * feature_.shape(1);
}
}

py::ssize_t size() const override { return feature_.shape(0); }
// Returns the number of samples in this input batch.
int64_t size() const override { return feature_.shape(0); }

// Returns the total number of embedding IDs across all samples.
int64_t id_count() const override { return id_count_; }

void ExtractCooTensors(const ExtractCooTensorsOptions& options,
ExtractedCooTensors& coo_tensors) override;

private:
const py::array feature_;
const py::array weights_;
int64_t id_count_;
};

} // namespace jax_sc_embedding
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -95,10 +95,6 @@ void ProcessCooTensors(
extracted_coo_tensors.batch_size_for_device / options.num_sc_per_device;
CHECK_GT(batch_size_per_sc, 0);

extracted_coo_tensors.coo_tensors.reserve(values_stream.size());

DCHECK_EQ(values_stream.size(), weights_stream.size());

for (; values_stream.row() < options.slice_end &&
weights_stream.row() < options.slice_end;
values_stream.NextRow(), weights_stream.NextRow()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,12 @@ class ABSL_ATTRIBUTE_VIEW RaggedTensorInputBatch : public AbstractInputBatch {
table_name_(table_name),
max_vocab_id_(max_vocab_id) {}

// Returns the number of samples in this input batch.
int64_t size() const override { return row_offsets_.size() - 1; }

// Returns the total number of embedding IDs across all samples.
int64_t id_count() const override { return row_offsets_[size()]; }

bool HasVariableWeights() const override { return false; }

void ExtractCooTensors(const ExtractCooTensorsOptions& options,
Expand Down Expand Up @@ -129,7 +133,12 @@ class RaggedTensorInputBatchWithOwnedData : public AbstractInputBatch {
view_(absl::MakeConstSpan(embedding_ids_),
absl::MakeConstSpan(row_splits_), table_name, max_vocab_id) {}

// Returns the number of samples in this input batch.
int64_t size() const override { return view_.size(); }

// Returns the total number of embedding IDs across all samples.
int64_t id_count() const override { return view_.id_count(); }

bool HasVariableWeights() const override {
return view_.HasVariableWeights();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,9 +56,12 @@ class PySparseCooInputBatch : public AbstractInputBatch {
<< "Need GIL to create references to indices and values.";
}

// Returns the number of rows in the current slice.
// Returns the number of samples in this input batch.
int64_t size() const override { return batch_size_; }

// Returns the total number of embedding IDs across all samples.
int64_t id_count() const override { return values_.size(); }

bool HasVariableWeights() const override { return false; }

// Extracts COO tensors for each SparseCore.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class ABSL_ATTRIBUTE_VIEW SparseCsrInputBatchStream {
T max_vocab_id = std::numeric_limits<T>::max())
: values_ref_(values),
row_pointers_(row_pointers),
row_start_(row_start),
curr_row_(row_start),
row_end_(row_end),
curr_idx_(row_pointers[row_start]),
Expand All @@ -66,8 +67,6 @@ class ABSL_ATTRIBUTE_VIEW SparseCsrInputBatchStream {
: row_pointers_[curr_row_ + 1] - row_pointers_[curr_row_];
}

int size() const { return row_pointers_[row_end_] - row_pointers_[0]; }

// Returns number of values in current row.
int cols() const { return curr_row_cols_; }

Expand Down Expand Up @@ -99,6 +98,7 @@ class ABSL_ATTRIBUTE_VIEW SparseCsrInputBatchStream {
private:
ValuesView values_ref_;
RowPointersView row_pointers_;
int row_start_;
int curr_row_;
int row_end_;
int curr_idx_;
Expand Down
Loading