diff --git a/.github/workflows/deploy-docs.yml b/.github/workflows/deploy-docs.yml index 25f61dd..8b38f1a 100644 --- a/.github/workflows/deploy-docs.yml +++ b/.github/workflows/deploy-docs.yml @@ -2,35 +2,62 @@ name: Deploy Documentation on: push: - branches: ["main"] + branches: [main] + paths: + - 'docs/**' + - 'mkdocs.yml' + - '.github/workflows/deploy-docs.yml' + pull_request: + paths: + - 'docs/**' + - 'mkdocs.yml' workflow_dispatch: permissions: - contents: write + contents: read + pages: write + id-token: write + +concurrency: + group: "pages" + cancel-in-progress: false jobs: - deploy-docs: + build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v4 - with: - fetch-depth: 0 + - name: Checkout + uses: actions/checkout@v4 - - name: Set up Python 3.10 - uses: actions/setup-python@v4 + - name: Setup Python + uses: actions/setup-python@v5 with: - python-version: "3.10" + python-version: '3.11' - - name: Cache pip dependencies - uses: actions/cache@v3 - with: - path: ~/.cache/pip - key: ${{ runner.os }}-pip-docs-${{ hashFiles('requirements-docs.txt') }} - restore-keys: | - ${{ runner.os }}-pip-docs- + - name: Install MkDocs + run: pip install mkdocs-material + + - name: Build docs + run: mkdocs build --site-dir site - - name: Install documentation dependencies - run: pip install -r requirements-docs.txt + - name: Setup Pages + if: github.ref == 'refs/heads/main' + uses: actions/configure-pages@v5 - - name: Build and deploy documentation - run: 'mkdocs gh-deploy --force --message "docs: deploy documentation [skip ci]"' + - name: Upload artifact + if: github.ref == 'refs/heads/main' + uses: actions/upload-pages-artifact@v3 + with: + path: site + + deploy: + if: github.ref == 'refs/heads/main' + environment: + name: github-pages + url: ${{ steps.deployment.outputs.page_url }} + runs-on: ubuntu-latest + needs: build + steps: + - name: Deploy to GitHub Pages + id: deployment + uses: actions/deploy-pages@v4 diff --git a/.gitignore b/.gitignore index 3feab05..8643c2e 100644 --- a/.gitignore +++ b/.gitignore @@ -248,3 +248,4 @@ tests/example/QCprofile.pdf poetry.lock .vscode/ .DS_Store +tmp/ diff --git a/docs/images/favicon.svg b/docs/images/favicon.svg new file mode 100644 index 0000000..4e9090c --- /dev/null +++ b/docs/images/favicon.svg @@ -0,0 +1,10 @@ + + + + + + + + + + diff --git a/docs/logo/mokume_logo.svg b/docs/logo/mokume_logo.svg new file mode 100644 index 0000000..932fdb8 --- /dev/null +++ b/docs/logo/mokume_logo.svg @@ -0,0 +1,11 @@ + + + + + + + + + + mokume + \ No newline at end of file diff --git a/docs/logo/mokume_logo_darkbg.svg b/docs/logo/mokume_logo_darkbg.svg new file mode 100644 index 0000000..54a4b5c --- /dev/null +++ b/docs/logo/mokume_logo_darkbg.svg @@ -0,0 +1,11 @@ + + + + + + + + + + mokume + \ No newline at end of file diff --git a/docs/logo/mokume_mark.svg b/docs/logo/mokume_mark.svg new file mode 100644 index 0000000..ae936eb --- /dev/null +++ b/docs/logo/mokume_mark.svg @@ -0,0 +1,8 @@ + + + + + + + + \ No newline at end of file diff --git a/docs/overrides/main.html b/docs/overrides/main.html new file mode 100644 index 0000000..003720b --- /dev/null +++ b/docs/overrides/main.html @@ -0,0 +1,67 @@ +{% extends "base.html" %} + +{% block footer %} + +{% endblock %} diff --git a/mkdocs.yml b/mkdocs.yml index 6bc1273..5ba88e5 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -8,22 +8,22 @@ repo_url: https://github.com/bigbio/mokume theme: name: material + custom_dir: docs/overrides + font: false palette: - scheme: default - primary: indigo + primary: blue accent: indigo toggle: icon: material/brightness-7 name: Switch to dark mode - scheme: slate - primary: indigo + primary: blue accent: indigo toggle: icon: material/brightness-4 name: Switch to light mode features: - - navigation.tabs - - navigation.tabs.sticky - navigation.sections - navigation.expand - navigation.path @@ -36,7 +36,6 @@ theme: - search.highlight - search.share - toc.follow - - toc.integrate - content.code.copy - content.code.annotate - content.tabs.link @@ -99,31 +98,34 @@ plugins: - en nav: - - Home: + - Getting Started: - Home: index.md - Installation: installation.md - Quick Start: quickstart.md - - Concepts: + - Community: community.md + - Library: - concepts/index.md - Quantification Methods: concepts/quantification.md - Normalization: concepts/normalization.md - Batch Correction: concepts/batch-correction.md - IRS Normalization: concepts/irs.md - - Preprocessing Filters: concepts/preprocessing.md - - User Guide: + - Preprocessing: concepts/preprocessing.md + - CLI: - user-guide/index.md - - "features2proteins: Unified Pipeline": user-guide/features2proteins.md - - "features2peptides: Peptide Normalization": user-guide/features2peptides.md - - "peptides2protein: Protein Quantification": user-guide/peptides2protein.md - - "batch-correct: Batch Correction": user-guide/batch-correct.md - - Visualization & Reports: user-guide/visualization.md - - Reference: + - features2proteins: user-guide/features2proteins.md + - features2peptides: user-guide/features2peptides.md + - peptides2protein: user-guide/peptides2protein.md + - batch-correct: user-guide/batch-correct.md + - Visualization: user-guide/visualization.md + - API Reference: - reference/index.md - - CLI Reference: reference/cli.md + - CLI Options: reference/cli.md - Python API: reference/python-api.md - Configuration: reference/configuration.md - Computed Values: reference/computed-values.md - - Community: community.md + +extra_css: + - https://quantms.org/css/quantms-theme.css extra: social: diff --git a/mokume/core/dataset.py b/mokume/core/dataset.py new file mode 100644 index 0000000..18b1460 --- /dev/null +++ b/mokume/core/dataset.py @@ -0,0 +1,941 @@ +""" +QpxDataset — Hierarchical proteomics data container. + +This module provides the core data container for the mokume pipeline. +QpxDataset mirrors the qpx format's hierarchical structure: PSMs, +features, peptides, and proteins are distinct data levels. Each +processing step reads from the appropriate level and writes results back. + +Each data level can be backed by either a pandas DataFrame or a DuckDB +``LazyFrame``. Lazy frames defer computation until results are explicitly +requested, enabling mokume to handle datasets with millions of features +without loading them all into memory. + +Serialization is native: the dataset saves/loads as a directory of +parquet files with a metadata JSON sidecar. + +Example +------- +>>> dataset = QpxDataset.from_parquet("data.parquet") +>>> dataset.validate_level("features") +>>> wide = dataset.to_wide_matrix(level="proteins", value_col="Intensity") +>>> adata = dataset.to_anndata(level="proteins", value_col="Intensity") +>>> dataset.save("output_dir/") +>>> +>>> # Lazy loading for large datasets: +>>> dataset = QpxDataset.from_parquet_lazy("large_data.parquet") +>>> dataset.features # LazyFrame, no materialization yet +>>> df = dataset.get_level("features") # Materializes to DataFrame +""" + +import json +import os +from copy import deepcopy +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Any, Dict, List, Optional, Union + +import pandas as pd + +from mokume.core.constants import PROTEIN_NAME, SAMPLE_ID +from mokume.core.schema import validate_schema + +# Type alias for data that can be either eager or lazy +DataLevel = Union[pd.DataFrame, "LazyFrame"] + +# Sentinel to avoid circular import at module level +_LazyFrame = None + + +def _get_lazy_frame_class(): + """Lazy import of LazyFrame to avoid circular imports.""" + global _LazyFrame + if _LazyFrame is None: + from mokume.core.duckdb_backend import LazyFrame + _LazyFrame = LazyFrame + return _LazyFrame + + +def _is_lazy(obj) -> bool: + """Check if an object is a LazyFrame without importing at module level.""" + if obj is None: + return False + LazyFrame = _get_lazy_frame_class() + return isinstance(obj, LazyFrame) + + +def _ensure_df(obj: Optional[DataLevel]) -> Optional[pd.DataFrame]: + """Materialize a LazyFrame to DataFrame if needed, or return as-is.""" + if obj is None: + return None + if _is_lazy(obj): + return obj.df() + return obj + + +def _extract_scalar(val): + """Extract a scalar string from a value that may be a list/ndarray. + + QPX sample metadata stores SDRF characteristics as arrays (one element + per unique value within a sample). For AnnData obs we need plain + strings, so we join multi-valued fields with ``"; "``. + """ + if val is None or (hasattr(val, "__class__") and val.__class__.__name__ == "NAType"): + return None + if hasattr(val, "__len__") and not isinstance(val, (str, dict)): + parts = [str(v) for v in val if v is not None and str(v) != ""] + return "; ".join(parts) if parts else None + return str(val) if pd.notna(val) else None + + +def _flatten_sample_meta(meta: pd.DataFrame) -> pd.DataFrame: + """Prepare QPX sample metadata for AnnData obs. + + 1. Flatten ndarray/list columns to scalar strings. + 2. Expand ``additional_properties`` dicts into individual columns. + 3. Drop columns that are entirely NA. + """ + out = meta.copy() + + # Expand additional_properties dicts into individual columns + if "additional_properties" in out.columns: + extra_rows = [] + for idx, val in out["additional_properties"].items(): + row_extra = {} + if isinstance(val, dict): + for k, v in val.items(): + # Normalise the key: "characteristics[biological replicate]" -> "biological_replicate" + clean_key = k + import re + m = re.match(r"characteristics\[(.+)\]", k) + if m: + clean_key = m.group(1).strip().replace(" ", "_") + row_extra[clean_key] = _extract_scalar(v) + extra_rows.append(row_extra) + if extra_rows and any(extra_rows): + extra_df = pd.DataFrame(extra_rows, index=out.index) + # Only add columns that don't already exist + for col in extra_df.columns: + if col not in out.columns: + out[col] = extra_df[col] + out.drop(columns=["additional_properties"], inplace=True) + + # Flatten ndarray/list values to scalars + for col in out.columns: + if out[col].dtype == object: + out[col] = out[col].apply(_extract_scalar) + + # Drop columns that are entirely NA + out.dropna(axis=1, how="all", inplace=True) + + return out + + +@dataclass +class QpxDataset: + """Hierarchical proteomics data container backed by qpx format. + + Each data level can hold either a ``pd.DataFrame`` (eager) or a + ``LazyFrame`` (lazy DuckDB-backed). Use ``get_level()`` to always + get a materialized DataFrame, or ``get_level_raw()`` to access the + underlying object (which may be lazy). + + Attributes + ---------- + psms : DataFrame or LazyFrame, optional + PSM-level data (used by ratio quantification). + features : DataFrame or LazyFrame, optional + Feature-level data (charge states, fractions). + peptides : DataFrame or LazyFrame, optional + Peptide-level data (assembled peptidoforms, normalized). + proteins : DataFrame or LazyFrame, optional + Protein-level quantification results. + sample_info : pd.DataFrame, optional + Per-sample metadata (from SDRF or future metadata format). + protein_info : pd.DataFrame, optional + Per-protein metadata (accessions, gene names, MW, etc.). + uns : dict + Unstructured metadata (pipeline config, provenance, DE results). + layers : dict + Named alternative representations at any level. + E.g., ``layers["normalized_features"]``, ``layers["batch_corrected"]``. + """ + + psms: Optional[DataLevel] = None + features: Optional[DataLevel] = None + peptides: Optional[DataLevel] = None + proteins: Optional[DataLevel] = None + + sample_info: Optional[pd.DataFrame] = None + protein_info: Optional[pd.DataFrame] = None + + uns: Dict[str, Any] = field(default_factory=dict) + layers: Dict[str, pd.DataFrame] = field(default_factory=dict) + + # ------------------------------------------------------------------ + # Data level access + # ------------------------------------------------------------------ + + _VALID_LEVELS = ("psms", "features", "peptides", "proteins") + + def get_level(self, level: str) -> Optional[pd.DataFrame]: + """Get a data level by name, materializing if lazy. + + Parameters + ---------- + level : str + One of "psms", "features", "peptides", "proteins". + + Returns + ------- + pd.DataFrame or None + Always returns a DataFrame (never a LazyFrame). + """ + if level not in self._VALID_LEVELS: + raise ValueError( + f"Unknown data level: '{level}'. " + f"Must be one of: {', '.join(self._VALID_LEVELS)}" + ) + raw = getattr(self, level, None) + if raw is None: + return None + if _is_lazy(raw): + # Materialize and cache + df = raw.df() + setattr(self, level, df) + return df + return raw + + def get_level_raw(self, level: str) -> Optional[DataLevel]: + """Get a data level without materializing. + + Returns the underlying object which may be a LazyFrame + (DuckDB-backed) or a DataFrame. + + Parameters + ---------- + level : str + One of "psms", "features", "peptides", "proteins". + + Returns + ------- + pd.DataFrame, LazyFrame, or None + """ + if level not in self._VALID_LEVELS: + raise ValueError( + f"Unknown data level: '{level}'. " + f"Must be one of: {', '.join(self._VALID_LEVELS)}" + ) + return getattr(self, level, None) + + def set_level(self, level: str, data: DataLevel) -> None: + """Set a data level by name. + + Parameters + ---------- + level : str + One of "psms", "features", "peptides", "proteins". + data : pd.DataFrame or LazyFrame + The data to assign. + """ + if level not in self._VALID_LEVELS: + raise ValueError( + f"Unknown data level: '{level}'. " + f"Must be one of: {', '.join(self._VALID_LEVELS)}" + ) + setattr(self, level, data) + + def is_lazy(self, level: str) -> bool: + """Check if a data level is lazy (DuckDB-backed). + + Parameters + ---------- + level : str + Data level name. + + Returns + ------- + bool + """ + return _is_lazy(getattr(self, level, None)) + + @property + def populated_levels(self) -> List[str]: + """Return names of data levels that have been populated.""" + levels = [] + for name in self._VALID_LEVELS: + if getattr(self, name) is not None: + levels.append(name) + return levels + + @property + def lazy_levels(self) -> List[str]: + """Return names of data levels that are lazy (DuckDB-backed).""" + return [name for name in self._VALID_LEVELS if self.is_lazy(name)] + + # ------------------------------------------------------------------ + # Materialization + # ------------------------------------------------------------------ + + def materialize(self, levels: Optional[List[str]] = None) -> "QpxDataset": + """Force materialization of lazy levels to DataFrames. + + Parameters + ---------- + levels : list[str], optional + Specific levels to materialize. If None, materializes all. + + Returns + ------- + QpxDataset + Self (for chaining). + """ + target_levels = levels or self._VALID_LEVELS + for level_name in target_levels: + raw = getattr(self, level_name, None) + if raw is not None and _is_lazy(raw): + setattr(self, level_name, raw.df()) + return self + + # ------------------------------------------------------------------ + # Schema validation + # ------------------------------------------------------------------ + + def validate_level(self, level: str) -> List[str]: + """Validate that a data level has required columns. + + For lazy levels, validation checks column names without + materializing the full dataset. + + Parameters + ---------- + level : str + Data level to validate. + + Returns + ------- + list[str] + List of error messages (empty if valid). + """ + raw = self.get_level_raw(level) + if raw is None: + return [f"Data level '{level}' is not populated"] + + if _is_lazy(raw): + # Validate column names without materializing + df_stub = pd.DataFrame(columns=raw.columns) + return validate_schema(df_stub, level) + + return validate_schema(raw, level) + + # ------------------------------------------------------------------ + # Convenience methods + # ------------------------------------------------------------------ + + def to_wide_matrix( + self, + level: str = "proteins", + value_col: str = "Intensity", + protein_col: str = PROTEIN_NAME, + sample_col: str = SAMPLE_ID, + ) -> pd.DataFrame: + """Pivot a data level to a protein x sample matrix. + + Parameters + ---------- + level : str + Data level to pivot. + value_col : str + Column containing values for the matrix cells. + protein_col : str + Column to use as row index. + sample_col : str + Column to use as column headers. + + Returns + ------- + pd.DataFrame + Wide-format matrix with protein_col as index. + """ + df = self.get_level(level) + if df is None: + raise ValueError(f"Data level '{level}' is not populated") + + return df.pivot_table( + index=protein_col, + columns=sample_col, + values=value_col, + aggfunc="first", + ) + + def to_long_format( + self, + level: str = "proteins", + ) -> pd.DataFrame: + """Return a data level in long format (copy). + + Parameters + ---------- + level : str + Data level to return. + + Returns + ------- + pd.DataFrame + Copy of the level's DataFrame. + """ + df = self.get_level(level) + if df is None: + raise ValueError(f"Data level '{level}' is not populated") + return df.copy() + + def peptide_protein_map(self) -> pd.DataFrame: + """Return a peptide-to-protein mapping table. + + Uses the first available level that has both ProteinName and + a peptide column. + + Returns + ------- + pd.DataFrame + DataFrame with ProteinName and peptide columns. + """ + from mokume.core.constants import PEPTIDE_CANONICAL, PEPTIDE_SEQUENCE + + for level_name in ("peptides", "features", "psms"): + raw = self.get_level_raw(level_name) + if raw is None: + continue + + # Get column names (works for both DataFrame and LazyFrame) + cols = raw.columns + pep_col = None + if PEPTIDE_CANONICAL in cols: + pep_col = PEPTIDE_CANONICAL + elif PEPTIDE_SEQUENCE in cols: + pep_col = PEPTIDE_SEQUENCE + + if pep_col and PROTEIN_NAME in cols: + df = self.get_level(level_name) # materialize if needed + return df[[PROTEIN_NAME, pep_col]].drop_duplicates() + + raise ValueError("No data level with protein and peptide columns found") + + def sample_metadata(self) -> pd.DataFrame: + """Return sample metadata. + + Returns + ------- + pd.DataFrame + Sample metadata DataFrame. + + Raises + ------ + ValueError + If sample_info is not populated. + """ + if self.sample_info is None: + raise ValueError("sample_info is not populated") + return self.sample_info.copy() + + # ------------------------------------------------------------------ + # Subsetting + # ------------------------------------------------------------------ + + def subset_samples(self, sample_ids: List[str]) -> "QpxDataset": + """Subset all populated levels to given samples. + + Creates a new QpxDataset where every data level and sample_info + are filtered to only include the specified samples. Lazy levels + are materialized during subsetting. + + Parameters + ---------- + sample_ids : list[str] + Sample identifiers to keep. + + Returns + ------- + QpxDataset + New dataset with subsetted data. + """ + sample_set = set(sample_ids) + new = QpxDataset( + uns=deepcopy(self.uns), + layers={}, + protein_info=self.protein_info.copy() if self.protein_info is not None else None, + ) + + for level_name in self._VALID_LEVELS: + df = self.get_level(level_name) # materializes if lazy + if df is not None and SAMPLE_ID in df.columns: + new.set_level(level_name, df[df[SAMPLE_ID].isin(sample_set)].copy()) + elif df is not None: + # Wide format — filter columns + sample_cols = [c for c in df.columns if c in sample_set] + non_sample_cols = [c for c in df.columns if c not in sample_set and c not in sample_ids] + if sample_cols: + new.set_level(level_name, df[non_sample_cols + sample_cols].copy()) + + if self.sample_info is not None: + # Try common sample ID column names + for col in [SAMPLE_ID, "sample_accession", "source name"]: + if col in self.sample_info.columns: + new.sample_info = self.sample_info[ + self.sample_info[col].isin(sample_set) + ].copy() + break + else: + new.sample_info = self.sample_info.copy() + + for key, layer_df in self.layers.items(): + if SAMPLE_ID in layer_df.columns: + new.layers[key] = layer_df[layer_df[SAMPLE_ID].isin(sample_set)].copy() + else: + new.layers[key] = layer_df.copy() + + return new + + def subset_proteins(self, protein_ids: List[str]) -> "QpxDataset": + """Subset all populated levels to given proteins. + + Parameters + ---------- + protein_ids : list[str] + Protein identifiers to keep. + + Returns + ------- + QpxDataset + New dataset with subsetted data. + """ + protein_set = set(protein_ids) + new = QpxDataset( + uns=deepcopy(self.uns), + layers={}, + sample_info=self.sample_info.copy() if self.sample_info is not None else None, + ) + + for level_name in self._VALID_LEVELS: + df = self.get_level(level_name) # materializes if lazy + if df is not None and PROTEIN_NAME in df.columns: + new.set_level(level_name, df[df[PROTEIN_NAME].isin(protein_set)].copy()) + + if self.protein_info is not None and PROTEIN_NAME in self.protein_info.columns: + new.protein_info = self.protein_info[ + self.protein_info[PROTEIN_NAME].isin(protein_set) + ].copy() + + for key, layer_df in self.layers.items(): + if PROTEIN_NAME in layer_df.columns: + new.layers[key] = layer_df[layer_df[PROTEIN_NAME].isin(protein_set)].copy() + else: + new.layers[key] = layer_df.copy() + + return new + + # ------------------------------------------------------------------ + # Serialization + # ------------------------------------------------------------------ + + def save(self, directory: str) -> None: + """Save the dataset to a directory of parquet files. + + Lazy levels are materialized during save. + + Directory structure:: + + directory/ + ├── psms.parquet (if populated) + ├── features.parquet (if populated) + ├── peptides.parquet (if populated) + ├── proteins.parquet (if populated) + ├── sample_info.parquet (if populated) + ├── protein_info.parquet(if populated) + ├── layers/ + │ ├── .parquet (for each layer) + └── uns.json + + Parameters + ---------- + directory : str + Output directory path. Created if it doesn't exist. + """ + os.makedirs(directory, exist_ok=True) + + # Save data levels (materializing lazy frames) + for level_name in self._VALID_LEVELS: + raw = self.get_level_raw(level_name) + if raw is not None: + df = _ensure_df(raw) + df.to_parquet(os.path.join(directory, f"{level_name}.parquet"), index=False) + + # Save metadata + if self.sample_info is not None: + self.sample_info.to_parquet( + os.path.join(directory, "sample_info.parquet"), index=False + ) + if self.protein_info is not None: + self.protein_info.to_parquet( + os.path.join(directory, "protein_info.parquet"), index=False + ) + + # Save layers + if self.layers: + layers_dir = os.path.join(directory, "layers") + os.makedirs(layers_dir, exist_ok=True) + for name, layer_df in self.layers.items(): + layer_df.to_parquet( + os.path.join(layers_dir, f"{name}.parquet"), index=False + ) + + # Save uns as JSON (convert non-serializable values to strings) + uns_path = os.path.join(directory, "uns.json") + with open(uns_path, "w") as f: + json.dump(self.uns, f, indent=2, default=str) + + @classmethod + def load(cls, directory: str, lazy: bool = False) -> "QpxDataset": + """Load a dataset from a directory of parquet files. + + Parameters + ---------- + directory : str + Directory containing the saved dataset. + lazy : bool + If True, load data levels as LazyFrames (DuckDB-backed) + instead of materializing to DataFrames. + + Returns + ------- + QpxDataset + """ + dataset = cls() + + if lazy: + LazyFrame = _get_lazy_frame_class() + + # Load data levels + for level_name in cls._VALID_LEVELS: + path = os.path.join(directory, f"{level_name}.parquet") + if os.path.exists(path): + if lazy: + dataset.set_level(level_name, LazyFrame.from_parquet(path)) + else: + dataset.set_level(level_name, pd.read_parquet(path)) + + # Load metadata (always eager — these are small) + sample_info_path = os.path.join(directory, "sample_info.parquet") + if os.path.exists(sample_info_path): + dataset.sample_info = pd.read_parquet(sample_info_path) + + protein_info_path = os.path.join(directory, "protein_info.parquet") + if os.path.exists(protein_info_path): + dataset.protein_info = pd.read_parquet(protein_info_path) + + # Load layers (always eager) + layers_dir = os.path.join(directory, "layers") + if os.path.isdir(layers_dir): + for filename in os.listdir(layers_dir): + if filename.endswith(".parquet"): + name = filename.rsplit(".parquet", 1)[0] + dataset.layers[name] = pd.read_parquet( + os.path.join(layers_dir, filename) + ) + + # Load uns + uns_path = os.path.join(directory, "uns.json") + if os.path.exists(uns_path): + with open(uns_path, "r") as f: + dataset.uns = json.load(f) + + return dataset + + # ------------------------------------------------------------------ + # Export + # ------------------------------------------------------------------ + + def to_anndata( + self, + level: str = "proteins", + value_col: str = "Intensity", + protein_col: str = PROTEIN_NAME, + sample_col: str = SAMPLE_ID, + layer_names: Optional[List[str]] = None, + ): + """Export a data level as an AnnData object. + + Parameters + ---------- + level : str + Data level to export. + value_col : str + Column for the main X matrix values. + protein_col : str + Column for variable (protein) identifiers. + sample_col : str + Column for observation (sample) identifiers. + layer_names : list[str], optional + Names of layers (from self.layers) to include in AnnData.layers. + + Returns + ------- + anndata.AnnData + AnnData object with X matrix, obs, var, and optional layers. + """ + try: + import anndata as ad + except ImportError: + raise ImportError( + "anndata is required for to_anndata(). " + "Install with: pip install mokume[anndata]" + ) + + df = self.get_level(level) + if df is None: + raise ValueError(f"Data level '{level}' is not populated") + + # Detect wide vs long format. + # Long: has sample_col and value_col as columns (classic long format). + # Wide: first column is protein IDs, remaining columns are sample names. + is_long = sample_col in df.columns and value_col in df.columns + + if is_long: + wide = df.pivot_table( + index=sample_col, + columns=protein_col, + values=value_col, + aggfunc="first", + ) + else: + # Wide format: identify the protein ID column (first column or + # matching protein_col / common names). + pid_col = None + for candidate in [protein_col, "protein", "ProteinName", df.columns[0]]: + if candidate in df.columns: + pid_col = candidate + break + wide = df.set_index(pid_col).T + wide.index.name = sample_col + + # Build obs (sample metadata) + obs = pd.DataFrame(index=wide.index) + obs.index.name = sample_col + if self.sample_info is not None: + # Resolve the sample ID column in sample_info; may be + # ``sample_col`` (SampleID) or ``sample_accession`` (QPX). + sid_col = None + for candidate in [sample_col, "sample_accession"]: + if candidate in self.sample_info.columns: + sid_col = candidate + break + if sid_col is not None: + sample_meta = self.sample_info.set_index(sid_col).copy() + # Flatten ndarray/list values to scalar strings (QPX stores + # SDRF characteristics as arrays). + sample_meta = _flatten_sample_meta(sample_meta) + obs = obs.join(sample_meta, how="left") + + # Build var (protein metadata) + var = pd.DataFrame(index=wide.columns) + var.index.name = protein_col + if self.protein_info is not None and protein_col in self.protein_info.columns: + protein_meta = self.protein_info.set_index(protein_col) + var = var.join(protein_meta, how="left") + + # Build AnnData + adata = ad.AnnData( + X=wide.values, + obs=obs, + var=var, + ) + + # Add layers if requested + if layer_names: + for layer_name in layer_names: + if layer_name in self.layers: + layer_df = self.layers[layer_name] + if sample_col in layer_df.columns and protein_col in layer_df.columns: + layer_wide = layer_df.pivot_table( + index=sample_col, + columns=protein_col, + values=value_col, + aggfunc="first", + ) + # Align to same shape as X + layer_wide = layer_wide.reindex( + index=wide.index, columns=wide.columns + ) + adata.layers[layer_name] = layer_wide.values + + # Store provenance in uns (serialize complex structures as JSON + # strings because h5py cannot write deeply-nested dicts/lists). + if self.uns: + import json as _json + safe = {} + for k, v in self.uns.items(): + if isinstance(v, (str, int, float, bool)): + safe[k] = v + elif isinstance(v, (list, dict)): + safe[k] = _json.dumps(v, default=str) + # skip other types + adata.uns["mokume"] = safe + + return adata + + # ------------------------------------------------------------------ + # Provenance + # ------------------------------------------------------------------ + + def record_step( + self, + name: str, + method: Optional[str] = None, + duration_seconds: Optional[float] = None, + rows_in: Optional[int] = None, + rows_out: Optional[int] = None, + **extra, + ) -> None: + """Record a pipeline step in provenance metadata. + + Parameters + ---------- + name : str + Step name (e.g., "loading", "normalization", "quantification"). + method : str, optional + Method used (e.g., "median", "directlfq"). + duration_seconds : float, optional + Wall-clock time for the step. + rows_in : int, optional + Number of input rows. + rows_out : int, optional + Number of output rows. + **extra + Additional metadata for the step. + """ + if "provenance" not in self.uns: + self.uns["provenance"] = { + "timestamp": datetime.now(timezone.utc).isoformat(), + "steps": [], + } + + step = {"name": name} + if method is not None: + step["method"] = method + if duration_seconds is not None: + step["duration_seconds"] = round(duration_seconds, 3) + if rows_in is not None: + step["rows_in"] = rows_in + if rows_out is not None: + step["rows_out"] = rows_out + step.update(extra) + + self.uns["provenance"]["steps"].append(step) + + # ------------------------------------------------------------------ + # Factory methods + # ------------------------------------------------------------------ + + @classmethod + def from_parquet( + cls, + parquet_path: str, + sdrf_path: Optional[str] = None, + level: str = "features", + ) -> "QpxDataset": + """Create a QpxDataset from a qpx parquet file (eager loading). + + Reads the entire parquet file into a DataFrame. For large + datasets, use ``from_parquet_lazy()`` instead. + + Parameters + ---------- + parquet_path : str + Path to the qpx parquet file. + sdrf_path : str, optional + Path to SDRF TSV for sample metadata. + level : str + Data level to assign the loaded data to. + + Returns + ------- + QpxDataset + """ + dataset = cls() + + # Load parquet + df = pd.read_parquet(parquet_path) + dataset.set_level(level, df) + + # Load SDRF if provided + if sdrf_path is not None: + from mokume.core.constants import load_sdrf + dataset.sample_info = load_sdrf(sdrf_path) + + return dataset + + @classmethod + def from_parquet_lazy( + cls, + parquet_path: str, + sdrf_path: Optional[str] = None, + level: str = "features", + ) -> "QpxDataset": + """Create a QpxDataset with DuckDB lazy backing from a parquet file. + + The parquet file is registered as a DuckDB scan but not loaded + into memory. Data is only materialized when explicitly requested + via ``get_level()`` or ``materialize()``. + + Parameters + ---------- + parquet_path : str + Path to the qpx parquet file. + sdrf_path : str, optional + Path to SDRF TSV for sample metadata. + level : str + Data level to assign the lazy scan to. + + Returns + ------- + QpxDataset + """ + LazyFrame = _get_lazy_frame_class() + dataset = cls() + + # Create lazy scan + lazy = LazyFrame.from_parquet(parquet_path) + dataset.set_level(level, lazy) + + # Load SDRF if provided (always eager — small file) + if sdrf_path is not None: + from mokume.core.constants import load_sdrf + dataset.sample_info = load_sdrf(sdrf_path) + + return dataset + + # ------------------------------------------------------------------ + # Representation + # ------------------------------------------------------------------ + + def __repr__(self) -> str: + parts = ["QpxDataset("] + for level_name in self._VALID_LEVELS: + raw = self.get_level_raw(level_name) + if raw is None: + continue + if _is_lazy(raw): + ncols = len(raw.columns) + parts.append(f" {level_name}: LazyFrame ({ncols} cols, not materialized)") + else: + parts.append(f" {level_name}: {raw.shape[0]} rows x {raw.shape[1]} cols") + if self.sample_info is not None: + parts.append(f" sample_info: {len(self.sample_info)} samples") + if self.protein_info is not None: + parts.append(f" protein_info: {len(self.protein_info)} proteins") + if self.layers: + parts.append(f" layers: {list(self.layers.keys())}") + if self.uns: + parts.append(f" uns: {list(self.uns.keys())}") + parts.append(")") + return "\n".join(parts) diff --git a/mokume/core/duckdb_backend.py b/mokume/core/duckdb_backend.py new file mode 100644 index 0000000..bb2fdda --- /dev/null +++ b/mokume/core/duckdb_backend.py @@ -0,0 +1,351 @@ +""" +DuckDB lazy backing for QpxDataset. + +Provides the ``LazyFrame`` class — a thin wrapper around DuckDB relations +that defers computation until results are explicitly requested. This allows +QpxDataset to represent millions of PSMs/features without loading them all +into memory. + +A LazyFrame can be: +- Created from a parquet file (``LazyFrame.from_parquet``) +- Created from a SQL query (``LazyFrame.from_sql``) +- Created from an existing DataFrame (``LazyFrame.from_dataframe``) +- Materialized to a pandas DataFrame on demand (``.df()``) + +Key operations (filter, select, head, describe) stay lazy until ``.df()`` +is called. + +Example +------- +>>> lf = LazyFrame.from_parquet("data.parquet") +>>> lf.columns +['ProteinName', 'SampleID', 'Intensity', ...] +>>> lf.shape # (row_count, col_count) — row count via COUNT(*) +(1500000, 12) +>>> filtered = lf.filter("Intensity > 0") +>>> df = filtered.df() # Materializes to pandas DataFrame +""" + +import logging +from typing import List, Optional, Union + +import duckdb +import pandas as pd + +logger = logging.getLogger(__name__) + + +class LazyFrame: + """Lazy wrapper around a DuckDB relation. + + Parameters + ---------- + relation : duckdb.DuckDBPyRelation + The underlying DuckDB relation. + connection : duckdb.DuckDBPyConnection + The DuckDB connection that owns this relation. + source : str, optional + Description of the data source (for repr/logging). + """ + + def __init__( + self, + relation: "duckdb.DuckDBPyRelation", + connection: "duckdb.DuckDBPyConnection", + source: str = "unknown", + owns_connection: bool = False, + ): + self._relation = relation + self._connection = connection + self._source = source + self._owns_connection = owns_connection + # Cache column names (cheap to compute from relation metadata) + self._columns: Optional[List[str]] = None + self._row_count: Optional[int] = None + + # ------------------------------------------------------------------ + # Factory methods + # ------------------------------------------------------------------ + + @classmethod + def from_parquet( + cls, + path: str, + connection: Optional["duckdb.DuckDBPyConnection"] = None, + ) -> "LazyFrame": + """Create a LazyFrame from a parquet file. + + Parameters + ---------- + path : str + Path to the parquet file. + connection : duckdb.DuckDBPyConnection, optional + Existing connection to reuse. If None, creates a new one. + + Returns + ------- + LazyFrame + """ + owns = connection is None + if connection is None: + connection = duckdb.connect() + + safe_path = path.replace("'", "''") + relation = connection.sql( + f"SELECT * FROM parquet_scan('{safe_path}')" + ) + return cls(relation, connection, source=f"parquet:{path}", owns_connection=owns) + + @classmethod + def from_sql( + cls, + sql: str, + connection: "duckdb.DuckDBPyConnection", + source: str = "sql", + ) -> "LazyFrame": + """Create a LazyFrame from an arbitrary SQL query. + + Parameters + ---------- + sql : str + SQL query to execute lazily. + connection : duckdb.DuckDBPyConnection + The DuckDB connection. + source : str + Description for repr. + + Returns + ------- + LazyFrame + """ + relation = connection.sql(sql) + return cls(relation, connection, source=source) + + @classmethod + def from_dataframe( + cls, + df: pd.DataFrame, + connection: Optional["duckdb.DuckDBPyConnection"] = None, + ) -> "LazyFrame": + """Create a LazyFrame from a pandas DataFrame. + + This registers the DataFrame as a temporary table in DuckDB, + allowing lazy operations on in-memory data. + + Parameters + ---------- + df : pd.DataFrame + Input DataFrame. + connection : duckdb.DuckDBPyConnection, optional + Existing connection. If None, creates a new one. + + Returns + ------- + LazyFrame + """ + owns = connection is None + if connection is None: + connection = duckdb.connect() + + # Use DuckDB's ability to query DataFrames directly + relation = connection.sql("SELECT * FROM df") + return cls(relation, connection, source="dataframe", owns_connection=owns) + + # ------------------------------------------------------------------ + # Properties + # ------------------------------------------------------------------ + + @property + def columns(self) -> List[str]: + """Column names (lazy — extracted from relation metadata).""" + if self._columns is None: + self._columns = self._relation.columns + return self._columns + + @property + def dtypes(self) -> List[str]: + """Column data types as strings.""" + return self._relation.dtypes + + @property + def row_count(self) -> int: + """Number of rows (executes COUNT(*) query).""" + if self._row_count is None: + result = self._relation.aggregate("COUNT(*)") + self._row_count = result.fetchone()[0] + return self._row_count + + @property + def shape(self) -> tuple: + """(rows, columns) — row count triggers a COUNT(*) query.""" + return (self.row_count, len(self.columns)) + + @property + def connection(self) -> "duckdb.DuckDBPyConnection": + """The underlying DuckDB connection.""" + return self._connection + + @property + def relation(self) -> "duckdb.DuckDBPyRelation": + """The underlying DuckDB relation.""" + return self._relation + + # ------------------------------------------------------------------ + # Lazy operations + # ------------------------------------------------------------------ + + def filter(self, condition: str) -> "LazyFrame": + """Apply a SQL WHERE filter lazily. + + Parameters + ---------- + condition : str + SQL condition (e.g., "Intensity > 0 AND ProteinName != 'DECOY'"). + + Returns + ------- + LazyFrame + New LazyFrame with the filter applied. + """ + new_relation = self._relation.filter(condition) + return LazyFrame(new_relation, self._connection, source=f"{self._source}[filtered]") + + def select(self, *columns: str) -> "LazyFrame": + """Select specific columns lazily. + + Parameters + ---------- + *columns : str + Column names to select. + + Returns + ------- + LazyFrame + New LazyFrame with only the selected columns. + """ + col_str = ", ".join(f'"{c}"' for c in columns) + new_relation = self._relation.project(col_str) + return LazyFrame(new_relation, self._connection, source=f"{self._source}[select]") + + def head(self, n: int = 5) -> pd.DataFrame: + """Materialize the first n rows as a DataFrame. + + Parameters + ---------- + n : int + Number of rows. + + Returns + ------- + pd.DataFrame + """ + return self._relation.limit(n).df() + + def describe(self) -> pd.DataFrame: + """Compute basic statistics (materializes aggregation only).""" + return self._relation.describe().df() + + def unique(self, column: str) -> List: + """Get unique values for a column. + + Parameters + ---------- + column : str + Column name. + + Returns + ------- + list + Unique values. + """ + result = self._relation.unique(column).df() + return result[column].tolist() + + # ------------------------------------------------------------------ + # Materialization + # ------------------------------------------------------------------ + + def df(self) -> pd.DataFrame: + """Materialize the full result as a pandas DataFrame. + + This triggers the actual computation. + + Returns + ------- + pd.DataFrame + """ + logger.debug("Materializing LazyFrame (%s)", self._source) + result = self._relation.df() + self._row_count = len(result) + return result + + def to_arrow(self): + """Materialize as a PyArrow Table.""" + return self._relation.arrow() + + # ------------------------------------------------------------------ + # Representation + # ------------------------------------------------------------------ + + # ------------------------------------------------------------------ + # Resource cleanup + # ------------------------------------------------------------------ + + def close(self) -> None: + """Close the underlying DuckDB connection if owned by this LazyFrame.""" + if self._owns_connection and self._connection is not None: + try: + self._connection.close() + except Exception: + pass + self._connection = None + + def __del__(self) -> None: + self.close() + + def __repr__(self) -> str: + cols = self.columns + ncols = len(cols) + col_preview = ", ".join(cols[:5]) + if ncols > 5: + col_preview += f", ... ({ncols - 5} more)" + return f"LazyFrame(source={self._source}, cols=[{col_preview}])" + + def __len__(self) -> int: + return self.row_count + + def __contains__(self, item: str) -> bool: + """Check if a column name exists.""" + return item in self.columns + + +def is_lazy(obj) -> bool: + """Check if an object is a LazyFrame. + + Parameters + ---------- + obj : any + Object to check. + + Returns + ------- + bool + """ + return isinstance(obj, LazyFrame) + + +def ensure_dataframe(obj: Union[pd.DataFrame, "LazyFrame"]) -> pd.DataFrame: + """Convert a LazyFrame to DataFrame if needed. + + Parameters + ---------- + obj : pd.DataFrame or LazyFrame + Input data. + + Returns + ------- + pd.DataFrame + """ + if isinstance(obj, LazyFrame): + return obj.df() + return obj diff --git a/mokume/core/registry.py b/mokume/core/registry.py new file mode 100644 index 0000000..85feef5 --- /dev/null +++ b/mokume/core/registry.py @@ -0,0 +1,297 @@ +""" +Plugin registry for the mokume package. + +This module provides a central registry for all mokume extension types. +Built-in methods register via decorators at import time. Third-party +packages register via Python entry points, discovered on first access. + +Extension groups: + - quantification: Protein quantification algorithms + - normalization.feature: Feature-level normalization methods + - normalization.sample: Sample/peptide-level normalization methods + - harmonization: Batch effect correction methods + - imputation: Missing value imputation methods + - filter: Quality control filters + +Example — registering a built-in method:: + + from mokume.core.registry import PluginRegistry + + @PluginRegistry.register("quantification", "directlfq") + class DirectLFQQuantification(QuantificationMethod): + ... + +Example — third-party registration via pyproject.toml:: + + [project.entry-points."mokume.quantification"] + spectral_counting = "my_package:SpectralCountingMethod" +""" + +import importlib.metadata +import logging +import re +from typing import Any, Dict, List, Optional, Set, Type + +logger = logging.getLogger(__name__) + +# Sentinel for TopN pattern matching +_TOPN_PATTERN = re.compile(r"^top(\d+)$") + +# Valid input_level values for quantification methods. +# Must match keys in FLOW_DISPATCH (runner.py) and QpxDataset._VALID_LEVELS. +VALID_INPUT_LEVELS: Set[str] = {"peptides", "psms", "peptides_raw", "features"} + + +class PluginRegistry: + """Central registry for all mokume extension types. + + Manages registration and discovery of plugins across five extension + groups. Supports both decorator-based registration (built-in) and + entry-point discovery (third-party packages). + """ + + _stores: Dict[str, Dict[str, Any]] = { + "quantification": {}, + "normalization.feature": {}, + "normalization.sample": {}, + "harmonization": {}, + "imputation": {}, + "filter": {}, + } + + _discovered: bool = False + + @classmethod + def register(cls, group: str, name: str): + """Decorator to register a plugin class. + + Parameters + ---------- + group : str + Extension group (e.g., "quantification", "normalization.feature"). + name : str + Name to register the plugin under (e.g., "maxlfq"). + + Returns + ------- + Callable + Decorator that registers the class and returns it unchanged. + + Raises + ------ + ValueError + If the group is not recognized. + + Examples + -------- + >>> @PluginRegistry.register("quantification", "my_method") + ... class MyMethod(QuantificationMethod): + ... ... + """ + if group not in cls._stores: + raise ValueError( + f"Unknown plugin group: '{group}'. " + f"Available groups: {list(cls._stores.keys())}" + ) + + def decorator(klass: Type) -> Type: + cls._stores[group][name.lower()] = klass + return klass + + return decorator + + @classmethod + def register_instance_factory(cls, group: str, name: str, factory): + """Register a callable factory for a plugin. + + Useful for registering aliases like top3/top5 that create + instances with specific parameters. + + Parameters + ---------- + group : str + Extension group. + name : str + Name to register under. + factory : callable + A callable that accepts **kwargs and returns a plugin instance. + """ + if group not in cls._stores: + raise ValueError( + f"Unknown plugin group: '{group}'. " + f"Available groups: {list(cls._stores.keys())}" + ) + cls._stores[group][name.lower()] = factory + + @classmethod + def get(cls, group: str, name: str, **kwargs: Any) -> Any: + """Get a plugin instance by group and name. + + Handles special patterns like topN (top3, top5, top10) by + parsing the numeric suffix. + + Parameters + ---------- + group : str + Extension group. + name : str + Plugin name. + **kwargs + Arguments passed to the plugin constructor. + + Returns + ------- + Any + An instance of the requested plugin. + + Raises + ------ + ValueError + If the plugin is not found. + """ + cls._ensure_discovered() + name_lower = name.lower() + + # Check direct match first + entry = cls._stores.get(group, {}).get(name_lower) + instance = None + if entry is not None: + if isinstance(entry, type): + instance = entry(**kwargs) + else: + # It's a factory callable + instance = entry(**kwargs) + else: + # Handle topN pattern: top3, top5, top10, etc. + if group == "quantification": + match = _TOPN_PATTERN.match(name_lower) + if match: + topn_cls = cls._stores.get(group, {}).get("topn") + if topn_cls is not None: + n = int(match.group(1)) + instance = topn_cls(n=n, **kwargs) + + if instance is None: + available = cls.available(group) + raise ValueError( + f"Unknown {group} method: '{name}'. " + f"Available: {available}" + ) + + # Validate input_level for quantification methods + if group == "quantification" and hasattr(instance, "input_level"): + level = instance.input_level + if level not in VALID_INPUT_LEVELS: + raise ValueError( + f"Quantification method '{name}' declares " + f"input_level='{level}', which is not valid. " + f"Must be one of: {sorted(VALID_INPUT_LEVELS)}" + ) + + return instance + + @classmethod + def get_class(cls, group: str, name: str) -> Optional[Type]: + """Get the registered class (not an instance) for a plugin. + + Parameters + ---------- + group : str + Extension group. + name : str + Plugin name. + + Returns + ------- + Type or None + The registered class, or None if not found. + """ + cls._ensure_discovered() + return cls._stores.get(group, {}).get(name.lower()) + + @classmethod + def available(cls, group: str) -> List[str]: + """List registered plugin names for a group. + + Parameters + ---------- + group : str + Extension group. + + Returns + ------- + list[str] + Sorted list of available plugin names. + """ + cls._ensure_discovered() + return sorted(cls._stores.get(group, {}).keys()) + + @classmethod + def is_registered(cls, group: str, name: str) -> bool: + """Check if a plugin is registered. + + Parameters + ---------- + group : str + Extension group. + name : str + Plugin name. + + Returns + ------- + bool + """ + cls._ensure_discovered() + return name.lower() in cls._stores.get(group, {}) + + @classmethod + def _ensure_discovered(cls): + """Discover entry-point plugins once on first access.""" + if cls._discovered: + return + cls._discovered = True + + for group in cls._stores: + ep_group = f"mokume.{group}" + try: + # Python 3.12+: entry_points(group=...) returns a SelectableGroups + # Python 3.9-3.11: entry_points() returns a dict + # Python 3.12+: entry_points(group=...) returns matching entries + # Python 3.9-3.11: entry_points() returns SelectableGroups + try: + group_eps = importlib.metadata.entry_points(group=ep_group) + except TypeError: + # Fallback for older Python versions + eps = importlib.metadata.entry_points() + if isinstance(eps, dict): + group_eps = eps.get(ep_group, []) + else: + group_eps = [ + ep for ep in eps if ep.group == ep_group + ] + + for ep in group_eps: + try: + klass = ep.load() + cls._stores[group][ep.name.lower()] = klass + logger.debug( + "Discovered plugin: %s.%s -> %s", + group, ep.name, klass, + ) + except Exception as exc: + logger.warning( + "Failed to load plugin '%s' from group '%s': %s", + ep.name, ep_group, exc, + ) + except Exception as exc: + logger.debug( + "Entry point discovery failed for group '%s': %s", + ep_group, exc, + ) + + @classmethod + def reset(cls): + """Reset the registry. Mainly useful for testing.""" + for group in cls._stores: + cls._stores[group].clear() + cls._discovered = False diff --git a/mokume/core/schema.py b/mokume/core/schema.py new file mode 100644 index 0000000..9c0d24c --- /dev/null +++ b/mokume/core/schema.py @@ -0,0 +1,120 @@ +""" +Schema definitions for the qpx data format. + +This module centralizes column name constants and schema validation +for the qpx parquet format used throughout mokume. It is the single +source of truth for column names at each data level. + +Column constants are re-exported here from constants.py for backward +compatibility. New code should import from this module. + +Data levels +----------- +- features: Raw feature-level data from qpx parquet +- peptides: Assembled peptide-level data after normalization +- proteins: Protein-level quantification results +- psms: PSM-level data (used by ratio quantification) +""" + +from typing import Dict, FrozenSet, List + +import pandas as pd + +# Re-export column constants from the canonical location +from mokume.core.constants import ( + BIOREPLICATE, + CHANNEL, + CONDITION, + FRACTION, + INTENSITY, + NORM_INTENSITY, + PARQUET_COLUMNS, + PEPTIDE_CANONICAL, + PEPTIDE_CHARGE, + PEPTIDE_SEQUENCE, + PROTEIN_NAME, + REFERENCE, + RUN, + SAMPLE_ID, + TECHREPLICATE, + parquet_map, +) + + +# --- Schema definitions per data level --- + +FEATURE_REQUIRED_COLS: FrozenSet[str] = frozenset({ + PROTEIN_NAME, + PEPTIDE_SEQUENCE, + SAMPLE_ID, + INTENSITY, +}) + +PEPTIDE_REQUIRED_COLS: FrozenSet[str] = frozenset({ + PROTEIN_NAME, + PEPTIDE_CANONICAL, + SAMPLE_ID, + NORM_INTENSITY, +}) + +PROTEIN_REQUIRED_COLS: FrozenSet[str] = frozenset({ + PROTEIN_NAME, + SAMPLE_ID, +}) + +PSM_REQUIRED_COLS: FrozenSet[str] = frozenset({ + PROTEIN_NAME, + PEPTIDE_SEQUENCE, + PEPTIDE_CHARGE, + SAMPLE_ID, + INTENSITY, +}) + +_LEVEL_SCHEMAS: Dict[str, FrozenSet[str]] = { + "features": FEATURE_REQUIRED_COLS, + "peptides": PEPTIDE_REQUIRED_COLS, + "proteins": PROTEIN_REQUIRED_COLS, + "psms": PSM_REQUIRED_COLS, +} + + +def validate_schema(df: pd.DataFrame, level: str) -> List[str]: + """Validate that a DataFrame has the required columns for a data level. + + Parameters + ---------- + df : pd.DataFrame + The DataFrame to validate. + level : str + Data level: one of "features", "peptides", "proteins", "psms". + + Returns + ------- + list[str] + List of error messages. Empty if the DataFrame is valid. + + Examples + -------- + >>> errors = validate_schema(my_df, "peptides") + >>> if errors: + ... raise ValueError(f"Schema errors: {errors}") + """ + required = _LEVEL_SCHEMAS.get(level) + if required is None: + return [f"Unknown data level: '{level}'. Available: {list(_LEVEL_SCHEMAS.keys())}"] + + missing = required - set(df.columns) + return [ + f"Missing required column '{col}' for {level} level" + for col in sorted(missing) + ] + + +def available_levels() -> List[str]: + """Return the list of available data levels. + + Returns + ------- + list[str] + """ + return list(_LEVEL_SCHEMAS.keys()) diff --git a/mokume/export/__init__.py b/mokume/export/__init__.py new file mode 100644 index 0000000..e589a45 --- /dev/null +++ b/mokume/export/__init__.py @@ -0,0 +1,17 @@ +""" +Export utilities for the mokume package. + +This module provides export functions for converting mokume data +into various formats: AnnData, CSV/TSV, and wide-format matrices. +""" + +from mokume.export.anndata import to_anndata, dataset_to_anndata +from mokume.export.csv import to_wide_csv, to_long_csv, dataset_to_csv + +__all__ = [ + "to_anndata", + "dataset_to_anndata", + "to_wide_csv", + "to_long_csv", + "dataset_to_csv", +] diff --git a/mokume/export/anndata.py b/mokume/export/anndata.py new file mode 100644 index 0000000..56233ca --- /dev/null +++ b/mokume/export/anndata.py @@ -0,0 +1,101 @@ +""" +AnnData export for mokume data. + +Provides functions for converting DataFrames and QpxDatasets to AnnData +objects. AnnData is an export-only format — used for downstream analysis +with scanpy, scvi-tools, etc. + +Requires: pip install mokume[anndata] +""" + +from typing import TYPE_CHECKING, List, Optional + +import pandas as pd + +from mokume.core.constants import PROTEIN_NAME, SAMPLE_ID + +if TYPE_CHECKING: + import anndata as ad + + from mokume.core.dataset import QpxDataset + + +def to_anndata( + df: pd.DataFrame, + obs_col: str = SAMPLE_ID, + var_col: str = PROTEIN_NAME, + value_col: str = "Intensity", + layer_cols: Optional[List[str]] = None, + obs_metadata_cols: Optional[List[str]] = None, + var_metadata_cols: Optional[List[str]] = None, +) -> "ad.AnnData": + """Create an AnnData object from a long-format DataFrame. + + This wraps the original ``mokume.io.parquet.create_anndata`` function + with better default column names for the standard mokume workflow. + + Parameters + ---------- + df : pd.DataFrame + Input data in long format. + obs_col : str + Column for observations (samples). Default: SampleID. + var_col : str + Column for variables (proteins). Default: ProteinName. + value_col : str + Column for the main X matrix values. + layer_cols : list[str], optional + Additional columns to include as AnnData layers. + obs_metadata_cols : list[str], optional + Columns to include as observation metadata. + var_metadata_cols : list[str], optional + Columns to include as variable metadata. + + Returns + ------- + anndata.AnnData + AnnData object with X matrix, obs, var, and optional layers. + """ + from mokume.io.parquet import create_anndata + + return create_anndata( + df=df, + obs_col=obs_col, + var_col=var_col, + value_col=value_col, + layer_cols=layer_cols, + obs_metadata_cols=obs_metadata_cols, + var_metadata_cols=var_metadata_cols, + ) + + +def dataset_to_anndata( + dataset: "QpxDataset", + level: str = "proteins", + value_col: str = "Intensity", + layer_names: Optional[List[str]] = None, +) -> "ad.AnnData": + """Export a QpxDataset level as AnnData. + + Convenience wrapper around ``QpxDataset.to_anndata()``. + + Parameters + ---------- + dataset : QpxDataset + The dataset to export. + level : str + Data level to export. + value_col : str + Column for the main X matrix values. + layer_names : list[str], optional + Names of dataset layers to include. + + Returns + ------- + anndata.AnnData + """ + return dataset.to_anndata( + level=level, + value_col=value_col, + layer_names=layer_names, + ) diff --git a/mokume/export/csv.py b/mokume/export/csv.py new file mode 100644 index 0000000..71cf32f --- /dev/null +++ b/mokume/export/csv.py @@ -0,0 +1,118 @@ +""" +CSV / TSV export utilities for mokume data. + +Provides functions for writing protein matrices and long-format data +to CSV/TSV files. +""" + +import logging +from typing import TYPE_CHECKING, Optional + +import pandas as pd + +from mokume.core.constants import PROTEIN_NAME, SAMPLE_ID + +if TYPE_CHECKING: + from mokume.core.dataset import QpxDataset + +logger = logging.getLogger(__name__) + + +def to_wide_csv( + df: pd.DataFrame, + output_path: str, + protein_col: str = PROTEIN_NAME, + sample_col: str = SAMPLE_ID, + value_col: str = "Intensity", + sep: str = ",", +) -> None: + """Export a long-format DataFrame as a wide protein x sample CSV. + + Parameters + ---------- + df : pd.DataFrame + Long-format data with protein, sample, and intensity columns. + output_path : str + Output file path. + protein_col : str + Column for protein identifiers (becomes row index). + sample_col : str + Column for sample identifiers (becomes column headers). + value_col : str + Column containing values for matrix cells. + sep : str + Delimiter. Use '\\t' for TSV. + """ + wide = df.pivot_table( + index=protein_col, + columns=sample_col, + values=value_col, + aggfunc="first", + ) + wide.to_csv(output_path, sep=sep) + logger.info( + "Exported wide matrix (%d proteins x %d samples) to %s", + wide.shape[0], + wide.shape[1], + output_path, + ) + + +def to_long_csv( + df: pd.DataFrame, + output_path: str, + sep: str = ",", + columns: Optional[list] = None, +) -> None: + """Export a DataFrame to CSV in long format. + + Parameters + ---------- + df : pd.DataFrame + Data to export. + output_path : str + Output file path. + sep : str + Delimiter. Use '\\t' for TSV. + columns : list, optional + Subset of columns to export. If None, all columns are exported. + """ + out = df[columns] if columns else df + out.to_csv(output_path, sep=sep, index=False) + logger.info("Exported %d rows to %s", len(out), output_path) + + +def dataset_to_csv( + dataset: "QpxDataset", + output_path: str, + level: str = "proteins", + wide: bool = True, + value_col: str = "Intensity", + sep: str = ",", +) -> None: + """Export a QpxDataset level to CSV. + + Parameters + ---------- + dataset : QpxDataset + The dataset to export. + output_path : str + Output file path. + level : str + Data level to export. + wide : bool + If True, export as wide protein x sample matrix. + If False, export as long format. + value_col : str + Column for values (used when wide=True). + sep : str + Delimiter. + """ + df = dataset.get_level(level) + if df is None: + raise ValueError(f"Data level '{level}' is not populated") + + if wide: + to_wide_csv(df, output_path, value_col=value_col, sep=sep) + else: + to_long_csv(df, output_path, sep=sep) diff --git a/mokume/harmonization/__init__.py b/mokume/harmonization/__init__.py new file mode 100644 index 0000000..9d7ad08 --- /dev/null +++ b/mokume/harmonization/__init__.py @@ -0,0 +1,46 @@ +""" +Batch effect correction (harmonization) for the mokume package. + +This package consolidates all batch-correction-related code: +- Base class for batch correctors (plugin ABC) +- ComBat implementation (via inmoose) +- Configuration and enums +- Core correction functions +""" + +from mokume.harmonization.base import BatchCorrector +from mokume.harmonization.combat import ComBatCorrector +from mokume.harmonization.models import BatchDetectionMethod, BatchCorrectionConfig +from mokume.harmonization.correction import ( + apply_batch_correction, + compute_pca, + detect_batches, + extract_covariates_from_sdrf, + get_batch_info_from_sample_names, + is_batch_correction_available, + is_inmoose_available, + iterative_outlier_removal, + remove_single_sample_batches, + TooFewSamplesInBatch, +) + +__all__ = [ + # Base + "BatchCorrector", + # Implementations + "ComBatCorrector", + # Config / enums + "BatchDetectionMethod", + "BatchCorrectionConfig", + # Functions + "apply_batch_correction", + "compute_pca", + "detect_batches", + "extract_covariates_from_sdrf", + "get_batch_info_from_sample_names", + "is_batch_correction_available", + "is_inmoose_available", + "iterative_outlier_removal", + "remove_single_sample_batches", + "TooFewSamplesInBatch", +] diff --git a/mokume/harmonization/base.py b/mokume/harmonization/base.py new file mode 100644 index 0000000..811aea9 --- /dev/null +++ b/mokume/harmonization/base.py @@ -0,0 +1,65 @@ +""" +Base class for batch correction methods. + +This module provides an abstract base class for batch effect correction +algorithms. Implementations should register with the PluginRegistry. +""" + +from abc import ABC, abstractmethod +from typing import List, Optional + +import pandas as pd + + +class BatchCorrector(ABC): + """Base class for batch effect correction methods. + + Batch correctors remove systematic technical variation (batch effects) + while preserving biological signal. They operate on protein-level + wide-format data (proteins x samples). + + Subclasses should register with:: + + from mokume.core.registry import PluginRegistry + + @PluginRegistry.register("harmonization", "combat") + class ComBatCorrector(BatchCorrector): + ... + """ + + @property + @abstractmethod + def name(self) -> str: + """Human-readable method name.""" + + @abstractmethod + def correct( + self, + df: pd.DataFrame, + batch: List[int], + covariates: Optional[List[List[int]]] = None, + **kwargs, + ) -> pd.DataFrame: + """Apply batch correction to a protein intensity matrix. + + Parameters + ---------- + df : pd.DataFrame + Wide-format protein intensity matrix (proteins x samples). + Index is protein identifiers, columns are sample identifiers. + batch : list[int] + Batch assignment for each sample (column). + covariates : list[list[int]], optional + Biological covariates to preserve. Each inner list is a + covariate with one value per sample. + **kwargs + Method-specific parameters. + + Returns + ------- + pd.DataFrame + Batch-corrected protein intensity matrix, same shape as input. + """ + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" diff --git a/mokume/harmonization/combat.py b/mokume/harmonization/combat.py new file mode 100644 index 0000000..aec083a --- /dev/null +++ b/mokume/harmonization/combat.py @@ -0,0 +1,78 @@ +""" +ComBat batch correction implementation. + +Wraps the existing batch correction logic and registers it with the PluginRegistry. + +Requires: pip install mokume[batch-correction] +""" + +from typing import List, Optional + +import pandas as pd + +from mokume.harmonization.base import BatchCorrector +from mokume.core.registry import PluginRegistry + + +@PluginRegistry.register("harmonization", "combat") +class ComBatCorrector(BatchCorrector): + """ComBat batch correction using inmoose. + + Removes batch effects while optionally preserving biological signal + specified via covariates (e.g., sex, tissue from SDRF). + + Parameters + ---------- + parametric : bool + Use parametric empirical Bayes estimation. Default True. + mean_only : bool + Only adjust batch means. Default False. + """ + + def __init__(self, parametric: bool = True, mean_only: bool = False): + self.parametric = parametric + self.mean_only = mean_only + + @property + def name(self) -> str: + return "ComBat" + + def correct( + self, + df: pd.DataFrame, + batch: List[int], + covariates: Optional[List[List[int]]] = None, + **kwargs, + ) -> pd.DataFrame: + """Apply ComBat batch correction. + + Delegates to mokume.harmonization.correction.apply_batch_correction(). + + Parameters + ---------- + df : pd.DataFrame + Wide-format protein intensity matrix (proteins x samples). + batch : list[int] + Batch assignment for each sample. + covariates : list[list[int]], optional + Biological covariates to preserve. + **kwargs + Additional keyword arguments passed to apply_batch_correction. + + Returns + ------- + pd.DataFrame + Batch-corrected intensity matrix. + """ + from mokume.harmonization.correction import apply_batch_correction + + return apply_batch_correction( + df=df, + batch=batch, + covs=covariates, + kwargs={ + "par_prior": self.parametric, + "mean_only": self.mean_only, + **kwargs, + }, + ) diff --git a/mokume/harmonization/correction.py b/mokume/harmonization/correction.py new file mode 100644 index 0000000..71b7719 --- /dev/null +++ b/mokume/harmonization/correction.py @@ -0,0 +1,477 @@ +""" +Batch correction utilities for the mokume package. + +This module provides batch effect correction using ComBat (via inmoose). + +Key Concepts: +- Batch: Technical variation to REMOVE (e.g., different runs, labs, processing days) +- Covariates: Biological variables to PRESERVE (e.g., sex, tissue from SDRF characteristics) + +Note: This module requires the optional 'inmoose' dependency. +Install it with: pip install mokume[batch-correction] +""" + +import logging +import warnings +from typing import List, Optional, Dict, Union + +import numpy as np +import pandas as pd + +warnings.filterwarnings( + "ignore", category=PendingDeprecationWarning, module="numpy.matrixlib.defmatrix" +) + +from sklearn.cluster._hdbscan import hdbscan +from sklearn.decomposition import PCA + +from mokume.plotting import is_plotting_available +from mokume.harmonization.models import BatchDetectionMethod + +logger = logging.getLogger(__name__) +logger.addHandler(logging.NullHandler()) + + +def is_inmoose_available() -> bool: + """Check if inmoose is installed.""" + try: + import inmoose + return True + except ImportError: + return False + + +def is_batch_correction_available() -> bool: + """ + Check if batch correction dependencies are installed. + + Returns + ------- + bool + True if inmoose is installed, False otherwise. + + Notes + ----- + Install batch correction support with: pip install mokume[batch-correction] + """ + return is_inmoose_available() + + +def compute_pca(df, n_components=5) -> pd.DataFrame: + """Compute principal components for a given dataframe.""" + pca = PCA(n_components=n_components) + pca.fit(df) + df_pca = pca.transform(df) + df_pca = pd.DataFrame( + df_pca, index=df.index, columns=[f"PC{i}" for i in range(1, n_components + 1)] + ) + return df_pca + + +def get_batch_info_from_sample_names(sample_list: List[str]) -> List[int]: + """Get batch indices from sample names (legacy function, use detect_batches instead).""" + samples = [s.split("-")[0] for s in sample_list] + batches = list(set(samples)) + index = {i: batches.index(i) for i in batches} + return [index[i] for i in samples] + + +def detect_batches( + sample_ids: List[str], + method: Union[BatchDetectionMethod, str] = BatchDetectionMethod.SAMPLE_PREFIX, + run_info: Optional[Dict[str, str]] = None, + batch_column_values: Optional[List[str]] = None, +) -> List[int]: + """ + Detect batch assignments for samples. + + Parameters + ---------- + sample_ids : List[str] + Sample identifiers. + method : BatchDetectionMethod or str + How to determine batches. Options: + - "sample_prefix": Extract from sample name prefix (PXD001-S1 → PXD001) + - "run": Use run/reference file name + - "fraction": Use fraction identifier + - "techreplicate": Use technical replicate identifier + - "column": Use explicit batch values + run_info : Optional[Dict[str, str]] + Mapping of sample_id → run_name (for "run" method). + batch_column_values : Optional[List[str]] + Explicit batch values for each sample (for "column" method). + + Returns + ------- + List[int] + Batch index for each sample (0-indexed). + + Raises + ------ + ValueError + If required parameters are missing for the selected method. + + Examples + -------- + >>> samples = ["PXD001-S1", "PXD001-S2", "PXD002-S1", "PXD002-S2"] + >>> detect_batches(samples, method="sample_prefix") + [0, 0, 1, 1] + + >>> detect_batches(samples, method="column", batch_column_values=["A", "A", "B", "B"]) + [0, 0, 1, 1] + """ + if isinstance(method, str): + method = BatchDetectionMethod.from_str(method) + + if method == BatchDetectionMethod.SAMPLE_PREFIX: + # Extract batch prefix from sample names. + # Supports multiple conventions: + # PXD001-S1 → PXD001 (hyphen-separated) + # p1_1 → p1 (quantms TMT plex prefix: letters + digits before _digit) + import re + prefixes = [] + for s in sample_ids: + if "-" in s: + prefixes.append(s.split("-")[0]) + else: + # Match leading non-numeric + optional digits as prefix (e.g. p1_1 → p1) + m = re.match(r"^([a-zA-Z]+\d+)_", s) + if m: + prefixes.append(m.group(1)) + else: + prefixes.append(s) + indices, _ = pd.factorize(pd.array(prefixes)) + return indices.tolist() + + elif method == BatchDetectionMethod.RUN_NAME: + if run_info is None: + raise ValueError("run_info required for RUN_NAME method") + runs = [run_info.get(s, s) for s in sample_ids] + indices, _ = pd.factorize(pd.array(runs)) + return indices.tolist() + + elif method == BatchDetectionMethod.FRACTION: + # Requires fraction info - fall back to sample prefix if not available + if run_info is None: + logger.warning("No fraction info provided, falling back to sample_prefix") + return detect_batches(sample_ids, BatchDetectionMethod.SAMPLE_PREFIX) + fractions = [run_info.get(s, "1") for s in sample_ids] + indices, _ = pd.factorize(pd.array(fractions)) + return indices.tolist() + + elif method == BatchDetectionMethod.TECHREPLICATE: + # Requires tech rep info - fall back to sample prefix if not available + if run_info is None: + logger.warning("No tech rep info provided, falling back to sample_prefix") + return detect_batches(sample_ids, BatchDetectionMethod.SAMPLE_PREFIX) + tech_reps = [run_info.get(s, "1") for s in sample_ids] + indices, _ = pd.factorize(pd.array(tech_reps)) + return indices.tolist() + + elif method == BatchDetectionMethod.EXPLICIT_COLUMN: + if batch_column_values is None: + raise ValueError("batch_column_values required for EXPLICIT_COLUMN method") + if len(batch_column_values) != len(sample_ids): + raise ValueError( + f"batch_column_values length ({len(batch_column_values)}) " + f"must match sample_ids length ({len(sample_ids)})" + ) + indices, _ = pd.factorize(pd.array(batch_column_values)) + return indices.tolist() + + else: + raise ValueError(f"Unknown batch detection method: {method}") + + +def extract_covariates_from_sdrf( + sdrf_path: str, + sample_ids: List[str], + covariate_columns: List[str], +) -> Optional[List[List[int]]]: + """ + Extract categorical covariates from SDRF for batch correction. + + Covariates represent biological variables whose signal should be PRESERVED + during batch correction. For example, if samples from different batches + share the same sex or tissue type, ComBat will preserve this biological + signal while removing technical batch effects. + + Parameters + ---------- + sdrf_path : str + Path to SDRF file. + sample_ids : List[str] + Sample IDs matching the protein matrix columns (in order). + covariate_columns : List[str] + SDRF columns to use as covariates. + e.g., ["characteristics[sex]", "characteristics[organism part]"] + + Returns + ------- + List[List[int]] or None + Covariate matrix as list of lists (samples × covariates) with + categorical encoding, or None if no valid covariates found. + + Notes + ----- + - Covariates MUST be categorical (ComBat requirement) + - Samples in covariate matrix must match protein matrix column order + - Signal from these variables is PRESERVED after batch correction + + Examples + -------- + SDRF with columns: + source name | characteristics[sex] | characteristics[tissue] + Sample1 | male | liver + Sample2 | female | liver + Sample3 | male | brain + + >>> extract_covariates_from_sdrf( + ... "experiment.sdrf.tsv", + ... ["Sample1", "Sample2", "Sample3"], + ... ["characteristics[sex]", "characteristics[tissue]"] + ... ) + [[0, 0], [1, 0], [0, 1]] # [sex_encoded, tissue_encoded] per sample + """ + if not covariate_columns: + return None + + try: + sdrf = pd.read_csv(sdrf_path, sep="\t") + except Exception as e: + logger.warning(f"Failed to read SDRF file: {e}") + return None + + sdrf.columns = [c.lower() for c in sdrf.columns] + + # Find the sample name column + sample_col = None + for col in ["source name", "sample name", "source_name", "sample_name"]: + if col in sdrf.columns: + sample_col = col + break + + if sample_col is None: + logger.warning("Could not find sample name column in SDRF") + return None + + # Build sample → row index mapping + sdrf_samples = sdrf[sample_col].tolist() + + covar_data = [] + valid_columns = [] + + for col in covariate_columns: + col_lower = col.lower() + + # Find matching column (exact or partial match) + matched_col = None + if col_lower in sdrf.columns: + matched_col = col_lower + else: + # Try partial match for characteristics columns + for sdrf_col in sdrf.columns: + if col_lower in sdrf_col or sdrf_col in col_lower: + matched_col = sdrf_col + break + + if matched_col is None: + logger.warning(f"Covariate column '{col}' not found in SDRF, skipping") + continue + + # Create sample → value mapping + sample_to_value = dict(zip(sdrf[sample_col], sdrf[matched_col])) + + # Get values for our samples in order + values = [] + for sample_id in sample_ids: + value = sample_to_value.get(sample_id) + if value is None: + # Try partial match + for sdrf_sample in sdrf_samples: + if sample_id in sdrf_sample or sdrf_sample in sample_id: + value = sample_to_value.get(sdrf_sample) + break + values.append(value if value is not None else "unknown") + + # Check if all values are the same (no information) + unique_values = set(values) + if len(unique_values) <= 1: + logger.warning( + f"Covariate '{col}' has only one unique value, skipping (no information)" + ) + continue + + # Encode as categorical integers + encoded, _ = pd.factorize(pd.array(values)) + covar_data.append(encoded.tolist()) + valid_columns.append(col) + logger.info(f"Extracted covariate '{col}' with {len(unique_values)} unique values") + + if not covar_data: + return None + + # Transpose: from [covariates][samples] to [samples][covariates] + # pycombat expects covar_mod as (n_samples, n_covariates) + n_samples = len(sample_ids) + n_covariates = len(covar_data) + result = [[covar_data[j][i] for j in range(n_covariates)] for i in range(n_samples)] + + logger.info(f"Extracted {n_covariates} covariates for {n_samples} samples: {valid_columns}") + return result + + +def remove_single_sample_batches(df: pd.DataFrame, batch: list) -> pd.DataFrame: + """Remove batches with only one sample.""" + batch_dict = dict(zip(df.columns, batch)) + single_sample_batch = [ + k for k, v in batch_dict.items() if list(batch_dict.values()).count(v) == 1 + ] + df_single_batches_removed = df.drop(single_sample_batch, axis=1) + return df_single_batches_removed + + +class TooFewSamplesInBatch(ValueError): + def __init__(self, batches): + super().__init__( + f"Batches must contain at least two samples, the following batch factors did not: {batches}" + ) + + +def apply_batch_correction( + df: pd.DataFrame, + batch: List[int], + covs: Optional[List[int]] = None, + kwargs: Optional[dict] = None, +) -> pd.DataFrame: + """ + Apply batch correction using pycombat from inmoose. + + Note: Requires the optional 'inmoose' dependency. + Install it with: pip install mokume[inmoose] + + Parameters + ---------- + df : pd.DataFrame + DataFrame with samples as columns and features as rows. + batch : List[int] + Batch indices for each sample. + covs : Optional[List[int]] + Covariate indices for each sample. + kwargs : Optional[dict] + Additional arguments for pycombat_norm. + + Returns + ------- + pd.DataFrame + Batch-corrected DataFrame. + + Raises + ------ + ImportError + If inmoose is not installed. + ValueError + If sample counts don't match batch/covariate counts. + TooFewSamplesInBatch + If any batch has fewer than 2 samples. + """ + if not is_inmoose_available(): + raise ImportError( + "inmoose is required for batch correction but is not installed. " + "Install it with: pip install mokume[inmoose]" + ) + + if kwargs is None: + kwargs = {} + + if len(df.columns) != len(batch): + raise ValueError( + f"The number of samples should match the number of batch " + f"indices. There were {len(batch)} batch indices and {len(df.columns)} samples" + ) + + if any([batch.count(i) < 2 for i in set(batch)]): + short_batches = [i for i in set(batch) if batch.count(i) < 2] + raise TooFewSamplesInBatch(short_batches) + + if covs: + if len(df.columns) != len(covs): + raise ValueError( + f"The number of samples should match the number of covariates. " + f"There were {len(covs)} batch indices and {len(df.columns)} samples" + ) + + from inmoose.pycombat import pycombat_norm + + df_co = pycombat_norm(counts=df, batch=batch, covar_mod=covs, **kwargs) + return df_co + + +def find_clusters(df, min_cluster_size, min_samples) -> pd.DataFrame: + """Compute clusters for a given dataframe using HDBSCAN.""" + clusterer = hdbscan.HDBSCAN( + min_cluster_size=min_cluster_size, + min_samples=min_samples, + metric="euclidean", + cluster_selection_method="eom", + allow_single_cluster=True, + cluster_selection_epsilon=0.01, + ) + clusterer.fit(df) + df["cluster"] = clusterer.labels_ + return df + + +def iterative_outlier_removal( + df: pd.DataFrame, + batch: List[int], + n_components: int = 5, + min_cluster_size: int = 10, + min_samples: int = 10, + n_iter: int = 10, + verbose: bool = True, +) -> pd.DataFrame: + """Iteratively remove outliers using PCA and HDBSCAN clustering.""" + batch_dict = dict(zip(df.columns, batch)) + + # Check plotting availability once if verbose + can_plot = verbose and is_plotting_available() + if verbose and not can_plot: + logger.warning( + "Plotting skipped: plotting dependencies not installed. " + "Install with: pip install mokume[plotting]" + ) + + for i in range(n_iter): + logger.info("Running iteration: {}".format(i + 1)) + + df_pca = compute_pca(df.T, n_components=n_components) + df_clusters = find_clusters( + df_pca, min_cluster_size=min_cluster_size, min_samples=min_samples + ) + logger.info(df_clusters) + + outliers = df_clusters[df_clusters["cluster"] == -1].index.tolist() + df_filtered_outliers = df.drop(outliers, axis=1) + logger.info(f"Number of outliers in iteration {i + 1}: {len(outliers)}") + logger.info(f"Outliers in iteration {i + 1}: {str(outliers)}") + + batch_dict = {col: batch_dict[col] for col in df_filtered_outliers.columns} + df = df_filtered_outliers + + if can_plot: + from mokume.plotting import plot_pca + + plot_pca( + df_clusters, + output_file=f"iterative_outlier_removal_{i + 1}.png", + x_col="PC1", + y_col="PC2", + hue_col="cluster", + title=f"Iteration {i + 1}: Number of outliers: {len(outliers)}", + ) + + if len(outliers) == 0: + break + + return df diff --git a/mokume/harmonization/models.py b/mokume/harmonization/models.py new file mode 100644 index 0000000..34e0ba8 --- /dev/null +++ b/mokume/harmonization/models.py @@ -0,0 +1,127 @@ +""" +Batch correction configuration and enums for the mokume package. + +This module provides configuration classes and enums for batch effect +correction using ComBat (via inmoose). + +Key Concepts: +- Batch: Technical variation to REMOVE (e.g., different runs, labs, processing days) +- Covariates: Biological variables to PRESERVE (e.g., sex, tissue from SDRF characteristics) +""" + +from dataclasses import dataclass, field +from enum import Enum +from typing import Optional, List + + +class BatchDetectionMethod(Enum): + """ + Methods for detecting/assigning batch labels. + + Attributes + ---------- + SAMPLE_PREFIX : str + Extract batch from sample name prefix (e.g., PXD001-S1 → batch=PXD001). + RUN_NAME : str + Use run/reference file name as batch identifier. + FRACTION : str + Treat each fraction as a separate batch. + TECHREPLICATE : str + Treat each technical replicate as a separate batch. + EXPLICIT_COLUMN : str + Use values from a user-specified column. + """ + + SAMPLE_PREFIX = "sample_prefix" + RUN_NAME = "run" + FRACTION = "fraction" + TECHREPLICATE = "techreplicate" + EXPLICIT_COLUMN = "column" + + @classmethod + def from_str(cls, name: str) -> "BatchDetectionMethod": + """ + Convert a string to a BatchDetectionMethod. + + Parameters + ---------- + name : str + The name of the batch detection method. + + Returns + ------- + BatchDetectionMethod + The batch detection method enum value. + + Raises + ------ + ValueError + If the name does not match any method. + """ + name_lower = name.lower().replace("-", "_").replace(" ", "_") + for member in cls: + if member.value == name_lower: + return member + valid = [m.value for m in cls] + raise ValueError(f"Unknown batch detection method: {name}. Valid options: {valid}") + + +@dataclass +class BatchCorrectionConfig: + """ + Configuration for batch effect correction. + + This configuration controls how batch effects are detected and corrected + using the ComBat algorithm (via inmoose). + + Attributes + ---------- + enabled : bool + Whether to apply batch correction. Default False. + batch_method : BatchDetectionMethod + How to detect/assign batch labels. Default SAMPLE_PREFIX. + batch_column : str, optional + Column name for explicit batch assignment (when batch_method=EXPLICIT_COLUMN). + covariate_columns : List[str] + SDRF columns to use as covariates (biological signal to preserve). + Example: ["characteristics[sex]", "characteristics[organism part]"] + parametric : bool + Use parametric empirical Bayes estimation. Default True. + Set False for non-parametric estimation. + mean_only : bool + Only adjust batch means, not individual effects. Default False. + ref_batch : int, optional + Batch ID to use as reference (all other batches adjusted to this one). + + Examples + -------- + >>> config = BatchCorrectionConfig( + ... enabled=True, + ... batch_method=BatchDetectionMethod.SAMPLE_PREFIX, + ... covariate_columns=["characteristics[sex]", "characteristics[tissue]"], + ... ) + """ + + enabled: bool = False + + # Batch detection + batch_method: BatchDetectionMethod = BatchDetectionMethod.SAMPLE_PREFIX + batch_column: Optional[str] = None + + # Covariates from SDRF (biological signal to preserve) + covariate_columns: List[str] = field(default_factory=list) + + # ComBat parameters + parametric: bool = True + mean_only: bool = False + ref_batch: Optional[int] = None + + def __post_init__(self): + """Validate configuration after initialization.""" + if isinstance(self.batch_method, str): + self.batch_method = BatchDetectionMethod.from_str(self.batch_method) + + if self.batch_method == BatchDetectionMethod.EXPLICIT_COLUMN and not self.batch_column: + raise ValueError( + "batch_column must be specified when batch_method is EXPLICIT_COLUMN" + ) diff --git a/mokume/imputation/base.py b/mokume/imputation/base.py new file mode 100644 index 0000000..410645c --- /dev/null +++ b/mokume/imputation/base.py @@ -0,0 +1,53 @@ +""" +Base class for imputation methods. + +This module provides an abstract base class for missing value imputation +algorithms. Implementations should register with the PluginRegistry. +""" + +from abc import ABC, abstractmethod + +import pandas as pd + + +class ImputationMethod(ABC): + """Base class for missing value imputation methods. + + Imputation methods fill in missing (NaN) values in wide-format + numeric matrices. They can be applied at any data level (features, + peptides, proteins). + + Subclasses should register with:: + + from mokume.core.registry import PluginRegistry + + @PluginRegistry.register("imputation", "knn") + class KNNImputation(ImputationMethod): + ... + """ + + @property + @abstractmethod + def name(self) -> str: + """Human-readable method name.""" + + @abstractmethod + def impute(self, df: pd.DataFrame, **kwargs) -> pd.DataFrame: + """Impute missing values in a wide-format matrix. + + Parameters + ---------- + df : pd.DataFrame + Wide-format numeric matrix (observations x variables). + Missing values are represented as NaN. + **kwargs + Method-specific parameters. + + Returns + ------- + pd.DataFrame + Matrix with missing values imputed, same shape as input. + """ + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" diff --git a/mokume/imputation/knn.py b/mokume/imputation/knn.py new file mode 100644 index 0000000..6c0e8bd --- /dev/null +++ b/mokume/imputation/knn.py @@ -0,0 +1,74 @@ +""" +KNN imputation implementation. + +Wraps sklearn.impute.KNNImputer and registers it with the PluginRegistry. +""" + +import pandas as pd +from sklearn.impute import KNNImputer + +from mokume.core.registry import PluginRegistry +from mokume.imputation.base import ImputationMethod + + +@PluginRegistry.register("imputation", "knn") +class KNNImputation(ImputationMethod): + """K-Nearest Neighbors imputation. + + Uses sklearn's KNNImputer to fill missing values based on the + values of the nearest neighbors in the feature space. + + Parameters + ---------- + n_neighbors : int + Number of neighboring samples to use. Default 5. + weights : str + Weight function for prediction: "uniform" or "distance". Default "uniform". + metric : str + Distance metric for neighbor search. Default "nan_euclidean". + keep_empty_features : bool + Whether to keep features that are entirely NaN. Default True. + """ + + def __init__( + self, + n_neighbors: int = 5, + weights: str = "uniform", + metric: str = "nan_euclidean", + keep_empty_features: bool = True, + ): + self.n_neighbors = n_neighbors + self.weights = weights + self.metric = metric + self.keep_empty_features = keep_empty_features + + @property + def name(self) -> str: + return "KNN" + + def impute(self, df: pd.DataFrame, **kwargs) -> pd.DataFrame: + """Impute missing values using K-Nearest Neighbors. + + Parameters + ---------- + df : pd.DataFrame + Wide-format numeric matrix with NaN values. + **kwargs + Ignored (reserved for future use). + + Returns + ------- + pd.DataFrame + Matrix with missing values imputed. + """ + imputer = KNNImputer( + n_neighbors=self.n_neighbors, + weights=self.weights, + metric=self.metric, + keep_empty_features=self.keep_empty_features, + ) + imputed = imputer.fit_transform(df) + return pd.DataFrame(imputed, columns=df.columns, index=df.index) + + def __repr__(self) -> str: + return f"KNNImputation(n_neighbors={self.n_neighbors})" diff --git a/mokume/imputation/simple.py b/mokume/imputation/simple.py new file mode 100644 index 0000000..65b0833 --- /dev/null +++ b/mokume/imputation/simple.py @@ -0,0 +1,91 @@ +""" +Simple imputation methods (mean, median, most_frequent, constant). + +Wraps sklearn.impute.SimpleImputer and registers each strategy +with the PluginRegistry. +""" + +import pandas as pd +from sklearn.impute import SimpleImputer + +from mokume.core.registry import PluginRegistry +from mokume.imputation.base import ImputationMethod + + +class _SimpleImputation(ImputationMethod): + """Shared base for sklearn SimpleImputer strategies.""" + + _strategy: str = "" + + def __init__(self, fill_value: float = 0.0): + self.fill_value = fill_value + + def impute(self, df: pd.DataFrame, **kwargs) -> pd.DataFrame: + imputer = SimpleImputer(strategy=self._strategy, fill_value=self.fill_value) + imputed = imputer.fit_transform(df) + return pd.DataFrame(imputed, columns=df.columns, index=df.index) + + +@PluginRegistry.register("imputation", "mean") +class MeanImputation(_SimpleImputation): + """Impute missing values with the column mean. + + Each column's NaN values are replaced by the mean of the + non-missing values in that column. + """ + + _strategy = "mean" + + @property + def name(self) -> str: + return "Mean" + + +@PluginRegistry.register("imputation", "median") +class MedianImputation(_SimpleImputation): + """Impute missing values with the column median. + + Each column's NaN values are replaced by the median of the + non-missing values in that column. + """ + + _strategy = "median" + + @property + def name(self) -> str: + return "Median" + + +@PluginRegistry.register("imputation", "most_frequent") +class MostFrequentImputation(_SimpleImputation): + """Impute missing values with the most frequent value. + + Each column's NaN values are replaced by the most frequent value + in that column. + """ + + _strategy = "most_frequent" + + @property + def name(self) -> str: + return "MostFrequent" + + +@PluginRegistry.register("imputation", "constant") +class ConstantImputation(_SimpleImputation): + """Impute missing values with a constant. + + Parameters + ---------- + fill_value : float + The value to fill NaN entries with. Default 0.0. + """ + + _strategy = "constant" + + @property + def name(self) -> str: + return "Constant" + + def __repr__(self) -> str: + return f"ConstantImputation(fill_value={self.fill_value})" diff --git a/mokume/io/qpx_adapter.py b/mokume/io/qpx_adapter.py new file mode 100644 index 0000000..7f59561 --- /dev/null +++ b/mokume/io/qpx_adapter.py @@ -0,0 +1,363 @@ +""" +QPX adapter — thin wrapper over qpx for reading and validation. + +All reading, validation, and dataset assembly are delegated to qpx. +Mokume only consumes qpx APIs and applies column-name mapping so +algorithm code receives the names it expects (reference_file_name, +precursor_charge, pg_accessions, etc.). + +See docs/plans/qpx-migration-principles.md for division of responsibilities. +""" + +from __future__ import annotations + +import logging +from pathlib import Path +from typing import Iterator, Optional + +import numpy as np +import pandas as pd + +from mokume.model.labeling import QuantificationCategory, IsobaricLabel + +logger = logging.getLogger(__name__) + +# Column mapping: QPX names -> mokume-internal names used by algorithms +QPX_TO_MOKUME_COLS = { + "run_file_name": "reference_file_name", + "charge": "precursor_charge", + "anchor_protein": "pg_accessions", # we expose as list of one element for compatibility +} + + +def _map_qpx_long_to_mokume(df: pd.DataFrame) -> pd.DataFrame: + """Map QPX long-form columns to names mokume algorithms expect.""" + out = df.copy() + if "run_file_name" in out.columns: + out["reference_file_name"] = out["run_file_name"] + out["run"] = out["run_file_name"] + if "charge" in out.columns: + out["precursor_charge"] = out["charge"] + if "anchor_protein" in out.columns: + import ast + def _parse_pg(x): + if pd.isna(x): + return [] + s = str(x) + if s.startswith("["): + try: + return ast.literal_eval(s) + except (ValueError, SyntaxError): + pass + return [s] + out["pg_accessions"] = out["anchor_protein"].apply(_parse_pg) + if "label" in out.columns: + out["channel"] = out["label"] + # biological_replicate: from QPX run.samples or default to 1 + if "biological_replicate" not in out.columns: + out["biological_replicate"] = 1 + else: + out["biological_replicate"] = out["biological_replicate"].fillna(1).astype(int) + # fraction: from QPX run table or default to "1" + if "fraction" not in out.columns: + out["fraction"] = "1" + else: + out["fraction"] = out["fraction"].fillna("1").astype(str) + return out + + +def open_qpx(path: str | Path, structures: Optional[list[str]] = None): + """Open a QPX dataset directory using qpx. All reading/validation done by qpx.""" + try: + import qpx + except ImportError as e: + raise ImportError( + "QPX directory input requires the 'qpx' package. Install with: pip install qpx" + ) from e + return qpx.Dataset(path, structures=structures or ["feature", "sample", "run"]) + + +def open_qpx_feature_single(path: str | Path): + """Open a single feature.parquet file using qpx. All reading done by qpx.""" + try: + import qpx + except ImportError as e: + raise ImportError( + "Single-file QPX mode requires the 'qpx' package. Install with: pip install qpx" + ) from e + return qpx.read_feature(str(path)) + + +class QpxFeatureAdapter: + """ + Feature-like facade over qpx.Dataset for quantification. + + Delegates all reading, iteration, and validation to qpx. Only applies + column-name mapping so mokume algorithms receive expected column names. + """ + + def __init__(self, dataset, filter_builder=None): + """ + Parameters + ---------- + dataset : qpx.Dataset + Opened QPX dataset (must have feature; run/sample optional but recommended). + filter_builder : object, optional + SQLFilterBuilder-compatible object; filtering is applied in Python + after fetching from qpx (qpx does not accept mokume's WHERE strings). + """ + self._ds = dataset + self.filter_builder = filter_builder + self._feature = dataset.feature + self._run = getattr(dataset, "run", None) + self._sample = getattr(dataset, "sample", None) + self._samples: Optional[list[str]] = None + self._long_form: Optional[pd.DataFrame] = None + + def _get_long_form(self) -> pd.DataFrame: + """Get feature data in long form from qpx; cache and map column names.""" + if self._long_form is not None: + return self._long_form + if self._feature is None: + raise ValueError("QPX dataset has no feature structure") + # Use qpx API: for_quantification when run is present, else peptide_intensities + if self._run is not None: + result = self._feature.for_quantification(self._ds) + else: + result = self._feature.peptide_intensities() + # peptide_intensities has label but not sample_accession; use label as sample + df = result.to_df() + if "sample_accession" not in df.columns and "label" in df.columns: + df = df.copy() + df["sample_accession"] = df["label"].astype(str) + self._long_form = _map_qpx_long_to_mokume(df) + if "unique" not in self._long_form.columns: + self._long_form["unique"] = 1 + return self._long_form + df = result.to_df() + self._long_form = _map_qpx_long_to_mokume(df) + if "unique" not in self._long_form.columns: + self._long_form["unique"] = 1 + return self._long_form + + @property + def samples(self) -> list[str]: + """Unique sample accessions from qpx (run/sample or from feature labels).""" + if self._samples is not None: + return self._samples + df = self._get_long_form() + self._samples = df["sample_accession"].dropna().unique().tolist() + return self._samples + + def iter_samples( + self, sample_num: int = 20, columns: Optional[list] = None + ) -> Iterator[tuple[list[str], pd.DataFrame]]: + """Iterate over samples in batches. Data and grouping from qpx.""" + df = self._get_long_form() + if columns: + available = [c for c in columns if c in df.columns] + df = df[available] if available else df + ref_list = [ + self.samples[i : i + sample_num] + for i in range(0, len(self.samples), sample_num) + ] + for refs in ref_list: + batch = df[df["sample_accession"].isin(refs)] + yield refs, batch + + def get_median_map(self) -> dict[str, float]: + """Sample median intensity map (sample median / global median). Filtering in Python.""" + df = self._get_long_form() + if self.filter_builder: + df = _apply_filter_builder(df, self.filter_builder) + med = df.groupby("sample_accession")["intensity"].median() + global_med = float(med.median()) + if global_med <= 0: + return {s: 1.0 for s in med.index} + return (med / global_med).to_dict() + + def get_median_map_to_condition(self) -> dict[str, dict[str, float]]: + """Per-condition median map. Uses 'condition' if present, else sample as condition.""" + df = self._get_long_form() + if self.filter_builder: + df = _apply_filter_builder(df, self.filter_builder) + if "condition" not in df.columns: + df = df.copy() + df["condition"] = df["sample_accession"] + grp = df.groupby(["condition", "sample_accession"])["intensity"].median().unstack(level=0) + med_map = {} + for cond in grp.columns: + s = grp[cond].dropna() + if s.empty: + continue + mean_val = s.mean() + if mean_val <= 0: + med_map[cond] = {k: 1.0 for k in s.index} + else: + med_map[cond] = (s / mean_val).to_dict() + return med_map + + @property + def experimental_inference(self) -> tuple[int, QuantificationCategory, list[str], Optional[IsobaricLabel]]: + """Infer label type and samples from qpx data.""" + df = self._get_long_form() + labels = df["label"].dropna().unique().tolist() if "label" in df.columns else [] + sample_names = self.samples + label_enum, choice = QuantificationCategory.classify(labels) + run_col = df.get("run_file_name", df.get("reference_file_name", None)) + if run_col is not None: + tech_reps = run_col.nunique() + else: + tech_reps = 1 + return tech_reps, label_enum, sample_names, choice + + def get_sample_metadata(self) -> Optional[pd.DataFrame]: + """Return sample table as DataFrame from qpx (one row per sample). + Returns None if this dataset has no sample structure. + Use config condition_column, batch_column, etc. to pick columns.""" + if self._sample is None: + return None + return self._sample.to_df() + + def get_run_to_fraction(self) -> Optional[dict[str, str]]: + """Return run_file_name -> fraction from qpx run table when present. + + Used for ratio quantification so Fraction comes from QPX instead of SDRF. + Returns None if no run structure or no fraction column; callers use "1" then. + """ + if self._run is None: + return None + try: + run_df = self._run.to_df() + except Exception: + return None + if run_df is None or run_df.empty: + return None + run_col = "run_file_name" if "run_file_name" in run_df.columns else None + frac_col = "fraction" if "fraction" in run_df.columns else None + if not run_col or not frac_col: + return None + out = {} + for _, row in run_df.iterrows(): + fname = str(row[run_col]).strip() if pd.notna(row.get(run_col)) else None + if not fname: + continue + frac = row.get(frac_col) + frac_str = str(frac).strip() if pd.notna(frac) and str(frac).strip() else "1" + out[fname] = frac_str + stem = fname.rsplit(".", 1)[0] if "." in fname else fname + out[stem] = frac_str + return out if out else None + + def enrich_with_sdrf(self, sdrf_path: str) -> None: + """No-op when using QPX; sample/run metadata come from qpx.""" + logger.debug("enrich_with_sdrf is a no-op when using QPX directory; metadata from qpx") + + def get_feature_long_form(self) -> pd.DataFrame: + """Long-form feature table with mokume column names for DirectLFQ path.""" + return self._get_long_form() + + def get_report_from_database(self, samples: list, columns: Optional[list] = None) -> pd.DataFrame: + """Subset long-form to given samples (and optional columns).""" + df = self._get_long_form() + df = df[df["sample_accession"].isin(samples)] + if columns: + available = [c for c in columns if c in df.columns] + if available: + df = df[available] + return df + + def get_low_frequency_peptides(self, percentage: float = 0.2) -> tuple: + """Peptides that appear in less than percentage of samples. Filtering applied in Python.""" + df = self._get_long_form() + df = _apply_filter_builder(df, self.filter_builder) + grp = df.groupby(["sequence", "pg_accessions"])["sample_accession"].nunique().reset_index() + grp = grp[grp["sample_accession"] < percentage * len(self.samples)] + grp.dropna(subset=["pg_accessions"], inplace=True) + # Parse protein accession: first element of list, then split by | and take [1] if present + def _first_acc(x): + if hasattr(x, "__len__") and len(x): + v = x[0] if not isinstance(x, str) else x + else: + return "" + return v.split("|")[1] if "|" in str(v) else v + proteins = grp["pg_accessions"].apply(_first_acc) + return tuple(zip(proteins.tolist(), grp["sequence"].tolist())) + + def get_irs_scaling_factors( + self, + irs_channel: str, + irs_stat: str = "median", + irs_scope: str = "global", + ) -> dict[int, float]: + """IRS scaling factors from long-form data. Filtering applied in Python.""" + df = self._get_long_form() + col_label = "label" if "label" in df.columns else "sample_accession" + if col_label not in df.columns: + return {} + df = df[df[col_label].astype(str) == str(irs_channel)] + df = _apply_filter_builder(df, self.filter_builder) + run_col = df.get("run_file_name", df.get("reference_file_name", None)) + if run_col is None or run_col.isna().all(): + return {} + techrep = run_col.astype(str).str.split("_").str.get(-1) + techrep = pd.to_numeric(techrep, errors="coerce").fillna(0).astype(int) + df = df.copy() + df["_techrep"] = techrep + agg = df.groupby("_techrep")["intensity"].agg(irs_stat.lower()) + agg = agg[agg > 0] + if agg.empty: + return {} + if irs_scope.lower() == "global": + center = agg.median() if irs_stat.lower() == "median" else agg.mean() + scale = center / agg + else: + center = agg.median() if irs_stat.lower() == "median" else agg.mean() + scale = center / agg + return dict(zip(agg.index.tolist(), scale.tolist())) + + +def _apply_filter_builder(df: pd.DataFrame, filter_builder) -> pd.DataFrame: + """Apply filter_builder logic in Python (intensity, length, unique, contaminants).""" + out = df[df["intensity"] > 0].copy() + if getattr(filter_builder, "min_intensity", 0) > 0: + out = out[out["intensity"] >= filter_builder.min_intensity] + if getattr(filter_builder, "min_peptide_length", 0) > 0 and "sequence" in out.columns: + out = out[out["sequence"].str.len() >= filter_builder.min_peptide_length] + if getattr(filter_builder, "require_unique", False) and "unique" in out.columns: + out = out[out["unique"] == 1] + if getattr(filter_builder, "remove_contaminants", True) and "pg_accessions" in out.columns: + patterns = getattr(filter_builder, "contaminant_patterns", ["CONTAMINANT", "ENTRAP", "DECOY"]) + for pat in patterns: + out = out[~out["pg_accessions"].astype(str).str.contains(pat, regex=False, na=False)] + return out + + +class _SingleFeatureDataset: + """Minimal wrapper so QpxFeatureAdapter can consume a single qpx Feature (no run).""" + def __init__(self, feature): + self.feature = feature + self.run = None + self.sample = None + + +def create_feature_for_input( + parquet: Optional[str], + qpx_dir: Optional[str], + filter_builder=None, +): + """ + Create a Feature-like object for pipeline input. All reading is done by qpx. + + - If qpx_dir is set: open qpx.Dataset(qpx_dir) and return QpxFeatureAdapter. + - If parquet is set: open qpx.read_feature(parquet) and return QpxFeatureAdapter over it. + + At least one of parquet or qpx_dir must be set. + """ + if qpx_dir: + ds = open_qpx(qpx_dir) + return QpxFeatureAdapter(ds, filter_builder=filter_builder) + if parquet: + feat = open_qpx_feature_single(parquet) + return QpxFeatureAdapter(_SingleFeatureDataset(feat), filter_builder=filter_builder) + raise ValueError("Either parquet or qpx_dir must be set for pipeline input") diff --git a/mokume/normalization/base.py b/mokume/normalization/base.py new file mode 100644 index 0000000..e21f5de --- /dev/null +++ b/mokume/normalization/base.py @@ -0,0 +1,212 @@ +""" +Base classes for normalization methods. + +This module provides abstract base classes for feature-level and +sample-level normalization methods. Implementations should register +with the PluginRegistry for automatic discovery. + +Feature normalizers operate within a single run/sample to correct +for technical variation in intensity measurements. + +Sample normalizers operate across samples to make them comparable. +Some (TMM, IRS, hierarchical) need the full dataset; others +(globalMedian, conditionMedian) adjust each sample independently. +""" + +from abc import ABC, abstractmethod +from typing import Optional + +import pandas as pd + +from mokume.core.constants import NORM_INTENSITY, SAMPLE_ID, TECHREPLICATE + + +class FeatureNormalizer(ABC): + """Base class for feature-level (within-run) normalization. + + Feature normalizers correct for technical variation within a single + MS run. They operate on intensity values grouped by run or sample. + + Subclasses must implement ``transform_series()`` with the per-run + normalization math. The concrete ``normalize()`` method handles the + orchestration: iterating samples → runs, calling ``transform_series()`` + per run, computing per-run metrics, and balancing runs within a sample. + + Subclasses should register with:: + + from mokume.core.registry import PluginRegistry + + @PluginRegistry.register("normalization.feature", "median") + class MedianFeatureNormalizer(FeatureNormalizer): + ... + """ + + @property + @abstractmethod + def name(self) -> str: + """Human-readable normalizer name.""" + + @abstractmethod + def transform_series(self, series: pd.Series) -> pd.Series: + """Apply normalization to a single intensity series (one run). + + Parameters + ---------- + series : pd.Series + Raw intensity values for a single run. + + Returns + ------- + pd.Series + Transformed metric (e.g., series / series.median()). + """ + + def normalize( + self, + df: pd.DataFrame, + intensity_col: str = NORM_INTENSITY, + group_col: str = TECHREPLICATE, + sample_col: str = SAMPLE_ID, + ) -> pd.DataFrame: + """Normalize intensities across runs within each sample. + + Iterates over samples, applies ``transform_series()`` per run to + compute a per-run metric, then scales each run so that its metric + equals the sample-average metric. + + Parameters + ---------- + df : pd.DataFrame + Feature-level DataFrame. + intensity_col : str + Column containing intensities to normalize. + group_col : str + Column defining run groups within a sample (e.g., TechReplicate). + sample_col : str + Column identifying samples. + + Returns + ------- + pd.DataFrame + DataFrame with normalized intensities. + """ + samples = df[sample_col].unique() + for sample in samples: + runs = df.loc[df[sample_col] == sample, group_col].unique().tolist() + if len(runs) <= 1: + continue + + sample_mask = df[sample_col] == sample + sample_df = df.loc[sample_mask] + + # Compute per-run metric + run_metrics = {} + total_metric = 0 + for run in runs: + run = str(run) + run_series = sample_df.loc[ + sample_df[group_col] == run, intensity_col + ] + metric = self.transform_series(run_series) + run_metrics[run] = metric + total_metric += metric + + sample_avg_metric = total_metric / len(runs) + + # Scale each run + for run in runs: + run = str(run) + mask = (df[sample_col] == sample) & (df[group_col] == run) + run_intensity = df.loc[mask, intensity_col] + df.loc[mask, intensity_col] = run_intensity / ( + run_metrics[run] / sample_avg_metric + ) + + return df + + def __call__( + self, + df: pd.DataFrame, + technical_replicates: int, + intensity_col: str = NORM_INTENSITY, + group_col: str = TECHREPLICATE, + sample_col: str = SAMPLE_ID, + ) -> pd.DataFrame: + """Callable interface matching the old enum signature.""" + if technical_replicates <= 1: + return df + return self.normalize(df, intensity_col, group_col, sample_col) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" + + +class SampleNormalizer(ABC): + """Base class for sample-level (across-sample) normalization. + + Sample normalizers make samples comparable by adjusting for + systematic differences in total intensity, loading, etc. + + Some normalizers (TMM, IRS, hierarchical) need the full dataset + to compute normalization factors. Others (globalMedian, + conditionMedian) adjust each sample using pre-computed statistics. + + Subclasses should register with:: + + from mokume.core.registry import PluginRegistry + + @PluginRegistry.register("normalization.sample", "tmm") + class TMMSampleNormalizer(SampleNormalizer): + ... + """ + + @property + @abstractmethod + def name(self) -> str: + """Human-readable normalizer name.""" + + @property + def is_dataset_level(self) -> bool: + """Whether this normalizer needs the full dataset at once. + + True for: TMM, IRS, hierarchical (need cross-sample statistics). + False for: globalMedian, conditionMedian (per-sample adjustment). + + Returns + ------- + bool + """ + return False + + @abstractmethod + def normalize( + self, + df: pd.DataFrame, + intensity_col: str, + sample_col: str, + condition_col: Optional[str] = None, + **kwargs, + ) -> pd.DataFrame: + """Apply normalization. + + Parameters + ---------- + df : pd.DataFrame + DataFrame with peptide/protein intensities. + intensity_col : str + Column containing intensities. + sample_col : str + Column identifying samples. + condition_col : str, optional + Column identifying conditions/groups. + **kwargs + Method-specific parameters. + + Returns + ------- + pd.DataFrame + DataFrame with normalized intensities. + """ + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" diff --git a/mokume/normalization/feature_normalizers.py b/mokume/normalization/feature_normalizers.py new file mode 100644 index 0000000..6fe9bc4 --- /dev/null +++ b/mokume/normalization/feature_normalizers.py @@ -0,0 +1,100 @@ +""" +Concrete feature-level (within-run) normalization methods. + +Each class registers with the PluginRegistry and implements +``transform_series()`` — the per-run normalization math. +The run-balancing orchestration is handled by the base class. +""" + +import pandas as pd + +from mokume.core.registry import PluginRegistry +from mokume.normalization.base import FeatureNormalizer + + +@PluginRegistry.register("normalization.feature", "none") +class NoneFeatureNormalizer(FeatureNormalizer): + """No-op normalizer — returns data unchanged.""" + + @property + def name(self) -> str: + return "none" + + def transform_series(self, series: pd.Series) -> pd.Series: + return series + + def normalize(self, df, intensity_col=None, group_col=None, sample_col=None): + return df + + +@PluginRegistry.register("normalization.feature", "mean") +class MeanFeatureNormalizer(FeatureNormalizer): + """Mean normalization: intensity / mean(intensity).""" + + @property + def name(self) -> str: + return "mean" + + def transform_series(self, series: pd.Series) -> pd.Series: + return series / series.mean() + + +@PluginRegistry.register("normalization.feature", "median") +class MedianFeatureNormalizer(FeatureNormalizer): + """Median normalization: intensity / median(intensity).""" + + @property + def name(self) -> str: + return "median" + + def transform_series(self, series: pd.Series) -> pd.Series: + return series / series.median() + + +@PluginRegistry.register("normalization.feature", "max") +class MaxFeatureNormalizer(FeatureNormalizer): + """Max normalization: intensity / max(intensity).""" + + @property + def name(self) -> str: + return "max" + + def transform_series(self, series: pd.Series) -> pd.Series: + return series / series.max() + + +@PluginRegistry.register("normalization.feature", "global") +class GlobalFeatureNormalizer(FeatureNormalizer): + """Global normalization: intensity / sum(intensity).""" + + @property + def name(self) -> str: + return "global" + + def transform_series(self, series: pd.Series) -> pd.Series: + return series / series.sum() + + +@PluginRegistry.register("normalization.feature", "max_min") +class MaxMinFeatureNormalizer(FeatureNormalizer): + """Max-Min normalization: (intensity - min) / (max - min).""" + + @property + def name(self) -> str: + return "max_min" + + def transform_series(self, series: pd.Series) -> pd.Series: + min_val = series.min() + return (series - min_val) / (series.max() - min_val) + + +@PluginRegistry.register("normalization.feature", "iqr") +class IQRFeatureNormalizer(FeatureNormalizer): + """IQR normalization: mean of 25th and 75th quantiles.""" + + @property + def name(self) -> str: + return "iqr" + + def transform_series(self, series: pd.Series) -> pd.Series: + return series.quantile([0.75, 0.25], interpolation="linear").mean() diff --git a/mokume/normalization/sample_normalizers.py b/mokume/normalization/sample_normalizers.py new file mode 100644 index 0000000..a8c1c15 --- /dev/null +++ b/mokume/normalization/sample_normalizers.py @@ -0,0 +1,251 @@ +""" +Concrete sample-level (across-sample) normalization methods. + +Per-sample normalizers (is_dataset_level=False): + none, globalmedian, conditionmedian + +Dataset-level normalizers (is_dataset_level=True): + hierarchical, tmm, irs + +Dataset-level normalizers are thin adapters that pivot long → wide, +delegate to the existing fit/transform classes, and melt back. +""" + +from typing import Optional + +import numpy as np +import pandas as pd + +from mokume.core.registry import PluginRegistry +from mokume.core.constants import ( + CONDITION, + NORM_INTENSITY, + PROTEIN_NAME, + PEPTIDE_CANONICAL, + SAMPLE_ID, +) +from mokume.core.logger import get_logger +from mokume.normalization.base import SampleNormalizer + +logger = get_logger("mokume.normalization.sample_normalizers") + +# --------------------------------------------------------------------------- +# Per-sample normalizers +# --------------------------------------------------------------------------- + + +@PluginRegistry.register("normalization.sample", "none") +class NoneSampleNormalizer(SampleNormalizer): + """No-op normalizer — returns data unchanged.""" + + @property + def name(self) -> str: + return "none" + + def normalize(self, df, intensity_col=NORM_INTENSITY, sample_col=SAMPLE_ID, + condition_col=None, **kwargs): + return df + + +@PluginRegistry.register("normalization.sample", "globalmedian") +class GlobalMedianNormalizer(SampleNormalizer): + """Normalize each sample by its median relative to the global median. + + Expects ``med_map`` (dict: sample → median ratio) and ``sample`` (str) + in kwargs. + """ + + @property + def name(self) -> str: + return "globalmedian" + + def normalize(self, df, intensity_col=NORM_INTENSITY, sample_col=SAMPLE_ID, + condition_col=None, **kwargs): + med_map = kwargs.get("med_map", {}) + sample = kwargs.get("sample") + if sample is not None and sample in med_map: + df.loc[:, intensity_col] = df[intensity_col] / med_map[sample] + return df + + +@PluginRegistry.register("normalization.sample", "conditionmedian") +class ConditionMedianNormalizer(SampleNormalizer): + """Normalize each sample by its condition-specific median ratio. + + Expects ``med_map`` (dict: condition → {sample → ratio}) and + ``sample`` (str) in kwargs. The condition is read from ``condition_col``. + """ + + @property + def name(self) -> str: + return "conditionmedian" + + def normalize(self, df, intensity_col=NORM_INTENSITY, sample_col=SAMPLE_ID, + condition_col=None, **kwargs): + med_map = kwargs.get("med_map", {}) + sample = kwargs.get("sample") + if condition_col is None: + condition_col = CONDITION + if sample is not None and med_map: + con = df[condition_col].unique()[0] + if con in med_map and sample in med_map[con]: + df.loc[:, intensity_col] = df[intensity_col] / med_map[con][sample] + return df + + +# --------------------------------------------------------------------------- +# Dataset-level normalizers +# --------------------------------------------------------------------------- + + +@PluginRegistry.register("normalization.sample", "hierarchical") +class HierarchicalSampleNormalizerPlugin(SampleNormalizer): + """Adapter for HierarchicalSampleNormalizer (DirectLFQ-style). + + Operates on the full dataset: long → wide → fit_transform → long. + + Parameters + ---------- + num_samples_quadratic : int + Use quadratic optimization for datasets with fewer samples. + selected_proteins : list[str], optional + Proteins to use for computing normalization factors. + """ + + def __init__( + self, + num_samples_quadratic: int = 50, + selected_proteins: Optional[list] = None, + ): + self.num_samples_quadratic = num_samples_quadratic + self.selected_proteins = selected_proteins + + @property + def name(self) -> str: + return "hierarchical" + + @property + def is_dataset_level(self) -> bool: + return True + + def normalize(self, df, intensity_col=NORM_INTENSITY, sample_col=SAMPLE_ID, + condition_col=None, **kwargs): + from mokume.normalization.hierarchical import HierarchicalSampleNormalizer + + protein_col = kwargs.get("protein_col", PROTEIN_NAME) + peptide_col = kwargs.get("peptide_col", PEPTIDE_CANONICAL) + + logger.info("Applying hierarchical sample normalization...") + + # Long → wide + wide = df.pivot_table( + index=[protein_col, peptide_col], + columns=sample_col, + values=intensity_col, + aggfunc="sum", + ) + wide = wide.replace(0, np.nan) + wide_log2 = np.log2(wide) + + normalizer = HierarchicalSampleNormalizer( + num_samples_quadratic=self.num_samples_quadratic, + selected_proteins=self.selected_proteins, + ) + normalized_log2 = normalizer.fit_transform(wide_log2) + + # Back to linear scale + normalized_wide = 2 ** normalized_log2 + + # Wide → long + result = normalized_wide.reset_index().melt( + id_vars=[protein_col, peptide_col], + var_name=sample_col, + value_name=intensity_col, + ) + result = result.dropna(subset=[intensity_col]) + + logger.info(f"Hierarchical normalization complete: {len(result)} rows") + return result + + +@PluginRegistry.register("normalization.sample", "tmm") +class TMMSampleNormalizerPlugin(SampleNormalizer): + """Adapter for TMMNormalizer. + + Operates on the full dataset: long → wide → fit_transform → long. + """ + + @property + def name(self) -> str: + return "tmm" + + @property + def is_dataset_level(self) -> bool: + return True + + def normalize(self, df, intensity_col=NORM_INTENSITY, sample_col=SAMPLE_ID, + condition_col=None, **kwargs): + from mokume.normalization.tmm import TMMNormalizer + + protein_col = kwargs.get("protein_col", PROTEIN_NAME) + peptide_col = kwargs.get("peptide_col", PEPTIDE_CANONICAL) + + logger.info("Applying TMM sample normalization...") + + # Long → wide + wide = df.pivot_table( + index=[protein_col, peptide_col], + columns=sample_col, + values=intensity_col, + aggfunc="sum", + ) + + normalizer = TMMNormalizer() + normalized_wide = normalizer.fit_transform(wide) + + # Wide → long + result = normalized_wide.reset_index().melt( + id_vars=[protein_col, peptide_col], + var_name=sample_col, + value_name=intensity_col, + ) + result = result.dropna(subset=[intensity_col]) + + logger.info(f"TMM normalization complete: {len(result)} rows") + return result + + +@PluginRegistry.register("normalization.sample", "irs") +class IRSSampleNormalizerPlugin(SampleNormalizer): + """Adapter for IRSNormalizer (Internal Reference Scaling). + + Operates on the full dataset in wide format. Requires + ``reference_samples`` and ``sample_to_plex`` in kwargs. + """ + + def __init__(self, reference_samples=None, stat="median"): + self.reference_samples = reference_samples or [] + self.stat = stat + + @property + def name(self) -> str: + return "irs" + + @property + def is_dataset_level(self) -> bool: + return True + + def normalize(self, df, intensity_col=NORM_INTENSITY, sample_col=SAMPLE_ID, + condition_col=None, **kwargs): + from mokume.normalization.irs import IRSNormalizer + + sample_to_plex = kwargs.get("sample_to_plex", {}) + reference_samples = kwargs.get("reference_samples", self.reference_samples) + stat = kwargs.get("stat", self.stat) + + if not reference_samples: + logger.warning("No reference samples provided for IRS, skipping") + return df + + normalizer = IRSNormalizer(reference_samples=reference_samples, stat=stat) + return normalizer.fit_transform(df, sample_to_plex) diff --git a/mokume/pipeline/flows/__init__.py b/mokume/pipeline/flows/__init__.py new file mode 100644 index 0000000..1dc7833 --- /dev/null +++ b/mokume/pipeline/flows/__init__.py @@ -0,0 +1,14 @@ +""" +Pipeline flow modules. + +Each flow handles a different quantification paradigm: +- standard: iBAQ, TopN, sum, median (input_level="peptides") +- ratio: PS protocol log2 ratios (input_level="psms") +- directlfq: DirectLFQ package delegation (input_level="peptides_raw") + +Flow dispatch is handled by `pipeline/runner.py`. +""" + +from mokume.pipeline.flows import standard, ratio, directlfq + +__all__ = ["standard", "ratio", "directlfq"] diff --git a/mokume/pipeline/flows/directlfq.py b/mokume/pipeline/flows/directlfq.py new file mode 100644 index 0000000..1338e79 --- /dev/null +++ b/mokume/pipeline/flows/directlfq.py @@ -0,0 +1,108 @@ +""" +DirectLFQ quantification flow. + +Delegates normalization and quantification entirely to the DirectLFQ +package. Uses adapter logic to convert between qpx parquet format +and DirectLFQ's expected input (input_level="peptides_raw"). + +Flow: + Load qpx parquet -> QpxDataset(.features) + Convert to DirectLFQ format (adapter) + Run DirectLFQ normalization + estimation + Convert back -> QpxDataset(.proteins) +""" + +from mokume.core.dataset import QpxDataset +from mokume.core.logger import get_logger +from mokume.pipeline.config import PipelineConfig +from mokume.quantification.base import QuantificationMethod + +logger = get_logger("mokume.pipeline.flows.directlfq") + + +def run(method: QuantificationMethod, config: PipelineConfig) -> QpxDataset: + """Execute the DirectLFQ quantification flow. + + Parameters + ---------- + method : QuantificationMethod + The resolved DirectLFQ method (used for metadata). + config : PipelineConfig + Pipeline configuration. + + Returns + ------- + QpxDataset + Dataset with proteins populated. + """ + try: + import directlfq.protein_intensity_estimation as lfq_estimation + import directlfq.normalization as lfq_norm + import directlfq.config as lfq_config + except ImportError: + raise ImportError( + "DirectLFQ quantification requires the directlfq package.\n" + "Install with: pip install directlfq\n" + "Or: pip install mokume[directlfq]" + ) + + from mokume.pipeline.stages import LoadingStage + + dataset = QpxDataset() + + # Load and filter data + loading = LoadingStage(config) + logger.info("Loading and filtering data for DirectLFQ...") + filtered_df, sample_metadata = loading.load_for_directlfq() + dataset.features = filtered_df + if sample_metadata is not None: + dataset.sample_info = sample_metadata + dataset.record_step("loading", method="directlfq", rows_out=len(filtered_df)) + logger.info(f"Filtered data: {len(filtered_df)} features") + + # Validate schema + errors = dataset.validate_level("features") + if errors: + logger.warning("Schema validation warnings for features: %s", errors) + + # Convert to DirectLFQ format + logger.info("Converting to DirectLFQ format...") + directlfq_input = loading.convert_to_directlfq_format(filtered_df) + logger.info(f"DirectLFQ input shape: {directlfq_input.shape}") + + # Configure DirectLFQ + lfq_config.set_global_protein_and_ion_id(protein_id="protein", quant_id="ion") + lfq_config.set_compile_normalized_ion_table( + config.output.export_ions is not None + ) + + # Run DirectLFQ normalization + logger.info("Running DirectLFQ sample normalization...") + normed_df = lfq_norm.NormalizationManagerSamplesOnSelectedProteins( + directlfq_input, + num_samples_quadratic=config.quantification.directlfq_num_samples_quadratic, + ).complete_dataframe + + # Run DirectLFQ protein estimation + logger.info("Running DirectLFQ protein estimation...") + protein_df, ion_df = lfq_estimation.estimate_protein_intensities( + normed_df, + min_nonan=config.quantification.directlfq_min_nonan, + num_samples_quadratic=10, + num_cores=config.quantification.directlfq_num_cores, + ) + + # Export ions if requested + if config.output.export_ions and ion_df is not None: + logger.info(f"Exporting ions to {config.output.export_ions}") + ion_df.to_csv(config.output.export_ions) + + dataset.proteins = protein_df + dataset.record_step( + "quantification", + method="directlfq", + rows_out=len(protein_df), + ) + + logger.info(f"DirectLFQ complete: {len(protein_df)} proteins") + return dataset diff --git a/mokume/pipeline/flows/ratio.py b/mokume/pipeline/flows/ratio.py new file mode 100644 index 0000000..c264cfe --- /dev/null +++ b/mokume/pipeline/flows/ratio.py @@ -0,0 +1,91 @@ +""" +Ratio quantification flow (PS protocol). + +Handles ratio-based quantification where each sample is normalized +against a per-plex reference channel. Works at PSM level +(input_level="psms"). Reference samples and plex mapping come from +the qpx sample table only. + +Flow: + Load qpx parquet at PSM level -> QpxDataset(.psms) + Reference/plex from qpx sample table (--irs-reference-*, --plex-column) + Ratio quantification (log2(sample/reference) per plex) + -> QpxDataset(.proteins) +""" + +from mokume.core.dataset import QpxDataset +from mokume.core.logger import get_logger +from mokume.pipeline.config import PipelineConfig +from mokume.quantification.base import QuantificationMethod + +logger = get_logger("mokume.pipeline.flows.ratio") + + +def run(method: QuantificationMethod, config: PipelineConfig) -> QpxDataset: + """Execute the ratio quantification flow. + + Parameters + ---------- + method : QuantificationMethod + The resolved ratio quantification method. + config : PipelineConfig + Pipeline configuration. + + Returns + ------- + QpxDataset + Dataset with proteins populated (log2 ratios). + """ + from mokume.quantification.ratio import RatioQuantification + from mokume.pipeline.stages import LoadingStage + + dataset = QpxDataset() + + logger.info("Running ratio-based quantification (PS protocol)...") + + # Load PSM data and get reference/plex from qpx sample table + loading = LoadingStage(config) + psm_df, ref_samples, sample_to_plex, sample_metadata = loading.load_for_ratio() + dataset.psms = psm_df + if sample_metadata is not None: + dataset.sample_info = sample_metadata + dataset.record_step("loading", method="ratio", rows_out=len(psm_df)) + + # Validate schema + errors = dataset.validate_level("psms") + if errors: + logger.warning("Schema validation warnings for psms: %s", errors) + + # Run ratio quantification + ratio_quant = RatioQuantification( + reference_samples=ref_samples, + sample_to_plex=sample_to_plex, + ) + protein_df = ratio_quant.quantify(psm_df) + + # Remove reference samples from output columns (log2(ref/ref) = 0) + protein_col = protein_df.columns[0] + seen = set() + unique_cols = [] + for c in protein_df.columns: + if c == protein_col or c not in ref_samples: + if c not in seen: + seen.add(c) + unique_cols.append(c) + protein_df = protein_df[unique_cols] + + dataset.proteins = protein_df + dataset.record_step( + "quantification", + method="ratio", + rows_out=len(protein_df), + reference_samples=ref_samples, + ) + + # Store ratio-specific metadata + dataset.uns["ratio_config"] = { + "reference_samples": ref_samples, + } + + logger.info(f"Ratio pipeline complete: {len(protein_df)} proteins") + return dataset diff --git a/mokume/pipeline/flows/standard.py b/mokume/pipeline/flows/standard.py new file mode 100644 index 0000000..5f3495d --- /dev/null +++ b/mokume/pipeline/flows/standard.py @@ -0,0 +1,79 @@ +""" +Standard quantification flow. + +Handles: iBAQ, TopN, sum, median quantification methods. +These all work from peptide-level data (input_level="peptides"). + +Flow: + Load qpx parquet -> QpxDataset(.features) + Apply filters + Feature normalization + Assemble peptides -> QpxDataset(.peptides) + Sample normalization + Quantification -> QpxDataset(.proteins) +""" + +from mokume.core.dataset import QpxDataset +from mokume.core.logger import get_logger +from mokume.pipeline.config import PipelineConfig +from mokume.pipeline.stages import ( + LoadingStage, + NormalizationStage, + QuantificationStage, +) +from mokume.quantification.base import QuantificationMethod + +logger = get_logger("mokume.pipeline.flows.standard") + + +def run(method: QuantificationMethod, config: PipelineConfig) -> QpxDataset: + """Execute the standard quantification flow. + + Parameters + ---------- + method : QuantificationMethod + The resolved quantification method (used for metadata only here; + the actual dispatch happens in QuantificationStage). + config : PipelineConfig + Pipeline configuration. + + Returns + ------- + QpxDataset + Dataset with proteins populated. + """ + dataset = QpxDataset() + + # Load and process peptides (filtering + normalization) + loading = LoadingStage(config) + logger.info("Loading and filtering data...") + peptide_df, sample_metadata = loading.load_for_mokume() + dataset.peptides = peptide_df + if sample_metadata is not None: + dataset.sample_info = sample_metadata + dataset.record_step("loading", method="standard", rows_out=len(peptide_df)) + logger.info(f"Processed peptides: {len(peptide_df)} rows") + + # Validate schema + errors = dataset.validate_level("peptides") + if errors: + logger.warning("Schema validation warnings for peptides: %s", errors) + + # Export peptides if requested + if config.output.export_peptides: + logger.info(f"Exporting peptides to {config.output.export_peptides}") + peptide_df.to_csv(config.output.export_peptides, index=False) + + # Quantify proteins + quant_stage = QuantificationStage(config) + logger.info(f"Quantifying proteins with method: {config.quantification.method}") + protein_df = quant_stage.quantify(peptide_df) + dataset.proteins = protein_df + dataset.record_step( + "quantification", + method=config.quantification.method, + rows_out=len(protein_df), + ) + logger.info(f"Quantification complete: {len(protein_df)} proteins") + + return dataset diff --git a/mokume/pipeline/runner.py b/mokume/pipeline/runner.py new file mode 100644 index 0000000..17738c2 --- /dev/null +++ b/mokume/pipeline/runner.py @@ -0,0 +1,231 @@ +""" +Pipeline runner with metadata-driven flow dispatch. + +The runner resolves the quantification method from the PluginRegistry, +selects the appropriate flow based on the method's ``input_level`` +property, executes it, and runs common post-processing. + +Third-party plugins can register new flows by adding entries to +``FLOW_DISPATCH``. +""" + +from typing import Dict, Callable, Optional + +from mokume.core.dataset import QpxDataset +from mokume.core.logger import get_logger +from mokume.core.registry import PluginRegistry +from mokume.pipeline.config import PipelineConfig +from mokume.pipeline import flows +from mokume.quantification.base import QuantificationMethod + +logger = get_logger("mokume.pipeline.runner") + + +# Map input_level -> flow module. +# Each flow module must have a run(method, config) -> QpxDataset function. +FLOW_DISPATCH: Dict[str, object] = { + "peptides": flows.standard, # iBAQ, TopN, sum, median + "psms": flows.ratio, # Ratio quantification (PS protocol) + "peptides_raw": flows.directlfq, # DirectLFQ (handles its own normalization) +} + + +def run_pipeline(config: PipelineConfig) -> QpxDataset: + """Execute the quantification pipeline. + + Resolves the quantification method, selects the appropriate flow, + and runs common post-processing. + + Parameters + ---------- + config : PipelineConfig + Pipeline configuration. + + Returns + ------- + QpxDataset + Dataset with proteins populated and optional DE results in uns. + """ + quant_method_name = config.quantification.method.lower() + logger.info(f"Starting pipeline with quant_method={quant_method_name}") + + # Ensure built-in methods are registered + import mokume.quantification # noqa: F401 + + # Resolve method from registry + method = PluginRegistry.get("quantification", quant_method_name) + + # Select flow based on method's declared input_level + flow = FLOW_DISPATCH.get(method.input_level) + if flow is None: + raise ValueError( + f"No pipeline flow registered for input_level='{method.input_level}'. " + f"Available flows: {list(FLOW_DISPATCH.keys())}" + ) + + # Execute the flow + dataset = flow.run(method, config) + + # Common post-processing + dataset = _postprocess(dataset, config) + + return dataset + + +def _postprocess(dataset: QpxDataset, config: PipelineConfig) -> QpxDataset: + """Run common post-processing steps. + + These steps apply regardless of which flow was used: + - IRS normalization (multi-plex TMT) + - Coverage filter + - Batch correction + - Differential expression + - Plotting and reports + + Parameters + ---------- + dataset : QpxDataset + Dataset with proteins populated. + config : PipelineConfig + Pipeline configuration. + + Returns + ------- + QpxDataset + Updated dataset. + """ + from mokume.pipeline.stages import NormalizationStage, PostprocessingStage + + protein_df = dataset.proteins + if protein_df is None: + logger.warning("No protein data after quantification, skipping post-processing") + return dataset + + quant_method = config.quantification.method.lower() + + # Create stages once (reused across steps) + norm_stage = NormalizationStage(config) + post_stage = PostprocessingStage(config) + + # IRS normalization (skip for ratio — handles cross-plex via reference division) + if config.irs.enabled and quant_method != "ratio": + protein_df = norm_stage.apply_irs(protein_df, dataset=dataset) + dataset.proteins = protein_df + dataset.record_step("irs_normalization", method=config.irs.stat) + + # Coverage filter + if config.quantification.coverage_threshold is not None: + protein_df = norm_stage.apply_coverage_filter(protein_df, dataset=dataset) + dataset.proteins = protein_df + dataset.record_step( + "coverage_filter", + threshold=config.quantification.coverage_threshold, + rows_out=len(protein_df), + ) + + # Batch correction + if config.batch.enabled: + protein_df = post_stage.apply_batch_correction(protein_df, dataset=dataset) + dataset.proteins = protein_df + dataset.record_step("batch_correction", method=config.batch.method) + + # Differential expression + if config.de.enabled: + de_results = post_stage.run_differential_expression(protein_df, dataset=dataset) + if de_results: + dataset.uns["de_results"] = { + k: v.to_dict(orient="records") for k, v in de_results.items() + } + + # Reconstruct DE DataFrames for plotting/report (shared helper) + de_dfs = None + if config.de.enabled and "de_results" in dataset.uns: + import pandas as pd + de_dfs = { + k: pd.DataFrame(v) for k, v in dataset.uns["de_results"].items() + } + + # Plotting + if config.output.plot_dir and any([ + config.output.plot_volcano, + config.output.plot_heatmap, + config.output.plot_pca, + ]): + post_stage.generate_plots(protein_df, de_dfs, dataset=dataset) + + # Interactive report + if config.output.interactive_report and config.de.enabled and de_dfs: + post_stage.generate_interactive_report(protein_df, de_dfs, dataset=dataset) + + # AnnData export — uses QPX naming convention via qpx.Dataset.save_anndata() + if config.output.export_anndata and config.input.qpx_dir: + _export_anndata(dataset, config) + + return dataset + + +def _export_anndata(dataset: QpxDataset, config: PipelineConfig) -> None: + """Export the dataset as AnnData using QPX's save_anndata API. + + If an AnnData file already exists for this view, the new quantification + is added as a layer (keyed by the quant method name) instead of + overwriting the file. + """ + try: + import anndata as ad + import qpx + except ImportError: + logger.warning( + "AnnData export requires the 'qpx' and 'anndata' packages. " + "Install with: pip install qpx anndata" + ) + return + import numpy as np + from pathlib import Path + + qpx_ds = qpx.Dataset(config.input.qpx_dir) + view = config.output.anndata_view or "ae" + prefix = Path(config.input.qpx_dir).name + existing_path = Path(config.input.qpx_dir) / f"{prefix}.{view}.h5ad" + + quant_method = config.quantification.method.lower() + + if existing_path.exists(): + # Append as a layer to existing AnnData + logger.info( + f"Existing AnnData found at {existing_path}, " + f"adding '{quant_method}' as layer" + ) + adata = ad.read_h5ad(existing_path) + new_adata = dataset.to_anndata(level="proteins", value_col="Intensity") + + # Align the new matrix to the existing obs/var indices + import pandas as pd + new_wide = pd.DataFrame( + new_adata.X, + index=new_adata.obs.index, + columns=new_adata.var.index, + ) + aligned = new_wide.reindex( + index=adata.obs.index, columns=adata.var.index + ) + adata.layers[quant_method] = aligned.values.astype(np.float32) + + # Also add log2 layer + log2_vals = np.log2( + np.where(aligned.values > 0, aligned.values, np.nan) + ).astype(np.float32) + adata.layers[f"{quant_method}_log2"] = log2_vals + else: + # Create new AnnData with X = current quantification + adata = dataset.to_anndata(level="proteins", value_col="Intensity") + + # Add log2 layer + x_vals = adata.X.copy() + log2_vals = np.log2( + np.where(x_vals > 0, x_vals, np.nan) + ).astype(np.float32) + adata.layers[f"{quant_method}_log2"] = log2_vals + + output_path = qpx_ds.save_anndata(adata, view=view) + logger.info(f"AnnData saved to {output_path}") diff --git a/mokume/quantification/median.py b/mokume/quantification/median.py new file mode 100644 index 0000000..c250fcf --- /dev/null +++ b/mokume/quantification/median.py @@ -0,0 +1,83 @@ +""" +Median protein quantification method. + +This module provides a quantification method that computes the median +of peptide intensities for each protein. +""" + +from typing import Optional + +import pandas as pd + +from mokume.quantification.base import QuantificationMethod +from mokume.core.constants import ( + PROTEIN_NAME, + PEPTIDE_CANONICAL, + NORM_INTENSITY, + SAMPLE_ID, +) +from mokume.core.registry import PluginRegistry + + +@PluginRegistry.register("quantification", "median") +class MedianQuantification(QuantificationMethod): + """ + Median protein quantification method. + + Calculates protein abundance as the median of all peptide intensities + for each protein in each sample (or run if run_column is provided). + """ + + @property + def name(self) -> str: + return "Median" + + def quantify( + self, + peptide_df: pd.DataFrame, + protein_column: str = PROTEIN_NAME, + peptide_column: str = PEPTIDE_CANONICAL, + intensity_column: str = NORM_INTENSITY, + sample_column: str = SAMPLE_ID, + run_column: Optional[str] = None, + ) -> pd.DataFrame: + """ + Quantify proteins using median of peptide intensities. + + Parameters + ---------- + peptide_df : pd.DataFrame + DataFrame containing peptide-level data. + protein_column : str + Column name for protein identifiers. + peptide_column : str + Column name for peptide sequences. + intensity_column : str + Column name for intensity values. + sample_column : str + Column name for sample identifiers. + run_column : str, optional + Column name for run identifiers. If provided, quantification + is performed at the run level instead of sample level. + + Returns + ------- + pd.DataFrame + DataFrame with columns: protein_column, sample_column, + (run_column if provided), 'Intensity'. + """ + # Determine grouping columns based on aggregation level + if run_column is not None and run_column in peptide_df.columns: + group_cols = [protein_column, sample_column, run_column] + else: + group_cols = [protein_column, sample_column] + + result = ( + peptide_df.groupby(group_cols)[intensity_column] + .median() + .reset_index() + ) + + # Rename intensity column + result = result.rename(columns={intensity_column: "Intensity"}) + return result diff --git a/tests/test_cecilia_integration.py b/tests/test_cecilia_integration.py new file mode 100644 index 0000000..1e93824 --- /dev/null +++ b/tests/test_cecilia_integration.py @@ -0,0 +1,238 @@ +""" +Integration test: run the plugin-architecture pipeline against the +cecilia NASH TMT dataset (2-plex TMT11). + +Tests multiple quantification methods end-to-end: + - median (standard flow) + - maxlfq (standard flow) + - top3 (standard flow, TopN pattern) + - ratio (ratio flow, PS protocol) + - median + IRS normalization (standard flow + post-processing) + +Each test validates: + 1. Pipeline completes without error + 2. Output proteins DataFrame is non-empty + 3. QpxDataset provenance is recorded + 4. Schema validation passes +""" + +import os +import tempfile + +import pandas as pd +import pytest + +# Paths --------------------------------------------------------------- +PARQUET = os.path.join( + os.path.dirname(__file__), "..", "..", + "cecilia-problem", "qpx_output", + "cecilia.feature.parquet", +) +SDRF = os.path.join( + os.path.dirname(__file__), "..", "..", + "cecilia-problem", "sdrf", + "combined_plex_data_sheet.sdrf.tsv", +) + +# Skip the entire module if dataset files are missing +pytestmark = pytest.mark.skipif( + not (os.path.exists(PARQUET) and os.path.exists(SDRF)), + reason="Cecilia NASH dataset not available", +) + +# Imports (after path setup) ------------------------------------------- +from mokume.pipeline.config import ( + PipelineConfig, + InputConfig, + FilterConfig, + NormalizationConfig, + QuantificationConfig, + IRSConfig, + OutputConfig, +) +from mokume.pipeline.features_to_proteins import QuantificationPipeline +from mokume.core.dataset import QpxDataset + + +def _make_config( + quant_method: str = "median", + irs: bool = False, + irs_remove_ref: bool = False, + coverage_threshold: float = None, + export_peptides: str = None, +) -> PipelineConfig: + """Build a PipelineConfig pointing at the cecilia dataset.""" + return PipelineConfig( + input=InputConfig(parquet=PARQUET, sdrf=SDRF), + filtering=FilterConfig( + min_aa=7, + min_unique_peptides=2, + remove_contaminants=True, + ), + normalization=NormalizationConfig( + run_method="median", + sample_method="globalMedian", + ), + quantification=QuantificationConfig( + method=quant_method, + coverage_threshold=coverage_threshold, + ), + irs=IRSConfig( + enabled=irs, + remove_reference=irs_remove_ref, + ), + output=OutputConfig( + export_peptides=export_peptides, + ), + ) + + +def _assert_valid_result(dataset: QpxDataset, min_proteins: int = 50): + """Common assertions for all pipeline results.""" + # Proteins must be populated + assert dataset.proteins is not None, "proteins level is None" + + protein_df = dataset.proteins + if isinstance(protein_df, pd.DataFrame): + assert len(protein_df) >= min_proteins, ( + f"Expected >= {min_proteins} proteins, got {len(protein_df)}" + ) + + # Provenance must be recorded + assert "provenance" in dataset.uns, "No provenance in uns" + steps = dataset.uns["provenance"]["steps"] + assert len(steps) >= 2, f"Expected >= 2 provenance steps, got {len(steps)}" + + step_names = [s["name"] for s in steps] + assert "loading" in step_names + assert "quantification" in step_names + + +# ====================================================================== +# Tests +# ====================================================================== + +class TestMedianQuantification: + """Standard flow: median summarization.""" + + def test_median_pipeline(self): + config = _make_config(quant_method="median") + pipeline = QuantificationPipeline(config) + dataset = pipeline.run_dataset() + _assert_valid_result(dataset) + print(f"\n median: {len(dataset.proteins)} proteins, " + f"{len(dataset.proteins.columns) - 1} samples") + + +class TestMaxLFQQuantification: + """Standard flow: MaxLFQ.""" + + def test_maxlfq_pipeline(self): + config = _make_config(quant_method="maxlfq") + pipeline = QuantificationPipeline(config) + dataset = pipeline.run_dataset() + _assert_valid_result(dataset) + print(f"\n maxlfq: {len(dataset.proteins)} proteins, " + f"{len(dataset.proteins.columns) - 1} samples") + + +class TestTopNQuantification: + """Standard flow: TopN pattern (top3 via registry).""" + + def test_top3_pipeline(self): + config = _make_config(quant_method="top3") + pipeline = QuantificationPipeline(config) + dataset = pipeline.run_dataset() + _assert_valid_result(dataset) + print(f"\n top3: {len(dataset.proteins)} proteins, " + f"{len(dataset.proteins.columns) - 1} samples") + + +class TestRatioQuantification: + """Ratio flow: PS protocol log2 ratios.""" + + def test_ratio_pipeline(self): + config = _make_config( + quant_method="ratio", + coverage_threshold=0.65, + ) + pipeline = QuantificationPipeline(config) + dataset = pipeline.run_dataset() + _assert_valid_result(dataset, min_proteins=30) + + # Ratio-specific checks + assert "ratio_config" in dataset.uns + assert dataset.uns["ratio_config"]["reference_samples"] + print(f"\n ratio: {len(dataset.proteins)} proteins, " + f"refs={dataset.uns['ratio_config']['reference_samples']}") + + +class TestIRSNormalization: + """Standard flow + IRS post-processing.""" + + def test_median_irs_pipeline(self): + config = _make_config( + quant_method="median", + irs=True, + irs_remove_ref=True, + ) + pipeline = QuantificationPipeline(config) + dataset = pipeline.run_dataset() + _assert_valid_result(dataset) + + # IRS should have added a step + step_names = [s["name"] for s in dataset.uns["provenance"]["steps"]] + assert "irs_normalization" in step_names + print(f"\n median+IRS: {len(dataset.proteins)} proteins, " + f"{len(dataset.proteins.columns) - 1} samples") + + +class TestQpxDatasetSaveLoad: + """Test save/load roundtrip with real pipeline output.""" + + def test_save_load_roundtrip(self): + config = _make_config(quant_method="median") + pipeline = QuantificationPipeline(config) + dataset = pipeline.run_dataset() + + with tempfile.TemporaryDirectory() as tmpdir: + dataset.save(tmpdir) + loaded = QpxDataset.load(tmpdir) + + assert loaded.proteins is not None + assert len(loaded.proteins) == len(dataset.proteins) + assert loaded.uns.get("provenance") is not None + + +class TestPeptideExport: + """Test peptide export path.""" + + def test_export_peptides(self): + with tempfile.TemporaryDirectory() as tmpdir: + pep_path = os.path.join(tmpdir, "peptides.csv") + config = _make_config( + quant_method="median", + export_peptides=pep_path, + ) + pipeline = QuantificationPipeline(config) + dataset = pipeline.run_dataset() + _assert_valid_result(dataset) + + assert os.path.exists(pep_path), "Peptide export file not created" + pep_df = pd.read_csv(pep_path) + assert len(pep_df) > 0 + print(f"\n Exported {len(pep_df)} peptide rows") + + +class TestSchemaValidation: + """Verify schema validation is wired in.""" + + def test_peptides_schema_valid(self): + config = _make_config(quant_method="median") + pipeline = QuantificationPipeline(config) + dataset = pipeline.run_dataset() + + # Peptides should pass validation + if dataset.peptides is not None: + errors = dataset.validate_level("peptides") + assert errors == [], f"Schema errors: {errors}" diff --git a/tests/test_dataset.py b/tests/test_dataset.py new file mode 100644 index 0000000..1fe0943 --- /dev/null +++ b/tests/test_dataset.py @@ -0,0 +1,222 @@ +"""Tests for QpxDataset.""" + +import json +import os +import tempfile + +import pandas as pd +import pytest + +from mokume.core.constants import PROTEIN_NAME, SAMPLE_ID, PEPTIDE_CANONICAL, NORM_INTENSITY +from mokume.core.dataset import QpxDataset + + +# --------------------------------------------------------------------------- +# Fixtures +# --------------------------------------------------------------------------- + +@pytest.fixture +def sample_peptide_df(): + """Minimal peptide-level DataFrame.""" + return pd.DataFrame({ + PROTEIN_NAME: ["P1", "P1", "P2", "P2"], + PEPTIDE_CANONICAL: ["PEPTIDE", "ANOTHERPEP", "THIRDPEP", "FOURTHPEP"], + SAMPLE_ID: ["S1", "S2", "S1", "S2"], + NORM_INTENSITY: [100.0, 200.0, 300.0, 400.0], + }) + + +@pytest.fixture +def sample_protein_df(): + """Minimal protein-level DataFrame.""" + return pd.DataFrame({ + PROTEIN_NAME: ["P1", "P2"], + SAMPLE_ID: ["S1", "S1"], + "Intensity": [500.0, 600.0], + }) + + +@pytest.fixture +def dataset_with_peptides(sample_peptide_df): + ds = QpxDataset() + ds.peptides = sample_peptide_df + return ds + + +# --------------------------------------------------------------------------- +# Tests: data level access +# --------------------------------------------------------------------------- + +class TestDataLevelAccess: + def test_get_level_returns_dataframe(self, dataset_with_peptides): + df = dataset_with_peptides.get_level("peptides") + assert isinstance(df, pd.DataFrame) + assert len(df) == 4 + + def test_get_level_none_when_empty(self): + ds = QpxDataset() + assert ds.get_level("peptides") is None + + def test_get_level_invalid_raises(self): + ds = QpxDataset() + with pytest.raises(ValueError, match="Unknown data level"): + ds.get_level("invalid_level") + + def test_set_level(self, sample_peptide_df): + ds = QpxDataset() + ds.set_level("peptides", sample_peptide_df) + assert ds.peptides is not None + assert len(ds.get_level("peptides")) == 4 + + def test_set_level_invalid_raises(self, sample_peptide_df): + ds = QpxDataset() + with pytest.raises(ValueError, match="Unknown data level"): + ds.set_level("invalid", sample_peptide_df) + + def test_populated_levels(self, sample_peptide_df, sample_protein_df): + ds = QpxDataset() + assert ds.populated_levels == [] + ds.peptides = sample_peptide_df + assert ds.populated_levels == ["peptides"] + ds.proteins = sample_protein_df + assert set(ds.populated_levels) == {"peptides", "proteins"} + + +# --------------------------------------------------------------------------- +# Tests: schema validation +# --------------------------------------------------------------------------- + +class TestSchemaValidation: + def test_validate_valid_peptides(self, dataset_with_peptides): + errors = dataset_with_peptides.validate_level("peptides") + assert errors == [] + + def test_validate_missing_columns(self): + ds = QpxDataset() + ds.peptides = pd.DataFrame({"A": [1]}) + errors = ds.validate_level("peptides") + assert len(errors) > 0 + assert any("Missing required column" in e for e in errors) + + def test_validate_unpopulated_level(self): + ds = QpxDataset() + errors = ds.validate_level("proteins") + assert len(errors) == 1 + assert "not populated" in errors[0] + + +# --------------------------------------------------------------------------- +# Tests: wide matrix +# --------------------------------------------------------------------------- + +class TestToWideMatrix: + def test_pivot_proteins(self, sample_protein_df): + ds = QpxDataset() + ds.proteins = sample_protein_df + wide = ds.to_wide_matrix(level="proteins") + assert isinstance(wide, pd.DataFrame) + assert wide.index.name == PROTEIN_NAME + + def test_wide_empty_level_raises(self): + ds = QpxDataset() + with pytest.raises(ValueError, match="not populated"): + ds.to_wide_matrix(level="proteins") + + +# --------------------------------------------------------------------------- +# Tests: subsetting +# --------------------------------------------------------------------------- + +class TestSubsetting: + def test_subset_samples(self, sample_peptide_df): + ds = QpxDataset() + ds.peptides = sample_peptide_df + subset = ds.subset_samples(["S1"]) + df = subset.get_level("peptides") + assert list(df[SAMPLE_ID].unique()) == ["S1"] + + def test_subset_proteins(self, sample_peptide_df): + ds = QpxDataset() + ds.peptides = sample_peptide_df + subset = ds.subset_proteins(["P1"]) + df = subset.get_level("peptides") + assert list(df[PROTEIN_NAME].unique()) == ["P1"] + + +# --------------------------------------------------------------------------- +# Tests: save / load roundtrip +# --------------------------------------------------------------------------- + +class TestSaveLoad: + def test_roundtrip(self, sample_peptide_df, sample_protein_df): + ds = QpxDataset() + ds.peptides = sample_peptide_df + ds.proteins = sample_protein_df + ds.uns = {"pipeline": "test", "version": 1} + + with tempfile.TemporaryDirectory() as tmpdir: + ds.save(tmpdir) + + # Check files exist + assert os.path.exists(os.path.join(tmpdir, "peptides.parquet")) + assert os.path.exists(os.path.join(tmpdir, "proteins.parquet")) + assert os.path.exists(os.path.join(tmpdir, "uns.json")) + + # Load back + loaded = QpxDataset.load(tmpdir) + assert loaded.get_level("peptides") is not None + assert len(loaded.get_level("peptides")) == 4 + assert loaded.get_level("proteins") is not None + assert loaded.uns["pipeline"] == "test" + + def test_save_with_layers(self, sample_peptide_df): + ds = QpxDataset() + ds.peptides = sample_peptide_df + ds.layers["normalized"] = sample_peptide_df.copy() + + with tempfile.TemporaryDirectory() as tmpdir: + ds.save(tmpdir) + loaded = QpxDataset.load(tmpdir) + assert "normalized" in loaded.layers + assert len(loaded.layers["normalized"]) == 4 + + +# --------------------------------------------------------------------------- +# Tests: provenance +# --------------------------------------------------------------------------- + +class TestProvenance: + def test_record_step(self): + ds = QpxDataset() + ds.record_step("loading", method="standard", rows_out=100) + assert "provenance" in ds.uns + assert len(ds.uns["provenance"]["steps"]) == 1 + step = ds.uns["provenance"]["steps"][0] + assert step["name"] == "loading" + assert step["method"] == "standard" + assert step["rows_out"] == 100 + + def test_record_multiple_steps(self): + ds = QpxDataset() + ds.record_step("loading") + ds.record_step("normalization") + ds.record_step("quantification") + assert len(ds.uns["provenance"]["steps"]) == 3 + + +# --------------------------------------------------------------------------- +# Tests: repr +# --------------------------------------------------------------------------- + +class TestRepr: + def test_repr_empty(self): + ds = QpxDataset() + r = repr(ds) + assert "QpxDataset(" in r + + def test_repr_with_data(self, sample_peptide_df): + ds = QpxDataset() + ds.peptides = sample_peptide_df + r = repr(ds) + assert "peptides:" in r + assert "4 rows" in r diff --git a/tests/test_registry.py b/tests/test_registry.py new file mode 100644 index 0000000..db3609b --- /dev/null +++ b/tests/test_registry.py @@ -0,0 +1,159 @@ +"""Tests for the PluginRegistry.""" + +import pytest + +from mokume.core.registry import PluginRegistry, VALID_INPUT_LEVELS +from mokume.quantification.base import QuantificationMethod + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +class _DummyMethod(QuantificationMethod): + """Minimal concrete method for testing.""" + + @property + def name(self): + return "dummy" + + def quantify(self, peptide_df, **kwargs): + return peptide_df + + +class _BadLevelMethod(QuantificationMethod): + """Method with an invalid input_level.""" + + @property + def name(self): + return "bad_level" + + @property + def input_level(self): + return "nonexistent" + + def quantify(self, peptide_df, **kwargs): + return peptide_df + + +@pytest.fixture(autouse=True) +def _reset_registry(): + """Reset registry before and after each test. + + After reset, we must reload all quantification submodules so + that @register decorators fire again and re-populate the store. + """ + import importlib + import mokume.quantification.topn + import mokume.quantification.ratio + import mokume.quantification as quant_mod + + PluginRegistry.reset() + # Reload submodules first (they have @register decorators) + importlib.reload(mokume.quantification.topn) + importlib.reload(mokume.quantification.ratio) + # Then reload __init__ which registers aliases + importlib.reload(quant_mod) + yield + PluginRegistry.reset() + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestPluginRegistryRegister: + def test_register_decorator(self): + PluginRegistry.reset() + + @PluginRegistry.register("quantification", "test_method") + class TestMethod(_DummyMethod): + @property + def name(self): + return "test" + + assert PluginRegistry.is_registered("quantification", "test_method") + instance = PluginRegistry.get("quantification", "test_method") + assert instance.name == "test" + + def test_register_unknown_group_raises(self): + with pytest.raises(ValueError, match="Unknown plugin group"): + PluginRegistry.register("nonexistent_group", "foo") + + def test_register_instance_factory(self): + PluginRegistry.reset() + PluginRegistry.register_instance_factory( + "quantification", "factory_test", + lambda **kw: _DummyMethod(), + ) + instance = PluginRegistry.get("quantification", "factory_test") + assert instance.name == "dummy" + + def test_register_instance_factory_unknown_group(self): + with pytest.raises(ValueError, match="Unknown plugin group"): + PluginRegistry.register_instance_factory( + "nonexistent", "foo", lambda **kw: None + ) + + +class TestPluginRegistryGet: + def test_get_maxlfq_alias_resolves_to_directlfq(self): + """'maxlfq' alias should resolve to DirectLFQ.""" + method = PluginRegistry.get("quantification", "maxlfq") + assert method.name == "DirectLFQ" + + def test_get_case_insensitive(self): + method = PluginRegistry.get("quantification", "MaxLFQ") + assert method is not None + + def test_get_unknown_raises(self): + with pytest.raises(ValueError, match="Unknown quantification method"): + PluginRegistry.get("quantification", "totally_nonexistent_xyz") + + def test_topn_pattern(self): + method = PluginRegistry.get("quantification", "top5") + assert method is not None + assert method.name in ("TopN", "Top5", "top5", "TopNQuantification") + + def test_topn_large_n(self): + method = PluginRegistry.get("quantification", "top100") + assert method is not None + + def test_invalid_input_level_raises(self): + """A method with invalid input_level should raise at get() time.""" + PluginRegistry.reset() + PluginRegistry.register("quantification", "bad_level")(_BadLevelMethod) + with pytest.raises(ValueError, match="input_level='nonexistent'"): + PluginRegistry.get("quantification", "bad_level") + + +class TestPluginRegistryAvailable: + def test_available_returns_sorted(self): + names = PluginRegistry.available("quantification") + assert names == sorted(names) + + def test_available_includes_builtins(self): + names = PluginRegistry.available("quantification") + assert "maxlfq" in names # backward-compat alias + assert "topn" in names + + def test_available_empty_group(self): + names = PluginRegistry.available("filter") + assert isinstance(names, list) + + +class TestPluginRegistryReset: + def test_reset_clears_registrations(self): + PluginRegistry.reset() + names = PluginRegistry.available("quantification") + # After reset + entry point discovery, only entry-point methods remain + # (or empty if not installed as package) + assert isinstance(names, list) + + +class TestValidInputLevels: + def test_valid_levels_set(self): + assert "peptides" in VALID_INPUT_LEVELS + assert "psms" in VALID_INPUT_LEVELS + assert "peptides_raw" in VALID_INPUT_LEVELS + assert "features" in VALID_INPUT_LEVELS