-
Notifications
You must be signed in to change notification settings - Fork 20
Expand file tree
/
Copy pathdata_loader.py
More file actions
102 lines (82 loc) · 3.88 KB
/
data_loader.py
File metadata and controls
102 lines (82 loc) · 3.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import nltk
import json
import torch
import torch.utils.data as data
class Dataset(data.Dataset):
"""Custom data.Dataset compatible with data.DataLoader."""
def __init__(self, src_path, trg_path, src_word2id, trg_word2id):
"""Reads source and target sequences from txt files."""
self.src_seqs = open(src_path).readlines()
self.trg_seqs = open(trg_path).readlines()
self.num_total_seqs = len(self.src_seqs)
self.src_word2id = src_word2id
self.trg_word2id = trg_word2id
def __getitem__(self, index):
"""Returns one data pair (source and target)."""
src_seq = self.src_seqs[index]
trg_seq = self.trg_seqs[index]
src_seq = self.preprocess(src_seq, self.src_word2id, trg=False)
trg_seq = self.preprocess(trg_seq, self.trg_word2id)
return src_seq, trg_seq
def __len__(self):
return self.num_total_seqs
def preprocess(self, sequence, word2id, trg=True):
"""Converts words to ids."""
tokens = nltk.tokenize.word_tokenize(sequence.lower())
sequence = []
sequence.append(word2id['<start>'])
sequence.extend([word2id[token] for token in tokens if token in word2id])
sequence.append(word2id['<end>'])
sequence = torch.Tensor(sequence)
return sequence
def collate_fn(data):
"""Creates mini-batch tensors from the list of tuples (src_seq, trg_seq).
We should build a custom collate_fn rather than using default collate_fn,
because merging sequences (including padding) is not supported in default.
Seqeuences are padded to the maximum length of mini-batch sequences (dynamic padding).
Args:
data: list of tuple (src_seq, trg_seq).
- src_seq: torch tensor of shape (?); variable length.
- trg_seq: torch tensor of shape (?); variable length.
Returns:
src_seqs: torch tensor of shape (batch_size, padded_length).
src_lengths: list of length (batch_size); valid length for each padded source sequence.
trg_seqs: torch tensor of shape (batch_size, padded_length).
trg_lengths: list of length (batch_size); valid length for each padded target sequence.
"""
def merge(sequences):
lengths = [len(seq) for seq in sequences]
padded_seqs = torch.zeros(len(sequences), max(lengths)).long()
for i, seq in enumerate(sequences):
end = lengths[i]
padded_seqs[i, :end] = seq[:end]
return padded_seqs, lengths
# sort a list by sequence length (descending order) to use pack_padded_sequence
data.sort(key=lambda x: len(x[0]), reverse=True)
# seperate source and target sequences
src_seqs, trg_seqs = zip(*data)
# merge sequences (from tuple of 1D tensor to 2D tensor)
src_seqs, src_lengths = merge(src_seqs)
trg_seqs, trg_lengths = merge(trg_seqs)
return src_seqs, src_lengths, trg_seqs, trg_lengths
def get_loader(src_path, trg_path, src_word2id, trg_word2id, batch_size=100):
"""Returns data loader for custom dataset.
Args:
src_path: txt file path for source domain.
trg_path: txt file path for target domain.
src_word2id: word-to-id dictionary (source domain).
trg_word2id: word-to-id dictionary (target domain).
batch_size: mini-batch size.
Returns:
data_loader: data loader for custom dataset.
"""
# build a custom dataset
dataset = Dataset(src_path, trg_path, src_word2id, trg_word2id)
# data loader for custome dataset
# this will return (src_seqs, src_lengths, trg_seqs, trg_lengths) for each iteration
# please see collate_fn for details
data_loader = torch.utils.data.DataLoader(dataset=dataset,
batch_size=batch_size,
shuffle=True,
collate_fn=collate_fn)
return data_loader