diff --git a/Dockerfile b/Dockerfile index ac8b542d..aa88185f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,7 +1,7 @@ FROM ubuntu:22.04 LABEL maintainer="cschu1981@gmail.com" -LABEL version="2.18.0" +LABEL version="2.19.0" LABEL description="gffquant - functional profiling of metagenomic/transcriptomic wgs samples" diff --git a/gffquant/__init__.py b/gffquant/__init__.py index 128d76bf..31f4177f 100644 --- a/gffquant/__init__.py +++ b/gffquant/__init__.py @@ -5,7 +5,7 @@ from enum import Enum, auto, unique -__version__ = "2.18.0" +__version__ = "2.19.0" __tool__ = "gffquant" diff --git a/gffquant/alignment/__init__.py b/gffquant/alignment/__init__.py index 9c07f61f..77c59f0a 100644 --- a/gffquant/alignment/__init__.py +++ b/gffquant/alignment/__init__.py @@ -6,6 +6,7 @@ from .aln_group import AlignmentGroup from .pysam_alignment_processor import AlignmentProcessor +from .reference_hit import ReferenceHit from .samflags import SamFlags from .cigarops import CigarOps diff --git a/gffquant/alignment/aln_group.py b/gffquant/alignment/aln_group.py index 057b5a3f..b465df46 100644 --- a/gffquant/alignment/aln_group.py +++ b/gffquant/alignment/aln_group.py @@ -79,6 +79,9 @@ def get_all_hits(self, as_ambiguous=False): except TypeError as err: raise TypeError(f"Cannot derive sequencing library from tags: {aln.tags}") from err + # in region mode, there can be more hits + # (if the alignment overlaps multiple features of the target sequence) + # in gene mode, each alignment is a hit, i.e. there is at most 1 hit / alignment yield aln.hits, n_aln def get_ambig_align_counts(self): diff --git a/gffquant/alignment/reference_hit.py b/gffquant/alignment/reference_hit.py new file mode 100644 index 00000000..968da602 --- /dev/null +++ b/gffquant/alignment/reference_hit.py @@ -0,0 +1,38 @@ +# pylint: disable=R0902 + +""" module docstring """ + +from dataclasses import dataclass, asdict + + +@dataclass(slots=True) +class ReferenceHit: + rid: int = None + start: int = None + end: int = None + rev_strand: bool = None + cov_start: int = None + cov_end: int = None + has_annotation: bool = None + n_aln: int = None + is_ambiguous: bool = None + library_mod: int = None + mate_id: int = None + + def __hash__(self): + return hash(tuple(asdict(self).values())) + + def __eq__(self, other): + return all( + item[0][1] == item[1][1] + for item in zip( + sorted(asdict(self).items()), + sorted(asdict(other).items()) + ) + ) + + def __str__(self): + return "\t".join(map(str, asdict(self).values())) + + def __repr__(self): + return str(self) diff --git a/gffquant/annotation/__init__.py b/gffquant/annotation/__init__.py index 4ae04f44..2f8e1c0c 100644 --- a/gffquant/annotation/__init__.py +++ b/gffquant/annotation/__init__.py @@ -2,5 +2,8 @@ """ module docstring """ -from .count_annotator import GeneCountAnnotator, RegionCountAnnotator +# from .count_annotator import GeneCountAnnotator, RegionCountAnnotator +from .count_annotator import CountAnnotator from .count_writer import CountWriter +from .genecount_annotator import GeneCountAnnotator +from .regioncount_annotator import RegionCountAnnotator diff --git a/gffquant/annotation/count_annotator.py b/gffquant/annotation/count_annotator.py index 760250dc..601eb865 100644 --- a/gffquant/annotation/count_annotator.py +++ b/gffquant/annotation/count_annotator.py @@ -95,6 +95,7 @@ def calc_scaling_factor(raw, normed, default=0): return (raw / normed) if normed else default total_uniq, total_uniq_normed, total_ambi, total_ambi_normed = self.total_counts + # total_uniq, total_ambi, total_uniq_normed, total_ambi_normed = self.total_counts logger.info( "TOTAL COUNTS: uraw=%s unorm=%s araw=%s anorm=%s", total_uniq, total_uniq_normed, total_ambi, total_ambi_normed @@ -108,19 +109,20 @@ def calc_scaling_factor(raw, normed, default=0): total_ambi, total_ambi_normed, default_scaling_factor ) - total_uniq, total_uniq_normed, total_ambi, total_ambi_normed = self.total_gene_counts - logger.info( - "TOTAL GENE COUNTS: uraw=%s unorm=%s araw=%s anorm=%s", - total_uniq, total_uniq_normed, total_ambi, total_ambi_normed - ) + # total_uniq, total_uniq_normed, total_ambi, total_ambi_normed = self.total_gene_counts + # total_uniq, total_ambi, total_uniq_normed, total_ambi_normed = self.total_gene_counts + # logger.info( + # "TOTAL GENE COUNTS: uraw=%s unorm=%s araw=%s anorm=%s", + # total_uniq, total_uniq_normed, total_ambi, total_ambi_normed + # ) - self.scaling_factors["total_gene_uniq"] = calc_scaling_factor( - total_uniq, total_uniq_normed, default_scaling_factor - ) + # self.scaling_factors["total_gene_uniq"] = calc_scaling_factor( + # total_uniq, total_uniq_normed, default_scaling_factor + # ) - self.scaling_factors["total_gene_ambi"] = calc_scaling_factor( - total_ambi, total_ambi_normed, default_scaling_factor - ) + # self.scaling_factors["total_gene_ambi"] = calc_scaling_factor( + # total_ambi, total_ambi_normed, default_scaling_factor + # ) fc_items = self.feature_count_sums.items() for category, ( @@ -189,140 +191,3 @@ def compute_count_vector( counts[1::2] /= float(length) return counts - - -class RegionCountAnnotator(CountAnnotator): - """ CountAnnotator subclass for contig/region-based counting. """ - - def __init__(self, strand_specific, report_scaling_factors=True): - CountAnnotator.__init__(self, strand_specific, report_scaling_factors=report_scaling_factors) - - # pylint: disable=R0914,W0613 - def annotate(self, refmgr, db, count_manager, gene_group_db=False): - """ - Annotate a set of region counts via db-lookup. - input: - - bam: bamr.BamFile to use as lookup table for reference names - - db: GffDatabaseManager holding functional annotation database - - count_manager: count_data - """ - for rid in set(count_manager.uniq_regioncounts).union( - count_manager.ambig_regioncounts - ): - ref = refmgr.get(rid[0] if isinstance(rid, tuple) else rid)[0] - - for region in count_manager.get_regions(rid): - if self.strand_specific: - (start, end), rev_strand = region - else: - (start, end), rev_strand = region, None - # the region_annotation is a tuple of key-value pairs: - # (strand, func_category1: subcategories, func_category2: subcategories, ...) - # the first is the strand, the second is the gene id, the rest are the features - - region_annotation = db.query_sequence(ref, start=start, end=end) - if region_annotation is not None: - region_strand, feature_id, region_annotation = region_annotation - if feature_id is None: - feature_id = ref - - on_other_strand = (region_strand == "+" and rev_strand) \ - or (region_strand == "-" and not rev_strand) - - antisense_region = self.strand_specific and on_other_strand - - uniq_counts, ambig_counts = count_manager.get_counts( - (rid, start, end), region_counts=True, strand_specific=self.strand_specific - ) - - if self.strand_specific: - # if the region is antisense, 'sense-counts' (relative to the) region come from the - # negative strand and 'antisense-counts' from the positive strand - # vice-versa for a sense-region - strand_specific_counts = ( - (count_manager.MINUS_STRAND, count_manager.PLUS_STRAND) - if antisense_region - else (count_manager.PLUS_STRAND, count_manager.MINUS_STRAND) - ) - else: - strand_specific_counts = None - - region_length = end - start + 1 - counts = self.compute_count_vector( - uniq_counts, - ambig_counts, - region_length, - strand_specific_counts=strand_specific_counts, - region_counts=True, - ) - - self.distribute_feature_counts(counts, region_annotation) - - gcounts = self.gene_counts.setdefault( - feature_id, np.zeros(self.bins) - ) - gcounts += counts - self.total_gene_counts += counts[:4] - - self.calculate_scaling_factors() - - -class GeneCountAnnotator(CountAnnotator): - """ CountAnnotator subclass for gene-based counting. """ - - def __init__(self, strand_specific, report_scaling_factors=True): - CountAnnotator.__init__(self, strand_specific, report_scaling_factors=report_scaling_factors) - - def annotate(self, refmgr, db, count_manager, gene_group_db=False): - """ - Annotate a set of gene counts via db-iteration. - input: - - bam: bamr.BamFile to use as reverse lookup table for reference ids - - db: GffDatabaseManager holding functional annotation database - - count_manager: count_data - """ - strand_specific_counts = ( - (count_manager.PLUS_STRAND, count_manager.MINUS_STRAND) - if self.strand_specific else None - ) - - for rid in set(count_manager.uniq_seqcounts).union( - count_manager.ambig_seqcounts - ): - ref, region_length = refmgr.get(rid[0] if isinstance(rid, tuple) else rid) - - uniq_counts, ambig_counts = count_manager.get_counts( - rid, region_counts=False, strand_specific=self.strand_specific - ) - - counts = self.compute_count_vector( - uniq_counts, - ambig_counts, - region_length, - strand_specific_counts=strand_specific_counts, - ) - - if gene_group_db: - ref_tokens = ref.split(".") - gene_id, ggroup_id = ".".join(ref_tokens[:-1]), ref_tokens[-1] - else: - ggroup_id, gene_id = ref, ref - - gcounts = self.gene_counts.setdefault(gene_id, np.zeros(self.bins)) - gcounts += counts - self.total_gene_counts += counts[:4] - - region_annotation = db.query_sequence(ggroup_id) - if region_annotation is not None: - _, _, region_annotation = region_annotation - logger.info( - "GCAnnotator: Distributing counts of Gene %s (group=%s) %s %s", - gene_id, ggroup_id, counts[0], counts[2], - ) - self.distribute_feature_counts(counts, region_annotation) - - else: - logger.info("GCAnnotator: Gene %s (group=%s) has no information in database.", gene_id, ggroup_id) - self.unannotated_counts += counts[:4] - - self.calculate_scaling_factors() diff --git a/gffquant/annotation/count_writer.py b/gffquant/annotation/count_writer.py index fa67c3fc..155bac39 100644 --- a/gffquant/annotation/count_writer.py +++ b/gffquant/annotation/count_writer.py @@ -1,4 +1,4 @@ -# pylint: disable=C0103,W1514,R0913,R0917 +# pylint: disable=C0103,W1514,R0913,R0917,R0914 """ module docstring """ @@ -8,6 +8,9 @@ import numpy as np +from ..counters import AlignmentCounter +from ..counters.count_matrix import CountMatrix + logger = logging.getLogger(__name__) @@ -74,6 +77,7 @@ def compile_block(raw, lnorm, scaling_factors): p, row = 0, [] rpkm_factor = 1e9 / self.filtered_readcount + # unique counts row += compile_block(*counts[p:p + 2], (scaling_factor, rpkm_factor,)) p += 2 @@ -107,69 +111,92 @@ def compile_block(raw, lnorm, scaling_factors): def write_row(header, data, stream=sys.stdout): print(header, *(f"{c:.5f}" for c in data), flush=True, sep="\t", file=stream) - # pylint: disable=R0914 - def write_feature_counts(self, db, featcounts, unannotated_reads=None, report_unseen=True): - for category_id, counts in sorted(featcounts.items()): - scaling_factor, ambig_scaling_factor = featcounts.scaling_factors[ - category_id - ] - category = db.query_category(category_id).name - if "scaled" in self.publish_reports: - logger.info( - "SCALING FACTORS %s %s %s", - category, scaling_factor, ambig_scaling_factor + def write_category( + self, + category_id, + category_name, + category_sum, + counts, + # feature_names, + features, + unannotated_reads=None, + report_unseen=True, + ): + with gzip.open(f"{self.out_prefix}.{category_name}.txt.gz", "wt") as feat_out: + header = self.get_header() + print("feature", *header, sep="\t", file=feat_out) + + if unannotated_reads is not None: + print("unannotated", unannotated_reads, sep="\t", file=feat_out) + + if "total_readcount" in self.publish_reports: + CountWriter.write_row( + "total_reads", + np.zeros(len(header)) + self.total_readcount, + stream=feat_out, ) - with gzip.open(f"{self.out_prefix}.{category}.txt.gz", "wt") as feat_out: - header = self.get_header() - print("feature", *header, sep="\t", file=feat_out) - - if unannotated_reads is not None: - print("unannotated", unannotated_reads, sep="\t", file=feat_out) - - if "total_readcount" in self.publish_reports: - CountWriter.write_row( - "total_reads", - np.zeros(len(header)) + self.total_readcount, - stream=feat_out, - ) - if "filtered_readcount" in self.publish_reports: - CountWriter.write_row( - "filtered_reads", - np.zeros(len(header)) + self.filtered_readcount, - stream=feat_out, - ) + if "filtered_readcount" in self.publish_reports: + CountWriter.write_row( + "filtered_reads", + np.zeros(len(header)) + self.filtered_readcount, + stream=feat_out, + ) - if "category" in self.publish_reports: - cat_counts = counts.get(f"cat:::{category_id}") - if cat_counts is not None: - cat_row = self.compile_output_row( - cat_counts, - scaling_factor=featcounts.scaling_factors["total_uniq"], - ambig_scaling_factor=featcounts.scaling_factors["total_ambi"], - ) - CountWriter.write_row("category", cat_row, stream=feat_out) - - for feature in db.get_features(category_id): - f_counts = counts.get(str(feature.id), np.zeros(len(header))) - if report_unseen or f_counts.sum(): - out_row = self.compile_output_row( - f_counts, - scaling_factor=scaling_factor, - ambig_scaling_factor=ambig_scaling_factor, - ) - CountWriter.write_row(feature.name, out_row, stream=feat_out) - - def write_gene_counts(self, gene_counts, uniq_scaling_factor, ambig_scaling_factor): - if "scaled" in self.publish_reports: - logger.info("SCALING_FACTORS %s %s", uniq_scaling_factor, ambig_scaling_factor) + if "category" in self.publish_reports: + # cat_counts = counts[0] + cat_counts = category_sum + logger.info("CAT %s: %s", category_name, str(cat_counts)) + if cat_counts is not None: + CountWriter.write_row("category", category_sum, stream=feat_out) + + # for item in counts: + # if not isinstance(item[0], tuple): + # logger.info("ITEM: %s", str(item)) + # raise TypeError(f"Weird key: {str(item)}") + # (cid, fid), fcounts = item + # if (report_unseen or fcounts.sum()) and cid == category_id: + # CountWriter.write_row(feature_names[fid], fcounts, stream=feat_out,) + + empty_row = np.zeros(6, dtype=CountMatrix.NUMPY_DTYPE) + for feature in features: + key = (category_id, feature.id) + if counts.has_record(key): + row = counts[key] + else: + row = empty_row + if (report_unseen or row.sum()): + CountWriter.write_row(feature.name, row, stream=feat_out,) + + + # for (cid, fid), fcounts in counts: + # if (report_unseen or fcounts.sum()) and cid == category_id: + # CountWriter.write_row(feature_names[fid], fcounts, stream=feat_out,) + + def write_gene_counts( + self, + gene_counts: AlignmentCounter, + refmgr, + gene_group_db=False, + ): with gzip.open(f"{self.out_prefix}.gene_counts.txt.gz", "wt") as gene_out: print("gene", *self.get_header(), sep="\t", file=gene_out, flush=True) - for gene, g_counts in sorted(gene_counts.items()): - out_row = self.compile_output_row( - g_counts, - scaling_factor=uniq_scaling_factor, - ambig_scaling_factor=ambig_scaling_factor + ref_stream = ( + ( + refmgr.get(rid[0] if isinstance(rid, tuple) else rid)[0], + rid, ) - CountWriter.write_row(gene, out_row, stream=gene_out) + for rid, _ in gene_counts + ) + + for ref, rid in sorted(ref_stream): + counts = gene_counts[rid] + # if gene_group_db: + # ref_tokens = ref.split(".") + # gene_id, _ = ".".join(ref_tokens[:-1]), ref_tokens[-1] + # else: + # gene_id = ref + gene_id = ref + + CountWriter.write_row(gene_id, counts, stream=gene_out,) diff --git a/gffquant/annotation/genecount_annotator.py b/gffquant/annotation/genecount_annotator.py new file mode 100644 index 00000000..c7697536 --- /dev/null +++ b/gffquant/annotation/genecount_annotator.py @@ -0,0 +1,77 @@ +# pylint: disable=R0914 + +""" module docstring """ +import logging + +import numpy as np + +from .count_annotator import CountAnnotator +from ..counters import AlignmentCounter +from ..counters.count_matrix import CountMatrix +from ..db.annotation_db import AnnotationDatabaseManager + + +logger = logging.getLogger(__name__) + + +class GeneCountAnnotator(CountAnnotator): + """ CountAnnotator subclass for gene-based counting. """ + + def __init__(self, strand_specific, report_scaling_factors=True): + """ __init__() """ + CountAnnotator.__init__(self, strand_specific, report_scaling_factors=report_scaling_factors) + + def annotate_gene_counts( + self, + refmgr, + db: AnnotationDatabaseManager, + counter: AlignmentCounter, + gene_group_db=False + ): + categories = list(db.get_categories()) + category_sums = np.zeros((len(categories), 6)) + functional_counts = CountMatrix(6) + + # for category in categories: + # features = ((feature.name, feature) for feature in db.get_features(category.id)) + # for _, feature in sorted(features, key=lambda x: x[0]): + # _ = functional_counts[(category.id, feature.id)] + + for rid, counts in counter: + # counts = counter[rid] + if gene_group_db: + ggroup_id = rid + else: + ref, _ = refmgr.get(rid[0] if isinstance(rid, tuple) else rid) + ggroup_id = ref + + region_annotation = db.query_sequence(ggroup_id) + if region_annotation is not None: + _, _, region_annotation = region_annotation + for category_id, features in region_annotation: + category_id = int(category_id) + category_sums[category_id] += counts + for feature_id in features: + feature_id = int(feature_id) + functional_counts[(category_id, feature_id)] += counts + + functional_counts.drop_unindexed() + + for i, category in enumerate(categories): + u_sf, c_sf = ( + CountMatrix.calculate_scaling_factor(*category_sums[i][0:2]), + CountMatrix.calculate_scaling_factor(*category_sums[i][3:5]), + ) + + rows = tuple( + key[0] == category.id + for key, _ in functional_counts + ) + + functional_counts.scale_column(1, u_sf, rows=rows) + functional_counts.scale_column(4, c_sf, rows=rows) + + category_sums[i, 2] = category_sums[i, 1] * u_sf + category_sums[i, 5] = category_sums[i, 4] * c_sf + + return functional_counts, category_sums diff --git a/gffquant/annotation/regioncount_annotator.py b/gffquant/annotation/regioncount_annotator.py new file mode 100644 index 00000000..57931cff --- /dev/null +++ b/gffquant/annotation/regioncount_annotator.py @@ -0,0 +1,80 @@ +""" module docstring """ + +import numpy as np + +from . import CountAnnotator +from ..counters import AlignmentCounter + + +class RegionCountAnnotator(CountAnnotator): + """ CountAnnotator subclass for contig/region-based counting. """ + + def __init__(self, strand_specific, report_scaling_factors=True): + raise NotImplementedError() + CountAnnotator.__init__(self, strand_specific, report_scaling_factors=report_scaling_factors) + + # pylint: disable=R0914,W0613 + def annotate(self, refmgr, db, counter: AlignmentCounter, gene_group_db=False): + """ + Annotate a set of region counts via db-lookup. + input: + - bam: bamr.BamFile to use as lookup table for reference names + - db: GffDatabaseManager holding functional annotation database + """ + for rid in counter.get_all_regions(region_counts=True): + ref = refmgr.get(rid[0] if isinstance(rid, tuple) else rid)[0] + + for region in counter.get_regions(rid): + if self.strand_specific: + (start, end), rev_strand = region + else: + (start, end), rev_strand = region, None + # the region_annotation is a tuple of key-value pairs: + # (strand, func_category1: subcategories, func_category2: subcategories, ...) + # the first is the strand, the second is the gene id, the rest are the features + + region_annotation = db.query_sequence(ref, start=start, end=end) + if region_annotation is not None: + region_strand, feature_id, region_annotation = region_annotation + if feature_id is None: + feature_id = ref + + on_other_strand = (region_strand == "+" and rev_strand) \ + or (region_strand == "-" and not rev_strand) + + antisense_region = self.strand_specific and on_other_strand + + uniq_counts, ambig_counts = counter.get_counts( + (rid, start, end), region_counts=True, strand_specific=self.strand_specific + ) + + if self.strand_specific: + # if the region is antisense, 'sense-counts' (relative to the) region come from the + # negative strand and 'antisense-counts' from the positive strand + # vice-versa for a sense-region + strand_specific_counts = ( + (counter.MINUS_STRAND, counter.PLUS_STRAND) + if antisense_region + else (counter.PLUS_STRAND, counter.MINUS_STRAND) + ) + else: + strand_specific_counts = None + + region_length = end - start + 1 + counts = self.compute_count_vector( + uniq_counts, + ambig_counts, + region_length, + strand_specific_counts=strand_specific_counts, + region_counts=True, + ) + + self.distribute_feature_counts(counts, region_annotation) + + gcounts = self.gene_counts.setdefault( + feature_id, np.zeros(self.bins) + ) + gcounts += counts + self.total_gene_counts += counts[:4] + + self.calculate_scaling_factors() diff --git a/gffquant/counters/__init__.py b/gffquant/counters/__init__.py index 7641c957..2b5426c3 100644 --- a/gffquant/counters/__init__.py +++ b/gffquant/counters/__init__.py @@ -4,6 +4,5 @@ """module docstring""" from .alignment_counter import AlignmentCounter +# from .count_matrix import CountMatrix from .region_counter import RegionCounter -from .seq_counter import UniqueSeqCounter, AmbiguousSeqCounter -from .count_manager import CountManager diff --git a/gffquant/counters/alignment_counter.py b/gffquant/counters/alignment_counter.py index 3c42b254..3f716955 100644 --- a/gffquant/counters/alignment_counter.py +++ b/gffquant/counters/alignment_counter.py @@ -1,18 +1,26 @@ -# pylint: disable=W0223 -# pylint: disable=C0103 -# pylint: disable=W1514 +# pylint: disable=R0902 -"""module docstring""" +""" module docstring """ import gzip +import logging -from collections import Counter +import numpy as np +from .count_matrix import CountMatrix from .. import DistributionMode -class AlignmentCounter(Counter): - COUNT_HEADER_ELEMENTS = ["raw", "lnorm", "scaled"] +logger = logging.getLogger(__name__) + + +class AlignmentCounter: + COUNT_HEADER_ELEMENTS = ("raw", "lnorm", "scaled") + INITIAL_SIZE = 1000 + # this may be counter-intuitive + # but originates from the samflags 0x10, 0x20, + # which explicitly identify the reverse-strandness of the read + PLUS_STRAND, MINUS_STRAND = False, True @staticmethod def normalise_counts(counts, feature_len, scaling_factor): @@ -21,32 +29,140 @@ def normalise_counts(counts, feature_len, scaling_factor): scaled = normalised * scaling_factor return counts, normalised, scaled - def get_increment(self, n_aln, increment): + @staticmethod + def get_increment(n_aln, increment, distribution_mode): # 1overN = lavern. Maya <3 - return (increment / n_aln) if self.distribution_mode == DistributionMode.ONE_OVER_N else increment + return (increment, (increment / n_aln))[distribution_mode == DistributionMode.ONE_OVER_N] + + def toggle_single_read_handling(self, unmarked_orphans): + # precalculate count-increment for single-end, paired-end reads + # for mixed input (i.e., paired-end data with single-end reads = orphans from preprocessing), + # properly attribute fractional counts to the orphans + # Increments: + # alignment from single end library read: 1 + # alignment from paired-end library read: 0.5 / mate (pe_count = 1) or 1 / mate (pe_count = 2) + # alignment from paired-end library orphan: 0.5 (pe_count = 1) or 1 (pe_count = 2) + + # old code: + # increment = 1 if (not pair or self.paired_end_count == 2) else 0.5 + + # if pair: + # increment = 1 if self.paired_end_count == 2 else 0.5 + # else: + # increment = 0.5 if self.unmarked_orphans else 1 + self.increments = ( + (self.paired_end_count / 2.0) if unmarked_orphans else 1.0, + self.paired_end_count / 2.0, + ) - def __init__(self, distribution_mode=DistributionMode.ONE_OVER_N, strand_specific=False): - Counter.__init__(self) + def __init__( + self, + distribution_mode=DistributionMode.ONE_OVER_N, + strand_specific=False, + paired_end_count=1, + ): self.distribution_mode = distribution_mode self.strand_specific = strand_specific + self.paired_end_count = paired_end_count + self.increments = (1.0, 1.0,) + self.increments_auto_detect = (1.0, self.paired_end_count / 2.0,) self.unannotated_reads = 0 + self.counts = CountMatrix(2, nrows=AlignmentCounter.INITIAL_SIZE) + def dump(self, prefix, refmgr): + raise NotImplementedError() with gzip.open(f"{prefix}.{self.__class__.__name__}.txt.gz", "wt") as _out: - for k, v in self.items(): - ref, reflen = refmgr.get(k[0] if isinstance(k, tuple) else k) - print(k, ref, reflen, v, sep="\t", file=_out) + for key, key_index in self.index.items(): + ref, reflen = refmgr.get(key[0] if isinstance(key, tuple) else key) + print(key, ref, reflen, self.counts[key_index], sep="\t", file=_out) + # for k, v in self.items(): + # ref, reflen = refmgr.get(k[0] if isinstance(k, tuple) else k) + # print(k, ref, reflen, v, sep="\t", file=_out) + + def has_ambig_counts(self): + # return bool(self.counts[:, 1].sum() != 0) + return bool(self.counts.colsum(1) != 0) + + def __iter__(self): + yield from self.counts + + def __getitem__(self, key): + return self.counts[key] + + def update(self, count_stream, ambiguous_counts=False, pair=False, pe_library=None,): + if pe_library is not None: + # this is the case when the alignment has a read group tag + # if pe_library is True (RG tag '2') -> take paired-end increment (also for orphans) + # else (RG tag '1') -> take single-end increment + increment = self.increments_auto_detect[pe_library] + else: + # if the alignment has no (appropriate) read group tag + # use the paired-end information instead + # if orphan reads are present in the input sam/bam, + # the flag `--unmarked_orphans` should be set + # otherwise orphan reads will be assigned a count of 1. + increment = self.increments[pair] + + contributed_counts = self.update_counts(count_stream, increment=increment, ambiguous_counts=ambiguous_counts,) - def update_counts(self, count_stream, increment=1): + return contributed_counts + + def get_unannotated_reads(self): + # return self.counts["c591b65a0f4cd46d5125745a40c8c056"][0] + # return self.counts["c591b65a0f4cd"][0] + return self.counts["00000000"][0] + + def update_counts(self, count_stream, increment=1, ambiguous_counts=False): contributed_counts = 0 for hits, aln_count in count_stream: hit = hits[0] - inc = increment if aln_count == 1 else self.get_increment(aln_count, increment) - if self.strand_specific: - self[(hit.rid, hit.rev_strand)] += inc - else: - self[hit.rid] += inc + inc = ( + ( + AlignmentCounter.get_increment(aln_count, increment, self.distribution_mode), + increment, + ) + )[aln_count == 1] + key = ( + ( + hit.rid, + (hit.rid, hit.rev_strand), + ) + )[self.strand_specific] + self.counts[key][int(ambiguous_counts)] += inc contributed_counts += inc return contributed_counts + + def generate_gene_count_matrix(self, refmgr): + # transform 2-column uniq/ambig count matrix + # into 4 columns + # uniq_raw, combined_raw, uniq_lnorm, combined_lnorm + + # obtain gene lengths + gene_lengths = np.array( + tuple( + (refmgr.get(key[0] if isinstance(key, tuple) else key))[1] + for key, _ in self.counts + ) + ) + + self.counts = self.counts.generate_gene_counts(gene_lengths) + + return self.counts.sum() + + @staticmethod + def calculate_scaling_factor(raw, norm): + if norm == 0.0: + return 1.0 + return raw / norm + + def group_gene_count_matrix(self, refmgr): + + ggroups = ( + (refmgr.get(key[0] if isinstance(key, tuple) else key))[0].split(".")[0] + for key, _ in self.counts + ) + + self.counts = self.counts.group_gene_counts(ggroups) diff --git a/gffquant/counters/count_manager.py b/gffquant/counters/count_manager.py deleted file mode 100644 index 40ae72a6..00000000 --- a/gffquant/counters/count_manager.py +++ /dev/null @@ -1,162 +0,0 @@ -"""count_manager""" - -from collections import Counter - -from .. import DistributionMode -from .alignment_counter import AlignmentCounter -from .region_counter import RegionCounter - - -# pylint: disable=R0902 -class CountManager: - # this may be counter-intuitive - # but originates from the samflags 0x10, 0x20, - # which also identify the reverse-strandness of the read - # and not the forward-strandness - PLUS_STRAND, MINUS_STRAND = False, True - - def toggle_single_read_handling(self, unmarked_orphans): - # precalculate count-increment for single-end, paired-end reads - # for mixed input (i.e., paired-end data with single-end reads = orphans from preprocessing), - # properly attribute fractional counts to the orphans - # Increments: - # alignment from single end library read: 1 - # alignment from paired-end library read: 0.5 / mate (pe_count = 1) or 1 / mate (pe_count = 2) - # alignment from paired-end library orphan: 0.5 (pe_count = 1) or 1 (pe_count = 2) - - # old code: - # increment = 1 if (not pair or self.paired_end_count == 2) else 0.5 - - # if pair: - # increment = 1 if self.paired_end_count == 2 else 0.5 - # else: - # increment = 0.5 if self.unmarked_orphans else 1 - self.increments = [ - (self.paired_end_count / 2.0) if unmarked_orphans else 1.0, - self.paired_end_count / 2.0 - ] - - def __init__( - # pylint: disable=W0613,R0913 - self, - distribution_mode=DistributionMode.ONE_OVER_N, - region_counts=True, - strand_specific=False, - paired_end_count=1, - ): - self.distribution_mode = distribution_mode - self.strand_specific = strand_specific - self.paired_end_count = paired_end_count - self.increments = [1.0, 1.0] - self.increments_auto_detect = [1.0, self.paired_end_count / 2.0] - - self.uniq_seqcounts, self.ambig_seqcounts = None, None - self.uniq_regioncounts, self.ambig_regioncounts = None, None - - if region_counts: - self.uniq_regioncounts = RegionCounter(strand_specific=strand_specific) - self.ambig_regioncounts = RegionCounter( - strand_specific=strand_specific, - distribution_mode=distribution_mode, - ) - - else: - self.uniq_seqcounts = AlignmentCounter(strand_specific=strand_specific) - self.ambig_seqcounts = AlignmentCounter( - strand_specific=strand_specific, - distribution_mode=distribution_mode - ) - - def has_ambig_counts(self): - return self.ambig_regioncounts or self.ambig_seqcounts - - def update_counts(self, count_stream, ambiguous_counts=False, pair=False, pe_library=None): - seq_counter, region_counter = ( - (self.uniq_seqcounts, self.uniq_regioncounts) - if not ambiguous_counts - else (self.ambig_seqcounts, self.ambig_regioncounts) - ) - - if pe_library is not None: - # this is the case when the alignment has a read group tag - # if pe_library is True (RG tag '2') -> take paired-end increment (also for orphans) - # else (RG tag '1') -> take single-end increment - increment = self.increments_auto_detect[pe_library] - else: - # if the alignment has no (appropriate) read group tag - # use the paired-end information instead - # if orphan reads are present in the input sam/bam, - # the flag `--unmarked_orphans` should be set - # otherwise orphan reads will be assigned a count of 1. - increment = self.increments[pair] - - contributed_counts = 0 - if seq_counter is not None: - contributed_counts = seq_counter.update_counts(count_stream, increment=increment) - elif region_counter is not None: - contributed_counts = region_counter.update_counts(count_stream, increment=increment) - - return contributed_counts - - def dump_raw_counters(self, prefix, refmgr): - if self.uniq_seqcounts is not None: - self.uniq_seqcounts.dump(prefix, refmgr) - if self.ambig_seqcounts is not None: - self.ambig_seqcounts.dump(prefix, refmgr) - if self.uniq_regioncounts is not None: - self.uniq_regioncounts.dump(prefix, refmgr) - if self.ambig_regioncounts is not None: - self.ambig_regioncounts.dump(prefix, refmgr) - - def get_unannotated_reads(self): - unannotated_reads = 0 - - if self.uniq_regioncounts is not None: - unannotated_reads += self.uniq_regioncounts.unannotated_reads - if self.ambig_regioncounts is not None: - unannotated_reads += self.ambig_regioncounts.unannotated_reads - if self.uniq_seqcounts is not None: - unannotated_reads += self.uniq_seqcounts.unannotated_reads - if self.ambig_seqcounts is not None: - unannotated_reads += self.ambig_seqcounts.unannotated_reads - - return unannotated_reads - - def get_counts(self, seqid, region_counts=False, strand_specific=False): - if region_counts: - rid, seqid = seqid[0], seqid[1:] - uniq_counter = self.uniq_regioncounts.get(rid, Counter()) - ambig_counter = self.ambig_regioncounts.get(rid, Counter()) - - # pylint: disable=R1720 - if strand_specific: - raise NotImplementedError - else: - return [uniq_counter[seqid]], [ambig_counter[seqid]] - - else: - uniq_counter, ambig_counter = self.uniq_seqcounts, self.ambig_seqcounts - - if strand_specific: - uniq_counts, ambig_counts = [0.0, 0.0], [0.0, 0.0] - uniq_counts[seqid[1]] = uniq_counter[seqid] - ambig_counts[seqid[1]] = ambig_counter[seqid] - - # rid = seqid[0] if isinstance(seqid, tuple) else seqid - # uniq_counts = [ - # uniq_counter[(rid, CountManager.PLUS_STRAND)], - # uniq_counter[(rid, CountManager.MINUS_STRAND)], - # ] - # ambig_counts = [ - # ambig_counter[(rid, CountManager.PLUS_STRAND)], - # ambig_counter[(rid, CountManager.MINUS_STRAND)], - # ] - else: - uniq_counts, ambig_counts = [uniq_counter[seqid]], [ambig_counter[seqid]] - - return uniq_counts, ambig_counts - - def get_regions(self, rid): - return set(self.uniq_regioncounts.get(rid, set())).union( - self.ambig_regioncounts.get(rid, set()) - ) diff --git a/gffquant/counters/count_matrix.py b/gffquant/counters/count_matrix.py new file mode 100644 index 00000000..dda3d8a6 --- /dev/null +++ b/gffquant/counters/count_matrix.py @@ -0,0 +1,175 @@ +""" module docstring """ + +import logging + +import numpy as np + + +logger = logging.getLogger(__name__) + + +class CountMatrix: + NUMPY_DTYPE = 'float64' # float16 causes some overflow issue during testing + + @classmethod + def from_count_matrix(cls, cmatrix, rows=None): + if rows is None: + counts = np.array(cmatrix.counts) + index = dict(cmatrix.index.items()) + else: + counts = cmatrix.counts[rows, :] + index = {} + for (key, _), keep in zip(cmatrix.index.items(), rows): + if keep: + index[key] = len(index) + # index = { + # key: value + # for (key, value), keep in zip(cmatrix.index.items(), rows) + # if keep + # } + return cls(index=index, counts=counts) + + @staticmethod + def calculate_scaling_factor(raw, norm): + if norm == 0.0: + return 1.0 + return raw / norm + + def __init__(self, ncols=2, nrows=1000, index=None, counts=None,): + if index is not None and counts is not None: + self.index = dict(index.items()) + self.counts = counts + else: + self.index = {} + self.counts = np.zeros( + (nrows, ncols,), + dtype=CountMatrix.NUMPY_DTYPE, + ) + + def has_record(self, key): + return self.index.get(key) is not None + + def _resize(self): + nrows = self.counts.shape[0] + if len(self.index) == nrows: + self.counts = np.pad( + self.counts, + ((0, nrows + 1000), (0, 0),), + ) + return len(self.index) + + def __getitem__(self, key): + key_index = self.index.get(key) + if key_index is None: + key_index = self.index[key] = self._resize() + return self.counts[key_index] + + def __setitem__(self, key, value): + key_index = self.index.get(key) + if key_index is None: + key_index = self.index[key] = self._resize() + self.counts[key_index] = value + + def __iter__(self): + yield from zip(self.index.keys(), self.counts) + + def sum(self): + return self.counts.sum(axis=0) + + def scale_column(self, col_index, factor, rows=None): + # apply scaling factors + if rows is None: + self.counts[:, col_index + 1] = self.counts[:, col_index] * factor + else: + self.counts[rows, col_index + 1] = self.counts[rows, col_index] * factor + + def drop_unindexed(self): + self.counts = self.counts[0:len(self.index), :] + + def generate_gene_counts(self, lengths): + logger.info("LENGTHS ARRAY = %s", lengths.shape) + logger.info("INDEX SIZE = %s", len(self.index)) + + # remove the un-indexed rows + counts = self.counts[0:len(self.index), :] + + # calculate combined_raw + counts[:, 1:2] += counts[:, 0:1] + + # duplicate the raw counts + counts = np.column_stack( + ( + counts[:, 0], counts[:, 0], counts[:, 0], # 0, 1, 2 + counts[:, 1], counts[:, 1], counts[:, 1], # 3, 4, 5 + ), + ) + + # length-normalise the lnorm columns + counts[:, 1::3] /= lengths[:, None] + + count_sums = counts.sum(axis=0) + + uniq_scaling_factor, combined_scaling_factor = ( + CountMatrix.calculate_scaling_factor(*count_sums[0:2]), + CountMatrix.calculate_scaling_factor(*count_sums[3:5]), + ) + + logger.info( + "AC:: TOTAL GENE COUNTS: uraw=%s unorm=%s craw=%s cnorm=%s => SF: %s %s", + count_sums[0], count_sums[1], count_sums[3], count_sums[4], + uniq_scaling_factor, combined_scaling_factor, + ) + + # apply scaling factors + counts[:, 2] = counts[:, 1] * uniq_scaling_factor + counts[:, 5] = counts[:, 4] * combined_scaling_factor + + self.counts = counts + + return self + + def dump(self, state="genes", labels=None,): + with open(f"CountMatrix.{state}.txt", "wt") as _out: + if labels is None: + for index, counts in self: + print(index, *counts, sep="\t", file=_out) + else: + for (index, counts), label in zip(self, labels): + print(label, *counts, sep="\t", file=_out) + + + def group_gene_counts(self, ggroups): + + ggroup_counts = CountMatrix(ncols=6) + for (_, gene_counts), ggroup_id in zip(self, ggroups): + ggroup_counts[ggroup_id] += gene_counts + + return ggroup_counts + + + + ggroup_index = {} + # for gene_id, gene_counts in self: + # ggroup_id = gene_id.split(".")[-1] + # g_key_index = ggroup_index.get(ggroup_id) + for (_, gene_counts), ggroup_id in zip(self, ggroups): + g_key_index = ggroup_index.get(ggroup_id) + # gene_counts = self.counts[self.index[key]] + if g_key_index is None: + g_key_index = ggroup_index[ggroup_id] = len(ggroup_index) + self.counts[g_key_index] = gene_counts + # logger.info("CM.group_gene_counts: Adding %s to new group %s (%s).", str(gene_counts), ggroup_id, g_key_index) + else: + self.counts[g_key_index] += gene_counts + # logger.info("CM.group_gene_counts: Adding %s to group %s (%s).", str(gene_counts), ggroup_id, g_key_index) + + # replace index with grouped index + self.index = ggroup_index + + # remove the un-indexed (ungrouped) rows + self.counts = self.counts[0:len(self.index), :] + + return self + + def colsum(self, col): + return self.counts[:, col].sum() diff --git a/gffquant/counters/region_counter.py b/gffquant/counters/region_counter.py index 7a617056..ffab7718 100644 --- a/gffquant/counters/region_counter.py +++ b/gffquant/counters/region_counter.py @@ -8,6 +8,21 @@ from .alignment_counter import AlignmentCounter +# from count_manager.get_counts() +# if region_counts: +# raise NotImplementedError() +# rid, seqid = seqid[0], seqid[1:] + +# uniq_counter = self.uniq_regioncounts.get(rid, Counter()) +# ambig_counter = self.ambig_regioncounts.get(rid, Counter()) + +# # pylint: disable=R1720 +# if strand_specific: +# raise NotImplementedError +# else: +# return [uniq_counter[seqid]], [ambig_counter[seqid]] + + class RegionCounter(AlignmentCounter): """This counter class can be used in overlap mode, i.e. when reads are aligned against long references (e.g. contigs) @@ -27,75 +42,10 @@ def _update_region(self, region_id, ostart, oend, rev_strand, cstart=None, cend= def update_counts(self, count_stream, increment=1): contributed_counts = 0 for hits, aln_count in count_stream: - inc = increment if aln_count == 1 else self.get_increment(aln_count, increment) + inc = increment if aln_count == 1 else AlignmentCounter.get_increment(aln_count, increment, self.distribution_mode) for hit in hits: self._update_region( hit.rid, hit.start, hit.end, hit.rev_strand, increment=inc, ) contributed_counts += inc return contributed_counts - - -class UniqueRegionCounter(RegionCounter): - """This counter class can be used in overlap mode, i.e. - when reads are aligned against long references (e.g. contigs) - with multiple regions of interest (features). - """ - - def __init__(self, distribution_mode=DistributionMode.ONE_OVER_N, strand_specific=False): - RegionCounter.__init__( - self, distribution_mode=distribution_mode, strand_specific=strand_specific, - ) - - # pylint: disable=W0613 - def update_counts(self, count_stream, increment=1): - """Update counter with alignments against the same reference. - - input: count_stream - - counts: set of overlaps with the reference - - aln_count: 1 if overlaps else 0 - - unaligned: 1 - aln_count - (redundant input due to streamlining uniq/ambig dataflows) - """ - for counts, aln_count, unaligned in count_stream: - if aln_count: - for rid, hits in counts.items(): - for hit in hits: - self._update_region( - rid, *hit, increment=increment - ) - else: - self.unannotated_reads += unaligned - - -class AmbiguousRegionCounter(RegionCounter): - """This counter class can be used in overlap mode, i.e. - when reads are aligned against long references (e.g. contigs) - with multiple regions of interest (features). - """ - - def __init__(self, distribution_mode=DistributionMode.ONE_OVER_N, strand_specific=False): - RegionCounter.__init__( - self, distribution_mode=distribution_mode, strand_specific=strand_specific, - ) - - # pylint: disable=W0613 - def update_counts(self, count_stream, increment=1): - """Update counter with alignments against the same reference. - - input: count_stream - - counts: set of overlaps with the reference - - aln_count: 1 if overlaps else 0 - - unaligned: 1 - aln_count - (redundant input due to streamlining uniq/ambig dataflows) - """ - for counts, aln_count, unaligned in count_stream: - if aln_count: - inc = self.get_increment(aln_count, increment) - for rid, hits in counts.items(): - for hit in hits: - self._update_region( - rid, *hit, increment=inc - ) - else: - self.unannotated_reads += unaligned diff --git a/gffquant/counters/seq_counter.py b/gffquant/counters/seq_counter.py deleted file mode 100644 index bc71c7fb..00000000 --- a/gffquant/counters/seq_counter.py +++ /dev/null @@ -1,63 +0,0 @@ -# pylint: disable=W0223 - -""" module docstring """ - -from .. import DistributionMode -from .alignment_counter import AlignmentCounter - - -class UniqueSeqCounter(AlignmentCounter): - def __init__(self, strand_specific=False): - AlignmentCounter.__init__(self, strand_specific=strand_specific) - - def get_counts(self, seq_ids): - """ - Given a list of sequence ids, return the total number of reads that mapped to each of those - sequences - - :param seq_ids: a list of sequence ids to count - :return: A list of counts for each sequence ID. - """ - if self.strand_specific: - return sum( - self[(seq_id, strand)] for seq_id in seq_ids for strand in (True, False) - ) - return sum(self[seq_id] for seq_id in seq_ids) - - def update_counts(self, count_stream, increment=1): - for counts, _, _ in count_stream: - - for rid, hits in counts.items(): - - if self.strand_specific: - strands = tuple(int(strand) for _, _, strand, _, _ in hits) - - self[(rid, True)] += sum(strands) * increment - self[(rid, False)] += (len(hits) - sum(strands)) * increment - - else: - self[rid] += len(hits) * increment - - -class AmbiguousSeqCounter(AlignmentCounter): - def __init__(self, strand_specific=False, distribution_mode=DistributionMode.ONE_OVER_N): - AlignmentCounter.__init__( - self, distribution_mode=distribution_mode, strand_specific=strand_specific - ) - - def update_counts(self, count_stream, increment=1): - - for counts, aln_count, _ in count_stream: - - inc = self.get_increment(aln_count, increment) - - for rid, hits in counts.items(): - - if self.strand_specific: - strands = tuple(int(strand) for _, _, strand, _, _ in hits) - - self[(rid, True)] += sum(strands) * inc - self[(rid, False)] += (len(hits) - sum(strands)) * inc - - else: - self[rid] += len(hits) * inc diff --git a/gffquant/profilers/feature_quantifier.py b/gffquant/profilers/feature_quantifier.py index bd5c5c96..53a5c2e6 100644 --- a/gffquant/profilers/feature_quantifier.py +++ b/gffquant/profilers/feature_quantifier.py @@ -10,12 +10,14 @@ from abc import ABC from collections import Counter -from dataclasses import dataclass, asdict + +import numpy as np from .panda_coverage_profiler import PandaCoverageProfiler -from ..alignment import AlignmentGroup, AlignmentProcessor, SamFlags +from ..alignment import AlignmentGroup, AlignmentProcessor, ReferenceHit, SamFlags from ..annotation import GeneCountAnnotator, RegionCountAnnotator, CountWriter -from ..counters import CountManager +from ..counters import AlignmentCounter +from ..counters.count_matrix import CountMatrix from ..db.annotation_db import AnnotationDatabaseManager from .. import __tool__, DistributionMode, RunMode @@ -24,39 +26,6 @@ logger = logging.getLogger(__name__) -@dataclass(slots=True) -class ReferenceHit: - rid: int = None - start: int = None - end: int = None - rev_strand: bool = None - cov_start: int = None - cov_end: int = None - has_annotation: bool = None - n_aln: int = None - is_ambiguous: bool = None - library_mod: int = None - mate_id: int = None - - def __hash__(self): - return hash(tuple(asdict(self).values())) - - def __eq__(self, other): - return all( - item[0][1] == item[1][1] - for item in zip( - sorted(asdict(self).items()), - sorted(asdict(other).items()) - ) - ) - - def __str__(self): - return "\t".join(map(str, asdict(self).values())) - - def __repr__(self): - return str(self) - - class FeatureQuantifier(ABC): """ Three groups of alignments: @@ -93,17 +62,15 @@ def __init__( self.db = db self.adm = None self.run_mode = run_mode - self.count_manager = CountManager( + self.counter = AlignmentCounter( distribution_mode=distribution_mode, - region_counts=run_mode.overlap_required, - strand_specific=strand_specific and not run_mode.overlap_required, + strand_specific=strand_specific, paired_end_count=paired_end_count, ) self.out_prefix = out_prefix self.distribution_mode = distribution_mode self.reference_manager = {} self.strand_specific = strand_specific - # self.coverage_counter = {} self.debug = debug self.panda_cv = PandaCoverageProfiler(dump_dataframes=self.debug) if calculate_coverage else None @@ -158,17 +125,16 @@ def process_counters( self.adm = AnnotationDatabaseManager.from_db(self.db, in_memory=in_memory) if dump_counters: - self.count_manager.dump_raw_counters(self.out_prefix, self.reference_manager) + self.counter.dump(self.out_prefix, self.reference_manager,) report_scaling_factors = restrict_reports is None or "scaled" in restrict_reports Annotator = (GeneCountAnnotator, RegionCountAnnotator)[self.run_mode.overlap_required] count_annotator = Annotator(self.strand_specific, report_scaling_factors=report_scaling_factors) - count_annotator.annotate(self.reference_manager, self.adm, self.count_manager, gene_group_db=gene_group_db,) count_writer = CountWriter( self.out_prefix, - has_ambig_counts=self.count_manager.has_ambig_counts(), + has_ambig_counts=self.counter.has_ambig_counts(), strand_specific=self.strand_specific, restrict_reports=restrict_reports, report_category=report_category, @@ -176,21 +142,126 @@ def process_counters( filtered_readcount=self.aln_counter["filtered_read_count"], ) - unannotated_reads = self.count_manager.get_unannotated_reads() - unannotated_reads += self.aln_counter["unannotated_ambig"] + total_gene_counts = self.counter.generate_gene_count_matrix(self.reference_manager) + logger.info("TOTAL_GENE_COUNTS = %s", total_gene_counts) - count_writer.write_feature_counts( - self.adm, - count_annotator, - (None, unannotated_reads)[report_unannotated], + count_writer.write_gene_counts( + self.counter, + self.reference_manager, + gene_group_db=gene_group_db, ) - count_writer.write_gene_counts( - count_annotator.gene_counts, - count_annotator.scaling_factors["total_gene_uniq"], - count_annotator.scaling_factors["total_gene_ambi"] + ggroups = tuple( + (self.reference_manager.get(key[0] if isinstance(key, tuple) else key))[0] # .split(".")[0] + for key, _ in self.counter ) + + self.counter.counts.dump(labels=ggroups) + + self.counter.group_gene_count_matrix(self.reference_manager) + unannotated_reads = self.counter.get_unannotated_reads() + self.aln_counter["unannotated_ambig"] + + self.counter.counts.dump(state="ggroup") + + # categories = self.adm.get_categories() + + # for category in categories: + # logger.info("PROCESSING CATEGORY=%s", category.name) + # category_sum = np.zeros(6, dtype='float64') + # category_counts = CountMatrix(ncols=6) + # for rid, counts in self.counter: + # if gene_group_db: + # ggroup_id = rid + # logger.info("GGROUP %s: %s", ggroup_id, str(counts)) + # else: + # ref, _ = self.reference_manager.get(rid[0] if isinstance(rid, tuple) else rid) + # ggroup_id = ref + + # region_annotation = self.adm.query_sequence(ggroup_id) + # if region_annotation is not None: + # _, _, region_annotation = region_annotation + # for category_id, features in region_annotation: + # if int(category_id) == category.id: + # category_sum += counts + # for feature_id in features: + # category_counts[(category.id, int(feature_id))] += counts + # break + + # u_sf, c_sf = ( + # CountMatrix.calculate_scaling_factor(*category_sum[0:2]), + # CountMatrix.calculate_scaling_factor(*category_sum[3:5]), + # ) + + # category_counts.scale_column(1, u_sf) + # category_counts.scale_column(4, c_sf) + + # category_sum[2] = category_sum[1] / u_sf + # category_sum[5] = category_sum[4] / c_sf + + # features = tuple(self.adm.get_features(category.id)) + # count_writer.write_category( + # category.id, + # category.name, + # category_sum, + # category_counts, + # features, + # unannotated_reads=(None, unannotated_reads)[report_unannotated], + # ) + + + + + + functional_counts, category_sums = count_annotator.annotate_gene_counts( + self.reference_manager, + self.adm, + self.counter, + gene_group_db=gene_group_db, + ) + + logger.info("FC-index: %s", str(list(functional_counts.index.keys())[:10])) + logger.info("FC-counts: %s", str(functional_counts.counts[0:10, :])) + + categories = self.adm.get_categories() + for category, category_sum in zip(categories, category_sums): + features = tuple(self.adm.get_features(category.id)) + feature_names = { + feature.id: feature.name + for feature in features + } + rows = tuple( + key[0] == category.id + for key, _ in functional_counts + ) + + cat_counts = CountMatrix.from_count_matrix(functional_counts, rows=rows) + # cat_counts = CountMatrix(ncols=6, nrows=len(feature_names)) + # for feature in features: + # key = (category.id, feature.id) + # if functional_counts.has_record(key): + # cat_counts[key] += functional_counts[key] + # else: + # _ = cat_counts[key] + + # for category in categories: + # features = ((feature.name, feature) for feature in db.get_features(category.id)) + # for _, feature in sorted(features, key=lambda x: x[0]): + # _ = functional_counts[(category.id, feature.id)] + + + logger.info("PROCESSING CATEGORY=%s", category.name) + count_writer.write_category( + category.id, + category.name, + category_sum, + # functional_counts, + cat_counts, + # feature_names, + features, + unannotated_reads=(None, unannotated_reads)[report_unannotated], + ) + self.adm.clear_caches() def register_reference(self, rid, aln_reader): @@ -226,7 +297,7 @@ def process_alignments( filtered_sam=debug_samfile, ) - self.count_manager.toggle_single_read_handling(unmarked_orphans) + self.counter.toggle_single_read_handling(unmarked_orphans) ac = self.aln_counter read_count = 0 @@ -447,7 +518,7 @@ def process_alignment_group(self, aln_group, aln_reader): ) ) - contributed_counts = self.count_manager.update_counts( + contributed_counts = self.counter.update( count_stream, ambiguous_counts=is_ambiguous_group, pair=aln_group.is_paired(), diff --git a/gffquant/profilers/panda_profiler.py b/gffquant/profilers/panda_profiler.py index 53dacee6..4a5ee057 100644 --- a/gffquant/profilers/panda_profiler.py +++ b/gffquant/profilers/panda_profiler.py @@ -28,7 +28,6 @@ def __init__( self._buffer_size = 0 self._max_buffer_size = 400_000_000 - def get_gene_coords(self): if self.with_overlap: for rid, start, end in zip( @@ -283,7 +282,6 @@ def add_records(self, hits, last_update=False): self._buffer += hits self._buffer_size += hits_size - def merge_dataframes(self): print("BUFFER:", len(self._buffer), self._buffer[:1]) hits_df = pd.DataFrame(self._buffer) @@ -319,7 +317,6 @@ def merge_dataframes(self): .groupby(by=self.index_columns, as_index=False) \ .sum(numeric_only=True) - def add_records_old(self, hits): # [2024-02-08 14:51:17,846] count_stream: