diff --git a/.dvc/.gitignore b/.dvc/.gitignore new file mode 100644 index 0000000..528f30c --- /dev/null +++ b/.dvc/.gitignore @@ -0,0 +1,3 @@ +/config.local +/tmp +/cache diff --git a/.dvc/config b/.dvc/config new file mode 100644 index 0000000..4cf322d --- /dev/null +++ b/.dvc/config @@ -0,0 +1,2 @@ +[core] + autostage = true diff --git a/.dvcignore b/.dvcignore new file mode 100644 index 0000000..5197305 --- /dev/null +++ b/.dvcignore @@ -0,0 +1,3 @@ +# Add patterns of files dvc should ignore, which could improve +# the performance. Learn more at +# https://dvc.org/doc/user-guide/dvcignore diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 2507587..e621c1e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -8,17 +8,17 @@ repos: - id: trailing-whitespace - id: detect-private-key - repo: https://github.com/psf/black - rev: 22.12.0 + rev: 23.3.0 hooks: - id: black language_version: python3.9 - repo: https://github.com/pycqa/isort - rev: 5.11.4 + rev: 5.12.0 hooks: - id: isort name: isort (python) - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: v0.0.228 + rev: v0.0.263 hooks: - id: ruff args: ['--fix'] diff --git a/README.md b/README.md index 6ece584..9dd9711 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,35 @@ + +![GitHub](https://img.shields.io/github/license/TalusBio/diadem?style=flat-square) +![GitHub release (latest SemVer)](https://img.shields.io/github/v/release/talusbio/diadem?style=flat-square) + + # diadem A feature-centric DIA search engine +## Current stage + +Under development. + +## Installation + +```shell +git clone git@github.com:TalusBio/diadem.git +cd diadem +pip install . +# pip install -e ".[dev,profiling,test]" # for development install +``` + +## Usage + +```shell +``` + + ## Release milestones - [x] Search prototype -- [ ] Stable search engine (more as in -) +- [-] Stable search engine (more as in -) - [ ] Quantification module - [ ] Stable quant module - [ ] RT alignment module diff --git a/benchmarking_tests/conftest.py b/benchmarking_tests/conftest.py index 955c08c..c59510c 100644 --- a/benchmarking_tests/conftest.py +++ b/benchmarking_tests/conftest.py @@ -33,7 +33,7 @@ def make_protein(prot_length: int) -> str: """Makes the sequence of a fake protein of the passed length.""" return "".join( - sample(list(aa_counts.keys()), counts=list(aa_counts.values()), k=prot_length) + sample(list(aa_counts.keys()), counts=list(aa_counts.values()), k=prot_length), ) diff --git a/benchmarking_tests/test_deisotoping.py b/benchmarking_tests/test_deisotoping.py new file mode 100644 index 0000000..d0422d2 --- /dev/null +++ b/benchmarking_tests/test_deisotoping.py @@ -0,0 +1,346 @@ +from __future__ import annotations + +import numpy as np +from numpy import array + +# These are just some hard-coded distributions +average_distributions = { + 0: array([1.000]), + 200: array([1.000, 0.108, 0.011, 0.001, 0.000, 0.000]), + 400: array([1.000, 0.219, 0.033, 0.004, 0.000, 0.000, 0.000]), + 600: array([1.000, 0.326, 0.068, 0.011, 0.001, 0.000, 0.000]), + 800: array([1.000, 0.436, 0.116, 0.023, 0.004, 0.000, 0.000, 0.000]), + 1000: array([1.000, 0.536, 0.167, 0.038, 0.007, 0.001, 0.000, 0.000, 0.000]), + 1200: array([1.000, 0.645, 0.238, 0.064, 0.014, 0.003, 0.000, 0.000, 0.000]), + 1400: array([1.000, 0.757, 0.367, 0.133, 0.039, 0.009, 0.002]), + 1600: array([1.000, 0.868, 0.461, 0.181, 0.057, 0.015, 0.003, 0.001]), + 1800: array([1.000, 0.976, 0.565, 0.242, 0.083, 0.024, 0.006, 0.001]), + 2000: array([0.923, 1.000, 0.629, 0.290, 0.107, 0.033, 0.009, 0.002]), + 2200: array([0.837, 1.000, 0.680, 0.336, 0.133, 0.044, 0.013, 0.003, 0.001]), + 2400: array([0.768, 1.000, 0.731, 0.386, 0.163, 0.058, 0.018, 0.005, 0.001]), + 2600: array([0.708, 1.000, 0.784, 0.442, 0.198, 0.074, 0.024, 0.007, 0.002]), + 2800: array([0.662, 1.000, 0.831, 0.494, 0.233, 0.092, 0.031, 0.009, 0.003, 0.001]), + 3000: array([0.617, 1.000, 0.884, 0.557, 0.277, 0.115, 0.041, 0.013, 0.004, 0.001]), + 3200: array([0.578, 1.000, 0.936, 0.622, 0.327, 0.143, 0.054, 0.018, 0.005, 0.001]), + 3400: array( + [0.543, 1.000, 0.990, 0.692, 0.381, 0.175, 0.069, 0.024, 0.008, 0.002, 0.001], + ), + 3600: array( + [0.493, 0.959, 1.000, 0.735, 0.424, 0.203, 0.084, 0.031, 0.010, 0.003, 0.001], + ), + 3800: array( + [0.444, 0.913, 1.000, 0.770, 0.464, 0.233, 0.101, 0.038, 0.013, 0.004, 0.001], + ), +} + +max_len = max([len(x) for x in average_distributions.values()]) +average_distributions_array = np.array( + [np.pad(x, (0, max_len - len(x))) for x in average_distributions.values()], +) + + +def mass_to_dist(mass): + """Convert a mass to a distribution. + + It is a hard coded-ugly way of doing it that approximates + to the closest 200 Da. + Will only work for masses lower than 4000. + """ + return average_distributions_array[(mass // 200)] + + +def make_isotope_envelope_vectorized( + mz, + charge, + intensity, + min_intensity, + ims: np.ndarray | None = None, +): + """Make an isotopic envelope for a given mz, charge, and intensity. + + Returns a tuple of (mz, intensity) arrays. + + Examples + -------- + >>> mzs = np.array([300., 300.]) + >>> ints = np.array([1000., 1000.]) + >>> charges = np.array([1, 2]) + >>> out = make_isotope_envelope_vectorized(mzs, charges, ints, 100) + >>> out + (array([300. , 301.003 , 300. , 300.5015]), array([1000., 108., 1000., 326.])) + >>> out = make_isotope_envelope_vectorized(mzs, charges, ints, 100, ims=np.array([1.0, 2.0])) + >>> out + (array([300. , 301.003 , 300. , 300.5015]), array([1000., 108., 1000., 326.]), array([1., 1., 2., 2.])) + """ # noqa: E501 + dist = average_distributions_array[((mz * charge) // 200).astype(int)] + dist = np.einsum("ij, i -> ij", dist, intensity) + + mz_offset = np.expand_dims(np.arange(dist.shape[1]), axis=-1) * 1.003 / charge + # shape (num_isotopes, 1) + mz_dist = (mz_offset + np.expand_dims(mz, axis=0)).T + + intensities = dist.flatten() + mz_dists = mz_dist.flatten() + + mask = intensities > min_intensity + + mz_dist = mz_dists[mask] + intensities = intensities[mask] + if ims is not None: + ims = np.repeat(ims, dist.shape[1]) + ims = ims[mask] + return mz_dist, intensities, ims + + return mz_dist, intensities + + +def simulate_isotopes( + min_mz, + max_mz, + num_peaks, + min_charge, + max_charge, + min_intensity, + max_intensity, +): + """Simulate a set of isotopic peaks.""" + mzs = np.random.uniform(min_mz, max_mz, num_peaks) + ints = np.random.uniform(min_intensity, max_intensity, size=num_peaks) + charges = np.random.randint(min_charge, max_charge, size=num_peaks) + + mzs, intensities = make_isotope_envelope_vectorized( + mzs, + charges, + ints, + min_intensity, + ) + + return mzs, intensities + + +def simulate_isotopes_ims( + min_mz, + max_mz, + num_peaks, + min_charge, + max_charge, + min_intensity, + max_intensity, + min_ims, + max_ims, +): + """Simulate a set of isotopic peaks.""" + mzs = np.random.uniform(min_mz, max_mz, num_peaks) + ints = np.random.uniform(min_intensity, max_intensity, size=num_peaks) + charges = np.random.randint(min_charge, max_charge, size=num_peaks) + + mzs, intensities = make_isotope_envelope_vectorized( + mzs, + charges, + ints, + min_intensity, + ) + + return mzs, intensities + + +def _split_ims(ims: float, intensity: float, ims_std: float, ims_binwidth=0.002): + """Split an ims peak into multiple peaks. + + The number of peaks is the integer square root of the intensity. + + The new ims is drawn from a normal distribution with a mean of the + provided ims, with the provided standard deviation. + + The intensity of each peak is drawn from a uniform distribution between 0 + and 2*(intensity)/len(num of peaks). + """ + ims_out = np.random.normal(ims, ims_std, size=int(np.sqrt(intensity)) + 1) + ims_intensity = np.random.uniform( + 0, + 2 * (intensity) / len(ims_out), + size=len(ims_out), + ) + ims_intensity_out = np.histogram( + ims_out, + bins=np.arange(ims.min(), ims.max(), ims_binwidth), + weights=ims_intensity, + ) + return ims_out, ims_intensity_out + + +def extend_ims(seed_mzs, intensities, ims_start=0.7, ims_end=1.2, std_ims=0.01): + """Extend a set of peaks to include IMS coordinates. + + A random ims is assigned to every peak and a distribution of peaks is generated + splitting the intensity of the peak between the new peaks. (the new distribution + is random so the total might not be the same). + + The IMS coordinates are drawn from a normal distribution with a mean of the + average IMS value and a standard deviation of the standard deviation of the + IMS values. + + Returns + ------- + mzs : np.ndarray + The m/z values of the peaks. + intensities : np.ndarray + The intensities of the peaks. + ims : np.ndarray + The IMS values of the peaks. + indices : np.ndarray + This indices correspond to the originally passed seed mzs and intensities. + """ + num_ims = len(seed_mzs) + ims = np.random.uniform(ims_start, ims_end, size=num_ims) + ims_inten_pairs = [ + _split_ims(im, inten, ims_std=std_ims, ims_binwidth=0.002) + for im, inten in zip(ims, intensities) + ] + + out_imss = [] + out_intens = [] + out_mzs = [] + out_indices = [] + + for i, (imss, intens, seed_mz) in enumerate(zip(*ims_inten_pairs, seed_mzs)): + out_imss.extend(imss) + out_intens.extend(intens) + out_mzs.extend([seed_mz] * len(imss)) + out_indices.extend([i] * len(imss)) + + return ( + np.array(out_mzs), + np.array(out_intens), + np.array(out_imss), + np.array(out_indices), + ) + + +def simulate_ims_isotopes( # noqa: D103 + min_mz, + max_mz, + num_peaks, + min_charge, + max_charge, + min_intensity, + max_intensity, + min_ims, + max_ims, +): + raise NotImplementedError + + +def add_noise(values, snr): + """Add noise to a set of peaks. + + The noise is added by adding a random number to each peak. The random number is + drawn from a normal distribution with a standard deviation equal to the intensity + of the peak divided by the SNR. + """ + noise = np.random.normal(0, values / snr) + return values + noise + + +def jitter_values(values, std): + """Jitter a set of values. + + The values are jittered by adding a random number to each value. The random number + is drawn from a normal distribution with a standard deviation equal to the standard + deviation of the values. + """ + noise = np.random.normal(0, std, size=len(values)) + return values + noise + + +def get_noise_peaks(mz_min, mz_max, ints, quantile, pct): + """Generates noise peaks from the distribution of intensities. + + Adds uniformly distributed peaks in the given mz range that do not belong to + an isotope envelope. Using the lowest quantile of the intensity distribution, + a threshold is determined below which no peaks are added. + """ + noise_intensity = np.quantile(ints, quantile) + noise_mzs = np.random.uniform(mz_min, mz_max, size=int(len(ints) * pct)) + noise_ints = np.random.uniform(0, noise_intensity, size=len(noise_mzs)) + + return noise_mzs, noise_ints + + +# clean simple spectrum +def clean_simple_spectrum() -> tuple[np.ndarray, np.ndarray]: + """Generate a clean simple spectrum. + + Returns + ------- + mzs : np.ndarray + The m/z values of the peaks. + intensities : np.ndarray + The intensities of the peaks. + """ + mzs, ints = simulate_isotopes(1000, 2000, 100, 1, 5, 1000, 10000) + ints = add_noise(ints, 100) + return mzs, ints + + +# clean complicated spectrum +def clean_complicated_spectrum(): + """Generate a clean simple spectrum. + + Returns + ------- + mzs : np.ndarray + The m/z values of the peaks. + intensities : np.ndarray + The intensities of the peaks. + """ + mzs, ints = simulate_isotopes(1000, 2000, 100, 1, 5, 1000, 10000) + ints = add_noise(ints, 100) + mzs = jitter_values(mzs, 0.01) + mzs2, ints2 = get_noise_peaks(1000, 2000, ints, 0.1, 0.1) + mzs = np.concatenate([mzs, mzs2]) + ints = np.concatenate([ints, ints2]) + return mzs, ints + + +# noisy simple spectrum +def noisy_simple_spectrum(): + """Generate a noisy simple spectrum. + + Returns + ------- + mzs : np.ndarray + The m/z values of the peaks. + intensities : np.ndarray + The intensities of the peaks. + """ + raise NotImplementedError + + +# noisy complicated spectrum +def noisy_complicated_spectrum(): + """Generate a noisy complicated spectrum. + + Returns + ------- + mzs : np.ndarray + The m/z values of the peaks. + intensities : np.ndarray + The intensities of the peaks. + """ + npeaks = 5_000 + + mzs, ints = simulate_isotopes( + 1000, + 2000, + npeaks, + 1, + 5, + min_intensity=1_000, + max_intensity=100_000, + ) + raise NotImplementedError + + +if __name__ == "__main__": + out = simulate_isotopes(1000, 2000, 3, 1, 2, 1000, 10000) diff --git a/benchmarking_tests/test_scoring_speed.py b/benchmarking_tests/test_scoring_speed.py index 0a6383c..0a09283 100644 --- a/benchmarking_tests/test_scoring_speed.py +++ b/benchmarking_tests/test_scoring_speed.py @@ -12,7 +12,9 @@ def fake_database(fake_5k_fasta, request): Right not it is parametrized so it uses multiple chunk sizes. """ db = db_from_fasta( - fake_5k_fasta, chunksize=request.param, config=DiademConfig(run_parallelism=1) + fake_5k_fasta, + chunksize=request.param, + config=DiademConfig(run_parallelism=1), ) return db @@ -29,7 +31,7 @@ def fake_prefiltered_database(fake_database: IndexedDb, request): def score_all_specs_open(db: IndexedDb, specs): """Helper function that scores all the spectra passed as tuples with a database.""" - for mzs, ints, prec_mz in tqdm(specs): + for mzs, ints, _prec_mz in tqdm(specs): db.hyperscore(precursor_mz=(700.0, 720.0), spec_int=ints, spec_mz=mzs) @@ -42,7 +44,9 @@ def test_db_scoring_speed_unfiltered(fake_database, fake_spectra_tuples_100, ben def test_db_scoring_speed_filtered( - fake_prefiltered_database, fake_spectra_tuples_100, benchmark + fake_prefiltered_database, + fake_spectra_tuples_100, + benchmark, ): """Benchmarks how long it takes to search 100 spectra. @@ -55,7 +59,9 @@ def score_all_specs_closed(db: IndexedDb, specs): """Runs a closed search on all spectra passed.""" for mzs, ints, prec_mz in tqdm(specs): db.hyperscore( - precursor_mz=(prec_mz - 0.01, prec_mz + 0.01), spec_int=ints, spec_mz=mzs + precursor_mz=(prec_mz - 0.01, prec_mz + 0.01), + spec_int=ints, + spec_mz=mzs, ) diff --git a/diadem/aggregate/__init__.py b/diadem/aggregate/__init__.py new file mode 100644 index 0000000..43b805b --- /dev/null +++ b/diadem/aggregate/__init__.py @@ -0,0 +1 @@ +"""The aggregate module.""" diff --git a/diadem/aggregate/imputers.py b/diadem/aggregate/imputers.py new file mode 100644 index 0000000..562c66f --- /dev/null +++ b/diadem/aggregate/imputers.py @@ -0,0 +1,335 @@ +"""The retention time matrix factorization model.""" +from __future__ import annotations + +import logging +from collections.abc import Iterable + +import numpy as np +import pandas as pd +import torch +from sklearn.exceptions import NotFittedError +from torch import nn +from tqdm import trange + +LOGGER = logging.getLogger(__name__) + + +class MatrixFactorizationModel(nn.Module): + """The PyTorch matrix factorization model. + + Parameters + ---------- + n_peptides : int + The number of peptides. + n_runs : int + The number of runs. + n_factors: int + The number of latent factors. + rng : int | numpy.random.Generator | None + The random number generator. + """ + + def __init__( + self, + n_peptides: int, + n_runs: int, + n_factors: int, + rng: int | np.random.Generator | None = None, + ) -> None: + """Initialize an ImputerModel.""" + super().__init__() + self.n_peptides = n_peptides + self.n_runs = n_runs + self.n_factors = n_factors + self.rng = np.random.default_rng(rng) + + torch.manual_seed(self.rng.integers(1, 100000)) + + # The model: + self.peptide_factors = nn.Parameter(torch.randn(n_peptides, n_factors)) + self.run_factors = nn.Parameter(torch.randn(n_factors, n_runs)) + + def forward(self) -> torch.Tensor: + """Reconstruct the matrix. + + Returns + ------- + torch.Tensor of shape (n_peptide, n_runs) + The reconstructed matrix. + """ + return torch.mm(self.peptide_factors, self.run_factors) + + +class MFImputer: + """A matrix factorization imputation model. + + The MFImputer is a PyTorch model wrapped in a sklearn API. + + Parameters + ---------- + n_factors : int | None, optional + The number of latent factors. + max_iter : int, optional + The maximum number of training iterations + tol : float, optional + The percent improvement over the previous loss required to + continue trianing. Used in conjuction with ``n_iter_no_change`` + to trigger early stopping. Set to ``None`` to disable. + n_iter_no_change : int, optional + The number of iterations to wait before triggering early + stopping. + lr: float, optional + The learning rate. + device : str or torch.Device, optional + A valid PyTorch device on which to perform the optimization. + rng : int | np.random.Generator | None, optional + The random number generator. + task : str | None, optional + A sting specifying the task. This is only used for logging. + silent : bool, optional + Suppress logging messages. + """ + + def __init__( + self, + n_factors: int = None, + max_iter: int = 10000, + tol: float = 1e-4, + n_iter_no_change: int = 20, + lr: float = 0.1, + device: str | torch.device = "cpu", + rng: int | np.random.Generator | None = None, + task: str | None = None, + silent: bool = False, + ) -> None: + """Initialize the RTImputer.""" + # Parameters: + self.n_factors = n_factors + self.max_iter = max_iter + self.tol = tol + self.n_iter_no_change = n_iter_no_change + self.device = device + self.lr = lr + self.rng = np.random.default_rng(rng) + self.task = task + self.silent = silent + + # Set during fit: + self._model = None + self._history = None + self._shape = None + self._std = None + self._means = None + + @property + def model_(self) -> MatrixFactorizationModel: + """The underlying PyTorch model.""" + if self._model is None: + raise NotFittedError("This model has not been fit yet.") + + return self._model + + @property + def history_(self) -> pd.DataFrame: + """The training history.""" + return pd.DataFrame( + self._history, + columns=["iteration", "train_loss"], + ) + + def _info(self, msg: str, *args: str | None) -> None: + """Log at the info level. + + Parameters + ---------- + msg : str + The message to be logged. + *args : str, optional + Values to be formatted into the message. + """ + if not self.silent: + LOGGER.info(msg, *args) + + def transform( + self, + X: np.array | torch.Tensor | None = None, + ) -> np.array: # noqa: N803 + """Impute missing retention times. + + Parameters + ---------- + X : array of shape (n_peptides, n_runs), optional + The value matrix. Missing peptides should be denoted as np.nan. + If ``none``, the full reconstruction will be returned. + + Returns + ------- + np.array of shape (n_peptides, n_runs) + The matrix with missing values imputed. + """ + if X is None: + return self.model_().to("cpu").detach().numpy() + + # Prepare the input and initialize model + X = to_tensor(X) + mask = torch.isnan(X) + X_hat = self.model_().to("cpu").detach().type_as(X) + X[mask] = X_hat[mask] + return X.numpy() + + def fit(self, X: np.ndarray | torch.Tensor) -> MFImputer: + """Fit the model. + + Parameters + ---------- + X : array of shape (n_peptides, n_runs) + The value matrix. Missing peptides should be denoted as np.nan. + + Returns + ------- + self + """ + if self.n_factors is None: + raise ValueError( + ( + "n_factors must be specified using search_factors() " + "to find the best value." + ), + ) + + if self.task: + self._info("Training %s model...", self.task) + + # Prepare the input and initialize model + X = to_tensor(X) + mask = ~torch.isnan(X) + X = X.to(self.device, torch.float32) + + self._shape = X.shape + self._model = MatrixFactorizationModel(*X.shape, self.n_factors, self.rng) + self._model.to(self.device) + self._history = [] + optimizer = torch.optim.Adam(self._model.parameters(), lr=self.lr) + log_interval = max(1, self.max_iter // 20) + self._info("Logging every %i iterations...", log_interval) + + # The main training loop: + best_loss = np.inf + early_stopping_counter = 0 + self._info("+-----------------------+") + self._info("| Iteration | Train MSE |") + self._info("+-----------------------+") + bar = trange(self.max_iter, disable=self.silent) + for iteration in bar: + optimizer.zero_grad() + X_hat = self._model() + loss = ((X - X_hat)[mask] ** 2).mean() + loss.backward() + optimizer.step() + self._history.append((iteration, loss.item())) + + if not iteration % log_interval: + self._info("| %9i | %9.4f |", iteration, loss.item()) + + bar.set_postfix(loss=f"{loss:,.3f}") + if self.tol is not None: + if loss < best_loss: + best_loss = loss.item() + early_stopping_counter = 0 + continue + early_stopping_counter += 1 + if early_stopping_counter >= self.n_iter_no_change: + break + + self._info("+-----------------------+") + self._model.to("cpu") + self._info("DONE!") + return self + + def fit_transform(self, X: np.array | torch.Tensor) -> np.ndarray: + """Fit and impute missing retention times. + + Parameters + ---------- + X : array of shape (n_peptides, n_runs) + The value matrix. Missing values should be denoted as np.nan. + + Returns + ------- + np.array of shape (n_peptides, n_runs) + The predicted retention time matrix. + """ + return self.fit(X).transform(X) + + def search_factors( + self, + X: np.array | torch.Tensor, + n_factors: Iterable[int], + folds: int = 3, + ) -> MFImputer: + """Perform line search for the number of latent factors. + + Parameters + ---------- + X : array of shape (n_peptides, n_runs) + The value matrix. Missing values should be denoted as np.nan. + n_factors : Iterable[int] + The numbers of latent factors to try. + folds: int, optional + The number of cross-validation folds. + + Returns + ------- + self + """ + X = to_tensor(X) + prev_silent = self.silent + + # Create CV splits + indices = np.dstack(np.meshgrid(range(X.shape[0]), range(X.shape[1]))) + indices = indices.reshape(-1, 2) + + self.rng.shuffle(indices, axis=0) + splits = torch.split(torch.tensor(indices), indices.shape[0] // folds) + + # Do the line search + self._info("Searching for the best number of latent factors...") + self.silent = True + scores = [] + # This could be parallelized, but I don't think its worth it yet. + for num in n_factors: + self.n_factors = num + split_scores = [] + for split in splits: + split = tuple(split.T) + train = X.clone() + train[split] = np.nan + pred = to_tensor(self.fit_transform(train)) + split_scores.append(((X[split] - pred[split]) ** 2).mean()) + + scores.append(sum(split_scores).item()) + + self.n_factors = n_factors[np.argmin(scores)] + self.silent = prev_silent + self._model = None + self._info(" -> Chose %s factors", self.n_factors) + return self + + +def to_tensor(array: np.ndarray | torch.Tensor) -> torch.Tensor: + """Transform an array into a PyTorch Tensor, copying the data. + + Parameters + ---------- + array : numpy.ndarray or torch.Tensor + The array to transform + + Returns + ------- + torch.Tensor + The converted PyTorch tensor. + """ + if isinstance(array, torch.Tensor): + return array.to("cpu").clone().detach() + + return torch.tensor(array) diff --git a/diadem/aggregate/quants.py b/diadem/aggregate/quants.py new file mode 100644 index 0000000..5ebdcff --- /dev/null +++ b/diadem/aggregate/quants.py @@ -0,0 +1,6 @@ +"""Aggregate quantification results.""" + + +def quants() -> None: + """Aggregate quantification results.""" + raise NotImplementedError("This hasn't been written yet...") diff --git a/diadem/aggregate/rt_model.py b/diadem/aggregate/rt_model.py new file mode 100644 index 0000000..c11a9ae --- /dev/null +++ b/diadem/aggregate/rt_model.py @@ -0,0 +1,244 @@ +"""The retention time matrix factorization model.""" +from __future__ import annotations + +import logging + +import numpy as np +import pandas as pd +import torch +import torch.nn as nn +from sklearn.base import BaseEstimator +from sklearn.exceptions import NotFittedError +from tqdm import trange + +LOGGER = logging.getLogger(__name__) + + +class MatrixFactorizationModel(nn.Module): + """The PyTorch matrix factorization model. + + Parameters + ---------- + n_peptides : int + The number of peptides. + n_runs : int + The number of runs. + n_factors: int + The number of latent factors. + rng : int | numpy.random.Generator | None + The random number generator. + """ + + def __init__( + self, + n_peptides: int, + n_runs: int, + n_factors: int, + rng: int | np.random.Generator | None = None, + ) -> None: + """Initialize an ImputerModel.""" + super().__init__() + self.n_peptides = n_peptides + self.n_runs = n_runs + self.n_factors = n_factors + self.rng = np.random.default_rng(rng) + + torch.manual_seed(self.rng.integers(1, 100000)) + + # The model: + self.peptide_factors = nn.Parameter(torch.randn(n_peptides, n_factors)) + self.run_factors = nn.Parameter(torch.randn(n_factors, n_runs)) + + def forward(self) -> torch.Tensor: + """Reconstruct the matrix. + + Returns + ------- + torch.Tensor of shape (n_peptide, n_runs) + The reconstructed matrix. + """ + return torch.mm(self.peptide_factors, self.run_factors) + + +class RTImputer(BaseEstimator): + """A retention time prediction model. + + RTImputer is a PyTorch model wrapped in a sklearn API. + + Parameters + ---------- + n_factors : int + The number of latent factors. + max_iter : int, optional + The maximum number of training iterations + tol : float, optional + The percent improvement over the previous loss required to + continue trianing. Used in conjuction with ``n_iter_no_change`` + to trigger early stopping. Set to ``None`` to disable. + n_iter_no_change : int, optional + The number of iterations to wait before triggering early + stopping. + lr: float, optional + The learning rate. + device : str or torch.Device, optional + A valid PyTorch device on which to perform the optimization. + silent : bool, optional + Disable logging. + rng : int | np.random.Generator | None, optional + The random number generator. + """ + + def __init__( + self, + n_factors: int, + max_iter: int = 1000, + tol: float = 1e-4, + n_iter_no_change: int = 20, + lr: float = 0.1, + device: str | torch.device = "cpu", + silent: bool = False, + rng: int | np.random.Generator | None = None, + ) -> None: + """Initialize the RTImputer.""" + # Parameters: + self.n_factors = n_factors + self.max_iter = max_iter + self.tol = tol + self.n_iter_no_change = n_iter_no_change + self.device = device + self.lr = lr + self.silent = silent + + # Set during fit: + self._model = None + self._history = None + self._shape = None + self._std = None + self._means = None + + @property + def model_(self) -> MatrixFactorizationModel: + """The underlying PyTorch model.""" + if self._model is None: + raise NotFittedError("This model has not been fit yet.") + + return self._model + + @property + def history_(self) -> pd.DataFrame: + """The training history.""" + return pd.DataFrame( + self._history, + columns=["iteration", "train_loss"], + ) + + def transform(self, X: np.array | torch.Tensor) -> np.array: # noqa: N803 + """Impute missing retention times. + + Parameters + ---------- + X : torch.Tensor of shape (n_peptides, n_runs) + The retention time array. Missing peptides should be denoted as np.nan. + + Returns + ------- + np.array of shape (n_peptides, n_runs) + The predicted retention time matrix. + """ + # Prepare the input and initialize model + X = to_tensor(X) + mask = torch.isnan(X) + X_hat = self.model().to("cpu").detach() + X[mask] = X_hat[mask] + return X.numpy() + + def fit(self, X: np.ndarray | torch.Tensor) -> RTImputer: + """Fit the model. + + Parameters + ---------- + X : array of shape (n_peptides, n_runs) + The retention time array. Missing peptides should be denoted as np.nan. + + Returns + ------- + self + """ + LOGGER.info("Training retention time predictor...") + + # Prepare the input and initialize model + X = to_tensor(X) + mask = ~torch.isnan(X) + X = X.to(self.device, torch.float32) + + self._shape = X.shape + self._model = MatrixFactorizationModel(*X.shape, self.n_factors).to(self.device) + self._history = [] + optimizer = torch.optim.RMSprop(self._model.parameters(), lr=self.lr) + log_interval = max(1, self.max_iter // 20) + LOGGER.info("Logging every %i iterations...", log_interval) + + # The main training loop: + best_loss = np.inf + early_stopping_counter = 0 + LOGGER.info("Iteration | Train Loss") + LOGGER.info("----------------------") + bar = trange(self.max_iter) + for iteration in bar: + optimizer.zero_grad() + X_hat = self._model(X) + loss = ((X - X_hat)[mask] ** 2).mean() + loss.backward() + optimizer.step() + self._history.append((iteration, loss.item())) + + if not iteration % log_interval: + LOGGER.info("%9i | %10.4f", iteration, loss.item()) + + bar.set_postfix(loss=f"{loss:,.3f}") + if self.tol is not None: + if loss < best_loss: + best_loss = loss.item() + early_stopping_counter = 0 + continue + early_stopping_counter += 1 + if early_stopping_counter >= self.n_iter_no_change: + LOGGER.info("Stopping...") + break + + LOGGER.info("DONE!") + return self + + def fit_transform(self, X: np.array | torch.Tensor) -> np.ndarray: + """Fit and impute missing retention times. + + Parameters + ---------- + X : torch.Tensor of shape (n_peptides, n_runs) + The retention time array. Missing peptides should be denoted as np.nan. + + Returns + ------- + np.array of shape (n_peptides, n_runs) + The predicted retention time matrix. + """ + return self.fit(X).transform(X) + + +def to_tensor(array: np.ndarray | torch.Tensor) -> torch.Tensor: + """Transform an array into a PyTorch Tensor, copying the data. + + Parameters + ---------- + array : numpy.ndarray or torch.Tensor + The array to transform + + Returns + ------- + torch.Tensor + The converted PyTorch tensor. + """ + if isinstance(array, torch.Tensor): + return array.to("cpu").clone().detach() + + return torch.tensor(array) diff --git a/diadem/aggregate/search.py b/diadem/aggregate/search.py new file mode 100644 index 0000000..3c58607 --- /dev/null +++ b/diadem/aggregate/search.py @@ -0,0 +1,219 @@ +"""Aggregate diadem search results.""" +from collections.abc import Iterable +from os import PathLike +from pathlib import Path + +import polars as pl +from mokapot import LinearPsmDataset + +from diadem.aggregate.imputers import MFImputer +from diadem.config import DiademConfig + + +def searches( + scores: Iterable[pl.DataFrame | pl.LazyFrame | PathLike], + config: DiademConfig, +) -> tuple[Path]: + """Aggregate search results and align retention times. + + Parameters + ---------- + scores : Iterable[PathLike] + The run-level mokapot parquet files to read. + fasta_file : Pathlike + The FASTA file used for the database search. + config : DiademConfig + The configuration options. + + Returns + ------- + peptides : Path + The accepted peptides + proteins : Path + The accepted proteins + """ + agg = SearchAggregator(scores, config) + return agg.peptide_path, agg.protein_path + + +class SearchAggregator: + """Aggregate search results and align retention times. + + Parameters + ---------- + scores : Iterable[PathLike] + The run-level mokapot parquet files to read. + fasta_file : Pathlike + The FASTA file used for the database search. + config : DiademConfig + The configuration options. + """ + + def __init__( + self, + scores: Iterable[PathLike], + config: DiademConfig, + ) -> None: + """Initialize the search aggregator.""" + self.scores = scores + self.config = config + + base_path = Path(self.config.output_dir) / "aggregate" + self.peptide_path = base_path / "diadem.search.peptides.parquet" + self.protein_path = base_path / "diadem.search.proteins.parquet" + + if not self.config.overwrite: + base_msg = "%s already exists and overwrite is disabled." + if self.peptide_path.exists(): + raise RuntimeError(base_msg.format(str(self.peptide_path))) + if self.protein_path.exists(): + raise RuntimeError(base_msg.format(str(self.protein_path))) + + self._peptides = None + self._proteins = None + self._ret_time = None + self._ion_mobility = None + + if len(self.scores) < 2: + raise ValueError("At least two search results must be provided.") + + # 1. Compute global FDR + self.assign_confidence() + + # 2. Gather RT/IM for accepted peptides: + self.collect_imputer_data() + + # 3. Impute RT/IM for missing peptides in each run: + self.impute() + + # 4. Save the results. + self.save() + + def assign_confidence(self) -> None: + """Assign confidence across all runs.""" + keep_cols = ["peptide", "target_pair", "is_target", "mokapot score"] + try: + score_df = ( + pl.concat(self.scores, how="vertical") + .lazy() + .select(keep_cols) + .collect() + ) + except TypeError: + score_df = pl.concat( + [pl.read_parquet(s, columns=keep_cols) for s in self.scores], + how="vertical", + ) + + peptides = LinearPsmDataset( + psms=(score_df.with_columns(pl.lit("").alias("proteins")).to_pandas()), + target_column="is_target", + spectrum_columns="target_pair", + peptide_column="peptide", + protein_column="protein", + feature_columns="mokapot score", + copy_data=False, + ) + + peptides.add_proteins( + self.config.fasta_file, + enzyme=self.config.db_enzyme, + missed_cleavages=self.config.db_max_missed_cleavages, + min_length=self.config.peptide_length_range[0], + max_length=self.config.peptide_length_range[1], + ) + + # Global FDR: + results = peptides.assign_confidence( + "mokapot score", + eval_fdr=self.config.eval_fdr, + desc=True, + ) + + self._peptides = ( + pl.DataFrame(results.peptides) + .filter(pl.col("mokapot q-value") <= self.config.eval_fdr) + .drop("target_pair") + ) + + self._proteins = pl.DataFrame(results.proteins).filter( + pl.col("mokapot q-value") <= self.config.eval_fdr, + ) + + def collect_imputer_data(self) -> None: + """Filter run results for confident peptides.""" + keep_cols = ["peptide", "filename", "mokapot q-value"] + rt_df = [] + im_df = [] + for run in self.scores: + try: + run_df = pl.read_parquet(run, columns=keep_cols).lazy() + except TypeError: + run_df = run.select(keep_cols).lazy() + + run_df = run_df.filter( + pl.col("mokapot q-value") <= self.config.eval_fdr, + ).drop("mokapot q-value") + + fname = run_df["filename"][0] + + # Join with accepted peptides, maintaining order. + run_df = ( + self._peptides.lazy() + .join(run_df, how="left", on="peptide") + .drop( + [ + "mokapot q-value", + "mokapot PEP", + "mokapot score", + "filename", + "peptide", + ], + ) + .collect() + ) + + rt_df.append( + run_df.select(pl.col("RetentionTime").alias(f"retention_time_{fname}")), + ) + try: + im_df.append( + run_df.select(pl.col("IonMobility").alias(f"ion_mobility_{fname}")), + ) + except pl.exceptions.ColumnNotFoundError: + pass + + rt_df = pl.concat(rt_df, how="diagonal") + self._ret_time = rt_df + if im_df: + im_df = pl.concat(im_df, how="diagonal") + self._ion_mobility = im_df + + def impute(self) -> None: + """Imput missing retention times and ion mobility values.""" + rt_mat = ( + MFImputer(rng=self.config.seed, task="retention time") + .search_factors(self._ret_time.to_numpy(), [2, 4, 8, 16]) + .fit_transform(self._ret_time.to_numpy()) + ) + + self._ret_time = pl.DataFrame(rt_mat, schema=self._ret_time.columns) + + if self._ion_mobility is not None: + im_mat = ( + MFImputer(rng=self.config.seed, task="ion mobility") + .search_factors(self._ion_mobility.to_numpy(), [2, 4, 8, 16]) + .fit_transform(self._ion_mobility.to_numpy()) + ) + + self._ion_mobility = pl.DataFrame( + im_mat, + schema=self._ion_mobility.columns, + ) + + def save(self) -> tuple(Path): + """Save the aggregated results.""" + self.peptide_path.parent.mkdir(exist_ok=True) + pep_dfs = [self._peptides, self._ret_time, self._ion_mobility] + pl.concat(pep_dfs, how="horizontally").write_parquet(self.peptide_path) + self._proteins.write_parquet(self.protein_path) diff --git a/diadem/cli.py b/diadem/cli.py index 0c05bcb..e1128d5 100644 --- a/diadem/cli.py +++ b/diadem/cli.py @@ -38,16 +38,18 @@ def main_cli() -> None: @main_cli.command(help="Runs the search module of DIAdem") @click.option( - "--mzml_file", - help="mzML file to use as an input for the search", + "--data_path", + help="mzML or .d file to use as an input for the search", ) @click.option("--fasta", help="fasta file to use as an input") @click.option("--out_prefix", help="Prefix to add to all output files") @click.option( - "--mode", type=click.Choice(["dda", "dia"], case_sensitive=False), default="dia" + "--mode", + type=click.Choice(["dda", "dia"], case_sensitive=False), + default="dia", ) @click.option("--config", help="Path to the config toml configuration file to use.") -def search(mzml_file, fasta, out_prefix, mode, config) -> None: +def search(data_path, fasta, out_prefix, mode, config) -> None: setup_logger() if config: config = DiademConfig.from_toml(config) @@ -56,11 +58,17 @@ def search(mzml_file, fasta, out_prefix, mode, config) -> None: config = DiademConfig() if mode == "dia": diadem_main( - fasta_path=fasta, mzml_path=mzml_file, config=config, out_prefix=out_prefix + fasta_path=fasta, + data_path=data_path, + config=config, + out_prefix=out_prefix, ) elif mode == "dda": dda_main( - mzml_path=mzml_file, fasta_path=fasta, config=config, out_prefix=out_prefix + mzml_path=data_path, + fasta_path=fasta, + config=config, + out_prefix=out_prefix, ) else: raise NotImplementedError diff --git a/diadem/config.py b/diadem/config.py index c146c5c..adbb59b 100644 --- a/diadem/config.py +++ b/diadem/config.py @@ -3,7 +3,7 @@ import hashlib import sys from argparse import Namespace -from dataclasses import asdict, dataclass, field +from dataclasses import asdict, dataclass, field, fields from typing import Literal import tomli_w @@ -24,9 +24,13 @@ @dataclass(frozen=True, eq=True) -class DiademConfig: # noqa - g_tolerances: tuple[float, ...] = field(default=(20, 10)) - g_tolerance_units: tuple[MassError, ...] = field(default=("ppm", "ppm")) +class DiademIndexConfig: + """Configuration to generate an index. + + Base class for the diadem index, this is mean to contain the configuration + options relevant to generate the index, and should not have parameters that + are used at runtime (index gen vs index use). + """ peptide_length_range: tuple[int, int] = field( default=(7, 25), @@ -43,63 +47,19 @@ class DiademConfig: # noqa default="by", ) ion_charges: tuple[int, ...] = field( - default=(1,), + default=(1, 2), ) ion_mz_range: tuple[float, float] = field( default=(250, 2000.0), ) - db_enzyme: str = field(default="trypsin") + db_enzyme: str = field(default="[KR]") # This needs to be a regex for mokapot. db_max_missed_cleavages: int = 2 db_bucket_size: int = 2**15 # Variable mods # Static mods - # Main score - # Currently unused ... - scoring_score_function: ScoringFunctions = "Hyperscore" - - run_max_peaks: int = 1e6 - - # the 5k number comes from the neviskii lab paper on deisotoping - run_max_peaks_per_spec: int = 5_000 - - # Prallelism 1 means no parallelism, -1 means all cores, any other positive - # integer means use that many cores. - run_parallelism: int = -2 - run_deconvolute_spectra: bool = True - run_min_peak_intensity: float = 100 - run_debug_log_frequency: int = 20 - run_allowed_fails: int = 500 - run_window_size: int = 21 - run_max_peaks_per_window: int = 150 - - # Min intensity to consider for matching and extracting - run_min_intensity_ratio: float = 0.011 - run_min_correlation_score: float = 0.25 - - run_scaling_ratio = 0.01 - run_scalin_limits: tuple[float, float] = (0.01, 0.999) - - @property - def ms2ml_config(self) -> Config: - """Returns the ms2ml config. - - It exports all the parameters that are used inside of - ms2ml to its own configuration object. - """ - conf = Config( - g_tolerances=self.g_tolerances, - g_tolerance_units=self.g_tolerance_units, - peptide_length_range=self.peptide_length_range, - precursor_charges=self.precursor_charges, - ion_series=self.ion_series, - ion_charges=self.ion_charges, - peptide_mz_range=self.peptide_mz_range, - ) - return conf - def log(self, logger: Logger, level: str = "INFO") -> None: """Logs all the configurations using the passed logger.""" logger.log(level, "Diadem Configuration:") @@ -152,17 +112,107 @@ def from_args(cls, args: Namespace) -> DiademConfig: def hash(self) -> str: """Hashes the config in a reproducible manner. - Notes + Notes: ----- Python adds a seed to the hash, therefore the has will be different - Example - ------- - >>> DiademConfig().hash() - '19922fcb81d81062169e5a677517e00b' - >>> DiademConfig(ion_series = "y").hash() - 'a365384390de3ba5f448096a73155005' + Examples + -------- + >>> DiademIndexConfig().hash() + '1a23e68d04576bb73dbd5e0173679e64' + >>> DiademIndexConfig(ion_series = "y").hash() + '846dbaf6adb3e2ddc5779fc5169ec675' """ h = hashlib.md5() h.update(tomli_w.dumps(self.toml_dict()).encode()) return h.hexdigest() + + @property + def ms2ml_config(self) -> Config: + """Returns the ms2ml config. + + It exports all the parameters that are used inside of + ms2ml to its own configuration object. + """ + conf = Config( + g_tolerances=[], + g_tolerance_units=[], + peptide_length_range=self.peptide_length_range, + precursor_charges=self.precursor_charges, + ion_series=self.ion_series, + ion_charges=self.ion_charges, + peptide_mz_range=self.peptide_mz_range, + ) + return conf + + +@dataclass(frozen=True, eq=True) +class DiademConfig(DiademIndexConfig): # noqa + # TODO split tolerances in 'within spectrum' and 'between spectra' + # since tolerances for deisotoping should be a lot lower than they should be + # for database matching ... 5ppm for a database match is ok, 1 ppm for + # an isotope envelope is barely acceptable. + g_tolerances: tuple[float, ...] = field(default=(20, 20)) + g_tolerance_units: tuple[MassError, ...] = field(default=("ppm", "ppm")) + + g_ims_tolerance: float = 0.03 + g_ims_tolerance_unit: Literal["abs"] = "abs" + # Main score + # Currently unused ... + scoring_score_function: ScoringFunctions = "Hyperscore" + + run_max_peaks: int = 1e6 + + # the 5k number comes from the neviskii lab paper on deisotoping + run_max_peaks_per_spec: int = 5_000 + + # Prallelism 1 means no parallelism, -1 means all cores, any other positive + # integer means use that many cores. + run_parallelism: int = -4 + run_deconvolute_spectra: bool = True + run_min_peak_intensity: float = 100 + run_debug_log_frequency: int = 50 + run_allowed_fails: int = 700 + run_window_size: int = 21 + run_max_peaks_per_window: int = 150 + + # Min intensity to consider for matching and extracting + run_min_intensity_ratio: float = 0.01 + run_min_correlation_score: float = 0.2 + + run_scaling_ratio: float = 0.001 + run_scaling_limits: tuple[float, float] = (0.001, 0.999) + + # Mokapot parameters + train_fdr: float = 0.01 + eval_fdr: float = 0.01 + + @property + def ms2ml_config(self) -> Config: + """Returns the ms2ml config. + + It exports all the parameters that are used inside of + ms2ml to its own configuration object. + """ + conf = Config( + g_tolerances=self.g_tolerances, + g_tolerance_units=self.g_tolerance_units, + peptide_length_range=self.peptide_length_range, + precursor_charges=self.precursor_charges, + ion_series=self.ion_series, + ion_charges=self.ion_charges, + peptide_mz_range=self.peptide_mz_range, + ) + return conf + + @property + def index_config(self) -> DiademIndexConfig: + """Generates an index config. + + The index config is a subset of the DiademConfig. + Therefore generating this subset allow us to hash + it in a way that would identify the index generation. + """ + self_dict = asdict(self) + kwargs = {x.name: self_dict[x.name] for x in fields(DiademIndexConfig)} + return DiademIndexConfig(**kwargs) diff --git a/diadem/data_io/__init__.py b/diadem/data_io/__init__.py new file mode 100644 index 0000000..fca2b25 --- /dev/null +++ b/diadem/data_io/__init__.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +from os import PathLike + +from diadem.config import DiademConfig +from diadem.data_io.mzml import SpectrumStacker +from diadem.data_io.timstof import TimsSpectrumStacker + + +def read_raw_data( + filepath: PathLike, + config: DiademConfig, +) -> TimsSpectrumStacker | SpectrumStacker: + """Generic function to read data for DIA. + + It uses the file extension to know whether to dispatch the data to an + mzML or timsTOF reader. + """ + if str(filepath).endswith(".d") or str(filepath).endswith("hdf"): + rf = TimsSpectrumStacker(filepath=filepath, config=config) + elif str(filepath).lower().endswith(".mzml"): + rf = SpectrumStacker(filepath, config=config) + else: + raise NotImplementedError + return rf diff --git a/diadem/data_io/mzml.py b/diadem/data_io/mzml.py new file mode 100644 index 0000000..58f809f --- /dev/null +++ b/diadem/data_io/mzml.py @@ -0,0 +1,773 @@ +from __future__ import annotations + +import os +from collections.abc import Iterator +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import Literal + +import numpy as np +import polars as pl +from joblib import Parallel, delayed +from loguru import logger +from msflattener.mzml import get_mzml_data +from numpy.typing import NDArray + +from diadem.config import DiademConfig, MassError +from diadem.data_io.utils import slice_from_center, strictzip, xic +from diadem.search.metrics import get_ref_trace_corrs +from diadem.utilities.utils import is_sorted, plot_to_log + +# TODO re-make some of these classes as ABCs +# It would make intent explicit, since they are sub-classed +# for timstof data equivalents. + + +@dataclass +class ScanGroup: + """Represents all spectra that share an isolation window.""" + + precursor_range: tuple[float, float] + mzs: list[NDArray] + intensities: list[NDArray] + base_peak_mz: NDArray[np.float32] + base_peak_int: NDArray[np.float32] + retention_times: NDArray + scan_ids: list[str] + iso_window_name: str + + precursor_mzs: list[NDArray] + precursor_intensities: list[NDArray] + precursor_retention_times: NDArray + + def __post_init__(self) -> None: + """Check that all the arrays have the same length.""" + elems = [ + self.mzs, + self.intensities, + self.base_peak_int, + self.base_peak_mz, + self.retention_times, + self.scan_ids, + ] + + # TODO move this to assertions so they can be skipped + # during runtime + lengths = {len(x) for x in elems} + if len(lengths) != 1: + raise ValueError("Not all lengths are the same") + if len(self.precursor_range) != 2: + raise ValueError( + ( + "Precursor mass range should have 2 elements," + f" has {len(self.precursor_range)}" + ), + ) + plot_to_log( + self.base_peak_int, + title=f"Base peak chromatogram for the Group in {self.iso_window_name}", + ) + for x in self.mzs: + if not is_sorted(x): + raise ValueError("m/z arrays are not sorted is ScanGroup") + + @property + def cache_file_stem(self) -> str: + stem = "".join(x if x.isalnum() else "_" for x in self.iso_window_name) + return stem + + def to_cache(self, dir: Path) -> None: + """Saves the group to a cache file.""" + Path(dir).mkdir(parents=True, exist_ok=True) + fragment_df = self.as_dataframe() + precursor_df = self.precursor_dataframe() + + stem = self.cache_file_stem + + precursor_df.write_parquet(dir / f"{stem}_precursors.parquet") + fragment_df.write_parquet(dir / f"{stem}_fragments.parquet") + + @classmethod + def _elems_from_fragment_cache(cls, file): + fragment_data = pl.read_parquet(file).to_dict() + precursor_range = ( + fragment_data.pop("precursor_start")[0], + fragment_data.pop("precursor_end")[0], + ) + ind_max_int = [np.argmax(x) for x in fragment_data["intensities"]] + base_peak_mz = np.array( + [x[i] for x, i in zip(fragment_data["mzs"], ind_max_int)], + ) + base_peak_int = np.array( + [x[i] for x, i in zip(fragment_data["intensities"], ind_max_int)], + ) + + out = { + "precursor_range": precursor_range, + "mzs": fragment_data["mzs"], + "intensities": fragment_data["intensities"], + "base_peak_mz": base_peak_mz, + "base_peak_int": base_peak_int, + "retention_times": fragment_data["retention_times"], + "scan_ids": fragment_data["scan_ids"], + } + + return out, fragment_data + + def _precursor_elems_from_cache(self, file): + precursor_data = pl.read_parquet(file).to_dict() + out = { + "precursor_mzs": precursor_data["precursor_mzs"], + "precursor_intensities": precursor_data["precursor_intensities"], + "precursor_retention_times": precursor_data["precursor_retention_times"], + } + return out, precursor_data + + @classmethod + def from_cache(cls, dir: Path, name: str) -> ScanGroup: + """Loads a group from a cache file.""" + raise ValueError("Why am I here??") + fragment_elems, _fragment_data = cls._elems_from_fragment_cache( + dir / f"{name}_fragments.parquet", + ) + precursor_elems, _fragment_data = cls._precursor_elems_from_cache( + dir / f"{name}_precursors.parquet", + ) + return cls( + iso_window_name=name, + **precursor_elems, + **fragment_elems, + ) + + def as_dataframe(self) -> pl.DataFrame: + """Returns a dataframe with the data in the group. + + The dataframe has the following columns: + - mzs: list of mzs for each spectrum + - intensities: list of intensities for each spectrum + - retention_times: retention times for each spectrum + - precursor_start: start of the precursor range + - precursor_end: end of the precursor range + """ + out = pl.DataFrame( + { + "mzs": self.mzs, + "intensities": self.intensities, + "retention_times": self.retention_times, + "scan_ids": self.scan_ids, + }, + ) + out["precursor_start"] = min(self.precursor_range) + out["precursor_end"] = max(self.precursor_range) + return out + + def precursor_dataframe(self) -> pl.DataFrame: + """Returns a dataframe with the metadata for the group. + + The dataframe has the following columns: + - precursor_mzs: list of precursor mzs for each spectrum + - precursor_intensities: list of precursor intensities for each spectrum + - precursor_retention_times: precursor retention times for each spectrum + - precursor_start: start of the precursor range + - precursor_end: end of the precursor range + """ + out = pl.DataFrame( + { + "precursor_mzs": self.precursor_mzs, + "precursor_intensities": self.precursor_intensities, + "precursor_retention_times": self.precursor_retention_times, + }, + ) + return out + + def get_highest_window( + self, + window: int, + min_intensity_ratio: float, + mz_tolerance: float, + mz_tolerance_unit: str, + min_correlation: float, + max_peaks: int, + ) -> StackedChromatograms: + """Gets the highest intensity window of the chromatogram. + + Briefly ... + 1. Gets the highes peak accross all spectra in the chromatogram range. + 2. Finds what peaks are in that same spectrum. + 3. Looks for spectra around that spectrum. + 4. extracts the chromatogram for all mzs in the "parent spectrum" + + """ + top_index = np.argmax(self.base_peak_int) + window = StackedChromatograms.from_group( + group=self, + index=top_index, + window=window, + min_intensity_ratio=min_intensity_ratio, + min_correlation=min_correlation, + mz_tolerance=mz_tolerance, + mz_tolerance_unit=mz_tolerance_unit, + max_peaks=max_peaks, + ) + + return window + + # TODO make this just take the stacked chromatogram object + def scale_window_intensities( + self, + index: int, + scaling: NDArray, + window_indices: list[list[int]], + ) -> None: + """Scales the intensities of specific mzs in a window of the chromatogram. + + Parameters + ---------- + index : int + The index of the center spectrum for the the window to scale. + scaling : NDArray + The scaling factors to apply to the intensities. Size should be + the same as the length of the window. + mzs : NDArray + The m/z values of the peaks to scale. + window_indices : list[list[int]] + The indices of the peaks in the window to scale. These are tracked + internally during the workflow. + window_mzs : NDArray + The m/z values of the peaks in the window to scale. These are tracked + internally during the workflow. + """ + window = len(scaling) + slc, center_index = slice_from_center( + center=index, + window=window, + length=len(self), + ) + slc = range(*slc.indices(len(self))) + + zipped = strictzip(slc, scaling, window_indices) + for i, s, si in zipped: + inds = [np.array(x) for x in si if len(x)] + if inds: + inds = np.unique(np.concatenate(inds)) + self._scale_spectrum_at( + spectrum_index=i, + value_indices=inds, + scaling_factor=s, + ) + + def _scale_spectrum_at( + self, + spectrum_index: int, + value_indices: NDArray[np.int64], + scaling_factor: float, + ) -> None: + i = spectrum_index # Alias for brevity in within this function + + if len(value_indices) > 0: + self.intensities[i][value_indices] = ( + self.intensities[i][value_indices] * scaling_factor + ) + else: + return None + + # TODO this is hard-coded right now, change as a param + int_remove = self.intensities[i] < 10 + if np.any(int_remove): + self.intensities[i] = self.intensities[i][np.invert(int_remove)] + self.mzs[i] = self.mzs[i][np.invert(int_remove)] + + if len(self.intensities[i]): + self.base_peak_int[i] = np.max(self.intensities[i]) + self.base_peak_mz[i] = self.mzs[i][np.argmax(self.intensities[i])] + else: + self.base_peak_int[i] = -1 + self.base_peak_mz[i] = -1 + + def get_precursor_evidence( + self, + rt: float, + mzs: NDArray[np.float32], + mz_tolerance: float, + mz_tolerance_unit: Literal["ppm", "Da"] = "ppm", + ) -> tuple[NDArray[np.float32], list[NDArray[np.float32]]]: + """Finds precursor information for a given RT and m/z. + + NOTE: This is a first implementation of the functionality, + therefore it is very simple and prone to optimization and + rework. + + 1. Find the closest RT. + 2. Find if there are peaks that match the mzs. + 3. Return a list of dm and a list of intensities for each. + + Parameters + ---------- + rt : float + The retention time to find precursor information for. + mzs : NDArray[np.float32] + The m/z values to find precursor information for. + mz_tolerance : float + The m/z tolerance to use when finding precursor information. + mz_tolerance_unit : str, optional + The unit of the m/z tolerance, by default "ppm" + + Returns + ------- + tuple[NDArray[np.float32], list[NDArray[np.float32]]]] + A array with the list of intensities and a list arrays, + each of which is the for the values integrated for the + intensity values. + """ + index = np.searchsorted(self.precursor_retention_times, rt) + slc, center_index = slice_from_center( + index, + window=11, + length=len(self.precursor_mzs), + ) + q_mzs = self.precursor_mzs[index] + + q_intensities = self.precursor_intensities[index] + # TODO change preprocessing of the MS1 level to make it more + # permissive, cleaner spectra is not critical here. + + out_ints = [] + out_dms = [] + + if len(q_intensities) > 0: + for q_mzs, q_intensities in zip( + self.precursor_mzs[slc], + self.precursor_intensities[slc], + ): + intensities, indices = xic( + query_mz=q_mzs, + query_int=q_intensities, + mzs=mzs, + tolerance=mz_tolerance, + tolerance_unit=mz_tolerance_unit, + ) + dms = [q_mzs[inds] - match_mz for inds, match_mz in zip(indices, mzs)] + out_ints.append(intensities) + out_dms.append(dms) + + intensities = np.stack(out_ints, axis=0).sum(axis=0) + dms = out_dms[center_index] + else: + intensities = np.zeros_like(mzs) + dms = [[] for _ in range(len(mzs))] + + return intensities, dms + + def __len__(self) -> int: + """Returns the number of spectra in the group.""" + return len(self.intensities) + + +@dataclass +class StackedChromatograms: + """A class containing the elements of a stacked chromatogram. + + The stacked chromatogram is the extracted ion chromatogram + across a window of spectra. + + Parameters + ---------- + array : + An array of shape [i, w] + mzs : + An array of shape [i] + ref_index : + An integer in the range [0, i] + parent_index : + Identifier of the range where the window was extracted + base_peak_intensity : + Intensity of the base peak in the reference spectrum + stack_peak_indices : + List of indices used to stack the array, it is a list of dimensions [w], + where each element can be either a list of indices or an empty list. + + Details + ------- + The dimensions of the arrays are `w` the window + size of the extracted ion chromatogram. `i` the number + of m/z peaks that were extracted. + + """ + + array: NDArray[np.float32] + mzs: NDArray[np.float32] + ref_index: int + parent_index: int + base_peak_intensity: float + stack_peak_indices: list[list[int]] | list[NDArray[np.int32]] + center_intensities: NDArray[np.float32] + correlations: NDArray + + def __post_init__(self) -> None: + """Checks that the dimensions of the arrays are correct. + + Since they are assertions, they are not meant to be needed for the + correct working of the + """ + array_i = self.array.shape[-2] + array_w = self.array.shape[-1] + + mz_i = self.mzs.shape[-1] + + assert ( + self.ref_index <= mz_i + ), f"Reference index outside of mz values {self.ref_index} > {mz_i}" + assert ( + array_i == mz_i + ), f"Intensity Array and mzs have different lengths {array_i} != {mz_i}" + for i, x in enumerate(self.stack_peak_indices): + assert len(x) == mz_i, ( + f"Number of mzs and number of indices {len(x)} != {mz_i} is different" + f" for {i}" + ) + assert array_w == len( + self.stack_peak_indices, + ), "Window size is not respected in the stack" + + @property + def ref_trace(self) -> NDArray[np.float32]: + """Returns the reference trace. + + The reference trace is the extracted ion chromatogram of the + mz that corresponds to the highest intensity peak. + """ + return self.array[self.ref_index, ...] + + @property + def ref_mz(self) -> float: + """Returns the m/z value of the reference trace.""" + return self.mzs[self.ref_index] + + @property + def ref_fwhm(self) -> int: + """Returns the number of points in the reference trace above half max. + + Not really fwhm, just number of elements above half max. + """ + rt = self.ref_trace + rt = rt - rt.min() + above_hm = rt >= (rt.max() / 2) + return above_hm.astype(int).sum() + + def plot(self, plt, matches=None) -> None: # noqa + """Plots the stacked chromatogram as lines.""" + # TODO reconsider this implementation, maybe lazy import + # of matplotlib. + plt.plot(self.array.T, color="gray", alpha=0.5, linewidth=0.5) + plt.plot(self.array[self.ref_index, ...].T, color="black", linewidth=2) + if matches is not None: + plt.plot(self.array[matches, ...].T, color="magenta") + + def trace_correlation(self) -> NDArray[np.float32]: + """Calculate the correlation between the reference trace and all other traces. + + Returns + ------- + NDArray[np.float32] + An array of shape [i] where i is the number of traces + in the stacked chromatogram. + """ + return get_ref_trace_corrs(arr=self.array, ref_idx=self.ref_index) + + @staticmethod + # @profile + def from_group( + group: ScanGroup, + index: int, + window: int = 21, + mz_tolerance: float = 0.02, + mz_tolerance_unit: MassError = "da", + min_intensity_ratio: float = 0.01, + min_correlation: float = 0.5, + max_peaks: int = 150, + ) -> StackedChromatograms: + """Create a stacked chromatogram from a scan group. + + Parameters + ---------- + group : ScanGroup + A scan group containing the spectra to stack + index : int + The index of the spectrum to use as the reference + window : int, optional + The number of spectra to stack, by default 21 + mz_tolerance : float, optional + The tolerance to use when matching m/z values, by default 0.02 + mz_tolerance_unit : MassError, optional + The unit of the tolerance, by default "da" + min_intensity_ratio : float, optional + The minimum intensity ratio to use when stacking, by default 0.01 + min_correlation : float, optional + The minimum correlation to use when stacking, by default 0.5 + max_peaks : int, optional + The maximum number of peaks to return in a group, by default is 150. + If the candidates is more than this number, it will the best co-eluting + peaks. + + """ + # The center index is the same as the provided index + # Except in cases where the edge of the group is reached, where + # the center index is adjusted to the edge of the group + slice_q, center_index = slice_from_center( + center=index, + window=window, + length=len(group.mzs), + ) + mzs = group.mzs[slice_q] + intensities = group.intensities[slice_q] + + center_mzs = mzs[center_index] + center_intensities = intensities[center_index] + + int_keep = center_intensities >= ( + center_intensities.max() * min_intensity_ratio + ) + + # num_keep = int_keep.sum() + # logger.debug("Number of peaks to stack: " + # f"{len(center_mzs)}, number above 0.1% intensity {num_keep} " + # f"[{100*num_keep/len(center_mzs):.02f} %]") + center_mzs = center_mzs[int_keep] + center_intensities = center_intensities[int_keep] + + xic_outs = [] + + for i, (m, inten) in enumerate(zip(mzs, intensities)): + xic_outs.append( + xic( + query_mz=m, + query_int=inten, + mzs=center_mzs, + tolerance=mz_tolerance, + tolerance_unit=mz_tolerance_unit, + ), + ) + if i == center_index: + assert xic_outs[-1][0].sum() >= center_intensities.max() + + stacked_arr = np.stack([x[0] for x in xic_outs], axis=-1) + + # TODO make this an array and subset it in line 457 + indices = [x[1] for x in xic_outs] + + if stacked_arr.shape[-2] > 1: + ref_id = np.argmax(stacked_arr[..., center_index]) + corrs = get_ref_trace_corrs(arr=stacked_arr, ref_idx=ref_id) + + # I think adding the 1e-5 is needed here due to numric instability + # in the flaoting point operation + assert np.max(corrs) <= ( + corrs[ref_id] + 1e-5 + ), "Reference does not have max corrr" + + max_peak_corr = np.sort(corrs)[-max_peaks] if len(corrs) > max_peaks else -1 + keep = corrs >= max(min_correlation, max_peak_corr) + keep_corrs = corrs[keep] + + stacked_arr = stacked_arr[..., keep, ::1] + center_mzs = center_mzs[keep] + center_intensities = center_intensities[keep] + indices = [[y for y, k in zip(x, keep) if k] for x in indices] + else: + keep_corrs = np.array([1.0]) + + ref_id = np.argmax(stacked_arr[..., center_index]) + bp_int = stacked_arr[ref_id, center_index] + + out = StackedChromatograms( + array=stacked_arr, + mzs=center_mzs, + ref_index=ref_id, + parent_index=index, + base_peak_intensity=bp_int, + stack_peak_indices=indices, + center_intensities=center_intensities, + correlations=keep_corrs, + ) + return out + + +class SpectrumStacker: + """Helper class that stacks the spectra of an mzml file into chromatograms.""" + + def __init__(self, mzml_file: Path | str, config: DiademConfig) -> None: + """Initializes the SpectrumStacker class. + + Parameters + ---------- + mzml_file : Path | str + Path to the mzml file. + config : DiademConfig + The configuration object. Note that this is an DiademConfig + configuration object. + """ + self.config = config + self.cache_location = Path(mzml_file).with_suffix(".parquet") + if self.cache_location.exists(): + logger.info(f"Found cache file at {self.cache_location}") + else: + df = get_mzml_data(mzml_file, min_peaks=15) + df.write_parquet(self.cache_location) + del df + + unique_windows = ( + pl.scan_parquet(self.cache_location) + .select(pl.col(["quad_low_mz_values", "quad_high_mz_values"])) + .filter(pl.col("quad_low_mz_values") > 0) + .sort("quad_low_mz_values") + .unique() + .collect() + ) + + if "DEBUG_DIADEM" in os.environ: + logger.error("RUNNING DIADEM IN DEBUG MODE (only the 4th precursor index)") + self.unique_precursor_windows = unique_windows[3:4].rows(named=True) + else: + self.unique_precursor_windows = unique_windows.rows(named=True) + + @contextmanager + def lazy_datafile(self) -> pl.LazyFrame: + """Scans the cached version of the data and yields it as a context manager.""" + yield pl.scan_parquet(self.cache_location) + + def _precursor_iso_window_elements( + self, + precursor_window: dict[str, float], + mz_range: None | tuple[float, float] = None, + ) -> dict[str : dict[str:NDArray]]: + # TODO make this a more generic function + # this is pretty much the same for timstof data but + # with ims values... + with self.lazy_datafile() as datafile: + datafile: pl.LazyFrame + promise = ( + pl.col("quad_low_mz_values") == precursor_window["quad_low_mz_values"] + ) & ( + pl.col("quad_high_mz_values") == precursor_window["quad_high_mz_values"] + ) + ms_data = datafile.filter(promise).sort("rt_values") + + if mz_range is not None: + nested_cols = [ + "mz_values", + "corrected_intensity_values", + ] + non_nested_cols = [ + x for x in ms_data.head().collect().columns if x not in nested_cols + ] + ms_data = ( + ms_data.explode(nested_cols) + .filter(pl.col("mz_values").is_between(mz_range[0], mz_range[1])) + .groupby(pl.col(non_nested_cols)) + .agg(nested_cols) + .sort("rt_values") + ) + + ms_data = ms_data.collect() + + bp_indices = [np.argmax(x) for x in ms_data["corrected_intensity_values"]] + bp_ints = [ + x1.to_numpy()[x2] + for x1, x2 in zip(ms_data["corrected_intensity_values"], bp_indices) + ] + bp_ints = np.array(bp_ints) + bp_mz = [ + x1.to_numpy()[x2] for x1, x2 in zip(ms_data["mz_values"], bp_indices) + ] + bp_mz = np.array(bp_mz) + bp_indices = np.array(bp_indices) + rts = ms_data["rt_values"].to_numpy(zero_copy_only=True) + assert is_sorted(rts) + + quad_high = ms_data["quad_high_mz_values"][0] + quad_low = ms_data["quad_low_mz_values"][0] + window_name = str(quad_low) + "_" + str(quad_high) + + template = window_name + "_{}" + scan_indices = [template.format(i) for i in range(len(rts))] + + x = { + "precursor_range": (quad_low, quad_high), + "base_peak_int": bp_ints, + "base_peak_mz": bp_mz, + "iso_window_name": window_name, + "retention_times": rts, + "scan_ids": scan_indices, + } + orders = [np.argsort(x.to_numpy()) for x in ms_data["mz_values"]] + + x.update( + { + "mzs": [ + x.to_numpy()[o] for x, o in zip(ms_data["mz_values"], orders) + ], + "intensities": [ + x.to_numpy()[o] + for x, o in zip(ms_data["corrected_intensity_values"], orders) + ], + }, + ) + + return x + + def _precursor_iso_window_groups( + self, + precursor_window: dict[str, float], + ) -> dict[str:ScanGroup]: + elems = self._precursor_iso_window_elements(precursor_window) + prec_info = self._precursor_iso_window_elements( + {"quad_low_mz_values": -1, "quad_high_mz_values": -1}, + mz_range=list(precursor_window.values()), + ) + + assert is_sorted(prec_info["retention_times"]) + + out = ScanGroup( + precursor_mzs=prec_info["mzs"], + precursor_intensities=prec_info["intensities"], + precursor_retention_times=prec_info["retention_times"], + **elems, + ) + return out + + def get_iso_window_groups(self, workerpool: None | Parallel) -> list[ScanGroup]: + """Get scan groups for each unique isolation window. + + Parameters + ---------- + workerpool : None | Parallel + If None, the function will be run in serial mode. + If Parallel, the function will be run in parallel mode. + The Parallel is created using joblib.Parallel. + + Returns + ------- + list[ScanGroup] + A list of ScanGroup objects. + Each of them corresponding to an unique isolation window from + the quadrupole. + """ + if workerpool is None: + results = [ + self._precursor_iso_window_groups(i) + for i in self.unique_precursor_windows + ] + else: + results = workerpool( + delayed(self._precursor_iso_window_groups)(i) + for i in self.unique_precursor_windows + ) + + return results + + def yield_iso_window_groups(self) -> Iterator[ScanGroup]: + """Yield scan groups for each unique isolation window.""" + for i in self.unique_precursor_windows: + results = self._precursor_iso_window_groups(i) + yield results diff --git a/diadem/data_io/timstof.py b/diadem/data_io/timstof.py new file mode 100644 index 0000000..94e2841 --- /dev/null +++ b/diadem/data_io/timstof.py @@ -0,0 +1,698 @@ +from __future__ import annotations + +import os +from collections.abc import Iterator +from contextlib import contextmanager +from dataclasses import dataclass +from os import PathLike +from pathlib import Path +from typing import Literal + +import numpy as np +import polars as pl +from joblib import Parallel, delayed +from loguru import logger +from ms2ml.utils.mz_utils import get_tolerance +from msflattener.bruker import get_timstof_data +from numpy.typing import NDArray + +from diadem.config import DiademConfig +from diadem.data_io.mzml import ( + MassError, + ScanGroup, + SpectrumStacker, + StackedChromatograms, +) +from diadem.data_io.utils import slice_from_center, xic +from diadem.search.metrics import get_ref_trace_corrs +from diadem.utilities.utils import is_sorted + +if "PLOTDIADEM" in os.environ: + pass + +IMSError = Literal["abs", "pct"] + + +@dataclass +class TimsStackedChromatograms(StackedChromatograms): + """A class containing the elements of a stacked chromatogram. + + The stacked chromatogram is the extracted ion chromatogram + across a window of spectra. + + Parameters + ---------- + array : + An array of shape [i, w] + mzs : + An array of shape [i] + ref_index : + An integer in the range [0, i] + ref_ims : + IMS value for the reference peak + parent_index : + Identifier of the range where the window was extracted + base_peak_intensity : + Intensity of the base peak in the reference spectrum + stack_peak_indices : + List of indices used to stack the array, it is a list of dimensions [w], + where each element can be either a list of indices or an empty list. + + Details + ------- + The dimensions of the arrays are `w` the window + size of the extracted ion chromatogram. `i` the number + of m/z peaks that were extracted. + + """ + + ref_ims: float + + @staticmethod + # @profile + def from_group( + group: TimsScanGroup, + index: int, + window: int = 21, + mz_tolerance: float = 0.02, + mz_tolerance_unit: MassError = "da", + ims_tolerance: float = 0.03, + ims_tolerance_unit: IMSError = "abs", + # TODO implement abs and pct ims error ... + # maybe just use the mz ppm tolerance an multiply by 10000 ... + min_intensity_ratio: float = 0.01, + min_correlation: float = 0.5, + max_peaks: int = 150, + ) -> TimsStackedChromatograms: + """Generates a stacked chromatogram from a TimsScanGroup. + + Parameters + ---------- + group : ScanGroup + A scan group containing the spectra to stack + index : int + The index of the spectrum to use as the reference + window : int, optional + The number of spectra to stack, by default 21 + mz_tolerance : float, optional + The tolerance to use when matching m/z values, by default 0.02 + mz_tolerance_unit : MassError, optional + The unit of the tolerance, by default "da" + ims_tolerance : float, optional + The tolerance to use for the ion mobility dimension. + ims_tolerance_unit : IMSError + The unit of the IMS tolerance to use, 'pct' (percent) and + 'abs' (absolute) are acceptable values. + min_intensity_ratio : float, optional + The minimum intensity ratio to use when stacking, by default 0.01 + min_correlation : float, optional + The minimum correlation to use when stacking, by default 0.5 + max_peaks : int, optional + The maximum number of peaks to return in a group, by default is 150. + If the candidates is more than this number, it will the best co-eluting + peaks. + + """ + # TODO consider if deisotoping should happen at this stage, + # not at the pre-processing stage. + # The issue with that was that tracking indices is harder. + # TODO abstract this so it is not as redundant with super().from_group() + + # The center index is the same as the provided index + # Except in cases where the edge of the group is reached, where + # the center index is adjusted to the edge of the group + slice_q, center_index = slice_from_center( + center=index, + window=window, + length=len(group.mzs), + ) + mzs = group.mzs[slice_q] + intensities = group.intensities[slice_q] + imss = group.imss[slice_q] + + center_mzs = mzs[center_index] + center_intensities = intensities[center_index] + center_ims = imss[center_index] + + bp_intensity_index = np.argmax(center_intensities) + bp_ims = center_ims[bp_intensity_index] + + if ims_tolerance_unit != "abs": + raise NotImplementedError() + ims_keep = np.abs(center_ims - bp_ims) <= ims_tolerance + + center_mzs = center_mzs[ims_keep] + center_intensities = center_intensities[ims_keep] + assert is_sorted(center_mzs) + + # TODO move this to its own helper function (collapse_unique ??) + # ... Maybe even "proprocess_ims_spec(mzs, imss, ints, ref_ims, ...)" + # We collapse all unique mzs, after filtering for IMS tolerance + # Note: Getting what indices were used to generate u_mzs[0] + # would be np.where(inv == 0) + # TODO testing equality and uniqueness on floats might not be wise. + # I should change this to an int ... maybe setting a "intensity bin" + # value like comet does (0.02??) + u_center_mzs, u_center_intensities, inv = _bin_spectrum_intensities( + center_mzs, + center_intensities, + bin_width=0.001, + bin_offset=0, + ) + assert is_sorted(u_center_mzs) + + xic_outs = [] + + for i, (m, inten, ims) in enumerate(zip(mzs, intensities, imss)): + # We first filter the peaks that are inside our IMS tolerance + # By getting their indices. + t_int_keep = np.abs(ims - bp_ims) <= ims_tolerance + t_int_keep = np.where(t_int_keep)[0] + + m = m[t_int_keep] + inten = inten[t_int_keep] + + u_mzs, u_intensities, inv = _bin_spectrum_intensities( + m, + inten, + bin_width=0.001, + bin_offset=0, + ) + + outs, inds = xic( + query_mz=u_mzs, + query_int=u_intensities, + mzs=u_center_mzs, + tolerance=mz_tolerance, + tolerance_unit=mz_tolerance_unit, + ) + + # Since inds are the indices used from that suvset array; + # We find what indices in the original array were used for + # each value. + out_inds = [] + for y in inds: + if len(y) > 0: + collapsed_indices = np.concatenate( + [np.where(inv == w)[0] for w in y], + ) + out_inds.append(np.unique(t_int_keep[collapsed_indices])) + else: + out_inds.append([]) + + xic_outs.append((outs, out_inds)) + if i == center_index: + assert xic_outs[-1][0].sum() >= u_center_intensities.max() + + stacked_arr = np.stack([x[0] for x in xic_outs], axis=-1) + + indices = [x[1] for x in xic_outs] + + if stacked_arr.shape[-2] > 1: + ref_id = np.argmax(stacked_arr[..., center_index]) + corrs = get_ref_trace_corrs(arr=stacked_arr, ref_idx=ref_id) + + # I think adding the 1e-5 is needed here due to numric instability + # in the flaoting point operation + assert np.max(corrs) <= ( + corrs[ref_id] + 1e-5 + ), "Reference does not have max corrr" + + max_peak_corr = np.sort(corrs)[-max_peaks] if len(corrs) > max_peaks else -1 + keep = corrs >= max(min_correlation, max_peak_corr) + keep_corrs = corrs[keep] + + stacked_arr = stacked_arr[..., keep, ::1] + u_center_mzs = u_center_mzs[keep] + u_center_intensities = u_center_intensities[keep] + indices = [[y for y, k in zip(x, keep) if k] for x in indices] + else: + keep_corrs = np.array([1.0]) + + ref_id = np.argmax(stacked_arr[..., center_index]) + bp_int = stacked_arr[ref_id, center_index] + # TODO: This might be a good place to plot the stacked chromatogram + + out = TimsStackedChromatograms( + array=stacked_arr, + mzs=u_center_mzs, + ref_index=ref_id, + parent_index=index, + base_peak_intensity=bp_int, + stack_peak_indices=indices, + center_intensities=u_center_intensities, + ref_ims=bp_ims, + correlations=keep_corrs, + ) + return out + + +def _bin_spectrum_intensities( + mzs: NDArray, + intensities: NDArray, + bin_width: float = 0.02, + bin_offset: float = 0.0, +) -> tuple[NDArray, NDArray, list[list[int]]]: + """Bins the intensities based on the mz values. + + Returns + ------- + new_mzs: + The new mz array + new_intensities + The new intensity array + inv: + Index for the new mzs in the original mzs and intensities + Note: Getting what indices were used to generate new_mzs[0] + would be np.where(inv == 0) + + """ + new_mzs, inv = np.unique( + np.rint((mzs + bin_offset) / bin_width), + return_inverse=True, + ) + new_mzs = (new_mzs * bin_width) - bin_offset + new_intensities = np.zeros(len(new_mzs), dtype=intensities.dtype) + np.add.at(new_intensities, inv, intensities) + return new_mzs, new_intensities, inv + + +@dataclass +class TimsScanGroup(ScanGroup): + """Represent all 'spectra' that share an isolation window.""" + + imss: list[NDArray] + precursor_imss: list[NDArray] + + def __post_init__(self) -> None: + """Validates that the values in the instance are consistent. + + Automatically runs when a new instance is created. + """ + super().__post_init__() + if len(self.imss) != len(self.mzs): + raise ValueError("IMS values do not have the same lenth as the MZ values") + + @classmethod + def _elems_from_fragment_cache(cls, file): + elems, data = super()._elems_from_fragment_cache(file) + elems["imss"] = data["imss"] + return elems, data + + @classmethod + def _precursor_elems_from_cache(cls, file): + elems, data = super()._precursor_elems_from_cache(file) + elems["imss"] = data["imss"] + return elems, data + + def to_cache(self, Path): + """Saves the group to a cache file.""" + super().to_cache(Path) + + def as_dataframe(self) -> pl.DataFrame: + """Returns a dataframe with the data in the group. + + The dataframe has the following columns: + - mzs: list of mzs for each spectrum + - intensities: list of intensities for each spectrum + - retention_times: retention times for each spectrum + - precursor_start: start of the precursor range + - precursor_end: end of the precursor range + - ims: list of ims values for each spectrum + """ + out = super().as_dataframe() + out["ims"] = self.imss + return out + + def precursor_dataframe(self) -> pl.DataFrame: + df = super().precursor_dataframe() + df = df.with_columns( + pl.Series(name="precursor_imss", values=self.precursor_imss), + ) + return df + + def get_highest_window( + self, + window: int, + min_intensity_ratio: float, + mz_tolerance: float, + mz_tolerance_unit: MassError, + ims_tolerance: float, + ims_tolerance_unit: IMSError, + min_correlation: float, + max_peaks: int, + ) -> TimsStackedChromatograms: + """Gets the highest intensity window of the chromatogram. + + Briefly ... + 1. Gets the highes peak accross all spectra in the chromatogram range. + 2. Finds what peaks are in that same spectrum. + 3. Looks for spectra around that spectrum. + 4. extracts the chromatogram for all mzs in the "parent spectrum" + + """ + top_index = np.argmax(self.base_peak_int) + window = TimsStackedChromatograms.from_group( + self, + window=window, + index=top_index, + min_intensity_ratio=min_intensity_ratio, + min_correlation=min_correlation, + mz_tolerance=mz_tolerance, + mz_tolerance_unit=mz_tolerance_unit, + ims_tolerance=ims_tolerance, + ims_tolerance_unit=ims_tolerance_unit, + max_peaks=max_peaks, + ) + + return window + + # This is an implementation of a method used by the parent class + def _scale_spectrum_at( + self, + spectrum_index: int, + value_indices: NDArray[np.int64], + scaling_factor: float, + ) -> None: + i = spectrum_index # Alias for brevity in within this function + + if len(value_indices) > 0: + self.intensities[i][value_indices] = ( + self.intensities[i][value_indices] * scaling_factor + ) + else: + return None + + # TODO this is hard-coded right now, change as a param + int_remove = self.intensities[i] < 10 + if np.any(int_remove): + self.intensities[i] = self.intensities[i][np.invert(int_remove)] + self.mzs[i] = self.mzs[i][np.invert(int_remove)] + self.imss[i] = self.imss[i][np.invert(int_remove)] + + if len(self.intensities[i]): + self.base_peak_int[i] = np.max(self.intensities[i]) + self.base_peak_mz[i] = self.mzs[i][np.argmax(self.intensities[i])] + else: + self.base_peak_int[i] = 0 + self.base_peak_mz[i] = 0 + + def __len__(self) -> int: + """Returns the number of spectra in the group.""" + return len(self.imss) + + +class TimsSpectrumStacker(SpectrumStacker): + """Helper class that stacks the spectra of TimsTof file into chromatograms.""" + + def __init__(self, filepath: PathLike, config: DiademConfig) -> None: + """Initializes the class. + + Parameters + ---------- + filepath : PathLike + Path to the TimsTof file + config : DiademConfig + Configuration object + """ + self.filepath = filepath + self.config = config + self.cache_location = Path(filepath).with_suffix(".centroided.parquet") + if self.cache_location.exists(): + logger.info(f"Found cache file at {self.cache_location}") + else: + df = get_timstof_data(filepath, centroid=True) + df.write_parquet(self.cache_location) + del df + + unique_windows = ( + pl.scan_parquet(self.cache_location) + .select(pl.col(["quad_low_mz_values", "quad_high_mz_values"])) + .filter(pl.col("quad_low_mz_values") > 0) + .sort("quad_low_mz_values") + .unique() + .collect() + ) + + if "DEBUG_DIADEM" in os.environ: + logger.error("RUNNING DIADEM IN DEBUG MODE (only the 4th precursor index)") + self.unique_precursor_windows = unique_windows[3:4].rows(named=True) + else: + self.unique_precursor_windows = unique_windows.rows(named=True) + + @contextmanager + def lazy_datafile(self) -> pl.LazyFrame: + """Scans the cached version of the data and yields it as a context manager.""" + yield pl.scan_parquet(self.cache_location) + + # @profile + def _precursor_iso_window_groups( + self, + precursor_window: dict[str, float], + ) -> dict[str:TimsScanGroup]: + elems = self._precursor_iso_window_elements(precursor_window) + prec_info = self._precursor_iso_window_elements( + {"quad_low_mz_values": -1, "quad_high_mz_values": -1}, + mz_range=list(precursor_window.values()), + ) + + assert is_sorted(prec_info["retention_times"]) + + out = TimsScanGroup( + precursor_mzs=prec_info["mzs"], + precursor_intensities=prec_info["intensities"], + precursor_retention_times=prec_info["retention_times"], + precursor_imss=prec_info["imss"], + **elems, + ) + return out + + # @profile + def _precursor_iso_window_elements( + self, + precursor_window: dict[str, float], + mz_range: None | tuple[float, float] = None, + ) -> dict[str : dict[str:NDArray]]: + with self.lazy_datafile() as datafile: + datafile: pl.LazyFrame + promise = ( + pl.col("quad_low_mz_values") == precursor_window["quad_low_mz_values"] + ) & ( + pl.col("quad_high_mz_values") == precursor_window["quad_high_mz_values"] + ) + ms_data = datafile.filter(promise).sort("rt_values") + + if mz_range is not None: + nested_cols = [ + "mz_values", + "corrected_intensity_values", + "mobility_values", + ] + non_nested_cols = [ + x for x in ms_data.head().collect().columns if x not in nested_cols + ] + ms_data = ( + ms_data.explode(nested_cols) + .filter(pl.col("mz_values").is_between(mz_range[0], mz_range[1])) + .groupby(pl.col(non_nested_cols)) + .agg(nested_cols) + .sort("rt_values") + ) + + ms_data = ms_data.collect() + + bp_indices = [np.argmax(x) for x in ms_data["corrected_intensity_values"]] + bp_ints = [ + x1.to_numpy()[x2] + for x1, x2 in zip(ms_data["corrected_intensity_values"], bp_indices) + ] + bp_ints = np.array(bp_ints) + bp_mz = [ + x1.to_numpy()[x2] for x1, x2 in zip(ms_data["mz_values"], bp_indices) + ] + bp_mz = np.array(bp_mz) + bp_indices = np.array(bp_indices) + rts = ms_data["rt_values"].to_numpy(zero_copy_only=True) + assert is_sorted(rts) + + quad_high = ms_data["quad_high_mz_values"][0] + quad_low = ms_data["quad_low_mz_values"][0] + window_name = str(quad_low) + "_" + str(quad_high) + + template = window_name + "_{}" + scan_indices = [template.format(i) for i in range(len(rts))] + + x = { + "precursor_range": (quad_low, quad_high), + "base_peak_int": bp_ints, + "base_peak_mz": bp_mz, + # 'base_peak_ims':bp_ims, + "iso_window_name": window_name, + "retention_times": rts, + "scan_ids": scan_indices, + } + orders = [np.argsort(x.to_numpy()) for x in ms_data["mz_values"]] + + x.update( + { + "mzs": [ + x.to_numpy()[o] for x, o in zip(ms_data["mz_values"], orders) + ], + "intensities": [ + x.to_numpy()[o] + for x, o in zip(ms_data["corrected_intensity_values"], orders) + ], + "imss": [ + x.to_numpy()[o] + for x, o in zip(ms_data["mobility_values"], orders) + ], + }, + ) + + return x + + def get_iso_window_groups(self, workerpool: None | Parallel) -> list[TimsScanGroup]: + """Get scan groups for each unique isolation window. + + Parameters + ---------- + workerpool : None | Parallel + If None, the function will be run in serial mode. + If Parallel, the function will be run in parallel mode. + The Parallel is created using joblib.Parallel. + + Returns + ------- + list[TimsScanGroup] + A list of TimsScanGroup objects. + Each of them corresponding to an unique isolation window from + the quadrupole. + """ + if workerpool is None: + results = [ + self._precursor_iso_window_groups(i) + for i in self.unique_precursor_windows + ] + else: + results = workerpool( + delayed(self._precursor_iso_window_groups)(i) + for i in self.unique_precursor_windows + ) + + return results + + def yield_iso_window_groups(self) -> Iterator[TimsScanGroup]: + """Yield scan groups for each unique isolation window.""" + for i in self.unique_precursor_windows: + results = self._precursor_iso_window_groups(i) + yield results + + +# @profile +def find_neighbors_mzsort( + ims_vals: NDArray[np.float32], + sorted_mz_values: NDArray[np.float32], + intensities: NDArray[np.float32], + top_n: int = 500, + top_n_pct: float = 0.1, + ims_tol: float = 0.02, + mz_tol: float = 0.02, + mz_tol_unit: MassError = "Da", +) -> dict[int : list[int]]: + """Finds the neighbors of the most intense peaks. + + It finds the neighboring peaks for the `top_n` most intense peaks + or the (TOTAL_PEAKS) * `top_n_pct` peaks, whichever is largest. + + Arguments: + --------- + ims_vals: NDArray[np.float32] + Array containing the ion mobility values of the precursor. + sorted_mz_values: NDArray[np.float32] + Sorted array contianing the mz values + intensities: NDArray[np.float32] + Array containing the intensities of the peaks + top_n : int + Number of peaks to use as seeds for neighborhood finding, + defautls to 500. + It will internally use the largest of either this number + or `len(intensities)*top_n_pct` + top_n_pct : float + Minimum percentage of the intensities to use as seeds for the + neighbors. defaults to 0.1 + ims_tol : float + Maximum distance to consider as a neighbor along the IMS dimension. + defaults to 0.02 + mz_tol : float + Maximum distance to consider as a neighbor along the MZ dimension. + defaults to 0.02 + mz_tol_unit : Literal['ppm', 'Da'] + Unit that describes the mz tolerance. + defaults to 'Da' + + """ + if mz_tol_unit.lower() == "da": + pass + elif mz_tol_unit.lower() == "ppm": + mz_tol: NDArray[np.float32] = get_tolerance( + mz_tol, + theoretical=sorted_mz_values, + unit="ppm", + ) + else: + raise ValueError("Only 'Da' and 'ppm' values are supported as mass errors") + + top_n = int(max(len(intensities) * top_n_pct, top_n)) + if len(intensities) > top_n: + top_indices = np.argpartition(intensities, -top_n)[-top_n:] + else: + top_indices = None + + opts = {} + for i1, (ims1, mz1) in enumerate(zip(ims_vals, sorted_mz_values)): + if top_indices is not None and i1 not in top_indices: + opts.setdefault(i1, []).append(i1) + continue + + candidates = np.arange( + np.searchsorted(sorted_mz_values, mz1 - mz_tol), + np.searchsorted(sorted_mz_values, mz1 + mz_tol, side="right"), + ) + + tmp_ims = ims_vals[candidates] + + match_indices = np.abs(tmp_ims - ims1) <= ims_tol + match_indices = np.where(match_indices)[0] + for i2 in candidates[match_indices]: + opts.setdefault(i1, [i1]).append(i2) + opts.setdefault(i2, [i2]).append(i1) + + opts = {k: list(set(v)) for k, v in opts.items()} + return opts + + +def get_break_indices( + inds: NDArray[np.int64], + min_diff: float | int = 1, + break_values: NDArray = None, +) -> tuple[NDArray[np.int64], NDArray[np.int64]]: + """Gets the incides and break values for an increasing array. + + Example: + ------- + >>> tmp = np.array([1,2,3,4,7,8,9,11,12,13]) + >>> bi = get_break_indices(tmp) + >>> bi + (array([ 0, 4, 7, 10]), array([ 1, 7, 11, 13])) + >>> [tmp[si: ei] for si, ei in zip(bi[0][:-1], bi[0][1:])] + [array([1, 2, 3, 4]), array([7, 8, 9]), array([11, 12, 13])] + """ + if break_values is None: + break_values = inds + breaks = 1 + np.where(np.diff(break_values) > min_diff)[0] + breaks = np.concatenate([np.array([0]), breaks, np.array([inds.size - 1])]) + + break_indices = inds[breaks] + breaks[-1] += 1 + + return breaks, break_indices diff --git a/diadem/data_io/utils.py b/diadem/data_io/utils.py new file mode 100644 index 0000000..4de178a --- /dev/null +++ b/diadem/data_io/utils.py @@ -0,0 +1,122 @@ +from collections.abc import Iterable + +import numpy as np +from ms2ml.utils.mz_utils import annotate_peaks +from numpy.typing import NDArray + +from diadem.config import MassError + + +def slice_from_center(center: int, window: int, length: int) -> tuple[slice, int]: + """Generates a slice provided a center and window size. + + Creates a slice that accounts for the endings of an iterable + in such way that the window size is maintained. + + Examples + -------- + >>> my_list = [0,1,2,3,4,5,6] + >>> slc, center_index = slice_from_center( + ... center=4, window=3, length=len(my_list)) + >>> slc + slice(3, 6, None) + >>> my_list[slc] + [3, 4, 5] + >>> my_list[slc][center_index] == my_list[4] + True + + >>> slc = slice_from_center(1, 3, len(my_list)) + >>> slc + (slice(0, 3, None), 1) + >>> my_list[slc[0]] + [0, 1, 2] + + >>> slc = slice_from_center(6, 3, len(my_list)) + >>> slc + (slice(4, 7, None), 2) + >>> my_list[slc[0]] + [4, 5, 6] + >>> my_list[slc[0]][slc[1]] == my_list[6] + True + + """ + start = center - (window // 2) + end = center + (window // 2) + 1 + center_index = window // 2 + + if start < 0: + start = 0 + end = window + center_index = center + + if end >= length: + end = length + start = end - window + center_index = window - (length - center) + + slice_q = slice(start, end) + return slice_q, center_index + + +try: + zip([], [], strict=True) + + def strictzip(*args: Iterable) -> Iterable: + """Like zip but checks that the length of all elements is the same.""" + return zip(*args, strict=True) + +except TypeError: + + def strictzip(*args: Iterable) -> Iterable: + """Like zip but checks that the length of all elements is the same.""" + # TODO optimize this, try to get the length and fallback to making it a list + args = [list(arg) for arg in args] + lengs = {len(x) for x in args} + if len(lengs) > 1: + raise ValueError("All arguments need to have the same legnths") + return zip(*args) + + +# @profile +def xic( + query_mz: NDArray[np.float32], + query_int: NDArray[np.float32], + mzs: NDArray[np.float32], + tolerance_unit: MassError = "da", + tolerance: float = 0.02, +) -> tuple[NDArray[np.float32], list[list[int]]]: + """Gets the extracted ion chromatogram form arrays. + + Gets the extracted ion chromatogram from the passed mzs and intensities + The output should be the same length as the passed mzs. + + Returns + ------- + NDArray[np.float32] + An array of length `len(mzs)` that integrates such masses in + the query_int (matching with the query_mz array ...) + + list[list[int]] + A nested list of length `len(mzs)` where each sub-list contains + the indices of the `query_int` array that were integrated. + + """ + theo_mz_indices, obs_mz_indices = annotate_peaks( + theo_mz=mzs, + mz=query_mz, + tolerance=tolerance, + unit=tolerance_unit, + ) + + outs = np.zeros_like(mzs, dtype="float") + inds = [] + for i in range(len(outs)): + query_indices = obs_mz_indices[theo_mz_indices == i] + ints_subset = query_int[query_indices] + if len(ints_subset) == 0: + inds.append([]) + else: + outs[i] = np.sum(ints_subset) + inds.append(query_indices) + + return outs, inds diff --git a/diadem/deisotoping.py b/diadem/deisotoping.py index 3cd4759..f7f64fb 100644 --- a/diadem/deisotoping.py +++ b/diadem/deisotoping.py @@ -1,6 +1,10 @@ +from __future__ import annotations + import numpy as np from numpy.typing import NDArray +from diadem.utilities.neighborhood import _multidim_neighbor_search + NEUTRON = 1.00335 @@ -21,17 +25,135 @@ def ppm_to_delta_mass(obs: float, ppm: float) -> float: return ppm * obs / 1_000_000.0 +# @profile +def _deisotope_with_ims_arrays( + mzs: NDArray[np.float32], + intensities: NDArray[np.float32], + imss: NDArray[np.float32] | None = None, + max_charge: int = 5, + ims_tolerance: float = 0.01, + mz_tolerance: float = 0.01, + track_indices: bool = False, +) -> dict[str, NDArray[np.float32]]: + """Deisotope a spectrum with IMS data. + + The current implementation allows only absolute values for the ims and mz + (no ppm). + + This function assumes that the IMS data has already been centroided (collapsed?) + and that the 3 arrays share ordering. + + If imss is None, will assume there is no IMS dimension in the data. + """ + if imss is None: + peak_iter = zip(mzs, intensities) + peaks = [ + { + "mz": mz, + "intensity": intensity, + "orig_intensity": intensity, + "envelope": None, + "charge": None, + } + for mz, intensity in peak_iter + ] + dim_order = ["mz"] + extract_values = ["mz", "intensity"] + spec = {"mz": mzs, "intensity": intensities} + else: + peak_iter = zip(mzs, intensities, imss) + peaks = [ + { + "mz": mz, + "ims": ims, + "intensity": intensity, + "orig_intensity": intensity, + "envelope": None, + "charge": None, + } + for mz, intensity, ims in peak_iter + ] + dim_order = ["mz", "ims"] + extract_values = ["mz", "intensity", "ims"] + spec = {"mz": mzs, "intensity": intensities, "ims": imss} + + dist_funs = {k: lambda x, y: y - x for k in dim_order} + # sort all elements by their first dimension + spec_order = np.argsort(spec["mz"]) + peaks = [peaks[i] for i in spec_order] + spec = {k: v[spec_order] for k, v in spec.items()} + + spec_indices = np.arange(len(spec["mz"])) + if track_indices: + extract_values.append("indices") + for i, peak in enumerate(peaks): + peak["indices"] = [i] + + # It might be faster to just generate an expanded + # array with all the charge variants and then do a + # single search. + for charge in range(max_charge + 1, 0, -1): + dist_ranges = { + "mz": ( + (NEUTRON / charge) - mz_tolerance, + (NEUTRON / charge) + mz_tolerance, + ), + "ims": (-ims_tolerance, ims_tolerance), + } + + out = _multidim_neighbor_search( + elems_1=spec, + elems_2=spec, + elems_1_indices=spec_indices, + elems_2_indices=spec_indices, + dist_funs=dist_funs, + dist_ranges=dist_ranges, + dimension_order=dim_order, + ) + + isotope_graph = out.left_neighbors + for i in np.argsort(-spec["mz"]): + if i not in isotope_graph: + continue + + for j in isotope_graph[i]: + # 0.5 is a magical number ... it is meant to account + # for the fact that intensity should in theory always increase + # (except for envelopes of high mass) but should also have some + # wiggle room for noise. + intensity_inrange = ( + 0.5 * (peaks[j]["orig_intensity"]) < peaks[i]["orig_intensity"] + ) + if intensity_inrange and peaks[j]["envelope"] is None: + peaks[i]["intensity"] += peaks[j]["intensity"] + peaks[j]["charge"] = charge + peaks[i]["charge"] = charge + peaks[j]["envelope"] = i + if track_indices: + peaks[i]["indices"].extend(peaks[j]["indices"]) + # At the end of this, all peaks that belong to an envelope + # have a value for "envelope". + # Therefore, peaks that need to be filtered out (since their + # intensity now bleongs to another peak). + + f_peaks = _filter_peaks(peaks, extract=extract_values) + f_peaks = dict(zip(extract_values, f_peaks)) + return f_peaks + + # Ported implementation from sage +# @profile def deisotope( - mz: NDArray[np.float32] | list[float], - inten: NDArray[np.float32] | list[float], + mz: NDArray[np.float32], + inten: NDArray[np.float32], max_charge: int, diff: float, unit: str, + track_indices: bool = False, ) -> tuple[NDArray[np.float32], NDArray[np.float32]]: """Deisotopes the passed spectra. - TE MZS NEED TO BE SORTED!! + THE MZS NEED TO BE SORTED!! Parameters ---------- @@ -45,44 +167,124 @@ def deisotope( Tolerance to use when searching (typically 20 for ppm or 0.02 for da) unit, str Unit for the diff. ppm or da - """ - if unit.lower() == "da": + track_indices, bool + Whether to return the indices of the combined indices as well. - def mass_unit_fun(x: float, y: float) -> float: - return y + Examples + -------- + >>> my_mzs = np.array([800.9, 803.408, 804.4108, 805.4106]) + >>> my_intens = np.array([1-(0.1*i) for i,_ in enumerate(my_mzs)]) + >>> deisotope(my_mzs, my_intens, max_charge=2, diff=5.0, unit="ppm") + (array([800.9 , 803.408]), array([1. , 2.4])) + >>> deisotope(my_mzs, my_intens, max_charge=2, diff=5.0, unit="ppm", track_indices=True) + (array([800.9 , 803.408]), array([1. , 2.4]), ([0], [1, 2, 3])) + """ # noqa: + if unit.lower() == "da": + mz_tolerance = diff elif unit.lower() == "ppm": - mass_unit_fun = ppm_to_delta_mass + # Might give wider tolerances than wanted to the lower end of the values + # but should be good enough for most cases. + mz_tolerance = ppm_to_delta_mass(mz.max(), diff) else: raise NotImplementedError("Masses need to be either 'da' or 'ppm'") - peaks = [(mz, intensity) for mz, intensity in zip(mz, inten)] - peaks = [ - {"mz": mz, "intensity": intensity, "envelope": None, "charge": None} - for mz, intensity in peaks - ] - - for i in range(len(mz) - 1, -1, -1): - j = i - 1 - while j >= 0 and mz[i] - mz[j] <= NEUTRON + mass_unit_fun(mz[i], diff): - delta = mz[i] - mz[j] - tol = mass_unit_fun(mz[i], diff) - for charge in range(1, max_charge + 1): - iso = NEUTRON / charge - if abs(delta - iso) <= tol and inten[i] < inten[j]: - peaks[j]["intensity"] += peaks[i]["intensity"] - if peaks[i]["charge"] and peaks[i]["charge"] != charge: - continue - peaks[j]["charge"] = charge - peaks[i]["charge"] = charge - peaks[i]["envelope"] = j - j -= 1 + peaks = _deisotope_with_ims_arrays( + imss=None, + mzs=mz, + intensities=inten, + max_charge=max_charge, + mz_tolerance=mz_tolerance, + track_indices=track_indices, + ) + + return tuple(peaks.values()) + + +# TODO this is a function prototype, if it works abstract it and +# combine with the parent function. + + +# @profile +def deisotope_with_ims( + mz: NDArray[np.float32], + inten: NDArray[np.float32], + imss: NDArray[np.float32], + max_charge: int, + mz_diff: float, + mz_unit: str, + ims_diff: float, + ims_unit: float, + track_indices: bool = False, +) -> tuple[NDArray[np.float32], NDArray[np.float32]]: + """Deisotopes the passed spectra. + + THE MZS NEED TO BE SORTED!! + + Parameters + ---------- + mz, list[float] + A list of the mz values to be deisotoped. (NEEDS to be sorted) + inten, list[float] + A list of the intensities that correspond in order to the elements in mz. + imss, list[float] + A list of the ims values to use for the search. + max_charge, int + Maximum charge to look for in the isotopes. + mz_diff, float + Tolerance to use when searching (typically 20 for ppm or 0.02 for da) + mz_unit, str + Unit for the diff. ppm or da + ims_diff, float + Tolerance to use when searching in the IMS dimension. + ims_unit, float + Tolerance unit to use when searching in the ims dimension. + track_indices, bool + Whether to return the indices of the combined indices as well. + + Examples + -------- + >>> my_mzs = np.array([800.9, 803.408, 803.409, 804.4108, 804.4109, 805.4106]) + >>> my_imss = np.array([0.7, 0.7, 0.8, 0.7, 0.8, 0.7]) + >>> my_intens = np.array([1-(0.1*i) for i,_ in enumerate(my_mzs)]) + >>> deisotope_with_ims(my_mzs, my_intens, my_imss, max_charge=2, + ... mz_diff=5.0, mz_unit="ppm", ims_diff=0.01, ims_unit="abs") + (array([800.9 , 803.408, 803.409]), array([1. , 2.1, 1.4]), array([0.7, 0.7, 0.8])) + >>> deisotope_with_ims(my_mzs, my_intens, my_imss, max_charge=2, + ... mz_diff=5.0, mz_unit="ppm", ims_diff=0.01, ims_unit="abs", track_indices=True) + (array([800.9 , 803.408, 803.409]), array([1. , 2.1, 1.4]), array([0.7, 0.7, 0.8]), ([0], [1, 3, 5], [2, 4])) + """ # noqa E501 + if mz_unit.lower() == "da": + mz_tolerance = mz_diff + + elif mz_unit.lower() == "ppm": + mz_tolerance = ppm_to_delta_mass(mz.max(), mz_diff) + else: + raise NotImplementedError("Masses need to be either 'da' or 'ppm'") + + if ims_unit.lower() == "abs": + ims_tolerance = ims_diff + + else: + raise NotImplementedError("only abs is supported as an IMS difference") + + peaks = _deisotope_with_ims_arrays( + imss=imss, + mzs=mz, + intensities=inten, + max_charge=max_charge, + mz_tolerance=mz_tolerance, + ims_tolerance=ims_tolerance, + track_indices=track_indices, + ) - peaks = _filter_peaks(peaks) - return peaks + return tuple(peaks.values()) -def _filter_peaks(peaks: dict) -> tuple[NDArray[np.float32], NDArray[np.float32]]: +def _filter_peaks( + peaks: list[dict], + extract: tuple[str], +) -> tuple[NDArray[np.float32] | list, ...]: """Filters peaks to remove isotope envelopes. When passed a list of dictionaries that look like this: @@ -90,9 +292,13 @@ def _filter_peaks(peaks: dict) -> tuple[NDArray[np.float32], NDArray[np.float32] It filters the ones that are not assigned to be in an envelope, thus keeping only monoisotopic peaks. """ - peaktuples = [(x["mz"], x["intensity"]) for x in peaks if x["envelope"] is None] + peaktuples = [[x[y] for y in extract] for x in peaks if x["envelope"] is None] if len(peaktuples) == 0: - mzs, ints = [], [] + out_tuple = tuple([] for _ in extract) else: - mzs, ints = zip(*peaktuples) - return np.array(mzs), np.array(ints) + out_tuple = zip(*peaktuples) + out_tuple = tuple( + np.array(x) if y in ["mz", "intensity", "ims"] else x + for x, y in zip(out_tuple, extract) + ) + return out_tuple diff --git a/diadem/index/fragment_buckets.py b/diadem/index/fragment_buckets.py index 259f190..9883510 100644 --- a/diadem/index/fragment_buckets.py +++ b/diadem/index/fragment_buckets.py @@ -1,13 +1,14 @@ from __future__ import annotations +from collections.abc import Iterator from dataclasses import dataclass -from typing import Iterator, Literal +from typing import Literal import numpy as np from numpy.typing import NDArray from tqdm.auto import tqdm -from diadem.utils import get_slice_inds, is_sorted +from diadem.utilities.utils import get_slice_inds, is_sorted SortingLevel = Literal["ms1", "ms2"] @@ -152,12 +153,14 @@ def concatenate(cls, *args: FragmentBucket) -> FragmentBucket: assert ( len(sorting_level) == 1 ), "Cannot concatenate buckets with different sorting levels" + if len(args) == 1: + return args[0] return cls( fragment_mzs=np.concatenate([x.fragment_mzs for x in args]), fragment_series=np.concatenate([x.fragment_series for x in args]), precursor_ids=np.concatenate([x.precursor_ids for x in args]), precursor_mzs=np.concatenate([x.precursor_mzs for x in args]), - is_sorted=True, + is_sorted=False, sorting_level=sorting_level.pop(), ) @@ -234,8 +237,10 @@ def __getitem__(self, val: int | slice) -> FragmentBucket | FragmentBucketList: return self.buckets[val] else: raise ValueError( - f"Subsetting FragmentBucketList with {type(val)}: {val} is not" - " supported" + ( + f"Subsetting FragmentBucketList with {type(val)}: {val} is not" + " supported" + ), ) @classmethod @@ -283,7 +288,7 @@ def from_arrays( precursor_mzs=precursor_mzs[i : i + chunksize], sorting_level=sorting_level, is_sorted=been_sorted, - ) + ), ) return cls(buckets) @@ -294,7 +299,9 @@ def sort(self, level: SortingLevel) -> None: # @profile def yield_candidates( - self, ms2_range: tuple[float, float], ms1_range: tuple[float, float] + self, + ms2_range: tuple[float, float], + ms1_range: tuple[float, float], ) -> None | Iterator[tuple[int, float, str]]: """Yields fragments that match the passed masses. @@ -326,7 +333,9 @@ def yield_candidates( continue yield from zip( - bucket.precursor_ids, bucket.fragment_mzs, bucket.fragment_series + bucket.precursor_ids, + bucket.fragment_mzs, + bucket.fragment_series, ) @@ -365,7 +374,9 @@ def __init__( for k, v in unpacked.items(): self.buckets[k].append(v) - iterator = enumerate(tqdm(self.buckets, disable=not progress)) + iterator = enumerate( + tqdm(self.buckets, disable=not progress, desc="Concatenating buckets"), + ) # This gets progressively updated with the minimum bucket size. min_ms2_mz = 2**15 @@ -383,6 +394,7 @@ def __init__( self.min_ms2_mz = min_ms2_mz + # @profile def unpack_bucket(self, bucket: FragmentBucket) -> dict[int, FragmentBucket]: """Unpacks a bucket into a dictionary of buckets. @@ -392,11 +404,18 @@ def unpack_bucket(self, bucket: FragmentBucket) -> dict[int, FragmentBucket]: """ # TODO decide if this should be a fragment bucket method... integerized = (bucket.fragment_mzs * self.prod_num).astype(int) + sorted_lookup = len(integerized > 1000) and is_sorted(integerized) uniqs = np.unique(integerized) out = {} for u in uniqs: - idxs = integerized == u + if sorted_lookup: + idxs = slice( + np.searchsorted(integerized, u, "left"), + np.searchsorted(integerized, u, "right"), + ) + else: + idxs = integerized == u # TODO consider here not traking precursor mzs anymore out[u] = FragmentBucket( @@ -408,7 +427,9 @@ def unpack_bucket(self, bucket: FragmentBucket) -> dict[int, FragmentBucket]: return out def yield_buckets_matching_ms2( - self, min_mz: float, max_mz: float + self, + min_mz: float, + max_mz: float, ) -> Iterator[FragmentBucket]: """Yields buckets that match the passed ms2 range.""" min_index = max(0, int(min_mz * self.prod_num) - 1) @@ -420,7 +441,9 @@ def yield_buckets_matching_ms2( # @profile def yield_candidates( - self, ms2_range: tuple[float, float], ms1_range: tuple[float, float] + self, + ms2_range: tuple[float, float], + ms1_range: tuple[float, float], ) -> Iterator[tuple[int, float, str]]: """Yields fragments that match the passed masses. @@ -446,10 +469,12 @@ def yield_candidates( min_mz, max_mz = ms2_range last_mz = 0 for x in self.yield_buckets_matching_ms2(min_mz, max_mz): + assert is_sorted(x.fragment_mzs) + last_val = np.searchsorted(x.fragment_mzs, max_mz, "right") for precursor_id, fragmz, fragseries in zip( - x.precursor_ids, x.fragment_mzs, x.fragment_series + x.precursor_ids[:last_val], + x.fragment_mzs[:last_val], + x.fragment_series[:last_val], ): - if fragmz > max_mz: - break assert last_mz <= (last_mz := fragmz), "Fragment mzs not sorted" yield precursor_id, fragmz, fragseries diff --git a/diadem/index/indexed_db.py b/diadem/index/indexed_db.py index fd05903..4a8035f 100644 --- a/diadem/index/indexed_db.py +++ b/diadem/index/indexed_db.py @@ -2,10 +2,10 @@ import copy from collections import namedtuple +from collections.abc import Iterable, Iterator from functools import lru_cache from os import PathLike from pathlib import Path -from typing import Iterable, Iterator import numpy as np import pandas as pd @@ -26,7 +26,7 @@ FragmentBucketList, PrefilteredMS1BucketList, ) -from diadem.utils import disabled_gc, is_sorted, make_decoy +from diadem.utilities.utils import disabled_gc, is_sorted, make_decoy # Pre-calculating factorials so I do not need # to calculate them repeatedly while scoring @@ -63,14 +63,15 @@ def _make_score_dict(ions: str) -> dict[str, dict[str, float | list[float]]]: SeqProperties = namedtuple( - "SeqProperties", "fragments, ion_series, prec_mz, proforma_seq, num_frags" + "SeqProperties", + "fragments, fragment_positions, ion_series, prec_mz, proforma_seq, num_frags", ) class PeptideScore: """Accumulates elements to calculate the score for a peptide.""" - __slots__ = ("id", "ions", "partial_scores", "tot_peaks") + __slots__ = ("id", "ions", "partial_scores", "tot_peaks", "peak_ids") def __init__(self, id: int, ions: str) -> None: """Accumulates elements to calculate the score for a peptide. @@ -85,20 +86,28 @@ def __init__(self, id: int, ions: str) -> None: Examples -------- >>> score = PeptideScore(1, "by") - >>> score.add_peak('y', mz = 234.22, intensity = 100, error = 0.01) - >>> score.add_peak('y', mz = 534.22, intensity = 200, error = 0.012) + >>> score.add_peak("y", mz=234.22, intensity=100, error=0.01, peak_id=1) + >>> score.add_peak("y", mz=534.22, intensity=200, error=0.012, peak_id=2) >>> outs = score.as_row_entry() - >>> [x for x in outs] # doctest: +NORMALIZE_WHITESPACE - ['id', 'b_intensity', 'b_npeaks', 'b_mzs', 'y_intensity', 'y_npeaks',\ - 'y_mzs', 'log_intensity_sums', 'log_factorial_peak_sum', 'mzs',\ - 'mass_errors', 'avg_abs_dm', 'med_abs_dm'] + >>> [x for x in outs] # doctest: +NORMALIZE_WHITESPACE + ['id', 'b_intensity', 'b_npeaks', 'b_mzs', 'y_intensity', \ + 'y_npeaks', 'y_mzs', 'log_intensity_sums', 'log_factorial_peak_sum', \ + 'mzs', 'mass_errors', 'avg_abs_dm', 'med_abs_dm', 'spec_indices'] """ # noqa self.id: int = id self.ions = ions self.partial_scores = copy.deepcopy(_make_score_dict(ions)) self.tot_peaks = 0 + self.peak_ids = [] - def add_peak(self, ion: str, mz: float, intensity: float, error: float) -> None: + def add_peak( + self, + ion: str, + mz: float, + intensity: float, + error: float, + peak_id: int, + ) -> None: """Adds a peak to the partial score. Check the class docstring for more details. @@ -113,11 +122,15 @@ def add_peak(self, ion: str, mz: float, intensity: float, error: float) -> None: The intensity of the peak to add. error : float The mass error of the peak to add. + peak_id: int + Id of the peak. It is usefull to track what peaks within + a spectrum were a match. """ self.partial_scores[ion]["intensities"] += intensity self.partial_scores[ion]["npeaks"] += 1 self.partial_scores[ion]["mzs"].append(mz) self.partial_scores[ion]["mass_errors"].append(error) + self.peak_ids.append(peak_id) self.tot_peaks += 1 def as_row_entry(self) -> dict[str, float | list[float] | int]: @@ -150,6 +163,7 @@ def as_row_entry(self) -> dict[str, float | list[float] | int]: out["mass_errors"] = mass_errors out["avg_abs_dm"] = np.abs(np.array(mass_errors)).mean() out["med_abs_dm"] = np.median(np.abs(np.array(mass_errors))) + out["spec_indices"] = self.peak_ids return out @@ -167,7 +181,10 @@ class IndexedDb: """ def __init__( - self, chunksize: int, config: DiademConfig = DEFAULT_CONFIG, name: str = "db" + self, + chunksize: int, + config: DiademConfig = DEFAULT_CONFIG, + name: str = "db", ) -> None: """Creates a new IndexedDb object. @@ -215,11 +232,21 @@ def targets(self, value: list[Peptide]) -> None: value : list[Peptide] A list of peptide objects to set as the targets for the database. """ + seen = set() + use = [] + # Since right now my peptide implementation is not hashable, + # this needs to be done to make sure all targets are unique. for x in value: x.config = self.ms2ml_config - self._targets = value + if (proforma := x.to_proforma()) in seen: + continue + else: + use.append(x) + seen.add(proforma) + + self._targets = use self._decoys = None - self.target_proforma = {x.to_proforma() for x in value} + self.target_proforma = {x.to_proforma() for x in use} @property def decoys(self) -> list[Peptide]: @@ -235,8 +262,10 @@ def decoys(self) -> list[Peptide]: make_decoy(x) for x in tqdm(targets, desc="Generating Decoys") ] logger.info( - f"Generating database with {len(self.decoys)} decoys," - f" and {len(self.targets)} targets" + ( + f"Generating database with {len(self.decoys)} decoys," + f" and {len(self.targets)} targets" + ), ) return self._decoys @@ -269,15 +298,18 @@ def ms1_filter(pep: Peptide) -> Peptide | None: config=self.config.ms2ml_config, only_unique=True, enzyme=self.config.db_enzyme, - missed_cleavages=1, + missed_cleavages=self.config.db_max_missed_cleavages, allow_modifications=False, out_hook=ms1_filter, ) sequences = list(adapter.parse()) + assert len(sequences) == len({x.to_proforma() for x in sequences}) self.targets = sequences def prefilter_ms1( - self, ms1_range: tuple[float, float], num_decimals: int = 3 + self, + ms1_range: tuple[float, float], + num_decimals: int = 3, ) -> IndexedDb: """Prefilters the database. @@ -304,7 +336,8 @@ def prefilter_ms1( out = copy.copy(self) out.bucketlist = self.bucketlist.prefilter_ms1( - *ms1_range, num_decimals=num_decimals + *ms1_range, + num_decimals=num_decimals, ) out.prefiltered_ms1 = True out.seq_prec_mzs = self.seq_prec_mzs @@ -358,7 +391,7 @@ def index_from_parquet(self, dir: Path | str) -> None: self.seqs = seqs_df["seq_proforma"].values self.target_proforma = set( - (seqs_df["seq_proforma"][np.invert(seqs_df["decoy"])]).values + (seqs_df["seq_proforma"][np.invert(seqs_df["decoy"])]).values, ) frags_df = pd.read_parquet( @@ -368,7 +401,7 @@ def index_from_parquet(self, dir: Path | str) -> None: frags_df = frags_df[frags_df["mz"] < self.config.ion_mz_range[1]] self.target_proforma = set( - (seqs_df["seq_proforma"][np.invert(seqs_df["decoy"])]).values + (seqs_df["seq_proforma"][np.invert(seqs_df["decoy"])]).values, ) self.index_from_arrays( frags_df["mz"].values, @@ -376,6 +409,7 @@ def index_from_parquet(self, dir: Path | str) -> None: frag_to_prec_ids=frags_df["seq_id"].values, prec_mzs=seqs_df["seq_mz"].values, prec_seqs=seqs_df["seq_proforma"].values, + frag_positions=frags_df["ion_position"].values, ) # @profile @@ -414,22 +448,42 @@ def _dump_peptides_parquet( frag_chunk = { "mz": [], "ion_series": [], + "ion_position": [], "seq_id": [], + "precursor_mz": [], } append = False if seq_file_path.exists(): append = True - for seq_id, (frag_mzs, ion_series, prec_mzs, prec_seqs, num_frags) in enumerate( - (self.seq_properties(x) for x in iter_seqs), start=start_id - ): + + my_iter = enumerate( + (self.seq_properties(x) for x in iter_seqs), + start=start_id, + ) + for seq_id, ( + frag_mzs, + ion_positions, + ion_series, + prec_mzs, + prec_seqs, + num_frags, + ) in my_iter: seq_chunk["seq_id"].append(seq_id) seq_chunk["seq_mz"].append(prec_mzs) seq_chunk["seq_proforma"].append(prec_seqs) seq_chunk["decoy"].append(decoy) - for x, y, z in zip(frag_mzs, ion_series, [seq_id] * num_frags): + for w, x, x2, y, z in zip( + [prec_mzs] * num_frags, + frag_mzs, + ion_positions, + ion_series, + [seq_id] * num_frags, + ): + frag_chunk["precursor_mz"].append(float(w)) frag_chunk["mz"].append(float(x)) + frag_chunk["ion_position"].append(x2) frag_chunk["ion_series"].append(y) frag_chunk["seq_id"].append(z) @@ -438,8 +492,13 @@ def _dump_peptides_parquet( append = True write_parquet(seq_file_path, pd.DataFrame(seq_chunk), append=append) write_parquet( - fragment_file_path, pd.DataFrame(frag_chunk), append=append + fragment_file_path, + pd.DataFrame(frag_chunk), + append=append, ) + + # This just flushes the chunk so next iteration starts + # clean. for x in seq_chunk: seq_chunk[x] = [] for x in frag_chunk: @@ -457,11 +516,13 @@ def index_from_sequences(self) -> None: Usage ----- - > db = IndexedDb(...) - > db.targets = [Peptide(...), Peptide(...)] + ``` + db = IndexedDb(...) + db.targets = [Peptide(...), Peptide(...)] or - > db.targets_from_fasta(...) - > db.index_from_sequences() + db.targets_from_fasta(...) + db.index_from_sequences() + ``` Note: ---- @@ -480,14 +541,15 @@ def index_from_sequences(self) -> None: miniters=one_pct, ) - frag_mzs, frag_series, prec_mzs, prec_seqs, num_frags = zip( - *(self.seq_properties(x) for x in iter_seqs) + frag_mzs, frag_position, frag_series, prec_mzs, prec_seqs, num_frags = zip( + *(self.seq_properties(x) for x in iter_seqs), ) # NOTE: Changing to float16 does not give the correct result prec_mzs = np.array(prec_mzs, dtype="float32") prec_seqs = np.array(prec_seqs, dtype="object") frag_mzs = np.concatenate(list(frag_mzs)).astype("float32") + frag_position = np.concatenate(list(frag_position)).astype("int8") frag_series = np.concatenate(list(frag_series)) seq_ids = np.empty_like(frag_mzs, dtype=int) @@ -508,12 +570,14 @@ def index_from_sequences(self) -> None: frag_to_prec_ids=seq_ids, prec_mzs=prec_mzs, prec_seqs=prec_seqs, + frag_positions=frag_position, ) def index_from_arrays( self, frag_mzs: NDArray[np.float32], frag_series: NDArray[np.str], + frag_positions: NDArray[np.int8], frag_to_prec_ids: NDArray[np.int64], prec_mzs: NDArray[np.float32], prec_seqs: NDArray[np.str], @@ -527,6 +591,9 @@ def index_from_arrays( An array of fragment m/z values. frag_series : NDArray[np.str] An array of fragment ion series. + frag_series : NDArray[np.int8] + An array of the positions of the fragment ions. + (for example 1 for b1, 2 for b2, 3 for b3, etc.) frag_to_prec_ids : NDArray[np.int64] An array of sequence ids. (unique identifier of a peptide sequence) prec_mzs : NDArray[np.float32] @@ -545,34 +612,62 @@ def index_from_arrays( also have to be the same. """ - assert len(prec_mzs) == len(prec_seqs) - assert all(len(frag_mzs) == len(x) for x in [frag_series, frag_to_prec_ids]) + if not len(prec_mzs) == len(prec_seqs): + raise ValueError( + ( + "The length of the precursor mzs and the precursor sequences needs" + " to be the same." + ), + ) + if not all( + len(frag_mzs) == len(x) + for x in [frag_series, frag_to_prec_ids, frag_positions] + ): + raise ValueError( + ( + f"The length of the frag_mz {len(frag_mzs)}, frag_series" + f" {len(frag_series)} and frag_to_prec_ids {len(frag_to_prec_ids)}," + f" frag_positions {len(frag_positions)} need to be the same" + ), + ) + if not len(prec_seqs) == len(np.unique(prec_seqs)): + raise ValueError("All precursor sequences need to be unique!") # Sorted externally by ms2 mz with disabled_gc(): logger.debug( - f"Sorting by ms2 mz. {frag_mzs.size} total fragments (if needed)" + f"Sorting by ms2 mz. {frag_mzs.size} total fragments (if needed)", ) if not is_sorted(frag_mzs): - sorted_frags, sorted_frag_series, sorted_seq_ids = sort_all( - frag_mzs, frag_series, frag_to_prec_ids + sorted_frags, sorted_frag_series, sorted_seq_ids, frag_positions = ( + sort_all( + frag_mzs, + frag_series, + frag_to_prec_ids, + frag_positions, + ) ) logger.debug("Done sorting (and GC), generating bucketlists") else: - logger.debug("Skippping sortinb because it is already sorted.") - sorted_frags, sorted_frag_series, sorted_seq_ids = ( + logger.debug("Skippping sorting because it is already sorted.") + sorted_frags, sorted_frag_series, sorted_seq_ids, frag_positions = ( frag_mzs, frag_series, frag_to_prec_ids, + frag_positions, ) del frag_mzs, frag_to_prec_ids, frag_series + # Temporary location for this, will be moved if it seems to give better results + # TODO + MIN_POSITION = 3 + self.bucketlist = FragmentBucketList.from_arrays( - fragment_mzs=sorted_frags, - fragment_series=sorted_frag_series, - precursor_ids=sorted_seq_ids, - precursor_mzs=prec_mzs[sorted_seq_ids], + fragment_mzs=sorted_frags[frag_positions >= MIN_POSITION], + fragment_series=sorted_frag_series[frag_positions >= MIN_POSITION], + precursor_ids=sorted_seq_ids[frag_positions >= MIN_POSITION], + precursor_mzs=prec_mzs[sorted_seq_ids[frag_positions >= MIN_POSITION]], chunksize=self.chunksize, sorting_level="ms2", been_sorted=True, @@ -588,22 +683,38 @@ def index_from_arrays( # @profile def seq_properties(self, x: Peptide) -> SeqProperties: """Internal method that extracts the peptide properties to build the index.""" - masses = { - k: np.concatenate( - [x.ion_series(ion_type=k, charge=c) for c in x.config.ion_charges] - ) - for k in x.config.ion_series - } + masses = {} + ion_positions = [] + for k in x.config.ion_series: + ions = [] + positions = [] + for c in x.config.ion_charges: + curr_ions = x.ion_series(ion_type=k, charge=c) + curr_pos = np.arange(len(curr_ions)) + 1 + ions.append(curr_ions) + positions.append(curr_pos) + + ion_positions.extend(positions) + masses[k] = np.concatenate(ions) ion_series = np.concatenate( - [np.full_like(v, k, dtype=str) for k, v in masses.items()] + [np.full_like(v, k, dtype=str) for k, v in masses.items()], ) masses = np.concatenate(list(masses.values())) + positions = np.concatenate(ion_positions) # TODO move this to the config ... mass_mask = (masses > 150) * (masses < 2000) masses = masses[mass_mask] ion_series = ion_series[mass_mask] - out = SeqProperties(masses, ion_series, x.mz, x.to_proforma(), len(masses)) + positions = positions[mass_mask] + out = SeqProperties( + masses, + positions, + ion_series, + x.mz, + x.to_proforma(), + len(masses), + ) return out # @profile @@ -614,6 +725,8 @@ def yield_candidates( ) -> Iterator[tuple[int, float, str]]: """Yields candidate fragments that match both an ms1 and an ms2 range. + Nore: MS1 range is ignored when the database has been pre-filtered. + Parameters ---------- ms2_range : tuple[float, float] @@ -631,7 +744,7 @@ def score_arrays( precursor_mz: float | tuple[float, float], spec_mz: Iterable[float], spec_int: Iterable[float], - ) -> DataFrame: + ) -> DataFrame | None: """Scores a spectrum against the index. The result is a data frame containing all generic data required to @@ -657,7 +770,7 @@ def score_arrays( ms1_range = precursor_mz else: raise ValueError( - "precursor_mz has to be of length 2 or a single number" + "precursor_mz has to be of length 2 or a single number", ) else: ms1_tol = get_tolerance( @@ -669,7 +782,7 @@ def score_arrays( peaks = [] - for fragment_mz, fragment_intensity in zip(spec_mz, spec_int): + for i, (fragment_mz, fragment_intensity) in enumerate(zip(spec_mz, spec_int)): ms2_tol = get_tolerance( self.config.g_tolerances[1], theoretical=fragment_mz, @@ -680,20 +793,24 @@ def score_arrays( ms1_range=ms1_range, ms2_range=(fragment_mz - ms2_tol, fragment_mz + ms2_tol), ) + candidates = list(candidates) + dms = np.array([x[1] for x in candidates]) + dms = np.abs(dms - fragment_mz) + candidates = [x for x, y in zip(candidates, dms <= ms2_tol) if y] - for seq, frag, series in candidates: + for (seq, _frag, series), dm in zip(candidates, dms): # Should tolerances be checked here? - dm = frag - fragment_mz - if abs(dm) <= ms2_tol: - peaks.append( - { - "seq": seq, - "ion": series, - "mz": fragment_mz, - "intensity": fragment_intensity, - "error": dm, - } - ) + # IN THEORY, they should have been filtered in the past. + peaks.append( + { + "seq": seq, + "ion": series, + "mz": fragment_mz, + "intensity": fragment_intensity, + "error": dm, + "peak_id": i, + }, + ) peptide_ids = np.array([x["seq"] for x in peaks]) ids, counts = np.unique(peptide_ids, return_counts=True) @@ -749,20 +866,27 @@ def hyperscore( A dataframe with the top scoring peptides. """ assert len(self.seq_ids) == len(self.seqs) + assert len(self.seq_prec_mzs) == len(self.seqs) scores = self.score_arrays( - precursor_mz=precursor_mz, spec_mz=spec_mz, spec_int=spec_int + precursor_mz=precursor_mz, + spec_mz=spec_mz, + spec_int=spec_int, ) if scores is None or len(scores) == 0: return None scores["Score"] = ( - scores["log_factorial_peak_sum"] + scores["log_intensity_sums"] + scores["log_factorial_peak_sum"] + + scores["log_intensity_sums"] + # scores["log_intensity_sums"] ) # Calculate requirements for the z score among all other proposed scores! score_mean = scores["Score"].mean() score_sd = scores["Score"].std() + # TODO reconsider if I want to do the filtering here. + # If i dont, I could use the precursor information as a filter ... scores = scores.nlargest(top_n, "Score", keep="all") scores.sort_values("Score", ascending=False, inplace=True) @@ -774,14 +898,35 @@ def hyperscore( indices_seqs_local = np.searchsorted(self.seq_ids, scores["id"].values) assert np.allclose(self.seq_ids[indices_seqs_local], scores["id"].values) - scores["Peptide"] = self.seqs[indices_seqs_local] - scores["decoy"] = [s not in self.target_proforma for s in scores["Peptide"]] + scores["peptide"] = self.seqs[indices_seqs_local] + scores["PrecursorMZ"] = self.seq_prec_mzs[indices_seqs_local] + scores["decoy"] = [s not in self.target_proforma for s in scores["peptide"]] + try: + assert len(np.unique(scores["peptide"])) == len(scores), np.unique( + scores["peptide"], + return_counts=True, + ) + except AssertionError: + # There is a bug that gets detected here where a single peptide gets + # scored multiple times ... Usually with different IDs + + # This issue happens when a sequence is also in the decoys + logger.error( + ( + f"{scores} has multiple peptides with the " + "same id (ocasionally happens when it is both a " + "target and a decoy)" + ), + ) return scores # @profile def index_prefiltered_from_parquet( - self, cache_path: PathLike, min_mz: float, max_mz: float + self, + cache_path: PathLike, + min_mz: float, + max_mz: float, ) -> IndexedDb: """Generates a pre-filtered index from a parquet cache. @@ -795,61 +940,76 @@ def index_prefiltered_from_parquet( seqs_df = pl.scan_parquet(str(cache_path) + "/seqs.parquet") logger.info( - f"Filtering ms1 ranges {min_mz} to {max_mz} in database {self.name}" - ) - chunk_seq_df = seqs_df.filter(pl.col("seq_mz") >= min_mz).filter( - pl.col("seq_mz") < max_mz + f"Filtering ms1 ranges {min_mz} to {max_mz} in database {self.name}", ) + joint_frags = ( - frags_df.join( - chunk_seq_df.select(["seq_id", "seq_mz"]), on="seq_id", how="inner" - ) + frags_df.filter(pl.col("precursor_mz") >= min_mz) + .filter(pl.col("precursor_mz") < max_mz) + # .unique(maintain_order=False) .sort(pl.col("mz")) .collect() ) + logger.debug(f"Loaded {len(joint_frags)} fragments from parquet cache.") out = copy.copy(self) out.bucketlist = PrefilteredMS1BucketList( [ FragmentBucket( - fragment_mzs=joint_frags["mz"].to_numpy(), + fragment_mzs=joint_frags["mz"].to_numpy().astype(np.float32), fragment_series=joint_frags["ion_series"].to_numpy(), precursor_ids=joint_frags["seq_id"].to_numpy(), - precursor_mzs=joint_frags["seq_mz"].to_numpy(), + precursor_mzs=joint_frags["precursor_mz"] + .to_numpy() + .astype(np.float32), sorting_level="ms2", is_sorted=True, - ) + ), ], num_decimal=2, max_frag_mz=2000, + progress=True, ) out.prefiltered_ms1 = True # TODO change it so the only required section # is the proforma seqs that are in the mz range - chunk_seq_df_coll = chunk_seq_df.sort(pl.col("seq_id")).collect() - out.seq_prec_mzs = chunk_seq_df_coll["seq_mz"].to_numpy() - out.seqs = chunk_seq_df_coll["seq_proforma"].to_numpy() - out.seq_ids = chunk_seq_df_coll["seq_id"].to_numpy() - + chunk_seq_df = ( + seqs_df.filter(pl.col("seq_mz") >= min_mz) + .filter(pl.col("seq_mz") < max_mz) + .unique(maintain_order=False) + .sort(pl.col("seq_id")) + ).collect() + out.seq_prec_mzs = chunk_seq_df["seq_mz"].to_numpy().astype(np.float32) + out.seqs = chunk_seq_df["seq_proforma"].to_numpy() + out.seq_ids = chunk_seq_df["seq_id"].to_numpy() + + target_set = chunk_seq_df.select(["seq_proforma", "decoy"]) target_set = set( - chunk_seq_df.filter(pl.col("decoy") is False) - .select(["seq_proforma"]) - .collect()["seq_proforma"] + target_set["seq_proforma"].to_numpy()[ + np.invert(target_set["decoy"].to_numpy()) + ], ) + + if len(target_set) == 0: + logger.warning(f"No targets were found in range {min_mz}-{max_mz}") out.target_proforma = target_set return out def db_from_fasta( - fasta: Path | str, chunksize: int, config: DiademConfig, index: bool = True + fasta: Path | str, + chunksize: int, + config: DiademConfig, + index: bool = True, ) -> tuple[IndexedDb, str]: """Created a peak index database from a fasta file. It internally checks the existance of a cache in the form of an sqlite file. Future implementations will allow cahching in the form of parquet. """ - config_hash = config.hash() + index_config = config.index_config + config_hash = index_config.hash() file_cache = file_cache_dir(file=fasta) curr_cache = file_cache / config_hash diff --git a/diadem/index/protein_index.py b/diadem/index/protein_index.py index cc31cd0..285d445 100644 --- a/diadem/index/protein_index.py +++ b/diadem/index/protein_index.py @@ -62,7 +62,9 @@ def search_ngram(self, entry: str) -> list[str]: @staticmethod def from_fasta( - fasta_file: PathLike | str, ngram_size: int = 4, progress: bool = True + fasta_file: PathLike | str, + ngram_size: int = 4, + progress: bool = True, ) -> ProteinNGram: """Builds a protein n-gram from a fasta file. @@ -84,7 +86,7 @@ def from_fasta( inv_alias = {} for i, entry in tqdm( - enumerate(FASTA(fasta_file)), + enumerate(FASTA(str(fasta_file))), disable=not progress, desc="Building peptide n-gram index", ): diff --git a/diadem/interfaces.py b/diadem/interfaces.py new file mode 100644 index 0000000..1901726 --- /dev/null +++ b/diadem/interfaces.py @@ -0,0 +1,149 @@ +"""Interfaces for results from Diadem modules. + +This module provides a base class for defining interfaces between +Diadem modules. Each child class should provides access to dataframes +representing results from a Diadem function/method, optionally backed +by a parquet file. The child class can then be used by other Diadem +modules to perform the next step of an algorithm. +""" +from __future__ import annotations + +from abc import ABC, abstractmethod +from collections.abc import Generator, Iterable +from dataclasses import dataclass +from os import PathLike + +import polars as pl + +POLY_DTYPES = ( + pl.datatypes.FLOAT_DTYPES, + pl.datatypes.INTEGER_DTYPES, + pl.datatypes.UNSIGNED_INTEGER_DTYPES, +) + + +class BaseDiademInterface(ABC): + """A base class for interfaces in Diadem. + + Parameters + ---------- + data : DataFrame | LazyFrame + A polars or pandas DataFrame containing the data for the interface. + """ + + @classmethod + def from_parquet(cls, source: PathLike) -> None: + """Initialize the interface from a parquet file. + + Parameters + ---------- + source : PathLike + The parquet file. + """ + return cls(pl.scan_parquet(source)) + + def __init__(self, data: pl.DataFrame | pl.LazyFrame) -> None: + """Initialize the interface.""" + try: + self.data = data.lazy() + except AttributeError: + self.data = pl.from_pandas(data).lazy() + + self.validate_schema() + + def validate_schema(self) -> None: + """Verify that the required columns are present and the correct dtype.""" + req_dtypes = {c.name: _check_for_poly_dtype(c.dtype) for c in self.schema} + + dtype_errors = [] + missing_columns = set(req_dtypes.keys()) + for col, dtype in zip(self.data.columns, self.data.dtypes): + dtype = _check_for_poly_dtype(dtype) + if not dtype == req_dtypes[col]: + dtype_errors.append((col, dtype, req_dtypes[col])) + + try: + missing_columns.remove(col) + except KeyError: + pass + + if not dtype_errors and not missing_columns: + return + + msg = [] + if dtype_errors: + dtype_msg = [ + f" - {n}: {d} (Expected {','.join([str(i) for i in r])})" + for n, d, r in dtype_errors + ] + msg.append("Some columns were of the wrong data type:") + msg += dtype_msg + + if missing_columns: + missing_msg = [f" - {c[0]}" for c in missing_columns] + msg.append("Some columns were missing:") + msg += missing_msg + + raise ValueError("\n".join(msg)) + + @property + @abstractmethod + def schema(self) -> Iterable[RequiredColumn]: + """The required columns for the underlying DataFrame.""" + + +@dataclass +class RequiredColumn: + """Specify a required column. + + Parameters + ---------- + name : str + The column name. + dtype : pl.datatypes.Datatype + The polars data type for the column. + """ + + name: str + dtype: pl.datatypes.DataType + + @classmethod + def from_iter( + cls, + columns: Iterable[tuple[str, pl.DataType], ...], + ) -> Generator[RequiredColumn, ...]: + """Create required columns from an iterable. + + Parameters + ---------- + columns : Iterable[tuple[str, pl.DataType]] + 2-tuples of name-dtype pairs to be required. + + Yields + ------ + RequiredColumn + """ + for col, dtype in columns: + yield cls(col, dtype) + + +def _check_for_poly_dtype( + dtype: pl.datatypes.DataType, +) -> set[pl.datatypes.DataType, ...]: + """Check for poly-dtypes, like floats and ints. + + Parameters + ---------- + dtype : pl.datatypes.DataType + A polars datatype + + Returns + ------- + set[pl.datatypes.DataType] + """ + dtype = {dtype} + for poly_dtype in POLY_DTYPES: + if dtype.issubset(poly_dtype): + return poly_dtype + + return dtype diff --git a/diadem/mzml.py b/diadem/mzml.py deleted file mode 100644 index 5d94dab..0000000 --- a/diadem/mzml.py +++ /dev/null @@ -1,638 +0,0 @@ -from __future__ import annotations - -import os -from dataclasses import dataclass -from pathlib import Path -from typing import Iterable, Iterator - -import numpy as np -from joblib import Parallel, delayed -from loguru import logger -from ms2ml import Spectrum -from ms2ml.data.adapters import MZMLAdapter -from ms2ml.utils.mz_utils import annotate_peaks -from numpy.typing import NDArray -from pandas import DataFrame -from tqdm.auto import tqdm - -from diadem.config import DiademConfig, MassError -from diadem.deisotoping import deisotope -from diadem.search.metrics import get_ref_trace_corrs -from diadem.utils import check_sorted - -try: - zip([], [], strict=True) - - def strictzip(*args: Iterable) -> Iterable: - """Like zip but checks that the length of all elements is the same.""" - return zip(*args, strict=True) - -except TypeError: - - def strictzip(*args: Iterable) -> Iterable: - """Like zip but checks that the length of all elements is the same.""" - args = [list(arg) for arg in args] - lengs = {len(x) for x in args} - if len(lengs) > 1: - raise ValueError("All arguments need to have the same legnths") - return zip(*args) - - -@dataclass -class ScanGroup: - """Represents all spectra that share an isolation window.""" - - precursor_range: tuple[float, float] - mzs: list[NDArray] - intensities: list[NDArray] - base_peak_mz: NDArray[np.float32] - base_peak_int: NDArray[np.float32] - retention_times: NDArray - scan_ids: list[str] - iso_window_name: str - - def __post_init__(self) -> None: - """Check that all the arrays have the same length.""" - elems = [ - self.mzs, - self.intensities, - self.base_peak_int, - self.base_peak_mz, - self.retention_times, - self.scan_ids, - ] - - # TODO move this to assertions so they can be skipped - # during runtime - lengths = {len(x) for x in elems} - if len(lengths) != 1: - raise ValueError("Not all lengths are the same") - if len(self.precursor_range) != 2: - raise ValueError( - "Precursor mass range should have 2 elements," - f" has {len(self.precursor_range)}" - ) - - def get_highest_window( - self, - window: int, - min_intensity_ratio: float, - tolerance: float, - tolerance_unit: str, - min_correlation: float, - max_peaks: int, - ) -> StackedChromatograms: - """Gets the highest intensity window of the chromatogram. - - Briefly ... - 1. Gets the highes peak accross all spectra in the chromatogram range. - 2. Finds what peaks are in that same spectrum. - 3. Looks for spectra around that spectrum. - 4. extracts the chromatogram for all mzs in the "parent spectrum" - - """ - top_index = np.argmax(self.base_peak_int) - window = StackedChromatograms.from_group( - group=self, - index=top_index, - window=window, - min_intensity_ratio=min_intensity_ratio, - min_correlation=min_correlation, - tolerance=tolerance, - tolerance_unit=tolerance_unit, - max_peaks=max_peaks, - ) - - return window - - # TODO make this just take the stacked chromatogram object - def scale_window_intensities( - self, - index: int, - scaling: NDArray, - mzs: NDArray, - window_indices: list[list[int]], - window_mzs: NDArray, - ) -> None: - """Scales the intensities of specific mzs in a window of the chromatogram. - - Parameters - ---------- - index : int - The index of the center spectrum for the the window to scale. - scaling : NDArray - The scaling factors to apply to the intensities. Size should be - the same as the length of the window. - mzs : NDArray - The m/z values of the peaks to scale. - window_indices : list[list[int]] - The indices of the peaks in the window to scale. These are tracked - internally during the workflow. - window_mzs : NDArray - The m/z values of the peaks in the window to scale. These are tracked - internally during the workflow. - """ - window = len(scaling) - slc, center_index = slice_from_center( - center=index, window=window, length=len(self) - ) - - # TODO this can be tracked internally ... - match_obs_mz_indices, match_win_mz_indices = annotate_peaks( - theo_mz=mzs, - mz=window_mzs, - tolerance=0.02, - unit="da", - ) - - match_win_mz_indices = np.unique(match_win_mz_indices) - - zipped = strictzip(range(*slc.indices(len(self))), scaling, window_indices) - for i, s, si in zipped: - for mz_i in match_win_mz_indices: - sim = si[mz_i] - if len(sim) > 0: - self.intensities[i][sim] = self.intensities[i][sim] * s - else: - continue - self.base_peak_int[i] = np.max(self.intensities[i]) - self.base_peak_mz[i] = self.mzs[i][np.argmax(self.intensities[i])] - - def __len__(self) -> int: - """Returns the number of spectra in the group.""" - return len(self.intensities) - - -# @profile -def xic( - query_mz: NDArray[np.float32], - query_int: NDArray[np.float32], - mzs: NDArray[np.float32], - tolerance_unit: MassError = "da", - tolerance: float = 0.02, -) -> tuple[NDArray[np.float32], list[list[int]]]: - """Gets the extracted ion chromatogram form arrays. - - Gets the extracted ion chromatogram from the passed mzs and intensities - The output should be the same length as the passed mzs. - - Returns - ------- - NDArray[np.float32] - An array of length `len(mzs)` that integrates such masses in - the query_int (matching with the query_mz array ...) - - list[list[int]] - A nested list of length `len(mzs)` where each sub-list contains - the indices of the `query_int` array that were integrated. - - """ - theo_mz_indices, obs_mz_indices = annotate_peaks( - theo_mz=mzs, - mz=query_mz, - tolerance=tolerance, - unit=tolerance_unit, - ) - - outs = np.zeros_like(mzs, dtype="float") - inds = [] - for i in range(len(outs)): - query_indices = obs_mz_indices[theo_mz_indices == i] - ints_subset = query_int[query_indices] - if len(ints_subset) == 0: - inds.append([]) - else: - outs[i] = np.sum(ints_subset) - inds.append(query_indices) - - return outs, inds - - -def slice_from_center(center: int, window: int, length: int) -> tuple[slice, int]: - """Generates a slice provided a center and window size. - - Creates a slice that accounts for the endings of an iterable - in such way that the window size is maintained. - - Examples - -------- - >>> my_list = [0,1,2,3,4,5,6] - >>> slc, center_index = slice_from_center( - ... center=4, window=3, length=len(my_list)) - >>> slc - slice(3, 6, None) - >>> my_list[slc] - [3, 4, 5] - >>> my_list[slc][center_index] == my_list[4] - True - - >>> slc = slice_from_center(1, 3, len(my_list)) - >>> slc - (slice(0, 3, None), 1) - >>> my_list[slc[0]] - [0, 1, 2] - - >>> slc = slice_from_center(6, 3, len(my_list)) - >>> slc - (slice(4, 7, None), 2) - >>> my_list[slc[0]] - [4, 5, 6] - >>> my_list[slc[0]][slc[1]] == my_list[6] - True - - """ - start = center - (window // 2) - end = center + (window // 2) + 1 - center_index = window // 2 - - if start < 0: - start = 0 - end = window - center_index = center - - if end >= length: - end = length - start = end - window - center_index = window - (length - center) - - slice_q = slice(start, end) - return slice_q, center_index - - -@dataclass -class StackedChromatograms: - """A class containing the elements of a stacked chromatogram. - - The stacked chromatogram is the extracted ion chromatogram - across a window of spectra. - - Parameters - ---------- - array : - An array of shape [i, w] - mzs : - An array of shape [i] - ref_index : - An integer in the range [0, i] - parent_index : - Identifier of the range where the window was extracted - base_peak_intensity : - Intensity of the base peak in the reference spectrum - stack_peak_indices : - List of indices used to stack the array, it is a list of dimensions [w] - - Details - ------- - The dimensions of the arrays are `w` the window - size of the extracted ion chromatogram. `i` the number - of m/z peaks that were extracted. - - """ - - array: NDArray[np.float32] - mzs: NDArray[np.float32] - ref_index: int - parent_index: int - base_peak_intensity: float - stack_peak_indices: list[list[int]] | list[NDArray[np.int32]] - center_intensities: NDArray[np.float32] - - def __post_init__(self) -> None: - """Checks that the dimensions of the arrays are correct. - - Since they are assertions, they are not meant to be needed for the - correct working of the - """ - array_i = self.array.shape[-2] - array_w = self.array.shape[-1] - - mz_i = self.mzs.shape[-1] - - assert ( - self.ref_index <= mz_i - ), f"Reference index outside of mz values {self.ref_index} > {mz_i}" - assert ( - array_i == mz_i - ), f"Intensity Array and mzs have different lengths {array_i} != {mz_i}" - for i, x in enumerate(self.stack_peak_indices): - assert len(x) == mz_i, ( - f"Number of mzs and number of indices {len(x)} != {mz_i} is different" - f" for {i}" - ) - assert array_w == len( - self.stack_peak_indices - ), "Window size is not respected in the stack" - - @property - def ref_trace(self) -> NDArray[np.float32]: - """Returns the reference trace. - - The reference trace is the extracted ion chromatogram of the - mz that corresponds to the highest intensity peak. - """ - return self.array[self.ref_index, ...] - - @property - def ref_mz(self) -> float: - """Returns the m/z value of the reference trace.""" - return self.mzs[self.ref_index] - - @property - def ref_fwhm(self) -> int: - """Returns the number of points in the reference trace above half max. - - Not really fwhm, just number of elements above half max. - """ - rt = self.ref_trace - rt = rt - rt.min() - above_hm = rt >= (rt.max() / 2) - return above_hm.astype(int).sum() - - def plot(self, plt) -> None: # noqa - """Plots the stacked chromatogram as lines.""" - # TODO reconsider this implementation, maybe lazy import - # of matplotlib. - plt.plot(self.array.T) - plt.plot(self.array[self.ref_index, ...].T, color="black") - plt.show() - - def trace_correlation(self) -> NDArray[np.float32]: - """Calculate the correlation between the reference trace and all other traces. - - Returns - ------- - NDArray[np.float32] - An array of shape [i] where i is the number of traces - in the stacked chromatogram. - """ - return get_ref_trace_corrs(arr=self.array, ref_idx=self.ref_index) - - @staticmethod - # @profile - def from_group( - group: ScanGroup, - index: int, - window: int = 21, - tolerance: float = 0.02, - tolerance_unit: MassError = "da", - min_intensity_ratio: float = 0.01, - min_correlation: float = 0.5, - max_peaks: int = 150, - ) -> StackedChromatograms: - """Create a stacked chromatogram from a scan group. - - Parameters - ---------- - group : ScanGroup - A scan group containing the spectra to stack - index : int - The index of the spectrum to use as the reference - window : int, optional - The number of spectra to stack, by default 21 - tolerance : float, optional - The tolerance to use when matching m/z values, by default 0.02 - tolerance_unit : MassError, optional - The unit of the tolerance, by default "da" - min_intensity_ratio : float, optional - The minimum intensity ratio to use when stacking, by default 0.01 - min_correlation : float, optional - The minimum correlation to use when stacking, by default 0.5 - max_peaks : int, optional - The maximum number of peaks to return in a group, by default is 150. - If the candidates is more than this number, it will the best co-eluting - peaks. - - """ - # The center index is the same as the provided index - # Except in cases where the edge of the group is reached, where - # the center index is adjusted to the edge of the group - slice_q, center_index = slice_from_center( - center=index, window=window, length=len(group.mzs) - ) - mzs = group.mzs[slice_q] - intensities = group.intensities[slice_q] - - center_mzs = mzs[center_index] - center_intensities = intensities[center_index] - - int_keep = center_intensities >= ( - center_intensities.max() * min_intensity_ratio - ) - - # num_keep = int_keep.sum() - # logger.debug("Number of peaks to stack: " - # f"{len(center_mzs)}, number above 0.1% intensity {num_keep} " - # f"[{100*num_keep/len(center_mzs):.02f} %]") - center_mzs = center_mzs[int_keep] - center_intensities = center_intensities[int_keep] - - xic_outs = [] - - for i, (m, inten) in enumerate(zip(mzs, intensities)): - xic_outs.append( - xic( - query_mz=m, - query_int=inten, - mzs=center_mzs, - tolerance=tolerance, - tolerance_unit=tolerance_unit, - ) - ) - if i == center_index: - assert xic_outs[-1][0].sum() >= center_intensities.max() - - stacked_arr = np.stack([x[0] for x in xic_outs], axis=-1) - - # TODO make this an array and subset it in line 457 - indices = [x[1] for x in xic_outs] - - if stacked_arr.shape[-2] > 1: - ref_id = np.argmax(stacked_arr[..., center_index]) - corrs = get_ref_trace_corrs(arr=stacked_arr, ref_idx=ref_id) - - # I think adding the 1e-5 is needed here due to numric instability - # in the flaoting point operation - assert np.max(corrs) <= ( - corrs[ref_id] + 1e-5 - ), "Reference does not have max corrr" - - max_peak_corr = np.sort(corrs)[-max_peaks] if len(corrs) > max_peaks else -1 - keep = corrs >= max(min_correlation, max_peak_corr) - - stacked_arr = stacked_arr[..., keep, ::1] - center_mzs = center_mzs[keep] - center_intensities = center_intensities[keep] - indices = [[y for y, k in zip(x, keep) if k] for x in indices] - - ref_id = np.argmax(stacked_arr[..., center_index]) - bp_int = stacked_arr[ref_id, center_index] - - out = StackedChromatograms( - array=stacked_arr, - mzs=center_mzs, - ref_index=ref_id, - parent_index=index, - base_peak_intensity=bp_int, - stack_peak_indices=indices, - center_intensities=center_intensities, - ) - return out - - -class SpectrumStacker: - """Helper class that stacks the spectra of an mzml file into chromatograms.""" - - def __init__(self, mzml_file: Path | str, config: DiademConfig) -> None: - """Initializes the SpectrumStacker class. - - Parameters - ---------- - mzml_file : Path | str - Path to the mzml file. - config : DiademConfig - The configuration object. Note that this is an DiademConfig - configuration object. - """ - self.adapter = MZMLAdapter(mzml_file, config=config) - - # TODO check if directly reading the xml is faster ... - # also evaluate if that is needed - scaninfo = self.adapter.get_scan_info() - self.config = config - if "DEBUG_DIADEM" in os.environ: - logger.error("RUNNING DIADEM IN DEBUG MODE (only 700-710 mz iso windows)") - scaninfo = scaninfo[scaninfo.ms_level > 1] - scaninfo = scaninfo[ - [x[0] > 700 and x[0] < 705 for x in scaninfo.iso_window] - ] - self.ms2info = scaninfo[scaninfo.ms_level > 1].copy().reset_index() - self.unique_iso_windows = set(np.array(self.ms2info.iso_window)) - - def _get_iso_window_group( - self, iso_window_name: str, iso_window: tuple[float, float], chunk: DataFrame - ) -> ScanGroup: - logger.debug(f"Processing iso window {iso_window_name}") - - window_mzs = [] - window_ints = [] - window_bp_mz = [] - window_bp_int = [] - window_rtinsecs = [] - window_scanids = [] - - npeaks_raw = [] - npeaks_deisotope = [] - - for row in tqdm( - chunk.itertuples(), desc=f"Preprocessing spectra for {iso_window_name}" - ): - spec_id = row.spec_id - curr_spec: Spectrum = self.adapter[spec_id] - # NOTE instrument seems to have a wrong value ... - # Also activation seems to not be recorded ... - curr_spec = curr_spec.filter_top(self.config.run_max_peaks_per_spec) - curr_spec = curr_spec.filter_mz_range(*self.config.ion_mz_range) - - # Deisotoping! - if self.config.run_deconvolute_spectra: - # Masses need to be ordered for the deisotoping function! - order = np.argsort(curr_spec.mz) - npeaks_raw.append(len(order)) - - mzs = curr_spec.mz[order] - intensities = curr_spec.intensity[order] - mzs, intensities = deisotope( - mzs, - intensities, - max_charge=5, - diff=self.config.g_tolerances[1], - unit=self.config.g_tolerance_units[1], - ) - npeaks_deisotope.append(len(mzs)) - else: - mzs = curr_spec.mz - intensities = curr_spec.intensity - # TODO evaluate this scaling - intensities = np.sqrt(intensities) - if len(mzs) == 0: - mzs, intensities = np.array([0]), np.array([0]) - bp_index = np.argmax(intensities) - bp_mz = mzs[bp_index] - bp_int = intensities[bp_index] - rtinsecs = curr_spec.retention_time.seconds() - - window_mzs.append(mzs) - window_ints.append(intensities) - window_bp_mz.append(bp_mz) - window_bp_int.append(bp_int) - window_rtinsecs.append(rtinsecs) - window_scanids.append(spec_id) - - avg_peaks_raw = np.array(npeaks_raw).mean() - avg_peaks_deisotope = np.array(npeaks_deisotope).mean() - # Create datasets within each group - logger.info(f"Saving group {iso_window_name} with length {len(window_mzs)}") - logger.info( - f"{avg_peaks_raw} peaks/spec; {avg_peaks_deisotope} peaks/spec after" - " deisotoping" - ) - - window_bp_mz = np.array(window_bp_mz).astype(np.float32) - window_bp_int = np.array(window_bp_int).astype(np.float32) - window_rtinsecs = np.array(window_rtinsecs).astype(np.float16) - check_sorted(window_rtinsecs) - window_scanids = np.array(window_scanids, dtype="object") - - group = ScanGroup( - iso_window_name=iso_window_name, - precursor_range=iso_window, - mzs=window_mzs, - intensities=window_ints, - base_peak_mz=window_bp_mz, - base_peak_int=window_bp_int, - retention_times=window_rtinsecs, - scan_ids=window_scanids, - ) - return group - - def get_iso_window_groups( - self, workerpool: None | Parallel = None - ) -> list[ScanGroup]: - """Returns a list of all ScanGroups in an mzML file.""" - grouped = self.ms2info.sort_values("RTinSeconds").groupby("iso_window") - iso_windows, chunks = zip(*list(grouped)) - - iso_window_names = [ - "({:.06f}, {:.06f})".format(*iso_window) for iso_window in iso_windows - ] - - if workerpool is None: - results = [ - self._get_iso_window_group( - iso_window_name=iwn, iso_window=iw, chunk=chunk - ) - for iwn, iw, chunk in zip(iso_window_names, iso_windows, chunks) - ] - else: - results = workerpool( - delayed(self._get_iso_window_group)( - iso_window_name=iwn, iso_window=iw, chunk=chunk - ) - for iwn, iw, chunk in zip(iso_window_names, iso_windows, chunks) - ) - - return results - - def yield_iso_window_groups(self, progress: bool = False) -> Iterator[ScanGroup]: - """Yield scan groups for each unique isolation window.""" - grouped = self.ms2info.sort_values("RTinSeconds").groupby("iso_window") - - for i, (iso_window, chunk) in enumerate( - tqdm(grouped, disable=not progress, desc="Unique Isolation Windows") - ): - iso_window_name = "({:.06f}, {:.06f})".format(*iso_window) - - group = self._get_iso_window_group( - iso_window_name=iso_window_name, iso_window=iso_window, chunk=chunk - ) - yield group diff --git a/diadem/search/dda.py b/diadem/search/dda.py index bf7e1b7..500a589 100644 --- a/diadem/search/dda.py +++ b/diadem/search/dda.py @@ -13,7 +13,7 @@ from diadem.config import DiademConfig from diadem.index.indexed_db import IndexedDb, db_from_fasta -from diadem.search.search_utils import make_pin +from diadem.search.mokapot import brew_run def score(db: IndexedDb, spec: Spectrum, mzml_stem: str) -> DataFrame | None: @@ -25,7 +25,10 @@ def score(db: IndexedDb, spec: Spectrum, mzml_stem: str) -> DataFrame | None: if spec is None: return None spec_results = db.hyperscore( - spec.precursor_mz, spec_mz=spec.mz, spec_int=spec.intensity, top_n=10 + spec.precursor_mz, + spec_mz=spec.mz, + spec_int=spec.intensity, + top_n=10, ) if spec_results is not None: spec_results["ScanID"] = f"{mzml_stem}::{spec.extras['id']}" @@ -75,7 +78,10 @@ def out_hook(spec: Spectrum) -> Spectrum | None: if spec is None: continue spec_results = db.hyperscore( - spec.precursor_mz, spec_mz=spec.mz, spec_int=spec.intensity, top_n=10 + spec.precursor_mz, + spec_mz=spec.mz, + spec_int=spec.intensity, + top_n=10, ) if spec_results is not None: spec_results["ScanID"] = f"{mzml_stem}::{spec.extras['id']}" @@ -86,12 +92,17 @@ def out_hook(spec: Spectrum) -> Spectrum | None: logger.info(f"Writting {prefix+'.csv'} and {prefix+'.parquet'}") results.to_csv(prefix + ".csv", index=False) results.to_parquet(prefix + ".parquet", index=False) - make_pin( - results, - fasta_path=fasta_path, - mzml_path=mzml_path, - pin_path=prefix + ".tsv.pin", - ) + try: + # Right now I am bypassing the mokapot results, because they break a test + # meant to check that no decoys are detected (which is true in that case). + mokapot_results = brew_run( + results, + fasta_path=fasta_path, + ms_data_path=mzml_path, + ) + mokapot_results.to_parquet(prefix + ".peptides.parquet") + except ValueError as e: + logger.error(f"Could not run mokapot: {e}") end_time = time.time() elapsed_time = end_time - start_time diff --git a/diadem/search/diadem.py b/diadem/search/diadem.py index 9cdfc95..88ccc50 100644 --- a/diadem/search/diadem.py +++ b/diadem/search/diadem.py @@ -1,6 +1,7 @@ from __future__ import annotations -import itertools +import logging +import os import time from pathlib import Path @@ -14,23 +15,22 @@ from diadem.config import DiademConfig from diadem.index.indexed_db import IndexedDb, db_from_fasta -from diadem.mzml import ScanGroup, SpectrumStacker, StackedChromatograms -from diadem.search.search_utils import make_pin +from diadem.search.mokapot import brew_run +from diadem.utilities.logging import InterceptHandler +from diadem.utilities.utils import plot_to_log +logging.basicConfig(handlers=[InterceptHandler()], level=logging.INFO, force=True) -def plot_to_log(*args, **kwargs) -> None: # noqa - """Plot to log. - - Generates a plot of the passed data to the function. - All arguments are passed internally to uniplot.plot_to_string. - """ - for line in uniplot.plot_to_string(*args, **kwargs): - logger.debug(line) +if "PLOTDIADEM" in os.environ and os.environ["PLOTDIADEM"]: + import matplotlib.pyplot as plt # noqa: I001 # @profile def search_group( - group: ScanGroup, db: IndexedDb, config: DiademConfig, progress: bool = True + group: ScanGroup, + db: IndexedDb, + config: DiademConfig, + progress: bool = True, ) -> DataFrame: """Search a group of scans. @@ -47,7 +47,7 @@ def search_group( WINDOWSIZE = config.run_window_size # noqa WINDOW_MAX_PEAKS = config.run_max_peaks_per_window # noqa - MIN_INTENSITY_SCALING, MAX_INTENSITY_SCALING = config.run_scalin_limits # noqa + MIN_INTENSITY_SCALING, MAX_INTENSITY_SCALING = config.run_scaling_limits # noqa SCALING_RATIO = config.run_scaling_ratio # noqa # Min intensity required on a peak in the base @@ -59,6 +59,38 @@ def search_group( MS2_TOLERANCE = config.g_tolerances[1] # noqa MS2_TOLERANCE_UNIT = config.g_tolerance_units[1] # noqa + MS1_TOLERANCE = config.g_tolerances[0] # noqa + MS1_TOLERANCE_UNIT = config.g_tolerance_units[0] # noqa + + IMS_TOLERANCE = config.g_ims_tolerance # noqa + IMS_TOLERANCE_UNIT = config.g_ims_tolerance_unit # noqa + MAX_NUM_CONSECUTIVE_FAILS = 100 # noqa + + start_rts, start_bpc = group.retention_times.copy(), group.base_peak_int.copy() + + new_window_kwargs = { + "window": WINDOWSIZE, + "min_intensity_ratio": MIN_INTENSITY_RATIO, + "min_correlation": MIN_CORR_SCORE, + "mz_tolerance": MS2_TOLERANCE, + "mz_tolerance_unit": MS2_TOLERANCE_UNIT, + "max_peaks": WINDOW_MAX_PEAKS, + } + + if hasattr(group, "imss"): + logger.info("Detected a diaPASEF dataset") + new_window_kwargs.update( + { + "ims_tolerance": IMS_TOLERANCE, + "ims_tolerance_unit": IMS_TOLERANCE_UNIT, + }, + ) + stack_getter = TimsStackedChromatograms.from_group + + else: + logger.info("No IMS detected") + stack_getter = StackedChromatograms.from_group + # Results and stats related variables group_results = [] intensity_log = [] @@ -79,8 +111,10 @@ def search_group( if num_fails > ALLOWED_FAILS: logger.warning( - "Exiting scoring loop because number of" - f" failes reached the maximum {ALLOWED_FAILS}" + ( + "Exiting scoring loop because number of" + f" failes reached the maximum {ALLOWED_FAILS}" + ), ) break @@ -97,8 +131,10 @@ def search_group( if last_id == new_stack.parent_index: logger.debug( - "Array generated on same index " - f"{new_stack.parent_index} as last iteration" + ( + "Array generated on same index " + f"{new_stack.parent_index} as last iteration" + ), ) num_fails += 1 else: @@ -132,16 +168,62 @@ def search_group( scores = None if scores is not None: - scores["id"] = match_id - ref_peak_mz = new_stack.mzs[new_stack.ref_index] - - mzs = itertools.chain( - *[scores[x].iloc[0] for x in scores.columns if "_mzs" in x] + scores["peak_id"] = match_id + scores["RetentionTime"] = group.retention_times[new_stack.parent_index] + if hasattr(group, "imss"): + scores["IonMobility"] = new_stack.ref_ims + if scores["decoy"].iloc[0]: + num_decoys += 1 + else: + num_targets += 1 + + scores = scores.sort_values(by="Score", ascending=False).iloc[:1] + match_indices = scores["spec_indices"].iloc[0] + [new_stack.ref_index] + match_indices = np.sort(np.unique(np.array(match_indices))) + + min_corr_score_scale = np.quantile( + new_stack.correlations[match_indices], + 0.75, ) - best_match_mzs = np.sort( - np.array(tuple(itertools.chain(mzs, [ref_peak_mz]))) + scores["q75_correlation"] = min_corr_score_scale + corr_match_indices = np.where( + new_stack.correlations > min_corr_score_scale, + )[0] + match_indices = np.sort( + np.unique(np.concatenate([match_indices, corr_match_indices])), ) + if "PLOTDIADEM" in os.environ and os.environ["PLOTDIADEM"]: + try: + ax1.cla() + ax2.cla() + except NameError: + fig, (ax1, ax2) = plt.subplots(1, 2) + + new_stack.plot(ax1, matches=match_indices) + + ax2.plot(start_rts, np.sqrt(start_bpc), alpha=0.2, color="gray") + ax2.plot(group.retention_times, np.sqrt(group.base_peak_int)) + ax2.vlines( + x=group.retention_times[new_stack.parent_index], + ymin=0, + ymax=np.sqrt(new_stack.base_peak_intensity), + color="r", + ) + plt.title( + ( + f"Score: {scores['Score'].iloc[0]} \n" + f" Peptide: {scores['peptide'].iloc[0]} \n" + f"@ RT: {scores['RetentionTime'].iloc[0]} \n" + f"Corr Score: {min_corr_score_scale}" + ), + ) + plt.pause(0.01) + + scaling_window_indices = [ + [x[y] for y in match_indices] for x in new_stack.stack_peak_indices + ] + # Scale based on the inverse of the reference chromatogram normalized_trace = new_stack.ref_trace / new_stack.ref_trace.max() # scaling = SCALING_RATIO * (1-normalized_trace) @@ -199,20 +281,23 @@ def search_group( pbar.update(1) if (num_peaks % DEBUG_FREQUENCY) == 0: logger.debug( - f"peak {num_peaks}/{MAX_PEAKS} max ; Intensity {curr_highest_peak_int}" + f"peak {num_peaks}/{MAX_PEAKS} max ; Intensity {curr_highest_peak_int}", ) pbar.close() plot_to_log( - np.log1p(np.array(intensity_log)), title="Max (log) intensity over time" + np.log1p(np.array(intensity_log)), + title="Max (log) intensity over time", ) plot_to_log(np.array(index_log), title="Requested index over time") plot_to_log(np.array(fwhm_log), title="FWHM across time") logger.info( - f"Done with window {group.iso_window_name}, " - f"scored {num_peaks} peaks in {len(group.base_peak_int)} spectra. " - f"Intensity of the last scored peak {curr_highest_peak_int} " - f"on index {last_id}" + ( + f"Done with window {group.iso_window_name}, " + f"scored {num_peaks} peaks in {len(group.base_peak_int)} spectra. " + f"Intensity of the last scored peak {curr_highest_peak_int} " + f"on index {last_id}" + ), ) group_results = pd.concat(group_results) return group_results @@ -244,7 +329,10 @@ def diadem_main( # Set up database db, cache = db_from_fasta( - fasta=fasta_path, chunksize=None, config=config, index=False + fasta=fasta_path, + chunksize=None, + config=config, + index=False, ) # set up mzml file @@ -256,9 +344,11 @@ def diadem_main( results = [] if config.run_parallelism == 1: - for group in ss.yield_iso_window_groups(progress=True): + for group in ss.yield_iso_window_groups(): group_db = db.index_prefiltered_from_parquet(cache, *group.precursor_range) group_results = search_group(group=group, db=group_db, config=config) + if group_results is not None: + group_results.to_parquet("latestresults.parquet") results.append(group_results) else: with Parallel(n_jobs=config.run_parallelism) as workerpool: @@ -282,14 +372,35 @@ def diadem_main( logger.info(f"Writting {prefix+'.csv'} and {prefix+'.parquet'}") results.to_csv(prefix + ".csv", index=False) - results.to_parquet(prefix + ".parquet", index=False, engine="pyarrow") - make_pin( - results, - fasta_path=fasta_path, - mzml_path=mzml_path, - pin_path=prefix + ".tsv.pin", - ) - end_time = time.time() - elapsed_time = end_time - start_time - logger.info(f"Elapsed time: {elapsed_time}") + # RTs are stored as f16, which need to be converted to f32 for parquet + f16_cols = list(results.select_dtypes("float16")) + if f16_cols: + for col in f16_cols: + results[col] = results[col].astype("float32") + results.to_parquet(prefix + ".parquet", index=False, engine="pyarrow") + try: + # Right now I am bypassing the mokapot results, because they break a test + # meant to check that no decoys are detected (which is true in that case). + logger.info("Running mokapot") + mokapot_results = brew_run( + results, + fasta_path=fasta_path, + ms_data_path=data_path, + ) + logger.info(f"Writting mokapot results to {prefix}.peptides.parquet") + mokapot_results.to_parquet(prefix + ".peptides.parquet") + except ValueError as e: + if "decoy PSMs were detected" in str(e): + logger.warning(f"Could not run mokapot: {e}") + else: + logger.error(f"Could not run mokapot: {e}") + raise e + except RuntimeError as e: + logger.warning(f"Could not run mokapot: {e}") + logger.error(results) + raise e + finally: + end_time = time.time() + elapsed_time = end_time - start_time + logger.info(f"Elapsed time: {elapsed_time}") diff --git a/diadem/search/metrics.py b/diadem/search/metrics.py index 549b81e..55b6b5b 100644 --- a/diadem/search/metrics.py +++ b/diadem/search/metrics.py @@ -26,6 +26,34 @@ def cosinesim(x: NDArray, y: NDArray) -> NDArray: return out +def max_rolling(a: np.ndarray, window: int, axis: int = 1) -> np.ndarray: + """Max window smoothing on a numpy array. + + Parameters + ---------- + a: np.ndarray + The array to smooth. + window: int + The size of the window. + axis: int + The axis along which to smooth. + + Examples + -------- + >>> foo = np.array([[1,2,3,4,5,6,7,8], [5,6,7,8,9,10,11,12]]) + >>> max_rolling(foo, 3, -1) + array([[ 3, 4, 5, 6, 7, 8], + [ 7, 8, 9, 10, 11, 12]]) + + From this answer: + https://stackoverflow.com/a/52219082. + """ + shape = a.shape[:-1] + (a.shape[-1] - window + 1, window) + strides = a.strides + (a.strides[-1],) + rolling = np.lib.stride_tricks.as_strided(a, shape=shape, strides=strides) + return np.max(rolling, axis=axis) + + def spectral_angle(x: NDArray, y: NDArray) -> NDArray: """Computes the spectral angle between two vectors. @@ -75,13 +103,16 @@ def get_ref_trace_corrs(arr: NDArray[np.float32], ref_idx: int) -> NDArray[np.fl >>> [round(x, 4) for x in out] [0.8355, 0.8597, 0.8869, 0.9182, 0.9551, 0.9989, 0.9436, 0.8704, 0.7722, 0.6385] """ - norm = np.linalg.norm(arr + 1e-5, axis=-1) - normalized_arr = arr / np.expand_dims(norm, axis=-1) + arr2 = max_rolling(arr, 3, axis=-1) + arr2 = np.sqrt(arr2) + norm = np.linalg.norm(arr2 + 1e-5, axis=-1) + normalized_arr = arr2 / np.expand_dims(norm, axis=-1) ref_trace = normalized_arr[..., ref_idx, ::1] # ref_trace = np.stack([ref_trace, ref_trace[..., ::-1]]).min(axis=0) # ref_trace = np.stack([ref_trace, ref_trace[..., ::-1]]).min(axis=0) spec_angle_weights = spectral_angle( - normalized_arr.astype("float"), ref_trace.astype("float") + normalized_arr.astype("float"), + ref_trace.astype("float"), ) return spec_angle_weights diff --git a/diadem/search/mokapot.py b/diadem/search/mokapot.py new file mode 100644 index 0000000..79ba133 --- /dev/null +++ b/diadem/search/mokapot.py @@ -0,0 +1,176 @@ +"""Run-level mokapot analyses.""" +import re +from collections.abc import Iterable +from os import PathLike +from pathlib import Path + +import mokapot +import numpy as np +import pandas as pd + +from diadem.config import DiademConfig +from diadem.index.protein_index import ProteinNGram + + +def brew_run( + results: pd.DataFrame, + fasta_path: PathLike, + ms_data_path: PathLike, + config: DiademConfig, +) -> pd.DataFrame: + """Prepare the result DataFrame for mokapot. + + Parameters + ---------- + results: pd.DataFrame + The diadem search results. + fasta_path : PathLike + The FASTA file that was used for the search. + ms_data_path : PathLike + The mass spectrometry data file that was searched. + config : DiademConfig + The configuration setting. + + Returns + ------- + pd.DataFrame + The run-level peptide scores and confidence estimates. + """ + input_df = _prepare_df(results, fasta_path, ms_data_path) + nonfeat = [ + "id", + "peptide", + "proteins", + "filename", + "target_pair", + "peak_id", + "is_target", + "PrecursorMZ", + "RetentionTime", + ] + # Retention time could be used if we had a model that makes use of the + # combination of RT/MZ and IMS. Since an SVM doesn't, we'll drop it. + peptides = mokapot.LinearPsmDataset( + psms=input_df, + target_column="is_target", + spectrum_columns="target_pair", + peptide_column="peptide", + protein_column="proteins", + feature_columns=[c for c in input_df.columns if c not in nonfeat], + filename_column="filename", + copy_data=False, + ) + + mokapot.PercolatorModel(train_fdr=config.train_fdr) + results, _ = mokapot.brew(peptides, test_fdr=config.eval_fdr) + targets = results.confidence_estimates["peptides"] + decoys = results.decoy_confidence_estiamtes["peptides"] + targets["is_target"] = True + decoys["is_target"] = False + return pd.concat([targets, decoys], axis=1) + + +def _prepare_df( + results: pd.DataFrame, + fasta_path: PathLike, + ms_data_path: PathLike, +) -> pd.DataFrame: + """Prepare the result DataFrame for mokapot. + + Parameters + ---------- + results: pd.DataFrame + The diadem search results. + fasta_path : PathLike + The FASTA file that was used for the search. + ms_data_path : PathLike + The mass spectrometry data file that was searched. + + Returns + ------- + pd.DataFrame + The input DataFrame for mokapot. + """ + # Keep only rank 1 peptides + results = results.loc[results["rank"] == 1, :] + + # Remove all list columns + non_list_cols = [c for c in results.columns if not isinstance(results[c][0], list)] + non_list_cols = [ + c for c in non_list_cols if not isinstance(results[c][0], np.ndarray) + ] + + results = results.loc[:, non_list_cols].drop(columns="rank") + + results["filename"] = Path(ms_data_path).stem + results["target_pair"] = results["peptide"] + results.loc[results["decoy"], "target_pair"] = results.loc[ + results["decoy"], + "peptide", + ].apply(_decoy_to_target) + results["decoy"] = ~results["decoy"] + stripped_peptides = results["target_pair"].str.replace("\\[.*?\\]", "", regex=True) + results["peptide_length"] = stripped_peptides.str.len() + + # Add the pct-features + npeak_cols = [x for x in results.columns if "npeaks" in x] + for x in npeak_cols: + results[f"{x}_pct"] = 100 * results[x] / results["peptide_length"] + + # Get proteins, although not enirely necessary: + ## For decoys, this is the corresponding target protein. + results["proteins"] = stripped_peptides.apply( + _get_proteins, + ngram=ProteinNGram.from_fasta(fasta_path, progress=False), + ) + + # Rename columns to the expected values + expected_names = {"decoy": "is_target"} + results.rename(columns=expected_names, inplace=True) + return results + + +def _decoy_to_target(seq: str, permutation: Iterable[int] | None = None) -> str: + """Get the target sequence for a decoy peptide. + + Parameters + ---------- + seq : str + The decoy peptide sequence with Proforma-style modifications. + permutation : Sequence[int] | None + The permuation that was used to generate the decoy from the target sequence. + If ``None,`` it is assumed to be reversed between the termini. + + Returns + ------- + str + The target sequence that generated the decoy sequence. + """ + seq = re.split(r"(?=[A-Z])", seq)[1:] + if permutation is None: + inverted = list(range(len(seq))) + inverted[1:-1] = reversed(inverted[1:-1]) + else: + inverted = [None] * len(seq) + for decoy_idx, target_idx in enumerate(permutation): + inverted[target_idx] = decoy_idx + + return "".join([seq[i] for i in inverted]) + + +def _get_proteins(peptide: str, ngram: ProteinNGram) -> str: + """Get the protein(s) that may have generated a peptide. + + Parameters + ---------- + peptide : str + The stripped peptide sequence. + ngram : ProteinNGram + The n-gram object to look-up sequences. + + Returns + ------- + str + The protein or proteins delimited by semi-colons. + """ + return ";".join(ngram.search_ngram(peptide)) diff --git a/diadem/search/search_utils.py b/diadem/search/search_utils.py deleted file mode 100644 index 6fdd459..0000000 --- a/diadem/search/search_utils.py +++ /dev/null @@ -1,57 +0,0 @@ -from os import PathLike -from pathlib import Path - -from pandas import DataFrame -from pyteomics.proforma import parse - -from diadem.index.protein_index import ProteinNGram - - -def make_pin( - results: DataFrame, fasta_path: PathLike, mzml_path: PathLike, pin_path: PathLike -) -> None: - """Makes a '.pin' file from a dataframe of results. - - It writes the .pin file to disk (`pin_path` argument). - """ - # Postprocessing of the results dataframe for percolator - ## 1. keep only rank 1 peptides - results = results[results["rank"] == 1] - - ## Remove all list columns - non_list_cols = [c for c in results.columns if not isinstance(results[c][0], list)] - results = results[non_list_cols] - - ## Add protein names - ngram = ProteinNGram.from_fasta(str(fasta_path)) - stripped_peptides = [ - "".join([x[0] for x in parse(y)[0]]) for y in results["Peptide"] - ] - proteins = [";".join(ngram.search_ngram(x)) for x in stripped_peptides] - results["Proteins"] = proteins - - ## Add the pct-features - results["PeptideLength"] = [len(x) for x in stripped_peptides] - npeak_cols = [x for x in results.columns if "npeaks" in x] - for x in npeak_cols: - results[f"{x}_pct"] = 100 * results[x] / results["PeptideLength"] - - ## Convert the decoys column to the right format - results["decoy"] = [-1 if d else 1 for d in results["decoy"]] - - ## Add a scan number column .... - # TODO add to the diadem logic to include the representative scan ID - results["ScanNr"] = list(range(len(results))) - results["Filename"] = Path(mzml_path).stem - - ## Rename columns to the expected values - ## Some columns in mokapot/percolator require a specific name - expected_names = { - "id": "SpecID", - "ScanNr": "ScanNr", - "Peptide": "Peptide", - "Proteins": "Proteins", - "decoy": "Label", - } - results.rename(columns=expected_names, inplace=True) - results.to_csv(pin_path, index=False, sep="\t") diff --git a/diadem/utilities/__init__.py b/diadem/utilities/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/diadem/utilities/logging.py b/diadem/utilities/logging.py new file mode 100644 index 0000000..050b178 --- /dev/null +++ b/diadem/utilities/logging.py @@ -0,0 +1,28 @@ +import logging +import sys +from logging import LogRecord + +from loguru import logger + + +class InterceptHandler(logging.Handler): + """Intercept a logging call and send it to Loguru.""" + + def emit(self, record: LogRecord) -> None: + """Intercept a logging call and send it to Loguru.""" + # Get corresponding Loguru level if it exists. + try: + level = logger.level(record.levelname).name + except ValueError: + level = record.levelno + + # Find caller from where originated the logged message. + frame, depth = sys._getframe(6), 6 + while frame and frame.f_code.co_filename == logging.__file__: + frame = frame.f_back + depth += 1 + + logger.opt(depth=depth, exception=record.exc_info).log( + level, + record.getMessage(), + ) diff --git a/diadem/utilities/neighborhood.py b/diadem/utilities/neighborhood.py new file mode 100644 index 0000000..05863b0 --- /dev/null +++ b/diadem/utilities/neighborhood.py @@ -0,0 +1,521 @@ +"""Contians and implements ways to look for and represent neighbors. + +This module contains implementations related to finding "neighbors" +and representations for those neighbors. +""" + +from __future__ import annotations + +from dataclasses import dataclass, field + +import numpy as np +from numpy.typing import NDArray + +from diadem.utilities.utils import is_sorted + + +@dataclass +class IndexBipartite: + """Simple bipartite graph structure. + + This dataclass is meant to contain connections between indices + in other structures. + + Examples + -------- + >>> bg = IndexBipartite() + >>> bg.add_connection(1,2) + >>> bg.add_connection(1,3) + >>> bg.add_connection(2,2) + >>> bg + IndexBipartite(left_neighbors={1: [2, 3], 2: [2]}, + right_neighbors={2: [1, 2], 3: [1]}) + >>> bg.get_neighbor_indices() + (array([1, 2]), array([2, 3])) + """ + + left_neighbors: dict = field(default_factory=dict) + right_neighbors: dict = field(default_factory=dict) + + def add_connection(self, left: int, right: int) -> None: + """Adds a connection to the graph. + + See the class docstring for details. + """ + self.left_neighbors.setdefault(left, []).append(right) + self.right_neighbors.setdefault(right, []).append(left) + + def get_neighbor_indices(self) -> tuple[NDArray, NDArray]: + """Returns the neighborhoods as two arrays. + + This returns two arrays that represent all indices tht have + a neghbor in their corresponding counterpart. + + Returns + ------- + x, y: tuple[NDArray, NDAarray] + an array representing the indices that have a neighbor. + In other words, all elements with index xi will have a + neighbor in y. + + Every element in x and every element in y is unique. + + Examples + -------- + >>> bg = IndexBipartite() + >>> bg.add_connection(1,1) + >>> bg.add_connection(1,2) + >>> bg.add_connection(1,3) + >>> bg.add_connection(2,2) + >>> bg.get_neighbor_indices() + (array([1, 2]), array([1, 2, 3])) + """ + x = np.sort(np.array(list(self.left_neighbors))) + y = np.sort(np.array(list(self.right_neighbors))) + return x, y + + def get_matching_indices(self) -> tuple[NDArray, NDArray]: + """Returns the matching indices in the neighborhood as two arrays. + + Note: this differs from the `get_neighbor_indices` function + because it assures that the lengths of both arrays are the same. + Therefore + + This returns two arrays that represent all indices tht have + a neghbor in their corresponding counterpart. + + Returns + ------- + x,y: tuple[NDArray, NDArray] + Returns two arrays where every element i in array x is + a neighbor of element i in array y. + + This entails that there will be duplicates if for instance + an element in x has many neighbors in y. + + Examples + -------- + >>> bg = IndexBipartite() + >>> bg.add_connection(1,1) + >>> bg.add_connection(1,2) + >>> bg.add_connection(1,3) + >>> bg.add_connection(2,2) + >>> bg.get_matching_indices() + (array([1, 1, 1, 2]), array([1, 2, 3, 2])) + """ + out_x = [] + out_y = [] + + for x in self.left_neighbors.keys(): + for xn in self.left_neighbors[x]: + out_x.append(x) + out_y.append(xn) + + return np.array(out_x), np.array(out_y) + + @classmethod + def from_arrays(cls, x: NDArray, y: NDArray) -> IndexBipartite: + """Builds a bipartite index from two arrays. + + Parameters + ---------- + x: NDArray + See the documentation of y. + y: NDArray + Two arrays of the same length that contain the indices + that match between them. For example element i in the + array y is a neighbor of element i in array x. + + Examples + -------- + >>> a = np.array([1,1,1,51]) + >>> b = np.array([2,3,4,1]) + >>> IndexBipartite.from_arrays(a,b) + IndexBipartite(left_neighbors={1: [2, 3, 4], 51: [1]}, + right_neighbors={2: [1], 3: [1], 4: [1], 1: [51]}) + """ + if len(x) != len(y): + raise ValueError("x and y must be the same length") + + new = cls() + for x1, y1 in zip(x, y): + new.add_connection(x1, y1) + + return new + + def intersect(self, other: IndexBipartite) -> IndexBipartite: + """Returns the intersection of two bipartite graphs. + + Parameters + ---------- + other: IndexBipartite + The other bipartite graph to intersect with. + + Returns + ------- + IndexBipartite + A new bipartite graph that contains the intersection of + both graphs. + + Examples + -------- + >>> bg1 = IndexBipartite() + >>> bg1.add_connection(1,1) + >>> bg1.add_connection(1,2) + >>> bg1.add_connection(1,3) + >>> bg1.add_connection(1,4) + >>> bg1.add_connection(2,2) + >>> bg2 = IndexBipartite() + >>> bg2.add_connection(1,1) + >>> bg2.add_connection(1,2) + >>> bg2.add_connection(1,3) + >>> bg2.add_connection(2,2) + >>> bg2.add_connection(2,3) + >>> bg1.intersect(bg2) + IndexBipartite(left_neighbors={1: [1, 2, 3], 2: [2]}, + right_neighbors={1: [1], 2: [1, 2], 3: [1]}) + """ + new = IndexBipartite() + + for x in self.left_neighbors.keys(): + if x in other.left_neighbors: + for xn in self.left_neighbors[x]: + if xn in other.left_neighbors[x]: + new.add_connection(x, xn) + + return new + + +def _default_dist_fun(x: float, y: float) -> float: + """Default distance function to use.""" + return y - x + + +@dataclass +class NeighborFinder: + """Class to find neighbors in a multidimensional space. + + This class is used to find neighbors in a multidimensional space. + + Parameters + ---------- + dist_ranges: dict[str, tuple[float, float]] + A dictionary that contains the ranges of the distances + for each dimension. The keys of the dictionary are the + names of the dimensions and the values are tuples that + contain the minimum and maximum distance for that dimension. + dist_funs: dict[str, callable] + A dictionary that contains the distance functions for each + dimension. The keys of the dictionary are the names of the + dimensions and the values are the distance functions. + The distance functions must take two arguments and return + a float. + order: tuple[str] + The order in which the dimensions should be searched. + If this is None then the order will be the same as the + order of the keys in the dist_ranges dictionary. + force_vectorized: bool + This forces the search to use a vectorized implementation that + might be faster depending on the use case, in theory does more + operations but those operations happen in cpu cache. (test it ...) + """ + + dist_ranges: dict[str, tuple[float, float]] + dist_funs: dict[str, callable] + order: tuple[str] + force_vectorized: bool + + def __post_init__(self) -> None: + """Post init function.""" + if self.order is None: + self.order = tuple(self.dist_ranges.keys()) + + if self.dist_funs is None: + self.dist_funs = {k: _default_dist_fun for k in self.dist_ranges} + + if not set(self.dist_ranges.keys()) == set(self.dist_funs.keys()): + raise ValueError("dist_ranges and dist_funs must have the same keys") + + def find_neighbors( + self, + elems1: dict[str, NDArray], + elems2: dict[str, NDArray], + ) -> IndexBipartite: + """Finds neighbors in a multidimensional space.""" + out = multidim_neighbor_search( + elems1=elems1, + elems2=elems2, + dist_ranges=self.dist_ranges, + dist_funs=self.dist_funs, + order=self.order, + force_vectorized=self.force_vectorized, + ) + return out + + +# @profile +def find_neighbors_sorted( + x: NDArray, + y: NDArray, + dist_fun: callable, + low_dist: float, + high_dist: float, + allowed_neighbors: dict[int, set[int]] | None = None, +) -> IndexBipartite: + """Finds neighbors between to sorted arrays. + + Parameters + ---------- + x, NDArray: + First array to use to find neighbors + y, NDArray: + Second array to use to find neighbors + dist_fun: + Function to calculate the distance between an element + in `x` and an element in `y`. + Note that this asumes directionality and should increase in value. + In other words ... + dist_fun(x[0], y[0]) < dist_fun(x[0], y[1]); assuming that y[1] > y[0] + low_dist: + Lowest value allowable as a distance for two elements + to be considered neighbors + low_dist: + Highest value allowable as a distance for two elements + to be considered neighbors + high_dist: + Highest value allowable as a distance for two elements. + allowed_neighbors: + A dictionary that contains the allowed neighbors for each + element in `x`. The keys of the dictionary are the indices + of the elements in `x` and the values are sets that contain + the indices of the elements in `y` that are allowed to be + neighbors of the element in `x`. If this is None then all + elements in `y` are allowed to be neighbors of the elements + in `x`. + + + Examples + -------- + >>> x = np.array([1.,2.,3.,4.,5.,15.,25.]) + >>> y = np.array([1.1, 2.3, 3.1, 4., 25., 25.1]) + >>> dist_fun = lambda x,y: y - x + >>> low_dist = -0.11 + >>> high_dist = 0.11 + >>> find_neighbors_sorted(x,y,dist_fun,low_dist, high_dist) + IndexBipartite(left_neighbors={0: [0], 2: [2], 3: [3], 6: [4, 5]}, + right_neighbors={0: [0], 2: [2], 3: [3], 4: [6], 5: [6]}) + >>> find_neighbors_sorted(x,y,dist_fun,low_dist, high_dist, allowed_neighbors={6: {5, 4}}) + IndexBipartite(left_neighbors={6: [4, 5]}, right_neighbors={4: [6], 5: [6]}) + """ # noqa: E501 + assert is_sorted(x) + assert is_sorted(y) + assert low_dist < high_dist + assert dist_fun(low_dist, high_dist) > 0 + assert dist_fun(high_dist, low_dist) < 0 + + neighbors = IndexBipartite() + + ii = 0 + + if allowed_neighbors is None: + iter_x = range(len(x)) + else: + iter_x = sorted(allowed_neighbors.keys()) + + for i in iter_x: + x_val = x[i] + last_diff = None + + if allowed_neighbors is None: + iter_y = range(ii, len(y)) + else: + iter_y = sorted(allowed_neighbors[i]) + + for j in iter_y: + y_val = y[j] + + diff = dist_fun(x_val, y_val) + + # TODO disable this for performance ... + if last_diff is not None: + assert diff >= last_diff + last_diff = diff + + if diff < low_dist: + ii = j + continue + if diff > high_dist: + break + + assert diff <= high_dist and diff >= low_dist + neighbors.add_connection(i, j) + return neighbors + + +# @profile +def multidim_neighbor_search( + elems1: dict[str, NDArray], + elems2: dict[str, NDArray] | None, + dist_ranges: dict[str, tuple[float, float]], + dist_funs: None | dict[str, callable] = None, + dimension_order: None | tuple[str] = None, +) -> IndexBipartite: + """Searches for neighbors in multiple dimensions. + + Parameters + ---------- + elems1, dict[str,NDArray]: + Seel elems2 + elems2, dict[str,NDArray] | None: + A dictionary of arrays. + All arrays within one of those elements need to have the same + length. + dist_ranges, dict[str, tuple[float, float]]: + maximum and minimum ranges for each of the dimensions. + dist_funs: + Dictionary of functions used to calculate distances. + For details check the documentation of `find_neighbors_sorted` + dimension_order, optional str: + Optional tuple of strings denoting what dimensions to use. + + Examples + -------- + >>> x1 = {"d1": np.array([1000., 1000., 2001., 3000.]), + ... "d2": np.array([1000., 1000.3, 2000., 3000.01])} + >>> x2 = {"d1": np.array([1000.01, 1000.01, 2000., 3000.]), + ... "d2": np.array([1000.01, 1000.01, 2000., 3001.01])} + >>> d_funs = {"d1": lambda x,y: 1e6 * (y-x)/abs(x), "d2": lambda x,y: y-x} + >>> d_ranges = {"d1": (-10, 10), "d2": (-0.02, 0.02)} + >>> multidim_neighbor_search( + ... x1, x2, d_ranges, d_funs + ... ) + IndexBipartite(left_neighbors={0: {0, 1}, 2: {2}}, right_neighbors={0: {0}, 1: {0}, 2: {2}}) + """ # noqa: E501 + if dimension_order is None: + dimension_order = list(elems1.keys()) + + elems_1_indices = np.arange(len(elems1[dimension_order[0]])) + + if dist_funs is None: + dist_funs = {k: lambda x, y: y - x for k in dimension_order} + + # sort all elements by their first dimension + elems_1_order = np.argsort(elems1[dimension_order[0]]) + + elems_1 = {k: v[elems_1_order] for k, v in elems1.items()} + + # The original indices are also sorted by the same dimension + elems_1_indices = elems_1_indices[elems_1_order] + + if elems2 is not None: + elems_2_indices = np.arange(len(elems2[dimension_order[0]])) + elems_2_order = np.argsort(elems2[dimension_order[0]]) + elems_2 = {k: v[elems_2_order] for k, v in elems2.items()} + elems_2_indices = elems_2_indices[elems_2_order] + else: + elems_2 = elems_1 + elems_2_indices = elems_1_indices + + # Set up the graph where the neighbors will be stored + out = _multidim_neighbor_search( + elems_1=elems_1, + elems_2=elems_2, + elems_1_indices=elems_1_indices, + elems_2_indices=elems_2_indices, + dist_ranges=dist_ranges, + dist_funs=dist_funs, + dimension_order=dimension_order, + ) + return out + + +# @profile +def _multidim_neighbor_search( + elems_1: dict[str, NDArray], + elems_2: dict[str, NDArray], + elems_1_indices: NDArray, + elems_2_indices: NDArray, + dist_ranges: dict[str, tuple[float, float]], + dist_funs: dict[str, callable], + dimension_order: tuple[str], +) -> IndexBipartite: + """Searches for neighbors in multiple dimensions. + + This internal function is used by `multidim_neighbor_search` and + is not intended to be used directly. Since it does not have the + safety guarantees that the public function has. + + For the parameters deltails please check the documentation of the + public function. + """ + neighbors = IndexBipartite() + # ii = 0 + assert is_sorted(elems_1[dimension_order[0]]) + + for i in range(len(elems_1[dimension_order[0]])): + # Allowable indices is a list that maps the indices that + # are still viable to use, mapping to the original indices + allowable_indices = None + for dimension in dimension_order: + curr_dimension_matches = [] + dist_fun = dist_funs[dimension] + low_dist, high_dist = dist_ranges[dimension] + x = elems_1[dimension] + y = elems_2[dimension] + + x_val = x[i] + + assert low_dist < high_dist + assert dist_fun(low_dist, high_dist) > 0 + assert dist_fun(high_dist, low_dist) < 0 + + # we generate an iterable that yields the indices of + # the original array in increasing order of the current + # dimension + if allowable_indices is not None: + if len(allowable_indices) == 0: + break + # we can only search in the allowable indices + # from the previous dimension + allowed_y = y[allowable_indices] + + match_indices = allowed_y >= x_val + low_dist + match_indices = match_indices & (allowed_y <= x_val + high_dist) + curr_dimension_matches = allowable_indices[match_indices] + else: + curr_dimension_matches = np.arange( + np.searchsorted(y, x_val + low_dist), + np.searchsorted(y, x_val + high_dist), + ) + + # Speed test + match_indices = y >= x_val + low_dist + + # Surprisingly, this is slower ... + # ii = np.searchsorted(y[ii:], x_val + low_dist) + ii + # oi = np.searchsorted(y[ii:], x_val + high_dist) + ii + # curr_dimension_matches_o = np.arange(ii, oi) + # assert np.all(curr_dimension_matches == curr_dimension_matches_o) + + # if allowable_indices is not None: + # assert all(x in allowable_indices for x in curr_dimension_matches) + allowable_indices = curr_dimension_matches + + # After going though all dimensions, we have the allowable indices + # for the current element in the first array + for j in curr_dimension_matches: + neighbors.add_connection(i, j) + + # Now we have to map the indices back to the original indices + # TODO check if vectorizing this is faster + neighbors = IndexBipartite( + left_neighbors={ + elems_1_indices[k]: {elems_2_indices[w] for w in v} + for k, v in neighbors.left_neighbors.items() + }, + right_neighbors={ + elems_2_indices[k]: {elems_1_indices[w] for w in v} + for k, v in neighbors.right_neighbors.items() + }, + ) + return neighbors diff --git a/diadem/utils.py b/diadem/utilities/utils.py similarity index 89% rename from diadem/utils.py rename to diadem/utilities/utils.py index d3c0cca..ec7b8da 100644 --- a/diadem/utils.py +++ b/diadem/utilities/utils.py @@ -2,11 +2,22 @@ from contextlib import contextmanager import numpy as np +import uniplot from loguru import logger from ms2ml import Peptide from numpy.typing import NDArray +def plot_to_log(*args, **kwargs) -> None: # noqa + """Plot to log. + + Generates a plot of the passed data to the function. + All arguments are passed internally to uniplot.plot_to_string. + """ + for line in uniplot.plot_to_string(*args, **kwargs): + logger.debug(line) + + @contextmanager # @profile def disabled_gc() -> None: @@ -80,6 +91,7 @@ def make_decoy(pep: Peptide) -> Peptide: return pep +# @profile def get_slice_inds(arr: NDArray, minval: float, maxval: float) -> slice: """Gets the slide indices that include a range. @@ -114,9 +126,10 @@ def get_slice_inds(arr: NDArray, minval: float, maxval: float) -> slice: # slice_max = np.searchsorted(arr[slice_min:], maxval, side="right") # slice_max = slice_min + slice_max i = 0 - for i, val in enumerate(arr[slice_min:]): + for val in arr[slice_min:]: if val > maxval: break + i += 1 slice_max = slice_min + i return slice(slice_min, slice_max) diff --git a/profiling/.dockerignore b/profiling/.dockerignore new file mode 100644 index 0000000..235a3a8 --- /dev/null +++ b/profiling/.dockerignore @@ -0,0 +1,2 @@ + +* diff --git a/profiling/.dvc/.gitignore b/profiling/.dvc/.gitignore new file mode 100644 index 0000000..528f30c --- /dev/null +++ b/profiling/.dvc/.gitignore @@ -0,0 +1,3 @@ +/config.local +/tmp +/cache diff --git a/profiling/.dvc/config b/profiling/.dvc/config new file mode 100644 index 0000000..e69de29 diff --git a/profiling/.dvcignore b/profiling/.dvcignore new file mode 100644 index 0000000..5197305 --- /dev/null +++ b/profiling/.dvcignore @@ -0,0 +1,3 @@ +# Add patterns of files dvc should ignore, which could improve +# the performance. Learn more at +# https://dvc.org/doc/user-guide/dvcignore diff --git a/profiling/.gitignore b/profiling/.gitignore index 055cf29..5e0e3bf 100644 --- a/profiling/.gitignore +++ b/profiling/.gitignore @@ -1,6 +1,18 @@ diadem.csv diadem.parquet +**diadem.csv +**diadem.tsv.pin +**diadem.parquet profiling_data/* line_profile*.txt profile*.txt + +**mokapot*.txt + +profiling_data_bkp +*/*.parquet +*.parquet + +dvc_plots +bkp diff --git a/profiling/Dockerfile b/profiling/Dockerfile new file mode 100644 index 0000000..04764bc --- /dev/null +++ b/profiling/Dockerfile @@ -0,0 +1,7 @@ +# FROM python:3.9-slim +FROM --platform=linux/amd64 python:3.9-bullseye + +RUN apt-get update && apt-get install -y build-essential gcc python3-dev +RUN python3 -m pip install git+https://github.com/jspaezp/alphatims.git@feature/dockerfile + +ENTRYPOINT [ "alphatims" ] diff --git a/profiling/configs/orbi.toml b/profiling/configs/orbi.toml new file mode 100644 index 0000000..d6b2586 --- /dev/null +++ b/profiling/configs/orbi.toml @@ -0,0 +1,50 @@ +peptide_length_range = [ + 7, + 25, +] +peptide_mz_range = [ + 400, + 2000, +] +precursor_charges = [ + 2, + 3, +] +ion_series = "by" +ion_charges = [ + 1, +] +ion_mz_range = [ + 250, + 2000.0, +] +db_enzyme = "trypsin" +db_max_missed_cleavages = 2 +db_bucket_size = 32768 +g_tolerances = [ + 0.02, + 0.02, +] +g_tolerance_units = [ + "da", + "da", +] +g_ims_tolerance = 0.01 +g_ims_tolerance_unit = "abs" +scoring_score_function = "Hyperscore" +run_max_peaks = 20000 +run_max_peaks_per_spec = 5000 +run_parallelism = 1 +run_deconvolute_spectra = true +run_min_peak_intensity = 100 +run_debug_log_frequency = 50 +run_allowed_fails = 5000 +run_window_size = 21 +run_max_peaks_per_window = 500 +run_min_intensity_ratio = 0.01 +run_min_correlation_score = 0.5 +run_scaling_ratio = 0.001 +run_scaling_limits = [ + 0.001, + 0.9, +] diff --git a/profiling/configs/tims.toml b/profiling/configs/tims.toml new file mode 100644 index 0000000..21670b8 --- /dev/null +++ b/profiling/configs/tims.toml @@ -0,0 +1,50 @@ +peptide_length_range = [ + 7, + 25, +] +peptide_mz_range = [ + 400, + 2000, +] +precursor_charges = [ + 2, + 3, +] +ion_series = "by" +ion_charges = [ + 1, +] +ion_mz_range = [ + 250, + 2000.0, +] +db_enzyme = "trypsin" +db_max_missed_cleavages = 2 +db_bucket_size = 32768 +g_tolerances = [ + 0.02, + 0.02, +] +g_tolerance_units = [ + "da", + "da", +] +g_ims_tolerance = 0.01 +g_ims_tolerance_unit = "abs" +scoring_score_function = "Hyperscore" +run_max_peaks = 20000 +run_max_peaks_per_spec = 5000 +run_parallelism = 1 +run_deconvolute_spectra = true +run_min_peak_intensity = 100 +run_debug_log_frequency = 50 +run_allowed_fails = 5000 +run_window_size = 21 +run_max_peaks_per_window = 150 +run_min_intensity_ratio = 0.01 +run_min_correlation_score = 0.2 +run_scaling_ratio = 0.001 +run_scaling_limits = [ + 0.001, + 0.9, +] diff --git a/profiling/dvc.lock b/profiling/dvc.lock new file mode 100644 index 0000000..dec5cfe --- /dev/null +++ b/profiling/dvc.lock @@ -0,0 +1,272 @@ +schema: '2.0' +stages: + tims_run_hela: + cmd: mkdir -p results/tims && python src/run.py --config configs/tims.toml --fasta + profiling_data/UP000005640_9606_crap.fasta --ms_data profiling_data/230426_Hela_01_S4-E5_1_662.hdf + --output results/tims/hela --threads 4 && python src/plot_results.py results/tims + --prefix hela + deps: + - path: profiling_data/230426_Hela_01_S4-E5_1_662.hdf + md5: 02f5846145d04ff9301f78d2b6299b75 + size: 4412473566 + - path: profiling_data/UP000005640_9606_crap.fasta + md5: 6a22e753f0e35eae9e0bf3d95759d089 + size: 13676378 + - path: src/run.py + md5: 074ac3e7f0c1020f179d8fdec9c677c2 + size: 1095 + params: + configs/tims.toml: + db_bucket_size: 32768 + db_enzyme: trypsin + db_max_missed_cleavages: 2 + g_ims_tolerance: 0.01 + g_ims_tolerance_unit: abs + g_tolerance_units: + - da + - da + g_tolerances: + - 0.02 + - 0.02 + ion_charges: + - 1 + ion_mz_range: + - 250 + - 2000.0 + ion_series: by + peptide_length_range: + - 7 + - 25 + peptide_mz_range: + - 400 + - 2000 + precursor_charges: + - 2 + - 3 + run_allowed_fails: 5000 + run_debug_log_frequency: 50 + run_deconvolute_spectra: true + run_max_peaks: 20000 + run_max_peaks_per_spec: 5000 + run_max_peaks_per_window: 150 + run_min_correlation_score: 0.2 + run_min_intensity_ratio: 0.01 + run_min_peak_intensity: 100 + run_parallelism: 1 + run_scaling_limits: + - 0.001 + - 0.9 + run_scaling_ratio: 0.001 + run_window_size: 21 + scoring_score_function: Hyperscore + outs: + - path: results/tims/hela.diadem.csv + md5: 9d3443227ddc4d7c0c742548336e4e9a + size: 15179553 + - path: results/tims/hela_log_score_histogram_peptide.png + md5: 3b04c6c5d0871886be7c9c366cd558ec + size: 22777 + - path: results/tims/hela_metrics.toml + md5: 592c82dba0c84c0e5114bb1ca058ef5b + size: 160 + - path: results/tims/hela_peptide_qval.png + md5: fc3f68c6dd84ef4306915186c3aaf69e + size: 19942 + - path: results/tims/hela_runtime.toml + md5: 6235b1e40562e00e58c2a7eaff18a064 + size: 27 + - path: results/tims/hela_score_histogram_peptide.png + md5: 03826bdd891ad915a2187016b8a82a6b + size: 21475 + - path: results/tims/hela_score_histogram_psm.png + md5: 3e88a2f284818880ba3abddc95d7b465 + size: 22476 + tims_run_ecoli: + cmd: python src/run.py --config configs/tims.toml --fasta profiling_data/UP000000625_83333_crap.fasta + --ms_data profiling_data/LFQ_timsTOFPro_diaPASEF_Ecoli_01.hdf --output results/tims/ecoli + --threads 4 && python src/plot_results.py results/tims --prefix ecoli + deps: + - path: profiling_data/LFQ_timsTOFPro_diaPASEF_Ecoli_01.hdf + md5: 764764d9da856aab25d562b4c8c835e2 + size: 17521478250 + - path: profiling_data/UP000000625_83333_crap.fasta + md5: 1b9a2db686eb2024edbd33b7f24117ef + size: 1891750 + - path: src/run.py + md5: 074ac3e7f0c1020f179d8fdec9c677c2 + size: 1095 + - path: src/run.py + md5: 074ac3e7f0c1020f179d8fdec9c677c2 + size: 1095 + params: + configs/tims.toml: + db_bucket_size: 32768 + db_enzyme: trypsin + db_max_missed_cleavages: 2 + g_ims_tolerance: 0.01 + g_ims_tolerance_unit: abs + g_tolerance_units: + - da + - da + g_tolerances: + - 0.02 + - 0.02 + ion_charges: + - 1 + ion_mz_range: + - 250 + - 2000.0 + ion_series: by + peptide_length_range: + - 7 + - 25 + peptide_mz_range: + - 400 + - 2000 + precursor_charges: + - 2 + - 3 + run_allowed_fails: 5000 + run_debug_log_frequency: 50 + run_deconvolute_spectra: true + run_max_peaks: 20000 + run_max_peaks_per_spec: 5000 + run_max_peaks_per_window: 150 + run_min_correlation_score: 0.2 + run_min_intensity_ratio: 0.01 + run_min_peak_intensity: 100 + run_parallelism: 1 + run_scaling_limits: + - 0.001 + - 0.9 + run_scaling_ratio: 0.001 + run_window_size: 21 + scoring_score_function: Hyperscore + outs: + - path: results/tims/ecoli.diadem.csv + md5: 999fb7588b75653e5a245753bf4b4070 + size: 70538015 + - path: results/tims/ecoli_log_score_histogram_peptide.png + md5: df9d4b3a8a21c54ae46b476ec0697224 + size: 23229 + - path: results/tims/ecoli_metrics.toml + md5: 0988437bd1b818d6aeeb1d23bb7793f4 + size: 166 + - path: results/tims/ecoli_peptide_qval.png + md5: 498b784d3e0cd132eac8f93b8bf3fd7a + size: 18364 + - path: results/tims/ecoli_runtime.toml + md5: a076a35fbcf20c8aecae926d53802364 + size: 27 + - path: results/tims/ecoli_score_histogram_peptide.png + md5: e0edfb0a28f3c264c6de36dd1d38c9b4 + size: 20845 + - path: results/tims/ecoli_score_histogram_psm.png + md5: 0b62302a8e3cbd56d80084501158cd49 + size: 24514 + get_data: + cmd: zsh src/get_data.zsh + deps: + - path: Dockerfile + md5: a5f9ec03e6ab364f6179b60079fba806 + size: 264 + - path: src/get_data.zsh + md5: 83ea92e422447d99c7d8c6508b5f0807 + size: 1896 + outs: + - path: profiling_data/230407_Chrom_60m_1ug_v2_01.mzML + md5: 19efb578f116318ce63f707dd0193ad9 + size: 1452019888 + - path: profiling_data/230426_Hela_01_S4-E5_1_662.hdf + md5: 02f5846145d04ff9301f78d2b6299b75 + size: 4412473566 + - path: profiling_data/LFQ_timsTOFPro_diaPASEF_Ecoli_01.hdf + md5: 764764d9da856aab25d562b4c8c835e2 + size: 17521478250 + - path: profiling_data/UP000000625_83333_crap.fasta + md5: 1b9a2db686eb2024edbd33b7f24117ef + size: 1891750 + - path: profiling_data/UP000005640_9606_crap.fasta + md5: 6a22e753f0e35eae9e0bf3d95759d089 + size: 13676378 + orbi_run_hela: + cmd: mkdir -p results/orbi && python src/run.py --config configs/orbi.toml --fasta + profiling_data/UP000005640_9606_crap.fasta --ms_data profiling_data/230407_Chrom_60m_1ug_v2_01.mzML + --output results/orbi/hela --threads 4 && python src/plot_results.py results/orbi + --prefix hela + deps: + - path: profiling_data/230407_Chrom_60m_1ug_v2_01.mzML + md5: 19efb578f116318ce63f707dd0193ad9 + size: 1452019888 + - path: profiling_data/UP000005640_9606_crap.fasta + md5: 6a22e753f0e35eae9e0bf3d95759d089 + size: 13676378 + - path: src/run.py + md5: 074ac3e7f0c1020f179d8fdec9c677c2 + size: 1095 + params: + configs/orbi.toml: + db_bucket_size: 32768 + db_enzyme: trypsin + db_max_missed_cleavages: 2 + g_ims_tolerance: 0.01 + g_ims_tolerance_unit: abs + g_tolerance_units: + - da + - da + g_tolerances: + - 0.02 + - 0.02 + ion_charges: + - 1 + ion_mz_range: + - 250 + - 2000.0 + ion_series: by + peptide_length_range: + - 7 + - 25 + peptide_mz_range: + - 400 + - 2000 + precursor_charges: + - 2 + - 3 + run_allowed_fails: 5000 + run_debug_log_frequency: 50 + run_deconvolute_spectra: true + run_max_peaks: 20000 + run_max_peaks_per_spec: 5000 + run_max_peaks_per_window: 500 + run_min_correlation_score: 0.5 + run_min_intensity_ratio: 0.01 + run_min_peak_intensity: 100 + run_parallelism: 1 + run_scaling_limits: + - 0.001 + - 0.9 + run_scaling_ratio: 0.001 + run_window_size: 21 + scoring_score_function: Hyperscore + outs: + - path: results/orbi/hela.diadem.csv + md5: a8d10c84ee9694b796b94fceed271bea + size: 28876637 + - path: results/orbi/hela_log_score_histogram_peptide.png + md5: e3f464612fa3dd216b493dbddce16b89 + size: 23397 + - path: results/orbi/hela_metrics.toml + md5: 8a906bf8cf89ed1bc13b4efda147bba1 + size: 160 + - path: results/orbi/hela_peptide_qval.png + md5: 965d5a315b0911a3a51d3c9acc980500 + size: 18880 + - path: results/orbi/hela_runtime.toml + md5: cdc4367963d2d0888807c4875405f114 + size: 28 + - path: results/orbi/hela_score_histogram_peptide.png + md5: 2c4ef5bb0055f822bf5c8cf88c85e21e + size: 23869 + - path: results/orbi/hela_score_histogram_psm.png + md5: 4d496191e3cca110117abb5b5f9ad478 + size: 20897 diff --git a/profiling/dvc.yaml b/profiling/dvc.yaml new file mode 100644 index 0000000..5f093aa --- /dev/null +++ b/profiling/dvc.yaml @@ -0,0 +1,119 @@ + +stages: + get_data: + cmd: zsh src/get_data.zsh + deps: + - src/get_data.zsh + - Dockerfile + outs: + - profiling_data/LFQ_timsTOFPro_diaPASEF_Ecoli_01.hdf + - profiling_data/230426_Hela_01_S4-E5_1_662.hdf + - profiling_data/230407_Chrom_60m_1ug_v2_01.mzML + - profiling_data/UP000000625_83333_crap.fasta + - profiling_data/UP000005640_9606_crap.fasta + + orbi_run_hela: + params: + - configs/orbi.toml: + plots: + - results/orbi/hela_log_score_histogram_peptide.png: + cache: false + - results/orbi/hela_score_histogram_peptide.png: + cache: false + - results/orbi/hela_peptide_qval.png: + cache: false + - results/orbi/hela_score_histogram_psm.png: + cache: false + metrics: + - results/orbi/hela_metrics.toml: + cache: false + - results/orbi/hela_runtime.toml: + cache: false + deps: + - src/run.py + - profiling_data/230407_Chrom_60m_1ug_v2_01.mzML + - profiling_data/UP000005640_9606_crap.fasta + outs: + - results/orbi/hela.diadem.csv + cmd: >- + mkdir -p results/orbi && + python src/run.py + --config configs/orbi.toml + --fasta profiling_data/UP000005640_9606_crap.fasta + --ms_data profiling_data/230407_Chrom_60m_1ug_v2_01.mzML + --output results/orbi/hela + --threads 4 && + python src/plot_results.py results/orbi --prefix hela + + tims_run_hela: + params: + - configs/tims.toml: + deps: + - src/run.py + - profiling_data/UP000005640_9606_crap.fasta + - profiling_data/230426_Hela_01_S4-E5_1_662.hdf + outs: + - results/tims/hela.diadem.csv + plots: + - results/tims/hela_log_score_histogram_peptide.png: + cache: false + - results/tims/hela_score_histogram_peptide.png: + cache: false + - results/tims/hela_peptide_qval.png: + cache: false + - results/tims/hela_score_histogram_psm.png: + cache: false + metrics: + - results/tims/hela_metrics.toml: + cache: false + - results/tims/hela_runtime.toml: + cache: false + cmd: >- + mkdir -p results/tims && + python src/run.py + --config configs/tims.toml + --fasta profiling_data/UP000005640_9606_crap.fasta + --ms_data profiling_data/230426_Hela_01_S4-E5_1_662.hdf + --output results/tims/hela + --threads 4 && + python src/plot_results.py results/tims --prefix hela + + tims_run_ecoli: + params: + - configs/tims.toml: + deps: + - src/run.py + - profiling_data/UP000000625_83333_crap.fasta + - profiling_data/LFQ_timsTOFPro_diaPASEF_Ecoli_01.hdf + - src/run.py + outs: + - results/tims/ecoli.diadem.csv + plots: + - results/tims/ecoli_log_score_histogram_peptide.png: + cache: false + - results/tims/ecoli_score_histogram_peptide.png: + cache: false + - results/tims/ecoli_peptide_qval.png: + cache: false + - results/tims/ecoli_score_histogram_psm.png: + cache: false + metrics: + - results/tims/ecoli_metrics.toml: + cache: false + - results/tims/ecoli_runtime.toml: + cache: false + cmd: >- + python src/run.py + --config configs/tims.toml + --fasta profiling_data/UP000000625_83333_crap.fasta + --ms_data profiling_data/LFQ_timsTOFPro_diaPASEF_Ecoli_01.hdf + --output results/tims/ecoli + --threads 4 && + python src/plot_results.py results/tims --prefix ecoli + + +plots: + - results/tims/ecoli_log_score_histogram_peptide.png + - results/tims/ecoli_score_histogram_peptide.png + - results/tims/ecoli_peptide_qval.png + - results/tims/ecoli_score_histogram_psm.png diff --git a/profiling/get_data.bash b/profiling/get_data.bash deleted file mode 100644 index c3372c5..0000000 --- a/profiling/get_data.bash +++ /dev/null @@ -1,10 +0,0 @@ -#!/bin/bash - -mkdir -p profiling_data -aws s3 cp --profile mfa s3://data-pipeline-mzml-bucket/221229_ChessFest_Plate3/Chessfest_Plate3_RH4_DMSO_DIA.mzML.gz ./profiling_data/. -aws s3 cp --profile mfa s3://data-pipeline-metadata-bucket/uniprot_human_sp_canonical_2021-11-19_crap.fasta ./profiling_data/. -curl ftp.pride.ebi.ac.uk/pride/data/archive/2022/02/PXD028735/LFQ_timsTOFPro_PASEF_Ecoli_01.d.zip --output ./profiling_data/ecoli_timsTOFPro_PASEF.d.zip -curl ftp.pride.ebi.ac.uk/pride/data/archive/2022/02/PXD028735/LFQ_timsTOFPro_diaPASEF_Ecoli_01.d.zip --output ./profiling_data/ecoli_timsTOFPro_diaPASEF.d.zip -gunzip ./profiling_data/*.gz -for i in ./profiling_data/*.zip ; do unzip $i ; done -mv LFQ_timsTOFPro_* ./profiling_data/. diff --git a/profiling/lineprofile_dia.zsh b/profiling/lineprofile_dia.zsh deleted file mode 100644 index 1f6aab8..0000000 --- a/profiling/lineprofile_dia.zsh +++ /dev/null @@ -1,42 +0,0 @@ - -# This one does need you to go into the code and decorate with @profile -# what you want profiled -mkdir -p lineprofile_results -sed -ie "s/# @profile/@profile/g" ../diadem/**/*.py -python -m pip install ../. -sed -ie "s/@profile/# @profile/g" ../diadem/**/*.py - -set -x -set -e - -# PYTHONOPTIMIZE=1 -DEBUG_DIADEM=1 python -m kernprof -l run_profile.py -python -m line_profiler run_profile.py.lprof > "line_profile_$(date '+%Y%m%d_%H%M').txt" -python -m line_profiler run_profile.py.lprof > "line_profile_latest.txt" - -R -e 'library(tidyverse) ; foo = readr::read_tsv("lineprofile_results/results.diadem.tsv.pin") ; g <- ggplot(foo, aes(x = `Score`, fill = factor(Label))) + geom_density(alpha=0.4) ; ggsave("lineprofile_results/raw_td_plot.png", plot = g)' -R -e 'library(tidyverse) ; foo = readr::read_tsv("lineprofile_results/results.diadem.tsv.pin") ; g <- ggplot(foo, aes(x = seq_along(`Score`), y = Score, colour = factor(Label))) + geom_point(alpha=0.4) ; ggsave("lineprofile_results/iter_score_plot.png", plot = g)' - -mokapot lineprofile_results/results.diadem.tsv.pin --test_fdr 0.01 --keep_decoys -R -e 'library(tidyverse) ; foo = readr::read_tsv("mokapot.peptides.txt") ; foo2 = readr::read_tsv("mokapot.decoy.peptides.txt") ; g <- ggplot(bind_rows(foo, foo2), aes(x = `mokapot score`, fill = Label)) + geom_density(alpha=0.4) ; ggsave("lineprofile_results/td_plot.png", plot = g)' - -# 20230207 6pm Elapsed time: 1571.8632419109344 -# 2023-02-07 20:02:42.108 | INFO | diadem.search.diadem:diadem_main:284 - Elapsed time: 1022.3919858932495 - -# Python deisotoping -# 2023-02-08 20:46:37.470 | INFO | diadem.search.diadem:diadem_main:290 - Elapsed time: 313.4971549510956 -# [INFO] - 83 target PSMs and 20 decoy PSMs detected. - -# Sage Deisotoping -# 2023-02-09 17:07:52.691 | INFO | diadem.search.diadem:diadem_main:289 - Elapsed time: 193.32771015167236 -# [INFO] - 180 target PSMs and 40 decoy PSMs detected. - -# Changing score tracking from objects to list and increased number of fails -# [INFO] - Found 356 peptides with q<=0.05 -# 2023-02-10 13:26:32.888 | INFO | diadem.search.diadem:diadem_main:290 - Elapsed time: 698.4216511249542 - -# DEBUG_DIADEM=1 PYTHONOPTIMIZE=1 python -m cProfile -s tottime run_profile.py > profile_singlethread.txt -# 2023-02-07 18:48:02.815 | INFO | diadem.search.diadem:diadem_main:284 - Elapsed time: 572.985399723053 - -# Changing num_decimal to 2 from 3, this will make buckets larger -# Changing number of isotopes diff --git a/profiling/profile_dia.zsh b/profiling/profile_dia.zsh deleted file mode 100644 index a5c00c8..0000000 --- a/profiling/profile_dia.zsh +++ /dev/null @@ -1,6 +0,0 @@ - -python -m pip install ../. -DEBUG_DIADEM=1 python -m cProfile -s tottime -m diadem.cli search \ - --mzml_file profiling_data/Chessfest_Plate3_RH4_DMSO_DIA.mzML \ - --fasta profiling_data/uniprot_human_sp_canonical_2021-11-19_crap.fasta \ - --out_prefix "" --mode DIA > profile.txt diff --git a/profiling/results/orbi/hela_log_score_histogram_peptide.png b/profiling/results/orbi/hela_log_score_histogram_peptide.png new file mode 100644 index 0000000..ec05210 Binary files /dev/null and b/profiling/results/orbi/hela_log_score_histogram_peptide.png differ diff --git a/profiling/results/orbi/hela_metrics.toml b/profiling/results/orbi/hela_metrics.toml new file mode 100644 index 0000000..3abd18c --- /dev/null +++ b/profiling/results/orbi/hela_metrics.toml @@ -0,0 +1,5 @@ +NumPeptides = 8697 +AvgTargetScore = 32.83279324619224 +TargetQ95Score = 44.880507377442605 +AvgDecoyScore = 29.001643309619226 +DecoyQ95Score = 35.013654165931776 diff --git a/profiling/results/orbi/hela_peptide_qval.png b/profiling/results/orbi/hela_peptide_qval.png new file mode 100644 index 0000000..0940849 Binary files /dev/null and b/profiling/results/orbi/hela_peptide_qval.png differ diff --git a/profiling/results/orbi/hela_runtime.toml b/profiling/results/orbi/hela_runtime.toml new file mode 100644 index 0000000..8352549 --- /dev/null +++ b/profiling/results/orbi/hela_runtime.toml @@ -0,0 +1 @@ +runtime = 3261.9539229869843 diff --git a/profiling/results/orbi/hela_score_histogram_peptide.png b/profiling/results/orbi/hela_score_histogram_peptide.png new file mode 100644 index 0000000..af3c2b4 Binary files /dev/null and b/profiling/results/orbi/hela_score_histogram_peptide.png differ diff --git a/profiling/results/orbi/hela_score_histogram_psm.png b/profiling/results/orbi/hela_score_histogram_psm.png new file mode 100644 index 0000000..1b4ed6b Binary files /dev/null and b/profiling/results/orbi/hela_score_histogram_psm.png differ diff --git a/profiling/results/orbi/hela_scores_over_time.png b/profiling/results/orbi/hela_scores_over_time.png new file mode 100644 index 0000000..5875b14 Binary files /dev/null and b/profiling/results/orbi/hela_scores_over_time.png differ diff --git a/profiling/results/tims/ecoli_log_score_histogram_peptide.png b/profiling/results/tims/ecoli_log_score_histogram_peptide.png new file mode 100644 index 0000000..d4e951d Binary files /dev/null and b/profiling/results/tims/ecoli_log_score_histogram_peptide.png differ diff --git a/profiling/results/tims/ecoli_metrics.toml b/profiling/results/tims/ecoli_metrics.toml new file mode 100644 index 0000000..7c9e580 --- /dev/null +++ b/profiling/results/tims/ecoli_metrics.toml @@ -0,0 +1,5 @@ +NumPeptides_q_0.01 = 6904 +AvgTargetScore = 22.11691792428762 +TargetQ95Score = 34.39576249711483 +AvgDecoyScore = 20.221616627128082 +DecoyQ95Score = 25.000040717505303 diff --git a/profiling/results/tims/ecoli_peptide_qval.png b/profiling/results/tims/ecoli_peptide_qval.png new file mode 100644 index 0000000..e800d51 Binary files /dev/null and b/profiling/results/tims/ecoli_peptide_qval.png differ diff --git a/profiling/results/tims/ecoli_runtime.toml b/profiling/results/tims/ecoli_runtime.toml new file mode 100644 index 0000000..71611c4 --- /dev/null +++ b/profiling/results/tims/ecoli_runtime.toml @@ -0,0 +1 @@ +runtime = 4270.734843254089 diff --git a/profiling/results/tims/ecoli_score_histogram_peptide.png b/profiling/results/tims/ecoli_score_histogram_peptide.png new file mode 100644 index 0000000..1240aee Binary files /dev/null and b/profiling/results/tims/ecoli_score_histogram_peptide.png differ diff --git a/profiling/results/tims/ecoli_score_histogram_psm.png b/profiling/results/tims/ecoli_score_histogram_psm.png new file mode 100644 index 0000000..648038c Binary files /dev/null and b/profiling/results/tims/ecoli_score_histogram_psm.png differ diff --git a/profiling/results/tims/ecoli_scores_over_time.png b/profiling/results/tims/ecoli_scores_over_time.png new file mode 100644 index 0000000..fc711ca Binary files /dev/null and b/profiling/results/tims/ecoli_scores_over_time.png differ diff --git a/profiling/results/tims/hela_log_score_histogram_peptide.png b/profiling/results/tims/hela_log_score_histogram_peptide.png new file mode 100644 index 0000000..c88c5bf Binary files /dev/null and b/profiling/results/tims/hela_log_score_histogram_peptide.png differ diff --git a/profiling/results/tims/hela_metrics.toml b/profiling/results/tims/hela_metrics.toml new file mode 100644 index 0000000..a2233a9 --- /dev/null +++ b/profiling/results/tims/hela_metrics.toml @@ -0,0 +1,5 @@ +NumPeptides_q_0.01 = 2442 +AvgTargetScore = 22.473263643744477 +TargetQ95Score = 32.60714890338063 +AvgDecoyScore = 20.481250731007073 +DecoyQ95Score = 25.627900520272156 diff --git a/profiling/results/tims/hela_peptide_qval.png b/profiling/results/tims/hela_peptide_qval.png new file mode 100644 index 0000000..1f47926 Binary files /dev/null and b/profiling/results/tims/hela_peptide_qval.png differ diff --git a/profiling/results/tims/hela_runtime.toml b/profiling/results/tims/hela_runtime.toml new file mode 100644 index 0000000..4607ce7 --- /dev/null +++ b/profiling/results/tims/hela_runtime.toml @@ -0,0 +1 @@ +runtime = 1483.474233865738 diff --git a/profiling/results/tims/hela_score_histogram_peptide.png b/profiling/results/tims/hela_score_histogram_peptide.png new file mode 100644 index 0000000..d8c5f4d Binary files /dev/null and b/profiling/results/tims/hela_score_histogram_peptide.png differ diff --git a/profiling/results/tims/hela_score_histogram_psm.png b/profiling/results/tims/hela_score_histogram_psm.png new file mode 100644 index 0000000..eec9356 Binary files /dev/null and b/profiling/results/tims/hela_score_histogram_psm.png differ diff --git a/profiling/results/tims/hela_scores_over_time.png b/profiling/results/tims/hela_scores_over_time.png new file mode 100644 index 0000000..699253e Binary files /dev/null and b/profiling/results/tims/hela_scores_over_time.png differ diff --git a/profiling/run_profile.py b/profiling/run_profile.py deleted file mode 100644 index 3897ea7..0000000 --- a/profiling/run_profile.py +++ /dev/null @@ -1,10 +0,0 @@ -from diadem import cli -from diadem.config import DiademConfig - -cli.setup_logger() -cli.diadem_main( - fasta_path="./profiling_data/uniprot_human_sp_canonical_2021-11-19_crap.fasta", - mzml_path="./profiling_data/Chessfest_Plate3_RH4_DMSO_DIA.mzML", - config=DiademConfig(run_parallelism=1), - out_prefix="lineprofile_results/results", -) diff --git a/profiling/run_profile_multithread.py b/profiling/run_profile_multithread.py deleted file mode 100644 index 9f95ab1..0000000 --- a/profiling/run_profile_multithread.py +++ /dev/null @@ -1,10 +0,0 @@ -from diadem import cli -from diadem.config import DiademConfig - -cli.setup_logger(level="INFO") -cli.diadem_main( - fasta_path="./profiling_data/uniprot_human_sp_canonical_2021-11-19_crap.fasta", - mzml_path="./profiling_data/Chessfest_Plate3_RH4_DMSO_DIA.mzML", - config=DiademConfig(run_parallelism=-2), - out_prefix="lineprofile_results_multithread/results", -) diff --git a/profiling/run_profile_multithread.zsh b/profiling/run_profile_multithread.zsh deleted file mode 100644 index 0fe745d..0000000 --- a/profiling/run_profile_multithread.zsh +++ /dev/null @@ -1,9 +0,0 @@ - -PYTHONOPTIMIZE=1 python run_profile_multithread.py - -R -e 'library(tidyverse) ; foo = readr::read_tsv("mokapot.peptides.txt") ; foo2 = readr::read_tsv("mokapot.decoy.peptides.txt") ; g <- ggplot(bind_rows(foo, foo2), aes(x = `mokapot score`, fill = Label)) + geom_density(alpha=0.4) ; ggsave("lineprofile_results_multithread/td_plot.png", plot = g)' -R -e 'library(tidyverse) ; foo = readr::read_tsv("lineprofile_results_multithread/results.diadem.tsv.pin") ; g <- ggplot(foo, aes(x = `Score`, fill = factor(Label))) + geom_density(alpha=0.4) ; ggsave("lineprofile_results_multithread/raw_td_plot.png", plot = g)' -R -e 'library(tidyverse) ; foo = readr::read_tsv("lineprofile_results_multithread/results.diadem.tsv.pin") ; g <- ggplot(foo, aes(x = seq_along(`Score`), y = Score, colour = factor(Label))) + geom_point(alpha=0.4) ; ggsave("lineprofile_results_multithread/iter_score_plot.png", plot = g)' - -# [INFO] - Found 26288 peptides with q<=0.01 -# 2023-02-13 12:43:55.191 | INFO | diadem.search.diadem:diadem_main:295 - Elapsed time: 3487.5772829055786 diff --git a/profiling/src/get_data.zsh b/profiling/src/get_data.zsh new file mode 100644 index 0000000..80784ac --- /dev/null +++ b/profiling/src/get_data.zsh @@ -0,0 +1,38 @@ +#!/bin/bash + +mkdir -p profiling_data + +## Fasta Files +# Ecoli +curl https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/reference_proteomes/Bacteria/UP000000625/UP000000625_83333.fasta.gz --output ./profiling_data/UP000000625_83333.fasta.gz +gunzip ./profiling_data/UP000000625_83333.fasta.gz + +# Human +curl https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/reference_proteomes/Eukaryota/UP000005640/UP000005640_9606.fasta.gz --output ./profiling_data/UP000005640_9606.fasta.gz +gunzip ./profiling_data/UP000005640_9606.fasta.gz +aws s3 cp --profile mfa s3://data-pipeline-metadata-bucket/contaminants.fasta ./profiling_data/. + +cat ./profiling_data/contaminants.fasta ./profiling_data/UP000000625_83333.fasta >> ./profiling_data/UP000000625_83333_crap.fasta +cat ./profiling_data/contaminants.fasta ./profiling_data/UP000005640_9606.fasta >> ./profiling_data/UP000005640_9606_crap.fasta + +# Raw Data +# Ecoli +# TimsTof +curl ftp.pride.ebi.ac.uk/pride/data/archive/2022/02/PXD028735/LFQ_timsTOFPro_diaPASEF_Ecoli_01.d.zip --output ./profiling_data/ecoli_timsTOFPro_diaPASEF.d.zip + +# Human +# Orbi +aws s3 cp --profile mfa s3://tmp-jspp-diadem-assets/230407_Chrom_60m_1ug_v2_01.mzML.gz ./profiling_data/. + +# TimsTof +aws s3 cp --profile mfa s3://tmp-jspp-diadem-assets/230426_Hela_01.d.tar ./profiling_data/. + + +for i in ./profiling_data/*.zip ; do unzip $i -d profiling_data ; done +for i in ./profiling_data/*.tar ; do tar -xf $i -C profiling_data ; done +for i in ./profiling_data/*.gz ; do gunzip -d $i ; done + +# This is done in docker ... still waiting for the mann lab to check my PR +docker build -t alphatims_docker . +docker run --rm -it -v ${PWD}/profiling_data/:/data/ alphatims_docker export hdf /data/230426_Hela_01_S4-E5_1_662.d +docker run --rm -it -v ${PWD}/profiling_data/:/data/ alphatims_docker export hdf /data/LFQ_timsTOFPro_diaPASEF_Ecoli_01.d diff --git a/profiling/src/plot_results.py b/profiling/src/plot_results.py new file mode 100644 index 0000000..7ebd6c8 --- /dev/null +++ b/profiling/src/plot_results.py @@ -0,0 +1,137 @@ +import argparse +from pathlib import Path + +import matplotlib.pyplot as plt +import numpy as np +import polars as pl +import vizta # First import vizta +from loguru import logger + +vizta.mpl.set_theme() + +parser = argparse.ArgumentParser() +parser.add_argument("base_dir", type=str) +parser.add_argument( + "--prefix", + type=str, + help=( + "prefix to use to subset the parquet files, for instance 'hela' will only match" + " '{base_dir}/hela*.parquet'" + ), +) + + +def main(): + args = parser.parse_args() + base_dir = args.base_dir + prefix = args.prefix + + # Find the data + parquet_matches = list(Path(base_dir).glob(f"{prefix}*.parquet")) + peptide_matches = [x for x in parquet_matches if "peptides" in x.name] + parquet_matches = [x for x in parquet_matches if "peptides" not in x.name] + logger.info(f"Found {len(parquet_matches)} parquet files") + logger.info(f"Found {len(peptide_matches)} peptide files") + logger.info(f"Found {len(parquet_matches)} parquet files") + + # Load the data + assert len(parquet_matches) == 1 + assert len(peptide_matches) == 1 + parquet_path = parquet_matches[0] + + # Plot the data + foo = pl.scan_parquet(parquet_path) + df = foo.select(pl.col(["Score", "decoy", "peptide"])).collect() + bins = np.histogram_bin_edges(df["Score"].to_numpy(), bins=100) + plt.hist( + df.filter(pl.col("decoy").is_not())["Score"], + alpha=0.6, + label="target", + bins=bins, + linewidth=0.01, + ) + plt.hist( + df.filter(pl.col("decoy"))["Score"], + alpha=0.6, + label="decoy", + bins=bins, + linewidth=0.01, + ) + plt.legend() + plt.title(f"PSM Score Histogram\n{prefix}") + plt.xlabel("score") + plt.ylabel("Frequency") + plt.savefig(Path(base_dir) / f"{prefix}_score_histogram_psm.png") + plt.clf() + + df = df.groupby("peptide").max() + bins = np.histogram_bin_edges(df["Score"].to_numpy(), bins=100) + plt.hist( + df.filter(pl.col("decoy").is_not())["Score"], + alpha=0.6, + label="target", + bins=bins, + linewidth=0.01, + ) + plt.hist( + df.filter(pl.col("decoy"))["Score"], + alpha=0.6, + label="decoy", + bins=bins, + linewidth=0.01, + ) + plt.legend() + plt.title(f"Peptide Score Histogram\n{prefix}") + plt.xlabel("score") + plt.ylabel("Frequency") + plt.savefig(Path(base_dir) / f"{prefix}_score_histogram_peptide.png") + + plt.yscale("log") + plt.title(f"Peptide Score Histogram (log scale)\n{prefix}") + plt.xlabel("log(score)") + plt.ylabel("Frequency") + plt.savefig(Path(base_dir) / f"{prefix}_log_score_histogram_peptide.png") + plt.clf() + + plt.scatter( + y=df["Score"], + x=np.arange(len(df)), + c=["red" if x else "blue" for x in df["decoy"]], + s=0.5, + alpha=0.4, + ) + plt.xlabel("Iteration") + plt.ylabel("Score") + plt.title(f"Peptide Score Over Iterations\n{prefix}") + plt.savefig(Path(base_dir) / f"{prefix}_scores_over_time.png") + plt.clf() + + pep_parquet = pl.scan_parquet(peptide_matches[0]) + qvals = ( + pep_parquet.filter(pl.col("is_target") & (pl.col("mokapot q-value") < 0.05)) + .select(pl.col(["mokapot q-value"])) + .sort("mokapot q-value") + .collect() + ) + plt.plot(qvals, np.arange(len(qvals))) + plt.title(f"Peptide q-values\n{prefix}") + plt.savefig(Path(base_dir) / f"{prefix}_peptide_qval.png") + plt.clf() + + metrics = {} + metrics["NumPeptides_q_0.01"] = len(qvals.filter(pl.col("mokapot q-value") < 0.01)) + metrics["AvgTargetScore"] = df.filter(pl.col("decoy").is_not())["Score"].mean() + metrics["TargetQ95Score"] = df.filter(pl.col("decoy").is_not())["Score"].quantile( + 0.95, + ) + metrics["AvgDecoyScore"] = df.filter(pl.col("decoy"))["Score"].mean() + metrics["DecoyQ95Score"] = df.filter(pl.col("decoy"))["Score"].quantile(0.95) + + # Write metrics to toml file + with open(Path(base_dir) / f"{prefix}_metrics.toml", "w") as f: + for k, v in metrics.items(): + f.write(f"{k} = {v}\n") + + +if __name__ == "__main__": + main() diff --git a/profiling/src/run.py b/profiling/src/run.py new file mode 100644 index 0000000..fe84a9a --- /dev/null +++ b/profiling/src/run.py @@ -0,0 +1,34 @@ +import argparse +import time +from dataclasses import replace + +from diadem import cli +from diadem.config import DiademConfig + +parser = argparse.ArgumentParser() +parser.add_argument("--config", type=str, help="Path to the config file") +parser.add_argument("--fasta", type=str, help="Path to the FASTA file") +parser.add_argument("--ms_data", type=str, help="Path to the MS data file") +parser.add_argument("--output", type=str, help="Path to the output file path") +parser.add_argument("--threads", type=int, help="Number of threads to use") + + +if __name__ == "__main__": + args, unk = parser.parse_known_args() + if unk: + raise ValueError(f"Unrecognized arguments: {unk}") + + config = DiademConfig.from_toml(args.config) + config = replace(config, run_parallelism=args.threads) + + st = time.time() + cli.setup_logger() + cli.diadem_main( + fasta_path=args.fasta, + data_path=args.ms_data, + config=config, + out_prefix=args.output, + ) + tt = time.time() - st + with open(args.output + "_runtime.toml", "w") as f: + f.write(f"runtime = {tt}\n") diff --git a/pyproject.toml b/pyproject.toml index b077b1e..077413c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,11 +6,11 @@ build-backend = "setuptools.build_meta" name = "diadem" authors = [ {name = "Sebastian Paez", email = "spaez@talus.bio"}, - {name = "William E. Fondrie", email = "wfrondie@talus.bio"}, {name = "Carollyn Allen", email = "callen@talus.bio"}, + {name = "William E. Fondrie", email = "wfrondie@talus.bio"}, ] description = "A modular, feature-centric toolkit for DIA proteomics" -requires-python = ">=3.9,<=3.12" +requires-python = ">=3.9,<3.11" keywords = ["proteomics", "dia", "mass spec"] license = {text = "Apache 2.0"} classifiers = [ @@ -20,10 +20,9 @@ classifiers = [ "Topic :: Scientific/Engineering :: Bio-Informatics", ] dependencies = [ - "pandas >= 1.5.2", - "numpy >= 1.23.5", - # "ms2ml >= 0.0.33", - "ms2ml >= 0.0.31", + "pandas >= 2.0.0", + "numpy == 1.23.5", # Pinned because of numba + "ms2ml >= 0.0.35", "tqdm >= 4.64.1", "loguru >= 0.6.0", "rich-click >= 1.6.0", @@ -31,13 +30,18 @@ dependencies = [ "pyarrow >= 10.0.1", "platformdirs >= 2.6.0", "joblib >= 1.2.0", - # Deisotoping - "brain-isotopic-distribution >= 1.5.11", - "ms-peak-picker >= 0.1.40", - "ms-deisotope >= 0.0.46", + "mokapot >= 0.9.1", + "alphatims >= 1.0.6", + "hdf5plugin", + "polars >= 0.16.9", + "torch >= 2.0.0", + "scikit-learn >= 1.2.2" ] dynamic = ["version"] +[project.scripts] +diadem = "diadem.cli:main_cli" + [project.readme] file = "README.md" content-type = "text/markdown" @@ -55,6 +59,16 @@ test = [ profiling = [ "line_profiler", ] +plot = [ + "matplotlib", + "vizta", +] +dev = [ + "ruff >= 0.0.253", + "black >= 23.1.0", + "isort >= 5.12.0", + "pylance >= 0.3.9", +] [tool.setuptools.packages.find] @@ -63,6 +77,7 @@ include = ["diadem"] [tool.pytest.ini_options] minversion = "6.0" addopts = "--doctest-modules -v" +doctest_optionflags = "NORMALIZE_WHITESPACE" testpaths = [ "diadem", "tests", @@ -70,7 +85,8 @@ testpaths = [ [tool.ruff] line-length = 88 -select = ["E", "F", "W", "C", "I", "D", "UP", "N", "ANN", "T20"] +select = ["E", "F", "B","W", "C", "I", "D", "UP", "N", "ANN", "T20", "COM"] +target-version = "py39" # ANN101 Missing type annotation for `self` in method # D213 Multi-line docstring summary should start at the second lin @@ -86,6 +102,8 @@ fix = true "*tests/*.py" = ["ANN"] # D104 is missing docstring in public package "**__init__.py" = ["D104"] +# Implements a sklearn interface with X and X_hat variables/params. +"diadem/aggregate/imputers.py" = ["N803", "N806"] # ANN001 Missing type annotation for function argument # Ignoring in the cli since it is redundant with the click options diff --git a/tests/conftest.py b/tests/conftest.py index cd2bcec..66aab9b 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -166,5 +166,8 @@ def albumin_peptides(): "LVAASQAALGLLVAASQAALGL", ] + # ALBUMIN_PEPTIDE_SEQS = set(ALBUMIN_PEPTIDE_SEQS) + # I am not making this a set on purposen to make sere the runtime handles it + # correctly out = [Peptide.from_proforma_seq(f"{x}/2") for x in ALBUMIN_PEPTIDE_SEQS] return out diff --git a/tests/test_database_build.py b/tests/test_database_build.py index e955a73..e8c522c 100644 --- a/tests/test_database_build.py +++ b/tests/test_database_build.py @@ -1,3 +1,5 @@ +import numpy as np + from diadem.config import DiademConfig from diadem.index.indexed_db import IndexedDb @@ -13,7 +15,7 @@ def test_peptide_scoring(sample_peaks, albumin_peptides): db.targets = albumin_peptides db.index_from_sequences() scores = db.hyperscore(z2_mass, mzs, ints) - assert "VPQVSTPTLVEVSR/2" in set(scores["Peptide"]) + assert "VPQVSTPTLVEVSR/2" in set(scores["peptide"]) return db @@ -26,5 +28,5 @@ def test_database_from_fasta(shared_datadir, sample_peaks): mzs, ints, z2_mass = sample_peaks scores = db.hyperscore(z2_mass, mzs, ints) - assert "VPQVSTPTLVEVSR/2" in set(scores["Peptide"]) + assert "VPQVSTPTLVEVSR/2" in set(scores["peptide"][np.invert(scores["decoy"])]) return db diff --git a/tests/test_database_parquet_cache.py b/tests/test_database_parquet_cache.py index 5e37e01..0c63795 100644 --- a/tests/test_database_parquet_cache.py +++ b/tests/test_database_parquet_cache.py @@ -1,3 +1,4 @@ +import numpy as np import pandas as pd from diadem.config import DiademConfig @@ -26,4 +27,7 @@ def test_parquet_generation(shared_datadir, tmpdir, sample_peaks): db.index_from_parquet(tmpdir) mzs, ints, z2_mass = sample_peaks scores = db.hyperscore(z2_mass, mzs, ints) - assert "VPQVSTPTLVEVSR/2" in set(scores["Peptide"]) + assert "VPQVSTPTLVEVSR/2" in set(scores["peptide"]) + + pep_scores = scores[scores["peptide"] == "VPQVSTPTLVEVSR/2"] + assert all(np.invert(pep_scores["decoy"])) diff --git a/tests/test_database_scoring.py b/tests/test_database_scoring.py index c9b12ea..5baca34 100644 --- a/tests/test_database_scoring.py +++ b/tests/test_database_scoring.py @@ -15,7 +15,9 @@ def test_fasta_shows_in_db(shared_datadir): config = DiademConfig() ms2ml_config = config.ms2ml_config adapter = FastaAdapter( - shared_datadir / "BSA.fasta", config=ms2ml_config, only_unique=True + shared_datadir / "BSA.fasta", + config=ms2ml_config, + only_unique=True, ) sequences = list(adapter.parse()) @@ -29,6 +31,7 @@ def test_fasta_shows_in_db(shared_datadir): ms1_range = (s.mz - 10, s.mz + 10) score_df = db.hyperscore(ms1_range, spec_mz=mzs, spec_int=intens, top_n=2) + score_df = score_df[np.invert(score_df["decoy"])] assert s.to_proforma() in list( - score_df["Peptide"] - ), f"Peptide i={i} {s.to_proforma()} not in db" + score_df["peptide"], + ), f"peptide i={i} {s.to_proforma()} not in db" diff --git a/tests/test_deisotoping.py b/tests/test_deisotoping.py index b1a79b2..29a41f9 100644 --- a/tests/test_deisotoping.py +++ b/tests/test_deisotoping.py @@ -1,7 +1,27 @@ import numpy as np +from numpy import array # noqa: F401 from diadem.deisotoping import NEUTRON, deisotope +single_clean_envelope_text = """ +{ + "mz": array( + [470.2321, 469.5762, 469.2416, 469.9097, + 470.2435, 470.5771, 467.7124, 467.2108, 471.6956, ] + ), + "ims": array( + [0.77803, 0.79777, 0.79688, 0.79611, 0.79404, + 0.79568, 0.80016, 0.79466, 0.81096, ] + ), + "intensity": array( + [92.0, 13729.0, 18491.0, 7138.0, 3081.0, + 1739.0, 751.0, 356.0, 293.0] + ), +} +""" + +single_clean_envelope = eval(single_clean_envelope_text) + def test_deisotoping() -> None: """Tests that isotopes are collapsed correctly.""" @@ -15,7 +35,8 @@ def test_deisotoping() -> None: 812.0, # Envelope 812.0 + NEUTRON / 2.0, ] - inten = [1.0, 4.0, 3.0, 2.0, 1.0, 1.0, 9.0, 4.5] + mz = np.array(mz) + inten = np.array([1.0, 4.0, 3.0, 2.0, 1.0, 1.0, 9.0, 4.5]) out_mz, out_inten = deisotope(mz, inten, 2, 5.0, "ppm") assert np.allclose(out_inten, np.array([1.0, 10.0, 1.0, 13.5])) assert np.allclose(out_mz, np.array([800.9, 803.408, 810.0, 812.0])) diff --git a/tests/test_dia_search.py b/tests/test_dia_search.py index 989f391..d7e83b9 100644 --- a/tests/test_dia_search.py +++ b/tests/test_dia_search.py @@ -1,3 +1,4 @@ +import numpy as np import pandas as pd import pytest from ms2ml import Peptide @@ -6,8 +7,8 @@ from diadem.search.diadem import diadem_main -@pytest.mark.parametrize("parallel", [True, False], ids=["Parallel", "NoParallel"]) -def test_dia_search_works(tmpdir, shared_datadir, parallel): +@pytest.mark.parametrize("parallel", [False, True], ids=["NoParallel", "Parallel"]) +def test_dia_search_works_mzml(tmpdir, shared_datadir, parallel): """Uses simulated data to test diadem. Uses data simulated using Synthedia to check if the full search @@ -17,14 +18,16 @@ def test_dia_search_works(tmpdir, shared_datadir, parallel): mzml = shared_datadir / "mzml/FGFR1_600_800_5min_group_0_sample_0.mzML" config = DiademConfig(run_parallelism=2 if parallel else 1) out = str(tmpdir / "out") - diadem_main(config=config, mzml_path=mzml, fasta_path=fasta, out_prefix=out) + diadem_main(config=config, data_path=mzml, fasta_path=fasta, out_prefix=out) expected_csv = out + ".diadem.csv" df = pd.read_csv(expected_csv) - peptides = {Peptide.from_sequence(x).stripped_sequence for x in df.Peptide.unique()} + df = df[np.invert(df["decoy"])] + peptides = {Peptide.from_sequence(x).stripped_sequence for x in df.peptide.unique()} theo_table = pd.read_csv( - shared_datadir / "mzml/FGFR1_600_800_5min_peptide_table.tsv", sep="\t" + shared_datadir / "mzml/FGFR1_600_800_5min_peptide_table.tsv", + sep="\t", ) theo_table = theo_table[ theo_table["MS2 chromatographic points group_0_sample_0"] > 0 diff --git a/tests/unit_tests/aggregate/test_confidence.py b/tests/unit_tests/aggregate/test_confidence.py new file mode 100644 index 0000000..c6c8b3b --- /dev/null +++ b/tests/unit_tests/aggregate/test_confidence.py @@ -0,0 +1 @@ +"""Tests global FDR estimation.""" diff --git a/tests/unit_tests/aggregate/test_imputers.py b/tests/unit_tests/aggregate/test_imputers.py new file mode 100644 index 0000000..e68e87d --- /dev/null +++ b/tests/unit_tests/aggregate/test_imputers.py @@ -0,0 +1,51 @@ +"""Test our RT alignment module.""" +import logging + +import numpy as np + +from diadem.aggregate.imputers import MatrixFactorizationModel, MFImputer + + +def test_model(): + """Test the base pytorch model.""" + model = MatrixFactorizationModel(n_peptides=10, n_runs=5, n_factors=3, rng=1) + assert model.peptide_factors.shape == (10, 3) + assert model.run_factors.shape == (3, 5) + assert model().shape == (10, 5) + + +def test_imputer(caplog): + """Test the imputer.""" + caplog.set_level(logging.INFO) + rng = np.random.default_rng(42) + model = MFImputer(n_factors=3, task="retention time", rng=1) + peptide_factors = rng.random((100, 3)) + run_factors = rng.random((3, 20)) + mat = peptide_factors @ run_factors + assert mat.shape == (100, 20) + + # Without NaNs: + pred = model.fit(mat).transform() + np.testing.assert_allclose(pred, mat, rtol=1e-5) + + # With NaNs: + mask = rng.binomial(1, 0.1, size=mat.shape).astype(bool) + missing = mat.copy() + missing[mask] = np.nan + pred = model.fit_transform(missing) + np.testing.assert_allclose(pred, mat, rtol=1e-5) + + +def test_search_factors(caplog): + """Test searching for the best number of factors.""" + caplog.set_level(logging.INFO) + rng = np.random.default_rng(42) + model = MFImputer(n_factors=None, task="retention time", rng=1) + peptide_factors = rng.random((101, 3)) + run_factors = rng.random((3, 20)) + mat = peptide_factors @ run_factors + assert mat.shape == (101, 20) + assert model.n_factors is None + + model = model.search_factors(mat, (2, 3, 4)) + assert model.n_factors == 3 diff --git a/tests/unit_tests/search/test_mokapot.py b/tests/unit_tests/search/test_mokapot.py new file mode 100644 index 0000000..1a6ae95 --- /dev/null +++ b/tests/unit_tests/search/test_mokapot.py @@ -0,0 +1,83 @@ +"""Unit tests for mokapot interactions.""" +import numpy as np +import pandas as pd +import pytest + +from diadem.index.protein_index import ProteinNGram +from diadem.search.mokapot import _decoy_to_target, _get_proteins, _prepare_df + + +@pytest.fixture +def kid_fasta(tmp_path): + """A tiny fasta.""" + fasta = """ + > sp|KID1|KID1_HUMAN + LESLIEKAAAAAR + > sp|KID3|KID3_HUMAN + EDITHKAAAAAR + """ + + fasta_file = tmp_path / "test.fasta" + with fasta_file.open("w+") as fout: + fout.write(fasta) + + return fasta_file + + +def test_prepare_df(kid_fasta): + """Test that our dataframe is prepared correctly.""" + in_df = pd.DataFrame( + { + "rank": [1, 2, 1, 1, 1], + "peptide": [ + "<[UNIMOD:4]@T>EDITH/2", + "<[UNIMOD:4]@T>EDITH/2", + "LES[+79.9]LIE/3", + "LILS[+79.9]EE/3", + "AAAR/2", + ], + "list_col": [np.array([1, 2])] * 5, + "cool_npeaks": [5, 5, 6, 6, 4], + "decoy": [False, False, False, True, False], + }, + ) + + expected = pd.DataFrame( + { + "peptide": ["EDITH", "LES[+79.9]LIE", "LILS[+79.9]EE", "AAAR"], + "cool_npeaks": [5, 6, 6, 4], + "is_target": [True, True, False, True], + "filename": "test", + "target_pair": ["EDITH", "LES[+79.9]LIE", "LES[+79.9]LIE", "AAAR"], + "peptide_length": [5, 6, 6, 4], + "cool_npeaks_pct": [100.0, 100.0, 100.0, 100.0], + "proteins": ["KID3", "KID1", "KID1", "KID1;KID3"], + }, + index=[0, 2, 3, 4], + ) + + out_df = _prepare_df(in_df, kid_fasta, "test.mzML") + pd.testing.assert_frame_equal(out_df, expected) + + +def test_decoy_to_target(): + """Test that our decoy to target function works correctly.""" + target = "LES[+79.9]LIEK" + + # Test reversal + decoy = "LEILS[+79.9]EK" + assert target == _decoy_to_target(decoy) + + # Test another permutation: + perm = [2, 0, 1, 3, 4, 5, 6] + decoy = "S[+79.9]LELIEK" + + assert target == _decoy_to_target(decoy, perm) + + +def test_get_proteins(kid_fasta): + """Test that _get_proteins works corectly.""" + ngram = ProteinNGram.from_fasta(kid_fasta) + assert _get_proteins("LESLIE", ngram) == "KID1" + assert _get_proteins("EDITH", ngram) == "KID3" + assert _get_proteins("AAAR", ngram) == "KID1;KID3" diff --git a/tests/unit_tests/test_interface.py b/tests/unit_tests/test_interface.py new file mode 100644 index 0000000..7f7e4a6 --- /dev/null +++ b/tests/unit_tests/test_interface.py @@ -0,0 +1,75 @@ +"""Verify that our interface base class works.""" +import polars as pl +import pytest +from polars.testing import assert_frame_equal + +from diadem.interfaces import BaseDiademInterface, RequiredColumn + + +def test_required_column(): + """Test the required column dataclass.""" + cols = [("foo", pl.datatypes.Float32), ("bar", pl.datatypes.Field)] + req = RequiredColumn.from_iter(cols) + + for col, result in zip(cols, req): + assert col[0] == result.name + assert col[1] == result.dtype + + +def test_base_interface_init(): + """Test our initialization.""" + + class RealInterface(BaseDiademInterface): + """A dummy interface.""" + + def __init__(self, data): + """Init the thing.""" + super().__init__(data) + + @property + def required_columns(self): + """The required columns.""" + return [ + RequiredColumn("foo", pl.datatypes.Float32), + RequiredColumn("bar", pl.datatypes.Utf8), + ] + + good_df = pl.DataFrame({"foo": [1.0, 2.0], "bar": ["a", "b"]}) + interface = RealInterface(good_df) + assert_frame_equal(interface.data.collect(), good_df) + + bad_df = pl.DataFrame({"foo": ["a", "b"], "bar": ["a", "b"]}) + with pytest.raises(ValueError) as err: + RealInterface(bad_df) + + assert "wrong data type" in str(err.value) + + bad_df = pl.DataFrame({"bar": ["a", "b"]}) + with pytest.raises(ValueError) as err: + RealInterface(bad_df) + + assert "missing" in str(err.value) + + +def test_base_interface_from_parquet(tmp_path): + """Test loading a parquet file.""" + + class RealInterface(BaseDiademInterface): + """A dummy interface.""" + + def __init__(self, data): + """Init the thing.""" + super().__init__(data) + + @property + def required_columns(self): + """The required columns.""" + return [ + RequiredColumn("foo", pl.datatypes.Float32), + RequiredColumn("bar", pl.datatypes.Utf8), + ] + + good_df = pl.DataFrame({"foo": [1.0, 2.0], "bar": ["a", "b"]}) + good_df.write_parquet(tmp_path / "test.parquet") + interface = RealInterface.from_parquet(tmp_path / "test.parquet") + pl.testing.assert_frame_equal(interface.data.collect(), good_df)