Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 23 additions & 33 deletions src/cell_load/data_modules/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def __init__(
f"Sampler created with {len(self.batches)} batches in {end_time - start_time:.2f} seconds."
)

def _create_batches(self) -> list[list[int]]:
def _create_batches(self) -> list[np.ndarray]:
"""
Combines existing batches into meta-batches of size batch_size * cell_sentence_len,
sampling with replacement if needed to reach cell_sentence_len.
Expand All @@ -115,46 +115,33 @@ def _create_batches(self) -> list[list[int]]:
else:
rank_sentences = self.sentences

all_batches = []
current_batch = []

num_full = 0
# If a sentence is smaller than self.cell_sentence_len,
# upsample it drawing from the sentence itself
upsampled = []
num_partial = 0
for sentence in rank_sentences:
# If batch is smaller than cell_sentence_len, sample with replacement
if len(sentence) < self.cell_sentence_len and not self.test:
# during inference, don't sample by replacement
new_sentence = np.random.choice(
sentence = np.random.choice(
sentence, size=self.cell_sentence_len, replace=True
).tolist()
num_partial += 1
else:
new_sentence = copy.deepcopy(sentence)
assert len(new_sentence) == self.cell_sentence_len or self.test
num_full += 1

sentence_len = len(new_sentence) if self.test else self.cell_sentence_len
upsampled.append(sentence)

if len(current_batch) + len(new_sentence) <= self.batch_size * sentence_len:
current_batch.extend(new_sentence)
else:
if current_batch: # Add the completed meta-batch
all_batches.append(current_batch)
current_batch = new_sentence
# Split the sentences into batches
all_batches = []
for i in range(0, len(upsampled), self.batch_size):
all_batches.append(np.concat(upsampled[i : i + self.batch_size]))

if self.distributed:
logger.info(
f"Rank {self.rank}: Of {len(rank_sentences)} sentences, {num_full} were full and {num_partial} were partial."
f"Rank {self.rank}: Of {len(rank_sentences)} sentences, {len(rank_sentences) - num_partial} were full and {num_partial} were partial."
)
else:
logger.info(
f"Of all batches, {num_full} were full and {num_partial} were partial."
f"Of all batches, {len(rank_sentences) - num_partial} were full and {num_partial} were partial."
)

# Add the last meta-batch if it exists
if current_batch:
all_batches.append(current_batch)

return all_batches

def _get_rank_sentences(self) -> list[list[int]]:
Expand Down Expand Up @@ -190,7 +177,7 @@ def _get_rank_sentences(self) -> list[list[int]]:

return rank_sentences

def _process_subset(self, global_offset: int, subset: Subset) -> list[list[int]]:
def _process_subset(self, global_offset: int, subset: Subset) -> list[np.ndarray]:
"""
Process a single subset to create batches based on H5 codes.

Expand Down Expand Up @@ -241,16 +228,19 @@ def _process_subset(self, global_offset: int, subset: Subset) -> list[list[int]]
group_indices = global_indices[mask]
np.random.shuffle(group_indices)

# Split the group indices into batches.
for i in range(0, len(group_indices), self.cell_sentence_len):
sentence = group_indices[i : i + self.cell_sentence_len].tolist()
if len(sentence) < self.cell_sentence_len and self.drop_last:
continue
subset_batches.append(sentence)
sentences = np.array_split(
group_indices,
np.arange(
self.cell_sentence_len, len(group_indices), self.cell_sentence_len
),
)
if sentences[-1].shape[0] < self.cell_sentence_len and self.drop_last:
sentences.pop()
subset_batches.extend(sentences)

return subset_batches

def _create_sentences(self) -> list[list[int]]:
def _create_sentences(self) -> list[np.ndarray]:
"""
Process each subset sequentially (across all datasets) and combine the batches.
"""
Expand Down
Loading