From da37737e71b200ef348ed9fe105c366d20be04b3 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Fri, 11 Apr 2025 11:34:22 -0700 Subject: [PATCH 01/39] Vectorize report graphics --- pyproject.toml | 2 +- src/finemo/evaluation.py | 35 ++++++++++++++++++++++---------- src/finemo/main.py | 6 ++---- src/finemo/templates/report.html | 20 +++++++++--------- 4 files changed, 37 insertions(+), 26 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 111f8c9..0b33113 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" name = "finemo" description = "Identification of regulatory elements from neural network contribution scores for DNA." keywords = ["deep learning", "genomics"] -version = "0.30" +version = "0.31" readme = "README.md" license = {file = "LICENSE"} authors = [ diff --git a/src/finemo/evaluation.py b/src/finemo/evaluation.py index 8fecafc..a121fb5 100644 --- a/src/finemo/evaluation.py +++ b/src/finemo/evaluation.py @@ -77,8 +77,10 @@ def plot_hit_distributions(occ_df, motif_names, plot_dir): y[unique] = freq ax.bar(x, y) - output_path = os.path.join(motifs_dir, f"{m}.png") - plt.savefig(output_path, dpi=300) + output_path_png = os.path.join(motifs_dir, f"{m}.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(motifs_dir, f"{m}.svg") + plt.savefig(output_path_svg) plt.close(fig) @@ -95,13 +97,15 @@ def plot_hit_distributions(occ_df, motif_names, plot_dir): ax.set_xlabel("Motifs per peak") ax.set_ylabel("Frequency") - output_path = os.path.join(plot_dir, "total_hit_distribution.png") - plt.savefig(output_path, dpi=300) + output_path_png = os.path.join(plot_dir, "total_hit_distribution.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(plot_dir, "total_hit_distribution.svg") + plt.savefig(output_path_svg, dpi=300) plt.close(fig) -def plot_peak_motif_indicator_heatmap(peak_hit_counts, motif_names, output_path): +def plot_peak_motif_indicator_heatmap(peak_hit_counts, motif_names, output_dir): """ Plots a simple indicator heatmap of the motifs in each peak. """ @@ -122,7 +126,10 @@ def plot_peak_motif_indicator_heatmap(peak_hit_counts, motif_names, output_path) ax.set_xlabel("Motif i") ax.set_ylabel("Motif j") - plt.savefig(output_path, dpi=300) + output_path_png = os.path.join(output_dir, "motif_cooocurrence.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(output_dir, "motif_cooocurrence.svg") + plt.savefig(output_path_svg) plt.close() @@ -393,8 +400,6 @@ def plot_cwms(cwms, trim_bounds, out_dir, alphabet=LOGO_ALPHABET, colors=LOGO_CO motif_dir = os.path.join(out_dir, m) os.makedirs(motif_dir, exist_ok=True) for cwm_type, cwm in v.items(): - output_path = os.path.join(motif_dir, f"{cwm_type}.png") - fig, ax = plt.subplots(figsize=(10,2)) plot_logo(ax, cwm, alphabet, colors=colors, font_props=font, shade_bounds=trim_bounds[m][cwm_type]) @@ -402,11 +407,15 @@ def plot_cwms(cwms, trim_bounds, out_dir, alphabet=LOGO_ALPHABET, colors=LOGO_CO for name, spine in ax.spines.items(): spine.set_visible(False) - plt.savefig(output_path, dpi=100) + output_path_png = os.path.join(motif_dir, f"{cwm_type}.png") + plt.savefig(output_path_png, dpi=100) + output_path_svg = os.path.join(motif_dir, f"{cwm_type}.svg") + plt.savefig(output_path_svg) + plt.close(fig) -def plot_hit_vs_seqlet_counts(recall_data, output_path): +def plot_hit_vs_seqlet_counts(recall_data, output_dir): x = [] y = [] m = [] @@ -430,7 +439,11 @@ def plot_hit_vs_seqlet_counts(recall_data, output_path): ax.set_xlabel("Hits per motif") ax.set_ylabel("Seqlets per motif") - plt.savefig(output_path, dpi=300) + output_path_png = os.path.join(output_dir, "hit_vs_seqlet_counts.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(output_dir, "hit_vs_seqlet_counts.svg") + plt.savefig(output_path_svg) + plt.close() diff --git a/src/finemo/main.py b/src/finemo/main.py index 77ce256..fd031fd 100644 --- a/src/finemo/main.py +++ b/src/finemo/main.py @@ -216,8 +216,7 @@ def report(regions_path, hits_dir, modisco_h5_path, peaks_path, motifs_include_p evaluation.plot_hit_distributions(occ_df, motif_names, out_dir) - coooc_path = os.path.join(out_dir, "motif_cooocurrence.png") - evaluation.plot_peak_motif_indicator_heatmap(coooc, motif_names, coooc_path) + evaluation.plot_peak_motif_indicator_heatmap(coooc, motif_names, out_dir) plot_dir = os.path.join(out_dir, "CWMs") evaluation.plot_cwms(cwms, trim_bounds, plot_dir) @@ -227,8 +226,7 @@ def report(regions_path, hits_dir, modisco_h5_path, peaks_path, motifs_include_p seqlets_path = os.path.join(out_dir, "seqlets.tsv") data_io.write_modisco_seqlets(seqlets_df, seqlets_path) - plot_path = os.path.join(out_dir, "hit_vs_seqlet_counts.png") - evaluation.plot_hit_vs_seqlet_counts(report_data, plot_path) + evaluation.plot_hit_vs_seqlet_counts(report_data, out_dir) report_path = os.path.join(out_dir, "report.html") evaluation.write_report(report_df, motif_names, report_path, compute_recall, seqlets_df is not None) diff --git a/src/finemo/templates/report.html b/src/finemo/templates/report.html index dff2f92..2b58a7e 100644 --- a/src/finemo/templates/report.html +++ b/src/finemo/templates/report.html @@ -201,7 +201,7 @@

Hit vs. seqlet counts

The dashed line is the identity line. When comparing a shared set of regions, the hit counts should be mostly greater than the corresponding seqlet counts, since TF-MoDISco stringently filters seqlets and usually uses a smaller input window.

- + {% endif %}

Hit and seqlet motif comparisons

@@ -309,13 +309,13 @@

Hit and seqlet motif comparisons

{{ item.num_seqlets_only }} {{ item.num_hits_restricted_only }} {% endif %} - - - - + + + + {% if compute_recall %} - - + + {% endif %} {% endfor %} @@ -332,7 +332,7 @@

Overall distribution of hits per peak

This plot shows the distribution of hit counts per peak for any motif. The number of peaks with no hits should be near zero.

- +

Per-motif distributions of hits per peak

@@ -349,7 +349,7 @@

Per-motif distributions of hits per peak

{% for m in motif_names %} {{ m }} - + {% endfor %} @@ -361,7 +361,7 @@

Motif co-occurrence

The color intensity here represents the cosine similarity between the motifs' occurrence across peaks, where occurence is defined as the presence of a hit for a motif in a peak.

- + From ac1112ac201a123d68f94742a762e90b865adf0d Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Sat, 12 Apr 2025 06:19:44 -0700 Subject: [PATCH 02/39] Additional hit distribution plots --- src/finemo/evaluation.py | 50 ++++++++++++++++++++++++++++++-- src/finemo/main.py | 4 ++- src/finemo/templates/report.html | 37 +++++++++++++++-------- 3 files changed, 75 insertions(+), 16 deletions(-) diff --git a/src/finemo/evaluation.py b/src/finemo/evaluation.py index a121fb5..4dddd38 100644 --- a/src/finemo/evaluation.py +++ b/src/finemo/evaluation.py @@ -62,12 +62,56 @@ def get_motif_occurences(hits_df, motif_names): return occ_df, coocc -def plot_hit_distributions(occ_df, motif_names, plot_dir): +def plot_hit_stat_distributions(hits_df, motif_names, plot_dir): + hits_df = hits_df.collect() + hits_by_motif = hits_df.partition_by("motif_name", as_dict=True) + dummy_df = hits_df.clear() + + motifs_dir = os.path.join(plot_dir, "motif_stat_distributions") + os.makedirs(motifs_dir, exist_ok=True) + for m in motif_names: + hits = hits_by_motif.get((m,), dummy_df) + coefficients = hits.get_column("hit_coefficient_global").to_numpy() + similarities = hits.get_column("hit_similarity").to_numpy() + importances = hits.get_column("hit_importance").to_numpy() + + fig, ax = plt.subplots(figsize=(5, 2)) + + ax.hist(coefficients, bins=50) + + output_path_png = os.path.join(motifs_dir, f"{m}_coefficients.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(motifs_dir, f"{m}_coefficients.svg") + plt.savefig(output_path_svg) + plt.close(fig) + + fig, ax = plt.subplots(figsize=(5, 2)) + + ax.hist(similarities, bins=50) + + output_path_png = os.path.join(motifs_dir, f"{m}_similarities.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(motifs_dir, f"{m}_similarities.svg") + plt.savefig(output_path_svg) + plt.close(fig) + + fig, ax = plt.subplots(figsize=(5, 2)) + + ax.hist(importances, bins=50) + + output_path_png = os.path.join(motifs_dir, f"{m}_importances.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(motifs_dir, f"{m}_importances.svg") + plt.savefig(output_path_svg) + plt.close(fig) + + +def plot_hit_peak_distributions(occ_df, motif_names, plot_dir): motifs_dir = os.path.join(plot_dir, "motif_hit_distributions") os.makedirs(motifs_dir, exist_ok=True) for m in motif_names: - fig, ax = plt.subplots(figsize=(6, 2)) + fig, ax = plt.subplots(figsize=(5, 2)) unique, counts = np.unique(occ_df.get_column(m), return_counts=True) freq = counts / counts.sum() @@ -94,7 +138,7 @@ def plot_hit_distributions(occ_df, motif_names, plot_dir): y[unique] = freq ax.bar(x, y) - ax.set_xlabel("Motifs per peak") + ax.set_xlabel("Total hits per region") ax.set_ylabel("Frequency") output_path_png = os.path.join(plot_dir, "total_hit_distribution.png") diff --git a/src/finemo/main.py b/src/finemo/main.py index fd031fd..caff09e 100644 --- a/src/finemo/main.py +++ b/src/finemo/main.py @@ -214,7 +214,9 @@ def report(regions_path, hits_dir, modisco_h5_path, peaks_path, motifs_include_p data_io.write_report_data(report_df, cwms, out_dir) - evaluation.plot_hit_distributions(occ_df, motif_names, out_dir) + evaluation.plot_hit_stat_distributions(hits_df, motif_names, out_dir) + + evaluation.plot_hit_peak_distributions(occ_df, motif_names, out_dir) evaluation.plot_peak_motif_indicator_heatmap(coooc, motif_names, out_dir) diff --git a/src/finemo/templates/report.html b/src/finemo/templates/report.html index 2b58a7e..0b2ed57 100644 --- a/src/finemo/templates/report.html +++ b/src/finemo/templates/report.html @@ -170,6 +170,7 @@ } .distplot img { + max-width: unset; margin-bottom: 0 } @@ -322,34 +323,46 @@

Hit and seqlet motif comparisons

-

Hit distributions

+

Hit statistic distributions

- The following figures visualize the distribution of hits across motifs and peaks. + The following figures visualize the distribution of hit statistics across motifs and regions.

-

Overall distribution of hits per peak

+

Overall distribution of hit counts per region

- This plot shows the distribution of hit counts per peak for any motif. - The number of peaks with no hits should be near zero. + This plot shows the distribution of hit counts per region for any motif. + The number of regions with no hits should be near zero.

-

Per-motif distributions of hits per peak

+

Per-motif distributions of hit statistics

- These plots show the distribution of hit counts per peak for each motif. + These plots show the distribution of hit statistics for each motif, specifically: +

- + + + + {% for m in motif_names %} - + + + + {% endfor %} @@ -357,9 +370,9 @@

Per-motif distributions of hits per peak

Motif co-occurrence

- This heatmap shows the co-occurrence of motifs across peaks. - The color intensity here represents the cosine similarity between the motifs' occurrence across peaks, - where occurence is defined as the presence of a hit for a motif in a peak. + This heatmap shows the co-occurrence of motifs across regions. + The color intensity here represents the cosine similarity between the motifs' occurrence across regions, + where occurence is defined as the presence of a hit for a motif in a region.

From 20853997f44ab1297e13ce9c934c5b2edd8f471e Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Sun, 13 Apr 2025 16:17:24 -0700 Subject: [PATCH 03/39] Utility for collapsing overlapping hits --- environment.yml | 1 + pyproject.toml | 3 +- src/finemo/data_io.py | 13 +++++-- src/finemo/evaluation.py | 6 +-- src/finemo/main.py | 24 ++++++++++++ src/finemo/postprocessing.py | 72 ++++++++++++++++++++++++++++++++++++ 6 files changed, 111 insertions(+), 8 deletions(-) create mode 100644 src/finemo/postprocessing.py diff --git a/environment.yml b/environment.yml index 01e7af9..98e4cde 100644 --- a/environment.yml +++ b/environment.yml @@ -8,6 +8,7 @@ dependencies: - pytorch=2.5.1 - pytorch-cuda=12.4 - python=3.11 + - numba=0.61.2 - numpy=2.2.0 - scipy=1.14.1 - polars=1.17.1 diff --git a/pyproject.toml b/pyproject.toml index 0b33113..58ddb2e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" name = "finemo" description = "Identification of regulatory elements from neural network contribution scores for DNA." keywords = ["deep learning", "genomics"] -version = "0.31" +version = "0.32" readme = "README.md" license = {file = "LICENSE"} authors = [ @@ -17,6 +17,7 @@ dependencies = [ "numpy", "scipy", "torch", + "numba", "polars>=1.0", "matplotlib", "h5py", diff --git a/src/finemo/data_io.py b/src/finemo/data_io.py index d5b119a..cd74fac 100644 --- a/src/finemo/data_io.py +++ b/src/finemo/data_io.py @@ -209,7 +209,6 @@ def write_regions_npz(sequences, contributions, out_path, peaks_df=None): chr=chr_arr, chr_id=chr_id_arr, start=start_arr, peak_id=peak_id_arr, peak_name=peak_name_arr) - def trim_motif(cwm, trim_threshold): """ Adapted from https://github.com/jmschrei/tfmodisco-lite/blob/570535ee5ccf43d670e898d92d63af43d68c38c5/modiscolite/report.py#L213-L236 @@ -439,18 +438,24 @@ def write_modisco_seqlets(seqlets_df, out_path): "strand": pl.String, "peak_name": pl.String, "peak_id": pl.UInt32, - } +HITS_COLLAPSED_DTYPES = HITS_DTYPES | {"is_primary": pl.UInt32} + -def load_hits(hits_path, lazy=False): +def load_hits(hits_path, lazy=False, schema=HITS_DTYPES): hits_df = ( - pl.scan_csv(hits_path, separator='\t', quote_char=None, schema=HITS_DTYPES) + pl.scan_csv(hits_path, separator='\t', quote_char=None, schema=schema) .with_columns(pl.lit(1).alias("count")) ) return hits_df if lazy else hits_df.collect() +def write_hits_processed(hits_df, out_path, schema=HITS_DTYPES): + hits_df = hits_df.select(schema.keys()) + hits_df.write_csv(out_path, separator="\t") + + def write_hits(hits_df, peaks_df, motifs_df, qc_df, out_dir, motif_width): os.makedirs(out_dir, exist_ok=True) out_path_tsv = os.path.join(out_dir, "hits.tsv") diff --git a/src/finemo/evaluation.py b/src/finemo/evaluation.py index 4dddd38..72358c5 100644 --- a/src/finemo/evaluation.py +++ b/src/finemo/evaluation.py @@ -77,7 +77,7 @@ def plot_hit_stat_distributions(hits_df, motif_names, plot_dir): fig, ax = plt.subplots(figsize=(5, 2)) - ax.hist(coefficients, bins=50) + ax.hist(coefficients, bins=50, density=True) output_path_png = os.path.join(motifs_dir, f"{m}_coefficients.png") plt.savefig(output_path_png, dpi=300) @@ -87,7 +87,7 @@ def plot_hit_stat_distributions(hits_df, motif_names, plot_dir): fig, ax = plt.subplots(figsize=(5, 2)) - ax.hist(similarities, bins=50) + ax.hist(similarities, bins=50, density=True) output_path_png = os.path.join(motifs_dir, f"{m}_similarities.png") plt.savefig(output_path_png, dpi=300) @@ -97,7 +97,7 @@ def plot_hit_stat_distributions(hits_df, motif_names, plot_dir): fig, ax = plt.subplots(figsize=(5, 2)) - ax.hist(importances, bins=50) + ax.hist(importances, bins=50, density=True) output_path_png = os.path.join(motifs_dir, f"{m}_importances.png") plt.savefig(output_path_png, dpi=300) diff --git a/src/finemo/main.py b/src/finemo/main.py index caff09e..6bae4ff 100644 --- a/src/finemo/main.py +++ b/src/finemo/main.py @@ -234,6 +234,15 @@ def report(regions_path, hits_dir, modisco_h5_path, peaks_path, motifs_include_p evaluation.write_report(report_df, motif_names, report_path, compute_recall, seqlets_df is not None) +def collapse_hits(hits_path, out_path, overlap): + from . import postprocessing + + hits_df = data_io.load_hits(hits_path, lazy=False) + hits_collapsed_df = postprocessing.collapse_hits(hits_df, overlap) + + data_io.write_hits_processed(hits_collapsed_df, out_path, schema=data_io.HITS_COLLAPSED_DTYPES) + + def cli(): parser = argparse.ArgumentParser() subparsers = parser.add_subparsers(required=True, dest='cmd') @@ -418,6 +427,17 @@ def cli(): help="Do not compute motif recall metrics.") report_parser.add_argument("-s", "--no-seqlets", action='store_true', help="DEPRECATED: Please omit the `--modisco-h5` argument instead.") + + + collapse_hits_parser = subparsers.add_parser("collapse-hits", formatter_class=argparse.ArgumentDefaultsHelpFormatter, + help="Identify best hit among sets of overlapping hits by motif similarity.") + + collapse_hits_parser.add_argument("-i", "--hits", type=str, required=True, + help="The `hits.tsv` or `hits_unique.tsv` file from `call-hits`.") + collapse_hits_parser.add_argument("-o", "--out-path", type=str, required=True, + help="The path to the output .tsv file with an additional \"is_primary\" column.") + collapse_hits_parser.add_argument("-O", "--overlap", type=int, default=3, + help="The minimum number of base pairs to consider as overlapping.") args = parser.parse_args() @@ -458,3 +478,7 @@ def cli(): report(args.regions, args.hits, args.modisco_h5, args.peaks, args.motifs_include, args.motif_names, args.out_dir, args.modisco_region_width, args.cwm_trim_threshold, not args.no_recall, not args.no_seqlets) + + elif args.cmd == "collapse-hits": + collapse_hits(args.hits, args.out_path, args.overlap) + diff --git a/src/finemo/postprocessing.py b/src/finemo/postprocessing.py new file mode 100644 index 0000000..be312d0 --- /dev/null +++ b/src/finemo/postprocessing.py @@ -0,0 +1,72 @@ +import heapq + +import numpy as np +import polars as pl +from numba import njit +from numba.types import Array, uint32, int32, float32 + +@njit( + uint32[:]( + Array(uint32, 1, 'C', readonly=True), + Array(int32, 1, 'C', readonly=True), + Array(int32, 1, 'C', readonly=True), + Array(float32, 1, 'C', readonly=True) + ) +) +def _collapse_hits(chrom_ids, starts, ends, similarities): + n = chrom_ids.shape[0] + out = np.ones(n, dtype=np.uint32) + heap = [(np.uint32(0), np.int32(0), -1) for _ in range(0)] + + for i in range(n): + chrom_new = chrom_ids[i] + start_new = starts[i] + end_new = ends[i] + sim_new = similarities[i] + + while heap and heap[0] < (chrom_new, start_new, -1): + heapq.heappop(heap) + + for _, _, idx in heap: + cmp = sim_new > similarities[idx] + out[idx] &= cmp + out[i] &= not cmp + + heapq.heappush(heap, (chrom_new, end_new, i)) + + return out + + +def collapse_hits(hits_df, overlap): + chroms = hits_df["chr"].unique(maintain_order=True) + + if not chroms.is_empty(): + chrom_to_id = { + chrom: i for i, chrom in enumerate(chroms) + } + df = hits_df.select( + chrom_id=pl.col("chr").replace_strict(chrom_to_id, return_dtype=pl.UInt32), + start_trim=pl.col("start") * 2 + overlap, + end_trim=pl.col("end") * 2 - overlap, + similarity=pl.col("hit_similarity") + ) + else: + df = hits_df.select( + chrom_id=pl.col("peak_id"), + start_trim=pl.col("start") * 2 + overlap, + end_trim=pl.col("end") * 2 - overlap, + similarity=pl.col("hit_similarity") + ) + + df = df.rechunk() + chrom_ids = df["chrom_id"].to_numpy(allow_copy=False) + starts = df["start_trim"].to_numpy(allow_copy=False) + ends = df["end_trim"].to_numpy(allow_copy=False) + similarities = df["similarity"].to_numpy(allow_copy=False) + is_primary = _collapse_hits(chrom_ids, starts, ends, similarities) + + df_out = hits_df.with_columns( + is_primary=pl.Series(is_primary, dtype=pl.UInt32) + ) + + return df_out From c4560942820a9183c0674e77dbc586fbfacf7c20 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Sun, 13 Apr 2025 16:27:21 -0700 Subject: [PATCH 04/39] Documentation for `collapse-hits` --- README.md | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/README.md b/README.md index 138c738..52bb243 100644 --- a/README.md +++ b/README.md @@ -203,7 +203,7 @@ Usage: `finemo call-hits -r -m -o [-p ] - Legacy TFMoDISCo H5 files can be updated to the newer TFMoDISCo-lite format with the `modisco convert` command found in the [tfmodisco-lite](https://github.com/jmschrei/tfmodisco-lite/tree/main) package. - The hit-calling thresholding procedure is scale-invariant. That is, whether a position is assigned a hit depends on the shapes of the motif CWM and the contribution scores, not the absolute magnitude of the scores. If you wish to prioritize hits based on the magnitude of the contribution scores, set a per-motif rank threshold the `hit_coefficient_global` field in the `hits.tsv` file, which captures both the absolute importance and the closeness of match. -### Output reporting +### Output reporting and post-processing #### `finemo report` @@ -220,12 +220,19 @@ Usage: `finemo report -r -H -o [-m ] [-W - `-W/--modisco-region-width`: The width of the region around each peak summit used by tfmodisco-lite. Default is 400. - `-n/--no-recall`: Do not compute motif recall metrics. Default is False. -#### Additional outputs +Additional report outputs: -`motif_report.tsv`: Statistics on the distribution of hits per motif. The columns and values correspond to those in the HTML report's table. +- `motif_report.tsv`: Statistics on the distribution of hits per motif. The columns and values correspond to those in the HTML report's table. +- `motif_occurrences.tsv`: The number of hits of each motif in each input region. Also includes the total number of hits per region. +- `CWMs`: A directory containing visualizations of motif CWMs, as well as corresponding tables with numerical CWM values. +- `seqlets.tsv`: tf-modisco seqlet coordinates for each motif in each region. Only generated if `-m/--modisco-h5` is provided. -`motif_occurrences.tsv`: The number of hits of each motif in each input region. Also includes the total number of hits per region. +#### `finemo collapse-hits` -`CWMs`: A directory containing visualizations of motif CWMs, as well as corresponding tables with numerical CWM values. +Identify the best hits by motif similarity within groups of overlapping hits. Adds a 0/1 `is_primary` column to the `hits.tsv` file, indicating whether a hit is the best hit in its group. This command does not utilize the GPU. -`seqlets.tsv`: tf-modisco seqlet coordinates for each motif in each region. Only generated if `-m/--modisco-h5` is provided. +Usage: `usage: finemo collapse-hits [-h] -i -o [-O ]` + +- `-i/--hits`: The path to the input hits file. This should be the `hits.tsv` or `hits_unique.tsv` file generated by the `finemo call-hits` command. +- `-o/--out-path`: The path to the output file. This will be a copy of the input file with an additional `is_primary` column. +- `-O/--overlap`: The minimum overlap (in base pairs) required for two hits to be considered overlapping. Default is 3 bp. From 5bc44d6319a18cf387606bc44b9666e59799e87b Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Sun, 13 Apr 2025 16:32:34 -0700 Subject: [PATCH 05/39] Cache jit outputs --- README.md | 2 +- src/finemo/postprocessing.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index 52bb243..df8c798 100644 --- a/README.md +++ b/README.md @@ -231,7 +231,7 @@ Additional report outputs: Identify the best hits by motif similarity within groups of overlapping hits. Adds a 0/1 `is_primary` column to the `hits.tsv` file, indicating whether a hit is the best hit in its group. This command does not utilize the GPU. -Usage: `usage: finemo collapse-hits [-h] -i -o [-O ]` +Usage: `usage: finemo collapse-hits -i -o [-O ]` - `-i/--hits`: The path to the input hits file. This should be the `hits.tsv` or `hits_unique.tsv` file generated by the `finemo call-hits` command. - `-o/--out-path`: The path to the output file. This will be a copy of the input file with an additional `is_primary` column. diff --git a/src/finemo/postprocessing.py b/src/finemo/postprocessing.py index be312d0..296758a 100644 --- a/src/finemo/postprocessing.py +++ b/src/finemo/postprocessing.py @@ -11,7 +11,7 @@ Array(int32, 1, 'C', readonly=True), Array(int32, 1, 'C', readonly=True), Array(float32, 1, 'C', readonly=True) - ) + ), cache=True ) def _collapse_hits(chrom_ids, starts, ends, similarities): n = chrom_ids.shape[0] From 3618533f50bf854f46df0a08d5d381e38b0a1c62 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Sun, 13 Apr 2025 18:48:17 -0700 Subject: [PATCH 06/39] Hit intersection utility --- README.md | 10 ++++++++++ pyproject.toml | 2 +- src/finemo/data_io.py | 3 ++- src/finemo/main.py | 25 ++++++++++++++++++++++++- src/finemo/postprocessing.py | 29 +++++++++++++++++++++++++++++ 5 files changed, 66 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index df8c798..2e30598 100644 --- a/README.md +++ b/README.md @@ -236,3 +236,13 @@ Usage: `usage: finemo collapse-hits -i -o [-O ]` - `-i/--hits`: The path to the input hits file. This should be the `hits.tsv` or `hits_unique.tsv` file generated by the `finemo call-hits` command. - `-o/--out-path`: The path to the output file. This will be a copy of the input file with an additional `is_primary` column. - `-O/--overlap`: The minimum overlap (in base pairs) required for two hits to be considered overlapping. Default is 3 bp. + +#### `finemo intersect-hits` + +Find the intersection of hits across multiple runs. This command does not utilize the GPU. + +Usage: `finemo intersect-hits -i -o [-r]` + +- `-i/--hits`: The path to one or more input hits file. This should be the `hits.tsv` or `hits_unique.tsv` file generated by the `finemo call-hits` command. +- `-o/--out-path`: The path to the output file. Reoccuring columns are suffixed with the positional index of the input file (e.g. `hit_importance_1`), with the exception of index 0. +- `-r/--relaxed`: By default, the intersection assumes consistent input region definitions (name and coordinates) and motif trimming across runs. In contrast, this relaxed intersection criteria uses only motif names and untrimmed hit coordinates. However, this is not suitable when hit genomic coordinates are unknown (e.g. when using `finemo call-hits` with `-p/--peaks`). Default is False. diff --git a/pyproject.toml b/pyproject.toml index 58ddb2e..9e9d76a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" name = "finemo" description = "Identification of regulatory elements from neural network contribution scores for DNA." keywords = ["deep learning", "genomics"] -version = "0.32" +version = "0.33" readme = "README.md" license = {file = "LICENSE"} authors = [ diff --git a/src/finemo/data_io.py b/src/finemo/data_io.py index cd74fac..35e36e9 100644 --- a/src/finemo/data_io.py +++ b/src/finemo/data_io.py @@ -452,7 +452,8 @@ def load_hits(hits_path, lazy=False, schema=HITS_DTYPES): def write_hits_processed(hits_df, out_path, schema=HITS_DTYPES): - hits_df = hits_df.select(schema.keys()) + if schema is not None: + hits_df = hits_df.select(schema.keys()) hits_df.write_csv(out_path, separator="\t") diff --git a/src/finemo/main.py b/src/finemo/main.py index 6bae4ff..1fed513 100644 --- a/src/finemo/main.py +++ b/src/finemo/main.py @@ -243,6 +243,15 @@ def collapse_hits(hits_path, out_path, overlap): data_io.write_hits_processed(hits_collapsed_df, out_path, schema=data_io.HITS_COLLAPSED_DTYPES) +def intersect_hits(hits_paths, out_path, relaxed): + from . import postprocessing + + hits_dfs = [data_io.load_hits(hits_path, lazy=False) for hits_path in hits_paths] + hits_df = postprocessing.intersect_hits(hits_dfs, relaxed) + + data_io.write_hits_processed(hits_df, out_path, schema=None) + + def cli(): parser = argparse.ArgumentParser() subparsers = parser.add_subparsers(required=True, dest='cmd') @@ -430,7 +439,7 @@ def cli(): collapse_hits_parser = subparsers.add_parser("collapse-hits", formatter_class=argparse.ArgumentDefaultsHelpFormatter, - help="Identify best hit among sets of overlapping hits by motif similarity.") + help="Identify best hit by motif similarity among sets of overlapping hits.") collapse_hits_parser.add_argument("-i", "--hits", type=str, required=True, help="The `hits.tsv` or `hits_unique.tsv` file from `call-hits`.") @@ -440,6 +449,17 @@ def cli(): help="The minimum number of base pairs to consider as overlapping.") + intersect_hits_parser = subparsers.add_parser("intersect-hits", formatter_class=argparse.ArgumentDefaultsHelpFormatter, + help="Intersect hits across multiple runs.") + + intersect_hits_parser.add_argument("-i", "--hits", type=str, required=True, nargs='+', + help="One or more hits.tsv or hits_unique.tsv files, with paths delimited by whitespace.") + intersect_hits_parser.add_argument("-o", "--out-path", type=str, required=True, + help="The path to the output .tsv file. Duplicate columns are suffixed with the positional index of the input file.") + intersect_hits_parser.add_argument("-r", "--relaxed", action='store_true', + help="Use relaxed intersection criteria, using only motif names and untrimmed coordinates. By default, the intersection assumes consistent region definitions and motif trimming. This option is not recommended if genomic coordinates are unavailable.") + + args = parser.parse_args() if args.cmd == "extract-regions-bw": @@ -482,3 +502,6 @@ def cli(): elif args.cmd == "collapse-hits": collapse_hits(args.hits, args.out_path, args.overlap) + elif args.cmd == "intersect-hits": + intersect_hits(args.hits, args.out_path, args.relaxed) + diff --git a/src/finemo/postprocessing.py b/src/finemo/postprocessing.py index 296758a..364f371 100644 --- a/src/finemo/postprocessing.py +++ b/src/finemo/postprocessing.py @@ -70,3 +70,32 @@ def collapse_hits(hits_df, overlap): ) return df_out + + +def intersect_hits(hits_dfs, relaxed): + if relaxed: + join_cols = [ + "chr", "start_untrimmed", "end_untrimmed", + "motif_name", "strand" + ] + else: + join_cols = [ + "chr", "start", "end", "start_untrimmed", "end_untrimmed", + "motif_name", "strand", "peak_name", "peak_id" + ] + + if len(hits_dfs) < 1: + raise ValueError("At least one hits dataframe required") + + hits_df = hits_dfs[0] + for i in range(1, len(hits_dfs)): + hits_df = hits_df.join( + hits_dfs[i], + on=join_cols, + how="inner", + suffix=f"_{i}", + join_nulls=True, + coalesce=True + ) + + return hits_df \ No newline at end of file From 0241df390b42c05d49084f58aa66d8f4bede38cb Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Mon, 14 Apr 2025 12:12:46 -0700 Subject: [PATCH 07/39] Additional motif trimming options --- README.md | 2 +- pyproject.toml | 2 +- src/finemo/data_io.py | 31 +++++++++++++++++++++++++++---- src/finemo/main.py | 39 +++++++++++++++++++++++++++------------ 4 files changed, 56 insertions(+), 18 deletions(-) diff --git a/README.md b/README.md index 2e30598..c71ded2 100644 --- a/README.md +++ b/README.md @@ -133,7 +133,7 @@ Usage: `finemo call-hits -r -m -o [-p ] - `-r/--regions`: A `.npz` file of input sequences, contributions, and coordinates. Created with a `finemo extract-regions-*` command. - `-m/--modisco-h5`: A tfmodisco-lite output H5 file of motif patterns. - `-o/--out-dir`: The path to the output directory. -- `-t/--cwm-trim-threshold`: The threshold to determine motif start and end positions within the full CWMs. Default is 0.3. +- `-t/--cwm-trim-threshold`: The threshold to determine motif start and end positions within the full CWMs. Default is 0.3. If you need finer control over motif trimming, check out the `-T/--cwm-trim-thresholds` and `-R/--cwm-trim-coords` options. - `-l/--global-lambda`: The L1 regularization weight determining the sparsity of hits. Default is 0.7. - `-b/--batch-size`: The batch size used for optimization. Default is 2000. - `-J/--compile`: Enable JIT compilation for faster execution. This option may not work on older GPUs. diff --git a/pyproject.toml b/pyproject.toml index 9e9d76a..e66a29a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" name = "finemo" description = "Identification of regulatory elements from neural network contribution scores for DNA." keywords = ["deep learning", "genomics"] -version = "0.33" +version = "0.34" readme = "README.md" license = {file = "LICENSE"} authors = [ diff --git a/src/finemo/data_io.py b/src/finemo/data_io.py index 35e36e9..7091fa2 100644 --- a/src/finemo/data_io.py +++ b/src/finemo/data_io.py @@ -32,6 +32,17 @@ def load_mapping(path, type): return mapping +def load_mapping_tuple(path, type): + mapping = {} + with open(path) as f: + for line in f: + entries = line.rstrip("\n").split("\t") + key = entries[0] + val = entries[1:] + mapping[key] = tuple(type(i) for i in val) + + return mapping + NARROWPEAK_SCHEMA = ["chr", "peak_start", "peak_end", "peak_name", "peak_score", "peak_strand", "peak_signal", "peak_pval", "peak_qval", "peak_summit"] @@ -238,8 +249,8 @@ def _motif_name_sort_key(data): MODISCO_PATTERN_GROUPS = ['pos_patterns', 'neg_patterns'] -def load_modisco_motifs(modisco_h5_path, trim_threshold, motif_type, motifs_include, - motif_name_map, motif_lambdas, motif_lambda_default, include_rc): +def load_modisco_motifs(modisco_h5_path, trim_coords, trim_thresholds, trim_threshold_default, motif_type, + motifs_include, motif_name_map, motif_lambdas, motif_lambda_default, include_rc): """ Adapted from https://github.com/jmschrei/tfmodisco-lite/blob/570535ee5ccf43d670e898d92d63af43d68c38c5/modiscolite/report.py#L252-L272 """ @@ -257,6 +268,11 @@ def load_modisco_motifs(modisco_h5_path, trim_threshold, motif_type, motifs_incl if motif_lambdas is None: motif_lambdas = {} + if trim_coords is None: + trim_coords = {} + if trim_thresholds is None: + trim_thresholds = {} + if len(motif_name_map.values()) != len(set(motif_name_map.values())): raise ValueError("Specified motif names are not unique") @@ -281,8 +297,15 @@ def load_modisco_motifs(modisco_h5_path, trim_threshold, motif_type, motifs_incl cwm_fwd = cwm_raw / cwm_norm cwm_rev = cwm_fwd[::-1,::-1] - start_fwd, end_fwd = trim_motif(cwm_fwd, trim_threshold) - start_rev, end_rev = trim_motif(cwm_rev, trim_threshold) + + if pattern_tag in trim_coords: + start_fwd, end_fwd = trim_coords[pattern_tag] + else: + trim_threshold = trim_thresholds.get(pattern_tag, trim_threshold_default) + start_fwd, end_fwd = trim_motif(cwm_fwd, trim_threshold) + + cwm_len = cwm_fwd.shape[1] + start_rev, end_rev = cwm_len - end_fwd, cwm_len - start_fwd trim_mask_fwd = np.zeros(cwm_fwd.shape[1], dtype=np.int8) trim_mask_fwd[start_fwd:end_fwd] = 1 diff --git a/src/finemo/main.py b/src/finemo/main.py index 1fed513..4897974 100644 --- a/src/finemo/main.py +++ b/src/finemo/main.py @@ -54,8 +54,9 @@ def extract_regions_modisco_fmt(peaks_path, chrom_order_path, shaps_paths, ohe_p def call_hits(regions_path, peaks_path, modisco_h5_path, chrom_order_path, motifs_include_path, motif_names_path, - motif_lambdas_path, out_dir, cwm_trim_threshold, lambda_default, step_size_max, step_size_min, sqrt_transform, - convergence_tol, max_steps, batch_size, step_adjust, device, mode, no_post_filter, compile_optimizer): + motif_lambdas_path, out_dir, cwm_trim_coords_path, cwm_trim_thresholds_path, cwm_trim_threshold_default, + lambda_default, step_size_max, step_size_min, sqrt_transform, convergence_tol, max_steps, batch_size, + step_adjust, device, mode, no_post_filter, compile_optimizer): params = locals() from . import hitcaller @@ -107,9 +108,19 @@ def call_hits(regions_path, peaks_path, modisco_h5_path, chrom_order_path, motif motif_lambdas = data_io.load_mapping(motif_lambdas_path, float) else: motif_lambdas = None + + if cwm_trim_coords_path is not None: + trim_coords = data_io.load_mapping_tuple(cwm_trim_coords_path, int) + else: + trim_coords = None + + if cwm_trim_thresholds_path is not None: + trim_thresholds = data_io.load_mapping(cwm_trim_thresholds_path, float) + else: + trim_thresholds = None - motifs_df, cwms, trim_masks, motif_names = data_io.load_modisco_motifs(modisco_h5_path, cwm_trim_threshold, motif_type, motifs_include, - motif_name_map, motif_lambdas, lambda_default, True) + motifs_df, cwms, trim_masks, motif_names = data_io.load_modisco_motifs(modisco_h5_path, trim_coords, trim_thresholds, cwm_trim_threshold_default, + motif_type, motifs_include, motif_name_map, motif_lambdas, lambda_default, True) num_motifs = cwms.shape[0] motif_width = cwms.shape[2] lambdas = motifs_df.get_column("lambda").to_numpy(writable=True) @@ -173,8 +184,8 @@ def report(regions_path, hits_dir, modisco_h5_path, peaks_path, motifs_include_p else: motif_name_map = None - motifs_df, cwms_modisco, trim_masks, motif_names = data_io.load_modisco_motifs(modisco_h5_path, cwm_trim_threshold, "cwm", - motifs_include, motif_name_map, None, None, True) + motifs_df, cwms_modisco, trim_masks, motif_names = data_io.load_modisco_motifs(modisco_h5_path, None, None, cwm_trim_threshold, "cwm", + motifs_include, motif_name_map, None, None, True) else: hits_df_path = os.path.join(hits_dir, "hits.tsv") @@ -188,7 +199,7 @@ def report(regions_path, hits_dir, modisco_h5_path, peaks_path, motifs_include_p params_path = os.path.join(hits_dir, "parameters.json") params = data_io.load_params(params_path) - cwm_trim_threshold = params["cwm_trim_threshold"] + cwm_trim_threshold = params["cwm_trim_threshold_default"] if not use_seqlets: warnings.warn("Usage of the `--no-seqlets` flag is deprecated and will be removed in a future version. Please omit the `--modisco-h5` argument instead.") @@ -376,10 +387,14 @@ def cli(): help="The path to the output directory.") call_hits_parser.add_argument("-t", "--cwm-trim-threshold", type=float, default=0.3, - help="The threshold to determine motif start and end positions within the full CWMs.") + help="The default threshold to determine motif start and end positions within the full CWMs.") + call_hits_parser.add_argument("-T", "--cwm-trim-thresholds", type=str, default=None, + help="A tab-delimited file with tfmodisco motif names (e.g pos_patterns.pattern_0) in the first column and custom trim thresholds in the second column. Omitted motifs default to the `--cwm-trim-threshold` value.") + call_hits_parser.add_argument("-R", "--cwm-trim-coords", type=str, default=None, + help="A tab-delimited file with tfmodisco motif names (e.g pos_patterns.pattern_0) in the first column and custom trim start and end coordinates in the second and third columns, respectively. Omitted motifs default to `--cwm-trim-thresholds` values.") call_hits_parser.add_argument("-l", "--global-lambda", type=float, default=0.7, - help="The L1 regularization weight determining the sparsity of hits.") + help="The default L1 regularization weight determining the sparsity of hits.") call_hits_parser.add_argument("-L", "--motif-lambdas", type=str, default=None, help="A tab-delimited file with tfmodisco motif names (e.g pos_patterns.pattern_0) in the first column and motif-specific lambdas in the second column. Omitted motifs default to the `--global-lambda` value.") call_hits_parser.add_argument("-a", "--alpha", type=float, default=None, @@ -487,9 +502,9 @@ def cli(): args.motif_lambdas = args.motif_alphas call_hits(args.regions, args.peaks, args.modisco_h5, args.chrom_order, args.motifs_include, args.motif_names, - args.motif_lambdas, args.out_dir, args.cwm_trim_threshold, args.global_lambda, args.step_size_max, - args.step_size_min, args.sqrt_transform, args.convergence_tol, args.max_steps, args.batch_size, - args.step_adjust, args.device, args.mode, args.no_post_filter, args.compile) + args.motif_lambdas, args.out_dir, args.cwm_trim_coords, args.cwm_trim_thresholds, args.cwm_trim_threshold, + args.global_lambda, args.step_size_max, args.step_size_min, args.sqrt_transform, args.convergence_tol, + args.max_steps, args.batch_size, args.step_adjust, args.device, args.mode, args.no_post_filter, args.compile) elif args.cmd == "report": if args.no_recall and not args.no_seqlets: From 58ebbd8362bf45510dfe7044adf2daed543746c3 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Tue, 29 Apr 2025 05:59:07 -0700 Subject: [PATCH 08/39] Switch to fraction threshold in `collapse-hits` --- README.md | 2 +- pyproject.toml | 2 +- src/finemo/main.py | 10 +++++----- src/finemo/postprocessing.py | 10 +++++----- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/README.md b/README.md index c71ded2..623b220 100644 --- a/README.md +++ b/README.md @@ -235,7 +235,7 @@ Usage: `usage: finemo collapse-hits -i -o [-O ]` - `-i/--hits`: The path to the input hits file. This should be the `hits.tsv` or `hits_unique.tsv` file generated by the `finemo call-hits` command. - `-o/--out-path`: The path to the output file. This will be a copy of the input file with an additional `is_primary` column. -- `-O/--overlap`: The minimum overlap (in base pairs) required for two hits to be considered overlapping. Default is 3 bp. +- `-O/--overlap-frac`: The minimum fraction overlap required for two hits to be considered overlapping. Precisely, given two hits of lengths `x` and `y`, the minimum number of overlapping bases is `overlap_frac * (x + y) / 2`. Default is 0.2. #### `finemo intersect-hits` diff --git a/pyproject.toml b/pyproject.toml index e66a29a..8c68141 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" name = "finemo" description = "Identification of regulatory elements from neural network contribution scores for DNA." keywords = ["deep learning", "genomics"] -version = "0.34" +version = "0.35" readme = "README.md" license = {file = "LICENSE"} authors = [ diff --git a/src/finemo/main.py b/src/finemo/main.py index 4897974..50be9b2 100644 --- a/src/finemo/main.py +++ b/src/finemo/main.py @@ -245,11 +245,11 @@ def report(regions_path, hits_dir, modisco_h5_path, peaks_path, motifs_include_p evaluation.write_report(report_df, motif_names, report_path, compute_recall, seqlets_df is not None) -def collapse_hits(hits_path, out_path, overlap): +def collapse_hits(hits_path, out_path, overlap_frac): from . import postprocessing hits_df = data_io.load_hits(hits_path, lazy=False) - hits_collapsed_df = postprocessing.collapse_hits(hits_df, overlap) + hits_collapsed_df = postprocessing.collapse_hits(hits_df, overlap_frac) data_io.write_hits_processed(hits_collapsed_df, out_path, schema=data_io.HITS_COLLAPSED_DTYPES) @@ -460,8 +460,8 @@ def cli(): help="The `hits.tsv` or `hits_unique.tsv` file from `call-hits`.") collapse_hits_parser.add_argument("-o", "--out-path", type=str, required=True, help="The path to the output .tsv file with an additional \"is_primary\" column.") - collapse_hits_parser.add_argument("-O", "--overlap", type=int, default=3, - help="The minimum number of base pairs to consider as overlapping.") + collapse_hits_parser.add_argument("-O", "--overlap-frac", type=float, default=0.2, + help="The threshold for determining overlapping hits. For two hits with lengths x and y, the minimum overlap is defined as `overlap_frac * (x + y) / 2`. The default value of 0.2 means that two hits must overlap by at least 20% of their average lengths to be considered overlapping.") intersect_hits_parser = subparsers.add_parser("intersect-hits", formatter_class=argparse.ArgumentDefaultsHelpFormatter, @@ -515,7 +515,7 @@ def cli(): not args.no_recall, not args.no_seqlets) elif args.cmd == "collapse-hits": - collapse_hits(args.hits, args.out_path, args.overlap) + collapse_hits(args.hits, args.out_path, args.overlap_frac) elif args.cmd == "intersect-hits": intersect_hits(args.hits, args.out_path, args.relaxed) diff --git a/src/finemo/postprocessing.py b/src/finemo/postprocessing.py index 364f371..5663a4a 100644 --- a/src/finemo/postprocessing.py +++ b/src/finemo/postprocessing.py @@ -37,7 +37,7 @@ def _collapse_hits(chrom_ids, starts, ends, similarities): return out -def collapse_hits(hits_df, overlap): +def collapse_hits(hits_df, overlap_frac): chroms = hits_df["chr"].unique(maintain_order=True) if not chroms.is_empty(): @@ -46,15 +46,15 @@ def collapse_hits(hits_df, overlap): } df = hits_df.select( chrom_id=pl.col("chr").replace_strict(chrom_to_id, return_dtype=pl.UInt32), - start_trim=pl.col("start") * 2 + overlap, - end_trim=pl.col("end") * 2 - overlap, + start_trim=pl.col("start") * 2 + ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), + end_trim=pl.col("end") * 2 - ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), similarity=pl.col("hit_similarity") ) else: df = hits_df.select( chrom_id=pl.col("peak_id"), - start_trim=pl.col("start") * 2 + overlap, - end_trim=pl.col("end") * 2 - overlap, + start_trim=pl.col("start") * 2 + ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), + end_trim=pl.col("end") * 2 - ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), similarity=pl.col("hit_similarity") ) From 8a998154d32e427e92b436881f566f34f984a957 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Thu, 1 May 2025 06:16:10 -0700 Subject: [PATCH 09/39] Confusion matrix visualization --- pyproject.toml | 2 +- src/finemo/data_io.py | 4 ++ src/finemo/evaluation.py | 117 ++++++++++++++++++++++++++++++- src/finemo/main.py | 7 ++ src/finemo/templates/report.html | 12 ++++ 5 files changed, 140 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 8c68141..e527bce 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,7 +6,7 @@ build-backend = "setuptools.build_meta" name = "finemo" description = "Identification of regulatory elements from neural network contribution scores for DNA." keywords = ["deep learning", "genomics"] -version = "0.35" +version = "0.36" readme = "README.md" license = {file = "LICENSE"} authors = [ diff --git a/src/finemo/data_io.py b/src/finemo/data_io.py index 7091fa2..1088d68 100644 --- a/src/finemo/data_io.py +++ b/src/finemo/data_io.py @@ -595,6 +595,10 @@ def write_occ_df(occ_df, out_path): occ_df.write_csv(out_path, separator="\t") +def write_seqlet_confusion_df(seqlet_confusion_df, out_path): + seqlet_confusion_df.write_csv(out_path, separator="\t") + + def write_report_data(report_df, cwms, out_dir): cwms_dir = os.path.join(out_dir, "CWMs") os.makedirs(cwms_dir, exist_ok=True) diff --git a/src/finemo/evaluation.py b/src/finemo/evaluation.py index 72358c5..ff85d73 100644 --- a/src/finemo/evaluation.py +++ b/src/finemo/evaluation.py @@ -160,7 +160,7 @@ def plot_peak_motif_indicator_heatmap(peak_hit_counts, motif_names, output_dir): fig, ax = plt.subplots(figsize=(8, 8)) # Plot the heatmap - ax.imshow(matrix, interpolation="nearest", aspect="auto", cmap="Greens") + cax = ax.imshow(matrix, interpolation="nearest", aspect="equal", cmap="Greens") # Set axes on heatmap ax.set_yticks(np.arange(len(motif_keys))) @@ -170,6 +170,11 @@ def plot_peak_motif_indicator_heatmap(peak_hit_counts, motif_names, output_dir): ax.set_xlabel("Motif i") ax.set_ylabel("Motif j") + ax.tick_params(axis='both', labelsize=8) + + cbar = fig.colorbar(cax, ax=ax, orientation="vertical", shrink=0.6, aspect=30) + cbar.ax.tick_params(labelsize=8) + output_path_png = os.path.join(output_dir, "motif_cooocurrence.png") plt.savefig(output_path_png, dpi=300) output_path_svg = os.path.join(output_dir, "motif_cooocurrence.svg") @@ -178,6 +183,35 @@ def plot_peak_motif_indicator_heatmap(peak_hit_counts, motif_names, output_dir): plt.close() +def plot_seqlet_confusion_heatmap(seqlet_confusion, motif_names, output_dir): + motif_keys = [abbreviate_motif_name(m) for m in motif_names] + + fig, ax = plt.subplots(figsize=(8, 8)) + + # Plot the heatmap + cax = ax.imshow(seqlet_confusion, interpolation="nearest", aspect="equal", cmap="Blues") + + # Set axes on heatmap + ax.set_yticks(np.arange(len(motif_keys))) + ax.set_yticklabels(motif_keys) + ax.set_xticks(np.arange(len(motif_keys))) + ax.set_xticklabels(motif_keys, rotation=90) + ax.set_xlabel("Hit motif") + ax.set_ylabel("Seqlet motif") + + ax.tick_params(axis='both', labelsize=8) + + cbar = fig.colorbar(cax, ax=ax, orientation="vertical", shrink=0.6, aspect=30) + cbar.ax.tick_params(labelsize=8) + + output_path_png = os.path.join(output_dir, "seqlet_confusion.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(output_dir, "seqlet_confusion.svg") + plt.savefig(output_path_svg) + + plt.close() + + def get_cwms(regions, positions_df, motif_width): idx_df = ( positions_df @@ -366,6 +400,87 @@ def tfmodisco_comparison(regions, hits_df, peaks_df, seqlets_df, motifs_df, cwms return report_data, report_df, cwms, cwm_trim_bounds +def seqlet_confusion(hits_df, seqlets_df, peaks_df, motif_names, motif_width): + bin_size = motif_width - 1 + + hits_binned = ( + hits_df + .with_columns( + peak_id=pl.col('peak_id').cast(pl.UInt32), + is_revcomp=pl.col("strand") == '-' + ) + .join( + peaks_df.lazy(), on="peak_id", how="inner" + ) + .unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"]) + .select( + chr_id=pl.col("chr_id"), + start_bin=pl.col("start_untrimmed") // bin_size, + end_bin=pl.col("end_untrimmed") // bin_size, + motif_name=pl.col("motif_name") + ) + ) + + seqlets_binned = ( + seqlets_df + .select( + chr_id=pl.col("chr_id"), + start_bin=pl.col("start_untrimmed") // bin_size, + end_bin=pl.col("end_untrimmed") // bin_size, + motif_name=pl.col("motif_name") + ) + ) + + overlaps_df = ( + seqlets_binned.join( + hits_binned, + on=["chr_id", "start_bin", "end_bin"], + how="inner", + suffix="_hits" + ) + ) + + seqlet_counts = seqlets_binned.group_by("motif_name").len(name="num_seqlets") + overlap_counts = overlaps_df.group_by(["motif_name", "motif_name_hits"]).len(name="num_overlaps") + + num_motifs = len(motif_names) + confusion_mat = np.zeros((num_motifs, num_motifs), dtype=np.float32) + name_to_idx = {m: i for i, m in enumerate(motif_names)} + + confusion_df = ( + overlap_counts + .join( + seqlet_counts, + on="motif_name", + how="inner" + ) + .select( + motif_name_seqlets=pl.col("motif_name"), + motif_name_hits=pl.col("motif_name_hits"), + frac_overlap=pl.col("num_overlaps") / pl.col("num_seqlets"), + ) + .collect() + ) + + confusion_idx_df = ( + confusion_df + .select( + row_idx=pl.col("motif_name_seqlets").replace_strict(name_to_idx), + col_idx=pl.col("motif_name_hits").replace_strict(name_to_idx), + frac_overlap=pl.col("frac_overlap") + ) + ) + + row_idx = confusion_idx_df["row_idx"].to_numpy() + col_idx = confusion_idx_df["col_idx"].to_numpy() + frac_overlap = confusion_idx_df["frac_overlap"].to_numpy() + + confusion_mat[row_idx, col_idx] = frac_overlap + + return confusion_df, confusion_mat + + + class LogoGlyph(AbstractPathEffect): def __init__(self, glyph, ref_glyph='E', font_props=None, offset=(0., 0.), **kwargs): diff --git a/src/finemo/main.py b/src/finemo/main.py index 50be9b2..0f1101c 100644 --- a/src/finemo/main.py +++ b/src/finemo/main.py @@ -218,6 +218,9 @@ def report(regions_path, hits_dir, modisco_h5_path, peaks_path, motifs_include_p cwms_modisco, motif_names, modisco_half_width, motif_width, compute_recall) + if seqlets_df is not None: + confusion_df, confusion_mat = evaluation.seqlet_confusion(hits_df, seqlets_df, peaks_df, motif_names, motif_width) + os.makedirs(out_dir, exist_ok=True) occ_path = os.path.join(out_dir, "motif_occurrences.tsv") @@ -239,7 +242,11 @@ def report(regions_path, hits_dir, modisco_h5_path, peaks_path, motifs_include_p seqlets_path = os.path.join(out_dir, "seqlets.tsv") data_io.write_modisco_seqlets(seqlets_df, seqlets_path) + seqlet_confusion_path = os.path.join(out_dir, "seqlet_confusion.tsv") + data_io.write_seqlet_confusion_df(confusion_df, seqlet_confusion_path) + evaluation.plot_hit_vs_seqlet_counts(report_data, out_dir) + evaluation.plot_seqlet_confusion_heatmap(confusion_mat, motif_names, out_dir) report_path = os.path.join(out_dir, "report.html") evaluation.write_report(report_df, motif_names, report_path, compute_recall, seqlets_df is not None) diff --git a/src/finemo/templates/report.html b/src/finemo/templates/report.html index 0b2ed57..8465ead 100644 --- a/src/finemo/templates/report.html +++ b/src/finemo/templates/report.html @@ -323,6 +323,18 @@

Hit and seqlet motif comparisons

Motif NameHits Per PeakHits Per RegionHit CoefficientHit SimilarityHit Importance
{{ m }}
+{% if compute_recall %} + +

Seqlet-hit confusion matrix

+

+ This heatmap shows the prevalence of motifs whose (untrimmed) hits overlap with TF-MoDISco seqlets of other motifs. + The vertical axis shows the motif of the seqlet, while the horizontal axis shows the motif of the hit. + The color intensity here represents an estimator of the expected number of bases of hit overlap per base of seqlet. +

+ + +{% endif %} +

Hit statistic distributions

The following figures visualize the distribution of hit statistics across motifs and regions. From 2cb2bf2dc1b5735b29ce18e48339f28cbbee048d Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Thu, 1 May 2025 11:03:50 -0700 Subject: [PATCH 10/39] Adjust plot padding --- src/finemo/evaluation.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/finemo/evaluation.py b/src/finemo/evaluation.py index ff85d73..5e44529 100644 --- a/src/finemo/evaluation.py +++ b/src/finemo/evaluation.py @@ -157,7 +157,7 @@ def plot_peak_motif_indicator_heatmap(peak_hit_counts, motif_names, output_dir): matrix = peak_hit_counts * cov_norm[:,None] * cov_norm[None,:] motif_keys = [abbreviate_motif_name(m) for m in motif_names] - fig, ax = plt.subplots(figsize=(8, 8)) + fig, ax = plt.subplots(figsize=(8, 8), layout='constrained') # Plot the heatmap cax = ax.imshow(matrix, interpolation="nearest", aspect="equal", cmap="Greens") @@ -186,7 +186,7 @@ def plot_peak_motif_indicator_heatmap(peak_hit_counts, motif_names, output_dir): def plot_seqlet_confusion_heatmap(seqlet_confusion, motif_names, output_dir): motif_keys = [abbreviate_motif_name(m) for m in motif_names] - fig, ax = plt.subplots(figsize=(8, 8)) + fig, ax = plt.subplots(figsize=(8, 8), layout='constrained') # Plot the heatmap cax = ax.imshow(seqlet_confusion, interpolation="nearest", aspect="equal", cmap="Blues") @@ -585,7 +585,7 @@ def plot_hit_vs_seqlet_counts(recall_data, output_dir): lim = max(np.amax(x), np.amax(y)) - fig, ax = plt.subplots(figsize=(8,8)) + fig, ax = plt.subplots(figsize=(8,8), layout='constrained') ax.axline((0, 0), (lim, lim), color="0.3", linewidth=0.7, linestyle=(0, (5, 5))) ax.scatter(x, y, s=5) for i, txt in enumerate(m): From 61182f7794e5fb387dad9cc06b8052ddf7ff49f9 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Thu, 1 May 2025 15:38:53 -0700 Subject: [PATCH 11/39] Adjust confusion matrix calculation --- src/finemo/evaluation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/finemo/evaluation.py b/src/finemo/evaluation.py index 5e44529..488cd48 100644 --- a/src/finemo/evaluation.py +++ b/src/finemo/evaluation.py @@ -401,7 +401,7 @@ def tfmodisco_comparison(regions, hits_df, peaks_df, seqlets_df, motifs_df, cwms def seqlet_confusion(hits_df, seqlets_df, peaks_df, motif_names, motif_width): - bin_size = motif_width - 1 + bin_size = motif_width hits_binned = ( hits_df From 1947e80e79d1fd88cce1ee55fa59f2f3b92442c0 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Mon, 5 May 2025 15:16:42 -0700 Subject: [PATCH 12/39] Refactor visualization and eval code --- src/finemo/evaluation.py | 315 ------------------------------------ src/finemo/main.py | 18 +-- src/finemo/visualization.py | 313 +++++++++++++++++++++++++++++++++++ 3 files changed, 321 insertions(+), 325 deletions(-) create mode 100644 src/finemo/visualization.py diff --git a/src/finemo/evaluation.py b/src/finemo/evaluation.py index 488cd48..98bb919 100644 --- a/src/finemo/evaluation.py +++ b/src/finemo/evaluation.py @@ -1,36 +1,7 @@ -import os import warnings -import importlib import numpy as np import polars as pl -import matplotlib.pyplot as plt -from matplotlib.patheffects import AbstractPathEffect -from matplotlib.textpath import TextPath -from matplotlib.transforms import Affine2D -from matplotlib.font_manager import FontProperties -from jinja2 import Template - -from . import templates - - -def abbreviate_motif_name(name): - try: - group, motif = name.split(".") - - if group == "pos_patterns": - group_short = "+" - elif group == "neg_patterns": - group_short = "-" - else: - raise Exception - - motif_num = motif.split("_")[1] - - return f"{group_short}/{motif_num}" - - except: - return name def get_motif_occurences(hits_df, motif_names): @@ -62,156 +33,6 @@ def get_motif_occurences(hits_df, motif_names): return occ_df, coocc -def plot_hit_stat_distributions(hits_df, motif_names, plot_dir): - hits_df = hits_df.collect() - hits_by_motif = hits_df.partition_by("motif_name", as_dict=True) - dummy_df = hits_df.clear() - - motifs_dir = os.path.join(plot_dir, "motif_stat_distributions") - os.makedirs(motifs_dir, exist_ok=True) - for m in motif_names: - hits = hits_by_motif.get((m,), dummy_df) - coefficients = hits.get_column("hit_coefficient_global").to_numpy() - similarities = hits.get_column("hit_similarity").to_numpy() - importances = hits.get_column("hit_importance").to_numpy() - - fig, ax = plt.subplots(figsize=(5, 2)) - - ax.hist(coefficients, bins=50, density=True) - - output_path_png = os.path.join(motifs_dir, f"{m}_coefficients.png") - plt.savefig(output_path_png, dpi=300) - output_path_svg = os.path.join(motifs_dir, f"{m}_coefficients.svg") - plt.savefig(output_path_svg) - plt.close(fig) - - fig, ax = plt.subplots(figsize=(5, 2)) - - ax.hist(similarities, bins=50, density=True) - - output_path_png = os.path.join(motifs_dir, f"{m}_similarities.png") - plt.savefig(output_path_png, dpi=300) - output_path_svg = os.path.join(motifs_dir, f"{m}_similarities.svg") - plt.savefig(output_path_svg) - plt.close(fig) - - fig, ax = plt.subplots(figsize=(5, 2)) - - ax.hist(importances, bins=50, density=True) - - output_path_png = os.path.join(motifs_dir, f"{m}_importances.png") - plt.savefig(output_path_png, dpi=300) - output_path_svg = os.path.join(motifs_dir, f"{m}_importances.svg") - plt.savefig(output_path_svg) - plt.close(fig) - - -def plot_hit_peak_distributions(occ_df, motif_names, plot_dir): - motifs_dir = os.path.join(plot_dir, "motif_hit_distributions") - os.makedirs(motifs_dir, exist_ok=True) - - for m in motif_names: - fig, ax = plt.subplots(figsize=(5, 2)) - - unique, counts = np.unique(occ_df.get_column(m), return_counts=True) - freq = counts / counts.sum() - num_bins = np.amax(unique, initial=0) + 1 - x = np.arange(num_bins) - y = np.zeros(num_bins) - y[unique] = freq - ax.bar(x, y) - - output_path_png = os.path.join(motifs_dir, f"{m}.png") - plt.savefig(output_path_png, dpi=300) - output_path_svg = os.path.join(motifs_dir, f"{m}.svg") - plt.savefig(output_path_svg) - - plt.close(fig) - - fig, ax = plt.subplots(figsize=(8, 4)) - - unique, counts = np.unique(occ_df.get_column("total"), return_counts=True) - freq = counts / counts.sum() - num_bins = np.amax(unique, initial=0) + 1 - x = np.arange(num_bins) - y = np.zeros(num_bins) - y[unique] = freq - ax.bar(x, y) - - ax.set_xlabel("Total hits per region") - ax.set_ylabel("Frequency") - - output_path_png = os.path.join(plot_dir, "total_hit_distribution.png") - plt.savefig(output_path_png, dpi=300) - output_path_svg = os.path.join(plot_dir, "total_hit_distribution.svg") - plt.savefig(output_path_svg, dpi=300) - - plt.close(fig) - - -def plot_peak_motif_indicator_heatmap(peak_hit_counts, motif_names, output_dir): - """ - Plots a simple indicator heatmap of the motifs in each peak. - """ - cov_norm = 1 / np.sqrt(np.diag(peak_hit_counts)) - matrix = peak_hit_counts * cov_norm[:,None] * cov_norm[None,:] - motif_keys = [abbreviate_motif_name(m) for m in motif_names] - - fig, ax = plt.subplots(figsize=(8, 8), layout='constrained') - - # Plot the heatmap - cax = ax.imshow(matrix, interpolation="nearest", aspect="equal", cmap="Greens") - - # Set axes on heatmap - ax.set_yticks(np.arange(len(motif_keys))) - ax.set_yticklabels(motif_keys) - ax.set_xticks(np.arange(len(motif_keys))) - ax.set_xticklabels(motif_keys, rotation=90) - ax.set_xlabel("Motif i") - ax.set_ylabel("Motif j") - - ax.tick_params(axis='both', labelsize=8) - - cbar = fig.colorbar(cax, ax=ax, orientation="vertical", shrink=0.6, aspect=30) - cbar.ax.tick_params(labelsize=8) - - output_path_png = os.path.join(output_dir, "motif_cooocurrence.png") - plt.savefig(output_path_png, dpi=300) - output_path_svg = os.path.join(output_dir, "motif_cooocurrence.svg") - plt.savefig(output_path_svg) - - plt.close() - - -def plot_seqlet_confusion_heatmap(seqlet_confusion, motif_names, output_dir): - motif_keys = [abbreviate_motif_name(m) for m in motif_names] - - fig, ax = plt.subplots(figsize=(8, 8), layout='constrained') - - # Plot the heatmap - cax = ax.imshow(seqlet_confusion, interpolation="nearest", aspect="equal", cmap="Blues") - - # Set axes on heatmap - ax.set_yticks(np.arange(len(motif_keys))) - ax.set_yticklabels(motif_keys) - ax.set_xticks(np.arange(len(motif_keys))) - ax.set_xticklabels(motif_keys, rotation=90) - ax.set_xlabel("Hit motif") - ax.set_ylabel("Seqlet motif") - - ax.tick_params(axis='both', labelsize=8) - - cbar = fig.colorbar(cax, ax=ax, orientation="vertical", shrink=0.6, aspect=30) - cbar.ax.tick_params(labelsize=8) - - output_path_png = os.path.join(output_dir, "seqlet_confusion.png") - plt.savefig(output_path_png, dpi=300) - output_path_svg = os.path.join(output_dir, "seqlet_confusion.svg") - plt.savefig(output_path_svg) - - plt.close() - - def get_cwms(regions, positions_df, motif_width): idx_df = ( positions_df @@ -480,139 +301,3 @@ def seqlet_confusion(hits_df, seqlets_df, peaks_df, motif_names, motif_width): return confusion_df, confusion_mat - -class LogoGlyph(AbstractPathEffect): - def __init__(self, glyph, ref_glyph='E', font_props=None, - offset=(0., 0.), **kwargs): - - super().__init__(offset) - - path_orig = TextPath((0, 0), glyph, size=1, prop=font_props) - dims = path_orig.get_extents() - ref_dims = TextPath((0, 0), ref_glyph, size=1, prop=font_props).get_extents() - - h_scale = 1 / dims.height - ref_width = max(dims.width, ref_dims.width) - w_scale = 1 / ref_width - w_shift = (1 - dims.width / ref_width) / 2 - x_shift = -dims.x0 - y_shift = -dims.y0 - stretch = ( - Affine2D() - .translate(tx=x_shift, ty=y_shift) - .scale(sx=w_scale, sy=h_scale) - .translate(tx=w_shift, ty=0) - ) - - self.path = stretch.transform_path(path_orig) - - #: The dictionary of keywords to update the graphics collection with. - self._gc = kwargs - - def draw_path(self, renderer, gc, tpath, affine, rgbFace): - return renderer.draw_path(gc, self.path, affine, rgbFace) - - -def plot_logo(ax, heights, glyphs, colors=None, font_props=None, shade_bounds=None): - if colors is None: - colors = {g: None for g in glyphs} - - ax.margins(x=0, y=0) - - pos_values = np.clip(heights, 0, None) - neg_values = np.clip(heights, None, 0) - pos_order = np.argsort(pos_values, axis=0) - neg_order = np.argsort(neg_values, axis=0)[::-1,:] - pos_reorder = np.argsort(pos_order, axis=0) - neg_reorder = np.argsort(neg_order, axis=0) - pos_offsets = np.take_along_axis( - np.cumsum( - np.take_along_axis(pos_values, pos_order, axis=0), axis=0 - ), pos_reorder, axis=0 - ) - neg_offsets = np.take_along_axis( - np.cumsum( - np.take_along_axis(neg_values, neg_order, axis=0), axis=0 - ), neg_reorder, axis=0 - ) - bottoms = pos_offsets + neg_offsets - heights - - x = np.arange(heights.shape[1]) - - for glyph, height, bottom in zip(glyphs, heights, bottoms): - ax.bar(x, height, 0.95, bottom=bottom, - path_effects=[LogoGlyph(glyph, font_props=font_props)], color=colors[glyph]) - - if shade_bounds is not None: - start, end = shade_bounds - ax.axvspan(start - 0.5, end - 0.5, color='0.9', zorder=-1) - - ax.axhline(zorder=-1, linewidth=0.5, color='black') - - -LOGO_ALPHABET = 'ACGT' -LOGO_COLORS = {"A": '#109648', "C": '#255C99', "G": '#F7B32B', "T": '#D62839'} -LOGO_FONT = FontProperties(weight="bold") - -def plot_cwms(cwms, trim_bounds, out_dir, alphabet=LOGO_ALPHABET, colors=LOGO_COLORS, font=LOGO_FONT): - for m, v in cwms.items(): - motif_dir = os.path.join(out_dir, m) - os.makedirs(motif_dir, exist_ok=True) - for cwm_type, cwm in v.items(): - fig, ax = plt.subplots(figsize=(10,2)) - - plot_logo(ax, cwm, alphabet, colors=colors, font_props=font, shade_bounds=trim_bounds[m][cwm_type]) - - for name, spine in ax.spines.items(): - spine.set_visible(False) - - output_path_png = os.path.join(motif_dir, f"{cwm_type}.png") - plt.savefig(output_path_png, dpi=100) - output_path_svg = os.path.join(motif_dir, f"{cwm_type}.svg") - plt.savefig(output_path_svg) - - plt.close(fig) - - -def plot_hit_vs_seqlet_counts(recall_data, output_dir): - x = [] - y = [] - m = [] - for k, v in recall_data.items(): - x.append(v["num_hits_total"]) - y.append(v["num_seqlets"]) - m.append(k) - - lim = max(np.amax(x), np.amax(y)) - - fig, ax = plt.subplots(figsize=(8,8), layout='constrained') - ax.axline((0, 0), (lim, lim), color="0.3", linewidth=0.7, linestyle=(0, (5, 5))) - ax.scatter(x, y, s=5) - for i, txt in enumerate(m): - short = abbreviate_motif_name(txt) - ax.annotate(short, (x[i], y[i]), fontsize=8, weight="bold") - - ax.set_yscale('log') - ax.set_xscale('log') - - ax.set_xlabel("Hits per motif") - ax.set_ylabel("Seqlets per motif") - - output_path_png = os.path.join(output_dir, "hit_vs_seqlet_counts.png") - plt.savefig(output_path_png, dpi=300) - output_path_svg = os.path.join(output_dir, "hit_vs_seqlet_counts.svg") - plt.savefig(output_path_svg) - - plt.close() - - -def write_report(report_df, motif_names, out_path, compute_recall, use_seqlets): - template_str = importlib.resources.files(templates).joinpath('report.html').read_text() - template = Template(template_str) - report = template.render(report_data=report_df.iter_rows(named=True), - motif_names=motif_names, compute_recall=compute_recall, - use_seqlets=use_seqlets) - with open(out_path, "w") as f: - f.write(report) - - diff --git a/src/finemo/main.py b/src/finemo/main.py index 0f1101c..852147c 100644 --- a/src/finemo/main.py +++ b/src/finemo/main.py @@ -152,7 +152,7 @@ def call_hits(regions_path, peaks_path, modisco_h5_path, chrom_order_path, motif def report(regions_path, hits_dir, modisco_h5_path, peaks_path, motifs_include_path, motif_names_path, out_dir, modisco_region_width, cwm_trim_threshold, compute_recall, use_seqlets): - from . import evaluation + from . import evaluation, visualization sequences, contribs, peaks_df, _ = data_io.load_regions_npz(regions_path) if len(contribs.shape) == 3: @@ -228,14 +228,12 @@ def report(regions_path, hits_dir, modisco_h5_path, peaks_path, motifs_include_p data_io.write_report_data(report_df, cwms, out_dir) - evaluation.plot_hit_stat_distributions(hits_df, motif_names, out_dir) - - evaluation.plot_hit_peak_distributions(occ_df, motif_names, out_dir) - - evaluation.plot_peak_motif_indicator_heatmap(coooc, motif_names, out_dir) + visualization.plot_hit_stat_distributions(hits_df, motif_names, out_dir) + visualization.plot_hit_peak_distributions(occ_df, motif_names, out_dir) + visualization.plot_peak_motif_indicator_heatmap(coooc, motif_names, out_dir) plot_dir = os.path.join(out_dir, "CWMs") - evaluation.plot_cwms(cwms, trim_bounds, plot_dir) + visualization.plot_cwms(cwms, trim_bounds, plot_dir) if seqlets_df is not None: seqlets_df = seqlets_df.collect() @@ -245,11 +243,11 @@ def report(regions_path, hits_dir, modisco_h5_path, peaks_path, motifs_include_p seqlet_confusion_path = os.path.join(out_dir, "seqlet_confusion.tsv") data_io.write_seqlet_confusion_df(confusion_df, seqlet_confusion_path) - evaluation.plot_hit_vs_seqlet_counts(report_data, out_dir) - evaluation.plot_seqlet_confusion_heatmap(confusion_mat, motif_names, out_dir) + visualization.plot_hit_vs_seqlet_counts(report_data, out_dir) + visualization.plot_seqlet_confusion_heatmap(confusion_mat, motif_names, out_dir) report_path = os.path.join(out_dir, "report.html") - evaluation.write_report(report_df, motif_names, report_path, compute_recall, seqlets_df is not None) + visualization.write_report(report_df, motif_names, report_path, compute_recall, seqlets_df is not None) def collapse_hits(hits_path, out_path, overlap_frac): diff --git a/src/finemo/visualization.py b/src/finemo/visualization.py new file mode 100644 index 0000000..fd348a1 --- /dev/null +++ b/src/finemo/visualization.py @@ -0,0 +1,313 @@ +import os +import importlib + +import numpy as np +import matplotlib.pyplot as plt +from matplotlib.patheffects import AbstractPathEffect +from matplotlib.textpath import TextPath +from matplotlib.transforms import Affine2D +from matplotlib.font_manager import FontProperties +from jinja2 import Template + +from . import templates + + +def abbreviate_motif_name(name): + try: + group, motif = name.split(".") + if group == "pos_patterns": + group_short = "+" + elif group == "neg_patterns": + group_short = "-" + else: + raise Exception + motif_num = motif.split("_")[1] + return f"{group_short}/{motif_num}" + except: + return name + + +def plot_hit_stat_distributions(hits_df, motif_names, plot_dir): + hits_df = hits_df.collect() + hits_by_motif = hits_df.partition_by("motif_name", as_dict=True) + dummy_df = hits_df.clear() + + motifs_dir = os.path.join(plot_dir, "motif_stat_distributions") + os.makedirs(motifs_dir, exist_ok=True) + for m in motif_names: + hits = hits_by_motif.get((m,), dummy_df) + coefficients = hits.get_column("hit_coefficient_global").to_numpy() + similarities = hits.get_column("hit_similarity").to_numpy() + importances = hits.get_column("hit_importance").to_numpy() + + fig, ax = plt.subplots(figsize=(5, 2)) + + ax.hist(coefficients, bins=50, density=True) + + output_path_png = os.path.join(motifs_dir, f"{m}_coefficients.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(motifs_dir, f"{m}_coefficients.svg") + plt.savefig(output_path_svg) + plt.close(fig) + + fig, ax = plt.subplots(figsize=(5, 2)) + + ax.hist(similarities, bins=50, density=True) + + output_path_png = os.path.join(motifs_dir, f"{m}_similarities.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(motifs_dir, f"{m}_similarities.svg") + plt.savefig(output_path_svg) + plt.close(fig) + + fig, ax = plt.subplots(figsize=(5, 2)) + + ax.hist(importances, bins=50, density=True) + + output_path_png = os.path.join(motifs_dir, f"{m}_importances.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(motifs_dir, f"{m}_importances.svg") + plt.savefig(output_path_svg) + plt.close(fig) + + +def plot_hit_peak_distributions(occ_df, motif_names, plot_dir): + motifs_dir = os.path.join(plot_dir, "motif_hit_distributions") + os.makedirs(motifs_dir, exist_ok=True) + + for m in motif_names: + fig, ax = plt.subplots(figsize=(5, 2)) + + unique, counts = np.unique(occ_df.get_column(m), return_counts=True) + freq = counts / counts.sum() + num_bins = np.amax(unique, initial=0) + 1 + x = np.arange(num_bins) + y = np.zeros(num_bins) + y[unique] = freq + ax.bar(x, y) + + output_path_png = os.path.join(motifs_dir, f"{m}.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(motifs_dir, f"{m}.svg") + plt.savefig(output_path_svg) + + plt.close(fig) + + fig, ax = plt.subplots(figsize=(8, 4)) + + unique, counts = np.unique(occ_df.get_column("total"), return_counts=True) + freq = counts / counts.sum() + num_bins = np.amax(unique, initial=0) + 1 + x = np.arange(num_bins) + y = np.zeros(num_bins) + y[unique] = freq + ax.bar(x, y) + + ax.set_xlabel("Total hits per region") + ax.set_ylabel("Frequency") + + output_path_png = os.path.join(plot_dir, "total_hit_distribution.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(plot_dir, "total_hit_distribution.svg") + plt.savefig(output_path_svg, dpi=300) + + plt.close(fig) + + +def plot_peak_motif_indicator_heatmap(peak_hit_counts, motif_names, output_dir): + """ + Plots a simple indicator heatmap of the motifs in each peak. + """ + cov_norm = 1 / np.sqrt(np.diag(peak_hit_counts)) + matrix = peak_hit_counts * cov_norm[:, None] * cov_norm[None, :] + motif_keys = [abbreviate_motif_name(m) for m in motif_names] + + fig, ax = plt.subplots(figsize=(8, 8), layout='constrained') + + # Plot the heatmap + cax = ax.imshow(matrix, interpolation="nearest", aspect="equal", cmap="Greens") + + # Set axes on heatmap + ax.set_yticks(np.arange(len(motif_keys))) + ax.set_yticklabels(motif_keys) + ax.set_xticks(np.arange(len(motif_keys))) + ax.set_xticklabels(motif_keys, rotation=90) + ax.set_xlabel("Motif i") + ax.set_ylabel("Motif j") + + ax.tick_params(axis='both', labelsize=8) + + cbar = fig.colorbar(cax, ax=ax, orientation="vertical", shrink=0.6, aspect=30) + cbar.ax.tick_params(labelsize=8) + + output_path_png = os.path.join(output_dir, "motif_cooocurrence.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(output_dir, "motif_cooocurrence.svg") + plt.savefig(output_path_svg) + + plt.close() + + +def plot_seqlet_confusion_heatmap(seqlet_confusion, motif_names, output_dir): + motif_keys = [abbreviate_motif_name(m) for m in motif_names] + + fig, ax = plt.subplots(figsize=(8, 8), layout='constrained') + + # Plot the heatmap + cax = ax.imshow(seqlet_confusion, interpolation="nearest", aspect="equal", cmap="Blues") + + # Set axes on heatmap + ax.set_yticks(np.arange(len(motif_keys))) + ax.set_yticklabels(motif_keys) + ax.set_xticks(np.arange(len(motif_keys))) + ax.set_xticklabels(motif_keys, rotation=90) + ax.set_xlabel("Hit motif") + ax.set_ylabel("Seqlet motif") + + ax.tick_params(axis='both', labelsize=8) + + cbar = fig.colorbar(cax, ax=ax, orientation="vertical", shrink=0.6, aspect=30) + cbar.ax.tick_params(labelsize=8) + + output_path_png = os.path.join(output_dir, "seqlet_confusion.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(output_dir, "seqlet_confusion.svg") + plt.savefig(output_path_svg) + + plt.close() + + +class LogoGlyph(AbstractPathEffect): + def __init__(self, glyph, ref_glyph='E', font_props=None, + offset=(0., 0.), **kwargs): + + super().__init__(offset) + + path_orig = TextPath((0, 0), glyph, size=1, prop=font_props) + dims = path_orig.get_extents() + ref_dims = TextPath((0, 0), ref_glyph, size=1, prop=font_props).get_extents() + + h_scale = 1 / dims.height + ref_width = max(dims.width, ref_dims.width) + w_scale = 1 / ref_width + w_shift = (1 - dims.width / ref_width) / 2 + x_shift = -dims.x0 + y_shift = -dims.y0 + stretch = ( + Affine2D() + .translate(tx=x_shift, ty=y_shift) + .scale(sx=w_scale, sy=h_scale) + .translate(tx=w_shift, ty=0) + ) + + self.path = stretch.transform_path(path_orig) + + #: The dictionary of keywords to update the graphics collection with. + self._gc = kwargs + + def draw_path(self, renderer, gc, tpath, affine, rgbFace): + return renderer.draw_path(gc, self.path, affine, rgbFace) + + +def plot_logo(ax, heights, glyphs, colors=None, font_props=None, shade_bounds=None): + if colors is None: + colors = {g: None for g in glyphs} + + ax.margins(x=0, y=0) + + pos_values = np.clip(heights, 0, None) + neg_values = np.clip(heights, None, 0) + pos_order = np.argsort(pos_values, axis=0) + neg_order = np.argsort(neg_values, axis=0)[::-1, :] + pos_reorder = np.argsort(pos_order, axis=0) + neg_reorder = np.argsort(neg_order, axis=0) + pos_offsets = np.take_along_axis( + np.cumsum( + np.take_along_axis(pos_values, pos_order, axis=0), axis=0 + ), pos_reorder, axis=0 + ) + neg_offsets = np.take_along_axis( + np.cumsum( + np.take_along_axis(neg_values, neg_order, axis=0), axis=0 + ), neg_reorder, axis=0 + ) + bottoms = pos_offsets + neg_offsets - heights + + x = np.arange(heights.shape[1]) + + for glyph, height, bottom in zip(glyphs, heights, bottoms): + ax.bar(x, height, 0.95, bottom=bottom, + path_effects=[LogoGlyph(glyph, font_props=font_props)], color=colors[glyph]) + + if shade_bounds is not None: + start, end = shade_bounds + ax.axvspan(start - 0.5, end - 0.5, color='0.9', zorder=-1) + + ax.axhline(zorder=-1, linewidth=0.5, color='black') + + +LOGO_ALPHABET = 'ACGT' +LOGO_COLORS = {"A": '#109648', "C": '#255C99', "G": '#F7B32B', "T": '#D62839'} +LOGO_FONT = FontProperties(weight="bold") + + +def plot_cwms(cwms, trim_bounds, out_dir, alphabet=LOGO_ALPHABET, colors=LOGO_COLORS, font=LOGO_FONT): + for m, v in cwms.items(): + motif_dir = os.path.join(out_dir, m) + os.makedirs(motif_dir, exist_ok=True) + for cwm_type, cwm in v.items(): + fig, ax = plt.subplots(figsize=(10, 2)) + + plot_logo(ax, cwm, alphabet, colors=colors, font_props=font, shade_bounds=trim_bounds[m][cwm_type]) + + for name, spine in ax.spines.items(): + spine.set_visible(False) + + output_path_png = os.path.join(motif_dir, f"{cwm_type}.png") + plt.savefig(output_path_png, dpi=100) + output_path_svg = os.path.join(motif_dir, f"{cwm_type}.svg") + plt.savefig(output_path_svg) + + plt.close(fig) + + +def plot_hit_vs_seqlet_counts(recall_data, output_dir): + x = [] + y = [] + m = [] + for k, v in recall_data.items(): + x.append(v["num_hits_total"]) + y.append(v["num_seqlets"]) + m.append(k) + + lim = max(np.amax(x), np.amax(y)) + + fig, ax = plt.subplots(figsize=(8, 8), layout='constrained') + ax.axline((0, 0), (lim, lim), color="0.3", linewidth=0.7, linestyle=(0, (5, 5))) + ax.scatter(x, y, s=5) + for i, txt in enumerate(m): + short = abbreviate_motif_name(txt) + ax.annotate(short, (x[i], y[i]), fontsize=8, weight="bold") + + ax.set_yscale('log') + ax.set_xscale('log') + + ax.set_xlabel("Hits per motif") + ax.set_ylabel("Seqlets per motif") + + output_path_png = os.path.join(output_dir, "hit_vs_seqlet_counts.png") + plt.savefig(output_path_png, dpi=300) + output_path_svg = os.path.join(output_dir, "hit_vs_seqlet_counts.svg") + plt.savefig(output_path_svg) + + plt.close() + + +def write_report(report_df, motif_names, out_path, compute_recall, use_seqlets): + template_str = importlib.resources.files(templates).joinpath('report.html').read_text() + template = Template(template_str) + report = template.render(report_data=report_df.iter_rows(named=True), + motif_names=motif_names, compute_recall=compute_recall, + use_seqlets=use_seqlets) + with open(out_path, "w") as f: + f.write(report) \ No newline at end of file From ff0282057e6fecdf5219c6a5927674e0d4bae2dc Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Sat, 30 Aug 2025 07:38:39 -0700 Subject: [PATCH 13/39] Update gitignore --- .gitignore | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/.gitignore b/.gitignore index f68e48a..5c3c2fa 100644 --- a/.gitignore +++ b/.gitignore @@ -4,4 +4,5 @@ __pycache__ /notebooks /notebooks/old -/scratch.txt \ No newline at end of file +/scratch.txt +/scratch \ No newline at end of file From 670e1fe188b7ad7bc29fb9ee35188960f12abed4 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Sun, 31 Aug 2025 13:34:41 -0700 Subject: [PATCH 14/39] Code documentation --- .gitignore | 4 +- README.md | 76 +- environment.yml | 3 +- pyproject.toml | 7 +- setup.py | 2 +- src/finemo/__init__.py | 63 ++ src/finemo/data_io.py | 1200 +++++++++++++++++++++---- src/finemo/evaluation.py | 527 ++++++++--- src/finemo/hitcaller.py | 786 +++++++++++++---- src/finemo/main.py | 1416 ++++++++++++++++++++++++------ src/finemo/postprocessing.py | 261 +++++- src/finemo/templates/report.html | 139 +-- src/finemo/visualization.py | 488 ++++++++-- 13 files changed, 4030 insertions(+), 942 deletions(-) diff --git a/.gitignore b/.gitignore index 5c3c2fa..78dbf8b 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,6 @@ -.conda -.DS_Store *.egg-info __pycache__ +/.* /notebooks -/notebooks/old /scratch.txt /scratch \ No newline at end of file diff --git a/README.md b/README.md index 623b220..9231afd 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,36 @@ -# finemo_gpu +# Fi-NeMo: Finding Neural network Motifs -**Fi-NeMo** (**Fi**nding **Ne**ural network **Mo**tifs) is a GPU-accelerated hit caller for identifying occurrences of TFMoDISCo motifs within contribution scores generated by machine learning models. +**Fi-NeMo** (**Fi**nding **Ne**ural network **Mo**tifs) is a GPU-accelerated motif instance calling tool for identifying transcription factor binding sites from neural network contribution scores. + +## Overview + +Fi-NeMo implements a competitive optimization approach using proximal gradient descent to identify motif instances by solving a sparse linear reconstruction problem. Unlike traditional sequence-based methods, Fi-NeMo leverages context-aware importance scores from deep neural networks to accurately map transcription factor binding sites, enabling the discovery of both high-confidence canonical motifs and low-prevalence cofactor motifs that are often missed by conventional approaches. + +The algorithm represents contribution scores as weighted combinations of motif contribution weight matrices (CWMs) at specific genomic positions. This competitive assignment process more closely reflects the biological reality of transcription factors competing for binding sites, resulting in superior sensitivity and specificity compared to sequence-only methods. + +### Features + +- **GPU-accelerated optimization**: Fast processing of large contribution score datasets using PyTorch +- **Competitive motif assignment**: Biologically-motivated algorithm that resolves similar motifs +- **Context-aware analysis**: Leverages neural network importance scores for improved sensitivity and specificity +- **Comprehensive evaluation**: Built-in tools for assessing and visualizing motif discovery quality and hit calling performance +- **Multiple input formats**: Support for bigWig, HDF5, and TF-MoDISCo output formats + +## Method + +Fi-NeMo solves motif instance calling as an optimization problem that reconstructs contribution score tracks as sparse linear combinations of motif CWMs. The algorithm formulates this as an L1-regularized linear model. + +This competitive assignment encourages overlapping motif instances to be resolved in a meaningful way, with stronger matches receiving higher coefficients while weaker or redundant matches are suppressed. + +## References + +Fi-NeMo is described in: +> Tseng, Ramalingam, Wang, Schreiber, et al. "Decoding predictive motif lexicons and syntax from deep learning models of transcription factor binding profiles." (manuscript in preparation) + +Related tools: +- [TF-MoDISCo](https://github.com/jmschrei/tfmodisco-lite): *De novo* motif discovery from importance scores +- [BPNet](https://github.com/kundajelab/bpnet-refactor): Deep learning models for TF binding prediction +- [ChromBPNet](https://github.com/kundajelab/chrombpnet): Deep learning models for chromatin accessibility prediction ## Installation @@ -19,7 +49,7 @@ cd finemo_gpu #### Create a Conda Environment with Dependencies -This step is optional but recommended +This step is optional but recommended for conda users. ```sh conda env create -f environment.yml -n $ENV_NAME @@ -60,7 +90,13 @@ Recommended: ## Usage -Fi-NeMo includes a command-line utility named `finemo`. Here, we describe basic usage for each subcommand. For all options, run `finemo -h`. +Fi-NeMo provides a command-line utility named `finemo` for motif instance calling and analysis. The typical workflow involves three main steps: + +1. **Preprocessing**: Transform input contributions and sequences into a compressed format +2. **Hit Calling**: Identify motif instances using the Fi-NeMo algorithm +3. **Reporting and Analysis**: Generate visualizations and perform post-processing + +For detailed options for any subcommand, run `finemo -h`. ### Preprocessing @@ -126,7 +162,7 @@ Usage: `finemo extract-regions-modisco-fmt -s -a -o < #### `finemo call-hits` -Identify hits in input regions using TFMoDISCo CWM's. +Identify motif instances in input regions using the Fi-NeMo competitive optimization algorithm. This is the core functionality that leverages TF-MoDISCo CWMs to find motif occurrences in contribution score data. Usage: `finemo call-hits -r -m -o [-p ] [-t ] [-l ] [-b ] [-J]` @@ -195,13 +231,31 @@ Usage: `finemo call-hits -r -m -o [-p ] `params.json`: The parameters used for hit calling. -#### Additional notes +#### Parameter Guidelines + +**Sensitivity Control (`-l/--global-lambda`)** +- Controls sparsity and sensitivity of hit calling +- Higher values (e.g., 0.8-0.9) → fewer, higher-confidence hits +- Lower values (e.g., 0.5-0.6) → more sensitive, may include weaker hits +- Default of 0.7 works well for chromatin accessibility data +- ChIP-seq data may benefit from lower values (0.6) + +**Motif Trimming (`-t/--cwm-trim-threshold`)** +- Determines where motif boundaries are set within full CWMs +- Lower values → more conservative trimming, longer motifs +- Higher values → more aggressive trimming, shorter core motifs +- Affects resolution of closely-spaced motif instances + +**Performance Optimization (`-b/--batch-size`, `-J`)** +- Set batch size to utilize available GPU memory efficiently +- Reduce batch size if you encounter out-of-memory errors +- Enable JIT compilation (`-J`) for faster execution on newer GPUs + +#### Important Notes -- The `-l/--global-lambda` parameter controls the sensitivity of the hit-calling algorithm, with higher values resulting in fewer but more confident hits. This parameter represents the minimum cosine similarity between a query contribution score window and a CWM to be considered a hit. The default value of 0.7 typically works well for chromatin accessibility data. ChIP-Seq data may require a lower value (e.g. 0.6). -- The `-t/--cwm-trim-threshold` parameter sets the maximum relative contribution score in trimmed-out CWM flanks. If you find that motif flanks are being trimmed too aggressively, consider lowering this value. However, a too-low value may result in closely-spaced motif instances being missed. -- Set `-b/--batch-size` to fill a significant fraction of your GPU memory. **If you encounter GPU out-of-memory errors, try lowering this value.** -- Legacy TFMoDISCo H5 files can be updated to the newer TFMoDISCo-lite format with the `modisco convert` command found in the [tfmodisco-lite](https://github.com/jmschrei/tfmodisco-lite/tree/main) package. -- The hit-calling thresholding procedure is scale-invariant. That is, whether a position is assigned a hit depends on the shapes of the motif CWM and the contribution scores, not the absolute magnitude of the scores. If you wish to prioritize hits based on the magnitude of the contribution scores, set a per-motif rank threshold the `hit_coefficient_global` field in the `hits.tsv` file, which captures both the absolute importance and the closeness of match. +- **Scale Invariance**: Hit calling depends on motif and contribution score shapes, not absolute magnitudes. Use `hit_coefficient_global` or `hit_importance` for importance-based thresholding. +- **Competitive Assignment**: Overlapping motif candidates compete; only the best-fitting motif at each position receives a non-zero coefficient. +- **Legacy Format Support**: Convert older TF-MoDISCo files using `modisco convert` from [tfmodisco-lite](https://github.com/jmschrei/tfmodisco-lite). ### Output reporting and post-processing diff --git a/environment.yml b/environment.yml index 98e4cde..3720c11 100644 --- a/environment.yml +++ b/environment.yml @@ -18,4 +18,5 @@ dependencies: - tqdm=4.67.1 - jinja2=3.1.4 - pybigwig=0.3.23 - - pyfaidx=0.8.1.3 \ No newline at end of file + - pyfaidx=0.8.1.3 + - jaxtyping=0.3.2 \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index e527bce..747bb6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,12 +6,12 @@ build-backend = "setuptools.build_meta" name = "finemo" description = "Identification of regulatory elements from neural network contribution scores for DNA." keywords = ["deep learning", "genomics"] -version = "0.36" +version = "0.40" readme = "README.md" license = {file = "LICENSE"} authors = [ {name = "Austin Wang", email = "austin.wang1357@gmail.com"}, - {name = "Anshul Kundaje"} + {name = "Anshul Kundaje", email = "akundaje@stanford.edu"} ] dependencies = [ "numpy", @@ -34,3 +34,6 @@ finemo = "finemo.main:cli" [project.urls] Homepage = "https://github.com/austintwang/finemo_gpu" Repository = "https://github.com/austintwang/finemo_gpu.git" + +[tool.ruff] +ignore = ["F722"] \ No newline at end of file diff --git a/setup.py b/setup.py index 88c95dc..3e0f86b 100644 --- a/setup.py +++ b/setup.py @@ -3,4 +3,4 @@ # Empty setup.py for compatibility with pip<21.1 # See pyproject.toml for package configuration -setup() \ No newline at end of file +setup() diff --git a/src/finemo/__init__.py b/src/finemo/__init__.py index e69de29..b170d01 100644 --- a/src/finemo/__init__.py +++ b/src/finemo/__init__.py @@ -0,0 +1,63 @@ +"""Fi-NeMo: Finding Neural network Motifs. + +A GPU-accelerated motif instance calling tool for identifying transcription factor +binding sites from neural network contribution scores. + +Fi-NeMo implements a competitive optimization approach using proximal gradient descent +to identify motif instances by solving a sparse linear reconstruction problem. The +algorithm represents contribution scores as weighted combinations of motif contribution +weight matrices (CWMs) at specific genomic positions. + +Key Features +------------ +- GPU-accelerated hit calling using PyTorch +- Support for multiple input formats (bigWig, HDF5, TF-MoDISCo) +- Competitive motif instance assignment +- Comprehensive evaluation and visualization tools +- Post-processing utilities for hit refinement + +Modules +------- +hitcaller : Core Fi-NeMo algorithm implementation +data_io : Data input/output utilities +main : Command-line interface +evaluation : Performance assessment tools +visualization : Plotting and report generation +postprocessing : Hit refinement and analysis + +Examples +-------- +Basic hit calling workflow: + +>>> import finemo +>>> from finemo import data_io, hitcaller +>>> +>>> # Load preprocessed data +>>> sequences, contribs, peaks_df, has_peaks = data_io.load_regions_npz('regions.npz') +>>> cwms, trim_masks = data_io.load_motif_cwms('motifs.h5') +>>> +>>> # Call hits +>>> hits_df, qc_df = hitcaller.fit_contribs( +... cwms=cwms, +... contribs=contribs, +... sequences=sequences, +... cwm_trim_mask=trim_masks, +... use_hypothetical=False, +... lambdas=np.array([0.7] * len(cwms)), +... step_size_max=3.0, +... step_size_min=0.08, +... sqrt_transform=False, +... convergence_tol=0.0005, +... max_steps=10000, +... batch_size=1000, +... step_adjust=0.7, +... post_filter=True, +... device=None, +... compile_optimizer=False +... ) + +See Also +-------- +TF-MoDISCo : https://github.com/jmschrei/tfmodisco-lite +BPNet : https://github.com/kundajelab/bpnet-refactor +""" diff --git a/src/finemo/data_io.py b/src/finemo/data_io.py index 1088d68..a604904 100644 --- a/src/finemo/data_io.py +++ b/src/finemo/data_io.py @@ -1,266 +1,846 @@ +"""Data input/output module for the Fi-NeMo motif instance calling pipeline. + +This module handles loading and processing of various genomic data formats including: +- Peak region files (ENCODE NarrowPeak format) +- Genome sequences (FASTA format) +- Contribution scores (bigWig, HDF5 formats) +- Neural network model outputs +- Motif data from TF-MoDISCo +- Hit calling results + +The module supports multiple input formats used for contribution scores +and provides utilities for data conversion and quality control. +""" + import json import os import warnings from contextlib import ExitStack +from typing import List, Dict, Tuple, Optional, Any, Union, Callable import numpy as np +from numpy import ndarray import h5py -import hdf5plugin +import hdf5plugin # noqa: F401, imported for side effects (HDF5 plugin registration) import polars as pl import pyBigWig import pyfaidx +from jaxtyping import Float, Int from tqdm import tqdm -def load_txt(path): +def load_txt(path: str) -> List[str]: + """Load a text file containing one item per line. + + Parameters + ---------- + path : str + Path to the text file. + + Returns + ------- + List[str] + List of strings, one per line (first column if tab-delimited). + """ entries = [] with open(path) as f: for line in f: item = line.rstrip("\n").split("\t")[0] entries.append(item) - + return entries -def load_mapping(path, type): +def load_mapping(path: str, value_type: Callable[[str], Any]) -> Dict[str, Any]: + """Load a two-column tab-delimited mapping file. + + Parameters + ---------- + path : str + Path to the mapping file. Must be tab-delimited with exactly two columns. + value_type : Callable[[str], Any] + Type constructor to apply to values (e.g., int, float, str). + Must accept a string and return the converted value. + + Returns + ------- + Dict[str, Any] + Dictionary mapping keys to values of the specified type. + + Raises + ------ + ValueError + If lines don't contain exactly two tab-separated values. + FileNotFoundError + If the specified file does not exist. + """ mapping = {} with open(path) as f: for line in f: key, val = line.rstrip("\n").split("\t") - mapping[key] = type(val) + mapping[key] = value_type(val) return mapping -def load_mapping_tuple(path, type): + +def load_mapping_tuple( + path: str, value_type: Callable[[str], Any] +) -> Dict[str, Tuple[Any, ...]]: + """Load a mapping file where values are tuples from multiple columns. + + Parameters + ---------- + path : str + Path to the mapping file. Must be tab-delimited with multiple columns. + value_type : Callable[[str], Any] + Type constructor to apply to each value element. + Must accept a string and return the converted value. + + Returns + ------- + Dict[str, Tuple[Any, ...]] + Dictionary mapping keys to tuples of values of the specified type. + The first column is used as the key, remaining columns as tuple values. + + Raises + ------ + ValueError + If lines don't contain at least two tab-separated values. + FileNotFoundError + If the specified file does not exist. + """ mapping = {} with open(path) as f: for line in f: entries = line.rstrip("\n").split("\t") key = entries[0] val = entries[1:] - mapping[key] = tuple(type(i) for i in val) + mapping[key] = tuple(value_type(i) for i in val) return mapping -NARROWPEAK_SCHEMA = ["chr", "peak_start", "peak_end", "peak_name", "peak_score", - "peak_strand", "peak_signal", "peak_pval", "peak_qval", "peak_summit"] -NARROWPEAK_DTYPES = [pl.String, pl.Int32, pl.Int32, pl.String, pl.UInt32, - pl.String, pl.Float32, pl.Float32, pl.Float32, pl.Int32] - -def load_peaks(peaks_path, chrom_order_path, half_width): +# ENCODE NarrowPeak format column definitions +NARROWPEAK_SCHEMA: List[str] = [ + "chr", + "peak_start", + "peak_end", + "peak_name", + "peak_score", + "peak_strand", + "peak_signal", + "peak_pval", + "peak_qval", + "peak_summit", +] +NARROWPEAK_DTYPES: List[Any] = [ + pl.String, + pl.Int32, + pl.Int32, + pl.String, + pl.UInt32, + pl.String, + pl.Float32, + pl.Float32, + pl.Float32, + pl.Int32, +] + + +def load_peaks( + peaks_path: str, chrom_order_path: Optional[str], half_width: int +) -> pl.DataFrame: + """Load peak region data from ENCODE NarrowPeak format file. + + Parameters + ---------- + peaks_path : str + Path to the NarrowPeak format file. + chrom_order_path : str, optional + Path to file defining chromosome ordering. If None, uses order from peaks file. + half_width : int + Half-width of regions around peak summits. + + Returns + ------- + pl.DataFrame + DataFrame containing peak information with columns: + - chr: Chromosome name + - peak_region_start: Start coordinate of centered region + - peak_name: Peak identifier + - peak_id: Sequential peak index + - chr_id: Numeric chromosome identifier + """ peaks = ( - pl.scan_csv(peaks_path, has_header=False, new_columns=NARROWPEAK_SCHEMA, separator='\t', - quote_char=None, schema_overrides=NARROWPEAK_DTYPES, null_values=['.', 'NA', 'null', 'NaN']) + pl.scan_csv( + peaks_path, + has_header=False, + new_columns=NARROWPEAK_SCHEMA, + separator="\t", + quote_char=None, + schema_overrides=NARROWPEAK_DTYPES, + null_values=[".", "NA", "null", "NaN"], + ) .select( chr=pl.col("chr"), peak_region_start=pl.col("peak_start") + pl.col("peak_summit") - half_width, - peak_name=pl.col("peak_name") + peak_name=pl.col("peak_name"), ) .with_row_index(name="peak_id") .collect() ) - + if chrom_order_path is not None: chrom_order = load_txt(chrom_order_path) else: chrom_order = [] chrom_order_set = set(chrom_order) - chrom_order_peaks = [i for i in peaks.get_column("chr").unique(maintain_order=True) if i not in chrom_order_set] + chrom_order_peaks = [ + i + for i in peaks.get_column("chr").unique(maintain_order=True) + if i not in chrom_order_set + ] chrom_order.extend(chrom_order_peaks) chrom_ind_map = {val: ind for ind, val in enumerate(chrom_order)} peaks = peaks.with_columns( pl.col("chr").replace_strict(chrom_ind_map).alias("chr_id") ) - + return peaks -SEQ_ALPHABET = np.array(["A","C","G","T"], dtype="S1") +# DNA sequence alphabet for one-hot encoding +SEQ_ALPHABET: np.ndarray = np.array(["A", "C", "G", "T"], dtype="S1") + + +def one_hot_encode(sequence: str, dtype: Any = np.int8) -> Int[ndarray, "4 L"]: + """Convert DNA sequence string to one-hot encoded matrix. -def one_hot_encode(sequence, dtype=np.int8): + Parameters + ---------- + sequence : str + DNA sequence string containing A, C, G, T characters. + dtype : np.dtype, default np.int8 + Data type for the output array. + + Returns + ------- + Int[ndarray, "4 L"] + One-hot encoded sequence where rows correspond to A, C, G, T and + L is the sequence length. + + Notes + ----- + The output array has shape (4, len(sequence)) with rows corresponding to + nucleotides A, C, G, T in that order. Non-standard nucleotides (N, etc.) + result in all-zero columns. + """ sequence = sequence.upper() - seq_chararray = np.frombuffer(sequence.encode('UTF-8'), dtype='S1') - one_hot = (seq_chararray[None,:] == SEQ_ALPHABET[:,None]).astype(dtype) + seq_chararray = np.frombuffer(sequence.encode("UTF-8"), dtype="S1") + one_hot = (seq_chararray[None, :] == SEQ_ALPHABET[:, None]).astype(dtype) return one_hot -def load_regions_from_bw(peaks, fa_path, bw_paths, half_width): +def load_regions_from_bw( + peaks: pl.DataFrame, fa_path: str, bw_paths: List[str], half_width: int +) -> Tuple[Int[ndarray, "N 4 L"], Float[ndarray, "N L"]]: + """Load genomic sequences and contribution scores from FASTA and bigWig files. + + Parameters + ---------- + peaks : pl.DataFrame + Peak regions DataFrame from load_peaks() containing columns: + 'chr', 'peak_region_start'. + fa_path : str + Path to genome FASTA file (.fa or .fasta format). + bw_paths : List[str] + List of paths to bigWig files containing contribution scores. + Must be non-empty. + half_width : int + Half-width of regions to extract around peak centers. + Total region width will be 2 * half_width. + + Returns + ------- + sequences : Int[ndarray, "N 4 L"] + One-hot encoded DNA sequences where N is the number of peaks, + 4 represents A,C,G,T nucleotides, and L is the region length (2 * half_width). + contribs : Float[ndarray, "N L"] + Contribution scores averaged across input bigWig files. + Shape is (N peaks, L region_length). + + Notes + ----- + BigWig files only provide projected contribution scores, not hypothetical scores. + Regions extending beyond chromosome boundaries are zero-padded. + Missing values in bigWig files are converted to zero. + """ num_peaks = peaks.height + region_width = half_width * 2 - sequences = np.zeros((num_peaks, 4, half_width * 2), dtype=np.int8) - contribs = np.zeros((num_peaks, half_width * 2), dtype=np.float16) + sequences = np.zeros((num_peaks, 4, region_width), dtype=np.int8) + contribs = np.zeros((num_peaks, region_width), dtype=np.float16) + # Load genome reference genome = pyfaidx.Fasta(fa_path, one_based_attributes=False) - + bws = [pyBigWig.open(i) for i in bw_paths] contrib_buffer = np.zeros((len(bw_paths), half_width * 2), dtype=np.float16) try: - for ind, row in tqdm(enumerate(peaks.iter_rows(named=True)), disable=None, unit="regions", total=num_peaks): + for ind, row in tqdm( + enumerate(peaks.iter_rows(named=True)), + disable=None, + unit="regions", + total=num_peaks, + ): chrom = row["chr"] start = row["peak_region_start"] end = start + 2 * half_width - - sequence_data = genome[chrom][start:end] - sequence = sequence_data.seq - start_adj = sequence_data.start - end_adj = sequence_data.end + + sequence_data: pyfaidx.FastaRecord = genome[chrom][start:end] # type: ignore + sequence: str = sequence_data.seq # type: ignore + start_adj: int = sequence_data.start # type: ignore + end_adj: int = sequence_data.end # type: ignore a = start_adj - start b = end_adj - start if b > a: - sequences[ind,:,a:b] = one_hot_encode(sequence) + sequences[ind, :, a:b] = one_hot_encode(sequence) for j, bw in enumerate(bws): - contrib_buffer[j,:] = np.nan_to_num(bw.values(chrom, start_adj, end_adj, numpy=True)) + contrib_buffer[j, :] = np.nan_to_num( + bw.values(chrom, start_adj, end_adj, numpy=True) + ) + + contribs[ind, a:b] = np.mean(contrib_buffer, axis=0) - contribs[ind,a:b] = np.mean(contrib_buffer, axis=0) - finally: for bw in bws: bw.close() - + return sequences, contribs -def load_regions_from_chrombpnet_h5(h5_paths, half_width): +def load_regions_from_chrombpnet_h5( + h5_paths: List[str], half_width: int +) -> Tuple[Int[ndarray, "N 4 L"], Float[ndarray, "N 4 L"]]: + """Load genomic sequences and contribution scores from ChromBPNet HDF5 files. + + Parameters + ---------- + h5_paths : List[str] + List of paths to ChromBPNet HDF5 files containing sequences and SHAP scores. + Must be non-empty and contain compatible data shapes. + half_width : int + Half-width of regions to extract around the center. + Total region width will be 2 * half_width. + + Returns + ------- + sequences : Int[ndarray, "N 4 L"] + One-hot encoded DNA sequences where N is the number of regions, + 4 represents A,C,G,T nucleotides, and L is the region length (2 * half_width). + contribs : Float[ndarray, "N 4 L"] + SHAP contribution scores averaged across input files. + Shape is (N regions, 4 nucleotides, L region_length). + + Notes + ----- + ChromBPNet files store sequences in 'raw/seq' and SHAP scores in 'shap/seq'. + All input files must have the same dimensions and number of regions. + Missing values in contribution scores are converted to zero. + """ with ExitStack() as stack: h5s = [stack.enter_context(h5py.File(i)) for i in h5_paths] - start = h5s[0]['raw/seq'].shape[-1] // 2 - half_width + start = h5s[0]["raw/seq"].shape[-1] // 2 - half_width # type: ignore # HDF5 array access end = start + 2 * half_width - - sequences = h5s[0]['raw/seq'][:,:,start:end].astype(np.int8) - contribs = np.mean([np.nan_to_num(f['shap/seq'][:,:,start:end]) for f in h5s], axis=0, dtype=np.float16) - - return sequences, contribs + sequences = h5s[0]["raw/seq"][:, :, start:end].astype(np.int8) # type: ignore # HDF5 array access + contribs = np.mean( + [np.nan_to_num(f["shap/seq"][:, :, start:end]) for f in h5s], # type: ignore # HDF5 array access + axis=0, + dtype=np.float16, + ) -def load_regions_from_bpnet_h5(h5_paths, half_width): + return sequences, contribs # type: ignore # HDF5 arrays converted to NumPy + + +def load_regions_from_bpnet_h5( + h5_paths: List[str], half_width: int +) -> Tuple[Int[ndarray, "N 4 L"], Float[ndarray, "N 4 L"]]: + """Load genomic sequences and contribution scores from BPNet HDF5 files. + + Parameters + ---------- + h5_paths : List[str] + List of paths to BPNet HDF5 files containing sequences and contribution scores. + Must be non-empty and contain compatible data shapes. + half_width : int + Half-width of regions to extract around the center. + Total region width will be 2 * half_width. + + Returns + ------- + sequences : Int[ndarray, "N 4 L"] + One-hot encoded DNA sequences where N is the number of regions, + 4 represents A,C,G,T nucleotides, and L is the region length (2 * half_width). + contribs : Float[ndarray, "N 4 L"] + Hypothetical contribution scores averaged across input files. + Shape is (N regions, 4 nucleotides, L region_length). + + Notes + ----- + BPNet files store sequences in 'input_seqs' and hypothetical scores in 'hyp_scores'. + The data requires axis swapping to convert from (n, length, 4) to (n, 4, length) format. + All input files must have the same dimensions and number of regions. + Missing values in contribution scores are converted to zero. + """ with ExitStack() as stack: h5s = [stack.enter_context(h5py.File(i)) for i in h5_paths] - start = h5s[0]['input_seqs'].shape[-2] // 2 - half_width + start = h5s[0]["input_seqs"].shape[-2] // 2 - half_width # type: ignore # HDF5 array access end = start + 2 * half_width - sequences = h5s[0]['input_seqs'][:,start:end,:].swapaxes(1,2).astype(np.int8) - contribs = np.mean([np.nan_to_num(f['hyp_scores'][:,start:end,:].swapaxes(1,2)) for f in h5s], axis=0, dtype=np.float16) + sequences = h5s[0]["input_seqs"][:, start:end, :].swapaxes(1, 2).astype(np.int8) # type: ignore # HDF5 array access with axis swap + contribs = np.mean( + [ + np.nan_to_num(f["hyp_scores"][:, start:end, :].swapaxes(1, 2)) # type: ignore # HDF5 array access + for f in h5s + ], + axis=0, + dtype=np.float16, + ) return sequences, contribs -def load_npy_or_npz(path): +def load_npy_or_npz(path: str) -> ndarray: + """Load array data from .npy or .npz file. + + Parameters + ---------- + path : str + Path to .npy or .npz file. File must exist and contain valid NumPy data. + + Returns + ------- + ndarray + Loaded array data. For .npz files, returns the first array ('arr_0'). + For .npy files, returns the array directly. + + Raises + ------ + FileNotFoundError + If the specified file does not exist. + KeyError + If .npz file does not contain 'arr_0' key. + """ f = np.load(path) if isinstance(f, np.ndarray): arr = f else: - arr = f['arr_0'] + arr = f["arr_0"] return arr -def load_regions_from_modisco_fmt(shaps_paths, ohe_path, half_width): + +def load_regions_from_modisco_fmt( + shaps_paths: List[str], ohe_path: str, half_width: int +) -> Tuple[Int[ndarray, "N 4 L"], Float[ndarray, "N 4 L"]]: + """Load genomic sequences and contribution scores from TF-MoDISCo format files. + + Parameters + ---------- + shaps_paths : List[str] + List of paths to .npy/.npz files containing SHAP/attribution scores. + Must be non-empty and all files must have compatible shapes. + ohe_path : str + Path to .npy/.npz file containing one-hot encoded sequences. + Must have shape (n_regions, 4, sequence_length). + half_width : int + Half-width of regions to extract around the center. + Total region width will be 2 * half_width. + + Returns + ------- + sequences : Int[ndarray, "N 4 L"] + One-hot encoded DNA sequences where N is the number of regions, + 4 represents A,C,G,T nucleotides, and L is the region length (2 * half_width). + contribs : Float[ndarray, "N 4 L"] + SHAP contribution scores averaged across input files. + Shape is (N regions, 4 nucleotides, L region_length). + + Notes + ----- + All SHAP files must have the same shape as the sequence file. + Missing values in contribution scores are converted to zero. + The center of the input sequences is used as the reference point for extraction. + """ sequences_raw = load_npy_or_npz(ohe_path) start = sequences_raw.shape[-1] // 2 - half_width end = start + 2 * half_width - sequences = sequences_raw[:,:,start:end].astype(np.int8) + sequences = sequences_raw[:, :, start:end].astype(np.int8) - shaps = [np.nan_to_num(load_npy_or_npz(p)[:,:,start:end]) for p in shaps_paths] + shaps = [np.nan_to_num(load_npy_or_npz(p)[:, :, start:end]) for p in shaps_paths] contribs = np.mean(shaps, axis=0, dtype=np.float16) return sequences, contribs -def load_regions_npz(npz_path): +def load_regions_npz( + npz_path: str, +) -> Tuple[ + Int[ndarray, "N 4 L"], + Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]], + pl.DataFrame, + bool, +]: + """Load preprocessed genomic regions from NPZ file. + + Parameters + ---------- + npz_path : str + Path to NPZ file containing sequences, contributions, and optional coordinates. + Must contain 'sequences' and 'contributions' arrays at minimum. + + Returns + ------- + sequences : Int[ndarray, "N 4 L"] + One-hot encoded DNA sequences where N is the number of regions, + 4 represents A,C,G,T nucleotides, and L is the region length. + contributions : Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]] + Contribution scores in either hypothetical format (N, 4, L) or + projected format (N, L). Shape depends on the input data format. + peaks_df : pl.DataFrame + DataFrame containing peak region information with columns: + 'chr', 'chr_id', 'peak_region_start', 'peak_id', 'peak_name'. + has_peaks : bool + Whether the file contains genomic coordinate information. + If False, placeholder coordinate data is used. + + Notes + ----- + If genomic coordinates are not present in the NPZ file, creates placeholder + coordinate data and issues a warning. The placeholder data uses 'NA' for + chromosome names and sequential indices for peak IDs. + + Raises + ------ + KeyError + If required arrays 'sequences' or 'contributions' are missing from the file. + """ data = np.load(npz_path) - + if "chr" not in data.keys(): - warnings.warn("No genome coordinates present in the input .npz file. Returning sequences and contributions only.") + warnings.warn( + "No genome coordinates present in the input .npz file. Returning sequences and contributions only." + ) has_peaks = False num_regions = data["sequences"].shape[0] - peak_data = {"chr": np.array(["NA"] * num_regions, dtype='U'), "chr_id": np.arange(num_regions, dtype=np.uint32), - "peak_region_start": np.zeros(num_regions, dtype=np.int32), "peak_id": np.arange(num_regions, dtype=np.uint32), - "peak_name": np.array(["NA"] * num_regions, dtype='U')} + peak_data = { + "chr": np.array(["NA"] * num_regions, dtype="U"), + "chr_id": np.arange(num_regions, dtype=np.uint32), + "peak_region_start": np.zeros(num_regions, dtype=np.int32), + "peak_id": np.arange(num_regions, dtype=np.uint32), + "peak_name": np.array(["NA"] * num_regions, dtype="U"), + } else: has_peaks = True - peak_data = {"chr": data["chr"], "chr_id": data["chr_id"], "peak_region_start": data["start"], - "peak_id": data["peak_id"], "peak_name": data["peak_name"]} - + peak_data = { + "chr": data["chr"], + "chr_id": data["chr_id"], + "peak_region_start": data["start"], + "peak_id": data["peak_id"], + "peak_name": data["peak_name"], + } + peaks_df = pl.DataFrame(peak_data) return data["sequences"], data["contributions"], peaks_df, has_peaks -def write_regions_npz(sequences, contributions, out_path, peaks_df=None): +def write_regions_npz( + sequences: Int[ndarray, "N 4 L"], + contributions: Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]], + out_path: str, + peaks_df: Optional[pl.DataFrame] = None, +) -> None: + """Write genomic regions and contribution scores to compressed NPZ file. + + Parameters + ---------- + sequences : Int[ndarray, "N 4 L"] + One-hot encoded DNA sequences where N is the number of regions, + 4 represents A,C,G,T nucleotides, and L is the region length. + contributions : Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]] + Contribution scores in either hypothetical format (N, 4, L) or + projected format (N, L). + out_path : str + Output path for the NPZ file. Parent directory must exist. + peaks_df : Optional[pl.DataFrame] + DataFrame containing peak region information with columns: + 'chr', 'chr_id', 'peak_region_start', 'peak_id', 'peak_name'. + If None, only sequences and contributions are saved. + + Raises + ------ + ValueError + If the number of regions in sequences/contributions doesn't match peaks_df. + FileNotFoundError + If the parent directory of out_path does not exist. + + Notes + ----- + The output file is compressed using NumPy's savez_compressed format. + If peaks_df is provided, genomic coordinate information is included + in the output file for downstream analysis. + """ if peaks_df is None: - warnings.warn("No genome coordinates provided. Writing sequences and contributions only.") + warnings.warn( + "No genome coordinates provided. Writing sequences and contributions only." + ) np.savez_compressed(out_path, sequences=sequences, contributions=contributions) else: num_regions = peaks_df.height - if (num_regions != sequences.shape[0]) or (num_regions != contributions.shape[0]): - raise ValueError(f"Input sequences of shape {sequences.shape} and/or " - f"input contributions of shape {contributions.shape} " - f"are not compatible with peak region count of {num_regions}" ) - - chr_arr = peaks_df.get_column("chr").to_numpy().astype('U') + if (num_regions != sequences.shape[0]) or ( + num_regions != contributions.shape[0] + ): + raise ValueError( + f"Input sequences of shape {sequences.shape} and/or " + f"input contributions of shape {contributions.shape} " + f"are not compatible with peak region count of {num_regions}" + ) + + chr_arr = peaks_df.get_column("chr").to_numpy().astype("U") chr_id_arr = peaks_df.get_column("chr_id").to_numpy() start_arr = peaks_df.get_column("peak_region_start").to_numpy() peak_id_arr = peaks_df.get_column("peak_id").to_numpy() - peak_name_arr = peaks_df.get_column("peak_name").to_numpy().astype('U') - np.savez_compressed(out_path, sequences=sequences, contributions=contributions, - chr=chr_arr, chr_id=chr_id_arr, start=start_arr, peak_id=peak_id_arr, peak_name=peak_name_arr) + peak_name_arr = peaks_df.get_column("peak_name").to_numpy().astype("U") + np.savez_compressed( + out_path, + sequences=sequences, + contributions=contributions, + chr=chr_arr, + chr_id=chr_id_arr, + start=start_arr, + peak_id=peak_id_arr, + peak_name=peak_name_arr, + ) -def trim_motif(cwm, trim_threshold): - """ +def trim_motif(cwm: Float[ndarray, "4 W"], trim_threshold: float) -> Tuple[int, int]: + """Determine trimmed start and end positions for a motif based on contribution magnitude. + + This function identifies the core region of a motif by finding positions where + the total absolute contribution exceeds a threshold relative to the maximum. + + Parameters + ---------- + cwm : Float[ndarray, "4 W"] + Contribution weight matrix for the motif where 4 represents A,C,G,T + nucleotides and W is the motif width. + trim_threshold : float + Fraction of maximum score to use as trimming threshold (0.0 to 1.0). + Higher values result in more aggressive trimming. + + Returns + ------- + start : int + Start position of the trimmed motif (inclusive). + end : int + End position of the trimmed motif (exclusive). + + Notes + ----- + The trimming is based on the sum of absolute contributions across all nucleotides + at each position. Positions with contributions below trim_threshold * max_score + are removed from the motif edges. + Adapted from https://github.com/jmschrei/tfmodisco-lite/blob/570535ee5ccf43d670e898d92d63af43d68c38c5/modiscolite/report.py#L213-L236 """ score = np.sum(np.abs(cwm), axis=0) trim_thresh = np.max(score) * trim_threshold pass_inds = np.nonzero(score >= trim_thresh) - start = max(np.min(pass_inds), 0) - end = min(np.max(pass_inds) + 1, len(score)) + start = max(int(np.min(pass_inds)), 0) # type: ignore # nonzero returns tuple of arrays + end = min(int(np.max(pass_inds)) + 1, len(score)) # type: ignore # nonzero returns tuple of arrays return start, end -def softmax(x, temp=100): +def softmax(x: Float[ndarray, "4 W"], temp: float = 100) -> Float[ndarray, "4 W"]: + """Apply softmax transformation with temperature scaling. + + Parameters + ---------- + x : Float[ndarray, "4 W"] + Input array to transform where 4 represents A,C,G,T nucleotides + and W is the motif width. + temp : float, default 100 + Temperature parameter for softmax scaling. Higher values create + sharper probability distributions. + + Returns + ------- + Float[ndarray, "4 W"] + Softmax-transformed array with same shape as input. Each column + sums to 1.0, representing nucleotide probabilities at each position. + + Notes + ----- + The softmax is applied along the nucleotide axis (axis=0), normalizing + each position to have probabilities that sum to 1. The temperature + parameter controls the sharpness of the distribution. + """ norm_x = x - np.mean(x, axis=1, keepdims=True) exp = np.exp(temp * norm_x) return exp / np.sum(exp, axis=0, keepdims=True) -def _motif_name_sort_key(data): +def _motif_name_sort_key(data: Tuple[str, Any]) -> Union[Tuple[int], Tuple[int, str]]: + """Generate sort key for TF-MoDISCo motif names. + + This function creates a sort key that orders motifs by pattern number, + with non-standard patterns sorted to the end. + + Parameters + ---------- + data : Tuple[str, Any] + Tuple containing motif name as first element and additional data. + The motif name should follow the format 'pattern_N' where N is an integer. + + Returns + ------- + Union[Tuple[int], Tuple[int, str]] + Sort key tuple for ordering motifs. Standard pattern names return + (pattern_number,) while non-standard names return (-1, name). + + Notes + ----- + This function is used internally by load_modisco_motifs to ensure + consistent motif ordering across runs. + """ name = data[0] if name.startswith("pattern_"): pattern_num = int(name.split("_")[-1]) return (pattern_num,) else: - return (-1, name) + return (-1, name) # Mixed tuple types for sorting + + +MODISCO_PATTERN_GROUPS = ["pos_patterns", "neg_patterns"] + + +def load_modisco_motifs( + modisco_h5_path: str, + trim_coords: Optional[Dict[str, Tuple[int, int]]], + trim_thresholds: Optional[Dict[str, float]], + trim_threshold_default: float, + motif_type: str, + motifs_include: Optional[List[str]], + motif_name_map: Optional[Dict[str, str]], + motif_lambdas: Optional[Dict[str, float]], + motif_lambda_default: float, + include_rc: bool, +) -> Tuple[pl.DataFrame, Float[ndarray, "M 4 W"], Int[ndarray, "M W"], ndarray]: + """Load motif data from TF-MoDISCo HDF5 file with customizable processing options. + + This function extracts contribution weight matrices and associated metadata from + TF-MoDISCo results, with support for custom naming, trimming, and regularization + parameters. + + Parameters + ---------- + modisco_h5_path : str + Path to TF-MoDISCo HDF5 results file containing pattern groups. + trim_coords : Optional[Dict[str, Tuple[int, int]]] + Manual trim coordinates for specific motifs {motif_name: (start, end)}. + Takes precedence over automatic trimming based on thresholds. + trim_thresholds : Optional[Dict[str, float]] + Custom trim thresholds for specific motifs {motif_name: threshold}. + Values should be between 0.0 and 1.0. + trim_threshold_default : float + Default trim threshold for motifs not in trim_thresholds. + Fraction of maximum contribution used for trimming. + motif_type : str + Type of motif to extract. Must be one of: + - 'cwm': Contribution weight matrix (normalized) + - 'hcwm': Hypothetical contribution weight matrix + - 'pfm': Position frequency matrix + - 'pfm_softmax': Softmax-transformed position frequency matrix + motifs_include : Optional[List[str]] + List of motif names to include. If None, includes all motifs found. + Names should follow format 'pos_patterns.pattern_N' or 'neg_patterns.pattern_N'. + motif_name_map : Optional[Dict[str, str]] + Mapping from original to custom motif names {orig_name: new_name}. + New names must be unique across all motifs. + motif_lambdas : Optional[Dict[str, float]] + Custom lambda regularization values for specific motifs {motif_name: lambda}. + Higher values increase sparsity penalty for the corresponding motif. + motif_lambda_default : float + Default lambda value for motifs not specified in motif_lambdas. + include_rc : bool + Whether to include reverse complement motifs in addition to forward motifs. + If True, doubles the number of motifs returned. + + Returns + ------- + motifs_df : pl.DataFrame + DataFrame containing motif metadata with columns: motif_id, motif_name, + motif_name_orig, strand, motif_start, motif_end, motif_scale, lambda. + cwms : Float[ndarray, "M 4 W"] + Contribution weight matrices for all motifs where M is the number of motifs, + 4 represents A,C,G,T nucleotides, and W is the motif width. + trim_masks : Int[ndarray, "M W"] + Binary masks indicating core motif regions (1) vs trimmed regions (0). + Shape is (M motifs, W motif_width). + names : ndarray + Array of unique motif names (forward strand only). + + Raises + ------ + ValueError + If motif_type is not one of the supported types, or if motif names + in motif_name_map are not unique. + FileNotFoundError + If the specified HDF5 file does not exist. + KeyError + If required datasets are missing from the HDF5 file. + + Notes + ----- + Motif trimming removes low-contribution positions from the edges based on + the position-wise sum of absolute contributions across nucleotides. The trimming + helps focus on the core binding site. -MODISCO_PATTERN_GROUPS = ['pos_patterns', 'neg_patterns'] - -def load_modisco_motifs(modisco_h5_path, trim_coords, trim_thresholds, trim_threshold_default, motif_type, - motifs_include, motif_name_map, motif_lambdas, motif_lambda_default, include_rc): - """ Adapted from https://github.com/jmschrei/tfmodisco-lite/blob/570535ee5ccf43d670e898d92d63af43d68c38c5/modiscolite/report.py#L252-L272 """ - motif_data_lsts = {"motif_name": [], "motif_name_orig": [], "strand": [], "motif_start": [], - "motif_end": [], "motif_scale": [], "lambda": []} - motif_lst = [] + motif_data_lsts = { + "motif_name": [], + "motif_name_orig": [], + "strand": [], + "motif_start": [], + "motif_end": [], + "motif_scale": [], + "lambda": [], + } + motif_lst = [] trim_mask_lst = [] if motifs_include is not None: - motifs_include = set(motifs_include) + motifs_include_set = set(motifs_include) + else: + motifs_include_set = None if motif_name_map is None: motif_name_map = {} @@ -276,37 +856,44 @@ def load_modisco_motifs(modisco_h5_path, trim_coords, trim_thresholds, trim_thre if len(motif_name_map.values()) != len(set(motif_name_map.values())): raise ValueError("Specified motif names are not unique") - with h5py.File(modisco_h5_path, 'r') as modisco_results: + with h5py.File(modisco_h5_path, "r") as modisco_results: for name in MODISCO_PATTERN_GROUPS: if name not in modisco_results.keys(): continue metacluster = modisco_results[name] - for ind, (pattern_name, pattern) in enumerate(sorted(metacluster.items(), key=_motif_name_sort_key)): - pattern_tag = f'{name}.{pattern_name}' - - if motifs_include is not None and pattern_tag not in motifs_include: + for _, (pattern_name, pattern) in enumerate( + sorted(metacluster.items(), key=_motif_name_sort_key) # type: ignore # HDF5 access + ): + pattern_tag = f"{name}.{pattern_name}" + + if ( + motifs_include_set is not None + and pattern_tag not in motifs_include_set + ): continue motif_lambda = motif_lambdas.get(pattern_tag, motif_lambda_default) pattern_tag_orig = pattern_tag pattern_tag = motif_name_map.get(pattern_tag, pattern_tag) - cwm_raw = pattern['contrib_scores'][:].T + cwm_raw = pattern["contrib_scores"][:].T # type: ignore cwm_norm = np.sqrt((cwm_raw**2).sum()) cwm_fwd = cwm_raw / cwm_norm - cwm_rev = cwm_fwd[::-1,::-1] + cwm_rev = cwm_fwd[::-1, ::-1] if pattern_tag in trim_coords: start_fwd, end_fwd = trim_coords[pattern_tag] else: - trim_threshold = trim_thresholds.get(pattern_tag, trim_threshold_default) + trim_threshold = trim_thresholds.get( + pattern_tag, trim_threshold_default + ) start_fwd, end_fwd = trim_motif(cwm_fwd, trim_threshold) cwm_len = cwm_fwd.shape[1] start_rev, end_rev = cwm_len - end_fwd, cwm_len - start_fwd - + trim_mask_fwd = np.zeros(cwm_fwd.shape[1], dtype=np.int8) trim_mask_fwd[start_fwd:end_fwd] = 1 trim_mask_rev = np.zeros(cwm_rev.shape[1], dtype=np.int8) @@ -318,29 +905,34 @@ def load_modisco_motifs(modisco_h5_path, trim_coords, trim_thresholds, trim_thre motif_norm = cwm_norm elif motif_type == "hcwm": - motif_raw = pattern['hypothetical_contribs'][:].T + motif_raw = pattern["hypothetical_contribs"][:].T # type: ignore motif_norm = np.sqrt((motif_raw**2).sum()) motif_fwd = motif_raw / motif_norm - motif_rev = motif_fwd[::-1,::-1] + motif_rev = motif_fwd[::-1, ::-1] elif motif_type == "pfm": - motif_raw = pattern['sequence'][:].T + motif_raw = pattern["sequence"][:].T # type: ignore motif_norm = 1 motif_fwd = motif_raw / np.sum(motif_raw, axis=0, keepdims=True) - motif_rev = motif_fwd[::-1,::-1] + motif_rev = motif_fwd[::-1, ::-1] elif motif_type == "pfm_softmax": - motif_raw = pattern['sequence'][:].T + motif_raw = pattern["sequence"][:].T # type: ignore motif_norm = 1 motif_fwd = softmax(motif_raw) - motif_rev = motif_fwd[::-1,::-1] + motif_rev = motif_fwd[::-1, ::-1] + + else: + raise ValueError( + f"Invalid motif_type: {motif_type}. Must be one of 'cwm', 'hcwm', 'pfm', 'pfm_softmax'." + ) motif_data_lsts["motif_name"].append(pattern_tag) motif_data_lsts["motif_name_orig"].append(pattern_tag_orig) - motif_data_lsts["strand"].append('+') + motif_data_lsts["strand"].append("+") motif_data_lsts["motif_start"].append(start_fwd) motif_data_lsts["motif_end"].append(end_fwd) motif_data_lsts["motif_scale"].append(motif_norm) @@ -349,7 +941,7 @@ def load_modisco_motifs(modisco_h5_path, trim_coords, trim_thresholds, trim_thre if include_rc: motif_data_lsts["motif_name"].append(pattern_tag) motif_data_lsts["motif_name_orig"].append(pattern_tag_orig) - motif_data_lsts["strand"].append('-') + motif_data_lsts["strand"].append("-") motif_data_lsts["motif_start"].append(start_rev) motif_data_lsts["motif_end"].append(end_rev) motif_data_lsts["motif_scale"].append(motif_norm) @@ -361,17 +953,75 @@ def load_modisco_motifs(modisco_h5_path, trim_coords, trim_thresholds, trim_thre else: motif_lst.append(motif_fwd) trim_mask_lst.append(trim_mask_fwd) - + motifs_df = pl.DataFrame(motif_data_lsts).with_row_index(name="motif_id") cwms = np.stack(motif_lst, dtype=np.float16, axis=0) trim_masks = np.stack(trim_mask_lst, dtype=np.int8, axis=0) - names = motifs_df.filter(pl.col("strand") == "+").get_column("motif_name").to_numpy() + names = ( + motifs_df.filter(pl.col("strand") == "+").get_column("motif_name").to_numpy() + ) return motifs_df, cwms, trim_masks, names -def load_modisco_seqlets(modisco_h5_path, peaks_df, motifs_df, half_width, modisco_half_width, lazy=False): - +def load_modisco_seqlets( + modisco_h5_path: str, + peaks_df: pl.DataFrame, + motifs_df: pl.DataFrame, + half_width: int, + modisco_half_width: int, + lazy: bool = False, +) -> Union[pl.DataFrame, pl.LazyFrame]: + """Load seqlet data from TF-MoDISCo HDF5 file and convert to genomic coordinates. + + This function extracts seqlet instances from TF-MoDISCo results and converts + their relative positions to absolute genomic coordinates using peak region + information. + + Parameters + ---------- + modisco_h5_path : str + Path to TF-MoDISCo HDF5 results file containing seqlet data. + peaks_df : pl.DataFrame + DataFrame containing peak region information with columns: + 'peak_id', 'chr', 'chr_id', 'peak_region_start'. + motifs_df : pl.DataFrame + DataFrame containing motif metadata with columns: + 'motif_name_orig', 'strand', 'motif_name', 'motif_start', 'motif_end'. + half_width : int + Half-width of the current analysis regions. + modisco_half_width : int + Half-width of the regions used in the original TF-MoDISCo analysis. + Used to calculate coordinate offsets. + lazy : bool, default False + If True, returns a LazyFrame for efficient chaining of operations. + If False, collects the result into a DataFrame. + + Returns + ------- + Union[pl.DataFrame, pl.LazyFrame] + Seqlets with genomic coordinates containing columns: + - chr: Chromosome name + - chr_id: Numeric chromosome identifier + - start: Start coordinate of trimmed motif instance + - end: End coordinate of trimmed motif instance + - start_untrimmed: Start coordinate of full motif instance + - end_untrimmed: End coordinate of full motif instance + - is_revcomp: Whether the motif is reverse complemented + - strand: Motif strand ('+' or '-') + - motif_name: Motif name (may be remapped) + - peak_id: Peak identifier + - peak_region_start: Peak region start coordinate + + Notes + ----- + Seqlets are deduplicated based on chromosome ID, start position (untrimmed), + motif name, and reverse complement status to avoid redundant instances. + + The coordinate transformation accounts for differences in region sizes + between the original TF-MoDISCo analysis and the current analysis. + """ + start_lst = [] end_lst = [] is_revcomp_lst = [] @@ -379,23 +1029,29 @@ def load_modisco_seqlets(modisco_h5_path, peaks_df, motifs_df, half_width, modis peak_id_lst = [] pattern_tags = [] - with h5py.File(modisco_h5_path, 'r') as modisco_results: + with h5py.File(modisco_h5_path, "r") as modisco_results: for name in MODISCO_PATTERN_GROUPS: if name not in modisco_results.keys(): continue metacluster = modisco_results[name] - key = lambda x: int(x[0].split("_")[-1]) - for ind, (pattern_name, pattern) in enumerate(sorted(metacluster.items(), key=key)): - pattern_tag = f'{name}.{pattern_name}' - starts = pattern['seqlets/start'][:].astype(np.int32) - ends = pattern['seqlets/end'][:].astype(np.int32) - is_revcomps = pattern['seqlets/is_revcomp'][:].astype(bool) - strands = ['+' if not i else '-' for i in is_revcomps] - peak_ids = pattern['seqlets/example_idx'][:].astype(np.uint32) + def get_pattern_number(x): + return int(x[0].split("_")[-1]) - n_seqlets = int(pattern['seqlets/n_seqlets'][0]) + key = get_pattern_number + for _, (pattern_name, pattern) in enumerate( + sorted(metacluster.items(), key=key) # type: ignore # HDF5 access + ): + pattern_tag = f"{name}.{pattern_name}" + + starts = pattern["seqlets/start"][:].astype(np.int32) # type: ignore + ends = pattern["seqlets/end"][:].astype(np.int32) # type: ignore + is_revcomps = pattern["seqlets/is_revcomp"][:].astype(bool) # type: ignore + strands = ["+" if not i else "-" for i in is_revcomps] + peak_ids = pattern["seqlets/example_idx"][:].astype(np.uint32) # type: ignore + + n_seqlets = int(pattern["seqlets/n_seqlets"][0]) # type: ignore start_lst.append(starts) end_lst.append(ends) @@ -412,7 +1068,7 @@ def load_modisco_seqlets(modisco_h5_path, peaks_df, motifs_df, half_width, modis "peak_id": np.concatenate(peak_id_lst), "motif_name_orig": pattern_tags, } - + offset = half_width - modisco_half_width seqlets_df = ( @@ -422,15 +1078,23 @@ def load_modisco_seqlets(modisco_h5_path, peaks_df, motifs_df, half_width, modis .select( chr=pl.col("chr"), chr_id=pl.col("chr_id"), - start=pl.col("peak_region_start") + pl.col("seqlet_start") + pl.col("motif_start") + offset, - end=pl.col("peak_region_start") + pl.col("seqlet_start") + pl.col("motif_end") + offset, - start_untrimmed=pl.col("peak_region_start") + pl.col("seqlet_start") + offset, + start=pl.col("peak_region_start") + + pl.col("seqlet_start") + + pl.col("motif_start") + + offset, + end=pl.col("peak_region_start") + + pl.col("seqlet_start") + + pl.col("motif_end") + + offset, + start_untrimmed=pl.col("peak_region_start") + + pl.col("seqlet_start") + + offset, end_untrimmed=pl.col("peak_region_start") + pl.col("seqlet_end") + offset, is_revcomp=pl.col("is_revcomp"), strand=pl.col("strand"), motif_name=pl.col("motif_name"), peak_id=pl.col("peak_id"), - peak_region_start=pl.col("peak_region_start") + peak_region_start=pl.col("peak_region_start"), ) .unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"]) ) @@ -440,8 +1104,27 @@ def load_modisco_seqlets(modisco_h5_path, peaks_df, motifs_df, half_width, modis return seqlets_df -def write_modisco_seqlets(seqlets_df, out_path): +def write_modisco_seqlets( + seqlets_df: Union[pl.DataFrame, pl.LazyFrame], out_path: str +) -> None: + """Write TF-MoDISCo seqlets to TSV file. + + Parameters + ---------- + seqlets_df : Union[pl.DataFrame, pl.LazyFrame] + Seqlets DataFrame with genomic coordinates. Must contain columns + that are safe to drop: 'chr_id', 'is_revcomp'. + out_path : str + Output TSV file path. + + Notes + ----- + Removes internal columns 'chr_id' and 'is_revcomp' before writing + to create a clean output format suitable for downstream analysis. + """ seqlets_df = seqlets_df.drop(["chr_id", "is_revcomp"]) + if isinstance(seqlets_df, pl.LazyFrame): + seqlets_df = seqlets_df.collect() seqlets_df.write_csv(out_path, separator="\t") @@ -465,71 +1148,144 @@ def write_modisco_seqlets(seqlets_df, out_path): HITS_COLLAPSED_DTYPES = HITS_DTYPES | {"is_primary": pl.UInt32} -def load_hits(hits_path, lazy=False, schema=HITS_DTYPES): - hits_df = ( - pl.scan_csv(hits_path, separator='\t', quote_char=None, schema=schema) - .with_columns(pl.lit(1).alias("count")) - ) +def load_hits( + hits_path: str, lazy: bool = False, schema: Dict[str, Any] = HITS_DTYPES +) -> Union[pl.DataFrame, pl.LazyFrame]: + """Load motif hit data from TSV file. + + Parameters + ---------- + hits_path : str + Path to TSV file containing motif hit results. + lazy : bool, default False + If True, returns a LazyFrame for efficient chaining operations. + If False, collects the result into a DataFrame. + schema : Dict[str, Any], default HITS_DTYPES + Schema defining column names and data types for the hit data. + + Returns + ------- + Union[pl.DataFrame, pl.LazyFrame] + Hit data with an additional 'count' column set to 1 for aggregation. + """ + hits_df = pl.scan_csv( + hits_path, separator="\t", quote_char=None, schema=schema + ).with_columns(pl.lit(1).alias("count")) return hits_df if lazy else hits_df.collect() -def write_hits_processed(hits_df, out_path, schema=HITS_DTYPES): +def write_hits_processed( + hits_df: Union[pl.DataFrame, pl.LazyFrame], + out_path: str, + schema: Optional[Dict[str, Any]] = HITS_DTYPES, +) -> None: + """Write processed hit data to TSV file with optional column filtering. + + Parameters + ---------- + hits_df : Union[pl.DataFrame, pl.LazyFrame] + Hit data to write to file. + out_path : str + Output path for the TSV file. + schema : Optional[Dict[str, Any]], default HITS_DTYPES + Schema defining which columns to include in output. + If None, all columns are written. + """ if schema is not None: hits_df = hits_df.select(schema.keys()) + + if isinstance(hits_df, pl.LazyFrame): + hits_df = hits_df.collect() + hits_df.write_csv(out_path, separator="\t") -def write_hits(hits_df, peaks_df, motifs_df, qc_df, out_dir, motif_width): +def write_hits( + hits_df: Union[pl.DataFrame, pl.LazyFrame], + peaks_df: pl.DataFrame, + motifs_df: pl.DataFrame, + qc_df: pl.DataFrame, + out_dir: str, + motif_width: int, +) -> None: + """Write comprehensive hit results to multiple output files. + + This function combines hit data with peak, motif, and quality control information + to generate complete output files including genomic coordinates and scores. + + Parameters + ---------- + hits_df : Union[pl.DataFrame, pl.LazyFrame] + Hit data containing motif instance information. + peaks_df : pl.DataFrame + Peak region information for coordinate conversion. + motifs_df : pl.DataFrame + Motif metadata for annotation and trimming information. + qc_df : pl.DataFrame + Quality control data for normalization factors. + out_dir : str + Output directory for results files. Will be created if it doesn't exist. + motif_width : int + Width of motif instances for coordinate calculations. + + Notes + ----- + Creates three output files: + - hits.tsv: Complete hit data with all instances + - hits_unique.tsv: Deduplicated hits by genomic position and motif + - hits.bed: BED format file for genome browser visualization + """ os.makedirs(out_dir, exist_ok=True) out_path_tsv = os.path.join(out_dir, "hits.tsv") out_path_tsv_unique = os.path.join(out_dir, "hits_unique.tsv") out_path_bed = os.path.join(out_dir, "hits.bed") data_all = ( - hits_df - .lazy() + hits_df.lazy() .join(peaks_df.lazy(), on="peak_id", how="inner") .join(qc_df.lazy(), on="peak_id", how="inner") .join(motifs_df.lazy(), on="motif_id", how="inner") .select( chr_id=pl.col("chr_id"), chr=pl.col("chr"), - start=pl.col("peak_region_start") + pl.col("hit_start") + pl.col("motif_start"), + start=pl.col("peak_region_start") + + pl.col("hit_start") + + pl.col("motif_start"), end=pl.col("peak_region_start") + pl.col("hit_start") + pl.col("motif_end"), start_untrimmed=pl.col("peak_region_start") + pl.col("hit_start"), - end_untrimmed=pl.col("peak_region_start") + pl.col("hit_start") + motif_width, + end_untrimmed=pl.col("peak_region_start") + + pl.col("hit_start") + + motif_width, motif_name=pl.col("motif_name"), hit_coefficient=pl.col("hit_coefficient"), - hit_coefficient_global=pl.col("hit_coefficient") * (pl.col("global_scale")**2), + hit_coefficient_global=pl.col("hit_coefficient") + * (pl.col("global_scale") ** 2), hit_similarity=pl.col("hit_similarity"), hit_correlation=pl.col("hit_similarity"), hit_importance=pl.col("hit_importance") * pl.col("global_scale"), - hit_importance_sq=pl.col("hit_importance_sq") * (pl.col("global_scale")**2), + hit_importance_sq=pl.col("hit_importance_sq") + * (pl.col("global_scale") ** 2), strand=pl.col("strand"), peak_name=pl.col("peak_name"), peak_id=pl.col("peak_id"), - motif_lambda = pl.col("lambda"), + motif_lambda=pl.col("lambda"), ) .sort(["chr_id", "start"]) .select(HITS_DTYPES.keys()) ) - data_unique = ( - data_all - .unique(subset=["chr", "start", "motif_name", "strand"], maintain_order=True) + data_unique = data_all.unique( + subset=["chr", "start", "motif_name", "strand"], maintain_order=True ) - data_bed = ( - data_unique - .select( - chr=pl.col("chr"), - start=pl.col("start"), - end=pl.col("end"), - motif_name=pl.col("motif_name"), - score=pl.lit(0), - strand=pl.col("strand") - ) + data_bed = data_unique.select( + chr=pl.col("chr"), + start=pl.col("start"), + end=pl.col("end"), + motif_name=pl.col("motif_name"), + score=pl.lit(0), + strand=pl.col("strand"), ) data_all.collect().write_csv(out_path_tsv, separator="\t") @@ -537,10 +1293,20 @@ def write_hits(hits_df, peaks_df, motifs_df, qc_df, out_dir, motif_width): data_bed.collect().write_csv(out_path_bed, include_header=False, separator="\t") -def write_qc(qc_df, peaks_df, out_path): +def write_qc(qc_df: pl.DataFrame, peaks_df: pl.DataFrame, out_path: str) -> None: + """Write quality control data with peak information to TSV file. + + Parameters + ---------- + qc_df : pl.DataFrame + Quality control metrics for each peak region. + peaks_df : pl.DataFrame + Peak region information for coordinate annotation. + out_path : str + Output path for the TSV file. + """ df = ( - qc_df - .lazy() + qc_df.lazy() .join(peaks_df.lazy(), on="peak_id", how="inner") .sort(["chr_id", "peak_region_start"]) .drop("chr_id") @@ -549,7 +1315,16 @@ def write_qc(qc_df, peaks_df, out_path): df.write_csv(out_path, separator="\t") -def write_motifs_df(motifs_df, out_path): +def write_motifs_df(motifs_df: pl.DataFrame, out_path: str) -> None: + """Write motif metadata to TSV file. + + Parameters + ---------- + motifs_df : pl.DataFrame + Motif metadata DataFrame. + out_path : str + Output path for the TSV file. + """ motifs_df.write_csv(out_path, separator="\t") @@ -564,42 +1339,132 @@ def write_motifs_df(motifs_df, out_path): "lambda": pl.Float32, } -def load_motifs_df(motifs_path): + +def load_motifs_df(motifs_path: str) -> Tuple[pl.DataFrame, ndarray]: + """Load motif metadata from TSV file. + + Parameters + ---------- + motifs_path : str + Path to motif metadata TSV file. + + Returns + ------- + motifs_df : pl.DataFrame + Motif metadata with predefined schema. + motif_names : ndarray + Array of unique forward-strand motif names. + """ motifs_df = pl.read_csv(motifs_path, separator="\t", schema=MOTIF_DTYPES) - motif_names = motifs_df.filter(pl.col("strand") == "+").get_column("motif_name").to_numpy() + motif_names = ( + motifs_df.filter(pl.col("strand") == "+").get_column("motif_name").to_numpy() + ) return motifs_df, motif_names -def write_motif_cwms(cwms, out_path): +def write_motif_cwms(cwms: Float[ndarray, "M 4 W"], out_path: str) -> None: + """Write motif contribution weight matrices to .npy file. + + Parameters + ---------- + cwms : Float[ndarray, "M 4 W"] + Contribution weight matrices for M motifs, 4 nucleotides, W width. + out_path : str + Output path for the .npy file. + """ np.save(out_path, cwms) -def load_motif_cwms(cwms_path): +def load_motif_cwms(cwms_path: str) -> Float[ndarray, "M 4 W"]: + """Load motif contribution weight matrices from .npy file. + + Parameters + ---------- + cwms_path : str + Path to .npy file containing CWMs. + + Returns + ------- + Float[ndarray, "M 4 W"] + Loaded contribution weight matrices. + """ return np.load(cwms_path) -def write_params(params, out_path): +def write_params(params: Dict[str, Any], out_path: str) -> None: + """Write parameter dictionary to JSON file. + + Parameters + ---------- + params : Dict[str, Any] + Parameter dictionary to serialize. + out_path : str + Output path for the JSON file. + """ with open(out_path, "w") as f: json.dump(params, f, indent=4) -def load_params(params_path): +def load_params(params_path: str) -> Dict[str, Any]: + """Load parameter dictionary from JSON file. + + Parameters + ---------- + params_path : str + Path to JSON file containing parameters. + + Returns + ------- + Dict[str, Any] + Loaded parameter dictionary. + """ with open(params_path) as f: params = json.load(f) return params -def write_occ_df(occ_df, out_path): +def write_occ_df(occ_df: pl.DataFrame, out_path: str) -> None: + """Write occurrence data to TSV file. + + Parameters + ---------- + occ_df : pl.DataFrame + Occurrence data DataFrame. + out_path : str + Output path for the TSV file. + """ occ_df.write_csv(out_path, separator="\t") -def write_seqlet_confusion_df(seqlet_confusion_df, out_path): +def write_seqlet_confusion_df(seqlet_confusion_df: pl.DataFrame, out_path: str) -> None: + """Write seqlet confusion matrix data to TSV file. + + Parameters + ---------- + seqlet_confusion_df : pl.DataFrame + Seqlet confusion matrix DataFrame. + out_path : str + Output path for the TSV file. + """ seqlet_confusion_df.write_csv(out_path, separator="\t") -def write_report_data(report_df, cwms, out_dir): +def write_report_data( + report_df: pl.DataFrame, cwms: Dict[str, Dict[str, ndarray]], out_dir: str +) -> None: + """Write comprehensive motif report data including CWMs and metadata. + + Parameters + ---------- + report_df : pl.DataFrame + Report metadata DataFrame. + cwms : Dict[str, Dict[str, ndarray]] + Nested dictionary of motif names to CWM types to arrays. + out_dir : str + Output directory for report files. + """ cwms_dir = os.path.join(out_dir, "CWMs") os.makedirs(cwms_dir, exist_ok=True) @@ -610,4 +1475,3 @@ def write_report_data(report_df, cwms, out_dir): np.savetxt(os.path.join(motif_dir, f"{cwm_type}.txt"), cwm) report_df.write_csv(os.path.join(out_dir, "motif_report.tsv"), separator="\t") - diff --git a/src/finemo/evaluation.py b/src/finemo/evaluation.py index 98bb919..61a89b2 100644 --- a/src/finemo/evaluation.py +++ b/src/finemo/evaluation.py @@ -1,21 +1,69 @@ +"""Evaluation module for assessing Fi-NeMo motif discovery and hit calling performance. + +This module provides functions for: +- Computing motif occurrence statistics and co-occurrence patterns +- Evaluating motif discovery quality against TF-MoDISCo results +- Analyzing hit calling performance and recall metrics +- Generating confusion matrices for seqlet-hit comparisons +""" + import warnings +from typing import List, Tuple, Dict, Any, Union import numpy as np +from numpy import ndarray import polars as pl - - -def get_motif_occurences(hits_df, motif_names): +from jaxtyping import Float, Int + + +def get_motif_occurences( + hits_df: pl.LazyFrame, motif_names: List[str] +) -> Tuple[pl.DataFrame, Int[ndarray, "M M"]]: + """Compute motif occurrence statistics and co-occurrence matrix. + + This function analyzes motif occurrence patterns across peaks by creating + a pivot table of hit counts and computing pairwise co-occurrence statistics. + + Parameters + ---------- + hits_df : pl.LazyFrame + Lazy DataFrame containing hit data with required columns: + - peak_id : Peak identifier + - motif_name : Name of the motif + Additional columns are ignored. + motif_names : List[str] + List of motif names to include in analysis. Missing motifs + will be added as columns with zero counts. + + Returns + ------- + occ_df : pl.DataFrame + DataFrame with motif occurrence counts per peak. Contains: + - peak_id column + - One column per motif with hit counts + - 'total' column summing all motif counts per peak + coocc : Int[ndarray, "M M"] + Co-occurrence matrix where M = len(motif_names). + Entry (i,j) indicates number of peaks containing both motif i and motif j. + Diagonal entries show total peaks containing each motif. + + Notes + ----- + The co-occurrence matrix is computed using binary occurrence indicators, + so multiple hits of the same motif in a peak are treated as a single occurrence. + """ occ_df = ( - hits_df - .collect() - .pivot(index="peak_id", columns="motif_name", values="count", aggregate_function="sum") + hits_df.collect() + .with_columns(pl.lit(1).alias("count")) + .pivot( + on="motif_name", index="peak_id", values="count", aggregate_function="sum" + ) .fill_null(0) ) missing_cols = set(motif_names) - set(occ_df.columns) occ_df = ( - occ_df - .with_columns([pl.lit(0).alias(m) for m in missing_cols]) + occ_df.with_columns([pl.lit(0).alias(m) for m in missing_cols]) .with_columns(total=pl.sum_horizontal(*motif_names)) .sort(["peak_id"]) ) @@ -25,7 +73,7 @@ def get_motif_occurences(hits_df, motif_names): occ_mat = np.zeros((num_peaks, num_motifs), dtype=np.int16) for i, m in enumerate(motif_names): - occ_mat[:,i] = occ_df.get_column(m).to_numpy() + occ_mat[:, i] = occ_df.get_column(m).to_numpy() occ_bin = (occ_mat > 0).astype(np.int32) coocc = occ_bin.T @ occ_bin @@ -33,113 +81,250 @@ def get_motif_occurences(hits_df, motif_names): return occ_df, coocc -def get_cwms(regions, positions_df, motif_width): - idx_df = ( - positions_df - .select( - peak_idx=pl.col("peak_id"), - start_idx=pl.col("start_untrimmed") - pl.col("peak_region_start"), - is_revcomp=pl.col("is_revcomp") - ) +def get_cwms( + regions: Float[ndarray, "N 4 L"], positions_df: pl.DataFrame, motif_width: int +) -> Float[ndarray, "H 4 W"]: + """Extract contribution weight matrices from regions based on hit positions. + + This function extracts motif-sized windows from contribution score regions + at positions specified by hit coordinates. It handles both forward and + reverse complement orientations and filters out invalid positions. + + Parameters + ---------- + regions : Float[ndarray, "N 4 L"] + Input contribution score regions multiplied by one-hot sequences. + Shape: (n_peaks, 4, region_width) where 4 represents DNA bases (A,C,G,T). + positions_df : pl.DataFrame + DataFrame containing hit positions with required columns: + - peak_id : int, Peak index (0-based) + - start_untrimmed : int, Start position in genomic coordinates + - peak_region_start : int, Peak region start coordinate + - is_revcomp : bool, Whether hit is on reverse complement strand + motif_width : int + Width of motifs to extract. Must be positive. + + Returns + ------- + cwms : Float[ndarray, "H 4 W"] + Extracted contribution weight matrices for valid hits. + Shape: (n_valid_hits, 4, motif_width) + Invalid hits (outside region boundaries) are filtered out. + + Notes + ----- + - Start positions are converted from genomic to region-relative coordinates + - Reverse complement hits have their sequence order reversed + - Hits extending beyond region boundaries are excluded + - The mean is computed across all valid hits, with warnings suppressed + for empty slices or invalid operations + + Raises + ------ + ValueError + If motif_width is non-positive or positions_df lacks required columns. + """ + idx_df = positions_df.select( + peak_idx=pl.col("peak_id"), + start_idx=pl.col("start_untrimmed") - pl.col("peak_region_start"), + is_revcomp=pl.col("is_revcomp"), ) - peak_idx = idx_df.get_column('peak_idx').to_numpy() - start_idx = idx_df.get_column('start_idx').to_numpy() + peak_idx = idx_df.get_column("peak_idx").to_numpy() + start_idx = idx_df.get_column("start_idx").to_numpy() is_revcomp = idx_df.get_column("is_revcomp").to_numpy().astype(bool) - # Ignore hits outside of region + # Filter hits that fall outside the region boundaries valid_mask = (start_idx >= 0) & (start_idx + motif_width <= regions.shape[2]) peak_idx = peak_idx[valid_mask] start_idx = start_idx[valid_mask] is_revcomp = is_revcomp[valid_mask] - row_idx = peak_idx[:,None,None] - pos_idx = start_idx[:,None,None] + np.zeros((1,1,motif_width), dtype=int) - pos_idx[~is_revcomp,:,:] += np.arange(motif_width)[None,None,:] - pos_idx[is_revcomp,:,:] += np.arange(motif_width)[None,None,::-1] - nuc_idx = np.zeros((peak_idx.shape[0],4,1), dtype=int) - nuc_idx[~is_revcomp,:,:] += np.arange(4)[None,:,None] - nuc_idx[is_revcomp,:,:] += np.arange(4)[None,::-1,None] + row_idx = peak_idx[:, None, None] + pos_idx = start_idx[:, None, None] + np.zeros((1, 1, motif_width), dtype=int) + pos_idx[~is_revcomp, :, :] += np.arange(motif_width)[None, None, :] + pos_idx[is_revcomp, :, :] += np.arange(motif_width)[None, None, ::-1] + nuc_idx = np.zeros((peak_idx.shape[0], 4, 1), dtype=int) + nuc_idx[~is_revcomp, :, :] += np.arange(4)[None, :, None] + nuc_idx[is_revcomp, :, :] += np.arange(4)[None, ::-1, None] seqs = regions[row_idx, nuc_idx, pos_idx] - + with warnings.catch_warnings(): - warnings.filterwarnings(action='ignore', message='invalid value encountered in divide') - warnings.filterwarnings(action='ignore', message='Mean of empty slice') + warnings.filterwarnings( + action="ignore", message="invalid value encountered in divide" + ) + warnings.filterwarnings(action="ignore", message="Mean of empty slice") cwms = seqs.mean(axis=0) return cwms -def tfmodisco_comparison(regions, hits_df, peaks_df, seqlets_df, motifs_df, cwms_modisco, - motif_names, modisco_half_width, motif_width, compute_recall): +def tfmodisco_comparison( + regions: Float[ndarray, "N 4 L"], + hits_df: Union[pl.DataFrame, pl.LazyFrame], + peaks_df: pl.DataFrame, + seqlets_df: Union[pl.DataFrame, pl.LazyFrame, None], + motifs_df: pl.DataFrame, + cwms_modisco: Float[ndarray, "M 4 W"], + motif_names: List[str], + modisco_half_width: int, + motif_width: int, + compute_recall: bool, +) -> Tuple[ + Dict[str, Dict[str, Any]], + pl.DataFrame, + Dict[str, Dict[str, Float[ndarray, "4 W"]]], + Dict[str, Dict[str, Tuple[int, int]]], +]: + """Compare Fi-NeMo hits with TF-MoDISCo seqlets and compute evaluation metrics. + + This function performs comprehensive comparison between Fi-NeMo hit calls + and TF-MoDISCo seqlets, computing recall metrics, CWM similarities, + and extracting contribution weight matrices for visualization. + + Parameters + ---------- + regions : Float[ndarray, "N 4 L"] + Contribution score regions multiplied by one-hot sequences. + Shape: (n_peaks, 4, region_length) + hits_df : Union[pl.DataFrame, pl.LazyFrame] + Fi-NeMo hit calls with required columns: + - peak_id, start_untrimmed, end_untrimmed, strand, motif_name + peaks_df : pl.DataFrame + Peak metadata with columns: + - peak_id, chr_id, peak_region_start + seqlets_df : Optional[pl.DataFrame] + TF-MoDISCo seqlets with columns: + - chr_id, start_untrimmed, is_revcomp, motif_name + If None, only basic hit statistics are computed. + motifs_df : pl.DataFrame + Motif metadata with columns: + - motif_name, strand, motif_id, motif_start, motif_end + cwms_modisco : Float[ndarray, "M 4 W"] + TF-MoDISCo contribution weight matrices. + Shape: (n_modisco_motifs, 4, motif_width) + motif_names : List[str] + Names of motifs to analyze. + modisco_half_width : int + Half-width for restricting hits to central region for fair comparison. + motif_width : int + Width of motifs for CWM extraction. + compute_recall : bool + Whether to compute recall metrics requiring seqlets_df. + + Returns + ------- + report_data : Dict[str, Dict[str, Any]] + Per-motif evaluation metrics including: + - num_hits_total, num_hits_restricted, num_seqlets + - num_overlaps, seqlet_recall, cwm_similarity + report_df : pl.DataFrame + Tabular format of report_data for easy analysis. + cwms : Dict[str, Dict[str, Float[ndarray, "4 W"]]] + Extracted CWMs for each motif and condition: + - hits_fc, hits_rc: Forward/reverse complement hits + - modisco_fc, modisco_rc: TF-MoDISCo forward/reverse + - seqlets_only, hits_restricted_only: Non-overlapping instances + cwm_trim_bounds : Dict[str, Dict[str, Tuple[int, int]]] + Trimming boundaries for each CWM type and motif. + + Notes + ----- + - Hits are filtered to central region defined by modisco_half_width + - CWM similarity is computed as normalized dot product between hit and TF-MoDISCo CWMs + - Recall metrics require both hits_df and seqlets_df to be non-empty + - Missing motifs are handled gracefully with empty DataFrames + + Raises + ------ + ValueError + If required columns are missing from input DataFrames. + """ + + # Ensure hits_df is LazyFrame for consistent operations + if isinstance(hits_df, pl.DataFrame): + hits_df = hits_df.lazy() + hits_df = ( - hits_df - .with_columns(pl.col('peak_id').cast(pl.UInt32)) - .join( - peaks_df.lazy(), on="peak_id", how="inner" - ) + hits_df.with_columns(pl.col("peak_id").cast(pl.UInt32)) + .join(peaks_df.lazy(), on="peak_id", how="inner") .select( chr_id=pl.col("chr_id"), start_untrimmed=pl.col("start_untrimmed"), end_untrimmed=pl.col("end_untrimmed"), - is_revcomp=pl.col("strand") == '-', + is_revcomp=pl.col("strand") == "-", motif_name=pl.col("motif_name"), peak_region_start=pl.col("peak_region_start"), - peak_id=pl.col("peak_id") + peak_id=pl.col("peak_id"), ) ) - hits_unique = hits_df.unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"]) - + hits_unique = hits_df.unique( + subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"] + ) + region_len = regions.shape[2] center = region_len / 2 - hits_filtered = ( - hits_df - .filter( - ((pl.col("start_untrimmed") - pl.col("peak_region_start")) >= (center - modisco_half_width)) - & ((pl.col("end_untrimmed") - pl.col("peak_region_start")) <= (center + modisco_half_width)) + hits_filtered = hits_df.filter( + ( + (pl.col("start_untrimmed") - pl.col("peak_region_start")) + >= (center - modisco_half_width) ) - .unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"]) - ) - - if compute_recall: - overlaps_df = ( - hits_filtered.join( - seqlets_df, - on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], - how="inner", - ) - .collect() + & ( + (pl.col("end_untrimmed") - pl.col("peak_region_start")) + <= (center + modisco_half_width) ) + ).unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"]) - seqlets_only_df = ( - seqlets_df.join( - hits_df, - on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], - how="anti", - ) - .collect() - ) + hits_by_motif = hits_unique.collect().partition_by("motif_name", as_dict=True) + hits_filtered_by_motif = hits_filtered.collect().partition_by( + "motif_name", as_dict=True + ) - hits_only_filtered_df = ( - hits_filtered.join( - seqlets_df, - on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], - how="anti", - ) - .collect() - ) + if seqlets_df is None: + seqlets_collected = None + seqlets_lazy = None + elif isinstance(seqlets_df, pl.LazyFrame): + seqlets_collected = seqlets_df.collect() + seqlets_lazy = seqlets_df + else: + seqlets_collected = seqlets_df + seqlets_lazy = seqlets_df.lazy() + + if seqlets_collected is not None: + seqlets_by_motif = seqlets_collected.partition_by("motif_name", as_dict=True) + else: + seqlets_by_motif = {} + + if compute_recall and seqlets_lazy is not None: + overlaps_df = hits_filtered.join( + seqlets_lazy, + on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], + how="inner", + ).collect() - hits_by_motif = hits_unique.collect().partition_by("motif_name", as_dict=True) - hits_fitered_by_motif = hits_filtered.collect().partition_by("motif_name", as_dict=True) + seqlets_only_df = seqlets_lazy.join( + hits_df, + on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], + how="anti", + ).collect() - if seqlets_df is not None: - seqlets_by_motif = seqlets_df.collect().partition_by("motif_name", as_dict=True) + hits_only_filtered_df = hits_filtered.join( + seqlets_lazy, + on=["chr_id", "start_untrimmed", "is_revcomp", "motif_name"], + how="anti", + ).collect() - if compute_recall: + # Create partition dictionaries overlaps_by_motif = overlaps_df.partition_by("motif_name", as_dict=True) seqlets_only_by_motif = seqlets_only_df.partition_by("motif_name", as_dict=True) - hits_only_filtered_by_motif = hits_only_filtered_df.partition_by("motif_name", as_dict=True) + hits_only_filtered_by_motif = hits_only_filtered_df.partition_by( + "motif_name", as_dict=True + ) + else: + overlaps_by_motif = {} + seqlets_only_by_motif = {} + hits_only_filtered_by_motif = {} report_data = {} cwms = {} @@ -147,12 +332,18 @@ def tfmodisco_comparison(regions, hits_df, peaks_df, seqlets_df, motifs_df, cwms dummy_df = hits_df.clear().collect() for m in motif_names: hits = hits_by_motif.get((m,), dummy_df) - hits_filtered = hits_fitered_by_motif.get((m,), dummy_df) + hits_filtered = hits_filtered_by_motif.get((m,), dummy_df) + + # Initialize default values + seqlets = dummy_df + overlaps = dummy_df + seqlets_only = dummy_df + hits_only_filtered = dummy_df if seqlets_df is not None: seqlets = seqlets_by_motif.get((m,), dummy_df) - if compute_recall: + if compute_recall and seqlets_df is not None: overlaps = overlaps_by_motif.get((m,), dummy_df) seqlets_only = seqlets_only_by_motif.get((m,), dummy_df) hits_only_filtered = hits_only_filtered_by_motif.get((m,), dummy_df) @@ -165,48 +356,56 @@ def tfmodisco_comparison(regions, hits_df, peaks_df, seqlets_df, motifs_df, cwms if seqlets_df is not None: report_data[m]["num_seqlets"] = seqlets.height - if compute_recall: + if compute_recall and seqlets_df is not None: report_data[m] |= { "num_overlaps": overlaps.height, "num_seqlets_only": seqlets_only.height, "num_hits_restricted_only": hits_only_filtered.height, "seqlet_recall": np.float64(overlaps.height) / seqlets.height + if seqlets.height > 0 + else 0.0, } - motif_data_fc = motifs_df.row(by_predicate=(pl.col("motif_name") == m) - & (pl.col("strand") == "+"), named=True) - motif_data_rc = motifs_df.row(by_predicate=(pl.col("motif_name") == m) - & (pl.col("strand") == "-"), named=True) + motif_data_fc = motifs_df.row( + by_predicate=(pl.col("motif_name") == m) & (pl.col("strand") == "+"), + named=True, + ) + motif_data_rc = motifs_df.row( + by_predicate=(pl.col("motif_name") == m) & (pl.col("strand") == "-"), + named=True, + ) cwms[m] = { "hits_fc": get_cwms(regions, hits, motif_width), "modisco_fc": cwms_modisco[motif_data_fc["motif_id"]], "modisco_rc": cwms_modisco[motif_data_rc["motif_id"]], } - cwms[m]["hits_rc"] = cwms[m]["hits_fc"][::-1,::-1] + cwms[m]["hits_rc"] = cwms[m]["hits_fc"][::-1, ::-1] - if compute_recall: + if compute_recall and seqlets_df is not None: cwms[m] |= { "seqlets_only": get_cwms(regions, seqlets_only, motif_width), - "hits_restricted_only": get_cwms(regions, hits_only_filtered, motif_width), + "hits_restricted_only": get_cwms( + regions, hits_only_filtered, motif_width + ), } bounds_fc = (motif_data_fc["motif_start"], motif_data_fc["motif_end"]) bounds_rc = (motif_data_rc["motif_start"], motif_data_rc["motif_end"]) - + cwm_trim_bounds[m] = { "hits_fc": bounds_fc, "modisco_fc": bounds_fc, "modisco_rc": bounds_rc, - "hits_rc": bounds_rc + "hits_rc": bounds_rc, } - if compute_recall: + if compute_recall and seqlets_df is not None: cwm_trim_bounds[m] |= { "seqlets_only": bounds_fc, "hits_restricted_only": bounds_fc, } - + hits_cwm = cwms[m]["hits_fc"] modisco_cwm = cwms[m]["modisco_fc"] hnorm = np.sqrt((hits_cwm**2).sum()) @@ -221,75 +420,121 @@ def tfmodisco_comparison(regions, hits_df, peaks_df, seqlets_df, motifs_df, cwms return report_data, report_df, cwms, cwm_trim_bounds -def seqlet_confusion(hits_df, seqlets_df, peaks_df, motif_names, motif_width): +def seqlet_confusion( + hits_df: Union[pl.DataFrame, pl.LazyFrame], + seqlets_df: Union[pl.DataFrame, pl.LazyFrame], + peaks_df: pl.DataFrame, + motif_names: List[str], + motif_width: int, +) -> Tuple[pl.DataFrame, Float[ndarray, "M M"]]: + """Compute confusion matrix between TF-MoDISCo seqlets and Fi-NeMo hits. + + This function creates a confusion matrix showing the overlap between + TF-MoDISCo seqlets (ground truth) and Fi-NeMo hits across different motifs. + Overlap frequencies are estimated using binned genomic coordinates. + + Parameters + ---------- + hits_df : Union[pl.DataFrame, pl.LazyFrame] + Fi-NeMo hit calls with required columns: + - peak_id, start_untrimmed, end_untrimmed, strand, motif_name + seqlets_df : pl.DataFrame + TF-MoDISCo seqlets with required columns: + - chr_id, start_untrimmed, end_untrimmed, motif_name + peaks_df : pl.DataFrame + Peak metadata for joining coordinates: + - peak_id, chr_id + motif_names : List[str] + Names of motifs to include in confusion matrix. + Determines matrix dimensions. + motif_width : int + Width used for binning genomic coordinates. + Positions are binned to motif_width resolution. + + Returns + ------- + confusion_df : pl.DataFrame + Detailed confusion matrix in tabular format with columns: + - motif_name_seqlets : Seqlet motif labels (rows) + - motif_name_hits : Hit motif labels (columns) + - frac_overlap : Fraction of seqlets overlapping with hits + confusion_mat : Float[ndarray, "M M"] + Confusion matrix where M = len(motif_names). + Entry (i,j) = fraction of motif i seqlets overlapping with motif j hits. + Rows represent seqlet motifs, columns represent hit motifs. + + Notes + ----- + - Genomic coordinates are binned to motif_width resolution for overlap detection + - Only exact bin overlaps are considered (same chr_id, start_bin, end_bin) + - Fractions are computed as: overlaps / total_seqlets_per_motif + - Missing motif combinations result in zero entries in the confusion matrix + + Raises + ------ + ValueError + If required columns are missing from input DataFrames. + KeyError + If motif names in data don't match those in motif_names list. + """ bin_size = motif_width + # Ensure hits_df is LazyFrame for consistent operations + if isinstance(hits_df, pl.DataFrame): + hits_df = hits_df.lazy() + hits_binned = ( - hits_df - .with_columns( - peak_id=pl.col('peak_id').cast(pl.UInt32), - is_revcomp=pl.col("strand") == '-' - ) - .join( - peaks_df.lazy(), on="peak_id", how="inner" + hits_df.with_columns( + peak_id=pl.col("peak_id").cast(pl.UInt32), + is_revcomp=pl.col("strand") == "-", ) + .join(peaks_df.lazy(), on="peak_id", how="inner") .unique(subset=["chr_id", "start_untrimmed", "motif_name", "is_revcomp"]) .select( chr_id=pl.col("chr_id"), start_bin=pl.col("start_untrimmed") // bin_size, end_bin=pl.col("end_untrimmed") // bin_size, - motif_name=pl.col("motif_name") + motif_name=pl.col("motif_name"), ) ) - seqlets_binned = ( - seqlets_df - .select( - chr_id=pl.col("chr_id"), - start_bin=pl.col("start_untrimmed") // bin_size, - end_bin=pl.col("end_untrimmed") // bin_size, - motif_name=pl.col("motif_name") - ) + seqlets_lazy = seqlets_df.lazy() + seqlets_binned = seqlets_lazy.select( + chr_id=pl.col("chr_id"), + start_bin=pl.col("start_untrimmed") // bin_size, + end_bin=pl.col("end_untrimmed") // bin_size, + motif_name=pl.col("motif_name"), ) - overlaps_df = ( - seqlets_binned.join( - hits_binned, - on=["chr_id", "start_bin", "end_bin"], - how="inner", - suffix="_hits" - ) + overlaps_df = seqlets_binned.join( + hits_binned, on=["chr_id", "start_bin", "end_bin"], how="inner", suffix="_hits" ) - seqlet_counts = seqlets_binned.group_by("motif_name").len(name="num_seqlets") - overlap_counts = overlaps_df.group_by(["motif_name", "motif_name_hits"]).len(name="num_overlaps") + seqlet_counts = ( + seqlets_binned.group_by("motif_name").len(name="num_seqlets").collect() + ) + overlap_counts = ( + overlaps_df.group_by(["motif_name", "motif_name_hits"]) + .len(name="num_overlaps") + .collect() + ) num_motifs = len(motif_names) confusion_mat = np.zeros((num_motifs, num_motifs), dtype=np.float32) name_to_idx = {m: i for i, m in enumerate(motif_names)} - confusion_df = ( - overlap_counts - .join( - seqlet_counts, - on="motif_name", - how="inner" - ) - .select( - motif_name_seqlets=pl.col("motif_name"), - motif_name_hits=pl.col("motif_name_hits"), - frac_overlap=pl.col("num_overlaps") / pl.col("num_seqlets"), - ) - .collect() + confusion_df = overlap_counts.join( + seqlet_counts, on="motif_name", how="inner" + ).select( + motif_name_seqlets=pl.col("motif_name"), + motif_name_hits=pl.col("motif_name_hits"), + frac_overlap=pl.col("num_overlaps") / pl.col("num_seqlets"), ) - confusion_idx_df = ( - confusion_df - .select( - row_idx=pl.col("motif_name_seqlets").replace_strict(name_to_idx), - col_idx=pl.col("motif_name_hits").replace_strict(name_to_idx), - frac_overlap=pl.col("frac_overlap") - ) + confusion_idx_df = confusion_df.select( + row_idx=pl.col("motif_name_seqlets").replace_strict(name_to_idx), + col_idx=pl.col("motif_name_hits").replace_strict(name_to_idx), + frac_overlap=pl.col("frac_overlap"), ) row_idx = confusion_idx_df["row_idx"].to_numpy() @@ -297,7 +542,5 @@ def seqlet_confusion(hits_df, seqlets_df, peaks_df, motif_names, motif_width): frac_overlap = confusion_idx_df["frac_overlap"].to_numpy() confusion_mat[row_idx, col_idx] = frac_overlap - - return confusion_df, confusion_mat - + return confusion_df, confusion_mat diff --git a/src/finemo/hitcaller.py b/src/finemo/hitcaller.py index ee0e666..b95c360 100644 --- a/src/finemo/hitcaller.py +++ b/src/finemo/hitcaller.py @@ -1,171 +1,538 @@ +"""Hit caller module implementing the Fi-NeMo motif instance calling algorithm. + +This module provides the core functionality for identifying transcription factor +binding motif instances in neural network contribution scores using a competitive +optimization approach based on proximal gradient descent. + +The main algorithm fits a sparse linear model where contribution scores are +reconstructed as a weighted combination of motif contribution weight matrices (CWMs) +at specific genomic positions. The sparsity constraint ensures that only the most +significant motif instances are called. +""" + import warnings +from typing import Tuple, Union, Optional, Dict, List +from abc import ABC, abstractmethod import numpy as np +from numpy import ndarray import torch import torch.nn.functional as F +from torch import Tensor import polars as pl +from jaxtyping import Float, Int, Bool from tqdm import tqdm - -def prox_grad_step(coefficients, importance_scale, cwms, contribs, sequences, - lambdas, step_sizes): - """ - Proximal gradient descent optimization step for non-negative lasso - - coefficients: (b, m, l - w + 1) - importance_scale: (b, 1, l - w + 1) - cwms: (m, 4, w) - contribs: (b, 4, l) - sequences: (b, 4, l) or dummy scalar - lambdas: (1, m, 1) - - For details on proximal gradient descent: https://yuxinchen2020.github.io/ele520_math_data/lectures/lasso_algorithm_extension.pdf, slide 22 - For details on duality gap computation: https://stanford.edu/~boyd/papers/pdf/l1_ls.pdf, Section III +# Type aliases for tensor operations +ArrayLike = Union[ndarray, torch.Tensor] + + +def prox_grad_step( + coefficients: Float[Tensor, "B M P"], + importance_scale: Float[Tensor, "B 1 P"], + cwms: Float[Tensor, "M 4 W"], + contribs: Float[Tensor, "B 4 L"], + sequences: Union[Int[Tensor, "B 4 L"], int], + lambdas: Float[Tensor, "1 M 1"], + step_sizes: Float[Tensor, "B 1 1"], +) -> Tuple[Float[Tensor, "B M P"], Float[Tensor, " B"], Float[Tensor, " B"]]: + """Perform a proximal gradient descent optimization step for non-negative lasso. + + This function implements a single optimization step of the Fi-NeMo algorithm, + which uses proximal gradient descent to solve a sparse reconstruction problem. + The goal is to represent contribution scores as a sparse linear combination + of motif contribution weight matrices (CWMs). + + B = batch size, M = number of motifs, L = sequence length, W = motif width. + P = L - W + 1 (the number of positions with coefficients). + + Parameters + ---------- + coefficients : Float[Tensor, "B M P"] + Current coefficient matrix representing motif instance strengths. + importance_scale : Float[Tensor, "B 1 P"] + Scaling factors for importance-weighted reconstruction. + cwms : Float[Tensor, "M 4 W"] + Motif contribution weight matrices for all motifs. + 4 represents the DNA bases (A, C, G, T). + contribs : Float[Tensor, "B 4 L"] + Target contribution scores to reconstruct. + sequences : Float[Tensor, "B 4 L"] | int + One-hot encoded DNA sequences. Can be a scalar (1) for hypothetical mode. + lambdas : Float[Tensor, "1 M 1"] + L1 regularization weights for each motif. + step_sizes : Float[Tensor, "B 1 1"] + Optimization step sizes for each batch element. + + Returns + ------- + c_next : Float[Tensor, "B M P"], shape (b, m, l - w + 1) + Updated coefficient matrix after the optimization step. + dual_gap : Float[Tensor, "B"] + Duality gap for convergence assessment. + nll : Float[Tensor, "B"] + Negative log likelihood (proportional to MSE). + + Notes + ----- + The algorithm uses proximal gradient descent to solve: + + minimize_c: ||contribs - conv_transpose(c * importance_scale, cwms) * sequences||²₂ + λ||c||₁ + + subject to: c ≥ 0 + + References + ---------- + - Proximal gradient descent: https://yuxinchen2020.github.io/ele520_math_data/lectures/lasso_algorithm_extension.pdf, slide 22 + - Duality gap computation: https://stanford.edu/~boyd/papers/pdf/l1_ls.pdf, Section III """ - # Forward pass + # Forward pass: convolution operations require specific tensor layouts coef_adj = coefficients * importance_scale - pred_unmasked = F.conv_transpose1d(coef_adj, cwms) # (b, 4, l) - pred = pred_unmasked * sequences # (b, 4, l) + pred_unmasked = F.conv_transpose1d(coef_adj, cwms) # (b, 4, l) + pred = ( + pred_unmasked * sequences + ) # (b, 4, l), element-wise masking for projected mode # Compute gradient * -1 - residuals = contribs - pred # (b, 4, l) - ngrad = F.conv1d(residuals, cwms) * importance_scale # (b, m, l - w + 1) + residuals = contribs - pred # (b, 4, l) + ngrad = F.conv1d(residuals, cwms) * importance_scale # (b, m, l - w + 1) # Negative log likelihood (proportional to MSE) - nll = (residuals**2).sum(dim=(1,2)) # (b) - - # Compute duality gap - dual_norm = (ngrad / lambdas).amax(dim=(1,2)) # (b) - dual_scale = (torch.clamp(1 / dual_norm, max=1.)**2 + 1) / 2 # (b) - nll_scaled = nll * dual_scale # (b) - - dual_diff = (residuals * contribs).sum(dim=(1,2)) # (b) - l1_term = (torch.abs(coefficients).sum(dim=2, keepdim=True) * lambdas).sum(dim=(1,2)) # (b) - # l1_term = torch.linalg.vector_norm((coefficients * lambdas), ord=1, dim=(1,2)) # (b) - dual_gap = (nll_scaled - dual_diff + l1_term).abs() # (b) + nll = (residuals**2).sum(dim=(1, 2)) # (b) - # Compute proximal gradient descent step - c_next = coefficients + step_sizes * (ngrad - lambdas) # (b, m, l - w + 1) - c_next = F.relu(c_next) # (b, m, l - w + 1) + # Compute duality gap for convergence assessment + dual_norm = (ngrad / lambdas).amax(dim=(1, 2)) # (b) + dual_scale = (torch.clamp(1 / dual_norm, max=1.0) ** 2 + 1) / 2 # (b) + nll_scaled = nll * dual_scale # (b) - return c_next, dual_gap, nll + dual_diff = (residuals * contribs).sum(dim=(1, 2)) # (b) + l1_term = (torch.abs(coefficients).sum(dim=2, keepdim=True) * lambdas).sum( + dim=(1, 2) + ) # (b) + dual_gap = (nll_scaled - dual_diff + l1_term).abs() # (b) + # Compute proximal gradient descent step + c_next = coefficients + step_sizes * (ngrad - lambdas) # (b, m, l - w + 1) + c_next = F.relu(c_next) # Ensure non-negativity constraint -def optimizer_step(cwms, contribs, importance_scale, sequences, coef_inter, coef, i, step_sizes, l, lambdas): - """ - Non-negative lasso optimizer step with momentum. + return c_next, dual_gap, nll - cwms: (m, 4, w) - contribs: (b, 4, l) - importance_scale: (b, 1, l - w + 1) - sequences: (b, 4, l) or dummy scalar - coef_inter, coef: (b, m, l - w + 1) - i, step_sizes: (b,) - For details on optimization algorithm: https://yuxinchen2020.github.io/ele520_math_data/lectures/lasso_algorithm_extension.pdf, slides 22, 27 +def optimizer_step( + cwms: Float[Tensor, "M 4 W"], + contribs: Float[Tensor, "B 4 L"], + importance_scale: Float[Tensor, "B 1 P"], + sequences: Union[Int[Tensor, "B 4 L"], int], + coef_inter: Float[Tensor, "B M P"], + coef: Float[Tensor, "B M P"], + i: Float[Tensor, "B 1 1"], + step_sizes: Float[Tensor, "B 1 1"], + sequence_length: int, + lambdas: Float[Tensor, "1 M 1"], +) -> Tuple[ + Float[Tensor, "B M P"], + Float[Tensor, "B M P"], + Float[Tensor, " B"], + Float[Tensor, " B"], +]: + """Perform a non-negative lasso optimizer step with Nesterov momentum. + + This function combines proximal gradient descent with momentum acceleration + to improve convergence speed while maintaining the non-negative constraint + on coefficients. + + B = batch size, M = number of motifs, L = sequence length, W = motif width. + P = L - W + 1 (the number of positions with coefficients). + + Parameters + ---------- + cwms : Float[Tensor, "M 4 W"] + Motif contribution weight matrices. + contribs : Float[Tensor, "B 4 L"] + Target contribution scores. + importance_scale : Float[Tensor, "B 1 P"] + Importance scaling factors. + sequences : Union[Int[Tensor, "B 4 L"], int] + One-hot encoded sequences or scalar for hypothetical mode. + coef_inter : Float[Tensor, "B M P"] + Intermediate coefficient matrix (with momentum). + coef : Float[Tensor, "B M P"] + Current coefficient matrix. + i : Float[Tensor, "B 1 1"] + Iteration counter for each batch element. + step_sizes : Float[Tensor, "B 1 1"] + Step sizes for optimization. + sequence_length : int + Sequence length for normalization. + lambdas : Float[Tensor, "1 M 1"] + Regularization parameters. + + Returns + ------- + coef_inter : Float[Tensor, "B M P"] + Updated intermediate coefficients with momentum. + coef : Float[Tensor, "B M P"] + Updated coefficient matrix. + gap : Float[Tensor, " B"] + Normalized duality gap. + nll : Float[Tensor, " B"] + Normalized negative log likelihood. + + Notes + ----- + Uses Nesterov momentum with momentum coefficient i/(i+3) for improved + convergence properties. The duality gap and NLL are normalized by + sequence length for scale-invariant convergence assessment. + + References + ---------- + https://yuxinchen2020.github.io/ele520_math_data/lectures/lasso_algorithm_extension.pdf, slides 22, 27 """ coef_prev = coef # Proximal gradient descent step - coef, gap, nll = prox_grad_step(coef_inter, importance_scale, cwms, contribs, sequences, - lambdas, step_sizes) - gap = gap / l - nll = nll / (2 * l) - - # Compute updated coefficients - mom_term = i / (i + 3.) + coef, gap, nll = prox_grad_step( + coef_inter, importance_scale, cwms, contribs, sequences, lambdas, step_sizes + ) + gap = gap / sequence_length + nll = nll / (2 * sequence_length) + + # Compute updated coefficients with Nesterov momentum + mom_term = i / (i + 3.0) coef_inter = (1 + mom_term) * coef - mom_term * coef_prev return coef_inter, coef, gap, nll -def _to_channel_last_layout(tensor, **kwargs): - return tensor[:,:,:,None].to(memory_format=torch.channels_last, **kwargs).squeeze(3) +def _to_channel_last_layout(tensor: Tensor, **kwargs) -> torch.Tensor: + """Convert tensor to channel-last memory layout for optimized convolution operations. + + Parameters + ---------- + tensor : torch.Tensor + Input tensor to convert. + **kwargs + Additional keyword arguments passed to tensor.to(). + + Returns + ------- + torch.Tensor + Tensor with channel-last memory layout. + """ + return ( + tensor[:, :, :, None].to(memory_format=torch.channels_last, **kwargs).squeeze(3) + ) + + +def _signed_sqrt(x: torch.Tensor) -> torch.Tensor: + """Apply signed square root transformation to input tensor. + This transformation preserves the sign while applying square root to the + absolute value, which can help with numerical stability and gradient flow. -def _signed_sqrt(x): + Parameters + ---------- + x : torch.Tensor + Input tensor. + + Returns + ------- + torch.Tensor + Transformed tensor with same shape as input. + """ return torch.sign(x) * torch.sqrt(torch.abs(x)) -class BatchLoaderBase: - def __init__(self, contribs, sequences, l, device): +class BatchLoaderBase(ABC): + """Base class for loading batches of contribution scores and sequences. + + This class provides common functionality for different input formats + including batch indexing and padding for consistent batch sizes. + + N = number of sequences, L = sequence length. + + Parameters + ---------- + contribs : Union[Float[Tensor, "N 4 L"], Float[Tensor, "N L"]] + Contribution scores array. + sequences : Int[Tensor, "N 4 L"] + One-hot encoded sequences array. + sequence_length : int + Sequence length. + device : torch.device + Target device for tensor operations. + """ + + def __init__( + self, + contribs: Union[Float[Tensor, "N 4 L"], Float[Tensor, "N L"]], + sequences: Int[Tensor, "N 4 L"], + sequence_length: int, + device: torch.device, + ) -> None: self.contribs = contribs self.sequences = sequences - self.l = l + self.sequence_length = sequence_length self.device = device - def _get_inds_and_pad_lens(self, start, end): + def _get_inds_and_pad_lens( + self, start: int, end: int + ) -> Tuple[Int[Tensor, " Z"], Tuple[int, ...]]: + """Get indices and padding lengths for batch loading. + + Parameters + ---------- + start : int + Start index for batch. + end : int + End index for batch. + + Returns + ------- + inds : Int[Tensor, " Z"] + Padded indices tensor with -1 for padding positions. + pad_lens : tuple + Padding specification for F.pad (left, right, top, bottom, front, back). + """ n = end - start end = min(end, self.contribs.shape[0]) overhang = n - (end - start) pad_lens = (0, 0, 0, 0, 0, overhang) - inds = F.pad(torch.arange(start, end, dtype=torch.int), (0, overhang), value=-1).to(device=self.device) + inds = F.pad( + torch.arange(start, end, dtype=torch.int), (0, overhang), value=-1 + ).to(device=self.device) return inds, pad_lens - def load_batch(self, start, end): - raise NotImplementedError - + @abstractmethod + def load_batch( + self, start: int, end: int + ) -> Tuple[ + Float[Tensor, "B 4 L"], Union[Int[Tensor, "B 4 L"], int], Int[Tensor, " B"] + ]: + """Load a batch of data. + + B = batch size, L = sequence length. + + Parameters + ---------- + start : int + Start index (used by subclasses). + end : int + End index (used by subclasses). + + Returns + ------- + contribs_batch : Float[Tensor, "B 4 L"] + Batch of contribution scores. + sequences_batch : Union[Int[Tensor, "B 4 L"], int] + Batch of sequences or scalar for hypothetical mode. + inds_batch : Int[Tensor, "B"] + Batch indices. + + Notes + ----- + This is an abstract method that must be implemented by subclasses. + Parameters are intentionally unused in the base implementation. + """ + pass + class BatchLoaderCompactFmt(BatchLoaderBase): - def load_batch(self, start, end): - inds, pad_lens = self._get_inds_and_pad_lens(start, end) + """Batch loader for compact format contribution scores. - contribs_compact = F.pad(self.contribs[start:end,None,:], pad_lens) - contribs_batch = _to_channel_last_layout(contribs_compact, device=self.device, dtype=torch.float32) - sequences_batch = F.pad(self.sequences[start:end,:,:], pad_lens) # (b, 4, l) - sequences_batch = _to_channel_last_layout(sequences_batch, device=self.device, dtype=torch.int8) + Handles contribution scores in shape (N, L) representing projected + scores that need to be broadcasted to (N, 4, L) format. + """ + + def load_batch( + self, start: int, end: int + ) -> Tuple[Float[Tensor, "B 4 L"], Int[Tensor, "B 4 L"], Int[Tensor, " B"]]: + inds, pad_lens = self._get_inds_and_pad_lens(start, end) - # global_scale = ((contribs_batch**2).sum(dim=(1,2)) / self.l).sqrt() + contribs_compact = F.pad(self.contribs[start:end, None, :], pad_lens) + contribs_batch = _to_channel_last_layout( + contribs_compact, device=self.device, dtype=torch.float32 + ) + sequences_batch = F.pad(self.sequences[start:end, :, :], pad_lens) # (b, 4, l) + sequences_batch = _to_channel_last_layout( + sequences_batch, device=self.device, dtype=torch.int8 + ) - contribs_batch = contribs_batch * sequences_batch # (b, 4, l) + contribs_batch = contribs_batch * sequences_batch # (b, 4, l) return contribs_batch, sequences_batch, inds class BatchLoaderProj(BatchLoaderBase): - def load_batch(self, start, end): + """Batch loader for projected contribution scores. + + Handles contribution scores in shape (N, 4, L) where scores are + element-wise multiplied by one-hot sequences to get projected contributions. + """ + + def load_batch( + self, start: int, end: int + ) -> Tuple[Float[Tensor, "B 4 L"], Int[Tensor, "B 4 L"], Int[Tensor, " B"]]: inds, pad_lens = self._get_inds_and_pad_lens(start, end) - contribs_hyp = F.pad(self.contribs[start:end,:,:], pad_lens) - contribs_hyp = _to_channel_last_layout(contribs_hyp, device=self.device, dtype=torch.float32) - sequences_batch = F.pad(self.sequences[start:end,:,:], pad_lens) # (b, 4, l) - sequences_batch = _to_channel_last_layout(sequences_batch, device=self.device, dtype=torch.int8) + contribs_hyp = F.pad(self.contribs[start:end, :, :], pad_lens) + contribs_hyp = _to_channel_last_layout( + contribs_hyp, device=self.device, dtype=torch.float32 + ) + sequences_batch = F.pad(self.sequences[start:end, :, :], pad_lens) # (b, 4, l) + sequences_batch = _to_channel_last_layout( + sequences_batch, device=self.device, dtype=torch.int8 + ) contribs_batch = contribs_hyp * sequences_batch - # global_scale = ((contribs_batch**2).sum(dim=(1,2)) / self.l).sqrt() - # contribs_batch = torch.nan_to_num(contribs_batch / global_scale[:,None,None]) - return contribs_batch, sequences_batch, inds - + class BatchLoaderHyp(BatchLoaderBase): - def load_batch(self, start, end): - inds, pad_lens = self._get_inds_and_pad_lens(start, end) + """Batch loader for hypothetical contribution scores. - contribs_batch = F.pad(self.contribs[start:end,:,:], pad_lens) - contribs_batch = _to_channel_last_layout(contribs_batch, device=self.device, dtype=torch.float32) + Handles hypothetical contribution scores in shape (N, 4, L) where + scores represent counterfactual effects of base substitutions. + """ + + def load_batch( + self, start: int, end: int + ) -> Tuple[Float[Tensor, "B 4 L"], int, Int[Tensor, " B"]]: + inds, pad_lens = self._get_inds_and_pad_lens(start, end) - # global_scale = ((contribs_batch**2).sum(dim=(1,2)) / self.l).sqrt() - # contribs_batch = torch.nan_to_num(contribs_batch / global_scale[:,None,None]) + contribs_batch = F.pad(self.contribs[start:end, :, :], pad_lens) + contribs_batch = _to_channel_last_layout( + contribs_batch, device=self.device, dtype=torch.float32 + ) return contribs_batch, 1, inds -def fit_contribs(cwms, contribs, sequences, cwm_trim_mask, use_hypothetical, lambdas, step_size_max, step_size_min, sqrt_transform, - convergence_tol, max_steps, batch_size, step_adjust, post_filter, device, compile_optimizer, eps=1.): - """ - Call hits by fitting sparse linear model to contributions - - cwms: (m, 4, w) - contribs: (n, 4, l) or (n, l) - sequences: (n, 4, l) - cwm_trim_mask: (m, w) +def fit_contribs( + cwms: Float[ndarray, "M 4 W"], + contribs: Union[Float[ndarray, "N 4 L"], Float[ndarray, "N L"]], + sequences: Int[ndarray, "N 4 L"], + cwm_trim_mask: Float[ndarray, "M W"], + use_hypothetical: bool, + lambdas: Float[ndarray, " M"], + step_size_max: float, + step_size_min: float, + sqrt_transform: bool, + convergence_tol: float, + max_steps: int, + batch_size: int, + step_adjust: float, + post_filter: bool, + device: Optional[torch.device], + compile_optimizer: bool, + eps: float = 1.0, +) -> Tuple[pl.DataFrame, pl.DataFrame]: + """Call motif hits by fitting sparse linear model to contribution scores. + + This is the main function implementing the Fi-NeMo algorithm. It identifies + motif instances by solving a sparse reconstruction problem where contribution + scores are approximated as a linear combination of motif CWMs at specific + positions. The optimization uses proximal gradient descent with momentum. + + Parameters + ---------- + cwms : Float[ndarray, "M 4 W"] + Motif contribution weight matrices where M = number of motifs, + 4 = DNA bases (A, C, G, T), W = motif width. + contribs : Float[ndarray, "N 4 L"] | Float[ndarray, "N L"] + Neural network contribution scores where N = number of regions, + L = sequence length. Can be hypothetical (N, 4, L) or projected (N, L). + sequences : Float[ndarray, "N 4 L"] + One-hot encoded DNA sequences. + cwm_trim_mask : Float[ndarray, "M W"] + Binary mask indicating which positions of each CWM to use. + use_hypothetical : bool + Whether to use hypothetical contribution scores (True) or + projected scores (False). + lambdas : Float[ndarray, "M"] + L1 regularization weights for each motif. + step_size_max : float + Maximum optimization step size. + step_size_min : float + Minimum optimization step size (for convergence failure detection). + sqrt_transform : bool + Whether to apply signed square root transformation to inputs. + convergence_tol : float + Convergence tolerance based on duality gap. + max_steps : int + Maximum number of optimization steps. + batch_size : int + Number of regions to process simultaneously. + step_adjust : float + Factor to reduce step size when optimization diverges. + post_filter : bool + Whether to filter hits based on similarity threshold. + device : torch.device, optional + Target device for computation. Auto-detected if None. + compile_optimizer : bool + Whether to JIT compile the optimizer for speed. + eps : float, default 1.0 + Small constant for numerical stability. + + Returns + ------- + hits_df : pl.DataFrame + DataFrame containing called motif hits with columns: + - peak_id: Region index + - motif_id: Motif index + - hit_start: Start position of hit + - hit_coefficient: Hit strength coefficient + - hit_similarity: Cosine similarity with motif + - hit_importance: Total contribution score in hit region + - hit_importance_sq: Sum of squared contributions (for normalization) + qc_df : pl.DataFrame + DataFrame containing quality control metrics with columns: + - peak_id: Region index + - nll: Final negative log likelihood + - dual_gap: Final duality gap + - num_steps: Number of optimization steps + - step_size: Final step size + - global_scale: Region-level scaling factor + + Notes + ----- + The algorithm solves the optimization problem: + + minimize_c: ||contribs - Σⱼ convolve(c * scale, cwms[j]) * sequences||²₂ + Σⱼ λⱼ||c[:,j]||₁ + + subject to: c ≥ 0 + + where c[i,j] represents the strength of motif j at position i. + + The importance scaling balances reconstruction across different + motifs and positions based on the local contribution magnitude. + + Examples + -------- + >>> hits_df, qc_df = fit_contribs( + ... cwms=motif_cwms, + ... contribs=contrib_scores, + ... sequences=onehot_seqs, + ... cwm_trim_mask=trim_masks, + ... use_hypothetical=False, + ... lambdas=np.array([0.7, 0.8]), + ... step_size_max=3.0, + ... step_size_min=0.08, + ... sqrt_transform=False, + ... convergence_tol=0.0005, + ... max_steps=10000, + ... batch_size=1000, + ... step_adjust=0.7, + ... post_filter=True, + ... device=None, + ... compile_optimizer=False + ... ) """ m, _, w = cwms.shape - n, _, l = sequences.shape + n, _, sequence_length = sequences.shape b = batch_size @@ -175,74 +542,113 @@ def fit_contribs(cwms, contribs, sequences, cwm_trim_mask, use_hypothetical, lam else: device = torch.device("cpu") warnings.warn("No GPU available. Running on CPU.", RuntimeWarning) - + # Compile optimizer if requested global optimizer_step if compile_optimizer: optimizer_step = torch.compile(optimizer_step, fullgraph=True) - # Convert inputs to pytorch tensors - cwms = torch.from_numpy(cwms) - contribs = torch.from_numpy(contribs) - sequences = torch.from_numpy(sequences) - cwm_trim_mask = torch.from_numpy(cwm_trim_mask)[:,None,:].repeat(1,4,1) - lambdas = torch.from_numpy(lambdas)[None,:,None].to(device=device, dtype=torch.float32) - - cwms = _to_channel_last_layout(cwms, device=device, dtype=torch.float32) - cwm_trim_mask = _to_channel_last_layout(cwm_trim_mask, device=device, dtype=torch.float32) - cwms = cwms * cwm_trim_mask + # Convert inputs to PyTorch tensors with proper device placement + cwms_tensor: torch.Tensor = torch.from_numpy(cwms) + contribs_tensor: torch.Tensor = torch.from_numpy(contribs) + sequences_tensor: torch.Tensor = torch.from_numpy(sequences) + cwm_trim_mask_tensor = torch.from_numpy(cwm_trim_mask)[:, None, :].repeat(1, 4, 1) + lambdas_tensor: torch.Tensor = torch.from_numpy(lambdas)[None, :, None].to( + device=device, dtype=torch.float32 + ) + + # Convert to channel-last layout for optimized convolution operations + cwms_tensor = _to_channel_last_layout( + cwms_tensor, device=device, dtype=torch.float32 + ) + cwm_trim_mask_tensor = _to_channel_last_layout( + cwm_trim_mask_tensor, device=device, dtype=torch.float32 + ) + cwms_tensor = cwms_tensor * cwm_trim_mask_tensor # Apply trimming mask if sqrt_transform: - cwms = _signed_sqrt(cwms) - cwm_norm = (cwms**2).sum(dim=(1,2)).sqrt() - cwms = cwms / cwm_norm[:,None,None] + cwms_tensor = _signed_sqrt(cwms_tensor) + cwm_norm = (cwms_tensor**2).sum(dim=(1, 2)).sqrt() + cwms_tensor = cwms_tensor / cwm_norm[:, None, None] # Initialize batch loader - if len(contribs.shape) == 3: + if len(contribs_tensor.shape) == 3: if use_hypothetical: - batch_loader = BatchLoaderHyp(contribs, sequences, l, device) + batch_loader = BatchLoaderHyp( + contribs_tensor, sequences_tensor, sequence_length, device + ) else: - batch_loader = BatchLoaderProj(contribs, sequences, l, device) - elif len(contribs.shape) == 2: + batch_loader = BatchLoaderProj( + contribs_tensor, sequences_tensor, sequence_length, device + ) + elif len(contribs_tensor.shape) == 2: if use_hypothetical: - raise ValueError("Input regions do not contain hypothetical contribution scores") + raise ValueError( + "Input regions do not contain hypothetical contribution scores" + ) else: - batch_loader = BatchLoaderCompactFmt(contribs, sequences, l, device) + batch_loader = BatchLoaderCompactFmt( + contribs_tensor, sequences_tensor, sequence_length, device + ) else: - raise ValueError(f"Input contributions array is of incorrect shape {contribs.shape}") - - # Intialize output container objects - hit_idxs_lst = [] - coefficients_lst = [] - similarity_lst = [] - importance_lst = [] - importance_sq_lst = [] - qc_lsts = {"nll": [], "dual_gap": [], "num_steps": [], "step_size": [], "global_scale": [], "peak_id": []} + raise ValueError( + f"Input contributions array is of incorrect shape {contribs_tensor.shape}" + ) + + # Initialize output container objects + hit_idxs_lst: List[ndarray] = [] + coefficients_lst: List[ndarray] = [] + similarity_lst: List[ndarray] = [] + importance_lst: List[ndarray] = [] + importance_sq_lst: List[ndarray] = [] + qc_lsts: Dict[str, List[ndarray]] = { + "nll": [], + "dual_gap": [], + "num_steps": [], + "step_size": [], + "global_scale": [], + "peak_id": [], + } # Initialize buffers for optimizer - coef_inter = torch.zeros((b, m, l - w + 1)) # (b, m, l - w + 1) + coef_inter: Float[Tensor, "B M P"] = torch.zeros( + (b, m, sequence_length - w + 1) + ) # (b, m, sequence_length - w + 1) coef_inter = _to_channel_last_layout(coef_inter, device=device, dtype=torch.float32) - coef = torch.zeros_like(coef_inter) - i = torch.zeros((b, 1, 1), dtype=torch.int, device=device) - step_sizes = torch.full((b, 1, 1), step_size_max, dtype=torch.float32, device=device) - - converged = torch.full((b,), True, dtype=torch.bool, device=device) + coef: Float[Tensor, "B M P"] = torch.zeros_like(coef_inter) + i: Float[Tensor, "B 1 1"] = torch.zeros((b, 1, 1), dtype=torch.int, device=device) + step_sizes: Float[Tensor, "B 1 1"] = torch.full( + (b, 1, 1), step_size_max, dtype=torch.float32, device=device + ) + + converged: Bool[Tensor, " B"] = torch.full( + (b,), True, dtype=torch.bool, device=device + ) num_load = b - contribs_buf = torch.zeros((b, 4, l)) - contribs_buf = _to_channel_last_layout(contribs_buf, device=device, dtype=torch.float32) + contribs_buf: Float[Tensor, "B 4 L"] = torch.zeros((b, 4, sequence_length)) + contribs_buf = _to_channel_last_layout( + contribs_buf, device=device, dtype=torch.float32 + ) + seqs_buf: Union[Int[Tensor, "B 4 L"], int] if use_hypothetical: seqs_buf = 1 else: - seqs_buf = torch.zeros((b, 4, l)) + seqs_buf = torch.zeros((b, 4, sequence_length)) seqs_buf = _to_channel_last_layout(seqs_buf, device=device, dtype=torch.int8) - importance_scale_buf = torch.zeros((b, m, l - w + 1)) - importance_scale_buf = _to_channel_last_layout(importance_scale_buf, device=device, dtype=torch.float32) + importance_scale_buf: Float[Tensor, "B M P"] = torch.zeros( + (b, m, sequence_length - w + 1) + ) + importance_scale_buf = _to_channel_last_layout( + importance_scale_buf, device=device, dtype=torch.float32 + ) - inds_buf = torch.zeros((b,), dtype=torch.int, device=device) - global_scale_buf = torch.zeros((b,), dtype=torch.float, device=device) + inds_buf: Int[Tensor, " B"] = torch.zeros((b,), dtype=torch.int, device=device) + global_scale_buf: Float[Tensor, " B"] = torch.zeros( + (b,), dtype=torch.float, device=device + ) with tqdm(disable=None, unit="regions", total=n, ncols=120) as pbar: num_complete = 0 @@ -252,61 +658,80 @@ def fit_contribs(cwms, contribs, sequences, cwm_trim_mask, use_hypothetical, lam if num_load > 0: load_start = next_ind load_end = load_start + num_load - next_ind = min(load_end, contribs.shape[0]) + next_ind = min(load_end, contribs_tensor.shape[0]) - batch_data = batch_loader.load_batch(load_start, load_end) + batch_data = batch_loader.load_batch(int(load_start), int(load_end)) contribs_batch, seqs_batch, inds_batch = batch_data if sqrt_transform: contribs_batch = _signed_sqrt(contribs_batch) - - global_scale_batch = ((contribs_batch**2).sum(dim=(1,2)) / l).sqrt() - contribs_batch = torch.nan_to_num(contribs_batch / global_scale_batch[:,None,None]) - importance_scale_batch = (F.conv1d(contribs_batch**2, cwm_trim_mask) + eps)**(-0.5) + global_scale_batch = ( + (contribs_batch**2).sum(dim=(1, 2)) / sequence_length + ).sqrt() + contribs_batch = torch.nan_to_num( + contribs_batch / global_scale_batch[:, None, None] + ) + + importance_scale_batch = ( + F.conv1d(contribs_batch**2, cwm_trim_mask_tensor) + eps + ) ** (-0.5) importance_scale_batch = importance_scale_batch.clamp(max=10) - contribs_buf[converged,:,:] = contribs_batch + contribs_buf[converged, :, :] = contribs_batch if not use_hypothetical: - seqs_buf[converged,:,:] = seqs_batch + seqs_buf[converged, :, :] = seqs_batch # type: ignore + + importance_scale_buf[converged, :, :] = importance_scale_batch - importance_scale_buf[converged,:,:] = importance_scale_batch - inds_buf[converged] = inds_batch global_scale_buf[converged] = global_scale_batch - coef_inter[converged,:,:] *= 0 - coef[converged,:,:] *= 0 + coef_inter[converged, :, :] *= 0 + coef[converged, :, :] *= 0 i[converged] *= 0 step_sizes[converged] = step_size_max # Optimization step - coef_inter, coef, gap, nll = optimizer_step(cwms, contribs_buf, importance_scale_buf, seqs_buf, coef_inter, coef, - i, step_sizes, l, lambdas) + coef_inter, coef, gap, nll = optimizer_step( + cwms_tensor, + contribs_buf, + importance_scale_buf, + seqs_buf, + coef_inter, + coef, + i, + step_sizes, + sequence_length, + lambdas_tensor, + ) i += 1 # Assess convergence of each peak being optimized. Reset diverged peaks with lower step size. active = inds_buf >= 0 diverged = ~torch.isfinite(gap) & active - coef_inter[diverged,:,:] *= 0 - coef[diverged,:,:] *= 0 + coef_inter[diverged, :, :] *= 0 + coef[diverged, :, :] *= 0 i[diverged] *= 0 - step_sizes[diverged,:,:] *= step_adjust + step_sizes[diverged, :, :] *= step_adjust timeouts = (i > max_steps).squeeze() & active if timeouts.sum().item() > 0: timeout_inds = inds_buf[timeouts] for ind in timeout_inds: - warnings.warn(f"Region {ind} has not converged within max_steps={max_steps} iterations.", RuntimeWarning) + warnings.warn( + f"Region {ind} has not converged within max_steps={max_steps} iterations.", + RuntimeWarning, + ) fails = (step_sizes < step_size_min).squeeze() & active if fails.sum().item() > 0: fail_inds = inds_buf[fails] for ind in fail_inds: warnings.warn(f"Optimizer failed for region {ind}.", RuntimeWarning) - + converged = ((gap <= convergence_tol) | timeouts | fails) & active num_load = converged.sum().item() @@ -315,27 +740,32 @@ def fit_contribs(cwms, contribs, sequences, cwm_trim_mask, use_hypothetical, lam inds_out = inds_buf[converged] global_scale_out = global_scale_buf[converged] - # Compute hit scores - coef_out = coef[converged,:,:] - importance_scale_out_dense = importance_scale_buf[converged,:,:] - importance_sq = importance_scale_out_dense**(-2) - eps + # Compute hit scores + coef_out = coef[converged, :, :] + importance_scale_out_dense = importance_scale_buf[converged, :, :] + importance_sq = importance_scale_out_dense ** (-2) - eps xcor_scale = importance_sq.sqrt() - contribs_converged = contribs_buf[converged,:,:] - importance_sum_out_dense = F.conv1d(torch.abs(contribs_converged), cwm_trim_mask) - xcov_out_dense = F.conv1d(contribs_converged, cwms) - # xcov_out_dense = F.conv1d(torch.abs(contribs_converged), cwms) + contribs_converged = contribs_buf[converged, :, :] + importance_sum_out_dense = F.conv1d( + torch.abs(contribs_converged), cwm_trim_mask_tensor + ) + xcov_out_dense = F.conv1d(contribs_converged, cwms_tensor) + # xcov_out_dense = F.conv1d(torch.abs(contribs_converged), cwms_tensor) xcor_out_dense = xcov_out_dense / xcor_scale if post_filter: - coef_out = coef_out * (xcor_out_dense >= lambdas) + coef_out = coef_out * (xcor_out_dense >= lambdas_tensor) - # Extract hit coordinates + # Extract hit coordinates using sparse tensor representation coef_out = coef_out.to_sparse() - hit_idxs_out = torch.clone(coef_out.indices()) - hit_idxs_out[0,:] = F.embedding(hit_idxs_out[0,:], inds_out[:,None]).squeeze() - # Map buffer index to peak index + # Tensor indexing operations for hit extraction + hit_idxs_out = torch.clone(coef_out.indices()) # Sparse tensor indices + hit_idxs_out[0, :] = F.embedding( + hit_idxs_out[0, :], inds_out[:, None] + ).squeeze() # Embedding lookup with complex indexing + # Map buffer index to peak index ind_tuple = torch.unbind(coef_out.indices()) importance_out = importance_sum_out_dense[ind_tuple] @@ -347,8 +777,8 @@ def fit_contribs(cwms, contribs, sequences, cwm_trim_mask, use_hypothetical, lam # Store outputs gap_out = gap[converged] nll_out = nll[converged] - step_out = i[converged,0,0] - step_sizes_out = step_sizes[converged,0,0] + step_out = i[converged, 0, 0] + step_sizes_out = step_sizes[converged, 0, 0] hit_idxs_lst.append(hit_idxs_out.numpy(force=True).T) coefficients_lst.append(scores_out_raw.numpy(force=True)) @@ -373,19 +803,19 @@ def fit_contribs(cwms, contribs, sequences, cwm_trim_mask, use_hypothetical, lam scores_importance = np.concatenate(importance_lst, axis=0) scores_importance_sq = np.concatenate(importance_sq_lst, axis=0) - hits = { - "peak_id": hit_idxs[:,0].astype(np.uint32), - "motif_id": hit_idxs[:,1].astype(np.uint32), - "hit_start": hit_idxs[:,2], + hits: Dict[str, ndarray] = { + "peak_id": hit_idxs[:, 0].astype(np.uint32), + "motif_id": hit_idxs[:, 1].astype(np.uint32), + "hit_start": hit_idxs[:, 2], "hit_coefficient": scores_coefficient, "hit_similarity": scores_similarity, "hit_importance": scores_importance, "hit_importance_sq": scores_importance_sq, } - qc = {k: np.concatenate(v, axis=0) for k, v in qc_lsts.items()} + qc: Dict[str, ndarray] = {k: np.concatenate(v, axis=0) for k, v in qc_lsts.items()} hits_df = pl.DataFrame(hits) qc_df = pl.DataFrame(qc) - return hits_df, qc_df \ No newline at end of file + return hits_df, qc_df diff --git a/src/finemo/main.py b/src/finemo/main.py index 852147c..16d03af 100644 --- a/src/finemo/main.py +++ b/src/finemo/main.py @@ -1,20 +1,88 @@ +"""Main CLI module for the Fi-NeMo motif instance calling pipeline. + +This module provides the command-line interface for all Fi-NeMo operations: +- Data preprocessing from various genomic formats +- Motif hit calling using the Fi-NeMo algorithm +- Report generation and result visualization +- Post-processing operations (hit collapsing, intersection) + +The CLI supports multiple input formats including bigWig, HDF5 (ChromBPNet/BPNet), +and TF-MoDISCo format. +""" + from . import data_io import os import argparse import warnings - - -def extract_regions_bw(peaks_path, chrom_order_path, fa_path, bw_paths, out_path, region_width): +from typing import Optional, List + +import polars as pl + + +def extract_regions_bw( + peaks_path: str, + chrom_order_path: Optional[str], + fa_path: str, + bw_paths: List[str], + out_path: str, + region_width: int, +) -> None: + """Extract genomic regions and contribution scores from bigWig and FASTA files. + + Parameters + ---------- + peaks_path : str + Path to ENCODE NarrowPeak format file. + chrom_order_path : str, optional + Path to chromosome ordering file. + fa_path : str + Path to genome FASTA file. + bw_paths : List[str] + List of bigWig file paths containing contribution scores. + out_path : str + Output path for NPZ file. + region_width : int + Width of regions to extract around peak summits. + + Notes + ----- + BigWig files only provide projected contribution scores. + """ half_width = region_width // 2 + # Load peak regions and extract sequences/contributions peaks_df = data_io.load_peaks(peaks_path, chrom_order_path, half_width) - sequences, contribs = data_io.load_regions_from_bw(peaks_df, fa_path, bw_paths, half_width) + sequences, contribs = data_io.load_regions_from_bw( + peaks_df, fa_path, bw_paths, half_width + ) + # Save processed data to NPZ format data_io.write_regions_npz(sequences, contribs, out_path, peaks_df=peaks_df) -def extract_regions_chrombpnet_h5(peaks_path, chrom_order_path, h5_paths, out_path, region_width): +def extract_regions_chrombpnet_h5( + peaks_path: Optional[str], + chrom_order_path: Optional[str], + h5_paths: List[str], + out_path: str, + region_width: int, +) -> None: + """Extract genomic regions and contribution scores from ChromBPNet HDF5 files. + + Parameters + ---------- + peaks_path : str, optional + Path to ENCODE NarrowPeak format file. If None, lacks absolute coordinates. + chrom_order_path : str, optional + Path to chromosome ordering file. + h5_paths : List[str] + List of ChromBPNet HDF5 file paths. + out_path : str + Output path for NPZ file. + region_width : int + Width of regions to extract around peak summits. + """ half_width = region_width // 2 if peaks_path is not None: @@ -27,7 +95,28 @@ def extract_regions_chrombpnet_h5(peaks_path, chrom_order_path, h5_paths, out_pa data_io.write_regions_npz(sequences, contribs, out_path, peaks_df=peaks_df) -def extract_regions_bpnet_h5(peaks_path, chrom_order_path, h5_paths, out_path, region_width): +def extract_regions_bpnet_h5( + peaks_path: Optional[str], + chrom_order_path: Optional[str], + h5_paths: List[str], + out_path: str, + region_width: int, +) -> None: + """Extract genomic regions and contribution scores from BPNet HDF5 files. + + Parameters + ---------- + peaks_path : str, optional + Path to ENCODE NarrowPeak format file. If None, output lacks absolute coordinates. + chrom_order_path : str, optional + Path to chromosome ordering file. + h5_paths : List[str] + List of BPNet HDF5 file paths. + out_path : str + Output path for NPZ file. + region_width : int + Width of regions to extract around peak summits. + """ half_width = region_width // 2 if peaks_path is not None: @@ -40,7 +129,31 @@ def extract_regions_bpnet_h5(peaks_path, chrom_order_path, h5_paths, out_path, r data_io.write_regions_npz(sequences, contribs, out_path, peaks_df=peaks_df) -def extract_regions_modisco_fmt(peaks_path, chrom_order_path, shaps_paths, ohe_path, out_path, region_width): +def extract_regions_modisco_fmt( + peaks_path: Optional[str], + chrom_order_path: Optional[str], + shaps_paths: List[str], + ohe_path: str, + out_path: str, + region_width: int, +) -> None: + """Extract genomic regions and contribution scores from TF-MoDISCo format files. + + Parameters + ---------- + peaks_path : str, optional + Path to ENCODE NarrowPeak format file. If None, output lacks absolute coordinates. + chrom_order_path : str, optional + Path to chromosome ordering file. + shaps_paths : List[str] + List of paths to .npy/.npz files containing SHAP/attribution scores. + ohe_path : str + Path to .npy/.npz file containing one-hot encoded sequences. + out_path : str + Output path for NPZ file. + region_width : int + Width of regions to extract around peak summits. + """ half_width = region_width // 2 if peaks_path is not None: @@ -48,38 +161,135 @@ def extract_regions_modisco_fmt(peaks_path, chrom_order_path, shaps_paths, ohe_p else: peaks_df = None - sequences, contribs = data_io.load_regions_from_modisco_fmt(shaps_paths, ohe_path, half_width) + sequences, contribs = data_io.load_regions_from_modisco_fmt( + shaps_paths, ohe_path, half_width + ) data_io.write_regions_npz(sequences, contribs, out_path, peaks_df=peaks_df) -def call_hits(regions_path, peaks_path, modisco_h5_path, chrom_order_path, motifs_include_path, motif_names_path, - motif_lambdas_path, out_dir, cwm_trim_coords_path, cwm_trim_thresholds_path, cwm_trim_threshold_default, - lambda_default, step_size_max, step_size_min, sqrt_transform, convergence_tol, max_steps, batch_size, - step_adjust, device, mode, no_post_filter, compile_optimizer): - +def call_hits( + regions_path: str, + peaks_path: Optional[str], + modisco_h5_path: str, + chrom_order_path: Optional[str], + motifs_include_path: Optional[str], + motif_names_path: Optional[str], + motif_lambdas_path: Optional[str], + out_dir: str, + cwm_trim_coords_path: Optional[str], + cwm_trim_thresholds_path: Optional[str], + cwm_trim_threshold_default: float, + lambda_default: float, + step_size_max: float, + step_size_min: float, + sqrt_transform: bool, + convergence_tol: float, + max_steps: int, + batch_size: int, + step_adjust: float, + device: Optional[str], + mode: str, + no_post_filter: bool, + compile_optimizer: bool, +) -> None: + """Call motif hits using the Fi-NeMo algorithm on preprocessed genomic regions. + + This function implements the core Fi-NeMo hit calling pipeline, which identifies + motif instances by solving a sparse reconstruction problem using proximal gradient + descent. The algorithm represents contribution scores as weighted combinations of + motif CWMs at specific positions. + + Parameters + ---------- + regions_path : str + Path to NPZ file containing preprocessed regions (sequences, contributions, + and optional peak coordinates). + peaks_path : str, optional + DEPRECATED. Path to ENCODE NarrowPeak format file. Peak data should be + included during preprocessing instead. + modisco_h5_path : str + Path to TF-MoDISco H5 file containing motif CWMs. + chrom_order_path : str, optional + DEPRECATED. Path to chromosome ordering file. + motifs_include_path : str, optional + Path to file listing motif names to include in analysis. + motif_names_path : str, optional + Path to file mapping motif IDs to custom names. + motif_lambdas_path : str, optional + Path to file specifying per-motif lambda values. + out_dir : str + Output directory for results. + cwm_trim_coords_path : str, optional + Path to file specifying custom motif trimming coordinates. + cwm_trim_thresholds_path : str, optional + Path to file specifying custom motif trimming thresholds. + cwm_trim_threshold_default : float + Default threshold for motif trimming (typically 0.3). + lambda_default : float + Default L1 regularization weight (typically 0.7). + step_size_max : float + Maximum optimization step size. + step_size_min : float + Minimum optimization step size. + sqrt_transform : bool + Whether to apply signed square root transform to contributions. + convergence_tol : float + Convergence tolerance for duality gap. + max_steps : int + Maximum number of optimization steps. + batch_size : int + Batch size for GPU processing. + step_adjust : float + Step size adjustment factor on divergence. + device : str, optional + DEPRECATED. Use CUDA_VISIBLE_DEVICES environment variable instead. + mode : str + Contribution type mode ('pp', 'ph', 'hp', 'hh') where 'p'=projected, 'h'=hypothetical. + no_post_filter : bool + If True, skip post-hit-calling similarity filtering. + compile_optimizer : bool + Whether to JIT-compile the optimizer for speed. + + Notes + ----- + The Fi-NeMo algorithm solves the optimization problem: + minimize_c: ||contribs - reconstruction(c)||²₂ + λ||c||₁ + subject to: c ≥ 0 + + where c represents motif hit coefficients and reconstruction uses convolution + with motif CWMs. + """ + params = locals() + import torch from . import hitcaller if device is not None: - warnings.warn("The `--device` flag is deprecated and will be removed in a future version. Please use the `CUDA_VISIBLE_DEVICES` environment variable to specify the GPU device.") - + warnings.warn( + "The `--device` flag is deprecated and will be removed in a future version. Please use the `CUDA_VISIBLE_DEVICES` environment variable to specify the GPU device." + ) + sequences, contribs, peaks_df, has_peaks = data_io.load_regions_npz(regions_path) region_width = sequences.shape[2] if region_width % 2 != 0: raise ValueError(f"Region width of {region_width} is not divisible by 2.") - + half_width = region_width // 2 num_regions = contribs.shape[0] if peaks_path is not None: - warnings.warn("Providing a peaks file to `call-hits` is deprecated, and this option will be removed in a future version. Peaks should instead be provided in the preprocessing step to be included in `regions.npz`.") + warnings.warn( + "Providing a peaks file to `call-hits` is deprecated, and this option will be removed in a future version. Peaks should instead be provided in the preprocessing step to be included in `regions.npz`." + ) peaks_df = data_io.load_peaks(peaks_path, chrom_order_path, half_width) has_peaks = True if not has_peaks: - warnings.warn("No peak region data provided. Output hits will lack absolute genomic coordinates.") + warnings.warn( + "No peak region data provided. Output hits will lack absolute genomic coordinates." + ) if mode == "pp": motif_type = "cwm" @@ -93,6 +303,10 @@ def call_hits(regions_path, peaks_path, modisco_h5_path, chrom_order_path, motif elif mode == "hh": motif_type = "hcwm" use_hypothetical_contribs = True + else: + raise ValueError( + f"Invalid mode: {mode}. Must be one of 'pp', 'ph', 'hp', 'hh'." + ) if motifs_include_path is not None: motifs_include = data_io.load_txt(motifs_include_path) @@ -118,15 +332,42 @@ def call_hits(regions_path, peaks_path, modisco_h5_path, chrom_order_path, motif trim_thresholds = data_io.load_mapping(cwm_trim_thresholds_path, float) else: trim_thresholds = None - - motifs_df, cwms, trim_masks, motif_names = data_io.load_modisco_motifs(modisco_h5_path, trim_coords, trim_thresholds, cwm_trim_threshold_default, - motif_type, motifs_include, motif_name_map, motif_lambdas, lambda_default, True) + + motifs_df, cwms, trim_masks, _ = data_io.load_modisco_motifs( + modisco_h5_path, + trim_coords, + trim_thresholds, + cwm_trim_threshold_default, + motif_type, + motifs_include, + motif_name_map, + motif_lambdas, + lambda_default, + True, + ) num_motifs = cwms.shape[0] motif_width = cwms.shape[2] lambdas = motifs_df.get_column("lambda").to_numpy(writable=True) - hits_df, qc_df = hitcaller.fit_contribs(cwms, contribs, sequences, trim_masks, use_hypothetical_contribs, lambdas, step_size_max, step_size_min, - sqrt_transform, convergence_tol, max_steps, batch_size, step_adjust, not no_post_filter, device, compile_optimizer) + device_obj = torch.device(device) if device is not None else None + hits_df, qc_df = hitcaller.fit_contribs( + cwms, + contribs, + sequences, + trim_masks, + use_hypothetical_contribs, + lambdas, + step_size_max, + step_size_min, + sqrt_transform, + convergence_tol, + max_steps, + batch_size, + step_adjust, + not no_post_filter, + device_obj, + compile_optimizer, + ) os.makedirs(out_dir, exist_ok=True) out_path_qc = os.path.join(out_dir, "peaks_qc.tsv") @@ -150,28 +391,86 @@ def call_hits(regions_path, peaks_path, modisco_h5_path, chrom_order_path, motif data_io.write_params(params, out_path_params) -def report(regions_path, hits_dir, modisco_h5_path, peaks_path, motifs_include_path, motif_names_path, - out_dir, modisco_region_width, cwm_trim_threshold, compute_recall, use_seqlets): - from . import evaluation, visualization +def report( + regions_path: str, + hits_dir: str, + modisco_h5_path: Optional[str], + peaks_path: Optional[str], + motifs_include_path: Optional[str], + motif_names_path: Optional[str], + out_dir: str, + modisco_region_width: int, + cwm_trim_threshold: float, + compute_recall: bool, + use_seqlets: bool, +) -> None: + """Generate comprehensive HTML report with statistics and visualizations. + + This function creates detailed analysis reports comparing Fi-NeMo hit calling + results with TF-MoDISCo seqlets, including performance metrics, distribution + plots, and motif visualization. The report provides insights into hit calling + quality and motif discovery accuracy. + + Parameters + ---------- + regions_path : str + Path to NPZ file containing the same regions used for hit calling. + hits_dir : str + Path to directory containing Fi-NeMo hit calling outputs. + modisco_h5_path : str, optional + Path to TF-MoDISCo H5 file. If None, seqlet comparisons are skipped. + peaks_path : str, optional + DEPRECATED. Peak coordinates should be included in regions file. + motifs_include_path : str, optional + DEPRECATED. This information is inferred from hit calling outputs. + motif_names_path : str, optional + DEPRECATED. This information is inferred from hit calling outputs. + out_dir : str + Output directory for report files. + modisco_region_width : int + Width of regions used by TF-MoDISCo (needed for coordinate conversion). + cwm_trim_threshold : float + DEPRECATED. This information is inferred from hit calling outputs. + compute_recall : bool + Whether to compute recall metrics against TF-MoDISCo seqlets. + use_seqlets : bool + Whether to include seqlet-based comparisons in the report. + + Notes + ----- + The generated report includes: + - Hit vs seqlet count comparisons + - Motif CWM visualizations + - Hit statistic distributions + - Co-occurrence heatmaps + - Confusion matrices for overlapping motifs + """ + from . import evaluation, visualization sequences, contribs, peaks_df, _ = data_io.load_regions_npz(regions_path) if len(contribs.shape) == 3: regions = contribs * sequences elif len(contribs.shape) == 2: - regions = contribs[:,None,:] * sequences + regions = contribs[:, None, :] * sequences + else: + raise ValueError(f"Unexpected contribs shape: {contribs.shape}") half_width = regions.shape[2] // 2 modisco_half_width = modisco_region_width // 2 if peaks_path is not None: - warnings.warn("Providing a peaks file to `report` is deprecated, and this option will be removed in a future version. Peaks should instead be provided in the preprocessing step to be included in `regions.npz`.") - peaks_df = data_io.load_peaks(peaks_path, None, half_width) + warnings.warn( + "Providing a peaks file to `report` is deprecated, and this option will be removed in a future version. Peaks should instead be provided in the preprocessing step to be included in `regions.npz`." + ) + peaks_df = data_io.load_peaks(peaks_path, None, half_width) if hits_dir.endswith(".tsv"): - warnings.warn("Passing a hits.tsv file to `finemo report` is deprecated. Please provide the directory containing the hits.tsv file instead.") + warnings.warn( + "Passing a hits.tsv file to `finemo report` is deprecated. Please provide the directory containing the hits.tsv file instead." + ) hits_path = hits_dir - + hits_df = data_io.load_hits(hits_path, lazy=True) if motifs_include_path is not None: @@ -180,12 +479,29 @@ def report(regions_path, hits_dir, modisco_h5_path, peaks_path, motifs_include_p motifs_include = None if motif_names_path is not None: - motif_name_map = data_io.load_txt(motif_names_path) + motif_name_map = data_io.load_mapping(motif_names_path, str) else: motif_name_map = None - motifs_df, cwms_modisco, trim_masks, motif_names = data_io.load_modisco_motifs(modisco_h5_path, None, None, cwm_trim_threshold, "cwm", - motifs_include, motif_name_map, None, None, True) + if modisco_h5_path is not None: + motifs_df, cwms_modisco, _, motif_names = data_io.load_modisco_motifs( + modisco_h5_path, + None, + None, + cwm_trim_threshold, + "cwm", + motifs_include, + motif_name_map, + None, + 1.0, + True, + ) + else: + # When no modisco_h5_path is provided in legacy TSV mode, we can't compute motifs + # This will cause an error later, but that's expected behavior + raise ValueError( + "modisco_h5_path is required when providing a hits.tsv file directly" + ) else: hits_df_path = os.path.join(hits_dir, "hits.tsv") @@ -202,64 +518,150 @@ def report(regions_path, hits_dir, modisco_h5_path, peaks_path, motifs_include_p cwm_trim_threshold = params["cwm_trim_threshold_default"] if not use_seqlets: - warnings.warn("Usage of the `--no-seqlets` flag is deprecated and will be removed in a future version. Please omit the `--modisco-h5` argument instead.") + warnings.warn( + "Usage of the `--no-seqlets` flag is deprecated and will be removed in a future version. Please omit the `--modisco-h5` argument instead." + ) seqlets_df = None elif modisco_h5_path is None: compute_recall = False seqlets_df = None else: - seqlets_df = data_io.load_modisco_seqlets(modisco_h5_path, peaks_df, motifs_df, half_width, modisco_half_width, lazy=True) + seqlets_df = data_io.load_modisco_seqlets( + modisco_h5_path, + peaks_df, + motifs_df, + half_width, + modisco_half_width, + lazy=True, + ) motif_width = cwms_modisco.shape[2] - occ_df, coooc = evaluation.get_motif_occurences(hits_df, motif_names) + # Convert to LazyFrame if needed and ensure motif_names is a list + if isinstance(hits_df, pl.LazyFrame): + hits_df_lazy: pl.LazyFrame = hits_df + else: + hits_df_lazy: pl.LazyFrame = hits_df.lazy() + + motif_names_list: List[str] = list(motif_names) + + occ_df, coooc = evaluation.get_motif_occurences(hits_df_lazy, motif_names_list) + + report_data, report_df, cwms, trim_bounds = evaluation.tfmodisco_comparison( + regions, + hits_df, + peaks_df, + seqlets_df, + motifs_df, + cwms_modisco, + motif_names_list, + modisco_half_width, + motif_width, + compute_recall, + ) - report_data, report_df, cwms, trim_bounds = evaluation.tfmodisco_comparison(regions, hits_df, peaks_df, seqlets_df, motifs_df, - cwms_modisco, motif_names, modisco_half_width, - motif_width, compute_recall) - if seqlets_df is not None: - confusion_df, confusion_mat = evaluation.seqlet_confusion(hits_df, seqlets_df, peaks_df, motif_names, motif_width) - + confusion_df, confusion_mat = evaluation.seqlet_confusion( + hits_df, seqlets_df, peaks_df, motif_names_list, motif_width + ) + else: + confusion_df, confusion_mat = None, None + os.makedirs(out_dir, exist_ok=True) - + occ_path = os.path.join(out_dir, "motif_occurrences.tsv") data_io.write_occ_df(occ_df, occ_path) data_io.write_report_data(report_df, cwms, out_dir) - visualization.plot_hit_stat_distributions(hits_df, motif_names, out_dir) - visualization.plot_hit_peak_distributions(occ_df, motif_names, out_dir) - visualization.plot_peak_motif_indicator_heatmap(coooc, motif_names, out_dir) + visualization.plot_hit_stat_distributions(hits_df_lazy, motif_names_list, out_dir) + visualization.plot_hit_peak_distributions(occ_df, motif_names_list, out_dir) + visualization.plot_peak_motif_indicator_heatmap(coooc, motif_names_list, out_dir) plot_dir = os.path.join(out_dir, "CWMs") visualization.plot_cwms(cwms, trim_bounds, plot_dir) if seqlets_df is not None: - seqlets_df = seqlets_df.collect() + seqlets_collected = ( + seqlets_df.collect() if isinstance(seqlets_df, pl.LazyFrame) else seqlets_df + ) seqlets_path = os.path.join(out_dir, "seqlets.tsv") - data_io.write_modisco_seqlets(seqlets_df, seqlets_path) + data_io.write_modisco_seqlets(seqlets_collected, seqlets_path) - seqlet_confusion_path = os.path.join(out_dir, "seqlet_confusion.tsv") - data_io.write_seqlet_confusion_df(confusion_df, seqlet_confusion_path) + if confusion_df is not None and confusion_mat is not None: + seqlet_confusion_path = os.path.join(out_dir, "seqlet_confusion.tsv") + data_io.write_seqlet_confusion_df(confusion_df, seqlet_confusion_path) - visualization.plot_hit_vs_seqlet_counts(report_data, out_dir) - visualization.plot_seqlet_confusion_heatmap(confusion_mat, motif_names, out_dir) + visualization.plot_hit_vs_seqlet_counts(report_data, out_dir) + visualization.plot_seqlet_confusion_heatmap( + confusion_mat, motif_names_list, out_dir + ) report_path = os.path.join(out_dir, "report.html") - visualization.write_report(report_df, motif_names, report_path, compute_recall, seqlets_df is not None) - - -def collapse_hits(hits_path, out_path, overlap_frac): + visualization.write_report( + report_df, motif_names_list, report_path, compute_recall, seqlets_df is not None + ) + + +def collapse_hits(hits_path: str, out_path: str, overlap_frac: float) -> None: + """Collapse overlapping hits by selecting the best hit per overlapping group. + + This function processes a set of motif hits and identifies overlapping hits, + keeping only the hit with the highest similarity score within each overlapping + group. This reduces redundancy in hit calls while preserving the most confident + predictions. + + Parameters + ---------- + hits_path : str + Path to input TSV file containing hit data (hits.tsv or hits_unique.tsv). + out_path : str + Path to output TSV file with additional 'is_primary' column. + overlap_frac : float + Minimum fractional overlap for considering hits as overlapping. + For hits of lengths x and y, minimum overlap = overlap_frac * (x + y) / 2. + + Notes + ----- + The algorithm uses a sweep line approach with a heap data structure to + efficiently identify overlapping intervals and select the best hit based + on similarity scores. + """ from . import postprocessing hits_df = data_io.load_hits(hits_path, lazy=False) hits_collapsed_df = postprocessing.collapse_hits(hits_df, overlap_frac) - data_io.write_hits_processed(hits_collapsed_df, out_path, schema=data_io.HITS_COLLAPSED_DTYPES) - - -def intersect_hits(hits_paths, out_path, relaxed): + data_io.write_hits_processed( + hits_collapsed_df, out_path, schema=data_io.HITS_COLLAPSED_DTYPES + ) + + +def intersect_hits(hits_paths: List[str], out_path: str, relaxed: bool) -> None: + """Find intersection of hits across multiple Fi-NeMo runs. + + This function identifies motif hits that are consistently called across + multiple independent runs, providing a way to assess reproducibility and + identify high-confidence hits. + + Parameters + ---------- + hits_paths : List[str] + List of paths to input TSV files from different runs. + out_path : str + Path to output TSV file containing intersection results. + Duplicate columns are suffixed with run index. + relaxed : bool + If True, uses relaxed intersection criteria based only on motif names + and untrimmed coordinates. If False, assumes consistent region definitions + and motif trimming across runs. + + Notes + ----- + The strict intersection mode requires consistent input regions and motif + processing parameters across all runs. The relaxed mode is more permissive + but may not be suitable when genomic coordinates are unavailable. + """ from . import postprocessing hits_dfs = [data_io.load_hits(hits_path, lazy=False) for hits_path in hits_paths] @@ -268,260 +670,692 @@ def intersect_hits(hits_paths, out_path, relaxed): data_io.write_hits_processed(hits_df, out_path, schema=None) -def cli(): - parser = argparse.ArgumentParser() - subparsers = parser.add_subparsers(required=True, dest='cmd') - - - extract_regions_bw_parser = subparsers.add_parser("extract-regions-bw", formatter_class=argparse.ArgumentDefaultsHelpFormatter, - help="Extract sequences and contributions from FASTA and bigwig files.") - - extract_regions_bw_parser.add_argument("-p", "--peaks", type=str, required=True, - help="A peak regions file in ENCODE NarrowPeak format.") - extract_regions_bw_parser.add_argument("-C", "--chrom-order", type=str, default=None, - help="A tab-delimited file with chromosome names in the first column to define sort order of chromosomes. Missing chromosomes are ordered as they appear in -p/--peaks.") - extract_regions_bw_parser.add_argument("-f", "--fasta", type=str, required=True, - help="A genome FASTA file. If an .fai index file doesn't exist in the same directory, it will be created.") - extract_regions_bw_parser.add_argument("-b", "--bigwigs", type=str, required=True, nargs='+', - help="One or more bigwig files of contribution scores, with paths delimited by whitespace. Scores are averaged across files.") - - extract_regions_bw_parser.add_argument("-o", "--out-path", type=str, required=True, - help="The path to the output .npz file.") - - extract_regions_bw_parser.add_argument("-w", "--region-width", type=int, default=1000, - help="The width of the input region centered around each peak summit.") - - - extract_chrombpnet_regions_h5_parser = subparsers.add_parser("extract-regions-chrombpnet-h5", formatter_class=argparse.ArgumentDefaultsHelpFormatter, - help="Extract sequences and contributions from ChromBPNet contributions H5 files.") - - extract_chrombpnet_regions_h5_parser.add_argument("-p", "--peaks", type=str, default=None, - help="A peak regions file in ENCODE NarrowPeak format. If omitted, downstream outputs will lack absolute genomic coordinates.") - extract_chrombpnet_regions_h5_parser.add_argument("-C", "--chrom-order", type=str, default=None, - help="A tab-delimited file with chromosome names in the first column to define sort order of chromosomes. Missing chromosomes are ordered as they appear in -p/--peaks.") - - extract_chrombpnet_regions_h5_parser.add_argument("-c", "--h5s", type=str, required=True, nargs='+', - help="One or more H5 files of contribution scores, with paths delimited by whitespace. Scores are averaged across files.") - - extract_chrombpnet_regions_h5_parser.add_argument("-o", "--out-path", type=str, required=True, - help="The path to the output .npz file.") - - extract_chrombpnet_regions_h5_parser.add_argument("-w", "--region-width", type=int, default=1000, - help="The width of the input region centered around each peak summit.") - - - extract_regions_h5_parser = subparsers.add_parser("extract-regions-h5", formatter_class=argparse.ArgumentDefaultsHelpFormatter, - help="Extract sequences and contributions from ChromBPNet contributions H5 files. DEPRECATED: Use `extract-regions-chrombpnet-h5` instead.") - - extract_regions_h5_parser.add_argument("-p", "--peaks", type=str, default=None, - help="A peak regions file in ENCODE NarrowPeak format. If omitted, downstream outputs will lack absolute genomic coordinates.") - extract_regions_h5_parser.add_argument("-C", "--chrom-order", type=str, default=None, - help="A tab-delimited file with chromosome names in the first column to define sort order of chromosomes. Missing chromosomes are ordered as they appear in -p/--peaks.") - - extract_regions_h5_parser.add_argument("-c", "--h5s", type=str, required=True, nargs='+', - help="One or more H5 files of contribution scores, with paths delimited by whitespace. Scores are averaged across files.") - - extract_regions_h5_parser.add_argument("-o", "--out-path", type=str, required=True, - help="The path to the output .npz file.") - - extract_regions_h5_parser.add_argument("-w", "--region-width", type=int, default=1000, - help="The width of the input region centered around each peak summit.") - - - extract_bpnet_regions_h5_parser = subparsers.add_parser("extract-regions-bpnet-h5", formatter_class=argparse.ArgumentDefaultsHelpFormatter, - help="Extract sequences and contributions from BPNet contributions H5 files.") - - extract_bpnet_regions_h5_parser.add_argument("-p", "--peaks", type=str, default=None, - help="A peak regions file in ENCODE NarrowPeak format. If omitted, downstream outputs will lack absolute genomic coordinates.") - extract_bpnet_regions_h5_parser.add_argument("-C", "--chrom-order", type=str, default=None, - help="A tab-delimited file with chromosome names in the first column to define sort order of chromosomes. Missing chromosomes are ordered as they appear in -p/--peaks.") - - extract_bpnet_regions_h5_parser.add_argument("-c", "--h5s", type=str, required=True, nargs='+', - help="One or more H5 files of contribution scores, with paths delimited by whitespace. Scores are averaged across files.") - - extract_bpnet_regions_h5_parser.add_argument("-o", "--out-path", type=str, required=True, - help="The path to the output .npz file.") - - extract_bpnet_regions_h5_parser.add_argument("-w", "--region-width", type=int, default=1000, - help="The width of the input region centered around each peak summit.") - - - extract_regions_modisco_fmt_parser = subparsers.add_parser("extract-regions-modisco-fmt", formatter_class=argparse.ArgumentDefaultsHelpFormatter, - help="Extract sequences and contributions from tfmodisco-lite input files.") - - extract_regions_modisco_fmt_parser.add_argument("-p", "--peaks", type=str, default=None, - help="A peak regions file in ENCODE NarrowPeak format. If omitted, downstream outputs will lack absolute genomic coordinates.") - extract_regions_modisco_fmt_parser.add_argument("-C", "--chrom-order", type=str, default=None, - help="A tab-delimited file with chromosome names in the first column to define sort order of chromosomes. Missing chromosomes are ordered as they appear in -p/--peaks.") - - extract_regions_modisco_fmt_parser.add_argument("-s", "--sequences", type=str, required=True, - help="A .npy or .npz file containing one-hot encoded sequences.") - - extract_regions_modisco_fmt_parser.add_argument("-a", "--attributions", type=str, required=True, nargs='+', - help="One or more .npy or .npz files of hypothetical contribution scores, with paths delimited by whitespace. Scores are averaged across files.") - - extract_regions_modisco_fmt_parser.add_argument("-o", "--out-path", type=str, required=True, - help="The path to the output .npz file.") - - extract_regions_modisco_fmt_parser.add_argument("-w", "--region-width", type=int, default=1000, - help="The width of the input region centered around each peak summit.") - - - call_hits_parser = subparsers.add_parser("call-hits", formatter_class=argparse.ArgumentDefaultsHelpFormatter, - help="Call hits on provided sequences, contributions, and motif CWM's.") - - call_hits_parser.add_argument("-M", "--mode", type=str, default="pp", choices={"pp", "ph", "hp", "hh"}, - help="The type of attributions to use for CWM's and input contribution scores, respectively. 'h' for hypothetical and 'p' for projected.") - - call_hits_parser.add_argument("-r", "--regions", type=str, required=True, - help="A .npz file of input sequences, contributions, and coordinates. Can be generated using `finemo extract-regions-*` subcommands.") - call_hits_parser.add_argument("-m", "--modisco-h5", type=str, required=True, - help="A tfmodisco-lite output H5 file of motif patterns.") - - call_hits_parser.add_argument("-p", "--peaks", type=str, default=None, - help="DEPRECATED: Please provide this file to a preprocessing `finemo extract-regions-*` subcommand instead.") - call_hits_parser.add_argument("-C", "--chrom-order", type=str, default=None, - help="DEPRECATED: Please provide this file to a preprocessing `finemo extract-regions-*` subcommand instead.") - - call_hits_parser.add_argument("-I", "--motifs-include", type=str, default=None, - help="A tab-delimited file with tfmodisco motif names (e.g pos_patterns.pattern_0) in the first column to include in hit calling. If omitted, all motifs in the modisco H5 file are used.") - call_hits_parser.add_argument("-N", "--motif-names", type=str, default=None, - help="A tab-delimited file with tfmodisco motif names (e.g pos_patterns.pattern_0) in the first column and custom names in the second column. Omitted motifs default to tfmodisco names.") - - call_hits_parser.add_argument("-o", "--out-dir", type=str, required=True, - help="The path to the output directory.") - - call_hits_parser.add_argument("-t", "--cwm-trim-threshold", type=float, default=0.3, - help="The default threshold to determine motif start and end positions within the full CWMs.") - call_hits_parser.add_argument("-T", "--cwm-trim-thresholds", type=str, default=None, - help="A tab-delimited file with tfmodisco motif names (e.g pos_patterns.pattern_0) in the first column and custom trim thresholds in the second column. Omitted motifs default to the `--cwm-trim-threshold` value.") - call_hits_parser.add_argument("-R", "--cwm-trim-coords", type=str, default=None, - help="A tab-delimited file with tfmodisco motif names (e.g pos_patterns.pattern_0) in the first column and custom trim start and end coordinates in the second and third columns, respectively. Omitted motifs default to `--cwm-trim-thresholds` values.") - - call_hits_parser.add_argument("-l", "--global-lambda", type=float, default=0.7, - help="The default L1 regularization weight determining the sparsity of hits.") - call_hits_parser.add_argument("-L", "--motif-lambdas", type=str, default=None, - help="A tab-delimited file with tfmodisco motif names (e.g pos_patterns.pattern_0) in the first column and motif-specific lambdas in the second column. Omitted motifs default to the `--global-lambda` value.") - call_hits_parser.add_argument("-a", "--alpha", type=float, default=None, - help="DEPRECATED: Please use the `--lambda` argument instead.") - call_hits_parser.add_argument("-A", "--motif-alphas", type=str, default=None, - help="DEPRECATED: Please use the `--motif-lambdas` argument instead.") - - call_hits_parser.add_argument("-f", "--no-post-filter", action='store_true', - help="Do not perform post-hit-calling filtering. By default, hits are filtered based on a minimum cosine similarity of `lambda` with the input contributions.") - call_hits_parser.add_argument("-q", "--sqrt-transform", action='store_true', - help="Apply a signed square root transform to the input contributions and CWMs before hit calling.") - call_hits_parser.add_argument("-s", "--step-size-max", type=float, default=3., - help="The maximum optimizer step size.") - call_hits_parser.add_argument("-i", "--step-size-min", type=float, default=0.08, - help="The minimum optimizer step size.") - call_hits_parser.add_argument("-j", "--step-adjust", type=float, default=0.7, - help="The optimizer step size adjustment factor. If the optimizer diverges, the step size is multiplicatively adjusted by this factor") - call_hits_parser.add_argument("-c", "--convergence-tol", type=float, default=0.0005, - help="The tolerance for determining convergence. The optimizer exits when the duality gap is less than the tolerance.") - call_hits_parser.add_argument("-S", "--max-steps", type=int, default=10000, - help="The maximum number of optimization steps.") - call_hits_parser.add_argument("-b", "--batch-size", type=int, default=2000, - help="The batch size used for optimization.") - call_hits_parser.add_argument("-d", "--device", type=str, default=None, - help="DEPRECATED: Please use the `CUDA_VISIBLE_DEVICES` environment variable to specify the GPU device.") - call_hits_parser.add_argument("-J", "--compile", action='store_true', - help="JIT-compile the optimizer for faster performance. This may not be supported on older GPUs.") - - - report_parser = subparsers.add_parser("report", formatter_class=argparse.ArgumentDefaultsHelpFormatter, - help="Generate statistics and visualizations from hits and tfmodisco-lite motif data.") - - report_parser.add_argument("-r", "--regions", type=str, required=True, - help="A .npz file containing input sequences, contributions, and coordinates. Must be the same as that used for `finemo call-hits`.") - report_parser.add_argument("-H", "--hits", type=str, required=True, - help="The output directory generated by the `finemo call-hits` command on the regions specified in `--regions`.") - report_parser.add_argument("-p", "--peaks", type=str, default=None, - help="DEPRECATED: Please provide this file to a preprocessing `finemo extract-regions-*` subcommand instead.") - report_parser.add_argument("-m", "--modisco-h5", type=str, default=None, - help="The tfmodisco-lite output H5 file of motif patterns. Must be the same as that used for hit calling unless `--no-recall` is set. If omitted, seqlet-derived metrics will not be computed.") - report_parser.add_argument("-I", "--motifs-include", type=str, default=None, - help="DEPRECATED: This information is now inferred from the outputs of `finemo call-hits`.") - report_parser.add_argument("-N", "--motif-names", type=str, default=None, - help="DEPRECATED: This information is now inferred from the outputs of `finemo call-hits`.") - - report_parser.add_argument("-o", "--out-dir", type=str, required=True, - help="The path to the report output directory.") - - report_parser.add_argument("-W", "--modisco-region-width", type=int, default=400, - help="The width of the region around each peak summit used by tfmodisco-lite.") - report_parser.add_argument("-t", "--cwm-trim-threshold", type=float, default=0.3, - help="DEPRECATED: This information is now inferred from the outputs of `finemo call-hits`.") - report_parser.add_argument("-n", "--no-recall", action='store_true', - help="Do not compute motif recall metrics.") - report_parser.add_argument("-s", "--no-seqlets", action='store_true', - help="DEPRECATED: Please omit the `--modisco-h5` argument instead.") - - - collapse_hits_parser = subparsers.add_parser("collapse-hits", formatter_class=argparse.ArgumentDefaultsHelpFormatter, - help="Identify best hit by motif similarity among sets of overlapping hits.") - - collapse_hits_parser.add_argument("-i", "--hits", type=str, required=True, - help="The `hits.tsv` or `hits_unique.tsv` file from `call-hits`.") - collapse_hits_parser.add_argument("-o", "--out-path", type=str, required=True, - help="The path to the output .tsv file with an additional \"is_primary\" column.") - collapse_hits_parser.add_argument("-O", "--overlap-frac", type=float, default=0.2, - help="The threshold for determining overlapping hits. For two hits with lengths x and y, the minimum overlap is defined as `overlap_frac * (x + y) / 2`. The default value of 0.2 means that two hits must overlap by at least 20% of their average lengths to be considered overlapping.") - - - intersect_hits_parser = subparsers.add_parser("intersect-hits", formatter_class=argparse.ArgumentDefaultsHelpFormatter, - help="Intersect hits across multiple runs.") - - intersect_hits_parser.add_argument("-i", "--hits", type=str, required=True, nargs='+', - help="One or more hits.tsv or hits_unique.tsv files, with paths delimited by whitespace.") - intersect_hits_parser.add_argument("-o", "--out-path", type=str, required=True, - help="The path to the output .tsv file. Duplicate columns are suffixed with the positional index of the input file.") - intersect_hits_parser.add_argument("-r", "--relaxed", action='store_true', - help="Use relaxed intersection criteria, using only motif names and untrimmed coordinates. By default, the intersection assumes consistent region definitions and motif trimming. This option is not recommended if genomic coordinates are unavailable.") +def cli() -> None: + """Command-line interface for the Fi-NeMo motif instance calling pipeline. + This function provides the main entry point for all Fi-NeMo operations including: + - Data preprocessing from various formats (bigWig, HDF5, TF-MoDISCo) + - Motif hit calling using the Fi-NeMo algorithm + - Report generation and visualization + - Post-processing operations (hit collapsing, intersection) + + The CLI supports comprehensive workflows for transcription factor motif + analysis from raw genomic data to publication-ready visualizations. + """ + parser = argparse.ArgumentParser() + subparsers = parser.add_subparsers(required=True, dest="cmd") + + extract_regions_bw_parser = subparsers.add_parser( + "extract-regions-bw", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + help="Extract sequences and contributions from FASTA and bigwig files.", + ) + + extract_regions_bw_parser.add_argument( + "-p", + "--peaks", + type=str, + required=True, + help="A peak regions file in ENCODE NarrowPeak format.", + ) + extract_regions_bw_parser.add_argument( + "-C", + "--chrom-order", + type=str, + default=None, + help="A tab-delimited file with chromosome names in the first column to define sort order of chromosomes. Missing chromosomes are ordered as they appear in -p/--peaks.", + ) + extract_regions_bw_parser.add_argument( + "-f", + "--fasta", + type=str, + required=True, + help="A genome FASTA file. If an .fai index file doesn't exist in the same directory, it will be created.", + ) + extract_regions_bw_parser.add_argument( + "-b", + "--bigwigs", + type=str, + required=True, + nargs="+", + help="One or more bigwig files of contribution scores, with paths delimited by whitespace. Scores are averaged across files.", + ) + + extract_regions_bw_parser.add_argument( + "-o", + "--out-path", + type=str, + required=True, + help="The path to the output .npz file.", + ) + + extract_regions_bw_parser.add_argument( + "-w", + "--region-width", + type=int, + default=1000, + help="The width of the input region centered around each peak summit.", + ) + + extract_chrombpnet_regions_h5_parser = subparsers.add_parser( + "extract-regions-chrombpnet-h5", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + help="Extract sequences and contributions from ChromBPNet contributions H5 files.", + ) + + extract_chrombpnet_regions_h5_parser.add_argument( + "-p", + "--peaks", + type=str, + default=None, + help="A peak regions file in ENCODE NarrowPeak format. If omitted, downstream outputs will lack absolute genomic coordinates.", + ) + extract_chrombpnet_regions_h5_parser.add_argument( + "-C", + "--chrom-order", + type=str, + default=None, + help="A tab-delimited file with chromosome names in the first column to define sort order of chromosomes. Missing chromosomes are ordered as they appear in -p/--peaks.", + ) + + extract_chrombpnet_regions_h5_parser.add_argument( + "-c", + "--h5s", + type=str, + required=True, + nargs="+", + help="One or more H5 files of contribution scores, with paths delimited by whitespace. Scores are averaged across files.", + ) + + extract_chrombpnet_regions_h5_parser.add_argument( + "-o", + "--out-path", + type=str, + required=True, + help="The path to the output .npz file.", + ) + + extract_chrombpnet_regions_h5_parser.add_argument( + "-w", + "--region-width", + type=int, + default=1000, + help="The width of the input region centered around each peak summit.", + ) + + extract_regions_h5_parser = subparsers.add_parser( + "extract-regions-h5", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + help="Extract sequences and contributions from ChromBPNet contributions H5 files. DEPRECATED: Use `extract-regions-chrombpnet-h5` instead.", + ) + + extract_regions_h5_parser.add_argument( + "-p", + "--peaks", + type=str, + default=None, + help="A peak regions file in ENCODE NarrowPeak format. If omitted, downstream outputs will lack absolute genomic coordinates.", + ) + extract_regions_h5_parser.add_argument( + "-C", + "--chrom-order", + type=str, + default=None, + help="A tab-delimited file with chromosome names in the first column to define sort order of chromosomes. Missing chromosomes are ordered as they appear in -p/--peaks.", + ) + + extract_regions_h5_parser.add_argument( + "-c", + "--h5s", + type=str, + required=True, + nargs="+", + help="One or more H5 files of contribution scores, with paths delimited by whitespace. Scores are averaged across files.", + ) + + extract_regions_h5_parser.add_argument( + "-o", + "--out-path", + type=str, + required=True, + help="The path to the output .npz file.", + ) + + extract_regions_h5_parser.add_argument( + "-w", + "--region-width", + type=int, + default=1000, + help="The width of the input region centered around each peak summit.", + ) + + extract_bpnet_regions_h5_parser = subparsers.add_parser( + "extract-regions-bpnet-h5", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + help="Extract sequences and contributions from BPNet contributions H5 files.", + ) + + extract_bpnet_regions_h5_parser.add_argument( + "-p", + "--peaks", + type=str, + default=None, + help="A peak regions file in ENCODE NarrowPeak format. If omitted, downstream outputs will lack absolute genomic coordinates.", + ) + extract_bpnet_regions_h5_parser.add_argument( + "-C", + "--chrom-order", + type=str, + default=None, + help="A tab-delimited file with chromosome names in the first column to define sort order of chromosomes. Missing chromosomes are ordered as they appear in -p/--peaks.", + ) + + extract_bpnet_regions_h5_parser.add_argument( + "-c", + "--h5s", + type=str, + required=True, + nargs="+", + help="One or more H5 files of contribution scores, with paths delimited by whitespace. Scores are averaged across files.", + ) + + extract_bpnet_regions_h5_parser.add_argument( + "-o", + "--out-path", + type=str, + required=True, + help="The path to the output .npz file.", + ) + + extract_bpnet_regions_h5_parser.add_argument( + "-w", + "--region-width", + type=int, + default=1000, + help="The width of the input region centered around each peak summit.", + ) + + extract_regions_modisco_fmt_parser = subparsers.add_parser( + "extract-regions-modisco-fmt", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + help="Extract sequences and contributions from tfmodisco-lite input files.", + ) + + extract_regions_modisco_fmt_parser.add_argument( + "-p", + "--peaks", + type=str, + default=None, + help="A peak regions file in ENCODE NarrowPeak format. If omitted, downstream outputs will lack absolute genomic coordinates.", + ) + extract_regions_modisco_fmt_parser.add_argument( + "-C", + "--chrom-order", + type=str, + default=None, + help="A tab-delimited file with chromosome names in the first column to define sort order of chromosomes. Missing chromosomes are ordered as they appear in -p/--peaks.", + ) + + extract_regions_modisco_fmt_parser.add_argument( + "-s", + "--sequences", + type=str, + required=True, + help="A .npy or .npz file containing one-hot encoded sequences.", + ) + + extract_regions_modisco_fmt_parser.add_argument( + "-a", + "--attributions", + type=str, + required=True, + nargs="+", + help="One or more .npy or .npz files of hypothetical contribution scores, with paths delimited by whitespace. Scores are averaged across files.", + ) + + extract_regions_modisco_fmt_parser.add_argument( + "-o", + "--out-path", + type=str, + required=True, + help="The path to the output .npz file.", + ) + + extract_regions_modisco_fmt_parser.add_argument( + "-w", + "--region-width", + type=int, + default=1000, + help="The width of the input region centered around each peak summit.", + ) + + call_hits_parser = subparsers.add_parser( + "call-hits", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + help="Call hits on provided sequences, contributions, and motif CWM's.", + ) + + call_hits_parser.add_argument( + "-M", + "--mode", + type=str, + default="pp", + choices={"pp", "ph", "hp", "hh"}, + help="The type of attributions to use for CWM's and input contribution scores, respectively. 'h' for hypothetical and 'p' for projected.", + ) + + call_hits_parser.add_argument( + "-r", + "--regions", + type=str, + required=True, + help="A .npz file of input sequences, contributions, and coordinates. Can be generated using `finemo extract-regions-*` subcommands.", + ) + call_hits_parser.add_argument( + "-m", + "--modisco-h5", + type=str, + required=True, + help="A tfmodisco-lite output H5 file of motif patterns.", + ) + + call_hits_parser.add_argument( + "-p", + "--peaks", + type=str, + default=None, + help="DEPRECATED: Please provide this file to a preprocessing `finemo extract-regions-*` subcommand instead.", + ) + call_hits_parser.add_argument( + "-C", + "--chrom-order", + type=str, + default=None, + help="DEPRECATED: Please provide this file to a preprocessing `finemo extract-regions-*` subcommand instead.", + ) + + call_hits_parser.add_argument( + "-I", + "--motifs-include", + type=str, + default=None, + help="A tab-delimited file with tfmodisco motif names (e.g pos_patterns.pattern_0) in the first column to include in hit calling. If omitted, all motifs in the modisco H5 file are used.", + ) + call_hits_parser.add_argument( + "-N", + "--motif-names", + type=str, + default=None, + help="A tab-delimited file with tfmodisco motif names (e.g pos_patterns.pattern_0) in the first column and custom names in the second column. Omitted motifs default to tfmodisco names.", + ) + + call_hits_parser.add_argument( + "-o", + "--out-dir", + type=str, + required=True, + help="The path to the output directory.", + ) + + call_hits_parser.add_argument( + "-t", + "--cwm-trim-threshold", + type=float, + default=0.3, + help="The default threshold to determine motif start and end positions within the full CWMs.", + ) + call_hits_parser.add_argument( + "-T", + "--cwm-trim-thresholds", + type=str, + default=None, + help="A tab-delimited file with tfmodisco motif names (e.g pos_patterns.pattern_0) in the first column and custom trim thresholds in the second column. Omitted motifs default to the `--cwm-trim-threshold` value.", + ) + call_hits_parser.add_argument( + "-R", + "--cwm-trim-coords", + type=str, + default=None, + help="A tab-delimited file with tfmodisco motif names (e.g pos_patterns.pattern_0) in the first column and custom trim start and end coordinates in the second and third columns, respectively. Omitted motifs default to `--cwm-trim-thresholds` values.", + ) + + call_hits_parser.add_argument( + "-l", + "--global-lambda", + type=float, + default=0.7, + help="The default L1 regularization weight determining the sparsity of hits.", + ) + call_hits_parser.add_argument( + "-L", + "--motif-lambdas", + type=str, + default=None, + help="A tab-delimited file with tfmodisco motif names (e.g pos_patterns.pattern_0) in the first column and motif-specific lambdas in the second column. Omitted motifs default to the `--global-lambda` value.", + ) + call_hits_parser.add_argument( + "-a", + "--alpha", + type=float, + default=None, + help="DEPRECATED: Please use the `--lambda` argument instead.", + ) + call_hits_parser.add_argument( + "-A", + "--motif-alphas", + type=str, + default=None, + help="DEPRECATED: Please use the `--motif-lambdas` argument instead.", + ) + + call_hits_parser.add_argument( + "-f", + "--no-post-filter", + action="store_true", + help="Do not perform post-hit-calling filtering. By default, hits are filtered based on a minimum cosine similarity of `lambda` with the input contributions.", + ) + call_hits_parser.add_argument( + "-q", + "--sqrt-transform", + action="store_true", + help="Apply a signed square root transform to the input contributions and CWMs before hit calling.", + ) + call_hits_parser.add_argument( + "-s", + "--step-size-max", + type=float, + default=3.0, + help="The maximum optimizer step size.", + ) + call_hits_parser.add_argument( + "-i", + "--step-size-min", + type=float, + default=0.08, + help="The minimum optimizer step size.", + ) + call_hits_parser.add_argument( + "-j", + "--step-adjust", + type=float, + default=0.7, + help="The optimizer step size adjustment factor. If the optimizer diverges, the step size is multiplicatively adjusted by this factor", + ) + call_hits_parser.add_argument( + "-c", + "--convergence-tol", + type=float, + default=0.0005, + help="The tolerance for determining convergence. The optimizer exits when the duality gap is less than the tolerance.", + ) + call_hits_parser.add_argument( + "-S", + "--max-steps", + type=int, + default=10000, + help="The maximum number of optimization steps.", + ) + call_hits_parser.add_argument( + "-b", + "--batch-size", + type=int, + default=2000, + help="The batch size used for optimization.", + ) + call_hits_parser.add_argument( + "-d", + "--device", + type=str, + default=None, + help="DEPRECATED: Please use the `CUDA_VISIBLE_DEVICES` environment variable to specify the GPU device.", + ) + call_hits_parser.add_argument( + "-J", + "--compile", + action="store_true", + help="JIT-compile the optimizer for faster performance. This may not be supported on older GPUs.", + ) + + report_parser = subparsers.add_parser( + "report", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + help="Generate statistics and visualizations from hits and tfmodisco-lite motif data.", + ) + + report_parser.add_argument( + "-r", + "--regions", + type=str, + required=True, + help="A .npz file containing input sequences, contributions, and coordinates. Must be the same as that used for `finemo call-hits`.", + ) + report_parser.add_argument( + "-H", + "--hits", + type=str, + required=True, + help="The output directory generated by the `finemo call-hits` command on the regions specified in `--regions`.", + ) + report_parser.add_argument( + "-p", + "--peaks", + type=str, + default=None, + help="DEPRECATED: Please provide this file to a preprocessing `finemo extract-regions-*` subcommand instead.", + ) + report_parser.add_argument( + "-m", + "--modisco-h5", + type=str, + default=None, + help="The tfmodisco-lite output H5 file of motif patterns. Must be the same as that used for hit calling unless `--no-recall` is set. If omitted, seqlet-derived metrics will not be computed.", + ) + report_parser.add_argument( + "-I", + "--motifs-include", + type=str, + default=None, + help="DEPRECATED: This information is now inferred from the outputs of `finemo call-hits`.", + ) + report_parser.add_argument( + "-N", + "--motif-names", + type=str, + default=None, + help="DEPRECATED: This information is now inferred from the outputs of `finemo call-hits`.", + ) + + report_parser.add_argument( + "-o", + "--out-dir", + type=str, + required=True, + help="The path to the report output directory.", + ) + + report_parser.add_argument( + "-W", + "--modisco-region-width", + type=int, + default=400, + help="The width of the region around each peak summit used by tfmodisco-lite.", + ) + report_parser.add_argument( + "-t", + "--cwm-trim-threshold", + type=float, + default=0.3, + help="DEPRECATED: This information is now inferred from the outputs of `finemo call-hits`.", + ) + report_parser.add_argument( + "-n", + "--no-recall", + action="store_true", + help="Do not compute motif recall metrics.", + ) + report_parser.add_argument( + "-s", + "--no-seqlets", + action="store_true", + help="DEPRECATED: Please omit the `--modisco-h5` argument instead.", + ) + + collapse_hits_parser = subparsers.add_parser( + "collapse-hits", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + help="Identify best hit by motif similarity among sets of overlapping hits.", + ) + + collapse_hits_parser.add_argument( + "-i", + "--hits", + type=str, + required=True, + help="The `hits.tsv` or `hits_unique.tsv` file from `call-hits`.", + ) + collapse_hits_parser.add_argument( + "-o", + "--out-path", + type=str, + required=True, + help='The path to the output .tsv file with an additional "is_primary" column.', + ) + collapse_hits_parser.add_argument( + "-O", + "--overlap-frac", + type=float, + default=0.2, + help="The threshold for determining overlapping hits. For two hits with lengths x and y, the minimum overlap is defined as `overlap_frac * (x + y) / 2`. The default value of 0.2 means that two hits must overlap by at least 20% of their average lengths to be considered overlapping.", + ) + + intersect_hits_parser = subparsers.add_parser( + "intersect-hits", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + help="Intersect hits across multiple runs.", + ) + + intersect_hits_parser.add_argument( + "-i", + "--hits", + type=str, + required=True, + nargs="+", + help="One or more hits.tsv or hits_unique.tsv files, with paths delimited by whitespace.", + ) + intersect_hits_parser.add_argument( + "-o", + "--out-path", + type=str, + required=True, + help="The path to the output .tsv file. Duplicate columns are suffixed with the positional index of the input file.", + ) + intersect_hits_parser.add_argument( + "-r", + "--relaxed", + action="store_true", + help="Use relaxed intersection criteria, using only motif names and untrimmed coordinates. By default, the intersection assumes consistent region definitions and motif trimming. This option is not recommended if genomic coordinates are unavailable.", + ) args = parser.parse_args() - + if args.cmd == "extract-regions-bw": - extract_regions_bw(args.peaks, args.chrom_order, args.fasta, args.bigwigs, args.out_path, args.region_width) + extract_regions_bw( + args.peaks, + args.chrom_order, + args.fasta, + args.bigwigs, + args.out_path, + args.region_width, + ) elif args.cmd == "extract-regions-chrombpnet-h5": - extract_regions_chrombpnet_h5(args.peaks, args.chrom_order, args.h5s, args.out_path, args.region_width) + extract_regions_chrombpnet_h5( + args.peaks, args.chrom_order, args.h5s, args.out_path, args.region_width + ) elif args.cmd == "extract-regions-h5": - print("WARNING: The `extract-regions-h5` command is deprecated. Use `extract-regions-chrombpnet-h5` instead.") - extract_regions_chrombpnet_h5(args.peaks, args.chrom_order, args.h5s, args.out_path, args.region_width) + print( + "WARNING: The `extract-regions-h5` command is deprecated. Use `extract-regions-chrombpnet-h5` instead." + ) + extract_regions_chrombpnet_h5( + args.peaks, args.chrom_order, args.h5s, args.out_path, args.region_width + ) elif args.cmd == "extract-regions-bpnet-h5": - extract_regions_bpnet_h5(args.peaks, args.chrom_order, args.h5s, args.out_path, args.region_width) + extract_regions_bpnet_h5( + args.peaks, args.chrom_order, args.h5s, args.out_path, args.region_width + ) elif args.cmd == "extract-regions-modisco-fmt": - extract_regions_modisco_fmt(args.peaks, args.chrom_order, args.attributions, args.sequences, args.out_path, args.region_width) - + extract_regions_modisco_fmt( + args.peaks, + args.chrom_order, + args.attributions, + args.sequences, + args.out_path, + args.region_width, + ) + elif args.cmd == "call-hits": if args.alpha is not None: - warnings.warn("The `--alpha` flag is deprecated and will be removed in a future version. Please use the `--global-lambda` flag instead.") + warnings.warn( + "The `--alpha` flag is deprecated and will be removed in a future version. Please use the `--global-lambda` flag instead." + ) args.global_lambda = args.alpha if args.motif_alphas is not None: - warnings.warn("The `--motif-alphas` flag is deprecated and will be removed in a future version. Please use the `--motif-lambdas` flag instead.") + warnings.warn( + "The `--motif-alphas` flag is deprecated and will be removed in a future version. Please use the `--motif-lambdas` flag instead." + ) args.motif_lambdas = args.motif_alphas - call_hits(args.regions, args.peaks, args.modisco_h5, args.chrom_order, args.motifs_include, args.motif_names, - args.motif_lambdas, args.out_dir, args.cwm_trim_coords, args.cwm_trim_thresholds, args.cwm_trim_threshold, - args.global_lambda, args.step_size_max, args.step_size_min, args.sqrt_transform, args.convergence_tol, - args.max_steps, args.batch_size, args.step_adjust, args.device, args.mode, args.no_post_filter, args.compile) + call_hits( + args.regions, + args.peaks, + args.modisco_h5, + args.chrom_order, + args.motifs_include, + args.motif_names, + args.motif_lambdas, + args.out_dir, + args.cwm_trim_coords, + args.cwm_trim_thresholds, + args.cwm_trim_threshold, + args.global_lambda, + args.step_size_max, + args.step_size_min, + args.sqrt_transform, + args.convergence_tol, + args.max_steps, + args.batch_size, + args.step_adjust, + args.device, + args.mode, + args.no_post_filter, + args.compile, + ) elif args.cmd == "report": if args.no_recall and not args.no_seqlets: - raise ValueError("The `--no-seqlets` flag must be set in conjunction with `--no-recall`.") - - report(args.regions, args.hits, args.modisco_h5, args.peaks, args.motifs_include, - args.motif_names, args.out_dir, args.modisco_region_width, args.cwm_trim_threshold, - not args.no_recall, not args.no_seqlets) + raise ValueError( + "The `--no-seqlets` flag must be set in conjunction with `--no-recall`." + ) + + report( + args.regions, + args.hits, + args.modisco_h5, + args.peaks, + args.motifs_include, + args.motif_names, + args.out_dir, + args.modisco_region_width, + args.cwm_trim_threshold, + not args.no_recall, + not args.no_seqlets, + ) elif args.cmd == "collapse-hits": collapse_hits(args.hits, args.out_path, args.overlap_frac) elif args.cmd == "intersect-hits": intersect_hits(args.hits, args.out_path, args.relaxed) - diff --git a/src/finemo/postprocessing.py b/src/finemo/postprocessing.py index 5663a4a..55c616b 100644 --- a/src/finemo/postprocessing.py +++ b/src/finemo/postprocessing.py @@ -1,19 +1,78 @@ +"""Post-processing utilities for Fi-NeMo hit calling results. + +This module provides functions for: +- Collapsing overlapping hits based on similarity scores +- Intersecting hit sets across multiple runs +- Quality control and filtering operations + +The main operations are optimized using Numba for efficient processing +of large hit datasets. +""" + import heapq +from typing import List, Union import numpy as np +from numpy import ndarray import polars as pl from numba import njit -from numba.types import Array, uint32, int32, float32 +from numba.types import Array, uint32, int32, float32 # type: ignore[attr-defined] +from jaxtyping import Float, Int + @njit( uint32[:]( - Array(uint32, 1, 'C', readonly=True), - Array(int32, 1, 'C', readonly=True), - Array(int32, 1, 'C', readonly=True), - Array(float32, 1, 'C', readonly=True) - ), cache=True + Array(uint32, 1, "C", readonly=True), + Array(int32, 1, "C", readonly=True), + Array(int32, 1, "C", readonly=True), + Array(float32, 1, "C", readonly=True), + ), + cache=True, ) -def _collapse_hits(chrom_ids, starts, ends, similarities): +def _collapse_hits( + chrom_ids: Int[ndarray, " N"], + starts: Int[ndarray, " N"], + ends: Int[ndarray, " N"], + similarities: Float[ndarray, " N"], +) -> Int[ndarray, " N"]: + """Identify primary hits among overlapping hits using a sweep line algorithm. + + This function uses a heap-based sweep line algorithm to efficiently identify + the best hit (highest similarity) among sets of overlapping hits within each + chromosome. Only one hit per overlapping group is marked as primary. + + Parameters + ---------- + chrom_ids : Int[ndarray, "N"] + Chromosome identifiers for each hit, where N is the number of hits. + Dtype should be uint32 for Numba compatibility. + starts : Int[ndarray, "N"] + Start positions of hits (adjusted for overlap computation). + Dtype should be int32 for Numba compatibility. + ends : Int[ndarray, "N"] + End positions of hits (adjusted for overlap computation). + Dtype should be int32 for Numba compatibility. + similarities : Float[ndarray, "N"] + Similarity scores used for selecting the best hit. + Dtype should be float32 for Numba compatibility. + + Returns + ------- + Int[ndarray, "N"] + Binary array where 1 indicates the hit is primary, 0 otherwise. + Returns uint32 array for consistency with input types. + + Notes + ----- + This function is JIT-compiled with Numba for performance on large datasets. + The algorithm maintains active intervals in a heap and resolves overlaps + by keeping only the hit with the highest similarity score. + + The sweep line algorithm processes hits in order and maintains a heap of + currently active intervals. When a new interval is encountered, it is + compared against all overlapping intervals in the heap, and only the + interval with the highest similarity score remains marked as primary. + """ n = chrom_ids.shape[0] out = np.ones(n, dtype=np.uint32) heap = [(np.uint32(0), np.int32(0), -1) for _ in range(0)] @@ -24,78 +83,216 @@ def _collapse_hits(chrom_ids, starts, ends, similarities): end_new = ends[i] sim_new = similarities[i] + # Remove expired intervals from heap while heap and heap[0] < (chrom_new, start_new, -1): heapq.heappop(heap) + # Check overlaps with active intervals for _, _, idx in heap: cmp = sim_new > similarities[idx] out[idx] &= cmp out[i] &= not cmp + # Add current interval to heap heapq.heappush(heap, (chrom_new, end_new, i)) return out -def collapse_hits(hits_df, overlap_frac): +def collapse_hits( + hits_df: Union[pl.DataFrame, pl.LazyFrame], overlap_frac: float +) -> pl.DataFrame: + """Collapse overlapping hits by selecting the best hit per overlapping group. + + This function identifies overlapping hits and marks only the highest-similarity + hit as primary in each overlapping group. Overlap is determined by a fractional + threshold based on the average length of the two hits being compared. + + Parameters + ---------- + hits_df : Union[pl.DataFrame, pl.LazyFrame] + Hit data containing required columns: chr (or peak_id if no chr), start, end, + hit_similarity. Will be collected to DataFrame if passed as LazyFrame. + overlap_frac : float + Overlap fraction threshold for considering hits as overlapping. + For two hits with lengths x and y, minimum overlap = overlap_frac * (x + y) / 2. + Must be between 0 and 1, where 0 means any overlap and 1 means complete overlap. + + Returns + ------- + pl.DataFrame + Original hit data with an additional 'is_primary' column (1 for primary hits, 0 otherwise). + All original columns are preserved, with the new column added at the end. + + Raises + ------ + KeyError + If required columns (chr/peak_id, start, end, hit_similarity) are missing. + + Notes + ----- + The algorithm transforms coordinates by scaling by 2 and adjusting by the overlap + fraction to create effective overlap regions for efficient processing. This allows + using a sweep line algorithm to identify overlaps in a single pass. + + The transformation works as follows: + - Original coordinates: [start, end] + - Length = end - start + - Adjusted start = start * 2 + length * overlap_frac + - Adjusted end = end * 2 - length * overlap_frac + + This creates regions that overlap only when the original regions have sufficient + overlap according to the specified fraction. + + Examples + -------- + >>> hits_collapsed = collapse_hits(hits_df, overlap_frac=0.2) + >>> primary_hits = hits_collapsed.filter(pl.col("is_primary") == 1) + >>> print(f"Kept {primary_hits.height}/{hits_df.height} hits as primary") + """ + # Ensure we're working with a DataFrame + if isinstance(hits_df, pl.LazyFrame): + hits_df = hits_df.collect() + chroms = hits_df["chr"].unique(maintain_order=True) if not chroms.is_empty(): - chrom_to_id = { - chrom: i for i, chrom in enumerate(chroms) - } + chrom_to_id = {chrom: i for i, chrom in enumerate(chroms)} + # Transform coordinates for overlap computation + # Scale by 2 and adjust by overlap fraction to create effective overlap regions df = hits_df.select( chrom_id=pl.col("chr").replace_strict(chrom_to_id, return_dtype=pl.UInt32), - start_trim=pl.col("start") * 2 + ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), - end_trim=pl.col("end") * 2 - ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), - similarity=pl.col("hit_similarity") + start_trim=pl.col("start") * 2 + + ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), + end_trim=pl.col("end") * 2 + - ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), + similarity=pl.col("hit_similarity"), ) else: + # Fall back to peak_id when chr column is not available df = hits_df.select( chrom_id=pl.col("peak_id"), - start_trim=pl.col("start") * 2 + ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), - end_trim=pl.col("end") * 2 - ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), - similarity=pl.col("hit_similarity") + start_trim=pl.col("start") * 2 + + ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), + end_trim=pl.col("end") * 2 + - ((pl.col("end") - pl.col("start")) * overlap_frac).cast(pl.Int32), + similarity=pl.col("hit_similarity"), ) + # Rechunk for efficient array access df = df.rechunk() chrom_ids = df["chrom_id"].to_numpy(allow_copy=False) starts = df["start_trim"].to_numpy(allow_copy=False) ends = df["end_trim"].to_numpy(allow_copy=False) similarities = df["similarity"].to_numpy(allow_copy=False) + + # Run the collapse algorithm is_primary = _collapse_hits(chrom_ids, starts, ends, similarities) - df_out = hits_df.with_columns( - is_primary=pl.Series(is_primary, dtype=pl.UInt32) - ) + # Add primary indicator column to original DataFrame + df_out = hits_df.with_columns(is_primary=pl.Series(is_primary, dtype=pl.UInt32)) return df_out -def intersect_hits(hits_dfs, relaxed): +def intersect_hits( + hits_dfs: List[Union[pl.DataFrame, pl.LazyFrame]], relaxed: bool +) -> pl.DataFrame: + """Intersect hit datasets across multiple runs to find common hits. + + This function finds hits that appear consistently across multiple Fi-NeMo + runs, which can be useful for identifying robust motif instances that are + not sensitive to parameter variations or random initialization. + + Parameters + ---------- + hits_dfs : List[Union[pl.DataFrame, pl.LazyFrame]] + List of hit DataFrames from different Fi-NeMo runs. Each DataFrame must + contain the columns specified by the intersection criteria. LazyFrames + will be collected before processing. + relaxed : bool + If True, uses relaxed intersection criteria with only motif names and + untrimmed coordinates. If False, uses strict criteria including all + coordinate and metadata columns. + + Returns + ------- + pl.DataFrame + DataFrame containing hits that appear in all input datasets. + Columns from later datasets are suffixed with their index (e.g., '_1', '_2'). + The first dataset's columns retain their original names. + + Raises + ------ + ValueError + If fewer than one hits DataFrame is provided. + KeyError + If required columns for the specified intersection criteria are missing + from any of the input DataFrames. + + Notes + ----- + Relaxed intersection is useful when comparing results across different + region definitions or motif trimming parameters, but may produce less + precise matches. Strict intersection requires identical region definitions + and is recommended for most use cases. + + The intersection columns used are: + - Relaxed: ["chr", "start_untrimmed", "end_untrimmed", "motif_name", "strand"] + - Strict: ["chr", "start", "end", "start_untrimmed", "end_untrimmed", + "motif_name", "strand", "peak_name", "peak_id"] + + The function performs successive inner joins starting with the first DataFrame, + so the final result contains only hits present in all input datasets. + + Examples + -------- + >>> common_hits = intersect_hits([hits_df1, hits_df2], relaxed=False) + >>> print(f"Found {common_hits.height} hits common to both runs") + >>> + >>> # Compare relaxed vs strict intersection + >>> relaxed_hits = intersect_hits([hits_df1, hits_df2], relaxed=True) + >>> strict_hits = intersect_hits([hits_df1, hits_df2], relaxed=False) + >>> print(f"Relaxed: {relaxed_hits.height}, Strict: {strict_hits.height}") + """ if relaxed: - join_cols = [ - "chr", "start_untrimmed", "end_untrimmed", - "motif_name", "strand" - ] + # Relaxed criteria: only motif identity and untrimmed positions + join_cols = ["chr", "start_untrimmed", "end_untrimmed", "motif_name", "strand"] else: + # Strict criteria: all coordinate and metadata columns join_cols = [ - "chr", "start", "end", "start_untrimmed", "end_untrimmed", - "motif_name", "strand", "peak_name", "peak_id" + "chr", + "start", + "end", + "start_untrimmed", + "end_untrimmed", + "motif_name", + "strand", + "peak_name", + "peak_id", ] if len(hits_dfs) < 1: raise ValueError("At least one hits dataframe required") - hits_df = hits_dfs[0] - for i in range(1, len(hits_dfs)): + # Ensure all DataFrames are collected + collected_dfs = [] + for df in hits_dfs: + if isinstance(df, pl.LazyFrame): + collected_dfs.append(df.collect()) + else: + collected_dfs.append(df) + + # Start with first DataFrame and successively intersect with others + hits_df = collected_dfs[0] + for i in range(1, len(collected_dfs)): hits_df = hits_df.join( - hits_dfs[i], + collected_dfs[i], on=join_cols, how="inner", suffix=f"_{i}", join_nulls=True, - coalesce=True + coalesce=True, ) - return hits_df \ No newline at end of file + return hits_df diff --git a/src/finemo/templates/report.html b/src/finemo/templates/report.html index 8465ead..a04417b 100644 --- a/src/finemo/templates/report.html +++ b/src/finemo/templates/report.html @@ -178,15 +178,24 @@ -

Fi-NeMo hit calling report

+

Fi-NeMo Motif Hit Calling Report

+ +

+ This report provides a comprehensive analysis of motif instance calling results from Fi-NeMo, + a GPU-accelerated method for identifying transcription factor binding sites using neural network + contribution scores. Fi-NeMo uses a competitive optimization approach to comprehensively map motif + instances by solving a sparse linear reconstruction problem. The report compares Fi-NeMo hits + with TF-MoDISCo seqlets (when available) and provides detailed statistics on hit quality and + motif discovery performance. +

{% if not use_seqlets %}
- Seqlet comparisons are not shown because a TF-MoDISco H5 file with seqlet data is not provided. + Note: Seqlet comparisons are not shown because a TF-MoDISco H5 file with seqlet data was not provided.
{% elif not compute_recall %}
- Seqlet recall and other statistics directly comparing hits and seqlets are not computed because the -n/--no-recall argument is set. + Note: Seqlet recall and other statistics directly comparing hits and seqlets are not computed because the -n/--no-recall argument was specified.
{% endif %} @@ -198,66 +207,69 @@

TF-MoDISco seqlet comparisons

Hit vs. seqlet counts

- This figure shows the number of hits called vs. the number of TF-MoDISco seqlets identified for each motif. - The dashed line is the identity line. - When comparing a shared set of regions, the hit counts should be mostly greater than the corresponding seqlet counts, since TF-MoDISco stringently filters seqlets and usually uses a smaller input window. + This scatter plot compares the number of motif instances called by Fi-NeMo versus the number of TF-MoDISCo seqlets + identified for each motif. The dashed line represents perfect agreement (y = x). Fi-NeMo typically identifies + an order of magnitude more motif instances than TF-MoDISCo because: (1) TF-MoDISCo applies stringent filtering criteria + during seqlet identification, and (2) TF-MoDISCo often analyzes smaller genomic windows than those used for hit calling.

{% endif %} -

Hit and seqlet motif comparisons

+

Motif-specific hit and seqlet analysis

- For each motif, this table examines the consistency between hits and TF-MoDISco seqlets. + This table provides detailed statistics for each motif, comparing the consistency between Fi-NeMo hits + and TF-MoDISCo seqlets. The analysis includes hit counts, overlap statistics, and visual comparisons + of contribution weight matrices (CWMs).

- The following statistics report the number of hits, seqlets, and their relationships: + Statistical measures include:

- Note that the seqlet counts here may be lower than those shown in the tfmodisco-lite report due to double-counting in overlapping regions. - The seqlet counts shown here are unique while the counts in the tfmodisco-lite report are not de-duplicated. -

-{% if compute_recall %} -

- Note that palindromic motifs may have lower recall due to disagreements on orientation. - If seqlet recall is near zero for all motifs, the -W/--modisco-region-width argument is likely incorrect. - This value is required to infer genomic coordinates of seqlets from the tfmodisco-lite output H5. -

-{% endif %} -

- Motif CWMs (contribution weight matrices) are average contribution scores over a set of regions. The CWMs plotted here are: + Important notes:

- The plots span the full untrimmed motif, with the trimmed motif shaded. + Contribution Weight Matrix (CWM) visualizations:
+ CWMs represent average contribution scores across motif instances and show the functional importance + of each nucleotide position. The following CWMs are displayed for comparison:

+

- The hit-seqlet similarity is the cosine similarity between the additional-restricted-hits CWM and the seqlet CWM. - This statistic measures the similarity between hits that were missed by TF-MoDISco and the seqlets used to construct the motif. + All CWM plots span the full untrimmed motif width, with the core trimmed region highlighted by shading. + {% if compute_recall %} + The hit-seqlet CWM similarity quantifies the overall agreement between Fi-NeMo's discovered instances + and TF-MoDISco's original motif definitions. + {% endif %}

@@ -325,38 +337,44 @@

Hit and seqlet motif comparisons

{% if compute_recall %} -

Seqlet-hit confusion matrix

+

Motif cross-assignment analysis

+

+ This confusion matrix identifies cases where Fi-NeMo hits of one motif type spatially overlap with + TF-MoDISco seqlets of different motif types. Such cross-assignments can reveal related motif families, + algorithm differences, or cases where similar-looking motifs compete for the same binding sites. +

- This heatmap shows the prevalence of motifs whose (untrimmed) hits overlap with TF-MoDISco seqlets of other motifs. - The vertical axis shows the motif of the seqlet, while the horizontal axis shows the motif of the hit. - The color intensity here represents an estimator of the expected number of bases of hit overlap per base of seqlet. + The y-axis represents seqlet motif identity, the x-axis represents hit motif identity, and color intensity + indicates the estimated overlap frequency per base of seqlet sequence. High off-diagonal values suggest + potential motif ambiguity and/or algorithmic disagreements at groups of putative TF binding sites.

{% endif %} -

Hit statistic distributions

+

Hit Quality and Distribution Analysis

- The following figures visualize the distribution of hit statistics across motifs and regions. + These visualizations examine the quality and distribution of Fi-NeMo hits across genomic regions and motifs, + measuring algorithm performance and signal strength.

-

Overall distribution of hit counts per region

+

Genome-wide hit density

- This plot shows the distribution of hit counts per region for any motif. - The number of regions with no hits should be near zero. + This histogram shows the distribution of total hit counts per genomic region (across all motifs). + A good distribution should show nearly all regions containing at least one hit.

-

Per-motif distributions of hit statistics

+

Motif-specific hit quality metrics

- These plots show the distribution of hit statistics for each motif, specifically: -

+ These distribution plots characterize the quality and prevalence of hits for individual motifs:

+
@@ -380,11 +398,16 @@

Per-motif distributions of hit statistics

-

Motif co-occurrence

+

Motif co-occurrence analysis

+

+ This correlation heatmap reveals which motifs tend to occur together in the same genomic regions, + potentially indicating cooperative transcription factor binding or shared regulatory mechanisms. + Color intensity represents cosine similarity between motif occurrence patterns, where occurrence + is defined as the presence of at least one hit for each motif within individual regions. +

- This heatmap shows the co-occurrence of motifs across regions. - The color intensity here represents the cosine similarity between the motifs' occurrence across regions, - where occurence is defined as the presence of a hit for a motif in a region. + High positive correlations (dark colors) suggest motifs that frequently co-occur. + Low correlations suggest independent or mutually exclusive binding patterns.

diff --git a/src/finemo/visualization.py b/src/finemo/visualization.py index fd348a1..6346857 100644 --- a/src/finemo/visualization.py +++ b/src/finemo/visualization.py @@ -1,18 +1,57 @@ +"""Visualization module for generating plots and reports for Fi-NeMo results. + +This module provides functions for: +- Plotting motif contribution weight matrices (CWMs) as sequence logos +- Generating distribution plots for hit statistics +- Creating co-occurrence heatmaps +- Producing HTML reports with interactive visualizations +- Plotting confusion matrices and performance metrics +""" + import os -import importlib +import importlib.resources +from typing import List, Optional, Dict, Any, Tuple, Union, Mapping, Iterable import numpy as np +from numpy import ndarray import matplotlib.pyplot as plt +from matplotlib.axes import Axes from matplotlib.patheffects import AbstractPathEffect from matplotlib.textpath import TextPath from matplotlib.transforms import Affine2D from matplotlib.font_manager import FontProperties from jinja2 import Template +import polars as pl +from jaxtyping import Float, Int from . import templates -def abbreviate_motif_name(name): +def abbreviate_motif_name(name: str) -> str: + """Convert TF-MoDISCo motif names to abbreviated format. + + Converts full TF-MoDISCo pattern names to shorter, more readable format + for display in plots and reports. + + Parameters + ---------- + name : str + Full motif name (e.g., 'pos_patterns.pattern_0'). + + Returns + ------- + str + Abbreviated name (e.g., '+/0') or original name if parsing fails. + + Examples + -------- + >>> abbreviate_motif_name('pos_patterns.pattern_0') + '+/0' + >>> abbreviate_motif_name('neg_patterns.pattern_1') + '-/1' + >>> abbreviate_motif_name('invalid_name') + 'invalid_name' + """ try: group, motif = name.split(".") if group == "pos_patterns": @@ -23,14 +62,44 @@ def abbreviate_motif_name(name): raise Exception motif_num = motif.split("_")[1] return f"{group_short}/{motif_num}" - except: + except Exception: return name -def plot_hit_stat_distributions(hits_df, motif_names, plot_dir): - hits_df = hits_df.collect() - hits_by_motif = hits_df.partition_by("motif_name", as_dict=True) - dummy_df = hits_df.clear() +def plot_hit_stat_distributions( + hits_df: pl.LazyFrame, motif_names: List[str], plot_dir: str +) -> None: + """Plot distributions of hit statistics for each motif. + + Creates separate histogram plots for coefficient, similarity, and importance + score distributions for each motif. Saves plots in both PNG (high-res) and + SVG (vector) formats. + + Parameters + ---------- + hits_df : pl.LazyFrame + Lazy DataFrame containing hit data with required columns: + - motif_name : str, name of the motif + - hit_coefficient_global : float, global coefficient values + - hit_similarity : float, similarity scores to motif CWM + - hit_importance : float, importance scores from attribution + motif_names : List[str] + List of motif names to generate plots for. Motifs not present + in hits_df will result in empty histograms. + plot_dir : str + Directory path where plots will be saved. Creates subdirectory + 'motif_stat_distributions' if it doesn't exist. + + Notes + ----- + For each motif, creates three separate plots: + - {motif_name}_coefficients.{png,svg} : coefficient distribution + - {motif_name}_similarities.{png,svg} : similarity distribution + - {motif_name}_importances.{png,svg} : importance distribution + """ + hits_df_collected = hits_df.collect() + hits_by_motif = hits_df_collected.partition_by("motif_name", as_dict=True) + dummy_df = hits_df_collected.clear() motifs_dir = os.path.join(plot_dir, "motif_stat_distributions") os.makedirs(motifs_dir, exist_ok=True) @@ -42,6 +111,7 @@ def plot_hit_stat_distributions(hits_df, motif_names, plot_dir): fig, ax = plt.subplots(figsize=(5, 2)) + # Plot coefficient distribution ax.hist(coefficients, bins=50, density=True) output_path_png = os.path.join(motifs_dir, f"{m}_coefficients.png") @@ -52,6 +122,7 @@ def plot_hit_stat_distributions(hits_df, motif_names, plot_dir): fig, ax = plt.subplots(figsize=(5, 2)) + # Plot similarity distribution ax.hist(similarities, bins=50, density=True) output_path_png = os.path.join(motifs_dir, f"{m}_similarities.png") @@ -62,6 +133,7 @@ def plot_hit_stat_distributions(hits_df, motif_names, plot_dir): fig, ax = plt.subplots(figsize=(5, 2)) + # Plot importance distribution ax.hist(importances, bins=50, density=True) output_path_png = os.path.join(motifs_dir, f"{m}_importances.png") @@ -71,7 +143,34 @@ def plot_hit_stat_distributions(hits_df, motif_names, plot_dir): plt.close(fig) -def plot_hit_peak_distributions(occ_df, motif_names, plot_dir): +def plot_hit_peak_distributions( + occ_df: pl.DataFrame, motif_names: List[str], plot_dir: str +) -> None: + """Plot distribution of hits per peak for each motif. + + Creates bar plots showing the frequency distribution of hit counts per peak + for each motif, plus an overall distribution of total hits per peak. + + Parameters + ---------- + occ_df : pl.DataFrame + DataFrame containing motif occurrence counts per peak. Expected to have: + - One column per motif name with integer hit counts + - 'total' column with sum of all motif hits per peak + - Each row represents a peak/genomic region + motif_names : List[str] + List of motif names corresponding to columns in occ_df. + plot_dir : str + Directory to save plots. Creates 'motif_hit_distributions' subdirectory. + + Notes + ----- + Generates the following plots: + - Individual motif hit distributions: {motif_name}.{png,svg} + - Overall hit distribution: total_hit_distribution.{png,svg} + + Bar plots show frequency (proportion) on y-axis and hit count on x-axis. + """ motifs_dir = os.path.join(plot_dir, "motif_hit_distributions") os.makedirs(motifs_dir, exist_ok=True) @@ -92,7 +191,7 @@ def plot_hit_peak_distributions(occ_df, motif_names, plot_dir): plt.savefig(output_path_svg) plt.close(fig) - + fig, ax = plt.subplots(figsize=(8, 4)) unique, counts = np.unique(occ_df.get_column("total"), return_counts=True) @@ -114,16 +213,41 @@ def plot_hit_peak_distributions(occ_df, motif_names, plot_dir): plt.close(fig) -def plot_peak_motif_indicator_heatmap(peak_hit_counts, motif_names, output_dir): - """ - Plots a simple indicator heatmap of the motifs in each peak. +def plot_peak_motif_indicator_heatmap( + peak_hit_counts: Int[ndarray, "M M"], motif_names: List[str], output_dir: str +) -> None: + """Plot co-occurrence heatmap showing motif associations across peaks. + + Creates a normalized correlation heatmap showing how frequently pairs of + motifs co-occur within the same genomic peaks. Values are normalized by + the geometric mean of individual motif frequencies. + + Parameters + ---------- + peak_hit_counts : Int[ndarray, "M M"] + Co-occurrence matrix where M = len(motif_names). + Entry (i,j) represents the number of peaks containing both motif i and j. + Diagonal entries represent total peaks containing each individual motif. + motif_names : List[str] + List of motif names for axis labels. Order must match matrix dimensions. + output_dir : str + Directory path where the heatmap plots will be saved. + + Notes + ----- + Saves plots as: + - motif_cooocurrence.png : High-resolution raster format + - motif_cooocurrence.svg : Vector format + + The heatmap uses correlation normalization: matrix[i,j] / sqrt(matrix[i,i] * matrix[j,j]) + Colors use the 'Greens' colormap with values typically in [0, 1] range. """ cov_norm = 1 / np.sqrt(np.diag(peak_hit_counts)) matrix = peak_hit_counts * cov_norm[:, None] * cov_norm[None, :] motif_keys = [abbreviate_motif_name(m) for m in motif_names] - fig, ax = plt.subplots(figsize=(8, 8), layout='constrained') - + fig, ax = plt.subplots(figsize=(8, 8), layout="constrained") + # Plot the heatmap cax = ax.imshow(matrix, interpolation="nearest", aspect="equal", cmap="Greens") @@ -135,26 +259,55 @@ def plot_peak_motif_indicator_heatmap(peak_hit_counts, motif_names, output_dir): ax.set_xlabel("Motif i") ax.set_ylabel("Motif j") - ax.tick_params(axis='both', labelsize=8) + ax.tick_params(axis="both", labelsize=8) cbar = fig.colorbar(cax, ax=ax, orientation="vertical", shrink=0.6, aspect=30) - cbar.ax.tick_params(labelsize=8) - + cbar.ax.tick_params(labelsize=8) + output_path_png = os.path.join(output_dir, "motif_cooocurrence.png") plt.savefig(output_path_png, dpi=300) output_path_svg = os.path.join(output_dir, "motif_cooocurrence.svg") - plt.savefig(output_path_svg) + plt.savefig(output_path_svg, dpi=300) plt.close() -def plot_seqlet_confusion_heatmap(seqlet_confusion, motif_names, output_dir): +def plot_seqlet_confusion_heatmap( + seqlet_confusion: Int[ndarray, "M M"], motif_names: List[str], output_dir: str +) -> None: + """Plot confusion matrix heatmap comparing seqlets to hit calls. + + Creates a heatmap showing the overlap between TF-MoDISCo seqlets and + Fi-NeMo hit calls. Rows represent seqlet motifs, columns represent hit motifs. + + Parameters + ---------- + seqlet_confusion : Int[ndarray, "M M"] + Confusion matrix where M = len(motif_names). + Entry (i,j) represents the number of seqlets of motif i that overlap + with hits called for motif j. + motif_names : List[str] + List of motif names for axis labels. Order must match matrix dimensions. + output_dir : str + Directory path where the confusion matrix plots will be saved. + + Notes + ----- + Saves plots as: + - seqlet_confusion.png : High-resolution raster format + - seqlet_confusion.svg : Vector format + + The heatmap uses 'Blues' colormap. Perfect agreement would show a diagonal + pattern with high values along the diagonal and low off-diagonal values. + """ motif_keys = [abbreviate_motif_name(m) for m in motif_names] - fig, ax = plt.subplots(figsize=(8, 8), layout='constrained') - + fig, ax = plt.subplots(figsize=(8, 8), layout="constrained") + # Plot the heatmap - cax = ax.imshow(seqlet_confusion, interpolation="nearest", aspect="equal", cmap="Blues") + cax = ax.imshow( + seqlet_confusion, interpolation="nearest", aspect="equal", cmap="Blues" + ) # Set axes on heatmap ax.set_yticks(np.arange(len(motif_keys))) @@ -164,23 +317,47 @@ def plot_seqlet_confusion_heatmap(seqlet_confusion, motif_names, output_dir): ax.set_xlabel("Hit motif") ax.set_ylabel("Seqlet motif") - ax.tick_params(axis='both', labelsize=8) + ax.tick_params(axis="both", labelsize=8) cbar = fig.colorbar(cax, ax=ax, orientation="vertical", shrink=0.6, aspect=30) - cbar.ax.tick_params(labelsize=8) + cbar.ax.tick_params(labelsize=8) output_path_png = os.path.join(output_dir, "seqlet_confusion.png") plt.savefig(output_path_png, dpi=300) output_path_svg = os.path.join(output_dir, "seqlet_confusion.svg") - plt.savefig(output_path_svg) + plt.savefig(output_path_svg, dpi=300) plt.close() class LogoGlyph(AbstractPathEffect): - def __init__(self, glyph, ref_glyph='E', font_props=None, - offset=(0., 0.), **kwargs): + """Path effect for creating sequence logo glyphs with normalized dimensions. + + This class creates properly scaled and positioned text glyphs for sequence + logos by normalizing character dimensions and applying appropriate transforms. + + Parameters + ---------- + glyph : str + Single character to render (e.g., 'A', 'C', 'G', 'T'). + ref_glyph : str, default 'E' + Reference character used for width normalization. + font_props : FontProperties, optional + Font properties for the glyph rendering. + offset : Tuple[float, float], default (0., 0.) + Offset for glyph positioning. + **kwargs + Additional graphics collection parameters. + """ + def __init__( + self, + glyph: str, + ref_glyph: str = "E", + font_props: Optional[FontProperties] = None, + offset: Tuple[float, float] = (0.0, 0.0), + **kwargs, + ) -> None: super().__init__(offset) path_orig = TextPath((0, 0), glyph, size=1, prop=font_props) @@ -205,16 +382,81 @@ def __init__(self, glyph, ref_glyph='E', font_props=None, #: The dictionary of keywords to update the graphics collection with. self._gc = kwargs - def draw_path(self, renderer, gc, tpath, affine, rgbFace): + def draw_path(self, renderer, gc, tpath, affine, rgbFace) -> Any: # type: ignore[override] + """Draw the glyph path using the renderer. + + Parameters + ---------- + renderer : matplotlib renderer + The renderer to draw with. + gc : GraphicsContext + Graphics context for drawing properties. + tpath : Path + Original text path (unused, using self.path instead). + affine : Transform + Affine transformation to apply. + rgbFace : color + Face color for the glyph. + + Returns + ------- + Any + Result from renderer.draw_path. + """ return renderer.draw_path(gc, self.path, affine, rgbFace) -def plot_logo(ax, heights, glyphs, colors=None, font_props=None, shade_bounds=None): +def plot_logo( + ax: Axes, + heights: Float[ndarray, "B W"], + glyphs: Iterable[str], + colors: Optional[Mapping[str, Optional[str]]] = None, + font_props: Optional[FontProperties] = None, + shade_bounds: Optional[Tuple[int, int]] = None, +) -> None: + """Plot sequence logo from contribution weight matrix. + + Creates a sequence logo visualization where letter heights represent + the contribution or information content at each position. Supports + both positive and negative contributions with proper stacking. + + Parameters + ---------- + ax : Axes + Matplotlib axes object to plot on. + heights : Float[ndarray, "B W"] + Height matrix where B = len(glyphs) and W = motif width. + Entry (i,j) represents the height/contribution of base i at position j. + Can contain both positive and negative values. + glyphs : Iterable[str] + Sequence of base characters corresponding to rows in heights matrix. + Typically ['A', 'C', 'G', 'T'] for DNA. + colors : Dict[str, str], optional + Color mapping for each base. Keys should match glyphs. + If None, all bases will use default matplotlib colors. + font_props : FontProperties, optional + Font properties for letter rendering. If None, uses default font. + shade_bounds : Tuple[int, int], optional + (start, end) position indices to shade in background. + Useful for highlighting core motif regions. + + Notes + ----- + Positive and negative contributions are handled separately: + - Positive values are stacked above zero line in order of descending absolute value + - Negative values are stacked below zero line in order of descending absolute value + - A horizontal line is drawn at y=0 for reference + + The resulting plot has: + - X-axis: Position in motif (0-indexed) + - Y-axis: Contribution magnitude + - Bar width: 0.95 (small gaps between positions) + """ if colors is None: colors = {g: None for g in glyphs} ax.margins(x=0, y=0) - + pos_values = np.clip(heights, 0, None) neg_values = np.clip(heights, None, 0) pos_order = np.argsort(pos_values, axis=0) @@ -222,47 +464,106 @@ def plot_logo(ax, heights, glyphs, colors=None, font_props=None, shade_bounds=No pos_reorder = np.argsort(pos_order, axis=0) neg_reorder = np.argsort(neg_order, axis=0) pos_offsets = np.take_along_axis( - np.cumsum( - np.take_along_axis(pos_values, pos_order, axis=0), axis=0 - ), pos_reorder, axis=0 + np.cumsum(np.take_along_axis(pos_values, pos_order, axis=0), axis=0), + pos_reorder, + axis=0, ) neg_offsets = np.take_along_axis( - np.cumsum( - np.take_along_axis(neg_values, neg_order, axis=0), axis=0 - ), neg_reorder, axis=0 + np.cumsum(np.take_along_axis(neg_values, neg_order, axis=0), axis=0), + neg_reorder, + axis=0, ) bottoms = pos_offsets + neg_offsets - heights x = np.arange(heights.shape[1]) for glyph, height, bottom in zip(glyphs, heights, bottoms): - ax.bar(x, height, 0.95, bottom=bottom, - path_effects=[LogoGlyph(glyph, font_props=font_props)], color=colors[glyph]) + ax.bar( + x, + height, + 0.95, + bottom=bottom, + path_effects=[LogoGlyph(glyph, font_props=font_props)], + color=colors[glyph], + ) if shade_bounds is not None: start, end = shade_bounds - ax.axvspan(start - 0.5, end - 0.5, color='0.9', zorder=-1) + ax.axvspan(start - 0.5, end - 0.5, color="0.9", zorder=-1) - ax.axhline(zorder=-1, linewidth=0.5, color='black') + ax.axhline(zorder=-1, linewidth=0.5, color="black") -LOGO_ALPHABET = 'ACGT' -LOGO_COLORS = {"A": '#109648', "C": '#255C99', "G": '#F7B32B', "T": '#D62839'} +LOGO_ALPHABET = "ACGT" +LOGO_COLORS = {"A": "#109648", "C": "#255C99", "G": "#F7B32B", "T": "#D62839"} LOGO_FONT = FontProperties(weight="bold") -def plot_cwms(cwms, trim_bounds, out_dir, alphabet=LOGO_ALPHABET, colors=LOGO_COLORS, font=LOGO_FONT): +def plot_cwms( + cwms: Dict[str, Dict[str, Float[ndarray, "4 W"]]], + trim_bounds: Dict[str, Dict[str, Tuple[int, int]]], + out_dir: str, + alphabet: str = LOGO_ALPHABET, + colors: Dict[str, str] = LOGO_COLORS, + font: FontProperties = LOGO_FONT, +) -> None: + """Plot contribution weight matrices as sequence logos. + + Creates sequence logo plots for all motifs and CWM types, with optional + shading to highlight trimmed regions. Saves plots in both PNG and SVG formats. + + Parameters + ---------- + cwms : Dict[str, Dict[str, Float[ndarray, "4 W"]]] + Nested dictionary structure: {motif_name: {cwm_type: cwm_array}}. + Each cwm_array has shape (4, W) where W is motif width. + Rows correspond to bases in alphabet order. + trim_bounds : Dict[str, Dict[str, Tuple[int, int]]] + Nested dictionary: {motif_name: {cwm_type: (start, end)}}. + Defines regions to shade in the sequence logos. + out_dir : str + Output directory where motif subdirectories will be created. + alphabet : str, default LOGO_ALPHABET + DNA alphabet string, typically 'ACGT'. + colors : Dict[str, str], default LOGO_COLORS + Color mapping for DNA bases. Keys should match alphabet characters. + font : FontProperties, default LOGO_FONT + Font properties for sequence logo rendering. + + Notes + ----- + Directory structure created: + ``` + out_dir/ + ├── motif1/ + │ ├── cwm_type1.png + │ ├── cwm_type1.svg + │ └── ... + └── motif2/ + └── ... + ``` + + Each plot is 10x2 inches with trimmed regions shaded if specified. + Spines (plot borders) are hidden for cleaner appearance. + """ for m, v in cwms.items(): motif_dir = os.path.join(out_dir, m) os.makedirs(motif_dir, exist_ok=True) for cwm_type, cwm in v.items(): fig, ax = plt.subplots(figsize=(10, 2)) - plot_logo(ax, cwm, alphabet, colors=colors, font_props=font, shade_bounds=trim_bounds[m][cwm_type]) + plot_logo( + ax, + cwm, + alphabet, + colors=colors, + font_props=font, + shade_bounds=trim_bounds[m][cwm_type], + ) for name, spine in ax.spines.items(): spine.set_visible(False) - + output_path_png = os.path.join(motif_dir, f"{cwm_type}.png") plt.savefig(output_path_png, dpi=100) output_path_svg = os.path.join(motif_dir, f"{cwm_type}.svg") @@ -271,7 +572,36 @@ def plot_cwms(cwms, trim_bounds, out_dir, alphabet=LOGO_ALPHABET, colors=LOGO_CO plt.close(fig) -def plot_hit_vs_seqlet_counts(recall_data, output_dir): +def plot_hit_vs_seqlet_counts( + recall_data: Dict[str, Dict[str, Union[int, float]]], output_dir: str +) -> None: + """Plot scatter plot comparing hit counts to seqlet counts per motif. + + Creates a log-log scatter plot showing the relationship between the number + of hits called by Fi-NeMo and the number of seqlets identified by TF-MoDISCo + for each motif. Includes diagonal reference line and motif annotations. + + Parameters + ---------- + recall_data : Dict[str, Dict[str, Union[int, float]]] + Dictionary with motif names as keys and metrics dictionaries as values. + Each metrics dictionary must contain: + - 'num_hits_total' : int, total number of hits for the motif + - 'num_seqlets' : int, total number of seqlets for the motif + output_dir : str + Directory path where the scatter plot will be saved. + + Notes + ----- + Saves plots as: + - hit_vs_seqlet_counts.png : High-resolution raster format + - hit_vs_seqlet_counts.svg : Vector format + + Plot features: + - Log-log scale on both axes + - Diagonal reference line (y = x) as dashed line + - Points annotated with abbreviated motif names + """ x = [] y = [] m = [] @@ -282,15 +612,15 @@ def plot_hit_vs_seqlet_counts(recall_data, output_dir): lim = max(np.amax(x), np.amax(y)) - fig, ax = plt.subplots(figsize=(8, 8), layout='constrained') + fig, ax = plt.subplots(figsize=(8, 8), layout="constrained") ax.axline((0, 0), (lim, lim), color="0.3", linewidth=0.7, linestyle=(0, (5, 5))) ax.scatter(x, y, s=5) for i, txt in enumerate(m): short = abbreviate_motif_name(txt) ax.annotate(short, (x[i], y[i]), fontsize=8, weight="bold") - ax.set_yscale('log') - ax.set_xscale('log') + ax.set_yscale("log") + ax.set_xscale("log") ax.set_xlabel("Hits per motif") ax.set_ylabel("Seqlets per motif") @@ -303,11 +633,59 @@ def plot_hit_vs_seqlet_counts(recall_data, output_dir): plt.close() -def write_report(report_df, motif_names, out_path, compute_recall, use_seqlets): - template_str = importlib.resources.files(templates).joinpath('report.html').read_text() +def write_report( + report_df: pl.DataFrame, + motif_names: List[str], + out_path: str, + compute_recall: bool, + use_seqlets: bool, +) -> None: + """Generate and write HTML report from motif analysis results. + + Creates a comprehensive HTML report with tables and visualizations + summarizing the Fi-NeMo motif discovery and hit calling results. + + Parameters + ---------- + report_df : pl.DataFrame + DataFrame containing motif statistics and performance metrics. + Expected columns depend on compute_recall and use_seqlets flags. + motif_names : List[str] + List of motif names to include in the report. + Order determines presentation sequence in the report. + out_path : str + File path where the HTML report will be written. + Parent directory must exist. + compute_recall : bool + Whether recall metrics were computed and should be included + in the report template. + use_seqlets : bool + Whether TF-MoDISCo seqlet data was used in the analysis + and should be referenced in the report. + + Notes + ----- + Uses Jinja2 templating with the report.html template from the + templates package. The template receives: + - report_data: Iterator of DataFrame rows as named tuples + - motif_names: List of motif names + - compute_recall: Boolean flag for recall metrics + - use_seqlets: Boolean flag for seqlet usage + + Raises + ------ + OSError + If the output path cannot be written. + """ + template_str = ( + importlib.resources.files(templates).joinpath("report.html").read_text() + ) template = Template(template_str) - report = template.render(report_data=report_df.iter_rows(named=True), - motif_names=motif_names, compute_recall=compute_recall, - use_seqlets=use_seqlets) + report = template.render( + report_data=report_df.iter_rows(named=True), + motif_names=motif_names, + compute_recall=compute_recall, + use_seqlets=use_seqlets, + ) with open(out_path, "w") as f: - f.write(report) \ No newline at end of file + f.write(report) From e9cd90f4eaa81edb815282e4085a13a00a21f2b0 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Sun, 31 Aug 2025 13:51:08 -0700 Subject: [PATCH 15/39] Methods diagram --- README.md | 8 ++++---- assets/methods.svg | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) create mode 100644 assets/methods.svg diff --git a/README.md b/README.md index 9231afd..c2c9bbc 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ ## Overview -Fi-NeMo implements a competitive optimization approach using proximal gradient descent to identify motif instances by solving a sparse linear reconstruction problem. Unlike traditional sequence-based methods, Fi-NeMo leverages context-aware importance scores from deep neural networks to accurately map transcription factor binding sites, enabling the discovery of both high-confidence canonical motifs and low-prevalence cofactor motifs that are often missed by conventional approaches. +Fi-NeMo implements a competitive optimization approach using proximal gradient descent to identify motif instances by solving a sparse linear reconstruction problem. Unlike traditional sequence-based methods, Fi-NeMo leverages context-aware importance scores from deep neural networks to comprehensively map transcription factor binding sites, enabling the identification of both high-confidence canonical motifs and low-prevalence cofactor motifs that are often missed by conventional approaches. The algorithm represents contribution scores as weighted combinations of motif contribution weight matrices (CWMs) at specific genomic positions. This competitive assignment process more closely reflects the biological reality of transcription factors competing for binding sites, resulting in superior sensitivity and specificity compared to sequence-only methods. @@ -18,9 +18,9 @@ The algorithm represents contribution scores as weighted combinations of motif c ## Method -Fi-NeMo solves motif instance calling as an optimization problem that reconstructs contribution score tracks as sparse linear combinations of motif CWMs. The algorithm formulates this as an L1-regularized linear model. +Fi-NeMo solves motif instance calling as an optimization problem that reconstructs contribution score tracks as sparse linear combinations of motif CWMs, formulated as an L1-regularized linear model. This competitive assignment encourages overlapping motif instances to be resolved in a meaningful way, with stronger matches receiving higher coefficients while weaker or redundant matches are suppressed. -This competitive assignment encourages overlapping motif instances to be resolved in a meaningful way, with stronger matches receiving higher coefficients while weaker or redundant matches are suppressed. +![Methods diagram](/assets/methods.svg | width=100) ## References @@ -92,7 +92,7 @@ Recommended: Fi-NeMo provides a command-line utility named `finemo` for motif instance calling and analysis. The typical workflow involves three main steps: -1. **Preprocessing**: Transform input contributions and sequences into a compressed format +1. **Preprocessing**: Transform input contributions and sequences into a unified format 2. **Hit Calling**: Identify motif instances using the Fi-NeMo algorithm 3. **Reporting and Analysis**: Generate visualizations and perform post-processing diff --git a/assets/methods.svg b/assets/methods.svg new file mode 100644 index 0000000..90ff4ea --- /dev/null +++ b/assets/methods.svg @@ -0,0 +1 @@ +Minimize DifferenceConvolveBackpropagateObserved scoresReconstructed scoresMotif CWMsHit coefficients (learned) \ No newline at end of file From 747908c281b40052cc2191e9bd3aaaaa53aa42bb Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Sun, 31 Aug 2025 13:52:32 -0700 Subject: [PATCH 16/39] Figure width --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index c2c9bbc..17eea6c 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ The algorithm represents contribution scores as weighted combinations of motif c Fi-NeMo solves motif instance calling as an optimization problem that reconstructs contribution score tracks as sparse linear combinations of motif CWMs, formulated as an L1-regularized linear model. This competitive assignment encourages overlapping motif instances to be resolved in a meaningful way, with stronger matches receiving higher coefficients while weaker or redundant matches are suppressed. -![Methods diagram](/assets/methods.svg | width=100) +![Methods diagram](/assets/methods.svg) ## References From 92a67bf2ef86326372ca868cbb94f4dfa5a17e71 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Sun, 31 Aug 2025 13:53:40 -0700 Subject: [PATCH 17/39] Figure width --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 17eea6c..7ce7d1a 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ The algorithm represents contribution scores as weighted combinations of motif c Fi-NeMo solves motif instance calling as an optimization problem that reconstructs contribution score tracks as sparse linear combinations of motif CWMs, formulated as an L1-regularized linear model. This competitive assignment encourages overlapping motif instances to be resolved in a meaningful way, with stronger matches receiving higher coefficients while weaker or redundant matches are suppressed. -![Methods diagram](/assets/methods.svg) + ## References From 76cb5ab0c5d046adc35e4f1b67c32a7ea730e8f2 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Sun, 31 Aug 2025 13:54:11 -0700 Subject: [PATCH 18/39] Figure width --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 7ce7d1a..81812ae 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ The algorithm represents contribution scores as weighted combinations of motif c Fi-NeMo solves motif instance calling as an optimization problem that reconstructs contribution score tracks as sparse linear combinations of motif CWMs, formulated as an L1-regularized linear model. This competitive assignment encourages overlapping motif instances to be resolved in a meaningful way, with stronger matches receiving higher coefficients while weaker or redundant matches are suppressed. - + ## References From b9a17c4fad3f47db08ffa4c472c5dfe2976fe432 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Sun, 31 Aug 2025 13:54:34 -0700 Subject: [PATCH 19/39] Figure width --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 81812ae..a127c91 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ The algorithm represents contribution scores as weighted combinations of motif c Fi-NeMo solves motif instance calling as an optimization problem that reconstructs contribution score tracks as sparse linear combinations of motif CWMs, formulated as an L1-regularized linear model. This competitive assignment encourages overlapping motif instances to be resolved in a meaningful way, with stronger matches receiving higher coefficients while weaker or redundant matches are suppressed. - + ## References From 516cfc7520b970db289cfc9a09b77fe3d1104777 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Sun, 31 Aug 2025 13:55:18 -0700 Subject: [PATCH 20/39] Figure width --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index a127c91..ddf9cfc 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,7 @@ The algorithm represents contribution scores as weighted combinations of motif c Fi-NeMo solves motif instance calling as an optimization problem that reconstructs contribution score tracks as sparse linear combinations of motif CWMs, formulated as an L1-regularized linear model. This competitive assignment encourages overlapping motif instances to be resolved in a meaningful way, with stronger matches receiving higher coefficients while weaker or redundant matches are suppressed. - +Methods diagram ## References From 9f6034167dec133de13f25dc70273345466fc87d Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Sun, 31 Aug 2025 13:55:51 -0700 Subject: [PATCH 21/39] Figure width --- README.md | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index ddf9cfc..196be85 100644 --- a/README.md +++ b/README.md @@ -20,7 +20,9 @@ The algorithm represents contribution scores as weighted combinations of motif c Fi-NeMo solves motif instance calling as an optimization problem that reconstructs contribution score tracks as sparse linear combinations of motif CWMs, formulated as an L1-regularized linear model. This competitive assignment encourages overlapping motif instances to be resolved in a meaningful way, with stronger matches receiving higher coefficients while weaker or redundant matches are suppressed. -Methods diagram +
+ +
## References From fb58facd1880b1f5eccf4178ace437e8c6777333 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Mon, 1 Sep 2025 12:38:02 -0700 Subject: [PATCH 22/39] TF-MoDISco capitalization --- README.md | 12 ++++++------ src/finemo/__init__.py | 4 ++-- src/finemo/data_io.py | 24 ++++++++++++------------ src/finemo/evaluation.py | 20 ++++++++++---------- src/finemo/main.py | 14 +++++++------- src/finemo/templates/report.html | 10 +++++----- src/finemo/visualization.py | 10 +++++----- 7 files changed, 47 insertions(+), 47 deletions(-) diff --git a/README.md b/README.md index 196be85..32affa9 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ The algorithm represents contribution scores as weighted combinations of motif c - **Competitive motif assignment**: Biologically-motivated algorithm that resolves similar motifs - **Context-aware analysis**: Leverages neural network importance scores for improved sensitivity and specificity - **Comprehensive evaluation**: Built-in tools for assessing and visualizing motif discovery quality and hit calling performance -- **Multiple input formats**: Support for bigWig, HDF5, and TF-MoDISCo output formats +- **Multiple input formats**: Support for bigWig, HDF5, and TF-MoDISco output formats ## Method @@ -30,7 +30,7 @@ Fi-NeMo is described in: > Tseng, Ramalingam, Wang, Schreiber, et al. "Decoding predictive motif lexicons and syntax from deep learning models of transcription factor binding profiles." (manuscript in preparation) Related tools: -- [TF-MoDISCo](https://github.com/jmschrei/tfmodisco-lite): *De novo* motif discovery from importance scores +- [TF-MoDISco](https://github.com/jmschrei/tfmodisco-lite): *De novo* motif discovery from importance scores - [BPNet](https://github.com/kundajelab/bpnet-refactor): Deep learning models for TF binding prediction - [ChromBPNet](https://github.com/kundajelab/chrombpnet): Deep learning models for chromatin accessibility prediction @@ -164,7 +164,7 @@ Usage: `finemo extract-regions-modisco-fmt -s -a -o < #### `finemo call-hits` -Identify motif instances in input regions using the Fi-NeMo competitive optimization algorithm. This is the core functionality that leverages TF-MoDISCo CWMs to find motif occurrences in contribution score data. +Identify motif instances in input regions using the Fi-NeMo competitive optimization algorithm. This is the core functionality that leverages TF-MoDISco CWMs to find motif occurrences in contribution score data. Usage: `finemo call-hits -r -m -o [-p ] [-t ] [-l ] [-b ] [-J]` @@ -257,14 +257,14 @@ Usage: `finemo call-hits -r -m -o [-p ] - **Scale Invariance**: Hit calling depends on motif and contribution score shapes, not absolute magnitudes. Use `hit_coefficient_global` or `hit_importance` for importance-based thresholding. - **Competitive Assignment**: Overlapping motif candidates compete; only the best-fitting motif at each position receives a non-zero coefficient. -- **Legacy Format Support**: Convert older TF-MoDISCo files using `modisco convert` from [tfmodisco-lite](https://github.com/jmschrei/tfmodisco-lite). +- **Legacy Format Support**: Convert older TF-MoDISco files using `modisco convert` from [tfmodisco-lite](https://github.com/jmschrei/tfmodisco-lite). ### Output reporting and post-processing #### `finemo report` -Generate an HTML report (`report.html`) visualizing TF-MoDISCo seqlet recall and hit distributions. -If `-n/--no-recall` is not set, the regions used for hit calling must exactly match those used during the TF-MoDISCo motif discovery process. +Generate an HTML report (`report.html`) visualizing TF-MoDISco seqlet recall and hit distributions. +If `-n/--no-recall` is not set, the regions used for hit calling must exactly match those used during the TF-MoDISco motif discovery process. This command does not utilize the GPU. Usage: `finemo report -r -H -o [-m ] [-W ] [-n]` diff --git a/src/finemo/__init__.py b/src/finemo/__init__.py index b170d01..1e891a4 100644 --- a/src/finemo/__init__.py +++ b/src/finemo/__init__.py @@ -11,7 +11,7 @@ Key Features ------------ - GPU-accelerated hit calling using PyTorch -- Support for multiple input formats (bigWig, HDF5, TF-MoDISCo) +- Support for multiple input formats (bigWig, HDF5, TF-MoDISco) - Competitive motif instance assignment - Comprehensive evaluation and visualization tools - Post-processing utilities for hit refinement @@ -58,6 +58,6 @@ See Also -------- -TF-MoDISCo : https://github.com/jmschrei/tfmodisco-lite +TF-MoDISco : https://github.com/jmschrei/tfmodisco-lite BPNet : https://github.com/kundajelab/bpnet-refactor """ diff --git a/src/finemo/data_io.py b/src/finemo/data_io.py index a604904..d266178 100644 --- a/src/finemo/data_io.py +++ b/src/finemo/data_io.py @@ -5,7 +5,7 @@ - Genome sequences (FASTA format) - Contribution scores (bigWig, HDF5 formats) - Neural network model outputs -- Motif data from TF-MoDISCo +- Motif data from TF-MoDISco - Hit calling results The module supports multiple input formats used for contribution scores @@ -454,7 +454,7 @@ def load_npy_or_npz(path: str) -> ndarray: def load_regions_from_modisco_fmt( shaps_paths: List[str], ohe_path: str, half_width: int ) -> Tuple[Int[ndarray, "N 4 L"], Float[ndarray, "N 4 L"]]: - """Load genomic sequences and contribution scores from TF-MoDISCo format files. + """Load genomic sequences and contribution scores from TF-MoDISco format files. Parameters ---------- @@ -708,7 +708,7 @@ def softmax(x: Float[ndarray, "4 W"], temp: float = 100) -> Float[ndarray, "4 W" def _motif_name_sort_key(data: Tuple[str, Any]) -> Union[Tuple[int], Tuple[int, str]]: - """Generate sort key for TF-MoDISCo motif names. + """Generate sort key for TF-MoDISco motif names. This function creates a sort key that orders motifs by pattern number, with non-standard patterns sorted to the end. @@ -753,16 +753,16 @@ def load_modisco_motifs( motif_lambda_default: float, include_rc: bool, ) -> Tuple[pl.DataFrame, Float[ndarray, "M 4 W"], Int[ndarray, "M W"], ndarray]: - """Load motif data from TF-MoDISCo HDF5 file with customizable processing options. + """Load motif data from TF-MoDISco HDF5 file with customizable processing options. This function extracts contribution weight matrices and associated metadata from - TF-MoDISCo results, with support for custom naming, trimming, and regularization + TF-MoDISco results, with support for custom naming, trimming, and regularization parameters. Parameters ---------- modisco_h5_path : str - Path to TF-MoDISCo HDF5 results file containing pattern groups. + Path to TF-MoDISco HDF5 results file containing pattern groups. trim_coords : Optional[Dict[str, Tuple[int, int]]] Manual trim coordinates for specific motifs {motif_name: (start, end)}. Takes precedence over automatic trimming based on thresholds. @@ -972,16 +972,16 @@ def load_modisco_seqlets( modisco_half_width: int, lazy: bool = False, ) -> Union[pl.DataFrame, pl.LazyFrame]: - """Load seqlet data from TF-MoDISCo HDF5 file and convert to genomic coordinates. + """Load seqlet data from TF-MoDISco HDF5 file and convert to genomic coordinates. - This function extracts seqlet instances from TF-MoDISCo results and converts + This function extracts seqlet instances from TF-MoDISco results and converts their relative positions to absolute genomic coordinates using peak region information. Parameters ---------- modisco_h5_path : str - Path to TF-MoDISCo HDF5 results file containing seqlet data. + Path to TF-MoDISco HDF5 results file containing seqlet data. peaks_df : pl.DataFrame DataFrame containing peak region information with columns: 'peak_id', 'chr', 'chr_id', 'peak_region_start'. @@ -991,7 +991,7 @@ def load_modisco_seqlets( half_width : int Half-width of the current analysis regions. modisco_half_width : int - Half-width of the regions used in the original TF-MoDISCo analysis. + Half-width of the regions used in the original TF-MoDISco analysis. Used to calculate coordinate offsets. lazy : bool, default False If True, returns a LazyFrame for efficient chaining of operations. @@ -1019,7 +1019,7 @@ def load_modisco_seqlets( motif name, and reverse complement status to avoid redundant instances. The coordinate transformation accounts for differences in region sizes - between the original TF-MoDISCo analysis and the current analysis. + between the original TF-MoDISco analysis and the current analysis. """ start_lst = [] @@ -1107,7 +1107,7 @@ def get_pattern_number(x): def write_modisco_seqlets( seqlets_df: Union[pl.DataFrame, pl.LazyFrame], out_path: str ) -> None: - """Write TF-MoDISCo seqlets to TSV file. + """Write TF-MoDISco seqlets to TSV file. Parameters ---------- diff --git a/src/finemo/evaluation.py b/src/finemo/evaluation.py index 61a89b2..cf74375 100644 --- a/src/finemo/evaluation.py +++ b/src/finemo/evaluation.py @@ -2,7 +2,7 @@ This module provides functions for: - Computing motif occurrence statistics and co-occurrence patterns -- Evaluating motif discovery quality against TF-MoDISCo results +- Evaluating motif discovery quality against TF-MoDISco results - Analyzing hit calling performance and recall metrics - Generating confusion matrices for seqlet-hit comparisons """ @@ -176,10 +176,10 @@ def tfmodisco_comparison( Dict[str, Dict[str, Float[ndarray, "4 W"]]], Dict[str, Dict[str, Tuple[int, int]]], ]: - """Compare Fi-NeMo hits with TF-MoDISCo seqlets and compute evaluation metrics. + """Compare Fi-NeMo hits with TF-MoDISco seqlets and compute evaluation metrics. This function performs comprehensive comparison between Fi-NeMo hit calls - and TF-MoDISCo seqlets, computing recall metrics, CWM similarities, + and TF-MoDISco seqlets, computing recall metrics, CWM similarities, and extracting contribution weight matrices for visualization. Parameters @@ -194,14 +194,14 @@ def tfmodisco_comparison( Peak metadata with columns: - peak_id, chr_id, peak_region_start seqlets_df : Optional[pl.DataFrame] - TF-MoDISCo seqlets with columns: + TF-MoDISco seqlets with columns: - chr_id, start_untrimmed, is_revcomp, motif_name If None, only basic hit statistics are computed. motifs_df : pl.DataFrame Motif metadata with columns: - motif_name, strand, motif_id, motif_start, motif_end cwms_modisco : Float[ndarray, "M 4 W"] - TF-MoDISCo contribution weight matrices. + TF-MoDISco contribution weight matrices. Shape: (n_modisco_motifs, 4, motif_width) motif_names : List[str] Names of motifs to analyze. @@ -223,7 +223,7 @@ def tfmodisco_comparison( cwms : Dict[str, Dict[str, Float[ndarray, "4 W"]]] Extracted CWMs for each motif and condition: - hits_fc, hits_rc: Forward/reverse complement hits - - modisco_fc, modisco_rc: TF-MoDISCo forward/reverse + - modisco_fc, modisco_rc: TF-MoDISco forward/reverse - seqlets_only, hits_restricted_only: Non-overlapping instances cwm_trim_bounds : Dict[str, Dict[str, Tuple[int, int]]] Trimming boundaries for each CWM type and motif. @@ -231,7 +231,7 @@ def tfmodisco_comparison( Notes ----- - Hits are filtered to central region defined by modisco_half_width - - CWM similarity is computed as normalized dot product between hit and TF-MoDISCo CWMs + - CWM similarity is computed as normalized dot product between hit and TF-MoDISco CWMs - Recall metrics require both hits_df and seqlets_df to be non-empty - Missing motifs are handled gracefully with empty DataFrames @@ -427,10 +427,10 @@ def seqlet_confusion( motif_names: List[str], motif_width: int, ) -> Tuple[pl.DataFrame, Float[ndarray, "M M"]]: - """Compute confusion matrix between TF-MoDISCo seqlets and Fi-NeMo hits. + """Compute confusion matrix between TF-MoDISco seqlets and Fi-NeMo hits. This function creates a confusion matrix showing the overlap between - TF-MoDISCo seqlets (ground truth) and Fi-NeMo hits across different motifs. + TF-MoDISco seqlets (ground truth) and Fi-NeMo hits across different motifs. Overlap frequencies are estimated using binned genomic coordinates. Parameters @@ -439,7 +439,7 @@ def seqlet_confusion( Fi-NeMo hit calls with required columns: - peak_id, start_untrimmed, end_untrimmed, strand, motif_name seqlets_df : pl.DataFrame - TF-MoDISCo seqlets with required columns: + TF-MoDISco seqlets with required columns: - chr_id, start_untrimmed, end_untrimmed, motif_name peaks_df : pl.DataFrame Peak metadata for joining coordinates: diff --git a/src/finemo/main.py b/src/finemo/main.py index 16d03af..9ab84f4 100644 --- a/src/finemo/main.py +++ b/src/finemo/main.py @@ -7,7 +7,7 @@ - Post-processing operations (hit collapsing, intersection) The CLI supports multiple input formats including bigWig, HDF5 (ChromBPNet/BPNet), -and TF-MoDISCo format. +and TF-MoDISco format. """ from . import data_io @@ -137,7 +137,7 @@ def extract_regions_modisco_fmt( out_path: str, region_width: int, ) -> None: - """Extract genomic regions and contribution scores from TF-MoDISCo format files. + """Extract genomic regions and contribution scores from TF-MoDISco format files. Parameters ---------- @@ -407,7 +407,7 @@ def report( """Generate comprehensive HTML report with statistics and visualizations. This function creates detailed analysis reports comparing Fi-NeMo hit calling - results with TF-MoDISCo seqlets, including performance metrics, distribution + results with TF-MoDISco seqlets, including performance metrics, distribution plots, and motif visualization. The report provides insights into hit calling quality and motif discovery accuracy. @@ -418,7 +418,7 @@ def report( hits_dir : str Path to directory containing Fi-NeMo hit calling outputs. modisco_h5_path : str, optional - Path to TF-MoDISCo H5 file. If None, seqlet comparisons are skipped. + Path to TF-MoDISco H5 file. If None, seqlet comparisons are skipped. peaks_path : str, optional DEPRECATED. Peak coordinates should be included in regions file. motifs_include_path : str, optional @@ -428,11 +428,11 @@ def report( out_dir : str Output directory for report files. modisco_region_width : int - Width of regions used by TF-MoDISCo (needed for coordinate conversion). + Width of regions used by TF-MoDISco (needed for coordinate conversion). cwm_trim_threshold : float DEPRECATED. This information is inferred from hit calling outputs. compute_recall : bool - Whether to compute recall metrics against TF-MoDISCo seqlets. + Whether to compute recall metrics against TF-MoDISco seqlets. use_seqlets : bool Whether to include seqlet-based comparisons in the report. @@ -674,7 +674,7 @@ def cli() -> None: """Command-line interface for the Fi-NeMo motif instance calling pipeline. This function provides the main entry point for all Fi-NeMo operations including: - - Data preprocessing from various formats (bigWig, HDF5, TF-MoDISCo) + - Data preprocessing from various formats (bigWig, HDF5, TF-MoDISco) - Motif hit calling using the Fi-NeMo algorithm - Report generation and visualization - Post-processing operations (hit collapsing, intersection) diff --git a/src/finemo/templates/report.html b/src/finemo/templates/report.html index a04417b..a2389d1 100644 --- a/src/finemo/templates/report.html +++ b/src/finemo/templates/report.html @@ -185,7 +185,7 @@

Fi-NeMo Motif Hit Calling Report

a GPU-accelerated method for identifying transcription factor binding sites using neural network contribution scores. Fi-NeMo uses a competitive optimization approach to comprehensively map motif instances by solving a sparse linear reconstruction problem. The report compares Fi-NeMo hits - with TF-MoDISCo seqlets (when available) and provides detailed statistics on hit quality and + with TF-MoDISco seqlets (when available) and provides detailed statistics on hit quality and motif discovery performance.

@@ -207,10 +207,10 @@

TF-MoDISco seqlet comparisons

Hit vs. seqlet counts

- This scatter plot compares the number of motif instances called by Fi-NeMo versus the number of TF-MoDISCo seqlets + This scatter plot compares the number of motif instances called by Fi-NeMo versus the number of TF-MoDISco seqlets identified for each motif. The dashed line represents perfect agreement (y = x). Fi-NeMo typically identifies - an order of magnitude more motif instances than TF-MoDISCo because: (1) TF-MoDISCo applies stringent filtering criteria - during seqlet identification, and (2) TF-MoDISCo often analyzes smaller genomic windows than those used for hit calling. + an order of magnitude more motif instances than TF-MoDISco because: (1) TF-MoDISco applies stringent filtering criteria + during seqlet identification, and (2) TF-MoDISco often analyzes smaller genomic windows than those used for hit calling.

{% endif %} @@ -218,7 +218,7 @@

Hit vs. seqlet counts

Motif-specific hit and seqlet analysis

This table provides detailed statistics for each motif, comparing the consistency between Fi-NeMo hits - and TF-MoDISCo seqlets. The analysis includes hit counts, overlap statistics, and visual comparisons + and TF-MoDISco seqlets. The analysis includes hit counts, overlap statistics, and visual comparisons of contribution weight matrices (CWMs).

diff --git a/src/finemo/visualization.py b/src/finemo/visualization.py index 6346857..74e4aa2 100644 --- a/src/finemo/visualization.py +++ b/src/finemo/visualization.py @@ -28,9 +28,9 @@ def abbreviate_motif_name(name: str) -> str: - """Convert TF-MoDISCo motif names to abbreviated format. + """Convert TF-MoDISco motif names to abbreviated format. - Converts full TF-MoDISCo pattern names to shorter, more readable format + Converts full TF-MoDISco pattern names to shorter, more readable format for display in plots and reports. Parameters @@ -277,7 +277,7 @@ def plot_seqlet_confusion_heatmap( ) -> None: """Plot confusion matrix heatmap comparing seqlets to hit calls. - Creates a heatmap showing the overlap between TF-MoDISCo seqlets and + Creates a heatmap showing the overlap between TF-MoDISco seqlets and Fi-NeMo hit calls. Rows represent seqlet motifs, columns represent hit motifs. Parameters @@ -578,7 +578,7 @@ def plot_hit_vs_seqlet_counts( """Plot scatter plot comparing hit counts to seqlet counts per motif. Creates a log-log scatter plot showing the relationship between the number - of hits called by Fi-NeMo and the number of seqlets identified by TF-MoDISCo + of hits called by Fi-NeMo and the number of seqlets identified by TF-MoDISco for each motif. Includes diagonal reference line and motif annotations. Parameters @@ -660,7 +660,7 @@ def write_report( Whether recall metrics were computed and should be included in the report template. use_seqlets : bool - Whether TF-MoDISCo seqlet data was used in the analysis + Whether TF-MoDISco seqlet data was used in the analysis and should be referenced in the report. Notes From 03fb6aa8f566348ccd4fd5c92611d3097376ea68 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Mon, 1 Sep 2025 12:40:04 -0700 Subject: [PATCH 23/39] Update license --- LICENSE | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/LICENSE b/LICENSE index dd86235..ebb7494 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,6 @@ MIT License -Copyright (c) 2023 Austin Wang +Copyright (c) 2025 Austin Wang Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal From abf970fe91ffa80586b34b152dc7f98d870dec1b Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Mon, 1 Sep 2025 12:54:02 -0700 Subject: [PATCH 24/39] API docs --- .github/workflows/docs.yml | 33 +++++++++++++++++++++++++++++++++ .gitignore | 1 + 2 files changed, 34 insertions(+) create mode 100644 .github/workflows/docs.yml diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 0000000..a5346ae --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,33 @@ +name: Generate API Documentation + +on: + push: + branches: [ dev ] + workflow_dispatch: + +jobs: + deploy-docs: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v3 + + - name: Set up Python + uses: actions/setup-python@v4 + with: + python-version: '3.x' + + - name: Install dependencies + run: | + pip install pdoc + pip install -e . + + - name: Generate documentation + run: | + pdoc --html --output-dir ./docs --force finemo + + - name: Deploy to GitHub Pages + uses: peaceiris/actions-gh-pages@v3 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + publish_dir: ./docs/finemo \ No newline at end of file diff --git a/.gitignore b/.gitignore index 78dbf8b..5e3aa36 100644 --- a/.gitignore +++ b/.gitignore @@ -1,6 +1,7 @@ *.egg-info __pycache__ /.* +!/.github /notebooks /scratch.txt /scratch \ No newline at end of file From 4c2f2faea90a6c28201698b62384944a92dc2f7d Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Mon, 1 Sep 2025 12:58:15 -0700 Subject: [PATCH 25/39] API docs --- .github/workflows/docs.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml index a5346ae..8499f19 100644 --- a/.github/workflows/docs.yml +++ b/.github/workflows/docs.yml @@ -24,10 +24,10 @@ jobs: - name: Generate documentation run: | - pdoc --html --output-dir ./docs --force finemo + pdoc -d numpy --output-dir ./docs finemo - name: Deploy to GitHub Pages uses: peaceiris/actions-gh-pages@v3 with: github_token: ${{ secrets.GITHUB_TOKEN }} - publish_dir: ./docs/finemo \ No newline at end of file + publish_dir: ./docs \ No newline at end of file From b4998f9d8639379a8572cb8aa9da98355515e554 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Mon, 1 Sep 2025 13:01:53 -0700 Subject: [PATCH 26/39] Fix missing dependencies --- pyproject.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 747bb6a..2344165 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,8 @@ dependencies = [ "tqdm", "pyBigWig", "pyfaidx", - "jinja2" + "jinja2", + "jaxtyping" ] [project.scripts] From 1736b28674a1a141076007eb5cbf5895e13dd66b Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Mon, 1 Sep 2025 13:22:59 -0700 Subject: [PATCH 27/39] Module-based entry point --- src/finemo/__main__.py | 8 ++++++++ 1 file changed, 8 insertions(+) create mode 100644 src/finemo/__main__.py diff --git a/src/finemo/__main__.py b/src/finemo/__main__.py new file mode 100644 index 0000000..4a6bf2e --- /dev/null +++ b/src/finemo/__main__.py @@ -0,0 +1,8 @@ +""" +Entry point for running finemo's CLI as a module via 'python -m finemo'. +""" + +from .main import cli + +if __name__ == "__main__": + cli() From 6e065fd4aa6375791c2f59d165363db8b291b142 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Mon, 1 Sep 2025 13:45:04 -0700 Subject: [PATCH 28/39] Move defaults to functions --- src/finemo/__init__.py | 13 ++-- src/finemo/hitcaller.py | 38 +++++------ src/finemo/main.py | 145 ++++++++++++++++++++++------------------ 3 files changed, 105 insertions(+), 91 deletions(-) diff --git a/src/finemo/__init__.py b/src/finemo/__init__.py index 1e891a4..d3bb472 100644 --- a/src/finemo/__init__.py +++ b/src/finemo/__init__.py @@ -18,12 +18,12 @@ Modules ------- -hitcaller : Core Fi-NeMo algorithm implementation -data_io : Data input/output utilities -main : Command-line interface -evaluation : Performance assessment tools -visualization : Plotting and report generation -postprocessing : Hit refinement and analysis +- hitcaller : Core Fi-NeMo algorithm implementation +- data_io : Data input/output utilities +- main : Command-line interface +- evaluation : Performance assessment tools +- visualization : Plotting and report generation +- postprocessing : Hit refinement and analysis Examples -------- @@ -60,4 +60,5 @@ -------- TF-MoDISco : https://github.com/jmschrei/tfmodisco-lite BPNet : https://github.com/kundajelab/bpnet-refactor +ChromBPNet: https://github.com/kundajelab/chrombpnet """ diff --git a/src/finemo/hitcaller.py b/src/finemo/hitcaller.py index b95c360..a968b44 100644 --- a/src/finemo/hitcaller.py +++ b/src/finemo/hitcaller.py @@ -418,16 +418,16 @@ def fit_contribs( cwm_trim_mask: Float[ndarray, "M W"], use_hypothetical: bool, lambdas: Float[ndarray, " M"], - step_size_max: float, - step_size_min: float, - sqrt_transform: bool, - convergence_tol: float, - max_steps: int, - batch_size: int, - step_adjust: float, - post_filter: bool, - device: Optional[torch.device], - compile_optimizer: bool, + step_size_max: float = 3.0, + step_size_min: float = 0.08, + sqrt_transform: bool = False, + convergence_tol: float = 0.0005, + max_steps: int = 10000, + batch_size: int = 2000, + step_adjust: float = 0.7, + post_filter: bool = True, + device: Optional[torch.device] = None, + compile_optimizer: bool = False, eps: float = 1.0, ) -> Tuple[pl.DataFrame, pl.DataFrame]: """Call motif hits by fitting sparse linear model to contribution scores. @@ -454,25 +454,25 @@ def fit_contribs( projected scores (False). lambdas : Float[ndarray, "M"] L1 regularization weights for each motif. - step_size_max : float + step_size_max : float, default 3.0 Maximum optimization step size. - step_size_min : float + step_size_min : float, default 0.08 Minimum optimization step size (for convergence failure detection). - sqrt_transform : bool + sqrt_transform : bool, default False Whether to apply signed square root transformation to inputs. - convergence_tol : float + convergence_tol : float, default 0.0005 Convergence tolerance based on duality gap. - max_steps : int + max_steps : int, default 10000 Maximum number of optimization steps. - batch_size : int + batch_size : int, default 2000 Number of regions to process simultaneously. - step_adjust : float + step_adjust : float, default 0.7 Factor to reduce step size when optimization diverges. - post_filter : bool + post_filter : bool, default True Whether to filter hits based on similarity threshold. device : torch.device, optional Target device for computation. Auto-detected if None. - compile_optimizer : bool + compile_optimizer : bool, default False Whether to JIT compile the optimizer for speed. eps : float, default 1.0 Small constant for numerical stability. diff --git a/src/finemo/main.py b/src/finemo/main.py index 9ab84f4..fe8623f 100644 --- a/src/finemo/main.py +++ b/src/finemo/main.py @@ -15,6 +15,7 @@ import os import argparse import warnings +import inspect from typing import Optional, List import polars as pl @@ -26,7 +27,7 @@ def extract_regions_bw( fa_path: str, bw_paths: List[str], out_path: str, - region_width: int, + region_width: int = 1000, ) -> None: """Extract genomic regions and contribution scores from bigWig and FASTA files. @@ -42,7 +43,7 @@ def extract_regions_bw( List of bigWig file paths containing contribution scores. out_path : str Output path for NPZ file. - region_width : int + region_width : int, default 1000 Width of regions to extract around peak summits. Notes @@ -66,7 +67,7 @@ def extract_regions_chrombpnet_h5( chrom_order_path: Optional[str], h5_paths: List[str], out_path: str, - region_width: int, + region_width: int = 1000, ) -> None: """Extract genomic regions and contribution scores from ChromBPNet HDF5 files. @@ -80,7 +81,7 @@ def extract_regions_chrombpnet_h5( List of ChromBPNet HDF5 file paths. out_path : str Output path for NPZ file. - region_width : int + region_width : int, default 1000 Width of regions to extract around peak summits. """ half_width = region_width // 2 @@ -100,7 +101,7 @@ def extract_regions_bpnet_h5( chrom_order_path: Optional[str], h5_paths: List[str], out_path: str, - region_width: int, + region_width: int = 1000, ) -> None: """Extract genomic regions and contribution scores from BPNet HDF5 files. @@ -114,7 +115,7 @@ def extract_regions_bpnet_h5( List of BPNet HDF5 file paths. out_path : str Output path for NPZ file. - region_width : int + region_width : int, default 1000 Width of regions to extract around peak summits. """ half_width = region_width // 2 @@ -135,7 +136,7 @@ def extract_regions_modisco_fmt( shaps_paths: List[str], ohe_path: str, out_path: str, - region_width: int, + region_width: int = 1000, ) -> None: """Extract genomic regions and contribution scores from TF-MoDISco format files. @@ -151,7 +152,7 @@ def extract_regions_modisco_fmt( Path to .npy/.npz file containing one-hot encoded sequences. out_path : str Output path for NPZ file. - region_width : int + region_width : int, default 1000 Width of regions to extract around peak summits. """ half_width = region_width // 2 @@ -177,21 +178,21 @@ def call_hits( motif_names_path: Optional[str], motif_lambdas_path: Optional[str], out_dir: str, - cwm_trim_coords_path: Optional[str], - cwm_trim_thresholds_path: Optional[str], - cwm_trim_threshold_default: float, - lambda_default: float, - step_size_max: float, - step_size_min: float, - sqrt_transform: bool, - convergence_tol: float, - max_steps: int, - batch_size: int, - step_adjust: float, - device: Optional[str], - mode: str, - no_post_filter: bool, - compile_optimizer: bool, + cwm_trim_coords_path: Optional[str] = None, + cwm_trim_thresholds_path: Optional[str] = None, + cwm_trim_threshold_default: float = 0.3, + lambda_default: float = 0.7, + step_size_max: float = 3.0, + step_size_min: float = 0.08, + sqrt_transform: bool = False, + convergence_tol: float = 0.0005, + max_steps: int = 10000, + batch_size: int = 2000, + step_adjust: float = 0.7, + device: Optional[str] = None, + mode: str = "pp", + no_post_filter: bool = False, + compile_optimizer: bool = False, ) -> None: """Call motif hits using the Fi-NeMo algorithm on preprocessed genomic regions. @@ -224,31 +225,31 @@ def call_hits( Path to file specifying custom motif trimming coordinates. cwm_trim_thresholds_path : str, optional Path to file specifying custom motif trimming thresholds. - cwm_trim_threshold_default : float - Default threshold for motif trimming (typically 0.3). - lambda_default : float - Default L1 regularization weight (typically 0.7). - step_size_max : float + cwm_trim_threshold_default : float, default 0.3 + Default threshold for motif trimming. + lambda_default : float, default 0.7 + Default L1 regularization weight. + step_size_max : float, default 3.0 Maximum optimization step size. - step_size_min : float + step_size_min : float, default 0.08 Minimum optimization step size. - sqrt_transform : bool + sqrt_transform : bool, default False Whether to apply signed square root transform to contributions. - convergence_tol : float + convergence_tol : float, default 0.0005 Convergence tolerance for duality gap. - max_steps : int + max_steps : int, default 10000 Maximum number of optimization steps. - batch_size : int + batch_size : int, default 2000 Batch size for GPU processing. - step_adjust : float + step_adjust : float, default 0.7 Step size adjustment factor on divergence. device : str, optional DEPRECATED. Use CUDA_VISIBLE_DEVICES environment variable instead. - mode : str + mode : str, default "pp" Contribution type mode ('pp', 'ph', 'hp', 'hh') where 'p'=projected, 'h'=hypothetical. - no_post_filter : bool + no_post_filter : bool, default False If True, skip post-hit-calling similarity filtering. - compile_optimizer : bool + compile_optimizer : bool, default False Whether to JIT-compile the optimizer for speed. Notes @@ -399,10 +400,10 @@ def report( motifs_include_path: Optional[str], motif_names_path: Optional[str], out_dir: str, - modisco_region_width: int, - cwm_trim_threshold: float, - compute_recall: bool, - use_seqlets: bool, + modisco_region_width: int = 400, + cwm_trim_threshold: float = 0.3, + compute_recall: bool = True, + use_seqlets: bool = True, ) -> None: """Generate comprehensive HTML report with statistics and visualizations. @@ -427,13 +428,13 @@ def report( DEPRECATED. This information is inferred from hit calling outputs. out_dir : str Output directory for report files. - modisco_region_width : int + modisco_region_width : int, default 400 Width of regions used by TF-MoDISco (needed for coordinate conversion). - cwm_trim_threshold : float + cwm_trim_threshold : float, default 0.3 DEPRECATED. This information is inferred from hit calling outputs. - compute_recall : bool + compute_recall : bool, default True Whether to compute recall metrics against TF-MoDISco seqlets. - use_seqlets : bool + use_seqlets : bool, default True Whether to include seqlet-based comparisons in the report. Notes @@ -603,7 +604,7 @@ def report( ) -def collapse_hits(hits_path: str, out_path: str, overlap_frac: float) -> None: +def collapse_hits(hits_path: str, out_path: str, overlap_frac: float = 0.2) -> None: """Collapse overlapping hits by selecting the best hit per overlapping group. This function processes a set of motif hits and identifies overlapping hits, @@ -617,7 +618,7 @@ def collapse_hits(hits_path: str, out_path: str, overlap_frac: float) -> None: Path to input TSV file containing hit data (hits.tsv or hits_unique.tsv). out_path : str Path to output TSV file with additional 'is_primary' column. - overlap_frac : float + overlap_frac : float, default 0.2 Minimum fractional overlap for considering hits as overlapping. For hits of lengths x and y, minimum overlap = overlap_frac * (x + y) / 2. @@ -637,7 +638,7 @@ def collapse_hits(hits_path: str, out_path: str, overlap_frac: float) -> None: ) -def intersect_hits(hits_paths: List[str], out_path: str, relaxed: bool) -> None: +def intersect_hits(hits_paths: List[str], out_path: str, relaxed: bool = False) -> None: """Find intersection of hits across multiple Fi-NeMo runs. This function identifies motif hits that are consistently called across @@ -651,7 +652,7 @@ def intersect_hits(hits_paths: List[str], out_path: str, relaxed: bool) -> None: out_path : str Path to output TSV file containing intersection results. Duplicate columns are suffixed with run index. - relaxed : bool + relaxed : bool, default False If True, uses relaxed intersection criteria based only on motif names and untrimmed coordinates. If False, assumes consistent region definitions and motif trimming across runs. @@ -733,7 +734,9 @@ def cli() -> None: "-w", "--region-width", type=int, - default=1000, + default=inspect.signature(extract_regions_bw) + .parameters["region_width"] + .default, help="The width of the input region centered around each peak summit.", ) @@ -779,7 +782,9 @@ def cli() -> None: "-w", "--region-width", type=int, - default=1000, + default=inspect.signature(extract_regions_chrombpnet_h5) + .parameters["region_width"] + .default, help="The width of the input region centered around each peak summit.", ) @@ -825,7 +830,9 @@ def cli() -> None: "-w", "--region-width", type=int, - default=1000, + default=inspect.signature(extract_regions_chrombpnet_h5) + .parameters["region_width"] + .default, help="The width of the input region centered around each peak summit.", ) @@ -871,7 +878,9 @@ def cli() -> None: "-w", "--region-width", type=int, - default=1000, + default=inspect.signature(extract_regions_bpnet_h5) + .parameters["region_width"] + .default, help="The width of the input region centered around each peak summit.", ) @@ -925,7 +934,9 @@ def cli() -> None: "-w", "--region-width", type=int, - default=1000, + default=inspect.signature(extract_regions_modisco_fmt) + .parameters["region_width"] + .default, help="The width of the input region centered around each peak summit.", ) @@ -939,7 +950,7 @@ def cli() -> None: "-M", "--mode", type=str, - default="pp", + default=inspect.signature(call_hits).parameters["mode"].default, choices={"pp", "ph", "hp", "hh"}, help="The type of attributions to use for CWM's and input contribution scores, respectively. 'h' for hypothetical and 'p' for projected.", ) @@ -1001,7 +1012,9 @@ def cli() -> None: "-t", "--cwm-trim-threshold", type=float, - default=0.3, + default=inspect.signature(call_hits) + .parameters["cwm_trim_threshold_default"] + .default, help="The default threshold to determine motif start and end positions within the full CWMs.", ) call_hits_parser.add_argument( @@ -1023,7 +1036,7 @@ def cli() -> None: "-l", "--global-lambda", type=float, - default=0.7, + default=inspect.signature(call_hits).parameters["lambda_default"].default, help="The default L1 regularization weight determining the sparsity of hits.", ) call_hits_parser.add_argument( @@ -1064,42 +1077,42 @@ def cli() -> None: "-s", "--step-size-max", type=float, - default=3.0, + default=inspect.signature(call_hits).parameters["step_size_max"].default, help="The maximum optimizer step size.", ) call_hits_parser.add_argument( "-i", "--step-size-min", type=float, - default=0.08, + default=inspect.signature(call_hits).parameters["step_size_min"].default, help="The minimum optimizer step size.", ) call_hits_parser.add_argument( "-j", "--step-adjust", type=float, - default=0.7, + default=inspect.signature(call_hits).parameters["step_adjust"].default, help="The optimizer step size adjustment factor. If the optimizer diverges, the step size is multiplicatively adjusted by this factor", ) call_hits_parser.add_argument( "-c", "--convergence-tol", type=float, - default=0.0005, + default=inspect.signature(call_hits).parameters["convergence_tol"].default, help="The tolerance for determining convergence. The optimizer exits when the duality gap is less than the tolerance.", ) call_hits_parser.add_argument( "-S", "--max-steps", type=int, - default=10000, + default=inspect.signature(call_hits).parameters["max_steps"].default, help="The maximum number of optimization steps.", ) call_hits_parser.add_argument( "-b", "--batch-size", type=int, - default=2000, + default=inspect.signature(call_hits).parameters["batch_size"].default, help="The batch size used for optimization.", ) call_hits_parser.add_argument( @@ -1177,14 +1190,14 @@ def cli() -> None: "-W", "--modisco-region-width", type=int, - default=400, + default=inspect.signature(report).parameters["modisco_region_width"].default, help="The width of the region around each peak summit used by tfmodisco-lite.", ) report_parser.add_argument( "-t", "--cwm-trim-threshold", type=float, - default=0.3, + default=inspect.signature(report).parameters["cwm_trim_threshold"].default, help="DEPRECATED: This information is now inferred from the outputs of `finemo call-hits`.", ) report_parser.add_argument( @@ -1224,7 +1237,7 @@ def cli() -> None: "-O", "--overlap-frac", type=float, - default=0.2, + default=inspect.signature(collapse_hits).parameters["overlap_frac"].default, help="The threshold for determining overlapping hits. For two hits with lengths x and y, the minimum overlap is defined as `overlap_frac * (x + y) / 2`. The default value of 0.2 means that two hits must overlap by at least 20% of their average lengths to be considered overlapping.", ) From 5a1489c69ee8690466446f61dc261a37ebc9e1c1 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Mon, 1 Sep 2025 13:50:31 -0700 Subject: [PATCH 29/39] Link to API docs --- README.md | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index 32affa9..0b7ce64 100644 --- a/README.md +++ b/README.md @@ -90,7 +90,11 @@ Recommended: - Peak region coordinates in uncompressed [ENCODE NarrowPeak](https://genome.ucsc.edu/FAQ/FAQformat.html#format12) format. -## Usage +## API Documentation + +For detailed Python API documentation, see: https://www.austintwang.com/finemo_gpu/finemo.html + +## Command-Line Usage Fi-NeMo provides a command-line utility named `finemo` for motif instance calling and analysis. The typical workflow involves three main steps: From 36ed9ea27231dd2b147f3529faac4f4432989ab2 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Mon, 1 Sep 2025 14:04:07 -0700 Subject: [PATCH 30/39] Define `__all__` --- README.md | 2 +- src/finemo/__init__.py | 16 ++++++++++++++++ src/finemo/main.py | 3 --- 3 files changed, 17 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 0b7ce64..e64a859 100644 --- a/README.md +++ b/README.md @@ -92,7 +92,7 @@ Recommended: ## API Documentation -For detailed Python API documentation, see: https://www.austintwang.com/finemo_gpu/finemo.html +For Fi-NeMo's Python API documentation, see: https://www.austintwang.com/finemo_gpu/finemo.html ## Command-Line Usage diff --git a/src/finemo/__init__.py b/src/finemo/__init__.py index d3bb472..dc14e91 100644 --- a/src/finemo/__init__.py +++ b/src/finemo/__init__.py @@ -62,3 +62,19 @@ BPNet : https://github.com/kundajelab/bpnet-refactor ChromBPNet: https://github.com/kundajelab/chrombpnet """ + +from . import data_io +from . import hitcaller +from . import evaluation +from . import visualization +from . import postprocessing +from . import main + +__all__ = [ + "data_io", + "hitcaller", + "evaluation", + "visualization", + "postprocessing", + "main", +] diff --git a/src/finemo/main.py b/src/finemo/main.py index fe8623f..3ffb2af 100644 --- a/src/finemo/main.py +++ b/src/finemo/main.py @@ -679,9 +679,6 @@ def cli() -> None: - Motif hit calling using the Fi-NeMo algorithm - Report generation and visualization - Post-processing operations (hit collapsing, intersection) - - The CLI supports comprehensive workflows for transcription factor motif - analysis from raw genomic data to publication-ready visualizations. """ parser = argparse.ArgumentParser() subparsers = parser.add_subparsers(required=True, dest="cmd") From 6f130307a9dff03c5620c231b82e2d461bb3d528 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Mon, 1 Sep 2025 14:22:57 -0700 Subject: [PATCH 31/39] README tweaks --- README.md | 5 ++--- src/finemo/__init__.py | 2 +- 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index e64a859..1b1ba64 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,6 @@ -# Fi-NeMo: Finding Neural network Motifs +# Fi-NeMo: Finding Neural Network Motifs -**Fi-NeMo** (**Fi**nding **Ne**ural network **Mo**tifs) is a GPU-accelerated motif instance calling tool for identifying transcription factor binding sites from neural network contribution scores. +**Fi-NeMo** (**Fi**nding **Ne**ural Network **Mo**tifs) is a GPU-accelerated motif instance calling tool for identifying transcription factor binding sites from neural network contribution scores. ## Overview @@ -260,7 +260,6 @@ Usage: `finemo call-hits -r -m -o [-p ] #### Important Notes - **Scale Invariance**: Hit calling depends on motif and contribution score shapes, not absolute magnitudes. Use `hit_coefficient_global` or `hit_importance` for importance-based thresholding. -- **Competitive Assignment**: Overlapping motif candidates compete; only the best-fitting motif at each position receives a non-zero coefficient. - **Legacy Format Support**: Convert older TF-MoDISco files using `modisco convert` from [tfmodisco-lite](https://github.com/jmschrei/tfmodisco-lite). ### Output reporting and post-processing diff --git a/src/finemo/__init__.py b/src/finemo/__init__.py index dc14e91..6b18638 100644 --- a/src/finemo/__init__.py +++ b/src/finemo/__init__.py @@ -1,4 +1,4 @@ -"""Fi-NeMo: Finding Neural network Motifs. +"""Fi-NeMo: Finding Neural Network Motifs. A GPU-accelerated motif instance calling tool for identifying transcription factor binding sites from neural network contribution scores. From fb962d4d0c4027d43d0d1c0e6b22052ced6fc5c8 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Mon, 1 Sep 2025 18:45:12 -0700 Subject: [PATCH 32/39] Make dimension names consistent --- src/finemo/hitcaller.py | 178 +++++++++++++++++++++------------------- 1 file changed, 92 insertions(+), 86 deletions(-) diff --git a/src/finemo/hitcaller.py b/src/finemo/hitcaller.py index a968b44..27cea8c 100644 --- a/src/finemo/hitcaller.py +++ b/src/finemo/hitcaller.py @@ -24,9 +24,6 @@ from tqdm import tqdm -# Type aliases for tensor operations -ArrayLike = Union[ndarray, torch.Tensor] - def prox_grad_step( coefficients: Float[Tensor, "B M P"], @@ -44,8 +41,12 @@ def prox_grad_step( The goal is to represent contribution scores as a sparse linear combination of motif contribution weight matrices (CWMs). - B = batch size, M = number of motifs, L = sequence length, W = motif width. - P = L - W + 1 (the number of positions with coefficients). + Dimension notation: + - B = batch size (number of regions processed simultaneously) + - M = number of motifs + - L = sequence length + - W = motif width (length of each motif) + - P = L - W + 1 (number of valid motif positions) Parameters ---------- @@ -67,12 +68,12 @@ def prox_grad_step( Returns ------- - c_next : Float[Tensor, "B M P"], shape (b, m, l - w + 1) - Updated coefficient matrix after the optimization step. - dual_gap : Float[Tensor, "B"] - Duality gap for convergence assessment. - nll : Float[Tensor, "B"] - Negative log likelihood (proportional to MSE). + c_next : Float[Tensor, "B M P"] + Updated coefficient matrix after the optimization step (shape: batch_size × motifs × positions). + dual_gap : Float[Tensor, " B"] + Duality gap for convergence assessment (shape: batch_size). + nll : Float[Tensor, " B"] + Negative log likelihood (proportional to MSE, shape: batch_size). Notes ----- @@ -89,31 +90,31 @@ def prox_grad_step( """ # Forward pass: convolution operations require specific tensor layouts coef_adj = coefficients * importance_scale - pred_unmasked = F.conv_transpose1d(coef_adj, cwms) # (b, 4, l) + pred_unmasked = F.conv_transpose1d(coef_adj, cwms) # (B, 4, L) pred = ( pred_unmasked * sequences - ) # (b, 4, l), element-wise masking for projected mode + ) # (B, 4, L), element-wise masking for projected mode # Compute gradient * -1 - residuals = contribs - pred # (b, 4, l) - ngrad = F.conv1d(residuals, cwms) * importance_scale # (b, m, l - w + 1) + residuals = contribs - pred # (B, 4, L) + ngrad = F.conv1d(residuals, cwms) * importance_scale # (B, M, P) # Negative log likelihood (proportional to MSE) - nll = (residuals**2).sum(dim=(1, 2)) # (b) + nll = (residuals**2).sum(dim=(1, 2)) # (B) # Compute duality gap for convergence assessment - dual_norm = (ngrad / lambdas).amax(dim=(1, 2)) # (b) - dual_scale = (torch.clamp(1 / dual_norm, max=1.0) ** 2 + 1) / 2 # (b) - nll_scaled = nll * dual_scale # (b) + dual_norm = (ngrad / lambdas).amax(dim=(1, 2)) # (B) + dual_scale = (torch.clamp(1 / dual_norm, max=1.0) ** 2 + 1) / 2 # (B) + nll_scaled = nll * dual_scale # (B) - dual_diff = (residuals * contribs).sum(dim=(1, 2)) # (b) + dual_diff = (residuals * contribs).sum(dim=(1, 2)) # (B) l1_term = (torch.abs(coefficients).sum(dim=2, keepdim=True) * lambdas).sum( dim=(1, 2) - ) # (b) - dual_gap = (nll_scaled - dual_diff + l1_term).abs() # (b) + ) # (B) + dual_gap = (nll_scaled - dual_diff + l1_term).abs() # (B) # Compute proximal gradient descent step - c_next = coefficients + step_sizes * (ngrad - lambdas) # (b, m, l - w + 1) + c_next = coefficients + step_sizes * (ngrad - lambdas) # (B, M, P) c_next = F.relu(c_next) # Ensure non-negativity constraint return c_next, dual_gap, nll @@ -128,7 +129,7 @@ def optimizer_step( coef: Float[Tensor, "B M P"], i: Float[Tensor, "B 1 1"], step_sizes: Float[Tensor, "B 1 1"], - sequence_length: int, + L: int, lambdas: Float[Tensor, "1 M 1"], ) -> Tuple[ Float[Tensor, "B M P"], @@ -142,8 +143,12 @@ def optimizer_step( to improve convergence speed while maintaining the non-negative constraint on coefficients. - B = batch size, M = number of motifs, L = sequence length, W = motif width. - P = L - W + 1 (the number of positions with coefficients). + Dimension notation: + - B = batch size (number of regions processed simultaneously) + - M = number of motifs + - L = sequence length + - W = motif width (length of each motif) + - P = L - W + 1 (number of valid motif positions) Parameters ---------- @@ -163,7 +168,7 @@ def optimizer_step( Iteration counter for each batch element. step_sizes : Float[Tensor, "B 1 1"] Step sizes for optimization. - sequence_length : int + L : int Sequence length for normalization. lambdas : Float[Tensor, "1 M 1"] Regularization parameters. @@ -171,13 +176,13 @@ def optimizer_step( Returns ------- coef_inter : Float[Tensor, "B M P"] - Updated intermediate coefficients with momentum. + Updated intermediate coefficients with momentum (shape: batch_size × motifs × positions). coef : Float[Tensor, "B M P"] - Updated coefficient matrix. + Updated coefficient matrix (shape: batch_size × motifs × positions). gap : Float[Tensor, " B"] - Normalized duality gap. + Normalized duality gap (shape: batch_size). nll : Float[Tensor, " B"] - Normalized negative log likelihood. + Normalized negative log likelihood (shape: batch_size). Notes ----- @@ -195,8 +200,8 @@ def optimizer_step( coef, gap, nll = prox_grad_step( coef_inter, importance_scale, cwms, contribs, sequences, lambdas, step_sizes ) - gap = gap / sequence_length - nll = nll / (2 * sequence_length) + gap = gap / L + nll = nll / (2 * L) # Compute updated coefficients with Nesterov momentum mom_term = i / (i + 3.0) @@ -250,7 +255,10 @@ class BatchLoaderBase(ABC): This class provides common functionality for different input formats including batch indexing and padding for consistent batch sizes. - N = number of sequences, L = sequence length. + Dimension notation: + - N = number of sequences/regions in dataset + - L = sequence length + - B = batch size (number of regions processed simultaneously) Parameters ---------- @@ -258,7 +266,7 @@ class BatchLoaderBase(ABC): Contribution scores array. sequences : Int[Tensor, "N 4 L"] One-hot encoded sequences array. - sequence_length : int + L : int Sequence length. device : torch.device Target device for tensor operations. @@ -268,12 +276,12 @@ def __init__( self, contribs: Union[Float[Tensor, "N 4 L"], Float[Tensor, "N L"]], sequences: Int[Tensor, "N 4 L"], - sequence_length: int, + L: int, device: torch.device, ) -> None: self.contribs = contribs self.sequences = sequences - self.sequence_length = sequence_length + self.L = L self.device = device def _get_inds_and_pad_lens( @@ -291,13 +299,13 @@ def _get_inds_and_pad_lens( Returns ------- inds : Int[Tensor, " Z"] - Padded indices tensor with -1 for padding positions. + Padded indices tensor with -1 for padding positions (shape: padded_batch_size). pad_lens : tuple Padding specification for F.pad (left, right, top, bottom, front, back). """ - n = end - start + N = end - start end = min(end, self.contribs.shape[0]) - overhang = n - (end - start) + overhang = N - (end - start) pad_lens = (0, 0, 0, 0, 0, overhang) inds = F.pad( @@ -314,7 +322,9 @@ def load_batch( ]: """Load a batch of data. - B = batch size, L = sequence length. + Dimension notation: + - B = batch size (number of regions in this batch) + - L = sequence length Parameters ---------- @@ -326,11 +336,11 @@ def load_batch( Returns ------- contribs_batch : Float[Tensor, "B 4 L"] - Batch of contribution scores. + Batch of contribution scores (shape: batch_size × 4_bases × L). sequences_batch : Union[Int[Tensor, "B 4 L"], int] - Batch of sequences or scalar for hypothetical mode. - inds_batch : Int[Tensor, "B"] - Batch indices. + Batch of one-hot encoded sequences (shape: batch_size × 4_bases × L) or scalar 1 for hypothetical mode. + inds_batch : Int[Tensor, " B"] + Batch indices mapping to original sequence indices (shape: batch_size). Notes ----- @@ -356,12 +366,12 @@ def load_batch( contribs_batch = _to_channel_last_layout( contribs_compact, device=self.device, dtype=torch.float32 ) - sequences_batch = F.pad(self.sequences[start:end, :, :], pad_lens) # (b, 4, l) + sequences_batch = F.pad(self.sequences[start:end, :, :], pad_lens) # (B, 4, L) sequences_batch = _to_channel_last_layout( sequences_batch, device=self.device, dtype=torch.int8 ) - contribs_batch = contribs_batch * sequences_batch # (b, 4, l) + contribs_batch = contribs_batch * sequences_batch # (B, 4, L) return contribs_batch, sequences_batch, inds @@ -382,7 +392,7 @@ def load_batch( contribs_hyp = _to_channel_last_layout( contribs_hyp, device=self.device, dtype=torch.float32 ) - sequences_batch = F.pad(self.sequences[start:end, :, :], pad_lens) # (b, 4, l) + sequences_batch = F.pad(self.sequences[start:end, :, :], pad_lens) # (B, 4, L) sequences_batch = _to_channel_last_layout( sequences_batch, device=self.device, dtype=torch.int8 ) @@ -440,20 +450,24 @@ def fit_contribs( Parameters ---------- cwms : Float[ndarray, "M 4 W"] - Motif contribution weight matrices where M = number of motifs, - 4 = DNA bases (A, C, G, T), W = motif width. + Motif contribution weight matrices where: + - M = number of motifs (transcription factor binding patterns) + - 4 = DNA bases (A, C, G, T dimensions) + - W = motif width (length of each motif pattern) contribs : Float[ndarray, "N 4 L"] | Float[ndarray, "N L"] - Neural network contribution scores where N = number of regions, - L = sequence length. Can be hypothetical (N, 4, L) or projected (N, L). - sequences : Float[ndarray, "N 4 L"] - One-hot encoded DNA sequences. + Neural network contribution scores where: + - N = number of regions in dataset + - L = sequence length + Can be hypothetical (N, 4, L) or projected (N, L) format. + sequences : Int[ndarray, "N 4 L"] + One-hot encoded DNA sequences (shape: num_regions × 4_bases × L). cwm_trim_mask : Float[ndarray, "M W"] - Binary mask indicating which positions of each CWM to use. + Binary mask indicating which positions of each CWM to use (shape: num_motifs × motif_width). use_hypothetical : bool Whether to use hypothetical contribution scores (True) or projected scores (False). - lambdas : Float[ndarray, "M"] - L1 regularization weights for each motif. + lambdas : Float[ndarray, " M"] + L1 regularization weights for each motif (shape: num_motifs). step_size_max : float, default 3.0 Maximum optimization step size. step_size_min : float, default 0.08 @@ -531,10 +545,10 @@ def fit_contribs( ... compile_optimizer=False ... ) """ - m, _, w = cwms.shape - n, _, sequence_length = sequences.shape + M, _, W = cwms.shape + N, _, L = sequences.shape - b = batch_size + B = batch_size # Using uppercase for consistency with dimension notation if device is None: if torch.cuda.is_available(): @@ -574,13 +588,9 @@ def fit_contribs( # Initialize batch loader if len(contribs_tensor.shape) == 3: if use_hypothetical: - batch_loader = BatchLoaderHyp( - contribs_tensor, sequences_tensor, sequence_length, device - ) + batch_loader = BatchLoaderHyp(contribs_tensor, sequences_tensor, L, device) else: - batch_loader = BatchLoaderProj( - contribs_tensor, sequences_tensor, sequence_length, device - ) + batch_loader = BatchLoaderProj(contribs_tensor, sequences_tensor, L, device) elif len(contribs_tensor.shape) == 2: if use_hypothetical: raise ValueError( @@ -588,7 +598,7 @@ def fit_contribs( ) else: batch_loader = BatchLoaderCompactFmt( - contribs_tensor, sequences_tensor, sequence_length, device + contribs_tensor, sequences_tensor, L, device ) else: raise ValueError( @@ -612,21 +622,21 @@ def fit_contribs( # Initialize buffers for optimizer coef_inter: Float[Tensor, "B M P"] = torch.zeros( - (b, m, sequence_length - w + 1) - ) # (b, m, sequence_length - w + 1) + (B, M, L - W + 1) + ) # (B, M, P) where P = L - W + 1 coef_inter = _to_channel_last_layout(coef_inter, device=device, dtype=torch.float32) coef: Float[Tensor, "B M P"] = torch.zeros_like(coef_inter) - i: Float[Tensor, "B 1 1"] = torch.zeros((b, 1, 1), dtype=torch.int, device=device) + i: Float[Tensor, "B 1 1"] = torch.zeros((B, 1, 1), dtype=torch.int, device=device) step_sizes: Float[Tensor, "B 1 1"] = torch.full( - (b, 1, 1), step_size_max, dtype=torch.float32, device=device + (B, 1, 1), step_size_max, dtype=torch.float32, device=device ) converged: Bool[Tensor, " B"] = torch.full( - (b,), True, dtype=torch.bool, device=device + (B,), True, dtype=torch.bool, device=device ) - num_load = b + num_load = B - contribs_buf: Float[Tensor, "B 4 L"] = torch.zeros((b, 4, sequence_length)) + contribs_buf: Float[Tensor, "B 4 L"] = torch.zeros((B, 4, L)) contribs_buf = _to_channel_last_layout( contribs_buf, device=device, dtype=torch.float32 ) @@ -635,25 +645,23 @@ def fit_contribs( if use_hypothetical: seqs_buf = 1 else: - seqs_buf = torch.zeros((b, 4, sequence_length)) + seqs_buf = torch.zeros((B, 4, L)) seqs_buf = _to_channel_last_layout(seqs_buf, device=device, dtype=torch.int8) - importance_scale_buf: Float[Tensor, "B M P"] = torch.zeros( - (b, m, sequence_length - w + 1) - ) + importance_scale_buf: Float[Tensor, "B M P"] = torch.zeros((B, M, L - W + 1)) importance_scale_buf = _to_channel_last_layout( importance_scale_buf, device=device, dtype=torch.float32 ) - inds_buf: Int[Tensor, " B"] = torch.zeros((b,), dtype=torch.int, device=device) + inds_buf: Int[Tensor, " B"] = torch.zeros((B,), dtype=torch.int, device=device) global_scale_buf: Float[Tensor, " B"] = torch.zeros( - (b,), dtype=torch.float, device=device + (B,), dtype=torch.float, device=device ) - with tqdm(disable=None, unit="regions", total=n, ncols=120) as pbar: + with tqdm(disable=None, unit="regions", total=N, ncols=120) as pbar: num_complete = 0 next_ind = 0 - while num_complete < n: + while num_complete < N: # Retire converged peaks and fill buffer with new data if num_load > 0: load_start = next_ind @@ -666,9 +674,7 @@ def fit_contribs( if sqrt_transform: contribs_batch = _signed_sqrt(contribs_batch) - global_scale_batch = ( - (contribs_batch**2).sum(dim=(1, 2)) / sequence_length - ).sqrt() + global_scale_batch = ((contribs_batch**2).sum(dim=(1, 2)) / L).sqrt() contribs_batch = torch.nan_to_num( contribs_batch / global_scale_batch[:, None, None] ) @@ -703,7 +709,7 @@ def fit_contribs( coef, i, step_sizes, - sequence_length, + L, lambdas_tensor, ) i += 1 From 2c9154fbe4bd3df4f247da4fc1cf5e347a3fe54d Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Tue, 9 Sep 2025 13:51:42 -0700 Subject: [PATCH 33/39] Ensure that hits_unique is well-defined --- README.md | 2 +- src/finemo/data_io.py | 7 +++++-- 2 files changed, 6 insertions(+), 3 deletions(-) diff --git a/README.md b/README.md index 1b1ba64..72e25df 100644 --- a/README.md +++ b/README.md @@ -198,7 +198,7 @@ Usage: `finemo call-hits -r -m -o [-p ] - `peak_name`: The name of the peak region containing the hit, taken from the `name` field of the input peak data. `NA` if peak coordinates are not provided. - `peak_id`: The numerical index of the peak region containing the hit. -`hits_unique.tsv`: A deduplicated list of hits in the same format as `hits.tsv`. In cases where peak regions overlap, `hits.tsv` may list multiple instances of a hit, each linked to a different peak. `hits_unique.tsv` arbitrarily selects one instance per duplicated hit. This file is generated only if peak coordinates are provided. +`hits_unique.tsv`: A deduplicated list of hits in the same format as `hits.tsv`. In cases where peak regions overlap, `hits.tsv` may list multiple instances of a hit, each linked to a different peak. `hits_unique.tsv` arbitrarily selects one instance per duplicated hit. **This file is empty if peak coordinates are not provided.** `hits.bed`: A coordinate-sorted BED file of unique hits. It includes: diff --git a/src/finemo/data_io.py b/src/finemo/data_io.py index d266178..8369108 100644 --- a/src/finemo/data_io.py +++ b/src/finemo/data_io.py @@ -1233,8 +1233,11 @@ def write_hits( ----- Creates three output files: - hits.tsv: Complete hit data with all instances - - hits_unique.tsv: Deduplicated hits by genomic position and motif + - hits_unique.tsv: Deduplicated hits by genomic position and motif (excludes rows with NA chromosome coordinates) - hits.bed: BED format file for genome browser visualization + + Rows where the chromosome field is NA are filtered out during deduplication + to ensure that data_unique only contains well-defined genomic coordinates. """ os.makedirs(out_dir, exist_ok=True) out_path_tsv = os.path.join(out_dir, "hits.tsv") @@ -1275,7 +1278,7 @@ def write_hits( .select(HITS_DTYPES.keys()) ) - data_unique = data_all.unique( + data_unique = data_all.filter(pl.col("chr").is_not_null()).unique( subset=["chr", "start", "motif_name", "strand"], maintain_order=True ) From 821bcb2756d31eac424934342cf4bd2baee1c485 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Tue, 9 Sep 2025 16:27:36 -0700 Subject: [PATCH 34/39] Remove unneeded guard --- src/finemo/main.py | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/finemo/main.py b/src/finemo/main.py index 3ffb2af..db78241 100644 --- a/src/finemo/main.py +++ b/src/finemo/main.py @@ -1345,11 +1345,6 @@ def cli() -> None: ) elif args.cmd == "report": - if args.no_recall and not args.no_seqlets: - raise ValueError( - "The `--no-seqlets` flag must be set in conjunction with `--no-recall`." - ) - report( args.regions, args.hits, From 55cb36a3527a7de5089150ac122c3fcda754e153 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Wed, 10 Sep 2025 18:28:21 -0700 Subject: [PATCH 35/39] Improve motif name sorting logic --- src/finemo/data_io.py | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/src/finemo/data_io.py b/src/finemo/data_io.py index 8369108..6dc6074 100644 --- a/src/finemo/data_io.py +++ b/src/finemo/data_io.py @@ -707,7 +707,7 @@ def softmax(x: Float[ndarray, "4 W"], temp: float = 100) -> Float[ndarray, "4 W" return exp / np.sum(exp, axis=0, keepdims=True) -def _motif_name_sort_key(data: Tuple[str, Any]) -> Union[Tuple[int], Tuple[int, str]]: +def _motif_name_sort_key(data: Tuple[str, Any]) -> Union[Tuple[int, int], Tuple[int, str]]: """Generate sort key for TF-MoDISco motif names. This function creates a sort key that orders motifs by pattern number, @@ -717,25 +717,27 @@ def _motif_name_sort_key(data: Tuple[str, Any]) -> Union[Tuple[int], Tuple[int, ---------- data : Tuple[str, Any] Tuple containing motif name as first element and additional data. - The motif name should follow the format 'pattern_N' where N is an integer. + The motif name should follow the format 'pattern_N' or 'pattern#N' where N is an integer. Returns ------- - Union[Tuple[int], Tuple[int, str]] + Union[Tuple[int, int], Tuple[int, str]] Sort key tuple for ordering motifs. Standard pattern names return - (pattern_number,) while non-standard names return (-1, name). + (0, pattern_number) while non-standard names return (1, name). Notes ----- This function is used internally by load_modisco_motifs to ensure consistent motif ordering across runs. """ - name = data[0] - if name.startswith("pattern_"): - pattern_num = int(name.split("_")[-1]) - return (pattern_num,) - else: - return (-1, name) # Mixed tuple types for sorting + pattern_name = data[0] + try: + return (0, int(pattern_name.split("_")[-1])) + except (ValueError, IndexError): + try: + return (0, int(pattern_name.split("#")[-1])) + except (ValueError, IndexError): + return (1, pattern_name) MODISCO_PATTERN_GROUPS = ["pos_patterns", "neg_patterns"] @@ -1036,10 +1038,7 @@ def load_modisco_seqlets( metacluster = modisco_results[name] - def get_pattern_number(x): - return int(x[0].split("_")[-1]) - - key = get_pattern_number + key = _motif_name_sort_key for _, (pattern_name, pattern) in enumerate( sorted(metacluster.items(), key=key) # type: ignore # HDF5 access ): From 56f21f4784195975092a55652930f9442e1f1093 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Thu, 11 Sep 2025 04:51:03 -0700 Subject: [PATCH 36/39] Properly encode report URLs --- src/finemo/templates/report.html | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/src/finemo/templates/report.html b/src/finemo/templates/report.html index a2389d1..c3a6ffc 100644 --- a/src/finemo/templates/report.html +++ b/src/finemo/templates/report.html @@ -322,13 +322,13 @@

Motif-specific hit and seqlet analysis

{{ item.num_seqlets_only }} {{ item.num_hits_restricted_only }} {% endif %} - - - - + + + + {% if compute_recall %} - - + + {% endif %} {% endfor %} @@ -389,10 +389,10 @@

Motif-specific hit quality metrics

{% for m in motif_names %} {{ m }} - - - - + + + + {% endfor %} From ed38c9dad392a765200ba3957635e7b22fea6d97 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Sun, 14 Sep 2025 07:41:36 -0700 Subject: [PATCH 37/39] Histogram edge case fallback --- src/finemo/visualization.py | 15 ++++++++++++--- 1 file changed, 12 insertions(+), 3 deletions(-) diff --git a/src/finemo/visualization.py b/src/finemo/visualization.py index 74e4aa2..d331cb9 100644 --- a/src/finemo/visualization.py +++ b/src/finemo/visualization.py @@ -112,7 +112,10 @@ def plot_hit_stat_distributions( fig, ax = plt.subplots(figsize=(5, 2)) # Plot coefficient distribution - ax.hist(coefficients, bins=50, density=True) + try: + ax.hist(coefficients, bins=50, density=True) + except ValueError: + ax.hist(coefficients, bins=1, density=True) output_path_png = os.path.join(motifs_dir, f"{m}_coefficients.png") plt.savefig(output_path_png, dpi=300) @@ -123,7 +126,10 @@ def plot_hit_stat_distributions( fig, ax = plt.subplots(figsize=(5, 2)) # Plot similarity distribution - ax.hist(similarities, bins=50, density=True) + try: + ax.hist(similarities, bins=50, density=True) + except ValueError: + ax.hist(similarities, bins=1, density=True) output_path_png = os.path.join(motifs_dir, f"{m}_similarities.png") plt.savefig(output_path_png, dpi=300) @@ -134,7 +140,10 @@ def plot_hit_stat_distributions( fig, ax = plt.subplots(figsize=(5, 2)) # Plot importance distribution - ax.hist(importances, bins=50, density=True) + try: + ax.hist(importances, bins=50, density=True) + except ValueError: + ax.hist(importances, bins=1, density=True) output_path_png = os.path.join(motifs_dir, f"{m}_importances.png") plt.savefig(output_path_png, dpi=300) From 9df34e83255401acc571926cb56e639d5aa1071d Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Sun, 14 Sep 2025 16:59:33 -0700 Subject: [PATCH 38/39] README typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 72e25df..115d397 100644 --- a/README.md +++ b/README.md @@ -304,4 +304,4 @@ Usage: `finemo intersect-hits -i -o [-r]` - `-i/--hits`: The path to one or more input hits file. This should be the `hits.tsv` or `hits_unique.tsv` file generated by the `finemo call-hits` command. - `-o/--out-path`: The path to the output file. Reoccuring columns are suffixed with the positional index of the input file (e.g. `hit_importance_1`), with the exception of index 0. -- `-r/--relaxed`: By default, the intersection assumes consistent input region definitions (name and coordinates) and motif trimming across runs. In contrast, this relaxed intersection criteria uses only motif names and untrimmed hit coordinates. However, this is not suitable when hit genomic coordinates are unknown (e.g. when using `finemo call-hits` with `-p/--peaks`). Default is False. +- `-r/--relaxed`: By default, the intersection assumes consistent input region definitions (name and coordinates) and motif trimming across runs. In contrast, this relaxed intersection criteria uses only motif names and untrimmed hit coordinates. However, this is not suitable when hit genomic coordinates are unknown. Default is False. From d5429df9237515f3af743cbce871e354fa308759 Mon Sep 17 00:00:00 2001 From: Austin Wang Date: Sun, 14 Sep 2025 17:01:37 -0700 Subject: [PATCH 39/39] Install torch from pip --- environment.yml | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/environment.yml b/environment.yml index 3720c11..0de69bf 100644 --- a/environment.yml +++ b/environment.yml @@ -1,12 +1,8 @@ channels: - - pytorch - - nvidia - conda-forge - bioconda - nodefaults dependencies: - - pytorch=2.5.1 - - pytorch-cuda=12.4 - python=3.11 - numba=0.61.2 - numpy=2.2.0 @@ -19,4 +15,7 @@ dependencies: - jinja2=3.1.4 - pybigwig=0.3.23 - pyfaidx=0.8.1.3 - - jaxtyping=0.3.2 \ No newline at end of file + - jaxtyping=0.3.2 + - pip=25.2 + - pip: + - torch==2.5.1 \ No newline at end of file