diff --git a/.gitignore b/.gitignore index bf04d5b..8775fbe 100644 --- a/.gitignore +++ b/.gitignore @@ -86,6 +86,13 @@ lightning_logs/ data_temp/ temp/ +# output +test_full_workflow/ + +# tfs files +simba/configs/*/tfs*.yaml +run_scripts_tfs* + # ============================================================================ # Code Quality Tools # ============================================================================ diff --git a/README.md b/README.md index bf5b005..5249e0a 100644 --- a/README.md +++ b/README.md @@ -275,6 +275,38 @@ simba preprocess \ --- +**Reusing Precomputed Distances:** + +To speed up preprocessing when working with related datasets (e.g., MS2-only, MS3-only, and joint MS2+MS3), you can reuse previously computed molecular distances: + +```bash +# First: preprocess MS2-only data +simba preprocess \ + paths.spectra_path=ms2_spectra.mgf \ + paths.preprocessing_dir=./ms2_preprocessing/ + +# Then: preprocess MS3-only data +simba preprocess \ + paths.spectra_path=ms3_spectra.mgf \ + paths.preprocessing_dir=./ms3_preprocessing/ + +# Finally: preprocess joint dataset, reusing distances from both +simba preprocess \ + paths.spectra_path=joint_spectra.mgf \ + paths.preprocessing_dir=./joint_preprocessing/ \ + 'preprocessing.precomputed_distances=[./ms2_preprocessing/, ./ms3_preprocessing/]' +``` + +The cache automatically: +- Finds all distance files (`edit_distance_*.npy`, `mces_*.npy`) in each directory +- Loads SMILES mappings from `mapping_unique_smiles.pkl` +- Matches molecules by SMILES strings (robust to different splits/filters) +- Logs cache hit/miss statistics during computation + +**Cache hit rate = % of molecule pairs that were reused instead of recomputed!** + +--- + **Quick Testing (Fast Dev Mode):** ```bash diff --git a/pyproject.toml b/pyproject.toml index 117d62b..58db336 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,6 +46,7 @@ dependencies = [ "pyteomics>=4.6.0", "depthcharge-ms @ git+https://github.com/wfondrie/depthcharge.git@bd2861f", "myopic-mces>=1.0.0,<2.0.0", + "highspy>=1.13.1", # Data processing "h5py>=3.10.0", "pyarrow>=15.0.0", diff --git a/simba/commands/analog_discovery.py b/simba/commands/analog_discovery.py index 7061c5c..8c86fab 100644 --- a/simba/commands/analog_discovery.py +++ b/simba/commands/analog_discovery.py @@ -191,5 +191,9 @@ def _analog_discovery_with_hydra( click.echo("=" * 70) except Exception as e: + import traceback + click.echo(f"\n❌ Error during analog discovery: {e}", err=True) + click.echo("\nFull traceback:", err=True) + click.echo(traceback.format_exc(), err=True) raise click.Abort() from e diff --git a/simba/configs/model/simba_default.yaml b/simba/configs/model/simba_default.yaml index bed1e99..cb46a9a 100644 --- a/simba/configs/model/simba_default.yaml +++ b/simba/configs/model/simba_default.yaml @@ -33,6 +33,7 @@ features: use_element_wise: true categorical_adducts: false use_only_protonized_adducts: false + use_ion_mode: false # Metadata features use_ce: false diff --git a/simba/configs/preprocessing/default.yaml b/simba/configs/preprocessing/default.yaml index e87ddcd..7a430a1 100644 --- a/simba/configs/preprocessing/default.yaml +++ b/simba/configs/preprocessing/default.yaml @@ -24,5 +24,13 @@ test_split: 0.1 # Test split fraction (0.0-1.0) random_mces_sampling: false use_only_protonized_adducts: true +# Precomputed distances - reuse distances from previous preprocessing runs +# Just list preprocessing directories - auto-discovers all distance files +precomputed_distances: + # Examples: + # - "./test_precomputed_cache/dataset1/" + # - "./ms2_preprocessing/" + # - "./ms3_preprocessing/" + # Subsampling subsample_preprocessing: false diff --git a/simba/core/chemistry/edit_distance/edit_distance.py b/simba/core/chemistry/edit_distance/edit_distance.py index 92793ea..b1d7efc 100644 --- a/simba/core/chemistry/edit_distance/edit_distance.py +++ b/simba/core/chemistry/edit_distance/edit_distance.py @@ -1,6 +1,7 @@ import os from functools import lru_cache +import dill import numpy as np import pandas as pd from myopic_mces.myopic_mces import MCES as MCES2 @@ -17,6 +18,258 @@ # Sentinel value indicating very dissimilar molecules (Tanimoto < 0.2) VERY_HIGH_DISTANCE = 666 +# Global cache for precomputed distances (shared across workers via fork) +_GLOBAL_CACHE = None + + +def set_global_cache(cache: dict) -> None: + """Set global cache before forking workers.""" + global _GLOBAL_CACHE + _GLOBAL_CACHE = cache + + +def get_global_cache() -> dict: + """Get global cache in worker process.""" + return _GLOBAL_CACHE + + +def filter_cache_by_smiles(cache: dict, smiles_list: list[str]) -> dict: + """Filter cache to only keep entries for given SMILES.""" + if not cache: + return {} + + smiles_set = set(smiles_list) + filtered = { + key: val + for key, val in cache.items() + if key[0] in smiles_set and key[1] in smiles_set + } + + logger.info( + f"Cache: {len(filtered):,} entries kept out of {len(cache):,} ({len(filtered) / len(cache) * 100:.1f}%)" + ) + return filtered + + +def load_precomputed_distances_cache(preprocessing_dirs: list[str]) -> dict: + """ + Load precomputed distances from preprocessing directories. + + Auto-discovers files in each directory: + - mapping_unique_smiles.pkl (for SMILES mapping) + - edit_distance_*.npy and mces_*.npy files (for distances) + + Parameters + ---------- + preprocessing_dirs : list[str] + List of preprocessing directory paths. + + Returns + ------- + dict + Cache with (smiles1, smiles2) tuple keys -> [ed_distance, mces_distance] values. + """ + cache = {} + + for prep_dir in preprocessing_dirs: + if not os.path.exists(prep_dir): + logger.warning(f"Preprocessing directory not found: {prep_dir}") + continue + + # Find the pickle file + pickle_path = os.path.join(prep_dir, "mapping_unique_smiles.pkl") + if not os.path.exists(pickle_path): + logger.warning(f"No mapping_unique_smiles.pkl found in {prep_dir}") + continue + + # Find distance files by prefix and split + distance_files = {} + for split in ["train", "val", "test"]: + distance_files[split] = {"ed": [], "mces": []} + + for filename in os.listdir(prep_dir): + if not filename.endswith(".npy"): + continue + + # Determine split + split = None + for s in ["train", "val", "test"]: + if f"_{s}_" in filename or filename.endswith(f"_{s}.npy"): + split = s + break + + if split is None: + continue + + # Categorize by type + if filename.startswith("edit_distance_"): + distance_files[split]["ed"].append(os.path.join(prep_dir, filename)) + elif filename.startswith("mces_"): + distance_files[split]["mces"].append(os.path.join(prep_dir, filename)) + + # Count total files + total_ed = sum(len(v["ed"]) for v in distance_files.values()) + total_mces = sum(len(v["mces"]) for v in distance_files.values()) + + if total_ed == 0 and total_mces == 0: + logger.warning(f"No distance files found in {prep_dir}") + continue + + logger.info( + f"Auto-discovered in {prep_dir}: {total_ed} ED files, {total_mces} MCES files" + ) + + # Load SMILES mapping per split + try: + with open(pickle_path, "rb") as f: + data = dill.load(f) + except Exception as e: + logger.error(f"Error loading pickle {pickle_path}: {e}") + continue + + # Process each split separately (indices are per-split) + pairs_added = 0 + for split in ["train", "val", "test"]: + # Extract SMILES for this split only + split_smiles = [] + for key_prefix in ["molecule_pairs_", "df_smiles_"]: + key = key_prefix + split + if key not in data or data[key] is None: + continue + + # Extract DataFrame + if key_prefix == "molecule_pairs_": + if not hasattr(data[key], "df_smiles"): + continue + df = data[key].df_smiles + else: + df = data[key] + + if df is None or df.empty: + continue + + # Extract SMILES from DataFrame + if "canon_smiles" in df.columns: + smiles_list = df["canon_smiles"].tolist() + elif "smiles" in df.columns: + smiles_list = df["smiles"].tolist() + elif hasattr(df.index, "tolist"): + smiles_list = df.index.tolist() + else: + continue + + # Remove duplicates while preserving order + seen = set() + for s in smiles_list: + if s not in seen: + seen.add(s) + split_smiles.append(s) + break # Only need one match + + if not split_smiles: + continue + + # Load distances for this split + ed_cache_split = _load_distances_from_files( + distance_files[split]["ed"], split_smiles + ) + mces_cache_split = _load_distances_from_files( + distance_files[split]["mces"], split_smiles + ) + + # Combine into final cache (only pairs with both ED and MCES) + for key in set(ed_cache_split.keys()) & set(mces_cache_split.keys()): + cache[key] = [ed_cache_split[key], mces_cache_split[key]] + pairs_added += 1 + + logger.info(f"Loaded {pairs_added} valid pairs from {prep_dir}") + + logger.info(f"Total unique pairs in cache: {len(cache)}") + return cache + + +def _load_smiles_from_pickle(pickle_path: str) -> list[str]: + """Extract unique SMILES list from preprocessing pickle.""" + try: + with open(pickle_path, "rb") as f: + data = dill.load(f) + + unique_smiles = [] + + # Handle both old format (molecule_pairs_X) and new lightweight format (df_smiles_X) + for key_prefix in ["molecule_pairs_", "df_smiles_"]: + for split in ["train", "val", "test"]: + key = key_prefix + split + if key not in data or data[key] is None: + continue + + # Extract DataFrame + if key_prefix == "molecule_pairs_": + if not hasattr(data[key], "df_smiles"): + continue + df = data[key].df_smiles + else: + df = data[key] + + if df is None or df.empty: + continue + + # Extract SMILES from DataFrame + if "canon_smiles" in df.columns: + smiles_list = df["canon_smiles"].tolist() + elif "smiles" in df.columns: + smiles_list = df["smiles"].tolist() + elif hasattr(df.index, "tolist"): + smiles_list = df.index.tolist() + else: + continue + + unique_smiles.extend(smiles_list) + + # Remove duplicates while preserving order + seen = set() + result = [] + for s in unique_smiles: + if s not in seen: + seen.add(s) + result.append(s) + + logger.info(f"Extracted {len(result)} unique SMILES from {pickle_path}") + return result + + except Exception as e: + logger.error(f"Error loading SMILES from {pickle_path}: {e}") + return [] + + +def _load_distances_from_files( + distance_files: list[str], smiles_list: list[str] +) -> dict: + """Load distances from multiple .npy files.""" + distance_cache = {} + + for dist_file in sorted(distance_files): + try: + distances = np.load(dist_file) + for row in distances: + idx1, idx2 = int(row[0]), int(row[1]) + + # Skip out of bounds + if idx1 >= len(smiles_list) or idx2 >= len(smiles_list): + continue + + # Map to SMILES + smiles1, smiles2 = smiles_list[idx1], smiles_list[idx2] + key = tuple(sorted([smiles1, smiles2])) + + # Store distance (3rd column) + distance_cache[key] = float(row[2]) + + except Exception as e: + logger.warning(f"Error loading {dist_file}: {e}") + + return distance_cache + def create_input_df(smiles, indexes_0, indexes_1): df = pd.DataFrame() @@ -106,6 +359,13 @@ def compute_ed_and_mces_both( ed_distances = [] mces_distances = [] + # Track cache hits/misses + cache_hits = 0 + cache_misses = 0 + + # Get global cache (set before pool creation) + precomputed_cache = get_global_cache() + for index in tqdm( range(pair_distances.shape[0]), desc=progress_desc, @@ -118,6 +378,21 @@ def compute_ed_and_mces_both( s0 = smiles[int(pair[0])] s1 = smiles[int(pair[1])] + + # Check precomputed cache first + if precomputed_cache is not None: + cache_key = tuple(sorted([s0, s1])) + if cache_key in precomputed_cache: + cached_values = precomputed_cache[cache_key] + ed_dist = cached_values[0] + mces_dist = cached_values[1] + cache_hits += 1 + ed_distances.append(ed_dist) + mces_distances.append(mces_dist) + continue + + # Not in cache, compute + cache_misses += 1 fp0 = fps[int(pair[0])] fp1 = fps[int(pair[1])] mol0 = mols[int(pair[0])] @@ -135,6 +410,14 @@ def compute_ed_and_mces_both( pair_distances[:, 2] = ed_distances pair_distances[:, 3] = mces_distances + # Log cache statistics + if precomputed_cache is not None: + total = cache_hits + cache_misses + hit_rate = (cache_hits / total * 100) if total > 0 else 0 + logger.info( + f"Cache: {cache_hits} hits, {cache_misses} misses ({hit_rate:.1f}% hit rate)" + ) + if output_file: os.makedirs(os.path.dirname(output_file), exist_ok=True) np.save(output_file, pair_distances) @@ -356,7 +639,7 @@ def simba_solve_pair_mces( threshold=threshold, i=0, # solver='CPLEX_CMD', # or another fast solver you have installed - solver="PULP_CBC_CMD", + solver="HiGHS", solver_options={ "threads": 1, "msg": False, diff --git a/simba/core/chemistry/mces/mces_computation.py b/simba/core/chemistry/mces/mces_computation.py index 2fbaca5..f9ee29d 100644 --- a/simba/core/chemistry/mces/mces_computation.py +++ b/simba/core/chemistry/mces/mces_computation.py @@ -36,6 +36,7 @@ def compute_all_mces_results_unique( use_edit_distance: bool = False, loaded_molecule_pairs: MolecularPairsSet | None = None, compute_both_metrics: bool = False, + precomputed_cache: dict = None, ) -> MoleculePairsOpt: """ Compute MCES or edit distance for all pairs of spectra using multiprocessing. @@ -100,6 +101,7 @@ def compute_all_mces_results_unique( identifier=identifier, use_edit_distance=use_edit_distance, compute_both_metrics=compute_both_metrics, + precomputed_cache=precomputed_cache, ) else: molecular_pairs = loaded_molecule_pairs @@ -289,6 +291,7 @@ def compute_all_mces_results_exhaustive( identifier: str = "", use_edit_distance=False, compute_both_metrics: bool = False, + precomputed_cache: dict = None, ) -> MolecularPairsSet: """ Compute MCES or edit distance for all pairs of spectra using multiprocessing. @@ -406,6 +409,9 @@ def compute_all_mces_results_exhaustive( if not (os.path.exists(filename)): # do not overwrite existing files logger.info(f"Processing chunk {chunk_idx}/{len(chunks)}") + # Set global cache before creating pool (inherited via fork) + edit_distance.set_global_cache(precomputed_cache) + pool = multiprocessing.Pool(processes=num_workers) mols = [Chem.MolFromSmiles(s) for s in all_smiles] diff --git a/simba/core/data/augmentation.py b/simba/core/data/augmentation.py index 45c68ad..3c69d16 100644 --- a/simba/core/data/augmentation.py +++ b/simba/core/data/augmentation.py @@ -5,6 +5,22 @@ class Augmentation: + @staticmethod + def zero_and_pack(arr, zero_idx): + """ + put zeros to the array and move the nonzero values to the beginning + """ + mask = np.ones(len(arr), dtype=bool) + mask[zero_idx] = False + complement_idx = np.where(mask)[0] + + arr = arr.copy() + arr[zero_idx] = 0 + nonzero = arr[complement_idx] + return np.concatenate( + [nonzero, np.zeros(len(arr) - len(nonzero), dtype=arr.dtype)] + ) + @staticmethod def augment(data_sample, training=False, max_num_peaks=None): new_sample = copy.deepcopy(data_sample) @@ -49,7 +65,7 @@ def normalize_max(data_sample): return data_sample @staticmethod - def peak_augmentation_max_peaks(data_sample, p_augmentation=1.0, max_peaks=100): + def peak_augmentation_max_peaks(data_sample, p_augmentation=0.50, max_peaks=100): # first normalize to maximum ## half of the time select maximum 20, the other half something between 5 and the maximum number of peaks @@ -71,10 +87,14 @@ def peak_augmentation_max_peaks(data_sample, p_augmentation=1.0, max_peaks=100): intensity_ordered_indexes = np.argsort(intensity)[ ::-1 ] # flip the order to have the max at the beginning - indexes_to_be_erased = intensity_ordered_indexes[max_augmented_peaks:-1] + indexes_to_be_erased = intensity_ordered_indexes[ + max_augmented_peaks:-1 + ] - intensity[indexes_to_be_erased] = 0 - mz[indexes_to_be_erased] = 0 + intensity = Augmentation.zero_and_pack( + intensity, indexes_to_be_erased + ) + mz = Augmentation.zero_and_pack(mz, indexes_to_be_erased) # apply data_sample[intensity_column] = intensity @@ -85,7 +105,7 @@ def peak_augmentation_max_peaks(data_sample, p_augmentation=1.0, max_peaks=100): @staticmethod def peak_augmentation_removal_noise( - data_sample, max_percentage=0.01, p_augmentation=1.0 + data_sample, max_percentage=0.01, p_augmentation=0.5 ): if random.random() < p_augmentation: # first normalize to maximum @@ -100,9 +120,13 @@ def peak_augmentation_removal_noise( max_amplitude = random.random() * max_percentage # indexes_to_modify=intensity < max_amplitude - indexes_to_be_erased = intensity < max_amplitude - intensity[indexes_to_be_erased] = 0 - mz[indexes_to_be_erased] = 0 + indexes_to_be_erased = intensity < ( + max_amplitude * np.max(intensity, keepdims=True) + ) + intensity = Augmentation.zero_and_pack( + intensity, indexes_to_be_erased + ) + mz = Augmentation.zero_and_pack(mz, indexes_to_be_erased) # apply data_sample[intensity_column] = intensity @@ -125,26 +149,6 @@ def normalize_intensities(data_sample, intensity_labels=None): data_sample[intensity_column] = intensity return data_sample - @staticmethod - def inversion(data_sample): - # inversion - - new_sample = {} - new_sample["mz_0"] = data_sample["mz_1"] - new_sample["mz_1"] = data_sample["mz_0"] - - new_sample["intensity_0"] = data_sample["intensity_1"] - new_sample["intensity_1"] = data_sample["intensity_0"] - - new_sample["precursor_mass_0"] = data_sample["precursor_mass_1"] - new_sample["precursor_mass_1"] = data_sample["precursor_mass_0"] - - new_sample["precursor_charge_0"] = data_sample["precursor_charge_1"] - new_sample["precursor_charge_1"] = data_sample["precursor_charge_0"] - - new_sample["similarity"] = data_sample["similarity"] - return new_sample - @staticmethod def add_false_precursor_masses_positives( sample, max_noise=0.01, p_augmentation=0.2 @@ -189,7 +193,7 @@ def add_false_precursor_masses_negatives( return sample @staticmethod - def random_peak_dropout(data_sample, dropout_rate=0.10, p_augmentation=1.0): + def random_peak_dropout(data_sample, dropout_rate=0.10, p_augmentation=0.5): """ Randomly zero out a percentage of peaks to simulate partial data loss. """ @@ -203,10 +207,11 @@ def random_peak_dropout(data_sample, dropout_rate=0.10, p_augmentation=1.0): n_peaks = len(intensity_array) n_drop = int(n_peaks * dropout_rate) # choose random peaks to drop - drop_indices = random.sample(range(n_peaks), n_drop) - for idx in drop_indices: - intensity_array[idx] = 0 - mz_array[idx] = 0 + drop_indices = np.array(random.sample(range(n_peaks), n_drop)) + intensity_array = Augmentation.zero_and_pack( + intensity_array, drop_indices + ) + mz_array = Augmentation.zero_and_pack(mz_array, drop_indices) data_sample[int_key] = intensity_array data_sample[mz_key] = mz_array @@ -219,19 +224,50 @@ def masking_metadata(data_sample, p_aug=0.5): """ keys_found = list(data_sample.keys()) if random.random() < p_aug: - if "ionmode_0" in keys_found: - data_sample["ionmode_0"] = 0 * data_sample["ionmode_0"] - data_sample["ionmode_1"] = 0 * data_sample["ionmode_1"] - - if "adduct_0" in keys_found: - data_sample["adduct_0"] = 0 * data_sample["adduct_0"] - data_sample["adduct_1"] = 0 * data_sample["adduct_1"] - - if "ce_0" in keys_found: - data_sample["ce_0"] = 0 * data_sample["ce_0"] - data_sample["ce_1"] = 0 * data_sample["ce_1"] + if random.random() < 0.5: + if "ionmode_0" in keys_found: + data_sample["ionmode_0"] = 0 * data_sample["ionmode_0"] + data_sample["ionmode_1"] = 0 * data_sample["ionmode_1"] + data_sample["precursor_charge_0"] = ( + 0 * data_sample["precursor_charge_0"] + ) + data_sample["precursor_charge_1"] = ( + 0 * data_sample["precursor_charge_1"] + ) + if random.random() < 0.5: + if "adduct_0" in keys_found: + data_sample["adduct_0"] = 0 * data_sample["adduct_0"] + data_sample["adduct_1"] = 0 * data_sample["adduct_1"] + if random.random() < 0.5: + if "ce_0" in keys_found: + data_sample["ce_0"] = 0 * data_sample["ce_0"] + data_sample["ce_1"] = 0 * data_sample["ce_1"] + if random.random() < 0.5: + if "ion_activation_0" in keys_found: + data_sample["ion_activation_0"] = ( + 0 * data_sample["ion_activation_0"] + ) + data_sample["ion_activation_1"] = ( + 0 * data_sample["ion_activation_1"] + ) - if "ion_activation_0" in keys_found: - data_sample["ion_activation_0"] = 0 * data_sample["ion_activation_0"] - data_sample["ion_activation_1"] = 0 * data_sample["ion_activation_1"] + else: + data_sample["ionmode_0"] = 0 * data_sample["ionmode_0"] + data_sample["ionmode_1"] = 0 * data_sample["ionmode_1"] + data_sample["precursor_charge_0"] = ( + 0 * data_sample["precursor_charge_0"] + ) + data_sample["precursor_charge_1"] = ( + 0 * data_sample["precursor_charge_1"] + ) + data_sample["adduct_0"] = 0 * data_sample["adduct_0"] + data_sample["adduct_1"] = 0 * data_sample["adduct_1"] + data_sample["ce_0"] = 0 * data_sample["ce_0"] + data_sample["ce_1"] = 0 * data_sample["ce_1"] + data_sample["ion_activation_0"] = ( + 0 * data_sample["ion_activation_0"] + ) + data_sample["ion_activation_1"] = ( + 0 * data_sample["ion_activation_1"] + ) return data_sample diff --git a/simba/core/data/datasets/encoder_dataset_builder.py b/simba/core/data/datasets/encoder_dataset_builder.py index c9e428f..fe8647b 100644 --- a/simba/core/data/datasets/encoder_dataset_builder.py +++ b/simba/core/data/datasets/encoder_dataset_builder.py @@ -4,6 +4,16 @@ from simba.core.data.datasets.encoder_dataset import CustomDatasetEncoder from simba.core.data.preprocessor import Preprocessor +from simba.core.chemistry.chem_utils import ( + ADDUCT_TO_MASS, +) +from simba.core.data.encoding import ( + IONIZATION_METHODS, + ION_ACTIVATION, + encode_adduct_mass, + encode_ionization_method, + encode_ion_activation, +) def prepare_encoder_dataset(spectra, max_num_peaks=100): @@ -27,6 +37,13 @@ def prepare_encoder_dataset(spectra, max_num_peaks=100): intensity = np.zeros((len(spectra), max_num_peaks), dtype=np.float32) precursor_mass = np.zeros((len(spectra), 1), dtype=np.float32) precursor_charge = np.zeros((len(spectra), 1), dtype=np.int32) + + # Metadata fields (initialized with default values) + ionmode = np.zeros((len(spectra), 1), dtype=np.float32) + adduct = np.zeros((len(spectra), len(ADDUCT_TO_MASS.keys())), dtype=np.float32) + ce = np.zeros((len(spectra), 1), dtype=np.int32) + ia = np.zeros((len(spectra), len(ION_ACTIVATION)), dtype=np.int32) + im = np.zeros((len(spectra), len(IONIZATION_METHODS)), dtype=np.int32) for i, spectrum in enumerate(spectra): # check for maximum length @@ -40,6 +57,37 @@ def prepare_encoder_dataset(spectra, max_num_peaks=100): precursor_mass[i] = spectrum.precursor_mz precursor_charge[i] = spectrum.precursor_charge + + # Extract metadata if available + # Ion mode + if hasattr(spectrum, 'ionmode') and spectrum.ionmode is not None and spectrum.ionmode != "None": + ionmode[i] = 1.0 if spectrum.ionmode.lower() == "positive" else -1.0 + else: + ionmode[i] = 0.0 + + # Adduct + if hasattr(spectrum, 'params') and 'adduct' in spectrum.params: + adduct[i] = encode_adduct_mass(spectrum.params['adduct']) + elif hasattr(spectrum, 'adduct') and spectrum.adduct is not None: + adduct[i] = encode_adduct_mass(spectrum.adduct) + + # Collision Energy + if hasattr(spectrum, 'ce') and spectrum.ce is not None and spectrum.ce != "None": + ce[i] = spectrum.ce + else: + ce[i] = 0 + + # Ion Activation + if hasattr(spectrum, 'ion_activation') and spectrum.ion_activation is not None and spectrum.ion_activation != "None": + ia[i] = encode_ion_activation(spectrum.ion_activation) + else: + ia[i] = np.zeros(len(ION_ACTIVATION), dtype=np.int32) + + # Ionization Method + if hasattr(spectrum, 'ionization_method') and spectrum.ionization_method is not None and spectrum.ionization_method != "None": + im[i] = encode_ionization_method(spectrum.ionization_method) + else: + im[i] = np.zeros(len(IONIZATION_METHODS), dtype=np.int32) # Normalize the intensity array intensity = intensity / np.sqrt(np.sum(intensity**2, axis=1, keepdims=True)) @@ -49,6 +97,11 @@ def prepare_encoder_dataset(spectra, max_num_peaks=100): "intensity": intensity, "precursor_mass": precursor_mass, "precursor_charge": precursor_charge, + "ionmode": ionmode, + "adduct": adduct, + "ce": ce, + "ion_activation": ia, + "ion_method": im, } return CustomDatasetEncoder(spectrum_data) diff --git a/simba/core/data/datasets/multitask_dataset.py b/simba/core/data/datasets/multitask_dataset.py index 15d6ab4..05950ff 100644 --- a/simba/core/data/datasets/multitask_dataset.py +++ b/simba/core/data/datasets/multitask_dataset.py @@ -17,8 +17,7 @@ def __init__( self, your_dict, training=False, - prob_aug=1.0, - # prob_aug=0.2, + prob_aug=0.50, mz=None, intensity=None, precursor_mass=None, @@ -36,6 +35,7 @@ def __init__( ion_activation=None, use_ion_method=False, ion_method=None, + use_ion_mode=False, ): self.data = your_dict self.keys = list(your_dict.keys()) @@ -52,23 +52,19 @@ def __init__( self.use_ce = use_ce self.use_ion_activation = use_ion_activation self.use_ion_method = use_ion_method + self.use_ion_mode = use_ion_mode if self.use_fingerprints: self.fingerprint_0 = fingerprint_0 self.max_num_peaks = max_num_peaks - if self.use_adduct: - self.ionmode = ionmode - self.adduct_mass = adduct - - if self.use_ce: - self.ce = ce + self.adduct_mass = adduct + self.ionmode = ionmode + self.ce = ce - if self.use_ion_activation: - self.ion_activation = ion_activation + self.ion_activation = ion_activation - if self.use_ion_method: - self.ion_method = ion_method + self.ion_method = ion_method def __len__(self): return len(self.data[self.keys[0]]) @@ -82,20 +78,32 @@ def get_original_dictionary(self, max_num_peaks=100): ## Get the mz, intensity values and precursor data dictionary = {} - dictionary["mz_0"] = np.zeros((len_data, max_num_peaks), dtype=np.float32) + dictionary["mz_0"] = np.zeros( + (len_data, max_num_peaks), dtype=np.float32 + ) dictionary["intensity_0"] = np.zeros( (len_data, max_num_peaks), dtype=np.float32 ) - dictionary["mz_1"] = np.zeros((len_data, max_num_peaks), dtype=np.float32) + dictionary["mz_1"] = np.zeros( + (len_data, max_num_peaks), dtype=np.float32 + ) dictionary["intensity_1"] = np.zeros( (len_data, max_num_peaks), dtype=np.float32 ) dictionary["ed"] = np.zeros((len_data, 1), dtype=np.float32) dictionary["mces"] = np.zeros((len_data, 1), dtype=np.float32) - dictionary["precursor_mass_0"] = np.zeros((len_data, 1), dtype=np.float32) - dictionary["precursor_charge_0"] = np.zeros((len_data, 1), dtype=np.int32) - dictionary["precursor_mass_1"] = np.zeros((len_data, 1), dtype=np.float32) - dictionary["precursor_charge_1"] = np.zeros((len_data, 1), dtype=np.int32) + dictionary["precursor_mass_0"] = np.zeros( + (len_data, 1), dtype=np.float32 + ) + dictionary["precursor_charge_0"] = np.zeros( + (len_data, 1), dtype=np.int32 + ) + dictionary["precursor_mass_1"] = np.zeros( + (len_data, 1), dtype=np.float32 + ) + dictionary["precursor_charge_1"] = np.zeros( + (len_data, 1), dtype=np.int32 + ) ### add extra metadata in case it is necessary if self.use_adduct: @@ -130,7 +138,9 @@ def get_original_dictionary(self, max_num_peaks=100): if self.use_fingerprints: print("Defining fingerprints ...") - dictionary["fingerprint_0"] = np.zeros((len_data, 2048), dtype=np.int32) + dictionary["fingerprint_0"] = np.zeros( + (len_data, 2048), dtype=np.int32 + ) for idx in tqdm(range(0, len_data)): sample_unique = {k: self.data[k][idx] for k in self.keys} @@ -139,19 +149,27 @@ def get_original_dictionary(self, max_num_peaks=100): indexes_unique_1 = sample_unique["index_unique_1"] print(f"value of indexes_unique_0 {indexes_unique_0} ") - indexes_original_0 = self.df_smiles.loc[int(indexes_unique_0), "indexes"][0] + indexes_original_0 = self.df_smiles.loc[ + int(indexes_unique_0), "indexes" + ][0] - indexes_original_1 = self.df_smiles.loc[int(indexes_unique_1), "indexes"][0] + indexes_original_1 = self.df_smiles.loc[ + int(indexes_unique_1), "indexes" + ][0] - dictionary["mz_0"][idx] = self.mz[indexes_original_0].astype(np.float32) - dictionary["intensity_0"][idx] = self.intensity[indexes_original_0].astype( + dictionary["mz_0"][idx] = self.mz[indexes_original_0].astype( np.float32 ) + dictionary["intensity_0"][idx] = self.intensity[ + indexes_original_0 + ].astype(np.float32) - dictionary["mz_1"][idx] = self.mz[indexes_original_1].astype(np.float32) - dictionary["intensity_1"][idx] = self.intensity[indexes_original_1].astype( + dictionary["mz_1"][idx] = self.mz[indexes_original_1].astype( np.float32 ) + dictionary["intensity_1"][idx] = self.intensity[ + indexes_original_1 + ].astype(np.float32) dictionary["precursor_mass_0"][idx] = self.precursor_mass[ indexes_original_0 ].astype(np.float32) @@ -166,14 +184,14 @@ def get_original_dictionary(self, max_num_peaks=100): ].astype(np.float32) dictionary["ed"][idx] = sample_unique["ed"].astype(np.float32) dictionary["mces"][idx] = sample_unique["mces"].astype(np.float32) + if self.use_ion_mode: + dictionary["ionmode_0"][idx] = self.ionmode[ + indexes_original_0 + ].astype(np.float32) + dictionary["ionmode_1"][idx] = self.ionmode[ + indexes_original_1 + ].astype(np.float32) if self.use_adduct: - dictionary["ionmode_0"][idx] = self.ionmode[indexes_original_0].astype( - np.float32 - ) - dictionary["ionmode_1"][idx] = self.ionmode[indexes_original_1].astype( - np.float32 - ) - dictionary["adduct_0"][idx] = self.adduct_mass[ indexes_original_0 ].astype(np.float32) @@ -182,8 +200,12 @@ def get_original_dictionary(self, max_num_peaks=100): ].astype(np.float32) if self.use_ce: - dictionary["ce_0"][idx] = self.ce[indexes_original_0].astype(np.float32) - dictionary["ce_1"][idx] = self.ce[indexes_original_1].astype(np.float32) + dictionary["ce_0"][idx] = self.ce[indexes_original_0].astype( + np.float32 + ) + dictionary["ce_1"][idx] = self.ce[indexes_original_1].astype( + np.float32 + ) if self.use_ion_activation: dictionary["ion_activation_0"][idx] = self.ion_activation[ @@ -216,8 +238,12 @@ def __getitem__(self, idx): if self.training: # select random samples - idx_0_original = random.choice(self.df_smiles.loc[int(idx_0[0]), "indexes"]) - idx_1_original = random.choice(self.df_smiles.loc[int(idx_1[0]), "indexes"]) + idx_0_original = random.choice( + self.df_smiles.loc[int(idx_0[0]), "indexes"] + ) + idx_1_original = random.choice( + self.df_smiles.loc[int(idx_1[0]), "indexes"] + ) else: # select the first index idx_0_original = self.df_smiles.loc[int(idx_0[0]), "indexes"][0] @@ -253,20 +279,20 @@ def __getitem__(self, idx): ind = int(idx_0[0]) if self.training: if (ind % 2) == 0: - spectrum_sample["fingerprint_0"] = self.fingerprint_0[ind].astype( - np.float32 - ) + spectrum_sample["fingerprint_0"] = self.fingerprint_0[ + ind + ].astype(np.float32) else: # return 0s spectrum_sample["fingerprint_0"] = 0 * self.fingerprint_0[ ind ].astype(np.float32) else: - spectrum_sample["fingerprint_0"] = self.fingerprint_0[ind].astype( - np.float32 - ) + spectrum_sample["fingerprint_0"] = self.fingerprint_0[ + ind + ].astype(np.float32) - if self.use_adduct: + if self.use_ion_mode: spectrum_sample["ionmode_0"] = self.ionmode[idx_0_original].astype( np.float32 ) @@ -274,17 +300,42 @@ def __getitem__(self, idx): np.float32 ) - spectrum_sample["adduct_0"] = self.adduct_mass[idx_0_original].astype( + else: + spectrum_sample["ionmode_0"] = 0 * self.ionmode[ + idx_0_original + ].astype(np.float32) + spectrum_sample["ionmode_1"] = 0 * self.ionmode[ + idx_1_original + ].astype(np.float32) + + if self.use_adduct: + spectrum_sample["adduct_0"] = self.adduct_mass[ + idx_0_original + ].astype(np.float32) + spectrum_sample["adduct_1"] = self.adduct_mass[ + idx_1_original + ].astype(np.float32) + else: + spectrum_sample["adduct_0"] = 0 * self.adduct_mass[ + idx_0_original + ].astype(np.float32) + spectrum_sample["adduct_1"] = 0 * self.adduct_mass[ + idx_1_original + ].astype(np.float32) + if self.use_ce: + spectrum_sample["ce_0"] = self.ce[idx_0_original].astype( np.float32 ) - spectrum_sample["adduct_1"] = self.adduct_mass[idx_1_original].astype( + spectrum_sample["ce_1"] = self.ce[idx_1_original].astype( + np.float32 + ) + else: + spectrum_sample["ce_0"] = 0 * self.ce[idx_0_original].astype( + np.float32 + ) + spectrum_sample["ce_1"] = 0 * self.ce[idx_1_original].astype( np.float32 ) - - if self.use_ce: - spectrum_sample["ce_0"] = self.ce[idx_0_original].astype(np.float32) - spectrum_sample["ce_1"] = self.ce[idx_1_original].astype(np.float32) - if self.use_ion_activation: spectrum_sample["ion_activation_0"] = self.ion_activation[ idx_0_original @@ -292,16 +343,28 @@ def __getitem__(self, idx): spectrum_sample["ion_activation_1"] = self.ion_activation[ idx_1_original ].astype(np.float32) - + else: + spectrum_sample["ion_activation_0"] = 0 * self.ion_activation[ + idx_0_original + ].astype(np.float32) + spectrum_sample["ion_activation_1"] = 0 * self.ion_activation[ + idx_1_original + ].astype(np.float32) if self.use_ion_method: - spectrum_sample["ion_method_0"] = self.ion_method[idx_0_original].astype( - np.float32 - ) - spectrum_sample["ion_method_1"] = self.ion_method[idx_1_original].astype( - np.float32 - ) - - if self.training and random.random() < self.prob_aug: + spectrum_sample["ion_method_0"] = self.ion_method[ + idx_0_original + ].astype(np.float32) + spectrum_sample["ion_method_1"] = self.ion_method[ + idx_1_original + ].astype(np.float32) + else: + spectrum_sample["ion_method_0"] = 0 * self.ion_method[ + idx_0_original + ].astype(np.float32) + spectrum_sample["ion_method_1"] = 0 * self.ion_method[ + idx_1_original + ].astype(np.float32) + if self.training and (random.random() < self.prob_aug): # augmentation spectrum_sample = Augmentation.augment( spectrum_sample, max_num_peaks=self.max_num_peaks @@ -309,4 +372,4 @@ def __getitem__(self, idx): # normalize spectrum_sample = Augmentation.normalize_intensities(spectrum_sample) - return spectrum_sample + return spectrum_sample \ No newline at end of file diff --git a/simba/core/data/datasets/multitask_dataset_builder.py b/simba/core/data/datasets/multitask_dataset_builder.py index 3626353..666d039 100644 --- a/simba/core/data/datasets/multitask_dataset_builder.py +++ b/simba/core/data/datasets/multitask_dataset_builder.py @@ -10,7 +10,7 @@ from simba.core.data.encoding import ( ION_ACTIVATION, IONIZATION_METHODS, - encode_adduct, + encode_adduct_mass, encode_ion_activation, encode_ionization_method, ) @@ -36,6 +36,7 @@ def from_molecule_pairs_to_dataset( use_ce: bool = False, use_ion_activation: bool = False, use_ion_method: bool = False, + use_ion_mode: bool = False, ) -> CustomDatasetMultitasking: """ Load data from molecule pairs into a Pytorch dataset for multitask learning. @@ -98,35 +99,33 @@ def from_molecule_pairs_to_dataset( precursor_charge = np.zeros( (len(molecule_pairs.original_spectra), 1), dtype=np.int32 ) - if use_adduct: - ionmode = np.zeros( - (len(molecule_pairs.original_spectra), 1), dtype=np.float32 - ) - adduct = np.zeros( - ( - len(molecule_pairs.original_spectra), - len(ADDUCT_TO_MASS.keys()), - ), - dtype=np.float32, - ) - if use_ce: - ce = np.zeros((len(molecule_pairs.original_spectra), 1), dtype=np.int32) - if use_ion_activation: - ia = np.zeros( - ( - len(molecule_pairs.original_spectra), - len(ION_ACTIVATION), - ), - dtype=np.int32, - ) - if use_ion_method: - im = np.zeros( - ( - len(molecule_pairs.original_spectra), - len(IONIZATION_METHODS), - ), - dtype=np.int32, - ) + ionmode = np.zeros( + (len(molecule_pairs.original_spectra), 1), dtype=np.float32 + ) + adduct = np.zeros( + ( + len(molecule_pairs.original_spectra), + len(ADDUCT_TO_MASS.keys()), + ), + dtype=np.float32, + ) + ce = np.zeros( + (len(molecule_pairs.original_spectra), 1), dtype=np.int32 + ) + ia = np.zeros( + ( + len(molecule_pairs.original_spectra), + len(ION_ACTIVATION), + ), + dtype=np.int32, + ) + im = np.zeros( + ( + len(molecule_pairs.original_spectra), + len(IONIZATION_METHODS), + ), + dtype=np.int32, + ) logger.info("Loading mz, intensity and precursor data ...") for i, spec in enumerate(molecule_pairs.original_spectra): @@ -140,14 +139,15 @@ def from_molecule_pairs_to_dataset( precursor_mass[i] = spec.precursor_mz precursor_charge[i] = spec.precursor_charge - if use_adduct: + if use_ion_mode: if (spec.ionmode is None) or ( spec.ionmode == "None" ): # TODO: check if the 2nd condition is needed ionmode[i] = 0 else: ionmode[i] = 1.0 if spec.ionmode == "positive" else -1.0 - adduct[i] = encode_adduct(spec.adduct) + if use_adduct: + adduct[i] = encode_adduct_mass(spec.params["adduct"]) if use_ce: if (spec.ce is None) or (spec.ce == "None"): @@ -217,12 +217,13 @@ def from_molecule_pairs_to_dataset( fingerprint_0=fingerprint_0, max_num_peaks=max_num_peaks, use_adduct=use_adduct, - ionmode=(ionmode if use_adduct else None), - adduct=(adduct if use_adduct else None), + use_ion_mode=use_ion_mode, + ionmode=ionmode, + adduct=adduct, use_ce=use_ce, - ce=(ce if use_ce else None), + ce=ce, use_ion_activation=use_ion_activation, - ion_activation=(ia if use_ion_activation else None), + ion_activation=ia, use_ion_method=use_ion_method, - ion_method=(im if use_ion_method else None), + ion_method=im, ) diff --git a/simba/core/data/encoding.py b/simba/core/data/encoding.py index b1f79a3..f04e9ce 100644 --- a/simba/core/data/encoding.py +++ b/simba/core/data/encoding.py @@ -7,7 +7,25 @@ IONIZATION_METHODS = ["NSI", "ESI", "APCI"] -def encode_adduct(adduct: str): +def encode_adduct_mass(adduct: str): + """Encode adduct as its mass. + + Args: + adduct: Adduct string (e.g., '[M+H]+') + + Returns: + float: Mass of the adduct, or 0 if adduct is not recognized + """ + # TODO: how encode adduct if not recognized? + # Currently returns 0, but might interfere with spectra without adducts + adducts = ADDUCT_TO_MASS.keys() + response = [0 for a in adducts] + if adduct in adducts: + response[0] = ADDUCT_TO_MASS[adduct] + return response + + +def encode_adduct_one_hot(adduct: str): """Encode adduct string as one-hot vector. Args: diff --git a/simba/core/models/similarity_models.py b/simba/core/models/similarity_models.py index 4551fbf..67afa99 100644 --- a/simba/core/models/similarity_models.py +++ b/simba/core/models/similarity_models.py @@ -47,6 +47,7 @@ def __init__( use_ce=False, use_ion_activation=False, use_ion_method=False, + use_ion_mode=False, ): """Initialize the CCSPredictor""" super().__init__() @@ -63,6 +64,7 @@ def __init__( self.use_ce = use_ce self.use_ion_activation = use_ion_activation self.use_ion_method = use_ion_method + self.use_ion_mode = use_ion_mode self.spectrum_encoder = SpectrumTransformerEncoderCustom( d_model=d_model, @@ -72,6 +74,7 @@ def __init__( use_ce=use_ce, use_ion_activation=use_ion_activation, use_ion_method=use_ion_method, + use_ion_mode=use_ion_mode, ) self.regression_loss = nn.MSELoss() @@ -104,13 +107,16 @@ def forward(self, batch): kwargs_0 = { "precursor_mass": batch["precursor_mass_0"].float(), - "precursor_charge": batch["precursor_charge_0"].float(), } kwargs_1 = { "precursor_mass": batch["precursor_mass_1"].float(), - "precursor_charge": batch["precursor_charge_1"].float(), } # extra data + if self.use_ion_mode: + kwargs_0["ionmode"] = batch["ionmode_0"].float() + kwargs_1["ionmode"] = batch["ionmode_1"].float() + kwargs_0["precursor_charge"] = batch["precursor_charge_0"].float() + kwargs_1["precursor_charge"] = batch["precursor_charge_1"].float() if self.use_adduct: kwargs_0["ionmode"] = batch["ionmode_0"].float() kwargs_1["ionmode"] = batch["ionmode_1"].float() @@ -359,6 +365,7 @@ def __init__( use_ce=False, use_ion_activation=False, use_ion_method=False, + use_ion_mode=False, ): """Initialize the CCSPredictor""" super().__init__( @@ -373,6 +380,7 @@ def __init__( use_ce=use_ce, use_ion_activation=use_ion_activation, use_ion_method=use_ion_method, + use_ion_mode=use_ion_mode, ) self.weights = weights @@ -420,17 +428,10 @@ def __init__( # Initialize learnable log variance parameters for each loss self.USE_LEARNABLE_MULTITASK = USE_LEARNABLE_MULTITASK if USE_LEARNABLE_MULTITASK: - initial_log_sigma1 = 1.2490483522415161 - initial_log_sigma2 = -7.0018157958984375 - # self.log_sigma1 = nn.Parameter(torch.tensor(initial_log_sigma1)) - # self.log_sigma2 = nn.Parameter(torch.tensor(initial_log_sigma2)) - self.log_sigma1 = torch.tensor( - float(initial_log_sigma1), dtype=torch.float32 - ) - self.log_sigma2 = torch.tensor( - float(initial_log_sigma2), dtype=torch.float32 - ) - self.use_extra_metadata = use_adduct + initial_log_sigma1 = 0.0 + initial_log_sigma2 = -5.3 + self.log_sigma1 = nn.Parameter(torch.tensor(initial_log_sigma1)) + self.log_sigma2 = nn.Parameter(torch.tensor(initial_log_sigma2)) def forward(self, batch, return_spectrum_output=False): # … compute raw emb0, emb1, apply relu, fingerprints, etc. … @@ -465,56 +466,47 @@ def forward(self, batch, return_spectrum_output=False): "precursor_charge": batch["precursor_charge_1"].float(), } - if self.use_adduct: - batch["ionmode_0"] = torch.nan_to_num( - batch["ionmode_0"], nan=0.0, posinf=0.0, neginf=0.0 - ) - batch["ionmode_1"] = torch.nan_to_num( - batch["ionmode_1"], nan=0.0, posinf=0.0, neginf=0.0 - ) - batch["adduct_0"] = torch.nan_to_num( - batch["adduct_0"], nan=0.0, posinf=0.0, neginf=0.0 - ) - batch["adduct_1"] = torch.nan_to_num( - batch["adduct_1"], nan=0.0, posinf=0.0, neginf=0.0 - ) - - kwargs_0["ionmode"] = batch["ionmode_0"].float() - kwargs_1["ionmode"] = batch["ionmode_1"].float() - kwargs_0["adduct"] = batch["adduct_0"].float() - kwargs_1["adduct"] = batch["adduct_1"].float() + batch["ionmode_0"] = torch.nan_to_num( + batch["ionmode_0"], nan=0.0, posinf=0.0, neginf=0.0 + ) + batch["ionmode_1"] = torch.nan_to_num( + batch["ionmode_1"], nan=0.0, posinf=0.0, neginf=0.0 + ) + kwargs_0["ionmode"] = batch["ionmode_0"].float() + kwargs_1["ionmode"] = batch["ionmode_1"].float() + batch["adduct_0"] = torch.nan_to_num( + batch["adduct_0"], nan=0.0, posinf=0.0, neginf=0.0 + ) + batch["adduct_1"] = torch.nan_to_num( + batch["adduct_1"], nan=0.0, posinf=0.0, neginf=0.0 + ) - if self.use_ce: - batch["ce_0"] = torch.nan_to_num( - batch["ce_0"], nan=0.0, posinf=0.0, neginf=0.0 - ) - batch["ce_1"] = torch.nan_to_num( - batch["ce_1"], nan=0.0, posinf=0.0, neginf=0.0 - ) - kwargs_0["ce"] = batch["ce_0"].float() - kwargs_1["ce"] = batch["ce_1"].float() + kwargs_0["adduct"] = batch["adduct_0"].float() + kwargs_1["adduct"] = batch["adduct_1"].float() - if self.use_ion_activation: - batch["ion_activation_0"] = torch.nan_to_num( - batch["ion_activation_0"], nan=0.0, posinf=0.0, neginf=0.0 - ) - batch["ion_activation_1"] = torch.nan_to_num( - batch["ion_activation_1"], nan=0.0, posinf=0.0, neginf=0.0 - ) - kwargs_0["ion_activation"] = batch["ion_activation_0"].float() - kwargs_1["ion_activation"] = batch["ion_activation_1"].float() + batch["ce_0"] = torch.nan_to_num(batch["ce_0"], nan=0.0, posinf=0.0, neginf=0.0) + batch["ce_1"] = torch.nan_to_num(batch["ce_1"], nan=0.0, posinf=0.0, neginf=0.0) + kwargs_0["ce"] = batch["ce_0"].float() + kwargs_1["ce"] = batch["ce_1"].float() - if self.use_ion_method: - batch["ion_method_0"] = torch.nan_to_num( - batch["ion_method_0"], nan=0.0, posinf=0.0, neginf=0.0 - ) - batch["ion_method_1"] = torch.nan_to_num( - batch["ion_method_1"], nan=0.0, posinf=0.0, neginf=0.0 - ) + batch["ion_activation_0"] = torch.nan_to_num( + batch["ion_activation_0"], nan=0.0, posinf=0.0, neginf=0.0 + ) + batch["ion_activation_1"] = torch.nan_to_num( + batch["ion_activation_1"], nan=0.0, posinf=0.0, neginf=0.0 + ) + kwargs_0["ion_activation"] = batch["ion_activation_0"].float() + kwargs_1["ion_activation"] = batch["ion_activation_1"].float() - kwargs_0["ion_method"] = batch["ion_method_0"].float() - kwargs_1["ion_method"] = batch["ion_method_1"].float() + batch["ion_method_0"] = torch.nan_to_num( + batch["ion_method_0"], nan=0.0, posinf=0.0, neginf=0.0 + ) + batch["ion_method_1"] = torch.nan_to_num( + batch["ion_method_1"], nan=0.0, posinf=0.0, neginf=0.0 + ) + kwargs_0["ion_method"] = batch["ion_method_0"].float() + kwargs_1["ion_method"] = batch["ion_method_1"].float() # intensity and mz batch["intensity_0"] = torch.nan_to_num( batch["intensity_0"], nan=0.0, posinf=0.0, neginf=0.0 @@ -794,6 +786,18 @@ def forward(self, batch): "precursor_charge": batch["precursor_charge"].float(), } + # Add metadata fields if present in batch + if "ionmode" in batch: + kwargs["ionmode"] = batch["ionmode"].float() + if "adduct" in batch: + kwargs["adduct"] = batch["adduct"].float() + if "ce" in batch: + kwargs["ce"] = batch["ce"].float() + if "ion_activation" in batch: + kwargs["ion_activation"] = batch["ion_activation"].float() + if "ion_method" in batch: + kwargs["ion_method"] = batch["ion_method"].float() + emb, _ = self.model( mz_array=batch["mz"].float(), intensity_array=batch["intensity"].float(), diff --git a/simba/core/models/spectrum_encoder.py b/simba/core/models/spectrum_encoder.py index 3a516a3..44e2614 100644 --- a/simba/core/models/spectrum_encoder.py +++ b/simba/core/models/spectrum_encoder.py @@ -1,4 +1,5 @@ import torch +from depthcharge.encoders import FloatEncoder from depthcharge.transformers import ( SpectrumTransformerEncoder, ) # PeptideTransformerEncoder, @@ -12,6 +13,7 @@ def __init__( use_ce: bool = False, use_ion_activation: bool = False, use_ion_method: bool = False, + use_ion_mode: bool = False, **kwargs, ): """ @@ -28,11 +30,35 @@ def __init__( use_ion_method: bool Whether to include ionization method in the encoding (default: False). """ + self.use_encoders = False super().__init__(*args, **kwargs) self.use_adduct = use_adduct self.use_ce = use_ce self.use_ion_activation = use_ion_activation self.use_ion_method = use_ion_method + self.use_ion_mode = use_ion_mode + + if self.use_encoders: + if self.use_adduct: + self.adduct_encoder = FloatEncoder(self.d_model) + self.ionmode_encoder = FloatEncoder(self.d_model) + if self.use_ce: + self.ce_encoder = FloatEncoder( + self.d_model, + ) + + if self.use_ion_activation: + self.ion_activation_encoder = FloatEncoder(self.d_model) + + if self.use_ion_method: + self.ion_method_encoder = FloatEncoder(self.d_model) + if ( + self.use_adduct + or self.use_ce + or self.use_ion_activation + or self.use_ion_method + ): + self.precursor_mz_encoder = FloatEncoder(self.d_model) def precursor_hook( self, @@ -44,46 +70,87 @@ def precursor_hook( dtype = mz_array.dtype batch_size = mz_array.shape[0] - placeholder = torch.zeros( - (batch_size, self.d_model), dtype=dtype, device=device - ) - precursor_mass = kwargs["precursor_mass"].float().to(device).view(batch_size) - placeholder[:, 0] = precursor_mass + if not (self.use_encoders): + placeholder = torch.zeros( + (batch_size, self.d_model), dtype=dtype, device=device + ) + precursor_mass = ( + kwargs["precursor_mass"].float().to(device).view(batch_size) + ) + placeholder[:, 0] = precursor_mass - precursor_charge = ( - kwargs["precursor_charge"].float().to(device).view(batch_size) - ) - placeholder[:, 1] = precursor_charge + precursor_charge = ( + kwargs["precursor_charge"].float().to(device).view(batch_size) + ) + # skip the use of the precursor charge field + if self.use_ion_mode: + placeholder[:, 1] = precursor_charge + + current_idx = 2 # keep track of where to insert metadata - current_idx = 2 # keep track of where to insert metadata - if self.use_adduct: ionmode = kwargs["ionmode"].float().to(device).view(batch_size) - placeholder[:, current_idx] = ionmode + if self.use_ion_mode: + placeholder[:, current_idx] = ionmode current_idx += 1 adduct = kwargs["adduct"].float().to(device).view(batch_size, -1) stop_idx = current_idx + adduct.shape[1] - placeholder[:, current_idx:stop_idx] = adduct + if self.use_adduct: + placeholder[:, current_idx:stop_idx] = adduct current_idx = stop_idx - if self.use_ce: ce = kwargs["ce"].float().to(device).view(batch_size) - placeholder[:, current_idx] = ce + if self.use_ce: + placeholder[:, current_idx] = ce current_idx += 1 - if self.use_ion_activation: ia = kwargs["ion_activation"].float().to(device).view(batch_size, -1) stop_idx = current_idx + ia.shape[1] - placeholder[:, current_idx:stop_idx] = ia + if self.use_ion_activation: + placeholder[:, current_idx:stop_idx] = ia current_idx = stop_idx - if self.use_ion_method: im = kwargs["ion_method"].float().to(device).view(batch_size, -1) stop_idx = current_idx + im.shape[1] - placeholder[:, current_idx:stop_idx] = im + if self.use_ion_method: + placeholder[:, current_idx:stop_idx] = im current_idx = stop_idx - # ensure there are no nans - placeholder = torch.nan_to_num(placeholder, nan=0.0, posinf=0.0, neginf=0.0) + # ensure there are no nans + placeholder = torch.nan_to_num(placeholder, nan=0.0, posinf=0.0, neginf=0.0) + + else: + precursor_mass = ( + kwargs["precursor_mass"].float().to(device).view(batch_size) + ) + precursor_mass_rep = self.precursor_mz_encoder( + precursor_mass[:, None] + ).squeeze(1) + placeholder = precursor_mass_rep + 0 * precursor_mass_rep + + if self.use_adduct: + ionmode = kwargs["ionmode"].float().to(device).view(batch_size) + adduct = kwargs["adduct"].float().to(device).view(batch_size, -1) + ionmode_rep = self.ionmode_encoder(ionmode[:, None]).squeeze(1) + adduct_rep = self.adduct_encoder(adduct).mean(dim=1) + + placeholder = placeholder + (ionmode_rep + adduct_rep) + + if self.use_ce: + ce = kwargs["ce"].float().to(device).view(batch_size) + ce_rep = self.ce_encoder(ce[:, None]).squeeze(1) + placeholder = placeholder + ce_rep + + if self.use_ion_method: + im = kwargs["ion_method"].float().to(device).view(batch_size, -1) + im_rep = self.ion_method_encoder(im).mean(dim=1) + placeholder = placeholder + im_rep + + if self.use_ion_activation: + ia = kwargs["ion_activation"].float().to(device).view(batch_size, -1) + ia_rep = self.ion_activation_encoder(ia).mean(dim=1) + placeholder = placeholder + ia_rep + + placeholder = torch.nan_to_num(placeholder, nan=0.0, posinf=0.0, neginf=0.0) return placeholder diff --git a/simba/workflows/inference.py b/simba/workflows/inference.py index d6f18c2..8f85267 100644 --- a/simba/workflows/inference.py +++ b/simba/workflows/inference.py @@ -10,6 +10,7 @@ import numpy as np from omegaconf import DictConfig from scipy.stats import spearmanr +from sklearn.metrics import mean_absolute_error from torch.utils.data import DataLoader import simba.core.data.molecule_pairs @@ -54,7 +55,18 @@ def load_inference_data(cfg: DictConfig): logger.info("Detected lightweight format - reconstructing molecule_pairs_test") mgf_path = dataset["mgf_path"] - all_spectra = load_spectra(mgf_path, cfg) + + # Use preprocessing config values (if available) to ensure consistent filtering + use_only_protonized = getattr( + cfg.preprocessing, "use_only_protonized_adducts", True + ) + + all_spectra = load_spectra( + mgf_path, + cfg, + n_samples=-1, + use_only_protonized_adducts=use_only_protonized, + ) # Create spectrum lookup by MGF index spectra_by_idx = {s.mgf_index: s for s in all_spectra} @@ -64,8 +76,31 @@ def load_inference_data(cfg: DictConfig): df_smiles = dataset["df_smiles_test"] spectrum_indexes = dataset["spectrum_indexes_test"] - # Load original spectra (all, including duplicates) - original_spectra = [spectra_by_idx[idx] for idx in spectrum_indexes] + # Load original spectra, handling missing ones + original_spectra = [] + idx_map = {} # old_idx -> new_idx + missing = [] + + for old_idx, mgf_idx in enumerate(spectrum_indexes): + if mgf_idx in spectra_by_idx: + idx_map[old_idx] = len(original_spectra) + original_spectra.append(spectra_by_idx[mgf_idx]) + else: + missing.append(mgf_idx) + + if missing: + logger.warning( + f"[test] Missing {len(missing)} spectra (e.g., MGF index {missing[0]})" + ) + # Filter df_smiles to keep only rows with valid spectra + valid_rows = [] + for i in df_smiles.index: + old_idxs = df_smiles.loc[i, "indexes"] + if all(idx in idx_map for idx in old_idxs): + # Remap to new positions + df_smiles.at[i, "indexes"] = [idx_map[idx] for idx in old_idxs] + valid_rows.append(i) + df_smiles = df_smiles.loc[valid_rows] # Build unique_spectra from df_smiles indexes # df_smiles['indexes'] contains lists of indexes into original_spectra @@ -157,6 +192,7 @@ def prepare_inference_dataloaders( use_ce=cfg.model.features.use_ce, use_ion_activation=cfg.model.features.use_ion_activation, use_ion_method=cfg.model.features.use_ion_method, + use_ion_mode=cfg.model.features.use_ion_mode, ) dataloader_ed = DataLoader( dataset_ed, batch_size=cfg.inference.batch_size, shuffle=False @@ -169,6 +205,7 @@ def prepare_inference_dataloaders( use_ce=cfg.model.features.use_ce, use_ion_activation=cfg.model.features.use_ion_activation, use_ion_method=cfg.model.features.use_ion_method, + use_ion_mode=cfg.model.features.use_ion_mode, ) dataloader_mces = DataLoader( dataset_mces, batch_size=cfg.inference.batch_size, shuffle=False @@ -202,6 +239,7 @@ def load_model_for_inference(cfg: DictConfig, checkpoint_path: str): "use_ce": cfg.model.features.use_ce, "use_ion_activation": cfg.model.features.use_ion_activation, "use_ion_method": cfg.model.features.use_ion_method, + "use_ion_mode": cfg.model.features.use_ion_mode, } model = SimilarityModelMultitask.load_from_checkpoint( @@ -314,7 +352,10 @@ def evaluate_predictions( # Edit distance correlation corr_model_ed, _ = spearmanr(ed_true_clean, pred_ed_ed_clean) + mae_model_ed = mean_absolute_error(ed_true_clean, pred_ed_ed_clean) + logger.info(f"Edit distance correlation: {corr_model_ed:.4f}") + logger.info(f"Edit distance mean absolute error: {mae_model_ed:.4f}") # Plot confusion matrix _plot_cm(ed_true_clean, pred_ed_ed_clean, cfg, output_dir) @@ -336,10 +377,15 @@ def evaluate_predictions( if len(mces_true) == 0 or len(pred_mces_mces_flat) == 0: logger.warning("No MCES samples after filtering, skipping MCES correlation") corr_model_mces = float("nan") + mae_model_mces = float("nan") + else: corr_model_mces, _ = spearmanr(mces_true, pred_mces_mces_flat) - + mae_model_mces = mean_absolute_error(mces_true, pred_mces_mces_flat) logger.info(f"MCES/Tanimoto correlation: {corr_model_mces:.4f}") + logger.info( + f"MCES/Tanimoto mean absolute error: {cfg.data.mces20_max_value * mae_model_mces:.4f}" + ) # Denormalize if using MCES20 if not cfg.data.use_tanimoto: @@ -453,10 +499,10 @@ def _plot_cm( ) -> None: """Plot confusion matrix.""" import matplotlib.pyplot as plt - from sklearn.metrics import accuracy_score, confusion_matrix + from sklearn.metrics import balanced_accuracy_score, confusion_matrix cm = confusion_matrix(true, preds) - accuracy = accuracy_score(true, preds) + accuracy = balanced_accuracy_score(true, preds) # Normalize cm_normalized = cm.astype("float") / cm.sum(axis=1)[:, np.newaxis] diff --git a/simba/workflows/preprocessing.py b/simba/workflows/preprocessing.py index 4c27c77..bd34fe5 100644 --- a/simba/workflows/preprocessing.py +++ b/simba/workflows/preprocessing.py @@ -17,6 +17,9 @@ with contextlib.suppress(RuntimeError): multiprocessing.set_start_method("spawn", force=True) +from simba.core.chemistry.edit_distance.edit_distance import ( + load_precomputed_distances_cache, +) from simba.core.chemistry.mces.mces_computation import MCES from simba.core.training.train_utils import TrainUtils from simba.utils.logger_setup import logger @@ -50,30 +53,34 @@ def write_data( # Lightweight format: save only df_smiles, original MGF indexes, and mgf path # Spectra will be loaded at training time from mgf file using absolute indexes dataset = { - "df_smiles_train": molecule_pairs_train.df_smiles - if molecule_pairs_train is not None - else None, - "df_smiles_val": molecule_pairs_val.df_smiles - if molecule_pairs_val is not None - else None, - "df_smiles_test": molecule_pairs_test.df_smiles - if molecule_pairs_test is not None - else None, - "spectrum_indexes_train": [ - s.mgf_index for s in molecule_pairs_train.original_spectra - ] - if molecule_pairs_train is not None - else None, - "spectrum_indexes_val": [ - s.mgf_index for s in molecule_pairs_val.original_spectra - ] - if molecule_pairs_val is not None - else None, - "spectrum_indexes_test": [ - s.mgf_index for s in molecule_pairs_test.original_spectra - ] - if molecule_pairs_test is not None - else None, + "df_smiles_train": ( + molecule_pairs_train.df_smiles + if molecule_pairs_train is not None + else None + ), + "df_smiles_val": ( + molecule_pairs_val.df_smiles if molecule_pairs_val is not None else None + ), + "df_smiles_test": ( + molecule_pairs_test.df_smiles + if molecule_pairs_test is not None + else None + ), + "spectrum_indexes_train": ( + [s.mgf_index for s in molecule_pairs_train.original_spectra] + if molecule_pairs_train is not None + else None + ), + "spectrum_indexes_val": ( + [s.mgf_index for s in molecule_pairs_val.original_spectra] + if molecule_pairs_val is not None + else None + ), + "spectrum_indexes_test": ( + [s.mgf_index for s in molecule_pairs_test.original_spectra] + if molecule_pairs_test is not None + else None + ), "mgf_path": mgf_path, "format_version": "lightweight", } @@ -201,6 +208,45 @@ def preprocess(cfg: DictConfig) -> None: with open(output_file, "wb") as file: pickle.dump(dataset, file) + # Load precomputed distances cache if configured + precomputed_cache = {} + if ( + hasattr(cfg.preprocessing, "precomputed_distances") + and cfg.preprocessing.precomputed_distances + ): + dirs = cfg.preprocessing.precomputed_distances + if dirs and len(dirs) > 0: + logger.info( + f"Loading precomputed distances from {len(dirs)} directory(ies)..." + ) + precomputed_cache = load_precomputed_distances_cache(dirs) + + # Filter cache for SMILES in current dataset + if precomputed_cache: + all_smiles = set() + for spectra_list in [ + all_spectra_train, + all_spectra_val, + all_spectra_test, + ]: + if spectra_list: + all_smiles.update(s.params["smiles"] for s in spectra_list) + + if all_smiles: + from simba.core.chemistry.edit_distance.edit_distance import ( + filter_cache_by_smiles, + ) + + precomputed_cache = filter_cache_by_smiles( + precomputed_cache, list(all_smiles) + ) + else: + logger.info( + "No precomputed distances configured, computing all from scratch" + ) + else: + logger.info("No precomputed distances configured, computing all from scratch") + # Compute distances for each partition molecule_pairs = {} for type_data, spectra in [ @@ -225,6 +271,7 @@ def preprocess(cfg: DictConfig) -> None: use_edit_distance=True, # Ignored when compute_both_metrics=True loaded_molecule_pairs=None, compute_both_metrics=True, + precomputed_cache=precomputed_cache if len(precomputed_cache) > 0 else None, ) # Combine edit distance and MCES files diff --git a/simba/workflows/training.py b/simba/workflows/training.py index 4adcd2e..987f3bb 100644 --- a/simba/workflows/training.py +++ b/simba/workflows/training.py @@ -84,7 +84,18 @@ def load_dataset(cfg: DictConfig): ) mgf_path = mapping["mgf_path"] - all_spectra = load_spectra(mgf_path, cfg) + + # Use preprocessing config values (if available) to ensure consistent filtering + use_only_protonized = getattr( + cfg.preprocessing, "use_only_protonized_adducts", True + ) + + all_spectra = load_spectra( + mgf_path, + cfg, + n_samples=-1, # Load all spectra during training + use_only_protonized_adducts=use_only_protonized, + ) # Create spectrum lookup by MGF index spectra_by_idx = {s.mgf_index: s for s in all_spectra} @@ -98,12 +109,35 @@ def load_dataset(cfg: DictConfig): df_smiles = mapping[df_smiles_key] spectrum_indexes = mapping[spectrum_indexes_key] - # Load original spectra (all, including duplicates) - original_spectra = [spectra_by_idx[idx] for idx in spectrum_indexes] + # Load original spectra, handling missing ones + original_spectra = [] + idx_map = {} # old_idx -> new_idx + missing = [] + + for old_idx, mgf_idx in enumerate(spectrum_indexes): + if mgf_idx in spectra_by_idx: + idx_map[old_idx] = len(original_spectra) + original_spectra.append(spectra_by_idx[mgf_idx]) + else: + missing.append(mgf_idx) + + if missing: + logger.warning( + f"[{split}] Missing {len(missing)} spectra (e.g., MGF index {missing[0]})" + ) + # Filter df_smiles to keep only rows with valid spectra + valid_rows = [] + for i in df_smiles.index: + old_idxs = df_smiles.loc[i, "indexes"] + if all(idx in idx_map for idx in old_idxs): + # Remap to new positions + df_smiles.at[i, "indexes"] = [ + idx_map[idx] for idx in old_idxs + ] + valid_rows.append(i) + df_smiles = df_smiles.loc[valid_rows] # Build unique_spectra from df_smiles indexes - # df_smiles['indexes'] contains lists of indexes into original_spectra - # We take the first spectrum for each unique SMILES unique_spectra = [ original_spectra[df_smiles.loc[i, "indexes"][0]] for i in df_smiles.index @@ -260,6 +294,7 @@ def prepare_data( use_ce=cfg.model.features.use_ce, use_ion_activation=cfg.model.features.use_ion_activation, use_ion_method=cfg.model.features.use_ion_method, + use_ion_mode=cfg.model.features.use_ion_mode, ) dataset_val = MultitaskDataBuilder.from_molecule_pairs_to_dataset( @@ -269,6 +304,7 @@ def prepare_data( use_ce=cfg.model.features.use_ce, use_ion_activation=cfg.model.features.use_ion_activation, use_ion_method=cfg.model.features.use_ion_method, + use_ion_mode=cfg.model.features.use_ion_mode, ) # Create samplers @@ -392,6 +428,7 @@ def setup_model(cfg: DictConfig, weights_mces: np.ndarray) -> SimilarityModelMul "use_ce": cfg.model.features.use_ce, "use_ion_activation": cfg.model.features.use_ion_activation, "use_ion_method": cfg.model.features.use_ion_method, + "use_ion_mode": cfg.model.features.use_ion_mode, } # Load pretrained weights if specified diff --git a/test_all_commands.sh b/test_all_commands.sh index ae754c9..ad88156 100755 --- a/test_all_commands.sh +++ b/test_all_commands.sh @@ -44,7 +44,7 @@ rm -rf test_full_workflow/ mkdir -p test_full_workflow echo "" -echo "1/9 Testing: simba preprocess" +echo "1/10 Testing: simba preprocess" echo "--------------------------------" uv run simba preprocess \ preprocessing=fast_dev \ @@ -52,7 +52,16 @@ uv run simba preprocess \ paths.preprocessing_dir=./test_full_workflow/preprocessed/ echo "" -echo "2/9 Testing: simba train" +echo "2/10 Testing: simba preprocess with cache (reuse distances)" +echo "------------------------------------------------------------" +uv run simba preprocess \ + preprocessing=fast_dev \ + paths.spectra_path=data/casmi2022.mgf \ + paths.preprocessing_dir=./test_full_workflow/preprocessed_cached/ \ + 'preprocessing.precomputed_distances=[./test_full_workflow/preprocessed/]' + +echo "" +echo "3/10 Testing: simba train" echo "--------------------------------" uv run simba train \ training=fast_dev \ @@ -63,7 +72,7 @@ uv run simba train \ hardware.accelerator=$DEVICE echo "" -echo "3/9 Testing: simba inference" +echo "4/10 Testing: simba inference" echo "--------------------------------" uv run simba inference \ inference=fast_dev \ @@ -73,7 +82,7 @@ uv run simba inference \ hardware.accelerator=$DEVICE echo "" -echo "4/9 Testing: simba analog-discovery" +echo "5/10 Testing: simba analog-discovery" echo "--------------------------------" uv run simba analog-discovery \ analog_discovery=fast_dev \ @@ -86,7 +95,7 @@ uv run simba analog-discovery \ if [[ "$SKIP_PRETRAINED" == "false" ]]; then echo "" - echo "5/9 Testing: simba inference (pretrained model)" + echo "6/10 Testing: simba inference (pretrained model)" echo "--------------------------------" uv run simba inference \ inference=fast_dev \ @@ -96,7 +105,7 @@ if [[ "$SKIP_PRETRAINED" == "false" ]]; then hardware.accelerator=$DEVICE echo "" - echo "6/9 Testing: simba analog-discovery (pretrained model)" + echo "7/10 Testing: simba analog-discovery (pretrained model)" echo "--------------------------------" uv run simba analog-discovery \ analog_discovery=fast_dev \ @@ -114,12 +123,13 @@ else fi echo "" -echo "7/9 Testing: simba train (with metadata features)" +echo "8/10 Testing: simba train (with metadata features)" echo "--------------------------------" uv run simba train \ training=fast_dev \ paths.preprocessing_dir_train=./test_full_workflow/preprocessed/ \ paths.checkpoint_dir=./test_full_workflow/checkpoints_metadata/ \ + model.features.use_adduct=true \ model.features.use_ce=true \ model.features.use_ion_activation=true \ model.features.use_ion_method=true \ @@ -128,20 +138,21 @@ uv run simba train \ hardware.accelerator=$DEVICE echo "" -echo "8/9 Testing: simba inference (with metadata features)" +echo "9/10 Testing: simba inference (with metadata features)" echo "--------------------------------" uv run simba inference \ inference=fast_dev \ paths.checkpoint_dir=./test_full_workflow/checkpoints_metadata/ \ paths.preprocessing_dir=./test_full_workflow/preprocessed/ \ inference.preprocessing_pickle=mapping_unique_smiles.pkl \ + model.features.use_adduct=true \ model.features.use_ce=true \ model.features.use_ion_activation=true \ model.features.use_ion_method=true \ hardware.accelerator=$DEVICE echo "" -echo "9/9 Testing: simba analog-discovery (with metadata features)" +echo "10/10 Testing: simba analog-discovery (with metadata features)" echo "--------------------------------" uv run simba analog-discovery \ analog_discovery=fast_dev \ @@ -150,6 +161,7 @@ uv run simba analog-discovery \ --reference-spectra data/casmi2022.mgf \ --output-dir ./test_full_workflow/analog_results_metadata/ \ analog_discovery.query_index=0 \ + model.features.use_adduct=true \ model.features.use_ce=true \ model.features.use_ion_activation=true \ model.features.use_ion_method=true \ diff --git a/tests/unit/test_embedder_multitask.py b/tests/unit/test_embedder_multitask.py index c21a8b7..b8f1a03 100644 --- a/tests/unit/test_embedder_multitask.py +++ b/tests/unit/test_embedder_multitask.py @@ -116,6 +116,12 @@ def sample_batch(self): "adduct_1": torch.zeros(batch_size, n_adducts), # One-hot encoded "ionmode_0": torch.ones(batch_size, 1), "ionmode_1": torch.ones(batch_size, 1), + "ce_0": torch.ones(batch_size, 1) * 30.0, + "ce_1": torch.ones(batch_size, 1) * 30.0, + "ion_activation_0": torch.zeros(batch_size, 1), + "ion_activation_1": torch.zeros(batch_size, 1), + "ion_method_0": torch.zeros(batch_size, 1), + "ion_method_1": torch.zeros(batch_size, 1), "similarity": torch.tensor([0.8, 0.6]), "similarity_2": torch.tensor([0.7, 0.5]), }