diff --git a/README.md b/README.md index 286871b..9691046 100644 --- a/README.md +++ b/README.md @@ -221,13 +221,14 @@ The two output formats are tables of comma-separated values with a header. | Start | Positive integer | Starting position of the feature (inclusive) | | End | Positive integer | Ending position of the feature (inclusive) | | Strand | `1` or `-1` | Whether the features is located on the positive (5'->3') or negative (3'->5') strand | -| CoveredSites | Positive integer | Number of sites in the feature that satisfy the minimum level of coverage | -| GenomeBases | Comma-separated positive integers | Frequencies of the bases in the feature in the reference genome (order: A, C, G, T) | -| SiteBasePairings | Comma-separated positive integers | Number of sites in which each genome-variant base pairings is found in the feature (order: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT) | -| ReadBasePairings | Comma-separated positive integers | Frequencies of genome-variant base pairings in the feature (order: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT) | +| TotalSites | Positive integer | Number of sites in the feature | +| ObservedBases | Comma-separated positive integers | Number and type of the bases in the feature in the reference genome (order: A, C, G, T) observed. The total of the 4 values corresponds to the total observed sites (reported by the editing tools e.g. Reditools3) | +| QualifiedBases | Comma-separated positive integers | Number and type of of the bases in the feature in the reference genome (order: A, C, G, T) that satisfy the minimum level of coverage and editing. The total of the 4 values corresponds to the total qualified sites (> cov) | +| SiteBasePairingsQualified| Comma-separated positive integers | Number of sites in which each genome-variant base pairings is found at reference level in the feature (order: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT) that satisfy the minimum level of coverage and editing | +| ReadBasePairingsQualified | Comma-separated positive integers | Number of sites in which each genome-variant base pairings is found at reads level in the feature (order: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT) that satisfy the minimum level of coverage and editing | > [!note] -> The number of **CoveredSites** can be higher than the sum of **SiteBasePairings** because of the presence of ambiguous bases (e.g. N) +> The number of **QualifiedBases** can differ from sum of AA,CC,GG,TT from **SiteBasePairingsQualified** because we can have site 100% edited that will not fall into one of these categories. An example of the feature output format is shown below, with some alterations to make the text line up in columns. @@ -275,10 +276,11 @@ This hierarchical information is provided in the same manner in the aggregate fi | ParentType | String | Type of the parent of the feature under which the aggregation was done | | AggregateType | String | Type of the features that are aggregated | | AggregationMode | `all_isoforms`, `longest_isoform`, `chimaera`, `feature` or `all-sites` | Way in which the aggregation was performed | -| CoveredSites | Positive integer | Number of sites in the aggregated features that satisfy the minimum level of coverage | -| GenomeBases | Comma-separated positive integers | Frequencies of the bases in the aggregated features in the reference genome (order: A, C, G, T) | -| SiteBasePairings | Comma-separated positive integers | Number of sites in which each genome-variant base pairings is found in the aggregated features (order: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT) | -| ReadBasePairings | Comma-separated positive integers | Frequencies of genome-variant base pairings in the aggregated features (order: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT) | +| TotalSites | Positive integer | Number of sites in the aggregated features | +| ObservedBases | Comma-separated positive integers | Number and type of the bases in the aggregated features in the reference genome (order: A, C, G, T) observed. The total of the 4 values corresponds to the total observed sites (reported by the editing tools e.g. Reditools3) | | +| QualifiedBases | Comma-separated positive integers | Number and type of of the bases in the aggregated features in the reference genome (order: A, C, G, T) that satisfy the minimum level of coverage and editing. The total of the 4 values corresponds to the total qualified sites (> cov) | | +| SiteBasePairingsQualifed | Comma-separated positive integers | Number of sites in which each genome-variant base pairings is found at reference level in the aggregated features (order: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT) observed | +| ReadBasePairingsQualifed | Comma-separated positive integers | Number of sites in which each genome-variant base pairings is found at reads level in the aggregated features (order: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT) that satisfy the minimum level of coverage and editing| In the output of Pluviometer, **aggregation** is the sum of counts from several features of the same type at some feature level. For instance, exons can be aggregated at transcript level, gene level, chromosome level, and genome level. @@ -344,6 +346,21 @@ $$ AG\ editing\ level = \sum_{i=0}^{n} \dfrac{AG_i}{AA_i + AC_i + AG_i + AT_i} $$ + +## Drip + +### espf (edited sites proportion in feature): + +denom_espf = df[f'{genome_base}_count'] # X_QualifiedBases (e.g. C_count) +df[espf_col] = df[f'{bp}_sites'] / denom_espf # XY_SiteBasePairingsQualified / X_QualifiedBases + +### espr (edited sites proportion in reads): + +df[total_reads_col] = XA_reads + XC_reads + XG_reads + XT_reads # all reads at X positions +df[espr_col] = df[f'{bp}_reads'] / df[total_reads_col] # XY_reads / sum(X*_reads) + +Drip retains a line only if at least one metric value is neither NA nor zero (i.e., at least one edit has been detected somewhere). Lines containing only NA values, only 0.0 values, or a mix of both are removed by default. + @@ -355,3 +372,7 @@ Jacques Dainat (@Juke34) ## Contributing Contributions from the community are welcome ! See the [Contributing guidelines](https://github.com/Juke34/rain/blob/main/CONTRIBUTING.md) + +## TODO + +update pluviometer to set NA for start end and strand instead of . to be able to use column as int64 in drip and barometer e.g. dtype={"SeqID": str, "Start": "Int64", "End": "Int64", "Strand": str} \ No newline at end of file diff --git a/bin/README b/bin/README index d48d843..8caf0ac 100644 --- a/bin/README +++ b/bin/README @@ -20,20 +20,12 @@ python -m pluviometer --sites SITES --gff GFF [OPTIONS] python pluviometer_wrapper.py --sites SITES --gff GFF [OPTIONS] ``` -### drip_features.py +### drip.py Post-processing tool for pluviometer feature output. Analyzes RNA editing from feature TSV files, calculating editing metrics (espf and espr) for all 16 genome-variant base pair combinations across multiple samples. Combines data into unified matrix format. **Usage:** ```bash -./drip_features.py OUTPUT_PREFIX FILE1:SAMPLE1 FILE2:SAMPLE2 [...] -``` - -### drip_aggregates.py -Post-processing tool for pluviometer aggregate output. Similar to drip_features.py but operates on aggregate-level data, calculating editing metrics for aggregated genomic regions across samples. - -**Usage:** -```bash -./drip_aggregates.py OUTPUT_PREFIX FILE1:SAMPLE1 FILE2:SAMPLE2 [...] +./drip.py OUTPUT_PREFIX FILE1:GROUP1:SAMPLE1:REPLICATE1 FILE2:GROUP1:SAMPLE2:REPLICATE1 [...] ``` ### restore_sequences.py diff --git a/bin/barometer_analyze.py b/bin/barometer_analyze.py new file mode 100755 index 0000000..2d4f5b2 --- /dev/null +++ b/bin/barometer_analyze.py @@ -0,0 +1,2297 @@ +#!/usr/bin/env python3 +""" +barometer_analyze.py – Exhaustive analysis pipeline for rain biomarker data. + +Loads barometer_aggregates_AG.tsv and barometer_features_AG.tsv, performs QC, +descriptive statistics, differential editing analysis, multivariate analysis, +correlation/network analysis, feature selection / biomarker ranking, +classification, stability analysis, and generates all tables + matplotlib figures. + +Results are saved into an output directory (default: barometer_results/) structured +by value_type → mtype → section. + +Usage: + python barometer_analyze.py [--aggregates FILE] [--features FILE] [--outdir DIR] +""" + +import argparse +import gc +import json +import logging +import multiprocessing +import os +import sys +import time +import warnings +from collections import defaultdict +from concurrent.futures import ProcessPoolExecutor, as_completed +from itertools import combinations +from pathlib import Path + +try: + import psutil +except ImportError: + psutil = None + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from scipy import stats +from scipy.cluster.hierarchy import linkage, dendrogram, fcluster +from scipy.spatial.distance import pdist +from sklearn.decomposition import PCA +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis +from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier +from sklearn.model_selection import cross_val_score, StratifiedKFold, LeaveOneOut +from sklearn.preprocessing import StandardScaler, LabelEncoder +from sklearn.metrics import classification_report +import statsmodels.api as sm +from statsmodels.stats.multicomp import pairwise_tukeyhsd +from statsmodels.stats.multitest import multipletests + +warnings.filterwarnings("ignore", category=FutureWarning) +warnings.filterwarnings("ignore", category=UserWarning) +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") +log = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def safe_mkdir(path): + Path(path).mkdir(parents=True, exist_ok=True) + + +def prepare_df_for_task(df): + """Convert DataFrame to a memory-efficient format for task submission. + + Uses pickle which is much faster than .to_dict('list') for large DataFrames. + Returns a tuple ('pickled', bytes_data) that can be unpickled in the worker. + """ + import pickle + return ('pickled', pickle.dumps(df, protocol=pickle.HIGHEST_PROTOCOL)) + + +def parse_sample_columns(columns): + """Parse sample column names like 'test1::rain_chr21_small::rep1::espf'. + + Returns a list of dicts with keys: col, group, sample, rep, value_type. + """ + parsed = [] + meta_cols = {"SeqID", "ParentIDs", "ID", "Mtype", "Ptype", "Type", "Ctype", "Mode", "Start", "End", "Strand"} + for c in columns: + if c in meta_cols: + continue + parts = c.split("::") + if len(parts) == 4: + parsed.append({ + "col": c, + "group": parts[0], + "sample": parts[1], + "rep": parts[2], + "value_type": parts[3], + }) + return parsed + + +def get_value_types(sample_info): + return sorted(set(s["value_type"] for s in sample_info)) + + +def cols_for_vtype(sample_info, vtype): + return [s["col"] for s in sample_info if s["value_type"] == vtype] + + +def group_for_col(sample_info, col): + for s in sample_info: + if s["col"] == col: + return s["group"] + return None + + +def sample_info_for_vtype(sample_info, vtype): + return [s for s in sample_info if s["value_type"] == vtype] + + +def numeric_df(df, cols): + """Return a copy with the given columns cast to numeric (coerce errors).""" + out = df.copy() + for c in cols: + out[c] = pd.to_numeric(out[c], errors="coerce") + return out + + +def save_fig(fig, path, dpi=150): + fig.savefig(path, dpi=dpi, bbox_inches="tight") + plt.close(fig) + + +# --------------------------------------------------------------------------- +# Quality Control +# --------------------------------------------------------------------------- + +def qc_analysis(df, sample_cols, outdir): + """Basic quality control: missing values, distributions, outliers.""" + safe_mkdir(outdir) + ndf = numeric_df(df, sample_cols) + results = {} + + # Missing value counts + log.info("Missing value counts") + missing = ndf[sample_cols].isnull().sum() + total = len(ndf) + missing_pct = (missing / total * 100).round(2) + miss_df = pd.DataFrame({"missing_count": missing, "missing_pct": missing_pct}) + miss_df.to_csv(os.path.join(outdir, "missing_values.csv")) + results["missing"] = miss_df.to_dict() + + # Distribution summary per sample + log.info("Distribution summary per sample") + desc = ndf[sample_cols].describe().T + desc.to_csv(os.path.join(outdir, "distribution_summary.csv")) + + # Box plot of all samples + log.info("Box plot of all samples") + fig, ax = plt.subplots(figsize=(max(6, len(sample_cols) * 0.8), 5)) + ndf[sample_cols].boxplot(ax=ax, rot=90) + ax.set_title("Distribution per sample") + ax.set_ylabel("Value") + save_fig(fig, os.path.join(outdir, "boxplot_samples.png")) + + # Heatmap of missing values (limit to avoid matplotlib memory errors with large datasets) + log.info("Heatmap of missing values") + max_heatmap_rows = 10000 # limit to prevent memory issues + heatmap_data = ndf[sample_cols].isnull().astype(int) + + if len(heatmap_data) > max_heatmap_rows: + # Sample: prioritize BMKs with most missing values + missing_counts = heatmap_data.sum(axis=1) + top_missing_idx = missing_counts.nlargest(max_heatmap_rows).index + heatmap_data = heatmap_data.loc[top_missing_idx] + log.info(f" Limiting missing values heatmap to {max_heatmap_rows} BMKs with most missing (from {len(ndf)} total)") + + fig_height = min(20, max(4, len(heatmap_data) * 0.02)) # Cap at 20 inches + try: + fig, ax = plt.subplots(figsize=(max(6, len(sample_cols) * 0.6), fig_height)) + sns.heatmap(heatmap_data, cbar=False, ax=ax, yticklabels=False) + ax.set_title(f"Missing values heatmap ({len(heatmap_data)} BMKs)") + save_fig(fig, os.path.join(outdir, "missing_heatmap.png")) + log.info(" Missing values heatmap saved") + except Exception as e: + log.warning(f" Missing values heatmap failed: {e}") + + results["desc"] = desc.to_dict() + return results + + +# --------------------------------------------------------------------------- +# Descriptive Statistics +# --------------------------------------------------------------------------- + +def descriptive_stats(df, sample_cols, sample_info, outdir): + safe_mkdir(outdir) + ndf = numeric_df(df, sample_cols) + results = {} + + # Per-group statistics + groups = sorted(set(s["group"] for s in sample_info)) + group_stats = {} + for g in groups: + gcols = [s["col"] for s in sample_info if s["group"] == g] + vals = ndf[gcols].values.flatten() + vals = vals[~np.isnan(vals)] + if len(vals) == 0: + continue + group_stats[g] = { + "n": int(len(vals)), + "mean": float(np.mean(vals)), + "median": float(np.median(vals)), + "std": float(np.std(vals, ddof=1)) if len(vals) > 1 else 0.0, + "min": float(np.min(vals)), + "max": float(np.max(vals)), + "q25": float(np.percentile(vals, 25)), + "q75": float(np.percentile(vals, 75)), + } + gs_df = pd.DataFrame(group_stats).T + gs_df.to_csv(os.path.join(outdir, "group_statistics.csv")) + results["group_stats"] = group_stats + + # Per-BMK mean by group + bmk_means = {} + for g in groups: + gcols = [s["col"] for s in sample_info if s["group"] == g] + bmk_means[g] = ndf[gcols].mean(axis=1) + bmk_mean_df = pd.DataFrame(bmk_means, index=df.index) + if "ID" in df.columns: + bmk_mean_df.index = df["ID"] + bmk_mean_df.to_csv(os.path.join(outdir, "bmk_mean_by_group.csv")) + + # Violin plot by group (optimized with pd.melt) + sample_cols_list = [s["col"] for s in sample_info] + if sample_cols_list: + # Create group mapping for columns + col_to_group = {s["col"]: s["group"] for s in sample_info} + # Melt the dataframe efficiently + melted = ndf[sample_cols_list].melt(var_name="sample_col", value_name="value") + melted["group"] = melted["sample_col"].map(col_to_group) + melted = melted.dropna(subset=["value"]) + + if len(melted) > 0: + fig, ax = plt.subplots(figsize=(max(6, len(groups) * 2), 5)) + sns.violinplot(data=melted, x="group", y="value", ax=ax, inner="box") + ax.set_title("Value distribution by group") + save_fig(fig, os.path.join(outdir, "violin_by_group.png")) + + return results + + +# --------------------------------------------------------------------------- +# Differential Editing Analysis +# --------------------------------------------------------------------------- + +def test_normality_and_homogeneity(group_values): + """Test for normality (Shapiro-Wilk) and homoscedasticity (Bartlett). + + Returns: + dict with keys: is_normal, is_homogeneous, shapiro_pvals, bartlett_pval + """ + results = { + "is_normal": False, + "is_homogeneous": False, + "shapiro_pvals": [], + "bartlett_pval": None + } + + non_empty = [gv for gv in group_values.values() if len(gv) >= 3] # Shapiro needs n>=3 + if len(non_empty) < 2: + return results + + # Test normality per group (Shapiro-Wilk) + shapiro_pvals = [] + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + for gv in non_empty: + if len(gv) >= 3: + try: + # Check for near-constant data before testing + if np.std(gv) < 1e-10: + continue # Skip constant data + _, p = stats.shapiro(gv) + shapiro_pvals.append(p) + except Exception: + pass + + results["shapiro_pvals"] = shapiro_pvals + results["is_normal"] = all(p > 0.05 for p in shapiro_pvals) if shapiro_pvals else False + + # Test homogeneity of variances (Bartlett) + if len(non_empty) >= 2: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + try: + # Check all groups have variance + if all(np.std(gv) > 1e-10 for gv in non_empty): + _, p = stats.bartlett(*non_empty) + results["bartlett_pval"] = p + results["is_homogeneous"] = p > 0.05 + except Exception: + pass + + return results + + +def differential_analysis(df, sample_cols, sample_info, outdir, stat_test="auto"): + """Pairwise and global differential tests between groups. + + Args: + stat_test: 'auto', 'parametric', 'nonparametric', 'welch', 'kruskal' + - auto: test normality/homogeneity and choose best test + - parametric: Student t-test / ANOVA (assumes normality + equal variances) + - nonparametric: Mann-Whitney U / Kruskal-Wallis (no assumptions) + - welch: Welch t-test / Welch ANOVA (assumes normality, unequal variances OK) + - kruskal: alias for nonparametric (backward compatibility) + + Note: Expects df to already be pre-filtered for variable BMKs (done in analyze_section) + """ + safe_mkdir(outdir) + ndf = numeric_df(df, sample_cols) + groups = sorted(set(s["group"] for s in sample_info)) + results = {"pairwise": {}, "global": {}, "stat_test_used": stat_test} + + if len(groups) < 2: + log.warning("Less than 2 groups, skipping differential analysis.") + return results + + if len(ndf) == 0: + log.warning("No BMKs remaining for differential analysis") + return results + + # Normalize test names + if stat_test == "kruskal": + stat_test = "nonparametric" + + # OPTIMIZATION: Pre-compute group columns to avoid O(N_samples) per BMK per group + group_cols = {g: [s["col"] for s in sample_info if s["group"] == g] for g in groups} + + rows = [] + + # Suppress scipy warnings about numerical issues (we handle them with try/except) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="Precision loss occurred") + warnings.filterwarnings("ignore", message="invalid value encountered") + warnings.filterwarnings("ignore", message="divide by zero encountered") + + for idx in ndf.index: + row_data = {"index": idx} + if "ID" in df.columns: + row_data["ID"] = df.loc[idx, "ID"] + + group_values = {} + for g in groups: + gcols = group_cols[g] # Use pre-computed dict + vals = ndf.loc[idx, gcols].dropna().values.astype(float) + group_values[g] = vals + row_data[f"mean_{g}"] = np.mean(vals) if len(vals) > 0 else np.nan + row_data[f"n_{g}"] = len(vals) + + non_empty = [gv for gv in group_values.values() if len(gv) > 0] + + # Determine which test to use + test_choice = stat_test + if stat_test == "auto" and len(non_empty) >= 2: + # Test assumptions + assumptions = test_normality_and_homogeneity(group_values) + row_data["shapiro_pvals"] = ",".join([f"{p:.4f}" for p in assumptions["shapiro_pvals"]]) if assumptions["shapiro_pvals"] else "" + row_data["bartlett_pval"] = assumptions["bartlett_pval"] + + # Choose test based on assumptions + if assumptions["is_normal"] and assumptions["is_homogeneous"]: + test_choice = "parametric" + elif assumptions["is_normal"] and not assumptions["is_homogeneous"]: + test_choice = "welch" + else: + test_choice = "nonparametric" + + row_data["test_selected"] = test_choice + + # Global tests (comparing all groups) + if len(non_empty) >= 2 and all(len(v) >= 1 for v in non_empty): + # Kruskal-Wallis (always compute for backward compatibility) + try: + stat, pval = stats.kruskal(*non_empty) + row_data["kruskal_stat"] = stat + row_data["kruskal_pval"] = pval + except Exception: + row_data["kruskal_stat"] = np.nan + row_data["kruskal_pval"] = np.nan + + # ANOVA (parametric) + valid_for_anova = [gv for gv in group_values.values() if len(gv) >= 2] + if len(valid_for_anova) >= 2: + try: + stat, pval = stats.f_oneway(*valid_for_anova) + row_data["anova_stat"] = stat + row_data["anova_pval"] = pval + except Exception: + row_data["anova_stat"] = np.nan + row_data["anova_pval"] = np.nan + else: + row_data["anova_stat"] = np.nan + row_data["anova_pval"] = np.nan + + # Welch ANOVA (parametric with unequal variances) + if len(valid_for_anova) >= 2: + try: + # Welch ANOVA using scipy (one-way test with equal_var=False not available directly) + # We'll use a workaround: for 2 groups use Welch t-test, for 3+ use Kruskal as fallback + if len(groups) == 2: + g1_vals = group_values[groups[0]] + g2_vals = group_values[groups[1]] + if len(g1_vals) >= 2 and len(g2_vals) >= 2: + stat, pval = stats.ttest_ind(g1_vals, g2_vals, equal_var=False) + row_data["welch_stat"] = stat + row_data["welch_pval"] = pval + else: + row_data["welch_stat"] = np.nan + row_data["welch_pval"] = np.nan + else: + # For 3+ groups, Welch ANOVA is complex; use oneway with unequal var assumption + # scipy doesn't have direct Welch ANOVA, so we use Kruskal as robust alternative + row_data["welch_stat"] = row_data.get("kruskal_stat", np.nan) + row_data["welch_pval"] = row_data.get("kruskal_pval", np.nan) + except Exception: + row_data["welch_stat"] = np.nan + row_data["welch_pval"] = np.nan + else: + row_data["welch_stat"] = np.nan + row_data["welch_pval"] = np.nan + + # Determine primary test based on choice + if test_choice == "parametric": + row_data["primary_stat"] = row_data.get("anova_stat", np.nan) + row_data["primary_pval"] = row_data.get("anova_pval", np.nan) + elif test_choice == "welch": + row_data["primary_stat"] = row_data.get("welch_stat", np.nan) + row_data["primary_pval"] = row_data.get("welch_pval", np.nan) + else: # nonparametric (default) + row_data["primary_stat"] = row_data.get("kruskal_stat", np.nan) + row_data["primary_pval"] = row_data.get("kruskal_pval", np.nan) + else: + row_data["kruskal_stat"] = np.nan + row_data["kruskal_pval"] = np.nan + row_data["anova_stat"] = np.nan + row_data["anova_pval"] = np.nan + row_data["welch_stat"] = np.nan + row_data["welch_pval"] = np.nan + row_data["primary_stat"] = np.nan + row_data["primary_pval"] = np.nan + + # Pairwise tests + for g1, g2 in combinations(groups, 2): + v1, v2 = group_values.get(g1, []), group_values.get(g2, []) + pair_key = f"{g1}_vs_{g2}" + + if len(v1) >= 1 and len(v2) >= 1: + # Mann-Whitney U (nonparametric, always compute) + try: + stat, pval = stats.mannwhitneyu(v1, v2, alternative="two-sided") + row_data[f"mwu_stat_{pair_key}"] = stat + row_data[f"mwu_pval_{pair_key}"] = pval + except Exception: + row_data[f"mwu_stat_{pair_key}"] = np.nan + row_data[f"mwu_pval_{pair_key}"] = np.nan + + # Student t-test (parametric, equal variances) + if len(v1) >= 2 and len(v2) >= 2: + try: + stat, pval = stats.ttest_ind(v1, v2, equal_var=True) + row_data[f"student_stat_{pair_key}"] = stat + row_data[f"student_pval_{pair_key}"] = pval + except Exception: + row_data[f"student_stat_{pair_key}"] = np.nan + row_data[f"student_pval_{pair_key}"] = np.nan + else: + row_data[f"student_stat_{pair_key}"] = np.nan + row_data[f"student_pval_{pair_key}"] = np.nan + + # Welch t-test (parametric, unequal variances) + if len(v1) >= 2 and len(v2) >= 2: + try: + stat, pval = stats.ttest_ind(v1, v2, equal_var=False) + row_data[f"welch_stat_{pair_key}"] = stat + row_data[f"welch_pval_{pair_key}"] = pval + except Exception: + row_data[f"welch_stat_{pair_key}"] = np.nan + row_data[f"welch_pval_{pair_key}"] = np.nan + else: + row_data[f"welch_stat_{pair_key}"] = np.nan + row_data[f"welch_pval_{pair_key}"] = np.nan + + # Effect size (rank-biserial for Mann-Whitney) + n1, n2 = len(v1), len(v2) + if n1 * n2 > 0 and not np.isnan(row_data.get(f"mwu_stat_{pair_key}", np.nan)): + row_data[f"effect_size_{pair_key}"] = 1 - 2 * row_data[f"mwu_stat_{pair_key}"] / (n1 * n2) + + # Log2 fold change of means + m1, m2 = np.mean(v1), np.mean(v2) + if m2 != 0 and m1 != 0: + row_data[f"log2fc_{pair_key}"] = np.log2(m1 / m2) if m1 > 0 and m2 > 0 else np.nan + row_data[f"diff_{pair_key}"] = m1 - m2 + else: + row_data[f"mwu_stat_{pair_key}"] = np.nan + row_data[f"mwu_pval_{pair_key}"] = np.nan + row_data[f"student_stat_{pair_key}"] = np.nan + row_data[f"student_pval_{pair_key}"] = np.nan + row_data[f"welch_stat_{pair_key}"] = np.nan + row_data[f"welch_pval_{pair_key}"] = np.nan + + rows.append(row_data) + + # Convert rows to DataFrame + res_df = pd.DataFrame(rows) + + # Multiple testing correction (FDR Benjamini-Hochberg) for ALL p-value columns + for col in res_df.columns: + if col.endswith("_pval"): + pvals = res_df[col].values + mask = ~np.isnan(pvals) + if mask.sum() > 0: + _, corrected, _, _ = multipletests(pvals[mask], method="fdr_bh") + adj_col = col.replace("_pval", "_padj") + res_df[adj_col] = np.nan + res_df.loc[mask, adj_col] = corrected + + res_df.to_csv(os.path.join(outdir, "differential_results.csv"), index=False) + results["table"] = os.path.join(outdir, "differential_results.csv") + results["stat_test_method"] = stat_test + + # Volcano-like plot for each pairwise comparison (using primary test) + for g1, g2 in combinations(groups, 2): + pair_key = f"{g1}_vs_{g2}" + diff_col = f"diff_{pair_key}" + + # Choose p-value column based on test method + if stat_test == "parametric": + pval_col = f"student_pval_{pair_key}" + padj_col = f"student_padj_{pair_key}" + test_label = "Student" + elif stat_test == "welch": + pval_col = f"welch_pval_{pair_key}" + padj_col = f"welch_padj_{pair_key}" + test_label = "Welch" + else: # nonparametric or auto (default to nonparametric) + pval_col = f"mwu_pval_{pair_key}" + padj_col = f"mwu_padj_{pair_key}" + test_label = "Mann-Whitney U" + if diff_col in res_df.columns and padj_col in res_df.columns: + pdf = res_df[[diff_col, padj_col]].dropna() + if len(pdf) > 0: + fig, ax = plt.subplots(figsize=(7, 5)) + neg_log_p = -np.log10(pdf[padj_col].clip(lower=1e-300)) + colors = ["red" if p < 0.05 else "grey" for p in pdf[padj_col]] + ax.scatter(pdf[diff_col], neg_log_p, c=colors, alpha=0.6, s=20) + ax.axhline(-np.log10(0.05), color="blue", linestyle="--", alpha=0.5) + ax.set_xlabel(f"Difference ({g1} - {g2})") + ax.set_ylabel("-log10(adjusted p-value)") + ax.set_title(f"Volcano plot: {g1} vs {g2} ({test_label})") + save_fig(fig, os.path.join(outdir, f"volcano_{pair_key}.png")) + + return results + + +# --------------------------------------------------------------------------- +# Multivariate Analysis (PCA, hierarchical clustering) +# --------------------------------------------------------------------------- + +def multivariate_analysis(df, sample_cols, sample_info, outdir): + safe_mkdir(outdir) + ndf = numeric_df(df, sample_cols) + results = {} + + # Transpose: samples as rows, BMKs as columns + mat = ndf[sample_cols].T.copy() + mat.columns = df["ID"].values if "ID" in df.columns else range(len(df)) + mat = mat.dropna(axis=1, how="all").fillna(0) + + if mat.shape[1] < 2 or mat.shape[0] < 2: + log.warning("Not enough data for multivariate analysis") + return results + + # PCA + scaler = StandardScaler() + scaled = scaler.fit_transform(mat) + n_components = min(mat.shape[0], mat.shape[1], 10) + pca = PCA(n_components=n_components) + coords = pca.fit_transform(scaled) + var_exp = pca.explained_variance_ratio_ + + labels = [s["group"] for s in sample_info if s["col"] in mat.index] + sample_labels = [s["sample"] for s in sample_info if s["col"] in mat.index] + + pca_df = pd.DataFrame(coords[:, :min(5, n_components)], + columns=[f"PC{i+1}" for i in range(min(5, n_components))], + index=mat.index) + pca_df["group"] = labels + pca_df["sample"] = sample_labels + pca_df.to_csv(os.path.join(outdir, "pca_coordinates.csv")) + + # Variance explained + var_df = pd.DataFrame({"PC": [f"PC{i+1}" for i in range(len(var_exp))], + "variance_explained": var_exp, + "cumulative": np.cumsum(var_exp)}) + var_df.to_csv(os.path.join(outdir, "pca_variance.csv"), index=False) + + # Scree plot + fig, ax = plt.subplots(figsize=(6, 4)) + ax.bar(range(1, len(var_exp) + 1), var_exp * 100, alpha=0.7, label="Individual") + ax.plot(range(1, len(var_exp) + 1), np.cumsum(var_exp) * 100, "ro-", label="Cumulative") + ax.set_xlabel("Principal Component") + ax.set_ylabel("Variance Explained (%)") + ax.set_title("PCA Scree Plot") + ax.legend() + save_fig(fig, os.path.join(outdir, "pca_scree.png")) + + # PCA biplot PC1 vs PC2 + if n_components >= 2: + fig, ax = plt.subplots(figsize=(8, 6)) + unique_groups = sorted(set(labels)) + colors = plt.cm.Set1(np.linspace(0, 1, max(len(unique_groups), 1))) + for i, g in enumerate(unique_groups): + mask = [l == g for l in labels] + ax.scatter(coords[mask, 0], coords[mask, 1], c=[colors[i]], label=g, s=80, alpha=0.8) + for j, m in enumerate(mask): + if m: + ax.annotate(sample_labels[j], (coords[j, 0], coords[j, 1]), + fontsize=7, alpha=0.7) + ax.set_xlabel(f"PC1 ({var_exp[0]*100:.1f}%)") + ax.set_ylabel(f"PC2 ({var_exp[1]*100:.1f}%)") + ax.set_title("PCA - Samples") + ax.legend() + save_fig(fig, os.path.join(outdir, "pca_biplot.png")) + + # Hierarchical clustering on samples + if mat.shape[0] >= 2: + try: + dist = pdist(scaled, metric="euclidean") + Z = linkage(dist, method="ward") + fig, ax = plt.subplots(figsize=(max(6, len(sample_cols) * 0.8), 5)) + dendrogram(Z, labels=[s.split("::")[1] for s in mat.index], ax=ax, leaf_rotation=90) + ax.set_title("Hierarchical Clustering of Samples") + save_fig(fig, os.path.join(outdir, "dendrogram_samples.png")) + except Exception as e: + log.warning(f"Dendrogram failed: {e}") + + # Sample correlation heatmap + corr = mat.T.corr() + corr.to_csv(os.path.join(outdir, "sample_correlation.csv")) + fig, ax = plt.subplots(figsize=(max(6, len(sample_cols) * 0.8), max(5, len(sample_cols) * 0.6))) + short_labels = [s.split("::")[0] + "::" + s.split("::")[1] for s in corr.index] + sns.heatmap(corr, annot=True, fmt=".2f", cmap="RdBu_r", center=0, ax=ax, + xticklabels=short_labels, yticklabels=short_labels) + ax.set_title("Sample Correlation Matrix") + save_fig(fig, os.path.join(outdir, "correlation_heatmap.png")) + + results["pca_variance"] = var_df.to_dict(orient="records") + return results + + +# --------------------------------------------------------------------------- +# Correlation / Network Analysis +# --------------------------------------------------------------------------- + +def correlation_network(df, sample_cols, outdir, max_bmks=50): + """Compute BMK-BMK correlation matrix and heatmap. + + MEMORY FIX: Reduced max_bmks from 200 to 50 to prevent OOM. + - Correlation matrix: N×N float64 = N² × 8 bytes + - With N=200: 320 KB per section × 1491 sections = 477 MB accumulated + - With N=50: 20 KB per section × 1491 sections = 30 MB accumulated (16x reduction) + - Heatmap figure with N=50: ~4 MB instead of 64 MB per image + """ + safe_mkdir(outdir) + ndf = numeric_df(df, sample_cols) + results = {} + + # BMK-BMK correlation (limit to top variable BMKs) + mat = ndf[sample_cols].copy() + mat.index = df["ID"].values if "ID" in df.columns else range(len(df)) + mat = mat.dropna(axis=0, how="all") + + # Early exit if too few BMKs + if len(mat) < 3: + log.info(f" Skipping correlation analysis: only {len(mat)} BMKs (need at least 3)") + return results + + var = mat.var(axis=1) + top_idx = var.nlargest(min(max_bmks, len(var))).index + sub = mat.loc[top_idx] + + if len(sub) >= 3: + # Compute correlation matrix (memory intensive: N×N) + corr = sub.T.corr() + corr.to_csv(os.path.join(outdir, "bmk_correlation.csv")) + + # Limit figure size to prevent excessive memory usage + fig_width = min(15, max(6, len(sub) * 0.15)) # Cap at 15 inches + fig_height = min(12, max(5, len(sub) * 0.12)) # Cap at 12 inches + + fig, ax = plt.subplots(figsize=(fig_width, fig_height)) + sns.heatmap(corr, cmap="RdBu_r", center=0, ax=ax, + xticklabels=len(sub) < 50, yticklabels=len(sub) < 50) + ax.set_title(f"BMK Correlation (top {len(sub)} by variance)") + save_fig(fig, os.path.join(outdir, "bmk_correlation_heatmap.png")) + results["n_bmks_corr"] = len(sub) + + return results + + +# --------------------------------------------------------------------------- +# Feature Selection / Biomarker Ranking +# --------------------------------------------------------------------------- + +def feature_ranking(df, sample_cols, sample_info, outdir, max_bmks=500, bmk_filter_cols=None): + """Feature ranking using RandomForest importance. + + Args: + bmk_filter_cols: list of column names to try in order for filtering significant BMKs + """ + safe_mkdir(outdir) + ndf = numeric_df(df, sample_cols) + groups = sorted(set(s["group"] for s in sample_info)) + results = {} + + if len(groups) < 2: + return results + + # Build matrix: rows = samples, cols = BMKs + mat = ndf[sample_cols].T.copy() + bmk_ids = df["ID"].values if "ID" in df.columns else [str(i) for i in range(len(df))] + mat.columns = bmk_ids + mat = mat.fillna(0) + + labels = [group_for_col(sample_info, c) for c in sample_cols] + le = LabelEncoder() + y = le.fit_transform(labels) + + # Default filter column priority if not specified + if bmk_filter_cols is None: + bmk_filter_cols = ["primary_padj", "kruskal_padj", "welch_padj", "anova_padj"] + + # Filter BMKs: cascade through filter columns to collect significant ones + # PASS 1: adjusted p-values (*_padj < 0.05) + # PASS 2: raw p-values (*_pval < 0.05) if not enough + diff_file = os.path.join(os.path.dirname(outdir), "differential", "differential_results.csv") + selected_bmks = set() + n_bmks_total = len(mat.columns) + filter_log = [] + + if os.path.exists(diff_file): + diff_df = pd.read_csv(diff_file) + + # PASS 1: Cascade through adjusted p-values (strong evidence) + for filter_col in bmk_filter_cols: + if len(selected_bmks) >= max_bmks: + break + + if filter_col in diff_df.columns and "ID" in diff_df.columns: + sig_bmks = diff_df[diff_df[filter_col] < 0.05]["ID"].values + valid_sig_bmks = [b for b in sig_bmks if b in mat.columns and b not in selected_bmks] + remaining = max_bmks - len(selected_bmks) + to_add = valid_sig_bmks[:remaining] + + if to_add: + selected_bmks.update(to_add) + filter_log.append(f"{filter_col}: +{len(to_add)}") + + # PASS 2: If not enough, cascade through raw p-values (suggestive evidence) + if len(selected_bmks) < max_bmks: + filter_log.append("|") + for filter_col in bmk_filter_cols: + if len(selected_bmks) >= max_bmks: + break + + # Convert *_padj to *_pval + pval_col = filter_col.replace("_padj", "_pval") + if pval_col != filter_col and pval_col in diff_df.columns and "ID" in diff_df.columns: + sig_bmks = diff_df[diff_df[pval_col] < 0.05]["ID"].values + valid_sig_bmks = [b for b in sig_bmks if b in mat.columns and b not in selected_bmks] + remaining = max_bmks - len(selected_bmks) + to_add = valid_sig_bmks[:remaining] + + if to_add: + selected_bmks.update(to_add) + filter_log.append(f"{pval_col}: +{len(to_add)}") + + if selected_bmks: + selected_bmks = list(selected_bmks) + log.info(f" Collected {len(selected_bmks)} significant BMKs from {n_bmks_total} total for ranking") + log.info(f" Cascade: {' → '.join(filter_log)}") + else: + selected_bmks = None + + # If not enough significant BMKs, complete with top variable ones + if selected_bmks is None or len(selected_bmks) < 10: + if selected_bmks is None: + selected_bmks = [] + var_scores = mat.var(axis=0) + # Exclude already selected + var_scores = var_scores[[b for b in var_scores.index if b not in selected_bmks]] + remaining = max_bmks - len(selected_bmks) + top_var = var_scores.nlargest(min(remaining, len(var_scores))).index.tolist() + selected_bmks.extend(top_var) + log.info(f" Completed with {len(top_var)} top variable BMKs (total: {len(selected_bmks)} from {n_bmks_total})") + + # Filter matrix to selected BMKs + mat_filtered = mat[selected_bmks].copy() + + # Random Forest importance + if len(set(y)) >= 2 and mat_filtered.shape[0] >= 4 and mat_filtered.shape[1] >= 2: + try: + rf = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=1, max_depth=10) + rf.fit(mat_filtered, y) + importances = rf.feature_importances_ + imp_df = pd.DataFrame({"bmk": mat_filtered.columns, "importance": importances}) + imp_df = imp_df.sort_values("importance", ascending=False) + imp_df.to_csv(os.path.join(outdir, "rf_importance.csv"), index=False) + results["rf_top10"] = imp_df.head(10).to_dict(orient="records") + + # Plot top 30 + top_n = min(30, len(imp_df)) + top = imp_df.head(top_n) + fig, ax = plt.subplots(figsize=(8, max(4, top_n * 0.3))) + ax.barh(range(top_n), top["importance"].values[::-1]) + ax.set_yticks(range(top_n)) + ax.set_yticklabels(top["bmk"].values[::-1], fontsize=7) + ax.set_xlabel("Importance") + ax.set_title(f"Random Forest Feature Importance (top 30 from {len(selected_bmks)} BMKs)") + save_fig(fig, os.path.join(outdir, "rf_importance.png")) + except Exception as e: + log.warning(f"RF importance failed: {e}") + + # Variance-based ranking + var_scores = ndf[sample_cols].var(axis=1) + var_df = pd.DataFrame({"bmk": bmk_ids, "variance": var_scores.values}) + var_df = var_df.sort_values("variance", ascending=False) + var_df.to_csv(os.path.join(outdir, "variance_ranking.csv"), index=False) + + # Kruskal-Wallis based ranking (from differential analysis if available) + diff_file = os.path.join(os.path.dirname(outdir), "differential", "differential_results.csv") + if os.path.exists(diff_file): + diff_df = pd.read_csv(diff_file) + if "kruskal_padj" in diff_df.columns: + rank_df = diff_df[["ID", "kruskal_padj"]].dropna().sort_values("kruskal_padj") + rank_df.to_csv(os.path.join(outdir, "kruskal_ranking.csv"), index=False) + results["kruskal_top10"] = rank_df.head(10).to_dict(orient="records") + + return results + + +# --------------------------------------------------------------------------- +# Classification / Predictive Modeling +# --------------------------------------------------------------------------- + +def classification_analysis(df, sample_cols, sample_info, outdir, max_bmks=500, bmk_filter_cols=None): + """Classification analysis with multiple classifiers. + + Args: + bmk_filter_cols: list of column names to try in order for filtering significant BMKs + """ + safe_mkdir(outdir) + ndf = numeric_df(df, sample_cols) + groups = sorted(set(s["group"] for s in sample_info)) + results = {} + + if len(groups) < 2: + return results + + mat = ndf[sample_cols].T.fillna(0).copy() + bmk_ids = df["ID"].values if "ID" in df.columns else [str(i) for i in range(len(df))] + mat.columns = bmk_ids + labels = [group_for_col(sample_info, c) for c in sample_cols] + le = LabelEncoder() + y = le.fit_transform(labels) + + n_samples = len(y) + n_classes = len(set(y)) + n_bmks_total = mat.shape[1] + + if n_samples < 4 or n_classes < 2: + log.warning("Not enough samples for classification") + return results + + # Default filter column priority if not specified + if bmk_filter_cols is None: + bmk_filter_cols = ["primary_padj", "kruskal_padj", "welch_padj", "anova_padj"] + + # Filter BMKs: cascade through filter columns to collect significant ones + diff_file = os.path.join(os.path.dirname(outdir), "differential", "differential_results.csv") + selected_bmks = set() + filter_log = [] + + if os.path.exists(diff_file): + diff_df = pd.read_csv(diff_file) + + # Cascade through filter columns in priority order + for filter_col in bmk_filter_cols: + if len(selected_bmks) >= max_bmks: + break # Stop if we have enough + + if filter_col in diff_df.columns and "ID" in diff_df.columns: + # Get significant BMKs from this column + sig_bmks = diff_df[diff_df[filter_col] < 0.05]["ID"].values + valid_sig_bmks = [b for b in sig_bmks if b in mat.columns and b not in selected_bmks] + + # Add up to remaining slots + remaining = max_bmks - len(selected_bmks) + to_add = valid_sig_bmks[:remaining] + + if to_add: + selected_bmks.update(to_add) + filter_log.append(f"{filter_col}: +{len(to_add)} BMKs") + + if selected_bmks: + selected_bmks = list(selected_bmks) + log.info(f" Collected {len(selected_bmks)} significant BMKs from {n_bmks_total} total for classification") + log.info(f" Cascade: {' → '.join(filter_log)}") + else: + selected_bmks = None + + # If not enough significant BMKs, complete with top variable ones + if selected_bmks is None or len(selected_bmks) < 10: + if selected_bmks is None: + selected_bmks = [] + var_scores = mat.var(axis=0) + # Exclude already selected + var_scores = var_scores[[b for b in var_scores.index if b not in selected_bmks]] + remaining = max_bmks - len(selected_bmks) + top_var = var_scores.nlargest(min(remaining, len(var_scores))).index.tolist() + selected_bmks.extend(top_var) + log.info(f" Completed with {len(top_var)} top variable BMKs (total: {len(selected_bmks)} from {n_bmks_total})") + + # Filter matrix to selected BMKs + mat = mat[selected_bmks].copy() + n_bmks = mat.shape[1] + + if n_bmks < max(2, n_classes): + log.warning(f"Not enough biomarkers for classification ({n_bmks} BMKs after filtering, need at least {max(2, n_classes)})") + return results + + # Use LOO or stratified k-fold depending on sample count + if n_samples < 10: + cv = LeaveOneOut() + cv_name = "LOO" + else: + cv = StratifiedKFold(n_splits=min(5, n_samples), shuffle=True, random_state=42) + cv_name = "StratifiedKFold" + + classifiers = {} + classifiers["RandomForest"] = RandomForestClassifier(n_estimators=50, random_state=42, n_jobs=1, max_depth=10) + + if n_classes == 2: + classifiers["GradientBoosting"] = GradientBoostingClassifier(n_estimators=50, random_state=42) + + # LDA only if feasible + if n_samples > n_classes and mat.shape[1] > 0: + try: + classifiers["LDA"] = LinearDiscriminantAnalysis() + except Exception: + pass + + clf_results = {} + for name, clf in classifiers.items(): + try: + scores = cross_val_score(clf, mat, y, cv=cv, scoring="accuracy") + clf_results[name] = { + "mean_accuracy": float(np.mean(scores)), + "std_accuracy": float(np.std(scores)), + "cv_method": cv_name, + "scores": scores.tolist(), + "n_bmks_used": n_bmks, + } + except Exception as e: + log.warning(f"Classifier {name} failed: {e}") + + results["classifiers"] = clf_results + clf_df = pd.DataFrame([ + {"classifier": k, "mean_accuracy": v["mean_accuracy"], + "std_accuracy": v["std_accuracy"], "cv_method": v["cv_method"], + "n_bmks_used": v.get("n_bmks_used", n_bmks)} + for k, v in clf_results.items() + ]) + clf_df.to_csv(os.path.join(outdir, "classification_results.csv"), index=False) + + # Plot + if clf_results: + fig, ax = plt.subplots(figsize=(6, 4)) + names = list(clf_results.keys()) + means = [clf_results[n]["mean_accuracy"] for n in names] + stds = [clf_results[n]["std_accuracy"] for n in names] + ax.bar(names, means, yerr=stds, alpha=0.7, capsize=5) + ax.set_ylabel("Accuracy") + ax.set_title(f"Classification Accuracy ({cv_name}, {n_bmks} BMKs)") + ax.set_ylim(0, 1.1) + save_fig(fig, os.path.join(outdir, "classification_accuracy.png")) + + return results + + +# --------------------------------------------------------------------------- +# Stability / Robustness (replicate concordance) +# --------------------------------------------------------------------------- + +def stability_analysis(df, sample_cols, sample_info, outdir): + safe_mkdir(outdir) + ndf = numeric_df(df, sample_cols) + results = {} + + # Group by (group, sample) and check replicate correlation + rep_groups = defaultdict(list) + for s in sample_info: + key = (s["group"], s["sample"]) + rep_groups[key].append(s["col"]) + + rep_corrs = [] + for key, cols in rep_groups.items(): + if len(cols) >= 2: + for c1, c2 in combinations(cols, 2): + v1 = pd.to_numeric(ndf[c1], errors="coerce") + v2 = pd.to_numeric(ndf[c2], errors="coerce") + mask = v1.notna() & v2.notna() + if mask.sum() >= 3: + r, p = stats.pearsonr(v1[mask], v2[mask]) + rep_corrs.append({ + "group": key[0], "sample": key[1], + "rep1": c1, "rep2": c2, + "pearson_r": r, "pval": p + }) + + if rep_corrs: + rc_df = pd.DataFrame(rep_corrs) + rc_df.to_csv(os.path.join(outdir, "replicate_correlation.csv"), index=False) + results["replicate_correlations"] = rc_df.to_dict(orient="records") + + # Coefficient of variation per BMK across all samples + cv_vals = ndf[sample_cols].apply(lambda row: row.std() / row.mean() if row.mean() != 0 else np.nan, axis=1) + cv_df = pd.DataFrame({"bmk": df["ID"].values if "ID" in df.columns else range(len(df)), + "cv": cv_vals.values}) + cv_df = cv_df.sort_values("cv") + cv_df.to_csv(os.path.join(outdir, "coefficient_variation.csv"), index=False) + + # CV histogram + fig, ax = plt.subplots(figsize=(6, 4)) + ax.hist(cv_df["cv"].dropna(), bins=30, alpha=0.7, edgecolor="black") + ax.set_xlabel("Coefficient of Variation") + ax.set_ylabel("Count") + ax.set_title("Distribution of CV across BMKs") + save_fig(fig, os.path.join(outdir, "cv_histogram.png")) + + return results + + +# --------------------------------------------------------------------------- +# Batch Effect Detection +# --------------------------------------------------------------------------- + +def batch_effect_analysis(df, sample_cols, sample_info, outdir): + safe_mkdir(outdir) + ndf = numeric_df(df, sample_cols) + results = {} + + # Check if samples cluster by batch (sample) rather than group + # Use PCA and check if samples separate by sample-id vs group + mat = ndf[sample_cols].T.fillna(0) + if mat.shape[0] < 3 or mat.shape[1] < 2: + return results + + scaler = StandardScaler() + scaled = scaler.fit_transform(mat) + pca = PCA(n_components=min(3, mat.shape[0], mat.shape[1])) + coords = pca.fit_transform(scaled) + + groups = [group_for_col(sample_info, c) for c in sample_cols] + samples = [s["sample"] for s in sample_info if s["col"] in sample_cols[:len(groups)]] + + # Plot colored by sample (batch) + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + unique_groups = sorted(set(groups)) + colors_g = plt.cm.Set1(np.linspace(0, 1, max(len(unique_groups), 1))) + for i, g in enumerate(unique_groups): + mask = [l == g for l in groups] + axes[0].scatter(coords[mask, 0], coords[mask, 1], c=[colors_g[i]], label=g, s=80) + axes[0].set_title("PCA colored by Group") + axes[0].set_xlabel("PC1") + axes[0].set_ylabel("PC2") + axes[0].legend() + + unique_samples = sorted(set(samples)) + colors_s = plt.cm.Set2(np.linspace(0, 1, max(len(unique_samples), 1))) + for i, s in enumerate(unique_samples): + mask = [l == s for l in samples] + axes[1].scatter(coords[mask, 0], coords[mask, 1], c=[colors_s[i]], label=s, s=80) + axes[1].set_title("PCA colored by Sample (batch)") + axes[1].set_xlabel("PC1") + axes[1].set_ylabel("PC2") + axes[1].legend(fontsize=7) + + save_fig(fig, os.path.join(outdir, "batch_effect_pca.png")) + return results + + +# --------------------------------------------------------------------------- +# Comprehensive heatmap for a section +# --------------------------------------------------------------------------- + +def section_heatmap(df, sample_cols, sample_info, outdir, title="", max_rows=100): + safe_mkdir(outdir) + ndf = numeric_df(df, sample_cols) + + # Sort sample_cols by group, then rep + sample_order = sorted(sample_info, key=lambda s: (s["group"], s["rep"], s["sample"])) + sorted_cols = [s["col"] for s in sample_order if s["col"] in sample_cols] + + mat = ndf[sorted_cols].copy() + mat = mat.dropna(how="all") + + # CRITICAL: Limit rows BEFORE any processing to avoid matplotlib memory errors + # With 200k+ BMKs, matplotlib cannot create images (pixel limit: 2^23 per dimension) + if len(mat) > max_rows: + var = mat.var(axis=1) + mat = mat.loc[var.nlargest(max_rows).index] + + if len(mat) == 0: + return + + # Create descriptive Y-axis labels + # Add SeqID/Ptype/Mode affixes when multiple values exist + if "Mode" in df.columns and "Ctype" in df.columns: + # Check if we need SeqID/Ptype/Mode affixes (mixed BMK types or chromosomes) + df_subset = df.loc[mat.index] + has_multiple_seqids = "SeqID" in df_subset.columns and df_subset["SeqID"].nunique() > 1 + has_multiple_ptypes = "Ptype" in df_subset.columns and df_subset["Ptype"].nunique() > 1 + has_multiple_modes = "Mode" in df_subset.columns and df_subset["Mode"].nunique() > 1 + add_affixes = has_multiple_seqids or has_multiple_ptypes or has_multiple_modes + + ylabels = [] + sort_keys = [] # For custom sorting: (priority, label) + for idx in mat.index: + mode = df.loc[idx, "Mode"] if "Mode" in df.columns else "" + ctype = df.loc[idx, "Ctype"] if "Ctype" in df.columns else "" + ptype = df.loc[idx, "Ptype"] if "Ptype" in df.columns else "" + seqid = df.loc[idx, "SeqID"] if "SeqID" in df.columns else "" + + # Build base label + if mode == "all_sites": + base_label = "all_sites" + elif ctype and ctype != ".": + base_label = str(ctype) + elif "ID" in df.columns: + base_label = str(df.loc[idx, "ID"]) + else: + base_label = str(idx) + + # Add affixes only if values vary across rows + if add_affixes: + # SeqID as PREFIX (first, if multiple chromosomes) + if has_multiple_seqids and seqid and seqid != ".": + label = f"{seqid}_{base_label}" + else: + label = base_label + + # Ptype as PREFIX or SUFFIX depending on context + if ptype and ptype != ".": + # If Ptypes vary OR SeqIDs vary, add Ptype + if has_multiple_ptypes or has_multiple_seqids: + if has_multiple_seqids: + # For all_sequence_bmks: SeqID_Ptype_Ctype + label = f"{seqid}_{ptype}_{base_label}" if seqid and seqid != "." else f"{ptype}_{base_label}" + else: + # For all_global_bmks: Ptype_Ctype + label = f"{ptype}_{base_label}" + + # Mode as SUFFIX (after), if not "all_sites" + if mode and mode != "all_sites": + label += f"_{mode}" + + # Determine sort priority for mixed BMKs: + # 1. all_sites first + # 2. no Ptype (Ptype == ".") + # 3. rest alphabetically + if mode == "all_sites": + priority = 0 + elif not ptype or ptype == ".": + priority = 1 + else: + priority = 2 + sort_keys.append((priority, label, idx)) + else: + label = base_label + sort_keys.append((0, label, idx)) + + ylabels.append(label) + + # Sort rows if affixes were added (mixed BMK types) + if add_affixes and len(sort_keys) > 1: + # Sort by (priority, label alphabetically) + sorted_items = sorted(sort_keys, key=lambda x: (x[0], x[1])) + sorted_indices = [item[2] for item in sorted_items] + mat = mat.loc[sorted_indices] + ylabels = [item[1] for item in sorted_items] + + mat.index = ylabels + + mat = mat.dropna(how="all") + if len(mat) == 0: + return + + if len(mat) > max_rows: + # Keep most variable + var = mat.var(axis=1) + mat = mat.loc[var.nlargest(max_rows).index] + + fig_h = max(4, len(mat) * 0.25) + fig_w = max(6, len(sorted_cols) * 0.8) + fig, ax = plt.subplots(figsize=(fig_w, fig_h)) + + # Create labels: group::sample::rep for x-axis + xlabels = [f"{s['group']}::{s['sample']}::{s['rep']}" for s in sample_order if s['col'] in sorted_cols] + + # Determine y-axis label display and font size + show_ylabels = len(mat) < 200 # Show labels up to 200 rows + yticklabel_fontsize = max(4, min(8, 500 / len(mat))) # Adaptive font size + + sns.heatmap(mat.astype(float), cmap="YlOrRd", ax=ax, + xticklabels=xlabels, + yticklabels=show_ylabels, + linewidths=0.5 if len(mat) < 30 else 0, + cbar_kws={'label': 'Value'}) + + if show_ylabels: + ax.set_yticklabels(ax.get_yticklabels(), fontsize=yticklabel_fontsize, rotation=0) + + ax.set_xlabel("Sample (Group::Sample::Replicate)", fontsize=10) + ax.set_ylabel("Biomarker Type" if show_ylabels else f"Biomarker ({len(mat)} total)", fontsize=10) + ax.set_title(title or "Heatmap") + plt.xticks(rotation=90, fontsize=8) + save_fig(fig, os.path.join(outdir, "heatmap.png")) + + +# --------------------------------------------------------------------------- +# Top-level orchestration per section +# --------------------------------------------------------------------------- + +def analyze_section(df, sample_cols, sample_info, outdir, section_name, stat_test="auto", bmk_filter_cols=None, max_bmks=500): + """Run all analyses on a filtered dataframe for a given report section.""" + len_df = len(df) + len_sample_cols = len(sample_cols) + if len_df == 0: + log.info(f" Section '{section_name}' is empty, skipping.") + return {} + + # Create log prefix for this section + log_prefix = f"[{section_name[:50]}]" # Limit to 50 chars + + log.info(f"{log_prefix} Analyzing section: {section_name} ({len_df} BMKs, {len_sample_cols} samples)") + safe_mkdir(outdir) + results = {"n_bmks": len_df, "n_samples": len_sample_cols, "section": section_name} + + # EARLY PRE-FILTER: For very large sections (>3k BMKs), reduce to top 3k by variance + # This prevents memory crashes in parallel workers for all_sequence_bmks / all_global_bmks + MAX_BMKS_FOR_ANALYSIS = 3000 # Reduced from 5000 to prevent hangs + if len_df > MAX_BMKS_FOR_ANALYSIS: + log.info(f"{log_prefix} {len_df} BMKs exceeds {MAX_BMKS_FOR_ANALYSIS} limit") + log.info(f"{log_prefix} Pre-filtering to top {MAX_BMKS_FOR_ANALYSIS} BMKs by variance to prevent memory issues...") + ndf_prefilter = numeric_df(df, sample_cols) + variance_prefilter = ndf_prefilter[sample_cols].var(axis=1) + top_indices = variance_prefilter.nlargest(MAX_BMKS_FOR_ANALYSIS).index + df = df.loc[top_indices].reset_index(drop=True).copy() + len_df = len(df) + log.info(f"{log_prefix} Pre-filtered to {len_df} BMKs") + results["n_bmks_prefiltered"] = len_df + + # Save filtered data + df.to_csv(os.path.join(outdir, "data.csv"), index=False) + + # 1. QC + log.info(f"{log_prefix} QC analysis...") + try: + results["qc"] = qc_analysis(df.reset_index(drop=True), sample_cols, os.path.join(outdir, "qc")) + except Exception as e: + log.warning(f"{log_prefix} QC failed: {e}") + + # 2. Descriptive stats + log.info(f"{log_prefix} Descriptive stats...") + try: + results["descriptive"] = descriptive_stats(df.reset_index(drop=True), sample_cols, sample_info, os.path.join(outdir, "descriptive")) + except Exception as e: + log.warning(f"{log_prefix} Descriptive stats failed: {e}") + + # PRE-FILTER: Remove BMKs with near-zero variance for all subsequent analyses + # This prevents numerical issues in: differential tests, PCA, correlation, heatmaps, ML models + # QC and descriptive stats above use ALL BMKs to identify constants + ndf_tmp = numeric_df(df, sample_cols) + variance = ndf_tmp[sample_cols].var(axis=1) + variable_mask = variance > 1e-10 + n_filtered_out = (~variable_mask).sum() + + if n_filtered_out > 0: + log.info(f"{log_prefix} Removing {n_filtered_out} BMKs with near-zero variance (keeping {variable_mask.sum()} variable)") + df_variable = df[variable_mask].reset_index(drop=True).copy() + n_bmks_variable = len(df_variable) + else: + df_variable = df.copy() + n_bmks_variable = len_df + + # Early exit if no variable BMKs + if n_bmks_variable == 0: + log.warning(f"{log_prefix} No variable BMKs remaining after pre-filtering") + return results + + # Update results with variable BMK count + results["n_bmks_variable"] = n_bmks_variable + + # 3. Differential analysis (use variable BMKs only) + log.info(f"{log_prefix} Differential analysis ({n_bmks_variable} BMKs)...") + try: + results["differential"] = differential_analysis(df_variable, sample_cols, sample_info, os.path.join(outdir, "differential"), stat_test=stat_test) + except Exception as e: + log.warning(f"{log_prefix} Differential analysis failed: {e}") + + # 4. Multivariate (use variable BMKs only) + log.info(f"{log_prefix} Multivariate analysis ({n_bmks_variable} BMKs)...") + try: + results["multivariate"] = multivariate_analysis(df_variable, sample_cols, sample_info, os.path.join(outdir, "multivariate")) + except Exception as e: + log.warning(f"{log_prefix} Multivariate analysis failed: {e}") + + # 5. Correlation / Network (use variable BMKs only) + log.info(f"{log_prefix} Correlation / Network analysis ({n_bmks_variable} BMKs)...") + try: + results["correlation"] = correlation_network(df_variable, sample_cols, os.path.join(outdir, "correlation")) + except Exception as e: + log.warning(f"{log_prefix} Correlation analysis failed: {e}") + + # 6. Feature ranking (use variable BMKs only) + log.info(f"{log_prefix} Feature ranking ({n_bmks_variable} BMKs)...") + try: + results["ranking"] = feature_ranking(df_variable, sample_cols, sample_info, os.path.join(outdir, "ranking"), max_bmks=max_bmks, bmk_filter_cols=bmk_filter_cols) + except Exception as e: + log.warning(f"{log_prefix} Feature ranking failed: {e}") + + # 7. Classification (use variable BMKs only) + log.info(f"{log_prefix} Classification ({n_bmks_variable} BMKs)...") + try: + results["classification"] = classification_analysis(df_variable, sample_cols, sample_info, os.path.join(outdir, "classification"), max_bmks=max_bmks, bmk_filter_cols=bmk_filter_cols) + except Exception as e: + log.warning(f"{log_prefix} Classification failed: {e}") + + # 8. Stability (use variable BMKs only) + log.info(f"{log_prefix} Stability analysis ({n_bmks_variable} BMKs)...") + try: + results["stability"] = stability_analysis(df_variable, sample_cols, sample_info, os.path.join(outdir, "stability")) + except Exception as e: + log.warning(f"{log_prefix} Stability analysis failed: {e}") + + # 9. Batch effect (use variable BMKs only) + log.info(f"{log_prefix} Batch effect analysis ({n_bmks_variable} BMKs)...") + try: + results["batch"] = batch_effect_analysis(df_variable, sample_cols, sample_info, os.path.join(outdir, "batch")) + except Exception as e: + log.warning(f"{log_prefix} Batch effect analysis failed: {e}") + + # 10. Heatmap (use variable BMKs only) + log.info(f"{log_prefix} Heatmap generation ({n_bmks_variable} BMKs)...") + try: + section_heatmap(df_variable, sample_cols, sample_info, outdir, title=section_name) + log.info(f"{log_prefix} Heatmap generation completed") + except Exception as e: + log.warning(f"{log_prefix} Heatmap failed: {e}") + + log.info(f"{log_prefix} Section analysis completed successfully") + return results + + +# --------------------------------------------------------------------------- +# Build feature hierarchy tree +# --------------------------------------------------------------------------- + +def build_feature_tree(features_df): + """Build parent-child tree for features. + + Each row's 'direct parent' is the last element in the comma-separated + ParentIDs column (the first element is always '.'). + Top-level features have ParentIDs == '.'. + """ + tree = {} # id -> {"row": row, "children": []} + for _, row in features_df.iterrows(): + fid = row["ID"] + tree[fid] = {"data": row, "children": []} + + for _, row in features_df.iterrows(): + fid = row["ID"] + parents = str(row["ParentIDs"]) + if parents == ".": + continue # top-level + parts = [p.strip() for p in parents.split(",") if p.strip() != "."] + if parts: + direct_parent = parts[-1] + if direct_parent in tree: + tree[direct_parent]["children"].append(fid) + + # Find top-level features + top_ids = [fid for fid, node in tree.items() + if str(features_df.loc[features_df["ID"] == fid, "ParentIDs"].iloc[0]) == "."] + + return tree, top_ids + + +# --------------------------------------------------------------------------- +# Parallel execution wrapper +# --------------------------------------------------------------------------- + +def analyze_section_wrapper(args_tuple): + """Wrapper for analyze_section to be used with ProcessPoolExecutor. + + Args: + args_tuple: (df_source, sample_cols, sample_info, outdir, section_name, result_key, stat_test, bmk_filter_cols, max_bmks) + where df_source can be: + - dict: legacy format, converted back to DataFrame + - tuple: ('pickled', bytes_data) for memory-efficient transfer + + Returns: + (result_key, results_dict) + """ + import traceback + import signal + import gc + import pickle + import random + import time as time_module + df_source, sample_cols, sample_info, outdir, section_name, result_key, stat_test, bmk_filter_cols, max_bmks = args_tuple + + # OPTIMIZATION: Add startup jitter to avoid synchronized worker restarts. + # With spawn + max_tasks_per_child=20, all workers start together and may + # finish their 20th task at the same time → 6 workers restart simultaneously + # → 6 × 275MB = 1.65GB spike. Jitter spreads restarts over 2 seconds. + time_module.sleep(random.uniform(0, 2)) + + # Log at start with worker PID and memory + pid = os.getpid() + mem_start = None + if psutil: + try: + proc = psutil.Process(pid) + mem_start = proc.memory_info().rss / (1024 * 1024) # MB + log.info(f" ▶ START [PID {pid}, {mem_start:.0f}MB]: {section_name}") + except: + log.info(f" ▶ START [PID {pid}]: {section_name}") + else: + log.info(f" ▶ START [PID {pid}]: {section_name}") + + # Setup timeout alarm (120 seconds max per task) + def timeout_handler(signum, frame): + raise TimeoutError(f"Task exceeded 120 seconds: {section_name}") + + try: + # Set alarm for 120 seconds + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(120) + + # Reconstruct DataFrame from source (dict or pickled bytes) + if isinstance(df_source, tuple) and df_source[0] == 'pickled': + # Memory-efficient: unpickle compressed data + import pickle + df = pickle.loads(df_source[1]) + log.info(f" ▶ [PID {pid}] DataFrame unpickled: {len(df)} rows") + else: + # Legacy: dict format + df = pd.DataFrame(df_source) + log.info(f" ▶ [PID {pid}] DataFrame reconstructed: {len(df)} rows") + + # Run analysis + results = analyze_section(df, sample_cols, sample_info, outdir, section_name, stat_test=stat_test, bmk_filter_cols=bmk_filter_cols, max_bmks=max_bmks) + + # Cancel alarm + signal.alarm(0) + + # MEMORY FIX #2: Return slim results dict (only file paths, not full data) + # Full results dict is ~500KB per task × 1491 tasks = 745MB accumulated in all_results + # global_ranking() only needs the differential.table path (50 bytes) + slim_results = { + "differential_table": os.path.join(outdir, "differential", "differential_results.csv") + } + + # Explicitly clean up to help garbage collector (critical for parallel execution) + del df + del df_source + del results # MEMORY FIX #3: Immediate cleanup of full results dict in worker + gc.collect() + + # Also force matplotlib to clean up any lingering figures + plt.close('all') + + # Log completion with memory + if psutil and mem_start: + try: + proc = psutil.Process(pid) + mem_end = proc.memory_info().rss / (1024 * 1024) + log.info(f" ✓ DONE [PID {pid}, {mem_end:.0f}MB, Δ{mem_end-mem_start:+.0f}MB]: {section_name}") + except: + log.info(f" ✓ DONE [PID {pid}]: {section_name}") + else: + log.info(f" ✓ DONE [PID {pid}]: {section_name}") + + return result_key, slim_results + + except TimeoutError as e: + signal.alarm(0) # Cancel alarm + log.error(f" ✗ TIMEOUT [PID {pid}]: {section_name} - {e}") + # Clean up even on error + try: + del df + del df_source + plt.close('all') + gc.collect() + except: + pass + raise + except Exception as e: + signal.alarm(0) # Cancel alarm + # Log full error details before crash + log.error(f" ✗ CRASH [PID {pid}]: {section_name}") + log.error(f" Error: {type(e).__name__}: {e}") + log.error(f" Stacktrace:\n{traceback.format_exc()}") + # Clean up even on error + try: + del df + del df_source + plt.close('all') + gc.collect() + except: + pass + raise + + +# --------------------------------------------------------------------------- +# Global Biomarker Ranking (across all sections) +# --------------------------------------------------------------------------- + +def global_ranking(all_results, outdir): + """Aggregate rankings across sections to produce a global ranking.""" + safe_mkdir(outdir) + + # Collect differential results + diff_files = [] + for vtype, vdata in all_results.items(): + for mtype, mdata in vdata.items(): + for section, sdata in mdata.items(): + # MEMORY FIX: sdata is now a slim dict with only "differential_table" key + diff_path = sdata.get("differential_table") + if diff_path and os.path.exists(diff_path): + ddf = pd.read_csv(diff_path) + # Defragment DataFrame before adding columns to avoid PerformanceWarning + ddf = ddf.copy() + # Add metadata columns + ddf["value_type"] = vtype + ddf["mtype"] = mtype + ddf["section"] = section + diff_files.append(ddf) + + if diff_files: + all_diff = pd.concat(diff_files, ignore_index=True) + # Rank by kruskal_padj + if "kruskal_padj" in all_diff.columns: + ranked = all_diff.dropna(subset=["kruskal_padj"]).sort_values("kruskal_padj") + ranked.to_csv(os.path.join(outdir, "global_ranking_kruskal.csv"), index=False) + + # Top 50 plot + top_n = min(50, len(ranked)) + top = ranked.head(top_n) + fig, ax = plt.subplots(figsize=(8, max(4, top_n * 0.3))) + neg_log_p = -np.log10(top["kruskal_padj"].clip(lower=1e-300)) + labels = top["ID"].astype(str) + " [" + top["section"].astype(str) + "]" + ax.barh(range(top_n), neg_log_p.values[::-1]) + ax.set_yticks(range(top_n)) + ax.set_yticklabels(labels.values[::-1], fontsize=6) + ax.set_xlabel("-log10(adjusted p-value)") + ax.set_title("Global BMK Ranking by Significance (top 50)") + + # Add reference lines for p-value thresholds + ax.axvline(-np.log10(0.05), color='red', linestyle='--', linewidth=1, alpha=0.7, label='p=0.05') + ax.axvline(-np.log10(0.1), color='orange', linestyle='--', linewidth=1, alpha=0.7, label='p=0.1') + ax.legend(loc='lower right', fontsize=8) + + save_fig(fig, os.path.join(outdir, "global_ranking_plot.png")) + + return outdir + + +# --------------------------------------------------------------------------- +# Main pipeline +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="barometer Biomarker Analysis Pipeline", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +STATISTICAL TESTS: + The pipeline ALWAYS computes ALL differential tests for each biomarker: + + Global tests (comparing all groups): + • Kruskal-Wallis → kruskal_pval, kruskal_padj + • ANOVA → anova_pval, anova_padj + • Welch ANOVA → welch_pval, welch_padj + + Pairwise tests (comparing 2 groups X vs Y): + • Mann-Whitney U (= Wilcoxon rank-sum) → mwu_pval_X_vs_Y, mwu_padj_X_vs_Y + • Student t-test → student_pval_X_vs_Y, student_padj_X_vs_Y + • Welch t-test → welch_pval_X_vs_Y, welch_padj_X_vs_Y + + Assumption tests (only with --stat-test auto): + • Shapiro-Wilk → shapiro_pvals (per group) + • Bartlett → bartlett_pval + + All tests receive FDR correction (Benjamini-Hochberg) → *_padj columns + All results are saved in differential_results.csv + + OPTION 1: --stat-test (controls which test is emphasized in plots) + This creates primary_pval and primary_padj columns that COPY the selected test: + + nonparametric : primary_* = kruskal_* (DEFAULT, robust) + parametric : primary_* = anova_* + welch : primary_* = welch_* + auto : primary_* = auto-selected test (varies per BMK) + kruskal : alias for nonparametric + + Example: --stat-test welch creates primary_padj as a copy of welch_padj + Volcano plots use the primary_* columns for visualization. + + OPTION 2: --bmk-filter (cascade priority for selecting significant BMKs) + Accepts a PRIORITY LIST of *_padj columns to collect significant BMKs + (padj < 0.05) for RandomForest and classification. + + The algorithm works in TWO-PASS CASCADE: + PASS 1 (adjusted p-values, strongest evidence): + 1. Take all BMKs with 1st column *_padj < 0.05 + 2. If < --max-bmks, add BMKs from 2nd column *_padj < 0.05 (not already selected) + 3. Continue until reaching --max-bmks or exhausting all *_padj columns + + PASS 2 (raw p-values, suggestive evidence, only if PASS 1 insufficient): + 4. Convert columns to *_pval (e.g., kruskal_padj → kruskal_pval) + 5. Add BMKs with *_pval < 0.05 (not already selected) + 6. Continue until reaching --max-bmks + + PASS 3 (if still insufficient): + 7. Complete with top variable BMKs (by variance) + + Default cascade: primary_padj → kruskal_padj → welch_padj → anova_padj + Then (if needed): primary_pval → kruskal_pval → welch_pval → anova_pval + + Available columns (choose your priority order): + • primary_padj : controlled by --stat-test (recommended) + • kruskal_padj : Kruskal-Wallis (robust, non-parametric) + • anova_padj : ANOVA (parametric, equal variances) + • welch_padj : Welch (parametric, unequal variances) + • mwu_padj_X_vs_Y : Mann-Whitney U for specific pair + • student_padj_X_vs_Y : Student t-test for specific pair + + Examples: + --bmk-filter kruskal_padj : only use Kruskal (padj then pval) + --bmk-filter primary_padj kruskal_padj : cascade from primary to kruskal + --bmk-filter welch_padj anova_padj : prioritize Welch, fallback to ANOVA + + OPTION 3: --max-bmks (maximum BMKs for ML models) + Sets the hard limit for RandomForest and classification (default: 500). + Prevents memory crashes with large datasets (e.g., 40,000+ BMKs). + Higher values = more accurate but slower and more memory-intensive. + +EXAMPLES: + # Default: cascade adjusted + raw p-values, max 500 BMKs: + ./barometer_analyze.py -a data.tsv -o results/ -j 4 + + # Use only Kruskal-Wallis (padj then pval if needed): + ./barometer_analyze.py -a data.tsv -o results/ --bmk-filter kruskal_padj + + # Prioritize Welch, fallback to Kruskal, use 1000 BMKs: + ./barometer_analyze.py -a data.tsv -o results/ \ + --stat-test welch --bmk-filter welch_padj kruskal_padj --max-bmks 1000 + + # Conservative: only most significant BMKs (200 max): + ./barometer_analyze.py -a data.tsv -o results/ --max-bmks 200 + + # Aggressive: maximize BMK collection with high limit: + ./barometer_analyze.py -a data.tsv -o results/ --max-bmks 1500 + """) + parser.add_argument("-a", "--aggregates", default=None, help="Aggregates TSV file (optional)") + parser.add_argument("-f", "--features", default=None, help="Features TSV file (optional)") + parser.add_argument("-o", "--outdir", default="barometer_results", help="Output directory") + parser.add_argument("-j", "--jobs", type=int, default=1, help="Number of parallel jobs (default: 1, use -1 for all CPUs)") + parser.add_argument("-v", "--value-types", nargs="+", default=None, help="Value types to analyze (e.g., espf espr). If not specified, all value types are analyzed.") + parser.add_argument("--agg-levels", nargs="+", default=None, help="Aggregate levels to analyze: global, sequence, feature. If not specified, all levels are analyzed.") + parser.add_argument("--feature-types", nargs="+", default=None, help="Feature types to analyze (e.g., gene exon RNA). If not specified, all feature types are analyzed.") + parser.add_argument("--stat-test", default="nonparametric", + choices=["auto", "parametric", "nonparametric", "welch", "kruskal"], + help="Statistical test selection: auto (test assumptions), parametric (Student/ANOVA), nonparametric (Mann-Whitney/Kruskal-Wallis, default), welch (Welch t-test/ANOVA)") + parser.add_argument("--bmk-filter", nargs="+", default=["primary_padj", "kruskal_padj", "welch_padj", "anova_padj"], + help="Priority list of column names for filtering significant BMKs (default: primary_padj kruskal_padj welch_padj anova_padj). BMKs are collected in cascade until --max-bmks is reached.") + parser.add_argument("--max-bmks", type=int, default=500, + help="Maximum number of BMKs to use for RandomForest and classification (default: 500). Prevents memory crashes with large datasets.") + args = parser.parse_args() + + # Determine number of workers + if args.jobs == -1: + n_jobs = os.cpu_count() + elif args.jobs < 1: + parser.error("--jobs must be >= 1 or -1 for all CPUs") + else: + n_jobs = args.jobs + + log.info(f"Using {n_jobs} parallel job(s)") + log.info(f"Statistical test method: {args.stat_test}") + log.info(f"BMK filtering cascade: {' → '.join(args.bmk_filter)}") + log.info(f"Max BMKs for ML models: {args.max_bmks}") + + # Check that at least one input file is provided + if not args.aggregates and not args.features: + parser.error("At least one of --aggregates or --features must be provided") + + outdir = args.outdir + safe_mkdir(outdir) + + # Load data + log.info("Loading data...") + agg_df = None + feat_df = None + + if args.aggregates: + if not os.path.exists(args.aggregates): + log.error(f"Aggregates file not found: {args.aggregates}") + sys.exit(1) + log.info(f"Loading aggregates from {args.aggregates}...") + agg_df = pd.read_csv(args.aggregates, sep="\t", dtype={"SeqID": str, "Start": str, "End": str, "Strand": str}) + agg_df.columns = agg_df.columns.str.strip() # Remove leading/trailing whitespace from column names + # Strip whitespace from string columns + for col in agg_df.select_dtypes(include=['object', 'string']).columns: + agg_df[col] = agg_df[col].str.strip() if agg_df[col].dtype in ['object', 'string'] else agg_df[col] + log.info(f" Aggregates: {len(agg_df)} rows") + else: + log.info("Skipping aggregates (no file provided)") + + if args.features: + if not os.path.exists(args.features): + log.error(f"Features file not found: {args.features}") + sys.exit(1) + log.info(f"Loading features from {args.features}...") + feat_df = pd.read_csv(args.features, sep="\t", dtype={"SeqID": str, "Start": str, "End": str, "Strand": str}) + feat_df.columns = feat_df.columns.str.strip() # Remove leading/trailing whitespace from column names + # Strip whitespace from string columns + for col in feat_df.select_dtypes(include=['object', 'string']).columns: + feat_df[col] = feat_df[col].str.strip() if feat_df[col].dtype in ['object', 'string'] else feat_df[col] + log.info(f" Features: {len(feat_df)} rows") + else: + log.info("Skipping features (no file provided)") + + # Parse sample information from whichever file is available + sample_df = agg_df if agg_df is not None else feat_df + sample_info = parse_sample_columns(sample_df.columns) + all_value_types = get_value_types(sample_info) + + # Filter value types if specified + if args.value_types: + value_types = [vt for vt in args.value_types if vt in all_value_types] + # Warn about non-existent value types + missing = [vt for vt in args.value_types if vt not in all_value_types] + if missing: + log.warning(f"Requested value types not found in data: {missing}") + if not value_types: + log.error(f"None of the requested value types {args.value_types} found in data. Available: {all_value_types}") + sys.exit(1) + else: + value_types = all_value_types + + # Count unique samples (group, sample, rep combinations) + unique_samples = len(set((s["group"], s["sample"], s["rep"]) for s in sample_info)) + log.info(f"Available value types: {all_value_types}") + if args.value_types: + log.info(f"Analyzing value types: {value_types}") + else: + log.info(f"Analyzing all value types: {value_types}") + log.info(f"Samples: {unique_samples} unique samples across {len(all_value_types)} value types ({len(sample_info)} total columns)") + + all_results = {} + + for vtype in value_types: + log.info(f"\n{'='*60}") + log.info(f"VALUE TYPE: {vtype}") + log.info(f"{'='*60}") + + vtype_dir = os.path.join(outdir, vtype) + safe_mkdir(vtype_dir) + vcols = cols_for_vtype(sample_info, vtype) + v_sample_info = sample_info_for_vtype(sample_info, vtype) + all_results[vtype] = {"aggregate": {}, "feature": {}} + + # Collect all analysis tasks for this value_type + tasks = [] + + # =============================================================== + # AGGREGATES + # =============================================================== + if agg_df is not None: + log.info(f"\n--- AGGREGATES for {vtype} ---") + agg_data = agg_df[agg_df["Mtype"] == "aggregate"].copy() + + agg_dir = os.path.join(vtype_dir, "aggregate") + + # Filter aggregate levels if specified + agg_levels = args.agg_levels if args.agg_levels else ["global", "sequence", "feature"] + log.info(f"Analyzing aggregate levels: {agg_levels}") + + # --- 1. Global aggregates (Type == global) --- + if "global" in agg_levels: + glob_agg = agg_data[agg_data["Type"] == "global"] + else: + glob_agg = pd.DataFrame() # Empty dataframe if global is not requested + + if not glob_agg.empty: + # All BMKs together (no Ptype/Ctype/Mode filter) + tasks.append(( + prepare_df_for_task(glob_agg), vcols, v_sample_info, + os.path.join(agg_dir, "global", "all_global_bmks"), + "Global - All BMKs", + ("aggregate", "global_all_bmks"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + + # all sites (Mode == all_sites) + section = glob_agg[glob_agg["Mode"] == "all_sites"] + tasks.append(( + prepare_df_for_task(section), vcols, v_sample_info, + os.path.join(agg_dir, "global", "all_sites"), + "Global - All Sites", + ("aggregate", "global_all_sites"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + + # All sites by Ctype (Ptype == "." and 3 modes) + for mode in ["all_isoforms", "chimaera", "longest_isoform"]: + section = glob_agg[(glob_agg["Ptype"] == ".") & (glob_agg["Mode"] == mode)] + tasks.append(( + prepare_df_for_task(section), vcols, v_sample_info, + os.path.join(agg_dir, "global", f"by_ctype_{mode}"), + f"Global - By Ctype - {mode}", + ("aggregate", f"global_ctype_{mode}"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + + # All sites by Ptype (Ptype != "." and 3 modes) + ptypes = [p for p in glob_agg["Ptype"].unique() if p != "."] + for ptype in ptypes: + for mode in ["all_isoforms", "chimaera", "longest_isoform"]: + section = glob_agg[(glob_agg["Ptype"] == ptype) & (glob_agg["Mode"] == mode)] + tasks.append(( + prepare_df_for_task(section), vcols, v_sample_info, + os.path.join(agg_dir, "global", f"by_ptype_{ptype}_{mode}"), + f"Global - Ptype={ptype} - {mode}", + ("aggregate", f"global_ptype_{ptype}_{mode}"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + + # --- 2. Chromosome/Sequence aggregates (Type == sequence) --- + if "sequence" in agg_levels: + seq_agg = agg_data[agg_data["Type"] == "sequence"] + else: + seq_agg = pd.DataFrame() # Empty dataframe if chr is not requested + + if not seq_agg.empty: + chromosomes = seq_agg["SeqID"].unique() + for chrom in chromosomes: + chr_data = seq_agg[seq_agg["SeqID"] == chrom] + + # All sites + section = chr_data[chr_data["Mode"] == "all_sites"] + tasks.append(( + prepare_df_for_task(section), vcols, v_sample_info, + os.path.join(agg_dir, f"sequence/{chrom}", "all_sites"), + f"Chr {chrom} - All Sites", + ("aggregate", f"chr{chrom}_all_sites"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + + # By Ctype + for mode in ["all_isoforms", "chimaera", "longest_isoform"]: + section = chr_data[(chr_data["Ptype"] == ".") & (chr_data["Mode"] == mode)] + tasks.append(( + prepare_df_for_task(section), vcols, v_sample_info, + os.path.join(agg_dir, f"sequence/{chrom}", f"by_ctype_{mode}"), + f"Chr {chrom} - By Ctype - {mode}", + ("aggregate", f"chr{chrom}_ctype_{mode}"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + + # By Ptype + local_ptypes = [p for p in chr_data["Ptype"].unique() if p != "."] + for ptype in local_ptypes: + for mode in ["all_isoforms", "chimaera", "longest_isoform"]: + section = chr_data[(chr_data["Ptype"] == ptype) & (chr_data["Mode"] == mode)] + tasks.append(( + prepare_df_for_task(section), vcols, v_sample_info, + os.path.join(agg_dir, f"sequence/{chrom}", f"by_ptype_{ptype}_{mode}"), + f"Chr {chrom} - Ptype={ptype} - {mode}", + ("aggregate", f"chr{chrom}_ptype_{ptype}_{mode}"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + + # --- All sequences combined (all chromosomes pooled) --- + # all_sites + section = seq_agg[seq_agg["Mode"] == "all_sites"] + tasks.append(( + prepare_df_for_task(section), vcols, v_sample_info, + os.path.join(agg_dir, "sequence", "all_sequence_bmks", "all_sites"), + "All Sequences - All Sites", + ("aggregate", "allseq_all_sites"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + # by_ctype + for mode in ["all_isoforms", "chimaera", "longest_isoform"]: + section = seq_agg[(seq_agg["Ptype"] == ".") & (seq_agg["Mode"] == mode)] + tasks.append(( + prepare_df_for_task(section), vcols, v_sample_info, + os.path.join(agg_dir, "sequence", "all_sequence_bmks", f"by_ctype_{mode}"), + f"All Sequences - By Ctype - {mode}", + ("aggregate", f"allseq_ctype_{mode}"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + # by_ptype + all_seq_ptypes = [p for p in seq_agg["Ptype"].unique() if p != "."] + for ptype in all_seq_ptypes: + for mode in ["all_isoforms", "chimaera", "longest_isoform"]: + section = seq_agg[(seq_agg["Ptype"] == ptype) & (seq_agg["Mode"] == mode)] + tasks.append(( + prepare_df_for_task(section), vcols, v_sample_info, + os.path.join(agg_dir, "sequence", "all_sequence_bmks", f"by_ptype_{ptype}_{mode}"), + f"All Sequences - Ptype={ptype} - {mode}", + ("aggregate", f"allseq_ptype_{ptype}_{mode}"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + + # --- 3. Feature aggregates (Type == feature) --- + if "feature" in agg_levels: + feat_agg = agg_data[agg_data["Type"] == "feature"] + else: + feat_agg = pd.DataFrame() # Empty dataframe if feature is not requested + + if not feat_agg.empty: + fa_ptypes = feat_agg["Ptype"].unique() + fa_ctypes = feat_agg["Ctype"].unique() + fa_modes = feat_agg["Mode"].unique() + + # --- all_feature_together: all chromosomes pooled --- + for ptype in fa_ptypes: + for ctype in fa_ctypes: + for mode in fa_modes: + section = feat_agg[ + (feat_agg["Ptype"] == ptype) & + (feat_agg["Ctype"] == ctype) & + (feat_agg["Mode"] == mode) + ] + safe_name = f"{ptype}_{ctype}_{mode}".replace(".", "all").replace("-", "_") + tasks.append(( + prepare_df_for_task(section), vcols, v_sample_info, + os.path.join(agg_dir, "feature", "all_feature_together", safe_name), + f"Feature Agg - Ptype={ptype}, Ctype={ctype}, Mode={mode}", + ("aggregate", f"featagg_{safe_name}"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + + # --- by_sequence: features grouped by chromosome --- + fa_chroms = [c for c in feat_agg["SeqID"].unique() if c != "."] + for chrom in fa_chroms: + chr_feat = feat_agg[feat_agg["SeqID"] == chrom] + chr_ptypes = chr_feat["Ptype"].unique() + chr_ctypes = chr_feat["Ctype"].unique() + chr_modes = chr_feat["Mode"].unique() + for ptype in chr_ptypes: + for ctype in chr_ctypes: + for mode in chr_modes: + section = chr_feat[ + (chr_feat["Ptype"] == ptype) & + (chr_feat["Ctype"] == ctype) & + (chr_feat["Mode"] == mode) + ] + safe_name = f"{ptype}_{ctype}_{mode}".replace(".", "all").replace("-", "_") + tasks.append(( + prepare_df_for_task(section), vcols, v_sample_info, + os.path.join(agg_dir, "feature", "by_sequence", str(chrom), safe_name), + f"Feature Agg - Chr {chrom} - Ptype={ptype}, Ctype={ctype}, Mode={mode}", + ("aggregate", f"featagg_chr{chrom}_{safe_name}"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + else: + log.info(f"\n--- Skipping AGGREGATES for {vtype} (no aggregates file) ---") + + # =============================================================== + # FEATURES + # =============================================================== + if feat_df is not None: + log.info(f"\n--- FEATURES for {vtype} ---") + feat_data = feat_df.copy() # Mtype is always "feature" in features file + feat_dir = os.path.join(vtype_dir, "feature") + + # Build hierarchy + tree, top_ids = build_feature_tree(feat_data) + + # Group top features by Type + top_features = feat_data[feat_data["ParentIDs"] == "."] + available_feature_types = top_features["Type"].unique().tolist() + + # Filter feature types if specified + if args.feature_types: + feature_types = [ft for ft in args.feature_types if ft in available_feature_types] + missing = [ft for ft in args.feature_types if ft not in available_feature_types] + if missing: + log.warning(f"Requested feature types not found in data: {missing}") + if not feature_types: + log.error(f"None of the requested feature types {args.feature_types} found in data. Available: {available_feature_types}") + feature_types = [] # Will skip all features + else: + feature_types = available_feature_types + + if feature_types: + log.info(f"Available feature types: {available_feature_types}") + log.info(f"Analyzing feature types: {feature_types}") + + for ttype in feature_types: + type_dir = os.path.join(feat_dir, ttype.replace(" ", "_")) + type_features = top_features[top_features["Type"] == ttype] + + # Analyze all features of this type together + all_ids_of_type = [] + for _, row in type_features.iterrows(): + fid = row["ID"] + # Collect this feature and all descendants + def collect_ids(node_id): + ids = [node_id] + if node_id in tree: + for child in tree[node_id]["children"]: + ids.extend(collect_ids(child)) + return ids + all_ids_of_type.extend(collect_ids(fid)) + + type_all_df = feat_data[feat_data["ID"].isin(all_ids_of_type)] + tasks.append(( + prepare_df_for_task(type_all_df), vcols, v_sample_info, + os.path.join(type_dir, "_all"), + f"Features - {ttype} (all)", + ("feature", f"type_{ttype}_all"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + + # Per top-feature analysis + for _, row in type_features.iterrows(): + fid = row["ID"] + def collect_ids(node_id): + ids = [node_id] + if node_id in tree: + for child in tree[node_id]["children"]: + ids.extend(collect_ids(child)) + return ids + sub_ids = collect_ids(fid) + sub_df = feat_data[feat_data["ID"].isin(sub_ids)] + safe_fid = fid.replace(":", "_").replace("/", "_") + tasks.append(( + prepare_df_for_task(sub_df), vcols, v_sample_info, + os.path.join(type_dir, safe_fid), + f"Feature: {fid}", + ("feature", f"feature_{safe_fid}"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + else: + log.info(f"\n--- Skipping FEATURES for {vtype} (no features file) ---") + + # Execute tasks (parallel or sequential) + log.info(f"Executing {len(tasks)} analysis tasks...") + if n_jobs > 1 and len(tasks) > 1: + # Parallel execution with worker recycling to prevent memory leaks + log.info(f"Submitting {len(tasks)} tasks to {n_jobs} workers...") + + # MEMORY FIX #1: Use spawn context to prevent fork Copy-on-Write memory inheritance + # With fork (default on Linux), each worker inherits ALL parent memory via CoW + # → 6 workers × 5GB parent = 30GB total (even if psutil shows less per process) + # With spawn, workers start clean and only receive their task data via pickle + mp_context = multiprocessing.get_context('spawn') + log.info(f" Using 'spawn' context (not fork) to avoid CoW memory inheritance") + + # Use max_tasks_per_child=20 to balance performance and memory + # Higher than fork (was 5) because spawn workers are clean but have import overhead (~275MB) + log.info(f" Worker recycling: Each worker will process max 20 tasks before restart") + + # CRITICAL: Use lazy submission to avoid loading all 1491 tasks (dict copies) in memory at once + # Submit only max_pending_tasks at a time, then submit new ones as they complete + max_pending_tasks = n_jobs * 3 # Keep 3x workers worth of tasks in flight + log.info(f" Lazy submission: Max {max_pending_tasks} tasks in memory at once (was {len(tasks)})") + + with ProcessPoolExecutor(max_workers=n_jobs, max_tasks_per_child=20, mp_context=mp_context) as executor: + future_to_key = {} + task_iter = iter(tasks) + completed = 0 + failed = 0 + submitted_count = 0 + last_progress_time = time.time() + stall_timeout = 90 # Increased from 40s for spawn startup delay (imports take ~2-5s per worker) + last_worker_check = time.time() + worker_check_interval = 15 # Check worker health every 15 seconds + + # Initial submission of first batch + initial_batch = min(max_pending_tasks, len(tasks)) + log.info(f" Submitting initial batch of {initial_batch} tasks...") + for _ in range(initial_batch): + try: + df_source, cols, info, outdir, name, key, stat_test, bmk_filter, max_bmks = next(task_iter) + future = executor.submit(analyze_section_wrapper, (df_source, cols, info, outdir, name, key, stat_test, bmk_filter, max_bmks)) + future_to_key[future] = (key, name) + submitted_count += 1 + # Explicitly free the df_source reference to help GC + del df_source + except StopIteration: + break + + log.info(f" Initial batch submitted. Processing with {n_jobs} workers...") + log.info(f" Note: Tasks have a 2-minute timeout. Anti-deadlock active.") + + # Process futures with stall detection + pending_futures = set(future_to_key.keys()) + + while pending_futures: + # Periodically check if workers have died + current_time = time.time() + if psutil is not None and (current_time - last_worker_check) > worker_check_interval: + try: + process = psutil.Process() + children = process.children(recursive=True) + n_active_workers = len(children) + + # If more than half of workers are dead, we have a problem + if n_active_workers < (n_jobs / 2) and len(pending_futures) > n_active_workers * 2: + log.error(f" ✗ WORKER DEATH DETECTED: Only {n_active_workers}/{n_jobs} workers alive with {len(pending_futures)} pending tasks") + log.error(f" Most likely cause: Out-Of-Memory (OOM) killed workers") + log.error(f" Cancelling all pending tasks to prevent infinite hang") + log.error(f" TIP: Restart with fewer workers (--n-jobs 2 or --n-jobs 3)") + + # Cancel all pending futures + for future in list(pending_futures): + future.cancel() + key, section_name = future_to_key[future] + log.error(f" ✗ CANCELLED (orphaned): {section_name}") + failed += 1 + break + + last_worker_check = current_time + except Exception as e: + log.debug(f"Worker health check failed: {e}") + + # Use short timeout on as_completed to check for stalls + try: + done_iter = as_completed(pending_futures, timeout=5) + for future in done_iter: + pending_futures.discard(future) + key, section_name = future_to_key[future] + + # Process the completed future + try: + result_key, results = future.result(timeout=1) # Short timeout since already done + mtype, section_key = result_key + all_results[vtype][mtype][section_key] = results + completed += 1 + last_progress_time = time.time() # Reset stall timer + + # MEMORY FIX #3: Immediate cleanup after storing (results is slim dict with only paths) + del results + + # Force garbage collection every 10 tasks to free memory in main process + if completed % 10 == 0: + gc.collect() + + # Progress update every 5 completions or at key milestones + if completed % 5 == 0 or completed in [1, 10, 25, 50, 100]: + log.info(f" ✓ Progress: {completed}/{len(tasks)} completed, {failed} failed, {submitted_count - completed - failed} submitted pending") + except TimeoutError: + failed += 1 + last_progress_time = time.time() + log.error(f" ✗ TIMEOUT ({completed+failed}/{len(tasks)}): {section_name} - exceeded 2 minutes") + except Exception as e: + failed += 1 + last_progress_time = time.time() + # Check if it's a worker crash (common patterns in error message) + if "process" in str(e).lower() and ("terminate" in str(e).lower() or "crash" in str(e).lower() or "abrupt" in str(e).lower()): + log.error(f" ✗ CRASH ({completed+failed}/{len(tasks)}): {section_name} - worker OOM or crash") + else: + log.error(f" ✗ ERROR ({completed+failed}/{len(tasks)}): {section_name} - {type(e).__name__}") + + # CRITICAL: Submit next task to maintain max_pending_tasks in flight + # This lazy submission keeps memory usage constant regardless of total tasks + if len(pending_futures) < max_pending_tasks: + try: + df_source, cols, info, outdir, name, key, stat_test, bmk_filter, max_bmks = next(task_iter) + new_future = executor.submit(analyze_section_wrapper, (df_source, cols, info, outdir, name, key, stat_test, bmk_filter, max_bmks)) + future_to_key[new_future] = (key, name) + pending_futures.add(new_future) + submitted_count += 1 + # Explicitly free the df_source reference to help GC + del df_source + # Log every 50 submissions + if submitted_count % 50 == 0: + log.info(f" → Submitted up to {submitted_count}/{len(tasks)} tasks (lazy mode)") + except StopIteration: + pass # No more tasks to submit + + # Break inner loop to check stall timeout + break + except TimeoutError: + # No futures completed in 5 seconds, check for stall + elapsed_since_progress = time.time() - last_progress_time + if elapsed_since_progress > stall_timeout: + log.error(f" ✗ DEADLOCK DETECTED: No progress for {elapsed_since_progress:.0f}s") + log.error(f" Cancelling {len(pending_futures)} remaining tasks to prevent infinite hang") + # Cancel all pending futures + for future in pending_futures: + future.cancel() + key, section_name = future_to_key[future] + log.error(f" ✗ CANCELLED: {section_name}") + failed += 1 + break + + if failed > 0: + log.warning(f" {failed} tasks failed or cancelled out of {len(tasks)} total") + else: + # Sequential execution + for df_source, cols, info, outdir, name, key, stat_test, bmk_filter, max_bmks in tasks: + # Reconstruct DataFrame from pickled format + if isinstance(df_source, tuple) and df_source[0] == 'pickled': + import pickle + df = pickle.loads(df_source[1]) + else: + df = pd.DataFrame(df_source) + results = analyze_section(df, cols, info, outdir, name, stat_test=stat_test, bmk_filter_cols=bmk_filter, max_bmks=max_bmks) + + # Create slim results dict (same as parallel mode for consistency) + slim_results = { + "differential_table": os.path.join(outdir, "differential", "differential_results.csv") + } + + mtype, section_key = key + all_results[vtype][mtype][section_key] = slim_results + + # =============================================================== + # GLOBAL RANKING + # =============================================================== + log.info("\n--- GLOBAL RANKING ---") + global_ranking(all_results, os.path.join(outdir, "global_ranking")) + + # Save manifest + manifest = { + "value_types": value_types, + "outdir": outdir, + "n_aggregates": len(agg_df) if agg_df is not None else 0, + "n_features": len(feat_df) if feat_df is not None else 0, + "sample_info": sample_info, + } + with open(os.path.join(outdir, "manifest.json"), "w") as f: + json.dump(manifest, f, indent=2, default=str) + + log.info(f"\nAnalysis complete. Results saved to {outdir}/") + + +if __name__ == "__main__": + main() diff --git a/bin/barometer_report.py b/bin/barometer_report.py new file mode 100755 index 0000000..ecca6bc --- /dev/null +++ b/bin/barometer_report.py @@ -0,0 +1,1133 @@ +#!/usr/bin/env python3 +""" +barometer_report.py – Generate an interactive HTML report from barometer_analyze.py results. + +Assembles all tables, figures, and analyses into a multi-tab HTML report using +Jinja2 templating with interactive DataTables, Plotly for image zoom, and +collapsible sections for navigation. + +Usage: + python barometer_report.py [--results DIR] [--output FILE] +""" + +import argparse +import base64 +import csv +import glob +import json +import logging +import os +import re +import sys +from collections import OrderedDict +from pathlib import Path + +import pandas as pd +from jinja2 import Template + +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") +log = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +EMBED_IMAGES = False +RESULTS_DIR = "llmeter_results" + + +def img_src(path): + """Return an image src: base64 data URI if embedding, else relative path.""" + if not os.path.exists(path): + return "" + if EMBED_IMAGES: + with open(path, "rb") as f: + data = f.read() + ext = path.rsplit(".", 1)[-1].lower() + mime = {"png": "image/png", "jpg": "image/jpeg", "jpeg": "image/jpeg", + "svg": "image/svg+xml"}.get(ext, "image/png") + return f"data:{mime};base64,{base64.b64encode(data).decode()}" + else: + # Return relative path from report file location + return path + + +def csv_to_html_table(path, max_rows=500, table_id=None): + """Read a CSV and return an HTML table string.""" + if not os.path.exists(path): + return "" + try: + df = pd.read_csv(path) + except Exception: + return "" + if len(df) > max_rows: + df = df.head(max_rows) + tid = f' id="{table_id}"' if table_id else "" + classes = "display compact stripe hover" + html = df.to_html(index=False, classes=classes, border=0, na_rep="NA", + float_format=lambda x: f"{x:.4g}" if abs(x) > 1e-4 else f"{x:.2e}") + html = html.replace("
13. Visualization – Overview Heatmap
' + content += f'Heatmap\n' + + for subdir in subdirs: + sub_path = os.path.join(section_dir, subdir) + if not os.path.isdir(sub_path): + continue + + label = analysis_labels.get(subdir, subdir.title()) + content += f'
{label}
\n' + + # Tables + for csv_path in find_csvs(sub_path): + fname = os.path.basename(csv_path).replace(".csv", "") + table_counter[0] += 1 + tid = f"dt_{table_counter[0]}" + content += f'
{fname.replace("_", " ").title()}
\n' + content += f'
{csv_to_html_table(csv_path, table_id=tid)}
\n' + content += f'\n' + content += '
\n' + + # Images + for img_path in find_images(sub_path): + fname = os.path.basename(img_path).replace(".png", "").replace("_", " ").title() + b64 = img_src(img_path) + if b64: + content += f'
{fname}
' + content += f'{fname}
\n' + + content += '
\n' + + return content + + +def build_report_data(results_dir): + """Walk the results directory and build the report structure.""" + manifest_path = os.path.join(results_dir, "manifest.json") + manifest = {} + if os.path.exists(manifest_path): + with open(manifest_path) as f: + manifest = json.load(f) + + value_types = manifest.get("value_types", []) + if not value_types: + # Discover from directory + for d in sorted(os.listdir(results_dir)): + dp = os.path.join(results_dir, d) + if os.path.isdir(dp) and d not in ("global_ranking",): + value_types.append(d) + + report = OrderedDict() + table_counter = [0] + + for vtype in value_types: + vtype_dir = os.path.join(results_dir, vtype) + if not os.path.isdir(vtype_dir): + continue + + report[vtype] = OrderedDict() + + # Aggregate and Feature tabs + for mtype in ["aggregate", "feature"]: + mtype_dir = os.path.join(vtype_dir, mtype) + if not os.path.isdir(mtype_dir): + continue + + report[vtype][mtype] = OrderedDict() + # Walk all section directories + for root, dirs, files in os.walk(mtype_dir): + # Only process leaf directories that contain actual results + if any(f.endswith(".csv") or f.endswith(".png") for f in files): + rel = os.path.relpath(root, mtype_dir) + parts = rel.split(os.sep) + first = parts[0] + group_name = prettify_group_name(first) + + # Detect subgroup / sub-subgroup + ssg_key = None + ssg_name = None + if first == "global" and len(parts) >= 2: + sec_dir = parts[1] + remainder = os.sep.join(parts[1:]) + if "by_ptype_" not in sec_dir: + # all_sites / by_ctype_* → "All" subgroup + sg_key = "all" + sg_name = "All" + section_name = prettify_section_name(remainder) + else: + # by_ptype__ → subgroup per ptype + ptype_raw, ctype_raw = _extract_ptype(sec_dir) + if ptype_raw: + sg_key = f"ptype_{ptype_raw}" + sg_name = _prettify_ptype(ptype_raw) + section_name = _prettify_mode(ctype_raw) if ctype_raw else prettify_section_name(remainder) + else: + sg_key = "" + sg_name = "" + section_name = prettify_section_name(remainder) + elif first == "sequence" and len(parts) >= 3: + sg_raw = parts[1] # e.g. "21" or "all_sequence_bmks" + sg_name = "All Sequences" if sg_raw == "all_sequence_bmks" else f"Chr {sg_raw}" + sg_key = sg_raw + sec_dir = parts[2] + remainder = os.sep.join(parts[2:]) + if "by_ptype_" not in sec_dir: + # all_sites / by_ctype_* → "All" sub-subgroup + ssg_key = "all" + ssg_name = "All" + section_name = prettify_section_name(remainder) + else: + # by_ptype__ → sub-subgroup per ptype + ptype_raw, ctype_raw = _extract_ptype(sec_dir) + if ptype_raw: + ssg_key = f"ptype_{ptype_raw}" + ssg_name = _prettify_ptype(ptype_raw) + section_name = _prettify_mode(ctype_raw) if ctype_raw else prettify_section_name(remainder) + else: + ssg_key = None + ssg_name = None + section_name = prettify_section_name(remainder) + elif first == "feature" and len(parts) >= 3 and parts[1] == "all_feature_together": + # all_feature_together/{safe_name} — all chromosomes pooled + sec_dir = parts[2] + section_dir = os.path.join(mtype_dir, first, parts[1], sec_dir) + meta = _read_section_meta(section_dir) + if meta: + ptype_raw = meta["Ptype"] + ctype_raw = meta["Ctype"] + mode = meta["Mode"] + sg_key = ptype_raw + sg_name = _prettify_ptype(ptype_raw) + ssg_key = ctype_raw + ssg_name = ctype_raw.replace("_", " ").replace("-", " ").title() + section_name = _prettify_mode(mode) + else: + sg_key = "all_feature_together" + sg_name = "All Features Together" + ssg_key = None + ssg_name = None + section_name = prettify_section_name(parts[2]) + elif first == "feature" and len(parts) >= 4 and parts[1] == "by_sequence": + # by_sequence/{chrom}/{safe_name} — features grouped by chromosome + chrom = parts[2] + sec_dir = parts[3] + section_dir = os.path.join(mtype_dir, first, parts[1], chrom, sec_dir) + meta = _read_section_meta(section_dir) + sg_key = f"seq_{chrom}" + sg_name = f"Chr {chrom}" + if meta: + ptype_raw = meta["Ptype"] + ctype_raw = meta["Ctype"] + mode = meta["Mode"] + ssg_key = ctype_raw + ssg_name = ctype_raw.replace("_", " ").replace("-", " ").title() + section_name = f"{_prettify_ptype(ptype_raw)} – {_prettify_mode(mode)}" + else: + ssg_key = None + ssg_name = None + section_name = prettify_section_name(parts[3]) + elif first == "feature" and len(parts) >= 2: + # Flat fallback (old structure) + sec_dir = parts[1] + section_dir = os.path.join(mtype_dir, first, sec_dir) + meta = _read_section_meta(section_dir) + if meta: + ptype_raw = meta["Ptype"] + ctype_raw = meta["Ctype"] + mode = meta["Mode"] + sg_key = ptype_raw + sg_name = _prettify_ptype(ptype_raw) + ssg_key = ctype_raw + ssg_name = ctype_raw.replace("_", " ").replace("-", " ").title() + section_name = _prettify_mode(mode) + else: + sg_key = "" + sg_name = "" + section_name = prettify_section_name(os.sep.join(parts[1:])) + else: + sg_key = "" + sg_name = "" + remainder = os.sep.join(parts[1:]) if len(parts) > 1 else "" + section_name = prettify_section_name(remainder) if remainder else prettify_section_name(first) + + # Ensure group exists + if group_name not in report[vtype][mtype]: + report[vtype][mtype][group_name] = OrderedDict() + + # Ensure subgroup bucket exists + # Structure: report[vtype][mtype][group_name][sg_key] = { + # "name": sg_name, "sections": OrderedDict(), "subgroups": OrderedDict()} + grp = report[vtype][mtype][group_name] + if sg_key not in grp: + grp[sg_key] = {"name": sg_name, "sections": OrderedDict(), "subgroups": OrderedDict()} + + content = build_section_content(root, section_name, table_counter) + if content.strip(): + if (first in ("sequence", "feature")) and ssg_key: + # Store in a sub-subgroup (All / Gene / …) inside the chr subgroup + ssg = grp[sg_key]["subgroups"] + if ssg_key not in ssg: + ssg[ssg_key] = {"name": ssg_name, "sections": OrderedDict()} + ssg[ssg_key]["sections"][rel] = {"name": section_name, "content": content} + else: + grp[sg_key]["sections"][rel] = {"name": section_name, "content": content} + + # Global ranking + gr_dir = os.path.join(results_dir, "global_ranking") + global_content = "" + if os.path.isdir(gr_dir): + for csv_path in find_csvs(gr_dir): + fname = os.path.basename(csv_path).replace(".csv", "") + table_counter[0] += 1 + tid = f"dt_{table_counter[0]}" + global_content += f'
{fname.replace("_", " ").title()}
\n' + global_content += f'
{csv_to_html_table(csv_path, table_id=tid)}
\n' + global_content += f'\n' + global_content += '
\n' + + for img_path in find_images(gr_dir): + fname = os.path.basename(img_path).replace(".png", "").replace("_", " ").title() + b64 = img_src(img_path) + if b64: + global_content += f'
{fname}
' + global_content += f'{fname}
\n' + + return report, manifest, global_content + + +# --------------------------------------------------------------------------- +# HTML Template +# --------------------------------------------------------------------------- + +HTML_TEMPLATE = """ + + + + +barometer Biomarker Analysis Report + + + + + + + + + + + + +
barometer Biomarker Analysis Report
+ + + +
+ + +
+

