-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathHFDataModuleExample.py
More file actions
91 lines (70 loc) · 3.35 KB
/
HFDataModuleExample.py
File metadata and controls
91 lines (70 loc) · 3.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import os
import pytorch_lightning as pl
from transformers import (
BertTokenizerFast,
CLIPTokenizer
)
os.environ["TOKENIZERS_PARALLELISM"]='true'
class CNDataModule(pl.LightningDataModule):
def __init__(self, Cache_dir='.', batch_size=256,ZHtokenizer=None,ENtokenizer=None):
super().__init__()
self.data_dir = Cache_dir
self.batch_size = batch_size
if ZHtokenizer is None:
self.ZHtokenizer = BertTokenizerFast.from_pretrained('bert-base-chinese',cache_dir=self.data_dir)
else :
self.ZHtokenizer = ZHtokenizer
if ENtokenizer is None:
self.ENtokenizer =CLIPTokenizer.from_pretrained("openai/clip-vit-base-patch32",cache_dir=self.data_dir)
else:
self.ENtokenizer = ENtokenizer
# def train_dataloader(self, B=None):
# if B is None:
# B=self.batch_size
# return torch.utils.data.DataLoader(self.train, batch_size=B, shuffle=True, num_workers=2, prefetch_factor=2, pin_memory=True,drop_last=True)
# def val_dataloader(self, B=None):
# if B is None:
# B=self.batch_size
# return torch.utils.data.DataLoader(self.val, batch_size=B, shuffle=False, num_workers=2, prefetch_factor=2, pin_memory=True,drop_last=True)
def test_dataloader(self,B=None):
if B is None:
B=self.batch_size
return torch.utils.data.DataLoader(self.test, batch_size=B, shuffle=True, num_workers=1, prefetch_factor=1, pin_memory=True,drop_last=True)
def prepare_data(self):
'''called only once and on 1 GPU'''
# # download data
if not os.path.exists(self.data_dir):
os.makedirs(self.data_dir,exist_ok=True)
from datasets import load_dataset
self.dataset = load_dataset("msr_zhen_translation_parity",
cache_dir=self.data_dir,
streaming=False,
)
def tokenization(self,sample):
return {'en' : self.ENtokenizer(sample["en"], padding="max_length", truncation=True, max_length=77),
'zh' : self.ZHtokenizer(sample["zh"], padding="max_length", truncation=True, max_length=77)}
def setup(self, stage=None):
'''called on each GPU separately - stage defines if we are at fit or test step'''
#print("Entered COCO datasetup")
from datasets import load_dataset
if not self.dataset:
self.dataset= load_dataset("un_pc","en-zh",
cache_dir=self.data_dir,
streaming=True,
)
self.test = self.dataset["train"]["translation"].map(lambda x: self.tokenization(x), batched=False)
if __name__=="__main__":
import argparse
from tqdm import tqdm
parser = argparse.ArgumentParser(description='location of data')
parser.add_argument('--data', type=str, default='/data', help='location of data')
args = parser.parse_args()
print("args",args)
datalocation=args.data
datamodule=MagicSwordCNDataModule(Cache_dir=datalocation,annotations=os.path.join(datalocation,"annotations"),batch_size=2)
datamodule.download_data()
datamodule.setup()
dl=datamodule.train_dataloader()
for batch in tqdm(dl):
print(batch)