Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions jax_tpu_embedding/sparsecore/lib/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ cc_library(
":partitioned_coo_tensors",
"@com_google_absl//absl/base:core_headers",
"@com_google_absl//absl/base:nullability",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/strings:string_view",
Expand Down
54 changes: 38 additions & 16 deletions jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,26 @@ void CheckDeviceBatchSize(int batch_size_for_device, int num_sc_per_device,
batch_size_for_device, stacked_table_name, num_sc_per_device);
}

CsrArraysPerHost CreateCsrArraysPerHost(
absl::string_view name,
const PreprocessSparseDenseMatmulInputOptions& options,
int coo_buffer_size_per_device, int row_pointers_size_per_bucket) {
const int row_pointers_dim = row_pointers_size_per_bucket *
(options.enable_minibatching
? CooFormat::kMaxMinibatchingBuckets
: 1) *
options.num_sc_per_device;
if (options.output_buffers) {
auto it = options.output_buffers->find(name);
if (it != options.output_buffers->end()) {
return CsrArraysPerHost(options.local_device_count, row_pointers_dim,
coo_buffer_size_per_device, it->second);
}
}
return CsrArraysPerHost(options.local_device_count, row_pointers_dim,
coo_buffer_size_per_device);
}

// Holds the state for processing a single stacked table across all local
// devices. This includes extracted COO tensors, partitioned COO tensors,
// CSR arrays, and statistics.
Expand All @@ -138,13 +158,9 @@ struct TableState {
coo_buffer_size_per_device(ComputeCooBufferSizePerDevice(
num_scs, options.num_sc_per_device, metadata, options.batch_number,
options.enable_minibatching)),
csr_arrays_per_host(options.local_device_count,
row_pointers_size_per_bucket *
(options.enable_minibatching
? CooFormat::kMaxMinibatchingBuckets
: 1) *
options.num_sc_per_device,
coo_buffer_size_per_device),
csr_arrays_per_host(CreateCsrArraysPerHost(
name, options, coo_buffer_size_per_device,
row_pointers_size_per_bucket)),
stats_per_host(options.local_device_count, options.GetNumScs(),
options.num_sc_per_device),
batch_size_for_device(0) {
Expand Down Expand Up @@ -428,15 +444,21 @@ void PopulateOutput(TableState& state, PreprocessSparseDenseMatmulOutput& out,
absl::Mutex& output_mutex) {
state.stats_per_host.Flatten();

absl::MutexLock mutex(output_mutex);
out.lhs_row_pointers[state.stacked_table_name] =
std::move(state.csr_arrays_per_host.row_pointers);
out.lhs_embedding_ids[state.stacked_table_name] =
std::move(state.csr_arrays_per_host.embedding_ids);
out.lhs_sample_ids[state.stacked_table_name] =
std::move(state.csr_arrays_per_host.sample_ids);
out.lhs_gains[state.stacked_table_name] =
std::move(state.csr_arrays_per_host.gains);
absl::MutexLock lock(output_mutex);
// If `owns_data` is true, it indicates that the data is owned
// by `CsrArraysPerHost`, so we need to move it to the output. Otherwise,
// the data has already been written directly into the output buffers via
// `views`, and no move is necessary.
if (state.csr_arrays_per_host.owns_data) {
out.lhs_row_pointers[state.stacked_table_name] =
std::move(state.csr_arrays_per_host.row_pointers);
out.lhs_embedding_ids[state.stacked_table_name] =
std::move(state.csr_arrays_per_host.embedding_ids);
out.lhs_sample_ids[state.stacked_table_name] =
std::move(state.csr_arrays_per_host.sample_ids);
out.lhs_gains[state.stacked_table_name] =
std::move(state.csr_arrays_per_host.gains);
}

out.stats.max_ids_per_partition[state.stacked_table_name] =
std::move(state.stats_per_host.max_ids_per_partition);
Expand Down
64 changes: 64 additions & 0 deletions jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ using ::testing::Each;
using ::testing::ElementsAreArray;
using ::testing::Eq;
using ::testing::Gt;
using ::testing::NanSensitiveFloatEq;
using ::testing::SizeIs;

std::unique_ptr<AbstractInputBatch> CreateInputBatchFromSamples(
Expand Down Expand Up @@ -1388,5 +1389,68 @@ FUZZ_TEST(InputPreprocessingFuzzTest, StatsValidationTest)
{FeatureStackingStrategy::kStackThenSplit,
FeatureStackingStrategy::kSplitThenStack}));

TEST_F(TableStackingTest,
PreprocessingWithAndWithoutOutputBuffersIsEquivalent) {
PreprocessSparseDenseMatmulInputOptions options{
.local_device_count = 1,
.global_device_count = 2,
.num_sc_per_device = 4,
.feature_stacking_strategy = FeatureStackingStrategy::kStackThenSplit};
absl::flat_hash_map<std::string, std::vector<StackedTableMetadata>>
stacked_tables({{"table_0", stacked_table_metadata_multi_}});

TF_ASSERT_OK_AND_ASSIGN(
PreprocessSparseDenseMatmulOutput output_matrix,
PreprocessSparseDenseMatmulInput(absl::MakeSpan(input_batches_multi_),
stacked_tables, options));

const int num_scs = options.GetNumScs();
const int coo_buffer_size_per_device =
ComputeCooBufferSizePerDevice(num_scs, options.num_sc_per_device,
stacked_table_metadata_multi_, 0, false);
const int row_pointers_size =
std::max(num_scs, TPU_VECTOR_REGISTER_ALIGNMENT_SIZE) *
options.num_sc_per_device;
std::vector<int> row_pointers_data(row_pointers_size, INT_MAX);
std::vector<int> embedding_ids_data(coo_buffer_size_per_device, INT_MAX);
std::vector<int> sample_ids_data(coo_buffer_size_per_device, INT_MAX);
std::vector<float> gains_data(coo_buffer_size_per_device, std::nanf(""));
absl::flat_hash_map<std::string, internal::CsrArraysPerDevice> output_buffers;
output_buffers["table_0"] = internal::CsrArraysPerDevice{
.row_pointers = absl::MakeSpan(row_pointers_data),
.embedding_ids = absl::MakeSpan(embedding_ids_data),
.sample_ids = absl::MakeSpan(sample_ids_data),
.gains = absl::MakeSpan(gains_data),
};
options.output_buffers = &output_buffers;

TF_ASSERT_OK_AND_ASSIGN(
PreprocessSparseDenseMatmulOutput output_zero_copy,
PreprocessSparseDenseMatmulInput(absl::MakeSpan(input_batches_multi_),
stacked_tables, options));
options.output_buffers = nullptr; // for next test

ASSERT_EQ(output_matrix.lhs_row_pointers["table_0"].rows(), 1);
ASSERT_EQ(output_matrix.lhs_embedding_ids["table_0"].rows(), 1);
ASSERT_EQ(output_matrix.lhs_sample_ids["table_0"].rows(), 1);
ASSERT_EQ(output_matrix.lhs_gains["table_0"].rows(), 1);

EXPECT_THAT(row_pointers_data,
ElementsAreArray(absl::MakeConstSpan(
output_matrix.lhs_row_pointers["table_0"].data(),
row_pointers_size)));
for (int i = 0; i < coo_buffer_size_per_device; ++i) {
if (embedding_ids_data[i] != INT_MAX) {
EXPECT_EQ(embedding_ids_data[i],
output_matrix.lhs_embedding_ids["table_0"].data()[i]);
EXPECT_EQ(sample_ids_data[i],
output_matrix.lhs_sample_ids["table_0"].data()[i]);
EXPECT_THAT(
gains_data[i],
NanSensitiveFloatEq(output_matrix.lhs_gains["table_0"].data()[i]));
}
}
}

} // namespace
} // namespace jax_sc_embedding
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ bool ValidIndices(int row_index, int coo_offset, int processed,

// Pad the row pointers buffer to the end of the buffer.
void PadRowPointersBuffer(int& lhs_row_offset, int padding, int row_end,
Eigen::Ref<RowVectorXi> row_pointers) {
absl::Span<int> row_pointers) {
while (lhs_row_offset < row_end) {
row_pointers[lhs_row_offset++] = padding;
}
Expand Down
65 changes: 54 additions & 11 deletions jax_tpu_embedding/sparsecore/lib/core/input_preprocessing_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

#include "absl/base/attributes.h" // from @com_google_absl
#include "absl/base/nullability.h" // from @com_google_absl
#include "absl/container/flat_hash_map.h" // from @com_google_absl
#include "absl/log/check.h" // from @com_google_absl
#include "absl/strings/string_view.h" // from @com_google_absl
#include "absl/types/span.h" // from @com_google_absl
#include "Eigen/Core" // from @eigen_archive
Expand Down Expand Up @@ -60,10 +62,10 @@ using BlockRow = Eigen::Block<MatrixX<T>, 1, Eigen::Dynamic, Eigen::RowMajor>;
namespace internal {

struct CsrArraysPerDevice {
BlockRow<int> row_pointers;
BlockRow<int> embedding_ids;
BlockRow<int> sample_ids;
BlockRow<float> gains;
absl::Span<int> row_pointers;
absl::Span<int> embedding_ids;
absl::Span<int> sample_ids;
absl::Span<float> gains;
};

struct StatsPerDevice {
Expand All @@ -81,21 +83,58 @@ struct CsrArraysPerHost {
MatrixXi sample_ids;
MatrixXf gains;

internal::CsrArraysPerDevice views;
bool owns_data = false;

const int local_device_count_;

CsrArraysPerHost(int local_device_count, int row_pointers_size_per_device,
int coo_buffer_size_per_device)
: row_pointers(local_device_count, row_pointers_size_per_device),
embedding_ids(local_device_count, coo_buffer_size_per_device),
sample_ids(local_device_count, coo_buffer_size_per_device),
gains(local_device_count, coo_buffer_size_per_device) {}
gains(local_device_count, coo_buffer_size_per_device),
owns_data(true),
local_device_count_(local_device_count) {
views = {absl::MakeSpan(row_pointers.data(), row_pointers.size()),
absl::MakeSpan(embedding_ids.data(), embedding_ids.size()),
absl::MakeSpan(sample_ids.data(), sample_ids.size()),
absl::MakeSpan(gains.data(), gains.size())};
}

CsrArraysPerHost(int local_device_count, int row_pointers_size_per_device,
int coo_buffer_size_per_device,
internal::CsrArraysPerDevice output_buffers)
: views(output_buffers),
owns_data(false),
local_device_count_(local_device_count) {
CHECK_EQ(output_buffers.row_pointers.size(),
local_device_count * row_pointers_size_per_device);
CHECK_EQ(output_buffers.embedding_ids.size(),
local_device_count * coo_buffer_size_per_device);
CHECK_EQ(output_buffers.sample_ids.size(),
local_device_count * coo_buffer_size_per_device);
CHECK_EQ(output_buffers.gains.size(),
local_device_count * coo_buffer_size_per_device);
}

internal::CsrArraysPerDevice GetCsrArraysPerDevice(int local_device_id)
ABSL_ATTRIBUTE_LIFETIME_BOUND {
return internal::CsrArraysPerDevice{
.row_pointers = row_pointers.row(local_device_id),
.embedding_ids = embedding_ids.row(local_device_id),
.sample_ids = sample_ids.row(local_device_id),
.gains = gains.row(local_device_id),
};
int row_pointers_size_per_device =
views.row_pointers.size() / local_device_count_;
int coo_buffer_size_per_device =
views.embedding_ids.size() / local_device_count_;
return {views.row_pointers.subspan(
local_device_id * row_pointers_size_per_device,
row_pointers_size_per_device),
views.embedding_ids.subspan(
local_device_id * coo_buffer_size_per_device,
coo_buffer_size_per_device),
views.sample_ids.subspan(
local_device_id * coo_buffer_size_per_device,
coo_buffer_size_per_device),
views.gains.subspan(local_device_id * coo_buffer_size_per_device,
coo_buffer_size_per_device)};
}
};

Expand Down Expand Up @@ -195,6 +234,10 @@ struct PreprocessSparseDenseMatmulInputOptions {
// mini-batching to synchronize state across different hosts.
AllReduceInterface* absl_nullable all_reduce_interface;

// If provided, CSR data will be written directly to these buffers.
const absl::flat_hash_map<std::string, internal::CsrArraysPerDevice>*
output_buffers = nullptr;

// Hash function used for creating minibatching buckets.
CooFormat::HashFn minibatching_bucketing_hash_fn = HighwayHash;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1269,5 +1269,38 @@ TEST(InputPreprocessingUtilTest,
ElementsAreArray(expected_sample_ids));
}

TEST(CsrArraysPerHostTest,
CreatingCsrArraysPerHostFromExternalArraysTriggersZeroCopies) {
const int kLocalDeviceCount = 1;
const int kRpSize = 10;
const int kCooSize = 20;
std::vector<int> rp_data(kRpSize, 0);
std::vector<int> eid_data(kCooSize, 0);
std::vector<int> sid_data(kCooSize, 0);
std::vector<float> gains_data(kCooSize, 0.0);

internal::CsrArraysPerDevice buffers{
.row_pointers = absl::MakeSpan(rp_data),
.embedding_ids = absl::MakeSpan(eid_data),
.sample_ids = absl::MakeSpan(sid_data),
.gains = absl::MakeSpan(gains_data),
};

CsrArraysPerHost csr_arrays_per_host(kLocalDeviceCount, kRpSize, kCooSize,
buffers);
EXPECT_FALSE(csr_arrays_per_host.owns_data);

internal::CsrArraysPerDevice device_array =
csr_arrays_per_host.GetCsrArraysPerDevice(0);
device_array.row_pointers[0] = 1;
device_array.embedding_ids[0] = 2;
device_array.sample_ids[0] = 3;
device_array.gains[0] = 4.0;

EXPECT_EQ(rp_data[0], 1);
EXPECT_EQ(eid_data[0], 2);
EXPECT_EQ(sid_data[0], 3);
EXPECT_EQ(gains_data[0], 4.0);
}
} // namespace
} // namespace jax_sc_embedding
Loading