From 16da6d5454839d4c1e7bc3df3d6e9fa9bcff4dfc Mon Sep 17 00:00:00 2001 From: patience111 Date: Sat, 18 Jun 2022 00:27:24 +0800 Subject: [PATCH 1/7] 220617 --- scripts/argnet_ssnt.py | 56 ++++++++++++++++++++++++++++-------------- 1 file changed, 37 insertions(+), 19 deletions(-) diff --git a/scripts/argnet_ssnt.py b/scripts/argnet_ssnt.py index f0752f6..5f093c9 100644 --- a/scripts/argnet_ssnt.py +++ b/scripts/argnet_ssnt.py @@ -5,12 +5,31 @@ from tensorflow.keras.utils import to_categorical import random import os -os.environ['CUDA_VISIBLE_DEVICES'] = '1' +#os.environ['CUDA_VISIBLE_DEVICES'] = '1' import tqdm - -#load model -filterm = tf.keras.models.load_model(os.path.join(os.path.dirname(__file__), '../model/AESS.h5')) -classifier = tf.keras.models.load_model(os.path.join(os.path.dirname(__file__), '../model/classifier_ss.h5')) +import cProfile, pstats, io + +#def profile(fnc): +# +# """A decorator that uses cProfile to profile a function""" +# +# def inner(*args, **kwargs): +# +# pr = cProfile.Profile() +# pr.enable() +# retval = fnc(*args, **kwargs) +# pr.disable() +# s = io.StringIO() +# sortby = 'cumulative' +# ps = pstats.Stats(pr, stream=s).sort_stats(sortby) +# ps.print_stats() +# print(s.getvalue()) +# return retval +# +# return inner +#model +filterm = tf.keras.models.load_model(os.path.join(os.path.dirname(__file__), './model/AESS_tall.h5')) +classifier = tf.keras.models.load_model(os.path.join(os.path.dirname(__file__), './model/classifier-ss_tall.h5')) #encode, encode all the sequence to 1600 aa length char_dict = {} @@ -55,12 +74,9 @@ def test_newEncodeVaryLength(tests): def encode64(seq): char = 'ACDEFGHIKLMNPQRSTVWXY' dimension1 = 64 - train_array = np.zeros((dimension1, 22)) - for i in range(dimension1): - if i < len(seq): - train_array[i] = char_dict[seq[i]] - else: - train_array[i][21] = 1 + pad = np.array(21*[0] + [1]) + time = dimension1-len(seq) + train_array = np.stack([char_dict[c] for c in seq]+[pad]*(time)) return train_array def testencode64(seqs): @@ -81,8 +97,8 @@ def test_encode(seqs): #length = length for idx, seq in tqdm.tqdm(enumerate(seqs)): #/print(seq.id) - temp = [seq.seq.translate(), seq.seq[1:].translate(), seq.seq[2:].translate(), seq.seq.reverse_complement().translate(), - seq.seq.reverse_complement()[1:].translate(), seq.seq.reverse_complement()[2:].translate()] + rc = seq.seq.reverse_complement() + temp = [seq.seq.translate(), seq.seq[1:].translate(), seq.seq[2:].translate(), rc.translate(), rc[1:].translate(), rc[2:].translate()] temp_split = [] for ele in temp: if "*" in ele: @@ -107,7 +123,7 @@ def test_encode(seqs): def prediction(seqs): predictions = [] - temp = filterm.predict(seqs, batch_size=8192) + temp = filterm.predict(seqs, batch_size=8196) predictions.append(temp) return predictions @@ -127,14 +143,16 @@ def reconstruction_simi(pres, ori): return reconstructs, simis cuts = [0.8064516129032258, 0.7666666666666667, 0.7752551020408163] + +#@profile def argnet_ssnt(input_file, outfile): testencode_pre = [] test = [i for i in sio.parse(input_file, 'fasta')] test_ids = [ele.id for ele in test] #arg_encode, record_notpre, record_pre, encodeall_dict, ori = test_encode(arg, i[-1]) testencode, not_pre, pre, encodeall_dict, ori = test_encode(test) - for num in range(0, len(testencode), 8192): - testencode_pre += prediction(testencode[num:num+8192]) + for num in range(0, len(testencode), 8196): + testencode_pre += prediction(testencode[num:num+8196]) #testencode_pre = prediction(testencode) # if huge volumn of seqs (~ millions) this will be change to create batch in advance pre_con = np.concatenate(testencode_pre) #print("the encode shape is: ", pre_con.shape) @@ -171,7 +189,7 @@ def argnet_ssnt(input_file, outfile): notpass_idx.append(index) ###classification - train_data = [i for i in sio.parse(os.path.join(os.path.dirname(__file__), "../data/train.fasta"),'fasta')] + train_data = [i for i in sio.parse(os.path.join(os.path.dirname(__file__), "./data/train.fasta"),'fasta')] train_labels = [ele.id.split('|')[3].strip() for ele in train_data] encodeder = LabelBinarizer() encoded_train_labels = encodeder.fit_transform(train_labels) @@ -181,14 +199,14 @@ def argnet_ssnt(input_file, outfile): label_dic[index] = ele classifications = [] - classifications = classifier.predict(np.stack(passed_encode, axis=0), batch_size = 2048) + classifications = classifier.predict(np.stack(passed_encode, axis=0), batch_size = 3500) out = {} for i, ele in enumerate(passed_idx): out[ele] = [np.max(classifications[i]), label_dic[np.argmax(classifications[i])]] ### output - with open(os.path.join(os.path.dirname(__file__), "../results/" + outfile) , 'w') as f: + with open(os.path.join(os.path.dirname(__file__), "./results/" + outfile) , 'w') as f: f.write('test_id' + '\t' + 'ARG_prediction' + '\t' + 'resistance_category' + '\t' + 'probability' + '\n') for idx, ele in enumerate(test): if idx in passed_idx: From 8a52d4572b541828983010014a9afd353bd9984c Mon Sep 17 00:00:00 2001 From: patience111 Date: Sat, 18 Jun 2022 10:08:18 +0800 Subject: [PATCH 2/7] back encode64 --- scripts/argnet_ssnt.py | 12 +++++++++--- 1 file changed, 9 insertions(+), 3 deletions(-) diff --git a/scripts/argnet_ssnt.py b/scripts/argnet_ssnt.py index 5f093c9..6887052 100644 --- a/scripts/argnet_ssnt.py +++ b/scripts/argnet_ssnt.py @@ -74,9 +74,15 @@ def test_newEncodeVaryLength(tests): def encode64(seq): char = 'ACDEFGHIKLMNPQRSTVWXY' dimension1 = 64 - pad = np.array(21*[0] + [1]) - time = dimension1-len(seq) - train_array = np.stack([char_dict[c] for c in seq]+[pad]*(time)) + #pad = np.array(21*[0] + [1]) + #time = dimension1-len(seq) + #train_array = np.stack([char_dict[c] for c in seq]+[pad]*(time)) + train_array = np.zeros((dimension1, 22)) + for i in range(dimension1): + if i < len(seq): + train_array[i] = char_dict[seq[i]] + else: + train_array[i][21] = 1 return train_array def testencode64(seqs): From 68cadfe53c62febe9468e54e44df2e5f61ffad38 Mon Sep 17 00:00:00 2001 From: patience111 Date: Sun, 19 Jun 2022 09:37:55 +0800 Subject: [PATCH 3/7] translate_reconstruct_classification --- scripts/argnet_ssnt.py | 100 ++++++++++++++++++++++++----------------- 1 file changed, 58 insertions(+), 42 deletions(-) diff --git a/scripts/argnet_ssnt.py b/scripts/argnet_ssnt.py index 6887052..5c723c1 100644 --- a/scripts/argnet_ssnt.py +++ b/scripts/argnet_ssnt.py @@ -8,6 +8,7 @@ #os.environ['CUDA_VISIBLE_DEVICES'] = '1' import tqdm import cProfile, pstats, io +import Bio.Data.CodonTable as bdc #def profile(fnc): # @@ -90,6 +91,19 @@ def testencode64(seqs): encode = np.array(encode) return encode +codon_table = bdc.ambiguous_generic_by_name['Standard'] +forwardT = codon_table.forward_table +def translate(seq): + aa = '' + for i in range(0, len(seq)-len(seq)%3, 3): + codon = seq[i:i+3] + print(codon) + try: + aa += forwardT[codon] + except: + aa += '*' + return aa + def test_encode(seqs): """ input as a list of test sequences @@ -103,8 +117,10 @@ def test_encode(seqs): #length = length for idx, seq in tqdm.tqdm(enumerate(seqs)): #/print(seq.id) + seqf = seq.seq rc = seq.seq.reverse_complement() - temp = [seq.seq.translate(), seq.seq[1:].translate(), seq.seq[2:].translate(), rc.translate(), rc[1:].translate(), rc[2:].translate()] + temp = [translate(seqf), translate(seqf[1:]), translate(seqf[2:]), translate(rc), translate(rc[1:]), translate(rc[2:])] + #temp = [seq.seq.translate(), seq.seq[1:].translate(), seq.seq[2:].translate(), rc.translate(), rc[1:].translate(), rc[2:].translate()] temp_split = [] for ele in temp: if "*" in ele: @@ -121,7 +137,7 @@ def test_encode(seqs): record_pre[seq.id] = idx ori.extend(temp_seq) encode = testencode64(temp_seq) - encodeall_dict[seq.id] = list(range(start, start + len(temp_seq))) + encodeall_dict[seq.id] = (start, start + len(temp_seq)) encodeall.extend(encode) start += len(temp_seq) encodeall = np.array(encodeall) @@ -133,20 +149,23 @@ def prediction(seqs): predictions.append(temp) return predictions + def reconstruction_simi(pres, ori): simis = [] reconstructs = [] - for index, ele in enumerate(pres): + argmax_pre = np.argmax(pres, axis=2) + for index, ele in enumerate(argmax_pre): length = len(ori[index]) count_simi = 0 - reconstruct = '' + #reconstruct = '' for pos in range(length): - if chars[np.argmax(ele[pos])] == ori[index][pos]: + if chars[ele[pos]] == ori[index][pos]: count_simi += 1 - reconstruct += chars[np.argmax(ele[pos])] + #reconstruct += chars[np.argmax(ele[pos])] simis.append(count_simi / length) - reconstructs.append(reconstruct) - return reconstructs, simis + #reconstructs.append(reconstruct) + return simis + cuts = [0.8064516129032258, 0.7666666666666667, 0.7752551020408163] @@ -163,43 +182,42 @@ def argnet_ssnt(input_file, outfile): pre_con = np.concatenate(testencode_pre) #print("the encode shape is: ", pre_con.shape) #print("the num of origin seqs is: ", len(ori)) - reconstructs, simis = reconstruction_simi(pre_con, ori) + simis = reconstruction_simi(pre_con, ori) passed_encode = [] ### notice list and np.array passed_idx = [] notpass_idx = [] - print(len(simis) == len(ori)) + assert len(simis) == len(ori) simis_edit = [] count_iter = 0 + for k, v in encodeall_dict.items(): - simis_edit.append(max(simis[v[0]:v[-1]+1])) + simis_edit.append(max(simis[v[0]:v[-1]])) count_iter += 1 for index, ele in enumerate(simis_edit): - if len(test[index]) in range(100, 120): - if ele >= cuts[0]: - #passed.append(test[index]) - passed_encode.append(testencode[index]) - passed_idx.append(index) - else: - notpass_idx.append(index) - if len(test[index]) in range(120, 150): - if ele >= cuts[1]: - passed_encode.append(testencode[index]) - passed_idx.append(index) - else: - notpass_idx.append(index) - if len(test[index]) == 150: - if ele >= cuts[-1]: - passed_encode.append(testencode[index]) - passed_idx.append(index) - else: - notpass_idx.append(index) - + if len(test[index]) < 120: + cuts_idx = 0 + elif len(test[index]) < 150: + cuts_idx = 1 + else: + cuts_idx = 2 + if ele >= cuts[cuts_idx]: + passed_encode.append(testencode[index]) + passed_idx.append(index) + else: + notpass_idx.append(index) + ###classification - train_data = [i for i in sio.parse(os.path.join(os.path.dirname(__file__), "./data/train.fasta"),'fasta')] - train_labels = [ele.id.split('|')[3].strip() for ele in train_data] - encodeder = LabelBinarizer() - encoded_train_labels = encodeder.fit_transform(train_labels) - prepare = sorted(list(set(train_labels))) + #train_data = [i for i in sio.parse(os.path.join(os.path.dirname(__file__), "./data/train.fasta"),'fasta')] + #train_labels = [ele.id.split('|')[3].strip() for ele in train_data] + #encodeder = LabelBinarizer() + #encoded_train_labels = encodeder.fit_transform(train_labels) + + train_labels = ['beta-lactam', 'multidrug', 'bacitracin', 'MLS', 'aminoglycoside', 'polymyxin', 'tetracycline', + 'fosfomycin', 'chloramphenicol', 'glycopeptide', 'quinolone', 'peptide','sulfonamide', 'trimethoprim', 'rifamycin', + 'qa_compound', 'aminocoumarin', 'kasugamycin', 'nitroimidazole', 'streptothricin', 'elfamycin', 'fusidic_acid', + 'mupirocin', 'tetracenomycin', 'pleuromutilin', 'bleomycin', 'triclosan', 'ethambutol', 'isoniazid', 'tunicamycin', + 'nitrofurantoin', 'puromycin', 'thiostrepton', 'pyrazinamide', 'oxazolidinone', 'fosmidomycin'] + prepare = sorted(train_labels) label_dic = {} for index, ele in enumerate(prepare): label_dic[index] = ele @@ -208,8 +226,11 @@ def argnet_ssnt(input_file, outfile): classifications = classifier.predict(np.stack(passed_encode, axis=0), batch_size = 3500) out = {} + + classification_argmax = np.argmax(classifications, axis=1) + classification_max = np.max(classifications, axis=1) for i, ele in enumerate(passed_idx): - out[ele] = [np.max(classifications[i]), label_dic[np.argmax(classifications[i])]] + out[ele] = [classification_max[i], label_dic[classification_argmax[i]]] ### output with open(os.path.join(os.path.dirname(__file__), "./results/" + outfile) , 'w') as f: @@ -223,8 +244,3 @@ def argnet_ssnt(input_file, outfile): if idx in notpass_idx: f.write(test[idx].id + '\t') f.write('non-ARG' + '\t' + '' + '\t' + '' + '\n') - - - - - From 09485ed023b38f76f083b0a43e08f6d772621ca2 Mon Sep 17 00:00:00 2001 From: patience111 Date: Sun, 19 Jun 2022 10:56:51 +0800 Subject: [PATCH 4/7] translate_reconstruct_classification --- scripts/argnet_ssnt.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/scripts/argnet_ssnt.py b/scripts/argnet_ssnt.py index 5c723c1..8fb6c2b 100644 --- a/scripts/argnet_ssnt.py +++ b/scripts/argnet_ssnt.py @@ -97,7 +97,7 @@ def translate(seq): aa = '' for i in range(0, len(seq)-len(seq)%3, 3): codon = seq[i:i+3] - print(codon) + #print(codon) try: aa += forwardT[codon] except: From 43804e63be40ddd63f57f6bd4c7532c8d537bc9c Mon Sep 17 00:00:00 2001 From: patience111 Date: Mon, 20 Jun 2022 06:26:18 +0800 Subject: [PATCH 5/7] translate_modelReduce_e5 --- scripts/argnet_ssnt.py | 104 ++++++++++++++++++++++++++++++++++------- 1 file changed, 87 insertions(+), 17 deletions(-) diff --git a/scripts/argnet_ssnt.py b/scripts/argnet_ssnt.py index 8fb6c2b..968a007 100644 --- a/scripts/argnet_ssnt.py +++ b/scripts/argnet_ssnt.py @@ -9,6 +9,8 @@ import tqdm import cProfile, pstats, io import Bio.Data.CodonTable as bdc +from itertools import product +from kito import reduce_keras_model #def profile(fnc): # @@ -31,6 +33,8 @@ #model filterm = tf.keras.models.load_model(os.path.join(os.path.dirname(__file__), './model/AESS_tall.h5')) classifier = tf.keras.models.load_model(os.path.join(os.path.dirname(__file__), './model/classifier-ss_tall.h5')) +filterm_reduced = reduce_keras_model(filterm) +classifier_reduced = reduce_keras_model(classifier) #encode, encode all the sequence to 1600 aa length char_dict = {} @@ -91,18 +95,86 @@ def testencode64(seqs): encode = np.array(encode) return encode -codon_table = bdc.ambiguous_generic_by_name['Standard'] -forwardT = codon_table.forward_table +comprehensive_coden_table = { + "UUU":"F", "UUC":"F", "UUA":"L", "UUG":"L", + "UCU":"S", "UCC":"S", "UCA":"S", "UCG":"S", + "UAU":"Y", "UAC":"Y", "UAA":"*", "UAG":"*", + "UGU":"C", "UGC":"C", "UGA":"*", "UGG":"W", + "CUU":"L", "CUC":"L", "CUA":"L", "CUG":"L", + "CCU":"P", "CCC":"P", "CCA":"P", "CCG":"P", + "CAU":"H", "CAC":"H", "CAA":"Q", "CAG":"Q", + "CGU":"R", "CGC":"R", "CGA":"R", "CGG":"R", + "AUU":"I", "AUC":"I", "AUA":"I", "AUG":"M", + "ACU":"T", "ACC":"T", "ACA":"T", "ACG":"T", + "AAU":"N", "AAC":"N", "AAA":"K", "AAG":"K", + "AGU":"S", "AGC":"S", "AGA":"R", "AGG":"R", + "GUU":"V", "GUC":"V", "GUA":"V", "GUG":"V", + "GCU":"A", "GCC":"A", "GCA":"A", "GCG":"A", + "GAU":"D", "GAC":"D", "GAA":"E", "GAG":"E", + "GGU":"G", "GGC":"G", "GGA":"G", "GGG":"G", + "TTT":"F", "TTC":"F", "TTA":"L", "TTG":"L", + "TCT":"S", "TCC":"S", "TCA":"S", "TCG":"S", + "TAT":"Y", "TAC":"Y", "TAA":"*", "TAG":"*", + "TGT":"C", "TGC":"C", "TGA":"*", "TGG":"W", + "CTT":"L", "CTC":"L", "CTA":"L", "CTG":"L", + "CCT":"P", "CAT":"H", "CGT":"R", "ATT":"I", + "ATC":"I", "ATA":"I", "ATG":"M", "ACT":"T", + "AAT":"N", "AGT":"S", "GTT":"V", "GTC":"V", + "GTA":"V", "GTG":"V", "GCT":"A", "GAT":"D", + "GGT":"G", 'aaa': 'K', 'aac': 'N','aag': 'K', + 'aat': 'N', 'aau': 'N', 'aca': 'T', 'acc': 'T', + 'acg': 'T', 'act': 'T', 'acu': 'T', 'aga': 'R', + 'agc': 'S', 'agg': 'R', 'agt': 'S', 'agu': 'S', + 'ata': 'I', 'atc': 'I', 'atg': 'M', 'att': 'I', + 'aua': 'I', 'auc': 'I', 'aug': 'M', 'auu': 'I', + 'caa': 'Q', 'cac': 'H', 'cag': 'Q', 'cat': 'H', + 'cau': 'H', 'cca': 'P', 'ccc': 'P', 'ccg': 'P', + 'cct': 'P', 'ccu': 'P', 'cga': 'R', 'cgc': 'R', + 'cgg': 'R', 'cgt': 'R', 'cgu': 'R', 'cta': 'L', + 'ctc': 'L', 'ctg': 'L', 'ctt': 'L', 'cua': 'L', + 'cuc': 'L', 'cug': 'L', 'cuu': 'L', 'gaa': 'E', + 'gac': 'D', 'gag': 'E', 'gat': 'D', 'gau': 'D', + 'gca': 'A', 'gcc': 'A', 'gcg': 'A', 'gct': 'A', + 'gcu': 'A', 'gga': 'G', 'ggc': 'G', 'ggg': 'G', + 'ggt': 'G', 'ggu': 'G', 'gta': 'V', 'gtc': 'V', + 'gtg': 'V', 'gtt': 'V', 'gua': 'V', 'guc': 'V', + 'gug': 'V', 'guu': 'V', 'taa': '*', 'tac': 'Y', + 'tag': '*', 'tat': 'Y', 'tca': 'S', 'tcc': 'S', + 'tcg': 'S', 'tct': 'S', 'tga': '*', 'tgc': 'C', + 'tgg': 'W', 'tgt': 'C', 'tta': 'L', 'ttc': 'F', + 'ttg': 'L', 'ttt': 'F', 'uaa': '*', 'uac': 'Y', + 'uag': '*', 'uau': 'Y', 'uca': 'S', 'ucc': 'S', + 'ucg': 'S', 'ucu': 'S', 'uga': '*', 'ugc': 'C', + 'ugg': 'W', 'ugu': 'C', 'uua': 'L', 'uuc': 'F', + 'uug': 'L', 'uuu': 'F'} + +#codon_table = bdc.ambiguous_generic_by_name['Standard'] +#forwardT = codon_table.forward_table +#for a in 'ACTGU': +# for b in 'ACTGU': +# for c in 'ACTGU': +# if a+b+c not in forwardT: +# forwardT[a+b+c] = "*" +ambiguous = ["A","C","G","T","U","W","S","M","K","R","Y","B","D","H","V","N"] +keywords = [''.join(i) for i in product(ambiguous, repeat = 3)] +keywords_select = [ele for ele in keywords if ele not in comprehensive_coden_table.keys()] +keywords_select_dict = {ele : 'X' for ele in keywords_select} +keywords_select_lower_dict = {ele.lower() : 'X' for ele in keywords_select} + +finalT = {} +finalT.update(comprehensive_coden_table) +finalT.update(keywords_select_dict) +finalT.update(keywords_select_lower_dict) + def translate(seq): - aa = '' - for i in range(0, len(seq)-len(seq)%3, 3): + lenseq = len(seq) + aa = ['*']*(lenseq//3) + for i in range(0, lenseq-lenseq%3, 3): codon = seq[i:i+3] - #print(codon) - try: - aa += forwardT[codon] - except: - aa += '*' - return aa + #if codon in forwardT: + aa[i//3] = finalT[codon] + aastr = ''.join(aa) + return aastr def test_encode(seqs): """ @@ -117,8 +189,8 @@ def test_encode(seqs): #length = length for idx, seq in tqdm.tqdm(enumerate(seqs)): #/print(seq.id) - seqf = seq.seq - rc = seq.seq.reverse_complement() + seqf = str(seq.seq) + rc = str(seq.seq.reverse_complement()) temp = [translate(seqf), translate(seqf[1:]), translate(seqf[2:]), translate(rc), translate(rc[1:]), translate(rc[2:])] #temp = [seq.seq.translate(), seq.seq[1:].translate(), seq.seq[2:].translate(), rc.translate(), rc[1:].translate(), rc[2:].translate()] temp_split = [] @@ -145,11 +217,10 @@ def test_encode(seqs): def prediction(seqs): predictions = [] - temp = filterm.predict(seqs, batch_size=8196) + temp = filterm_reduced.predict(seqs, batch_size=8196) predictions.append(temp) return predictions - def reconstruction_simi(pres, ori): simis = [] reconstructs = [] @@ -166,14 +237,13 @@ def reconstruction_simi(pres, ori): #reconstructs.append(reconstruct) return simis - cuts = [0.8064516129032258, 0.7666666666666667, 0.7752551020408163] #@profile def argnet_ssnt(input_file, outfile): testencode_pre = [] test = [i for i in sio.parse(input_file, 'fasta')] - test_ids = [ele.id for ele in test] + #test_ids = [ele.id for ele in test] #arg_encode, record_notpre, record_pre, encodeall_dict, ori = test_encode(arg, i[-1]) testencode, not_pre, pre, encodeall_dict, ori = test_encode(test) for num in range(0, len(testencode), 8196): @@ -223,7 +293,7 @@ def argnet_ssnt(input_file, outfile): label_dic[index] = ele classifications = [] - classifications = classifier.predict(np.stack(passed_encode, axis=0), batch_size = 3500) + classifications = classifier_reduced.predict(np.stack(passed_encode, axis=0), batch_size = 3500) out = {} From b193d252d9e83f2d1c9e46ae964fbe3ac32f0970 Mon Sep 17 00:00:00 2001 From: patience111 Date: Thu, 28 Jul 2022 16:20:56 +0800 Subject: [PATCH 6/7] argnet_lsaa_speed.py --- scripts/argnet_lsaa_speed.py | 174 +++++++++++++++++++++++++++++++++++ 1 file changed, 174 insertions(+) create mode 100644 scripts/argnet_lsaa_speed.py diff --git a/scripts/argnet_lsaa_speed.py b/scripts/argnet_lsaa_speed.py new file mode 100644 index 0000000..0bdd489 --- /dev/null +++ b/scripts/argnet_lsaa_speed.py @@ -0,0 +1,174 @@ +import Bio.SeqIO as sio +import tensorflow as tf +import numpy as np +from sklearn.preprocessing import LabelBinarizer +from tensorflow.keras.utils import to_categorical +import random +import os +os.environ['CUDA_VISIBLE_DEVICES'] = '1' +import tqdm + +#load model +filterm = tf.keras.models.load_model(os.path.join(os.path.dirname(__file__), '../model/AELS_tall.h5')) +classifier = tf.keras.models.load_model(os.path.join(os.path.dirname(__file__), '../model/classifier-ls_tall.h5')) + +#encode, encode all the sequence to 1600 aa length +char_dict = {} +chars = 'ACDEFGHIKLMNPQRSTVWXYBJZ' +new_chars = "ACDEFGHIKLMNPQRSTVWXY" +for char in chars: + temp = np.zeros(22) + if char == 'B': + for ch in 'DN': + temp[new_chars.index(ch)] = 0.5 + elif char == 'J': + for ch in 'IL': + temp[new_chars.index(ch)] = 0.5 + elif char == 'Z': + for ch in 'EQ': + temp[new_chars.index(ch)] = 0.5 + else: + temp[new_chars.index(char)] = 1 + char_dict[char] = temp + +def encode(seq): + char = 'ACDEFGHIKLMNPQRSTVWXY' + train_array = np.zeros((1600,22)) + for i in range(1600): + if i= 1600: + align = 1600 + else: + align = length + for pos in range(align): + if chars[ele[pos]] == ori[index][pos]: + count_simi += 1 + #reconstruct += chars[np.argmax(ele[pos])] + simis.append(count_simi / length) + #reconstructs.append(reconstruct) + return simis + + +def argnet_lsaa(input_file, outfile): + cut = 0.25868536454055224 + test = [i for i in sio.parse(input_file, 'fasta')] + train_labels = ['beta-lactam', 'multidrug', 'bacitracin', 'MLS', 'aminoglycoside', 'polymyxin', 'tetracycline', + 'fosfomycin', 'chloramphenicol', 'glycopeptide', 'quinolone', 'peptide','sulfonamide', 'trimethoprim', 'rifamycin', + 'qa_compound', 'aminocoumarin', 'kasugamycin', 'nitroimidazole', 'streptothricin', 'elfamycin', 'fusidic_acid', + 'mupirocin', 'tetracenomycin', 'pleuromutilin', 'bleomycin', 'triclosan', 'ethambutol', 'isoniazid', 'tunicamycin', + 'nitrofurantoin', 'puromycin', 'thiostrepton', 'pyrazinamide', 'oxazolidinone', 'fosmidomycin'] + + prepare = sorted(train_labels) + label_dic = {} + for index, ele in enumerate(prepare): + label_dic[index] = ele + + with open(os.path.join(os.path.dirname(__file__), "../results/" + outfile) , 'w') as f: + f.write('test_id' + '\t' + 'ARG_prediction' + '\t' + 'resistance_category' + '\t' + 'probability' + '\n') + for idx, test_chunk in enumerate(list(chunks(test, 10000))): + #test_ids = [ele.id for ele in test] + testencode = test_encode(test_chunk) + testencode_pre = filter_prediction_batch(testencode) # if huge volumn of seqs (~ millions) this will be change to create batch in advance + simis = reconstruction_simi(testencode_pre, test_chunk) + #results = calErrorRate(simis, cut) + #passed = [] + passed_encode = [] ### notice list and np.array + passed_idx = [] + notpass_idx = [] + for index, ele in enumerate(simis): + if ele >= cut: + #passed.append(test[index]) + passed_encode.append(testencode[index]) + passed_idx.append(index) + else: + notpass_idx.append(index) + + ###classification + #train_data = [i for i in sio.parse(os.path.join(os.path.dirname(__file__), "../data/train.fasta"),'fasta')] + #train_labels = [ele.id.split('|')[3].strip() for ele in train_data] + #encodeder = LabelBinarizer() + #encoded_train_labels = encodeder.fit_transform(train_labels) + + classifications = [] + classifications = classifier.predict(np.stack(passed_encode, axis=0), batch_size = 512) + + out = {} + classification_argmax = np.argmax(classifications, axis=1) + classification_max = np.max(classifications, axis=1) + + for i, ele in enumerate(passed_idx): + out[ele] = [classification_max[i], label_dic[classification_argmax[i]]] + ### output + with open(os.path.join(os.path.dirname(__file__), "../results/" + outfile) , 'a') as f: + for idx, ele in enumerate(test_chunk): + if idx in passed_idx: + f.write(test[idx].id + '\t') + f.write('ARG' + '\t') + f.write(out[idx][-1] + '\t') + f.write(str(out[idx][0]) + '\n') + if idx in notpass_idx: + f.write(test[idx].id + '\t') + f.write('non-ARG' + '\t' + '' + '\t' + '' + '\n') From abf945e19f158cef4b8524066ca20a80ab8eb362 Mon Sep 17 00:00:00 2001 From: patience111 Date: Thu, 28 Jul 2022 16:22:17 +0800 Subject: [PATCH 7/7] argnet_main --- scripts/argnet.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/scripts/argnet.py b/scripts/argnet.py index e1bde44..e20938f 100644 --- a/scripts/argnet.py +++ b/scripts/argnet.py @@ -44,25 +44,25 @@ args = parser.parse_args() -import argnet_lsaa as lsaa -import argnet_lsnt as lsnt -import argnet_ssaa as ssaa -import argnet_ssnt as ssnt ## for AESS_aa -> classifier if args.type == 'aa' and args.model == 'argnet-s': + import argnet_ssaa_chunk as ssaa ssaa.argnet_ssaa(args.input, args.outname) # for AESS_nt -> classifier if args.type == 'nt' and args.model == 'argnet-s': + import argnet_ssnt as ssnt ssnt.argnet_ssnt(args.input, args.outname) # for AELS_aa -> classifier if args.type == 'aa' and args.model == 'argnet-l': + import argnet_lsaa_speed as lsaa lsaa.argnet_lsaa(args.input, args.outname) # for AELS_nt -> classifier if args.type == 'nt' and args.model == 'argnet-l': + import argnet_lsnt as lsnt lsnt.argnet_lsnt(args.input, args.outname)