diff --git a/ms2query/benchmarking/k_fold_cross_validation.py b/ms2query/benchmarking/k_fold_cross_validation.py index c0138626..cbb67772 100644 --- a/ms2query/benchmarking/k_fold_cross_validation.py +++ b/ms2query/benchmarking/k_fold_cross_validation.py @@ -12,7 +12,7 @@ clean_normalize_and_split_annotated_spectra from ms2query.create_new_library.split_data_for_training import ( select_spectra_per_unique_inchikey, split_spectra_in_random_inchikey_sets) -from ms2query.create_new_library.train_models import train_all_models +from ms2query.create_new_library.train_models import train_all_models, SettingsTrainingModels from ms2query.ms2library import create_library_object_from_one_dir from ms2query.utils import (load_matchms_spectrum_objects_from_file, save_pickled_file) @@ -102,7 +102,8 @@ def train_models_and_test_result_from_k_fold_folder(k_fold_split_folder:str, # Train all models train_all_models(annotated_training_spectra, unannotated_training_spectra, - models_folder) + models_folder, + SettingsTrainingModels({"add_compound_classes": False})) # Generate test results ms2library = create_library_object_from_one_dir(models_folder) diff --git a/ms2query/create_new_library/train_models.py b/ms2query/create_new_library/train_models.py index aaf99426..f3a8523b 100644 --- a/ms2query/create_new_library/train_models.py +++ b/ms2query/create_new_library/train_models.py @@ -18,11 +18,20 @@ class SettingsTrainingModels: def __init__(self, - settings): + settings: dict = None): + """ + + :param settings: + preselection_cut_off: + This determines the number of highest scoring matches of MS2Deepscore that are used during training of MS2Query. + For these top library matches all scores are calculated + """ default_settings = {"ms2ds_fraction_validation_spectra": 30, "ms2ds_epochs": 150, "spec2vec_iterations": 30, - "ms2query_fraction_for_making_pairs": 40} + "ms2query_fraction_for_making_pairs": 40, + "add_compound_classes": True, + "preselection_cut_off": 2000} if settings: for setting in settings: assert setting in default_settings, \ @@ -32,15 +41,16 @@ def __init__(self, self.ms2ds_epochs: int = default_settings["ms2ds_epochs"] self.ms2query_fraction_for_making_pairs: int = default_settings["ms2query_fraction_for_making_pairs"] self.spec2vec_iterations = default_settings["spec2vec_iterations"] + self.add_compound_classes = default_settings["add_compound_classes"] + self.preselection_cut_off = default_settings["preselection_cut_off"] def train_all_models(annotated_training_spectra, unannotated_training_spectra, output_folder, - other_settings: dict = None): + settings: SettingsTrainingModels): if not os.path.isdir(output_folder): os.mkdir(output_folder) - settings = SettingsTrainingModels(other_settings) # set file names of new generated files ms2deepscore_model_file_name = os.path.join(output_folder, "ms2deepscore_model.hdf5") spec2vec_model_file_name = os.path.join(output_folder, "spec2vec_model.model") @@ -69,14 +79,15 @@ def train_all_models(annotated_training_spectra, os.path.join(output_folder, "library_for_training_ms2query"), ms2deepscore_model_file_name, spec2vec_model_file_name, - fraction_for_training=settings.ms2query_fraction_for_making_pairs) + settings) convert_to_onnx_model(ms2query_model, ms2query_model_file_name) # Create library with all training spectra library_files_creator = LibraryFilesCreator(annotated_training_spectra, output_folder, spec2vec_model_file_name, - ms2deepscore_model_file_name) + ms2deepscore_model_file_name, + add_compound_classes=settings.add_compound_classes) library_files_creator.create_all_library_files() @@ -92,12 +103,19 @@ def clean_and_train_models(spectrum_file: str, The ion mode of the spectra you want to use for training the models, choose from "positive" or "negative" :param output_folder: The folder in which the models and library files are stored. + :param model_train_settings: + The settings used for training the models, options can be found in SettingsTrainingModels. If None is given + all the default settings are used. The options and default settings are: + {"ms2ds_fraction_validation_spectra": 30, "ms2ds_epochs": 150, "spec2vec_iterations": 30, + "ms2query_fraction_for_making_pairs": 40, "add_compound_classes": False} """ if not os.path.exists(output_folder): os.mkdir(output_folder) assert os.path.isdir(output_folder), "The specified folder is not a folder" assert ion_mode in {"positive", "negative"}, "ion_mode should be set to 'positive' or 'negative'" + settings = SettingsTrainingModels(model_train_settings) + spectra = load_matchms_spectrum_objects_from_file(spectrum_file) annotated_spectra, unnnotated_spectra = clean_normalize_and_split_annotated_spectra(spectra, ion_mode, @@ -105,4 +123,4 @@ def clean_and_train_models(spectrum_file: str, train_all_models(annotated_spectra, unnnotated_spectra, output_folder, - model_train_settings) + settings) diff --git a/ms2query/create_new_library/train_ms2query_model.py b/ms2query/create_new_library/train_ms2query_model.py index 1162fc69..9c1c1f6a 100644 --- a/ms2query/create_new_library/train_ms2query_model.py +++ b/ms2query/create_new_library/train_ms2query_model.py @@ -21,7 +21,7 @@ split_spectra_on_inchikeys, split_training_and_validation_spectra) from ms2query.query_from_sqlite_database import SqliteLibrary from ms2query.utils import return_non_existing_file_name, save_pickled_file - +from ms2query.create_new_library.train_models import SettingsTrainingModels class DataCollectorForTraining(): """Class to collect data needed to train a ms2query random forest""" @@ -115,13 +115,15 @@ def train_ms2query_model(training_spectra, library_files_folder, ms2ds_model_file_name, s2v_model_file_name, - fraction_for_training): + settings: SettingsTrainingModels): # Select spectra belonging to a single InChIKey - library_spectra, unique_inchikey_query_spectra = split_spectra_on_inchikeys(training_spectra, - fraction_for_training) + library_spectra, unique_inchikey_query_spectra = split_spectra_on_inchikeys( + training_spectra, + settings.ms2query_fraction_for_making_pairs) # Select random spectra from the library - library_spectra, single_spectra_query_spectra = split_training_and_validation_spectra(library_spectra, - fraction_for_training) + library_spectra, single_spectra_query_spectra = split_training_and_validation_spectra( + library_spectra, + settings.ms2query_fraction_for_making_pairs) query_spectra_for_training = unique_inchikey_query_spectra + single_spectra_query_spectra # Create library files for training ms2query @@ -138,7 +140,7 @@ def train_ms2query_model(training_spectra, pickled_ms2ds_embeddings_file_name=library_creator_for_training.ms2ds_embeddings_file_name, ms2query_model_file_name=None) # Create training data MS2Query model - collector = DataCollectorForTraining(ms2library_for_training) + collector = DataCollectorForTraining(ms2library_for_training, preselection_cut_off=settings.preselection_cut_off) training_scores, training_labels = collector.get_matches_info_and_tanimoto(query_spectra_for_training) save_pickled_file(training_scores, os.path.join(library_files_folder, "training_scores_ms2query")) diff --git a/tests/test_train_models.py b/tests/test_train_models.py index 1d5efe72..61f6cc01 100644 --- a/tests/test_train_models.py +++ b/tests/test_train_models.py @@ -14,7 +14,8 @@ def test_train_all_models(path_to_general_test_files, tmp_path): {"ms2ds_fraction_validation_spectra": 2, "ms2ds_epochs": 2, "spec2vec_iterations": 2, - "ms2query_fraction_for_making_pairs": 400} + "ms2query_fraction_for_making_pairs": 400, + "add_compound_classes": False} ) ms2library = create_library_object_from_one_dir(models_folder) assert isinstance(ms2library, MS2Library) diff --git a/tests/test_train_ms2query_model.py b/tests/test_train_ms2query_model.py index 379a9599..c4bdbd8f 100644 --- a/tests/test_train_ms2query_model.py +++ b/tests/test_train_ms2query_model.py @@ -4,13 +4,14 @@ import pandas as pd import pytest from onnxruntime import InferenceSession +from ms2query.create_new_library.train_models import SettingsTrainingModels from ms2query.create_new_library.train_ms2query_model import ( DataCollectorForTraining, calculate_tanimoto_scores_with_library, convert_to_onnx_model, train_ms2query_model, train_random_forest) -from ms2query.ms2library import MS2Library from ms2query.utils import predict_onnx_model from matchms import Spectrum + if sys.version_info < (3, 8): pass else: @@ -78,5 +79,7 @@ def test_train_ms2query_model(path_to_general_test_files, tmp_path, hundred_test "ms2ds_siamese_210301_5000_500_400.hdf5"), s2v_model_file_name=os.path.join(path_to_general_test_files, "100_test_spectra_s2v_model.model"), - fraction_for_training=10 + settings=SettingsTrainingModels({"ms2query_fraction_for_making_pairs": 10, + "add_compound_classes": False, + "preselection_cut_off": 10}) )