From 5444ec693811f33ed41a2c1401d78c1e4c6ee4e2 Mon Sep 17 00:00:00 2001 From: Beatrice Bevilacqua Date: Fri, 25 Jul 2025 16:35:15 -0700 Subject: [PATCH] Simplify _create_sentences and _create_batches logic --- src/cell_load/data_modules/samplers.py | 56 +++++++++++--------------- 1 file changed, 23 insertions(+), 33 deletions(-) diff --git a/src/cell_load/data_modules/samplers.py b/src/cell_load/data_modules/samplers.py index a5f1243..64163b9 100644 --- a/src/cell_load/data_modules/samplers.py +++ b/src/cell_load/data_modules/samplers.py @@ -101,7 +101,7 @@ def __init__( f"Sampler created with {len(self.batches)} batches in {end_time - start_time:.2f} seconds." ) - def _create_batches(self) -> list[list[int]]: + def _create_batches(self) -> list[np.ndarray]: """ Combines existing batches into meta-batches of size batch_size * cell_sentence_len, sampling with replacement if needed to reach cell_sentence_len. @@ -115,46 +115,33 @@ def _create_batches(self) -> list[list[int]]: else: rank_sentences = self.sentences - all_batches = [] - current_batch = [] - - num_full = 0 + # If a sentence is smaller than self.cell_sentence_len, + # upsample it drawing from the sentence itself + upsampled = [] num_partial = 0 for sentence in rank_sentences: - # If batch is smaller than cell_sentence_len, sample with replacement if len(sentence) < self.cell_sentence_len and not self.test: # during inference, don't sample by replacement - new_sentence = np.random.choice( + sentence = np.random.choice( sentence, size=self.cell_sentence_len, replace=True ).tolist() num_partial += 1 - else: - new_sentence = copy.deepcopy(sentence) - assert len(new_sentence) == self.cell_sentence_len or self.test - num_full += 1 - - sentence_len = len(new_sentence) if self.test else self.cell_sentence_len + upsampled.append(sentence) - if len(current_batch) + len(new_sentence) <= self.batch_size * sentence_len: - current_batch.extend(new_sentence) - else: - if current_batch: # Add the completed meta-batch - all_batches.append(current_batch) - current_batch = new_sentence + # Split the sentences into batches + all_batches = [] + for i in range(0, len(upsampled), self.batch_size): + all_batches.append(np.concat(upsampled[i : i + self.batch_size])) if self.distributed: logger.info( - f"Rank {self.rank}: Of {len(rank_sentences)} sentences, {num_full} were full and {num_partial} were partial." + f"Rank {self.rank}: Of {len(rank_sentences)} sentences, {len(rank_sentences) - num_partial} were full and {num_partial} were partial." ) else: logger.info( - f"Of all batches, {num_full} were full and {num_partial} were partial." + f"Of all batches, {len(rank_sentences) - num_partial} were full and {num_partial} were partial." ) - # Add the last meta-batch if it exists - if current_batch: - all_batches.append(current_batch) - return all_batches def _get_rank_sentences(self) -> list[list[int]]: @@ -190,7 +177,7 @@ def _get_rank_sentences(self) -> list[list[int]]: return rank_sentences - def _process_subset(self, global_offset: int, subset: Subset) -> list[list[int]]: + def _process_subset(self, global_offset: int, subset: Subset) -> list[np.ndarray]: """ Process a single subset to create batches based on H5 codes. @@ -241,16 +228,19 @@ def _process_subset(self, global_offset: int, subset: Subset) -> list[list[int]] group_indices = global_indices[mask] np.random.shuffle(group_indices) - # Split the group indices into batches. - for i in range(0, len(group_indices), self.cell_sentence_len): - sentence = group_indices[i : i + self.cell_sentence_len].tolist() - if len(sentence) < self.cell_sentence_len and self.drop_last: - continue - subset_batches.append(sentence) + sentences = np.array_split( + group_indices, + np.arange( + self.cell_sentence_len, len(group_indices), self.cell_sentence_len + ), + ) + if sentences[-1].shape[0] < self.cell_sentence_len and self.drop_last: + sentences.pop() + subset_batches.extend(sentences) return subset_batches - def _create_sentences(self) -> list[list[int]]: + def _create_sentences(self) -> list[np.ndarray]: """ Process each subset sequentially (across all datasets) and combine the batches. """