diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 0000000..8499f19 --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,33 @@ +name: Generate API Documentation + +on: + push: + branches: [ dev ] + workflow_dispatch: + +jobs: + deploy-docs: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.x' + + - name: Install dependencies + run: | + pip install pdoc + pip install -e . + + - name: Generate documentation + run: | + pdoc -d numpy --output-dir ./docs finemo + + - name: Deploy to GitHub Pages + uses: peaceiris/actions-gh-pages@v3 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + publish_dir: ./docs \ No newline at end of file diff --git a/.gitignore b/.gitignore index f68e48a..5e3aa36 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,7 @@ -.conda -.DS_Store *.egg-info __pycache__ +/.* +!/.github /notebooks -/notebooks/old -/scratch.txt \ No newline at end of file +/scratch.txt +/scratch \ No newline at end of file diff --git a/LICENSE b/LICENSE index dd86235..ebb7494 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2023 Austin Wang +Copyright (c) 2025 Austin Wang Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/README.md b/README.md index 138c738..115d397 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,38 @@ -# finemo_gpu +# Fi-NeMo: Finding Neural Network Motifs -**Fi-NeMo** (**Fi**nding **Ne**ural network **Mo**tifs) is a GPU-accelerated hit caller for identifying occurrences of TFMoDISCo motifs within contribution scores generated by machine learning models. +**Fi-NeMo** (**Fi**nding **Ne**ural Network **Mo**tifs) is a GPU-accelerated motif instance calling tool for identifying transcription factor binding sites from neural network contribution scores. + +## Overview + +Fi-NeMo implements a competitive optimization approach using proximal gradient descent to identify motif instances by solving a sparse linear reconstruction problem. Unlike traditional sequence-based methods, Fi-NeMo leverages context-aware importance scores from deep neural networks to comprehensively map transcription factor binding sites, enabling the identification of both high-confidence canonical motifs and low-prevalence cofactor motifs that are often missed by conventional approaches. + +The algorithm represents contribution scores as weighted combinations of motif contribution weight matrices (CWMs) at specific genomic positions. This competitive assignment process more closely reflects the biological reality of transcription factors competing for binding sites, resulting in superior sensitivity and specificity compared to sequence-only methods. + +### Features + +- **GPU-accelerated optimization**: Fast processing of large contribution score datasets using PyTorch +- **Competitive motif assignment**: Biologically-motivated algorithm that resolves similar motifs +- **Context-aware analysis**: Leverages neural network importance scores for improved sensitivity and specificity +- **Comprehensive evaluation**: Built-in tools for assessing and visualizing motif discovery quality and hit calling performance +- **Multiple input formats**: Support for bigWig, HDF5, and TF-MoDISco output formats + +## Method + +Fi-NeMo solves motif instance calling as an optimization problem that reconstructs contribution score tracks as sparse linear combinations of motif CWMs, formulated as an L1-regularized linear model. This competitive assignment encourages overlapping motif instances to be resolved in a meaningful way, with stronger matches receiving higher coefficients while weaker or redundant matches are suppressed. + +
+ +
+ +## References + +Fi-NeMo is described in: +> Tseng, Ramalingam, Wang, Schreiber, et al. "Decoding predictive motif lexicons and syntax from deep learning models of transcription factor binding profiles." (manuscript in preparation) + +Related tools: +- [TF-MoDISco](https://github.com/jmschrei/tfmodisco-lite): *De novo* motif discovery from importance scores +- [BPNet](https://github.com/kundajelab/bpnet-refactor): Deep learning models for TF binding prediction +- [ChromBPNet](https://github.com/kundajelab/chrombpnet): Deep learning models for chromatin accessibility prediction ## Installation @@ -19,7 +51,7 @@ cd finemo_gpu #### Create a Conda Environment with Dependencies -This step is optional but recommended +This step is optional but recommended for conda users. ```sh conda env create -f environment.yml -n $ENV_NAME @@ -58,9 +90,19 @@ Recommended: - Peak region coordinates in uncompressed [ENCODE NarrowPeak](https://genome.ucsc.edu/FAQ/FAQformat.html#format12) format. -## Usage +## API Documentation + +For Fi-NeMo's Python API documentation, see: https://www.austintwang.com/finemo_gpu/finemo.html -Fi-NeMo includes a command-line utility named `finemo`. Here, we describe basic usage for each subcommand. For all options, run `finemo -h`. +## Command-Line Usage + +Fi-NeMo provides a command-line utility named `finemo` for motif instance calling and analysis. The typical workflow involves three main steps: + +1. **Preprocessing**: Transform input contributions and sequences into a unified format +2. **Hit Calling**: Identify motif instances using the Fi-NeMo algorithm +3. **Reporting and Analysis**: Generate visualizations and perform post-processing + +For detailed options for any subcommand, run `finemo -h`. ### Preprocessing @@ -126,14 +168,14 @@ Usage: `finemo extract-regions-modisco-fmt -s -a -o < #### `finemo call-hits` -Identify hits in input regions using TFMoDISCo CWM's. +Identify motif instances in input regions using the Fi-NeMo competitive optimization algorithm. This is the core functionality that leverages TF-MoDISco CWMs to find motif occurrences in contribution score data. Usage: `finemo call-hits -r -m -o [-p ] [-t ] [-l ] [-b ] [-J]` - `-r/--regions`: A `.npz` file of input sequences, contributions, and coordinates. Created with a `finemo extract-regions-*` command. - `-m/--modisco-h5`: A tfmodisco-lite output H5 file of motif patterns. - `-o/--out-dir`: The path to the output directory. -- `-t/--cwm-trim-threshold`: The threshold to determine motif start and end positions within the full CWMs. Default is 0.3. +- `-t/--cwm-trim-threshold`: The threshold to determine motif start and end positions within the full CWMs. Default is 0.3. If you need finer control over motif trimming, check out the `-T/--cwm-trim-thresholds` and `-R/--cwm-trim-coords` options. - `-l/--global-lambda`: The L1 regularization weight determining the sparsity of hits. Default is 0.7. - `-b/--batch-size`: The batch size used for optimization. Default is 2000. - `-J/--compile`: Enable JIT compilation for faster execution. This option may not work on older GPUs. @@ -156,7 +198,7 @@ Usage: `finemo call-hits -r -m -o [-p ] - `peak_name`: The name of the peak region containing the hit, taken from the `name` field of the input peak data. `NA` if peak coordinates are not provided. - `peak_id`: The numerical index of the peak region containing the hit. -`hits_unique.tsv`: A deduplicated list of hits in the same format as `hits.tsv`. In cases where peak regions overlap, `hits.tsv` may list multiple instances of a hit, each linked to a different peak. `hits_unique.tsv` arbitrarily selects one instance per duplicated hit. This file is generated only if peak coordinates are provided. +`hits_unique.tsv`: A deduplicated list of hits in the same format as `hits.tsv`. In cases where peak regions overlap, `hits.tsv` may list multiple instances of a hit, each linked to a different peak. `hits_unique.tsv` arbitrarily selects one instance per duplicated hit. **This file is empty if peak coordinates are not provided.** `hits.bed`: A coordinate-sorted BED file of unique hits. It includes: @@ -195,20 +237,37 @@ Usage: `finemo call-hits -r -m -o [-p ] `params.json`: The parameters used for hit calling. -#### Additional notes +#### Parameter Guidelines + +**Sensitivity Control (`-l/--global-lambda`)** +- Controls sparsity and sensitivity of hit calling +- Higher values (e.g., 0.8-0.9) → fewer, higher-confidence hits +- Lower values (e.g., 0.5-0.6) → more sensitive, may include weaker hits +- Default of 0.7 works well for chromatin accessibility data +- ChIP-seq data may benefit from lower values (0.6) + +**Motif Trimming (`-t/--cwm-trim-threshold`)** +- Determines where motif boundaries are set within full CWMs +- Lower values → more conservative trimming, longer motifs +- Higher values → more aggressive trimming, shorter core motifs +- Affects resolution of closely-spaced motif instances + +**Performance Optimization (`-b/--batch-size`, `-J`)** +- Set batch size to utilize available GPU memory efficiently +- Reduce batch size if you encounter out-of-memory errors +- Enable JIT compilation (`-J`) for faster execution on newer GPUs -- The `-l/--global-lambda` parameter controls the sensitivity of the hit-calling algorithm, with higher values resulting in fewer but more confident hits. This parameter represents the minimum cosine similarity between a query contribution score window and a CWM to be considered a hit. The default value of 0.7 typically works well for chromatin accessibility data. ChIP-Seq data may require a lower value (e.g. 0.6). -- The `-t/--cwm-trim-threshold` parameter sets the maximum relative contribution score in trimmed-out CWM flanks. If you find that motif flanks are being trimmed too aggressively, consider lowering this value. However, a too-low value may result in closely-spaced motif instances being missed. -- Set `-b/--batch-size` to fill a significant fraction of your GPU memory. **If you encounter GPU out-of-memory errors, try lowering this value.** -- Legacy TFMoDISCo H5 files can be updated to the newer TFMoDISCo-lite format with the `modisco convert` command found in the [tfmodisco-lite](https://github.com/jmschrei/tfmodisco-lite/tree/main) package. -- The hit-calling thresholding procedure is scale-invariant. That is, whether a position is assigned a hit depends on the shapes of the motif CWM and the contribution scores, not the absolute magnitude of the scores. If you wish to prioritize hits based on the magnitude of the contribution scores, set a per-motif rank threshold the `hit_coefficient_global` field in the `hits.tsv` file, which captures both the absolute importance and the closeness of match. +#### Important Notes -### Output reporting +- **Scale Invariance**: Hit calling depends on motif and contribution score shapes, not absolute magnitudes. Use `hit_coefficient_global` or `hit_importance` for importance-based thresholding. +- **Legacy Format Support**: Convert older TF-MoDISco files using `modisco convert` from [tfmodisco-lite](https://github.com/jmschrei/tfmodisco-lite). + +### Output reporting and post-processing #### `finemo report` -Generate an HTML report (`report.html`) visualizing TF-MoDISCo seqlet recall and hit distributions. -If `-n/--no-recall` is not set, the regions used for hit calling must exactly match those used during the TF-MoDISCo motif discovery process. +Generate an HTML report (`report.html`) visualizing TF-MoDISco seqlet recall and hit distributions. +If `-n/--no-recall` is not set, the regions used for hit calling must exactly match those used during the TF-MoDISco motif discovery process. This command does not utilize the GPU. Usage: `finemo report -r -H -o [-m ] [-W ] [-n]` @@ -220,12 +279,29 @@ Usage: `finemo report -r -H -o [-m ] [-W - `-W/--modisco-region-width`: The width of the region around each peak summit used by tfmodisco-lite. Default is 400. - `-n/--no-recall`: Do not compute motif recall metrics. Default is False. -#### Additional outputs +Additional report outputs: + +- `motif_report.tsv`: Statistics on the distribution of hits per motif. The columns and values correspond to those in the HTML report's table. +- `motif_occurrences.tsv`: The number of hits of each motif in each input region. Also includes the total number of hits per region. +- `CWMs`: A directory containing visualizations of motif CWMs, as well as corresponding tables with numerical CWM values. +- `seqlets.tsv`: tf-modisco seqlet coordinates for each motif in each region. Only generated if `-m/--modisco-h5` is provided. + +#### `finemo collapse-hits` + +Identify the best hits by motif similarity within groups of overlapping hits. Adds a 0/1 `is_primary` column to the `hits.tsv` file, indicating whether a hit is the best hit in its group. This command does not utilize the GPU. + +Usage: `usage: finemo collapse-hits -i -o [-O ]` + +- `-i/--hits`: The path to the input hits file. This should be the `hits.tsv` or `hits_unique.tsv` file generated by the `finemo call-hits` command. +- `-o/--out-path`: The path to the output file. This will be a copy of the input file with an additional `is_primary` column. +- `-O/--overlap-frac`: The minimum fraction overlap required for two hits to be considered overlapping. Precisely, given two hits of lengths `x` and `y`, the minimum number of overlapping bases is `overlap_frac * (x + y) / 2`. Default is 0.2. -`motif_report.tsv`: Statistics on the distribution of hits per motif. The columns and values correspond to those in the HTML report's table. +#### `finemo intersect-hits` -`motif_occurrences.tsv`: The number of hits of each motif in each input region. Also includes the total number of hits per region. +Find the intersection of hits across multiple runs. This command does not utilize the GPU. -`CWMs`: A directory containing visualizations of motif CWMs, as well as corresponding tables with numerical CWM values. +Usage: `finemo intersect-hits -i -o [-r]` -`seqlets.tsv`: tf-modisco seqlet coordinates for each motif in each region. Only generated if `-m/--modisco-h5` is provided. +- `-i/--hits`: The path to one or more input hits file. This should be the `hits.tsv` or `hits_unique.tsv` file generated by the `finemo call-hits` command. +- `-o/--out-path`: The path to the output file. Reoccuring columns are suffixed with the positional index of the input file (e.g. `hit_importance_1`), with the exception of index 0. +- `-r/--relaxed`: By default, the intersection assumes consistent input region definitions (name and coordinates) and motif trimming across runs. In contrast, this relaxed intersection criteria uses only motif names and untrimmed hit coordinates. However, this is not suitable when hit genomic coordinates are unknown. Default is False. diff --git a/assets/methods.svg b/assets/methods.svg new file mode 100644 index 0000000..90ff4ea --- /dev/null +++ b/assets/methods.svg @@ -0,0 +1 @@ +Minimize DifferenceConvolveBackpropagateObserved scoresReconstructed scoresMotif CWMsHit coefficients (learned) \ No newline at end of file diff --git a/environment.yml b/environment.yml index 01e7af9..0de69bf 100644 --- a/environment.yml +++ b/environment.yml @@ -1,13 +1,10 @@ channels: - - pytorch - - nvidia - conda-forge - bioconda - nodefaults dependencies: - - pytorch=2.5.1 - - pytorch-cuda=12.4 - python=3.11 + - numba=0.61.2 - numpy=2.2.0 - scipy=1.14.1 - polars=1.17.1 @@ -17,4 +14,8 @@ dependencies: - tqdm=4.67.1 - jinja2=3.1.4 - pybigwig=0.3.23 - - pyfaidx=0.8.1.3 \ No newline at end of file + - pyfaidx=0.8.1.3 + - jaxtyping=0.3.2 + - pip=25.2 + - pip: + - torch==2.5.1 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 111f8c9..2344165 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,17 +6,18 @@ build-backend = "setuptools.build_meta" name = "finemo" description = "Identification of regulatory elements from neural network contribution scores for DNA." keywords = ["deep learning", "genomics"] -version = "0.30" +version = "0.40" readme = "README.md" license = {file = "LICENSE"} authors = [ {name = "Austin Wang", email = "austin.wang1357@gmail.com"}, - {name = "Anshul Kundaje"} + {name = "Anshul Kundaje", email = "akundaje@stanford.edu"} ] dependencies = [ "numpy", "scipy", "torch", + "numba", "polars>=1.0", "matplotlib", "h5py", @@ -24,7 +25,8 @@ dependencies = [ "tqdm", "pyBigWig", "pyfaidx", - "jinja2" + "jinja2", + "jaxtyping" ] [project.scripts] @@ -33,3 +35,6 @@ finemo = "finemo.main:cli" [project.urls] Homepage = "https://github.com/austintwang/finemo_gpu" Repository = "https://github.com/austintwang/finemo_gpu.git" + +[tool.ruff] +ignore = ["F722"] \ No newline at end of file diff --git a/setup.py b/setup.py index 88c95dc..3e0f86b 100644 --- a/setup.py +++ b/setup.py @@ -3,4 +3,4 @@ # Empty setup.py for compatibility with pip<21.1 # See pyproject.toml for package configuration -setup() \ No newline at end of file +setup() diff --git a/src/finemo/__init__.py b/src/finemo/__init__.py index e69de29..6b18638 100644 --- a/src/finemo/__init__.py +++ b/src/finemo/__init__.py @@ -0,0 +1,80 @@ +"""Fi-NeMo: Finding Neural Network Motifs. + +A GPU-accelerated motif instance calling tool for identifying transcription factor +binding sites from neural network contribution scores. + +Fi-NeMo implements a competitive optimization approach using proximal gradient descent +to identify motif instances by solving a sparse linear reconstruction problem. The +algorithm represents contribution scores as weighted combinations of motif contribution +weight matrices (CWMs) at specific genomic positions. + +Key Features +------------ +- GPU-accelerated hit calling using PyTorch +- Support for multiple input formats (bigWig, HDF5, TF-MoDISco) +- Competitive motif instance assignment +- Comprehensive evaluation and visualization tools +- Post-processing utilities for hit refinement + +Modules +------- +- hitcaller : Core Fi-NeMo algorithm implementation +- data_io : Data input/output utilities +- main : Command-line interface +- evaluation : Performance assessment tools +- visualization : Plotting and report generation +- postprocessing : Hit refinement and analysis + +Examples +-------- +Basic hit calling workflow: + +>>> import finemo +>>> from finemo import data_io, hitcaller +>>> +>>> # Load preprocessed data +>>> sequences, contribs, peaks_df, has_peaks = data_io.load_regions_npz('regions.npz') +>>> cwms, trim_masks = data_io.load_motif_cwms('motifs.h5') +>>> +>>> # Call hits +>>> hits_df, qc_df = hitcaller.fit_contribs( +... cwms=cwms, +... contribs=contribs, +... sequences=sequences, +... cwm_trim_mask=trim_masks, +... use_hypothetical=False, +... lambdas=np.array([0.7] * len(cwms)), +... step_size_max=3.0, +... step_size_min=0.08, +... sqrt_transform=False, +... convergence_tol=0.0005, +... max_steps=10000, +... batch_size=1000, +... step_adjust=0.7, +... post_filter=True, +... device=None, +... compile_optimizer=False +... ) + +See Also +-------- +TF-MoDISco : https://github.com/jmschrei/tfmodisco-lite +BPNet : https://github.com/kundajelab/bpnet-refactor +ChromBPNet: https://github.com/kundajelab/chrombpnet +""" + +from . import data_io +from . import hitcaller +from . import evaluation +from . import visualization +from . import postprocessing +from . import main + +__all__ = [ + "data_io", + "hitcaller", + "evaluation", + "visualization", + "postprocessing", + "main", +] diff --git a/src/finemo/__main__.py b/src/finemo/__main__.py new file mode 100644 index 0000000..4a6bf2e --- /dev/null +++ b/src/finemo/__main__.py @@ -0,0 +1,8 @@ +""" +Entry point for running finemo's CLI as a module via 'python -m finemo'. +""" + +from .main import cli + +if __name__ == "__main__": + cli() diff --git a/src/finemo/data_io.py b/src/finemo/data_io.py index d5b119a..6dc6074 100644 --- a/src/finemo/data_io.py +++ b/src/finemo/data_io.py @@ -1,256 +1,848 @@ +"""Data input/output module for the Fi-NeMo motif instance calling pipeline. + +This module handles loading and processing of various genomic data formats including: +- Peak region files (ENCODE NarrowPeak format) +- Genome sequences (FASTA format) +- Contribution scores (bigWig, HDF5 formats) +- Neural network model outputs +- Motif data from TF-MoDISco +- Hit calling results + +The module supports multiple input formats used for contribution scores +and provides utilities for data conversion and quality control. +""" + import json import os import warnings from contextlib import ExitStack +from typing import List, Dict, Tuple, Optional, Any, Union, Callable import numpy as np +from numpy import ndarray import h5py -import hdf5plugin +import hdf5plugin # noqa: F401, imported for side effects (HDF5 plugin registration) import polars as pl import pyBigWig import pyfaidx +from jaxtyping import Float, Int from tqdm import tqdm -def load_txt(path): +def load_txt(path: str) -> List[str]: + """Load a text file containing one item per line. + + Parameters + ---------- + path : str + Path to the text file. + + Returns + ------- + List[str] + List of strings, one per line (first column if tab-delimited). + """ entries = [] with open(path) as f: for line in f: item = line.rstrip("\n").split("\t")[0] entries.append(item) - + return entries -def load_mapping(path, type): +def load_mapping(path: str, value_type: Callable[[str], Any]) -> Dict[str, Any]: + """Load a two-column tab-delimited mapping file. + + Parameters + ---------- + path : str + Path to the mapping file. Must be tab-delimited with exactly two columns. + value_type : Callable[[str], Any] + Type constructor to apply to values (e.g., int, float, str). + Must accept a string and return the converted value. + + Returns + ------- + Dict[str, Any] + Dictionary mapping keys to values of the specified type. + + Raises + ------ + ValueError + If lines don't contain exactly two tab-separated values. + FileNotFoundError + If the specified file does not exist. + """ mapping = {} with open(path) as f: for line in f: key, val = line.rstrip("\n").split("\t") - mapping[key] = type(val) + mapping[key] = value_type(val) return mapping -NARROWPEAK_SCHEMA = ["chr", "peak_start", "peak_end", "peak_name", "peak_score", - "peak_strand", "peak_signal", "peak_pval", "peak_qval", "peak_summit"] -NARROWPEAK_DTYPES = [pl.String, pl.Int32, pl.Int32, pl.String, pl.UInt32, - pl.String, pl.Float32, pl.Float32, pl.Float32, pl.Int32] +def load_mapping_tuple( + path: str, value_type: Callable[[str], Any] +) -> Dict[str, Tuple[Any, ...]]: + """Load a mapping file where values are tuples from multiple columns. + + Parameters + ---------- + path : str + Path to the mapping file. Must be tab-delimited with multiple columns. + value_type : Callable[[str], Any] + Type constructor to apply to each value element. + Must accept a string and return the converted value. + + Returns + ------- + Dict[str, Tuple[Any, ...]] + Dictionary mapping keys to tuples of values of the specified type. + The first column is used as the key, remaining columns as tuple values. + + Raises + ------ + ValueError + If lines don't contain at least two tab-separated values. + FileNotFoundError + If the specified file does not exist. + """ + mapping = {} + with open(path) as f: + for line in f: + entries = line.rstrip("\n").split("\t") + key = entries[0] + val = entries[1:] + mapping[key] = tuple(value_type(i) for i in val) + + return mapping + -def load_peaks(peaks_path, chrom_order_path, half_width): +# ENCODE NarrowPeak format column definitions +NARROWPEAK_SCHEMA: List[str] = [ + "chr", + "peak_start", + "peak_end", + "peak_name", + "peak_score", + "peak_strand", + "peak_signal", + "peak_pval", + "peak_qval", + "peak_summit", +] +NARROWPEAK_DTYPES: List[Any] = [ + pl.String, + pl.Int32, + pl.Int32, + pl.String, + pl.UInt32, + pl.String, + pl.Float32, + pl.Float32, + pl.Float32, + pl.Int32, +] + + +def load_peaks( + peaks_path: str, chrom_order_path: Optional[str], half_width: int +) -> pl.DataFrame: + """Load peak region data from ENCODE NarrowPeak format file. + + Parameters + ---------- + peaks_path : str + Path to the NarrowPeak format file. + chrom_order_path : str, optional + Path to file defining chromosome ordering. If None, uses order from peaks file. + half_width : int + Half-width of regions around peak summits. + + Returns + ------- + pl.DataFrame + DataFrame containing peak information with columns: + - chr: Chromosome name + - peak_region_start: Start coordinate of centered region + - peak_name: Peak identifier + - peak_id: Sequential peak index + - chr_id: Numeric chromosome identifier + """ peaks = ( - pl.scan_csv(peaks_path, has_header=False, new_columns=NARROWPEAK_SCHEMA, separator='\t', - quote_char=None, schema_overrides=NARROWPEAK_DTYPES, null_values=['.', 'NA', 'null', 'NaN']) + pl.scan_csv( + peaks_path, + has_header=False, + new_columns=NARROWPEAK_SCHEMA, + separator="\t", + quote_char=None, + schema_overrides=NARROWPEAK_DTYPES, + null_values=[".", "NA", "null", "NaN"], + ) .select( chr=pl.col("chr"), peak_region_start=pl.col("peak_start") + pl.col("peak_summit") - half_width, - peak_name=pl.col("peak_name") + peak_name=pl.col("peak_name"), ) .with_row_index(name="peak_id") .collect() ) - + if chrom_order_path is not None: chrom_order = load_txt(chrom_order_path) else: chrom_order = [] chrom_order_set = set(chrom_order) - chrom_order_peaks = [i for i in peaks.get_column("chr").unique(maintain_order=True) if i not in chrom_order_set] + chrom_order_peaks = [ + i + for i in peaks.get_column("chr").unique(maintain_order=True) + if i not in chrom_order_set + ] chrom_order.extend(chrom_order_peaks) chrom_ind_map = {val: ind for ind, val in enumerate(chrom_order)} peaks = peaks.with_columns( pl.col("chr").replace_strict(chrom_ind_map).alias("chr_id") ) - + return peaks -SEQ_ALPHABET = np.array(["A","C","G","T"], dtype="S1") +# DNA sequence alphabet for one-hot encoding +SEQ_ALPHABET: np.ndarray = np.array(["A", "C", "G", "T"], dtype="S1") + -def one_hot_encode(sequence, dtype=np.int8): +def one_hot_encode(sequence: str, dtype: Any = np.int8) -> Int[ndarray, "4 L"]: + """Convert DNA sequence string to one-hot encoded matrix. + + Parameters + ---------- + sequence : str + DNA sequence string containing A, C, G, T characters. + dtype : np.dtype, default np.int8 + Data type for the output array. + + Returns + ------- + Int[ndarray, "4 L"] + One-hot encoded sequence where rows correspond to A, C, G, T and + L is the sequence length. + + Notes + ----- + The output array has shape (4, len(sequence)) with rows corresponding to + nucleotides A, C, G, T in that order. Non-standard nucleotides (N, etc.) + result in all-zero columns. + """ sequence = sequence.upper() - seq_chararray = np.frombuffer(sequence.encode('UTF-8'), dtype='S1') - one_hot = (seq_chararray[None,:] == SEQ_ALPHABET[:,None]).astype(dtype) + seq_chararray = np.frombuffer(sequence.encode("UTF-8"), dtype="S1") + one_hot = (seq_chararray[None, :] == SEQ_ALPHABET[:, None]).astype(dtype) return one_hot -def load_regions_from_bw(peaks, fa_path, bw_paths, half_width): +def load_regions_from_bw( + peaks: pl.DataFrame, fa_path: str, bw_paths: List[str], half_width: int +) -> Tuple[Int[ndarray, "N 4 L"], Float[ndarray, "N L"]]: + """Load genomic sequences and contribution scores from FASTA and bigWig files. + + Parameters + ---------- + peaks : pl.DataFrame + Peak regions DataFrame from load_peaks() containing columns: + 'chr', 'peak_region_start'. + fa_path : str + Path to genome FASTA file (.fa or .fasta format). + bw_paths : List[str] + List of paths to bigWig files containing contribution scores. + Must be non-empty. + half_width : int + Half-width of regions to extract around peak centers. + Total region width will be 2 * half_width. + + Returns + ------- + sequences : Int[ndarray, "N 4 L"] + One-hot encoded DNA sequences where N is the number of peaks, + 4 represents A,C,G,T nucleotides, and L is the region length (2 * half_width). + contribs : Float[ndarray, "N L"] + Contribution scores averaged across input bigWig files. + Shape is (N peaks, L region_length). + + Notes + ----- + BigWig files only provide projected contribution scores, not hypothetical scores. + Regions extending beyond chromosome boundaries are zero-padded. + Missing values in bigWig files are converted to zero. + """ num_peaks = peaks.height + region_width = half_width * 2 - sequences = np.zeros((num_peaks, 4, half_width * 2), dtype=np.int8) - contribs = np.zeros((num_peaks, half_width * 2), dtype=np.float16) + sequences = np.zeros((num_peaks, 4, region_width), dtype=np.int8) + contribs = np.zeros((num_peaks, region_width), dtype=np.float16) + # Load genome reference genome = pyfaidx.Fasta(fa_path, one_based_attributes=False) - + bws = [pyBigWig.open(i) for i in bw_paths] contrib_buffer = np.zeros((len(bw_paths), half_width * 2), dtype=np.float16) try: - for ind, row in tqdm(enumerate(peaks.iter_rows(named=True)), disable=None, unit="regions", total=num_peaks): + for ind, row in tqdm( + enumerate(peaks.iter_rows(named=True)), + disable=None, + unit="regions", + total=num_peaks, + ): chrom = row["chr"] start = row["peak_region_start"] end = start + 2 * half_width - - sequence_data = genome[chrom][start:end] - sequence = sequence_data.seq - start_adj = sequence_data.start - end_adj = sequence_data.end + + sequence_data: pyfaidx.FastaRecord = genome[chrom][start:end] # type: ignore + sequence: str = sequence_data.seq # type: ignore + start_adj: int = sequence_data.start # type: ignore + end_adj: int = sequence_data.end # type: ignore a = start_adj - start b = end_adj - start if b > a: - sequences[ind,:,a:b] = one_hot_encode(sequence) + sequences[ind, :, a:b] = one_hot_encode(sequence) for j, bw in enumerate(bws): - contrib_buffer[j,:] = np.nan_to_num(bw.values(chrom, start_adj, end_adj, numpy=True)) + contrib_buffer[j, :] = np.nan_to_num( + bw.values(chrom, start_adj, end_adj, numpy=True) + ) + + contribs[ind, a:b] = np.mean(contrib_buffer, axis=0) - contribs[ind,a:b] = np.mean(contrib_buffer, axis=0) - finally: for bw in bws: bw.close() - + return sequences, contribs -def load_regions_from_chrombpnet_h5(h5_paths, half_width): +def load_regions_from_chrombpnet_h5( + h5_paths: List[str], half_width: int +) -> Tuple[Int[ndarray, "N 4 L"], Float[ndarray, "N 4 L"]]: + """Load genomic sequences and contribution scores from ChromBPNet HDF5 files. + + Parameters + ---------- + h5_paths : List[str] + List of paths to ChromBPNet HDF5 files containing sequences and SHAP scores. + Must be non-empty and contain compatible data shapes. + half_width : int + Half-width of regions to extract around the center. + Total region width will be 2 * half_width. + + Returns + ------- + sequences : Int[ndarray, "N 4 L"] + One-hot encoded DNA sequences where N is the number of regions, + 4 represents A,C,G,T nucleotides, and L is the region length (2 * half_width). + contribs : Float[ndarray, "N 4 L"] + SHAP contribution scores averaged across input files. + Shape is (N regions, 4 nucleotides, L region_length). + + Notes + ----- + ChromBPNet files store sequences in 'raw/seq' and SHAP scores in 'shap/seq'. + All input files must have the same dimensions and number of regions. + Missing values in contribution scores are converted to zero. + """ with ExitStack() as stack: h5s = [stack.enter_context(h5py.File(i)) for i in h5_paths] - start = h5s[0]['raw/seq'].shape[-1] // 2 - half_width + start = h5s[0]["raw/seq"].shape[-1] // 2 - half_width # type: ignore # HDF5 array access end = start + 2 * half_width - - sequences = h5s[0]['raw/seq'][:,:,start:end].astype(np.int8) - contribs = np.mean([np.nan_to_num(f['shap/seq'][:,:,start:end]) for f in h5s], axis=0, dtype=np.float16) - - return sequences, contribs + sequences = h5s[0]["raw/seq"][:, :, start:end].astype(np.int8) # type: ignore # HDF5 array access + contribs = np.mean( + [np.nan_to_num(f["shap/seq"][:, :, start:end]) for f in h5s], # type: ignore # HDF5 array access + axis=0, + dtype=np.float16, + ) -def load_regions_from_bpnet_h5(h5_paths, half_width): + return sequences, contribs # type: ignore # HDF5 arrays converted to NumPy + + +def load_regions_from_bpnet_h5( + h5_paths: List[str], half_width: int +) -> Tuple[Int[ndarray, "N 4 L"], Float[ndarray, "N 4 L"]]: + """Load genomic sequences and contribution scores from BPNet HDF5 files. + + Parameters + ---------- + h5_paths : List[str] + List of paths to BPNet HDF5 files containing sequences and contribution scores. + Must be non-empty and contain compatible data shapes. + half_width : int + Half-width of regions to extract around the center. + Total region width will be 2 * half_width. + + Returns + ------- + sequences : Int[ndarray, "N 4 L"] + One-hot encoded DNA sequences where N is the number of regions, + 4 represents A,C,G,T nucleotides, and L is the region length (2 * half_width). + contribs : Float[ndarray, "N 4 L"] + Hypothetical contribution scores averaged across input files. + Shape is (N regions, 4 nucleotides, L region_length). + + Notes + ----- + BPNet files store sequences in 'input_seqs' and hypothetical scores in 'hyp_scores'. + The data requires axis swapping to convert from (n, length, 4) to (n, 4, length) format. + All input files must have the same dimensions and number of regions. + Missing values in contribution scores are converted to zero. + """ with ExitStack() as stack: h5s = [stack.enter_context(h5py.File(i)) for i in h5_paths] - start = h5s[0]['input_seqs'].shape[-2] // 2 - half_width + start = h5s[0]["input_seqs"].shape[-2] // 2 - half_width # type: ignore # HDF5 array access end = start + 2 * half_width - sequences = h5s[0]['input_seqs'][:,start:end,:].swapaxes(1,2).astype(np.int8) - contribs = np.mean([np.nan_to_num(f['hyp_scores'][:,start:end,:].swapaxes(1,2)) for f in h5s], axis=0, dtype=np.float16) + sequences = h5s[0]["input_seqs"][:, start:end, :].swapaxes(1, 2).astype(np.int8) # type: ignore # HDF5 array access with axis swap + contribs = np.mean( + [ + np.nan_to_num(f["hyp_scores"][:, start:end, :].swapaxes(1, 2)) # type: ignore # HDF5 array access + for f in h5s + ], + axis=0, + dtype=np.float16, + ) return sequences, contribs -def load_npy_or_npz(path): +def load_npy_or_npz(path: str) -> ndarray: + """Load array data from .npy or .npz file. + + Parameters + ---------- + path : str + Path to .npy or .npz file. File must exist and contain valid NumPy data. + + Returns + ------- + ndarray + Loaded array data. For .npz files, returns the first array ('arr_0'). + For .npy files, returns the array directly. + + Raises + ------ + FileNotFoundError + If the specified file does not exist. + KeyError + If .npz file does not contain 'arr_0' key. + """ f = np.load(path) if isinstance(f, np.ndarray): arr = f else: - arr = f['arr_0'] + arr = f["arr_0"] return arr -def load_regions_from_modisco_fmt(shaps_paths, ohe_path, half_width): + +def load_regions_from_modisco_fmt( + shaps_paths: List[str], ohe_path: str, half_width: int +) -> Tuple[Int[ndarray, "N 4 L"], Float[ndarray, "N 4 L"]]: + """Load genomic sequences and contribution scores from TF-MoDISco format files. + + Parameters + ---------- + shaps_paths : List[str] + List of paths to .npy/.npz files containing SHAP/attribution scores. + Must be non-empty and all files must have compatible shapes. + ohe_path : str + Path to .npy/.npz file containing one-hot encoded sequences. + Must have shape (n_regions, 4, sequence_length). + half_width : int + Half-width of regions to extract around the center. + Total region width will be 2 * half_width. + + Returns + ------- + sequences : Int[ndarray, "N 4 L"] + One-hot encoded DNA sequences where N is the number of regions, + 4 represents A,C,G,T nucleotides, and L is the region length (2 * half_width). + contribs : Float[ndarray, "N 4 L"] + SHAP contribution scores averaged across input files. + Shape is (N regions, 4 nucleotides, L region_length). + + Notes + ----- + All SHAP files must have the same shape as the sequence file. + Missing values in contribution scores are converted to zero. + The center of the input sequences is used as the reference point for extraction. + """ sequences_raw = load_npy_or_npz(ohe_path) start = sequences_raw.shape[-1] // 2 - half_width end = start + 2 * half_width - sequences = sequences_raw[:,:,start:end].astype(np.int8) + sequences = sequences_raw[:, :, start:end].astype(np.int8) - shaps = [np.nan_to_num(load_npy_or_npz(p)[:,:,start:end]) for p in shaps_paths] + shaps = [np.nan_to_num(load_npy_or_npz(p)[:, :, start:end]) for p in shaps_paths] contribs = np.mean(shaps, axis=0, dtype=np.float16) return sequences, contribs -def load_regions_npz(npz_path): +def load_regions_npz( + npz_path: str, +) -> Tuple[ + Int[ndarray, "N 4 L"], + Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]], + pl.DataFrame, + bool, +]: + """Load preprocessed genomic regions from NPZ file. + + Parameters + ---------- + npz_path : str + Path to NPZ file containing sequences, contributions, and optional coordinates. + Must contain 'sequences' and 'contributions' arrays at minimum. + + Returns + ------- + sequences : Int[ndarray, "N 4 L"] + One-hot encoded DNA sequences where N is the number of regions, + 4 represents A,C,G,T nucleotides, and L is the region length. + contributions : Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]] + Contribution scores in either hypothetical format (N, 4, L) or + projected format (N, L). Shape depends on the input data format. + peaks_df : pl.DataFrame + DataFrame containing peak region information with columns: + 'chr', 'chr_id', 'peak_region_start', 'peak_id', 'peak_name'. + has_peaks : bool + Whether the file contains genomic coordinate information. + If False, placeholder coordinate data is used. + + Notes + ----- + If genomic coordinates are not present in the NPZ file, creates placeholder + coordinate data and issues a warning. The placeholder data uses 'NA' for + chromosome names and sequential indices for peak IDs. + + Raises + ------ + KeyError + If required arrays 'sequences' or 'contributions' are missing from the file. + """ data = np.load(npz_path) - + if "chr" not in data.keys(): - warnings.warn("No genome coordinates present in the input .npz file. Returning sequences and contributions only.") + warnings.warn( + "No genome coordinates present in the input .npz file. Returning sequences and contributions only." + ) has_peaks = False num_regions = data["sequences"].shape[0] - peak_data = {"chr": np.array(["NA"] * num_regions, dtype='U'), "chr_id": np.arange(num_regions, dtype=np.uint32), - "peak_region_start": np.zeros(num_regions, dtype=np.int32), "peak_id": np.arange(num_regions, dtype=np.uint32), - "peak_name": np.array(["NA"] * num_regions, dtype='U')} + peak_data = { + "chr": np.array(["NA"] * num_regions, dtype="U"), + "chr_id": np.arange(num_regions, dtype=np.uint32), + "peak_region_start": np.zeros(num_regions, dtype=np.int32), + "peak_id": np.arange(num_regions, dtype=np.uint32), + "peak_name": np.array(["NA"] * num_regions, dtype="U"), + } else: has_peaks = True - peak_data = {"chr": data["chr"], "chr_id": data["chr_id"], "peak_region_start": data["start"], - "peak_id": data["peak_id"], "peak_name": data["peak_name"]} - + peak_data = { + "chr": data["chr"], + "chr_id": data["chr_id"], + "peak_region_start": data["start"], + "peak_id": data["peak_id"], + "peak_name": data["peak_name"], + } + peaks_df = pl.DataFrame(peak_data) return data["sequences"], data["contributions"], peaks_df, has_peaks -def write_regions_npz(sequences, contributions, out_path, peaks_df=None): +def write_regions_npz( + sequences: Int[ndarray, "N 4 L"], + contributions: Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]], + out_path: str, + peaks_df: Optional[pl.DataFrame] = None, +) -> None: + """Write genomic regions and contribution scores to compressed NPZ file. + + Parameters + ---------- + sequences : Int[ndarray, "N 4 L"] + One-hot encoded DNA sequences where N is the number of regions, + 4 represents A,C,G,T nucleotides, and L is the region length. + contributions : Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]] + Contribution scores in either hypothetical format (N, 4, L) or + projected format (N, L). + out_path : str + Output path for the NPZ file. Parent directory must exist. + peaks_df : Optional[pl.DataFrame] + DataFrame containing peak region information with columns: + 'chr', 'chr_id', 'peak_region_start', 'peak_id', 'peak_name'. + If None, only sequences and contributions are saved. + + Raises + ------ + ValueError + If the number of regions in sequences/contributions doesn't match peaks_df. + FileNotFoundError + If the parent directory of out_path does not exist. + + Notes + ----- + The output file is compressed using NumPy's savez_compressed format. + If peaks_df is provided, genomic coordinate information is included + in the output file for downstream analysis. + """ if peaks_df is None: - warnings.warn("No genome coordinates provided. Writing sequences and contributions only.") + warnings.warn( + "No genome coordinates provided. Writing sequences and contributions only." + ) np.savez_compressed(out_path, sequences=sequences, contributions=contributions) else: num_regions = peaks_df.height - if (num_regions != sequences.shape[0]) or (num_regions != contributions.shape[0]): - raise ValueError(f"Input sequences of shape {sequences.shape} and/or " - f"input contributions of shape {contributions.shape} " - f"are not compatible with peak region count of {num_regions}" ) - - chr_arr = peaks_df.get_column("chr").to_numpy().astype('U') + if (num_regions != sequences.shape[0]) or ( + num_regions != contributions.shape[0] + ): + raise ValueError( + f"Input sequences of shape {sequences.shape} and/or " + f"input contributions of shape {contributions.shape} " + f"are not compatible with peak region count of {num_regions}" + ) + + chr_arr = peaks_df.get_column("chr").to_numpy().astype("U") chr_id_arr = peaks_df.get_column("chr_id").to_numpy() start_arr = peaks_df.get_column("peak_region_start").to_numpy() peak_id_arr = peaks_df.get_column("peak_id").to_numpy() - peak_name_arr = peaks_df.get_column("peak_name").to_numpy().astype('U') - np.savez_compressed(out_path, sequences=sequences, contributions=contributions, - chr=chr_arr, chr_id=chr_id_arr, start=start_arr, peak_id=peak_id_arr, peak_name=peak_name_arr) + peak_name_arr = peaks_df.get_column("peak_name").to_numpy().astype("U") + np.savez_compressed( + out_path, + sequences=sequences, + contributions=contributions, + chr=chr_arr, + chr_id=chr_id_arr, + start=start_arr, + peak_id=peak_id_arr, + peak_name=peak_name_arr, + ) +def trim_motif(cwm: Float[ndarray, "4 W"], trim_threshold: float) -> Tuple[int, int]: + """Determine trimmed start and end positions for a motif based on contribution magnitude. + + This function identifies the core region of a motif by finding positions where + the total absolute contribution exceeds a threshold relative to the maximum. + + Parameters + ---------- + cwm : Float[ndarray, "4 W"] + Contribution weight matrix for the motif where 4 represents A,C,G,T + nucleotides and W is the motif width. + trim_threshold : float + Fraction of maximum score to use as trimming threshold (0.0 to 1.0). + Higher values result in more aggressive trimming. + + Returns + ------- + start : int + Start position of the trimmed motif (inclusive). + end : int + End position of the trimmed motif (exclusive). + + Notes + ----- + The trimming is based on the sum of absolute contributions across all nucleotides + at each position. Positions with contributions below trim_threshold * max_score + are removed from the motif edges. -def trim_motif(cwm, trim_threshold): - """ Adapted from https://github.com/jmschrei/tfmodisco-lite/blob/570535ee5ccf43d670e898d92d63af43d68c38c5/modiscolite/report.py#L213-L236 """ score = np.sum(np.abs(cwm), axis=0) trim_thresh = np.max(score) * trim_threshold pass_inds = np.nonzero(score >= trim_thresh) - start = max(np.min(pass_inds), 0) - end = min(np.max(pass_inds) + 1, len(score)) + start = max(int(np.min(pass_inds)), 0) # type: ignore # nonzero returns tuple of arrays + end = min(int(np.max(pass_inds)) + 1, len(score)) # type: ignore # nonzero returns tuple of arrays return start, end -def softmax(x, temp=100): +def softmax(x: Float[ndarray, "4 W"], temp: float = 100) -> Float[ndarray, "4 W"]: + """Apply softmax transformation with temperature scaling. + + Parameters + ---------- + x : Float[ndarray, "4 W"] + Input array to transform where 4 represents A,C,G,T nucleotides + and W is the motif width. + temp : float, default 100 + Temperature parameter for softmax scaling. Higher values create + sharper probability distributions. + + Returns + ------- + Float[ndarray, "4 W"] + Softmax-transformed array with same shape as input. Each column + sums to 1.0, representing nucleotide probabilities at each position. + + Notes + ----- + The softmax is applied along the nucleotide axis (axis=0), normalizing + each position to have probabilities that sum to 1. The temperature + parameter controls the sharpness of the distribution. + """ norm_x = x - np.mean(x, axis=1, keepdims=True) exp = np.exp(temp * norm_x) return exp / np.sum(exp, axis=0, keepdims=True) -def _motif_name_sort_key(data): - name = data[0] - if name.startswith("pattern_"): - pattern_num = int(name.split("_")[-1]) - return (pattern_num,) - else: - return (-1, name) +def _motif_name_sort_key(data: Tuple[str, Any]) -> Union[Tuple[int, int], Tuple[int, str]]: + """Generate sort key for TF-MoDISco motif names. + + This function creates a sort key that orders motifs by pattern number, + with non-standard patterns sorted to the end. -MODISCO_PATTERN_GROUPS = ['pos_patterns', 'neg_patterns'] + Parameters + ---------- + data : Tuple[str, Any] + Tuple containing motif name as first element and additional data. + The motif name should follow the format 'pattern_N' or 'pattern#N' where N is an integer. -def load_modisco_motifs(modisco_h5_path, trim_threshold, motif_type, motifs_include, - motif_name_map, motif_lambdas, motif_lambda_default, include_rc): + Returns + ------- + Union[Tuple[int, int], Tuple[int, str]] + Sort key tuple for ordering motifs. Standard pattern names return + (0, pattern_number) while non-standard names return (1, name). + + Notes + ----- + This function is used internally by load_modisco_motifs to ensure + consistent motif ordering across runs. """ + pattern_name = data[0] + try: + return (0, int(pattern_name.split("_")[-1])) + except (ValueError, IndexError): + try: + return (0, int(pattern_name.split("#")[-1])) + except (ValueError, IndexError): + return (1, pattern_name) + + +MODISCO_PATTERN_GROUPS = ["pos_patterns", "neg_patterns"] + + +def load_modisco_motifs( + modisco_h5_path: str, + trim_coords: Optional[Dict[str, Tuple[int, int]]], + trim_thresholds: Optional[Dict[str, float]], + trim_threshold_default: float, + motif_type: str, + motifs_include: Optional[List[str]], + motif_name_map: Optional[Dict[str, str]], + motif_lambdas: Optional[Dict[str, float]], + motif_lambda_default: float, + include_rc: bool, +) -> Tuple[pl.DataFrame, Float[ndarray, "M 4 W"], Int[ndarray, "M W"], ndarray]: + """Load motif data from TF-MoDISco HDF5 file with customizable processing options. + + This function extracts contribution weight matrices and associated metadata from + TF-MoDISco results, with support for custom naming, trimming, and regularization + parameters. + + Parameters + ---------- + modisco_h5_path : str + Path to TF-MoDISco HDF5 results file containing pattern groups. + trim_coords : Optional[Dict[str, Tuple[int, int]]] + Manual trim coordinates for specific motifs {motif_name: (start, end)}. + Takes precedence over automatic trimming based on thresholds. + trim_thresholds : Optional[Dict[str, float]] + Custom trim thresholds for specific motifs {motif_name: threshold}. + Values should be between 0.0 and 1.0. + trim_threshold_default : float + Default trim threshold for motifs not in trim_thresholds. + Fraction of maximum contribution used for trimming. + motif_type : str + Type of motif to extract. Must be one of: + - 'cwm': Contribution weight matrix (normalized) + - 'hcwm': Hypothetical contribution weight matrix + - 'pfm': Position frequency matrix + - 'pfm_softmax': Softmax-transformed position frequency matrix + motifs_include : Optional[List[str]] + List of motif names to include. If None, includes all motifs found. + Names should follow format 'pos_patterns.pattern_N' or 'neg_patterns.pattern_N'. + motif_name_map : Optional[Dict[str, str]] + Mapping from original to custom motif names {orig_name: new_name}. + New names must be unique across all motifs. + motif_lambdas : Optional[Dict[str, float]] + Custom lambda regularization values for specific motifs {motif_name: lambda}. + Higher values increase sparsity penalty for the corresponding motif. + motif_lambda_default : float + Default lambda value for motifs not specified in motif_lambdas. + include_rc : bool + Whether to include reverse complement motifs in addition to forward motifs. + If True, doubles the number of motifs returned. + + Returns + ------- + motifs_df : pl.DataFrame + DataFrame containing motif metadata with columns: motif_id, motif_name, + motif_name_orig, strand, motif_start, motif_end, motif_scale, lambda. + cwms : Float[ndarray, "M 4 W"] + Contribution weight matrices for all motifs where M is the number of motifs, + 4 represents A,C,G,T nucleotides, and W is the motif width. + trim_masks : Int[ndarray, "M W"] + Binary masks indicating core motif regions (1) vs trimmed regions (0). + Shape is (M motifs, W motif_width). + names : ndarray + Array of unique motif names (forward strand only). + + Raises + ------ + ValueError + If motif_type is not one of the supported types, or if motif names + in motif_name_map are not unique. + FileNotFoundError + If the specified HDF5 file does not exist. + KeyError + If required datasets are missing from the HDF5 file. + + Notes + ----- + Motif trimming removes low-contribution positions from the edges based on + the position-wise sum of absolute contributions across nucleotides. The trimming + helps focus on the core binding site. + Adapted from https://github.com/jmschrei/tfmodisco-lite/blob/570535ee5ccf43d670e898d92d63af43d68c38c5/modiscolite/report.py#L252-L272 """ - motif_data_lsts = {"motif_name": [], "motif_name_orig": [], "strand": [], "motif_start": [], - "motif_end": [], "motif_scale": [], "lambda": []} - motif_lst = [] + motif_data_lsts = { + "motif_name": [], + "motif_name_orig": [], + "strand": [], + "motif_start": [], + "motif_end": [], + "motif_scale": [], + "lambda": [], + } + motif_lst = [] trim_mask_lst = [] if motifs_include is not None: - motifs_include = set(motifs_include) + motifs_include_set = set(motifs_include) + else: + motifs_include_set = None if motif_name_map is None: motif_name_map = {} @@ -258,33 +850,52 @@ def load_modisco_motifs(modisco_h5_path, trim_threshold, motif_type, motifs_incl if motif_lambdas is None: motif_lambdas = {} + if trim_coords is None: + trim_coords = {} + if trim_thresholds is None: + trim_thresholds = {} + if len(motif_name_map.values()) != len(set(motif_name_map.values())): raise ValueError("Specified motif names are not unique") - with h5py.File(modisco_h5_path, 'r') as modisco_results: + with h5py.File(modisco_h5_path, "r") as modisco_results: for name in MODISCO_PATTERN_GROUPS: if name not in modisco_results.keys(): continue metacluster = modisco_results[name] - for ind, (pattern_name, pattern) in enumerate(sorted(metacluster.items(), key=_motif_name_sort_key)): - pattern_tag = f'{name}.{pattern_name}' - - if motifs_include is not None and pattern_tag not in motifs_include: + for _, (pattern_name, pattern) in enumerate( + sorted(metacluster.items(), key=_motif_name_sort_key) # type: ignore # HDF5 access + ): + pattern_tag = f"{name}.{pattern_name}" + + if ( + motifs_include_set is not None + and pattern_tag not in motifs_include_set + ): continue motif_lambda = motif_lambdas.get(pattern_tag, motif_lambda_default) pattern_tag_orig = pattern_tag pattern_tag = motif_name_map.get(pattern_tag, pattern_tag) - cwm_raw = pattern['contrib_scores'][:].T + cwm_raw = pattern["contrib_scores"][:].T # type: ignore cwm_norm = np.sqrt((cwm_raw**2).sum()) cwm_fwd = cwm_raw / cwm_norm - cwm_rev = cwm_fwd[::-1,::-1] - start_fwd, end_fwd = trim_motif(cwm_fwd, trim_threshold) - start_rev, end_rev = trim_motif(cwm_rev, trim_threshold) - + cwm_rev = cwm_fwd[::-1, ::-1] + + if pattern_tag in trim_coords: + start_fwd, end_fwd = trim_coords[pattern_tag] + else: + trim_threshold = trim_thresholds.get( + pattern_tag, trim_threshold_default + ) + start_fwd, end_fwd = trim_motif(cwm_fwd, trim_threshold) + + cwm_len = cwm_fwd.shape[1] + start_rev, end_rev = cwm_len - end_fwd, cwm_len - start_fwd + trim_mask_fwd = np.zeros(cwm_fwd.shape[1], dtype=np.int8) trim_mask_fwd[start_fwd:end_fwd] = 1 trim_mask_rev = np.zeros(cwm_rev.shape[1], dtype=np.int8) @@ -296,29 +907,34 @@ def load_modisco_motifs(modisco_h5_path, trim_threshold, motif_type, motifs_incl motif_norm = cwm_norm elif motif_type == "hcwm": - motif_raw = pattern['hypothetical_contribs'][:].T + motif_raw = pattern["hypothetical_contribs"][:].T # type: ignore motif_norm = np.sqrt((motif_raw**2).sum()) motif_fwd = motif_raw / motif_norm - motif_rev = motif_fwd[::-1,::-1] + motif_rev = motif_fwd[::-1, ::-1] elif motif_type == "pfm": - motif_raw = pattern['sequence'][:].T + motif_raw = pattern["sequence"][:].T # type: ignore motif_norm = 1 motif_fwd = motif_raw / np.sum(motif_raw, axis=0, keepdims=True) - motif_rev = motif_fwd[::-1,::-1] + motif_rev = motif_fwd[::-1, ::-1] elif motif_type == "pfm_softmax": - motif_raw = pattern['sequence'][:].T + motif_raw = pattern["sequence"][:].T # type: ignore motif_norm = 1 motif_fwd = softmax(motif_raw) - motif_rev = motif_fwd[::-1,::-1] + motif_rev = motif_fwd[::-1, ::-1] + + else: + raise ValueError( + f"Invalid motif_type: {motif_type}. Must be one of 'cwm', 'hcwm', 'pfm', 'pfm_softmax'." + ) motif_data_lsts["motif_name"].append(pattern_tag) motif_data_lsts["motif_name_orig"].append(pattern_tag_orig) - motif_data_lsts["strand"].append('+') + motif_data_lsts["strand"].append("+") motif_data_lsts["motif_start"].append(start_fwd) motif_data_lsts["motif_end"].append(end_fwd) motif_data_lsts["motif_scale"].append(motif_norm) @@ -327,7 +943,7 @@ def load_modisco_motifs(modisco_h5_path, trim_threshold, motif_type, motifs_incl if include_rc: motif_data_lsts["motif_name"].append(pattern_tag) motif_data_lsts["motif_name_orig"].append(pattern_tag_orig) - motif_data_lsts["strand"].append('-') + motif_data_lsts["strand"].append("-") motif_data_lsts["motif_start"].append(start_rev) motif_data_lsts["motif_end"].append(end_rev) motif_data_lsts["motif_scale"].append(motif_norm) @@ -339,17 +955,75 @@ def load_modisco_motifs(modisco_h5_path, trim_threshold, motif_type, motifs_incl else: motif_lst.append(motif_fwd) trim_mask_lst.append(trim_mask_fwd) - + motifs_df = pl.DataFrame(motif_data_lsts).with_row_index(name="motif_id") cwms = np.stack(motif_lst, dtype=np.float16, axis=0) trim_masks = np.stack(trim_mask_lst, dtype=np.int8, axis=0) - names = motifs_df.filter(pl.col("strand") == "+").get_column("motif_name").to_numpy() + names = ( + motifs_df.filter(pl.col("strand") == "+").get_column("motif_name").to_numpy() + ) return motifs_df, cwms, trim_masks, names -def load_modisco_seqlets(modisco_h5_path, peaks_df, motifs_df, half_width, modisco_half_width, lazy=False): - +def load_modisco_seqlets( + modisco_h5_path: str, + peaks_df: pl.DataFrame, + motifs_df: pl.DataFrame, + half_width: int, + modisco_half_width: int, + lazy: bool = False, +) -> Union[pl.DataFrame, pl.LazyFrame]: + """Load seqlet data from TF-MoDISco HDF5 file and convert to genomic coordinates. + + This function extracts seqlet instances from TF-MoDISco results and converts + their relative positions to absolute genomic coordinates using peak region + information. + + Parameters + ---------- + modisco_h5_path : str + Path to TF-MoDISco HDF5 results file containing seqlet data. + peaks_df : pl.DataFrame + DataFrame containing peak region information with columns: + 'peak_id', 'chr', 'chr_id', 'peak_region_start'. + motifs_df : pl.DataFrame + DataFrame containing motif metadata with columns: + 'motif_name_orig', 'strand', 'motif_name', 'motif_start', 'motif_end'. + half_width : int + Half-width of the current analysis regions. + modisco_half_width : int + Half-width of the regions used in the original TF-MoDISco analysis. + Used to calculate coordinate offsets. + lazy : bool, default False + If True, returns a LazyFrame for efficient chaining of operations. + If False, collects the result into a DataFrame. + + Returns + ------- + Union[pl.DataFrame, pl.LazyFrame] + Seqlets with genomic coordinates containing columns: + - chr: Chromosome name + - chr_id: Numeric chromosome identifier + - start: Start coordinate of trimmed motif instance + - end: End coordinate of trimmed motif instance + - start_untrimmed: Start coordinate of full motif instance + - end_untrimmed: End coordinate of full motif instance + - is_revcomp: Whether the motif is reverse complemented + - strand: Motif strand ('+' or '-') + - motif_name: Motif name (may be remapped) + - peak_id: Peak identifier + - peak_region_start: Peak region start coordinate + + Notes + ----- + Seqlets are deduplicated based on chromosome ID, start position (untrimmed), + motif name, and reverse complement status to avoid redundant instances. + + The coordinate transformation accounts for differences in region sizes + between the original TF-MoDISco analysis and the current analysis. + """ + start_lst = [] end_lst = [] is_revcomp_lst = [] @@ -357,23 +1031,26 @@ def load_modisco_seqlets(modisco_h5_path, peaks_df, motifs_df, half_width, modis peak_id_lst = [] pattern_tags = [] - with h5py.File(modisco_h5_path, 'r') as modisco_results: + with h5py.File(modisco_h5_path, "r") as modisco_results: for name in MODISCO_PATTERN_GROUPS: if name not in modisco_results.keys(): continue metacluster = modisco_results[name] - key = lambda x: int(x[0].split("_")[-1]) - for ind, (pattern_name, pattern) in enumerate(sorted(metacluster.items(), key=key)): - pattern_tag = f'{name}.{pattern_name}' - starts = pattern['seqlets/start'][:].astype(np.int32) - ends = pattern['seqlets/end'][:].astype(np.int32) - is_revcomps = pattern['seqlets/is_revcomp'][:].astype(bool) - strands = ['+' if not i else '-' for i in is_revcomps] - peak_ids = pattern['seqlets/example_idx'][:].astype(np.uint32) + key = _motif_name_sort_key + for _, (pattern_name, pattern) in enumerate( + sorted(metacluster.items(), key=key) # type: ignore # HDF5 access + ): + pattern_tag = f"{name}.{pattern_name}" - n_seqlets = int(pattern['seqlets/n_seqlets'][0]) + starts = pattern["seqlets/start"][:].astype(np.int32) # type: ignore + ends = pattern["seqlets/end"][:].astype(np.int32) # type: ignore + is_revcomps = pattern["seqlets/is_revcomp"][:].astype(bool) # type: ignore + strands = ["+" if not i else "-" for i in is_revcomps] + peak_ids = pattern["seqlets/example_idx"][:].astype(np.uint32) # type: ignore + + n_seqlets = int(pattern["seqlets/n_seqlets"][0]) # type: ignore start_lst.append(starts) end_lst.append(ends) @@ -390,7 +1067,7 @@ def load_modisco_seqlets(modisco_h5_path, peaks_df, motifs_df, half_width, modis "peak_id": np.concatenate(peak_id_lst), "motif_name_orig": pattern_tags, } - + offset = half_width - modisco_half_width seqlets_df = ( @@ -400,15 +1077,23 @@ def load_modisco_seqlets(modisco_h5_path, peaks_df, motifs_df, half_width, modis .select( chr=pl.col("chr"), chr_id=pl.col("chr_id"), - start=pl.col("peak_region_start") + pl.col("seqlet_start") + pl.col("motif_start") + offset, - end=pl.col("peak_region_start") + pl.col("seqlet_start") + pl.col("motif_end") + offset, - start_untrimmed=pl.col("peak_region_start") + pl.col("seqlet_start") + offset, + start=pl.col("peak_region_start") + + pl.col("seqlet_start") + + pl.col("motif_start") + + offset, + end=pl.col("peak_region_start") + + pl.col("seqlet_start") + + pl.col("motif_end") + + offset, + start_untrimmed=pl.col("peak_region_start") + + pl.col("seqlet_start") + + offset, end_untrimmed=pl.col("peak_region_start") + pl.col("seqlet_end") + offset, is_revcomp=pl.col("is_revcomp"), strand=pl.col("strand"), motif_name=pl.col("motif_name"), peak_id=pl.col("peak_id"), - peak_region_start=pl.col("peak_region_start") + peak_region_start=pl.col("peak_region_start"), ) .unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"]) ) @@ -418,8 +1103,27 @@ def load_modisco_seqlets(modisco_h5_path, peaks_df, motifs_df, half_width, modis return seqlets_df -def write_modisco_seqlets(seqlets_df, out_path): +def write_modisco_seqlets( + seqlets_df: Union[pl.DataFrame, pl.LazyFrame], out_path: str +) -> None: + """Write TF-MoDISco seqlets to TSV file. + + Parameters + ---------- + seqlets_df : Union[pl.DataFrame, pl.LazyFrame] + Seqlets DataFrame with genomic coordinates. Must contain columns + that are safe to drop: 'chr_id', 'is_revcomp'. + out_path : str + Output TSV file path. + + Notes + ----- + Removes internal columns 'chr_id' and 'is_revcomp' before writing + to create a clean output format suitable for downstream analysis. + """ seqlets_df = seqlets_df.drop(["chr_id", "is_revcomp"]) + if isinstance(seqlets_df, pl.LazyFrame): + seqlets_df = seqlets_df.collect() seqlets_df.write_csv(out_path, separator="\t") @@ -439,68 +1143,151 @@ def write_modisco_seqlets(seqlets_df, out_path): "strand": pl.String, "peak_name": pl.String, "peak_id": pl.UInt32, - } - -def load_hits(hits_path, lazy=False): - hits_df = ( - pl.scan_csv(hits_path, separator='\t', quote_char=None, schema=HITS_DTYPES) - .with_columns(pl.lit(1).alias("count")) - ) +HITS_COLLAPSED_DTYPES = HITS_DTYPES | {"is_primary": pl.UInt32} + + +def load_hits( + hits_path: str, lazy: bool = False, schema: Dict[str, Any] = HITS_DTYPES +) -> Union[pl.DataFrame, pl.LazyFrame]: + """Load motif hit data from TSV file. + + Parameters + ---------- + hits_path : str + Path to TSV file containing motif hit results. + lazy : bool, default False + If True, returns a LazyFrame for efficient chaining operations. + If False, collects the result into a DataFrame. + schema : Dict[str, Any], default HITS_DTYPES + Schema defining column names and data types for the hit data. + + Returns + ------- + Union[pl.DataFrame, pl.LazyFrame] + Hit data with an additional 'count' column set to 1 for aggregation. + """ + hits_df = pl.scan_csv( + hits_path, separator="\t", quote_char=None, schema=schema + ).with_columns(pl.lit(1).alias("count")) return hits_df if lazy else hits_df.collect() -def write_hits(hits_df, peaks_df, motifs_df, qc_df, out_dir, motif_width): +def write_hits_processed( + hits_df: Union[pl.DataFrame, pl.LazyFrame], + out_path: str, + schema: Optional[Dict[str, Any]] = HITS_DTYPES, +) -> None: + """Write processed hit data to TSV file with optional column filtering. + + Parameters + ---------- + hits_df : Union[pl.DataFrame, pl.LazyFrame] + Hit data to write to file. + out_path : str + Output path for the TSV file. + schema : Optional[Dict[str, Any]], default HITS_DTYPES + Schema defining which columns to include in output. + If None, all columns are written. + """ + if schema is not None: + hits_df = hits_df.select(schema.keys()) + + if isinstance(hits_df, pl.LazyFrame): + hits_df = hits_df.collect() + + hits_df.write_csv(out_path, separator="\t") + + +def write_hits( + hits_df: Union[pl.DataFrame, pl.LazyFrame], + peaks_df: pl.DataFrame, + motifs_df: pl.DataFrame, + qc_df: pl.DataFrame, + out_dir: str, + motif_width: int, +) -> None: + """Write comprehensive hit results to multiple output files. + + This function combines hit data with peak, motif, and quality control information + to generate complete output files including genomic coordinates and scores. + + Parameters + ---------- + hits_df : Union[pl.DataFrame, pl.LazyFrame] + Hit data containing motif instance information. + peaks_df : pl.DataFrame + Peak region information for coordinate conversion. + motifs_df : pl.DataFrame + Motif metadata for annotation and trimming information. + qc_df : pl.DataFrame + Quality control data for normalization factors. + out_dir : str + Output directory for results files. Will be created if it doesn't exist. + motif_width : int + Width of motif instances for coordinate calculations. + + Notes + ----- + Creates three output files: + - hits.tsv: Complete hit data with all instances + - hits_unique.tsv: Deduplicated hits by genomic position and motif (excludes rows with NA chromosome coordinates) + - hits.bed: BED format file for genome browser visualization + + Rows where the chromosome field is NA are filtered out during deduplication + to ensure that data_unique only contains well-defined genomic coordinates. + """ os.makedirs(out_dir, exist_ok=True) out_path_tsv = os.path.join(out_dir, "hits.tsv") out_path_tsv_unique = os.path.join(out_dir, "hits_unique.tsv") out_path_bed = os.path.join(out_dir, "hits.bed") data_all = ( - hits_df - .lazy() + hits_df.lazy() .join(peaks_df.lazy(), on="peak_id", how="inner") .join(qc_df.lazy(), on="peak_id", how="inner") .join(motifs_df.lazy(), on="motif_id", how="inner") .select( chr_id=pl.col("chr_id"), chr=pl.col("chr"), - start=pl.col("peak_region_start") + pl.col("hit_start") + pl.col("motif_start"), + start=pl.col("peak_region_start") + + pl.col("hit_start") + + pl.col("motif_start"), end=pl.col("peak_region_start") + pl.col("hit_start") + pl.col("motif_end"), start_untrimmed=pl.col("peak_region_start") + pl.col("hit_start"), - end_untrimmed=pl.col("peak_region_start") + pl.col("hit_start") + motif_width, + end_untrimmed=pl.col("peak_region_start") + + pl.col("hit_start") + + motif_width, motif_name=pl.col("motif_name"), hit_coefficient=pl.col("hit_coefficient"), - hit_coefficient_global=pl.col("hit_coefficient") * (pl.col("global_scale")**2), + hit_coefficient_global=pl.col("hit_coefficient") + * (pl.col("global_scale") ** 2), hit_similarity=pl.col("hit_similarity"), hit_correlation=pl.col("hit_similarity"), hit_importance=pl.col("hit_importance") * pl.col("global_scale"), - hit_importance_sq=pl.col("hit_importance_sq") * (pl.col("global_scale")**2), + hit_importance_sq=pl.col("hit_importance_sq") + * (pl.col("global_scale") ** 2), strand=pl.col("strand"), peak_name=pl.col("peak_name"), peak_id=pl.col("peak_id"), - motif_lambda = pl.col("lambda"), + motif_lambda=pl.col("lambda"), ) .sort(["chr_id", "start"]) .select(HITS_DTYPES.keys()) ) - data_unique = ( - data_all - .unique(subset=["chr", "start", "motif_name", "strand"], maintain_order=True) + data_unique = data_all.filter(pl.col("chr").is_not_null()).unique( + subset=["chr", "start", "motif_name", "strand"], maintain_order=True ) - data_bed = ( - data_unique - .select( - chr=pl.col("chr"), - start=pl.col("start"), - end=pl.col("end"), - motif_name=pl.col("motif_name"), - score=pl.lit(0), - strand=pl.col("strand") - ) + data_bed = data_unique.select( + chr=pl.col("chr"), + start=pl.col("start"), + end=pl.col("end"), + motif_name=pl.col("motif_name"), + score=pl.lit(0), + strand=pl.col("strand"), ) data_all.collect().write_csv(out_path_tsv, separator="\t") @@ -508,10 +1295,20 @@ def write_hits(hits_df, peaks_df, motifs_df, qc_df, out_dir, motif_width): data_bed.collect().write_csv(out_path_bed, include_header=False, separator="\t") -def write_qc(qc_df, peaks_df, out_path): +def write_qc(qc_df: pl.DataFrame, peaks_df: pl.DataFrame, out_path: str) -> None: + """Write quality control data with peak information to TSV file. + + Parameters + ---------- + qc_df : pl.DataFrame + Quality control metrics for each peak region. + peaks_df : pl.DataFrame + Peak region information for coordinate annotation. + out_path : str + Output path for the TSV file. + """ df = ( - qc_df - .lazy() + qc_df.lazy() .join(peaks_df.lazy(), on="peak_id", how="inner") .sort(["chr_id", "peak_region_start"]) .drop("chr_id") @@ -520,7 +1317,16 @@ def write_qc(qc_df, peaks_df, out_path): df.write_csv(out_path, separator="\t") -def write_motifs_df(motifs_df, out_path): +def write_motifs_df(motifs_df: pl.DataFrame, out_path: str) -> None: + """Write motif metadata to TSV file. + + Parameters + ---------- + motifs_df : pl.DataFrame + Motif metadata DataFrame. + out_path : str + Output path for the TSV file. + """ motifs_df.write_csv(out_path, separator="\t") @@ -535,38 +1341,132 @@ def write_motifs_df(motifs_df, out_path): "lambda": pl.Float32, } -def load_motifs_df(motifs_path): + +def load_motifs_df(motifs_path: str) -> Tuple[pl.DataFrame, ndarray]: + """Load motif metadata from TSV file. + + Parameters + ---------- + motifs_path : str + Path to motif metadata TSV file. + + Returns + ------- + motifs_df : pl.DataFrame + Motif metadata with predefined schema. + motif_names : ndarray + Array of unique forward-strand motif names. + """ motifs_df = pl.read_csv(motifs_path, separator="\t", schema=MOTIF_DTYPES) - motif_names = motifs_df.filter(pl.col("strand") == "+").get_column("motif_name").to_numpy() + motif_names = ( + motifs_df.filter(pl.col("strand") == "+").get_column("motif_name").to_numpy() + ) return motifs_df, motif_names -def write_motif_cwms(cwms, out_path): +def write_motif_cwms(cwms: Float[ndarray, "M 4 W"], out_path: str) -> None: + """Write motif contribution weight matrices to .npy file. + + Parameters + ---------- + cwms : Float[ndarray, "M 4 W"] + Contribution weight matrices for M motifs, 4 nucleotides, W width. + out_path : str + Output path for the .npy file. + """ np.save(out_path, cwms) -def load_motif_cwms(cwms_path): +def load_motif_cwms(cwms_path: str) -> Float[ndarray, "M 4 W"]: + """Load motif contribution weight matrices from .npy file. + + Parameters + ---------- + cwms_path : str + Path to .npy file containing CWMs. + + Returns + ------- + Float[ndarray, "M 4 W"] + Loaded contribution weight matrices. + """ return np.load(cwms_path) -def write_params(params, out_path): +def write_params(params: Dict[str, Any], out_path: str) -> None: + """Write parameter dictionary to JSON file. + + Parameters + ---------- + params : Dict[str, Any] + Parameter dictionary to serialize. + out_path : str + Output path for the JSON file. + """ with open(out_path, "w") as f: json.dump(params, f, indent=4) -def load_params(params_path): +def load_params(params_path: str) -> Dict[str, Any]: + """Load parameter dictionary from JSON file. + + Parameters + ---------- + params_path : str + Path to JSON file containing parameters. + + Returns + ------- + Dict[str, Any] + Loaded parameter dictionary. + """ with open(params_path) as f: params = json.load(f) return params -def write_occ_df(occ_df, out_path): +def write_occ_df(occ_df: pl.DataFrame, out_path: str) -> None: + """Write occurrence data to TSV file. + + Parameters + ---------- + occ_df : pl.DataFrame + Occurrence data DataFrame. + out_path : str + Output path for the TSV file. + """ occ_df.write_csv(out_path, separator="\t") -def write_report_data(report_df, cwms, out_dir): +def write_seqlet_confusion_df(seqlet_confusion_df: pl.DataFrame, out_path: str) -> None: + """Write seqlet confusion matrix data to TSV file. + + Parameters + ---------- + seqlet_confusion_df : pl.DataFrame + Seqlet confusion matrix DataFrame. + out_path : str + Output path for the TSV file. + """ + seqlet_confusion_df.write_csv(out_path, separator="\t") + + +def write_report_data( + report_df: pl.DataFrame, cwms: Dict[str, Dict[str, ndarray]], out_dir: str +) -> None: + """Write comprehensive motif report data including CWMs and metadata. + + Parameters + ---------- + report_df : pl.DataFrame + Report metadata DataFrame. + cwms : Dict[str, Dict[str, ndarray]] + Nested dictionary of motif names to CWM types to arrays. + out_dir : str + Output directory for report files. + """ cwms_dir = os.path.join(out_dir, "CWMs") os.makedirs(cwms_dir, exist_ok=True) @@ -577,4 +1477,3 @@ def write_report_data(report_df, cwms, out_dir): np.savetxt(os.path.join(motif_dir, f"{cwm_type}.txt"), cwm) report_df.write_csv(os.path.join(out_dir, "motif_report.tsv"), separator="\t") - diff --git a/src/finemo/evaluation.py b/src/finemo/evaluation.py index 8fecafc..cf74375 100644 --- a/src/finemo/evaluation.py +++ b/src/finemo/evaluation.py @@ -1,50 +1,69 @@ -import os +"""Evaluation module for assessing Fi-NeMo motif discovery and hit calling performance. + +This module provides functions for: +- Computing motif occurrence statistics and co-occurrence patterns +- Evaluating motif discovery quality against TF-MoDISco results +- Analyzing hit calling performance and recall metrics +- Generating confusion matrices for seqlet-hit comparisons +""" + import warnings -import importlib +from typing import List, Tuple, Dict, Any, Union import numpy as np +from numpy import ndarray import polars as pl -import matplotlib.pyplot as plt -from matplotlib.patheffects import AbstractPathEffect -from matplotlib.textpath import TextPath -from matplotlib.transforms import Affine2D -from matplotlib.font_manager import FontProperties -from jinja2 import Template - -from . import templates - - -def abbreviate_motif_name(name): - try: - group, motif = name.split(".") - - if group == "pos_patterns": - group_short = "+" - elif group == "neg_patterns": - group_short = "-" - else: - raise Exception - - motif_num = motif.split("_")[1] - - return f"{group_short}/{motif_num}" - - except: - return name - - -def get_motif_occurences(hits_df, motif_names): +from jaxtyping import Float, Int + + +def get_motif_occurences( + hits_df: pl.LazyFrame, motif_names: List[str] +) -> Tuple[pl.DataFrame, Int[ndarray, "M M"]]: + """Compute motif occurrence statistics and co-occurrence matrix. + + This function analyzes motif occurrence patterns across peaks by creating + a pivot table of hit counts and computing pairwise co-occurrence statistics. + + Parameters + ---------- + hits_df : pl.LazyFrame + Lazy DataFrame containing hit data with required columns: + - peak_id : Peak identifier + - motif_name : Name of the motif + Additional columns are ignored. + motif_names : List[str] + List of motif names to include in analysis. Missing motifs + will be added as columns with zero counts. + + Returns + ------- + occ_df : pl.DataFrame + DataFrame with motif occurrence counts per peak. Contains: + - peak_id column + - One column per motif with hit counts + - 'total' column summing all motif counts per peak + coocc : Int[ndarray, "M M"] + Co-occurrence matrix where M = len(motif_names). + Entry (i,j) indicates number of peaks containing both motif i and motif j. + Diagonal entries show total peaks containing each motif. + + Notes + ----- + The co-occurrence matrix is computed using binary occurrence indicators, + so multiple hits of the same motif in a peak are treated as a single occurrence. + """ occ_df = ( - hits_df - .collect() - .pivot(index="peak_id", columns="motif_name", values="count", aggregate_function="sum") + hits_df.collect() + .with_columns(pl.lit(1).alias("count")) + .pivot( + on="motif_name", index="peak_id", values="count", aggregate_function="sum" + ) .fill_null(0) ) missing_cols = set(motif_names) - set(occ_df.columns) occ_df = ( - occ_df - .with_columns([pl.lit(0).alias(m) for m in missing_cols]) + occ_df.with_columns([pl.lit(0).alias(m) for m in missing_cols]) .with_columns(total=pl.sum_horizontal(*motif_names)) .sort(["peak_id"]) ) @@ -54,7 +73,7 @@ def get_motif_occurences(hits_df, motif_names): occ_mat = np.zeros((num_peaks, num_motifs), dtype=np.int16) for i, m in enumerate(motif_names): - occ_mat[:,i] = occ_df.get_column(m).to_numpy() + occ_mat[:, i] = occ_df.get_column(m).to_numpy() occ_bin = (occ_mat > 0).astype(np.int32) coocc = occ_bin.T @ occ_bin @@ -62,178 +81,250 @@ def get_motif_occurences(hits_df, motif_names): return occ_df, coocc -def plot_hit_distributions(occ_df, motif_names, plot_dir): - motifs_dir = os.path.join(plot_dir, "motif_hit_distributions") - os.makedirs(motifs_dir, exist_ok=True) - - for m in motif_names: - fig, ax = plt.subplots(figsize=(6, 2)) - - unique, counts = np.unique(occ_df.get_column(m), return_counts=True) - freq = counts / counts.sum() - num_bins = np.amax(unique, initial=0) + 1 - x = np.arange(num_bins) - y = np.zeros(num_bins) - y[unique] = freq - ax.bar(x, y) - - output_path = os.path.join(motifs_dir, f"{m}.png") - plt.savefig(output_path, dpi=300) - - plt.close(fig) - - fig, ax = plt.subplots(figsize=(8, 4)) - - unique, counts = np.unique(occ_df.get_column("total"), return_counts=True) - freq = counts / counts.sum() - num_bins = np.amax(unique, initial=0) + 1 - x = np.arange(num_bins) - y = np.zeros(num_bins) - y[unique] = freq - ax.bar(x, y) - - ax.set_xlabel("Motifs per peak") - ax.set_ylabel("Frequency") - - output_path = os.path.join(plot_dir, "total_hit_distribution.png") - plt.savefig(output_path, dpi=300) - - plt.close(fig) - - -def plot_peak_motif_indicator_heatmap(peak_hit_counts, motif_names, output_path): - """ - Plots a simple indicator heatmap of the motifs in each peak. +def get_cwms( + regions: Float[ndarray, "N 4 L"], positions_df: pl.DataFrame, motif_width: int +) -> Float[ndarray, "H 4 W"]: + """Extract contribution weight matrices from regions based on hit positions. + + This function extracts motif-sized windows from contribution score regions + at positions specified by hit coordinates. It handles both forward and + reverse complement orientations and filters out invalid positions. + + Parameters + ---------- + regions : Float[ndarray, "N 4 L"] + Input contribution score regions multiplied by one-hot sequences. + Shape: (n_peaks, 4, region_width) where 4 represents DNA bases (A,C,G,T). + positions_df : pl.DataFrame + DataFrame containing hit positions with required columns: + - peak_id : int, Peak index (0-based) + - start_untrimmed : int, Start position in genomic coordinates + - peak_region_start : int, Peak region start coordinate + - is_revcomp : bool, Whether hit is on reverse complement strand + motif_width : int + Width of motifs to extract. Must be positive. + + Returns + ------- + cwms : Float[ndarray, "H 4 W"] + Extracted contribution weight matrices for valid hits. + Shape: (n_valid_hits, 4, motif_width) + Invalid hits (outside region boundaries) are filtered out. + + Notes + ----- + - Start positions are converted from genomic to region-relative coordinates + - Reverse complement hits have their sequence order reversed + - Hits extending beyond region boundaries are excluded + - The mean is computed across all valid hits, with warnings suppressed + for empty slices or invalid operations + + Raises + ------ + ValueError + If motif_width is non-positive or positions_df lacks required columns. """ - cov_norm = 1 / np.sqrt(np.diag(peak_hit_counts)) - matrix = peak_hit_counts * cov_norm[:,None] * cov_norm[None,:] - motif_keys = [abbreviate_motif_name(m) for m in motif_names] - - fig, ax = plt.subplots(figsize=(8, 8)) - - # Plot the heatmap - ax.imshow(matrix, interpolation="nearest", aspect="auto", cmap="Greens") - - # Set axes on heatmap - ax.set_yticks(np.arange(len(motif_keys))) - ax.set_yticklabels(motif_keys) - ax.set_xticks(np.arange(len(motif_keys))) - ax.set_xticklabels(motif_keys, rotation=90) - ax.set_xlabel("Motif i") - ax.set_ylabel("Motif j") - - plt.savefig(output_path, dpi=300) - - plt.close() - - -def get_cwms(regions, positions_df, motif_width): - idx_df = ( - positions_df - .select( - peak_idx=pl.col("peak_id"), - start_idx=pl.col("start_untrimmed") - pl.col("peak_region_start"), - is_revcomp=pl.col("is_revcomp") - ) + idx_df = positions_df.select( + peak_idx=pl.col("peak_id"), + start_idx=pl.col("start_untrimmed") - pl.col("peak_region_start"), + is_revcomp=pl.col("is_revcomp"), ) - peak_idx = idx_df.get_column('peak_idx').to_numpy() - start_idx = idx_df.get_column('start_idx').to_numpy() + peak_idx = idx_df.get_column("peak_idx").to_numpy() + start_idx = idx_df.get_column("start_idx").to_numpy() is_revcomp = idx_df.get_column("is_revcomp").to_numpy().astype(bool) - # Ignore hits outside of region + # Filter hits that fall outside the region boundaries valid_mask = (start_idx >= 0) & (start_idx + motif_width <= regions.shape[2]) peak_idx = peak_idx[valid_mask] start_idx = start_idx[valid_mask] is_revcomp = is_revcomp[valid_mask] - row_idx = peak_idx[:,None,None] - pos_idx = start_idx[:,None,None] + np.zeros((1,1,motif_width), dtype=int) - pos_idx[~is_revcomp,:,:] += np.arange(motif_width)[None,None,:] - pos_idx[is_revcomp,:,:] += np.arange(motif_width)[None,None,::-1] - nuc_idx = np.zeros((peak_idx.shape[0],4,1), dtype=int) - nuc_idx[~is_revcomp,:,:] += np.arange(4)[None,:,None] - nuc_idx[is_revcomp,:,:] += np.arange(4)[None,::-1,None] + row_idx = peak_idx[:, None, None] + pos_idx = start_idx[:, None, None] + np.zeros((1, 1, motif_width), dtype=int) + pos_idx[~is_revcomp, :, :] += np.arange(motif_width)[None, None, :] + pos_idx[is_revcomp, :, :] += np.arange(motif_width)[None, None, ::-1] + nuc_idx = np.zeros((peak_idx.shape[0], 4, 1), dtype=int) + nuc_idx[~is_revcomp, :, :] += np.arange(4)[None, :, None] + nuc_idx[is_revcomp, :, :] += np.arange(4)[None, ::-1, None] seqs = regions[row_idx, nuc_idx, pos_idx] - + with warnings.catch_warnings(): - warnings.filterwarnings(action='ignore', message='invalid value encountered in divide') - warnings.filterwarnings(action='ignore', message='Mean of empty slice') + warnings.filterwarnings( + action="ignore", message="invalid value encountered in divide" + ) + warnings.filterwarnings(action="ignore", message="Mean of empty slice") cwms = seqs.mean(axis=0) return cwms -def tfmodisco_comparison(regions, hits_df, peaks_df, seqlets_df, motifs_df, cwms_modisco, - motif_names, modisco_half_width, motif_width, compute_recall): +def tfmodisco_comparison( + regions: Float[ndarray, "N 4 L"], + hits_df: Union[pl.DataFrame, pl.LazyFrame], + peaks_df: pl.DataFrame, + seqlets_df: Union[pl.DataFrame, pl.LazyFrame, None], + motifs_df: pl.DataFrame, + cwms_modisco: Float[ndarray, "M 4 W"], + motif_names: List[str], + modisco_half_width: int, + motif_width: int, + compute_recall: bool, +) -> Tuple[ + Dict[str, Dict[str, Any]], + pl.DataFrame, + Dict[str, Dict[str, Float[ndarray, "4 W"]]], + Dict[str, Dict[str, Tuple[int, int]]], +]: + """Compare Fi-NeMo hits with TF-MoDISco seqlets and compute evaluation metrics. + + This function performs comprehensive comparison between Fi-NeMo hit calls + and TF-MoDISco seqlets, computing recall metrics, CWM similarities, + and extracting contribution weight matrices for visualization. + + Parameters + ---------- + regions : Float[ndarray, "N 4 L"] + Contribution score regions multiplied by one-hot sequences. + Shape: (n_peaks, 4, region_length) + hits_df : Union[pl.DataFrame, pl.LazyFrame] + Fi-NeMo hit calls with required columns: + - peak_id, start_untrimmed, end_untrimmed, strand, motif_name + peaks_df : pl.DataFrame + Peak metadata with columns: + - peak_id, chr_id, peak_region_start + seqlets_df : Optional[pl.DataFrame] + TF-MoDISco seqlets with columns: + - chr_id, start_untrimmed, is_revcomp, motif_name + If None, only basic hit statistics are computed. + motifs_df : pl.DataFrame + Motif metadata with columns: + - motif_name, strand, motif_id, motif_start, motif_end + cwms_modisco : Float[ndarray, "M 4 W"] + TF-MoDISco contribution weight matrices. + Shape: (n_modisco_motifs, 4, motif_width) + motif_names : List[str] + Names of motifs to analyze. + modisco_half_width : int + Half-width for restricting hits to central region for fair comparison. + motif_width : int + Width of motifs for CWM extraction. + compute_recall : bool + Whether to compute recall metrics requiring seqlets_df. + + Returns + ------- + report_data : Dict[str, Dict[str, Any]] + Per-motif evaluation metrics including: + - num_hits_total, num_hits_restricted, num_seqlets + - num_overlaps, seqlet_recall, cwm_similarity + report_df : pl.DataFrame + Tabular format of report_data for easy analysis. + cwms : Dict[str, Dict[str, Float[ndarray, "4 W"]]] + Extracted CWMs for each motif and condition: + - hits_fc, hits_rc: Forward/reverse complement hits + - modisco_fc, modisco_rc: TF-MoDISco forward/reverse + - seqlets_only, hits_restricted_only: Non-overlapping instances + cwm_trim_bounds : Dict[str, Dict[str, Tuple[int, int]]] + Trimming boundaries for each CWM type and motif. + + Notes + ----- + - Hits are filtered to central region defined by modisco_half_width + - CWM similarity is computed as normalized dot product between hit and TF-MoDISco CWMs + - Recall metrics require both hits_df and seqlets_df to be non-empty + - Missing motifs are handled gracefully with empty DataFrames + + Raises + ------ + ValueError + If required columns are missing from input DataFrames. + """ + + # Ensure hits_df is LazyFrame for consistent operations + if isinstance(hits_df, pl.DataFrame): + hits_df = hits_df.lazy() + hits_df = ( - hits_df - .with_columns(pl.col('peak_id').cast(pl.UInt32)) - .join( - peaks_df.lazy(), on="peak_id", how="inner" - ) + hits_df.with_columns(pl.col("peak_id").cast(pl.UInt32)) + .join(peaks_df.lazy(), on="peak_id", how="inner") .select( chr_id=pl.col("chr_id"), start_untrimmed=pl.col("start_untrimmed"), end_untrimmed=pl.col("end_untrimmed"), - is_revcomp=pl.col("strand") == '-', + is_revcomp=pl.col("strand") == "-", motif_name=pl.col("motif_name"), peak_region_start=pl.col("peak_region_start"), - peak_id=pl.col("peak_id") + peak_id=pl.col("peak_id"), ) ) - hits_unique = hits_df.unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"]) - - region_len = regions.shape[2] - center = region_len / 2 - hits_filtered = ( - hits_df - .filter( - ((pl.col("start_untrimmed") - pl.col("peak_region_start")) >= (center - modisco_half_width)) - & ((pl.col("end_untrimmed") - pl.col("peak_region_start")) <= (center + modisco_half_width)) - ) - .unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"]) + hits_unique = hits_df.unique( + subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"] ) - - if compute_recall: - overlaps_df = ( - hits_filtered.join( - seqlets_df, - on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], - how="inner", - ) - .collect() - ) - seqlets_only_df = ( - seqlets_df.join( - hits_df, - on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], - how="anti", - ) - .collect() + region_len = regions.shape[2] + center = region_len / 2 + hits_filtered = hits_df.filter( + ( + (pl.col("start_untrimmed") - pl.col("peak_region_start")) + >= (center - modisco_half_width) ) - - hits_only_filtered_df = ( - hits_filtered.join( - seqlets_df, - on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], - how="anti", - ) - .collect() + & ( + (pl.col("end_untrimmed") - pl.col("peak_region_start")) + <= (center + modisco_half_width) ) + ).unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"]) hits_by_motif = hits_unique.collect().partition_by("motif_name", as_dict=True) - hits_fitered_by_motif = hits_filtered.collect().partition_by("motif_name", as_dict=True) - - if seqlets_df is not None: - seqlets_by_motif = seqlets_df.collect().partition_by("motif_name", as_dict=True) + hits_filtered_by_motif = hits_filtered.collect().partition_by( + "motif_name", as_dict=True + ) - if compute_recall: + if seqlets_df is None: + seqlets_collected = None + seqlets_lazy = None + elif isinstance(seqlets_df, pl.LazyFrame): + seqlets_collected = seqlets_df.collect() + seqlets_lazy = seqlets_df + else: + seqlets_collected = seqlets_df + seqlets_lazy = seqlets_df.lazy() + + if seqlets_collected is not None: + seqlets_by_motif = seqlets_collected.partition_by("motif_name", as_dict=True) + else: + seqlets_by_motif = {} + + if compute_recall and seqlets_lazy is not None: + overlaps_df = hits_filtered.join( + seqlets_lazy, + on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], + how="inner", + ).collect() + + seqlets_only_df = seqlets_lazy.join( + hits_df, + on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], + how="anti", + ).collect() + + hits_only_filtered_df = hits_filtered.join( + seqlets_lazy, + on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], + how="anti", + ).collect() + + # Create partition dictionaries overlaps_by_motif = overlaps_df.partition_by("motif_name", as_dict=True) seqlets_only_by_motif = seqlets_only_df.partition_by("motif_name", as_dict=True) - hits_only_filtered_by_motif = hits_only_filtered_df.partition_by("motif_name", as_dict=True) + hits_only_filtered_by_motif = hits_only_filtered_df.partition_by( + "motif_name", as_dict=True + ) + else: + overlaps_by_motif = {} + seqlets_only_by_motif = {} + hits_only_filtered_by_motif = {} report_data = {} cwms = {} @@ -241,12 +332,18 @@ def tfmodisco_comparison(regions, hits_df, peaks_df, seqlets_df, motifs_df, cwms dummy_df = hits_df.clear().collect() for m in motif_names: hits = hits_by_motif.get((m,), dummy_df) - hits_filtered = hits_fitered_by_motif.get((m,), dummy_df) + hits_filtered = hits_filtered_by_motif.get((m,), dummy_df) + + # Initialize default values + seqlets = dummy_df + overlaps = dummy_df + seqlets_only = dummy_df + hits_only_filtered = dummy_df if seqlets_df is not None: seqlets = seqlets_by_motif.get((m,), dummy_df) - if compute_recall: + if compute_recall and seqlets_df is not None: overlaps = overlaps_by_motif.get((m,), dummy_df) seqlets_only = seqlets_only_by_motif.get((m,), dummy_df) hits_only_filtered = hits_only_filtered_by_motif.get((m,), dummy_df) @@ -259,48 +356,56 @@ def tfmodisco_comparison(regions, hits_df, peaks_df, seqlets_df, motifs_df, cwms if seqlets_df is not None: report_data[m]["num_seqlets"] = seqlets.height - if compute_recall: + if compute_recall and seqlets_df is not None: report_data[m] |= { "num_overlaps": overlaps.height, "num_seqlets_only": seqlets_only.height, "num_hits_restricted_only": hits_only_filtered.height, "seqlet_recall": np.float64(overlaps.height) / seqlets.height + if seqlets.height > 0 + else 0.0, } - motif_data_fc = motifs_df.row(by_predicate=(pl.col("motif_name") == m) - & (pl.col("strand") == "+"), named=True) - motif_data_rc = motifs_df.row(by_predicate=(pl.col("motif_name") == m) - & (pl.col("strand") == "-"), named=True) + motif_data_fc = motifs_df.row( + by_predicate=(pl.col("motif_name") == m) & (pl.col("strand") == "+"), + named=True, + ) + motif_data_rc = motifs_df.row( + by_predicate=(pl.col("motif_name") == m) & (pl.col("strand") == "-"), + named=True, + ) cwms[m] = { "hits_fc": get_cwms(regions, hits, motif_width), "modisco_fc": cwms_modisco[motif_data_fc["motif_id"]], "modisco_rc": cwms_modisco[motif_data_rc["motif_id"]], } - cwms[m]["hits_rc"] = cwms[m]["hits_fc"][::-1,::-1] + cwms[m]["hits_rc"] = cwms[m]["hits_fc"][::-1, ::-1] - if compute_recall: + if compute_recall and seqlets_df is not None: cwms[m] |= { "seqlets_only": get_cwms(regions, seqlets_only, motif_width), - "hits_restricted_only": get_cwms(regions, hits_only_filtered, motif_width), + "hits_restricted_only": get_cwms( + regions, hits_only_filtered, motif_width + ), } bounds_fc = (motif_data_fc["motif_start"], motif_data_fc["motif_end"]) bounds_rc = (motif_data_rc["motif_start"], motif_data_rc["motif_end"]) - + cwm_trim_bounds[m] = { "hits_fc": bounds_fc, "modisco_fc": bounds_fc, "modisco_rc": bounds_rc, - "hits_rc": bounds_rc + "hits_rc": bounds_rc, } - if compute_recall: + if compute_recall and seqlets_df is not None: cwm_trim_bounds[m] |= { "seqlets_only": bounds_fc, "hits_restricted_only": bounds_fc, } - + hits_cwm = cwms[m]["hits_fc"] modisco_cwm = cwms[m]["modisco_fc"] hnorm = np.sqrt((hits_cwm**2).sum()) @@ -315,132 +420,127 @@ def tfmodisco_comparison(regions, hits_df, peaks_df, seqlets_df, motifs_df, cwms return report_data, report_df, cwms, cwm_trim_bounds -class LogoGlyph(AbstractPathEffect): - def __init__(self, glyph, ref_glyph='E', font_props=None, - offset=(0., 0.), **kwargs): - - super().__init__(offset) +def seqlet_confusion( + hits_df: Union[pl.DataFrame, pl.LazyFrame], + seqlets_df: Union[pl.DataFrame, pl.LazyFrame], + peaks_df: pl.DataFrame, + motif_names: List[str], + motif_width: int, +) -> Tuple[pl.DataFrame, Float[ndarray, "M M"]]: + """Compute confusion matrix between TF-MoDISco seqlets and Fi-NeMo hits. + + This function creates a confusion matrix showing the overlap between + TF-MoDISco seqlets (ground truth) and Fi-NeMo hits across different motifs. + Overlap frequencies are estimated using binned genomic coordinates. + + Parameters + ---------- + hits_df : Union[pl.DataFrame, pl.LazyFrame] + Fi-NeMo hit calls with required columns: + - peak_id, start_untrimmed, end_untrimmed, strand, motif_name + seqlets_df : pl.DataFrame + TF-MoDISco seqlets with required columns: + - chr_id, start_untrimmed, end_untrimmed, motif_name + peaks_df : pl.DataFrame + Peak metadata for joining coordinates: + - peak_id, chr_id + motif_names : List[str] + Names of motifs to include in confusion matrix. + Determines matrix dimensions. + motif_width : int + Width used for binning genomic coordinates. + Positions are binned to motif_width resolution. + + Returns + ------- + confusion_df : pl.DataFrame + Detailed confusion matrix in tabular format with columns: + - motif_name_seqlets : Seqlet motif labels (rows) + - motif_name_hits : Hit motif labels (columns) + - frac_overlap : Fraction of seqlets overlapping with hits + confusion_mat : Float[ndarray, "M M"] + Confusion matrix where M = len(motif_names). + Entry (i,j) = fraction of motif i seqlets overlapping with motif j hits. + Rows represent seqlet motifs, columns represent hit motifs. + + Notes + ----- + - Genomic coordinates are binned to motif_width resolution for overlap detection + - Only exact bin overlaps are considered (same chr_id, start_bin, end_bin) + - Fractions are computed as: overlaps / total_seqlets_per_motif + - Missing motif combinations result in zero entries in the confusion matrix + + Raises + ------ + ValueError + If required columns are missing from input DataFrames. + KeyError + If motif names in data don't match those in motif_names list. + """ + bin_size = motif_width - path_orig = TextPath((0, 0), glyph, size=1, prop=font_props) - dims = path_orig.get_extents() - ref_dims = TextPath((0, 0), ref_glyph, size=1, prop=font_props).get_extents() + # Ensure hits_df is LazyFrame for consistent operations + if isinstance(hits_df, pl.DataFrame): + hits_df = hits_df.lazy() - h_scale = 1 / dims.height - ref_width = max(dims.width, ref_dims.width) - w_scale = 1 / ref_width - w_shift = (1 - dims.width / ref_width) / 2 - x_shift = -dims.x0 - y_shift = -dims.y0 - stretch = ( - Affine2D() - .translate(tx=x_shift, ty=y_shift) - .scale(sx=w_scale, sy=h_scale) - .translate(tx=w_shift, ty=0) + hits_binned = ( + hits_df.with_columns( + peak_id=pl.col("peak_id").cast(pl.UInt32), + is_revcomp=pl.col("strand") == "-", ) + .join(peaks_df.lazy(), on="peak_id", how="inner") + .unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"]) + .select( + chr_id=pl.col("chr_id"), + start_bin=pl.col("start_untrimmed") // bin_size, + end_bin=pl.col("end_untrimmed") // bin_size, + motif_name=pl.col("motif_name"), + ) + ) - self.path = stretch.transform_path(path_orig) - - #: The dictionary of keywords to update the graphics collection with. - self._gc = kwargs - - def draw_path(self, renderer, gc, tpath, affine, rgbFace): - return renderer.draw_path(gc, self.path, affine, rgbFace) - + seqlets_lazy = seqlets_df.lazy() + seqlets_binned = seqlets_lazy.select( + chr_id=pl.col("chr_id"), + start_bin=pl.col("start_untrimmed") // bin_size, + end_bin=pl.col("end_untrimmed") // bin_size, + motif_name=pl.col("motif_name"), + ) -def plot_logo(ax, heights, glyphs, colors=None, font_props=None, shade_bounds=None): - if colors is None: - colors = {g: None for g in glyphs} + overlaps_df = seqlets_binned.join( + hits_binned, on=["chr_id", "start_bin", "end_bin"], how="inner", suffix="_hits" + ) - ax.margins(x=0, y=0) - - pos_values = np.clip(heights, 0, None) - neg_values = np.clip(heights, None, 0) - pos_order = np.argsort(pos_values, axis=0) - neg_order = np.argsort(neg_values, axis=0)[::-1,:] - pos_reorder = np.argsort(pos_order, axis=0) - neg_reorder = np.argsort(neg_order, axis=0) - pos_offsets = np.take_along_axis( - np.cumsum( - np.take_along_axis(pos_values, pos_order, axis=0), axis=0 - ), pos_reorder, axis=0 + seqlet_counts = ( + seqlets_binned.group_by("motif_name").len(name="num_seqlets").collect() ) - neg_offsets = np.take_along_axis( - np.cumsum( - np.take_along_axis(neg_values, neg_order, axis=0), axis=0 - ), neg_reorder, axis=0 + overlap_counts = ( + overlaps_df.group_by(["motif_name", "motif_name_hits"]) + .len(name="num_overlaps") + .collect() ) - bottoms = pos_offsets + neg_offsets - heights - - x = np.arange(heights.shape[1]) - - for glyph, height, bottom in zip(glyphs, heights, bottoms): - ax.bar(x, height, 0.95, bottom=bottom, - path_effects=[LogoGlyph(glyph, font_props=font_props)], color=colors[glyph]) - - if shade_bounds is not None: - start, end = shade_bounds - ax.axvspan(start - 0.5, end - 0.5, color='0.9', zorder=-1) - - ax.axhline(zorder=-1, linewidth=0.5, color='black') - - -LOGO_ALPHABET = 'ACGT' -LOGO_COLORS = {"A": '#109648', "C": '#255C99', "G": '#F7B32B', "T": '#D62839'} -LOGO_FONT = FontProperties(weight="bold") - -def plot_cwms(cwms, trim_bounds, out_dir, alphabet=LOGO_ALPHABET, colors=LOGO_COLORS, font=LOGO_FONT): - for m, v in cwms.items(): - motif_dir = os.path.join(out_dir, m) - os.makedirs(motif_dir, exist_ok=True) - for cwm_type, cwm in v.items(): - output_path = os.path.join(motif_dir, f"{cwm_type}.png") - - fig, ax = plt.subplots(figsize=(10,2)) - plot_logo(ax, cwm, alphabet, colors=colors, font_props=font, shade_bounds=trim_bounds[m][cwm_type]) - - for name, spine in ax.spines.items(): - spine.set_visible(False) - - plt.savefig(output_path, dpi=100) - plt.close(fig) - - -def plot_hit_vs_seqlet_counts(recall_data, output_path): - x = [] - y = [] - m = [] - for k, v in recall_data.items(): - x.append(v["num_hits_total"]) - y.append(v["num_seqlets"]) - m.append(k) - - lim = max(np.amax(x), np.amax(y)) - - fig, ax = plt.subplots(figsize=(8,8)) - ax.axline((0, 0), (lim, lim), color="0.3", linewidth=0.7, linestyle=(0, (5, 5))) - ax.scatter(x, y, s=5) - for i, txt in enumerate(m): - short = abbreviate_motif_name(txt) - ax.annotate(short, (x[i], y[i]), fontsize=8, weight="bold") - - ax.set_yscale('log') - ax.set_xscale('log') - - ax.set_xlabel("Hits per motif") - ax.set_ylabel("Seqlets per motif") - - plt.savefig(output_path, dpi=300) - plt.close() + num_motifs = len(motif_names) + confusion_mat = np.zeros((num_motifs, num_motifs), dtype=np.float32) + name_to_idx = {m: i for i, m in enumerate(motif_names)} + + confusion_df = overlap_counts.join( + seqlet_counts, on="motif_name", how="inner" + ).select( + motif_name_seqlets=pl.col("motif_name"), + motif_name_hits=pl.col("motif_name_hits"), + frac_overlap=pl.col("num_overlaps") / pl.col("num_seqlets"), + ) + confusion_idx_df = confusion_df.select( + row_idx=pl.col("motif_name_seqlets").replace_strict(name_to_idx), + col_idx=pl.col("motif_name_hits").replace_strict(name_to_idx), + frac_overlap=pl.col("frac_overlap"), + ) -def write_report(report_df, motif_names, out_path, compute_recall, use_seqlets): - template_str = importlib.resources.files(templates).joinpath('report.html').read_text() - template = Template(template_str) - report = template.render(report_data=report_df.iter_rows(named=True), - motif_names=motif_names, compute_recall=compute_recall, - use_seqlets=use_seqlets) - with open(out_path, "w") as f: - f.write(report) + row_idx = confusion_idx_df["row_idx"].to_numpy() + col_idx = confusion_idx_df["col_idx"].to_numpy() + frac_overlap = confusion_idx_df["frac_overlap"].to_numpy() + confusion_mat[row_idx, col_idx] = frac_overlap + return confusion_df, confusion_mat diff --git a/src/finemo/hitcaller.py b/src/finemo/hitcaller.py index ee0e666..27cea8c 100644 --- a/src/finemo/hitcaller.py +++ b/src/finemo/hitcaller.py @@ -1,173 +1,554 @@ +"""Hit caller module implementing the Fi-NeMo motif instance calling algorithm. + +This module provides the core functionality for identifying transcription factor +binding motif instances in neural network contribution scores using a competitive +optimization approach based on proximal gradient descent. + +The main algorithm fits a sparse linear model where contribution scores are +reconstructed as a weighted combination of motif contribution weight matrices (CWMs) +at specific genomic positions. The sparsity constraint ensures that only the most +significant motif instances are called. +""" + import warnings +from typing import Tuple, Union, Optional, Dict, List +from abc import ABC, abstractmethod import numpy as np +from numpy import ndarray import torch import torch.nn.functional as F +from torch import Tensor import polars as pl +from jaxtyping import Float, Int, Bool from tqdm import tqdm -def prox_grad_step(coefficients, importance_scale, cwms, contribs, sequences, - lambdas, step_sizes): - """ - Proximal gradient descent optimization step for non-negative lasso - - coefficients: (b, m, l - w + 1) - importance_scale: (b, 1, l - w + 1) - cwms: (m, 4, w) - contribs: (b, 4, l) - sequences: (b, 4, l) or dummy scalar - lambdas: (1, m, 1) - - For details on proximal gradient descent: https://yuxinchen2020.github.io/ele520_math_data/lectures/lasso_algorithm_extension.pdf, slide 22 - For details on duality gap computation: https://stanford.edu/~boyd/papers/pdf/l1_ls.pdf, Section III +def prox_grad_step( + coefficients: Float[Tensor, "B M P"], + importance_scale: Float[Tensor, "B 1 P"], + cwms: Float[Tensor, "M 4 W"], + contribs: Float[Tensor, "B 4 L"], + sequences: Union[Int[Tensor, "B 4 L"], int], + lambdas: Float[Tensor, "1 M 1"], + step_sizes: Float[Tensor, "B 1 1"], +) -> Tuple[Float[Tensor, "B M P"], Float[Tensor, " B"], Float[Tensor, " B"]]: + """Perform a proximal gradient descent optimization step for non-negative lasso. + + This function implements a single optimization step of the Fi-NeMo algorithm, + which uses proximal gradient descent to solve a sparse reconstruction problem. + The goal is to represent contribution scores as a sparse linear combination + of motif contribution weight matrices (CWMs). + + Dimension notation: + - B = batch size (number of regions processed simultaneously) + - M = number of motifs + - L = sequence length + - W = motif width (length of each motif) + - P = L - W + 1 (number of valid motif positions) + + Parameters + ---------- + coefficients : Float[Tensor, "B M P"] + Current coefficient matrix representing motif instance strengths. + importance_scale : Float[Tensor, "B 1 P"] + Scaling factors for importance-weighted reconstruction. + cwms : Float[Tensor, "M 4 W"] + Motif contribution weight matrices for all motifs. + 4 represents the DNA bases (A, C, G, T). + contribs : Float[Tensor, "B 4 L"] + Target contribution scores to reconstruct. + sequences : Float[Tensor, "B 4 L"] | int + One-hot encoded DNA sequences. Can be a scalar (1) for hypothetical mode. + lambdas : Float[Tensor, "1 M 1"] + L1 regularization weights for each motif. + step_sizes : Float[Tensor, "B 1 1"] + Optimization step sizes for each batch element. + + Returns + ------- + c_next : Float[Tensor, "B M P"] + Updated coefficient matrix after the optimization step (shape: batch_size × motifs × positions). + dual_gap : Float[Tensor, " B"] + Duality gap for convergence assessment (shape: batch_size). + nll : Float[Tensor, " B"] + Negative log likelihood (proportional to MSE, shape: batch_size). + + Notes + ----- + The algorithm uses proximal gradient descent to solve: + + minimize_c: ||contribs - conv_transpose(c * importance_scale, cwms) * sequences||²₂ + λ||c||₁ + + subject to: c ≥ 0 + + References + ---------- + - Proximal gradient descent: https://yuxinchen2020.github.io/ele520_math_data/lectures/lasso_algorithm_extension.pdf, slide 22 + - Duality gap computation: https://stanford.edu/~boyd/papers/pdf/l1_ls.pdf, Section III """ - # Forward pass + # Forward pass: convolution operations require specific tensor layouts coef_adj = coefficients * importance_scale - pred_unmasked = F.conv_transpose1d(coef_adj, cwms) # (b, 4, l) - pred = pred_unmasked * sequences # (b, 4, l) + pred_unmasked = F.conv_transpose1d(coef_adj, cwms) # (B, 4, L) + pred = ( + pred_unmasked * sequences + ) # (B, 4, L), element-wise masking for projected mode # Compute gradient * -1 - residuals = contribs - pred # (b, 4, l) - ngrad = F.conv1d(residuals, cwms) * importance_scale # (b, m, l - w + 1) + residuals = contribs - pred # (B, 4, L) + ngrad = F.conv1d(residuals, cwms) * importance_scale # (B, M, P) # Negative log likelihood (proportional to MSE) - nll = (residuals**2).sum(dim=(1,2)) # (b) - - # Compute duality gap - dual_norm = (ngrad / lambdas).amax(dim=(1,2)) # (b) - dual_scale = (torch.clamp(1 / dual_norm, max=1.)**2 + 1) / 2 # (b) - nll_scaled = nll * dual_scale # (b) - - dual_diff = (residuals * contribs).sum(dim=(1,2)) # (b) - l1_term = (torch.abs(coefficients).sum(dim=2, keepdim=True) * lambdas).sum(dim=(1,2)) # (b) - # l1_term = torch.linalg.vector_norm((coefficients * lambdas), ord=1, dim=(1,2)) # (b) - dual_gap = (nll_scaled - dual_diff + l1_term).abs() # (b) + nll = (residuals**2).sum(dim=(1, 2)) # (B) - # Compute proximal gradient descent step - c_next = coefficients + step_sizes * (ngrad - lambdas) # (b, m, l - w + 1) - c_next = F.relu(c_next) # (b, m, l - w + 1) + # Compute duality gap for convergence assessment + dual_norm = (ngrad / lambdas).amax(dim=(1, 2)) # (B) + dual_scale = (torch.clamp(1 / dual_norm, max=1.0) ** 2 + 1) / 2 # (B) + nll_scaled = nll * dual_scale # (B) - return c_next, dual_gap, nll + dual_diff = (residuals * contribs).sum(dim=(1, 2)) # (B) + l1_term = (torch.abs(coefficients).sum(dim=2, keepdim=True) * lambdas).sum( + dim=(1, 2) + ) # (B) + dual_gap = (nll_scaled - dual_diff + l1_term).abs() # (B) + # Compute proximal gradient descent step + c_next = coefficients + step_sizes * (ngrad - lambdas) # (B, M, P) + c_next = F.relu(c_next) # Ensure non-negativity constraint -def optimizer_step(cwms, contribs, importance_scale, sequences, coef_inter, coef, i, step_sizes, l, lambdas): - """ - Non-negative lasso optimizer step with momentum. + return c_next, dual_gap, nll - cwms: (m, 4, w) - contribs: (b, 4, l) - importance_scale: (b, 1, l - w + 1) - sequences: (b, 4, l) or dummy scalar - coef_inter, coef: (b, m, l - w + 1) - i, step_sizes: (b,) - For details on optimization algorithm: https://yuxinchen2020.github.io/ele520_math_data/lectures/lasso_algorithm_extension.pdf, slides 22, 27 +def optimizer_step( + cwms: Float[Tensor, "M 4 W"], + contribs: Float[Tensor, "B 4 L"], + importance_scale: Float[Tensor, "B 1 P"], + sequences: Union[Int[Tensor, "B 4 L"], int], + coef_inter: Float[Tensor, "B M P"], + coef: Float[Tensor, "B M P"], + i: Float[Tensor, "B 1 1"], + step_sizes: Float[Tensor, "B 1 1"], + L: int, + lambdas: Float[Tensor, "1 M 1"], +) -> Tuple[ + Float[Tensor, "B M P"], + Float[Tensor, "B M P"], + Float[Tensor, " B"], + Float[Tensor, " B"], +]: + """Perform a non-negative lasso optimizer step with Nesterov momentum. + + This function combines proximal gradient descent with momentum acceleration + to improve convergence speed while maintaining the non-negative constraint + on coefficients. + + Dimension notation: + - B = batch size (number of regions processed simultaneously) + - M = number of motifs + - L = sequence length + - W = motif width (length of each motif) + - P = L - W + 1 (number of valid motif positions) + + Parameters + ---------- + cwms : Float[Tensor, "M 4 W"] + Motif contribution weight matrices. + contribs : Float[Tensor, "B 4 L"] + Target contribution scores. + importance_scale : Float[Tensor, "B 1 P"] + Importance scaling factors. + sequences : Union[Int[Tensor, "B 4 L"], int] + One-hot encoded sequences or scalar for hypothetical mode. + coef_inter : Float[Tensor, "B M P"] + Intermediate coefficient matrix (with momentum). + coef : Float[Tensor, "B M P"] + Current coefficient matrix. + i : Float[Tensor, "B 1 1"] + Iteration counter for each batch element. + step_sizes : Float[Tensor, "B 1 1"] + Step sizes for optimization. + L : int + Sequence length for normalization. + lambdas : Float[Tensor, "1 M 1"] + Regularization parameters. + + Returns + ------- + coef_inter : Float[Tensor, "B M P"] + Updated intermediate coefficients with momentum (shape: batch_size × motifs × positions). + coef : Float[Tensor, "B M P"] + Updated coefficient matrix (shape: batch_size × motifs × positions). + gap : Float[Tensor, " B"] + Normalized duality gap (shape: batch_size). + nll : Float[Tensor, " B"] + Normalized negative log likelihood (shape: batch_size). + + Notes + ----- + Uses Nesterov momentum with momentum coefficient i/(i+3) for improved + convergence properties. The duality gap and NLL are normalized by + sequence length for scale-invariant convergence assessment. + + References + ---------- + https://yuxinchen2020.github.io/ele520_math_data/lectures/lasso_algorithm_extension.pdf, slides 22, 27 """ coef_prev = coef # Proximal gradient descent step - coef, gap, nll = prox_grad_step(coef_inter, importance_scale, cwms, contribs, sequences, - lambdas, step_sizes) - gap = gap / l - nll = nll / (2 * l) - - # Compute updated coefficients - mom_term = i / (i + 3.) + coef, gap, nll = prox_grad_step( + coef_inter, importance_scale, cwms, contribs, sequences, lambdas, step_sizes + ) + gap = gap / L + nll = nll / (2 * L) + + # Compute updated coefficients with Nesterov momentum + mom_term = i / (i + 3.0) coef_inter = (1 + mom_term) * coef - mom_term * coef_prev return coef_inter, coef, gap, nll -def _to_channel_last_layout(tensor, **kwargs): - return tensor[:,:,:,None].to(memory_format=torch.channels_last, **kwargs).squeeze(3) +def _to_channel_last_layout(tensor: Tensor, **kwargs) -> torch.Tensor: + """Convert tensor to channel-last memory layout for optimized convolution operations. + Parameters + ---------- + tensor : torch.Tensor + Input tensor to convert. + **kwargs + Additional keyword arguments passed to tensor.to(). -def _signed_sqrt(x): + Returns + ------- + torch.Tensor + Tensor with channel-last memory layout. + """ + return ( + tensor[:, :, :, None].to(memory_format=torch.channels_last, **kwargs).squeeze(3) + ) + + +def _signed_sqrt(x: torch.Tensor) -> torch.Tensor: + """Apply signed square root transformation to input tensor. + + This transformation preserves the sign while applying square root to the + absolute value, which can help with numerical stability and gradient flow. + + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Transformed tensor with same shape as input. + """ return torch.sign(x) * torch.sqrt(torch.abs(x)) -class BatchLoaderBase: - def __init__(self, contribs, sequences, l, device): +class BatchLoaderBase(ABC): + """Base class for loading batches of contribution scores and sequences. + + This class provides common functionality for different input formats + including batch indexing and padding for consistent batch sizes. + + Dimension notation: + - N = number of sequences/regions in dataset + - L = sequence length + - B = batch size (number of regions processed simultaneously) + + Parameters + ---------- + contribs : Union[Float[Tensor, "N 4 L"], Float[Tensor, "N L"]] + Contribution scores array. + sequences : Int[Tensor, "N 4 L"] + One-hot encoded sequences array. + L : int + Sequence length. + device : torch.device + Target device for tensor operations. + """ + + def __init__( + self, + contribs: Union[Float[Tensor, "N 4 L"], Float[Tensor, "N L"]], + sequences: Int[Tensor, "N 4 L"], + L: int, + device: torch.device, + ) -> None: self.contribs = contribs self.sequences = sequences - self.l = l + self.L = L self.device = device - def _get_inds_and_pad_lens(self, start, end): - n = end - start + def _get_inds_and_pad_lens( + self, start: int, end: int + ) -> Tuple[Int[Tensor, " Z"], Tuple[int, ...]]: + """Get indices and padding lengths for batch loading. + + Parameters + ---------- + start : int + Start index for batch. + end : int + End index for batch. + + Returns + ------- + inds : Int[Tensor, " Z"] + Padded indices tensor with -1 for padding positions (shape: padded_batch_size). + pad_lens : tuple + Padding specification for F.pad (left, right, top, bottom, front, back). + """ + N = end - start end = min(end, self.contribs.shape[0]) - overhang = n - (end - start) + overhang = N - (end - start) pad_lens = (0, 0, 0, 0, 0, overhang) - inds = F.pad(torch.arange(start, end, dtype=torch.int), (0, overhang), value=-1).to(device=self.device) + inds = F.pad( + torch.arange(start, end, dtype=torch.int), (0, overhang), value=-1 + ).to(device=self.device) return inds, pad_lens - def load_batch(self, start, end): - raise NotImplementedError - + @abstractmethod + def load_batch( + self, start: int, end: int + ) -> Tuple[ + Float[Tensor, "B 4 L"], Union[Int[Tensor, "B 4 L"], int], Int[Tensor, " B"] + ]: + """Load a batch of data. + + Dimension notation: + - B = batch size (number of regions in this batch) + - L = sequence length + + Parameters + ---------- + start : int + Start index (used by subclasses). + end : int + End index (used by subclasses). + + Returns + ------- + contribs_batch : Float[Tensor, "B 4 L"] + Batch of contribution scores (shape: batch_size × 4_bases × L). + sequences_batch : Union[Int[Tensor, "B 4 L"], int] + Batch of one-hot encoded sequences (shape: batch_size × 4_bases × L) or scalar 1 for hypothetical mode. + inds_batch : Int[Tensor, " B"] + Batch indices mapping to original sequence indices (shape: batch_size). + + Notes + ----- + This is an abstract method that must be implemented by subclasses. + Parameters are intentionally unused in the base implementation. + """ + pass + class BatchLoaderCompactFmt(BatchLoaderBase): - def load_batch(self, start, end): - inds, pad_lens = self._get_inds_and_pad_lens(start, end) + """Batch loader for compact format contribution scores. + + Handles contribution scores in shape (N, L) representing projected + scores that need to be broadcasted to (N, 4, L) format. + """ - contribs_compact = F.pad(self.contribs[start:end,None,:], pad_lens) - contribs_batch = _to_channel_last_layout(contribs_compact, device=self.device, dtype=torch.float32) - sequences_batch = F.pad(self.sequences[start:end,:,:], pad_lens) # (b, 4, l) - sequences_batch = _to_channel_last_layout(sequences_batch, device=self.device, dtype=torch.int8) + def load_batch( + self, start: int, end: int + ) -> Tuple[Float[Tensor, "B 4 L"], Int[Tensor, "B 4 L"], Int[Tensor, " B"]]: + inds, pad_lens = self._get_inds_and_pad_lens(start, end) - # global_scale = ((contribs_batch**2).sum(dim=(1,2)) / self.l).sqrt() + contribs_compact = F.pad(self.contribs[start:end, None, :], pad_lens) + contribs_batch = _to_channel_last_layout( + contribs_compact, device=self.device, dtype=torch.float32 + ) + sequences_batch = F.pad(self.sequences[start:end, :, :], pad_lens) # (B, 4, L) + sequences_batch = _to_channel_last_layout( + sequences_batch, device=self.device, dtype=torch.int8 + ) - contribs_batch = contribs_batch * sequences_batch # (b, 4, l) + contribs_batch = contribs_batch * sequences_batch # (B, 4, L) return contribs_batch, sequences_batch, inds class BatchLoaderProj(BatchLoaderBase): - def load_batch(self, start, end): + """Batch loader for projected contribution scores. + + Handles contribution scores in shape (N, 4, L) where scores are + element-wise multiplied by one-hot sequences to get projected contributions. + """ + + def load_batch( + self, start: int, end: int + ) -> Tuple[Float[Tensor, "B 4 L"], Int[Tensor, "B 4 L"], Int[Tensor, " B"]]: inds, pad_lens = self._get_inds_and_pad_lens(start, end) - contribs_hyp = F.pad(self.contribs[start:end,:,:], pad_lens) - contribs_hyp = _to_channel_last_layout(contribs_hyp, device=self.device, dtype=torch.float32) - sequences_batch = F.pad(self.sequences[start:end,:,:], pad_lens) # (b, 4, l) - sequences_batch = _to_channel_last_layout(sequences_batch, device=self.device, dtype=torch.int8) + contribs_hyp = F.pad(self.contribs[start:end, :, :], pad_lens) + contribs_hyp = _to_channel_last_layout( + contribs_hyp, device=self.device, dtype=torch.float32 + ) + sequences_batch = F.pad(self.sequences[start:end, :, :], pad_lens) # (B, 4, L) + sequences_batch = _to_channel_last_layout( + sequences_batch, device=self.device, dtype=torch.int8 + ) contribs_batch = contribs_hyp * sequences_batch - # global_scale = ((contribs_batch**2).sum(dim=(1,2)) / self.l).sqrt() - # contribs_batch = torch.nan_to_num(contribs_batch / global_scale[:,None,None]) - return contribs_batch, sequences_batch, inds - + class BatchLoaderHyp(BatchLoaderBase): - def load_batch(self, start, end): - inds, pad_lens = self._get_inds_and_pad_lens(start, end) + """Batch loader for hypothetical contribution scores. + + Handles hypothetical contribution scores in shape (N, 4, L) where + scores represent counterfactual effects of base substitutions. + """ - contribs_batch = F.pad(self.contribs[start:end,:,:], pad_lens) - contribs_batch = _to_channel_last_layout(contribs_batch, device=self.device, dtype=torch.float32) + def load_batch( + self, start: int, end: int + ) -> Tuple[Float[Tensor, "B 4 L"], int, Int[Tensor, " B"]]: + inds, pad_lens = self._get_inds_and_pad_lens(start, end) - # global_scale = ((contribs_batch**2).sum(dim=(1,2)) / self.l).sqrt() - # contribs_batch = torch.nan_to_num(contribs_batch / global_scale[:,None,None]) + contribs_batch = F.pad(self.contribs[start:end, :, :], pad_lens) + contribs_batch = _to_channel_last_layout( + contribs_batch, device=self.device, dtype=torch.float32 + ) return contribs_batch, 1, inds -def fit_contribs(cwms, contribs, sequences, cwm_trim_mask, use_hypothetical, lambdas, step_size_max, step_size_min, sqrt_transform, - convergence_tol, max_steps, batch_size, step_adjust, post_filter, device, compile_optimizer, eps=1.): - """ - Call hits by fitting sparse linear model to contributions - - cwms: (m, 4, w) - contribs: (n, 4, l) or (n, l) - sequences: (n, 4, l) - cwm_trim_mask: (m, w) +def fit_contribs( + cwms: Float[ndarray, "M 4 W"], + contribs: Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]], + sequences: Int[ndarray, "N 4 L"], + cwm_trim_mask: Float[ndarray, "M W"], + use_hypothetical: bool, + lambdas: Float[ndarray, " M"], + step_size_max: float = 3.0, + step_size_min: float = 0.08, + sqrt_transform: bool = False, + convergence_tol: float = 0.0005, + max_steps: int = 10000, + batch_size: int = 2000, + step_adjust: float = 0.7, + post_filter: bool = True, + device: Optional[torch.device] = None, + compile_optimizer: bool = False, + eps: float = 1.0, +) -> Tuple[pl.DataFrame, pl.DataFrame]: + """Call motif hits by fitting sparse linear model to contribution scores. + + This is the main function implementing the Fi-NeMo algorithm. It identifies + motif instances by solving a sparse reconstruction problem where contribution + scores are approximated as a linear combination of motif CWMs at specific + positions. The optimization uses proximal gradient descent with momentum. + + Parameters + ---------- + cwms : Float[ndarray, "M 4 W"] + Motif contribution weight matrices where: + - M = number of motifs (transcription factor binding patterns) + - 4 = DNA bases (A, C, G, T dimensions) + - W = motif width (length of each motif pattern) + contribs : Float[ndarray, "N 4 L"] | Float[ndarray, "N L"] + Neural network contribution scores where: + - N = number of regions in dataset + - L = sequence length + Can be hypothetical (N, 4, L) or projected (N, L) format. + sequences : Int[ndarray, "N 4 L"] + One-hot encoded DNA sequences (shape: num_regions × 4_bases × L). + cwm_trim_mask : Float[ndarray, "M W"] + Binary mask indicating which positions of each CWM to use (shape: num_motifs × motif_width). + use_hypothetical : bool + Whether to use hypothetical contribution scores (True) or + projected scores (False). + lambdas : Float[ndarray, " M"] + L1 regularization weights for each motif (shape: num_motifs). + step_size_max : float, default 3.0 + Maximum optimization step size. + step_size_min : float, default 0.08 + Minimum optimization step size (for convergence failure detection). + sqrt_transform : bool, default False + Whether to apply signed square root transformation to inputs. + convergence_tol : float, default 0.0005 + Convergence tolerance based on duality gap. + max_steps : int, default 10000 + Maximum number of optimization steps. + batch_size : int, default 2000 + Number of regions to process simultaneously. + step_adjust : float, default 0.7 + Factor to reduce step size when optimization diverges. + post_filter : bool, default True + Whether to filter hits based on similarity threshold. + device : torch.device, optional + Target device for computation. Auto-detected if None. + compile_optimizer : bool, default False + Whether to JIT compile the optimizer for speed. + eps : float, default 1.0 + Small constant for numerical stability. + + Returns + ------- + hits_df : pl.DataFrame + DataFrame containing called motif hits with columns: + - peak_id: Region index + - motif_id: Motif index + - hit_start: Start position of hit + - hit_coefficient: Hit strength coefficient + - hit_similarity: Cosine similarity with motif + - hit_importance: Total contribution score in hit region + - hit_importance_sq: Sum of squared contributions (for normalization) + qc_df : pl.DataFrame + DataFrame containing quality control metrics with columns: + - peak_id: Region index + - nll: Final negative log likelihood + - dual_gap: Final duality gap + - num_steps: Number of optimization steps + - step_size: Final step size + - global_scale: Region-level scaling factor + + Notes + ----- + The algorithm solves the optimization problem: + + minimize_c: ||contribs - Σⱼ convolve(c * scale, cwms[j]) * sequences||²₂ + Σⱼ λⱼ||c[:,j]||₁ + + subject to: c ≥ 0 + + where c[i,j] represents the strength of motif j at position i. + + The importance scaling balances reconstruction across different + motifs and positions based on the local contribution magnitude. + + Examples + -------- + >>> hits_df, qc_df = fit_contribs( + ... cwms=motif_cwms, + ... contribs=contrib_scores, + ... sequences=onehot_seqs, + ... cwm_trim_mask=trim_masks, + ... use_hypothetical=False, + ... lambdas=np.array([0.7, 0.8]), + ... step_size_max=3.0, + ... step_size_min=0.08, + ... sqrt_transform=False, + ... convergence_tol=0.0005, + ... max_steps=10000, + ... batch_size=1000, + ... step_adjust=0.7, + ... post_filter=True, + ... device=None, + ... compile_optimizer=False + ... ) """ - m, _, w = cwms.shape - n, _, l = sequences.shape + M, _, W = cwms.shape + N, _, L = sequences.shape - b = batch_size + B = batch_size # Using uppercase for consistency with dimension notation if device is None: if torch.cuda.is_available(): @@ -175,138 +556,188 @@ def fit_contribs(cwms, contribs, sequences, cwm_trim_mask, use_hypothetical, lam else: device = torch.device("cpu") warnings.warn("No GPU available. Running on CPU.", RuntimeWarning) - + # Compile optimizer if requested global optimizer_step if compile_optimizer: optimizer_step = torch.compile(optimizer_step, fullgraph=True) - # Convert inputs to pytorch tensors - cwms = torch.from_numpy(cwms) - contribs = torch.from_numpy(contribs) - sequences = torch.from_numpy(sequences) - cwm_trim_mask = torch.from_numpy(cwm_trim_mask)[:,None,:].repeat(1,4,1) - lambdas = torch.from_numpy(lambdas)[None,:,None].to(device=device, dtype=torch.float32) - - cwms = _to_channel_last_layout(cwms, device=device, dtype=torch.float32) - cwm_trim_mask = _to_channel_last_layout(cwm_trim_mask, device=device, dtype=torch.float32) - cwms = cwms * cwm_trim_mask + # Convert inputs to PyTorch tensors with proper device placement + cwms_tensor: torch.Tensor = torch.from_numpy(cwms) + contribs_tensor: torch.Tensor = torch.from_numpy(contribs) + sequences_tensor: torch.Tensor = torch.from_numpy(sequences) + cwm_trim_mask_tensor = torch.from_numpy(cwm_trim_mask)[:, None, :].repeat(1, 4, 1) + lambdas_tensor: torch.Tensor = torch.from_numpy(lambdas)[None, :, None].to( + device=device, dtype=torch.float32 + ) + + # Convert to channel-last layout for optimized convolution operations + cwms_tensor = _to_channel_last_layout( + cwms_tensor, device=device, dtype=torch.float32 + ) + cwm_trim_mask_tensor = _to_channel_last_layout( + cwm_trim_mask_tensor, device=device, dtype=torch.float32 + ) + cwms_tensor = cwms_tensor * cwm_trim_mask_tensor # Apply trimming mask if sqrt_transform: - cwms = _signed_sqrt(cwms) - cwm_norm = (cwms**2).sum(dim=(1,2)).sqrt() - cwms = cwms / cwm_norm[:,None,None] + cwms_tensor = _signed_sqrt(cwms_tensor) + cwm_norm = (cwms_tensor**2).sum(dim=(1, 2)).sqrt() + cwms_tensor = cwms_tensor / cwm_norm[:, None, None] # Initialize batch loader - if len(contribs.shape) == 3: + if len(contribs_tensor.shape) == 3: if use_hypothetical: - batch_loader = BatchLoaderHyp(contribs, sequences, l, device) + batch_loader = BatchLoaderHyp(contribs_tensor, sequences_tensor, L, device) else: - batch_loader = BatchLoaderProj(contribs, sequences, l, device) - elif len(contribs.shape) == 2: + batch_loader = BatchLoaderProj(contribs_tensor, sequences_tensor, L, device) + elif len(contribs_tensor.shape) == 2: if use_hypothetical: - raise ValueError("Input regions do not contain hypothetical contribution scores") + raise ValueError( + "Input regions do not contain hypothetical contribution scores" + ) else: - batch_loader = BatchLoaderCompactFmt(contribs, sequences, l, device) + batch_loader = BatchLoaderCompactFmt( + contribs_tensor, sequences_tensor, L, device + ) else: - raise ValueError(f"Input contributions array is of incorrect shape {contribs.shape}") - - # Intialize output container objects - hit_idxs_lst = [] - coefficients_lst = [] - similarity_lst = [] - importance_lst = [] - importance_sq_lst = [] - qc_lsts = {"nll": [], "dual_gap": [], "num_steps": [], "step_size": [], "global_scale": [], "peak_id": []} + raise ValueError( + f"Input contributions array is of incorrect shape {contribs_tensor.shape}" + ) + + # Initialize output container objects + hit_idxs_lst: List[ndarray] = [] + coefficients_lst: List[ndarray] = [] + similarity_lst: List[ndarray] = [] + importance_lst: List[ndarray] = [] + importance_sq_lst: List[ndarray] = [] + qc_lsts: Dict[str, List[ndarray]] = { + "nll": [], + "dual_gap": [], + "num_steps": [], + "step_size": [], + "global_scale": [], + "peak_id": [], + } # Initialize buffers for optimizer - coef_inter = torch.zeros((b, m, l - w + 1)) # (b, m, l - w + 1) + coef_inter: Float[Tensor, "B M P"] = torch.zeros( + (B, M, L - W + 1) + ) # (B, M, P) where P = L - W + 1 coef_inter = _to_channel_last_layout(coef_inter, device=device, dtype=torch.float32) - coef = torch.zeros_like(coef_inter) - i = torch.zeros((b, 1, 1), dtype=torch.int, device=device) - step_sizes = torch.full((b, 1, 1), step_size_max, dtype=torch.float32, device=device) - - converged = torch.full((b,), True, dtype=torch.bool, device=device) - num_load = b - - contribs_buf = torch.zeros((b, 4, l)) - contribs_buf = _to_channel_last_layout(contribs_buf, device=device, dtype=torch.float32) - + coef: Float[Tensor, "B M P"] = torch.zeros_like(coef_inter) + i: Float[Tensor, "B 1 1"] = torch.zeros((B, 1, 1), dtype=torch.int, device=device) + step_sizes: Float[Tensor, "B 1 1"] = torch.full( + (B, 1, 1), step_size_max, dtype=torch.float32, device=device + ) + + converged: Bool[Tensor, " B"] = torch.full( + (B,), True, dtype=torch.bool, device=device + ) + num_load = B + + contribs_buf: Float[Tensor, "B 4 L"] = torch.zeros((B, 4, L)) + contribs_buf = _to_channel_last_layout( + contribs_buf, device=device, dtype=torch.float32 + ) + + seqs_buf: Union[Int[Tensor, "B 4 L"], int] if use_hypothetical: seqs_buf = 1 else: - seqs_buf = torch.zeros((b, 4, l)) + seqs_buf = torch.zeros((B, 4, L)) seqs_buf = _to_channel_last_layout(seqs_buf, device=device, dtype=torch.int8) - importance_scale_buf = torch.zeros((b, m, l - w + 1)) - importance_scale_buf = _to_channel_last_layout(importance_scale_buf, device=device, dtype=torch.float32) + importance_scale_buf: Float[Tensor, "B M P"] = torch.zeros((B, M, L - W + 1)) + importance_scale_buf = _to_channel_last_layout( + importance_scale_buf, device=device, dtype=torch.float32 + ) - inds_buf = torch.zeros((b,), dtype=torch.int, device=device) - global_scale_buf = torch.zeros((b,), dtype=torch.float, device=device) + inds_buf: Int[Tensor, " B"] = torch.zeros((B,), dtype=torch.int, device=device) + global_scale_buf: Float[Tensor, " B"] = torch.zeros( + (B,), dtype=torch.float, device=device + ) - with tqdm(disable=None, unit="regions", total=n, ncols=120) as pbar: + with tqdm(disable=None, unit="regions", total=N, ncols=120) as pbar: num_complete = 0 next_ind = 0 - while num_complete < n: + while num_complete < N: # Retire converged peaks and fill buffer with new data if num_load > 0: load_start = next_ind load_end = load_start + num_load - next_ind = min(load_end, contribs.shape[0]) + next_ind = min(load_end, contribs_tensor.shape[0]) - batch_data = batch_loader.load_batch(load_start, load_end) + batch_data = batch_loader.load_batch(int(load_start), int(load_end)) contribs_batch, seqs_batch, inds_batch = batch_data if sqrt_transform: contribs_batch = _signed_sqrt(contribs_batch) - - global_scale_batch = ((contribs_batch**2).sum(dim=(1,2)) / l).sqrt() - contribs_batch = torch.nan_to_num(contribs_batch / global_scale_batch[:,None,None]) - importance_scale_batch = (F.conv1d(contribs_batch**2, cwm_trim_mask) + eps)**(-0.5) + global_scale_batch = ((contribs_batch**2).sum(dim=(1, 2)) / L).sqrt() + contribs_batch = torch.nan_to_num( + contribs_batch / global_scale_batch[:, None, None] + ) + + importance_scale_batch = ( + F.conv1d(contribs_batch**2, cwm_trim_mask_tensor) + eps + ) ** (-0.5) importance_scale_batch = importance_scale_batch.clamp(max=10) - contribs_buf[converged,:,:] = contribs_batch + contribs_buf[converged, :, :] = contribs_batch if not use_hypothetical: - seqs_buf[converged,:,:] = seqs_batch + seqs_buf[converged, :, :] = seqs_batch # type: ignore + + importance_scale_buf[converged, :, :] = importance_scale_batch - importance_scale_buf[converged,:,:] = importance_scale_batch - inds_buf[converged] = inds_batch global_scale_buf[converged] = global_scale_batch - coef_inter[converged,:,:] *= 0 - coef[converged,:,:] *= 0 + coef_inter[converged, :, :] *= 0 + coef[converged, :, :] *= 0 i[converged] *= 0 step_sizes[converged] = step_size_max # Optimization step - coef_inter, coef, gap, nll = optimizer_step(cwms, contribs_buf, importance_scale_buf, seqs_buf, coef_inter, coef, - i, step_sizes, l, lambdas) + coef_inter, coef, gap, nll = optimizer_step( + cwms_tensor, + contribs_buf, + importance_scale_buf, + seqs_buf, + coef_inter, + coef, + i, + step_sizes, + L, + lambdas_tensor, + ) i += 1 # Assess convergence of each peak being optimized. Reset diverged peaks with lower step size. active = inds_buf >= 0 diverged = ~torch.isfinite(gap) & active - coef_inter[diverged,:,:] *= 0 - coef[diverged,:,:] *= 0 + coef_inter[diverged, :, :] *= 0 + coef[diverged, :, :] *= 0 i[diverged] *= 0 - step_sizes[diverged,:,:] *= step_adjust + step_sizes[diverged, :, :] *= step_adjust timeouts = (i > max_steps).squeeze() & active if timeouts.sum().item() > 0: timeout_inds = inds_buf[timeouts] for ind in timeout_inds: - warnings.warn(f"Region {ind} has not converged within max_steps={max_steps} iterations.", RuntimeWarning) + warnings.warn( + f"Region {ind} has not converged within max_steps={max_steps} iterations.", + RuntimeWarning, + ) fails = (step_sizes < step_size_min).squeeze() & active if fails.sum().item() > 0: fail_inds = inds_buf[fails] for ind in fail_inds: warnings.warn(f"Optimizer failed for region {ind}.", RuntimeWarning) - + converged = ((gap <= convergence_tol) | timeouts | fails) & active num_load = converged.sum().item() @@ -315,27 +746,32 @@ def fit_contribs(cwms, contribs, sequences, cwm_trim_mask, use_hypothetical, lam inds_out = inds_buf[converged] global_scale_out = global_scale_buf[converged] - # Compute hit scores - coef_out = coef[converged,:,:] - importance_scale_out_dense = importance_scale_buf[converged,:,:] - importance_sq = importance_scale_out_dense**(-2) - eps + # Compute hit scores + coef_out = coef[converged, :, :] + importance_scale_out_dense = importance_scale_buf[converged, :, :] + importance_sq = importance_scale_out_dense ** (-2) - eps xcor_scale = importance_sq.sqrt() - contribs_converged = contribs_buf[converged,:,:] - importance_sum_out_dense = F.conv1d(torch.abs(contribs_converged), cwm_trim_mask) - xcov_out_dense = F.conv1d(contribs_converged, cwms) - # xcov_out_dense = F.conv1d(torch.abs(contribs_converged), cwms) + contribs_converged = contribs_buf[converged, :, :] + importance_sum_out_dense = F.conv1d( + torch.abs(contribs_converged), cwm_trim_mask_tensor + ) + xcov_out_dense = F.conv1d(contribs_converged, cwms_tensor) + # xcov_out_dense = F.conv1d(torch.abs(contribs_converged), cwms_tensor) xcor_out_dense = xcov_out_dense / xcor_scale if post_filter: - coef_out = coef_out * (xcor_out_dense >= lambdas) + coef_out = coef_out * (xcor_out_dense >= lambdas_tensor) - # Extract hit coordinates + # Extract hit coordinates using sparse tensor representation coef_out = coef_out.to_sparse() - hit_idxs_out = torch.clone(coef_out.indices()) - hit_idxs_out[0,:] = F.embedding(hit_idxs_out[0,:], inds_out[:,None]).squeeze() - # Map buffer index to peak index + # Tensor indexing operations for hit extraction + hit_idxs_out = torch.clone(coef_out.indices()) # Sparse tensor indices + hit_idxs_out[0, :] = F.embedding( + hit_idxs_out[0, :], inds_out[:, None] + ).squeeze() # Embedding lookup with complex indexing + # Map buffer index to peak index ind_tuple = torch.unbind(coef_out.indices()) importance_out = importance_sum_out_dense[ind_tuple] @@ -347,8 +783,8 @@ def fit_contribs(cwms, contribs, sequences, cwm_trim_mask, use_hypothetical, lam # Store outputs gap_out = gap[converged] nll_out = nll[converged] - step_out = i[converged,0,0] - step_sizes_out = step_sizes[converged,0,0] + step_out = i[converged, 0, 0] + step_sizes_out = step_sizes[converged, 0, 0] hit_idxs_lst.append(hit_idxs_out.numpy(force=True).T) coefficients_lst.append(scores_out_raw.numpy(force=True)) @@ -373,19 +809,19 @@ def fit_contribs(cwms, contribs, sequences, cwm_trim_mask, use_hypothetical, lam scores_importance = np.concatenate(importance_lst, axis=0) scores_importance_sq = np.concatenate(importance_sq_lst, axis=0) - hits = { - "peak_id": hit_idxs[:,0].astype(np.uint32), - "motif_id": hit_idxs[:,1].astype(np.uint32), - "hit_start": hit_idxs[:,2], + hits: Dict[str, ndarray] = { + "peak_id": hit_idxs[:, 0].astype(np.uint32), + "motif_id": hit_idxs[:, 1].astype(np.uint32), + "hit_start": hit_idxs[:, 2], "hit_coefficient": scores_coefficient, "hit_similarity": scores_similarity, "hit_importance": scores_importance, "hit_importance_sq": scores_importance_sq, } - qc = {k: np.concatenate(v, axis=0) for k, v in qc_lsts.items()} + qc: Dict[str, ndarray] = {k: np.concatenate(v, axis=0) for k, v in qc_lsts.items()} hits_df = pl.DataFrame(hits) qc_df = pl.DataFrame(qc) - return hits_df, qc_df \ No newline at end of file + return hits_df, qc_df diff --git a/src/finemo/main.py b/src/finemo/main.py index 77ce256..db78241 100644 --- a/src/finemo/main.py +++ b/src/finemo/main.py @@ -1,20 +1,89 @@ +"""Main CLI module for the Fi-NeMo motif instance calling pipeline. + +This module provides the command-line interface for all Fi-NeMo operations: +- Data preprocessing from various genomic formats +- Motif hit calling using the Fi-NeMo algorithm +- Report generation and result visualization +- Post-processing operations (hit collapsing, intersection) + +The CLI supports multiple input formats including bigWig, HDF5 (ChromBPNet/BPNet), +and TF-MoDISco format. +""" + from . import data_io import os import argparse import warnings - - -def extract_regions_bw(peaks_path, chrom_order_path, fa_path, bw_paths, out_path, region_width): +import inspect +from typing import Optional, List + +import polars as pl + + +def extract_regions_bw( + peaks_path: str, + chrom_order_path: Optional[str], + fa_path: str, + bw_paths: List[str], + out_path: str, + region_width: int = 1000, +) -> None: + """Extract genomic regions and contribution scores from bigWig and FASTA files. + + Parameters + ---------- + peaks_path : str + Path to ENCODE NarrowPeak format file. + chrom_order_path : str, optional + Path to chromosome ordering file. + fa_path : str + Path to genome FASTA file. + bw_paths : List[str] + List of bigWig file paths containing contribution scores. + out_path : str + Output path for NPZ file. + region_width : int, default 1000 + Width of regions to extract around peak summits. + + Notes + ----- + BigWig files only provide projected contribution scores. + """ half_width = region_width // 2 + # Load peak regions and extract sequences/contributions peaks_df = data_io.load_peaks(peaks_path, chrom_order_path, half_width) - sequences, contribs = data_io.load_regions_from_bw(peaks_df, fa_path, bw_paths, half_width) + sequences, contribs = data_io.load_regions_from_bw( + peaks_df, fa_path, bw_paths, half_width + ) + # Save processed data to NPZ format data_io.write_regions_npz(sequences, contribs, out_path, peaks_df=peaks_df) -def extract_regions_chrombpnet_h5(peaks_path, chrom_order_path, h5_paths, out_path, region_width): +def extract_regions_chrombpnet_h5( + peaks_path: Optional[str], + chrom_order_path: Optional[str], + h5_paths: List[str], + out_path: str, + region_width: int = 1000, +) -> None: + """Extract genomic regions and contribution scores from ChromBPNet HDF5 files. + + Parameters + ---------- + peaks_path : str, optional + Path to ENCODE NarrowPeak format file. If None, lacks absolute coordinates. + chrom_order_path : str, optional + Path to chromosome ordering file. + h5_paths : List[str] + List of ChromBPNet HDF5 file paths. + out_path : str + Output path for NPZ file. + region_width : int, default 1000 + Width of regions to extract around peak summits. + """ half_width = region_width // 2 if peaks_path is not None: @@ -27,7 +96,28 @@ def extract_regions_chrombpnet_h5(peaks_path, chrom_order_path, h5_paths, out_pa data_io.write_regions_npz(sequences, contribs, out_path, peaks_df=peaks_df) -def extract_regions_bpnet_h5(peaks_path, chrom_order_path, h5_paths, out_path, region_width): +def extract_regions_bpnet_h5( + peaks_path: Optional[str], + chrom_order_path: Optional[str], + h5_paths: List[str], + out_path: str, + region_width: int = 1000, +) -> None: + """Extract genomic regions and contribution scores from BPNet HDF5 files. + + Parameters + ---------- + peaks_path : str, optional + Path to ENCODE NarrowPeak format file. If None, output lacks absolute coordinates. + chrom_order_path : str, optional + Path to chromosome ordering file. + h5_paths : List[str] + List of BPNet HDF5 file paths. + out_path : str + Output path for NPZ file. + region_width : int, default 1000 + Width of regions to extract around peak summits. + """ half_width = region_width // 2 if peaks_path is not None: @@ -40,7 +130,31 @@ def extract_regions_bpnet_h5(peaks_path, chrom_order_path, h5_paths, out_path, r data_io.write_regions_npz(sequences, contribs, out_path, peaks_df=peaks_df) -def extract_regions_modisco_fmt(peaks_path, chrom_order_path, shaps_paths, ohe_path, out_path, region_width): +def extract_regions_modisco_fmt( + peaks_path: Optional[str], + chrom_order_path: Optional[str], + shaps_paths: List[str], + ohe_path: str, + out_path: str, + region_width: int = 1000, +) -> None: + """Extract genomic regions and contribution scores from TF-MoDISco format files. + + Parameters + ---------- + peaks_path : str, optional + Path to ENCODE NarrowPeak format file. If None, output lacks absolute coordinates. + chrom_order_path : str, optional + Path to chromosome ordering file. + shaps_paths : List[str] + List of paths to .npy/.npz files containing SHAP/attribution scores. + ohe_path : str + Path to .npy/.npz file containing one-hot encoded sequences. + out_path : str + Output path for NPZ file. + region_width : int, default 1000 + Width of regions to extract around peak summits. + """ half_width = region_width // 2 if peaks_path is not None: @@ -48,37 +162,135 @@ def extract_regions_modisco_fmt(peaks_path, chrom_order_path, shaps_paths, ohe_p else: peaks_df = None - sequences, contribs = data_io.load_regions_from_modisco_fmt(shaps_paths, ohe_path, half_width) + sequences, contribs = data_io.load_regions_from_modisco_fmt( + shaps_paths, ohe_path, half_width + ) data_io.write_regions_npz(sequences, contribs, out_path, peaks_df=peaks_df) -def call_hits(regions_path, peaks_path, modisco_h5_path, chrom_order_path, motifs_include_path, motif_names_path, - motif_lambdas_path, out_dir, cwm_trim_threshold, lambda_default, step_size_max, step_size_min, sqrt_transform, - convergence_tol, max_steps, batch_size, step_adjust, device, mode, no_post_filter, compile_optimizer): - +def call_hits( + regions_path: str, + peaks_path: Optional[str], + modisco_h5_path: str, + chrom_order_path: Optional[str], + motifs_include_path: Optional[str], + motif_names_path: Optional[str], + motif_lambdas_path: Optional[str], + out_dir: str, + cwm_trim_coords_path: Optional[str] = None, + cwm_trim_thresholds_path: Optional[str] = None, + cwm_trim_threshold_default: float = 0.3, + lambda_default: float = 0.7, + step_size_max: float = 3.0, + step_size_min: float = 0.08, + sqrt_transform: bool = False, + convergence_tol: float = 0.0005, + max_steps: int = 10000, + batch_size: int = 2000, + step_adjust: float = 0.7, + device: Optional[str] = None, + mode: str = "pp", + no_post_filter: bool = False, + compile_optimizer: bool = False, +) -> None: + """Call motif hits using the Fi-NeMo algorithm on preprocessed genomic regions. + + This function implements the core Fi-NeMo hit calling pipeline, which identifies + motif instances by solving a sparse reconstruction problem using proximal gradient + descent. The algorithm represents contribution scores as weighted combinations of + motif CWMs at specific positions. + + Parameters + ---------- + regions_path : str + Path to NPZ file containing preprocessed regions (sequences, contributions, + and optional peak coordinates). + peaks_path : str, optional + DEPRECATED. Path to ENCODE NarrowPeak format file. Peak data should be + included during preprocessing instead. + modisco_h5_path : str + Path to TF-MoDISco H5 file containing motif CWMs. + chrom_order_path : str, optional + DEPRECATED. Path to chromosome ordering file. + motifs_include_path : str, optional + Path to file listing motif names to include in analysis. + motif_names_path : str, optional + Path to file mapping motif IDs to custom names. + motif_lambdas_path : str, optional + Path to file specifying per-motif lambda values. + out_dir : str + Output directory for results. + cwm_trim_coords_path : str, optional + Path to file specifying custom motif trimming coordinates. + cwm_trim_thresholds_path : str, optional + Path to file specifying custom motif trimming thresholds. + cwm_trim_threshold_default : float, default 0.3 + Default threshold for motif trimming. + lambda_default : float, default 0.7 + Default L1 regularization weight. + step_size_max : float, default 3.0 + Maximum optimization step size. + step_size_min : float, default 0.08 + Minimum optimization step size. + sqrt_transform : bool, default False + Whether to apply signed square root transform to contributions. + convergence_tol : float, default 0.0005 + Convergence tolerance for duality gap. + max_steps : int, default 10000 + Maximum number of optimization steps. + batch_size : int, default 2000 + Batch size for GPU processing. + step_adjust : float, default 0.7 + Step size adjustment factor on divergence. + device : str, optional + DEPRECATED. Use CUDA_VISIBLE_DEVICES environment variable instead. + mode : str, default "pp" + Contribution type mode ('pp', 'ph', 'hp', 'hh') where 'p'=projected, 'h'=hypothetical. + no_post_filter : bool, default False + If True, skip post-hit-calling similarity filtering. + compile_optimizer : bool, default False + Whether to JIT-compile the optimizer for speed. + + Notes + ----- + The Fi-NeMo algorithm solves the optimization problem: + minimize_c: ||contribs - reconstruction(c)||²₂ + λ||c||₁ + subject to: c ≥ 0 + + where c represents motif hit coefficients and reconstruction uses convolution + with motif CWMs. + """ + params = locals() + import torch from . import hitcaller if device is not None: - warnings.warn("The `--device` flag is deprecated and will be removed in a future version. Please use the `CUDA_VISIBLE_DEVICES` environment variable to specify the GPU device.") - + warnings.warn( + "The `--device` flag is deprecated and will be removed in a future version. Please use the `CUDA_VISIBLE_DEVICES` environment variable to specify the GPU device." + ) + sequences, contribs, peaks_df, has_peaks = data_io.load_regions_npz(regions_path) region_width = sequences.shape[2] if region_width % 2 != 0: raise ValueError(f"Region width of {region_width} is not divisible by 2.") - + half_width = region_width // 2 num_regions = contribs.shape[0] if peaks_path is not None: - warnings.warn("Providing a peaks file to `call-hits` is deprecated, and this option will be removed in a future version. Peaks should instead be provided in the preprocessing step to be included in `regions.npz`.") + warnings.warn( + "Providing a peaks file to `call-hits` is deprecated, and this option will be removed in a future version. Peaks should instead be provided in the preprocessing step to be included in `regions.npz`." + ) peaks_df = data_io.load_peaks(peaks_path, chrom_order_path, half_width) has_peaks = True if not has_peaks: - warnings.warn("No peak region data provided. Output hits will lack absolute genomic coordinates.") + warnings.warn( + "No peak region data provided. Output hits will lack absolute genomic coordinates." + ) if mode == "pp": motif_type = "cwm" @@ -92,6 +304,10 @@ def call_hits(regions_path, peaks_path, modisco_h5_path, chrom_order_path, motif elif mode == "hh": motif_type = "hcwm" use_hypothetical_contribs = True + else: + raise ValueError( + f"Invalid mode: {mode}. Must be one of 'pp', 'ph', 'hp', 'hh'." + ) if motifs_include_path is not None: motifs_include = data_io.load_txt(motifs_include_path) @@ -107,15 +323,52 @@ def call_hits(regions_path, peaks_path, modisco_h5_path, chrom_order_path, motif motif_lambdas = data_io.load_mapping(motif_lambdas_path, float) else: motif_lambdas = None - - motifs_df, cwms, trim_masks, motif_names = data_io.load_modisco_motifs(modisco_h5_path, cwm_trim_threshold, motif_type, motifs_include, - motif_name_map, motif_lambdas, lambda_default, True) + + if cwm_trim_coords_path is not None: + trim_coords = data_io.load_mapping_tuple(cwm_trim_coords_path, int) + else: + trim_coords = None + + if cwm_trim_thresholds_path is not None: + trim_thresholds = data_io.load_mapping(cwm_trim_thresholds_path, float) + else: + trim_thresholds = None + + motifs_df, cwms, trim_masks, _ = data_io.load_modisco_motifs( + modisco_h5_path, + trim_coords, + trim_thresholds, + cwm_trim_threshold_default, + motif_type, + motifs_include, + motif_name_map, + motif_lambdas, + lambda_default, + True, + ) num_motifs = cwms.shape[0] motif_width = cwms.shape[2] lambdas = motifs_df.get_column("lambda").to_numpy(writable=True) - hits_df, qc_df = hitcaller.fit_contribs(cwms, contribs, sequences, trim_masks, use_hypothetical_contribs, lambdas, step_size_max, step_size_min, - sqrt_transform, convergence_tol, max_steps, batch_size, step_adjust, not no_post_filter, device, compile_optimizer) + device_obj = torch.device(device) if device is not None else None + hits_df, qc_df = hitcaller.fit_contribs( + cwms, + contribs, + sequences, + trim_masks, + use_hypothetical_contribs, + lambdas, + step_size_max, + step_size_min, + sqrt_transform, + convergence_tol, + max_steps, + batch_size, + step_adjust, + not no_post_filter, + device_obj, + compile_optimizer, + ) os.makedirs(out_dir, exist_ok=True) out_path_qc = os.path.join(out_dir, "peaks_qc.tsv") @@ -139,28 +392,86 @@ def call_hits(regions_path, peaks_path, modisco_h5_path, chrom_order_path, motif data_io.write_params(params, out_path_params) -def report(regions_path, hits_dir, modisco_h5_path, peaks_path, motifs_include_path, motif_names_path, - out_dir, modisco_region_width, cwm_trim_threshold, compute_recall, use_seqlets): - from . import evaluation +def report( + regions_path: str, + hits_dir: str, + modisco_h5_path: Optional[str], + peaks_path: Optional[str], + motifs_include_path: Optional[str], + motif_names_path: Optional[str], + out_dir: str, + modisco_region_width: int = 400, + cwm_trim_threshold: float = 0.3, + compute_recall: bool = True, + use_seqlets: bool = True, +) -> None: + """Generate comprehensive HTML report with statistics and visualizations. + + This function creates detailed analysis reports comparing Fi-NeMo hit calling + results with TF-MoDISco seqlets, including performance metrics, distribution + plots, and motif visualization. The report provides insights into hit calling + quality and motif discovery accuracy. + + Parameters + ---------- + regions_path : str + Path to NPZ file containing the same regions used for hit calling. + hits_dir : str + Path to directory containing Fi-NeMo hit calling outputs. + modisco_h5_path : str, optional + Path to TF-MoDISco H5 file. If None, seqlet comparisons are skipped. + peaks_path : str, optional + DEPRECATED. Peak coordinates should be included in regions file. + motifs_include_path : str, optional + DEPRECATED. This information is inferred from hit calling outputs. + motif_names_path : str, optional + DEPRECATED. This information is inferred from hit calling outputs. + out_dir : str + Output directory for report files. + modisco_region_width : int, default 400 + Width of regions used by TF-MoDISco (needed for coordinate conversion). + cwm_trim_threshold : float, default 0.3 + DEPRECATED. This information is inferred from hit calling outputs. + compute_recall : bool, default True + Whether to compute recall metrics against TF-MoDISco seqlets. + use_seqlets : bool, default True + Whether to include seqlet-based comparisons in the report. + + Notes + ----- + The generated report includes: + - Hit vs seqlet count comparisons + - Motif CWM visualizations + - Hit statistic distributions + - Co-occurrence heatmaps + - Confusion matrices for overlapping motifs + """ + from . import evaluation, visualization sequences, contribs, peaks_df, _ = data_io.load_regions_npz(regions_path) if len(contribs.shape) == 3: regions = contribs * sequences elif len(contribs.shape) == 2: - regions = contribs[:,None,:] * sequences + regions = contribs[:, None, :] * sequences + else: + raise ValueError(f"Unexpected contribs shape: {contribs.shape}") half_width = regions.shape[2] // 2 modisco_half_width = modisco_region_width // 2 if peaks_path is not None: - warnings.warn("Providing a peaks file to `report` is deprecated, and this option will be removed in a future version. Peaks should instead be provided in the preprocessing step to be included in `regions.npz`.") - peaks_df = data_io.load_peaks(peaks_path, None, half_width) + warnings.warn( + "Providing a peaks file to `report` is deprecated, and this option will be removed in a future version. Peaks should instead be provided in the preprocessing step to be included in `regions.npz`." + ) + peaks_df = data_io.load_peaks(peaks_path, None, half_width) if hits_dir.endswith(".tsv"): - warnings.warn("Passing a hits.tsv file to `finemo report` is deprecated. Please provide the directory containing the hits.tsv file instead.") + warnings.warn( + "Passing a hits.tsv file to `finemo report` is deprecated. Please provide the directory containing the hits.tsv file instead." + ) hits_path = hits_dir - + hits_df = data_io.load_hits(hits_path, lazy=True) if motifs_include_path is not None: @@ -169,12 +480,29 @@ def report(regions_path, hits_dir, modisco_h5_path, peaks_path, motifs_include_p motifs_include = None if motif_names_path is not None: - motif_name_map = data_io.load_txt(motif_names_path) + motif_name_map = data_io.load_mapping(motif_names_path, str) else: motif_name_map = None - motifs_df, cwms_modisco, trim_masks, motif_names = data_io.load_modisco_motifs(modisco_h5_path, cwm_trim_threshold, "cwm", - motifs_include, motif_name_map, None, None, True) + if modisco_h5_path is not None: + motifs_df, cwms_modisco, _, motif_names = data_io.load_modisco_motifs( + modisco_h5_path, + None, + None, + cwm_trim_threshold, + "cwm", + motifs_include, + motif_name_map, + None, + 1.0, + True, + ) + else: + # When no modisco_h5_path is provided in legacy TSV mode, we can't compute motifs + # This will cause an error later, but that's expected behavior + raise ValueError( + "modisco_h5_path is required when providing a hits.tsv file directly" + ) else: hits_df_path = os.path.join(hits_dir, "hits.tsv") @@ -188,273 +516,851 @@ def report(regions_path, hits_dir, modisco_h5_path, peaks_path, motifs_include_p params_path = os.path.join(hits_dir, "parameters.json") params = data_io.load_params(params_path) - cwm_trim_threshold = params["cwm_trim_threshold"] + cwm_trim_threshold = params["cwm_trim_threshold_default"] if not use_seqlets: - warnings.warn("Usage of the `--no-seqlets` flag is deprecated and will be removed in a future version. Please omit the `--modisco-h5` argument instead.") + warnings.warn( + "Usage of the `--no-seqlets` flag is deprecated and will be removed in a future version. Please omit the `--modisco-h5` argument instead." + ) seqlets_df = None elif modisco_h5_path is None: compute_recall = False seqlets_df = None else: - seqlets_df = data_io.load_modisco_seqlets(modisco_h5_path, peaks_df, motifs_df, half_width, modisco_half_width, lazy=True) + seqlets_df = data_io.load_modisco_seqlets( + modisco_h5_path, + peaks_df, + motifs_df, + half_width, + modisco_half_width, + lazy=True, + ) motif_width = cwms_modisco.shape[2] - occ_df, coooc = evaluation.get_motif_occurences(hits_df, motif_names) + # Convert to LazyFrame if needed and ensure motif_names is a list + if isinstance(hits_df, pl.LazyFrame): + hits_df_lazy: pl.LazyFrame = hits_df + else: + hits_df_lazy: pl.LazyFrame = hits_df.lazy() + + motif_names_list: List[str] = list(motif_names) + + occ_df, coooc = evaluation.get_motif_occurences(hits_df_lazy, motif_names_list) + + report_data, report_df, cwms, trim_bounds = evaluation.tfmodisco_comparison( + regions, + hits_df, + peaks_df, + seqlets_df, + motifs_df, + cwms_modisco, + motif_names_list, + modisco_half_width, + motif_width, + compute_recall, + ) + + if seqlets_df is not None: + confusion_df, confusion_mat = evaluation.seqlet_confusion( + hits_df, seqlets_df, peaks_df, motif_names_list, motif_width + ) + else: + confusion_df, confusion_mat = None, None - report_data, report_df, cwms, trim_bounds = evaluation.tfmodisco_comparison(regions, hits_df, peaks_df, seqlets_df, motifs_df, - cwms_modisco, motif_names, modisco_half_width, - motif_width, compute_recall) - os.makedirs(out_dir, exist_ok=True) - + occ_path = os.path.join(out_dir, "motif_occurrences.tsv") data_io.write_occ_df(occ_df, occ_path) data_io.write_report_data(report_df, cwms, out_dir) - evaluation.plot_hit_distributions(occ_df, motif_names, out_dir) - - coooc_path = os.path.join(out_dir, "motif_cooocurrence.png") - evaluation.plot_peak_motif_indicator_heatmap(coooc, motif_names, coooc_path) + visualization.plot_hit_stat_distributions(hits_df_lazy, motif_names_list, out_dir) + visualization.plot_hit_peak_distributions(occ_df, motif_names_list, out_dir) + visualization.plot_peak_motif_indicator_heatmap(coooc, motif_names_list, out_dir) plot_dir = os.path.join(out_dir, "CWMs") - evaluation.plot_cwms(cwms, trim_bounds, plot_dir) + visualization.plot_cwms(cwms, trim_bounds, plot_dir) if seqlets_df is not None: - seqlets_df = seqlets_df.collect() + seqlets_collected = ( + seqlets_df.collect() if isinstance(seqlets_df, pl.LazyFrame) else seqlets_df + ) seqlets_path = os.path.join(out_dir, "seqlets.tsv") - data_io.write_modisco_seqlets(seqlets_df, seqlets_path) + data_io.write_modisco_seqlets(seqlets_collected, seqlets_path) - plot_path = os.path.join(out_dir, "hit_vs_seqlet_counts.png") - evaluation.plot_hit_vs_seqlet_counts(report_data, plot_path) - - report_path = os.path.join(out_dir, "report.html") - evaluation.write_report(report_df, motif_names, report_path, compute_recall, seqlets_df is not None) + if confusion_df is not None and confusion_mat is not None: + seqlet_confusion_path = os.path.join(out_dir, "seqlet_confusion.tsv") + data_io.write_seqlet_confusion_df(confusion_df, seqlet_confusion_path) + visualization.plot_hit_vs_seqlet_counts(report_data, out_dir) + visualization.plot_seqlet_confusion_heatmap( + confusion_mat, motif_names_list, out_dir + ) -def cli(): + report_path = os.path.join(out_dir, "report.html") + visualization.write_report( + report_df, motif_names_list, report_path, compute_recall, seqlets_df is not None + ) + + +def collapse_hits(hits_path: str, out_path: str, overlap_frac: float = 0.2) -> None: + """Collapse overlapping hits by selecting the best hit per overlapping group. + + This function processes a set of motif hits and identifies overlapping hits, + keeping only the hit with the highest similarity score within each overlapping + group. This reduces redundancy in hit calls while preserving the most confident + predictions. + + Parameters + ---------- + hits_path : str + Path to input TSV file containing hit data (hits.tsv or hits_unique.tsv). + out_path : str + Path to output TSV file with additional 'is_primary' column. + overlap_frac : float, default 0.2 + Minimum fractional overlap for considering hits as overlapping. + For hits of lengths x and y, minimum overlap = overlap_frac * (x + y) / 2. + + Notes + ----- + The algorithm uses a sweep line approach with a heap data structure to + efficiently identify overlapping intervals and select the best hit based + on similarity scores. + """ + from . import postprocessing + + hits_df = data_io.load_hits(hits_path, lazy=False) + hits_collapsed_df = postprocessing.collapse_hits(hits_df, overlap_frac) + + data_io.write_hits_processed( + hits_collapsed_df, out_path, schema=data_io.HITS_COLLAPSED_DTYPES + ) + + +def intersect_hits(hits_paths: List[str], out_path: str, relaxed: bool = False) -> None: + """Find intersection of hits across multiple Fi-NeMo runs. + + This function identifies motif hits that are consistently called across + multiple independent runs, providing a way to assess reproducibility and + identify high-confidence hits. + + Parameters + ---------- + hits_paths : List[str] + List of paths to input TSV files from different runs. + out_path : str + Path to output TSV file containing intersection results. + Duplicate columns are suffixed with run index. + relaxed : bool, default False + If True, uses relaxed intersection criteria based only on motif names + and untrimmed coordinates. If False, assumes consistent region definitions + and motif trimming across runs. + + Notes + ----- + The strict intersection mode requires consistent input regions and motif + processing parameters across all runs. The relaxed mode is more permissive + but may not be suitable when genomic coordinates are unavailable. + """ + from . import postprocessing + + hits_dfs = [data_io.load_hits(hits_path, lazy=False) for hits_path in hits_paths] + hits_df = postprocessing.intersect_hits(hits_dfs, relaxed) + + data_io.write_hits_processed(hits_df, out_path, schema=None) + + +def cli() -> None: + """Command-line interface for the Fi-NeMo motif instance calling pipeline. + + This function provides the main entry point for all Fi-NeMo operations including: + - Data preprocessing from various formats (bigWig, HDF5, TF-MoDISco) + - Motif hit calling using the Fi-NeMo algorithm + - Report generation and visualization + - Post-processing operations (hit collapsing, intersection) + """ parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers(required=True, dest='cmd') - - - extract_regions_bw_parser = subparsers.add_parser("extract-regions-bw", formatter_class=argparse.ArgumentDefaultsHelpFormatter, - help="Extract sequences and contributions from FASTA and bigwig files.") - - extract_regions_bw_parser.add_argument("-p", "--peaks", type=str, required=True, - help="A peak regions file in ENCODE NarrowPeak format.") - extract_regions_bw_parser.add_argument("-C", "--chrom-order", type=str, default=None, - help="A tab-delimited file with chromosome names in the first column to define sort order of chromosomes. Missing chromosomes are ordered as they appear in -p/--peaks.") - extract_regions_bw_parser.add_argument("-f", "--fasta", type=str, required=True, - help="A genome FASTA file. If an .fai index file doesn't exist in the same directory, it will be created.") - extract_regions_bw_parser.add_argument("-b", "--bigwigs", type=str, required=True, nargs='+', - help="One or more bigwig files of contribution scores, with paths delimited by whitespace. Scores are averaged across files.") - - extract_regions_bw_parser.add_argument("-o", "--out-path", type=str, required=True, - help="The path to the output .npz file.") - - extract_regions_bw_parser.add_argument("-w", "--region-width", type=int, default=1000, - help="The width of the input region centered around each peak summit.") - - - extract_chrombpnet_regions_h5_parser = subparsers.add_parser("extract-regions-chrombpnet-h5", formatter_class=argparse.ArgumentDefaultsHelpFormatter, - help="Extract sequences and contributions from ChromBPNet contributions H5 files.") - - extract_chrombpnet_regions_h5_parser.add_argument("-p", "--peaks", type=str, default=None, - help="A peak regions file in ENCODE NarrowPeak format. If omitted, downstream outputs will lack absolute genomic coordinates.") - extract_chrombpnet_regions_h5_parser.add_argument("-C", "--chrom-order", type=str, default=None, - help="A tab-delimited file with chromosome names in the first column to define sort order of chromosomes. Missing chromosomes are ordered as they appear in -p/--peaks.") - - extract_chrombpnet_regions_h5_parser.add_argument("-c", "--h5s", type=str, required=True, nargs='+', - help="One or more H5 files of contribution scores, with paths delimited by whitespace. Scores are averaged across files.") - - extract_chrombpnet_regions_h5_parser.add_argument("-o", "--out-path", type=str, required=True, - help="The path to the output .npz file.") - - extract_chrombpnet_regions_h5_parser.add_argument("-w", "--region-width", type=int, default=1000, - help="The width of the input region centered around each peak summit.") - - - extract_regions_h5_parser = subparsers.add_parser("extract-regions-h5", formatter_class=argparse.ArgumentDefaultsHelpFormatter, - help="Extract sequences and contributions from ChromBPNet contributions H5 files. DEPRECATED: Use `extract-regions-chrombpnet-h5` instead.") - - extract_regions_h5_parser.add_argument("-p", "--peaks", type=str, default=None, - help="A peak regions file in ENCODE NarrowPeak format. If omitted, downstream outputs will lack absolute genomic coordinates.") - extract_regions_h5_parser.add_argument("-C", "--chrom-order", type=str, default=None, - help="A tab-delimited file with chromosome names in the first column to define sort order of chromosomes. Missing chromosomes are ordered as they appear in -p/--peaks.") - - extract_regions_h5_parser.add_argument("-c", "--h5s", type=str, required=True, nargs='+', - help="One or more H5 files of contribution scores, with paths delimited by whitespace. Scores are averaged across files.") - - extract_regions_h5_parser.add_argument("-o", "--out-path", type=str, required=True, - help="The path to the output .npz file.") - - extract_regions_h5_parser.add_argument("-w", "--region-width", type=int, default=1000, - help="The width of the input region centered around each peak summit.") - - - extract_bpnet_regions_h5_parser = subparsers.add_parser("extract-regions-bpnet-h5", formatter_class=argparse.ArgumentDefaultsHelpFormatter, - help="Extract sequences and contributions from BPNet contributions H5 files.") - - extract_bpnet_regions_h5_parser.add_argument("-p", "--peaks", type=str, default=None, - help="A peak regions file in ENCODE NarrowPeak format. If omitted, downstream outputs will lack absolute genomic coordinates.") - extract_bpnet_regions_h5_parser.add_argument("-C", "--chrom-order", type=str, default=None, - help="A tab-delimited file with chromosome names in the first column to define sort order of chromosomes. Missing chromosomes are ordered as they appear in -p/--peaks.") - - extract_bpnet_regions_h5_parser.add_argument("-c", "--h5s", type=str, required=True, nargs='+', - help="One or more H5 files of contribution scores, with paths delimited by whitespace. Scores are averaged across files.") - - extract_bpnet_regions_h5_parser.add_argument("-o", "--out-path", type=str, required=True, - help="The path to the output .npz file.") - - extract_bpnet_regions_h5_parser.add_argument("-w", "--region-width", type=int, default=1000, - help="The width of the input region centered around each peak summit.") - - - extract_regions_modisco_fmt_parser = subparsers.add_parser("extract-regions-modisco-fmt", formatter_class=argparse.ArgumentDefaultsHelpFormatter, - help="Extract sequences and contributions from tfmodisco-lite input files.") - - extract_regions_modisco_fmt_parser.add_argument("-p", "--peaks", type=str, default=None, - help="A peak regions file in ENCODE NarrowPeak format. If omitted, downstream outputs will lack absolute genomic coordinates.") - extract_regions_modisco_fmt_parser.add_argument("-C", "--chrom-order", type=str, default=None, - help="A tab-delimited file with chromosome names in the first column to define sort order of chromosomes. Missing chromosomes are ordered as they appear in -p/--peaks.") - - extract_regions_modisco_fmt_parser.add_argument("-s", "--sequences", type=str, required=True, - help="A .npy or .npz file containing one-hot encoded sequences.") - - extract_regions_modisco_fmt_parser.add_argument("-a", "--attributions", type=str, required=True, nargs='+', - help="One or more .npy or .npz files of hypothetical contribution scores, with paths delimited by whitespace. Scores are averaged across files.") - - extract_regions_modisco_fmt_parser.add_argument("-o", "--out-path", type=str, required=True, - help="The path to the output .npz file.") - - extract_regions_modisco_fmt_parser.add_argument("-w", "--region-width", type=int, default=1000, - help="The width of the input region centered around each peak summit.") - - - call_hits_parser = subparsers.add_parser("call-hits", formatter_class=argparse.ArgumentDefaultsHelpFormatter, - help="Call hits on provided sequences, contributions, and motif CWM's.") - - call_hits_parser.add_argument("-M", "--mode", type=str, default="pp", choices={"pp", "ph", "hp", "hh"}, - help="The type of attributions to use for CWM's and input contribution scores, respectively. 'h' for hypothetical and 'p' for projected.") - - call_hits_parser.add_argument("-r", "--regions", type=str, required=True, - help="A .npz file of input sequences, contributions, and coordinates. Can be generated using `finemo extract-regions-*` subcommands.") - call_hits_parser.add_argument("-m", "--modisco-h5", type=str, required=True, - help="A tfmodisco-lite output H5 file of motif patterns.") - - call_hits_parser.add_argument("-p", "--peaks", type=str, default=None, - help="DEPRECATED: Please provide this file to a preprocessing `finemo extract-regions-*` subcommand instead.") - call_hits_parser.add_argument("-C", "--chrom-order", type=str, default=None, - help="DEPRECATED: Please provide this file to a preprocessing `finemo extract-regions-*` subcommand instead.") - - call_hits_parser.add_argument("-I", "--motifs-include", type=str, default=None, - help="A tab-delimited file with tfmodisco motif names (e.g pos_patterns.pattern_0) in the first column to include in hit calling. If omitted, all motifs in the modisco H5 file are used.") - call_hits_parser.add_argument("-N", "--motif-names", type=str, default=None, - help="A tab-delimited file with tfmodisco motif names (e.g pos_patterns.pattern_0) in the first column and custom names in the second column. Omitted motifs default to tfmodisco names.") - - call_hits_parser.add_argument("-o", "--out-dir", type=str, required=True, - help="The path to the output directory.") - - call_hits_parser.add_argument("-t", "--cwm-trim-threshold", type=float, default=0.3, - help="The threshold to determine motif start and end positions within the full CWMs.") - - call_hits_parser.add_argument("-l", "--global-lambda", type=float, default=0.7, - help="The L1 regularization weight determining the sparsity of hits.") - call_hits_parser.add_argument("-L", "--motif-lambdas", type=str, default=None, - help="A tab-delimited file with tfmodisco motif names (e.g pos_patterns.pattern_0) in the first column and motif-specific lambdas in the second column. Omitted motifs default to the `--global-lambda` value.") - call_hits_parser.add_argument("-a", "--alpha", type=float, default=None, - help="DEPRECATED: Please use the `--lambda` argument instead.") - call_hits_parser.add_argument("-A", "--motif-alphas", type=str, default=None, - help="DEPRECATED: Please use the `--motif-lambdas` argument instead.") - - call_hits_parser.add_argument("-f", "--no-post-filter", action='store_true', - help="Do not perform post-hit-calling filtering. By default, hits are filtered based on a minimum cosine similarity of `lambda` with the input contributions.") - call_hits_parser.add_argument("-q", "--sqrt-transform", action='store_true', - help="Apply a signed square root transform to the input contributions and CWMs before hit calling.") - call_hits_parser.add_argument("-s", "--step-size-max", type=float, default=3., - help="The maximum optimizer step size.") - call_hits_parser.add_argument("-i", "--step-size-min", type=float, default=0.08, - help="The minimum optimizer step size.") - call_hits_parser.add_argument("-j", "--step-adjust", type=float, default=0.7, - help="The optimizer step size adjustment factor. If the optimizer diverges, the step size is multiplicatively adjusted by this factor") - call_hits_parser.add_argument("-c", "--convergence-tol", type=float, default=0.0005, - help="The tolerance for determining convergence. The optimizer exits when the duality gap is less than the tolerance.") - call_hits_parser.add_argument("-S", "--max-steps", type=int, default=10000, - help="The maximum number of optimization steps.") - call_hits_parser.add_argument("-b", "--batch-size", type=int, default=2000, - help="The batch size used for optimization.") - call_hits_parser.add_argument("-d", "--device", type=str, default=None, - help="DEPRECATED: Please use the `CUDA_VISIBLE_DEVICES` environment variable to specify the GPU device.") - call_hits_parser.add_argument("-J", "--compile", action='store_true', - help="JIT-compile the optimizer for faster performance. This may not be supported on older GPUs.") - - - report_parser = subparsers.add_parser("report", formatter_class=argparse.ArgumentDefaultsHelpFormatter, - help="Generate statistics and visualizations from hits and tfmodisco-lite motif data.") - - report_parser.add_argument("-r", "--regions", type=str, required=True, - help="A .npz file containing input sequences, contributions, and coordinates. Must be the same as that used for `finemo call-hits`.") - report_parser.add_argument("-H", "--hits", type=str, required=True, - help="The output directory generated by the `finemo call-hits` command on the regions specified in `--regions`.") - report_parser.add_argument("-p", "--peaks", type=str, default=None, - help="DEPRECATED: Please provide this file to a preprocessing `finemo extract-regions-*` subcommand instead.") - report_parser.add_argument("-m", "--modisco-h5", type=str, default=None, - help="The tfmodisco-lite output H5 file of motif patterns. Must be the same as that used for hit calling unless `--no-recall` is set. If omitted, seqlet-derived metrics will not be computed.") - report_parser.add_argument("-I", "--motifs-include", type=str, default=None, - help="DEPRECATED: This information is now inferred from the outputs of `finemo call-hits`.") - report_parser.add_argument("-N", "--motif-names", type=str, default=None, - help="DEPRECATED: This information is now inferred from the outputs of `finemo call-hits`.") - - report_parser.add_argument("-o", "--out-dir", type=str, required=True, - help="The path to the report output directory.") - - report_parser.add_argument("-W", "--modisco-region-width", type=int, default=400, - help="The width of the region around each peak summit used by tfmodisco-lite.") - report_parser.add_argument("-t", "--cwm-trim-threshold", type=float, default=0.3, - help="DEPRECATED: This information is now inferred from the outputs of `finemo call-hits`.") - report_parser.add_argument("-n", "--no-recall", action='store_true', - help="Do not compute motif recall metrics.") - report_parser.add_argument("-s", "--no-seqlets", action='store_true', - help="DEPRECATED: Please omit the `--modisco-h5` argument instead.") - + subparsers = parser.add_subparsers(required=True, dest="cmd") + + extract_regions_bw_parser = subparsers.add_parser( + "extract-regions-bw", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + help="Extract sequences and contributions from FASTA and bigwig files.", + ) + + extract_regions_bw_parser.add_argument( + "-p", + "--peaks", + type=str, + required=True, + help="A peak regions file in ENCODE NarrowPeak format.", + ) + extract_regions_bw_parser.add_argument( + "-C", + "--chrom-order", + type=str, + default=None, + help="A tab-delimited file with chromosome names in the first column to define sort order of chromosomes. Missing chromosomes are ordered as they appear in -p/--peaks.", + ) + extract_regions_bw_parser.add_argument( + "-f", + "--fasta", + type=str, + required=True, + help="A genome FASTA file. If an .fai index file doesn't exist in the same directory, it will be created.", + ) + extract_regions_bw_parser.add_argument( + "-b", + "--bigwigs", + type=str, + required=True, + nargs="+", + help="One or more bigwig files of contribution scores, with paths delimited by whitespace. Scores are averaged across files.", + ) + + extract_regions_bw_parser.add_argument( + "-o", + "--out-path", + type=str, + required=True, + help="The path to the output .npz file.", + ) + + extract_regions_bw_parser.add_argument( + "-w", + "--region-width", + type=int, + default=inspect.signature(extract_regions_bw) + .parameters["region_width"] + .default, + help="The width of the input region centered around each peak summit.", + ) + + extract_chrombpnet_regions_h5_parser = subparsers.add_parser( + "extract-regions-chrombpnet-h5", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + help="Extract sequences and contributions from ChromBPNet contributions H5 files.", + ) + + extract_chrombpnet_regions_h5_parser.add_argument( + "-p", + "--peaks", + type=str, + default=None, + help="A peak regions file in ENCODE NarrowPeak format. If omitted, downstream outputs will lack absolute genomic coordinates.", + ) + extract_chrombpnet_regions_h5_parser.add_argument( + "-C", + "--chrom-order", + type=str, + default=None, + help="A tab-delimited file with chromosome names in the first column to define sort order of chromosomes. Missing chromosomes are ordered as they appear in -p/--peaks.", + ) + + extract_chrombpnet_regions_h5_parser.add_argument( + "-c", + "--h5s", + type=str, + required=True, + nargs="+", + help="One or more H5 files of contribution scores, with paths delimited by whitespace. Scores are averaged across files.", + ) + + extract_chrombpnet_regions_h5_parser.add_argument( + "-o", + "--out-path", + type=str, + required=True, + help="The path to the output .npz file.", + ) + + extract_chrombpnet_regions_h5_parser.add_argument( + "-w", + "--region-width", + type=int, + default=inspect.signature(extract_regions_chrombpnet_h5) + .parameters["region_width"] + .default, + help="The width of the input region centered around each peak summit.", + ) + + extract_regions_h5_parser = subparsers.add_parser( + "extract-regions-h5", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + help="Extract sequences and contributions from ChromBPNet contributions H5 files. DEPRECATED: Use `extract-regions-chrombpnet-h5` instead.", + ) + + extract_regions_h5_parser.add_argument( + "-p", + "--peaks", + type=str, + default=None, + help="A peak regions file in ENCODE NarrowPeak format. If omitted, downstream outputs will lack absolute genomic coordinates.", + ) + extract_regions_h5_parser.add_argument( + "-C", + "--chrom-order", + type=str, + default=None, + help="A tab-delimited file with chromosome names in the first column to define sort order of chromosomes. Missing chromosomes are ordered as they appear in -p/--peaks.", + ) + + extract_regions_h5_parser.add_argument( + "-c", + "--h5s", + type=str, + required=True, + nargs="+", + help="One or more H5 files of contribution scores, with paths delimited by whitespace. Scores are averaged across files.", + ) + + extract_regions_h5_parser.add_argument( + "-o", + "--out-path", + type=str, + required=True, + help="The path to the output .npz file.", + ) + + extract_regions_h5_parser.add_argument( + "-w", + "--region-width", + type=int, + default=inspect.signature(extract_regions_chrombpnet_h5) + .parameters["region_width"] + .default, + help="The width of the input region centered around each peak summit.", + ) + + extract_bpnet_regions_h5_parser = subparsers.add_parser( + "extract-regions-bpnet-h5", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + help="Extract sequences and contributions from BPNet contributions H5 files.", + ) + + extract_bpnet_regions_h5_parser.add_argument( + "-p", + "--peaks", + type=str, + default=None, + help="A peak regions file in ENCODE NarrowPeak format. If omitted, downstream outputs will lack absolute genomic coordinates.", + ) + extract_bpnet_regions_h5_parser.add_argument( + "-C", + "--chrom-order", + type=str, + default=None, + help="A tab-delimited file with chromosome names in the first column to define sort order of chromosomes. Missing chromosomes are ordered as they appear in -p/--peaks.", + ) + + extract_bpnet_regions_h5_parser.add_argument( + "-c", + "--h5s", + type=str, + required=True, + nargs="+", + help="One or more H5 files of contribution scores, with paths delimited by whitespace. Scores are averaged across files.", + ) + + extract_bpnet_regions_h5_parser.add_argument( + "-o", + "--out-path", + type=str, + required=True, + help="The path to the output .npz file.", + ) + + extract_bpnet_regions_h5_parser.add_argument( + "-w", + "--region-width", + type=int, + default=inspect.signature(extract_regions_bpnet_h5) + .parameters["region_width"] + .default, + help="The width of the input region centered around each peak summit.", + ) + + extract_regions_modisco_fmt_parser = subparsers.add_parser( + "extract-regions-modisco-fmt", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + help="Extract sequences and contributions from tfmodisco-lite input files.", + ) + + extract_regions_modisco_fmt_parser.add_argument( + "-p", + "--peaks", + type=str, + default=None, + help="A peak regions file in ENCODE NarrowPeak format. If omitted, downstream outputs will lack absolute genomic coordinates.", + ) + extract_regions_modisco_fmt_parser.add_argument( + "-C", + "--chrom-order", + type=str, + default=None, + help="A tab-delimited file with chromosome names in the first column to define sort order of chromosomes. Missing chromosomes are ordered as they appear in -p/--peaks.", + ) + + extract_regions_modisco_fmt_parser.add_argument( + "-s", + "--sequences", + type=str, + required=True, + help="A .npy or .npz file containing one-hot encoded sequences.", + ) + + extract_regions_modisco_fmt_parser.add_argument( + "-a", + "--attributions", + type=str, + required=True, + nargs="+", + help="One or more .npy or .npz files of hypothetical contribution scores, with paths delimited by whitespace. Scores are averaged across files.", + ) + + extract_regions_modisco_fmt_parser.add_argument( + "-o", + "--out-path", + type=str, + required=True, + help="The path to the output .npz file.", + ) + + extract_regions_modisco_fmt_parser.add_argument( + "-w", + "--region-width", + type=int, + default=inspect.signature(extract_regions_modisco_fmt) + .parameters["region_width"] + .default, + help="The width of the input region centered around each peak summit.", + ) + + call_hits_parser = subparsers.add_parser( + "call-hits", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + help="Call hits on provided sequences, contributions, and motif CWM's.", + ) + + call_hits_parser.add_argument( + "-M", + "--mode", + type=str, + default=inspect.signature(call_hits).parameters["mode"].default, + choices={"pp", "ph", "hp", "hh"}, + help="The type of attributions to use for CWM's and input contribution scores, respectively. 'h' for hypothetical and 'p' for projected.", + ) + + call_hits_parser.add_argument( + "-r", + "--regions", + type=str, + required=True, + help="A .npz file of input sequences, contributions, and coordinates. Can be generated using `finemo extract-regions-*` subcommands.", + ) + call_hits_parser.add_argument( + "-m", + "--modisco-h5", + type=str, + required=True, + help="A tfmodisco-lite output H5 file of motif patterns.", + ) + + call_hits_parser.add_argument( + "-p", + "--peaks", + type=str, + default=None, + help="DEPRECATED: Please provide this file to a preprocessing `finemo extract-regions-*` subcommand instead.", + ) + call_hits_parser.add_argument( + "-C", + "--chrom-order", + type=str, + default=None, + help="DEPRECATED: Please provide this file to a preprocessing `finemo extract-regions-*` subcommand instead.", + ) + + call_hits_parser.add_argument( + "-I", + "--motifs-include", + type=str, + default=None, + help="A tab-delimited file with tfmodisco motif names (e.g pos_patterns.pattern_0) in the first column to include in hit calling. If omitted, all motifs in the modisco H5 file are used.", + ) + call_hits_parser.add_argument( + "-N", + "--motif-names", + type=str, + default=None, + help="A tab-delimited file with tfmodisco motif names (e.g pos_patterns.pattern_0) in the first column and custom names in the second column. Omitted motifs default to tfmodisco names.", + ) + + call_hits_parser.add_argument( + "-o", + "--out-dir", + type=str, + required=True, + help="The path to the output directory.", + ) + + call_hits_parser.add_argument( + "-t", + "--cwm-trim-threshold", + type=float, + default=inspect.signature(call_hits) + .parameters["cwm_trim_threshold_default"] + .default, + help="The default threshold to determine motif start and end positions within the full CWMs.", + ) + call_hits_parser.add_argument( + "-T", + "--cwm-trim-thresholds", + type=str, + default=None, + help="A tab-delimited file with tfmodisco motif names (e.g pos_patterns.pattern_0) in the first column and custom trim thresholds in the second column. Omitted motifs default to the `--cwm-trim-threshold` value.", + ) + call_hits_parser.add_argument( + "-R", + "--cwm-trim-coords", + type=str, + default=None, + help="A tab-delimited file with tfmodisco motif names (e.g pos_patterns.pattern_0) in the first column and custom trim start and end coordinates in the second and third columns, respectively. Omitted motifs default to `--cwm-trim-thresholds` values.", + ) + + call_hits_parser.add_argument( + "-l", + "--global-lambda", + type=float, + default=inspect.signature(call_hits).parameters["lambda_default"].default, + help="The default L1 regularization weight determining the sparsity of hits.", + ) + call_hits_parser.add_argument( + "-L", + "--motif-lambdas", + type=str, + default=None, + help="A tab-delimited file with tfmodisco motif names (e.g pos_patterns.pattern_0) in the first column and motif-specific lambdas in the second column. Omitted motifs default to the `--global-lambda` value.", + ) + call_hits_parser.add_argument( + "-a", + "--alpha", + type=float, + default=None, + help="DEPRECATED: Please use the `--lambda` argument instead.", + ) + call_hits_parser.add_argument( + "-A", + "--motif-alphas", + type=str, + default=None, + help="DEPRECATED: Please use the `--motif-lambdas` argument instead.", + ) + + call_hits_parser.add_argument( + "-f", + "--no-post-filter", + action="store_true", + help="Do not perform post-hit-calling filtering. By default, hits are filtered based on a minimum cosine similarity of `lambda` with the input contributions.", + ) + call_hits_parser.add_argument( + "-q", + "--sqrt-transform", + action="store_true", + help="Apply a signed square root transform to the input contributions and CWMs before hit calling.", + ) + call_hits_parser.add_argument( + "-s", + "--step-size-max", + type=float, + default=inspect.signature(call_hits).parameters["step_size_max"].default, + help="The maximum optimizer step size.", + ) + call_hits_parser.add_argument( + "-i", + "--step-size-min", + type=float, + default=inspect.signature(call_hits).parameters["step_size_min"].default, + help="The minimum optimizer step size.", + ) + call_hits_parser.add_argument( + "-j", + "--step-adjust", + type=float, + default=inspect.signature(call_hits).parameters["step_adjust"].default, + help="The optimizer step size adjustment factor. If the optimizer diverges, the step size is multiplicatively adjusted by this factor", + ) + call_hits_parser.add_argument( + "-c", + "--convergence-tol", + type=float, + default=inspect.signature(call_hits).parameters["convergence_tol"].default, + help="The tolerance for determining convergence. The optimizer exits when the duality gap is less than the tolerance.", + ) + call_hits_parser.add_argument( + "-S", + "--max-steps", + type=int, + default=inspect.signature(call_hits).parameters["max_steps"].default, + help="The maximum number of optimization steps.", + ) + call_hits_parser.add_argument( + "-b", + "--batch-size", + type=int, + default=inspect.signature(call_hits).parameters["batch_size"].default, + help="The batch size used for optimization.", + ) + call_hits_parser.add_argument( + "-d", + "--device", + type=str, + default=None, + help="DEPRECATED: Please use the `CUDA_VISIBLE_DEVICES` environment variable to specify the GPU device.", + ) + call_hits_parser.add_argument( + "-J", + "--compile", + action="store_true", + help="JIT-compile the optimizer for faster performance. This may not be supported on older GPUs.", + ) + + report_parser = subparsers.add_parser( + "report", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + help="Generate statistics and visualizations from hits and tfmodisco-lite motif data.", + ) + + report_parser.add_argument( + "-r", + "--regions", + type=str, + required=True, + help="A .npz file containing input sequences, contributions, and coordinates. Must be the same as that used for `finemo call-hits`.", + ) + report_parser.add_argument( + "-H", + "--hits", + type=str, + required=True, + help="The output directory generated by the `finemo call-hits` command on the regions specified in `--regions`.", + ) + report_parser.add_argument( + "-p", + "--peaks", + type=str, + default=None, + help="DEPRECATED: Please provide this file to a preprocessing `finemo extract-regions-*` subcommand instead.", + ) + report_parser.add_argument( + "-m", + "--modisco-h5", + type=str, + default=None, + help="The tfmodisco-lite output H5 file of motif patterns. Must be the same as that used for hit calling unless `--no-recall` is set. If omitted, seqlet-derived metrics will not be computed.", + ) + report_parser.add_argument( + "-I", + "--motifs-include", + type=str, + default=None, + help="DEPRECATED: This information is now inferred from the outputs of `finemo call-hits`.", + ) + report_parser.add_argument( + "-N", + "--motif-names", + type=str, + default=None, + help="DEPRECATED: This information is now inferred from the outputs of `finemo call-hits`.", + ) + + report_parser.add_argument( + "-o", + "--out-dir", + type=str, + required=True, + help="The path to the report output directory.", + ) + + report_parser.add_argument( + "-W", + "--modisco-region-width", + type=int, + default=inspect.signature(report).parameters["modisco_region_width"].default, + help="The width of the region around each peak summit used by tfmodisco-lite.", + ) + report_parser.add_argument( + "-t", + "--cwm-trim-threshold", + type=float, + default=inspect.signature(report).parameters["cwm_trim_threshold"].default, + help="DEPRECATED: This information is now inferred from the outputs of `finemo call-hits`.", + ) + report_parser.add_argument( + "-n", + "--no-recall", + action="store_true", + help="Do not compute motif recall metrics.", + ) + report_parser.add_argument( + "-s", + "--no-seqlets", + action="store_true", + help="DEPRECATED: Please omit the `--modisco-h5` argument instead.", + ) + + collapse_hits_parser = subparsers.add_parser( + "collapse-hits", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + help="Identify best hit by motif similarity among sets of overlapping hits.", + ) + + collapse_hits_parser.add_argument( + "-i", + "--hits", + type=str, + required=True, + help="The `hits.tsv` or `hits_unique.tsv` file from `call-hits`.", + ) + collapse_hits_parser.add_argument( + "-o", + "--out-path", + type=str, + required=True, + help='The path to the output .tsv file with an additional "is_primary" column.', + ) + collapse_hits_parser.add_argument( + "-O", + "--overlap-frac", + type=float, + default=inspect.signature(collapse_hits).parameters["overlap_frac"].default, + help="The threshold for determining overlapping hits. For two hits with lengths x and y, the minimum overlap is defined as `overlap_frac * (x + y) / 2`. The default value of 0.2 means that two hits must overlap by at least 20% of their average lengths to be considered overlapping.", + ) + + intersect_hits_parser = subparsers.add_parser( + "intersect-hits", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + help="Intersect hits across multiple runs.", + ) + + intersect_hits_parser.add_argument( + "-i", + "--hits", + type=str, + required=True, + nargs="+", + help="One or more hits.tsv or hits_unique.tsv files, with paths delimited by whitespace.", + ) + intersect_hits_parser.add_argument( + "-o", + "--out-path", + type=str, + required=True, + help="The path to the output .tsv file. Duplicate columns are suffixed with the positional index of the input file.", + ) + intersect_hits_parser.add_argument( + "-r", + "--relaxed", + action="store_true", + help="Use relaxed intersection criteria, using only motif names and untrimmed coordinates. By default, the intersection assumes consistent region definitions and motif trimming. This option is not recommended if genomic coordinates are unavailable.", + ) args = parser.parse_args() - + if args.cmd == "extract-regions-bw": - extract_regions_bw(args.peaks, args.chrom_order, args.fasta, args.bigwigs, args.out_path, args.region_width) + extract_regions_bw( + args.peaks, + args.chrom_order, + args.fasta, + args.bigwigs, + args.out_path, + args.region_width, + ) elif args.cmd == "extract-regions-chrombpnet-h5": - extract_regions_chrombpnet_h5(args.peaks, args.chrom_order, args.h5s, args.out_path, args.region_width) + extract_regions_chrombpnet_h5( + args.peaks, args.chrom_order, args.h5s, args.out_path, args.region_width + ) elif args.cmd == "extract-regions-h5": - print("WARNING: The `extract-regions-h5` command is deprecated. Use `extract-regions-chrombpnet-h5` instead.") - extract_regions_chrombpnet_h5(args.peaks, args.chrom_order, args.h5s, args.out_path, args.region_width) + print( + "WARNING: The `extract-regions-h5` command is deprecated. Use `extract-regions-chrombpnet-h5` instead." + ) + extract_regions_chrombpnet_h5( + args.peaks, args.chrom_order, args.h5s, args.out_path, args.region_width + ) elif args.cmd == "extract-regions-bpnet-h5": - extract_regions_bpnet_h5(args.peaks, args.chrom_order, args.h5s, args.out_path, args.region_width) + extract_regions_bpnet_h5( + args.peaks, args.chrom_order, args.h5s, args.out_path, args.region_width + ) elif args.cmd == "extract-regions-modisco-fmt": - extract_regions_modisco_fmt(args.peaks, args.chrom_order, args.attributions, args.sequences, args.out_path, args.region_width) - + extract_regions_modisco_fmt( + args.peaks, + args.chrom_order, + args.attributions, + args.sequences, + args.out_path, + args.region_width, + ) + elif args.cmd == "call-hits": if args.alpha is not None: - warnings.warn("The `--alpha` flag is deprecated and will be removed in a future version. Please use the `--global-lambda` flag instead.") + warnings.warn( + "The `--alpha` flag is deprecated and will be removed in a future version. Please use the `--global-lambda` flag instead." + ) args.global_lambda = args.alpha if args.motif_alphas is not None: - warnings.warn("The `--motif-alphas` flag is deprecated and will be removed in a future version. Please use the `--motif-lambdas` flag instead.") + warnings.warn( + "The `--motif-alphas` flag is deprecated and will be removed in a future version. Please use the `--motif-lambdas` flag instead." + ) args.motif_lambdas = args.motif_alphas - call_hits(args.regions, args.peaks, args.modisco_h5, args.chrom_order, args.motifs_include, args.motif_names, - args.motif_lambdas, args.out_dir, args.cwm_trim_threshold, args.global_lambda, args.step_size_max, - args.step_size_min, args.sqrt_transform, args.convergence_tol, args.max_steps, args.batch_size, - args.step_adjust, args.device, args.mode, args.no_post_filter, args.compile) + call_hits( + args.regions, + args.peaks, + args.modisco_h5, + args.chrom_order, + args.motifs_include, + args.motif_names, + args.motif_lambdas, + args.out_dir, + args.cwm_trim_coords, + args.cwm_trim_thresholds, + args.cwm_trim_threshold, + args.global_lambda, + args.step_size_max, + args.step_size_min, + args.sqrt_transform, + args.convergence_tol, + args.max_steps, + args.batch_size, + args.step_adjust, + args.device, + args.mode, + args.no_post_filter, + args.compile, + ) elif args.cmd == "report": - if args.no_recall and not args.no_seqlets: - raise ValueError("The `--no-seqlets` flag must be set in conjunction with `--no-recall`.") - - report(args.regions, args.hits, args.modisco_h5, args.peaks, args.motifs_include, - args.motif_names, args.out_dir, args.modisco_region_width, args.cwm_trim_threshold, - not args.no_recall, not args.no_seqlets) + report( + args.regions, + args.hits, + args.modisco_h5, + args.peaks, + args.motifs_include, + args.motif_names, + args.out_dir, + args.modisco_region_width, + args.cwm_trim_threshold, + not args.no_recall, + not args.no_seqlets, + ) + + elif args.cmd == "collapse-hits": + collapse_hits(args.hits, args.out_path, args.overlap_frac) + + elif args.cmd == "intersect-hits": + intersect_hits(args.hits, args.out_path, args.relaxed) diff --git a/src/finemo/postprocessing.py b/src/finemo/postprocessing.py new file mode 100644 index 0000000..55c616b --- /dev/null +++ b/src/finemo/postprocessing.py @@ -0,0 +1,298 @@ +"""Post-processing utilities for Fi-NeMo hit calling results. + +This module provides functions for: +- Collapsing overlapping hits based on similarity scores +- Intersecting hit sets across multiple runs +- Quality control and filtering operations + +The main operations are optimized using Numba for efficient processing +of large hit datasets. +""" + +import heapq +from typing import List, Union + +import numpy as np +from numpy import ndarray +import polars as pl +from numba import njit +from numba.types import Array, uint32, int32, float32 # type: ignore[attr-defined] +from jaxtyping import Float, Int + + +@njit( + uint32[:]( + Array(uint32, 1, "C", readonly=True), + Array(int32, 1, "C", readonly=True), + Array(int32, 1, "C", readonly=True), + Array(float32, 1, "C", readonly=True), + ), + cache=True, +) +def _collapse_hits( + chrom_ids: Int[ndarray, " N"], + starts: Int[ndarray, " N"], + ends: Int[ndarray, " N"], + similarities: Float[ndarray, " N"], +) -> Int[ndarray, " N"]: + """Identify primary hits among overlapping hits using a sweep line algorithm. + + This function uses a heap-based sweep line algorithm to efficiently identify + the best hit (highest similarity) among sets of overlapping hits within each + chromosome. Only one hit per overlapping group is marked as primary. + + Parameters + ---------- + chrom_ids : Int[ndarray, "N"] + Chromosome identifiers for each hit, where N is the number of hits. + Dtype should be uint32 for Numba compatibility. + starts : Int[ndarray, "N"] + Start positions of hits (adjusted for overlap computation). + Dtype should be int32 for Numba compatibility. + ends : Int[ndarray, "N"] + End positions of hits (adjusted for overlap computation). + Dtype should be int32 for Numba compatibility. + similarities : Float[ndarray, "N"] + Similarity scores used for selecting the best hit. + Dtype should be float32 for Numba compatibility. + + Returns + ------- + Int[ndarray, "N"] + Binary array where 1 indicates the hit is primary, 0 otherwise. + Returns uint32 array for consistency with input types. + + Notes + ----- + This function is JIT-compiled with Numba for performance on large datasets. + The algorithm maintains active intervals in a heap and resolves overlaps + by keeping only the hit with the highest similarity score. + + The sweep line algorithm processes hits in order and maintains a heap of + currently active intervals. When a new interval is encountered, it is + compared against all overlapping intervals in the heap, and only the + interval with the highest similarity score remains marked as primary. + """ + n = chrom_ids.shape[0] + out = np.ones(n, dtype=np.uint32) + heap = [(np.uint32(0), np.int32(0), -1) for _ in range(0)] + + for i in range(n): + chrom_new = chrom_ids[i] + start_new = starts[i] + end_new = ends[i] + sim_new = similarities[i] + + # Remove expired intervals from heap + while heap and heap[0] < (chrom_new, start_new, -1): + heapq.heappop(heap) + + # Check overlaps with active intervals + for _, _, idx in heap: + cmp = sim_new > similarities[idx] + out[idx] &= cmp + out[i] &= not cmp + + # Add current interval to heap + heapq.heappush(heap, (chrom_new, end_new, i)) + + return out + + +def collapse_hits( + hits_df: Union[pl.DataFrame, pl.LazyFrame], overlap_frac: float +) -> pl.DataFrame: + """Collapse overlapping hits by selecting the best hit per overlapping group. + + This function identifies overlapping hits and marks only the highest-similarity + hit as primary in each overlapping group. Overlap is determined by a fractional + threshold based on the average length of the two hits being compared. + + Parameters + ---------- + hits_df : Union[pl.DataFrame, pl.LazyFrame] + Hit data containing required columns: chr (or peak_id if no chr), start, end, + hit_similarity. Will be collected to DataFrame if passed as LazyFrame. + overlap_frac : float + Overlap fraction threshold for considering hits as overlapping. + For two hits with lengths x and y, minimum overlap = overlap_frac * (x + y) / 2. + Must be between 0 and 1, where 0 means any overlap and 1 means complete overlap. + + Returns + ------- + pl.DataFrame + Original hit data with an additional 'is_primary' column (1 for primary hits, 0 otherwise). + All original columns are preserved, with the new column added at the end. + + Raises + ------ + KeyError + If required columns (chr/peak_id, start, end, hit_similarity) are missing. + + Notes + ----- + The algorithm transforms coordinates by scaling by 2 and adjusting by the overlap + fraction to create effective overlap regions for efficient processing. This allows + using a sweep line algorithm to identify overlaps in a single pass. + + The transformation works as follows: + - Original coordinates: [start, end] + - Length = end - start + - Adjusted start = start * 2 + length * overlap_frac + - Adjusted end = end * 2 - length * overlap_frac + + This creates regions that overlap only when the original regions have sufficient + overlap according to the specified fraction. + + Examples + -------- + >>> hits_collapsed = collapse_hits(hits_df, overlap_frac=0.2) + >>> primary_hits = hits_collapsed.filter(pl.col("is_primary") == 1) + >>> print(f"Kept {primary_hits.height}/{hits_df.height} hits as primary") + """ + # Ensure we're working with a DataFrame + if isinstance(hits_df, pl.LazyFrame): + hits_df = hits_df.collect() + + chroms = hits_df["chr"].unique(maintain_order=True) + + if not chroms.is_empty(): + chrom_to_id = {chrom: i for i, chrom in enumerate(chroms)} + # Transform coordinates for overlap computation + # Scale by 2 and adjust by overlap fraction to create effective overlap regions + df = hits_df.select( + chrom_id=pl.col("chr").replace_strict(chrom_to_id, return_dtype=pl.UInt32), + start_trim=pl.col("start") * 2 + + ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), + end_trim=pl.col("end") * 2 + - ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), + similarity=pl.col("hit_similarity"), + ) + else: + # Fall back to peak_id when chr column is not available + df = hits_df.select( + chrom_id=pl.col("peak_id"), + start_trim=pl.col("start") * 2 + + ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), + end_trim=pl.col("end") * 2 + - ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), + similarity=pl.col("hit_similarity"), + ) + + # Rechunk for efficient array access + df = df.rechunk() + chrom_ids = df["chrom_id"].to_numpy(allow_copy=False) + starts = df["start_trim"].to_numpy(allow_copy=False) + ends = df["end_trim"].to_numpy(allow_copy=False) + similarities = df["similarity"].to_numpy(allow_copy=False) + + # Run the collapse algorithm + is_primary = _collapse_hits(chrom_ids, starts, ends, similarities) + + # Add primary indicator column to original DataFrame + df_out = hits_df.with_columns(is_primary=pl.Series(is_primary, dtype=pl.UInt32)) + + return df_out + + +def intersect_hits( + hits_dfs: List[Union[pl.DataFrame, pl.LazyFrame]], relaxed: bool +) -> pl.DataFrame: + """Intersect hit datasets across multiple runs to find common hits. + + This function finds hits that appear consistently across multiple Fi-NeMo + runs, which can be useful for identifying robust motif instances that are + not sensitive to parameter variations or random initialization. + + Parameters + ---------- + hits_dfs : List[Union[pl.DataFrame, pl.LazyFrame]] + List of hit DataFrames from different Fi-NeMo runs. Each DataFrame must + contain the columns specified by the intersection criteria. LazyFrames + will be collected before processing. + relaxed : bool + If True, uses relaxed intersection criteria with only motif names and + untrimmed coordinates. If False, uses strict criteria including all + coordinate and metadata columns. + + Returns + ------- + pl.DataFrame + DataFrame containing hits that appear in all input datasets. + Columns from later datasets are suffixed with their index (e.g., '_1', '_2'). + The first dataset's columns retain their original names. + + Raises + ------ + ValueError + If fewer than one hits DataFrame is provided. + KeyError + If required columns for the specified intersection criteria are missing + from any of the input DataFrames. + + Notes + ----- + Relaxed intersection is useful when comparing results across different + region definitions or motif trimming parameters, but may produce less + precise matches. Strict intersection requires identical region definitions + and is recommended for most use cases. + + The intersection columns used are: + - Relaxed: ["chr", "start_untrimmed", "end_untrimmed", "motif_name", "strand"] + - Strict: ["chr", "start", "end", "start_untrimmed", "end_untrimmed", + "motif_name", "strand", "peak_name", "peak_id"] + + The function performs successive inner joins starting with the first DataFrame, + so the final result contains only hits present in all input datasets. + + Examples + -------- + >>> common_hits = intersect_hits([hits_df1, hits_df2], relaxed=False) + >>> print(f"Found {common_hits.height} hits common to both runs") + >>> + >>> # Compare relaxed vs strict intersection + >>> relaxed_hits = intersect_hits([hits_df1, hits_df2], relaxed=True) + >>> strict_hits = intersect_hits([hits_df1, hits_df2], relaxed=False) + >>> print(f"Relaxed: {relaxed_hits.height}, Strict: {strict_hits.height}") + """ + if relaxed: + # Relaxed criteria: only motif identity and untrimmed positions + join_cols = ["chr", "start_untrimmed", "end_untrimmed", "motif_name", "strand"] + else: + # Strict criteria: all coordinate and metadata columns + join_cols = [ + "chr", + "start", + "end", + "start_untrimmed", + "end_untrimmed", + "motif_name", + "strand", + "peak_name", + "peak_id", + ] + + if len(hits_dfs) < 1: + raise ValueError("At least one hits dataframe required") + + # Ensure all DataFrames are collected + collected_dfs = [] + for df in hits_dfs: + if isinstance(df, pl.LazyFrame): + collected_dfs.append(df.collect()) + else: + collected_dfs.append(df) + + # Start with first DataFrame and successively intersect with others + hits_df = collected_dfs[0] + for i in range(1, len(collected_dfs)): + hits_df = hits_df.join( + collected_dfs[i], + on=join_cols, + how="inner", + suffix=f"_{i}", + join_nulls=True, + coalesce=True, + ) + + return hits_df diff --git a/src/finemo/templates/report.html b/src/finemo/templates/report.html index dff2f92..c3a6ffc 100644 --- a/src/finemo/templates/report.html +++ b/src/finemo/templates/report.html @@ -170,6 +170,7 @@ } .distplot img { + max-width: unset; margin-bottom: 0 } @@ -177,15 +178,24 @@ -

