Skip to content

Commit 89f4d21

Browse files
committed
update test cases
1 parent abbc783 commit 89f4d21

File tree

4 files changed

+137
-12
lines changed

4 files changed

+137
-12
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ pip-log.txt
3939
pip-delete-this-directory.txt
4040

4141
# Unit test / coverage reports
42+
test/data/*
4243
htmlcov/
4344
.tox/
4445
.nox/

test/test_deconvolute.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,24 @@ def test_adjustment(trainer, tokenizer, data_loader, output_path, df_train):
1717

1818
assert pd.read_csv(os.path.join(output_path, "FI.csv"), sep="\t").shape[0] == len(df_train["dmr_label"].unique())
1919

20+
def test_multi_cell_type(trainer, tokenizer, data_loader, output_path, df_train):
21+
deconvolute(trainer = trainer,
22+
tokenizer = tokenizer,
23+
data_loader = data_loader,
24+
output_path = output_path,
25+
df_train = df_train,
26+
adjustment = False)
27+
28+
assert pd.read_csv(os.path.join(output_path, "deconvolution.csv"), sep="\t").shape[0] == 3
29+
2030
if __name__=="__main__":
2131
f_bulk = "data/processed/test_seq.csv"
2232
f_train = "data/processed/train_seq.csv"
2333
model_dir = "res/bert.model/"
2434
out_dir = "res/deconvolution/"
2535

2636
tokenizer = MethylVocab(k=3)
37+
2738
dataset = MethylBertFinetuneDataset(f_bulk,
2839
tokenizer,
2940
seq_len=150)
@@ -37,4 +48,21 @@ def test_adjustment(trainer, tokenizer, data_loader, output_path, df_train):
3748
trainer.load(model_dir)
3849

3950
test_adjustment(trainer, tokenizer, data_loader, out_dir, df_train)
51+
# multiple cell type
52+
model_dir = "data/multi_cell_type/res/bert.model/"
53+
f_bulk = "data/multi_cell_type/test_seq.csv"
54+
f_train = "data/multi_cell_type/train_seq.csv"
55+
dataset = MethylBertFinetuneDataset(f_bulk,
56+
tokenizer,
57+
seq_len=150)
58+
data_loader = DataLoader(dataset, batch_size=50, num_workers=20)
59+
df_train = pd.read_csv(f_train, sep="\t")
60+
61+
trainer = MethylBertFinetuneTrainer(len(tokenizer),
62+
train_dataloader=data_loader,
63+
test_dataloader=data_loader,
64+
)
65+
trainer.load(model_dir)
66+
test_multi_cell_type(trainer, tokenizer, data_loader, out_dir, df_train)
67+
4068
print("Everything passed!")

test/test_finetune.py

Lines changed: 77 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from torch.utils.data import DataLoader
77
import pandas as pd
8-
import os, shutil
8+
import os, shutil, json
99

1010
def load_data(train_dataset: str, test_dataset: str, batch_size: int = 64, num_workers: int = 40):
1111
tokenizer=MethylVocab(k=3)
@@ -14,6 +14,9 @@ def load_data(train_dataset: str, test_dataset: str, batch_size: int = 64, num_w
1414
train_dataset = MethylBertFinetuneDataset(train_dataset, tokenizer, seq_len=150)
1515
test_dataset = MethylBertFinetuneDataset(test_dataset, tokenizer, seq_len=150)
1616

17+
if len(test_dataset) > 500:
18+
test_dataset.subset_data(500)
19+
1720
# Create a data loader
1821
print("Creating Dataloader")
1922
local_step_batch_size = int(batch_size/4)
@@ -53,27 +56,73 @@ def test_finetune_no_pretrain(tokenizer : MethylVocab,
5356
assert os.path.exists(os.path.join(save_path, "train.csv"))
5457
assert steps == pd.read_csv(os.path.join(save_path, "train.csv")).shape[0]
5558

56-
def test_finetune_savefreq(tokenizer : MethylVocab,
59+
def test_finetune_no_pretrain_focal(tokenizer : MethylVocab,
60+
save_path : str,
61+
train_data_loader : DataLoader,
62+
test_data_loader : DataLoader,
63+
pretrain_model : str,
64+
steps : int=10):
65+
66+
trainer = MethylBertFinetuneTrainer(vocab_size = len(tokenizer),
67+
save_path=save_path,
68+
train_dataloader=train_data_loader,
69+
test_dataloader=test_data_loader,
70+
with_cuda=False,
71+
loss="focal_bce")
72+
73+
trainer.create_model(config_file=os.path.join(pretrain_model, "config.json"))
74+
75+
trainer.train(steps)
76+
assert os.path.exists(os.path.exists(os.path.join(save_path, "config.json")))
77+
78+
with open(os.path.join(save_path, "config.json")) as fp:
79+
config = json.load(fp)
80+
assert config["loss"] == "focal_bce"
81+
82+
def test_finetune_no_pretrain_focal(tokenizer : MethylVocab,
5783
save_path : str,
5884
train_data_loader : DataLoader,
5985
test_data_loader : DataLoader,
6086
pretrain_model : str,
61-
steps : int=10,
62-
save_freq: int=1):
87+
steps : int=10):
88+
89+
trainer = MethylBertFinetuneTrainer(vocab_size = len(tokenizer),
90+
save_path=save_path,
91+
train_dataloader=train_data_loader,
92+
test_dataloader=test_data_loader,
93+
with_cuda=False,
94+
loss="focal_bce")
95+
96+
trainer.create_model(config_file=os.path.join(pretrain_model, "config.json"))
97+
98+
trainer.train(steps)
99+
assert os.path.exists(os.path.exists(os.path.join(save_path, "config.json")))
100+
101+
with open(os.path.join(save_path, "config.json")) as fp:
102+
config = json.load(fp)
103+
assert config["loss"] == "focal_bce"
104+
105+
106+
def test_finetune_focal_multicelltype(tokenizer : MethylVocab,
107+
save_path : str,
108+
train_data_loader : DataLoader,
109+
test_data_loader : DataLoader,
110+
pretrain_model : str,
111+
steps : int=10):
63112

64113
trainer = MethylBertFinetuneTrainer(vocab_size = len(tokenizer),
65114
save_path=save_path+"bert.model/",
66115
train_dataloader=train_data_loader,
67116
test_dataloader=test_data_loader,
68-
save_freq=save_freq,
69-
with_cuda=False)
117+
with_cuda=False,
118+
loss="focal_bce")
70119
trainer.load(pretrain_model)
71120
trainer.train(steps)
72121

73-
assert os.path.exists(os.path.join(save_path, "bert.model_step0/config.json"))
74-
assert os.path.exists(os.path.join(save_path, "bert.model_step0/dmr_encoder.pickle"))
75-
assert os.path.exists(os.path.join(save_path, "bert.model_step0/pytorch_model.bin"))
76-
assert os.path.exists(os.path.join(save_path, "bert.model_step0/read_classification_model.pickle"))
122+
assert os.path.exists(os.path.join(save_path, "bert.model/config.json"))
123+
assert os.path.exists(os.path.join(save_path, "bert.model/dmr_encoder.pickle"))
124+
assert os.path.exists(os.path.join(save_path, "bert.model/pytorch_model.bin"))
125+
assert os.path.exists(os.path.join(save_path, "bert.model/read_classification_model.pickle"))
77126

78127
def test_finetune(tokenizer : MethylVocab,
79128
save_path : str,
@@ -104,9 +153,9 @@ def reset_dir(dirname):
104153
# For data processing
105154
f_bam_list = "data/bam_list.txt"
106155
f_dmr = "data/dmrs.csv"
107-
f_ref = "data/genome/hg19.fa"
156+
f_ref = "data/genome.fa"
108157
out_dir = "data/processed/"
109-
158+
110159
# Process data for fine-tuning
111160
fdg.finetune_data_generate(
112161
sc_dataset = f_bam_list,
@@ -137,7 +186,23 @@ def reset_dir(dirname):
137186
reset_dir(save_path)
138187
test_finetune_no_pretrain(tokenizer, save_path, train_data_loader, test_data_loader, model_dir, train_step)
139188

189+
reset_dir(save_path)
190+
test_finetune_no_pretrain_focal(tokenizer, save_path, train_data_loader, test_data_loader, model_dir, train_step)
191+
140192
reset_dir(save_path)
141193
test_finetune_savefreq(tokenizer, save_path, train_data_loader, test_data_loader, model_dir, train_step, save_freq=1)
194+
195+
# Multiple cell type
196+
out_dir="data/multi_cell_type/"
197+
tokenizer, train_data_loader, test_data_loader = \
198+
load_data(train_dataset = os.path.join(out_dir, "train_seq.csv"),
199+
test_dataset = os.path.join(out_dir, "test_seq.csv"))
200+
201+
# For fine-tuning
202+
model_dir="data/pretrained_model/"
203+
save_path="data/multi_cell_type/res/"
204+
205+
reset_dir(save_path)
206+
test_finetune_focal_multicelltype(tokenizer, save_path, train_data_loader, test_data_loader, model_dir)
142207

143208
print("Everything passed!")

test/test_finetune_preprocess.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,19 +102,50 @@ def test_dorado_aligned_file(bam_file: str, f_dmr: str, f_ref: str, out_dir = "t
102102

103103
print("test_dorado_aligned_file passed!")
104104

105+
106+
def test_multi_cell_type(f_bam_file_list: str, f_dmr: str, f_ref: str, out_dir = "tmp/"):
107+
fdg.finetune_data_generate(
108+
sc_dataset = f_bam_file_list,
109+
f_dmr = f_dmr,
110+
f_ref = f_ref,
111+
output_dir=out_dir,
112+
split_ratio = 0.8,
113+
n_cores=1
114+
)
115+
116+
assert os.path.exists(out_dir+"train_seq.csv")
117+
assert os.path.exists(out_dir+"test_seq.csv")
118+
assert os.path.exists(out_dir+"dmrs.csv")
119+
120+
res = pd.read_csv(out_dir+"train_seq.csv", sep="\t")
121+
assert "T" in res["ctype"].tolist()
122+
assert "N" in res["ctype"].tolist()
123+
assert "P" in res["ctype"].tolist()
124+
125+
print("test_multi_cell_type passed!")
126+
127+
105128
if __name__=="__main__":
106129
f_bam = "data/T_sample.bam"
107130
f_bam_list = "data/bam_list.txt"
108131
f_dmr = "data/dmrs.csv"
109132
f_ref = "data/genome.fa"
110133

134+
111135
test_single_bam_file(bam_file = f_bam, f_dmr=f_dmr, f_ref=f_ref)
112136
test_list_bam_file(f_bam_file_list = f_bam_list, f_dmr=f_dmr, f_ref=f_ref)
113137
test_dmr_subset(bam_file = f_bam, f_dmr=f_dmr, f_ref=f_ref, n_dmrs=10)
114138
test_multi_cores(bam_file = f_bam, f_dmr=f_dmr, f_ref=f_ref, n_cores=4)
115139
test_split_ratio(bam_file = f_bam, f_dmr=f_dmr, f_ref=f_ref, split_ratio=0.7)
140+
116141

142+
f_bam_list = "data/multi_cell_type/bam_list.txt"
143+
f_dmr = "data/multi_cell_type/dmrs.csv"
144+
out_dir = "data/multi_cell_type/"
145+
test_multi_cell_type(f_bam_file_list = f_bam_list, f_dmr=f_dmr, f_ref=f_ref, out_dir=out_dir)
146+
117147
f_dorado = "data/dorado_aligned.bam"
118148
f_ref_hg38="data/hg38_genome.fa"
119149
test_dorado_aligned_file(bam_file = f_dorado, f_dmr=f_dmr, f_ref=f_ref_hg38)
120150
print("Everything passed!")
151+

0 commit comments

Comments
 (0)