Skip to content

Commit aa320e1

Browse files
[JAX SC] Add combiner to sort benchmarks.
PiperOrigin-RevId: 838857288
1 parent ca34f3e commit aa320e1

File tree

1 file changed

+22
-1
lines changed

1 file changed

+22
-1
lines changed

jax_tpu_embedding/sparsecore/lib/core/extract_sort_and_group_benchmark.cc

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
#include <memory>
1919
#include <optional>
2020
#include <random>
21+
#include <string>
2122
#include <utility>
2223
#include <vector>
2324

@@ -38,6 +39,18 @@
3839
namespace jax_sc_embedding {
3940

4041
namespace {
42+
43+
std::string CombinerToString(RowCombiner combiner) {
44+
switch (combiner) {
45+
case RowCombiner::kSum:
46+
return "kSum";
47+
case RowCombiner::kMean:
48+
return "kMean";
49+
case RowCombiner::kSqrtn:
50+
return "kSqrtn";
51+
}
52+
}
53+
4154
template <typename Derived>
4255
void LogStats(const Eigen::MatrixBase<Derived>& data, absl::string_view name) {
4356
if (data.size() == 0) {
@@ -147,6 +160,7 @@ GenerateSkewedRaggedTensorInputBatches(int num_sc_per_device,
147160
void BM_ExtractCooTensors(benchmark::State& state) {
148161
const int num_features = state.range(0);
149162
const RowCombiner combiner = static_cast<RowCombiner>(state.range(1));
163+
state.SetLabel(CombinerToString(combiner));
150164
std::vector<std::unique_ptr<AbstractInputBatch>> input_batches =
151165
GenerateSkewedRaggedTensorInputBatches(kNumScPerDevice, kBatchSizePerSc,
152166
kVocabSize, num_features);
@@ -188,8 +202,10 @@ BENCHMARK(BM_ExtractCooTensors)
188202
->UseRealTime();
189203

190204
void BM_SortAndGroup_Phase1(benchmark::State& state) {
205+
const RowCombiner combiner = static_cast<RowCombiner>(state.range(0));
191206
ExtractedCooTensors extracted_coo_tensors =
192207
GenerateSkewedCooTensors(kNumScPerDevice, kBatchSizePerSc, kVocabSize);
208+
state.SetLabel(CombinerToString(combiner));
193209

194210
// Set to INT_MAX to avoid ID dropping and observe the actual statistics of
195211
// the generated data. This doesn't affect performance of grouping itself.
@@ -199,7 +215,9 @@ void BM_SortAndGroup_Phase1(benchmark::State& state) {
199215
/*max_unique_ids_per_partition=*/std::numeric_limits<int>::max(),
200216
/*row_offset=*/0,
201217
/*col_offset=*/0,
202-
/*col_shift=*/0, /*batch_size=*/0);
218+
/*col_shift=*/0, /*batch_size=*/0,
219+
/*suggested_coo_buffer_size_per_device=*/std::nullopt,
220+
/*row_combiner=*/combiner);
203221
PreprocessSparseDenseMatmulInputOptions options = {
204222
.local_device_count = 1,
205223
.global_device_count = kGlobalDeviceCount,
@@ -231,6 +249,9 @@ void BM_SortAndGroup_Phase1(benchmark::State& state) {
231249
}
232250
}
233251
BENCHMARK(BM_SortAndGroup_Phase1)
252+
->Arg(0) // kSum
253+
->Arg(1) // kMean
254+
->Arg(2) // kSqrtn
234255
->Threads(8)
235256
->UseRealTime();
236257

0 commit comments

Comments
 (0)