Fi-NeMo hit calling report

+

Fi-NeMo Motif Hit Calling Report

+ +

+ This report provides a comprehensive analysis of motif instance calling results from Fi-NeMo, + a GPU-accelerated method for identifying transcription factor binding sites using neural network + contribution scores. Fi-NeMo uses a competitive optimization approach to comprehensively map motif + instances by solving a sparse linear reconstruction problem. The report compares Fi-NeMo hits + with TF-MoDISco seqlets (when available) and provides detailed statistics on hit quality and + motif discovery performance. +

{% if not use_seqlets %}
- Seqlet comparisons are not shown because a TF-MoDISco H5 file with seqlet data is not provided. + Note: Seqlet comparisons are not shown because a TF-MoDISco H5 file with seqlet data was not provided.
{% elif not compute_recall %}
- Seqlet recall and other statistics directly comparing hits and seqlets are not computed because the -n/--no-recall argument is set. + Note: Seqlet recall and other statistics directly comparing hits and seqlets are not computed because the -n/--no-recall argument was specified.
{% endif %} @@ -197,66 +207,69 @@

TF-MoDISco seqlet comparisons

Hit vs. seqlet counts

- This figure shows the number of hits called vs. the number of TF-MoDISco seqlets identified for each motif. - The dashed line is the identity line. - When comparing a shared set of regions, the hit counts should be mostly greater than the corresponding seqlet counts, since TF-MoDISco stringently filters seqlets and usually uses a smaller input window. + This scatter plot compares the number of motif instances called by Fi-NeMo versus the number of TF-MoDISco seqlets + identified for each motif. The dashed line represents perfect agreement (y = x). Fi-NeMo typically identifies + an order of magnitude more motif instances than TF-MoDISco because: (1) TF-MoDISco applies stringent filtering criteria + during seqlet identification, and (2) TF-MoDISco often analyzes smaller genomic windows than those used for hit calling.

