Skip to content

Commit a1f12ec

Browse files
meilame-tayebjeemicedre
authored andcommitted
feat: enable flexibility in the dataloader creation
Possibility to override default params and also pass any additional params that is accepted by PyTorch DataLoader
1 parent 72fc92c commit a1f12ec

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

torchFastText/datasets/dataset.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -152,15 +152,22 @@ def create_dataloader(
152152
shuffle: bool = False,
153153
drop_last: bool = False,
154154
num_workers: int = os.cpu_count() - 1,
155+
pin_memory: bool = True,
156+
persistent_workers: bool = True,
155157
**kwargs,
156158
) -> torch.utils.data.DataLoader:
157159
"""
158-
Creates a Dataloader.
160+
Creates a Dataloader from the FastTextModelDataset.
161+
Use collate_fn() to tokenize and pad the sequences.
159162
160163
Args:
161164
batch_size (int): Batch size.
162165
shuffle (bool, optional): Shuffle option. Defaults to False.
163166
drop_last (bool, optional): Drop last option. Defaults to False.
167+
num_workers (int, optional): Number of workers. Defaults to os.cpu_count() - 1.
168+
pin_memory (bool, optional): Set True if working on GPU, False if CPU. Defaults to True.
169+
persistent_workers (bool, optional): Set True for training, False for inference. Defaults to True.
170+
**kwargs: Additional arguments for PyTorch DataLoader.
164171
165172
Returns:
166173
torch.utils.data.DataLoader: Dataloader.
@@ -174,7 +181,8 @@ def create_dataloader(
174181
collate_fn=self.collate_fn,
175182
shuffle=shuffle,
176183
drop_last=drop_last,
177-
pin_memory=True,
184+
pin_memory=pin_memory,
178185
num_workers=num_workers,
179-
persistent_workers=True,
186+
persistent_workers=persistent_workers,
187+
**kwargs,
180188
)

0 commit comments

Comments
 (0)