Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions jax_tpu_embedding/sparsecore/lib/core/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ cc_test(
"@com_google_fuzztest//fuzztest",
"@com_google_googletest//:gtest_main",
"@eigen_archive//:eigen3",
"@xla//xla:util",
],
)

Expand All @@ -139,6 +140,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log",
"@com_google_absl//absl/log:check",
"@com_google_absl//absl/status:statusor",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/strings:string_view",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -85,15 +85,17 @@ std::vector<int> GenerateEmbeddingIdsForRow(absl::BitGen& gen, int vocab_size) {
return ids_out;
}

ExtractedCooTensors GenerateSkewedCooTensors(int num_sc_per_device,
int batch_size_per_sc,
int vocab_size) {
std::vector<ExtractedSparseCoreTensors> GenerateSkewedCooTensors(
int num_sc_per_device, int batch_size_per_sc, int vocab_size) {
const int batch_size_for_device = num_sc_per_device * batch_size_per_sc;

absl::BitGen gen(std::seed_seq{kSeed}); // seed for reproducibility

ExtractedCooTensors extracted_coo_tensors(num_sc_per_device,
batch_size_for_device);
std::vector<ExtractedSparseCoreTensors> extracted_sc_tensors(
num_sc_per_device);
for (int i = 0; i < num_sc_per_device; ++i) {
extracted_sc_tensors[i].batch_size_for_device = batch_size_for_device;
}

// For each sample in the batch:
// 1. Draw a sample size from a Lognormal distribution.
Expand All @@ -104,13 +106,12 @@ ExtractedCooTensors GenerateSkewedCooTensors(int num_sc_per_device,
std::vector<int> embedding_ids =
GenerateEmbeddingIdsForRow(gen, vocab_size);
int sc_id = row / batch_size_per_sc;
extracted_coo_tensors.coo_tensors_per_sc[sc_id] += embedding_ids.size();
for (int embedding_id : embedding_ids) {
extracted_coo_tensors.coo_tensors.push_back(
extracted_sc_tensors[sc_id].coo_tensors.push_back(
CooFormat(row, embedding_id, 1.0));
}
}
return extracted_coo_tensors;
return extracted_sc_tensors;
}

std::vector<std::unique_ptr<AbstractInputBatch>>
Expand Down Expand Up @@ -173,10 +174,13 @@ void BM_ExtractCooTensors(benchmark::State& state) {
.feature_stacking_strategy = FeatureStackingStrategy::kSplitThenStack,
};

std::vector<ExtractedSparseCoreTensors> extracted_tensors_per_sc(
kNumScPerDevice);

for (auto s : state) {
internal::ExtractCooTensorsForAllFeaturesPerLocalDevice(
internal::ExtractCooTensorsForLocalDevice(
stacked_table_metadata, absl::MakeSpan(input_batches),
/*local_device_id=*/0, options);
/*local_device=*/0, options, absl::MakeSpan(extracted_tensors_per_sc));
}
}
BENCHMARK(BM_ExtractCooTensors)
Expand All @@ -188,7 +192,7 @@ BENCHMARK(BM_ExtractCooTensors)
->UseRealTime();

void BM_SortAndGroup_Phase1(benchmark::State& state) {
ExtractedCooTensors extracted_coo_tensors =
std::vector<ExtractedSparseCoreTensors> per_sc_tensors =
GenerateSkewedCooTensors(kNumScPerDevice, kBatchSizePerSc, kVocabSize);

// Set to INT_MAX to avoid ID dropping and observe the actual statistics of
Expand All @@ -206,28 +210,41 @@ void BM_SortAndGroup_Phase1(benchmark::State& state) {
.num_sc_per_device = kNumScPerDevice,
.allow_id_dropping = false,
};
bool minibatching_required = false;
StatsPerHost stats_per_host(
/*local_device_count=*/1,
/*global_sc_count=*/kNumScPerDevice * kGlobalDeviceCount,
/*num_sc_per_device=*/kNumScPerDevice);
internal::StatsPerDevice stats_per_device =
stats_per_host.GetStatsPerDevice(0);

if (state.thread_index() == 0) {
SortAndGroupCooTensorsPerLocalDevice</*kHasVariableWeights=*/false>(
extracted_coo_tensors, stacked_table_metadata, options,
stats_per_device, minibatching_required);
bool minibatching_required = false;
StatsPerHost stats_per_host(
/*local_device_count=*/1,
/*global_sc_count=*/kNumScPerDevice * kGlobalDeviceCount,
/*num_sc_per_device=*/kNumScPerDevice);
internal::StatsPerDevice stats_per_device =
stats_per_host.GetStatsPerDevice(0);
int dropped_id_count = 0;
for (int sc_id = 0; sc_id < kNumScPerDevice; ++sc_id) {
PerSparseCoreGroupedData result =
internal::SortAndGroupCooTensorsForSingleSparseCore<
/*kHasVariableWeights=*/false>(per_sc_tensors[sc_id],
/*local_device_id=*/0, sc_id,
options, stacked_table_metadata,
minibatching_required);
internal::AggregatePerSparseCoreStats(result, sc_id, stats_per_device,
dropped_id_count);
}
LogStats(stats_per_device.max_ids_per_partition,
"Max ids per partition across all global SCs");
LogStats(stats_per_device.max_unique_ids_per_partition,
"Max unique ids per partition across all global SCs");
}

for (auto s : state) {
SortAndGroupCooTensorsPerLocalDevice</*kHasVariableWeights=*/false>(
extracted_coo_tensors, stacked_table_metadata, options,
stats_per_device, minibatching_required);
bool minibatching_required = false;
for (int sc_id = 0; sc_id < kNumScPerDevice; ++sc_id) {
internal::SortAndGroupCooTensorsForSingleSparseCore<
/*kHasVariableWeights=*/false>(per_sc_tensors[sc_id],
/*local_device_id=*/0, sc_id, options,
stacked_table_metadata,
minibatching_required);
}
}
}
BENCHMARK(BM_SortAndGroup_Phase1)
Expand Down
Loading
Loading