Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions src/slicegpt/data_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,25 @@ def prepare_dataloader(
start_idx = torch.randint(0, len(indices), (1,)).item()
idx = start_idx
tokens = []
while len(tokens) < max_seqlen and idx < len(indices):
while len(tokens) < max_seqlen:
item = data_list[indices[idx]]
sep = "" if not tokens else "\n\n"
tokens += tokenizer.tokenize(sep + item)
idx += 1

indices = indices[:start_idx] + indices[idx:] # remove the used indices
idx = (idx + 1) % len(indices)
if idx == start_idx:
# In this case, idx has wrapped around and caught up with start_idx.
# There is no more data left to continue.
break

if idx <= start_idx:
# We wrapped around and used the indices in the rage
# [start_idx:end] and [0:idx)
# Remaining indices are:
indices = indices[idx:start_idx]
else:
# We used the indices in the range [start_idx:idx)
# Remaining indices are:
indices = indices[:start_idx] + indices[idx:]

if len(tokens) >= max_seqlen:
tokens = tokens[:max_seqlen] # truncate to max_seqlen
Expand Down