1818#include < memory>
1919#include < optional>
2020#include < random>
21+ #include < string>
2122#include < utility>
2223#include < vector>
2324
3839namespace jax_sc_embedding {
3940
4041namespace {
42+
43+ std::string CombinerToString (RowCombiner combiner) {
44+ switch (combiner) {
45+ case RowCombiner::kSum :
46+ return " kSum" ;
47+ case RowCombiner::kMean :
48+ return " kMean" ;
49+ case RowCombiner::kSqrtn :
50+ return " kSqrtn" ;
51+ }
52+ }
53+
4154template <typename Derived>
4255void LogStats (const Eigen::MatrixBase<Derived>& data, absl::string_view name) {
4356 if (data.size () == 0 ) {
@@ -147,6 +160,7 @@ GenerateSkewedRaggedTensorInputBatches(int num_sc_per_device,
147160void BM_ExtractCooTensors (benchmark::State& state) {
148161 const int num_features = state.range (0 );
149162 const RowCombiner combiner = static_cast <RowCombiner>(state.range (1 ));
163+ state.SetLabel (CombinerToString (combiner));
150164 std::vector<std::unique_ptr<AbstractInputBatch>> input_batches =
151165 GenerateSkewedRaggedTensorInputBatches (kNumScPerDevice , kBatchSizePerSc ,
152166 kVocabSize , num_features);
@@ -188,8 +202,10 @@ BENCHMARK(BM_ExtractCooTensors)
188202 ->UseRealTime();
189203
190204void BM_SortAndGroup_Phase1 (benchmark::State& state) {
205+ const RowCombiner combiner = static_cast <RowCombiner>(state.range (0 ));
191206 ExtractedCooTensors extracted_coo_tensors =
192207 GenerateSkewedCooTensors (kNumScPerDevice , kBatchSizePerSc , kVocabSize );
208+ state.SetLabel (CombinerToString (combiner));
193209
194210 // Set to INT_MAX to avoid ID dropping and observe the actual statistics of
195211 // the generated data. This doesn't affect performance of grouping itself.
@@ -199,7 +215,9 @@ void BM_SortAndGroup_Phase1(benchmark::State& state) {
199215 /* max_unique_ids_per_partition=*/ std::numeric_limits<int >::max (),
200216 /* row_offset=*/ 0 ,
201217 /* col_offset=*/ 0 ,
202- /* col_shift=*/ 0 , /* batch_size=*/ 0 );
218+ /* col_shift=*/ 0 , /* batch_size=*/ 0 ,
219+ /* suggested_coo_buffer_size_per_device=*/ std::nullopt ,
220+ /* row_combiner=*/ combiner);
203221 PreprocessSparseDenseMatmulInputOptions options = {
204222 .local_device_count = 1 ,
205223 .global_device_count = kGlobalDeviceCount ,
@@ -231,6 +249,9 @@ void BM_SortAndGroup_Phase1(benchmark::State& state) {
231249 }
232250}
233251BENCHMARK (BM_SortAndGroup_Phase1)
252+ ->Arg (0 ) // kSum
253+ ->Arg(1 ) // kMean
254+ ->Arg(2 ) // kSqrtn
234255 ->Threads(8 )
235256 ->UseRealTime();
236257
0 commit comments