diff --git a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc index 9f1fddc7..5cc62b6f 100644 --- a/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc +++ b/jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc @@ -449,6 +449,48 @@ void FillDeviceBuffersForTable( state.partitioned_coo_tensors_per_device[local_device]; if (options.enable_minibatching && global_minibatching_required) { grouped_coo_tensors.Merge(global_minibatching_split); + + // Recompute required buffer size per SC per device using + // the minibatch-aware layout. Each minibatch is padded to the TPU + // HBM alignment (8), and all minibatches for a given device + // are packed into a single buffer. We aggregate per SC the sum + // of aligned minibatch sizes to reflect the true required buffer usage + // when minibatching is enabled. + auto required_row = + state.stats_per_host.required_buffer_size.row(local_device); + const uint32_t global_sc_count = options.GetNumScs(); + for (int local_sc_id = 0; local_sc_id < options.num_sc_per_device; + ++local_sc_id) { + int total_required = 0; + for (int minibatch_id = 0; + minibatch_id < grouped_coo_tensors.GetNumMinibatches(); + ++minibatch_id) { + const auto span = grouped_coo_tensors(local_sc_id, minibatch_id); + // Accumulate per-global-SC partition sizes within this minibatch, + // rounding each partition up to alignment to match the fill logic + // (which aligns at each partition boundary). + int partition_count = 0; + uint32_t prev_global_sc = 0; + bool have_prev = false; + for (const auto& coo : span) { + const uint32_t global_sc_id = coo.col_id & (global_sc_count - 1); + if (have_prev && global_sc_id != prev_global_sc && + partition_count > 0) { + total_required += xla::RoundUpTo( + partition_count, TPU_VECTOR_REGISTER_ALIGNMENT_SIZE); + partition_count = 0; + } + prev_global_sc = global_sc_id; + have_prev = true; + ++partition_count; + } + if (partition_count > 0) { + total_required += xla::RoundUpTo( + partition_count, TPU_VECTOR_REGISTER_ALIGNMENT_SIZE); + } + } + required_row[local_sc_id] = total_required; + } } const int batch_size_per_sc = xla::CeilOfRatio(state.batch_size_for_device,