@@ -152,15 +152,22 @@ def create_dataloader(
152152 shuffle : bool = False ,
153153 drop_last : bool = False ,
154154 num_workers : int = os .cpu_count () - 1 ,
155+ pin_memory : bool = True ,
156+ persistent_workers : bool = True ,
155157 ** kwargs ,
156158 ) -> torch .utils .data .DataLoader :
157159 """
158- Creates a Dataloader.
160+ Creates a Dataloader from the FastTextModelDataset.
161+ Use collate_fn() to tokenize and pad the sequences.
159162
160163 Args:
161164 batch_size (int): Batch size.
162165 shuffle (bool, optional): Shuffle option. Defaults to False.
163166 drop_last (bool, optional): Drop last option. Defaults to False.
167+ num_workers (int, optional): Number of workers. Defaults to os.cpu_count() - 1.
168+ pin_memory (bool, optional): Set True if working on GPU, False if CPU. Defaults to True.
169+ persistent_workers (bool, optional): Set True for training, False for inference. Defaults to True.
170+ **kwargs: Additional arguments for PyTorch DataLoader.
164171
165172 Returns:
166173 torch.utils.data.DataLoader: Dataloader.
@@ -174,7 +181,8 @@ def create_dataloader(
174181 collate_fn = self .collate_fn ,
175182 shuffle = shuffle ,
176183 drop_last = drop_last ,
177- pin_memory = True ,
184+ pin_memory = pin_memory ,
178185 num_workers = num_workers ,
179- persistent_workers = True ,
186+ persistent_workers = persistent_workers ,
187+ ** kwargs ,
180188 )
0 commit comments