From 541f34435c533f776894e58dcd4f555b583eed04 Mon Sep 17 00:00:00 2001 From: rukubrakov Date: Wed, 4 Mar 2026 14:42:25 +0100 Subject: [PATCH 1/8] feature: add all recent metadata fixes --- .gitignore | 7 + simba/commands/analog_discovery.py | 3 + simba/configs/model/simba_default.yaml | 1 + simba/core/data/augmentation.py | 130 ++++++++----- .../data/datasets/encoder_dataset_builder.py | 53 +++++ simba/core/data/datasets/multitask_dataset.py | 183 ++++++++++++------ .../datasets/multitask_dataset_builder.py | 75 +++---- simba/core/data/encoding.py | 20 +- simba/core/models/similarity_models.py | 124 ++++++------ simba/core/models/spectrum_encoder.py | 107 +++++++--- simba/workflows/inference.py | 14 +- simba/workflows/preprocessing.py | 54 +++--- simba/workflows/training.py | 3 + test_all_commands.sh | 3 + 14 files changed, 525 insertions(+), 252 deletions(-) diff --git a/.gitignore b/.gitignore index bf04d5b..8775fbe 100644 --- a/.gitignore +++ b/.gitignore @@ -86,6 +86,13 @@ lightning_logs/ data_temp/ temp/ +# output +test_full_workflow/ + +# tfs files +simba/configs/*/tfs*.yaml +run_scripts_tfs* + # ============================================================================ # Code Quality Tools # ============================================================================ diff --git a/simba/commands/analog_discovery.py b/simba/commands/analog_discovery.py index 7061c5c..4089695 100644 --- a/simba/commands/analog_discovery.py +++ b/simba/commands/analog_discovery.py @@ -191,5 +191,8 @@ def _analog_discovery_with_hydra( click.echo("=" * 70) except Exception as e: + import traceback click.echo(f"\n❌ Error during analog discovery: {e}", err=True) + click.echo("\nFull traceback:", err=True) + click.echo(traceback.format_exc(), err=True) raise click.Abort() from e diff --git a/simba/configs/model/simba_default.yaml b/simba/configs/model/simba_default.yaml index bed1e99..cb46a9a 100644 --- a/simba/configs/model/simba_default.yaml +++ b/simba/configs/model/simba_default.yaml @@ -33,6 +33,7 @@ features: use_element_wise: true categorical_adducts: false use_only_protonized_adducts: false + use_ion_mode: false # Metadata features use_ce: false diff --git a/simba/core/data/augmentation.py b/simba/core/data/augmentation.py index 45c68ad..3c69d16 100644 --- a/simba/core/data/augmentation.py +++ b/simba/core/data/augmentation.py @@ -5,6 +5,22 @@ class Augmentation: + @staticmethod + def zero_and_pack(arr, zero_idx): + """ + put zeros to the array and move the nonzero values to the beginning + """ + mask = np.ones(len(arr), dtype=bool) + mask[zero_idx] = False + complement_idx = np.where(mask)[0] + + arr = arr.copy() + arr[zero_idx] = 0 + nonzero = arr[complement_idx] + return np.concatenate( + [nonzero, np.zeros(len(arr) - len(nonzero), dtype=arr.dtype)] + ) + @staticmethod def augment(data_sample, training=False, max_num_peaks=None): new_sample = copy.deepcopy(data_sample) @@ -49,7 +65,7 @@ def normalize_max(data_sample): return data_sample @staticmethod - def peak_augmentation_max_peaks(data_sample, p_augmentation=1.0, max_peaks=100): + def peak_augmentation_max_peaks(data_sample, p_augmentation=0.50, max_peaks=100): # first normalize to maximum ## half of the time select maximum 20, the other half something between 5 and the maximum number of peaks @@ -71,10 +87,14 @@ def peak_augmentation_max_peaks(data_sample, p_augmentation=1.0, max_peaks=100): intensity_ordered_indexes = np.argsort(intensity)[ ::-1 ] # flip the order to have the max at the beginning - indexes_to_be_erased = intensity_ordered_indexes[max_augmented_peaks:-1] + indexes_to_be_erased = intensity_ordered_indexes[ + max_augmented_peaks:-1 + ] - intensity[indexes_to_be_erased] = 0 - mz[indexes_to_be_erased] = 0 + intensity = Augmentation.zero_and_pack( + intensity, indexes_to_be_erased + ) + mz = Augmentation.zero_and_pack(mz, indexes_to_be_erased) # apply data_sample[intensity_column] = intensity @@ -85,7 +105,7 @@ def peak_augmentation_max_peaks(data_sample, p_augmentation=1.0, max_peaks=100): @staticmethod def peak_augmentation_removal_noise( - data_sample, max_percentage=0.01, p_augmentation=1.0 + data_sample, max_percentage=0.01, p_augmentation=0.5 ): if random.random() < p_augmentation: # first normalize to maximum @@ -100,9 +120,13 @@ def peak_augmentation_removal_noise( max_amplitude = random.random() * max_percentage # indexes_to_modify=intensity < max_amplitude - indexes_to_be_erased = intensity < max_amplitude - intensity[indexes_to_be_erased] = 0 - mz[indexes_to_be_erased] = 0 + indexes_to_be_erased = intensity < ( + max_amplitude * np.max(intensity, keepdims=True) + ) + intensity = Augmentation.zero_and_pack( + intensity, indexes_to_be_erased + ) + mz = Augmentation.zero_and_pack(mz, indexes_to_be_erased) # apply data_sample[intensity_column] = intensity @@ -125,26 +149,6 @@ def normalize_intensities(data_sample, intensity_labels=None): data_sample[intensity_column] = intensity return data_sample - @staticmethod - def inversion(data_sample): - # inversion - - new_sample = {} - new_sample["mz_0"] = data_sample["mz_1"] - new_sample["mz_1"] = data_sample["mz_0"] - - new_sample["intensity_0"] = data_sample["intensity_1"] - new_sample["intensity_1"] = data_sample["intensity_0"] - - new_sample["precursor_mass_0"] = data_sample["precursor_mass_1"] - new_sample["precursor_mass_1"] = data_sample["precursor_mass_0"] - - new_sample["precursor_charge_0"] = data_sample["precursor_charge_1"] - new_sample["precursor_charge_1"] = data_sample["precursor_charge_0"] - - new_sample["similarity"] = data_sample["similarity"] - return new_sample - @staticmethod def add_false_precursor_masses_positives( sample, max_noise=0.01, p_augmentation=0.2 @@ -189,7 +193,7 @@ def add_false_precursor_masses_negatives( return sample @staticmethod - def random_peak_dropout(data_sample, dropout_rate=0.10, p_augmentation=1.0): + def random_peak_dropout(data_sample, dropout_rate=0.10, p_augmentation=0.5): """ Randomly zero out a percentage of peaks to simulate partial data loss. """ @@ -203,10 +207,11 @@ def random_peak_dropout(data_sample, dropout_rate=0.10, p_augmentation=1.0): n_peaks = len(intensity_array) n_drop = int(n_peaks * dropout_rate) # choose random peaks to drop - drop_indices = random.sample(range(n_peaks), n_drop) - for idx in drop_indices: - intensity_array[idx] = 0 - mz_array[idx] = 0 + drop_indices = np.array(random.sample(range(n_peaks), n_drop)) + intensity_array = Augmentation.zero_and_pack( + intensity_array, drop_indices + ) + mz_array = Augmentation.zero_and_pack(mz_array, drop_indices) data_sample[int_key] = intensity_array data_sample[mz_key] = mz_array @@ -219,19 +224,50 @@ def masking_metadata(data_sample, p_aug=0.5): """ keys_found = list(data_sample.keys()) if random.random() < p_aug: - if "ionmode_0" in keys_found: - data_sample["ionmode_0"] = 0 * data_sample["ionmode_0"] - data_sample["ionmode_1"] = 0 * data_sample["ionmode_1"] - - if "adduct_0" in keys_found: - data_sample["adduct_0"] = 0 * data_sample["adduct_0"] - data_sample["adduct_1"] = 0 * data_sample["adduct_1"] - - if "ce_0" in keys_found: - data_sample["ce_0"] = 0 * data_sample["ce_0"] - data_sample["ce_1"] = 0 * data_sample["ce_1"] + if random.random() < 0.5: + if "ionmode_0" in keys_found: + data_sample["ionmode_0"] = 0 * data_sample["ionmode_0"] + data_sample["ionmode_1"] = 0 * data_sample["ionmode_1"] + data_sample["precursor_charge_0"] = ( + 0 * data_sample["precursor_charge_0"] + ) + data_sample["precursor_charge_1"] = ( + 0 * data_sample["precursor_charge_1"] + ) + if random.random() < 0.5: + if "adduct_0" in keys_found: + data_sample["adduct_0"] = 0 * data_sample["adduct_0"] + data_sample["adduct_1"] = 0 * data_sample["adduct_1"] + if random.random() < 0.5: + if "ce_0" in keys_found: + data_sample["ce_0"] = 0 * data_sample["ce_0"] + data_sample["ce_1"] = 0 * data_sample["ce_1"] + if random.random() < 0.5: + if "ion_activation_0" in keys_found: + data_sample["ion_activation_0"] = ( + 0 * data_sample["ion_activation_0"] + ) + data_sample["ion_activation_1"] = ( + 0 * data_sample["ion_activation_1"] + ) - if "ion_activation_0" in keys_found: - data_sample["ion_activation_0"] = 0 * data_sample["ion_activation_0"] - data_sample["ion_activation_1"] = 0 * data_sample["ion_activation_1"] + else: + data_sample["ionmode_0"] = 0 * data_sample["ionmode_0"] + data_sample["ionmode_1"] = 0 * data_sample["ionmode_1"] + data_sample["precursor_charge_0"] = ( + 0 * data_sample["precursor_charge_0"] + ) + data_sample["precursor_charge_1"] = ( + 0 * data_sample["precursor_charge_1"] + ) + data_sample["adduct_0"] = 0 * data_sample["adduct_0"] + data_sample["adduct_1"] = 0 * data_sample["adduct_1"] + data_sample["ce_0"] = 0 * data_sample["ce_0"] + data_sample["ce_1"] = 0 * data_sample["ce_1"] + data_sample["ion_activation_0"] = ( + 0 * data_sample["ion_activation_0"] + ) + data_sample["ion_activation_1"] = ( + 0 * data_sample["ion_activation_1"] + ) return data_sample diff --git a/simba/core/data/datasets/encoder_dataset_builder.py b/simba/core/data/datasets/encoder_dataset_builder.py index c9e428f..fe8647b 100644 --- a/simba/core/data/datasets/encoder_dataset_builder.py +++ b/simba/core/data/datasets/encoder_dataset_builder.py @@ -4,6 +4,16 @@ from simba.core.data.datasets.encoder_dataset import CustomDatasetEncoder from simba.core.data.preprocessor import Preprocessor +from simba.core.chemistry.chem_utils import ( + ADDUCT_TO_MASS, +) +from simba.core.data.encoding import ( + IONIZATION_METHODS, + ION_ACTIVATION, + encode_adduct_mass, + encode_ionization_method, + encode_ion_activation, +) def prepare_encoder_dataset(spectra, max_num_peaks=100): @@ -27,6 +37,13 @@ def prepare_encoder_dataset(spectra, max_num_peaks=100): intensity = np.zeros((len(spectra), max_num_peaks), dtype=np.float32) precursor_mass = np.zeros((len(spectra), 1), dtype=np.float32) precursor_charge = np.zeros((len(spectra), 1), dtype=np.int32) + + # Metadata fields (initialized with default values) + ionmode = np.zeros((len(spectra), 1), dtype=np.float32) + adduct = np.zeros((len(spectra), len(ADDUCT_TO_MASS.keys())), dtype=np.float32) + ce = np.zeros((len(spectra), 1), dtype=np.int32) + ia = np.zeros((len(spectra), len(ION_ACTIVATION)), dtype=np.int32) + im = np.zeros((len(spectra), len(IONIZATION_METHODS)), dtype=np.int32) for i, spectrum in enumerate(spectra): # check for maximum length @@ -40,6 +57,37 @@ def prepare_encoder_dataset(spectra, max_num_peaks=100): precursor_mass[i] = spectrum.precursor_mz precursor_charge[i] = spectrum.precursor_charge + + # Extract metadata if available + # Ion mode + if hasattr(spectrum, 'ionmode') and spectrum.ionmode is not None and spectrum.ionmode != "None": + ionmode[i] = 1.0 if spectrum.ionmode.lower() == "positive" else -1.0 + else: + ionmode[i] = 0.0 + + # Adduct + if hasattr(spectrum, 'params') and 'adduct' in spectrum.params: + adduct[i] = encode_adduct_mass(spectrum.params['adduct']) + elif hasattr(spectrum, 'adduct') and spectrum.adduct is not None: + adduct[i] = encode_adduct_mass(spectrum.adduct) + + # Collision Energy + if hasattr(spectrum, 'ce') and spectrum.ce is not None and spectrum.ce != "None": + ce[i] = spectrum.ce + else: + ce[i] = 0 + + # Ion Activation + if hasattr(spectrum, 'ion_activation') and spectrum.ion_activation is not None and spectrum.ion_activation != "None": + ia[i] = encode_ion_activation(spectrum.ion_activation) + else: + ia[i] = np.zeros(len(ION_ACTIVATION), dtype=np.int32) + + # Ionization Method + if hasattr(spectrum, 'ionization_method') and spectrum.ionization_method is not None and spectrum.ionization_method != "None": + im[i] = encode_ionization_method(spectrum.ionization_method) + else: + im[i] = np.zeros(len(IONIZATION_METHODS), dtype=np.int32) # Normalize the intensity array intensity = intensity / np.sqrt(np.sum(intensity**2, axis=1, keepdims=True)) @@ -49,6 +97,11 @@ def prepare_encoder_dataset(spectra, max_num_peaks=100): "intensity": intensity, "precursor_mass": precursor_mass, "precursor_charge": precursor_charge, + "ionmode": ionmode, + "adduct": adduct, + "ce": ce, + "ion_activation": ia, + "ion_method": im, } return CustomDatasetEncoder(spectrum_data) diff --git a/simba/core/data/datasets/multitask_dataset.py b/simba/core/data/datasets/multitask_dataset.py index 15d6ab4..05950ff 100644 --- a/simba/core/data/datasets/multitask_dataset.py +++ b/simba/core/data/datasets/multitask_dataset.py @@ -17,8 +17,7 @@ def __init__( self, your_dict, training=False, - prob_aug=1.0, - # prob_aug=0.2, + prob_aug=0.50, mz=None, intensity=None, precursor_mass=None, @@ -36,6 +35,7 @@ def __init__( ion_activation=None, use_ion_method=False, ion_method=None, + use_ion_mode=False, ): self.data = your_dict self.keys = list(your_dict.keys()) @@ -52,23 +52,19 @@ def __init__( self.use_ce = use_ce self.use_ion_activation = use_ion_activation self.use_ion_method = use_ion_method + self.use_ion_mode = use_ion_mode if self.use_fingerprints: self.fingerprint_0 = fingerprint_0 self.max_num_peaks = max_num_peaks - if self.use_adduct: - self.ionmode = ionmode - self.adduct_mass = adduct - - if self.use_ce: - self.ce = ce + self.adduct_mass = adduct + self.ionmode = ionmode + self.ce = ce - if self.use_ion_activation: - self.ion_activation = ion_activation + self.ion_activation = ion_activation - if self.use_ion_method: - self.ion_method = ion_method + self.ion_method = ion_method def __len__(self): return len(self.data[self.keys[0]]) @@ -82,20 +78,32 @@ def get_original_dictionary(self, max_num_peaks=100): ## Get the mz, intensity values and precursor data dictionary = {} - dictionary["mz_0"] = np.zeros((len_data, max_num_peaks), dtype=np.float32) + dictionary["mz_0"] = np.zeros( + (len_data, max_num_peaks), dtype=np.float32 + ) dictionary["intensity_0"] = np.zeros( (len_data, max_num_peaks), dtype=np.float32 ) - dictionary["mz_1"] = np.zeros((len_data, max_num_peaks), dtype=np.float32) + dictionary["mz_1"] = np.zeros( + (len_data, max_num_peaks), dtype=np.float32 + ) dictionary["intensity_1"] = np.zeros( (len_data, max_num_peaks), dtype=np.float32 ) dictionary["ed"] = np.zeros((len_data, 1), dtype=np.float32) dictionary["mces"] = np.zeros((len_data, 1), dtype=np.float32) - dictionary["precursor_mass_0"] = np.zeros((len_data, 1), dtype=np.float32) - dictionary["precursor_charge_0"] = np.zeros((len_data, 1), dtype=np.int32) - dictionary["precursor_mass_1"] = np.zeros((len_data, 1), dtype=np.float32) - dictionary["precursor_charge_1"] = np.zeros((len_data, 1), dtype=np.int32) + dictionary["precursor_mass_0"] = np.zeros( + (len_data, 1), dtype=np.float32 + ) + dictionary["precursor_charge_0"] = np.zeros( + (len_data, 1), dtype=np.int32 + ) + dictionary["precursor_mass_1"] = np.zeros( + (len_data, 1), dtype=np.float32 + ) + dictionary["precursor_charge_1"] = np.zeros( + (len_data, 1), dtype=np.int32 + ) ### add extra metadata in case it is necessary if self.use_adduct: @@ -130,7 +138,9 @@ def get_original_dictionary(self, max_num_peaks=100): if self.use_fingerprints: print("Defining fingerprints ...") - dictionary["fingerprint_0"] = np.zeros((len_data, 2048), dtype=np.int32) + dictionary["fingerprint_0"] = np.zeros( + (len_data, 2048), dtype=np.int32 + ) for idx in tqdm(range(0, len_data)): sample_unique = {k: self.data[k][idx] for k in self.keys} @@ -139,19 +149,27 @@ def get_original_dictionary(self, max_num_peaks=100): indexes_unique_1 = sample_unique["index_unique_1"] print(f"value of indexes_unique_0 {indexes_unique_0} ") - indexes_original_0 = self.df_smiles.loc[int(indexes_unique_0), "indexes"][0] + indexes_original_0 = self.df_smiles.loc[ + int(indexes_unique_0), "indexes" + ][0] - indexes_original_1 = self.df_smiles.loc[int(indexes_unique_1), "indexes"][0] + indexes_original_1 = self.df_smiles.loc[ + int(indexes_unique_1), "indexes" + ][0] - dictionary["mz_0"][idx] = self.mz[indexes_original_0].astype(np.float32) - dictionary["intensity_0"][idx] = self.intensity[indexes_original_0].astype( + dictionary["mz_0"][idx] = self.mz[indexes_original_0].astype( np.float32 ) + dictionary["intensity_0"][idx] = self.intensity[ + indexes_original_0 + ].astype(np.float32) - dictionary["mz_1"][idx] = self.mz[indexes_original_1].astype(np.float32) - dictionary["intensity_1"][idx] = self.intensity[indexes_original_1].astype( + dictionary["mz_1"][idx] = self.mz[indexes_original_1].astype( np.float32 ) + dictionary["intensity_1"][idx] = self.intensity[ + indexes_original_1 + ].astype(np.float32) dictionary["precursor_mass_0"][idx] = self.precursor_mass[ indexes_original_0 ].astype(np.float32) @@ -166,14 +184,14 @@ def get_original_dictionary(self, max_num_peaks=100): ].astype(np.float32) dictionary["ed"][idx] = sample_unique["ed"].astype(np.float32) dictionary["mces"][idx] = sample_unique["mces"].astype(np.float32) + if self.use_ion_mode: + dictionary["ionmode_0"][idx] = self.ionmode[ + indexes_original_0 + ].astype(np.float32) + dictionary["ionmode_1"][idx] = self.ionmode[ + indexes_original_1 + ].astype(np.float32) if self.use_adduct: - dictionary["ionmode_0"][idx] = self.ionmode[indexes_original_0].astype( - np.float32 - ) - dictionary["ionmode_1"][idx] = self.ionmode[indexes_original_1].astype( - np.float32 - ) - dictionary["adduct_0"][idx] = self.adduct_mass[ indexes_original_0 ].astype(np.float32) @@ -182,8 +200,12 @@ def get_original_dictionary(self, max_num_peaks=100): ].astype(np.float32) if self.use_ce: - dictionary["ce_0"][idx] = self.ce[indexes_original_0].astype(np.float32) - dictionary["ce_1"][idx] = self.ce[indexes_original_1].astype(np.float32) + dictionary["ce_0"][idx] = self.ce[indexes_original_0].astype( + np.float32 + ) + dictionary["ce_1"][idx] = self.ce[indexes_original_1].astype( + np.float32 + ) if self.use_ion_activation: dictionary["ion_activation_0"][idx] = self.ion_activation[ @@ -216,8 +238,12 @@ def __getitem__(self, idx): if self.training: # select random samples - idx_0_original = random.choice(self.df_smiles.loc[int(idx_0[0]), "indexes"]) - idx_1_original = random.choice(self.df_smiles.loc[int(idx_1[0]), "indexes"]) + idx_0_original = random.choice( + self.df_smiles.loc[int(idx_0[0]), "indexes"] + ) + idx_1_original = random.choice( + self.df_smiles.loc[int(idx_1[0]), "indexes"] + ) else: # select the first index idx_0_original = self.df_smiles.loc[int(idx_0[0]), "indexes"][0] @@ -253,20 +279,20 @@ def __getitem__(self, idx): ind = int(idx_0[0]) if self.training: if (ind % 2) == 0: - spectrum_sample["fingerprint_0"] = self.fingerprint_0[ind].astype( - np.float32 - ) + spectrum_sample["fingerprint_0"] = self.fingerprint_0[ + ind + ].astype(np.float32) else: # return 0s spectrum_sample["fingerprint_0"] = 0 * self.fingerprint_0[ ind ].astype(np.float32) else: - spectrum_sample["fingerprint_0"] = self.fingerprint_0[ind].astype( - np.float32 - ) + spectrum_sample["fingerprint_0"] = self.fingerprint_0[ + ind + ].astype(np.float32) - if self.use_adduct: + if self.use_ion_mode: spectrum_sample["ionmode_0"] = self.ionmode[idx_0_original].astype( np.float32 ) @@ -274,17 +300,42 @@ def __getitem__(self, idx): np.float32 ) - spectrum_sample["adduct_0"] = self.adduct_mass[idx_0_original].astype( + else: + spectrum_sample["ionmode_0"] = 0 * self.ionmode[ + idx_0_original + ].astype(np.float32) + spectrum_sample["ionmode_1"] = 0 * self.ionmode[ + idx_1_original + ].astype(np.float32) + + if self.use_adduct: + spectrum_sample["adduct_0"] = self.adduct_mass[ + idx_0_original + ].astype(np.float32) + spectrum_sample["adduct_1"] = self.adduct_mass[ + idx_1_original + ].astype(np.float32) + else: + spectrum_sample["adduct_0"] = 0 * self.adduct_mass[ + idx_0_original + ].astype(np.float32) + spectrum_sample["adduct_1"] = 0 * self.adduct_mass[ + idx_1_original + ].astype(np.float32) + if self.use_ce: + spectrum_sample["ce_0"] = self.ce[idx_0_original].astype( np.float32 ) - spectrum_sample["adduct_1"] = self.adduct_mass[idx_1_original].astype( + spectrum_sample["ce_1"] = self.ce[idx_1_original].astype( + np.float32 + ) + else: + spectrum_sample["ce_0"] = 0 * self.ce[idx_0_original].astype( + np.float32 + ) + spectrum_sample["ce_1"] = 0 * self.ce[idx_1_original].astype( np.float32 ) - - if self.use_ce: - spectrum_sample["ce_0"] = self.ce[idx_0_original].astype(np.float32) - spectrum_sample["ce_1"] = self.ce[idx_1_original].astype(np.float32) - if self.use_ion_activation: spectrum_sample["ion_activation_0"] = self.ion_activation[ idx_0_original @@ -292,16 +343,28 @@ def __getitem__(self, idx): spectrum_sample["ion_activation_1"] = self.ion_activation[ idx_1_original ].astype(np.float32) - + else: + spectrum_sample["ion_activation_0"] = 0 * self.ion_activation[ + idx_0_original + ].astype(np.float32) + spectrum_sample["ion_activation_1"] = 0 * self.ion_activation[ + idx_1_original + ].astype(np.float32) if self.use_ion_method: - spectrum_sample["ion_method_0"] = self.ion_method[idx_0_original].astype( - np.float32 - ) - spectrum_sample["ion_method_1"] = self.ion_method[idx_1_original].astype( - np.float32 - ) - - if self.training and random.random() < self.prob_aug: + spectrum_sample["ion_method_0"] = self.ion_method[ + idx_0_original + ].astype(np.float32) + spectrum_sample["ion_method_1"] = self.ion_method[ + idx_1_original + ].astype(np.float32) + else: + spectrum_sample["ion_method_0"] = 0 * self.ion_method[ + idx_0_original + ].astype(np.float32) + spectrum_sample["ion_method_1"] = 0 * self.ion_method[ + idx_1_original + ].astype(np.float32) + if self.training and (random.random() < self.prob_aug): # augmentation spectrum_sample = Augmentation.augment( spectrum_sample, max_num_peaks=self.max_num_peaks @@ -309,4 +372,4 @@ def __getitem__(self, idx): # normalize spectrum_sample = Augmentation.normalize_intensities(spectrum_sample) - return spectrum_sample + return spectrum_sample \ No newline at end of file diff --git a/simba/core/data/datasets/multitask_dataset_builder.py b/simba/core/data/datasets/multitask_dataset_builder.py index 3626353..666d039 100644 --- a/simba/core/data/datasets/multitask_dataset_builder.py +++ b/simba/core/data/datasets/multitask_dataset_builder.py @@ -10,7 +10,7 @@ from simba.core.data.encoding import ( ION_ACTIVATION, IONIZATION_METHODS, - encode_adduct, + encode_adduct_mass, encode_ion_activation, encode_ionization_method, ) @@ -36,6 +36,7 @@ def from_molecule_pairs_to_dataset( use_ce: bool = False, use_ion_activation: bool = False, use_ion_method: bool = False, + use_ion_mode: bool = False, ) -> CustomDatasetMultitasking: """ Load data from molecule pairs into a Pytorch dataset for multitask learning. @@ -98,35 +99,33 @@ def from_molecule_pairs_to_dataset( precursor_charge = np.zeros( (len(molecule_pairs.original_spectra), 1), dtype=np.int32 ) - if use_adduct: - ionmode = np.zeros( - (len(molecule_pairs.original_spectra), 1), dtype=np.float32 - ) - adduct = np.zeros( - ( - len(molecule_pairs.original_spectra), - len(ADDUCT_TO_MASS.keys()), - ), - dtype=np.float32, - ) - if use_ce: - ce = np.zeros((len(molecule_pairs.original_spectra), 1), dtype=np.int32) - if use_ion_activation: - ia = np.zeros( - ( - len(molecule_pairs.original_spectra), - len(ION_ACTIVATION), - ), - dtype=np.int32, - ) - if use_ion_method: - im = np.zeros( - ( - len(molecule_pairs.original_spectra), - len(IONIZATION_METHODS), - ), - dtype=np.int32, - ) + ionmode = np.zeros( + (len(molecule_pairs.original_spectra), 1), dtype=np.float32 + ) + adduct = np.zeros( + ( + len(molecule_pairs.original_spectra), + len(ADDUCT_TO_MASS.keys()), + ), + dtype=np.float32, + ) + ce = np.zeros( + (len(molecule_pairs.original_spectra), 1), dtype=np.int32 + ) + ia = np.zeros( + ( + len(molecule_pairs.original_spectra), + len(ION_ACTIVATION), + ), + dtype=np.int32, + ) + im = np.zeros( + ( + len(molecule_pairs.original_spectra), + len(IONIZATION_METHODS), + ), + dtype=np.int32, + ) logger.info("Loading mz, intensity and precursor data ...") for i, spec in enumerate(molecule_pairs.original_spectra): @@ -140,14 +139,15 @@ def from_molecule_pairs_to_dataset( precursor_mass[i] = spec.precursor_mz precursor_charge[i] = spec.precursor_charge - if use_adduct: + if use_ion_mode: if (spec.ionmode is None) or ( spec.ionmode == "None" ): # TODO: check if the 2nd condition is needed ionmode[i] = 0 else: ionmode[i] = 1.0 if spec.ionmode == "positive" else -1.0 - adduct[i] = encode_adduct(spec.adduct) + if use_adduct: + adduct[i] = encode_adduct_mass(spec.params["adduct"]) if use_ce: if (spec.ce is None) or (spec.ce == "None"): @@ -217,12 +217,13 @@ def from_molecule_pairs_to_dataset( fingerprint_0=fingerprint_0, max_num_peaks=max_num_peaks, use_adduct=use_adduct, - ionmode=(ionmode if use_adduct else None), - adduct=(adduct if use_adduct else None), + use_ion_mode=use_ion_mode, + ionmode=ionmode, + adduct=adduct, use_ce=use_ce, - ce=(ce if use_ce else None), + ce=ce, use_ion_activation=use_ion_activation, - ion_activation=(ia if use_ion_activation else None), + ion_activation=ia, use_ion_method=use_ion_method, - ion_method=(im if use_ion_method else None), + ion_method=im, ) diff --git a/simba/core/data/encoding.py b/simba/core/data/encoding.py index b1f79a3..f04e9ce 100644 --- a/simba/core/data/encoding.py +++ b/simba/core/data/encoding.py @@ -7,7 +7,25 @@ IONIZATION_METHODS = ["NSI", "ESI", "APCI"] -def encode_adduct(adduct: str): +def encode_adduct_mass(adduct: str): + """Encode adduct as its mass. + + Args: + adduct: Adduct string (e.g., '[M+H]+') + + Returns: + float: Mass of the adduct, or 0 if adduct is not recognized + """ + # TODO: how encode adduct if not recognized? + # Currently returns 0, but might interfere with spectra without adducts + adducts = ADDUCT_TO_MASS.keys() + response = [0 for a in adducts] + if adduct in adducts: + response[0] = ADDUCT_TO_MASS[adduct] + return response + + +def encode_adduct_one_hot(adduct: str): """Encode adduct string as one-hot vector. Args: diff --git a/simba/core/models/similarity_models.py b/simba/core/models/similarity_models.py index 4551fbf..adde69d 100644 --- a/simba/core/models/similarity_models.py +++ b/simba/core/models/similarity_models.py @@ -47,6 +47,7 @@ def __init__( use_ce=False, use_ion_activation=False, use_ion_method=False, + use_ion_mode=False, ): """Initialize the CCSPredictor""" super().__init__() @@ -63,6 +64,7 @@ def __init__( self.use_ce = use_ce self.use_ion_activation = use_ion_activation self.use_ion_method = use_ion_method + self.use_ion_mode = use_ion_mode self.spectrum_encoder = SpectrumTransformerEncoderCustom( d_model=d_model, @@ -72,6 +74,7 @@ def __init__( use_ce=use_ce, use_ion_activation=use_ion_activation, use_ion_method=use_ion_method, + use_ion_mode= use_ion_mode, ) self.regression_loss = nn.MSELoss() @@ -104,13 +107,16 @@ def forward(self, batch): kwargs_0 = { "precursor_mass": batch["precursor_mass_0"].float(), - "precursor_charge": batch["precursor_charge_0"].float(), } kwargs_1 = { "precursor_mass": batch["precursor_mass_1"].float(), - "precursor_charge": batch["precursor_charge_1"].float(), } # extra data + if self.use_ion_mode: + kwargs_0["ionmode"] = batch["ionmode_0"].float() + kwargs_1["ionmode"] = batch["ionmode_1"].float() + kwargs_0["precursor_charge"]=batch["precursor_charge_0"].float() + kwargs_1["precursor_charge"]=batch["precursor_charge_1"].float() if self.use_adduct: kwargs_0["ionmode"] = batch["ionmode_0"].float() kwargs_1["ionmode"] = batch["ionmode_1"].float() @@ -359,6 +365,7 @@ def __init__( use_ce=False, use_ion_activation=False, use_ion_method=False, + use_ion_mode=False, ): """Initialize the CCSPredictor""" super().__init__( @@ -373,6 +380,7 @@ def __init__( use_ce=use_ce, use_ion_activation=use_ion_activation, use_ion_method=use_ion_method, + use_ion_mode= use_ion_mode, ) self.weights = weights @@ -420,17 +428,10 @@ def __init__( # Initialize learnable log variance parameters for each loss self.USE_LEARNABLE_MULTITASK = USE_LEARNABLE_MULTITASK if USE_LEARNABLE_MULTITASK: - initial_log_sigma1 = 1.2490483522415161 - initial_log_sigma2 = -7.0018157958984375 - # self.log_sigma1 = nn.Parameter(torch.tensor(initial_log_sigma1)) - # self.log_sigma2 = nn.Parameter(torch.tensor(initial_log_sigma2)) - self.log_sigma1 = torch.tensor( - float(initial_log_sigma1), dtype=torch.float32 - ) - self.log_sigma2 = torch.tensor( - float(initial_log_sigma2), dtype=torch.float32 - ) - self.use_extra_metadata = use_adduct + initial_log_sigma1= 0.0 + initial_log_sigma2 = -5.3 + self.log_sigma1 = nn.Parameter(torch.tensor(initial_log_sigma1)) + self.log_sigma2 = nn.Parameter(torch.tensor(initial_log_sigma2)) def forward(self, batch, return_spectrum_output=False): # … compute raw emb0, emb1, apply relu, fingerprints, etc. … @@ -465,56 +466,51 @@ def forward(self, batch, return_spectrum_output=False): "precursor_charge": batch["precursor_charge_1"].float(), } - if self.use_adduct: - batch["ionmode_0"] = torch.nan_to_num( - batch["ionmode_0"], nan=0.0, posinf=0.0, neginf=0.0 - ) - batch["ionmode_1"] = torch.nan_to_num( - batch["ionmode_1"], nan=0.0, posinf=0.0, neginf=0.0 - ) - batch["adduct_0"] = torch.nan_to_num( - batch["adduct_0"], nan=0.0, posinf=0.0, neginf=0.0 - ) - batch["adduct_1"] = torch.nan_to_num( - batch["adduct_1"], nan=0.0, posinf=0.0, neginf=0.0 - ) + batch["ionmode_0"] = torch.nan_to_num( + batch["ionmode_0"], nan=0.0, posinf=0.0, neginf=0.0 + ) + batch["ionmode_1"] = torch.nan_to_num( + batch["ionmode_1"], nan=0.0, posinf=0.0, neginf=0.0 + ) + kwargs_0["ionmode"] = batch["ionmode_0"].float() + kwargs_1["ionmode"] = batch["ionmode_1"].float() + batch["adduct_0"] = torch.nan_to_num( + batch["adduct_0"], nan=0.0, posinf=0.0, neginf=0.0 + ) + batch["adduct_1"] = torch.nan_to_num( + batch["adduct_1"], nan=0.0, posinf=0.0, neginf=0.0 + ) - kwargs_0["ionmode"] = batch["ionmode_0"].float() - kwargs_1["ionmode"] = batch["ionmode_1"].float() - kwargs_0["adduct"] = batch["adduct_0"].float() - kwargs_1["adduct"] = batch["adduct_1"].float() + kwargs_0["adduct"] = batch["adduct_0"].float() + kwargs_1["adduct"] = batch["adduct_1"].float() - if self.use_ce: - batch["ce_0"] = torch.nan_to_num( - batch["ce_0"], nan=0.0, posinf=0.0, neginf=0.0 - ) - batch["ce_1"] = torch.nan_to_num( - batch["ce_1"], nan=0.0, posinf=0.0, neginf=0.0 - ) - kwargs_0["ce"] = batch["ce_0"].float() - kwargs_1["ce"] = batch["ce_1"].float() - - if self.use_ion_activation: - batch["ion_activation_0"] = torch.nan_to_num( - batch["ion_activation_0"], nan=0.0, posinf=0.0, neginf=0.0 - ) - batch["ion_activation_1"] = torch.nan_to_num( - batch["ion_activation_1"], nan=0.0, posinf=0.0, neginf=0.0 - ) - kwargs_0["ion_activation"] = batch["ion_activation_0"].float() - kwargs_1["ion_activation"] = batch["ion_activation_1"].float() + batch["ce_0"] = torch.nan_to_num( + batch["ce_0"], nan=0.0, posinf=0.0, neginf=0.0 + ) + batch["ce_1"] = torch.nan_to_num( + batch["ce_1"], nan=0.0, posinf=0.0, neginf=0.0 + ) + kwargs_0["ce"] = batch["ce_0"].float() + kwargs_1["ce"] = batch["ce_1"].float() - if self.use_ion_method: - batch["ion_method_0"] = torch.nan_to_num( - batch["ion_method_0"], nan=0.0, posinf=0.0, neginf=0.0 - ) - batch["ion_method_1"] = torch.nan_to_num( - batch["ion_method_1"], nan=0.0, posinf=0.0, neginf=0.0 - ) + batch["ion_activation_0"] = torch.nan_to_num( + batch["ion_activation_0"], nan=0.0, posinf=0.0, neginf=0.0 + ) + batch["ion_activation_1"] = torch.nan_to_num( + batch["ion_activation_1"], nan=0.0, posinf=0.0, neginf=0.0 + ) + kwargs_0["ion_activation"] = batch["ion_activation_0"].float() + kwargs_1["ion_activation"] = batch["ion_activation_1"].float() - kwargs_0["ion_method"] = batch["ion_method_0"].float() - kwargs_1["ion_method"] = batch["ion_method_1"].float() + batch["ion_method_0"] = torch.nan_to_num( + batch["ion_method_0"], nan=0.0, posinf=0.0, neginf=0.0 + ) + batch["ion_method_1"] = torch.nan_to_num( + batch["ion_method_1"], nan=0.0, posinf=0.0, neginf=0.0 + ) + kwargs_0["ion_method"] = batch["ion_method_0"].float() + kwargs_1["ion_method"] = batch["ion_method_1"].float() # intensity and mz batch["intensity_0"] = torch.nan_to_num( batch["intensity_0"], nan=0.0, posinf=0.0, neginf=0.0 @@ -793,6 +789,18 @@ def forward(self, batch): "precursor_mass": batch["precursor_mass"].float(), "precursor_charge": batch["precursor_charge"].float(), } + + # Add metadata fields if present in batch + if "ionmode" in batch: + kwargs["ionmode"] = batch["ionmode"].float() + if "adduct" in batch: + kwargs["adduct"] = batch["adduct"].float() + if "ce" in batch: + kwargs["ce"] = batch["ce"].float() + if "ion_activation" in batch: + kwargs["ion_activation"] = batch["ion_activation"].float() + if "ion_method" in batch: + kwargs["ion_method"] = batch["ion_method"].float() emb, _ = self.model( mz_array=batch["mz"].float(), diff --git a/simba/core/models/spectrum_encoder.py b/simba/core/models/spectrum_encoder.py index 3a516a3..c303e5b 100644 --- a/simba/core/models/spectrum_encoder.py +++ b/simba/core/models/spectrum_encoder.py @@ -2,6 +2,7 @@ from depthcharge.transformers import ( SpectrumTransformerEncoder, ) # PeptideTransformerEncoder, +from depthcharge.encoders import FloatEncoder class SpectrumTransformerEncoderCustom(SpectrumTransformerEncoder): @@ -12,6 +13,7 @@ def __init__( use_ce: bool = False, use_ion_activation: bool = False, use_ion_method: bool = False, + use_ion_mode: bool =False, **kwargs, ): """ @@ -28,11 +30,28 @@ def __init__( use_ion_method: bool Whether to include ionization method in the encoding (default: False). """ + self.use_encoders=False super().__init__(*args, **kwargs) self.use_adduct = use_adduct self.use_ce = use_ce self.use_ion_activation = use_ion_activation self.use_ion_method = use_ion_method + self.use_ion_mode = use_ion_mode + + if self.use_encoders: + if self.use_adduct: + self.adduct_encoder = FloatEncoder(self.d_model) + self.ionmode_encoder =FloatEncoder(self.d_model) + if self.use_ce: + self.ce_encoder = FloatEncoder(self.d_model, ) + + if self.use_ion_activation: + self.ion_activation_encoder= FloatEncoder(self.d_model) + + if self.use_ion_method: + self.ion_method_encoder = FloatEncoder(self.d_model) + if (self.use_adduct or self.use_ce or self.use_ion_activation or self.use_ion_method): + self.precursor_mz_encoder =FloatEncoder(self.d_model) def precursor_hook( self, @@ -44,46 +63,88 @@ def precursor_hook( dtype = mz_array.dtype batch_size = mz_array.shape[0] - placeholder = torch.zeros( + if not(self.use_encoders): + placeholder = torch.zeros( (batch_size, self.d_model), dtype=dtype, device=device - ) - precursor_mass = kwargs["precursor_mass"].float().to(device).view(batch_size) - placeholder[:, 0] = precursor_mass + ) + precursor_mass = kwargs["precursor_mass"].float().to(device).view(batch_size) + placeholder[:, 0] = precursor_mass - precursor_charge = ( - kwargs["precursor_charge"].float().to(device).view(batch_size) + precursor_charge = ( + kwargs["precursor_charge"].float().to(device).view(batch_size) ) - placeholder[:, 1] = precursor_charge - - current_idx = 2 # keep track of where to insert metadata - if self.use_adduct: - ionmode = kwargs["ionmode"].float().to(device).view(batch_size) - placeholder[:, current_idx] = ionmode + # skip the use of the precursor charge field + if self.use_ion_mode: + placeholder[:, 1] = precursor_charge + + current_idx = 2 # keep track of where to insert metadata + + + ionmode = kwargs["ionmode"].float().to(device).view(batch_size) + if self.use_ion_mode: + placeholder[:, current_idx] = ionmode current_idx += 1 + adduct = kwargs["adduct"].float().to(device).view(batch_size, -1) stop_idx = current_idx + adduct.shape[1] - placeholder[:, current_idx:stop_idx] = adduct + if self.use_adduct: + placeholder[:, current_idx:stop_idx] = adduct current_idx = stop_idx - - if self.use_ce: + ce = kwargs["ce"].float().to(device).view(batch_size) - placeholder[:, current_idx] = ce + if self.use_ce: + placeholder[:, current_idx] = ce current_idx += 1 - if self.use_ion_activation: + ia = kwargs["ion_activation"].float().to(device).view(batch_size, -1) stop_idx = current_idx + ia.shape[1] - placeholder[:, current_idx:stop_idx] = ia + if self.use_ion_activation: + placeholder[:, current_idx:stop_idx] = ia current_idx = stop_idx - if self.use_ion_method: im = kwargs["ion_method"].float().to(device).view(batch_size, -1) stop_idx = current_idx + im.shape[1] - placeholder[:, current_idx:stop_idx] = im + if self.use_ion_method: + placeholder[:, current_idx:stop_idx] = im current_idx = stop_idx - # ensure there are no nans - placeholder = torch.nan_to_num(placeholder, nan=0.0, posinf=0.0, neginf=0.0) + # ensure there are no nans + placeholder = torch.nan_to_num(placeholder, nan=0.0, posinf=0.0, neginf=0.0) + + + else: + + precursor_mass = kwargs["precursor_mass"].float().to(device).view(batch_size) + precursor_mass_rep = self.precursor_mz_encoder(precursor_mass[:,None]).squeeze(1) + placeholder = precursor_mass_rep + 0*precursor_mass_rep + + if self.use_adduct: + ionmode = kwargs["ionmode"].float().to(device).view(batch_size) + adduct = kwargs["adduct"].float().to(device).view(batch_size, -1) + ionmode_rep = self.ionmode_encoder(ionmode[:,None]).squeeze(1) + adduct_rep = self.adduct_encoder(adduct).mean(dim=1) + + placeholder= placeholder + (ionmode_rep + adduct_rep) + + if self.use_ce: + ce = kwargs["ce"].float().to(device).view(batch_size) + ce_rep = self.ce_encoder(ce[:,None]).squeeze(1) + placeholder= placeholder + ce_rep + + if self.use_ion_method: + im = kwargs["ion_method"].float().to(device).view(batch_size, -1) + im_rep = self.ion_method_encoder(im).mean(dim=1) + placeholder = placeholder + im_rep + + if self.use_ion_activation: + ia = kwargs["ion_activation"].float().to(device).view(batch_size, -1) + ia_rep = self.ion_activation_encoder(ia).mean(dim=1) + placeholder = placeholder + ia_rep + + + + placeholder = torch.nan_to_num(placeholder, nan=0.0, posinf=0.0, neginf=0.0) - return placeholder + return placeholder \ No newline at end of file diff --git a/simba/workflows/inference.py b/simba/workflows/inference.py index d6f18c2..4ff0ca9 100644 --- a/simba/workflows/inference.py +++ b/simba/workflows/inference.py @@ -21,6 +21,7 @@ from simba.core.training.train_utils import TrainUtils from simba.utils.logger_setup import logger from simba.workflows.utils import load_spectra +from sklearn.metrics import mean_absolute_error # Backward compatibility: Support loading old pickle files with old module paths @@ -157,6 +158,7 @@ def prepare_inference_dataloaders( use_ce=cfg.model.features.use_ce, use_ion_activation=cfg.model.features.use_ion_activation, use_ion_method=cfg.model.features.use_ion_method, + use_ion_mode = cfg.model.features.use_ion_mode, ) dataloader_ed = DataLoader( dataset_ed, batch_size=cfg.inference.batch_size, shuffle=False @@ -169,6 +171,7 @@ def prepare_inference_dataloaders( use_ce=cfg.model.features.use_ce, use_ion_activation=cfg.model.features.use_ion_activation, use_ion_method=cfg.model.features.use_ion_method, + use_ion_mode = cfg.model.features.use_ion_mode, ) dataloader_mces = DataLoader( dataset_mces, batch_size=cfg.inference.batch_size, shuffle=False @@ -202,6 +205,7 @@ def load_model_for_inference(cfg: DictConfig, checkpoint_path: str): "use_ce": cfg.model.features.use_ce, "use_ion_activation": cfg.model.features.use_ion_activation, "use_ion_method": cfg.model.features.use_ion_method, + "use_ion_mode": cfg.model.features.use_ion_mode, } model = SimilarityModelMultitask.load_from_checkpoint( @@ -314,7 +318,10 @@ def evaluate_predictions( # Edit distance correlation corr_model_ed, _ = spearmanr(ed_true_clean, pred_ed_ed_clean) + mae_model_ed = mean_absolute_error(ed_true_clean, pred_ed_ed_clean) + logger.info(f"Edit distance correlation: {corr_model_ed:.4f}") + logger.info(f"Edit distance mean absolute error: {mae_model_ed:.4f}") # Plot confusion matrix _plot_cm(ed_true_clean, pred_ed_ed_clean, cfg, output_dir) @@ -335,11 +342,14 @@ def evaluate_predictions( if len(mces_true) == 0 or len(pred_mces_mces_flat) == 0: logger.warning("No MCES samples after filtering, skipping MCES correlation") - corr_model_mces = float("nan") + corr_model_mces = float("nan") + mae_model_mces =float("nan") + else: corr_model_mces, _ = spearmanr(mces_true, pred_mces_mces_flat) - + mae_model_mces = mean_absolute_error(mces_true, pred_mces_mces_flat) logger.info(f"MCES/Tanimoto correlation: {corr_model_mces:.4f}") + logger.info(f"MCES/Tanimoto mean absolute error: {cfg.data.mces20_max_value * mae_model_mces:.4f}") # Denormalize if using MCES20 if not cfg.data.use_tanimoto: diff --git a/simba/workflows/preprocessing.py b/simba/workflows/preprocessing.py index 4c27c77..c972e9a 100644 --- a/simba/workflows/preprocessing.py +++ b/simba/workflows/preprocessing.py @@ -50,30 +50,36 @@ def write_data( # Lightweight format: save only df_smiles, original MGF indexes, and mgf path # Spectra will be loaded at training time from mgf file using absolute indexes dataset = { - "df_smiles_train": molecule_pairs_train.df_smiles - if molecule_pairs_train is not None - else None, - "df_smiles_val": molecule_pairs_val.df_smiles - if molecule_pairs_val is not None - else None, - "df_smiles_test": molecule_pairs_test.df_smiles - if molecule_pairs_test is not None - else None, - "spectrum_indexes_train": [ - s.mgf_index for s in molecule_pairs_train.original_spectra - ] - if molecule_pairs_train is not None - else None, - "spectrum_indexes_val": [ - s.mgf_index for s in molecule_pairs_val.original_spectra - ] - if molecule_pairs_val is not None - else None, - "spectrum_indexes_test": [ - s.mgf_index for s in molecule_pairs_test.original_spectra - ] - if molecule_pairs_test is not None - else None, + "df_smiles_train": ( + molecule_pairs_train.df_smiles + if molecule_pairs_train is not None + else None + ), + "df_smiles_val": ( + molecule_pairs_val.df_smiles + if molecule_pairs_val is not None + else None + ), + "df_smiles_test": ( + molecule_pairs_test.df_smiles + if molecule_pairs_test is not None + else None + ), + "spectrum_indexes_train": ( + [s.mgf_index for s in molecule_pairs_train.original_spectra] + if molecule_pairs_train is not None + else None + ), + "spectrum_indexes_val": ( + [s.mgf_index for s in molecule_pairs_val.original_spectra] + if molecule_pairs_val is not None + else None + ), + "spectrum_indexes_test": ( + [s.mgf_index for s in molecule_pairs_test.original_spectra] + if molecule_pairs_test is not None + else None + ), "mgf_path": mgf_path, "format_version": "lightweight", } diff --git a/simba/workflows/training.py b/simba/workflows/training.py index 4adcd2e..11ba597 100644 --- a/simba/workflows/training.py +++ b/simba/workflows/training.py @@ -260,6 +260,7 @@ def prepare_data( use_ce=cfg.model.features.use_ce, use_ion_activation=cfg.model.features.use_ion_activation, use_ion_method=cfg.model.features.use_ion_method, + use_ion_mode = cfg.model.features.use_ion_mode, ) dataset_val = MultitaskDataBuilder.from_molecule_pairs_to_dataset( @@ -269,6 +270,7 @@ def prepare_data( use_ce=cfg.model.features.use_ce, use_ion_activation=cfg.model.features.use_ion_activation, use_ion_method=cfg.model.features.use_ion_method, + use_ion_mode = cfg.model.features.use_ion_mode, ) # Create samplers @@ -392,6 +394,7 @@ def setup_model(cfg: DictConfig, weights_mces: np.ndarray) -> SimilarityModelMul "use_ce": cfg.model.features.use_ce, "use_ion_activation": cfg.model.features.use_ion_activation, "use_ion_method": cfg.model.features.use_ion_method, + "use_ion_mode": cfg.model.features.use_ion_mode, } # Load pretrained weights if specified diff --git a/test_all_commands.sh b/test_all_commands.sh index ae754c9..cabfc02 100755 --- a/test_all_commands.sh +++ b/test_all_commands.sh @@ -120,6 +120,7 @@ uv run simba train \ training=fast_dev \ paths.preprocessing_dir_train=./test_full_workflow/preprocessed/ \ paths.checkpoint_dir=./test_full_workflow/checkpoints_metadata/ \ + model.features.use_adduct=true \ model.features.use_ce=true \ model.features.use_ion_activation=true \ model.features.use_ion_method=true \ @@ -135,6 +136,7 @@ uv run simba inference \ paths.checkpoint_dir=./test_full_workflow/checkpoints_metadata/ \ paths.preprocessing_dir=./test_full_workflow/preprocessed/ \ inference.preprocessing_pickle=mapping_unique_smiles.pkl \ + model.features.use_adduct=true \ model.features.use_ce=true \ model.features.use_ion_activation=true \ model.features.use_ion_method=true \ @@ -150,6 +152,7 @@ uv run simba analog-discovery \ --reference-spectra data/casmi2022.mgf \ --output-dir ./test_full_workflow/analog_results_metadata/ \ analog_discovery.query_index=0 \ + model.features.use_adduct=true \ model.features.use_ce=true \ model.features.use_ion_activation=true \ model.features.use_ion_method=true \ From 380bdb8a632a9b4210e072798ced03924f6f4f72 Mon Sep 17 00:00:00 2001 From: rukubrakov Date: Thu, 5 Mar 2026 12:44:39 +0100 Subject: [PATCH 2/8] fix: handle missing spectra during training --- simba/workflows/training.py | 27 +++++++++++++++++++++++---- 1 file changed, 23 insertions(+), 4 deletions(-) diff --git a/simba/workflows/training.py b/simba/workflows/training.py index 11ba597..ed5c990 100644 --- a/simba/workflows/training.py +++ b/simba/workflows/training.py @@ -98,12 +98,31 @@ def load_dataset(cfg: DictConfig): df_smiles = mapping[df_smiles_key] spectrum_indexes = mapping[spectrum_indexes_key] - # Load original spectra (all, including duplicates) - original_spectra = [spectra_by_idx[idx] for idx in spectrum_indexes] + # Load original spectra, handling missing ones + original_spectra = [] + idx_map = {} # old_idx -> new_idx + missing = [] + + for old_idx, mgf_idx in enumerate(spectrum_indexes): + if mgf_idx in spectra_by_idx: + idx_map[old_idx] = len(original_spectra) + original_spectra.append(spectra_by_idx[mgf_idx]) + else: + missing.append(mgf_idx) + + if missing: + logger.warning(f"[{split}] Missing {len(missing)} spectra (e.g., MGF index {missing[0]})") + # Filter df_smiles to keep only rows with valid spectra + valid_rows = [] + for i in df_smiles.index: + old_idxs = df_smiles.loc[i, "indexes"] + if all(idx in idx_map for idx in old_idxs): + # Remap to new positions + df_smiles.at[i, "indexes"] = [idx_map[idx] for idx in old_idxs] + valid_rows.append(i) + df_smiles = df_smiles.loc[valid_rows].reset_index(drop=True) # Build unique_spectra from df_smiles indexes - # df_smiles['indexes'] contains lists of indexes into original_spectra - # We take the first spectrum for each unique SMILES unique_spectra = [ original_spectra[df_smiles.loc[i, "indexes"][0]] for i in df_smiles.index From 4ebc042fafe5f8e71acfcc76ad5f09858fda767c Mon Sep 17 00:00:00 2001 From: rukubrakov Date: Thu, 5 Mar 2026 17:44:22 +0100 Subject: [PATCH 3/8] fix: make training respect use_only_protonized_adducts param from config --- simba/workflows/training.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/simba/workflows/training.py b/simba/workflows/training.py index ed5c990..f768b98 100644 --- a/simba/workflows/training.py +++ b/simba/workflows/training.py @@ -84,7 +84,16 @@ def load_dataset(cfg: DictConfig): ) mgf_path = mapping["mgf_path"] - all_spectra = load_spectra(mgf_path, cfg) + + # Use preprocessing config values (if available) to ensure consistent filtering + use_only_protonized = getattr(cfg.preprocessing, 'use_only_protonized_adducts', True) + + all_spectra = load_spectra( + mgf_path, + cfg, + n_samples=-1, # Load all spectra during training + use_only_protonized_adducts=use_only_protonized, + ) # Create spectrum lookup by MGF index spectra_by_idx = {s.mgf_index: s for s in all_spectra} From c2365f4f9b5e6dfa01645e203349379695854496 Mon Sep 17 00:00:00 2001 From: rukubrakov Date: Mon, 9 Mar 2026 12:19:30 +0100 Subject: [PATCH 4/8] fix: make inference script respect preprocessing config --- simba/workflows/inference.py | 36 +++++++++++++++++++++++++++++++++--- 1 file changed, 33 insertions(+), 3 deletions(-) diff --git a/simba/workflows/inference.py b/simba/workflows/inference.py index 4ff0ca9..e4a261a 100644 --- a/simba/workflows/inference.py +++ b/simba/workflows/inference.py @@ -55,7 +55,16 @@ def load_inference_data(cfg: DictConfig): logger.info("Detected lightweight format - reconstructing molecule_pairs_test") mgf_path = dataset["mgf_path"] - all_spectra = load_spectra(mgf_path, cfg) + + # Use preprocessing config values (if available) to ensure consistent filtering + use_only_protonized = getattr(cfg.preprocessing, 'use_only_protonized_adducts', True) + + all_spectra = load_spectra( + mgf_path, + cfg, + n_samples=-1, + use_only_protonized_adducts=use_only_protonized, + ) # Create spectrum lookup by MGF index spectra_by_idx = {s.mgf_index: s for s in all_spectra} @@ -65,8 +74,29 @@ def load_inference_data(cfg: DictConfig): df_smiles = dataset["df_smiles_test"] spectrum_indexes = dataset["spectrum_indexes_test"] - # Load original spectra (all, including duplicates) - original_spectra = [spectra_by_idx[idx] for idx in spectrum_indexes] + # Load original spectra, handling missing ones + original_spectra = [] + idx_map = {} # old_idx -> new_idx + missing = [] + + for old_idx, mgf_idx in enumerate(spectrum_indexes): + if mgf_idx in spectra_by_idx: + idx_map[old_idx] = len(original_spectra) + original_spectra.append(spectra_by_idx[mgf_idx]) + else: + missing.append(mgf_idx) + + if missing: + logger.warning(f"[test] Missing {len(missing)} spectra (e.g., MGF index {missing[0]})") + # Filter df_smiles to keep only rows with valid spectra + valid_rows = [] + for i in df_smiles.index: + old_idxs = df_smiles.loc[i, "indexes"] + if all(idx in idx_map for idx in old_idxs): + # Remap to new positions + df_smiles.at[i, "indexes"] = [idx_map[idx] for idx in old_idxs] + valid_rows.append(i) + df_smiles = df_smiles.loc[valid_rows].reset_index(drop=True) # Build unique_spectra from df_smiles indexes # df_smiles['indexes'] contains lists of indexes into original_spectra From ed8b3c6119ce33acb4f384c8ae53df2d2137637b Mon Sep 17 00:00:00 2001 From: rukubrakov Date: Mon, 9 Mar 2026 15:18:43 +0100 Subject: [PATCH 5/8] fix: Remove reset_index() in training/inference to keep original row numbers --- simba/workflows/inference.py | 34 ++++++++++++++++++++-------------- simba/workflows/training.py | 2 +- 2 files changed, 21 insertions(+), 15 deletions(-) diff --git a/simba/workflows/inference.py b/simba/workflows/inference.py index e4a261a..e057196 100644 --- a/simba/workflows/inference.py +++ b/simba/workflows/inference.py @@ -10,6 +10,7 @@ import numpy as np from omegaconf import DictConfig from scipy.stats import spearmanr +from sklearn.metrics import mean_absolute_error from torch.utils.data import DataLoader import simba.core.data.molecule_pairs @@ -21,7 +22,6 @@ from simba.core.training.train_utils import TrainUtils from simba.utils.logger_setup import logger from simba.workflows.utils import load_spectra -from sklearn.metrics import mean_absolute_error # Backward compatibility: Support loading old pickle files with old module paths @@ -55,10 +55,12 @@ def load_inference_data(cfg: DictConfig): logger.info("Detected lightweight format - reconstructing molecule_pairs_test") mgf_path = dataset["mgf_path"] - + # Use preprocessing config values (if available) to ensure consistent filtering - use_only_protonized = getattr(cfg.preprocessing, 'use_only_protonized_adducts', True) - + use_only_protonized = getattr( + cfg.preprocessing, "use_only_protonized_adducts", True + ) + all_spectra = load_spectra( mgf_path, cfg, @@ -78,16 +80,18 @@ def load_inference_data(cfg: DictConfig): original_spectra = [] idx_map = {} # old_idx -> new_idx missing = [] - + for old_idx, mgf_idx in enumerate(spectrum_indexes): if mgf_idx in spectra_by_idx: idx_map[old_idx] = len(original_spectra) original_spectra.append(spectra_by_idx[mgf_idx]) else: missing.append(mgf_idx) - + if missing: - logger.warning(f"[test] Missing {len(missing)} spectra (e.g., MGF index {missing[0]})") + logger.warning( + f"[test] Missing {len(missing)} spectra (e.g., MGF index {missing[0]})" + ) # Filter df_smiles to keep only rows with valid spectra valid_rows = [] for i in df_smiles.index: @@ -96,7 +100,7 @@ def load_inference_data(cfg: DictConfig): # Remap to new positions df_smiles.at[i, "indexes"] = [idx_map[idx] for idx in old_idxs] valid_rows.append(i) - df_smiles = df_smiles.loc[valid_rows].reset_index(drop=True) + df_smiles = df_smiles.loc[valid_rows] # Build unique_spectra from df_smiles indexes # df_smiles['indexes'] contains lists of indexes into original_spectra @@ -188,7 +192,7 @@ def prepare_inference_dataloaders( use_ce=cfg.model.features.use_ce, use_ion_activation=cfg.model.features.use_ion_activation, use_ion_method=cfg.model.features.use_ion_method, - use_ion_mode = cfg.model.features.use_ion_mode, + use_ion_mode=cfg.model.features.use_ion_mode, ) dataloader_ed = DataLoader( dataset_ed, batch_size=cfg.inference.batch_size, shuffle=False @@ -201,7 +205,7 @@ def prepare_inference_dataloaders( use_ce=cfg.model.features.use_ce, use_ion_activation=cfg.model.features.use_ion_activation, use_ion_method=cfg.model.features.use_ion_method, - use_ion_mode = cfg.model.features.use_ion_mode, + use_ion_mode=cfg.model.features.use_ion_mode, ) dataloader_mces = DataLoader( dataset_mces, batch_size=cfg.inference.batch_size, shuffle=False @@ -372,14 +376,16 @@ def evaluate_predictions( if len(mces_true) == 0 or len(pred_mces_mces_flat) == 0: logger.warning("No MCES samples after filtering, skipping MCES correlation") - corr_model_mces = float("nan") - mae_model_mces =float("nan") + corr_model_mces = float("nan") + mae_model_mces = float("nan") else: corr_model_mces, _ = spearmanr(mces_true, pred_mces_mces_flat) - mae_model_mces = mean_absolute_error(mces_true, pred_mces_mces_flat) + mae_model_mces = mean_absolute_error(mces_true, pred_mces_mces_flat) logger.info(f"MCES/Tanimoto correlation: {corr_model_mces:.4f}") - logger.info(f"MCES/Tanimoto mean absolute error: {cfg.data.mces20_max_value * mae_model_mces:.4f}") + logger.info( + f"MCES/Tanimoto mean absolute error: {cfg.data.mces20_max_value * mae_model_mces:.4f}" + ) # Denormalize if using MCES20 if not cfg.data.use_tanimoto: diff --git a/simba/workflows/training.py b/simba/workflows/training.py index f768b98..4352745 100644 --- a/simba/workflows/training.py +++ b/simba/workflows/training.py @@ -129,7 +129,7 @@ def load_dataset(cfg: DictConfig): # Remap to new positions df_smiles.at[i, "indexes"] = [idx_map[idx] for idx in old_idxs] valid_rows.append(i) - df_smiles = df_smiles.loc[valid_rows].reset_index(drop=True) + df_smiles = df_smiles.loc[valid_rows] # Build unique_spectra from df_smiles indexes unique_spectra = [ From 40bb215ff8139e27fc0ab054d809782b3c5ac7ab Mon Sep 17 00:00:00 2001 From: rukubrakov Date: Tue, 17 Mar 2026 12:23:57 +0100 Subject: [PATCH 6/8] Add precomputed distances cache to reuse molecular similarity calculations across preprocessing runs --- README.md | 32 +++ simba/configs/preprocessing/default.yaml | 8 + .../chemistry/edit_distance/edit_distance.py | 241 ++++++++++++++++++ simba/core/chemistry/mces/mces_computation.py | 8 + simba/workflows/preprocessing.py | 16 ++ test_all_commands.sh | 27 +- 6 files changed, 323 insertions(+), 9 deletions(-) diff --git a/README.md b/README.md index bf5b005..5249e0a 100644 --- a/README.md +++ b/README.md @@ -275,6 +275,38 @@ simba preprocess \ --- +**Reusing Precomputed Distances:** + +To speed up preprocessing when working with related datasets (e.g., MS2-only, MS3-only, and joint MS2+MS3), you can reuse previously computed molecular distances: + +```bash +# First: preprocess MS2-only data +simba preprocess \ + paths.spectra_path=ms2_spectra.mgf \ + paths.preprocessing_dir=./ms2_preprocessing/ + +# Then: preprocess MS3-only data +simba preprocess \ + paths.spectra_path=ms3_spectra.mgf \ + paths.preprocessing_dir=./ms3_preprocessing/ + +# Finally: preprocess joint dataset, reusing distances from both +simba preprocess \ + paths.spectra_path=joint_spectra.mgf \ + paths.preprocessing_dir=./joint_preprocessing/ \ + 'preprocessing.precomputed_distances=[./ms2_preprocessing/, ./ms3_preprocessing/]' +``` + +The cache automatically: +- Finds all distance files (`edit_distance_*.npy`, `mces_*.npy`) in each directory +- Loads SMILES mappings from `mapping_unique_smiles.pkl` +- Matches molecules by SMILES strings (robust to different splits/filters) +- Logs cache hit/miss statistics during computation + +**Cache hit rate = % of molecule pairs that were reused instead of recomputed!** + +--- + **Quick Testing (Fast Dev Mode):** ```bash diff --git a/simba/configs/preprocessing/default.yaml b/simba/configs/preprocessing/default.yaml index e87ddcd..7a430a1 100644 --- a/simba/configs/preprocessing/default.yaml +++ b/simba/configs/preprocessing/default.yaml @@ -24,5 +24,13 @@ test_split: 0.1 # Test split fraction (0.0-1.0) random_mces_sampling: false use_only_protonized_adducts: true +# Precomputed distances - reuse distances from previous preprocessing runs +# Just list preprocessing directories - auto-discovers all distance files +precomputed_distances: + # Examples: + # - "./test_precomputed_cache/dataset1/" + # - "./ms2_preprocessing/" + # - "./ms3_preprocessing/" + # Subsampling subsample_preprocessing: false diff --git a/simba/core/chemistry/edit_distance/edit_distance.py b/simba/core/chemistry/edit_distance/edit_distance.py index 92793ea..fe34fda 100644 --- a/simba/core/chemistry/edit_distance/edit_distance.py +++ b/simba/core/chemistry/edit_distance/edit_distance.py @@ -1,6 +1,7 @@ import os from functools import lru_cache +import dill import numpy as np import pandas as pd from myopic_mces.myopic_mces import MCES as MCES2 @@ -18,6 +19,218 @@ VERY_HIGH_DISTANCE = 666 +def load_precomputed_distances_cache(preprocessing_dirs: list[str]) -> dict: + """ + Load precomputed distances from preprocessing directories. + + Auto-discovers files in each directory: + - mapping_unique_smiles.pkl (for SMILES mapping) + - edit_distance_*.npy and mces_*.npy files (for distances) + + Parameters + ---------- + preprocessing_dirs : list[str] + List of preprocessing directory paths. + + Returns + ------- + dict + Cache with (smiles1, smiles2) tuple keys -> [ed_distance, mces_distance] values. + """ + cache = {} + + for prep_dir in preprocessing_dirs: + if not os.path.exists(prep_dir): + logger.warning(f"Preprocessing directory not found: {prep_dir}") + continue + + # Find the pickle file + pickle_path = os.path.join(prep_dir, 'mapping_unique_smiles.pkl') + if not os.path.exists(pickle_path): + logger.warning(f"No mapping_unique_smiles.pkl found in {prep_dir}") + continue + + # Find distance files by prefix and split + distance_files = {} + for split in ['train', 'val', 'test']: + distance_files[split] = {'ed': [], 'mces': []} + + for filename in os.listdir(prep_dir): + if not filename.endswith('.npy'): + continue + + # Determine split + split = None + for s in ['train', 'val', 'test']: + if f'_{s}_' in filename or filename.endswith(f'_{s}.npy'): + split = s + break + + if split is None: + continue + + # Categorize by type + if filename.startswith('edit_distance_'): + distance_files[split]['ed'].append(os.path.join(prep_dir, filename)) + elif filename.startswith('mces_'): + distance_files[split]['mces'].append(os.path.join(prep_dir, filename)) + + # Count total files + total_ed = sum(len(v['ed']) for v in distance_files.values()) + total_mces = sum(len(v['mces']) for v in distance_files.values()) + + if total_ed == 0 and total_mces == 0: + logger.warning(f"No distance files found in {prep_dir}") + continue + + logger.info(f"Auto-discovered in {prep_dir}: {total_ed} ED files, {total_mces} MCES files") + + # Load SMILES mapping per split + try: + with open(pickle_path, 'rb') as f: + data = dill.load(f) + except Exception as e: + logger.error(f"Error loading pickle {pickle_path}: {e}") + continue + + # Process each split separately (indices are per-split) + pairs_added = 0 + for split in ['train', 'val', 'test']: + # Extract SMILES for this split only + split_smiles = [] + for key_prefix in ['molecule_pairs_', 'df_smiles_']: + key = key_prefix + split + if key not in data or data[key] is None: + continue + + # Extract DataFrame + if key_prefix == 'molecule_pairs_': + if not hasattr(data[key], 'df_smiles'): + continue + df = data[key].df_smiles + else: + df = data[key] + + if df is None or df.empty: + continue + + # Extract SMILES from DataFrame + if 'canon_smiles' in df.columns: + smiles_list = df['canon_smiles'].tolist() + elif 'smiles' in df.columns: + smiles_list = df['smiles'].tolist() + elif hasattr(df.index, 'tolist'): + smiles_list = df.index.tolist() + else: + continue + + # Remove duplicates while preserving order + seen = set() + for s in smiles_list: + if s not in seen: + seen.add(s) + split_smiles.append(s) + break # Only need one match + + if not split_smiles: + continue + + # Load distances for this split + ed_cache_split = _load_distances_from_files(distance_files[split]['ed'], split_smiles) + mces_cache_split = _load_distances_from_files(distance_files[split]['mces'], split_smiles) + + # Combine into final cache (only pairs with both ED and MCES) + for key in set(ed_cache_split.keys()) & set(mces_cache_split.keys()): + cache[key] = [ed_cache_split[key], mces_cache_split[key]] + pairs_added += 1 + + logger.info(f"Loaded {pairs_added} valid pairs from {prep_dir}") + + logger.info(f"Total unique pairs in cache: {len(cache)}") + return cache + + +def _load_smiles_from_pickle(pickle_path: str) -> list[str]: + """Extract unique SMILES list from preprocessing pickle.""" + try: + with open(pickle_path, 'rb') as f: + data = dill.load(f) + + unique_smiles = [] + + # Handle both old format (molecule_pairs_X) and new lightweight format (df_smiles_X) + for key_prefix in ['molecule_pairs_', 'df_smiles_']: + for split in ['train', 'val', 'test']: + key = key_prefix + split + if key not in data or data[key] is None: + continue + + # Extract DataFrame + if key_prefix == 'molecule_pairs_': + if not hasattr(data[key], 'df_smiles'): + continue + df = data[key].df_smiles + else: + df = data[key] + + if df is None or df.empty: + continue + + # Extract SMILES from DataFrame + if 'canon_smiles' in df.columns: + smiles_list = df['canon_smiles'].tolist() + elif 'smiles' in df.columns: + smiles_list = df['smiles'].tolist() + elif hasattr(df.index, 'tolist'): + smiles_list = df.index.tolist() + else: + continue + + unique_smiles.extend(smiles_list) + + # Remove duplicates while preserving order + seen = set() + result = [] + for s in unique_smiles: + if s not in seen: + seen.add(s) + result.append(s) + + logger.info(f"Extracted {len(result)} unique SMILES from {pickle_path}") + return result + + except Exception as e: + logger.error(f"Error loading SMILES from {pickle_path}: {e}") + return [] + + +def _load_distances_from_files(distance_files: list[str], smiles_list: list[str]) -> dict: + """Load distances from multiple .npy files.""" + distance_cache = {} + + for dist_file in sorted(distance_files): + try: + distances = np.load(dist_file) + for row in distances: + idx1, idx2 = int(row[0]), int(row[1]) + + # Skip out of bounds + if idx1 >= len(smiles_list) or idx2 >= len(smiles_list): + continue + + # Map to SMILES + smiles1, smiles2 = smiles_list[idx1], smiles_list[idx2] + key = tuple(sorted([smiles1, smiles2])) + + # Store distance (3rd column) + distance_cache[key] = float(row[2]) + + except Exception as e: + logger.warning(f"Error loading {dist_file}: {e}") + + return distance_cache + + def create_input_df(smiles, indexes_0, indexes_1): df = pd.DataFrame() logger.info(f"Number of spectra: {len(smiles)}") @@ -42,6 +255,7 @@ def compute_ed_and_mces_both( output_file: str = None, progress_position: int = 0, progress_desc: str = "Computing", + precomputed_cache: dict = None, ) -> np.ndarray | dict: """ Compute BOTH edit distance AND MCES for a batch of molecule pairs in a single pass. @@ -80,6 +294,8 @@ def compute_ed_and_mces_both( Position for the tqdm progress bar (default: 0). Used for stacking multiple progress bars. progress_desc : str, optional Description text for the tqdm progress bar (default: "Computing"). + precomputed_cache : dict, optional + Cache of precomputed distances with (smiles1, smiles2) keys -> [ed, mces] values. Returns ------- @@ -105,6 +321,10 @@ def compute_ed_and_mces_both( ed_distances = [] mces_distances = [] + + # Track cache hits/misses + cache_hits = 0 + cache_misses = 0 for index in tqdm( range(pair_distances.shape[0]), @@ -118,6 +338,21 @@ def compute_ed_and_mces_both( s0 = smiles[int(pair[0])] s1 = smiles[int(pair[1])] + + # Check precomputed cache first + if precomputed_cache is not None: + cache_key = tuple(sorted([s0, s1])) + if cache_key in precomputed_cache: + cached_values = precomputed_cache[cache_key] + ed_dist = cached_values[0] + mces_dist = cached_values[1] + cache_hits += 1 + ed_distances.append(ed_dist) + mces_distances.append(mces_dist) + continue + + # Not in cache, compute + cache_misses += 1 fp0 = fps[int(pair[0])] fp1 = fps[int(pair[1])] mol0 = mols[int(pair[0])] @@ -134,6 +369,12 @@ def compute_ed_and_mces_both( pair_distances[:, 2] = ed_distances pair_distances[:, 3] = mces_distances + + # Log cache statistics + if precomputed_cache is not None: + total = cache_hits + cache_misses + hit_rate = (cache_hits / total * 100) if total > 0 else 0 + logger.info(f"Cache: {cache_hits} hits, {cache_misses} misses ({hit_rate:.1f}% hit rate)") if output_file: os.makedirs(os.path.dirname(output_file), exist_ok=True) diff --git a/simba/core/chemistry/mces/mces_computation.py b/simba/core/chemistry/mces/mces_computation.py index 2fbaca5..1e3536c 100644 --- a/simba/core/chemistry/mces/mces_computation.py +++ b/simba/core/chemistry/mces/mces_computation.py @@ -36,6 +36,7 @@ def compute_all_mces_results_unique( use_edit_distance: bool = False, loaded_molecule_pairs: MolecularPairsSet | None = None, compute_both_metrics: bool = False, + precomputed_cache: dict = None, ) -> MoleculePairsOpt: """ Compute MCES or edit distance for all pairs of spectra using multiprocessing. @@ -72,6 +73,8 @@ def compute_all_mces_results_unique( Precomputed molecule pairs to use instead of computing new ones, by default None; compute_both_metrics : bool, optional Whether to compute both MCES and edit distance metrics, by default False. + precomputed_cache : dict, optional + Cache of precomputed distances to reuse from previous runs, by default None. Returns ------- @@ -100,6 +103,7 @@ def compute_all_mces_results_unique( identifier=identifier, use_edit_distance=use_edit_distance, compute_both_metrics=compute_both_metrics, + precomputed_cache=precomputed_cache, ) else: molecular_pairs = loaded_molecule_pairs @@ -289,6 +293,7 @@ def compute_all_mces_results_exhaustive( identifier: str = "", use_edit_distance=False, compute_both_metrics: bool = False, + precomputed_cache: dict = None, ) -> MolecularPairsSet: """ Compute MCES or edit distance for all pairs of spectra using multiprocessing. @@ -311,6 +316,8 @@ def compute_all_mces_results_exhaustive( If True, compute edit distance instead of MCES, by default False. compute_both_metrics : bool, optional If True, compute both ED and MCES in single pass (optimized), by default False. + precomputed_cache : dict, optional + Cache of precomputed distances to reuse from previous runs, by default None. Returns ------- @@ -447,6 +454,7 @@ def compute_all_mces_results_exhaustive( worker_output, sub_index, f"Node {current_node} Chunk {chunk_idx} Worker {sub_index}", + precomputed_cache, ), ) results.append(result) diff --git a/simba/workflows/preprocessing.py b/simba/workflows/preprocessing.py index c972e9a..1251cb6 100644 --- a/simba/workflows/preprocessing.py +++ b/simba/workflows/preprocessing.py @@ -17,6 +17,9 @@ with contextlib.suppress(RuntimeError): multiprocessing.set_start_method("spawn", force=True) +from simba.core.chemistry.edit_distance.edit_distance import ( + load_precomputed_distances_cache, +) from simba.core.chemistry.mces.mces_computation import MCES from simba.core.training.train_utils import TrainUtils from simba.utils.logger_setup import logger @@ -207,6 +210,18 @@ def preprocess(cfg: DictConfig) -> None: with open(output_file, "wb") as file: pickle.dump(dataset, file) + # Load precomputed distances cache if configured + precomputed_cache = {} + if hasattr(cfg.preprocessing, 'precomputed_distances') and cfg.preprocessing.precomputed_distances: + dirs = cfg.preprocessing.precomputed_distances + if dirs and len(dirs) > 0: + logger.info(f"Loading precomputed distances from {len(dirs)} directory(ies)...") + precomputed_cache = load_precomputed_distances_cache(dirs) + else: + logger.info("No precomputed distances configured, computing all from scratch") + else: + logger.info("No precomputed distances configured, computing all from scratch") + # Compute distances for each partition molecule_pairs = {} for type_data, spectra in [ @@ -231,6 +246,7 @@ def preprocess(cfg: DictConfig) -> None: use_edit_distance=True, # Ignored when compute_both_metrics=True loaded_molecule_pairs=None, compute_both_metrics=True, + precomputed_cache=precomputed_cache if len(precomputed_cache) > 0 else None, ) # Combine edit distance and MCES files diff --git a/test_all_commands.sh b/test_all_commands.sh index cabfc02..ad88156 100755 --- a/test_all_commands.sh +++ b/test_all_commands.sh @@ -44,7 +44,7 @@ rm -rf test_full_workflow/ mkdir -p test_full_workflow echo "" -echo "1/9 Testing: simba preprocess" +echo "1/10 Testing: simba preprocess" echo "--------------------------------" uv run simba preprocess \ preprocessing=fast_dev \ @@ -52,7 +52,16 @@ uv run simba preprocess \ paths.preprocessing_dir=./test_full_workflow/preprocessed/ echo "" -echo "2/9 Testing: simba train" +echo "2/10 Testing: simba preprocess with cache (reuse distances)" +echo "------------------------------------------------------------" +uv run simba preprocess \ + preprocessing=fast_dev \ + paths.spectra_path=data/casmi2022.mgf \ + paths.preprocessing_dir=./test_full_workflow/preprocessed_cached/ \ + 'preprocessing.precomputed_distances=[./test_full_workflow/preprocessed/]' + +echo "" +echo "3/10 Testing: simba train" echo "--------------------------------" uv run simba train \ training=fast_dev \ @@ -63,7 +72,7 @@ uv run simba train \ hardware.accelerator=$DEVICE echo "" -echo "3/9 Testing: simba inference" +echo "4/10 Testing: simba inference" echo "--------------------------------" uv run simba inference \ inference=fast_dev \ @@ -73,7 +82,7 @@ uv run simba inference \ hardware.accelerator=$DEVICE echo "" -echo "4/9 Testing: simba analog-discovery" +echo "5/10 Testing: simba analog-discovery" echo "--------------------------------" uv run simba analog-discovery \ analog_discovery=fast_dev \ @@ -86,7 +95,7 @@ uv run simba analog-discovery \ if [[ "$SKIP_PRETRAINED" == "false" ]]; then echo "" - echo "5/9 Testing: simba inference (pretrained model)" + echo "6/10 Testing: simba inference (pretrained model)" echo "--------------------------------" uv run simba inference \ inference=fast_dev \ @@ -96,7 +105,7 @@ if [[ "$SKIP_PRETRAINED" == "false" ]]; then hardware.accelerator=$DEVICE echo "" - echo "6/9 Testing: simba analog-discovery (pretrained model)" + echo "7/10 Testing: simba analog-discovery (pretrained model)" echo "--------------------------------" uv run simba analog-discovery \ analog_discovery=fast_dev \ @@ -114,7 +123,7 @@ else fi echo "" -echo "7/9 Testing: simba train (with metadata features)" +echo "8/10 Testing: simba train (with metadata features)" echo "--------------------------------" uv run simba train \ training=fast_dev \ @@ -129,7 +138,7 @@ uv run simba train \ hardware.accelerator=$DEVICE echo "" -echo "8/9 Testing: simba inference (with metadata features)" +echo "9/10 Testing: simba inference (with metadata features)" echo "--------------------------------" uv run simba inference \ inference=fast_dev \ @@ -143,7 +152,7 @@ uv run simba inference \ hardware.accelerator=$DEVICE echo "" -echo "9/9 Testing: simba analog-discovery (with metadata features)" +echo "10/10 Testing: simba analog-discovery (with metadata features)" echo "--------------------------------" uv run simba analog-discovery \ analog_discovery=fast_dev \ From a0c1787e6551b54af6e961aaee41ece4c55155ee Mon Sep 17 00:00:00 2001 From: rukubrakov Date: Thu, 19 Mar 2026 11:14:44 +0100 Subject: [PATCH 7/8] Apply quality fixes: fix ruff, add missing metadata fields to tests, switch to balanced_accuracy_score metric in inference, and configure HiGHS solver with highspy dependency. --- pyproject.toml | 1 + simba/commands/analog_discovery.py | 1 + .../chemistry/edit_distance/edit_distance.py | 172 +++++++++--------- simba/core/models/similarity_models.py | 22 +-- simba/core/models/spectrum_encoder.py | 76 ++++---- simba/workflows/inference.py | 4 +- simba/workflows/preprocessing.py | 17 +- simba/workflows/training.py | 26 ++- tests/unit/test_embedder_multitask.py | 6 + 9 files changed, 178 insertions(+), 147 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 117d62b..58db336 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dependencies = [ "pyteomics>=4.6.0", "depthcharge-ms @ git+https://github.com/wfondrie/depthcharge.git@bd2861f", "myopic-mces>=1.0.0,<2.0.0", + "highspy>=1.13.1", # Data processing "h5py>=3.10.0", "pyarrow>=15.0.0", diff --git a/simba/commands/analog_discovery.py b/simba/commands/analog_discovery.py index 4089695..8c86fab 100644 --- a/simba/commands/analog_discovery.py +++ b/simba/commands/analog_discovery.py @@ -192,6 +192,7 @@ def _analog_discovery_with_hydra( except Exception as e: import traceback + click.echo(f"\n❌ Error during analog discovery: {e}", err=True) click.echo("\nFull traceback:", err=True) click.echo(traceback.format_exc(), err=True) diff --git a/simba/core/chemistry/edit_distance/edit_distance.py b/simba/core/chemistry/edit_distance/edit_distance.py index fe34fda..9f927ba 100644 --- a/simba/core/chemistry/edit_distance/edit_distance.py +++ b/simba/core/chemistry/edit_distance/edit_distance.py @@ -22,108 +22,110 @@ def load_precomputed_distances_cache(preprocessing_dirs: list[str]) -> dict: """ Load precomputed distances from preprocessing directories. - + Auto-discovers files in each directory: - mapping_unique_smiles.pkl (for SMILES mapping) - edit_distance_*.npy and mces_*.npy files (for distances) - + Parameters ---------- preprocessing_dirs : list[str] List of preprocessing directory paths. - + Returns ------- dict Cache with (smiles1, smiles2) tuple keys -> [ed_distance, mces_distance] values. """ cache = {} - + for prep_dir in preprocessing_dirs: if not os.path.exists(prep_dir): logger.warning(f"Preprocessing directory not found: {prep_dir}") continue - + # Find the pickle file - pickle_path = os.path.join(prep_dir, 'mapping_unique_smiles.pkl') + pickle_path = os.path.join(prep_dir, "mapping_unique_smiles.pkl") if not os.path.exists(pickle_path): logger.warning(f"No mapping_unique_smiles.pkl found in {prep_dir}") continue - + # Find distance files by prefix and split distance_files = {} - for split in ['train', 'val', 'test']: - distance_files[split] = {'ed': [], 'mces': []} - + for split in ["train", "val", "test"]: + distance_files[split] = {"ed": [], "mces": []} + for filename in os.listdir(prep_dir): - if not filename.endswith('.npy'): + if not filename.endswith(".npy"): continue - + # Determine split split = None - for s in ['train', 'val', 'test']: - if f'_{s}_' in filename or filename.endswith(f'_{s}.npy'): + for s in ["train", "val", "test"]: + if f"_{s}_" in filename or filename.endswith(f"_{s}.npy"): split = s break - + if split is None: continue - + # Categorize by type - if filename.startswith('edit_distance_'): - distance_files[split]['ed'].append(os.path.join(prep_dir, filename)) - elif filename.startswith('mces_'): - distance_files[split]['mces'].append(os.path.join(prep_dir, filename)) - + if filename.startswith("edit_distance_"): + distance_files[split]["ed"].append(os.path.join(prep_dir, filename)) + elif filename.startswith("mces_"): + distance_files[split]["mces"].append(os.path.join(prep_dir, filename)) + # Count total files - total_ed = sum(len(v['ed']) for v in distance_files.values()) - total_mces = sum(len(v['mces']) for v in distance_files.values()) - + total_ed = sum(len(v["ed"]) for v in distance_files.values()) + total_mces = sum(len(v["mces"]) for v in distance_files.values()) + if total_ed == 0 and total_mces == 0: logger.warning(f"No distance files found in {prep_dir}") continue - - logger.info(f"Auto-discovered in {prep_dir}: {total_ed} ED files, {total_mces} MCES files") - + + logger.info( + f"Auto-discovered in {prep_dir}: {total_ed} ED files, {total_mces} MCES files" + ) + # Load SMILES mapping per split try: - with open(pickle_path, 'rb') as f: + with open(pickle_path, "rb") as f: data = dill.load(f) except Exception as e: logger.error(f"Error loading pickle {pickle_path}: {e}") continue - + # Process each split separately (indices are per-split) pairs_added = 0 - for split in ['train', 'val', 'test']: + for split in ["train", "val", "test"]: # Extract SMILES for this split only split_smiles = [] - for key_prefix in ['molecule_pairs_', 'df_smiles_']: + for key_prefix in ["molecule_pairs_", "df_smiles_"]: key = key_prefix + split if key not in data or data[key] is None: continue - + # Extract DataFrame - if key_prefix == 'molecule_pairs_': - if not hasattr(data[key], 'df_smiles'): + if key_prefix == "molecule_pairs_": + if not hasattr(data[key], "df_smiles"): continue df = data[key].df_smiles else: df = data[key] - + if df is None or df.empty: continue - + # Extract SMILES from DataFrame - if 'canon_smiles' in df.columns: - smiles_list = df['canon_smiles'].tolist() - elif 'smiles' in df.columns: - smiles_list = df['smiles'].tolist() - elif hasattr(df.index, 'tolist'): + if "canon_smiles" in df.columns: + smiles_list = df["canon_smiles"].tolist() + elif "smiles" in df.columns: + smiles_list = df["smiles"].tolist() + elif hasattr(df.index, "tolist"): smiles_list = df.index.tolist() else: continue - + # Remove duplicates while preserving order seen = set() for s in smiles_list: @@ -131,21 +133,25 @@ def load_precomputed_distances_cache(preprocessing_dirs: list[str]) -> dict: seen.add(s) split_smiles.append(s) break # Only need one match - + if not split_smiles: continue - + # Load distances for this split - ed_cache_split = _load_distances_from_files(distance_files[split]['ed'], split_smiles) - mces_cache_split = _load_distances_from_files(distance_files[split]['mces'], split_smiles) - + ed_cache_split = _load_distances_from_files( + distance_files[split]["ed"], split_smiles + ) + mces_cache_split = _load_distances_from_files( + distance_files[split]["mces"], split_smiles + ) + # Combine into final cache (only pairs with both ED and MCES) for key in set(ed_cache_split.keys()) & set(mces_cache_split.keys()): cache[key] = [ed_cache_split[key], mces_cache_split[key]] pairs_added += 1 - + logger.info(f"Loaded {pairs_added} valid pairs from {prep_dir}") - + logger.info(f"Total unique pairs in cache: {len(cache)}") return cache @@ -153,41 +159,41 @@ def load_precomputed_distances_cache(preprocessing_dirs: list[str]) -> dict: def _load_smiles_from_pickle(pickle_path: str) -> list[str]: """Extract unique SMILES list from preprocessing pickle.""" try: - with open(pickle_path, 'rb') as f: + with open(pickle_path, "rb") as f: data = dill.load(f) - + unique_smiles = [] - + # Handle both old format (molecule_pairs_X) and new lightweight format (df_smiles_X) - for key_prefix in ['molecule_pairs_', 'df_smiles_']: - for split in ['train', 'val', 'test']: + for key_prefix in ["molecule_pairs_", "df_smiles_"]: + for split in ["train", "val", "test"]: key = key_prefix + split if key not in data or data[key] is None: continue - + # Extract DataFrame - if key_prefix == 'molecule_pairs_': - if not hasattr(data[key], 'df_smiles'): + if key_prefix == "molecule_pairs_": + if not hasattr(data[key], "df_smiles"): continue df = data[key].df_smiles else: df = data[key] - + if df is None or df.empty: continue - + # Extract SMILES from DataFrame - if 'canon_smiles' in df.columns: - smiles_list = df['canon_smiles'].tolist() - elif 'smiles' in df.columns: - smiles_list = df['smiles'].tolist() - elif hasattr(df.index, 'tolist'): + if "canon_smiles" in df.columns: + smiles_list = df["canon_smiles"].tolist() + elif "smiles" in df.columns: + smiles_list = df["smiles"].tolist() + elif hasattr(df.index, "tolist"): smiles_list = df.index.tolist() else: continue - + unique_smiles.extend(smiles_list) - + # Remove duplicates while preserving order seen = set() result = [] @@ -195,39 +201,41 @@ def _load_smiles_from_pickle(pickle_path: str) -> list[str]: if s not in seen: seen.add(s) result.append(s) - + logger.info(f"Extracted {len(result)} unique SMILES from {pickle_path}") return result - + except Exception as e: logger.error(f"Error loading SMILES from {pickle_path}: {e}") return [] -def _load_distances_from_files(distance_files: list[str], smiles_list: list[str]) -> dict: +def _load_distances_from_files( + distance_files: list[str], smiles_list: list[str] +) -> dict: """Load distances from multiple .npy files.""" distance_cache = {} - + for dist_file in sorted(distance_files): try: distances = np.load(dist_file) for row in distances: idx1, idx2 = int(row[0]), int(row[1]) - + # Skip out of bounds if idx1 >= len(smiles_list) or idx2 >= len(smiles_list): continue - + # Map to SMILES smiles1, smiles2 = smiles_list[idx1], smiles_list[idx2] key = tuple(sorted([smiles1, smiles2])) - + # Store distance (3rd column) distance_cache[key] = float(row[2]) - + except Exception as e: logger.warning(f"Error loading {dist_file}: {e}") - + return distance_cache @@ -321,7 +329,7 @@ def compute_ed_and_mces_both( ed_distances = [] mces_distances = [] - + # Track cache hits/misses cache_hits = 0 cache_misses = 0 @@ -338,7 +346,7 @@ def compute_ed_and_mces_both( s0 = smiles[int(pair[0])] s1 = smiles[int(pair[1])] - + # Check precomputed cache first if precomputed_cache is not None: cache_key = tuple(sorted([s0, s1])) @@ -350,7 +358,7 @@ def compute_ed_and_mces_both( ed_distances.append(ed_dist) mces_distances.append(mces_dist) continue - + # Not in cache, compute cache_misses += 1 fp0 = fps[int(pair[0])] @@ -369,12 +377,14 @@ def compute_ed_and_mces_both( pair_distances[:, 2] = ed_distances pair_distances[:, 3] = mces_distances - + # Log cache statistics if precomputed_cache is not None: total = cache_hits + cache_misses hit_rate = (cache_hits / total * 100) if total > 0 else 0 - logger.info(f"Cache: {cache_hits} hits, {cache_misses} misses ({hit_rate:.1f}% hit rate)") + logger.info( + f"Cache: {cache_hits} hits, {cache_misses} misses ({hit_rate:.1f}% hit rate)" + ) if output_file: os.makedirs(os.path.dirname(output_file), exist_ok=True) @@ -597,7 +607,7 @@ def simba_solve_pair_mces( threshold=threshold, i=0, # solver='CPLEX_CMD', # or another fast solver you have installed - solver="PULP_CBC_CMD", + solver="HiGHS", solver_options={ "threads": 1, "msg": False, diff --git a/simba/core/models/similarity_models.py b/simba/core/models/similarity_models.py index adde69d..67afa99 100644 --- a/simba/core/models/similarity_models.py +++ b/simba/core/models/similarity_models.py @@ -74,7 +74,7 @@ def __init__( use_ce=use_ce, use_ion_activation=use_ion_activation, use_ion_method=use_ion_method, - use_ion_mode= use_ion_mode, + use_ion_mode=use_ion_mode, ) self.regression_loss = nn.MSELoss() @@ -115,8 +115,8 @@ def forward(self, batch): if self.use_ion_mode: kwargs_0["ionmode"] = batch["ionmode_0"].float() kwargs_1["ionmode"] = batch["ionmode_1"].float() - kwargs_0["precursor_charge"]=batch["precursor_charge_0"].float() - kwargs_1["precursor_charge"]=batch["precursor_charge_1"].float() + kwargs_0["precursor_charge"] = batch["precursor_charge_0"].float() + kwargs_1["precursor_charge"] = batch["precursor_charge_1"].float() if self.use_adduct: kwargs_0["ionmode"] = batch["ionmode_0"].float() kwargs_1["ionmode"] = batch["ionmode_1"].float() @@ -380,7 +380,7 @@ def __init__( use_ce=use_ce, use_ion_activation=use_ion_activation, use_ion_method=use_ion_method, - use_ion_mode= use_ion_mode, + use_ion_mode=use_ion_mode, ) self.weights = weights @@ -428,8 +428,8 @@ def __init__( # Initialize learnable log variance parameters for each loss self.USE_LEARNABLE_MULTITASK = USE_LEARNABLE_MULTITASK if USE_LEARNABLE_MULTITASK: - initial_log_sigma1= 0.0 - initial_log_sigma2 = -5.3 + initial_log_sigma1 = 0.0 + initial_log_sigma2 = -5.3 self.log_sigma1 = nn.Parameter(torch.tensor(initial_log_sigma1)) self.log_sigma2 = nn.Parameter(torch.tensor(initial_log_sigma2)) @@ -484,12 +484,8 @@ def forward(self, batch, return_spectrum_output=False): kwargs_0["adduct"] = batch["adduct_0"].float() kwargs_1["adduct"] = batch["adduct_1"].float() - batch["ce_0"] = torch.nan_to_num( - batch["ce_0"], nan=0.0, posinf=0.0, neginf=0.0 - ) - batch["ce_1"] = torch.nan_to_num( - batch["ce_1"], nan=0.0, posinf=0.0, neginf=0.0 - ) + batch["ce_0"] = torch.nan_to_num(batch["ce_0"], nan=0.0, posinf=0.0, neginf=0.0) + batch["ce_1"] = torch.nan_to_num(batch["ce_1"], nan=0.0, posinf=0.0, neginf=0.0) kwargs_0["ce"] = batch["ce_0"].float() kwargs_1["ce"] = batch["ce_1"].float() @@ -789,7 +785,7 @@ def forward(self, batch): "precursor_mass": batch["precursor_mass"].float(), "precursor_charge": batch["precursor_charge"].float(), } - + # Add metadata fields if present in batch if "ionmode" in batch: kwargs["ionmode"] = batch["ionmode"].float() diff --git a/simba/core/models/spectrum_encoder.py b/simba/core/models/spectrum_encoder.py index c303e5b..44e2614 100644 --- a/simba/core/models/spectrum_encoder.py +++ b/simba/core/models/spectrum_encoder.py @@ -1,8 +1,8 @@ import torch +from depthcharge.encoders import FloatEncoder from depthcharge.transformers import ( SpectrumTransformerEncoder, ) # PeptideTransformerEncoder, -from depthcharge.encoders import FloatEncoder class SpectrumTransformerEncoderCustom(SpectrumTransformerEncoder): @@ -13,7 +13,7 @@ def __init__( use_ce: bool = False, use_ion_activation: bool = False, use_ion_method: bool = False, - use_ion_mode: bool =False, + use_ion_mode: bool = False, **kwargs, ): """ @@ -30,7 +30,7 @@ def __init__( use_ion_method: bool Whether to include ionization method in the encoding (default: False). """ - self.use_encoders=False + self.use_encoders = False super().__init__(*args, **kwargs) self.use_adduct = use_adduct self.use_ce = use_ce @@ -41,17 +41,24 @@ def __init__( if self.use_encoders: if self.use_adduct: self.adduct_encoder = FloatEncoder(self.d_model) - self.ionmode_encoder =FloatEncoder(self.d_model) + self.ionmode_encoder = FloatEncoder(self.d_model) if self.use_ce: - self.ce_encoder = FloatEncoder(self.d_model, ) + self.ce_encoder = FloatEncoder( + self.d_model, + ) if self.use_ion_activation: - self.ion_activation_encoder= FloatEncoder(self.d_model) + self.ion_activation_encoder = FloatEncoder(self.d_model) if self.use_ion_method: self.ion_method_encoder = FloatEncoder(self.d_model) - if (self.use_adduct or self.use_ce or self.use_ion_activation or self.use_ion_method): - self.precursor_mz_encoder =FloatEncoder(self.d_model) + if ( + self.use_adduct + or self.use_ce + or self.use_ion_activation + or self.use_ion_method + ): + self.precursor_mz_encoder = FloatEncoder(self.d_model) def precursor_hook( self, @@ -63,41 +70,40 @@ def precursor_hook( dtype = mz_array.dtype batch_size = mz_array.shape[0] - if not(self.use_encoders): + if not (self.use_encoders): placeholder = torch.zeros( - (batch_size, self.d_model), dtype=dtype, device=device + (batch_size, self.d_model), dtype=dtype, device=device + ) + precursor_mass = ( + kwargs["precursor_mass"].float().to(device).view(batch_size) ) - precursor_mass = kwargs["precursor_mass"].float().to(device).view(batch_size) placeholder[:, 0] = precursor_mass precursor_charge = ( kwargs["precursor_charge"].float().to(device).view(batch_size) - ) + ) # skip the use of the precursor charge field if self.use_ion_mode: placeholder[:, 1] = precursor_charge - + current_idx = 2 # keep track of where to insert metadata - - - ionmode = kwargs["ionmode"].float().to(device).view(batch_size) - if self.use_ion_mode: + + ionmode = kwargs["ionmode"].float().to(device).view(batch_size) + if self.use_ion_mode: placeholder[:, current_idx] = ionmode current_idx += 1 - adduct = kwargs["adduct"].float().to(device).view(batch_size, -1) stop_idx = current_idx + adduct.shape[1] if self.use_adduct: placeholder[:, current_idx:stop_idx] = adduct current_idx = stop_idx - + ce = kwargs["ce"].float().to(device).view(batch_size) if self.use_ce: - placeholder[:, current_idx] = ce + placeholder[:, current_idx] = ce current_idx += 1 - ia = kwargs["ion_activation"].float().to(device).view(batch_size, -1) stop_idx = current_idx + ia.shape[1] if self.use_ion_activation: @@ -106,32 +112,34 @@ def precursor_hook( im = kwargs["ion_method"].float().to(device).view(batch_size, -1) stop_idx = current_idx + im.shape[1] - if self.use_ion_method: + if self.use_ion_method: placeholder[:, current_idx:stop_idx] = im current_idx = stop_idx # ensure there are no nans placeholder = torch.nan_to_num(placeholder, nan=0.0, posinf=0.0, neginf=0.0) - else: - - precursor_mass = kwargs["precursor_mass"].float().to(device).view(batch_size) - precursor_mass_rep = self.precursor_mz_encoder(precursor_mass[:,None]).squeeze(1) - placeholder = precursor_mass_rep + 0*precursor_mass_rep - + precursor_mass = ( + kwargs["precursor_mass"].float().to(device).view(batch_size) + ) + precursor_mass_rep = self.precursor_mz_encoder( + precursor_mass[:, None] + ).squeeze(1) + placeholder = precursor_mass_rep + 0 * precursor_mass_rep + if self.use_adduct: ionmode = kwargs["ionmode"].float().to(device).view(batch_size) adduct = kwargs["adduct"].float().to(device).view(batch_size, -1) - ionmode_rep = self.ionmode_encoder(ionmode[:,None]).squeeze(1) + ionmode_rep = self.ionmode_encoder(ionmode[:, None]).squeeze(1) adduct_rep = self.adduct_encoder(adduct).mean(dim=1) - placeholder= placeholder + (ionmode_rep + adduct_rep) + placeholder = placeholder + (ionmode_rep + adduct_rep) if self.use_ce: ce = kwargs["ce"].float().to(device).view(batch_size) - ce_rep = self.ce_encoder(ce[:,None]).squeeze(1) - placeholder= placeholder + ce_rep + ce_rep = self.ce_encoder(ce[:, None]).squeeze(1) + placeholder = placeholder + ce_rep if self.use_ion_method: im = kwargs["ion_method"].float().to(device).view(batch_size, -1) @@ -143,8 +151,6 @@ def precursor_hook( ia_rep = self.ion_activation_encoder(ia).mean(dim=1) placeholder = placeholder + ia_rep - - placeholder = torch.nan_to_num(placeholder, nan=0.0, posinf=0.0, neginf=0.0) - return placeholder \ No newline at end of file + return placeholder diff --git a/simba/workflows/inference.py b/simba/workflows/inference.py index e057196..8f85267 100644 --- a/simba/workflows/inference.py +++ b/simba/workflows/inference.py @@ -499,10 +499,10 @@ def _plot_cm( ) -> None: """Plot confusion matrix.""" import matplotlib.pyplot as plt - from sklearn.metrics import accuracy_score, confusion_matrix + from sklearn.metrics import balanced_accuracy_score, confusion_matrix cm = confusion_matrix(true, preds) - accuracy = accuracy_score(true, preds) + accuracy = balanced_accuracy_score(true, preds) # Normalize cm_normalized = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] diff --git a/simba/workflows/preprocessing.py b/simba/workflows/preprocessing.py index 1251cb6..2cfe359 100644 --- a/simba/workflows/preprocessing.py +++ b/simba/workflows/preprocessing.py @@ -59,9 +59,7 @@ def write_data( else None ), "df_smiles_val": ( - molecule_pairs_val.df_smiles - if molecule_pairs_val is not None - else None + molecule_pairs_val.df_smiles if molecule_pairs_val is not None else None ), "df_smiles_test": ( molecule_pairs_test.df_smiles @@ -212,13 +210,20 @@ def preprocess(cfg: DictConfig) -> None: # Load precomputed distances cache if configured precomputed_cache = {} - if hasattr(cfg.preprocessing, 'precomputed_distances') and cfg.preprocessing.precomputed_distances: + if ( + hasattr(cfg.preprocessing, "precomputed_distances") + and cfg.preprocessing.precomputed_distances + ): dirs = cfg.preprocessing.precomputed_distances if dirs and len(dirs) > 0: - logger.info(f"Loading precomputed distances from {len(dirs)} directory(ies)...") + logger.info( + f"Loading precomputed distances from {len(dirs)} directory(ies)..." + ) precomputed_cache = load_precomputed_distances_cache(dirs) else: - logger.info("No precomputed distances configured, computing all from scratch") + logger.info( + "No precomputed distances configured, computing all from scratch" + ) else: logger.info("No precomputed distances configured, computing all from scratch") diff --git a/simba/workflows/training.py b/simba/workflows/training.py index 4352745..987f3bb 100644 --- a/simba/workflows/training.py +++ b/simba/workflows/training.py @@ -84,12 +84,14 @@ def load_dataset(cfg: DictConfig): ) mgf_path = mapping["mgf_path"] - + # Use preprocessing config values (if available) to ensure consistent filtering - use_only_protonized = getattr(cfg.preprocessing, 'use_only_protonized_adducts', True) - + use_only_protonized = getattr( + cfg.preprocessing, "use_only_protonized_adducts", True + ) + all_spectra = load_spectra( - mgf_path, + mgf_path, cfg, n_samples=-1, # Load all spectra during training use_only_protonized_adducts=use_only_protonized, @@ -111,23 +113,27 @@ def load_dataset(cfg: DictConfig): original_spectra = [] idx_map = {} # old_idx -> new_idx missing = [] - + for old_idx, mgf_idx in enumerate(spectrum_indexes): if mgf_idx in spectra_by_idx: idx_map[old_idx] = len(original_spectra) original_spectra.append(spectra_by_idx[mgf_idx]) else: missing.append(mgf_idx) - + if missing: - logger.warning(f"[{split}] Missing {len(missing)} spectra (e.g., MGF index {missing[0]})") + logger.warning( + f"[{split}] Missing {len(missing)} spectra (e.g., MGF index {missing[0]})" + ) # Filter df_smiles to keep only rows with valid spectra valid_rows = [] for i in df_smiles.index: old_idxs = df_smiles.loc[i, "indexes"] if all(idx in idx_map for idx in old_idxs): # Remap to new positions - df_smiles.at[i, "indexes"] = [idx_map[idx] for idx in old_idxs] + df_smiles.at[i, "indexes"] = [ + idx_map[idx] for idx in old_idxs + ] valid_rows.append(i) df_smiles = df_smiles.loc[valid_rows] @@ -288,7 +294,7 @@ def prepare_data( use_ce=cfg.model.features.use_ce, use_ion_activation=cfg.model.features.use_ion_activation, use_ion_method=cfg.model.features.use_ion_method, - use_ion_mode = cfg.model.features.use_ion_mode, + use_ion_mode=cfg.model.features.use_ion_mode, ) dataset_val = MultitaskDataBuilder.from_molecule_pairs_to_dataset( @@ -298,7 +304,7 @@ def prepare_data( use_ce=cfg.model.features.use_ce, use_ion_activation=cfg.model.features.use_ion_activation, use_ion_method=cfg.model.features.use_ion_method, - use_ion_mode = cfg.model.features.use_ion_mode, + use_ion_mode=cfg.model.features.use_ion_mode, ) # Create samplers diff --git a/tests/unit/test_embedder_multitask.py b/tests/unit/test_embedder_multitask.py index c21a8b7..b8f1a03 100644 --- a/tests/unit/test_embedder_multitask.py +++ b/tests/unit/test_embedder_multitask.py @@ -116,6 +116,12 @@ def sample_batch(self): "adduct_1": torch.zeros(batch_size, n_adducts), # One-hot encoded "ionmode_0": torch.ones(batch_size, 1), "ionmode_1": torch.ones(batch_size, 1), + "ce_0": torch.ones(batch_size, 1) * 30.0, + "ce_1": torch.ones(batch_size, 1) * 30.0, + "ion_activation_0": torch.zeros(batch_size, 1), + "ion_activation_1": torch.zeros(batch_size, 1), + "ion_method_0": torch.zeros(batch_size, 1), + "ion_method_1": torch.zeros(batch_size, 1), "similarity": torch.tensor([0.8, 0.6]), "similarity_2": torch.tensor([0.7, 0.5]), } From 68627f48f37849ff151a36bde9840338eaee051c Mon Sep 17 00:00:00 2001 From: rukubrakov Date: Wed, 25 Mar 2026 09:33:54 +0100 Subject: [PATCH 8/8] Add global cache with filtering for precomputed molecular distances --- .../chemistry/edit_distance/edit_distance.py | 38 +++++++++++++++++-- simba/core/chemistry/mces/mces_computation.py | 8 ++-- simba/workflows/preprocessing.py | 20 ++++++++++ 3 files changed, 58 insertions(+), 8 deletions(-) diff --git a/simba/core/chemistry/edit_distance/edit_distance.py b/simba/core/chemistry/edit_distance/edit_distance.py index 9f927ba..b1d7efc 100644 --- a/simba/core/chemistry/edit_distance/edit_distance.py +++ b/simba/core/chemistry/edit_distance/edit_distance.py @@ -18,6 +18,38 @@ # Sentinel value indicating very dissimilar molecules (Tanimoto < 0.2) VERY_HIGH_DISTANCE = 666 +# Global cache for precomputed distances (shared across workers via fork) +_GLOBAL_CACHE = None + + +def set_global_cache(cache: dict) -> None: + """Set global cache before forking workers.""" + global _GLOBAL_CACHE + _GLOBAL_CACHE = cache + + +def get_global_cache() -> dict: + """Get global cache in worker process.""" + return _GLOBAL_CACHE + + +def filter_cache_by_smiles(cache: dict, smiles_list: list[str]) -> dict: + """Filter cache to only keep entries for given SMILES.""" + if not cache: + return {} + + smiles_set = set(smiles_list) + filtered = { + key: val + for key, val in cache.items() + if key[0] in smiles_set and key[1] in smiles_set + } + + logger.info( + f"Cache: {len(filtered):,} entries kept out of {len(cache):,} ({len(filtered) / len(cache) * 100:.1f}%)" + ) + return filtered + def load_precomputed_distances_cache(preprocessing_dirs: list[str]) -> dict: """ @@ -263,7 +295,6 @@ def compute_ed_and_mces_both( output_file: str = None, progress_position: int = 0, progress_desc: str = "Computing", - precomputed_cache: dict = None, ) -> np.ndarray | dict: """ Compute BOTH edit distance AND MCES for a batch of molecule pairs in a single pass. @@ -302,8 +333,6 @@ def compute_ed_and_mces_both( Position for the tqdm progress bar (default: 0). Used for stacking multiple progress bars. progress_desc : str, optional Description text for the tqdm progress bar (default: "Computing"). - precomputed_cache : dict, optional - Cache of precomputed distances with (smiles1, smiles2) keys -> [ed, mces] values. Returns ------- @@ -334,6 +363,9 @@ def compute_ed_and_mces_both( cache_hits = 0 cache_misses = 0 + # Get global cache (set before pool creation) + precomputed_cache = get_global_cache() + for index in tqdm( range(pair_distances.shape[0]), desc=progress_desc, diff --git a/simba/core/chemistry/mces/mces_computation.py b/simba/core/chemistry/mces/mces_computation.py index 1e3536c..f9ee29d 100644 --- a/simba/core/chemistry/mces/mces_computation.py +++ b/simba/core/chemistry/mces/mces_computation.py @@ -73,8 +73,6 @@ def compute_all_mces_results_unique( Precomputed molecule pairs to use instead of computing new ones, by default None; compute_both_metrics : bool, optional Whether to compute both MCES and edit distance metrics, by default False. - precomputed_cache : dict, optional - Cache of precomputed distances to reuse from previous runs, by default None. Returns ------- @@ -316,8 +314,6 @@ def compute_all_mces_results_exhaustive( If True, compute edit distance instead of MCES, by default False. compute_both_metrics : bool, optional If True, compute both ED and MCES in single pass (optimized), by default False. - precomputed_cache : dict, optional - Cache of precomputed distances to reuse from previous runs, by default None. Returns ------- @@ -413,6 +409,9 @@ def compute_all_mces_results_exhaustive( if not (os.path.exists(filename)): # do not overwrite existing files logger.info(f"Processing chunk {chunk_idx}/{len(chunks)}") + # Set global cache before creating pool (inherited via fork) + edit_distance.set_global_cache(precomputed_cache) + pool = multiprocessing.Pool(processes=num_workers) mols = [Chem.MolFromSmiles(s) for s in all_smiles] @@ -454,7 +453,6 @@ def compute_all_mces_results_exhaustive( worker_output, sub_index, f"Node {current_node} Chunk {chunk_idx} Worker {sub_index}", - precomputed_cache, ), ) results.append(result) diff --git a/simba/workflows/preprocessing.py b/simba/workflows/preprocessing.py index 2cfe359..bd34fe5 100644 --- a/simba/workflows/preprocessing.py +++ b/simba/workflows/preprocessing.py @@ -220,6 +220,26 @@ def preprocess(cfg: DictConfig) -> None: f"Loading precomputed distances from {len(dirs)} directory(ies)..." ) precomputed_cache = load_precomputed_distances_cache(dirs) + + # Filter cache for SMILES in current dataset + if precomputed_cache: + all_smiles = set() + for spectra_list in [ + all_spectra_train, + all_spectra_val, + all_spectra_test, + ]: + if spectra_list: + all_smiles.update(s.params["smiles"] for s in spectra_list) + + if all_smiles: + from simba.core.chemistry.edit_distance.edit_distance import ( + filter_cache_by_smiles, + ) + + precomputed_cache = filter_cache_by_smiles( + precomputed_cache, list(all_smiles) + ) else: logger.info( "No precomputed distances configured, computing all from scratch"