From 44971e82f8ea8e0f6208ea36c3b651c42ea08ee1 Mon Sep 17 00:00:00 2001 From: The JAX SC Authors Date: Mon, 27 Oct 2025 22:49:10 -0700 Subject: [PATCH] [JAX SC] `required_buffer_size_per_sc` is being undercounted when minibatching is enabled. The initial calculation does not account for the alignment padding added at the end of each minibatch. This potentially leads to the following sequence of events: 1. A smaller-than-needed buffer size was calculated and reported. FDO used this incorrect value to reconfigure the buffer for the next run. 2. The buffer would inevitably overflow and drop IDs. 3. Could cause a persistent FDO loop. Now, we can recompute the required buffer size after the minibatches have been merged and their final memory layout is known. PiperOrigin-RevId: 824850862 --- .../lib/core/input_preprocessing.cc | 42 +++++++++++++++++++ 1 file changed, 42 insertions(+) diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc index 9f1fddc7..5cc62b6f 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc @@ -449,6 +449,48 @@ void FillDeviceBuffersForTable( state.partitioned_coo_tensors_per_device[local_device]; if (options.enable_minibatching && global_minibatching_required) { grouped_coo_tensors.Merge(global_minibatching_split); + + // Recompute required buffer size per SC per device using + // the minibatch-aware layout. Each minibatch is padded to the TPU + // HBM alignment (8), and all minibatches for a given device + // are packed into a single buffer. We aggregate per SC the sum + // of aligned minibatch sizes to reflect the true required buffer usage + // when minibatching is enabled. + auto required_row = + state.stats_per_host.required_buffer_size.row(local_device); + const uint32_t global_sc_count = options.GetNumScs(); + for (int local_sc_id = 0; local_sc_id < options.num_sc_per_device; + ++local_sc_id) { + int total_required = 0; + for (int minibatch_id = 0; + minibatch_id < grouped_coo_tensors.GetNumMinibatches(); + ++minibatch_id) { + const auto span = grouped_coo_tensors(local_sc_id, minibatch_id); + // Accumulate per-global-SC partition sizes within this minibatch, + // rounding each partition up to alignment to match the fill logic + // (which aligns at each partition boundary). + int partition_count = 0; + uint32_t prev_global_sc = 0; + bool have_prev = false; + for (const auto& coo : span) { + const uint32_t global_sc_id = coo.col_id & (global_sc_count - 1); + if (have_prev && global_sc_id != prev_global_sc && + partition_count > 0) { + total_required += xla::RoundUpTo( + partition_count, TPU_VECTOR_REGISTER_ALIGNMENT_SIZE); + partition_count = 0; + } + prev_global_sc = global_sc_id; + have_prev = true; + ++partition_count; + } + if (partition_count > 0) { + total_required += xla::RoundUpTo( + partition_count, TPU_VECTOR_REGISTER_ALIGNMENT_SIZE); + } + } + required_row[local_sc_id] = total_required; + } } const int batch_size_per_sc = xla::CeilOfRatio(state.batch_size_for_device,