Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions jax_tpu_embedding/sparsecore/lib/core/input_preprocessing.cc
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,48 @@ void FillDeviceBuffersForTable(
state.partitioned_coo_tensors_per_device[local_device];
if (options.enable_minibatching && global_minibatching_required) {
grouped_coo_tensors.Merge(global_minibatching_split);

// Recompute required buffer size per SC per device using
// the minibatch-aware layout. Each minibatch is padded to the TPU
// HBM alignment (8), and all minibatches for a given device
// are packed into a single buffer. We aggregate per SC the sum
// of aligned minibatch sizes to reflect the true required buffer usage
// when minibatching is enabled.
auto required_row =
state.stats_per_host.required_buffer_size.row(local_device);
const uint32_t global_sc_count = options.GetNumScs();
for (int local_sc_id = 0; local_sc_id < options.num_sc_per_device;
++local_sc_id) {
int total_required = 0;
for (int minibatch_id = 0;
minibatch_id < grouped_coo_tensors.GetNumMinibatches();
++minibatch_id) {
const auto span = grouped_coo_tensors(local_sc_id, minibatch_id);
// Accumulate per-global-SC partition sizes within this minibatch,
// rounding each partition up to alignment to match the fill logic
// (which aligns at each partition boundary).
int partition_count = 0;
uint32_t prev_global_sc = 0;
bool have_prev = false;
for (const auto& coo : span) {
const uint32_t global_sc_id = coo.col_id & (global_sc_count - 1);
if (have_prev && global_sc_id != prev_global_sc &&
partition_count > 0) {
total_required += xla::RoundUpTo(
partition_count, TPU_VECTOR_REGISTER_ALIGNMENT_SIZE);
partition_count = 0;
}
prev_global_sc = global_sc_id;
have_prev = true;
++partition_count;
}
if (partition_count > 0) {
total_required += xla::RoundUpTo(
partition_count, TPU_VECTOR_REGISTER_ALIGNMENT_SIZE);
}
}
required_row[local_sc_id] = total_required;
}
}

const int batch_size_per_sc = xla::CeilOfRatio(state.batch_size_for_device,
Expand Down