- + {% endif %} -

Hit and seqlet motif comparisons

+

Motif-specific hit and seqlet analysis

- For each motif, this table examines the consistency between hits and TF-MoDISco seqlets. + This table provides detailed statistics for each motif, comparing the consistency between Fi-NeMo hits + and TF-MoDISco seqlets. The analysis includes hit counts, overlap statistics, and visual comparisons + of contribution weight matrices (CWMs).

- The following statistics report the number of hits, seqlets, and their relationships: + Statistical measures include:

    -
  • Hits: The number of hits called by Fi-NeMo
  • +
  • Hits: Total number of motif instances called by Fi-NeMo across all genomic regions
  • {% if compute_recall %} -
  • Restricted Hits: The number of Fi-NeMo hits within the TF-MoDISco input regions
  • +
  • Restricted Hits: Fi-NeMo hits overlapping with TF-MoDISco input regions (enables direct comparison)
  • {% endif %} {% if use_seqlets %} -
  • Seqlets: The number of unique TF-MoDISco seqlets
  • +
  • Seqlets: Unique TF-MoDISco seqlets used to construct this motif pattern
  • {% endif %} {% if compute_recall %} -
  • Hit/Seqlet Overlaps: The number of hits that coincide with TF-MoDISco seqlets
  • -
  • Missed Seqlets: The number of TF-MoDISco seqlets not called as hits
  • -
  • Additional Restricted Hits: The number of hits within the TF-MoDISco input regions that are not identified as seqlets by TF-MoDISco
  • -
  • Seqlet Recall: The fraction of seqlets that are called as hits
  • +
  • Hit/Seqlet Overlaps: Fi-NeMo hits that spatially coincide with TF-MoDISco seqlets (successful recovery)
  • +
  • Missed Seqlets: TF-MoDISco seqlets not identified as hits by Fi-NeMo (potential false negatives)
  • +
  • Additional Restricted Hits: Fi-NeMo hits not identified as seqlets by TF-MoDISco (potential new discoveries)
  • +
  • Seqlet Recall: Fraction of TF-MoDISco seqlets successfully recovered as Fi-NeMo hits
  • {% endif %} -
  • Hit-Seqlet CWM Similarity: The cosine similarity between the hit CWM and the TF-MoDISco CWM
  • +
  • Hit-Seqlet CWM Similarity: Cosine similarity between average contribution scores of hits vs. seqlets

