Description
Hi! I was following the tutorial and noticed an edge case in the MaskedLMLoader class.
In __getitem__, when len(data) > self.max_seq_len, the code performs a random slice using:
rand_start_idx = random.choice(list(range(len(data) - self.max_seq_len)))
end_idx = rand_start_idx + self.max_seq_len
data = data[rand_start_idx:end_idx]
if the input data relies on special tokens (like <BOS> at index 0 and <EOS> at index -1), this random slicing often cuts them off. This results in the model training on sequences that lack the necessary start/end context markers.
Proposed Fix We should likely strip the special tokens first, perform the random slicing on the content only, and then re-attach the special tokens.
Here is a potential implementation of the fix:
if len(data) > self.max_seq_len:
# 1. Isolate special tokens
bos_token = data[0]
eos_token = data[-1]
content = data[1:-1]
# 2. Slice content (reserving 2 spots for BOS/EOS)
max_content_len = self.max_seq_len - 2
if len(content) > max_content_len:
rand_start_idx = random.choice(list(range(len(content) - max_content_len + 1)))
content = content[rand_start_idx : rand_start_idx + max_content_len]
# 3. Reconstruct
data = torch.cat([
bos_token.unsqueeze(0),
content,
eos_token.unsqueeze(0)
])
Description
Hi! I was following the tutorial and noticed an edge case in the
MaskedLMLoaderclass.In
__getitem__, whenlen(data) > self.max_seq_len, the code performs a random slice using: