Skip to content

Commit 9ec2ba5

Browse files
[JAX SC] Update benchmarks to use enable_minibatching.
PiperOrigin-RevId: 838215624
1 parent aa320e1 commit 9ec2ba5

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

jax_tpu_embedding/sparsecore/lib/core/extract_sort_and_group_benchmark.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,7 +184,7 @@ void BM_ExtractCooTensors(benchmark::State& state) {
184184
.global_device_count = kGlobalDeviceCount,
185185
.num_sc_per_device = kNumScPerDevice,
186186
.allow_id_dropping = false,
187-
.feature_stacking_strategy = FeatureStackingStrategy::kSplitThenStack,
187+
.enable_minibatching = true,
188188
};
189189

190190
for (auto s : state) {
@@ -223,6 +223,7 @@ void BM_SortAndGroup_Phase1(benchmark::State& state) {
223223
.global_device_count = kGlobalDeviceCount,
224224
.num_sc_per_device = kNumScPerDevice,
225225
.allow_id_dropping = false,
226+
.enable_minibatching = true,
226227
};
227228
bool minibatching_required = false;
228229
StatsPerHost stats_per_host(

0 commit comments

Comments
 (0)