diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml new file mode 100644 index 00000000..ad3ec43f --- /dev/null +++ b/.github/workflows/CI.yml @@ -0,0 +1,181 @@ +# This file is autogenerated by maturin v1.7.4 +# To update, run +# +# maturin generate-ci github +# +name: CI + +on: + push: + branches: + - main + - master + tags: + - '*' + pull_request: + workflow_dispatch: + +permissions: + contents: read + +jobs: + linux: + runs-on: ${{ matrix.platform.runner }} + strategy: + matrix: + platform: + - runner: ubuntu-latest + target: x86_64 + - runner: ubuntu-latest + target: x86 + - runner: ubuntu-latest + target: aarch64 + - runner: ubuntu-latest + target: armv7 + - runner: ubuntu-latest + target: s390x + - runner: ubuntu-latest + target: ppc64le + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.x + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.platform.target }} + args: --release --out dist --find-interpreter + sccache: 'true' + manylinux: auto + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-linux-${{ matrix.platform.target }} + path: dist + + musllinux: + runs-on: ${{ matrix.platform.runner }} + strategy: + matrix: + platform: + - runner: ubuntu-latest + target: x86_64 + - runner: ubuntu-latest + target: x86 + - runner: ubuntu-latest + target: aarch64 + - runner: ubuntu-latest + target: armv7 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.x + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.platform.target }} + args: --release --out dist --find-interpreter + sccache: 'true' + manylinux: musllinux_1_2 + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-musllinux-${{ matrix.platform.target }} + path: dist + + windows: + runs-on: ${{ matrix.platform.runner }} + strategy: + matrix: + platform: + - runner: windows-latest + target: x64 + - runner: windows-latest + target: x86 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.x + architecture: ${{ matrix.platform.target }} + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.platform.target }} + args: --release --out dist --find-interpreter + sccache: 'true' + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-windows-${{ matrix.platform.target }} + path: dist + + macos: + runs-on: ${{ matrix.platform.runner }} + strategy: + matrix: + platform: + - runner: macos-12 + target: x86_64 + - runner: macos-14 + target: aarch64 + steps: + - uses: actions/checkout@v4 + - uses: actions/setup-python@v5 + with: + python-version: 3.x + - name: Build wheels + uses: PyO3/maturin-action@v1 + with: + target: ${{ matrix.platform.target }} + args: --release --out dist --find-interpreter + sccache: 'true' + - name: Upload wheels + uses: actions/upload-artifact@v4 + with: + name: wheels-macos-${{ matrix.platform.target }} + path: dist + + sdist: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: Build sdist + uses: PyO3/maturin-action@v1 + with: + command: sdist + args: --out dist + - name: Upload sdist + uses: actions/upload-artifact@v4 + with: + name: wheels-sdist + path: dist + + release: + name: Release + runs-on: ubuntu-latest + if: ${{ startsWith(github.ref, 'refs/tags/') || github.event_name == 'workflow_dispatch' }} + needs: [linux, musllinux, windows, macos, sdist] + permissions: + # Use to sign the release artifacts + id-token: write + # Used to upload release artifacts + contents: write + # Used to generate artifact attestation + attestations: write + steps: + - uses: actions/download-artifact@v4 + - name: Generate artifact attestation + uses: actions/attest-build-provenance@v1 + with: + subject-path: 'wheels-*/*' + - name: Publish to PyPI + if: "startsWith(github.ref, 'refs/tags/')" + uses: PyO3/maturin-action@v1 + env: + MATURIN_PYPI_TOKEN: ${{ secrets.PYPI_API_TOKEN }} + with: + command: upload + args: --non-interactive --skip-existing wheels-*/* diff --git a/.gitignore b/.gitignore index 7ea9ea8f..4184eca0 100644 --- a/.gitignore +++ b/.gitignore @@ -109,3 +109,15 @@ ENV/ .DS_Store results *old + +# python generated files +__pycache__/ +*.py[oc] +build/ +dist/ +wheels/ +*.egg-info + +# venv +.venv +test_data diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 00000000..aa9638c1 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,437 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "ahash" +version = "0.8.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" +dependencies = [ + "cfg-if", + "getrandom", + "once_cell", + "version_check", + "zerocopy", +] + +[[package]] +name = "autocfg" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" + +[[package]] +name = "bitflags" +version = "2.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "cc" +version = "1.1.29" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "58e804ac3194a48bb129643eb1d62fcc20d18c6b8c181704489353d13120bcd1" +dependencies = [ + "shlex", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "csv" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac574ff4d437a7b5ad237ef331c17ccca63c46479e5b5453eb8e10bb99a759fe" +dependencies = [ + "csv-core", + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "csv-core" +version = "0.1.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5efa2b3d7902f4b634a20cae3c9c4e6209dc4779feb6863329607560143efa70" +dependencies = [ + "memchr", +] + +[[package]] +name = "fallible-iterator" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2acce4a10f12dc2fb14a218589d4f1f62ef011b2d0cc4b3cb1bba8e94da14649" + +[[package]] +name = "fallible-streaming-iterator" +version = "0.1.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7360491ce676a36bf9bb3c56c1aa791658183a54d2744120f27285738d90465a" + +[[package]] +name = "funmap_lib" +version = "0.1.0" +dependencies = [ + "ahash", + "csv", + "pyo3", + "rusqlite", + "serde-pickle", +] + +[[package]] +name = "getrandom" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "hashbrown" +version = "0.14.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", +] + +[[package]] +name = "hashlink" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ba4ff7128dee98c7dc9794b6a411377e1404dba1c97deb8d1a55297bd25d8af" +dependencies = [ + "hashbrown", +] + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "indoc" +version = "2.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" + +[[package]] +name = "iter-read" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c397ca3ea05ad509c4ec451fea28b4771236a376ca1c69fd5143aae0cf8f93c4" + +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + +[[package]] +name = "libc" +version = "0.2.159" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "561d97a539a36e26a9a5fad1ea11a3039a67714694aaa379433e580854bc3dc5" + +[[package]] +name = "libsqlite3-sys" +version = "0.30.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e99fb7a497b1e3339bc746195567ed8d3e24945ecd636e3619d20b9de9e9149" +dependencies = [ + "cc", + "pkg-config", + "vcpkg", +] + +[[package]] +name = "memchr" +version = "2.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" + +[[package]] +name = "memoffset" +version = "0.9.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "488016bfae457b036d996092f6cb448677611ce4449e970ceaf42695203f218a" +dependencies = [ + "autocfg", +] + +[[package]] +name = "num-bigint" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +dependencies = [ + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "pkg-config" +version = "0.3.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "953ec861398dccce10c670dfeaf3ec4911ca479e9c02154b3a215178c5f566f2" + +[[package]] +name = "portable-atomic" +version = "1.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d30538d42559de6b034bc76fd6dd4c38961b1ee5c6c56e3808c50128fdbc22ce" + +[[package]] +name = "proc-macro2" +version = "1.0.86" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pyo3" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15ee168e30649f7f234c3d49ef5a7a6cbf5134289bc46c29ff3155fa3221c225" +dependencies = [ + "cfg-if", + "indoc", + "libc", + "memoffset", + "once_cell", + "portable-atomic", + "pyo3-build-config", + "pyo3-ffi", + "pyo3-macros", + "unindent", +] + +[[package]] +name = "pyo3-build-config" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e61cef80755fe9e46bb8a0b8f20752ca7676dcc07a5277d8b7768c6172e529b3" +dependencies = [ + "once_cell", + "target-lexicon", +] + +[[package]] +name = "pyo3-ffi" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67ce096073ec5405f5ee2b8b31f03a68e02aa10d5d4f565eca04acc41931fa1c" +dependencies = [ + "libc", + "pyo3-build-config", +] + +[[package]] +name = "pyo3-macros" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2440c6d12bc8f3ae39f1e775266fa5122fd0c8891ce7520fa6048e683ad3de28" +dependencies = [ + "proc-macro2", + "pyo3-macros-backend", + "quote", + "syn", +] + +[[package]] +name = "pyo3-macros-backend" +version = "0.22.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1be962f0e06da8f8465729ea2cb71a416d2257dff56cbe40a70d3e62a93ae5d1" +dependencies = [ + "heck", + "proc-macro2", + "pyo3-build-config", + "quote", + "syn", +] + +[[package]] +name = "quote" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rusqlite" +version = "0.32.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7753b721174eb8ff87a9a0e799e2d7bc3749323e773db92e0984debb00019d6e" +dependencies = [ + "bitflags", + "fallible-iterator", + "fallible-streaming-iterator", + "hashlink", + "libsqlite3-sys", + "smallvec", +] + +[[package]] +name = "ryu" +version = "1.0.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" + +[[package]] +name = "serde" +version = "1.0.210" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde-pickle" +version = "1.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c762ad136a26407c6a80825813600ceeab5e613660d93d79a41f0ec877171e71" +dependencies = [ + "byteorder", + "iter-read", + "num-bigint", + "num-traits", + "serde", +] + +[[package]] +name = "serde_derive" +version = "1.0.210" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + +[[package]] +name = "smallvec" +version = "1.13.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" + +[[package]] +name = "syn" +version = "2.0.77" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9f35bcdf61fd8e7be6caf75f429fdca8beb3ed76584befb503b1569faee373ed" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "target-lexicon" +version = "0.12.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61c41af27dd6d1e27b1b16b489db798443478cef1f06a660c96db617ba5de3b1" + +[[package]] +name = "unicode-ident" +version = "1.0.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91b56cd4cadaeb79bbf1a5645f6b4f8dc5bde8834ad5894a8db35fda9efa1fe" + +[[package]] +name = "unindent" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" + +[[package]] +name = "vcpkg" +version = "0.2.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" + +[[package]] +name = "version_check" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "zerocopy" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +dependencies = [ + "zerocopy-derive", +] + +[[package]] +name = "zerocopy-derive" +version = "0.7.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 00000000..e3c902b7 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "funmap_lib" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[lib] +name = "funmap" +crate-type = ["cdylib"] + +[dependencies] +ahash = "0.8.11" +csv = "1.3.0" +pyo3 = "0.22.0" +rusqlite = { version = "0.32.1", features = ["bundled"] } +serde-pickle = "1.1.1" diff --git a/funmap/cli.py b/funmap/cli.py deleted file mode 100644 index 0e44f77d..00000000 --- a/funmap/cli.py +++ /dev/null @@ -1,297 +0,0 @@ -import os -import click -import pandas as pd -import numpy as np -import gzip -import pickle -from pathlib import Path -from funmap.funmap import compute_llr, predict_all_pairs, dataset_llr, predict_all_pairs -from funmap.plotting import explore_data, plot_results, merge_and_delete -from funmap.funmap import train_ml_model, prepare_gs_data, get_cutoff, get_ppi_feature -from funmap.funmap import compute_features, predict_network -from funmap.data_urls import misc_urls as urls -from funmap.logger import setup_logging, setup_logger -from funmap.utils import setup_experiment, cleanup_experiment, check_gold_standard_file -from funmap.utils import check_extra_feature_file -from funmap import __version__ - -log = setup_logger(__name__) - -@click.group(help='funmap command line interface') -@click.version_option(version=f'{__version__}') -def cli(): - """ - Command line interface for funmap. - """ - click.echo("====== funmap =======") - -@cli.command(help='check the data quality') -@click.option('--config-file', '-c', required=True, type=click.Path(exists=True), - help='path to experiment configuration yaml file') -@click.option('--force-rerun', '-f', is_flag=True, default=False, - help='if set, will remove results from previous run first') -def qc(config_file, force_rerun): - if force_rerun: - while True: - confirmation = input("Do you want to remove results from previous run? (y/n): ") - if confirmation.lower() == 'y': - click.echo('Removing results from previous run') - cleanup_experiment(config_file) - break - elif confirmation.lower() == 'n': - click.echo('Not removing results from previous run') - break - else: - click.echo("Invalid input. Please enter 'y' or 'n'.") - - setup_logging(config_file) - log.info(f'Running QC...') - cfg = setup_experiment(config_file) - all_fig_names = [] - figure_dir = Path(cfg['results_dir']) / cfg['subdirs']['figure_dir'] - min_sample_count = cfg['min_sample_count'] - fig_names = explore_data(cfg, min_sample_count, figure_dir) - all_fig_names.extend(fig_names) - merge_and_delete(figure_dir, all_fig_names, 'qc.pdf') - log.info('figure qc.pdf saved to {}'.format(figure_dir)) - log.info('QC complete') - - -@cli.command(help='run funmap') -@click.option('--config-file', '-c', required=True, type=click.Path(exists=True), - help='path to experiment configuration yaml file') -@click.option('--force-rerun', '-f', is_flag=True, default=False, - help='if set, will remove results from previous run first') -def run(config_file, force_rerun): - click.echo('Running funmap...') - if force_rerun: - while True: - confirmation = input("Do you want to remove results from previous run? (y/n): ") - if confirmation.lower() == 'y': - click.echo('Removing results from previous run') - cleanup_experiment(config_file) - break - elif confirmation.lower() == 'n': - click.echo('Not removing results from previous run') - break - else: - click.echo("Invalid input. Please enter 'y' or 'n'.") - - setup_logging(config_file) - cfg = setup_experiment(config_file) - extra_feature_file = cfg['extra_feature_file'] - if (extra_feature_file is not None) and (not check_extra_feature_file(extra_feature_file)): - return - gs_file = cfg['gs_file'] - if (gs_file is not None) and (not check_gold_standard_file(gs_file)): - return - - task = cfg['task'] - seed = cfg['seed'] - np.random.seed(seed) - ml_type = cfg['ml_type'] - feature_type = cfg['feature_type'] - # min_feature_count = cfg['min_feature_count'] - min_sample_count = cfg['min_sample_count'] - # filter_before_prediction = cfg['filter_before_prediction'] - test_size = cfg['test_size'] - # filter_after_prediction = cfg['filter_after_prediction'] - # filter_criterion = cfg['filter_criterion'] - # filter_threshold = cfg['filter_threshold'] - # filter_blacklist = cfg['filter_blacklist'] - n_jobs = cfg['n_jobs'] - lr_cutoff = cfg['lr_cutoff'] - max_num_edges = cfg['max_num_edges'] - step_size = cfg['step_size'] - start_edge_num = cfg['start_edge_num'] - - results_dir = Path(cfg['results_dir']) - saved_data_dir = results_dir / cfg['subdirs']['saved_data_dir'] - model_dir = results_dir / cfg['subdirs']['saved_model_dir'] - prediction_dir = results_dir / cfg['subdirs']['saved_predictions_dir'] - network_dir = results_dir / cfg['subdirs']['network_dir'] - figure_dir = results_dir / cfg['subdirs']['figure_dir'] - - if cfg['task'] == 'protein_func': - feature_mapping = ['ex', 'ei'] - else: - feature_mapping = ['ex'] - # here the file stored a dictionary of ml models - ml_model_file = {feature: model_dir / f'model_{feature}.pkl.gz' - for feature in feature_mapping } - predicted_all_pairs_file = {feature: prediction_dir / f'predicted_all_pairs_{feature}.parquet' - for feature in feature_mapping } - llr_res_file = {feature: results_dir / f'llr_results_{feature}.tsv' - for feature in feature_mapping } - edge_list_file = {feature: network_dir/ f'funmap_{feature}.tsv' - for feature in feature_mapping } - # gold standard data include specified feature (cc or mr) and ppi feature (if applicable) - # and extra feature if applicable - gs_df_file = saved_data_dir / 'gold_standard_data.h5' - # blacklist_file = urls['funmap_blacklist'] - # llr obtained with each invividual dataset - llr_dataset_file = results_dir / 'llr_dataset.tsv' - gs_train = gs_test = None - cutoff_p = cutoff_llr = None - ml_model_dict = {} - - # compute and save cc, mr results - cc_dict, mr_dict, all_valid_ids = compute_features(cfg, feature_type, min_sample_count, - saved_data_dir) - gs_args = { - 'task': task, - 'saved_data_dir': saved_data_dir, - 'cc_dict': cc_dict, - 'mr_dict': mr_dict, - 'feature_type': feature_type, - 'gs_file': gs_file, - 'extra_feature_file': extra_feature_file, - 'valid_id_list': all_valid_ids, - 'test_size': test_size, - 'seed': seed - } - - all_edge_list_exist = all(os.path.exists(file_path) for file_path in edge_list_file.values()) - - if all_edge_list_exist: - log.info('Fumap network(s) already exists. Skipping model training and prediction.') - else: - all_model_exist = all(os.path.exists(file_path) for file_path in ml_model_file.values()) - if all_model_exist: - log.info(f'Trained model(s) exists. Loading model(s) ...') - ml_model_dict = {} - # feature: ex or ei - for feature in ml_model_file: - with gzip.open(ml_model_file[feature], 'rb') as fh: - ml_model = pickle.load(fh) - ml_model_dict[feature] = ml_model - log.info('Loading model(s) ... done') - if not gs_df_file.exists(): - log.error(f'Trained models found but gold standard data file {gs_df_file} ' - f'does not exist.') - return - with pd.HDFStore(gs_df_file, mode='r') as store: - gs_train = store['train'] - gs_test = store['test'] - else: - gs_train, gs_test = prepare_gs_data(**gs_args) - with pd.HDFStore(gs_df_file, mode='w') as store: - store.put('train', gs_train) - store.put('test', gs_test) - ml_model_dict = train_ml_model(gs_train, ml_type, seed, n_jobs, feature_mapping, - model_dir) - - all_predicted_all_pairs_exist = all(os.path.exists(file_path) for file_path in - predicted_all_pairs_file.values()) - if all_predicted_all_pairs_exist: - log.info('Predicted all pairs already exists. Skipping prediction.') - else: - log.info('Predicting all pairs ...') - if task == 'protein_func': - ppi_feature = get_ppi_feature() - else: - ppi_feature = None - pred_all_pairs_args = { - 'model_dict': ml_model_dict, - 'all_ids': all_valid_ids, - 'feature_type': feature_type, - 'ppi_feature': ppi_feature, - 'cc_dict': cc_dict, - 'mr_dict': mr_dict, - 'extra_feature_file': extra_feature_file, - 'prediction_dir': prediction_dir, - 'output_file': predicted_all_pairs_file, - 'n_jobs': n_jobs - } - predict_all_pairs(**pred_all_pairs_args) - log.info('Predicting all pairs ... done') - - cutoff_p, cutoff_llr = get_cutoff(ml_model_dict, gs_test, lr_cutoff) - log.info(f'cutoff probability: {cutoff_p}') - log.info(f'cutoff llr: {cutoff_llr}') - predict_network(predicted_all_pairs_file, cutoff_p, edge_list_file) - - if not gs_df_file.exists(): - gs_train, gs_test = prepare_gs_data(**gs_args) - with pd.HDFStore(gs_df_file, mode='w') as store: - store.put('train', gs_train) - store.put('test', gs_test) - else: - if gs_test is None: - with pd.HDFStore(gs_df_file, mode='r') as store: - gs_train = store['train'] - gs_test = store['test'] - - all_llr_res_exist = all(os.path.exists(file_path) for file_path in llr_res_file.values()) - all_edge_list_exist = all(os.path.exists(file_path) for file_path in edge_list_file.values()) - if all_llr_res_exist and all_edge_list_exist: - log.info('LLR results already exist.') - else: - for ft in feature_mapping: - log.info(f'Computing LLR for {ft} ...') - if not predicted_all_pairs_file[ft].exists(): - log.error(f'Predicted all pairs file {predicted_all_pairs_file[ft]} does not exist.') - return - predicted_all_pairs = pd.read_parquet(predicted_all_pairs_file[ft]) - # also save the llr results to file - compute_llr(predicted_all_pairs, llr_res_file[ft], start_edge_num, max_num_edges, step_size, - gs_test) - log.info(f'Computing LLR for {ft} ... done') - - validation_res = {} - for ft in feature_mapping: - validation_res[ft] = { - 'llr_res_path': llr_res_file[ft], - 'edge_list_path': edge_list_file[ft] - } - if not llr_dataset_file.exists(): - log.info('Computing LLR for each dataset ...') - # feature_dict = cc_dict if feature_type == 'cc' else mr_dict - # use CC features for individual dataset - llr_ds = dataset_llr(all_valid_ids, cc_dict, 'cc', gs_test, start_edge_num, - max_num_edges, step_size, llr_dataset_file) - log.info('Done.') - else: - llr_ds = pd.read_csv(llr_dataset_file, sep='\t') - - if not ml_model_dict: - log.info('Trained model(s) exists. Loading model(s) ...') - for feature in ml_model_file: - with gzip.open(ml_model_file[feature], 'rb') as fh: - ml_model = pickle.load(fh) - ml_model_dict[feature] = ml_model - log.info('Loading model(s) ... done') - - all_fig_names = [] - if cutoff_llr is None: - cutoff_p, cutoff_llr = get_cutoff(ml_model_dict, gs_test, lr_cutoff) - - gs_dict = {} - gs_dict[feature_type.upper()] = gs_train - if task == 'protein_func' and feature_type.upper() == 'MR' and 'rp_pairs' in cfg: - # extract gs data for CC and MR for plotting - gs_args = { - 'task': task, - 'saved_data_dir': saved_data_dir, - 'cc_dict': cc_dict, - 'mr_dict': mr_dict, - 'feature_type': 'cc', - 'gs_file': gs_file, - # no extra feature for plotting - 'extra_feature_file': None, - 'valid_id_list': all_valid_ids, - 'test_size': test_size, - 'seed': seed - } - gs_train, gs_test = prepare_gs_data(**gs_args) - gs_dict['CC'] = gs_train - - fig_names = plot_results(cfg, validation_res, llr_ds, gs_dict, cutoff_llr, - figure_dir) - all_fig_names.extend(fig_names) - - merge_and_delete(figure_dir, all_fig_names, 'results.pdf') - - -if __name__ == '__main__': - cli() diff --git a/funmap/data_urls.py b/funmap/data_urls.py deleted file mode 100644 index 222ce369..00000000 --- a/funmap/data_urls.py +++ /dev/null @@ -1,17 +0,0 @@ -misc_urls = { - 'reactome_gold_standard': 'https://figshare.com/ndownloader/files/38647601', - 'reactome_gold_standard_md5': '671942763b6a7c32506cba1ed9900fe6', - 'funmap_blacklist': 'https://figshare.com/ndownloader/files/39033977', - 'mapping_file': 'https://figshare.com/ndownloader/files/39033971' -} - -# the information about other networks is fixed for now -network_info = { - 'name': ['BioGRID', 'BioPlex', 'HI-union', 'STRING'], - 'type': ['BioGRID', 'BioPlex', 'HI', 'STRING'], - 'url': ['https://figshare.com/ndownloader/files/39125054', - 'https://figshare.com/ndownloader/files/39125051', - 'https://figshare.com/ndownloader/files/39125093', - 'https://figshare.com/ndownloader/files/39125090' - ] -} diff --git a/funmap/funmap.py b/funmap/funmap.py deleted file mode 100644 index 4aa4528b..00000000 --- a/funmap/funmap.py +++ /dev/null @@ -1,613 +0,0 @@ -import os -import glob -import math -import h5py -import gzip -import pickle -import pyarrow as pa -import pyarrow.parquet as pq -from concurrent.futures import ThreadPoolExecutor -from tqdm import tqdm -from sklearn.utils import resample -import itertools -from typing import List -from pathlib import Path -from collections import defaultdict, Counter -import pandas as pd -import numpy as np -from sklearn.model_selection import GridSearchCV -from sklearn.model_selection import train_test_split -from sklearn.model_selection import StratifiedKFold -import xgboost as xgb -from funmap.utils import get_data_dict, is_url_scheme, read_csv_with_md5_check -from funmap.data_urls import network_info, misc_urls as urls -from funmap.logger import setup_logger - -log = setup_logger(__name__) - -def get_valid_gs_data(gs_path: str, valid_gene_list: List[str], md5=None): - log.info(f'Loading gold standard feature file "{gs_path}" ...') - if is_url_scheme(gs_path): - gs_edge_df = read_csv_with_md5_check(gs_path, expected_md5=md5, - local_path='download_gs_file', sep='\t') - if gs_edge_df is None: - raise ValueError('Failed to download gold standard file') - else: - gs_edge_df = pd.read_csv(gs_path, sep='\t') - - log.info('Done loading gold standard feature file') - gs_edge_df = gs_edge_df.rename(columns={gs_edge_df.columns[0]: 'P1', - gs_edge_df.columns[1]: 'P2'}) - gs_edge_df = gs_edge_df[gs_edge_df['P1'].isin(valid_gene_list) & - gs_edge_df['P2'].isin(valid_gene_list) & - (gs_edge_df['P1'] != gs_edge_df['P2'])] - m = ~pd.DataFrame(np.sort(gs_edge_df[['P1','P2']], axis=1)).duplicated() - gs_edge_df = gs_edge_df[list(m)] - gs_edge_df.reset_index(drop=True, inplace=True) - # rename the last column name to 'label' - gs_edge_df.rename(columns={gs_edge_df.columns[-1]: 'label'}, inplace=True) - - return gs_edge_df - - -def pairwise_mutual_rank(pcc_matrix): - """ - Calculate the pairwise mutual rank matrix based on the given Pearson correlation coefficient matrix. - - Parameters: - ----------- - pcc_matrix : numpy.ndarray - The Pearson correlation coefficient matrix. It should be a square matrix where - pcc_matrix[i, j] represents the correlation coefficient between variables i and j. - - Returns: - -------- - numpy.ndarray - A matrix containing the pairwise mutual ranks between variables based on the - provided Pearson correlation coefficient matrix. The matrix has the same shape - as the input pcc_matrix. - - Mutual Rank Calculation: - ------------------------ - The mutual rank between two variables A and B, based on their Pearson correlation coefficients, - is a measure of their relative rankings within their respective groups of correlated variables. - The formula for calculating the mutual rank is given by: - - mr_{AB} = sqrt((r_{AB} / n_B) * (r_{BA} / n_A)) - - Where: - - mr_{AB} is the mutual rank between variables A and B. - - r_{AB} is the rank of the correlation coefficient between A and B among all other correlation - coefficients involving A (excluding NaN values). - - n_B is the number of valid (non-NaN) correlation coefficients involving variable B. - - r_{BA} is the rank of the correlation coefficient between B and A among all other correlation - coefficients involving B (excluding NaN values). - - n_A is the number of valid (non-NaN) correlation coefficients involving variable A. - - Steps: - - For each variable pair (A, B): - - Calculate the rank of the correlation coefficient between A and B among all other correlation - coefficients involving A. This rank is denoted as r_{AB}. - - Calculate the rank of the correlation coefficient between B and A among all other correlation - coefficients involving B. This rank is denoted as r_{BA}. - - For each variable pair (A, B): - - Determine the number of valid (non-NaN) correlation coefficients involving variable B, denoted as n_B. - - Determine the number of valid (non-NaN) correlation coefficients involving variable A, denoted as n_A. - - For each variable pair (A, B): - - Compute the mutual rank mr_{AB} using the formula mentioned earlier. - - Populate the mutual rank matrix: - - Create a new matrix with the same shape as the input correlation coefficient matrix, - initialized with NaN values. - - For each valid variable pair (A, B), assign the corresponding mutual rank mr_{AB} - to the matrix at the appropriate indices. - - The resulting matrix contains the mutual ranks between all pairs of variables based on their - Pearson correlation coefficients. Higher mutual rank values indicate stronger and more consistent - correlations between variables. - """ - valid_a = ~np.isnan(pcc_matrix) - valid_b = valid_a.T - - rank_ab = np.argsort(pcc_matrix, axis=1).argsort(axis=1, kind='mergesort') + 1 # Start ranks from 1 - rank_ba = np.argsort(pcc_matrix, axis=0).argsort(axis=0, kind='mergesort') + 1 # Start ranks from 1 - - n_a = np.sum(valid_a, axis=1) - n_b = np.sum(valid_b, axis=0) - - valid_indices_a, valid_indices_b = np.where(valid_a) - - mr_values = np.sqrt((rank_ab[valid_indices_a, valid_indices_b] / n_b[valid_indices_b]) * - (rank_ba[valid_indices_a, valid_indices_b] / n_a[valid_indices_a])) - - mr_matrix = np.full_like(pcc_matrix, np.nan) - mr_matrix[valid_indices_a, valid_indices_b] = mr_values - - return mr_matrix - - -def compute_features(cfg, feature_type, min_sample_count, output_dir): - """Compute the pearson correlation coefficient for each edge in the list of edges and for each - """ - data_dict, all_valid_ids = get_data_dict(cfg, min_sample_count) - cc_dict = {} - for i in data_dict: - cc_file = os.path.join(output_dir, f'cc_{i}.h5') - cc_dict[i] = cc_file - - mr_dict = {} - for i in data_dict: - mr_file = os.path.join(output_dir, f'mr_{i}.h5') - mr_dict[i] = mr_file - - all_cc_exist = all(os.path.exists(file_path) for file_path in cc_dict.values()) - if all_cc_exist: - log.info("All cc files exist. Skipping feature computation.") - if feature_type == 'cc': - return cc_dict, mr_dict, all_valid_ids - - if feature_type == 'mr': - all_mr_exist = all(os.path.exists(file_path) for file_path in mr_dict.values()) - if all_cc_exist and all_mr_exist: - log.info("All mr files exist. Skipping feature computation.") - return cc_dict, mr_dict, all_valid_ids - - log.debug(f"Computing {feature_type} features") - for i in data_dict: - cc_file = cc_dict[i] - if os.path.exists(cc_file): - continue - log.info(f"Computing pearson correlation coefficient matrix for {i}") - x = data_dict[i].values.astype(np.float32) - df = pd.DataFrame(x) - corr_matrix = df.corr(method='pearson', min_periods=min_sample_count) - arr = corr_matrix.values - upper_indices = np.triu_indices(arr.shape[0]) - with h5py.File(cc_dict[i], 'w') as hf: - # only store the upper triangle part - hf.create_dataset('cc', data=arr[upper_indices]) - hf.create_dataset('ids', data=data_dict[i].columns.values.astype('S')) - cc_dict[i] = cc_file - - # compute pairwise mutual rank features - if feature_type == 'mr': - log.info(f"Computing mutual rank matrix for {i}") - arr_mr = pairwise_mutual_rank(arr) - upper_indices = np.triu_indices(arr_mr.shape[0]) - with h5py.File(mr_dict[i], 'w') as hf: - # only store the upper triangle part - hf.create_dataset('mr', data=arr_mr[upper_indices]) - hf.create_dataset('ids', data=data_dict[i].columns.values.astype('S')) - - return cc_dict, mr_dict, all_valid_ids - - -def balance_classes(df, random_state=42): - class_column = df.columns[-1] # Assuming class column is the last column - - class_values = df[class_column].unique() - if len(class_values) != 2: - raise ValueError("The class column should have exactly 2 unique values.") - - class_0 = df[df[class_column] == class_values[0]] - class_1 = df[df[class_column] == class_values[1]] - minority_class = class_0 if len(class_0) < len(class_1) else class_1 - majority_class = class_1 if minority_class is class_0 else class_0 - - majority_class_undersampled = resample(majority_class, replace=False, - n_samples=len(minority_class), random_state=random_state) - - balanced_df = pd.concat([minority_class, majority_class_undersampled]) - # Shuffle the rows in the balanced DataFrame - balanced_df = balanced_df.sample(frac=1, random_state=random_state).reset_index(drop=True) - - return balanced_df - -def assemble_feature_df(h5_file_mapping, df, dataset='cc'): - df.reset_index(drop=True, inplace=True) - # Initialize feature_df with columns for HDF5 file keys and 'label' - file_keys = list(h5_file_mapping.keys()) - feature_df = pd.DataFrame(columns=file_keys + ['label']) - - def get_1d_indices(i_array, j_array, n): - # Ensure i and j are within bounds - mask = (i_array < n) & (j_array < n) - i = np.minimum(i_array[mask], j_array[mask]) - j = np.maximum(i_array[mask], j_array[mask]) - - # Calculate 1D indices - return i * n - i * (i - 1) // 2 + (j - i) - - # Iterate over HDF5 files and load feature values - for key, file_path in h5_file_mapping.items(): - with h5py.File(file_path, 'r') as h5_file: - gene_ids = h5_file['ids'][:] - gene_to_index = {gene.astype(str): idx for idx, gene in enumerate(gene_ids)} - - # Get gene indices for P1 and P2 - p1_indices = np.array([gene_to_index.get(gene, -1) for gene in df.iloc[:, 0]]) - p2_indices = np.array([gene_to_index.get(gene, -1) for gene in df.iloc[:, 1]]) - - f_dataset = h5_file[dataset] - f_values = np.empty(len(df), dtype=float) - valid_indices = (p1_indices != -1) & (p2_indices != -1) - - linear_indices = get_1d_indices(p1_indices[valid_indices], p2_indices[valid_indices], len(gene_ids)) - - f_values[valid_indices] = f_dataset[:][linear_indices] - f_values[~valid_indices] = np.nan - - # Add feature values to the feature_df - feature_df[key] = f_values - - # if the last column is 'label', assign it to feature_df - if df.columns[-1] == 'label': - feature_df['label'] = df[df.columns[-1]] - else: - # delete the 'label' column from feature_df - del feature_df['label'] - - return feature_df - - -def extract_features(df, feature_type, cc_dict, ppi_feature=None, extra_feature=None, mr_dict=None): - if feature_type == 'mr': - if not mr_dict: - raise ValueError('mr dict is empty') - - feature_dict = cc_dict if feature_type == 'cc' else mr_dict - feature_df = assemble_feature_df(feature_dict, df, feature_type) - if ppi_feature is not None: - ppi_dict = {key: set(value) for key, value in ppi_feature.items()} - for ppi_source, ppi_tuples in ppi_dict.items(): - feature_df[ppi_source] = df.apply( - lambda row: 1 if (row['P1'], row['P2']) in ppi_tuples else 0, axis=1) - - # TODO: add extra features if provided - if extra_feature is not None: - pass - - # move 'label' column to the end of the dataframe if it exists - if 'label' in feature_df.columns: - feature_df = feature_df[[col for col in feature_df.columns if col != 'label'] + ['label']] - - return feature_df - - -def get_ppi_feature(): - """ - Returns a dictionary of protein-protein interaction (PPI) features. - - The PPI features are extracted from data in the "network_info" dictionary and are specified by the - "feature_names" list. The URLs of the relevant data are extracted from "network_info" and read - using the Pandas library. The resulting PPI data is stored in the "ppi_features" dictionary and - returned by the function. - - Returns: - ppi_features: dict - A dictionary with PPI features, where the keys are the feature names and the values are lists of tuples - representing the protein interactions. - """ - feature_names = ['BioGRID', 'BioPlex', 'HI-union'] - urls = [network_info['url'][i] for i in range(len(network_info['name'])) - if network_info['name'][i] in feature_names] - - ppi_features = {} - # use pandas to read the file - for (i, url) in enumerate(urls): - data = pd.read_csv(url, sep='\t', header=None) - data = data.apply(lambda x: tuple(sorted(x)), axis=1) - ppi_name = f'{feature_names[i]}_PPI' - ppi_features[ppi_name] = data.tolist() - - return ppi_features - - -def train_ml_model(data_df, ml_type, seed, n_jobs, feature_mapping, model_dir): - assert ml_type == 'xgboost', 'ML model must be xgboost' - models = train_model(data_df.iloc[:, :-1], data_df.iloc[:, -1], seed, n_jobs, feature_mapping, - model_dir) - - return models - - -def train_model(X, y, seed, n_jobs, feature_mapping, model_dir): - model_params = { - 'n_estimators': [10, 20, 50, 100], - 'max_depth': [1, 2, 3, 4, 5], - 'learning_rate': [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0] - } - - models = {} - for ft in feature_mapping: - # use only mutual rank - log.info(f'Training model for {ft} ...') - xgb_model = xgb.XGBClassifier(random_state=seed, n_jobs=n_jobs) - cv = StratifiedKFold(n_splits=5, random_state=seed, shuffle=True) - clf = GridSearchCV(xgb_model, model_params, scoring='roc_auc', cv=cv, - n_jobs=1, verbose=2) - if ft == 'ex': - # exclude ppi features - Xtrain = X.loc[:, ~X.columns.str.endswith('_PPI')] - else: - Xtrain = X - model = clf.fit(Xtrain, y) - models[ft] = model - ml_model_file = model_dir / f'model_{ft}.pkl.gz' - with gzip.open(ml_model_file, 'wb') as fh: - pickle.dump(model, fh) - - log.info(f'Training model for {ft} ... done') - - return models - - -def compute_llr(predicted_all_pairs, llr_res_file, start_edge_num, max_num_edges, step_size, - gs_test): - # make sure max_num_edges is smaller than the number of non-NA values - assert max_num_edges < np.count_nonzero(~np.isnan(predicted_all_pairs.iloc[:, -1].values)), \ - 'max_num_edges should be smaller than the number of non-NA values' - - cur_col_name = 'prediction' - cur_results = predicted_all_pairs.nlargest(max_num_edges, cur_col_name) - selected_edges_all = cur_results[['P1', 'P2']].apply(lambda row: tuple(sorted({row['P1'], row['P2']})), axis=1) - - gs_test_pos_set = set(gs_test[gs_test['label'] == 1][['P1', 'P2']].apply(lambda row: tuple(sorted({row['P1'], row['P2']})), axis=1)) - gs_test_neg_set = set(gs_test[gs_test['label'] == 0][['P1', 'P2']].apply(lambda row: tuple(sorted({row['P1'], row['P2']})), axis=1)) - n_gs_test_pos_set = len(gs_test_pos_set) - n_gs_test_neg_set = len(gs_test_neg_set) - - result_dict = defaultdict(list) - # llr_res_dict only save maximum of max_steps data points for downstream - # analysis / plotting - total = math.ceil((max_num_edges - start_edge_num) / step_size) + 1 - for k in tqdm(range(start_edge_num, max_num_edges+step_size, step_size), total=total, ascii=' >='): - selected_edges = set(selected_edges_all[:k]) - all_nodes = set(itertools.chain.from_iterable(selected_edges)) - common_pos_edges = selected_edges & gs_test_pos_set - common_neg_edges = selected_edges & gs_test_neg_set - try: - lr = len(common_pos_edges) / len(common_neg_edges) / (n_gs_test_pos_set / n_gs_test_neg_set) - except ZeroDivisionError: - lr = 0 - llr = np.log(lr) if lr > 0 else np.nan - n_node = len(all_nodes) - result_dict['k'].append(k) - result_dict['n'].append(n_node) - result_dict['llr'].append(llr) - - llr_res = pd.DataFrame(result_dict) - if llr_res_file is not None: - llr_res.to_csv(llr_res_file, sep='\t', index=False) - - return llr_res - - -def prepare_gs_data(**kwargs): - task = kwargs['task'] - cc_dict = kwargs['cc_dict'] - mr_dict = kwargs['mr_dict'] - gs_file = kwargs['gs_file'] - gs_file_md5 = None - feature_type = kwargs['feature_type'] - extra_feature_file = kwargs['extra_feature_file'] - valid_id_list = kwargs['valid_id_list'] - test_size = kwargs['test_size'] - seed = kwargs['seed'] - - if gs_file is None: - gs_file = urls['reactome_gold_standard'] - gs_file_md5 = urls['reactome_gold_standard_md5'] - - log.info('Preparing gold standard data') - gs_df = get_valid_gs_data(gs_file, valid_id_list, md5=gs_file_md5) - gs_df_balanced = balance_classes(gs_df, random_state=seed) - del gs_df - gs_train, gs_test = train_test_split(gs_df_balanced, - test_size=test_size, random_state=seed, - stratify=gs_df_balanced.iloc[:, -1]) - if task == 'protein_func': - ppi_feature = get_ppi_feature() - else: - ppi_feature = None - gs_train_df = extract_features(gs_train, feature_type, cc_dict, ppi_feature, - extra_feature_file, mr_dict) - gs_test_df = extract_features(gs_test, feature_type, cc_dict, ppi_feature, - extra_feature_file, mr_dict) - - # store both the ids with gs_test_df for later use - # add the first two column of gs_test to gs_test_df at the beginning - gs_test_df = pd.concat([gs_test.iloc[:, :2], gs_test_df], axis=1) - log.info('Preparing gs data ... done') - return gs_train_df, gs_test_df - - -def extract_dataset_feature(all_pairs, feature_file, feature_type='cc'): - # convert all_pairs to a dataframe - df = pd.DataFrame(all_pairs, columns=['P1', 'P2']) - - def get_1d_indices(i_array, j_array, n): - # Ensure i and j are within bounds - mask = (i_array < n) & (j_array < n) - i = np.minimum(i_array[mask], j_array[mask]) - j = np.maximum(i_array[mask], j_array[mask]) - - # Calculate 1D indices - return i * n - i * (i - 1) // 2 + (j - i) - - with h5py.File(feature_file, 'r') as h5_file: - gene_ids = h5_file['ids'][:] - gene_to_index = {gene.astype(str): idx for idx, gene in enumerate(gene_ids)} - - # Get gene indices for P1 and P2 - p1_indices = np.array([gene_to_index.get(gene, -1) for gene in df.iloc[:, 0]]) - p2_indices = np.array([gene_to_index.get(gene, -1) for gene in df.iloc[:, 1]]) - - f_dataset = h5_file[feature_type] - f_values = np.empty(len(df), dtype=float) - valid_indices = (p1_indices != -1) & (p2_indices != -1) - - linear_indices = get_1d_indices(p1_indices[valid_indices], p2_indices[valid_indices], len(gene_ids)) - - f_values[valid_indices] = f_dataset[:][linear_indices] - f_values[~valid_indices] = np.nan - # extracted feature is the 'prediction' column - df['prediction'] = f_values - - return df - - -def dataset_llr(all_ids, feature_dict, feature_type, gs_test, start_edge_num, - max_num_edge, step_size, llr_dataset_file): - llr_ds = pd.DataFrame() - all_ids_sorted = sorted(all_ids) - all_pairs = list(itertools.combinations(all_ids_sorted, 2)) - all_ds_pred = None - - for dataset in feature_dict: - log.info(f'Calculating llr for {dataset} ...') - feature_file = feature_dict[dataset] - predicted_all_pairs = extract_dataset_feature(all_pairs, feature_file, feature_type) - if all_ds_pred is None: - all_ds_pred = predicted_all_pairs['prediction'].values - else: - all_ds_pred = np.vstack((all_ds_pred, predicted_all_pairs['prediction'].values)) - - cur_llr_res = compute_llr(predicted_all_pairs, None, start_edge_num, - max_num_edge, step_size, gs_test) - cur_llr_res['dataset'] = dataset - llr_ds = pd.concat([llr_ds, cur_llr_res], axis=0, ignore_index=True) - llr_ds.to_csv(llr_dataset_file, sep='\t', index=False) - log.info(f'Calculating llr for {dataset} ... done') - - # calculate llr for all datasets based on the average prediction - log.info('Calculating llr for all datasets average ...') - all_ds_pred_df = pd.DataFrame(all_pairs, columns=['P1', 'P2']) - if all_ds_pred.ndim == 1: - all_ds_pred_avg = all_ds_pred - elif all_ds_pred.ndim == 2: - all_ds_pred_avg = np.nanmean(all_ds_pred, axis=0) - else: - raise ValueError(f'Invalid dimension for all_ds_pred: {all_ds_pred.ndim}') - all_ds_pred_df['prediction'] = all_ds_pred_avg - cur_llr_res = compute_llr(all_ds_pred_df, None, start_edge_num, max_num_edge, step_size, gs_test) - cur_llr_res['dataset'] = 'all_average' - llr_ds = pd.concat([llr_ds, cur_llr_res], axis=0, ignore_index=True) - log.info('Calculating llr for all datasets average ... done') - llr_ds.to_csv(llr_dataset_file, sep='\t', index=False) - - return llr_ds - - -def get_cutoff(model_dict, gs_test, lr_cutoff): - cutoff_p_dict = {} - cutoff_llr_dict = {} - for ft in model_dict: - log.info(f'Calculating cutoff prob for {ft} ...') - model = model_dict[ft] - if ft == 'ex': - gs_test_df = gs_test.loc[:, ~gs_test.columns.str.endswith('_PPI')] - else: - gs_test_df = gs_test - prob = model.predict_proba(gs_test_df.iloc[:, 2:-1]) - prob = prob[:, 1] - pred_df = pd.DataFrame(prob, columns=['prob']) - pred_df = pd.concat([pred_df, gs_test_df.iloc[:, -1]], axis=1) - pred_df = pred_df.sort_values(by='prob', ascending=False) - - P = pred_df['label'].sum() - N = len(pred_df) - P - cumulative_pp = np.cumsum(pred_df['label']) - cumulative_pn = np.arange(len(pred_df)) + 1 - cumulative_pp - llr_values = np.log((cumulative_pp / cumulative_pn) / (P / N)) - pred_df['llr'] = llr_values - - # find the first prob that has llr >= lr_cutoff - cutoff = np.log(lr_cutoff) - cutoff_prob = None - for _, row in pred_df[::-1].iterrows(): - if not np.isinf(row['llr']) and row['llr'] >= cutoff: - cutoff_prob = row['prob'] - cutoff_llr = row['llr'] - break - - # if cutoff_prob is None, it means that the lr_cutoff is too high - # and we cannot find a cutoff prob that has llr >= lr_cutoff - if cutoff_prob is None: - log.error(f'Cannot find cutoff prob for {ft}, lower lr_cutoff and try again') - import sys; sys.exit(1) - - cutoff_p_dict[ft] = cutoff_prob - cutoff_llr_dict[ft] = cutoff_llr - - return cutoff_p_dict, cutoff_llr_dict - - -def predict_network(predict_results_file, cutoff_p, output_file): - for ft in predict_results_file: - log.info(f'Predicting network for {ft} ...') - predicted_df = pd.read_parquet(predict_results_file[ft]) - filtered_df = predicted_df[predicted_df['prediction'] > cutoff_p[ft]] - cur_file = output_file[ft] - filtered_df[['P1', 'P2']].to_csv(cur_file, sep='\t', index=False, header=None) - directory, file_name = os.path.split(cur_file) - base_name, extension = os.path.splitext(file_name) - new_file_name = f'{base_name}_with_p{extension}' - new_file = os.path.join(directory, new_file_name) - filtered_df.to_csv(new_file, sep='\t', index=False) - num_edges = len(filtered_df) - num_nodes = len(set(filtered_df['P1']) | set(filtered_df['P2'])) - log.info(f'Number of edges: {num_edges}') - log.info(f'Number of nodes: {num_nodes}') - log.info(f'Predicting network for {ft} ... done') - - -def predict_all_pairs(model_dict, all_ids, feature_type, ppi_feature, cc_dict, - mr_dict, extra_feature_file, prediction_dir, - output_file, n_jobs=1): - chunk_size = 1000000 - log.info('Genearating all pairs ...') - all_ids = sorted(all_ids) - all_pairs = list(itertools.combinations(all_ids, 2)) - log.info('Genearating all pairs ... done') - log.info(f'Number of valid ids {format(len(all_ids), ",")}') - # remove all "chunk_*.parquet" files in prediction_dir if they exist - pattern = os.path.join(prediction_dir, 'chunk_*.parquet') - matching_files = glob.glob(pattern) - for file in matching_files: - os.remove(file) - - for ft in model_dict: - log.info(f'Predicting all pairs ({format(len(all_pairs), ",")}) for {ft} ...') - model = model_dict[ft] - def process_and_save_chunk(start_idx, chunk_id): - chunk = all_pairs[start_idx:start_idx + chunk_size] - chunk_df = pd.DataFrame(chunk, columns=['P1', 'P2']) - if ft == 'ex': - cur_ppi_feature = None - else: - cur_ppi_feature = ppi_feature - feature_df = extract_features(chunk_df, feature_type, cc_dict, cur_ppi_feature, - extra_feature_file, mr_dict) - predictions = model.predict_proba(feature_df) - prediction_df = pd.DataFrame(predictions[:, 1], columns=['prediction']) - prediction_df['P1'] = chunk_df['P1'] - prediction_df['P2'] = chunk_df['P2'] - prediction_df = prediction_df[['P1', 'P2', 'prediction']] - prediction_df['prediction'] = prediction_df['prediction'].astype('float32') - table = pa.Table.from_pandas(prediction_df) - chunk_id = str(chunk_id).zfill(6) - output_file = f'{prediction_dir}/chunk_{chunk_id}.parquet' - pq.write_table(table, output_file) - - with ThreadPoolExecutor(max_workers=n_jobs) as executor: - for chunk_id, chunk_start in enumerate(range(0, len(all_pairs), chunk_size)): - executor.submit(process_and_save_chunk, chunk_start, chunk_id) - - pattern = os.path.join(prediction_dir, 'chunk_*.parquet') - matching_files = glob.glob(pattern) - matching_files.sort() - pq.write_table(pa.concat_tables([pq.read_table(file) for file in matching_files]), output_file[ft]) - for file in matching_files: - os.remove(file) - - log.info(f'Predicting all {format(len(all_pairs), ",")} pairs for {ft} done.') diff --git a/funmap/plotting.py b/funmap/plotting.py deleted file mode 100644 index 3a46a970..00000000 --- a/funmap/plotting.py +++ /dev/null @@ -1,789 +0,0 @@ -from typing import List, Dict -import os -from pathlib import Path -import pandas as pd -import numpy as np -from funmap.utils import get_data_dict, get_node_edge_overlap -import matplotlib -import matplotlib.pyplot as plt -from matplotlib.cbook import flatten -from matplotlib.ticker import MaxNLocator -import matplotlib.ticker as mticker -from matplotlib.lines import Line2D -from matplotlib.patches import Patch -import seaborn as sns -import PyPDF2 -from matplotlib_venn import venn2, venn2_circles -import networkx as nx -import warnings -import powerlaw -from funmap.logger import setup_logging, setup_logger - -log = setup_logger(__name__) - -def edge_number(x, pos): - """ - Formatter function to format the x-axis tick labels - - Parameters - ---------- - x : float - The value to be formatted. - pos : float - The tick position. - - Returns - ------- - s : str - The formatted string of the value. - """ - if x >= 1e6: - s = '{:1.1f}M'.format(x*1e-6) - elif x == 0: - s = '0' - else: - s = '{:1.0f}K'.format(x*1e-3) - return s - - -def plot_llr_comparison(cfg, validation_results, llr_ds, output_file='llr_comparison.pdf'): - name_type_dict = {item['name']: item['type'] for item in cfg['data_files']} - name_type_dict['all_average'] = 'other' - datasets = sorted(llr_ds['dataset'].unique().tolist()) - fig, ax = plt.subplots(figsize=(20, 16)) - - for ds in datasets: - start = -1 - cur_df = llr_ds[llr_ds['dataset'] == ds] - if name_type_dict[ds].upper() == 'RNA': - ltype = '--' - elif name_type_dict[ds].upper() == 'PROTEIN': - ltype = ':' - else: - ltype = '-' - ax.plot(cur_df['k'], cur_df['llr'], linestyle=ltype, label=ds) - if start == -1: - start = cur_df['k'].iloc[0] - - # plot llr_res with the same start point - for ft in validation_results: - llr_res = pd.read_csv(validation_results[ft]['llr_res_path'], sep='\t') - llr_res = llr_res[llr_res['k'] >= start] - ax.plot(llr_res['k'], llr_res['llr'], label=f'funmap_{ft}', linewidth=3) - - line_styles = [Line2D([0], [0], linestyle='--', color='black', label='RNA'), - Line2D([0], [0], linestyle=':', color='black', label='Protein'), - Line2D([0], [0], linestyle='-', color='black', label='Other')] - color_legend = ax.legend(handles=[Patch(color=line.get_color(), label=line.get_label()) for line in ax.lines], - bbox_to_anchor=(1.05, 1), title='Data type', fontsize=16, - bbox_transform=ax.transAxes) - linestyle_legend = ax.legend(handles=line_styles, - bbox_to_anchor=(1.05, 0.2), title='Model', fontsize=16, - bbox_transform=ax.transAxes) - ax.add_artist(color_legend) - ax.add_artist(linestyle_legend) - ax.xaxis.set_major_formatter(edge_number) - ax.set_xlabel('number of pairs', fontsize=16) - ax.set_ylabel('log likelihood ratio', fontsize=16) - ax.yaxis.grid(color = 'gainsboro', linestyle = 'dotted') - plt.tight_layout() - plt.box(on=None) - plt.savefig(output_file, bbox_inches='tight', bbox_extra_artists=[color_legend, linestyle_legend]) - plt.close(fig) - return output_file - - -def explore_data(data_config: Path, - min_sample_count: int, - output_dir: Path): - """ - Generate plots to explore and visualize data - - Parameters - ---------- - data_config: Path - Path to the data configuration file - min_sample_count: int - The minimum number of samples required to consider a dataset - output_dir: Path - The directory to save the output plots - - Returns - ------- - A list of file names of the generated plots - - """ - data_dict, _ = get_data_dict(data_config, min_sample_count) - fig_names = [] - - # sample wise median expression plot for each dataset - data = [] - data_keys = [] - - max_col_to_plot = 100 - - log.info('Generating plots to explore and visualize data ...') - for ds in data_dict: - log.info(f'... {ds}') - data_df = data_dict[ds] - fig, ax = plt.subplots(1, 2, figsize=(20, 5)) - cur_data = data_df.T - cur_data.dropna(inplace=True) - if cur_data.shape[1] > max_col_to_plot: - cur_data = cur_data.sample(max_col_to_plot, axis=1) - ax[0].boxplot(cur_data) - ax[0].set_ylabel('expression') - if data_df.shape[0] > max_col_to_plot: - ax[0].set_xlabel(f'random selected {max_col_to_plot} samples (total n={data_df.shape[0]})') - ax[0].set_xticklabels([]) - ax[0].set_xticks([]) - else: - ax[0].set_xlabel('sample') - ax[0].set_xticklabels(cur_data.columns, rotation=45, ha='right') - - # density plot for each sample in each dataset - for i in range(data_df.shape[0]): - sns.kdeplot(data_df.iloc[i, :], linewidth=1, ax=ax[1]) - locator=MaxNLocator(60) - ax[1].xaxis.set_major_locator(locator) - ax[1].set_xlabel('values') - ax[1].set_ylabel('density') - ticks_loc = ax[1].get_xticks().tolist() - ax[1].xaxis.set_major_locator(mticker.FixedLocator(ticks_loc)) - ax[1].set_xticklabels(ax[1].get_xticklabels(), rotation=45, ha='right') - - # set title for the figu - fig.suptitle(f'{ds}', fontsize=16) - fig.tight_layout() - cur_file_name = f'{ds}_sample_plot.pdf' - fig_names.append(cur_file_name) - log.info(f'Saving figure {cur_file_name} ...') - fig.savefig(output_dir / cur_file_name) - plt.close(fig) - data_keys.append(ds) - data.append(data_df.median(axis=1).values) - - fig, ax = plt.subplots(figsize=(10, 5)) - ax.boxplot(data) - ax.set_xticklabels(data_keys, rotation=45) - ax.yaxis.grid(color = 'gainsboro', linestyle = 'dotted') - - ax.set_xlabel('dataset') - ax.set_ylabel('median expression') - plt.box(on=None) - fig.tight_layout() - file_name = 'data_box_plot.pdf' - fig_names.append(file_name) - log.info(f'Saving figure {file_name} ...') - fig.savefig(output_dir / file_name) - plt.close(fig) - - # boxplot of the number of samples and genes - sample_count = pd.DataFrame( - {'count':[data_dict[ds].shape[0] for ds in data_dict], - 'dataset': [ds for ds in data_dict]}) - gene_count = pd.DataFrame({ - 'count':[data_dict[ds].shape[1] for ds in data_dict], - 'dataset': [ds for ds in data_dict]}) - - fig, ax = plt.subplots(1, 2, figsize=(10, 5)) - - bars0 = ax[0].barh(sample_count['dataset'], sample_count['count'], color='#774FA0') - bars1 = ax[1].barh(gene_count['dataset'], gene_count['count'], color='#7DC462') - - ax[0].spines['top'].set_visible(False) - ax[0].spines['left'].set_visible(False) - ax[0].spines['right'].set_visible(False) - ax[0].tick_params(axis='both', which='major', labelsize=12) - ax[0].bar_label(bars0, label_type='edge', fontsize=10) - ax[0].set_xlabel('number of samples') - - ax[1].spines['top'].set_visible(False) - ax[1].spines['left'].set_visible(False) - ax[1].spines['right'].set_visible(False) - ax[1].set_yticklabels([]) - ax[1].tick_params(axis='x', which='major', labelsize=12) - ax[1].tick_params(axis='y', which='both', left=False, right=False, labelleft=False) - - ax[1].bar_label(bars1, label_type='edge', fontsize=10) - ax[1].set_xlabel('number of genes') - - fig.tight_layout() - file_name = 'sample_gene_count.pdf' - fig.savefig(output_dir / file_name) - log.info(f'Saving figure {file_name} ...') - plt.close(fig) - fig_names.append(file_name) - return fig_names - - -def plot_results(cfg, validation_results, llr_ds, gs_dict, cutoff_llr, figure_dir): - fig_names = [] - file_name = 'llr_compare_dataset.pdf' - plot_llr_comparison(cfg, validation_results, llr_ds, output_file=figure_dir / file_name) - fig_names.append(file_name) - - if 'rp_pairs' in cfg: - file_names = plot_pair_llr(gs_dict, cfg['feature_type'], output_dir=figure_dir, - rp_pairs=cfg['rp_pairs']) - fig_names.extend(file_names) - - if cfg['task'] == 'protein_func': - file_name = 'llr_compare_networks.pdf' - # if gs_file is not specified, plot the llr comparison between networks - # because some the numbers are pre-computed based on default gs_file - if cfg['gs_file'] is None: - plot_llr_compare_networks(validation_results, cfg['lr_cutoff'], cutoff_llr, - output_file=figure_dir / file_name) - fig_names.append(file_name) - - # the information about other networks is fixed for now - other_network_info = { - 'name': ['BioGRID', 'BioPlex', 'HI-union', 'STRING'], - 'type': ['BioGRID', 'BioPlex', 'HI', 'STRING'], - 'url': ['https://figshare.com/ndownloader/files/39125054', - 'https://figshare.com/ndownloader/files/39125051', - 'https://figshare.com/ndownloader/files/39125093', - 'https://figshare.com/ndownloader/files/39125090' - ] - } - # convert the info to a data frame where the url is read as a dataframe - network_info = pd.DataFrame(other_network_info) - network_info['el'] = network_info['url'].apply(lambda x: pd.read_csv(x, - sep='\t', header=None)) - network_info = network_info.drop(columns=['url']) - - # for each funmap, create a dataframe - for ft in validation_results: - edge_file_path = validation_results[ft]['edge_list_path'] - funmap_el = pd.read_csv(edge_file_path, sep='\t', header=None) - funmap_df = pd.DataFrame({'name': ['FunMap'], 'type': ['FunMap'], 'el': [funmap_el]}) - all_network_info = pd.concat([network_info, funmap_df], ignore_index=True) - overlap_info = get_node_edge_overlap(all_network_info) - node_color, edge_color = '#7DC462', '#774FA0' - for (type, color) in zip(['node', 'edge'], [node_color, edge_color]): - fig_name = plot_overlap_venn(f'funmap_{ft}', overlap_info[type], type, color, figure_dir) - fig_names.append(fig_name) - - fig_name = plot_network_stats(all_network_info, ft, figure_dir) - fig_names.append(fig_name) - - return fig_names - - -def plot_1d_llr(ax, feature_df, feature_name, feature_type, data_type, n_bins): - """ - Plot the 1D histogram of the likelihood ratio for each feature - - Parameters - ---------- - ax : matplotlib.axes._subplots.AxesSubplot - The subplot where the histogram is to be plotted. - feature_df : pd.DataFrame - DataFrame containing all features and their values. - feature_name : str - The name of the feature for which histogram is to be plotted. - feature_type : str - The type of the feature, either 'CC' or 'MR'. - data_type : str - The type of data, either 'RNA' or 'PRO'. - n_bins : int - The number of bins for the histogram. - - Returns - ------- - None - """ - df = feature_df.loc[:, [feature_name]] - cur_df = df.dropna() - cur_df_vals = cur_df.values.reshape(-1) - clr = '#bcbddc' - data_range = {'CC': (-1, 1), 'MR': (0, 1)} - if data_type == 'PRO': - ax.hist(cur_df_vals, bins=n_bins, range=data_range[feature_type], color=clr, - orientation='horizontal', - density=True) - ax.text(0.95, 0.95, data_type, - verticalalignment='top', horizontalalignment='right', - transform=ax.transAxes, - rotation=-90, - color='black', fontsize=16) - ax.set_ylim(data_range[feature_type]) - ax.set_xlim(0,2.5) - else: - ax.hist(cur_df_vals, bins=n_bins, range=data_range[feature_type], color=clr, - density=True) - ax.text(0.02, 0.9, data_type, - verticalalignment='top', horizontalalignment='left', - transform=ax.transAxes, - color='black', fontsize=16) - ax.set_ylabel('density', fontsize=16) - ax.set_xlim(data_range[feature_type]) - ax.set_ylim(0,2.5) - ax.tick_params(axis='x', labelsize=16) - ax.tick_params(axis='y', labelsize=16) - - -def plot_2d_llr(ax, feature_df, feature_type, pair_name, rna_feature, pro_feature, n_bins): - """ - Plots a 2D log likelihood ratio between two features in a scatter plot. - - Parameters - ---------- - ax : matplotlib.axes._subplots.AxesSubplot - The subplot to plot the log likelihood ratio on. - feature_df : pandas.DataFrame - DataFrame containing all features and target variables. - feature_type : str - Type of feature, either "CC" (correlation coefficient) or "MR" (mutual rank). - pair_name : str - Name of the feature pair. - rna_feature : str - Name of the RNA feature in `feature_df`. - pro_feature : str - Name of the protein feature in `feature_df`. - n_bins : int - Number of bins in the 2D histogram. - - Returns - ------- - fig : matplotlib.collections.QuadMesh - The mesh plot of the log likelihood ratio. - - """ - data_types = ['RNA', 'PRO'] - label_mapping = {'PRO': 'Protein', 'RNA': 'mRNA'} - feature_label_mapping = {'CC': 'correlation coefficient', - 'MR': 'mutual rank'} - cnt = {} - cnt_pos_neg = {} - max_density = -1 - - data_range = {'CC': (-1, 1), 'MR': (0, 1)} - - for label in [0, 1]: - df = feature_df.loc[:, ['label', rna_feature, pro_feature]] - df = df.dropna() - cur_df = df.loc[df['label'] == label, [rna_feature, pro_feature]] - hist_density, _, _ = np.histogram2d(cur_df[rna_feature].values, - cur_df[pro_feature].values, bins=n_bins, - range=np.array([data_range[feature_type], - data_range[feature_type]]), - density=True) - max_density = max(max_density, np.max(hist_density)) - - for label in [0, 1]: - df = feature_df.loc[:, ['label', rna_feature, pro_feature]] - df = df.dropna() - cur_df = df.loc[df['label'] == label, [rna_feature, pro_feature]] - cnt[label] = cur_df.shape[0] - hist, _, _ = np.histogram2d(cur_df[rna_feature].values, - cur_df[pro_feature].values, bins=n_bins, - range=np.array([data_range[feature_type], - data_range[feature_type]])) - cnt_pos_neg[label] = hist - hh = ax.hist2d(cur_df[rna_feature].values, - cur_df[pro_feature].values, - bins=n_bins, - range=np.array([data_range[feature_type], - data_range[feature_type]]), - vmin=0, - vmax=max_density, - density=True) - - llr_vals = ((cnt_pos_neg[1] + 1)/(cnt_pos_neg[0] + cnt[0]/cnt[1]))/(cnt[1]/cnt[0]) - if feature_type == 'MR': - vmin = np.percentile(llr_vals, 5) - vmax = np.percentile(llr_vals, 95) - symmetric_max = max(abs(vmin), abs(vmax)) - vmin = -symmetric_max - vmax = symmetric_max - else: - vmin = -4 - vmax = 4 - cmap = plt.cm.RdBu_r - fig = ax.pcolormesh(hh[1], hh[2], np.transpose(np.log(llr_vals)), vmin=vmin, vmax=vmax, - cmap=cmap) - ax.text(0.02, 0.01, pair_name, - verticalalignment='bottom', horizontalalignment='left', - transform=ax.transAxes, - color='gray', fontsize=24, fontweight='bold') - ax.set_xlabel(f'{label_mapping[data_types[0]]}\n{feature_label_mapping[feature_type]}', fontsize=16) - ax.set_ylabel(f'{label_mapping[data_types[1]]}\n{feature_label_mapping[feature_type]}', fontsize=16) - if feature_type == 'CC': - ax.set_xticks([-1, -0.5, 0, 0.5, 1]) - ax.set_yticks([-1, -0.5, 0, 0.5, 1]) - else: - ax.set_xticks([0, 0.25, 0.5, 0.75, 1]) - ax.set_yticks([0, 0.25, 0.5, 0.75, 1]) - ax.tick_params(axis='x', labelsize=16) - ax.tick_params(axis='y', labelsize=16) - return fig - - -def plot_pair_llr(gs_dict, feature_type, output_dir, rp_pairs): - n_bins = 20 - file_names = [] - feature_type = ['CC', 'MR'] - # plot for each rna-protein pair using CC or MR as feature - for rp_pair in rp_pairs: - for ft in feature_type: - ft = ft.upper() - feature_df = gs_dict[ft] - fig, ax = plt.subplots(2, 2, figsize=(10, 10), - gridspec_kw={'width_ratios': [4, 1], - 'height_ratios': [1, 4]}) - rna_feature = rp_pair['rna'] - pro_feature = rp_pair['protein'] - plot_1d_llr(ax[0,0], feature_df, rna_feature, ft, 'RNA', n_bins) - ax[0, 0].xaxis.set_ticks_position('none') - ax[0, 0].set_xticklabels([]) - - ax[0,1].axis('off') - - heatmap2d = plot_2d_llr(ax[1,0], feature_df, ft, rp_pair['name'], - rna_feature, pro_feature, n_bins) - - plot_1d_llr(ax[1, 1], feature_df, pro_feature, ft, 'PRO', n_bins) - ax[1, 1].yaxis.set_ticks_position('none') - ax[1, 1].set_yticklabels([]) - - # add colorbar to the right of the plot - cax = fig.add_axes([1.05, 0.25, 0.03, 0.5]) - fig.colorbar(heatmap2d, cax=cax) - - # plt.tight_layout() - plt.box(on=None) - file_name = f"{rp_pair['name']}_rna_pro_{ft}_llr.pdf" - file_names.append(file_name) - plt.savefig(output_dir / file_name, bbox_inches='tight') - plt.close(fig) - - return file_names - - -def plot_llr_compare_networks(validaton_results, cutoff, cutoff_llr, output_file): - all_networks = [] - - for ft in validaton_results: - edge_list = pd.read_csv(validaton_results[ft]['edge_list_path'], sep='\t', header=None) - n_edge = len(edge_list) - n_node = len(set(edge_list.iloc[:, 0]) | set(edge_list.iloc[:, 1])) - all_networks.extend([ - (f'FunMap_{ft}', 'FunMap', n_node, n_edge, cutoff_llr[ft], np.exp(cutoff_llr[ft])) - ]) - - # these are pre-computed values - all_networks.extend( - [ - # name, type, n, e, llr, lr - ('HuRI', 'HI', 8272, 52548, 2.3139014130648827, 10.11), - ('HI-union', 'HI', 9094, 64006, 2.298975841813893, 9.96), - ('ProHD', 'ProHD', 2680, 61580, 4.039348296, 56.78), - # this is combined_score_700 - ('STRING', 'STRING', 16351, 240314, 5.229377563059293, 186.676572955849), - ('BioGRID', 'BioGRID', 17259, 654490, 2.6524642147396182, 14.18896024552041), - ('BioPlex', 'BioPlex', 13854, 154428, 3.3329858940888375, 28.021887299660047) - ] - ) - log.info(all_networks) - - cols = ['name', 'group', 'n', 'e', 'llr', 'lr'] - network_data = pd.DataFrame(all_networks, columns=cols) - x = np.array(network_data['n']) - y = np.array(network_data['llr']) - e = np.array(network_data['e']) - z = network_data['name'] - - fig, ax = plt.subplots(figsize=(10, 10)) - ax.set_axisbelow(True) - ax.xaxis.grid(color = 'gainsboro', linestyle = 'dotted') - ax.yaxis.grid(color = 'gainsboro', linestyle = 'dotted') - ax.get_ygridlines()[4].set_color('salmon') - ax.get_yticklabels()[4].set_color('red') - ax2 = ax.twinx() - - # we have 6 groups, so we need 6 colors - mycmap = matplotlib.colors.ListedColormap(['#de2d26', '#8B6CAF', '#0D95D0', - '#69A953', '#F1C36B', '#DC6C43']) - # group 0 is FunMap, group 1 is HI, group 2 is ProHD, group 3 is STRING, - # group 4 is BioGRID, group 5 is BioPlex - # the length of gro - color_group = [0] * len(validaton_results) + [1] * 2 + [2] * 1 + [3] * 1 + [4] * 1 + [5] * 1 - scatter = ax.scatter(x, y, c=color_group, cmap=mycmap, - s=e/1000) - ax.set_ylim(2, 6) - ax2.set_ylim(np.exp(2.0), np.exp(6)) - ax.set_xlabel('number of genes') - ax.set_yticks([2.0, 2.5, 3, 3.5, np.log(cutoff), 4, 4.5, 5, 5.5, 6]) - ax.set_ylabel('log likelihood ratio') - ax.spines['top'].set_visible(False) - ax.spines['left'].set_visible(False) - ax.spines['right'].set_visible(False) - ax2.set_ylabel('likelihood ratio') - ax2.set_yscale('log', base=np.e) - ax2.set_yticks([np.exp(2), np.exp(2.5), np.exp(3), np.exp(3.5), - cutoff, np.exp(4), np.exp(4.5), np.exp(5), - np.exp(5.5), np.exp(6)]) - ax.tick_params(axis='x', labelsize=12) - ax.tick_params(axis='y', labelsize=12) - ax2.tick_params(axis='y', labelsize=12) - ax2.get_yticklabels()[4].set_color('red') - ax2.get_yaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter()) - ax2.spines['top'].set_visible(False) - ax2.spines['left'].set_visible(False) - ax2.spines['right'].set_visible(False) - - handles1, labels1 = scatter.legend_elements(prop='sizes', num=4, alpha=1, - fmt='{x:.0f} K') - legend1 = ax.legend(handles1, labels1, - loc='upper left', labelspacing=1.8, borderpad=1.0, - title='number of pairs', frameon=True) - - ax.add_artist(legend1) - leg = ax.get_legend() - for i in range(len(leg.legendHandles)): - leg.legendHandles[i].set_color('gray') - - ax.xaxis.set_major_formatter(edge_number) - for i, txt in enumerate(z): - if txt == 'STRING': - ax.annotate(txt, (x[i]-500, y[i]+0.18), color='gray', fontsize=10) - elif txt == 'HI-union': - ax.annotate(txt, (x[i]-100, y[i]+0.1), color='gray', fontsize=10) - elif txt == 'BioGRID': - ax.annotate(txt, (x[i]-800, y[i]-0.25), color='gray', fontsize=10) - else: - ax.annotate(txt, (x[i]-100, y[i]-0.2), color='gray', fontsize=10) - - fig.tight_layout() - fig.savefig(output_file, bbox_inches='tight') - plt.close(fig) - - -def plot_overlap_venn(network_name, overlap, node_or_edge, color, output_dir): - """ - Plot the Venn diagrams for the overlap between different datasets. - - Parameters - ---------- - network_name : str - The name of the network to plot the overlap for. - overlap : dict - A dictionary containing the overlap between the datasets. - The keys are the names of the datasets, and the values are the sets - representing the overlap. - node_or_edge : str - A string indicating whether to plot the overlap of nodes or edges. - Must be one of 'node' or 'edge'. - color : str - The color to use for the FunMap dataset in the Venn diagrams. - output_dir : path-like - The directory to save the output figure in. - - Returns - ------- - file_name : str - The name of the file that the figure was saved as. - - """ - data = [] - for nw in overlap: - data.append(overlap[nw]) - max_area = max(map(sum, data)) - - def set_venn_scale(ax, true_area, reference_area=max_area): - s = np.sqrt(float(reference_area)/true_area) - ax.set_xlim(-s, s) - ax.set_ylim(-s, s) - - all_axes = [] - - n_plot = len(overlap) - fig, ax = plt.subplots(1, n_plot, figsize=(5*n_plot, 5)) - - for i, nw in enumerate(overlap): - cur_ax = ax[i] - all_axes.append(cur_ax) - labels = ('FunMap', nw) - out = venn2(overlap[nw], - set_labels=labels, alpha=1.0, - ax=cur_ax, set_colors=[color, 'white']) - venn2_circles(overlap[nw], ax=cur_ax, linestyle='solid', - color='gray', - linewidth=1) - if out.set_labels: - for text in out.set_labels: - text.set_fontsize(12) - - for text in out.subset_labels: - text.set_fontsize(10) - - # add title to the figure - name = 'genes' if node_or_edge == 'node' else 'edges' - fig.suptitle(f'Overlap of {name} ({network_name})', fontsize=16) - - for a, d in zip(flatten(ax), data): - set_venn_scale(a, sum(d)*1.5) - - file_name = f'{network_name}_overlap_{node_or_edge}.pdf' - fig.savefig(output_dir / file_name, bbox_inches='tight') - plt.close(fig) - return file_name - - -def plot_network_stats(network_info, feature_type, output_dir): - fig, ax = plt.subplots(1, 4, figsize=(20, 5)) - density = {} - average_shortest_path = {} - # these are pre-calculated since they take a long time to compute and - # the network is fixed - average_shortest_path ={ - 'BioGRID': 2.74, - 'BioPlex': 3.60, - 'HI-union': 3.70, - 'STRING': 3.95 - } - # if you want to recompute the average shortest path length, - # add the network name to this list - network_list = ['FunMap'] - for n in network_list: - network_el = network_info.loc[network_info['name'] == n, 'el'].values[0] - cur_network = nx.from_pandas_edgelist(network_el, source=0, target=1) - cur_density = nx.density(cur_network) - density[n] = cur_density - largest_cc = max(nx.connected_components(cur_network), key=len) - cur_cc = cur_network.subgraph(largest_cc).copy() - cur_average_shortest_path = nx.average_shortest_path_length(cur_cc) - average_shortest_path[n] = cur_average_shortest_path - cur_degrees = [val for (_, val) in cur_network.degree()] - if n == 'FunMap': # only fit for FunMap - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - fit = powerlaw.Fit(cur_degrees, discrete=True, xmax=250, estimate_discrete=False) - powerlaw.plot_pdf(cur_degrees, linear_bins=True, linestyle='None', marker='o', - markerfacecolor='None', color='#de2d26', - linewidth=3, ax=ax[0]) - # not plotting the power law fit - # fit.power_law.plot_pdf(linestyle='--',color='black', ax=ax[0]) - - # all the networks in network_info minus FunMap - other_networks = list(set(network_info['name'].tolist()) - set(['FunMap'])) - for n in other_networks: - network_el = network_info.loc[network_info['name'] == n, 'el'].values[0] - cur_network = nx.from_pandas_edgelist(network_el, source=0, target=1) - cur_density = nx.density(cur_network) - density[n] = cur_density - - ax[0].set_xlabel('degree') - ax[0].set_ylabel('p(x)') - ax[0].spines['top'].set_visible(False) - ax[0].spines['right'].set_visible(False) - ax[0].yaxis.grid(color = 'gainsboro', linestyle = 'dotted') - ax[0].set_axisbelow(True) - - # global average clustering coefficient - # these are pre-calculated since they take a long time to compute and - # the network is fixed - avg_cc = { - 'BioGRID': 0.125, - 'BioPlex': 0.103, - 'HI-union': 0.06, - 'STRING': 0.335 - } - for n in network_list: - network_el = network_info.loc[network_info['name'] == n, 'el'].values[0] - cur_network = nx.from_pandas_edgelist(network_el, source=0, target=1) - cur_cc = nx.average_clustering(cur_network) - avg_cc[n] = cur_cc - - network_list = network_info['name'].tolist() - ax[1].bar(network_list, [avg_cc[i] for i in network_list], width=0.5, align='center', - color='#E4C89A') - ax[1].spines['left'].set_position(('outward', 8)) - ax[1].spines['bottom'].set_position(('outward', 5)) - ax[1].spines['top'].set_visible(False) - ax[1].spines['left'].set_visible(False) - ax[1].spines['right'].set_visible(False) - ax[1].set_ylabel('Average clustering coefficient') - ticks_loc = ax[1].get_xticks() - ax[1].xaxis.set_major_locator(mticker.FixedLocator(ticks_loc)) - ax[1].set_xticklabels(network_list, rotation=45, ha='right') - ax[1].yaxis.grid(color = 'gainsboro', linestyle = 'dotted') - ax[1].set_axisbelow(True) - - ax[2].bar(network_list, [density[i] for i in network_list], width=0.5, align='center', - color = '#D8B2C6' - ) - ax[2].spines['left'].set_position(('outward', 8)) - ax[2].spines['bottom'].set_position(('outward', 5)) - ax[2].spines['top'].set_visible(False) - ax[2].spines['left'].set_visible(False) - ax[2].spines['right'].set_visible(False) - ax[2].set_ylabel('Density') - ticks_loc = ax[2].get_xticks() - ax[2].xaxis.set_major_locator(mticker.FixedLocator(ticks_loc)) - ax[2].set_xticklabels(network_list, rotation=45, ha='right') - ax[2].yaxis.grid(color = 'gainsboro', linestyle = 'dotted') - ax[2].set_axisbelow(True) - - ax[3].bar(network_list, [average_shortest_path[i] for i in network_list], width=0.5, align='center', - color = '#B6D8A6' - ) - ax[3].spines['left'].set_position(('outward', 8)) - ax[3].spines['bottom'].set_position(('outward', 5)) - ax[3].spines['top'].set_visible(False) - ax[3].spines['left'].set_visible(False) - ax[3].spines['right'].set_visible(False) - ax[3].set_ylabel('Average shortest path length') - ticks_loc = ax[3].get_xticks() - ax[3].xaxis.set_major_locator(mticker.FixedLocator(ticks_loc)) - ax[3].set_xticklabels(network_list, rotation=45, ha='right') - ax[3].yaxis.grid(color = 'gainsboro', linestyle = 'dotted') - ax[3].set_axisbelow(True) - - fig.suptitle(f'Network properties of Funmap ({feature_type})', fontsize=16) - file_name = f'funmap_{feature_type}_network_properties.pdf' - fig.savefig(output_dir / file_name, bbox_inches='tight') - plt.close(fig) - return file_name - - -def merge_and_delete(fig_dir, file_list, output_file): - """ - Merge multiple PDF files into one and delete the original files. - - Parameters - ---------- - fig_dir : Path - The directory where the PDF files are located. - file_list : list of str - The list of file names to be merged. - output_file : str or Path - The name of the output file. - - Returns - ------- - None - - """ - pdf_writer = PyPDF2.PdfWriter() - - total_page_num = 0 - for file in file_list: - pdf_reader = PyPDF2.PdfReader(fig_dir / file) - cur_page_num = len(pdf_reader.pages) - for page in range(cur_page_num): - pdf_writer.add_page(pdf_reader.pages[page]) - pdf_writer.add_outline_item(os.path.splitext(file)[0], total_page_num) - total_page_num = total_page_num + cur_page_num - - with open(fig_dir / output_file, 'wb') as fh: - pdf_writer.write(fh) - log.info('figures have been merged.') - - for filename in file_list: - try: - os.remove(fig_dir / filename) - except: - log.error(f'{filename} could not be deleted.') diff --git a/funmap/utils.py b/funmap/utils.py deleted file mode 100644 index 1f769274..00000000 --- a/funmap/utils.py +++ /dev/null @@ -1,507 +0,0 @@ -import csv -import yaml -import os -import tarfile -import re -import hashlib -import urllib -from urllib.parse import urlparse -from pathlib import Path -import pandas as pd -import shutil -from funmap.data_urls import misc_urls as urls -from funmap.logger import setup_logger - -log = setup_logger(__name__) - - -def is_url_scheme(path): - parsed = urlparse(path) - if parsed.scheme == 'file': - return False - - return parsed.scheme != '' - -def read_csv_with_md5_check(url, expected_md5=None, local_path='downloaded_file.csv', **kwargs): - try: - response = urllib.request.urlopen(url) - content = response.read() - - if expected_md5: - md5_hash = hashlib.md5(content).hexdigest() - if md5_hash != expected_md5: - log.error('gold standard file: MD5 hash mismatch, file may be corrupted.') - raise ValueError("MD5 hash mismatch, file may be corrupted.") - - # Save the content to a local file - with open(local_path, 'wb') as f: - f.write(content) - - df = pd.read_csv(local_path, **kwargs) - os.remove(local_path) - return df - except Exception as e: - return None - - -def check_gs_files_exist(file_dict, key='CC'): - if key.upper() == 'CC': - paths = file_dict.get('CC') - if paths is not None and os.path.exists(paths): - return True - elif key.upper() == 'MR': - paths = file_dict.get('MR') - if paths is not None and all(os.path.exists(p) for p in paths): - return True - else: - log.error(f"'{key}' is not a valid feature type (cc or mr).") - - return False - - -def normalize_filename(filename): - # Remove any characters that are not allowed in filenames - cleaned_filename = re.sub(r'[^\w\s.-]', '', filename) - # Replace spaces with underscores - cleaned_filename = cleaned_filename.replace(' ', '_') - return cleaned_filename - - -def get_data_dict(config, min_sample_count=15): - """ - Returns a dictionary of data from the provided data configuration, filtered to only include genes that are - coding and have at least `min_sample_count` samples. - - Returns - ------- - data_dict : dict - A dictionary where the keys are the names of the data files and the values are pandas DataFrames containing - the data from the corresponding file. - - """ - data_file = config['data_path'] - data_dict = {} - if 'filter_noncoding_genes' in config and config['filter_noncoding_genes']: - mapping = pd.read_csv(urls['mapping_file'], sep='\t') - # extract the data file from the tar.gz file - tmp_dir = 'tmp_data' - if not os.path.exists(tmp_dir): - os.mkdir(tmp_dir) - os.system(f'tar -xzf {data_file} --strip-components=1 -C {tmp_dir}') - # gene ids are gene symbols - for dt in config['data_files']: - log.info(f"processing ... {dt['name']}") - cur_feature = dt['name'] - # cur_file = get_obj_from_tgz(data_file, dt['path']) - # extract the data file from the tar.gz file - # cur_data= get_obj_from_tgz(data_file, dt['path']) - cur_data = pd.read_csv(os.path.join(tmp_dir, dt['path']), sep='\t', index_col=0, - header=0) - if cur_data.shape[1] < min_sample_count: - log.info(f"... {dt['name']} ... not enough samples, skipped") - continue - cur_data = cur_data.T - # exclude cohort with sample number < min_sample_count - # remove noncoding genes first - if config['filter_noncoding_genes']: - coding = mapping.loc[mapping['coding'] == 'coding', ['gene_name']] - coding_genes = list(set(coding['gene_name'].to_list())) - cur_data = cur_data[[c for c in cur_data.columns if c in coding_genes]] - # duplicated columns, for now select the last column - cur_data = cur_data.loc[:,~cur_data.columns.duplicated(keep='last')] - data_dict[cur_feature] = cur_data - - shutil.rmtree(tmp_dir) - - log.info('filtering out non valid ids ...') - all_valid_ids = set() - for i in data_dict: - cur_data = data_dict[i] - is_valid = cur_data.notna().sum() >= min_sample_count - # valid_count = np.sum(is_valid) - valid_p = cur_data.columns[is_valid].values - all_valid_ids = all_valid_ids.union(set(valid_p)) - - all_valid_ids = list(all_valid_ids) - all_valid_ids.sort() - log.info(f'total number of valid ids: {len(all_valid_ids)}') - - # filter out columns that are not in all_valid_ids - for i in data_dict: - cur_data = data_dict[i] - selected_columns = cur_data.columns.intersection(all_valid_ids) - cur_data = cur_data[selected_columns] - # it is possible the entire column is nan, remove it - cur_data = cur_data.dropna(axis=1, how='all') - data_dict[i] = cur_data - log.info(f'{i} -- ') - log.info(f' samples x ids: {cur_data.shape}') - - return data_dict, all_valid_ids - - -def get_node_edge(edge_list): - """ - Calculate the number of nodes and edges, and the ratio of edges per node, - and return the results in a dictionary format. - - Parameters - ---------- - edge_list : pandas DataFrame - The input DataFrame containing edge information. - - Returns - ------- - dict - A dictionary containing the number of nodes, the number of edges, - the ratio of edges per node, a list of nodes, and the edge_list. - - The keys of the dictionary are: - * n_node: int - The number of nodes in the network. - - * n_edge: int - The number of edges in the network. - - * edge_per_node: float - The ratio of edges per node in the network. - - * nodes: list - A list of nodes in the network. - - * edges: pandas DataFrame - The edge_list input DataFrame. - - """ - # remove duplidated rows in edge_list - edge_list = edge_list.drop_duplicates() - n_edge = len(edge_list) - nodes = set(edge_list.iloc[:,0].to_list()) | set(edge_list.iloc[:,1].to_list()) - return(dict(n_node = len(nodes), n_edge = n_edge, - edge_per_node = n_edge / len(nodes), - nodes = list(nodes), - edges = edge_list)) - - -def get_node_edge_overlap(network_info): - """ - Computes the node and edge overlap between networks. - - Parameters - ---------- - network_info : pandas DataFrame - A DataFrame with information about the networks. - - Returns - ------- - overlap : dict - A dictionary with the node and edge overlap between networks. - """ - networks = pd.DataFrame(columns = ['name', 'type', 'n_node', - 'n_edge', 'edge_per_node', - 'nodes', 'edges']) - - for _, row in network_info.iterrows(): - network_name = row['name'] - network_type = row['type'] - network_el = row['el'] - res = get_node_edge(network_el) - cur_df = pd.DataFrame({'name': [network_name], - 'type': [network_type], - 'n_node': [int(res['n_node'])], - 'n_edge': [int(res['n_edge'])], - 'edge_per_node': [res['edge_per_node']], - 'nodes': [res['nodes']], - 'edges': [res['edges']]}) - networks = pd.concat([networks, cur_df], ignore_index=True) - # overlap of nodes and edges - overlap = {} - - # node overlap - target = 'FunMap' - cur_res = {} - target_node_set = set(networks.loc[networks['name'] == target, - 'nodes'].tolist()[0]) - target_size = len(target_node_set) - for _, row in networks.iterrows(): - if row['name'] == target: - continue - cur_node_set = set(row['nodes']) - cur_size = len(cur_node_set) - overlap_size = len(target_node_set & cur_node_set) - cur_res[row['name']] = tuple([ - target_size - overlap_size, - cur_size - overlap_size, - overlap_size]) - - overlap['node'] = cur_res - - # edge overlap - cur_res = {} - target_edge_df = networks.loc[networks['name'] == target, 'edges'].tolist()[0] - target_edge_set = set(tuple(sorted(x)) for x in zip(target_edge_df.pop(0), - target_edge_df.pop(1))) - target_size = len(target_edge_set) - - for _, row in networks.iterrows(): - if row['name'] == target: - continue - edge_df = row['edges'] - edges = [tuple(sorted(x)) for x in zip(edge_df.pop(0), edge_df.pop(1))] - cur_edge_set = set(edges) - cur_size = len(cur_edge_set) - overlap_size = len(target_edge_set & cur_edge_set) - cur_res[row['name']] = tuple([ - target_size - overlap_size, - cur_size - overlap_size, - overlap_size]) - - overlap['edge'] = cur_res - return overlap - -def cleanup_experiment(config_file): - cfg = get_config(config_file) - results_dir = Path(cfg['results_dir']) - shutil.rmtree(results_dir, ignore_errors=True) - -def setup_experiment(config_file): - cfg = get_config(config_file) - results_dir = Path(cfg['results_dir']) - # create folders - folder_dict = cfg['subdirs'] - folders = [results_dir / Path(folder_dict[folder_name]) for folder_name in folder_dict] - for folder in folders: - folder.mkdir(parents=True, exist_ok=True) - - # save configuration to results folder - with open(str( results_dir / 'config.yml'), 'w') as fh: - yaml.dump(cfg, fh, sort_keys=False) - - return cfg - -def get_config(cfg_file: Path): - cfg = { - 'task': 'protein_func', - 'results_dir': 'results', - # the following directories are relative to the results_dir - 'subdirs': { - 'saved_model_dir': 'saved_models', - 'saved_data_dir': 'saved_data', - 'saved_predictions_dir': 'saved_predictions', - 'figure_dir': 'figures', - 'network_dir': 'networks', - }, - 'seed': 42, - 'feature_type': 'cc', - 'test_size': 0.2, - 'ml_type': 'xgboost', - 'gs_file': None, - 'extra_feature_file': None, - # 'filter_before_prediction': True, - # 'min_feature_count': 1, - 'min_sample_count': 20, - 'filter_noncoding_genes': False, - # 'filter_after_prediction': True, - # 'filter_criterion': 'max', - # 'filter_threshold': 0.95, - # 'filter_blacklist': False, - 'n_jobs': os.cpu_count(), - 'start_edge_num': 1000, - 'max_num_edges': 250000, - 'step_size': 1000, - 'lr_cutoff': 50, - } - - with open(cfg_file, 'r') as fh: - cfg_dict = yaml.load(fh, Loader=yaml.FullLoader) - - # use can change the following parameters in the config file - if 'task' in cfg_dict: - cfg['task'] = cfg_dict['task'] - assert cfg['task'] in ['protein_func', 'kinase_func'] - - if 'seed' in cfg_dict: - cfg['seed'] = cfg_dict['seed'] - - if 'feature_type' in cfg_dict: - cfg['feature_type'] = cfg_dict['feature_type'] - assert cfg['feature_type'] in ['cc', 'mr'] - - if 'extra_feature_file' in cfg_dict: - cfg['extra_feature_file'] = cfg_dict['extra_feature_file'] - - if 'gs_file' in cfg_dict: - cfg['gs_file'] = cfg_dict['gs_file'] - - if 'min_sample_count' in cfg_dict: - cfg['min_sample_count'] = cfg_dict['min_sample_count'] - - if 'n_jobs' in cfg_dict: - cfg['n_jobs'] = cfg_dict['n_jobs'] - - if 'start_edge_num' in cfg_dict: - cfg['start_edge_num'] = cfg_dict['start_edge_num'] - - if 'max_num_edges' in cfg_dict: - cfg['max_num_edges'] = cfg_dict['max_num_edges'] - - if 'step_size' in cfg_dict: - cfg['step_size'] = cfg_dict['step_size'] - - if 'lr_cutoff' in cfg_dict: - cfg['lr_cutoff'] = cfg_dict['lr_cutoff'] - - if 'name' in cfg_dict: - cfg['name'] = cfg_dict['name'] - else: - raise ValueError('name not specified in config file') - - if 'data_path' in cfg_dict: - cfg['data_path'] = cfg_dict['data_path'] - else: - raise ValueError('data_path not specified in config file') - - if cfg['task'] == 'protein_func': - if 'filter_noncoding_genes' in cfg_dict: - cfg['filter_noncoding_genes'] = cfg_dict['filter_noncoding_genes'] - else: - # ignore filter_noncoding_genes for kinase_func - cfg['filter_noncoding_genes'] = False - log.info('ignoring filter_noncoding_genes for kinase_func') - - if 'results_dir' in cfg_dict: - cfg['results_dir'] = cfg_dict['results_dir'] + '/' + cfg_dict['name'] - - if 'data_files' not in cfg_dict: - raise ValueError('data_files not specified in config file') - - # Check all files listed under data_files are also in the tar.gz file - data_files = cfg_dict['data_files'] - # List all the files in the tar.gz file - with tarfile.open(cfg['data_path'], "r:gz") as tar: - tar_files = {Path(file).name for file in tar.getnames()} - - # Check if all files in data_files are in tar_files - if not all(file['path'] in tar_files for file in data_files): - print('Files listed under data_files are not in the tar.gz file!') - raise ValueError('Files listed under data_files are not in the tar.gz file!') - - cfg['data_files'] = cfg_dict['data_files'] - - if 'rp_pairs' in cfg_dict: - cfg['rp_pairs'] = cfg_dict['rp_pairs'] - - return cfg - - -def check_gold_standard_file(file_path, min_count=10000): - """ - min_threshold : int - The minimum threshold for the lesser of '0' and '1' counts in the 'Class' column. - - """ - try: - with open(file_path, 'r', newline='') as tsv_file: - dialect = csv.Sniffer().sniff(tsv_file.read(2048)) - if dialect.delimiter != '\t': - print("Error: Incorrect TSV format. TSV files should be tab-separated.") - return False - except csv.Error as e: - print(f"CSV Error: {e}") - return False - - # Check data format and Class values - class_values = [] - with open(file_path, 'r', newline='') as tsv_file: - reader = csv.reader(tsv_file, delimiter='\t') - next(reader) # Skip header - for row_num, row in enumerate(reader, start=2): # Add row_num for better error reporting - if len(row) != 3: # Assuming each row should have 3 columns - log.error(f'Invalid row format in row {row_num}. Each row should have 3 columns.') - return False - class_value = row[2].strip() - if not class_value.isdigit() or int(class_value) not in (0, 1): - log.error(f'Invalid "Class" value in row {row_num}. Must be 0 or 1.') - return False - class_values.append(int(class_value)) - - # Check Class value counts and ratio - count_0 = class_values.count(0) - count_1 = class_values.count(1) - lesser_count = min(count_0, count_1) - - if lesser_count < min_count: - log.error(f"The lesser of 0 and 1 occurrences ({lesser_count}) does not meet the threshold. " - f"Expected at least {min_count}.") - return False - - return True - - -def check_extra_feature_file(file_path, missing_value='NA'): - """ - Notes - ----- - This function checks the following criteria for the measurement TSV file: - - The file must have a header row. - - There must be at least 3 columns. - - The first two columns are gene/protein IDs. - - The data type for each additional column (excluding the first two columns) must be consistent - and can be either integer, float, or the specified missing_value. - - If any of the checks fail, the function will print informative error messages and return False. - - The TSV file should be tab-separated. - """ - try: - with open(file_path, 'r', newline='') as tsv_file: - dialect = csv.Sniffer().sniff(tsv_file.read(1024)) - if dialect.delimiter != '\t': - log.error("Incorrect TSV format. TSV files should be tab-separated.") - return False - except csv.Error: - log.error("Unable to read TSV file.") - return False - - # Check header and number of columns - with open(file_path, 'r', newline='') as tsv_file: - reader = csv.reader(tsv_file, delimiter='\t') - header = next(reader, None) - if header is None: - log.error("The TSV file must have a header row.") - return False - - num_columns = len(header) - if num_columns < 3: - log.error("The TSV file must have at least 3 columns.") - return False - - # Check data type consistency for additional columns - with open(file_path, 'r', newline='') as tsv_file: - reader = csv.DictReader(tsv_file, delimiter='\t') - column_data_types = {} - for row_num, row in enumerate(reader, start=2): # Add row_num for better error reporting - for column_name, value in row.items(): - if column_name not in header[:2]: # Skip the first two columns (Protein_1 and Protein_2) - if value == missing_value: - continue - try: - float_value = float(value) - if float_value.is_integer(): - value = int(float_value) - except ValueError: - log.error("Invalid data type in row %d, column '%s'. " - "The value '%s' should be either an integer, a float, or '%s' (missing value).", - row_num, column_name, value, missing_value) - return False - # Store the data type of each column (float or integer) - data_type = float if '.' in value else int - if column_name not in column_data_types: - column_data_types[column_name] = data_type - elif column_data_types[column_name] != data_type: - log.error("Inconsistent data type in column '%s'. " - "Expected a consistent data type (integer, float, or '%s') for all rows.", - column_name, missing_value) - return False - - return True diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..043ed966 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,50 @@ +[project] +name = "funmap" +version = "0.2.0" +description = "generate gene co-function networks using omics data" +authors = [{ name = "Zhiao Shi", email = "zhiao.shi@gmail.com" }] +readme = "README.md" +license = { text = "MIT license" } +requires-python = ">=3.8" +classifiers = [ + "Programming Language :: Rust", + "Programming Language :: Python :: Implementation :: CPython", + "Programming Language :: Python :: Implementation :: PyPy", +] +dependencies = [ + "h5py>=3.12.1", + "numpy>=2.1.1", + "seaborn>=0.13.2", + "tqdm>=4.66.5", + "xgboost>=2.1.1", + "pyyaml>=6.0.2", + "scipy>=1.14.1", + "pyarrow>=17.0.0", + "pandas>=2.2.3", + "joblib>=1.4.2", + "matplotlib>=3.9.2", + "scikit-learn>=1.5.2", + "imbalanced-learn>=0.12.3", + "pypdf2>=3.0.1", + "matplotlib-venn>=1.1.1", + "networkx>=3.3", + "powerlaw>=1.5", + "click>=8.1.7", + "tables>=3.10.1", +] + +[project.scripts] +funmap = "funmap.cli:cli" + +[build-system] +requires = ["maturin>=1,<2"] +build-backend = "maturin" + +[tool.maturin] +features = ["pyo3/extension-module"] +python-source = "python" +module-name = "funmap._lib" + +[tool.rye] +managed = true +dev-dependencies = ["maturin>=1.7.4"] diff --git a/funmap/__init__.py b/python/funmap/__init__.py similarity index 100% rename from funmap/__init__.py rename to python/funmap/__init__.py diff --git a/python/funmap/cli.py b/python/funmap/cli.py new file mode 100644 index 00000000..9b762b24 --- /dev/null +++ b/python/funmap/cli.py @@ -0,0 +1,403 @@ +import gzip +import os +import pickle +from pathlib import Path + +import click +import numpy as np +import pandas as pd + +from funmap import __version__ +from funmap import _lib as funmap_lib +from funmap.data_urls import misc_urls as urls +from funmap.funmap import ( + compute_features, + compute_llr, + dataset_llr, + get_cutoff, + get_ppi_feature, + predict_all_pairs, + predict_network, + prepare_gs_data, + train_ml_model, +) +from funmap.logger import setup_logger, setup_logging +from funmap.plotting import explore_data, merge_and_delete, plot_results +from funmap.utils import ( + check_extra_feature_file, + check_gold_standard_file, + cleanup_experiment, + new_extra_feature, + setup_experiment, +) + +log = setup_logger(__name__) + + +@click.group(help="funmap command line interface") +@click.version_option(version=f"{__version__}") +def cli(): + """ + Command line interface for funmap. + """ + click.echo("====== funmap =======") + + +@cli.command(help="check the data quality") +@click.option( + "--config-file", + "-c", + required=True, + type=click.Path(exists=True), + help="path to experiment configuration yaml file", +) +@click.option( + "--force-rerun", + "-f", + is_flag=True, + default=False, + help="if set, will remove results from previous run first", +) +def qc(config_file, force_rerun): + if force_rerun: + while True: + confirmation = input( + "Do you want to remove results from previous run? (y/n): " + ) + if confirmation.lower() == "y": + click.echo("Removing results from previous run") + cleanup_experiment(config_file) + break + elif confirmation.lower() == "n": + click.echo("Not removing results from previous run") + break + else: + click.echo("Invalid input. Please enter 'y' or 'n'.") + + setup_logging(config_file) + log.info(f"Running QC...") + cfg = setup_experiment(config_file) + all_fig_names = [] + figure_dir = Path(cfg["results_dir"]) / cfg["subdirs"]["figure_dir"] + min_sample_count = cfg["min_sample_count"] + fig_names = explore_data(cfg, min_sample_count, figure_dir) + all_fig_names.extend(fig_names) + merge_and_delete(figure_dir, all_fig_names, "qc.pdf") + log.info("figure qc.pdf saved to {}".format(figure_dir)) + log.info("QC complete") + + +@cli.command() +def rust(): + mapping = pd.read_csv(urls["mapping_file"], sep="\t") + coding = mapping.loc[mapping["coding"] == "coding", ["gene_name"]] + coding_genes = list(set(coding["gene_name"].to_list())) + funmap_lib.process_files( + [ + "dummy/dia.tsv", + "dummy/methyl.tsv", + "dummy/RNAseq.tsv", + "dummy/tmt_abundance.tsv", + ], + ["dummy/all_no_methyl.tsv"], + "test_data", + coding_genes, + ) + + +@cli.command(help="run funmap") +@click.option( + "--config-file", + "-c", + required=True, + type=click.Path(exists=True), + help="path to experiment configuration yaml file", +) +@click.option( + "--force-rerun", + "-f", + is_flag=True, + default=False, + help="if set, will remove results from previous run first", +) +def run(config_file, force_rerun): + click.echo("Running funmap...") + if force_rerun: + while True: + confirmation = input( + "Do you want to remove results from previous run? (y/n): " + ) + if confirmation.lower() == "y": + click.echo("Removing results from previous run") + cleanup_experiment(config_file) + break + elif confirmation.lower() == "n": + click.echo("Not removing results from previous run") + break + else: + click.echo("Invalid input. Please enter 'y' or 'n'.") + + setup_logging(config_file) + cfg = setup_experiment(config_file) + extra_feature_folder = cfg["extra_feature_folder"] + extra_feature_df = None + uniq_gene = None + if extra_feature_folder is not None: + log.info("Loading extra feature file into dataframe") + (uniq_gene, extra_feature_df) = new_extra_feature(extra_feature_folder) + gs_file = cfg["gs_file"] + if (gs_file is not None) and (not check_gold_standard_file(gs_file)): + return + + task = cfg["task"] + seed = cfg["seed"] + np.random.seed(seed) + ml_type = cfg["ml_type"] + feature_type = cfg["feature_type"] + # min_feature_count = cfg['min_feature_count'] + min_sample_count = cfg["min_sample_count"] + # filter_before_prediction = cfg['filter_before_prediction'] + test_size = cfg["test_size"] + # filter_after_prediction = cfg['filter_after_prediction'] + # filter_criterion = cfg['filter_criterion'] + # filter_threshold = cfg['filter_threshold'] + # filter_blacklist = cfg['filter_blacklist'] + n_jobs = cfg["n_jobs"] + lr_cutoff = cfg["lr_cutoff"] + max_num_edges = cfg["max_num_edges"] + step_size = cfg["step_size"] + start_edge_num = cfg["start_edge_num"] + only_extra_features = cfg["only_extra_features"] + results_dir = Path(cfg["results_dir"]) + saved_data_dir = results_dir / cfg["subdirs"]["saved_data_dir"] + model_dir = results_dir / cfg["subdirs"]["saved_model_dir"] + prediction_dir = results_dir / cfg["subdirs"]["saved_predictions_dir"] + network_dir = results_dir / cfg["subdirs"]["network_dir"] + figure_dir = results_dir / cfg["subdirs"]["figure_dir"] + + if cfg["task"] == "protein_func": + feature_mapping = ["ex", "ei"] + else: + feature_mapping = ["ex"] + # here the file stored a dictionary of ml models + ml_model_file = { + feature: model_dir / f"model_{feature}.pkl.gz" for feature in feature_mapping + } + predicted_all_pairs_file = { + feature: prediction_dir / f"predicted_all_pairs_{feature}.parquet" + for feature in feature_mapping + } + llr_res_file = { + feature: results_dir / f"llr_results_{feature}.tsv" + for feature in feature_mapping + } + edge_list_file = { + feature: network_dir / f"funmap_{feature}.tsv" for feature in feature_mapping + } + # gold standard data include specified feature (cc or mr) and ppi feature (if applicable) + # and extra feature if applicable + gs_df_file = saved_data_dir / "gold_standard_data.h5" + # blacklist_file = urls['funmap_blacklist'] + # llr obtained with each invividual dataset + llr_dataset_file = results_dir / "llr_dataset.tsv" + gs_train = gs_test = None + cutoff_p = cutoff_llr = None + ml_model_dict = {} + + # compute and save cc, mr results + cc_dict, mr_dict, all_valid_ids = compute_features( + cfg, feature_type, min_sample_count, saved_data_dir + ) + + if uniq_gene is not None: + all_valid_ids = uniq_gene + + gs_args = { + "task": task, + "saved_data_dir": saved_data_dir, + "cc_dict": cc_dict, + "mr_dict": mr_dict, + "feature_type": feature_type, + "gs_file": gs_file, + "extra_feature_df": extra_feature_df, + "valid_id_list": all_valid_ids, + "test_size": test_size, + "seed": seed, + } + + all_edge_list_exist = all( + os.path.exists(file_path) for file_path in edge_list_file.values() + ) + + if all_edge_list_exist: + log.info( + "Fumap network(s) already exists. Skipping model training and prediction." + ) + else: + all_model_exist = all( + os.path.exists(file_path) for file_path in ml_model_file.values() + ) + if all_model_exist: + log.info("Trained model(s) exists. Loading model(s) ...") + ml_model_dict = {} + # feature: ex or ei + for feature in ml_model_file: + with gzip.open(ml_model_file[feature], "rb") as fh: + ml_model = pickle.load(fh) + ml_model_dict[feature] = ml_model + log.info("Loading model(s) ... done") + if not gs_df_file.exists(): + log.error( + f"Trained models found but gold standard data file {gs_df_file} " + f"does not exist." + ) + return + with pd.HDFStore(gs_df_file, mode="r") as store: + gs_train = store["train"] + gs_test = store["test"] + else: + gs_train, gs_test = prepare_gs_data(**gs_args) + with pd.HDFStore(gs_df_file, mode="w") as store: + store.put("train", gs_train) + store.put("test", gs_test) + ml_model_dict = train_ml_model( + gs_train, ml_type, seed, n_jobs, feature_mapping, model_dir + ) + + all_predicted_all_pairs_exist = all( + os.path.exists(file_path) for file_path in predicted_all_pairs_file.values() + ) + if all_predicted_all_pairs_exist: + log.info("Predicted all pairs already exists. Skipping prediction.") + else: + log.info("Predicting all pairs ...") + if task == "protein_func": + ppi_feature = get_ppi_feature() + else: + ppi_feature = None + pred_all_pairs_args = { + "model_dict": ml_model_dict, + "all_ids": all_valid_ids, + "feature_type": feature_type, + "ppi_feature": ppi_feature, + "cc_dict": cc_dict, + "mr_dict": mr_dict, + "extra_feature_df": extra_feature_df, + "prediction_dir": prediction_dir, + "output_file": predicted_all_pairs_file, + "n_jobs": n_jobs, + } + predict_all_pairs(**pred_all_pairs_args) + log.info("Predicting all pairs ... done") + + cutoff_p, cutoff_llr = get_cutoff(ml_model_dict, gs_test, lr_cutoff) + log.info(f"cutoff probability: {cutoff_p}") + log.info(f"cutoff llr: {cutoff_llr}") + predict_network(predicted_all_pairs_file, cutoff_p, edge_list_file) + + if not gs_df_file.exists(): + gs_train, gs_test = prepare_gs_data(**gs_args) + with pd.HDFStore(gs_df_file, mode="w") as store: + store.put("train", gs_train) + store.put("test", gs_test) + else: + if gs_test is None: + with pd.HDFStore(gs_df_file, mode="r") as store: + gs_train = store["train"] + gs_test = store["test"] + + all_llr_res_exist = all( + os.path.exists(file_path) for file_path in llr_res_file.values() + ) + all_edge_list_exist = all( + os.path.exists(file_path) for file_path in edge_list_file.values() + ) + if all_llr_res_exist and all_edge_list_exist: + log.info("LLR results already exist.") + else: + for ft in feature_mapping: + log.info(f"Computing LLR for {ft} ...") + if not predicted_all_pairs_file[ft].exists(): + log.error( + f"Predicted all pairs file {predicted_all_pairs_file[ft]} does not exist." + ) + return + predicted_all_pairs = pd.read_parquet(predicted_all_pairs_file[ft]) + # also save the llr results to file + compute_llr( + predicted_all_pairs, + llr_res_file[ft], + start_edge_num, + max_num_edges, + step_size, + gs_test, + ) + log.info(f"Computing LLR for {ft} ... done") + + validation_res = {} + for ft in feature_mapping: + validation_res[ft] = { + "llr_res_path": llr_res_file[ft], + "edge_list_path": edge_list_file[ft], + } + if not llr_dataset_file.exists(): + log.info("Computing LLR for each dataset ...") + # feature_dict = cc_dict if feature_type == 'cc' else mr_dict + # use CC features for individual dataset + llr_ds = dataset_llr( + all_valid_ids, + cc_dict, + "cc", + gs_test, + start_edge_num, + max_num_edges, + step_size, + llr_dataset_file, + extra_feature_df, + ) + log.info("Done.") + else: + llr_ds = pd.read_csv(llr_dataset_file, sep="\t") + + if not ml_model_dict: + log.info("Trained model(s) exists. Loading model(s) ...") + for feature in ml_model_file: + with gzip.open(ml_model_file[feature], "rb") as fh: + ml_model = pickle.load(fh) + ml_model_dict[feature] = ml_model + log.info("Loading model(s) ... done") + + all_fig_names = [] + if cutoff_llr is None: + cutoff_p, cutoff_llr = get_cutoff(ml_model_dict, gs_test, lr_cutoff) + + gs_dict = {} + gs_dict[feature_type.upper()] = gs_train + if task == "protein_func" and feature_type.upper() == "MR" and "rp_pairs" in cfg: + # extract gs data for CC and MR for plotting + gs_args = { + "task": task, + "saved_data_dir": saved_data_dir, + "cc_dict": cc_dict, + "mr_dict": mr_dict, + "feature_type": "cc", + "gs_file": gs_file, + # no extra feature for plotting + "extra_feature_df": extra_feature_df, + "valid_id_list": all_valid_ids, + "test_size": test_size, + "seed": seed, + } + gs_train, gs_test = prepare_gs_data(**gs_args) + gs_dict["CC"] = gs_train + + fig_names = plot_results( + cfg, validation_res, llr_ds, gs_dict, cutoff_llr, figure_dir + ) + all_fig_names.extend(fig_names) + + merge_and_delete(figure_dir, all_fig_names, "results.pdf") + + +if __name__ == "__main__": + cli() diff --git a/python/funmap/data_urls.py b/python/funmap/data_urls.py new file mode 100644 index 00000000..e0c6915c --- /dev/null +++ b/python/funmap/data_urls.py @@ -0,0 +1,18 @@ +misc_urls = { + "reactome_gold_standard": "https://figshare.com/ndownloader/files/38647601", + "reactome_gold_standard_md5": "671942763b6a7c32506cba1ed9900fe6", + "funmap_blacklist": "https://figshare.com/ndownloader/files/39033977", + "mapping_file": "https://figshare.com/ndownloader/files/39033971", +} + +# the information about other networks is fixed for now +network_info = { + "name": ["BioGRID", "BioPlex", "HI-union", "STRING"], + "type": ["BioGRID", "BioPlex", "HI", "STRING"], + "url": [ + "https://figshare.com/ndownloader/files/39125054", + "https://figshare.com/ndownloader/files/39125051", + "https://figshare.com/ndownloader/files/39125093", + "https://figshare.com/ndownloader/files/39125090", + ], +} diff --git a/python/funmap/funmap.py b/python/funmap/funmap.py new file mode 100644 index 00000000..fecb1108 --- /dev/null +++ b/python/funmap/funmap.py @@ -0,0 +1,797 @@ +import glob +import gzip +import itertools +import math +import os +import pickle +from collections import Counter, defaultdict +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path +from typing import List + +import h5py +import numpy as np +import pandas as pd +import pyarrow as pa +import pyarrow.parquet as pq +import xgboost as xgb +from sklearn.model_selection import GridSearchCV, StratifiedKFold, train_test_split +from sklearn.utils import resample +from tqdm import tqdm + +from funmap.data_urls import misc_urls as urls +from funmap.data_urls import network_info +from funmap.logger import setup_logger +from funmap.utils import ( + get_data_dict, + is_url_scheme, + read_csv_with_md5_check, +) + +log = setup_logger(__name__) + + +def get_valid_gs_data( + gs_path: str, valid_gene_list: List[str], md5=None +): # TODO: Add error message if there is no overlap between valid_gene_list and the gold standard set + log.info(f'Loading gold standard feature file "{gs_path}" ...') + if is_url_scheme(gs_path): + gs_edge_df = read_csv_with_md5_check( + gs_path, expected_md5=md5, local_path="download_gs_file", sep="\t" + ) + if gs_edge_df is None: + raise ValueError("Failed to download gold standard file") + else: + gs_edge_df = pd.read_csv(gs_path, sep="\t") + + log.info("Done loading gold standard feature file") + gs_edge_df = gs_edge_df.rename( + columns={gs_edge_df.columns[0]: "P1", gs_edge_df.columns[1]: "P2"} + ) + gs_edge_df = gs_edge_df[ + gs_edge_df["P1"].isin(valid_gene_list) + & gs_edge_df["P2"].isin(valid_gene_list) + & (gs_edge_df["P1"] != gs_edge_df["P2"]) + ] + m = ~pd.DataFrame(np.sort(gs_edge_df[["P1", "P2"]], axis=1)).duplicated() + gs_edge_df = gs_edge_df[list(m)] + gs_edge_df.reset_index(drop=True, inplace=True) + # rename the last column name to 'label' + gs_edge_df.rename(columns={gs_edge_df.columns[-1]: "label"}, inplace=True) + + return gs_edge_df + + +def pairwise_mutual_rank(pcc_matrix): + """ + Calculate the pairwise mutual rank matrix based on the given Pearson correlation coefficient matrix. + + Parameters: + ----------- + pcc_matrix : numpy.ndarray + The Pearson correlation coefficient matrix. It should be a square matrix where + pcc_matrix[i, j] represents the correlation coefficient between variables i and j. + + Returns: + -------- + numpy.ndarray + A matrix containing the pairwise mutual ranks between variables based on the + provided Pearson correlation coefficient matrix. The matrix has the same shape + as the input pcc_matrix. + + Mutual Rank Calculation: + ------------------------ + The mutual rank between two variables A and B, based on their Pearson correlation coefficients, + is a measure of their relative rankings within their respective groups of correlated variables. + The formula for calculating the mutual rank is given by: + + mr_{AB} = sqrt((r_{AB} / n_B) * (r_{BA} / n_A)) + + Where: + - mr_{AB} is the mutual rank between variables A and B. + - r_{AB} is the rank of the correlation coefficient between A and B among all other correlation + coefficients involving A (excluding NaN values). + - n_B is the number of valid (non-NaN) correlation coefficients involving variable B. + - r_{BA} is the rank of the correlation coefficient between B and A among all other correlation + coefficients involving B (excluding NaN values). + - n_A is the number of valid (non-NaN) correlation coefficients involving variable A. + + Steps: + - For each variable pair (A, B): + - Calculate the rank of the correlation coefficient between A and B among all other correlation + coefficients involving A. This rank is denoted as r_{AB}. + - Calculate the rank of the correlation coefficient between B and A among all other correlation + coefficients involving B. This rank is denoted as r_{BA}. + - For each variable pair (A, B): + - Determine the number of valid (non-NaN) correlation coefficients involving variable B, denoted as n_B. + - Determine the number of valid (non-NaN) correlation coefficients involving variable A, denoted as n_A. + - For each variable pair (A, B): + - Compute the mutual rank mr_{AB} using the formula mentioned earlier. + - Populate the mutual rank matrix: + - Create a new matrix with the same shape as the input correlation coefficient matrix, + initialized with NaN values. + - For each valid variable pair (A, B), assign the corresponding mutual rank mr_{AB} + to the matrix at the appropriate indices. + + The resulting matrix contains the mutual ranks between all pairs of variables based on their + Pearson correlation coefficients. Higher mutual rank values indicate stronger and more consistent + correlations between variables. + """ + valid_a = ~np.isnan(pcc_matrix) + valid_b = valid_a.T + + rank_ab = ( + np.argsort(pcc_matrix, axis=1).argsort(axis=1, kind="mergesort") + 1 + ) # Start ranks from 1 + rank_ba = ( + np.argsort(pcc_matrix, axis=0).argsort(axis=0, kind="mergesort") + 1 + ) # Start ranks from 1 + + n_a = np.sum(valid_a, axis=1) + n_b = np.sum(valid_b, axis=0) + + valid_indices_a, valid_indices_b = np.where(valid_a) + + mr_values = np.sqrt( + (rank_ab[valid_indices_a, valid_indices_b] / n_b[valid_indices_b]) + * (rank_ba[valid_indices_a, valid_indices_b] / n_a[valid_indices_a]) + ) + + mr_matrix = np.full_like(pcc_matrix, np.nan) + mr_matrix[valid_indices_a, valid_indices_b] = mr_values + + return mr_matrix + + +def compute_features(cfg, feature_type, min_sample_count, output_dir): + """Compute the pearson correlation coefficient for each edge in the list of edges and for each""" + data_dict, all_valid_ids = get_data_dict(cfg, min_sample_count) + cc_dict = {} + mr_dict = {} + if not cfg["only_extra_features"]: + for i in data_dict: + cc_file = os.path.join(output_dir, f"cc_{i}.h5") + cc_dict[i] = cc_file + + mr_dict = {} + for i in data_dict: + mr_file = os.path.join(output_dir, f"mr_{i}.h5") + mr_dict[i] = mr_file + + all_cc_exist = all(os.path.exists(file_path) for file_path in cc_dict.values()) + if all_cc_exist: + log.info("All cc files exist. Skipping feature computation.") + if feature_type == "cc": + return cc_dict, mr_dict, all_valid_ids + + if feature_type == "mr": + all_mr_exist = all( + os.path.exists(file_path) for file_path in mr_dict.values() + ) + if all_cc_exist and all_mr_exist: + log.info("All mr files exist. Skipping feature computation.") + return cc_dict, mr_dict, all_valid_ids + + log.debug(f"Computing {feature_type} features") + for i in data_dict: + cc_file = cc_dict[i] + if os.path.exists(cc_file): + continue + log.info(f"Computing pearson correlation coefficient matrix for {i}") + x = data_dict[i].values.astype(np.float32) + df = pd.DataFrame(x) + corr_matrix = df.corr(method="pearson", min_periods=min_sample_count) + arr = corr_matrix.values + upper_indices = np.triu_indices(arr.shape[0]) + with h5py.File(cc_dict[i], "w") as hf: + # only store the upper triangle part + hf.create_dataset("cc", data=arr[upper_indices]) + hf.create_dataset("ids", data=data_dict[i].columns.values.astype("S")) + cc_dict[i] = cc_file + + # compute pairwise mutual rank features + if feature_type == "mr": + log.info(f"Computing mutual rank matrix for {i}") + arr_mr = pairwise_mutual_rank(arr) + upper_indices = np.triu_indices(arr_mr.shape[0]) + with h5py.File(mr_dict[i], "w") as hf: + # only store the upper triangle part + hf.create_dataset("mr", data=arr_mr[upper_indices]) + hf.create_dataset( + "ids", data=data_dict[i].columns.values.astype("S") + ) + + return cc_dict, mr_dict, all_valid_ids + + +def balance_classes(df, random_state=42): + class_column = df.columns[-1] # Assuming class column is the last column + + class_values = df[class_column].unique() + if len(class_values) != 2: + raise ValueError("The class column should have exactly 2 unique values.") + + class_0 = df[df[class_column] == class_values[0]] + class_1 = df[df[class_column] == class_values[1]] + minority_class = class_0 if len(class_0) < len(class_1) else class_1 + majority_class = class_1 if minority_class is class_0 else class_0 + + majority_class_undersampled = resample( + majority_class, + replace=False, + n_samples=len(minority_class), + random_state=random_state, + ) + + balanced_df = pd.concat([minority_class, majority_class_undersampled]) + # Shuffle the rows in the balanced DataFrame + balanced_df = balanced_df.sample(frac=1, random_state=random_state).reset_index( + drop=True + ) + + return balanced_df + + +def assemble_feature_df(h5_file_mapping, df, dataset="cc"): + df.reset_index(drop=True, inplace=True) + # Initialize feature_df with columns for HDF5 file keys and 'label' + file_keys = list(h5_file_mapping.keys()) + feature_df = pd.DataFrame(columns=file_keys + ["label"]) + + def get_1d_indices(i_array, j_array, n): + # Ensure i and j are within bounds + mask = (i_array < n) & (j_array < n) + i = np.minimum(i_array[mask], j_array[mask]) + j = np.maximum(i_array[mask], j_array[mask]) + + # Calculate 1D indices + return i * n - i * (i - 1) // 2 + (j - i) + + # Iterate over HDF5 files and load feature values + for key, file_path in h5_file_mapping.items(): + with h5py.File(file_path, "r") as h5_file: + gene_ids = h5_file["ids"][:] + gene_to_index = {gene.astype(str): idx for idx, gene in enumerate(gene_ids)} + + # Get gene indices for P1 and P2 + p1_indices = np.array( + [gene_to_index.get(gene, -1) for gene in df.iloc[:, 0]] + ) + p2_indices = np.array( + [gene_to_index.get(gene, -1) for gene in df.iloc[:, 1]] + ) + + f_dataset = h5_file[dataset] + f_values = np.empty(len(df), dtype=float) + valid_indices = (p1_indices != -1) & (p2_indices != -1) + + linear_indices = get_1d_indices( + p1_indices[valid_indices], p2_indices[valid_indices], len(gene_ids) + ) + + f_values[valid_indices] = f_dataset[:][linear_indices] + f_values[~valid_indices] = np.nan + + # Add feature values to the feature_df + feature_df[key] = f_values + + # if the last column is 'label', assign it to feature_df + if df.columns[-1] == "label": + feature_df["label"] = df[df.columns[-1]] + else: + # delete the 'label' column from feature_df + del feature_df["label"] + + return feature_df + + +def extract_features( + df, + feature_type, + cc_dict, + ppi_feature=None, + extra_feature_df=None, + mr_dict=None, + chunk_data: tuple[int, int] | None = None, +): + """ + extract_features - creates the final feature `pandas` dataframe used by xgboost + + Params: + df: pd.DataFrame + feature_type: str - Type of feature used. Currently supporting pearson correlation or mutual rank. Accepted values: 'cc' or 'mr'. + cc_dict: dict[str, str] - dictionary containing paths to the pre-calculated features. + ... + extra_feature: str - path for file containing extra features + + Extra Features: + + This is a TSV file with the following format: + + Gene A Gene B Feature X Feature Y + -------- -------- ----------- ----------- + ABC1 DEF2 0.12 34.5 + GHI3 JKLM4 -0.6 NA + … … … … + """ + if feature_type == "mr": + if not mr_dict: + raise ValueError("mr dict is empty") + + feature_dict = cc_dict if feature_type == "cc" else mr_dict + feature_df = assemble_feature_df(feature_dict, df, feature_type) + if ppi_feature is not None: + ppi_dict = {key: set(value) for key, value in ppi_feature.items()} + for ppi_source, ppi_tuples in ppi_dict.items(): + feature_df[ppi_source] = df.apply( + lambda row: 1 if (row["P1"], row["P2"]) in ppi_tuples else 0, axis=1 + ) + + # TODO: add extra features if provided + if extra_feature_df is not None: + if chunk_data is not None: + (start_idx, chunk_size) = chunk_data + feature_df = pd.concat( + [feature_df, extra_feature_df[start_idx : start_idx + chunk_size]], + axis=1, + ) + else: + feature_df = pd.concat( + [feature_df, extra_feature_df], + axis=1, + ) + # move 'label' column to the end of the dataframe if it exists + if "label" in feature_df.columns: + feature_df = feature_df[ + [col for col in feature_df.columns if col != "label"] + ["label"] + ] + return feature_df + + +def get_ppi_feature(): + """ + Returns a dictionary of protein-protein interaction (PPI) features. + + The PPI features are extracted from data in the "network_info" dictionary and are specified by the + "feature_names" list. The URLs of the relevant data are extracted from "network_info" and read + using the Pandas library. The resulting PPI data is stored in the "ppi_features" dictionary and + returned by the function. + + Returns: + ppi_features: dict + A dictionary with PPI features, where the keys are the feature names and the values are lists of tuples + representing the protein interactions. + """ + feature_names = ["BioGRID", "BioPlex", "HI-union"] + urls = [ + network_info["url"][i] + for i in range(len(network_info["name"])) + if network_info["name"][i] in feature_names + ] + + ppi_features = {} + # use pandas to read the file + for i, url in enumerate(urls): + data = pd.read_csv(url, sep="\t", header=None) + data = data.apply(lambda x: tuple(sorted(x)), axis=1) + ppi_name = f"{feature_names[i]}_PPI" + ppi_features[ppi_name] = data.tolist() + + return ppi_features + + +def train_ml_model(data_df, ml_type, seed, n_jobs, feature_mapping, model_dir): + assert ml_type == "xgboost", "ML model must be xgboost" + models = train_model( + data_df.iloc[:, :-1], + data_df.iloc[:, -1], + seed, + n_jobs, + feature_mapping, + model_dir, + ) + + return models + + +def train_model(X, y, seed, n_jobs, feature_mapping, model_dir): + model_params = { + "n_estimators": [10, 20, 50, 100], + "max_depth": [1, 2, 3, 4, 5], + "learning_rate": [0.01, 0.02, 0.05, 0.1, 0.2, 0.5, 1.0], + } + + models = {} + for ft in feature_mapping: + # use only mutual rank + log.info(f"Training model for {ft} ...") + xgb_model = xgb.XGBClassifier(random_state=seed, n_jobs=n_jobs) + cv = StratifiedKFold(n_splits=5, random_state=seed, shuffle=True) + clf = GridSearchCV( + xgb_model, model_params, scoring="roc_auc", cv=cv, n_jobs=1, verbose=2 + ) + if ft == "ex": + # exclude ppi features + Xtrain = X.loc[:, ~X.columns.str.endswith("_PPI")] + else: + Xtrain = X + model = clf.fit(Xtrain, y) + models[ft] = model + ml_model_file = model_dir / f"model_{ft}.pkl.gz" + with gzip.open(ml_model_file, "wb") as fh: + pickle.dump(model, fh) + + log.info(f"Training model for {ft} ... done") + + return models + + +def compute_llr( + predicted_all_pairs, + llr_res_file, + start_edge_num, + max_num_edges, + step_size, + gs_test, + is_extra_feat=False, +): + # make sure max_num_edges is smaller than the number of non-NA values + assert is_extra_feat or max_num_edges < np.count_nonzero( + ~np.isnan(predicted_all_pairs.iloc[:, -1].values) + ), "max_num_edges should be smaller than the number of non-NA values" + + cur_col_name = "prediction" + cur_results = predicted_all_pairs.nlargest(max_num_edges, cur_col_name) + selected_edges_all = cur_results[["P1", "P2"]].apply( + lambda row: tuple(sorted({row["P1"], row["P2"]})), axis=1 + ) + + gs_test_pos_set = set( + gs_test[gs_test["label"] == 1][["P1", "P2"]].apply( + lambda row: tuple(sorted({row["P1"], row["P2"]})), axis=1 + ) + ) + gs_test_neg_set = set( + gs_test[gs_test["label"] == 0][["P1", "P2"]].apply( + lambda row: tuple(sorted({row["P1"], row["P2"]})), axis=1 + ) + ) + n_gs_test_pos_set = len(gs_test_pos_set) + n_gs_test_neg_set = len(gs_test_neg_set) + + result_dict = defaultdict(list) + # llr_res_dict only save maximum of max_steps data points for downstream + # analysis / plotting + total = math.ceil((max_num_edges - start_edge_num) / step_size) + 1 + for k in tqdm( + range(start_edge_num, max_num_edges + step_size, step_size), + total=total, + ascii=" >=", + ): + selected_edges = set(selected_edges_all[:k]) + all_nodes = set(itertools.chain.from_iterable(selected_edges)) + common_pos_edges = selected_edges & gs_test_pos_set + common_neg_edges = selected_edges & gs_test_neg_set + try: + lr = ( + len(common_pos_edges) + / len(common_neg_edges) + / (n_gs_test_pos_set / n_gs_test_neg_set) + ) + except ZeroDivisionError: + lr = 0 + llr = np.log(lr) if lr > 0 else np.nan + n_node = len(all_nodes) + result_dict["k"].append(k) + result_dict["n"].append(n_node) + result_dict["llr"].append(llr) + + llr_res = pd.DataFrame(result_dict) + if llr_res_file is not None: + llr_res.to_csv(llr_res_file, sep="\t", index=False) + + return llr_res + + +def prepare_gs_data(**kwargs): + task = kwargs["task"] + cc_dict = kwargs["cc_dict"] + mr_dict = kwargs["mr_dict"] + gs_file = kwargs["gs_file"] + gs_file_md5 = None + feature_type = kwargs["feature_type"] + extra_feature_df = kwargs["extra_feature_df"] + valid_id_list = kwargs["valid_id_list"] + test_size = kwargs["test_size"] + seed = kwargs["seed"] + + if gs_file is None: + gs_file = urls["reactome_gold_standard"] + gs_file_md5 = urls["reactome_gold_standard_md5"] + + log.info("Preparing gold standard data") + gs_df = get_valid_gs_data(gs_file, valid_id_list, md5=gs_file_md5) + gs_df_balanced = balance_classes(gs_df, random_state=seed) + del gs_df + gs_train, gs_test = train_test_split( + gs_df_balanced, + test_size=test_size, + random_state=seed, + stratify=gs_df_balanced.iloc[:, -1], + ) + if task == "protein_func": + ppi_feature = get_ppi_feature() + else: + ppi_feature = None + gs_train_df = extract_features( + gs_train, feature_type, cc_dict, ppi_feature, extra_feature_df, mr_dict, None + ) + gs_test_df = extract_features( + gs_test, feature_type, cc_dict, ppi_feature, extra_feature_df, mr_dict, None + ) + + # store both the ids with gs_test_df for later use + # add the first two column of gs_test to gs_test_df at the beginning + gs_test_df = pd.concat([gs_test.iloc[:, :2], gs_test_df], axis=1) + + log.info("Preparing gs data ... done") + return gs_train_df, gs_test_df + + +def extract_dataset_feature(all_pairs, feature_file, feature_type="cc"): + # convert all_pairs to a dataframe + df = pd.DataFrame(all_pairs, columns=["P1", "P2"]) + + def get_1d_indices(i_array, j_array, n): + # Ensure i and j are within bounds + mask = (i_array < n) & (j_array < n) + i = np.minimum(i_array[mask], j_array[mask]) + j = np.maximum(i_array[mask], j_array[mask]) + + # Calculate 1D indices + return i * n - i * (i - 1) // 2 + (j - i) + + df = pd.DataFrame(all_pairs, columns=["P1", "P2"]) + with h5py.File(feature_file, "r") as h5_file: + gene_ids = h5_file["ids"][:] + gene_to_index = {gene.astype(str): idx for idx, gene in enumerate(gene_ids)} + + # Get gene indices for P1 and P2 + p1_indices = np.array([gene_to_index.get(gene, -1) for gene in df.iloc[:, 0]]) + p2_indices = np.array([gene_to_index.get(gene, -1) for gene in df.iloc[:, 1]]) + + f_dataset = h5_file[feature_type] + f_values = np.empty(len(df), dtype=float) + valid_indices = (p1_indices != -1) & (p2_indices != -1) + + linear_indices = get_1d_indices( + p1_indices[valid_indices], p2_indices[valid_indices], len(gene_ids) + ) + + f_values[valid_indices] = f_dataset[:][linear_indices] + f_values[~valid_indices] = np.nan + # extracted feature is the 'prediction' column + df["prediction"] = f_values + + return df + + +def extract_extra_features(all_pairs, ef_df): + # convert all_pairs to a dataframe + df = pd.DataFrame(all_pairs, columns=["P1", "P2"]) + return pd.merge(df, ef_df, on=["P1", "P2"], how="left") + + +def dataset_llr( + all_ids, + feature_dict, + feature_type, + gs_test, + start_edge_num, + max_num_edge, + step_size, + llr_dataset_file, + extra_feature_df, +): + llr_ds = pd.DataFrame() + all_ids_sorted = sorted(all_ids) + all_pairs = list(itertools.combinations(all_ids_sorted, 2)) + all_ds_pred = None + + for dataset in feature_dict: + log.info(f"Calculating llr for {dataset} ...") + feature_file = feature_dict[dataset] + predicted_all_pairs = extract_dataset_feature( + all_pairs, feature_file, feature_type + ) + if all_ds_pred is None: + all_ds_pred = predicted_all_pairs["prediction"].values + else: + all_ds_pred = np.vstack( + (all_ds_pred, predicted_all_pairs["prediction"].values) + ) + + cur_llr_res = compute_llr( + predicted_all_pairs, None, start_edge_num, max_num_edge, step_size, gs_test + ) + cur_llr_res["dataset"] = dataset + llr_ds = pd.concat([llr_ds, cur_llr_res], axis=0, ignore_index=True) + llr_ds.to_csv(llr_dataset_file, sep="\t", index=False) + log.info(f"Calculating llr for {dataset} ... done") + + # calculate llr for all datasets based on the average prediction + log.info("Calculating llr for all datasets average ...") + all_ds_pred_df = pd.DataFrame(all_pairs, columns=["P1", "P2"]) + if all_ds_pred.ndim == 1: + all_ds_pred_avg = all_ds_pred + elif all_ds_pred.ndim == 2: + all_ds_pred_avg = np.nanmean(all_ds_pred, axis=0) + else: + raise ValueError(f"Invalid dimension for all_ds_pred: {all_ds_pred.ndim}") + all_ds_pred_df["prediction"] = all_ds_pred_avg + cur_llr_res = compute_llr( + all_ds_pred_df, None, start_edge_num, max_num_edge, step_size, gs_test + ) + cur_llr_res["dataset"] = "all_average" + llr_ds = pd.concat([llr_ds, cur_llr_res], axis=0, ignore_index=True) + log.info("Calculating llr for all datasets average ... done") + llr_ds.to_csv(llr_dataset_file, sep="\t", index=False) + if extra_feature_df is not None: + log.info("Calculating LLR for extra features") + extra_feature_df = extract_extra_features( + all_pairs, extra_feature_df + ) # filter out unused pairs + features = extra_feature_df.columns.values[2:] + for f in features: + subset_df = extra_feature_df[["P1", "P2", f]] + subset_df.columns.values[-1] = "prediction" + log.info(f"Calculating llr for extra feature {f} ...") + cur_llr_res = compute_llr( + subset_df, None, start_edge_num, max_num_edge, step_size, gs_test, True + ) + cur_llr_res["dataset"] = f + "_EXTRAFEAT" + llr_ds = pd.concat([llr_ds, cur_llr_res], axis=0, ignore_index=True) + llr_ds.to_csv(llr_dataset_file, sep="\t", index=False) + log.info(f"Calculating llr for {f} ... done") + return llr_ds + + +def get_cutoff(model_dict, gs_test, lr_cutoff): + cutoff_p_dict = {} + cutoff_llr_dict = {} + for ft in model_dict: + log.info(f"Calculating cutoff prob for {ft} ...") + model = model_dict[ft] + if ft == "ex": + gs_test_df = gs_test.loc[:, ~gs_test.columns.str.endswith("_PPI")] + else: + gs_test_df = gs_test + prob = model.predict_proba(gs_test_df.iloc[:, 2:-1]) + prob = prob[:, 1] + pred_df = pd.DataFrame(prob, columns=["prob"]) + pred_df = pd.concat([pred_df, gs_test_df.iloc[:, -1]], axis=1) + pred_df = pred_df.sort_values(by="prob", ascending=False) + + P = pred_df["label"].sum() + N = len(pred_df) - P + cumulative_pp = np.cumsum(pred_df["label"]) + cumulative_pn = np.arange(len(pred_df)) + 1 - cumulative_pp + llr_values = np.log((cumulative_pp / cumulative_pn) / (P / N)) + pred_df["llr"] = llr_values + + # find the first prob that has llr >= lr_cutoff + cutoff = np.log(lr_cutoff) + cutoff_prob = None + for _, row in pred_df[::-1].iterrows(): + if not np.isinf(row["llr"]) and row["llr"] >= cutoff: + cutoff_prob = row["prob"] + cutoff_llr = row["llr"] + break + + # if cutoff_prob is None, it means that the lr_cutoff is too high + # and we cannot find a cutoff prob that has llr >= lr_cutoff + if cutoff_prob is None: + log.error( + f"Cannot find cutoff prob for {ft}, lower lr_cutoff and try again" + ) + import sys + + sys.exit(1) + + cutoff_p_dict[ft] = cutoff_prob + cutoff_llr_dict[ft] = cutoff_llr + + return cutoff_p_dict, cutoff_llr_dict + + +def predict_network(predict_results_file, cutoff_p, output_file): + for ft in predict_results_file: + log.info(f"Predicting network for {ft} ...") + predicted_df = pd.read_parquet(predict_results_file[ft]) + filtered_df = predicted_df[predicted_df["prediction"] > cutoff_p[ft]] + cur_file = output_file[ft] + filtered_df[["P1", "P2"]].to_csv(cur_file, sep="\t", index=False, header=None) + directory, file_name = os.path.split(cur_file) + base_name, extension = os.path.splitext(file_name) + new_file_name = f"{base_name}_with_p{extension}" + new_file = os.path.join(directory, new_file_name) + filtered_df.to_csv(new_file, sep="\t", index=False) + num_edges = len(filtered_df) + num_nodes = len(set(filtered_df["P1"]) | set(filtered_df["P2"])) + log.info(f"Number of edges: {num_edges}") + log.info(f"Number of nodes: {num_nodes}") + log.info(f"Predicting network for {ft} ... done") + + +def predict_all_pairs( + model_dict, + all_ids, + feature_type, + ppi_feature, + cc_dict, + mr_dict, + extra_feature_df, + prediction_dir, + output_file, + n_jobs=1, +): + chunk_size = 1000000 + log.info("Genearating all pairs ...") + all_ids = sorted(all_ids) + all_pairs = list(itertools.combinations(all_ids, 2)) + log.info("Genearating all pairs ... done") + log.info(f'Number of valid ids {format(len(all_ids), ",")}') + # remove all "chunk_*.parquet" files in prediction_dir if they exist + pattern = os.path.join(prediction_dir, "chunk_*.parquet") + matching_files = glob.glob(pattern) + for file in matching_files: + os.remove(file) + + for ft in model_dict: + log.info(f'Predicting all pairs ({format(len(all_pairs), ",")}) for {ft} ...') + model = model_dict[ft] + + def process_and_save_chunk(start_idx, chunk_id): + chunk = all_pairs[start_idx : start_idx + chunk_size] + chunk_df = pd.DataFrame(chunk, columns=["P1", "P2"]) + if ft == "ex": + cur_ppi_feature = None + else: + cur_ppi_feature = ppi_feature + feature_df = extract_features( + chunk_df, + feature_type, + cc_dict, + cur_ppi_feature, + extra_feature_df, + mr_dict, + (start_idx, chunk_size), + ) + predictions = model.predict_proba(feature_df) + prediction_df = pd.DataFrame(predictions[:, 1], columns=["prediction"]) + prediction_df["P1"] = chunk_df["P1"] + prediction_df["P2"] = chunk_df["P2"] + prediction_df = prediction_df[["P1", "P2", "prediction"]] + prediction_df["prediction"] = prediction_df["prediction"].astype("float32") + table = pa.Table.from_pandas(prediction_df) + chunk_id = str(chunk_id).zfill(6) + output_file = f"{prediction_dir}/chunk_{chunk_id}.parquet" + pq.write_table(table, output_file) + + with ThreadPoolExecutor(max_workers=n_jobs) as executor: + for chunk_id, chunk_start in enumerate( + range(0, len(all_pairs), chunk_size) + ): + executor.submit(process_and_save_chunk, chunk_start, chunk_id) + + pattern = os.path.join(prediction_dir, "chunk_*.parquet") + matching_files = glob.glob(pattern) + matching_files.sort() + pq.write_table( + pa.concat_tables([pq.read_table(file) for file in matching_files]), + output_file[ft], + ) + for file in matching_files: + os.remove(file) + + log.info(f'Predicting all {format(len(all_pairs), ",")} pairs for {ft} done.') diff --git a/funmap/logger.py b/python/funmap/logger.py similarity index 92% rename from funmap/logger.py rename to python/funmap/logger.py index ef53b84f..d994cdf9 100644 --- a/funmap/logger.py +++ b/python/funmap/logger.py @@ -7,6 +7,7 @@ LOG_LEVEL = logging.INFO + def setup_logging(run_config, log_config="logging.yml") -> None: """ Setup ``logging.config`` @@ -44,7 +45,7 @@ def setup_logging(run_config, log_config="logging.yml") -> None: handler_critical.setLevel(logging.CRITICAL) # Create formatters (if needed) - formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s') + formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(message)s") handler_debug.setFormatter(formatter) handler_info.setFormatter(formatter) handler_warning.setFormatter(formatter) @@ -69,7 +70,9 @@ def setup_logging(run_config, log_config="logging.yml") -> None: logger.addHandler(console_handler) # Log a warning message to the console - logger.warning(f'"{log_config}" not found. Using basicConfig with custom log files.') + logger.warning( + f'"{log_config}" not found. Using basicConfig with custom log files.' + ) return with open(log_config, "rt") as f: @@ -85,6 +88,6 @@ def setup_logging(run_config, log_config="logging.yml") -> None: def setup_logger(name): - log = logging.getLogger(f'funmap.{name}') + log = logging.getLogger(f"funmap.{name}") log.setLevel(LOG_LEVEL) return log diff --git a/python/funmap/plotting.py b/python/funmap/plotting.py new file mode 100644 index 00000000..57019304 --- /dev/null +++ b/python/funmap/plotting.py @@ -0,0 +1,954 @@ +from typing import List, Dict +import os +from pathlib import Path +import pandas as pd +import numpy as np +from funmap.utils import get_data_dict, get_node_edge_overlap +import matplotlib +import matplotlib.pyplot as plt +from matplotlib.cbook import flatten +from matplotlib.ticker import MaxNLocator +import matplotlib.ticker as mticker +from matplotlib.lines import Line2D +from matplotlib.patches import Patch +import seaborn as sns +import PyPDF2 +from matplotlib_venn import venn2, venn2_circles +import networkx as nx +import warnings +import powerlaw +from funmap.logger import setup_logging, setup_logger + +log = setup_logger(__name__) + + +def edge_number(x, pos): + """ + Formatter function to format the x-axis tick labels + + Parameters + ---------- + x : float + The value to be formatted. + pos : float + The tick position. + + Returns + ------- + s : str + The formatted string of the value. + """ + if x >= 1e6: + s = "{:1.1f}M".format(x * 1e-6) + elif x == 0: + s = "0" + else: + s = "{:1.0f}K".format(x * 1e-3) + return s + + +def plot_llr_comparison( + cfg, validation_results, llr_ds, output_file="llr_comparison.pdf" +): + name_type_dict = {item["name"]: item["type"] for item in cfg["data_files"]} + name_type_dict["all_average"] = "other" + datasets = sorted(llr_ds["dataset"].unique().tolist()) + fig, ax = plt.subplots(figsize=(20, 16)) + + for ds in datasets: + start = -1 + cur_df = llr_ds[llr_ds["dataset"] == ds] + if ds.endswith("_EXTRAFEAT"): + ds = ds.replace("_EXTRAFEAT", "") + ltype = "-." + elif name_type_dict[ds].upper() == "RNA": + ltype = "--" + elif name_type_dict[ds].upper() == "PROTEIN": + ltype = ":" + else: + ltype = "-" + ax.plot(cur_df["k"], cur_df["llr"], linestyle=ltype, label=ds) + if start == -1: + start = cur_df["k"].iloc[0] + + # plot llr_res with the same start point + for ft in validation_results: + llr_res = pd.read_csv(validation_results[ft]["llr_res_path"], sep="\t") + llr_res = llr_res[llr_res["k"] >= start] + ax.plot(llr_res["k"], llr_res["llr"], label=f"funmap_{ft}", linewidth=3) + + line_styles = [ + Line2D([0], [0], linestyle="--", color="black", label="RNA"), + Line2D([0], [0], linestyle=":", color="black", label="Protein"), + Line2D([0], [0], linestyle="-.", color="black", label="Extra Feature"), + Line2D([0], [0], linestyle="-", color="black", label="Other"), + ] + color_legend = ax.legend( + handles=[ + Patch(color=line.get_color(), label=line.get_label()) for line in ax.lines + ], + bbox_to_anchor=(1.05, 1), + title="Data type", + fontsize=16, + bbox_transform=ax.transAxes, + ) + linestyle_legend = ax.legend( + handles=line_styles, + bbox_to_anchor=(1.05, 0.2), + title="Model", + fontsize=16, + bbox_transform=ax.transAxes, + ) + ax.add_artist(color_legend) + ax.add_artist(linestyle_legend) + ax.xaxis.set_major_formatter(edge_number) + ax.set_xlabel("number of pairs", fontsize=16) + ax.set_ylabel("log likelihood ratio", fontsize=16) + ax.yaxis.grid(color="gainsboro", linestyle="dotted") + plt.tight_layout() + plt.box(on=None) + plt.savefig( + output_file, + bbox_inches="tight", + bbox_extra_artists=[color_legend, linestyle_legend], + ) + plt.close(fig) + return output_file + + +def explore_data(data_config: Path, min_sample_count: int, output_dir: Path): + """ + Generate plots to explore and visualize data + + Parameters + ---------- + data_config: Path + Path to the data configuration file + min_sample_count: int + The minimum number of samples required to consider a dataset + output_dir: Path + The directory to save the output plots + + Returns + ------- + A list of file names of the generated plots + + """ + data_dict, _ = get_data_dict(data_config, min_sample_count) + fig_names = [] + + # sample wise median expression plot for each dataset + data = [] + data_keys = [] + + max_col_to_plot = 100 + + log.info("Generating plots to explore and visualize data ...") + for ds in data_dict: + log.info(f"... {ds}") + data_df = data_dict[ds] + fig, ax = plt.subplots(1, 2, figsize=(20, 5)) + cur_data = data_df.T + cur_data.dropna(inplace=True) + if cur_data.shape[1] > max_col_to_plot: + cur_data = cur_data.sample(max_col_to_plot, axis=1) + ax[0].boxplot(cur_data) + ax[0].set_ylabel("expression") + if data_df.shape[0] > max_col_to_plot: + ax[0].set_xlabel( + f"random selected {max_col_to_plot} samples (total n={data_df.shape[0]})" + ) + ax[0].set_xticklabels([]) + ax[0].set_xticks([]) + else: + ax[0].set_xlabel("sample") + ax[0].set_xticklabels(cur_data.columns, rotation=45, ha="right") + + # density plot for each sample in each dataset + for i in range(data_df.shape[0]): + sns.kdeplot(data_df.iloc[i, :], linewidth=1, ax=ax[1]) + locator = MaxNLocator(60) + ax[1].xaxis.set_major_locator(locator) + ax[1].set_xlabel("values") + ax[1].set_ylabel("density") + ticks_loc = ax[1].get_xticks().tolist() + ax[1].xaxis.set_major_locator(mticker.FixedLocator(ticks_loc)) + ax[1].set_xticklabels(ax[1].get_xticklabels(), rotation=45, ha="right") + + # set title for the figu + fig.suptitle(f"{ds}", fontsize=16) + fig.tight_layout() + cur_file_name = f"{ds}_sample_plot.pdf" + fig_names.append(cur_file_name) + log.info(f"Saving figure {cur_file_name} ...") + fig.savefig(output_dir / cur_file_name) + plt.close(fig) + data_keys.append(ds) + data.append(data_df.median(axis=1).values) + + fig, ax = plt.subplots(figsize=(10, 5)) + ax.boxplot(data) + ax.set_xticklabels(data_keys, rotation=45) + ax.yaxis.grid(color="gainsboro", linestyle="dotted") + + ax.set_xlabel("dataset") + ax.set_ylabel("median expression") + plt.box(on=None) + fig.tight_layout() + file_name = "data_box_plot.pdf" + fig_names.append(file_name) + log.info(f"Saving figure {file_name} ...") + fig.savefig(output_dir / file_name) + plt.close(fig) + + # boxplot of the number of samples and genes + sample_count = pd.DataFrame( + { + "count": [data_dict[ds].shape[0] for ds in data_dict], + "dataset": [ds for ds in data_dict], + } + ) + gene_count = pd.DataFrame( + { + "count": [data_dict[ds].shape[1] for ds in data_dict], + "dataset": [ds for ds in data_dict], + } + ) + + fig, ax = plt.subplots(1, 2, figsize=(10, 5)) + + bars0 = ax[0].barh(sample_count["dataset"], sample_count["count"], color="#774FA0") + bars1 = ax[1].barh(gene_count["dataset"], gene_count["count"], color="#7DC462") + + ax[0].spines["top"].set_visible(False) + ax[0].spines["left"].set_visible(False) + ax[0].spines["right"].set_visible(False) + ax[0].tick_params(axis="both", which="major", labelsize=12) + ax[0].bar_label(bars0, label_type="edge", fontsize=10) + ax[0].set_xlabel("number of samples") + + ax[1].spines["top"].set_visible(False) + ax[1].spines["left"].set_visible(False) + ax[1].spines["right"].set_visible(False) + ax[1].set_yticklabels([]) + ax[1].tick_params(axis="x", which="major", labelsize=12) + ax[1].tick_params(axis="y", which="both", left=False, right=False, labelleft=False) + + ax[1].bar_label(bars1, label_type="edge", fontsize=10) + ax[1].set_xlabel("number of genes") + + fig.tight_layout() + file_name = "sample_gene_count.pdf" + fig.savefig(output_dir / file_name) + log.info(f"Saving figure {file_name} ...") + plt.close(fig) + fig_names.append(file_name) + return fig_names + + +def plot_results(cfg, validation_results, llr_ds, gs_dict, cutoff_llr, figure_dir): + fig_names = [] + file_name = "llr_compare_dataset.pdf" + plot_llr_comparison( + cfg, validation_results, llr_ds, output_file=figure_dir / file_name + ) + fig_names.append(file_name) + + if "rp_pairs" in cfg: + file_names = plot_pair_llr( + gs_dict, + cfg["feature_type"], + output_dir=figure_dir, + rp_pairs=cfg["rp_pairs"], + ) + fig_names.extend(file_names) + + if cfg["task"] == "protein_func": + file_name = "llr_compare_networks.pdf" + # if gs_file is not specified, plot the llr comparison between networks + # because some the numbers are pre-computed based on default gs_file + if cfg["gs_file"] is None: + plot_llr_compare_networks( + validation_results, + cfg["lr_cutoff"], + cutoff_llr, + output_file=figure_dir / file_name, + ) + fig_names.append(file_name) + + # the information about other networks is fixed for now + other_network_info = { + "name": ["BioGRID", "BioPlex", "HI-union", "STRING"], + "type": ["BioGRID", "BioPlex", "HI", "STRING"], + "url": [ + "https://figshare.com/ndownloader/files/39125054", + "https://figshare.com/ndownloader/files/39125051", + "https://figshare.com/ndownloader/files/39125093", + "https://figshare.com/ndownloader/files/39125090", + ], + } + # convert the info to a data frame where the url is read as a dataframe + network_info = pd.DataFrame(other_network_info) + network_info["el"] = network_info["url"].apply( + lambda x: pd.read_csv(x, sep="\t", header=None) + ) + network_info = network_info.drop(columns=["url"]) + + # for each funmap, create a dataframe + for ft in validation_results: + edge_file_path = validation_results[ft]["edge_list_path"] + funmap_el = pd.read_csv(edge_file_path, sep="\t", header=None) + funmap_df = pd.DataFrame( + {"name": ["FunMap"], "type": ["FunMap"], "el": [funmap_el]} + ) + all_network_info = pd.concat([network_info, funmap_df], ignore_index=True) + overlap_info = get_node_edge_overlap(all_network_info) + node_color, edge_color = "#7DC462", "#774FA0" + for type, color in zip(["node", "edge"], [node_color, edge_color]): + fig_name = plot_overlap_venn( + f"funmap_{ft}", overlap_info[type], type, color, figure_dir + ) + fig_names.append(fig_name) + + fig_name = plot_network_stats(all_network_info, ft, figure_dir) + fig_names.append(fig_name) + + return fig_names + + +def plot_1d_llr(ax, feature_df, feature_name, feature_type, data_type, n_bins): + """ + Plot the 1D histogram of the likelihood ratio for each feature + + Parameters + ---------- + ax : matplotlib.axes._subplots.AxesSubplot + The subplot where the histogram is to be plotted. + feature_df : pd.DataFrame + DataFrame containing all features and their values. + feature_name : str + The name of the feature for which histogram is to be plotted. + feature_type : str + The type of the feature, either 'CC' or 'MR'. + data_type : str + The type of data, either 'RNA' or 'PRO'. + n_bins : int + The number of bins for the histogram. + + Returns + ------- + None + """ + df = feature_df.loc[:, [feature_name]] + cur_df = df.dropna() + cur_df_vals = cur_df.values.reshape(-1) + clr = "#bcbddc" + data_range = {"CC": (-1, 1), "MR": (0, 1)} + if data_type == "PRO": + ax.hist( + cur_df_vals, + bins=n_bins, + range=data_range[feature_type], + color=clr, + orientation="horizontal", + density=True, + ) + ax.text( + 0.95, + 0.95, + data_type, + verticalalignment="top", + horizontalalignment="right", + transform=ax.transAxes, + rotation=-90, + color="black", + fontsize=16, + ) + ax.set_ylim(data_range[feature_type]) + ax.set_xlim(0, 2.5) + else: + ax.hist( + cur_df_vals, + bins=n_bins, + range=data_range[feature_type], + color=clr, + density=True, + ) + ax.text( + 0.02, + 0.9, + data_type, + verticalalignment="top", + horizontalalignment="left", + transform=ax.transAxes, + color="black", + fontsize=16, + ) + ax.set_ylabel("density", fontsize=16) + ax.set_xlim(data_range[feature_type]) + ax.set_ylim(0, 2.5) + ax.tick_params(axis="x", labelsize=16) + ax.tick_params(axis="y", labelsize=16) + + +def plot_2d_llr( + ax, feature_df, feature_type, pair_name, rna_feature, pro_feature, n_bins +): + """ + Plots a 2D log likelihood ratio between two features in a scatter plot. + + Parameters + ---------- + ax : matplotlib.axes._subplots.AxesSubplot + The subplot to plot the log likelihood ratio on. + feature_df : pandas.DataFrame + DataFrame containing all features and target variables. + feature_type : str + Type of feature, either "CC" (correlation coefficient) or "MR" (mutual rank). + pair_name : str + Name of the feature pair. + rna_feature : str + Name of the RNA feature in `feature_df`. + pro_feature : str + Name of the protein feature in `feature_df`. + n_bins : int + Number of bins in the 2D histogram. + + Returns + ------- + fig : matplotlib.collections.QuadMesh + The mesh plot of the log likelihood ratio. + + """ + data_types = ["RNA", "PRO"] + label_mapping = {"PRO": "Protein", "RNA": "mRNA"} + feature_label_mapping = {"CC": "correlation coefficient", "MR": "mutual rank"} + cnt = {} + cnt_pos_neg = {} + max_density = -1 + + data_range = {"CC": (-1, 1), "MR": (0, 1)} + + for label in [0, 1]: + df = feature_df.loc[:, ["label", rna_feature, pro_feature]] + df = df.dropna() + cur_df = df.loc[df["label"] == label, [rna_feature, pro_feature]] + hist_density, _, _ = np.histogram2d( + cur_df[rna_feature].values, + cur_df[pro_feature].values, + bins=n_bins, + range=np.array([data_range[feature_type], data_range[feature_type]]), + density=True, + ) + max_density = max(max_density, np.max(hist_density)) + + for label in [0, 1]: + df = feature_df.loc[:, ["label", rna_feature, pro_feature]] + df = df.dropna() + cur_df = df.loc[df["label"] == label, [rna_feature, pro_feature]] + cnt[label] = cur_df.shape[0] + hist, _, _ = np.histogram2d( + cur_df[rna_feature].values, + cur_df[pro_feature].values, + bins=n_bins, + range=np.array([data_range[feature_type], data_range[feature_type]]), + ) + cnt_pos_neg[label] = hist + hh = ax.hist2d( + cur_df[rna_feature].values, + cur_df[pro_feature].values, + bins=n_bins, + range=np.array([data_range[feature_type], data_range[feature_type]]), + vmin=0, + vmax=max_density, + density=True, + ) + + llr_vals = ((cnt_pos_neg[1] + 1) / (cnt_pos_neg[0] + cnt[0] / cnt[1])) / ( + cnt[1] / cnt[0] + ) + if feature_type == "MR": + vmin = np.percentile(llr_vals, 5) + vmax = np.percentile(llr_vals, 95) + symmetric_max = max(abs(vmin), abs(vmax)) + vmin = -symmetric_max + vmax = symmetric_max + else: + vmin = -4 + vmax = 4 + cmap = plt.cm.RdBu_r + fig = ax.pcolormesh( + hh[1], hh[2], np.transpose(np.log(llr_vals)), vmin=vmin, vmax=vmax, cmap=cmap + ) + ax.text( + 0.02, + 0.01, + pair_name, + verticalalignment="bottom", + horizontalalignment="left", + transform=ax.transAxes, + color="gray", + fontsize=24, + fontweight="bold", + ) + ax.set_xlabel( + f"{label_mapping[data_types[0]]}\n{feature_label_mapping[feature_type]}", + fontsize=16, + ) + ax.set_ylabel( + f"{label_mapping[data_types[1]]}\n{feature_label_mapping[feature_type]}", + fontsize=16, + ) + if feature_type == "CC": + ax.set_xticks([-1, -0.5, 0, 0.5, 1]) + ax.set_yticks([-1, -0.5, 0, 0.5, 1]) + else: + ax.set_xticks([0, 0.25, 0.5, 0.75, 1]) + ax.set_yticks([0, 0.25, 0.5, 0.75, 1]) + ax.tick_params(axis="x", labelsize=16) + ax.tick_params(axis="y", labelsize=16) + return fig + + +def plot_pair_llr(gs_dict, feature_type, output_dir, rp_pairs): + n_bins = 20 + file_names = [] + feature_type = ["CC", "MR"] + # plot for each rna-protein pair using CC or MR as feature + for rp_pair in rp_pairs: + for ft in feature_type: + ft = ft.upper() + feature_df = gs_dict[ft] + fig, ax = plt.subplots( + 2, + 2, + figsize=(10, 10), + gridspec_kw={"width_ratios": [4, 1], "height_ratios": [1, 4]}, + ) + rna_feature = rp_pair["rna"] + pro_feature = rp_pair["protein"] + plot_1d_llr(ax[0, 0], feature_df, rna_feature, ft, "RNA", n_bins) + ax[0, 0].xaxis.set_ticks_position("none") + ax[0, 0].set_xticklabels([]) + + ax[0, 1].axis("off") + + heatmap2d = plot_2d_llr( + ax[1, 0], + feature_df, + ft, + rp_pair["name"], + rna_feature, + pro_feature, + n_bins, + ) + + plot_1d_llr(ax[1, 1], feature_df, pro_feature, ft, "PRO", n_bins) + ax[1, 1].yaxis.set_ticks_position("none") + ax[1, 1].set_yticklabels([]) + + # add colorbar to the right of the plot + cax = fig.add_axes([1.05, 0.25, 0.03, 0.5]) + fig.colorbar(heatmap2d, cax=cax) + + # plt.tight_layout() + plt.box(on=None) + file_name = f"{rp_pair['name']}_rna_pro_{ft}_llr.pdf" + file_names.append(file_name) + plt.savefig(output_dir / file_name, bbox_inches="tight") + plt.close(fig) + + return file_names + + +def plot_llr_compare_networks(validaton_results, cutoff, cutoff_llr, output_file): + all_networks = [] + + for ft in validaton_results: + edge_list = pd.read_csv( + validaton_results[ft]["edge_list_path"], sep="\t", header=None + ) + n_edge = len(edge_list) + n_node = len(set(edge_list.iloc[:, 0]) | set(edge_list.iloc[:, 1])) + all_networks.extend( + [ + ( + f"FunMap_{ft}", + "FunMap", + n_node, + n_edge, + cutoff_llr[ft], + np.exp(cutoff_llr[ft]), + ) + ] + ) + + # these are pre-computed values + all_networks.extend( + [ + # name, type, n, e, llr, lr + ("HuRI", "HI", 8272, 52548, 2.3139014130648827, 10.11), + ("HI-union", "HI", 9094, 64006, 2.298975841813893, 9.96), + ("ProHD", "ProHD", 2680, 61580, 4.039348296, 56.78), + # this is combined_score_700 + ("STRING", "STRING", 16351, 240314, 5.229377563059293, 186.676572955849), + ( + "BioGRID", + "BioGRID", + 17259, + 654490, + 2.6524642147396182, + 14.18896024552041, + ), + ( + "BioPlex", + "BioPlex", + 13854, + 154428, + 3.3329858940888375, + 28.021887299660047, + ), + ] + ) + log.info(all_networks) + + cols = ["name", "group", "n", "e", "llr", "lr"] + network_data = pd.DataFrame(all_networks, columns=cols) + x = np.array(network_data["n"]) + y = np.array(network_data["llr"]) + e = np.array(network_data["e"]) + z = network_data["name"] + + fig, ax = plt.subplots(figsize=(10, 10)) + ax.set_axisbelow(True) + ax.xaxis.grid(color="gainsboro", linestyle="dotted") + ax.yaxis.grid(color="gainsboro", linestyle="dotted") + ax.get_ygridlines()[4].set_color("salmon") + ax.get_yticklabels()[4].set_color("red") + ax2 = ax.twinx() + + # we have 6 groups, so we need 6 colors + mycmap = matplotlib.colors.ListedColormap( + ["#de2d26", "#8B6CAF", "#0D95D0", "#69A953", "#F1C36B", "#DC6C43"] + ) + # group 0 is FunMap, group 1 is HI, group 2 is ProHD, group 3 is STRING, + # group 4 is BioGRID, group 5 is BioPlex + # the length of gro + color_group = ( + [0] * len(validaton_results) + [1] * 2 + [2] * 1 + [3] * 1 + [4] * 1 + [5] * 1 + ) + scatter = ax.scatter(x, y, c=color_group, cmap=mycmap, s=e / 1000) + ax.set_ylim(2, 6) + ax2.set_ylim(np.exp(2.0), np.exp(6)) + ax.set_xlabel("number of genes") + ax.set_yticks([2.0, 2.5, 3, 3.5, np.log(cutoff), 4, 4.5, 5, 5.5, 6]) + ax.set_ylabel("log likelihood ratio") + ax.spines["top"].set_visible(False) + ax.spines["left"].set_visible(False) + ax.spines["right"].set_visible(False) + ax2.set_ylabel("likelihood ratio") + ax2.set_yscale("log", base=np.e) + ax2.set_yticks( + [ + np.exp(2), + np.exp(2.5), + np.exp(3), + np.exp(3.5), + cutoff, + np.exp(4), + np.exp(4.5), + np.exp(5), + np.exp(5.5), + np.exp(6), + ] + ) + ax.tick_params(axis="x", labelsize=12) + ax.tick_params(axis="y", labelsize=12) + ax2.tick_params(axis="y", labelsize=12) + ax2.get_yticklabels()[4].set_color("red") + ax2.get_yaxis().set_major_formatter(matplotlib.ticker.ScalarFormatter()) + ax2.spines["top"].set_visible(False) + ax2.spines["left"].set_visible(False) + ax2.spines["right"].set_visible(False) + + handles1, labels1 = scatter.legend_elements( + prop="sizes", num=4, alpha=1, fmt="{x:.0f} K" + ) + legend1 = ax.legend( + handles1, + labels1, + loc="upper left", + labelspacing=1.8, + borderpad=1.0, + title="number of pairs", + frameon=True, + ) + + ax.add_artist(legend1) + leg = ax.get_legend() + for i in range(len(leg.legendHandles)): + leg.legendHandles[i].set_color("gray") + + ax.xaxis.set_major_formatter(edge_number) + for i, txt in enumerate(z): + if txt == "STRING": + ax.annotate(txt, (x[i] - 500, y[i] + 0.18), color="gray", fontsize=10) + elif txt == "HI-union": + ax.annotate(txt, (x[i] - 100, y[i] + 0.1), color="gray", fontsize=10) + elif txt == "BioGRID": + ax.annotate(txt, (x[i] - 800, y[i] - 0.25), color="gray", fontsize=10) + else: + ax.annotate(txt, (x[i] - 100, y[i] - 0.2), color="gray", fontsize=10) + + fig.tight_layout() + fig.savefig(output_file, bbox_inches="tight") + plt.close(fig) + + +def plot_overlap_venn(network_name, overlap, node_or_edge, color, output_dir): + """ + Plot the Venn diagrams for the overlap between different datasets. + + Parameters + ---------- + network_name : str + The name of the network to plot the overlap for. + overlap : dict + A dictionary containing the overlap between the datasets. + The keys are the names of the datasets, and the values are the sets + representing the overlap. + node_or_edge : str + A string indicating whether to plot the overlap of nodes or edges. + Must be one of 'node' or 'edge'. + color : str + The color to use for the FunMap dataset in the Venn diagrams. + output_dir : path-like + The directory to save the output figure in. + + Returns + ------- + file_name : str + The name of the file that the figure was saved as. + + """ + data = [] + for nw in overlap: + data.append(overlap[nw]) + max_area = max(map(sum, data)) + + def set_venn_scale(ax, true_area, reference_area=max_area): + s = np.sqrt(float(reference_area) / true_area) + ax.set_xlim(-s, s) + ax.set_ylim(-s, s) + + all_axes = [] + + n_plot = len(overlap) + fig, ax = plt.subplots(1, n_plot, figsize=(5 * n_plot, 5)) + + for i, nw in enumerate(overlap): + cur_ax = ax[i] + all_axes.append(cur_ax) + labels = ("FunMap", nw) + out = venn2( + overlap[nw], + set_labels=labels, + alpha=1.0, + ax=cur_ax, + set_colors=[color, "white"], + ) + venn2_circles( + overlap[nw], ax=cur_ax, linestyle="solid", color="gray", linewidth=1 + ) + if out.set_labels: + for text in out.set_labels: + text.set_fontsize(12) + + for text in out.subset_labels: + text.set_fontsize(10) + + # add title to the figure + name = "genes" if node_or_edge == "node" else "edges" + fig.suptitle(f"Overlap of {name} ({network_name})", fontsize=16) + + for a, d in zip(flatten(ax), data): + set_venn_scale(a, sum(d) * 1.5) + + file_name = f"{network_name}_overlap_{node_or_edge}.pdf" + fig.savefig(output_dir / file_name, bbox_inches="tight") + plt.close(fig) + return file_name + + +def plot_network_stats(network_info, feature_type, output_dir): + fig, ax = plt.subplots(1, 4, figsize=(20, 5)) + density = {} + average_shortest_path = {} + # these are pre-calculated since they take a long time to compute and + # the network is fixed + average_shortest_path = { + "BioGRID": 2.74, + "BioPlex": 3.60, + "HI-union": 3.70, + "STRING": 3.95, + } + # if you want to recompute the average shortest path length, + # add the network name to this list + network_list = ["FunMap"] + for n in network_list: + network_el = network_info.loc[network_info["name"] == n, "el"].values[0] + cur_network = nx.from_pandas_edgelist(network_el, source=0, target=1) + cur_density = nx.density(cur_network) + density[n] = cur_density + largest_cc = max(nx.connected_components(cur_network), key=len) + cur_cc = cur_network.subgraph(largest_cc).copy() + cur_average_shortest_path = nx.average_shortest_path_length(cur_cc) + average_shortest_path[n] = cur_average_shortest_path + cur_degrees = [val for (_, val) in cur_network.degree()] + if n == "FunMap": # only fit for FunMap + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + fit = powerlaw.Fit( + cur_degrees, discrete=True, xmax=250, estimate_discrete=False + ) + powerlaw.plot_pdf( + cur_degrees, + linear_bins=True, + linestyle="None", + marker="o", + markerfacecolor="None", + color="#de2d26", + linewidth=3, + ax=ax[0], + ) + # not plotting the power law fit + # fit.power_law.plot_pdf(linestyle='--',color='black', ax=ax[0]) + + # all the networks in network_info minus FunMap + other_networks = list(set(network_info["name"].tolist()) - set(["FunMap"])) + for n in other_networks: + network_el = network_info.loc[network_info["name"] == n, "el"].values[0] + cur_network = nx.from_pandas_edgelist(network_el, source=0, target=1) + cur_density = nx.density(cur_network) + density[n] = cur_density + + ax[0].set_xlabel("degree") + ax[0].set_ylabel("p(x)") + ax[0].spines["top"].set_visible(False) + ax[0].spines["right"].set_visible(False) + ax[0].yaxis.grid(color="gainsboro", linestyle="dotted") + ax[0].set_axisbelow(True) + + # global average clustering coefficient + # these are pre-calculated since they take a long time to compute and + # the network is fixed + avg_cc = {"BioGRID": 0.125, "BioPlex": 0.103, "HI-union": 0.06, "STRING": 0.335} + for n in network_list: + network_el = network_info.loc[network_info["name"] == n, "el"].values[0] + cur_network = nx.from_pandas_edgelist(network_el, source=0, target=1) + cur_cc = nx.average_clustering(cur_network) + avg_cc[n] = cur_cc + + network_list = network_info["name"].tolist() + ax[1].bar( + network_list, + [avg_cc[i] for i in network_list], + width=0.5, + align="center", + color="#E4C89A", + ) + ax[1].spines["left"].set_position(("outward", 8)) + ax[1].spines["bottom"].set_position(("outward", 5)) + ax[1].spines["top"].set_visible(False) + ax[1].spines["left"].set_visible(False) + ax[1].spines["right"].set_visible(False) + ax[1].set_ylabel("Average clustering coefficient") + ticks_loc = ax[1].get_xticks() + ax[1].xaxis.set_major_locator(mticker.FixedLocator(ticks_loc)) + ax[1].set_xticklabels(network_list, rotation=45, ha="right") + ax[1].yaxis.grid(color="gainsboro", linestyle="dotted") + ax[1].set_axisbelow(True) + + ax[2].bar( + network_list, + [density[i] for i in network_list], + width=0.5, + align="center", + color="#D8B2C6", + ) + ax[2].spines["left"].set_position(("outward", 8)) + ax[2].spines["bottom"].set_position(("outward", 5)) + ax[2].spines["top"].set_visible(False) + ax[2].spines["left"].set_visible(False) + ax[2].spines["right"].set_visible(False) + ax[2].set_ylabel("Density") + ticks_loc = ax[2].get_xticks() + ax[2].xaxis.set_major_locator(mticker.FixedLocator(ticks_loc)) + ax[2].set_xticklabels(network_list, rotation=45, ha="right") + ax[2].yaxis.grid(color="gainsboro", linestyle="dotted") + ax[2].set_axisbelow(True) + + ax[3].bar( + network_list, + [average_shortest_path[i] for i in network_list], + width=0.5, + align="center", + color="#B6D8A6", + ) + ax[3].spines["left"].set_position(("outward", 8)) + ax[3].spines["bottom"].set_position(("outward", 5)) + ax[3].spines["top"].set_visible(False) + ax[3].spines["left"].set_visible(False) + ax[3].spines["right"].set_visible(False) + ax[3].set_ylabel("Average shortest path length") + ticks_loc = ax[3].get_xticks() + ax[3].xaxis.set_major_locator(mticker.FixedLocator(ticks_loc)) + ax[3].set_xticklabels(network_list, rotation=45, ha="right") + ax[3].yaxis.grid(color="gainsboro", linestyle="dotted") + ax[3].set_axisbelow(True) + + fig.suptitle(f"Network properties of Funmap ({feature_type})", fontsize=16) + file_name = f"funmap_{feature_type}_network_properties.pdf" + fig.savefig(output_dir / file_name, bbox_inches="tight") + plt.close(fig) + return file_name + + +def merge_and_delete(fig_dir, file_list, output_file): + """ + Merge multiple PDF files into one and delete the original files. + + Parameters + ---------- + fig_dir : Path + The directory where the PDF files are located. + file_list : list of str + The list of file names to be merged. + output_file : str or Path + The name of the output file. + + Returns + ------- + None + + """ + pdf_writer = PyPDF2.PdfWriter() + + total_page_num = 0 + for file in file_list: + pdf_reader = PyPDF2.PdfReader(fig_dir / file) + cur_page_num = len(pdf_reader.pages) + for page in range(cur_page_num): + pdf_writer.add_page(pdf_reader.pages[page]) + pdf_writer.add_outline_item(os.path.splitext(file)[0], total_page_num) + total_page_num = total_page_num + cur_page_num + + with open(fig_dir / output_file, "wb") as fh: + pdf_writer.write(fh) + log.info("figures have been merged.") + + for filename in file_list: + try: + os.remove(fig_dir / filename) + except: + log.error(f"{filename} could not be deleted.") diff --git a/funmap/saving.py b/python/funmap/saving.py similarity index 90% rename from funmap/saving.py rename to python/funmap/saving.py index a92eac8c..fbbff116 100644 --- a/funmap/saving.py +++ b/python/funmap/saving.py @@ -3,6 +3,7 @@ LOG_DIR = "logs" + def ensure_exists(p: Path) -> Path: """ Helper to ensure a directory exists. @@ -11,14 +12,16 @@ def ensure_exists(p: Path) -> Path: p.mkdir(parents=True, exist_ok=True) return p + def arch_path(config_file) -> Path: """ Construct a path based on the name of a configuration file eg. 'saved/EfficientNet' """ config = yaml.safe_load(open(config_file)) - p = Path(config['results_dir']) / config['name'] + p = Path(config["results_dir"]) / config["name"] return ensure_exists(p) + def log_path(config_file) -> Path: p = arch_path(config_file) / LOG_DIR return ensure_exists(p) diff --git a/python/funmap/utils.py b/python/funmap/utils.py new file mode 100644 index 00000000..9a843796 --- /dev/null +++ b/python/funmap/utils.py @@ -0,0 +1,639 @@ +import csv +import hashlib +import os +import re +import shutil +import tarfile +import urllib +from pathlib import Path +from urllib.parse import urlparse +import glob +import pickle +import pandas as pd +import yaml + +from funmap.data_urls import misc_urls as urls +from funmap.logger import setup_logger + +log = setup_logger(__name__) + + +def is_url_scheme(path): + parsed = urlparse(path) + if parsed.scheme == "file": + return False + + return parsed.scheme != "" + + +def read_csv_with_md5_check( + url, expected_md5=None, local_path="downloaded_file.csv", **kwargs +): + try: + response = urllib.request.urlopen(url) + content = response.read() + + if expected_md5: + md5_hash = hashlib.md5(content).hexdigest() + if md5_hash != expected_md5: + log.error( + "gold standard file: MD5 hash mismatch, file may be corrupted." + ) + raise ValueError("MD5 hash mismatch, file may be corrupted.") + + # Save the content to a local file + with open(local_path, "wb") as f: + f.write(content) + + df = pd.read_csv(local_path, **kwargs) + os.remove(local_path) + return df + except Exception as e: + return None + + +def check_gs_files_exist(file_dict, key="CC"): + if key.upper() == "CC": + paths = file_dict.get("CC") + if paths is not None and os.path.exists(paths): + return True + elif key.upper() == "MR": + paths = file_dict.get("MR") + if paths is not None and all(os.path.exists(p) for p in paths): + return True + else: + log.error(f"'{key}' is not a valid feature type (cc or mr).") + + return False + + +def normalize_filename(filename): + # Remove any characters that are not allowed in filenames + cleaned_filename = re.sub(r"[^\w\s.-]", "", filename) + # Replace spaces with underscores + cleaned_filename = cleaned_filename.replace(" ", "_") + return cleaned_filename + + +def get_data_dict(config, min_sample_count=15): + """ + Returns a dictionary of data from the provided data configuration, filtered to only include genes that are + coding and have at least `min_sample_count` samples. + + Returns + ------- + data_dict : dict + A dictionary where the keys are the names of the data files and the values are pandas DataFrames containing + the data from the corresponding file. + + """ + all_valid_ids = set() + data_dict = {} + if not config["only_extra_features"]: + data_file = config["data_path"] + if "filter_noncoding_genes" in config and config["filter_noncoding_genes"]: + mapping = pd.read_csv(urls["mapping_file"], sep="\t") + # extract the data file from the tar.gz file + tmp_dir = "tmp_data" + if not os.path.exists(tmp_dir): + os.mkdir(tmp_dir) + os.system(f"tar -xzf {data_file} --strip-components=1 -C {tmp_dir}") + # gene ids are gene symbols + for dt in config["data_files"]: + log.info(f"processing ... {dt['name']}") + cur_feature = dt["name"] + # cur_file = get_obj_from_tgz(data_file, dt['path']) + # extract the data file from the tar.gz file + # cur_data= get_obj_from_tgz(data_file, dt['path']) + cur_data = pd.read_csv( + os.path.join(tmp_dir, dt["path"]), sep="\t", index_col=0, header=0 + ) + if cur_data.shape[1] < min_sample_count: + log.info(f"... {dt['name']} ... not enough samples, skipped") + continue + cur_data = cur_data.T + # exclude cohort with sample number < min_sample_count + # remove noncoding genes first + if config["filter_noncoding_genes"]: + coding = mapping.loc[mapping["coding"] == "coding", ["gene_name"]] + coding_genes = list(set(coding["gene_name"].to_list())) + cur_data = cur_data[[c for c in cur_data.columns if c in coding_genes]] + # duplicated columns, for now select the last column + cur_data = cur_data.loc[:, ~cur_data.columns.duplicated(keep="last")] + data_dict[cur_feature] = cur_data + + shutil.rmtree(tmp_dir) + + log.info("filtering out non valid ids ...") + for i in data_dict: + cur_data = data_dict[i] + is_valid = cur_data.notna().sum() >= min_sample_count + # valid_count = np.sum(is_valid) + valid_p = cur_data.columns[is_valid].values + all_valid_ids = all_valid_ids.union(set(valid_p)) + if config["extra_feature_file"] is not None: + log.info("Importing extra feature file") + # TODO: Import extra feature file + extra_feature_df = pd.read_csv(config["extra_feature_file"], sep="\t") + if config["filter_noncoding_genes"]: + col1 = extra_feature_df.columns.values[0] + col2 = extra_feature_df.columns.values[1] + coding = mapping.loc[mapping["coding"] == "coding", ["gene_name"]] + coding_genes = list(set(coding["gene_name"].to_list())) + extra_feature_df = extra_feature_df[ + extra_feature_df[col1].isin(coding_genes) + & extra_feature_df[col2].isin(coding_genes) + ] + genes = set(extra_feature_df.iloc[:, :2].values.flatten().tolist()) + all_valid_ids = all_valid_ids.union(genes) + all_valid_ids = list(all_valid_ids) + all_valid_ids.sort() + log.info(f"total number of valid ids: {len(all_valid_ids)}") + + # filter out columns that are not in all_valid_ids + for i in data_dict: + cur_data = data_dict[i] + selected_columns = cur_data.columns.intersection(all_valid_ids) + cur_data = cur_data[selected_columns] + # it is possible the entire column is nan, remove it + cur_data = cur_data.dropna(axis=1, how="all") + data_dict[i] = cur_data + log.info(f"{i} -- ") + log.info(f" samples x ids: {cur_data.shape}") + + return data_dict, all_valid_ids + + +def get_node_edge(edge_list): + """ + Calculate the number of nodes and edges, and the ratio of edges per node, + and return the results in a dictionary format. + + Parameters + ---------- + edge_list : pandas DataFrame + The input DataFrame containing edge information. + + Returns + ------- + dict + A dictionary containing the number of nodes, the number of edges, + the ratio of edges per node, a list of nodes, and the edge_list. + + The keys of the dictionary are: + * n_node: int + The number of nodes in the network. + + * n_edge: int + The number of edges in the network. + + * edge_per_node: float + The ratio of edges per node in the network. + + * nodes: list + A list of nodes in the network. + + * edges: pandas DataFrame + The edge_list input DataFrame. + + """ + # remove duplidated rows in edge_list + edge_list = edge_list.drop_duplicates() + n_edge = len(edge_list) + nodes = set(edge_list.iloc[:, 0].to_list()) | set(edge_list.iloc[:, 1].to_list()) + return dict( + n_node=len(nodes), + n_edge=n_edge, + edge_per_node=n_edge / len(nodes), + nodes=list(nodes), + edges=edge_list, + ) + + +def get_node_edge_overlap(network_info): + """ + Computes the node and edge overlap between networks. + + Parameters + ---------- + network_info : pandas DataFrame + A DataFrame with information about the networks. + + Returns + ------- + overlap : dict + A dictionary with the node and edge overlap between networks. + """ + networks = pd.DataFrame( + columns=["name", "type", "n_node", "n_edge", "edge_per_node", "nodes", "edges"] + ) + + for _, row in network_info.iterrows(): + network_name = row["name"] + network_type = row["type"] + network_el = row["el"] + res = get_node_edge(network_el) + cur_df = pd.DataFrame( + { + "name": [network_name], + "type": [network_type], + "n_node": [int(res["n_node"])], + "n_edge": [int(res["n_edge"])], + "edge_per_node": [res["edge_per_node"]], + "nodes": [res["nodes"]], + "edges": [res["edges"]], + } + ) + networks = pd.concat([networks, cur_df], ignore_index=True) + # overlap of nodes and edges + overlap = {} + + # node overlap + target = "FunMap" + cur_res = {} + target_node_set = set(networks.loc[networks["name"] == target, "nodes"].tolist()[0]) + target_size = len(target_node_set) + for _, row in networks.iterrows(): + if row["name"] == target: + continue + cur_node_set = set(row["nodes"]) + cur_size = len(cur_node_set) + overlap_size = len(target_node_set & cur_node_set) + cur_res[row["name"]] = tuple( + [target_size - overlap_size, cur_size - overlap_size, overlap_size] + ) + + overlap["node"] = cur_res + + # edge overlap + cur_res = {} + target_edge_df = networks.loc[networks["name"] == target, "edges"].tolist()[0] + target_edge_set = set( + tuple(sorted(x)) for x in zip(target_edge_df.pop(0), target_edge_df.pop(1)) + ) + target_size = len(target_edge_set) + + for _, row in networks.iterrows(): + if row["name"] == target: + continue + edge_df = row["edges"] + edges = [tuple(sorted(x)) for x in zip(edge_df.pop(0), edge_df.pop(1))] + cur_edge_set = set(edges) + cur_size = len(cur_edge_set) + overlap_size = len(target_edge_set & cur_edge_set) + cur_res[row["name"]] = tuple( + [target_size - overlap_size, cur_size - overlap_size, overlap_size] + ) + + overlap["edge"] = cur_res + return overlap + + +def cleanup_experiment(config_file): + cfg = get_config(config_file) + results_dir = Path(cfg["results_dir"]) + shutil.rmtree(results_dir, ignore_errors=True) + + +def setup_experiment(config_file): + cfg = get_config(config_file) + results_dir = Path(cfg["results_dir"]) + # create folders + folder_dict = cfg["subdirs"] + folders = [ + results_dir / Path(folder_dict[folder_name]) for folder_name in folder_dict + ] + for folder in folders: + folder.mkdir(parents=True, exist_ok=True) + + # save configuration to results folder + with open(str(results_dir / "config.yml"), "w") as fh: + yaml.dump(cfg, fh, sort_keys=False) + + return cfg + + +def get_config(cfg_file: Path): + cfg = { + "task": "protein_func", + "results_dir": "results", + # the following directories are relative to the results_dir + "subdirs": { + "saved_model_dir": "saved_models", + "saved_data_dir": "saved_data", + "saved_predictions_dir": "saved_predictions", + "figure_dir": "figures", + "network_dir": "networks", + }, + "seed": 42, + "feature_type": "cc", + "test_size": 0.2, + "ml_type": "xgboost", + "gs_file": None, + "extra_feature_file": None, + # 'filter_before_prediction': True, + # 'min_feature_count': 1, + "min_sample_count": 20, + "filter_noncoding_genes": False, + # 'filter_after_prediction': True, + # 'filter_criterion': 'max', + # 'filter_threshold': 0.95, + # 'filter_blacklist': False, + "n_jobs": os.cpu_count(), + "start_edge_num": 1000, + "max_num_edges": 250000, + "step_size": 1000, + "lr_cutoff": 50, + "only_extra_features": False, + "extra_feature_folder": None, + } + + with open(cfg_file, "r") as fh: + cfg_dict = yaml.load(fh, Loader=yaml.FullLoader) + + # use can change the following parameters in the config file + if "task" in cfg_dict: + cfg["task"] = cfg_dict["task"] + assert cfg["task"] in ["protein_func", "kinase_func"] + + if "only_extra_features" in cfg_dict: + cfg["only_extra_features"] = cfg_dict["only_extra_features"] + if cfg["only_extra_features"] and "extra_feature_file" not in cfg_dict: + msg = "Extra feature file is not defined when only_extra_features = True" + print(msg) + raise ValueError(msg) + if "seed" in cfg_dict: + cfg["seed"] = cfg_dict["seed"] + + if "feature_type" in cfg_dict: + cfg["feature_type"] = cfg_dict["feature_type"] + assert cfg["feature_type"] in ["cc", "mr"] + + if "extra_feature_file" in cfg_dict: + cfg["extra_feature_file"] = cfg_dict["extra_feature_file"] + if "extra_feature_folder" in cfg_dict: + cfg["extra_feature_folder"] = cfg_dict["extra_feature_folder"] + + if "gs_file" in cfg_dict: + cfg["gs_file"] = cfg_dict["gs_file"] + + if "min_sample_count" in cfg_dict: + cfg["min_sample_count"] = cfg_dict["min_sample_count"] + + if "n_jobs" in cfg_dict: + cfg["n_jobs"] = cfg_dict["n_jobs"] + + if "start_edge_num" in cfg_dict: + cfg["start_edge_num"] = cfg_dict["start_edge_num"] + + if "max_num_edges" in cfg_dict: + cfg["max_num_edges"] = cfg_dict["max_num_edges"] + + if "step_size" in cfg_dict: + cfg["step_size"] = cfg_dict["step_size"] + + if "lr_cutoff" in cfg_dict: + cfg["lr_cutoff"] = cfg_dict["lr_cutoff"] + + if "name" in cfg_dict: + cfg["name"] = cfg_dict["name"] + else: + raise ValueError("name not specified in config file") + + if "data_path" in cfg_dict: + cfg["data_path"] = cfg_dict["data_path"] + else: + raise ValueError("data_path not specified in config file") + + if cfg["task"] == "protein_func": + if "filter_noncoding_genes" in cfg_dict: + cfg["filter_noncoding_genes"] = cfg_dict["filter_noncoding_genes"] + else: + # ignore filter_noncoding_genes for kinase_func + cfg["filter_noncoding_genes"] = False + log.info("ignoring filter_noncoding_genes for kinase_func") + + if "results_dir" in cfg_dict: + cfg["results_dir"] = cfg_dict["results_dir"] + "/" + cfg_dict["name"] + + if "data_files" not in cfg_dict and not cfg.get("only_extra_features", False): + raise ValueError("data_files not specified in config file") + + # Check all files listed under data_files are also in the tar.gz file + if not cfg["only_extra_features"]: + data_files = cfg_dict["data_files"] + # List all the files in the tar.gz file + with tarfile.open(cfg["data_path"], "r:gz") as tar: + tar_files = {Path(file).name for file in tar.getnames()} + + # Check if all files in data_files are in tar_files + if not all(file["path"] in tar_files for file in data_files): + print("Files listed under data_files are not in the tar.gz file!") + raise ValueError( + "Files listed under data_files are not in the tar.gz file!" + ) + + cfg["data_files"] = cfg_dict["data_files"] + + if "rp_pairs" in cfg_dict: + cfg["rp_pairs"] = cfg_dict["rp_pairs"] + + return cfg + + +def check_gold_standard_file(file_path, min_count=10000): + """ + min_threshold : int + The minimum threshold for the lesser of '0' and '1' counts in the 'Class' column. + + """ + try: + with open(file_path, "r", newline="") as tsv_file: + dialect = csv.Sniffer().sniff(tsv_file.read(2048)) + if dialect.delimiter != "\t": + print("Error: Incorrect TSV format. TSV files should be tab-separated.") + return False + except csv.Error as e: + print(f"CSV Error: {e}") + return False + + # Check data format and Class values + class_values = [] + with open(file_path, "r", newline="") as tsv_file: + reader = csv.reader(tsv_file, delimiter="\t") + next(reader) # Skip header + for row_num, row in enumerate( + reader, start=2 + ): # Add row_num for better error reporting + if len(row) != 3: # Assuming each row should have 3 columns + log.error( + f"Invalid row format in row {row_num}. Each row should have 3 columns." + ) + return False + class_value = row[2].strip() + if not class_value.isdigit() or int(class_value) not in (0, 1): + log.error(f'Invalid "Class" value in row {row_num}. Must be 0 or 1.') + return False + class_values.append(int(class_value)) + + # Check Class value counts and ratio + count_0 = class_values.count(0) + count_1 = class_values.count(1) + lesser_count = min(count_0, count_1) + + if lesser_count < min_count: + log.error( + f"The lesser of 0 and 1 occurrences ({lesser_count}) does not meet the threshold. " + f"Expected at least {min_count}." + ) + return False + + return True + + +def check_extra_feature_file(file_path, missing_value="NA"): + """ + Notes + ----- + This function checks the following criteria for the measurement TSV file: + - The file must have a header row. + - There must be at least 3 columns. + - The first two columns are gene/protein IDs. + - The data type for each additional column (excluding the first two columns) must be consistent + and can be either integer, float, or the specified missing_value. + + If any of the checks fail, the function will print informative error messages and return False. + + The TSV file should be tab-separated. + """ + try: + with open(file_path, "r", newline="") as tsv_file: + dialect = csv.Sniffer().sniff(tsv_file.read(1024)) + if dialect.delimiter != "\t": + log.error("Incorrect TSV format. TSV files should be tab-separated.") + return False + except csv.Error: + log.error("Unable to read TSV file.") + return False + + # Check header and number of columns + with open(file_path, "r", newline="") as tsv_file: + reader = csv.reader(tsv_file, delimiter="\t") + header = next(reader, None) + if header is None: + log.error("The TSV file must have a header row.") + return False + + num_columns = len(header) + if num_columns < 3: + log.error("The TSV file must have at least 3 columns.") + return False + + # Check data type consistency for additional columns + with open(file_path, "r", newline="") as tsv_file: + reader = csv.DictReader(tsv_file, delimiter="\t") + column_data_types = {} + for row_num, row in enumerate( + reader, start=2 + ): # Add row_num for better error reporting + for column_name, value in row.items(): + if ( + column_name not in header[:2] + ): # Skip the first two columns (Protein_1 and Protein_2) + if value == missing_value: + continue + try: + float_value = float(value) + if float_value.is_integer(): + value = int(float_value) + except ValueError: + log.error( + "Invalid data type in row %d, column '%s'. " + "The value '%s' should be either an integer, a float, or '%s' (missing value).", + row_num, + column_name, + value, + missing_value, + ) + return False + # Store the data type of each column (float or integer) + data_type = float if "." in value else int + if column_name not in column_data_types: + column_data_types[column_name] = data_type + elif column_data_types[column_name] != data_type: + log.error( + "Inconsistent data type in column '%s'. " + "Expected a consistent data type (integer, float, or '%s') for all rows.", + column_name, + missing_value, + ) + return False + + return True + + +def read_files_to_dataframe(file_paths): + # Read each file into a DataFrame and concatenate + dataframes = [] + for file_path in file_paths: + # Read file into DataFrame; each line becomes a row in a single column + df = pd.read_csv( + file_path, + header=None, + names=[os.path.basename(file_path).replace(".col", "")], + ) + dataframes.append(df) + + # Concatenate all DataFrames into one + final_df = pd.concat(dataframes, ignore_index=True) + return final_df + + +def reorder_dataframe(final_df, index_file_path, max_length): + reordered_df = pd.DataFrame(index=range(max_length), columns=final_df.columns) + # Read the index file + with open(index_file_path, "r") as f: + indices = [int(line.strip()) for line in f.readlines()] + + # Ensure the length of the indices is less than or equal to max_length + if len(indices) > max_length: + raise ValueError("More indices provided than the max_length.") + + # Fill in the reordered DataFrame based on the provided indices + for i, index in enumerate(indices): + if index < max_length: + reordered_df.iloc[index] = final_df.iloc[i] + + return reordered_df + + +def new_extra_feature(extra_feature_folder): + index_file = list(glob.glob(f"{extra_feature_folder}/*.index"))[0] + features = list(glob.glob(f"{extra_feature_folder}/*.col")) + with open(f"{extra_feature_folder}/uniq_gene.pkl", "rb") as r: + uniq_gene = pickle.load(r) + i = j = len(uniq_gene) - 1 + n = i + 1 + max_len = i * n - i * (i - 1) // 2 + (j - i) + + df = read_files_to_dataframe(features) + log.info(df.head()) + df = reorder_dataframe(df, index_file, max_len) + return (uniq_gene, df) + + +def process_extra_feature(extra_feature_file) -> pd.DataFrame: + extra_feature_df = pd.read_csv(extra_feature_file, sep="\t") + extra_feature_df.columns.values[0] = "P1" + extra_feature_df.columns.values[1] = "P2" + extra_feature_df[["P1", "P2"]] = extra_feature_df.apply( + lambda row: sorted([row["P1"], row["P2"]]) + if row["P1"] > row["P2"] + else [row["P1"], row["P2"]], + axis=1, + result_type="expand", + ) + extra_feature_df = extra_feature_df.drop_duplicates( + subset=["P1", "P2"], keep="last" + ) + return extra_feature_df diff --git a/requirements-dev.lock b/requirements-dev.lock new file mode 100644 index 00000000..8ff4123e --- /dev/null +++ b/requirements-dev.lock @@ -0,0 +1,121 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: false +# with-sources: false +# generate-hashes: false +# universal: false + +-e file:. +blosc2==2.7.1 + # via tables +click==8.1.7 + # via funmap +contourpy==1.3.0 + # via matplotlib +cycler==0.12.1 + # via matplotlib +fonttools==4.54.1 + # via matplotlib +h5py==3.12.1 + # via funmap +imbalanced-learn==0.12.3 + # via funmap +joblib==1.4.2 + # via funmap + # via imbalanced-learn + # via scikit-learn +kiwisolver==1.4.7 + # via matplotlib +matplotlib==3.9.2 + # via funmap + # via matplotlib-venn + # via powerlaw + # via seaborn +matplotlib-venn==1.1.1 + # via funmap +maturin==1.7.4 +mpmath==1.3.0 + # via powerlaw +msgpack==1.1.0 + # via blosc2 +ndindex==1.9.2 + # via blosc2 +networkx==3.3 + # via funmap +numexpr==2.10.1 + # via blosc2 + # via tables +numpy==2.1.1 + # via blosc2 + # via contourpy + # via funmap + # via h5py + # via imbalanced-learn + # via matplotlib + # via matplotlib-venn + # via numexpr + # via pandas + # via powerlaw + # via pyarrow + # via scikit-learn + # via scipy + # via seaborn + # via tables + # via xgboost +packaging==24.1 + # via matplotlib + # via tables +pandas==2.2.3 + # via funmap + # via seaborn +pillow==10.4.0 + # via matplotlib +powerlaw==1.5 + # via funmap +py-cpuinfo==9.0.0 + # via blosc2 + # via tables +pyarrow==17.0.0 + # via funmap +pyparsing==3.1.4 + # via matplotlib +pypdf2==3.0.1 + # via funmap +python-dateutil==2.9.0.post0 + # via matplotlib + # via pandas +pytz==2024.2 + # via pandas +pyyaml==6.0.2 + # via funmap +scikit-learn==1.5.2 + # via funmap + # via imbalanced-learn +scipy==1.14.1 + # via funmap + # via imbalanced-learn + # via matplotlib-venn + # via powerlaw + # via scikit-learn + # via xgboost +seaborn==0.13.2 + # via funmap +six==1.16.0 + # via python-dateutil +tables==3.10.1 + # via funmap +threadpoolctl==3.5.0 + # via imbalanced-learn + # via scikit-learn +tqdm==4.66.5 + # via funmap +typing-extensions==4.12.2 + # via tables +tzdata==2024.2 + # via pandas +xgboost==2.1.1 + # via funmap diff --git a/requirements.lock b/requirements.lock new file mode 100644 index 00000000..43e3816b --- /dev/null +++ b/requirements.lock @@ -0,0 +1,120 @@ +# generated by rye +# use `rye lock` or `rye sync` to update this lockfile +# +# last locked with the following flags: +# pre: false +# features: [] +# all-features: false +# with-sources: false +# generate-hashes: false +# universal: false + +-e file:. +blosc2==2.7.1 + # via tables +click==8.1.7 + # via funmap +contourpy==1.3.0 + # via matplotlib +cycler==0.12.1 + # via matplotlib +fonttools==4.54.1 + # via matplotlib +h5py==3.12.1 + # via funmap +imbalanced-learn==0.12.3 + # via funmap +joblib==1.4.2 + # via funmap + # via imbalanced-learn + # via scikit-learn +kiwisolver==1.4.7 + # via matplotlib +matplotlib==3.9.2 + # via funmap + # via matplotlib-venn + # via powerlaw + # via seaborn +matplotlib-venn==1.1.1 + # via funmap +mpmath==1.3.0 + # via powerlaw +msgpack==1.1.0 + # via blosc2 +ndindex==1.9.2 + # via blosc2 +networkx==3.3 + # via funmap +numexpr==2.10.1 + # via blosc2 + # via tables +numpy==2.1.1 + # via blosc2 + # via contourpy + # via funmap + # via h5py + # via imbalanced-learn + # via matplotlib + # via matplotlib-venn + # via numexpr + # via pandas + # via powerlaw + # via pyarrow + # via scikit-learn + # via scipy + # via seaborn + # via tables + # via xgboost +packaging==24.1 + # via matplotlib + # via tables +pandas==2.2.3 + # via funmap + # via seaborn +pillow==10.4.0 + # via matplotlib +powerlaw==1.5 + # via funmap +py-cpuinfo==9.0.0 + # via blosc2 + # via tables +pyarrow==17.0.0 + # via funmap +pyparsing==3.1.4 + # via matplotlib +pypdf2==3.0.1 + # via funmap +python-dateutil==2.9.0.post0 + # via matplotlib + # via pandas +pytz==2024.2 + # via pandas +pyyaml==6.0.2 + # via funmap +scikit-learn==1.5.2 + # via funmap + # via imbalanced-learn +scipy==1.14.1 + # via funmap + # via imbalanced-learn + # via matplotlib-venn + # via powerlaw + # via scikit-learn + # via xgboost +seaborn==0.13.2 + # via funmap +six==1.16.0 + # via python-dateutil +tables==3.10.1 + # via funmap +threadpoolctl==3.5.0 + # via imbalanced-learn + # via scikit-learn +tqdm==4.66.5 + # via funmap +typing-extensions==4.12.2 + # via tables +tzdata==2024.2 + # via pandas +xgboost==2.1.1 + # via funmap diff --git a/setup.py b/setup.py index f5d50147..62b477cd 100644 --- a/setup.py +++ b/setup.py @@ -2,51 +2,51 @@ """The setup script.""" -from setuptools import setup, find_packages +from setuptools import find_packages, setup requirements = [ - 'pyyaml==6.0.1', - 'xgboost==2.0.0', - 'numpy==1.24.4', - 'scipy==1.10.1', - 'pyarrow==13.0.0', - 'pandas==2.0.3', - 'joblib==1.3.2', - 'matplotlib==3.7.3', - 'seaborn==0.13.0', - 'scikit-learn==1.3.2', - 'imbalanced-learn==0.11.0', - 'tqdm==4.66.1', - 'PyPDF2==3.0.1', - 'matplotlib_venn==0.11.9', - 'networkx==3.1', - 'powerlaw==1.5', - 'click==8.1.7', - 'h5py==3.10.0', - 'tables==3.8.0' - ] + "pyyaml==6.0.1", + "xgboost==2.0.0", + "numpy==1.24.4", + "scipy==1.10.1", + "pyarrow==13.0.0", + "pandas==2.0.3", + "joblib==1.3.2", + "matplotlib==3.7.3", + "seaborn==0.13.0", + "scikit-learn==1.3.2", + "imbalanced-learn==0.11.0", + "tqdm==4.66.1", + "PyPDF2==3.0.1", + "matplotlib_venn==0.11.9", + "networkx==3.1", + "powerlaw==1.5", + "click==8.1.7", + "h5py==3.10.0", + "tables==3.8.0", +] -test_requirements = [ ] +test_requirements = [] setup( author="Zhiao Shi", - author_email='zhiao.shi@gmail.com', - python_requires='>=3.7', + author_email="zhiao.shi@gmail.com", + python_requires=">=3.7", description="generate gene co-function networks using omics data", entry_points={ - 'console_scripts': [ - 'funmap=funmap.cli:cli', + "console_scripts": [ + "funmap=funmap.cli:cli", ], }, install_requires=requirements, - license='MIT license', + license="MIT license", include_package_data=True, - keywords=['funmap', 'bioinformatics', 'biological-network'], - name='funmap', - packages=find_packages(include=['funmap', 'funmap.*']), - test_suite='tests', + keywords=["funmap", "bioinformatics", "biological-network"], + name="funmap", + packages=find_packages(include=["funmap", "funmap.*"]), + test_suite="tests", tests_require=test_requirements, - url='https://github.com/bzhanglab/funmap', - version='0.2.0', + url="https://github.com/bzhanglab/funmap", + version="0.2.0", zip_safe=False, ) diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 00000000..14090ecb --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,322 @@ +use ahash::{AHashMap, AHashSet, HashSet, HashSetExt}; +use csv::ReaderBuilder; +use pyo3::{exceptions::PyValueError, prelude::*}; +use rusqlite::{params, params_from_iter, Connection, Result}; +use serde_pickle::SerOptions; +use std::{ + fs::File, + io::{BufRead, BufReader, BufWriter, Write}, + path::Path, +}; + +/// process_files +/// +/// Supports the addition of extra features in the format of gene pairs +/// +/// Files can reach up to 10 GB, so Rust is used to speed up and optimize the merging process. +/// +/// # Step 1: Identify all unique genes +/// +/// This step identifies all the unique genes found in both the data and extra-feature files. +/// +/// This allows for a universal order to genes that will be referenced in other steps +/// +/// Input: All data files and extra feature files +/// Output: .pkl file containing the unique genes found in alphabetical order (A before B) +/// +/// # Step 2: Align extra features +/// +/// For each extra feature file, re-index the rows according to the order of genes in Step 1. +/// +/// Then save each column as a separate pkl file +/// +/// Input: unique genes from Step 1, extra feature files. +/// Output: One pkl file for each feature across all extra feature files +/// +/// Function +/// Output: list of all output feature pkl files +/// +/// TODO: Protein-coding gene filtering +#[pyfunction] +#[pyo3(signature = (expression_paths, extra_feature_paths, output_folder, valid_ids=None))] +fn process_files( + expression_paths: Vec, + extra_feature_paths: Vec, + output_folder: String, + valid_ids: Option>, +) -> PyResult { + // Step 1: Identify all unique genes + // Across both expression and extra_feature_paths + // Create final unique_gene pkl file + + let mut uniq_gene = HashSet::new(); + + // Read expression data where first column is gene information; + for file_path in expression_paths.iter() { + let file = File::open(file_path).expect("Could not read expression data"); + let reader = BufReader::new(file); + let mut has_header = true; + for line in reader.lines() { + let line = line?; + if has_header { + // skip header + has_header = false; + let row: Vec<&str> = line.split('\t').collect(); + if row.len() < 2 { + return Err(PyValueError::new_err(format!( + "Expression data file at {} does not have enough columns.", + file_path + ))); + } + continue; + } + let row: Vec<&str> = line.split('\t').collect(); + if row.len() > 1 { + uniq_gene.insert(row[0].to_string()); // add gene to set + } + } + } + + // Read extra feature data, where first and second column are genes + for file_path in extra_feature_paths.iter() { + let file = File::open(file_path).expect("Could not read extra feature file"); + let reader = BufReader::new(file); + let mut has_header = true; + for line in reader.lines() { + let line = line?; + if has_header { + // skip header + has_header = false; + let row: Vec<&str> = line.split('\t').collect(); + if row.len() < 3 { + return Err(PyValueError::new_err(format!( + "Extra feature file at {} does not have enough columns.", + file_path + ))); + } + continue; + } + let row: Vec<&str> = line.split('\t').collect(); + if row.len() > 2 { + uniq_gene.insert(row[0].to_string()); // add first gene to set + uniq_gene.insert(row[1].to_string()); // add second gene to set + } + } + } + + let mut uniq_gene: Vec = if let Some(valid_ids) = valid_ids { + let valid_ids = AHashSet::from_iter(valid_ids); + uniq_gene.union(&valid_ids).cloned().collect() + } else { + uniq_gene.iter().cloned().collect() + }; + + // Save to pickle + // TODO: Look at other file formats + + // Sort genes alphabetically + uniq_gene.sort(); + let folder_path = Path::new(&output_folder); + let uniq_gene_file_path = folder_path.join("uniq_gene.pkl"); + let mut w = File::create(uniq_gene_file_path).expect("Could not cread uniq_gene.pkl"); + serde_pickle::to_writer(&mut w, &uniq_gene, SerOptions::default()).unwrap(); + let n = uniq_gene.len(); + // Re-align each file + // Create a HashMap to store the indices of each string + let mut gene_index_map: AHashMap<&String, usize> = AHashMap::new(); + for (index, gene) in uniq_gene.iter().enumerate() { + gene_index_map.insert(gene, index); + } + + for file_path in extra_feature_paths { + new(&file_path, &gene_index_map, n as i32, folder_path).expect("Error aligning file"); + } + // One column of indices, and one column of values. Separated by file + Ok(true) +} +// Function to parse a string into a float, returning None for invalid or NaN values +fn safe_parse_float(s: &str) -> Option { + match s.parse::() { + Ok(f) if f.is_nan() => None, // Explicitly keep NaNs as None + Ok(f) => Some(f), + Err(_) => None, // If it can't be parsed, treat as None (null) + } +} +fn new( + path: &String, + uniq_gene: &AHashMap<&String, usize>, + n: i32, + output_folder: &Path, +) -> PyResult { + // Create SQLite connection + let conn = Connection::open("db.sqlite").unwrap(); + + // Open the TSV file + let file_path = path; + let mut rdr = ReaderBuilder::new() + .delimiter(b'\t') + .from_path(file_path) + .unwrap(); + + // Get the headers to determine feature names dynamically + let headers = rdr.headers().unwrap().clone(); + + // Features are all columns after the first two columns + let feature_names: Vec<&str> = headers + .iter() + .skip(2) // Skip the first two columns + .collect(); + + // Dynamically create the SQL table with the appropriate number of features, all as FLOAT + let feature_columns: Vec = feature_names + .iter() + .map(|name| format!("{} FLOAT", name)) + .collect(); + + let create_table_query = format!( + "CREATE TABLE IF NOT EXISTS gene_data ( + index_id TEXT PRIMARY KEY, + {} + )", + feature_columns.join(", ") + ); + conn.execute(&create_table_query, []).unwrap(); + + // Prepare to write features to a separate file (feature names only) + let feature_file = File::create("features.txt")?; + let mut feature_writer = BufWriter::new(feature_file); + + // Write feature names to the file + for feature in &feature_names { + writeln!(feature_writer, "{}", feature)?; + } + + // Read each record from TSV and insert it into the database + for result in rdr.records() { + let record = result.unwrap(); + let column1 = &record[0]; + let column2 = &record[1]; + let index1 = uniq_gene.get(&column1.to_string()); + let index2 = uniq_gene.get(&column2.to_string()); + if let (Some(index1), Some(index2)) = (index1, index2) { + let (i, j) = if index1 <= index2 { + (*index1 as i32, *index2 as i32) + } else { + (*index2 as i32, *index1 as i32) + }; + let index_id = (i * n - i * (i - 1) / 2 + (j - i)) as usize; + + // Extract feature values, converting them to Option to handle nulls/NaNs + let feature_values: Vec> = record + .iter() + .skip(2) // Skip the first two columns + .map(safe_parse_float) // Handle nulls and invalid values + .collect(); + + // Insert into SQLite database using dynamic query + let insert_query = format!( + "INSERT OR REPLACE INTO gene_data (index_id, {}) + VALUES (?1, {})", + feature_names.join(", "), + feature_values + .iter() + .enumerate() + .map(|(i, _)| format!("?{}", i + 2)) + .collect::>() + .join(", ") + ); + + // Collect the parameters for the query, using None for null/NaN values + let mut params_vec: Vec<&(dyn rusqlite::ToSql + Sync)> = vec![&index_id]; + for val in &feature_values { + match val { + Some(v) => params_vec.push(v), + None => params_vec.push(&rusqlite::types::Null), + } + } + + conn.execute(&insert_query, params_from_iter(params_vec.iter())) + .unwrap(); + } + } + Ok(true) +} + +fn align_file( + path: &String, + uniq_gene: &AHashMap<&String, usize>, + n: i32, + output_folder: &Path, +) -> PyResult { + // Read extra feature data, where first and second column are genes + let file = File::open(path)?; + let reader = BufReader::new(file); + let mut has_header = true; + let mut indices: Vec = Vec::new(); + let mut feature_count = 0; + let mut writers = Vec::new(); + let original_file = Path::new(path).file_name().unwrap().to_str().unwrap(); + let index_file_path = output_folder.join(format!("{}.index", original_file)); + let f = File::create(index_file_path)?; + let bf = BufWriter::new(f); + writers.push(bf); + + for line in reader.lines() { + let line = line?; + if has_header { + // skip header + has_header = false; + let row: Vec<&str> = line.split('\t').collect(); + if row.len() < 3 { + return Err(PyValueError::new_err(format!( + "Extra feature file at {} does not have enough columns.", + path + ))); + } + feature_count = row.len() - 2; + for i in 0..feature_count { + let file_path = output_folder.join(format!("{}.col", row[i + 2])); + let f = File::create(file_path)?; + let bf = BufWriter::new(f); + writers.push(bf); + } + continue; + } + let row: Vec<&str> = line.split('\t').collect(); + if row.len() > 2 { + let index1 = uniq_gene.get(&row[0].to_string()); + let index2 = uniq_gene.get(&row[1].to_string()); + if let (Some(index1), Some(index2)) = (index1, index2) { + let (i, j) = if index1 <= index2 { + (*index1 as i32, *index2 as i32) + } else { + (*index2 as i32, *index1 as i32) + }; + let new_index = (i * n - i * (i - 1) / 2 + (j - i)) as usize; + indices.push(new_index); + writers[0].write_all(new_index.to_string().as_bytes())?; + writers[0].write_all(b"\n")?; + for i in 0..feature_count { + writers[i + 1].write_all(row[i + 2].as_bytes())?; + writers[i + 1].write_all(b"\n")?; + } + } + } + } + Ok(true) +} + +/// Formats the sum of two numbers as string. +#[pyfunction] +fn sum_as_string(a: usize, b: usize) -> PyResult { + Ok((a + b).to_string()) +} + +/// A Python module implemented in Rust. +#[pymodule] +#[pyo3(name = "_lib")] // module name is _lib. (imports as funmap._lib). This is to hide these functions from regular users. +fn funmap_lib(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_function(wrap_pyfunction!(sum_as_string, m)?)?; + m.add_function(wrap_pyfunction!(process_files, m)?)?; + Ok(()) +}