Report Summary

+

Aggregates: {{ manifest.get('n_aggregates', 'N/A') }} rows  |  + Features: {{ manifest.get('n_features', 'N/A') }} rows  |  + Value types: {{ manifest.get('value_types', [])|join(', ') }}

+

Samples:

+
    + {% for s in manifest.get('sample_info', [])[:20] %} +
  • {{ s.col }} (group: {{ s.group }}, sample: {{ s.sample }}, rep: {{ s.rep }})
  • + {% endfor %} +
+

Analyses performed: QC, Descriptive Statistics, Differential Editing (Mann-Whitney U, Kruskal-Wallis, ANOVA), + Multiple Testing Correction (FDR-BH), PCA, Hierarchical Clustering, Correlation Analysis, Random Forest Feature + Importance, Classification (RF, GBT, LDA), Stability/Robustness (CV, replicate concordance), Batch Effect Detection.

+
+ + +
+

🏆 Global Biomarker Ranking

+ {{ global_content|safe }} +
+ + + + +
+ {% for vtype, mtype_data in report.items() %} +
+ +
+ + + + +
+ {% for mtype, groups in mtype_data.items() %} +
+ +
+ + {% for group_name, subgroups in groups.items() %} +
+

{{ group_name }}

+ +
+ {% endfor %} + +
+ {% endfor %} +
+ +
+ {% endfor %} +
+ +
+ + +
+ Zoomed +
+ + + + + +""" + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser(description="Generate barometer Biomarker HTML Report") + parser.add_argument("-r", "--results", default="barometer_results", help="Results directory from barometer_analyze.py") + parser.add_argument("-o", "--output", default="barometer_report.html", help="Output HTML report file") + parser.add_argument("-e", "--embed-images", action="store_true", default=False, + help="Embed images as base64 (larger file but self-contained)") + args = parser.parse_args() + # Pass embed flag as module-level so helpers can use it + global EMBED_IMAGES, RESULTS_DIR + EMBED_IMAGES = args.embed_images + RESULTS_DIR = args.results + + if not os.path.isdir(args.results): + log.error(f"Results directory not found: {args.results}") + sys.exit(1) + + log.info(f"Building report from {args.results}...") + report, manifest, global_content = build_report_data(args.results) + + log.info(f"Rendering HTML...") + template = Template(HTML_TEMPLATE) + html = template.render( + report=report, + manifest=manifest, + global_content=global_content, + sanitize=sanitize_id, + ) + + with open(args.output, "w", encoding="utf-8") as f: + f.write(html) + + log.info(f"Report written to {args.output}") + log.info(f"Open in browser to view the interactive report.") + + +if __name__ == "__main__": + main() diff --git a/bin/drip.py b/bin/drip.py index 04cebac..108ae42 100755 --- a/bin/drip.py +++ b/bin/drip.py @@ -1,8 +1,25 @@ #!/usr/bin/env python3 -import pandas as pd +import os import sys + +# Limit implicit BLAS/LAPACK multi-threading to respect SLURM CPU allocation. +# Will be set to 1 by default and updated based on --threads CLI argument. +# This ensures total CPU usage = implicit threads + explicit Pool threads <= allocated CPUs. +os.environ.setdefault('OMP_NUM_THREADS', '1') +os.environ.setdefault('OPENBLAS_NUM_THREADS', '1') +os.environ.setdefault('MKL_NUM_THREADS', '1') +os.environ.setdefault('VECLIB_MAXIMUM_THREADS', '1') +os.environ.setdefault('NUMEXPR_NUM_THREADS', '1') + +import pandas as pd +import numpy as np +import multiprocessing +import gc from pathlib import Path +import shutil +import tempfile +import re def print_help(): """Print help message explaining the script usage and calculations.""" @@ -10,71 +27,112 @@ def print_help(): DRIP - RNA Editing Analysis Tool DESCRIPTION: - This script analyzes RNA editing from standardized puviometer files. It calculates - two key metrics for all 16 genome-variant base pair combinations across multiple + This script analyzes RNA editing from standardized pluviometer files. It calculates + two key metrics for all 12 genome-variant base pair combinations across multiple samples and combines them into a unified matrix format. USAGE: - ./drip.py OUTPUT_PREFIX FILE1:GROUP1:SAMPLE1:REP1 FILE2:GROUP2:SAMPLE2:REP2 [...] [--with-file-id] + ./drip.py --output OUTPUT_PREFIX FILE1:GROUP1:SAMPLE1:REP1 FILE2:GROUP2:SAMPLE2:REP2 [...] ./drip.py --help | -h ARGUMENTS: - OUTPUT_PREFIX Prefix for the output TSV file (will create OUTPUT_PREFIX.tsv) + --output OUTPUT_PREFIX, -o OUTPUT_PREFIX + Prefix for the output directories (required). + Creates OUTPUT_PREFIX_espf/ and OUTPUT_PREFIX_espr/, + each containing 12 TSV files (AA.tsv, AC.tsv, …, TT.tsv). FILEn:GROUPn:SAMPLEn:REPn Input file path, group name, sample name, and replicate ID separated by colons. All four components are required. --with-file-id Include file ID in column names (default: omit file ID) + --report-non-qualified-features + Include rows where all metric values are NA (default: omit them). + By default, rows where every sample has NA for the given base pair + are skipped (not covered or not qualified in any sample). + --min-samples-pct X + Keep a row only if at least X% of all samples have a qualified value + (non-NA, non-zero) for the base pair. A sample "has a value" when + at least one of its espf/espr metrics is non-NA and non-zero. + Applied as an OR with --min-group-pct when both are provided. + Ignored when --report-non-qualified-features is set. + --min-group-pct Y + Keep a row only if at least one group has at least Y% of its + samples with a qualified value. Applied as an OR with + --min-samples-pct when both are provided. + Ignored when --report-non-qualified-features is set. + --min-cov N Minimum read coverage threshold (default: 1). Positions with a + denominator (genome base count for espf, read count for espr) + strictly below this value are reported as NA instead of 0. + --threads N, -t N + Number of parallel threads to use for writing output files + (default: 1, sequential). Max useful value is 12 (one per base pair). + --decimals N, -d N + Number of decimal places for output values (default: 4). + Reduces file size by rounding espf/espr metrics. --help, -h Display this help message +NA BEHAVIOR: +| Source du NA | Mécanisme | Résultat +| Couverture = 0 dans sampleA | np.where(mask, ..., np.nan) | NA ✅ +| Couverture < min_cov dans sampleA | même masque | NA ✅ +| Ligne absente de sampleB | how='outer' → NaN | NA ✅ +| Couverture OK, 0 éditions | ratio = 0.0 | 0.0 ✅ + INPUT FILE FORMAT: The input files must be TSV files with the following columns: - - GenomeBases: Frequencies of bases in the reference genome (order: A, C, G, T) - - SiteBasePairings: Number of sites with each genome-variant base pairing + - ObservedBases: Frequencies of bases in the reference genome (order: A, C, G, T) + - SiteBasePairingsQualified: Number of sites with each genome-variant base pairing (qualified, filtered by cov + edit thresholds) (order: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT) - - ReadBasePairings: Frequencies of genome-variant base pairings in reads + - ReadBasePairingsQualified: Frequencies of genome-variant base pairings in reads (filtered by cov + edit thresholds) (order: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT) CALCULATED METRICS: - For each line, the script calculates metrics for all 16 base pair combinations: + For each line, the script calculates metrics for all 12 base pair combinations: For each combination XY (where X = genome base, Y = read base): 1. XY_espf (edited_sites_proportion_feature) - Proportion of XY sites in the DNA feature: - Formula: XY_SiteBasePairings / X_GenomeBases - This represents the proportion of genomic X positions that show X-to-Y variation in the feature. + Formula: XY_SiteBasePairingsQualified / X_QualifiedBases + This represents the proportion of qualified X positions that show X-to-Y variation in the feature. 2. XY_espr (edited_sites_proportion_reads) - Proportion of XY pairing in reads: Formula: XY_ReadBasePairings / (XA + XC + XG + XT)_ReadBasePairings This represents the proportion of X-position reads that show Y in the reads. - All 16 combinations are calculated: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT + All 12 combinations are calculated: AC, AG, AT, CA, CG, CT, GA, GC, GT, TA, TC, TG OUTPUT FORMAT: - Multiple TSV files (one per base pair combination) with aggregates as rows: - - OUTPUT_PREFIX_AA.tsv - - OUTPUT_PREFIX_AC.tsv - - OUTPUT_PREFIX_AG.tsv - - ... (16 files total, one for each XY combination) - + Two output directories, one per metric type, each containing 12 TSV files + (one per base pair combination). Row filtering (--min-samples-pct, etc.) + is applied independently per metric: a row may appear in the espf output + but not in the espr output and vice versa. + + Directories created: + - OUTPUT_PREFIX_espf/ espf metric (proportion of edited sites in feature) + - OUTPUT_PREFIX_espr/ espr metric (proportion of edited reads) + + Files in each directory (12 per metric): + - AA.tsv, AC.tsv, AG.tsv, AT.tsv + - CA.tsv, CC.tsv, CG.tsv, CT.tsv + - GA.tsv, GC.tsv, GG.tsv, GT.tsv + - TA.tsv, TC.tsv, TG.tsv, TT.tsv + Each file contains: - - Metadata columns (first 6 columns): + + Metadata columns: - SeqID: Sequence/chromosome identifier - ParentIDs: Parent feature identifiers - ID: Unique identifier + - Mtype: Type of feature - Ptype: Type of Parent feature - - Type: Type of feature + - Type: Aggregate type (feature / sequence / global) - Ctype: Type of Children feature - - Mode: Mode of aggregation used if any (e.g., 'all_sites', 'edited_sites', 'edited_reads') - - Metric columns (for each sample): - - GROUP::SAMPLE::REPLICATE::espf: XY sites proportion in feature (XY sites / X bases) - - GROUP::SAMPLE::REPLICATE::espr: XY sites proportion in reads (XY reads / all X reads) - - Or with --with-file-id option: - - GROUP::SAMPLE::REPLICATE::FILE_ID::espf - - GROUP::SAMPLE::REPLICATE::FILE_ID::espr - + - Mode: Mode of aggregation (e.g., 'all_sites', 'edited_sites', 'edited_reads') + - Start, End, Strand + + Metric columns (one per sample): + - GROUP::SAMPLE::REPLICATE:: (without --with-file-id) + - GROUP::SAMPLE::REPLICATE::FILE_ID:: (with --with-file-id) + Where: - GROUP: Group/condition name provided in arguments - SAMPLE: Sample name provided in arguments @@ -83,29 +141,26 @@ def print_help(): - The '::' separator allows easy splitting to retrieve all components EXAMPLE: - ./drip.py results \\ + ./drip.py --output results \\ sample1_aggregates.tsv:control:sample1:rep1 \\ sample2_aggregates.tsv:control:sample2:rep2 \\ sample3_aggregates.tsv:treated:sample1:rep1 - This creates 16 files (one per base pair combination): - - results_AA.tsv, results_AC.tsv, results_AG.tsv, results_AT.tsv, - - results_CA.tsv, results_CC.tsv, results_CG.tsv, results_CT.tsv, - - results_GA.tsv, results_GC.tsv, results_GG.tsv, results_GT.tsv, - - results_TA.tsv, results_TC.tsv, results_TG.tsv, results_TT.tsv - - Each file has columns: - SeqID, ParentIDs, ID, Ptype, Ctype, Mode, - control::sample1::rep1::rain_sample1::espf, control::sample1::rep1::rain_sample1::espr, - control::sample2::rep2::rain_sample2::espf, control::sample2::rep2::rain_sample2::espr, - treated::sample1::rep1::rain_sample3::espf, treated::sample1::rep1::rain_sample3::espr - - Column headers use format: GROUP::SAMPLE::REPLICATE::FILE_ID::METRIC + Creates two directories: + - results_espf/ with AA.tsv, AC.tsv, …, TT.tsv (espf metric) + - results_espr/ with AA.tsv, AC.tsv, …, TT.tsv (espr metric) + + Example columns in results_espf/AG.tsv: + SeqID, ParentIDs, ID, Mtype, Ptype, Type, Ctype, Mode, Start, End, Strand, + control::sample1::rep1::espf, + control::sample2::rep2::espf, + treated::sample1::rep1::espf + + Column headers use format: GROUP::SAMPLE::REPLICATE::METRIC - GROUP: The group/condition name provided - SAMPLE: The sample name provided - REPLICATE: Replicate ID (rep1, rep2, etc.) - - FILE_ID: Input filename without extension and last '_' suffix - - METRIC: espf or espr + - METRIC: espf or espr (same for all columns in a given file) - Separator '::' allows easy splitting to retrieve all components AUTHORS: @@ -115,158 +170,375 @@ def print_help(): print(help_text) sys.exit(0) -def parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, include_file_id=False): - """Parse a single TSV file and extract editing metrics for all base pair combinations.""" - df = pd.read_csv(filepath, sep='\t') - - # DO NOT filter out rows where ID is '.' - # These are special aggregate rows (e.g., all_sites) that should be kept - # Base pair combinations in order - base_pairs = ['AA', 'AC', 'AG', 'AT', 'CA', 'CC', 'CG', 'CT', - 'GA', 'GC', 'GG', 'GT', 'TA', 'TC', 'TG', 'TT'] - bases = ['A', 'C', 'G', 'T'] - - # Parse GenomeBases (order: A, C, G, T) - for i, base in enumerate(bases): - df[f'{base}_count'] = df['GenomeBases'].str.split(',').str[i].astype(int) +def _parse_comma_col_at(series, idx): + """Extract the integer at comma-based index `idx` from a packed string column.""" + return series.str.split(',', expand=True)[idx].astype(np.int64) + + +def parse_tsv_file_for_bp(filepath, bp, group_name, sample_name, replicate, file_id, + include_file_id=False, min_cov=1, decimals=4): + """Parse one TSV file and compute espf/espr for a single base pair `bp` only. + + Loads only the three packed data columns + metadata (via usecols) and discards + them immediately after extracting the two needed integer vectors. Returns a + slim DataFrame with 11 metadata columns + 2 float metric columns. + Peak RAM per call ≈ raw CSV in memory + a few integer Series. + """ + ALL_BPS = ['AA', 'AC', 'AG', 'AT', 'CA', 'CC', 'CG', 'CT', + 'GA', 'GC', 'GG', 'GT', 'TA', 'TC', 'TG', 'TT'] + BASES = ['A', 'C', 'G', 'T'] + + genome_base = bp[0] + bp_idx = ALL_BPS.index(bp) + gb_idx = BASES.index(genome_base) # 0–3 + gb_offset = gb_idx * 4 # first XA/XC/XG/XT index in 12-value vector + + metadata_cols = ['SeqID', 'ParentIDs', 'ID', 'Mtype', 'Ptype', 'Type', + 'Ctype', 'Mode', 'Start', 'End', 'Strand'] + needed_cols = metadata_cols + ['QualifiedBases', + 'SiteBasePairingsQualified', + 'ReadBasePairingsQualified'] + mixed_dtypes = {'SeqID': str, 'Start': str, 'End': str, 'Strand': str} + + df = pd.read_csv(filepath, sep='\t', dtype=mixed_dtypes, usecols=needed_cols) + + # espf numerator / denominator + bp_sites = _parse_comma_col_at(df['SiteBasePairingsQualified'], bp_idx) + x_count = _parse_comma_col_at(df['QualifiedBases'], gb_idx) + + # espr: expand ReadBasePairingsQualified once, pick 5 values, discard the rest + reads_split = df['ReadBasePairingsQualified'].str.split(',', expand=True) + bp_reads = reads_split[bp_idx].astype(np.int64) + total_reads = (reads_split[gb_offset ].astype(np.int64) + + reads_split[gb_offset + 1].astype(np.int64) + + reads_split[gb_offset + 2].astype(np.int64) + + reads_split[gb_offset + 3].astype(np.int64)) + del reads_split + + # Keep only metadata columns; drop the three packed columns now + result = df[metadata_cols].copy() + del df + + col_prefix = (f'{group_name}::{sample_name}::{replicate}::{file_id}' + if include_file_id + else f'{group_name}::{sample_name}::{replicate}') + + mask_f = x_count >= min_cov + result[f'{col_prefix}::espf'] = np.where( + mask_f, bp_sites / x_count.where(mask_f, 1), np.nan + ) + result[f'{col_prefix}::espf'] = result[f'{col_prefix}::espf'].round(decimals) + + mask_r = total_reads >= min_cov + result[f'{col_prefix}::espr'] = np.where( + mask_r, bp_reads / total_reads.where(mask_r, 1), np.nan + ) + result[f'{col_prefix}::espr'] = result[f'{col_prefix}::espr'].round(decimals) - # Parse SiteBasePairings (all 16 combinations) - for i, bp in enumerate(base_pairs): - df[f'{bp}_sites'] = df['SiteBasePairings'].str.split(',').str[i].astype(int) - - # Parse ReadBasePairings (all 16 combinations) - for i, bp in enumerate(base_pairs): - df[f'{bp}_reads'] = df['ReadBasePairings'].str.split(',').str[i].astype(int) - - # Calculate metrics for each base pair combination - metadata_cols = ['SeqID', 'ParentIDs', 'ID', 'Mtype', 'Ptype', 'Type', 'Ctype', 'Mode', 'Start', 'End', 'Strand'] - result_cols = metadata_cols.copy() - - # Create column prefix with group::sample::replicate::file_id or group::sample::replicate - if include_file_id: - col_prefix = f'{group_name}::{sample_name}::{replicate}::{file_id}' - else: - col_prefix = f'{group_name}::{sample_name}::{replicate}' - - for bp in base_pairs: - genome_base = bp[0] # First letter is the genome base - - # Calculate espf: XY_sites / X_count - espf_col = f'{col_prefix}::{bp}::espf' - df[espf_col] = df.apply( - lambda row: row[f'{bp}_sites'] / row[f'{genome_base}_count'] - if row[f'{genome_base}_count'] > 0 else 0, - axis=1 - ) - result_cols.append(espf_col) - - # Calculate espr: XY_reads / (XA + XC + XG + XT) - # Calculate total reads for this genome base - total_reads_col = f'{genome_base}_total_reads' - if total_reads_col not in df.columns: - df[total_reads_col] = ( - df[f'{genome_base}A_reads'] + - df[f'{genome_base}C_reads'] + - df[f'{genome_base}G_reads'] + - df[f'{genome_base}T_reads'] - ) - - espr_col = f'{col_prefix}::{bp}::espr' - df[espr_col] = df.apply( - lambda row: row[f'{bp}_reads'] / row[total_reads_col] - if row[total_reads_col] > 0 else 0, - axis=1 - ) - result_cols.append(espr_col) - - # Return dataframe with metadata and all metrics - result = df[result_cols].copy() - return result -def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id=False): - """Merge data from multiple samples and create output matrices - one file per base pair combination.""" - - all_data = [] - file_id_list = [] - group_name_list = [] - sample_name_list = [] - replicate_list = [] - - # Base pair combinations in order - base_pairs = ['AA', 'AC', 'AG', 'AT', 'CA', 'CC', 'CG', 'CT', - 'GA', 'GC', 'GG', 'GT', 'TA', 'TC', 'TG', 'TT'] - + +def _compute_bp_from_df(df, bp, bp_idx, gb_idx, gb_offset, col_prefix, min_cov, decimals): + """Compute espf/espr for one BP from an already-loaded DataFrame. + + Returns a slim DataFrame (11 metadata cols + 2 metric cols). + `df` must already contain the pre-parsed columns: + SiteBasePairingsQualified, QualifiedBases, ReadBasePairingsQualified + as well as the 11 metadata columns. + """ + metadata_cols = ['SeqID', 'ParentIDs', 'ID', 'Mtype', 'Ptype', 'Type', + 'Ctype', 'Mode', 'Start', 'End', 'Strand'] + + bp_sites = _parse_comma_col_at(df['SiteBasePairingsQualified'], bp_idx) + x_count = _parse_comma_col_at(df['QualifiedBases'], gb_idx) + + reads_split = df['ReadBasePairingsQualified'].str.split(',', expand=True) + bp_reads = reads_split[bp_idx].astype(np.int64) + total_reads = (reads_split[gb_offset ].astype(np.int64) + + reads_split[gb_offset + 1].astype(np.int64) + + reads_split[gb_offset + 2].astype(np.int64) + + reads_split[gb_offset + 3].astype(np.int64)) + del reads_split + + result = df[metadata_cols].copy() + + mask_f = x_count >= min_cov + result[f'{col_prefix}::espf'] = np.where( + mask_f, bp_sites / x_count.where(mask_f, 1), np.nan + ).round(decimals) + + mask_r = total_reads >= min_cov + result[f'{col_prefix}::espr'] = np.where( + mask_r, bp_reads / total_reads.where(mask_r, 1), np.nan + ).round(decimals) + + return result + + +def _write_one_bp(bp, accumulated, output_prefix, metadata_cols, report_non_qualified): + """Sort, filter and write the accumulated DataFrame for one BP.""" + merged = accumulated.sort_values(['SeqID', 'ParentIDs', 'Mode']) + + if not report_non_qualified: + metric_cols = [c for c in merged.columns if c not in metadata_cols] + merged = merged[ + (merged[metric_cols].notna() & (merged[metric_cols] != 0)).any(axis=1) + ] + + output_file = f'{output_prefix}_{bp}.tsv' + merged.to_csv(output_file, sep='\t', index=False, na_rep='NA') + n_rows = len(merged) + del merged + gc.collect() + print(f' {output_file} ({n_rows} rows)') + return output_file + +def _safe_seqid(seqid: str) -> str: + """Convert a SeqID to a filesystem-safe directory name.""" + return re.sub(r'[^A-Za-z0-9._-]', '_', seqid) or 'EMPTY' + + +def _split_file_by_seqid(filepath, temp_dir, sample_idx, needed_cols, mixed_dtypes): + """Read one TSV file once and write per-SeqID chunk files. + + Returns a dict {seqid: chunk_file_path}. Each chunk file contains the + same columns as the original (needed_cols), filtered to one SeqID. + """ + df = pd.read_csv(filepath, sep='\t', dtype=mixed_dtypes, usecols=needed_cols) + seqid_to_path = {} + for seqid, group in df.groupby('SeqID', sort=False): + chunk_dir = os.path.join(temp_dir, _safe_seqid(seqid)) + os.makedirs(chunk_dir, exist_ok=True) + path = os.path.join(chunk_dir, f'{sample_idx}.tsv') + group.to_csv(path, sep='\t', index=False) + seqid_to_path[seqid] = path + del df + gc.collect() + return seqid_to_path + + +def _process_seqid_chunk(seqid, chunk_paths_by_sample, col_prefixes, temp_out_dir, + metadata_cols, bp_meta, all_bps, min_cov, decimals, + report_non_qualified, min_samples_pct=None, min_group_pct=None): + """Compute 12 BPs for one SeqID and write 12 per-BP chunk files. + + Each element in chunk_paths_by_sample is a path (str) or None when that + sample has no rows for this SeqID — the outer join fills NaN for it. + Returns (seqid, {bp: out_path, ...}) — only BPs with at least one row. + """ + mixed_dtypes_local = {'SeqID': str, 'Start': str, 'End': str, 'Strand': str} + bp_accumulators = [None] * len(all_bps) + + for col_prefix, chunk_path in zip(col_prefixes, chunk_paths_by_sample): + if chunk_path is None: + continue + df = pd.read_csv(chunk_path, sep='\t', dtype=mixed_dtypes_local) + for i in range(len(all_bps)): + bp_idx, gb_idx, gb_offset = bp_meta[i] + bp_data = _compute_bp_from_df( + df, all_bps[i], bp_idx, gb_idx, gb_offset, col_prefix, min_cov, decimals + ) + if bp_accumulators[i] is None: + bp_accumulators[i] = bp_data + else: + bp_accumulators[i] = bp_accumulators[i].merge( + bp_data, on=metadata_cols, how='outer' + ) + del bp_data + del df + gc.collect() + + out_paths = {} + safe = _safe_seqid(seqid) + for i, bp in enumerate(all_bps): + acc = bp_accumulators[i] + if acc is None: + continue + acc = acc.sort_values(['ParentIDs', 'Mode']) + + # Process espf and espr independently: separate output files, separate + # row-filtering. A row that passes the espf threshold but not the espr + # threshold (or vice versa) will appear in only one of the two outputs. + for metric in ('espf', 'espr'): + # Each column for this metric type maps 1:1 to one sample. + metric_cols = [c for c in acc.columns + if c not in metadata_cols and c.endswith(f'::{metric}')] + if not metric_cols: + continue + acc_metric = acc[metadata_cols + metric_cols] + + if not report_non_qualified: + # Step 1 — cell-level: "has a value" = non-NA AND non-zero. + # NA means the position was not covered (or below min_cov, or + # absent from this sample via the outer join). + # 0.0 means covered but no editing event observed. + has_value = acc_metric[metric_cols].notna() & (acc_metric[metric_cols] != 0) + + # Step 2 — row-level decision, applied independently per metric + # type (espf rows and espr rows are filtered separately): + # + # Default: keep if ANY sample has a value for this metric. + # + # --min-samples-pct X: keep if proportion of samples with a + # value >= X/100 (across all samples globally). + # Example: 3/5 samples with espf > 0 → 60%; X=50 → keep. + # + # --min-group-pct Y: keep if at least one group has >= Y/100 + # of its samples with a value. Groups are identified by + # the first :: component of the column name (GROUP name). + # Example: "ctrl" group has 2/3 espf values → 67%; Y=60 → keep. + # + # Both flags → OR: keep if either condition is satisfied. + if min_samples_pct is None and min_group_pct is None: + keep = has_value.any(axis=1) + else: + keep = pd.Series(False, index=acc_metric.index) + if min_samples_pct is not None: + n = len(metric_cols) + keep |= has_value.sum(axis=1) / n >= min_samples_pct / 100.0 + if min_group_pct is not None: + groups: dict[str, list[str]] = {} + for c in metric_cols: + groups.setdefault(c.split('::')[0], []).append(c) + for g_cols in groups.values(): + keep |= ( + has_value[g_cols].sum(axis=1) / len(g_cols) + >= min_group_pct / 100.0 + ) + acc_metric = acc_metric[keep] + + if len(acc_metric) > 0: + out_path = os.path.join(temp_out_dir, metric, f'{metric}_{bp}_{safe}.tsv') + acc_metric.to_csv(out_path, sep='\t', index=False, na_rep='NA') + out_paths[(bp, metric)] = out_path + + bp_accumulators[i] = None + + gc.collect() + return seqid, out_paths + + +def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id=False, + min_cov=1, threads=1, decimals=4, report_non_qualified=False, + min_samples_pct=None, min_group_pct=None): + """Produce one output file per base pair combination. + + Memory strategy — three phases: + Phase 1 (Split): each input file is read once and split by SeqID into + temporary chunk files. Peak RAM = one full input file. + Phase 2 (Process): each SeqID is processed independently (all 12 BPs, + across all samples) and results written to temp chunks. + Peak RAM per worker ≈ num_samples × one-SeqID slice. + Workers run in parallel when threads > 1. + Phase 3 (Concat): per-SeqID temp chunks are appended in order to the + 12 final output files. Peak RAM = one chunk at a time. + """ + ALL_BPS = ['AC', 'AG', 'AT', 'CA', 'CG', 'CT', + 'GA', 'GC', 'GT', 'TA', 'TC', 'TG'] + BASES = ['A', 'C', 'G', 'T'] + metadata_cols = ['SeqID', 'ParentIDs', 'ID', 'Mtype', 'Ptype', 'Type', + 'Ctype', 'Mode', 'Start', 'End', 'Strand'] + needed_cols = metadata_cols + ['QualifiedBases', + 'SiteBasePairingsQualified', + 'ReadBasePairingsQualified'] + mixed_dtypes = {'SeqID': str, 'Start': str, 'End': str, 'Strand': str} + bp_meta = [(ALL_BPS.index(bp), BASES.index(bp[0]), BASES.index(bp[0]) * 4) + for bp in ALL_BPS] + + sample_info = [] for filepath, (group_name, sample_name, replicate) in file_group_sample_replicate_dict.items(): - # Extract file ID (everything before the last underscore in filename) filename_stem = Path(filepath).stem file_id = '_'.join(filename_stem.split('_')[:-1]) - print(f"Processing {group_name}::{sample_name} (replicate {replicate}) from {filepath} (file_id: {file_id})...") - data = parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, include_file_id) - all_data.append(data) - file_id_list.append(file_id) - group_name_list.append(group_name) - sample_name_list.append(sample_name) - replicate_list.append(replicate) - - # Merge all samples based on metadata columns - metadata_cols = ['SeqID', 'ParentIDs', 'ID', 'Mtype', 'Ptype', 'Type', 'Ctype', 'Mode', 'Start', 'End', 'Strand'] - merged = all_data[0] - for data in all_data[1:]: - merged = merged.merge(data, on=metadata_cols, how='outer') - - # Fill NA values with 0 for metrics - merged = merged.fillna(0) - - # Sort by SeqID, then ParentIDs, then Mode - merged = merged.sort_values(['SeqID', 'ParentIDs', 'Mode']) - - metadata_cols = ['SeqID', 'ParentIDs', 'ID', 'Mtype', 'Ptype', 'Type', 'Ctype', 'Mode', 'Start', 'End', 'Strand'] - - # Create one file per base pair combination - output_files = [] - for bp in base_pairs: - # Select columns for this base pair combination - bp_cols = metadata_cols.copy() - for group_name, sample_name, replicate, file_id in zip(group_name_list, sample_name_list, replicate_list, file_id_list): - if include_file_id: - col_prefix = f'{group_name}::{sample_name}::{replicate}::{file_id}' - else: - col_prefix = f'{group_name}::{sample_name}::{replicate}' - espf_col = f'{col_prefix}::{bp}::espf' - espr_col = f'{col_prefix}::{bp}::espr' - if espf_col in merged.columns: - bp_cols.append(espf_col) - if espr_col in merged.columns: - bp_cols.append(espr_col) - - # Create result for this base pair - bp_result = merged[bp_cols].copy() - - # Rename columns to remove the bp suffix (cleaner output) - rename_dict = {} - for group_name, sample_name, replicate, file_id in zip(group_name_list, sample_name_list, replicate_list, file_id_list): - if include_file_id: - col_prefix = f'{group_name}::{sample_name}::{replicate}::{file_id}' - else: - col_prefix = f'{group_name}::{sample_name}::{replicate}' - rename_dict[f'{col_prefix}::{bp}::espf'] = f'{col_prefix}::espf' - rename_dict[f'{col_prefix}::{bp}::espr'] = f'{col_prefix}::espr' - bp_result = bp_result.rename(columns=rename_dict) - - # Save to file - output_file = f"{output_prefix}_{bp}.tsv" - bp_result.to_csv(output_file, sep='\t', index=False) - output_files.append(output_file) - - print(f"\nOutput files created:") - for output_file in output_files: - print(f" - {output_file}") - print(f" - {len(merged)} aggregates per file") - print(f" - {len(sample_name_list)} samples: {', '.join(sample_name_list)}") - print(f" - {len(base_pairs)} files (one per base pair combination)") - - return merged + col_prefix = (f'{group_name}::{sample_name}::{replicate}::{file_id}' + if include_file_id + else f'{group_name}::{sample_name}::{replicate}') + sample_info.append((filepath, col_prefix)) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_split_dir = os.path.join(temp_dir, 'split') + temp_out_dir = os.path.join(temp_dir, 'out') + os.makedirs(temp_split_dir) + os.makedirs(temp_out_dir) + os.makedirs(os.path.join(temp_out_dir, 'espf')) + os.makedirs(os.path.join(temp_out_dir, 'espr')) + + # ── Phase 1: split each file by SeqID ───────────────────────────────── + print('Phase 1/3 — Splitting input files by SeqID...') + all_seqids_seen: dict[str, int] = {} # seqid → first-appearance order + sample_seqid_paths: list[dict[str, str]] = [] + + for sample_idx, (filepath, _col_prefix) in enumerate(sample_info): + print(f' [{sample_idx + 1}/{len(sample_info)}] {filepath}') + seqid_to_path = _split_file_by_seqid( + filepath, temp_split_dir, sample_idx, needed_cols, mixed_dtypes + ) + sample_seqid_paths.append(seqid_to_path) + for seqid in seqid_to_path: + if seqid not in all_seqids_seen: + all_seqids_seen[seqid] = len(all_seqids_seen) + + # Sort SeqIDs: real sequences first (lexicographic), "." (globals) last + all_seqids = sorted( + all_seqids_seen.keys(), + key=lambda s: (s == '.', s) + ) + print(f' Found {len(all_seqids)} unique SeqIDs across all samples.') + + # ── Phase 2: process each SeqID (parallel if threads > 1) ───────────── + col_prefixes = [col_prefix for _, col_prefix in sample_info] + worker_args = [ + (seqid, + [ssp.get(seqid) for ssp in sample_seqid_paths], + col_prefixes, temp_out_dir, + metadata_cols, bp_meta, ALL_BPS, min_cov, decimals, report_non_qualified, + min_samples_pct, min_group_pct) + for seqid in all_seqids + ] + + mode = f'{min(threads, len(all_seqids))} workers' if threads > 1 else 'sequential' + print(f'Phase 2/3 — Processing {len(all_seqids)} SeqID chunks ({mode})...') + + if threads > 1: + with multiprocessing.Pool(processes=min(threads, len(all_seqids))) as pool: + results = pool.starmap(_process_seqid_chunk, worker_args) + else: + results = [_process_seqid_chunk(*a) for a in worker_args] + + # Restore stable SeqID order (parallel mode may return out of order) + seqid_order = {s: i for i, s in enumerate(all_seqids)} + results.sort(key=lambda r: seqid_order[r[0]]) + + # ── Phase 3: concatenate per-SeqID chunks into final output files ───────── + # Two output directories (one per metric type), each with 12 BP files. + print('Phase 3/3 — Writing final output files...') + output_files = [] + for metric in ('espf', 'espr'): + out_dir = f'{output_prefix}_{metric}' + os.makedirs(out_dir, exist_ok=True) + for bp in ALL_BPS: + bp_chunks = [ + out_paths[(bp, metric)] + for _, out_paths in results + if (bp, metric) in out_paths + ] + if not bp_chunks: + continue + + out_path = os.path.join(out_dir, f'{output_prefix}_{metric}_{bp}.tsv') + with open(out_path, 'w') as fout: + with open(bp_chunks[0]) as first: + shutil.copyfileobj(first, fout) # includes header + for chunk_path in bp_chunks[1:]: + with open(chunk_path) as f: + next(f) # skip header + shutil.copyfileobj(f, fout) + output_files.append(out_path) + print(f' {out_path}') + + print(f'\nDone. {len(sample_info)} samples, {len(all_seqids)} SeqIDs, ' + f'{len(output_files)} output files written ' + f'in {output_prefix}_espf/ and {output_prefix}_espr/.') + # Example usage if __name__ == "__main__": @@ -275,24 +547,150 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ print_help() # Validate arguments - if len(sys.argv) < 3: + if len(sys.argv) < 2: print("ERROR: Insufficient arguments provided.\n", file=sys.stderr) - print("Usage: ./drip.py OUTPUT_PREFIX FILE1:SAMPLE1 FILE2:SAMPLE2 [...]\n", file=sys.stderr) + print("Usage: ./drip.py --output OUTPUT_PREFIX FILE1:GROUP1:SAMPLE1:REP1 [...]\n", file=sys.stderr) print("For detailed help, run: ./drip.py --help\n", file=sys.stderr) sys.exit(1) # Parse command line arguments - output_prefix = sys.argv[1] + output_prefix = None file_group_sample_replicate_dict = {} group_sample_rep_counts = {} include_file_id = False # Default: omit file_id from column names - - for arg in sys.argv[2:]: + min_cov = 1 # Default: NA when denominator is 0; treat 0 coverage as non-observed + threads = 1 # Default: sequential writing + decimals = 4 # Default: round to 4 decimal places + report_non_qualified = False # Default: skip rows where all metric values are NA + min_samples_pct = None # Default: no global-sample-% filter + min_group_pct = None # Default: no per-group-% filter + + args_iter = iter(range(1, len(sys.argv))) + for i in args_iter: + arg = sys.argv[i] # Check for --output/-o flag (supports --output=PREFIX, --output PREFIX, -o PREFIX) + if arg.startswith('--output') or arg == '-o': + if arg.startswith('--output') and '=' in arg: + output_prefix = arg.split('=', 1)[1] + if not output_prefix: + print("ERROR: --output requires a value", file=sys.stderr) + sys.exit(1) + else: + try: + next_i = next(args_iter) + output_prefix = sys.argv[next_i] + except StopIteration: + print("ERROR: --output/-o requires a value", file=sys.stderr) + sys.exit(1) + continue # Check for --with-file-id flag if arg == '--with-file-id': include_file_id = True continue - + + # Check for --report-non-qualified-features flag + if arg == '--report-non-qualified-features': + report_non_qualified = True + continue + + # Check for --min-samples-pct flag + if arg.startswith('--min-samples-pct'): + if '=' in arg: + val = arg.split('=', 1)[1] + else: + try: + next_i = next(args_iter) + val = sys.argv[next_i] + except StopIteration: + print('ERROR: --min-samples-pct requires a value', file=sys.stderr) + sys.exit(1) + try: + min_samples_pct = float(val) + except ValueError: + print('ERROR: --min-samples-pct requires a numeric value (0–100)', file=sys.stderr) + sys.exit(1) + if not (0.0 <= min_samples_pct <= 100.0): + print('ERROR: --min-samples-pct must be between 0 and 100', file=sys.stderr) + sys.exit(1) + continue + + # Check for --min-group-pct flag + if arg.startswith('--min-group-pct'): + if '=' in arg: + val = arg.split('=', 1)[1] + else: + try: + next_i = next(args_iter) + val = sys.argv[next_i] + except StopIteration: + print('ERROR: --min-group-pct requires a value', file=sys.stderr) + sys.exit(1) + try: + min_group_pct = float(val) + except ValueError: + print('ERROR: --min-group-pct requires a numeric value (0–100)', file=sys.stderr) + sys.exit(1) + if not (0.0 <= min_group_pct <= 100.0): + print('ERROR: --min-group-pct must be between 0 and 100', file=sys.stderr) + sys.exit(1) + continue + + # Check for --min-cov flag (supports both --min-cov=N and --min-cov N) + if arg.startswith('--min-cov'): + if '=' in arg: + try: + min_cov = int(arg.split('=', 1)[1]) + except ValueError: + print(f"ERROR: --min-cov requires an integer value", file=sys.stderr) + sys.exit(1) + else: + try: + next_i = next(args_iter) + min_cov = int(sys.argv[next_i]) + except (StopIteration, ValueError): + print(f"ERROR: --min-cov requires an integer value", file=sys.stderr) + sys.exit(1) + continue + + # Check for --threads/-t flag (supports --threads=N, --threads N, -t N) + if arg.startswith('--threads') or arg == '-t': + if arg.startswith('--threads') and '=' in arg: + try: + threads = int(arg.split('=', 1)[1]) + except ValueError: + print(f"ERROR: --threads requires an integer value", file=sys.stderr) + sys.exit(1) + else: + try: + next_i = next(args_iter) + threads = int(sys.argv[next_i]) + except (StopIteration, ValueError): + print(f"ERROR: --threads/-t requires an integer value", file=sys.stderr) + sys.exit(1) + if threads < 1: + print(f"ERROR: --threads must be at least 1", file=sys.stderr) + sys.exit(1) + continue + + # Check for --decimals/-d flag (supports --decimals=N, --decimals N, -d N) + if arg.startswith('--decimals') or arg == '-d': + if arg.startswith('--decimals') and '=' in arg: + try: + decimals = int(arg.split('=', 1)[1]) + except ValueError: + print(f"ERROR: --decimals requires an integer value", file=sys.stderr) + sys.exit(1) + else: + try: + next_i = next(args_iter) + decimals = int(sys.argv[next_i]) + except (StopIteration, ValueError): + print(f"ERROR: --decimals/-d requires an integer value", file=sys.stderr) + sys.exit(1) + if decimals < 0: + print(f"ERROR: --decimals must be non-negative", file=sys.stderr) + sys.exit(1) + continue + if ':' not in arg: print(f"ERROR: Invalid argument format '{arg}'", file=sys.stderr) print("Expected format: FILE:GROUP:SAMPLE:REPLICATE", file=sys.stderr) @@ -325,7 +723,21 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ file_group_sample_replicate_dict[filepath] = (group_name, sample_name, replicate) + # Validate that output_prefix was provided + if output_prefix is None: + print("ERROR: --output/-o argument is required.\n", file=sys.stderr) + print("Usage: ./drip.py --output OUTPUT_PREFIX FILE1:GROUP1:SAMPLE1:REP1 [...]\n", file=sys.stderr) + print("For detailed help, run: ./drip.py --help\n", file=sys.stderr) + sys.exit(1) + + # Validate that at least one input file was provided + if len(file_group_sample_replicate_dict) == 0: + print("ERROR: At least one input file is required.\n", file=sys.stderr) + print("Usage: ./drip.py --output OUTPUT_PREFIX FILE1:GROUP1:SAMPLE1:REP1 [...]\n", file=sys.stderr) + print("For detailed help, run: ./drip.py --help\n", file=sys.stderr) + sys.exit(1) + # Process all samples - result = merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id) + result = merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id, min_cov, threads, decimals, report_non_qualified, min_samples_pct, min_group_pct) print("\nAnalysis complete!") \ No newline at end of file diff --git a/bin/pluviometer/README.md b/bin/pluviometer/README.md index 2e7cf2a..5c67d85 100644 --- a/bin/pluviometer/README.md +++ b/bin/pluviometer/README.md @@ -49,3 +49,22 @@ The following files have been moved from `bin/` to `bin/pluviometer/`: - utils.py → pluviometer/utils.py All internal imports have been converted to relative imports to function as a Python package. + +## Test + +# Call as a python module (recommended) +cd bin +python -m pluviometer --sites SITES --gff GFF [OPTIONS] + +## calcul location: +Cas gérés correctement : + +1. Features simples (exon, CDS) : SimpleLocation → len = longueur réelle +2. Features parent (gene, mRNA) du GFF : SimpleLocation(start, end) → len = span complet (incluant introns) +C'est la sémantique correcte du GFF - un gène "couvre" toute sa région +3. Aggregates chimaera créés par le code : CompoundLocation via location_union() → len = somme des parties sans gaps ✓ +4. All sites - for ttotal site it is "." because cannot guess without fasta. + +To test everything goes well: + +PYTHONPATH=/workspaces/rain/bin python3 -m unittest pluviometer.test_site_counts -v \ No newline at end of file diff --git a/bin/pluviometer/SeqFeature_extensions.py b/bin/pluviometer/SeqFeature_extensions.py index bb8abe4..6b0e9b0 100644 --- a/bin/pluviometer/SeqFeature_extensions.py +++ b/bin/pluviometer/SeqFeature_extensions.py @@ -67,7 +67,7 @@ def make_chimaeras(self: SeqFeature, record_id: str) -> None: chimaera: SeqFeature = SeqFeature( location=location, id=f"{self.id}-{key}-chimaera", - type=key+"-chimaera", + type=key, qualifiers={"Parent": self.id} ) diff --git a/bin/pluviometer/__main__.py b/bin/pluviometer/__main__.py index 85d203e..08295d0 100755 --- a/bin/pluviometer/__main__.py +++ b/bin/pluviometer/__main__.py @@ -20,10 +20,14 @@ from BCBio import GFF from os import remove import progressbar +import subprocess import tempfile import argparse import logging +import pickle import math +import io +import gc logger = logging.getLogger(__name__) @@ -92,6 +96,7 @@ def __init__( aggregate_writer: AggregateFileWriter, filter: SiteFilter, use_progress_bar: bool, + report_non_qualified: bool = False, ): self.aggregate_writer: AggregateFileWriter = aggregate_writer """Aggregate counting output is written to a temporary file through this object""" @@ -99,6 +104,9 @@ def __init__( self.feature_writer: FeatureFileWriter = feature_writer """Feature counting output is written to a temporary file through this object""" + self.report_non_qualified: bool = report_non_qualified + """If True, write feature rows even when no qualified sites are found""""" + self.filter: SiteFilter = filter """Filter applied to the reads. Filters contain a mutable state, so they should not be shared between concurrent RecordCountingContext instances""" @@ -117,6 +125,16 @@ def __init__( ) """Dictionary of counters for the record-level all isoforms aggregates. Keys correspond to aggregate feature types""" + # Positions for aggregate counters + self.longest_isoform_aggregate_positions: dict[str, AggregatePositions] = {} + """Dictionary of positions for the record-level longest isoform aggregates""" + + self.chimaera_aggregate_positions: dict[str, AggregatePositions] = {} + """Dictionary of positions for the record-level chimaeric aggregates""" + + self.all_isoforms_aggregate_positions: dict[str, AggregatePositions] = {} + """Dictionary of positions for the record-level all isoforms aggregates""" + # New: Aggregate counters by (parent_type, aggregate_type) for intermediate-level aggregation self.longest_isoform_aggregate_counters_by_parent_type: defaultdict[tuple[str, str], MultiCounter] = defaultdict( DefaultMultiCounterFactory(self.filter) @@ -319,16 +337,19 @@ def state_update_cycle(self, new_position: int) -> None: visited_positions: int = 0 while ( - len(self.action_queue) > 0 and self.action_queue[0][0] < new_position - ): # Use < instead of <= because of Python's right-exclusive indfgexing + len(self.action_queue) > 0 and self.action_queue[0][0] <= new_position + ): # Use <= to process actions at the current position before updating counters _, actions = self.action_queue.popleft() visited_positions += 1 for feature in actions.activate: + logging.debug(f"Activating feature: {feature.id} (type: {feature.type}, level: {feature.level})") self.active_features[feature.id] = feature for feature in actions.deactivate: + logging.debug(f"Deactivating feature: {feature.id} (type: {feature.type}, level: {feature.level})") if feature.level == 1: + logging.debug(f"Calling checkout on level-1 feature: {feature.id}") self.checkout(feature, None) self.active_features.pop(feature.id, None) @@ -368,35 +389,57 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> Checkout should be triggered after a level-1 feature is completely cleared. Aggregation of the level-1 feature and all its descentants proceeds. """ + + logging.debug( + f"checkout() called: feature.id={feature.id}, feature.type={feature.type}, " + f"feature.level={feature.level}, is_chimaera={getattr(feature, 'is_chimaera', 'NOT_SET')}, " + f"has_counter={feature.id in self.counters}" + ) # Deactivate the feature self.active_features.pop(feature.id, None) # Counter for the feature itself - feature_counter: Optional[MultiCounter] = self.counters.get(feature.id, None) + feature_counter: Optional[MultiCounter] = self.counters.pop(feature.id, None) assert self.record.id if feature_counter: if feature.is_chimaera: assert parent_feature # A chimaera must always have a parent feature (a gene) - self.aggregate_writer.write_row_chimaera_with_data( - self.record.id, feature, parent_feature, feature_counter - ) + if self.report_non_qualified or feature_counter.filtered_base_freqs.sum() > 0: + logging.debug(f"Writing chimaera with data: {feature.id}") + self.aggregate_writer.write_row_chimaera_with_data( + self.record.id, feature, parent_feature, feature_counter + ) self.chimaera_aggregate_counters[feature.type].merge(feature_counter) # Also track by parent_type self.chimaera_aggregate_counters_by_parent_type[(parent_feature.type, feature.type)].merge(feature_counter) + + # Track positions for chimaera + if feature.type not in self.chimaera_aggregate_positions: + self.chimaera_aggregate_positions[feature.type] = AggregatePositions( + strand=feature.location.strand if hasattr(feature.location, 'strand') else 0 + ) + self.chimaera_aggregate_positions[feature.type].update_from_feature(feature) else: - self.feature_writer.write_row_with_data(self.record.id, feature, feature_counter) - del self.counters[feature.id] + logging.debug(f"Writing feature with data: {feature.id}, type: {feature.type}") + if self.report_non_qualified or feature_counter.filtered_base_freqs.sum() > 0: + self.feature_writer.write_row_with_data(self.record.id, feature, feature_counter) + # Explicitly delete counter after use to free memory + del feature_counter else: if feature.is_chimaera: assert parent_feature - self.aggregate_writer.write_row_chimaera_without_data( - self.record.id, feature, parent_feature - ) + if self.report_non_qualified: + logging.debug(f"Writing chimaera without data: {feature.id}") + self.aggregate_writer.write_row_chimaera_without_data( + self.record.id, feature, parent_feature + ) else: - self.feature_writer.write_row_without_data(self.record.id, feature) + logging.debug(f"Writing feature without data: {feature.id}, type: {feature.type}") + if self.report_non_qualified: + self.feature_writer.write_row_without_data(self.record.id, feature) # all_isoforms_aggregation_counters: Optional[defaultdict[str, MultiCounter]] = None @@ -413,6 +456,18 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> merge_aggregation_counter_dicts( self.all_isoforms_aggregate_counters, level1_all_isoforms_aggregation_counters ) + + # Merge positions for all_isoforms at record level + for aggregate_type, positions in level1_all_isoforms_aggregation_positions.items(): + if aggregate_type not in self.all_isoforms_aggregate_positions: + self.all_isoforms_aggregate_positions[aggregate_type] = AggregatePositions(strand=positions.strand) + self.all_isoforms_aggregate_positions[aggregate_type].merge(positions) + + # Merge positions for longest_isoform at record level + for aggregate_type, positions in level1_longest_isoform_aggregation_positions.items(): + if aggregate_type not in self.longest_isoform_aggregate_positions: + self.longest_isoform_aggregate_positions[aggregate_type] = AggregatePositions(strand=positions.strand) + self.longest_isoform_aggregate_positions[aggregate_type].merge(positions) # Also merge into by-parent-type dictionaries parent_type = feature.type @@ -425,6 +480,8 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> aggregate_type, aggregate_counter, ) in level1_longest_isoform_aggregation_counters.items(): + if not self.report_non_qualified and aggregate_counter.filtered_base_freqs.sum() == 0: + continue # Get positions for this aggregate type start_str, end_str, strand_str = ".", ".", "." if aggregate_type in level1_longest_isoform_aggregation_positions: @@ -441,12 +498,26 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> end=end_str, strand=strand_str, ) - self.aggregate_writer.write_counter_data(aggregate_counter) + # Calculate TotalSites from positions + total_sites_str = "." + if aggregate_type in level1_longest_isoform_aggregation_positions: + pos = level1_longest_isoform_aggregation_positions[aggregate_type] + if pos.start != float('inf') and pos.end > 0: + total_sites_str = str(pos.end - pos.start) + self.aggregate_writer.write_data( + total_sites_str, + ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), + ",".join(map(str, aggregate_counter.filtered_base_freqs[0:4].flat)), + ",".join(map(str, aggregate_counter.edit_qualified_site_freqs[0:4, 0:4].flat)), + ",".join(map(str, aggregate_counter.edit_qualified_read_freqs[0:4, 0:4].flat)), + ) for ( aggregate_type, aggregate_counter, ) in level1_all_isoforms_aggregation_counters.items(): + if not self.report_non_qualified and aggregate_counter.filtered_base_freqs.sum() == 0: + continue # Get positions for this aggregate type start_str, end_str, strand_str = ".", ".", "." if aggregate_type in level1_all_isoforms_aggregation_positions: @@ -463,10 +534,24 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> end=end_str, strand=strand_str, ) - self.aggregate_writer.write_counter_data(aggregate_counter) + # Calculate TotalSites from positions + total_sites_str = "." + if aggregate_type in level1_all_isoforms_aggregation_positions: + pos = level1_all_isoforms_aggregation_positions[aggregate_type] + if pos.start != float('inf') and pos.end > 0: + total_sites_str = str(pos.end - pos.start) + self.aggregate_writer.write_data( + total_sites_str, + ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), + ",".join(map(str, aggregate_counter.filtered_base_freqs[0:4].flat)), + ",".join(map(str, aggregate_counter.edit_qualified_site_freqs[0:4, 0:4].flat)), + ",".join(map(str, aggregate_counter.edit_qualified_read_freqs[0:4, 0:4].flat)), + ) else: feature_aggregation_counters, feature_aggregation_positions = self.aggregate_children(feature) for aggregate_type, aggregate_counter in feature_aggregation_counters.items(): + if not self.report_non_qualified and aggregate_counter.filtered_base_freqs.sum() == 0: + continue # Get positions for this aggregate type start_str, end_str, strand_str = ".", ".", "." if aggregate_type in feature_aggregation_positions: @@ -483,11 +568,27 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> end=end_str, strand=strand_str, ) - self.aggregate_writer.write_counter_data(aggregate_counter) + # Calculate TotalSites from positions + total_sites_str = "." + if aggregate_type in feature_aggregation_positions: + pos = feature_aggregation_positions[aggregate_type] + if pos.start != float('inf') and pos.end > 0: + total_sites_str = str(pos.end - pos.start) + self.aggregate_writer.write_data( + total_sites_str, + ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), + ",".join(map(str, aggregate_counter.filtered_base_freqs[0:4].flat)), + ",".join(map(str, aggregate_counter.edit_qualified_site_freqs[0:4, 0:4].flat)), + ",".join(map(str, aggregate_counter.edit_qualified_read_freqs[0:4, 0:4].flat)), + ) # Recursively check-out children for child in feature.sub_features: self.checkout(child, feature) + + # After processing all children, clear sub_features to free memory + if hasattr(feature, 'sub_features'): + feature.sub_features.clear() return None @@ -634,16 +735,32 @@ def update_active_counters(self, site_data: RNASiteVariantData) -> None: """ Update the multicounters matching the ID of features in the `active_features` set. A new multicounter is created if no matching ID is found. + + Strand compatibility logic: + - Feature unstranded (0): accepts sites from any strand + - Feature stranded: only accepts sites from same strand or unstranded sites + - Site unstranded: only counted for unstranded features (to avoid combinatorial explosion) """ for feature_key, feature in self.active_features.items(): + # Optimized order: check exact match first (most common case) if feature.location.strand == site_data.strand: counter: MultiCounter = self.counters[feature_key] counter.update(site_data) + # Feature is unstranded: accept stranded sites too + elif feature.location.strand == 0 and site_data.strand != 0: + counter: MultiCounter = self.counters[feature_key] + counter.update(site_data) + # Site is unstranded: only count for unstranded features + # (counting for all features would cause combinatorial explosion) + elif site_data.strand == 0 and feature.location.strand == 0: + counter: MultiCounter = self.counters[feature_key] + counter.update(site_data) return None def is_finished(self) -> bool: - return len(self.action_queue) > 0 or len(self.active_features) > 0 + """Return True if there is no remaining work (queues are empty and no active features)""" + return len(self.action_queue) == 0 and len(self.active_features) == 0 def launch_counting(self, reader: RNASiteVariantReader) -> None: next_svdata: Optional[RNASiteVariantData] = reader.seek_record(self.record.id) @@ -670,6 +787,35 @@ def launch_counting(self, reader: RNASiteVariantReader) -> None: return None + def cleanup(self) -> None: + """ + Clean up memory after processing a record. Clear all data structures. + """ + self.active_features.clear() + self.action_queue.clear() + self.counters.clear() + + # Clear aggregate counters + self.longest_isoform_aggregate_counters.clear() + self.chimaera_aggregate_counters.clear() + self.all_isoforms_aggregate_counters.clear() + self.longest_isoform_aggregate_counters_by_parent_type.clear() + self.chimaera_aggregate_counters_by_parent_type.clear() + self.all_isoforms_aggregate_counters_by_parent_type.clear() + + # Clear record reference + if hasattr(self, 'record'): + # Clear features from record to help garbage collection + if hasattr(self.record, 'features'): + self.record.features.clear() + delattr(self, 'record') + + # Clear progress bar reference + if hasattr(self, 'progbar'): + delattr(self, 'progbar') + + return None + def parse_cli_input() -> argparse.Namespace: """Parse command line input""" @@ -708,7 +854,7 @@ def parse_cli_input() -> argparse.Namespace: "--cov", "-c", type=int, - default=0, + default=1, help="Site coverage threshold for counting editions", ) parser.add_argument( @@ -736,6 +882,31 @@ def parse_cli_input() -> argparse.Namespace: parser.add_argument( "--progress", action="store_true", default=False, help="Display progress bar" ) + parser.add_argument( + "--gff_feature_types", + type=str, + default=None, + help="Comma-separated list of GFF feature types to load (e.g., 'gene,mRNA,exon,CDS'). Reduces memory usage by filtering during parsing. Leave empty to load all types.", + ) + parser.add_argument( + "--log-level", + type=str, + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Logging level (default: INFO). Use DEBUG to see detailed site filtering information.", + ) + parser.add_argument( + "--report-non-qualified-features", + action="store_true", + default=False, + help="Report features with no qualified sites (default: skip them).", + ) + parser.add_argument( + "--log-to-console", + action="store_true", + default=False, + help="Output logs to console/terminal in addition to log file.", + ) return parser.parse_args() @@ -764,13 +935,79 @@ def init_worker(args_value, reader_factory_value): reader_factory = reader_factory_value -def run_job(record: SeqRecord) -> dict[str, Any]: +def _scan_gff_record_ids(gff_path: str) -> list[str]: + """Return a natsorted list of unique sequence IDs present in the GFF file. + + For bgzip+tabix files (.gz) uses ``tabix -l`` (fast, no decompression of + data lines). For plain-text files falls back to a lightweight linear scan. + """ + if gff_path.endswith('.gz'): + result = subprocess.run( + ['tabix', '-l', gff_path], + capture_output=True, text=True, check=True, + ) + return natsorted(result.stdout.splitlines()) + + seen: set[str] = set() + ordered: list[str] = [] + with open(gff_path) as fh: + for line in fh: + if not line or line[0] == '#': + continue + fields = line.split('\t', 2) + if len(fields) < 2: + # Not a GFF data line (e.g. FASTA section, wrapped attribute line) + continue + seqid = fields[0] + if seqid not in seen: + seen.add(seqid) + ordered.append(seqid) + return natsorted(ordered) + + +def _open_gff_for_record( + gff_path: str, + record_id: str, + gff_limit: Optional[dict], +) -> Optional[SeqRecord]: + """Load a single chromosome from a GFF3 file and return its SeqRecord. + + Strategy differs by file type: + + * **bgzip+tabix (.gz)**: call ``tabix gff_path record_id`` to extract only + the relevant lines and feed them to BCBio via an ``io.StringIO`` buffer + (which is seekable, so BCBio can do its two-pass parent-child resolution). + We intentionally do *not* pass a ``gff_id`` filter here — the stream + already contains only the target chromosome, so there is nothing to + filter. The ``gff_id`` filter in BCBio can bypass the two-pass logic, + causing "parent not found" warnings for out-of-order features. + + * **Plain text**: use BCBio's ``gff_id`` filter on the open (seekable) file + handle so BCBio can seek back for its second pass when needed. + """ + if gff_path.endswith('.gz'): + proc = subprocess.run( + ['tabix', gff_path, record_id], + capture_output=True, text=True, check=True, + ) + gff_stream = io.StringIO('##gff-version 3\n' + proc.stdout) + return next(GFF.parse(gff_stream, limit_info=gff_limit), None) + + # Plain text: seekable file handle → BCBio two-pass works + limit: dict = {'gff_id': [record_id]} + if gff_limit: + limit.update(gff_limit) + with open(gff_path) as gff_handle: + return next(GFF.parse(gff_handle, limit_info=limit), None) + + +def _do_counting(record: SeqRecord) -> dict[str, Any]: """ - A wrapper function for performing counting parallelized by record. The return value is a dict containing all the information needed for integrating - the output of all records after the computations are finished. + Core counting logic for a single SeqRecord. Returns the result dict. + Called directly in sequential mode and via run_job in parallel mode. """ - assert record.id # Stupid assertion for pylance - logging.info(f"Record {record.id} · Record parsed. Counting beings.") + assert record.id # Placate pylance + logging.info(f"Record {record.id} · Counting begins.") tmp_feature_output_file: str = tempfile.mkstemp()[1] tmp_aggregate_output_file: str = tempfile.mkstemp()[1] @@ -788,7 +1025,8 @@ def run_job(record: SeqRecord) -> dict[str, Any]: filter: SiteFilter = SiteFilter(cov_threshold=args.cov, edit_threshold=args.edit_threshold) record_ctx: RecordCountingContext = RecordCountingContext( - feature_writer, aggregate_writer, filter, args.progress + feature_writer, aggregate_writer, filter, args.progress, + report_non_qualified=args.report_non_qualified_features, ) # Count @@ -804,6 +1042,8 @@ def run_job(record: SeqRecord) -> dict[str, Any]: ".", "longest_isoform", record_ctx.longest_isoform_aggregate_counters, + record_ctx.longest_isoform_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) aggregate_writer.write_rows_with_data( record.id, @@ -812,6 +1052,8 @@ def run_job(record: SeqRecord) -> dict[str, Any]: ".", "all_isoforms", record_ctx.all_isoforms_aggregate_counters, + record_ctx.all_isoforms_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) aggregate_writer.write_rows_with_data( record.id, @@ -820,6 +1062,8 @@ def run_job(record: SeqRecord) -> dict[str, Any]: ".", "chimaera", record_ctx.chimaera_aggregate_counters, + record_ctx.chimaera_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) # Write aggregate counter data by parent type @@ -829,6 +1073,8 @@ def run_job(record: SeqRecord) -> dict[str, Any]: ".", "longest_isoform", record_ctx.longest_isoform_aggregate_counters_by_parent_type, + record_ctx.longest_isoform_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) aggregate_writer.write_rows_with_data_by_parent_type( record.id, @@ -836,6 +1082,8 @@ def run_job(record: SeqRecord) -> dict[str, Any]: ".", "all_isoforms", record_ctx.all_isoforms_aggregate_counters_by_parent_type, + record_ctx.all_isoforms_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) aggregate_writer.write_rows_with_data_by_parent_type( record.id, @@ -843,6 +1091,8 @@ def run_job(record: SeqRecord) -> dict[str, Any]: ".", "chimaera", record_ctx.chimaera_aggregate_counters_by_parent_type, + record_ctx.chimaera_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) # Write the total counter data of the record. A dummy dict needs to be created to use the `write_rows_with_data` method @@ -851,22 +1101,163 @@ def run_job(record: SeqRecord) -> dict[str, Any]: ) total_counter_dict["."] = record_ctx.total_counter aggregate_writer.write_rows_with_data( - record.id, ["."], ".", ".", "all_sites", total_counter_dict + record.id, ["."], ".", ".", "all_sites", total_counter_dict, + report_non_qualified=args.report_non_qualified_features, ) - return { + # Extract data needed for return before cleanup + result = { "record_id": record.id, "tmp_feature_output_file": tmp_feature_output_file, "tmp_aggregate_output_file": tmp_aggregate_output_file, - "chimaera_aggregate_counters": record_ctx.chimaera_aggregate_counters, - "longest_isoform_aggregate_counters": record_ctx.longest_isoform_aggregate_counters, - "all_isoforms_aggregate_counters": record_ctx.all_isoforms_aggregate_counters, - "chimaera_aggregate_counters_by_parent_type": record_ctx.chimaera_aggregate_counters_by_parent_type, - "longest_isoform_aggregate_counters_by_parent_type": record_ctx.longest_isoform_aggregate_counters_by_parent_type, - "all_isoforms_aggregate_counters_by_parent_type": record_ctx.all_isoforms_aggregate_counters_by_parent_type, + "chimaera_aggregate_counters": record_ctx.chimaera_aggregate_counters.copy(), + "longest_isoform_aggregate_counters": record_ctx.longest_isoform_aggregate_counters.copy(), + "all_isoforms_aggregate_counters": record_ctx.all_isoforms_aggregate_counters.copy(), + "chimaera_aggregate_counters_by_parent_type": record_ctx.chimaera_aggregate_counters_by_parent_type.copy(), + "longest_isoform_aggregate_counters_by_parent_type": record_ctx.longest_isoform_aggregate_counters_by_parent_type.copy(), + "all_isoforms_aggregate_counters_by_parent_type": record_ctx.all_isoforms_aggregate_counters_by_parent_type.copy(), "total_counter": record_ctx.total_counter, + "longest_isoform_aggregate_positions": record_ctx.longest_isoform_aggregate_positions.copy(), + "all_isoforms_aggregate_positions": record_ctx.all_isoforms_aggregate_positions.copy(), + "chimaera_aggregate_positions": record_ctx.chimaera_aggregate_positions.copy(), } + # Clean up record context to free memory immediately + record_ctx.cleanup() + + return result + + +def run_job(record_id: str) -> str: + """ + Worker entry point for parallel processing. Loads only the chromosome matching + record_id from the GFF (minimal RAM footprint per worker), performs counting, + serializes the result dict to a pickle temp file and returns its path. + Only a tiny string crosses the IPC channel. + """ + gff_limit: Optional[dict] = None + if args.gff_feature_types: + gff_limit = {'gff_type': [t.strip() for t in args.gff_feature_types.split(',')]} + + logging.info(f"Record {record_id} · Loading GFF features.") + record: Optional[SeqRecord] = _open_gff_for_record(args.gff, record_id, gff_limit) + + if record is None: + raise RuntimeError(f"Record '{record_id}' not found in {args.gff}") + + result = _do_counting(record) + record.features.clear() + del record + + tmp_pickle_fd, tmp_pickle_path = tempfile.mkstemp(suffix='.pkl') + with open(tmp_pickle_fd, 'wb') as f: + pickle.dump(result, f) + del result + + return tmp_pickle_path + + +def process_and_write_record_data( + record_data: dict[str, Any], + feature_output_handle: TextIO, + aggregate_output_handle: TextIO, + genome_longest_isoform_aggregate_counters: defaultdict[str, MultiCounter], + genome_all_isoforms_aggregate_counters: defaultdict[str, MultiCounter], + genome_chimaera_aggregate_counters: defaultdict[str, MultiCounter], + genome_longest_isoform_aggregate_counters_by_parent_type: defaultdict[tuple[str, str], MultiCounter], + genome_all_isoforms_aggregate_counters_by_parent_type: defaultdict[tuple[str, str], MultiCounter], + genome_chimaera_aggregate_counters_by_parent_type: defaultdict[tuple[str, str], MultiCounter], + genome_total_counter: MultiCounter, + genome_longest_isoform_aggregate_positions: dict[str, AggregatePositions], + genome_all_isoforms_aggregate_positions: dict[str, AggregatePositions], + genome_chimaera_aggregate_positions: dict[str, AggregatePositions], +) -> None: + """ + Process a single record's data: write its output and merge counters into genome totals. + This function writes immediately to output files and only keeps genome-level totals in memory. + """ + logging.info(f"Record {record_data['record_id']} · Writing output and merging totals...") + + # Write temporary files to final output immediately + with open(record_data["tmp_feature_output_file"]) as tmp_output_handle: + feature_output_handle.write(tmp_output_handle.read()) + remove(record_data["tmp_feature_output_file"]) + + with open(record_data["tmp_aggregate_output_file"]) as tmp_output_handle: + aggregate_output_handle.write(tmp_output_handle.read()) + remove(record_data["tmp_aggregate_output_file"]) + + # Update the genome's aggregate counters from the record data aggregate counters + for record_aggregate_type, record_aggregate_counter in record_data[ + "longest_isoform_aggregate_counters" + ].items(): + genome_longest_isoform_aggregate_counters[record_aggregate_type].merge( + record_aggregate_counter + ) + + for record_aggregate_type, record_aggregate_counter in record_data[ + "chimaera_aggregate_counters" + ].items(): + genome_chimaera_aggregate_counters[record_aggregate_type].merge( + record_aggregate_counter + ) + + merge_aggregation_counter_dicts( + genome_all_isoforms_aggregate_counters, + record_data["all_isoforms_aggregate_counters"], + ) + + # Update the genome's by-parent-type aggregate counters + for (parent_type, aggregate_type), record_aggregate_counter in record_data[ + "longest_isoform_aggregate_counters_by_parent_type" + ].items(): + genome_longest_isoform_aggregate_counters_by_parent_type[(parent_type, aggregate_type)].merge( + record_aggregate_counter + ) + + for (parent_type, aggregate_type), record_aggregate_counter in record_data[ + "all_isoforms_aggregate_counters_by_parent_type" + ].items(): + genome_all_isoforms_aggregate_counters_by_parent_type[(parent_type, aggregate_type)].merge( + record_aggregate_counter + ) + + for (parent_type, aggregate_type), record_aggregate_counter in record_data[ + "chimaera_aggregate_counters_by_parent_type" + ].items(): + genome_chimaera_aggregate_counters_by_parent_type[(parent_type, aggregate_type)].merge( + record_aggregate_counter + ) + + # Update the genome's total counter from the record data total counter + genome_total_counter.merge(record_data["total_counter"]) + + # Merge positions for aggregates + for aggregate_type, positions in record_data.get("longest_isoform_aggregate_positions", {}).items(): + if aggregate_type not in genome_longest_isoform_aggregate_positions: + genome_longest_isoform_aggregate_positions[aggregate_type] = AggregatePositions(strand=positions.strand) + genome_longest_isoform_aggregate_positions[aggregate_type].merge(positions) + + for aggregate_type, positions in record_data.get("all_isoforms_aggregate_positions", {}).items(): + if aggregate_type not in genome_all_isoforms_aggregate_positions: + genome_all_isoforms_aggregate_positions[aggregate_type] = AggregatePositions(strand=positions.strand) + genome_all_isoforms_aggregate_positions[aggregate_type].merge(positions) + + for aggregate_type, positions in record_data.get("chimaera_aggregate_positions", {}).items(): + if aggregate_type not in genome_chimaera_aggregate_positions: + genome_chimaera_aggregate_positions[aggregate_type] = AggregatePositions(strand=positions.strand) + genome_chimaera_aggregate_positions[aggregate_type].merge(positions) + + # Clear record_data counters to free memory immediately after merging + record_data["chimaera_aggregate_counters"].clear() + record_data["longest_isoform_aggregate_counters"].clear() + record_data["all_isoforms_aggregate_counters"].clear() + record_data["chimaera_aggregate_counters_by_parent_type"].clear() + record_data["longest_isoform_aggregate_counters_by_parent_type"].clear() + record_data["all_isoforms_aggregate_counters_by_parent_type"].clear() + + return None + def main(): global args @@ -875,31 +1266,59 @@ def main(): args = parse_cli_input() LOGGING_FORMAT = "%(asctime)s - %(levelname)s - %(message)s" log_filename: str = args.output + "_pluviometer.log" if args.output else "pluviometer.log" - logging.basicConfig(filename=log_filename, level=logging.INFO, format=LOGGING_FORMAT) + log_level = getattr(logging, args.log_level.upper()) + + # Configure logging with file handler + logger = logging.getLogger() + logger.setLevel(log_level) + + # File handler + file_handler = logging.FileHandler(log_filename, mode='w') + file_handler.setLevel(log_level) + file_handler.setFormatter(logging.Formatter(LOGGING_FORMAT)) + logger.addHandler(file_handler) + + # Console handler if requested + if args.log_to_console: + console_handler = logging.StreamHandler() + console_handler.setLevel(log_level) + console_handler.setFormatter(logging.Formatter(LOGGING_FORMAT)) + logger.addHandler(console_handler) + logging.info(f"Pluviometer started. Log file: {log_filename}") + logging.info("Options:") + logging.info(f" Sites file: {args.sites}") + logging.info(f" GFF file: {args.gff}") + logging.info(f" Output prefix: {args.output if args.output else 'default'}") + logging.info(f" Format: {args.format}") + logging.info(f" Coverage threshold: {args.cov}") + logging.info(f" Edit threshold: {args.edit_threshold}") + logging.info(f" Aggregation mode: {args.aggregation_mode}") + logging.info(f" Threads: {args.threads}") + logging.info(f" Progress bar: {args.progress}") + logging.info(f" GFF feature types filter: {args.gff_feature_types if args.gff_feature_types else 'none (all types)'}") + logging.info(f" Log level: {args.log_level}") feature_output_filename: str = args.output + "_features.tsv" if args.output else "features.tsv" aggregate_output_filename: str = ( args.output + "_aggregates.tsv" if args.output else "aggregates.tsv" ) + # Determine reader factory based on format + match args.format: + case "reditools2": + reader_factory = Reditools2Reader + case "reditools3": + reader_factory = Reditools3Reader + case "jacusa2": + reader_factory = Jacusa2Reader + case _: + raise Exception(f'Unimplemented format "{args.format}"') + with ( - open(args.gff) as gff_handle, open(feature_output_filename, "w") as feature_output_handle, open(aggregate_output_filename, "w") as aggregate_output_handle, ): - match args.format: - case "reditools2": - reader_factory = Reditools2Reader - case "reditools3": - reader_factory = Reditools3Reader - case "jacusa2": - reader_factory = Jacusa2Reader - case _: - raise Exception(f'Unimplemented format "{args.format}"') - - logging.info("Parsing GFF3 file...") - records: Generator[SeqRecord, None, None] = GFF.parse(gff_handle) - + # Initialize genome-level counters (only these will be kept in memory) genome_filter: SiteFilter = SiteFilter( cov_threshold=args.cov, edit_threshold=args.edit_threshold ) @@ -914,6 +1333,11 @@ def main(): genome_chimaera_aggregate_counters: defaultdict[str, MultiCounter] = defaultdict( lambda: MultiCounter(genome_filter) ) + + # Positions for genome-level aggregates + genome_longest_isoform_aggregate_positions: dict[str, AggregatePositions] = {} + genome_all_isoforms_aggregate_positions: dict[str, AggregatePositions] = {} + genome_chimaera_aggregate_positions: dict[str, AggregatePositions] = {} # By-parent-type genome-level aggregates genome_longest_isoform_aggregate_counters_by_parent_type: defaultdict[tuple[str, str], MultiCounter] = defaultdict( @@ -926,105 +1350,131 @@ def main(): lambda: MultiCounter(genome_filter) ) + # Write headers once at the beginning feature_writer: FeatureFileWriter = FeatureFileWriter(feature_output_handle) aggregate_writer: AggregateFileWriter = AggregateFileWriter(aggregate_output_handle) - feature_writer.write_header() aggregate_writer.write_header() - with multiprocessing.Pool(processes=args.threads, initializer=init_worker, initargs=(args, reader_factory)) as pool: - # Run the jobs and save the result of each record in a dict stored in the `record_data` list - record_data_list: list[dict[str, Any]] = pool.map(run_job, records) - - # Sort record results in "natural" order - record_data_list = natsorted(record_data_list, key=lambda x: x["record_id"]) - - with ( - open(feature_output_filename, "a") as feature_output_handle, - open(aggregate_output_filename, "a") as aggregate_output_handle, - ): - for record_data in record_data_list: - logging.info(f"Record {record_data['record_id']} · Merging temporary output files...") - - with open(record_data["tmp_feature_output_file"]) as tmp_output_handle: - feature_output_handle.write(tmp_output_handle.read()) - remove(record_data["tmp_feature_output_file"]) - - with open(record_data["tmp_aggregate_output_file"]) as tmp_output_handle: - aggregate_output_handle.write(tmp_output_handle.read()) - remove(record_data["tmp_aggregate_output_file"]) - - # Update the genome's aggregate counters from the record data aggregate counters - for record_aggregate_type, record_aggregate_counter in record_data[ - "longest_isoform_aggregate_counters" - ].items(): - genome_aggregate_counter: MultiCounter = genome_longest_isoform_aggregate_counters[ - record_aggregate_type - ] - genome_aggregate_counter.merge(record_aggregate_counter) - - for record_aggregate_type, record_aggregate_counter in record_data[ - "chimaera_aggregate_counters" - ].items(): - genome_aggregate_counter: MultiCounter = genome_chimaera_aggregate_counters[ - record_aggregate_type - ] - genome_aggregate_counter.merge(record_aggregate_counter) - - merge_aggregation_counter_dicts( - genome_all_isoforms_aggregate_counters, - record_data["all_isoforms_aggregate_counters"], - ) - - # Update the genome's by-parent-type aggregate counters - for (parent_type, aggregate_type), record_aggregate_counter in record_data[ - "longest_isoform_aggregate_counters_by_parent_type" - ].items(): - genome_longest_isoform_aggregate_counters_by_parent_type[(parent_type, aggregate_type)].merge( - record_aggregate_counter - ) - - for (parent_type, aggregate_type), record_aggregate_counter in record_data[ - "all_isoforms_aggregate_counters_by_parent_type" - ].items(): - genome_all_isoforms_aggregate_counters_by_parent_type[(parent_type, aggregate_type)].merge( - record_aggregate_counter + # Process records and write output immediately + if args.threads > 1: + # Parallel processing: use imap_unordered so each worker's result is + # written to a pickle temp file as soon as it is ready, regardless of + # chromosome order. Only the tiny file paths cross the IPC channel, + # keeping the main process RAM near zero while workers are running. + record_ids: list[str] = _scan_gff_record_ids(args.gff) + logging.info(f"Processing {len(record_ids)} records with {args.threads} threads...") + pickle_paths: list[str] = [] + with multiprocessing.Pool(processes=args.threads, initializer=init_worker, initargs=(args, reader_factory)) as pool: + for pickle_path in pool.imap_unordered(run_job, record_ids, chunksize=1): + pickle_paths.append(pickle_path) + + # Sort pickle files by their record_id (natural chromosome order) + def _record_id_from_pickle(path: str) -> str: + with open(path, 'rb') as _f: + return pickle.load(_f)['record_id'] + + pickle_paths = natsorted(pickle_paths, key=_record_id_from_pickle) + + # Process in natural order: unpickle → write → delete → GC + for pickle_path in pickle_paths: + with open(pickle_path, 'rb') as f: + record_data = pickle.load(f) + remove(pickle_path) + process_and_write_record_data( + record_data, + feature_output_handle, + aggregate_output_handle, + genome_longest_isoform_aggregate_counters, + genome_all_isoforms_aggregate_counters, + genome_chimaera_aggregate_counters, + genome_longest_isoform_aggregate_counters_by_parent_type, + genome_all_isoforms_aggregate_counters_by_parent_type, + genome_chimaera_aggregate_counters_by_parent_type, + genome_total_counter, + genome_longest_isoform_aggregate_positions, + genome_all_isoforms_aggregate_positions, + genome_chimaera_aggregate_positions, ) + del record_data + # Force flush to disk after each chromosome + feature_output_handle.flush() + aggregate_output_handle.flush() - for (parent_type, aggregate_type), record_aggregate_counter in record_data[ - "chimaera_aggregate_counters_by_parent_type" - ].items(): - genome_chimaera_aggregate_counters_by_parent_type[(parent_type, aggregate_type)].merge( - record_aggregate_counter + # Force garbage collection to free memory immediately + gc.collect() + + else: + # Sequential processing: load one chromosome at a time, write immediately. + # Uses the same _open_gff_for_record helper as the parallel mode so that + # both plain and tabix-indexed files are handled identically. + logging.info("Processing records sequentially...") + gff_limit: Optional[dict] = None + if args.gff_feature_types: + gff_limit = {'gff_type': [t.strip() for t in args.gff_feature_types.split(',')]} + for record_id in _scan_gff_record_ids(args.gff): + record = _open_gff_for_record(args.gff, record_id, gff_limit) + if record is None: + logging.warning(f"No features found for record '{record_id}', skipping.") + continue + result = _do_counting(record) + record.features.clear() + del record + process_and_write_record_data( + result, + feature_output_handle, + aggregate_output_handle, + genome_longest_isoform_aggregate_counters, + genome_all_isoforms_aggregate_counters, + genome_chimaera_aggregate_counters, + genome_longest_isoform_aggregate_counters_by_parent_type, + genome_all_isoforms_aggregate_counters_by_parent_type, + genome_chimaera_aggregate_counters_by_parent_type, + genome_total_counter, + genome_longest_isoform_aggregate_positions, + genome_all_isoforms_aggregate_positions, + genome_chimaera_aggregate_positions, ) + del result + feature_output_handle.flush() + aggregate_output_handle.flush() + gc.collect() - # Update the genome's total counter from the record data total counter - genome_total_counter.merge(record_data["total_counter"]) - + # Write genome-level totals at the end (only genome totals are kept in memory) logging.info("Writing genome totals...") - aggregate_writer: AggregateFileWriter = AggregateFileWriter(aggregate_output_handle) - # Write genomic counts aggregate_writer.write_rows_with_data( - ".", ["."], ".", ".", "longest_isoform", genome_longest_isoform_aggregate_counters + ".", ["."], ".", ".", "longest_isoform", genome_longest_isoform_aggregate_counters, + genome_longest_isoform_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) aggregate_writer.write_rows_with_data( - ".", ["."], ".", ".", "all_isoforms", genome_all_isoforms_aggregate_counters + ".", ["."], ".", ".", "all_isoforms", genome_all_isoforms_aggregate_counters, + genome_all_isoforms_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) aggregate_writer.write_rows_with_data( - ".", ["."], ".", ".", "chimaera", genome_chimaera_aggregate_counters + ".", ["."], ".", ".", "chimaera", genome_chimaera_aggregate_counters, + genome_chimaera_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) # Write genomic counts by parent type aggregate_writer.write_rows_with_data_by_parent_type( - ".", ["."], ".", "longest_isoform", genome_longest_isoform_aggregate_counters_by_parent_type + ".", ["."], ".", "longest_isoform", genome_longest_isoform_aggregate_counters_by_parent_type, + genome_longest_isoform_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) aggregate_writer.write_rows_with_data_by_parent_type( - ".", ["."], ".", "all_isoforms", genome_all_isoforms_aggregate_counters_by_parent_type + ".", ["."], ".", "all_isoforms", genome_all_isoforms_aggregate_counters_by_parent_type, + genome_all_isoforms_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) aggregate_writer.write_rows_with_data_by_parent_type( - ".", ["."], ".", "chimaera", genome_chimaera_aggregate_counters_by_parent_type + ".", ["."], ".", "chimaera", genome_chimaera_aggregate_counters_by_parent_type, + genome_chimaera_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) # Write the genomic total. A dummy dict needs to be created to use the `write_rows_with_data` method @@ -1033,7 +1483,8 @@ def main(): ) genomic_total_counter_dict["."] = genome_total_counter aggregate_writer.write_rows_with_data( - ".", ["."], ".", ".", "all_sites", genomic_total_counter_dict + ".", ["."], ".", ".", "all_sites", genomic_total_counter_dict, + report_non_qualified=args.report_non_qualified_features, ) logging.info("Program finished") diff --git a/bin/pluviometer/multi_counter.py b/bin/pluviometer/multi_counter.py index d11f147..5f5eef1 100644 --- a/bin/pluviometer/multi_counter.py +++ b/bin/pluviometer/multi_counter.py @@ -15,9 +15,12 @@ def __init__(self, site_filter: SiteFilter) -> None: Row-columns corresponding to the same base (e.g. (0,0) -> (A,A)) represent reads where the base is unchanged """ self.edit_read_freqs: NDArray[np.int64] = np.zeros((5, 5), dtype=np.int64) - self.edit_site_freqs: NDArray[np.int64] = np.zeros((5, 5), dtype=np.int64) + self.edit_qualified_site_freqs: NDArray[np.int64] = np.zeros((5, 5), dtype=np.int64) # Non-ref sites passing cov + edit thresholds + self.edit_qualified_read_freqs: NDArray[np.int64] = np.zeros((5, 5), dtype=np.int64) # Non-ref reads passing cov + edit thresholds self.genome_base_freqs: NDArray[np.int64] = np.zeros(5, dtype=np.int64) + self.filtered_base_freqs: NDArray[np.int64] = np.zeros(5, dtype=np.int64) # Bases for qualified sites + self.filtered_sites_count: int = 0 # Number of sites that pass the coverage filter self.filter = site_filter @@ -30,7 +33,16 @@ def update(self, variant_data: RNASiteVariantData) -> None: self.edit_read_freqs[i, :] += variant_data.frequencies self.filter.apply(variant_data) - self.edit_site_freqs[i, :] += self.filter.frequencies + # The edit_threshold is not applied to the self-pairing (X→X): coverage alone is sufficient. + if variant_data.coverage >= self.filter.cov_threshold: + self.filter.frequencies[i] = variant_data.frequencies[i] + + # A site qualifies if it passes the coverage threshold. + # Self-pairings are included unconditionally; non-ref pairings only if they also pass edit_threshold. + self.filtered_sites_count += 1 + self.filtered_base_freqs[i] += 1 + self.edit_qualified_site_freqs[i, :] += (self.filter.frequencies > 0).astype(np.int64) + self.edit_qualified_read_freqs[i, :] += self.filter.frequencies self.genome_base_freqs[i] += 1 @@ -41,8 +53,11 @@ def merge(self, other_counter: "MultiCounter") -> None: Add to this counter the values of another. """ self.edit_read_freqs[:] += other_counter.edit_read_freqs - self.edit_site_freqs[:] += other_counter.edit_site_freqs + self.edit_qualified_site_freqs[:] += other_counter.edit_qualified_site_freqs + self.edit_qualified_read_freqs[:] += other_counter.edit_qualified_read_freqs self.genome_base_freqs[:] += other_counter.genome_base_freqs + self.filtered_base_freqs[:] += other_counter.filtered_base_freqs + self.filtered_sites_count += other_counter.filtered_sites_count return None diff --git a/bin/pluviometer/rain_file_writers.py b/bin/pluviometer/rain_file_writers.py index 89686d4..eb6386a 100644 --- a/bin/pluviometer/rain_file_writers.py +++ b/bin/pluviometer/rain_file_writers.py @@ -3,7 +3,10 @@ from .multi_counter import MultiCounter from Bio.SeqFeature import SeqFeature from collections import defaultdict -from typing import TextIO +from typing import TextIO, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from __main__ import AggregatePositions def make_parent_path(parent_list: list[str]) -> str: @@ -80,10 +83,6 @@ def write_counter_data(self, counter: MultiCounter) -> int: b: int = self.handle.write(str(counter.genome_base_freqs.sum())) b += self.handle.write('\t') b += self.handle.write(",".join(map(str, counter.genome_base_freqs[0:4].flat))) - b += self.handle.write('\t') - b += self.handle.write(",".join(map(str, counter.edit_site_freqs[0:4, 0:4].flat))) - b += self.handle.write('\t') - b += self.handle.write(",".join(map(str, counter.edit_read_freqs[0:4, 0:4].flat))) b += self.handle.write('\n') return b @@ -94,18 +93,23 @@ class FeatureFileWriter(RainFileWriter): metadata_fields: list[str] = [ "SeqID", "ParentIDs", - "FeatureID", + "ID", + "Mtype", + "Ptype", "Type", + "Ctype", + "Mode", "Start", "End", "Strand", ] data_fields: list[str] = [ - "CoveredSites", - "GenomeBases", - "SiteBasePairings", - "ReadBasePairings" + "TotalSites", + "ObservedBases", + "QualifiedBases", + "SiteBasePairingsQualified", + "ReadBasePairingsQualified" ] def __init__(self, handle: TextIO): @@ -123,7 +127,11 @@ def write_metadata(self, record_id: str, feature: SeqFeature) -> int: record_id, make_parent_path(feature.parent_list), feature.id, + "feature", + ".", feature.type, + ".", + ".", str(feature.location.parts[0].start + ExactPosition(1)), str(feature.location.parts[-1].end), str(feature.location.strand), @@ -134,12 +142,13 @@ def write_row_with_data( ) -> int: """Write the data fields (coverage, read frequency, &c) of an output line, taken from the informations in a counter object""" return self.write_metadata(record_id, feature) + self.write_data( - str(counter.genome_base_freqs.sum()), + str(len(feature.location)), # TotalSites: total positions in the feature # Subindexing is used below because counter matrices are 5x5 because they contain pairings with N, which we don't print # Use the flat attribute of the NumPy arrays to flatten the matrix row-wise to attain the desired base pairing order in the output - ",".join(map(str, counter.genome_base_freqs[0:4].flat)), - ",".join(map(str, counter.edit_site_freqs[0:4, 0:4].flat)), - ",".join(map(str, counter.edit_read_freqs[0:4, 0:4].flat)), + ",".join(map(str, counter.genome_base_freqs[0:4].flat)), # ObservedBases: base distribution for all observations + ",".join(map(str, counter.filtered_base_freqs[0:4].flat)), # QualifiedBases: base distribution for qualified sites + ",".join(map(str, counter.edit_qualified_site_freqs[0:4, 0:4].flat)), # SiteBasePairingsQualified + ",".join(map(str, counter.edit_qualified_read_freqs[0:4, 0:4].flat)), # ReadBasePairingsQualified ) def write_row_without_data(self, record_id: str, feature: SeqFeature) -> int: @@ -148,7 +157,7 @@ def write_row_without_data(self, record_id: str, feature: SeqFeature) -> int: This is faster than creating dummy counter objects with zero observations. """ return self.write_metadata(record_id, feature) + self.write_data( - "0", self.STR_ZERO_BASE_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS + str(len(feature.location)), self.STR_ZERO_BASE_FREQS, self.STR_ZERO_BASE_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS ) @@ -156,20 +165,23 @@ class AggregateFileWriter(RainFileWriter): metadata_fields: list[str] = [ "SeqID", "ParentIDs", - "AggregateID", - "ParentType", - "AggregateType", - "AggregationMode", + "ID", + "Mtype", + "Ptype", + "Type", + "Ctype", + "Mode", "Start", "End", "Strand", ] data_fields: list[str] = [ - "CoveredSites", - "GenomeBases", - "SiteBasePairings", - "ReadBasePairings", + "TotalSites", + "ObservedBases", + "QualifiedBases", + "SiteBasePairingsQualified", + "ReadBasePairingsQualified", ] def __init__(self, handle: TextIO): @@ -177,6 +189,16 @@ def __init__(self, handle: TextIO): return None + @staticmethod + def _agg_type(seq_id: str, aggregate_id: str) -> str: + """Compute the Type field for an aggregate row.""" + if aggregate_id not in (".", ""): + return "feature" + elif seq_id != ".": + return "sequence" + else: + return "global" + def write_metadata( self, seq_id: str, @@ -190,26 +212,36 @@ def write_metadata( strand: str = "." ) -> int: """Write metadata fields of an aggregate""" - b: int = self.handle.write(seq_id) - b += self.handle.write('\t') - b += self.handle.write(parent_ids) - b += self.handle.write('\t') - b += self.handle.write(aggregate_id) - b += self.handle.write('\t') - b += self.handle.write(parent_type) - b += self.handle.write('\t') - b += self.handle.write(aggregate_type) - b += self.handle.write('\t') - b += self.handle.write(aggregation_mode) - b += self.handle.write('\t') - b += self.handle.write(start) - b += self.handle.write('\t') - b += self.handle.write(end) - b += self.handle.write('\t') - b += self.handle.write(strand) - b += self.handle.write('\t') - - return b + + # Determine the Type field before modifying aggregate_id + original_aggregate_id = aggregate_id + agg_type = self._agg_type(seq_id, original_aggregate_id) + + # Create a dynamic ID for aggregates without explicit IDs (sequence/global level) + if aggregate_id == ".": + # Build ID as Type_Ptype_Ctype_Mode, omitting Ptype and/or Ctype if they are "." + id_parts = [agg_type] # "sequence" or "global" + if parent_type != ".": + id_parts.append(parent_type) + if aggregate_type != ".": + id_parts.append(aggregate_type) + id_parts.append(aggregation_mode) + + aggregate_id = "_".join(id_parts) + + return super().write_metadata( + seq_id, + parent_ids, + aggregate_id, + "aggregate", + parent_type, + agg_type, # Use the pre-computed Type instead of recalculating it + aggregate_type, + aggregation_mode, + start, + end, + strand, + ) # Case like that we will add an empty start end strand # 21 . . . exon longest_isoform @@ -224,28 +256,37 @@ def write_rows_with_data( feature_type: str, aggregation_mode: str, counter_dict: defaultdict[str, MultiCounter], + positions_dict: Optional[dict[str, 'AggregatePositions']] = None, + report_non_qualified: bool = True, ) -> int: """Write metadata and data fields of multiple counters of the same aggregate feature""" b: int = 0 for aggregate_type, aggregate_counter in counter_dict.items(): - b += super().write_metadata( + if not report_non_qualified and aggregate_counter.filtered_base_freqs.sum() == 0: + continue + b += self.write_metadata( record_id, make_parent_path(parent_list), aggregate_id, feature_type, aggregate_type, aggregation_mode, - ".", - ".", - ".", ) + # Calculate TotalSites from positions if available + total_sites_str = "." + if positions_dict and aggregate_type in positions_dict: + pos = positions_dict[aggregate_type] + if pos.start != float('inf') and pos.end > 0: + total_sites_str = str(pos.end - pos.start) + b += self.write_data( - str(aggregate_counter.genome_base_freqs.sum()), - ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), - ",".join(map(str, aggregate_counter.edit_site_freqs[0:4, 0:4].flat)), - ",".join(map(str, aggregate_counter.edit_read_freqs[0:4, 0:4].flat)), + total_sites_str, # TotalSites: calculated from aggregate positions + ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), # ObservedBases + ",".join(map(str, aggregate_counter.filtered_base_freqs[0:4].flat)), # QualifiedBases + ",".join(map(str, aggregate_counter.edit_qualified_site_freqs[0:4, 0:4].flat)), # SiteBasePairingsQualified + ",".join(map(str, aggregate_counter.edit_qualified_read_freqs[0:4, 0:4].flat)), # ReadBasePairingsQualified ) return b @@ -276,10 +317,11 @@ def write_row_chimaera_with_data( strand=strand_str, ) b += self.write_data( - str(counter.genome_base_freqs.sum()), - ",".join(map(str, counter.genome_base_freqs[0:4].flat)), - ",".join(map(str, counter.edit_site_freqs[0:4, 0:4].flat)), - ",".join(map(str, counter.edit_read_freqs[0:4, 0:4].flat)), + str(len(feature.location)) if hasattr(feature, 'location') and feature.location else ".", # TotalSites + ",".join(map(str, counter.genome_base_freqs[0:4].flat)), # ObservedBases + ",".join(map(str, counter.filtered_base_freqs[0:4].flat)), # QualifiedBases + ",".join(map(str, counter.edit_qualified_site_freqs[0:4, 0:4].flat)), # SiteBasePairingsQualified + ",".join(map(str, counter.edit_qualified_read_freqs[0:4, 0:4].flat)), # ReadBasePairingsQualified ) return b @@ -309,11 +351,14 @@ def write_row_chimaera_without_data( end=end_str, strand=strand_str, ) - b += self.write_data("0", self.STR_ZERO_BASE_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS) + b += self.write_data( + str(len(feature.location)) if hasattr(feature, 'location') and feature.location else ".", # TotalSites + self.STR_ZERO_BASE_FREQS, self.STR_ZERO_BASE_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS + ) return b - # Case like that we will add empty start end strand + # Case like that will add empty start end strand # 21 . . gene exon longest_isoform # 21 . . pseudogene exon longest_isoform # . . . pseudogene exon longest_isoform @@ -325,28 +370,37 @@ def write_rows_with_data_by_parent_type( aggregate_id: str, aggregation_mode: str, counter_dict: defaultdict[tuple[str, str], MultiCounter], + positions_dict: Optional[dict[str, 'AggregatePositions']] = None, + report_non_qualified: bool = True, ) -> int: """Write metadata and data fields of multiple counters grouped by (parent_type, aggregate_type)""" b: int = 0 for (parent_type, aggregate_type), aggregate_counter in counter_dict.items(): - b += super().write_metadata( + if not report_non_qualified and aggregate_counter.filtered_base_freqs.sum() == 0: + continue + b += self.write_metadata( record_id, make_parent_path(parent_list), aggregate_id, parent_type, aggregate_type, aggregation_mode, - ".", - ".", - ".", ) + # Calculate TotalSites from positions if available + total_sites_str = "." + if positions_dict and aggregate_type in positions_dict: + pos = positions_dict[aggregate_type] + if pos.start != float('inf') and pos.end > 0: + total_sites_str = str(pos.end - pos.start) + b += self.write_data( - str(aggregate_counter.genome_base_freqs.sum()), - ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), - ",".join(map(str, aggregate_counter.edit_site_freqs[0:4, 0:4].flat)), - ",".join(map(str, aggregate_counter.edit_read_freqs[0:4, 0:4].flat)), + total_sites_str, # TotalSites: calculated from aggregate positions + ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), # ObservedBases + ",".join(map(str, aggregate_counter.filtered_base_freqs[0:4].flat)), # QualifiedBases + ",".join(map(str, aggregate_counter.edit_qualified_site_freqs[0:4, 0:4].flat)), # SiteBasePairingsQualified + ",".join(map(str, aggregate_counter.edit_qualified_read_freqs[0:4, 0:4].flat)), # ReadBasePairingsQualified ) return b diff --git a/bin/pluviometer/rna_site_variant_readers.py b/bin/pluviometer/rna_site_variant_readers.py index 7437afa..dcd9f92 100755 --- a/bin/pluviometer/rna_site_variant_readers.py +++ b/bin/pluviometer/rna_site_variant_readers.py @@ -234,11 +234,11 @@ def parse_strand(self) -> int: strand = int(self.parts[REDITOOLS_FIELD_INDEX["Strand"]]) match strand: case 0: - return -1 + return 0 # Unstranded case 1: - return 1 + return 1 # Forward strand case 2: - return 0 + return -1 # Reverse strand (first-strand oriented) case _: raise Exception(f"Invalid strand value: {strand}") diff --git a/bin/pluviometer/site_filter.py b/bin/pluviometer/site_filter.py index b8e3ee8..745c3c8 100644 --- a/bin/pluviometer/site_filter.py +++ b/bin/pluviometer/site_filter.py @@ -1,6 +1,9 @@ from .utils import RNASiteVariantData import numpy as np from numpy.typing import NDArray +import logging + +logger = logging.getLogger(__name__) class SiteFilter: def __init__(self, cov_threshold: int, edit_threshold: int) -> None: @@ -9,13 +12,33 @@ def __init__(self, cov_threshold: int, edit_threshold: int) -> None: self.frequencies: NDArray[np.int32] = np.zeros(5, np.int32) def apply(self, variant_data: RNASiteVariantData) -> None: + logger.debug( + f"SiteFilter.apply() BEFORE - seqid={variant_data.seqid}, " + f"position={variant_data.position}, reference={variant_data.reference}, " + f"strand={variant_data.strand}, coverage={variant_data.coverage}, " + f"mean_quality={variant_data.mean_quality:.2f}, " + f"frequencies={variant_data.frequencies}, score={variant_data.score:.2f}, " + f"cov_threshold={self.cov_threshold}, edit_threshold={self.edit_threshold}, " + f"self.frequencies[before]={self.frequencies}, " + f"self.cov_threshold={self.cov_threshold}, self.edit_threshold={self.edit_threshold}" + ) + + # Have to pass the coverage threshold first if variant_data.coverage >= self.cov_threshold: np.copyto( self.frequencies, - variant_data.frequencies * variant_data.frequencies - >= self.edit_threshold, + variant_data.frequencies ) + self.frequencies[self.frequencies < self.edit_threshold] = 0 + else: + logger.debug("set to 0!") self.frequencies.fill(0) + + logger.debug( + f"SiteFilter.apply() AFTER - position={variant_data.position}, " + f"coverage_check={variant_data.coverage >= self.cov_threshold}, " + f"self.frequencies[after]={self.frequencies}" + ) return None diff --git a/bin/pluviometer/test_site_counts.py b/bin/pluviometer/test_site_counts.py new file mode 100644 index 0000000..d29f60f --- /dev/null +++ b/bin/pluviometer/test_site_counts.py @@ -0,0 +1,433 @@ +#!/usr/bin/env python3 +""" +Tests pour les colonnes TotalSites, ObservedSites et QualifiedSites + +Pour exécuter: + cd /workspaces/rain/bin + PYTHONPATH=/workspaces/rain/bin python3 -m unittest pluviometer.test_site_counts -v +""" + +import unittest +import numpy as np +from Bio.SeqFeature import SeqFeature, SimpleLocation, CompoundLocation + +# Import du package pluviometer +from pluviometer.multi_counter import MultiCounter +from pluviometer.site_filter import SiteFilter +from pluviometer.utils import RNASiteVariantData +from pluviometer.__main__ import AggregatePositions + + +class TestAggregatePositions(unittest.TestCase): + """Tests pour la classe AggregatePositions""" + + def test_initial_state(self): + """Test l'état initial""" + pos = AggregatePositions(strand=1) + self.assertEqual(pos.start, float('inf')) + self.assertEqual(pos.end, 0) + self.assertEqual(pos.strand, 1) + + def test_update_from_simple_feature(self): + """Test mise à jour avec une SimpleLocation""" + pos = AggregatePositions(strand=1) + feature = SeqFeature(location=SimpleLocation(10, 20), id="test", type="exon") + pos.update_from_feature(feature) + + self.assertEqual(pos.start, 10) + self.assertEqual(pos.end, 20) + self.assertEqual(pos.strand, 1) + + def test_update_from_multiple_features(self): + """Test mise à jour avec plusieurs features""" + pos = AggregatePositions(strand=1) + + feature1 = SeqFeature(location=SimpleLocation(10, 20), id="exon1", type="exon") + feature2 = SeqFeature(location=SimpleLocation(5, 15), id="exon2", type="exon") + feature3 = SeqFeature(location=SimpleLocation(30, 40), id="exon3", type="exon") + + pos.update_from_feature(feature1) + pos.update_from_feature(feature2) + pos.update_from_feature(feature3) + + # Start doit être le minimum (5), end le maximum (40) + self.assertEqual(pos.start, 5) + self.assertEqual(pos.end, 40) + + def test_merge_positions(self): + """Test fusion de deux AggregatePositions""" + pos1 = AggregatePositions(strand=1) + pos1.start = 10 + pos1.end = 20 + + pos2 = AggregatePositions(strand=1) + pos2.start = 5 + pos2.end = 30 + + pos1.merge(pos2) + + self.assertEqual(pos1.start, 5) + self.assertEqual(pos1.end, 30) + + def test_to_strings_valid(self): + """Test conversion en strings avec positions valides""" + pos = AggregatePositions(strand=1) + pos.start = 10 + pos.end = 20 + + start_str, end_str, strand_str = pos.to_strings() + + # Start converti en 1-based + self.assertEqual(start_str, "11") + self.assertEqual(end_str, "20") + self.assertEqual(strand_str, "1") + + def test_to_strings_empty(self): + """Test conversion en strings avec positions non initialisées""" + pos = AggregatePositions(strand=0) + + start_str, end_str, strand_str = pos.to_strings() + + self.assertEqual(start_str, ".") + self.assertEqual(end_str, ".") + self.assertEqual(strand_str, ".") + + +class TestMultiCounterSiteCounts(unittest.TestCase): + """Tests pour les compteurs de sites dans MultiCounter""" + + def test_site_base_pairings_counts_raw_reads(self): + """Test que edit_read_freqs compte les reads bruts (non filtrés)""" + # Reproduit le cas du test minimal: + # pos1: A ref, cov=10, [10,0,0,0] -> AA pairing passes (10>=5) + # pos2: C ref, cov=12, [0,9,0,3] -> CC passes (9>=5), CT filtered (3<5) + # pos4: T ref, cov=2 -> tout filtré (cov<5) + filter = SiteFilter(cov_threshold=5, edit_threshold=5) + counter = MultiCounter(filter) + + counter.update(RNASiteVariantData("21", 0, 0, 1, 10, 37.0, np.array([10,0,0,0,0], dtype=np.int32), 0.0)) # A site + counter.update(RNASiteVariantData("21", 1, 1, 1, 12, 37.0, np.array([0,9,0,3,0], dtype=np.int32), 0.0)) # C site + counter.update(RNASiteVariantData("21", 3, 3, 1, 2, 37.0, np.array([0,0,0,2,0], dtype=np.int32), 0.0)) # T site, cov<5 + + # edit_read_freqs: compte les reads bruts (non filtrés) + self.assertEqual(counter.edit_read_freqs[0, 0], 10) # AA: 10 reads + self.assertEqual(counter.edit_read_freqs[1, 1], 9) # CC: 9 reads + self.assertEqual(counter.edit_read_freqs[1, 3], 3) # CT: 3 reads (non filtrés) + self.assertEqual(counter.edit_read_freqs[3, 3], 2) # TT: 2 reads + + def test_observed_sites_count(self): + """Test que ObservedSites compte toutes les observations""" + filter = SiteFilter(cov_threshold=5, edit_threshold=2) + counter = MultiCounter(filter) + + # Ajouter 3 sites avec différentes couvertures + site1 = RNASiteVariantData( + seqid="chr1", position=10, reference=0, # A + strand=1, coverage=10, mean_quality=30.0, + frequencies=np.array([8, 2, 0, 0, 0], dtype=np.int32), + score=0.0 + ) + site2 = RNASiteVariantData( + seqid="chr1", position=20, reference=1, # C + strand=1, coverage=3, mean_quality=30.0, + frequencies=np.array([1, 2, 0, 0, 0], dtype=np.int32), + score=0.0 + ) + site3 = RNASiteVariantData( + seqid="chr1", position=30, reference=2, # G + strand=1, coverage=8, mean_quality=30.0, + frequencies=np.array([0, 3, 5, 0, 0], dtype=np.int32), + score=0.0 + ) + + counter.update(site1) + counter.update(site2) + counter.update(site3) + + # ObservedSites = sum de genome_base_freqs = 3 sites + observed = counter.genome_base_freqs.sum() + self.assertEqual(observed, 3) + + # Vérifier la distribution par base + self.assertEqual(counter.genome_base_freqs[0], 1) # A + self.assertEqual(counter.genome_base_freqs[1], 1) # C + self.assertEqual(counter.genome_base_freqs[2], 1) # G + + def test_qualified_sites_count(self): + """Test que QualifiedSites compte seulement les sites passant le filtre de couverture""" + filter = SiteFilter(cov_threshold=5, edit_threshold=2) + counter = MultiCounter(filter) + + # Site 1: couverture 10 >= 5 → qualifié (référence A) + site1 = RNASiteVariantData( + seqid="chr1", position=10, reference=0, # A + strand=1, coverage=10, mean_quality=30.0, + frequencies=np.array([8, 2, 0, 0, 0], dtype=np.int32), + score=0.0 + ) + # Site 2: couverture 3 < 5 → non qualifié (référence C) + site2 = RNASiteVariantData( + seqid="chr1", position=20, reference=1, # C + strand=1, coverage=3, mean_quality=30.0, + frequencies=np.array([1, 2, 0, 0, 0], dtype=np.int32), + score=0.0 + ) + # Site 3: couverture 8 >= 5 → qualifié (référence G) + site3 = RNASiteVariantData( + seqid="chr1", position=30, reference=2, # G + strand=1, coverage=8, mean_quality=30.0, + frequencies=np.array([0, 3, 5, 0, 0], dtype=np.int32), + score=0.0 + ) + + counter.update(site1) + counter.update(site2) + counter.update(site3) + + # QualifiedSites = 2 (site1 et site3) + self.assertEqual(counter.filtered_sites_count, 2) + + # ObservedSites = 3 (tous) + self.assertEqual(counter.genome_base_freqs.sum(), 3) + + # QualifiedBases: seulement A et G (site1 et site3) + self.assertEqual(counter.filtered_base_freqs[0], 1) # A + self.assertEqual(counter.filtered_base_freqs[1], 0) # C (non qualifié) + self.assertEqual(counter.filtered_base_freqs[2], 1) # G + self.assertEqual(counter.filtered_base_freqs[3], 0) # T + + def test_merge_preserves_counts(self): + """Test que merge préserve les compteurs de sites et de bases""" + filter = SiteFilter(cov_threshold=5, edit_threshold=2) + + counter1 = MultiCounter(filter) + site1 = RNASiteVariantData( + seqid="chr1", position=10, reference=0, # A, coverage 10 >= 5 + strand=1, coverage=10, mean_quality=30.0, + frequencies=np.array([8, 2, 0, 0, 0], dtype=np.int32), + score=0.0 + ) + counter1.update(site1) + + counter2 = MultiCounter(filter) + site2 = RNASiteVariantData( + seqid="chr1", position=20, reference=1, # C, coverage 3 < 5 + strand=1, coverage=3, mean_quality=30.0, + frequencies=np.array([1, 2, 0, 0, 0], dtype=np.int32), + score=0.0 + ) + site3 = RNASiteVariantData( + seqid="chr1", position=30, reference=0, # A, coverage 6 >= 5 + strand=1, coverage=6, mean_quality=30.0, + frequencies=np.array([4, 2, 0, 0, 0], dtype=np.int32), + score=0.0 + ) + counter2.update(site2) + counter2.update(site3) + + # Merge counter2 into counter1 + counter1.merge(counter2) + + # Vérifier les totaux + self.assertEqual(counter1.genome_base_freqs.sum(), 3) # 1 + 2 = 3 sites observés + self.assertEqual(counter1.filtered_sites_count, 2) # 1 + 1 = 2 sites qualifiés (site1 et site3) + + # Vérifier les bases observées + self.assertEqual(counter1.genome_base_freqs[0], 2) # A: site1 + site3 + self.assertEqual(counter1.genome_base_freqs[1], 1) # C: site2 + + # Vérifier les bases qualifiées + self.assertEqual(counter1.filtered_base_freqs[0], 2) # A: site1 + site3 (tous deux qualifiés) + self.assertEqual(counter1.filtered_base_freqs[1], 0) # C: site2 non qualifié + self.assertEqual(counter1.filtered_base_freqs.sum(), 2) # Total = 2 + + def test_qualified_bases_distribution(self): + """Test que QualifiedBases reflète correctement la distribution des bases qualifiées""" + filter = SiteFilter(cov_threshold=5, edit_threshold=2) + counter = MultiCounter(filter) + + # Ajouter plusieurs sites avec différentes bases et couvertures + sites = [ + (0, 10), # A, qualifié + (0, 6), # A, qualifié + (1, 3), # C, non qualifié + (1, 7), # C, qualifié + (2, 2), # G, non qualifié + (2, 5), # G, qualifié + (3, 8), # T, qualifié + ] + + for ref, cov in sites: + # Use frequencies with 2 non-ref reads so edit_threshold=2 is met when cov >= 5 + non_ref = 2 if cov >= 5 else 0 + freqs = np.zeros(5, dtype=np.int32) + freqs[ref] = cov - non_ref + # Add a non-ref pairing at index (ref+1)%4 + freqs[(ref + 1) % 4] = non_ref + site = RNASiteVariantData( + seqid="chr1", + position=0, + reference=ref, + strand=1, + coverage=cov, + mean_quality=30.0, + frequencies=freqs, + score=0.0 + ) + counter.update(site) + + # ObservedBases: toutes les occurrences + self.assertEqual(counter.genome_base_freqs[0], 2) # A: 2 + self.assertEqual(counter.genome_base_freqs[1], 2) # C: 2 + self.assertEqual(counter.genome_base_freqs[2], 2) # G: 2 + self.assertEqual(counter.genome_base_freqs[3], 1) # T: 1 + self.assertEqual(counter.genome_base_freqs.sum(), 7) + + # QualifiedBases: seulement avec cov >= 5 + self.assertEqual(counter.filtered_base_freqs[0], 2) # A: 2 qualifiés (10, 6) + self.assertEqual(counter.filtered_base_freqs[1], 1) # C: 1 qualifié (7) + self.assertEqual(counter.filtered_base_freqs[2], 1) # G: 1 qualifié (5) + self.assertEqual(counter.filtered_base_freqs[3], 1) # T: 1 qualifié (8) + self.assertEqual(counter.filtered_base_freqs.sum(), 5) + self.assertEqual(counter.filtered_sites_count, 5) + + +class TestFeatureTotalSites(unittest.TestCase): + """Tests pour le calcul de TotalSites des features""" + + def test_simple_feature_total_sites(self): + """Test TotalSites pour une feature simple""" + feature = SeqFeature( + location=SimpleLocation(0, 100), + id="exon1", + type="exon" + ) + + # TotalSites = len(location) + total_sites = len(feature.location) + self.assertEqual(total_sites, 100) + + def test_compound_feature_total_sites(self): + """Test TotalSites pour une feature avec CompoundLocation (chimaera)""" + # Deux exons: [0-50] et [100-150] + location = CompoundLocation([ + SimpleLocation(0, 50), + SimpleLocation(100, 150) + ]) + feature = SeqFeature( + location=location, + id="chimaera", + type="exon" + ) + + # TotalSites = somme des longueurs sans gaps + total_sites = len(feature.location) + self.assertEqual(total_sites, 100) # 50 + 50 + + def test_gene_total_sites(self): + """Test TotalSites pour un gene (span complet incluant introns)""" + # Gene de 1 à 1000 + feature = SeqFeature( + location=SimpleLocation(0, 1000), + id="gene1", + type="gene" + ) + + # TotalSites = span complet + total_sites = len(feature.location) + self.assertEqual(total_sites, 1000) + + +class TestAggregateTotalSites(unittest.TestCase): + """Tests pour le calcul de TotalSites des aggregates""" + + def test_aggregate_from_continuous_features(self): + """Test TotalSites pour aggregate de features contiguës""" + pos = AggregatePositions(strand=1) + + # Trois exons adjacents + exon1 = SeqFeature(location=SimpleLocation(0, 100), id="exon1", type="exon") + exon2 = SeqFeature(location=SimpleLocation(100, 200), id="exon2", type="exon") + exon3 = SeqFeature(location=SimpleLocation(200, 300), id="exon3", type="exon") + + pos.update_from_feature(exon1) + pos.update_from_feature(exon2) + pos.update_from_feature(exon3) + + # TotalSites = end - start = 300 - 0 = 300 + total_sites = pos.end - pos.start + self.assertEqual(total_sites, 300) + + def test_aggregate_from_separated_features(self): + """Test TotalSites pour aggregate de features séparées""" + pos = AggregatePositions(strand=1) + + # Deux exons avec un gap + exon1 = SeqFeature(location=SimpleLocation(0, 100), id="exon1", type="exon") + exon2 = SeqFeature(location=SimpleLocation(200, 300), id="exon2", type="exon") + + pos.update_from_feature(exon1) + pos.update_from_feature(exon2) + + # TotalSites = end - start = 300 - 0 = 300 (inclut le gap) + total_sites = pos.end - pos.start + self.assertEqual(total_sites, 300) + + def test_aggregate_empty(self): + """Test TotalSites pour aggregate vide""" + pos = AggregatePositions(strand=0) + + # Aucune feature ajoutée + start_str, end_str, strand_str = pos.to_strings() + + # Devrait retourner "." + self.assertEqual(start_str, ".") + self.assertEqual(end_str, ".") + + +class TestIntegrationSiteCounts(unittest.TestCase): + """Tests d'intégration pour les trois colonnes""" + + def test_all_three_columns(self): + """Test que les trois colonnes sont cohérentes""" + filter = SiteFilter(cov_threshold=5, edit_threshold=2) + counter = MultiCounter(filter) + + # 5 sites avec différentes couvertures + coverages = [10, 3, 8, 4, 6] + for i, cov in enumerate(coverages): + site = RNASiteVariantData( + seqid="chr1", + position=i * 10, + reference=0, + strand=1, + coverage=cov, + mean_quality=30.0, + frequencies=np.array([cov-2, 2, 0, 0, 0], dtype=np.int32), + score=0.0 + ) + counter.update(site) + + # Feature de 0 à 50 (50 bases) + feature = SeqFeature( + location=SimpleLocation(0, 50), + id="exon1", + type="exon" + ) + + total_sites = len(feature.location) + observed_sites = counter.genome_base_freqs.sum() + qualified_sites = counter.filtered_sites_count + + # Vérifications + self.assertEqual(total_sites, 50) # Taille de la feature + self.assertEqual(observed_sites, 5) # 5 observations + self.assertEqual(qualified_sites, 3) # 3 avec cov >= 5 (10, 8, 6) + + # Relations logiques + self.assertLessEqual(observed_sites, total_sites) + self.assertLessEqual(qualified_sites, observed_sites) + + +if __name__ == '__main__': + unittest.main() diff --git a/bin/pluviometer_wrapper.py b/bin/pluviometer_wrapper.py index 3774710..5020175 100755 --- a/bin/pluviometer_wrapper.py +++ b/bin/pluviometer_wrapper.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 """Wrapper to run pluviometer module""" import sys diff --git a/build_containers.sh b/build_containers.sh index 6a5ffed..bec143f 100755 --- a/build_containers.sh +++ b/build_containers.sh @@ -129,7 +129,6 @@ if [ "$build_docker" = true ]; then image_list=( ) for dir in containers/docker/*; do - cd "${dir}" imgname=$(echo $dir | rev | cut -d/ -f1 | rev) image_list+=(${imgname}) @@ -146,10 +145,8 @@ if [ "$build_docker" = true ]; then fi fi - docker build ${docker_arch_option} -t ${imgname} . - - # Back to the original working directory - cd "$wd" + # Use containers/ as build context so shared files (e.g. env_*.yml) are accessible + docker build ${docker_arch_option} -f "${dir}/Dockerfile" -t ${imgname} containers/common/ done if [[ ${github_action_mode} == 'github_action' ]]; then diff --git a/config/resources/hpc.config b/config/resources/hpc.config index 2900223..33b9db3 100644 --- a/config/resources/hpc.config +++ b/config/resources/hpc.config @@ -12,6 +12,10 @@ process { memory = '4GB' time = '2d' // Long timeout to wait for submitted job completion } + withName: 'drip' { + cpus = 16 + time = '6h' + } withLabel: 'fastqc' { cpus = 8 time = '6h' @@ -21,21 +25,19 @@ process { time = '2d' } withLabel: 'jacusa2' { - cpus = 4 - time = '4h' + cpus = 16 + time = '1d' } withLabel: 'pigz' { cpus = 8 time = '4h' } withLabel: 'pluviometer' { - cpus = 4 + cpus = 8 time = '6h' - memory = '48GB' - errorStrategy = 'terminate' } withLabel: 'reditools3' { - cpus = 6 + cpus = 16 time = '1d' } withLabel: 'samtools' { diff --git a/config/resources/local.config b/config/resources/local.config index 3752d38..1a88460 100644 --- a/config/resources/local.config +++ b/config/resources/local.config @@ -4,7 +4,11 @@ process { maxForks = 8 shell = ['/bin/bash', '-euo', 'pipefail'] stageOutMode = 'rsync' - + + withName: 'drip' { + cpus = 4 + time = '6h' + } withLabel: 'fastqc' { cpus = 2 time = '6h' @@ -27,6 +31,10 @@ process { cpus = 4 memory = '16GB' time = '2d' + } + withLabel: 'drip' { + cpus = 4 + time = '6h' } withLabel: 'hisat2_index' { cpus = 4 diff --git a/containers/docker/sapin/constraints.txt b/containers/common/constraints_sapin.txt similarity index 100% rename from containers/docker/sapin/constraints.txt rename to containers/common/constraints_sapin.txt diff --git a/containers/docker/pluviometer/env_pluviometer.yml b/containers/common/env_pluviometer.yml similarity index 100% rename from containers/docker/pluviometer/env_pluviometer.yml rename to containers/common/env_pluviometer.yml diff --git a/containers/docker/reditools2/env_reditools2.yml b/containers/common/env_reditools2.yml similarity index 100% rename from containers/docker/reditools2/env_reditools2.yml rename to containers/common/env_reditools2.yml diff --git a/containers/docker/jacusa2/Dockerfile b/containers/docker/jacusa2/Dockerfile index 5d431f9..34196b1 100644 --- a/containers/docker/jacusa2/Dockerfile +++ b/containers/docker/jacusa2/Dockerfile @@ -1,4 +1,4 @@ -FROM openjdk:8u102 +FROM quay.io/biocontainers/openjdk:11.0.1--2 # Download the tool RUN wget https://github.com/dieterich-lab/JACUSA2/releases/download/v2.0.4/JACUSA_v2.0.4.jar -O /usr/local/bin/JACUSA_v2.0.4.jar diff --git a/containers/docker/sapin/Dockerfile b/containers/docker/sapin/Dockerfile index 4c68470..a283fa7 100644 --- a/containers/docker/sapin/Dockerfile +++ b/containers/docker/sapin/Dockerfile @@ -10,9 +10,9 @@ RUN apt install -y git RUN git clone --depth=1 https://github.com/Juke34/SAPiN.git # Use a pip constraint file to pin dependencies for reproducibility -COPY constraints.txt SAPiN/ +COPY constraints_sapin.txt SAPiN/ RUN cd SAPiN && \ - pip install --constraint constraints.txt . + pip install --constraint constraints_sapin.txt . CMD [ "bash" ] diff --git a/containers/singularity/pluviometer/env_pluviometer.yml b/containers/singularity/pluviometer/env_pluviometer.yml deleted file mode 100644 index 47b0cb0..0000000 --- a/containers/singularity/pluviometer/env_pluviometer.yml +++ /dev/null @@ -1,57 +0,0 @@ -name: rain-stats-scripts -channels: - - bioconda - - conda-forge -dependencies: - - _libgcc_mutex=0.1=conda_forge - - _openmp_mutex=4.5=2_gnu - - bcbio-gff=0.7.1=pyh7e72e81_2 - - biopython=1.85=py312h66e93f0_1 - - bx-python=0.13.0=py312h5e9d817_1 - - bzip2=1.0.8=h4bc722e_7 - - ca-certificates=2025.7.14=hbd8a1cb_0 - - frozendict=2.4.6=py312h66e93f0_0 - - ld_impl_linux-64=2.43=h712a8e2_4 - - libblas=3.9.0=31_h59b9bed_openblas - - libcblas=3.9.0=31_he106b2a_openblas - - libexpat=2.7.0=h5888daf_0 - - libffi=3.4.6=h2dba641_1 - - libgcc=14.2.0=h767d61c_2 - - libgcc-ng=14.2.0=h69a702a_2 - - libgfortran=14.2.0=h69a702a_2 - - libgfortran5=14.2.0=hf1ad2bd_2 - - libgomp=14.2.0=h767d61c_2 - - liblapack=3.9.0=31_h7ac8fdf_openblas - - liblzma=5.8.1=hb9d3cd8_0 - - libnsl=2.0.1=hd590300_0 - - libopenblas=0.3.29=pthreads_h94d23a6_0 - - libsqlite=3.49.1=hee588c1_2 - - libstdcxx=14.2.0=h8f9b012_2 - - libstdcxx-ng=14.2.0=h4852527_2 - - libunwind=1.6.2=h9c3ff4c_0 - - libuuid=2.38.1=h0b41bf4_0 - - libxcrypt=4.4.36=hd590300_1 - - libzlib=1.3.1=hb9d3cd8_2 - - natsort=8.4.0=pyh29332c3_1 - - ncurses=6.5=h2d0b736_3 - - numpy=2.2.4=py312h72c5963_0 - - openssl=3.5.1=h7b32b05_0 - - pip=25.0.1=pyh8b19718_0 - - progressbar2=4.5.0=pyhd8ed1ab_1 - - py-spy=0.4.0=h4c5a871_1 - - pyparsing=3.2.3=pyhd8ed1ab_1 - - python=3.12.10=h9e4cc4f_0_cpython - - python-utils=3.9.1=pyhff2d567_1 - - python_abi=3.12=6_cp312 - - readline=8.2=h8c095d6_2 - - setuptools=78.1.0=pyhff2d567_0 - - six=1.17.0=pyhd8ed1ab_0 - - sortedcontainers=2.4.0=pyhd8ed1ab_1 - - tk=8.6.13=noxft_h4845f30_101 - - typing_extensions=4.13.2=pyh29332c3_0 - - tzdata=2025b=h78e105d_0 - - wheel=0.45.1=pyhd8ed1ab_1 - - pysam=0.23.3=py312h47d5410_1 - - pip: - - flameprof==0.4 -prefix: /home/earx/miniforge3/envs/rain-stats-scripts diff --git a/containers/singularity/pluviometer/pluviometer.def b/containers/singularity/pluviometer/pluviometer.def index 9b4473a..bbcfa3d 100644 --- a/containers/singularity/pluviometer/pluviometer.def +++ b/containers/singularity/pluviometer/pluviometer.def @@ -2,7 +2,7 @@ Bootstrap: docker From: condaforge/miniforge3:24.7.1-2 %files - containers/singularity/pluviometer/env_pluviometer.yml /env_pluviometer.yml + containers/env_pluviometer.yml /env_pluviometer.yml %post conda env update --name base --file env_pluviometer.yml diff --git a/containers/singularity/reditools2/env_reditools2.yml b/containers/singularity/reditools2/env_reditools2.yml deleted file mode 100644 index 5897beb..0000000 --- a/containers/singularity/reditools2/env_reditools2.yml +++ /dev/null @@ -1,15 +0,0 @@ -channels: - - bioconda - - conda-forge -dependencies: - - _openmp_mutex=4.5=2_gnu - - pip=20.1.1=pyh9f0ad1d_0 - - psutil=5.7.0=py27hdf8410d_1 - - pysam=0.20.0=py27h7835474_0 - - python=2.7.15=h5a48372_1011_cpython - - python_abi=2.7=1_cp27mu - - samtools=1.18=hd87286a_0 - - tabix=1.11=hdfd78af_0 - - sortedcontainers - - netifaces - diff --git a/containers/singularity/reditools2/reditools2.def b/containers/singularity/reditools2/reditools2.def index 88f20e9..1a792e0 100644 --- a/containers/singularity/reditools2/reditools2.def +++ b/containers/singularity/reditools2/reditools2.def @@ -2,7 +2,7 @@ Bootstrap: docker From: condaforge/miniforge3:24.7.1-2 %files - containers/singularity/reditools2/env_reditools2.yml ./env_reditools2.yml + containers/singularity/common/env_reditools2.yml ./env_reditools2.yml %post # Activer conda dans l'environnement shell diff --git a/containers/singularity/sapin/constraints.txt b/containers/singularity/sapin/constraints.txt deleted file mode 100644 index 3733d3d..0000000 --- a/containers/singularity/sapin/constraints.txt +++ /dev/null @@ -1,20 +0,0 @@ -argcomplete==3.6.2 -argh==0.31.3 -contourpy==1.3.1 -cycler==0.12.1 -fonttools==4.57.0 -gffutils==0.13 -importlib_metadata==8.6.1 -kiwisolver==1.4.8 -matplotlib==3.10.1 -numpy==2.2.4 -packaging==24.2 -pillow==11.1.0 -pyfaidx==0.8.1.3 -pyparsing==3.2.3 -pysam==0.23.0 -python-dateutil==2.9.0.post0 -SAPiN @ file:///SAPiN -simplejson==3.20.1 -six==1.17.0 -zipp==3.21.0 diff --git a/containers/singularity/sapin/sapin.def b/containers/singularity/sapin/sapin.def index ab020a7..1f2fffb 100644 --- a/containers/singularity/sapin/sapin.def +++ b/containers/singularity/sapin/sapin.def @@ -2,7 +2,7 @@ Bootstrap: docker From: python:slim-bullseye %files - containers/singularity/sapin/constraints.txt /constraints.txt + containers/singularity/common/constraints_sapin.txt /constraints.txt %post apt-get update && apt-get install -y \ diff --git a/modules/aline.nf b/modules/aline.nf index 6c8e9fd..9fbff88 100644 --- a/modules/aline.nf +++ b/modules/aline.nf @@ -1,12 +1,11 @@ /* Adapted from from https://github.com/mahesh-panchal/nf-cascade -Auto-detects HPC/local environment and adapts execution strategy */ process AliNe { tag "$pipeline_name" label 'aline' - publishDir "${params.outdir}", mode: 'copy' + publishDir "${output_dir}", mode: 'copy' errorStrategy 'terminate' // to avoid any retry maxRetries 0 // Override global retry config - do not retry this process @@ -21,116 +20,43 @@ process AliNe { val aligner val library_type val annotation - val cache_dir // String - val use_slurm // Boolean - whether parent is using slurm + val cache_dir // cache directory + val output_dir // output directory when: task.ext.when == null || task.ext.when - script: - def nxf_cmd = "nextflow run ${pipeline_name} ${profile} ${config} --reads ${reads} --reference ${genome} ${read_type} ${aligner} ${library_type} --annotation ${annotation} --data_type rna --outdir \$WORK_DIR/AliNe" - """ - echo "[AliNe] Process started at \$(date '+%Y-%m-%d %H:%M:%S')" - - # Save absolute work directory before changing context - WORK_DIR=\$(pwd) - - # Create cache directory for resume AliNe run made from different working directory - mkdir -p "${cache_dir}" - cd "${cache_dir}" - - # Save command for reference/debugging - echo "${nxf_cmd}" > \$WORK_DIR/nf-cmd.sh - - # Detect execution environment and run AliNe accordingly - if ${use_slurm} && command -v sbatch >/dev/null 2>&1; then - echo "[AliNe] Detected HPC environment - submitting AliNe as separate SLURM job" - - # Create sbatch script for AliNe - cat > \$WORK_DIR/aline_job.sh < \$WORK_DIR/aline_job_id.txt - - # Wait for job to appear in scheduler queue - # Simple wait for job to appear or to be started - echo "[AliNe] Waiting for job \$JOB_ID to appear in queue or to be started..." - while true; do - # Check if job is in queue - if squeue -j \$JOB_ID 2>/dev/null | grep -q \$JOB_ID; then - echo "[AliNe] Job \$JOB_ID is now visible in queue" - break - else - echo "[AliNe] Job \$JOB_ID not yet visible in queue \$(date '+%Y-%m-%d %H:%M:%S')" - fi - # Check job state (crash test: if job has started or finished) - JOB_STATE=\$(sacct -j \$JOB_ID --format=State --noheader | head -1 | tr -d ' ') - if [[ "\$JOB_STATE" =~ ^(RUNNING|COMPLETED|FAILED|CANCELLED|TIMEOUT|PREEMPTED|NODE_FAIL|OUT_OF_MEMORY)\$ ]]; then - echo "[AliNe] Job \$JOB_ID has state: \$JOB_STATE (not visible in queue, but started or finished)" - break - fi - echo "[AliNe] Job not yet visible, waiting..." - sleep 5 - done # Simple wait for job to appear or to be started - - # Wait for job completion - echo "[AliNe] Waiting for job \$JOB_ID to complete..." - while squeue -j \$JOB_ID 2>/dev/null | grep -q \$JOB_ID; do - sleep 30 - done - - # Check job exit status - JOB_STATE=\$(sacct -j \$JOB_ID --format=State --noheader | head -1 | tr -d ' ') - echo "[AliNe] Job \$JOB_ID finished with state: \$JOB_STATE at \$(date '+%Y-%m-%d %H:%M:%S')" - - if [[ "\$JOB_STATE" != "COMPLETED" ]]; then - echo "[AliNe] ERROR: Job failed with state \$JOB_STATE at \$(date '+%Y-%m-%d %H:%M:%S')" >&2 - echo "With message (100 last lines)": >&2 - tail -n 100 \$WORK_DIR/aline_\$JOB_ID.out >&2 - exit 1 - fi - - # Copy log for reference - if [ -f .nextflow.log ]; then - cp .nextflow.log \$WORK_DIR/nextflow.log - fi - echo "[AliNe] Pipeline completed successfully via SLURM at \$(date '+%Y-%m-%d %H:%M:%S')" - else - echo "[AliNe] Detected local/standard environment - running AliNe directly" - - # Run nextflow command directly - ${nxf_cmd} || { - echo "[AliNe] ERROR: Pipeline failed at \$(date '+%Y-%m-%d %H:%M:%S')" >&2 - exit 1 - } - - echo "[AliNe] Pipeline completed successfully (direct execution) at \$(date '+%Y-%m-%d %H:%M:%S')" - fi - - echo "[AliNe] Process finished at \$(date '+%Y-%m-%d %H:%M:%S')" -""" + exec: + def cache_path = file(cache_dir) + assert cache_path.mkdirs() + // construct nextflow command + def nxf_cmd = [ + 'nextflow run', + pipeline_name, + profile, + config, + "--reads ${reads}", + "--reference ${genome}", + read_type, + aligner, + library_type, + "--annotation ${annotation}", + "--data_type rna", + "--outdir $task.workDir/AliNe", + ].join(" ") + // Copy command to shell script in work dir for reference/debugging. + file("$task.workDir/nf-cmd.sh").text = nxf_cmd + // Run nextflow command locally in cache directory + def process = nxf_cmd.execute(null, cache_path.toFile()) + // Print process output to stdout and stderr + process.consumeProcessOutput(System.out, System.err) + process.waitFor() + stdout = process.text + // Copy nextflow log to work directory + cache_path.resolve(".nextflow.log").copyTo("${task.workDir}/nextflow.log") + assert process.exitValue() == 0: stdout output: path "AliNe" , emit: output + val stdout, emit: log } diff --git a/modules/bash.nf b/modules/bash.nf index 534e397..7f21323 100644 --- a/modules/bash.nf +++ b/modules/bash.nf @@ -126,7 +126,7 @@ process create_aline_csv_he { */ process collect_aline_csv { label 'bash' - publishDir("${params.outdir}/${output_dir}", mode:"copy", pattern: "*.csv") + publishDir("${output_dir}", mode:"copy", pattern: "*.csv") input: val all_csv // List of tuples (meta, fastq_files) @@ -175,189 +175,6 @@ process recreate_csv_with_abs_paths { """ } -process standardize_pluvio_aggregates { - label 'bash' - tag "${tsv.baseName}" - publishDir("${params.outdir}/pluviometer/${tool_name}/standardized", mode:"copy", pattern: "*_standardized.tsv") - - input: - tuple(val(meta), path(tsv)) - - output: - tuple val(meta), path("*_standardized.tsv"), emit: standardized_tsv - - script: - def basename = tsv.baseName - tool_name = basename.split('_')[-2] - """ - awk 'BEGIN { - FS=OFS="\\t" - } - NR==1 { - # Store original header and find column indices - for(i=1; i<=NF; i++) { - col[\$i] = i - } - - # Print new header - printf "SeqID\\tParentIDs\\tID\\tMtype\\tPtype\\tType\\tCtype\\tMode\\tStart\\tEnd\\tStrand" - - # Print remaining data columns (after Strand) - for(i=col["Strand"]+1; i<=NF; i++) { - printf "\\t%s", \$i - } - printf "\\n" - next - } - { - # SeqID - seqid = \$col["SeqID"] - printf "%s\\t", seqid - - # ParentIDs - printf "%s\\t", \$col["ParentIDs"] - - # ID (from AggregateID) - aggregate_id = \$col["AggregateID"] - printf "%s\\t", aggregate_id - - # Mtype (always "aggregate") - printf "aggregate\\t" - - # Ptype (from ParentType) - printf "%s\\t", \$col["ParentType"] - - # Type (logic: feature if AggregateID has info, chr if SeqID not ".", global if SeqID is ".") - if(aggregate_id != "." && aggregate_id != "") { - type = "feature_agg" - } else if(seqid != ".") { - type = "chr_agg" - } else { - type = "global_agg" - } - printf "%s\\t", type - - # Ctype (from AggregateType) - printf "%s\\t", \$col["AggregateType"] - - # Mode (from AggregationMode) - printf "%s\\t", \$col["AggregationMode"] - - # Start, End, Strand (use existing columns if present, otherwise ".") - if("Start" in col) { - printf "%s\\t", \$col["Start"] - } else { - printf ".\\t" - } - - if("End" in col) { - printf "%s\\t", \$col["End"] - } else { - printf ".\\t" - } - - if("Strand" in col) { - printf "%s", \$col["Strand"] - } else { - printf "." - } - - # Print remaining data columns - for(i=col["Strand"]+1; i<=NF; i++) { - printf "\\t%s", \$i - } - printf "\\n" - }' ${tsv} > ${basename}_standardized.tsv - - echo "Standardization complete: ${basename}_standardized.tsv" - """ -} - -/* - * Standardize drip_features output format - * Transforms feature-level data into standardized column structure - * Output: SeqID | ParentIDs | ID | MType | Ptype | Type | Ctype | Mode | Start | End | Strand | [data...] - */ -process standardize_pluvio_features { - label 'bash' - tag "${tsv.baseName}" - publishDir("${params.outdir}/pluviometer/${tool_name}/standardized", mode:"copy", pattern: "*_standardized.tsv") - - input: - tuple(val(meta), path(tsv)) - - output: - tuple val(meta), path("*_standardized.tsv"), emit: standardized_tsv - - script: - def basename = tsv.baseName - tool_name = basename.split('_')[-2] - """ - awk 'BEGIN { - FS=OFS="\\t" - } - NR==1 { - # Store original header and find column indices - for(i=1; i<=NF; i++) { - col[\$i] = i - } - - # Print new header - printf "SeqID\\tParentIDs\\tID\\tMtype\\tPtype\\tType\\tCtype\\tMode\\tStart\\tEnd\\tStrand" - - # Print remaining data columns (after Strand if exists, or after Strand) - start_data_col = ("Strand" in col) ? col["Strand"]+1 : col["Strand"]+1 - for(i=start_data_col; i<=NF; i++) { - printf "\\t%s", \$i - } - printf "\\n" - next - } - { - # SeqID - printf "%s\\t", \$col["SeqID"] - - # ParentIDs - printf "%s\\t", \$col["ParentIDs"] - - # ID (from FeatureID) - printf "%s\\t", \$col["FeatureID"] - - # Mtype (always "feature") - printf "feature\\t" - - # Ptype (always ".") - printf ".\\t" - - # Type (from original Type column) - printf "%s\\t", \$col["Type"] - - # Ctype (always ".") - printf ".\\t" - - # Mode (always ".") - printf ".\\t" - - # Start - printf "%s\\t", \$col["Start"] - - # End - printf "%s\\t", \$col["End"] - - # Strand - printf "%s", \$col["Strand"] - - # Print remaining data columns - for(i=col["Strand"]+1; i<=NF; i++) { - printf "\\t%s", \$i - } - printf "\\n" - }' ${tsv} > ${basename}_standardized.tsv - - echo "Standardization complete: ${basename}_standardized.tsv" - """ -} - /* * Filter drip output files by AggregationMode * Creates separate files for each unique AggregationMode value diff --git a/modules/jacusa2.nf b/modules/jacusa2.nf index 15727e9..c40632c 100644 --- a/modules/jacusa2.nf +++ b/modules/jacusa2.nf @@ -8,7 +8,7 @@ process jacusa2 { tuple(path(genome), path(fastaindex)) output: - tuple(val(meta), path("*.site_edits_jacusa2.txt"), emit: tuple_sample_jacusa2_table) + tuple(val(meta), val("jacusa2"), path("*.site_edits_jacusa2.txt"), emit: tuple_sample_jacusa2_table) path("*.filtered") script: diff --git a/modules/pluviometer.nf b/modules/pluviometer.nf index efa5a5b..67b2a80 100644 --- a/modules/pluviometer.nf +++ b/modules/pluviometer.nf @@ -1,29 +1,28 @@ process pluviometer { label "pluviometer" - publishDir("${params.outdir}/pluviometer/${tool_format}/raw", mode: "copy") - tag "${meta.uid}" + publishDir("${params.outdir}/pluviometer/${tool_format}/", mode: "copy") + tag "${meta.uid}_${tool_format}" input: - tuple(val(meta), path(site_edits)) + tuple(val(meta), val(tool_format), path(site_edits)) path(gff) - val(tool_format) output: - tuple(val(meta), path("*features.tsv"), emit: tuple_sample_feature) - tuple(val(meta), path("*aggregates.tsv"), emit: tuple_sample_aggregate) - tuple(val(meta), path("*pluviometer.log"), emit: tuple_sample_log) + tuple(val(meta), val(tool_format), path("*features.tsv"), emit: tuple_sample_feature) + tuple(val(meta), val(tool_format), path("*aggregates.tsv"), emit: tuple_sample_aggregate) + tuple(val(meta), val(tool_format), path("*pluviometer.log"), emit: tuple_sample_log) script: base_name = site_edits.BaseName - """ + """ pluviometer_wrapper.py \ --sites ${site_edits} \ --gff ${gff} \ --format ${tool_format} \ - --cov 1 \ + --cov ${params.cov_threshold} \ --edit_threshold ${params.edit_threshold} \ --threads ${task.cpus} \ --aggregation_mode ${params.aggregation_mode} \ - --output "${meta.uid}_${tool_format}" + --output "${meta.uid}_${tool_format}" """ -} \ No newline at end of file +} diff --git a/modules/python.nf b/modules/python.nf index 256325a..229af8c 100644 --- a/modules/python.nf +++ b/modules/python.nf @@ -28,34 +28,36 @@ process restore_original_sequences { process drip { label "pluviometer" - tag "drip" - publishDir("${params.outdir}/drip/${prefix}", mode:"copy", pattern: "*.tsv") + tag "drip_${tool}" + publishDir("${params.outdir}/drip/${prefix}", mode:"copy", pattern: "*/*") input: - val(meta_tsv) + tuple(val(tool), val(meta_tsv)) val prefix + val samples_pct + val group_pct output: - path("*_AG.tsv"), emit: editing_ag - path("*.tsv"), emit: editing_all + path("*_espr/*.tsv"), emit: editing_all_espr + path("*_espf/*.tsv"), emit: editing_all_espf script: def list = meta_tsv def args = [] - // Process list by pairs: [meta, file, meta, file, ...] - for (int i = 0; i < list.size(); i += 2) { - def m = list[i] // Don't reuse 'meta' variable name - def file = list[i + 1] + // Process list of [meta, file] pairs from groupTuple + list.each { pair -> + def m = pair[0] // meta dictionary + def file = pair[1] // file path def group = m.group ?: "group_unknown" def sample = m.sample ?: "sample_unknown" - def replicate = m.rep ?: "rep1" // Note: it's 'rep' not 'replicate' in meta + def replicate = m.rep ?: "rep1" args.add("${file}:${group}:${sample}:${replicate}") } def args_str = args.join(" ") - """ - drip.py drip_${prefix} ${args_str} + """ + drip.py --threads ${task.cpus} --min-samples-pct ${samples_pct} --min-group-pct ${group_pct} --output drip_${prefix} ${args_str} """ } \ No newline at end of file diff --git a/modules/reditools2.nf b/modules/reditools2.nf index 468ba40..6065015 100644 --- a/modules/reditools2.nf +++ b/modules/reditools2.nf @@ -9,12 +9,13 @@ process reditools2 { val region output: - tuple(val(meta), path("*site_edits_reditools2.txt"), emit: tuple_sample_serial_table) + tuple(val(meta), val("reditools2"), path("*site_edits_reditools2.txt"), emit: tuple_sample_serial_table) script: // Set the strand orientation parameter from the library type parameter // Terms explained in https://salmon.readthedocs.io/en/latest/library_type.html + // https://github.com/BioinfoUNIBA/REDItools2?tab=readme-ov-file: -s STRAND, --strand STRAND Strand: this can be 0 (unstranded), 1 (secondstrand oriented) or 2 (firststrand oriented) if (meta.strandedness in ["ISR", "SR"]) { // First-strand oriented strand_orientation = "2" diff --git a/modules/reditools3.nf b/modules/reditools3.nf index 50e7a4d..120eb5a 100644 --- a/modules/reditools3.nf +++ b/modules/reditools3.nf @@ -8,7 +8,7 @@ process reditools3 { path genome output: - tuple(val(meta), path("${base_name}.site_edits_reditools3.txt"), emit: tuple_sample_serial_table) + tuple(val(meta), val("reditools3"), path("${base_name}.site_edits_reditools3.txt"), emit: tuple_sample_serial_table) path("${base_name}.reditools3.log", emit: log) script: @@ -31,7 +31,7 @@ process reditools3 { } base_name = bam.BaseName - """ + """ python -m reditools analyze ${bam} --reference ${genome} --strand ${strand_orientation} --output-file ${base_name}.site_edits_reditools3.txt --threads ${task.cpus} --verbose &> ${base_name}.reditools3.log """ } diff --git a/modules/samtools.nf b/modules/samtools.nf index 89a6bd4..d3576f1 100644 --- a/modules/samtools.nf +++ b/modules/samtools.nf @@ -151,4 +151,27 @@ process samtools_merge_bams { """ samtools merge -@ ${task.cpus} ${meta.uid}_merged.bam ${bam1} ${bam2} """ +} + +/* + * Calculate MD tag for BAM file + * The MD tag is required by some tools like REDItools +* It encodes the reference bases that differ from the aligned read bases + * This tag can be missing or incorrect in BAM files that have been processed (e.g. with Picard MarkDuplicates, bamutil) and needs to be recomputed to ensure accurate editing site detection. + */ +process samtools_calmd { + label "samtools" + tag "${meta.uid}" + + input: + tuple val(meta), path(bam) + path(reference) + + output: + tuple val(meta), path("*_md.bam"), emit: tuple_sample_bam_with_md + + script: + """ + samtools calmd -@ ${task.cpus} -b ${bam} ${reference} > ${bam.baseName}_md.bam + """ } \ No newline at end of file diff --git a/rain.nf b/rain.nf index b79b964..ea97e9f 100644 --- a/rain.nf +++ b/rain.nf @@ -22,8 +22,12 @@ params.clean_duplicate = true // Edit counting params edit_site_tools = ["reditools2", "reditools3", "jacusa2", "sapin"] params.edit_site_tool = "reditools3" -params.edit_threshold = 1 -params.aggregation_mode = "all" +params.edit_threshold = 1 // Minimal number of edited reads to count a site as edited +params.cov_threshold = 10 // Minimal coverage to consider a site for editing detection +// When both flags below are provided → OR: keep if either condition is satisfied independently. +params.min_samples_pct = 50 // Minimal percentage of samples in which a site must be edited to be kept in the analysis (drip filtering) +params.min_group_pct = 75 // Minimal percentage of groups in which a site must be edited to be kept in the analysis (drip filtering) +params.aggregation_mode = "all" // used by pluviometer params.skip_hyper_editing = false // Skip hyper-editing detection // Report params params.multiqc_config = "$baseDir/config/multiqc_config.yaml" // MultiQC config file @@ -35,7 +39,6 @@ params.region = "" // e.g. chr21 - Used to limit the analysis to a specific regi params.help = null params.monochrome_logs = false // if true, no color in logs params.debug = false // Enable debug output -params.use_slurm_for_aline = false // Whether to submit AliNe as a separate SLURM job when using an HPC environment // -------------------------------------------------- /* ---- Params shared between RAIN and AliNe ---- */ @@ -62,7 +65,7 @@ params.trimming_fastp = false align_tools = [ 'bbmap', 'bowtie', 'bowtie2', 'bwaaln', 'bwamem', 'bwamem2', 'bwasw', 'dragmap', 'graphmap2', 'hisat2', 'kallisto', 'last', 'minimap2', 'novoalign', 'nucmer', 'ngmlr', 'salmon', 'star', 'subread', 'sublong' ] params.aligner = 'hisat2' // AliNe version -params.aline_version = 'v1.6.2' +params.aline_version = 'v1.6.3' //************************************************* // STEP 1 - HELP //************************************************* @@ -128,6 +131,7 @@ def helpMSG() { --aligner Aligner to use [default: $params.aligner] --clean_duplicate Remove PCR duplicates from BAM files using GATK MarkDuplicates. [default: $params.clean_duplicate] --clip_overlap Clip overlapping sequences in read pairs to avoid double counting. [default: $params.clipoverlap] + --cov_threshold Minimal coverage to consider a site for editing detection [default: $params.cov_threshold] --debug Enable debug output for troubleshooting. [default: $params.debug] --edit_site_tool Tool used for detecting edited sites. [default: $params.edit_site_tool] --edit_threshold Minimal number of edited reads to count a site as edited [default: $params.edit_threshold] @@ -160,6 +164,7 @@ Alignment Parameters Edited Site Detection Parameters + cov_threshold : ${params.cov_threshold} edit_site_tool : ${params.edit_site_tool} edit_threshold : ${params.edit_threshold} region : ${params.region} @@ -176,8 +181,7 @@ Report Parameters //************************************************* include { AliNe as ALIGNMENT } from "./modules/aline.nf" include {normalize_gxf} from "./modules/agat.nf" -include { extract_libtype; recreate_csv_with_abs_paths; collect_aline_csv; filter_drip_by_aggregation_mode; filter_drip_features_by_type - standardize_pluvio_aggregates; standardize_pluvio_features} from "./modules/bash.nf" +include { extract_libtype; recreate_csv_with_abs_paths; collect_aline_csv; filter_drip_by_aggregation_mode; filter_drip_features_by_type} from "./modules/bash.nf" include {bamutil_clipoverlap} from './modules/bamutil.nf' include {fastp} from './modules/fastp.nf' include {fastqc as fastqc_ali; fastqc as fastqc_dup; fastqc as fastqc_clip} from './modules/fastqc.nf' @@ -185,10 +189,10 @@ include {gatk_markduplicates } from './modules/gatk.nf' include {jacusa2} from "./modules/jacusa2.nf" include {multiqc} from './modules/multiqc.nf' include {fasta_unzip} from "$baseDir/modules/pigz.nf" -include {samtools_index; samtools_fasta_index; samtools_sort_bam as samtools_sort_bam_raw; samtools_sort_bam as samtools_sort_bam_merged; samtools_split_mapped_unmapped; samtools_merge_bams} from './modules/samtools.nf' +include {samtools_index; samtools_fasta_index; samtools_sort_bam as samtools_sort_bam_raw; samtools_sort_bam as samtools_sort_bam_merged; samtools_split_mapped_unmapped; samtools_merge_bams; samtools_calmd} from './modules/samtools.nf' include {reditools2} from "./modules/reditools2.nf" include {reditools3} from "./modules/reditools3.nf" -include {pluviometer as pluviometer_jacusa2; pluviometer as pluviometer_reditools2; pluviometer as pluviometer_reditools3; pluviometer as pluviometer_sapin} from "./modules/pluviometer.nf" +include {pluviometer} from "./modules/pluviometer.nf" include {drip as drip_aggregates; drip as drip_features} from "./modules/python.nf" include {sapin} from "./modules/sapin.nf" @@ -208,7 +212,6 @@ else { exit 1, "No executer selected: please use a profile activating docker or // check AliNE profile def aline_profile_list=[] -def use_slurm_for_aline = params.use_slurm_for_aline str_list = workflow.profile.tokenize(',') str_list.each { if ( it in aline_profile_allowed ){ @@ -524,7 +527,7 @@ workflow { if ( via_csv ){ aline_data_in_ch = recreate_csv_with_abs_paths(csv_ch) - aline_data_in_ch = collect_aline_csv(aline_data_in_ch.collect(), "AliNe") + aline_data_in_ch = collect_aline_csv(aline_data_in_ch.collect(), params.outdir) aline_data_in = aline_data_in_ch } else { @@ -593,7 +596,7 @@ workflow { "--strandedness ${params.strandedness}", clean_annotation, workflow.workDir.resolve('Juke34/AliNe').toUriString(), - use_slurm_for_aline // Pass info about whether to use slurm submission + params.outdir, ) // GET TUPLE [ID, BAM] FILES @@ -670,7 +673,6 @@ workflow { samtools_split_mapped_unmapped.out.unmapped_bam, genome, aline_profile, - use_slurm_for_aline, clean_annotation, 30, // quality threshold "${params.outdir}/hyper_editing", @@ -720,7 +722,7 @@ workflow { // Clip overlap if (params.clip_overlap) { - bamutil_clipoverlap(gatk_markduplicates.out.tuple_sample_dedupbam) + bamutil_clipoverlap(all_bam_sorted_dedup) all_bam_sorted_dedup_clip = bamutil_clipoverlap.out.tuple_sample_clipoverbam // stat on bam with overlap clipped if(params.fastqc){ @@ -728,45 +730,98 @@ workflow { logs.concat(fastqc_clip.out).set{logs} // save log } } else { - all_bam_sorted_dedup_clip = gatk_markduplicates.out.tuple_sample_dedupbam + all_bam_sorted_dedup_clip = all_bam_sorted_dedup } + // Recompute MD tag + if (params.clip_overlap || params.clean_duplicate) { + samtools_calmd(all_bam_sorted_dedup_clip, genome.collect()) + all_bam_sorted_dedup_clip_md = samtools_calmd.out.tuple_sample_bam_with_md + } else { + all_bam_sorted_dedup_clip_md = all_bam_sorted_dedup_clip + } + // index mapped bam - samtools_index(all_bam_sorted_dedup_clip) + samtools_index(all_bam_sorted_dedup_clip_md) final_bam_for_editing = samtools_index.out.tuple_sample_bam_bamindex // ------------------------------------------------------- // ----------------- DETECT EDITING SITES ---------------- // ------------------------------------------------------- - + Channel.empty().set{editing_analysis} // Select site detection tool if ( "jacusa2" in edit_site_tool_list ){ // Create a fasta index file of the reference genome samtools_fasta_index(genome.collect()) jacusa2(final_bam_for_editing, samtools_fasta_index.out.tuple_fasta_fastaindex.collect()) - pluviometer_jacusa2(jacusa2.out.tuple_sample_jacusa2_table, clean_annotation.collect(), "jacusa2") + editing_analysis = editing_analysis.mix(jacusa2.out.tuple_sample_jacusa2_table) } if ( "sapin" in edit_site_tool_list ){ sapin(tuple_sample_bam_processed, genome.collect()) } if ( "reditools2" in edit_site_tool_list ){ reditools2(final_bam_for_editing, genome.collect(), params.region) - pluviometer_reditools2(reditools2.out.tuple_sample_serial_table, clean_annotation.collect(), "reditools2") + editing_analysis = editing_analysis.mix(reditools2.out.tuple_sample_serial_table) } if ( "reditools3" in edit_site_tool_list ){ reditools3(final_bam_for_editing, genome.collect()) - pluviometer_reditools3(reditools3.out.tuple_sample_serial_table, clean_annotation.collect(), "reditools3") - // standardize the feature and aggregate output of pluviometer to be post process in a single way in down processes - standardize_pluvio_aggregates(pluviometer_reditools3.out.tuple_sample_aggregate) - standardize_pluvio_features(pluviometer_reditools3.out.tuple_sample_feature) - if(via_csv){ - // drip - compute espn, espf, merge different sample in one, and output by type of mutation (AG, AC, etc..) - drip_aggregates(standardize_pluvio_aggregates.out.standardized_tsv.collect(), "aggregates") - drip_features(standardize_pluvio_features.out.standardized_tsv.collect(), "features") - - //christalize(drip_features.out.editing_ag, "AG") + editing_analysis = editing_analysis.mix(reditools3.out.tuple_sample_serial_table) + } + + // Run pluviometer on editing analysis results to get aggregates and features values + pluviometer(editing_analysis, clean_annotation.collect()) + + if(via_csv){ + // Collect pluviometer outputs by tool (group by element at index 1 = tool name) + aggregates_by_tool = pluviometer.out.tuple_sample_aggregate.map { meta, tool, file -> tuple(tool, [meta, file]) }.groupTuple() + features_by_tool = pluviometer.out.tuple_sample_feature.map { meta, tool, file -> tuple(tool, [meta, file]) }.groupTuple() + + // drip - compute espn, espf, merge different sample in one, and output by type of mutation (AG, AC, etc..) + drip_aggregates(aggregates_by_tool, "aggregates", params.min_samples_pct, params.min_group_pct) + drip_features(features_by_tool, "features", params.min_samples_pct, params.min_group_pct) + + // ------------------- ESPF JOIN AGGREGATES AND FEATURES ----------------- + drip_aggregates.out.editing_all_espf + .flatten() + .map { file -> + // Extract editing type from filename (e.g., "drip_aggregates_espf_AC.tsv" -> "AC") + def editType = file.baseName.tokenize('_').last() + tuple(editType, file) + } + .join( + drip_features.out.editing_all_espf + .flatten() + .map { file -> + // Extract editing type from filename (e.g., "drip_aggregates_espf_AC.tsv" -> "AC") + def editType = file.baseName.tokenize('_').last() + tuple(editType, file) + } + ) + .set { features_espf_by_edit_type } + features_espf_by_edit_type.view() + + // ------------------- ESPR JOIN AGGREGATES AND FEATURES ----------------- + drip_aggregates.out.editing_all_espr + .flatten() + .map { file -> + // Extract editing type from filename (e.g., "drip_aggregates_espr_AC.tsv" -> "AC") + def editType = file.baseName.tokenize('_').last() + tuple(editType, file) + } + .join( + drip_features.out.editing_all_espr + .flatten() + .map { file -> + // Extract editing type from filename (e.g., "drip_aggregates_espr_AC.tsv" -> "AC") + def editType = file.baseName.tokenize('_').last() + tuple(editType, file) + } + ) + .set { features_espr_by_edit_type } + features_espr_by_edit_type.view() + + // READY for barometer analysis - } } // ------------------- MULTIQC ----------------- diff --git a/subworkflows/hyper-editing.nf b/subworkflows/hyper-editing.nf index 1531d7f..2260bf2 100644 --- a/subworkflows/hyper-editing.nf +++ b/subworkflows/hyper-editing.nf @@ -27,7 +27,6 @@ workflow HYPER_EDITING { unmapped_bams // Unmapped read chunks from primary alignment genome // Genomic reference sequence aline_profile // AliNe profile in coma-separated format - use_slurm_for_aline // Boolean - whether to use slurm for AliNe clean_annotation // Annotation file for AliNe quality_threshold // Quality score filter threshold output_he // output directory path ier @@ -48,7 +47,7 @@ workflow HYPER_EDITING { // Stage 4: Create CSV file for AliNe with all converted reads aline_csv = create_aline_csv_he(converted_reads).collect() - aline_csv = collect_aline_csv(aline_csv,output_he) + aline_csv = collect_aline_csv(aline_csv, output_he) // Stage 5: Build alignment index for converted reference alignment_index = samtools_fasta_index(converted_reference) @@ -69,7 +68,7 @@ workflow HYPER_EDITING { "--strandedness ${params.strandedness}", clean_annotation, workflow.workDir.resolve('Juke34/AliNe').toUriString(), - use_slurm_for_aline // Pass info about whether to use slurm submission + output_he, ) // GET TUPLE [ID, BAM] FILES