diff --git a/docs/cli-docs/deRIP_cmd_line.md b/docs/cli-docs/deRIP_cmd_line.md index ae77a12..f9ef5bf 100644 --- a/docs/cli-docs/deRIP_cmd_line.md +++ b/docs/cli-docs/deRIP_cmd_line.md @@ -6,7 +6,7 @@ For aligned sequences in 'mintest.fa': - Any column with >= 70% gap positions will not be corrected and a gap inserted in corrected sequence. - Bases in column must be >= 80% C/T or G/A -- At least 50% bases in a column must be in RIP dinucleotide context (C/T as CpA / TpA) for correction. +- At least 50% bases in a column must be in RIP dinucleotide context (C/T as CpA or TpA) for correction. - Default: Inherit all remaining uncorrected positions from the least RIP'd sequence. - Mask all substrate and product motifs from corrected columns as ambiguous bases (i.e. CpA to TpA --> YpA) @@ -14,6 +14,7 @@ For aligned sequences in 'mintest.fa': ```bash derip2 -i tests/data/mintest.fa \ + --threads 1 \ --max-gaps 0.7 \ --max-snp-noise 0.2 \ --min-rip-like 0.5 \ @@ -34,6 +35,7 @@ The `--plot` option will create a visualization of the alignment with RIP markup ```bash derip2 -i tests/data/mintest.fa \ + --threads 1 \ --max-gaps 0.7 \ --max-snp-noise 0.2 \ --min-rip-like 0.5 \ @@ -57,6 +59,7 @@ By default uncorrected positions in the output sequence are filled from the sequ ```bash derip2 -i tests/data/mintest.fa \ + --threads 1 \ --max-gaps 0.7 \ --max-snp-noise 0.2 \ --min-rip-like 0.5 \ @@ -76,6 +79,7 @@ Non-RIP deamination events are also highlighted. ```bash derip2 -i tests/data/mintest.fa \ + --threads 1 \ --max-gaps 0.7 \ --reaminate \ -d results \ @@ -97,6 +101,8 @@ derip2 -i tests/data/mintest.fa \ ```code --version Show the version and exit. -i, --input TEXT Multiple sequence alignment. [required] + -t, --threads INTEGER Number of threads to use for processing. + Default: 1. [default: 1] -g, --max-gaps FLOAT Maximum proportion of gapped positions in column to be tolerated before forcing a gap in final deRIP sequence. [default: 0.7] diff --git a/docs/tutorial.ipynb b/docs/tutorial.ipynb index fda58a4..4d49827 100644 --- a/docs/tutorial.ipynb +++ b/docs/tutorial.ipynb @@ -600,8 +600,11 @@ "import os\n", "from derip2.derip import DeRIP\n", "\n", + "# Check number of CPU cores\n", + "num_cores = os.cpu_count()\n", + "\n", "# Load the alignment from gzipped file\n", - "sahana = DeRIP(gzipped_alignment_file)\n", + "sahana = DeRIP(gzipped_alignment_file, num_threads=num_cores)\n", "\n", "# Inspect alignment stats\n", "print(\n", diff --git a/src/derip2/aln_ops.py b/src/derip2/aln_ops.py index ebb93c2..d6ef7d1 100644 --- a/src/derip2/aln_ops.py +++ b/src/derip2/aln_ops.py @@ -22,6 +22,7 @@ from Bio.Seq import Seq from Bio.SeqRecord import SeqRecord from Bio.SeqUtils import gc_fraction +import numpy as np # Add numpy import from tqdm import tqdm from derip2.utils.checks import isfile @@ -31,6 +32,178 @@ ) +def MSAToArray(alignment: MultipleSeqAlignment) -> np.ndarray: + """ + Convert a Biopython MultipleSeqAlignment to a NumPy array for efficient processing. + + This function converts the alignment to a 2D NumPy array where each row represents + a sequence and each column represents an alignment position. This allows for much + faster column-wise operations compared to the native Biopython alignment access. + + Parameters + ---------- + alignment : Bio.Align.MultipleSeqAlignment + The sequence alignment to convert. + + Returns + ------- + np.ndarray + 2D NumPy array of shape (n_sequences, alignment_length) containing + single character strings representing nucleotides and gaps. + """ + # Convert each sequence to a list of characters, then stack into 2D array + return np.array([list(str(record.seq)) for record in alignment], dtype='U1') + + +def find_numpy( + column: np.ndarray, targets: Union[str, List[str], Set[str]] +) -> np.ndarray: + """ + Find indices of elements in a NumPy array that match specified characters. + + Parameters + ---------- + column : np.ndarray + 1D NumPy array of characters to search through. + targets : Union[str, List[str], Set[str]] + Character or collection of characters to find in the array. + + Returns + ------- + np.ndarray + Array of indices where matching characters were found. + """ + if isinstance(targets, str): + targets = [targets] + + # Use vectorized operations for much faster searching + mask = np.isin(column, targets) + return np.where(mask)[0] + + +def hasBoth_numpy(column: np.ndarray, a: str, b: str) -> bool: + """ + Check if a NumPy array contains at least one instance of each of two characters. + + Parameters + ---------- + column : np.ndarray + 1D NumPy array of characters to search through. + a : str + First character to find. + b : str + Second character to find. + + Returns + ------- + bool + True if both characters are present, False otherwise. + """ + return np.any(column == a) and np.any(column == b) + + +def nextBase_numpy( + arr: np.ndarray, colIdx: int, motif: str +) -> Tuple[np.ndarray, np.ndarray]: + """ + Find rows where a base is followed by a specific nucleotide using NumPy operations. + + This function identifies all rows in an alignment array where the column at index colIdx + contains the first base of a specified dinucleotide motif, and the next non-gap + position contains the second base of the motif. + + Parameters + ---------- + arr : np.ndarray + 2D NumPy array representing the alignment. + colIdx : int + Column index to check for the first base of the motif. + motif : str + Dinucleotide motif (e.g., 'CA' or 'TG'). + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + A tuple containing: + - Array of row indices where the specified pattern was found. + - Array of corresponding offsets (distance to the next non-gap position). + """ + n_rows, n_cols = arr.shape + + # Find rows where current column matches first base of motif + first_base_rows = find_numpy(arr[:, colIdx], motif[0]) + + matching_rows = [] + offsets = [] + + # For each row that has the first base + for row_idx in first_base_rows: + # Look for the next non-gap position + for offset in range(1, n_cols - colIdx): + next_col = colIdx + offset + if next_col >= n_cols: + break + + next_base = arr[row_idx, next_col] + if next_base != '-': # First non-gap position + if next_base == motif[1]: + matching_rows.append(row_idx) + offsets.append(offset) + break + + return np.array(matching_rows, dtype=int), np.array(offsets, dtype=int) + + +def lastBase_numpy( + arr: np.ndarray, colIdx: int, motif: str +) -> Tuple[np.ndarray, np.ndarray]: + """ + Find rows where a base is preceded by a specific nucleotide using NumPy operations. + + This function identifies all rows in an alignment array where the column at index colIdx + contains the second base of a specified dinucleotide motif, and the previous non-gap + position contains the first base of the motif. + + Parameters + ---------- + arr : np.ndarray + 2D NumPy array representing the alignment. + colIdx : int + Column index to check for the second base of the motif. + motif : str + Dinucleotide motif (e.g., 'CA' or 'TG'). + + Returns + ------- + Tuple[np.ndarray, np.ndarray] + A tuple containing: + - Array of row indices where the specified pattern was found. + - Array of corresponding offsets (distance to the previous non-gap position). + """ + # Find rows where current column matches second base of motif + second_base_rows = find_numpy(arr[:, colIdx], motif[1]) + + matching_rows = [] + offsets = [] + + # For each row that has the second base + for row_idx in second_base_rows: + # Look for the previous non-gap position + for offset in range(1, colIdx + 1): + prev_col = colIdx - offset + if prev_col < 0: + break + + prev_base = arr[row_idx, prev_col] + if prev_base != '-': # First non-gap position going backwards + if prev_base == motif[0]: + matching_rows.append(row_idx) + offsets.append(-offset) # Negative offset for backwards + break + + return np.array(matching_rows, dtype=int), np.array(offsets, dtype=int) + + def checkUniqueID(align: MultipleSeqAlignment) -> None: """ Validate that all sequence IDs in an alignment are unique. @@ -721,6 +894,8 @@ def correctRIP( min_rip_like: float = 0.1, reaminate: bool = True, mask: bool = False, + num_threads: Optional[int] = None, + min_columns_for_threading: int = 100, ) -> Tuple[ Dict[int, NamedTuple], Dict[int, NamedTuple], @@ -738,6 +913,8 @@ def correctRIP( 2. Updates the consensus sequence tracker with the ancestral (pre-RIP) base 3. Optionally masks the corrected positions in the output alignment + The analysis is optimized using NumPy arrays and can be parallelized for large alignments. + RIP signatures as observed in the + sense strand, with RIP targeting CpA motifs on either the +/- strand: @@ -764,6 +941,12 @@ def correctRIP( If True, also correct C→T or G→A transitions not in RIP context (default: True). mask : bool, optional If True, mask corrected positions in the alignment output (default: False). + num_threads : int, optional + Number of threads to use for parallel processing. If None, uses the number + of CPU cores available (default: None). + min_columns_for_threading : int, optional + Minimum number of alignment columns required to use parallel processing. + For smaller alignments, sequential processing is used (default: 100). Returns ------- @@ -778,7 +961,19 @@ def correctRIP( 'rip_substrate': Positions containing unmutated nucleotides in RIP context 'non_rip_deamination': Positions with C→T or G→A outside of RIP context """ + import concurrent.futures + from multiprocessing import cpu_count + import threading + import time + logging.debug('Correcting RIP-like mutations in the consensus sequence...') + + # Convert alignment to NumPy array for efficient processing + start_conversion = time.time() + arr = MSAToArray(align) + conversion_time = time.time() - start_conversion + logging.info(f'Converted alignment to NumPy array in {conversion_time:.3f} seconds') + # Create deep copies of input objects to avoid modifying the originals tracker = deepcopy(tracker) RIPcounts = deepcopy(RIPcounts) @@ -787,275 +982,477 @@ def correctRIP( # Store colIdx for each position that was corrected in the tracker corrected_positions = [] - # markupdict : Dict[str, List[RIPPosition]], optional - # Dictionary with RIP categories as keys and lists of position tuples as values. - # Categories are 'rip_product', 'rip_substrate', and 'non_rip_deamination'. - # Each position is a named tuple with (colIdx, rowIdx, offset). - # i.e. RIPPosition = NamedTuple('RIPPosition', [('colIdx', int), ('rowIdx', int), ('base',str),('offset', int)]) - # Initialize dictionary to store RIP categories for each position markupdict = {'rip_product': [], 'rip_substrate': [], 'non_rip_deamination': []} - # Process each column in the alignment with progress bar - for colIdx in tqdm( - range(align.get_alignment_length()), - desc='Scanning for RIP mutations', - unit='column', - ncols=80, - ): - # Track if we revert T→C or A→G in this column - modC = False - modG = False + # Threading locks for shared data structures + tracker_lock = threading.Lock() + ripcounts_lock = threading.Lock() + markup_lock = threading.Lock() + corrpos_lock = threading.Lock() + maskedalign_lock = threading.Lock() + + # Set number of threads to use + if num_threads is None: + num_threads = cpu_count() + + # If num_threads > available CPU cores, set to available cores + if num_threads > cpu_count(): + logging.warning( + f'Requested {num_threads} threads, but only {cpu_count()} available. Using {cpu_count()} threads.' + ) + num_threads = cpu_count() + + # Determine if we should use threading based on alignment size + alignment_length = align.get_alignment_length() + use_threading = alignment_length >= min_columns_for_threading and num_threads > 1 + + if use_threading: + logging.info( + f'Using {num_threads} threads to process {alignment_length} columns' + ) + else: + if alignment_length < min_columns_for_threading: + logging.info( + f'Alignment too small ({alignment_length} columns) for threading, using sequential processing' + ) + else: + logging.info('Using sequential processing with a single thread') + + # Function to analyze a single column using NumPy operations + def _analyze_column_numpy(colIdx): + """ + Analyze a single alignment column for RIP-like mutations using NumPy operations. + + This function examines a column of the alignment for patterns consistent with + Repeat-Induced Point (RIP) mutations, using vectorized NumPy operations for + much better performance than the original Biopython-based approach. + + Parameters + ---------- + colIdx : int + Column index in the alignment (0-based) to analyze. + + Returns + ------- + dict + Dictionary containing analysis results with the same structure as the + original _analyze_column function. + """ + # Local results to collect before applying to shared structures + column_results = { + 'modC': False, + 'modG': False, + 'corrected_positions': [], + 'markupdict_updates': { + 'rip_product': [], + 'rip_substrate': [], + 'non_rip_deamination': [], + }, + 'ripcounts_updates': [], # List of (row_idx, addRev, addFwd, addNonRIP) tuples + 'masking_updates': [], # List of (colIdx, targetRows, newbase) tuples + } + + # Get the column as a NumPy array + column = arr[:, colIdx] # Count total number of nucleotide bases (excluding gaps) - baseCount = len(find(align[:, colIdx], ['A', 'T', 'G', 'C'])) + nucleotide_mask = np.isin(column, ['A', 'T', 'G', 'C']) + baseCount = np.sum(nucleotide_mask) # Skip columns with no bases - if baseCount: - # Identify rows containing C or T in this column - CTinCol = find(align[:, colIdx], ['C', 'T']) - # Identify rows containing G or A in this column - GAinCol = find(align[:, colIdx], ['G', 'A']) - - # Calculate proportion of C/T and G/A bases - CTprop = len(CTinCol) / baseCount - GAprop = len(GAinCol) / baseCount - - # FORWARD STRAND RIP DETECTION (C→T) - # Check if column has sufficient C/T content - if CTprop >= max_snp_noise: - # Find rows where C is followed by A (RIP substrate) - # Even if whole column is C, we can still have RIP substrate - CArows, _CA_nextbase_offsets = nextBase(align, colIdx, motif='CA') - # Record forward strand RIP substrate for CA rows - for rowCA, offset in zip(CArows, _CA_nextbase_offsets): - markupdict = updateMarkupDict( - 'rip_substrate', - markupdict, - colIdx, - base='C', - row_idx=rowCA, - offset=offset, + if baseCount == 0: + return column_results + + # Find rows containing C or T in this column using NumPy + CTinCol = find_numpy(column, ['C', 'T']) + # Find rows containing G or A in this column using NumPy + GAinCol = find_numpy(column, ['G', 'A']) + + # Calculate proportion of C/T and G/A bases + CTprop = len(CTinCol) / baseCount + GAprop = len(GAinCol) / baseCount + + # FORWARD STRAND RIP DETECTION (C→T) + # Check if column has sufficient C/T content + if CTprop >= max_snp_noise: + # Find rows where C is followed by A (RIP substrate) + CArows, _CA_nextbase_offsets = nextBase_numpy(arr, colIdx, motif='CA') + # Record forward strand RIP substrate for CA rows + for rowCA, offset in zip(CArows, _CA_nextbase_offsets): + column_results['markupdict_updates']['rip_substrate'].append( + RIPPosition( + colIdx=colIdx, rowIdx=int(rowCA), base='C', offset=int(offset) ) - # Check if C/T content is higher than G/A content and both C and T are present - if CTprop > GAprop and hasBoth(align[:, colIdx], 'C', 'T'): - # Find rows where C/T is followed by A (potential RIP context) - TArows, _TA_nextbase_offsets = nextBase( - align, colIdx, motif='TA' - ) # T followed by A (mutated) - CArows, _CA_nextbase_offsets = nextBase( - align, colIdx, motif='CA' - ) # C followed by A (ancestral) - - # Get rows with T in this column - TinCol = find(align[:, colIdx], ['T']) - - # If we have both CA and TA context (indicating RIP transition) - if CArows and TArows: - # Calculate proportion of C/T positions in a RIP-like context - propRIPlike = (len(TArows) + len(CArows)) / len(CTinCol) - - # Record forward strand RIP substrate for CA rows - for rowCA, offset in zip(CArows, _CA_nextbase_offsets): - markupdict = updateMarkupDict( - 'rip_substrate', - markupdict, - colIdx, + ) + + # Check if C/T content is higher than G/A content and both C and T are present + if CTprop > GAprop and hasBoth_numpy(column, 'C', 'T'): + # Find rows where C/T is followed by A (potential RIP context) + TArows, _TA_nextbase_offsets = nextBase_numpy( + arr, colIdx, motif='TA' + ) # T followed by A (mutated) + CArows, _CA_nextbase_offsets = nextBase_numpy( + arr, colIdx, motif='CA' + ) # C followed by A (ancestral) + + # Get rows with T in this column + TinCol = find_numpy(column, ['T']) + + # If we have both CA and TA context (indicating RIP transition) + if len(CArows) > 0 and len(TArows) > 0: + # Calculate proportion of C/T positions in a RIP-like context + propRIPlike = (len(TArows) + len(CArows)) / len(CTinCol) + + # Record forward strand RIP substrate for CA rows + for rowCA, offset in zip(CArows, _CA_nextbase_offsets): + column_results['markupdict_updates']['rip_substrate'].append( + RIPPosition( + colIdx=colIdx, + rowIdx=int(rowCA), base='C', - row_idx=rowCA, - offset=offset, + offset=int(offset), ) - - # Record forward strand RIP events for TA rows - for rowTA in set(TArows): - RIPcounts = updateRIPCount(rowTA, RIPcounts, addFwd=1) - - for rowTA, offset in zip(TArows, _TA_nextbase_offsets): - markupdict = updateMarkupDict( - 'rip_product', - markupdict, - colIdx, + ) + + # Record forward strand RIP events for TA rows + for rowTA in np.unique(TArows): + column_results['ripcounts_updates'].append( + (int(rowTA), 0, 1, 0) + ) # (rowIdx, addRev, addFwd, addNonRIP) + + for rowTA, offset in zip(TArows, _TA_nextbase_offsets): + column_results['markupdict_updates']['rip_product'].append( + RIPPosition( + colIdx=colIdx, + rowIdx=int(rowTA), base='T', - row_idx=rowTA, - offset=offset, + offset=int(offset), ) - - # Record non-RIP deamination for T's not in TA context - for TnonRIP in {x for x in TinCol if x not in TArows}: - RIPcounts = updateRIPCount(TnonRIP, RIPcounts, addNonRIP=1) - markupdict = updateMarkupDict( - 'non_rip_deamination', - markupdict, - colIdx, + ) + + # Record non-RIP deamination for T's not in TA context + TnonRIP = np.setdiff1d(TinCol, TArows) + for TnonRIP_row in TnonRIP: + column_results['ripcounts_updates'].append( + (int(TnonRIP_row), 0, 0, 1) + ) + column_results['markupdict_updates'][ + 'non_rip_deamination' + ].append( + RIPPosition( + colIdx=colIdx, + rowIdx=int(TnonRIP_row), base='T', - row_idx=TnonRIP, offset=0, ) - - # If sufficient mutations are in RIP context, correct to ancestral C - if propRIPlike >= min_rip_like: - tracker = updateTracker(colIdx, 'C', tracker, force=False) - modC = True - # Log corrected position in tracker - corrected_positions.append(colIdx) - # Otherwise correct if reaminate option is enabled - elif reaminate: - tracker = updateTracker(colIdx, 'C', tracker, force=False) - modC = True - # Log corrected position in tracker - corrected_positions.append(colIdx) - - # If C and T present but not in RIP context - else: - # If reaminate flag is on, correct to C anyway - if reaminate: - tracker = updateTracker(colIdx, 'C', tracker, force=False) - modC = True - # Log corrected position in tracker - corrected_positions.append(colIdx) - - # Log all T's as non-RIP deamination events - for TnonRIP in TinCol: - RIPcounts = updateRIPCount(TnonRIP, RIPcounts, addNonRIP=1) - markupdict = updateMarkupDict( - 'non_rip_deamination', - markupdict, - colIdx, + ) + + # If sufficient mutations are in RIP context, mark for correction to ancestral C + if propRIPlike >= min_rip_like: + column_results['modC'] = True + column_results['corrected_positions'].append(colIdx) + # Otherwise correct if reaminate option is enabled + elif reaminate: + column_results['modC'] = True + column_results['corrected_positions'].append(colIdx) + + # If C and T present but not in RIP context + else: + # If reaminate flag is on, mark for correction anyway + if reaminate: + column_results['modC'] = True + column_results['corrected_positions'].append(colIdx) + + # Log all T's as non-RIP deamination events + for TnonRIP_row in TinCol: + column_results['ripcounts_updates'].append( + (int(TnonRIP_row), 0, 0, 1) + ) + column_results['markupdict_updates'][ + 'non_rip_deamination' + ].append( + RIPPosition( + colIdx=colIdx, + rowIdx=int(TnonRIP_row), base='T', - row_idx=TnonRIP, offset=0, ) - - # REVERSE STRAND RIP DETECTION (G→A) - # Check if column has sufficient G/A content and both G and A are present - if GAprop >= max_snp_noise: - # Find rows where G is followed by T (RIP substrate) - # Even if whole column is G, we can still have RIP substrate - TGrows, _TG_lastbase_offsets = lastBase(align, colIdx, motif='TG') - # Record forward strand RIP substrate for TG rows - for rowTG, offset in zip(TGrows, _TG_lastbase_offsets): - markupdict = updateMarkupDict( - 'rip_substrate', - markupdict, - colIdx, - base='G', - row_idx=rowTG, - offset=offset, + ) + + # REVERSE STRAND RIP DETECTION (G→A) + # Check if column has sufficient G/A content + if GAprop >= max_snp_noise: + # Find rows where G is preceded by T (RIP substrate) + TGrows, _TG_lastbase_offsets = lastBase_numpy(arr, colIdx, motif='TG') + # Record reverse strand RIP substrate for TG rows + for rowTG, offset in zip(TGrows, _TG_lastbase_offsets): + column_results['markupdict_updates']['rip_substrate'].append( + RIPPosition( + colIdx=colIdx, rowIdx=int(rowTG), base='G', offset=int(offset) ) - # Check if G/A content is higher than C/T content and both G and A are present - if GAprop > CTprop and hasBoth(align[:, colIdx], 'G', 'A'): - # Find rows where G/A is preceded by T (potential RIP context) - TGrows, _TG_lastbase_offsets = lastBase( - align, colIdx, motif='TG' - ) # T followed by G (ancestral) - TArows, _TA_lastbase_offsets = lastBase( - align, colIdx, motif='TA' - ) # T followed by A (mutated) - - # Get rows with A in this column - AinCol = find(align[:, colIdx], ['A']) - - # If we have both TG and TA context (indicating RIP transition) - if TGrows and TArows: - # Calculate proportion of G/A positions in a RIP-like context - propRIPlike = (len(TGrows) + len(TArows)) / len(GAinCol) - - # Record forward strand RIP substrate for TG rows - for rowTG, offset in zip(TGrows, _TG_lastbase_offsets): - markupdict = updateMarkupDict( - 'rip_substrate', - markupdict, - colIdx, + ) + + # Check if G/A content is higher than C/T content and both G and A are present + if GAprop > CTprop and hasBoth_numpy(column, 'G', 'A'): + # Find rows where G/A is preceded by T (potential RIP context) + TGrows, _TG_lastbase_offsets = lastBase_numpy( + arr, colIdx, motif='TG' + ) # T followed by G (ancestral) + TArows, _TA_lastbase_offsets = lastBase_numpy( + arr, colIdx, motif='TA' + ) # T followed by A (mutated) + + # Get rows with A in this column + AinCol = find_numpy(column, ['A']) + + # If we have both TG and TA context (indicating RIP transition) + if len(TGrows) > 0 and len(TArows) > 0: + # Calculate proportion of G/A positions in a RIP-like context + propRIPlike = (len(TGrows) + len(TArows)) / len(GAinCol) + + # Record reverse strand RIP substrate for TG rows + for rowTG, offset in zip(TGrows, _TG_lastbase_offsets): + column_results['markupdict_updates']['rip_substrate'].append( + RIPPosition( + colIdx=colIdx, + rowIdx=int(rowTG), base='G', - row_idx=rowTG, - offset=offset, + offset=int(offset), ) - - # Record reverse strand RIP events for TA rows - for rowTA in set(TArows): - RIPcounts = updateRIPCount(rowTA, RIPcounts, addRev=1) - - for rowTA, offset in zip(TArows, _TA_lastbase_offsets): - markupdict = updateMarkupDict( - 'rip_product', - markupdict, - colIdx, + ) + + # Record reverse strand RIP events for TA rows + for rowTA in np.unique(TArows): + column_results['ripcounts_updates'].append( + (int(rowTA), 1, 0, 0) + ) + + for rowTA, offset in zip(TArows, _TA_lastbase_offsets): + column_results['markupdict_updates']['rip_product'].append( + RIPPosition( + colIdx=colIdx, + rowIdx=int(rowTA), base='A', - row_idx=rowTA, - offset=offset, + offset=int(offset), ) - - # Record non-RIP deamination for A's not in TA context - for AnonRIP in {x for x in AinCol if x not in TArows}: - RIPcounts = updateRIPCount(AnonRIP, RIPcounts, addNonRIP=1) - markupdict = updateMarkupDict( - 'non_rip_deamination', - markupdict, - colIdx, + ) + + # Record non-RIP deamination for A's not in TA context + AnonRIP = np.setdiff1d(AinCol, TArows) + for AnonRIP_row in AnonRIP: + column_results['ripcounts_updates'].append( + (int(AnonRIP_row), 0, 0, 1) + ) + column_results['markupdict_updates'][ + 'non_rip_deamination' + ].append( + RIPPosition( + colIdx=colIdx, + rowIdx=int(AnonRIP_row), base='A', - row_idx=AnonRIP, offset=0, ) - - # If sufficient mutations are in RIP context, correct to ancestral G - if propRIPlike >= min_rip_like: - tracker = updateTracker(colIdx, 'G', tracker, force=False) - modG = True - # Log corrected position in tracker - corrected_positions.append(colIdx) - # Otherwise correct if reaminate option is enabled - elif reaminate: - tracker = updateTracker(colIdx, 'G', tracker, force=False) - modG = True - # Log corrected position in tracker - corrected_positions.append(colIdx) - - # If G and A present but not in RIP context - else: - # If reaminate flag is on, correct to G anyway - if reaminate: - tracker = updateTracker(colIdx, 'G', tracker, force=False) - modG = True - # Log corrected position in tracker - corrected_positions.append(colIdx) - - # Log all A's as non-RIP deamination events - for AnonRIP in AinCol: - RIPcounts = updateRIPCount(AnonRIP, RIPcounts, addNonRIP=1) - markupdict = updateMarkupDict( - 'non_rip_deamination', - markupdict, - colIdx, + ) + + # If sufficient mutations are in RIP context, mark for correction to ancestral G + if propRIPlike >= min_rip_like: + column_results['modG'] = True + column_results['corrected_positions'].append(colIdx) + # Otherwise correct if reaminate option is enabled + elif reaminate: + column_results['modG'] = True + column_results['corrected_positions'].append(colIdx) + + # If G and A present but not in RIP context + else: + # If reaminate flag is on, mark for correction anyway + if reaminate: + column_results['modG'] = True + column_results['corrected_positions'].append(colIdx) + + # Log all A's as non-RIP deamination events + for AnonRIP_row in AinCol: + column_results['ripcounts_updates'].append( + (int(AnonRIP_row), 0, 0, 1) + ) + column_results['markupdict_updates'][ + 'non_rip_deamination' + ].append( + RIPPosition( + colIdx=colIdx, + rowIdx=int(AnonRIP_row), base='A', - row_idx=AnonRIP, offset=0, ) + ) + + # Mark masking operations for C→T corrections if needed + if column_results['modC']: + if reaminate: + # If reaminating all C→T transitions, mask all T positions in column + targetRows = find_numpy(column, ['T']) + else: + # Otherwise only mask 'T' positions in TpA context where C→T occurred + targetRows = TArows + + if len(targetRows) > 0: # Only add if there are rows to update + column_results['masking_updates'].append( + (colIdx, targetRows.tolist(), 'Y') + ) + + # Mark masking operations for G→A corrections if needed + if column_results['modG']: + if reaminate: + # If reaminating all G→A transitions, mask all A positions + targetRows = find_numpy(column, ['A']) + else: + # Otherwise only mask 'A' positions in TpA context where G→A occurred + targetRows = TArows + + if len(targetRows) > 0: # Only add if there are rows to update + column_results['masking_updates'].append( + (colIdx, targetRows.tolist(), 'R') + ) + + return column_results + + # Apply column results to shared data structures with appropriate locking + def _apply_column_results(results): + """ + Apply column analysis results to shared data structures with thread safety. + + This function updates the various shared data structures based on the analysis + results for a single column. It handles the following updates: + 1. Updates consensus tracker for detected C→T or G→A mutations + 2. Adds positions to the corrected_positions list + 3. Updates RIP counts for affected sequences + 4. Updates the markup dictionary for visualization + 5. Applies masking to the alignment if requested + + Parameters + ---------- + results : dict + Dictionary containing analysis results for a column with the following keys: + - 'modC' : bool + Whether this column should be corrected to 'C' in the consensus. + - 'modG' : bool + Whether this column should be corrected to 'G' in the consensus. + - 'corrected_positions' : list + Column indices that should be corrected in the consensus sequence. + - 'markupdict_updates' : dict + Dictionary with keys 'rip_product', 'rip_substrate', and + 'non_rip_deamination', each containing a list of RIPPosition + objects to be added to the markup dictionary. + - 'ripcounts_updates' : list + List of tuples (row_idx, addRev, addFwd, addNonRIP) representing + increments to various RIP counters for specific sequences. + - 'masking_updates' : list + List of tuples (colIdx, targetRows, newbase) for applying + masking operations to the alignment if requested. + + Notes + ----- + This function accesses several variables from the outer scope using the + nonlocal declaration: + - tracker : Dict[int, NamedTuple] + Dictionary tracking the consensus sequence state for each column. + - RIPcounts : Dict[int, NamedTuple] + Dictionary tracking RIP mutation counts for each sequence. + - maskedAlign : Bio.Align.MultipleSeqAlignment + Alignment object that may be modified with masking operations. + - corrected_positions : List[int] + List of column indices that were corrected in the consensus. + - markupdict : Dict[str, List[RIPPosition]] + Dictionary mapping RIP categories to positions for visualization. + + All updates to shared data structures are protected by appropriate locks + to ensure thread safety during parallel processing. + """ + # Declare variables from outer scope + nonlocal tracker, RIPcounts, maskedAlign, corrected_positions, markupdict + + # Update tracker if needed + if results['modC'] or results['modG']: + with tracker_lock: + for col_idx in results['corrected_positions']: + if results['modC']: + tracker = updateTracker(col_idx, 'C', tracker, force=False) + if results['modG']: + tracker = updateTracker(col_idx, 'G', tracker, force=False) + + # Update corrected positions + if results['corrected_positions']: + with corrpos_lock: + corrected_positions.extend(results['corrected_positions']) + + # Update RIP counts + if results['ripcounts_updates']: + with ripcounts_lock: + for row_idx, addRev, addFwd, addNonRIP in results['ripcounts_updates']: + RIPcounts = updateRIPCount( + row_idx, RIPcounts, addRev, addFwd, addNonRIP + ) - # Apply masking for C→T corrections if requested - if modC: - if reaminate: - # If reaminating all C→T transitions, mask all T positions in column - targetRows = find(align[:, colIdx], ['T']) - else: - # Otherwise only mask 'T' positions in TpA context where C→T occurred - targetRows = TArows - # substrate_rows = CArows - - # Replace target positions with IUPAC ambiguity code Y (C or T) - maskedAlign = replaceBase(maskedAlign, colIdx, targetRows, 'Y') - - # Apply masking for G→A corrections if requested - if modG: - if reaminate: - # If reaminating all G→A transitions, mask all positions - targetRows = find(align[:, colIdx], ['A']) - else: - # Otherwise only mask 'A' positions in TpA context where G→A occurred - targetRows = TArows - # substrate_rows = TGrows + # Update markup dictionary + with markup_lock: + for category in results['markupdict_updates']: + for position in results['markupdict_updates'][category]: + if position not in markupdict[category]: + markupdict[category].append(position) + + # Apply masking updates + if mask and results['masking_updates']: + with maskedalign_lock: + for col_idx, target_rows, new_base in results['masking_updates']: + maskedAlign = replaceBase( + maskedAlign, col_idx, target_rows, new_base + ) + + # Start timing + start_time = time.time() + + # Choose between parallel and sequential processing + if use_threading: + # Process columns in parallel using ThreadPoolExecutor with progress bar + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + # Submit all column analysis jobs + future_to_col = { + executor.submit(_analyze_column_numpy, colIdx): colIdx + for colIdx in range(alignment_length) + } - # Replace target positions with IUPAC ambiguity code R (G or A) - maskedAlign = replaceBase(maskedAlign, colIdx, targetRows, 'R') + # Process results with a progress bar + for future in tqdm( + concurrent.futures.as_completed(future_to_col), + desc='Scanning for RIP mutations', + total=alignment_length, + unit='column', + ncols=80, + colour='green', # Use color to make the bar more visible + leave=True, # Keep the progress bar after completion + ): + col_result = future.result() + _apply_column_results(col_result) + else: + # Sequential processing with progress bar + for colIdx in tqdm( + range(alignment_length), + desc='Scanning for RIP mutations', + unit='column', + ncols=80, + colour='green', # Use color to make the bar more visible + leave=True, # Keep the progress bar after completion + ): + col_result = _analyze_column_numpy(colIdx) + _apply_column_results(col_result) + + # Calculate elapsed time + elapsed_time = time.time() - start_time + logging.info( + f'RIP detection completed in {elapsed_time:.2f} seconds ({alignment_length} columns)' + ) + logging.info(f'Detected {len(corrected_positions)} columns with RIP-like mutations') return (tracker, RIPcounts, maskedAlign, corrected_positions, markupdict) diff --git a/src/derip2/app.py b/src/derip2/app.py index 5eca170..0e255dd 100644 --- a/src/derip2/app.py +++ b/src/derip2/app.py @@ -19,6 +19,7 @@ import logging from os import path import sys +import time import click @@ -38,6 +39,15 @@ @click.option( '-i', '--input', required=True, type=str, help='Multiple sequence alignment.' ) +# Threads opton +@click.option( + '-t', + '--threads', + type=int, + default=1, + show_default=True, + help='Number of threads to use for processing. Default: 1.', +) # Algorithm parameters @click.option( '-g', @@ -139,6 +149,7 @@ @click.option('--logfile', default=None, help='Log file path.') def main( input, + threads, max_gaps, reaminate, max_snp_noise, @@ -170,6 +181,8 @@ def main( ---------- input : str Path to multiple sequence alignment file. + threads : int + Number of threads to use for processing. Default: 1. max_gaps : float Maximum proportion of gapped positions in column to be tolerated before forcing a gap in final deRIP sequence. Default: 0.7. @@ -217,6 +230,9 @@ def main( None Does not return any values, but writes output files and logs to the console. """ + # Record start time for execution timing + start_time = time.time() + # ---------- Setup ---------- # Print full command line call print(f'Command line call: {colored.green(" ".join(sys.argv))}\n') @@ -247,6 +263,7 @@ def main( fill_index=fill_index, fill_max_gc=fill_max_gc, max_gaps=max_gaps, + num_threads=threads, ) # Report alignment summary @@ -266,7 +283,16 @@ def main( logging.info(f'RIP summary by row:\n\033[0m{derip_obj.rip_summary()}\n') # Print colourized alignment + consensus - logging.info(f'Corrected alignment:\n\033[0m{derip_obj}\n') + # If the alignment is too large, skip printing + if derip_obj.alignment.get_alignment_length() > 200: + logging.info('Alignment is too large to print. Skipping alignment printout.\n') + else: + logging.info( + f'Alignment with RIP mutations highlighted:\n\033[0m{derip_obj.colored_alignment}\n' + ) + logging.info( + f'Consensus sequence with RIP mutations highlighted:\n\033[0m{derip_obj.colored_consensus}\n' + ) # ---------- Output Results ---------- # Report deRIP'd sequence to stdout @@ -325,6 +351,10 @@ def main( else: logging.warning('Failed to create RIP visualization') + # Calculate and log total execution time + elapsed_time = time.time() - start_time + logging.info(f'Total execution time: {elapsed_time:.2f} seconds') + if __name__ == '__main__': main() diff --git a/src/derip2/derip.py b/src/derip2/derip.py index be6fe4e..6051564 100644 --- a/src/derip2/derip.py +++ b/src/derip2/derip.py @@ -42,6 +42,8 @@ class DeRIP: max_gaps : float, optional Maximum proportion of gaps in a column before considering it a gap in consensus (default: 0.7). + num_threads : int, optional + Number of threads to use for parallel processing (default: 1). Attributes ---------- @@ -76,6 +78,7 @@ def __init__( fill_index: Optional[int] = None, fill_max_gc: bool = False, max_gaps: float = 0.7, + num_threads: int = 1, ) -> None: """ Initialize DeRIP with an alignment file or MultipleSeqAlignment object and parameters. @@ -101,6 +104,8 @@ def __init__( max_gaps : float, optional Maximum proportion of gaps in a column before considering it a gap in consensus (default: 0.7). + num_threads : int, optional + Number of threads to use for parallel processing (default: 1). """ # Store parameters self.max_snp_noise = max_snp_noise @@ -109,6 +114,7 @@ def __init__( self.fill_index = fill_index self.fill_max_gc = fill_max_gc self.max_gaps = max_gaps + self.num_threads = num_threads # Initialize attributes self.alignment = None @@ -257,6 +263,7 @@ def calculate_rip(self, label: str = 'deRIPseq') -> None: min_rip_like=self.min_rip_like, reaminate=self.reaminate, mask=True, # Always mask so we have the masked alignment available + num_threads=self.num_threads, ) ) @@ -744,6 +751,7 @@ def plot_alignment( show_rip=show_rip, highlight_corrected=highlight_corrected, flag_corrected=flag_corrected, + num_threads=self.num_threads, **kwargs, # Pass any additional customization options ) diff --git a/src/derip2/plotting/minialign.py b/src/derip2/plotting/minialign.py index 934e709..9845f76 100644 --- a/src/derip2/plotting/minialign.py +++ b/src/derip2/plotting/minialign.py @@ -178,14 +178,13 @@ def MSAToArray( # Define valid nucleotide characters for DNA sequences valid_chars: Set[str] = {'A', 'G', 'C', 'T', 'N', '-'} - # Extract sequences from the alignment object + # Extract sequences from the alignment object using vectorized operations for record in alignment: nams.append(record.id) - # Convert sequence to uppercase and replace invalid characters with gaps - seq = [ - base if base.upper() in valid_chars else '-' - for base in str(record.seq).upper() - ] + # Convert sequence to uppercase string once + seq_str = str(record.seq).upper() + # Use list comprehension for faster character processing + seq = [base if base in valid_chars else '-' for base in seq_str] seqs.append(seq) # Check if we have enough sequences for an alignment @@ -193,28 +192,25 @@ def MSAToArray( if seq_len <= 1: return None, None, None - # Verify all sequences have the same length (proper alignment) - # This should always be true for a Biopython MSA object, but check anyway - seq_lengths = {len(seq) for seq in seqs} - if len(seq_lengths) > 1: + # Convert list of sequences to numpy array in one operation + # This is much faster than building the array incrementally + arr = np.array(seqs, dtype='U1') # Unicode string of length 1 + + # Verify all sequences have the same length (should be true for MSA) + if arr.shape[1] == 0 or len({len(seq) for seq in seqs}) > 1: raise ValueError( 'ERROR: The sequences in the alignment have different lengths. This should not happen with a MultipleSeqAlignment.' ) - # Convert list of sequences to numpy array - arr = np.array(seqs) return arr, nams, seq_len -def arrNumeric( +def arrNumeric_optimized( arr: np.ndarray, palette: str = 'colorblind' ) -> Tuple[np.ndarray, matplotlib.colors.ListedColormap]: """ Convert sequence array into a numerical matrix with a color map for visualization. - - This function transforms the sequence data into a format that matplotlib - can interpret as an image. The sequence array is flipped vertically so the - output image has rows in the same order as the input alignment. + Optimized version using vectorized NumPy operations. Parameters ---------- @@ -237,34 +233,402 @@ def arrNumeric( # Select the appropriate color pattern or default to colorblind color_pattern = get_color_palette(palette) - # Get dimensions of the alignment - ali_height, ali_width = np.shape(arr) + # Find unique nucleotides in the array using NumPy operations + unique_nucleotides = np.unique(arr) # Create mapping from nucleotides to numeric values - keys = list(color_pattern.keys()) nD = {} # Dictionary mapping nucleotides to integers colours = [] # List of colors for the colormap - # Build the mapping and color list for the specific nucleotides in the alignment - i = 0 - for key in keys: - if key in arr: - nD[key] = i - colours.append(color_pattern[key]) - i += 1 - - # Create the numeric representation of the alignment - arr2 = np.empty([ali_height, ali_width]) - for x in range(ali_width): - for y in range(ali_height): - # Convert each nucleotide to its corresponding integer - arr2[y, x] = nD[arr[y, x]] + # Build the mapping and color list for nucleotides present in the alignment + for i, nucleotide in enumerate(unique_nucleotides): + if nucleotide in color_pattern: + nD[nucleotide] = i + colours.append(color_pattern[nucleotide]) + + # Create the numeric representation using vectorized operations + arr2 = np.zeros_like(arr, dtype=int) + for nucleotide, value in nD.items(): + arr2[arr == nucleotide] = value # Create the colormap for visualization cmap = matplotlib.colors.ListedColormap(colours) return arr2, cmap +def getHighlightedPositions_optimized( + markupdict: Dict[str, List[RIPPosition]], + ali_height: int, + arr: np.ndarray = None, + reaminate: bool = False, + num_threads: int = None, + min_items_for_threading: int = 1000, +) -> Set[Tuple[int, int]]: + """ + Optimized version using NumPy operations to get highlighted positions. + + Parameters + ---------- + markupdict : Dict[str, List[RIPPosition]] + Dictionary with categories as keys and lists of position tuples as values. + ali_height : int + Height of the alignment (number of rows). + arr : np.ndarray, optional + The original alignment array, used to check for gap positions. + reaminate : bool, optional + Whether to include non-RIP deamination positions. + num_threads : int, optional + Number of threads to use for parallel processing. + min_items_for_threading : int, optional + Minimum number of positions required to use parallel processing. + + Returns + ------- + Set[Tuple[int, int]] + Set of (col_idx, flipped_y) tuples for all highlighted positions. + """ + import time + + start_time = time.time() + + # Collect all positions to process + all_positions = [] + for category, positions in markupdict.items(): + if category != 'non_rip_deamination' or reaminate: + all_positions.extend(positions) + + if not all_positions: + return set() + + # Convert to structured arrays for vectorized processing + positions_array = np.array( + [ + (pos.colIdx, pos.rowIdx, pos.offset if pos.offset is not None else 0) + for pos in all_positions + ], + dtype=[('col', int), ('row', int), ('offset', int)], + ) + + # Pre-compute flipped y coordinates + flipped_ys = ali_height - positions_array['row'] - 1 + + # Initialize result set + highlighted_positions = set() + + # Process positions with no offset first (most common case) + no_offset_mask = positions_array['offset'] == 0 + if np.any(no_offset_mask): + no_offset_cols = positions_array['col'][no_offset_mask] + no_offset_ys = flipped_ys[no_offset_mask] + + # Add all no-offset positions at once + highlighted_positions.update(zip(no_offset_cols, no_offset_ys)) + + # Process positions with offsets + offset_mask = positions_array['offset'] != 0 + if np.any(offset_mask): + offset_positions = positions_array[offset_mask] + offset_ys = flipped_ys[offset_mask] + + # Group by offset value for batch processing + unique_offsets = np.unique(offset_positions['offset']) + + for offset_val in unique_offsets: + offset_val_mask = offset_positions['offset'] == offset_val + cols = offset_positions['col'][offset_val_mask] + ys = offset_ys[offset_val_mask] + + if offset_val < 0: + # Negative offset: positions to the left + for col, y in zip(cols, ys): + valid_cols = np.arange(max(0, col + offset_val), col) + if arr is not None: + # Check for gaps using vectorized operations + row_data = arr[ali_height - y - 1, valid_cols] + valid_cols = valid_cols[row_data != '-'] + highlighted_positions.update(zip(valid_cols, [y] * len(valid_cols))) + else: + # Positive offset: positions to the right + for col, y in zip(cols, ys): + max_col = arr.shape[1] if arr is not None else col + offset_val + 1 + valid_cols = np.arange(col + 1, min(max_col, col + offset_val + 1)) + if arr is not None and len(valid_cols) > 0: + # Check for gaps using vectorized operations + row_data = arr[ali_height - y - 1, valid_cols] + valid_cols = valid_cols[row_data != '-'] + highlighted_positions.update(zip(valid_cols, [y] * len(valid_cols))) + + elapsed_time = time.time() - start_time + logging.info( + f'Optimized position highlighting completed in {elapsed_time:.2f} seconds' + ) + logging.info(f'Found {len(highlighted_positions)} positions to highlight') + + return highlighted_positions + + +def markupRIPBases_optimized( + a: plt.Axes, + markupdict: Dict[str, List[RIPPosition]], + ali_height: int, + arr: np.ndarray = None, + reaminate: bool = False, + palette: str = 'derip2', + draw_boxes: bool = True, + num_threads: int = None, + min_items_for_threading: int = 100, +) -> Tuple[Set[Tuple[int, int]], Set[Tuple[int, int]]]: + """ + Optimized version of markupRIPBases using NumPy operations for better performance. + + This function visualizes different categories of RIP mutations by adding colored + rectangles to the matplotlib axes. Target bases (primary mutation sites) are drawn + with full opacity and black borders, while offset bases (context around mutations) + are drawn with reduced opacity. Uses vectorized NumPy operations and optional + parallel processing for improved performance on large datasets. + + Parameters + ---------- + a : plt.Axes + The matplotlib axes object where the alignment is being plotted. + markupdict : Dict[str, List[RIPPosition]] + Dictionary containing RIP positions to highlight, with categories as keys: + - 'rip_product': Positions where RIP mutations have occurred (typically T from C→T) + - 'rip_substrate': Positions with unmutated nucleotides in RIP context + - 'non_rip_deamination': Positions with deamination events not in RIP context + + Each value is a list of RIPPosition named tuples with fields: + - colIdx: column index in alignment (int) + - rowIdx: row index in alignment (int) + - base: nucleotide base at this position (str) + - offset: context range around the mutation, negative=left, positive=right (int or None) + ali_height : int + Height of the alignment in rows (number of sequences). + arr : np.ndarray, optional + Original alignment array, needed to get base identities for offset positions. + Shape should be (ali_height, alignment_width). If None, base colors will be + determined from the RIPPosition base field (default: None). + reaminate : bool, optional + Whether to include non-RIP deamination highlights (default: False). + palette : str, optional + Color palette to use for base highlighting. Options include 'colorblind', + 'bright', 'tetrimmer', 'basegrey', or 'derip2' (default: 'derip2'). + draw_boxes : bool, optional + Whether to draw black borders around highlighted bases (default: True). + num_threads : int, optional + Number of threads to use for parallel processing. If None, uses the number + of CPU cores available. Parallel processing is only used if the number of + positions exceeds min_items_for_threading (default: None). + min_items_for_threading : int, optional + Minimum number of positions required to use parallel processing. + For smaller datasets, sequential processing is used for better efficiency + (default: 100). + + Returns + ------- + highlighted_positions : Set[Tuple[int, int]] + Set of all (col_idx, y_coord) positions that received highlighting, + including both target bases and offset positions. Coordinates are in + matplotlib format where y-axis is flipped (0 at bottom, increasing upward). + target_positions : Set[Tuple[int, int]] + Set of only the primary mutation site (col_idx, y_coord) positions, + excluding offset positions. Used for text coloring in other functions. + Coordinates are in matplotlib format. + + Notes + ----- + - Target bases are drawn with full opacity and black borders (if draw_boxes=True) + - Offset bases (context) are drawn with 70% opacity to distinguish from targets + - Uses structured NumPy arrays for efficient batch processing of positions + - Parallel processing is automatically enabled for large datasets + - All drawing operations are queued and executed sequentially for thread safety + - Performance scales well with dataset size due to vectorized operations + + Examples + -------- + >>> import matplotlib.pyplot as plt + >>> fig, ax = plt.subplots() + >>> markup = {'rip_product': [RIPPosition(10, 0, 'T', 0)]} + >>> highlighted, targets = markupRIPBases_optimized(ax, markup, 5) + >>> len(highlighted) >= len(targets) + True + """ + import time + + start_time = time.time() + + # Collect all positions and convert to NumPy arrays for vectorized processing + all_positions = [] + all_categories = [] + + for category, positions in markupdict.items(): + if category != 'non_rip_deamination' or reaminate: + all_positions.extend(positions) + all_categories.extend([category] * len(positions)) + + if not all_positions: + return set(), set() + + # Convert to structured NumPy array for efficient processing + positions_data = np.array( + [ + (i, pos.colIdx, pos.rowIdx, pos.offset if pos.offset is not None else 0) + for i, pos in enumerate(all_positions) + ], + dtype=[('idx', int), ('col', int), ('row', int), ('offset', int)], + ) + + # Pre-compute flipped y coordinates + flipped_ys = ali_height - positions_data['row'] - 1 + + # Initialize result sets + highlighted_positions = set() + target_positions = set() + + # Get nucleotide colors + nuc_colors = get_color_palette(palette) + + # Styling parameters + border_thickness = 2.5 + inset = 0.05 + + # Store drawing instructions + drawing_queue = [] + + # Process positions in batches based on offset values + unique_offsets = np.unique(positions_data['offset']) + + total_positions = len(all_positions) + + with tqdm( + total=total_positions, desc='Processing RIP positions', unit='pos' + ) as pbar: + for offset_val in unique_offsets: + offset_mask = positions_data['offset'] == offset_val + batch_positions = positions_data[offset_mask] + batch_ys = flipped_ys[offset_mask] + + # Process this batch of positions + for _i, (pos_data, y) in enumerate(zip(batch_positions, batch_ys)): + pos_idx, col_idx, row_idx, offset = pos_data + original_pos = all_positions[pos_idx] + category = all_categories[pos_idx] + base = original_pos.base + + # Add target position + highlighted_positions.add((col_idx, y)) + target_positions.add((col_idx, y)) + + if offset == 0: + # Single position + if base in nuc_colors: + color = nuc_colors[base] + drawing_queue.append( + ( + 'rect', + { + 'xy': (col_idx - 0.5, y - 0.5), + 'width': 1.0, + 'height': 1.0, + 'facecolor': color, + 'edgecolor': 'none', + 'linewidth': 0, + 'zorder': 50, + }, + ) + ) + + if draw_boxes: + drawing_queue.append( + ( + 'rect', + { + 'xy': (col_idx - 0.5 + inset, y - 0.5 + inset), + 'width': 1.0 - 2 * inset, + 'height': 1.0 - 2 * inset, + 'facecolor': 'none', + 'edgecolor': 'black', + 'linewidth': border_thickness, + 'zorder': 150, + }, + ) + ) + else: + # Multiple positions with offset - use vectorized operations + if offset < 0: + valid_cols = np.arange(max(0, col_idx + offset), col_idx) + else: + max_col = ( + arr.shape[1] if arr is not None else col_idx + offset + 1 + ) + valid_cols = np.arange( + col_idx, min(max_col, col_idx + offset + 1) + ) + + # Filter out gaps using vectorized operations + if arr is not None and len(valid_cols) > 0: + row_data = arr[ali_height - y - 1, valid_cols] + valid_cols = valid_cols[row_data != '-'] + + # Add all valid positions + for col in valid_cols: + highlighted_positions.add((col, y)) + + # Get base color + cell_base = ( + arr[ali_height - y - 1, col] if arr is not None else base + ) + cell_color = nuc_colors.get(cell_base, '#CCCCCC') + cell_alpha = 0.7 if col != col_idx else 1.0 + + drawing_queue.append( + ( + 'rect', + { + 'xy': (col - 0.5, y - 0.5), + 'width': 1.0, + 'height': 1.0, + 'facecolor': cell_color, + 'edgecolor': 'none', + 'linewidth': 0, + 'alpha': cell_alpha, + 'zorder': 50, + }, + ) + ) + + # Add border around the group + if len(valid_cols) > 0 and draw_boxes: + start_col, end_col = valid_cols[0], valid_cols[-1] + drawing_queue.append( + ( + 'rect', + { + 'xy': (start_col - 0.5 + inset, y - 0.5 + inset), + 'width': (end_col - start_col + 1) - 2 * inset, + 'height': 1.0 - 2 * inset, + 'facecolor': 'none', + 'edgecolor': 'black', + 'linewidth': border_thickness, + 'zorder': 150, + }, + ) + ) + + pbar.update(1) + + # Execute all drawing operations + for draw_type, params in drawing_queue: + if draw_type == 'rect': + a.add_patch(matplotlib.patches.Rectangle(**params)) + + elapsed_time = time.time() - start_time + logging.info(f'Optimized RIP highlighting completed in {elapsed_time:.2f} seconds') + + return highlighted_positions, target_positions + + +# Update the main function to use optimized versions def drawMiniAlignment( alignment: MultipleSeqAlignment, outfile: str, @@ -284,9 +648,11 @@ def drawMiniAlignment( corrected_positions: Optional[List[int]] = None, reaminate: bool = False, reference_seq_index: Optional[int] = None, - show_rip: str = 'both', # 'substrate', 'product', or 'both' + show_rip: str = 'both', highlight_corrected: bool = True, flag_corrected: bool = False, + num_threads: int = None, + min_items_for_threading: int = 500, ) -> Union[str, bool]: """ Generate a visualization of a DNA sequence alignment with optional RIP markup. @@ -340,6 +706,12 @@ def drawMiniAlignment( If True, only corrected positions in the consensus will be colored, all others will be gray (default: True). flag_corrected : bool, optional If True, corrected positions will be marked with a large asterisk above the consensus (default: False). + num_threads : int, optional + Number of threads to use for parallel processing. If None, uses the number + of CPU cores available (default: None). + min_items_for_threading : int, optional + Minimum number of cells/characters required to use parallel processing. + For smaller alignments, sequential processing is used (default: 500). Returns ------- @@ -362,15 +734,37 @@ def drawMiniAlignment( - RIP substrates are highlighted in blue - Non-RIP deamination events are highlighted in orange """ - # DEBUG: Print function parameters for troubleshooting + import concurrent.futures + from multiprocessing import cpu_count + import threading + import time + + start_time = time.time() + + # Set number of threads to use if not specified + if num_threads is None: + num_threads = cpu_count() + + # If num_threads > available CPU cores, set to available cores + if num_threads > cpu_count(): + logging.warning( + f'Requested {num_threads} threads, but only {cpu_count()} available. Using {cpu_count()} threads.' + ) + num_threads = cpu_count() + + # Log function call with important parameters logging.debug( - f'drawMiniAlignment: outfile={outfile}, dpi={dpi}, title={title}, width={width}, height={height}, orig_nams={orig_nams}, keep_numbers={keep_numbers}, force_numbers={force_numbers}, palette={palette}, markupdict={markupdict}, column_ranges={column_ranges}, show_chars={show_chars}, consensus_seq={consensus_seq}, corrected_positions={corrected_positions}, reaminate={reaminate}, reference_seq_index={reference_seq_index}, show_rip={show_rip}, highlight_corrected={highlight_corrected}' + f'drawMiniAlignment: outfile={outfile}, dpi={dpi}, title={title}, width={width}, height={height}, ' + f'show_chars={show_chars}, consensus_seq={"provided" if consensus_seq else "None"}, ' + f'corrected_positions={"provided" if corrected_positions else "None"}, ' + f'num_threads={num_threads}, min_items_for_threading={min_items_for_threading}' ) + # Handle default value for orig_nams if orig_nams is None: orig_nams = [] - # Convert the MSA object to a numpy array + # Convert the MSA object to a numpy array (already optimized) arr, nams, seq_len = MSAToArray(alignment) # Return False if only one sequence was found @@ -386,6 +780,15 @@ def drawMiniAlignment( # Get alignment dimensions ali_height, ali_width = np.shape(arr) + # Calculate total cells for threading decisions + total_cells = ali_height * ali_width + use_threading_for_chars = ( + show_chars + and total_cells >= min_items_for_threading + and num_threads > 1 + and ali_width < 500 + ) + # Define plot styling parameters fontsize = 14 @@ -399,9 +802,6 @@ def drawMiniAlignment( else: tickint = 100 - # The rest of the function is identical to drawMiniAlignment, - # continuing with the same plotting logic - # Calculate line weights based on alignment dimensions lineweight_h = 10 / ali_height # Horizontal grid lines lineweight_v = 10 / ali_width # Vertical grid lines @@ -442,58 +842,76 @@ def drawMiniAlignment( a.set_xlim(-0.5, ali_width - 0.5) a.set_ylim(-0.5, ali_height - 0.5) - # Convert alignment to numeric form and get color map - arr2, cm = arrNumeric(arr, palette='basegrey') + # Convert alignment to numeric form using optimized function + arr2, cm = arrNumeric_optimized(arr, palette='basegrey') - # Process markup if provided + # Process markup if provided using optimized functions if markupdict: + markup_start_time = time.time() + # Filter the markup dictionary based on show_rip parameter filtered_markup = {} - - # Always include non-RIP deamination if specified (controlled by reaminate parameter) if 'non_rip_deamination' in markupdict: filtered_markup['non_rip_deamination'] = markupdict['non_rip_deamination'] - - # Include RIP substrates if requested if show_rip in ['substrate', 'both'] and 'rip_substrate' in markupdict: filtered_markup['rip_substrate'] = markupdict['rip_substrate'] - - # Include RIP products if requested if show_rip in ['product', 'both'] and 'rip_product' in markupdict: filtered_markup['rip_product'] = markupdict['rip_product'] - # Get all positions that will be highlighted without drawing them - positions_to_highlight = getHighlightedPositions( - filtered_markup, ali_height, arr, reaminate + # Get highlighted positions using optimized function + positions_to_highlight = getHighlightedPositions_optimized( + filtered_markup, + ali_height, + arr, + reaminate, + num_threads=num_threads, + min_items_for_threading=min_items_for_threading, ) - # Create a mask where highlighted positions are True + # Create mask using vectorized operations mask = np.zeros_like(arr2, dtype=bool) - for x, y in positions_to_highlight: - if 0 <= x < ali_width and 0 <= y < ali_height: - mask[y, x] = True - - # Create masked array where highlighted positions are transparent - masked_arr2 = np.ma.array(arr2, mask=mask) + if positions_to_highlight: + # Convert positions to arrays for vectorized indexing + highlight_coords = list(positions_to_highlight) + if highlight_coords: + x_coords, y_coords = zip(*highlight_coords) + # Use advanced indexing for fast mask creation + valid_mask = ( + (np.array(x_coords) >= 0) + & (np.array(x_coords) < ali_width) + & (np.array(y_coords) >= 0) + & (np.array(y_coords) < ali_height) + ) + valid_x = np.array(x_coords)[valid_mask] + valid_y = np.array(y_coords)[valid_mask] + mask[valid_y, valid_x] = True # Draw the alignment with highlighted positions masked out + masked_arr2 = np.ma.array(arr2, mask=mask) a.imshow( masked_arr2, cmap=cm, aspect='auto', interpolation='nearest', zorder=10 ) - # Draw the colored highlights on top - highlighted_positions, target_positions = markupRIPBases( - a, filtered_markup, ali_height, arr, reaminate, palette, draw_boxes + # Draw the colored highlights using optimized function + highlighted_positions, target_positions = markupRIPBases_optimized( + a, + filtered_markup, + ali_height, + arr, + reaminate, + palette, + draw_boxes, + num_threads=num_threads, + min_items_for_threading=min_items_for_threading, ) + + markup_time = time.time() - markup_start_time + logging.info(f'Optimized RIP markup processing took {markup_time:.2f} seconds') else: # No markup, just draw the regular alignment a.imshow(arr2, cmap=cm, aspect='auto', interpolation='nearest', zorder=10) - _highlighted_positions = set() target_positions = set() - # Continue with the rest of the plotting code from drawMiniAlignment... - # (Including grid lines, reference marker, labels, text, etc.) - # Add grid lines a.hlines( np.arange(-0.5, ali_height), @@ -587,41 +1005,119 @@ def drawMiniAlignment( # Display sequence characters if requested and alignment isn't too large if show_chars and ali_width < 500: # Limit for performance reasons + chars_start_time = time.time() + # Increase font size for better visibility char_fontsize = min( 14, 18000 / (ali_width * ali_height) ) # Adjusted for larger font - # Don't show characters if they'll be too small + # Skip character rendering if they'll be too small if char_fontsize >= 4: - for y in range(ali_height): - for x in range(ali_width): - # Flip y-coordinate to match alignment orientation - flipped_y = ali_height - y - 1 - - # Get the character at this position - char = arr[y, x] + # Use multithreaded character rendering for large alignments + if use_threading_for_chars: + logging.info( + f'Using {num_threads} threads to render {total_cells} characters' + ) - # Determine text color based on whether position is a target position - text_color = ( - 'black' if (x, flipped_y) in target_positions else '#777777' - ) # Lighter grey for non-target bases (including offsets) + # Text commands to be executed sequentially (for matplotlib thread safety) + text_commands = [] + text_lock = threading.Lock() + + # Process a chunk of rows + def process_rows(row_range): + nonlocal arr, ali_width, ali_height, target_positions + local_commands = [] + for y in row_range: + for x in range(ali_width): + flipped_y = ali_height - y - 1 + char = arr[y, x] + text_color = ( + 'black' + if (x, flipped_y) in target_positions + else '#777777' + ) - # Add character as text annotation + # Store text parameters rather than rendering directly + local_commands.append((x, flipped_y, char, text_color)) + return local_commands + + # Create chunks of rows for parallel processing + chunk_size = max( + 1, ali_height // (num_threads * 2) + ) # Ensure at least 1 row per chunk + row_chunks = [ + range(i, min(i + chunk_size, ali_height)) + for i in range(0, ali_height, chunk_size) + ] + + # Process row chunks in parallel + with concurrent.futures.ThreadPoolExecutor( + max_workers=num_threads + ) as executor: + # Submit all row chunk processing jobs + future_to_chunk = { + executor.submit(process_rows, chunk): i + for i, chunk in enumerate(row_chunks) + } + + # Process results as they complete + with tqdm( + total=len(row_chunks), desc='Rendering text', unit='chunk' + ) as pbar: + for future in concurrent.futures.as_completed(future_to_chunk): + try: + commands = future.result() + with text_lock: + text_commands.extend(commands) + pbar.update(1) + except Exception as e: + logging.error(f'Error processing text chunk: {e}') + + # Render all text commands sequentially (for matplotlib thread safety) + for x, y, char, color in text_commands: a.text( x, - flipped_y, + y, char, ha='center', va='center', fontsize=char_fontsize, - color=text_color, + color=color, fontweight='bold', - zorder=200, # Make sure characters are on top of everything + zorder=200, ) + else: + # Sequential processing for smaller alignments + for y in tqdm(range(ali_height), desc='Rendering text', unit='row'): + for x in range(ali_width): + flipped_y = ali_height - y - 1 + char = arr[y, x] + text_color = ( + 'black' if (x, flipped_y) in target_positions else '#777777' + ) + + a.text( + x, + flipped_y, + char, + ha='center', + va='center', + fontsize=char_fontsize, + color=text_color, + fontweight='bold', + zorder=200, # Make sure characters are on top of everything + ) + + chars_time = time.time() - chars_start_time + logging.info( + f'Character rendering took {chars_time:.2f} seconds for {total_cells} cells' + ) # If consensus sequence is provided, add it to the second subplot if consensus_seq is not None and consensus_ax is not None: + consensus_start_time = time.time() + # Determine colors for each nucleotide nuc_colors = get_color_palette(palette) @@ -655,78 +1151,207 @@ def drawMiniAlignment( zorder=100, ) - # Plot each base in the consensus as a colored cell with character - for i, base in enumerate(consensus_seq): - # Determine cell color based on whether this is a corrected position - if highlight_corrected and i not in corrected_set: - # Use gray for non-corrected positions when highlight_corrected is True - color = '#c7d1d0' # Standard gray color - else: - # Use the regular color palette for this base - color = nuc_colors.get( - base.upper(), '#CCCCCC' - ) # Default to gray for unknown bases - - # Create colored rectangle for this base - consensus_ax.add_patch( - matplotlib.patches.Rectangle( - (i - 0.5, -0.5), # bottom left corner - 1, - 1, # width, height - color=color, - zorder=10, - ) + # Determine if we should use threading for consensus visualization + use_threading_for_consensus = ( + len(consensus_seq) >= min_items_for_threading and num_threads > 1 + ) + + if use_threading_for_consensus: + logging.info( + f'Using {num_threads} threads to render consensus sequence ({len(consensus_seq)} bases)' ) - # Add the character as text with increased font size - if show_chars: - # Determine text color - use black for all characters for better readability - text_color = 'black' - - consensus_ax.text( - i, - 0, - base, - ha='center', - va='center', - fontsize=min( - 18, 30 - len(consensus_seq) / 100 - ), # Further increased font size - color=text_color, - fontweight='bold', - zorder=20, + # Queue to collect drawing operations + drawing_queue = [] + drawing_lock = threading.Lock() + + # Function to process a chunk of the consensus sequence + def process_consensus_chunk(base_range): + local_drawing_ops = [] + + for i in base_range: + base = consensus_seq[i] + + # Determine cell color based on whether this is a corrected position + if highlight_corrected and i not in corrected_set: + color = '#c7d1d0' # Standard gray for non-corrected + else: + color = nuc_colors.get( + base.upper(), '#CCCCCC' + ) # Default to gray for unknown + + # Add colored rectangle for this base + local_drawing_ops.append( + ( + 'rect', + i, + { + 'xy': (i - 0.5, -0.5), + 'width': 1, + 'height': 1, + 'color': color, + 'zorder': 10, + }, + ) + ) + + # Add character as text if requested + if show_chars: + text_color = 'black' # For better readability + fontsize = min(18, 30 - len(consensus_seq) / 100) + + local_drawing_ops.append( + ( + 'text', + i, + { + 'x': i, + 'y': 0, + 'text': base, + 'ha': 'center', + 'va': 'center', + 'fontsize': fontsize, + 'color': text_color, + 'fontweight': 'bold', + 'zorder': 20, + }, + ) + ) + + # Add marker for corrected positions if requested + if corrected_positions and flag_corrected and i in corrected_set: + asterisk_fontsize = min( + 24, max(14, 40 - len(consensus_seq) / 50) + ) + + local_drawing_ops.append( + ( + 'text', + i, + { + 'x': i, + 'y': 1.0, + 'text': '*', + 'ha': 'center', + 'va': 'center', + 'fontsize': asterisk_fontsize, + 'color': 'red', + 'fontweight': 'bold', + 'zorder': 30, + }, + ) + ) + + return local_drawing_ops + + # Create chunks of consensus sequence bases for parallel processing + chunk_size = max(1, len(consensus_seq) // (num_threads * 2)) + base_chunks = [ + range(i, min(i + chunk_size, len(consensus_seq))) + for i in range(0, len(consensus_seq), chunk_size) + ] + + # Process chunks in parallel + with concurrent.futures.ThreadPoolExecutor( + max_workers=num_threads + ) as executor: + future_to_chunk = { + executor.submit(process_consensus_chunk, chunk): i + for i, chunk in enumerate(base_chunks) + } + + # Process results as they complete + with tqdm( + total=len(base_chunks), desc='Rendering consensus', unit='chunk' + ) as pbar: + for future in concurrent.futures.as_completed(future_to_chunk): + try: + local_ops = future.result() + with drawing_lock: + drawing_queue.extend(local_ops) + pbar.update(1) + except Exception as e: + logging.error(f'Error processing consensus chunk: {e}') + + # Sort operations by index for consistent rendering + drawing_queue.sort(key=lambda x: x[1]) + + # Execute all drawing operations sequentially on the consensus axis + for op_type, _, params in drawing_queue: + if op_type == 'rect': + consensus_ax.add_patch(matplotlib.patches.Rectangle(**params)) + elif op_type == 'text': + consensus_ax.text(**params) + else: + # Sequential processing for consensus visualization + for i in tqdm( + range(len(consensus_seq)), desc='Rendering consensus', unit='base' + ): + base = consensus_seq[i] + + # Determine cell color + if highlight_corrected and i not in corrected_set: + color = '#c7d1d0' # Standard gray for non-corrected + else: + color = nuc_colors.get(base.upper(), '#CCCCCC') + + # Add colored rectangle + consensus_ax.add_patch( + matplotlib.patches.Rectangle( + (i - 0.5, -0.5), # bottom left corner + 1, + 1, # width, height + color=color, + zorder=10, + ) ) - # Add markers for corrected positions if provided - if corrected_positions and flag_corrected: - for pos in corrected_positions: - if 0 <= pos < len(consensus_seq): - # Calculate appropriate font size based on sequence length - # Scale inversely with sequence length to fit within cells - # Use a slightly larger size than the base characters to stand out - asterisk_fontsize = min(24, max(14, 40 - len(consensus_seq) / 50)) + # Add character text if requested + if show_chars: + consensus_ax.text( + i, + 0, + base, + ha='center', + va='center', + fontsize=min(18, 30 - len(consensus_seq) / 100), + color='black', + fontweight='bold', + zorder=20, + ) - # Draw a large asterisk centered in the space above each corrected position - # Size now scales with the cell dimensions + # Add asterisk for corrected positions if requested + if corrected_positions and flag_corrected and i in corrected_set: + asterisk_fontsize = min(24, max(14, 40 - len(consensus_seq) / 50)) consensus_ax.text( - pos, # x position - 1.0, # y position (centered in new space above sequence) - '*', # asterisk character - ha='center', # horizontally centered - va='center', # vertically centered - fontsize=asterisk_fontsize, # dynamically scaled font size - color='red', # red color - fontweight='bold', # bold for emphasis - zorder=30, # ensure it's on top + i, + 1.0, + '*', + ha='center', + va='center', + fontsize=asterisk_fontsize, + color='red', + fontweight='bold', + zorder=30, ) + consensus_time = time.time() - consensus_start_time + logging.info( + f'Consensus visualization took {consensus_time:.2f} seconds for {len(consensus_seq)} bases' + ) + # Save the plot as a PNG image + logging.info(f'Saving figure to {outfile}') f.savefig(outfile, format='png') # Clean up resources plt.close() del arr, arr2, nams + # Report total processing time + total_time = time.time() - start_time + logging.info(f'Total visualization time: {total_time:.2f} seconds') + return outfile @@ -738,6 +1363,8 @@ def markupRIPBases( reaminate: bool = False, palette: str = 'derip2', draw_boxes: bool = True, + num_threads: int = None, + min_items_for_threading: int = 100, ) -> Tuple[Set[Tuple[int, int]], Set[Tuple[int, int]]]: """ Highlight RIP-related bases in the alignment plot with color coding and borders. @@ -773,6 +1400,12 @@ def markupRIPBases( Color palette to use for base highlighting (default: 'derip2'). draw_boxes : bool, optional Whether to draw black borders around highlighted bases (default: True). + num_threads : int, optional + Number of threads to use for parallel processing. If None, uses the number + of CPU cores available (default: None). + min_items_for_threading : int, optional + Minimum number of positions required to use parallel processing. + For smaller datasets, sequential processing is used (default: 100). Returns ------- @@ -791,26 +1424,61 @@ def markupRIPBases( - Coordinates in returned sets are in matplotlib coordinates, where y-axis is flipped compared to the alignment array (0 at bottom, increasing upward) """ + import concurrent.futures + from multiprocessing import cpu_count + import threading + import time + + start_time = time.time() + # DEBUG: Print function parameters for troubleshooting logging.debug( - f'markupRIPBases: markupdict={markupdict}, ali_height={ali_height}, arr={arr}, reaminate={reaminate}, palette={palette}, draw_boxes={draw_boxes}' + f'markupRIPBases: markupdict={markupdict}, ali_height={ali_height}, arr={arr}, ' + f'reaminate={reaminate}, palette={palette}, draw_boxes={draw_boxes}, ' + f'num_threads={num_threads}, min_items_for_threading={min_items_for_threading}' ) + # Initialize result sets with thread locks highlighted_positions = set() - target_positions = set() # Track primary target positions separately + target_positions = set() + highlighted_lock = threading.Lock() + target_lock = threading.Lock() + + # Styling parameters border_thickness = 2.5 # Border thickness inset = 0.05 # Smaller inset for borders to reduce gap with grid lines # Define colors for nucleotide bases nuc_colors = get_color_palette(palette) - # Count total positions to process for progress bar + # Count total positions to process and determine if threading should be used total_positions = sum( len(positions) for category, positions in markupdict.items() if category != 'non_rip_deamination' or reaminate ) + # Set number of threads to use + if num_threads is None: + num_threads = cpu_count() + + # If num_threads > available CPU cores, set to available cores + if num_threads > cpu_count(): + logging.warning( + f'Requested {num_threads} threads, but only {cpu_count()} available. Using {cpu_count()} threads.' + ) + num_threads = cpu_count() + + # Determine if we should use threading based on dataset size + use_threading = (total_positions >= min_items_for_threading) and (num_threads > 1) + + if use_threading: + logging.info( + f'Using {num_threads} threads to process {total_positions} RIP positions' + ) + else: + logging.info(f'Using sequential processing for {total_positions} RIP positions') + # Create one progress bar for all positions pbar = tqdm( total=total_positions, @@ -820,142 +1488,268 @@ def markupRIPBases( leave=False, ) - # Process all positions in the markup dictionary - for category, positions in markupdict.items(): - # Skip non-RIP deamination if reaminate is False - if category == 'non_rip_deamination' and not reaminate: - continue - - # Update progress bar description to show current category - pbar.set_description(f'Highlighting {category}') - - # Process each position with progress tracking - for pos in positions: - col_idx, row_idx, base, offset = pos - y = ali_height - row_idx - 1 - highlighted_positions.add((col_idx, y)) - target_positions.add( - (col_idx, y) - ) # Add only the target position to target set - - # Case 1: Single base (no offset or offset=0) - if offset is None or offset == 0: - if base in nuc_colors: - color = nuc_colors[base] - - # Draw the base with full-size colored rectangle - use EXACT cell dimensions - a.add_patch( - matplotlib.patches.Rectangle( - (col_idx - 0.5, y - 0.5), # Exact cell boundaries - 1.0, # Full width - 1.0, # Full height - facecolor=color, - edgecolor='none', - linewidth=0, - zorder=50, # Above base image (10), below grid lines (100) - ) + # Store drawing instructions to handle matplotlib's thread-safety issues + # We'll collect all drawing instructions and execute them sequentially + drawing_queue = [] + drawing_lock = threading.Lock() + + def process_position(category, pos): + """ + Process a single RIP position and generate drawing instructions. + + This function handles the visualization of a single RIP-related position, + generating appropriate drawing instructions for both the target nucleotide + and any context nucleotides (specified by offset). It creates rectangles + with appropriate colors and, if requested, black borders around the highlighted + regions. + + Parameters + ---------- + category : str + The category of RIP mutation ('rip_product', 'rip_substrate', + or 'non_rip_deamination'). + pos : RIPPosition + A named tuple containing: + - colIdx : int + Column index in the alignment matrix. + - rowIdx : int + Row index in the alignment matrix. + - base : str + Nucleotide base at this position. + - offset : int or None + Context range around the mutation, negative=left, positive=right. + + Returns + ------- + tuple + A tuple containing three elements: + - local_highlighted : set + Set of (x, y) coordinate tuples for all highlighted positions. + - local_target : set + Set of (x, y) coordinate tuples for primary mutation sites. + - local_draw_instructions : list + List of drawing instruction tuples, each containing: + ('rect', {rectangle_parameters}) for matplotlib rendering. + + Notes + ----- + This function accesses several variables from the outer scope: + - ali_height : int + Height of the alignment in rows. + - arr : np.ndarray + Original alignment array for accessing base information. + - nuc_colors : dict + Dictionary mapping nucleotides to color codes. + - draw_boxes : bool + Whether to draw borders around highlighted positions. + - inset : float + Inset amount for drawing borders. + - border_thickness : float + Thickness of the border lines. + """ + col_idx, row_idx, base, offset = pos + y = ali_height - row_idx - 1 + + # Local results to collect + local_highlighted = set() + local_target = set() + local_draw_instructions = [] + + # Add target position to highlighted set + local_highlighted.add((col_idx, y)) + local_target.add((col_idx, y)) + + # Case 1: Single base (no offset or offset=0) + if offset is None or offset == 0: + if base in nuc_colors: + color = nuc_colors[base] + + # Queue drawing instruction for the base rectangle + local_draw_instructions.append( + ( + 'rect', + { + 'xy': (col_idx - 0.5, y - 0.5), + 'width': 1.0, + 'height': 1.0, + 'facecolor': color, + 'edgecolor': 'none', + 'linewidth': 0, + 'zorder': 50, + }, ) + ) - # Draw black border with smaller inset and thinner line for cleaner appearance - if draw_boxes: - a.add_patch( - matplotlib.patches.Rectangle( - ( - col_idx - 0.5 + inset, - y - 0.5 + inset, - ), # Inset from cell edge - 1.0 - 2 * inset, # Width with minimal inset - 1.0 - 2 * inset, # Height with minimal inset - facecolor='none', - edgecolor='black', - linewidth=border_thickness, - zorder=150, # Above grid lines (100) - ) + # Queue drawing instruction for the border if needed + if draw_boxes: + local_draw_instructions.append( + ( + 'rect', + { + 'xy': (col_idx - 0.5 + inset, y - 0.5 + inset), + 'width': 1.0 - 2 * inset, + 'height': 1.0 - 2 * inset, + 'facecolor': 'none', + 'edgecolor': 'black', + 'linewidth': border_thickness, + 'zorder': 150, + }, ) - - # Case 2: Multiple positions (with offset) - elif offset != 0: - # Process range and get valid cells as before - if offset < 0: # Positions to the left - start_idx = max(0, col_idx + offset) - end_idx = col_idx - else: # Positions to the right - start_idx = col_idx - end_idx = ( - min(arr.shape[1] - 1, col_idx + offset) - if arr is not None - else col_idx + offset ) + # Case 2: Multiple positions (with offset) + elif offset != 0: + # Process range and get valid cell indices + if offset < 0: # Positions to the left + start_idx = max(0, col_idx + offset) + end_idx = col_idx + else: # Positions to the right + start_idx = col_idx + end_idx = ( + min(arr.shape[1] - 1, col_idx + offset) + if arr is not None + else col_idx + offset + ) - # Skip gaps and out-of-bounds positions - valid_indices = [] - for i in range(start_idx, end_idx + 1): - if i < 0 or ( - arr is not None - and (i >= arr.shape[1] or arr[ali_height - y - 1, i] == '-') - ): - continue - valid_indices.append(i) - - if not valid_indices: - pbar.update(1) # Update progress bar even if skipping + # Collect valid indices (excluding gaps and out-of-bounds) + valid_indices = [] + for i in range(start_idx, end_idx + 1): + if i < 0 or ( + arr is not None + and (i >= arr.shape[1] or arr[ali_height - y - 1, i] == '-') + ): continue - - # Fill cells with appropriate colors (full-size) - use EXACT cell dimensions - for i in valid_indices: - # Add to highlighted positions - highlighted_positions.add((i, y)) - # Note: We don't add offset positions to target_positions - - # Get color for this base - cell_base = arr[ali_height - y - 1, i] if arr is not None else base - cell_color = nuc_colors.get(cell_base, '#CCCCCC') - - # Determine alpha value based on whether this is the target cell or an offset cell - # For cells with offset > 0, make offset cells semi-transparent - cell_alpha = 1.0 - if ( - offset > 0 or offset < 0 - ) and i != col_idx: # This is an offset cell - cell_alpha = 0.7 - - # Draw full-size colored cell with appropriate transparency - a.add_patch( - matplotlib.patches.Rectangle( - (i - 0.5, y - 0.5), # Exact cell boundaries - 1.0, # Full width - 1.0, # Full height - facecolor=cell_color, - edgecolor='none', - linewidth=0, - alpha=cell_alpha, # Apply transparency to offset cells - zorder=50, # Above base image, below grid - ) + valid_indices.append(i) + + if not valid_indices: + return local_highlighted, local_target, local_draw_instructions + + # Queue drawing instructions for each valid cell + for i in valid_indices: + # Add position to highlighted set + local_highlighted.add((i, y)) + + # Get color for this base + cell_base = arr[ali_height - y - 1, i] if arr is not None else base + cell_color = nuc_colors.get(cell_base, '#CCCCCC') + + # Determine transparency based on whether it's target or context + cell_alpha = 1.0 + if (offset > 0 or offset < 0) and i != col_idx: + cell_alpha = 0.7 + + # Queue drawing instruction + local_draw_instructions.append( + ( + 'rect', + { + 'xy': (i - 0.5, y - 0.5), + 'width': 1.0, + 'height': 1.0, + 'facecolor': cell_color, + 'edgecolor': 'none', + 'linewidth': 0, + 'alpha': cell_alpha, + 'zorder': 50, + }, ) + ) - # Draw border with smaller inset - if valid_indices and draw_boxes: - start_i = min(valid_indices) - end_i = max(valid_indices) - - a.add_patch( - matplotlib.patches.Rectangle( - (start_i - 0.5 + inset, y - 0.5 + inset), - (end_i - start_i + 1) - 2 * inset, - 1.0 - 2 * inset, - facecolor='none', - edgecolor='black', - linewidth=border_thickness, - zorder=150, # Above grid lines - ) + # Queue drawing instruction for the border around the whole group + if valid_indices and draw_boxes: + start_i = min(valid_indices) + end_i = max(valid_indices) + + local_draw_instructions.append( + ( + 'rect', + { + 'xy': (start_i - 0.5 + inset, y - 0.5 + inset), + 'width': (end_i - start_i + 1) - 2 * inset, + 'height': 1.0 - 2 * inset, + 'facecolor': 'none', + 'edgecolor': 'black', + 'linewidth': border_thickness, + 'zorder': 150, + }, ) + ) - # Update progress bar - pbar.update(1) + return local_highlighted, local_target, local_draw_instructions + + # Process positions either in parallel or sequentially + if use_threading: + # Prepare all tasks for processing + all_tasks = [] + for category, positions in markupdict.items(): + # Skip non-RIP deamination if reaminate is False + if category == 'non_rip_deamination' and not reaminate: + continue + for pos in positions: + all_tasks.append((category, pos)) + + # Process in parallel using ThreadPoolExecutor + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + # Submit all position processing jobs + future_results = { + executor.submit(process_position, category, pos): (category, pos) + for category, pos in all_tasks + } + + # Process results as they complete + for future in concurrent.futures.as_completed(future_results): + try: + local_highlighted, local_target, local_draw = future.result() + + # Synchronize results with locks + with highlighted_lock: + highlighted_positions.update(local_highlighted) + with target_lock: + target_positions.update(local_target) + with drawing_lock: + drawing_queue.extend(local_draw) + + # Update progress bar + pbar.update(1) + + except Exception as e: + logging.error(f'Error processing position: {e}') + else: + # Sequential processing + for category, positions in markupdict.items(): + # Skip non-RIP deamination if reaminate is False + if category == 'non_rip_deamination' and not reaminate: + continue + + # Update progress bar description to show current category + pbar.set_description(f'Highlighting {category}') + + # Process each position sequentially + for pos in positions: + local_highlighted, local_target, local_draw = process_position( + category, pos + ) + highlighted_positions.update(local_highlighted) + target_positions.update(local_target) + drawing_queue.extend(local_draw) + pbar.update(1) # Close the progress bar pbar.close() + # Execute all drawing operations sequentially (for matplotlib thread safety) + for draw_type, params in drawing_queue: + if draw_type == 'rect': + a.add_patch(matplotlib.patches.Rectangle(**params)) + + # Report processing time + elapsed_time = time.time() - start_time + logging.info( + f'RIP highlighting completed in {elapsed_time:.2f} seconds ({total_positions} positions)' + ) + logging.info( + f'Highlighted {len(highlighted_positions)} total positions, {len(target_positions)} target positions' + ) + return highlighted_positions, target_positions @@ -1014,6 +1808,8 @@ def getHighlightedPositions( ali_height: int, arr: np.ndarray = None, reaminate: bool = False, + num_threads: int = None, + min_items_for_threading: int = 1000, ) -> Set[Tuple[int, int]]: """ Get all positions that should be highlighted based on the markup dictionary. @@ -1028,43 +1824,168 @@ def getHighlightedPositions( The original alignment array, used to check for gap positions. reaminate : bool, optional Whether to include non-RIP deamination positions. + num_threads : int, optional + Number of threads to use for parallel processing. If None, uses the number + of CPU cores available (default: None). + min_items_for_threading : int, optional + Minimum number of positions required to use parallel processing. + For smaller datasets, sequential processing is used (default: 100). Returns ------- Set[Tuple[int, int]] Set of (col_idx, flipped_y) tuples for all highlighted positions. """ + import concurrent.futures + from multiprocessing import cpu_count + import threading + import time + + start_time = time.time() + + # Shared set for collecting results with thread safety highlighted_positions = set() + positions_lock = threading.Lock() - for category, positions in markupdict.items(): - # Skip non-RIP deamination if reaminate is False - if category == 'non_rip_deamination' and not reaminate: - continue - - for pos in positions: - col_idx, row_idx, base, offset = pos - - # Convert row index to matplotlib coordinates (flipped) - y = ali_height - row_idx - 1 - - # Add target position to highlighted set - highlighted_positions.add((col_idx, y)) - - # Handle offset positions - if offset is not None: - if offset < 0: - # Negative offset means positions to the left - for i in range(col_idx + offset, col_idx): - if i >= 0 and ( - arr is None or arr[ali_height - y - 1, i] != '-' - ): - highlighted_positions.add((i, y)) - elif offset > 0: - # Positive offset means positions to the right - for i in range(col_idx + 1, col_idx + offset + 1): - if i <= arr.shape[1] - 1 and ( - arr is None or arr[ali_height - y - 1, i] != '-' - ): - highlighted_positions.add((i, y)) + # Set number of threads to use if not specified + if num_threads is None: + num_threads = cpu_count() + + # If num_threads > available CPU cores, set to available cores + if num_threads > cpu_count(): + logging.warning( + f'Requested {num_threads} threads, but only {cpu_count()} available. Using {cpu_count()} threads.' + ) + num_threads = cpu_count() + + # Count total positions to determine if threading should be used + total_positions = sum( + len(positions) + for category, positions in markupdict.items() + if category != 'non_rip_deamination' or reaminate + ) + + # Determine if threading should be used + use_threading = total_positions >= min_items_for_threading and num_threads > 1 + + # Function to process a single markup position + def process_position(pos): + """ + Process a single position and collect its coordinates for highlighting. + + This function takes a RIP position and determines all coordinates that should + be highlighted in the visualization. It handles both the target position itself + and any context positions specified by the offset parameter, taking into account + alignment boundaries and gaps. + + Parameters + ---------- + pos : RIPPosition + A named tuple containing: + - colIdx : int + Column index in the alignment matrix. + - rowIdx : int + Row index in the alignment matrix. + - base : str + Nucleotide base at this position. + - offset : int or None + Context range around the mutation, negative=left, positive=right. + + Returns + ------- + set + Set of (col_idx, y) coordinate tuples for all positions that should be + highlighted, where y is the flipped row index in matplotlib coordinates. + + Notes + ----- + This function accesses several variables from the outer scope: + - ali_height : int + Height of the alignment in rows. + - arr : np.ndarray + Original alignment array for checking gap positions. + """ + col_idx, row_idx, base, offset = pos + local_positions = set() + + # Convert row index to matplotlib coordinates (flipped) + y = ali_height - row_idx - 1 + + # Add target position to highlighted set + local_positions.add((col_idx, y)) + + # Handle offset positions + if offset is not None: + if offset < 0: + # Negative offset means positions to the left + for i in range(col_idx + offset, col_idx): + if i >= 0 and (arr is None or arr[ali_height - y - 1, i] != '-'): + local_positions.add((i, y)) + elif offset > 0: + # Positive offset means positions to the right + for i in range(col_idx + 1, col_idx + offset + 1): + if arr is None or ( + i < arr.shape[1] and arr[ali_height - y - 1, i] != '-' + ): + local_positions.add((i, y)) + + return local_positions + + if use_threading: + logging.debug( + f'Using {num_threads} threads to process {total_positions} positions for highlighting' + ) + + # Collect all positions to process + positions_to_process = [] + for category, positions in markupdict.items(): + # Skip non-RIP deamination if reaminate is False + if category == 'non_rip_deamination' and not reaminate: + continue + positions_to_process.extend(positions) + + # Process positions in parallel + with concurrent.futures.ThreadPoolExecutor(max_workers=num_threads) as executor: + # Create progress bar for parallel processing + with tqdm( + total=total_positions, desc='Finding highlight positions', unit='pos' + ) as pbar: + # Submit all positions for processing + future_to_pos = { + executor.submit(process_position, pos): pos + for pos in positions_to_process + } + + # Process results as they complete + for future in concurrent.futures.as_completed(future_to_pos): + try: + local_positions = future.result() + # Synchronize results + with positions_lock: + highlighted_positions.update(local_positions) + pbar.update(1) + except Exception as e: + logging.error(f'Error processing position: {e}') + else: + # Sequential processing + for category, positions in markupdict.items(): + # Skip non-RIP deamination if reaminate is False + if category == 'non_rip_deamination' and not reaminate: + continue + + # Process each position sequentially with progress tracking + with tqdm( + total=len(positions), desc=f'Finding {category} positions', unit='pos' + ) as pbar: + for pos in positions: + local_positions = process_position(pos) + highlighted_positions.update(local_positions) + pbar.update(1) + + elapsed_time = time.time() - start_time + logging.info( + f'Position highlighting calculation completed in {elapsed_time:.2f} seconds' + ) + logging.info(f'Found {len(highlighted_positions)} positions to highlight') return highlighted_positions diff --git a/tests/test_minialign.py b/tests/test_minialign.py index fe06eff..c0297ff 100644 --- a/tests/test_minialign.py +++ b/tests/test_minialign.py @@ -18,7 +18,11 @@ MSAToArray, RIPPosition, addColumnRangeMarkers, - arrNumeric, +) +from derip2.plotting.minialign import ( + arrNumeric_optimized as arrNumeric, # Import optimized version with alias +) +from derip2.plotting.minialign import ( drawMiniAlignment, markupRIPBases, ) @@ -241,8 +245,10 @@ def test_drawMiniAlignment_single_sequence(mock_savefig, single_seq_alignment): @patch('matplotlib.figure.Figure.savefig') -@patch('derip2.plotting.minialign.markupRIPBases') -@patch('derip2.plotting.minialign.getHighlightedPositions') +@patch('derip2.plotting.minialign.markupRIPBases_optimized') # Fixed function name +@patch( + 'derip2.plotting.minialign.getHighlightedPositions_optimized' +) # Fixed function name def test_drawMiniAlignment_with_markup( mock_get_highlighted, mock_markup, mock_savefig, simple_alignment, rip_positions ): @@ -292,8 +298,10 @@ def test_drawMiniAlignment_with_custom_dimensions(simple_alignment): mock_arr = np.full((3, 4), 'A') mock_msa_to_array.return_value = (mock_arr, ['seq1', 'seq2', 'seq3'], 3) - # Also mock arrNumeric to avoid string/numeric conversion issues - with patch('derip2.plotting.minialign.arrNumeric') as mock_arrnumeric: + # Also mock arrNumeric_optimized to avoid string/numeric conversion issues + with patch( + 'derip2.plotting.minialign.arrNumeric_optimized' + ) as mock_arrnumeric: # Fixed function name mock_arrnumeric.return_value = ( np.zeros((3, 4)), matplotlib.colors.ListedColormap(['#000000', '#FFFFFF']), @@ -349,8 +357,8 @@ def test_drawMiniAlignment_different_palettes(mock_savefig, simple_alignment): # Reset the mock to clear call history between loop iterations mock_savefig.reset_mock() - # Fix: Create a proper mock for arrNumeric that returns valid objects - with patch('derip2.plotting.minialign.arrNumeric') as mock_arrNumeric: + # Fix: Create a proper mock for arrNumeric_optimized that returns valid objects + with patch('derip2.plotting.minialign.arrNumeric_optimized') as mock_arrNumeric: # Create a simple numeric array and a valid colormap numeric_arr = np.zeros((3, 4)) cmap = matplotlib.colors.ListedColormap(['#ffffff', '#000000'])