diff --git a/.github/workflows/pre-commit.yml b/.github/workflows/pre-commit.yml new file mode 100644 index 0000000..5da8d71 --- /dev/null +++ b/.github/workflows/pre-commit.yml @@ -0,0 +1,43 @@ +name: pre-commit + +on: + pull_request: + push: + branches: [main] + +jobs: + pre-commit: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up Python + id: setup-python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: "pip" + + - uses: actions/cache@v4 + with: + path: ~/.cache/pre-commit + key: > + ${{ format('pre-commit-{0}-{1}', + steps.setup-python.outputs.python-version, + hashFiles('.pre-commit-config.yaml') + ) }} + + - name: Install pre-commit + run: | + pip install --upgrade pip + pip install pre-commit + pre-commit install + + - name: Run pre-commit hooks + working-directory: ${{ inputs.working-directory }} + run: | + git ls-files | xargs pre-commit run \ + --show-diff-on-failure \ + --color=always \ + --files diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..f35a8aa --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,13 @@ +repos: + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 + hooks: + - id: trailing-whitespace + - id: end-of-file-fixer + - id: check-yaml + - id: check-added-large-files + - repo: https://github.com/google/pyink + rev: 24.10.1 + hooks: + - id: pyink + args: ["--pyink-indentation=2", "--line-length=80"] diff --git a/docs/plugins/main.py b/docs/plugins/main.py index 4199ce2..c3982e5 100644 --- a/docs/plugins/main.py +++ b/docs/plugins/main.py @@ -8,7 +8,9 @@ from mkdocs.structure.pages import Page -def on_page_markdown(markdown: str, page: Page, config: Config, files: Files) -> str: +def on_page_markdown( + markdown: str, page: Page, config: Config, files: Files +) -> str: """Called on each file after it is read and before it is converted to HTML.""" markdown = eastr_print_help(markdown, page) return markdown diff --git a/mkdocs.yml b/mkdocs.yml index 3bcfe84..d53319d 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -10,7 +10,7 @@ theme: # Palette toggle for light mode - media: "(prefers-color-scheme: light)" - scheme: default + scheme: default toggle: icon: material/brightness-7 name: Switch to dark mode diff --git a/setup.py b/setup.py index 6da46ad..60240ef 100644 --- a/setup.py +++ b/setup.py @@ -7,22 +7,27 @@ from setuptools import find_packages, setup, Extension from setuptools.command.build_ext import build_ext + class CMakeExtension(Extension): - def __init__(self, name, sourcedir=''): + + def __init__(self, name, sourcedir=""): Extension.__init__(self, name, sources=[]) self.sourcedir = os.path.abspath(sourcedir) + class CMakeBuild(build_ext): """CMake build extension.""" + def run(self): - extensions = ', '.join(e.name for e in self.extensions) - deps = ['cmake', 'make'] + extensions = ", ".join(e.name for e in self.extensions) + deps = ["cmake", "make"] for dep in deps: try: - subprocess.check_output([dep, '--version']) + subprocess.check_output([dep, "--version"]) except OSError as e: raise RuntimeError( - f'{dep} must be installed to build the following extensions: {extensions}') from e + f"{dep} must be installed to build the following extensions: {extensions}" + ) from e for ext in self.extensions: self.build_extension(ext) @@ -30,67 +35,70 @@ def run(self): def build_extension(self, ext): extdir = os.path.abspath(os.path.dirname(self.get_ext_fullpath(ext.name))) extdir = os.path.join(extdir, ext.name) - cmake_args = ['-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=' + extdir] - cmake_args += ['-DEXECUTABLE_OUTPUT_PATH=' + extdir] + cmake_args = ["-DCMAKE_LIBRARY_OUTPUT_DIRECTORY=" + extdir] + cmake_args += ["-DEXECUTABLE_OUTPUT_PATH=" + extdir] - cfg = 'Debug' if self.debug else 'Release' - build_args = ['--config', cfg] + cfg = "Debug" if self.debug else "Release" + build_args = ["--config", cfg] - if platform.system() == 'Windows': - cmake_args += ['-DCMAKE_GENERATOR_PLATFORM=x64'] - cmake_args += ['-DCMAKE_BUILD_TYPE=' + cfg] - build_args += ['--', '/m'] + if platform.system() == "Windows": + cmake_args += ["-DCMAKE_GENERATOR_PLATFORM=x64"] + cmake_args += ["-DCMAKE_BUILD_TYPE=" + cfg] + build_args += ["--", "/m"] else: - cmake_args += ['-DCMAKE_BUILD_TYPE=Release .' + cfg] + cmake_args += ["-DCMAKE_BUILD_TYPE=Release ." + cfg] env = os.environ.copy() - env['CXXFLAGS'] = '{} -DVERSION_INFO=\"{}\"'.format( - env.get('CXXFLAGS', ''), - self.distribution.get_version()) + env["CXXFLAGS"] = '{} -DVERSION_INFO="{}"'.format( + env.get("CXXFLAGS", ""), self.distribution.get_version() + ) if not os.path.exists(self.build_temp): os.makedirs(self.build_temp) - subprocess.check_call(['cmake', ext.sourcedir] + cmake_args, - cwd=self.build_temp, env=env) - subprocess.check_call(['cmake', '--build', '.', *build_args], - cwd=self.build_temp) + subprocess.check_call( + ["cmake", ext.sourcedir] + cmake_args, cwd=self.build_temp, env=env + ) + subprocess.check_call( + ["cmake", "--build", ".", *build_args], cwd=self.build_temp + ) + -desc = 'Tool for emending alignments of spuriously spliced transcript reads' +desc = "Tool for emending alignments of spuriously spliced transcript reads" -with open('./README.md', 'r', encoding='utf-8') as fh: +with open("./README.md", "r", encoding="utf-8") as fh: long_description = fh.read() -with open('./LICENSE', 'r', encoding='utf-8') as fh: +with open("./LICENSE", "r", encoding="utf-8") as fh: license_str = fh.read() setup( - name='eastr', - version='1.1.2', - author='Ida Shinder', - author_email='ishinde1@jhmi.edu', + name="eastr", + version="1.1.2", + author="Ida Shinder", + author_email="ishinde1@jhmi.edu", description=desc, long_description=long_description, - long_description_content_type='text/markdown', - url='https://github.com/ishinder/EASTR', + long_description_content_type="text/markdown", + url="https://github.com/ishinder/EASTR", cmdclass=dict(build_ext=CMakeBuild), - ext_modules=[CMakeExtension('eastr', 'utils')], + ext_modules=[CMakeExtension("eastr", "utils")], install_requires=[ - 'biopython>=1.81,<2.0', - 'mappy>=2.26,<3.0', - 'numpy>=1.26.1', - 'pandas>=2.1.2,<2.3', - 'pysam>=0.22.0,<0.23', + "biopython>=1.81,<2.0", + "mappy>=2.26,<3.0", + "numpy>=1.26.1", + "pandas>=2.1.2,<2.3", + "pysam>=0.22.0,<0.23", ], packages=find_packages( - where='src', - include=['eastr'], + where="src", + include=["eastr"], ), - package_dir={'': 'src'}, + package_dir={"": "src"}, entry_points={ - 'console_scripts': [ - 'eastr = eastr.run_eastr:main', - ] + "console_scripts": [ + "eastr = eastr.run_eastr:main", + ] }, - python_requires='>=3.10', - license='MIT', - license_files=('LICENSE',), + python_requires=">=3.10", + license="MIT", + license_files=("LICENSE",), ) diff --git a/src/eastr/alignment_utils.py b/src/eastr/alignment_utils.py index a9da037..77e2991 100644 --- a/src/eastr/alignment_utils.py +++ b/src/eastr/alignment_utils.py @@ -6,91 +6,100 @@ import mappy as mp -def get_seq(chrom,start,end,pysam_fa): - seq = pysam_fa.fetch(region=chrom,start=start,end=end) - return seq +def get_seq(chrom, start, end, pysam_fa): + seq = pysam_fa.fetch(region=chrom, start=start, end=end) + return seq -def align_seq_pair(rseq:str, qseq:str, scoring:list, k:int, w:int, m:int, best_n=1): - #TODO ambiguous bases not working - a = mp.Aligner(seq=rseq,k=k,w=w,best_n=best_n,scoring=scoring,min_chain_score=m) - itr = list(a.map(qseq,MD=True,cs=True)) - if best_n > 1: - return itr +def align_seq_pair( + rseq: str, qseq: str, scoring: list, k: int, w: int, m: int, best_n=1 +): + # TODO ambiguous bases not working + a = mp.Aligner( + seq=rseq, k=k, w=w, best_n=best_n, scoring=scoring, min_chain_score=m + ) + itr = list(a.map(qseq, MD=True, cs=True)) - if not itr: - return + if best_n > 1: + return itr - for hit in itr: - #TODO if hit.r_st==intron_len <- look for additional alignments? - #or it is not needed because the longest alignment is the "best" one? + if not itr: + return - if hit.strand != 1: - continue + for hit in itr: + # TODO if hit.r_st==intron_len <- look for additional alignments? + # or it is not needed because the longest alignment is the "best" one? - if not hit.is_primary: - continue + if hit.strand != 1: + continue - else: - return hit #returns the first primary hit + if not hit.is_primary: + continue -def calc_alignment_score(hit,scoring): - if hit is None: - return None + else: + return hit # returns the first primary hit - # TODO verify alignment_score calc - matches = hit.mlen - gap_penalty = 0 - cs = hit.cs - #gaps - p = re.compile('[\\-\\+]([atgc]+)') - m = p.findall(cs) - gaps = len(m) - for gap in m: - gap_len = len(gap) - gap_penalty += min(scoring[2] + (gap_len - 1) * scoring[3], - scoring[4] + (gap_len - 1) * scoring[5]) +def calc_alignment_score(hit, scoring): + if hit is None: + return None - #mismatches - p = re.compile('\\*([atgc]+)') - m = p.findall(cs) - mismatches = len(m) + # TODO verify alignment_score calc + matches = hit.mlen + gap_penalty = 0 + cs = hit.cs - alignment_score = matches*scoring[0] - (mismatches)*scoring[1] - gap_penalty + # gaps + p = re.compile("[\\-\\+]([atgc]+)") + m = p.findall(cs) + for gap in m: + gap_len = len(gap) + gap_penalty += min( + scoring[2] + (gap_len - 1) * scoring[3], + scoring[4] + (gap_len - 1) * scoring[5], + ) - return alignment_score + # mismatches + p = re.compile("\\*([atgc]+)") + m = p.findall(cs) + mismatches = len(m) -def get_alignment(chrom, jstart, jend, overhang, pysam_fa, max_length, scoring, k, w, m): + alignment_score = ( + matches * scoring[0] - (mismatches) * scoring[1] - gap_penalty + ) + return alignment_score - intron_len = jend - jstart - rstart = max(jstart - overhang, 0) - rend = min(jstart + overhang, max_length) - qstart = max(jend - overhang, 0) - qend = min(jend + overhang, max_length) +def get_alignment( + chrom, jstart, jend, overhang, pysam_fa, max_length, scoring, k, w, m +): + intron_len = jend - jstart - rseq = get_seq(chrom, rstart, rend, pysam_fa) - qseq = get_seq(chrom, qstart, qend, pysam_fa ) - hit = align_seq_pair(rseq, qseq, scoring,k,w,m) + rstart = max(jstart - overhang, 0) + rend = min(jstart + overhang, max_length) + qstart = max(jend - overhang, 0) + qend = min(jend + overhang, max_length) - if hit: - #check if the alignment is in the overlap region of short introns - if overhang * 2 >= intron_len: + rseq = get_seq(chrom, rstart, rend, pysam_fa) + qseq = get_seq(chrom, qstart, qend, pysam_fa) + hit = align_seq_pair(rseq, qseq, scoring, k, w, m) - if hit.r_st == intron_len: - return None + if hit: + # check if the alignment is in the overlap region of short introns + if overhang * 2 >= intron_len: + if hit.r_st == intron_len: + return None - if overhang > intron_len: - if hit.r_st - hit.q_st == intron_len: - return None + if overhang > intron_len: + if hit.r_st - hit.q_st == intron_len: + return None - if hit.r_st >= overhang and hit.q_en <= overhang: - return None + if hit.r_st >= overhang and hit.q_en <= overhang: + return None - return hit + return hit # def get_self_alignment(chrom, start, end, max_length, scoring, k, w, m, pysam_fa, overhang): @@ -98,57 +107,59 @@ def get_alignment(chrom, jstart, jend, overhang, pysam_fa, max_length, scoring, # return rseq,qseq,hit -def get_flanking_subsequences(introns,chrom_sizes,overhang,ref_fa): - tmp_regions = tempfile.NamedTemporaryFile(mode='a',dir=os.getcwd(),delete=False) - tmp_fa = tempfile.NamedTemporaryFile(dir=os.getcwd(),delete=False) - - seen = set() - for key in list(introns.keys()): - chrom = key[0] - jstart = key[1] - jend = key[2] - max_length = chrom_sizes[chrom] - rstart = max(jstart - overhang + 1, 1) #1 based - rend = min(jstart + overhang, max_length) - qstart = max(jend - overhang + 1, 1) #1 based - qend = min(jend + overhang, max_length) - r1 = f'{chrom}:{rstart}-{rend}' - r2 = f'{chrom}:{qstart}-{qend}' - introns[key]['jstart'] = r1 - introns[key]['jend'] = r2 - - if rstart> max_length or qstart > max_length: - #remove key from introns dict - del introns[key] - print(f'Warning: intron {key} from the GTF file is out of range in FASTA file. Skipping...') - continue - - for r in [r1,r2]: - if r not in seen: - t=tmp_regions.write(f'{r}\n') - seen.add(r) - - - tmp_regions.close() - tmp_fa.close() - - cmd1 = f"samtools faidx {ref_fa} -r {tmp_regions.name} -o {tmp_fa.name}" - p1 = subprocess.run(shlex.split(cmd1), check=True) - - - seqs = {} - with open(tmp_fa.name,'r') as f: - for line in f: - line = line.strip() - if line.startswith(">"): - seq_name = line[1:] - if seq_name not in seqs: - seqs[seq_name] = '' - continue - sequence = line - seqs[seq_name]=seqs[seq_name] + sequence - - os.unlink(tmp_regions.name) - os.unlink(tmp_fa.name) - - return seqs +def get_flanking_subsequences(introns, chrom_sizes, overhang, ref_fa): + tmp_regions = tempfile.NamedTemporaryFile( + mode="a", dir=os.getcwd(), delete=False + ) + tmp_fa = tempfile.NamedTemporaryFile(dir=os.getcwd(), delete=False) + + seen = set() + for key in list(introns.keys()): + chrom = key[0] + jstart = key[1] + jend = key[2] + max_length = chrom_sizes[chrom] + rstart = max(jstart - overhang + 1, 1) # 1 based + rend = min(jstart + overhang, max_length) + qstart = max(jend - overhang + 1, 1) # 1 based + qend = min(jend + overhang, max_length) + r1 = f"{chrom}:{rstart}-{rend}" + r2 = f"{chrom}:{qstart}-{qend}" + introns[key]["jstart"] = r1 + introns[key]["jend"] = r2 + + if rstart > max_length or qstart > max_length: + # remove key from introns dict + del introns[key] + print( + f"Warning: intron {key} from the GTF file is out of range in FASTA file. Skipping..." + ) + continue + + for r in [r1, r2]: + if r not in seen: + tmp_regions.write(f"{r}\n") + seen.add(r) + + tmp_regions.close() + tmp_fa.close() + + cmd1 = f"samtools faidx {ref_fa} -r {tmp_regions.name} -o {tmp_fa.name}" + subprocess.run(shlex.split(cmd1), check=True) + + seqs = {} + with open(tmp_fa.name, "r") as f: + for line in f: + line = line.strip() + if line.startswith(">"): + seq_name = line[1:] + if seq_name not in seqs: + seqs[seq_name] = "" + continue + sequence = line + seqs[seq_name] = seqs[seq_name] + sequence + + os.unlink(tmp_regions.name) + os.unlink(tmp_fa.name) + + return seqs diff --git a/src/eastr/extract_junctions.py b/src/eastr/extract_junctions.py index 9eace1e..c7cdb60 100644 --- a/src/eastr/extract_junctions.py +++ b/src/eastr/extract_junctions.py @@ -9,141 +9,162 @@ this_directory = pathlib.Path(__file__).resolve().parent # This should exist with source after compilation. -JUNCTION_CMD = os.path.join(this_directory, 'junction_extractor') +JUNCTION_CMD = os.path.join(this_directory, "junction_extractor") def get_junctions_from_bed(bed_path: str) -> dict: - junctions = {} - with open(bed_path, 'r') as f: - for line in f: - line = line.strip() - if line.startswith('#') or line.startswith('track'): - continue - fields = line.split('\t') - if len(fields) >= 6: - chrom, start, end, name, score, strand = fields[:6] - else: - print(f"Error in file: {bed_path}") - print(f"Offending line: {line}") - raise ValueError("Invalid BED format: Expected at least 6 columns.") - - start, end = int(start), int(end) - score = int(score) - if start > end: - raise Exception("Start of region cannot be greater than end of region for:\n", line) - junctions[(chrom, start, end, strand)] = (name, score) - return junctions - -def get_junctions_multi_bed(bed_list:list, p) -> dict: - with multiprocessing.Pool(p) as pool: - results = pool.map(get_junctions_from_bed, bed_list) - - dd = collections.defaultdict(dict) - for i, d in enumerate(results): - name2 = os.path.splitext(os.path.basename(bed_list[i]))[0] - for key, (name, score) in d.items(): - if key not in dd: - dd[key]['samples']= set() - dd[key]['samples'].add((name, name2, score)) - dd[key]['score'] = score - else: - dd[key]['samples'].add((name, name2, score)) - dd[key]['score'] = dd[key]['score'] + score - - return dd - -def junction_extractor(bam_path:str, out_path:str) -> dict: - check_for_dependency() - name = os.path.splitext(os.path.basename(bam_path))[0] - cmd = f"{JUNCTION_CMD} -o {out_path} {bam_path}" - a = subprocess.Popen(shlex.split(cmd), stdout=subprocess.DEVNULL) - b = a.communicate() - - with open(out_path, 'r') as f: - first_line = f.readline().strip() - if 'track name=junctions' in first_line: - skip = 1 + junctions = {} + with open(bed_path, "r") as f: + for line in f: + line = line.strip() + if line.startswith("#") or line.startswith("track"): + continue + fields = line.split("\t") + if len(fields) >= 6: + chrom, start, end, name, score, strand = fields[:6] + else: + print(f"Error in file: {bed_path}") + print(f"Offending line: {line}") + raise ValueError("Invalid BED format: Expected at least 6 columns.") + + start, end = int(start), int(end) + score = int(score) + if start > end: + raise Exception( + "Start of region cannot be greater than end of region for:\n", line + ) + junctions[(chrom, start, end, strand)] = (name, score) + return junctions + + +def get_junctions_multi_bed(bed_list: list, p) -> dict: + with multiprocessing.Pool(p) as pool: + results = pool.map(get_junctions_from_bed, bed_list) + + dd = collections.defaultdict(dict) + for i, d in enumerate(results): + name2 = os.path.splitext(os.path.basename(bed_list[i]))[0] + for key, (name, score) in d.items(): + if key not in dd: + dd[key]["samples"] = set() + dd[key]["samples"].add((name, name2, score)) + dd[key]["score"] = score + else: + dd[key]["samples"].add((name, name2, score)) + dd[key]["score"] = dd[key]["score"] + score + + return dd + + +def junction_extractor(bam_path: str, out_path: str) -> dict: + check_for_dependency() + name = os.path.splitext(os.path.basename(bam_path))[0] + cmd = f"{JUNCTION_CMD} -o {out_path} {bam_path}" + a = subprocess.Popen(shlex.split(cmd), stdout=subprocess.DEVNULL) + a.communicate() + + with open(out_path, "r") as f: + first_line = f.readline().strip() + if "track name=junctions" in first_line: + skip = 1 + else: + skip = 0 + + df = pd.read_csv(out_path, sep="\t", header=None, skiprows=skip, comment="#") + + junctions = {} + for _, row in df.iterrows(): + junctions[(row[0], row[1], row[2], row[5])] = (name, row[4]) + return junctions + + +def junction_extractor_multi_bam( + bam_list: list, out_original_junctions: list, p: int +) -> dict: + with multiprocessing.Pool(p) as pool: + results = pool.starmap( + junction_extractor, zip(bam_list, out_original_junctions) + ) + + dd = collections.defaultdict(dict) + for d in results: + for key, value in d.items(): + if key not in dd: + dd[key]["samples"] = set() + dd[key]["samples"].add(value) + dd[key]["score"] = value[1] + else: + dd[key]["samples"].add(value) + dd[key]["score"] = dd[key]["score"] + value[1] + + return dd + + +def extract_splice_sites_gtf(gtf_path: str) -> dict: + trans = {} + gtf = open(gtf_path, "r") + + for line in gtf: + line = line.strip() + if line.startswith("#"): + continue + chrom, source, feature, start, end, score, strand, frame, attributes = ( + line.split("\t") + ) + + start, end = int(start), int(end) + + if feature != "exon": + continue + + if start > end: + raise Exception( + "Start of region can not be greater than end of region for:\n", line + ) + + values_dict = {} + for attr in attributes.split(";"): + if attr: + attr, _, val = attr.strip().partition(" ") + values_dict[attr] = val.strip('"') + + if "gene_id" not in values_dict: + values_dict["gene_id"] = "NA" + + if "transcript_id" not in values_dict: + raise Exception("Exon does not contain transcript ID\n") + + transcript_id = values_dict["transcript_id"] + gene_id = values_dict["gene_id"] + + if transcript_id not in trans: + trans[transcript_id] = [chrom, strand, gene_id, [[start, end]]] else: - skip = 0 + trans[transcript_id][3].append([start, end]) + + for tran, [chrom, strand, gene_id, exons] in trans.items(): + exons.sort() + + junctions = collections.defaultdict(dict) + for tran, (chrom, strand, gene_id, exons) in trans.items(): + for i in range(1, len(exons)): + if ( + "transcripts" + not in junctions[(chrom, exons[i - 1][1], exons[i][0] - 1, strand)] + ): + junctions[(chrom, exons[i - 1][1], exons[i][0] - 1, strand)][ + "transcripts" + ] = [gene_id, [tran]] + else: + junctions[(chrom, exons[i - 1][1], exons[i][0] - 1, strand)][ + "transcripts" + ][1].append( + tran + ) # intron bed coordinates + + return junctions - df = pd.read_csv(out_path, sep='\t', header=None, skiprows=skip, comment='#') - - junctions = {} - for _, row in df.iterrows(): - junctions[(row[0],row[1],row[2],row[5])] = (name,row[4]) - return junctions - -def junction_extractor_multi_bam(bam_list:list, out_original_junctions:list, p:int) -> dict: - with multiprocessing.Pool(p) as pool: - results = pool.starmap(junction_extractor, zip(bam_list,out_original_junctions)) - - dd = collections.defaultdict(dict) - for d in results: - for key, value in d.items(): - if key not in dd: - dd[key]['samples']= set() - dd[key]['samples'].add(value) - dd[key]['score'] = value[1] - else: - dd[key]['samples'].add(value) - dd[key]['score'] = dd[key]['score'] + value[1] - - return dd - -def extract_splice_sites_gtf(gtf_path:str) -> dict: - trans = {} - gtf = open(gtf_path, "r") - - for line in gtf: - line = line.strip() - if line.startswith('#'): - continue - chrom, source, feature, start, end, score, \ - strand, frame, attributes = line.split('\t') - - start, end = int(start), int(end) - - if feature != 'exon': - continue - - if start > end: - raise Exception("Start of region can not be greater than end of region for:\n",line) - - values_dict = {} - for attr in attributes.split(';'): - if attr: - attr, _, val = attr.strip().partition(' ') - values_dict[attr] = val.strip('"') - - if 'gene_id' not in values_dict: - values_dict['gene_id'] = "NA" - - if 'transcript_id' not in values_dict: - raise Exception("Exon does not contain transcript ID\n") - - transcript_id = values_dict['transcript_id'] - gene_id = values_dict['gene_id'] - - if transcript_id not in trans: - trans[transcript_id] = [chrom, strand, gene_id, [[start, end]]] - else: - trans[transcript_id][3].append([start, end]) - - - for tran, [chrom, strand, gene_id, exons] in trans.items(): - exons.sort() - - - junctions = collections.defaultdict(dict) - for tran, (chrom, strand, gene_id, exons) in trans.items(): - for i in range(1, len(exons)): - if 'transcripts' not in junctions[(chrom, exons[i-1][1], exons[i][0]-1, strand)]: - junctions[(chrom, exons[i-1][1], exons[i][0]-1, strand)]['transcripts'] = [gene_id, [tran]] - else: - junctions[(chrom, exons[i-1][1], exons[i][0]-1, strand)]['transcripts'][1].append(tran) #intron bed coordinates - - return junctions def check_for_dependency(): - if not os.path.exists(JUNCTION_CMD): - raise RuntimeError(f"{JUNCTION_CMD} not found.") + if not os.path.exists(JUNCTION_CMD): + raise RuntimeError(f"{JUNCTION_CMD} not found.") diff --git a/src/eastr/get_spurious_introns.py b/src/eastr/get_spurious_introns.py index a2a0a5b..a15c671 100644 --- a/src/eastr/get_spurious_introns.py +++ b/src/eastr/get_spurious_introns.py @@ -13,237 +13,301 @@ def get_self_aligned_introns(introns, seqs, overhang, k, w, m, scoring): - self_introns = {} - for key, value in introns.items(): - rseq = seqs[introns[key]['jstart']] - qseq = seqs[introns[key]['jend']] - hit = alignment_utils.align_seq_pair(rseq, qseq, scoring,k,w,m) - if hit: - intron_len = key[2] - key[1] - if overhang * 2 >= intron_len: - if hit.r_st == intron_len: - hit = None - elif overhang > intron_len: - if hit.r_st - hit.q_st == intron_len: - hit = None - elif hit.r_st >= overhang and hit.q_en <= overhang: - hit = None - if hit: - self_introns[key] = value - self_introns[key]['hit'] = hit - return self_introns + self_introns = {} + for key, value in introns.items(): + rseq = seqs[introns[key]["jstart"]] + qseq = seqs[introns[key]["jend"]] + hit = alignment_utils.align_seq_pair(rseq, qseq, scoring, k, w, m) + if hit: + intron_len = key[2] - key[1] + if overhang * 2 >= intron_len: + if hit.r_st == intron_len: + hit = None + elif overhang > intron_len: + if hit.r_st - hit.q_st == intron_len: + hit = None + elif hit.r_st >= overhang and hit.q_en <= overhang: + hit = None + if hit: + self_introns[key] = value + self_introns[key]["hit"] = hit + return self_introns def linear_distance(string1, string2): - if len(string1)!=len(string2): - raise ValueError("strings must be of equal length") - distance = 0 - for i in range(len(string1)): - if string1[i] != string2[i]: - distance += 1 - return distance - -def get_middle_seq(len_,seq): - if len(seq) > len_: - middle_start = (len(seq) - len_) // 2 - middle_end = middle_start + len_ - seq = seq[middle_start:middle_end] - return seq - - -def bowtie2_align_self_introns_to_ref (introns_to_align, seqs, bt2_index, overhang, p=1, len_=15, bt2_k=10): - bt2_k = bt2_k + 1 - tmp_sam = tempfile.NamedTemporaryFile(dir=os.getcwd(),delete=False) - tmp_fa = tempfile.NamedTemporaryFile(mode='a',dir=os.getcwd(),delete=False) - for key,value in introns_to_align.items(): - read_name = ','.join([str(i) for i in key]) - hit = value['hit'] - rseq = seqs[value['jstart']] - qseq = seqs[value['jend']] - seq1 = get_middle_seq(len_*2, rseq[hit.r_st:hit.r_en]) - seq2 = get_middle_seq(len_*2, qseq[hit.q_st:hit.q_en]) - seqh = rseq[overhang-len_:overhang] + qseq[overhang:overhang+len_] - - x = tmp_fa.write(f'>{read_name},seq1\n{seq1}\n' + \ - f'>{read_name},seq2\n{seq2}\n' + \ - f'>{read_name},seqh\n{seqh}\n') - tmp_fa.close() - tmp_sam.close() - - #TODO -R {bt2_k} -N 2? - cmd = f"bowtie2 -p {p} --end-to-end -k {bt2_k} -D 20 -R 5 -L 20 -N 1 -i S,1,0.50 -x {bt2_index} -f {tmp_fa.name} -S {tmp_sam.name}" - subprocess.run(shlex.split(cmd),stderr=subprocess.DEVNULL, check=True) - samfile = pysam.AlignmentFile(tmp_sam.name,'r') - - d = collections.defaultdict(list) - for alignment in samfile.fetch(until_eof=True): - qname= alignment.qname.split(',') - qname[1] = int(qname[1]) - qname[2] = int(qname[2]) - qname = tuple(qname) - d[qname].append(alignment) - - if qname[4] == 'seqh': - if alignment.is_unmapped: - introns_to_align[qname[0:4]][qname[4]] = 0 - continue - - if qname[4] not in introns_to_align[qname[0:4]]: - introns_to_align[qname[0:4]][qname[4]] = 1 - - else: - introns_to_align[qname[0:4]][qname[4]] += 1 - - os.unlink(tmp_sam.name) - os.unlink(tmp_fa.name) - - return d #TODO return? - -def is_two_anchor_alignment(hit,overhang,anchor): - c1 = (hit.r_en < overhang + anchor - 1) - c2 = (hit.r_st > overhang - anchor) - c3 = (hit.q_en < overhang + anchor - 1) - c4 = (hit.q_st > overhang - anchor) - c5 = (abs(hit.r_st - hit.q_st) > anchor * 2) - if (c1 or c2 or c3 or c4 or c5): - return False - else: - return True - - -def is_spurious_alignment(key, value, seqs, overhang, min_duplicate_exon_length, bt2_k=10, anchor=7): - is_two_anchor = is_two_anchor_alignment(value['hit'],overhang,anchor) - hit = value['hit'] - - if is_two_anchor: - #check unique alignment - if (value['seq1'] == 1) or (value['seq2'] == 1): - if value['seqh'] == 0: - # print("unique alignment") - # print(key) - return False - + if len(string1) != len(string2): + raise ValueError("strings must be of equal length") + distance = 0 + for i in range(len(string1)): + if string1[i] != string2[i]: + distance += 1 + return distance + + +def get_middle_seq(len_, seq): + if len(seq) > len_: + middle_start = (len(seq) - len_) // 2 + middle_end = middle_start + len_ + seq = seq[middle_start:middle_end] + return seq + + +def bowtie2_align_self_introns_to_ref( + introns_to_align, seqs, bt2_index, overhang, p=1, len_=15, bt2_k=10 +): + bt2_k = bt2_k + 1 + tmp_sam = tempfile.NamedTemporaryFile(dir=os.getcwd(), delete=False) + tmp_fa = tempfile.NamedTemporaryFile(mode="a", dir=os.getcwd(), delete=False) + for key, value in introns_to_align.items(): + read_name = ",".join([str(i) for i in key]) + hit = value["hit"] + rseq = seqs[value["jstart"]] + qseq = seqs[value["jend"]] + seq1 = get_middle_seq(len_ * 2, rseq[hit.r_st : hit.r_en]) + seq2 = get_middle_seq(len_ * 2, qseq[hit.q_st : hit.q_en]) + seqh = rseq[overhang - len_ : overhang] + qseq[overhang : overhang + len_] + + tmp_fa.write( + f">{read_name},seq1\n{seq1}\n" + + f">{read_name},seq2\n{seq2}\n" + + f">{read_name},seqh\n{seqh}\n" + ) + tmp_fa.close() + tmp_sam.close() + + # TODO -R {bt2_k} -N 2? + cmd = f"bowtie2 -p {p} --end-to-end -k {bt2_k} -D 20 -R 5 -L 20 -N 1 -i S,1,0.50 -x {bt2_index} -f {tmp_fa.name} -S {tmp_sam.name}" + subprocess.run(shlex.split(cmd), stderr=subprocess.DEVNULL, check=True) + samfile = pysam.AlignmentFile(tmp_sam.name, "r") + + d = collections.defaultdict(list) + for alignment in samfile.fetch(until_eof=True): + qname = alignment.qname.split(",") + qname[1] = int(qname[1]) + qname[2] = int(qname[2]) + qname = tuple(qname) + d[qname].append(alignment) + + if qname[4] == "seqh": + if alignment.is_unmapped: + introns_to_align[qname[0:4]][qname[4]] = 0 + continue + + if qname[4] not in introns_to_align[qname[0:4]]: + introns_to_align[qname[0:4]][qname[4]] = 1 else: - #if duplicated exon: - if hit.q_st - hit.r_st >= min_duplicate_exon_length: - if value['seqh'] == 0: - # print("duplicated exon") - # print(key) - return False - - - if (value['seq1'] < bt2_k ) or (value['seq2'] < bt2_k): - if value['seqh'] == 0: - # print("partial alignment - no hybrid sequence found") - # print(key) - return False - - rseq = seqs[value['jstart']] - qseq = seqs[value['jend']] - e5 = rseq[overhang - anchor:overhang] - i5 = rseq[overhang:overhang + anchor] - i3 = qseq[overhang - anchor:overhang] - e3 = qseq[overhang:overhang + anchor] - distance_e5i3 = linear_distance(e5, i3) - distance_i5e3 = linear_distance(i5, e3) - if distance_e5i3 + distance_i5e3 > 2: - # print("partial alignment - no overhang") - # print(key) - return False - else: - # print("partial alignment - has overhang") - # print(key) - return True - - return True - - -def get_spurious_introns(self_introns, seqs, bt2_index, overhang, min_duplicate_exon_length, - anchor=7, min_junc_score=1, p=1, is_bam=True, bt2_k=10): - introns_to_align = {} - spurious = {} - for k,v in self_introns.items(): - - if is_bam and v['score'] <= min_junc_score: - spurious[k] = v - continue - - introns_to_align[k] = v - - bw2_alignments = bowtie2_align_self_introns_to_ref(introns_to_align, seqs, bt2_index, overhang, p=p, len_=15, bt2_k=bt2_k) + introns_to_align[qname[0:4]][qname[4]] += 1 - for k, v in introns_to_align.items(): - if is_spurious_alignment(k, v, seqs, overhang, min_duplicate_exon_length, anchor=anchor): - spurious[k] = v + os.unlink(tmp_sam.name) + os.unlink(tmp_fa.name) - return spurious + return d # TODO return? -def get_spurious_junctions(scoring, k, w, m, overhang, min_duplicate_exon_length, bt2_index, bt2_k, ref_fa, p, anchor, - min_junc_score, bam_list, gtf_path, - bed_path, trusted_bed, out_original_junctions, verbose): - chrom_sizes = utils.get_chroms_list_from_fasta(ref_fa) - - if bam_list: - is_bam = True - - #make tmp files if not given out_original_junctions - remove_tmp=False - utils.make_dir('tmp') - if out_original_junctions is None: - remove_tmp = True - out_original_junctions = [tempfile.NamedTemporaryFile(delete=False, dir='tmp', suffix='.bed') for i in range(len(bam_list))] - #close all the temporary files in out_original_junctions: - for f in out_original_junctions: - f.close() - out_original_junctions = [f.name for f in out_original_junctions] - - - if verbose: - start_extr = time.time() - print('extracting junctions from bam files...') - - introns = extract_junctions.junction_extractor_multi_bam(bam_list,out_original_junctions,p) +def is_two_anchor_alignment(hit, overhang, anchor): + c1 = hit.r_en < overhang + anchor - 1 + c2 = hit.r_st > overhang - anchor + c3 = hit.q_en < overhang + anchor - 1 + c4 = hit.q_st > overhang - anchor + c5 = abs(hit.r_st - hit.q_st) > anchor * 2 + if c1 or c2 or c3 or c4 or c5: + return False + else: + return True - if verbose: - end_extr = time.time() - print('extracting junctions took {} seconds'.format(end_extr-start_extr)) - #delete all the temporary files in out_original_junctions: - if remove_tmp: - for f in out_original_junctions: - os.remove(f) +def is_spurious_alignment( + key, value, seqs, overhang, min_duplicate_exon_length, bt2_k=10, anchor=7 +): + is_two_anchor = is_two_anchor_alignment(value["hit"], overhang, anchor) + hit = value["hit"] - elif gtf_path: - is_bam = False - introns = extract_junctions.extract_splice_sites_gtf(gtf_path) + if is_two_anchor: + # check unique alignment + if (value["seq1"] == 1) or (value["seq2"] == 1): + if value["seqh"] == 0: + # print("unique alignment") + # print(key) + return False - elif bed_path: - is_bam = False - if len(bed_path) == 1: - introns = extract_junctions.get_junctions_multi_bed(bed_path,p) + else: + # if duplicated exon: + if hit.q_st - hit.r_st >= min_duplicate_exon_length: + if value["seqh"] == 0: + # print("duplicated exon") + # print(key) + return False + if (value["seq1"] < bt2_k) or (value["seq2"] < bt2_k): + if value["seqh"] == 0: + # print("partial alignment - no hybrid sequence found") + # print(key) + return False - else: - raise ValueError('No input file given') + rseq = seqs[value["jstart"]] + qseq = seqs[value["jend"]] + e5 = rseq[overhang - anchor : overhang] + i5 = rseq[overhang : overhang + anchor] + i3 = qseq[overhang - anchor : overhang] + e3 = qseq[overhang : overhang + anchor] + distance_e5i3 = linear_distance(e5, i3) + distance_i5e3 = linear_distance(i5, e3) + if distance_e5i3 + distance_i5e3 > 2: + # print("partial alignment - no overhang") + # print(key) + return False + else: + # print("partial alignment - has overhang") + # print(key) + return True - if trusted_bed: - trusted_introns = extract_junctions.get_junctions_from_bed(trusted_bed) - introns = {k:v for k,v in introns.items() if k not in trusted_introns} + return True + + +def get_spurious_introns( + self_introns, + seqs, + bt2_index, + overhang, + min_duplicate_exon_length, + anchor=7, + min_junc_score=1, + p=1, + is_bam=True, + bt2_k=10, +): + introns_to_align = {} + spurious = {} + for k, v in self_introns.items(): + if is_bam and v["score"] <= min_junc_score: + spurious[k] = v + continue + + introns_to_align[k] = v + + bowtie2_align_self_introns_to_ref( + introns_to_align, seqs, bt2_index, overhang, p=p, len_=15, bt2_k=bt2_k + ) + + for k, v in introns_to_align.items(): + if is_spurious_alignment( + k, v, seqs, overhang, min_duplicate_exon_length, anchor=anchor + ): + spurious[k] = v + + return spurious + + +def get_spurious_junctions( + scoring, + k, + w, + m, + overhang, + min_duplicate_exon_length, + bt2_index, + bt2_k, + ref_fa, + p, + anchor, + min_junc_score, + bam_list, + gtf_path, + bed_path, + trusted_bed, + out_original_junctions, + verbose, +): + chrom_sizes = utils.get_chroms_list_from_fasta(ref_fa) + + if bam_list: + is_bam = True + + # make tmp files if not given out_original_junctions + remove_tmp = False + utils.make_dir("tmp") + if out_original_junctions is None: + remove_tmp = True + out_original_junctions = [ + tempfile.NamedTemporaryFile(delete=False, dir="tmp", suffix=".bed") + for i in range(len(bam_list)) + ] + # close all the temporary files in out_original_junctions: + for f in out_original_junctions: + f.close() + out_original_junctions = [f.name for f in out_original_junctions] if verbose: - print('Getting spurious junctions...') - start_spur = time.time() + start_extr = time.time() + print("extracting junctions from bam files...") - seqs = alignment_utils.get_flanking_subsequences(introns, chrom_sizes, overhang, ref_fa) - self_introns = get_self_aligned_introns(introns, seqs, overhang, k, w, m, scoring) - spurious_dict = get_spurious_introns(self_introns, seqs, bt2_index, overhang, min_duplicate_exon_length, anchor=anchor, - min_junc_score=min_junc_score, p=p, is_bam=is_bam, bt2_k=bt2_k) - spurious_dict = dict(sorted(spurious_dict.items(), key=lambda x: (x[0][1], x[0][1], x[0][2], x[0][3]))) + introns = extract_junctions.junction_extractor_multi_bam( + bam_list, out_original_junctions, p + ) if verbose: - end_spur = time.time() - print('Getting spurious junctions took {} seconds'.format(end_spur - start_spur)) - - return spurious_dict + end_extr = time.time() + print( + "extracting junctions took {} seconds".format(end_extr - start_extr) + ) + + # delete all the temporary files in out_original_junctions: + if remove_tmp: + for f in out_original_junctions: + os.remove(f) + + elif gtf_path: + is_bam = False + introns = extract_junctions.extract_splice_sites_gtf(gtf_path) + + elif bed_path: + is_bam = False + if len(bed_path) == 1: + introns = extract_junctions.get_junctions_multi_bed(bed_path, p) + + else: + raise ValueError("No input file given") + + if trusted_bed: + trusted_introns = extract_junctions.get_junctions_from_bed(trusted_bed) + introns = {k: v for k, v in introns.items() if k not in trusted_introns} + + if verbose: + print("Getting spurious junctions...") + start_spur = time.time() + + seqs = alignment_utils.get_flanking_subsequences( + introns, chrom_sizes, overhang, ref_fa + ) + self_introns = get_self_aligned_introns( + introns, seqs, overhang, k, w, m, scoring + ) + spurious_dict = get_spurious_introns( + self_introns, + seqs, + bt2_index, + overhang, + min_duplicate_exon_length, + anchor=anchor, + min_junc_score=min_junc_score, + p=p, + is_bam=is_bam, + bt2_k=bt2_k, + ) + spurious_dict = dict( + sorted( + spurious_dict.items(), + key=lambda x: (x[0][1], x[0][1], x[0][2], x[0][3]), + ) + ) + + if verbose: + end_spur = time.time() + print( + "Getting spurious junctions took {} seconds".format( + end_spur - start_spur + ) + ) + + return spurious_dict diff --git a/src/eastr/output_utils.py b/src/eastr/output_utils.py index f2df037..06918a2 100644 --- a/src/eastr/output_utils.py +++ b/src/eastr/output_utils.py @@ -14,280 +14,368 @@ this_directory = pathlib.Path(__file__).resolve().parent # This should exist with source after compilation. -VACUUM_CMD = os.path.join(this_directory, 'vacuum') - -def out_junctions_filelist(bam_list:list, gtf_path, bed_list, out_junctions, suffix="") -> Union[List[str], None, str]: - if out_junctions is None: - return None - - if gtf_path: - if utils.check_directory_or_file(out_junctions) == 'dir': - out_junctions= out_junctions + "/" + os.path.splitext(os.path.basename(gtf_path)[0]) + suffix + ".bed" - return out_junctions - - if bed_list: - if suffix in ["_original_junctions", ""]: - return None - - if len(bed_list) == 1: - if utils.check_directory_or_file(out_junctions) == 'dir': - out_junctions = f"{out_junctions}/{os.path.splitext(os.path.basename(gtf_path))[0]}{suffix}.bed" - path = os.path.dirname(out_junctions) - utils.make_dir(path) - return [out_junctions] - - if utils.check_directory_or_file(out_junctions) == 'file': - print("ERROR: the path provided for the output bed files is a file path, not a directory") - sys.exit(1) - - utils.make_dir(out_junctions) - result = [] - for bed in bed_list: - result.append(f"{out_junctions}/{os.path.splitext(os.path.basename(bed))[0]}{suffix}.bed") - return result - - if len(bam_list) == 1: - if utils.check_directory_or_file(out_junctions) == 'dir': - out_junctions = f"{out_junctions}/{os.path.splitext(os.path.basename(bam_list[0]))[0]}{suffix}.bed" - path = os.path.dirname(out_junctions) - utils.make_dir(path) - return [out_junctions] - - if utils.check_directory_or_file(out_junctions) == 'file': - print("ERROR: the path provided for the output bed files is a file path, not a directory") - sys.exit(1) +VACUUM_CMD = os.path.join(this_directory, "vacuum") + + +def out_junctions_filelist( + bam_list: list, gtf_path, bed_list, out_junctions, suffix="" +) -> Union[List[str], None, str]: + if out_junctions is None: + return None + + if gtf_path: + if utils.check_directory_or_file(out_junctions) == "dir": + out_junctions = ( + out_junctions + + "/" + + os.path.splitext(os.path.basename(gtf_path)[0]) + + suffix + + ".bed" + ) + return out_junctions + + if bed_list: + if suffix in ["_original_junctions", ""]: + return None + + if len(bed_list) == 1: + if utils.check_directory_or_file(out_junctions) == "dir": + out_junctions = f"{out_junctions}/{os.path.splitext(os.path.basename(gtf_path))[0]}{suffix}.bed" + path = os.path.dirname(out_junctions) + utils.make_dir(path) + return [out_junctions] + + if utils.check_directory_or_file(out_junctions) == "file": + print( + "ERROR: the path provided for the output bed files is a file path, not a directory" + ) + sys.exit(1) utils.make_dir(out_junctions) result = [] - for bam in bam_list: - result.append(out_junctions + "/" + os.path.splitext(os.path.basename(bam))[0] + suffix + ".bed") + for bed in bed_list: + result.append( + f"{out_junctions}/{os.path.splitext(os.path.basename(bed))[0]}{suffix}.bed" + ) return result -def out_filtered_bam_filelist(bam_list:list, out_filtered_bam, suffix="_EASTR_filtered") -> Union[List[str], None]: - result = None - if bam_list is None or out_filtered_bam is None: - return - - if len(bam_list) == 1: - if utils.check_directory_or_file(out_filtered_bam) == 'dir': - out_filtered_bam = out_filtered_bam + "/" + os.path.splitext(os.path.basename(bam_list[0]))[0] + suffix + ".bam" - path = os.path.dirname(out_filtered_bam) - utils.make_dir(path) - result = [out_filtered_bam] - - else: - if utils.check_directory_or_file(out_filtered_bam) == 'file': - print("ERROR: the path provided for the output file is a file, not a directory") - sys.exit(1) - utils.make_dir(out_filtered_bam) - result = [] - for bam in bam_list: - result.append(out_filtered_bam + "/" + os.path.splitext(os.path.basename(bam))[0] + suffix + ".bam") + if len(bam_list) == 1: + if utils.check_directory_or_file(out_junctions) == "dir": + out_junctions = f"{out_junctions}/{os.path.splitext(os.path.basename(bam_list[0]))[0]}{suffix}.bed" + path = os.path.dirname(out_junctions) + utils.make_dir(path) + return [out_junctions] + + if utils.check_directory_or_file(out_junctions) == "file": + print( + "ERROR: the path provided for the output bed files is a file path, not a directory" + ) + sys.exit(1) + + utils.make_dir(out_junctions) + result = [] + for bam in bam_list: + result.append( + out_junctions + + "/" + + os.path.splitext(os.path.basename(bam))[0] + + suffix + + ".bed" + ) + return result + + +def out_filtered_bam_filelist( + bam_list: list, out_filtered_bam, suffix="_EASTR_filtered" +) -> Union[List[str], None]: + result = None + if bam_list is None or out_filtered_bam is None: + return + + if len(bam_list) == 1: + if utils.check_directory_or_file(out_filtered_bam) == "dir": + out_filtered_bam = ( + out_filtered_bam + + "/" + + os.path.splitext(os.path.basename(bam_list[0]))[0] + + suffix + + ".bam" + ) + path = os.path.dirname(out_filtered_bam) + utils.make_dir(path) + result = [out_filtered_bam] + + else: + if utils.check_directory_or_file(out_filtered_bam) == "file": + print( + "ERROR: the path provided for the output file is a file, not a directory" + ) + sys.exit(1) + utils.make_dir(out_filtered_bam) + result = [] + for bam in bam_list: + result.append( + out_filtered_bam + + "/" + + os.path.splitext(os.path.basename(bam))[0] + + suffix + + ".bam" + ) + + return result - return result def writer_spurious_dict_bam_to_bed(spurious_dict, named_keys, scoring, writer): - for key, value in spurious_dict.items(): - chrom, start, end, strand = key - name = named_keys[key] - score = value['score'] - samples = value['samples'] - score2 = alignment_utils.calc_alignment_score(value['hit'],scoring) - name2 = ';'.join([f"{score2},{sample_id}" for score2, sample_id in samples]) - writer.writerow([chrom, start, end, name, score, strand, score2, name2]) + for key, value in spurious_dict.items(): + chrom, start, end, strand = key + name = named_keys[key] + score = value["score"] + samples = value["samples"] + score2 = alignment_utils.calc_alignment_score(value["hit"], scoring) + name2 = ";".join([f"{score2},{sample_id}" for score2, sample_id in samples]) + writer.writerow([chrom, start, end, name, score, strand, score2, name2]) def writer_spurious_dict_gtf_to_bed(spurious_dict, named_keys, scoring, writer): - for key, value in spurious_dict.items(): - chrom, start, end, strand = key - gene_id = value['transcripts'][0] - transcripts = value['transcripts'][1] - name = named_keys[key] - score = '.' - score2 = alignment_utils.calc_alignment_score(value['hit'],scoring) - name2 = ';'.join([f"{gene_id}", *transcripts]) - writer.writerow([chrom, start, end, name, score, strand, score2, name2]) + for key, value in spurious_dict.items(): + chrom, start, end, strand = key + gene_id = value["transcripts"][0] + transcripts = value["transcripts"][1] + name = named_keys[key] + score = "." + score2 = alignment_utils.calc_alignment_score(value["hit"], scoring) + name2 = ";".join([f"{gene_id}", *transcripts]) + writer.writerow([chrom, start, end, name, score, strand, score2, name2]) + def writer_spurious_dict_bed_to_bed(spurious_dict, named_keys, scoring, writer): - for key, value in spurious_dict.items(): - chrom, start, end, strand = key - name = named_keys[key] - samples = value['samples'] - score = value['score'] - name2 = ';'.join([f"{name2},{score2}" for _, name2, score2 in samples]) - writer.writerow([chrom, start, end, name, score, strand, name2]) - - -def spurious_dict_all_to_bed(spurious_dict,scoring,fileout,gtf_path, bed_list, bam_list): - sorted_keys = spurious_dict.keys() - named_keys = {} - for i, key in enumerate(sorted_keys): - name = "JUNC{}".format(i+1) - named_keys[key] = name - out = io.StringIO() - writer = csv.writer(out, delimiter='\t') - - if gtf_path is not None: - writer_spurious_dict_gtf_to_bed(spurious_dict, named_keys, scoring, writer) - elif bed_list is not None: - writer_spurious_dict_bed_to_bed(spurious_dict, named_keys, scoring, writer) - else: - writer_spurious_dict_bam_to_bed(spurious_dict, named_keys, scoring, writer) - - if fileout is None: - print(out.getvalue()) - else: - with open(fileout, 'w') as out_file: - out_file.write(out.getvalue()) + for key, value in spurious_dict.items(): + chrom, start, end, strand = key + name = named_keys[key] + samples = value["samples"] + score = value["score"] + name2 = ";".join([f"{name2},{score2}" for _, name2, score2 in samples]) + writer.writerow([chrom, start, end, name, score, strand, name2]) + + +def spurious_dict_all_to_bed( + spurious_dict, scoring, fileout, gtf_path, bed_list, bam_list +): + sorted_keys = spurious_dict.keys() + named_keys = {} + for i, key in enumerate(sorted_keys): + name = "JUNC{}".format(i + 1) + named_keys[key] = name + out = io.StringIO() + writer = csv.writer(out, delimiter="\t") + + if gtf_path is not None: + writer_spurious_dict_gtf_to_bed(spurious_dict, named_keys, scoring, writer) + elif bed_list is not None: + writer_spurious_dict_bed_to_bed(spurious_dict, named_keys, scoring, writer) + else: + writer_spurious_dict_bam_to_bed(spurious_dict, named_keys, scoring, writer) + + if fileout is None: + print(out.getvalue()) + else: + with open(fileout, "w") as out_file: + out_file.write(out.getvalue()) def create_sample_to_bed_dict(sample_names, out_removed_junctions_filelist): - sample_to_bed = {} - + sample_to_bed = {} + + for sample in sample_names: + shortest_match = None + for sample_id in out_removed_junctions_filelist: + if sample in sample_id: + if shortest_match is None or len(sample_id) < len(shortest_match): + shortest_match = sample_id + if shortest_match is not None: # Only update if a match was found + file_path = out_removed_junctions_filelist[ + out_removed_junctions_filelist.index(shortest_match) + ] + file_obj = open(file_path, mode="w+b") + sample_to_bed[sample] = file_obj + + return sample_to_bed + + +def spurious_dict_bed_by_sample_to_bed( + spurious_dict, bed_list, out_removed_junctions_filelist, scoring +): + sample_names = [ + os.path.splitext(os.path.basename(bed_path))[0] for bed_path in bed_list + ] + sample_to_bed = create_sample_to_bed_dict( + sample_names, out_removed_junctions_filelist + ) + + if out_removed_junctions_filelist is None: + return + + sorted_keys = sorted(spurious_dict.keys()) + num_digits = len(str(len(sorted_keys))) + + for i, key in enumerate(sorted_keys): + chrom, start, end, strand = key + samples = spurious_dict[key]["samples"] + score2 = alignment_utils.calc_alignment_score( + spurious_dict[key]["hit"], scoring + ) + name = "JUNC{:0{}d}".format(i + 1, num_digits) + + for _, name2, score in samples: + sample_to_bed[name2].write( + f"{chrom}\t{start}\t{end}\t{name}\t{score}\t{strand}\t{name2}\t{score2}\n".encode() + ) + + for sample in sample_names: + sample_to_bed[sample].close() + sample_to_bed[sample] = sample_to_bed[sample].name + + return sample_to_bed + + +def spurious_dict_bam_by_sample_to_bed( + spurious_dict, bam_list, out_removed_junctions_filelist, scoring +): + + # get sample name from bam_list + sample_names = [ + os.path.splitext(os.path.basename(bam_path))[0] for bam_path in bam_list + ] + + # dictionary where the key is the sample name and the value is a file path + sample_to_bed = {} + + if out_removed_junctions_filelist is None: for sample in sample_names: - shortest_match = None - for sample_id in out_removed_junctions_filelist: - if sample in sample_id: - if shortest_match is None or len(sample_id) < len(shortest_match): - shortest_match = sample_id - if shortest_match is not None: # Only update if a match was found - file_path = out_removed_junctions_filelist[out_removed_junctions_filelist.index(shortest_match)] - file_obj = open(file_path, mode='w+b') - sample_to_bed[sample] = file_obj - - return sample_to_bed - - - -def spurious_dict_bed_by_sample_to_bed(spurious_dict, bed_list, out_removed_junctions_filelist, scoring): - sample_names = [os.path.splitext(os.path.basename(bed_path))[0] for bed_path in bed_list] - sample_to_bed = create_sample_to_bed_dict(sample_names, out_removed_junctions_filelist) - - if out_removed_junctions_filelist is None: - return - - sorted_keys = sorted(spurious_dict.keys()) - num_digits = len(str(len(sorted_keys))) - - for i, key in enumerate(sorted_keys): - chrom, start, end, strand = key - samples = spurious_dict[key]['samples'] - score2 = alignment_utils.calc_alignment_score(spurious_dict[key]['hit'], scoring) - name = "JUNC{:0{}d}".format(i+1, num_digits) - - for _, name2, score in samples: - sample_to_bed[name2].write(f'{chrom}\t{start}\t{end}\t{name}\t{score}\t{strand}\t{name2}\t{score2}\n'.encode()) - - for sample in sample_names: - sample_to_bed[sample].close() - sample_to_bed[sample] = sample_to_bed[sample].name - - return sample_to_bed - - -def spurious_dict_bam_by_sample_to_bed(spurious_dict, bam_list, out_removed_junctions_filelist, scoring): - - #get sample name from bam_list - sample_names = [os.path.splitext(os.path.basename(bam_path))[0] for bam_path in bam_list] - - #dictionary where the key is the sample name and the value is a file path - sample_to_bed = {} - - if out_removed_junctions_filelist is None: - for sample in sample_names: - sample_to_bed[sample] = tempfile.NamedTemporaryFile(delete=False, dir='tmp', suffix='.bed') - - else: - sample_to_bed = create_sample_to_bed_dict(sample_names, out_removed_junctions_filelist) - - sorted_keys = sorted(spurious_dict.keys()) - named_keys = {} - num_digits = len(str(len(sorted_keys))) - for i, key in enumerate(sorted_keys): - name = "JUNC{:0{}d}".format(i+1, num_digits) - named_keys[(name,) + key] = spurious_dict[key] - named_keys[(name,) + key]['score2'] = alignment_utils.calc_alignment_score(spurious_dict[key]['hit'], scoring) - - for (name, chrom, start, end, strand), value in named_keys.items(): - samples = value['samples'] - score2 = value['score2'] - for (sample, score) in samples: - sample_to_bed[sample].write(f'{chrom}\t{start}\t{end}\t{name}\t{score}\t{strand}\t{score2}\n'.encode()) - - for sample in sample_names: - sample_to_bed[sample].close() - sample_to_bed[sample] = sample_to_bed[sample].name - - return sample_to_bed - - -def filter_bam_with_vacuum(bam_path, spurious_junctions_bed, out_bam_path, verbose, removed_alignments_bam): - check_for_dependency() - vacuum_cmd = f"{VACUUM_CMD} --remove_mate " - if verbose: - vacuum_cmd = f"{vacuum_cmd} -V" - if removed_alignments_bam: - out_bam_name = os.path.splitext(out_bam_path)[0] - vacuum_cmd = f'{vacuum_cmd} -r {out_bam_name}_removed_alignments.bam' - vacuum_cmd = f'{vacuum_cmd} -o {out_bam_path} {bam_path} {spurious_junctions_bed}' - vacuum_cmd = shlex.split(vacuum_cmd) - - #use subprocess to run vacuum_cmd and return stdout and stderr - process = subprocess.Popen((vacuum_cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE) - out, err = process.communicate() - if process.returncode != 0: - print(f"Error running vacuum. Return code: {process.returncode}") - print(f"stdout: {out.decode('utf-8')}") - print(f"stderr: {err.decode('utf-8')}") - sys.exit(1) - - return out.decode() - -def filter_multi_bam_with_vacuum(bam_list, sample_to_bed, out_bam_list, p, verbose, removed_alignments_bam): - #if verbose is true, make a vector of True values for each bam file - if verbose: - verbose = [True for bam in bam_list] - else: - verbose = [False for bam in bam_list] - - if removed_alignments_bam: - removed_alignments_bam = [True for bam in bam_list] - else: - removed_alignments_bam = [False for bam in bam_list] - - - sample_names = [os.path.splitext(os.path.basename(bam_path))[0] for bam_path in bam_list] - #run filter_bam_with_vacuum in parallel with multiprocessing starmap - pool = multiprocessing.Pool(processes=p) - with pool: - outs = pool.starmap(filter_bam_with_vacuum, zip(bam_list, [sample_to_bed[sample] for sample in sample_names], - out_bam_list, verbose, removed_alignments_bam)) - if verbose: - for out in outs: - print(out) - -#def write_gtf_to_bed(spurious_dict, out_removed_junctions_filelist, scoring): - # sorted_keys = spurious_dict.keys() - # named_keys = {} - # for i, key in enumerate(sorted_keys): - # name = "JUNC{}".format(i+1) - # named_keys[key] = name - - # out_io_dict = {} - # for i,sample in enumerate(sample_names): - # out_io_dict[sample] = [StringIO(), out_removed_junctions_filelist[i]] - - # for key, value in spurious_dict.items(): - # name = named_keys[key] - # samples = value['samples'] - # score2 = alignment_utils.calc_alignment_score(value['hit'],scoring) - # for i, sample in enumerate(samples): - # score, sample_id = sample - # out_io_dict[sample_id][0].write(f"{key[0]}\t{key[1]}\t{key[2]}\t{name}\t{score}\t{key[3]}\t{score2}") - - # for (out,filepath) in out_io_dict.values(): - # with open(filepath, 'w') as out_removed_junctions: - # out_removed_junctions.write(out.getvalue()) + sample_to_bed[sample] = tempfile.NamedTemporaryFile( + delete=False, dir="tmp", suffix=".bed" + ) + + else: + sample_to_bed = create_sample_to_bed_dict( + sample_names, out_removed_junctions_filelist + ) + + sorted_keys = sorted(spurious_dict.keys()) + named_keys = {} + num_digits = len(str(len(sorted_keys))) + for i, key in enumerate(sorted_keys): + name = "JUNC{:0{}d}".format(i + 1, num_digits) + named_keys[(name,) + key] = spurious_dict[key] + named_keys[(name,) + key]["score2"] = alignment_utils.calc_alignment_score( + spurious_dict[key]["hit"], scoring + ) + + for (name, chrom, start, end, strand), value in named_keys.items(): + samples = value["samples"] + score2 = value["score2"] + for sample, score in samples: + sample_to_bed[sample].write( + f"{chrom}\t{start}\t{end}\t{name}\t{score}\t{strand}\t{score2}\n".encode() + ) + + for sample in sample_names: + sample_to_bed[sample].close() + sample_to_bed[sample] = sample_to_bed[sample].name + + return sample_to_bed + + +def filter_bam_with_vacuum( + bam_path, + spurious_junctions_bed, + out_bam_path, + verbose, + removed_alignments_bam, +): + check_for_dependency() + vacuum_cmd = f"{VACUUM_CMD} --remove_mate " + if verbose: + vacuum_cmd = f"{vacuum_cmd} -V" + if removed_alignments_bam: + out_bam_name = os.path.splitext(out_bam_path)[0] + vacuum_cmd = f"{vacuum_cmd} -r {out_bam_name}_removed_alignments.bam" + vacuum_cmd = ( + f"{vacuum_cmd} -o {out_bam_path} {bam_path} {spurious_junctions_bed}" + ) + vacuum_cmd = shlex.split(vacuum_cmd) + + # use subprocess to run vacuum_cmd and return stdout and stderr + process = subprocess.Popen( + (vacuum_cmd), stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + out, err = process.communicate() + if process.returncode != 0: + print(f"Error running vacuum. Return code: {process.returncode}") + print(f"stdout: {out.decode('utf-8')}") + print(f"stderr: {err.decode('utf-8')}") + sys.exit(1) + + return out.decode() + + +def filter_multi_bam_with_vacuum( + bam_list, sample_to_bed, out_bam_list, p, verbose, removed_alignments_bam +): + # if verbose is true, make a vector of True values for each bam file + if verbose: + verbose = [True for bam in bam_list] + else: + verbose = [False for bam in bam_list] + + if removed_alignments_bam: + removed_alignments_bam = [True for bam in bam_list] + else: + removed_alignments_bam = [False for bam in bam_list] + + sample_names = [ + os.path.splitext(os.path.basename(bam_path))[0] for bam_path in bam_list + ] + # run filter_bam_with_vacuum in parallel with multiprocessing starmap + pool = multiprocessing.Pool(processes=p) + with pool: + outs = pool.starmap( + filter_bam_with_vacuum, + zip( + bam_list, + [sample_to_bed[sample] for sample in sample_names], + out_bam_list, + verbose, + removed_alignments_bam, + ), + ) + if verbose: + for out in outs: + print(out) + + +# def write_gtf_to_bed(spurious_dict, out_removed_junctions_filelist, scoring): +# sorted_keys = spurious_dict.keys() +# named_keys = {} +# for i, key in enumerate(sorted_keys): +# name = "JUNC{}".format(i+1) +# named_keys[key] = name + +# out_io_dict = {} +# for i,sample in enumerate(sample_names): +# out_io_dict[sample] = [StringIO(), out_removed_junctions_filelist[i]] + +# for key, value in spurious_dict.items(): +# name = named_keys[key] +# samples = value['samples'] +# score2 = alignment_utils.calc_alignment_score(value['hit'],scoring) +# for i, sample in enumerate(samples): +# score, sample_id = sample +# out_io_dict[sample_id][0].write(f"{key[0]}\t{key[1]}\t{key[2]}\t{name}\t{score}\t{key[3]}\t{score2}") + +# for (out,filepath) in out_io_dict.values(): +# with open(filepath, 'w') as out_removed_junctions: +# out_removed_junctions.write(out.getvalue()) # if __name__ == '__main__': # # import time @@ -296,7 +384,8 @@ def filter_multi_bam_with_vacuum(bam_list, sample_to_bed, out_bam_list, p, verbo # # end = time.time() # # print(f"took {(end-start)/60} mins")) + def check_for_dependency(): - """Check if runtime dependency exists.""" - if not os.path.exists(VACUUM_CMD): - raise RuntimeError(f"{VACUUM_CMD} not found.") + """Check if runtime dependency exists.""" + if not os.path.exists(VACUUM_CMD): + raise RuntimeError(f"{VACUUM_CMD} not found.") diff --git a/src/eastr/run_eastr.py b/src/eastr/run_eastr.py index abaa267..c63f893 100644 --- a/src/eastr/run_eastr.py +++ b/src/eastr/run_eastr.py @@ -5,280 +5,364 @@ from eastr import output_utils from eastr import utils -def parse_args(): - parser = argparse.ArgumentParser( - prog="eastr", - description="eastr: Emending alignments of spuriously spliced transcript reads. " - "The script takes GTF, BED, or BAM files as input and processes them using " - "the provided reference genome and BowTie2 index. It identifies spurious junctions " - "and filters the input data accordingly." - ) - - #required args - group_reqd = parser.add_mutually_exclusive_group(required=True) - group_reqd.add_argument('--gtf', help='Input GTF file containing transcript annotations') - group_reqd.add_argument('--bed', help='Input BED file with intron coordinates') - group_reqd.add_argument('--bam', help='Input BAM file or a TXT file containing a list of BAM files with read alignments') - parser.add_argument("-r", "--reference", required=True, help="reference FASTA genome used in alignment") - parser.add_argument('-i','--bowtie2_index', required=True, help='Path to Bowtie2 index for the reference genome') - - #bt2 args: - parser.add_argument( - "--bt2_k", - help="Minimum number of distinct alignments found by bowtie2 for a junction to be \ +def parse_args(): + parser = argparse.ArgumentParser( + prog="eastr", + description="eastr: Emending alignments of spuriously spliced transcript reads. " + "The script takes GTF, BED, or BAM files as input and processes them using " + "the provided reference genome and BowTie2 index. It identifies spurious junctions " + "and filters the input data accordingly.", + ) + + # required args + group_reqd = parser.add_mutually_exclusive_group(required=True) + group_reqd.add_argument( + "--gtf", help="Input GTF file containing transcript annotations" + ) + group_reqd.add_argument( + "--bed", help="Input BED file with intron coordinates" + ) + group_reqd.add_argument( + "--bam", + help="Input BAM file or a TXT file containing a list of BAM files with read alignments", + ) + parser.add_argument( + "-r", + "--reference", + required=True, + help="reference FASTA genome used in alignment", + ) + parser.add_argument( + "-i", + "--bowtie2_index", + required=True, + help="Path to Bowtie2 index for the reference genome", + ) + + # bt2 args: + parser.add_argument( + "--bt2_k", + help="Minimum number of distinct alignments found by bowtie2 for a junction to be \ considered spurious. Default: 10", - default=10, - type=int - ) - #EASTR args - parser.add_argument( - "-o", - help="Length of the overhang on either side of the splice junction. Default = 50", - default=50, - type=int) - - parser.add_argument( - "--min_duplicate_exon_length", - help="Minimum length of the duplicated exon. Default = 27", - default=27, - type=int - ) - - parser.add_argument( - "-a", - help="Minimum required anchor length in each of the two exons, default = 7", - default=7, - type=int - ) - - parser.add_argument( - "--min_junc_score", - help=" Minimum number of supporting spliced reads required per junction. " - "Junctions with fewer supporting reads in all samples are filtered out " - "if the flanking regions are similar (based on mappy scoring matrix). Default: 1", - default=1, - type=int) - - parser.add_argument( - "--trusted_bed", - help="Path to a BED file path with trusted junctions, which will not be removed by EASTR." - ) - - parser.add_argument( - "--verbose", default=False, action="store_true", - help="Display additional information during BAM filtering, " - "including the count of total spliced alignments and removed alignments") - - - parser.add_argument( #TODO: directory instead of store_true - "--removed_alignments_bam", default=False, action="store_true", - help="Write removed alignments to a BAM file") - - #minimap2 args - group_mm2 = parser.add_argument_group('Minimap2 parameters') - - group_mm2.add_argument( - "-A", - help="Matching score. Default = 3", - default=3, - type=int) - - group_mm2.add_argument( - "-B", - help="Mismatching penalty. Default = 4", - default=4, - type=int) - - group_mm2.add_argument( - "-O", - nargs=2, - type=int, - help="Gap open penalty. Default = [12, 32]", - default=[12,32]) - - group_mm2.add_argument( - "-E", - nargs=2, - type=int, - help="Gap extension penalty. A gap of length k costs min(O1+k*E1, O2+k*E2). Default = [2, 1]", - default=[2,1]) - - group_mm2.add_argument( - "-k", - help="K-mer length for alignment. Default=3", - default=3, - type=int - ) - - group_mm2.add_argument( - "--scoreN", - help="Score of a mismatch involving ambiguous bases. Default=1", - default=1, - type=int - ) - - group_mm2.add_argument( - "-w", - help="Minimizer window size. Default=2", - default=2, - type=int - ) + default=10, + type=int, + ) + # EASTR args + parser.add_argument( + "-o", + help="Length of the overhang on either side of the splice junction. Default = 50", + default=50, + type=int, + ) + + parser.add_argument( + "--min_duplicate_exon_length", + help="Minimum length of the duplicated exon. Default = 27", + default=27, + type=int, + ) + + parser.add_argument( + "-a", + help="Minimum required anchor length in each of the two exons, default = 7", + default=7, + type=int, + ) + + parser.add_argument( + "--min_junc_score", + help=" Minimum number of supporting spliced reads required per junction. " + "Junctions with fewer supporting reads in all samples are filtered out " + "if the flanking regions are similar (based on mappy scoring matrix). Default: 1", + default=1, + type=int, + ) + + parser.add_argument( + "--trusted_bed", + help="Path to a BED file path with trusted junctions, which will not be removed by EASTR.", + ) + + parser.add_argument( + "--verbose", + default=False, + action="store_true", + help="Display additional information during BAM filtering, " + "including the count of total spliced alignments and removed alignments", + ) + + parser.add_argument( # TODO: directory instead of store_true + "--removed_alignments_bam", + default=False, + action="store_true", + help="Write removed alignments to a BAM file", + ) + + # minimap2 args + group_mm2 = parser.add_argument_group("Minimap2 parameters") + + group_mm2.add_argument( + "-A", help="Matching score. Default = 3", default=3, type=int + ) + + group_mm2.add_argument( + "-B", help="Mismatching penalty. Default = 4", default=4, type=int + ) + + group_mm2.add_argument( + "-O", + nargs=2, + type=int, + help="Gap open penalty. Default = [12, 32]", + default=[12, 32], + ) + + group_mm2.add_argument( + "-E", + nargs=2, + type=int, + help="Gap extension penalty. A gap of length k costs min(O1+k*E1, O2+k*E2). Default = [2, 1]", + default=[2, 1], + ) + + group_mm2.add_argument( + "-k", help="K-mer length for alignment. Default=3", default=3, type=int + ) + + group_mm2.add_argument( + "--scoreN", + help="Score of a mismatch involving ambiguous bases. Default=1", + default=1, + type=int, + ) + + group_mm2.add_argument( + "-w", help="Minimizer window size. Default=2", default=2, type=int + ) + + group_mm2.add_argument( + "-m", + help="Discard chains with chaining score. Default=25.", + default=25, + type=int, + ) + + # output args + group_out = parser.add_argument_group("Output") + + group_out.add_argument( + "--out_original_junctions", + default=None, + metavar="OUT", + help="Write original junctions to the OUT file or directory", + ) + + group_out.add_argument( + "--out_removed_junctions", + default="stdout", + metavar="OUT", + help="Write removed junctions to OUT file or directory; the default output is to terminal", + ) + + group_out.add_argument( + "--out_filtered_bam", + metavar="OUT", + default=None, + help="Write filtered bams to OUT file or directory", + ) + + group_out.add_argument( + "--filtered_bam_suffix", + metavar="STR", + default="_EASTR_filtered", + help="Suffix added to the name of the output BAM files. Default='_EASTR_filtered'", + ) + + # other args + parser.add_argument( + "-p", help="Number of parallel processes, default=1", default=1, type=int + ) + + return parser.parse_args() - group_mm2.add_argument( - "-m", - help="Discard chains with chaining score. Default=25.", - default=25, - type=int - ) - #output args - group_out = parser.add_argument_group('Output') - - group_out.add_argument("--out_original_junctions", default=None, metavar='OUT', - help="Write original junctions to the OUT file or directory") - - group_out.add_argument("--out_removed_junctions", default='stdout', metavar='OUT', - help="Write removed junctions to OUT file or directory; the default output is to terminal") +def minimap_scoring(args): + gap_open_penalty = args.O + gap_ext_penalty = args.E + mismatch_penalty = args.B + match_score = args.A + ambiguous_score = args.scoreN - group_out.add_argument("--out_filtered_bam", metavar='OUT',default=None, - help="Write filtered bams to OUT file or directory") + scoring = [ + match_score, + mismatch_penalty, + gap_open_penalty[0], + gap_ext_penalty[0], + gap_open_penalty[1], + gap_ext_penalty[1], + ambiguous_score, + ] - group_out.add_argument("--filtered_bam_suffix", metavar='STR',default="_EASTR_filtered", - help="Suffix added to the name of the output BAM files. Default='_EASTR_filtered'") + return scoring - #other args - parser.add_argument( - "-p", - help="Number of parallel processes, default=1", - default=1, - type=int +def main(): + args = parse_args() + + # required input args + bam_list = args.bam + gtf_path = args.gtf + bed_list = args.bed + ref_fa = args.reference + bt2_index = args.bowtie2_index + bt2_k = args.bt2_k + + # EASTR variables + overhang = args.o + min_duplicate_exon_length = args.min_duplicate_exon_length + min_junc_score = args.min_junc_score + anchor = args.a + trusted_bed = args.trusted_bed + verbose = args.verbose + removed_alignments_bam = args.removed_alignments_bam + + # mm2 variables + scoring = minimap_scoring(args) + k = args.k + w = args.w + m = args.m + + # output args + suffix = args.filtered_bam_suffix + out_original_junctions = args.out_original_junctions + out_removed_junctions = args.out_removed_junctions + out_filtered_bam = args.out_filtered_bam + + # other args + p = args.p + + # index reference fasta if it's not indexed + utils.index_fasta(ref_fa) + + # check if the input is a bam file or a list of bam files + is_bam = False + if bam_list: + is_bam = True + + # if a single bam file is provided + extension = os.path.splitext(os.path.basename(bam_list))[1] + if extension in [".bam", ".cram", ".sam"]: + bam_list = [bam_list] + + else: + with open(bam_list) as file: + bam_list = [line.rstrip() for line in file] + for bam in bam_list: + if not os.path.isfile(bam): + raise ValueError( + "input must be a bam file or a file containing a list of bam files" + ) + + elif bed_list: + extension = os.path.splitext(os.path.basename(bed_list))[1] + if extension in [".bed"]: + bed_list = [bed_list] + + else: + with open(bed_list) as file: + bed_list = [line.rstrip() for line in file] + for bed in bed_list: + if not os.path.isfile(bed): + raise ValueError( + "input must be a bed file or a file containing a list of bed files" + ) + + original_junctions_filelist = output_utils.out_junctions_filelist( + bam_list, + gtf_path, + bed_list, + out_original_junctions, + suffix="_original_junctions", + ) + removed_junctions_filelist = output_utils.out_junctions_filelist( + bam_list, + gtf_path, + bed_list, + out_removed_junctions, + suffix="_removed_junctions", + ) + filtered_bam_filelist = output_utils.out_filtered_bam_filelist( + bam_list, out_filtered_bam, suffix=suffix + ) + + spurious_dict = get_spurious_introns.get_spurious_junctions( + scoring, + k, + w, + m, + overhang, + min_duplicate_exon_length, + bt2_index, + bt2_k, + ref_fa, + p, + anchor, + min_junc_score, + bam_list, + gtf_path, + bed_list, + trusted_bed, + original_junctions_filelist, + verbose, + ) + + if is_bam: + if filtered_bam_filelist: + sample_to_bed = output_utils.spurious_dict_bam_by_sample_to_bed( + spurious_dict, bam_list, removed_junctions_filelist, scoring=scoring + ) + output_utils.filter_multi_bam_with_vacuum( + bam_list, + sample_to_bed, + filtered_bam_filelist, + p, + verbose, + removed_alignments_bam, + ) + if removed_junctions_filelist is None: + for _, sample in sample_to_bed.items(): + os.remove(sample) + elif removed_junctions_filelist: + sample_to_bed = output_utils.spurious_dict_bam_by_sample_to_bed( + spurious_dict, bam_list, removed_junctions_filelist, scoring=scoring + ) + + else: + output_utils.spurious_dict_all_to_bed( + spurious_dict, scoring, None, gtf_path, bed_list, bam_list + ) + + if gtf_path: + output_utils.spurious_dict_all_to_bed( + spurious_dict, + scoring, + removed_junctions_filelist, + gtf_path, + bed_list, + bam_list, ) + elif bed_list: + if removed_junctions_filelist: + output_utils.spurious_dict_bed_by_sample_to_bed( + spurious_dict, bed_list, removed_junctions_filelist, scoring + ) + else: + output_utils.spurious_dict_all_to_bed( + spurious_dict, scoring, None, gtf_path, bed_list, bam_list + ) - return parser.parse_args() - -def minimap_scoring(args): - gap_open_penalty = args.O - gap_ext_penalty = args.E - mismatch_penalty = args.B - match_score = args.A - ambiguous_score = args.scoreN - - scoring=[ - match_score, - mismatch_penalty, - gap_open_penalty[0], - gap_ext_penalty[0], - gap_open_penalty[1], - gap_ext_penalty[1], - ambiguous_score] - return scoring - - -def main(): - args = parse_args() - - #required input args - bam_list = args.bam - gtf_path = args.gtf - bed_list = args.bed - ref_fa = args.reference - bt2_index = args.bowtie2_index - bt2_k = args.bt2_k - - #EASTR variables - overhang = args.o - min_duplicate_exon_length = args.min_duplicate_exon_length - min_junc_score = args.min_junc_score - anchor = args.a - trusted_bed = args.trusted_bed - verbose = args.verbose - removed_alignments_bam = args.removed_alignments_bam - - #mm2 variables - scoring = minimap_scoring(args) - k = args.k - w = args.w - m = args.m - - #output args - suffix = args.filtered_bam_suffix - out_original_junctions = args.out_original_junctions - out_removed_junctions = args.out_removed_junctions - out_filtered_bam = args.out_filtered_bam - - - #other args - p = args.p - - #index reference fasta if it's not indexed - utils.index_fasta(ref_fa) - - #check if the input is a bam file or a list of bam files - is_bam = False - if bam_list: - is_bam = True - - #if a single bam file is provided - extension = os.path.splitext(os.path.basename(bam_list))[1] - if extension in ['.bam','.cram','.sam']: - bam_list = [bam_list] - - else: - with open(bam_list) as file: - bam_list = [line.rstrip() for line in file] - for bam in bam_list: - if not os.path.isfile(bam): - raise ValueError('input must be a bam file or a file containing a list of bam files') - - - elif bed_list: - extension = os.path.splitext(os.path.basename(bed_list))[1] - if extension in ['.bed']: - bed_list = [bed_list] - - else: - with open(bed_list) as file: - bed_list = [line.rstrip() for line in file] - for bed in bed_list: - if not os.path.isfile(bed): - raise ValueError('input must be a bed file or a file containing a list of bed files') - - - original_junctions_filelist = output_utils.out_junctions_filelist(bam_list, gtf_path, bed_list, out_original_junctions, suffix="_original_junctions") - removed_junctions_filelist = output_utils.out_junctions_filelist(bam_list, gtf_path, bed_list, out_removed_junctions, suffix="_removed_junctions") - filtered_bam_filelist = output_utils.out_filtered_bam_filelist(bam_list, out_filtered_bam, suffix=suffix) - - spurious_dict = get_spurious_introns.get_spurious_junctions(scoring, k, w, m, overhang, min_duplicate_exon_length, bt2_index, bt2_k, - ref_fa, p, anchor, min_junc_score, bam_list, gtf_path, - bed_list, trusted_bed, original_junctions_filelist, verbose ) - - - if is_bam: - if filtered_bam_filelist: - sample_to_bed = output_utils.spurious_dict_bam_by_sample_to_bed( - spurious_dict, bam_list, removed_junctions_filelist, scoring=scoring) - output_utils.filter_multi_bam_with_vacuum(bam_list, sample_to_bed, filtered_bam_filelist, p, verbose, removed_alignments_bam) - if removed_junctions_filelist is None: - for _, sample in sample_to_bed.items(): - os.remove(sample) - elif removed_junctions_filelist: - sample_to_bed = output_utils.spurious_dict_bam_by_sample_to_bed(spurious_dict, bam_list, removed_junctions_filelist, scoring=scoring) - - else: - output_utils.spurious_dict_all_to_bed(spurious_dict, scoring, None, gtf_path, bed_list, bam_list) - - if gtf_path: - output_utils.spurious_dict_all_to_bed(spurious_dict, scoring, removed_junctions_filelist, gtf_path, bed_list, bam_list) - - elif bed_list: - if removed_junctions_filelist: - output_utils.spurious_dict_bed_by_sample_to_bed(spurious_dict, bed_list, removed_junctions_filelist, scoring) - else: - output_utils.spurious_dict_all_to_bed(spurious_dict, scoring, None, gtf_path, bed_list, bam_list) - -if __name__ == '__main__': - main() +if __name__ == "__main__": + main() diff --git a/src/eastr/utils.py b/src/eastr/utils.py index 565ffeb..a09f0d0 100644 --- a/src/eastr/utils.py +++ b/src/eastr/utils.py @@ -5,25 +5,27 @@ def index_fasta(ref_fa): - if not os.path.exists(f"{ref_fa}.fai"): - pysam.faidx(ref_fa) + if not os.path.exists(f"{ref_fa}.fai"): + pysam.faidx(ref_fa) -#Make a new directory + +# Make a new directory def make_dir(path): - directory = os.path.join(path) - os.makedirs(directory,exist_ok=True) + directory = os.path.join(path) + os.makedirs(directory, exist_ok=True) def get_chroms_list_from_fasta(ref_fa): - fasta=pysam.FastaFile(ref_fa) - chroms = list(fasta.references) - chrom_sizes = collections.defaultdict(int) - for chrom in chroms: - chrom_sizes[chrom] = fasta.get_reference_length(chrom) - return chrom_sizes - -def check_directory_or_file(path:str) -> str: - if os.path.splitext(os.path.basename(path))[1]!='': - return 'file' - else: - return 'dir' + fasta = pysam.FastaFile(ref_fa) + chroms = list(fasta.references) + chrom_sizes = collections.defaultdict(int) + for chrom in chroms: + chrom_sizes[chrom] = fasta.get_reference_length(chrom) + return chrom_sizes + + +def check_directory_or_file(path: str) -> str: + if os.path.splitext(os.path.basename(path))[1] != "": + return "file" + else: + return "dir" diff --git a/tests/align_hisat2.sh b/tests/align_hisat2.sh index 4d4a217..d47c9b9 100755 --- a/tests/align_hisat2.sh +++ b/tests/align_hisat2.sh @@ -17,7 +17,7 @@ for i in $(ls "$wd"/fastq/*_1.fastq); do -1 "$wd"/fastq/"${name}"_1.fastq \ -2 "$wd"/fastq/"${name}"_2.fastq \ -S "$TEMPLOC"/"${name}".sam - + samtools view -bS "$TEMPLOC"/"${name}".sam | \ samtools sort -@ $NCPU - -o "$ALIGNLOC"/"${name}"_hisat.bam rm "$TEMPLOC"/"${name}".sam