- Note that the seqlet counts here may be lower than those shown in the tfmodisco-lite report due to double-counting in overlapping regions. - The seqlet counts shown here are unique while the counts in the tfmodisco-lite report are not de-duplicated. -

-{% if compute_recall %} -

- Note that palindromic motifs may have lower recall due to disagreements on orientation. - If seqlet recall is near zero for all motifs, the -W/--modisco-region-width argument is likely incorrect. - This value is required to infer genomic coordinates of seqlets from the tfmodisco-lite output H5. -

-{% endif %} -

- Motif CWMs (contribution weight matrices) are average contribution scores over a set of regions. The CWMs plotted here are: + Important notes:

    -
  • Hit CWM (FC): The forward-strand CWM of all hits
  • -
  • Hit CWM (RC): The reverse-strand CWM of all hits
  • -
  • TF-MoDISco CWM (FC/RC): The CWM of all TF-MoDISco seqlets
  • +
  • Seqlet counts may appear lower than in TF-MoDISco-lite reports due to removal of duplicate seqlets
  • {% if compute_recall %} -
  • Missed-Seqlet-Only CWM: The CWM of all TF-MoDISco seqlets that were not called as hits
  • -
  • Additional-Restricted-Hit CWM: The CWM of all hits within the TF-MoDISco input regions that were not identified as seqlets by TF-MoDISco
  • +
  • Palindromic motifs may show reduced recall due to strand orientation ambiguity
  • +
  • If seqlet recall is near zero across all motifs, verify that the -W/--modisco-region-width parameter matches the original TF-MoDISco analysis window
  • {% endif %}

- The plots span the full untrimmed motif, with the trimmed motif shaded. + Contribution Weight Matrix (CWM) visualizations:
+ CWMs represent average contribution scores across motif instances and show the functional importance + of each nucleotide position. The following CWMs are displayed for comparison:

+
    +
  • Hit CWM (FC/RC): Average contribution patterns from Fi-NeMo hits on forward/reverse strands
  • +
  • TF-MoDISco CWM (FC/RC): Average contribution patterns from TF-MoDISco seqlets on forward/reverse strands
  • + {% if compute_recall %} +
  • Missed-Seqlet-Only CWM: Contribution patterns from seqlets not recovered by Fi-NeMo (identifies potential algorithmic disagreements)
  • +
  • Additional-Restricted-Hit CWM: Contribution patterns from Fi-NeMo hits not identified by TF-MoDISco
  • + {% endif %} +

- The hit-seqlet similarity is the cosine similarity between the additional-restricted-hits CWM and the seqlet CWM. - This statistic measures the similarity between hits that were missed by TF-MoDISco and the seqlets used to construct the motif. + All CWM plots span the full untrimmed motif width, with the core trimmed region highlighted by shading. + {% if compute_recall %} + The hit-seqlet CWM similarity quantifies the overall agreement between Fi-NeMo's discovered instances + and TF-MoDISco's original motif definitions. + {% endif %}

