From 882ec0596583637a1bcb6263695df17a7110e6dd Mon Sep 17 00:00:00 2001 From: Henry Jackson-Flux Date: Wed, 18 Sep 2024 14:52:34 +0100 Subject: [PATCH] Change the behaviour of prepare_dataloader to make sure it does not throw away data if it does not need to. --- src/slicegpt/data_utils.py | 20 ++++++++++++++++---- 1 file changed, 16 insertions(+), 4 deletions(-) diff --git a/src/slicegpt/data_utils.py b/src/slicegpt/data_utils.py index 09b85386..56f34e42 100755 --- a/src/slicegpt/data_utils.py +++ b/src/slicegpt/data_utils.py @@ -150,13 +150,25 @@ def prepare_dataloader( start_idx = torch.randint(0, len(indices), (1,)).item() idx = start_idx tokens = [] - while len(tokens) < max_seqlen and idx < len(indices): + while len(tokens) < max_seqlen: item = data_list[indices[idx]] sep = "" if not tokens else "\n\n" tokens += tokenizer.tokenize(sep + item) - idx += 1 - - indices = indices[:start_idx] + indices[idx:] # remove the used indices + idx = (idx + 1) % len(indices) + if idx == start_idx: + # In this case, idx has wrapped around and caught up with start_idx. + # There is no more data left to continue. + break + + if idx <= start_idx: + # We wrapped around and used the indices in the rage + # [start_idx:end] and [0:idx) + # Remaining indices are: + indices = indices[idx:start_idx] + else: + # We used the indices in the range [start_idx:idx) + # Remaining indices are: + indices = indices[:start_idx] + indices[idx:] if len(tokens) >= max_seqlen: tokens = tokens[:max_seqlen] # truncate to max_seqlen