55
66from torch .utils .data import DataLoader
77import pandas as pd
8- import os , shutil
8+ import os , shutil , json
99
1010def load_data (train_dataset : str , test_dataset : str , batch_size : int = 64 , num_workers : int = 40 ):
1111 tokenizer = MethylVocab (k = 3 )
@@ -14,6 +14,9 @@ def load_data(train_dataset: str, test_dataset: str, batch_size: int = 64, num_w
1414 train_dataset = MethylBertFinetuneDataset (train_dataset , tokenizer , seq_len = 150 )
1515 test_dataset = MethylBertFinetuneDataset (test_dataset , tokenizer , seq_len = 150 )
1616
17+ if len (test_dataset ) > 500 :
18+ test_dataset .subset_data (500 )
19+
1720 # Create a data loader
1821 print ("Creating Dataloader" )
1922 local_step_batch_size = int (batch_size / 4 )
@@ -53,27 +56,73 @@ def test_finetune_no_pretrain(tokenizer : MethylVocab,
5356 assert os .path .exists (os .path .join (save_path , "train.csv" ))
5457 assert steps == pd .read_csv (os .path .join (save_path , "train.csv" )).shape [0 ]
5558
56- def test_finetune_savefreq (tokenizer : MethylVocab ,
59+ def test_finetune_no_pretrain_focal (tokenizer : MethylVocab ,
60+ save_path : str ,
61+ train_data_loader : DataLoader ,
62+ test_data_loader : DataLoader ,
63+ pretrain_model : str ,
64+ steps : int = 10 ):
65+
66+ trainer = MethylBertFinetuneTrainer (vocab_size = len (tokenizer ),
67+ save_path = save_path ,
68+ train_dataloader = train_data_loader ,
69+ test_dataloader = test_data_loader ,
70+ with_cuda = False ,
71+ loss = "focal_bce" )
72+
73+ trainer .create_model (config_file = os .path .join (pretrain_model , "config.json" ))
74+
75+ trainer .train (steps )
76+ assert os .path .exists (os .path .exists (os .path .join (save_path , "config.json" )))
77+
78+ with open (os .path .join (save_path , "config.json" )) as fp :
79+ config = json .load (fp )
80+ assert config ["loss" ] == "focal_bce"
81+
82+ def test_finetune_no_pretrain_focal (tokenizer : MethylVocab ,
5783 save_path : str ,
5884 train_data_loader : DataLoader ,
5985 test_data_loader : DataLoader ,
6086 pretrain_model : str ,
61- steps : int = 10 ,
62- save_freq : int = 1 ):
87+ steps : int = 10 ):
88+
89+ trainer = MethylBertFinetuneTrainer (vocab_size = len (tokenizer ),
90+ save_path = save_path ,
91+ train_dataloader = train_data_loader ,
92+ test_dataloader = test_data_loader ,
93+ with_cuda = False ,
94+ loss = "focal_bce" )
95+
96+ trainer .create_model (config_file = os .path .join (pretrain_model , "config.json" ))
97+
98+ trainer .train (steps )
99+ assert os .path .exists (os .path .exists (os .path .join (save_path , "config.json" )))
100+
101+ with open (os .path .join (save_path , "config.json" )) as fp :
102+ config = json .load (fp )
103+ assert config ["loss" ] == "focal_bce"
104+
105+
106+ def test_finetune_focal_multicelltype (tokenizer : MethylVocab ,
107+ save_path : str ,
108+ train_data_loader : DataLoader ,
109+ test_data_loader : DataLoader ,
110+ pretrain_model : str ,
111+ steps : int = 10 ):
63112
64113 trainer = MethylBertFinetuneTrainer (vocab_size = len (tokenizer ),
65114 save_path = save_path + "bert.model/" ,
66115 train_dataloader = train_data_loader ,
67116 test_dataloader = test_data_loader ,
68- save_freq = save_freq ,
69- with_cuda = False )
117+ with_cuda = False ,
118+ loss = "focal_bce" )
70119 trainer .load (pretrain_model )
71120 trainer .train (steps )
72121
73- assert os .path .exists (os .path .join (save_path , "bert.model_step0 /config.json" ))
74- assert os .path .exists (os .path .join (save_path , "bert.model_step0 /dmr_encoder.pickle" ))
75- assert os .path .exists (os .path .join (save_path , "bert.model_step0 /pytorch_model.bin" ))
76- assert os .path .exists (os .path .join (save_path , "bert.model_step0 /read_classification_model.pickle" ))
122+ assert os .path .exists (os .path .join (save_path , "bert.model /config.json" ))
123+ assert os .path .exists (os .path .join (save_path , "bert.model /dmr_encoder.pickle" ))
124+ assert os .path .exists (os .path .join (save_path , "bert.model /pytorch_model.bin" ))
125+ assert os .path .exists (os .path .join (save_path , "bert.model /read_classification_model.pickle" ))
77126
78127def test_finetune (tokenizer : MethylVocab ,
79128 save_path : str ,
@@ -104,9 +153,9 @@ def reset_dir(dirname):
104153 # For data processing
105154 f_bam_list = "data/bam_list.txt"
106155 f_dmr = "data/dmrs.csv"
107- f_ref = "data/genome/hg19 .fa"
156+ f_ref = "data/genome.fa"
108157 out_dir = "data/processed/"
109-
158+
110159 # Process data for fine-tuning
111160 fdg .finetune_data_generate (
112161 sc_dataset = f_bam_list ,
@@ -137,7 +186,23 @@ def reset_dir(dirname):
137186 reset_dir (save_path )
138187 test_finetune_no_pretrain (tokenizer , save_path , train_data_loader , test_data_loader , model_dir , train_step )
139188
189+ reset_dir (save_path )
190+ test_finetune_no_pretrain_focal (tokenizer , save_path , train_data_loader , test_data_loader , model_dir , train_step )
191+
140192 reset_dir (save_path )
141193 test_finetune_savefreq (tokenizer , save_path , train_data_loader , test_data_loader , model_dir , train_step , save_freq = 1 )
194+
195+ # Multiple cell type
196+ out_dir = "data/multi_cell_type/"
197+ tokenizer , train_data_loader , test_data_loader = \
198+ load_data (train_dataset = os .path .join (out_dir , "train_seq.csv" ),
199+ test_dataset = os .path .join (out_dir , "test_seq.csv" ))
200+
201+ # For fine-tuning
202+ model_dir = "data/pretrained_model/"
203+ save_path = "data/multi_cell_type/res/"
204+
205+ reset_dir (save_path )
206+ test_finetune_focal_multicelltype (tokenizer , save_path , train_data_loader , test_data_loader , model_dir )
142207
143208 print ("Everything passed!" )
0 commit comments