@@ -309,59 +322,94 @@

Hit and seqlet motif comparisons

{% endif %} - - - - + + + + {% if compute_recall %} - - + + {% endif %} {% endfor %}
{{ item.num_seqlets_only }} {{ item.num_hits_restricted_only }}
-

Hit distributions

+{% if compute_recall %} + +

Motif cross-assignment analysis

+

+ This confusion matrix identifies cases where Fi-NeMo hits of one motif type spatially overlap with + TF-MoDISco seqlets of different motif types. Such cross-assignments can reveal related motif families, + algorithm differences, or cases where similar-looking motifs compete for the same binding sites. +

+

+ The y-axis represents seqlet motif identity, the x-axis represents hit motif identity, and color intensity + indicates the estimated overlap frequency per base of seqlet sequence. High off-diagonal values suggest + potential motif ambiguity and/or algorithmic disagreements at groups of putative TF binding sites. +

+ + +{% endif %} + +

Hit Quality and Distribution Analysis

- The following figures visualize the distribution of hits across motifs and peaks. + These visualizations examine the quality and distribution of Fi-NeMo hits across genomic regions and motifs, + measuring algorithm performance and signal strength.

-

Overall distribution of hits per peak

+

Genome-wide hit density

- This plot shows the distribution of hit counts per peak for any motif. - The number of peaks with no hits should be near zero. + This histogram shows the distribution of total hit counts per genomic region (across all motifs). + A good distribution should show nearly all regions containing at least one hit.

