diff --git a/src/slicegpt/data_utils.py b/src/slicegpt/data_utils.py index 09b85386..56f34e42 100755 --- a/src/slicegpt/data_utils.py +++ b/src/slicegpt/data_utils.py @@ -150,13 +150,25 @@ def prepare_dataloader( start_idx = torch.randint(0, len(indices), (1,)).item() idx = start_idx tokens = [] - while len(tokens) < max_seqlen and idx < len(indices): + while len(tokens) < max_seqlen: item = data_list[indices[idx]] sep = "" if not tokens else "\n\n" tokens += tokenizer.tokenize(sep + item) - idx += 1 - - indices = indices[:start_idx] + indices[idx:] # remove the used indices + idx = (idx + 1) % len(indices) + if idx == start_idx: + # In this case, idx has wrapped around and caught up with start_idx. + # There is no more data left to continue. + break + + if idx <= start_idx: + # We wrapped around and used the indices in the rage + # [start_idx:end] and [0:idx) + # Remaining indices are: + indices = indices[idx:start_idx] + else: + # We used the indices in the range [start_idx:idx) + # Remaining indices are: + indices = indices[:start_idx] + indices[idx:] if len(tokens) >= max_seqlen: tokens = tokens[:max_seqlen] # truncate to max_seqlen