From 6fadbc104220106f623bed93a27654ff76393e39 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Fri, 13 Mar 2026 17:57:40 +0100 Subject: [PATCH 01/61] minimize memory footprint --- bin/pluviometer/__main__.py | 294 +++++++++++++++++++++++++----------- 1 file changed, 205 insertions(+), 89 deletions(-) diff --git a/bin/pluviometer/__main__.py b/bin/pluviometer/__main__.py index 85d203e..9dac558 100755 --- a/bin/pluviometer/__main__.py +++ b/bin/pluviometer/__main__.py @@ -373,7 +373,7 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> self.active_features.pop(feature.id, None) # Counter for the feature itself - feature_counter: Optional[MultiCounter] = self.counters.get(feature.id, None) + feature_counter: Optional[MultiCounter] = self.counters.pop(feature.id, None) assert self.record.id @@ -388,7 +388,8 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> self.chimaera_aggregate_counters_by_parent_type[(parent_feature.type, feature.type)].merge(feature_counter) else: self.feature_writer.write_row_with_data(self.record.id, feature, feature_counter) - del self.counters[feature.id] + # Explicitly delete counter after use to free memory + del feature_counter else: if feature.is_chimaera: assert parent_feature @@ -488,6 +489,10 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> # Recursively check-out children for child in feature.sub_features: self.checkout(child, feature) + + # After processing all children, clear sub_features to free memory + if hasattr(feature, 'sub_features'): + feature.sub_features.clear() return None @@ -670,6 +675,35 @@ def launch_counting(self, reader: RNASiteVariantReader) -> None: return None + def cleanup(self) -> None: + """ + Clean up memory after processing a record. Clear all data structures. + """ + self.active_features.clear() + self.action_queue.clear() + self.counters.clear() + + # Clear aggregate counters + self.longest_isoform_aggregate_counters.clear() + self.chimaera_aggregate_counters.clear() + self.all_isoforms_aggregate_counters.clear() + self.longest_isoform_aggregate_counters_by_parent_type.clear() + self.chimaera_aggregate_counters_by_parent_type.clear() + self.all_isoforms_aggregate_counters_by_parent_type.clear() + + # Clear record reference + if hasattr(self, 'record'): + # Clear features from record to help garbage collection + if hasattr(self.record, 'features'): + self.record.features.clear() + delattr(self, 'record') + + # Clear progress bar reference + if hasattr(self, 'progbar'): + delattr(self, 'progbar') + + return None + def parse_cli_input() -> argparse.Namespace: """Parse command line input""" @@ -854,18 +888,112 @@ def run_job(record: SeqRecord) -> dict[str, Any]: record.id, ["."], ".", ".", "all_sites", total_counter_dict ) - return { + # Extract data needed for return before cleanup + result = { "record_id": record.id, "tmp_feature_output_file": tmp_feature_output_file, "tmp_aggregate_output_file": tmp_aggregate_output_file, - "chimaera_aggregate_counters": record_ctx.chimaera_aggregate_counters, - "longest_isoform_aggregate_counters": record_ctx.longest_isoform_aggregate_counters, - "all_isoforms_aggregate_counters": record_ctx.all_isoforms_aggregate_counters, - "chimaera_aggregate_counters_by_parent_type": record_ctx.chimaera_aggregate_counters_by_parent_type, - "longest_isoform_aggregate_counters_by_parent_type": record_ctx.longest_isoform_aggregate_counters_by_parent_type, - "all_isoforms_aggregate_counters_by_parent_type": record_ctx.all_isoforms_aggregate_counters_by_parent_type, + "chimaera_aggregate_counters": record_ctx.chimaera_aggregate_counters.copy(), + "longest_isoform_aggregate_counters": record_ctx.longest_isoform_aggregate_counters.copy(), + "all_isoforms_aggregate_counters": record_ctx.all_isoforms_aggregate_counters.copy(), + "chimaera_aggregate_counters_by_parent_type": record_ctx.chimaera_aggregate_counters_by_parent_type.copy(), + "longest_isoform_aggregate_counters_by_parent_type": record_ctx.longest_isoform_aggregate_counters_by_parent_type.copy(), + "all_isoforms_aggregate_counters_by_parent_type": record_ctx.all_isoforms_aggregate_counters_by_parent_type.copy(), "total_counter": record_ctx.total_counter, } + + # Clean up record context to free memory immediately + record_ctx.cleanup() + + # Clear record features to help garbage collection + if hasattr(record, 'features'): + record.features.clear() + del record + + return result + + +def process_and_write_record_data( + record_data: dict[str, Any], + feature_output_handle: TextIO, + aggregate_output_handle: TextIO, + genome_longest_isoform_aggregate_counters: defaultdict[str, MultiCounter], + genome_all_isoforms_aggregate_counters: defaultdict[str, MultiCounter], + genome_chimaera_aggregate_counters: defaultdict[str, MultiCounter], + genome_longest_isoform_aggregate_counters_by_parent_type: defaultdict[tuple[str, str], MultiCounter], + genome_all_isoforms_aggregate_counters_by_parent_type: defaultdict[tuple[str, str], MultiCounter], + genome_chimaera_aggregate_counters_by_parent_type: defaultdict[tuple[str, str], MultiCounter], + genome_total_counter: MultiCounter, +) -> None: + """ + Process a single record's data: write its output and merge counters into genome totals. + This function writes immediately to output files and only keeps genome-level totals in memory. + """ + logging.info(f"Record {record_data['record_id']} · Writing output and merging totals...") + + # Write temporary files to final output immediately + with open(record_data["tmp_feature_output_file"]) as tmp_output_handle: + feature_output_handle.write(tmp_output_handle.read()) + remove(record_data["tmp_feature_output_file"]) + + with open(record_data["tmp_aggregate_output_file"]) as tmp_output_handle: + aggregate_output_handle.write(tmp_output_handle.read()) + remove(record_data["tmp_aggregate_output_file"]) + + # Update the genome's aggregate counters from the record data aggregate counters + for record_aggregate_type, record_aggregate_counter in record_data[ + "longest_isoform_aggregate_counters" + ].items(): + genome_longest_isoform_aggregate_counters[record_aggregate_type].merge( + record_aggregate_counter + ) + + for record_aggregate_type, record_aggregate_counter in record_data[ + "chimaera_aggregate_counters" + ].items(): + genome_chimaera_aggregate_counters[record_aggregate_type].merge( + record_aggregate_counter + ) + + merge_aggregation_counter_dicts( + genome_all_isoforms_aggregate_counters, + record_data["all_isoforms_aggregate_counters"], + ) + + # Update the genome's by-parent-type aggregate counters + for (parent_type, aggregate_type), record_aggregate_counter in record_data[ + "longest_isoform_aggregate_counters_by_parent_type" + ].items(): + genome_longest_isoform_aggregate_counters_by_parent_type[(parent_type, aggregate_type)].merge( + record_aggregate_counter + ) + + for (parent_type, aggregate_type), record_aggregate_counter in record_data[ + "all_isoforms_aggregate_counters_by_parent_type" + ].items(): + genome_all_isoforms_aggregate_counters_by_parent_type[(parent_type, aggregate_type)].merge( + record_aggregate_counter + ) + + for (parent_type, aggregate_type), record_aggregate_counter in record_data[ + "chimaera_aggregate_counters_by_parent_type" + ].items(): + genome_chimaera_aggregate_counters_by_parent_type[(parent_type, aggregate_type)].merge( + record_aggregate_counter + ) + + # Update the genome's total counter from the record data total counter + genome_total_counter.merge(record_data["total_counter"]) + + # Clear record_data counters to free memory immediately after merging + record_data["chimaera_aggregate_counters"].clear() + record_data["longest_isoform_aggregate_counters"].clear() + record_data["all_isoforms_aggregate_counters"].clear() + record_data["chimaera_aggregate_counters_by_parent_type"].clear() + record_data["longest_isoform_aggregate_counters_by_parent_type"].clear() + record_data["all_isoforms_aggregate_counters_by_parent_type"].clear() + + return None def main(): @@ -882,24 +1010,26 @@ def main(): args.output + "_aggregates.tsv" if args.output else "aggregates.tsv" ) + # Determine reader factory based on format + match args.format: + case "reditools2": + reader_factory = Reditools2Reader + case "reditools3": + reader_factory = Reditools3Reader + case "jacusa2": + reader_factory = Jacusa2Reader + case _: + raise Exception(f'Unimplemented format "{args.format}"') + with ( open(args.gff) as gff_handle, open(feature_output_filename, "w") as feature_output_handle, open(aggregate_output_filename, "w") as aggregate_output_handle, ): - match args.format: - case "reditools2": - reader_factory = Reditools2Reader - case "reditools3": - reader_factory = Reditools3Reader - case "jacusa2": - reader_factory = Jacusa2Reader - case _: - raise Exception(f'Unimplemented format "{args.format}"') - logging.info("Parsing GFF3 file...") records: Generator[SeqRecord, None, None] = GFF.parse(gff_handle) + # Initialize genome-level counters (only these will be kept in memory) genome_filter: SiteFilter = SiteFilter( cov_threshold=args.cov, edit_threshold=args.edit_threshold ) @@ -926,85 +1056,71 @@ def main(): lambda: MultiCounter(genome_filter) ) + # Write headers once at the beginning feature_writer: FeatureFileWriter = FeatureFileWriter(feature_output_handle) aggregate_writer: AggregateFileWriter = AggregateFileWriter(aggregate_output_handle) - feature_writer.write_header() aggregate_writer.write_header() - with multiprocessing.Pool(processes=args.threads, initializer=init_worker, initargs=(args, reader_factory)) as pool: - # Run the jobs and save the result of each record in a dict stored in the `record_data` list - record_data_list: list[dict[str, Any]] = pool.map(run_job, records) - - # Sort record results in "natural" order - record_data_list = natsorted(record_data_list, key=lambda x: x["record_id"]) - - with ( - open(feature_output_filename, "a") as feature_output_handle, - open(aggregate_output_filename, "a") as aggregate_output_handle, - ): - for record_data in record_data_list: - logging.info(f"Record {record_data['record_id']} · Merging temporary output files...") - - with open(record_data["tmp_feature_output_file"]) as tmp_output_handle: - feature_output_handle.write(tmp_output_handle.read()) - remove(record_data["tmp_feature_output_file"]) - - with open(record_data["tmp_aggregate_output_file"]) as tmp_output_handle: - aggregate_output_handle.write(tmp_output_handle.read()) - remove(record_data["tmp_aggregate_output_file"]) - - # Update the genome's aggregate counters from the record data aggregate counters - for record_aggregate_type, record_aggregate_counter in record_data[ - "longest_isoform_aggregate_counters" - ].items(): - genome_aggregate_counter: MultiCounter = genome_longest_isoform_aggregate_counters[ - record_aggregate_type - ] - genome_aggregate_counter.merge(record_aggregate_counter) - - for record_aggregate_type, record_aggregate_counter in record_data[ - "chimaera_aggregate_counters" - ].items(): - genome_aggregate_counter: MultiCounter = genome_chimaera_aggregate_counters[ - record_aggregate_type - ] - genome_aggregate_counter.merge(record_aggregate_counter) - - merge_aggregation_counter_dicts( - genome_all_isoforms_aggregate_counters, - record_data["all_isoforms_aggregate_counters"], - ) - - # Update the genome's by-parent-type aggregate counters - for (parent_type, aggregate_type), record_aggregate_counter in record_data[ - "longest_isoform_aggregate_counters_by_parent_type" - ].items(): - genome_longest_isoform_aggregate_counters_by_parent_type[(parent_type, aggregate_type)].merge( - record_aggregate_counter - ) - - for (parent_type, aggregate_type), record_aggregate_counter in record_data[ - "all_isoforms_aggregate_counters_by_parent_type" - ].items(): - genome_all_isoforms_aggregate_counters_by_parent_type[(parent_type, aggregate_type)].merge( - record_aggregate_counter - ) - - for (parent_type, aggregate_type), record_aggregate_counter in record_data[ - "chimaera_aggregate_counters_by_parent_type" - ].items(): - genome_chimaera_aggregate_counters_by_parent_type[(parent_type, aggregate_type)].merge( - record_aggregate_counter + # Process records and write output immediately + if args.threads > 1: + # Parallel processing: use imap (ordered) to maintain chromosome order + # imap returns results in the same order as input, allowing immediate writing + logging.info(f"Processing records with {args.threads} threads...") + with multiprocessing.Pool(processes=args.threads, initializer=init_worker, initargs=(args, reader_factory)) as pool: + # imap (not imap_unordered) processes records in order + # Results are yielded as soon as they're ready, maintaining input order + record_data_iterator = pool.imap(run_job, records, chunksize=1) + + # Write each result immediately as it comes (no accumulation!) + for record_data in record_data_iterator: + process_and_write_record_data( + record_data, + feature_output_handle, + aggregate_output_handle, + genome_longest_isoform_aggregate_counters, + genome_all_isoforms_aggregate_counters, + genome_chimaera_aggregate_counters, + genome_longest_isoform_aggregate_counters_by_parent_type, + genome_all_isoforms_aggregate_counters_by_parent_type, + genome_chimaera_aggregate_counters_by_parent_type, + genome_total_counter, + ) + # Force flush to disk after each chromosome + feature_output_handle.flush() + aggregate_output_handle.flush() + + else: + # Sequential processing: write immediately without storing results + logging.info("Processing records sequentially...") + for record in records: + record_data = run_job(record) + + # Write output and merge totals immediately (no accumulation) + process_and_write_record_data( + record_data, + feature_output_handle, + aggregate_output_handle, + genome_longest_isoform_aggregate_counters, + genome_all_isoforms_aggregate_counters, + genome_chimaera_aggregate_counters, + genome_longest_isoform_aggregate_counters_by_parent_type, + genome_all_isoforms_aggregate_counters_by_parent_type, + genome_chimaera_aggregate_counters_by_parent_type, + genome_total_counter, ) + + # Force flush to disk after each chromosome + feature_output_handle.flush() + aggregate_output_handle.flush() + + # Explicitly delete the record to free memory + del record + del record_data - # Update the genome's total counter from the record data total counter - genome_total_counter.merge(record_data["total_counter"]) - + # Write genome-level totals at the end (only genome totals are kept in memory) logging.info("Writing genome totals...") - aggregate_writer: AggregateFileWriter = AggregateFileWriter(aggregate_output_handle) - # Write genomic counts aggregate_writer.write_rows_with_data( ".", ["."], ".", ".", "longest_isoform", genome_longest_isoform_aggregate_counters From 11b5ffc412d873efb45b1ccd99ee548ba5a7964d Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Mon, 16 Mar 2026 11:51:02 +0100 Subject: [PATCH 02/61] replace BCBio by parse_gff3_streaming to process line by and get rid of useless attributes. No SQLite --- bin/pluviometer/__main__.py | 90 ++++++++++++++++++++++++++++++++++--- 1 file changed, 84 insertions(+), 6 deletions(-) diff --git a/bin/pluviometer/__main__.py b/bin/pluviometer/__main__.py index 9dac558..6077225 100755 --- a/bin/pluviometer/__main__.py +++ b/bin/pluviometer/__main__.py @@ -6,8 +6,8 @@ from collections import deque, defaultdict from dataclasses import dataclass, field from .multi_counter import MultiCounter -from Bio.SeqRecord import SeqRecord from .site_filter import SiteFilter +from types import SimpleNamespace from .rna_site_variant_readers import ( RNASiteVariantReader, Reditools2Reader, @@ -17,7 +17,6 @@ from .utils import RNASiteVariantData from natsort import natsorted import multiprocessing -from BCBio import GFF from os import remove import progressbar import tempfile @@ -183,7 +182,7 @@ def update_aggregate_counters( return None - def set_record(self, record: SeqRecord) -> None: + def set_record(self, record: Any) -> None: """ Assign a specific record to the RecordCountingContext. If the context was already assigned to another record, the queues and counters are cleared before proceeding (this was used in a previous version that recycled the record counting context, but it isn't used since parallelization was implemented). @@ -204,7 +203,7 @@ def set_record(self, record: SeqRecord) -> None: self.action_queue.clear() self.counters.clear() - self.record: SeqRecord = record + self.record: Any = record # Map positions to activation feature and deactivation actions logging.info("Pre-processing features") @@ -798,7 +797,7 @@ def init_worker(args_value, reader_factory_value): reader_factory = reader_factory_value -def run_job(record: SeqRecord) -> dict[str, Any]: +def run_job(record: Any) -> dict[str, Any]: """ A wrapper function for performing counting parallelized by record. The return value is a dict containing all the information needed for integrating the output of all records after the computations are finished. @@ -995,7 +994,86 @@ def process_and_write_record_data( return None +def parse_gff3_streaming(file_handle: TextIO) -> Generator[Any, None, None]: + """ + Memory-efficient streaming GFF3 parser. Processes one chromosome at a time, + loading only the current chromosome's features into memory. + + Replaces BCBio GFF.parse() to eliminate SQLite overhead and per-feature qualifier + storage. Only retains the minimal data needed: ID, type, location, and parent-child + relationships. + + Assumptions: + - The GFF3 file is grouped by chromosome (all features of one chromosome are contiguous). + - Parents appear before their children within each group (normalized GFF3). + """ + current_seqid: Optional[str] = None + features_by_id: dict[str, SeqFeature] = {} + top_level: list[SeqFeature] = [] + + for raw_line in file_handle: + # Skip directive lines (##...) and comment lines (#...) + if raw_line[0] == '#' or raw_line == '\n': + continue + + cols = raw_line.rstrip('\n').split('\t') + if len(cols) < 9: + continue + + seqid = cols[0] + ftype = cols[2] + strand_ch = cols[6] + attrs = cols[8] + + # Extract only ID and Parent from the attributes column + feature_id: Optional[str] = None + parent_ids: list[str] = [] + for attr in attrs.split(';'): + if attr.startswith('ID='): + feature_id = attr[3:].strip() + elif attr.startswith('Parent='): + parent_ids = attr[7:].strip().split(',') + + if feature_id is None: + continue # Features without ID cannot be referenced as parents — skip + + # New chromosome: yield the completed previous record + if seqid != current_seqid: + if current_seqid is not None and top_level: + yield SimpleNamespace(id=current_seqid, features=top_level) + features_by_id = {} + top_level = [] + current_seqid = seqid + + strand_int: int = 1 if strand_ch == '+' else (-1 if strand_ch == '-' else 0) + location = SimpleLocation(int(cols[3]) - 1, int(cols[4]), strand=strand_int) + + feature = SeqFeature( + location=location, + id=feature_id, + type=ftype, + qualifiers={'ID': [feature_id]}, # Only the ID qualifier is needed + ) + feature.sub_features = [] + features_by_id[feature_id] = feature + if not parent_ids: + top_level.append(feature) + else: + for pid in parent_ids: + parent = features_by_id.get(pid) + if parent is not None: + parent.sub_features.append(feature) + else: + logging.warning( + f"GFF3 parser: parent '{pid}' not found for '{feature_id}' " + f"— file may not be normalized (parents before children). Skipping feature." + ) + + # Yield the final chromosome + if current_seqid is not None and top_level: + yield SimpleNamespace(id=current_seqid, features=top_level) + def main(): global args global reader_factory @@ -1027,7 +1105,7 @@ def main(): open(aggregate_output_filename, "w") as aggregate_output_handle, ): logging.info("Parsing GFF3 file...") - records: Generator[SeqRecord, None, None] = GFF.parse(gff_handle) + records: Generator[Any, None, None] = parse_gff3_streaming(gff_handle) # Initialize genome-level counters (only these will be kept in memory) genome_filter: SiteFilter = SiteFilter( From 4c2697599bcd2ce833e0f6b8f1e0e963dc0461e8 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Mon, 16 Mar 2026 11:53:14 +0100 Subject: [PATCH 03/61] fix bug when paired-end where strand is *. We count alignment both strands --- bin/pluviometer/__main__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/pluviometer/__main__.py b/bin/pluviometer/__main__.py index 6077225..9f627a5 100755 --- a/bin/pluviometer/__main__.py +++ b/bin/pluviometer/__main__.py @@ -640,7 +640,7 @@ def update_active_counters(self, site_data: RNASiteVariantData) -> None: A new multicounter is created if no matching ID is found. """ for feature_key, feature in self.active_features.items(): - if feature.location.strand == site_data.strand: + if site_data.strand == 0 or feature.location.strand == 0 or feature.location.strand == site_data.strand: counter: MultiCounter = self.counters[feature_key] counter.update(site_data) From 2f6d435dceb5539ac805eb419e95627438b9cba4 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Mon, 16 Mar 2026 12:02:23 +0100 Subject: [PATCH 04/61] Revert "replace BCBio by parse_gff3_streaming to process line by and get rid of useless attributes. No SQLite" This reverts commit 11b5ffc412d873efb45b1ccd99ee548ba5a7964d. --- bin/pluviometer/__main__.py | 90 +++---------------------------------- 1 file changed, 6 insertions(+), 84 deletions(-) diff --git a/bin/pluviometer/__main__.py b/bin/pluviometer/__main__.py index 9f627a5..87ef876 100755 --- a/bin/pluviometer/__main__.py +++ b/bin/pluviometer/__main__.py @@ -6,8 +6,8 @@ from collections import deque, defaultdict from dataclasses import dataclass, field from .multi_counter import MultiCounter +from Bio.SeqRecord import SeqRecord from .site_filter import SiteFilter -from types import SimpleNamespace from .rna_site_variant_readers import ( RNASiteVariantReader, Reditools2Reader, @@ -17,6 +17,7 @@ from .utils import RNASiteVariantData from natsort import natsorted import multiprocessing +from BCBio import GFF from os import remove import progressbar import tempfile @@ -182,7 +183,7 @@ def update_aggregate_counters( return None - def set_record(self, record: Any) -> None: + def set_record(self, record: SeqRecord) -> None: """ Assign a specific record to the RecordCountingContext. If the context was already assigned to another record, the queues and counters are cleared before proceeding (this was used in a previous version that recycled the record counting context, but it isn't used since parallelization was implemented). @@ -203,7 +204,7 @@ def set_record(self, record: Any) -> None: self.action_queue.clear() self.counters.clear() - self.record: Any = record + self.record: SeqRecord = record # Map positions to activation feature and deactivation actions logging.info("Pre-processing features") @@ -797,7 +798,7 @@ def init_worker(args_value, reader_factory_value): reader_factory = reader_factory_value -def run_job(record: Any) -> dict[str, Any]: +def run_job(record: SeqRecord) -> dict[str, Any]: """ A wrapper function for performing counting parallelized by record. The return value is a dict containing all the information needed for integrating the output of all records after the computations are finished. @@ -994,86 +995,7 @@ def process_and_write_record_data( return None -def parse_gff3_streaming(file_handle: TextIO) -> Generator[Any, None, None]: - """ - Memory-efficient streaming GFF3 parser. Processes one chromosome at a time, - loading only the current chromosome's features into memory. - - Replaces BCBio GFF.parse() to eliminate SQLite overhead and per-feature qualifier - storage. Only retains the minimal data needed: ID, type, location, and parent-child - relationships. - - Assumptions: - - The GFF3 file is grouped by chromosome (all features of one chromosome are contiguous). - - Parents appear before their children within each group (normalized GFF3). - """ - current_seqid: Optional[str] = None - features_by_id: dict[str, SeqFeature] = {} - top_level: list[SeqFeature] = [] - - for raw_line in file_handle: - # Skip directive lines (##...) and comment lines (#...) - if raw_line[0] == '#' or raw_line == '\n': - continue - - cols = raw_line.rstrip('\n').split('\t') - if len(cols) < 9: - continue - - seqid = cols[0] - ftype = cols[2] - strand_ch = cols[6] - attrs = cols[8] - - # Extract only ID and Parent from the attributes column - feature_id: Optional[str] = None - parent_ids: list[str] = [] - for attr in attrs.split(';'): - if attr.startswith('ID='): - feature_id = attr[3:].strip() - elif attr.startswith('Parent='): - parent_ids = attr[7:].strip().split(',') - - if feature_id is None: - continue # Features without ID cannot be referenced as parents — skip - - # New chromosome: yield the completed previous record - if seqid != current_seqid: - if current_seqid is not None and top_level: - yield SimpleNamespace(id=current_seqid, features=top_level) - features_by_id = {} - top_level = [] - current_seqid = seqid - - strand_int: int = 1 if strand_ch == '+' else (-1 if strand_ch == '-' else 0) - location = SimpleLocation(int(cols[3]) - 1, int(cols[4]), strand=strand_int) - - feature = SeqFeature( - location=location, - id=feature_id, - type=ftype, - qualifiers={'ID': [feature_id]}, # Only the ID qualifier is needed - ) - feature.sub_features = [] - features_by_id[feature_id] = feature - if not parent_ids: - top_level.append(feature) - else: - for pid in parent_ids: - parent = features_by_id.get(pid) - if parent is not None: - parent.sub_features.append(feature) - else: - logging.warning( - f"GFF3 parser: parent '{pid}' not found for '{feature_id}' " - f"— file may not be normalized (parents before children). Skipping feature." - ) - - # Yield the final chromosome - if current_seqid is not None and top_level: - yield SimpleNamespace(id=current_seqid, features=top_level) - def main(): global args global reader_factory @@ -1105,7 +1027,7 @@ def main(): open(aggregate_output_filename, "w") as aggregate_output_handle, ): logging.info("Parsing GFF3 file...") - records: Generator[Any, None, None] = parse_gff3_streaming(gff_handle) + records: Generator[SeqRecord, None, None] = GFF.parse(gff_handle) # Initialize genome-level counters (only these will be kept in memory) genome_filter: SiteFilter = SiteFilter( From facf65edec66e503ebb6b6fb40b69507c9b3ebf7 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Mon, 16 Mar 2026 12:20:07 +0100 Subject: [PATCH 05/61] add gc and possibility to filter featurs --- bin/pluviometer/__main__.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/bin/pluviometer/__main__.py b/bin/pluviometer/__main__.py index 87ef876..6c94572 100755 --- a/bin/pluviometer/__main__.py +++ b/bin/pluviometer/__main__.py @@ -24,6 +24,7 @@ import argparse import logging import math +import gc logger = logging.getLogger(__name__) @@ -770,6 +771,12 @@ def parse_cli_input() -> argparse.Namespace: parser.add_argument( "--progress", action="store_true", default=False, help="Display progress bar" ) + parser.add_argument( + "--gff_feature_types", + type=str, + default=None, + help="Comma-separated list of GFF feature types to load (e.g., 'gene,mRNA,exon,CDS'). Reduces memory usage by filtering during parsing. Leave empty to load all types.", + ) return parser.parse_args() @@ -1027,7 +1034,16 @@ def main(): open(aggregate_output_filename, "w") as aggregate_output_handle, ): logging.info("Parsing GFF3 file...") - records: Generator[SeqRecord, None, None] = GFF.parse(gff_handle) + + # Configure BCBio GFF parser with memory-efficient options + limit_info = None + if args.gff_feature_types: + feature_types = [t.strip() for t in args.gff_feature_types.split(',')] + limit_info = {"gff_type": feature_types} + logging.info(f"Filtering GFF to load only these feature types: {feature_types}") + logging.info("This will reduce memory usage significantly!") + + records: Generator[SeqRecord, None, None] = GFF.parse(gff_handle, limit_info=limit_info) # Initialize genome-level counters (only these will be kept in memory) genome_filter: SiteFilter = SiteFilter( @@ -1089,6 +1105,9 @@ def main(): # Force flush to disk after each chromosome feature_output_handle.flush() aggregate_output_handle.flush() + + # Force garbage collection to free memory immediately + gc.collect() else: # Sequential processing: write immediately without storing results @@ -1117,6 +1136,9 @@ def main(): # Explicitly delete the record to free memory del record del record_data + + # Force garbage collection to free memory immediately + gc.collect() # Write genome-level totals at the end (only genome totals are kept in memory) logging.info("Writing genome totals...") From 85446f4348119929eead203dbe8201ff5e6f098b Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Mon, 16 Mar 2026 15:23:15 +0100 Subject: [PATCH 06/61] space to re-run --- modules/pluviometer.nf | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/modules/pluviometer.nf b/modules/pluviometer.nf index efa5a5b..a821950 100644 --- a/modules/pluviometer.nf +++ b/modules/pluviometer.nf @@ -16,6 +16,7 @@ process pluviometer { script: base_name = site_edits.BaseName """ + pluviometer_wrapper.py \ --sites ${site_edits} \ --gff ${gff} \ @@ -24,6 +25,6 @@ process pluviometer { --edit_threshold ${params.edit_threshold} \ --threads ${task.cpus} \ --aggregation_mode ${params.aggregation_mode} \ - --output "${meta.uid}_${tool_format}" + --output "${meta.uid}_${tool_format}" """ -} \ No newline at end of file +} From 0825ddc81a4fb6195ad59396dacb9f347291ca06 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Mon, 16 Mar 2026 15:39:05 +0100 Subject: [PATCH 07/61] handle NA correctly --- bin/drip.py | 72 +++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 50 insertions(+), 22 deletions(-) diff --git a/bin/drip.py b/bin/drip.py index 04cebac..09dd6d0 100755 --- a/bin/drip.py +++ b/bin/drip.py @@ -1,6 +1,7 @@ #!/usr/bin/env python3 import pandas as pd +import numpy as np import sys from pathlib import Path @@ -10,7 +11,7 @@ def print_help(): DRIP - RNA Editing Analysis Tool DESCRIPTION: - This script analyzes RNA editing from standardized puviometer files. It calculates + This script analyzes RNA editing from standardized pluviometer files. It calculates two key metrics for all 16 genome-variant base pair combinations across multiple samples and combines them into a unified matrix format. @@ -24,8 +25,18 @@ def print_help(): Input file path, group name, sample name, and replicate ID separated by colons. All four components are required. --with-file-id Include file ID in column names (default: omit file ID) + --min-cov N Minimum read coverage threshold (default: 1). Positions with a + denominator (genome base count for espf, read count for espr) + strictly below this value are reported as NA instead of 0. --help, -h Display this help message +NA BEHAVIOR: +| Source du NA | Mécanisme | Résultat +| Couverture = 0 dans sampleA | np.where(mask, ..., np.nan) | NA ✅ +| Couverture < min_cov dans sampleA | même masque | NA ✅ +| Ligne absente de sampleB | how='outer' → NaN | NA ✅ +| Couverture OK, 0 éditions | ratio = 0.0 | 0.0 ✅ + INPUT FILE FORMAT: The input files must be TSV files with the following columns: - GenomeBases: Frequencies of bases in the reference genome (order: A, C, G, T) @@ -115,7 +126,7 @@ def print_help(): print(help_text) sys.exit(0) -def parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, include_file_id=False): +def parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, include_file_id=False, min_cov=1): """Parse a single TSV file and extract editing metrics for all base pair combinations.""" df = pd.read_csv(filepath, sep='\t') @@ -152,16 +163,15 @@ def parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, includ genome_base = bp[0] # First letter is the genome base # Calculate espf: XY_sites / X_count + # NA when genome base count < min_cov (position not covered or not in feature) espf_col = f'{col_prefix}::{bp}::espf' - df[espf_col] = df.apply( - lambda row: row[f'{bp}_sites'] / row[f'{genome_base}_count'] - if row[f'{genome_base}_count'] > 0 else 0, - axis=1 - ) + denom_espf = df[f'{genome_base}_count'] + mask_espf = denom_espf >= min_cov + df[espf_col] = np.where(mask_espf, df[f'{bp}_sites'] / denom_espf.where(mask_espf, 1), np.nan) result_cols.append(espf_col) # Calculate espr: XY_reads / (XA + XC + XG + XT) - # Calculate total reads for this genome base + # NA when total read coverage < min_cov (position not sequenced) total_reads_col = f'{genome_base}_total_reads' if total_reads_col not in df.columns: df[total_reads_col] = ( @@ -172,11 +182,9 @@ def parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, includ ) espr_col = f'{col_prefix}::{bp}::espr' - df[espr_col] = df.apply( - lambda row: row[f'{bp}_reads'] / row[total_reads_col] - if row[total_reads_col] > 0 else 0, - axis=1 - ) + denom_espr = df[total_reads_col] + mask_espr = denom_espr >= min_cov + df[espr_col] = np.where(mask_espr, df[f'{bp}_reads'] / denom_espr.where(mask_espr, 1), np.nan) result_cols.append(espr_col) # Return dataframe with metadata and all metrics @@ -184,7 +192,7 @@ def parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, includ return result -def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id=False): +def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id=False, min_cov=1): """Merge data from multiple samples and create output matrices - one file per base pair combination.""" all_data = [] @@ -202,7 +210,7 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ filename_stem = Path(filepath).stem file_id = '_'.join(filename_stem.split('_')[:-1]) print(f"Processing {group_name}::{sample_name} (replicate {replicate}) from {filepath} (file_id: {file_id})...") - data = parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, include_file_id) + data = parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, include_file_id, min_cov) all_data.append(data) file_id_list.append(file_id) group_name_list.append(group_name) @@ -215,8 +223,8 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ for data in all_data[1:]: merged = merged.merge(data, on=metadata_cols, how='outer') - # Fill NA values with 0 for metrics - merged = merged.fillna(0) + # Do NOT fill NA with 0: NA means the position was not covered (or below min_cov), + # which is distinct from 0 (position covered but no editing observed) # Sort by SeqID, then ParentIDs, then Mode merged = merged.sort_values(['SeqID', 'ParentIDs', 'Mode']) @@ -256,7 +264,7 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ # Save to file output_file = f"{output_prefix}_{bp}.tsv" - bp_result.to_csv(output_file, sep='\t', index=False) + bp_result.to_csv(output_file, sep='\t', index=False, na_rep='NA') output_files.append(output_file) print(f"\nOutput files created:") @@ -286,13 +294,33 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ file_group_sample_replicate_dict = {} group_sample_rep_counts = {} include_file_id = False # Default: omit file_id from column names - - for arg in sys.argv[2:]: + min_cov = 1 # Default: NA when denominator is 0; treat 0 coverage as non-observed + + args_iter = iter(range(2, len(sys.argv))) + for i in args_iter: + arg = sys.argv[i] # Check for --with-file-id flag if arg == '--with-file-id': include_file_id = True continue - + + # Check for --min-cov flag (supports both --min-cov=N and --min-cov N) + if arg.startswith('--min-cov'): + if '=' in arg: + try: + min_cov = int(arg.split('=', 1)[1]) + except ValueError: + print(f"ERROR: --min-cov requires an integer value", file=sys.stderr) + sys.exit(1) + else: + try: + next_i = next(args_iter) + min_cov = int(sys.argv[next_i]) + except (StopIteration, ValueError): + print(f"ERROR: --min-cov requires an integer value", file=sys.stderr) + sys.exit(1) + continue + if ':' not in arg: print(f"ERROR: Invalid argument format '{arg}'", file=sys.stderr) print("Expected format: FILE:GROUP:SAMPLE:REPLICATE", file=sys.stderr) @@ -326,6 +354,6 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ file_group_sample_replicate_dict[filepath] = (group_name, sample_name, replicate) # Process all samples - result = merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id) + result = merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id, min_cov) print("\nAnalysis complete!") \ No newline at end of file From 6db54c84961d42cb7b4f3c99fe44dcd668ed910d Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Mon, 16 Mar 2026 15:46:50 +0100 Subject: [PATCH 08/61] centralize env for container building --- .../pluviometer => }/env_pluviometer.yml | 0 .../pluviometer/env_pluviometer.yml | 57 ------------------- .../singularity/pluviometer/pluviometer.def | 2 +- 3 files changed, 1 insertion(+), 58 deletions(-) rename containers/{docker/pluviometer => }/env_pluviometer.yml (100%) delete mode 100644 containers/singularity/pluviometer/env_pluviometer.yml diff --git a/containers/docker/pluviometer/env_pluviometer.yml b/containers/env_pluviometer.yml similarity index 100% rename from containers/docker/pluviometer/env_pluviometer.yml rename to containers/env_pluviometer.yml diff --git a/containers/singularity/pluviometer/env_pluviometer.yml b/containers/singularity/pluviometer/env_pluviometer.yml deleted file mode 100644 index 47b0cb0..0000000 --- a/containers/singularity/pluviometer/env_pluviometer.yml +++ /dev/null @@ -1,57 +0,0 @@ -name: rain-stats-scripts -channels: - - bioconda - - conda-forge -dependencies: - - _libgcc_mutex=0.1=conda_forge - - _openmp_mutex=4.5=2_gnu - - bcbio-gff=0.7.1=pyh7e72e81_2 - - biopython=1.85=py312h66e93f0_1 - - bx-python=0.13.0=py312h5e9d817_1 - - bzip2=1.0.8=h4bc722e_7 - - ca-certificates=2025.7.14=hbd8a1cb_0 - - frozendict=2.4.6=py312h66e93f0_0 - - ld_impl_linux-64=2.43=h712a8e2_4 - - libblas=3.9.0=31_h59b9bed_openblas - - libcblas=3.9.0=31_he106b2a_openblas - - libexpat=2.7.0=h5888daf_0 - - libffi=3.4.6=h2dba641_1 - - libgcc=14.2.0=h767d61c_2 - - libgcc-ng=14.2.0=h69a702a_2 - - libgfortran=14.2.0=h69a702a_2 - - libgfortran5=14.2.0=hf1ad2bd_2 - - libgomp=14.2.0=h767d61c_2 - - liblapack=3.9.0=31_h7ac8fdf_openblas - - liblzma=5.8.1=hb9d3cd8_0 - - libnsl=2.0.1=hd590300_0 - - libopenblas=0.3.29=pthreads_h94d23a6_0 - - libsqlite=3.49.1=hee588c1_2 - - libstdcxx=14.2.0=h8f9b012_2 - - libstdcxx-ng=14.2.0=h4852527_2 - - libunwind=1.6.2=h9c3ff4c_0 - - libuuid=2.38.1=h0b41bf4_0 - - libxcrypt=4.4.36=hd590300_1 - - libzlib=1.3.1=hb9d3cd8_2 - - natsort=8.4.0=pyh29332c3_1 - - ncurses=6.5=h2d0b736_3 - - numpy=2.2.4=py312h72c5963_0 - - openssl=3.5.1=h7b32b05_0 - - pip=25.0.1=pyh8b19718_0 - - progressbar2=4.5.0=pyhd8ed1ab_1 - - py-spy=0.4.0=h4c5a871_1 - - pyparsing=3.2.3=pyhd8ed1ab_1 - - python=3.12.10=h9e4cc4f_0_cpython - - python-utils=3.9.1=pyhff2d567_1 - - python_abi=3.12=6_cp312 - - readline=8.2=h8c095d6_2 - - setuptools=78.1.0=pyhff2d567_0 - - six=1.17.0=pyhd8ed1ab_0 - - sortedcontainers=2.4.0=pyhd8ed1ab_1 - - tk=8.6.13=noxft_h4845f30_101 - - typing_extensions=4.13.2=pyh29332c3_0 - - tzdata=2025b=h78e105d_0 - - wheel=0.45.1=pyhd8ed1ab_1 - - pysam=0.23.3=py312h47d5410_1 - - pip: - - flameprof==0.4 -prefix: /home/earx/miniforge3/envs/rain-stats-scripts diff --git a/containers/singularity/pluviometer/pluviometer.def b/containers/singularity/pluviometer/pluviometer.def index 9b4473a..bbcfa3d 100644 --- a/containers/singularity/pluviometer/pluviometer.def +++ b/containers/singularity/pluviometer/pluviometer.def @@ -2,7 +2,7 @@ Bootstrap: docker From: condaforge/miniforge3:24.7.1-2 %files - containers/singularity/pluviometer/env_pluviometer.yml /env_pluviometer.yml + containers/env_pluviometer.yml /env_pluviometer.yml %post conda env update --name base --file env_pluviometer.yml From dcde2d216a3d32f42fcbb820cdbc98de18f5f7c5 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Mon, 16 Mar 2026 15:48:54 +0100 Subject: [PATCH 09/61] centralize env for container building --- build_containers.sh | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/build_containers.sh b/build_containers.sh index 6a5ffed..ad289fd 100755 --- a/build_containers.sh +++ b/build_containers.sh @@ -129,7 +129,6 @@ if [ "$build_docker" = true ]; then image_list=( ) for dir in containers/docker/*; do - cd "${dir}" imgname=$(echo $dir | rev | cut -d/ -f1 | rev) image_list+=(${imgname}) @@ -146,10 +145,8 @@ if [ "$build_docker" = true ]; then fi fi - docker build ${docker_arch_option} -t ${imgname} . - - # Back to the original working directory - cd "$wd" + # Use containers/ as build context so shared files (e.g. env_*.yml) are accessible + docker build ${docker_arch_option} -f "${dir}/Dockerfile" -t ${imgname} containers/ done if [[ ${github_action_mode} == 'github_action' ]]; then From 37faa0d5141afc9d7673d5c33cda6ad8db99f72f Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Tue, 17 Mar 2026 15:32:31 +0100 Subject: [PATCH 10/61] use pickle to write in temporary file instead to keep data in memory to decrease memory footpint --- bin/pluviometer/__main__.py | 77 +++++++++++++++++++++++-------------- 1 file changed, 49 insertions(+), 28 deletions(-) diff --git a/bin/pluviometer/__main__.py b/bin/pluviometer/__main__.py index 6c94572..380f677 100755 --- a/bin/pluviometer/__main__.py +++ b/bin/pluviometer/__main__.py @@ -23,6 +23,7 @@ import tempfile import argparse import logging +import pickle import math import gc @@ -916,8 +917,15 @@ def run_job(record: SeqRecord) -> dict[str, Any]: if hasattr(record, 'features'): record.features.clear() del record - - return result + + # Serialize result to a temp pickle file so the worker frees RAM immediately. + # Only the tiny file path is sent back through the IPC channel. + tmp_pickle_fd, tmp_pickle_path = tempfile.mkstemp(suffix='.pkl') + with open(tmp_pickle_fd, 'wb') as f: + pickle.dump(result, f) + del result + + return tmp_pickle_path def process_and_write_record_data( @@ -1080,34 +1088,47 @@ def main(): # Process records and write output immediately if args.threads > 1: - # Parallel processing: use imap (ordered) to maintain chromosome order - # imap returns results in the same order as input, allowing immediate writing + # Parallel processing: use imap_unordered so each worker's result is + # written to a pickle temp file as soon as it is ready, regardless of + # chromosome order. Only the tiny file paths cross the IPC channel, + # keeping the main process RAM near zero while workers are running. logging.info(f"Processing records with {args.threads} threads...") + pickle_paths: list[str] = [] with multiprocessing.Pool(processes=args.threads, initializer=init_worker, initargs=(args, reader_factory)) as pool: - # imap (not imap_unordered) processes records in order - # Results are yielded as soon as they're ready, maintaining input order - record_data_iterator = pool.imap(run_job, records, chunksize=1) - - # Write each result immediately as it comes (no accumulation!) - for record_data in record_data_iterator: - process_and_write_record_data( - record_data, - feature_output_handle, - aggregate_output_handle, - genome_longest_isoform_aggregate_counters, - genome_all_isoforms_aggregate_counters, - genome_chimaera_aggregate_counters, - genome_longest_isoform_aggregate_counters_by_parent_type, - genome_all_isoforms_aggregate_counters_by_parent_type, - genome_chimaera_aggregate_counters_by_parent_type, - genome_total_counter, - ) - # Force flush to disk after each chromosome - feature_output_handle.flush() - aggregate_output_handle.flush() - - # Force garbage collection to free memory immediately - gc.collect() + for pickle_path in pool.imap_unordered(run_job, records, chunksize=1): + pickle_paths.append(pickle_path) + + # Sort pickle files by their record_id (natural chromosome order) + def _record_id_from_pickle(path: str) -> str: + with open(path, 'rb') as _f: + return pickle.load(_f)['record_id'] + + pickle_paths = natsorted(pickle_paths, key=_record_id_from_pickle) + + # Process in natural order: unpickle → write → delete → GC + for pickle_path in pickle_paths: + with open(pickle_path, 'rb') as f: + record_data = pickle.load(f) + remove(pickle_path) + process_and_write_record_data( + record_data, + feature_output_handle, + aggregate_output_handle, + genome_longest_isoform_aggregate_counters, + genome_all_isoforms_aggregate_counters, + genome_chimaera_aggregate_counters, + genome_longest_isoform_aggregate_counters_by_parent_type, + genome_all_isoforms_aggregate_counters_by_parent_type, + genome_chimaera_aggregate_counters_by_parent_type, + genome_total_counter, + ) + del record_data + # Force flush to disk after each chromosome + feature_output_handle.flush() + aggregate_output_handle.flush() + + # Force garbage collection to free memory immediately + gc.collect() else: # Sequential processing: write immediately without storing results From 54af45e1e526ff909635089dfb975e6831bb1877 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Tue, 17 Mar 2026 16:26:58 +0100 Subject: [PATCH 11/61] use dict of SeqID to provide to BCbio gff that will load only the sequence of interest by worker. Now the needed RAM is really mlow --- bin/pluviometer/__main__.py | 157 ++++++++++++++++++++++++++---------- 1 file changed, 116 insertions(+), 41 deletions(-) diff --git a/bin/pluviometer/__main__.py b/bin/pluviometer/__main__.py index 380f677..f41a42e 100755 --- a/bin/pluviometer/__main__.py +++ b/bin/pluviometer/__main__.py @@ -20,11 +20,13 @@ from BCBio import GFF from os import remove import progressbar +import subprocess import tempfile import argparse import logging import pickle import math +import io import gc logger = logging.getLogger(__name__) @@ -806,13 +808,79 @@ def init_worker(args_value, reader_factory_value): reader_factory = reader_factory_value -def run_job(record: SeqRecord) -> dict[str, Any]: +def _scan_gff_record_ids(gff_path: str) -> list[str]: + """Return a natsorted list of unique sequence IDs present in the GFF file. + + For bgzip+tabix files (.gz) uses ``tabix -l`` (fast, no decompression of + data lines). For plain-text files falls back to a lightweight linear scan. """ - A wrapper function for performing counting parallelized by record. The return value is a dict containing all the information needed for integrating - the output of all records after the computations are finished. + if gff_path.endswith('.gz'): + result = subprocess.run( + ['tabix', '-l', gff_path], + capture_output=True, text=True, check=True, + ) + return natsorted(result.stdout.splitlines()) + + seen: set[str] = set() + ordered: list[str] = [] + with open(gff_path) as fh: + for line in fh: + if not line or line[0] == '#': + continue + fields = line.split('\t', 2) + if len(fields) < 2: + # Not a GFF data line (e.g. FASTA section, wrapped attribute line) + continue + seqid = fields[0] + if seqid not in seen: + seen.add(seqid) + ordered.append(seqid) + return natsorted(ordered) + + +def _open_gff_for_record( + gff_path: str, + record_id: str, + gff_limit: Optional[dict], +) -> Optional[SeqRecord]: + """Load a single chromosome from a GFF3 file and return its SeqRecord. + + Strategy differs by file type: + + * **bgzip+tabix (.gz)**: call ``tabix gff_path record_id`` to extract only + the relevant lines and feed them to BCBio via an ``io.StringIO`` buffer + (which is seekable, so BCBio can do its two-pass parent-child resolution). + We intentionally do *not* pass a ``gff_id`` filter here — the stream + already contains only the target chromosome, so there is nothing to + filter. The ``gff_id`` filter in BCBio can bypass the two-pass logic, + causing "parent not found" warnings for out-of-order features. + + * **Plain text**: use BCBio's ``gff_id`` filter on the open (seekable) file + handle so BCBio can seek back for its second pass when needed. """ - assert record.id # Stupid assertion for pylance - logging.info(f"Record {record.id} · Record parsed. Counting beings.") + if gff_path.endswith('.gz'): + proc = subprocess.run( + ['tabix', gff_path, record_id], + capture_output=True, text=True, check=True, + ) + gff_stream = io.StringIO('##gff-version 3\n' + proc.stdout) + return next(GFF.parse(gff_stream, limit_info=gff_limit), None) + + # Plain text: seekable file handle → BCBio two-pass works + limit: dict = {'gff_id': [record_id]} + if gff_limit: + limit.update(gff_limit) + with open(gff_path) as gff_handle: + return next(GFF.parse(gff_handle, limit_info=limit), None) + + +def _do_counting(record: SeqRecord) -> dict[str, Any]: + """ + Core counting logic for a single SeqRecord. Returns the result dict. + Called directly in sequential mode and via run_job in parallel mode. + """ + assert record.id # Placate pylance + logging.info(f"Record {record.id} · Counting begins.") tmp_feature_output_file: str = tempfile.mkstemp()[1] tmp_aggregate_output_file: str = tempfile.mkstemp()[1] @@ -909,17 +977,34 @@ def run_job(record: SeqRecord) -> dict[str, Any]: "all_isoforms_aggregate_counters_by_parent_type": record_ctx.all_isoforms_aggregate_counters_by_parent_type.copy(), "total_counter": record_ctx.total_counter, } - + # Clean up record context to free memory immediately record_ctx.cleanup() - - # Clear record features to help garbage collection - if hasattr(record, 'features'): - record.features.clear() + + return result + + +def run_job(record_id: str) -> str: + """ + Worker entry point for parallel processing. Loads only the chromosome matching + record_id from the GFF (minimal RAM footprint per worker), performs counting, + serializes the result dict to a pickle temp file and returns its path. + Only a tiny string crosses the IPC channel. + """ + gff_limit: Optional[dict] = None + if args.gff_feature_types: + gff_limit = {'gff_type': [t.strip() for t in args.gff_feature_types.split(',')]} + + logging.info(f"Record {record_id} · Loading GFF features.") + record: Optional[SeqRecord] = _open_gff_for_record(args.gff, record_id, gff_limit) + + if record is None: + raise RuntimeError(f"Record '{record_id}' not found in {args.gff}") + + result = _do_counting(record) + record.features.clear() del record - # Serialize result to a temp pickle file so the worker frees RAM immediately. - # Only the tiny file path is sent back through the IPC channel. tmp_pickle_fd, tmp_pickle_path = tempfile.mkstemp(suffix='.pkl') with open(tmp_pickle_fd, 'wb') as f: pickle.dump(result, f) @@ -1037,22 +1122,9 @@ def main(): raise Exception(f'Unimplemented format "{args.format}"') with ( - open(args.gff) as gff_handle, open(feature_output_filename, "w") as feature_output_handle, open(aggregate_output_filename, "w") as aggregate_output_handle, ): - logging.info("Parsing GFF3 file...") - - # Configure BCBio GFF parser with memory-efficient options - limit_info = None - if args.gff_feature_types: - feature_types = [t.strip() for t in args.gff_feature_types.split(',')] - limit_info = {"gff_type": feature_types} - logging.info(f"Filtering GFF to load only these feature types: {feature_types}") - logging.info("This will reduce memory usage significantly!") - - records: Generator[SeqRecord, None, None] = GFF.parse(gff_handle, limit_info=limit_info) - # Initialize genome-level counters (only these will be kept in memory) genome_filter: SiteFilter = SiteFilter( cov_threshold=args.cov, edit_threshold=args.edit_threshold @@ -1092,10 +1164,11 @@ def main(): # written to a pickle temp file as soon as it is ready, regardless of # chromosome order. Only the tiny file paths cross the IPC channel, # keeping the main process RAM near zero while workers are running. - logging.info(f"Processing records with {args.threads} threads...") + record_ids: list[str] = _scan_gff_record_ids(args.gff) + logging.info(f"Processing {len(record_ids)} records with {args.threads} threads...") pickle_paths: list[str] = [] with multiprocessing.Pool(processes=args.threads, initializer=init_worker, initargs=(args, reader_factory)) as pool: - for pickle_path in pool.imap_unordered(run_job, records, chunksize=1): + for pickle_path in pool.imap_unordered(run_job, record_ids, chunksize=1): pickle_paths.append(pickle_path) # Sort pickle files by their record_id (natural chromosome order) @@ -1131,14 +1204,23 @@ def _record_id_from_pickle(path: str) -> str: gc.collect() else: - # Sequential processing: write immediately without storing results + # Sequential processing: load one chromosome at a time, write immediately. + # Uses the same _open_gff_for_record helper as the parallel mode so that + # both plain and tabix-indexed files are handled identically. logging.info("Processing records sequentially...") - for record in records: - record_data = run_job(record) - - # Write output and merge totals immediately (no accumulation) + gff_limit: Optional[dict] = None + if args.gff_feature_types: + gff_limit = {'gff_type': [t.strip() for t in args.gff_feature_types.split(',')]} + for record_id in _scan_gff_record_ids(args.gff): + record = _open_gff_for_record(args.gff, record_id, gff_limit) + if record is None: + logging.warning(f"No features found for record '{record_id}', skipping.") + continue + result = _do_counting(record) + record.features.clear() + del record process_and_write_record_data( - record_data, + result, feature_output_handle, aggregate_output_handle, genome_longest_isoform_aggregate_counters, @@ -1149,16 +1231,9 @@ def _record_id_from_pickle(path: str) -> str: genome_chimaera_aggregate_counters_by_parent_type, genome_total_counter, ) - - # Force flush to disk after each chromosome + del result feature_output_handle.flush() aggregate_output_handle.flush() - - # Explicitly delete the record to free memory - del record - del record_data - - # Force garbage collection to free memory immediately gc.collect() # Write genome-level totals at the end (only genome totals are kept in memory) From 2a57b5a7a038549a613738e055fc2f9c08104082 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Tue, 17 Mar 2026 16:28:07 +0100 Subject: [PATCH 12/61] change ressources for pluviometer --- config/resources/hpc.config | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/config/resources/hpc.config b/config/resources/hpc.config index 2900223..40807e1 100644 --- a/config/resources/hpc.config +++ b/config/resources/hpc.config @@ -29,10 +29,8 @@ process { time = '4h' } withLabel: 'pluviometer' { - cpus = 4 + cpus = 6 time = '6h' - memory = '48GB' - errorStrategy = 'terminate' } withLabel: 'reditools3' { cpus = 6 From db67f96c53753579cb2d3e7d7ae2f4b099ecbd9e Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Tue, 17 Mar 2026 21:39:19 +0100 Subject: [PATCH 13/61] avoid DtypeWarning: Columns (0: SeqID, 1: Start, 2: End, 3: Strand) have mixed types. Specify dtype option on import or set low_memory=False. --- bin/drip.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/bin/drip.py b/bin/drip.py index 09dd6d0..b48a184 100755 --- a/bin/drip.py +++ b/bin/drip.py @@ -128,7 +128,10 @@ def print_help(): def parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, include_file_id=False, min_cov=1): """Parse a single TSV file and extract editing metrics for all base pair combinations.""" - df = pd.read_csv(filepath, sep='\t') + # SeqID, Start, End, Strand contain mixed values ("." and actual numbers/strings) + # → force them to string to avoid DtypeWarning and preserve "." as-is + mixed_cols = {'SeqID': str, 'Start': str, 'End': str, 'Strand': str} + df = pd.read_csv(filepath, sep='\t', dtype=mixed_cols) # DO NOT filter out rows where ID is '.' # These are special aggregate rows (e.g., all_sites) that should be kept From 5c757708b2cad884f0c93ad6733b9792aa8e60d7 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Tue, 17 Mar 2026 21:48:00 +0100 Subject: [PATCH 14/61] decrease memory - process sequentially by base pair --- bin/drip.py | 97 ++++++++++++++++++++++------------------------------- 1 file changed, 41 insertions(+), 56 deletions(-) diff --git a/bin/drip.py b/bin/drip.py index b48a184..2ebd0d7 100755 --- a/bin/drip.py +++ b/bin/drip.py @@ -3,6 +3,7 @@ import pandas as pd import numpy as np import sys +import gc from pathlib import Path def print_help(): @@ -190,56 +191,49 @@ def parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, includ df[espr_col] = np.where(mask_espr, df[f'{bp}_reads'] / denom_espr.where(mask_espr, 1), np.nan) result_cols.append(espr_col) - # Return dataframe with metadata and all metrics - result = df[result_cols].copy() - + # Select only needed columns — list indexing creates a new DataFrame, + # so we can free the large intermediate df immediately. + result = df[result_cols] + del df return result def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id=False, min_cov=1): """Merge data from multiple samples and create output matrices - one file per base pair combination.""" - - all_data = [] - file_id_list = [] - group_name_list = [] - sample_name_list = [] - replicate_list = [] - - # Base pair combinations in order - base_pairs = ['AA', 'AC', 'AG', 'AT', 'CA', 'CC', 'CG', 'CT', + + base_pairs = ['AA', 'AC', 'AG', 'AT', 'CA', 'CC', 'CG', 'CT', 'GA', 'GC', 'GG', 'GT', 'TA', 'TC', 'TG', 'TT'] - + metadata_cols = ['SeqID', 'ParentIDs', 'ID', 'Mtype', 'Ptype', 'Type', 'Ctype', 'Mode', 'Start', 'End', 'Strand'] + + # Collect sample metadata without loading data yet + sample_info = [] # (filepath, group_name, sample_name, replicate, file_id) for filepath, (group_name, sample_name, replicate) in file_group_sample_replicate_dict.items(): - # Extract file ID (everything before the last underscore in filename) filename_stem = Path(filepath).stem file_id = '_'.join(filename_stem.split('_')[:-1]) + sample_info.append((filepath, group_name, sample_name, replicate, file_id)) + + # Incremental merge: load one sample at a time and free it immediately after merging. + # Peak RAM = merged_so_far + one_new_sample (instead of N samples simultaneously). + merged = None + for filepath, group_name, sample_name, replicate, file_id in sample_info: print(f"Processing {group_name}::{sample_name} (replicate {replicate}) from {filepath} (file_id: {file_id})...") data = parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, include_file_id, min_cov) - all_data.append(data) - file_id_list.append(file_id) - group_name_list.append(group_name) - sample_name_list.append(sample_name) - replicate_list.append(replicate) - - # Merge all samples based on metadata columns - metadata_cols = ['SeqID', 'ParentIDs', 'ID', 'Mtype', 'Ptype', 'Type', 'Ctype', 'Mode', 'Start', 'End', 'Strand'] - merged = all_data[0] - for data in all_data[1:]: - merged = merged.merge(data, on=metadata_cols, how='outer') - - # Do NOT fill NA with 0: NA means the position was not covered (or below min_cov), - # which is distinct from 0 (position covered but no editing observed) - - # Sort by SeqID, then ParentIDs, then Mode + if merged is None: + merged = data + else: + merged = merged.merge(data, on=metadata_cols, how='outer') + del data + gc.collect() + + # Do NOT fill NA with 0: NA means not covered / below min_cov, + # which is distinct from 0 (covered but no editing observed). merged = merged.sort_values(['SeqID', 'ParentIDs', 'Mode']) - - metadata_cols = ['SeqID', 'ParentIDs', 'ID', 'Mtype', 'Ptype', 'Type', 'Ctype', 'Mode', 'Start', 'End', 'Strand'] - - # Create one file per base pair combination + + # Write one file per base pair combination — no intermediate copy, rename in-place. output_files = [] for bp in base_pairs: - # Select columns for this base pair combination bp_cols = metadata_cols.copy() - for group_name, sample_name, replicate, file_id in zip(group_name_list, sample_name_list, replicate_list, file_id_list): + rename_dict = {} + for _, group_name, sample_name, replicate, file_id in sample_info: if include_file_id: col_prefix = f'{group_name}::{sample_name}::{replicate}::{file_id}' else: @@ -248,35 +242,26 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ espr_col = f'{col_prefix}::{bp}::espr' if espf_col in merged.columns: bp_cols.append(espf_col) + rename_dict[espf_col] = f'{col_prefix}::espf' if espr_col in merged.columns: bp_cols.append(espr_col) - - # Create result for this base pair - bp_result = merged[bp_cols].copy() - - # Rename columns to remove the bp suffix (cleaner output) - rename_dict = {} - for group_name, sample_name, replicate, file_id in zip(group_name_list, sample_name_list, replicate_list, file_id_list): - if include_file_id: - col_prefix = f'{group_name}::{sample_name}::{replicate}::{file_id}' - else: - col_prefix = f'{group_name}::{sample_name}::{replicate}' - rename_dict[f'{col_prefix}::{bp}::espf'] = f'{col_prefix}::espf' - rename_dict[f'{col_prefix}::{bp}::espr'] = f'{col_prefix}::espr' - bp_result = bp_result.rename(columns=rename_dict) - - # Save to file + rename_dict[espr_col] = f'{col_prefix}::espr' + output_file = f"{output_prefix}_{bp}.tsv" - bp_result.to_csv(output_file, sep='\t', index=False, na_rep='NA') + # Column selection already creates a new DataFrame; rename + write without extra copy. + merged[bp_cols].rename(columns=rename_dict).to_csv( + output_file, sep='\t', index=False, na_rep='NA' + ) output_files.append(output_file) - + gc.collect() + print(f"\nOutput files created:") for output_file in output_files: print(f" - {output_file}") print(f" - {len(merged)} aggregates per file") - print(f" - {len(sample_name_list)} samples: {', '.join(sample_name_list)}") + print(f" - {len(sample_info)} samples: {', '.join(si[2] for si in sample_info)}") print(f" - {len(base_pairs)} files (one per base pair combination)") - + return merged # Example usage From 794c40b1c2bbc2c4e648d02d7cbf1008905e4f6d Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Tue, 17 Mar 2026 22:27:50 +0100 Subject: [PATCH 15/61] remove extension chimaera from AggregateType --- bin/pluviometer/SeqFeature_extensions.py | 2 +- modules/pluviometer.nf | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/bin/pluviometer/SeqFeature_extensions.py b/bin/pluviometer/SeqFeature_extensions.py index bb8abe4..6b0e9b0 100644 --- a/bin/pluviometer/SeqFeature_extensions.py +++ b/bin/pluviometer/SeqFeature_extensions.py @@ -67,7 +67,7 @@ def make_chimaeras(self: SeqFeature, record_id: str) -> None: chimaera: SeqFeature = SeqFeature( location=location, id=f"{self.id}-{key}-chimaera", - type=key+"-chimaera", + type=key, qualifiers={"Parent": self.id} ) diff --git a/modules/pluviometer.nf b/modules/pluviometer.nf index a821950..9f90f3e 100644 --- a/modules/pluviometer.nf +++ b/modules/pluviometer.nf @@ -26,5 +26,6 @@ process pluviometer { --threads ${task.cpus} \ --aggregation_mode ${params.aggregation_mode} \ --output "${meta.uid}_${tool_format}" + """ } From b5de7ebf1c4976a6d5046490e5742bf88add2204 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Wed, 18 Mar 2026 09:57:20 +0100 Subject: [PATCH 16/61] parallelize drip.py --- bin/drip.py | 96 +++++++++++++++++++++++++---------- config/resources/hpc.config | 8 ++- config/resources/local.config | 4 ++ modules/python.nf | 3 +- 4 files changed, 80 insertions(+), 31 deletions(-) diff --git a/bin/drip.py b/bin/drip.py index 2ebd0d7..86cc9cf 100755 --- a/bin/drip.py +++ b/bin/drip.py @@ -2,6 +2,7 @@ import pandas as pd import numpy as np +import multiprocessing import sys import gc from pathlib import Path @@ -29,6 +30,9 @@ def print_help(): --min-cov N Minimum read coverage threshold (default: 1). Positions with a denominator (genome base count for espf, read count for espr) strictly below this value are reported as NA instead of 0. + --threads N, -t N + Number of parallel threads to use for writing output files + (default: 1, sequential). Max useful value is 16 (one per base pair). --help, -h Display this help message NA BEHAVIOR: @@ -197,7 +201,34 @@ def parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, includ del df return result -def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id=False, min_cov=1): +def write_base_pair_file(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id): + """Worker function to write a single base pair combination file. + + This function is designed to be called in parallel for each base pair. + """ + bp_cols = metadata_cols.copy() + rename_dict = {} + for _, group_name, sample_name, replicate, file_id in sample_info: + if include_file_id: + col_prefix = f'{group_name}::{sample_name}::{replicate}::{file_id}' + else: + col_prefix = f'{group_name}::{sample_name}::{replicate}' + espf_col = f'{col_prefix}::{bp}::espf' + espr_col = f'{col_prefix}::{bp}::espr' + if espf_col in merged.columns: + bp_cols.append(espf_col) + rename_dict[espf_col] = f'{col_prefix}::espf' + if espr_col in merged.columns: + bp_cols.append(espr_col) + rename_dict[espr_col] = f'{col_prefix}::espr' + + output_file = f"{output_prefix}_{bp}.tsv" + merged[bp_cols].rename(columns=rename_dict).to_csv( + output_file, sep='\t', index=False, na_rep='NA' + ) + return output_file + +def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id=False, min_cov=1, threads=1): """Merge data from multiple samples and create output matrices - one file per base pair combination.""" base_pairs = ['AA', 'AC', 'AG', 'AT', 'CA', 'CC', 'CG', 'CT', @@ -228,32 +259,20 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ # which is distinct from 0 (covered but no editing observed). merged = merged.sort_values(['SeqID', 'ParentIDs', 'Mode']) - # Write one file per base pair combination — no intermediate copy, rename in-place. - output_files = [] - for bp in base_pairs: - bp_cols = metadata_cols.copy() - rename_dict = {} - for _, group_name, sample_name, replicate, file_id in sample_info: - if include_file_id: - col_prefix = f'{group_name}::{sample_name}::{replicate}::{file_id}' - else: - col_prefix = f'{group_name}::{sample_name}::{replicate}' - espf_col = f'{col_prefix}::{bp}::espf' - espr_col = f'{col_prefix}::{bp}::espr' - if espf_col in merged.columns: - bp_cols.append(espf_col) - rename_dict[espf_col] = f'{col_prefix}::espf' - if espr_col in merged.columns: - bp_cols.append(espr_col) - rename_dict[espr_col] = f'{col_prefix}::espr' - - output_file = f"{output_prefix}_{bp}.tsv" - # Column selection already creates a new DataFrame; rename + write without extra copy. - merged[bp_cols].rename(columns=rename_dict).to_csv( - output_file, sep='\t', index=False, na_rep='NA' - ) - output_files.append(output_file) - gc.collect() + # Write one file per base pair combination. + # Parallelize if threads > 1: each base pair file is written by a separate worker. + if threads > 1: + print(f"Writing {len(base_pairs)} output files using {threads} threads...") + with multiprocessing.Pool(processes=threads) as pool: + args = [(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id) for bp in base_pairs] + output_files = pool.starmap(write_base_pair_file, args) + else: + print(f"Writing {len(base_pairs)} output files sequentially...") + output_files = [] + for bp in base_pairs: + output_file = write_base_pair_file(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id) + output_files.append(output_file) + gc.collect() print(f"\nOutput files created:") for output_file in output_files: @@ -283,6 +302,7 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ group_sample_rep_counts = {} include_file_id = False # Default: omit file_id from column names min_cov = 1 # Default: NA when denominator is 0; treat 0 coverage as non-observed + threads = 1 # Default: sequential writing args_iter = iter(range(2, len(sys.argv))) for i in args_iter: @@ -309,6 +329,26 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ sys.exit(1) continue + # Check for --threads/-t flag (supports --threads=N, --threads N, -t N) + if arg.startswith('--threads') or arg == '-t': + if arg.startswith('--threads') and '=' in arg: + try: + threads = int(arg.split('=', 1)[1]) + except ValueError: + print(f"ERROR: --threads requires an integer value", file=sys.stderr) + sys.exit(1) + else: + try: + next_i = next(args_iter) + threads = int(sys.argv[next_i]) + except (StopIteration, ValueError): + print(f"ERROR: --threads/-t requires an integer value", file=sys.stderr) + sys.exit(1) + if threads < 1: + print(f"ERROR: --threads must be at least 1", file=sys.stderr) + sys.exit(1) + continue + if ':' not in arg: print(f"ERROR: Invalid argument format '{arg}'", file=sys.stderr) print("Expected format: FILE:GROUP:SAMPLE:REPLICATE", file=sys.stderr) @@ -342,6 +382,6 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ file_group_sample_replicate_dict[filepath] = (group_name, sample_name, replicate) # Process all samples - result = merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id, min_cov) + result = merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id, min_cov, threads) print("\nAnalysis complete!") \ No newline at end of file diff --git a/config/resources/hpc.config b/config/resources/hpc.config index 40807e1..58e5c4d 100644 --- a/config/resources/hpc.config +++ b/config/resources/hpc.config @@ -12,6 +12,10 @@ process { memory = '4GB' time = '2d' // Long timeout to wait for submitted job completion } + withLabel: 'drip' { + cpus = 8 + time = '6h' + } withLabel: 'fastqc' { cpus = 8 time = '6h' @@ -29,11 +33,11 @@ process { time = '4h' } withLabel: 'pluviometer' { - cpus = 6 + cpus = 8 time = '6h' } withLabel: 'reditools3' { - cpus = 6 + cpus = 8 time = '1d' } withLabel: 'samtools' { diff --git a/config/resources/local.config b/config/resources/local.config index 3752d38..8b608e0 100644 --- a/config/resources/local.config +++ b/config/resources/local.config @@ -27,6 +27,10 @@ process { cpus = 4 memory = '16GB' time = '2d' + } + withLabel: 'drip' { + cpus = 4 + time = '6h' } withLabel: 'hisat2_index' { cpus = 4 diff --git a/modules/python.nf b/modules/python.nf index 256325a..477cf6b 100644 --- a/modules/python.nf +++ b/modules/python.nf @@ -56,6 +56,7 @@ process drip { def args_str = args.join(" ") """ - drip.py drip_${prefix} ${args_str} + + drip.py --threads ${task.cpu} drip_${prefix} ${args_str} """ } \ No newline at end of file From c14da8a9261afcf774951cb6e0e9448e5d988ff0 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Wed, 18 Mar 2026 10:14:13 +0100 Subject: [PATCH 17/61] BLAS was using nb CPU out of scope of --threads (even before we paralelize by pair of nucleotide). We now avoid that by forcing the max CPU to use by numpy. --- bin/drip.py | 13 ++++++++++++- 1 file changed, 12 insertions(+), 1 deletion(-) diff --git a/bin/drip.py b/bin/drip.py index 86cc9cf..1b98b93 100755 --- a/bin/drip.py +++ b/bin/drip.py @@ -1,9 +1,20 @@ #!/usr/bin/env python3 +import os +import sys + +# Limit implicit BLAS/LAPACK multi-threading to respect SLURM CPU allocation. +# Will be set to 1 by default and updated based on --threads CLI argument. +# This ensures total CPU usage = implicit threads + explicit Pool threads <= allocated CPUs. +os.environ.setdefault('OMP_NUM_THREADS', '1') +os.environ.setdefault('OPENBLAS_NUM_THREADS', '1') +os.environ.setdefault('MKL_NUM_THREADS', '1') +os.environ.setdefault('VECLIB_MAXIMUM_THREADS', '1') +os.environ.setdefault('NUMEXPR_NUM_THREADS', '1') + import pandas as pd import numpy as np import multiprocessing -import sys import gc from pathlib import Path From 7c090a883030ea7bfb8b755f9c0d8362cee33f70 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Wed, 18 Mar 2026 11:21:38 +0100 Subject: [PATCH 18/61] round value to minimize file --- bin/drip.py | 43 +++++++++++++++++++++++++++++++++++++------ 1 file changed, 37 insertions(+), 6 deletions(-) diff --git a/bin/drip.py b/bin/drip.py index 1b98b93..33bcf65 100755 --- a/bin/drip.py +++ b/bin/drip.py @@ -44,6 +44,9 @@ def print_help(): --threads N, -t N Number of parallel threads to use for writing output files (default: 1, sequential). Max useful value is 16 (one per base pair). + --decimals N, -d N + Number of decimal places for output values (default: 4). + Reduces file size by rounding espf/espr metrics. --help, -h Display this help message NA BEHAVIOR: @@ -212,13 +215,14 @@ def parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, includ del df return result -def write_base_pair_file(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id): +def write_base_pair_file(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id, decimals): """Worker function to write a single base pair combination file. This function is designed to be called in parallel for each base pair. """ bp_cols = metadata_cols.copy() rename_dict = {} + metric_cols = [] # Track metric columns to round for _, group_name, sample_name, replicate, file_id in sample_info: if include_file_id: col_prefix = f'{group_name}::{sample_name}::{replicate}::{file_id}' @@ -229,17 +233,23 @@ def write_base_pair_file(merged, bp, metadata_cols, sample_info, output_prefix, if espf_col in merged.columns: bp_cols.append(espf_col) rename_dict[espf_col] = f'{col_prefix}::espf' + metric_cols.append(espf_col) if espr_col in merged.columns: bp_cols.append(espr_col) rename_dict[espr_col] = f'{col_prefix}::espr' + metric_cols.append(espr_col) output_file = f"{output_prefix}_{bp}.tsv" - merged[bp_cols].rename(columns=rename_dict).to_csv( + # Select columns, round metric columns, rename, and write + bp_df = merged[bp_cols].copy() + for col in metric_cols: + bp_df[col] = bp_df[col].round(decimals) + bp_df.rename(columns=rename_dict).to_csv( output_file, sep='\t', index=False, na_rep='NA' ) return output_file -def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id=False, min_cov=1, threads=1): +def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id=False, min_cov=1, threads=1, decimals=4): """Merge data from multiple samples and create output matrices - one file per base pair combination.""" base_pairs = ['AA', 'AC', 'AG', 'AT', 'CA', 'CC', 'CG', 'CT', @@ -275,13 +285,13 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ if threads > 1: print(f"Writing {len(base_pairs)} output files using {threads} threads...") with multiprocessing.Pool(processes=threads) as pool: - args = [(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id) for bp in base_pairs] + args = [(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id, decimals) for bp in base_pairs] output_files = pool.starmap(write_base_pair_file, args) else: print(f"Writing {len(base_pairs)} output files sequentially...") output_files = [] for bp in base_pairs: - output_file = write_base_pair_file(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id) + output_file = write_base_pair_file(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id, decimals) output_files.append(output_file) gc.collect() @@ -314,6 +324,7 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ include_file_id = False # Default: omit file_id from column names min_cov = 1 # Default: NA when denominator is 0; treat 0 coverage as non-observed threads = 1 # Default: sequential writing + decimals = 4 # Default: round to 4 decimal places args_iter = iter(range(2, len(sys.argv))) for i in args_iter: @@ -360,6 +371,26 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ sys.exit(1) continue + # Check for --decimals/-d flag (supports --decimals=N, --decimals N, -d N) + if arg.startswith('--decimals') or arg == '-d': + if arg.startswith('--decimals') and '=' in arg: + try: + decimals = int(arg.split('=', 1)[1]) + except ValueError: + print(f"ERROR: --decimals requires an integer value", file=sys.stderr) + sys.exit(1) + else: + try: + next_i = next(args_iter) + decimals = int(sys.argv[next_i]) + except (StopIteration, ValueError): + print(f"ERROR: --decimals/-d requires an integer value", file=sys.stderr) + sys.exit(1) + if decimals < 0: + print(f"ERROR: --decimals must be non-negative", file=sys.stderr) + sys.exit(1) + continue + if ':' not in arg: print(f"ERROR: Invalid argument format '{arg}'", file=sys.stderr) print("Expected format: FILE:GROUP:SAMPLE:REPLICATE", file=sys.stderr) @@ -393,6 +424,6 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ file_group_sample_replicate_dict[filepath] = (group_name, sample_name, replicate) # Process all samples - result = merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id, min_cov, threads) + result = merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id, min_cov, threads, decimals) print("\nAnalysis complete!") \ No newline at end of file From 95a656469d76645906ab6659032dfef066fef77c Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Wed, 18 Mar 2026 11:39:45 +0100 Subject: [PATCH 19/61] round when computing instead of writing to save mermory while merging --- bin/drip.py | 26 +++++++++++++------------- 1 file changed, 13 insertions(+), 13 deletions(-) diff --git a/bin/drip.py b/bin/drip.py index 33bcf65..25ce7e8 100755 --- a/bin/drip.py +++ b/bin/drip.py @@ -145,7 +145,7 @@ def print_help(): print(help_text) sys.exit(0) -def parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, include_file_id=False, min_cov=1): +def parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, include_file_id=False, min_cov=1, decimals=4): """Parse a single TSV file and extract editing metrics for all base pair combinations.""" # SeqID, Start, End, Strand contain mixed values ("." and actual numbers/strings) # → force them to string to avoid DtypeWarning and preserve "." as-is @@ -190,6 +190,8 @@ def parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, includ denom_espf = df[f'{genome_base}_count'] mask_espf = denom_espf >= min_cov df[espf_col] = np.where(mask_espf, df[f'{bp}_sites'] / denom_espf.where(mask_espf, 1), np.nan) + # Round immediately to reduce RAM footprint during merge + df[espf_col] = df[espf_col].round(decimals) result_cols.append(espf_col) # Calculate espr: XY_reads / (XA + XC + XG + XT) @@ -207,6 +209,8 @@ def parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, includ denom_espr = df[total_reads_col] mask_espr = denom_espr >= min_cov df[espr_col] = np.where(mask_espr, df[f'{bp}_reads'] / denom_espr.where(mask_espr, 1), np.nan) + # Round immediately to reduce RAM footprint during merge + df[espr_col] = df[espr_col].round(decimals) result_cols.append(espr_col) # Select only needed columns — list indexing creates a new DataFrame, @@ -215,14 +219,14 @@ def parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, includ del df return result -def write_base_pair_file(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id, decimals): +def write_base_pair_file(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id): """Worker function to write a single base pair combination file. This function is designed to be called in parallel for each base pair. + Note: Values are already rounded in parse_tsv_file() to reduce RAM usage. """ bp_cols = metadata_cols.copy() rename_dict = {} - metric_cols = [] # Track metric columns to round for _, group_name, sample_name, replicate, file_id in sample_info: if include_file_id: col_prefix = f'{group_name}::{sample_name}::{replicate}::{file_id}' @@ -233,18 +237,13 @@ def write_base_pair_file(merged, bp, metadata_cols, sample_info, output_prefix, if espf_col in merged.columns: bp_cols.append(espf_col) rename_dict[espf_col] = f'{col_prefix}::espf' - metric_cols.append(espf_col) if espr_col in merged.columns: bp_cols.append(espr_col) rename_dict[espr_col] = f'{col_prefix}::espr' - metric_cols.append(espr_col) output_file = f"{output_prefix}_{bp}.tsv" - # Select columns, round metric columns, rename, and write - bp_df = merged[bp_cols].copy() - for col in metric_cols: - bp_df[col] = bp_df[col].round(decimals) - bp_df.rename(columns=rename_dict).to_csv( + # Values are already rounded, just select, rename and write + merged[bp_cols].rename(columns=rename_dict).to_csv( output_file, sep='\t', index=False, na_rep='NA' ) return output_file @@ -268,7 +267,7 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ merged = None for filepath, group_name, sample_name, replicate, file_id in sample_info: print(f"Processing {group_name}::{sample_name} (replicate {replicate}) from {filepath} (file_id: {file_id})...") - data = parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, include_file_id, min_cov) + data = parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, include_file_id, min_cov, decimals) if merged is None: merged = data else: @@ -282,16 +281,17 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ # Write one file per base pair combination. # Parallelize if threads > 1: each base pair file is written by a separate worker. + # Note: values are already rounded in parse_tsv_file() to reduce RAM during merge. if threads > 1: print(f"Writing {len(base_pairs)} output files using {threads} threads...") with multiprocessing.Pool(processes=threads) as pool: - args = [(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id, decimals) for bp in base_pairs] + args = [(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id) for bp in base_pairs] output_files = pool.starmap(write_base_pair_file, args) else: print(f"Writing {len(base_pairs)} output files sequentially...") output_files = [] for bp in base_pairs: - output_file = write_base_pair_file(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id, decimals) + output_file = write_base_pair_file(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id) output_files.append(output_file) gc.collect() From 436c994095a7b0246c07d0972a31bd49c985d575 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Wed, 18 Mar 2026 19:48:17 +0100 Subject: [PATCH 20/61] update info --- README.md | 6 ++++++ bin/README | 12 ++---------- 2 files changed, 8 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index 286871b..4c1d585 100644 --- a/README.md +++ b/README.md @@ -355,3 +355,9 @@ Jacques Dainat (@Juke34) ## Contributing Contributions from the community are welcome ! See the [Contributing guidelines](https://github.com/Juke34/rain/blob/main/CONTRIBUTING.md) + +## TODO + +update pluviometer to set NA for start end and strand instead of . to be able to use column as int64 in drip and barometer e.g. dtype={"SeqID": str, "Start": "Int64", "End": "Int64", "Strand": str} + +Le feature type seq_agg must be called sequence_agg. Maybe _agg can be removed because know by another column. \ No newline at end of file diff --git a/bin/README b/bin/README index d48d843..8caf0ac 100644 --- a/bin/README +++ b/bin/README @@ -20,20 +20,12 @@ python -m pluviometer --sites SITES --gff GFF [OPTIONS] python pluviometer_wrapper.py --sites SITES --gff GFF [OPTIONS] ``` -### drip_features.py +### drip.py Post-processing tool for pluviometer feature output. Analyzes RNA editing from feature TSV files, calculating editing metrics (espf and espr) for all 16 genome-variant base pair combinations across multiple samples. Combines data into unified matrix format. **Usage:** ```bash -./drip_features.py OUTPUT_PREFIX FILE1:SAMPLE1 FILE2:SAMPLE2 [...] -``` - -### drip_aggregates.py -Post-processing tool for pluviometer aggregate output. Similar to drip_features.py but operates on aggregate-level data, calculating editing metrics for aggregated genomic regions across samples. - -**Usage:** -```bash -./drip_aggregates.py OUTPUT_PREFIX FILE1:SAMPLE1 FILE2:SAMPLE2 [...] +./drip.py OUTPUT_PREFIX FILE1:GROUP1:SAMPLE1:REPLICATE1 FILE2:GROUP1:SAMPLE2:REPLICATE1 [...] ``` ### restore_sequences.py From f418b564f2db1ad96a551894dd28bfc7958345c0 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Wed, 18 Mar 2026 19:48:34 +0100 Subject: [PATCH 21/61] increase drip CPU --- config/resources/hpc.config | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/config/resources/hpc.config b/config/resources/hpc.config index 58e5c4d..bd57363 100644 --- a/config/resources/hpc.config +++ b/config/resources/hpc.config @@ -13,7 +13,7 @@ process { time = '2d' // Long timeout to wait for submitted job completion } withLabel: 'drip' { - cpus = 8 + cpus = 16 time = '6h' } withLabel: 'fastqc' { From 47ca6e81dc0e4b4fad6533b4023408fc7ab20092 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Wed, 18 Mar 2026 19:49:30 +0100 Subject: [PATCH 22/61] add coverage parameter set to 10 by default --- modules/pluviometer.nf | 6 ++---- rain.nf | 5 ++++- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/modules/pluviometer.nf b/modules/pluviometer.nf index 9f90f3e..3f91bf9 100644 --- a/modules/pluviometer.nf +++ b/modules/pluviometer.nf @@ -16,16 +16,14 @@ process pluviometer { script: base_name = site_edits.BaseName """ - pluviometer_wrapper.py \ --sites ${site_edits} \ --gff ${gff} \ --format ${tool_format} \ - --cov 1 \ + --cov ${params.cov_threshold} \ --edit_threshold ${params.edit_threshold} \ --threads ${task.cpus} \ --aggregation_mode ${params.aggregation_mode} \ - --output "${meta.uid}_${tool_format}" - + --output "${meta.uid}_${tool_format}" """ } diff --git a/rain.nf b/rain.nf index b79b964..5699763 100644 --- a/rain.nf +++ b/rain.nf @@ -22,7 +22,8 @@ params.clean_duplicate = true // Edit counting params edit_site_tools = ["reditools2", "reditools3", "jacusa2", "sapin"] params.edit_site_tool = "reditools3" -params.edit_threshold = 1 +params.edit_threshold = 1 // Minimal number of edited reads to count a site as edited +params.cov_threshold = 10 // Minimal coverage to consider a site for editing detection params.aggregation_mode = "all" params.skip_hyper_editing = false // Skip hyper-editing detection // Report params @@ -128,6 +129,7 @@ def helpMSG() { --aligner Aligner to use [default: $params.aligner] --clean_duplicate Remove PCR duplicates from BAM files using GATK MarkDuplicates. [default: $params.clean_duplicate] --clip_overlap Clip overlapping sequences in read pairs to avoid double counting. [default: $params.clipoverlap] + --cov_threshold Minimal coverage to consider a site for editing detection [default: $params.cov_threshold] --debug Enable debug output for troubleshooting. [default: $params.debug] --edit_site_tool Tool used for detecting edited sites. [default: $params.edit_site_tool] --edit_threshold Minimal number of edited reads to count a site as edited [default: $params.edit_threshold] @@ -160,6 +162,7 @@ Alignment Parameters Edited Site Detection Parameters + cov_threshold : ${params.cov_threshold} edit_site_tool : ${params.edit_site_tool} edit_threshold : ${params.edit_threshold} region : ${params.region} From e65bf692818bc23e42ee3de333ad30f0a90f7439 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Thu, 19 Mar 2026 16:44:05 +0100 Subject: [PATCH 23/61] fix call for threads --- bin/drip.py | 46 ++++++++++++++++++++++++++++++++++++++-------- modules/python.nf | 2 +- 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/bin/drip.py b/bin/drip.py index 25ce7e8..d55b3a2 100755 --- a/bin/drip.py +++ b/bin/drip.py @@ -29,11 +29,13 @@ def print_help(): samples and combines them into a unified matrix format. USAGE: - ./drip.py OUTPUT_PREFIX FILE1:GROUP1:SAMPLE1:REP1 FILE2:GROUP2:SAMPLE2:REP2 [...] [--with-file-id] + ./drip.py --output OUTPUT_PREFIX FILE1:GROUP1:SAMPLE1:REP1 FILE2:GROUP2:SAMPLE2:REP2 [...] ./drip.py --help | -h ARGUMENTS: - OUTPUT_PREFIX Prefix for the output TSV file (will create OUTPUT_PREFIX.tsv) + --output OUTPUT_PREFIX, -o OUTPUT_PREFIX + Prefix for the output TSV files (required). + Will create OUTPUT_PREFIX_AA.tsv, OUTPUT_PREFIX_AC.tsv, etc. FILEn:GROUPn:SAMPLEn:REPn Input file path, group name, sample name, and replicate ID separated by colons. All four components are required. @@ -113,7 +115,7 @@ def print_help(): - The '::' separator allows easy splitting to retrieve all components EXAMPLE: - ./drip.py results \\ + ./drip.py --output results \\ sample1_aggregates.tsv:control:sample1:rep1 \\ sample2_aggregates.tsv:control:sample2:rep2 \\ sample3_aggregates.tsv:treated:sample1:rep1 @@ -311,14 +313,14 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ print_help() # Validate arguments - if len(sys.argv) < 3: + if len(sys.argv) < 2: print("ERROR: Insufficient arguments provided.\n", file=sys.stderr) - print("Usage: ./drip.py OUTPUT_PREFIX FILE1:SAMPLE1 FILE2:SAMPLE2 [...]\n", file=sys.stderr) + print("Usage: ./drip.py --output OUTPUT_PREFIX FILE1:GROUP1:SAMPLE1:REP1 [...]\n", file=sys.stderr) print("For detailed help, run: ./drip.py --help\n", file=sys.stderr) sys.exit(1) # Parse command line arguments - output_prefix = sys.argv[1] + output_prefix = None file_group_sample_replicate_dict = {} group_sample_rep_counts = {} include_file_id = False # Default: omit file_id from column names @@ -326,9 +328,23 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ threads = 1 # Default: sequential writing decimals = 4 # Default: round to 4 decimal places - args_iter = iter(range(2, len(sys.argv))) + args_iter = iter(range(1, len(sys.argv))) for i in args_iter: - arg = sys.argv[i] + arg = sys.argv[i] # Check for --output/-o flag (supports --output=PREFIX, --output PREFIX, -o PREFIX) + if arg.startswith('--output') or arg == '-o': + if arg.startswith('--output') and '=' in arg: + output_prefix = arg.split('=', 1)[1] + if not output_prefix: + print("ERROR: --output requires a value", file=sys.stderr) + sys.exit(1) + else: + try: + next_i = next(args_iter) + output_prefix = sys.argv[next_i] + except StopIteration: + print("ERROR: --output/-o requires a value", file=sys.stderr) + sys.exit(1) + continue # Check for --with-file-id flag if arg == '--with-file-id': include_file_id = True @@ -423,6 +439,20 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ file_group_sample_replicate_dict[filepath] = (group_name, sample_name, replicate) + # Validate that output_prefix was provided + if output_prefix is None: + print("ERROR: --output/-o argument is required.\n", file=sys.stderr) + print("Usage: ./drip.py --output OUTPUT_PREFIX FILE1:GROUP1:SAMPLE1:REP1 [...]\n", file=sys.stderr) + print("For detailed help, run: ./drip.py --help\n", file=sys.stderr) + sys.exit(1) + + # Validate that at least one input file was provided + if len(file_group_sample_replicate_dict) == 0: + print("ERROR: At least one input file is required.\n", file=sys.stderr) + print("Usage: ./drip.py --output OUTPUT_PREFIX FILE1:GROUP1:SAMPLE1:REP1 [...]\n", file=sys.stderr) + print("For detailed help, run: ./drip.py --help\n", file=sys.stderr) + sys.exit(1) + # Process all samples result = merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id, min_cov, threads, decimals) diff --git a/modules/python.nf b/modules/python.nf index 477cf6b..d5076f2 100644 --- a/modules/python.nf +++ b/modules/python.nf @@ -57,6 +57,6 @@ process drip { """ - drip.py --threads ${task.cpu} drip_${prefix} ${args_str} + drip.py --threads ${task.cpus} --output drip_${prefix} ${args_str} """ } \ No newline at end of file From 081d6babee84a619131bf1e59b7b32126860ed1f Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Thu, 19 Mar 2026 16:46:57 +0100 Subject: [PATCH 24/61] update output and AliNe verssion --- modules/aline.nf | 140 ++++++++-------------------------- rain.nf | 9 +-- subworkflows/hyper-editing.nf | 5 +- 3 files changed, 38 insertions(+), 116 deletions(-) diff --git a/modules/aline.nf b/modules/aline.nf index 6c8e9fd..9fbff88 100644 --- a/modules/aline.nf +++ b/modules/aline.nf @@ -1,12 +1,11 @@ /* Adapted from from https://github.com/mahesh-panchal/nf-cascade -Auto-detects HPC/local environment and adapts execution strategy */ process AliNe { tag "$pipeline_name" label 'aline' - publishDir "${params.outdir}", mode: 'copy' + publishDir "${output_dir}", mode: 'copy' errorStrategy 'terminate' // to avoid any retry maxRetries 0 // Override global retry config - do not retry this process @@ -21,116 +20,43 @@ process AliNe { val aligner val library_type val annotation - val cache_dir // String - val use_slurm // Boolean - whether parent is using slurm + val cache_dir // cache directory + val output_dir // output directory when: task.ext.when == null || task.ext.when - script: - def nxf_cmd = "nextflow run ${pipeline_name} ${profile} ${config} --reads ${reads} --reference ${genome} ${read_type} ${aligner} ${library_type} --annotation ${annotation} --data_type rna --outdir \$WORK_DIR/AliNe" - """ - echo "[AliNe] Process started at \$(date '+%Y-%m-%d %H:%M:%S')" - - # Save absolute work directory before changing context - WORK_DIR=\$(pwd) - - # Create cache directory for resume AliNe run made from different working directory - mkdir -p "${cache_dir}" - cd "${cache_dir}" - - # Save command for reference/debugging - echo "${nxf_cmd}" > \$WORK_DIR/nf-cmd.sh - - # Detect execution environment and run AliNe accordingly - if ${use_slurm} && command -v sbatch >/dev/null 2>&1; then - echo "[AliNe] Detected HPC environment - submitting AliNe as separate SLURM job" - - # Create sbatch script for AliNe - cat > \$WORK_DIR/aline_job.sh < \$WORK_DIR/aline_job_id.txt - - # Wait for job to appear in scheduler queue - # Simple wait for job to appear or to be started - echo "[AliNe] Waiting for job \$JOB_ID to appear in queue or to be started..." - while true; do - # Check if job is in queue - if squeue -j \$JOB_ID 2>/dev/null | grep -q \$JOB_ID; then - echo "[AliNe] Job \$JOB_ID is now visible in queue" - break - else - echo "[AliNe] Job \$JOB_ID not yet visible in queue \$(date '+%Y-%m-%d %H:%M:%S')" - fi - # Check job state (crash test: if job has started or finished) - JOB_STATE=\$(sacct -j \$JOB_ID --format=State --noheader | head -1 | tr -d ' ') - if [[ "\$JOB_STATE" =~ ^(RUNNING|COMPLETED|FAILED|CANCELLED|TIMEOUT|PREEMPTED|NODE_FAIL|OUT_OF_MEMORY)\$ ]]; then - echo "[AliNe] Job \$JOB_ID has state: \$JOB_STATE (not visible in queue, but started or finished)" - break - fi - echo "[AliNe] Job not yet visible, waiting..." - sleep 5 - done # Simple wait for job to appear or to be started - - # Wait for job completion - echo "[AliNe] Waiting for job \$JOB_ID to complete..." - while squeue -j \$JOB_ID 2>/dev/null | grep -q \$JOB_ID; do - sleep 30 - done - - # Check job exit status - JOB_STATE=\$(sacct -j \$JOB_ID --format=State --noheader | head -1 | tr -d ' ') - echo "[AliNe] Job \$JOB_ID finished with state: \$JOB_STATE at \$(date '+%Y-%m-%d %H:%M:%S')" - - if [[ "\$JOB_STATE" != "COMPLETED" ]]; then - echo "[AliNe] ERROR: Job failed with state \$JOB_STATE at \$(date '+%Y-%m-%d %H:%M:%S')" >&2 - echo "With message (100 last lines)": >&2 - tail -n 100 \$WORK_DIR/aline_\$JOB_ID.out >&2 - exit 1 - fi - - # Copy log for reference - if [ -f .nextflow.log ]; then - cp .nextflow.log \$WORK_DIR/nextflow.log - fi - echo "[AliNe] Pipeline completed successfully via SLURM at \$(date '+%Y-%m-%d %H:%M:%S')" - else - echo "[AliNe] Detected local/standard environment - running AliNe directly" - - # Run nextflow command directly - ${nxf_cmd} || { - echo "[AliNe] ERROR: Pipeline failed at \$(date '+%Y-%m-%d %H:%M:%S')" >&2 - exit 1 - } - - echo "[AliNe] Pipeline completed successfully (direct execution) at \$(date '+%Y-%m-%d %H:%M:%S')" - fi - - echo "[AliNe] Process finished at \$(date '+%Y-%m-%d %H:%M:%S')" -""" + exec: + def cache_path = file(cache_dir) + assert cache_path.mkdirs() + // construct nextflow command + def nxf_cmd = [ + 'nextflow run', + pipeline_name, + profile, + config, + "--reads ${reads}", + "--reference ${genome}", + read_type, + aligner, + library_type, + "--annotation ${annotation}", + "--data_type rna", + "--outdir $task.workDir/AliNe", + ].join(" ") + // Copy command to shell script in work dir for reference/debugging. + file("$task.workDir/nf-cmd.sh").text = nxf_cmd + // Run nextflow command locally in cache directory + def process = nxf_cmd.execute(null, cache_path.toFile()) + // Print process output to stdout and stderr + process.consumeProcessOutput(System.out, System.err) + process.waitFor() + stdout = process.text + // Copy nextflow log to work directory + cache_path.resolve(".nextflow.log").copyTo("${task.workDir}/nextflow.log") + assert process.exitValue() == 0: stdout output: path "AliNe" , emit: output + val stdout, emit: log } diff --git a/rain.nf b/rain.nf index 5699763..331674e 100644 --- a/rain.nf +++ b/rain.nf @@ -36,7 +36,6 @@ params.region = "" // e.g. chr21 - Used to limit the analysis to a specific regi params.help = null params.monochrome_logs = false // if true, no color in logs params.debug = false // Enable debug output -params.use_slurm_for_aline = false // Whether to submit AliNe as a separate SLURM job when using an HPC environment // -------------------------------------------------- /* ---- Params shared between RAIN and AliNe ---- */ @@ -63,7 +62,7 @@ params.trimming_fastp = false align_tools = [ 'bbmap', 'bowtie', 'bowtie2', 'bwaaln', 'bwamem', 'bwamem2', 'bwasw', 'dragmap', 'graphmap2', 'hisat2', 'kallisto', 'last', 'minimap2', 'novoalign', 'nucmer', 'ngmlr', 'salmon', 'star', 'subread', 'sublong' ] params.aligner = 'hisat2' // AliNe version -params.aline_version = 'v1.6.2' +params.aline_version = 'v1.6.3' //************************************************* // STEP 1 - HELP //************************************************* @@ -211,7 +210,6 @@ else { exit 1, "No executer selected: please use a profile activating docker or // check AliNE profile def aline_profile_list=[] -def use_slurm_for_aline = params.use_slurm_for_aline str_list = workflow.profile.tokenize(',') str_list.each { if ( it in aline_profile_allowed ){ @@ -527,7 +525,7 @@ workflow { if ( via_csv ){ aline_data_in_ch = recreate_csv_with_abs_paths(csv_ch) - aline_data_in_ch = collect_aline_csv(aline_data_in_ch.collect(), "AliNe") + aline_data_in_ch = collect_aline_csv(aline_data_in_ch.collect(), params.outdir) aline_data_in = aline_data_in_ch } else { @@ -596,7 +594,7 @@ workflow { "--strandedness ${params.strandedness}", clean_annotation, workflow.workDir.resolve('Juke34/AliNe').toUriString(), - use_slurm_for_aline // Pass info about whether to use slurm submission + params.outdir, ) // GET TUPLE [ID, BAM] FILES @@ -673,7 +671,6 @@ workflow { samtools_split_mapped_unmapped.out.unmapped_bam, genome, aline_profile, - use_slurm_for_aline, clean_annotation, 30, // quality threshold "${params.outdir}/hyper_editing", diff --git a/subworkflows/hyper-editing.nf b/subworkflows/hyper-editing.nf index 1531d7f..2260bf2 100644 --- a/subworkflows/hyper-editing.nf +++ b/subworkflows/hyper-editing.nf @@ -27,7 +27,6 @@ workflow HYPER_EDITING { unmapped_bams // Unmapped read chunks from primary alignment genome // Genomic reference sequence aline_profile // AliNe profile in coma-separated format - use_slurm_for_aline // Boolean - whether to use slurm for AliNe clean_annotation // Annotation file for AliNe quality_threshold // Quality score filter threshold output_he // output directory path ier @@ -48,7 +47,7 @@ workflow HYPER_EDITING { // Stage 4: Create CSV file for AliNe with all converted reads aline_csv = create_aline_csv_he(converted_reads).collect() - aline_csv = collect_aline_csv(aline_csv,output_he) + aline_csv = collect_aline_csv(aline_csv, output_he) // Stage 5: Build alignment index for converted reference alignment_index = samtools_fasta_index(converted_reference) @@ -69,7 +68,7 @@ workflow HYPER_EDITING { "--strandedness ${params.strandedness}", clean_annotation, workflow.workDir.resolve('Juke34/AliNe').toUriString(), - use_slurm_for_aline // Pass info about whether to use slurm submission + output_he, ) // GET TUPLE [ID, BAM] FILES From 922804e76d2e96cba6e039cc181d1078197d5a7b Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Thu, 19 Mar 2026 16:47:19 +0100 Subject: [PATCH 25/61] mend --- bin/pluviometer/README.md | 6 ++++++ bin/pluviometer/__main__.py | 13 ++++++++++++- config/resources/hpc.config | 2 +- config/resources/local.config | 6 +++++- modules/bash.nf | 2 +- 5 files changed, 25 insertions(+), 4 deletions(-) diff --git a/bin/pluviometer/README.md b/bin/pluviometer/README.md index 2e7cf2a..b6b9d64 100644 --- a/bin/pluviometer/README.md +++ b/bin/pluviometer/README.md @@ -49,3 +49,9 @@ The following files have been moved from `bin/` to `bin/pluviometer/`: - utils.py → pluviometer/utils.py All internal imports have been converted to relative imports to function as a Python package. + +## Test + +# Call as a python module (recommended) +cd bin +python -m pluviometer --sites SITES --gff GFF [OPTIONS] \ No newline at end of file diff --git a/bin/pluviometer/__main__.py b/bin/pluviometer/__main__.py index f41a42e..af1596b 100755 --- a/bin/pluviometer/__main__.py +++ b/bin/pluviometer/__main__.py @@ -746,7 +746,7 @@ def parse_cli_input() -> argparse.Namespace: "--cov", "-c", type=int, - default=0, + default=1, help="Site coverage threshold for counting editions", ) parser.add_argument( @@ -1105,6 +1105,17 @@ def main(): log_filename: str = args.output + "_pluviometer.log" if args.output else "pluviometer.log" logging.basicConfig(filename=log_filename, level=logging.INFO, format=LOGGING_FORMAT) logging.info(f"Pluviometer started. Log file: {log_filename}") + logging.info("Options:") + logging.info(f" Sites file: {args.sites}") + logging.info(f" GFF file: {args.gff}") + logging.info(f" Output prefix: {args.output if args.output else 'default'}") + logging.info(f" Format: {args.format}") + logging.info(f" Coverage threshold: {args.cov}") + logging.info(f" Edit threshold: {args.edit_threshold}") + logging.info(f" Aggregation mode: {args.aggregation_mode}") + logging.info(f" Threads: {args.threads}") + logging.info(f" Progress bar: {args.progress}") + logging.info(f" GFF feature types filter: {args.gff_feature_types if args.gff_feature_types else 'none (all types)'}") feature_output_filename: str = args.output + "_features.tsv" if args.output else "features.tsv" aggregate_output_filename: str = ( args.output + "_aggregates.tsv" if args.output else "aggregates.tsv" diff --git a/config/resources/hpc.config b/config/resources/hpc.config index bd57363..b003a4e 100644 --- a/config/resources/hpc.config +++ b/config/resources/hpc.config @@ -12,7 +12,7 @@ process { memory = '4GB' time = '2d' // Long timeout to wait for submitted job completion } - withLabel: 'drip' { + withName: 'drip' { cpus = 16 time = '6h' } diff --git a/config/resources/local.config b/config/resources/local.config index 8b608e0..1a88460 100644 --- a/config/resources/local.config +++ b/config/resources/local.config @@ -4,7 +4,11 @@ process { maxForks = 8 shell = ['/bin/bash', '-euo', 'pipefail'] stageOutMode = 'rsync' - + + withName: 'drip' { + cpus = 4 + time = '6h' + } withLabel: 'fastqc' { cpus = 2 time = '6h' diff --git a/modules/bash.nf b/modules/bash.nf index 534e397..1cc194b 100644 --- a/modules/bash.nf +++ b/modules/bash.nf @@ -126,7 +126,7 @@ process create_aline_csv_he { */ process collect_aline_csv { label 'bash' - publishDir("${params.outdir}/${output_dir}", mode:"copy", pattern: "*.csv") + publishDir("${output_dir}", mode:"copy", pattern: "*.csv") input: val all_csv // List of tuples (meta, fastq_files) From 753fe46db7c9e252f8a9163046305a22fbb32c5e Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Thu, 19 Mar 2026 17:07:55 +0000 Subject: [PATCH 26/61] add debug, fix first feature position to activate feature via state_update_cycle(). --- bin/pluviometer/__main__.py | 54 ++++++++++++++++++++++++++++++++++--- 1 file changed, 50 insertions(+), 4 deletions(-) diff --git a/bin/pluviometer/__main__.py b/bin/pluviometer/__main__.py index af1596b..d4dca99 100755 --- a/bin/pluviometer/__main__.py +++ b/bin/pluviometer/__main__.py @@ -323,16 +323,19 @@ def state_update_cycle(self, new_position: int) -> None: visited_positions: int = 0 while ( - len(self.action_queue) > 0 and self.action_queue[0][0] < new_position - ): # Use < instead of <= because of Python's right-exclusive indfgexing + len(self.action_queue) > 0 and self.action_queue[0][0] <= new_position + ): # Use <= to process actions at the current position before updating counters _, actions = self.action_queue.popleft() visited_positions += 1 for feature in actions.activate: + logging.debug(f"Activating feature: {feature.id} (type: {feature.type}, level: {feature.level})") self.active_features[feature.id] = feature for feature in actions.deactivate: + logging.debug(f"Deactivating feature: {feature.id} (type: {feature.type}, level: {feature.level})") if feature.level == 1: + logging.debug(f"Calling checkout on level-1 feature: {feature.id}") self.checkout(feature, None) self.active_features.pop(feature.id, None) @@ -372,6 +375,12 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> Checkout should be triggered after a level-1 feature is completely cleared. Aggregation of the level-1 feature and all its descentants proceeds. """ + + logging.debug( + f"checkout() called: feature.id={feature.id}, feature.type={feature.type}, " + f"feature.level={feature.level}, is_chimaera={getattr(feature, 'is_chimaera', 'NOT_SET')}, " + f"has_counter={feature.id in self.counters}" + ) # Deactivate the feature self.active_features.pop(feature.id, None) @@ -384,6 +393,7 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> if feature_counter: if feature.is_chimaera: assert parent_feature # A chimaera must always have a parent feature (a gene) + logging.debug(f"Writing chimaera with data: {feature.id}") self.aggregate_writer.write_row_chimaera_with_data( self.record.id, feature, parent_feature, feature_counter ) @@ -391,16 +401,19 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> # Also track by parent_type self.chimaera_aggregate_counters_by_parent_type[(parent_feature.type, feature.type)].merge(feature_counter) else: + logging.debug(f"Writing feature with data: {feature.id}, type: {feature.type}") self.feature_writer.write_row_with_data(self.record.id, feature, feature_counter) # Explicitly delete counter after use to free memory del feature_counter else: if feature.is_chimaera: assert parent_feature + logging.debug(f"Writing chimaera without data: {feature.id}") self.aggregate_writer.write_row_chimaera_without_data( self.record.id, feature, parent_feature ) else: + logging.debug(f"Writing feature without data: {feature.id}, type: {feature.type}") self.feature_writer.write_row_without_data(self.record.id, feature) # all_isoforms_aggregation_counters: Optional[defaultdict[str, MultiCounter]] = None @@ -652,7 +665,8 @@ def update_active_counters(self, site_data: RNASiteVariantData) -> None: return None def is_finished(self) -> bool: - return len(self.action_queue) > 0 or len(self.active_features) > 0 + """Return True if there is no remaining work (queues are empty and no active features)""" + return len(self.action_queue) == 0 and len(self.active_features) == 0 def launch_counting(self, reader: RNASiteVariantReader) -> None: next_svdata: Optional[RNASiteVariantData] = reader.seek_record(self.record.id) @@ -780,6 +794,19 @@ def parse_cli_input() -> argparse.Namespace: default=None, help="Comma-separated list of GFF feature types to load (e.g., 'gene,mRNA,exon,CDS'). Reduces memory usage by filtering during parsing. Leave empty to load all types.", ) + parser.add_argument( + "--log-level", + type=str, + default="INFO", + choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], + help="Logging level (default: INFO). Use DEBUG to see detailed site filtering information.", + ) + parser.add_argument( + "--log-to-console", + action="store_true", + default=False, + help="Output logs to console/terminal in addition to log file.", + ) return parser.parse_args() @@ -1103,7 +1130,25 @@ def main(): args = parse_cli_input() LOGGING_FORMAT = "%(asctime)s - %(levelname)s - %(message)s" log_filename: str = args.output + "_pluviometer.log" if args.output else "pluviometer.log" - logging.basicConfig(filename=log_filename, level=logging.INFO, format=LOGGING_FORMAT) + log_level = getattr(logging, args.log_level.upper()) + + # Configure logging with file handler + logger = logging.getLogger() + logger.setLevel(log_level) + + # File handler + file_handler = logging.FileHandler(log_filename, mode='w') + file_handler.setLevel(log_level) + file_handler.setFormatter(logging.Formatter(LOGGING_FORMAT)) + logger.addHandler(file_handler) + + # Console handler if requested + if args.log_to_console: + console_handler = logging.StreamHandler() + console_handler.setLevel(log_level) + console_handler.setFormatter(logging.Formatter(LOGGING_FORMAT)) + logger.addHandler(console_handler) + logging.info(f"Pluviometer started. Log file: {log_filename}") logging.info("Options:") logging.info(f" Sites file: {args.sites}") @@ -1116,6 +1161,7 @@ def main(): logging.info(f" Threads: {args.threads}") logging.info(f" Progress bar: {args.progress}") logging.info(f" GFF feature types filter: {args.gff_feature_types if args.gff_feature_types else 'none (all types)'}") + logging.info(f" Log level: {args.log_level}") feature_output_filename: str = args.output + "_features.tsv" if args.output else "features.tsv" aggregate_output_filename: str = ( args.output + "_aggregates.tsv" if args.output else "aggregates.tsv" From 2bca65e2f386907c5294576bb79f13bc459edef1 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Fri, 20 Mar 2026 09:32:55 +0000 Subject: [PATCH 27/61] CoveredSites becomes ObservedSites. We add now TotalSites column, QualifiedSites columns and tests --- bin/pluviometer/README.md | 15 ++- bin/pluviometer/__main__.py | 131 +++++++++++++++++++++++++-- bin/pluviometer/multi_counter.py | 6 ++ bin/pluviometer/rain_file_writers.py | 52 +++++++++-- bin/pluviometer/site_filter.py | 27 +++++- 5 files changed, 210 insertions(+), 21 deletions(-) diff --git a/bin/pluviometer/README.md b/bin/pluviometer/README.md index b6b9d64..5c67d85 100644 --- a/bin/pluviometer/README.md +++ b/bin/pluviometer/README.md @@ -54,4 +54,17 @@ All internal imports have been converted to relative imports to function as a Py # Call as a python module (recommended) cd bin -python -m pluviometer --sites SITES --gff GFF [OPTIONS] \ No newline at end of file +python -m pluviometer --sites SITES --gff GFF [OPTIONS] + +## calcul location: +Cas gérés correctement : + +1. Features simples (exon, CDS) : SimpleLocation → len = longueur réelle +2. Features parent (gene, mRNA) du GFF : SimpleLocation(start, end) → len = span complet (incluant introns) +C'est la sémantique correcte du GFF - un gène "couvre" toute sa région +3. Aggregates chimaera créés par le code : CompoundLocation via location_union() → len = somme des parties sans gaps ✓ +4. All sites - for ttotal site it is "." because cannot guess without fasta. + +To test everything goes well: + +PYTHONPATH=/workspaces/rain/bin python3 -m unittest pluviometer.test_site_counts -v \ No newline at end of file diff --git a/bin/pluviometer/__main__.py b/bin/pluviometer/__main__.py index d4dca99..60d0bff 100755 --- a/bin/pluviometer/__main__.py +++ b/bin/pluviometer/__main__.py @@ -121,6 +121,16 @@ def __init__( ) """Dictionary of counters for the record-level all isoforms aggregates. Keys correspond to aggregate feature types""" + # Positions for aggregate counters + self.longest_isoform_aggregate_positions: dict[str, AggregatePositions] = {} + """Dictionary of positions for the record-level longest isoform aggregates""" + + self.chimaera_aggregate_positions: dict[str, AggregatePositions] = {} + """Dictionary of positions for the record-level chimaeric aggregates""" + + self.all_isoforms_aggregate_positions: dict[str, AggregatePositions] = {} + """Dictionary of positions for the record-level all isoforms aggregates""" + # New: Aggregate counters by (parent_type, aggregate_type) for intermediate-level aggregation self.longest_isoform_aggregate_counters_by_parent_type: defaultdict[tuple[str, str], MultiCounter] = defaultdict( DefaultMultiCounterFactory(self.filter) @@ -400,6 +410,13 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> self.chimaera_aggregate_counters[feature.type].merge(feature_counter) # Also track by parent_type self.chimaera_aggregate_counters_by_parent_type[(parent_feature.type, feature.type)].merge(feature_counter) + + # Track positions for chimaera + if feature.type not in self.chimaera_aggregate_positions: + self.chimaera_aggregate_positions[feature.type] = AggregatePositions( + strand=feature.location.strand if hasattr(feature.location, 'strand') else 0 + ) + self.chimaera_aggregate_positions[feature.type].update_from_feature(feature) else: logging.debug(f"Writing feature with data: {feature.id}, type: {feature.type}") self.feature_writer.write_row_with_data(self.record.id, feature, feature_counter) @@ -431,6 +448,18 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> merge_aggregation_counter_dicts( self.all_isoforms_aggregate_counters, level1_all_isoforms_aggregation_counters ) + + # Merge positions for all_isoforms at record level + for aggregate_type, positions in level1_all_isoforms_aggregation_positions.items(): + if aggregate_type not in self.all_isoforms_aggregate_positions: + self.all_isoforms_aggregate_positions[aggregate_type] = AggregatePositions(strand=positions.strand) + self.all_isoforms_aggregate_positions[aggregate_type].merge(positions) + + # Merge positions for longest_isoform at record level + for aggregate_type, positions in level1_longest_isoform_aggregation_positions.items(): + if aggregate_type not in self.longest_isoform_aggregate_positions: + self.longest_isoform_aggregate_positions[aggregate_type] = AggregatePositions(strand=positions.strand) + self.longest_isoform_aggregate_positions[aggregate_type].merge(positions) # Also merge into by-parent-type dictionaries parent_type = feature.type @@ -459,7 +488,20 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> end=end_str, strand=strand_str, ) - self.aggregate_writer.write_counter_data(aggregate_counter) + # Calculate TotalSites from positions + total_sites_str = "." + if aggregate_type in level1_longest_isoform_aggregation_positions: + pos = level1_longest_isoform_aggregation_positions[aggregate_type] + if pos.start != float('inf') and pos.end > 0: + total_sites_str = str(pos.end - pos.start) + self.aggregate_writer.write_data( + total_sites_str, + str(aggregate_counter.genome_base_freqs.sum()), + str(aggregate_counter.filtered_sites_count), + ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), + ",".join(map(str, aggregate_counter.edit_site_freqs[0:4, 0:4].flat)), + ",".join(map(str, aggregate_counter.edit_read_freqs[0:4, 0:4].flat)), + ) for ( aggregate_type, @@ -481,7 +523,20 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> end=end_str, strand=strand_str, ) - self.aggregate_writer.write_counter_data(aggregate_counter) + # Calculate TotalSites from positions + total_sites_str = "." + if aggregate_type in level1_all_isoforms_aggregation_positions: + pos = level1_all_isoforms_aggregation_positions[aggregate_type] + if pos.start != float('inf') and pos.end > 0: + total_sites_str = str(pos.end - pos.start) + self.aggregate_writer.write_data( + total_sites_str, + str(aggregate_counter.genome_base_freqs.sum()), + str(aggregate_counter.filtered_sites_count), + ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), + ",".join(map(str, aggregate_counter.edit_site_freqs[0:4, 0:4].flat)), + ",".join(map(str, aggregate_counter.edit_read_freqs[0:4, 0:4].flat)), + ) else: feature_aggregation_counters, feature_aggregation_positions = self.aggregate_children(feature) for aggregate_type, aggregate_counter in feature_aggregation_counters.items(): @@ -501,7 +556,20 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> end=end_str, strand=strand_str, ) - self.aggregate_writer.write_counter_data(aggregate_counter) + # Calculate TotalSites from positions + total_sites_str = "." + if aggregate_type in feature_aggregation_positions: + pos = feature_aggregation_positions[aggregate_type] + if pos.start != float('inf') and pos.end > 0: + total_sites_str = str(pos.end - pos.start) + self.aggregate_writer.write_data( + total_sites_str, + str(aggregate_counter.genome_base_freqs.sum()), + str(aggregate_counter.filtered_sites_count), + ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), + ",".join(map(str, aggregate_counter.edit_site_freqs[0:4, 0:4].flat)), + ",".join(map(str, aggregate_counter.edit_read_freqs[0:4, 0:4].flat)), + ) # Recursively check-out children for child in feature.sub_features: @@ -941,6 +1009,7 @@ def _do_counting(record: SeqRecord) -> dict[str, Any]: ".", "longest_isoform", record_ctx.longest_isoform_aggregate_counters, + record_ctx.longest_isoform_aggregate_positions, ) aggregate_writer.write_rows_with_data( record.id, @@ -949,6 +1018,7 @@ def _do_counting(record: SeqRecord) -> dict[str, Any]: ".", "all_isoforms", record_ctx.all_isoforms_aggregate_counters, + record_ctx.all_isoforms_aggregate_positions, ) aggregate_writer.write_rows_with_data( record.id, @@ -957,6 +1027,7 @@ def _do_counting(record: SeqRecord) -> dict[str, Any]: ".", "chimaera", record_ctx.chimaera_aggregate_counters, + record_ctx.chimaera_aggregate_positions, ) # Write aggregate counter data by parent type @@ -966,6 +1037,7 @@ def _do_counting(record: SeqRecord) -> dict[str, Any]: ".", "longest_isoform", record_ctx.longest_isoform_aggregate_counters_by_parent_type, + record_ctx.longest_isoform_aggregate_positions, ) aggregate_writer.write_rows_with_data_by_parent_type( record.id, @@ -973,6 +1045,7 @@ def _do_counting(record: SeqRecord) -> dict[str, Any]: ".", "all_isoforms", record_ctx.all_isoforms_aggregate_counters_by_parent_type, + record_ctx.all_isoforms_aggregate_positions, ) aggregate_writer.write_rows_with_data_by_parent_type( record.id, @@ -980,6 +1053,7 @@ def _do_counting(record: SeqRecord) -> dict[str, Any]: ".", "chimaera", record_ctx.chimaera_aggregate_counters_by_parent_type, + record_ctx.chimaera_aggregate_positions, ) # Write the total counter data of the record. A dummy dict needs to be created to use the `write_rows_with_data` method @@ -1003,6 +1077,9 @@ def _do_counting(record: SeqRecord) -> dict[str, Any]: "longest_isoform_aggregate_counters_by_parent_type": record_ctx.longest_isoform_aggregate_counters_by_parent_type.copy(), "all_isoforms_aggregate_counters_by_parent_type": record_ctx.all_isoforms_aggregate_counters_by_parent_type.copy(), "total_counter": record_ctx.total_counter, + "longest_isoform_aggregate_positions": record_ctx.longest_isoform_aggregate_positions.copy(), + "all_isoforms_aggregate_positions": record_ctx.all_isoforms_aggregate_positions.copy(), + "chimaera_aggregate_positions": record_ctx.chimaera_aggregate_positions.copy(), } # Clean up record context to free memory immediately @@ -1051,6 +1128,9 @@ def process_and_write_record_data( genome_all_isoforms_aggregate_counters_by_parent_type: defaultdict[tuple[str, str], MultiCounter], genome_chimaera_aggregate_counters_by_parent_type: defaultdict[tuple[str, str], MultiCounter], genome_total_counter: MultiCounter, + genome_longest_isoform_aggregate_positions: dict[str, AggregatePositions], + genome_all_isoforms_aggregate_positions: dict[str, AggregatePositions], + genome_chimaera_aggregate_positions: dict[str, AggregatePositions], ) -> None: """ Process a single record's data: write its output and merge counters into genome totals. @@ -1112,6 +1192,22 @@ def process_and_write_record_data( # Update the genome's total counter from the record data total counter genome_total_counter.merge(record_data["total_counter"]) + # Merge positions for aggregates + for aggregate_type, positions in record_data.get("longest_isoform_aggregate_positions", {}).items(): + if aggregate_type not in genome_longest_isoform_aggregate_positions: + genome_longest_isoform_aggregate_positions[aggregate_type] = AggregatePositions(strand=positions.strand) + genome_longest_isoform_aggregate_positions[aggregate_type].merge(positions) + + for aggregate_type, positions in record_data.get("all_isoforms_aggregate_positions", {}).items(): + if aggregate_type not in genome_all_isoforms_aggregate_positions: + genome_all_isoforms_aggregate_positions[aggregate_type] = AggregatePositions(strand=positions.strand) + genome_all_isoforms_aggregate_positions[aggregate_type].merge(positions) + + for aggregate_type, positions in record_data.get("chimaera_aggregate_positions", {}).items(): + if aggregate_type not in genome_chimaera_aggregate_positions: + genome_chimaera_aggregate_positions[aggregate_type] = AggregatePositions(strand=positions.strand) + genome_chimaera_aggregate_positions[aggregate_type].merge(positions) + # Clear record_data counters to free memory immediately after merging record_data["chimaera_aggregate_counters"].clear() record_data["longest_isoform_aggregate_counters"].clear() @@ -1197,6 +1293,11 @@ def main(): genome_chimaera_aggregate_counters: defaultdict[str, MultiCounter] = defaultdict( lambda: MultiCounter(genome_filter) ) + + # Positions for genome-level aggregates + genome_longest_isoform_aggregate_positions: dict[str, AggregatePositions] = {} + genome_all_isoforms_aggregate_positions: dict[str, AggregatePositions] = {} + genome_chimaera_aggregate_positions: dict[str, AggregatePositions] = {} # By-parent-type genome-level aggregates genome_longest_isoform_aggregate_counters_by_parent_type: defaultdict[tuple[str, str], MultiCounter] = defaultdict( @@ -1251,6 +1352,9 @@ def _record_id_from_pickle(path: str) -> str: genome_all_isoforms_aggregate_counters_by_parent_type, genome_chimaera_aggregate_counters_by_parent_type, genome_total_counter, + genome_longest_isoform_aggregate_positions, + genome_all_isoforms_aggregate_positions, + genome_chimaera_aggregate_positions, ) del record_data # Force flush to disk after each chromosome @@ -1287,6 +1391,9 @@ def _record_id_from_pickle(path: str) -> str: genome_all_isoforms_aggregate_counters_by_parent_type, genome_chimaera_aggregate_counters_by_parent_type, genome_total_counter, + genome_longest_isoform_aggregate_positions, + genome_all_isoforms_aggregate_positions, + genome_chimaera_aggregate_positions, ) del result feature_output_handle.flush() @@ -1298,24 +1405,30 @@ def _record_id_from_pickle(path: str) -> str: # Write genomic counts aggregate_writer.write_rows_with_data( - ".", ["."], ".", ".", "longest_isoform", genome_longest_isoform_aggregate_counters + ".", ["."], ".", ".", "longest_isoform", genome_longest_isoform_aggregate_counters, + genome_longest_isoform_aggregate_positions ) aggregate_writer.write_rows_with_data( - ".", ["."], ".", ".", "all_isoforms", genome_all_isoforms_aggregate_counters + ".", ["."], ".", ".", "all_isoforms", genome_all_isoforms_aggregate_counters, + genome_all_isoforms_aggregate_positions ) aggregate_writer.write_rows_with_data( - ".", ["."], ".", ".", "chimaera", genome_chimaera_aggregate_counters + ".", ["."], ".", ".", "chimaera", genome_chimaera_aggregate_counters, + genome_chimaera_aggregate_positions ) # Write genomic counts by parent type aggregate_writer.write_rows_with_data_by_parent_type( - ".", ["."], ".", "longest_isoform", genome_longest_isoform_aggregate_counters_by_parent_type + ".", ["."], ".", "longest_isoform", genome_longest_isoform_aggregate_counters_by_parent_type, + genome_longest_isoform_aggregate_positions ) aggregate_writer.write_rows_with_data_by_parent_type( - ".", ["."], ".", "all_isoforms", genome_all_isoforms_aggregate_counters_by_parent_type + ".", ["."], ".", "all_isoforms", genome_all_isoforms_aggregate_counters_by_parent_type, + genome_all_isoforms_aggregate_positions ) aggregate_writer.write_rows_with_data_by_parent_type( - ".", ["."], ".", "chimaera", genome_chimaera_aggregate_counters_by_parent_type + ".", ["."], ".", "chimaera", genome_chimaera_aggregate_counters_by_parent_type, + genome_chimaera_aggregate_positions ) # Write the genomic total. A dummy dict needs to be created to use the `write_rows_with_data` method diff --git a/bin/pluviometer/multi_counter.py b/bin/pluviometer/multi_counter.py index d11f147..378ff37 100644 --- a/bin/pluviometer/multi_counter.py +++ b/bin/pluviometer/multi_counter.py @@ -18,6 +18,7 @@ def __init__(self, site_filter: SiteFilter) -> None: self.edit_site_freqs: NDArray[np.int64] = np.zeros((5, 5), dtype=np.int64) self.genome_base_freqs: NDArray[np.int64] = np.zeros(5, dtype=np.int64) + self.filtered_sites_count: int = 0 # Number of sites that pass the coverage filter self.filter = site_filter @@ -32,6 +33,10 @@ def update(self, variant_data: RNASiteVariantData) -> None: self.filter.apply(variant_data) self.edit_site_freqs[i, :] += self.filter.frequencies + # Count sites that pass the coverage filter + if variant_data.coverage >= self.filter.cov_threshold: + self.filtered_sites_count += 1 + self.genome_base_freqs[i] += 1 return None @@ -43,6 +48,7 @@ def merge(self, other_counter: "MultiCounter") -> None: self.edit_read_freqs[:] += other_counter.edit_read_freqs self.edit_site_freqs[:] += other_counter.edit_site_freqs self.genome_base_freqs[:] += other_counter.genome_base_freqs + self.filtered_sites_count += other_counter.filtered_sites_count return None diff --git a/bin/pluviometer/rain_file_writers.py b/bin/pluviometer/rain_file_writers.py index 89686d4..de3b25a 100644 --- a/bin/pluviometer/rain_file_writers.py +++ b/bin/pluviometer/rain_file_writers.py @@ -3,7 +3,10 @@ from .multi_counter import MultiCounter from Bio.SeqFeature import SeqFeature from collections import defaultdict -from typing import TextIO +from typing import TextIO, Optional, TYPE_CHECKING + +if TYPE_CHECKING: + from __main__ import AggregatePositions def make_parent_path(parent_list: list[str]) -> str: @@ -102,7 +105,9 @@ class FeatureFileWriter(RainFileWriter): ] data_fields: list[str] = [ - "CoveredSites", + "TotalSites", + "ObservedSites", + "QualifiedSites", "GenomeBases", "SiteBasePairings", "ReadBasePairings" @@ -134,7 +139,9 @@ def write_row_with_data( ) -> int: """Write the data fields (coverage, read frequency, &c) of an output line, taken from the informations in a counter object""" return self.write_metadata(record_id, feature) + self.write_data( - str(counter.genome_base_freqs.sum()), + str(len(feature.location)), # TotalSites: total positions in the feature + str(counter.genome_base_freqs.sum()), # ObservedSites: sites with REDItools observations + str(counter.filtered_sites_count), # QualifiedSites: sites passing coverage filter # Subindexing is used below because counter matrices are 5x5 because they contain pairings with N, which we don't print # Use the flat attribute of the NumPy arrays to flatten the matrix row-wise to attain the desired base pairing order in the output ",".join(map(str, counter.genome_base_freqs[0:4].flat)), @@ -148,7 +155,7 @@ def write_row_without_data(self, record_id: str, feature: SeqFeature) -> int: This is faster than creating dummy counter objects with zero observations. """ return self.write_metadata(record_id, feature) + self.write_data( - "0", self.STR_ZERO_BASE_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS + str(len(feature.location)), "0", "0", self.STR_ZERO_BASE_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS ) @@ -166,7 +173,9 @@ class AggregateFileWriter(RainFileWriter): ] data_fields: list[str] = [ - "CoveredSites", + "TotalSites", + "ObservedSites", + "QualifiedSites", "GenomeBases", "SiteBasePairings", "ReadBasePairings", @@ -224,6 +233,7 @@ def write_rows_with_data( feature_type: str, aggregation_mode: str, counter_dict: defaultdict[str, MultiCounter], + positions_dict: Optional[dict[str, 'AggregatePositions']] = None, ) -> int: """Write metadata and data fields of multiple counters of the same aggregate feature""" b: int = 0 @@ -241,8 +251,17 @@ def write_rows_with_data( ".", ) + # Calculate TotalSites from positions if available + total_sites_str = "." + if positions_dict and aggregate_type in positions_dict: + pos = positions_dict[aggregate_type] + if pos.start != float('inf') and pos.end > 0: + total_sites_str = str(pos.end - pos.start) + b += self.write_data( - str(aggregate_counter.genome_base_freqs.sum()), + total_sites_str, # TotalSites: calculated from aggregate positions + str(aggregate_counter.genome_base_freqs.sum()), # ObservedSites + str(aggregate_counter.filtered_sites_count), # QualifiedSites ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), ",".join(map(str, aggregate_counter.edit_site_freqs[0:4, 0:4].flat)), ",".join(map(str, aggregate_counter.edit_read_freqs[0:4, 0:4].flat)), @@ -276,7 +295,9 @@ def write_row_chimaera_with_data( strand=strand_str, ) b += self.write_data( - str(counter.genome_base_freqs.sum()), + str(len(feature.location)) if hasattr(feature, 'location') and feature.location else ".", # TotalSites + str(counter.genome_base_freqs.sum()), # ObservedSites + str(counter.filtered_sites_count), # QualifiedSites ",".join(map(str, counter.genome_base_freqs[0:4].flat)), ",".join(map(str, counter.edit_site_freqs[0:4, 0:4].flat)), ",".join(map(str, counter.edit_read_freqs[0:4, 0:4].flat)), @@ -309,7 +330,10 @@ def write_row_chimaera_without_data( end=end_str, strand=strand_str, ) - b += self.write_data("0", self.STR_ZERO_BASE_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS) + b += self.write_data( + str(len(feature.location)) if hasattr(feature, 'location') and feature.location else ".", # TotalSites + "0", "0", self.STR_ZERO_BASE_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS + ) return b @@ -325,6 +349,7 @@ def write_rows_with_data_by_parent_type( aggregate_id: str, aggregation_mode: str, counter_dict: defaultdict[tuple[str, str], MultiCounter], + positions_dict: Optional[dict[str, 'AggregatePositions']] = None, ) -> int: """Write metadata and data fields of multiple counters grouped by (parent_type, aggregate_type)""" b: int = 0 @@ -342,8 +367,17 @@ def write_rows_with_data_by_parent_type( ".", ) + # Calculate TotalSites from positions if available + total_sites_str = "." + if positions_dict and aggregate_type in positions_dict: + pos = positions_dict[aggregate_type] + if pos.start != float('inf') and pos.end > 0: + total_sites_str = str(pos.end - pos.start) + b += self.write_data( - str(aggregate_counter.genome_base_freqs.sum()), + total_sites_str, # TotalSites: calculated from aggregate positions + str(aggregate_counter.genome_base_freqs.sum()), # ObservedSites + str(aggregate_counter.filtered_sites_count), # QualifiedSites ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), ",".join(map(str, aggregate_counter.edit_site_freqs[0:4, 0:4].flat)), ",".join(map(str, aggregate_counter.edit_read_freqs[0:4, 0:4].flat)), diff --git a/bin/pluviometer/site_filter.py b/bin/pluviometer/site_filter.py index b8e3ee8..745c3c8 100644 --- a/bin/pluviometer/site_filter.py +++ b/bin/pluviometer/site_filter.py @@ -1,6 +1,9 @@ from .utils import RNASiteVariantData import numpy as np from numpy.typing import NDArray +import logging + +logger = logging.getLogger(__name__) class SiteFilter: def __init__(self, cov_threshold: int, edit_threshold: int) -> None: @@ -9,13 +12,33 @@ def __init__(self, cov_threshold: int, edit_threshold: int) -> None: self.frequencies: NDArray[np.int32] = np.zeros(5, np.int32) def apply(self, variant_data: RNASiteVariantData) -> None: + logger.debug( + f"SiteFilter.apply() BEFORE - seqid={variant_data.seqid}, " + f"position={variant_data.position}, reference={variant_data.reference}, " + f"strand={variant_data.strand}, coverage={variant_data.coverage}, " + f"mean_quality={variant_data.mean_quality:.2f}, " + f"frequencies={variant_data.frequencies}, score={variant_data.score:.2f}, " + f"cov_threshold={self.cov_threshold}, edit_threshold={self.edit_threshold}, " + f"self.frequencies[before]={self.frequencies}, " + f"self.cov_threshold={self.cov_threshold}, self.edit_threshold={self.edit_threshold}" + ) + + # Have to pass the coverage threshold first if variant_data.coverage >= self.cov_threshold: np.copyto( self.frequencies, - variant_data.frequencies * variant_data.frequencies - >= self.edit_threshold, + variant_data.frequencies ) + self.frequencies[self.frequencies < self.edit_threshold] = 0 + else: + logger.debug("set to 0!") self.frequencies.fill(0) + + logger.debug( + f"SiteFilter.apply() AFTER - position={variant_data.position}, " + f"coverage_check={variant_data.coverage >= self.cov_threshold}, " + f"self.frequencies[after]={self.frequencies}" + ) return None From 68895d9acf09e7cc8d8b7844f6555e8ebc42986f Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Fri, 20 Mar 2026 09:43:26 +0000 Subject: [PATCH 28/61] GenomesBases becomes ObservedBases and we add QualifiedBases --- bin/pluviometer/multi_counter.py | 5 ++++- bin/pluviometer/rain_file_writers.py | 19 ++++++++++++------- bin/pluviometer_wrapper.py | 2 +- 3 files changed, 17 insertions(+), 9 deletions(-) diff --git a/bin/pluviometer/multi_counter.py b/bin/pluviometer/multi_counter.py index 378ff37..aa1fb2b 100644 --- a/bin/pluviometer/multi_counter.py +++ b/bin/pluviometer/multi_counter.py @@ -18,6 +18,7 @@ def __init__(self, site_filter: SiteFilter) -> None: self.edit_site_freqs: NDArray[np.int64] = np.zeros((5, 5), dtype=np.int64) self.genome_base_freqs: NDArray[np.int64] = np.zeros(5, dtype=np.int64) + self.filtered_base_freqs: NDArray[np.int64] = np.zeros(5, dtype=np.int64) # Bases for qualified sites self.filtered_sites_count: int = 0 # Number of sites that pass the coverage filter self.filter = site_filter @@ -33,9 +34,10 @@ def update(self, variant_data: RNASiteVariantData) -> None: self.filter.apply(variant_data) self.edit_site_freqs[i, :] += self.filter.frequencies - # Count sites that pass the coverage filter + # Count sites and bases that pass the coverage filter if variant_data.coverage >= self.filter.cov_threshold: self.filtered_sites_count += 1 + self.filtered_base_freqs[i] += 1 self.genome_base_freqs[i] += 1 @@ -48,6 +50,7 @@ def merge(self, other_counter: "MultiCounter") -> None: self.edit_read_freqs[:] += other_counter.edit_read_freqs self.edit_site_freqs[:] += other_counter.edit_site_freqs self.genome_base_freqs[:] += other_counter.genome_base_freqs + self.filtered_base_freqs[:] += other_counter.filtered_base_freqs self.filtered_sites_count += other_counter.filtered_sites_count return None diff --git a/bin/pluviometer/rain_file_writers.py b/bin/pluviometer/rain_file_writers.py index de3b25a..847e582 100644 --- a/bin/pluviometer/rain_file_writers.py +++ b/bin/pluviometer/rain_file_writers.py @@ -108,7 +108,8 @@ class FeatureFileWriter(RainFileWriter): "TotalSites", "ObservedSites", "QualifiedSites", - "GenomeBases", + "ObservedBases", + "QualifiedBases", "SiteBasePairings", "ReadBasePairings" ] @@ -144,7 +145,8 @@ def write_row_with_data( str(counter.filtered_sites_count), # QualifiedSites: sites passing coverage filter # Subindexing is used below because counter matrices are 5x5 because they contain pairings with N, which we don't print # Use the flat attribute of the NumPy arrays to flatten the matrix row-wise to attain the desired base pairing order in the output - ",".join(map(str, counter.genome_base_freqs[0:4].flat)), + ",".join(map(str, counter.genome_base_freqs[0:4].flat)), # ObservedBases: base distribution for all observations + ",".join(map(str, counter.filtered_base_freqs[0:4].flat)), # QualifiedBases: base distribution for qualified sites ",".join(map(str, counter.edit_site_freqs[0:4, 0:4].flat)), ",".join(map(str, counter.edit_read_freqs[0:4, 0:4].flat)), ) @@ -155,7 +157,7 @@ def write_row_without_data(self, record_id: str, feature: SeqFeature) -> int: This is faster than creating dummy counter objects with zero observations. """ return self.write_metadata(record_id, feature) + self.write_data( - str(len(feature.location)), "0", "0", self.STR_ZERO_BASE_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS + str(len(feature.location)), "0", "0", self.STR_ZERO_BASE_FREQS, self.STR_ZERO_BASE_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS ) @@ -176,7 +178,8 @@ class AggregateFileWriter(RainFileWriter): "TotalSites", "ObservedSites", "QualifiedSites", - "GenomeBases", + "ObservedBases", + "QualifiedBases", "SiteBasePairings", "ReadBasePairings", ] @@ -262,7 +265,8 @@ def write_rows_with_data( total_sites_str, # TotalSites: calculated from aggregate positions str(aggregate_counter.genome_base_freqs.sum()), # ObservedSites str(aggregate_counter.filtered_sites_count), # QualifiedSites - ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), + ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), # ObservedBases + ",".join(map(str, aggregate_counter.filtered_base_freqs[0:4].flat)), # QualifiedBases ",".join(map(str, aggregate_counter.edit_site_freqs[0:4, 0:4].flat)), ",".join(map(str, aggregate_counter.edit_read_freqs[0:4, 0:4].flat)), ) @@ -298,7 +302,8 @@ def write_row_chimaera_with_data( str(len(feature.location)) if hasattr(feature, 'location') and feature.location else ".", # TotalSites str(counter.genome_base_freqs.sum()), # ObservedSites str(counter.filtered_sites_count), # QualifiedSites - ",".join(map(str, counter.genome_base_freqs[0:4].flat)), + ",".join(map(str, counter.genome_base_freqs[0:4].flat)), # ObservedBases + ",".join(map(str, counter.filtered_base_freqs[0:4].flat)), # QualifiedBases ",".join(map(str, counter.edit_site_freqs[0:4, 0:4].flat)), ",".join(map(str, counter.edit_read_freqs[0:4, 0:4].flat)), ) @@ -332,7 +337,7 @@ def write_row_chimaera_without_data( ) b += self.write_data( str(len(feature.location)) if hasattr(feature, 'location') and feature.location else ".", # TotalSites - "0", "0", self.STR_ZERO_BASE_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS + "0", "0", self.STR_ZERO_BASE_FREQS, self.STR_ZERO_BASE_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS ) return b diff --git a/bin/pluviometer_wrapper.py b/bin/pluviometer_wrapper.py index 3774710..5020175 100755 --- a/bin/pluviometer_wrapper.py +++ b/bin/pluviometer_wrapper.py @@ -1,4 +1,4 @@ -#!/usr/bin/env python +#!/usr/bin/env python3 """Wrapper to run pluviometer module""" import sys From 1eb0bd86c28c58eef8f0f3f71352776fe93144c9 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Fri, 20 Mar 2026 09:43:35 +0000 Subject: [PATCH 29/61] add tests --- bin/pluviometer/test_site_counts.py | 408 ++++++++++++++++++++++++++++ 1 file changed, 408 insertions(+) create mode 100644 bin/pluviometer/test_site_counts.py diff --git a/bin/pluviometer/test_site_counts.py b/bin/pluviometer/test_site_counts.py new file mode 100644 index 0000000..7caeba1 --- /dev/null +++ b/bin/pluviometer/test_site_counts.py @@ -0,0 +1,408 @@ +#!/usr/bin/env python3 +""" +Tests pour les colonnes TotalSites, ObservedSites et QualifiedSites + +Pour exécuter: + cd /workspaces/rain/bin + PYTHONPATH=/workspaces/rain/bin python3 -m unittest pluviometer.test_site_counts -v +""" + +import unittest +import numpy as np +from Bio.SeqFeature import SeqFeature, SimpleLocation, CompoundLocation + +# Import du package pluviometer +from pluviometer.multi_counter import MultiCounter +from pluviometer.site_filter import SiteFilter +from pluviometer.utils import RNASiteVariantData +from pluviometer.__main__ import AggregatePositions + + +class TestAggregatePositions(unittest.TestCase): + """Tests pour la classe AggregatePositions""" + + def test_initial_state(self): + """Test l'état initial""" + pos = AggregatePositions(strand=1) + self.assertEqual(pos.start, float('inf')) + self.assertEqual(pos.end, 0) + self.assertEqual(pos.strand, 1) + + def test_update_from_simple_feature(self): + """Test mise à jour avec une SimpleLocation""" + pos = AggregatePositions(strand=1) + feature = SeqFeature(location=SimpleLocation(10, 20), id="test", type="exon") + pos.update_from_feature(feature) + + self.assertEqual(pos.start, 10) + self.assertEqual(pos.end, 20) + self.assertEqual(pos.strand, 1) + + def test_update_from_multiple_features(self): + """Test mise à jour avec plusieurs features""" + pos = AggregatePositions(strand=1) + + feature1 = SeqFeature(location=SimpleLocation(10, 20), id="exon1", type="exon") + feature2 = SeqFeature(location=SimpleLocation(5, 15), id="exon2", type="exon") + feature3 = SeqFeature(location=SimpleLocation(30, 40), id="exon3", type="exon") + + pos.update_from_feature(feature1) + pos.update_from_feature(feature2) + pos.update_from_feature(feature3) + + # Start doit être le minimum (5), end le maximum (40) + self.assertEqual(pos.start, 5) + self.assertEqual(pos.end, 40) + + def test_merge_positions(self): + """Test fusion de deux AggregatePositions""" + pos1 = AggregatePositions(strand=1) + pos1.start = 10 + pos1.end = 20 + + pos2 = AggregatePositions(strand=1) + pos2.start = 5 + pos2.end = 30 + + pos1.merge(pos2) + + self.assertEqual(pos1.start, 5) + self.assertEqual(pos1.end, 30) + + def test_to_strings_valid(self): + """Test conversion en strings avec positions valides""" + pos = AggregatePositions(strand=1) + pos.start = 10 + pos.end = 20 + + start_str, end_str, strand_str = pos.to_strings() + + # Start converti en 1-based + self.assertEqual(start_str, "11") + self.assertEqual(end_str, "20") + self.assertEqual(strand_str, "1") + + def test_to_strings_empty(self): + """Test conversion en strings avec positions non initialisées""" + pos = AggregatePositions(strand=0) + + start_str, end_str, strand_str = pos.to_strings() + + self.assertEqual(start_str, ".") + self.assertEqual(end_str, ".") + self.assertEqual(strand_str, ".") + + +class TestMultiCounterSiteCounts(unittest.TestCase): + """Tests pour les compteurs de sites dans MultiCounter""" + + def test_observed_sites_count(self): + """Test que ObservedSites compte toutes les observations""" + filter = SiteFilter(cov_threshold=5, edit_threshold=2) + counter = MultiCounter(filter) + + # Ajouter 3 sites avec différentes couvertures + site1 = RNASiteVariantData( + seqid="chr1", position=10, reference=0, # A + strand=1, coverage=10, mean_quality=30.0, + frequencies=np.array([8, 2, 0, 0, 0], dtype=np.int32), + score=0.0 + ) + site2 = RNASiteVariantData( + seqid="chr1", position=20, reference=1, # C + strand=1, coverage=3, mean_quality=30.0, + frequencies=np.array([1, 2, 0, 0, 0], dtype=np.int32), + score=0.0 + ) + site3 = RNASiteVariantData( + seqid="chr1", position=30, reference=2, # G + strand=1, coverage=8, mean_quality=30.0, + frequencies=np.array([0, 3, 5, 0, 0], dtype=np.int32), + score=0.0 + ) + + counter.update(site1) + counter.update(site2) + counter.update(site3) + + # ObservedSites = sum de genome_base_freqs = 3 sites + observed = counter.genome_base_freqs.sum() + self.assertEqual(observed, 3) + + # Vérifier la distribution par base + self.assertEqual(counter.genome_base_freqs[0], 1) # A + self.assertEqual(counter.genome_base_freqs[1], 1) # C + self.assertEqual(counter.genome_base_freqs[2], 1) # G + + def test_qualified_sites_count(self): + """Test que QualifiedSites compte seulement les sites passant le filtre de couverture""" + filter = SiteFilter(cov_threshold=5, edit_threshold=2) + counter = MultiCounter(filter) + + # Site 1: couverture 10 >= 5 → qualifié (référence A) + site1 = RNASiteVariantData( + seqid="chr1", position=10, reference=0, # A + strand=1, coverage=10, mean_quality=30.0, + frequencies=np.array([8, 2, 0, 0, 0], dtype=np.int32), + score=0.0 + ) + # Site 2: couverture 3 < 5 → non qualifié (référence C) + site2 = RNASiteVariantData( + seqid="chr1", position=20, reference=1, # C + strand=1, coverage=3, mean_quality=30.0, + frequencies=np.array([1, 2, 0, 0, 0], dtype=np.int32), + score=0.0 + ) + # Site 3: couverture 8 >= 5 → qualifié (référence G) + site3 = RNASiteVariantData( + seqid="chr1", position=30, reference=2, # G + strand=1, coverage=8, mean_quality=30.0, + frequencies=np.array([0, 3, 5, 0, 0], dtype=np.int32), + score=0.0 + ) + + counter.update(site1) + counter.update(site2) + counter.update(site3) + + # QualifiedSites = 2 (site1 et site3) + self.assertEqual(counter.filtered_sites_count, 2) + + # ObservedSites = 3 (tous) + self.assertEqual(counter.genome_base_freqs.sum(), 3) + + # QualifiedBases: seulement A et G (site1 et site3) + self.assertEqual(counter.filtered_base_freqs[0], 1) # A + self.assertEqual(counter.filtered_base_freqs[1], 0) # C (non qualifié) + self.assertEqual(counter.filtered_base_freqs[2], 1) # G + self.assertEqual(counter.filtered_base_freqs[3], 0) # T + + def test_merge_preserves_counts(self): + """Test que merge préserve les compteurs de sites et de bases""" + filter = SiteFilter(cov_threshold=5, edit_threshold=2) + + counter1 = MultiCounter(filter) + site1 = RNASiteVariantData( + seqid="chr1", position=10, reference=0, # A, coverage 10 >= 5 + strand=1, coverage=10, mean_quality=30.0, + frequencies=np.array([8, 2, 0, 0, 0], dtype=np.int32), + score=0.0 + ) + counter1.update(site1) + + counter2 = MultiCounter(filter) + site2 = RNASiteVariantData( + seqid="chr1", position=20, reference=1, # C, coverage 3 < 5 + strand=1, coverage=3, mean_quality=30.0, + frequencies=np.array([1, 2, 0, 0, 0], dtype=np.int32), + score=0.0 + ) + site3 = RNASiteVariantData( + seqid="chr1", position=30, reference=0, # A, coverage 6 >= 5 + strand=1, coverage=6, mean_quality=30.0, + frequencies=np.array([5, 1, 0, 0, 0], dtype=np.int32), + score=0.0 + ) + counter2.update(site2) + counter2.update(site3) + + # Merge counter2 into counter1 + counter1.merge(counter2) + + # Vérifier les totaux + self.assertEqual(counter1.genome_base_freqs.sum(), 3) # 1 + 2 = 3 sites observés + self.assertEqual(counter1.filtered_sites_count, 2) # 1 + 1 = 2 sites qualifiés (site1 et site3) + + # Vérifier les bases observées + self.assertEqual(counter1.genome_base_freqs[0], 2) # A: site1 + site3 + self.assertEqual(counter1.genome_base_freqs[1], 1) # C: site2 + + # Vérifier les bases qualifiées + self.assertEqual(counter1.filtered_base_freqs[0], 2) # A: site1 + site3 (tous deux qualifiés) + self.assertEqual(counter1.filtered_base_freqs[1], 0) # C: site2 non qualifié + self.assertEqual(counter1.filtered_base_freqs.sum(), 2) # Total = 2 + + def test_qualified_bases_distribution(self): + """Test que QualifiedBases reflète correctement la distribution des bases qualifiées""" + filter = SiteFilter(cov_threshold=5, edit_threshold=2) + counter = MultiCounter(filter) + + # Ajouter plusieurs sites avec différentes bases et couvertures + sites = [ + (0, 10), # A, qualifié + (0, 6), # A, qualifié + (1, 3), # C, non qualifié + (1, 7), # C, qualifié + (2, 2), # G, non qualifié + (2, 5), # G, qualifié + (3, 8), # T, qualifié + ] + + for ref, cov in sites: + site = RNASiteVariantData( + seqid="chr1", + position=0, + reference=ref, + strand=1, + coverage=cov, + mean_quality=30.0, + frequencies=np.array([1, 0, 0, 0, 0], dtype=np.int32), + score=0.0 + ) + counter.update(site) + + # ObservedBases: toutes les occurrences + self.assertEqual(counter.genome_base_freqs[0], 2) # A: 2 + self.assertEqual(counter.genome_base_freqs[1], 2) # C: 2 + self.assertEqual(counter.genome_base_freqs[2], 2) # G: 2 + self.assertEqual(counter.genome_base_freqs[3], 1) # T: 1 + self.assertEqual(counter.genome_base_freqs.sum(), 7) + + # QualifiedBases: seulement avec cov >= 5 + self.assertEqual(counter.filtered_base_freqs[0], 2) # A: 2 qualifiés (10, 6) + self.assertEqual(counter.filtered_base_freqs[1], 1) # C: 1 qualifié (7) + self.assertEqual(counter.filtered_base_freqs[2], 1) # G: 1 qualifié (5) + self.assertEqual(counter.filtered_base_freqs[3], 1) # T: 1 qualifié (8) + self.assertEqual(counter.filtered_base_freqs.sum(), 5) + self.assertEqual(counter.filtered_sites_count, 5) + + +class TestFeatureTotalSites(unittest.TestCase): + """Tests pour le calcul de TotalSites des features""" + + def test_simple_feature_total_sites(self): + """Test TotalSites pour une feature simple""" + feature = SeqFeature( + location=SimpleLocation(0, 100), + id="exon1", + type="exon" + ) + + # TotalSites = len(location) + total_sites = len(feature.location) + self.assertEqual(total_sites, 100) + + def test_compound_feature_total_sites(self): + """Test TotalSites pour une feature avec CompoundLocation (chimaera)""" + # Deux exons: [0-50] et [100-150] + location = CompoundLocation([ + SimpleLocation(0, 50), + SimpleLocation(100, 150) + ]) + feature = SeqFeature( + location=location, + id="chimaera", + type="exon" + ) + + # TotalSites = somme des longueurs sans gaps + total_sites = len(feature.location) + self.assertEqual(total_sites, 100) # 50 + 50 + + def test_gene_total_sites(self): + """Test TotalSites pour un gene (span complet incluant introns)""" + # Gene de 1 à 1000 + feature = SeqFeature( + location=SimpleLocation(0, 1000), + id="gene1", + type="gene" + ) + + # TotalSites = span complet + total_sites = len(feature.location) + self.assertEqual(total_sites, 1000) + + +class TestAggregateTotalSites(unittest.TestCase): + """Tests pour le calcul de TotalSites des aggregates""" + + def test_aggregate_from_continuous_features(self): + """Test TotalSites pour aggregate de features contiguës""" + pos = AggregatePositions(strand=1) + + # Trois exons adjacents + exon1 = SeqFeature(location=SimpleLocation(0, 100), id="exon1", type="exon") + exon2 = SeqFeature(location=SimpleLocation(100, 200), id="exon2", type="exon") + exon3 = SeqFeature(location=SimpleLocation(200, 300), id="exon3", type="exon") + + pos.update_from_feature(exon1) + pos.update_from_feature(exon2) + pos.update_from_feature(exon3) + + # TotalSites = end - start = 300 - 0 = 300 + total_sites = pos.end - pos.start + self.assertEqual(total_sites, 300) + + def test_aggregate_from_separated_features(self): + """Test TotalSites pour aggregate de features séparées""" + pos = AggregatePositions(strand=1) + + # Deux exons avec un gap + exon1 = SeqFeature(location=SimpleLocation(0, 100), id="exon1", type="exon") + exon2 = SeqFeature(location=SimpleLocation(200, 300), id="exon2", type="exon") + + pos.update_from_feature(exon1) + pos.update_from_feature(exon2) + + # TotalSites = end - start = 300 - 0 = 300 (inclut le gap) + total_sites = pos.end - pos.start + self.assertEqual(total_sites, 300) + + def test_aggregate_empty(self): + """Test TotalSites pour aggregate vide""" + pos = AggregatePositions(strand=0) + + # Aucune feature ajoutée + start_str, end_str, strand_str = pos.to_strings() + + # Devrait retourner "." + self.assertEqual(start_str, ".") + self.assertEqual(end_str, ".") + + +class TestIntegrationSiteCounts(unittest.TestCase): + """Tests d'intégration pour les trois colonnes""" + + def test_all_three_columns(self): + """Test que les trois colonnes sont cohérentes""" + filter = SiteFilter(cov_threshold=5, edit_threshold=2) + counter = MultiCounter(filter) + + # 5 sites avec différentes couvertures + coverages = [10, 3, 8, 4, 6] + for i, cov in enumerate(coverages): + site = RNASiteVariantData( + seqid="chr1", + position=i * 10, + reference=0, + strand=1, + coverage=cov, + mean_quality=30.0, + frequencies=np.array([cov-1, 1, 0, 0, 0], dtype=np.int32), + score=0.0 + ) + counter.update(site) + + # Feature de 0 à 50 (50 bases) + feature = SeqFeature( + location=SimpleLocation(0, 50), + id="exon1", + type="exon" + ) + + total_sites = len(feature.location) + observed_sites = counter.genome_base_freqs.sum() + qualified_sites = counter.filtered_sites_count + + # Vérifications + self.assertEqual(total_sites, 50) # Taille de la feature + self.assertEqual(observed_sites, 5) # 5 observations + self.assertEqual(qualified_sites, 3) # 3 avec cov >= 5 (10, 8, 6) + + # Relations logiques + self.assertLessEqual(observed_sites, total_sites) + self.assertLessEqual(qualified_sites, observed_sites) + + +if __name__ == '__main__': + unittest.main() From 7e080588c7b12594fcf518f02b20a70b44a36dbd Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Fri, 20 Mar 2026 10:35:43 +0000 Subject: [PATCH 30/61] add SiteBasePairingsQualified ReadBasePairingsQualified --- bin/pluviometer/multi_counter.py | 25 ++++++++++++--- bin/pluviometer/rain_file_writers.py | 47 ++++++++++++++++++---------- bin/pluviometer/test_site_counts.py | 26 +++++++++++++++ 3 files changed, 77 insertions(+), 21 deletions(-) diff --git a/bin/pluviometer/multi_counter.py b/bin/pluviometer/multi_counter.py index aa1fb2b..1e88c0c 100644 --- a/bin/pluviometer/multi_counter.py +++ b/bin/pluviometer/multi_counter.py @@ -16,6 +16,9 @@ def __init__(self, site_filter: SiteFilter) -> None: """ self.edit_read_freqs: NDArray[np.int64] = np.zeros((5, 5), dtype=np.int64) self.edit_site_freqs: NDArray[np.int64] = np.zeros((5, 5), dtype=np.int64) + self.edit_filtered_read_freqs: NDArray[np.int64] = np.zeros((5, 5), dtype=np.int64) # Reads filtered by cov + edit thresholds + self.edit_qualified_site_freqs: NDArray[np.int64] = np.zeros((5, 5), dtype=np.int64) # Non-ref sites passing cov + edit thresholds + self.edit_qualified_read_freqs: NDArray[np.int64] = np.zeros((5, 5), dtype=np.int64) # Non-ref reads passing cov + edit thresholds self.genome_base_freqs: NDArray[np.int64] = np.zeros(5, dtype=np.int64) self.filtered_base_freqs: NDArray[np.int64] = np.zeros(5, dtype=np.int64) # Bases for qualified sites @@ -32,10 +35,21 @@ def update(self, variant_data: RNASiteVariantData) -> None: self.edit_read_freqs[i, :] += variant_data.frequencies self.filter.apply(variant_data) - self.edit_site_freqs[i, :] += self.filter.frequencies - - # Count sites and bases that pass the coverage filter - if variant_data.coverage >= self.filter.cov_threshold: + # edit_site_freqs counts the number of SITES where each pairing is found (not reads): + # add 1 for each pairing that passes the filter at this site + self.edit_site_freqs[i, :] += (self.filter.frequencies > 0).astype(np.int64) + # edit_filtered_read_freqs counts reads filtered by both cov_threshold and edit_threshold + self.edit_filtered_read_freqs[i, :] += self.filter.frequencies + + # Qualified: non-reference pairings only (mask out self-pairing i→i) + non_ref_filtered = self.filter.frequencies.copy() + non_ref_filtered[i] = 0 + self.edit_qualified_site_freqs[i, :] += (non_ref_filtered > 0).astype(np.int64) + self.edit_qualified_read_freqs[i, :] += non_ref_filtered + + # Count sites and bases that are qualified: at least one NON-REFERENCE base + # must pass both cov_threshold and edit_threshold + if non_ref_filtered.any(): self.filtered_sites_count += 1 self.filtered_base_freqs[i] += 1 @@ -49,6 +63,9 @@ def merge(self, other_counter: "MultiCounter") -> None: """ self.edit_read_freqs[:] += other_counter.edit_read_freqs self.edit_site_freqs[:] += other_counter.edit_site_freqs + self.edit_filtered_read_freqs[:] += other_counter.edit_filtered_read_freqs + self.edit_qualified_site_freqs[:] += other_counter.edit_qualified_site_freqs + self.edit_qualified_read_freqs[:] += other_counter.edit_qualified_read_freqs self.genome_base_freqs[:] += other_counter.genome_base_freqs self.filtered_base_freqs[:] += other_counter.filtered_base_freqs self.filtered_sites_count += other_counter.filtered_sites_count diff --git a/bin/pluviometer/rain_file_writers.py b/bin/pluviometer/rain_file_writers.py index 847e582..2e910cb 100644 --- a/bin/pluviometer/rain_file_writers.py +++ b/bin/pluviometer/rain_file_writers.py @@ -86,7 +86,7 @@ def write_counter_data(self, counter: MultiCounter) -> int: b += self.handle.write('\t') b += self.handle.write(",".join(map(str, counter.edit_site_freqs[0:4, 0:4].flat))) b += self.handle.write('\t') - b += self.handle.write(",".join(map(str, counter.edit_read_freqs[0:4, 0:4].flat))) + b += self.handle.write(",".join(map(str, counter.edit_filtered_read_freqs[0:4, 0:4].flat))) b += self.handle.write('\n') return b @@ -110,8 +110,10 @@ class FeatureFileWriter(RainFileWriter): "QualifiedSites", "ObservedBases", "QualifiedBases", - "SiteBasePairings", - "ReadBasePairings" + "SiteBasePairingsObserved", + "ReadBasePairingsObserved", + "SiteBasePairingsQualified", + "ReadBasePairingsQualified" ] def __init__(self, handle: TextIO): @@ -147,8 +149,10 @@ def write_row_with_data( # Use the flat attribute of the NumPy arrays to flatten the matrix row-wise to attain the desired base pairing order in the output ",".join(map(str, counter.genome_base_freqs[0:4].flat)), # ObservedBases: base distribution for all observations ",".join(map(str, counter.filtered_base_freqs[0:4].flat)), # QualifiedBases: base distribution for qualified sites - ",".join(map(str, counter.edit_site_freqs[0:4, 0:4].flat)), - ",".join(map(str, counter.edit_read_freqs[0:4, 0:4].flat)), + ",".join(map(str, counter.edit_site_freqs[0:4, 0:4].flat)), # SiteBasePairingsObserved + ",".join(map(str, counter.edit_filtered_read_freqs[0:4, 0:4].flat)), # ReadBasePairingsObserved + ",".join(map(str, counter.edit_qualified_site_freqs[0:4, 0:4].flat)), # SiteBasePairingsQualified + ",".join(map(str, counter.edit_qualified_read_freqs[0:4, 0:4].flat)), # ReadBasePairingsQualified ) def write_row_without_data(self, record_id: str, feature: SeqFeature) -> int: @@ -157,7 +161,7 @@ def write_row_without_data(self, record_id: str, feature: SeqFeature) -> int: This is faster than creating dummy counter objects with zero observations. """ return self.write_metadata(record_id, feature) + self.write_data( - str(len(feature.location)), "0", "0", self.STR_ZERO_BASE_FREQS, self.STR_ZERO_BASE_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS + str(len(feature.location)), "0", "0", self.STR_ZERO_BASE_FREQS, self.STR_ZERO_BASE_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS ) @@ -180,8 +184,10 @@ class AggregateFileWriter(RainFileWriter): "QualifiedSites", "ObservedBases", "QualifiedBases", - "SiteBasePairings", - "ReadBasePairings", + "SiteBasePairingsObserved", + "ReadBasePairingsObserved", + "SiteBasePairingsQualified", + "ReadBasePairingsQualified", ] def __init__(self, handle: TextIO): @@ -267,8 +273,10 @@ def write_rows_with_data( str(aggregate_counter.filtered_sites_count), # QualifiedSites ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), # ObservedBases ",".join(map(str, aggregate_counter.filtered_base_freqs[0:4].flat)), # QualifiedBases - ",".join(map(str, aggregate_counter.edit_site_freqs[0:4, 0:4].flat)), - ",".join(map(str, aggregate_counter.edit_read_freqs[0:4, 0:4].flat)), + ",".join(map(str, aggregate_counter.edit_site_freqs[0:4, 0:4].flat)), # SiteBasePairingsObserved + ",".join(map(str, aggregate_counter.edit_filtered_read_freqs[0:4, 0:4].flat)), # ReadBasePairingsObserved + ",".join(map(str, aggregate_counter.edit_qualified_site_freqs[0:4, 0:4].flat)), # SiteBasePairingsQualified + ",".join(map(str, aggregate_counter.edit_qualified_read_freqs[0:4, 0:4].flat)), # ReadBasePairingsQualified ) return b @@ -304,8 +312,10 @@ def write_row_chimaera_with_data( str(counter.filtered_sites_count), # QualifiedSites ",".join(map(str, counter.genome_base_freqs[0:4].flat)), # ObservedBases ",".join(map(str, counter.filtered_base_freqs[0:4].flat)), # QualifiedBases - ",".join(map(str, counter.edit_site_freqs[0:4, 0:4].flat)), - ",".join(map(str, counter.edit_read_freqs[0:4, 0:4].flat)), + ",".join(map(str, counter.edit_site_freqs[0:4, 0:4].flat)), # SiteBasePairingsObserved + ",".join(map(str, counter.edit_filtered_read_freqs[0:4, 0:4].flat)), # ReadBasePairingsObserved + ",".join(map(str, counter.edit_qualified_site_freqs[0:4, 0:4].flat)), # SiteBasePairingsQualified + ",".join(map(str, counter.edit_qualified_read_freqs[0:4, 0:4].flat)), # ReadBasePairingsQualified ) return b @@ -337,12 +347,12 @@ def write_row_chimaera_without_data( ) b += self.write_data( str(len(feature.location)) if hasattr(feature, 'location') and feature.location else ".", # TotalSites - "0", "0", self.STR_ZERO_BASE_FREQS, self.STR_ZERO_BASE_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS + "0", "0", self.STR_ZERO_BASE_FREQS, self.STR_ZERO_BASE_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS ) return b - # Case like that we will add empty start end strand + # Case like that will add empty start end strand # 21 . . gene exon longest_isoform # 21 . . pseudogene exon longest_isoform # . . . pseudogene exon longest_isoform @@ -383,9 +393,12 @@ def write_rows_with_data_by_parent_type( total_sites_str, # TotalSites: calculated from aggregate positions str(aggregate_counter.genome_base_freqs.sum()), # ObservedSites str(aggregate_counter.filtered_sites_count), # QualifiedSites - ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), - ",".join(map(str, aggregate_counter.edit_site_freqs[0:4, 0:4].flat)), - ",".join(map(str, aggregate_counter.edit_read_freqs[0:4, 0:4].flat)), + ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), # ObservedBases + ",".join(map(str, aggregate_counter.filtered_base_freqs[0:4].flat)), # QualifiedBases + ",".join(map(str, aggregate_counter.edit_site_freqs[0:4, 0:4].flat)), # SiteBasePairingsObserved + ",".join(map(str, aggregate_counter.edit_filtered_read_freqs[0:4, 0:4].flat)), # ReadBasePairingsObserved + ",".join(map(str, aggregate_counter.edit_qualified_site_freqs[0:4, 0:4].flat)), # SiteBasePairingsQualified + ",".join(map(str, aggregate_counter.edit_qualified_read_freqs[0:4, 0:4].flat)), # ReadBasePairingsQualified ) return b diff --git a/bin/pluviometer/test_site_counts.py b/bin/pluviometer/test_site_counts.py index 7caeba1..215426f 100644 --- a/bin/pluviometer/test_site_counts.py +++ b/bin/pluviometer/test_site_counts.py @@ -96,6 +96,32 @@ def test_to_strings_empty(self): class TestMultiCounterSiteCounts(unittest.TestCase): """Tests pour les compteurs de sites dans MultiCounter""" + def test_site_base_pairings_counts_sites_not_reads(self): + """Test que SiteBasePairingsObserved compte les SITES (pas les reads) avec chaque pairing""" + # Reproduit le cas du test minimal: + # pos1: A ref, cov=10, [10,0,0,0] -> AA pairing passes (10>=5) + # pos2: C ref, cov=12, [0,9,0,3] -> CC passes (9>=5), CT filtered (3<5) + # pos4: T ref, cov=2 -> tout filtré (cov<5) + filter = SiteFilter(cov_threshold=5, edit_threshold=5) + counter = MultiCounter(filter) + + counter.update(RNASiteVariantData("21", 0, 0, 1, 10, 37.0, np.array([10,0,0,0,0], dtype=np.int32), 0.0)) # A site + counter.update(RNASiteVariantData("21", 1, 1, 1, 12, 37.0, np.array([0,9,0,3,0], dtype=np.int32), 0.0)) # C site + counter.update(RNASiteVariantData("21", 3, 3, 1, 2, 37.0, np.array([0,0,0,2,0], dtype=np.int32), 0.0)) # T site, cov<5 + + # SiteBasePairingsObserved: 1 site AA, 1 site CC, 0 CT (filtré), 0 TT (cov<5) + self.assertEqual(counter.edit_site_freqs[0, 0], 1) # AA: 1 site + self.assertEqual(counter.edit_site_freqs[0, 1], 0) # AC: 0 + self.assertEqual(counter.edit_site_freqs[1, 1], 1) # CC: 1 site + self.assertEqual(counter.edit_site_freqs[1, 3], 0) # CT: 0 (3 reads < edit_threshold) + self.assertEqual(counter.edit_site_freqs[3, 3], 0) # TT: 0 (cov < cov_threshold) + + # ReadBasePairings: compte les reads bruts (non filtrés) + self.assertEqual(counter.edit_read_freqs[0, 0], 10) # AA: 10 reads + self.assertEqual(counter.edit_read_freqs[1, 1], 9) # CC: 9 reads + self.assertEqual(counter.edit_read_freqs[1, 3], 3) # CT: 3 reads (non filtrés dans ReadBasePairings) + self.assertEqual(counter.edit_read_freqs[3, 3], 2) # TT: 2 reads + def test_observed_sites_count(self): """Test que ObservedSites compte toutes les observations""" filter = SiteFilter(cov_threshold=5, edit_threshold=2) From e4d3805ad07b52999ec353e985840cbb4421b399 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Fri, 20 Mar 2026 16:53:06 +0000 Subject: [PATCH 31/61] skip standardize step and do it directly within pluviometer --- README.md | 20 +-- bin/drip.py | 20 +-- bin/pluviometer/__main__.py | 40 +++--- bin/pluviometer/multi_counter.py | 27 ++-- bin/pluviometer/rain_file_writers.py | 105 ++++++--------- bin/pluviometer/test_site_counts.py | 27 ++-- modules/bash.nf | 183 --------------------------- modules/pluviometer.nf | 2 +- rain.nf | 10 +- 9 files changed, 113 insertions(+), 321 deletions(-) diff --git a/README.md b/README.md index 4c1d585..ab13173 100644 --- a/README.md +++ b/README.md @@ -221,13 +221,14 @@ The two output formats are tables of comma-separated values with a header. | Start | Positive integer | Starting position of the feature (inclusive) | | End | Positive integer | Ending position of the feature (inclusive) | | Strand | `1` or `-1` | Whether the features is located on the positive (5'->3') or negative (3'->5') strand | -| CoveredSites | Positive integer | Number of sites in the feature that satisfy the minimum level of coverage | -| GenomeBases | Comma-separated positive integers | Frequencies of the bases in the feature in the reference genome (order: A, C, G, T) | -| SiteBasePairings | Comma-separated positive integers | Number of sites in which each genome-variant base pairings is found in the feature (order: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT) | -| ReadBasePairings | Comma-separated positive integers | Frequencies of genome-variant base pairings in the feature (order: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT) | +| TotalSites | Positive integer | Number of sites in the feature | +| ObservedBases | Comma-separated positive integers | Number and type of the bases in the feature in the reference genome (order: A, C, G, T) observed. The total of the 4 values corresponds to the total observed sites (reported by the editing tools e.g. Reditools3) | +| QualifiedBases | Comma-separated positive integers | Number and type of of the bases in the feature in the reference genome (order: A, C, G, T) that satisfy the minimum level of coverage and editing. The total of the 4 values corresponds to the total qualified sites (> cov) | +| SiteBasePairingsQualified| Comma-separated positive integers | Number of sites in which each genome-variant base pairings is found at reference level in the feature (order: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT) that satisfy the minimum level of coverage and editing | +| ReadBasePairingsQualified | Comma-separated positive integers | Number of sites in which each genome-variant base pairings is found at reads level in the feature (order: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT) that satisfy the minimum level of coverage and editing | > [!note] -> The number of **CoveredSites** can be higher than the sum of **SiteBasePairings** because of the presence of ambiguous bases (e.g. N) +> The number of **QualifiedBases** can differ from sum of AA,CC,GG,TT from **SiteBasePairingsQualified** because we can have site 100% edited that will not fall into one of these categories. An example of the feature output format is shown below, with some alterations to make the text line up in columns. @@ -275,10 +276,11 @@ This hierarchical information is provided in the same manner in the aggregate fi | ParentType | String | Type of the parent of the feature under which the aggregation was done | | AggregateType | String | Type of the features that are aggregated | | AggregationMode | `all_isoforms`, `longest_isoform`, `chimaera`, `feature` or `all-sites` | Way in which the aggregation was performed | -| CoveredSites | Positive integer | Number of sites in the aggregated features that satisfy the minimum level of coverage | -| GenomeBases | Comma-separated positive integers | Frequencies of the bases in the aggregated features in the reference genome (order: A, C, G, T) | -| SiteBasePairings | Comma-separated positive integers | Number of sites in which each genome-variant base pairings is found in the aggregated features (order: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT) | -| ReadBasePairings | Comma-separated positive integers | Frequencies of genome-variant base pairings in the aggregated features (order: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT) | +| TotalSites | Positive integer | Number of sites in the aggregated features | +| ObservedBases | Comma-separated positive integers | Number and type of the bases in the aggregated features in the reference genome (order: A, C, G, T) observed. The total of the 4 values corresponds to the total observed sites (reported by the editing tools e.g. Reditools3) | | +| QualifiedBases | Comma-separated positive integers | Number and type of of the bases in the aggregated features in the reference genome (order: A, C, G, T) that satisfy the minimum level of coverage and editing. The total of the 4 values corresponds to the total qualified sites (> cov) | | +| SiteBasePairingsQualifed | Comma-separated positive integers | Number of sites in which each genome-variant base pairings is found at reference level in the aggregated features (order: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT) observed | +| ReadBasePairingsQualifed | Comma-separated positive integers | Number of sites in which each genome-variant base pairings is found at reads level in the aggregated features (order: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT) that satisfy the minimum level of coverage and editing| In the output of Pluviometer, **aggregation** is the sum of counts from several features of the same type at some feature level. For instance, exons can be aggregated at transcript level, gene level, chromosome level, and genome level. diff --git a/bin/drip.py b/bin/drip.py index d55b3a2..0b71aba 100755 --- a/bin/drip.py +++ b/bin/drip.py @@ -60,10 +60,10 @@ def print_help(): INPUT FILE FORMAT: The input files must be TSV files with the following columns: - - GenomeBases: Frequencies of bases in the reference genome (order: A, C, G, T) - - SiteBasePairings: Number of sites with each genome-variant base pairing + - ObservedBases: Frequencies of bases in the reference genome (order: A, C, G, T) + - SiteBasePairingsQualified: Number of sites with each genome-variant base pairing (qualified, filtered by cov + edit thresholds) (order: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT) - - ReadBasePairings: Frequencies of genome-variant base pairings in reads + - ReadBasePairingsQualified: Frequencies of genome-variant base pairings in reads (filtered by cov + edit thresholds) (order: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT) CALCULATED METRICS: @@ -72,7 +72,7 @@ def print_help(): For each combination XY (where X = genome base, Y = read base): 1. XY_espf (edited_sites_proportion_feature) - Proportion of XY sites in the DNA feature: - Formula: XY_SiteBasePairings / X_GenomeBases + Formula: XY_SiteBasePairingsQualified / X_ObservedBases This represents the proportion of genomic X positions that show X-to-Y variation in the feature. 2. XY_espr (edited_sites_proportion_reads) - Proportion of XY pairing in reads: @@ -161,17 +161,17 @@ def parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, includ 'GA', 'GC', 'GG', 'GT', 'TA', 'TC', 'TG', 'TT'] bases = ['A', 'C', 'G', 'T'] - # Parse GenomeBases (order: A, C, G, T) + # Parse ObservedBases (order: A, C, G, T) for i, base in enumerate(bases): - df[f'{base}_count'] = df['GenomeBases'].str.split(',').str[i].astype(int) + df[f'{base}_count'] = df['ObservedBases'].str.split(',').str[i].astype(int) - # Parse SiteBasePairings (all 16 combinations) + # Parse SiteBasePairingsQualified (all 16 combinations) for i, bp in enumerate(base_pairs): - df[f'{bp}_sites'] = df['SiteBasePairings'].str.split(',').str[i].astype(int) + df[f'{bp}_sites'] = df['SiteBasePairingsQualified'].str.split(',').str[i].astype(int) - # Parse ReadBasePairings (all 16 combinations) + # Parse ReadBasePairingsQualified (all 16 combinations) for i, bp in enumerate(base_pairs): - df[f'{bp}_reads'] = df['ReadBasePairings'].str.split(',').str[i].astype(int) + df[f'{bp}_reads'] = df['ReadBasePairingsQualified'].str.split(',').str[i].astype(int) # Calculate metrics for each base pair combination metadata_cols = ['SeqID', 'ParentIDs', 'ID', 'Mtype', 'Ptype', 'Type', 'Ctype', 'Mode', 'Start', 'End', 'Strand'] diff --git a/bin/pluviometer/__main__.py b/bin/pluviometer/__main__.py index 60d0bff..c6cbc34 100755 --- a/bin/pluviometer/__main__.py +++ b/bin/pluviometer/__main__.py @@ -96,6 +96,7 @@ def __init__( aggregate_writer: AggregateFileWriter, filter: SiteFilter, use_progress_bar: bool, + report_non_qualified: bool = False, ): self.aggregate_writer: AggregateFileWriter = aggregate_writer """Aggregate counting output is written to a temporary file through this object""" @@ -103,6 +104,9 @@ def __init__( self.feature_writer: FeatureFileWriter = feature_writer """Feature counting output is written to a temporary file through this object""" + self.report_non_qualified: bool = report_non_qualified + """If True, write feature rows even when no qualified sites are found""""" + self.filter: SiteFilter = filter """Filter applied to the reads. Filters contain a mutable state, so they should not be shared between concurrent RecordCountingContext instances""" @@ -419,7 +423,8 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> self.chimaera_aggregate_positions[feature.type].update_from_feature(feature) else: logging.debug(f"Writing feature with data: {feature.id}, type: {feature.type}") - self.feature_writer.write_row_with_data(self.record.id, feature, feature_counter) + if self.report_non_qualified or feature_counter.filtered_base_freqs.sum() > 0: + self.feature_writer.write_row_with_data(self.record.id, feature, feature_counter) # Explicitly delete counter after use to free memory del feature_counter else: @@ -431,7 +436,8 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> ) else: logging.debug(f"Writing feature without data: {feature.id}, type: {feature.type}") - self.feature_writer.write_row_without_data(self.record.id, feature) + if self.report_non_qualified: + self.feature_writer.write_row_without_data(self.record.id, feature) # all_isoforms_aggregation_counters: Optional[defaultdict[str, MultiCounter]] = None @@ -496,11 +502,10 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> total_sites_str = str(pos.end - pos.start) self.aggregate_writer.write_data( total_sites_str, - str(aggregate_counter.genome_base_freqs.sum()), - str(aggregate_counter.filtered_sites_count), ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), - ",".join(map(str, aggregate_counter.edit_site_freqs[0:4, 0:4].flat)), - ",".join(map(str, aggregate_counter.edit_read_freqs[0:4, 0:4].flat)), + ",".join(map(str, aggregate_counter.filtered_base_freqs[0:4].flat)), + ",".join(map(str, aggregate_counter.edit_qualified_site_freqs[0:4, 0:4].flat)), + ",".join(map(str, aggregate_counter.edit_qualified_read_freqs[0:4, 0:4].flat)), ) for ( @@ -531,11 +536,10 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> total_sites_str = str(pos.end - pos.start) self.aggregate_writer.write_data( total_sites_str, - str(aggregate_counter.genome_base_freqs.sum()), - str(aggregate_counter.filtered_sites_count), ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), - ",".join(map(str, aggregate_counter.edit_site_freqs[0:4, 0:4].flat)), - ",".join(map(str, aggregate_counter.edit_read_freqs[0:4, 0:4].flat)), + ",".join(map(str, aggregate_counter.filtered_base_freqs[0:4].flat)), + ",".join(map(str, aggregate_counter.edit_qualified_site_freqs[0:4, 0:4].flat)), + ",".join(map(str, aggregate_counter.edit_qualified_read_freqs[0:4, 0:4].flat)), ) else: feature_aggregation_counters, feature_aggregation_positions = self.aggregate_children(feature) @@ -564,11 +568,10 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> total_sites_str = str(pos.end - pos.start) self.aggregate_writer.write_data( total_sites_str, - str(aggregate_counter.genome_base_freqs.sum()), - str(aggregate_counter.filtered_sites_count), ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), - ",".join(map(str, aggregate_counter.edit_site_freqs[0:4, 0:4].flat)), - ",".join(map(str, aggregate_counter.edit_read_freqs[0:4, 0:4].flat)), + ",".join(map(str, aggregate_counter.filtered_base_freqs[0:4].flat)), + ",".join(map(str, aggregate_counter.edit_qualified_site_freqs[0:4, 0:4].flat)), + ",".join(map(str, aggregate_counter.edit_qualified_read_freqs[0:4, 0:4].flat)), ) # Recursively check-out children @@ -869,6 +872,12 @@ def parse_cli_input() -> argparse.Namespace: choices=["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"], help="Logging level (default: INFO). Use DEBUG to see detailed site filtering information.", ) + parser.add_argument( + "--report-non-qualified-features", + action="store_true", + default=False, + help="Report features with no qualified sites (default: skip them).", + ) parser.add_argument( "--log-to-console", action="store_true", @@ -993,7 +1002,8 @@ def _do_counting(record: SeqRecord) -> dict[str, Any]: filter: SiteFilter = SiteFilter(cov_threshold=args.cov, edit_threshold=args.edit_threshold) record_ctx: RecordCountingContext = RecordCountingContext( - feature_writer, aggregate_writer, filter, args.progress + feature_writer, aggregate_writer, filter, args.progress, + report_non_qualified=args.report_non_qualified_features, ) # Count diff --git a/bin/pluviometer/multi_counter.py b/bin/pluviometer/multi_counter.py index 1e88c0c..5f5eef1 100644 --- a/bin/pluviometer/multi_counter.py +++ b/bin/pluviometer/multi_counter.py @@ -15,8 +15,6 @@ def __init__(self, site_filter: SiteFilter) -> None: Row-columns corresponding to the same base (e.g. (0,0) -> (A,A)) represent reads where the base is unchanged """ self.edit_read_freqs: NDArray[np.int64] = np.zeros((5, 5), dtype=np.int64) - self.edit_site_freqs: NDArray[np.int64] = np.zeros((5, 5), dtype=np.int64) - self.edit_filtered_read_freqs: NDArray[np.int64] = np.zeros((5, 5), dtype=np.int64) # Reads filtered by cov + edit thresholds self.edit_qualified_site_freqs: NDArray[np.int64] = np.zeros((5, 5), dtype=np.int64) # Non-ref sites passing cov + edit thresholds self.edit_qualified_read_freqs: NDArray[np.int64] = np.zeros((5, 5), dtype=np.int64) # Non-ref reads passing cov + edit thresholds @@ -35,23 +33,16 @@ def update(self, variant_data: RNASiteVariantData) -> None: self.edit_read_freqs[i, :] += variant_data.frequencies self.filter.apply(variant_data) - # edit_site_freqs counts the number of SITES where each pairing is found (not reads): - # add 1 for each pairing that passes the filter at this site - self.edit_site_freqs[i, :] += (self.filter.frequencies > 0).astype(np.int64) - # edit_filtered_read_freqs counts reads filtered by both cov_threshold and edit_threshold - self.edit_filtered_read_freqs[i, :] += self.filter.frequencies - - # Qualified: non-reference pairings only (mask out self-pairing i→i) - non_ref_filtered = self.filter.frequencies.copy() - non_ref_filtered[i] = 0 - self.edit_qualified_site_freqs[i, :] += (non_ref_filtered > 0).astype(np.int64) - self.edit_qualified_read_freqs[i, :] += non_ref_filtered - - # Count sites and bases that are qualified: at least one NON-REFERENCE base - # must pass both cov_threshold and edit_threshold - if non_ref_filtered.any(): + # The edit_threshold is not applied to the self-pairing (X→X): coverage alone is sufficient. + if variant_data.coverage >= self.filter.cov_threshold: + self.filter.frequencies[i] = variant_data.frequencies[i] + + # A site qualifies if it passes the coverage threshold. + # Self-pairings are included unconditionally; non-ref pairings only if they also pass edit_threshold. self.filtered_sites_count += 1 self.filtered_base_freqs[i] += 1 + self.edit_qualified_site_freqs[i, :] += (self.filter.frequencies > 0).astype(np.int64) + self.edit_qualified_read_freqs[i, :] += self.filter.frequencies self.genome_base_freqs[i] += 1 @@ -62,8 +53,6 @@ def merge(self, other_counter: "MultiCounter") -> None: Add to this counter the values of another. """ self.edit_read_freqs[:] += other_counter.edit_read_freqs - self.edit_site_freqs[:] += other_counter.edit_site_freqs - self.edit_filtered_read_freqs[:] += other_counter.edit_filtered_read_freqs self.edit_qualified_site_freqs[:] += other_counter.edit_qualified_site_freqs self.edit_qualified_read_freqs[:] += other_counter.edit_qualified_read_freqs self.genome_base_freqs[:] += other_counter.genome_base_freqs diff --git a/bin/pluviometer/rain_file_writers.py b/bin/pluviometer/rain_file_writers.py index 2e910cb..19d5f0e 100644 --- a/bin/pluviometer/rain_file_writers.py +++ b/bin/pluviometer/rain_file_writers.py @@ -83,10 +83,6 @@ def write_counter_data(self, counter: MultiCounter) -> int: b: int = self.handle.write(str(counter.genome_base_freqs.sum())) b += self.handle.write('\t') b += self.handle.write(",".join(map(str, counter.genome_base_freqs[0:4].flat))) - b += self.handle.write('\t') - b += self.handle.write(",".join(map(str, counter.edit_site_freqs[0:4, 0:4].flat))) - b += self.handle.write('\t') - b += self.handle.write(",".join(map(str, counter.edit_filtered_read_freqs[0:4, 0:4].flat))) b += self.handle.write('\n') return b @@ -97,8 +93,12 @@ class FeatureFileWriter(RainFileWriter): metadata_fields: list[str] = [ "SeqID", "ParentIDs", - "FeatureID", + "ID", + "Mtype", + "Ptype", "Type", + "Ctype", + "Mode", "Start", "End", "Strand", @@ -106,12 +106,8 @@ class FeatureFileWriter(RainFileWriter): data_fields: list[str] = [ "TotalSites", - "ObservedSites", - "QualifiedSites", "ObservedBases", "QualifiedBases", - "SiteBasePairingsObserved", - "ReadBasePairingsObserved", "SiteBasePairingsQualified", "ReadBasePairingsQualified" ] @@ -131,7 +127,11 @@ def write_metadata(self, record_id: str, feature: SeqFeature) -> int: record_id, make_parent_path(feature.parent_list), feature.id, + "feature", + ".", feature.type, + ".", + ".", str(feature.location.parts[0].start + ExactPosition(1)), str(feature.location.parts[-1].end), str(feature.location.strand), @@ -143,14 +143,10 @@ def write_row_with_data( """Write the data fields (coverage, read frequency, &c) of an output line, taken from the informations in a counter object""" return self.write_metadata(record_id, feature) + self.write_data( str(len(feature.location)), # TotalSites: total positions in the feature - str(counter.genome_base_freqs.sum()), # ObservedSites: sites with REDItools observations - str(counter.filtered_sites_count), # QualifiedSites: sites passing coverage filter # Subindexing is used below because counter matrices are 5x5 because they contain pairings with N, which we don't print # Use the flat attribute of the NumPy arrays to flatten the matrix row-wise to attain the desired base pairing order in the output ",".join(map(str, counter.genome_base_freqs[0:4].flat)), # ObservedBases: base distribution for all observations ",".join(map(str, counter.filtered_base_freqs[0:4].flat)), # QualifiedBases: base distribution for qualified sites - ",".join(map(str, counter.edit_site_freqs[0:4, 0:4].flat)), # SiteBasePairingsObserved - ",".join(map(str, counter.edit_filtered_read_freqs[0:4, 0:4].flat)), # ReadBasePairingsObserved ",".join(map(str, counter.edit_qualified_site_freqs[0:4, 0:4].flat)), # SiteBasePairingsQualified ",".join(map(str, counter.edit_qualified_read_freqs[0:4, 0:4].flat)), # ReadBasePairingsQualified ) @@ -161,7 +157,7 @@ def write_row_without_data(self, record_id: str, feature: SeqFeature) -> int: This is faster than creating dummy counter objects with zero observations. """ return self.write_metadata(record_id, feature) + self.write_data( - str(len(feature.location)), "0", "0", self.STR_ZERO_BASE_FREQS, self.STR_ZERO_BASE_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS + str(len(feature.location)), self.STR_ZERO_BASE_FREQS, self.STR_ZERO_BASE_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS ) @@ -169,10 +165,12 @@ class AggregateFileWriter(RainFileWriter): metadata_fields: list[str] = [ "SeqID", "ParentIDs", - "AggregateID", - "ParentType", - "AggregateType", - "AggregationMode", + "ID", + "Mtype", + "Ptype", + "Type", + "Ctype", + "Mode", "Start", "End", "Strand", @@ -180,12 +178,8 @@ class AggregateFileWriter(RainFileWriter): data_fields: list[str] = [ "TotalSites", - "ObservedSites", - "QualifiedSites", "ObservedBases", "QualifiedBases", - "SiteBasePairingsObserved", - "ReadBasePairingsObserved", "SiteBasePairingsQualified", "ReadBasePairingsQualified", ] @@ -195,6 +189,16 @@ def __init__(self, handle: TextIO): return None + @staticmethod + def _agg_type(seq_id: str, aggregate_id: str) -> str: + """Compute the Type field for an aggregate row.""" + if aggregate_id not in (".", ""): + return "feature_agg" + elif seq_id != ".": + return "chr_agg" + else: + return "global_agg" + def write_metadata( self, seq_id: str, @@ -208,26 +212,19 @@ def write_metadata( strand: str = "." ) -> int: """Write metadata fields of an aggregate""" - b: int = self.handle.write(seq_id) - b += self.handle.write('\t') - b += self.handle.write(parent_ids) - b += self.handle.write('\t') - b += self.handle.write(aggregate_id) - b += self.handle.write('\t') - b += self.handle.write(parent_type) - b += self.handle.write('\t') - b += self.handle.write(aggregate_type) - b += self.handle.write('\t') - b += self.handle.write(aggregation_mode) - b += self.handle.write('\t') - b += self.handle.write(start) - b += self.handle.write('\t') - b += self.handle.write(end) - b += self.handle.write('\t') - b += self.handle.write(strand) - b += self.handle.write('\t') - - return b + return super().write_metadata( + seq_id, + parent_ids, + aggregate_id, + "aggregate", + parent_type, + self._agg_type(seq_id, aggregate_id), + aggregate_type, + aggregation_mode, + start, + end, + strand, + ) # Case like that we will add an empty start end strand # 21 . . . exon longest_isoform @@ -248,16 +245,13 @@ def write_rows_with_data( b: int = 0 for aggregate_type, aggregate_counter in counter_dict.items(): - b += super().write_metadata( + b += self.write_metadata( record_id, make_parent_path(parent_list), aggregate_id, feature_type, aggregate_type, aggregation_mode, - ".", - ".", - ".", ) # Calculate TotalSites from positions if available @@ -269,12 +263,8 @@ def write_rows_with_data( b += self.write_data( total_sites_str, # TotalSites: calculated from aggregate positions - str(aggregate_counter.genome_base_freqs.sum()), # ObservedSites - str(aggregate_counter.filtered_sites_count), # QualifiedSites ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), # ObservedBases ",".join(map(str, aggregate_counter.filtered_base_freqs[0:4].flat)), # QualifiedBases - ",".join(map(str, aggregate_counter.edit_site_freqs[0:4, 0:4].flat)), # SiteBasePairingsObserved - ",".join(map(str, aggregate_counter.edit_filtered_read_freqs[0:4, 0:4].flat)), # ReadBasePairingsObserved ",".join(map(str, aggregate_counter.edit_qualified_site_freqs[0:4, 0:4].flat)), # SiteBasePairingsQualified ",".join(map(str, aggregate_counter.edit_qualified_read_freqs[0:4, 0:4].flat)), # ReadBasePairingsQualified ) @@ -308,12 +298,8 @@ def write_row_chimaera_with_data( ) b += self.write_data( str(len(feature.location)) if hasattr(feature, 'location') and feature.location else ".", # TotalSites - str(counter.genome_base_freqs.sum()), # ObservedSites - str(counter.filtered_sites_count), # QualifiedSites ",".join(map(str, counter.genome_base_freqs[0:4].flat)), # ObservedBases ",".join(map(str, counter.filtered_base_freqs[0:4].flat)), # QualifiedBases - ",".join(map(str, counter.edit_site_freqs[0:4, 0:4].flat)), # SiteBasePairingsObserved - ",".join(map(str, counter.edit_filtered_read_freqs[0:4, 0:4].flat)), # ReadBasePairingsObserved ",".join(map(str, counter.edit_qualified_site_freqs[0:4, 0:4].flat)), # SiteBasePairingsQualified ",".join(map(str, counter.edit_qualified_read_freqs[0:4, 0:4].flat)), # ReadBasePairingsQualified ) @@ -347,7 +333,7 @@ def write_row_chimaera_without_data( ) b += self.write_data( str(len(feature.location)) if hasattr(feature, 'location') and feature.location else ".", # TotalSites - "0", "0", self.STR_ZERO_BASE_FREQS, self.STR_ZERO_BASE_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS + self.STR_ZERO_BASE_FREQS, self.STR_ZERO_BASE_FREQS, self.STR_ZERO_PAIRING_FREQS, self.STR_ZERO_PAIRING_FREQS ) return b @@ -370,16 +356,13 @@ def write_rows_with_data_by_parent_type( b: int = 0 for (parent_type, aggregate_type), aggregate_counter in counter_dict.items(): - b += super().write_metadata( + b += self.write_metadata( record_id, make_parent_path(parent_list), aggregate_id, parent_type, aggregate_type, aggregation_mode, - ".", - ".", - ".", ) # Calculate TotalSites from positions if available @@ -391,12 +374,8 @@ def write_rows_with_data_by_parent_type( b += self.write_data( total_sites_str, # TotalSites: calculated from aggregate positions - str(aggregate_counter.genome_base_freqs.sum()), # ObservedSites - str(aggregate_counter.filtered_sites_count), # QualifiedSites ",".join(map(str, aggregate_counter.genome_base_freqs[0:4].flat)), # ObservedBases ",".join(map(str, aggregate_counter.filtered_base_freqs[0:4].flat)), # QualifiedBases - ",".join(map(str, aggregate_counter.edit_site_freqs[0:4, 0:4].flat)), # SiteBasePairingsObserved - ",".join(map(str, aggregate_counter.edit_filtered_read_freqs[0:4, 0:4].flat)), # ReadBasePairingsObserved ",".join(map(str, aggregate_counter.edit_qualified_site_freqs[0:4, 0:4].flat)), # SiteBasePairingsQualified ",".join(map(str, aggregate_counter.edit_qualified_read_freqs[0:4, 0:4].flat)), # ReadBasePairingsQualified ) diff --git a/bin/pluviometer/test_site_counts.py b/bin/pluviometer/test_site_counts.py index 215426f..d29f60f 100644 --- a/bin/pluviometer/test_site_counts.py +++ b/bin/pluviometer/test_site_counts.py @@ -96,8 +96,8 @@ def test_to_strings_empty(self): class TestMultiCounterSiteCounts(unittest.TestCase): """Tests pour les compteurs de sites dans MultiCounter""" - def test_site_base_pairings_counts_sites_not_reads(self): - """Test que SiteBasePairingsObserved compte les SITES (pas les reads) avec chaque pairing""" + def test_site_base_pairings_counts_raw_reads(self): + """Test que edit_read_freqs compte les reads bruts (non filtrés)""" # Reproduit le cas du test minimal: # pos1: A ref, cov=10, [10,0,0,0] -> AA pairing passes (10>=5) # pos2: C ref, cov=12, [0,9,0,3] -> CC passes (9>=5), CT filtered (3<5) @@ -109,17 +109,10 @@ def test_site_base_pairings_counts_sites_not_reads(self): counter.update(RNASiteVariantData("21", 1, 1, 1, 12, 37.0, np.array([0,9,0,3,0], dtype=np.int32), 0.0)) # C site counter.update(RNASiteVariantData("21", 3, 3, 1, 2, 37.0, np.array([0,0,0,2,0], dtype=np.int32), 0.0)) # T site, cov<5 - # SiteBasePairingsObserved: 1 site AA, 1 site CC, 0 CT (filtré), 0 TT (cov<5) - self.assertEqual(counter.edit_site_freqs[0, 0], 1) # AA: 1 site - self.assertEqual(counter.edit_site_freqs[0, 1], 0) # AC: 0 - self.assertEqual(counter.edit_site_freqs[1, 1], 1) # CC: 1 site - self.assertEqual(counter.edit_site_freqs[1, 3], 0) # CT: 0 (3 reads < edit_threshold) - self.assertEqual(counter.edit_site_freqs[3, 3], 0) # TT: 0 (cov < cov_threshold) - - # ReadBasePairings: compte les reads bruts (non filtrés) + # edit_read_freqs: compte les reads bruts (non filtrés) self.assertEqual(counter.edit_read_freqs[0, 0], 10) # AA: 10 reads self.assertEqual(counter.edit_read_freqs[1, 1], 9) # CC: 9 reads - self.assertEqual(counter.edit_read_freqs[1, 3], 3) # CT: 3 reads (non filtrés dans ReadBasePairings) + self.assertEqual(counter.edit_read_freqs[1, 3], 3) # CT: 3 reads (non filtrés) self.assertEqual(counter.edit_read_freqs[3, 3], 2) # TT: 2 reads def test_observed_sites_count(self): @@ -226,7 +219,7 @@ def test_merge_preserves_counts(self): site3 = RNASiteVariantData( seqid="chr1", position=30, reference=0, # A, coverage 6 >= 5 strand=1, coverage=6, mean_quality=30.0, - frequencies=np.array([5, 1, 0, 0, 0], dtype=np.int32), + frequencies=np.array([4, 2, 0, 0, 0], dtype=np.int32), score=0.0 ) counter2.update(site2) @@ -265,6 +258,12 @@ def test_qualified_bases_distribution(self): ] for ref, cov in sites: + # Use frequencies with 2 non-ref reads so edit_threshold=2 is met when cov >= 5 + non_ref = 2 if cov >= 5 else 0 + freqs = np.zeros(5, dtype=np.int32) + freqs[ref] = cov - non_ref + # Add a non-ref pairing at index (ref+1)%4 + freqs[(ref + 1) % 4] = non_ref site = RNASiteVariantData( seqid="chr1", position=0, @@ -272,7 +271,7 @@ def test_qualified_bases_distribution(self): strand=1, coverage=cov, mean_quality=30.0, - frequencies=np.array([1, 0, 0, 0, 0], dtype=np.int32), + frequencies=freqs, score=0.0 ) counter.update(site) @@ -404,7 +403,7 @@ def test_all_three_columns(self): strand=1, coverage=cov, mean_quality=30.0, - frequencies=np.array([cov-1, 1, 0, 0, 0], dtype=np.int32), + frequencies=np.array([cov-2, 2, 0, 0, 0], dtype=np.int32), score=0.0 ) counter.update(site) diff --git a/modules/bash.nf b/modules/bash.nf index 1cc194b..7f21323 100644 --- a/modules/bash.nf +++ b/modules/bash.nf @@ -175,189 +175,6 @@ process recreate_csv_with_abs_paths { """ } -process standardize_pluvio_aggregates { - label 'bash' - tag "${tsv.baseName}" - publishDir("${params.outdir}/pluviometer/${tool_name}/standardized", mode:"copy", pattern: "*_standardized.tsv") - - input: - tuple(val(meta), path(tsv)) - - output: - tuple val(meta), path("*_standardized.tsv"), emit: standardized_tsv - - script: - def basename = tsv.baseName - tool_name = basename.split('_')[-2] - """ - awk 'BEGIN { - FS=OFS="\\t" - } - NR==1 { - # Store original header and find column indices - for(i=1; i<=NF; i++) { - col[\$i] = i - } - - # Print new header - printf "SeqID\\tParentIDs\\tID\\tMtype\\tPtype\\tType\\tCtype\\tMode\\tStart\\tEnd\\tStrand" - - # Print remaining data columns (after Strand) - for(i=col["Strand"]+1; i<=NF; i++) { - printf "\\t%s", \$i - } - printf "\\n" - next - } - { - # SeqID - seqid = \$col["SeqID"] - printf "%s\\t", seqid - - # ParentIDs - printf "%s\\t", \$col["ParentIDs"] - - # ID (from AggregateID) - aggregate_id = \$col["AggregateID"] - printf "%s\\t", aggregate_id - - # Mtype (always "aggregate") - printf "aggregate\\t" - - # Ptype (from ParentType) - printf "%s\\t", \$col["ParentType"] - - # Type (logic: feature if AggregateID has info, chr if SeqID not ".", global if SeqID is ".") - if(aggregate_id != "." && aggregate_id != "") { - type = "feature_agg" - } else if(seqid != ".") { - type = "chr_agg" - } else { - type = "global_agg" - } - printf "%s\\t", type - - # Ctype (from AggregateType) - printf "%s\\t", \$col["AggregateType"] - - # Mode (from AggregationMode) - printf "%s\\t", \$col["AggregationMode"] - - # Start, End, Strand (use existing columns if present, otherwise ".") - if("Start" in col) { - printf "%s\\t", \$col["Start"] - } else { - printf ".\\t" - } - - if("End" in col) { - printf "%s\\t", \$col["End"] - } else { - printf ".\\t" - } - - if("Strand" in col) { - printf "%s", \$col["Strand"] - } else { - printf "." - } - - # Print remaining data columns - for(i=col["Strand"]+1; i<=NF; i++) { - printf "\\t%s", \$i - } - printf "\\n" - }' ${tsv} > ${basename}_standardized.tsv - - echo "Standardization complete: ${basename}_standardized.tsv" - """ -} - -/* - * Standardize drip_features output format - * Transforms feature-level data into standardized column structure - * Output: SeqID | ParentIDs | ID | MType | Ptype | Type | Ctype | Mode | Start | End | Strand | [data...] - */ -process standardize_pluvio_features { - label 'bash' - tag "${tsv.baseName}" - publishDir("${params.outdir}/pluviometer/${tool_name}/standardized", mode:"copy", pattern: "*_standardized.tsv") - - input: - tuple(val(meta), path(tsv)) - - output: - tuple val(meta), path("*_standardized.tsv"), emit: standardized_tsv - - script: - def basename = tsv.baseName - tool_name = basename.split('_')[-2] - """ - awk 'BEGIN { - FS=OFS="\\t" - } - NR==1 { - # Store original header and find column indices - for(i=1; i<=NF; i++) { - col[\$i] = i - } - - # Print new header - printf "SeqID\\tParentIDs\\tID\\tMtype\\tPtype\\tType\\tCtype\\tMode\\tStart\\tEnd\\tStrand" - - # Print remaining data columns (after Strand if exists, or after Strand) - start_data_col = ("Strand" in col) ? col["Strand"]+1 : col["Strand"]+1 - for(i=start_data_col; i<=NF; i++) { - printf "\\t%s", \$i - } - printf "\\n" - next - } - { - # SeqID - printf "%s\\t", \$col["SeqID"] - - # ParentIDs - printf "%s\\t", \$col["ParentIDs"] - - # ID (from FeatureID) - printf "%s\\t", \$col["FeatureID"] - - # Mtype (always "feature") - printf "feature\\t" - - # Ptype (always ".") - printf ".\\t" - - # Type (from original Type column) - printf "%s\\t", \$col["Type"] - - # Ctype (always ".") - printf ".\\t" - - # Mode (always ".") - printf ".\\t" - - # Start - printf "%s\\t", \$col["Start"] - - # End - printf "%s\\t", \$col["End"] - - # Strand - printf "%s", \$col["Strand"] - - # Print remaining data columns - for(i=col["Strand"]+1; i<=NF; i++) { - printf "\\t%s", \$i - } - printf "\\n" - }' ${tsv} > ${basename}_standardized.tsv - - echo "Standardization complete: ${basename}_standardized.tsv" - """ -} - /* * Filter drip output files by AggregationMode * Creates separate files for each unique AggregationMode value diff --git a/modules/pluviometer.nf b/modules/pluviometer.nf index 3f91bf9..bc81c4a 100644 --- a/modules/pluviometer.nf +++ b/modules/pluviometer.nf @@ -1,6 +1,6 @@ process pluviometer { label "pluviometer" - publishDir("${params.outdir}/pluviometer/${tool_format}/raw", mode: "copy") + publishDir("${params.outdir}/pluviometer/${tool_format}/", mode: "copy") tag "${meta.uid}" input: diff --git a/rain.nf b/rain.nf index 331674e..093aab0 100644 --- a/rain.nf +++ b/rain.nf @@ -178,8 +178,7 @@ Report Parameters //************************************************* include { AliNe as ALIGNMENT } from "./modules/aline.nf" include {normalize_gxf} from "./modules/agat.nf" -include { extract_libtype; recreate_csv_with_abs_paths; collect_aline_csv; filter_drip_by_aggregation_mode; filter_drip_features_by_type - standardize_pluvio_aggregates; standardize_pluvio_features} from "./modules/bash.nf" +include { extract_libtype; recreate_csv_with_abs_paths; collect_aline_csv; filter_drip_by_aggregation_mode; filter_drip_features_by_type} from "./modules/bash.nf" include {bamutil_clipoverlap} from './modules/bamutil.nf' include {fastp} from './modules/fastp.nf' include {fastqc as fastqc_ali; fastqc as fastqc_dup; fastqc as fastqc_clip} from './modules/fastqc.nf' @@ -756,13 +755,10 @@ workflow { if ( "reditools3" in edit_site_tool_list ){ reditools3(final_bam_for_editing, genome.collect()) pluviometer_reditools3(reditools3.out.tuple_sample_serial_table, clean_annotation.collect(), "reditools3") - // standardize the feature and aggregate output of pluviometer to be post process in a single way in down processes - standardize_pluvio_aggregates(pluviometer_reditools3.out.tuple_sample_aggregate) - standardize_pluvio_features(pluviometer_reditools3.out.tuple_sample_feature) if(via_csv){ // drip - compute espn, espf, merge different sample in one, and output by type of mutation (AG, AC, etc..) - drip_aggregates(standardize_pluvio_aggregates.out.standardized_tsv.collect(), "aggregates") - drip_features(standardize_pluvio_features.out.standardized_tsv.collect(), "features") + drip_aggregates(pluviometer_reditools3.out.tuple_sample_aggregate.collect(), "aggregates") + drip_features(pluviometer_reditools3.out.tuple_sample_feature.collect(), "features") //christalize(drip_features.out.editing_ag, "AG") From 8b350c3562bbef0abe9e6b2009724398cf6f1d11 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Fri, 20 Mar 2026 18:45:12 +0000 Subject: [PATCH 32/61] fix tiny things --- README.md | 13 +++++++++++++ bin/drip.py | 14 +++++++------- bin/pluviometer/rain_file_writers.py | 6 +++--- 3 files changed, 23 insertions(+), 10 deletions(-) diff --git a/README.md b/README.md index ab13173..fc2be9a 100644 --- a/README.md +++ b/README.md @@ -346,6 +346,19 @@ $$ AG\ editing\ level = \sum_{i=0}^{n} \dfrac{AG_i}{AA_i + AC_i + AG_i + AT_i} $$ + +## Drip + +### espf (edited sites proportion in feature): + +denom_espf = df[f'{genome_base}_count'] # X_QualifiedBases (e.g. C_count) +df[espf_col] = df[f'{bp}_sites'] / denom_espf # XY_SiteBasePairingsQualified / X_QualifiedBases + +### espr (edited sites proportion in reads): + +df[total_reads_col] = XA_reads + XC_reads + XG_reads + XT_reads # all reads at X positions +df[espr_col] = df[f'{bp}_reads'] / df[total_reads_col] # XY_reads / sum(X*_reads) + diff --git a/bin/drip.py b/bin/drip.py index 0b71aba..66c2397 100755 --- a/bin/drip.py +++ b/bin/drip.py @@ -72,8 +72,8 @@ def print_help(): For each combination XY (where X = genome base, Y = read base): 1. XY_espf (edited_sites_proportion_feature) - Proportion of XY sites in the DNA feature: - Formula: XY_SiteBasePairingsQualified / X_ObservedBases - This represents the proportion of genomic X positions that show X-to-Y variation in the feature. + Formula: XY_SiteBasePairingsQualified / X_QualifiedBases + This represents the proportion of qualified X positions that show X-to-Y variation in the feature. 2. XY_espr (edited_sites_proportion_reads) - Proportion of XY pairing in reads: Formula: XY_ReadBasePairings / (XA + XC + XG + XT)_ReadBasePairings @@ -161,9 +161,9 @@ def parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, includ 'GA', 'GC', 'GG', 'GT', 'TA', 'TC', 'TG', 'TT'] bases = ['A', 'C', 'G', 'T'] - # Parse ObservedBases (order: A, C, G, T) + # Parse QualifiedBases (order: A, C, G, T) for i, base in enumerate(bases): - df[f'{base}_count'] = df['ObservedBases'].str.split(',').str[i].astype(int) + df[f'{base}_count'] = df['QualifiedBases'].str.split(',').str[i].astype(int) # Parse SiteBasePairingsQualified (all 16 combinations) for i, bp in enumerate(base_pairs): @@ -186,8 +186,8 @@ def parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, includ for bp in base_pairs: genome_base = bp[0] # First letter is the genome base - # Calculate espf: XY_sites / X_count - # NA when genome base count < min_cov (position not covered or not in feature) + # Calculate espf: XY_sites / X_count - edited sites proportion in feature + # NA when qualified base count < min_cov (position not covered or not qualified) espf_col = f'{col_prefix}::{bp}::espf' denom_espf = df[f'{genome_base}_count'] mask_espf = denom_espf >= min_cov @@ -196,7 +196,7 @@ def parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, includ df[espf_col] = df[espf_col].round(decimals) result_cols.append(espf_col) - # Calculate espr: XY_reads / (XA + XC + XG + XT) + # Calculate espr: XY_reads / (XA + XC + XG + XT) - edited sites proportion in reads # NA when total read coverage < min_cov (position not sequenced) total_reads_col = f'{genome_base}_total_reads' if total_reads_col not in df.columns: diff --git a/bin/pluviometer/rain_file_writers.py b/bin/pluviometer/rain_file_writers.py index 19d5f0e..c8ce341 100644 --- a/bin/pluviometer/rain_file_writers.py +++ b/bin/pluviometer/rain_file_writers.py @@ -193,11 +193,11 @@ def __init__(self, handle: TextIO): def _agg_type(seq_id: str, aggregate_id: str) -> str: """Compute the Type field for an aggregate row.""" if aggregate_id not in (".", ""): - return "feature_agg" + return "feature" elif seq_id != ".": - return "chr_agg" + return "chr" else: - return "global_agg" + return "global" def write_metadata( self, From 1142ff8f92ab10d21d4d5d5acc4aabddd45816db Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Fri, 20 Mar 2026 20:04:15 +0000 Subject: [PATCH 33/61] skip NA only lines --- bin/drip.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/bin/drip.py b/bin/drip.py index 66c2397..39595f9 100755 --- a/bin/drip.py +++ b/bin/drip.py @@ -40,6 +40,10 @@ def print_help(): Input file path, group name, sample name, and replicate ID separated by colons. All four components are required. --with-file-id Include file ID in column names (default: omit file ID) + --report-non-qualified-features + Include rows where all metric values are NA (default: omit them). + By default, rows where every sample has NA for the given base pair + are skipped (not covered or not qualified in any sample). --min-cov N Minimum read coverage threshold (default: 1). Positions with a denominator (genome base count for espf, read count for espr) strictly below this value are reported as NA instead of 0. @@ -221,7 +225,7 @@ def parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, includ del df return result -def write_base_pair_file(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id): +def write_base_pair_file(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id, report_non_qualified=False): """Worker function to write a single base pair combination file. This function is designed to be called in parallel for each base pair. @@ -245,12 +249,14 @@ def write_base_pair_file(merged, bp, metadata_cols, sample_info, output_prefix, output_file = f"{output_prefix}_{bp}.tsv" # Values are already rounded, just select, rename and write - merged[bp_cols].rename(columns=rename_dict).to_csv( - output_file, sep='\t', index=False, na_rep='NA' - ) + out_df = merged[bp_cols].rename(columns=rename_dict) + if not report_non_qualified: + metric_cols = [c for c in out_df.columns if c not in metadata_cols] + out_df = out_df[out_df[metric_cols].notna().any(axis=1)] + out_df.to_csv(output_file, sep='\t', index=False, na_rep='NA') return output_file -def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id=False, min_cov=1, threads=1, decimals=4): +def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id=False, min_cov=1, threads=1, decimals=4, report_non_qualified=False): """Merge data from multiple samples and create output matrices - one file per base pair combination.""" base_pairs = ['AA', 'AC', 'AG', 'AT', 'CA', 'CC', 'CG', 'CT', @@ -287,13 +293,13 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ if threads > 1: print(f"Writing {len(base_pairs)} output files using {threads} threads...") with multiprocessing.Pool(processes=threads) as pool: - args = [(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id) for bp in base_pairs] + args = [(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id, report_non_qualified) for bp in base_pairs] output_files = pool.starmap(write_base_pair_file, args) else: print(f"Writing {len(base_pairs)} output files sequentially...") output_files = [] for bp in base_pairs: - output_file = write_base_pair_file(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id) + output_file = write_base_pair_file(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id, report_non_qualified) output_files.append(output_file) gc.collect() @@ -327,6 +333,7 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ min_cov = 1 # Default: NA when denominator is 0; treat 0 coverage as non-observed threads = 1 # Default: sequential writing decimals = 4 # Default: round to 4 decimal places + report_non_qualified = False # Default: skip rows where all metric values are NA args_iter = iter(range(1, len(sys.argv))) for i in args_iter: @@ -350,6 +357,11 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ include_file_id = True continue + # Check for --report-non-qualified-features flag + if arg == '--report-non-qualified-features': + report_non_qualified = True + continue + # Check for --min-cov flag (supports both --min-cov=N and --min-cov N) if arg.startswith('--min-cov'): if '=' in arg: @@ -454,6 +466,6 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ sys.exit(1) # Process all samples - result = merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id, min_cov, threads, decimals) + result = merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id, min_cov, threads, decimals, report_non_qualified) print("\nAnalysis complete!") \ No newline at end of file From c7ebec29598d9239e39f7e4c6ae08595fea7d739 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Fri, 20 Mar 2026 20:05:24 +0000 Subject: [PATCH 34/61] fi report only qualified features --- bin/pluviometer/__main__.py | 54 +++++++++++++++++++--------- bin/pluviometer/rain_file_writers.py | 6 ++++ 2 files changed, 44 insertions(+), 16 deletions(-) diff --git a/bin/pluviometer/__main__.py b/bin/pluviometer/__main__.py index c6cbc34..0f3a3d7 100755 --- a/bin/pluviometer/__main__.py +++ b/bin/pluviometer/__main__.py @@ -407,10 +407,11 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> if feature_counter: if feature.is_chimaera: assert parent_feature # A chimaera must always have a parent feature (a gene) - logging.debug(f"Writing chimaera with data: {feature.id}") - self.aggregate_writer.write_row_chimaera_with_data( - self.record.id, feature, parent_feature, feature_counter - ) + if self.report_non_qualified or feature_counter.filtered_base_freqs.sum() > 0: + logging.debug(f"Writing chimaera with data: {feature.id}") + self.aggregate_writer.write_row_chimaera_with_data( + self.record.id, feature, parent_feature, feature_counter + ) self.chimaera_aggregate_counters[feature.type].merge(feature_counter) # Also track by parent_type self.chimaera_aggregate_counters_by_parent_type[(parent_feature.type, feature.type)].merge(feature_counter) @@ -430,10 +431,11 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> else: if feature.is_chimaera: assert parent_feature - logging.debug(f"Writing chimaera without data: {feature.id}") - self.aggregate_writer.write_row_chimaera_without_data( - self.record.id, feature, parent_feature - ) + if self.report_non_qualified: + logging.debug(f"Writing chimaera without data: {feature.id}") + self.aggregate_writer.write_row_chimaera_without_data( + self.record.id, feature, parent_feature + ) else: logging.debug(f"Writing feature without data: {feature.id}, type: {feature.type}") if self.report_non_qualified: @@ -478,6 +480,8 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> aggregate_type, aggregate_counter, ) in level1_longest_isoform_aggregation_counters.items(): + if not self.report_non_qualified and aggregate_counter.filtered_base_freqs.sum() == 0: + continue # Get positions for this aggregate type start_str, end_str, strand_str = ".", ".", "." if aggregate_type in level1_longest_isoform_aggregation_positions: @@ -512,6 +516,8 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> aggregate_type, aggregate_counter, ) in level1_all_isoforms_aggregation_counters.items(): + if not self.report_non_qualified and aggregate_counter.filtered_base_freqs.sum() == 0: + continue # Get positions for this aggregate type start_str, end_str, strand_str = ".", ".", "." if aggregate_type in level1_all_isoforms_aggregation_positions: @@ -544,6 +550,8 @@ def checkout(self, feature: SeqFeature, parent_feature: Optional[SeqFeature]) -> else: feature_aggregation_counters, feature_aggregation_positions = self.aggregate_children(feature) for aggregate_type, aggregate_counter in feature_aggregation_counters.items(): + if not self.report_non_qualified and aggregate_counter.filtered_base_freqs.sum() == 0: + continue # Get positions for this aggregate type start_str, end_str, strand_str = ".", ".", "." if aggregate_type in feature_aggregation_positions: @@ -1020,6 +1028,7 @@ def _do_counting(record: SeqRecord) -> dict[str, Any]: "longest_isoform", record_ctx.longest_isoform_aggregate_counters, record_ctx.longest_isoform_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) aggregate_writer.write_rows_with_data( record.id, @@ -1029,6 +1038,7 @@ def _do_counting(record: SeqRecord) -> dict[str, Any]: "all_isoforms", record_ctx.all_isoforms_aggregate_counters, record_ctx.all_isoforms_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) aggregate_writer.write_rows_with_data( record.id, @@ -1038,6 +1048,7 @@ def _do_counting(record: SeqRecord) -> dict[str, Any]: "chimaera", record_ctx.chimaera_aggregate_counters, record_ctx.chimaera_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) # Write aggregate counter data by parent type @@ -1048,6 +1059,7 @@ def _do_counting(record: SeqRecord) -> dict[str, Any]: "longest_isoform", record_ctx.longest_isoform_aggregate_counters_by_parent_type, record_ctx.longest_isoform_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) aggregate_writer.write_rows_with_data_by_parent_type( record.id, @@ -1056,6 +1068,7 @@ def _do_counting(record: SeqRecord) -> dict[str, Any]: "all_isoforms", record_ctx.all_isoforms_aggregate_counters_by_parent_type, record_ctx.all_isoforms_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) aggregate_writer.write_rows_with_data_by_parent_type( record.id, @@ -1064,6 +1077,7 @@ def _do_counting(record: SeqRecord) -> dict[str, Any]: "chimaera", record_ctx.chimaera_aggregate_counters_by_parent_type, record_ctx.chimaera_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) # Write the total counter data of the record. A dummy dict needs to be created to use the `write_rows_with_data` method @@ -1072,7 +1086,8 @@ def _do_counting(record: SeqRecord) -> dict[str, Any]: ) total_counter_dict["."] = record_ctx.total_counter aggregate_writer.write_rows_with_data( - record.id, ["."], ".", ".", "all_sites", total_counter_dict + record.id, ["."], ".", ".", "all_sites", total_counter_dict, + report_non_qualified=args.report_non_qualified_features, ) # Extract data needed for return before cleanup @@ -1416,29 +1431,35 @@ def _record_id_from_pickle(path: str) -> str: # Write genomic counts aggregate_writer.write_rows_with_data( ".", ["."], ".", ".", "longest_isoform", genome_longest_isoform_aggregate_counters, - genome_longest_isoform_aggregate_positions + genome_longest_isoform_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) aggregate_writer.write_rows_with_data( ".", ["."], ".", ".", "all_isoforms", genome_all_isoforms_aggregate_counters, - genome_all_isoforms_aggregate_positions + genome_all_isoforms_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) aggregate_writer.write_rows_with_data( ".", ["."], ".", ".", "chimaera", genome_chimaera_aggregate_counters, - genome_chimaera_aggregate_positions + genome_chimaera_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) # Write genomic counts by parent type aggregate_writer.write_rows_with_data_by_parent_type( ".", ["."], ".", "longest_isoform", genome_longest_isoform_aggregate_counters_by_parent_type, - genome_longest_isoform_aggregate_positions + genome_longest_isoform_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) aggregate_writer.write_rows_with_data_by_parent_type( ".", ["."], ".", "all_isoforms", genome_all_isoforms_aggregate_counters_by_parent_type, - genome_all_isoforms_aggregate_positions + genome_all_isoforms_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) aggregate_writer.write_rows_with_data_by_parent_type( ".", ["."], ".", "chimaera", genome_chimaera_aggregate_counters_by_parent_type, - genome_chimaera_aggregate_positions + genome_chimaera_aggregate_positions, + report_non_qualified=args.report_non_qualified_features, ) # Write the genomic total. A dummy dict needs to be created to use the `write_rows_with_data` method @@ -1447,7 +1468,8 @@ def _record_id_from_pickle(path: str) -> str: ) genomic_total_counter_dict["."] = genome_total_counter aggregate_writer.write_rows_with_data( - ".", ["."], ".", ".", "all_sites", genomic_total_counter_dict + ".", ["."], ".", ".", "all_sites", genomic_total_counter_dict, + report_non_qualified=args.report_non_qualified_features, ) logging.info("Program finished") diff --git a/bin/pluviometer/rain_file_writers.py b/bin/pluviometer/rain_file_writers.py index c8ce341..ece26db 100644 --- a/bin/pluviometer/rain_file_writers.py +++ b/bin/pluviometer/rain_file_writers.py @@ -240,11 +240,14 @@ def write_rows_with_data( aggregation_mode: str, counter_dict: defaultdict[str, MultiCounter], positions_dict: Optional[dict[str, 'AggregatePositions']] = None, + report_non_qualified: bool = True, ) -> int: """Write metadata and data fields of multiple counters of the same aggregate feature""" b: int = 0 for aggregate_type, aggregate_counter in counter_dict.items(): + if not report_non_qualified and aggregate_counter.filtered_base_freqs.sum() == 0: + continue b += self.write_metadata( record_id, make_parent_path(parent_list), @@ -351,11 +354,14 @@ def write_rows_with_data_by_parent_type( aggregation_mode: str, counter_dict: defaultdict[tuple[str, str], MultiCounter], positions_dict: Optional[dict[str, 'AggregatePositions']] = None, + report_non_qualified: bool = True, ) -> int: """Write metadata and data fields of multiple counters grouped by (parent_type, aggregate_type)""" b: int = 0 for (parent_type, aggregate_type), aggregate_counter in counter_dict.items(): + if not report_non_qualified and aggregate_counter.filtered_base_freqs.sum() == 0: + continue b += self.write_metadata( record_id, make_parent_path(parent_list), From 87562e4eeb4412e1dd62265ef71cbd790d33c54b Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Fri, 20 Mar 2026 20:09:55 +0000 Subject: [PATCH 35/61] just to re-run this part --- modules/pluviometer.nf | 2 +- modules/python.nf | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/pluviometer.nf b/modules/pluviometer.nf index bc81c4a..3ee06b3 100644 --- a/modules/pluviometer.nf +++ b/modules/pluviometer.nf @@ -15,7 +15,7 @@ process pluviometer { script: base_name = site_edits.BaseName - """ + """ pluviometer_wrapper.py \ --sites ${site_edits} \ --gff ${gff} \ diff --git a/modules/python.nf b/modules/python.nf index d5076f2..7a571d4 100644 --- a/modules/python.nf +++ b/modules/python.nf @@ -57,6 +57,6 @@ process drip { """ - drip.py --threads ${task.cpus} --output drip_${prefix} ${args_str} + drip.py --threads ${task.cpus} --output drip_${prefix} ${args_str} """ } \ No newline at end of file From 17f09006b8a0a35a768b2c0d46bc5cc8ed370735 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Fri, 20 Mar 2026 22:03:13 +0000 Subject: [PATCH 36/61] change chr by sequence to be more inclusive --- bin/pluviometer/rain_file_writers.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/pluviometer/rain_file_writers.py b/bin/pluviometer/rain_file_writers.py index ece26db..22f9c9a 100644 --- a/bin/pluviometer/rain_file_writers.py +++ b/bin/pluviometer/rain_file_writers.py @@ -195,7 +195,7 @@ def _agg_type(seq_id: str, aggregate_id: str) -> str: if aggregate_id not in (".", ""): return "feature" elif seq_id != ".": - return "chr" + return "sequence" else: return "global" From c230b7d1e714f0d1b04fd9ac5dc05016dcb506b9 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Fri, 20 Mar 2026 22:05:28 +0000 Subject: [PATCH 37/61] add space to re-run this step --- modules/pluviometer.nf | 2 +- modules/python.nf | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/modules/pluviometer.nf b/modules/pluviometer.nf index 3ee06b3..683af48 100644 --- a/modules/pluviometer.nf +++ b/modules/pluviometer.nf @@ -15,7 +15,7 @@ process pluviometer { script: base_name = site_edits.BaseName - """ + """ pluviometer_wrapper.py \ --sites ${site_edits} \ --gff ${gff} \ diff --git a/modules/python.nf b/modules/python.nf index 7a571d4..b14119e 100644 --- a/modules/python.nf +++ b/modules/python.nf @@ -56,7 +56,7 @@ process drip { def args_str = args.join(" ") """ - + drip.py --threads ${task.cpus} --output drip_${prefix} ${args_str} """ } \ No newline at end of file From 3db00a773ea0038591bd3a39ab1712cd5732f998 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Fri, 20 Mar 2026 23:35:09 +0100 Subject: [PATCH 38/61] remove row if value are only 0.0 and or NA --- README.md | 2 ++ bin/drip.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index fc2be9a..5844972 100644 --- a/README.md +++ b/README.md @@ -359,6 +359,8 @@ df[espf_col] = df[f'{bp}_sites'] / denom_espf # XY_SiteBasePairingsQualified df[total_reads_col] = XA_reads + XC_reads + XG_reads + XT_reads # all reads at X positions df[espr_col] = df[f'{bp}_reads'] / df[total_reads_col] # XY_reads / sum(X*_reads) +Drip retains a line only if at least one metric value is neither NA nor zero (i.e., at least one edit has been detected somewhere). Lines containing only NA values, only 0.0 values, or a mix of both are removed by default. + diff --git a/bin/drip.py b/bin/drip.py index 39595f9..d865117 100755 --- a/bin/drip.py +++ b/bin/drip.py @@ -252,7 +252,7 @@ def write_base_pair_file(merged, bp, metadata_cols, sample_info, output_prefix, out_df = merged[bp_cols].rename(columns=rename_dict) if not report_non_qualified: metric_cols = [c for c in out_df.columns if c not in metadata_cols] - out_df = out_df[out_df[metric_cols].notna().any(axis=1)] + out_df = out_df[(out_df[metric_cols].notna() & (out_df[metric_cols] != 0)).any(axis=1)] out_df.to_csv(output_file, sep='\t', index=False, na_rep='NA') return output_file From 65bc84d13e2df158928940c5bda5183c4524e78a Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Sat, 21 Mar 2026 18:51:25 +0100 Subject: [PATCH 39/61] add ressource to reditools3 --- config/resources/hpc.config | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/config/resources/hpc.config b/config/resources/hpc.config index b003a4e..33b9db3 100644 --- a/config/resources/hpc.config +++ b/config/resources/hpc.config @@ -25,8 +25,8 @@ process { time = '2d' } withLabel: 'jacusa2' { - cpus = 4 - time = '4h' + cpus = 16 + time = '1d' } withLabel: 'pigz' { cpus = 8 @@ -37,7 +37,7 @@ process { time = '6h' } withLabel: 'reditools3' { - cpus = 8 + cpus = 16 time = '1d' } withLabel: 'samtools' { From e13b448188d43fa198cf40ec8c5e03b77fff57bc Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Sun, 22 Mar 2026 16:18:28 +0100 Subject: [PATCH 40/61] try to catch end of reditools3 not well catched --- modules/reditools3.nf | 29 ++++++++++++++++++++++++++++- 1 file changed, 28 insertions(+), 1 deletion(-) diff --git a/modules/reditools3.nf b/modules/reditools3.nf index 50e7a4d..5e62247 100644 --- a/modules/reditools3.nf +++ b/modules/reditools3.nf @@ -32,6 +32,33 @@ process reditools3 { base_name = bam.BaseName """ - python -m reditools analyze ${bam} --reference ${genome} --strand ${strand_orientation} --output-file ${base_name}.site_edits_reditools3.txt --threads ${task.cpus} --verbose &> ${base_name}.reditools3.log + # Trap and catch up to avoid reditools not stopping properly + # [ERROR] () MD tag not present + # [ERROR] Killing job + + trap "kill 0" EXIT + + LOG=${base_name}.reditools3.log + OUT=${base_name}.site_edits_reditools3.txt + + python -m reditools analyze ${bam} \ + --reference ${genome} \ + --strand ${strand_orientation} \ + --output-file \$OUT \ + --threads ${task.cpus} \ + --verbose \ + &> \$LOG || exit 1 + + # 🔥 Détection erreurs silencieuses REDItools + if grep -qE "\\[ERROR\\]|Killing job" \$LOG; then + echo "REDItools a échoué pour $bam" + exit 1 + fi + + # 🔍 Vérification output non vide (optionnel mais recommandé) + if [ ! -s "\$OUT" ]; then + echo "Output vide → erreur probable" + exit 1 + fi """ } From f9bdb26510d46a18251e5deefdc9c466e8f118fe Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Sun, 22 Mar 2026 18:01:12 +0100 Subject: [PATCH 41/61] back to previous --- modules/reditools3.nf | 29 +---------------------------- 1 file changed, 1 insertion(+), 28 deletions(-) diff --git a/modules/reditools3.nf b/modules/reditools3.nf index 5e62247..50e7a4d 100644 --- a/modules/reditools3.nf +++ b/modules/reditools3.nf @@ -32,33 +32,6 @@ process reditools3 { base_name = bam.BaseName """ - # Trap and catch up to avoid reditools not stopping properly - # [ERROR] () MD tag not present - # [ERROR] Killing job - - trap "kill 0" EXIT - - LOG=${base_name}.reditools3.log - OUT=${base_name}.site_edits_reditools3.txt - - python -m reditools analyze ${bam} \ - --reference ${genome} \ - --strand ${strand_orientation} \ - --output-file \$OUT \ - --threads ${task.cpus} \ - --verbose \ - &> \$LOG || exit 1 - - # 🔥 Détection erreurs silencieuses REDItools - if grep -qE "\\[ERROR\\]|Killing job" \$LOG; then - echo "REDItools a échoué pour $bam" - exit 1 - fi - - # 🔍 Vérification output non vide (optionnel mais recommandé) - if [ ! -s "\$OUT" ]; then - echo "Output vide → erreur probable" - exit 1 - fi + python -m reditools analyze ${bam} --reference ${genome} --strand ${strand_orientation} --output-file ${base_name}.site_edits_reditools3.txt --threads ${task.cpus} --verbose &> ${base_name}.reditools3.log """ } From 80294f31e8802a2e8a6ec1b904e7b2b91b105a47 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Sun, 22 Mar 2026 18:01:46 +0100 Subject: [PATCH 42/61] add calmd to fix deteriorated MD tags --- modules/samtools.nf | 23 +++++++++++++++++++++++ rain.nf | 16 ++++++++++++---- 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/modules/samtools.nf b/modules/samtools.nf index 89a6bd4..d3576f1 100644 --- a/modules/samtools.nf +++ b/modules/samtools.nf @@ -151,4 +151,27 @@ process samtools_merge_bams { """ samtools merge -@ ${task.cpus} ${meta.uid}_merged.bam ${bam1} ${bam2} """ +} + +/* + * Calculate MD tag for BAM file + * The MD tag is required by some tools like REDItools +* It encodes the reference bases that differ from the aligned read bases + * This tag can be missing or incorrect in BAM files that have been processed (e.g. with Picard MarkDuplicates, bamutil) and needs to be recomputed to ensure accurate editing site detection. + */ +process samtools_calmd { + label "samtools" + tag "${meta.uid}" + + input: + tuple val(meta), path(bam) + path(reference) + + output: + tuple val(meta), path("*_md.bam"), emit: tuple_sample_bam_with_md + + script: + """ + samtools calmd -@ ${task.cpus} -b ${bam} ${reference} > ${bam.baseName}_md.bam + """ } \ No newline at end of file diff --git a/rain.nf b/rain.nf index 093aab0..0a103c8 100644 --- a/rain.nf +++ b/rain.nf @@ -186,7 +186,7 @@ include {gatk_markduplicates } from './modules/gatk.nf' include {jacusa2} from "./modules/jacusa2.nf" include {multiqc} from './modules/multiqc.nf' include {fasta_unzip} from "$baseDir/modules/pigz.nf" -include {samtools_index; samtools_fasta_index; samtools_sort_bam as samtools_sort_bam_raw; samtools_sort_bam as samtools_sort_bam_merged; samtools_split_mapped_unmapped; samtools_merge_bams} from './modules/samtools.nf' +include {samtools_index; samtools_fasta_index; samtools_sort_bam as samtools_sort_bam_raw; samtools_sort_bam as samtools_sort_bam_merged; samtools_split_mapped_unmapped; samtools_merge_bams; samtools_calmd} from './modules/samtools.nf' include {reditools2} from "./modules/reditools2.nf" include {reditools3} from "./modules/reditools3.nf" include {pluviometer as pluviometer_jacusa2; pluviometer as pluviometer_reditools2; pluviometer as pluviometer_reditools3; pluviometer as pluviometer_sapin} from "./modules/pluviometer.nf" @@ -719,7 +719,7 @@ workflow { // Clip overlap if (params.clip_overlap) { - bamutil_clipoverlap(gatk_markduplicates.out.tuple_sample_dedupbam) + bamutil_clipoverlap(all_bam_sorted_dedup) all_bam_sorted_dedup_clip = bamutil_clipoverlap.out.tuple_sample_clipoverbam // stat on bam with overlap clipped if(params.fastqc){ @@ -727,11 +727,19 @@ workflow { logs.concat(fastqc_clip.out).set{logs} // save log } } else { - all_bam_sorted_dedup_clip = gatk_markduplicates.out.tuple_sample_dedupbam + all_bam_sorted_dedup_clip = all_bam_sorted_dedup } + // Recompute MD tag + if (params.clip_overlap || params.clean_duplicate) { + samtools_calmd(all_bam_sorted_dedup_clip, genome.collect()) + all_bam_sorted_dedup_clip_md = samtools_calmd.out.tuple_sample_bam_with_md + } else { + all_bam_sorted_dedup_clip_md = all_bam_sorted_dedup_clip + } + // index mapped bam - samtools_index(all_bam_sorted_dedup_clip) + samtools_index(all_bam_sorted_dedup_clip_md) final_bam_for_editing = samtools_index.out.tuple_sample_bam_bamindex // ------------------------------------------------------- From 031334853911c4146335497eb118e8570ad4aebf Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Sun, 22 Mar 2026 20:10:37 +0000 Subject: [PATCH 43/61] try to decrease memory footprint --- bin/drip.py | 329 +++++++++++++++++++++++++++++----------------------- 1 file changed, 185 insertions(+), 144 deletions(-) diff --git a/bin/drip.py b/bin/drip.py index d865117..d4ce6e6 100755 --- a/bin/drip.py +++ b/bin/drip.py @@ -151,166 +151,207 @@ def print_help(): print(help_text) sys.exit(0) -def parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, include_file_id=False, min_cov=1, decimals=4): - """Parse a single TSV file and extract editing metrics for all base pair combinations.""" - # SeqID, Start, End, Strand contain mixed values ("." and actual numbers/strings) - # → force them to string to avoid DtypeWarning and preserve "." as-is - mixed_cols = {'SeqID': str, 'Start': str, 'End': str, 'Strand': str} - df = pd.read_csv(filepath, sep='\t', dtype=mixed_cols) - - # DO NOT filter out rows where ID is '.' - # These are special aggregate rows (e.g., all_sites) that should be kept - # Base pair combinations in order - base_pairs = ['AA', 'AC', 'AG', 'AT', 'CA', 'CC', 'CG', 'CT', - 'GA', 'GC', 'GG', 'GT', 'TA', 'TC', 'TG', 'TT'] - bases = ['A', 'C', 'G', 'T'] - - # Parse QualifiedBases (order: A, C, G, T) - for i, base in enumerate(bases): - df[f'{base}_count'] = df['QualifiedBases'].str.split(',').str[i].astype(int) +def _parse_comma_col_at(series, idx): + """Extract the integer at comma-based index `idx` from a packed string column.""" + return series.str.split(',', expand=True)[idx].astype(np.int64) - # Parse SiteBasePairingsQualified (all 16 combinations) - for i, bp in enumerate(base_pairs): - df[f'{bp}_sites'] = df['SiteBasePairingsQualified'].str.split(',').str[i].astype(int) - - # Parse ReadBasePairingsQualified (all 16 combinations) - for i, bp in enumerate(base_pairs): - df[f'{bp}_reads'] = df['ReadBasePairingsQualified'].str.split(',').str[i].astype(int) - - # Calculate metrics for each base pair combination - metadata_cols = ['SeqID', 'ParentIDs', 'ID', 'Mtype', 'Ptype', 'Type', 'Ctype', 'Mode', 'Start', 'End', 'Strand'] - result_cols = metadata_cols.copy() - - # Create column prefix with group::sample::replicate::file_id or group::sample::replicate - if include_file_id: - col_prefix = f'{group_name}::{sample_name}::{replicate}::{file_id}' - else: - col_prefix = f'{group_name}::{sample_name}::{replicate}' - - for bp in base_pairs: - genome_base = bp[0] # First letter is the genome base - - # Calculate espf: XY_sites / X_count - edited sites proportion in feature - # NA when qualified base count < min_cov (position not covered or not qualified) - espf_col = f'{col_prefix}::{bp}::espf' - denom_espf = df[f'{genome_base}_count'] - mask_espf = denom_espf >= min_cov - df[espf_col] = np.where(mask_espf, df[f'{bp}_sites'] / denom_espf.where(mask_espf, 1), np.nan) - # Round immediately to reduce RAM footprint during merge - df[espf_col] = df[espf_col].round(decimals) - result_cols.append(espf_col) - - # Calculate espr: XY_reads / (XA + XC + XG + XT) - edited sites proportion in reads - # NA when total read coverage < min_cov (position not sequenced) - total_reads_col = f'{genome_base}_total_reads' - if total_reads_col not in df.columns: - df[total_reads_col] = ( - df[f'{genome_base}A_reads'] + - df[f'{genome_base}C_reads'] + - df[f'{genome_base}G_reads'] + - df[f'{genome_base}T_reads'] - ) - - espr_col = f'{col_prefix}::{bp}::espr' - denom_espr = df[total_reads_col] - mask_espr = denom_espr >= min_cov - df[espr_col] = np.where(mask_espr, df[f'{bp}_reads'] / denom_espr.where(mask_espr, 1), np.nan) - # Round immediately to reduce RAM footprint during merge - df[espr_col] = df[espr_col].round(decimals) - result_cols.append(espr_col) - - # Select only needed columns — list indexing creates a new DataFrame, - # so we can free the large intermediate df immediately. - result = df[result_cols] + +def parse_tsv_file_for_bp(filepath, bp, group_name, sample_name, replicate, file_id, + include_file_id=False, min_cov=1, decimals=4): + """Parse one TSV file and compute espf/espr for a single base pair `bp` only. + + Loads only the three packed data columns + metadata (via usecols) and discards + them immediately after extracting the two needed integer vectors. Returns a + slim DataFrame with 11 metadata columns + 2 float metric columns. + Peak RAM per call ≈ raw CSV in memory + a few integer Series. + """ + ALL_BPS = ['AA', 'AC', 'AG', 'AT', 'CA', 'CC', 'CG', 'CT', + 'GA', 'GC', 'GG', 'GT', 'TA', 'TC', 'TG', 'TT'] + BASES = ['A', 'C', 'G', 'T'] + + genome_base = bp[0] + bp_idx = ALL_BPS.index(bp) + gb_idx = BASES.index(genome_base) # 0–3 + gb_offset = gb_idx * 4 # first XA/XC/XG/XT index in 16-value vector + + metadata_cols = ['SeqID', 'ParentIDs', 'ID', 'Mtype', 'Ptype', 'Type', + 'Ctype', 'Mode', 'Start', 'End', 'Strand'] + needed_cols = metadata_cols + ['QualifiedBases', + 'SiteBasePairingsQualified', + 'ReadBasePairingsQualified'] + mixed_dtypes = {'SeqID': str, 'Start': str, 'End': str, 'Strand': str} + + df = pd.read_csv(filepath, sep='\t', dtype=mixed_dtypes, usecols=needed_cols) + + # espf numerator / denominator + bp_sites = _parse_comma_col_at(df['SiteBasePairingsQualified'], bp_idx) + x_count = _parse_comma_col_at(df['QualifiedBases'], gb_idx) + + # espr: expand ReadBasePairingsQualified once, pick 5 values, discard the rest + reads_split = df['ReadBasePairingsQualified'].str.split(',', expand=True) + bp_reads = reads_split[bp_idx].astype(np.int64) + total_reads = (reads_split[gb_offset ].astype(np.int64) + + reads_split[gb_offset + 1].astype(np.int64) + + reads_split[gb_offset + 2].astype(np.int64) + + reads_split[gb_offset + 3].astype(np.int64)) + del reads_split + + # Keep only metadata columns; drop the three packed columns now + result = df[metadata_cols].copy() del df + + col_prefix = (f'{group_name}::{sample_name}::{replicate}::{file_id}' + if include_file_id + else f'{group_name}::{sample_name}::{replicate}') + + mask_f = x_count >= min_cov + result[f'{col_prefix}::espf'] = np.where( + mask_f, bp_sites / x_count.where(mask_f, 1), np.nan + ) + result[f'{col_prefix}::espf'] = result[f'{col_prefix}::espf'].round(decimals) + + mask_r = total_reads >= min_cov + result[f'{col_prefix}::espr'] = np.where( + mask_r, bp_reads / total_reads.where(mask_r, 1), np.nan + ) + result[f'{col_prefix}::espr'] = result[f'{col_prefix}::espr'].round(decimals) + return result -def write_base_pair_file(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id, report_non_qualified=False): - """Worker function to write a single base pair combination file. - - This function is designed to be called in parallel for each base pair. - Note: Values are already rounded in parse_tsv_file() to reduce RAM usage. + +def _compute_bp_from_df(df, bp, bp_idx, gb_idx, gb_offset, col_prefix, min_cov, decimals): + """Compute espf/espr for one BP from an already-loaded DataFrame. + + Returns a slim DataFrame (11 metadata cols + 2 metric cols). + `df` must already contain the pre-parsed columns: + SiteBasePairingsQualified, QualifiedBases, ReadBasePairingsQualified + as well as the 11 metadata columns. """ - bp_cols = metadata_cols.copy() - rename_dict = {} - for _, group_name, sample_name, replicate, file_id in sample_info: - if include_file_id: - col_prefix = f'{group_name}::{sample_name}::{replicate}::{file_id}' - else: - col_prefix = f'{group_name}::{sample_name}::{replicate}' - espf_col = f'{col_prefix}::{bp}::espf' - espr_col = f'{col_prefix}::{bp}::espr' - if espf_col in merged.columns: - bp_cols.append(espf_col) - rename_dict[espf_col] = f'{col_prefix}::espf' - if espr_col in merged.columns: - bp_cols.append(espr_col) - rename_dict[espr_col] = f'{col_prefix}::espr' - - output_file = f"{output_prefix}_{bp}.tsv" - # Values are already rounded, just select, rename and write - out_df = merged[bp_cols].rename(columns=rename_dict) + metadata_cols = ['SeqID', 'ParentIDs', 'ID', 'Mtype', 'Ptype', 'Type', + 'Ctype', 'Mode', 'Start', 'End', 'Strand'] + + bp_sites = _parse_comma_col_at(df['SiteBasePairingsQualified'], bp_idx) + x_count = _parse_comma_col_at(df['QualifiedBases'], gb_idx) + + reads_split = df['ReadBasePairingsQualified'].str.split(',', expand=True) + bp_reads = reads_split[bp_idx].astype(np.int64) + total_reads = (reads_split[gb_offset ].astype(np.int64) + + reads_split[gb_offset + 1].astype(np.int64) + + reads_split[gb_offset + 2].astype(np.int64) + + reads_split[gb_offset + 3].astype(np.int64)) + del reads_split + + result = df[metadata_cols].copy() + + mask_f = x_count >= min_cov + result[f'{col_prefix}::espf'] = np.where( + mask_f, bp_sites / x_count.where(mask_f, 1), np.nan + ).round(decimals) + + mask_r = total_reads >= min_cov + result[f'{col_prefix}::espr'] = np.where( + mask_r, bp_reads / total_reads.where(mask_r, 1), np.nan + ).round(decimals) + + return result + + +def _write_one_bp(bp, accumulated, output_prefix, metadata_cols, report_non_qualified): + """Sort, filter and write the accumulated DataFrame for one BP.""" + merged = accumulated.sort_values(['SeqID', 'ParentIDs', 'Mode']) + if not report_non_qualified: - metric_cols = [c for c in out_df.columns if c not in metadata_cols] - out_df = out_df[(out_df[metric_cols].notna() & (out_df[metric_cols] != 0)).any(axis=1)] - out_df.to_csv(output_file, sep='\t', index=False, na_rep='NA') + metric_cols = [c for c in merged.columns if c not in metadata_cols] + merged = merged[ + (merged[metric_cols].notna() & (merged[metric_cols] != 0)).any(axis=1) + ] + + output_file = f'{output_prefix}_{bp}.tsv' + merged.to_csv(output_file, sep='\t', index=False, na_rep='NA') + n_rows = len(merged) + del merged + gc.collect() + print(f' {output_file} ({n_rows} rows)') return output_file -def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id=False, min_cov=1, threads=1, decimals=4, report_non_qualified=False): - """Merge data from multiple samples and create output matrices - one file per base pair combination.""" - - base_pairs = ['AA', 'AC', 'AG', 'AT', 'CA', 'CC', 'CG', 'CT', - 'GA', 'GC', 'GG', 'GT', 'TA', 'TC', 'TG', 'TT'] - metadata_cols = ['SeqID', 'ParentIDs', 'ID', 'Mtype', 'Ptype', 'Type', 'Ctype', 'Mode', 'Start', 'End', 'Strand'] +def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id=False, + min_cov=1, threads=1, decimals=4, report_non_qualified=False): + """Produce one output file per base pair combination. - # Collect sample metadata without loading data yet - sample_info = [] # (filepath, group_name, sample_name, replicate, file_id) + Memory strategy: each input file is read exactly once. All 16 BP + accumulators are updated in a single pass over the inputs, so peak RAM is: + one raw input file + 16 × (11 + 2·N cols) × nrows + vs the old approach which was one huge (11 + 32·N cols) DF picklied 16×. + After accumulation the 16 output files are written (optionally in parallel) + and each accumulator is freed immediately. + """ + ALL_BPS = ['AA', 'AC', 'AG', 'AT', 'CA', 'CC', 'CG', 'CT', + 'GA', 'GC', 'GG', 'GT', 'TA', 'TC', 'TG', 'TT'] + BASES = ['A', 'C', 'G', 'T'] + metadata_cols = ['SeqID', 'ParentIDs', 'ID', 'Mtype', 'Ptype', 'Type', + 'Ctype', 'Mode', 'Start', 'End', 'Strand'] + needed_cols = metadata_cols + ['QualifiedBases', + 'SiteBasePairingsQualified', + 'ReadBasePairingsQualified'] + mixed_dtypes = {'SeqID': str, 'Start': str, 'End': str, 'Strand': str} + + # Pre-compute per-BP constants once + bp_meta = [] + for bp in ALL_BPS: + gb_idx = BASES.index(bp[0]) + bp_meta.append((ALL_BPS.index(bp), gb_idx, gb_idx * 4)) + + sample_info = [] for filepath, (group_name, sample_name, replicate) in file_group_sample_replicate_dict.items(): filename_stem = Path(filepath).stem file_id = '_'.join(filename_stem.split('_')[:-1]) - sample_info.append((filepath, group_name, sample_name, replicate, file_id)) - - # Incremental merge: load one sample at a time and free it immediately after merging. - # Peak RAM = merged_so_far + one_new_sample (instead of N samples simultaneously). - merged = None - for filepath, group_name, sample_name, replicate, file_id in sample_info: - print(f"Processing {group_name}::{sample_name} (replicate {replicate}) from {filepath} (file_id: {file_id})...") - data = parse_tsv_file(filepath, group_name, sample_name, replicate, file_id, include_file_id, min_cov, decimals) - if merged is None: - merged = data - else: - merged = merged.merge(data, on=metadata_cols, how='outer') - del data - gc.collect() + col_prefix = (f'{group_name}::{sample_name}::{replicate}::{file_id}' + if include_file_id + else f'{group_name}::{sample_name}::{replicate}') + sample_info.append((filepath, col_prefix)) + + # One pass: read each file once, accumulate into 16 BP DataFrames + accumulators = [None] * len(ALL_BPS) + + for filepath, col_prefix in sample_info: + print(f'Reading {filepath}...') + df = pd.read_csv(filepath, sep='\t', dtype=mixed_dtypes, usecols=needed_cols) + + for i, bp in enumerate(ALL_BPS): + bp_idx, gb_idx, gb_offset = bp_meta[i] + bp_data = _compute_bp_from_df( + df, bp, bp_idx, gb_idx, gb_offset, col_prefix, min_cov, decimals + ) + if accumulators[i] is None: + accumulators[i] = bp_data + else: + accumulators[i] = accumulators[i].merge(bp_data, on=metadata_cols, how='outer') + del bp_data + + del df + gc.collect() - # Do NOT fill NA with 0: NA means not covered / below min_cov, - # which is distinct from 0 (covered but no editing observed). - merged = merged.sort_values(['SeqID', 'ParentIDs', 'Mode']) + print(f'\nWriting {len(ALL_BPS)} output files...') + write_args = [ + (ALL_BPS[i], accumulators[i], output_prefix, metadata_cols, report_non_qualified) + for i in range(len(ALL_BPS)) + ] - # Write one file per base pair combination. - # Parallelize if threads > 1: each base pair file is written by a separate worker. - # Note: values are already rounded in parse_tsv_file() to reduce RAM during merge. if threads > 1: - print(f"Writing {len(base_pairs)} output files using {threads} threads...") - with multiprocessing.Pool(processes=threads) as pool: - args = [(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id, report_non_qualified) for bp in base_pairs] - output_files = pool.starmap(write_base_pair_file, args) + n_workers = min(threads, len(ALL_BPS)) + with multiprocessing.Pool(processes=n_workers) as pool: + output_files = pool.starmap(_write_one_bp, write_args) else: - print(f"Writing {len(base_pairs)} output files sequentially...") - output_files = [] - for bp in base_pairs: - output_file = write_base_pair_file(merged, bp, metadata_cols, sample_info, output_prefix, include_file_id, report_non_qualified) - output_files.append(output_file) - gc.collect() - - print(f"\nOutput files created:") - for output_file in output_files: - print(f" - {output_file}") - print(f" - {len(merged)} aggregates per file") - print(f" - {len(sample_info)} samples: {', '.join(si[2] for si in sample_info)}") - print(f" - {len(base_pairs)} files (one per base pair combination)") - - return merged + output_files = [_write_one_bp(*a) for a in write_args] + + # Free accumulators after writing + for i in range(len(ALL_BPS)): + accumulators[i] = None + gc.collect() + + print(f'\nDone.') + print(f' {len(sample_info)} samples processed.') + print(f' {len(ALL_BPS)} output files written.') # Example usage if __name__ == "__main__": From ab91c901353b58349a6faf5cfe8e761e0884180c Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Mon, 23 Mar 2026 08:03:05 +0000 Subject: [PATCH 44/61] try to improve RAM --- bin/drip.py | 137 ++++++++++++++++------------------------------------ 1 file changed, 41 insertions(+), 96 deletions(-) diff --git a/bin/drip.py b/bin/drip.py index d4ce6e6..b174706 100755 --- a/bin/drip.py +++ b/bin/drip.py @@ -219,46 +219,30 @@ def parse_tsv_file_for_bp(filepath, bp, group_name, sample_name, replicate, file return result -def _compute_bp_from_df(df, bp, bp_idx, gb_idx, gb_offset, col_prefix, min_cov, decimals): - """Compute espf/espr for one BP from an already-loaded DataFrame. - - Returns a slim DataFrame (11 metadata cols + 2 metric cols). - `df` must already contain the pre-parsed columns: - SiteBasePairingsQualified, QualifiedBases, ReadBasePairingsQualified - as well as the 11 metadata columns. +def _process_one_bp(bp, sample_info, output_prefix, metadata_cols, + include_file_id, min_cov, decimals, report_non_qualified): + """Load, merge, filter and write the output file for one base-pair combination. + + Memory model: at most one raw input file + the cumulative merged result + (11 + 2·N columns) live in RAM simultaneously. The previous approach kept + 11 + 32·N columns in RAM across all base pairs before writing any file. + In parallel mode each worker runs this function independently — no shared + DataFrame is pickled across processes. """ - metadata_cols = ['SeqID', 'ParentIDs', 'ID', 'Mtype', 'Ptype', 'Type', - 'Ctype', 'Mode', 'Start', 'End', 'Strand'] - - bp_sites = _parse_comma_col_at(df['SiteBasePairingsQualified'], bp_idx) - x_count = _parse_comma_col_at(df['QualifiedBases'], gb_idx) - - reads_split = df['ReadBasePairingsQualified'].str.split(',', expand=True) - bp_reads = reads_split[bp_idx].astype(np.int64) - total_reads = (reads_split[gb_offset ].astype(np.int64) - + reads_split[gb_offset + 1].astype(np.int64) - + reads_split[gb_offset + 2].astype(np.int64) - + reads_split[gb_offset + 3].astype(np.int64)) - del reads_split - - result = df[metadata_cols].copy() - - mask_f = x_count >= min_cov - result[f'{col_prefix}::espf'] = np.where( - mask_f, bp_sites / x_count.where(mask_f, 1), np.nan - ).round(decimals) - - mask_r = total_reads >= min_cov - result[f'{col_prefix}::espr'] = np.where( - mask_r, bp_reads / total_reads.where(mask_r, 1), np.nan - ).round(decimals) - - return result - + merged = None + for filepath, group_name, sample_name, replicate, file_id in sample_info: + data = parse_tsv_file_for_bp( + filepath, bp, group_name, sample_name, replicate, + file_id, include_file_id, min_cov, decimals + ) + if merged is None: + merged = data + else: + merged = merged.merge(data, on=metadata_cols, how='outer') + del data + gc.collect() -def _write_one_bp(bp, accumulated, output_prefix, metadata_cols, report_non_qualified): - """Sort, filter and write the accumulated DataFrame for one BP.""" - merged = accumulated.sort_values(['SeqID', 'ParentIDs', 'Mode']) + merged = merged.sort_values(['SeqID', 'ParentIDs', 'Mode']) if not report_non_qualified: metric_cols = [c for c in merged.columns if c not in metadata_cols] @@ -278,80 +262,41 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ min_cov=1, threads=1, decimals=4, report_non_qualified=False): """Produce one output file per base pair combination. - Memory strategy: each input file is read exactly once. All 16 BP - accumulators are updated in a single pass over the inputs, so peak RAM is: - one raw input file + 16 × (11 + 2·N cols) × nrows - vs the old approach which was one huge (11 + 32·N cols) DF picklied 16×. - After accumulation the 16 output files are written (optionally in parallel) - and each accumulator is freed immediately. + Memory strategy: process one base pair at a time. Each iteration reads + every input file once but keeps only 11 + 2·N columns in RAM (vs the old + 11 + 32·N columns for all base pairs at once). In parallel mode each + worker owns its data entirely — nothing large is pickled between processes. + Trade-off: input files are read 16 times instead of once. """ - ALL_BPS = ['AA', 'AC', 'AG', 'AT', 'CA', 'CC', 'CG', 'CT', - 'GA', 'GC', 'GG', 'GT', 'TA', 'TC', 'TG', 'TT'] - BASES = ['A', 'C', 'G', 'T'] + base_pairs = ['AA', 'AC', 'AG', 'AT', 'CA', 'CC', 'CG', 'CT', + 'GA', 'GC', 'GG', 'GT', 'TA', 'TC', 'TG', 'TT'] metadata_cols = ['SeqID', 'ParentIDs', 'ID', 'Mtype', 'Ptype', 'Type', 'Ctype', 'Mode', 'Start', 'End', 'Strand'] - needed_cols = metadata_cols + ['QualifiedBases', - 'SiteBasePairingsQualified', - 'ReadBasePairingsQualified'] - mixed_dtypes = {'SeqID': str, 'Start': str, 'End': str, 'Strand': str} - - # Pre-compute per-BP constants once - bp_meta = [] - for bp in ALL_BPS: - gb_idx = BASES.index(bp[0]) - bp_meta.append((ALL_BPS.index(bp), gb_idx, gb_idx * 4)) sample_info = [] for filepath, (group_name, sample_name, replicate) in file_group_sample_replicate_dict.items(): filename_stem = Path(filepath).stem file_id = '_'.join(filename_stem.split('_')[:-1]) - col_prefix = (f'{group_name}::{sample_name}::{replicate}::{file_id}' - if include_file_id - else f'{group_name}::{sample_name}::{replicate}') - sample_info.append((filepath, col_prefix)) - - # One pass: read each file once, accumulate into 16 BP DataFrames - accumulators = [None] * len(ALL_BPS) - - for filepath, col_prefix in sample_info: - print(f'Reading {filepath}...') - df = pd.read_csv(filepath, sep='\t', dtype=mixed_dtypes, usecols=needed_cols) - - for i, bp in enumerate(ALL_BPS): - bp_idx, gb_idx, gb_offset = bp_meta[i] - bp_data = _compute_bp_from_df( - df, bp, bp_idx, gb_idx, gb_offset, col_prefix, min_cov, decimals - ) - if accumulators[i] is None: - accumulators[i] = bp_data - else: - accumulators[i] = accumulators[i].merge(bp_data, on=metadata_cols, how='outer') - del bp_data - - del df - gc.collect() + sample_info.append((filepath, group_name, sample_name, replicate, file_id)) - print(f'\nWriting {len(ALL_BPS)} output files...') - write_args = [ - (ALL_BPS[i], accumulators[i], output_prefix, metadata_cols, report_non_qualified) - for i in range(len(ALL_BPS)) + worker_args = [ + (bp, sample_info, output_prefix, metadata_cols, + include_file_id, min_cov, decimals, report_non_qualified) + for bp in base_pairs ] if threads > 1: - n_workers = min(threads, len(ALL_BPS)) + n_workers = min(threads, len(base_pairs)) + print(f'Processing {len(base_pairs)} base pairs with {n_workers} parallel workers...') with multiprocessing.Pool(processes=n_workers) as pool: - output_files = pool.starmap(_write_one_bp, write_args) + output_files = pool.starmap(_process_one_bp, worker_args) else: - output_files = [_write_one_bp(*a) for a in write_args] - - # Free accumulators after writing - for i in range(len(ALL_BPS)): - accumulators[i] = None - gc.collect() + print(f'Processing {len(base_pairs)} base pairs sequentially...') + output_files = [_process_one_bp(*a) for a in worker_args] print(f'\nDone.') - print(f' {len(sample_info)} samples processed.') - print(f' {len(ALL_BPS)} output files written.') + print(f' {len(sample_info)} samples: {chr(44)+chr(32).join(si[2] for si in sample_info)}') + print(f' {len(base_pairs)} output files written.') # Example usage if __name__ == "__main__": From c567507bfcc8ebf8b60175d625f86d5e91ba0ff6 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Mon, 23 Mar 2026 08:44:22 +0000 Subject: [PATCH 45/61] Improve RAM usage by spliting by seqid --- bin/drip.py | 266 ++++++++++++++++++++++++++++++++++++++++++---------- 1 file changed, 218 insertions(+), 48 deletions(-) diff --git a/bin/drip.py b/bin/drip.py index b174706..737c0a2 100755 --- a/bin/drip.py +++ b/bin/drip.py @@ -17,6 +17,9 @@ import multiprocessing import gc from pathlib import Path +import shutil +import tempfile +import re def print_help(): """Print help message explaining the script usage and calculations.""" @@ -219,30 +222,46 @@ def parse_tsv_file_for_bp(filepath, bp, group_name, sample_name, replicate, file return result -def _process_one_bp(bp, sample_info, output_prefix, metadata_cols, - include_file_id, min_cov, decimals, report_non_qualified): - """Load, merge, filter and write the output file for one base-pair combination. +def _compute_bp_from_df(df, bp, bp_idx, gb_idx, gb_offset, col_prefix, min_cov, decimals): + """Compute espf/espr for one BP from an already-loaded DataFrame. - Memory model: at most one raw input file + the cumulative merged result - (11 + 2·N columns) live in RAM simultaneously. The previous approach kept - 11 + 32·N columns in RAM across all base pairs before writing any file. - In parallel mode each worker runs this function independently — no shared - DataFrame is pickled across processes. + Returns a slim DataFrame (11 metadata cols + 2 metric cols). + `df` must already contain the pre-parsed columns: + SiteBasePairingsQualified, QualifiedBases, ReadBasePairingsQualified + as well as the 11 metadata columns. """ - merged = None - for filepath, group_name, sample_name, replicate, file_id in sample_info: - data = parse_tsv_file_for_bp( - filepath, bp, group_name, sample_name, replicate, - file_id, include_file_id, min_cov, decimals - ) - if merged is None: - merged = data - else: - merged = merged.merge(data, on=metadata_cols, how='outer') - del data - gc.collect() + metadata_cols = ['SeqID', 'ParentIDs', 'ID', 'Mtype', 'Ptype', 'Type', + 'Ctype', 'Mode', 'Start', 'End', 'Strand'] - merged = merged.sort_values(['SeqID', 'ParentIDs', 'Mode']) + bp_sites = _parse_comma_col_at(df['SiteBasePairingsQualified'], bp_idx) + x_count = _parse_comma_col_at(df['QualifiedBases'], gb_idx) + + reads_split = df['ReadBasePairingsQualified'].str.split(',', expand=True) + bp_reads = reads_split[bp_idx].astype(np.int64) + total_reads = (reads_split[gb_offset ].astype(np.int64) + + reads_split[gb_offset + 1].astype(np.int64) + + reads_split[gb_offset + 2].astype(np.int64) + + reads_split[gb_offset + 3].astype(np.int64)) + del reads_split + + result = df[metadata_cols].copy() + + mask_f = x_count >= min_cov + result[f'{col_prefix}::espf'] = np.where( + mask_f, bp_sites / x_count.where(mask_f, 1), np.nan + ).round(decimals) + + mask_r = total_reads >= min_cov + result[f'{col_prefix}::espr'] = np.where( + mask_r, bp_reads / total_reads.where(mask_r, 1), np.nan + ).round(decimals) + + return result + + +def _write_one_bp(bp, accumulated, output_prefix, metadata_cols, report_non_qualified): + """Sort, filter and write the accumulated DataFrame for one BP.""" + merged = accumulated.sort_values(['SeqID', 'ParentIDs', 'Mode']) if not report_non_qualified: metric_cols = [c for c in merged.columns if c not in metadata_cols] @@ -258,45 +277,196 @@ def _process_one_bp(bp, sample_info, output_prefix, metadata_cols, print(f' {output_file} ({n_rows} rows)') return output_file +def _safe_seqid(seqid: str) -> str: + """Convert a SeqID to a filesystem-safe directory name.""" + return re.sub(r'[^A-Za-z0-9._-]', '_', seqid) or 'EMPTY' + + +def _split_file_by_seqid(filepath, temp_dir, sample_idx, needed_cols, mixed_dtypes): + """Read one TSV file once and write per-SeqID chunk files. + + Returns a dict {seqid: chunk_file_path}. Each chunk file contains the + same columns as the original (needed_cols), filtered to one SeqID. + """ + df = pd.read_csv(filepath, sep='\t', dtype=mixed_dtypes, usecols=needed_cols) + seqid_to_path = {} + for seqid, group in df.groupby('SeqID', sort=False): + chunk_dir = os.path.join(temp_dir, _safe_seqid(seqid)) + os.makedirs(chunk_dir, exist_ok=True) + path = os.path.join(chunk_dir, f'{sample_idx}.tsv') + group.to_csv(path, sep='\t', index=False) + seqid_to_path[seqid] = path + del df + gc.collect() + return seqid_to_path + + +def _process_seqid_chunk(seqid, chunk_paths_by_sample, col_prefixes, temp_out_dir, + metadata_cols, bp_meta, all_bps, min_cov, decimals, + report_non_qualified): + """Compute 16 BPs for one SeqID and write 16 per-BP chunk files. + + Each element in chunk_paths_by_sample is a path (str) or None when that + sample has no rows for this SeqID — the outer join fills NaN for it. + Returns (seqid, {bp: out_path, ...}) — only BPs with at least one row. + """ + mixed_dtypes_local = {'SeqID': str, 'Start': str, 'End': str, 'Strand': str} + bp_accumulators = [None] * len(all_bps) + + for col_prefix, chunk_path in zip(col_prefixes, chunk_paths_by_sample): + if chunk_path is None: + continue + df = pd.read_csv(chunk_path, sep='\t', dtype=mixed_dtypes_local) + for i in range(len(all_bps)): + bp_idx, gb_idx, gb_offset = bp_meta[i] + bp_data = _compute_bp_from_df( + df, all_bps[i], bp_idx, gb_idx, gb_offset, col_prefix, min_cov, decimals + ) + if bp_accumulators[i] is None: + bp_accumulators[i] = bp_data + else: + bp_accumulators[i] = bp_accumulators[i].merge( + bp_data, on=metadata_cols, how='outer' + ) + del bp_data + del df + gc.collect() + + out_paths = {} + safe = _safe_seqid(seqid) + for i, bp in enumerate(all_bps): + acc = bp_accumulators[i] + if acc is None: + continue + acc = acc.sort_values(['ParentIDs', 'Mode']) + if not report_non_qualified: + metric_cols = [c for c in acc.columns if c not in metadata_cols] + if metric_cols: + acc = acc[ + (acc[metric_cols].notna() & (acc[metric_cols] != 0)).any(axis=1) + ] + if len(acc) > 0: + out_path = os.path.join(temp_out_dir, f'{bp}_{safe}.tsv') + acc.to_csv(out_path, sep='\t', index=False, na_rep='NA') + out_paths[bp] = out_path + bp_accumulators[i] = None + + gc.collect() + return seqid, out_paths + + def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id=False, min_cov=1, threads=1, decimals=4, report_non_qualified=False): """Produce one output file per base pair combination. - Memory strategy: process one base pair at a time. Each iteration reads - every input file once but keeps only 11 + 2·N columns in RAM (vs the old - 11 + 32·N columns for all base pairs at once). In parallel mode each - worker owns its data entirely — nothing large is pickled between processes. - Trade-off: input files are read 16 times instead of once. + Memory strategy — three phases: + Phase 1 (Split): each input file is read once and split by SeqID into + temporary chunk files. Peak RAM = one full input file. + Phase 2 (Process): each SeqID is processed independently (all 16 BPs, + across all samples) and results written to temp chunks. + Peak RAM per worker ≈ num_samples × one-SeqID slice. + Workers run in parallel when threads > 1. + Phase 3 (Concat): per-SeqID temp chunks are appended in order to the + 16 final output files. Peak RAM = one chunk at a time. """ - base_pairs = ['AA', 'AC', 'AG', 'AT', 'CA', 'CC', 'CG', 'CT', - 'GA', 'GC', 'GG', 'GT', 'TA', 'TC', 'TG', 'TT'] + ALL_BPS = ['AA', 'AC', 'AG', 'AT', 'CA', 'CC', 'CG', 'CT', + 'GA', 'GC', 'GG', 'GT', 'TA', 'TC', 'TG', 'TT'] + BASES = ['A', 'C', 'G', 'T'] metadata_cols = ['SeqID', 'ParentIDs', 'ID', 'Mtype', 'Ptype', 'Type', 'Ctype', 'Mode', 'Start', 'End', 'Strand'] + needed_cols = metadata_cols + ['QualifiedBases', + 'SiteBasePairingsQualified', + 'ReadBasePairingsQualified'] + mixed_dtypes = {'SeqID': str, 'Start': str, 'End': str, 'Strand': str} + bp_meta = [(ALL_BPS.index(bp), BASES.index(bp[0]), BASES.index(bp[0]) * 4) + for bp in ALL_BPS] sample_info = [] for filepath, (group_name, sample_name, replicate) in file_group_sample_replicate_dict.items(): filename_stem = Path(filepath).stem file_id = '_'.join(filename_stem.split('_')[:-1]) - sample_info.append((filepath, group_name, sample_name, replicate, file_id)) - - worker_args = [ - (bp, sample_info, output_prefix, metadata_cols, - include_file_id, min_cov, decimals, report_non_qualified) - for bp in base_pairs - ] - - if threads > 1: - n_workers = min(threads, len(base_pairs)) - print(f'Processing {len(base_pairs)} base pairs with {n_workers} parallel workers...') - with multiprocessing.Pool(processes=n_workers) as pool: - output_files = pool.starmap(_process_one_bp, worker_args) - else: - print(f'Processing {len(base_pairs)} base pairs sequentially...') - output_files = [_process_one_bp(*a) for a in worker_args] - - print(f'\nDone.') - print(f' {len(sample_info)} samples: {chr(44)+chr(32).join(si[2] for si in sample_info)}') - print(f' {len(base_pairs)} output files written.') + col_prefix = (f'{group_name}::{sample_name}::{replicate}::{file_id}' + if include_file_id + else f'{group_name}::{sample_name}::{replicate}') + sample_info.append((filepath, col_prefix)) + + with tempfile.TemporaryDirectory() as temp_dir: + temp_split_dir = os.path.join(temp_dir, 'split') + temp_out_dir = os.path.join(temp_dir, 'out') + os.makedirs(temp_split_dir) + os.makedirs(temp_out_dir) + + # ── Phase 1: split each file by SeqID ───────────────────────────────── + print('Phase 1/3 — Splitting input files by SeqID...') + all_seqids_seen: dict[str, int] = {} # seqid → first-appearance order + sample_seqid_paths: list[dict[str, str]] = [] + + for sample_idx, (filepath, _col_prefix) in enumerate(sample_info): + print(f' [{sample_idx + 1}/{len(sample_info)}] {filepath}') + seqid_to_path = _split_file_by_seqid( + filepath, temp_split_dir, sample_idx, needed_cols, mixed_dtypes + ) + sample_seqid_paths.append(seqid_to_path) + for seqid in seqid_to_path: + if seqid not in all_seqids_seen: + all_seqids_seen[seqid] = len(all_seqids_seen) + + # Sort SeqIDs: real sequences first (lexicographic), "." (globals) last + all_seqids = sorted( + all_seqids_seen.keys(), + key=lambda s: (s == '.', s) + ) + print(f' Found {len(all_seqids)} unique SeqIDs across all samples.') + + # ── Phase 2: process each SeqID (parallel if threads > 1) ───────────── + col_prefixes = [col_prefix for _, col_prefix in sample_info] + worker_args = [ + (seqid, + [ssp.get(seqid) for ssp in sample_seqid_paths], + col_prefixes, temp_out_dir, + metadata_cols, bp_meta, ALL_BPS, min_cov, decimals, report_non_qualified) + for seqid in all_seqids + ] + + mode = f'{min(threads, len(all_seqids))} workers' if threads > 1 else 'sequential' + print(f'Phase 2/3 — Processing {len(all_seqids)} SeqID chunks ({mode})...') + + if threads > 1: + with multiprocessing.Pool(processes=min(threads, len(all_seqids))) as pool: + results = pool.starmap(_process_seqid_chunk, worker_args) + else: + results = [_process_seqid_chunk(*a) for a in worker_args] + + # Restore stable SeqID order (parallel mode may return out of order) + seqid_order = {s: i for i, s in enumerate(all_seqids)} + results.sort(key=lambda r: seqid_order[r[0]]) + + # ── Phase 3: concatenate per-SeqID chunks into 16 final files ────────── + print('Phase 3/3 — Writing final output files...') + output_files = [] + for bp in ALL_BPS: + bp_chunks = [ + out_paths[bp] + for _, out_paths in results + if bp in out_paths + ] + if not bp_chunks: + continue + + out_path = f'{output_prefix}_{bp}.tsv' + with open(out_path, 'w') as fout: + with open(bp_chunks[0]) as first: + shutil.copyfileobj(first, fout) # includes header + for chunk_path in bp_chunks[1:]: + with open(chunk_path) as f: + next(f) # skip header + shutil.copyfileobj(f, fout) + output_files.append(out_path) + print(f' {out_path}') + + print(f'\nDone. {len(sample_info)} samples, {len(all_seqids)} SeqIDs, ' + f'{len(output_files)} output files written.') + # Example usage if __name__ == "__main__": From ca33d34460fd4dddf922f32ecad2d03c74db8b52 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Mon, 23 Mar 2026 08:54:21 +0000 Subject: [PATCH 46/61] add --min-samples-pct --min-group-pct to filter rows --- bin/drip.py | 95 +++++++++++++++++++++++++++++++++++++++++++++++++---- 1 file changed, 88 insertions(+), 7 deletions(-) diff --git a/bin/drip.py b/bin/drip.py index 737c0a2..c6815dd 100755 --- a/bin/drip.py +++ b/bin/drip.py @@ -47,6 +47,17 @@ def print_help(): Include rows where all metric values are NA (default: omit them). By default, rows where every sample has NA for the given base pair are skipped (not covered or not qualified in any sample). + --min-samples-pct X + Keep a row only if at least X% of all samples have a qualified value + (non-NA, non-zero) for the base pair. A sample "has a value" when + at least one of its espf/espr metrics is non-NA and non-zero. + Applied as an OR with --min-group-pct when both are provided. + Ignored when --report-non-qualified-features is set. + --min-group-pct Y + Keep a row only if at least one group has at least Y% of its + samples with a qualified value. Applied as an OR with + --min-samples-pct when both are provided. + Ignored when --report-non-qualified-features is set. --min-cov N Minimum read coverage threshold (default: 1). Positions with a denominator (genome base count for espf, read count for espr) strictly below this value are reported as NA instead of 0. @@ -303,7 +314,7 @@ def _split_file_by_seqid(filepath, temp_dir, sample_idx, needed_cols, mixed_dtyp def _process_seqid_chunk(seqid, chunk_paths_by_sample, col_prefixes, temp_out_dir, metadata_cols, bp_meta, all_bps, min_cov, decimals, - report_non_qualified): + report_non_qualified, min_samples_pct=None, min_group_pct=None): """Compute 16 BPs for one SeqID and write 16 per-BP chunk files. Each element in chunk_paths_by_sample is a path (str) or None when that @@ -342,9 +353,33 @@ def _process_seqid_chunk(seqid, chunk_paths_by_sample, col_prefixes, temp_out_di if not report_non_qualified: metric_cols = [c for c in acc.columns if c not in metadata_cols] if metric_cols: - acc = acc[ - (acc[metric_cols].notna() & (acc[metric_cols] != 0)).any(axis=1) - ] + has_value = acc[metric_cols].notna() & (acc[metric_cols] != 0) + # One boolean column per sample (pair espf/espr → any of the two) + sample_prefixes = sorted(set(c.rsplit('::', 1)[0] for c in metric_cols)) + sample_hv = pd.DataFrame( + {p: has_value[ + [c for c in metric_cols if c.rsplit('::', 1)[0] == p] + ].any(axis=1) + for p in sample_prefixes}, + index=acc.index, + ) + if min_samples_pct is None and min_group_pct is None: + keep = sample_hv.any(axis=1) + else: + keep = pd.Series(False, index=acc.index) + if min_samples_pct is not None: + n = len(sample_prefixes) + keep |= sample_hv.sum(axis=1) / n >= min_samples_pct / 100.0 + if min_group_pct is not None: + groups: dict[str, list[str]] = {} + for p in sample_prefixes: + groups.setdefault(p.split('::')[0], []).append(p) + for g_cols in groups.values(): + keep |= ( + sample_hv[g_cols].sum(axis=1) / len(g_cols) + >= min_group_pct / 100.0 + ) + acc = acc[keep] if len(acc) > 0: out_path = os.path.join(temp_out_dir, f'{bp}_{safe}.tsv') acc.to_csv(out_path, sep='\t', index=False, na_rep='NA') @@ -356,7 +391,8 @@ def _process_seqid_chunk(seqid, chunk_paths_by_sample, col_prefixes, temp_out_di def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id=False, - min_cov=1, threads=1, decimals=4, report_non_qualified=False): + min_cov=1, threads=1, decimals=4, report_non_qualified=False, + min_samples_pct=None, min_group_pct=None): """Produce one output file per base pair combination. Memory strategy — three phases: @@ -424,7 +460,8 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ (seqid, [ssp.get(seqid) for ssp in sample_seqid_paths], col_prefixes, temp_out_dir, - metadata_cols, bp_meta, ALL_BPS, min_cov, decimals, report_non_qualified) + metadata_cols, bp_meta, ALL_BPS, min_cov, decimals, report_non_qualified, + min_samples_pct, min_group_pct) for seqid in all_seqids ] @@ -490,6 +527,8 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ threads = 1 # Default: sequential writing decimals = 4 # Default: round to 4 decimal places report_non_qualified = False # Default: skip rows where all metric values are NA + min_samples_pct = None # Default: no global-sample-% filter + min_group_pct = None # Default: no per-group-% filter args_iter = iter(range(1, len(sys.argv))) for i in args_iter: @@ -518,6 +557,48 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ report_non_qualified = True continue + # Check for --min-samples-pct flag + if arg.startswith('--min-samples-pct'): + if '=' in arg: + val = arg.split('=', 1)[1] + else: + try: + next_i = next(args_iter) + val = sys.argv[next_i] + except StopIteration: + print('ERROR: --min-samples-pct requires a value', file=sys.stderr) + sys.exit(1) + try: + min_samples_pct = float(val) + except ValueError: + print('ERROR: --min-samples-pct requires a numeric value (0–100)', file=sys.stderr) + sys.exit(1) + if not (0.0 <= min_samples_pct <= 100.0): + print('ERROR: --min-samples-pct must be between 0 and 100', file=sys.stderr) + sys.exit(1) + continue + + # Check for --min-group-pct flag + if arg.startswith('--min-group-pct'): + if '=' in arg: + val = arg.split('=', 1)[1] + else: + try: + next_i = next(args_iter) + val = sys.argv[next_i] + except StopIteration: + print('ERROR: --min-group-pct requires a value', file=sys.stderr) + sys.exit(1) + try: + min_group_pct = float(val) + except ValueError: + print('ERROR: --min-group-pct requires a numeric value (0–100)', file=sys.stderr) + sys.exit(1) + if not (0.0 <= min_group_pct <= 100.0): + print('ERROR: --min-group-pct must be between 0 and 100', file=sys.stderr) + sys.exit(1) + continue + # Check for --min-cov flag (supports both --min-cov=N and --min-cov N) if arg.startswith('--min-cov'): if '=' in arg: @@ -622,6 +703,6 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ sys.exit(1) # Process all samples - result = merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id, min_cov, threads, decimals, report_non_qualified) + result = merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_id, min_cov, threads, decimals, report_non_qualified, min_samples_pct, min_group_pct) print("\nAnalysis complete!") \ No newline at end of file From 000f8741351efa06318c17ffaba76b5614331181 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Mon, 23 Mar 2026 10:02:51 +0100 Subject: [PATCH 47/61] add filter in sample annd group minimal present --- modules/python.nf | 4 +++- rain.nf | 6 ++++-- 2 files changed, 7 insertions(+), 3 deletions(-) diff --git a/modules/python.nf b/modules/python.nf index b14119e..3b299ba 100644 --- a/modules/python.nf +++ b/modules/python.nf @@ -34,6 +34,8 @@ process drip { input: val(meta_tsv) val prefix + val samples_pct + val group_pct output: path("*_AG.tsv"), emit: editing_ag @@ -57,6 +59,6 @@ process drip { """ - drip.py --threads ${task.cpus} --output drip_${prefix} ${args_str} + drip.py --threads ${task.cpus} --min-samples-pct ${samples_pct} --min-group-pct ${group_pct} --output drip_${prefix} ${args_str} """ } \ No newline at end of file diff --git a/rain.nf b/rain.nf index 0a103c8..b90cb83 100644 --- a/rain.nf +++ b/rain.nf @@ -24,6 +24,8 @@ edit_site_tools = ["reditools2", "reditools3", "jacusa2", "sapin"] params.edit_site_tool = "reditools3" params.edit_threshold = 1 // Minimal number of edited reads to count a site as edited params.cov_threshold = 10 // Minimal coverage to consider a site for editing detection +params.min_samples_pct = 50 // Minimal percentage of samples in which a site must be edited to be kept in the analysis (drip filtering) +params.min_group_pct = 75 // Minimal percentage of groups in which a site must be edited to be kept in the analysis (drip filtering) params.aggregation_mode = "all" params.skip_hyper_editing = false // Skip hyper-editing detection // Report params @@ -765,8 +767,8 @@ workflow { pluviometer_reditools3(reditools3.out.tuple_sample_serial_table, clean_annotation.collect(), "reditools3") if(via_csv){ // drip - compute espn, espf, merge different sample in one, and output by type of mutation (AG, AC, etc..) - drip_aggregates(pluviometer_reditools3.out.tuple_sample_aggregate.collect(), "aggregates") - drip_features(pluviometer_reditools3.out.tuple_sample_feature.collect(), "features") + drip_aggregates(pluviometer_reditools3.out.tuple_sample_aggregate.collect(), "aggregates", params.min_samples_pct, params.min_group_pct) + drip_features(pluviometer_reditools3.out.tuple_sample_feature.collect(), "features", params.min_samples_pct, params.min_group_pct) //christalize(drip_features.out.editing_ag, "AG") From 63b136aa1bb8ba647882e2f380a821b966f56218 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Mon, 23 Mar 2026 10:44:51 +0000 Subject: [PATCH 48/61] output one folder by value type computed espr espf --- bin/drip.py | 199 ++++++++++++++++++++++++++++++---------------------- 1 file changed, 117 insertions(+), 82 deletions(-) diff --git a/bin/drip.py b/bin/drip.py index c6815dd..757a502 100755 --- a/bin/drip.py +++ b/bin/drip.py @@ -37,8 +37,9 @@ def print_help(): ARGUMENTS: --output OUTPUT_PREFIX, -o OUTPUT_PREFIX - Prefix for the output TSV files (required). - Will create OUTPUT_PREFIX_AA.tsv, OUTPUT_PREFIX_AC.tsv, etc. + Prefix for the output directories (required). + Creates OUTPUT_PREFIX_espf/ and OUTPUT_PREFIX_espr/, + each containing 16 TSV files (AA.tsv, AC.tsv, …, TT.tsv). FILEn:GROUPn:SAMPLEn:REPn Input file path, group name, sample name, and replicate ID separated by colons. All four components are required. @@ -100,31 +101,38 @@ def print_help(): All 16 combinations are calculated: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT OUTPUT FORMAT: - Multiple TSV files (one per base pair combination) with aggregates as rows: - - OUTPUT_PREFIX_AA.tsv - - OUTPUT_PREFIX_AC.tsv - - OUTPUT_PREFIX_AG.tsv - - ... (16 files total, one for each XY combination) - + Two output directories, one per metric type, each containing 16 TSV files + (one per base pair combination). Row filtering (--min-samples-pct, etc.) + is applied independently per metric: a row may appear in the espf output + but not in the espr output and vice versa. + + Directories created: + - OUTPUT_PREFIX_espf/ espf metric (proportion of edited sites in feature) + - OUTPUT_PREFIX_espr/ espr metric (proportion of edited reads) + + Files in each directory (16 per metric): + - AA.tsv, AC.tsv, AG.tsv, AT.tsv + - CA.tsv, CC.tsv, CG.tsv, CT.tsv + - GA.tsv, GC.tsv, GG.tsv, GT.tsv + - TA.tsv, TC.tsv, TG.tsv, TT.tsv + Each file contains: - - Metadata columns (first 6 columns): + + Metadata columns: - SeqID: Sequence/chromosome identifier - ParentIDs: Parent feature identifiers - ID: Unique identifier + - Mtype: Type of feature - Ptype: Type of Parent feature - - Type: Type of feature + - Type: Aggregate type (feature / sequence / global) - Ctype: Type of Children feature - - Mode: Mode of aggregation used if any (e.g., 'all_sites', 'edited_sites', 'edited_reads') - - Metric columns (for each sample): - - GROUP::SAMPLE::REPLICATE::espf: XY sites proportion in feature (XY sites / X bases) - - GROUP::SAMPLE::REPLICATE::espr: XY sites proportion in reads (XY reads / all X reads) - - Or with --with-file-id option: - - GROUP::SAMPLE::REPLICATE::FILE_ID::espf - - GROUP::SAMPLE::REPLICATE::FILE_ID::espr - + - Mode: Mode of aggregation (e.g., 'all_sites', 'edited_sites', 'edited_reads') + - Start, End, Strand + + Metric columns (one per sample): + - GROUP::SAMPLE::REPLICATE:: (without --with-file-id) + - GROUP::SAMPLE::REPLICATE::FILE_ID:: (with --with-file-id) + Where: - GROUP: Group/condition name provided in arguments - SAMPLE: Sample name provided in arguments @@ -138,24 +146,21 @@ def print_help(): sample2_aggregates.tsv:control:sample2:rep2 \\ sample3_aggregates.tsv:treated:sample1:rep1 - This creates 16 files (one per base pair combination): - - results_AA.tsv, results_AC.tsv, results_AG.tsv, results_AT.tsv, - - results_CA.tsv, results_CC.tsv, results_CG.tsv, results_CT.tsv, - - results_GA.tsv, results_GC.tsv, results_GG.tsv, results_GT.tsv, - - results_TA.tsv, results_TC.tsv, results_TG.tsv, results_TT.tsv - - Each file has columns: - SeqID, ParentIDs, ID, Ptype, Ctype, Mode, - control::sample1::rep1::rain_sample1::espf, control::sample1::rep1::rain_sample1::espr, - control::sample2::rep2::rain_sample2::espf, control::sample2::rep2::rain_sample2::espr, - treated::sample1::rep1::rain_sample3::espf, treated::sample1::rep1::rain_sample3::espr - - Column headers use format: GROUP::SAMPLE::REPLICATE::FILE_ID::METRIC + Creates two directories: + - results_espf/ with AA.tsv, AC.tsv, …, TT.tsv (espf metric) + - results_espr/ with AA.tsv, AC.tsv, …, TT.tsv (espr metric) + + Example columns in results_espf/AG.tsv: + SeqID, ParentIDs, ID, Mtype, Ptype, Type, Ctype, Mode, Start, End, Strand, + control::sample1::rep1::espf, + control::sample2::rep2::espf, + treated::sample1::rep1::espf + + Column headers use format: GROUP::SAMPLE::REPLICATE::METRIC - GROUP: The group/condition name provided - SAMPLE: The sample name provided - REPLICATE: Replicate ID (rep1, rep2, etc.) - - FILE_ID: Input filename without extension and last '_' suffix - - METRIC: espf or espr + - METRIC: espf or espr (same for all columns in a given file) - Separator '::' allows easy splitting to retrieve all components AUTHORS: @@ -350,40 +355,63 @@ def _process_seqid_chunk(seqid, chunk_paths_by_sample, col_prefixes, temp_out_di if acc is None: continue acc = acc.sort_values(['ParentIDs', 'Mode']) - if not report_non_qualified: - metric_cols = [c for c in acc.columns if c not in metadata_cols] - if metric_cols: - has_value = acc[metric_cols].notna() & (acc[metric_cols] != 0) - # One boolean column per sample (pair espf/espr → any of the two) - sample_prefixes = sorted(set(c.rsplit('::', 1)[0] for c in metric_cols)) - sample_hv = pd.DataFrame( - {p: has_value[ - [c for c in metric_cols if c.rsplit('::', 1)[0] == p] - ].any(axis=1) - for p in sample_prefixes}, - index=acc.index, - ) + + # Process espf and espr independently: separate output files, separate + # row-filtering. A row that passes the espf threshold but not the espr + # threshold (or vice versa) will appear in only one of the two outputs. + for metric in ('espf', 'espr'): + # Each column for this metric type maps 1:1 to one sample. + metric_cols = [c for c in acc.columns + if c not in metadata_cols and c.endswith(f'::{metric}')] + if not metric_cols: + continue + acc_metric = acc[metadata_cols + metric_cols] + + if not report_non_qualified: + # Step 1 — cell-level: "has a value" = non-NA AND non-zero. + # NA means the position was not covered (or below min_cov, or + # absent from this sample via the outer join). + # 0.0 means covered but no editing event observed. + has_value = acc_metric[metric_cols].notna() & (acc_metric[metric_cols] != 0) + + # Step 2 — row-level decision, applied independently per metric + # type (espf rows and espr rows are filtered separately): + # + # Default: keep if ANY sample has a value for this metric. + # + # --min-samples-pct X: keep if proportion of samples with a + # value >= X/100 (across all samples globally). + # Example: 3/5 samples with espf > 0 → 60%; X=50 → keep. + # + # --min-group-pct Y: keep if at least one group has >= Y/100 + # of its samples with a value. Groups are identified by + # the first :: component of the column name (GROUP name). + # Example: "ctrl" group has 2/3 espf values → 67%; Y=60 → keep. + # + # Both flags → OR: keep if either condition is satisfied. if min_samples_pct is None and min_group_pct is None: - keep = sample_hv.any(axis=1) + keep = has_value.any(axis=1) else: - keep = pd.Series(False, index=acc.index) + keep = pd.Series(False, index=acc_metric.index) if min_samples_pct is not None: - n = len(sample_prefixes) - keep |= sample_hv.sum(axis=1) / n >= min_samples_pct / 100.0 + n = len(metric_cols) + keep |= has_value.sum(axis=1) / n >= min_samples_pct / 100.0 if min_group_pct is not None: groups: dict[str, list[str]] = {} - for p in sample_prefixes: - groups.setdefault(p.split('::')[0], []).append(p) + for c in metric_cols: + groups.setdefault(c.split('::')[0], []).append(c) for g_cols in groups.values(): keep |= ( - sample_hv[g_cols].sum(axis=1) / len(g_cols) + has_value[g_cols].sum(axis=1) / len(g_cols) >= min_group_pct / 100.0 ) - acc = acc[keep] - if len(acc) > 0: - out_path = os.path.join(temp_out_dir, f'{bp}_{safe}.tsv') - acc.to_csv(out_path, sep='\t', index=False, na_rep='NA') - out_paths[bp] = out_path + acc_metric = acc_metric[keep] + + if len(acc_metric) > 0: + out_path = os.path.join(temp_out_dir, metric, f'{bp}_{safe}.tsv') + acc_metric.to_csv(out_path, sep='\t', index=False, na_rep='NA') + out_paths[(bp, metric)] = out_path + bp_accumulators[i] = None gc.collect() @@ -431,6 +459,8 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ temp_out_dir = os.path.join(temp_dir, 'out') os.makedirs(temp_split_dir) os.makedirs(temp_out_dir) + os.makedirs(os.path.join(temp_out_dir, 'espf')) + os.makedirs(os.path.join(temp_out_dir, 'espr')) # ── Phase 1: split each file by SeqID ───────────────────────────────── print('Phase 1/3 — Splitting input files by SeqID...') @@ -478,31 +508,36 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ seqid_order = {s: i for i, s in enumerate(all_seqids)} results.sort(key=lambda r: seqid_order[r[0]]) - # ── Phase 3: concatenate per-SeqID chunks into 16 final files ────────── + # ── Phase 3: concatenate per-SeqID chunks into final output files ───────── + # Two output directories (one per metric type), each with 16 BP files. print('Phase 3/3 — Writing final output files...') output_files = [] - for bp in ALL_BPS: - bp_chunks = [ - out_paths[bp] - for _, out_paths in results - if bp in out_paths - ] - if not bp_chunks: - continue - - out_path = f'{output_prefix}_{bp}.tsv' - with open(out_path, 'w') as fout: - with open(bp_chunks[0]) as first: - shutil.copyfileobj(first, fout) # includes header - for chunk_path in bp_chunks[1:]: - with open(chunk_path) as f: - next(f) # skip header - shutil.copyfileobj(f, fout) - output_files.append(out_path) - print(f' {out_path}') + for metric in ('espf', 'espr'): + out_dir = f'{output_prefix}_{metric}' + os.makedirs(out_dir, exist_ok=True) + for bp in ALL_BPS: + bp_chunks = [ + out_paths[(bp, metric)] + for _, out_paths in results + if (bp, metric) in out_paths + ] + if not bp_chunks: + continue + + out_path = os.path.join(out_dir, f'{bp}.tsv') + with open(out_path, 'w') as fout: + with open(bp_chunks[0]) as first: + shutil.copyfileobj(first, fout) # includes header + for chunk_path in bp_chunks[1:]: + with open(chunk_path) as f: + next(f) # skip header + shutil.copyfileobj(f, fout) + output_files.append(out_path) + print(f' {out_path}') print(f'\nDone. {len(sample_info)} samples, {len(all_seqids)} SeqIDs, ' - f'{len(output_files)} output files written.') + f'{len(output_files)} output files written ' + f'in {output_prefix}_espf/ and {output_prefix}_espr/.') # Example usage From 93871893a1be92e1bea42b504c795ed852026098 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Mon, 23 Mar 2026 11:55:10 +0100 Subject: [PATCH 49/61] separate espf and espr --- bin/drip.py | 32 ++++++++++++++++---------------- modules/python.nf | 7 ++++--- rain.nf | 5 +++-- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/bin/drip.py b/bin/drip.py index 757a502..94682f8 100755 --- a/bin/drip.py +++ b/bin/drip.py @@ -28,7 +28,7 @@ def print_help(): DESCRIPTION: This script analyzes RNA editing from standardized pluviometer files. It calculates - two key metrics for all 16 genome-variant base pair combinations across multiple + two key metrics for all 12 genome-variant base pair combinations across multiple samples and combines them into a unified matrix format. USAGE: @@ -39,7 +39,7 @@ def print_help(): --output OUTPUT_PREFIX, -o OUTPUT_PREFIX Prefix for the output directories (required). Creates OUTPUT_PREFIX_espf/ and OUTPUT_PREFIX_espr/, - each containing 16 TSV files (AA.tsv, AC.tsv, …, TT.tsv). + each containing 12 TSV files (AA.tsv, AC.tsv, …, TT.tsv). FILEn:GROUPn:SAMPLEn:REPn Input file path, group name, sample name, and replicate ID separated by colons. All four components are required. @@ -64,7 +64,7 @@ def print_help(): strictly below this value are reported as NA instead of 0. --threads N, -t N Number of parallel threads to use for writing output files - (default: 1, sequential). Max useful value is 16 (one per base pair). + (default: 1, sequential). Max useful value is 12 (one per base pair). --decimals N, -d N Number of decimal places for output values (default: 4). Reduces file size by rounding espf/espr metrics. @@ -86,7 +86,7 @@ def print_help(): (order: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT) CALCULATED METRICS: - For each line, the script calculates metrics for all 16 base pair combinations: + For each line, the script calculates metrics for all 12 base pair combinations: For each combination XY (where X = genome base, Y = read base): @@ -98,10 +98,10 @@ def print_help(): Formula: XY_ReadBasePairings / (XA + XC + XG + XT)_ReadBasePairings This represents the proportion of X-position reads that show Y in the reads. - All 16 combinations are calculated: AA, AC, AG, AT, CA, CC, CG, CT, GA, GC, GG, GT, TA, TC, TG, TT + All 12 combinations are calculated: AC, AG, AT, CA, CG, CT, GA, GC, GT, TA, TC, TG OUTPUT FORMAT: - Two output directories, one per metric type, each containing 16 TSV files + Two output directories, one per metric type, each containing 12 TSV files (one per base pair combination). Row filtering (--min-samples-pct, etc.) is applied independently per metric: a row may appear in the espf output but not in the espr output and vice versa. @@ -110,7 +110,7 @@ def print_help(): - OUTPUT_PREFIX_espf/ espf metric (proportion of edited sites in feature) - OUTPUT_PREFIX_espr/ espr metric (proportion of edited reads) - Files in each directory (16 per metric): + Files in each directory (12 per metric): - AA.tsv, AC.tsv, AG.tsv, AT.tsv - CA.tsv, CC.tsv, CG.tsv, CT.tsv - GA.tsv, GC.tsv, GG.tsv, GT.tsv @@ -191,7 +191,7 @@ def parse_tsv_file_for_bp(filepath, bp, group_name, sample_name, replicate, file genome_base = bp[0] bp_idx = ALL_BPS.index(bp) gb_idx = BASES.index(genome_base) # 0–3 - gb_offset = gb_idx * 4 # first XA/XC/XG/XT index in 16-value vector + gb_offset = gb_idx * 4 # first XA/XC/XG/XT index in 12-value vector metadata_cols = ['SeqID', 'ParentIDs', 'ID', 'Mtype', 'Ptype', 'Type', 'Ctype', 'Mode', 'Start', 'End', 'Strand'] @@ -320,7 +320,7 @@ def _split_file_by_seqid(filepath, temp_dir, sample_idx, needed_cols, mixed_dtyp def _process_seqid_chunk(seqid, chunk_paths_by_sample, col_prefixes, temp_out_dir, metadata_cols, bp_meta, all_bps, min_cov, decimals, report_non_qualified, min_samples_pct=None, min_group_pct=None): - """Compute 16 BPs for one SeqID and write 16 per-BP chunk files. + """Compute 12 BPs for one SeqID and write 12 per-BP chunk files. Each element in chunk_paths_by_sample is a path (str) or None when that sample has no rows for this SeqID — the outer join fills NaN for it. @@ -408,7 +408,7 @@ def _process_seqid_chunk(seqid, chunk_paths_by_sample, col_prefixes, temp_out_di acc_metric = acc_metric[keep] if len(acc_metric) > 0: - out_path = os.path.join(temp_out_dir, metric, f'{bp}_{safe}.tsv') + out_path = os.path.join(temp_out_dir, metric, f'{metric}_{bp}_{safe}.tsv') acc_metric.to_csv(out_path, sep='\t', index=False, na_rep='NA') out_paths[(bp, metric)] = out_path @@ -426,15 +426,15 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ Memory strategy — three phases: Phase 1 (Split): each input file is read once and split by SeqID into temporary chunk files. Peak RAM = one full input file. - Phase 2 (Process): each SeqID is processed independently (all 16 BPs, + Phase 2 (Process): each SeqID is processed independently (all 12 BPs, across all samples) and results written to temp chunks. Peak RAM per worker ≈ num_samples × one-SeqID slice. Workers run in parallel when threads > 1. Phase 3 (Concat): per-SeqID temp chunks are appended in order to the - 16 final output files. Peak RAM = one chunk at a time. + 12 final output files. Peak RAM = one chunk at a time. """ - ALL_BPS = ['AA', 'AC', 'AG', 'AT', 'CA', 'CC', 'CG', 'CT', - 'GA', 'GC', 'GG', 'GT', 'TA', 'TC', 'TG', 'TT'] + ALL_BPS = ['AC', 'AG', 'AT', 'CA', 'CG', 'CT', + 'GA', 'GC', 'GT', 'TA', 'TC', 'TG'] BASES = ['A', 'C', 'G', 'T'] metadata_cols = ['SeqID', 'ParentIDs', 'ID', 'Mtype', 'Ptype', 'Type', 'Ctype', 'Mode', 'Start', 'End', 'Strand'] @@ -509,7 +509,7 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ results.sort(key=lambda r: seqid_order[r[0]]) # ── Phase 3: concatenate per-SeqID chunks into final output files ───────── - # Two output directories (one per metric type), each with 16 BP files. + # Two output directories (one per metric type), each with 12 BP files. print('Phase 3/3 — Writing final output files...') output_files = [] for metric in ('espf', 'espr'): @@ -524,7 +524,7 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ if not bp_chunks: continue - out_path = os.path.join(out_dir, f'{bp}.tsv') + out_path = os.path.join(out_dir, f'{metric}_{bp}.tsv') with open(out_path, 'w') as fout: with open(bp_chunks[0]) as first: shutil.copyfileobj(first, fout) # includes header diff --git a/modules/python.nf b/modules/python.nf index 3b299ba..b779102 100644 --- a/modules/python.nf +++ b/modules/python.nf @@ -38,8 +38,10 @@ process drip { val group_pct output: - path("*_AG.tsv"), emit: editing_ag - path("*.tsv"), emit: editing_all + path("*_espr/*_AG.tsv"), emit: editing_ag_espr + path("*_espr/*.tsv"), emit: editing_all_espr + path("*_espf/*_AG.tsv"), emit: editing_ag_espf + path("*_espf/*.tsv"), emit: editing_all_espf script: def list = meta_tsv @@ -58,7 +60,6 @@ process drip { def args_str = args.join(" ") """ - drip.py --threads ${task.cpus} --min-samples-pct ${samples_pct} --min-group-pct ${group_pct} --output drip_${prefix} ${args_str} """ } \ No newline at end of file diff --git a/rain.nf b/rain.nf index b90cb83..df387b2 100644 --- a/rain.nf +++ b/rain.nf @@ -24,9 +24,10 @@ edit_site_tools = ["reditools2", "reditools3", "jacusa2", "sapin"] params.edit_site_tool = "reditools3" params.edit_threshold = 1 // Minimal number of edited reads to count a site as edited params.cov_threshold = 10 // Minimal coverage to consider a site for editing detection +// When both flags below are provided → OR: keep if either condition is satisfied independently. params.min_samples_pct = 50 // Minimal percentage of samples in which a site must be edited to be kept in the analysis (drip filtering) params.min_group_pct = 75 // Minimal percentage of groups in which a site must be edited to be kept in the analysis (drip filtering) -params.aggregation_mode = "all" +params.aggregation_mode = "all" // used by pluviometer params.skip_hyper_editing = false // Skip hyper-editing detection // Report params params.multiqc_config = "$baseDir/config/multiqc_config.yaml" // MultiQC config file @@ -770,7 +771,7 @@ workflow { drip_aggregates(pluviometer_reditools3.out.tuple_sample_aggregate.collect(), "aggregates", params.min_samples_pct, params.min_group_pct) drip_features(pluviometer_reditools3.out.tuple_sample_feature.collect(), "features", params.min_samples_pct, params.min_group_pct) - //christalize(drip_features.out.editing_ag, "AG") + //barometer (drip_features.out.editing_ag_espr, "AG") } } From 85dc09e2a006cd3675a88f4d3e3c6ff599d1a639 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Mon, 23 Mar 2026 12:14:37 +0100 Subject: [PATCH 50/61] fix publisdir --- modules/python.nf | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/python.nf b/modules/python.nf index b779102..537cd37 100644 --- a/modules/python.nf +++ b/modules/python.nf @@ -29,7 +29,7 @@ process restore_original_sequences { process drip { label "pluviometer" tag "drip" - publishDir("${params.outdir}/drip/${prefix}", mode:"copy", pattern: "*.tsv") + publishDir("${params.outdir}/drip/${prefix}", mode:"copy", pattern: "*") input: val(meta_tsv) From 386d97e3147ee1143ef3f8802795d0cda9dc21bf Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Mon, 23 Mar 2026 12:16:57 +0100 Subject: [PATCH 51/61] fix path outpu --- modules/python.nf | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/python.nf b/modules/python.nf index 537cd37..846be6e 100644 --- a/modules/python.nf +++ b/modules/python.nf @@ -29,7 +29,7 @@ process restore_original_sequences { process drip { label "pluviometer" tag "drip" - publishDir("${params.outdir}/drip/${prefix}", mode:"copy", pattern: "*") + publishDir("${params.outdir}/drip/${prefix}", mode:"copy", pattern: "*/*") input: val(meta_tsv) From 43b3bdedc559f9d8b7316bc2def055ac2a6951ae Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Mon, 23 Mar 2026 11:26:45 +0000 Subject: [PATCH 52/61] better output naming --- bin/drip.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/bin/drip.py b/bin/drip.py index 94682f8..108ae42 100755 --- a/bin/drip.py +++ b/bin/drip.py @@ -524,7 +524,7 @@ def merge_samples(file_group_sample_replicate_dict, output_prefix, include_file_ if not bp_chunks: continue - out_path = os.path.join(out_dir, f'{metric}_{bp}.tsv') + out_path = os.path.join(out_dir, f'{output_prefix}_{metric}_{bp}.tsv') with open(out_path, 'w') as fout: with open(bp_chunks[0]) as first: shutil.copyfileobj(first, fout) # includes header From e2f571d993518ac4289b3dcc0642dc4b6a42236e Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Mon, 23 Mar 2026 12:27:20 +0100 Subject: [PATCH 53/61] to rerun --- modules/python.nf | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modules/python.nf b/modules/python.nf index 846be6e..0371325 100644 --- a/modules/python.nf +++ b/modules/python.nf @@ -59,7 +59,7 @@ process drip { def args_str = args.join(" ") - """ + """ drip.py --threads ${task.cpus} --min-samples-pct ${samples_pct} --min-group-pct ${group_pct} --output drip_${prefix} ${args_str} """ } \ No newline at end of file From 0257b090a0e6715ce21a05f701b3dee319f02053 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Mon, 23 Mar 2026 21:26:08 +0100 Subject: [PATCH 54/61] create an ID for convenience for sequence and global bmks --- bin/pluviometer/rain_file_writers.py | 19 ++++++++++++++++++- modules/pluviometer.nf | 2 +- 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/bin/pluviometer/rain_file_writers.py b/bin/pluviometer/rain_file_writers.py index 22f9c9a..eb6386a 100644 --- a/bin/pluviometer/rain_file_writers.py +++ b/bin/pluviometer/rain_file_writers.py @@ -212,13 +212,30 @@ def write_metadata( strand: str = "." ) -> int: """Write metadata fields of an aggregate""" + + # Determine the Type field before modifying aggregate_id + original_aggregate_id = aggregate_id + agg_type = self._agg_type(seq_id, original_aggregate_id) + + # Create a dynamic ID for aggregates without explicit IDs (sequence/global level) + if aggregate_id == ".": + # Build ID as Type_Ptype_Ctype_Mode, omitting Ptype and/or Ctype if they are "." + id_parts = [agg_type] # "sequence" or "global" + if parent_type != ".": + id_parts.append(parent_type) + if aggregate_type != ".": + id_parts.append(aggregate_type) + id_parts.append(aggregation_mode) + + aggregate_id = "_".join(id_parts) + return super().write_metadata( seq_id, parent_ids, aggregate_id, "aggregate", parent_type, - self._agg_type(seq_id, aggregate_id), + agg_type, # Use the pre-computed Type instead of recalculating it aggregate_type, aggregation_mode, start, diff --git a/modules/pluviometer.nf b/modules/pluviometer.nf index 683af48..d56a90f 100644 --- a/modules/pluviometer.nf +++ b/modules/pluviometer.nf @@ -15,7 +15,7 @@ process pluviometer { script: base_name = site_edits.BaseName - """ + """ pluviometer_wrapper.py \ --sites ${site_edits} \ --gff ${gff} \ From c3d0a51b5d815c6027d546241bd6695220d4881f Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Tue, 24 Mar 2026 13:33:24 +0000 Subject: [PATCH 55/61] fix to avoid slowlyness when unstranded --- bin/pluviometer/__main__.py | 17 ++++++++++++++++- 1 file changed, 16 insertions(+), 1 deletion(-) diff --git a/bin/pluviometer/__main__.py b/bin/pluviometer/__main__.py index 0f3a3d7..08295d0 100755 --- a/bin/pluviometer/__main__.py +++ b/bin/pluviometer/__main__.py @@ -735,9 +735,24 @@ def update_active_counters(self, site_data: RNASiteVariantData) -> None: """ Update the multicounters matching the ID of features in the `active_features` set. A new multicounter is created if no matching ID is found. + + Strand compatibility logic: + - Feature unstranded (0): accepts sites from any strand + - Feature stranded: only accepts sites from same strand or unstranded sites + - Site unstranded: only counted for unstranded features (to avoid combinatorial explosion) """ for feature_key, feature in self.active_features.items(): - if site_data.strand == 0 or feature.location.strand == 0 or feature.location.strand == site_data.strand: + # Optimized order: check exact match first (most common case) + if feature.location.strand == site_data.strand: + counter: MultiCounter = self.counters[feature_key] + counter.update(site_data) + # Feature is unstranded: accept stranded sites too + elif feature.location.strand == 0 and site_data.strand != 0: + counter: MultiCounter = self.counters[feature_key] + counter.update(site_data) + # Site is unstranded: only count for unstranded features + # (counting for all features would cause combinatorial explosion) + elif site_data.strand == 0 and feature.location.strand == 0: counter: MultiCounter = self.counters[feature_key] counter.update(site_data) From 25629c0a58c3ddcd7518ce8dec03723b6574146d Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Tue, 24 Mar 2026 13:33:45 +0000 Subject: [PATCH 56/61] fix 2 must be interpreted as unstranded --- bin/pluviometer/rna_site_variant_readers.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bin/pluviometer/rna_site_variant_readers.py b/bin/pluviometer/rna_site_variant_readers.py index 7437afa..dcd9f92 100755 --- a/bin/pluviometer/rna_site_variant_readers.py +++ b/bin/pluviometer/rna_site_variant_readers.py @@ -234,11 +234,11 @@ def parse_strand(self) -> int: strand = int(self.parts[REDITOOLS_FIELD_INDEX["Strand"]]) match strand: case 0: - return -1 + return 0 # Unstranded case 1: - return 1 + return 1 # Forward strand case 2: - return 0 + return -1 # Reverse strand (first-strand oriented) case _: raise Exception(f"Invalid strand value: {strand}") From c4394417a6db9040323171045048d92a7330b728 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Tue, 24 Mar 2026 15:18:47 +0100 Subject: [PATCH 57/61] serealize by editing tool analyzer and fix strand for reditools --- modules/jacusa2.nf | 2 +- modules/pluviometer.nf | 11 +++---- modules/python.nf | 16 ++++------ modules/reditools2.nf | 3 +- modules/reditools3.nf | 4 +-- rain.nf | 72 +++++++++++++++++++++++++++++++++++------- 6 files changed, 77 insertions(+), 31 deletions(-) diff --git a/modules/jacusa2.nf b/modules/jacusa2.nf index 15727e9..c40632c 100644 --- a/modules/jacusa2.nf +++ b/modules/jacusa2.nf @@ -8,7 +8,7 @@ process jacusa2 { tuple(path(genome), path(fastaindex)) output: - tuple(val(meta), path("*.site_edits_jacusa2.txt"), emit: tuple_sample_jacusa2_table) + tuple(val(meta), val("jacusa2"), path("*.site_edits_jacusa2.txt"), emit: tuple_sample_jacusa2_table) path("*.filtered") script: diff --git a/modules/pluviometer.nf b/modules/pluviometer.nf index d56a90f..67b2a80 100644 --- a/modules/pluviometer.nf +++ b/modules/pluviometer.nf @@ -1,17 +1,16 @@ process pluviometer { label "pluviometer" publishDir("${params.outdir}/pluviometer/${tool_format}/", mode: "copy") - tag "${meta.uid}" + tag "${meta.uid}_${tool_format}" input: - tuple(val(meta), path(site_edits)) + tuple(val(meta), val(tool_format), path(site_edits)) path(gff) - val(tool_format) output: - tuple(val(meta), path("*features.tsv"), emit: tuple_sample_feature) - tuple(val(meta), path("*aggregates.tsv"), emit: tuple_sample_aggregate) - tuple(val(meta), path("*pluviometer.log"), emit: tuple_sample_log) + tuple(val(meta), val(tool_format), path("*features.tsv"), emit: tuple_sample_feature) + tuple(val(meta), val(tool_format), path("*aggregates.tsv"), emit: tuple_sample_aggregate) + tuple(val(meta), val(tool_format), path("*pluviometer.log"), emit: tuple_sample_log) script: base_name = site_edits.BaseName diff --git a/modules/python.nf b/modules/python.nf index 0371325..229af8c 100644 --- a/modules/python.nf +++ b/modules/python.nf @@ -28,32 +28,30 @@ process restore_original_sequences { process drip { label "pluviometer" - tag "drip" + tag "drip_${tool}" publishDir("${params.outdir}/drip/${prefix}", mode:"copy", pattern: "*/*") input: - val(meta_tsv) + tuple(val(tool), val(meta_tsv)) val prefix val samples_pct val group_pct output: - path("*_espr/*_AG.tsv"), emit: editing_ag_espr path("*_espr/*.tsv"), emit: editing_all_espr - path("*_espf/*_AG.tsv"), emit: editing_ag_espf path("*_espf/*.tsv"), emit: editing_all_espf script: def list = meta_tsv def args = [] - // Process list by pairs: [meta, file, meta, file, ...] - for (int i = 0; i < list.size(); i += 2) { - def m = list[i] // Don't reuse 'meta' variable name - def file = list[i + 1] + // Process list of [meta, file] pairs from groupTuple + list.each { pair -> + def m = pair[0] // meta dictionary + def file = pair[1] // file path def group = m.group ?: "group_unknown" def sample = m.sample ?: "sample_unknown" - def replicate = m.rep ?: "rep1" // Note: it's 'rep' not 'replicate' in meta + def replicate = m.rep ?: "rep1" args.add("${file}:${group}:${sample}:${replicate}") } diff --git a/modules/reditools2.nf b/modules/reditools2.nf index 468ba40..6065015 100644 --- a/modules/reditools2.nf +++ b/modules/reditools2.nf @@ -9,12 +9,13 @@ process reditools2 { val region output: - tuple(val(meta), path("*site_edits_reditools2.txt"), emit: tuple_sample_serial_table) + tuple(val(meta), val("reditools2"), path("*site_edits_reditools2.txt"), emit: tuple_sample_serial_table) script: // Set the strand orientation parameter from the library type parameter // Terms explained in https://salmon.readthedocs.io/en/latest/library_type.html + // https://github.com/BioinfoUNIBA/REDItools2?tab=readme-ov-file: -s STRAND, --strand STRAND Strand: this can be 0 (unstranded), 1 (secondstrand oriented) or 2 (firststrand oriented) if (meta.strandedness in ["ISR", "SR"]) { // First-strand oriented strand_orientation = "2" diff --git a/modules/reditools3.nf b/modules/reditools3.nf index 50e7a4d..120eb5a 100644 --- a/modules/reditools3.nf +++ b/modules/reditools3.nf @@ -8,7 +8,7 @@ process reditools3 { path genome output: - tuple(val(meta), path("${base_name}.site_edits_reditools3.txt"), emit: tuple_sample_serial_table) + tuple(val(meta), val("reditools3"), path("${base_name}.site_edits_reditools3.txt"), emit: tuple_sample_serial_table) path("${base_name}.reditools3.log", emit: log) script: @@ -31,7 +31,7 @@ process reditools3 { } base_name = bam.BaseName - """ + """ python -m reditools analyze ${bam} --reference ${genome} --strand ${strand_orientation} --output-file ${base_name}.site_edits_reditools3.txt --threads ${task.cpus} --verbose &> ${base_name}.reditools3.log """ } diff --git a/rain.nf b/rain.nf index df387b2..ea97e9f 100644 --- a/rain.nf +++ b/rain.nf @@ -192,7 +192,7 @@ include {fasta_unzip} from "$baseDir/modules/pigz.nf" include {samtools_index; samtools_fasta_index; samtools_sort_bam as samtools_sort_bam_raw; samtools_sort_bam as samtools_sort_bam_merged; samtools_split_mapped_unmapped; samtools_merge_bams; samtools_calmd} from './modules/samtools.nf' include {reditools2} from "./modules/reditools2.nf" include {reditools3} from "./modules/reditools3.nf" -include {pluviometer as pluviometer_jacusa2; pluviometer as pluviometer_reditools2; pluviometer as pluviometer_reditools3; pluviometer as pluviometer_sapin} from "./modules/pluviometer.nf" +include {pluviometer} from "./modules/pluviometer.nf" include {drip as drip_aggregates; drip as drip_features} from "./modules/python.nf" include {sapin} from "./modules/sapin.nf" @@ -748,32 +748,80 @@ workflow { // ------------------------------------------------------- // ----------------- DETECT EDITING SITES ---------------- // ------------------------------------------------------- - + Channel.empty().set{editing_analysis} // Select site detection tool if ( "jacusa2" in edit_site_tool_list ){ // Create a fasta index file of the reference genome samtools_fasta_index(genome.collect()) jacusa2(final_bam_for_editing, samtools_fasta_index.out.tuple_fasta_fastaindex.collect()) - pluviometer_jacusa2(jacusa2.out.tuple_sample_jacusa2_table, clean_annotation.collect(), "jacusa2") + editing_analysis = editing_analysis.mix(jacusa2.out.tuple_sample_jacusa2_table) } if ( "sapin" in edit_site_tool_list ){ sapin(tuple_sample_bam_processed, genome.collect()) } if ( "reditools2" in edit_site_tool_list ){ reditools2(final_bam_for_editing, genome.collect(), params.region) - pluviometer_reditools2(reditools2.out.tuple_sample_serial_table, clean_annotation.collect(), "reditools2") + editing_analysis = editing_analysis.mix(reditools2.out.tuple_sample_serial_table) } if ( "reditools3" in edit_site_tool_list ){ reditools3(final_bam_for_editing, genome.collect()) - pluviometer_reditools3(reditools3.out.tuple_sample_serial_table, clean_annotation.collect(), "reditools3") - if(via_csv){ - // drip - compute espn, espf, merge different sample in one, and output by type of mutation (AG, AC, etc..) - drip_aggregates(pluviometer_reditools3.out.tuple_sample_aggregate.collect(), "aggregates", params.min_samples_pct, params.min_group_pct) - drip_features(pluviometer_reditools3.out.tuple_sample_feature.collect(), "features", params.min_samples_pct, params.min_group_pct) - - //barometer (drip_features.out.editing_ag_espr, "AG") + editing_analysis = editing_analysis.mix(reditools3.out.tuple_sample_serial_table) + } + + // Run pluviometer on editing analysis results to get aggregates and features values + pluviometer(editing_analysis, clean_annotation.collect()) + + if(via_csv){ + // Collect pluviometer outputs by tool (group by element at index 1 = tool name) + aggregates_by_tool = pluviometer.out.tuple_sample_aggregate.map { meta, tool, file -> tuple(tool, [meta, file]) }.groupTuple() + features_by_tool = pluviometer.out.tuple_sample_feature.map { meta, tool, file -> tuple(tool, [meta, file]) }.groupTuple() + + // drip - compute espn, espf, merge different sample in one, and output by type of mutation (AG, AC, etc..) + drip_aggregates(aggregates_by_tool, "aggregates", params.min_samples_pct, params.min_group_pct) + drip_features(features_by_tool, "features", params.min_samples_pct, params.min_group_pct) + + // ------------------- ESPF JOIN AGGREGATES AND FEATURES ----------------- + drip_aggregates.out.editing_all_espf + .flatten() + .map { file -> + // Extract editing type from filename (e.g., "drip_aggregates_espf_AC.tsv" -> "AC") + def editType = file.baseName.tokenize('_').last() + tuple(editType, file) + } + .join( + drip_features.out.editing_all_espf + .flatten() + .map { file -> + // Extract editing type from filename (e.g., "drip_aggregates_espf_AC.tsv" -> "AC") + def editType = file.baseName.tokenize('_').last() + tuple(editType, file) + } + ) + .set { features_espf_by_edit_type } + features_espf_by_edit_type.view() + + // ------------------- ESPR JOIN AGGREGATES AND FEATURES ----------------- + drip_aggregates.out.editing_all_espr + .flatten() + .map { file -> + // Extract editing type from filename (e.g., "drip_aggregates_espr_AC.tsv" -> "AC") + def editType = file.baseName.tokenize('_').last() + tuple(editType, file) + } + .join( + drip_features.out.editing_all_espr + .flatten() + .map { file -> + // Extract editing type from filename (e.g., "drip_aggregates_espr_AC.tsv" -> "AC") + def editType = file.baseName.tokenize('_').last() + tuple(editType, file) + } + ) + .set { features_espr_by_edit_type } + features_espr_by_edit_type.view() + + // READY for barometer analysis - } } // ------------------- MULTIQC ----------------- From e305c319d47e980c89993a6a82b2b12d93681eeb Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Tue, 24 Mar 2026 21:37:45 +0100 Subject: [PATCH 58/61] add barometer scripts for last steps --- README.md | 4 +- bin/barometer_analyze.py | 2297 ++++++++++++++++++++++++++++++++++++++ bin/barometer_report.py | 1133 +++++++++++++++++++ 3 files changed, 3431 insertions(+), 3 deletions(-) create mode 100755 bin/barometer_analyze.py create mode 100755 bin/barometer_report.py diff --git a/README.md b/README.md index 5844972..9691046 100644 --- a/README.md +++ b/README.md @@ -375,6 +375,4 @@ Contributions from the community are welcome ! See the [Contributing guidelines] ## TODO -update pluviometer to set NA for start end and strand instead of . to be able to use column as int64 in drip and barometer e.g. dtype={"SeqID": str, "Start": "Int64", "End": "Int64", "Strand": str} - -Le feature type seq_agg must be called sequence_agg. Maybe _agg can be removed because know by another column. \ No newline at end of file +update pluviometer to set NA for start end and strand instead of . to be able to use column as int64 in drip and barometer e.g. dtype={"SeqID": str, "Start": "Int64", "End": "Int64", "Strand": str} \ No newline at end of file diff --git a/bin/barometer_analyze.py b/bin/barometer_analyze.py new file mode 100755 index 0000000..2d4f5b2 --- /dev/null +++ b/bin/barometer_analyze.py @@ -0,0 +1,2297 @@ +#!/usr/bin/env python3 +""" +barometer_analyze.py – Exhaustive analysis pipeline for rain biomarker data. + +Loads barometer_aggregates_AG.tsv and barometer_features_AG.tsv, performs QC, +descriptive statistics, differential editing analysis, multivariate analysis, +correlation/network analysis, feature selection / biomarker ranking, +classification, stability analysis, and generates all tables + matplotlib figures. + +Results are saved into an output directory (default: barometer_results/) structured +by value_type → mtype → section. + +Usage: + python barometer_analyze.py [--aggregates FILE] [--features FILE] [--outdir DIR] +""" + +import argparse +import gc +import json +import logging +import multiprocessing +import os +import sys +import time +import warnings +from collections import defaultdict +from concurrent.futures import ProcessPoolExecutor, as_completed +from itertools import combinations +from pathlib import Path + +try: + import psutil +except ImportError: + psutil = None + +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +import numpy as np +import pandas as pd +import seaborn as sns +from scipy import stats +from scipy.cluster.hierarchy import linkage, dendrogram, fcluster +from scipy.spatial.distance import pdist +from sklearn.decomposition import PCA +from sklearn.discriminant_analysis import LinearDiscriminantAnalysis +from sklearn.ensemble import RandomForestClassifier, GradientBoostingClassifier +from sklearn.model_selection import cross_val_score, StratifiedKFold, LeaveOneOut +from sklearn.preprocessing import StandardScaler, LabelEncoder +from sklearn.metrics import classification_report +import statsmodels.api as sm +from statsmodels.stats.multicomp import pairwise_tukeyhsd +from statsmodels.stats.multitest import multipletests + +warnings.filterwarnings("ignore", category=FutureWarning) +warnings.filterwarnings("ignore", category=UserWarning) +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") +log = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def safe_mkdir(path): + Path(path).mkdir(parents=True, exist_ok=True) + + +def prepare_df_for_task(df): + """Convert DataFrame to a memory-efficient format for task submission. + + Uses pickle which is much faster than .to_dict('list') for large DataFrames. + Returns a tuple ('pickled', bytes_data) that can be unpickled in the worker. + """ + import pickle + return ('pickled', pickle.dumps(df, protocol=pickle.HIGHEST_PROTOCOL)) + + +def parse_sample_columns(columns): + """Parse sample column names like 'test1::rain_chr21_small::rep1::espf'. + + Returns a list of dicts with keys: col, group, sample, rep, value_type. + """ + parsed = [] + meta_cols = {"SeqID", "ParentIDs", "ID", "Mtype", "Ptype", "Type", "Ctype", "Mode", "Start", "End", "Strand"} + for c in columns: + if c in meta_cols: + continue + parts = c.split("::") + if len(parts) == 4: + parsed.append({ + "col": c, + "group": parts[0], + "sample": parts[1], + "rep": parts[2], + "value_type": parts[3], + }) + return parsed + + +def get_value_types(sample_info): + return sorted(set(s["value_type"] for s in sample_info)) + + +def cols_for_vtype(sample_info, vtype): + return [s["col"] for s in sample_info if s["value_type"] == vtype] + + +def group_for_col(sample_info, col): + for s in sample_info: + if s["col"] == col: + return s["group"] + return None + + +def sample_info_for_vtype(sample_info, vtype): + return [s for s in sample_info if s["value_type"] == vtype] + + +def numeric_df(df, cols): + """Return a copy with the given columns cast to numeric (coerce errors).""" + out = df.copy() + for c in cols: + out[c] = pd.to_numeric(out[c], errors="coerce") + return out + + +def save_fig(fig, path, dpi=150): + fig.savefig(path, dpi=dpi, bbox_inches="tight") + plt.close(fig) + + +# --------------------------------------------------------------------------- +# Quality Control +# --------------------------------------------------------------------------- + +def qc_analysis(df, sample_cols, outdir): + """Basic quality control: missing values, distributions, outliers.""" + safe_mkdir(outdir) + ndf = numeric_df(df, sample_cols) + results = {} + + # Missing value counts + log.info("Missing value counts") + missing = ndf[sample_cols].isnull().sum() + total = len(ndf) + missing_pct = (missing / total * 100).round(2) + miss_df = pd.DataFrame({"missing_count": missing, "missing_pct": missing_pct}) + miss_df.to_csv(os.path.join(outdir, "missing_values.csv")) + results["missing"] = miss_df.to_dict() + + # Distribution summary per sample + log.info("Distribution summary per sample") + desc = ndf[sample_cols].describe().T + desc.to_csv(os.path.join(outdir, "distribution_summary.csv")) + + # Box plot of all samples + log.info("Box plot of all samples") + fig, ax = plt.subplots(figsize=(max(6, len(sample_cols) * 0.8), 5)) + ndf[sample_cols].boxplot(ax=ax, rot=90) + ax.set_title("Distribution per sample") + ax.set_ylabel("Value") + save_fig(fig, os.path.join(outdir, "boxplot_samples.png")) + + # Heatmap of missing values (limit to avoid matplotlib memory errors with large datasets) + log.info("Heatmap of missing values") + max_heatmap_rows = 10000 # limit to prevent memory issues + heatmap_data = ndf[sample_cols].isnull().astype(int) + + if len(heatmap_data) > max_heatmap_rows: + # Sample: prioritize BMKs with most missing values + missing_counts = heatmap_data.sum(axis=1) + top_missing_idx = missing_counts.nlargest(max_heatmap_rows).index + heatmap_data = heatmap_data.loc[top_missing_idx] + log.info(f" Limiting missing values heatmap to {max_heatmap_rows} BMKs with most missing (from {len(ndf)} total)") + + fig_height = min(20, max(4, len(heatmap_data) * 0.02)) # Cap at 20 inches + try: + fig, ax = plt.subplots(figsize=(max(6, len(sample_cols) * 0.6), fig_height)) + sns.heatmap(heatmap_data, cbar=False, ax=ax, yticklabels=False) + ax.set_title(f"Missing values heatmap ({len(heatmap_data)} BMKs)") + save_fig(fig, os.path.join(outdir, "missing_heatmap.png")) + log.info(" Missing values heatmap saved") + except Exception as e: + log.warning(f" Missing values heatmap failed: {e}") + + results["desc"] = desc.to_dict() + return results + + +# --------------------------------------------------------------------------- +# Descriptive Statistics +# --------------------------------------------------------------------------- + +def descriptive_stats(df, sample_cols, sample_info, outdir): + safe_mkdir(outdir) + ndf = numeric_df(df, sample_cols) + results = {} + + # Per-group statistics + groups = sorted(set(s["group"] for s in sample_info)) + group_stats = {} + for g in groups: + gcols = [s["col"] for s in sample_info if s["group"] == g] + vals = ndf[gcols].values.flatten() + vals = vals[~np.isnan(vals)] + if len(vals) == 0: + continue + group_stats[g] = { + "n": int(len(vals)), + "mean": float(np.mean(vals)), + "median": float(np.median(vals)), + "std": float(np.std(vals, ddof=1)) if len(vals) > 1 else 0.0, + "min": float(np.min(vals)), + "max": float(np.max(vals)), + "q25": float(np.percentile(vals, 25)), + "q75": float(np.percentile(vals, 75)), + } + gs_df = pd.DataFrame(group_stats).T + gs_df.to_csv(os.path.join(outdir, "group_statistics.csv")) + results["group_stats"] = group_stats + + # Per-BMK mean by group + bmk_means = {} + for g in groups: + gcols = [s["col"] for s in sample_info if s["group"] == g] + bmk_means[g] = ndf[gcols].mean(axis=1) + bmk_mean_df = pd.DataFrame(bmk_means, index=df.index) + if "ID" in df.columns: + bmk_mean_df.index = df["ID"] + bmk_mean_df.to_csv(os.path.join(outdir, "bmk_mean_by_group.csv")) + + # Violin plot by group (optimized with pd.melt) + sample_cols_list = [s["col"] for s in sample_info] + if sample_cols_list: + # Create group mapping for columns + col_to_group = {s["col"]: s["group"] for s in sample_info} + # Melt the dataframe efficiently + melted = ndf[sample_cols_list].melt(var_name="sample_col", value_name="value") + melted["group"] = melted["sample_col"].map(col_to_group) + melted = melted.dropna(subset=["value"]) + + if len(melted) > 0: + fig, ax = plt.subplots(figsize=(max(6, len(groups) * 2), 5)) + sns.violinplot(data=melted, x="group", y="value", ax=ax, inner="box") + ax.set_title("Value distribution by group") + save_fig(fig, os.path.join(outdir, "violin_by_group.png")) + + return results + + +# --------------------------------------------------------------------------- +# Differential Editing Analysis +# --------------------------------------------------------------------------- + +def test_normality_and_homogeneity(group_values): + """Test for normality (Shapiro-Wilk) and homoscedasticity (Bartlett). + + Returns: + dict with keys: is_normal, is_homogeneous, shapiro_pvals, bartlett_pval + """ + results = { + "is_normal": False, + "is_homogeneous": False, + "shapiro_pvals": [], + "bartlett_pval": None + } + + non_empty = [gv for gv in group_values.values() if len(gv) >= 3] # Shapiro needs n>=3 + if len(non_empty) < 2: + return results + + # Test normality per group (Shapiro-Wilk) + shapiro_pvals = [] + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + for gv in non_empty: + if len(gv) >= 3: + try: + # Check for near-constant data before testing + if np.std(gv) < 1e-10: + continue # Skip constant data + _, p = stats.shapiro(gv) + shapiro_pvals.append(p) + except Exception: + pass + + results["shapiro_pvals"] = shapiro_pvals + results["is_normal"] = all(p > 0.05 for p in shapiro_pvals) if shapiro_pvals else False + + # Test homogeneity of variances (Bartlett) + if len(non_empty) >= 2: + with warnings.catch_warnings(): + warnings.filterwarnings("ignore") + try: + # Check all groups have variance + if all(np.std(gv) > 1e-10 for gv in non_empty): + _, p = stats.bartlett(*non_empty) + results["bartlett_pval"] = p + results["is_homogeneous"] = p > 0.05 + except Exception: + pass + + return results + + +def differential_analysis(df, sample_cols, sample_info, outdir, stat_test="auto"): + """Pairwise and global differential tests between groups. + + Args: + stat_test: 'auto', 'parametric', 'nonparametric', 'welch', 'kruskal' + - auto: test normality/homogeneity and choose best test + - parametric: Student t-test / ANOVA (assumes normality + equal variances) + - nonparametric: Mann-Whitney U / Kruskal-Wallis (no assumptions) + - welch: Welch t-test / Welch ANOVA (assumes normality, unequal variances OK) + - kruskal: alias for nonparametric (backward compatibility) + + Note: Expects df to already be pre-filtered for variable BMKs (done in analyze_section) + """ + safe_mkdir(outdir) + ndf = numeric_df(df, sample_cols) + groups = sorted(set(s["group"] for s in sample_info)) + results = {"pairwise": {}, "global": {}, "stat_test_used": stat_test} + + if len(groups) < 2: + log.warning("Less than 2 groups, skipping differential analysis.") + return results + + if len(ndf) == 0: + log.warning("No BMKs remaining for differential analysis") + return results + + # Normalize test names + if stat_test == "kruskal": + stat_test = "nonparametric" + + # OPTIMIZATION: Pre-compute group columns to avoid O(N_samples) per BMK per group + group_cols = {g: [s["col"] for s in sample_info if s["group"] == g] for g in groups} + + rows = [] + + # Suppress scipy warnings about numerical issues (we handle them with try/except) + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", message="Precision loss occurred") + warnings.filterwarnings("ignore", message="invalid value encountered") + warnings.filterwarnings("ignore", message="divide by zero encountered") + + for idx in ndf.index: + row_data = {"index": idx} + if "ID" in df.columns: + row_data["ID"] = df.loc[idx, "ID"] + + group_values = {} + for g in groups: + gcols = group_cols[g] # Use pre-computed dict + vals = ndf.loc[idx, gcols].dropna().values.astype(float) + group_values[g] = vals + row_data[f"mean_{g}"] = np.mean(vals) if len(vals) > 0 else np.nan + row_data[f"n_{g}"] = len(vals) + + non_empty = [gv for gv in group_values.values() if len(gv) > 0] + + # Determine which test to use + test_choice = stat_test + if stat_test == "auto" and len(non_empty) >= 2: + # Test assumptions + assumptions = test_normality_and_homogeneity(group_values) + row_data["shapiro_pvals"] = ",".join([f"{p:.4f}" for p in assumptions["shapiro_pvals"]]) if assumptions["shapiro_pvals"] else "" + row_data["bartlett_pval"] = assumptions["bartlett_pval"] + + # Choose test based on assumptions + if assumptions["is_normal"] and assumptions["is_homogeneous"]: + test_choice = "parametric" + elif assumptions["is_normal"] and not assumptions["is_homogeneous"]: + test_choice = "welch" + else: + test_choice = "nonparametric" + + row_data["test_selected"] = test_choice + + # Global tests (comparing all groups) + if len(non_empty) >= 2 and all(len(v) >= 1 for v in non_empty): + # Kruskal-Wallis (always compute for backward compatibility) + try: + stat, pval = stats.kruskal(*non_empty) + row_data["kruskal_stat"] = stat + row_data["kruskal_pval"] = pval + except Exception: + row_data["kruskal_stat"] = np.nan + row_data["kruskal_pval"] = np.nan + + # ANOVA (parametric) + valid_for_anova = [gv for gv in group_values.values() if len(gv) >= 2] + if len(valid_for_anova) >= 2: + try: + stat, pval = stats.f_oneway(*valid_for_anova) + row_data["anova_stat"] = stat + row_data["anova_pval"] = pval + except Exception: + row_data["anova_stat"] = np.nan + row_data["anova_pval"] = np.nan + else: + row_data["anova_stat"] = np.nan + row_data["anova_pval"] = np.nan + + # Welch ANOVA (parametric with unequal variances) + if len(valid_for_anova) >= 2: + try: + # Welch ANOVA using scipy (one-way test with equal_var=False not available directly) + # We'll use a workaround: for 2 groups use Welch t-test, for 3+ use Kruskal as fallback + if len(groups) == 2: + g1_vals = group_values[groups[0]] + g2_vals = group_values[groups[1]] + if len(g1_vals) >= 2 and len(g2_vals) >= 2: + stat, pval = stats.ttest_ind(g1_vals, g2_vals, equal_var=False) + row_data["welch_stat"] = stat + row_data["welch_pval"] = pval + else: + row_data["welch_stat"] = np.nan + row_data["welch_pval"] = np.nan + else: + # For 3+ groups, Welch ANOVA is complex; use oneway with unequal var assumption + # scipy doesn't have direct Welch ANOVA, so we use Kruskal as robust alternative + row_data["welch_stat"] = row_data.get("kruskal_stat", np.nan) + row_data["welch_pval"] = row_data.get("kruskal_pval", np.nan) + except Exception: + row_data["welch_stat"] = np.nan + row_data["welch_pval"] = np.nan + else: + row_data["welch_stat"] = np.nan + row_data["welch_pval"] = np.nan + + # Determine primary test based on choice + if test_choice == "parametric": + row_data["primary_stat"] = row_data.get("anova_stat", np.nan) + row_data["primary_pval"] = row_data.get("anova_pval", np.nan) + elif test_choice == "welch": + row_data["primary_stat"] = row_data.get("welch_stat", np.nan) + row_data["primary_pval"] = row_data.get("welch_pval", np.nan) + else: # nonparametric (default) + row_data["primary_stat"] = row_data.get("kruskal_stat", np.nan) + row_data["primary_pval"] = row_data.get("kruskal_pval", np.nan) + else: + row_data["kruskal_stat"] = np.nan + row_data["kruskal_pval"] = np.nan + row_data["anova_stat"] = np.nan + row_data["anova_pval"] = np.nan + row_data["welch_stat"] = np.nan + row_data["welch_pval"] = np.nan + row_data["primary_stat"] = np.nan + row_data["primary_pval"] = np.nan + + # Pairwise tests + for g1, g2 in combinations(groups, 2): + v1, v2 = group_values.get(g1, []), group_values.get(g2, []) + pair_key = f"{g1}_vs_{g2}" + + if len(v1) >= 1 and len(v2) >= 1: + # Mann-Whitney U (nonparametric, always compute) + try: + stat, pval = stats.mannwhitneyu(v1, v2, alternative="two-sided") + row_data[f"mwu_stat_{pair_key}"] = stat + row_data[f"mwu_pval_{pair_key}"] = pval + except Exception: + row_data[f"mwu_stat_{pair_key}"] = np.nan + row_data[f"mwu_pval_{pair_key}"] = np.nan + + # Student t-test (parametric, equal variances) + if len(v1) >= 2 and len(v2) >= 2: + try: + stat, pval = stats.ttest_ind(v1, v2, equal_var=True) + row_data[f"student_stat_{pair_key}"] = stat + row_data[f"student_pval_{pair_key}"] = pval + except Exception: + row_data[f"student_stat_{pair_key}"] = np.nan + row_data[f"student_pval_{pair_key}"] = np.nan + else: + row_data[f"student_stat_{pair_key}"] = np.nan + row_data[f"student_pval_{pair_key}"] = np.nan + + # Welch t-test (parametric, unequal variances) + if len(v1) >= 2 and len(v2) >= 2: + try: + stat, pval = stats.ttest_ind(v1, v2, equal_var=False) + row_data[f"welch_stat_{pair_key}"] = stat + row_data[f"welch_pval_{pair_key}"] = pval + except Exception: + row_data[f"welch_stat_{pair_key}"] = np.nan + row_data[f"welch_pval_{pair_key}"] = np.nan + else: + row_data[f"welch_stat_{pair_key}"] = np.nan + row_data[f"welch_pval_{pair_key}"] = np.nan + + # Effect size (rank-biserial for Mann-Whitney) + n1, n2 = len(v1), len(v2) + if n1 * n2 > 0 and not np.isnan(row_data.get(f"mwu_stat_{pair_key}", np.nan)): + row_data[f"effect_size_{pair_key}"] = 1 - 2 * row_data[f"mwu_stat_{pair_key}"] / (n1 * n2) + + # Log2 fold change of means + m1, m2 = np.mean(v1), np.mean(v2) + if m2 != 0 and m1 != 0: + row_data[f"log2fc_{pair_key}"] = np.log2(m1 / m2) if m1 > 0 and m2 > 0 else np.nan + row_data[f"diff_{pair_key}"] = m1 - m2 + else: + row_data[f"mwu_stat_{pair_key}"] = np.nan + row_data[f"mwu_pval_{pair_key}"] = np.nan + row_data[f"student_stat_{pair_key}"] = np.nan + row_data[f"student_pval_{pair_key}"] = np.nan + row_data[f"welch_stat_{pair_key}"] = np.nan + row_data[f"welch_pval_{pair_key}"] = np.nan + + rows.append(row_data) + + # Convert rows to DataFrame + res_df = pd.DataFrame(rows) + + # Multiple testing correction (FDR Benjamini-Hochberg) for ALL p-value columns + for col in res_df.columns: + if col.endswith("_pval"): + pvals = res_df[col].values + mask = ~np.isnan(pvals) + if mask.sum() > 0: + _, corrected, _, _ = multipletests(pvals[mask], method="fdr_bh") + adj_col = col.replace("_pval", "_padj") + res_df[adj_col] = np.nan + res_df.loc[mask, adj_col] = corrected + + res_df.to_csv(os.path.join(outdir, "differential_results.csv"), index=False) + results["table"] = os.path.join(outdir, "differential_results.csv") + results["stat_test_method"] = stat_test + + # Volcano-like plot for each pairwise comparison (using primary test) + for g1, g2 in combinations(groups, 2): + pair_key = f"{g1}_vs_{g2}" + diff_col = f"diff_{pair_key}" + + # Choose p-value column based on test method + if stat_test == "parametric": + pval_col = f"student_pval_{pair_key}" + padj_col = f"student_padj_{pair_key}" + test_label = "Student" + elif stat_test == "welch": + pval_col = f"welch_pval_{pair_key}" + padj_col = f"welch_padj_{pair_key}" + test_label = "Welch" + else: # nonparametric or auto (default to nonparametric) + pval_col = f"mwu_pval_{pair_key}" + padj_col = f"mwu_padj_{pair_key}" + test_label = "Mann-Whitney U" + if diff_col in res_df.columns and padj_col in res_df.columns: + pdf = res_df[[diff_col, padj_col]].dropna() + if len(pdf) > 0: + fig, ax = plt.subplots(figsize=(7, 5)) + neg_log_p = -np.log10(pdf[padj_col].clip(lower=1e-300)) + colors = ["red" if p < 0.05 else "grey" for p in pdf[padj_col]] + ax.scatter(pdf[diff_col], neg_log_p, c=colors, alpha=0.6, s=20) + ax.axhline(-np.log10(0.05), color="blue", linestyle="--", alpha=0.5) + ax.set_xlabel(f"Difference ({g1} - {g2})") + ax.set_ylabel("-log10(adjusted p-value)") + ax.set_title(f"Volcano plot: {g1} vs {g2} ({test_label})") + save_fig(fig, os.path.join(outdir, f"volcano_{pair_key}.png")) + + return results + + +# --------------------------------------------------------------------------- +# Multivariate Analysis (PCA, hierarchical clustering) +# --------------------------------------------------------------------------- + +def multivariate_analysis(df, sample_cols, sample_info, outdir): + safe_mkdir(outdir) + ndf = numeric_df(df, sample_cols) + results = {} + + # Transpose: samples as rows, BMKs as columns + mat = ndf[sample_cols].T.copy() + mat.columns = df["ID"].values if "ID" in df.columns else range(len(df)) + mat = mat.dropna(axis=1, how="all").fillna(0) + + if mat.shape[1] < 2 or mat.shape[0] < 2: + log.warning("Not enough data for multivariate analysis") + return results + + # PCA + scaler = StandardScaler() + scaled = scaler.fit_transform(mat) + n_components = min(mat.shape[0], mat.shape[1], 10) + pca = PCA(n_components=n_components) + coords = pca.fit_transform(scaled) + var_exp = pca.explained_variance_ratio_ + + labels = [s["group"] for s in sample_info if s["col"] in mat.index] + sample_labels = [s["sample"] for s in sample_info if s["col"] in mat.index] + + pca_df = pd.DataFrame(coords[:, :min(5, n_components)], + columns=[f"PC{i+1}" for i in range(min(5, n_components))], + index=mat.index) + pca_df["group"] = labels + pca_df["sample"] = sample_labels + pca_df.to_csv(os.path.join(outdir, "pca_coordinates.csv")) + + # Variance explained + var_df = pd.DataFrame({"PC": [f"PC{i+1}" for i in range(len(var_exp))], + "variance_explained": var_exp, + "cumulative": np.cumsum(var_exp)}) + var_df.to_csv(os.path.join(outdir, "pca_variance.csv"), index=False) + + # Scree plot + fig, ax = plt.subplots(figsize=(6, 4)) + ax.bar(range(1, len(var_exp) + 1), var_exp * 100, alpha=0.7, label="Individual") + ax.plot(range(1, len(var_exp) + 1), np.cumsum(var_exp) * 100, "ro-", label="Cumulative") + ax.set_xlabel("Principal Component") + ax.set_ylabel("Variance Explained (%)") + ax.set_title("PCA Scree Plot") + ax.legend() + save_fig(fig, os.path.join(outdir, "pca_scree.png")) + + # PCA biplot PC1 vs PC2 + if n_components >= 2: + fig, ax = plt.subplots(figsize=(8, 6)) + unique_groups = sorted(set(labels)) + colors = plt.cm.Set1(np.linspace(0, 1, max(len(unique_groups), 1))) + for i, g in enumerate(unique_groups): + mask = [l == g for l in labels] + ax.scatter(coords[mask, 0], coords[mask, 1], c=[colors[i]], label=g, s=80, alpha=0.8) + for j, m in enumerate(mask): + if m: + ax.annotate(sample_labels[j], (coords[j, 0], coords[j, 1]), + fontsize=7, alpha=0.7) + ax.set_xlabel(f"PC1 ({var_exp[0]*100:.1f}%)") + ax.set_ylabel(f"PC2 ({var_exp[1]*100:.1f}%)") + ax.set_title("PCA - Samples") + ax.legend() + save_fig(fig, os.path.join(outdir, "pca_biplot.png")) + + # Hierarchical clustering on samples + if mat.shape[0] >= 2: + try: + dist = pdist(scaled, metric="euclidean") + Z = linkage(dist, method="ward") + fig, ax = plt.subplots(figsize=(max(6, len(sample_cols) * 0.8), 5)) + dendrogram(Z, labels=[s.split("::")[1] for s in mat.index], ax=ax, leaf_rotation=90) + ax.set_title("Hierarchical Clustering of Samples") + save_fig(fig, os.path.join(outdir, "dendrogram_samples.png")) + except Exception as e: + log.warning(f"Dendrogram failed: {e}") + + # Sample correlation heatmap + corr = mat.T.corr() + corr.to_csv(os.path.join(outdir, "sample_correlation.csv")) + fig, ax = plt.subplots(figsize=(max(6, len(sample_cols) * 0.8), max(5, len(sample_cols) * 0.6))) + short_labels = [s.split("::")[0] + "::" + s.split("::")[1] for s in corr.index] + sns.heatmap(corr, annot=True, fmt=".2f", cmap="RdBu_r", center=0, ax=ax, + xticklabels=short_labels, yticklabels=short_labels) + ax.set_title("Sample Correlation Matrix") + save_fig(fig, os.path.join(outdir, "correlation_heatmap.png")) + + results["pca_variance"] = var_df.to_dict(orient="records") + return results + + +# --------------------------------------------------------------------------- +# Correlation / Network Analysis +# --------------------------------------------------------------------------- + +def correlation_network(df, sample_cols, outdir, max_bmks=50): + """Compute BMK-BMK correlation matrix and heatmap. + + MEMORY FIX: Reduced max_bmks from 200 to 50 to prevent OOM. + - Correlation matrix: N×N float64 = N² × 8 bytes + - With N=200: 320 KB per section × 1491 sections = 477 MB accumulated + - With N=50: 20 KB per section × 1491 sections = 30 MB accumulated (16x reduction) + - Heatmap figure with N=50: ~4 MB instead of 64 MB per image + """ + safe_mkdir(outdir) + ndf = numeric_df(df, sample_cols) + results = {} + + # BMK-BMK correlation (limit to top variable BMKs) + mat = ndf[sample_cols].copy() + mat.index = df["ID"].values if "ID" in df.columns else range(len(df)) + mat = mat.dropna(axis=0, how="all") + + # Early exit if too few BMKs + if len(mat) < 3: + log.info(f" Skipping correlation analysis: only {len(mat)} BMKs (need at least 3)") + return results + + var = mat.var(axis=1) + top_idx = var.nlargest(min(max_bmks, len(var))).index + sub = mat.loc[top_idx] + + if len(sub) >= 3: + # Compute correlation matrix (memory intensive: N×N) + corr = sub.T.corr() + corr.to_csv(os.path.join(outdir, "bmk_correlation.csv")) + + # Limit figure size to prevent excessive memory usage + fig_width = min(15, max(6, len(sub) * 0.15)) # Cap at 15 inches + fig_height = min(12, max(5, len(sub) * 0.12)) # Cap at 12 inches + + fig, ax = plt.subplots(figsize=(fig_width, fig_height)) + sns.heatmap(corr, cmap="RdBu_r", center=0, ax=ax, + xticklabels=len(sub) < 50, yticklabels=len(sub) < 50) + ax.set_title(f"BMK Correlation (top {len(sub)} by variance)") + save_fig(fig, os.path.join(outdir, "bmk_correlation_heatmap.png")) + results["n_bmks_corr"] = len(sub) + + return results + + +# --------------------------------------------------------------------------- +# Feature Selection / Biomarker Ranking +# --------------------------------------------------------------------------- + +def feature_ranking(df, sample_cols, sample_info, outdir, max_bmks=500, bmk_filter_cols=None): + """Feature ranking using RandomForest importance. + + Args: + bmk_filter_cols: list of column names to try in order for filtering significant BMKs + """ + safe_mkdir(outdir) + ndf = numeric_df(df, sample_cols) + groups = sorted(set(s["group"] for s in sample_info)) + results = {} + + if len(groups) < 2: + return results + + # Build matrix: rows = samples, cols = BMKs + mat = ndf[sample_cols].T.copy() + bmk_ids = df["ID"].values if "ID" in df.columns else [str(i) for i in range(len(df))] + mat.columns = bmk_ids + mat = mat.fillna(0) + + labels = [group_for_col(sample_info, c) for c in sample_cols] + le = LabelEncoder() + y = le.fit_transform(labels) + + # Default filter column priority if not specified + if bmk_filter_cols is None: + bmk_filter_cols = ["primary_padj", "kruskal_padj", "welch_padj", "anova_padj"] + + # Filter BMKs: cascade through filter columns to collect significant ones + # PASS 1: adjusted p-values (*_padj < 0.05) + # PASS 2: raw p-values (*_pval < 0.05) if not enough + diff_file = os.path.join(os.path.dirname(outdir), "differential", "differential_results.csv") + selected_bmks = set() + n_bmks_total = len(mat.columns) + filter_log = [] + + if os.path.exists(diff_file): + diff_df = pd.read_csv(diff_file) + + # PASS 1: Cascade through adjusted p-values (strong evidence) + for filter_col in bmk_filter_cols: + if len(selected_bmks) >= max_bmks: + break + + if filter_col in diff_df.columns and "ID" in diff_df.columns: + sig_bmks = diff_df[diff_df[filter_col] < 0.05]["ID"].values + valid_sig_bmks = [b for b in sig_bmks if b in mat.columns and b not in selected_bmks] + remaining = max_bmks - len(selected_bmks) + to_add = valid_sig_bmks[:remaining] + + if to_add: + selected_bmks.update(to_add) + filter_log.append(f"{filter_col}: +{len(to_add)}") + + # PASS 2: If not enough, cascade through raw p-values (suggestive evidence) + if len(selected_bmks) < max_bmks: + filter_log.append("|") + for filter_col in bmk_filter_cols: + if len(selected_bmks) >= max_bmks: + break + + # Convert *_padj to *_pval + pval_col = filter_col.replace("_padj", "_pval") + if pval_col != filter_col and pval_col in diff_df.columns and "ID" in diff_df.columns: + sig_bmks = diff_df[diff_df[pval_col] < 0.05]["ID"].values + valid_sig_bmks = [b for b in sig_bmks if b in mat.columns and b not in selected_bmks] + remaining = max_bmks - len(selected_bmks) + to_add = valid_sig_bmks[:remaining] + + if to_add: + selected_bmks.update(to_add) + filter_log.append(f"{pval_col}: +{len(to_add)}") + + if selected_bmks: + selected_bmks = list(selected_bmks) + log.info(f" Collected {len(selected_bmks)} significant BMKs from {n_bmks_total} total for ranking") + log.info(f" Cascade: {' → '.join(filter_log)}") + else: + selected_bmks = None + + # If not enough significant BMKs, complete with top variable ones + if selected_bmks is None or len(selected_bmks) < 10: + if selected_bmks is None: + selected_bmks = [] + var_scores = mat.var(axis=0) + # Exclude already selected + var_scores = var_scores[[b for b in var_scores.index if b not in selected_bmks]] + remaining = max_bmks - len(selected_bmks) + top_var = var_scores.nlargest(min(remaining, len(var_scores))).index.tolist() + selected_bmks.extend(top_var) + log.info(f" Completed with {len(top_var)} top variable BMKs (total: {len(selected_bmks)} from {n_bmks_total})") + + # Filter matrix to selected BMKs + mat_filtered = mat[selected_bmks].copy() + + # Random Forest importance + if len(set(y)) >= 2 and mat_filtered.shape[0] >= 4 and mat_filtered.shape[1] >= 2: + try: + rf = RandomForestClassifier(n_estimators=100, random_state=42, n_jobs=1, max_depth=10) + rf.fit(mat_filtered, y) + importances = rf.feature_importances_ + imp_df = pd.DataFrame({"bmk": mat_filtered.columns, "importance": importances}) + imp_df = imp_df.sort_values("importance", ascending=False) + imp_df.to_csv(os.path.join(outdir, "rf_importance.csv"), index=False) + results["rf_top10"] = imp_df.head(10).to_dict(orient="records") + + # Plot top 30 + top_n = min(30, len(imp_df)) + top = imp_df.head(top_n) + fig, ax = plt.subplots(figsize=(8, max(4, top_n * 0.3))) + ax.barh(range(top_n), top["importance"].values[::-1]) + ax.set_yticks(range(top_n)) + ax.set_yticklabels(top["bmk"].values[::-1], fontsize=7) + ax.set_xlabel("Importance") + ax.set_title(f"Random Forest Feature Importance (top 30 from {len(selected_bmks)} BMKs)") + save_fig(fig, os.path.join(outdir, "rf_importance.png")) + except Exception as e: + log.warning(f"RF importance failed: {e}") + + # Variance-based ranking + var_scores = ndf[sample_cols].var(axis=1) + var_df = pd.DataFrame({"bmk": bmk_ids, "variance": var_scores.values}) + var_df = var_df.sort_values("variance", ascending=False) + var_df.to_csv(os.path.join(outdir, "variance_ranking.csv"), index=False) + + # Kruskal-Wallis based ranking (from differential analysis if available) + diff_file = os.path.join(os.path.dirname(outdir), "differential", "differential_results.csv") + if os.path.exists(diff_file): + diff_df = pd.read_csv(diff_file) + if "kruskal_padj" in diff_df.columns: + rank_df = diff_df[["ID", "kruskal_padj"]].dropna().sort_values("kruskal_padj") + rank_df.to_csv(os.path.join(outdir, "kruskal_ranking.csv"), index=False) + results["kruskal_top10"] = rank_df.head(10).to_dict(orient="records") + + return results + + +# --------------------------------------------------------------------------- +# Classification / Predictive Modeling +# --------------------------------------------------------------------------- + +def classification_analysis(df, sample_cols, sample_info, outdir, max_bmks=500, bmk_filter_cols=None): + """Classification analysis with multiple classifiers. + + Args: + bmk_filter_cols: list of column names to try in order for filtering significant BMKs + """ + safe_mkdir(outdir) + ndf = numeric_df(df, sample_cols) + groups = sorted(set(s["group"] for s in sample_info)) + results = {} + + if len(groups) < 2: + return results + + mat = ndf[sample_cols].T.fillna(0).copy() + bmk_ids = df["ID"].values if "ID" in df.columns else [str(i) for i in range(len(df))] + mat.columns = bmk_ids + labels = [group_for_col(sample_info, c) for c in sample_cols] + le = LabelEncoder() + y = le.fit_transform(labels) + + n_samples = len(y) + n_classes = len(set(y)) + n_bmks_total = mat.shape[1] + + if n_samples < 4 or n_classes < 2: + log.warning("Not enough samples for classification") + return results + + # Default filter column priority if not specified + if bmk_filter_cols is None: + bmk_filter_cols = ["primary_padj", "kruskal_padj", "welch_padj", "anova_padj"] + + # Filter BMKs: cascade through filter columns to collect significant ones + diff_file = os.path.join(os.path.dirname(outdir), "differential", "differential_results.csv") + selected_bmks = set() + filter_log = [] + + if os.path.exists(diff_file): + diff_df = pd.read_csv(diff_file) + + # Cascade through filter columns in priority order + for filter_col in bmk_filter_cols: + if len(selected_bmks) >= max_bmks: + break # Stop if we have enough + + if filter_col in diff_df.columns and "ID" in diff_df.columns: + # Get significant BMKs from this column + sig_bmks = diff_df[diff_df[filter_col] < 0.05]["ID"].values + valid_sig_bmks = [b for b in sig_bmks if b in mat.columns and b not in selected_bmks] + + # Add up to remaining slots + remaining = max_bmks - len(selected_bmks) + to_add = valid_sig_bmks[:remaining] + + if to_add: + selected_bmks.update(to_add) + filter_log.append(f"{filter_col}: +{len(to_add)} BMKs") + + if selected_bmks: + selected_bmks = list(selected_bmks) + log.info(f" Collected {len(selected_bmks)} significant BMKs from {n_bmks_total} total for classification") + log.info(f" Cascade: {' → '.join(filter_log)}") + else: + selected_bmks = None + + # If not enough significant BMKs, complete with top variable ones + if selected_bmks is None or len(selected_bmks) < 10: + if selected_bmks is None: + selected_bmks = [] + var_scores = mat.var(axis=0) + # Exclude already selected + var_scores = var_scores[[b for b in var_scores.index if b not in selected_bmks]] + remaining = max_bmks - len(selected_bmks) + top_var = var_scores.nlargest(min(remaining, len(var_scores))).index.tolist() + selected_bmks.extend(top_var) + log.info(f" Completed with {len(top_var)} top variable BMKs (total: {len(selected_bmks)} from {n_bmks_total})") + + # Filter matrix to selected BMKs + mat = mat[selected_bmks].copy() + n_bmks = mat.shape[1] + + if n_bmks < max(2, n_classes): + log.warning(f"Not enough biomarkers for classification ({n_bmks} BMKs after filtering, need at least {max(2, n_classes)})") + return results + + # Use LOO or stratified k-fold depending on sample count + if n_samples < 10: + cv = LeaveOneOut() + cv_name = "LOO" + else: + cv = StratifiedKFold(n_splits=min(5, n_samples), shuffle=True, random_state=42) + cv_name = "StratifiedKFold" + + classifiers = {} + classifiers["RandomForest"] = RandomForestClassifier(n_estimators=50, random_state=42, n_jobs=1, max_depth=10) + + if n_classes == 2: + classifiers["GradientBoosting"] = GradientBoostingClassifier(n_estimators=50, random_state=42) + + # LDA only if feasible + if n_samples > n_classes and mat.shape[1] > 0: + try: + classifiers["LDA"] = LinearDiscriminantAnalysis() + except Exception: + pass + + clf_results = {} + for name, clf in classifiers.items(): + try: + scores = cross_val_score(clf, mat, y, cv=cv, scoring="accuracy") + clf_results[name] = { + "mean_accuracy": float(np.mean(scores)), + "std_accuracy": float(np.std(scores)), + "cv_method": cv_name, + "scores": scores.tolist(), + "n_bmks_used": n_bmks, + } + except Exception as e: + log.warning(f"Classifier {name} failed: {e}") + + results["classifiers"] = clf_results + clf_df = pd.DataFrame([ + {"classifier": k, "mean_accuracy": v["mean_accuracy"], + "std_accuracy": v["std_accuracy"], "cv_method": v["cv_method"], + "n_bmks_used": v.get("n_bmks_used", n_bmks)} + for k, v in clf_results.items() + ]) + clf_df.to_csv(os.path.join(outdir, "classification_results.csv"), index=False) + + # Plot + if clf_results: + fig, ax = plt.subplots(figsize=(6, 4)) + names = list(clf_results.keys()) + means = [clf_results[n]["mean_accuracy"] for n in names] + stds = [clf_results[n]["std_accuracy"] for n in names] + ax.bar(names, means, yerr=stds, alpha=0.7, capsize=5) + ax.set_ylabel("Accuracy") + ax.set_title(f"Classification Accuracy ({cv_name}, {n_bmks} BMKs)") + ax.set_ylim(0, 1.1) + save_fig(fig, os.path.join(outdir, "classification_accuracy.png")) + + return results + + +# --------------------------------------------------------------------------- +# Stability / Robustness (replicate concordance) +# --------------------------------------------------------------------------- + +def stability_analysis(df, sample_cols, sample_info, outdir): + safe_mkdir(outdir) + ndf = numeric_df(df, sample_cols) + results = {} + + # Group by (group, sample) and check replicate correlation + rep_groups = defaultdict(list) + for s in sample_info: + key = (s["group"], s["sample"]) + rep_groups[key].append(s["col"]) + + rep_corrs = [] + for key, cols in rep_groups.items(): + if len(cols) >= 2: + for c1, c2 in combinations(cols, 2): + v1 = pd.to_numeric(ndf[c1], errors="coerce") + v2 = pd.to_numeric(ndf[c2], errors="coerce") + mask = v1.notna() & v2.notna() + if mask.sum() >= 3: + r, p = stats.pearsonr(v1[mask], v2[mask]) + rep_corrs.append({ + "group": key[0], "sample": key[1], + "rep1": c1, "rep2": c2, + "pearson_r": r, "pval": p + }) + + if rep_corrs: + rc_df = pd.DataFrame(rep_corrs) + rc_df.to_csv(os.path.join(outdir, "replicate_correlation.csv"), index=False) + results["replicate_correlations"] = rc_df.to_dict(orient="records") + + # Coefficient of variation per BMK across all samples + cv_vals = ndf[sample_cols].apply(lambda row: row.std() / row.mean() if row.mean() != 0 else np.nan, axis=1) + cv_df = pd.DataFrame({"bmk": df["ID"].values if "ID" in df.columns else range(len(df)), + "cv": cv_vals.values}) + cv_df = cv_df.sort_values("cv") + cv_df.to_csv(os.path.join(outdir, "coefficient_variation.csv"), index=False) + + # CV histogram + fig, ax = plt.subplots(figsize=(6, 4)) + ax.hist(cv_df["cv"].dropna(), bins=30, alpha=0.7, edgecolor="black") + ax.set_xlabel("Coefficient of Variation") + ax.set_ylabel("Count") + ax.set_title("Distribution of CV across BMKs") + save_fig(fig, os.path.join(outdir, "cv_histogram.png")) + + return results + + +# --------------------------------------------------------------------------- +# Batch Effect Detection +# --------------------------------------------------------------------------- + +def batch_effect_analysis(df, sample_cols, sample_info, outdir): + safe_mkdir(outdir) + ndf = numeric_df(df, sample_cols) + results = {} + + # Check if samples cluster by batch (sample) rather than group + # Use PCA and check if samples separate by sample-id vs group + mat = ndf[sample_cols].T.fillna(0) + if mat.shape[0] < 3 or mat.shape[1] < 2: + return results + + scaler = StandardScaler() + scaled = scaler.fit_transform(mat) + pca = PCA(n_components=min(3, mat.shape[0], mat.shape[1])) + coords = pca.fit_transform(scaled) + + groups = [group_for_col(sample_info, c) for c in sample_cols] + samples = [s["sample"] for s in sample_info if s["col"] in sample_cols[:len(groups)]] + + # Plot colored by sample (batch) + fig, axes = plt.subplots(1, 2, figsize=(14, 5)) + unique_groups = sorted(set(groups)) + colors_g = plt.cm.Set1(np.linspace(0, 1, max(len(unique_groups), 1))) + for i, g in enumerate(unique_groups): + mask = [l == g for l in groups] + axes[0].scatter(coords[mask, 0], coords[mask, 1], c=[colors_g[i]], label=g, s=80) + axes[0].set_title("PCA colored by Group") + axes[0].set_xlabel("PC1") + axes[0].set_ylabel("PC2") + axes[0].legend() + + unique_samples = sorted(set(samples)) + colors_s = plt.cm.Set2(np.linspace(0, 1, max(len(unique_samples), 1))) + for i, s in enumerate(unique_samples): + mask = [l == s for l in samples] + axes[1].scatter(coords[mask, 0], coords[mask, 1], c=[colors_s[i]], label=s, s=80) + axes[1].set_title("PCA colored by Sample (batch)") + axes[1].set_xlabel("PC1") + axes[1].set_ylabel("PC2") + axes[1].legend(fontsize=7) + + save_fig(fig, os.path.join(outdir, "batch_effect_pca.png")) + return results + + +# --------------------------------------------------------------------------- +# Comprehensive heatmap for a section +# --------------------------------------------------------------------------- + +def section_heatmap(df, sample_cols, sample_info, outdir, title="", max_rows=100): + safe_mkdir(outdir) + ndf = numeric_df(df, sample_cols) + + # Sort sample_cols by group, then rep + sample_order = sorted(sample_info, key=lambda s: (s["group"], s["rep"], s["sample"])) + sorted_cols = [s["col"] for s in sample_order if s["col"] in sample_cols] + + mat = ndf[sorted_cols].copy() + mat = mat.dropna(how="all") + + # CRITICAL: Limit rows BEFORE any processing to avoid matplotlib memory errors + # With 200k+ BMKs, matplotlib cannot create images (pixel limit: 2^23 per dimension) + if len(mat) > max_rows: + var = mat.var(axis=1) + mat = mat.loc[var.nlargest(max_rows).index] + + if len(mat) == 0: + return + + # Create descriptive Y-axis labels + # Add SeqID/Ptype/Mode affixes when multiple values exist + if "Mode" in df.columns and "Ctype" in df.columns: + # Check if we need SeqID/Ptype/Mode affixes (mixed BMK types or chromosomes) + df_subset = df.loc[mat.index] + has_multiple_seqids = "SeqID" in df_subset.columns and df_subset["SeqID"].nunique() > 1 + has_multiple_ptypes = "Ptype" in df_subset.columns and df_subset["Ptype"].nunique() > 1 + has_multiple_modes = "Mode" in df_subset.columns and df_subset["Mode"].nunique() > 1 + add_affixes = has_multiple_seqids or has_multiple_ptypes or has_multiple_modes + + ylabels = [] + sort_keys = [] # For custom sorting: (priority, label) + for idx in mat.index: + mode = df.loc[idx, "Mode"] if "Mode" in df.columns else "" + ctype = df.loc[idx, "Ctype"] if "Ctype" in df.columns else "" + ptype = df.loc[idx, "Ptype"] if "Ptype" in df.columns else "" + seqid = df.loc[idx, "SeqID"] if "SeqID" in df.columns else "" + + # Build base label + if mode == "all_sites": + base_label = "all_sites" + elif ctype and ctype != ".": + base_label = str(ctype) + elif "ID" in df.columns: + base_label = str(df.loc[idx, "ID"]) + else: + base_label = str(idx) + + # Add affixes only if values vary across rows + if add_affixes: + # SeqID as PREFIX (first, if multiple chromosomes) + if has_multiple_seqids and seqid and seqid != ".": + label = f"{seqid}_{base_label}" + else: + label = base_label + + # Ptype as PREFIX or SUFFIX depending on context + if ptype and ptype != ".": + # If Ptypes vary OR SeqIDs vary, add Ptype + if has_multiple_ptypes or has_multiple_seqids: + if has_multiple_seqids: + # For all_sequence_bmks: SeqID_Ptype_Ctype + label = f"{seqid}_{ptype}_{base_label}" if seqid and seqid != "." else f"{ptype}_{base_label}" + else: + # For all_global_bmks: Ptype_Ctype + label = f"{ptype}_{base_label}" + + # Mode as SUFFIX (after), if not "all_sites" + if mode and mode != "all_sites": + label += f"_{mode}" + + # Determine sort priority for mixed BMKs: + # 1. all_sites first + # 2. no Ptype (Ptype == ".") + # 3. rest alphabetically + if mode == "all_sites": + priority = 0 + elif not ptype or ptype == ".": + priority = 1 + else: + priority = 2 + sort_keys.append((priority, label, idx)) + else: + label = base_label + sort_keys.append((0, label, idx)) + + ylabels.append(label) + + # Sort rows if affixes were added (mixed BMK types) + if add_affixes and len(sort_keys) > 1: + # Sort by (priority, label alphabetically) + sorted_items = sorted(sort_keys, key=lambda x: (x[0], x[1])) + sorted_indices = [item[2] for item in sorted_items] + mat = mat.loc[sorted_indices] + ylabels = [item[1] for item in sorted_items] + + mat.index = ylabels + + mat = mat.dropna(how="all") + if len(mat) == 0: + return + + if len(mat) > max_rows: + # Keep most variable + var = mat.var(axis=1) + mat = mat.loc[var.nlargest(max_rows).index] + + fig_h = max(4, len(mat) * 0.25) + fig_w = max(6, len(sorted_cols) * 0.8) + fig, ax = plt.subplots(figsize=(fig_w, fig_h)) + + # Create labels: group::sample::rep for x-axis + xlabels = [f"{s['group']}::{s['sample']}::{s['rep']}" for s in sample_order if s['col'] in sorted_cols] + + # Determine y-axis label display and font size + show_ylabels = len(mat) < 200 # Show labels up to 200 rows + yticklabel_fontsize = max(4, min(8, 500 / len(mat))) # Adaptive font size + + sns.heatmap(mat.astype(float), cmap="YlOrRd", ax=ax, + xticklabels=xlabels, + yticklabels=show_ylabels, + linewidths=0.5 if len(mat) < 30 else 0, + cbar_kws={'label': 'Value'}) + + if show_ylabels: + ax.set_yticklabels(ax.get_yticklabels(), fontsize=yticklabel_fontsize, rotation=0) + + ax.set_xlabel("Sample (Group::Sample::Replicate)", fontsize=10) + ax.set_ylabel("Biomarker Type" if show_ylabels else f"Biomarker ({len(mat)} total)", fontsize=10) + ax.set_title(title or "Heatmap") + plt.xticks(rotation=90, fontsize=8) + save_fig(fig, os.path.join(outdir, "heatmap.png")) + + +# --------------------------------------------------------------------------- +# Top-level orchestration per section +# --------------------------------------------------------------------------- + +def analyze_section(df, sample_cols, sample_info, outdir, section_name, stat_test="auto", bmk_filter_cols=None, max_bmks=500): + """Run all analyses on a filtered dataframe for a given report section.""" + len_df = len(df) + len_sample_cols = len(sample_cols) + if len_df == 0: + log.info(f" Section '{section_name}' is empty, skipping.") + return {} + + # Create log prefix for this section + log_prefix = f"[{section_name[:50]}]" # Limit to 50 chars + + log.info(f"{log_prefix} Analyzing section: {section_name} ({len_df} BMKs, {len_sample_cols} samples)") + safe_mkdir(outdir) + results = {"n_bmks": len_df, "n_samples": len_sample_cols, "section": section_name} + + # EARLY PRE-FILTER: For very large sections (>3k BMKs), reduce to top 3k by variance + # This prevents memory crashes in parallel workers for all_sequence_bmks / all_global_bmks + MAX_BMKS_FOR_ANALYSIS = 3000 # Reduced from 5000 to prevent hangs + if len_df > MAX_BMKS_FOR_ANALYSIS: + log.info(f"{log_prefix} {len_df} BMKs exceeds {MAX_BMKS_FOR_ANALYSIS} limit") + log.info(f"{log_prefix} Pre-filtering to top {MAX_BMKS_FOR_ANALYSIS} BMKs by variance to prevent memory issues...") + ndf_prefilter = numeric_df(df, sample_cols) + variance_prefilter = ndf_prefilter[sample_cols].var(axis=1) + top_indices = variance_prefilter.nlargest(MAX_BMKS_FOR_ANALYSIS).index + df = df.loc[top_indices].reset_index(drop=True).copy() + len_df = len(df) + log.info(f"{log_prefix} Pre-filtered to {len_df} BMKs") + results["n_bmks_prefiltered"] = len_df + + # Save filtered data + df.to_csv(os.path.join(outdir, "data.csv"), index=False) + + # 1. QC + log.info(f"{log_prefix} QC analysis...") + try: + results["qc"] = qc_analysis(df.reset_index(drop=True), sample_cols, os.path.join(outdir, "qc")) + except Exception as e: + log.warning(f"{log_prefix} QC failed: {e}") + + # 2. Descriptive stats + log.info(f"{log_prefix} Descriptive stats...") + try: + results["descriptive"] = descriptive_stats(df.reset_index(drop=True), sample_cols, sample_info, os.path.join(outdir, "descriptive")) + except Exception as e: + log.warning(f"{log_prefix} Descriptive stats failed: {e}") + + # PRE-FILTER: Remove BMKs with near-zero variance for all subsequent analyses + # This prevents numerical issues in: differential tests, PCA, correlation, heatmaps, ML models + # QC and descriptive stats above use ALL BMKs to identify constants + ndf_tmp = numeric_df(df, sample_cols) + variance = ndf_tmp[sample_cols].var(axis=1) + variable_mask = variance > 1e-10 + n_filtered_out = (~variable_mask).sum() + + if n_filtered_out > 0: + log.info(f"{log_prefix} Removing {n_filtered_out} BMKs with near-zero variance (keeping {variable_mask.sum()} variable)") + df_variable = df[variable_mask].reset_index(drop=True).copy() + n_bmks_variable = len(df_variable) + else: + df_variable = df.copy() + n_bmks_variable = len_df + + # Early exit if no variable BMKs + if n_bmks_variable == 0: + log.warning(f"{log_prefix} No variable BMKs remaining after pre-filtering") + return results + + # Update results with variable BMK count + results["n_bmks_variable"] = n_bmks_variable + + # 3. Differential analysis (use variable BMKs only) + log.info(f"{log_prefix} Differential analysis ({n_bmks_variable} BMKs)...") + try: + results["differential"] = differential_analysis(df_variable, sample_cols, sample_info, os.path.join(outdir, "differential"), stat_test=stat_test) + except Exception as e: + log.warning(f"{log_prefix} Differential analysis failed: {e}") + + # 4. Multivariate (use variable BMKs only) + log.info(f"{log_prefix} Multivariate analysis ({n_bmks_variable} BMKs)...") + try: + results["multivariate"] = multivariate_analysis(df_variable, sample_cols, sample_info, os.path.join(outdir, "multivariate")) + except Exception as e: + log.warning(f"{log_prefix} Multivariate analysis failed: {e}") + + # 5. Correlation / Network (use variable BMKs only) + log.info(f"{log_prefix} Correlation / Network analysis ({n_bmks_variable} BMKs)...") + try: + results["correlation"] = correlation_network(df_variable, sample_cols, os.path.join(outdir, "correlation")) + except Exception as e: + log.warning(f"{log_prefix} Correlation analysis failed: {e}") + + # 6. Feature ranking (use variable BMKs only) + log.info(f"{log_prefix} Feature ranking ({n_bmks_variable} BMKs)...") + try: + results["ranking"] = feature_ranking(df_variable, sample_cols, sample_info, os.path.join(outdir, "ranking"), max_bmks=max_bmks, bmk_filter_cols=bmk_filter_cols) + except Exception as e: + log.warning(f"{log_prefix} Feature ranking failed: {e}") + + # 7. Classification (use variable BMKs only) + log.info(f"{log_prefix} Classification ({n_bmks_variable} BMKs)...") + try: + results["classification"] = classification_analysis(df_variable, sample_cols, sample_info, os.path.join(outdir, "classification"), max_bmks=max_bmks, bmk_filter_cols=bmk_filter_cols) + except Exception as e: + log.warning(f"{log_prefix} Classification failed: {e}") + + # 8. Stability (use variable BMKs only) + log.info(f"{log_prefix} Stability analysis ({n_bmks_variable} BMKs)...") + try: + results["stability"] = stability_analysis(df_variable, sample_cols, sample_info, os.path.join(outdir, "stability")) + except Exception as e: + log.warning(f"{log_prefix} Stability analysis failed: {e}") + + # 9. Batch effect (use variable BMKs only) + log.info(f"{log_prefix} Batch effect analysis ({n_bmks_variable} BMKs)...") + try: + results["batch"] = batch_effect_analysis(df_variable, sample_cols, sample_info, os.path.join(outdir, "batch")) + except Exception as e: + log.warning(f"{log_prefix} Batch effect analysis failed: {e}") + + # 10. Heatmap (use variable BMKs only) + log.info(f"{log_prefix} Heatmap generation ({n_bmks_variable} BMKs)...") + try: + section_heatmap(df_variable, sample_cols, sample_info, outdir, title=section_name) + log.info(f"{log_prefix} Heatmap generation completed") + except Exception as e: + log.warning(f"{log_prefix} Heatmap failed: {e}") + + log.info(f"{log_prefix} Section analysis completed successfully") + return results + + +# --------------------------------------------------------------------------- +# Build feature hierarchy tree +# --------------------------------------------------------------------------- + +def build_feature_tree(features_df): + """Build parent-child tree for features. + + Each row's 'direct parent' is the last element in the comma-separated + ParentIDs column (the first element is always '.'). + Top-level features have ParentIDs == '.'. + """ + tree = {} # id -> {"row": row, "children": []} + for _, row in features_df.iterrows(): + fid = row["ID"] + tree[fid] = {"data": row, "children": []} + + for _, row in features_df.iterrows(): + fid = row["ID"] + parents = str(row["ParentIDs"]) + if parents == ".": + continue # top-level + parts = [p.strip() for p in parents.split(",") if p.strip() != "."] + if parts: + direct_parent = parts[-1] + if direct_parent in tree: + tree[direct_parent]["children"].append(fid) + + # Find top-level features + top_ids = [fid for fid, node in tree.items() + if str(features_df.loc[features_df["ID"] == fid, "ParentIDs"].iloc[0]) == "."] + + return tree, top_ids + + +# --------------------------------------------------------------------------- +# Parallel execution wrapper +# --------------------------------------------------------------------------- + +def analyze_section_wrapper(args_tuple): + """Wrapper for analyze_section to be used with ProcessPoolExecutor. + + Args: + args_tuple: (df_source, sample_cols, sample_info, outdir, section_name, result_key, stat_test, bmk_filter_cols, max_bmks) + where df_source can be: + - dict: legacy format, converted back to DataFrame + - tuple: ('pickled', bytes_data) for memory-efficient transfer + + Returns: + (result_key, results_dict) + """ + import traceback + import signal + import gc + import pickle + import random + import time as time_module + df_source, sample_cols, sample_info, outdir, section_name, result_key, stat_test, bmk_filter_cols, max_bmks = args_tuple + + # OPTIMIZATION: Add startup jitter to avoid synchronized worker restarts. + # With spawn + max_tasks_per_child=20, all workers start together and may + # finish their 20th task at the same time → 6 workers restart simultaneously + # → 6 × 275MB = 1.65GB spike. Jitter spreads restarts over 2 seconds. + time_module.sleep(random.uniform(0, 2)) + + # Log at start with worker PID and memory + pid = os.getpid() + mem_start = None + if psutil: + try: + proc = psutil.Process(pid) + mem_start = proc.memory_info().rss / (1024 * 1024) # MB + log.info(f" ▶ START [PID {pid}, {mem_start:.0f}MB]: {section_name}") + except: + log.info(f" ▶ START [PID {pid}]: {section_name}") + else: + log.info(f" ▶ START [PID {pid}]: {section_name}") + + # Setup timeout alarm (120 seconds max per task) + def timeout_handler(signum, frame): + raise TimeoutError(f"Task exceeded 120 seconds: {section_name}") + + try: + # Set alarm for 120 seconds + signal.signal(signal.SIGALRM, timeout_handler) + signal.alarm(120) + + # Reconstruct DataFrame from source (dict or pickled bytes) + if isinstance(df_source, tuple) and df_source[0] == 'pickled': + # Memory-efficient: unpickle compressed data + import pickle + df = pickle.loads(df_source[1]) + log.info(f" ▶ [PID {pid}] DataFrame unpickled: {len(df)} rows") + else: + # Legacy: dict format + df = pd.DataFrame(df_source) + log.info(f" ▶ [PID {pid}] DataFrame reconstructed: {len(df)} rows") + + # Run analysis + results = analyze_section(df, sample_cols, sample_info, outdir, section_name, stat_test=stat_test, bmk_filter_cols=bmk_filter_cols, max_bmks=max_bmks) + + # Cancel alarm + signal.alarm(0) + + # MEMORY FIX #2: Return slim results dict (only file paths, not full data) + # Full results dict is ~500KB per task × 1491 tasks = 745MB accumulated in all_results + # global_ranking() only needs the differential.table path (50 bytes) + slim_results = { + "differential_table": os.path.join(outdir, "differential", "differential_results.csv") + } + + # Explicitly clean up to help garbage collector (critical for parallel execution) + del df + del df_source + del results # MEMORY FIX #3: Immediate cleanup of full results dict in worker + gc.collect() + + # Also force matplotlib to clean up any lingering figures + plt.close('all') + + # Log completion with memory + if psutil and mem_start: + try: + proc = psutil.Process(pid) + mem_end = proc.memory_info().rss / (1024 * 1024) + log.info(f" ✓ DONE [PID {pid}, {mem_end:.0f}MB, Δ{mem_end-mem_start:+.0f}MB]: {section_name}") + except: + log.info(f" ✓ DONE [PID {pid}]: {section_name}") + else: + log.info(f" ✓ DONE [PID {pid}]: {section_name}") + + return result_key, slim_results + + except TimeoutError as e: + signal.alarm(0) # Cancel alarm + log.error(f" ✗ TIMEOUT [PID {pid}]: {section_name} - {e}") + # Clean up even on error + try: + del df + del df_source + plt.close('all') + gc.collect() + except: + pass + raise + except Exception as e: + signal.alarm(0) # Cancel alarm + # Log full error details before crash + log.error(f" ✗ CRASH [PID {pid}]: {section_name}") + log.error(f" Error: {type(e).__name__}: {e}") + log.error(f" Stacktrace:\n{traceback.format_exc()}") + # Clean up even on error + try: + del df + del df_source + plt.close('all') + gc.collect() + except: + pass + raise + + +# --------------------------------------------------------------------------- +# Global Biomarker Ranking (across all sections) +# --------------------------------------------------------------------------- + +def global_ranking(all_results, outdir): + """Aggregate rankings across sections to produce a global ranking.""" + safe_mkdir(outdir) + + # Collect differential results + diff_files = [] + for vtype, vdata in all_results.items(): + for mtype, mdata in vdata.items(): + for section, sdata in mdata.items(): + # MEMORY FIX: sdata is now a slim dict with only "differential_table" key + diff_path = sdata.get("differential_table") + if diff_path and os.path.exists(diff_path): + ddf = pd.read_csv(diff_path) + # Defragment DataFrame before adding columns to avoid PerformanceWarning + ddf = ddf.copy() + # Add metadata columns + ddf["value_type"] = vtype + ddf["mtype"] = mtype + ddf["section"] = section + diff_files.append(ddf) + + if diff_files: + all_diff = pd.concat(diff_files, ignore_index=True) + # Rank by kruskal_padj + if "kruskal_padj" in all_diff.columns: + ranked = all_diff.dropna(subset=["kruskal_padj"]).sort_values("kruskal_padj") + ranked.to_csv(os.path.join(outdir, "global_ranking_kruskal.csv"), index=False) + + # Top 50 plot + top_n = min(50, len(ranked)) + top = ranked.head(top_n) + fig, ax = plt.subplots(figsize=(8, max(4, top_n * 0.3))) + neg_log_p = -np.log10(top["kruskal_padj"].clip(lower=1e-300)) + labels = top["ID"].astype(str) + " [" + top["section"].astype(str) + "]" + ax.barh(range(top_n), neg_log_p.values[::-1]) + ax.set_yticks(range(top_n)) + ax.set_yticklabels(labels.values[::-1], fontsize=6) + ax.set_xlabel("-log10(adjusted p-value)") + ax.set_title("Global BMK Ranking by Significance (top 50)") + + # Add reference lines for p-value thresholds + ax.axvline(-np.log10(0.05), color='red', linestyle='--', linewidth=1, alpha=0.7, label='p=0.05') + ax.axvline(-np.log10(0.1), color='orange', linestyle='--', linewidth=1, alpha=0.7, label='p=0.1') + ax.legend(loc='lower right', fontsize=8) + + save_fig(fig, os.path.join(outdir, "global_ranking_plot.png")) + + return outdir + + +# --------------------------------------------------------------------------- +# Main pipeline +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser( + description="barometer Biomarker Analysis Pipeline", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +STATISTICAL TESTS: + The pipeline ALWAYS computes ALL differential tests for each biomarker: + + Global tests (comparing all groups): + • Kruskal-Wallis → kruskal_pval, kruskal_padj + • ANOVA → anova_pval, anova_padj + • Welch ANOVA → welch_pval, welch_padj + + Pairwise tests (comparing 2 groups X vs Y): + • Mann-Whitney U (= Wilcoxon rank-sum) → mwu_pval_X_vs_Y, mwu_padj_X_vs_Y + • Student t-test → student_pval_X_vs_Y, student_padj_X_vs_Y + • Welch t-test → welch_pval_X_vs_Y, welch_padj_X_vs_Y + + Assumption tests (only with --stat-test auto): + • Shapiro-Wilk → shapiro_pvals (per group) + • Bartlett → bartlett_pval + + All tests receive FDR correction (Benjamini-Hochberg) → *_padj columns + All results are saved in differential_results.csv + + OPTION 1: --stat-test (controls which test is emphasized in plots) + This creates primary_pval and primary_padj columns that COPY the selected test: + + nonparametric : primary_* = kruskal_* (DEFAULT, robust) + parametric : primary_* = anova_* + welch : primary_* = welch_* + auto : primary_* = auto-selected test (varies per BMK) + kruskal : alias for nonparametric + + Example: --stat-test welch creates primary_padj as a copy of welch_padj + Volcano plots use the primary_* columns for visualization. + + OPTION 2: --bmk-filter (cascade priority for selecting significant BMKs) + Accepts a PRIORITY LIST of *_padj columns to collect significant BMKs + (padj < 0.05) for RandomForest and classification. + + The algorithm works in TWO-PASS CASCADE: + PASS 1 (adjusted p-values, strongest evidence): + 1. Take all BMKs with 1st column *_padj < 0.05 + 2. If < --max-bmks, add BMKs from 2nd column *_padj < 0.05 (not already selected) + 3. Continue until reaching --max-bmks or exhausting all *_padj columns + + PASS 2 (raw p-values, suggestive evidence, only if PASS 1 insufficient): + 4. Convert columns to *_pval (e.g., kruskal_padj → kruskal_pval) + 5. Add BMKs with *_pval < 0.05 (not already selected) + 6. Continue until reaching --max-bmks + + PASS 3 (if still insufficient): + 7. Complete with top variable BMKs (by variance) + + Default cascade: primary_padj → kruskal_padj → welch_padj → anova_padj + Then (if needed): primary_pval → kruskal_pval → welch_pval → anova_pval + + Available columns (choose your priority order): + • primary_padj : controlled by --stat-test (recommended) + • kruskal_padj : Kruskal-Wallis (robust, non-parametric) + • anova_padj : ANOVA (parametric, equal variances) + • welch_padj : Welch (parametric, unequal variances) + • mwu_padj_X_vs_Y : Mann-Whitney U for specific pair + • student_padj_X_vs_Y : Student t-test for specific pair + + Examples: + --bmk-filter kruskal_padj : only use Kruskal (padj then pval) + --bmk-filter primary_padj kruskal_padj : cascade from primary to kruskal + --bmk-filter welch_padj anova_padj : prioritize Welch, fallback to ANOVA + + OPTION 3: --max-bmks (maximum BMKs for ML models) + Sets the hard limit for RandomForest and classification (default: 500). + Prevents memory crashes with large datasets (e.g., 40,000+ BMKs). + Higher values = more accurate but slower and more memory-intensive. + +EXAMPLES: + # Default: cascade adjusted + raw p-values, max 500 BMKs: + ./barometer_analyze.py -a data.tsv -o results/ -j 4 + + # Use only Kruskal-Wallis (padj then pval if needed): + ./barometer_analyze.py -a data.tsv -o results/ --bmk-filter kruskal_padj + + # Prioritize Welch, fallback to Kruskal, use 1000 BMKs: + ./barometer_analyze.py -a data.tsv -o results/ \ + --stat-test welch --bmk-filter welch_padj kruskal_padj --max-bmks 1000 + + # Conservative: only most significant BMKs (200 max): + ./barometer_analyze.py -a data.tsv -o results/ --max-bmks 200 + + # Aggressive: maximize BMK collection with high limit: + ./barometer_analyze.py -a data.tsv -o results/ --max-bmks 1500 + """) + parser.add_argument("-a", "--aggregates", default=None, help="Aggregates TSV file (optional)") + parser.add_argument("-f", "--features", default=None, help="Features TSV file (optional)") + parser.add_argument("-o", "--outdir", default="barometer_results", help="Output directory") + parser.add_argument("-j", "--jobs", type=int, default=1, help="Number of parallel jobs (default: 1, use -1 for all CPUs)") + parser.add_argument("-v", "--value-types", nargs="+", default=None, help="Value types to analyze (e.g., espf espr). If not specified, all value types are analyzed.") + parser.add_argument("--agg-levels", nargs="+", default=None, help="Aggregate levels to analyze: global, sequence, feature. If not specified, all levels are analyzed.") + parser.add_argument("--feature-types", nargs="+", default=None, help="Feature types to analyze (e.g., gene exon RNA). If not specified, all feature types are analyzed.") + parser.add_argument("--stat-test", default="nonparametric", + choices=["auto", "parametric", "nonparametric", "welch", "kruskal"], + help="Statistical test selection: auto (test assumptions), parametric (Student/ANOVA), nonparametric (Mann-Whitney/Kruskal-Wallis, default), welch (Welch t-test/ANOVA)") + parser.add_argument("--bmk-filter", nargs="+", default=["primary_padj", "kruskal_padj", "welch_padj", "anova_padj"], + help="Priority list of column names for filtering significant BMKs (default: primary_padj kruskal_padj welch_padj anova_padj). BMKs are collected in cascade until --max-bmks is reached.") + parser.add_argument("--max-bmks", type=int, default=500, + help="Maximum number of BMKs to use for RandomForest and classification (default: 500). Prevents memory crashes with large datasets.") + args = parser.parse_args() + + # Determine number of workers + if args.jobs == -1: + n_jobs = os.cpu_count() + elif args.jobs < 1: + parser.error("--jobs must be >= 1 or -1 for all CPUs") + else: + n_jobs = args.jobs + + log.info(f"Using {n_jobs} parallel job(s)") + log.info(f"Statistical test method: {args.stat_test}") + log.info(f"BMK filtering cascade: {' → '.join(args.bmk_filter)}") + log.info(f"Max BMKs for ML models: {args.max_bmks}") + + # Check that at least one input file is provided + if not args.aggregates and not args.features: + parser.error("At least one of --aggregates or --features must be provided") + + outdir = args.outdir + safe_mkdir(outdir) + + # Load data + log.info("Loading data...") + agg_df = None + feat_df = None + + if args.aggregates: + if not os.path.exists(args.aggregates): + log.error(f"Aggregates file not found: {args.aggregates}") + sys.exit(1) + log.info(f"Loading aggregates from {args.aggregates}...") + agg_df = pd.read_csv(args.aggregates, sep="\t", dtype={"SeqID": str, "Start": str, "End": str, "Strand": str}) + agg_df.columns = agg_df.columns.str.strip() # Remove leading/trailing whitespace from column names + # Strip whitespace from string columns + for col in agg_df.select_dtypes(include=['object', 'string']).columns: + agg_df[col] = agg_df[col].str.strip() if agg_df[col].dtype in ['object', 'string'] else agg_df[col] + log.info(f" Aggregates: {len(agg_df)} rows") + else: + log.info("Skipping aggregates (no file provided)") + + if args.features: + if not os.path.exists(args.features): + log.error(f"Features file not found: {args.features}") + sys.exit(1) + log.info(f"Loading features from {args.features}...") + feat_df = pd.read_csv(args.features, sep="\t", dtype={"SeqID": str, "Start": str, "End": str, "Strand": str}) + feat_df.columns = feat_df.columns.str.strip() # Remove leading/trailing whitespace from column names + # Strip whitespace from string columns + for col in feat_df.select_dtypes(include=['object', 'string']).columns: + feat_df[col] = feat_df[col].str.strip() if feat_df[col].dtype in ['object', 'string'] else feat_df[col] + log.info(f" Features: {len(feat_df)} rows") + else: + log.info("Skipping features (no file provided)") + + # Parse sample information from whichever file is available + sample_df = agg_df if agg_df is not None else feat_df + sample_info = parse_sample_columns(sample_df.columns) + all_value_types = get_value_types(sample_info) + + # Filter value types if specified + if args.value_types: + value_types = [vt for vt in args.value_types if vt in all_value_types] + # Warn about non-existent value types + missing = [vt for vt in args.value_types if vt not in all_value_types] + if missing: + log.warning(f"Requested value types not found in data: {missing}") + if not value_types: + log.error(f"None of the requested value types {args.value_types} found in data. Available: {all_value_types}") + sys.exit(1) + else: + value_types = all_value_types + + # Count unique samples (group, sample, rep combinations) + unique_samples = len(set((s["group"], s["sample"], s["rep"]) for s in sample_info)) + log.info(f"Available value types: {all_value_types}") + if args.value_types: + log.info(f"Analyzing value types: {value_types}") + else: + log.info(f"Analyzing all value types: {value_types}") + log.info(f"Samples: {unique_samples} unique samples across {len(all_value_types)} value types ({len(sample_info)} total columns)") + + all_results = {} + + for vtype in value_types: + log.info(f"\n{'='*60}") + log.info(f"VALUE TYPE: {vtype}") + log.info(f"{'='*60}") + + vtype_dir = os.path.join(outdir, vtype) + safe_mkdir(vtype_dir) + vcols = cols_for_vtype(sample_info, vtype) + v_sample_info = sample_info_for_vtype(sample_info, vtype) + all_results[vtype] = {"aggregate": {}, "feature": {}} + + # Collect all analysis tasks for this value_type + tasks = [] + + # =============================================================== + # AGGREGATES + # =============================================================== + if agg_df is not None: + log.info(f"\n--- AGGREGATES for {vtype} ---") + agg_data = agg_df[agg_df["Mtype"] == "aggregate"].copy() + + agg_dir = os.path.join(vtype_dir, "aggregate") + + # Filter aggregate levels if specified + agg_levels = args.agg_levels if args.agg_levels else ["global", "sequence", "feature"] + log.info(f"Analyzing aggregate levels: {agg_levels}") + + # --- 1. Global aggregates (Type == global) --- + if "global" in agg_levels: + glob_agg = agg_data[agg_data["Type"] == "global"] + else: + glob_agg = pd.DataFrame() # Empty dataframe if global is not requested + + if not glob_agg.empty: + # All BMKs together (no Ptype/Ctype/Mode filter) + tasks.append(( + prepare_df_for_task(glob_agg), vcols, v_sample_info, + os.path.join(agg_dir, "global", "all_global_bmks"), + "Global - All BMKs", + ("aggregate", "global_all_bmks"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + + # all sites (Mode == all_sites) + section = glob_agg[glob_agg["Mode"] == "all_sites"] + tasks.append(( + prepare_df_for_task(section), vcols, v_sample_info, + os.path.join(agg_dir, "global", "all_sites"), + "Global - All Sites", + ("aggregate", "global_all_sites"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + + # All sites by Ctype (Ptype == "." and 3 modes) + for mode in ["all_isoforms", "chimaera", "longest_isoform"]: + section = glob_agg[(glob_agg["Ptype"] == ".") & (glob_agg["Mode"] == mode)] + tasks.append(( + prepare_df_for_task(section), vcols, v_sample_info, + os.path.join(agg_dir, "global", f"by_ctype_{mode}"), + f"Global - By Ctype - {mode}", + ("aggregate", f"global_ctype_{mode}"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + + # All sites by Ptype (Ptype != "." and 3 modes) + ptypes = [p for p in glob_agg["Ptype"].unique() if p != "."] + for ptype in ptypes: + for mode in ["all_isoforms", "chimaera", "longest_isoform"]: + section = glob_agg[(glob_agg["Ptype"] == ptype) & (glob_agg["Mode"] == mode)] + tasks.append(( + prepare_df_for_task(section), vcols, v_sample_info, + os.path.join(agg_dir, "global", f"by_ptype_{ptype}_{mode}"), + f"Global - Ptype={ptype} - {mode}", + ("aggregate", f"global_ptype_{ptype}_{mode}"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + + # --- 2. Chromosome/Sequence aggregates (Type == sequence) --- + if "sequence" in agg_levels: + seq_agg = agg_data[agg_data["Type"] == "sequence"] + else: + seq_agg = pd.DataFrame() # Empty dataframe if chr is not requested + + if not seq_agg.empty: + chromosomes = seq_agg["SeqID"].unique() + for chrom in chromosomes: + chr_data = seq_agg[seq_agg["SeqID"] == chrom] + + # All sites + section = chr_data[chr_data["Mode"] == "all_sites"] + tasks.append(( + prepare_df_for_task(section), vcols, v_sample_info, + os.path.join(agg_dir, f"sequence/{chrom}", "all_sites"), + f"Chr {chrom} - All Sites", + ("aggregate", f"chr{chrom}_all_sites"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + + # By Ctype + for mode in ["all_isoforms", "chimaera", "longest_isoform"]: + section = chr_data[(chr_data["Ptype"] == ".") & (chr_data["Mode"] == mode)] + tasks.append(( + prepare_df_for_task(section), vcols, v_sample_info, + os.path.join(agg_dir, f"sequence/{chrom}", f"by_ctype_{mode}"), + f"Chr {chrom} - By Ctype - {mode}", + ("aggregate", f"chr{chrom}_ctype_{mode}"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + + # By Ptype + local_ptypes = [p for p in chr_data["Ptype"].unique() if p != "."] + for ptype in local_ptypes: + for mode in ["all_isoforms", "chimaera", "longest_isoform"]: + section = chr_data[(chr_data["Ptype"] == ptype) & (chr_data["Mode"] == mode)] + tasks.append(( + prepare_df_for_task(section), vcols, v_sample_info, + os.path.join(agg_dir, f"sequence/{chrom}", f"by_ptype_{ptype}_{mode}"), + f"Chr {chrom} - Ptype={ptype} - {mode}", + ("aggregate", f"chr{chrom}_ptype_{ptype}_{mode}"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + + # --- All sequences combined (all chromosomes pooled) --- + # all_sites + section = seq_agg[seq_agg["Mode"] == "all_sites"] + tasks.append(( + prepare_df_for_task(section), vcols, v_sample_info, + os.path.join(agg_dir, "sequence", "all_sequence_bmks", "all_sites"), + "All Sequences - All Sites", + ("aggregate", "allseq_all_sites"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + # by_ctype + for mode in ["all_isoforms", "chimaera", "longest_isoform"]: + section = seq_agg[(seq_agg["Ptype"] == ".") & (seq_agg["Mode"] == mode)] + tasks.append(( + prepare_df_for_task(section), vcols, v_sample_info, + os.path.join(agg_dir, "sequence", "all_sequence_bmks", f"by_ctype_{mode}"), + f"All Sequences - By Ctype - {mode}", + ("aggregate", f"allseq_ctype_{mode}"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + # by_ptype + all_seq_ptypes = [p for p in seq_agg["Ptype"].unique() if p != "."] + for ptype in all_seq_ptypes: + for mode in ["all_isoforms", "chimaera", "longest_isoform"]: + section = seq_agg[(seq_agg["Ptype"] == ptype) & (seq_agg["Mode"] == mode)] + tasks.append(( + prepare_df_for_task(section), vcols, v_sample_info, + os.path.join(agg_dir, "sequence", "all_sequence_bmks", f"by_ptype_{ptype}_{mode}"), + f"All Sequences - Ptype={ptype} - {mode}", + ("aggregate", f"allseq_ptype_{ptype}_{mode}"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + + # --- 3. Feature aggregates (Type == feature) --- + if "feature" in agg_levels: + feat_agg = agg_data[agg_data["Type"] == "feature"] + else: + feat_agg = pd.DataFrame() # Empty dataframe if feature is not requested + + if not feat_agg.empty: + fa_ptypes = feat_agg["Ptype"].unique() + fa_ctypes = feat_agg["Ctype"].unique() + fa_modes = feat_agg["Mode"].unique() + + # --- all_feature_together: all chromosomes pooled --- + for ptype in fa_ptypes: + for ctype in fa_ctypes: + for mode in fa_modes: + section = feat_agg[ + (feat_agg["Ptype"] == ptype) & + (feat_agg["Ctype"] == ctype) & + (feat_agg["Mode"] == mode) + ] + safe_name = f"{ptype}_{ctype}_{mode}".replace(".", "all").replace("-", "_") + tasks.append(( + prepare_df_for_task(section), vcols, v_sample_info, + os.path.join(agg_dir, "feature", "all_feature_together", safe_name), + f"Feature Agg - Ptype={ptype}, Ctype={ctype}, Mode={mode}", + ("aggregate", f"featagg_{safe_name}"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + + # --- by_sequence: features grouped by chromosome --- + fa_chroms = [c for c in feat_agg["SeqID"].unique() if c != "."] + for chrom in fa_chroms: + chr_feat = feat_agg[feat_agg["SeqID"] == chrom] + chr_ptypes = chr_feat["Ptype"].unique() + chr_ctypes = chr_feat["Ctype"].unique() + chr_modes = chr_feat["Mode"].unique() + for ptype in chr_ptypes: + for ctype in chr_ctypes: + for mode in chr_modes: + section = chr_feat[ + (chr_feat["Ptype"] == ptype) & + (chr_feat["Ctype"] == ctype) & + (chr_feat["Mode"] == mode) + ] + safe_name = f"{ptype}_{ctype}_{mode}".replace(".", "all").replace("-", "_") + tasks.append(( + prepare_df_for_task(section), vcols, v_sample_info, + os.path.join(agg_dir, "feature", "by_sequence", str(chrom), safe_name), + f"Feature Agg - Chr {chrom} - Ptype={ptype}, Ctype={ctype}, Mode={mode}", + ("aggregate", f"featagg_chr{chrom}_{safe_name}"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + else: + log.info(f"\n--- Skipping AGGREGATES for {vtype} (no aggregates file) ---") + + # =============================================================== + # FEATURES + # =============================================================== + if feat_df is not None: + log.info(f"\n--- FEATURES for {vtype} ---") + feat_data = feat_df.copy() # Mtype is always "feature" in features file + feat_dir = os.path.join(vtype_dir, "feature") + + # Build hierarchy + tree, top_ids = build_feature_tree(feat_data) + + # Group top features by Type + top_features = feat_data[feat_data["ParentIDs"] == "."] + available_feature_types = top_features["Type"].unique().tolist() + + # Filter feature types if specified + if args.feature_types: + feature_types = [ft for ft in args.feature_types if ft in available_feature_types] + missing = [ft for ft in args.feature_types if ft not in available_feature_types] + if missing: + log.warning(f"Requested feature types not found in data: {missing}") + if not feature_types: + log.error(f"None of the requested feature types {args.feature_types} found in data. Available: {available_feature_types}") + feature_types = [] # Will skip all features + else: + feature_types = available_feature_types + + if feature_types: + log.info(f"Available feature types: {available_feature_types}") + log.info(f"Analyzing feature types: {feature_types}") + + for ttype in feature_types: + type_dir = os.path.join(feat_dir, ttype.replace(" ", "_")) + type_features = top_features[top_features["Type"] == ttype] + + # Analyze all features of this type together + all_ids_of_type = [] + for _, row in type_features.iterrows(): + fid = row["ID"] + # Collect this feature and all descendants + def collect_ids(node_id): + ids = [node_id] + if node_id in tree: + for child in tree[node_id]["children"]: + ids.extend(collect_ids(child)) + return ids + all_ids_of_type.extend(collect_ids(fid)) + + type_all_df = feat_data[feat_data["ID"].isin(all_ids_of_type)] + tasks.append(( + prepare_df_for_task(type_all_df), vcols, v_sample_info, + os.path.join(type_dir, "_all"), + f"Features - {ttype} (all)", + ("feature", f"type_{ttype}_all"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + + # Per top-feature analysis + for _, row in type_features.iterrows(): + fid = row["ID"] + def collect_ids(node_id): + ids = [node_id] + if node_id in tree: + for child in tree[node_id]["children"]: + ids.extend(collect_ids(child)) + return ids + sub_ids = collect_ids(fid) + sub_df = feat_data[feat_data["ID"].isin(sub_ids)] + safe_fid = fid.replace(":", "_").replace("/", "_") + tasks.append(( + prepare_df_for_task(sub_df), vcols, v_sample_info, + os.path.join(type_dir, safe_fid), + f"Feature: {fid}", + ("feature", f"feature_{safe_fid}"), + args.stat_test, + args.bmk_filter, + args.max_bmks + )) + else: + log.info(f"\n--- Skipping FEATURES for {vtype} (no features file) ---") + + # Execute tasks (parallel or sequential) + log.info(f"Executing {len(tasks)} analysis tasks...") + if n_jobs > 1 and len(tasks) > 1: + # Parallel execution with worker recycling to prevent memory leaks + log.info(f"Submitting {len(tasks)} tasks to {n_jobs} workers...") + + # MEMORY FIX #1: Use spawn context to prevent fork Copy-on-Write memory inheritance + # With fork (default on Linux), each worker inherits ALL parent memory via CoW + # → 6 workers × 5GB parent = 30GB total (even if psutil shows less per process) + # With spawn, workers start clean and only receive their task data via pickle + mp_context = multiprocessing.get_context('spawn') + log.info(f" Using 'spawn' context (not fork) to avoid CoW memory inheritance") + + # Use max_tasks_per_child=20 to balance performance and memory + # Higher than fork (was 5) because spawn workers are clean but have import overhead (~275MB) + log.info(f" Worker recycling: Each worker will process max 20 tasks before restart") + + # CRITICAL: Use lazy submission to avoid loading all 1491 tasks (dict copies) in memory at once + # Submit only max_pending_tasks at a time, then submit new ones as they complete + max_pending_tasks = n_jobs * 3 # Keep 3x workers worth of tasks in flight + log.info(f" Lazy submission: Max {max_pending_tasks} tasks in memory at once (was {len(tasks)})") + + with ProcessPoolExecutor(max_workers=n_jobs, max_tasks_per_child=20, mp_context=mp_context) as executor: + future_to_key = {} + task_iter = iter(tasks) + completed = 0 + failed = 0 + submitted_count = 0 + last_progress_time = time.time() + stall_timeout = 90 # Increased from 40s for spawn startup delay (imports take ~2-5s per worker) + last_worker_check = time.time() + worker_check_interval = 15 # Check worker health every 15 seconds + + # Initial submission of first batch + initial_batch = min(max_pending_tasks, len(tasks)) + log.info(f" Submitting initial batch of {initial_batch} tasks...") + for _ in range(initial_batch): + try: + df_source, cols, info, outdir, name, key, stat_test, bmk_filter, max_bmks = next(task_iter) + future = executor.submit(analyze_section_wrapper, (df_source, cols, info, outdir, name, key, stat_test, bmk_filter, max_bmks)) + future_to_key[future] = (key, name) + submitted_count += 1 + # Explicitly free the df_source reference to help GC + del df_source + except StopIteration: + break + + log.info(f" Initial batch submitted. Processing with {n_jobs} workers...") + log.info(f" Note: Tasks have a 2-minute timeout. Anti-deadlock active.") + + # Process futures with stall detection + pending_futures = set(future_to_key.keys()) + + while pending_futures: + # Periodically check if workers have died + current_time = time.time() + if psutil is not None and (current_time - last_worker_check) > worker_check_interval: + try: + process = psutil.Process() + children = process.children(recursive=True) + n_active_workers = len(children) + + # If more than half of workers are dead, we have a problem + if n_active_workers < (n_jobs / 2) and len(pending_futures) > n_active_workers * 2: + log.error(f" ✗ WORKER DEATH DETECTED: Only {n_active_workers}/{n_jobs} workers alive with {len(pending_futures)} pending tasks") + log.error(f" Most likely cause: Out-Of-Memory (OOM) killed workers") + log.error(f" Cancelling all pending tasks to prevent infinite hang") + log.error(f" TIP: Restart with fewer workers (--n-jobs 2 or --n-jobs 3)") + + # Cancel all pending futures + for future in list(pending_futures): + future.cancel() + key, section_name = future_to_key[future] + log.error(f" ✗ CANCELLED (orphaned): {section_name}") + failed += 1 + break + + last_worker_check = current_time + except Exception as e: + log.debug(f"Worker health check failed: {e}") + + # Use short timeout on as_completed to check for stalls + try: + done_iter = as_completed(pending_futures, timeout=5) + for future in done_iter: + pending_futures.discard(future) + key, section_name = future_to_key[future] + + # Process the completed future + try: + result_key, results = future.result(timeout=1) # Short timeout since already done + mtype, section_key = result_key + all_results[vtype][mtype][section_key] = results + completed += 1 + last_progress_time = time.time() # Reset stall timer + + # MEMORY FIX #3: Immediate cleanup after storing (results is slim dict with only paths) + del results + + # Force garbage collection every 10 tasks to free memory in main process + if completed % 10 == 0: + gc.collect() + + # Progress update every 5 completions or at key milestones + if completed % 5 == 0 or completed in [1, 10, 25, 50, 100]: + log.info(f" ✓ Progress: {completed}/{len(tasks)} completed, {failed} failed, {submitted_count - completed - failed} submitted pending") + except TimeoutError: + failed += 1 + last_progress_time = time.time() + log.error(f" ✗ TIMEOUT ({completed+failed}/{len(tasks)}): {section_name} - exceeded 2 minutes") + except Exception as e: + failed += 1 + last_progress_time = time.time() + # Check if it's a worker crash (common patterns in error message) + if "process" in str(e).lower() and ("terminate" in str(e).lower() or "crash" in str(e).lower() or "abrupt" in str(e).lower()): + log.error(f" ✗ CRASH ({completed+failed}/{len(tasks)}): {section_name} - worker OOM or crash") + else: + log.error(f" ✗ ERROR ({completed+failed}/{len(tasks)}): {section_name} - {type(e).__name__}") + + # CRITICAL: Submit next task to maintain max_pending_tasks in flight + # This lazy submission keeps memory usage constant regardless of total tasks + if len(pending_futures) < max_pending_tasks: + try: + df_source, cols, info, outdir, name, key, stat_test, bmk_filter, max_bmks = next(task_iter) + new_future = executor.submit(analyze_section_wrapper, (df_source, cols, info, outdir, name, key, stat_test, bmk_filter, max_bmks)) + future_to_key[new_future] = (key, name) + pending_futures.add(new_future) + submitted_count += 1 + # Explicitly free the df_source reference to help GC + del df_source + # Log every 50 submissions + if submitted_count % 50 == 0: + log.info(f" → Submitted up to {submitted_count}/{len(tasks)} tasks (lazy mode)") + except StopIteration: + pass # No more tasks to submit + + # Break inner loop to check stall timeout + break + except TimeoutError: + # No futures completed in 5 seconds, check for stall + elapsed_since_progress = time.time() - last_progress_time + if elapsed_since_progress > stall_timeout: + log.error(f" ✗ DEADLOCK DETECTED: No progress for {elapsed_since_progress:.0f}s") + log.error(f" Cancelling {len(pending_futures)} remaining tasks to prevent infinite hang") + # Cancel all pending futures + for future in pending_futures: + future.cancel() + key, section_name = future_to_key[future] + log.error(f" ✗ CANCELLED: {section_name}") + failed += 1 + break + + if failed > 0: + log.warning(f" {failed} tasks failed or cancelled out of {len(tasks)} total") + else: + # Sequential execution + for df_source, cols, info, outdir, name, key, stat_test, bmk_filter, max_bmks in tasks: + # Reconstruct DataFrame from pickled format + if isinstance(df_source, tuple) and df_source[0] == 'pickled': + import pickle + df = pickle.loads(df_source[1]) + else: + df = pd.DataFrame(df_source) + results = analyze_section(df, cols, info, outdir, name, stat_test=stat_test, bmk_filter_cols=bmk_filter, max_bmks=max_bmks) + + # Create slim results dict (same as parallel mode for consistency) + slim_results = { + "differential_table": os.path.join(outdir, "differential", "differential_results.csv") + } + + mtype, section_key = key + all_results[vtype][mtype][section_key] = slim_results + + # =============================================================== + # GLOBAL RANKING + # =============================================================== + log.info("\n--- GLOBAL RANKING ---") + global_ranking(all_results, os.path.join(outdir, "global_ranking")) + + # Save manifest + manifest = { + "value_types": value_types, + "outdir": outdir, + "n_aggregates": len(agg_df) if agg_df is not None else 0, + "n_features": len(feat_df) if feat_df is not None else 0, + "sample_info": sample_info, + } + with open(os.path.join(outdir, "manifest.json"), "w") as f: + json.dump(manifest, f, indent=2, default=str) + + log.info(f"\nAnalysis complete. Results saved to {outdir}/") + + +if __name__ == "__main__": + main() diff --git a/bin/barometer_report.py b/bin/barometer_report.py new file mode 100755 index 0000000..ecca6bc --- /dev/null +++ b/bin/barometer_report.py @@ -0,0 +1,1133 @@ +#!/usr/bin/env python3 +""" +barometer_report.py – Generate an interactive HTML report from barometer_analyze.py results. + +Assembles all tables, figures, and analyses into a multi-tab HTML report using +Jinja2 templating with interactive DataTables, Plotly for image zoom, and +collapsible sections for navigation. + +Usage: + python barometer_report.py [--results DIR] [--output FILE] +""" + +import argparse +import base64 +import csv +import glob +import json +import logging +import os +import re +import sys +from collections import OrderedDict +from pathlib import Path + +import pandas as pd +from jinja2 import Template + +logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") +log = logging.getLogger(__name__) + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +EMBED_IMAGES = False +RESULTS_DIR = "llmeter_results" + + +def img_src(path): + """Return an image src: base64 data URI if embedding, else relative path.""" + if not os.path.exists(path): + return "" + if EMBED_IMAGES: + with open(path, "rb") as f: + data = f.read() + ext = path.rsplit(".", 1)[-1].lower() + mime = {"png": "image/png", "jpg": "image/jpeg", "jpeg": "image/jpeg", + "svg": "image/svg+xml"}.get(ext, "image/png") + return f"data:{mime};base64,{base64.b64encode(data).decode()}" + else: + # Return relative path from report file location + return path + + +def csv_to_html_table(path, max_rows=500, table_id=None): + """Read a CSV and return an HTML table string.""" + if not os.path.exists(path): + return "" + try: + df = pd.read_csv(path) + except Exception: + return "" + if len(df) > max_rows: + df = df.head(max_rows) + tid = f' id="{table_id}"' if table_id else "" + classes = "display compact stripe hover" + html = df.to_html(index=False, classes=classes, border=0, na_rep="NA", + float_format=lambda x: f"{x:.4g}" if abs(x) > 1e-4 else f"{x:.2e}") + html = html.replace("
13. Visualization – Overview Heatmap
' + content += f'Heatmap\n' + + for subdir in subdirs: + sub_path = os.path.join(section_dir, subdir) + if not os.path.isdir(sub_path): + continue + + label = analysis_labels.get(subdir, subdir.title()) + content += f'
{label}
\n' + + # Tables + for csv_path in find_csvs(sub_path): + fname = os.path.basename(csv_path).replace(".csv", "") + table_counter[0] += 1 + tid = f"dt_{table_counter[0]}" + content += f'
{fname.replace("_", " ").title()}
\n' + content += f'
{csv_to_html_table(csv_path, table_id=tid)}
\n' + content += f'\n' + content += '
\n' + + # Images + for img_path in find_images(sub_path): + fname = os.path.basename(img_path).replace(".png", "").replace("_", " ").title() + b64 = img_src(img_path) + if b64: + content += f'
{fname}
' + content += f'{fname}
\n' + + content += '
\n' + + return content + + +def build_report_data(results_dir): + """Walk the results directory and build the report structure.""" + manifest_path = os.path.join(results_dir, "manifest.json") + manifest = {} + if os.path.exists(manifest_path): + with open(manifest_path) as f: + manifest = json.load(f) + + value_types = manifest.get("value_types", []) + if not value_types: + # Discover from directory + for d in sorted(os.listdir(results_dir)): + dp = os.path.join(results_dir, d) + if os.path.isdir(dp) and d not in ("global_ranking",): + value_types.append(d) + + report = OrderedDict() + table_counter = [0] + + for vtype in value_types: + vtype_dir = os.path.join(results_dir, vtype) + if not os.path.isdir(vtype_dir): + continue + + report[vtype] = OrderedDict() + + # Aggregate and Feature tabs + for mtype in ["aggregate", "feature"]: + mtype_dir = os.path.join(vtype_dir, mtype) + if not os.path.isdir(mtype_dir): + continue + + report[vtype][mtype] = OrderedDict() + # Walk all section directories + for root, dirs, files in os.walk(mtype_dir): + # Only process leaf directories that contain actual results + if any(f.endswith(".csv") or f.endswith(".png") for f in files): + rel = os.path.relpath(root, mtype_dir) + parts = rel.split(os.sep) + first = parts[0] + group_name = prettify_group_name(first) + + # Detect subgroup / sub-subgroup + ssg_key = None + ssg_name = None + if first == "global" and len(parts) >= 2: + sec_dir = parts[1] + remainder = os.sep.join(parts[1:]) + if "by_ptype_" not in sec_dir: + # all_sites / by_ctype_* → "All" subgroup + sg_key = "all" + sg_name = "All" + section_name = prettify_section_name(remainder) + else: + # by_ptype__ → subgroup per ptype + ptype_raw, ctype_raw = _extract_ptype(sec_dir) + if ptype_raw: + sg_key = f"ptype_{ptype_raw}" + sg_name = _prettify_ptype(ptype_raw) + section_name = _prettify_mode(ctype_raw) if ctype_raw else prettify_section_name(remainder) + else: + sg_key = "" + sg_name = "" + section_name = prettify_section_name(remainder) + elif first == "sequence" and len(parts) >= 3: + sg_raw = parts[1] # e.g. "21" or "all_sequence_bmks" + sg_name = "All Sequences" if sg_raw == "all_sequence_bmks" else f"Chr {sg_raw}" + sg_key = sg_raw + sec_dir = parts[2] + remainder = os.sep.join(parts[2:]) + if "by_ptype_" not in sec_dir: + # all_sites / by_ctype_* → "All" sub-subgroup + ssg_key = "all" + ssg_name = "All" + section_name = prettify_section_name(remainder) + else: + # by_ptype__ → sub-subgroup per ptype + ptype_raw, ctype_raw = _extract_ptype(sec_dir) + if ptype_raw: + ssg_key = f"ptype_{ptype_raw}" + ssg_name = _prettify_ptype(ptype_raw) + section_name = _prettify_mode(ctype_raw) if ctype_raw else prettify_section_name(remainder) + else: + ssg_key = None + ssg_name = None + section_name = prettify_section_name(remainder) + elif first == "feature" and len(parts) >= 3 and parts[1] == "all_feature_together": + # all_feature_together/{safe_name} — all chromosomes pooled + sec_dir = parts[2] + section_dir = os.path.join(mtype_dir, first, parts[1], sec_dir) + meta = _read_section_meta(section_dir) + if meta: + ptype_raw = meta["Ptype"] + ctype_raw = meta["Ctype"] + mode = meta["Mode"] + sg_key = ptype_raw + sg_name = _prettify_ptype(ptype_raw) + ssg_key = ctype_raw + ssg_name = ctype_raw.replace("_", " ").replace("-", " ").title() + section_name = _prettify_mode(mode) + else: + sg_key = "all_feature_together" + sg_name = "All Features Together" + ssg_key = None + ssg_name = None + section_name = prettify_section_name(parts[2]) + elif first == "feature" and len(parts) >= 4 and parts[1] == "by_sequence": + # by_sequence/{chrom}/{safe_name} — features grouped by chromosome + chrom = parts[2] + sec_dir = parts[3] + section_dir = os.path.join(mtype_dir, first, parts[1], chrom, sec_dir) + meta = _read_section_meta(section_dir) + sg_key = f"seq_{chrom}" + sg_name = f"Chr {chrom}" + if meta: + ptype_raw = meta["Ptype"] + ctype_raw = meta["Ctype"] + mode = meta["Mode"] + ssg_key = ctype_raw + ssg_name = ctype_raw.replace("_", " ").replace("-", " ").title() + section_name = f"{_prettify_ptype(ptype_raw)} – {_prettify_mode(mode)}" + else: + ssg_key = None + ssg_name = None + section_name = prettify_section_name(parts[3]) + elif first == "feature" and len(parts) >= 2: + # Flat fallback (old structure) + sec_dir = parts[1] + section_dir = os.path.join(mtype_dir, first, sec_dir) + meta = _read_section_meta(section_dir) + if meta: + ptype_raw = meta["Ptype"] + ctype_raw = meta["Ctype"] + mode = meta["Mode"] + sg_key = ptype_raw + sg_name = _prettify_ptype(ptype_raw) + ssg_key = ctype_raw + ssg_name = ctype_raw.replace("_", " ").replace("-", " ").title() + section_name = _prettify_mode(mode) + else: + sg_key = "" + sg_name = "" + section_name = prettify_section_name(os.sep.join(parts[1:])) + else: + sg_key = "" + sg_name = "" + remainder = os.sep.join(parts[1:]) if len(parts) > 1 else "" + section_name = prettify_section_name(remainder) if remainder else prettify_section_name(first) + + # Ensure group exists + if group_name not in report[vtype][mtype]: + report[vtype][mtype][group_name] = OrderedDict() + + # Ensure subgroup bucket exists + # Structure: report[vtype][mtype][group_name][sg_key] = { + # "name": sg_name, "sections": OrderedDict(), "subgroups": OrderedDict()} + grp = report[vtype][mtype][group_name] + if sg_key not in grp: + grp[sg_key] = {"name": sg_name, "sections": OrderedDict(), "subgroups": OrderedDict()} + + content = build_section_content(root, section_name, table_counter) + if content.strip(): + if (first in ("sequence", "feature")) and ssg_key: + # Store in a sub-subgroup (All / Gene / …) inside the chr subgroup + ssg = grp[sg_key]["subgroups"] + if ssg_key not in ssg: + ssg[ssg_key] = {"name": ssg_name, "sections": OrderedDict()} + ssg[ssg_key]["sections"][rel] = {"name": section_name, "content": content} + else: + grp[sg_key]["sections"][rel] = {"name": section_name, "content": content} + + # Global ranking + gr_dir = os.path.join(results_dir, "global_ranking") + global_content = "" + if os.path.isdir(gr_dir): + for csv_path in find_csvs(gr_dir): + fname = os.path.basename(csv_path).replace(".csv", "") + table_counter[0] += 1 + tid = f"dt_{table_counter[0]}" + global_content += f'
{fname.replace("_", " ").title()}
\n' + global_content += f'
{csv_to_html_table(csv_path, table_id=tid)}
\n' + global_content += f'\n' + global_content += '
\n' + + for img_path in find_images(gr_dir): + fname = os.path.basename(img_path).replace(".png", "").replace("_", " ").title() + b64 = img_src(img_path) + if b64: + global_content += f'
{fname}
' + global_content += f'{fname}
\n' + + return report, manifest, global_content + + +# --------------------------------------------------------------------------- +# HTML Template +# --------------------------------------------------------------------------- + +HTML_TEMPLATE = """ + + + + +barometer Biomarker Analysis Report + + + + + + + + + + + + +
barometer Biomarker Analysis Report
+ + + +
+ + +
+

Report Summary

+

Aggregates: {{ manifest.get('n_aggregates', 'N/A') }} rows  |  + Features: {{ manifest.get('n_features', 'N/A') }} rows  |  + Value types: {{ manifest.get('value_types', [])|join(', ') }}

+

Samples:

+
    + {% for s in manifest.get('sample_info', [])[:20] %} +
  • {{ s.col }} (group: {{ s.group }}, sample: {{ s.sample }}, rep: {{ s.rep }})
  • + {% endfor %} +
+

Analyses performed: QC, Descriptive Statistics, Differential Editing (Mann-Whitney U, Kruskal-Wallis, ANOVA), + Multiple Testing Correction (FDR-BH), PCA, Hierarchical Clustering, Correlation Analysis, Random Forest Feature + Importance, Classification (RF, GBT, LDA), Stability/Robustness (CV, replicate concordance), Batch Effect Detection.

+
+ + +
+

🏆 Global Biomarker Ranking

+ {{ global_content|safe }} +
+ + + + +
+ {% for vtype, mtype_data in report.items() %} +
+ +
+ + + + +
+ {% for mtype, groups in mtype_data.items() %} +
+ +
+ + {% for group_name, subgroups in groups.items() %} +
+

{{ group_name }}

+ +
+ {% endfor %} + +
+ {% endfor %} +
+ +
+ {% endfor %} +
+ +
+ + +
+ Zoomed +
+ + + + + +""" + + +# --------------------------------------------------------------------------- +# Main +# --------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser(description="Generate barometer Biomarker HTML Report") + parser.add_argument("-r", "--results", default="barometer_results", help="Results directory from barometer_analyze.py") + parser.add_argument("-o", "--output", default="barometer_report.html", help="Output HTML report file") + parser.add_argument("-e", "--embed-images", action="store_true", default=False, + help="Embed images as base64 (larger file but self-contained)") + args = parser.parse_args() + # Pass embed flag as module-level so helpers can use it + global EMBED_IMAGES, RESULTS_DIR + EMBED_IMAGES = args.embed_images + RESULTS_DIR = args.results + + if not os.path.isdir(args.results): + log.error(f"Results directory not found: {args.results}") + sys.exit(1) + + log.info(f"Building report from {args.results}...") + report, manifest, global_content = build_report_data(args.results) + + log.info(f"Rendering HTML...") + template = Template(HTML_TEMPLATE) + html = template.render( + report=report, + manifest=manifest, + global_content=global_content, + sanitize=sanitize_id, + ) + + with open(args.output, "w", encoding="utf-8") as f: + f.write(html) + + log.info(f"Report written to {args.output}") + log.info(f"Open in browser to view the interactive report.") + + +if __name__ == "__main__": + main() From 1ba2ebc2e2dbd619120911c64aa54730691f1c58 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Tue, 24 Mar 2026 21:40:34 +0100 Subject: [PATCH 59/61] fix certiicate for download --- containers/docker/jacusa2/Dockerfile | 1 + 1 file changed, 1 insertion(+) diff --git a/containers/docker/jacusa2/Dockerfile b/containers/docker/jacusa2/Dockerfile index 5d431f9..79dc722 100644 --- a/containers/docker/jacusa2/Dockerfile +++ b/containers/docker/jacusa2/Dockerfile @@ -1,5 +1,6 @@ FROM openjdk:8u102 +RUN apt-get update && apt-get install -y ca-certificates # Download the tool RUN wget https://github.com/dieterich-lab/JACUSA2/releases/download/v2.0.4/JACUSA_v2.0.4.jar -O /usr/local/bin/JACUSA_v2.0.4.jar From ac979055a8f69bc120abd604520b4d20f7814f18 Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Tue, 24 Mar 2026 21:54:13 +0100 Subject: [PATCH 60/61] standardize and fix containers --- build_containers.sh | 2 +- .../sapin/constraints.txt => common/constraints_sapin.txt} | 0 containers/{ => common}/env_pluviometer.yml | 0 containers/{docker/reditools2 => common}/env_reditools2.yml | 0 containers/docker/jacusa2/Dockerfile | 3 +-- containers/docker/sapin/Dockerfile | 4 ++-- 6 files changed, 4 insertions(+), 5 deletions(-) rename containers/{docker/sapin/constraints.txt => common/constraints_sapin.txt} (100%) rename containers/{ => common}/env_pluviometer.yml (100%) rename containers/{docker/reditools2 => common}/env_reditools2.yml (100%) diff --git a/build_containers.sh b/build_containers.sh index ad289fd..bec143f 100755 --- a/build_containers.sh +++ b/build_containers.sh @@ -146,7 +146,7 @@ if [ "$build_docker" = true ]; then fi # Use containers/ as build context so shared files (e.g. env_*.yml) are accessible - docker build ${docker_arch_option} -f "${dir}/Dockerfile" -t ${imgname} containers/ + docker build ${docker_arch_option} -f "${dir}/Dockerfile" -t ${imgname} containers/common/ done if [[ ${github_action_mode} == 'github_action' ]]; then diff --git a/containers/docker/sapin/constraints.txt b/containers/common/constraints_sapin.txt similarity index 100% rename from containers/docker/sapin/constraints.txt rename to containers/common/constraints_sapin.txt diff --git a/containers/env_pluviometer.yml b/containers/common/env_pluviometer.yml similarity index 100% rename from containers/env_pluviometer.yml rename to containers/common/env_pluviometer.yml diff --git a/containers/docker/reditools2/env_reditools2.yml b/containers/common/env_reditools2.yml similarity index 100% rename from containers/docker/reditools2/env_reditools2.yml rename to containers/common/env_reditools2.yml diff --git a/containers/docker/jacusa2/Dockerfile b/containers/docker/jacusa2/Dockerfile index 79dc722..34196b1 100644 --- a/containers/docker/jacusa2/Dockerfile +++ b/containers/docker/jacusa2/Dockerfile @@ -1,6 +1,5 @@ -FROM openjdk:8u102 +FROM quay.io/biocontainers/openjdk:11.0.1--2 -RUN apt-get update && apt-get install -y ca-certificates # Download the tool RUN wget https://github.com/dieterich-lab/JACUSA2/releases/download/v2.0.4/JACUSA_v2.0.4.jar -O /usr/local/bin/JACUSA_v2.0.4.jar diff --git a/containers/docker/sapin/Dockerfile b/containers/docker/sapin/Dockerfile index 4c68470..a283fa7 100644 --- a/containers/docker/sapin/Dockerfile +++ b/containers/docker/sapin/Dockerfile @@ -10,9 +10,9 @@ RUN apt install -y git RUN git clone --depth=1 https://github.com/Juke34/SAPiN.git # Use a pip constraint file to pin dependencies for reproducibility -COPY constraints.txt SAPiN/ +COPY constraints_sapin.txt SAPiN/ RUN cd SAPiN && \ - pip install --constraint constraints.txt . + pip install --constraint constraints_sapin.txt . CMD [ "bash" ] From 7263e8a2d6847b89a92f1dd3f784a7e07d54b31d Mon Sep 17 00:00:00 2001 From: Jacques Dainat Date: Tue, 24 Mar 2026 21:56:35 +0100 Subject: [PATCH 61/61] standardize for singu --- .../singularity/reditools2/env_reditools2.yml | 15 -------------- .../singularity/reditools2/reditools2.def | 2 +- containers/singularity/sapin/constraints.txt | 20 ------------------- containers/singularity/sapin/sapin.def | 2 +- 4 files changed, 2 insertions(+), 37 deletions(-) delete mode 100644 containers/singularity/reditools2/env_reditools2.yml delete mode 100644 containers/singularity/sapin/constraints.txt diff --git a/containers/singularity/reditools2/env_reditools2.yml b/containers/singularity/reditools2/env_reditools2.yml deleted file mode 100644 index 5897beb..0000000 --- a/containers/singularity/reditools2/env_reditools2.yml +++ /dev/null @@ -1,15 +0,0 @@ -channels: - - bioconda - - conda-forge -dependencies: - - _openmp_mutex=4.5=2_gnu - - pip=20.1.1=pyh9f0ad1d_0 - - psutil=5.7.0=py27hdf8410d_1 - - pysam=0.20.0=py27h7835474_0 - - python=2.7.15=h5a48372_1011_cpython - - python_abi=2.7=1_cp27mu - - samtools=1.18=hd87286a_0 - - tabix=1.11=hdfd78af_0 - - sortedcontainers - - netifaces - diff --git a/containers/singularity/reditools2/reditools2.def b/containers/singularity/reditools2/reditools2.def index 88f20e9..1a792e0 100644 --- a/containers/singularity/reditools2/reditools2.def +++ b/containers/singularity/reditools2/reditools2.def @@ -2,7 +2,7 @@ Bootstrap: docker From: condaforge/miniforge3:24.7.1-2 %files - containers/singularity/reditools2/env_reditools2.yml ./env_reditools2.yml + containers/singularity/common/env_reditools2.yml ./env_reditools2.yml %post # Activer conda dans l'environnement shell diff --git a/containers/singularity/sapin/constraints.txt b/containers/singularity/sapin/constraints.txt deleted file mode 100644 index 3733d3d..0000000 --- a/containers/singularity/sapin/constraints.txt +++ /dev/null @@ -1,20 +0,0 @@ -argcomplete==3.6.2 -argh==0.31.3 -contourpy==1.3.1 -cycler==0.12.1 -fonttools==4.57.0 -gffutils==0.13 -importlib_metadata==8.6.1 -kiwisolver==1.4.8 -matplotlib==3.10.1 -numpy==2.2.4 -packaging==24.2 -pillow==11.1.0 -pyfaidx==0.8.1.3 -pyparsing==3.2.3 -pysam==0.23.0 -python-dateutil==2.9.0.post0 -SAPiN @ file:///SAPiN -simplejson==3.20.1 -six==1.17.0 -zipp==3.21.0 diff --git a/containers/singularity/sapin/sapin.def b/containers/singularity/sapin/sapin.def index ab020a7..1f2fffb 100644 --- a/containers/singularity/sapin/sapin.def +++ b/containers/singularity/sapin/sapin.def @@ -2,7 +2,7 @@ Bootstrap: docker From: python:slim-bullseye %files - containers/singularity/sapin/constraints.txt /constraints.txt + containers/singularity/common/constraints_sapin.txt /constraints.txt %post apt-get update && apt-get install -y \