- + -

Per-motif distributions of hits per peak

+

Motif-specific hit quality metrics

- These plots show the distribution of hit counts per peak for each motif. + These distribution plots characterize the quality and prevalence of hits for individual motifs:

+
    +
  • Hits Per Region: Frequency of motif occurrence across genomic regions (higher values suggest more prevalent motifs)
  • +
  • Hit Coefficient: Strength of motif instance assignment by the optimization algorithm (higher values indicate stronger matches)
  • +
  • Hit Similarity: Cosine similarity between individual hits and the motif CWM (higher values indicate closer pattern matching)
  • +
  • Hit Importance: Total contribution score magnitude within hit regions (reflects functional significance from the neural network model)
  • +
- + + + + {% for m in motif_names %} - + + + + {% endfor %}
Motif NameHits Per PeakHits Per RegionHit CoefficientHit SimilarityHit Importance
{{ m }}
-

Motif co-occurrence

+

Motif co-occurrence analysis

+

+ This correlation heatmap reveals which motifs tend to occur together in the same genomic regions, + potentially indicating cooperative transcription factor binding or shared regulatory mechanisms. + Color intensity represents cosine similarity between motif occurrence patterns, where occurrence + is defined as the presence of at least one hit for each motif within individual regions. +

- This heatmap shows the co-occurrence of motifs across peaks. - The color intensity here represents the cosine similarity between the motifs' occurrence across peaks, - where occurence is defined as the presence of a hit for a motif in a peak. + High positive correlations (dark colors) suggest motifs that frequently co-occur. + Low correlations suggest independent or mutually exclusive binding patterns.

- + diff --git a/src/finemo/visualization.py b/src/finemo/visualization.py new file mode 100644 index 0000000..d331cb9 --- /dev/null +++ b/src/finemo/visualization.py @@ -0,0 +1,700 @@ +"""Visualization module for generating plots and reports for Fi-NeMo results. + +This module provides functions for: +- Plotting motif contribution weight matrices (CWMs) as sequence logos +- Generating distribution plots for hit statistics +- Creating co-occurrence heatmaps +- Producing HTML reports with interactive visualizations +- Plotting confusion matrices and performance metrics +""" + +import os +import importlib.resources +from typing import List, Optional, Dict, Any, Tuple, Union, Mapping, Iterable + +import numpy as np +from numpy import ndarray +import matplotlib.pyplot as plt +from matplotlib.axes import Axes +from matplotlib.patheffects import AbstractPathEffect +from matplotlib.textpath import TextPath +from matplotlib.transforms import Affine2D +from matplotlib.font_manager import FontProperties +from jinja2 import Template +import polars as pl +from jaxtyping import Float, Int + +from . import templates + + +def abbreviate_motif_name(name: str) -> str: + """Convert TF-MoDISco motif names to abbreviated format. + + Converts full TF-MoDISco pattern names to shorter, more readable format + for display in plots and reports. + + Parameters + ---------- + name : str + Full motif name (e.g., 'pos_patterns.pattern_0'). + + Returns + ------- + str + Abbreviated name (e.g., '+/0') or original name if parsing fails. + + Examples + -------- + >>> abbreviate_motif_name('pos_patterns.pattern_0') + '+/0' + >>> abbreviate_motif_name('neg_patterns.pattern_1') + '-/1' + >>> abbreviate_motif_name('invalid_name') + 'invalid_name' + """ + try: + group, motif = name.split(".") + if group == "pos_patterns": + group_short = "+" + elif group == "neg_patterns": + group_short = "-" + else: + raise Exception + motif_num = motif.split("_")[1] + return f"{group_short}/{motif_num}" + except Exception: + return name + + +def plot_hit_stat_distributions( + hits_df: pl.LazyFrame, motif_names: List[str], plot_dir: str +) -> None: + """Plot distributions of hit statistics for each motif. + + Creates separate histogram plots for coefficient, similarity, and importance + score distributions for each motif. Saves plots in both PNG (high-res) and + SVG (vector) formats. + + Parameters + ---------- + hits_df : pl.LazyFrame + Lazy DataFrame containing hit data with required columns: + - motif_name : str, name of the motif + - hit_coefficient_global : float, global coefficient values + - hit_similarity : float, similarity scores to motif CWM + - hit_importance : float, importance scores from attribution + motif_names : List[str] + List of motif names to generate plots for. Motifs not present + in hits_df will result in empty histograms. + plot_dir : str + Directory path where plots will be saved. Creates subdirectory + 'motif_stat_distributions' if it doesn't exist. + + Notes + ----- + For each motif, creates three separate plots: + - {motif_name}_coefficients.{png,svg} : coefficient distribution + - {motif_name}_similarities.{png,svg} : similarity distribution + - {motif_name}_importances.{png,svg} : importance distribution + """ + hits_df_collected = hits_df.collect() + hits_by_motif = hits_df_collected.partition_by("motif_name", as_dict=True) + dummy_df = hits_df_collected.clear() + + motifs_dir = os.path.join(plot_dir, "motif_stat_distributions") + os.makedirs(motifs_dir, exist_ok=True) + for m in motif_names: + hits = hits_by_motif.get((m,), dummy_df) + coefficients = hits.get_column("hit_coefficient_global").to_numpy() + similarities = hits.get_column("hit_similarity").to_numpy() + importances = hits.get_column("hit_importance").to_numpy() + + fig, ax = plt.subplots(figsize=(5, 2)) + + # Plot coefficient distribution + try: + ax.hist(coefficients, bins=50, density=True) + except ValueError: + ax.hist(coefficients, bins=1, density=True) + + output_path_png = os.path.join(motifs_dir, f"{m}_coefficients.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(motifs_dir, f"{m}_coefficients.svg") + plt.savefig(output_path_svg) + plt.close(fig) + + fig, ax = plt.subplots(figsize=(5, 2)) + + # Plot similarity distribution + try: + ax.hist(similarities, bins=50, density=True) + except ValueError: + ax.hist(similarities, bins=1, density=True) + + output_path_png = os.path.join(motifs_dir, f"{m}_similarities.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(motifs_dir, f"{m}_similarities.svg") + plt.savefig(output_path_svg) + plt.close(fig) + + fig, ax = plt.subplots(figsize=(5, 2)) + + # Plot importance distribution + try: + ax.hist(importances, bins=50, density=True) + except ValueError: + ax.hist(importances, bins=1, density=True) + + output_path_png = os.path.join(motifs_dir, f"{m}_importances.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(motifs_dir, f"{m}_importances.svg") + plt.savefig(output_path_svg) + plt.close(fig) + + +def plot_hit_peak_distributions( + occ_df: pl.DataFrame, motif_names: List[str], plot_dir: str +) -> None: + """Plot distribution of hits per peak for each motif. + + Creates bar plots showing the frequency distribution of hit counts per peak + for each motif, plus an overall distribution of total hits per peak. + + Parameters + ---------- + occ_df : pl.DataFrame + DataFrame containing motif occurrence counts per peak. Expected to have: + - One column per motif name with integer hit counts + - 'total' column with sum of all motif hits per peak + - Each row represents a peak/genomic region + motif_names : List[str] + List of motif names corresponding to columns in occ_df. + plot_dir : str + Directory to save plots. Creates 'motif_hit_distributions' subdirectory. + + Notes + ----- + Generates the following plots: + - Individual motif hit distributions: {motif_name}.{png,svg} + - Overall hit distribution: total_hit_distribution.{png,svg} + + Bar plots show frequency (proportion) on y-axis and hit count on x-axis. + """ + motifs_dir = os.path.join(plot_dir, "motif_hit_distributions") + os.makedirs(motifs_dir, exist_ok=True) + + for m in motif_names: + fig, ax = plt.subplots(figsize=(5, 2)) + + unique, counts = np.unique(occ_df.get_column(m), return_counts=True) + freq = counts / counts.sum() + num_bins = np.amax(unique, initial=0) + 1 + x = np.arange(num_bins) + y = np.zeros(num_bins) + y[unique] = freq + ax.bar(x, y) + + output_path_png = os.path.join(motifs_dir, f"{m}.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(motifs_dir, f"{m}.svg") + plt.savefig(output_path_svg) + + plt.close(fig) + + fig, ax = plt.subplots(figsize=(8, 4)) + + unique, counts = np.unique(occ_df.get_column("total"), return_counts=True) + freq = counts / counts.sum() + num_bins = np.amax(unique, initial=0) + 1 + x = np.arange(num_bins) + y = np.zeros(num_bins) + y[unique] = freq + ax.bar(x, y) + + ax.set_xlabel("Total hits per region") + ax.set_ylabel("Frequency") + + output_path_png = os.path.join(plot_dir, "total_hit_distribution.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(plot_dir, "total_hit_distribution.svg") + plt.savefig(output_path_svg, dpi=300) + + plt.close(fig) + + +def plot_peak_motif_indicator_heatmap( + peak_hit_counts: Int[ndarray, "M M"], motif_names: List[str], output_dir: str +) -> None: + """Plot co-occurrence heatmap showing motif associations across peaks. + + Creates a normalized correlation heatmap showing how frequently pairs of + motifs co-occur within the same genomic peaks. Values are normalized by + the geometric mean of individual motif frequencies. + + Parameters + ---------- + peak_hit_counts : Int[ndarray, "M M"] + Co-occurrence matrix where M = len(motif_names). + Entry (i,j) represents the number of peaks containing both motif i and j. + Diagonal entries represent total peaks containing each individual motif. + motif_names : List[str] + List of motif names for axis labels. Order must match matrix dimensions. + output_dir : str + Directory path where the heatmap plots will be saved. + + Notes + ----- + Saves plots as: + - motif_cooocurrence.png : High-resolution raster format + - motif_cooocurrence.svg : Vector format + + The heatmap uses correlation normalization: matrix[i,j] / sqrt(matrix[i,i] * matrix[j,j]) + Colors use the 'Greens' colormap with values typically in [0, 1] range. + """ + cov_norm = 1 / np.sqrt(np.diag(peak_hit_counts)) + matrix = peak_hit_counts * cov_norm[:, None] * cov_norm[None, :] + motif_keys = [abbreviate_motif_name(m) for m in motif_names] + + fig, ax = plt.subplots(figsize=(8, 8), layout="constrained") + + # Plot the heatmap + cax = ax.imshow(matrix, interpolation="nearest", aspect="equal", cmap="Greens") + + # Set axes on heatmap + ax.set_yticks(np.arange(len(motif_keys))) + ax.set_yticklabels(motif_keys) + ax.set_xticks(np.arange(len(motif_keys))) + ax.set_xticklabels(motif_keys, rotation=90) + ax.set_xlabel("Motif i") + ax.set_ylabel("Motif j") + + ax.tick_params(axis="both", labelsize=8) + + cbar = fig.colorbar(cax, ax=ax, orientation="vertical", shrink=0.6, aspect=30) + cbar.ax.tick_params(labelsize=8) + + output_path_png = os.path.join(output_dir, "motif_cooocurrence.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(output_dir, "motif_cooocurrence.svg") + plt.savefig(output_path_svg, dpi=300) + + plt.close() + + +def plot_seqlet_confusion_heatmap( + seqlet_confusion: Int[ndarray, "M M"], motif_names: List[str], output_dir: str +) -> None: + """Plot confusion matrix heatmap comparing seqlets to hit calls. + + Creates a heatmap showing the overlap between TF-MoDISco seqlets and + Fi-NeMo hit calls. Rows represent seqlet motifs, columns represent hit motifs. + + Parameters + ---------- + seqlet_confusion : Int[ndarray, "M M"] + Confusion matrix where M = len(motif_names). + Entry (i,j) represents the number of seqlets of motif i that overlap + with hits called for motif j. + motif_names : List[str] + List of motif names for axis labels. Order must match matrix dimensions. + output_dir : str + Directory path where the confusion matrix plots will be saved. + + Notes + ----- + Saves plots as: + - seqlet_confusion.png : High-resolution raster format + - seqlet_confusion.svg : Vector format + + The heatmap uses 'Blues' colormap. Perfect agreement would show a diagonal + pattern with high values along the diagonal and low off-diagonal values. + """ + motif_keys = [abbreviate_motif_name(m) for m in motif_names] + + fig, ax = plt.subplots(figsize=(8, 8), layout="constrained") + + # Plot the heatmap + cax = ax.imshow( + seqlet_confusion, interpolation="nearest", aspect="equal", cmap="Blues" + ) + + # Set axes on heatmap + ax.set_yticks(np.arange(len(motif_keys))) + ax.set_yticklabels(motif_keys) + ax.set_xticks(np.arange(len(motif_keys))) + ax.set_xticklabels(motif_keys, rotation=90) + ax.set_xlabel("Hit motif") + ax.set_ylabel("Seqlet motif") + + ax.tick_params(axis="both", labelsize=8) + + cbar = fig.colorbar(cax, ax=ax, orientation="vertical", shrink=0.6, aspect=30) + cbar.ax.tick_params(labelsize=8) + + output_path_png = os.path.join(output_dir, "seqlet_confusion.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(output_dir, "seqlet_confusion.svg") + plt.savefig(output_path_svg, dpi=300) + + plt.close() + + +class LogoGlyph(AbstractPathEffect): + """Path effect for creating sequence logo glyphs with normalized dimensions. + + This class creates properly scaled and positioned text glyphs for sequence + logos by normalizing character dimensions and applying appropriate transforms. + + Parameters + ---------- + glyph : str + Single character to render (e.g., 'A', 'C', 'G', 'T'). + ref_glyph : str, default 'E' + Reference character used for width normalization. + font_props : FontProperties, optional + Font properties for the glyph rendering. + offset : Tuple[float, float], default (0., 0.) + Offset for glyph positioning. + **kwargs + Additional graphics collection parameters. + """ + + def __init__( + self, + glyph: str, + ref_glyph: str = "E", + font_props: Optional[FontProperties] = None, + offset: Tuple[float, float] = (0.0, 0.0), + **kwargs, + ) -> None: + super().__init__(offset) + + path_orig = TextPath((0, 0), glyph, size=1, prop=font_props) + dims = path_orig.get_extents() + ref_dims = TextPath((0, 0), ref_glyph, size=1, prop=font_props).get_extents() + + h_scale = 1 / dims.height + ref_width = max(dims.width, ref_dims.width) + w_scale = 1 / ref_width + w_shift = (1 - dims.width / ref_width) / 2 + x_shift = -dims.x0 + y_shift = -dims.y0 + stretch = ( + Affine2D() + .translate(tx=x_shift, ty=y_shift) + .scale(sx=w_scale, sy=h_scale) + .translate(tx=w_shift, ty=0) + ) + + self.path = stretch.transform_path(path_orig) + + #: The dictionary of keywords to update the graphics collection with. + self._gc = kwargs + + def draw_path(self, renderer, gc, tpath, affine, rgbFace) -> Any: # type: ignore[override] + """Draw the glyph path using the renderer. + + Parameters + ---------- + renderer : matplotlib renderer + The renderer to draw with. + gc : GraphicsContext + Graphics context for drawing properties. + tpath : Path + Original text path (unused, using self.path instead). + affine : Transform + Affine transformation to apply. + rgbFace : color + Face color for the glyph. + + Returns + ------- + Any + Result from renderer.draw_path. + """ + return renderer.draw_path(gc, self.path, affine, rgbFace) + + +def plot_logo( + ax: Axes, + heights: Float[ndarray, "B W"], + glyphs: Iterable[str], + colors: Optional[Mapping[str, Optional[str]]] = None, + font_props: Optional[FontProperties] = None, + shade_bounds: Optional[Tuple[int, int]] = None, +) -> None: + """Plot sequence logo from contribution weight matrix. + + Creates a sequence logo visualization where letter heights represent + the contribution or information content at each position. Supports + both positive and negative contributions with proper stacking. + + Parameters + ---------- + ax : Axes + Matplotlib axes object to plot on. + heights : Float[ndarray, "B W"] + Height matrix where B = len(glyphs) and W = motif width. + Entry (i,j) represents the height/contribution of base i at position j. + Can contain both positive and negative values. + glyphs : Iterable[str] + Sequence of base characters corresponding to rows in heights matrix. + Typically ['A', 'C', 'G', 'T'] for DNA. + colors : Dict[str, str], optional + Color mapping for each base. Keys should match glyphs. + If None, all bases will use default matplotlib colors. + font_props : FontProperties, optional + Font properties for letter rendering. If None, uses default font. + shade_bounds : Tuple[int, int], optional + (start, end) position indices to shade in background. + Useful for highlighting core motif regions. + + Notes + ----- + Positive and negative contributions are handled separately: + - Positive values are stacked above zero line in order of descending absolute value + - Negative values are stacked below zero line in order of descending absolute value + - A horizontal line is drawn at y=0 for reference + + The resulting plot has: + - X-axis: Position in motif (0-indexed) + - Y-axis: Contribution magnitude + - Bar width: 0.95 (small gaps between positions) + """ + if colors is None: + colors = {g: None for g in glyphs} + + ax.margins(x=0, y=0) + + pos_values = np.clip(heights, 0, None) + neg_values = np.clip(heights, None, 0) + pos_order = np.argsort(pos_values, axis=0) + neg_order = np.argsort(neg_values, axis=0)[::-1, :] + pos_reorder = np.argsort(pos_order, axis=0) + neg_reorder = np.argsort(neg_order, axis=0) + pos_offsets = np.take_along_axis( + np.cumsum(np.take_along_axis(pos_values, pos_order, axis=0), axis=0), + pos_reorder, + axis=0, + ) + neg_offsets = np.take_along_axis( + np.cumsum(np.take_along_axis(neg_values, neg_order, axis=0), axis=0), + neg_reorder, + axis=0, + ) + bottoms = pos_offsets + neg_offsets - heights + + x = np.arange(heights.shape[1]) + + for glyph, height, bottom in zip(glyphs, heights, bottoms): + ax.bar( + x, + height, + 0.95, + bottom=bottom, + path_effects=[LogoGlyph(glyph, font_props=font_props)], + color=colors[glyph], + ) + + if shade_bounds is not None: + start, end = shade_bounds + ax.axvspan(start - 0.5, end - 0.5, color="0.9", zorder=-1) + + ax.axhline(zorder=-1, linewidth=0.5, color="black") + + +LOGO_ALPHABET = "ACGT" +LOGO_COLORS = {"A": "#109648", "C": "#255C99", "G": "#F7B32B", "T": "#D62839"} +LOGO_FONT = FontProperties(weight="bold") + + +def plot_cwms( + cwms: Dict[str, Dict[str, Float[ndarray, "4 W"]]], + trim_bounds: Dict[str, Dict[str, Tuple[int, int]]], + out_dir: str, + alphabet: str = LOGO_ALPHABET, + colors: Dict[str, str] = LOGO_COLORS, + font: FontProperties = LOGO_FONT, +) -> None: + """Plot contribution weight matrices as sequence logos. + + Creates sequence logo plots for all motifs and CWM types, with optional + shading to highlight trimmed regions. Saves plots in both PNG and SVG formats. + + Parameters + ---------- + cwms : Dict[str, Dict[str, Float[ndarray, "4 W"]]] + Nested dictionary structure: {motif_name: {cwm_type: cwm_array}}. + Each cwm_array has shape (4, W) where W is motif width. + Rows correspond to bases in alphabet order. + trim_bounds : Dict[str, Dict[str, Tuple[int, int]]] + Nested dictionary: {motif_name: {cwm_type: (start, end)}}. + Defines regions to shade in the sequence logos. + out_dir : str + Output directory where motif subdirectories will be created. + alphabet : str, default LOGO_ALPHABET + DNA alphabet string, typically 'ACGT'. + colors : Dict[str, str], default LOGO_COLORS + Color mapping for DNA bases. Keys should match alphabet characters. + font : FontProperties, default LOGO_FONT + Font properties for sequence logo rendering. + + Notes + ----- + Directory structure created: + ``` + out_dir/ + ├── motif1/ + │ ├── cwm_type1.png + │ ├── cwm_type1.svg + │ └── ... + └── motif2/ + └── ... + ``` + + Each plot is 10x2 inches with trimmed regions shaded if specified. + Spines (plot borders) are hidden for cleaner appearance. + """ + for m, v in cwms.items(): + motif_dir = os.path.join(out_dir, m) + os.makedirs(motif_dir, exist_ok=True) + for cwm_type, cwm in v.items(): + fig, ax = plt.subplots(figsize=(10, 2)) + + plot_logo( + ax, + cwm, + alphabet, + colors=colors, + font_props=font, + shade_bounds=trim_bounds[m][cwm_type], + ) + + for name, spine in ax.spines.items(): + spine.set_visible(False) + + output_path_png = os.path.join(motif_dir, f"{cwm_type}.png") + plt.savefig(output_path_png, dpi=100) + output_path_svg = os.path.join(motif_dir, f"{cwm_type}.svg") + plt.savefig(output_path_svg) + + plt.close(fig) + + +def plot_hit_vs_seqlet_counts( + recall_data: Dict[str, Dict[str, Union[int, float]]], output_dir: str +) -> None: + """Plot scatter plot comparing hit counts to seqlet counts per motif. + + Creates a log-log scatter plot showing the relationship between the number + of hits called by Fi-NeMo and the number of seqlets identified by TF-MoDISco + for each motif. Includes diagonal reference line and motif annotations. + + Parameters + ---------- + recall_data : Dict[str, Dict[str, Union[int, float]]] + Dictionary with motif names as keys and metrics dictionaries as values. + Each metrics dictionary must contain: + - 'num_hits_total' : int, total number of hits for the motif + - 'num_seqlets' : int, total number of seqlets for the motif + output_dir : str + Directory path where the scatter plot will be saved. + + Notes + ----- + Saves plots as: + - hit_vs_seqlet_counts.png : High-resolution raster format + - hit_vs_seqlet_counts.svg : Vector format + + Plot features: + - Log-log scale on both axes + - Diagonal reference line (y = x) as dashed line + - Points annotated with abbreviated motif names + """ + x = [] + y = [] + m = [] + for k, v in recall_data.items(): + x.append(v["num_hits_total"]) + y.append(v["num_seqlets"]) + m.append(k) + + lim = max(np.amax(x), np.amax(y)) + + fig, ax = plt.subplots(figsize=(8, 8), layout="constrained") + ax.axline((0, 0), (lim, lim), color="0.3", linewidth=0.7, linestyle=(0, (5, 5))) + ax.scatter(x, y, s=5) + for i, txt in enumerate(m): + short = abbreviate_motif_name(txt) + ax.annotate(short, (x[i], y[i]), fontsize=8, weight="bold") + + ax.set_yscale("log") + ax.set_xscale("log") + + ax.set_xlabel("Hits per motif") + ax.set_ylabel("Seqlets per motif") + + output_path_png = os.path.join(output_dir, "hit_vs_seqlet_counts.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(output_dir, "hit_vs_seqlet_counts.svg") + plt.savefig(output_path_svg) + + plt.close() + + +def write_report( + report_df: pl.DataFrame, + motif_names: List[str], + out_path: str, + compute_recall: bool, + use_seqlets: bool, +) -> None: + """Generate and write HTML report from motif analysis results. + + Creates a comprehensive HTML report with tables and visualizations + summarizing the Fi-NeMo motif discovery and hit calling results. + + Parameters + ---------- + report_df : pl.DataFrame + DataFrame containing motif statistics and performance metrics. + Expected columns depend on compute_recall and use_seqlets flags. + motif_names : List[str] + List of motif names to include in the report. + Order determines presentation sequence in the report. + out_path : str + File path where the HTML report will be written. + Parent directory must exist. + compute_recall : bool + Whether recall metrics were computed and should be included + in the report template. + use_seqlets : bool + Whether TF-MoDISco seqlet data was used in the analysis + and should be referenced in the report. + + Notes + ----- + Uses Jinja2 templating with the report.html template from the + templates package. The template receives: + - report_data: Iterator of DataFrame rows as named tuples + - motif_names: List of motif names + - compute_recall: Boolean flag for recall metrics + - use_seqlets: Boolean flag for seqlet usage + + Raises + ------ + OSError + If the output path cannot be written. + """ + template_str = ( + importlib.resources.files(templates).joinpath("report.html").read_text() + ) + template = Template(template_str) + report = template.render( + report_data=report_df.iter_rows(named=True), + motif_names=motif_names, + compute_recall=compute_recall, + use_seqlets=use_seqlets, + ) + with open(out_path, "w") as f: + f.write(report)