From 2519852b62fbb87d9724287c32623811b459c4d5 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 3 Apr 2025 16:27:13 +0200 Subject: [PATCH 01/36] add protein-related code from https://github.com/ChEB-AI/python-chebai --- .gitattributes | 10 + .github/workflows/black.yml | 10 + .github/workflows/export_constants.py | 22 + .github/workflows/test.yml | 38 + .github/workflows/token_consistency.yaml | 110 + .github/workflows/verify_constants.yml | 116 + .gitignore | 169 + .pre-commit-config.yaml | 25 + LICENSE | 661 ++ chebai/__init__.py | 30 + chebai/__main__.py | 10 + chebai/callbacks.py | 86 + chebai/callbacks/__init__.py | 0 chebai/callbacks/epoch_metrics.py | 180 + chebai/callbacks/model_checkpoint.py | 95 + chebai/callbacks/prediction_callback.py | 55 + chebai/cli.py | 75 + chebai/loggers/__init__.py | 0 chebai/loggers/custom.py | 127 + chebai/loss/__init__.py | 0 chebai/loss/bce_weighted.py | 98 + chebai/loss/mixed.py | 40 + chebai/loss/pretraining.py | 48 + chebai/loss/semantic.py | 532 ++ chebai/models/__init__.py | 2 + chebai/models/base.py | 372 + chebai/models/chemberta.py | 77 + chebai/models/chemyk.py | 63 + chebai/models/electra.py | 535 ++ chebai/models/external/__init__.py | 0 chebai/models/ffn.py | 153 + chebai/models/lnn_model.py | 40 + chebai/models/lstm.py | 34 + chebai/models/recursive.py | 97 + chebai/models/strontex.py | 14 + chebai/preprocessing/__init__.py | 0 .../bin/protein_token/tokens.txt | 21 + .../bin/protein_token_3_gram/tokens.txt | 8359 +++++++++++++++++ chebai/preprocessing/collate.py | 137 + chebai/preprocessing/collect_all.py | 226 + chebai/preprocessing/datasets/__init__.py | 4 + chebai/preprocessing/datasets/base.py | 1184 +++ .../preprocessing/datasets/deepGO/__init__.py | 0 .../datasets/deepGO/go_uniprot.py | 1007 ++ .../datasets/deepGO/protein_pretraining.py | 279 + .../preprocessing/datasets/scope/__init__.py | 0 chebai/preprocessing/datasets/scope/scope.py | 972 ++ chebai/preprocessing/migration/__init__.py | 0 .../migration/deep_go/__init__.py | 0 .../deep_go/migrate_deep_go_1_data.py | 316 + .../deep_go/migrate_deep_go_2_data.py | 366 + chebai/preprocessing/reader.py | 514 + chebai/preprocessing/structures.py | 141 + chebai/result/__init__.py | 0 chebai/result/analyse_sem.py | 721 ++ chebai/result/base.py | 105 + chebai/result/classification.py | 105 + chebai/result/evaluate_predictions.py | 108 + chebai/result/molplot.py | 506 + chebai/result/prediction_json.py | 26 + chebai/result/pretraining.py | 65 + chebai/result/utils.py | 235 + chebai/trainer/CustomTrainer.py | 149 + chebai/trainer/__init__.py | 0 configs/data/deepGO/deepgo2_esm2.yml | 5 + .../data/deepGO/deepgo_1_migrated_data.yml | 4 + .../data/deepGO/deepgo_2_migrated_data.yml | 5 + configs/data/deepGO/go250.yml | 3 + configs/data/deepGO/go50.yml | 1 + configs/data/scope/scope2000.yml | 3 + configs/data/scope/scope50.yml | 3 + configs/default_prediction_callback.yml | 4 + configs/loss/bce.yml | 1 + configs/loss/electra_pre_loss.yml | 1 + configs/loss/semantic_loss.yml | 10 + configs/metrics/balanced-accuracy.yml | 5 + configs/metrics/micro-macro-f1.yml | 9 + configs/metrics/single-class-f1.yml | 5 + configs/model/electra-for-pretraining.yml | 20 + configs/model/electra.yml | 11 + configs/model/electra_pretraining.yml | 18 + configs/model/ffn.yml | 5 + configs/training/csv_logger.yml | 3 + configs/training/default_callbacks.yml | 12 + configs/training/default_trainer.yml | 5 + configs/training/early_stop_callbacks.yml | 19 + configs/training/pretraining_callbacks.yml | 12 + configs/training/pretraining_trainer.yml | 7 + configs/training/single_class_callbacks.yml | 13 + configs/training/wandb_logger.yml | 6 + docs/source/experiment.rst | 1 + docs/source/model.rst | 1 + setup.cfg | 7 + setup.py | 57 + tests/__init__.py | 0 tests/unit/__init__.py | 4 + tests/unit/collators/__init__.py | 0 tests/unit/collators/testDefaultCollator.py | 65 + tests/unit/collators/testRaggedCollator.py | 204 + tests/unit/dataset_classes/__init__.py | 0 .../dataset_classes/testDynamicDataset.py | 372 + .../testGOUniProDataExtractor.py | 229 + .../dataset_classes/testGoUniProtOverX.py | 140 + .../testProteinPretrainingData.py | 76 + .../dataset_classes/testXYBaseDataModule.py | 92 + tests/unit/mock_data/__init__.py | 0 tests/unit/mock_data/ontology_mock_data.py | 521 + tests/unit/readers/__init__.py | 0 tests/unit/readers/testDataReader.py | 56 + tests/unit/readers/testProteinDataReader.py | 139 + tutorials/data_exploration_go.ipynb | 1341 +++ tutorials/data_exploration_scope.ipynb | 1182 +++ 112 files changed, 24147 insertions(+) create mode 100644 .gitattributes create mode 100644 .github/workflows/black.yml create mode 100644 .github/workflows/export_constants.py create mode 100644 .github/workflows/test.yml create mode 100644 .github/workflows/token_consistency.yaml create mode 100644 .github/workflows/verify_constants.yml create mode 100644 .gitignore create mode 100644 .pre-commit-config.yaml create mode 100644 LICENSE create mode 100644 chebai/__init__.py create mode 100644 chebai/__main__.py create mode 100644 chebai/callbacks.py create mode 100644 chebai/callbacks/__init__.py create mode 100644 chebai/callbacks/epoch_metrics.py create mode 100644 chebai/callbacks/model_checkpoint.py create mode 100644 chebai/callbacks/prediction_callback.py create mode 100644 chebai/cli.py create mode 100644 chebai/loggers/__init__.py create mode 100644 chebai/loggers/custom.py create mode 100644 chebai/loss/__init__.py create mode 100644 chebai/loss/bce_weighted.py create mode 100644 chebai/loss/mixed.py create mode 100644 chebai/loss/pretraining.py create mode 100644 chebai/loss/semantic.py create mode 100644 chebai/models/__init__.py create mode 100644 chebai/models/base.py create mode 100644 chebai/models/chemberta.py create mode 100644 chebai/models/chemyk.py create mode 100644 chebai/models/electra.py create mode 100644 chebai/models/external/__init__.py create mode 100644 chebai/models/ffn.py create mode 100644 chebai/models/lnn_model.py create mode 100644 chebai/models/lstm.py create mode 100644 chebai/models/recursive.py create mode 100644 chebai/models/strontex.py create mode 100644 chebai/preprocessing/__init__.py create mode 100644 chebai/preprocessing/bin/protein_token/tokens.txt create mode 100644 chebai/preprocessing/bin/protein_token_3_gram/tokens.txt create mode 100644 chebai/preprocessing/collate.py create mode 100644 chebai/preprocessing/collect_all.py create mode 100644 chebai/preprocessing/datasets/__init__.py create mode 100644 chebai/preprocessing/datasets/base.py create mode 100644 chebai/preprocessing/datasets/deepGO/__init__.py create mode 100644 chebai/preprocessing/datasets/deepGO/go_uniprot.py create mode 100644 chebai/preprocessing/datasets/deepGO/protein_pretraining.py create mode 100644 chebai/preprocessing/datasets/scope/__init__.py create mode 100644 chebai/preprocessing/datasets/scope/scope.py create mode 100644 chebai/preprocessing/migration/__init__.py create mode 100644 chebai/preprocessing/migration/deep_go/__init__.py create mode 100644 chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py create mode 100644 chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py create mode 100644 chebai/preprocessing/reader.py create mode 100644 chebai/preprocessing/structures.py create mode 100644 chebai/result/__init__.py create mode 100644 chebai/result/analyse_sem.py create mode 100644 chebai/result/base.py create mode 100644 chebai/result/classification.py create mode 100644 chebai/result/evaluate_predictions.py create mode 100644 chebai/result/molplot.py create mode 100644 chebai/result/prediction_json.py create mode 100644 chebai/result/pretraining.py create mode 100644 chebai/result/utils.py create mode 100644 chebai/trainer/CustomTrainer.py create mode 100644 chebai/trainer/__init__.py create mode 100644 configs/data/deepGO/deepgo2_esm2.yml create mode 100644 configs/data/deepGO/deepgo_1_migrated_data.yml create mode 100644 configs/data/deepGO/deepgo_2_migrated_data.yml create mode 100644 configs/data/deepGO/go250.yml create mode 100644 configs/data/deepGO/go50.yml create mode 100644 configs/data/scope/scope2000.yml create mode 100644 configs/data/scope/scope50.yml create mode 100644 configs/default_prediction_callback.yml create mode 100644 configs/loss/bce.yml create mode 100644 configs/loss/electra_pre_loss.yml create mode 100644 configs/loss/semantic_loss.yml create mode 100644 configs/metrics/balanced-accuracy.yml create mode 100644 configs/metrics/micro-macro-f1.yml create mode 100644 configs/metrics/single-class-f1.yml create mode 100644 configs/model/electra-for-pretraining.yml create mode 100644 configs/model/electra.yml create mode 100644 configs/model/electra_pretraining.yml create mode 100644 configs/model/ffn.yml create mode 100644 configs/training/csv_logger.yml create mode 100644 configs/training/default_callbacks.yml create mode 100644 configs/training/default_trainer.yml create mode 100644 configs/training/early_stop_callbacks.yml create mode 100644 configs/training/pretraining_callbacks.yml create mode 100644 configs/training/pretraining_trainer.yml create mode 100644 configs/training/single_class_callbacks.yml create mode 100644 configs/training/wandb_logger.yml create mode 100644 docs/source/experiment.rst create mode 100644 docs/source/model.rst create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 tests/__init__.py create mode 100644 tests/unit/__init__.py create mode 100644 tests/unit/collators/__init__.py create mode 100644 tests/unit/collators/testDefaultCollator.py create mode 100644 tests/unit/collators/testRaggedCollator.py create mode 100644 tests/unit/dataset_classes/__init__.py create mode 100644 tests/unit/dataset_classes/testDynamicDataset.py create mode 100644 tests/unit/dataset_classes/testGOUniProDataExtractor.py create mode 100644 tests/unit/dataset_classes/testGoUniProtOverX.py create mode 100644 tests/unit/dataset_classes/testProteinPretrainingData.py create mode 100644 tests/unit/dataset_classes/testXYBaseDataModule.py create mode 100644 tests/unit/mock_data/__init__.py create mode 100644 tests/unit/mock_data/ontology_mock_data.py create mode 100644 tests/unit/readers/__init__.py create mode 100644 tests/unit/readers/testDataReader.py create mode 100644 tests/unit/readers/testProteinDataReader.py create mode 100644 tutorials/data_exploration_go.ipynb create mode 100644 tutorials/data_exploration_scope.ipynb diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..a8cf84d --- /dev/null +++ b/.gitattributes @@ -0,0 +1,10 @@ +# Juypter notebooks contains images, and tables, and parsing text +# blowing up the total language fraction unrealistically; +# then 'Juypter notebooks' are suddenly major part of repo language. + +# As they don't want to parse notebooks better +# (wont-fix = https://github.com/github/linguist/issues/3496) +# Simply exclude this file from counting now: + +notebooks/*.ipynb linguist-generated=true +stream_viz/tutorial/*.ipynb linguist-generated=true diff --git a/.github/workflows/black.yml b/.github/workflows/black.yml new file mode 100644 index 0000000..b04fb15 --- /dev/null +++ b/.github/workflows/black.yml @@ -0,0 +1,10 @@ +name: Lint + +on: [push, pull_request] + +jobs: + lint: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - uses: psf/black@stable diff --git a/.github/workflows/export_constants.py b/.github/workflows/export_constants.py new file mode 100644 index 0000000..6421498 --- /dev/null +++ b/.github/workflows/export_constants.py @@ -0,0 +1,22 @@ +import json + +from chebai.preprocessing.reader import ( + CLS_TOKEN, + EMBEDDING_OFFSET, + MASK_TOKEN_INDEX, + PADDING_TOKEN_INDEX, +) + +# Define the constants you want to export +# Any changes in the key names here should also follow the same change in verify_constants.yml code +constants = { + "EMBEDDING_OFFSET": EMBEDDING_OFFSET, + "CLS_TOKEN": CLS_TOKEN, + "PADDING_TOKEN_INDEX": PADDING_TOKEN_INDEX, + "MASK_TOKEN_INDEX": MASK_TOKEN_INDEX, +} + +if __name__ == "__main__": + # Write constants to a JSON file + with open("constants.json", "w") as f: + json.dump(constants, f) diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 0000000..0ad2115 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,38 @@ +name: Unittests + +on: [pull_request] + +jobs: + build: + + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + python -m pip install --upgrade pip setuptools wheel + python -m pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cpu + python -m pip install -e . + + - name: Display Python & Installed Packages + run: | + python --version + pip freeze + + - name: Run Unit Tests + run: python -m unittest discover -s tests/unit -v + env: + ACTIONS_STEP_DEBUG: true # Enable debug logs + ACTIONS_RUNNER_DEBUG: true # Additional debug logs from Github Actions itself diff --git a/.github/workflows/token_consistency.yaml b/.github/workflows/token_consistency.yaml new file mode 100644 index 0000000..df9d8b6 --- /dev/null +++ b/.github/workflows/token_consistency.yaml @@ -0,0 +1,110 @@ +name: Check consistency of tokens.txt file + +# Define the file paths under `paths` to trigger this check only when specific files are modified. +# This script will then execute checks only on files that have changed, rather than all files listed in `paths`. + +# **Note** : To add a new token file for checks, include its path in: +# - `on` -> `push` and `pull_request` sections +# - `jobs` -> `check_tokens` -> `steps` -> Set global variable for multiple tokens.txt paths -> `TOKENS_FILES` + +on: + push: + paths: + - "chebai/preprocessing/bin/protein_token/tokens.txt" + - "chebai/preprocessing/bin/protein_token_3_gram/tokens.txt" + pull_request: + paths: + - "chebai/preprocessing/bin/protein_token/tokens.txt" + - "chebai/preprocessing/bin/protein_token_3_gram/tokens.txt" + +jobs: + check_tokens: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v2 + + - name: Get list of changed files + id: changed_files + run: | + git fetch origin dev + + # Get the list of changed files compared to origin/dev and save them to a file + git diff --name-only origin/dev > changed_files.txt + + # Print the names of changed files on separate lines + echo "Changed files:" + while read -r line; do + echo "Changed File name : $line" + done < changed_files.txt + + - name: Set global variable for multiple tokens.txt paths + run: | + # All token files that needs to checked must be included here too, same as in `paths`. + TOKENS_FILES=( + "chebai/preprocessing/bin/protein_token/tokens.txt" + "chebai/preprocessing/bin/protein_token_3_gram/tokens.txt" + ) + echo "TOKENS_FILES=${TOKENS_FILES[*]}" >> $GITHUB_ENV + + - name: Process only changed tokens.txt files + run: | + # Convert the TOKENS_FILES environment variable into an array + TOKENS_FILES=(${TOKENS_FILES}) + + # Iterate over each token file path + for TOKENS_FILE_PATH in "${TOKENS_FILES[@]}"; do + # Check if the current token file path is in the list of changed files + if grep -q "$TOKENS_FILE_PATH" changed_files.txt; then + echo "----------------------- Processing $TOKENS_FILE_PATH -----------------------" + + # Get previous tokens.txt version + git fetch origin dev + git diff origin/dev -- $TOKENS_FILE_PATH > tokens_diff.txt || echo "No previous tokens.txt found for $TOKENS_FILE_PATH" + + # Check for deleted or added lines in tokens.txt + if [ -f tokens_diff.txt ]; then + + # Check for deleted lines (lines starting with '-') + deleted_lines=$(grep '^-' tokens_diff.txt | grep -v '^---' | sed 's/^-//' || true) + if [ -n "$deleted_lines" ]; then + echo "Error: Lines have been deleted from $TOKENS_FILE_PATH." + echo -e "Deleted Lines: \n$deleted_lines" + exit 1 + fi + + # Check for added lines (lines starting with '+') + added_lines=$(grep '^+' tokens_diff.txt | grep -v '^+++' | sed 's/^+//' || true) + if [ -n "$added_lines" ]; then + + # Count how many lines have been added + num_added_lines=$(echo "$added_lines" | wc -l) + + # Get last `n` lines (equal to num_added_lines) of tokens.txt + last_lines=$(tail -n "$num_added_lines" $TOKENS_FILE_PATH) + + # Check if the added lines are at the end of the file + if [ "$added_lines" != "$last_lines" ]; then + + # Find lines that were added but not appended at the end of the file + non_appended_lines=$(diff <(echo "$added_lines") <(echo "$last_lines") | grep '^<' | sed 's/^< //') + + echo "Error: New lines have been added to $TOKENS_FILE_PATH, but they are not at the end of the file." + echo -e "Added lines that are not at the end of the file: \n$non_appended_lines" + exit 1 + fi + fi + + if [ "$added_lines" == "" ]; then + echo "$TOKENS_FILE_PATH validation successful: No lines were deleted, and no new lines were added." + else + echo "$TOKENS_FILE_PATH validation successful: No lines were deleted, and new lines were correctly appended at the end." + fi + else + echo "No previous version of $TOKENS_FILE_PATH found." + fi + else + echo "$TOKENS_FILE_PATH was not changed, skipping." + fi + done diff --git a/.github/workflows/verify_constants.yml b/.github/workflows/verify_constants.yml new file mode 100644 index 0000000..3246f64 --- /dev/null +++ b/.github/workflows/verify_constants.yml @@ -0,0 +1,116 @@ +name: Verify Constants + +# Define the file paths under `paths` to trigger this check only when specific files are modified. +# This script will then execute checks only on files that have changed, rather than all files listed in `paths`. + +# **Note** : To add a new file for checks, include its path in: +# - `on` -> `push` and `pull_request` sections +# - `jobs` -> `verify-constants` -> `steps` -> Verify constants -> Add a new if else for your file, with check logic inside it. + + +on: + push: + paths: + - "chebai/preprocessing/reader.py" + pull_request: + paths: + - "chebai/preprocessing/reader.py" + +jobs: + verify-constants: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: [ +# Only use 3.10 as of now +# "3.9", + "3.10", +# "3.11" + ] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + + - name: Set PYTHONPATH + run: echo "PYTHONPATH=$PWD" >> $GITHUB_ENV + + - name: Get list of changed files + id: changed_files + run: | + git fetch origin dev + + # Get the list of changed files compared to origin/dev and save them to a file + git diff --name-only origin/dev > changed_files.txt + + # Print the names of changed files on separate lines + echo "Changed files:" + while read -r line; do + echo "Changed File name : $line" + done < changed_files.txt + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + # Setting a fix version for torch due to an error with latest version (2.5.1) + # ImportError: cannot import name 'T_co' from 'torch.utils.data.dataset' + run: | + python -m pip install --upgrade pip + python -m pip install --upgrade pip setuptools wheel + python -m pip install torch==2.4.1 --index-url https://download.pytorch.org/whl/cpu + python -m pip install -e . + + - name: Export constants + run: python .github/workflows/export_constants.py + + - name: Load constants into environment variables + id: load_constants + # "E_" is appended as suffix to every constant, to protect overwriting other sys env variables with same name + run: | + constants=$(cat constants.json) + echo "$constants" | jq -r 'to_entries|map("E_\(.key)=\(.value|tostring)")|.[]' >> $GITHUB_ENV + + - name: Print all environment variables + run: printenv + + - name: Verify constants + run: | + file_name="chebai/preprocessing/reader.py" + if grep -q "$file_name" changed_files.txt; then + echo "----------------------- Checking file : $file_name ----------------------- " + + # Define expected values for constants + exp_embedding_offset="10" + exp_cls_token="2" + exp_padding_token_index="0" + exp_mask_token_index="1" + + # Debugging output to check environment variables + echo "Current Environment Variables:" + echo "E_EMBEDDING_OFFSET = $E_EMBEDDING_OFFSET" + echo "Expected: $exp_embedding_offset" + + # Verify constants match expected values + if [ "$E_EMBEDDING_OFFSET" != "$exp_embedding_offset" ]; then + echo "EMBEDDING_OFFSET ($E_EMBEDDING_OFFSET) does not match expected value ($exp_embedding_offset)!" + exit 1 + fi + if [ "$E_CLS_TOKEN" != "$exp_cls_token" ]; then + echo "CLS_TOKEN ($E_CLS_TOKEN) does not match expected value ($exp_cls_token)!" + exit 1 + fi + if [ "$E_PADDING_TOKEN_INDEX" != "$exp_padding_token_index" ]; then + echo "PADDING_TOKEN_INDEX ($E_PADDING_TOKEN_INDEX) does not match expected value ($exp_padding_token_index)!" + exit 1 + fi + if [ "$E_MASK_TOKEN_INDEX" != "$exp_mask_token_index" ]; then + echo "MASK_TOKEN_INDEX ($E_MASK_TOKEN_INDEX) does not match expected value ($exp_mask_token_index)!" + exit 1 + fi + else + echo "$file_name not found in changed_files.txt; skipping check." + fi diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..f9cb175 --- /dev/null +++ b/.gitignore @@ -0,0 +1,169 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ +docs/build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +#.idea/ + +# configs/ # commented as new configs can be added as a part of a feature + +/.idea +/data +/logs +/results_buffer +electra_pretrained.ckpt diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000..108b91d --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,25 @@ +repos: +- repo: https://github.com/psf/black + rev: "24.2.0" + hooks: + - id: black + - id: black-jupyter # for formatting jupyter-notebook + +- repo: https://github.com/pycqa/isort + rev: 5.13.2 + hooks: + - id: isort + name: isort (python) + args: ["--profile=black"] + +- repo: https://github.com/asottile/seed-isort-config + rev: v2.2.0 + hooks: + - id: seed-isort-config + +- repo: https://github.com/pre-commit/pre-commit-hooks + rev: v4.6.0 + hooks: + - id: check-yaml + - id: end-of-file-fixer + - id: trailing-whitespace diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000..0ad25db --- /dev/null +++ b/LICENSE @@ -0,0 +1,661 @@ + GNU AFFERO GENERAL PUBLIC LICENSE + Version 3, 19 November 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU Affero General Public License is a free, copyleft license for +software and other kinds of works, specifically designed to ensure +cooperation with the community in the case of network server software. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +our General Public Licenses are intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + Developers that use our General Public Licenses protect your rights +with two steps: (1) assert copyright on the software, and (2) offer +you this License which gives you legal permission to copy, distribute +and/or modify the software. + + A secondary benefit of defending all users' freedom is that +improvements made in alternate versions of the program, if they +receive widespread use, become available for other developers to +incorporate. Many developers of free software are heartened and +encouraged by the resulting cooperation. However, in the case of +software used on network servers, this result may fail to come about. +The GNU General Public License permits making a modified version and +letting the public access it on a server without ever releasing its +source code to the public. + + The GNU Affero General Public License is designed specifically to +ensure that, in such cases, the modified source code becomes available +to the community. It requires the operator of a network server to +provide the source code of the modified version running there to the +users of that server. Therefore, public use of a modified version, on +a publicly accessible server, gives the public access to the source +code of the modified version. + + An older license, called the Affero General Public License and +published by Affero, was designed to accomplish similar goals. This is +a different license, not a version of the Affero GPL, but Affero has +released a new version of the Affero GPL which permits relicensing under +this license. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU Affero General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Remote Network Interaction; Use with the GNU General Public License. + + Notwithstanding any other provision of this License, if you modify the +Program, your modified version must prominently offer all users +interacting with it remotely through a computer network (if your version +supports such interaction) an opportunity to receive the Corresponding +Source of your version by providing access to the Corresponding Source +from a network server at no charge, through some standard or customary +means of facilitating copying of software. This Corresponding Source +shall include the Corresponding Source for any work covered by version 3 +of the GNU General Public License that is incorporated pursuant to the +following paragraph. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the work with which it is combined will remain governed by version +3 of the GNU General Public License. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU Affero General Public License from time to time. Such new versions +will be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU Affero General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU Affero General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU Affero General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU Affero General Public License as published + by the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Affero General Public License for more details. + + You should have received a copy of the GNU Affero General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If your software can interact with users remotely through a computer +network, you should also make sure that it provides a way for users to +get its source. For example, if your program is a web application, its +interface could display a "Source" link that leads users to an archive +of the code. There are many ways you could offer source, and different +solutions will be better for different programs; see section 13 for the +specific requirements. + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU AGPL, see +. diff --git a/chebai/__init__.py b/chebai/__init__.py new file mode 100644 index 0000000..9f508aa --- /dev/null +++ b/chebai/__init__.py @@ -0,0 +1,30 @@ +import os +from typing import Any + +import torch + +# Get the absolute path of the current file's directory +MODULE_PATH = os.path.abspath(os.path.dirname(__file__)) + + +class CustomTensor(torch.Tensor): + """ + A custom tensor class inheriting from `torch.Tensor`. + + This class allows for the creation of tensors using the provided data. + + Attributes: + data (Any): The data to be converted into a tensor. + """ + + def __new__(cls, data: Any) -> "CustomTensor": + """ + Creates a new instance of CustomTensor. + + Args: + data (Any): The data to be converted into a tensor. + + Returns: + CustomTensor: A tensor containing the provided data. + """ + return torch.tensor(data) diff --git a/chebai/__main__.py b/chebai/__main__.py new file mode 100644 index 0000000..0afee8e --- /dev/null +++ b/chebai/__main__.py @@ -0,0 +1,10 @@ +from chebai.cli import cli + +if __name__ == "__main__": + """ + Entry point for the CLI application. + + This script calls the `cli` function from the `chebai.cli` module + when executed as the main program. + """ + cli() diff --git a/chebai/callbacks.py b/chebai/callbacks.py new file mode 100644 index 0000000..764db44 --- /dev/null +++ b/chebai/callbacks.py @@ -0,0 +1,86 @@ +import json +import os +from typing import Any, Dict, List, Literal, Union + +import torch +from lightning.pytorch.callbacks import BasePredictionWriter + + +class ChebaiPredictionWriter(BasePredictionWriter): + """ + A custom prediction writer for saving batch and epoch predictions during model training. + + This class inherits from `BasePredictionWriter` and is designed to save predictions + in a specified output directory at specified intervals. + + Args: + output_dir (str): The directory where predictions will be saved. + write_interval (str): The interval at which predictions will be written. + target_file (str): The name of the file where epoch predictions will be saved (default: "predictions.json"). + """ + + def __init__( + self, + output_dir: str, + write_interval: Literal["batch", "epoch", "batch_and_epoch"], + target_file: str = "predictions.json", + ) -> None: + super().__init__(write_interval) + self.output_dir = output_dir + self.target_file = target_file + + def write_on_batch_end( + self, + trainer: Any, + pl_module: Any, + prediction: Union[torch.Tensor, List[torch.Tensor]], + batch_indices: List[int], + batch: Any, + batch_idx: int, + dataloader_idx: int, + ) -> None: + """ + Saves batch predictions at the end of each batch. + + Args: + trainer (Any): The trainer instance. + pl_module (Any): The LightningModule instance. + prediction (Union[torch.Tensor, List[torch.Tensor]]): The prediction output from the model. + batch_indices (List[int]): The indices of the batch. + batch (Any): The current batch. + batch_idx (int): The index of the batch. + dataloader_idx (int): The index of the dataloader. + """ + outpath = os.path.join(self.output_dir, str(dataloader_idx), f"{batch_idx}.pt") + os.makedirs(os.path.dirname(outpath), exist_ok=True) + torch.save(prediction, outpath) + + def write_on_epoch_end( + self, + trainer: Any, + pl_module: Any, + predictions: List[Dict[str, Any]], + batch_indices: List[int], + ) -> None: + """ + Saves all predictions at the end of each epoch in a JSON file. + + Args: + trainer (Any): The trainer instance. + pl_module (Any): The LightningModule instance. + predictions (List[Dict[str, Any]]): The list of prediction outputs from the model. + batch_indices (List[int]): The indices of the batches. + """ + pred_list = [] + for p in predictions: + idents = p["data"]["idents"] + labels = p["data"]["labels"] + if labels is not None: + labels = labels.tolist() + else: + labels = [None for _ in idents] + output = torch.sigmoid(p["output"]["logits"]).tolist() + for i, l, o in zip(idents, labels, output): + pred_list.append(dict(ident=i, labels=l, predictions=o)) + with open(os.path.join(self.output_dir, self.target_file), "wt") as fout: + json.dump(pred_list, fout) diff --git a/chebai/callbacks/__init__.py b/chebai/callbacks/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chebai/callbacks/epoch_metrics.py b/chebai/callbacks/epoch_metrics.py new file mode 100644 index 0000000..c1cf7bd --- /dev/null +++ b/chebai/callbacks/epoch_metrics.py @@ -0,0 +1,180 @@ +import torch +import torchmetrics + + +def custom_reduce_fx(input: torch.Tensor) -> torch.Tensor: + """ + Custom reduction function for distributed training. + + Args: + input (torch.Tensor): The input tensor to be reduced. + + Returns: + torch.Tensor: The reduced tensor. + """ + print(f"called reduce (device: {input.device})") + return torch.sum(input, dim=0) + + +class MacroF1(torchmetrics.Metric): + """ + Computes the Macro F1 score, which is the unweighted mean of F1 scores for each class. + This implementation differs from torchmetrics.classification.MultilabelF1Score in the behaviour for undefined + values (i.e., classes where TP+FN=0). The torchmetrics implementation sets these classes to a default value. + Here, the mean is only taken over classes which have at least one positive sample. + + Args: + num_labels (int): Number of classes/labels. + dist_sync_on_step (bool, optional): Synchronize metric state across processes at each forward + before returning the value at the step. Default: False. + threshold (float, optional): Threshold for converting predicted probabilities to binary (0, 1) predictions. + Default: 0.5. + """ + + def __init__( + self, num_labels: int, dist_sync_on_step: bool = False, threshold: float = 0.5 + ): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.add_state( + "true_positives", + default=torch.zeros(num_labels, dtype=torch.int), + dist_reduce_fx="sum", + ) + self.add_state( + "positive_predictions", + default=torch.zeros(num_labels, dtype=torch.int), + dist_reduce_fx="sum", + ) + self.add_state( + "positive_labels", + default=torch.zeros(num_labels, dtype=torch.int), + dist_reduce_fx="sum", + ) + self.threshold = threshold + + def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None: + """ + Update the state (TPs, Positive Predictions, Positive labels) with the current batch of predictions and labels. + + Args: + preds (torch.Tensor): Predictions from the model. + labels (torch.Tensor): Ground truth labels. + """ + tps = torch.sum( + torch.logical_and(preds > self.threshold, labels.to(torch.bool)), dim=0 + ) + self.true_positives += tps + self.positive_predictions += torch.sum(preds > self.threshold, dim=0) + self.positive_labels += torch.sum(labels, dim=0) + + def compute(self) -> torch.Tensor: + """ + Compute the Macro F1 score. + + Returns: + torch.Tensor: The computed Macro F1 score. + """ + + # ignore classes without positive labels + # classes with positive labels, but no positive predictions will get a precision of "nan" (0 divided by 0), + # which is propagated to the classwise_f1 and then turned into 0 + mask = self.positive_labels != 0 + precision = self.true_positives[mask] / self.positive_predictions[mask] + recall = self.true_positives[mask] / self.positive_labels[mask] + classwise_f1 = 2 * precision * recall / (precision + recall) + # if (precision and recall are 0) or (precision is nan), set f1 to 0 + classwise_f1 = classwise_f1.nan_to_num() + return torch.mean(classwise_f1) + + +class BalancedAccuracy(torchmetrics.Metric): + """ + Computes the Balanced Accuracy, which is the average of true positive rate (TPR) and true negative rate (TNR). + Useful for imbalanced datasets. + Balanced Accuracy = (TPR + TNR)/2 = (TP/(TP + FN) + (TN)/(TN + FP))/2 + + Args: + num_labels (int): Number of classes/labels. + dist_sync_on_step (bool, optional): Synchronize metric state across processes at each forward + before returning the value at the step. Default: False. + threshold (float, optional): Threshold for converting predicted probabilities to binary (0, 1) predictions. + Default: 0.5. + """ + + def __init__( + self, num_labels: int, dist_sync_on_step: bool = False, threshold: float = 0.5 + ): + super().__init__(dist_sync_on_step=dist_sync_on_step) + + self.add_state( + "true_positives", + default=torch.zeros(num_labels, dtype=torch.int), + dist_reduce_fx="sum", + ) + + self.add_state( + "false_positives", + default=torch.zeros(num_labels, dtype=torch.int), + dist_reduce_fx="sum", + ) + + self.add_state( + "true_negatives", + default=torch.zeros(num_labels, dtype=torch.int), + dist_reduce_fx="sum", + ) + + self.add_state( + "false_negatives", + default=torch.zeros(num_labels, dtype=torch.int), + dist_reduce_fx="sum", + ) + + self.threshold = threshold + + def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None: + """ + Update the state (TPs, TNs, FPs, FNs) with the current batch of predictions and labels. + + Args: + preds (torch.Tensor): Predictions from the model. + labels (torch.Tensor): Ground truth labels. + """ + + # Size: Batch_size x Num_of_Classes; + # summing over 1st dimension (dim=0), gives us the True positives per class + tps = torch.sum( + torch.logical_and(preds > self.threshold, labels.to(torch.bool)), dim=0 + ) + fps = torch.sum( + torch.logical_and(preds > self.threshold, ~labels.to(torch.bool)), dim=0 + ) + tns = torch.sum( + torch.logical_and(preds <= self.threshold, ~labels.to(torch.bool)), dim=0 + ) + fns = torch.sum( + torch.logical_and(preds <= self.threshold, labels.to(torch.bool)), dim=0 + ) + + # Size: Num_of_Classes; + self.true_positives += tps + self.false_positives += fps + self.true_negatives += tns + self.false_negatives += fns + + def compute(self) -> torch.Tensor: + """ + Compute the Balanced Accuracy. + + Returns: + torch.Tensor: The computed Balanced Accuracy. + """ + tpr = self.true_positives / (self.true_positives + self.false_negatives) + tnr = self.true_negatives / (self.true_negatives + self.false_positives) + # Convert the nan values to 0 + tpr = tpr.nan_to_num() + tnr = tnr.nan_to_num() + + balanced_acc = (tpr + tnr) / 2 + return torch.mean(balanced_acc) diff --git a/chebai/callbacks/model_checkpoint.py b/chebai/callbacks/model_checkpoint.py new file mode 100644 index 0000000..dbdbab1 --- /dev/null +++ b/chebai/callbacks/model_checkpoint.py @@ -0,0 +1,95 @@ +import os + +from lightning.fabric.utilities.cloud_io import _is_dir +from lightning.fabric.utilities.types import _PATH +from lightning.pytorch import LightningModule, Trainer +from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint +from lightning.pytorch.loggers import WandbLogger +from lightning.pytorch.utilities.rank_zero import rank_zero_info +from lightning_utilities.core.rank_zero import rank_zero_warn + + +class CustomModelCheckpoint(ModelCheckpoint): + """ + Custom checkpoint class that resolves checkpoint paths to ensure checkpoints are saved in the same directory + as other logs when using CustomLogger. + Inherits from PyTorch Lightning's ModelCheckpoint class. + """ + + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: + """ + Setup the directory path for saving checkpoints. If the directory path is not set, it resolves the checkpoint + directory using the custom logger's directory. + + Note: + Same as in parent class, duplicated to be able to call self.__resolve_ckpt_dir + + Args: + trainer (Trainer): The Trainer instance. + pl_module (LightningModule): The LightningModule instance. + stage (str): The stage of training (e.g., 'fit'). + """ + if self.dirpath is not None: + self.dirpath = None + dirpath = self.__resolve_ckpt_dir(trainer) + dirpath = trainer.strategy.broadcast(dirpath) + self.dirpath = dirpath + if trainer.is_global_zero and stage == "fit": + self.__warn_if_dir_not_empty(self.dirpath) + + def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: + """ + Warn if the checkpoint directory is not empty. + + Note: + Same as in parent class, duplicated because method in parent class is not accessible + + Args: + dirpath (_PATH): The path to the checkpoint directory. + """ + if ( + self.save_top_k != 0 + and _is_dir(self._fs, dirpath, strict=True) + and len(self._fs.ls(dirpath)) > 0 + ): + rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.") + + def __resolve_ckpt_dir(self, trainer: Trainer) -> _PATH: + """ + Resolve the checkpoint directory path, ensuring compatibility with WandbLogger by saving checkpoints + in the same directory as Wandb logs. + + Note: + Overwritten for compatibility with wandb -> saves checkpoints in same dir as wandb logs + + Args: + trainer (Trainer): The Trainer instance. + + Returns: + _PATH: The resolved checkpoint directory path. + """ + rank_zero_info(f"Resolving checkpoint dir (custom)") + if self.dirpath is not None: + # short circuit if dirpath was passed to ModelCheckpoint + return self.dirpath + if len(trainer.loggers) > 0: + if trainer.loggers[0].save_dir is not None: + save_dir = trainer.loggers[0].save_dir + else: + save_dir = trainer.default_root_dir + name = trainer.loggers[0].name + version = trainer.loggers[0].version + version = version if isinstance(version, str) else f"version_{version}" + logger = trainer.loggers[0] + if isinstance(logger, WandbLogger) and isinstance( + logger.experiment.dir, str + ): + ckpt_path = os.path.join(logger.experiment.dir, "checkpoints") + else: + ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints") + else: + # if no loggers, use default_root_dir + ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints") + + rank_zero_info(f"Now using checkpoint path {ckpt_path}") + return ckpt_path diff --git a/chebai/callbacks/prediction_callback.py b/chebai/callbacks/prediction_callback.py new file mode 100644 index 0000000..b36197d --- /dev/null +++ b/chebai/callbacks/prediction_callback.py @@ -0,0 +1,55 @@ +import os +import pickle +from typing import Any, Literal, Sequence + +import torch +from lightning.pytorch import LightningModule, Trainer +from lightning.pytorch.callbacks import BasePredictionWriter + + +class PredictionWriter(BasePredictionWriter): + """ + Custom callback for writing predictions to a file at the end of each epoch. + + Args: + output_dir (str): The directory where prediction files will be saved. + write_interval (str): When to write predictions. Options are "batch" or "epoch". + """ + + def __init__( + self, + output_dir: str, + write_interval: Literal["batch", "epoch", "batch_and_epoch"], + ): + super().__init__(write_interval) + self.output_dir = output_dir + self.prediction_file_name = "predictions.pkl" + + def write_on_epoch_end( + self, + trainer: Trainer, + pl_module: LightningModule, + predictions: Sequence[Any], + batch_indices: Sequence[Any], + ) -> None: + """ + Writes the predictions to a file at the end of the epoch. + + Args: + trainer (Trainer): The Trainer instance. + pl_module (LightningModule): The LightningModule instance. + predictions (Sequence[Any]): Any sequence of predictions for the epoch. + batch_indices (Sequence[Any]): Any sequence of batch indices. + """ + results = [ + dict( + ident=row["data"]["idents"][0], + predictions=torch.sigmoid(row["output"]["logits"]).numpy(), + labels=row["labels"][0].numpy() if row["labels"] is not None else None, + ) + for row in predictions + ] + with open( + os.path.join(self.output_dir, self.prediction_file_name), "wb" + ) as fout: + pickle.dump(results, fout) diff --git a/chebai/cli.py b/chebai/cli.py new file mode 100644 index 0000000..b7e78d1 --- /dev/null +++ b/chebai/cli.py @@ -0,0 +1,75 @@ +from typing import Dict, Set + +from lightning.pytorch.cli import LightningArgumentParser, LightningCLI + +from chebai.trainer.CustomTrainer import CustomTrainer + + +class ChebaiCLI(LightningCLI): + """ + Custom CLI subclass for Chebai project based on PyTorch Lightning's LightningCLI. + + Args: + save_config_kwargs (dict): Keyword arguments for saving configuration. + parser_kwargs (dict): Keyword arguments for parser configuration. + + Attributes: + save_config_kwargs (dict): Configuration options for saving. + parser_kwargs (dict): Configuration options for the argument parser. + """ + + def __init__(self, *args, **kwargs): + """ + Initialize ChebaiCLI with custom trainer and configure parser settings. + + Args: + args (list): List of arguments for LightningCLI. + kwargs (dict): Keyword arguments for LightningCLI. + save_config_kwargs (dict): Keyword arguments for saving configuration. + parser_kwargs (dict): Keyword arguments for parser configuration. + """ + super().__init__(trainer_class=CustomTrainer, *args, **kwargs) + + def add_arguments_to_parser(self, parser: LightningArgumentParser): + """ + Link input parameters that are used by different classes (e.g. number of labels) + see https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_expert.html#argument-linking + + Args: + parser (LightningArgumentParser): Argument parser instance. + """ + for kind in ("train", "val", "test"): + for average in ("micro-f1", "macro-f1", "balanced-accuracy"): + parser.link_arguments( + "model.init_args.out_dim", + f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels", + ) + parser.link_arguments( + "model.init_args.out_dim", "trainer.callbacks.init_args.num_labels" + ) + + @staticmethod + def subcommands() -> Dict[str, Set[str]]: + """ + Defines the list of available subcommands and the arguments to skip. + + Returns: + Dict[str, Set[str]]: Dictionary where keys are subcommands and values are sets of arguments to skip. + """ + return { + "fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"}, + "validate": {"model", "dataloaders", "datamodule"}, + "test": {"model", "dataloaders", "datamodule"}, + "predict": {"model", "dataloaders", "datamodule"}, + "predict_from_file": {"model"}, + } + + +def cli(): + """ + Main function to instantiate and run the ChebaiCLI. + """ + r = ChebaiCLI( + save_config_kwargs={"config_filename": "lightning_config.yaml"}, + parser_kwargs={"parser_mode": "omegaconf"}, + ) diff --git a/chebai/loggers/__init__.py b/chebai/loggers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chebai/loggers/custom.py b/chebai/loggers/custom.py new file mode 100644 index 0000000..d1b4282 --- /dev/null +++ b/chebai/loggers/custom.py @@ -0,0 +1,127 @@ +import os +from datetime import datetime +from typing import List, Literal, Optional, Union + +import wandb +from lightning.fabric.utilities.types import _PATH +from lightning.pytorch.callbacks import ModelCheckpoint +from lightning.pytorch.loggers import WandbLogger + + +class CustomLogger(WandbLogger): + """ + A custom logger that extends WandbLogger to add support for custom naming of runs and cross-validation. + + Args: + save_dir (_PATH): Directory where logs are saved. + name (str): Name of the logging run. + version (Optional[Union[int, str]]): Version of the logging run. + prefix (str): Prefix for logging. + fold (Optional[int]): Cross-validation fold number. + project (Optional[str]): Wandb project name. + entity (Optional[str]): Wandb entity name. + offline (bool): Whether to log offline. + log_model (Union[Literal["all"], bool]): Whether to log the model. + verbose_hyperparameters (bool): Whether to log hyperparameters verbosely. + tags (Optional[List[str]]): List of tags for the run. + **kwargs: Additional keyword arguments for WandbLogger. + """ + + def __init__( + self, + save_dir: _PATH, + name: str = "logs", + version: Optional[Union[int, str]] = None, + prefix: str = "", + fold: Optional[int] = None, + project: Optional[str] = None, + entity: Optional[str] = None, + offline: bool = False, + log_model: Union[Literal["all"], bool] = False, + verbose_hyperparameters: bool = False, + tags: Optional[List[str]] = None, + **kwargs, + ): + if version is None: + version = f"{datetime.now():%y%m%d-%H%M}" + self._version = version + self._name = name + self._fold = fold + self.verbose_hyperparameters = verbose_hyperparameters + super().__init__( + name=self.name, + save_dir=save_dir, + version=None, + prefix=prefix, + log_model=log_model, + entity=entity, + project=project, + offline=offline, + **kwargs, + ) + if tags: + self.experiment.tags += tuple(tags) + + @property + def name(self) -> Optional[str]: + """ + Returns the name of the logging run, including the version and fold number if applicable. + """ + name = f"{self._name}_{self.version}" + if self._fold is not None: + name += f"_fold{self._fold}" + return name + + @property + def version(self) -> Optional[str]: + """ + Returns the version of the logging run. + """ + return self._version + + @property + def root_dir(self) -> Optional[str]: + """ + Returns the root directory for saving logs. + """ + return os.path.join(self.save_dir, self.name) + + @property + def log_dir(self) -> str: + """ + Returns the directory for saving logs, including the version and fold number if applicable. + """ + version = ( + self.version if isinstance(self.version, str) else f"version_{self.version}" + ) + if self._fold is None: + return os.path.join(self.root_dir, version) + return os.path.join(self.root_dir, version, f"fold_{self._fold}") + + def set_fold(self, fold: int) -> None: + """ + Sets the fold number and restarts the Wandb experiment with the new fold number. + + Args: + fold (int): Cross-validation fold number. + """ + if fold != self._fold: + self._fold = fold + # Start new experiment + wandb.finish() + self._wandb_init["name"] = self.name + self._experiment = None + _ = self.experiment + + @property + def fold(self) -> Optional[int]: + """ + Returns the current fold number. + """ + return self._fold + + def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: + """ + Override method to prevent saving checkpoints as Wandb artifacts. + """ + pass diff --git a/chebai/loss/__init__.py b/chebai/loss/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py new file mode 100644 index 0000000..b4fb863 --- /dev/null +++ b/chebai/loss/bce_weighted.py @@ -0,0 +1,98 @@ +import os +from typing import Optional + +import pandas as pd +import torch + +from chebai.preprocessing.datasets.base import XYBaseDataModule +from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor +from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed + + +class BCEWeighted(torch.nn.BCEWithLogitsLoss): + """ + BCEWithLogitsLoss with weights automatically computed according to the beta parameter. + If beta is None or data_extractor is None, the loss is unweighted. + + This class computes weights based on the formula from the paper by Cui et al. (2019): + https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf + + Args: + beta (float, optional): The beta parameter for weight calculation. Default is None. + data_extractor (XYBaseDataModule, optional): The data extractor for loading the dataset. Default is None. + """ + + def __init__( + self, + beta: Optional[float] = None, + data_extractor: Optional[XYBaseDataModule] = None, + **kwargs, + ): + self.beta = beta + if isinstance(data_extractor, LabeledUnlabeledMixed): + data_extractor = data_extractor.labeled + self.data_extractor = data_extractor + assert ( + isinstance(self.data_extractor, _ChEBIDataExtractor) + or self.data_extractor is None + ) + super().__init__(**kwargs) + + def set_pos_weight(self, input: torch.Tensor) -> None: + """ + Sets the positive weights for the loss function based on the input tensor. + + Args: + input (torch.Tensor): The input tensor for which to set the positive weights. + """ + if ( + self.beta is not None + and self.data_extractor is not None + and all( + os.path.exists( + os.path.join(self.data_extractor.processed_dir, file_name) + ) + for file_name in self.data_extractor.processed_file_names + ) + and self.pos_weight is None + ): + print( + f"Computing loss-weights based on v{self.data_extractor.chebi_version} dataset (beta={self.beta})" + ) + complete_labels = torch.concat( + [ + torch.stack( + [ + torch.Tensor(row["labels"]) + for row in self.data_extractor.load_processed_data( + filename=file_name + ) + ] + ) + for file_name in self.data_extractor.processed_file_names + ] + ) + value_counts = complete_labels.sum(dim=0) + weights = [ + (1 - self.beta) / (1 - pow(self.beta, value)) for value in value_counts + ] + mean = sum(weights) / len(weights) + self.pos_weight = torch.tensor( + [w / mean for w in weights], device=input.device + ) + + def forward( + self, input: torch.Tensor, target: torch.Tensor, **kwargs + ) -> torch.Tensor: + """ + Forward pass for the loss calculation. + + Args: + input (torch.Tensor): The input tensor (predictions). + target (torch.Tensor): The target tensor (labels). + + Returns: + torch.Tensor: The computed loss. + """ + self.set_pos_weight(input) + return super().forward(input, target) diff --git a/chebai/loss/mixed.py b/chebai/loss/mixed.py new file mode 100644 index 0000000..edfc5cb --- /dev/null +++ b/chebai/loss/mixed.py @@ -0,0 +1,40 @@ +import torch +from torch import nn + + +class MixedDataLoss(nn.Module): + """ + A wrapper for applying a base loss function to a subset of input data. + + This class allows for selective application of a loss function based on the provided + non-null labels. + + Args: + base_loss (nn.Module): The base loss function to be applied. + """ + + def __init__(self, base_loss: nn.Module): + super().__init__() + self.base_loss = base_loss + + def forward( + self, input: torch.Tensor, target: torch.Tensor, **kwargs + ) -> torch.Tensor: + """ + Forward pass for applying the base loss function. + + Args: + input (torch.Tensor): The input tensor (predictions). + target (torch.Tensor): The target tensor (labels). + **kwargs: Additional keyword arguments. The 'non_null_labels' key can be used + to specify the indices of the non-null labels. + + Returns: + torch.Tensor: The computed loss. + """ + nnl = kwargs.pop("non_null_labels", None) + if nnl: + inp = input[nnl] + else: + inp = input + return self.base_loss(inp, target, **kwargs) diff --git a/chebai/loss/pretraining.py b/chebai/loss/pretraining.py new file mode 100644 index 0000000..e2f51da --- /dev/null +++ b/chebai/loss/pretraining.py @@ -0,0 +1,48 @@ +import torch + + +class ElectraPreLoss(torch.nn.Module): + """ + Custom loss module for pre-training ELECTRA-like models. + + This module computes a combined loss from two CrossEntropyLosses: + one for generator predictions and another for discriminator predictions. + + Attributes: + ce (torch.nn.CrossEntropyLoss): Cross entropy loss function. + + Methods: + forward(input, target, **loss_kwargs): + Computes the combined loss for generator and discriminator predictions. + + """ + + def __init__(self): + """ + Initializes the ElectraPreLoss module. + """ + super().__init__() + self.ce = torch.nn.CrossEntropyLoss() + + def forward(self, input, target, **loss_kwargs): + """ + Forward pass for computing the combined loss. + + Args: + input (tuple): A tuple containing generator predictions (gen_pred, disc_pred). + target (tuple): A tuple containing generator targets (gen_tar, disc_tar). + **loss_kwargs: Additional keyword arguments. + + Returns: + torch.Tensor: Combined loss of generator and discriminator predictions. + """ + t, p = input + gen_pred, disc_pred = t + gen_tar, disc_tar = p + + # Compute losses for generator and discriminator + gen_loss = self.ce(target=torch.argmax(gen_tar.int(), dim=-1), input=gen_pred) + disc_loss = self.ce( + target=torch.argmax(disc_tar.int(), dim=-1), input=disc_pred + ) + return gen_loss + disc_loss diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py new file mode 100644 index 0000000..271c312 --- /dev/null +++ b/chebai/loss/semantic.py @@ -0,0 +1,532 @@ +import csv +import math +import os +import pickle +from typing import List, Literal, Union + +import torch + +from chebai.loss.bce_weighted import BCEWeighted +from chebai.preprocessing.datasets import XYBaseDataModule +from chebai.preprocessing.datasets.chebi import ChEBIOver100, _ChEBIDataExtractor +from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed + + +class ImplicationLoss(torch.nn.Module): + """ + Implication Loss module. + + Args: + data_extractor _ChEBIDataExtractor: Data extractor for labels. + base_loss (torch.nn.Module, optional): Base loss function. Defaults to None. + fuzzy_implication (Literal["product", "lukasiewicz", "xu19"], optional): T-norm type. Defaults to "product". + impl_loss_weight (float, optional): Weight of implication loss relative to base loss. Defaults to 0.1. + pos_scalar (int, optional): Positive scalar exponent. Defaults to 1. + pos_epsilon (float, optional): Epsilon value for numerical stability. Defaults to 0.01. + multiply_by_softmax (bool, optional): Whether to multiply by softmax. Defaults to False. + use_sigmoidal_implication (bool, optional): Whether to use the sigmoidal fuzzy implication based on the + specified fuzzy_implication (as defined by van Krieken et al., 2022: Analyzing Differentiable Fuzzy Logic + Operators). Defaults to False. + weight_epoch_dependent (Union[bool, tuple[int, int]], optional): Whether to weight the implication loss + depending on the current epoch with the sigmoid function sigmoid((epoch-c)/s). If True, c=50 and s=10, + otherwise, a tuple of integers (c,s) can be supplied. Defaults to False. + start_at_epoch (int, optional): Epoch at which to start applying the loss. Defaults to 0. + violations_per_cls_aggregator (Literal["sum", "max"], optional): How to aggregate violations for each class. + If a class is involved in several implications / disjointnesses, the loss value for this class will be + aggregated with this method. Defaults to "sum". + """ + + def __init__( + self, + data_extractor: XYBaseDataModule, + base_loss: torch.nn.Module = None, + fuzzy_implication: Literal[ + "reichenbach", + "rc", + "lukasiewicz", + "lk", + "xu19", + "kleene_dienes", + "kd", + "goedel", + "g", + "reverse-goedel", + "rg", + "binary", + "b", + ] = "reichenbach", + impl_loss_weight: float = 0.1, + pos_scalar: Union[int, float] = 1, + pos_epsilon: float = 0.01, + multiply_by_softmax: bool = False, + use_sigmoidal_implication: bool = False, + weight_epoch_dependent: Union[bool | tuple[int, int]] = False, + start_at_epoch: int = 0, + violations_per_cls_aggregator: Literal[ + "sum", "max", "mean", "log-sum", "log-max", "log-mean" + ] = "sum", + multiply_with_base_loss: bool = True, + no_grads: bool = False, + ): + super().__init__() + # automatically choose labeled subset for implication filter in case of mixed dataset + if isinstance(data_extractor, LabeledUnlabeledMixed): + data_extractor = data_extractor.labeled + assert isinstance(data_extractor, _ChEBIDataExtractor) + self.data_extractor = data_extractor + # propagate data_extractor to base loss + if isinstance(base_loss, BCEWeighted): + base_loss.data_extractor = self.data_extractor + base_loss.reduction = ( + "none" # needed to multiply fuzzy loss with base loss for each sample + ) + self.base_loss = base_loss + self.implication_cache_file = f"implications_{self.data_extractor.name}.cache" + self.label_names = _load_label_names( + os.path.join(data_extractor.processed_dir_main, "classes.txt") + ) + self.hierarchy = self._load_implications( + os.path.join(data_extractor.raw_dir, "chebi.obo") + ) + implication_filter_dense = _build_dense_filter( + _build_implication_filter(self.label_names, self.hierarchy), + len(self.label_names), + ) + self.implication_filter_l = implication_filter_dense + self.implication_filter_r = self.implication_filter_l.transpose(0, 1) + self.fuzzy_implication = fuzzy_implication + self.impl_weight = impl_loss_weight + self.pos_scalar = pos_scalar + self.eps = pos_epsilon + self.multiply_by_softmax = multiply_by_softmax + self.use_sigmoidal_implication = use_sigmoidal_implication + self.weight_epoch_dependent = weight_epoch_dependent + self.start_at_epoch = start_at_epoch + self.violations_per_cls_aggregator = violations_per_cls_aggregator + self.multiply_with_base_loss = multiply_with_base_loss + self.no_grads = no_grads + + def _calculate_unaggregated_fuzzy_loss( + self, + pred, + target: torch.Tensor, + weight, + filter_l, + filter_r, + mode="impl", + **kwargs, + ): + # for each batch, get all pairwise losses: [a1, a2, a3] -> [[a1*a1, a1*a2, a1*a3],[a2*a1,...],[a3*a1,...]] + preds_expanded1 = pred.unsqueeze(1).expand(-1, pred.shape[1], -1) + preds_expanded2 = pred.unsqueeze(2).expand(-1, -1, pred.shape[1]) + # filter by implication relations and labels + + label_filter = target.unsqueeze(2).expand(-1, -1, pred.shape[1]) + filter_l = filter_l.to(pred.device).unsqueeze(0).expand(pred.shape[0], -1, -1) + filter_r = filter_r.to(pred.device).unsqueeze(0).expand(pred.shape[0], -1, -1) + if mode == "impl": + all_implications = self._calculate_implication_loss( + preds_expanded2, preds_expanded1 + ) + else: + all_implications = self._calculate_implication_loss( + preds_expanded2, 1 - preds_expanded1 + ) + loss_impl_l = all_implications * filter_l * (1 - label_filter) + if mode == "impl": + loss_impl_r = all_implications.transpose(1, 2) * filter_r * label_filter + loss_impl_sum = loss_impl_l + loss_impl_r + else: + loss_impl_sum = loss_impl_l + + if self.violations_per_cls_aggregator.startswith("log-"): + loss_impl_sum = -torch.log(1 - loss_impl_sum) + violations_per_cls_aggregator = self.violations_per_cls_aggregator[4:] + else: + violations_per_cls_aggregator = self.violations_per_cls_aggregator + if violations_per_cls_aggregator == "sum": + loss_by_cls = loss_impl_sum.sum(dim=-1) + elif violations_per_cls_aggregator == "max": + loss_by_cls = loss_impl_sum.max(dim=-1).values + elif violations_per_cls_aggregator == "mean": + loss_by_cls = loss_impl_sum.mean(dim=-1) + else: + raise NotImplementedError( + f"Unknown violations_per_cls_aggregator {self.violations_per_cls_aggregator}" + ) + + unweighted_mean = loss_by_cls.mean() + implication_loss_weighted = loss_by_cls + if "current_epoch" in kwargs and self.weight_epoch_dependent: + sigmoid_center = ( + self.weight_epoch_dependent[0] + if isinstance(self.weight_epoch_dependent, tuple) + else 50 + ) + sigmoid_spread = ( + self.weight_epoch_dependent[1] + if isinstance(self.weight_epoch_dependent, tuple) + else 10 + ) + # sigmoid function centered around epoch 50 + implication_loss_weighted = implication_loss_weighted / ( + 1 + + math.exp(-(kwargs["current_epoch"] - sigmoid_center) / sigmoid_spread) + ) + implication_loss_weighted *= weight + weighted_mean = implication_loss_weighted.mean() + + return implication_loss_weighted, unweighted_mean, weighted_mean + + def _calculate_unaggregated_base_loss(self, input, target, **kwargs): + nnl = kwargs.pop("non_null_labels", None) + labeled_input = input[nnl] if nnl else input + + if target is not None and self.base_loss is not None: + return self.base_loss(labeled_input, target.float()) + else: + return torch.zeros(input.shape, device=input.device) + + def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: + """ + Forward pass of the implication loss module. + + Args: + input (torch.Tensor): Input tensor. + target (torch.Tensor): Target tensor. + **kwargs: Additional arguments. + + Returns: + tuple: Tuple containing total loss, base loss, and implication loss. + """ + base_loss = self._calculate_unaggregated_base_loss(input, target, **kwargs) + loss_components = {"base_loss": base_loss.mean()} + + if "current_epoch" in kwargs and self.start_at_epoch > kwargs["current_epoch"]: + return base_loss.mean(), loss_components + + pred = torch.sigmoid(input) + fuzzy_loss, unweighted_fuzzy_mean, weighted_fuzzy_mean = ( + self._calculate_unaggregated_fuzzy_loss( + pred, + target, + self.impl_weight, + self.implication_filter_l, + self.implication_filter_r, + **kwargs, + ) + ) + if self.no_grads: + fuzzy_loss = fuzzy_loss.detach() + loss_components["unweighted_fuzzy_loss"] = unweighted_fuzzy_mean + loss_components["weighted_fuzzy_loss"] = weighted_fuzzy_mean + if self.base_loss is None or target is None: + total_loss = self.impl_weight * fuzzy_loss + elif self.multiply_with_base_loss: + total_loss = base_loss * (1 + self.impl_weight * fuzzy_loss) + else: + total_loss = base_loss + self.impl_weight * fuzzy_loss + return total_loss.mean(), loss_components + + def _calculate_implication_loss( + self, l: torch.Tensor, r: torch.Tensor + ) -> torch.Tensor: + """ + Calculate implication loss based on T-norm and other parameters. + + Args: + l (torch.Tensor): Left part of implication. + r (torch.Tensor): Right part of implication. + + Returns: + torch.Tensor: Calculated implication loss. + """ + assert not l.isnan().any(), ( + f"l contains NaN values - l.shape: {l.shape}, l.isnan().sum(): {l.isnan().sum()}, " + f"l: {l}" + ) + assert not r.isnan().any(), ( + f"r contains NaN values - r.shape: {r.shape}, r.isnan().sum(): {r.isnan().sum()}, " + f"r: {r}" + ) + if self.pos_scalar != 1: + l = ( + torch.pow(l + self.eps, 1 / self.pos_scalar) + - math.pow(self.eps, 1 / self.pos_scalar) + ) / ( + math.pow(1 + self.eps, 1 / self.pos_scalar) + - math.pow(self.eps, 1 / self.pos_scalar) + ) + one_min_r = ( + torch.pow(1 - r + self.eps, 1 / self.pos_scalar) + - math.pow(self.eps, 1 / self.pos_scalar) + ) / ( + math.pow(1 + self.eps, 1 / self.pos_scalar) + - math.pow(self.eps, 1 / self.pos_scalar) + ) + else: + one_min_r = 1 - r + # for each implication I, calculate 1 - I(l, 1-one_min_r) + # for S-implications, this is equivalent to the t-norm + if self.fuzzy_implication in ["reichenbach", "rc"]: + individual_loss = l * one_min_r + # xu19 (from Xu et al., 2019: Semantic loss) is not a fuzzy implication, but behaves similar to the Reichenbach + # implication + elif self.fuzzy_implication == "xu19": + individual_loss = -torch.log(1 - l * one_min_r) + elif self.fuzzy_implication in ["lukasiewicz", "lk"]: + individual_loss = torch.relu(l + one_min_r - 1) + elif self.fuzzy_implication in ["kleene_dienes", "kd"]: + individual_loss = torch.min(l, 1 - r) + elif self.fuzzy_implication in ["goedel", "g"]: + individual_loss = torch.where(l <= r, 0, one_min_r) + elif self.fuzzy_implication in ["reverse-goedel", "rg"]: + individual_loss = torch.where(l <= r, 0, l) + elif self.fuzzy_implication in ["binary", "b"]: + individual_loss = torch.where(l <= r, 0, 1).to(dtype=l.dtype) + else: + raise NotImplementedError( + f"Unknown fuzzy implication {self.fuzzy_implication}" + ) + + if self.use_sigmoidal_implication: + # formula by van Krieken, 2022, applied to fuzzy implication with default parameters: b_0 = 0.5, s = 9 + # parts that only depend on b_0 and s are pre-calculated + implication = 1 - individual_loss + sigmoidal_implication = 0.01123379 * ( + 91.0171 * torch.sigmoid(9 * (implication - 0.5)) - 1 + ) + individual_loss = 1 - sigmoidal_implication + + if self.multiply_by_softmax: + individual_loss = individual_loss * individual_loss.softmax(dim=-1) + + return individual_loss + + def _load_implications(self, path_to_chebi: str) -> dict: + """ + Load class hierarchy implications. + + Args: + path_to_chebi (str): Path to the ChEBI ontology file. + + Returns: + dict: Loaded hierarchy of implications. + """ + if os.path.isfile(self.implication_cache_file): + with open(self.implication_cache_file, "rb") as fin: + hierarchy = pickle.load(fin) + else: + hierarchy = self.data_extractor.extract_class_hierarchy(path_to_chebi) + with open(self.implication_cache_file, "wb") as fout: + pickle.dump(hierarchy, fout) + return hierarchy + + +class DisjointLoss(ImplicationLoss): + """ + Disjoint Loss module, extending ImplicationLoss. + + Args: + path_to_disjointness (str): Path to the disjointness data file (a csv file containing pairs of disjoint classes) + data_extractor (Union[_ChEBIDataExtractor, LabeledUnlabeledMixed]): Data extractor for labels. + base_loss (torch.nn.Module, optional): Base loss function. Defaults to None. + disjoint_loss_weight (float, optional): Weight of disjointness loss. Defaults to 100. + **kwargs: Additional arguments. + """ + + def __init__( + self, + path_to_disjointness: str, + data_extractor: Union[_ChEBIDataExtractor, LabeledUnlabeledMixed], + base_loss: torch.nn.Module = None, + disjoint_loss_weight: float = 100, + **kwargs, + ): + super().__init__(data_extractor, base_loss, **kwargs) + self.disjoint_filter_l, self.disjoint_filter_r = _build_disjointness_filter( + path_to_disjointness, self.label_names, self.hierarchy + ) + self.disjoint_weight = disjoint_loss_weight + + def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: + """ + Forward pass of the disjoint loss module. + + Args: + input (torch.Tensor): Input tensor. + target (torch.Tensor): Target tensor. + **kwargs: Additional arguments. + + Returns: + tuple: Tuple containing total loss, base loss, implication loss, and disjointness loss. + """ + base_loss = self._calculate_unaggregated_base_loss(input, target, **kwargs) + loss_components = {"base_loss": base_loss.mean()} + + if "current_epoch" in kwargs and self.start_at_epoch > kwargs["current_epoch"]: + return base_loss.mean(), loss_components + + pred = torch.sigmoid(input) + impl_loss, unweighted_impl_mean, weighted_impl_mean = ( + self._calculate_unaggregated_fuzzy_loss( + pred, + target, + self.impl_weight, + self.implication_filter_l, + self.implication_filter_r, + **kwargs, + ) + ) + if self.no_grads: + impl_loss = impl_loss.detach() + loss_components["unweighted_implication_loss"] = unweighted_impl_mean + loss_components["weighted_implication_loss"] = weighted_impl_mean + + disj_loss, unweighted_disj_mean, weighted_disj_mean = ( + self._calculate_unaggregated_fuzzy_loss( + pred, + target, + self.disjoint_weight, + self.disjoint_filter_l, + self.disjoint_filter_r, + mode="disj", + **kwargs, + ) + ) + if self.no_grads: + disj_loss = disj_loss.detach() + loss_components["unweighted_disjointness_loss"] = unweighted_disj_mean + loss_components["weighted_disjointness_loss"] = weighted_disj_mean + + if self.base_loss is None or target is None: + total_loss = self.impl_weight * impl_loss + self.disjoint_weight * disj_loss + elif self.multiply_with_base_loss: + total_loss = base_loss * ( + 1 + self.impl_weight * impl_loss + self.disjoint_weight * disj_loss + ) + else: + total_loss = ( + base_loss + + self.impl_weight * impl_loss + + self.disjoint_weight * disj_loss + ) + return total_loss.mean(), loss_components + + +def _load_label_names(path_to_label_names: str) -> List: + """ + Load label names from a file. + + Args: + path_to_label_names (str): Path to the label names file. + + Returns: + list: List of label names. + """ + with open(path_to_label_names) as fin: + label_names = [int(line.strip()) for line in fin] + return label_names + + +def _build_implication_filter(label_names: List, hierarchy: dict) -> torch.Tensor: + """ + Build implication filter based on label names and hierarchy. Results in list of pairs (A,B) for each implication + A->B (including indirect implications). + + Args: + label_names (list): List of label names. + hierarchy (dict): Hierarchy of implications. + + Returns: + torch.Tensor: Tensor representing implication filter. + """ + return torch.tensor( + [ + (i1, i2) + for i1, l1 in enumerate(label_names) + for i2, l2 in enumerate(label_names) + if l2 in hierarchy.pred[l1] + ] + ) + + +def _build_dense_filter(sparse_filter: torch.Tensor, n_labels: int) -> torch.Tensor: + res = torch.zeros((n_labels, n_labels), dtype=torch.bool) + for l, r in sparse_filter: + res[l, r] = True + return res + + +def _build_disjointness_filter( + path_to_disjointness: str, label_names: List, hierarchy: dict +) -> tuple: + """ + Build disjointness filter based on disjointness data and hierarchy. + + Args: + path_to_disjointness (str): Path to the disjointness data file. + label_names (list): List of label names. + hierarchy (dict): Hierarchy of implications. + + Returns: + tuple: Tuple containing tensors representing disjointness filter. + """ + disjoints = set() + label_dict = dict(map(reversed, enumerate(label_names))) + + with open(path_to_disjointness, "rt") as fin: + reader = csv.reader(fin) + for l1_raw, r1_raw in reader: + l1 = int(l1_raw) + r1 = int(r1_raw) + if l1 == 36233 and r1 == 63353: + # ignore disaccharide-disaccharide derivative disjointness axiom + continue + disjoints.update( + { + (label_dict[l2], label_dict[r2]) + for r2 in list(hierarchy.succ[r1]) + [r1] + if r2 in label_names + for l2 in list(hierarchy.succ[l1]) + [l1] + if l2 in label_names + } + ) + + dis_filter = torch.tensor(list(disjoints)) + dense = _build_dense_filter(dis_filter, len(label_names)) + dense_r = dense.transpose(0, 1) + return dense, dense_r + + +if __name__ == "__main__": + loss = DisjointLoss( + os.path.join("data", "disjoint.csv"), + ChEBIOver100(chebi_version=231), + base_loss=BCEWeighted(), + impl_loss_weight=1, + disjoint_loss_weight=1, + ) + random_preds = torch.randn(10, 997) + random_labels = torch.randint(0, 2, (10, 997)) + for agg in ["sum", "max", "mean", "log-mean"]: + loss.violations_per_cls_aggregator = agg + l = loss(random_preds, random_labels) + print(f"Loss with {agg} aggregation for random input:", l) + + # simplified example for ontology with 4 classes, A -> B, B -> C, D -> C, B and D disjoint + loss.implication_filter_l = torch.tensor( + [[0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 0, 0], [0, 0, 1, 0]] + ) + loss.implication_filter_r = loss.implication_filter_l.transpose(0, 1) + loss.disjoint_filter_l = torch.tensor( + [[0, 0, 0, 0], [0, 0, 0, 1], [0, 0, 0, 0], [0, 1, 0, 0]] + ) + loss.disjoint_filter_r = loss.disjoint_filter_l.transpose(0, 1) + # expected result: first sample: moderately high loss for B disj D, otherwise low, second sample: high loss for A -> B (applied to A), otherwise low + preds = torch.tensor([[0.1, 0.3, 0.7, 0.4], [0.5, 0.2, 0.9, 0.1]]) + labels = [[0, 1, 1, 0], [0, 0, 1, 1]] + for agg in ["sum", "max", "mean", "log-mean"]: + loss.violations_per_cls_aggregator = agg + l = loss(preds, torch.tensor(labels)) + print(f"Loss with {agg} aggregation for simple input:", l) diff --git a/chebai/models/__init__.py b/chebai/models/__init__.py new file mode 100644 index 0000000..e3122d5 --- /dev/null +++ b/chebai/models/__init__.py @@ -0,0 +1,2 @@ +from chebai.models.base import * +from chebai.models.electra import * diff --git a/chebai/models/base.py b/chebai/models/base.py new file mode 100644 index 0000000..4ba27bb --- /dev/null +++ b/chebai/models/base.py @@ -0,0 +1,372 @@ +import logging +from typing import Any, Dict, Optional, Union, Iterable + +import torch +from lightning.pytorch.core.module import LightningModule +from torchmetrics import Metric + +from chebai.preprocessing.structures import XYData + +logging.getLogger("pysmiles").setLevel(logging.CRITICAL) + +_MODEL_REGISTRY = dict() + + +class ChebaiBaseNet(LightningModule): + """ + Base class for Chebai neural network models inheriting from PyTorch Lightning's LightningModule. + + Args: + criterion (torch.nn.Module, optional): The loss criterion for the model. Defaults to None. + out_dim (int, optional): The output dimension of the model. Defaults to None. + train_metrics (torch.nn.Module, optional): The metrics to be used during training. Defaults to None. + val_metrics (torch.nn.Module, optional): The metrics to be used during validation. Defaults to None. + test_metrics (torch.nn.Module, optional): The metrics to be used during testing. Defaults to None. + pass_loss_kwargs (bool, optional): Whether to pass loss kwargs to the criterion. Defaults to True. + optimizer_kwargs (Dict[str, Any], optional): Additional keyword arguments for the optimizer. Defaults to None. + **kwargs: Additional keyword arguments. + + Attributes: + NAME (str): The name of the model. + """ + + NAME = None + + def __init__( + self, + criterion: torch.nn.Module = None, + out_dim: Optional[int] = None, + train_metrics: Optional[torch.nn.Module] = None, + val_metrics: Optional[torch.nn.Module] = None, + test_metrics: Optional[torch.nn.Module] = None, + pass_loss_kwargs: bool = True, + optimizer_kwargs: Optional[Dict[str, Any]] = None, + exclude_hyperparameter_logging: Optional[Iterable[str]] = None, + **kwargs, + ): + super().__init__() + if exclude_hyperparameter_logging is None: + exclude_hyperparameter_logging = tuple() + self.criterion = criterion + self.save_hyperparameters( + ignore=[ + "criterion", + "train_metrics", + "val_metrics", + "test_metrics", + *exclude_hyperparameter_logging, + ] + ) + self.out_dim = out_dim + if optimizer_kwargs: + self.optimizer_kwargs = optimizer_kwargs + else: + self.optimizer_kwargs = dict() + self.train_metrics = train_metrics + self.validation_metrics = val_metrics + self.test_metrics = test_metrics + self.pass_loss_kwargs = pass_loss_kwargs + + def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: + # avoid errors due to unexpected keys (e.g., if loading checkpoint from a bce model and using it with a + # different loss) + if "criterion.base_loss.pos_weight" in checkpoint["state_dict"]: + del checkpoint["state_dict"]["criterion.base_loss.pos_weight"] + if "criterion.pos_weight" in checkpoint["state_dict"]: + del checkpoint["state_dict"]["criterion.pos_weight"] + + def __init_subclass__(cls, **kwargs): + """ + Automatically registers subclasses in the model registry to prevent duplicates. + + Args: + **kwargs: Additional keyword arguments. + """ + if cls.NAME in _MODEL_REGISTRY: + raise ValueError(f"Model {cls.NAME} does already exist") + else: + _MODEL_REGISTRY[cls.NAME] = cls + + def _get_prediction_and_labels( + self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor + ) -> (torch.Tensor, torch.Tensor): + """ + Gets the predictions and labels from the model output. + + Args: + data (Dict[str, Any]): The processed batch data. + labels (torch.Tensor): The true labels. + output (torch.Tensor): The model output. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Predictions and labels. + """ + return output, labels + + def _process_labels_in_batch(self, batch: XYData) -> torch.Tensor: + """ + Processes the labels in the batch. + + Args: + batch (XYData): The input batch of data. + + Returns: + torch.Tensor: The processed labels. + """ + return batch.y.float() + + def _process_batch(self, batch: XYData, batch_idx: int) -> Dict[str, Any]: + """ + Processes the batch data. + + Args: + batch (XYData): The input batch of data. + batch_idx (int): The index of the current batch. + + Returns: + Dict[str, Any]: Processed batch data. + """ + return dict( + features=batch.x, + labels=self._process_labels_in_batch(batch), + model_kwargs=batch.additional_fields["model_kwargs"], + loss_kwargs=batch.additional_fields["loss_kwargs"], + idents=batch.additional_fields["idents"], + ) + + def _process_for_loss( + self, + model_output: torch.Tensor, + labels: torch.Tensor, + loss_kwargs: Dict[str, Any], + ) -> (torch.Tensor, torch.Tensor, Dict[str, Any]): + """ + Processes the data for loss computation. + + Args: + model_output (torch.Tensor): The model output. + labels (torch.Tensor): The true labels. + loss_kwargs (Dict[str, Any]): Additional keyword arguments for the loss function. + + Returns: + Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: Model output, labels, and loss kwargs. + """ + return model_output, labels, loss_kwargs + + def training_step( + self, batch: XYData, batch_idx: int + ) -> Dict[str, Union[torch.Tensor, Any]]: + """ + Defines the training step. + + Args: + batch (XYData): The input batch of data. + batch_idx (int): The index of the current batch. + + Returns: + Dict[str, Union[torch.Tensor, Any]]: The result of the training step. + """ + return self._execute( + batch, batch_idx, self.train_metrics, prefix="train_", sync_dist=True + ) + + def validation_step( + self, batch: XYData, batch_idx: int + ) -> Dict[str, Union[torch.Tensor, Any]]: + """ + Defines the validation step. + + Args: + batch (XYData): The input batch of data. + batch_idx (int): The index of the current batch. + + Returns: + Dict[str, Union[torch.Tensor, Any]]: The result of the validation step. + """ + return self._execute( + batch, batch_idx, self.validation_metrics, prefix="val_", sync_dist=True + ) + + def test_step( + self, batch: XYData, batch_idx: int + ) -> Dict[str, Union[torch.Tensor, Any]]: + """ + Defines the test step. + + Args: + batch (XYData): The input batch of data. + batch_idx (int): The index of the current batch. + + Returns: + Dict[str, Union[torch.Tensor, Any]]: The result of the test step. + """ + return self._execute( + batch, batch_idx, self.test_metrics, prefix="test_", sync_dist=True + ) + + def predict_step( + self, batch: XYData, batch_idx: int, **kwargs + ) -> Dict[str, Union[torch.Tensor, Any]]: + """ + Defines the prediction step. + + Args: + batch (XYData): The input batch of data. + batch_idx (int): The index of the current batch. + **kwargs: Additional keyword arguments. + + Returns: + Dict[str, Union[torch.Tensor, Any]]: The result of the prediction step. + """ + return self._execute(batch, batch_idx, self.test_metrics, prefix="", log=False) + + def _execute( + self, + batch: XYData, + batch_idx: int, + metrics: Optional[torch.nn.Module] = None, + prefix: Optional[str] = "", + log: Optional[bool] = True, + sync_dist: Optional[bool] = False, + ) -> Dict[str, Union[torch.Tensor, Any]]: + """ + Executes the model on a batch of data and returns the model output and predictions. + + Args: + batch (XYData): The input batch of data. + batch_idx (int): The index of the current batch. + metrics (torch.nn.Module): A dictionary of metrics to track. + prefix (str, optional): A prefix to add to the metric names. Defaults to "". + log (bool, optional): Whether to log the metrics. Defaults to True. + sync_dist (bool, optional): Whether to synchronize distributed training. Defaults to False. + + Returns: + Dict[str, Union[torch.Tensor, Any]]: A dictionary containing the processed data, labels, model output, + predictions, and loss (if applicable). + """ + assert isinstance(batch, XYData) + batch = batch.to(self.device) + data = self._process_batch(batch, batch_idx) + labels = data["labels"] + model_output = self(data, **data.get("model_kwargs", dict())) + pr, tar = self._get_prediction_and_labels(data, labels, model_output) + d = dict(data=data, labels=labels, output=model_output, preds=pr) + if log: + if self.criterion is not None: + loss_data, loss_labels, loss_kwargs_candidates = self._process_for_loss( + model_output, labels, data.get("loss_kwargs", dict()) + ) + loss_kwargs = dict() + if self.pass_loss_kwargs: + loss_kwargs = loss_kwargs_candidates + loss_kwargs["current_epoch"] = self.trainer.current_epoch + loss = self.criterion(loss_data, loss_labels, **loss_kwargs) + if isinstance(loss, tuple): + unnamed_loss_index = 1 + if isinstance(loss[1], dict): + unnamed_loss_index = 2 + for key, value in loss[1].items(): + self.log( + key, + value if isinstance(value, int) else value.item(), + batch_size=len(batch), + on_step=True, + on_epoch=True, + prog_bar=False, + logger=True, + sync_dist=sync_dist, + ) + loss_additional = loss[unnamed_loss_index:] + for i, loss_add in enumerate(loss_additional): + self.log( + f"{prefix}loss_{i}", + loss_add if isinstance(loss_add, int) else loss_add.item(), + batch_size=len(batch), + on_step=True, + on_epoch=True, + prog_bar=False, + logger=True, + sync_dist=sync_dist, + ) + loss = loss[0] + + d["loss"] = loss + self.log( + f"{prefix}loss", + loss.item(), + batch_size=len(batch), + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + sync_dist=sync_dist, + ) + if metrics and labels is not None: + for metric_name, metric in metrics.items(): + metric.update(pr, tar) + self._log_metrics(prefix, metrics, len(batch)) + return d + + def _log_metrics(self, prefix: str, metrics: torch.nn.Module, batch_size: int): + """ + Logs the metrics for the given prefix. + + Args: + prefix (str): The prefix to be added to the metric names. + metrics (torch.nn.Module): A dictionary containing the metrics to be logged. + batch_size (int): The batch size used for logging. + + Returns: + None + """ + # don't use sync_dist=True if the metric is a torchmetrics-metric + # (see https://github.com/Lightning-AI/pytorch-lightning/discussions/6501#discussioncomment-569757) + for metric_name, metric in metrics.items(): + m = None # m = metric.compute() + if isinstance(m, dict): + # todo: is this case needed? it requires logging values directly which does not give accurate results + # with the current metric-setup + for k, m2 in m.items(): + self.log( + f"{prefix}{metric_name}{k}", + m2, + batch_size=batch_size, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + else: + self.log( + f"{prefix}{metric_name}", + metric, + batch_size=batch_size, + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + + def forward(self, x: Dict[str, Any]) -> torch.Tensor: + """ + Defines the forward pass. + + Args: + x (Dict[str, Any]): The input data. + + Returns: + torch.Tensor: The model output. + """ + raise NotImplementedError + + def configure_optimizers(self, **kwargs) -> torch.optim.Optimizer: + """ + Configures the optimizers. + + Args: + **kwargs: Additional keyword arguments. + + Returns: + torch.optim.Optimizer: The optimizer. + """ + return torch.optim.Adamax(self.parameters(), **self.optimizer_kwargs) diff --git a/chebai/models/chemberta.py b/chebai/models/chemberta.py new file mode 100644 index 0000000..b601542 --- /dev/null +++ b/chebai/models/chemberta.py @@ -0,0 +1,77 @@ +import logging +import random +from tempfile import TemporaryDirectory + +import torch +from torch import nn +from torch.nn.functional import one_hot +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence +from transformers import ( + RobertaConfig, + RobertaForMaskedLM, + RobertaModel, + RobertaTokenizer, +) + +from chebai.models.base import ChebaiBaseNet + +logging.getLogger("pysmiles").setLevel(logging.CRITICAL) +MAX_LEN = 1800 + + +class ChembertaPre(ChebaiBaseNet): + NAME = "ChembertaPre" + + def __init__(self, p=0.2, **kwargs): + super().__init__(**kwargs) + self._p = p + self.config = RobertaConfig(**kwargs["config"]) + self.model = RobertaForMaskedLM(self.config) + + def _process_batch(self, batch, batch_idx): + masked = ( + torch.rand([batch.x.shape[0]], device=self.device) + * torch.tensor(batch.lens, device=self.device) + ).long() + labels = one_hot( + torch.gather(batch.x, 1, masked.unsqueeze(-1)).squeeze(-1), + self.config.vocab_size, + ) + features = 1 + batch.x + features = features * (1 - one_hot(masked, batch.x.shape[-1])) + return features, labels + + def forward(self, data): + x = self.model(data) + return {"logits": torch.sum(x.logits, dim=1)} + + +class Chemberta(ChebaiBaseNet): + NAME = "Chemberta" + + def __init__(self, **kwargs): + # Remove this property in order to prevent it from being stored as a + # hyper parameter + pretrained_checkpoint = ( + kwargs.pop("pretrained_checkpoint") + if "pretrained_checkpoint" in kwargs + else None + ) + super().__init__(**kwargs) + self.config = RobertaConfig( + **kwargs["config"], output_attentions=True, num_labels=self.out_dim + ) + + if pretrained_checkpoint: + elpre = RobertaModel.load_from_checkpoint(pretrained_checkpoint) + with TemporaryDirectory() as td: + elpre.electra.save_pretrained(td) + self.electra = RobertaModel.from_pretrained(td, config=self.config) + in_d = elpre.config.hidden_size + else: + self.electra = RobertaModel(config=self.config) + in_d = self.config.hidden_size + + def forward(self, data): + electra = self.electra(data) + return dict(logits=electra.logits, attentions=electra.attentions) diff --git a/chebai/models/chemyk.py b/chebai/models/chemyk.py new file mode 100644 index 0000000..13bbea7 --- /dev/null +++ b/chebai/models/chemyk.py @@ -0,0 +1,63 @@ +import logging +import os +import pickle +import sys + +import networkx as nx +import torch +from torch import nn +from torch.nn import functional as F +from torch.nn.functional import pad + +from chebai.models.base import ChebaiBaseNet + +logging.getLogger("pysmiles").setLevel(logging.CRITICAL) + + +class ChemYK(ChebaiBaseNet): + NAME = "ChemYK" + + def __init__(self, in_d, out_d, num_classes, **kwargs): + super().__init__(num_classes, **kwargs) + d_internal = in_d + self.d_internal = d_internal + self.embedding = nn.Embedding(800, d_internal) + self.s = nn.Linear(d_internal, 1) + self.a_l = nn.Linear(d_internal, 1) + self.a_r = nn.Linear(d_internal, 1) + self.w_l = nn.Linear(d_internal, d_internal) + self.w_r = nn.Linear(d_internal, d_internal) + self.norm = nn.LayerNorm(d_internal) + self.output = nn.Sequential( + nn.Linear(in_d, in_d), + nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(in_d, num_classes), + ) + + def forward(self, data, *args, **kwargs): + m = self.embedding(data.x) + max_width = m.shape[1] + h = [m] # torch.zeros(emb.shape[0], max_width, *emb.shape[1:]) + # h[:, 0] = emb + for width in range(1, max_width): + l = torch.stack(tuple(h[i][:, : (max_width - width)] for i in range(width))) + r = torch.stack( + tuple(h[i][:, (width - i) :] for i in range(0, width)) + ).flip(0) + m = self.merge(l, r) + h.append(m) + return self.output(m).squeeze(1) + + def merge(self, l, r): + x = torch.stack([self.a_l(l), self.a_r(r)]) + beta = torch.softmax(x, 0) + return F.leaky_relu( + self.attention( + torch.sum(beta * torch.stack([self.w_l(l), self.w_r(r)]), dim=0) + ) + ) + + def attention(self, parts): + at = torch.softmax(self.s(parts), 1) + return torch.sum(at * parts, dim=0) diff --git a/chebai/models/electra.py b/chebai/models/electra.py new file mode 100644 index 0000000..dc6c719 --- /dev/null +++ b/chebai/models/electra.py @@ -0,0 +1,535 @@ +import logging +from math import pi +from tempfile import TemporaryDirectory +from typing import Any, Dict, Optional, Tuple + +import torch +from torch import Tensor, nn +from torch.nn.utils.rnn import pad_sequence +from transformers import ( + ElectraConfig, + ElectraForMaskedLM, + ElectraForPreTraining, + ElectraModel, +) + +from chebai.loss.pretraining import ElectraPreLoss # noqa +from chebai.models.base import ChebaiBaseNet +from chebai.preprocessing.reader import CLS_TOKEN, MASK_TOKEN_INDEX + +logging.getLogger("pysmiles").setLevel(logging.CRITICAL) + +from chebai.loss.semantic import DisjointLoss as ElectraChEBIDisjointLoss # noqa + + +class ElectraPre(ChebaiBaseNet): + """ + ElectraPre class represents an Electra model for pre-training inherited from ChebaiBaseNet. + + Args: + config (dict): Configuration parameters for the Electra model. + **kwargs: Additional keyword arguments (passed to parent class). + + Attributes: + NAME (str): Name of the ElectraPre model. + generator_config (ElectraConfig): Configuration for the generator model. + generator (ElectraForMaskedLM): Generator model for masked language modeling. + discriminator_config (ElectraConfig): Configuration for the discriminator model. + discriminator (ElectraForPreTraining): Discriminator model for pre-training. + replace_p (float): Probability of replacing tokens during training. + """ + + NAME = "ElectraPre" + + def __init__(self, config: Dict[str, Any] = None, **kwargs: Any): + super().__init__(config=config, **kwargs) + self.generator_config = ElectraConfig(**config["generator"]) + self.generator = ElectraForMaskedLM(self.generator_config) + self.discriminator_config = ElectraConfig(**config["discriminator"]) + self.discriminator = ElectraForPreTraining(self.discriminator_config) + self.replace_p = 0.1 + + @property + def as_pretrained(self) -> ElectraForPreTraining: + """ + Returns the discriminator model as a pre-trained model. + + Returns: + ElectraForPreTraining: The discriminator model. + """ + return self.discriminator + + def _process_labels_in_batch(self, batch: Dict[str, Any]) -> None: + """ + Processes the labels in the batch. + + Args: + batch (Dict[str, Any]): The input batch of data. + + Returns: + torch.Tensor: The processed labels. + """ + return None + + def forward( + self, data: Dict[str, Any], **kwargs: Any + ) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]: + """ + Forward pass of the ElectraPre model. + + Args: + data (dict): Input data. + **kwargs: Additional keyword arguments. + + Returns: + tuple: A tuple containing the raw generator output and discriminator output. + The generator output is a tensor of shape (batch_size, max_seq_len, vocab_size). + The discriminator output is a tensor of shape (batch_size, max_seq_len). + """ + features = data["features"] + features = features.long() + self.batch_size = batch_size = features.shape[0] + max_seq_len = features.shape[1] + + mask = kwargs["mask"] + with torch.no_grad(): + dis_tar = ( + torch.rand((batch_size,), device=self.device) * torch.sum(mask, dim=-1) + ).int() + disc_tar_one_hot = torch.eq( + torch.arange(max_seq_len, device=self.device)[None, :], dis_tar[:, None] + ) + gen_tar = features[disc_tar_one_hot] + gen_tar_one_hot = torch.eq( + torch.arange(self.generator_config.vocab_size, device=self.device)[ + None, : + ], + gen_tar[:, None], + ) + + raw_gen_out = torch.mean( + self.generator( + (features * ~disc_tar_one_hot) + MASK_TOKEN_INDEX * disc_tar_one_hot, + attention_mask=mask, + ).logits, + dim=1, + ) + + with torch.no_grad(): + gen_best_guess = raw_gen_out.argmax(dim=-1) + correct_mask = features[disc_tar_one_hot] == gen_best_guess + random_tokens = torch.randint( + self.generator_config.vocab_size, (batch_size,), device=self.device + ) + replacements = gen_best_guess * ~correct_mask + random_tokens * correct_mask + + disc_out = self.discriminator( + features * ~disc_tar_one_hot + replacements[:, None] * disc_tar_one_hot, + attention_mask=mask, + ).logits + return (raw_gen_out, disc_out), (gen_tar_one_hot, disc_tar_one_hot) + + def _get_prediction_and_labels( + self, batch: Dict[str, Any], labels: Tensor, output: Tensor + ) -> Tuple[Tensor, Tensor]: + """ + Gets the predictions and labels from the model output. + + Args: + data (Dict[str, Any]): The processed batch data. + labels (torch.Tensor): The true labels. + output (torch.Tensor): The model output. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: Predictions and labels. + """ + return torch.softmax(output[0][1], dim=-1), output[1][1].int() + + +def filter_dict(d: Dict[str, Any], filter_key: str) -> Dict[str, Any]: + """ + Filters a dictionary by a given key prefix. + + Args: + d (dict): The dictionary to filter. + filter_key (str): The key prefix to filter by. + + Returns: + dict: A dictionary containing only the key-value pairs where the key starts with the given prefix. + """ + return { + str(k)[len(filter_key) :]: v + for k, v in d.items() + if str(k).startswith(filter_key) + } + + +class Electra(ChebaiBaseNet): + """ + Electra model implementation inherited from ChebaiBaseNet. + + Args: + config (Dict[str, Any], optional): Configuration parameters for the Electra model. Defaults to None. + pretrained_checkpoint (str, optional): Path to the pretrained checkpoint file. Defaults to None. + load_prefix (str, optional): Prefix to filter the state_dict keys from the pretrained checkpoint. Defaults to None. + **kwargs: Additional keyword arguments. + + Attributes: + NAME (str): Name of the Electra model. + """ + + NAME = "Electra" + + def _process_batch(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any]: + """ + Process a batch of data. + + Args: + batch (Dict[str, Any]): The input batch of data. + batch_idx (int): The index of the batch (not used). + + Returns: + dict: A dictionary containing the processed batch, keys are `features`, `labels`, `model_kwargs`, + `loss_kwargs` and `idents`. + """ + model_kwargs = dict() + loss_kwargs = batch.additional_fields["loss_kwargs"] + if "lens" in batch.additional_fields["model_kwargs"]: + model_kwargs["attention_mask"] = pad_sequence( + [ + torch.ones(l + 1, device=self.device) + for l in batch.additional_fields["model_kwargs"]["lens"] + ], + batch_first=True, + ) + cls_tokens = ( + torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze( + -1 + ) + * CLS_TOKEN + ) + return dict( + features=torch.cat((cls_tokens, batch.x), dim=1), + labels=batch.y, + model_kwargs=model_kwargs, + loss_kwargs=loss_kwargs, + idents=batch.additional_fields["idents"], + ) + + @property + def as_pretrained(self) -> ElectraModel: + """ + Get the pretrained Electra model. + + Returns: + ElectraModel: The pretrained Electra model. + """ + return self.electra.electra + + def __init__( + self, + config: Optional[Dict[str, Any]] = None, + pretrained_checkpoint: Optional[str] = None, + load_prefix: Optional[str] = None, + **kwargs: Any, + ): + # Remove this property in order to prevent it from being stored as a + # hyper parameter + + super().__init__(**kwargs) + if config is None: + config = dict() + if not "num_labels" in config and self.out_dim is not None: + config["num_labels"] = self.out_dim + self.config = ElectraConfig(**config, output_attentions=True) + self.word_dropout = nn.Dropout(config.get("word_dropout", 0)) + + in_d = self.config.hidden_size + self.output = nn.Sequential( + nn.Dropout(self.config.hidden_dropout_prob), + nn.Linear(in_d, in_d), + nn.GELU(), + nn.Dropout(self.config.hidden_dropout_prob), + nn.Linear(in_d, self.config.num_labels), + ) + + # Load pretrained checkpoint if provided + if pretrained_checkpoint: + with open(pretrained_checkpoint, "rb") as fin: + model_dict = torch.load( + fin, map_location=self.device, weights_only=False + ) + if load_prefix: + state_dict = filter_dict(model_dict["state_dict"], load_prefix) + else: + state_dict = model_dict["state_dict"] + self.electra = ElectraModel.from_pretrained( + None, state_dict=state_dict, config=self.config + ) + else: + self.electra = ElectraModel(config=self.config) + + def _process_for_loss( + self, + model_output: Dict[str, Tensor], + labels: Tensor, + loss_kwargs: Dict[str, Any], + ) -> Tuple[Tensor, Tensor, Dict[str, Any]]: + """ + Process the model output for calculating the loss. + + Args: + model_output (Dict[str, Tensor]): The output of the model. + labels (Tensor): The target labels. + loss_kwargs (Dict[str, Any]): Additional loss arguments. + + Returns: + tuple: A tuple containing the processed model output, labels, and loss arguments. + """ + kwargs_copy = dict(loss_kwargs) + if labels is not None: + labels = labels.float() + return model_output["logits"], labels, kwargs_copy + + def _get_prediction_and_labels( + self, data: Dict[str, Any], labels: Tensor, model_output: Dict[str, Tensor] + ) -> Tuple[Tensor, Tensor]: + """ + Get the predictions and labels from the model output. Applies a sigmoid to the model output. + + Args: + data (Dict[str, Any]): The input data. + labels (Tensor): The target labels. + model_output (Dict[str, Tensor]): The output of the model. + + Returns: + tuple: A tuple containing the predictions and labels. + """ + d = model_output["logits"] + loss_kwargs = data.get("loss_kwargs", dict()) + if "non_null_labels" in loss_kwargs: + n = loss_kwargs["non_null_labels"] + d = d[n] + return torch.sigmoid(d), labels.int() if labels is not None else None + + def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: + """ + Forward pass of the Electra model. + + Args: + data (Dict[str, Tensor]): The input data (expects a key `features`). + **kwargs: Additional keyword arguments for `self.electra`. + + Returns: + dict: A dictionary containing the model output (logits and attentions). + """ + self.batch_size = data["features"].shape[0] + try: + inp = self.electra.embeddings.forward(data["features"].int()) + except RuntimeError as e: + print(f"RuntimeError at forward: {e}") + print(f'data[features]: {data["features"]}') + raise e + inp = self.word_dropout(inp) + electra = self.electra(inputs_embeds=inp, **kwargs) + d = electra.last_hidden_state[:, 0, :] + return dict( + logits=self.output(d), + attentions=electra.attentions, + ) + + +class ElectraLegacy(ChebaiBaseNet): + NAME = "ElectraLeg" + + def __init__(self, **kwargs): + super().__init__(**kwargs) + self.config = ElectraConfig(**kwargs["config"], output_attentions=True) + + if "pretrained_checkpoint" in kwargs: + elpre = ElectraPre.load_from_checkpoint(kwargs["pretrained_checkpoint"]) + with TemporaryDirectory() as td: + elpre.electra.save_pretrained(td) + self.electra = ElectraModel.from_pretrained(td, config=self.config) + in_d = elpre.config.hidden_size + else: + self.electra = ElectraModel(config=self.config) + in_d = self.config.hidden_size + + self.output = nn.Sequential( + nn.Linear(in_d, in_d), + nn.ReLU(), + nn.Linear(in_d, in_d), + nn.ReLU(), + nn.Linear(in_d, in_d), + nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(in_d, 500), + ) + + def forward(self, data): + electra = self.electra(data) + d = torch.sum(electra.last_hidden_state, dim=1) + return dict(logits=self.output(d), attentions=electra.attentions) + + +class ConeElectra(ChebaiBaseNet): + NAME = "ConeElectra" + + def _process_batch(self, batch, batch_idx): + mask = pad_sequence( + [torch.ones(l + 1, device=self.device) for l in batch.lens], + batch_first=True, + ) + cls_tokens = ( + torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze( + -1 + ) + * CLS_TOKEN + ) + return dict( + features=torch.cat((cls_tokens, batch.x), dim=1), + labels=batch.y, + model_kwargs=dict(attention_mask=mask), + ) + + @property + def as_pretrained(self): + return self.electra.electra + + def __init__(self, cone_dimensions=20, **kwargs): + # Remove this property in order to prevent it from being stored as a + # hyper parameter + pretrained_checkpoint = ( + kwargs.pop("pretrained_checkpoint") + if "pretrained_checkpoint" in kwargs + else None + ) + + self.cone_dimensions = cone_dimensions + + super().__init__(**kwargs) + if not "num_labels" in kwargs["config"] and self.out_dim is not None: + kwargs["config"]["num_labels"] = self.out_dim + self.config = ElectraConfig(**kwargs["config"], output_attentions=True) + self.word_dropout = nn.Dropout(kwargs["config"].get("word_dropout", 0)) + model_prefix = kwargs.get("load_prefix", None) + if pretrained_checkpoint: + with open(pretrained_checkpoint, "rb") as fin: + model_dict = torch.load( + fin, map_location=self.device, weights_only=False + ) + if model_prefix: + state_dict = { + str(k)[len(model_prefix) :]: v + for k, v in model_dict["state_dict"].items() + if str(k).startswith(model_prefix) + } + else: + state_dict = model_dict["state_dict"] + self.electra = ElectraModel.from_pretrained( + None, state_dict=state_dict, config=self.config + ) + else: + self.electra = ElectraModel(config=self.config) + + in_d = self.config.hidden_size + + self.line_embedding = nn.Sequential( + nn.Dropout(self.config.hidden_dropout_prob), + nn.Linear(in_d, in_d), + nn.GELU(), + nn.Dropout(self.config.hidden_dropout_prob), + nn.Linear(in_d, self.cone_dimensions), + ) + + self.cone_axes = nn.Parameter( + 2 * pi * torch.rand((1, self.config.num_labels, self.cone_dimensions)) + ) + self.cone_arcs = nn.Parameter( + pi * (1 - 2 * torch.rand((1, self.config.num_labels, self.cone_dimensions))) + ) + + def _get_data_for_loss(self, model_output, labels): + d = model_output["predicted_vectors"] + return dict( + input=dict( + predicted_vectors=d, cone_axes=self.cone_axes, cone_arcs=self.cone_arcs + ), + target=labels.float(), + ) + + def _get_prediction_and_labels(self, data, labels, model_output): + d = model_output["predicted_vectors"].unsqueeze(1) + + d = in_cone_parts(d, self.cone_axes, self.cone_arcs) + + return torch.mean(d, dim=-1), labels.int() + + def forward(self, data, **kwargs): + self.batch_size = data["features"].shape[0] + inp = self.electra.embeddings.forward(data["features"]) + inp = self.word_dropout(inp) + electra = self.electra(inputs_embeds=inp, **kwargs) + d = electra.last_hidden_state[:, 0, :] + return dict( + predicted_vectors=self.line_embedding(d), + attentions=electra.attentions, + ) + + +def softabs(x, eps=0.01): + return (x**2 + eps) ** 0.5 - eps**0.5 + + +def anglify(x): + return torch.tanh(x) * pi + + +def turn(vector, angle): + v = vector - angle + return v - (v > pi) * 2 * pi + (v < -pi) * 2 * pi + + +def in_cone_parts(vectors, cone_axes, cone_arcs): + """ + # trap between -pi and pi + cone_ax_ang = anglify(cone_axes) + v = anglify(vectors) + + # trap between 0 and pi + cone_arc_ang = (torch.tanh(cone_arcs)+1)*pi/2 + theta_L = cone_ax_ang - cone_arc_ang/2 + #theta_L = theta_L - (theta_L > 2*pi) * 2 * pi + (theta_L < 0) *2*pi + theta_R = cone_ax_ang + cone_arc_ang/2 + #theta_R = theta_R - (theta_R > 2 * pi) * 2 * pi + (theta_R < 0) * 2 * pi + dis = (torch.abs(turn(v, theta_L)) + torch.abs(turn(v, theta_R)) - cone_arc_ang)/(2*pi-cone_arc_ang) + return dis + """ + a = cone_axes - cone_arcs**2 + b = cone_axes + cone_arcs**2 + bigger_than_a = torch.sigmoid(vectors - a) + smaller_than_b = torch.sigmoid(b - vectors) + return bigger_than_a * smaller_than_b + + +class ConeLoss: + def __init__(self, center_scaling=0.1): + self.center_scaling = center_scaling + + def negate(self, ax, arc): + offset = pi * torch.ones_like(ax) + offset[ax >= 0] *= -1 + return ax + offset, pi - arc + + def __call__(self, target, input): + predicted_vectors = input["predicted_vectors"].unsqueeze(1) + cone_axes = input["cone_axes"] + cone_arcs = input["cone_arcs"] + memberships = (1 - 1e-6) * ( + in_cone_parts(predicted_vectors, cone_axes, cone_arcs) + ) + loss = torch.nn.functional.binary_cross_entropy( + memberships, target.unsqueeze(-1).expand(-1, -1, 20) + ) + return loss diff --git a/chebai/models/external/__init__.py b/chebai/models/external/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chebai/models/ffn.py b/chebai/models/ffn.py new file mode 100644 index 0000000..c9c6f91 --- /dev/null +++ b/chebai/models/ffn.py @@ -0,0 +1,153 @@ +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch import Tensor, nn + +from chebai.models import ChebaiBaseNet + + +class FFN(ChebaiBaseNet): + # Reference: https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/models.py#L121-L139 + + NAME = "FFN" + + def __init__( + self, + input_size: int, + hidden_layers: List[int] = [ + 1024, + ], + **kwargs + ): + super().__init__(**kwargs) + + layers = [] + current_layer_input_size = input_size + for hidden_dim in hidden_layers: + layers.append(MLPBlock(current_layer_input_size, hidden_dim)) + layers.append(Residual(MLPBlock(hidden_dim, hidden_dim))) + current_layer_input_size = hidden_dim + + layers.append(torch.nn.Linear(current_layer_input_size, self.out_dim)) + layers.append(nn.Sigmoid()) + self.model = nn.Sequential(*layers) + + def _get_prediction_and_labels(self, data, labels, model_output): + d = model_output["logits"] + loss_kwargs = data.get("loss_kwargs", dict()) + if "non_null_labels" in loss_kwargs: + n = loss_kwargs["non_null_labels"] + d = d[n] + return torch.sigmoid(d), labels.int() if labels is not None else None + + def _process_for_loss( + self, + model_output: Dict[str, Tensor], + labels: Tensor, + loss_kwargs: Dict[str, Any], + ) -> Tuple[Tensor, Tensor, Dict[str, Any]]: + """ + Process the model output for calculating the loss. + + Args: + model_output (Dict[str, Tensor]): The output of the model. + labels (Tensor): The target labels. + loss_kwargs (Dict[str, Any]): Additional loss arguments. + + Returns: + tuple: A tuple containing the processed model output, labels, and loss arguments. + """ + kwargs_copy = dict(loss_kwargs) + if labels is not None: + labels = labels.float() + return model_output["logits"], labels, kwargs_copy + + def forward(self, data, **kwargs): + x = data["features"] + return {"logits": self.model(x)} + + +class Residual(nn.Module): + """ + A residual layer that adds the output of a function to its input. + + Args: + fn (nn.Module): The function to be applied to the input. + + References: + https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/base.py#L6-L35 + """ + + def __init__(self, fn): + """ + Initialize the Residual layer with a given function. + + Args: + fn (nn.Module): The function to be applied to the input. + """ + super().__init__() + self.fn = fn + + def forward(self, x): + """ + Forward pass of the Residual layer. + + Args: + x: Input tensor. + + Returns: + torch.Tensor: The input tensor added to the result of applying the function `fn` to it. + """ + return x + self.fn(x) + + +class MLPBlock(nn.Module): + """ + A basic Multi-Layer Perceptron (MLP) block with one fully connected layer. + + Args: + in_features (int): The number of input features. + output_size (int): The number of output features. + bias (boolean): Add bias to the linear layer + layer_norm (boolean): Apply layer normalization + dropout (float): The dropout value + activation (nn.Module): The activation function to be applied after each fully connected layer. + + References: + https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/base.py#L38-L73 + + Example: + ```python + # Create an MLP block with 2 hidden layers and ReLU activation + mlp_block = MLPBlock(input_size=64, output_size=10, activation=nn.ReLU()) + + # Apply the MLP block to an input tensor + input_tensor = torch.randn(32, 64) + output = mlp_block(input_tensor) + ``` + """ + + def __init__( + self, + in_features, + out_features, + bias=True, + layer_norm=True, + dropout=0.1, + activation=nn.ReLU, + ): + super().__init__() + self.linear = nn.Linear(in_features, out_features, bias) + self.activation = activation() + self.layer_norm: Optional[nn.LayerNorm] = ( + nn.LayerNorm(out_features) if layer_norm else None + ) + self.dropout: Optional[nn.Dropout] = nn.Dropout(dropout) if dropout else None + + def forward(self, x): + x = self.activation(self.linear(x)) + if self.layer_norm: + x = self.layer_norm(x) + if self.dropout: + x = self.dropout(x) + return x diff --git a/chebai/models/lnn_model.py b/chebai/models/lnn_model.py new file mode 100644 index 0000000..3d61c5a --- /dev/null +++ b/chebai/models/lnn_model.py @@ -0,0 +1,40 @@ +import fastobo +import pyhornedowl +import tqdm +from lnn import Implies, Model, Not, Predicate, Variable, World +from owlready2 import get_ontology + + +def get_name(iri: str): + return iri.split("/")[-1] + + +if __name__ == "__main__": + formulae = [] + + # Load disjointness axioms + # onto_dis = pyhornedowl.open_ontology("/data/ontologies/chebi-disjoints.owl") + # print("Process disjointness releation") + # formulae += [Implies(predicates[get_name(c)](x), Not(predicates[get_name(d)](x))) for _, c,d in (ax for ax in onto_dis.get_axioms() if ax[0] == "AxiomKind::SubClassOf" and isinstance(ax[-1], str))] + + model = Model() + x = Variable("x") + y = Variable("y") + + onto = pyhornedowl.open_ontology("/data/ontologies/chebi.owl") + + print("Process classes") + predicates = {get_name(c): Predicate(get_name(c)) for c in onto.get_classes()} + + print("Process subsumption releation") + formulae += [ + Implies(predicates[get_name(c)](x), predicates[get_name(d)](x)) + for _, c, d in ( + ax + for ax in onto.get_axioms() + if ax[0] == "AxiomKind::SubClassOf" and isinstance(ax[-1], str) + ) + ] + + model.add_knowledge(*formulae, world=World.AXIOM) + model.print() diff --git a/chebai/models/lstm.py b/chebai/models/lstm.py new file mode 100644 index 0000000..c706d6a --- /dev/null +++ b/chebai/models/lstm.py @@ -0,0 +1,34 @@ +import logging +import sys + +from torch import nn +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence + +from chebai.models.base import ChebaiBaseNet + +logging.getLogger("pysmiles").setLevel(logging.CRITICAL) + + +class ChemLSTM(ChebaiBaseNet): + NAME = "LSTM" + + def __init__(self, in_d, out_d, num_classes, **kwargs): + super().__init__(num_classes, **kwargs) + self.lstm = nn.LSTM(in_d, out_d, batch_first=True) + self.embedding = nn.Embedding(800, 100) + self.output = nn.Sequential( + nn.Linear(out_d, in_d), + nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(in_d, num_classes), + ) + + def forward(self, data): + x = data.x + x_lens = data.lens + x = self.embedding(x) + x = pack_padded_sequence(x, x_lens, batch_first=True, enforce_sorted=False) + x = self.lstm(x)[1][0] + # = pad_packed_sequence(x, batch_first=True)[0] + x = self.output(x) + return x.squeeze(0) diff --git a/chebai/models/recursive.py b/chebai/models/recursive.py new file mode 100644 index 0000000..fb40803 --- /dev/null +++ b/chebai/models/recursive.py @@ -0,0 +1,97 @@ +import logging + +import networkx as nx +import torch +import torch.nn.functional as F +from torch import exp, nn, tensor + +from chebai.models.base import ChebaiBaseNet + +logging.getLogger("pysmiles").setLevel(logging.CRITICAL) + + +class Recursive(ChebaiBaseNet): + NAME = "REC" + + def __init__(self, in_d, out_d, num_classes, **kwargs): + super().__init__(num_classes, **kwargs) + mem_len = in_d + self.internal_dimension = in_d + self.embedding = nn.Embedding(800, 100) + + self.input_post = nn.Linear(in_d, in_d) + + self.input_attention = nn.MultiheadAttention(in_d, 5) + self.hidden_attention = nn.MultiheadAttention(in_d, 5) + self.merge_attention = nn.MultiheadAttention(in_d, 5) + + self.hidden_post = nn.Linear(in_d, in_d) + + self.merge_post = nn.Linear(in_d, in_d) + + self.post = nn.Linear(in_d, in_d) + + self.children_attention = nn.MultiheadAttention(in_d, 5) + + self.input_norm_1 = nn.LayerNorm(in_d) + self.input_norm_2 = nn.LayerNorm(in_d) + self.hidden_norm_1 = nn.LayerNorm(in_d) + self.merge_norm_1 = nn.LayerNorm(in_d) + self.merge_norm_2 = nn.LayerNorm(in_d) + + self.base = torch.nn.parameter.Parameter(torch.empty((in_d,))) + self.base_memory = torch.nn.parameter.Parameter(torch.empty((mem_len,))) + self.output = nn.Sequential( + nn.Linear(in_d, in_d), + nn.ReLU(), + nn.Dropout(0.2), + nn.Linear(in_d, num_classes), + ) + + def forward(self, batch): + result = [] + for row in batch: + graph = row[0] + c = nx.center(graph)[0] + d = nx.single_source_shortest_path(graph, c) + if graph.edges: + digraph = nx.DiGraph( + (a, b) if d[a] > d[b] else (b, a) for (a, b) in graph.edges + ) + else: + digraph = nx.DiGraph(graph) + child_results = {} + x = None + for node in nx.topological_sort(digraph): + child_values = child_results.pop(node, []) + inp = self.embedding(graph.nodes[node]["x"]) + if not child_values: + hidden_state = self.base_memory + else: + hidden_state = self.merge_childen(child_values, inp) + x = self.input(inp, hidden_state) + for s in digraph.successors(node): + child_results[s] = child_results.get(s, []) + [x] + result.append(self.output(x)) + return torch.stack(result) + + def merge_childen(self, child_values, x): + stack = torch.stack(child_values).unsqueeze(0).transpose(1, 0) + att = self.children_attention( + x.expand(1, stack.shape[1], -1).transpose(1, 0), stack, stack + )[0] + return torch.sum(att.squeeze(0), dim=0) + + def input(self, x0, hidden): + x = x0.unsqueeze(0).unsqueeze(0) + a = self.input_norm_1(x + self.input_attention(x, x, x)[0]) + a = self.input_norm_2(a + F.relu(self.input_post(a))) + + h0 = hidden.unsqueeze(0).unsqueeze(0) + b = self.hidden_norm_1(h0 + self.input_attention(h0, h0, h0)[0]) + # b = self.norm(b + self.hidden_post(b)) + + c = self.merge_norm_1(b + self.merge_attention(a, b, b)[0]) + c = self.merge_norm_2(c + F.relu(self.merge_post(c))) + + return self.post(c).squeeze(0).squeeze(0) diff --git a/chebai/models/strontex.py b/chebai/models/strontex.py new file mode 100644 index 0000000..c22e72c --- /dev/null +++ b/chebai/models/strontex.py @@ -0,0 +1,14 @@ +import abc +import typing + +import networkx as nx +import numpy as np +import torch + +FeatureType = typing.TypeVar("FeatureType") +LabelType = typing.TypeVar("LabelType") + + +class StrOntEx(torch.Module): + def __init__(self, computation_graph): + pass diff --git a/chebai/preprocessing/__init__.py b/chebai/preprocessing/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chebai/preprocessing/bin/protein_token/tokens.txt b/chebai/preprocessing/bin/protein_token/tokens.txt new file mode 100644 index 0000000..c31c5b7 --- /dev/null +++ b/chebai/preprocessing/bin/protein_token/tokens.txt @@ -0,0 +1,21 @@ +M +S +I +G +A +T +R +L +Q +N +D +K +Y +P +C +F +W +E +V +H +X diff --git a/chebai/preprocessing/bin/protein_token_3_gram/tokens.txt b/chebai/preprocessing/bin/protein_token_3_gram/tokens.txt new file mode 100644 index 0000000..534e5db --- /dev/null +++ b/chebai/preprocessing/bin/protein_token_3_gram/tokens.txt @@ -0,0 +1,8359 @@ +MAT +ATP +TPG +PGA +GAS +ASS +SSA +SAR +ARD +RDE +DEF +EFV +FVY +VYM +YMA +MAK +AKL +KLA +LAE +AEQ +EQA +QAE +AER +ERY +RYE +YEE +EEM +EMV +MVE +VEF +EFM +FME +MEK +EKV +KVA +VAK +AKA +KAV +AVD +VDK +DKD +KDE +DEL +ELT +LTV +TVE +VEE +EER +ERN +RNL +NLL +LLS +LSV +SVA +VAY +AYK +YKN +KNV +NVI +VIG +IGA +GAR +ARR +RRA +RAS +ASW +SWR +WRI +RII +IIS +ISS +SSI +SIE +IEQ +EQK +QKE +KEE +EES +ESR +SRG +RGN +GND +NDD +DDH +DHV +HVS +VSL +SLI +LIR +IRD +RDY +DYR +YRS +RSK +SKI +KIE +IET +ETE +TEL +ELS +LSD +SDI +DIC +ICD +CDG +DGI +GIL +ILK +LKL +KLL +LLD +LDT +DTI +TIL +ILV +LVP +VPA +PAA +AAA +AAS +ASG +SGD +GDS +DSK +SKV +KVF +VFY +FYL +YLK +LKM +KMK +MKG +KGD +GDY +DYH +YHR +HRY +RYL +YLA +AEF +EFK +FKS +KSG +SGQ +GQE +QER +ERK +RKD +KDA +DAA +AAE +AEH +EHT +HTL +TLT +LTA +TAY +YKA +KAA +AAQ +AQD +QDI +DIA +IAN +ANS +NSE +SEL +ELA +LAP +APT +PTH +THP +HPI +PIR +IRL +RLG +LGL +GLA +LAL +ALN +LNF +NFS +FSV +SVF +FYY +YYE +YEI +EIL +ILN +LNS +NSP +SPD +PDR +DRA +RAC +ACN +CNL +NLA +LAK +AKQ +KQA +QAF +AFD +FDE +DEA +EAI +AIA +IAE +AEL +ELD +DTL +TLG +LGE +GEE +ESY +SYK +YKD +KDS +DST +STL +TLI +LIM +IMQ +MQL +QLL +LLR +LRD +RDN +DNL +NLT +LTL +TLW +LWT +WTS +TSD +SDM +DMQ +MQD +QDD +DDV +DVA +VAD +ADD +DDI +DIK +IKE +KEA +EAA +AAP +APA +AAK +AKP +KPA +PAD +ADE +DEQ +EQQ +QQS +MSD +SDT +DTV +EEL +ELV +LVQ +VQR +QRA +RAK +RYD +YDD +DDM +DMA +MAA +AAM +AMK +MKK +KKV +KVT +VTE +TEQ +EQG +QGQ +QEL +LSN +SNE +NEE +NVV +VVG +VGA +RRS +RSS +SSW +WRV +RVI +VIS +QKT +KTE +TEG +EGS +GSE +SEK +EKK +KKQ +KQQ +QQL +QLA +AKE +KEY +EYR +YRV +RVK +VKV +KVE +VEQ +EQE +ELN +LND +NDI +ICQ +CQD +QDV +DVL +VLK +LDE +EFL +FLI +LIV +IVK +VKA +KAG +AGA +GAA +AES +ESK +DYY +YYR +YRY +AEV +EVA +VAS +ASE +SED +EDR +RAA +AAV +AVV +VVE +VEK +EKS +KSQ +SQK +QKA +KAY +AYQ +YQE +QEA +EAL +ALD +LDI +IAK +AKD +KDK +DKM +KMQ +MQP +QPT +LNT +NTP +TPE +PEH +EHA +HAC +ACQ +CQL +FDD +DDA +DAI +TLN +LNE +NED +EDS +DSY +SDV +DVG +GAE +AED +EDQ +DQE +QEQ +QEG +EGN +GNQ +NQE +EAG +AGN +MAS +ASA +SAE +LSR +SRE +REE +EEN +ENV +NVY +AKT +KTV +TVD +VDS +DSE +SEE +EEG +EGR +GRG +GNE +DRV +RVT +VTL +LIK +IKD +KDY +YRG +RGK +GKI +LTK +TKI +KIC +LLE +LET +ETH +THL +HLV +VPS +PSS +SST +STA +TAP +APE +PES +FKT +KTG +TGA +AEN +ENT +NTM +TMV +MVA +IAL +ALA +ACS +CSL +SLA +AIS +ISE +TLS +LSE +DIS +EDP +DPA +PAE +AEE +EEI +EIR +IRE +REA +EAP +APK +PKR +KRD +RDS +DSS +SSE +SEG +EGQ +LES +ESH +SHL +LLH +LHD +HDN +PKH +KHD +HDL +DLS +MST +STR +TRE +VDV +DVE +SVE +SKG +KGN +EDH +HVA +VAI +AII +IIK +IES +ESE +LSK +LNV +NVL +VLE +LEA +EAH +AHL +HLI +LIP +IPS +PSA +SAS +ASP +SPA +FKA +RKE +EST +TLV +LVA +YKS +KSA +ASD +IAT +ATA +TAE +DMT +MTD +TDE +AGD +GDE +DEI +EIK +EAS +ASK +SKP +KPD +PDG +DGA +MAE +RED +EDC +DCV +CVF +VFL +FLS +SKL +EQS +QSE +SER +YDE +DEM +MVQ +VQY +QYM +YMK +MKQ +KQV +QVA +VAA +AAL +NTE +IGS +GSR +SRR +IIT +ITS +TSL +SLE +LEQ +KEQ +QAK +AKG +NDK +DKH +KHV +HVE +VEI +EII +IKG +KGY +GYR +YRA +AKI +IED +EDE +AKY +KYC +YCD +CDD +LKV +KVI +VIK +KEN +ENL +LLP +LPN +PNA +NAS +AST +STS +TSE +SES +FYK +YKK +KKM +KME +MEG +EGD +RYY +YYA +YAE +EFT +FTV +VDE +DEK +EKR +KRQ +RQE +QEV +ADK +DKS +KSL +LAA +AAY +AYT +YTE +TEA +EAT +ATE +TEI +EIS +ISN +SNA +NAD +ADL +DLA +EIM +IMN +MND +NDA +DAD +DKA +KAC +DDS +DSI +SIA +KLD +DEV +EVP +VPE +ESS +SSY +DTA +TAD +DEE +AAT +ATL +LGR +GRD +RDQ +DQY +QYV +YVY +VQF +QFM +MEQ +EQL +QLV +LVT +VTG +GAT +TPA +GSL +SLR +LRA +AAW +AWR +RIV +IVS +VSS +SRK +RKN +KND +NDE +DEH +EHV +SLV +LVK +VKD +VES +LSS +SSV +SVC +VCS +CSG +SGI +LDS +DSH +SAG +RYM +DER +RKT +KTA +TAA +EDT +DTM +TML +MLA +LAY +IAA +AAD +ADM +MAP +NSS +SSD +SDK +CNM +NMA +AFE +FEE +EEA +MQE +EQM +QMD +MDE +ATT +TTL +SRD +LVS +VSG +SGA +PAG +AGE +GEL +KNE +EEH +VET +SIC +ICS +ILR +LRL +RLL +SAT +TAS +TMI +MIA +IAY +VAV +AVA +EKA +CSM +SMA +MTM +TMD +MDK +KSE +VQK +KAK +MKA +AVT +QGH +GHE +HEL +TER +RNE +NEK +QQM +QMG +MGK +GKE +YRE +REK +EKI +IEA +EAE +ELQ +LQD +ICN +CND +NDV +LEL +ELL +LDK +DKY +KYL +YLI +IPN +NAT +ATQ +TQP +QPE +DYF +YFR +FRY +YLS +SEV +GDN +DNK +NKQ +KQT +QTT +TTV +TVS +VSN +SNS +NSQ +SQQ +QQA +QAY +EAF +FEI +ISK +SKK +KKE +KEM +EMQ +SPE +PEK +TAF +SEN +ENQ +NQG +QGD +DEG +GDA +DAG +GEG +EGE +GEN +LIL +LNA +TQA +SGE +ENK +CSD +ATH +THA +HAE +MTE +ERE +REN +ENN +NNV +VYK +VEA +EAM +ASM +SMD +MDV +VEL +TSI +NKG +KGA +EEK +EKL +KLE +LEM +EMI +MIK +IKT +KTY +TYR +RGQ +GQV +QVE +EKE +KEL +ELR +RDI +DIL +LEK +EKH +KHL +IPC +PCA +CAT +ATS +TSG +GES +YYK +YKM +EFA +FAT +ATG +TGS +GSD +SDR +DRK +ENS +NSL +LIA +IAM +AMN +NDL +DLP +LPP +PPT +ACR +CRL +RLA +AAF +MQA +EEV +EVD +VDP +DPN +NAG +GDG +DGE +GEP +EPK +PKE +EQI +QIQ +IQD +VED +DQD +DVS +MDD +DDR +DRE +EDL +DLV +LVY +VYQ +YQA +ESM +SMK +VAG +AGM +GMD +KGG +GGE +GED +EDK +DKL +KLK +KMI +MIR +REY +YRQ +RQM +QMV +ELK +KLI +LIC +ICC +CCD +CDI +ILD +LDV +VLD +IPA +AAN +ANT +NTG +TGE +TGN +NDR +AMT +ELP +MQG +EEQ +EQN +QNK +NKE +ALQ +DEN +MGD +GDR +REQ +LLQ +LQR +RAR +ARL +SAM +NEP +EPL +PLS +DRN +KTM +TMA +MAD +ADG +DGN +KKL +KVK +AYR +IEK +ELE +ETV +TVC +VCN +VLS +LSL +SLL +DKF +KFL +IKN +KNC +NCN +NDF +DFQ +FQY +QYE +YES +GEK +KKN +KNS +NSV +SVV +SEA +YKE +SKE +QMQ +EIQ +IQN +QNA +NAP +PEQ +QAC +ACL +CLL +LLA +SDQ +DQQ +QQD +QDE +VLA +ALL +KEH +EHM +HMQ +MVD +VDR +KAR +MKN +NVT +KTS +TSA +SAD +KKI +IEM +MVR +VRA +RAY +EAV +AVC +VCQ +LDN +DNY +NYL +NCS +CSE +SET +ETQ +TQY +VAT +KRA +RAT +ATV +TVV +AYS +YSE +AHE +HEI +LNY +NYS +YSV +ACH +CHL +HLA +DDD +DDG +DGG +GNN +MER +ERA +ASL +LIQ +IQK +YED +EDM +AFM +FMK +MKS +SAV +AVE +EKG +KGE +LSC +SCE +CEE +VGG +GGQ +GQR +RVL +QKS +KSN +KGP +GPE +PEV +EVK +VKE +LRG +RGV +GVC +VCD +CDT +TVL +VLG +GLL +GAG +DAE +SRV +RVF +TGD +GDD +DDK +DKK +KKR +KRI +IID +IDS +DSA +ARS +RSA +SAY +AMD +MDI +EMP +MPP +PTN +TNP +NPI +VFH +FHY +HYE +EIA +PEE +ISL +KTT +TTF +TFD +AMA +DLH +LHT +WTA +ADS +EGG +GEA +EEP +EPQ +PQS +EKT +ELI +ATC +TCM +CMK +QGA +GGR +GRR +SAW +KTD +TDT +DTS +KLQ +LQL +QLI +LRS +RSI +ICT +CTT +ANA +ATN +NPE +VAC +ACG +CGD +RKQ +QTI +TID +IDN +DNS +SQG +GAY +FDI +LNN +NNP +PEL +LAC +ACT +CTL +TLA +SDS +EEC +ECD +CDA +AEG +EGA +TIE +IEN +STV +DKE +MAQ +AQA +QAM +KSV +SVT +TET +ETG +TGV +GVE +ARK +LAR +ARE +RER +ERV +RVE +LRE +REI +EIC +ICY +CYE +YEV +EVL +IPK +PKA +KAS +ASN +SNP +DAR +ARN +RNT +NTV +VVD +VDD +DSQ +SQT +QTA +YQD +QDA +DAF +KGK +GKM +PDK +DTQ +TQG +AEP +PQE +GGD +DKN +NEL +AAC +ACM +RVV +VVS +AEK +QMA +MAR +EKF +ASQ +SQA +AAG +KKG +KGI +GIV +IVD +VDQ +DQS +QSQ +AEA +SQP +MPA +PAS +ASR +DSV +SVY +VYL +VEN +ENM +NMK +SSG +EAK +NES +ESQ +SQV +VAL +ALI +ICE +CED +EDI +ILS +SVL +SDH +DHL +LIT +SAQ +AQT +QTG +FAI +KRK +EAY +DAV +DLE +ETL +WTD +TDL +TEE +QQQ +QSS +SSQ +QAP +AQP +PTE +EGK +GKA +KAD +ADQ +MTR +VAE +NEN +ENH +NHV +HVK +VKK +KIK +EYK +YKC +KCK +CKV +LTD +TDI +ILE +LEV +GNP +NPR +PRK +SSL +IAV +DVH +VHN +HNM +NME +EKN +KNQ +NQD +QDG +DGD +DDQ +DQN +QNE +EPG +PGM +AFT +FTR +EDY +DYV +YVF +VFM +FMA +AQL +QLN +ENA +NAE +ETM +TMR +MRK +RKI +KIS +ISG +SGM +GME +KER +IGP +GPR +PRR +KEK +KGR +GRQ +RQK +QKP +KPN +NAK +AKR +RIE +QIR +IRV +RVY +VYR +QKI +LQE +EQF +QFV +FVP +VPR +PRS +RST +STN +TNA +ADA +DAK +AKV +AEY +EYS +YSS +KIA +IAG +AGS +GSA +SAL +NAY +AYN +YNS +NSA +SAF +ISQ +QLP +ILA +LAS +ACE +CEL +RKA +KAF +FDA +AAI +AIT +ITD +DLD +KLT +LTE +NLN +LNL +NLW +LWV +WVT +VTD +TDS +DDN +DNA +NEA +ALS +VLN +DNF +NFL +NCG +CGE +GET +TQH +QHE +HES +KSY +SYS +DDE +MVS +VSQ +QVV +VVA +EKP +KPQ +PQL +KKA +AGC +GCN +CNS +NSH +SHG +HGQ +GQD +QDS +SYF +YFL +FLG +LGW +GWQ +WQE +QEY +EYE +YEK +KNP +NPF +PFD +FDP +DPV +PVS +NPS +PSG +GII +IIQ +IQM +MGL +NQL +QLS +LSF +SFD +FDL +DLL +LEE +EEW +EWL +WLE +NPH +PHA +HAL +ALG +GLR +LRR +RRE +REG +GGG +GGA +ASV +VFR +FRE +REL +ALF +LFQ +FQD +QDY +YHG +HGL +GLP +LPA +PAF +AFK +FKN +KNA +NAL +ARF +RFM +FMS +MSE +SEQ +EQR +QRG +RGY +GYK +YKV +KVV +VVF +VFD +DPS +PSN +SNI +NIV +IVL +VLT +TAG +SAN +ANE +ALM +LMF +MFC +FCL +CLA +LAD +ADH +DHG +HGD +AFL +IPT +PTP +TPY +PYY +YYP +YPG +PGF +GFD +FDR +DRD +RDL +DLK +LKW +KWR +WRT +RTG +AEI +EIV +IVP +VPV +PVH +VHC +HCA +CAS +ANG +NGF +GFR +FRV +VTR +TRP +RPA +PAL +LDD +DAY +YRR +RAQ +AQK +QKR +KRR +RRL +RLR +LRV +VKG +KGV +GVL +VLI +ITN +NPL +PLG +LGT +GTA +SPR +PRA +RAD +ETI +TIV +VDF +DFV +FVA +GIH +IHL +LIS +ISD +SDE +EIY +IYA +YAG +AGT +AFA +FAE +EPP +PPA +AGF +GFV +FVS +VSA +ALE +EVV +AGR +RDG +GAD +ADV +VSD +RVH +VHV +HVV +VVY +VYS +YSL +SLS +SKD +KDL +DLG +LPG +RVG +GAI +AIY +IYS +YSA +NAA +SAA +ATK +TKM +KMS +MSS +SSF +SFG +FGL +GLV +QTQ +QYL +YLL +LLG +LGD +RDF +DFT +TRS +RSY +SYV +YVA +NKR +RRI +RIK +ERH +RHD +HDQ +DQL +LVD +VDG +DGL +EIG +IGI +GIG +IGC +GCL +CLP +LPS +AGL +GLF +LFC +FCW +CWV +WVD +VDM +DMS +MSH +HLM +LMR +MRS +RSR +SRS +RSF +SFA +FAG +GEM +EME +MEL +ELW +LWK +WKK +VFE +FEV +EVG +VGL +GLN +LNI +NIS +ISP +SPG +PGS +GSS +SSC +SCH +CHC +HCR +CRE +REP +PGW +GWF +WFR +RVC +VCF +CFA +FAN +ANM +NMS +MSA +SAK +KTL +TLD +VAM +AMQ +MQR +QRL +SFV +FVD +TGG +ALR +AVP +PVR +VRS +RSV +SVS +VSC +SCP +CPL +PLA +LAI +AIK +IKW +KWA +WAL +RLT +LTP +TPS +PSI +IAD +ADR +KAE +MAY +YQG +QGI +GID +IDL +LST +STK +TKA +HGE +YFD +FDG +DGW +GWK +WKA +AYD +YDT +DTN +DLR +LRH +RHN +HNR +NRG +RGG +GGV +GVI +VIQ +SLD +LDL +DLI +LIE +IEE +EWS +WSK +SKN +KNH +NHP +HPE +PEA +ASI +CTP +PEG +EGV +GVS +SQF +QFK +FKR +RIA +ANF +NFQ +LPE +PEF +EFR +FRK +KAM +AQF +FMG +MGQ +QVR +VRG +GGK +KAT +ATF +DPD +VVM +VMS +MSG +SGG +GAQ +AQE +QET +LAF +AFC +LAN +ANP +NPG +PGE +FLV +VPT +YPA +RDC +DCC +CCW +CWR +WRS +RSG +GIK +IKL +LPI +PIE +IEC +ECH +CHS +HSF +SFN +FND +DFR +FRL +TKE +ALV +YDG +RRQ +RQG +GIS +ISV +SVK +ILI +GTI +TIT +TDR +RDT +LAM +AML +LAT +TFA +TEH +EHR +HRV +VHL +LVC +CDE +GSV +VFA +PEY +EYV +YVS +VSI +EVI +VIE +IER +ERD +RDV +DVP +VPW +PWC +WCN +CNR +NRD +LIH +IHV +KDF +DFG +VGI +IIY +YSY +SYN +YND +AAR +RRM +RMS +QYF +FLA +ARM +RML +MLS +EEF +EFI +FIG +IGR +GRF +RFL +FLQ +QES +SKC +KCR +RLV +VAR +ARH +RHE +HER +ERF +RFT +FTS +SGL +REV +CLR +GNA +LFS +FSW +SWM +WMD +MDL +MLR +LWR +VIV +IVH +VHQ +HQV +QVK +VKL +KLN +NVS +VSP +PGT +GTS +TSF +SFH +FHC +VCH +CHA +HAN +NMD +DET +TME +MEV +GRI +RIH +IHD +HDF +FVR +VRQ +RQH +QHQ +HQQ +QQR +QRR +RRV +ERW +RWA +WAA +ANR +NRQ +RQL +QLR +RLS +SLP +LPH +PHH +HHH +HHL +HLS +LSP +PAH +SSP +SPL +SPQ +QSP +SPM +PMV +KQL +TKV +VTS +TSN +SNG +NGH +GHG +GWE +WEE +EEY +NPY +PYD +NPN +PNG +NGM +GMI +MIQ +QLC +LCF +CFD +ESW +SWL +WLT +TKN +NPD +PDA +SLK +LKR +KRN +RNG +NGQ +GQS +QSI +SIF +IFR +HGM +GMP +MPE +FKK +MEE +IRG +GNR +NRV +VTF +DPK +PKK +KIV +GST +NET +TLM +PGD +FLL +LPT +VPI +PIH +IHC +HCS +CSS +SSS +SSN +GFQ +FQI +QIT +ITE +TES +ESA +LQQ +YQQ +QAQ +QKL +VLV +VTN +TAL +ALT +LTR +TRR +LLV +DFI +FIT +TSK +KNI +NIH +YSG +SGT +GTM +TMF +MFG +FGF +GFE +FEQ +QFI +FIS +SVM +VMD +LKD +LED +DTE +TEV +EVS +VSK +SKR +KRV +YSN +SND +MIV +LSA +KKF +KFT +TSQ +SQY +YLE +NQK +KRL +RLK +LKS +KSR +SRQ +RQR +GLE +AGI +GIT +ITC +TCL +RSN +DMR +MRH +RHL +HLL +TNT +NTF +TFE +FEA +DLW +IVY +VYN +YNV +NVK +HCT +CTE +TEP +ALK +LKT +KTF +TFV +FVE +STD +TDC +DCG +CGR +GRM +RMI +MIS +ISR +SSH +SHE +ERL +LRK +RKK +KKT +SNW +NWV +WVF +RVS +VSW +SWT +RVP +VPD +PDE +VAF +TEK +KQD +QDL +DLN +IAS +DGH +AYE +ENP +PFH +FHP +PID +IDR +DRP +RPD +DGV +LCG +GDL +DLM +RKW +KWV +WVL +LKH +KHP +CTS +GVN +VNQ +NQF +QFS +FSD +IAI +AIF +IFQ +FRQ +RQA +QAV +AKF +KFM +KTR +TRN +RNN +NNK +NKV +VKF +KFD +DRI +IVM +GAH +HET +TVA +DGF +GFL +LRW +RWR +VNL +NLV +PVT +VTC +TCH +HSS +GFK +FKI +KIT +ITV +YEN +NAR +RKS +NIP +IPV +PVK +KGL +GTT +LDR +REC +ECL +CLK +LVN +VNF +NFT +FTN +TND +DKG +YAA +TFG +FGQ +SEF +EIE +DCN +IHI +HIV +KDM +DMG +PGL +VVQ +VQI +QIA +IAR +RKM +QHL +AKM +KML +FIR +RES +KLR +RHA +EIT +ITT +TTG +TGL +GLD +LDG +GLG +LGI +IGW +GWL +WLK +LKA +LFL +FLW +LWM +LRN +LLK +TAT +FDS +PGG +GGS +GSF +HCH +CHE +HEP +MDH +DHK +HKT +MET +ETA +LER +ERI +RIR +VFT +SQL +QLE +EEE +EET +ETK +TKP +KPM +PMA +TTM +TMM +MMA +AKK +KKK +KKC +KCW +CWQ +WQS +QSN +SNL +NLR +SFS +DTR +RRF +RFD +GFF +FFS +FSP +SPH +PHS +HSP +SPV +PVP +VPP +PPS +PSP +PLV +LVR +RKV +NAH +AHG +NGI +ETW +TWL +WLA +AKN +GLK +LKK +KKD +KDG +DGQ +IFK +FKE +KAL +PSK +MLT +GTV +TVF +VFG +VSV +KNL +NLE +LEN +VHI +MVV +TST +STY +TYL +YLD +LKI +KIR +IRQ +QKK +KLV +VYD +YDV +DVK +MKR +LKE +YVE +DSR +SKS +KSS +SHD +HDR +IKS +RKR +KRT +RTV +MHG +HGS +GSG +SGH +GHS +HSL +SLT +LTG +GAP +APH +PHQ +HQI +QIP +IPP +PPP +PPR +PRT +RTQ +GQQ +TAN +ANQ +DKI +KID +IDP +DPF +FHN +HNK +KRG +RGT +TSR +LRI +RIN +INN +NNS +SSR +SRY +RYN +NVD +VQL +KDT +NEQ +EQP +QPA +LVI +VQC +QCQ +CQH +QHV +HVF +FDF +DFY +FYD +YDP +PVA +VAQ +QLK +LKC +CKE +KEI +IKR +LID +IDH +DHI +HIT +TKG +AIV +IVE +TIY +IYP +PAV +AVI +IKM +KMV +NIF +VLP +PSE +ENC +NCE +CEF +EFD +DPE +EED +DEP +EPT +PTL +TLE +SWP +WPH +PHL +HLQ +VYE +YEL +ELF +FLR +LRF +FLE +ESP +PDF +FQA +QAS +SIG +IGK +GKK +KKY +KYI +YID +IDQ +DQR +QRF +RFV +FVL +DLF +LFD +DPR +PRE +DFL +FLK +VLH +LHR +HRI +RIY +IYG +YGK +GKF +RAF +AFI +IRK +RKH +KHI +HIN +NNM +NMF +MFL +YET +ETD +DSF +FNG +NGV +GVG +VGE +LEI +ILG +LGS +GSI +SII +IIN +ING +GFA +FAL +ALP +LPL +PLK +LKQ +KQE +QEH +EHK +HKV +KVL +VLL +PLH +LHK +HKP +KPK +PKC +KCL +CLS +SLY +LYH +YHA +HAQ +AYC +YCV +CVV +FIE +EKD +TPQ +PQV +QVF +LKF +KFW +FWP +WPR +RTC +TCS +SSK +KEV +EVM +VMF +GEV +EVE +DII +IIE +IEP +EPE +KII +DPL +PLF +LFR +AKC +KCV +CVS +PHF +HFQ +FQV +RAL +ALY +LYF +YFW +FWN +WNN +NNE +NEY +EYI +YIL +TSS +LVM +VMP +MPI +PIM +IMF +MFP +FPA +LYR +YRI +RIS +EHW +HWN +WNQ +NQT +IVA +TFM +MEM +EMN +MNG +NGK +GKL +KLF +LTS +TYK +YKG +GER +EKQ +KQR +QRE +KDR +RDA +AFW +FWK +MEA +LNP +NPP +EVT +VTP +PSL +SLF +LFP +FPE +TDY +DYL +DGP +GPN +PNM +NMT +MTP +TPL +PLP +LPV +AGG +GDK +KSP +SPS +PSV +VVK +KKS +STG +ETT +TTT +TTP +PAK +TKL +KLP +STP +TPT +PTS +TSP +GLS +PPD +DKV +KVD +GFS +FSR +RSL +ARP +RPR +RSH +SHS +QFR +RYQ +YQS +SNQ +NQQ +QQE +PLL +KDV +ELH +LHE +RKL +LAQ +AQC +QCG +CGV +GVM +MFD +FLD +LDC +CVA +LKG +VKR +LVE +VEC +ECV +CVG +VGS +TRG +EPV +PVY +VYP +YPD +PDI +IIR +IRM +SVN +VNI +FRT +RTL +TLP +EPN +PNL +LEP +EPS +PSW +YEF +EFF +FFL +FQP +QPS +KRY +RYV +YVD +DQK +QKF +KFV +VLM +LML +MLL +EYL +KTI +ILH +VYG +AYI +YIR +KQC +QCN +CNH +NHI +HIF +IFL +RFI +FIY +IYE +LEH +EHF +HFN +GVA +HKQ +KQF +QFL +VRV +IPL +LHS +HSV +VKS +FHA +DAT +HVI +VIR +RGL +LKY +KYW +YWP +WPK +PKT +KTC +TCT +CTQ +TQK +DVI +PSQ +FVK +VKI +KIQ +IQE +QEP +LFK +FKQ +ARC +RCV +EDN +DNC +NCH +CHT +HTV +AVF +FGT +GTL +TLY +LYQ +YQV +QVS +LIY +IYN +ASY +YKL +QQK +KAQ +ERQ +WRG +RLQ +LQG +QGT +GTQ +GAK +APV +PRP +RPT +MPY +PYK +KEP +PPK +PKV +KCT +CTA +TAK +KPS +SGK +GKD +EAQ +QPQ +PQP +PQA +AQS +QPP +SNK +KRP +RPS +NST +TPP +PTQ +TQL +IKY +KYS +GGP +GPQ +PQI +QIV +ERR +RQS +SRF +RFN +FNL +NLS +KNR +NRE +LQK +DSP +SPT +TQE +LFI +FIQ +LRQ +RQC +QCC +CCV +CVL +VLF +SDP +SDL +KFK +RAG +NEM +VEY +YIT +ITH +THS +HSR +DVV +VVT +YPE +VTM +MFS +NLF +NPT +PTG +AWP +QPN +PNI +NIA +IRR +RQI +QIN +INH +IFY +FYR +YRF +EHH +HHN +HNG +GIA +HKM +KMF +VYH +YHP +HPQ +KES +PVI +IVG +KTH +SPK +FLN +EFS +FSK +KVM +VME +MEP +LYY +YYW +YWN +YIM +IMS +MSL +SDN +ARV +YRN +RNS +NSK +KSH +SHW +WNK +NKT +TIH +IHG +GLI +YNA +LFM +MNQ +DDC +DCT +TQQ +QQY +QYK +KQK +QKG +RFR +FRM +RMK +MKE +EMW +MWQ +WQK +RLN +NPQ +PQY +QYP +YPM +PMF +MFR +FRA +RAP +APP +PPL +PPV +YSM +SME +ETP +PTA +DIQ +IQL +AVQ +VQM +QML +MLK +KDI +IKK +RRK +LPQ +PQD +DVY +VYT +YTI +TIK +IKA +AHK +HKR +RAE +FLT +SQE +MMR +MRG +RGF +RLI +STT +TTS +KKP +HGT +TTH +GSK +KST +TTE +GKQ +KQS +QSG +SGS +SVP +QGK +GKH +KHH +HHS +SKT +KTK +TKT +VSR +TKK +RKG +KGQ +QSK +SKQ +QQP +SQS +QKQ +KQG +QGS +AIM +MNP +TPV +PVL +TVT +VTK +TKD +KDD +DHA +HAH +AHP +HPT +TLL +LGA +GAV +AVS +SPI +PIS +TAV +ENG +NGN +GNS +NSN +SNN +NNN +NMN +MNI +NIN +INT +NTS +SNT +NTQ +TQD +DAN +ANH +NHA +HAS +SID +IDI +DIP +IPR +SFE +FER +RLP +PTK +PDT +DTD +KTP +PQR +QRH +RHS +RFE +FEP +PSR +RYT +YTP +PLT +PNF +NFN +FNE +NEV +RIP +FIA +DQC +CNT +DFN +NDP +PSF +IQG +KRS +IEF +TNR +NRF +FTY +TYT +YTN +TNE +EMY +MYA +YAH +AHV +VVN +VNM +MFK +KIN +INL +FRP +RPI +PIP +PVN +VNP +NPV +PVG +VGD +GDI +DIY +IYD +DED +VNE +LAW +PHM +AVY +FNH +NHQ +KQY +QYI +QDF +FIL +DIR +DCL +TLH +SFI +RSM +SMN +MNN +NNI +LQF +KFN +VRI +RIL +KVR +VRC +RCL +YCI +CIV +IVQ +KDP +LLT +VMG +LRY +RYW +PKI +INS +NEI +DIF +IFE +PLE +LEF +FIK +IKV +VEV +VPL +LFV +FVQ +KCI +CIS +LSY +SYW +EYF +NLC +LCI +CIE +VIL +ILP +PII +IIF +IFP +LYE +NGE +SIS +DPY +PYM +YML +MLV +QAI +AIN +NSG +GSW +SWN +WNR +NRA +RAI +AIH +IHA +HAM +MAF +KIF +ETN +VLY +CNA +LYL +KET +QRK +KVQ +ENW +NWS +YVK +VKN +NND +KDQ +QYT +NSF +FNT +NTA +NNT +NTL +ENE +END +NDC +DCD +CDS +SEI +IKQ +KQI +QIF +IFG +FGK +LPR +RKP +SHN +HND +NDS +DSN +VNS +NSY +SYY +YYI +YIP +PNS +NGA +GAN +NGT +TVI +VIA +IAP +APS +SNR +NRT +RTN +TNQ +NQV +QVN +VNG +GVY +YEA +SFR +FRD +KLS +LSM +SMC +MCC +RQT +QTL +VDY +DYI +YIA +VST +SDA +QEI +RTF +TFP +FPS +NHE +KIL +DVD +EPA +PAW +LQV +LLL +PMT +TDA +RYI +DHS +FMV +MVH +VHR +HRP +RPF +PFI +KAI +FIF +FET +KHN +HKL +IRA +RPK +KCA +AYH +YHQ +SYC +DFK +FKL +ADT +WPV +TNS +QAA +EFQ +FQR +QRC +RCM +CMV +MVP +CLN +SHF +LWN +NDH +HIR +IRN +NLI +ITQ +TQN +QNH +NHK +VIM +IMP +PIV +IVF +VFP +PAM +AME +NTR +RGH +GHW +NQA +VQS +QSL +NVR +VRK +VMA +AET +TDQ +DQI +QIL +ILF +DEC +KFQ +FQE +QED +EAN +KRE +ATW +TWK +WKL +AVL +PRF +RFS +FSS +TGK +GKT +LTC +TCN +CNK +NKA +SRM +RMV +VDA +NGP +GPF +PFQ +QPV +PVV +VVL +LHI +QEK +KWK +WKE +SEM +THN +NRN +RNV +VIT +EPI +PIY +VVH +VHM +HMF +MFA +FAV +AVN +VLQ +HKI +MAL +KIM +IME +THW +QQF +EAW +AWV +WVK +KAN +YTV +TVY +YSQ +STM +TMS +MSI +SIP +TDG +GPL +LFE +FED +EDV +DVQ +TVK +AHQ +HQA +QKD +RPL +QDP +DPH +PHT +HTK +AHC +CRA +SQD +DGR +MSV +ATD +TDD +DAL +LYP +YPI +PIA +IDE +DVT +TLR +NSI +SIR +STI +TIA +LGV +VER +ERT +RTR +IQF +LVL +QLG +LGN +GNF +FTP +LVG +GPD +PDH +HVH +HCL +VVR +VRD +RDK +ESL +KHS +HFV +VPM +PML +GDW +DWF +WFT +SRT +RTS +SAC +CGL +YPR +PRV +PAI +KSM +SMF +TLC +LCR +CRD +RDD +DDT +DTP +TPM +VRR +KLG +GEF +FAK +FEK +IEG +EGL +GLH +LHV +HVD +EQD +SVR +VRL +SAI +IAF +AFG +ANK +NKK +PIL +IEL +KSW +RVR +VRY +YMV +IEI +QNV +DMD +MDT +DTT +NMY +MYT +TNL +EVR +RCA +CAA +TQR +QEF +NLP +PED +DKR +RQN +QNI +NII +IIC +LLN +NVA +LAG +AGV +IMG +APL +PLI +LIG +EQT +QTV +VSE +IYM +YMQ +NDQ +DQT +QTP +KVN +EDG +DGK +GKW +FMP +MPL +LGQ +FFD +PLC +LCL +LNW +NWL +TDH +VFS +FSI +IMK +LTQ +KFG +FGG +GQW +QWA +WAS +TNI +VPK +PKM +MQK +TNY +YLQ +QRM +RMT +MTC +CLF +MTQ +EDD +VPN +PNV +VRF +FNA +AKS +RIG +GKN +PST +VKP +KPL +LGK +DSD +SDF +DFD +FDV +DVR +RYF +YFS +FSE +SLG +SVD +DSL +LKN +SIK +RSE +IPF +PFL +FAM +AMY +MYL +LRT +EHS +HSA +EIH +VVP +TLQ +VCY +CYP +VTQ +RAN +NFR +KLC +LCQ +NKL +TEY +KSD +NFV +LAV +EAC +ACV +IAQ +VEH +EHL +QCA +VDL +DLQ +AVG +VGP +PEI +ITR +TRV +RVD +AFQ +DFC +FCA +CAN +ANL +NLD +QVQ +QII +IIL +SIL +LPY +PYV +YVR +PNP +PHV +SVI +MLG +YQT +ECP +CPE +CVN +VND +GIQ +IQQ +LSQ +SKW +IEY +EYM +YMP +AGQ +GQL +FDQ +GLC +LCM +CMG +MGW +WLN +HVY +VYA +YAI +AIR +LNM +QFG +FGA +APW +PWA +WAE +IIP +IPM +PMI +MIL +MSR +SRN +RNK +NKN +KNY +YLH +HRM +EVC +VCG +CGT +GTD +DIT +TTK +PTV +ADP +VAN +ANV +FNV +SPF +VID +IDA +DAQ +AQV +KPT +NTD +TDV +VKH +KHF +HFA +FAA +LPF +GTF +TFT +FTT +YVH +ISH +HEH +PSD +AHF +AVK +RQY +FRN +LCS +SDD +DNV +FSN +MPT +FTE +ITK +FQN +QNL +NLM +LMK +MKD +KDC +DCE +CEA +ASH +SHK +KEF +EFC +FCE +CEN +ADC +DCR +MSQ +SQI +LPC +PCI +CIK +NQH +KDN +DNT +NTI +IEH +GIR +EDA +AKW +SLC +CMA +MAW +AWL +WLV +VDH +NLK +KEW +EWA +WAH +AHA +HAT +ATI +TII +AMS +GDP +PNY +MTT +TLF +FCI +CIN +INV +CGQ +TKH +KHM +HML +MLP +VLR +LRM +RMA +MAG +SLQ +KIG +GPI +LQS +KPI +QDQ +VKY +KYF +YFA +FAQ +TTA +YPL +LLM +LMD +HDD +LGP +PER +EVF +VPY +PYI +YIG +IGG +QYA +YAT +ILL +VRE +SLN +QLF +ADW +WFS +KVS +IVR +NIL +MVK +RAV +VGK +NLG +EDW +DWD +WDY +YIS +FQK +IND +NDN +DNQ +VDC +CLI +ISI +KFF +FFN +DES +SHT +HTQ +IGD +DRF +VQP +QPF +LCE +DNE +NEG +GDV +SGF +LNK +NKI +VQN +TVR +NKD +DQV +QVI +VIN +NNF +FLP +NML +EFP +FPD +PDV +IIA +GIE +DVN +VNW +NWR +VRM +MAI +IPI +LGM +GMQ +MQF +QFF +DLC +LSW +WLW +LWD +WDT +YSI +VNN +NNL +EIF +FGS +SDW +DWC +WCR +SRL +ENF +FTI +LTT +GVP +NIR +IRF +SYA +YAV +KYD +YDA +KNT +LQT +AEC +ECQ +CQE +MVM +SQN +QNQ +NQP +AND +FDM +EGP +ETF +PVD +INW +NWK +WKF +FNQ +GNI +NID +VHT +HTE +EAD +ISC +SCV +CVE +FSH +HDG +GEY +GRV +VVI +VIF +QRD +GKY +KYV +GVR +EYN +YST +STF +TFQ +FQS +QSH +FDY +EID +INQ +NQI +IRW +RWL +NFI +DKT +KLW +WKI +DAW +AWN +WNL +NRI +FRG +RGR +GRL +LQI +SIV +PME +YGN +AHT +HTY +TYH +YHV +HVN +NSD +TFL +DDL +RVN +ESF +FNI +VDI +IKP +PAN +ITA +EFH +TQC +CNW +NWF +WFV +KGS +RLC +LCD +CDM +MRD +RDR +ALC +AYA +YAK +DPQ +QSR +SFF +KFS +NGR +GRY +TRD +YLT +KVW +VWD +WDL +MES +PVE +ETY +TYP +YPV +HNY +YLR +RTK +LCA +CAL +IFD +FDK +KFE +FEC +CDW +DWS +WSG +HIL +ILT +GSY +SYH +YHN +HNL +FRS +YAR +ARG +NNQ +KTW +TWE +WEA +EAR +RPQ +EPH +HSQ +FVV +QLQ +QFD +HTA +TAW +AWH +WHP +HPK +PKD +DNI +TNN +NLY +LYI +YIF +IFS +MGR +GRW +RWG +WGR +PDP +PQM +MQT +FMR +MRQ +SIT +IGN +GNM +MLN +TAI +INI +SWC +WCF +CFS +FSQ +QIK +GAL +ADI +EFN +NHD +RDP +SKA +RRG +RGE +INK +WLQ +QKN +VHF +HFL +WKV +KSF +GGY +GYN +YNT +NTK +NGL +PQN +VTA +VKQ +RRT +YHI +LWH +WHL +HLE +NQS +QSY +YNI +TNM +TEC +ECN +CNV +NVF +VFV +KGT +TIR +CDR +DRH +HSK +QFE +PEN +NRS +SGR +YMI +LSI +LHM +HME +VHE +HEY +DCI +CIF +ECC +CWN +WNG +SIM +IMT +MTG +YNN +NFF +FFR +LKP +KPR +KVC +VCT +CTG +GKR +CLD +LDF +FNK +ENI +QDK +DID +IDT +TRK +SFL +RDH +HSY +IST +NHT +HTG +QVH +HRR +WLP +PQQ +QQN +AYF +RPE +EGY +YNL +PAT +LRP +RPM +PMD +LMV +TPR +SDY +DYE +TYM +YMS +WNF +NFE +QSF +HPH +HHC +HCN +MRA +RHT +TKF +FFE +HSG +MEN +ENR +NRP +RPV +TYQ +VHD +HDY +CVW +VWN +NGS +RMF +TKR +AIL +VCV +DFS +HPS +MRF +RFC +FCV +AWF +WFF +FFP +FPN +NTT +TTR +VFW +FWD +WDA +AFS +SNF +FTG +TGC +GCH +CHH +HHG +GQN +GLY +YFQ +RFG +FGY +GYI +IPE +PET +TFS +FSG +SGN +FTD +DDF +ELY +QTN +TNF +LDA +LTI +TIQ +IQH +QHI +IVI +VIP +PRC +RCG +CGN +SLM +LMH +HGG +EVN +RTH +HLH +LHA +HAV +YTL +FPG +EPR +PRW +RWP +PRN +RNR +NRR +RRD +DLT +LTY +TYA +YAF +PKN +SRA +FGR +RWS +WSD +FTL +FST +ITI +TIG +IGF +GFY +FYT +YTG +GDH +EPF +LAH +HAF +SPP +KFH +FHL +HLD +WVV +ESV +AVH +IGH +GHL +LGH +ESI +IMY +MYP +YPT +PTI +LTN +VEG +EGI +IQY +YLY +LYG +YGA +KHQ +HQR +DTG +GGF +FSA +RID +IDG +DGS +TVG +VLW +LWF +WFL +MGS +PLR +KPG +TSW +WNS +VRT +TQV +EYG +YGC +GCF +CFE +KGH +LNG +GNK +NKP +KPE +EYD +GFT +EGM +GMG +MGV +VGR +RIT +LMW +MWP +WPE +CET +SYG +KRM +KMM +MMV +MVF +FES +FGM +HFD +SFC +CES +LHF +HFM +MRY +QPG +PGK +GRS +RSP +SLH +HKD +KSI +IVN +NQN +QND +EFE +GEW +EWI +WIL +ADN +DNH +GDC +DCF +CFM +AWS +WSN +RLH +QAR +FSF +SFP +FPK +EHP +HPL +LLF +LFN +FNP +PFE +YCF +CFT +FTK +KEG +CDL +PAQ +PFR +FRI +QGP +ERP +RQQ +QQC +QCS +CSQ +SQR +QRI +RIQ +QGE +NQC +QCR +CRS +RSQ +SQM +QSC +SCC +CCQ +LQN +NVE +EQC +CQC +MPG +GWS +WSC +SCL +CLV +FVG +VGQ +VQE +QTK +MLE +LEG +AQY +CQG +VIH +IHT +IDV +VSH +SHV +HVL +PRQ +IYC +YCS +CST +AGP +HEE +HHE +STW +TWS +AYP +YPY +PYS +YSK +KNG +NGG +GGT +HTC +TCA +PMY +MYI +YIY +YGE +ERS +VMI +KNK +VYV +YVG +VGN +GNV +VAW +AWA +AHI +NVQ +VQG +GQF +QFY +TPH +HQS +SYD +LNC +NCT +EWG +WGL +RLD +SWS +WSL +LLY +LYW +YWL +VSF +PFY +FYN +YNY +NYR +YRP +RPP +PPF +PFN +FNC +SKF +FTF +FSY +AQR +LGY +GYV +YVP +SWE +SEW +WIG +IGT +EQH +QHR +HRE +RET +DTK +TKS +GGL +AFR +QNR +TAC +ACI +CII +DVF +FGV +GVT +VTH +THR +MNV +NVN +VNV +CVQ +VQA +PVF +VFI +IYT +YTS +IEV +QNG +NTW +TWP +WPT +PYP +NGW +GWN +NGD +GDT +LYT +YTC +PTY +TYI +SIN +INE +NNG +SVG +TVN +KAP +YDN +NYI +EFG +SRW +LMY +MYW +YWI +SYQ +YQP +FNR +NRH +YKP +PLY +LYS +YSW +VEW +EWV +WVG +RHK +HKE +TLK +KSK +KTQ +YRT +KHK +VTV +RGD +DIV +QGM +GMS +VII +IIH +DAC +TFH +FHT +MVN +VNR +KNN +KRH +SIQ +NYT +WGF +GFC +MVT +VTI +TIS +ISY +GYE +YEP +QVP +YLV +GGC +GCG +CGF +GEH +EHI +LEW +EWE +WEP +PRL +LHL +TGP +GPV +PVQ +VQV +QVT +AIQ +QAH +HEV +GSH +IHK +VQT +TGT +GTR +TRL +SSM +GHP +HPF +PYE +IHR +HRH +RHP +HPY +YPC +PCS +CSK +GRK +RLF +AIP +EHG +HGR +AWM +WMH +MHI +LMG +MGG +QVY +VYF +YFC +FCY +CYD +YDK +SPY +SYE +EDF +FNM +MEF +SPC +PCG +GTH +PYW +WLL +LQW +QWL +PYT +TNK +RHF +HFG +ART +RTI +IHW +HWV +WVQ +RMG +DAS +ELG +VTT +DRG +WVR +DVC +VCA +TIF +IFH +ELM +DEY +QRS +NVG +GTE +TEN +HAG +GVQ +YTD +DLY +AQN +GVD +DGM +GML +CAI +IRP +GIW +IWG +WGN +GNG +GDQ +QTM +GHV +HGF +GFI +AAH +DGT +APG +PGQ +GQA +YFI +FIN +PIN +INM +MFE +FEF +FAR +QRW +KMR +MRI +SGP +GPA +AVR +VRW +RWV +WVM +VMT +TGW +WQR +HFR +FRF +GFP +PAP +RLY +NYF +LFT +TTQ +QAL +YYV +QMK +ARA +MMK +QLH +RMR +GRT +RTP +RLE +AHN +HNI +LQA +CLQ +PLM +LMA +SFK +LDP +PDS +SMG +EMS +MSC +SCA +ARI +FEM +EMT +MTL +LQP +QPL +HKK +DWN +WNT +QAT +QGL +LGG +GSP +HSH +HTT +MAN +YHF +FVT +KED +YAN +ANY +IQA +QAD +ADY +NHG +PSM +SMT +MTA +THF +HFP +FPR +YGV +GRE +CVM +VMM +MML +GMK +FCS +SYL +PEP +LMT +MTF +LYD +DDW +DWM +WMR +CSR +PPE +YLM +MKF +VNK +NKM +KMT +LLW +LWP +WPP +DQA +QLD +IQV +VGV +GVV +IQS +QSA +DIN +INF +QDT +DRL +RTE +PAR +PTM +TMP +PPQ +PPG +GTP +TVP +PGP +NPA +QVD +SGV +QPR +HNV +NVH +VHK +TAM +PLN +LNR +NRL +HTH +THM +HMA +QCK +CKD +HFS +YFT +FTH +HRK +NHS +APF +PFS +QEE +MTS +ALH +HDV +QEN +FNN +GIF +APQ +QQV +MTV +LPK +PKP +PTD +VGT +PCP +CPA +SNM +NMP +DQG +TED +GGH +HPP +PRG +EMH +MHW +HWP +PMK +AIG +LTM +AGY +GYL +KWP +WPL +FVI +KRC +CVY +VYY +YYF +YFK +PQG +GAF +FSL +LSG +SGY +YNR +RVM +VMR +FPF +PFK +HIS +KKH +KHR +HRT +RTW +TWF +WMA +GHF +HFH +FHE +HEK +PLD +SFY +FYG +TDN +YEH +EHD +EPD +PGR +MHP +PAY +YPP +DMP +MPR +RAH +AHS +SFT +GPG +KHG +LPD +LCP +CPR +EPC +DPP +KPP +PPC +PCF +CFR +EPW +PWT +WTP +PGH +HGA +GAC +IMA +RNC +NCD +CDK +RGP +GPP +SEP +PKF +AMP +VAP +APR +RQP +KVP +FVN +VNT +ESC +CEV +LYC +CIR +GKV +LVV +VVW +WDE +ETS +VRN +RNY +RIF +KFY +GSM +SMV +EHY +HYH +YHT +THV +PSH +SHQ +PYG +YGY +GYT +IQI +QIE +EIN +TFR +GNC +NCI +RPY +AQI +CQK +HAA +MSN +HEW +EWQ +WQF +FDN +NAW +AWQ +QEM +EML +LNH +QKV +MDA +DCH +EHQ +FRR +NKS +SRP +PYF +YFE +QVC +TYS +DIH +HRQ +GDF +DFP +FPT +PGV +FQL +EKC +KCD +CDY +DYP +YPS +GSQ +QMS +ACD +DYD +VRP +DVW +VWE +WEH +EHE +LDH +LMM +QQT +STE +QRP +RHC +HCD +CDV +TSC +HHQ +HQL +NHL +TPI +PIK +VSM +SMR +MRE +DRS +RRR +PRI +LNQ +QST +INR +ARQ +KFR +KPY +YWE +RVA +RQF +QRV +LVH +ARY +AMG +FEL +KYY +YVQ +KMA +IHE +MGP +RGC +TSV +DSC +SCS +CSN +TQS +QSV +GPT +MPD +PDQ +DQF +QFP +RPG +GMM +MMF +FPV +SEC +ECS +PEC +ECE +ERG +ANN +NNR +NRM +LQC +QIG +ISA +REH +HKA +LQM +GKS +TRM +GCD +GVK +YHS +HSN +WDD +YGD +HAD +IGE +IFN +FNS +QLW +WMV +VDN +FQT +QTE +YWS +WSE +LGF +LHG +HGY +FEH +HFK +FKD +DQM +QFT +FTA +NDT +QTR +VFN +AFP +KFA +AYL +YRW +RWH +WHS +SYI +TPD +FHS +QCL +CLW +WRW +RWW +WWK +WGC +GCP +LTF +TFI +IRH +RHR +EFY +IDM +DMV +VKT +DMY +MYD +DTF +KRW +RWD +WDP +MVL +EMA +QGR +AEW +WIA +TGY +PTF +FEN +GHR +QPI +PFP +FPH +HHI +ILQ +IDF +NDY +DYA +YAC +CSI +TRC +RCY +CYK +ASC +SCT +SCY +CYM +STQ +MIE +NWE +WEF +PDN +DNN +NNA +API +KHA +AFN +LHH +HHF +HFY +YRD +DGY +GYS +LDY +QFA +SVQ +VQQ +CVK +AQW +QWI +SCI +DNP +DMI +YMR +LIN +CLG +GSC +SCN +DFA +CGY +GYA +IVC +CFW +HSD +GQK +III +GGI +RGA +YER +GLQ +GPH +PHG +HGW +GWR +WRM +SWG +LDQ +IVV +YLP +FQQ +QQH +QHY +HYG +YGG +HRS +RSD +KLH +LHN +DIE +IHS +DAP +AEM +EMK +IGY +HFI +QRY +RTA +DWG +YNH +NHC +CDP +QDR +WRN +NNW +NWW +WWQ +WQM +HAP +PLQ +LQY +AVM +MAM +MED +LFA +GNL +LDW +DWE +RRP +RCS +SRI +IQT +RFW +FWG +WGE +WHV +EGT +TAR +WFI +YAD +DWL +LWG +WGY +GYD +HIA +MPQ +EWR +WRY +RYA +YAL +NWQ +WQP +PPY +YDW +WSW +WML +IPD +CNP +PGC +GCV +CVD +QGV +QLY +YIC +ICF +CFP +LPM +MTI +TIP +IPG +MKT +QTF +PGI +RWT +RGW +WQA +PDD +DDY +RFP +GMT +RRY +RWK +WKP +KPW +PWR +HIW +IWY +WYT +EGW +QPD +RIC +ICV +LFF +FFA +FAP +RNA +NPW +PWN +AGK +LYM +FQH +QHF +NAV +VEM +MYQ +YQR +QRN +RNF +TMH +MHS +RFH +KHY +HYS +YSF +TRW +RWE +FYS +GPM +PMR +MRT +TGH +NWI +WIV +IRT +TGR +TTD +DSG +SDG +QYY +FWI +WII +FLY +YDL +ACW +CWA +WAP +LFG +IWI +WIP +NYD +YDQ +GYM +CVR +RGM +GMA +AYV +SKM +GIP +IPY +PYR +RAM +KYA +YPH +PHI +HIE +RTM +MDP +MRP +PGN +HSM +SML +GIM +IML +YPW +DRR +MWC +VQD +QRQ +QQI +INA +RNQ +EMR +YLN +PTR +NPC +QYG +DAH +AHR +HRA +QAW +GRA +AHH +HGC +GCS +SRH +GVH +VHG +AWI +ASF +QNP +NPM +PMG +LMP +VYW +YWK +WKG +RRW +KIW +IWR +WRA +EYA +GGN +DRY +YYG +FYA +YAM +AMR +MRL +RLW +WPG +GEI +GTK +FAF +MVG +GKP +MFY +FYM +YMT +TGQ +VVV +GMV +HQG +PHY +GVW +VWI +PNN +RKY +HAI +IIG +DTY +PEM +LCW +WVP +VPG +PGY +YSD +VEP +KPF +PDL +PMN +MNM +NMV +VMQ +MQQ +HPR +KVG +TWG +WGK +VGM +IGL +LYV +GIY +IYV +RHG +HGV +EHN +HNE +QMR +MRV +KYQ +PIT +TEW +EWT +WTV +LME +AWW +WWG +WGP +PWF +WFA +IIV +KRF +FMN +MNE +SMP +HHM +HMY +MYG +GQY +YGQ +GQG +WLI +LIF +QYR +IFA +KWL +ESG +DFH +FHR +HRG +YDR +DPT +IKH +HGP +RTD +LYA +PVM +MGH +GHT +TVQ +RTY +HGI +KHT +HTP +KMC +MCW +GRP +AYG +MKV +TMW +MWA +WAK +HEA +CGG +LVF +RYR +WLD +NAF +VGH +SAP +QAG +QDW +DWT +YTA +AQG +GLT +TTI +SIW +IWL +RQD +NIE +PDY +RMD +INP +DIG +GRC +CTK +DRM +MIG +QNF +NFA +PRY +MHA +FEG +AIW +IWS +WSM +GPS +ATR +RRN +VPQ +TSH +CSP +DNG +SFM +FMI +MIF +DCP +CPP +AQH +QHC +CRK +RCR +AFF +FFC +FCP +PPN +AIE +AID +GNT +FYP +AMV +SYR +QDM +MIC +CYN +YNQ +PTT +GQC +QCY +DHR +GCA +CAC +ACP +CPN +CCS +KCN +YKT +TCP +LCY +MFM +GCI +CID +CPK +YVC +VCC +CCN +DRC +RCN +VCL +KCY +CYV +TQT +QTC +CEK +EKY +VSY +YFH +FHD +YEC +ECT +CHR +GPY +PYN +NVC +LCN +MGE +THT +HTI +HTS +HLN +KFI +ITY +EIP +NAN +LII +DFF +FCN +TSM +TYF +LLC +LCT +CTF +FLH +HHP +LHQ +HQT +FPL +PMS +LFY +YRK +KTN +TNV +YKH +NMR +YGP +LSH +PHD +HDT +HEC +FLC +CFG +AQQ +SGC +GCR +CRF +LWL +EMD +EGF +VGF +TWV +PQK +HDA +THC +HCG +CGW +WSS +GWP +MPM +IYI +HLP +RPC +PCL +NNH +HIY +YTY +TIM +IMI +FVF +MGA +YLG +ACF +CFV +VIC +ICI +EGC +CIH +IHF +HDI +QSD +PKG +VML +LTH +THK +HKG +YMH +HSE +LMC +MCV +LFH +FHI +QFC +KYK +PFV +PPI +TVM +IKF +KFP +QGY +GYG +YGM +AMC +MCL +MKI +QIM +TRT +IDK +WLH +DND +FIV +KGF +HPN +AFV +YKR +VFF +FFV +PKS +TKQ +PNH +QIY +NSR +IQP +QPK +IVT +CHV +CLH +QAN +NEH +YIH +DVM +MLC +LCV +IQR +RYK +WLY +IDD +TFY +FID +SPN +VVC +TTN +EMM +TGM +MSK +SHR +HRN +GEQ +FIC +CTV +MFH +YGL +YGS +MHE +HEM +MMS +SMH +MHT +VLC +KYP +TGI +RYG +YGT +GQI +GPK +PKQ +GYF +WLR +CYI +IFV +GYQ +FPM +YVV +APY +IKI +MDS +HAR +PNT +HVT +VNH +HPD +NIK +ESD +FHV +TFW +WPD +PDM +DMK +KYN +TWY +IHQ +SHP +EYP +YPK +PKL +IRS +RSC +CSA +LMS +PHK +KPV +VCI +FGW +WFH +FDT +KYG +INC +NCA +CAV +FCK +CKK +FKV +DYS +TRI +RKF +FLM +MEC +ECR +CRN +PRD +PPM +HLR +GHQ +HQP +DYC +YCT +PCH +MIT +DPI +PIQ +QMP +EVY +RGS +SNV +NVP +PSC +TPF +RKC +CVP +QFQ +MDR +KCP +CPH +PHR +YTK +YDS +EKW +KWH +WHA +KDH +HRL +REF +FGD +FGE +RND +SYT +PEW +EWF +TGF +CNG +NEF +VPC +SMI +RWF +PHE +QNS +GNY +MLQ +PFM +GDM +TMK +MSP +GQP +GLM +VFQ +TRA +PIG +FQG +GMR +TAQ +AQM +NFY +FYQ +GFG +DRT +KMY +MYE +NRY +VPH +HVP +LHP +HPG +VHP +PQH +SHA +HMH +KWF +WFG +LEY +DYK +APM +PMH +NPK +QTS +THQ +MPH +SHC +HCV +SDC +CVT +KQP +QPM +MNA +GWV +LFW +FWL +WLG +QKW +KWW +WWH +WHT +HKN +QTD +QID +NFG +TPN +NSW +SWF +VDT +EFW +WQN +NIT +LLI +GTN +ESN +NRW +RWC +WCS +CSW +YQL +QLM +MLF +MLW +DPG +RHW +HWD +WDQ +NER +HEG +FPY +PYA +QMN +MNL +KLY +FAD +TKC +KCH +QKH +YKI +NDM +MVI +SHI +HIQ +ECK +CKY +KYE +RQV +KLM +MKL +YVT +VKM +DHY +HYA +DME +VFC +CIT +PIF +IFF +FFF +KIP +WFK +KSC +SCK +CKG +CAY +CKS +LQH +QHP +PWV +WVE +MRM +MLH +SHM +NSM +QGN +GYY +YYD +KGW +RYP +YSP +PND +ITP +IFC +NAC +QVL +NKW +KWT +WTL +TCD +LCC +CCT +HLC +YWA +WAI +TDP +IDY +YVN +LTW +CTI +AFY +FYI +YGR +TRH +RNW +WRL +EVH +TPC +CAP +IIM +MGT +ILC +CWL +PFF +FFI +PFC +CHM +HMP +VIY +YAY +YFN +IKC +CKF +KFC +FCR +CRQ +WNI +WRR +RRC +RCP +CPV +YQI +FGN +CVI +IFI +ITW +TWI +CRI +ILM +DTC +VHH +HHY +HYV +LHC +HCK +CKP +ETC +IQC +HNC +IYQ +LPW +PWK +ITL +TMY +CDF +DFW +WLS +TCC +IMH +MHL +TPK +LVW +VWV +FFW +FWR +WRQ +PNK +VCW +FII +PIC +ICK +CWF +FHM +FNW +YTM +AFH +FHK +RFK +FKC +PNQ +GAW +AWD +YTT +TWN +DIW +IWV +WVS +AGH +GHA +AMI +AVW +TAH +QIS +STC +TCG +CGA +ILY +ITG +ICW +ICR +SCW +CWI +WIH +IHP +HPA +FFT +FTW +TNC +EKM +MLI +ICM +CMT +YIV +DRW +EVW +VWL +CTC +NAI +LMI +TVW +VWT +WTI +ISM +SQC +QCT +QHD +HDH +IYH +YQK +FAS +CKL +TFC +TEF +IRI +DHP +SIY +ADF +IRC +MPS +NWT +CEG +KNW +WSA +MAV +DML +MPV +WIY +IYL +IHH +VFK +FIP +IMV +SIH +TMQ +MQS +ACK +FLF +VMW +WCP +CPF +NIM +CNE +FVW +YIQ +CQY +KQH +QHS +LMQ +DCS +MEI +PTC +NRK +QVG +CTD +FVH +LHW +HWA +FVM +AMW +WLF +NQY +QYN +TCV +DFM +FML +VTY +MLD +MRR +CNQ +CNY +KIY +IYF +RNP +FFK +GIN +EQV +SEH +CDH +TVH +VHW +IWP +PHN +TCE +HRD +NDG +RTT +VNA +VGY +NYQ +KCS +YFG +MFQ +NWA +VMV +DWP +CPI +PIW +CIA +NYP +PHP +ITF +SNH +IMM +MMI +MII +KVY +SYP +TMG +SMQ +CGP +GPC +PCD +ANI +RLM +SWV +TRY +THI +LGC +NCK +CKQ +VHS +VWQ +FKF +QNW +NWP +WPA +HNA +YDY +YVW +VWP +YLC +PVW +WIS +IVW +VWA +IGV +TTC +TYC +YCL +CLT +YVL +HGH +KCC +CCK +CKR +IMW +IYR +SNY +LRC +NYK +NIY +YRH +HTN +MQV +TFN +FQF +NCC +CCC +APN +CGK +SKY +RCD +GKG +HNS +RDW +DWR +WRK +TTY +YIW +WYR +QFW +FWT +QWN +WNP +VRH +RHQ +EVQ +QNY +YNF +NFP +QNC +SLW +WEL +FYV +YFV +VCM +PLW +QLT +MGN +GNH +MCG +ETR +LWS +WSV +SVW +VWH +WHY +QYW +YWT +VYI +FSM +NYY +IAW +CSH +HMG +DFE +WSI +IWQ +WQY +CIP +IPQ +HST +QWT +GVF +PDW +MAH +HVG +IWH +LWA +WAC +CIL +DTH +THH +YHL +QKY +KYH +YHK +YNW +WTK +VHA +VWY +WYQ +WND +IWA +APD +ENY +TFK +HDK +QFN +RRH +PNC +NCR +HFF +TIC +HDE +VWS +WSQ +YLF +ALW +WGG +TRQ +QYH +CVH +SMW +MWY +RNH +QTY +WQL +GCT +APC +CAE +FWF +WFQ +LCH +CHF +WSR +GGW +TIN +DQH +QHG +PFT +ISF +WDN +NWN +KEC +DKP +FYF +YFP +DSM +WEI +MSM +YII +MMN +PMP +RCC +CCP +CPT +SGW +GWT +NCP +CPG +GQH +IWN +WNC +NCY +CYS +YSR +HTF +HHA +RPW +PWH +WHN +HNQ +VQW +ECG +SMS +ANW +RAW +SFQ +QIH +HHR +REW +YEQ +CRP +FKM +GQM +WTR +ICL +MVW +ISW +IRY +LDM +RNI +QTH +THG +YVI +VPF +DAM +CWD +WDR +DRQ +MPF +PFG +CCI +AIC +CQP +GCW +WVI +QGW +NIG +CSV +AYY +STH +MGC +CFC +CLC +DYT +QVW +YIN +GQT +QWE +WES +QCH +CHP +VCR +WAY +EMF +MFI +TWC +WCV +FCF +SRC +RCH +HLT +TFF +IWF +WFY +NHN +PQT +QDN +TNH +FAW +RWQ +LWI +IAC +WNV +RHM +MEY +EYT +NVM +TWA +GWG +CQV +PSY +IYK +DTW +TWR +WRE +CSC +SCD +IWK +WKS +NYN +KNF +PNR +SQH +WFP +SWD +WDI +AWG +WVA +VMC +GWH +WHE +YCR +GMF +MFF +VTW +CDC +GYC +YCN +TMN +PCV +HCP +CCL +TQI +FCC +CLE +PRH +WST +MMD +YGH +SWA +PTW +TWD +INY +NYG +NCL +KWI +WIF +FGH +GKC +KCM +GWA +WAQ +FMY +MYY +YYQ +AKH +HKF +ECA +KHE +ICG +GMH +CTR +DHH +HNW +MIH +HPV +WMP +RVQ +DHC +GHD +AWE +WET +CHI +HIG +NTY +QNN +GAM +HLY +GWM +WMI +TCI +RYS +HNT +NQW +QWS +DPW +YSC +MPK +EWK +YAQ +AYW +MLY +CAR +KYM +YME +EYQ +MTN +NHY +ATM +FFQ +RPH +QMF +YGI +IFT +MMP +TDF +CHN +EWY +HSI +NFM +FMD +MDF +FKG +NKH +YMD +YDI +MCP +QYQ +TIW +VAH +HEF +FPI +YPF +HGN +SCR +NMG +MIN +HTD +MFW +WNH +MCI +WEN +SVH +CRV +KMD +FQM +QMI +HDS +CHG +WAV +PKW +PQW +QWP +TSY +QWR +TCR +SEY +SHY +HYQ +SFW +FWA +WAM +GFN +VMK +DYW +YWY +WYS +HWQ +WQT +VMH +MHF +PKY +VKW +WWS +HMM +MMT +MTY +KWD +TKW +KWS +HPM +MKP +KPH +HNH +QHA +MVY +YYT +EYY +WDF +GYH +WVY +YGF +GFH +HDP +NEW +WYE +YKY +YIE +EWN +GTG +YAS +CSY +GMN +RIW +IMD +TYE +DHF +HIK +PCC +CCE +CEW +GHY +DIM +TDK +KWN +QGF +GFM +ATY +TYG +YGW +MWR +NKF +WTG +PYL +YQN +QNT +RGI +GHI +HID +REM +YRC +QYS +DYG +CRY +GWD +GMC +MCA +CAF +EMC +MCK +DMM +MID +MFV +VGC +FGP +KAW +HLK +DWV +KYR +WTC +TCF +DSW +GTW +WVH +TNW +ICA +CAK +HIM +YCA +MLM +FRW +WGI +QFH +FHH +HHT +SCF +CFL +NHH +IDW +DWA +WAR +ARW +WHF +FWV +NTN +KKW +SAH +CGS +VWF +LWW +WWR +WRP +HRF +SHH +HGK +NYE +YEG +FSC +NAM +MKM +KMG +IGM +TWT +WTM +SWQ +WQD +PMM +FPP +FIH +HHD +RFA +NQM +CIG +IGQ +YSH +IIW +IPW +PWY +WYL +HKY +LWE +WEG +YDF +FGI +KGM +HSC +GTC +WPF +WWL +HQD +WGD +FKP +NYV +YMM +WGH +HQF +GGM +YDM +QGG +QDC +CDN +FDC +EQY +MQM +GYP +WLM +FEW +FAY +HLG +DGC +HLF +AHD +GFW +FWW +YTH +THY +HYK +TQF +FMH +DMF +YFF +VWG +WGV +TEM +RMP +FNF +RHY +HYC +YCE +MQI +FFG +YCK +CKH +YYN +PPH +CPS +GHH +KQN +FTC +GCK +CRG +QHH +QYC +MGM +DPM +MGI +EPY +QCM +CMQ +NIC +EYC +CRT +YPN +TRF +RDM +WPY +LYK +MNC +NCV +SCQ +CQA +GHN +HNN +AMF +LYN +NIW +ECF +FAC +ACA +RHV +EWM +WME +FEY +DDP +TMT +MTW +IPH +CNI +NIQ +NVW +VWK +DHE +RWN +FCT +PCW +DMW +YWF +FHG +YMG +DHW +HWK +VCK +SWK +RFF +CAG +MFT +VMY +MYN +PRM +RMC +CPM +THD +CKI +FWQ +WQV +MDY +RME +WGA +AYM +YYL +FFM +KNM +YCP +CPD +WDG +HNF +MGY +LEC +CLY +SWI +NLQ +GSN +PQC +QCV +PNW +IYY +NMI +EGH +HIP +WQG +MWS +HII +YNG +RMH +WKR +RHI +YNK +HQE +RWM +DHT +WEV +KCE +IWT +YRL +KCG +WSP +KGC +CKA +HDM +MPN +DKC +DNR +LCK +YRM +RFQ +FMM +WMS +QIW +HVC +NWD +MDG +FIW +DPC +HED +KWQ +WQQ +PVC +HLW +TYN +GHK +EWD +WDS +PHC +HCI +CIQ +IWD +WDV +AWC +IDC +DCA +CRR +WPM +QYD +HYR +HCW +CWS +WHI +WRD +KMP +VIW +AWK +WKH +HYP +FFY +QCI +CIY +FAH +NWG +FYE +TDW +DWK +CEP +HMV +DKW +YFY +VDW +HKS +MQH +FDH +CCM +MMM +DYN +WKQ +WCL +IHN +NNY +NYH +VNY +HTR +MNR +DMH +MHY +GTY +IFM +QRT +GCE +CEI +YEM +RMM +YTR +YAP +YMN +WGQ +WNM +GHM +WQI +NFD +WEY +FKY +HYD +HVW +AMM +RMY +QWV +MMQ +MSF +HFE +WER +HQM +VQH +YPQ +PQF +GHC +HMN +MNT +FFH +MMH +FIM +MIY +IYW +YWH +FMC +MCS +DWY +WYA +MWK +CMR +HYN +GWI +WIW +KWC +PWQ +RYC +THE +YQF +ERM +EWP +SWY +WYN +WKY +WEC +ECM +CME +RVW +VWC +WCK +RFY +NHM +KHC +GWC +HRW +RWI +WIK +FRC +HIH +RCW +CWP +CGI +FVC +VMN +KDW +AMH +MHQ +NQR +TCY +CYT +YTQ +HHV +AHY +QSM +LMN +FMT +MTH +GRN +NMQ +NGY +AWT +FCG +NMH +MHM +YCQ +NWC +CKT +VCE +HWH +NLH +VWW +PCR +RWY +WYF +YCM +QVM +QHT +HVR +RMN +QPW +FRH +HQK +YKF +MQN +KWE +TYV +HMR +ICH +KYT +TDM +CEY +CVC +PAC +NFK +KCF +YNC +QSW +WEW +WPW +YQH +NFH +MSY +YNP +DQP +HKH +MTK +KAH +VKC +YKW +GWW +WWP +MWG +VYC +YCG +HSW +WNE +CFI +CLM +CHK +RCQ +TCQ +PFA +NNC +QGC +MNY +NYM +KQM +QME +NCF +PDC +WAN +RPN +VCP +WIN +PPW +PWL +CRH +PWD +SYM +FGC +YIK +VNC +YTF +SNC +QHM +MEH +CQT +ITM +EYH +CQF +DYM +SMM +QMH +CYA +MAC +WVN +WAT +FWM +WMT +CCG +CYG +WAF +EPM +MVC +HWG +ELC +RCI +WQH +FWH +QWQ +AGW +NWY +WYC +CRW +CQS +LIW +CAQ +QMW +MWT +CER +ERC +VGW +IAH +NAQ +WIM +MKC +FQC +MWE +TQM +YHW +HWS +NYA +WMM +MMW +MWN +WNW +NWM +YEY +PCQ +HFW +FNY +NHR +NSC +TNG +HVM +HQW +EYW +IWE +HCE +PYH +YHD +YKQ +SWH +HAY +QMY +KIH +WFN +CSF +RCE +YCH +GRH +YNE +HQN +QPH +HYL +MHV +WIT +SCG +SPW +FHF +CIW +WAG +CTW +YAW +RHH +NFW +MNK +GEC +AHM +CYY +HEQ +MWV +IMR +FCD +HQC +CYF +MHC +PMC +HQY +WTH +QKC +HRC +HYF +CYL +HKC +WPS +WDC +FMQ +QHK +CFK +NEC +DNM +CQM +QMT +MDN +DCK +WDW +LHY +TKY +FPC +MDM +QWF +MDW +DWW +WWE +GLW +TWM +MSW +WEQ +WKN +PMQ +WAW +WMQ +DCY +CYR +CFH +HMS +IWW +WWI +PFW +WVC +ACY +MNS +CGC +GCM +TYY +YYS +MIM +MKW +HMI +FWE +MKH +MEW +SMY +MYH +HYI +CKN +NMM +RIM +SKH +YEW +CQR +RYH +HTM +WKT +KMN +FKH +TCK +WYI +HNP +NGC +MRN +FHW +EIW +KVH +WFE +YCY +AHW +TYW +YWR +WNA +EMG +CFF +HYT +FHQ +NKY +HHK +PCE +FCM +CMY +DHM +QQW +QWY +WYM +MRW +FPQ +MME +MYR +LWQ +GWY +WYD +HPW +YWD +CAH +EQW +QWK +WSH +NMC +PNE +FYH +QKM +HWE +WHD +RQW +SWW +WWA +MYS +KQW +WWT +CPQ +WIE +ACC +CCH +WEK +GMY +HFT +WTY +MMG +WTN +YYM +NTH +YCC +CCF +DYQ +WEM +WGT +NHF +CMS +WGS +MIW +YQM +IHM +QDH +TWQ +CAD +GNW +NWH +YYH +YYY +YFM +TPW +WED +MCR +YNM +WWD +MYV +YWM +SCM +CMM +NRC +RCT +CTN +YHM +QWC +WCT +TTW +TWW +WWY +WMG +YYC +WID +YVM +WIR +FYC +FWS +FYW +WTW +RCF +QQG +HMD +HEN +CKM +MKY +HCF +SQW +TYD +GIC +FQW +IFW +YQY +CCY +WAD +WSF +MYK +NDW +MIP +QWG +TCW +CWW +YLW +TQW +IHY +MQC +QCD +WTQ +MWW +VWM +WMK +GMW +MQW +NCQ +CQI +MRC +PWP +WTF +HVQ +HMC +DWQ +ILW +PWS +YHH +CPC +YHE +HAK +RNM +CEH +CMF +QHN +QCE +MDQ +DHQ +YTW +WLC +MCF +WFC +CFQ +YCW +CWE +MPW +WYK +MGF +FTM +CWK +HWF +PCT +MHN +HKW +WYV +DCW +CYQ +CAW +HWC +HWR +RSW +PYC +FKW +WFW +FMF +YMY +DCM +YDH +LWY +WKD +WRF +DKQ +QEC +WTE +CEM +GCY +MNH +CEQ +HYY +PYQ +QIC +GPW +PWW +MCD +WHR +NYW +QWM +CQQ +YHC +FCH +CHQ +QCF +NFC +PCN +PWG +CMI +CTM +QCP +WWN +TMC +CYW +EHC +CCR +FTQ +CNF +FDW +DWI +PWM +YWG +KMH +PWE +KWG +WGM +WHM +WPQ +CHY +VWR +WRH +CYC +AWY +DHN +CIC +CPW +ICP +QWD +CQW +CTY +WRC +WYW +MWL +CGH +HPC +PCY +EWH +QNM +PCM +QMM +WMY +WPN +WCE +HQH +CNN +CMW +PCK +QWH +NTC +HIC +CMC +MCQ +KHW +KCQ +MHK +CWG +HMT +WFM +IWC +CML +HWT +MHR +DQW +IQW +WVW +WPC +WHG +WYH +IEW +VHY +YQW +WDH +CHD +QPY +WKC +YDC +NHW +WDM +QPC +CKW +KWY +NCM +CQN +MYF +YMW +MMC +KMW +MWI +MHD +ECI +CMD +WCI +CGM +GCQ +MCE +WWF +WTT +HDC +FCQ +DMN +PWI +RMQ +WGW +WYP +MYM +HCC +CDQ +MNW +CMP +RCK +MWD +FPW +QTW +WNY +MCT +MHH +IWM +CFY +HYW +PHW +HWW +CFN +MWF +HCM +MWH +GYW +HAW +DWH +YWV +NMW +QEW +CNC +WDK +NKC +GCC +MPC +MCN +CCA +KWM +MCM +HWL +WSY +CKC +WMF +CWY +HCQ +WCA +HMK +DHD +YHY +DNW +WCD +WPI +WFD +WHW +WHC +HCY +WHQ +IMC +KPC +YMC +CRC +MCY +ECY +MCH +HWI +DCQ +PMW +LWC +CRM +DMC +MNF +HWY +YWW +YWC +WYY +EWC +FWC +FWY +WMN +WWV +EWW +WCM +CAM +WKM +WHH +YMF +WCQ +WIQ +MFN +ANC +ECW +WCG +CIM +WQC +CMH +MYC +CTH +HHW +QWW +WIC +CPY +MDC +NYC +CMN +WHK +MMY +DEW +QHW +WQW +CEC +TWH +HFC +WKW +HWM +MQY +HDW +WYG +CWM +CYH +HYM +QMC +QCW +NCW +YQC +FMW +WMC +WWW +HMW +RMW +CHW +WCW +HTW +CWC +WCY +YWQ +WMW +CWT +CWH +MWM +WWC +WCC +WCH +WWM +TAX +AXD +XDR +IEX +EXV +QAX +AXX +XXE +XES +MXN +XNF +NRX +RXX +XXX +XXR +XRI +SAX +AXG +XGG +PRX +RXR +XRX +RXE +XEF +QEX +EXQ +XQR +REX +EXR +RXQ +XQQ +DRX +RXP +XPG +QMX +MXT +XTX +TXR +XRM +APX +PXX +XXG +XGI +NLX +LXX +XXM +XMA +LNX +NXE +XEA +GTX +TXN +XND +LIX +IXI +XIM +MVX +VXX +XXK +XKT +GLX +LXP +XPP +QGX +GXD +XDL +XAP +QNX +NXM +XMN +VAX +XGV +IKX +KXY +KEX +EXL +XLY +GQX +QXE +XEP +PLX +XKC +PVX +XKE +RXI +XIR +AXL +XLN +LLX +LXD +XDA +AXE +XEL +GGX +GXG +KAX +XXA +XAG +XWS +SPX +PXC +XCD +GWX +WXH +XHF +MPX +ESX +SXN +XNK +DLX +LXN +XNS +QXG +XGD +ITX +XRG +NEX +EXA +XAL +LDX +DXI +XII +TPX +PXM +XMR +NXG +XGY +ASX +SXV +XVE +TKX +KXA +KRX +XXT +XTL +IDX +DXX +XXL +XLV +AKX +KXX +QHX +HXV +XVN +NSX +SXX +XKX +XDP +DAX +AXK +XKQ +PIX +IXX +XXF +VLX +XDI +DIX +IXL +XLK +LKX +KXV +XVA +DNX +NXD +ILX +LXK +XKV +VYX +YXE +XEI +RXS +XSH +KGX +XGF +AVX +VXY +XYG +HVX +XXI +XID +TVX +XXS +XSA +ENX +NXX +XMD +IIX +XMQ +AEX +EXX +XME +PGX +GXP +XPR +SKX +KXF +XFT +HRX +XSW +PQX +XGR +QQX +VTX +XRP +PSX +SXP +XPL +VGX +GXY +RSX +SXS +XSL +VSX +XST +AXV +XVL +AGX +GXX +XTK +KLX +LXR +XRV +AHX +HXC +XCS +LVX +VXN +XNR +NGX +GXL +TSX +SXQ +XQN +KXL +XLL +VIX +IXG +XGA +GFX +FXG +XGL +PTX +TXT +XTS +EMX +MXQ +SXY +XYA +IQX +QXY +XYR +TXK +IGX +XPS +PXT +XTG +NXQ +VKX +KXS +XSN +GVX +VXE +GRX +XRE +YKX +KXE +XEE +EEX +EXT +XTI +EHX +HXN +XNL +NDX +DXD +IAX +KSX +SXL +RRX +XRK +DDX +DXE +RXG +VXL +XLS +DTX +TXG +VXF +XFA +XIG +VXT +XTA +ISX +SXR +XRY +VQX +QXP +XPC +LGX +GXS +HGX +XGH +XXD +XDD +KKX +XXV +PKX +XLT +XSP +XLD +RAX +AXS +XSI +IYX +YXX +XXP +XPI +MSX +SXT +GEX +XHP +LFX +FXX +VXI +XIW +QTX +TXX +XXQ +XQA +FLX +DXN +XNC +MXS +XSR +YLX +EQX +QXS +TMX +MXC +XCY +NXA +XAV +EXE +XEQ +HPX +PXP +LMX +MXX +KTX +XKK +XXH +XHS +MKX +XIH +WRX +XKS +EXY +XYQ +QKX diff --git a/chebai/preprocessing/collate.py b/chebai/preprocessing/collate.py new file mode 100644 index 0000000..ecbcb87 --- /dev/null +++ b/chebai/preprocessing/collate.py @@ -0,0 +1,137 @@ +from typing import Dict, List, Tuple, Union + +import torch +from torch.nn.utils.rnn import pad_sequence + +from chebai.preprocessing.structures import XYData + + +class Collator: + """Base class for collating data samples into a batch.""" + + def __init__(self, **kwargs): + pass + + def __call__(self, data: List[Dict]) -> XYData: + """Collate a list of data samples into a batch. + + Args: + data (List[Dict]): List of data samples. + + Returns: + XYData: Batched data. + """ + raise NotImplementedError + + +class DefaultCollator(Collator): + """Default collator that extracts features and labels.""" + + def __call__(self, data: List[Dict]) -> XYData: + """Collate data samples by extracting features and labels. + + Args: + data (List[Dict]): List of data samples. + + Returns: + XYData: Batched data. + """ + x, y = zip(*((d["features"], d["labels"]) for d in data)) + return XYData(x, y) + + +class RaggedCollator(Collator): + """ + Collator for handling ragged data samples, designed to support scenarios where some labels may be missing (None). + + This class is specifically designed for preparing batches of "ragged" data, where the samples may have varying sizes, + such as molecular representations or variable-length protein sequences. Additionally, it supports cases where some + of the data samples might be partially labeled, which is useful for certain loss functions that allow training + with incomplete or fuzzy data (e.g., fuzzy loss). + + During batching, the class pads the data samples to a uniform length, applies appropriate masks to differentiate + between valid and padded elements, and ensures that label misalignment is handled by filtering out unlabelled + data points. The indices of valid labels are stored in the `non_null_labels` field, which can be used later for + metrics computation such as F1-score or MSE, especially in cases where some data points lack labels. + + Reference: https://github.com/ChEB-AI/python-chebai/pull/48#issuecomment-2324393829 + """ + + def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData: + """ + Collate ragged data samples (i.e., samples of unequal size, such as molecular sequences) into a batch. + + Handles both fully and partially labeled data, where some samples may have `None` as their label. The indices + of non-null labels are stored in the `non_null_labels` field, which is used to filter out predictions for + unlabeled data during evaluation (e.g., F1, MSE). For models supporting partially labeled data, this method + ensures alignment between features and labels. + + Args: + data (List[Union[Dict, Tuple]]): List of ragged data samples. Each sample can be a dictionary or tuple + with 'features', 'labels', and 'ident'. + + Returns: + XYData: A batch of padded sequences and labels, including masks for valid positions and indices of + non-null labels for metric computation. + """ + model_kwargs: Dict = dict() + # Indices of non-null labels are stored in key `non_null_labels` of loss_kwargs. + loss_kwargs: Dict = dict() + + if isinstance(data[0], tuple): + # For legacy data + x, y, idents = zip(*data) + else: + x, y, idents = zip( + *((d["features"], d["labels"], d.get("ident")) for d in data) + ) + if any(x is not None for x in y): + # If any label is not None: (None, None, `1`, None) + if any(x is None for x in y): + # If any label is None: (`None`, `None`, 1, `None`) + non_null_labels = [i for i, r in enumerate(y) if r is not None] + y = self.process_label_rows( + tuple(ye for i, ye in enumerate(y) if i in non_null_labels) + ) + loss_kwargs["non_null_labels"] = non_null_labels + else: + # If all labels are not None: (`0`, `2`, `1`, `3`) + y = self.process_label_rows(y) + else: + # If all labels are None : (`None`, `None`, `None`, `None`) + y = None + loss_kwargs["non_null_labels"] = [] + + # Calculate the lengths of each sequence, create a binary mask for valid (non-padded) positions + lens = torch.tensor(list(map(len, x))) + model_kwargs["mask"] = torch.arange(max(lens))[None, :] < lens[:, None] + model_kwargs["lens"] = lens + + return XYData( + pad_sequence([torch.tensor(a) for a in x], batch_first=True), + y, + model_kwargs=model_kwargs, + loss_kwargs=loss_kwargs, + idents=idents, + ) + + def process_label_rows(self, labels: Tuple) -> torch.Tensor: + """ + Process label rows by padding sequences to ensure uniform shape across the batch. + + This method pads the label rows, converting sequences of labels of different lengths into a uniform tensor. + It ensures that `None` values in the labels are handled by substituting them with a default value(e.g.,`False`). + + Args: + labels (Tuple): Tuple of label rows. + + Returns: + torch.Tensor: Padded label sequences. + """ + return pad_sequence( + [ + torch.tensor([v if v is not None else False for v in row]) + for row in labels + ], + batch_first=True, + ) diff --git a/chebai/preprocessing/collect_all.py b/chebai/preprocessing/collect_all.py new file mode 100644 index 0000000..62e140f --- /dev/null +++ b/chebai/preprocessing/collect_all.py @@ -0,0 +1,226 @@ +import logging +import os +import sys + +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from pytorch_lightning import loggers as pl_loggers +from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning.metrics import F1 +from sklearn.metrics import f1_score +from torch import nn +from torch_geometric import nn as tgnn +from torch_geometric.data import DataLoader + +from data import ClassificationData, JCIClassificationData + +logging.getLogger("pysmiles").setLevel(logging.CRITICAL) + + +class PartOfNet(pl.LightningModule): + def __init__(self, in_length, loops=10): + super().__init__() + self.loops = loops + self.left_graph_net = tgnn.GATConv(in_length, in_length) + self.right_graph_net = tgnn.GATConv(in_length, in_length) + self.attention = nn.Linear(in_length, 1) + self.global_attention = tgnn.GlobalAttention(self.attention) + self.output_net = nn.Sequential( + nn.Linear(2 * in_length, 2 * in_length), + nn.Linear(2 * in_length, in_length), + nn.Linear(in_length, 500), + ) + self.f1 = F1(1, threshold=0.5) + + def _execute(self, batch, batch_idx): + pred = self(batch) + loss = F.binary_cross_entropy_with_logits(pred, batch.label) + f1 = self.f1(batch.label, torch.sigmoid(pred)) + return loss, f1 + + def training_step(self, *args, **kwargs): + loss, f1 = self._execute(*args, **kwargs) + self.log( + "train_loss", + loss.detach().item(), + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + self.log( + "train_f1", + f1.item(), + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return loss + + def validation_step(self, *args, **kwargs): + with torch.no_grad(): + loss, f1 = self._execute(*args, **kwargs) + self.log( + "val_loss", + loss.detach().item(), + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + self.log( + "val_f1", + f1.item(), + on_step=True, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return loss + + def forward(self, x): + a = self.left_graph_net(x.x_s, x.edge_index_s.long()) + b = self.right_graph_net(x.x_t, x.edge_index_t.long()) + return self.output_net( + torch.cat( + [ + self.global_attention(a, x.x_s_batch), + self.global_attention(b, x.x_t_batch), + ], + dim=1, + ) + ) + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters()) + return optimizer + + +class JCINet(pl.LightningModule): + def __init__(self, in_length, hidden_length, num_classes, loops=10): + super().__init__() + self.loops = loops + + self.node_net = nn.Sequential( + nn.Linear(self.loops * in_length, hidden_length), nn.ReLU() + ) + self.embedding = torch.nn.Embedding(800, in_length) + self.left_graph_net = tgnn.GATConv(in_length, in_length, dropout=0.1) + self.final_graph_net = tgnn.GATConv(in_length, hidden_length, dropout=0.1) + self.attention = nn.Linear(hidden_length, 1) + self.global_attention = tgnn.GlobalAttention(self.attention) + self.output_net = nn.Sequential( + nn.Linear(hidden_length, hidden_length), + nn.Linear(hidden_length, num_classes), + ) + self.f1 = F1(num_classes, threshold=0.5) + + def _execute(self, batch, batch_idx): + pred = self(batch) + labels = batch.label.float() + loss = F.binary_cross_entropy_with_logits(pred, labels) + f1 = f1_score( + labels.cpu() > 0.5, torch.sigmoid(pred).cpu() > 0.5, average="micro" + ) + return loss, f1 + + def training_step(self, *args, **kwargs): + loss, f1 = self._execute(*args, **kwargs) + self.log( + "train_loss", + loss.detach().item(), + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + self.log( + "train_f1", + f1.item(), + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return loss + + def validation_step(self, *args, **kwargs): + with torch.no_grad(): + loss, f1 = self._execute(*args, **kwargs) + self.log( + "val_loss", + loss.detach().item(), + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + self.log( + "val_f1", + f1.item(), + on_step=False, + on_epoch=True, + prog_bar=True, + logger=True, + ) + return loss + + def forward(self, x): + a = self.embedding(x.x) + l = [] + for _ in range(self.loops): + a = self.left_graph_net(a, x.edge_index.long()) + l.append(a) + at = self.global_attention(self.node_net(torch.cat(l, dim=1)), x.x_batch) + return self.output_net(at) + + def configure_optimizers(self): + optimizer = torch.optim.Adam(self.parameters()) + return optimizer + + +def train(train_loader, validation_loader): + if torch.cuda.is_available(): + trainer_kwargs = dict(gpus=-1, accelerator="ddp") + else: + trainer_kwargs = dict(gpus=0) + net = JCINet(100, 100, 500) + tb_logger = pl_loggers.CSVLogger("../../logs/") + checkpoint_callback = ModelCheckpoint( + dirpath=os.path.join(tb_logger.log_dir, "checkpoints"), + filename="{epoch}-{step}-{val_loss:.7f}", + save_top_k=5, + save_last=True, + verbose=True, + monitor="val_loss", + mode="min", + ) + trainer = pl.Trainer( + logger=tb_logger, + callbacks=[checkpoint_callback], + replace_sampler_ddp=False, + **trainer_kwargs + ) + trainer.fit(net, train_loader, val_dataloaders=validation_loader) + + +if __name__ == "__main__": + batch_size = int(sys.argv[1]) + # vl = ClassificationData("data/full_chebi", split="validation") + # tr = ClassificationData("data/full_chebi", split="train") + tr = JCIClassificationData("data/JCI_data", split="train") + vl = JCIClassificationData("data/JCI_data", split="validation") + + train_loader = DataLoader( + tr, + shuffle=True, + batch_size=batch_size, + follow_batch=["x", "edge_index", "label"], + ) + validation_loader = DataLoader( + vl, batch_size=batch_size, follow_batch=["x", "edge_index", "label"] + ) + + train(train_loader, validation_loader) diff --git a/chebai/preprocessing/datasets/__init__.py b/chebai/preprocessing/datasets/__init__.py new file mode 100644 index 0000000..d09b21c --- /dev/null +++ b/chebai/preprocessing/datasets/__init__.py @@ -0,0 +1,4 @@ +from .base import XYBaseDataModule +from .chebi import * +from .pubchem import * +from .tox21 import * diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py new file mode 100644 index 0000000..817bc1d --- /dev/null +++ b/chebai/preprocessing/datasets/base.py @@ -0,0 +1,1184 @@ +import os +import random +from abc import ABC, abstractmethod +from typing import Any, Dict, Generator, List, Optional, Tuple, Union + +import lightning as pl +import networkx as nx +import pandas as pd +import torch +import tqdm +from iterstrat.ml_stratifiers import ( + MultilabelStratifiedKFold, + MultilabelStratifiedShuffleSplit, +) +from lightning.pytorch.core.datamodule import LightningDataModule +from lightning_utilities.core.rank_zero import rank_zero_info +from sklearn.model_selection import StratifiedShuffleSplit +from torch.utils.data import DataLoader + +from chebai.preprocessing import reader as dr + + +class XYBaseDataModule(LightningDataModule): + """ + Base class for data modules. + + This class provides a base implementation for loading and preprocessing datasets. + It inherits from `LightningDataModule` and defines common properties and methods for data loading and processing. + + Args: + batch_size (int): The batch size for data loading. Default is 1. + train_split (float): The ratio of training data to total data and of test data to (validation + test) data. Default is 0.85. + reader_kwargs (dict): Additional keyword arguments to be passed to the data reader. Default is None. + prediction_kind (str): The kind of prediction to be performed (only relevant for the predict_dataloader). Default is "test". + data_limit (Optional[int]): The maximum number of data samples to load. If set to None, the complete dataset will be used. Default is None. + label_filter (Optional[int]): The index of the label to filter. Default is None. + balance_after_filter (Optional[float]): The ratio of negative samples to positive samples after filtering. Default is None. + num_workers (int): The number of worker processes for data loading. Default is 1. + chebi_version (int): The version of ChEBI to use. Default is 200. + inner_k_folds (int): The number of folds for inner cross-validation. Use -1 to disable inner cross-validation. Default is -1. + fold_index (Optional[int]): The index of the fold to use for training and validation. Default is None. + base_dir (Optional[str]): The base directory for storing processed and raw data. Default is None. + **kwargs: Additional keyword arguments. + + Attributes: + READER (DataReader): The data reader class to use. + reader (DataReader): An instance of the data reader class. + train_split (float): The ratio of training data to total data. + batch_size (int): The batch size for data loading. + prediction_kind (str): The kind of prediction to be performed. + data_limit (Optional[int]): The maximum number of data samples to load. + label_filter (Optional[int]): The index of the label to filter. + balance_after_filter (Optional[float]): The ratio of negative samples to positive samples after filtering. + num_workers (int): The number of worker processes for data loading. + chebi_version (int): The version of ChEBI to use. + inner_k_folds (int): The number of folds for inner cross-validation. If it is less than to, no cross-validation will be performed. + fold_index (Optional[int]): The index of the fold to use for training and validation (only relevant for cross-validation). + _base_dir (Optional[str]): The base directory for storing processed and raw data. + raw_dir (str): The directory for storing raw data. + processed_dir (str): The directory for storing processed data. + fold_dir (str): The name of the directory where the folds from inner cross-validation are stored. + _name (str): The name of the data module. + + """ + + READER = dr.DataReader + + def __init__( + self, + batch_size: int = 1, + train_split: float = 0.85, + reader_kwargs: Optional[dict] = None, + prediction_kind: str = "test", + data_limit: Optional[int] = None, + label_filter: Optional[int] = None, + balance_after_filter: Optional[float] = None, + num_workers: int = 1, + chebi_version: int = 200, + inner_k_folds: int = -1, # use inner cross-validation if > 1 + fold_index: Optional[int] = None, + base_dir: Optional[str] = None, + **kwargs, + ): + super().__init__() + if reader_kwargs is None: + reader_kwargs = dict() + self.reader = self.READER(**reader_kwargs) + self.train_split = train_split + self.batch_size = batch_size + self.prediction_kind = prediction_kind + self.data_limit = data_limit + self.label_filter = label_filter + assert (balance_after_filter is not None) or ( + self.label_filter is None + ), "Filter balancing requires a filter" + self.balance_after_filter = balance_after_filter + self.num_workers = num_workers + self.chebi_version = chebi_version + assert type(inner_k_folds) is int + self.inner_k_folds = inner_k_folds + self.use_inner_cross_validation = ( + inner_k_folds > 1 + ) # only use cv if there are at least 2 folds + assert ( + fold_index is None or self.use_inner_cross_validation is not None + ), "fold_index can only be set if cross validation is used" + if fold_index is not None and self.inner_k_folds is not None: + assert ( + fold_index < self.inner_k_folds + ), "fold_index can't be larger than the total number of folds" + self.fold_index = fold_index + self._base_dir = base_dir + os.makedirs(self.raw_dir, exist_ok=True) + os.makedirs(self.processed_dir, exist_ok=True) + if self.use_inner_cross_validation: + os.makedirs(os.path.join(self.raw_dir, self.fold_dir), exist_ok=True) + os.makedirs(os.path.join(self.processed_dir, self.fold_dir), exist_ok=True) + self.save_hyperparameters() + + @property + def identifier(self) -> tuple: + """Identifier for the dataset.""" + return (self.reader.name(),) + + @property + def full_identifier(self) -> tuple: + """Full identifier for the dataset.""" + return (self._name, *self.identifier) + + @property + def base_dir(self) -> str: + """Common base directory for processed and raw directories.""" + if self._base_dir is not None: + return self._base_dir + return os.path.join("data", self._name) + + @property + def processed_dir_main(self) -> str: + """Name of the directory where processed (but not tokenized) data is stored.""" + return os.path.join(self.base_dir, "processed") + + @property + def processed_dir(self) -> str: + """Name of the directory where the processed and tokenized data is stored.""" + return os.path.join(self.processed_dir_main, *self.identifier) + + @property + def raw_dir(self) -> str: + """Name of the directory where the raw data is stored.""" + return os.path.join(self.base_dir, "raw") + + @property + def fold_dir(self) -> str: + """Name of the directory where the folds from inner cross-validation (i.e., the train and val sets) are stored.""" + return f"cv_{self.inner_k_folds}_fold" + + @property + @abstractmethod + def _name(self) -> str: + """ + Abstract property representing the name of the data module. + + This property should be implemented in subclasses to provide a unique name for the data module. + The name is used to create subdirectories within the base directory or `processed_dir_main` + for storing relevant data associated with this module. + + Returns: + str: The name of the data module. + """ + pass + + def _filter_labels(self, row: dict) -> dict: + """ + Filter labels based on `label_filter`. + This method selects specific labels from the `labels` list within the row dictionary + according to the index or indices provided by the `label_filter` attribute of the class. + + Args: + row (dict): A dictionary containing the row data. + + Returns: + dict: The filtered row data. + """ + row["labels"] = [row["labels"][self.label_filter]] + return row + + def load_processed_data( + self, kind: Optional[str] = None, filename: Optional[str] = None + ) -> List: + """ + Load processed data from a file. Either the kind or the filename has to be provided. If both are provided, the + filename is used. + + Args: + kind (str, optional): The kind of dataset to load such as "train", "val" or "test". Defaults to None. + filename (str, optional): The name of the file to load the dataset from. Defaults to None. + + Returns: + List: The loaded processed data. + + Raises: + ValueError: If both kind and filename are None. + """ + if kind is None and filename is None: + raise ValueError( + "Either kind or filename is required to load the correct dataset, both are None" + ) + # if both kind and filename are given, use filename + if kind is not None and filename is None: + try: + # processed_file_names_dict is only implemented for _ChEBIDataExtractor + if self.use_inner_cross_validation and kind != "test": + filename = self.processed_file_names_dict[ + f"fold_{self.fold_index}_{kind}" + ] + else: + filename = self.processed_file_names_dict[kind] + except NotImplementedError: + filename = f"{kind}.pt" + return torch.load( + os.path.join(self.processed_dir, filename), weights_only=False + ) + + def dataloader(self, kind: str, **kwargs) -> DataLoader: + """ + Returns a DataLoader object for the specified kind (train, val or test) of data. + + Args: + kind (str): The kind indicates whether it is a train, val or test data to load. + **kwargs: Additional keyword arguments. + + Returns: + DataLoader: A DataLoader object. + """ + dataset = self.load_processed_data(kind) + if "ids" in kwargs: + ids = kwargs.pop("ids") + _dataset = [] + for i in range(len(dataset)): + if i in ids: + _dataset.append(dataset[i]) + dataset = _dataset + if self.label_filter is not None: + original_len = len(dataset) + dataset = [self._filter_labels(r) for r in dataset] + positives = [r for r in dataset if r["labels"][0]] + negatives = [r for r in dataset if not r["labels"][0]] + if self.balance_after_filter is not None: + negative_length = min( + original_len, int(len(positives) * self.balance_after_filter) + ) + dataset = positives + negatives[:negative_length] + else: + dataset = positives + negatives + random.shuffle(dataset) + if self.data_limit is not None: + dataset = dataset[: self.data_limit] + return DataLoader( + dataset, + collate_fn=self.reader.collator, + batch_size=self.batch_size, + **kwargs, + ) + + @staticmethod + def _load_dict( + input_file_path: str, + ) -> Generator[Dict[str, Any], None, None]: + """ + Load data from a file and return a dictionary. + + Args: + input_file_path (str): The path to the input file. + + Yields: + dict: A dictionary containing the features and labels. + """ + with open(input_file_path, "r") as input_file: + for row in input_file: + smiles, labels = row.split("\t") + yield dict(features=smiles, labels=labels) + + @staticmethod + def _get_data_size(input_file_path: str) -> int: + """ + Get the number of lines in a file. + + Args: + input_file_path (str): The path to the input file. + + Returns: + int: The number of lines in the file. + """ + with open(input_file_path, "r") as f: + return sum(1 for _ in f) + + def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: + """ + Load data from a file and return a list of dictionaries. + + Args: + path (str): The path to the input file. + + Returns: + List: A list of dictionaries containing the features and labels. + """ + lines = self._get_data_size(path) + print(f"Processing {lines} lines...") + data = [ + self.reader.to_data(d) + for d in tqdm.tqdm(self._load_dict(path), total=lines) + if d["features"] is not None + ] + # filter for missing features in resulting data + data = [val for val in data if val["features"] is not None] + + return data + + def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: + """ + Returns the train DataLoader. + + Args: + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + DataLoader: A DataLoader object for training data. + """ + return self.dataloader( + "train", + shuffle=True, + num_workers=self.num_workers, + persistent_workers=True, + **kwargs, + ) + + def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: + """ + Returns the validation DataLoader. + + Args: + *args: Additional positional arguments (unused). + **kwargs: Additional keyword arguments, passed to dataloader(). + + Returns: + Union[DataLoader, List[DataLoader]]: A DataLoader object for validation data. + """ + return self.dataloader( + "validation", + shuffle=False, + num_workers=self.num_workers, + persistent_workers=True, + **kwargs, + ) + + def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: + """ + Returns the test DataLoader. + + Args: + *args: Additional positional arguments (unused). + **kwargs: Additional keyword arguments, passed to dataloader(). + + Returns: + Union[DataLoader, List[DataLoader]]: A DataLoader object for test data. + """ + return self.dataloader("test", shuffle=False, **kwargs) + + def predict_dataloader( + self, *args, **kwargs + ) -> Union[DataLoader, List[DataLoader]]: + """ + Returns the predict DataLoader. + + Args: + *args: Additional positional arguments (unused). + **kwargs: Additional keyword arguments, passed to dataloader(). + + Returns: + Union[DataLoader, List[DataLoader]]: A DataLoader object for prediction data. + """ + return self.dataloader(self.prediction_kind, shuffle=False, **kwargs) + + def setup(self, **kwargs): + """ + Setup the data module. + + This method checks for the processed data and sets up the data module for training, validation, and testing. + + Args: + **kwargs: Additional keyword arguments. + """ + rank_zero_info(f"Check for processed data in {self.processed_dir}") + rank_zero_info(f"Cross-validation enabled: {self.use_inner_cross_validation}") + if any( + not os.path.isfile(os.path.join(self.processed_dir, f)) + for f in self.processed_file_names + ): + self.setup_processed() + + if not ("keep_reader" in kwargs and kwargs["keep_reader"]): + self.reader.on_finish() + + def setup_processed(self): + """ + Setup the processed data. + + This method should be implemented by subclasses to handle the specific setup of processed data. + """ + raise NotImplementedError + + @property + def processed_main_file_names_dict(self) -> dict: + """ + Returns a dictionary mapping processed data file names. + + Returns: + dict: A dictionary mapping dataset key to their respective file names. + For example, {"data": "data.pkl"}. + """ + raise NotImplementedError + + @property + def processed_main_file_names(self) -> List[str]: + """ + Returns a list of file names for processed data (before tokenization). + + Returns: + List[str]: A list of file names corresponding to the processed data. + """ + return list(self.processed_main_file_names_dict.values()) + + @property + def processed_file_names_dict(self) -> dict: + """ + Returns a dictionary for the processed and tokenized data files. + + Returns: + dict: A dictionary mapping dataset keys to their respective file names. + For example, {"data": "data.pt"}. + """ + raise NotImplementedError + + @property + def processed_file_names(self) -> List[str]: + """ + Returns a list of file names for processed data. + + Returns: + List[str]: A list of file names corresponding to the processed data. + """ + return list(self.processed_file_names_dict.values()) + + @property + def raw_file_names(self) -> List[str]: + """ + Returns the list of raw file names. + + Returns: + List[str]: The list of raw file names. + """ + return list(self.raw_file_names_dict.values()) + + @property + def raw_file_names_dict(self) -> dict: + """ + Returns the dictionary of raw file names (i.e., files that are directly obtained from an external source). + + This property should be implemented by subclasses to provide the dictionary of raw file names. + + Returns: + dict: The dictionary of raw file names. + """ + raise NotImplementedError + + @property + def label_number(self) -> int: + """ + Returns the number of labels. + + This property should be implemented by subclasses to provide the number of labels. + + Returns: + int: The number of labels. Returns -1 for seq2seq encoding. + """ + raise NotImplementedError + + +class MergedDataset(XYBaseDataModule): + MERGED = [] + + @property + def _name(self) -> str: + """ + Returns a concatenated name of all subset names. + """ + return "+".join(s._name for s in self.subsets) + + def __init__( + self, + batch_size: int = 1, + train_split: float = 0.85, + reader_kwargs: Union[None, List[dict]] = None, + **kwargs, + ): + """ + Args: + batch_size (int): Batch size for data loaders. + train_split (float): Fraction of data to use for training. + reader_kwargs (Union[None, List[dict]]): Optional arguments for subset readers. + **kwargs: Additional arguments to pass to LightningDataModule. + """ + if reader_kwargs is None: + reader_kwargs = [None for _ in self.MERGED] + self.train_split = train_split + self.batch_size = batch_size + self.subsets = [ + s(train_split=train_split, reader_kwargs=kws) + for s, kws in zip(self.MERGED, reader_kwargs) + ] + self.reader = self.subsets[0].reader + os.makedirs(self.processed_dir, exist_ok=True) + super(pl.LightningDataModule, self).__init__(**kwargs) + + def prepare_data(self): + """ + Placeholder for data preparation logic. + """ + for s in self.subsets: + s.prepare_data() + + def setup(self, **kwargs): + """ + Setup the data module. + + This method checks for the processed data and sets up the data module for training, validation, and testing. + + Args: + **kwargs: Additional keyword arguments. + """ + for s in self.subsets: + s.setup(**kwargs) + + def dataloader(self, kind: str, **kwargs) -> DataLoader: + """ + Creates a DataLoader for a specific subset. + + Args: + kind (str): Kind of data loader ('train', 'validation', or 'test'). + **kwargs: Additional arguments passed to DataLoader. + + Returns: + DataLoader: DataLoader object for the specified subset. + """ + subdatasets = [ + torch.load(os.path.join(s.processed_dir, f"{kind}.pt"), weights_only=False) + for s in self.subsets + ] + dataset = [ + self._process_data(i, d) + for i, (s, lim) in enumerate(zip(subdatasets, self.limits)) + for d in (s if lim is None else s[:lim]) + ] + return DataLoader( + dataset, + collate_fn=self.reader.collator, + batch_size=self.batch_size, + **kwargs, + ) + + def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: + """ + Returns the training DataLoader. + """ + return self.dataloader("train", shuffle=True, **kwargs) + + def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: + """ + Returns the validation DataLoader. + """ + return self.dataloader("validation", shuffle=False, **kwargs) + + def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: + """ + Returns the test DataLoader. + """ + return self.dataloader("test", shuffle=False, **kwargs) + + def _process_data(self, subset_id: int, data: dict) -> dict: + """ + Processes data from a subset. + + Args: + subset_id (int): Index of the subset. + data (dict): Data from the subset. + + Returns: + dict: Processed data with 'features', 'labels', and 'ident' keys. + """ + return dict( + features=data["features"], labels=data["labels"], ident=data["ident"] + ) + + def setup_processed(self): + """ + Placeholder for setup logic after data processing. + """ + pass + + @property + def processed_file_names(self) -> List[str]: + """ + Returns the list of processed file names. + """ + return ["test.pt", "train.pt", "validation.pt"] + + @property + def label_number(self) -> int: + """ + Returns the number of labels from the first subset. + """ + return self.subsets[0].label_number + + @property + def limits(self): + """ + Returns None, assuming no limits on data slicing. + """ + return None + + +class _DynamicDataset(XYBaseDataModule, ABC): + """ + A class for extracting and processing data from the given dataset. + + The processed and transformed data is stored in `data.pkl` and `data.pt` format as a whole respectively, + rather than as separate train, validation, and test splits, with dynamic splitting of data.pt occurring at runtime. + The `_DynamicDataset` class manages data splits by either generating them during execution or retrieving them from + a CSV file. + If no split file path is provided, `_generate_dynamic_splits` creates the training, validation, and test splits + from the encoded/transformed data, storing them in `_dynamic_df_train`, `_dynamic_df_val`, and `_dynamic_df_test`. + When a split file path is provided, `_retrieve_splits_from_csv` loads splits from the CSV file, which must + include 'id' and 'split' columns. + The `dynamic_split_dfs` property ensures that the necessary splits are loaded as required. + + Args: + dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42. + splits_file_path (str, optional): Path to the splits CSV file. Defaults to None. + **kwargs: Additional keyword arguments passed to XYBaseDataModule. + + Attributes: + dynamic_data_split_seed (int): The seed for random data splitting, default is 42. + splits_file_path (Optional[str]): Path to the CSV file containing split assignments. + """ + + # ---- Index for columns of processed `data.pkl` (should be derived from `_graph_to_raw_dataset` method) ------ + _ID_IDX: int = None + _DATA_REPRESENTATION_IDX: int = None + _LABELS_START_IDX: int = None + + def __init__( + self, + **kwargs, + ): + super(_DynamicDataset, self).__init__(**kwargs) + self.dynamic_data_split_seed = int(kwargs.get("seed", 42)) # default is 42 + # Class variables to store the dynamics splits + self._dynamic_df_train = None + self._dynamic_df_test = None + self._dynamic_df_val = None + # Path of csv file which contains a list of ids & their assignment to a dataset (either train, + # validation or test). + self.splits_file_path = self._validate_splits_file_path( + kwargs.get("splits_file_path", None) + ) + + @staticmethod + def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]: + """ + Validates the file in provided splits file path. + + Args: + splits_file_path (Optional[str]): Path to the splits CSV file. + + Returns: + Optional[str]: Validated splits file path if checks pass, None if splits_file_path is None. + + Raises: + FileNotFoundError: If the splits file does not exist. + ValueError: If splits file is empty or missing required columns ('id' and/or 'split'), or not a CSV file. + """ + if splits_file_path is None: + return None + + if not os.path.isfile(splits_file_path): + raise FileNotFoundError(f"File {splits_file_path} does not exist") + + file_size = os.path.getsize(splits_file_path) + if file_size == 0: + raise ValueError(f"File {splits_file_path} is empty") + + # Check if the file has a CSV extension + if not splits_file_path.lower().endswith(".csv"): + raise ValueError(f"File {splits_file_path} is not a CSV file") + + # Read the first row of CSV file into a DataFrame + splits_df = pd.read_csv(splits_file_path, nrows=1) + + # Check if 'id' and 'split' columns are in the DataFrame + required_columns = {"id", "split"} + if not required_columns.issubset(splits_df.columns): + raise ValueError( + f"CSV file {splits_file_path} is missing required columns ('id' and/or 'split')." + ) + + return splits_file_path + + # ------------------------------ Phase: Prepare data ----------------------------------- + def prepare_data(self, *args: Any, **kwargs: Any) -> None: + """ + Prepares the data for the dataset. + + This method checks for the presence of raw data in the specified directory. + If the raw data is missing, it fetches the ontology and creates a dataframe and saves it to a data.pkl file. + + The resulting dataframe/pickle file is expected to contain columns with the following structure: + - Column at index `self._ID_IDX`: ID of data instance + - Column at index `self._DATA_REPRESENTATION_IDX`: Sequence representation of the protein + - Column from index `self._LABELS_START_IDX` onwards: Labels + + Args: + *args: Variable length argument list. + **kwargs: Arbitrary keyword arguments. + + Returns: + None + """ + print("Checking for processed data in", self.processed_dir_main) + + processed_name = self.processed_main_file_names_dict["data"] + if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)): + print(f"Missing processed data file (`{processed_name}` file)") + os.makedirs(self.processed_dir_main, exist_ok=True) + data_path = self._download_required_data() + g = self._extract_class_hierarchy(data_path) + data_df = self._graph_to_raw_dataset(g) + self.save_processed(data_df, processed_name) + + @abstractmethod + def _download_required_data(self) -> str: + """ + Downloads the required raw data. + + Returns: + str: Path to the downloaded data. + """ + pass + + @abstractmethod + def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: + """ + Extracts the class hierarchy from the data. + Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from + the term documents. + + Args: + data_path (str): Path to the data. + + Returns: + nx.DiGraph: The class hierarchy graph. + """ + pass + + @abstractmethod + def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: + """ + Converts the graph to a raw dataset. + Uses the graph created by `_extract_class_hierarchy` method to extract the + raw data in Dataframe format with additional columns corresponding to each multi-label class. + + Args: + graph (nx.DiGraph): The class hierarchy graph. + + Returns: + pd.DataFrame: The raw dataset. + """ + pass + + @abstractmethod + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + """ + Selects classes from the dataset based on a specified criteria. + + Args: + g (nx.Graph): The graph representing the dataset. + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + + Returns: + List: A sorted list of node IDs that meet the specified criteria. + """ + pass + + def save_processed(self, data: pd.DataFrame, filename: str) -> None: + """ + Save the processed dataset to a pickle file. + + Args: + data (pd.DataFrame): The processed dataset to be saved. + filename (str): The filename for the pickle file. + """ + pd.to_pickle(data, open(os.path.join(self.processed_dir_main, filename), "wb")) + + # ------------------------------ Phase: Setup data ----------------------------------- + def setup_processed(self) -> None: + """ + Transforms `data.pkl` into a model input data format (`data.pt`), ensuring that the data is in a format + compatible for input to the model. + The transformed data contains the following keys: `ident`, `features`, `labels`, and `group`. + This method uses a subclass of Data Reader to perform the transformation. + + Returns: + None + """ + os.makedirs(self.processed_dir, exist_ok=True) + transformed_file_name = self.processed_file_names_dict["data"] + print( + f"Missing transformed data (`{transformed_file_name}` file). Transforming data.... " + ) + torch.save( + self._load_data_from_file( + os.path.join( + self.processed_dir_main, + self.processed_main_file_names_dict["data"], + ) + ), + os.path.join(self.processed_dir, transformed_file_name), + ) + + @staticmethod + def _get_data_size(input_file_path: str) -> int: + """ + Get the size of the data from a pickled file. + + Args: + input_file_path (str): The path to the file. + + Returns: + int: The size of the data. + """ + with open(input_file_path, "rb") as f: + return len(pd.read_pickle(f)) + + @abstractmethod + def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: + """ + Loads data from given pickled file and yields individual dictionaries for each row. + + This method is used by `_load_data_from_file` to generate dictionaries that are then + processed and converted into a list of dictionaries containing the features and labels. + + Args: + input_file_path (str): The path to the pickled input file. + + Yields: + Generator[Dict[str, Any], None, None]: Generator yielding dictionaries. + + """ + pass + + # ------------------------------ Phase: Dynamic Splits ----------------------------------- + @property + def dynamic_split_dfs(self) -> Dict[str, pd.DataFrame]: + """ + Property to retrieve dynamic train, validation, and test splits. + + This property checks if dynamic data splits (`_dynamic_df_train`, `_dynamic_df_val`, `_dynamic_df_test`) + are already loaded. If any of them is None, it either generates them dynamically or retrieves them + from data file with help of pre-existing split csv file (`splits_file_path`) containing splits assignments. + + Returns: + dict: A dictionary containing the dynamic train, validation, and test DataFrames. + Keys are 'train', 'validation', and 'test'. + """ + if any( + split is None + for split in [ + self._dynamic_df_test, + self._dynamic_df_val, + self._dynamic_df_train, + ] + ): + if self.splits_file_path is None: + # Generate splits based on given seed, create csv file to records the splits + self._generate_dynamic_splits() + else: + # If user has provided splits file path, use it to get the splits from the data + self._retrieve_splits_from_csv() + return { + "train": self._dynamic_df_train, + "validation": self._dynamic_df_val, + "test": self._dynamic_df_test, + } + + def _generate_dynamic_splits(self) -> None: + """ + Generate data splits during runtime and save them in class variables. + + This method loads encoded data and generates train, validation, and test splits based on the loaded data. + """ + print("\nGenerate dynamic splits...") + df_train, df_val, df_test = self._get_data_splits() + + # Generate splits.csv file to store ids of each corresponding split + split_assignment_list: List[pd.DataFrame] = [ + pd.DataFrame({"id": df_train["ident"], "split": "train"}), + pd.DataFrame({"id": df_val["ident"], "split": "validation"}), + pd.DataFrame({"id": df_test["ident"], "split": "test"}), + ] + + combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) + combined_split_assignment.to_csv( + os.path.join(self.processed_dir_main, "splits.csv"), index=False + ) + + # Store the splits in class variables + self._dynamic_df_train = df_train + self._dynamic_df_val = df_val + self._dynamic_df_test = df_test + + @abstractmethod + def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """ + Retrieve the train, validation, and test data splits for the dataset. + + This method returns data splits according to specific criteria implemented + in the subclasses. + + Returns: + tuple: A tuple containing DataFrames for train, validation, and test splits. + """ + pass + + def get_test_split( + self, df: pd.DataFrame, seed: Optional[int] = None + ) -> Tuple[pd.DataFrame, pd.DataFrame]: + """ + Split the input DataFrame into training and testing sets based on multilabel stratified sampling. + + This method uses MultilabelStratifiedShuffleSplit to split the data such that the distribution of labels + in the training and testing sets is approximately the same. The split is based on the "labels" column + in the DataFrame. + + Args: + df (pd.DataFrame): The input DataFrame containing the data to be split. It must contain a column + named "labels" with the multilabel data. + seed (int, optional): The random seed to be used for reproducibility. Default is None. + + Returns: + Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the training set and testing set DataFrames. + + Raises: + ValueError: If the DataFrame does not contain a column named "labels". + """ + print("Get test data split") + + labels_list = df["labels"].tolist() + + test_size = 1 - self.train_split - (1 - self.train_split) ** 2 + + if len(labels_list[0]) > 1: + splitter = MultilabelStratifiedShuffleSplit( + n_splits=1, test_size=test_size, random_state=seed + ) + else: + splitter = StratifiedShuffleSplit( + n_splits=1, test_size=test_size, random_state=seed + ) + + train_indices, test_indices = next(splitter.split(labels_list, labels_list)) + + df_train = df.iloc[train_indices] + df_test = df.iloc[test_indices] + return df_train, df_test + + def get_train_val_splits_given_test( + self, df: pd.DataFrame, test_df: pd.DataFrame, seed: int = None + ) -> Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: + """ + Split the dataset into train and validation sets, given a test set. + Use test set (e.g., loaded from another source or generated in get_test_split), to avoid overlap + + Args: + df (pd.DataFrame): The original dataset. + test_df (pd.DataFrame): The test dataset. + seed (int, optional): The random seed to be used for reproducibility. Default is None. + + Returns: + Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: A dictionary containing train and + validation sets if self.use_inner_cross_validation is True, otherwise a tuple containing the train + and validation DataFrames. The keys are the names of the train and validation sets, and the values + are the corresponding DataFrames. + """ + print(f"Split dataset into train / val with given test set") + + test_ids = test_df["ident"].tolist() + df_trainval = df[~df["ident"].isin(test_ids)] + labels_list_trainval = df_trainval["labels"].tolist() + + if self.use_inner_cross_validation: + folds = {} + kfold = MultilabelStratifiedKFold( + n_splits=self.inner_k_folds, random_state=seed + ) + for fold, (train_ids, val_ids) in enumerate( + kfold.split( + labels_list_trainval, + labels_list_trainval, + ) + ): + df_validation = df_trainval.iloc[val_ids] + df_train = df_trainval.iloc[train_ids] + folds[self.raw_file_names_dict[f"fold_{fold}_train"]] = df_train + folds[self.raw_file_names_dict[f"fold_{fold}_validation"]] = ( + df_validation + ) + + return folds + + # scale val set size by 1/self.train_split to compensate for (hypothetical) test set size (1-self.train_split) + test_size = ((1 - self.train_split) ** 2) / self.train_split + + if len(labels_list_trainval[0]) > 1: + splitter = MultilabelStratifiedShuffleSplit( + n_splits=1, test_size=test_size, random_state=seed + ) + else: + splitter = StratifiedShuffleSplit( + n_splits=1, test_size=test_size, random_state=seed + ) + + train_indices, validation_indices = next( + splitter.split(labels_list_trainval, labels_list_trainval) + ) + + df_validation = df_trainval.iloc[validation_indices] + df_train = df_trainval.iloc[train_indices] + return df_train, df_validation + + def _retrieve_splits_from_csv(self) -> None: + """ + Retrieve previously saved data splits from splits.csv file or from provided file path. + + This method loads the splits.csv file located at `self.splits_file_path`. + It then loads the encoded data (`data.pt`) and filters it based on the IDs retrieved from + splits.csv to reconstruct the train, validation, and test splits. + """ + print(f"\nLoading splits from {self.splits_file_path}...") + splits_df = pd.read_csv(self.splits_file_path) + + filename = self.processed_file_names_dict["data"] + data = self.load_processed_data(filename=filename) + df_data = pd.DataFrame(data) + + train_ids = splits_df[splits_df["split"] == "train"]["id"] + validation_ids = splits_df[splits_df["split"] == "validation"]["id"] + test_ids = splits_df[splits_df["split"] == "test"]["id"] + + self._dynamic_df_train = df_data[df_data["ident"].isin(train_ids)] + self._dynamic_df_val = df_data[df_data["ident"].isin(validation_ids)] + self._dynamic_df_test = df_data[df_data["ident"].isin(test_ids)] + + def load_processed_data( + self, kind: Optional[str] = None, filename: Optional[str] = None + ) -> List[Dict[str, Any]]: + """ + Loads processed data from a specified dataset type or file. + + This method retrieves processed data based on the dataset type (`kind`) such as "train", + "val", or "test", or directly from a provided filename. When `kind` is specified, the method + leverages the `dynamic_split_dfs` property to dynamically generate or retrieve the corresponding + data splits if they are not already loaded. If both `kind` and `filename` are provided, `filename` + takes precedence. + + Args: + kind (str, optional): The type of dataset to load ("train", "val", or "test"). + If `filename` is provided, this argument is ignored. Defaults to None. + filename (str, optional): The name of the file to load the dataset from. + If provided, this takes precedence over `kind`. Defaults to None. + + Returns: + List[Dict[str, Any]]: A list of dictionaries, where each dictionary contains + the processed data for an individual data point. + + Raises: + ValueError: If both `kind` and `filename` are None, as one of them is required to load the dataset. + KeyError: If the specified `kind` does not exist in the `dynamic_split_dfs` property or + `processed_file_names_dict`, when expected. + FileNotFoundError: If the file corresponding to the provided `filename` does not exist. + """ + if kind is None and filename is None: + raise ValueError( + "Either kind or filename is required to load the correct dataset, both are None" + ) + + # If both kind and filename are given, use filename + if kind is not None and filename is None: + try: + if self.use_inner_cross_validation and kind != "test": + filename = self.processed_file_names_dict[ + f"fold_{self.fold_index}_{kind}" + ] + else: + data_df = self.dynamic_split_dfs[kind] + return data_df.to_dict(orient="records") + except KeyError: + kind = f"{kind}" + + # If filename is provided + try: + return torch.load( + os.path.join(self.processed_dir, filename), weights_only=False + ) + except FileNotFoundError: + raise FileNotFoundError(f"File {filename} doesn't exist") + + # ------------------------------ Phase: Raw Properties ----------------------------------- + @property + @abstractmethod + def base_dir(self) -> str: + """ + Returns the base directory path for storing data. + + Returns: + str: The path to the base directory. + """ + pass + + @property + def processed_dir_main(self) -> str: + """ + Returns the main directory path where processed data is stored. + + Returns: + str: The path to the main processed data directory, based on the base directory and the instance's name. + """ + return os.path.join( + self.base_dir, + self._name, + "processed", + ) + + @property + def processed_main_file_names_dict(self) -> dict: + """ + Returns a dictionary mapping processed data file names. + + Returns: + dict: A dictionary mapping dataset key to their respective file names. + For example, {"data": "data.pkl"}. + """ + return {"data": "data.pkl"} + + @property + def raw_file_names(self) -> List[str]: + """ + Returns a list of raw file names. + + Returns: + List[str]: A list of file names corresponding to the raw data. + """ + return list(self.raw_file_names_dict.values()) + + @property + def processed_file_names_dict(self) -> dict: + """ + Returns a dictionary for the processed and tokenized data files. + + Returns: + dict: A dictionary mapping dataset keys to their respective file names. + For example, {"data": "data.pt"}. + """ + return {"data": "data.pt"} diff --git a/chebai/preprocessing/datasets/deepGO/__init__.py b/chebai/preprocessing/datasets/deepGO/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chebai/preprocessing/datasets/deepGO/go_uniprot.py b/chebai/preprocessing/datasets/deepGO/go_uniprot.py new file mode 100644 index 0000000..1b0eb2a --- /dev/null +++ b/chebai/preprocessing/datasets/deepGO/go_uniprot.py @@ -0,0 +1,1007 @@ +# References for this file : +# Reference 1: +# Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf; +# DeepGO: Predicting protein functions from sequence and interactions +# using a deep ontology-aware classifier, Bioinformatics, 2017. +# https://doi.org/10.1093/bioinformatics/btx624 +# Github: https://github.com/bio-ontology-research-group/deepgo + +# Reference 2: +# https://www.ebi.ac.uk/GOA/downloads +# https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt +# https://www.uniprot.org/uniprotkb + +# Reference 3: +# Kulmanov, M., Guzmรกn-Vega, F.J., Duek Roggli, +# P. et al. Protein function prediction as approximate semantic entailment. Nat Mach Intell 6, 220โ€“228 (2024). +# https://doi.org/10.1038/s42256-024-00795-w +# https://github.com/bio-ontology-research-group/deepgo2 + +__all__ = [ + "GOUniProtOver250", + "GOUniProtOver50", + "EXPERIMENTAL_EVIDENCE_CODES", + "AMBIGUOUS_AMINO_ACIDS", + "DeepGO1MigratedData", + "DeepGO2MigratedData", +] + +import gzip +import itertools +import os +import shutil +from abc import ABC, abstractmethod +from collections import OrderedDict +from tempfile import NamedTemporaryFile +from typing import Any, Dict, Generator, List, Optional, Tuple, Union + +import fastobo +import networkx as nx +import pandas as pd +import requests +import torch +import tqdm +from Bio import SwissProt + +from chebai.preprocessing import reader as dr +from chebai.preprocessing.datasets.base import _DynamicDataset + +# https://github.com/bio-ontology-research-group/deepgo/blob/master/utils.py#L15 +EXPERIMENTAL_EVIDENCE_CODES = { + "EXP", + "IDA", + "IPI", + "IMP", + "IGI", + "IEP", + "TAS", + "IC", + # New evidence codes added in latest paper year 2024 Reference number 3 + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/utils.py#L24-L26 + "HTP", + "HDA", + "HMP", + "HGI", + "HEP", +} + +# https://github.com/bio-ontology-research-group/deepgo/blob/d97447a05c108127fee97982fd2c57929b2cf7eb/aaindex.py#L8 +# https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L10 +# `X` is now considered as valid amino acid, as per latest paper year 2024 Refernce number 3 +AMBIGUOUS_AMINO_ACIDS = {"B", "O", "J", "U", "Z", "*"} + + +class _GOUniProtDataExtractor(_DynamicDataset, ABC): + """ + A class for extracting and processing data from the Gene Ontology (GO) dataset and the Swiss UniProt dataset. + + Args: + dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42. + splits_file_path (str, optional): Path to the splits CSV file. Defaults to None. + max_sequence_length (int, optional): Specifies the maximum allowed sequence length for a protein, with a + default of 1002. During data preprocessing, any proteins exceeding this length will be excluded from further + processing. + **kwargs: Additional keyword arguments passed to DynamicDataset and XYBaseDataModule. + + Attributes: + dynamic_data_split_seed (int): The seed for random data splitting, default is 42. + max_sequence_length (int, optional): Specifies the maximum allowed sequence length for a protein, with a + default of 1002. During data preprocessing, any proteins exceeding this length will be excluded from further + processing. + splits_file_path (Optional[str]): Path to the CSV file containing split assignments. + """ + + _GO_DATA_INIT = "GO" + _SWISS_DATA_INIT = "SWISS" + + # -- Index for columns of processed `data.pkl` (derived from `_get_swiss_to_go_mapping` & `_graph_to_raw_dataset` + # "swiss_id" at row index 0 + # "accession" at row index 1 + # "go_ids" at row index 2 + # "sequence" at row index 3 + # labels starting from row index 4 + _ID_IDX: int = 0 + _DATA_REPRESENTATION_IDX: int = 3 # here `sequence` column + _LABELS_START_IDX: int = 4 + + _GO_DATA_URL: str = "https://purl.obolibrary.org/obo/go/go-basic.obo" + _SWISS_DATA_URL: str = ( + "https://ftp.uniprot.org/pub/databases/uniprot/knowledgebase/complete/uniprot_sprot.dat.gz" + ) + + # Gene Ontology (GO) has three major branches, one for biological processes (BP), molecular functions (MF) and + # cellular components (CC). The value "all" will take data related to all three branches into account. + _ALL_GO_BRANCHES: str = "all" + _GO_BRANCH_NAMESPACE: Dict[str, str] = { + "BP": "biological_process", + "MF": "molecular_function", + "CC": "cellular_component", + } + + def __init__(self, **kwargs): + self.go_branch: str = self._get_go_branch(**kwargs) + + self.max_sequence_length: int = int(kwargs.get("max_sequence_length", 1002)) + assert ( + self.max_sequence_length >= 1 + ), "Max sequence length should be greater than or equal to 1." + + super(_GOUniProtDataExtractor, self).__init__(**kwargs) + + if self.reader.n_gram is not None: + assert self.max_sequence_length >= self.reader.n_gram, ( + f"max_sequence_length ({self.max_sequence_length}) must be greater than " + f"or equal to n_gram ({self.reader.n_gram})." + ) + + @classmethod + def _get_go_branch(cls, **kwargs) -> str: + """ + Retrieves the Gene Ontology (GO) branch based on provided keyword arguments. + This method checks if a valid GO branch value is provided in the keyword arguments. + + Args: + **kwargs: Arbitrary keyword arguments. Specifically looks for: + - "go_branch" (str): The desired GO branch. + Returns: + str: The GO branch value. This will be one of the allowed values. + + Raises: + ValueError: If the provided 'go_branch' value is not in the allowed list of values. + """ + + go_branch_value = kwargs.get("go_branch", cls._ALL_GO_BRANCHES) + allowed_values = list(cls._GO_BRANCH_NAMESPACE.keys()) + [cls._ALL_GO_BRANCHES] + if go_branch_value not in allowed_values: + raise ValueError( + f"Invalid value for go_branch: {go_branch_value}, Allowed values: {allowed_values}" + ) + return go_branch_value + + # ------------------------------ Phase: Prepare data ----------------------------------- + def _download_required_data(self) -> str: + """ + Downloads the required raw data related to Gene Ontology (GO) and Swiss-UniProt dataset. + + Returns: + str: Path to the downloaded data. + """ + self._download_swiss_uni_prot_data() + return self._download_gene_ontology_data() + + def _download_gene_ontology_data(self) -> str: + """ + Download the Gene Ontology data `.obo` file. + + Note: + Quote from : https://geneontology.org/docs/download-ontology/ + Three versions of the ontology are available, the one use in this method is described below: + https://purl.obolibrary.org/obo/go/go-basic.obo + The basic version of the GO, filtered such that the graph is guaranteed to be acyclic and annotations + can be propagated up the graph. The relations included are `is a, part of, regulates, negatively` + `regulates` and `positively regulates`. This version excludes relationships that cross the 3 GO + hierarchies. This version should be used with most GO-based annotation tools. + + Returns: + str: The file path of the loaded Gene Ontology data. + """ + go_path = os.path.join(self.raw_dir, self.raw_file_names_dict["GO"]) + os.makedirs(os.path.dirname(go_path), exist_ok=True) + + if not os.path.isfile(go_path): + print("Missing Gene Ontology raw data") + print(f"Downloading Gene Ontology data....") + r = requests.get(self._GO_DATA_URL, allow_redirects=True) + r.raise_for_status() # Check if the request was successful + open(go_path, "wb").write(r.content) + return go_path + + def _download_swiss_uni_prot_data(self) -> Optional[str]: + """ + Download the Swiss-Prot data file from UniProt Knowledgebase. + + Note: + UniProt Knowledgebase is collection of functional information on proteins, with accurate, consistent + and rich annotation. + + Swiss-Prot contains manually-annotated records with information extracted from literature and + curator-evaluated computational analysis. + + Returns: + str: The file path of the loaded Swiss-Prot data file. + """ + uni_prot_file_path = os.path.join( + self.raw_dir, self.raw_file_names_dict["SwissUniProt"] + ) + os.makedirs(os.path.dirname(uni_prot_file_path), exist_ok=True) + + if not os.path.isfile(uni_prot_file_path): + print(f"Downloading Swiss UniProt data....") + + # Create a temporary file + with NamedTemporaryFile(delete=False) as tf: + temp_filename = tf.name + print(f"Downloading to temporary file {temp_filename}") + + # Download the file + response = requests.get(self._SWISS_DATA_URL, stream=True) + with open(temp_filename, "wb") as temp_file: + shutil.copyfileobj(response.raw, temp_file) + + print(f"Downloaded to {temp_filename}") + + # Unpack the gzipped file + try: + print(f"Unzipping the file....") + with gzip.open(temp_filename, "rb") as f_in: + output_file_path = uni_prot_file_path + with open(output_file_path, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + print(f"Unpacked and saved to {output_file_path}") + + except Exception as e: + print(f"Failed to unpack the file: {e}") + finally: + # Clean up the temporary file + os.remove(temp_filename) + print(f"Removed temporary file {temp_filename}") + + return uni_prot_file_path + + def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: + """ + Extracts the class hierarchy from the GO ontology. + Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with GO term data. + + Args: + data_path (str): The path to the GO ontology. + + Returns: + nx.DiGraph: A directed graph representing the class hierarchy, where nodes are GO terms and edges + represent parent-child relationships. + """ + print("Extracting class hierarchy...") + elements = [] + for term in fastobo.load(data_path): + if isinstance(term, fastobo.typedef.TypedefFrame): + # ---- To avoid term frame of the below format/structure ---- + # [Typedef] + # id: part_of + # name: part of + # namespace: external + # xref: BFO:0000050 + # is_transitive: true + continue + + if ( + term + and isinstance(term.id, fastobo.id.PrefixedIdent) + and term.id.prefix == self._GO_DATA_INIT + ): + # Consider only terms with id in following format - GO:2001271 + term_dict = self.term_callback(term) + if term_dict: + elements.append(term_dict) + + g = nx.DiGraph() + + # Add GO term nodes to the graph and their hierarchical ontology + for n in elements: + g.add_node(n["go_id"], **n) + g.add_edges_from( + [ + (parent_id, node_id) + for node_id in g.nodes + for parent_id in g.nodes[node_id]["parents"] + if parent_id in g.nodes + ] + ) + + print("Compute transitive closure") + return nx.transitive_closure_dag(g) + + def term_callback(self, term: fastobo.term.TermFrame) -> Union[Dict, bool]: + """ + Extracts information from a Gene Ontology (GO) term document. + + Args: + term: A Gene Ontology term Frame document. + + Returns: + Optional[Dict]: A dictionary containing the extracted information if the term is not obsolete, + otherwise None. The dictionary includes: + - "id" (str): The ID of the GO term. + - "parents" (List[str]): A list of parent term IDs. + - "name" (str): The name of the GO term. + """ + parents = [] + name = None + + for clause in term: + if isinstance(clause, fastobo.term.NamespaceClause): + if ( + self.go_branch != self._ALL_GO_BRANCHES + and clause.namespace.escaped + != self._GO_BRANCH_NAMESPACE[self.go_branch] + ): + # if the term document is not related to given go branch (except `all`), skip this document. + return False + + if isinstance(clause, fastobo.term.IsObsoleteClause): + if clause.obsolete: + # if the term document contains clause as obsolete as true, skips this document. + return False + + if isinstance(clause, fastobo.term.IsAClause): + parents.append(self._parse_go_id(clause.term)) + elif isinstance(clause, fastobo.term.NameClause): + name = clause.name + + return { + "go_id": self._parse_go_id(term.id), + "parents": parents, + "name": name, + } + + @staticmethod + def _parse_go_id(go_id: str) -> int: + """ + Helper function to parse and normalize GO term IDs. + + Args: + go_id: The raw GO term ID string. + + Returns: + str: The parsed and normalized GO term ID. + """ + # `is_a` clause has GO id in the following formats: + # GO:0009968 ! negative regulation of signal transduction + # GO:0046780 + return int(str(go_id).split(":")[1].split("!")[0].strip()) + + def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: + """ + Processes a directed acyclic graph (DAG) to create a raw dataset in DataFrame format. The dataset includes + Swiss-Prot protein data and their associations with Gene Ontology (GO) terms. + + Note: + - GO classes are used as labels in the dataset. Each GO term is represented as a column, and its value + indicates whether a Swiss-Prot protein is associated with that GO term. + - Swiss-Prot proteins serve as samples. There is no 1-to-1 correspondence between Swiss-Prot proteins + and GO terms. + + Data Format: pd.DataFrame + - Column 0 : swiss_id (Identifier for SwissProt protein) + - Column 1 : Accession of the protein + - Column 2 : GO IDs (associated GO terms) + - Column 3 : Sequence of the protein + - Column 4 to Column "n": Each column corresponding to a class with value True/False indicating whether the + protein is associated with this GO term. + + Args: + g (nx.DiGraph): The class hierarchy graph. + + Returns: + pd.DataFrame: The raw dataset created from the graph. + """ + print(f"Processing graph") + + data_df = self._get_swiss_to_go_mapping() + # add ancestors to go ids + data_df["go_ids"] = data_df["go_ids"].apply( + lambda go_ids: sorted( + set( + itertools.chain.from_iterable( + [ + [go_id] + list(g.predecessors(go_id)) + for go_id in go_ids + if go_id in g.nodes + ] + ) + ) + ) + ) + # Initialize the GO term labels/columns to False + selected_classes = self.select_classes(g, data_df=data_df) + new_label_columns = pd.DataFrame( + False, index=data_df.index, columns=selected_classes + ) + data_df = pd.concat([data_df, new_label_columns], axis=1) + + # Set True for the corresponding GO IDs in the DataFrame go labels/columns + for index, row in data_df.iterrows(): + for go_id in row["go_ids"]: + if go_id in data_df.columns: + data_df.at[index, go_id] = True + + # This filters the DataFrame to include only the rows where at least one value in the row from 5th column + # onwards is True/non-zero. + # Quote from DeepGo Paper: `For training and testing, we use proteins which have been annotated with at least + # one GO term from the set of the GO terms for the model` + data_df = data_df[data_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)] + return data_df + + def _get_swiss_to_go_mapping(self) -> pd.DataFrame: + """ + Parses the Swiss-Prot data and returns a DataFrame mapping Swiss-Prot records to Gene Ontology (GO) data. + + The DataFrame includes the following columns: + - "swiss_id": The unique identifier for each Swiss-Prot record. + - "sequence": The protein sequence. + - "accessions": Comma-separated list of accession numbers. + - "go_ids": List of GO IDs associated with the Swiss-Prot record. + + Note: + This mapping is necessary because the GO data does not include the protein sequence representation. + We select proteins with annotations having experimental evidence codes, as specified in + `EXPERIMENTAL_EVIDENCE_CODES` and filter the proteins by a maximum length of 1002, ignoring proteins with + ambiguous amino acid codes specified in `AMBIGUOUS_AMINO_ACIDS` in their sequence. + + Check the link below for keyword details: + https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt + + Returns: + pd.DataFrame: A DataFrame where each row corresponds to a Swiss-Prot record with its associated GO data. + """ + + print("Parsing swiss uniprot raw data....") + + swiss_ids, sequences, accessions, go_ids_list = [], [], [], [] + + swiss_data = SwissProt.parse( + open( + os.path.join(self.raw_dir, self.raw_file_names_dict["SwissUniProt"]), + "r", + ) + ) + + for record in swiss_data: + if record.data_class != "Reviewed": + # To consider only manually-annotated swiss data + continue + + if not record.sequence or len(record.sequence) > self.max_sequence_length: + # Consider protein with only sequence representation and seq. length not greater than max seq. length + continue + + if any(aa in AMBIGUOUS_AMINO_ACIDS for aa in record.sequence): + # Skip proteins with ambiguous amino acid codes + continue + + go_ids = [] + + for cross_ref in record.cross_references: + if cross_ref[0] == self._GO_DATA_INIT: + # One swiss data protein can correspond to many GO data instances + + if len(cross_ref) <= 3: + # No evidence code + continue + + # https://github.com/bio-ontology-research-group/deepgo/blob/master/get_functions.py#L63-L66 + evidence_code = cross_ref[3].split(":")[0] + if evidence_code not in EXPERIMENTAL_EVIDENCE_CODES: + # Skip GO id without the required experimental evidence codes + continue + + go_ids.append(self._parse_go_id(cross_ref[1])) + + if not go_ids: + # Skip Swiss proteins without mapping to GO data + continue + + swiss_ids.append(record.entry_name) + sequences.append(record.sequence) + accessions.append(",".join(record.accessions)) + go_ids.sort() + go_ids_list.append(go_ids) + + data_dict = OrderedDict( + swiss_id=swiss_ids, # swiss_id column at index 0 + accession=accessions, # Accession column at index 1 + go_ids=go_ids_list, # Go_ids (data representation) column at index 2 + sequence=sequences, # Sequence column at index 3 + ) + + return pd.DataFrame(data_dict) + + # ------------------------------ Phase: Setup data ----------------------------------- + def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: + """ + Loads data from a pickled file and yields individual dictionaries for each row. + + The pickled file is expected to contain rows with the following structure: + - Data at row index `self._ID_IDX`: ID of go data instance + - Data at row index `self._DATA_REPRESENTATION_IDX`: Sequence representation of protein + - Data from row index `self._LABELS_START_IDX` onwards: Labels + + This method is used by `_load_data_from_file` to generate dictionaries that are then + processed and converted into a list of dictionaries containing the features and labels. + + Args: + input_file_path (str): The path to the pickled input file. + + Yields: + Dict[str, Any]: A dictionary containing: + - `features` (str): The sequence data from the file. + - `labels` (np.ndarray): A boolean array of labels starting from row index 4. + - `ident` (Any): The identifier from row index 0. + """ + with open(input_file_path, "rb") as input_file: + df = pd.read_pickle(input_file) + for row in df.values: + labels = row[self._LABELS_START_IDX :].astype(bool) + # chebai.preprocessing.reader.DataReader only needs features, labels, ident, group + # "group" set to None, by default as no such entity for this data + yield dict( + features=row[self._DATA_REPRESENTATION_IDX], + labels=labels, + ident=row[self._ID_IDX], + ) + + # ------------------------------ Phase: Dynamic Splits ----------------------------------- + def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """ + Loads encoded data and generates training, validation, and test splits. + + This method attempts to load encoded data from a file named `data.pt`. It then splits this data into + training, validation, and test sets. + + Raises: + FileNotFoundError: If the `data.pt` file does not exist. Ensure that `prepare_data` and/or + `setup` methods are called to generate the necessary dataset files. + + Returns: + Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: A tuple containing three DataFrames: + - Training set + - Validation set + - Test set + """ + try: + filename = self.processed_file_names_dict["data"] + data_go = torch.load( + os.path.join(self.processed_dir, filename), weights_only=False + ) + except FileNotFoundError: + raise FileNotFoundError( + f"File data.pt doesn't exists. " + f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files" + ) + + df_go_data = pd.DataFrame(data_go) + train_df_go, df_test = self.get_test_split( + df_go_data, seed=self.dynamic_data_split_seed + ) + + # Get all splits + df_train, df_val = self.get_train_val_splits_given_test( + train_df_go, + df_test, + seed=self.dynamic_data_split_seed, + ) + + return df_train, df_val, df_test + + # ------------------------------ Phase: Raw Properties ----------------------------------- + @property + def base_dir(self) -> str: + """ + Returns the base directory path for storing GO-Uniprot data. + + Returns: + str: The path to the base directory, which is "data/GO_UniProt". + """ + return os.path.join("data", f"GO_UniProt") + + @property + def raw_file_names_dict(self) -> dict: + """ + Returns a dictionary of raw file names used in data processing. + + Returns: + dict: A dictionary mapping dataset names to their respective file names. + For example, {"GO": "go-basic.obo", "SwissUniProt": "uniprot_sprot.dat"}. + """ + return {"GO": "go-basic.obo", "SwissUniProt": "uniprot_sprot.dat"} + + +class _GOUniProtOverX(_GOUniProtDataExtractor, ABC): + """ + A class for extracting data from the Gene Ontology (GO) dataset with a threshold for selecting classes based on + the number of subclasses. + + This class is designed to filter GO classes based on a specified threshold, selecting only those classes + which have a certain number of subclasses in the hierarchy. + + Attributes: + READER (dr.ProteinDataReader): The reader used for reading the dataset. + THRESHOLD (int): The threshold for selecting classes based on the number of subclasses. + """ + + READER: dr.ProteinDataReader = dr.ProteinDataReader + THRESHOLD: int = None + + @property + def _name(self) -> str: + """ + Returns the name of the dataset. + + 'max_sequence_length' in the name indicates that proteins with sequence lengths exceeding are ignored + in the dataset. + + Returns: + str: The dataset name, formatted with the current threshold value and/or given go_branch. + """ + if self.go_branch != self._ALL_GO_BRANCHES: + return f"GO{self.THRESHOLD}_{self.go_branch}_{self.max_sequence_length}" + + return f"GO{self.THRESHOLD}_{self.max_sequence_length}" + + def select_classes( + self, g: nx.DiGraph, *args: Any, **kwargs: Dict[str, Any] + ) -> List[int]: + """ + Selects classes (GO terms) from the Gene Ontology (GO) dataset based on the number of annotations meeting a + specified threshold. + + The selection process is based on the annotations of the GO terms with its ancestors across the dataset. + + Annotations are calculated by counting how many times each GO term, along with its ancestral hierarchy, + is annotated per protein across the dataset. + This means that for each protein, the GO terms associated with it are considered, and the entire hierarchical + structure (ancestors) of each GO term is taken into account. The total count for each GO term and its ancestors + reflects how frequently these terms are annotated across all proteins in the dataset. + + Args: + g (nx.DiGraph): The directed acyclic graph representing the GO dataset, where each node corresponds to a GO term. + *args: Additional positional arguments (not used). + **kwargs: Additional keyword arguments, including: + - data_df (pd.DataFrame): A DataFrame containing the GO annotations for various proteins. + It should include a 'go_ids' column with the GO terms associated with each protein. + + Returns: + List[int]: A sorted list of selected GO term IDs that meet the annotation threshold criteria. + + Side Effects: + - Writes the list of selected GO term IDs to a file named "classes.txt" in the specified processed directory. + + Raises: + AttributeError: If the 'data_df' argument is not provided in kwargs. + + Notes: + - The `THRESHOLD` attribute, which defines the minimum number of annotations required to select a GO term, should be defined in the subclass. + """ + # Retrieve the DataFrame containing GO annotations per protein from the keyword arguments + data_df = kwargs.get("data_df", None) + if data_df is None or not isinstance(data_df, pd.DataFrame) or data_df.empty: + raise AttributeError( + "The 'data_df' argument must be provided and must be a non-empty pandas DataFrame." + ) + + print(f"Selecting GO terms based on given threshold: {self.THRESHOLD} ...") + + # https://github.com/bio-ontology-research-group/deepgo/blob/master/get_functions.py#L59-L77 + go_term_annot: Dict[int, int] = {} + for idx, row in data_df.iterrows(): + # Count the annotations for each go_id **`per protein`** + for go_id in row["go_ids"]: + if go_id not in go_term_annot: + go_term_annot[go_id] = 0 + go_term_annot[go_id] += 1 + + # Select GO terms that meet or exceed the threshold of annotations + selected_nodes: List[int] = [ + go_id + for go_id in g.nodes + if go_id in go_term_annot and go_term_annot[go_id] >= self.THRESHOLD + ] + + # Sort the selected nodes (optional but often useful for consistent output) + selected_nodes.sort() + + # Write the selected node IDs/classes to the file + filename = "classes.txt" + with open(os.path.join(self.processed_dir_main, filename), "wt") as fout: + fout.writelines(str(node) + "\n" for node in selected_nodes) + + return selected_nodes + + +class GOUniProtOver250(_GOUniProtOverX): + """ + A class for extracting data from the Gene Ontology (GO) dataset with a threshold of 250 for selecting classes. + + Inherits from `_GOUniProtOverX` and sets the threshold for selecting classes to 250. + + Attributes: + THRESHOLD (int): The threshold for selecting classes (250). + """ + + THRESHOLD: int = 250 + + +class GOUniProtOver50(_GOUniProtOverX): + """ + A class for extracting data from the Gene Ontology (GO) dataset with a threshold of 50 for selecting classes. + + Inherits from `_GOUniProtOverX` and sets the threshold for selecting classes to 50. + + Attributes: + THRESHOLD (int): The threshold for selecting classes (50). + """ + + THRESHOLD: int = 50 + + +class _DeepGOMigratedData(_GOUniProtDataExtractor, ABC): + """ + Base class for use of the migrated DeepGO data with common properties, name formatting, and file paths. + + Attributes: + READER (dr.ProteinDataReader): Protein data reader class. + THRESHOLD (Optional[int]): Threshold value for GO class selection, + determined by the GO branch type in derived classes. + """ + + READER: dr.ProteinDataReader = dr.ProteinDataReader + THRESHOLD: Optional[int] = None + + # Mapping from GO branch conventions used in DeepGO to our conventions + GO_BRANCH_MAPPING: dict = { + "cc": "CC", + "mf": "MF", + "bp": "BP", + } + + @property + def _name(self) -> str: + """ + Generates a unique identifier for the migrated data based on the GO + branch and max sequence length, optionally including a threshold. + + Returns: + str: A formatted name string for the data. + """ + threshold_part = f"GO{self.THRESHOLD}_" if self.THRESHOLD is not None else "GO_" + + if self.go_branch != self._ALL_GO_BRANCHES: + return f"{threshold_part}{self.go_branch}_{self.max_sequence_length}" + + return f"{threshold_part}{self.max_sequence_length}" + + # ------------------------------ Phase: Prepare data ----------------------------------- + def prepare_data(self, *args: Any, **kwargs: Any) -> None: + """ + Checks for the existence of migrated DeepGO data in the specified directory. + Raises an error if the required data file is not found, prompting + migration from DeepGO to this data structure. + + Args: + *args (Any): Additional positional arguments. + **kwargs (Any): Additional keyword arguments. + + Raises: + FileNotFoundError: If the processed data file does not exist. + """ + print("Checking for processed data in", self.processed_dir_main) + + processed_name = self.processed_main_file_names_dict["data"] + if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)): + raise FileNotFoundError( + f"File {processed_name} not found.\n" + f"You must run the appropriate DeepGO migration script " + f"(chebai/preprocessing/migration/deep_go) before executing this configuration " + f"to migrate data from DeepGO to this data structure." + ) + + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + # Selection of GO classes not needed for migrated data + pass + + # ------------------------------ Phase: Raw Properties ----------------------------------- + @property + @abstractmethod + def processed_main_file_names_dict(self) -> Dict[str, str]: + """ + Abstract property for defining main processed file names. + These files are stored in the same directory as the generated data files + but have distinct names to differentiate them during training. + + Returns: + dict: A dictionary with key-value pairs for main processed file names. + """ + pass + + @property + @abstractmethod + def processed_file_names_dict(self) -> Dict[str, str]: + """ + Abstract property for defining additional processed file names. + These files are stored in the same directory as the generated data files + but have distinct names to differentiate them during training. + + Returns: + dict: A dictionary with key-value pairs for processed file names. + """ + pass + + +class DeepGO1MigratedData(_DeepGOMigratedData): + """ + Migrated data class specific to DeepGO1. Sets threshold values according + to the research paper based on the GO branch. + + Note: + Refer reference number 1 at the top of this file for the corresponding research paper. + + Args: + **kwargs: Arbitrary keyword arguments passed to the superclass. + + Raises: + ValueError: If an unsupported GO branch is provided. + """ + + def __init__(self, **kwargs): + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 + assert int(kwargs.get("max_sequence_length")) == 1002 + + # Set threshold based on GO branch, as per DeepGO1 paper and its data. + if kwargs.get("go_branch") in ["CC", "MF"]: + self.THRESHOLD = 50 + elif kwargs.get("go_branch") == "BP": + self.THRESHOLD = 250 + else: + raise ValueError( + f"DeepGO1 paper has no defined threshold for branch {self.go_branch}" + ) + + super(_DeepGOMigratedData, self).__init__(**kwargs) + + @property + def processed_main_file_names_dict(self) -> Dict[str, str]: + """ + Returns main processed file names specific to DeepGO1. + + Returns: + dict: Dictionary with the main data file name for DeepGO1. + """ + return {"data": "data_deep_go1.pkl"} + + @property + def processed_file_names_dict(self) -> Dict[str, str]: + """ + Returns processed file names specific to DeepGO1. + + Returns: + dict: Dictionary with data file name for DeepGO1. + """ + return {"data": "data_deep_go1.pt"} + + +class DeepGO2MigratedData(_DeepGOMigratedData): + """ + Migrated data class specific to DeepGO2, inheriting from DeepGO1MigratedData + with different processed file names. + + Note: + Refer reference number 3 at the top of this file for the corresponding research paper. + + Returns: + dict: Dictionary with file names specific to DeepGO2. + """ + + _LABELS_START_IDX: int = 5 # additional esm2_embeddings column in the dataframe + _ESM_EMBEDDINGS_COL_IDX: int = 4 + + def __init__(self, use_esm2_embeddings=False, **kwargs): + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 + assert int(kwargs.get("max_sequence_length")) == 1000 + self.use_esm2_embeddings: bool = use_esm2_embeddings + super(_DeepGOMigratedData, self).__init__(**kwargs) + + # ------------------------------ Phase: Setup data ----------------------------------- + def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: + """ + Load and process data from a file into a list of dictionaries containing features and labels. + + This method processes data differently based on the `use_esm2_embeddings` flag: + - If `use_esm2_embeddings` is True, raw dictionaries from `_load_dict` are returned, _load_dict already returns + the numerical features (esm2 embeddings) from the data file, hence no reader is required. + - Otherwise, a reader is used to process the data (generate numerical features). + + Args: + path (str): The path to the input file. + + Returns: + List[Dict[str, Any]]: A list of dictionaries with the following keys: + - `features`: Sequence or embedding data, depending on the context. + - `labels`: A boolean array of labels. + - `ident`: The identifier for the sequence. + """ + lines = self._get_data_size(path) + print(f"Processing {lines} lines...") + + if self.use_esm2_embeddings: + data = [ + d + for d in tqdm.tqdm(self._load_dict(path), total=lines) + if d["features"] is not None + ] + else: + data = [ + self.reader.to_data(d) + for d in tqdm.tqdm(self._load_dict(path), total=lines) + if d["features"] is not None + ] + + # filter for missing features in resulting data + data = [val for val in data if val["features"] is not None] + + return data + + def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: + """ + Loads data from a pickled file and yields individual dictionaries for each row. + + The pickled file is expected to contain rows with the following structure: + - Data at row index `self._ID_IDX`: ID of go data instance + - Data at row index `self._DATA_REPRESENTATION_IDX`: Sequence representation of protein + - Data at row index `self._ESM2_EMBEDDINGS_COL_IDX`: ESM2 embeddings of the protein + - Data from row index `self._LABELS_START_IDX` onwards: Labels + + The method adapts based on the `use_esm2_embeddings` flag: + - If `use_esm2_embeddings` is True, features are loaded from the column specified by `self._ESM_EMBEDDINGS_COL_IDX`. + - Otherwise, features are loaded from the column specified by `self._DATA_REPRESENTATION_IDX`. + + Args: + input_file_path (str): The path to the pickled input file. + + Yields: + Dict[str, Any]: A dictionary containing: + - `features` (Any): Sequence or embedding data for the instance. + - `labels` (np.ndarray): A boolean array of labels starting from row index 4. + - `ident` (Any): The identifier from row index 0. + """ + with open(input_file_path, "rb") as input_file: + df = pd.read_pickle(input_file) + + if self.use_esm2_embeddings: + features_idx = self._ESM_EMBEDDINGS_COL_IDX + else: + features_idx = self._DATA_REPRESENTATION_IDX + + for row in df.values: + labels = row[self._LABELS_START_IDX :].astype(bool) + yield dict( + features=row[features_idx], + labels=labels, + ident=row[self._ID_IDX], + ) + + # ------------------------------ Phase: Raw Properties ----------------------------------- + @property + def processed_main_file_names_dict(self) -> Dict[str, str]: + """ + Returns main processed file names specific to DeepGO2. + + Returns: + dict: Dictionary with the main data file name for DeepGO2. + """ + return {"data": "data_deep_go2.pkl"} + + @property + def processed_file_names_dict(self) -> Dict[str, str]: + """ + Returns processed file names specific to DeepGO2. + + Returns: + dict: Dictionary with data file name for DeepGO2. + """ + return {"data": "data_deep_go2.pt"} + + @property + def identifier(self) -> tuple: + """Identifier for the dataset.""" + if self.use_esm2_embeddings: + return (dr.ESM2EmbeddingReader.name(),) + return (self.reader.name(),) diff --git a/chebai/preprocessing/datasets/deepGO/protein_pretraining.py b/chebai/preprocessing/datasets/deepGO/protein_pretraining.py new file mode 100644 index 0000000..8f7e9c4 --- /dev/null +++ b/chebai/preprocessing/datasets/deepGO/protein_pretraining.py @@ -0,0 +1,279 @@ +__all__ = ["SwissProteinPretrain"] + +import os +from abc import ABC +from collections import OrderedDict +from typing import Any, Dict, Generator, List, Tuple + +import networkx as nx +import pandas as pd +import torch +from Bio import SwissProt +from sklearn.model_selection import train_test_split + +from chebai.preprocessing.datasets.base import _DynamicDataset +from chebai.preprocessing.datasets.deepGO.go_uniprot import ( + AMBIGUOUS_AMINO_ACIDS, + EXPERIMENTAL_EVIDENCE_CODES, + GOUniProtOver250, +) +from chebai.preprocessing.reader import ProteinDataReader + + +class _ProteinPretrainingData(_DynamicDataset, ABC): + """ + Data module for pretraining protein sequences, specifically designed for Swiss-UniProt data. It includes methods for + data preparation, loading, and dynamic splitting of protein sequences. + The data is parsed and filtered to only select proteins with no associated `valid` Gene Ontology (GO) labels. + A valid GO label is the one which has one of evidence codes defined in `EXPERIMENTAL_EVIDENCE_CODES`. + """ + + _ID_IDX: int = 0 + _DATA_REPRESENTATION_IDX: int = 1 # Index of `sequence` column + + def __init__(self, **kwargs): + """ + Initializes the data module with any GOUniProt extractor class object. + + Args: + **kwargs: Additional arguments for the superclass initialization. + """ + self._go_uniprot_extractor = GOUniProtOver250() + assert self._go_uniprot_extractor.go_branch == GOUniProtOver250._ALL_GO_BRANCHES + + self.max_sequence_length: int = int(kwargs.get("max_sequence_length", 1002)) + assert ( + self.max_sequence_length >= 1 + ), "Max sequence length should be greater than or equal to 1." + + super(_ProteinPretrainingData, self).__init__(**kwargs) + + if self.reader.n_gram is not None: + assert self.max_sequence_length >= self.reader.n_gram, ( + f"max_sequence_length ({self.max_sequence_length}) must be greater than " + f"or equal to n_gram ({self.reader.n_gram})." + ) + + # ------------------------------ Phase: Prepare data ----------------------------------- + def prepare_data(self, *args: Any, **kwargs: Any) -> None: + """ + Prepares the data by downloading and parsing Swiss-Prot data if not already available. Saves the processed data + for further use. + + Args: + *args: Additional positional arguments. + **kwargs: Additional keyword arguments. + """ + processed_name = self.processed_main_file_names_dict["data"] + if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)): + print("Missing processed data file (`data.pkl` file)") + os.makedirs(self.processed_dir_main, exist_ok=True) + self._download_required_data() + protein_df = self._parse_protein_data_for_pretraining() + self.save_processed(protein_df, processed_name) + + def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: + # method not required as no Swiss-UniProt has no ontological data + pass + + def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: + # method not required as no Swiss-UniProt has no ontological data + pass + + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: + # method not required as no Swiss-UniProt has no ontological data + pass + + def _download_required_data(self) -> str: + """ + Downloads the required Swiss-Prot data using the GOUniProt extractor class. + + Returns: + str: Path to the downloaded data. + """ + return self._go_uniprot_extractor._download_swiss_uni_prot_data() + + def _parse_protein_data_for_pretraining(self) -> pd.DataFrame: + """ + Parses the Swiss-Prot data and returns a DataFrame containing Swiss-Prot proteins which does not have any valid + Gene Ontology(GO) label. A valid GO label is the one which has one of the following evidence codes, as specified in + `EXPERIMENTAL_EVIDENCE_CODES`. + + The DataFrame includes the following columns: + - "swiss_id": The unique identifier for each Swiss-Prot record. + - "sequence": The protein sequence. + + Note: + We ignore proteins with ambiguous amino acid specified in `AMBIGUOUS_AMINO_ACIDS` in their sequence.` + + Returns: + pd.DataFrame: A DataFrame where each row corresponds to a Swiss-Prot record with not associated valid GO. + """ + print("Parsing swiss uniprot raw data....") + + swiss_ids, sequences = [], [] + + swiss_data = SwissProt.parse( + open( + os.path.join( + self._go_uniprot_extractor.raw_dir, + self._go_uniprot_extractor.raw_file_names_dict["SwissUniProt"], + ), + "r", + ) + ) + + for record in swiss_data: + if record.data_class != "Reviewed": + # To consider only manually-annotated swiss data + continue + + if not record.sequence: + # Consider protein with only sequence representation + continue + + if len(record.sequence) > self.max_sequence_length: + # Consider protein with only sequence length not greater than max seq. length + continue + + if any(aa in AMBIGUOUS_AMINO_ACIDS for aa in record.sequence): + # Skip proteins with ambiguous amino acid codes + continue + + has_valid_associated_go_label = False + for cross_ref in record.cross_references: + if cross_ref[0] == self._go_uniprot_extractor._GO_DATA_INIT: + + if len(cross_ref) <= 3: + # No evidence code + continue + + # https://github.com/bio-ontology-research-group/deepgo/blob/master/get_functions.py#L63-L66 + evidence_code = cross_ref[3].split(":")[0] + if evidence_code in EXPERIMENTAL_EVIDENCE_CODES: + has_valid_associated_go_label = True + break + + if has_valid_associated_go_label: + # Skip proteins which has at least one associated go label + continue + + swiss_ids.append(record.entry_name) + sequences.append(record.sequence) + + data_dict = OrderedDict( + swiss_id=swiss_ids, # swiss_id column at index 0 + sequence=sequences, # Sequence column at index 1 + ) + + return pd.DataFrame(data_dict) + + # ------------------------------ Phase: Setup data ----------------------------------- + def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: + """ + Loads data from a pickled file and yields individual dictionaries for each row. + + The pickled file is expected to contain rows with the following structure: + - Data at row index `self._ID_IDX`: ID of go data instance + - Data at row index `self._DATA_REPRESENTATION_IDX`: Sequence representation of protein + + This method is used by `_load_data_from_file` to generate dictionaries that are then + processed and converted into a list of dictionaries containing the features and labels. + + Args: + input_file_path (str): The path to the pickled input file. + + Yields: + Dict[str, Any]: A dictionary containing: + - `features` (str): The sequence data from the file. + - `ident` (Any): The identifier from row index 0. + - `labels`: Set to None + """ + with open(input_file_path, "rb") as input_file: + df = pd.read_pickle(input_file) + for row in df.values: + yield dict( + features=row[self._DATA_REPRESENTATION_IDX], + ident=row[self._ID_IDX], + labels=None, + ) + + # ------------------------------ Phase: Dynamic Splits ----------------------------------- + def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """ + Loads encoded data and generates training, validation, and test splits. + + This method attempts to load encoded data from a file named `data.pt`. It then splits this data into + training, validation, and test sets. + + Raises: + FileNotFoundError: If the `data.pt` file does not exist. Ensure that `prepare_data` and/or + `setup` methods are called to generate the necessary dataset files. + + Returns: + Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: A tuple containing three DataFrames: + - Training set + - Validation set + - Test set + """ + try: + filename = self.processed_file_names_dict["data"] + data_go = torch.load( + os.path.join(self.processed_dir, filename), weights_only=False + ) + except FileNotFoundError: + raise FileNotFoundError( + f"File data.pt doesn't exists. " + f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files" + ) + + df_go_data = pd.DataFrame(data_go) + train_df_go, df_test = train_test_split( + df_go_data, + train_size=self.train_split, + random_state=self.dynamic_data_split_seed, + ) + + # Get all splits + df_train, df_val = train_test_split( + train_df_go, + train_size=self.train_split, + random_state=self.dynamic_data_split_seed, + ) + + return df_train, df_val, df_test + + # ------------------------------ Phase: Raw Properties ----------------------------------- + @property + def base_dir(self) -> str: + """ + str: The base directory for pretraining data storage. + """ + return os.path.join(self._go_uniprot_extractor.base_dir, "Pretraining") + + @property + def raw_dir(self) -> str: + """Name of the directory where the raw data is stored.""" + return self._go_uniprot_extractor.raw_dir + + +class SwissProteinPretrain(_ProteinPretrainingData): + """ + Data module for Swiss-Prot protein pretraining, inheriting from `_ProteinPretrainingData`. + This class is specifically designed to handle data processing and loading for Swiss-Prot-based protein datasets. + + Attributes: + READER (Type): The data reader class used to load and process protein pretraining data. + """ + + READER = ProteinDataReader + + @property + def _name(self) -> str: + """ + The name identifier for this data module. + + Returns: + str: A string identifier, "SwissProteinPretrain", representing the name of this data module. + """ + return f"Swiss_{self.max_sequence_length}" diff --git a/chebai/preprocessing/datasets/scope/__init__.py b/chebai/preprocessing/datasets/scope/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chebai/preprocessing/datasets/scope/scope.py b/chebai/preprocessing/datasets/scope/scope.py new file mode 100644 index 0000000..e9127b2 --- /dev/null +++ b/chebai/preprocessing/datasets/scope/scope.py @@ -0,0 +1,972 @@ +# References for this file : + +# Reference 1: +# John-Marc Chandonia, Naomi K Fox, Steven E Brenner, SCOPe: classification of large macromolecular structures +# in the structural classification of proteinsโ€”extended database, Nucleic Acids Research, Volume 47, +# Issue D1, 08 January 2019, Pages D475โ€“D481, https://doi.org/10.1093/nar/gky1134 +# https://scop.berkeley.edu/about/ver=2.08 + +# Reference 2: +# Murzin AG, Brenner SE, Hubbard TJP, Chothia C. 1995. SCOP: a structural classification of proteins database for +# the investigation of sequences and structures. Journal of Molecular Biology 247:536-540 + +import gzip +import os +import re +import shutil +from abc import ABC, abstractmethod +from tempfile import NamedTemporaryFile +from typing import Any, Dict, Generator, List, Optional, Tuple + +import networkx as nx +import pandas as pd +import requests +import torch +from Bio import SeqIO + +from chebai.preprocessing.datasets.base import _DynamicDataset +from chebai.preprocessing.reader import ProteinDataReader + + +class _SCOPeDataExtractor(_DynamicDataset, ABC): + """ + A class for extracting and processing data from the SCOPe (Structural Classification of Proteins - extended) dataset. + + This class is designed to handle the parsing, preprocessing, and hierarchical structure extraction from various + SCOPe dataset files, such as classification (CLA), hierarchy (HIE), and description (DES) files. + Additionally, it supports downloading related data like PDB sequence files. + + Args: + scope_version (str): The SCOPe version to use. + scope_version_train (Optional[str]): The training SCOPe version, if different. + dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42. + splits_file_path (str, optional): Path to the splits CSV file. Defaults to None. + **kwargs: Additional keyword arguments passed to DynamicDataset and XYBaseDataModule. + """ + + # -- Index for columns of processed `data.pkl` (derived from `_graph_to_raw_dataset`) + # "id" at row index 0 + # "sids" at row index 1 + # "sequence" at row index 2 + # labels starting from row index 3 + _ID_IDX: int = 0 + _DATA_REPRESENTATION_IDX: int = 2 # here `sequence` column + _LABELS_START_IDX: int = 3 + + _SCOPE_GENERAL_URL = "https://scop.berkeley.edu/downloads/parse/dir.{data_type}.scope.{version_number}-stable.txt" + _PDB_SEQUENCE_DATA_URL = ( + "https://files.rcsb.org/pub/pdb/derived_data/pdb_seqres.txt.gz" + ) + + SCOPE_HIERARCHY: Dict[str, str] = { + "cl": "class", + "cf": "fold", + "sf": "superfamily", + "fa": "family", + "dm": "protein", + "sp": "species", + "px": "domain", + } + + def __init__( + self, + scope_version: str, + scope_version_train: Optional[str] = None, + max_sequence_len: int = 1000, + **kwargs, + ): + self.scope_version: str = scope_version + self.scope_version_train: str = scope_version_train + self.max_sequence_len: int = max_sequence_len + + super(_SCOPeDataExtractor, self).__init__(**kwargs) + + if self.scope_version_train is not None: + # Instantiate another same class with "scope_version" as "scope_version_train", if train_version is given + # This is to get the data from respective directory related to "scope_version_train" + _init_kwargs = kwargs + _init_kwargs["scope_version"] = self.scope_version_train + self._scope_version_train_obj = self.__class__( + **_init_kwargs, + ) + + @staticmethod + def _get_scope_url(data_type: str, version_number: str) -> str: + """ + Generates the URL for downloading SCOPe files. + + Args: + data_type (str): The type of data (e.g., 'cla', 'hie', 'des'). + version_number (str): The version of the SCOPe file. + + Returns: + str: The formatted SCOPe file URL. + """ + return _SCOPeDataExtractor._SCOPE_GENERAL_URL.format( + data_type=data_type, version_number=version_number + ) + + # ------------------------------ Phase: Prepare data ----------------------------------- + def _download_required_data(self) -> str: + """ + Downloads the required raw data for SCOPe and PDB sequence datasets. + + Returns: + str: Path to the downloaded data. + """ + self._download_pdb_sequence_data() + return self._download_scope_raw_data() + + def _download_pdb_sequence_data(self) -> None: + """ + Downloads and unzips the PDB sequence dataset from the RCSB PDB repository. + + The file is downloaded as a temporary gzip file, which is then extracted to the + specified directory. + """ + pdb_seq_file_path = os.path.join( + self.scope_root_dir, self.raw_file_names_dict["PDB"] + ) + os.makedirs(os.path.dirname(pdb_seq_file_path), exist_ok=True) + + if not os.path.isfile(pdb_seq_file_path): + print(f"Missing PDB raw data, Downloading PDB sequence data....") + + # Create a temporary file + with NamedTemporaryFile(delete=False) as tf: + temp_filename = tf.name + print(f"Downloading to temporary file {temp_filename}") + + # Download the file + response = requests.get(self._PDB_SEQUENCE_DATA_URL, stream=True) + with open(temp_filename, "wb") as temp_file: + shutil.copyfileobj(response.raw, temp_file) + + print(f"Downloaded to {temp_filename}") + + # Unpack the gzipped file + try: + print(f"Unzipping the file....") + with gzip.open(temp_filename, "rb") as f_in: + output_file_path = pdb_seq_file_path + with open(output_file_path, "wb") as f_out: + shutil.copyfileobj(f_in, f_out) + print(f"Unpacked and saved to {output_file_path}") + + except Exception as e: + print(f"Failed to unpack the file: {e}") + finally: + # Clean up the temporary file + os.remove(temp_filename) + print(f"Removed temporary file {temp_filename}") + + def _download_scope_raw_data(self) -> str: + """ + Downloads the raw SCOPe dataset files (CLA, HIE, DES, and COM). + + Each file is downloaded from the SCOPe repository and saved to the specified directory. + Files are only downloaded if they do not already exist. + + Returns: + str: A dummy path to indicate completion (can be extended for custom behavior). + """ + os.makedirs(self.raw_dir, exist_ok=True) + for data_type in ["CLA", "HIE", "DES"]: + data_file_name = self.raw_file_names_dict[data_type] + scope_path = os.path.join(self.raw_dir, data_file_name) + if not os.path.isfile(scope_path): + print(f"Missing Scope: {data_file_name} raw data, Downloading...") + r = requests.get( + self._get_scope_url(data_type.lower(), self.scope_version), + allow_redirects=False, + verify=False, # Disable SSL verification + ) + r.raise_for_status() # Check if the request was successful + open(scope_path, "wb").write(r.content) + return "dummy/path" + + def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: + """ + Extracts the class hierarchy from SCOPe data and computes its transitive closure. + + Args: + data_path (str): Path to the processed SCOPe dataset. + + Returns: + nx.DiGraph: A directed acyclic graph representing the SCOPe class hierarchy. + """ + print("Extracting class hierarchy...") + df_scope = self._get_scope_data() + pdb_chain_df = self._parse_pdb_sequence_file() + pdb_id_set = set(pdb_chain_df["pdb_id"]) # Search time complexity - O(1) + + # Initialize sets and dictionaries for storing edges and attributes + parent_node_edges, node_child_edges = set(), set() + node_attrs = {} + px_level_nodes = set() + sequence_nodes = dict() + px_to_seq_edges = set() + required_graph_nodes = set() + + # Create a lookup dictionary for PDB chain sequences + lookup_dict = ( + pdb_chain_df.groupby("pdb_id")[["chain_id", "sequence"]] + .apply(lambda x: dict(zip(x["chain_id"], x["sequence"]))) + .to_dict() + ) + + def add_sequence_nodes_edges(chain_sequence, px_sun_id): + """Adds sequence nodes and edges connecting px-level nodes to sequence nodes.""" + if chain_sequence not in sequence_nodes: + sequence_nodes[chain_sequence] = f"seq_{len(sequence_nodes)}" + px_to_seq_edges.add((px_sun_id, sequence_nodes[chain_sequence])) + + # Step 1: Build the graph structure and store node attributes + for row in df_scope.itertuples(index=False): + if row.level == "px": + + pdb_id, chain_id = row.sid[1:5], row.sid[5] + + if pdb_id not in pdb_id_set or chain_id == "_": + # Don't add domain level nodes that don't have pdb_id in pdb_sequences.txt file + # Also chain_id with "_" which corresponds to no chain + continue + px_level_nodes.add(row.sunid) + + # Add edges between px-level nodes and sequence nodes + if chain_id != ".": + if chain_id not in lookup_dict[pdb_id]: + continue + add_sequence_nodes_edges(lookup_dict[pdb_id][chain_id], row.sunid) + else: + # If chain_id is '.', connect all chains of this PDB ID + for chain, chain_sequence in lookup_dict[pdb_id].items(): + add_sequence_nodes_edges(chain_sequence, row.sunid) + else: + required_graph_nodes.add(row.sunid) + + node_attrs[row.sunid] = {"sid": row.sid, "level": row.level} + + if row.parent_sunid != -1: + parent_node_edges.add((row.parent_sunid, row.sunid)) + + for child_id in row.children_sunids: + node_child_edges.add((row.sunid, child_id)) + + del df_scope, pdb_chain_df, pdb_id_set + + g = nx.DiGraph() + g.add_nodes_from(node_attrs.items()) + # Note - `add_edges` internally create a node, if a node doesn't exist already + g.add_edges_from({(p, c) for p, c in parent_node_edges if p in node_attrs}) + g.add_edges_from({(p, c) for p, c in node_child_edges if c in node_attrs}) + + seq_nodes = set(sequence_nodes.values()) + g.add_nodes_from([(seq_id, {"level": "sequence"}) for seq_id in seq_nodes]) + g.add_edges_from( + { + (px_node, seq_node) + for px_node, seq_node in px_to_seq_edges + if px_node in node_attrs and seq_node in seq_nodes + } + ) + + # Step 2: Count sequence successors for required graph nodes only + for node in required_graph_nodes: + num_seq_successors = sum( + g.nodes[child]["level"] == "sequence" + for child in nx.descendants(g, node) + ) + g.nodes[node]["num_seq_successors"] = num_seq_successors + + # Step 3: Remove nodes which are not required before computing transitive closure for better efficiency + g.remove_nodes_from(px_level_nodes | seq_nodes) + + print("Computing Transitive Closure.........") + # Transitive closure is not needed in `select_classes` method but is required in _SCOPeOverXPartial + return nx.transitive_closure_dag(g) + + def _get_scope_data(self) -> pd.DataFrame: + """ + Merges and preprocesses the SCOPe classification, hierarchy, and description files into a unified DataFrame. + + Returns: + pd.DataFrame: A DataFrame containing combined SCOPe data with classification and hierarchy details. + """ + df_cla = self._get_classification_data() + df_hie = self._get_hierarchy_data() + df_des = self._get_node_description_data() + df_hie_with_cla = pd.merge(df_hie, df_cla, how="left", on="sunid") + df_all = pd.merge( + df_hie_with_cla, + df_des.drop(columns=["sid"], axis=1), + how="left", + on="sunid", + ) + return df_all + + def _get_classification_data(self) -> pd.DataFrame: + """ + Parses and processes the SCOPe CLA (classification) file. + + Returns: + pd.DataFrame: A DataFrame containing classification details, including hierarchy levels. + """ + df_cla = pd.read_csv( + os.path.join(self.raw_dir, self.raw_file_names_dict["CLA"]), + sep="\t", + header=None, + comment="#", + ) + df_cla.columns = [ + "sid", + "PDB_ID", + "description", + "sccs", + "sunid", + "hie_levels", + ] + + # Convert to dict - {cl:46456, cf:46457, sf:46458, fa:46459, dm:46460, sp:116748, px:113449} + df_cla["hie_levels"] = df_cla["hie_levels"].apply( + lambda x: {k: int(v) for k, v in (item.split("=") for item in x.split(","))} + ) + + # Split ancestor_nodes into separate columns and assign values + for key in self.SCOPE_HIERARCHY.keys(): + df_cla[self.SCOPE_HIERARCHY[key]] = df_cla["hie_levels"].apply( + lambda x: x[key] + ) + + df_cla["sunid"] = df_cla["sunid"].astype("int64") + + return df_cla + + def _get_hierarchy_data(self) -> pd.DataFrame: + """ + Parses and processes the SCOPe HIE (hierarchy) file. + + Returns: + pd.DataFrame: A DataFrame containing hierarchy details, including parent-child relationships. + """ + df_hie = pd.read_csv( + os.path.join(self.raw_dir, self.raw_file_names_dict["HIE"]), + sep="\t", + header=None, + comment="#", + low_memory=False, + ) + df_hie.columns = ["sunid", "parent_sunid", "children_sunids"] + + # if not parent id, then insert -1 + df_hie["parent_sunid"] = df_hie["parent_sunid"].replace("-", -1).astype(int) + # convert children ids to list of ids + df_hie["children_sunids"] = df_hie["children_sunids"].apply( + lambda x: list(map(int, x.split(","))) if x != "-" else [] + ) + + # Ensure the 'sunid' column in both DataFrames has the same type + df_hie["sunid"] = df_hie["sunid"].astype("int64") + return df_hie + + def _get_node_description_data(self) -> pd.DataFrame: + """ + Parses and processes the SCOPe DES (description) file. + + Returns: + pd.DataFrame: A DataFrame containing node-level descriptions from the SCOPe dataset. + """ + df_des = pd.read_csv( + os.path.join(self.raw_dir, self.raw_file_names_dict["DES"]), + sep="\t", + header=None, + comment="#", + low_memory=False, + ) + df_des.columns = ["sunid", "level", "scss", "sid", "description"] + df_des.loc[len(df_des)] = {"sunid": 0, "level": "root"} + + # Ensure the 'sunid' column in both DataFrames has the same type + df_des["sunid"] = df_des["sunid"].astype("int64") + return df_des + + def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: + """ + Processes a directed acyclic graph (DAG) to generate a raw dataset in DataFrame format. This dataset includes + chain-level sequences and their corresponding labels based on the hierarchical structure of the associated domains. + + The process: + - Extracts SCOPe domain identifiers (sids) from the graph. + - Retrieves class labels for each domain based on all applicable taxonomy levels. + - Fetches the chain-level sequences from the Protein Data Bank (PDB) for each domain. + - For each sequence, identifies all domains associated with the same chain and assigns their corresponding labels. + + Notes: + - SCOPe hierarchy levels are used as labels, with each level represented by a column. The value in each column + indicates whether a PDB chain is associated with that particular hierarchy level. + - PDB chains are treated as samples. The method considers only domains that are mapped to the selected hierarchy levels. + + Data Format: pd.DataFrame + - Column 0 : id (Unique identifier for each sequence entry) + - Column 1 : sids (List of domain identifiers associated with the sequence) + - Column 2 : sequence (Amino acid sequence of the chain) + - Column 3 to Column "n": Each column corresponds to a SCOPe class hierarchy level with a value + of True/False indicating whether the chain is associated with the corresponding level. + + Args: + graph (nx.DiGraph): The class hierarchy graph. + + Returns: + pd.DataFrame: The raw dataset created from the graph. + + Raises: + RuntimeError: If no sunids are selected. + """ + print(f"Process graph") + + selected_sun_ids_per_lvl = self.select_classes(graph) + + if not selected_sun_ids_per_lvl: + raise RuntimeError("No sunid selected.") + + df_cla = self._get_classification_data() + hierarchy_levels = list(self.SCOPE_HIERARCHY.values()) + hierarchy_levels.remove("domain") + + df_cla = df_cla[["sid", "sunid"] + hierarchy_levels] + + # Initialize selected target columns + df_encoded = df_cla[["sid", "sunid"]].copy() + + # Collect all new columns in a dictionary first (avoids fragmentation) + encoded_df_columns = {} + + lvl_to_target_cols_mapping = {} + # Iterate over only the selected sun_ids (nodes) to one-hot encode them + for level, selected_sun_ids in selected_sun_ids_per_lvl.items(): + level_column = self.SCOPE_HIERARCHY[level] + if level_column in df_cla.columns: + # Create binary encoding for only relevant sun_ids + for sun_id in selected_sun_ids: + col_name = f"{level_column}_{sun_id}" + encoded_df_columns[col_name] = ( + df_cla[level_column] == sun_id + ).astype(bool) + + lvl_to_target_cols_mapping.setdefault(level_column, []).append( + col_name + ) + + # Convert the dictionary into a DataFrame and concatenate at once (prevents fragmentation) + df_encoded = pd.concat([df_encoded, pd.DataFrame(encoded_df_columns)], axis=1) + + encoded_target_columns = [] + for level in hierarchy_levels: + if level in lvl_to_target_cols_mapping: + encoded_target_columns.extend(lvl_to_target_cols_mapping[level]) + + print( + f"{len(encoded_target_columns)} labels has been selected for specified threshold, " + ) + print("Constructing data.pkl file .....") + + df_encoded = df_encoded[["sid", "sunid"] + encoded_target_columns] + + # Filter to select only domains that atleast map to any one selected sunid in any level + df_encoded = df_encoded[df_encoded.iloc[:, 2:].any(axis=1)] + + df_encoded["pdb_id"] = df_encoded["sid"].str[1:5] + df_encoded["chain_id"] = df_encoded["sid"].str[5] + + # "_" (underscore) means it has no chain + df_encoded = df_encoded[df_encoded["chain_id"] != "_"] + + pdb_chain_df = self._parse_pdb_sequence_file() + + # Handle chain_id == "." - Multiple chain case + # Split df_encoded into two: One for specific chains, one for "multiple chains" (".") + df_specific_chains = df_encoded[df_encoded["chain_id"] != "."] + df_multiple_chains = df_encoded[df_encoded["chain_id"] == "."].drop( + columns=["chain_id"] + ) + + # Merge specific chains normally + merged_specific = df_specific_chains.merge( + pdb_chain_df, on=["pdb_id", "chain_id"], how="left" + ) + + # Merge all chains case -> Join by pdb_id (not chain_id) + merged_all_chains = df_multiple_chains.merge( + pdb_chain_df, on="pdb_id", how="left" + ) + + # Combine both cases + sequence_hierarchy_df = pd.concat( + [merged_specific, merged_all_chains], ignore_index=True + ).dropna(subset=["sequence"]) + + # Vectorized Aggregation Instead of Row-wise Updates + sequence_hierarchy_df = ( + sequence_hierarchy_df.groupby("sequence", as_index=False) + .agg( + { + "sid": list, # Collect all SIDs per sequence + **{ + col: "max" for col in encoded_target_columns + }, # Max works as Bitwise OR for labels + } + ) + .rename(columns={"sid": "sids"}) + ) # Rename for clarity + + sequence_hierarchy_df = sequence_hierarchy_df.assign( + id=range(1, len(sequence_hierarchy_df) + 1) + )[["id", "sids", "sequence"] + encoded_target_columns] + + # Ensure atleast one label is true for each protein sequence + sequence_hierarchy_df = sequence_hierarchy_df[ + sequence_hierarchy_df.iloc[:, self._LABELS_START_IDX :].any(axis=1) + ] + + with open(os.path.join(self.processed_dir_main, "classes.txt"), "wt") as fout: + fout.writelines(str(sun_id) + "\n" for sun_id in encoded_target_columns) + + return sequence_hierarchy_df + + def _parse_pdb_sequence_file(self) -> pd.DataFrame: + """ + Parses the PDB sequence file and returns a DataFrame containing PDB IDs, chain IDs, and sequences. + + Returns: + pd.DataFrame: A DataFrame with columns ["pdb_id", "chain_id", "sequence"]. + """ + records = [] + valid_amino_acids = "".join(ProteinDataReader.AA_LETTER) + + for record in SeqIO.parse( + os.path.join(self.scope_root_dir, self.raw_file_names_dict["PDB"]), "fasta" + ): + + if not record.seq or len(record.seq) > self.max_sequence_len: + continue + + pdb_id, chain = record.id.split("_") + sequence = re.sub(f"[^{valid_amino_acids}]", "X", str(record.seq)) + + # Store as a dictionary entry (list of dicts -> DataFrame later) + records.append( + { + "pdb_id": pdb_id.lower(), + "chain_id": chain.lower(), + "sequence": sequence, + } + ) + + # Convert list of dictionaries to a DataFrame + pdb_chain_df = pd.DataFrame.from_records(records) + + return pdb_chain_df + + @abstractmethod + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict[str, List[int]]: + # Override the return type of the method from superclass + pass + + # ------------------------------ Phase: Setup data ----------------------------------- + def setup_processed(self) -> None: + """ + Transform and prepare processed data for the SCOPe dataset. + + Main function of this method is to transform `data.pkl` into a model input data format (`data.pt`), + ensuring that the data is in a format compatible for input to the model. + The transformed data must contain the following keys: `ident`, `features`, `labels`, and `group`. + This method uses a subclass of Data Reader to perform the transformation. + + It will transform the data related to `scope_version_train`, if specified. + """ + super().setup_processed() + + # Transform the data related to "scope_version_train" to encoded data, if it doesn't exist + if self.scope_version_train is not None and not os.path.isfile( + os.path.join( + self._scope_version_train_obj.processed_dir, + self._scope_version_train_obj.processed_file_names_dict["data"], + ) + ): + print( + f"Missing encoded data related to train version: {self.scope_version_train}" + ) + print("Calling the setup method related to it") + self._scope_version_train_obj.setup() + + def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: + """ + Loads data from a pickled file and yields individual dictionaries for each row. + + The pickled file is expected to contain rows with the following structure: + - Data at row index `self._ID_IDX`: ID of go data instance + - Data at row index `self._DATA_REPRESENTATION_IDX`: Sequence representation of protein + - Data from row index `self._LABELS_START_IDX` onwards: Labels + + This method is used by `_load_data_from_file` to generate dictionaries that are then + processed and converted into a list of dictionaries containing the features and labels. + + Args: + input_file_path (str): The path to the pickled input file. + + Yields: + Dict[str, Any]: A dictionary containing: + - `features` (str): The sequence data from the file. + - `labels` (np.ndarray): A boolean array of labels starting from row index 4. + - `ident` (Any): The identifier from row index 0. + """ + with open(input_file_path, "rb") as input_file: + df = pd.read_pickle(input_file) + for row in df.values: + labels = row[self._LABELS_START_IDX :].astype(bool) + # chebai.preprocessing.reader.DataReader only needs features, labels, ident, group + # "group" set to None, by default as no such entity for this data + yield dict( + features=row[self._DATA_REPRESENTATION_IDX], + labels=labels, + ident=row[self._ID_IDX], + ) + + # ------------------------------ Phase: Dynamic Splits ----------------------------------- + def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: + """ + Loads encoded/transformed data and generates training, validation, and test splits. + + This method first loads encoded data from a file named `data.pt`, which is derived from either + `scope_version` or `scope_version_train`. It then splits the data into training, validation, and test sets. + + If `scope_version_train` is provided: + - Loads additional encoded data from `scope_version_train`. + - Splits this data into training and validation sets, while using the test set from `scope_version`. + - Prunes the test set from `scope_version` to include only labels that exist in `scope_version_train`. + + If `scope_version_train` is not provided: + - Splits the data from `scope_version` into training, validation, and test sets without modification. + + Raises: + FileNotFoundError: If the required `data.pt` file(s) do not exist. Ensure that `prepare_data` + and/or `setup` methods have been called to generate the dataset files. + + Returns: + Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: A tuple containing three DataFrames: + - Training set + - Validation set + - Test set + """ + try: + filename = self.processed_file_names_dict["data"] + data_scope_version = torch.load( + os.path.join(self.processed_dir, filename), weights_only=False + ) + except FileNotFoundError: + raise FileNotFoundError( + f"File data.pt doesn't exists. " + f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files" + ) + + df_scope_version = pd.DataFrame(data_scope_version) + train_df_scope_ver, df_test_scope_ver = self.get_test_split( + df_scope_version, seed=self.dynamic_data_split_seed + ) + + if self.scope_version_train is not None: + # Load encoded data derived from "scope_version_train" + try: + filename_train = ( + self._scope_version_train_obj.processed_file_names_dict["data"] + ) + data_scope_train_version = torch.load( + os.path.join( + self._scope_version_train_obj.processed_dir, filename_train + ), + weights_only=False, + ) + except FileNotFoundError: + raise FileNotFoundError( + f"File data.pt doesn't exists related to scope_version_train {self.scope_version_train}." + f"Please call 'prepare_data' and/or 'setup' methods to generate the dataset files" + ) + + df_scope_train_version = pd.DataFrame(data_scope_train_version) + # Get train/val split of data based on "scope_version_train", but + # using test set from "scope_version" + df_train, df_val = self.get_train_val_splits_given_test( + df_scope_train_version, + df_test_scope_ver, + seed=self.dynamic_data_split_seed, + ) + # Modify test set from "scope_version" to only include the labels that + # exists in "scope_version_train", all other entries remains same. + df_test = self._setup_pruned_test_set(df_test_scope_ver) + else: + # Get all splits based on "scope_version" + df_train, df_val = self.get_train_val_splits_given_test( + train_df_scope_ver, + df_test_scope_ver, + seed=self.dynamic_data_split_seed, + ) + df_test = df_test_scope_ver + + return df_train, df_val, df_test + + def _setup_pruned_test_set( + self, df_test_scope_version: pd.DataFrame + ) -> pd.DataFrame: + """ + Create a test set with the same leaf nodes, but use only classes that appear in the training set. + + Args: + df_test_scope_version (pd.DataFrame): The test dataset. + + Returns: + pd.DataFrame: The pruned test dataset. + """ + # TODO: find a more efficient way to do this + filename_old = "classes.txt" + # filename_new = f"classes_v{self.scope_version_train}.txt" + # dataset = torch.load(os.path.join(self.processed_dir, "test.pt")) + + # Load original classes (from the current SCOPe version - scope_version) + with open(os.path.join(self.processed_dir_main, filename_old), "r") as file: + orig_classes = file.readlines() + + # Load new classes (from the training SCOPe version - scope_version_train) + with open( + os.path.join( + self._scope_version_train_obj.processed_dir_main, filename_old + ), + "r", + ) as file: + new_classes = file.readlines() + + # Create a mapping which give index of a class from scope_version, if the corresponding + # class exists in scope_version_train, Size = Number of classes in scope_version + mapping = [ + None if or_class not in new_classes else new_classes.index(or_class) + for or_class in orig_classes + ] + + # Iterate over each data instance in the test set which is derived from scope_version + for _, row in df_test_scope_version.iterrows(): + # Size = Number of classes in scope_version_train + new_labels = [False for _ in new_classes] + for ind, label in enumerate(row["labels"]): + # If the scope_version class exists in the scope_version_train and has a True label, + # set the corresponding label in new_labels to True + if mapping[ind] is not None and label: + new_labels[mapping[ind]] = label + # Update the labels from test instance from scope_version to the new labels, which are compatible to both versions + row["labels"] = new_labels + + return df_test_scope_version + + # ------------------------------ Phase: Raw Properties ----------------------------------- + @property + def scope_root_dir(self) -> str: + """ + Returns the root directory of scope data. + + Returns: + str: The path to the base directory, which is "data/GO_UniProt". + """ + return os.path.join("data", "SCOPe") + + @property + def base_dir(self) -> str: + """ + Returns the base directory path for storing SCOPe data. + + Returns: + str: The path to the base directory, which is "data/GO_UniProt". + """ + return os.path.join(self.scope_root_dir, f"version_{self.scope_version}") + + @property + def raw_file_names_dict(self) -> dict: + """ + Returns a dictionary of raw file names used in data processing. + + Returns: + dict: A dictionary mapping dataset names to their respective file names. + """ + return { + "CLA": "cla.txt", + "DES": "des.txt", + "HIE": "hie.txt", + "PDB": "pdb_sequences.txt", + } + + +class _SCOPeOverX(_SCOPeDataExtractor, ABC): + """ + A class for extracting data from the SCOPe dataset with a threshold for selecting classes/labels based on + the number of subclasses. + + This class is designed to filter SCOPe classes/labels based on a specified threshold, selecting only those classes + which have a certain number of subclasses in the hierarchy. + + Attributes: + READER (dr.ProteinDataReader): The reader used for reading the dataset. + THRESHOLD (int): The threshold for selecting classes/labels based on the number of subclasses. + + """ + + READER = ProteinDataReader + THRESHOLD: int = None + + @property + def _name(self) -> str: + """ + Returns the name of the dataset. + + Returns: + str: The dataset name, formatted with the current threshold. + """ + return f"SCOPe{self.THRESHOLD}" + + def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> Dict[str, List[int]]: + """ + Selects classes from the SCOPe dataset based on the number of successors meeting a specified threshold. + + This method iterates over the nodes in the graph, counting the number of successors for each node. + Nodes with a number of successors greater than or equal to the defined threshold are selected. + + Note: + The input graph must be transitive closure of a directed acyclic graph. + + Args: + g (nx.Graph): The graph representing the dataset. + *args: Additional positional arguments (not used). + **kwargs: Additional keyword arguments (not used). + + Returns: + Dict: A dict containing selected nodes at each hierarchy level. + + Notes: + - The `THRESHOLD` attribute should be defined in the subclass of this class. + """ + selected_sunids_for_level = {} + for node, attr_dict in g.nodes(data=True): + if attr_dict["level"] in {"root", "px", "sequence"}: + # Skip nodes with level "root", "px", or "sequence" + continue + + # Check if the number of "sequence"-level successors meets or exceeds the threshold + if g.nodes[node]["num_seq_successors"] >= self.THRESHOLD: + selected_sunids_for_level.setdefault(attr_dict["level"], []).append( + node + ) + return selected_sunids_for_level + + +class _SCOPeOverXPartial(_SCOPeOverX, ABC): + """ + Dataset that doesn't use the full SCOPe dataset, but extracts a part of SCOPe (subclasses of a given top class) + + Attributes: + top_class_sunid (int): The Sun-ID of the top class from which to extract subclasses. + """ + + def __init__(self, top_class_sunid: int, **kwargs): + """ + Initializes the _SCOPeOverXPartial dataset. + + Args: + top_class_sunid (int): The Sun-ID of the top class from which to extract subclasses. + **kwargs: Additional keyword arguments passed to the superclass initializer. + """ + if "top_class_sunid" not in kwargs: + kwargs["top_class_sunid"] = top_class_sunid + + self.top_class_sunid: int = top_class_sunid + super().__init__(**kwargs) + + @property + def processed_dir_main(self) -> str: + """ + Returns the main processed data directory specific to the top class. + + Returns: + str: The processed data directory path. + """ + return os.path.join( + self.base_dir, + self._name, + f"partial_{self.top_class_sunid}", + "processed", + ) + + def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: + """ + Extracts a subset of SCOPe based on subclasses of the top class ID. + + This method calls the superclass method to extract the full class hierarchy, + then extracts the subgraph containing only the descendants of the top class ID, including itself. + + Args: + data_path (str): The file path to the SCOPe ontology file. + + Returns: + nx.DiGraph: The extracted class hierarchy as a directed graph, limited to the + descendants of the top class ID. + """ + g = super()._extract_class_hierarchy(data_path) + g = g.subgraph( + list(g.successors(self.top_class_sunid)) + [self.top_class_sunid] + ) + return g + + +class SCOPeOver2000(_SCOPeOverX): + """ + A class for extracting data from the SCOPe dataset with a threshold of 2000 for selecting classes. + + Inherits from `_SCOPeOverX` and sets the threshold for selecting classes to 2000. + + Attributes: + THRESHOLD (int): The threshold for selecting classes (2000). + """ + + THRESHOLD: int = 2000 + + +class SCOPeOver50(_SCOPeOverX): + + THRESHOLD = 50 + + +class SCOPeOverPartial2000(_SCOPeOverXPartial): + """ + A class for extracting data from the SCOPe dataset with a threshold of 2000 for selecting classes. + + Inherits from `_SCOPeOverXPartial` and sets the threshold for selecting classes to 2000. + + Attributes: + THRESHOLD (int): The threshold for selecting classes (2000). + """ + + THRESHOLD: int = 2000 + + +if __name__ == "__main__": + scope = SCOPeOver50(scope_version="2.08") + + # g = scope._extract_class_hierarchy("dummy/path") + # # Save graph + # import pickle + # with open("graph.gpickle", "wb") as f: + # pickle.dump(g, f) + + # Load graph + import pickle + + with open("graph.gpickle", "rb") as f: + g = pickle.load(f) + + # print(len([node for node in g.nodes() if g.out_degree(node) > 10000])) + scope._graph_to_raw_dataset(g) diff --git a/chebai/preprocessing/migration/__init__.py b/chebai/preprocessing/migration/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chebai/preprocessing/migration/deep_go/__init__.py b/chebai/preprocessing/migration/deep_go/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py new file mode 100644 index 0000000..7d59c69 --- /dev/null +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py @@ -0,0 +1,316 @@ +import os +from collections import OrderedDict +from typing import List, Literal, Optional, Tuple + +import pandas as pd +from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit +from jsonargparse import CLI + +from chebai.preprocessing.datasets.deepGO.go_uniprot import DeepGO1MigratedData + + +class DeepGo1DataMigration: + """ + A class to handle data migration and processing for the DeepGO project. + It migrates the DeepGO data to our data structure followed for GO-UniProt data. + + This class handles migration of data from the DeepGO paper below: + Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf, + DeepGO: predicting protein functions from sequence and interactions using a deep ontology-aware classifier, + Bioinformatics, Volume 34, Issue 4, February 2018, Pages 660โ€“668 + (https://doi.org/10.1093/bioinformatics/btx624). + """ + + # Max sequence length as per DeepGO1 + _MAXLEN = 1002 + _LABELS_START_IDX = DeepGO1MigratedData._LABELS_START_IDX + + def __init__(self, data_dir: str, go_branch: Literal["cc", "mf", "bp"]): + """ + Initializes the data migration object with a data directory and GO branch. + + Args: + data_dir (str): Directory containing the data files. + go_branch (Literal["cc", "mf", "bp"]): GO branch to use. + """ + valid_go_branches = list(DeepGO1MigratedData.GO_BRANCH_MAPPING.keys()) + if go_branch not in valid_go_branches: + raise ValueError(f"go_branch must be one of {valid_go_branches}") + self._go_branch = go_branch + + self._data_dir: str = rf"{data_dir}" + self._train_df: Optional[pd.DataFrame] = None + self._test_df: Optional[pd.DataFrame] = None + self._validation_df: Optional[pd.DataFrame] = None + self._terms_df: Optional[pd.DataFrame] = None + self._classes: Optional[List[str]] = None + + def migrate(self) -> None: + """ + Executes the data migration by loading, processing, and saving the data. + """ + print("Starting the migration process...") + self._load_data() + if not all( + df is not None + for df in [ + self._train_df, + self._validation_df, + self._test_df, + self._terms_df, + ] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + splits_df = self._record_splits() + data_with_labels_df = self._extract_required_data_from_splits() + + if not all( + var is not None for var in [data_with_labels_df, splits_df, self._classes] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + + self.save_migrated_data(data_with_labels_df, splits_df) + + def _load_data(self) -> None: + """ + Loads the test, train, validation, and terms data from the pickled files + in the data directory. + """ + try: + print(f"Loading data files from directory: {self._data_dir}") + self._test_df = pd.DataFrame( + pd.read_pickle( + os.path.join(self._data_dir, f"test-{self._go_branch}.pkl") + ) + ) + + # DeepGO 1 lacks a validation split, so we will create one by further splitting the training set. + # Although this reduces the training data slightly compared to the original DeepGO setup, + # given the data size, the impact should be minimal. + train_df = pd.DataFrame( + pd.read_pickle( + os.path.join(self._data_dir, f"train-{self._go_branch}.pkl") + ) + ) + + self._train_df, self._validation_df = self._get_train_val_split(train_df) + + self._terms_df = pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, f"{self._go_branch}.pkl")) + ) + + except FileNotFoundError as e: + raise FileNotFoundError( + f"Data file not found in directory: {e}. " + "Please ensure all required files are available in the specified directory." + ) + + @staticmethod + def _get_train_val_split( + train_df: pd.DataFrame, + ) -> Tuple[pd.DataFrame, pd.DataFrame]: + """ + Splits the training data into a smaller training set and a validation set. + + Args: + train_df (pd.DataFrame): Original training DataFrame. + + Returns: + Tuple[pd.DataFrame, pd.DataFrame]: Training and validation DataFrames. + """ + labels_list_train = train_df["labels"].tolist() + train_split = 0.85 + test_size = ((1 - train_split) ** 2) / train_split + + splitter = MultilabelStratifiedShuffleSplit( + n_splits=1, test_size=test_size, random_state=42 + ) + + train_indices, validation_indices = next( + splitter.split(labels_list_train, labels_list_train) + ) + + df_validation = train_df.iloc[validation_indices] + df_train = train_df.iloc[train_indices] + return df_train, df_validation + + def _record_splits(self) -> pd.DataFrame: + """ + Creates a DataFrame that stores the IDs and their corresponding data splits. + + Returns: + pd.DataFrame: A combined DataFrame containing split assignments. + """ + print("Recording data splits for train, validation, and test sets.") + split_assignment_list: List[pd.DataFrame] = [ + pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}), + pd.DataFrame( + {"id": self._validation_df["proteins"], "split": "validation"} + ), + pd.DataFrame({"id": self._test_df["proteins"], "split": "test"}), + ] + + combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) + return combined_split_assignment + + def _extract_required_data_from_splits(self) -> pd.DataFrame: + """ + Extracts required columns from the combined data splits. + + Returns: + pd.DataFrame: A DataFrame containing the essential columns for processing. + """ + print("Combining data splits into a single DataFrame with required columns.") + required_columns = [ + "proteins", + "accessions", + "sequences", + "gos", + "labels", + ] + + new_df = pd.concat( + [ + self._train_df[required_columns], + self._validation_df[required_columns], + self._test_df[required_columns], + ], + ignore_index=True, + ) + new_df["go_ids"] = new_df.apply( + lambda row: self.extract_go_id(row["gos"]), axis=1 + ) + + labels_df = self._get_labels_columns(new_df) + + data_df = pd.DataFrame( + OrderedDict( + swiss_id=new_df["proteins"], + accession=new_df["accessions"], + go_ids=new_df["go_ids"], + sequence=new_df["sequences"], + ) + ) + + df = pd.concat([data_df, labels_df], axis=1) + + return df + + @staticmethod + def extract_go_id(go_list: List[str]) -> List[int]: + """ + Extracts and parses GO IDs from a list of GO annotations. + + Args: + go_list (List[str]): List of GO annotation strings. + + Returns: + List[int]: List of parsed GO IDs. + """ + return [DeepGO1MigratedData._parse_go_id(go_id_str) for go_id_str in go_list] + + def _get_labels_columns(self, data_df: pd.DataFrame) -> pd.DataFrame: + """ + Generates columns for labels based on provided selected terms. + + Args: + data_df (pd.DataFrame): DataFrame with GO annotations and labels. + + Returns: + pd.DataFrame: DataFrame with label columns. + """ + print("Generating label columns from provided selected terms.") + parsed_go_ids: pd.Series = self._terms_df["functions"].apply( + lambda gos: DeepGO1MigratedData._parse_go_id(gos) + ) + all_go_ids_list = parsed_go_ids.values.tolist() + self._classes = all_go_ids_list + + new_label_columns = pd.DataFrame( + data_df["labels"].tolist(), index=data_df.index, columns=all_go_ids_list + ) + + return new_label_columns + + def save_migrated_data( + self, data_df: pd.DataFrame, splits_df: pd.DataFrame + ) -> None: + """ + Saves the processed data and split information. + + Args: + data_df (pd.DataFrame): Data with GO labels. + splits_df (pd.DataFrame): Split assignment DataFrame. + """ + print("Saving transformed data files.") + + deepgo_migr_inst: DeepGO1MigratedData = DeepGO1MigratedData( + go_branch=DeepGO1MigratedData.GO_BRANCH_MAPPING[self._go_branch], + max_sequence_length=self._MAXLEN, + ) + + # Save data file + deepgo_migr_inst.save_processed( + data_df, deepgo_migr_inst.processed_main_file_names_dict["data"] + ) + print( + f"{deepgo_migr_inst.processed_main_file_names_dict['data']} saved to {deepgo_migr_inst.processed_dir_main}" + ) + + # Save splits file + splits_df.to_csv( + os.path.join(deepgo_migr_inst.processed_dir_main, "splits_deep_go1.csv"), + index=False, + ) + print(f"splits_deep_go1.csv saved to {deepgo_migr_inst.processed_dir_main}") + + # Save classes file + classes = sorted(self._classes) + with open( + os.path.join(deepgo_migr_inst.processed_dir_main, "classes_deep_go1.txt"), + "wt", + ) as fout: + fout.writelines(str(node) + "\n" for node in classes) + print(f"classes_deep_go1.txt saved to {deepgo_migr_inst.processed_dir_main}") + + print("Migration process completed!") + + +class Main: + """ + Main class to handle the migration process for DeepGo1DataMigration. + + Methods: + migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]): + Initiates the migration process for the specified data directory and GO branch. + """ + + @staticmethod + def migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]) -> None: + """ + Initiates the migration process by creating a DeepGoDataMigration instance + and invoking its migrate method. + + Args: + data_dir (str): Directory containing the data files. + go_branch (Literal["cc", "mf", "bp"]): GO branch to use + ("cc" for cellular_component, + "mf" for molecular_function, + or "bp" for biological_process). + """ + DeepGo1DataMigration(data_dir, go_branch).migrate() + + +if __name__ == "__main__": + # Example: python script_name.py migrate --data_dir="data/deep_go1" --go_branch="mf" + # --data_dir specifies the directory containing the data files. + # --go_branch specifies the GO branch (cc, mf, or bp) you want to use for the migration. + CLI( + Main, + description="DeepGo1DataMigration CLI tool to handle migration of GO data for specified branches (cc, mf, bp).", + as_positional=False, + ) diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py new file mode 100644 index 0000000..d23247c --- /dev/null +++ b/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py @@ -0,0 +1,366 @@ +import os +import re +from collections import OrderedDict +from typing import List, Literal, Optional + +import pandas as pd +from jsonargparse import CLI + +from chebai.preprocessing.datasets.deepGO.go_uniprot import DeepGO2MigratedData +from chebai.preprocessing.reader import ProteinDataReader + + +class DeepGo2DataMigration: + """ + A class to handle data migration and processing for the DeepGO project. It migrates the data from the DeepGO-SE + data structure to our data structure followed for GO-UniProt data. + + This class handles migration of data from the DeepGO paper below: + Maxat Kulmanov, Mohammed Asif Khan, Robert Hoehndorf, + DeepGO: predicting protein functions from sequence and interactions using a deep ontology-aware classifier, + Bioinformatics, Volume 34, Issue 4, February 2018, Pages 660โ€“668 + (https://doi.org/10.1093/bioinformatics/btx624) + """ + + _LABELS_START_IDX = DeepGO2MigratedData._LABELS_START_IDX + + def __init__( + self, data_dir: str, go_branch: Literal["cc", "mf", "bp"], max_len: int = 1000 + ): + """ + Initializes the data migration object with a data directory and GO branch. + + Args: + data_dir (str): Directory containing the data files. + go_branch (Literal["cc", "mf", "bp"]): GO branch to use. + max_len (int): Used to truncate the sequence to this length. Default is 1000. + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 + """ + valid_go_branches = list(DeepGO2MigratedData.GO_BRANCH_MAPPING.keys()) + if go_branch not in valid_go_branches: + raise ValueError(f"go_branch must be one of {valid_go_branches}") + self._go_branch = go_branch + + self._data_dir: str = os.path.join(rf"{data_dir}", go_branch) + self._max_len: int = max_len + + self._train_df: Optional[pd.DataFrame] = None + self._test_df: Optional[pd.DataFrame] = None + self._validation_df: Optional[pd.DataFrame] = None + self._terms_df: Optional[pd.DataFrame] = None + self._classes: Optional[List[str]] = None + + def migrate(self) -> None: + """ + Executes the data migration by loading, processing, and saving the data. + """ + print("Starting the migration process...") + self._load_data() + if not all( + df is not None + for df in [ + self._train_df, + self._validation_df, + self._test_df, + self._terms_df, + ] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + splits_df = self._record_splits() + + data_df = self._extract_required_data_from_splits() + data_with_labels_df = self._generate_labels(data_df) + + if not all( + var is not None for var in [data_with_labels_df, splits_df, self._classes] + ): + raise Exception( + "Data splits or terms data is not available in instance variables." + ) + + self.save_migrated_data(data_with_labels_df, splits_df) + + def _load_data(self) -> None: + """ + Loads the test, train, validation, and terms data from the pickled files + in the data directory. + """ + + try: + print(f"Loading data from directory: {self._data_dir}......") + + print( + "Pre-processing the data before loading them into instance variables\n" + f"2-Steps preprocessing: \n" + f"\t 1: Truncating every sequence to {self._max_len}\n" + f"\t 2: Replacing every amino acid which is not in {ProteinDataReader.AA_LETTER}" + ) + + self._test_df = self._pre_process_data( + pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "test_data.pkl")) + ) + ) + self._train_df = self._pre_process_data( + pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "train_data.pkl")) + ) + ) + self._validation_df = self._pre_process_data( + pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "valid_data.pkl")) + ) + ) + + self._terms_df = pd.DataFrame( + pd.read_pickle(os.path.join(self._data_dir, "terms.pkl")) + ) + + except FileNotFoundError as e: + raise FileNotFoundError( + f"Data file not found in directory: {e}. " + "Please ensure all required files are available in the specified directory." + ) + + def _pre_process_data(self, df: pd.DataFrame) -> pd.DataFrame: + """ + Pre-processes the input dataframe by truncating sequences to the maximum + length and replacing invalid amino acids with 'X'. + + Args: + df (pd.DataFrame): The dataframe to preprocess. + + Returns: + pd.DataFrame: The processed dataframe. + """ + df = self._truncate_sequences(df) + df = self._replace_invalid_amino_acids(df) + return df + + def _truncate_sequences( + self, df: pd.DataFrame, column: str = "sequences" + ) -> pd.DataFrame: + """ + Truncate sequences in a specified column of a dataframe to the maximum length. + + https://github.com/bio-ontology-research-group/deepgo2/blob/main/train_cnn.py#L206-L217 + + Args: + df (pd.DataFrame): The input dataframe containing the data to be processed. + column (str, optional): The column containing sequences to truncate. + Defaults to "sequences". + + Returns: + pd.DataFrame: The dataframe with sequences truncated to `self._max_len`. + """ + df[column] = df[column].apply(lambda x: x[: self._max_len]) + return df + + @staticmethod + def _replace_invalid_amino_acids( + df: pd.DataFrame, column: str = "sequences" + ) -> pd.DataFrame: + """ + Replaces invalid amino acids in a sequence with 'X' using regex. + + https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L26-L33 + https://github.com/ChEB-AI/python-chebai/pull/64#issuecomment-2517067073 + + Args: + df (pd.DataFrame): The dataframe containing the sequences to be processed. + column (str, optional): The column containing the sequences. Defaults to "sequences". + + Returns: + pd.DataFrame: The dataframe with invalid amino acids replaced by 'X'. + """ + valid_amino_acids = "".join(ProteinDataReader.AA_LETTER) + # Replace any character not in the valid set with 'X' + df[column] = df[column].apply( + lambda x: re.sub(f"[^{valid_amino_acids}]", "X", x) + ) + return df + + def _record_splits(self) -> pd.DataFrame: + """ + Creates a DataFrame that stores the IDs and their corresponding data splits. + + Returns: + pd.DataFrame: A combined DataFrame containing split assignments. + """ + print("Recording data splits for train, validation, and test sets.") + split_assignment_list: List[pd.DataFrame] = [ + pd.DataFrame({"id": self._train_df["proteins"], "split": "train"}), + pd.DataFrame( + {"id": self._validation_df["proteins"], "split": "validation"} + ), + pd.DataFrame({"id": self._test_df["proteins"], "split": "test"}), + ] + + combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) + return combined_split_assignment + + def _extract_required_data_from_splits(self) -> pd.DataFrame: + """ + Extracts required columns from the combined data splits. + + Returns: + pd.DataFrame: A DataFrame containing the essential columns for processing. + """ + print("Combining the data splits with required data..... ") + required_columns = [ + "proteins", + "accessions", + "sequences", + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/gendata/uni2pandas.py#L60-L69 + "prop_annotations", # Direct and Transitively associated GO ids + "esm2", + ] + + new_df = pd.concat( + [ + self._train_df[required_columns], + self._validation_df[required_columns], + self._test_df[required_columns], + ], + ignore_index=True, + ) + new_df["go_ids"] = new_df["prop_annotations"].apply( + lambda x: self.extract_go_id(x) + ) + + data_df = pd.DataFrame( + OrderedDict( + swiss_id=new_df["proteins"], + accession=new_df["accessions"], + go_ids=new_df["go_ids"], + sequence=new_df["sequences"], + esm2_embeddings=new_df["esm2"], + ) + ) + return data_df + + @staticmethod + def extract_go_id(go_list: List[str]) -> List[int]: + """ + Extracts and parses GO IDs from a list of GO annotations. + + Args: + go_list (List[str]): List of GO annotation strings. + + Returns: + List[str]: List of parsed GO IDs. + """ + return [DeepGO2MigratedData._parse_go_id(go_id_str) for go_id_str in go_list] + + def _generate_labels(self, data_df: pd.DataFrame) -> pd.DataFrame: + """ + Generates label columns for each GO term in the dataset. + + Args: + data_df (pd.DataFrame): DataFrame containing data with GO IDs. + + Returns: + pd.DataFrame: DataFrame with new label columns. + """ + print("Generating labels based on terms.pkl file.......") + parsed_go_ids: pd.Series = self._terms_df["gos"].apply( + DeepGO2MigratedData._parse_go_id + ) + all_go_ids_list = parsed_go_ids.values.tolist() + self._classes = all_go_ids_list + new_label_columns = pd.DataFrame( + False, index=data_df.index, columns=all_go_ids_list + ) + data_df = pd.concat([data_df, new_label_columns], axis=1) + + for index, row in data_df.iterrows(): + for go_id in row["go_ids"]: + if go_id in data_df.columns: + data_df.at[index, go_id] = True + + data_df = data_df[data_df.iloc[:, self._LABELS_START_IDX :].any(axis=1)] + return data_df + + def save_migrated_data( + self, data_df: pd.DataFrame, splits_df: pd.DataFrame + ) -> None: + """ + Saves the processed data and split information. + + Args: + data_df (pd.DataFrame): Data with GO labels. + splits_df (pd.DataFrame): Split assignment DataFrame. + """ + print("Saving transformed data......") + deepgo_migr_inst: DeepGO2MigratedData = DeepGO2MigratedData( + go_branch=DeepGO2MigratedData.GO_BRANCH_MAPPING[self._go_branch], + max_sequence_length=self._max_len, + ) + + # Save data file + deepgo_migr_inst.save_processed( + data_df, deepgo_migr_inst.processed_main_file_names_dict["data"] + ) + print( + f"{deepgo_migr_inst.processed_main_file_names_dict['data']} saved to {deepgo_migr_inst.processed_dir_main}" + ) + + # Save split file + splits_df.to_csv( + os.path.join(deepgo_migr_inst.processed_dir_main, "splits_deep_go2.csv"), + index=False, + ) + print(f"splits_deep_go2.csv saved to {deepgo_migr_inst.processed_dir_main}") + + # Save classes.txt file + classes = sorted(self._classes) + with open( + os.path.join(deepgo_migr_inst.processed_dir_main, "classes_deep_go2.txt"), + "wt", + ) as fout: + fout.writelines(str(node) + "\n" for node in classes) + print(f"classes_deep_go2.txt saved to {deepgo_migr_inst.processed_dir_main}") + + print("Migration completed!") + + +class Main: + """ + Main class to handle the migration process for DeepGoDataMigration. + + Methods: + migrate(data_dir: str, go_branch: Literal["cc", "mf", "bp"]): + Initiates the migration process for the specified data directory and GO branch. + """ + + @staticmethod + def migrate( + data_dir: str, go_branch: Literal["cc", "mf", "bp"], max_len: int = 1000 + ) -> None: + """ + Initiates the migration process by creating a DeepGoDataMigration instance + and invoking its migrate method. + + Args: + data_dir (str): Directory containing the data files. + go_branch (Literal["cc", "mf", "bp"]): GO branch to use + ("cc" for cellular_component, + "mf" for molecular_function, + or "bp" for biological_process). + max_len (int): Used to truncate the sequence to this length. Default is 1000. + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L11 + """ + DeepGo2DataMigration(data_dir, go_branch, max_len).migrate() + + +if __name__ == "__main__": + # Example: python script_name.py migrate --data_dir="data/deep_go_se_training_data" --go_branch="bp" + # --data_dir specifies the directory containing the data files. + # --go_branch specifies the GO branch (cc, mf, or bp) you want to use for the migration. + CLI( + Main, + description="DeepGoDataMigration CLI tool to handle migration of GO data for specified branches (cc, mf, bp).", + as_positional=False, + ) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py new file mode 100644 index 0000000..38060f2 --- /dev/null +++ b/chebai/preprocessing/reader.py @@ -0,0 +1,514 @@ +import os +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple +from urllib.error import HTTPError + +import deepsmiles +import selfies as sf +import torch +from esm import Alphabet +from esm.model.esm2 import ESM2 +from esm.pretrained import ( + _has_regression_weights, + load_model_and_alphabet_core, + load_model_and_alphabet_local, +) +from pysmiles.read_smiles import _tokenize +from transformers import RobertaTokenizerFast + +from chebai.preprocessing.collate import DefaultCollator, RaggedCollator + +EMBEDDING_OFFSET = 10 +PADDING_TOKEN_INDEX = 0 +MASK_TOKEN_INDEX = 1 +CLS_TOKEN = 2 + + +class DataReader: + """ + Base class for reading and preprocessing data. Turns the raw input data (e.g., a SMILES string) into the model + input format (e.g., a list of tokens). + + Args: + collator_kwargs: Optional dictionary of keyword arguments for the collator. + token_path: Optional path for the token file. + kwargs: Additional keyword arguments (not used). + """ + + COLLATOR = DefaultCollator + + def __init__( + self, + collator_kwargs: Optional[Dict[str, Any]] = None, + token_path: Optional[str] = None, + **kwargs, + ): + if collator_kwargs is None: + collator_kwargs = dict() + self.collator = self.COLLATOR(**collator_kwargs) + self.dirname = os.path.dirname(__file__) + self._token_path = token_path + + def _get_raw_data(self, row: Dict[str, Any]) -> Any: + """Get raw data from the row.""" + return row["features"] + + def _get_raw_label(self, row: Dict[str, Any]) -> Any: + """Get raw label from the row.""" + return row["labels"] + + def _get_raw_id(self, row: Dict[str, Any]) -> Any: + """Get raw ID from the row.""" + return row.get("ident", row["features"]) + + def _get_raw_group(self, row: Dict[str, Any]) -> Any: + """Get raw group from the row.""" + return row.get("group", None) + + def _get_additional_kwargs(self, row: Dict[str, Any]) -> Dict[str, Any]: + """Get additional keyword arguments from the row.""" + return row.get("additional_kwargs", dict()) + + def name(cls) -> str: + """Returns the name of the data reader.""" + raise NotImplementedError + + @property + def token_path(self) -> str: + """Get token path, create file if it does not exist yet.""" + if self._token_path is not None: + return self._token_path + token_path = os.path.join(self.dirname, "bin", self.name(), "tokens.txt") + os.makedirs(os.path.join(self.dirname, "bin", self.name()), exist_ok=True) + if not os.path.exists(token_path): + with open(token_path, "x"): + pass + return token_path + + def _read_id(self, raw_data: Any) -> Any: + """Read and return ID from raw data.""" + return raw_data + + def _read_data(self, raw_data: Any) -> Any: + """Read and return data from raw data.""" + return raw_data + + def _read_label(self, raw_label: Any) -> Any: + """Read and return label from raw label.""" + return raw_label + + def _read_group(self, raw: Any) -> Any: + """Read and return group from raw group data.""" + return raw + + def _read_components(self, row: Dict[str, Any]) -> Dict[str, Any]: + """Read and return components from the row.""" + return dict( + features=self._get_raw_data(row), + labels=self._get_raw_label(row), + ident=self._get_raw_id(row), + group=self._get_raw_group(row), + additional_kwargs=self._get_additional_kwargs(row), + ) + + def to_data(self, row: Dict[str, Any]) -> Dict[str, Any]: + """Convert raw row data to processed data.""" + d = self._read_components(row) + return dict( + features=self._read_data(d["features"]), + labels=self._read_label(d["labels"]), + ident=self._read_id(d["ident"]), + group=self._read_group(d["group"]), + **d["additional_kwargs"], + ) + + def on_finish(self) -> None: + """Hook to run at the end of preprocessing.""" + return + + +class ProteinDataReader(DataReader): + """ + Data reader for protein sequences using amino acid tokens. This class processes raw protein sequences into a format + suitable for model input by tokenizing them and assigning unique indices to each token. + + Note: + Refer for amino acid sequence: https://en.wikipedia.org/wiki/Protein_primary_structure + + Args: + collator_kwargs (Optional[Dict[str, Any]]): Optional dictionary of keyword arguments for configuring the collator. + token_path (Optional[str]): Path to the token file. If not provided, it will be created automatically. + kwargs: Additional keyword arguments. + """ + + COLLATOR = RaggedCollator + + # 21 natural amino acid notation + AA_LETTER = [ + "A", + "R", + "N", + "D", + "C", + "Q", + "E", + "G", + "H", + "I", + "L", + "K", + "M", + "F", + "P", + "S", + "T", + "W", + "Y", + "V", + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L3-L5 + "X", # Consider valid in latest paper year 2024 Reference number 3 in go_uniprot.py + ] + + def name(self) -> str: + """ + Returns the name of the data reader. This method identifies the specific type of data reader. + + Returns: + str: The name of the data reader, which is "protein_token". + """ + if self.n_gram is not None: + return f"protein_token_{self.n_gram}_gram" + + return "protein_token" + + def __init__(self, *args, n_gram: Optional[int] = None, **kwargs): + """ + Initializes the ProteinDataReader, loading existing tokens from the specified token file. + + Args: + *args: Additional positional arguments passed to the base class. + **kwargs: Additional keyword arguments passed to the base class. + """ + if n_gram is not None: + assert ( + int(n_gram) >= 2 + ), "Ngrams must be greater than or equal to 2 if provided." + self.n_gram = int(n_gram) + else: + self.n_gram = None + + super().__init__(*args, **kwargs) + + # Load the existing tokens from the token file into a cache + with open(self.token_path, "r") as pk: + self.cache = [x.strip() for x in pk] + + def _get_token_index(self, token: str) -> int: + """ + Returns a unique index for each token (amino acid). If the token is not already in the cache, it is added. + + Args: + token (str): The amino acid token to retrieve or add. + + Returns: + int: The index of the token, offset by the predefined EMBEDDING_OFFSET. + """ + error_str = ( + f"Please ensure that the input only contains valid amino acids " + f"20 Valid natural amino acid notation: {self.AA_LETTER}" + f"Refer to the amino acid sequence details here: " + f"https://en.wikipedia.org/wiki/Protein_primary_structure" + ) + + if self.n_gram is None: + # Single-letter amino acid token check + if str(token) not in self.AA_LETTER: + raise KeyError(f"Invalid token '{token}' encountered. " + error_str) + else: + # n-gram token validation, ensure that each component of the n-gram is valid + for aa in token: + if aa not in self.AA_LETTER: + raise KeyError( + f"Invalid token '{token}' encountered as part of n-gram {self.n_gram}. " + + error_str + ) + + if str(token) not in self.cache: + self.cache.append(str(token)) + return self.cache.index(str(token)) + EMBEDDING_OFFSET + + def _read_data(self, raw_data: str) -> List[int]: + """ + Reads and tokenizes raw protein sequence data into a list of token indices. + + Args: + raw_data (str): The raw protein sequence to be tokenized (e.g., "MKTFF..."). + + Returns: + List[int]: A list of integers representing the indices of the amino acid tokens. + """ + if self.n_gram is not None: + # Tokenize the sequence into n-grams + tokens = [ + raw_data[i : i + self.n_gram] + for i in range(len(raw_data) - self.n_gram + 1) + ] + return [self._get_token_index(gram) for gram in tokens] + + # If n_gram is None, tokenize the sequence at the amino acid level (single-letter representation) + return [self._get_token_index(aa) for aa in raw_data] + + def on_finish(self) -> None: + """ + Saves the current cache of tokens to the token file. This method is called after all data processing is complete. + """ + with open(self.token_path, "w") as pk: + print(f"Saving {len(self.cache)} tokens to {self.token_path}...") + print(f"First 10 tokens: {self.cache[:10]}") + pk.writelines([f"{c}\n" for c in self.cache]) + + +class ESM2EmbeddingReader(DataReader): + """ + A data reader to process protein sequences using the ESM2 model for embeddings. + + References: + https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/extract_esm.py + + Note: + For layer availability by model, Please check below link: + https://github.com/facebookresearch/esm?tab=readme-ov-file#pre-trained-models- + + To test this reader, try lighter models: + esm2_t6_8M_UR50D: 6 layers (valid layers: 1โ€“6), (~28 Mb) - A tiny 8M parameter model. + esm2_t12_35M_UR50D: 12 layers (valid layers: 1โ€“12), (~128 Mb) - A slightly larger, 35M parameter model. + These smaller models are good for testing and debugging purposes. + + """ + + # https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L53 + _MODELS_URL = "https://dl.fbaipublicfiles.com/fair-esm/models/{}.pt" + _REGRESSION_URL = ( + "https://dl.fbaipublicfiles.com/fair-esm/regression/{}-contact-regression.pt" + ) + + def __init__( + self, + save_model_dir: str = os.path.join("data", "esm2_reader"), + model_name: str = "esm2_t36_3B_UR50D", + device: Optional[torch.device] = None, + truncation_length: int = 1022, + toks_per_batch: int = 4096, + return_contacts: bool = False, + repr_layer: int = 36, + *args, + **kwargs, + ): + """ + Initialize the ESM2EmbeddingReader class. + + Args: + save_model_dir (str): Directory to save/load the pretrained ESM model. + model_name (str): Name of the pretrained model. Defaults to "esm2_t36_3B_UR50D". + device (torch.device or str, optional): Device for computation (e.g., 'cpu', 'cuda'). + truncation_length (int): Maximum sequence length for truncation. Defaults to 1022. + toks_per_batch (int): Tokens per batch for data processing. Defaults to 4096. + return_contacts (bool): Whether to return contact maps. Defaults to False. + repr_layers (int): Layer number to extract representations from. Defaults to 36. + """ + self.save_model_dir = save_model_dir + if not os.path.exists(self.save_model_dir): + os.makedirs((os.path.dirname(self.save_model_dir)), exist_ok=True) + self.model_name = model_name + self.device = device + self.truncation_length = truncation_length + self.toks_per_batch = toks_per_batch + self.return_contacts = return_contacts + self.repr_layer = repr_layer + + self._model: Optional[ESM2] = None + self._alphabet: Optional[Alphabet] = None + + self._model, self._alphabet = self.load_model_and_alphabet() + self._model.eval() + + if self.device: + self._model = self._model.to(device) + + super().__init__(*args, **kwargs) + + def load_model_and_alphabet(self) -> Tuple[ESM2, Alphabet]: + """ + Load the ESM2 model and its alphabet. + + References: + https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L24-L28 + + Returns: + Tuple[ESM2, Alphabet]: Loaded model and alphabet. + """ + model_location = os.path.join(self.save_model_dir, f"{self.model_name}.pt") + if os.path.exists(model_location): + return load_model_and_alphabet_local(model_location) + else: + return self.load_model_and_alphabet_hub() + + def load_model_and_alphabet_hub(self) -> Tuple[ESM2, Alphabet]: + """ + Load the model and alphabet from the hub URL. + + References: + https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L62-L64 + + Returns: + Tuple[ESM2, Alphabet]: Loaded model and alphabet. + """ + model_url = self._MODELS_URL.format(self.model_name) + model_data = self.load_hub_workaround(model_url) + regression_data = None + if _has_regression_weights(self.model_name): + regression_url = self._REGRESSION_URL.format(self.model_name) + regression_data = self.load_hub_workaround(regression_url) + return load_model_and_alphabet_core( + self.model_name, model_data, regression_data + ) + + def load_hub_workaround(self, url) -> torch.Tensor: + """ + Workaround to load models from the PyTorch Hub. + + References: + https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L31-L43 + + Returns: + torch.Tensor: Loaded model state dictionary. + """ + try: + data = torch.hub.load_state_dict_from_url( + url, self.save_model_dir, progress=True, map_location=self.device + ) + + except RuntimeError: + # Handle PyTorch version issues + fn = Path(url).name + data = torch.load( + f"{torch.hub.get_dir()}/checkpoints/{fn}", + map_location="cpu", + ) + except HTTPError as e: + raise Exception( + f"Could not load {url}. Did you specify the correct model name?" + ) + return data + + @staticmethod + def name() -> str: + """ + Returns the name of the data reader. This method identifies the specific type of data reader. + + Returns: + str: The name of the data reader, which is "protein_token". + """ + return "esm2_embedding" + + @property + def token_path(self) -> None: + """ + Not used as no token file is not created for this reader. + + Returns: + str: Empty string since this method is not implemented. + """ + return + + def _read_data(self, raw_data: str) -> List[int]: + """ + Reads protein sequence data and generates embeddings. + + Args: + raw_data (str): The protein sequence. + + Returns: + List[int]: Embeddings generated for the sequence. + """ + alp_tokens_idx = self._sequence_to_alphabet_tokens_idx(raw_data) + return self._alphabet_tokens_to_esm_embedding(alp_tokens_idx).tolist() + + def _sequence_to_alphabet_tokens_idx(self, sequence: str) -> torch.Tensor: + """ + Converts a protein sequence into ESM alphabet token indices. + + Args: + sequence (str): Protein sequence. + + References: + https://github.com/facebookresearch/esm/blob/2b369911bb5b4b0dda914521b9475cad1656b2ac/esm/data.py#L249-L250 + https://github.com/facebookresearch/esm/blob/2b369911bb5b4b0dda914521b9475cad1656b2ac/esm/data.py#L262-L297 + + Returns: + torch.Tensor: Tokenized sequence with special tokens (BOS/EOS) included. + """ + seq_encoded = self._alphabet.encode(sequence) + tokens = [] + + # Add BOS token if configured + if self._alphabet.prepend_bos: + tokens.append(self._alphabet.cls_idx) + + # Add the main sequence + tokens.extend(seq_encoded) + + # Add EOS token if configured + if self._alphabet.append_eos: + tokens.append(self._alphabet.eos_idx) + + # Convert to PyTorch tensor and return + return torch.tensor([tokens], dtype=torch.int64) + + def _alphabet_tokens_to_esm_embedding(self, tokens: torch.Tensor) -> torch.Tensor: + """ + Converts alphabet tokens into ESM embeddings. + + Args: + tokens (torch.Tensor): Tokenized protein sequences. + + References: + https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/extract_esm.py#L82-L107 + + Returns: + torch.Tensor: Protein embedding from the specified representation layer. + """ + if self.device: + tokens = tokens.to(self.device, non_blocking=True) + + with torch.no_grad(): + out = self._model( + tokens, + repr_layers=[ + self.repr_layer, + ], + return_contacts=self.return_contacts, + ) + + # Extract representations and compute the mean embedding for each layer + representations = { + layer: t.to(self.device) for layer, t in out["representations"].items() + } + truncate_len = min(self.truncation_length, tokens.size(1)) + + result = { + "mean_representations": { + layer: t[0, 1 : truncate_len + 1].mean(0).clone() + for layer, t in representations.items() + } + } + return result["mean_representations"][self.repr_layer] + + def on_finish(self) -> None: + """ + Not used here as no token file exists for this reader. + + Returns: + None + """ + pass diff --git a/chebai/preprocessing/structures.py b/chebai/preprocessing/structures.py new file mode 100644 index 0000000..1fb3711 --- /dev/null +++ b/chebai/preprocessing/structures.py @@ -0,0 +1,141 @@ +from typing import Any, Tuple, Union + +import networkx as nx +import torch + + +class XYData(torch.utils.data.Dataset): + """ + A dataset class for handling pairs of data (x, y). + + Args: + x: Input data. + y: Target data. + kwargs: Additional fields to store in the dataset. + """ + + def __init__( + self, x: Union[torch.Tensor, Tuple[Any, ...]], y: torch.Tensor, **kwargs + ): + super().__init__() + self.additional_fields = kwargs + self.x = x + self.y = y + + def __getitem__(self, index: int): + """Returns the data and target at the given index.""" + return self.x[index], self.y[index] + + def __len__(self) -> int: + """Returns the size of the dataset.""" + return len(self.x) + + def to_x(self, device: torch.device) -> Union[torch.Tensor, Tuple[Any, ...]]: + """ + Moves the input data to the specified device. + + Args: + device: The device to move the data to. + + Returns: + The input data on the specified device. + """ + if isinstance(self.x, tuple): + res = [] + for elem in self.x: + if isinstance(elem, dict): + for k, v in elem.items(): + elem[k] = v.to(device) if v is not None else None + else: + elem = elem.to(device) + res.append(elem) + return tuple(res) + return self.x.to(device) + + def to_y(self, device: torch.device) -> torch.Tensor: + """ + Moves the target data to the specified device. + + Args: + device: The device to move the data to. + + Returns: + The target data on the specified device. + """ + return self.y.to(device) + + def _to_if_tensor(self, obj: Any, device: torch.device) -> Any: + """ + Recursively moves the object to the specified device if it is a tensor. + + Args: + obj: The object to move. + device: The device to move the object to. + + Returns: + The object on the specified device. + """ + if isinstance(obj, torch.Tensor): + return obj.to(device) + elif isinstance(obj, dict): + return {k: self._to_if_tensor(v, device) for k, v in obj.items()} + elif isinstance(obj, list): + return [self._to_if_tensor(v, device) for v in obj] + else: + return obj + + def to(self, device: torch.device) -> "XYData": + """ + Moves the dataset to the specified device. + + Args: + device: The device to move the dataset to. + + Returns: + A new dataset on the specified device. + """ + x = self.to_x(device) + if self.y is not None: + y = self.to_y(device) + else: + y = None + return XYData( + x, + y, + **{ + k: self._to_if_tensor(v, device) + for k, v in self.additional_fields.items() + }, + ) + + +class XYMolData(XYData): + """ + A dataset class for handling molecular data represented as NetworkX graphs. + + Args: + x: Input molecular graphs. + y: Target data. + kwargs: Additional fields to store in the dataset. + """ + + def to_x(self, device: torch.device) -> Tuple[nx.Graph, ...]: + """ + Moves the node attributes of the molecular graphs to the specified device. + + Args: + device: The device to move the data to. + + Returns: + A tuple of molecular graphs with node attributes on the specified device. + """ + l = [] + for g in self.x: + graph = g.copy() + nx.set_node_attributes( + graph, + {k: v.to(device) for k, v in nx.get_node_attributes(g, "x").items()}, + "x", + ) + l.append(graph) + return tuple(l) diff --git a/chebai/result/__init__.py b/chebai/result/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chebai/result/analyse_sem.py b/chebai/result/analyse_sem.py new file mode 100644 index 0000000..51a1fb2 --- /dev/null +++ b/chebai/result/analyse_sem.py @@ -0,0 +1,721 @@ +import gc +import sys +import traceback +from datetime import datetime +from typing import List, LiteralString + +from torchmetrics.functional.classification import ( + multilabel_auroc, + multilabel_average_precision, + multilabel_f1_score, +) +from utils import * + +from chebai.loss.semantic import DisjointLoss +from chebai.preprocessing.datasets.base import _DynamicDataset +from chebai.preprocessing.datasets.chebi import ChEBIOver100 +from chebai.preprocessing.datasets.pubchem import PubChemKMeans + +DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") + + +def binary(left, right): + return torch.logical_and(left > 0.5, right > 0.5) + + +def strict(left, right): + return left + right > 1 + + +def weak(left, right): + return left + right > 1.01 + + +def product(left, right): + return left * right + + +def lukasiewicz(left, right): + return torch.relu(left + right - 1) + + +def apply_metric(metric, left, right): + return torch.sum(metric(left, right), dim=0) + + +ALL_CONSISTENCY_METRICS = [product, lukasiewicz, weak, strict, binary] + + +def _filter_to_one_hot(preds, idx_filter): + """Takes list of indices (e.g. [1, 3, 0]) and returns a one-hot filter with these indices + (e.g. [[0,1,0,0], [0,0,0,1], [1,0,0,0]])""" + res = torch.zeros((len(idx_filter), preds.shape[1]), dtype=torch.bool) + for i, idx in enumerate(idx_filter): + res[i][idx] = True + return res + + +def _sort_results_by_label(n_labels, results, filter): + by_label = torch.zeros(n_labels, device=DEVICE, dtype=torch.int) + for r, filter_l in zip(results, filter): + by_label[filter_l] += r + return by_label + + +def get_best_epoch(run): + files = run.files() + best_ep = None + best_micro_f1 = 0 + for file in files: + if file.name.startswith("checkpoints/best_epoch"): + micro_f1 = float(file.name.split("=")[-1][:-5]) + if micro_f1 > best_micro_f1 or best_ep is None: + best_ep = int(file.name.split("=")[1].split("_")[0]) + best_micro_f1 = micro_f1 + if best_ep is None: + raise Exception(f"Could not find any 'best' checkpoint for run {run.id}") + else: + print(f"Best epoch for run {run.id}: {best_ep}") + return best_ep + + +def download_model_from_wandb( + run_id, base_dir=os.path.join("logs", "downloaded_ckpts") +): + api = wandb.Api() + run = api.run(f"chebai/chebai/{run_id}") + epoch = get_best_epoch(run) + return ( + get_checkpoint_from_wandb(epoch, run, root=base_dir), + epoch, + ) + + +def load_preds_labels( + ckpt_path: LiteralString, data_module, data_subset_key="test", buffer_dir=None +): + if buffer_dir is None: + buffer_dir = os.path.join( + "results_buffer", + *ckpt_path.split(os.path.sep)[-2:], + f"{data_module.__class__.__name__}_{data_subset_key}", + ) + model = Electra.load_from_checkpoint(ckpt_path, map_location="cuda:0", strict=False) + print( + f"Calculating predictions on {data_module.__class__.__name__} ({data_subset_key})..." + ) + evaluate_model( + model, + data_module, + buffer_dir=buffer_dir, + # for chebi, use kinds, otherwise use file names + filename=( + data_subset_key if not isinstance(buffer_dir, _DynamicDataset) else None + ), + kind=data_subset_key, + skip_existing_preds=True, + batch_size=1, + ) + return load_results_from_buffer(buffer_dir, device=torch.device("cpu")) + + +def get_label_names(data_module): + if os.path.exists(os.path.join(data_module.processed_dir_main, "classes.txt")): + with open(os.path.join(data_module.processed_dir_main, "classes.txt")) as fin: + return [int(line.strip()) for line in fin] + print( + f"Failed to retrieve label names, {os.path.join(data_module.processed_dir_main, 'classes.txt')} not found" + ) + return None + + +def get_chebi_graph(data_module, label_names): + if os.path.exists(os.path.join(data_module.raw_dir, "chebi.obo")): + chebi_graph = data_module.extract_class_hierarchy( + os.path.join(data_module.raw_dir, "chebi.obo") + ) + return chebi_graph.subgraph(label_names) + print( + f"Failed to retrieve ChEBI graph, {os.path.join(data_module.raw_dir, 'chebi.obo')} not found" + ) + return None + + +def get_disjoint_groups(): + disjoints_owl_file = os.path.join("data", "chebi-disjoints.owl") + with open(disjoints_owl_file, "r") as f: + plaintext = f.read() + segments = plaintext.split("<") + disjoint_pairs = [] + left = None + for seg in segments: + if seg.startswith("rdf:Description ") or seg.startswith("owl:Class"): + left = int(seg.split('rdf:about="&obo;CHEBI_')[1].split('"')[0]) + elif seg.startswith("owl:disjointWith"): + right = int(seg.split('rdf:resource="&obo;CHEBI_')[1].split('"')[0]) + disjoint_pairs.append([left, right]) + + disjoint_groups = [] + for seg in plaintext.split(""): + if "owl;AllDisjointClasses" in seg: + classes = seg.split('rdf:about="&obo;CHEBI_')[1:] + classes = [int(c.split('"')[0]) for c in classes] + disjoint_groups.append(classes) + disjoint_all = disjoint_pairs + disjoint_groups + # one disjointness is commented out in the owl-file + # (the correct way would be to parse the owl file and notice the comment symbols, but for this case, it should work) + disjoint_all.remove([22729, 51880]) + print(f"Found {len(disjoint_all)} disjoint groups") + return disjoint_all + + +class PredictionSmoother: + """Removes implication and disjointness violations from predictions""" + + def __init__(self, dataset): + self.label_names = get_label_names(dataset) + self.chebi_graph = get_chebi_graph(dataset, self.label_names) + self.disjoint_groups = get_disjoint_groups() + + def __call__(self, preds): + + preds_sum_orig = torch.sum(preds) + print(f"Preds sum: {preds_sum_orig}") + # eliminate implication violations by setting each prediction to maximum of its successors + for i, label in enumerate(self.label_names): + succs = [ + self.label_names.index(p) for p in self.chebi_graph.successors(label) + ] + [i] + if len(succs) > 0: + preds[:, i] = torch.max(preds[:, succs], dim=1).values + print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}") + preds_sum_orig = torch.sum(preds) + # step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower) + preds_bounded = torch.min(preds, torch.ones_like(preds) * 0.49) + for disj_group in self.disjoint_groups: + disj_group = [ + self.label_names.index(g) for g in disj_group if g in self.label_names + ] + if len(disj_group) > 1: + old_preds = preds[:, disj_group] + disj_max = torch.max(preds[:, disj_group], dim=1) + for i, row in enumerate(preds): + for l in range(len(preds[i])): + if l in disj_group and l != disj_group[disj_max.indices[i]]: + preds[i, l] = preds_bounded[i, l] + samples_changed = 0 + for i, row in enumerate(preds[:, disj_group]): + if any(r != o for r, o in zip(row, old_preds[i])): + samples_changed += 1 + if samples_changed != 0: + print( + f"disjointness group {[self.label_names[d] for d in disj_group]} changed {samples_changed} samples" + ) + print( + f"Preds change after disjointness (step 2): {torch.sum(preds) - preds_sum_orig}" + ) + preds_sum_orig = torch.sum(preds) + # step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors + for i, label in enumerate(self.label_names): + predecessors = [i] + [ + self.label_names.index(p) for p in self.chebi_graph.predecessors(label) + ] + lowest_predecessors = torch.min(preds[:, predecessors], dim=1) + preds[:, i] = lowest_predecessors.values + for idx_idx, idx in enumerate(lowest_predecessors.indices): + if idx > 0: + print( + f"class {label}: changed prediction of sample {idx_idx} to value of class " + f"{self.label_names[predecessors[idx]]} ({preds[idx_idx, i].item():.2f})" + ) + if torch.sum(preds) != preds_sum_orig: + print( + f"Preds change (step 3) for {label}: {torch.sum(preds) - preds_sum_orig}" + ) + preds_sum_orig = torch.sum(preds) + return preds + + +def _filter_to_dense(filter): + filter_dense = [] + for i in range(filter.shape[0]): + for j in range(filter.shape[1]): + if filter[i, j] > 0: + filter_dense.append([i, j]) + return torch.tensor(filter_dense) + + +def build_prediction_filter(data_module_labeled=None): + if data_module_labeled is None: + data_module_labeled = ChEBIOver100(chebi_version=231) + # prepare filters + print(f"Loading implication / disjointness filters...") + dl = DisjointLoss( + path_to_disjointness=os.path.join("data", "disjoint.csv"), + data_extractor=data_module_labeled, + ) + impl = _filter_to_dense(dl.implication_filter_l) + disj = _filter_to_dense(dl.disjoint_filter_l) + + return [ + (impl[:, 0], impl[:, 1], "impl"), + (disj[:, 0], disj[:, 1], "disj"), + ] + + +def run_consistency_metrics( + preds, + consistency_filters, + data_module_labeled=None, # use labels from this dataset for violations + violation_metrics=None, + verbose_violation_output=False, + save_details_to=None, +): + """Calculates all semantic metrics for given predictions (and supervised metrics if labels are provided)""" + if violation_metrics is None: + violation_metrics = ALL_CONSISTENCY_METRICS + if data_module_labeled is None: + data_module_labeled = ChEBIOver100(chebi_version=231) + if save_details_to is not None: + os.makedirs(save_details_to, exist_ok=True) + + preds.to("cpu") + + n_labels = preds.size(1) + print(f"Found {preds.shape[0]} predictions ({n_labels} classes)") + + results = {} + + for dl_filter_l, dl_filter_r, filter_type in consistency_filters: + l_preds = preds[:, dl_filter_l] + r_preds = preds[:, dl_filter_r] + for i, metric in enumerate(violation_metrics): + if metric.__name__ not in results: + results[metric.__name__] = {} + print(f"Calculating metrics {metric.__name__} on {filter_type}") + + metric_results = {} + metric_results["tps"] = torch.sum( + torch.stack( + [ + apply_metric( + metric, + l_preds[i : i + 1000], + ( + r_preds[i : i + 1000] + if filter_type == "impl" + else 1 - r_preds[i : i + 1000] + ), + ) + for i in range(0, r_preds.shape[0], 1000) + ] + ), + dim=0, + ) + metric_results["fns"] = torch.sum( + torch.stack( + [ + apply_metric( + metric, + l_preds[i : i + 1000], + ( + 1 - r_preds[i : i + 1000] + if filter_type == "impl" + else r_preds[i : i + 1000] + ), + ) + for i in range(0, r_preds.shape[0], 1000) + ] + ), + dim=0, + ) + if verbose_violation_output: + label_names = get_label_names(data_module_labeled) + print( + f"Found {torch.sum(metric_results['fns'])} {filter_type}-violations" + ) + # for k, fn_cls in enumerate(metric_results['fns']): + # if fn_cls > 0: + # print(f"\tThereof, {fn_cls.item()} belong to class {label_names[k]}") + if torch.sum(metric_results["fns"]) != 0: + fns = metric( + l_preds, 1 - r_preds if filter_type == "impl" else r_preds + ) + print(fns.shape) + for k, row in enumerate(fns): + if torch.sum(row) != 0: + print(f"{torch.sum(row)} violations for entity {k}") + for j, violation in enumerate(row): + if violation > 0: + print( + f"\tviolated ({label_names[dl_filter_l[j]]} -> {preds[k, dl_filter_l[j]]:.3f}" + f", {label_names[dl_filter_r[j]]} -> {preds[k, dl_filter_r[j]]:.3f})" + ) + + m_l_agg = {} + for key, value in metric_results.items(): + m_l_agg[key] = _sort_results_by_label( + n_labels, + value, + dl_filter_l, + ) + m_r_agg = {} + for key, value in metric_results.items(): + m_r_agg[key] = _sort_results_by_label( + n_labels, + value, + dl_filter_r, + ) + + if save_details_to is not None: + with open( + os.path.join( + save_details_to, f"{metric.__name__}_{filter_type}_all.csv" + ), + "w+", + ) as f: + f.write("left,right,tps,fns\n") + for left, right, tps, fns in zip( + dl_filter_l, + dl_filter_r, + metric_results["tps"], + metric_results["fns"], + ): + f.write(f"{left},{right},{tps},{fns}\n") + with open( + os.path.join( + save_details_to, f"{metric.__name__}_{filter_type}_l.csv" + ), + "w+", + ) as f: + f.write("left,tps,fns\n") + for left in range(n_labels): + f.write( + f"{left},{m_l_agg['tps'][left].item()},{m_l_agg['fns'][left].item()}\n" + ) + with open( + os.path.join( + save_details_to, f"{metric.__name__}_{filter_type}_r.csv" + ), + "w+", + ) as f: + f.write("right,tps,fns\n") + for right in range(n_labels): + f.write( + f"{right},{m_r_agg['tps'][right].item()},{m_r_agg['fns'][right].item()}\n" + ) + print( + f"Saved unaggregated consistency metrics ({metric.__name__}, {filter_type}) to {save_details_to}" + ) + + fns_sum = torch.sum(metric_results["fns"]).item() + results[metric.__name__][f"micro-fnr-{filter_type}"] = ( + 0 + if fns_sum == 0 + else ( + torch.sum(metric_results["fns"]) + / ( + torch.sum(metric_results[f"tps"]) + + torch.sum(metric_results[f"fns"]) + ) + ).item() + ) + macro_fnr_l = m_l_agg[f"fns"] / (m_l_agg[f"tps"] + m_l_agg[f"fns"]) + results[metric.__name__][f"lmacro-fnr-{filter_type}"] = ( + 0 + if fns_sum == 0 + else torch.mean(macro_fnr_l[~macro_fnr_l.isnan()]).item() + ) + macro_fnr_r = m_r_agg[f"fns"] / (m_r_agg[f"tps"] + m_r_agg[f"fns"]) + results[metric.__name__][f"rmacro-fnr-{filter_type}"] = ( + 0 + if fns_sum == 0 + else torch.mean(macro_fnr_r[~macro_fnr_r.isnan()]).item() + ) + results[metric.__name__][f"fn-sum-{filter_type}"] = torch.sum( + metric_results["fns"] + ).item() + results[metric.__name__][f"tp-sum-{filter_type}"] = torch.sum( + metric_results["tps"] + ).item() + + del metric_results + del m_l_agg + del m_r_agg + + gc.collect() + del l_preds + del r_preds + gc.collect() + + return results + + +def run_supervised_metrics(preds, labels, save_details_to=None): + # calculate supervised metrics + results = {} + if labels is not None: + results["micro-f1"] = multilabel_f1_score( + preds, labels, num_labels=preds.size(1), average="micro" + ).item() + results["macro-f1"] = multilabel_f1_score( + preds, labels, num_labels=preds.size(1), average="macro" + ).item() + results["micro-roc-auc"] = multilabel_auroc( + preds, labels, num_labels=preds.size(1), average="micro" + ).item() + results["macro-roc-auc"] = multilabel_auroc( + preds, labels, num_labels=preds.size(1), average="macro" + ).item() + + results["micro-ap"] = multilabel_average_precision( + preds, labels, num_labels=preds.size(1), average="micro" + ).item() + results["macro-ap"] = multilabel_average_precision( + preds, labels, num_labels=preds.size(1), average="macro" + ).item() + + if save_details_to is not None: + f1_by_label = multilabel_f1_score( + preds, labels, num_labels=preds.size(1), average=None + ) + roc_by_label = multilabel_auroc( + preds, labels, num_labels=preds.size(1), average=None + ) + ap_by_label = multilabel_average_precision( + preds, labels, num_labels=preds.size(1), average=None + ) + with open(os.path.join(save_details_to, f"supervised.csv"), "w+") as f: + f.write("label,f1,roc-auc,ap\n") + for right in range(preds.size(1)): + f.write( + f"{right},{f1_by_label[right].item()},{roc_by_label[right].item()},{ap_by_label[right].item()}\n" + ) + print(f"Saved class-wise supervised metrics to {save_details_to}") + + del preds + del labels + gc.collect() + return results + + +# run predictions / metrics calculations for semantic loss paper runs (NeSy 2024 submission) +def run_semloss_eval(): + # runs from wandb + non_wandb_runs = [] + api = wandb.Api() + runs = api.runs("chebai/chebai", filters={"tags": "eval_semloss_paper"}) + print(f"Found {len(runs)} tagged wandb runs") + ids_wandb = [run.id for run in runs] + + # ids used in the NeSy submission + prod = ["tk15yznc", "uke62a8m", "w0h3zr5s"] + xu19 = ["5ko8knb4", "061fd85t", "r50ioujs"] + prod_mixed = ["hk8555ff", "e0lxw8py", "lig23cmg"] + luka = ["0c0s48nh", "lfg384bp", "qeghvubh"] + baseline = ["i4wtz1k4", "zd020wkv", "rc1q3t49"] + prodk2 = ["ng3usn0p", "rp0wwzjv", "8fma1q7r"] + ids = baseline + prod + prodk2 + xu19 + luka + prod_mixed + # ids = ids_wandb + run_all( + ids, + non_wandb_runs, + prediction_datasets=[(ChEBIOver100(chebi_version=231), "test")], + consistency_metrics=[binary], + ) + + +def run_all( + wandb_ids=None, + local_ckpts: List[Tuple] = None, + consistency_metrics: Optional[List[callable]] = None, + prediction_datasets: List[Tuple] = None, + remove_violations: bool = False, + results_dir="_fuzzy_loss_eval", + check_consistency_on=None, + verbose_violation_output=False, +): + if wandb_ids is None: + wandb_ids = [] + if local_ckpts is None: + local_ckpts = [] + if consistency_metrics is None: + consistency_metrics = ALL_CONSISTENCY_METRICS + if prediction_datasets is None: + prediction_datasets = [ + (ChEBIOver100(chebi_version=231), "test"), + ] + if check_consistency_on is None: + check_consistency_on = ChEBIOver100(chebi_version=231) + + if remove_violations: + smooth_preds = PredictionSmoother(check_consistency_on) + else: + smooth_preds = lambda x: x + + timestamp = datetime.now().strftime("%y%m%d-%H%M%S") + prediction_filters = build_prediction_filter(check_consistency_on) + + results_path_consistency = os.path.join( + results_dir, + f"consistency_metrics_{timestamp}{'_violations_removed' if remove_violations else ''}.csv", + ) + consistency_keys = [ + "micro-fnr-impl", + "lmacro-fnr-impl", + "rmacro-fnr-impl", + "fn-sum-impl", + "tp-sum-impl", + "micro-fnr-disj", + "lmacro-fnr-disj", + "rmacro-fnr-disj", + "fn-sum-disj", + "tp-sum-disj", + ] + with open(results_path_consistency, "x") as f: + f.write( + "run-id,epoch,datamodule,data_key,metric," + + ",".join(consistency_keys) + + "\n" + ) + results_path_supervised = os.path.join( + results_dir, + f"supervised_metrics_{timestamp}{'_violations_removed' if remove_violations else ''}.csv", + ) + supervised_keys = [ + "micro-f1", + "macro-f1", + "micro-roc-auc", + "macro-roc-auc", + "micro-ap", + "macro-ap", + ] + with open(results_path_supervised, "x") as f: + f.write("run-id,epoch,datamodule,data_key," + ",".join(supervised_keys) + "\n") + + ckpts = [(run_name, ep, None) for run_name, ep in local_ckpts] + [ + (None, None, wandb_id) for wandb_id in wandb_ids + ] + + for run_name, epoch, wandb_id in ckpts: + try: + ckpt_dir = os.path.join("logs", "downloaded_ckpts") + # for wandb runs, use short id as name, otherwise use ckpt dir name + if wandb_id is not None: + run_name = wandb_id + ckpt_path, epoch = download_model_from_wandb(run_name, ckpt_dir) + else: + ckpt_path = None + for file in os.listdir(os.path.join(ckpt_dir, run_name)): + if f"epoch={epoch}_" in file or f"epoch={epoch}." in file: + ckpt_path = os.path.join(os.path.join(ckpt_dir, run_name, file)) + assert ( + ckpt_path is not None + ), f"Failed to find checkpoint for epoch {epoch} in {os.path.join(ckpt_dir, run_name)}" + print(f"Starting run {run_name} (epoch {epoch})") + + for dataset, dataset_key in prediction_datasets: + # copy data from legacy buffer dir if possible + old_buffer_dir = os.path.join( + "results_buffer", + *ckpt_path.split(os.path.sep)[-2:], + f"{dataset.__class__.__name__}_{dataset_key}", + ) + buffer_dir = os.path.join( + "results_buffer", + run_name, + f"epoch={epoch}", + f"{dataset.__class__.__name__}_{dataset_key}", + ) + print("Checking for buffer dir", old_buffer_dir) + if os.path.isdir(old_buffer_dir): + from distutils.dir_util import copy_tree, remove_tree + + os.makedirs(buffer_dir, exist_ok=True) + copy_tree(old_buffer_dir, buffer_dir) + remove_tree(old_buffer_dir, dry_run=True) + print(f"Moved buffer from {old_buffer_dir} to {buffer_dir}") + print(f"Using buffer_dir {buffer_dir}") + preds, labels = load_preds_labels( + ckpt_path, dataset, dataset_key, buffer_dir + ) + # identity function if remove_violations is False + smooth_preds(preds) + + details_path = None # os.path.join( + # results_dir, + # f"{run_name}_ep{epoch}_{dataset.__class__.__name__}_{dataset_key}", + # ) + metrics_dict = run_consistency_metrics( + preds, + prediction_filters, + check_consistency_on, + consistency_metrics, + verbose_violation_output, + save_details_to=details_path, + ) + with open(results_path_consistency, "a") as f: + for metric in metrics_dict: + values = metrics_dict[metric] + f.write( + f"{run_name},{epoch},{dataset.__class__.__name__},{dataset_key},{metric}," + f"{','.join([str(values[k]) for k in consistency_keys])}\n" + ) + print( + f"Consistency metrics have been written to {results_path_consistency}" + ) + if labels is not None: + metrics_dict = run_supervised_metrics( + preds, labels, save_details_to=details_path + ) + with open(results_path_supervised, "a") as f: + f.write( + f"{run_name},{epoch},{dataset.__class__.__name__},{dataset_key}," + f"{','.join([str(metrics_dict[k]) for k in supervised_keys])}\n" + ) + print( + f"Supervised metrics have been written to {results_path_supervised}" + ) + except Exception as e: + print( + f"Error during run {wandb_id if wandb_id is not None else run_name}: {e}" + ) + print(traceback.format_exc()) + + +# follow-up to NeSy submission +def run_fuzzy_loss(tag="fuzzy_loss", skip_first_n=0): + api = wandb.Api() + runs = api.runs("chebai/chebai", filters={"tags": tag}) + print(f"Found {len(runs)} wandb runs tagged with '{tag}'") + ids = [run.id for run in runs] + chebi100 = ChEBIOver100( + chebi_version=231, + splits_file_path=os.path.join( + "data", "chebi_v231", "ChEBI100", "fuzzy_loss_splits.csv" + ), + ) + local_ckpts = [][skip_first_n:] + pubchem_kmeans = PubChemKMeans() + run_all( + ids[max(0, skip_first_n - len(local_ckpts)) :], # ids, + local_ckpts, + consistency_metrics=[binary], + check_consistency_on=chebi100, + prediction_datasets=[ + (chebi100, "test"), + # (pubchem_kmeans, "cluster1_cutoff2k.pt"), + # (pubchem_kmeans, "cluster2.pt"), + # (pubchem_kmeans, "ten_from_each_cluster.pt"), + # (pubchem_kmeans, "chebi_close.pt"), + ], + ) + + +if __name__ == "__main__": + if len(sys.argv) > 2: + run_fuzzy_loss(sys.argv[1], int(sys.argv[2])) + elif len(sys.argv) > 1: + run_fuzzy_loss(sys.argv[1]) + else: + run_fuzzy_loss() diff --git a/chebai/result/base.py b/chebai/result/base.py new file mode 100644 index 0000000..9d583a0 --- /dev/null +++ b/chebai/result/base.py @@ -0,0 +1,105 @@ +import abc +import multiprocessing as mp +from typing import Iterable + +import torch +import tqdm + +from chebai.models.base import ChebaiBaseNet + +PROCESSORS = dict() + + +class ResultProcessor(abc.ABC): + @classmethod + def _identifier(cls) -> str: + raise NotImplementedError + + def start(self): + pass + + def close(self): + pass + + def __init_subclass__(cls, **kwargs): + assert ( + cls._identifier() not in PROCESSORS + ), f"ResultProcessor {cls.__name__} does not have a unique identifier" + PROCESSORS[cls._identifier()] = cls + + def process_prediction(self, proc_id, features, labels, pred, ident): + raise NotImplementedError + + +class ResultFactory(abc.ABC): + def __init__( + self, model: ChebaiBaseNet, dataset, processors: Iterable[ResultProcessor] + ): + self._model = model + self._reader = dataset.reader + self.dataset = dataset + self._processors = processors + + def _process_row(self, row): + return row + + def _generate_predictions(self, data_path, raw=False, **kwargs): + self._model.eval() + collate = self._reader.COLLATOR() + if raw: + data_tuples = [ + (x["features"], x["ident"], self._reader.to_data(self._process_row(x))) + for x in self.dataset._load_dict(data_path) + ] + else: + data_tuples = [ + (x.get("raw_features", x["ident"]), x["ident"], x) + for x in torch.load(data_path, weights_only=False) + ] + + for raw_features, ident, row in tqdm.tqdm(data_tuples): + raw_labels = row.get("labels") + + processable_data = self._model._process_batch(collate([row]), 0) + + model_output = self._model(processable_data) + preds, labels = self._model._get_prediction_and_labels( + processable_data, processable_data["labels"], model_output + ) + d = dict( + model_output=model_output, + preds=preds, + raw_features=raw_features, + ident=ident, + threshold=self._model.thres, + ) + if raw_labels is not None: + d["labels"] = raw_labels + yield d + + def call_procs(self, args): + proc_id, proc_args = args + for proc in self._processors: + try: + proc.process_prediction(proc_id, **proc_args) + except Exception: + print("Could not process results for", proc_args["ident"]) + raise + + def execute(self, data_path, **kwargs): + for proc in self._processors: + proc.start() + try: + with mp.Pool() as pool: + res = map( + self.call_procs, + enumerate(self._generate_predictions(data_path, **kwargs)), + ) + for r in res: + pass + + except: + raise + finally: + for proc in self._processors: + proc.close() diff --git a/chebai/result/classification.py b/chebai/result/classification.py new file mode 100644 index 0000000..bb23dea --- /dev/null +++ b/chebai/result/classification.py @@ -0,0 +1,105 @@ +from typing import List + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +from torch import Tensor +from torchmetrics.classification import ( + MultilabelF1Score, + MultilabelPrecision, + MultilabelRecall, +) + +from chebai.callbacks.epoch_metrics import BalancedAccuracy, MacroF1 +from chebai.result.utils import * + + +def visualise_f1(logs_path: str) -> None: + """ + Visualize F1 scores from metrics.csv and save the plot as f1_plot.png. + + Args: + logs_path: The path to the directory containing metrics.csv. + """ + df = pd.read_csv(os.path.join(logs_path, "metrics.csv")) + df_loss = df.melt( + id_vars="epoch", + value_vars=[ + "val_ep_macro-f1", + "val_micro-f1", + "train_micro-f1", + "train_ep_macro-f1", + ], + ) + lineplt = sns.lineplot(df_loss, x="epoch", y="value", hue="variable") + plt.savefig(os.path.join(logs_path, "f1_plot.png")) + plt.show() + + +def print_metrics( + preds: Tensor, + labels: Tensor, + device: torch.device, + classes: Optional[List[str]] = None, + top_k: int = 10, + markdown_output: bool = False, +) -> None: + """ + Prints relevant metrics, including micro and macro F1, recall and precision, + best k classes, and worst classes. + + Args: + preds: Predicted labels as a tensor. + labels: True labels as a tensor. + device: The device to perform computations on. + classes: Optional list of class names. + top_k: The number of top classes to display based on F1 score. + markdown_output: If True, print metrics in markdown format. + """ + f1_micro = MultilabelF1Score(preds.shape[1], average="micro").to(device=device) + my_f1_macro = MacroF1(preds.shape[1]).to(device=device) + my_bal_acc = BalancedAccuracy(preds.shape[1]).to(device=device) + + print(f"Macro-F1: {my_f1_macro(preds, labels):3f}") + print(f"Micro-F1: {f1_micro(preds, labels):3f}") + print(f"Balanced Accuracy: {my_bal_acc(preds, labels):3f}") + precision_macro = MultilabelPrecision(preds.shape[1], average="macro").to( + device=device + ) + precision_micro = MultilabelPrecision(preds.shape[1], average="micro").to( + device=device + ) + macro_adjust = 1 + recall_macro = MultilabelRecall(preds.shape[1], average="macro").to(device=device) + recall_micro = MultilabelRecall(preds.shape[1], average="micro").to(device=device) + print(f"Macro-Precision: {precision_macro(preds, labels) * macro_adjust:3f}") + print(f"Micro-Precision: {precision_micro(preds, labels):3f}") + print(f"Macro-Recall: {recall_macro(preds, labels) * macro_adjust:3f}") + print(f"Micro-Recall: {recall_micro(preds, labels):3f}") + if markdown_output: + print( + f"| Model | Macro-F1 | Micro-F1 | Macro-Precision | Micro-Precision | Macro-Recall | Micro-Recall | Balanced Accuracy |" + ) + print(f"| --- | --- | --- | --- | --- | --- | --- | --- |") + print( + f"| | {my_f1_macro(preds, labels):3f} | {f1_micro(preds, labels):3f} | {precision_macro(preds, labels):3f} | " + f"{precision_micro(preds, labels):3f} | {recall_macro(preds, labels):3f} | " + f"{recall_micro(preds, labels):3f} | {my_bal_acc(preds, labels):3f} |" + ) + + classwise_f1_fn = MultilabelF1Score(preds.shape[1], average=None).to(device=device) + classwise_f1 = classwise_f1_fn(preds, labels) + best_classwise_f1 = torch.topk(classwise_f1, top_k).indices + print(f"Top {top_k} classes (F1-score):") + for i, best in enumerate(best_classwise_f1): + print( + f"{i + 1}. {classes[best] if classes is not None else best} - F1: {classwise_f1[best]:3f}" + ) + + zeros = [] + for i, f1 in enumerate(classwise_f1): + if f1 == 0.0 and torch.sum(labels[:, i]) != 0: + zeros.append(f"{classes[i] if classes is not None else i}") + print( + f'Found {len(zeros)} classes with F1-score == 0 (and non-zero labels): {", ".join(zeros)}' + ) diff --git a/chebai/result/evaluate_predictions.py b/chebai/result/evaluate_predictions.py new file mode 100644 index 0000000..355c07c --- /dev/null +++ b/chebai/result/evaluate_predictions.py @@ -0,0 +1,108 @@ +from typing import Tuple + +import numpy as np +import torch +from jsonargparse import CLI +from torchmetrics.functional.classification import multilabel_auroc + +from chebai.callbacks.epoch_metrics import MacroF1 +from chebai.result.utils import load_results_from_buffer + + +class EvaluatePredictions: + def __init__(self, eval_dir: str): + """ + Initializes the EvaluatePredictions class. + + Args: + eval_dir (str): Path to the directory containing evaluation files. + """ + self.eval_dir = eval_dir + self.metrics = [] + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + self.num_labels = None + + @staticmethod + def validate_eval_dir(label_files: torch.Tensor, pred_files: torch.Tensor) -> None: + """ + Validates that the number of labels matches the number of predictions, + ensuring that they have the same shape. + + Args: + label_files (torch.Tensor): Tensor containing label data. + pred_files (torch.Tensor): Tensor containing prediction data. + + Raises: + ValueError: If label and prediction tensors are mismatched in shape. + """ + if label_files is None or pred_files is None: + raise ValueError("Both label and prediction tensors must be provided.") + + # Check if the number of labels matches the number of predictions + if label_files.shape[0] != pred_files.shape[0]: + raise ValueError( + "Number of label tensors does not match the number of prediction tensors." + ) + + # Validate that the last dimension matches the expected number of classes + if label_files.shape[1] != pred_files.shape[1]: + raise ValueError( + "Label and prediction tensors must have the same shape in terms of class outputs." + ) + + def evaluate(self) -> None: + """ + Loads predictions and labels, validates file correspondence, and calculates Multilabel AUROC and Fmax. + """ + test_preds, test_labels = load_results_from_buffer(self.eval_dir, self.device) + self.validate_eval_dir(test_labels, test_preds) + self.num_labels = test_preds.shape[1] + + ml_auroc = multilabel_auroc( + test_preds, test_labels, num_labels=self.num_labels + ).item() + + print("Multilabel AUC-ROC:", ml_auroc) + + fmax, threshold = self.calculate_fmax(test_preds, test_labels) + print(f"F-max : {fmax}, threshold: {threshold}") + + def calculate_fmax( + self, test_preds: torch.Tensor, test_labels: torch.Tensor + ) -> Tuple[float, float]: + """ + Calculates the Fmax metric using the F1 score at various thresholds. + + Args: + test_preds (torch.Tensor): Predicted scores for the labels. + test_labels (torch.Tensor): True labels for the evaluation. + + Returns: + Tuple[float, float]: The maximum F1 score and the corresponding threshold. + """ + # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/metrics.py#L51-L52 + thresholds = np.linspace(0, 1, 101) + fmax = 0.0 + best_threshold = 0.0 + + for t in thresholds: + custom_f1_metric = MacroF1(num_labels=self.num_labels, threshold=t) + custom_f1_metric.update(test_preds, test_labels) + custom_f1_metric_score = custom_f1_metric.compute().item() + + # Check if the current score is the best we've seen + if custom_f1_metric_score > fmax: + fmax = custom_f1_metric_score + best_threshold = t + + return fmax, best_threshold + + +class Main: + def evaluate(self, eval_dir: str): + EvaluatePredictions(eval_dir).evaluate() + + +if __name__ == "__main__": + # evaluate_predictions.py evaluate + CLI(Main) diff --git a/chebai/result/molplot.py b/chebai/result/molplot.py new file mode 100644 index 0000000..8fdbc77 --- /dev/null +++ b/chebai/result/molplot.py @@ -0,0 +1,506 @@ +import abc +from os import makedirs +from tempfile import NamedTemporaryFile + +import networkx as nx +import numpy as np +import pandas as pd +import torch +from matplotlib import cm, colors +from matplotlib import pyplot as plt +from matplotlib import rc +from matplotlib.image import AxesImage, imread +from networkx.algorithms.isomorphism import GraphMatcher +from pysmiles.read_smiles import * +from pysmiles.read_smiles import _tokenize +from rdkit import Chem +from rdkit.Chem.Draw import MolToMPL, rdMolDraw2D + +from chebai.preprocessing.datasets import JCI_500_COLUMNS, JCI_500_COLUMNS_INT +from chebai.result.base import ResultProcessor + + +class AttentionMolPlot: + def draw_attention_molecule(self, smiles, attention): + pmol = self.read_smiles_with_index(smiles) + rdmol = Chem.MolFromSmiles(smiles) + if not rdmol: + raise NoRDMolException + rdmolx = self.mol_to_nx(rdmol) + gm = GraphMatcher(pmol, rdmolx) + iso = next(gm.isomorphisms_iter()) + token_to_node_map = { + pmol.nodes[node]["token_index"]: iso[node] for node in pmol.nodes + } + d = rdMolDraw2D.MolDraw2DCairo(500, 500) + cmap = cm.ScalarMappable(cmap=cm.Greens) + + aggr_attention_colors = cmap.to_rgba( + np.max(attention[2:, :], axis=0), norm=False + ) + cols = { + token_to_node_map[token_index]: tuple( + aggr_attention_colors[token_index].tolist() + ) + for node, token_index in nx.get_node_attributes(pmol, "token_index").items() + } + highlight_atoms = [ + token_to_node_map[token_index] + for node, token_index in nx.get_node_attributes(pmol, "token_index").items() + ] + rdMolDraw2D.PrepareAndDrawMolecule( + d, rdmol, highlightAtoms=highlight_atoms, highlightAtomColors=cols + ) + + d.FinishDrawing() + return d + + def plot_attentions(self, smiles, attention, threshold, labels): + d = self.draw_attention_molecule(smiles, attention) + cmap = cm.ScalarMappable(cmap=cm.Greens) + attention_colors = cmap.to_rgba(attention, norm=False) + num_tokens = sum(1 for _ in _tokenize(smiles)) + + fig = plt.figure(figsize=(15, 15), facecolor="w") + + rc("font", **{"family": "monospace", "monospace": "DejaVu Sans Mono"}) + fig.tight_layout() + + ax2, ax = fig.subplots(2, 1, gridspec_kw={"height_ratios": [10, 1]}) + + with NamedTemporaryFile(mode="wt", suffix=".png") as svg1: + d.WriteDrawingText(svg1.name) + ax2.imshow(imread(svg1.name)) + ax2.axis("off") + ax2.spines["left"].set_position("center") + ax2.spines["bottom"].set_position("zero") + ax2.autoscale(tight=True) + + table = plt.table( + cellText=[ + (["[CLS]"] + [t for _, t in _tokenize(smiles)]) + for _ in range(attention.shape[0]) + ], + cellColours=attention_colors, + cellLoc="center", + ) + table.auto_set_column_width(list(range(num_tokens))) + table.scale(1, 4) + table.set_fontsize(26) + + ax.add_table(table) + ax.axis("off") + ax.spines["top"].set_position("zero") + ax.autoscale(tight=True) + + self.counter += 1 + for w, label, predicted in labels: + if predicted: + cat = "p" + else: + cat = "n" + if predicted == label: + cat = "t" + cat + else: + cat = "f" + cat + fig.savefig( + f"/tmp/plots/{w}/{cat}_{self.counter}.png", + transparent=False, + bbox_inches="tight", + pad_inches=0, + ) + plt.close() + + @staticmethod + def mol_to_nx(mol): + G = nx.Graph() + + for atom in mol.GetAtoms(): + G.add_node( + atom.GetIdx(), + atomic_num=atom.GetAtomicNum(), + formal_charge=atom.GetFormalCharge(), + chiral_tag=atom.GetChiralTag(), + hybridization=atom.GetHybridization(), + num_explicit_hs=atom.GetNumExplicitHs(), + is_aromatic=atom.GetIsAromatic(), + ) + for bond in mol.GetBonds(): + G.add_edge( + bond.GetBeginAtomIdx(), + bond.GetEndAtomIdx(), + bond_type=bond.GetBondType(), + ) + return G + + @staticmethod + def read_smiles_with_index( + smiles, + explicit_hydrogen=False, + zero_order_bonds=True, + reinterpret_aromatic=True, + ): + """ + This is just a re-implementation of pysmiles.read_smiles, that stores token indices + """ + bond_to_order = {"-": 1, "=": 2, "#": 3, "$": 4, ":": 1.5, ".": 0} + mol = nx.Graph() + anchor = None + idx = 0 + default_bond = 1 + next_bond = None + branches = [] + ring_nums = {} + for token_index, (tokentype, token) in enumerate(_tokenize(smiles)): + if tokentype == TokenType.ATOM: + mol.add_node(idx, token_index=token_index, **parse_atom(token)) + if anchor is not None: + if next_bond is None: + next_bond = default_bond + if next_bond or zero_order_bonds: + mol.add_edge(anchor, idx, order=next_bond) + next_bond = None + anchor = idx + idx += 1 + elif tokentype == TokenType.BRANCH_START: + branches.append(anchor) + elif tokentype == TokenType.BRANCH_END: + anchor = branches.pop() + elif tokentype == TokenType.BOND_TYPE: + if next_bond is not None: + raise ValueError( + "Previous bond (order {}) not used. " + 'Overwritten by "{}"'.format(next_bond, token) + ) + next_bond = bond_to_order[token] + elif tokentype == TokenType.RING_NUM: + if token in ring_nums: + jdx, order = ring_nums[token] + if next_bond is None and order is None: + next_bond = default_bond + elif order is None: # Note that the check is needed, + next_bond = next_bond # But this could be pass. + elif next_bond is None: + next_bond = order + elif next_bond != order: # Both are not None + raise ValueError( + "Conflicting bond orders for ring " + "between indices {}".format(token) + ) + # idx is the index of the *next* atom we're adding. So: -1. + if mol.has_edge(idx - 1, jdx): + raise ValueError( + "Edge specified by marker {} already " + "exists".format(token) + ) + if idx - 1 == jdx: + raise ValueError( + "Marker {} specifies a bond between an " + "atom and itself".format(token) + ) + if next_bond or zero_order_bonds: + mol.add_edge(idx - 1, jdx, order=next_bond) + next_bond = None + del ring_nums[token] + else: + if idx == 0: + raise ValueError( + "Can't have a marker ({}) before an atom" "".format(token) + ) + # idx is the index of the *next* atom we're adding. So: -1. + ring_nums[token] = (idx - 1, next_bond) + next_bond = None + elif tokentype == TokenType.EZSTEREO: + LOGGER.warning( + 'E/Z stereochemical information, which is specified by "%s", will be discarded', + token, + ) + if ring_nums: + raise KeyError("Unmatched ring indices {}".format(list(ring_nums.keys()))) + + # Time to deal with aromaticity. This is a mess, because it's not super + # clear what aromaticity information has been provided, and what should be + # inferred. In addition, to what extend do we want to provide a "sane" + # molecule, even if this overrides what the SMILES string specifies? + cycles = nx.cycle_basis(mol) + ring_idxs = set() + for cycle in cycles: + ring_idxs.update(cycle) + non_ring_idxs = set(mol.nodes) - ring_idxs + for n_idx in non_ring_idxs: + if mol.nodes[n_idx].get("aromatic", False): + raise ValueError( + "You specified an aromatic atom outside of a" + " ring. This is impossible" + ) + + mark_aromatic_edges(mol) + fill_valence(mol) + if reinterpret_aromatic: + mark_aromatic_atoms(mol) + mark_aromatic_edges(mol) + for idx, jdx in mol.edges: + if ( + not mol.nodes[idx].get("aromatic", False) + or not mol.nodes[jdx].get("aromatic", False) + ) and mol.edges[idx, jdx].get("order", 1) == 1.5: + mol.edges[idx, jdx]["order"] = 1 + + if explicit_hydrogen: + add_explicit_hydrogens(mol) + else: + remove_explicit_hydrogens(mol) + return mol + + +class AttentionOnMoleculesProcessor(AttentionMolPlot, ResultProcessor): + def __init__(self, *args, headers=None, **kwargs): + super().__init__(*args, **kwargs) + self.headers = headers + + def start(self): + self.counter = 0 + + @classmethod + def _identifier(cls): + return "platt" + + def filter(self, l): + return + + def process_prediction( + self, proc_id, preds, raw_features, model_output, labels, **kwargs + ): + atts = torch.stack(model_output["attentions"]).squeeze(1).detach().numpy() + predictions = preds.detach().numpy().squeeze(0) > 0.5 + if self.headers is None: + headers = list(range(len(labels))) + else: + headers = self.headers + + for w in headers: + makedirs(f"/tmp/plots/{w}", exist_ok=True) + + try: + self.plot_attentions( + raw_features, + np.max(np.max(atts, axis=2), axis=1), + 0.4, + [ + (ident, label, predicted) + for label, ident, predicted in zip(labels, headers, predictions) + if (label or predicted) + ], + ) + except StopIteration: + print("Could not match", raw_features) + except NoRDMolException: + pass + + +class LastLayerAttentionProcessor(AttentionMolPlot, ResultProcessor): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def start(self): + self.counter = 0 + for w in JCI_500_COLUMNS_INT: + makedirs(f"/tmp/plots/{w}", exist_ok=True) + + @classmethod + def _identifier(cls): + return "platt_last" + + def filter(self, l): + return + + def process_prediction(self, raw_features, raw_labels, features, labels, pred): + atts = torch.stack(pred["attentions"]).squeeze(1).detach().numpy() + last_layer = np.max(atts, axis=2)[-1, :] + if np.any(last_layer > 0.4): + try: + self.plot_attentions( + raw_features, + np.max(np.max(atts, axis=2), axis=1), + 0.4, + [ + ident + for present, ident in zip(labels, JCI_500_COLUMNS_INT) + if present + ], + ) + except StopIteration: + print("Could not match", raw_features) + except NoRDMolException: + pass + + +class SingletonAttentionProcessor(AttentionMolPlot, ResultProcessor): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def start(self): + self.counter = 0 + for w in JCI_500_COLUMNS_INT: + makedirs(f"/tmp/plots/{w}", exist_ok=True) + + @classmethod + def _identifier(cls): + return "platt_singles" + + def filter(self, l): + return + + def process_prediction(self, raw_features, raw_labels, features, labels, pred): + atts = torch.stack(pred["attentions"]).squeeze(1).detach().numpy() + if sum(labels) == 1: + try: + predictions = ( + torch.sigmoid(pred["logits"]).detach().numpy().squeeze(0) > 0.5 + ) + self.plot_attentions( + raw_features, + np.max(np.average(atts, axis=2), axis=1), + 0.4, + [ + (ident, label, predicted) + for label, ident, predicted in zip( + labels, JCI_500_COLUMNS_INT, predictions + ) + if (label or predicted) + ], + ) + except StopIteration: + print("Could not match", raw_features) + except NoRDMolException: + pass + + +class AttentionNetwork(ResultProcessor): + def __init__(self, *args, headers=None, **kwargs): + super().__init__(*args, **kwargs) + self.headers = headers + self.i = 0 + + @classmethod + def _identifier(cls): + return "platt_table" + + def start(self): + self.counter = 0 + + def process_prediction( + self, + proc_id, + preds, + raw_features, + model_output, + labels, + ident=None, + threshold=0.5, + **kwargs, + ): + if self.headers is None: + headers = list(range(len(labels))) + else: + headers = self.headers + + for w in headers: + makedirs(f"plots/{w}", exist_ok=True) + + atts = torch.stack(model_output["attentions"]).squeeze(1).detach().numpy() + predictions = preds.detach().numpy().squeeze(0) > 0.5 + plt.rcParams.update({"font.size": 8}) + try: + attentions = atts + tokens = ["[CLS]"] + [s for _, s in _tokenize(raw_features)] + cmap = cm.ScalarMappable(cmap=cm.Greens) + assert len(tokens) == attentions.shape[2] + + rows = int((attentions.shape[1] + 2)) + width = len(tokens) + height = 12 + rdmol = Chem.MolFromSmiles(raw_features) + if rdmol is not None: + fig0 = MolToMPL(rdmol, fitImage=True) + fig0.text( + 0.1, + 0, + "annotated:" + + ", ".join( + str(l) for (l, is_member) in zip(headers, labels) if is_member + ) + + "\n" + + "predicted:" + + ", ".join( + str(l) + for (l, is_member) in zip(headers, predictions) + if is_member + ), + fontdict=dict(fontsize=10), + ) + fig0.savefig( + f"plots/mol_{ident}.png", + bbox_inches="tight", + pad_inches=0, + ) + plt.close(fig0) + fig = plt.figure(figsize=(10 * 12, width // 3)) + l_tokens = {i: str(t) for i, t in enumerate(tokens)} + r_tokens = {(len(tokens) + i): str(t) for i, t in enumerate(tokens)} + labels = dict(list(l_tokens.items()) + list(r_tokens.items())) + edges = [(l, r) for r in r_tokens.keys() for l in l_tokens.keys()] + g = nx.Graph() + g.add_nodes_from(l_tokens, bipartite=0) + g.add_nodes_from(r_tokens, bipartite=1) + g.add_edges_from(edges) + pos = np.array( + [(0, -i) for i in range(len(l_tokens))] + + [(1, -i) for i in range(len(l_tokens))] + ) + + offset = np.array( + [(1, 0) for i in range(len(l_tokens))] + + [(1, 0) for i in range(len(l_tokens))] + ) + # axes = fig.subplots(1, 6 * 8 + 5, subplot_kw=dict(frameon=False)) + + ax = fig.add_subplot(111) + ax.axis("off") + for layer in range(attentions.shape[0]): + for head in range(attentions.shape[1]): + index = 8 * (layer) + head + layer + 1 + + at = np.concatenate([a for a in attentions[layer, head]]) + col = cmap.cmap(at) + col[:, 3] = at + nx.draw_networkx( + g, + pos=pos + (index * offset), + edge_color=col, + ax=ax, + labels=labels, + node_color="none", + node_size=8, + ) + # sns.heatmap(attentions[i,j], linewidth=0.5, ax=ax, cmap=cm.Greens, square=True, vmin=0, vmax=1, xticklabels=tokens, yticklabels=tokens) + fig.subplots_adjust() + fig.savefig( + f"plots/att_{ident}.png", + # transparent=True, + bbox_inches="tight", + pad_inches=0, + dpi=100, + ) + + plt.close() + except StopIteration: + print("Could not match", raw_features) + except NoRDMolException: + pass + finally: + plt.close() + + +class NoRDMolException(Exception): + pass diff --git a/chebai/result/prediction_json.py b/chebai/result/prediction_json.py new file mode 100644 index 0000000..924df65 --- /dev/null +++ b/chebai/result/prediction_json.py @@ -0,0 +1,26 @@ +import json + +from chebai.result.base import ResultProcessor + + +class JSONResultProcessor(ResultProcessor): + @classmethod + def _identifier(cls): + return "json" + + def start(self): + self.data = [] + + def close(self): + with open("predictions.json", "w") as fout: + json.dump(self.data, fout) + del self.data + + def process_prediction(self, proc_id, raw_features, labels, preds, ident, **kwargs): + self.data.append( + dict( + ident=ident, + labels=labels if labels is not None else None, + prediction=preds.tolist(), + ) + ) diff --git a/chebai/result/pretraining.py b/chebai/result/pretraining.py new file mode 100644 index 0000000..8d712f2 --- /dev/null +++ b/chebai/result/pretraining.py @@ -0,0 +1,65 @@ +import os + +import matplotlib.pyplot as plt +import pandas as pd +import seaborn as sns +import torch +import tqdm + +import chebai.models.electra as electra +from chebai.loss.pretraining import ElectraPreLoss +from chebai.result.base import ResultProcessor + + +def visualise_loss(logs_path): + df = pd.read_csv(os.path.join(logs_path, "metrics.csv")) + df_loss = df.melt( + id_vars="epoch", value_vars=["val_loss_epoch", "train_loss_epoch"] + ) + lineplt = sns.lineplot(df_loss, x="epoch", y="value", hue="variable") + plt.savefig(os.path.join(logs_path, "f1_plot.png")) + plt.show() + + +# get predictions from model +def evaluate_model(logs_base_path, model_filename, data_module): + model = electra.ElectraPre.load_from_checkpoint( + os.path.join( + logs_base_path, + "best_epoch=85_val_loss=0.0147_val_micro-f1=0.90.ckpt", + model_filename, + ) + ) + assert isinstance(model, electra.ElectraPre) + collate = data_module.reader.COLLATOR() + test_file = "test.pt" + data_path = os.path.join(data_module.processed_dir, test_file) + data_list = torch.load(data_path, weights_only=False) + preds_list = [] + labels_list = [] + + for row in tqdm.tqdm(data_list): + processable_data = model._process_batch(collate([row]), 0) + model_output = model(processable_data, **processable_data["model_kwargs"]) + preds, labels = model._get_prediction_and_labels( + processable_data, processable_data["labels"], model_output + ) + preds_list.append(preds) + labels_list.append(labels) + + test_preds = torch.cat(preds_list) + test_labels = torch.cat(labels_list) + print(test_preds.shape) + print(test_labels.shape) + test_loss = ElectraPreLoss() + print(f"Loss on test set: {test_loss(test_preds, test_labels)}") + # f1_macro = MultilabelF1Score(test_preds.shape[1], average='macro') + # f1_micro = MultilabelF1Score(test_preds.shape[1], average='micro') + # print(f'Macro-F1 on test set with {test_preds.shape[1]} classes: {f1_macro(test_preds, test_labels):3f}') + # print(f'Micro-F1 on test set with {test_preds.shape[1]} classes: {f1_micro(test_preds, test_labels):3f}') + + +class PretrainingResultProcessor(ResultProcessor): + @classmethod + def _identifier(cls) -> str: + return "PretrainingResultProcessor" diff --git a/chebai/result/utils.py b/chebai/result/utils.py new file mode 100644 index 0000000..991960d --- /dev/null +++ b/chebai/result/utils.py @@ -0,0 +1,235 @@ +import os +import shutil +from typing import Optional, Tuple, Union + +import torch +import tqdm +import wandb +import wandb.util as wandb_util + +from chebai.models.base import ChebaiBaseNet +from chebai.models.electra import Electra +from chebai.preprocessing.datasets.base import XYBaseDataModule +from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor + + +def get_checkpoint_from_wandb( + epoch: int, + run: wandb.apis.public.Run, + root: str = os.path.join("logs", "downloaded_ckpts"), +): + """ + Gets a wandb checkpoint based on run and epoch, downloads it if necessary. + + Args: + epoch: The epoch number of the checkpoint to retrieve. + run: The wandb run object. + root: The root directory to save the downloaded checkpoint. + + Returns: + The location of the downloaded checkpoint. + """ + api = wandb.Api() + + files = run.files() + for file in files: + if file.name.startswith( + f"checkpoints/per_epoch={epoch}" + ) or file.name.startswith(f"checkpoints/best_epoch={epoch}"): + dest_path = os.path.join( + root, run.id, file.name.split("/")[-1].split("_")[1] + ".ckpt" + ) + # legacy: also look for ckpts in the old format + old_dest_path = os.path.join(root, run.name, file.name.split("/")[-1]) + if not os.path.isfile(dest_path): + if os.path.isfile(old_dest_path): + print(f"Copying checkpoint from {old_dest_path} to {dest_path}") + os.makedirs(os.path.dirname(dest_path), exist_ok=True) + shutil.copy2(old_dest_path, dest_path) + else: + print(f"Downloading checkpoint to {dest_path}") + wandb_util.download_file_from_url(dest_path, file.url, api.api_key) + return dest_path + print(f"No model found for epoch {epoch}") + return None + + +def _run_batch(batch, model, collate): + collated = collate(batch) + collated.x = collated.to_x(model.device) + if collated.y is not None: + collated.y = collated.to_y(model.device) + processable_data = model._process_batch(collated, 0) + del processable_data["loss_kwargs"] + model_output = model(processable_data, **processable_data["model_kwargs"]) + preds, labels = model._get_prediction_and_labels( + processable_data, processable_data["labels"], model_output + ) + return preds, labels + + +def _concat_tuple(l): + if isinstance(l[0], tuple): + print(l[0]) + return tuple([torch.cat([t[i] for t in l]) for i in range(len(l[0]))]) + return torch.cat(l) + + +def evaluate_model( + model: ChebaiBaseNet, + data_module: XYBaseDataModule, + filename: Optional[str] = None, + buffer_dir: Optional[str] = None, + batch_size: int = 32, + skip_existing_preds: bool = False, + kind: str = "test", +) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: + """ + Runs the model on the test set of the data module or on the dataset found in the specified file. + If buffer_dir is set, results will be saved in buffer_dir. + + Note: + No need to provide "filename" parameter for Chebi dataset, "kind" parameter should be provided. + + Args: + model: The model to evaluate. + data_module: The data module containing the dataset. + filename: Optional file name for the dataset. + buffer_dir: Optional directory to save the results. + batch_size: The batch size for evaluation. + skip_existing_preds: Whether to skip evaluation if predictions already exist. + kind: Kind of split of the data to be used for testing the model. Default is `test`. + + Returns: + Tensors with predictions and labels. + """ + model.eval() + collate = data_module.reader.COLLATOR() + + if isinstance(data_module, _ChEBIDataExtractor): + # As the dynamic split change is implemented only for chebi-dataset as of now + data_df = data_module.dynamic_split_dfs[kind] + data_list = data_df.to_dict(orient="records") + else: + data_list = data_module.load_processed_data("test", filename) + data_list = data_list[: data_module.data_limit] + preds_list = [] + labels_list = [] + if buffer_dir is not None: + os.makedirs(buffer_dir, exist_ok=True) + save_ind = 0 + save_batch_size = 128 + n_saved = 1 + + print(f"") + for i in tqdm.tqdm(range(0, len(data_list), batch_size)): + if not ( + skip_existing_preds + and os.path.isfile(os.path.join(buffer_dir, f"preds{save_ind:03d}.pt")) + ): + preds, labels = _run_batch(data_list[i : i + batch_size], model, collate) + preds_list.append(preds) + labels_list.append(labels) + + if buffer_dir is not None: + if n_saved * batch_size >= save_batch_size: + torch.save( + _concat_tuple(preds_list), + os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), + ) + if labels_list[0] is not None: + torch.save( + _concat_tuple(labels_list), + os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"), + ) + preds_list = [] + labels_list = [] + if n_saved * batch_size >= save_batch_size: + save_ind += 1 + n_saved = 0 + n_saved += 1 + + if buffer_dir is None: + test_preds = _concat_tuple(preds_list) + if labels_list is not None: + test_labels = _concat_tuple(labels_list) + return test_preds, test_labels + return test_preds, None + elif len(preds_list) < 0: + if len(preds_list) > 0 and preds_list[0] is not None: + torch.save( + _concat_tuple(preds_list), + os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), + ) + if len(labels_list) > 0 and labels_list[0] is not None: + torch.save( + _concat_tuple(labels_list), + os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"), + ) + + +def load_results_from_buffer( + buffer_dir: str, device: torch.device +) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: + """ + Load results stored in evaluate_model() from the buffer directory. + + Args: + buffer_dir: The directory containing the buffered results. + device: The device to load the results onto. + + Returns: + Tensors with predictions and labels. + """ + preds_list = [] + labels_list = [] + + i = 0 + filename = f"preds{i:03d}.pt" + while os.path.isfile(os.path.join(buffer_dir, filename)): + preds_list.append( + torch.load( + os.path.join(buffer_dir, filename), + map_location=torch.device(device), + weights_only=False, + ) + ) + i += 1 + filename = f"preds{i:03d}.pt" + + i = 0 + filename = f"labels{i:03d}.pt" + while os.path.isfile(os.path.join(buffer_dir, filename)): + labels_list.append( + torch.load( + os.path.join(buffer_dir, filename), + map_location=torch.device(device), + weights_only=False, + ) + ) + i += 1 + filename = f"labels{i:03d}.pt" + + if len(preds_list) > 0: + test_preds = torch.cat(preds_list) + else: + test_preds = None + if len(labels_list) > 0: + test_labels = torch.cat(labels_list) + else: + test_labels = None + + return test_preds, test_labels + + +if __name__ == "__main__": + import sys + + buffer_dir = os.path.join("results_buffer", sys.argv[1], "ChEBIOver100_train") + buffer_dir_concat = os.path.join( + "results_buffer", "concatenated", sys.argv[1], "ChEBIOver100_train" + ) + os.makedirs(buffer_dir_concat, exist_ok=True) + preds, labels = load_results_from_buffer(buffer_dir, "cpu") + torch.save(preds, os.path.join(buffer_dir_concat, f"preds000.pt")) + torch.save(labels, os.path.join(buffer_dir_concat, f"labels000.pt")) diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py new file mode 100644 index 0000000..874d6b3 --- /dev/null +++ b/chebai/trainer/CustomTrainer.py @@ -0,0 +1,149 @@ +import logging +from typing import Any, List, Optional, Tuple + +import pandas as pd +import torch +from lightning import LightningModule, Trainer +from lightning.fabric.utilities.types import _PATH +from lightning.pytorch.loggers import WandbLogger +from torch.nn.utils.rnn import pad_sequence + +from chebai.loggers.custom import CustomLogger +from chebai.preprocessing.reader import CLS_TOKEN, ChemDataReader + +log = logging.getLogger(__name__) + + +class CustomTrainer(Trainer): + def __init__(self, *args, **kwargs): + """ + Initializes the CustomTrainer class, logging additional hyperparameters to the custom logger if specified. + + Args: + *args: Positional arguments for the Trainer class. + **kwargs: Keyword arguments for the Trainer class. + """ + self.init_args = args + self.init_kwargs = kwargs + super().__init__(*args, **kwargs) + # instantiation custom logger connector + self._logger_connector.on_trainer_init(self.logger, 1) + # log additional hyperparameters to wandb + if isinstance(self.logger, CustomLogger): + custom_logger = self.logger + assert isinstance(custom_logger, CustomLogger) + if custom_logger.verbose_hyperparameters: + log_kwargs = {} + for key, value in self.init_kwargs.items(): + log_key, log_value = self._resolve_logging_argument(key, value) + log_kwargs[log_key] = log_value + self.logger.log_hyperparams(log_kwargs) + + def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: + """ + Resolves logging arguments, handling nested structures such as lists and complex objects. + + Args: + key: The key of the argument. + value: The value of the argument. + + Returns: + A tuple containing the resolved key and value. + """ + if isinstance(value, list): + key_value_pairs = [ + self._resolve_logging_argument(f"{key}_{i}", v) + for i, v in enumerate(value) + ] + return key, {k: v for k, v in key_value_pairs} + if not ( + isinstance(value, str) + or isinstance(value, float) + or isinstance(value, int) + or value is None + ): + params = {"class": value.__class__} + params.update(value.__dict__) + return key, params + else: + return key, value + + def predict_from_file( + self, + model: LightningModule, + checkpoint_path: _PATH, + input_path: _PATH, + save_to: _PATH = "predictions.csv", + classes_path: Optional[_PATH] = None, + ) -> None: + """ + Loads a model from a checkpoint and makes predictions on input data from a file. + + Args: + model: The model to use for predictions. + checkpoint_path: Path to the model checkpoint. + input_path: Path to the input file containing SMILES strings. + save_to: Path to save the predictions CSV file. + classes_path: Optional path to a file containing class names. + """ + loaded_model = model.__class__.load_from_checkpoint(checkpoint_path) + with open(input_path, "r") as input: + smiles_strings = [inp.strip() for inp in input.readlines()] + loaded_model.eval() + predictions = self._predict_smiles(loaded_model, smiles_strings) + predictions_df = pd.DataFrame(predictions.detach().cpu().numpy()) + if classes_path is not None: + with open(classes_path, "r") as f: + predictions_df.columns = [cls.strip() for cls in f.readlines()] + predictions_df.index = smiles_strings + predictions_df.to_csv(save_to) + + def _predict_smiles( + self, model: LightningModule, smiles: List[str] + ) -> torch.Tensor: + """ + Predicts the output for a list of SMILES strings using the model. + + Args: + model: The model to use for predictions. + smiles: A list of SMILES strings. + + Returns: + A tensor containing the predictions. + """ + reader = ChemDataReader() + parsed_smiles = [reader._read_data(s) for s in smiles] + x = pad_sequence( + [torch.tensor(a, device=model.device) for a in parsed_smiles], + batch_first=True, + ) + cls_tokens = ( + torch.ones(x.shape[0], dtype=torch.int, device=model.device).unsqueeze(-1) + * CLS_TOKEN + ) + features = torch.cat((cls_tokens, x), dim=1) + model_output = model({"features": features}) + preds = torch.sigmoid(model_output["logits"]) + + print(preds.shape) + return preds + + @property + def log_dir(self) -> Optional[str]: + """ + Returns the logging directory. + + Returns: + The path to the logging directory if available, else the default root directory. + """ + if len(self.loggers) > 0: + logger = self.loggers[0] + if isinstance(logger, WandbLogger): + dirpath = logger.experiment.dir + else: + dirpath = self.loggers[0].log_dir + else: + dirpath = self.default_root_dir + + dirpath = self.strategy.broadcast(dirpath) + return dirpath diff --git a/chebai/trainer/__init__.py b/chebai/trainer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/configs/data/deepGO/deepgo2_esm2.yml b/configs/data/deepGO/deepgo2_esm2.yml new file mode 100644 index 0000000..5a0436e --- /dev/null +++ b/configs/data/deepGO/deepgo2_esm2.yml @@ -0,0 +1,5 @@ +class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.DeepGO2MigratedData +init_args: + go_branch: "MF" + max_sequence_length: 1000 + use_esm2_embeddings: True diff --git a/configs/data/deepGO/deepgo_1_migrated_data.yml b/configs/data/deepGO/deepgo_1_migrated_data.yml new file mode 100644 index 0000000..0924e02 --- /dev/null +++ b/configs/data/deepGO/deepgo_1_migrated_data.yml @@ -0,0 +1,4 @@ +class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.DeepGO1MigratedData +init_args: + go_branch: "MF" + max_sequence_length: 1002 diff --git a/configs/data/deepGO/deepgo_2_migrated_data.yml b/configs/data/deepGO/deepgo_2_migrated_data.yml new file mode 100644 index 0000000..5a0436e --- /dev/null +++ b/configs/data/deepGO/deepgo_2_migrated_data.yml @@ -0,0 +1,5 @@ +class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.DeepGO2MigratedData +init_args: + go_branch: "MF" + max_sequence_length: 1000 + use_esm2_embeddings: True diff --git a/configs/data/deepGO/go250.yml b/configs/data/deepGO/go250.yml new file mode 100644 index 0000000..01e34aa --- /dev/null +++ b/configs/data/deepGO/go250.yml @@ -0,0 +1,3 @@ +class_path: chebai.preprocessing.datasets.go_uniprot.deepGO.GOUniProtOver250 +init_args: + go_branch: "BP" diff --git a/configs/data/deepGO/go50.yml b/configs/data/deepGO/go50.yml new file mode 100644 index 0000000..bee4377 --- /dev/null +++ b/configs/data/deepGO/go50.yml @@ -0,0 +1 @@ +class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.GOUniProtOver50 diff --git a/configs/data/scope/scope2000.yml b/configs/data/scope/scope2000.yml new file mode 100644 index 0000000..d75c807 --- /dev/null +++ b/configs/data/scope/scope2000.yml @@ -0,0 +1,3 @@ +class_path: chebai.preprocessing.datasets.scope.scope.SCOPeOver2000 +init_args: + scope_version: "2.08" diff --git a/configs/data/scope/scope50.yml b/configs/data/scope/scope50.yml new file mode 100644 index 0000000..c65028e --- /dev/null +++ b/configs/data/scope/scope50.yml @@ -0,0 +1,3 @@ +class_path: chebai.preprocessing.datasets.scope.scope.SCOPeOver50 +init_args: + scope_version: "2.08" \ No newline at end of file diff --git a/configs/default_prediction_callback.yml b/configs/default_prediction_callback.yml new file mode 100644 index 0000000..152b5d1 --- /dev/null +++ b/configs/default_prediction_callback.yml @@ -0,0 +1,4 @@ +class_path: chebai.callbacks.prediction_callback.PredictionWriter +init_args: + output_dir: pred + write_interval: epoch diff --git a/configs/loss/bce.yml b/configs/loss/bce.yml new file mode 100644 index 0000000..e2fc30b --- /dev/null +++ b/configs/loss/bce.yml @@ -0,0 +1 @@ +class_path: chebai.loss.bce_weighted.BCEWeighted diff --git a/configs/loss/electra_pre_loss.yml b/configs/loss/electra_pre_loss.yml new file mode 100644 index 0000000..06520b2 --- /dev/null +++ b/configs/loss/electra_pre_loss.yml @@ -0,0 +1 @@ +class_path: chebai.loss.pretraining.ElectraPreLoss diff --git a/configs/loss/semantic_loss.yml b/configs/loss/semantic_loss.yml new file mode 100644 index 0000000..5434084 --- /dev/null +++ b/configs/loss/semantic_loss.yml @@ -0,0 +1,10 @@ +class_path: chebai.loss.semantic.DisjointLoss +init_args: + path_to_disjointness: data/disjoint.csv + base_loss: + class_path: chebai.loss.bce_weighted.BCEWeighted + init_args: + beta: 0.99 + multiply_by_softmax: true + impl_loss_weight: 100 + disjoint_loss_weight: 1000000 diff --git a/configs/metrics/balanced-accuracy.yml b/configs/metrics/balanced-accuracy.yml new file mode 100644 index 0000000..eb079ed --- /dev/null +++ b/configs/metrics/balanced-accuracy.yml @@ -0,0 +1,5 @@ +class_path: torchmetrics.MetricCollection +init_args: + metrics: + balanced-accuracy: + class_path: chebai.callbacks.epoch_metrics.BalancedAccuracy diff --git a/configs/metrics/micro-macro-f1.yml b/configs/metrics/micro-macro-f1.yml new file mode 100644 index 0000000..9cae109 --- /dev/null +++ b/configs/metrics/micro-macro-f1.yml @@ -0,0 +1,9 @@ +class_path: torchmetrics.MetricCollection +init_args: + metrics: + micro-f1: + class_path: torchmetrics.classification.MultilabelF1Score + init_args: + average: micro + macro-f1: + class_path: chebai.callbacks.epoch_metrics.MacroF1 diff --git a/configs/metrics/single-class-f1.yml b/configs/metrics/single-class-f1.yml new file mode 100644 index 0000000..fbcd63d --- /dev/null +++ b/configs/metrics/single-class-f1.yml @@ -0,0 +1,5 @@ +class_path: torchmetrics.MetricCollection +init_args: + metrics: + f1: + class_path: torchmetrics.classification.BinaryF1Score diff --git a/configs/model/electra-for-pretraining.yml b/configs/model/electra-for-pretraining.yml new file mode 100644 index 0000000..80acd9a --- /dev/null +++ b/configs/model/electra-for-pretraining.yml @@ -0,0 +1,20 @@ +class_path: chebai.models.ElectraPre +init_args: + criterion: + class_path: chebai.loss.pretraining.ElectraPreLoss + out_dim: null + optimizer_kwargs: + lr: 1e-4 + config: + generator: + vocab_size: 1400 + max_position_embeddings: 1800 + num_attention_heads: 8 + num_hidden_layers: 6 + type_vocab_size: 1 + discriminator: + vocab_size: 1400 + max_position_embeddings: 1800 + num_attention_heads: 8 + num_hidden_layers: 6 + type_vocab_size: 1 diff --git a/configs/model/electra.yml b/configs/model/electra.yml new file mode 100644 index 0000000..c3cf2fd --- /dev/null +++ b/configs/model/electra.yml @@ -0,0 +1,11 @@ +class_path: chebai.models.Electra +init_args: + optimizer_kwargs: + lr: 1e-3 + config: + vocab_size: 1400 + max_position_embeddings: 1800 + num_attention_heads: 8 + num_hidden_layers: 6 + type_vocab_size: 1 + hidden_size: 256 diff --git a/configs/model/electra_pretraining.yml b/configs/model/electra_pretraining.yml new file mode 100644 index 0000000..f480a79 --- /dev/null +++ b/configs/model/electra_pretraining.yml @@ -0,0 +1,18 @@ +class_path: chebai.models.ElectraPre +init_args: + out_dim: null + optimizer_kwargs: + lr: 1e-4 + config: + generator: + vocab_size: 1400 + max_position_embeddings: 1800 + num_attention_heads: 8 + num_hidden_layers: 6 + type_vocab_size: 1 + discriminator: + vocab_size: 1400 + max_position_embeddings: 1800 + num_attention_heads: 8 + num_hidden_layers: 6 + type_vocab_size: 1 diff --git a/configs/model/ffn.yml b/configs/model/ffn.yml new file mode 100644 index 0000000..ba94a43 --- /dev/null +++ b/configs/model/ffn.yml @@ -0,0 +1,5 @@ +class_path: chebai.models.ffn.FFN +init_args: + optimizer_kwargs: + lr: 1e-3 + input_size: 2560 diff --git a/configs/training/csv_logger.yml b/configs/training/csv_logger.yml new file mode 100644 index 0000000..86a94ba --- /dev/null +++ b/configs/training/csv_logger.yml @@ -0,0 +1,3 @@ +class_path: lightning.pytorch.loggers.CSVLogger +init_args: + save_dir: logs diff --git a/configs/training/default_callbacks.yml b/configs/training/default_callbacks.yml new file mode 100644 index 0000000..ade7d14 --- /dev/null +++ b/configs/training/default_callbacks.yml @@ -0,0 +1,12 @@ +- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint + init_args: + monitor: val_micro-f1 + mode: 'max' + filename: 'best_{epoch:02d}_{val_loss:.4f}_{val_macro-f1:.4f}_{val_micro-f1:.4f}' + every_n_epochs: 1 + save_top_k: 3 +- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint + init_args: + filename: 'per_{epoch:02d}_{val_loss:.4f}_{val_macro-f1:.4f}_{val_micro-f1:.4f}' + every_n_epochs: 25 + save_top_k: -1 diff --git a/configs/training/default_trainer.yml b/configs/training/default_trainer.yml new file mode 100644 index 0000000..91aa424 --- /dev/null +++ b/configs/training/default_trainer.yml @@ -0,0 +1,5 @@ +min_epochs: 100 +max_epochs: 100 +default_root_dir: &default_root_dir logs +logger: csv_logger.yml +callbacks: default_callbacks.yml diff --git a/configs/training/early_stop_callbacks.yml b/configs/training/early_stop_callbacks.yml new file mode 100644 index 0000000..9113090 --- /dev/null +++ b/configs/training/early_stop_callbacks.yml @@ -0,0 +1,19 @@ +- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint + init_args: + monitor: val_micro-f1 + mode: 'max' + filename: 'best_{epoch:02d}_{val_loss:.4f}_{val_macro-f1:.4f}_{val_micro-f1:.4f}' + every_n_epochs: 1 + save_top_k: 3 +- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint + init_args: + filename: 'per_{epoch:02d}_{val_loss:.4f}_{val_macro-f1:.4f}_{val_micro-f1:.4f}' + every_n_epochs: 25 + save_top_k: -1 +- class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping + init_args: + monitor: "val_loss_epoch" + min_delta: 0.0 + patience: 3 + verbose: False + mode: "min" diff --git a/configs/training/pretraining_callbacks.yml b/configs/training/pretraining_callbacks.yml new file mode 100644 index 0000000..0862433 --- /dev/null +++ b/configs/training/pretraining_callbacks.yml @@ -0,0 +1,12 @@ +- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint + init_args: + monitor: val_loss + mode: 'min' + filename: 'best_{epoch}_{val_loss:.4f}' + every_n_epochs: 1 + save_top_k: 3 +- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint + init_args: + filename: 'per_{epoch}_{val_loss:.4f}' + every_n_epochs: 25 + save_top_k: -1 diff --git a/configs/training/pretraining_trainer.yml b/configs/training/pretraining_trainer.yml new file mode 100644 index 0000000..6c56870 --- /dev/null +++ b/configs/training/pretraining_trainer.yml @@ -0,0 +1,7 @@ +min_epochs: 100 +max_epochs: 100 + +default_root_dir: &default_root_dir logs +logger: csv_logger.yml + +callbacks: pretraining_callbacks.yml diff --git a/configs/training/single_class_callbacks.yml b/configs/training/single_class_callbacks.yml new file mode 100644 index 0000000..73f4a72 --- /dev/null +++ b/configs/training/single_class_callbacks.yml @@ -0,0 +1,13 @@ +- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint + init_args: + monitor: val_f1 + mode: 'max' + filename: 'best_{epoch:02d}_{val_loss:.4f}_{val_f1:.4f}' + every_n_epochs: 1 + save_top_k: 3 +- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint + init_args: + filename: 'per_{epoch:02d}_{val_loss:.4f}_{val_f1:.4f}' + every_n_epochs: 25 + save_top_k: -1 +# difference to default_callbacks.yml: no macro-f1 diff --git a/configs/training/wandb_logger.yml b/configs/training/wandb_logger.yml new file mode 100644 index 0000000..b0dd887 --- /dev/null +++ b/configs/training/wandb_logger.yml @@ -0,0 +1,6 @@ +class_path: chebai.loggers.custom.CustomLogger # Extension of Wandb logger +init_args: + save_dir: logs + project: 'chebai' + entity: 'chebai' + log_model: 'all' diff --git a/docs/source/experiment.rst b/docs/source/experiment.rst new file mode 100644 index 0000000..59aced7 --- /dev/null +++ b/docs/source/experiment.rst @@ -0,0 +1 @@ +.. autoclass:: chebai.experiments.Experiment diff --git a/docs/source/model.rst b/docs/source/model.rst new file mode 100644 index 0000000..59aced7 --- /dev/null +++ b/docs/source/model.rst @@ -0,0 +1 @@ +.. autoclass:: chebai.experiments.Experiment diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000..034dc5b --- /dev/null +++ b/setup.cfg @@ -0,0 +1,7 @@ +[tool:isort] +profile = black +from_first = True +line_length = 79 +known_first_party = chem +default_section = THIRDPARTY +skip = .tox,.eggs,ci/bootstrap.py,ci/templates,build,dist diff --git a/setup.py b/setup.py new file mode 100644 index 0000000..8a6d3e0 --- /dev/null +++ b/setup.py @@ -0,0 +1,57 @@ +from setuptools import find_packages, setup + +packages = find_packages() +print(packages) +setup( + name="chebai", + version="0.0.2.dev0", + packages=packages, + package_data={"": ["**/*.txt", "**/*.json"]}, + include_package_data=True, + url="", + license="", + author="MGlauer", + author_email="martin.glauer@ovgu.de", + description="", + zip_safe=False, + python_requires=">=3.9, <3.13", + install_requires=[ + "certifi", + "idna", + "joblib", + "networkx", + "numpy<2", + "pandas", + "python-dateutil", + "pytz", + "requests", + "scikit-learn", + "scipy", + "six", + "threadpoolctl", + "torch", + "typing-extensions", + "urllib3", + "transformers", + "fastobo", + "pysmiles==1.1.2", + "scikit-network", + "svgutils", + "matplotlib", + "rdkit", + "selfies", + "lightning>=2.5", + "jsonargparse[signatures]>=4.17", + "omegaconf", + "seaborn", + "deepsmiles", + "iterative-stratification", + "wandb", + "chardet", + "pyyaml", + "torchmetrics", + "biopython", + "fair-esm", + ], + extras_require={"dev": ["black", "isort", "pre-commit"]}, +) diff --git a/tests/__init__.py b/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/__init__.py b/tests/unit/__init__.py new file mode 100644 index 0000000..6640a69 --- /dev/null +++ b/tests/unit/__init__.py @@ -0,0 +1,4 @@ +""" +This directory contains unit tests, which focus on individual functions and methods, ensuring they work as +expected in isolation. +""" diff --git a/tests/unit/collators/__init__.py b/tests/unit/collators/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/collators/testDefaultCollator.py b/tests/unit/collators/testDefaultCollator.py new file mode 100644 index 0000000..73f09c7 --- /dev/null +++ b/tests/unit/collators/testDefaultCollator.py @@ -0,0 +1,65 @@ +import unittest +from typing import Dict, List + +from chebai.preprocessing.collate import DefaultCollator +from chebai.preprocessing.structures import XYData + + +class TestDefaultCollator(unittest.TestCase): + """ + Unit tests for the DefaultCollator class. + """ + + @classmethod + def setUpClass(cls) -> None: + """ + Set up the test environment by initializing a DefaultCollator instance. + """ + cls.collator = DefaultCollator() + + def test_call_with_valid_data(self) -> None: + """ + Test the __call__ method with valid data to ensure features and labels are correctly extracted. + """ + data: List[Dict] = [ + {"features": [1.0, 2.0], "labels": [True, False, True]}, + {"features": [3.0, 4.0], "labels": [False, False, True]}, + ] + + result: XYData = self.collator(data) + self.assertIsInstance( + result, XYData, "The result should be an instance of XYData." + ) + + expected_x = ([1.0, 2.0], [3.0, 4.0]) + expected_y = ([True, False, True], [False, False, True]) + + self.assertEqual( + result.x, + expected_x, + "The feature data 'x' does not match the expected output.", + ) + self.assertEqual( + result.y, + expected_y, + "The label data 'y' does not match the expected output.", + ) + + def test_call_with_empty_data(self) -> None: + """ + Test the __call__ method with an empty list to ensure it handles the edge case correctly. + """ + data: List[Dict] = [] + + with self.assertRaises(ValueError) as context: + self.collator(data) + + self.assertEqual( + str(context.exception), + "not enough values to unpack (expected 2, got 0)", + "The exception message for empty data is not as expected.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/collators/testRaggedCollator.py b/tests/unit/collators/testRaggedCollator.py new file mode 100644 index 0000000..d9ab2b1 --- /dev/null +++ b/tests/unit/collators/testRaggedCollator.py @@ -0,0 +1,204 @@ +import unittest +from typing import Dict, List, Tuple + +import torch + +from chebai.preprocessing.collate import RaggedCollator +from chebai.preprocessing.structures import XYData + + +class TestRaggedCollator(unittest.TestCase): + """ + Unit tests for the RaggedCollator class. + """ + + @classmethod + def setUpClass(cls) -> None: + """ + Set up the test environment by initializing a RaggedCollator instance. + """ + cls.collator = RaggedCollator() + + def test_call_with_valid_data(self) -> None: + """ + Test the __call__ method with valid ragged data to ensure features, labels, and masks are correctly handled. + """ + data: List[Dict] = [ + {"features": [1, 2], "labels": [True, False], "ident": "sample1"}, + {"features": [3, 4, 5], "labels": [False, True, True], "ident": "sample2"}, + {"features": [6], "labels": [True], "ident": "sample3"}, + ] + + result: XYData = self.collator(data) + + expected_x = torch.tensor([[1, 2, 0], [3, 4, 5], [6, 0, 0]]) + expected_y = torch.tensor( + [[True, False, False], [False, True, True], [True, False, False]] + ) + expected_mask_for_x = torch.tensor( + [[True, True, False], [True, True, True], [True, False, False]] + ) + expected_lens_for_x = torch.tensor([2, 3, 1]) + + self.assertTrue( + torch.equal(result.x, expected_x), + "The feature tensor 'x' does not match the expected output.", + ) + self.assertTrue( + torch.equal(result.y, expected_y), + "The label tensor 'y' does not match the expected output.", + ) + self.assertTrue( + torch.equal( + result.additional_fields["model_kwargs"]["mask"], expected_mask_for_x + ), + "The mask tensor does not match the expected output.", + ) + self.assertTrue( + torch.equal( + result.additional_fields["model_kwargs"]["lens"], expected_lens_for_x + ), + "The lens tensor does not match the expected output.", + ) + self.assertEqual( + result.additional_fields["idents"], + ("sample1", "sample2", "sample3"), + "The identifiers do not match the expected output.", + ) + + def test_call_with_missing_entire_labels(self) -> None: + """ + Test the __call__ method with data where some samples are missing labels. + """ + data: List[Dict] = [ + {"features": [1, 2], "labels": [True, False], "ident": "sample1"}, + {"features": [3, 4, 5], "labels": None, "ident": "sample2"}, + {"features": [6], "labels": [True], "ident": "sample3"}, + ] + + result: XYData = self.collator(data) + + # https://github.com/ChEB-AI/python-chebai/pull/48#issuecomment-2324393829 + expected_x = torch.tensor([[1, 2, 0], [3, 4, 5], [6, 0, 0]]) + expected_y = torch.tensor( + [[True, False], [True, False]] + ) # True -> 1, False -> 0 + expected_mask_for_x = torch.tensor( + [[True, True, False], [True, True, True], [True, False, False]] + ) + expected_lens_for_x = torch.tensor([2, 3, 1]) + + self.assertTrue( + torch.equal(result.x, expected_x), + "The feature tensor 'x' does not match the expected output when labels are missing.", + ) + self.assertTrue( + torch.equal(result.y, expected_y), + "The label tensor 'y' does not match the expected output when labels are missing.", + ) + self.assertTrue( + torch.equal( + result.additional_fields["model_kwargs"]["mask"], expected_mask_for_x + ), + "The mask tensor does not match the expected output when labels are missing.", + ) + self.assertTrue( + torch.equal( + result.additional_fields["model_kwargs"]["lens"], expected_lens_for_x + ), + "The lens tensor does not match the expected output when labels are missing.", + ) + self.assertEqual( + result.additional_fields["loss_kwargs"]["non_null_labels"], + [0, 2], + "The non-null labels list does not match the expected output.", + ) + self.assertEqual( + len(result.additional_fields["loss_kwargs"]["non_null_labels"]), + result.y.shape[1], + "The length of non null labels list must match with target label variable size", + ) + self.assertEqual( + result.additional_fields["idents"], + ("sample1", "sample2", "sample3"), + "The identifiers do not match the expected output when labels are missing.", + ) + + def test_call_with_none_in_labels(self) -> None: + """ + Test the __call__ method with data where one of the elements in the labels is None. + """ + data: List[Dict] = [ + {"features": [1, 2], "labels": [None, True], "ident": "sample1"}, + {"features": [3, 4, 5], "labels": [True, False], "ident": "sample2"}, + {"features": [6], "labels": [True], "ident": "sample3"}, + ] + + result: XYData = self.collator(data) + + expected_x = torch.tensor([[1, 2, 0], [3, 4, 5], [6, 0, 0]]) + expected_y = torch.tensor( + [[False, True], [True, False], [True, False]] + ) # None -> False + expected_mask_for_x = torch.tensor( + [[True, True, False], [True, True, True], [True, False, False]] + ) + expected_lens_for_x = torch.tensor([2, 3, 1]) + + self.assertTrue( + torch.equal(result.x, expected_x), + "The feature tensor 'x' does not match the expected output when labels contain None.", + ) + self.assertTrue( + torch.equal(result.y, expected_y), + "The label tensor 'y' does not match the expected output when labels contain None.", + ) + self.assertTrue( + torch.equal( + result.additional_fields["model_kwargs"]["mask"], expected_mask_for_x + ), + "The mask tensor does not match the expected output when labels contain None.", + ) + self.assertTrue( + torch.equal( + result.additional_fields["model_kwargs"]["lens"], expected_lens_for_x + ), + "The lens tensor does not match the expected output when labels contain None.", + ) + self.assertEqual( + result.additional_fields["idents"], + ("sample1", "sample2", "sample3"), + "The identifiers do not match the expected output when labels contain None.", + ) + + def test_call_with_empty_data(self) -> None: + """ + Test the __call__ method with an empty list to ensure it raises an error. + """ + data: List[Dict] = [] + + with self.assertRaises( + Exception, msg="Expected an Error when no data is provided" + ): + self.collator(data) + + def test_process_label_rows(self) -> None: + """ + Test the process_label_rows method to ensure it pads label sequences correctly. + """ + labels: Tuple = ([True, False], [False, True, True], [True]) + + result: torch.Tensor = self.collator.process_label_rows(labels) + + expected_output = torch.tensor( + [[True, False, False], [False, True, True], [True, False, False]] + ) + + self.assertTrue( + torch.equal(result, expected_output), + "The processed label rows tensor does not match the expected output.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/dataset_classes/__init__.py b/tests/unit/dataset_classes/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/dataset_classes/testDynamicDataset.py b/tests/unit/dataset_classes/testDynamicDataset.py new file mode 100644 index 0000000..c884627 --- /dev/null +++ b/tests/unit/dataset_classes/testDynamicDataset.py @@ -0,0 +1,372 @@ +import unittest +from typing import Tuple +from unittest.mock import MagicMock, PropertyMock, patch + +import pandas as pd + +from chebai.preprocessing.datasets.base import _DynamicDataset + + +class TestDynamicDataset(unittest.TestCase): + """ + Test case for _DynamicDataset functionality, ensuring correct data splits and integrity + of train, validation, and test datasets. + """ + + @classmethod + @patch.multiple(_DynamicDataset, __abstractmethods__=frozenset()) + @patch.object(_DynamicDataset, "base_dir", new_callable=PropertyMock) + @patch.object(_DynamicDataset, "_name", new_callable=PropertyMock) + @patch("os.makedirs", return_value=None) + def setUpClass( + cls, + mock_makedirs, + mock_base_dir_property: PropertyMock, + mock_name_property: PropertyMock, + ) -> None: + """ + Set up a base instance of _DynamicDataset for testing with mocked properties. + """ + + # Mocking properties + mock_base_dir_property.return_value = "MockedBaseDirPropertyDynamicDataset" + mock_name_property.return_value = "MockedNamePropertyDynamicDataset" + + # Mock Data Reader + ReaderMock = MagicMock() + ReaderMock.name.return_value = "MockedReader" + _DynamicDataset.READER = ReaderMock + + # Creating an instance of the dataset + cls.dataset: _DynamicDataset = _DynamicDataset() + + # Dataset with a balanced distribution of labels + X = [ + [1, 2], + [3, 4], + [5, 6], + [7, 8], + [9, 10], + [11, 12], + [13, 14], + [15, 16], + [17, 18], + [19, 20], + [21, 22], + [23, 24], + [25, 26], + [27, 28], + [29, 30], + [31, 32], + ] + y = [ + [False, False], + [False, True], + [True, False], + [True, True], + [False, False], + [False, True], + [True, False], + [True, True], + [False, False], + [False, True], + [True, False], + [True, True], + [False, False], + [False, True], + [True, False], + [True, True], + ] + cls.data_df = pd.DataFrame( + {"ident": [f"id{i + 1}" for i in range(len(X))], "features": X, "labels": y} + ) + + def test_get_test_split_valid(self) -> None: + """ + Test splitting the dataset into train and test sets and verify balance and non-overlap. + """ + self.dataset.train_split = 0.5 + # Test size will be 0.25 * 16 = 4 + train_df, test_df = self.dataset.get_test_split(self.data_df, seed=0) + + # Assert the correct number of rows in train and test sets + self.assertEqual(len(train_df), 12, "Train set should contain 12 samples.") + self.assertEqual(len(test_df), 4, "Test set should contain 4 samples.") + + # Check positive and negative label counts in train and test sets + train_pos_count, train_neg_count = self.get_positive_negative_labels_counts( + train_df + ) + test_pos_count, test_neg_count = self.get_positive_negative_labels_counts( + test_df + ) + + # Ensure that the train and test sets have balanced positives and negatives + self.assertEqual( + train_pos_count, train_neg_count, "Train set labels should be balanced." + ) + self.assertEqual( + test_pos_count, test_neg_count, "Test set labels should be balanced." + ) + + # Assert there is no overlap between train and test sets + train_idents = set(train_df["ident"]) + test_idents = set(test_df["ident"]) + self.assertEqual( + len(train_idents.intersection(test_idents)), + 0, + "Train and test sets should not overlap.", + ) + + def test_get_test_split_missing_labels(self) -> None: + """ + Test the behavior when the 'labels' column is missing in the dataset. + """ + df_missing_labels = pd.DataFrame({"ident": ["id1", "id2"]}) + with self.assertRaises( + KeyError, msg="Expected KeyError when 'labels' column is missing." + ): + self.dataset.get_test_split(df_missing_labels) + + def test_get_test_split_seed_consistency(self) -> None: + """ + Test that splitting the dataset with the same seed produces consistent results. + """ + train_df1, test_df1 = self.dataset.get_test_split(self.data_df, seed=42) + train_df2, test_df2 = self.dataset.get_test_split(self.data_df, seed=42) + + pd.testing.assert_frame_equal( + train_df1, + train_df2, + obj="Train sets should be identical for the same seed.", + ) + pd.testing.assert_frame_equal( + test_df1, test_df2, obj="Test sets should be identical for the same seed." + ) + + def test_get_train_val_splits_given_test(self) -> None: + """ + Test splitting the dataset into train and validation sets and verify balance and non-overlap. + """ + self.dataset.use_inner_cross_validation = False + self.dataset.train_split = 0.5 + df_train_main, test_df = self.dataset.get_test_split(self.data_df, seed=0) + train_df, val_df = self.dataset.get_train_val_splits_given_test( + df_train_main, test_df, seed=42 + ) + + # Ensure there is no overlap between train and test sets + train_idents = set(train_df["ident"]) + test_idents = set(test_df["ident"]) + self.assertEqual( + len(train_idents.intersection(test_idents)), + 0, + "Train and test sets should not overlap.", + ) + + # Ensure there is no overlap between validation and test sets + val_idents = set(val_df["ident"]) + self.assertEqual( + len(val_idents.intersection(test_idents)), + 0, + "Validation and test sets should not overlap.", + ) + + # Ensure there is no overlap between train and validation sets + self.assertEqual( + len(train_idents.intersection(val_idents)), + 0, + "Train and validation sets should not overlap.", + ) + + # Check positive and negative label counts in train and validation sets + train_pos_count, train_neg_count = self.get_positive_negative_labels_counts( + train_df + ) + val_pos_count, val_neg_count = self.get_positive_negative_labels_counts(val_df) + + # Ensure that the train and validation sets have balanced positives and negatives + self.assertEqual( + train_pos_count, train_neg_count, "Train set labels should be balanced." + ) + self.assertEqual( + val_pos_count, val_neg_count, "Validation set labels should be balanced." + ) + + def test_get_train_val_splits_given_test_consistency(self) -> None: + """ + Test that splitting the dataset into train and validation sets with the same seed produces consistent results. + """ + test_df = self.data_df.iloc[12:] # Assume rows 12 onward are for testing + train_df1, val_df1 = self.dataset.get_train_val_splits_given_test( + self.data_df, test_df, seed=42 + ) + train_df2, val_df2 = self.dataset.get_train_val_splits_given_test( + self.data_df, test_df, seed=42 + ) + + pd.testing.assert_frame_equal( + train_df1, + train_df2, + obj="Train sets should be identical for the same seed.", + ) + pd.testing.assert_frame_equal( + val_df1, + val_df2, + obj="Validation sets should be identical for the same seed.", + ) + + def test_get_test_split_stratification(self) -> None: + """ + Test that the split into train and test sets maintains the stratification of labels. + """ + self.dataset.train_split = 0.5 + train_df, test_df = self.dataset.get_test_split(self.data_df, seed=0) + + number_of_labels = len(self.data_df["labels"][0]) + + # Check the label distribution in the original dataset + original_pos_count, original_neg_count = ( + self.get_positive_negative_labels_counts(self.data_df) + ) + total_count = len(self.data_df) * number_of_labels + + # Calculate the expected proportions + original_pos_proportion = original_pos_count / total_count + original_neg_proportion = original_neg_count / total_count + + # Check the label distribution in the train set + train_pos_count, train_neg_count = self.get_positive_negative_labels_counts( + train_df + ) + train_total_count = len(train_df) * number_of_labels + + # Calculate the train set proportions + train_pos_proportion = train_pos_count / train_total_count + train_neg_proportion = train_neg_count / train_total_count + + # Assert that the proportions are similar to the original dataset + self.assertAlmostEqual( + train_pos_proportion, + original_pos_proportion, + places=1, + msg="Train set labels should maintain original positive label proportion.", + ) + self.assertAlmostEqual( + train_neg_proportion, + original_neg_proportion, + places=1, + msg="Train set labels should maintain original negative label proportion.", + ) + + # Check the label distribution in the test set + test_pos_count, test_neg_count = self.get_positive_negative_labels_counts( + test_df + ) + test_total_count = len(test_df) * number_of_labels + + # Calculate the test set proportions + test_pos_proportion = test_pos_count / test_total_count + test_neg_proportion = test_neg_count / test_total_count + + # Assert that the proportions are similar to the original dataset + self.assertAlmostEqual( + test_pos_proportion, + original_pos_proportion, + places=1, + msg="Test set labels should maintain original positive label proportion.", + ) + self.assertAlmostEqual( + test_neg_proportion, + original_neg_proportion, + places=1, + msg="Test set labels should maintain original negative label proportion.", + ) + + def test_get_train_val_splits_given_test_stratification(self) -> None: + """ + Test that the split into train and validation sets maintains the stratification of labels. + """ + self.dataset.use_inner_cross_validation = False + self.dataset.train_split = 0.5 + df_train_main, test_df = self.dataset.get_test_split(self.data_df, seed=0) + train_df, val_df = self.dataset.get_train_val_splits_given_test( + df_train_main, test_df, seed=42 + ) + + number_of_labels = len(self.data_df["labels"][0]) + + # Check the label distribution in the original dataset + original_pos_count, original_neg_count = ( + self.get_positive_negative_labels_counts(self.data_df) + ) + total_count = len(self.data_df) * number_of_labels + + # Calculate the expected proportions + original_pos_proportion = original_pos_count / total_count + original_neg_proportion = original_neg_count / total_count + + # Check the label distribution in the train set + train_pos_count, train_neg_count = self.get_positive_negative_labels_counts( + train_df + ) + train_total_count = len(train_df) * number_of_labels + + # Calculate the train set proportions + train_pos_proportion = train_pos_count / train_total_count + train_neg_proportion = train_neg_count / train_total_count + + # Assert that the proportions are similar to the original dataset + self.assertAlmostEqual( + train_pos_proportion, + original_pos_proportion, + places=1, + msg="Train set labels should maintain original positive label proportion.", + ) + self.assertAlmostEqual( + train_neg_proportion, + original_neg_proportion, + places=1, + msg="Train set labels should maintain original negative label proportion.", + ) + + # Check the label distribution in the validation set + val_pos_count, val_neg_count = self.get_positive_negative_labels_counts(val_df) + val_total_count = len(val_df) * number_of_labels + + # Calculate the validation set proportions + val_pos_proportion = val_pos_count / val_total_count + val_neg_proportion = val_neg_count / val_total_count + + # Assert that the proportions are similar to the original dataset + self.assertAlmostEqual( + val_pos_proportion, + original_pos_proportion, + places=1, + msg="Validation set labels should maintain original positive label proportion.", + ) + self.assertAlmostEqual( + val_neg_proportion, + original_neg_proportion, + places=1, + msg="Validation set labels should maintain original negative label proportion.", + ) + + @staticmethod + def get_positive_negative_labels_counts(df: pd.DataFrame) -> Tuple[int, int]: + """ + Count the number of True and False values within the labels column. + + Args: + df (pd.DataFrame): The DataFrame containing the 'labels' column. + + Returns: + Tuple[int, int]: A tuple containing the counts of True and False values, respectively. + """ + true_count = sum(sum(label) for label in df["labels"]) + false_count = sum(len(label) - sum(label) for label in df["labels"]) + return true_count, false_count + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/dataset_classes/testGOUniProDataExtractor.py b/tests/unit/dataset_classes/testGOUniProDataExtractor.py new file mode 100644 index 0000000..96ff9a3 --- /dev/null +++ b/tests/unit/dataset_classes/testGOUniProDataExtractor.py @@ -0,0 +1,229 @@ +import unittest +from collections import OrderedDict +from unittest.mock import PropertyMock, mock_open, patch + +import fastobo +import networkx as nx +import pandas as pd + +from chebai.preprocessing.datasets.deepGO.go_uniprot import _GOUniProtDataExtractor +from chebai.preprocessing.reader import ProteinDataReader +from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData + + +class TestGOUniProtDataExtractor(unittest.TestCase): + """ + Unit tests for the _GOUniProtDataExtractor class. + """ + + @classmethod + @patch.multiple(_GOUniProtDataExtractor, __abstractmethods__=frozenset()) + @patch.object(_GOUniProtDataExtractor, "base_dir", new_callable=PropertyMock) + @patch.object(_GOUniProtDataExtractor, "_name", new_callable=PropertyMock) + @patch("os.makedirs", return_value=None) + def setUpClass( + cls, + mock_makedirs, + mock_name_property: PropertyMock, + mock_base_dir_property: PropertyMock, + ) -> None: + """ + Class setup for mocking abstract properties of _GOUniProtDataExtractor. + """ + mock_base_dir_property.return_value = "MockedBaseDirPropGOUniProtDataExtractor" + mock_name_property.return_value = "MockedNamePropGOUniProtDataExtractor" + + _GOUniProtDataExtractor.READER = ProteinDataReader + + cls.extractor = _GOUniProtDataExtractor() + + def test_term_callback(self) -> None: + """ + Test the term_callback method for correct parsing and filtering of GO terms. + """ + self.extractor.go_branch = "all" + term_mapping = {} + for term in fastobo.loads(GOUniProtMockData.get_GO_raw_data()): + if isinstance(term, fastobo.typedef.TypedefFrame): + continue + term_mapping[self.extractor._parse_go_id(term.id)] = term + + # Test individual term callback + term_dict = self.extractor.term_callback(term_mapping[4]) + expected_dict = {"go_id": 4, "parents": [3, 2], "name": "GO_4"} + self.assertEqual( + term_dict, + expected_dict, + "The term_callback did not return the expected dictionary.", + ) + + # Test filtering valid terms + valid_terms_docs = set() + for term_id, term_doc in term_mapping.items(): + if self.extractor.term_callback(term_doc): + valid_terms_docs.add(term_id) + + self.assertEqual( + valid_terms_docs, + set(GOUniProtMockData.get_nodes()), + "The valid terms do not match expected nodes.", + ) + + # Test that obsolete terms are filtered out + self.assertFalse( + any( + self.extractor.term_callback(term_mapping[obs_id]) + for obs_id in GOUniProtMockData.get_obsolete_nodes_ids() + ), + "Obsolete terms should not be present.", + ) + + # Test filtering by GO branch (e.g., BP) + self.extractor.go_branch = "BP" + BP_terms = { + term_id + for term_id, term in term_mapping.items() + if self.extractor.term_callback(term) + } + self.assertEqual( + BP_terms, {2, 4}, "The BP terms do not match the expected set." + ) + + @patch( + "fastobo.load", return_value=fastobo.loads(GOUniProtMockData.get_GO_raw_data()) + ) + def test_extract_class_hierarchy(self, mock_load) -> None: + """ + Test the extraction of the class hierarchy from the ontology. + """ + graph = self.extractor._extract_class_hierarchy("fake_path") + + # Validate the graph structure + self.assertIsInstance( + graph, nx.DiGraph, "The result should be a directed graph." + ) + + # Check nodes + actual_nodes = set(graph.nodes) + self.assertEqual( + set(GOUniProtMockData.get_nodes()), + actual_nodes, + "The graph nodes do not match the expected nodes.", + ) + + # Check edges + actual_edges = set(graph.edges) + self.assertEqual( + GOUniProtMockData.get_edges_of_transitive_closure_graph(), + actual_edges, + "The graph edges do not match the expected edges.", + ) + + # Check number of nodes and edges + self.assertEqual( + GOUniProtMockData.get_number_of_nodes(), + len(actual_nodes), + "The number of nodes should match the actual number of nodes in the graph.", + ) + + self.assertEqual( + GOUniProtMockData.get_number_of_transitive_edges(), + len(actual_edges), + "The number of transitive edges should match the actual number of transitive edges in the graph.", + ) + + @patch( + "builtins.open", + new_callable=mock_open, + read_data=GOUniProtMockData.get_UniProt_raw_data(), + ) + def test_get_swiss_to_go_mapping(self, mock_open) -> None: + """ + Test the extraction of SwissProt to GO term mapping. + """ + mapping_df = self.extractor._get_swiss_to_go_mapping() + expected_df = pd.DataFrame( + OrderedDict( + swiss_id=["Swiss_Prot_1", "Swiss_Prot_2"], + accession=["Q6GZX4", "DCGZX4"], + go_ids=[[2, 3, 5], [2, 5]], + sequence=list(GOUniProtMockData.protein_sequences().values()), + ) + ) + + pd.testing.assert_frame_equal( + mapping_df, + expected_df, + obj="The SwissProt to GO mapping DataFrame does not match the expected DataFrame.", + ) + + @patch( + "fastobo.load", return_value=fastobo.loads(GOUniProtMockData.get_GO_raw_data()) + ) + @patch( + "builtins.open", + new_callable=mock_open, + read_data=GOUniProtMockData.get_UniProt_raw_data(), + ) + @patch.object( + _GOUniProtDataExtractor, + "select_classes", + return_value=GOUniProtMockData.get_nodes(), + ) + def test_graph_to_raw_dataset( + self, mock_select_classes, mock_open, mock_load + ) -> None: + """ + Test the conversion of the class hierarchy graph to a raw dataset. + """ + graph = self.extractor._extract_class_hierarchy("fake_path") + actual_df = self.extractor._graph_to_raw_dataset(graph) + expected_df = GOUniProtMockData.get_data_in_dataframe() + + pd.testing.assert_frame_equal( + actual_df, + expected_df, + obj="The raw dataset DataFrame does not match the expected DataFrame.", + ) + + @patch("builtins.open", new_callable=mock_open, read_data=b"Mocktestdata") + @patch("pandas.read_pickle") + def test_load_dict( + self, mock_read_pickle: PropertyMock, mock_open: mock_open + ) -> None: + """ + Test the loading of the dictionary from a DataFrame. + """ + mock_df = GOUniProtMockData.get_data_in_dataframe() + mock_read_pickle.return_value = mock_df + + generator = self.extractor._load_dict("data/tests") + result = list(generator) + + # Convert NumPy arrays to lists for comparison + for item in result: + item["labels"] = list(item["labels"]) + + # Expected output for comparison + expected_result = [ + { + "features": mock_df["sequence"][0], + "labels": mock_df.iloc[0, 4:].to_list(), + "ident": mock_df["swiss_id"][0], + }, + { + "features": mock_df["sequence"][1], + "labels": mock_df.iloc[1, 4:].to_list(), + "ident": mock_df["swiss_id"][1], + }, + ] + + self.assertEqual( + result, + expected_result, + "The loaded dictionary does not match the expected structure.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/dataset_classes/testGoUniProtOverX.py b/tests/unit/dataset_classes/testGoUniProtOverX.py new file mode 100644 index 0000000..3f329c5 --- /dev/null +++ b/tests/unit/dataset_classes/testGoUniProtOverX.py @@ -0,0 +1,140 @@ +import unittest +from typing import List +from unittest.mock import mock_open, patch + +import networkx as nx +import pandas as pd + +from chebai.preprocessing.datasets.deepGO.go_uniprot import _GOUniProtOverX +from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData + + +class TestGOUniProtOverX(unittest.TestCase): + @classmethod + @patch.multiple(_GOUniProtOverX, __abstractmethods__=frozenset()) + @patch("os.makedirs", return_value=None) + def setUpClass(cls, mock_makedirs) -> None: + """ + Set up the class for tests by initializing the extractor, graph, and input DataFrame. + """ + cls.extractor = _GOUniProtOverX() + cls.test_graph: nx.DiGraph = GOUniProtMockData.get_transitively_closed_graph() + cls.input_df: pd.DataFrame = GOUniProtMockData.get_data_in_dataframe().iloc[ + :, :4 + ] + + @patch("builtins.open", new_callable=mock_open) + def test_select_classes(self, mock_open_file: mock_open) -> None: + """ + Test the `select_classes` method to ensure it selects classes based on the threshold. + + Args: + mock_open_file (mock_open): Mocked open function to intercept file operations. + """ + # Set threshold for testing + self.extractor.THRESHOLD = 2 + selected_classes: List[int] = self.extractor.select_classes( + self.test_graph, data_df=self.input_df + ) + + # Expected result: GO terms 1, 2, and 5 should be selected based on the threshold + expected_selected_classes: List[int] = sorted([1, 2, 5]) + + # Check if the selected classes are as expected + self.assertEqual( + selected_classes, + expected_selected_classes, + msg="The selected classes do not match the expected output for threshold 2.", + ) + + # Expected data as string + expected_lines: str = "\n".join(map(str, expected_selected_classes)) + "\n" + + # Extract the generator passed to writelines + written_generator = mock_open_file().writelines.call_args[0][0] + written_lines: str = "".join(written_generator) + + # Ensure the data matches + self.assertEqual( + written_lines, + expected_lines, + msg="The written lines do not match the expected lines for the given threshold of 2.", + ) + + @patch("builtins.open", new_callable=mock_open) + def test_no_classes_meet_threshold(self, mock_open_file: mock_open) -> None: + """ + Test the `select_classes` method when no nodes meet the successor threshold. + + Args: + mock_open_file (mock_open): Mocked open function to intercept file operations. + """ + self.extractor.THRESHOLD = 5 + selected_classes: List[int] = self.extractor.select_classes( + self.test_graph, data_df=self.input_df + ) + + # Expected result: No classes should meet the threshold of 5 + expected_selected_classes: List[int] = [] + + # Check if the selected classes are as expected + self.assertEqual( + selected_classes, + expected_selected_classes, + msg="The selected classes list should be empty when no nodes meet the threshold of 5.", + ) + + # Expected data as string + expected_lines: str = "" + + # Extract the generator passed to writelines + written_generator = mock_open_file().writelines.call_args[0][0] + written_lines: str = "".join(written_generator) + + # Ensure the data matches + self.assertEqual( + written_lines, + expected_lines, + msg="The written lines do not match the expected lines when no nodes meet the threshold of 5.", + ) + + @patch("builtins.open", new_callable=mock_open) + def test_all_nodes_meet_threshold(self, mock_open_file: mock_open) -> None: + """ + Test the `select_classes` method when all nodes meet the successor threshold. + + Args: + mock_open_file (mock_open): Mocked open function to intercept file operations. + """ + self.extractor.THRESHOLD = 0 + selected_classes: List[int] = self.extractor.select_classes( + self.test_graph, data_df=self.input_df + ) + + # Expected result: All nodes except those not referenced by any protein (4 and 6) should be selected + expected_classes: List[int] = sorted([1, 2, 3, 5]) + + # Check if the returned selected classes match the expected list + self.assertListEqual( + selected_classes, + expected_classes, + msg="The selected classes do not match the expected output when all nodes meet the threshold of 0.", + ) + + # Expected data as string + expected_lines: str = "\n".join(map(str, expected_classes)) + "\n" + + # Extract the generator passed to writelines + written_generator = mock_open_file().writelines.call_args[0][0] + written_lines: str = "".join(written_generator) + + # Ensure the data matches + self.assertEqual( + written_lines, + expected_lines, + msg="The written lines do not match the expected lines when all nodes meet the threshold of 0.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/dataset_classes/testProteinPretrainingData.py b/tests/unit/dataset_classes/testProteinPretrainingData.py new file mode 100644 index 0000000..caac3ea --- /dev/null +++ b/tests/unit/dataset_classes/testProteinPretrainingData.py @@ -0,0 +1,76 @@ +import unittest +from unittest.mock import PropertyMock, mock_open, patch + +from chebai.preprocessing.datasets.deepGO.protein_pretraining import ( + _ProteinPretrainingData, +) +from chebai.preprocessing.reader import ProteinDataReader +from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData + + +class TestProteinPretrainingData(unittest.TestCase): + """ + Unit tests for the _ProteinPretrainingData class. + Tests focus on data parsing and validation checks for protein pretraining. + """ + + @classmethod + @patch.multiple(_ProteinPretrainingData, __abstractmethods__=frozenset()) + @patch.object(_ProteinPretrainingData, "base_dir", new_callable=PropertyMock) + @patch.object(_ProteinPretrainingData, "_name", new_callable=PropertyMock) + @patch("os.makedirs", return_value=None) + def setUpClass( + cls, + mock_makedirs, + mock_name_property: PropertyMock, + mock_base_dir_property: PropertyMock, + ) -> None: + """ + Class setup for mocking abstract properties of _ProteinPretrainingData. + + Mocks the required abstract properties and sets up the data extractor. + """ + mock_base_dir_property.return_value = "MockedBaseDirPropProteinPretrainingData" + mock_name_property.return_value = "MockedNameProp_ProteinPretrainingData" + + # Set the READER class for the pretraining data + _ProteinPretrainingData.READER = ProteinDataReader + + # Initialize the extractor instance + cls.extractor = _ProteinPretrainingData() + + @patch( + "builtins.open", + new_callable=mock_open, + read_data=GOUniProtMockData.get_UniProt_raw_data(), + ) + def test_parse_protein_data_for_pretraining( + self, mock_open_file: mock_open + ) -> None: + """ + Tests the _parse_protein_data_for_pretraining method. + + Verifies that: + - The parsed DataFrame contains the expected protein IDs. + - The protein sequences are not empty. + """ + # Parse the pretraining data + pretrain_df = self.extractor._parse_protein_data_for_pretraining() + list_of_pretrain_swiss_ids = GOUniProtMockData.proteins_for_pretraining() + + # Assert that all expected Swiss-Prot IDs are present in the DataFrame + self.assertEqual( + set(pretrain_df["swiss_id"]), + set(list_of_pretrain_swiss_ids), + msg="The parsed DataFrame does not contain the expected Swiss-Prot IDs for pretraining.", + ) + + # Assert that all sequences are not empty + self.assertTrue( + pretrain_df["sequence"].str.len().gt(0).all(), + msg="Some protein sequences in the pretraining DataFrame are empty.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/dataset_classes/testXYBaseDataModule.py b/tests/unit/dataset_classes/testXYBaseDataModule.py new file mode 100644 index 0000000..64dfbe4 --- /dev/null +++ b/tests/unit/dataset_classes/testXYBaseDataModule.py @@ -0,0 +1,92 @@ +import unittest +from unittest.mock import MagicMock, PropertyMock, patch + +from chebai.preprocessing.datasets.base import XYBaseDataModule + + +class TestXYBaseDataModule(unittest.TestCase): + """ + Unit tests for the methods of the XYBaseDataModule class. + """ + + @classmethod + @patch.object(XYBaseDataModule, "_name", new_callable=PropertyMock) + @patch("os.makedirs", return_value=None) + def setUpClass(cls, mock_makedirs, mock_name_property: PropertyMock) -> None: + """ + Set up a base instance of XYBaseDataModule for testing. + """ + + # Mock the _name property of XYBaseDataModule + mock_name_property.return_value = "MockedNamePropXYBaseDataModule" + + # Assign a static variable READER with ProteinDataReader (to get rid of default Abstract DataReader) + # Mock Data Reader + ReaderMock = MagicMock() + ReaderMock.name.return_value = "MockedReader" + XYBaseDataModule.READER = ReaderMock + + # Initialize the module with a label_filter + cls.module = XYBaseDataModule( + label_filter=1, # Provide a label_filter + balance_after_filter=1.0, # Balance ratio + ) + + def test_filter_labels_valid_index(self) -> None: + """ + Test the _filter_labels method with a valid label_filter index. + """ + self.module.label_filter = 1 + row = { + "features": ["feature1", "feature2"], + "labels": [0, 3, 1, 2], # List of labels + } + filtered_row = self.module._filter_labels(row) + expected_labels = [3] # Only the label at index 1 should be kept + + self.assertEqual( + filtered_row["labels"], + expected_labels, + "The filtered labels do not match the expected labels.", + ) + + row = { + "features": ["feature1", "feature2"], + "labels": [True, False, True, True], + } + self.assertEqual( + self.module._filter_labels(row)["labels"], + [False], + "The filtered labels for the boolean case do not match the expected labels.", + ) + + def test_filter_labels_no_filter(self) -> None: + """ + Test the _filter_labels method with no label_filter index. + """ + # Update the module to have no label filter + self.module.label_filter = None + row = {"features": ["feature1", "feature2"], "labels": [False, True]} + # Handle the case where the index is out of bounds + with self.assertRaises( + TypeError, msg="Expected a TypeError when no label filter is provided." + ): + self.module._filter_labels(row) + + def test_filter_labels_invalid_index(self) -> None: + """ + Test the _filter_labels method with an invalid label_filter index. + """ + # Set an invalid label filter index (e.g., greater than the number of labels) + self.module.label_filter = 10 + row = {"features": ["feature1", "feature2"], "labels": [False, True]} + # Handle the case where the index is out of bounds + with self.assertRaises( + IndexError, + msg="Expected an IndexError when the label filter index is out of bounds.", + ): + self.module._filter_labels(row) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/mock_data/__init__.py b/tests/unit/mock_data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/mock_data/ontology_mock_data.py b/tests/unit/mock_data/ontology_mock_data.py new file mode 100644 index 0000000..87d24bf --- /dev/null +++ b/tests/unit/mock_data/ontology_mock_data.py @@ -0,0 +1,521 @@ +from abc import ABC, abstractmethod +from collections import OrderedDict +from typing import Dict, List, Set, Tuple + +import networkx as nx +import pandas as pd + + +class MockOntologyGraphData(ABC): + """ + Abstract base class for mocking ontology graph data. + + This class provides a set of static methods that must be implemented by subclasses + to return various elements of an ontology graph such as nodes, edges, and dataframes. + """ + + @staticmethod + @abstractmethod + def get_nodes() -> List[int]: + """ + Get a list of node IDs in the ontology graph. + + Returns: + List[int]: A list of node IDs. + """ + pass + + @staticmethod + @abstractmethod + def get_number_of_nodes() -> int: + """ + Get the number of nodes in the ontology graph. + + Returns: + int: The total number of nodes. + """ + pass + + @staticmethod + @abstractmethod + def get_edges() -> Set[Tuple[int, int]]: + """ + Get the set of edges in the ontology graph. + + Returns: + Set[Tuple[int, int]]: A set of tuples where each tuple represents an edge between two nodes. + """ + pass + + @staticmethod + @abstractmethod + def get_number_of_edges() -> int: + """ + Get the number of edges in the ontology graph. + + Returns: + int: The total number of edges. + """ + pass + + @staticmethod + @abstractmethod + def get_edges_of_transitive_closure_graph() -> Set[Tuple[int, int]]: + """ + Get the set of edges in the transitive closure of the ontology graph. + + Returns: + Set[Tuple[int, int]]: A set of tuples representing the transitive closure edges. + """ + pass + + @staticmethod + @abstractmethod + def get_number_of_transitive_edges() -> int: + """ + Get the number of edges in the transitive closure of the ontology graph. + + Returns: + int: The total number of transitive edges. + """ + pass + + @staticmethod + @abstractmethod + def get_obsolete_nodes_ids() -> Set[int]: + """ + Get the set of obsolete node IDs in the ontology graph. + + Returns: + Set[int]: A set of obsolete node IDs. + """ + pass + + @staticmethod + @abstractmethod + def get_transitively_closed_graph() -> nx.DiGraph: + """ + Get the transitive closure of the ontology graph. + + Returns: + nx.DiGraph: A directed graph representing the transitive closure of the ontology graph. + """ + pass + + @staticmethod + @abstractmethod + def get_data_in_dataframe() -> pd.DataFrame: + """ + Get the ontology data as a Pandas DataFrame. + + Returns: + pd.DataFrame: A DataFrame containing ontology data. + """ + pass + + +class GOUniProtMockData(MockOntologyGraphData): + """ + A mock ontology representing a simplified version of the Gene Ontology (GO) structure with nodes and edges + representing GO terms and their relationships in a directed acyclic graph (DAG). + + Nodes: + - GO_1 + - GO_2 + - GO_3 + - GO_4 + - GO_5 + - GO_6 + + Edges (Parent-Child Relationships): + - GO_1 -> GO_2 + - GO_1 -> GO_3 + - GO_2 -> GO_4 + - GO_2 -> GO_5 + - GO_3 -> GO_4 + - GO_4 -> GO_6 + + This mock ontology structure is useful for testing methods related to GO hierarchy, graph extraction, and transitive + closure operations. + + The class also includes methods to retrieve nodes, edges, and transitive closure of the graph. + + Visual Representation Graph with Valid Nodes and Edges: + + GO_1 + / \ + GO_2 GO_3 + / \ / + GO_5 GO_4 + \ + GO_6 + + Valid Swiss Proteins with mapping to valid GO ids + Swiss_Prot_1 -> GO_2, GO_3, GO_5 + Swiss_Prot_2 -> GO_2, GO_5 + """ + + @staticmethod + def get_nodes() -> List[int]: + """ + Get a sorted list of node IDs. + + Returns: + List[int]: A sorted list of node IDs in the ontology graph. + """ + return sorted([1, 2, 3, 4, 5, 6]) + + @staticmethod + def get_number_of_nodes() -> int: + """ + Get the total number of nodes in the ontology graph. + + Returns: + int: The number of nodes. + """ + return len(GOUniProtMockData.get_nodes()) + + @staticmethod + def get_edges() -> Set[Tuple[int, int]]: + """ + Get the set of edges in the ontology graph. + + Returns: + Set[Tuple[int, int]]: A set of tuples where each tuple represents an edge between two nodes. + """ + return {(1, 2), (1, 3), (2, 4), (2, 5), (3, 4), (4, 6)} + + @staticmethod + def get_number_of_edges() -> int: + """ + Get the total number of edges in the ontology graph. + + Returns: + int: The number of edges. + """ + return len(GOUniProtMockData.get_edges()) + + @staticmethod + def get_edges_of_transitive_closure_graph() -> Set[Tuple[int, int]]: + """ + Get the set of edges in the transitive closure of the ontology graph. + + Returns: + Set[Tuple[int, int]]: A set of tuples representing edges in the transitive closure graph. + """ + return { + (1, 2), + (1, 3), + (1, 4), + (1, 5), + (1, 6), + (2, 4), + (2, 5), + (2, 6), + (3, 4), + (3, 6), + (4, 6), + } + + @staticmethod + def get_number_of_transitive_edges() -> int: + """ + Get the total number of edges in the transitive closure graph. + + Returns: + int: The number of transitive edges. + """ + return len(GOUniProtMockData.get_edges_of_transitive_closure_graph()) + + @staticmethod + def get_obsolete_nodes_ids() -> Set[int]: + """ + Get the set of obsolete node IDs in the ontology graph. + + Returns: + Set[int]: A set of node IDs representing obsolete nodes. + """ + return {7, 8} + + @staticmethod + def get_GO_raw_data() -> str: + """ + Get raw data in string format for a basic Gene Ontology (GO) structure. + + This data simulates a basic GO ontology format typically used for testing purposes. + The data will include valid and obsolete GO terms with various relationships between them. + + Scenarios covered: + - Obsolete terms being the parent of valid terms. + - Valid terms being the parent of obsolete terms. + - Both direct and indirect hierarchical relationships between terms. + + The data is designed to help test the proper handling of obsolete and valid GO terms, + ensuring that the ontology parser can correctly manage both cases. + + Returns: + str: The raw GO data in string format, structured as test input. + """ + return """ + [Term] + id: GO:0000001 + name: GO_1 + namespace: molecular_function + def: "OBSOLETE. Assists in the correct assembly of ribosomes or ribosomal subunits in vivo, but is not a component of the assembled ribosome when performing its normal biological function." [GOC:jl, PMID:12150913] + comment: This term was made obsolete because it refers to a class of gene products and a biological process rather than a molecular function. + synonym: "ribosomal chaperone activity" EXACT [] + xref: MetaCyc:BETAGALACTOSID-RXN + xref: Reactome:R-HSA-189062 "lactose + H2O => D-glucose + D-galactose" + xref: Reactome:R-HSA-5658001 "Defective LCT does not hydrolyze Lac" + xref: RHEA:10076 + + [Term] + id: GO:0000002 + name: GO_2 + namespace: biological_process + is_a: GO:0000001 ! hydrolase activity, hydrolyzing O-glycosyl compounds + is_a: GO:0000008 ! hydrolase activity, hydrolyzing O-glycosyl compounds + + [Term] + id: GO:0000003 + name: GO_3 + namespace: cellular_component + is_a: GO:0000001 ! regulation of DNA recombination + + [Term] + id: GO:0000004 + name: GO_4 + namespace: biological_process + is_a: GO:0000003 ! regulation of DNA recombination + is_a: GO:0000002 ! hydrolase activity, hydrolyzing O-glycosyl compounds + + [Term] + id: GO:0000005 + name: GO_5 + namespace: molecular_function + is_a: GO:0000002 ! regulation of DNA recombination + + [Term] + id: GO:0000006 + name: GO_6 + namespace: cellular_component + is_a: GO:0000004 ! glucoside transport + + [Term] + id: GO:0000007 + name: GO_7 + namespace: biological_process + is_a: GO:0000003 ! glucoside transport + is_obsolete: true + + [Term] + id: GO:0000008 + name: GO_8 + namespace: molecular_function + is_obsolete: true + + [Typedef] + id: term_tracker_item + name: term tracker item + namespace: external + xref: IAO:0000233 + is_metadata_tag: true + is_class_level: true + """ + + @staticmethod + def protein_sequences() -> Dict[str, str]: + """ + Get the protein sequences for Swiss-Prot proteins. + + Returns: + Dict[str, str]: A dictionary where keys are Swiss-Prot IDs and values are their respective sequences. + """ + return { + "Swiss_Prot_1": "MAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK".replace( + " ", "" + ), + "Swiss_Prot_2": "EKGLIVGHFS GIKYKGEKAQ ASEVDVNKMC CWVSKFKDAM RRYQGIQTCK".replace( + " ", "" + ), + } + + @staticmethod + def proteins_for_pretraining() -> List[str]: + """ + Returns a list of protein IDs which will be used for pretraining based on mock UniProt data. + + Proteins include those with: + - No GO classes or invalid GO classes (missing required evidence codes). + + Returns: + List[str]: A list of protein IDs that do not meet validation criteria. + """ + return [ + "Swiss_Prot_5", # No GO classes associated + "Swiss_Prot_6", # GO class with no evidence code + "Swiss_Prot_7", # GO class with invalid evidence code + ] + + @staticmethod + def get_UniProt_raw_data() -> str: + """ + Get raw data in string format for UniProt proteins. + + This mock data contains eleven Swiss-Prot proteins with different properties: + - **Swiss_Prot_1**: A valid protein with three valid GO classes and one invalid GO class. + - **Swiss_Prot_2**: Another valid protein with two valid GO classes and one invalid. + - **Swiss_Prot_3**: Contains valid GO classes but has a sequence length > 1002. + - **Swiss_Prot_4**: Has valid GO classes but contains an invalid amino acid, 'B'. + - **Swiss_Prot_5**: Has a sequence but no GO classes associated. + - **Swiss_Prot_6**: Has GO classes without any associated evidence codes. + - **Swiss_Prot_7**: Has a GO class with an invalid evidence code. + - **Swiss_Prot_8**: Has a sequence length > 1002 and has only invalid GO class. + - **Swiss_Prot_9**: Has no GO classes but contains an invalid amino acid, 'B', in its sequence. + - **Swiss_Prot_10**: Has a valid GO class but lacks a sequence. + - **Swiss_Prot_11**: Has only Invalid GO class but lacks a sequence. + + Note: + A valid GO label is the one which has one of the following evidence code specified in + go_uniprot.py->`EXPERIMENTAL_EVIDENCE_CODES`. + Invalid amino acids are specified in go_uniprot.py->`AMBIGUOUS_AMINO_ACIDS`. + + Returns: + str: The raw UniProt data in string format. + """ + protein_sq_1 = GOUniProtMockData.protein_sequences()["Swiss_Prot_1"] + protein_sq_2 = GOUniProtMockData.protein_sequences()["Swiss_Prot_2"] + raw_str = ( + # Below protein with 3 valid associated GO class and one invalid GO class + f"ID Swiss_Prot_1 Reviewed; {len(protein_sq_1)} AA. \n" + "AC Q6GZX4;\n" + "DR GO; GO:0000002; C:membrane; EXP:UniProtKB-KW.\n" + "DR GO; GO:0000003; C:membrane; IDA:UniProtKB-KW.\n" + "DR GO; GO:0000005; P:regulation of viral transcription; IPI:InterPro.\n" + "DR GO; GO:0000004; P:regulation of viral transcription; IEA:SGD.\n" + f"SQ SEQUENCE {len(protein_sq_1)} AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + f" {protein_sq_1}\n" + "//\n" + # Below protein with 2 valid associated GO class and one invalid GO class + f"ID Swiss_Prot_2 Reviewed; {len(protein_sq_2)} AA.\n" + "AC DCGZX4;\n" + "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" + "DR GO; GO:0000002; P:regulation of viral transcription; IMP:InterPro.\n" + "DR GO; GO:0000005; P:regulation of viral transcription; IGI:InterPro.\n" + "DR GO; GO:0000006; P:regulation of viral transcription; IEA:PomBase.\n" + f"SQ SEQUENCE {len(protein_sq_2)} AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + f" {protein_sq_2}\n" + "//\n" + # Below protein with all valid associated GO class but sequence length greater than 1002 + f"ID Swiss_Prot_3 Reviewed; {len(protein_sq_1 * 25)} AA.\n" + "AC Q6GZX4;\n" + "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" + "DR GO; GO:0000002; P:regulation of viral transcription; IEP:InterPro.\n" + "DR GO; GO:0000005; P:regulation of viral transcription; TAS:InterPro.\n" + "DR GO; GO:0000006; P:regulation of viral transcription; EXP:PomBase.\n" + f"SQ SEQUENCE {len(protein_sq_1 * 25)} AA; 129118 MW; FE2984658CED53A8 CRC64;\n" + f" {protein_sq_1 * 25}\n" + "//\n" + # Below protein has valid go class association but invalid amino acid `X` in its sequence + "ID Swiss_Prot_4 Reviewed; 60 AA.\n" + "AC Q6GZX4;\n" + "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" + "DR GO; GO:0000002; P:regulation of viral transcription; EXP:InterPro.\n" + "DR GO; GO:0000005; P:regulation of viral transcription; IEA:InterPro.\n" + "DR GO; GO:0000006; P:regulation of viral transcription; EXP:PomBase.\n" + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + " BAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + "//\n" + # Below protein with sequence string but has no GO class + "ID Swiss_Prot_5 Reviewed; 60 AA.\n" + "AC Q6GZX4;\n" + "DR EMBL; AY548484; AAT09660.1; -; Genomic_DNA.\n" + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + " MAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + "//\n" + # Below protein with sequence string and with NO `valid` associated GO class (no evidence code) + "ID Swiss_Prot_6 Reviewed; 60 AA.\n" + "AC Q6GZX4;\n" + "DR GO; GO:0000023; P:regulation of viral transcription;\n" + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + " MAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + "//\n" + # Below protein with sequence string and with NO `valid` associated GO class (invalid evidence code) + "ID Swiss_Prot_7 Reviewed; 60 AA.\n" + "AC Q6GZX4;\n" + "DR GO; GO:0000024; P:regulation of viral transcription; IEA:SGD.\n" + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + " MAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + "//\n" + # Below protein with sequence length greater than 1002 but with `Invalid` associated GO class + f"ID Swiss_Prot_8 Reviewed; {len(protein_sq_2 * 25)} AA.\n" + "AC Q6GZX4;\n" + "DR GO; GO:0000025; P:regulation of viral transcription; IC:Inferred.\n" + f"SQ SEQUENCE {len(protein_sq_2 * 25)} AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + f" {protein_sq_2 * 25}\n" + "//\n" + # Below protein with sequence string but invalid amino acid `X` in its sequence + "ID Swiss_Prot_9 Reviewed; 60 AA.\n" + "AC Q6GZX4;\n" + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + " BAFSAEDVLK EYDRRRRMEA LLLSLYYPND RKLLDYKEWS PPRVQVECPK APVEWNNPPS\n" + "//\n" + # Below protein with a `valid` associated GO class but without sequence string + "ID Swiss_Prot_10 Reviewed; 60 AA.\n" + "AC Q6GZX4;\n" + "DR GO; GO:0000027; P:regulation of viral transcription; EXP:InterPro.\n" + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + " \n" + "//\n" + # Below protein with a `Invalid` associated GO class but without sequence string + "ID Swiss_Prot_11 Reviewed; 60 AA.\n" + "AC Q6GZX4;\n" + "DR GO; GO:0000028; P:regulation of viral transcription; ND:NoData.\n" + "SQ SEQUENCE 60 AA; 29735 MW; B4840739BF7D4121 CRC64;\n" + " \n" + "//\n" + ) + + return raw_str + + @staticmethod + def get_data_in_dataframe() -> pd.DataFrame: + """ + Get a mock DataFrame representing UniProt data. + + The DataFrame contains Swiss-Prot protein data, including identifiers, accessions, GO terms, sequences, + and binary label columns representing whether each protein is associated with certain GO classes. + + Returns: + pd.DataFrame: A DataFrame containing mock UniProt data with columns for 'swiss_id', 'accession', 'go_ids', 'sequence', + and binary labels for GO classes. + """ + expected_data = OrderedDict( + swiss_id=["Swiss_Prot_1", "Swiss_Prot_2"], + accession=["Q6GZX4", "DCGZX4"], + go_ids=[[1, 2, 3, 5], [1, 2, 5]], + sequence=list(GOUniProtMockData.protein_sequences().values()), + **{ + # SP_1, SP_2 + 1: [True, True], + 2: [True, True], + 3: [True, False], + 4: [False, False], + 5: [True, True], + 6: [False, False], + }, + ) + return pd.DataFrame(expected_data) + + @staticmethod + def get_transitively_closed_graph() -> nx.DiGraph: + """ + Get the transitive closure of the ontology graph. + + Returns: + nx.DiGraph: A directed graph representing the transitive closure of the ontology graph. + """ + g = nx.DiGraph() + g.add_nodes_from(node for node in ChebiMockOntology.get_nodes()) + g.add_edges_from(GOUniProtMockData.get_edges_of_transitive_closure_graph()) + return g diff --git a/tests/unit/readers/__init__.py b/tests/unit/readers/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/unit/readers/testDataReader.py b/tests/unit/readers/testDataReader.py new file mode 100644 index 0000000..745c0ac --- /dev/null +++ b/tests/unit/readers/testDataReader.py @@ -0,0 +1,56 @@ +import unittest +from typing import Any, Dict, List + +from chebai.preprocessing.reader import DataReader + + +class TestDataReader(unittest.TestCase): + """ + Unit tests for the DataReader class. + """ + + @classmethod + def setUpClass(cls) -> None: + """ + Set up the test environment by initializing a DataReader instance. + """ + cls.reader = DataReader() + + def test_to_data(self) -> None: + """ + Test the to_data method to ensure it correctly processes the input row + and formats it according to the expected output. + + This method tests the conversion of raw data into a processed format, + including extracting features, labels, ident, group, and additional + keyword arguments. + """ + features_list: List[int] = [10, 20, 30] + labels_list: List[bool] = [True, False, True] + ident_no: int = 123 + + row: Dict[str, Any] = { + "features": features_list, + "labels": labels_list, + "ident": ident_no, + "group": "group_data", + "additional_kwargs": {"extra_key": "extra_value"}, + } + + expected: Dict[str, Any] = { + "features": features_list, + "labels": labels_list, + "ident": ident_no, + "group": "group_data", + "extra_key": "extra_value", + } + + self.assertEqual( + self.reader.to_data(row), + expected, + "The to_data method did not process the input row as expected.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/readers/testProteinDataReader.py b/tests/unit/readers/testProteinDataReader.py new file mode 100644 index 0000000..c5bc5e9 --- /dev/null +++ b/tests/unit/readers/testProteinDataReader.py @@ -0,0 +1,139 @@ +import unittest +from typing import List +from unittest.mock import mock_open, patch + +from chebai.preprocessing.reader import EMBEDDING_OFFSET, ProteinDataReader + + +class TestProteinDataReader(unittest.TestCase): + """ + Unit tests for the ProteinDataReader class. + """ + + @classmethod + @patch( + "chebai.preprocessing.reader.open", + new_callable=mock_open, + read_data="M\nK\nT\nF\nR\nN", + ) + def setUpClass(cls, mock_file: mock_open) -> None: + """ + Set up the test environment by initializing a ProteinDataReader instance with a mocked token file. + + Args: + mock_file: Mock object for file operations. + """ + cls.reader = ProteinDataReader(token_path="/mock/path") + # After initializing, cls.reader.cache should now be set to ['M', 'K', 'T', 'F', 'R', 'N'] + assert cls.reader.cache == [ + "M", + "K", + "T", + "F", + "R", + "N", + ], "Cache initialization did not match expected tokens." + + def test_read_data(self) -> None: + """ + Test the _read_data method with a protein sequence to ensure it correctly tokenizes the sequence. + """ + raw_data = "MKTFFRN" + + # Expected output based on the cached tokens + expected_output: List[int] = [ + EMBEDDING_OFFSET + 0, # M + EMBEDDING_OFFSET + 1, # K + EMBEDDING_OFFSET + 2, # T + EMBEDDING_OFFSET + 3, # F + EMBEDDING_OFFSET + 3, # F (repeated token) + EMBEDDING_OFFSET + 4, # R + EMBEDDING_OFFSET + 5, # N + ] + result = self.reader._read_data(raw_data) + self.assertEqual( + result, + expected_output, + "The _read_data method did not produce the expected tokenized output.", + ) + + def test_read_data_with_new_token(self) -> None: + """ + Test the _read_data method with a protein sequence that includes a new token. + Ensure that the new token is added to the cache and processed correctly. + """ + raw_data = "MKTFY" + + # 'Y' is not in the initial cache and should be added. + expected_output: List[int] = [ + EMBEDDING_OFFSET + 0, # M + EMBEDDING_OFFSET + 1, # K + EMBEDDING_OFFSET + 2, # T + EMBEDDING_OFFSET + 3, # F + EMBEDDING_OFFSET + len(self.reader.cache), # Y (new token) + ] + + result = self.reader._read_data(raw_data) + self.assertEqual( + result, + expected_output, + "The _read_data method did not correctly handle a new token.", + ) + + # Verify that 'Y' was added to the cache + self.assertIn( + "Y", self.reader.cache, "The new token 'Y' was not added to the cache." + ) + # Ensure it's at the correct index + self.assertEqual( + self.reader.cache.index("Y"), + len(self.reader.cache) - 1, + "The new token 'Y' was not added at the correct index in the cache.", + ) + + def test_read_data_with_invalid_token(self) -> None: + """ + Test the _read_data method with an invalid amino acid token to ensure it raises a KeyError. + """ + raw_data = "MKTFZ" # 'Z' is not a valid amino acid token + + with self.assertRaises(KeyError) as context: + self.reader._read_data(raw_data) + + self.assertIn( + "Invalid token 'Z' encountered", + str(context.exception), + "The KeyError did not contain the expected message for an invalid token.", + ) + + def test_read_data_with_empty_sequence(self) -> None: + """ + Test the _read_data method with an empty protein sequence to ensure it returns an empty list. + """ + raw_data = "" + + result = self.reader._read_data(raw_data) + self.assertEqual( + result, + [], + "The _read_data method did not return an empty list for an empty input sequence.", + ) + + def test_read_data_with_repeated_tokens(self) -> None: + """ + Test the _read_data method with repeated amino acid tokens to ensure it handles them correctly. + """ + raw_data = "MMMMM" + + expected_output: List[int] = [EMBEDDING_OFFSET + 0] * 5 # All tokens are 'M' + + result = self.reader._read_data(raw_data) + self.assertEqual( + result, + expected_output, + "The _read_data method did not correctly handle repeated tokens.", + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/tutorials/data_exploration_go.ipynb b/tutorials/data_exploration_go.ipynb new file mode 100644 index 0000000..6f67c82 --- /dev/null +++ b/tutorials/data_exploration_go.ipynb @@ -0,0 +1,1341 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "da687d32ba48b188", + "metadata": {}, + "source": [ + "# Introduction\n", + "\n", + "This notebook serves as a guide for new developers using the `chebai` package. If you just want to run the experiments, you can refer to the [README.md](https://github.com/ChEB-AI/python-chebai/blob/dev/README.md) and the [wiki](https://github.com/ChEB-AI/python-chebai/wiki) for the basic commands. This notebook explains what happens under the hood for the GO-UniProt dataset. It covers\n", + "- how to instantiate a data class and generate data\n", + "- how the data is processed and stored\n", + "- and how to work with different molecule encodings.\n", + "\n", + "The chebai package simplifies the handling of these datasets by **automatically creating** them as needed. This means that you do not have to input any data manually; the package will generate and organize the data files based on the parameters and encodings selected. This feature ensures that the right data is available and formatted properly. You can however provide your own data files, for instance if you want to replicate a specific experiment.\n", + "\n", + "---\n" + ] + }, + { + "cell_type": "markdown", + "id": "0bd07c91-bb02-48d4-b759-aa35ecb224bd", + "metadata": {}, + "source": [ + "# 1. Instantiation of a Data Class\n", + "\n", + "To start working with `chebai`, you first need to instantiate a GO-UniProt data class. This class is responsible for managing, interacting with, and preprocessing the GO and UniProt data" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a4d590fb-9a83-456e-9cb4-303caa8203e8", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Already in the project root directory: G:\\github-aditya0by0\\python-chebai\n" + ] + } + ], + "source": [ + "# To run this notebook, you need to change the working directory of the jupyter notebook to root dir of the project.\n", + "import os\n", + "\n", + "# Root directory name of the project\n", + "expected_root_dir = \"python-chebai\"\n", + "\n", + "# Check if the current directory ends with the expected root directory name\n", + "if not os.getcwd().endswith(expected_root_dir):\n", + " os.chdir(\"..\") # Move up one directory level\n", + " if os.getcwd().endswith(expected_root_dir):\n", + " print(\"Changed to project root directory:\", os.getcwd())\n", + " else:\n", + " print(\"Warning: Directory change unsuccessful. Current directory:\", os.getcwd())\n", + "else:\n", + " print(\"Already in the project root directory:\", os.getcwd())" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "440f203ceaf7e4b7", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-30T21:25:03.920610Z", + "start_time": "2024-09-30T21:25:03.622407Z" + } + }, + "outputs": [], + "source": "from chebai.preprocessing.datasets.go_uniprot import GOUniProtOver250" + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "a648346d81d0dc5e", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-30T21:25:08.863132Z", + "start_time": "2024-09-30T21:25:08.387739Z" + } + }, + "outputs": [], + "source": [ + "go_class = GOUniProtOver250(go_branch=\"BP\")" + ] + }, + { + "cell_type": "markdown", + "id": "64585012b0d7f66f", + "metadata": {}, + "source": [ + "### Inheritance Hierarchy\n", + "\n", + "GO_UniProt data classes inherit from [`_DynamicDataset`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/base.py#L597), which in turn inherits from [`XYBaseDataModule`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/base.py#L22). Specifically:\n", + "\n", + "- **`_DynamicDataset`**: This class serves as an intermediate base class that provides additional functionality or customization for datasets that require dynamic behavior. It inherits from `XYBaseDataModule`, which provides the core methods for data loading and processing.\n", + "\n", + "- **`XYBaseDataModule`**: This is the base class for data modules, providing foundational properties and methods for handling and processing datasets, including data splitting, loading, and preprocessing.\n", + "\n", + "In summary, GO_UniProt data classes are designed to manage and preprocess chemical data effectively by leveraging the capabilities provided by `XYBaseDataModule` through the `_DynamicDataset` intermediary.\n", + "\n", + "\n", + "### Configuration Parameters\n", + "\n", + "Data classes related to proteins can be configured using the following main parameters:\n", + "\n", + "- **`go_branch (str)`**: The Gene Ontology (GO) branch. The default value is `\"all\"`, which includes all branches of GO in the dataset.\n", + " - **`\"BP\"`**: Biological Process branch.\n", + " - **`\"MF\"`**: Molecular Function branch.\n", + " - **`\"CC\"`**: Cellular Component branch.\n", + "\n", + "- **`max_sequence_length (int)`**: Specifies the maximum allowed sequence length for a protein, with a default of `1002`. During data preprocessing, any proteins exceeding this length will be excluded from further processing.\n", + "\n", + "This allows for more specific datasets focused on a particular aspect of gene function.\n", + "\n", + "- **`splits_file_path (str, optional)`**: Path to a CSV file containing data splits. If not provided, the class will handle splits internally. The default is `None`.\n", + "\n", + "### Additional Input Parameters\n", + "\n", + "To get more control over various aspects of data loading, processing, and splitting, you can refer to documentation of additional parameters in docstrings of the respective classes: [`_GOUniProtDataExtractor`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/go_uniprot.py#L33), [`XYBaseDataModule`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/base.py#L22), [`_DynamicDataset`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/base.py#L597), etc.\n", + "\n", + "\n", + "# Available Data Classes\n", + "\n", + "__Note__: Check the code implementation of classes [here](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/go_uniprot.py).\n", + "\n", + "There is a range of available dataset classes for GOUniProt classes. Usually, you want to use `GOUniProtOver250` or `GOUniProtOver50`. Both inherit from `_GOUniProtOverX`. The number indicates the threshold for selecting label classes. The selection process is based on the annotations of the GO terms with its ancestors across the dataset. For instance, GOUniProtOver50 will only select labels which have at least 50 samples in the dataset.\n", + "\n", + "Refer `select_classes` method of `_GOUniProtOverX` for more details on selection process.\n", + "\n", + "If you need a different threshold, you can create your own subclass." + ] + }, + { + "cell_type": "markdown", + "id": "651ab5c39833bd2c", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "a52b4363-7398-44aa-a4cc-8bba14bdd966", + "metadata": {}, + "source": [ + "# 2. Preparation / Setup Methods\n", + "\n", + "Once a GOUniProt data class instance is created, it typically requires preparation before use. This step is to generate the actual dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "9f77351090560bc4", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Checking for processed data in data\\GO_UniProt\\GO250_BP_1002\\processed\n", + "Missing processed data file (`data.pkl` file)\n", + "Downloading Swiss UniProt data....\n", + "Downloading to temporary file C:\\Users\\HP\\AppData\\Local\\Temp\\tmp7pp677ik\n", + "Downloaded to C:\\Users\\HP\\AppData\\Local\\Temp\\tmp7pp677ik\n", + "Unzipping the file....\n", + "Unpacked and saved to data\\GO_UniProt\\raw\\uniprot_sprot.dat\n", + "Removed temporary file C:\\Users\\HP\\AppData\\Local\\Temp\\tmp7pp677ik\n", + "Missing Gene Ontology raw data\n", + "Downloading Gene Ontology data....\n", + "Extracting class hierarchy...\n", + "Compute transitive closure\n", + "Processing graph\n", + "Parsing swiss uniprot raw data....\n", + "Selecting GO terms based on given threshold: 250 ...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Check for processed data in data\\GO_UniProt\\GO250_BP_1002\\processed\\protein_token\n", + "Cross-validation enabled: False\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Missing transformed data (`data.pt` file). Transforming data.... \n", + "Processing 53604 lines...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 53604/53604 [01:18<00:00, 678.84it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Saving 20 tokens to G:\\github-aditya0by0\\python-chebai\\chebai\\preprocessing\\bin\\protein_token\\tokens.txt...\n", + "First 10 tokens: ['M', 'S', 'I', 'G', 'A', 'T', 'R', 'L', 'Q', 'N']\n" + ] + } + ], + "source": [ + "go_class.prepare_data()\n", + "go_class.setup()" + ] + }, + { + "cell_type": "markdown", + "id": "2328e824c4dafb2d", + "metadata": {}, + "source": [ + "### Automatic Execution: \n", + "These methods are executed automatically within the data class instance. Users do not need to call them explicitly, as the code internally manages the preparation and setup of data, ensuring that it is ready for subsequent use in training and validation processes.\n", + "\n", + "\n", + "### Why is Preparation Needed?\n", + "\n", + "- **Data Availability**: The preparation step ensures that the required GOUniProt data files are downloaded or loaded, which are essential for analysis.\n", + "- **Data Integrity**: It ensures that the data files are transformed into a compatible format required for model input.\n", + "\n", + "### Main Methods for Data Preprocessing\n", + "\n", + "The data preprocessing in a data class involves two main methods:\n", + "\n", + "1. **`prepare_data` Method**:\n", + " - **Purpose**: This method checks for the presence of raw data in the specified directory. If the raw data is missing, it fetches the ontology, creates a dataframe, and saves it to a file (`data.pkl`). The dataframe includes columns such as IDs, data representations, and labels.\n", + " - **Documentation**: [PyTorch Lightning - `prepare_data`](https://lightning.ai/docs/pytorch/stable/data/datamodule.html#prepare-data)\n", + "\n", + "2. **`setup` Method**:\n", + " - **Purpose**: This method sets up the data module for training, validation, and testing. It checks for the processed data and, if necessary, performs additional setup to ensure the data is ready for model input. It also handles cross-validation settings if enabled.\n", + " - **Description**: Transforms `data.pkl` into a model input data format (`data.pt`), ensuring that the data is in a format compatible for input to the model. The transformed data contains the following keys: `ident`, `features`, `labels`, and `group`. This method uses a subclass of Data Reader to perform the transformation.\n", + " - **Documentation**: [PyTorch Lightning - `setup`](https://lightning.ai/docs/pytorch/stable/data/datamodule.html#setup)\n", + "\n", + "These methods ensure that the data is correctly prepared and set up for subsequent use in training and validation processes." + ] + }, + { + "cell_type": "markdown", + "id": "db5b58f2d96823fc", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "ee174b61b36c71aa", + "metadata": {}, + "source": [ + "# 3. Overview of the 3 preprocessing stages\n", + "\n", + "The `chebai` library follows a three-stage preprocessing pipeline, which is reflected in its file structure:\n", + "\n", + "1. **Raw Data Stage**:\n", + " - **File**: `go-basic.obo` and `uniprot_sprot.data`\n", + " - **Description**: This stage contains the raw GO ontology data and raw Swiss-UniProt data, serving as the initial input for further processing.\n", + " - **File Paths**:\n", + " - `data/GO_UniProt/raw/go-basic.obo`\n", + " - `data/GO_UniProt/raw/uniprot_sprot.dat`\n", + "\n", + "2. **Processed Data Stage 1**:\n", + " - **File**: `data.pkl`\n", + " - **Description**: This stage includes the data after initial processing. It contains sequence strings, class columns, and metadata but lacks data splits.\n", + " - **File Path**: `data/GO_UniProt/${dataset_name}/processed/data.pkl`\n", + " - **Additional File**: `classes.txt` - A file listing the relevant ChEBI classes.\n", + "\n", + "3. **Processed Data Stage 2**:\n", + " - **File**: `data.pt`\n", + " - **Description**: This final stage includes the encoded data in a format compatible with PyTorch, ready for model input. This stage also references data splits when available.\n", + " - **File Path**: `data/GO_UniProt/${dataset_name}/processed/${reader_name}/data.pt`\n", + " - **Additional File**: `splits.csv` - Contains saved splits for reproducibility.\n", + "\n", + "**Note**: If `go_branch` is specified, the `dataset_name` will include the branch name in the format `${dataset_name}_${go_branch}`. Otherwise, it will just be `${dataset_name}`.\n", + "\n", + "### Summary of File Paths\n", + "\n", + "- **Raw Data**: `data/GO_UniProt/raw`\n", + "- **Processed Data 1**: `data/GO_UniProt/${dataset_name}/processed`\n", + "- **Processed Data 2**: `data/GO_UniProt/${dataset_name}/processed/${reader_name}`\n", + "\n", + "This structured approach to data management ensures that each stage of data processing is well-organized and documented, from raw data acquisition to the preparation of model-ready inputs. It also facilitates reproducibility and traceability across different experiments.\n", + "\n", + "### Data Splits\n", + "\n", + "- **Creation**: Data splits are generated dynamically \"on the fly\" during training and evaluation to ensure flexibility and adaptability to different tasks.\n", + "- **Reproducibility**: To maintain consistency across different runs, splits can be reproduced by comparing hashes with a fixed seed value.\n" + ] + }, + { + "cell_type": "markdown", + "id": "a927ad484c930960", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "3f92b58e460c08fd", + "metadata": {}, + "source": [ + "# 4. Data Files and their structure\n", + "\n", + "`chebai` creates and manages several data files during its operation. These files store various chemical data and metadata essential for different tasks. Letโ€™s explore these files and their content.\n" + ] + }, + { + "cell_type": "markdown", + "id": "cca75d881cb8bade", + "metadata": {}, + "source": [ + "## go-basic.obo File\n", + "\n", + "**Description**: The `go-basic.obo` file is a key resource in the Gene Ontology (GO) dataset, containing the ontology data that defines various biological processes, molecular functions, and cellular components, as well as their relationships. This file is downloaded directly from the Gene Ontology Consortium and serves as the foundational raw data for further processing in GO-based applications.\n", + "\n", + "#### Example of a Term Document\n", + "\n", + "```plaintext\n", + "[Term]\n", + "id: GO:0000032\n", + "name: cell wall mannoprotein biosynthetic process\n", + "namespace: biological_process\n", + "def: \"The chemical reactions and pathways resulting in the formation of cell wall mannoproteins, any cell wall protein that contains covalently bound mannose residues.\" [GOC:ai]\n", + "synonym: \"cell wall mannoprotein anabolism\" EXACT []\n", + "is_a: GO:0006057 ! mannoprotein biosynthetic process\n", + "is_a: GO:0031506 ! cell wall glycoprotein biosynthetic process\n", + "```\n", + "\n", + "**File Path**: `data/GO_UniProt/raw/go-basic.obo`\n", + "\n", + "### Structure of `go-basic.obo`\n", + "\n", + "The `go-basic.obo` file is organized into blocks of text known as \"term documents.\" Each block starts with a `[Term]` header and contains various attributes that describe a specific biological process, molecular function, or cellular component within the GO ontology. These attributes include identifiers, names, relationships to other terms, and more.\n", + "\n", + "\n", + "\n", + "### Breakdown of Attributes\n", + "\n", + "Each term document in the `go-basic.obo` file consists of the following key attributes:\n", + "\n", + "- **`[Term]`**: \n", + " - **Description**: Indicates the beginning of a new term in the ontology. Each term represents a distinct biological process, molecular function, or cellular component.\n", + "\n", + "- **`id: GO:0000032`**: \n", + " - **Description**: A unique identifier for the biological term within the GO ontology.\n", + " - **Example**: `GO:0000032` refers to the term \"cell wall mannoprotein biosynthetic process.\"\n", + "\n", + "- **`name: cell wall mannoprotein biosynthetic process`**: \n", + " - **Description**: The name of the biological process, molecular function, or cellular component being described.\n", + " - **Example**: The name \"cell wall mannoprotein biosynthetic process\" is a descriptive label for the GO term with the identifier `GO:0000032`.\n", + "\n", + "- **`namespace: biological_process`**: \n", + " - **Description**: Specifies which ontology the term belongs to. The main namespaces are `biological_process`, `molecular_function`, and `cellular_component`.\n", + "\n", + "- **`is_a: GO:0006057`**: \n", + " - **Description**: Defines hierarchical relationships to other terms within the ontology. The `is_a` attribute indicates that the current term is a subclass or specific instance of the referenced term.\n", + " - **Example**: The term `GO:0000032` (\"cell wall mannoprotein biosynthetic process\") is a subclass of `GO:0006057` and subclass of `GO:0031506`.\n" + ] + }, + { + "cell_type": "markdown", + "id": "87c841de7d80beef", + "metadata": {}, + "source": [ + "## uniprot_sprot.dat File\n", + "\n", + "**Description**: The `uniprot_sprot.dat` file is a key component of the UniProtKB/Swiss-Prot dataset. It contains curated protein sequences with detailed annotations. Each entry in the file corresponds to a reviewed protein sequence, complete with metadata about its biological function, taxonomy, gene name, cross-references to other databases, and more. Below is a breakdown of the structure and key attributes in the file, using the provided example.\n", + "\n", + "\n", + "### Example of a Protein Entry\n", + "\n", + "```plaintext\n", + "ID 002L_FRG3G Reviewed; 320 AA.\n", + "AC Q6GZX3;\n", + "DT 28-JUN-2011, integrated into UniProtKB/Swiss-Prot.\n", + "DT 19-JUL-2004, sequence version 1.\n", + "DT 08-NOV-2023, entry version 46.\n", + "DE RecName: Full=Uncharacterized protein 002L;\n", + "GN ORFNames=FV3-002L;\n", + "OS Frog virus 3 (isolate Goorha) (FV-3).\n", + "OC Viruses; Varidnaviria; Bamfordvirae; Nucleocytoviricota; Megaviricetes;\n", + "OX NCBI_TaxID=654924;\n", + "OH NCBI_TaxID=8404; Lithobates pipiens (Northern leopard frog) (Rana pipiens).\n", + "RN [1]\n", + "RP NUCLEOTIDE SEQUENCE [LARGE SCALE GENOMIC DNA].\n", + "RX PubMed=15165820; DOI=10.1016/j.virol.2004.02.019;\n", + "RA Tan W.G., Barkman T.J., Gregory Chinchar V., Essani K.;\n", + "RT \"Comparative genomic analyses of frog virus 3, type species of the genus\n", + "RT Ranavirus (family Iridoviridae).\";\n", + "RL Virology 323:70-84(2004).\n", + "CC -!- SUBCELLULAR LOCATION: Host membrane {ECO:0000305}; Single-pass membrane\n", + "CC protein {ECO:0000305}.\n", + "DR EMBL; AY548484; AAT09661.1; -; Genomic_DNA.\n", + "DR RefSeq; YP_031580.1; NC_005946.1.\n", + "DR GeneID; 2947774; -.\n", + "DR KEGG; vg:2947774; -.\n", + "DR Proteomes; UP000008770; Segment.\n", + "DR GO; GO:0033644; C:host cell membrane; IEA:UniProtKB-SubCell.\n", + "DR GO; GO:0016020; C:membrane; IEA:UniProtKB-KW.\n", + "PE 4: Predicted;\n", + "KW Host membrane; Membrane; Reference proteome; Transmembrane;\n", + "KW Transmembrane helix.\n", + "FT CHAIN 1..320\n", + "FT /note=\"Uncharacterized protein 002L\"\n", + "FT /id=\"PRO_0000410509\"\n", + "SQ SEQUENCE 320 AA; 34642 MW; 9E110808B6E328E0 CRC64;\n", + " MSIIGATRLQ NDKSDTYSAG PCYAGGCSAF TPRGTCGKDW DLGEQTCASG FCTSQPLCAR\n", + " IKKTQVCGLR YSSKGKDPLV SAEWDSRGAP YVRCTYDADL IDTQAQVDQF VSMFGESPSL\n", + " AERYCMRGVK NTAGELVSRV SSDADPAGGW CRKWYSAHRG PDQDAALGSF CIKNPGAADC\n", + " KCINRASDPV YQKVKTLHAY PDQCWYVPCA ADVGELKMGT QRDTPTNCPT QVCQIVFNML\n", + " DDGSVTMDDV KNTINCDFSK YVPPPPPPKP TPPTPPTPPT PPTPPTPPTP PTPRPVHNRK\n", + " VMFFVAGAVL VAILISTVRW\n", + "//\n", + "```\n", + "\n", + "**File Path**: `data/GO_UniProt/raw/uniprot_sprot.dat`\n", + "\n", + "\n", + "## Structure of `uniprot_sprot.dat`\n", + "\n", + "The `uniprot_sprot.dat` file is organized into blocks of text, each representing a single protein entry. These blocks contain specific tags and fields that describe different aspects of the protein, including its sequence, function, taxonomy, and cross-references to external databases.\n", + "\n", + "### Breakdown of Attributes\n", + "\n", + "Each protein entry in the `uniprot_sprot.dat` file is structured with specific tags and sections that describe the protein in detail. Here's a breakdown of the key attributes:\n", + "\n", + "- **`ID`**: \n", + " - **Description**: Contains the unique identifier for the protein and its status (e.g., `Reviewed` indicates the sequence has been manually curated).\n", + " - **Example**: `002L_FRG3G` is the identifier for the protein from Frog virus 3.\n", + "\n", + "- **`AC`**: \n", + " - **Description**: Accession number, a unique identifier for the protein sequence.\n", + " - **Example**: `Q6GZX3` is the accession number for this entry.\n", + "\n", + "- **`DR`**: \n", + " - **Description**: Cross-references to other databases like EMBL, RefSeq, KEGG, and GeneID.\n", + " - **Example**: This entry is cross-referenced with the EMBL database, RefSeq, GO, etc.\n", + "\n", + "- **`GO`**: \n", + " - **Description**: Gene Ontology annotations that describe the cellular component, biological process, or molecular function associated with the protein.\n", + " - **Example**: The protein is associated with the GO terms `GO:0033644` (host cell membrane) and `GO:0016020` (membrane).\n", + "\n", + "- **`SQ`**: \n", + " - **Description**: The amino acid sequence of the protein.\n", + " - **Example**: The sequence consists of 320 amino acids.\n", + "\n", + "__Note__: For more detailed information refer [here](https://ftp.uniprot.org/pub/databases/uniprot/current_release/knowledgebase/complete/docs/keywlist.txt\n", + "). \n", + "\n", + "Consider the below line from above example: \n", + "```plaintext\n", + "DR GO; GO:0033644; C:host cell membrane; IEA:UniProtKB-SubCell.\n", + "```\n", + "\n", + "The line contains a **Gene Ontology (GO) annotation** describing the protein's subcellular location. Here's a detailed breakdown:\n", + "\n", + "- **`GO:0033644`**: This is the specific **GO term** identifier for \"host cell membrane,\" which indicates that the protein is associated with or located at the membrane of the host cell.\n", + "\n", + "- **`IEA`**: This stands for **Inferred from Electronic Annotation**, which is part of the **GO Evidence Codes**. **IEA** indicates that the annotation was automatically generated based on computational methods rather than direct experimental evidence. While **IEA** annotations are useful, they are generally considered less reliable than manually curated or experimentally verified evidence codes.\n", + "\n", + "__Note__: For more details on evidence codes check section 5.2" + ] + }, + { + "cell_type": "markdown", + "id": "b7687078-f6b8-4fbf-afa7-dfda89061a5e", + "metadata": {}, + "source": [ + "## data.pkl File\n", + "\n", + "**Description**: This file is generated by the `prepare_data` method and contains the processed GO data in a dataframe format. It includes protein IDs, data representations (such as sequence strings), and class columns with boolean values." + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "b4da7e73e251e1d1", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-30T14:08:33.990378Z", + "start_time": "2024-09-30T14:08:33.959459Z" + } + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "b66fbb9b720d053c", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-30T14:10:12.796911Z", + "start_time": "2024-09-30T14:10:06.052276Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Size of the data (rows x columns): (53604, 902)\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
swiss_idaccessiongo_idssequence4175122165209226...1990778200002620001452000146200014720002412000243200114120012332001234
111S1_CARILB5KVH4[3006, 8150, 9791, 10431, 21700, 22414, 32501,...MAKPILLSIYLCLIIVALFNGCLAQSGGRQQHKFGQCQLNRLDALE...FalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
311S2_SESINQ9XHP0[3006, 8150, 10431, 21700, 22414, 32502, 48609]MVAFKFLLALSLSLLVSAAIAQTREPRLTQGQQCRFQRISGAQPSL...FalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
614310_ARATHP48347,Q9LME5[7165, 8150, 9742, 9755, 9987, 43401, 50789, 5...MENEREKQVYLAKLSEQTERYDEMVEAMKKVAQLDVELTVEERNLV...FalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
814331_ARATHP42643,Q945M2,Q9M0S7[8150, 19222, 50789, 65007]MATPGASSARDEFVYMAKLAEQAERYEEMVEFMEKVAKAVDKDELT...FalseFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
914331_CAEELP41932,Q21537[132, 226, 1708, 6611, 6810, 6886, 6913, 6950,...MSDTVEELVQRAKLAEQAERYDDMAAAMKKVTEQGQELSNEERNLL...FalseFalseFalseFalseFalseTrue...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
\n", + "

5 rows ร— 902 columns

\n", + "
" + ], + "text/plain": [ + " swiss_id accession \\\n", + "1 11S1_CARIL B5KVH4 \n", + "3 11S2_SESIN Q9XHP0 \n", + "6 14310_ARATH P48347,Q9LME5 \n", + "8 14331_ARATH P42643,Q945M2,Q9M0S7 \n", + "9 14331_CAEEL P41932,Q21537 \n", + "\n", + " go_ids \\\n", + "1 [3006, 8150, 9791, 10431, 21700, 22414, 32501,... \n", + "3 [3006, 8150, 10431, 21700, 22414, 32502, 48609] \n", + "6 [7165, 8150, 9742, 9755, 9987, 43401, 50789, 5... \n", + "8 [8150, 19222, 50789, 65007] \n", + "9 [132, 226, 1708, 6611, 6810, 6886, 6913, 6950,... \n", + "\n", + " sequence 41 75 122 \\\n", + "1 MAKPILLSIYLCLIIVALFNGCLAQSGGRQQHKFGQCQLNRLDALE... False False False \n", + "3 MVAFKFLLALSLSLLVSAAIAQTREPRLTQGQQCRFQRISGAQPSL... False False False \n", + "6 MENEREKQVYLAKLSEQTERYDEMVEAMKKVAQLDVELTVEERNLV... False False False \n", + "8 MATPGASSARDEFVYMAKLAEQAERYEEMVEFMEKVAKAVDKDELT... False False False \n", + "9 MSDTVEELVQRAKLAEQAERYDDMAAAMKKVTEQGQELSNEERNLL... False False False \n", + "\n", + " 165 209 226 ... 1990778 2000026 2000145 2000146 2000147 \\\n", + "1 False False False ... False False False False False \n", + "3 False False False ... False False False False False \n", + "6 False False False ... False False False False False \n", + "8 False False False ... False False False False False \n", + "9 False False True ... False False False False False \n", + "\n", + " 2000241 2000243 2001141 2001233 2001234 \n", + "1 False False False False False \n", + "3 False False False False False \n", + "6 False False False False False \n", + "8 False False False False False \n", + "9 False False False False False \n", + "\n", + "[5 rows x 902 columns]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pkl_df = pd.DataFrame(\n", + " pd.read_pickle(\n", + " os.path.join(\n", + " go_class.processed_dir_main,\n", + " go_class.processed_dir_main_file_names_dict[\"data\"],\n", + " )\n", + " )\n", + ")\n", + "print(\"Size of the data (rows x columns): \", pkl_df.shape)\n", + "pkl_df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "735844f0b2474ad6", + "metadata": {}, + "source": [ + "**File Path**: `data/GO_UniProt/${dataset_name}/processed/data.pkl`\n", + "\n", + "\n", + "### Structure of `data.pkl`\n", + "`data.pkl` as following structure: \n", + "- **Column 0**: Contains the Identifier from Swiss-UniProt Dataset for each Swiss Protein data instance.\n", + "- **Column 1**: Contains the accession of each Protein data instance.\n", + "- **Column 2**: Contains the list of GO-IDs (Identifiers from Gene Ontology) which maps each Swiss Protein to the Gene Ontology instance.\n", + "- **Column 3**: Contains the sequence representation for the Swiss Protein using Amino Acid notation.\n", + "- **Column 4 and onwards**: Contains the labels, starting from column 4.\n", + "\n", + "This structure ensures that the data is organized and ready for further processing, such as further encoding.\n" + ] + }, + { + "cell_type": "markdown", + "id": "2c9b17f6-93bd-4cc3-8967-7ab1d2e06e51", + "metadata": {}, + "source": [ + "## data.pt File\n", + "\n", + "**Description**: Generated by the `setup` method, this file contains encoded data in a format compatible with the PyTorch library. It includes keys such as `ident`, `features`, `labels`, and `group`, making it ready for model input." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "85b097601fb242d6", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-30T14:10:35.034002Z", + "start_time": "2024-09-30T14:10:35.018342Z" + } + }, + "outputs": [], + "source": [ + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "289a54a71dec20fb", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-30T14:11:36.443693Z", + "start_time": "2024-09-30T14:11:34.199285Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Type of loaded data: \n", + "Content of the data file: \n", + " {'features': [10, 14, 21, 23, 12, 17, 17, 11, 12, 22, 17, 24, 17, 12, 12, 28, 14, 17, 25, 19, 13, 24, 17, 14, 18, 11, 13, 13, 16, 18, 18, 29, 21, 25, 13, 18, 24, 18, 17, 19, 16, 17, 20, 14, 17, 27, 23, 15, 19, 16, 12, 27, 14, 27, 14, 13, 28, 12, 27, 11, 26, 20, 23, 19, 29, 18, 18, 17, 18, 24, 14, 13, 28, 14, 28, 28, 16, 16, 15, 12, 27, 23, 19, 13, 17, 17, 17, 23, 29, 22, 11, 19, 14, 23, 18, 17, 28, 22, 12, 14, 16, 13, 16, 13, 12, 15, 13, 28, 17, 25, 23, 13, 24, 23, 27, 15, 25, 27, 27, 11, 18, 16, 18, 11, 18, 18, 13, 18, 16, 16, 27, 25, 18, 18, 20, 16, 29, 18, 21, 12, 16, 29, 25, 16, 27, 13, 20, 12, 12, 14, 25, 23, 14, 13, 28, 14, 29, 26, 24, 22, 19, 20, 13, 11, 11, 23, 28, 28, 14, 12, 25, 17, 17, 20, 15, 29, 19, 19, 14, 19, 18, 17, 20, 18, 19, 23, 16, 19, 25, 22, 17, 14, 13, 19, 23, 20, 20, 27, 25, 16, 23, 18, 13, 18, 18, 27, 22, 27, 18, 29, 16, 16, 18, 18, 18, 29, 18, 18, 16, 16, 13, 27, 29, 13, 27, 18, 18, 16, 20, 17, 13, 19, 19, 28, 25, 11, 13, 25, 20, 14, 27, 25, 17, 14, 20, 14, 25, 19, 28, 20, 15, 27, 15, 14, 16, 16, 17, 18, 11, 27, 19, 20, 29, 16, 13, 11, 12, 28, 16, 28, 27, 13, 16, 18, 17, 18, 28, 12, 16, 23, 16, 26, 11, 16, 27, 27, 18, 27, 29, 27, 27, 16, 21, 27, 16, 27, 16, 27, 16, 27, 11, 27, 11, 27, 16, 16, 18, 11, 16, 16, 13, 13, 16, 20, 20, 19, 13, 17, 27, 27, 15, 12, 24, 15, 17, 11, 17, 16, 27, 19, 12, 13, 20, 23, 11, 16, 14, 20, 12, 22, 15, 27, 27, 14, 13, 16, 12, 11, 15, 28, 19, 11, 29, 19, 17, 23, 12, 17, 16, 26, 17, 18, 17, 11, 14, 27, 16, 13, 14, 17, 22, 11, 20, 14, 17, 22, 28, 23, 29, 26, 19, 17, 19, 14, 29, 11, 28, 28, 22, 14, 17, 16, 13, 16, 14, 27, 28, 18, 28, 28, 20, 19, 25, 13, 18, 15, 28, 25, 20, 20, 27, 17, 16, 27, 13, 18, 17, 17, 15, 12, 23, 18, 19, 25, 14, 28, 28, 21, 16, 14, 16, 20, 27, 13, 25, 27, 26, 28, 11, 25, 21, 15, 19, 27, 19, 14, 10, 28, 11, 23, 17, 14, 13, 16, 15, 11, 14, 12, 16, 14, 17, 23, 27, 27, 28, 17, 28, 19, 14, 25, 18, 12, 23, 16, 27, 20, 14, 16, 16, 17, 21, 25, 19, 16, 18, 27, 11, 15, 17, 28, 16, 11, 16, 11, 16, 11, 11, 16, 11, 27, 16, 16, 14, 27, 28], 'labels': array([False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, True, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, True, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, True, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, True, False, False, False, False, False, True,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, True, True, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, True,\n", + " True, False, False, False, False, False, False, False, False,\n", + " True, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False, False, False,\n", + " False, False, False, False, False, False, False]), 'ident': '11S1_CARIL', 'group': None}\n" + ] + } + ], + "source": [ + "data_pt = torch.load(\n", + " os.path.join(go_class.processed_dir, go_class.processed_file_names_dict[\"data\"]),\n", + " weights_only=False,\n", + ")\n", + "print(\"Type of loaded data:\", type(data_pt))\n", + "print(\"Content of the data file: \\n\", data_pt[0])" + ] + }, + { + "cell_type": "markdown", + "id": "2c9f23883c66b48d", + "metadata": {}, + "source": [ + "**File Path**: `data/GO_UniProt/${dataset_name}/processed/${reader_name}/data.pt`\n", + "\n", + "The `data.pt` file is a list where each element is a dictionary with the following keys:\n", + "\n", + "- **`features`**: \n", + " - **Description**: This key holds the input features for the model. The features are typically stored as tensors and represent the attributes used by the model for training and evaluation.\n", + "\n", + "- **`labels`**: \n", + " - **Description**: This key contains the labels or target values associated with each instance. Labels are also stored as tensors and are used by the model to learn and make predictions.\n", + "\n", + "- **`ident`**: \n", + " - **Description**: This key holds identifiers for each data instance. These identifiers help track and reference the individual samples in the dataset.\n" + ] + }, + { + "cell_type": "markdown", + "id": "36aed0b8-ab05-428d-8833-2a24deebacc3", + "metadata": {}, + "source": [ + "## classes.txt File\n", + "\n", + "**Description**: This file lists the GO classes that are used as labels. It can be used to match labels in `data.pt` with GO classes: For position `i` in the label-tensor, the GO-ID is in line `i` of `classes.txt`" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "19200f7ff9a6ebba", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-30T21:30:34.344202Z", + "start_time": "2024-09-30T21:30:34.328318Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "41\n", + "75\n", + "122\n", + "165\n", + "209\n" + ] + } + ], + "source": [ + "with open(os.path.join(go_class.processed_dir_main, \"classes.txt\"), \"r\") as file:\n", + " for i in range(5):\n", + " line = file.readline()\n", + " print(line.strip())" + ] + }, + { + "cell_type": "markdown", + "id": "f69012b3540fd1b6", + "metadata": {}, + "source": [ + "**File Path**: `data/GO_UniProt/${dataset_name}/processed/classes.txt`\n", + "\n", + "The `classes.txt` file lists selected GO classes. These classes are chosen based on a specified threshold, which is typically used for filtering or categorizing the dataset. Each line in the file corresponds to a unique Swiss Protein class ID, identifying specific protein from Swiss-UniProt dataset." + ] + }, + { + "cell_type": "markdown", + "id": "b81ea34f-cfa8-4ffa-8b88-b54ca96afd84", + "metadata": {}, + "source": [ + "## splits.csv File\n", + "\n", + "**Description**: This file contains saved data splits from previous runs. During subsequent runs, it is used to reconstruct the train, validation, and test splits by filtering the encoded data (`data.pt`) based on the IDs stored in `splits.csv`." + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "88c3ea8f01ba9fac", + "metadata": { + "ExecuteTime": { + "end_time": "2024-09-30T21:30:41.586616Z", + "start_time": "2024-09-30T21:30:39.318598Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idsplit
014331_ARATHtrain
114331_CAEELtrain
214331_MAIZEtrain
314332_MAIZEtrain
414333_ARATHtrain
\n", + "
" + ], + "text/plain": [ + " id split\n", + "0 14331_ARATH train\n", + "1 14331_CAEEL train\n", + "2 14331_MAIZE train\n", + "3 14332_MAIZE train\n", + "4 14333_ARATH train" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "csv_df = pd.read_csv(os.path.join(go_class.processed_dir_main, \"splits.csv\"))\n", + "csv_df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "6661dc11247e9753", + "metadata": {}, + "source": [ + "**File Path**: `data/GO_UniProt/${dataset_name}/processed/splits.csv`\n", + "\n", + "To reuse an existing split, you can use the `splits_file_path` argument. This way, you can reuse the same datasplit across several runs." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "2b02d8b4-c2de-4b8e-b680-ec67b40d9a30", + "metadata": {}, + "outputs": [], + "source": [ + "# You can specify a literal path for the `splits_file_path`, or if another `go_class` instance is already defined,\n", + "# you can use its existing `splits_file_path` attribute for consistency.\n", + "go_class_with_splits = GOUniProtOver250(\n", + " go_branch=\"BP\",\n", + " # splits_file_path=\"data/GO_UniProt/GO250_BP_1002/processed/splits.csv\", # Literal path option\n", + " splits_file_path=go_class.splits_file_path, # Use path from an existing `go_class` instance\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "e6b1f184a5091b83", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "481b8c0271ec9636", + "metadata": {}, + "source": [ + "## 5.1 Protein Representation Using Amino Acid Sequence Notation\n", + "\n", + "Proteins are composed of chains of amino acids, and these sequences can be represented using a one-letter notation for each amino acid. This notation provides a concise way to describe the primary structure of a protein.\n", + "\n", + "### Example Protein Sequence\n", + "\n", + "Protein: **Lysozyme C** from **Gallus gallus** (Chicken). \n", + "[Lysozyme C - UniProtKB P00698](https://www.uniprot.org/uniprotkb/P00698/entry#function)\n", + "\n", + "- **Sequence**: `MRSLLILVLCFLPLAALGKVFGRCELAAAMKRHGLDNYRGYSLGNWVCAAKFESNFNTQATNRNTDGSTDYGILQINSRWWCNDGRTPGSRNLCNIPCSALLSSDITASVNCAKKIVSDGNGMNAWVAWRNRCKGTDVQAWIRGCRL`\n", + "- **Sequence Length**: 147\n", + "\n", + "In this sequence, each letter corresponds to a specific amino acid. This notation is widely used in bioinformatics and molecular biology to represent protein sequences.\n", + "\n", + "### Tokenization and Encoding\n", + "\n", + "To tokenize and numerically encode this protein sequence, the `ProteinDataReader` class is used. This class allows for n-gram tokenization, where the `n_gram` parameter defines the size of the tokenized units. If `n_gram` is not provided (default is `None`), each amino acid letter is treated as a single token.\n", + "\n", + "For more details, you can explore the implementation of the `ProteinDataReader` class in the source code [here](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/reader.py)." + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "id": "e0cf4fb6-2ca4-4b85-a4e7-0cfbac5cd6c1", + "metadata": {}, + "outputs": [], + "source": [ + "from chebai.preprocessing.reader import ProteinDataReader" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "id": "e8343d83-0be3-44df-9224-bba8d5c32336", + "metadata": {}, + "outputs": [], + "source": [ + "protein_dr_3gram = ProteinDataReader(n_gram=3)\n", + "protein_dr = ProteinDataReader()" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "id": "8a18dc27-f308-4dde-b1ae-b03a20fb0d45", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[10, 16, 11, 17, 17, 12, 17, 28, 17, 24, 25, 17, 23, 17, 14, 14, 17, 13, 21]\n", + "[30, 31, 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46]\n" + ] + } + ], + "source": [ + "protein = \"MRSLLILVLCFLPLAALGK\"\n", + "print(protein_dr._read_data(protein))\n", + "print(protein_dr_3gram._read_data(protein))" + ] + }, + { + "cell_type": "markdown", + "id": "7e95738a-0b2d-4c56-ac97-f3b24c1de18f", + "metadata": {}, + "source": [ + "The numbers mentioned above refer to the index of each individual token from the [`tokens.txt`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/bin/protein_token/tokens.txt) file, which is used by the `ProteinDataReader` class. \n", + "\n", + "Each token in the `tokens.txt` file corresponds to a specific amino-acid letter, and these tokens are referenced by their index. Additionally, the index values are offset by the `EMBEDDING_OFFSET`, ensuring that the token embeddings are adjusted appropriately during processing." + ] + }, + { + "cell_type": "markdown", + "id": "fd54ca4a-743c-496e-9e89-cff2d8226eb2", + "metadata": {}, + "source": [ + "### The 20 Amino Acids and Their One-Letter Notations\n", + "\n", + "Here is a list of the 20 standard amino acids, along with their one-letter notations and descriptions:\n", + "\n", + "| One-Letter Notation | Amino Acid Name | Description |\n", + "|---------------------|----------------------|---------------------------------------------------------|\n", + "| **A** | Alanine | Non-polar, aliphatic amino acid. |\n", + "| **C** | Cysteine | Polar, contains a thiol group, forms disulfide bonds. |\n", + "| **D** | Aspartic Acid | Acidic, negatively charged at physiological pH. |\n", + "| **E** | Glutamic Acid | Acidic, negatively charged at physiological pH. |\n", + "| **F** | Phenylalanine | Aromatic, non-polar. |\n", + "| **G** | Glycine | Smallest amino acid, non-polar. |\n", + "| **H** | Histidine | Polar, positively charged, can participate in enzyme active sites. |\n", + "| **I** | Isoleucine | Non-polar, aliphatic. |\n", + "| **K** | Lysine | Basic, positively charged at physiological pH. |\n", + "| **L** | Leucine | Non-polar, aliphatic. |\n", + "| **M** | Methionine | Non-polar, contains sulfur, start codon in mRNA translation. |\n", + "| **N** | Asparagine | Polar, uncharged. |\n", + "| **P** | Proline | Non-polar, introduces kinks in protein chains. |\n", + "| **Q** | Glutamine | Polar, uncharged. |\n", + "| **R** | Arginine | Basic, positively charged, involved in binding phosphate groups. |\n", + "| **S** | Serine | Polar, can be phosphorylated. |\n", + "| **T** | Threonine | Polar, can be phosphorylated. |\n", + "| **V** | Valine | Non-polar, aliphatic. |\n", + "| **W** | Tryptophan | Aromatic, non-polar, largest amino acid. |\n", + "| **Y** | Tyrosine | Aromatic, polar, can be phosphorylated. |\n", + "\n", + "### Understanding Protein Sequences\n", + "\n", + "In the example sequence, each letter represents one of the above amino acids. The sequence reflects the specific order of amino acids in the protein, which is critical for its structure and function.\n", + "\n", + "This notation is used extensively in various bioinformatics tools and databases to study protein structure, function, and interactions.\n", + "\n", + "\n", + "_Note_: Refer for amino acid sequence: https://en.wikipedia.org/wiki/Protein_primary_structure" + ] + }, + { + "cell_type": "markdown", + "id": "db6d7f2cc446e6f9", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "7f42b928364e5cd1", + "metadata": {}, + "source": [ + "## 5.2 More on GO Evidence Codes\n", + "\n", + "The **Gene Ontology (GO) Evidence Codes** provide a way to indicate the level of evidence supporting a GO annotation. Here's a list of the GO evidence codes with brief descriptions:\n", + "\n", + "| **Evidence Code** | **Description** |\n", + "|-----------------------|-----------------|\n", + "| **EXP** | [Inferred from Experiment (EXP)](http://wiki.geneontology.org/index.php/Inferred_from_Experiment_(EXP)) |\n", + "| **IDA** | [Inferred from Direct Assay (IDA)](http://wiki.geneontology.org/index.php/Inferred_from_Direct_Assay_(IDA)) |\n", + "| **IPI** | [Inferred from Physical Interaction (IPI)](http://wiki.geneontology.org/index.php/Inferred_from_Physical_Interaction_(IPI)) |\n", + "| **IMP** | [Inferred from Mutant Phenotype (IMP)](http://wiki.geneontology.org/index.php/Inferred_from_Mutant_Phenotype_(IMP)) |\n", + "| **IGI** | [Inferred from Genetic Interaction (IGI)](http://wiki.geneontology.org/index.php/Inferred_from_Genetic_Interaction_(IGI)) |\n", + "| **IEP** | [Inferred from Expression Pattern (IEP)](http://wiki.geneontology.org/index.php/Inferred_from_Expression_Pattern_(IEP)) |\n", + "| **HTP** | [Inferred from High Throughput Experiment (HTP)](http://wiki.geneontology.org/index.php/Inferred_from_High_Throughput_Experiment_(HTP) ) |\n", + "| **HDA** | [Inferred from High Throughput Direct Assay (HDA)](http://wiki.geneontology.org/index.php/Inferred_from_High_Throughput_Direct_Assay_(HDA)) |\n", + "| **HMP** | [Inferred from High Throughput Mutant Phenotype (HMP)](http://wiki.geneontology.org/index.php/Inferred_from_High_Throughput_Mutant_Phenotype_(HMP)) |\n", + "| **HGI** | [Inferred from High Throughput Genetic Interaction (HGI)](http://wiki.geneontology.org/index.php/Inferred_from_High_Throughput_Genetic_Interaction_(HGI)) |\n", + "| **HEP** | [Inferred from High Throughput Expression Pattern (HEP)](http://wiki.geneontology.org/index.php/Inferred_from_High_Throughput_Expression_Pattern_(HEP)) |\n", + "| **IBA** | [Inferred from Biological aspect of Ancestor (IBA)](http://wiki.geneontology.org/index.php/Inferred_from_Biological_aspect_of_Ancestor_(IBA)) |\n", + "| **IBD** | [Inferred from Biological aspect of Descendant (IBD)](http://wiki.geneontology.org/index.php/Inferred_from_Biological_aspect_of_Descendant_(IBD)) |\n", + "| **IKR** | [Inferred from Key Residues (IKR)](http://wiki.geneontology.org/index.php/Inferred_from_Key_Residues_(IKR)) |\n", + "| **IRD** | [Inferred from Rapid Divergence (IRD)](http://wiki.geneontology.org/index.php/Inferred_from_Rapid_Divergence(IRD)) |\n", + "| **ISS** | [Inferred from Sequence or Structural Similarity (ISS)](http://wiki.geneontology.org/index.php/Inferred_from_Sequence_or_structural_Similarity_(ISS)) |\n", + "| **ISO** | [Inferred from Sequence Orthology (ISO)](http://wiki.geneontology.org/index.php/Inferred_from_Sequence_Orthology_(ISO)) |\n", + "| **ISA** | [Inferred from Sequence Alignment (ISA)](http://wiki.geneontology.org/index.php/Inferred_from_Sequence_Alignment_(ISA)) |\n", + "| **ISM** | [Inferred from Sequence Model (ISM)](http://wiki.geneontology.org/index.php/Inferred_from_Sequence_Model_(ISM)) |\n", + "| **RCA** | [Inferred from Reviewed Computational Analysis (RCA)](http://wiki.geneontology.org/index.php/Inferred_from_Reviewed_Computational_Analysis_(RCA)) |\n", + "| **IEA** | [Inferred from Electronic Annotation (IEA)](http://wiki.geneontology.org/index.php/Inferred_from_Electronic_Annotation_(IEA)) |\n", + "| **TAS** | [Traceable Author Statement (TAS)](http://wiki.geneontology.org/index.php/Traceable_Author_Statement_(TAS)) |\n", + "| **NAS** | [Non-traceable Author Statement (NAS)](http://wiki.geneontology.org/index.php/Non-traceable_Author_Statement_(NAS)) |\n", + "| **IC** | [Inferred by Curator (IC)](http://wiki.geneontology.org/index.php/Inferred_by_Curator_(IC)) |\n", + "| **ND** | [No Biological Data Available (ND)](http://wiki.geneontology.org/index.php/No_biological_Data_available_(ND)_evidence_code) |\n", + "| **NR** | Not Recorded |\n", + "\n", + "\n", + "### **Grouping of Codes**:\n", + "\n", + "- **Experimental Evidence Codes**:\n", + " - **EXP**, **IDA**, **IPI**, **IMP**, **IGI**, **IEP**\n", + " \n", + "- **High-Throughput Experimental Codes**:\n", + " - **HTP**, **HDA**, **HMP**, **HGI**, **HEP**\n", + "\n", + "- **Phylogenetically-Inferred Codes**:\n", + " - **IBA**, **IBD**, **IKR**, **IRD**\n", + "\n", + "- **Author/Curator Inferred Codes**:\n", + " - **TAS**, **IC**, **NAS**\n", + "\n", + "- **Computational Evidence Codes**:\n", + " - **IEA**, **ISS**, **ISA**, **ISM**, **ISO**, **RCA**\n", + "\n", + "- **Others**:\n", + " - **ND** (No Biological Data Available), **NR** (Not Recorded)\n", + "\n", + "\n", + "These evidence codes ensure transparency and give researchers an understanding of how confident they can be in a particular GO annotation.\n", + "\n", + "__Note__ : For more information on GO evidence codes please check [here](https://geneontology.org/docs/guide-go-evidence-codes/) " + ] + }, + { + "cell_type": "markdown", + "id": "1c11d6f520b02434", + "metadata": {}, + "source": [ + "---" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/tutorials/data_exploration_scope.ipynb b/tutorials/data_exploration_scope.ipynb new file mode 100644 index 0000000..c14046a --- /dev/null +++ b/tutorials/data_exploration_scope.ipynb @@ -0,0 +1,1182 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "0bd757ea-a6a0-43f8-8701-cafb44f20f6b", + "metadata": {}, + "source": [ + "# Introduction\n", + "\n", + "This notebook serves as a guide for new developers using the `chebai` package. If you just want to run the experiments, you can refer to the [README.md](https://github.com/ChEB-AI/python-chebai/blob/dev/README.md) and the [wiki](https://github.com/ChEB-AI/python-chebai/wiki) for the basic commands. This notebook explains what happens under the hood for the SCOPe dataset. It covers\n", + "- how to instantiate a data class and generate data\n", + "- how the data is processed and stored\n", + "- and how to work with different molecule encodings.\n", + "\n", + "The `chebai` package simplifies the handling of these datasets by **automatically downloading and processing** them as needed. This means that you do not have to input any data manually; the package will generate and organize the data files based on the parameters and encodings selected. You can however provide your own data files, for instance if you want to replicate a specific experiment.\n", + "\n", + "---\n" + ] + }, + { + "cell_type": "markdown", + "id": "cca637ce-d4ea-4365-acd9-657418e0640f", + "metadata": {}, + "source": [ + "### Overview of SCOPe Data and its Usage in Protein-Related Tasks\n", + "\n", + "#### **What is SCOPe?**\n", + "\n", + "The **Structural Classification of Proteins โ€” extended (SCOPe)** is a comprehensive database that extends the original SCOP (Structural Classification of Proteins) database. SCOPe offers a detailed classification of protein domains based on their structural and evolutionary relationships.\n", + "\n", + "The SCOPe database, like SCOP, organizes proteins into a hierarchy of domains based on structural similarities, which is crucial for understanding evolutionary patterns and functional aspects of proteins. This hierarchical structure is comparable to taxonomy in biology, where species are classified based on shared characteristics.\n", + "\n", + "#### **SCOPe Hierarchy:**\n", + "By analogy with taxonomy, SCOP was created as a hierarchy of several levels where the fundamental unit of classification is a **domain** in the experimentally determined protein structure. Starting at the bottom, the hierarchy of SCOP domains comprises the following levels:\n", + "\n", + "1. **Species**: Representing distinct protein sequences and their naturally occurring or artificially created variants.\n", + "2. **Protein**: Groups together similar sequences with essentially the same functions. These can originate from different biological species or represent isoforms within the same species.\n", + "3. **Family**: Contains proteins with similar sequences but typically distinct functions.\n", + "4. **Superfamily**: Bridges protein families with common functional and structural features, often inferred from a shared evolutionary ancestor.\n", + "5. **Fold**: Groups structurally similar superfamilies. \n", + "6. **Class**: Based on secondary structure content and organization. This level classifies proteins based on their secondary structure properties, such as alpha-helices and beta-sheets.\n", + "\n", + "\n", + "\n", + "For more details, you can refer to the [SCOPe documentation](https://scop.berkeley.edu/help/ver=2.08).\n", + "\n", + "---\n", + "\n", + "#### **Why are We Using SCOPe?**\n", + "\n", + "We are integrating the SCOPe data into our pipeline as part of an ontology pretraining task for protein-related models. SCOPe is a great fit for our goal because it is primarily **structure-based**, unlike other protein-related databases like Gene Ontology (GO), which focuses more on functional classes.\n", + "\n", + "Our primary objective is to reproduce **ontology pretraining** on a protein-related task, and SCOPe provides the structural ontology that we need for this. The steps in our pipeline are aligned as follows:\n", + "\n", + "| **Stage** | **Chemistry Task** | **Proteins Task** |\n", + "|--------------------------|-------------------------------------|------------------------------------------------|\n", + "| **Unsupervised Pretraining** | Mask pretraining (ELECTRA) | Mask pretraining (ESM2, optional) |\n", + "| **Ontology Pretraining** | ChEBI | SCOPe |\n", + "| **Finetuning Task** | Toxicity, Solubility, etc. | GO (MF, BP, CC branches) |\n", + "\n", + " \n", + "This integration will allow us to use **SCOPe** for tasks such as **protein classification** and will contribute to the success of **pretraining models** for protein structures. The data will be processed with the same approach as the GO data, with **different labels** corresponding to the SCOPe classification system.\n", + "\n", + "---\n", + "\n", + "#### **Why SCOPe is Suitable for Our Task**\n", + "\n", + "1. **Structure-Based Classification**: SCOPe is primarily concerned with the structural characteristics of proteins, making it ideal for protein structure pretraining tasks. This contrasts with other ontology databases like **GO**, which categorize proteins based on more complex functional relationships.\n", + " \n", + "2. **Manageable Size**: SCOPe contains around **140,000 entries**, making it a manageable dataset for training models. This is similar in size to **ChEBI**, which is used in the chemical domain, and ensures we can work with it effectively for pretraining." + ] + }, + { + "cell_type": "markdown", + "id": "338e452f-426c-493d-bec2-5bd51e24e4aa", + "metadata": {}, + "source": [ + "\n", + "### Protein Data Bank (PDB)\n", + "\n", + "The **Protein Data Bank (PDB)** is a global repository that stores 3D structural data of biological macromolecules like proteins and nucleic acids. It contains information obtained through experimental methods such as **X-ray crystallography**, **NMR spectroscopy**, and **cryo-EM**. The data includes atomic coordinates, secondary structure details, and experimental conditions.\n", + "\n", + "The PDB is an essential resource for **structural biology**, **bioinformatics**, and **drug discovery**, enabling scientists to understand protein functions, interactions, and mechanisms at the molecular level.\n", + "\n", + "For more details, visit the [RCSB PDB website](https://www.rcsb.org/).\n" + ] + }, + { + "cell_type": "markdown", + "id": "f6c25706-251c-438c-9915-e8002647eb94", + "metadata": {}, + "source": [ + "### Understanding [SCOPe](https://scop.berkeley.edu/) and [PDB](https://www.rcsb.org/) \n", + "\n", + "\n", + "1. **Protein domains form chains.** \n", + "2. **Chains form complexes** (protein complexes or structures). \n", + "3. These **complexes are the entries in PDB**, represented by unique identifiers like `\"1A3N\"`. \n", + "\n", + "---\n", + "\n", + "#### **Protein Domain** \n", + "A **protein domain** is a **structural and functional unit** of a protein. \n", + "\n", + "\n", + "##### Key Characteristics:\n", + "- **Domains are part of a protein chain.** \n", + "- A domain can span: \n", + " 1. **The entire chain** (single-domain protein): \n", + " - In this case, the protein domain is equivalent to the chain itself. \n", + " - Example: \n", + " - All chains of the **PDB structure \"1A3N\"** are single-domain proteins. \n", + " - Each chain has a SCOPe domain identifier. \n", + " - For example, Chain **A**: \n", + " - Domain identifier: `d1a3na_` \n", + " - Breakdown of the identifier: \n", + " - `d`: Denotes domain. \n", + " - `1a3n`: Refers to the PDB protein structure identifier. \n", + " - `a`: Specifies the chain within the structure. (`_` for None and `.` for multiple chains)\n", + " - `_`: Indicates the domain spans the entire chain (single-domain protein). \n", + " - Example: [PDB Structure 1A3N - Chain A](https://www.rcsb.org/sequence/1A3N#A)\n", + " 2. **A specific portion of the chain** (multi-domain protein): \n", + " - Here, a single chain contains multiple domains. \n", + " - Example: Chain **A** of the **PDB structure \"1PKN\"** contains three domains: `d1pkna1`, `d1pkna2`, `d1pkna3`. \n", + " - Example: [PDB Structure 1PKN - Chain A](https://www.rcsb.org/annotations/1PKN). \n", + "\n", + "---\n", + "\n", + "#### **Protein Chain** \n", + "A **protein chain** refers to the entire **polypeptide chain** observed in a protein's 3D structure (as described in PDB files). \n", + "\n", + "##### Key Points:\n", + "- A chain can consist of **one or multiple domains**:\n", + " - **Single-domain chain**: The chain and domain are identical. \n", + " - Example: Myoglobin. \n", + " - **Multi-domain chain**: Contains several domains, each with distinct structural and functional roles. \n", + "- Chains assemble to form **protein complexes** or **structures**. \n", + "\n", + "\n", + "---\n", + "\n", + "#### **Key Observations About SCOPe** \n", + "- The **fundamental classification unit** in SCOPe is the **protein domain**, not the entire protein. \n", + "- _**The taxonomy in SCOPe is not for the entire protein (i.e., the full-length amino acid sequence as encoded by a gene) but for protein domains, which are smaller, structurally and functionally distinct regions of the protein.**_\n", + "\n", + "\n", + "--- \n", + "\n", + "**SCOPe 2.08 Data Analysis:**\n", + "\n", + "The current SCOPe version (2.08) includes the following statistics based on analysis for relevant data:\n", + "\n", + "- **Classes**: 12\n", + "- **Folds**: 1485\n", + "- **Superfamilies**: 2368\n", + "- **Families**: 5431\n", + "- **Proteins**: 13,514\n", + "- **Species**: 30,294\n", + "- **Domains**: 344,851\n", + "\n", + "For more detailed statistics, please refer to the official SCOPe website:\n", + "\n", + "- [SCOPe 2.08 Statistics](https://scop.berkeley.edu/statistics/ver=2.08)\n", + "- [SCOPe 2.08 Release](https://scop.berkeley.edu/ver=2.08)\n", + "\n", + "---\n", + "\n", + "## SCOPe Labeling \n", + "\n", + "- Use SCOPe labels for protein domains.\n", + "- Map them back to their **protein-chain** sequences (protein sequence label = sum of all domain labels).\n", + "- Train on protein sequences.\n", + "- This pretraining task would be comparable to GO-based training.\n", + "\n", + "--- " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "990cc6f2-6b4a-4fa7-905f-dda183c3ec4c", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Changed to project root directory: G:\\github-aditya0by0\\python-chebai\n" + ] + } + ], + "source": [ + "# To run this notebook, you need to change the working directory of the jupyter notebook to root dir of the project.\n", + "import os\n", + "\n", + "# Root directory name of the project\n", + "expected_root_dir = \"python-chebai\"\n", + "\n", + "# Check if the current directory ends with the expected root directory name\n", + "if not os.getcwd().endswith(expected_root_dir):\n", + " os.chdir(\"..\") # Move up one directory level\n", + " if os.getcwd().endswith(expected_root_dir):\n", + " print(\"Changed to project root directory:\", os.getcwd())\n", + " else:\n", + " print(\"Warning: Directory change unsuccessful. Current directory:\", os.getcwd())\n", + "else:\n", + " print(\"Already in the project root directory:\", os.getcwd())" + ] + }, + { + "cell_type": "markdown", + "id": "4550d01fc7af5ae4", + "metadata": {}, + "source": [ + "# 1. Instantiation of a Data Class\n", + "\n", + "To start working with `chebai`, you first need to instantiate a SCOPe data class. This class is responsible for managing, interacting with, and preprocessing the ChEBI chemical data." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "f3a66e07-edc9-4aa2-9cd0-d4ea58914d22", + "metadata": {}, + "outputs": [], + "source": [ + "from chebai.preprocessing.datasets.scope.scope import SCOPeOver50" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a71b7301-6195-4155-a439-f5eb3183d0f3", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:07:26.371796Z", + "start_time": "2024-10-05T21:07:26.058728Z" + } + }, + "outputs": [], + "source": [ + "scope_class = SCOPeOver50(scope_version=\"2.08\")" + ] + }, + { + "cell_type": "markdown", + "id": "b810d7c9-4f7f-4725-9bc2-452ff2c3a89d", + "metadata": {}, + "source": [ + "\n", + "### Inheritance Hierarchy\n", + "\n", + "SCOPe data classes inherit from [`_DynamicDataset`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/base.py#L598), which in turn inherits from [`XYBaseDataModule`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/base.py#L23). Specifically:\n", + "\n", + "- **`_DynamicDataset`**: This class serves as an intermediate base class that provides additional functionality or customization for datasets that require dynamic behavior. It inherits from `XYBaseDataModule`, which provides the core methods for data loading and processing.\n", + "\n", + "- **`XYBaseDataModule`**: This is the base class for data modules, providing foundational properties and methods for handling and processing datasets, including data splitting, loading, and preprocessing.\n", + "\n", + "In summary, ChEBI data classes are designed to manage and preprocess chemical data effectively by leveraging the capabilities provided by `XYBaseDataModule` through the `_DynamicDataset` intermediary.\n", + "\n", + "\n", + "### Input parameters\n", + "A SCOPe data class can be configured with a range of parameters, including:\n", + "\n", + "- **scope_version (str)**: Specifies the version of the ChEBI database to be used. Specifying a version ensures the reproducibility of your experiments by using a consistent dataset.\n", + "\n", + "- **scope_version_train (str, optional)**: The version of ChEBI to use specifically for training and validation. If not set, the `scope_version` specified will be used for all data splits, including training, validation, and test. Defaults to `None`.\n", + "\n", + "- **splits_file_path (str, optional)**: Path to a CSV file containing data splits. If not provided, the class will handle splits internally. Defaults to `None`.\n", + "\n", + "### Additional Input Parameters\n", + "\n", + "To get more control over various aspects of data loading, processing, and splitting, you can refer to documentation of additional parameters in docstrings of the respective classes: [`_SCOPeDataExtractor`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/scope/scope.py#L31), [`XYBaseDataModule`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/base.py#L22), [`_DynamicDataset`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/base.py#L597), etc.\n" + ] + }, + { + "cell_type": "markdown", + "id": "8578b7aa-1bd9-4e50-9eee-01bfc6d5464a", + "metadata": {}, + "source": [ + "# Available SCOPe Data Classes\n", + "\n", + "__Note__: Check the code implementation of classes [here](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/datasets/scope/scope.py):\n", + "\n", + "There is a range of available dataset classes for SCOPe. Usually, you want to use `SCOPeOver2000` or `SCOPeOver50`. The number indicates the threshold for selecting label classes: SCOPe classes which have at least 2000 / 50 subclasses will be used as labels.\n", + "\n", + "Both inherit from `SCOPeOverX`. If you need a different threshold, you can create your own subclass. By default, `SCOPeOverX` uses the Protein encoding (see Section 5).\n", + "\n", + "Finally, `SCOPeOver2000Partial` selects extracts a part of SCOPe based on a given top class, with a threshold of 2000 for selecting labels.\n", + "This class inherits from `SCOPEOverXPartial`.\n" + ] + }, + { + "cell_type": "markdown", + "id": "8456b545-88c5-401d-baa5-47e8ae710f04", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "ed973fb59df11849", + "metadata": {}, + "source": [ + "# 2. Preparation / Setup Methods\n", + "\n", + "Now we have a SCOPe data class with all the relevant parameters. Next, we need to generate the actual dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "11f2208e-fa40-44c9-bfe7-576ca23ad366", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Checking for processed data in data\\SCOPe\\version_2.08\\SCOPe50\\processed\n", + "Missing processed data file (`data.pkl` file)\n", + "Missing PDB raw data, Downloading PDB sequence data....\n", + "Downloading to temporary file C:\\Users\\HP\\AppData\\Local\\Temp\\tmpsif7r129\n", + "Downloaded to C:\\Users\\HP\\AppData\\Local\\Temp\\tmpsif7r129\n", + "Unzipping the file....\n", + "Unpacked and saved to data\\SCOPe\\pdb_sequences.txt\n", + "Removed temporary file C:\\Users\\HP\\AppData\\Local\\Temp\\tmpsif7r129\n", + "Missing Scope: cla.txt raw data, Downloading...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "G:\\anaconda3\\envs\\env_chebai\\lib\\site-packages\\urllib3\\connectionpool.py:1099: InsecureRequestWarning: Unverified HTTPS request is being made to host 'scop.berkeley.edu'. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#tls-warnings\n", + "warnings.warn(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Missing Scope: hie.txt raw data, Downloading...\n", + "Missing Scope: des.txt raw data, Downloading...\n", + "Extracting class hierarchy...\n", + "Computing transitive closure\n", + "Process graph\n", + "101 labels has been selected for specified threshold, \n", + "Constructing data.pkl file .....\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Check for processed data in data\\SCOPe\\version_2.08\\SCOPe50\\processed\\protein_token\n", + "Cross-validation enabled: False\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Missing transformed data (`data.pt` file). Transforming data.... \n", + "Processing 60298 lines...\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "100%|โ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆโ–ˆ| 60298/60298 [00:53<00:00, 1119.10it/s]\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Saving 21 tokens to G:\\github-aditya0by0\\python-chebai\\chebai\\preprocessing\\bin\\protein_token\\tokens.txt...\n", + "First 10 tokens: ['M', 'S', 'I', 'G', 'A', 'T', 'R', 'L', 'Q', 'N']\n" + ] + } + ], + "source": [ + "scope_class.prepare_data()\n", + "scope_class.setup()" + ] + }, + { + "cell_type": "markdown", + "id": "1655d489-25fe-46de-9feb-eeca5d36936f", + "metadata": {}, + "source": [ + "\n", + "### Automatic Execution: \n", + "These methods are executed automatically when using the training command `chebai fit`. Users do not need to call them explicitly, as the code internally manages the preparation and setup of data, ensuring that it is ready for subsequent use in training and validation processes.\n", + "\n", + "### Why is Preparation Needed?\n", + "\n", + "- **Data Availability**: The preparation step ensures that the required SCOPe data files are downloaded or loaded, which are essential for analysis.\n", + "- **Data Integrity**: It ensures that the data files are transformed into a compatible format required for model input.\n", + "\n", + "### Main Methods for Data Preprocessing\n", + "\n", + "The data preprocessing in a data class involves two main methods:\n", + "\n", + "1. **`prepare_data` Method**:\n", + " - **Purpose**: This method checks for the presence of raw data in the specified directory. If the raw data is missing, it fetches the ontology, creates a dataframe, and saves it to a file (`data.pkl`). The dataframe includes columns such as IDs, data representations, and labels. This step is independent of input encodings.\n", + " - **Documentation**: [PyTorch Lightning - `prepare_data`](https://lightning.ai/docs/pytorch/stable/data/datamodule.html#prepare-data)\n", + "\n", + "2. **`setup` Method**:\n", + " - **Purpose**: This method sets up the data module for training, validation, and testing. It checks for the processed data and, if necessary, performs additional setup to ensure the data is ready for model input. It also handles cross-validation settings if enabled.\n", + " - **Description**: Transforms `data.pkl` into a model input data format (`data.pt`), tokenizing the input according to the specified encoding. The transformed data contains the following keys: `ident`, `features`, `labels`, and `group`. This method uses a subclass of Data Reader to perform the tokenization.\n", + " - **Documentation**: [PyTorch Lightning - `setup`](https://lightning.ai/docs/pytorch/stable/data/datamodule.html#setup)\n", + "\n", + "These methods ensure that the data is correctly prepared and set up for subsequent use in training and validation processes." + ] + }, + { + "cell_type": "markdown", + "id": "f5aaa12d-5f01-4b74-8b59-72562af953bf", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "bb6e9a81554368f7", + "metadata": {}, + "source": [ + "# 3. Overview of the 3 preprocessing stages\n", + "\n", + "The `chebai` library follows a three-stage preprocessing pipeline, which is reflected in its file structure:\n", + "\n", + "1. **Raw Data Stage**:\n", + " - **Files**: `cla.txt`, `des.txt` and `hie.txt`. Please find description of each file [here](https://scop.berkeley.edu/help/ver=2.08#parseablefiles-2.08).\n", + " - **Description**: This stage contains the raw SCOPe data in txt format, serving as the initial input for further processing.\n", + " - **File Path**: `data/SCOPe/version_${scope_version}/raw/${filename}.txt`\n", + "\n", + "2. **Processed Data Stage 1**:\n", + " - **File**: `data.pkl`\n", + " - **Description**: This stage includes the data after initial processing. It contains protein sequence strings, class columns, and metadata but lacks data splits.\n", + " - **File Path**: `data/SCOPe/version_${scope_version}/${dataset_name}/processed/data.pkl`\n", + " - **Additional File**: `classes.txt` - A file listing the relevant SCOPe classes.\n", + "\n", + "3. **Processed Data Stage 2**:\n", + " - **File**: `data.pt`\n", + " - **Description**: This final stage includes the encoded data in a format compatible with PyTorch, ready for model input. This stage also references data splits when available.\n", + " - **File Path**: `data/SCOPe/version_${scope_version}/${dataset_name}/processed/${reader_name}/data.pt`\n", + " - **Additional File**: `splits.csv` - Contains saved splits for reproducibility.\n", + "\n", + "This structured approach to data management ensures that each stage of data processing is well-organized and documented, from raw data acquisition to the preparation of model-ready inputs. It also facilitates reproducibility and traceability across different experiments.\n", + "\n", + "### Data Splits\n", + "\n", + "- **Creation**: Data splits are generated dynamically \"on the fly\" during training and evaluation to ensure flexibility and adaptability to different tasks.\n", + "- **Reproducibility**: To maintain consistency across different runs, splits can be reproduced by comparing hashes with a fixed seed value.\n" + ] + }, + { + "cell_type": "markdown", + "id": "7e172c0d1e8bb93f", + "metadata": {}, + "source": [ + "# 4. Data Files and their structure\n", + "\n", + "`chebai` creates and manages several data files during its operation. These files store various chemical data and metadata essential for different tasks. Letโ€™s explore these files and their content.\n" + ] + }, + { + "cell_type": "markdown", + "id": "43329709-5134-4ce5-88e7-edd2176bf84d", + "metadata": {}, + "source": [ + "## raw files\n", + "- cla.txt, des.txt and hie.txt\n", + "\n", + "For detailed description of raw files and their structures, please refer the official website [here](https://scop.berkeley.edu/help/ver=2.08#parseablefiles-2.08).\n" + ] + }, + { + "cell_type": "markdown", + "id": "558295e5a7ded456", + "metadata": {}, + "source": [ + "## data.pkl File\n", + "\n", + "**Description**: Generated by the `prepare_data` method, this file contains processed data in a dataframe format. It includes the ids, sids which are used to label corresponding sequence, protein-chain sequence, and columns for each label with boolean values." + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "fd490270-59b8-4c1c-8b09-204defddf592", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:09:01.622317Z", + "start_time": "2024-10-05T21:09:01.606698Z" + } + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "d7d16247-092c-4e8d-96c2-ab23931cf766", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:11:51.296162Z", + "start_time": "2024-10-05T21:11:44.559304Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Size of the data (rows x columns): (60424, 1035)\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idsidssequenceclass_46456class_48724class_51349class_53931class_56572class_56835class_56992...species_187294species_56257species_186882species_56690species_161316species_57962species_58067species_267696species_311502species_311501
01[d4oq9a_, d4oq9b_, d4oq9c_, d4oq9d_, d4niaa_, ...AAAAAAAAAAFalseTrueFalseFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
12[d7dxhc_]AAAAAAAAAAAAAAAAAAAAAAAFalseFalseFalseFalseFalseTrueFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseFalse
23[d1gkub1, d1gkub2, d1gkub3, d1gkub4]AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAASLCLFPEDFLLKEF...FalseFalseTrueFalseTrueFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseTrue
34[d3c9wa2, d3c9wb2, d3c9wa3, d3c9wb3]AAAAAAGPEMVRGQVFDVGPRYTNLSYIGEGAYGMVCSAYDNLNKV...FalseFalseFalseTrueFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseTrue
45[d1xwaa1, d1xwab_, d1xwac_, d1xwad_, d1xwaa2]AAAAAMVYQVKDKADLDGQLTKASGKLVVLDFFATWCGPCKMISPK...FalseFalseTrueFalseFalseFalseFalse...FalseFalseFalseFalseFalseFalseFalseFalseFalseTrue
\n", + "

5 rows ร— 1035 columns

\n", + "
" + ], + "text/plain": [ + " id sids \\\n", + "0 1 [d4oq9a_, d4oq9b_, d4oq9c_, d4oq9d_, d4niaa_, ... \n", + "1 2 [d7dxhc_] \n", + "2 3 [d1gkub1, d1gkub2, d1gkub3, d1gkub4] \n", + "3 4 [d3c9wa2, d3c9wb2, d3c9wa3, d3c9wb3] \n", + "4 5 [d1xwaa1, d1xwab_, d1xwac_, d1xwad_, d1xwaa2] \n", + "\n", + " sequence class_46456 \\\n", + "0 AAAAAAAAAA False \n", + "1 AAAAAAAAAAAAAAAAAAAAAAA False \n", + "2 AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAASLCLFPEDFLLKEF... False \n", + "3 AAAAAAGPEMVRGQVFDVGPRYTNLSYIGEGAYGMVCSAYDNLNKV... False \n", + "4 AAAAAMVYQVKDKADLDGQLTKASGKLVVLDFFATWCGPCKMISPK... False \n", + "\n", + " class_48724 class_51349 class_53931 class_56572 class_56835 \\\n", + "0 True False False False False \n", + "1 False False False False True \n", + "2 False True False True False \n", + "3 False False True False False \n", + "4 False True False False False \n", + "\n", + " class_56992 ... species_187294 species_56257 species_186882 \\\n", + "0 False ... False False False \n", + "1 False ... False False False \n", + "2 False ... False False False \n", + "3 False ... False False False \n", + "4 False ... False False False \n", + "\n", + " species_56690 species_161316 species_57962 species_58067 \\\n", + "0 False False False False \n", + "1 False False False False \n", + "2 False False False False \n", + "3 False False False False \n", + "4 False False False False \n", + "\n", + " species_267696 species_311502 species_311501 \n", + "0 False False False \n", + "1 False False False \n", + "2 False False True \n", + "3 False False True \n", + "4 False False True \n", + "\n", + "[5 rows x 1035 columns]" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "pkl_df = pd.DataFrame(\n", + " pd.read_pickle(\n", + " os.path.join(\n", + " scope_class.processed_dir_main,\n", + " scope_class.processed_main_file_names_dict[\"data\"],\n", + " )\n", + " )\n", + ")\n", + "print(\"Size of the data (rows x columns): \", pkl_df.shape)\n", + "pkl_df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "322bc926-69ff-4b93-9e95-5e8b85869c38", + "metadata": {}, + "source": [ + "**File Path**: `data/SCOPe/version_${scope_version}/${dataset_name}/processed/data.pkl`\n", + "\n", + "\n", + "### Structure of `data.pkl`\n", + "`data.pkl` as following structure: \n", + "- **Column 0**: Contains the ID of eachdata instance.\n", + "- **Column 1**: Contains the `sids` which are associated with corresponding protein-chain sequence.\n", + "- **Column 2**: Contains the protein-chain sequence.\n", + "- **Column 3 and onwards**: Contains the labels, starting from column 3.\n", + "\n", + "This structure ensures that the data is organized and ready for further processing, such as further encoding.\n" + ] + }, + { + "cell_type": "markdown", + "id": "ba019d2d4324bd0b", + "metadata": {}, + "source": [ + "## data.pt File\n", + "\n", + "\n", + "**Description**: Generated by the `setup` method, this file contains encoded data in a format compatible with the PyTorch library, specifically as a list of dictionaries. Each dictionary in this list includes keys such as `ident`, `features`, `labels`, and `group`, ready for model input." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "977ddd83-b469-4b58-ab1a-8574fb8769b4", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:12:49.338943Z", + "start_time": "2024-10-05T21:12:49.323319Z" + } + }, + "outputs": [], + "source": [ + "import torch" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "3266ade9-efdc-49fe-ae07-ed52b2eb52d0", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:14:12.892845Z", + "start_time": "2024-10-05T21:13:59.859953Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Type of loaded data: \n" + ] + } + ], + "source": [ + "data_pt = torch.load(\n", + " os.path.join(\n", + " scope_class.processed_dir, scope_class.processed_file_names_dict[\"data\"]\n", + " ),\n", + " weights_only=False,\n", + ")\n", + "print(\"Type of loaded data:\", type(data_pt))" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "84cfa3e6-f60d-47c0-9f82-db3d5673d1e7", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:14:21.185027Z", + "start_time": "2024-10-05T21:14:21.169358Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "{'features': [14, 14, 14, 14, 20, 15, 15, 28, 15, 18, 25, 17, 18, 11, 25, 21, 27, 19, 14, 27, 19, 13, 14, 17, 16, 21, 25, 22, 27, 28, 12, 10, 20, 19, 13, 13, 14, 28, 17, 20, 20, 12, 19, 11, 17, 15, 27, 28, 15, 12, 17, 14, 23, 11, 19, 27, 14, 26, 19, 11, 11, 19, 12, 19, 19, 28, 17, 16, 20, 16, 19, 21, 10, 16, 18, 12, 17, 19, 10, 29, 12, 12, 21, 20, 16, 17, 19, 28, 20, 21, 12, 16, 18, 21, 19, 14, 19, 17, 12, 14, 18, 28, 23, 15, 28, 19, 19, 19, 15, 25, 17, 22, 25, 19, 28, 16, 13, 27, 13, 11, 20, 15, 28, 12, 15, 28, 27, 13, 13, 13, 28, 19, 14, 15, 28, 12, 18, 14, 20, 28, 14, 18, 15, 19, 13, 22, 28, 29, 12, 12, 20, 29, 28, 17, 13, 28, 23, 22, 15, 15, 28, 17, 13, 21, 17, 27, 11, 20, 23, 10, 10, 11, 20, 15, 22, 21, 10, 13, 21, 25, 11, 29, 25, 19, 20, 18, 17, 19, 19, 15, 18, 16, 16, 25, 15, 22, 25, 28, 23, 16, 20, 21, 13, 26, 18, 21, 15, 27, 17, 20, 22, 23, 11, 14, 29, 21, 21, 17, 25, 10, 14, 20, 25, 11, 22, 29, 11, 21, 11, 12, 17, 27, 16, 29, 17, 14, 12, 11, 20, 21, 27, 22, 15, 10, 21, 20, 17, 28, 21, 25, 11, 18, 27, 11, 13, 11, 28, 12, 17, 23, 15, 25, 16, 20, 11, 17, 11, 12, 16, 28, 27, 27, 27, 14, 13, 16, 22, 28, 12, 12, 26, 19, 22, 21, 21, 12, 19, 28, 22, 16, 23, 20, 28, 27, 24, 15, 19, 13, 12, 12, 29, 28, 12, 20, 22, 23, 17, 17, 27, 27, 21, 20, 28, 28, 28, 14, 13, 13, 11, 14, 14, 14, 14, 14], 'labels': array([False, True, False, ..., False, False, False]), 'ident': 6, 'group': None}\n" + ] + } + ], + "source": [ + "for i in range(5, 6):\n", + " print(data_pt[i])" + ] + }, + { + "cell_type": "markdown", + "id": "0d80ffbb-5f1e-4489-9bc8-d688c9be1d07", + "metadata": {}, + "source": [ + "**File Path**: `data/SCOPe/version_${scope_version}/${dataset_name}/processed/${reader_name}/data.pt`\n", + "\n", + "\n", + "### Structure of `data.pt`\n", + "\n", + "The `data.pt` file is a list where each element is a dictionary with the following keys:\n", + "\n", + "- **`features`**: \n", + " - **Description**: This key holds the input features for the model. The features are typically stored as tensors and represent the attributes used by the model for training and evaluation.\n", + "\n", + "- **`labels`**: \n", + " - **Description**: This key contains the labels or target values associated with each instance. Labels are also stored as tensors and are used by the model to learn and make predictions.\n", + "\n", + "- **`ident`**: \n", + " - **Description**: This key holds identifiers for each data instance. These identifiers help track and reference the individual samples in the dataset.\n" + ] + }, + { + "cell_type": "markdown", + "id": "186ec6f0eed6ecf7", + "metadata": {}, + "source": [ + "## classes.txt File\n", + "\n", + "**Description**: A file containing the list of selected SCOPe **labels** based on the specified threshold. This file is crucial for ensuring that only relevant **labels** are included in the dataset." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "8d1fbe6c-beb8-4038-93d4-c56bc7628716", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:15:19.146285Z", + "start_time": "2024-10-05T21:15:18.503284Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "class_48724\n", + "class_53931\n", + "class_310555\n", + "fold_48725\n", + "fold_56111\n", + "fold_56234\n", + "fold_310573\n", + "superfamily_48726\n", + "superfamily_56112\n", + "superfamily_56235\n", + "superfamily_310607\n", + "family_48942\n", + "family_56251\n", + "family_191359\n", + "family_191470\n" + ] + } + ], + "source": [ + "with open(os.path.join(scope_class.processed_dir_main, \"classes.txt\"), \"r\") as file:\n", + " for i in range(15):\n", + " line = file.readline()\n", + " print(line.strip())" + ] + }, + { + "cell_type": "markdown", + "id": "861da1c3-0401-49f0-a22f-109814ed95d5", + "metadata": {}, + "source": [ + "\n", + "**File Path**: `data/SCOPe/version_${scope_version}/${dataset_name}/processed/classes.txt`\n", + "\n", + "The `classes.txt` file lists selected SCOPe classes. These classes are chosen based on a specified threshold, which is typically used for filtering or categorizing the dataset. Each line in the file corresponds to a unique SCOPe class ID, identifying specific class withing SCOPe ontology along with the hierarchy level.\n", + "\n", + "This file is essential for organizing the data and ensuring that only relevant classes, as defined by the threshold, are included in subsequent processing and analysis tasks.\n" + ] + }, + { + "cell_type": "markdown", + "id": "fb72be449e52b63f", + "metadata": {}, + "source": [ + "## splits.csv File\n", + "\n", + "**Description**: Contains saved data splits from previous runs. During subsequent runs, this file is used to reconstruct the train, validation, and test splits by filtering the encoded data (`data.pt`) based on the IDs stored in `splits.csv`." + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "3ebdcae4-4344-46bd-8fc0-a82ef5d40da5", + "metadata": { + "ExecuteTime": { + "end_time": "2024-10-05T21:15:54.575116Z", + "start_time": "2024-10-05T21:15:53.945139Z" + } + }, + "outputs": [ + { + "data": { + "text/html": [ + "
\n", + "\n", + "\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "
idsplit
01train
13train
24train
36train
49train
\n", + "
" + ], + "text/plain": [ + " id split\n", + "0 1 train\n", + "1 3 train\n", + "2 4 train\n", + "3 6 train\n", + "4 9 train" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "csv_df = pd.read_csv(os.path.join(scope_class.processed_dir_main, \"splits.csv\"))\n", + "csv_df.head()" + ] + }, + { + "cell_type": "markdown", + "id": "b058714f-e434-4367-89b9-74c129ac727f", + "metadata": {}, + "source": [ + "\n", + "\n", + "**File Path**: `data/SCOPe/version_${scope_version}/${dataset_name}/processed/splits.csv`\n", + "\n", + "The `splits.csv` file contains the saved data splits from previous runs, including the train, validation, and test sets. During subsequent runs, this file is used to reconstruct these splits by filtering the encoded data (`data.pt`) based on the IDs stored in `splits.csv`. This ensures consistency and reproducibility in data splitting, allowing for reliable evaluation and comparison of model performance across different run.\n" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "6dc3fd6c-7cf6-47ef-812f-54319a0cdeb9", + "metadata": {}, + "outputs": [], + "source": [ + "# You can specify a literal path for the `splits_file_path`, or if another `scope_class` instance is already defined,\n", + "# you can use its existing `splits_file_path` attribute for consistency.\n", + "scope_class_with_splits = SCOPeOver2000(\n", + " scope_version=\"2.08\",\n", + " # splits_file_path=\"data/chebi_v231/ChEBI50/processed/splits.csv\", # Literal path option\n", + " splits_file_path=scope_class.splits_file_path, # Use path from an existing `chebi_class` instance\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "a5eb482c-ce5b-4efc-b2ec-85ac7b1a78ee", + "metadata": {}, + "source": [ + "---" + ] + }, + { + "cell_type": "markdown", + "id": "ab110764-216d-4d52-a9d1-4412c8ac8c9d", + "metadata": {}, + "source": [ + "## 5.1 Protein Representation Using Amino Acid Sequence Notation\n", + "\n", + "Proteins are composed of chains of amino acids, and these sequences can be represented using a one-letter notation for each amino acid. This notation provides a concise way to describe the primary structure of a protein.\n", + "\n", + "### Example Protein Sequence\n", + "\n", + "Protein-Chain: PDB ID:**1cph** Chain ID:**B** mol:protein length:30 INSULIN (PH 10)\n", + "
Refer - [1cph_B](https://www.rcsb.org/sequence/1CPH)\n", + "\n", + "- **Sequence**: `FVNQHLCGSHLVEALYLVCGERGFFYTPKA`\n", + "- **Sequence Length**: 30\n", + "\n", + "In this sequence, each letter corresponds to a specific amino acid. This notation is widely used in bioinformatics and molecular biology to represent protein sequences.\n", + "\n", + "### Tokenization and Encoding\n", + "\n", + "To tokenize and numerically encode this protein sequence, the `ProteinDataReader` class is used. This class allows for n-gram tokenization, where the `n_gram` parameter defines the size of the tokenized units. If `n_gram` is not provided (default is `None`), each amino acid letter is treated as a single token.\n", + "\n", + "For more details, you can explore the implementation of the `ProteinDataReader` class in the source code [here](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/reader.py)." + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "da47d47e-4560-46af-b246-235596f27d82", + "metadata": {}, + "outputs": [], + "source": [ + "from chebai.preprocessing.reader import ProteinDataReader" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "8bdbf309-29ec-4aab-a6dc-9e09bc6961a2", + "metadata": {}, + "outputs": [], + "source": [ + "protein_dr_3gram = ProteinDataReader(n_gram=3)\n", + "protein_dr = ProteinDataReader()" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "68e5c87c-79c3-4d5f-91e6-635399a84d3d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "[25, 28, 19, 18, 29, 17, 24, 13, 11, 29, 17, 28, 27, 14, 17, 22, 17, 28, 24, 13, 27, 16, 13, 25, 25, 22, 15, 23, 21, 14]\n", + "[5023, 2218, 3799, 2290, 6139, 2208, 6917, 4674, 484, 439, 2737, 851, 365, 2624, 3240, 4655, 1904, 3737, 1453, 2659, 5160, 3027, 2355, 7163, 4328, 3115, 6207, 1234]\n" + ] + } + ], + "source": [ + "protein = \"FVNQHLCGSHLVEALYLVCGERGFFYTPKA\"\n", + "print(protein_dr._read_data(protein))\n", + "print(protein_dr_3gram._read_data(protein))" + ] + }, + { + "cell_type": "markdown", + "id": "5b7211ee-2ccc-46d3-8e8f-790f344726ba", + "metadata": {}, + "source": [ + "The numbers mentioned above refer to the index of each individual token from the [`tokens.txt`](https://github.com/ChEB-AI/python-chebai/blob/dev/chebai/preprocessing/bin/protein_token/tokens.txt) file, which is used by the `ProteinDataReader` class. \n", + "\n", + "Each token in the `tokens.txt` file corresponds to a specific amino-acid letter, and these tokens are referenced by their index. Additionally, the index values are offset by the `EMBEDDING_OFFSET`, ensuring that the token embeddings are adjusted appropriately during processing." + ] + }, + { + "cell_type": "markdown", + "id": "93e328cf-09f9-4694-b175-28320590937d", + "metadata": {}, + "source": [ + "---" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.14" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} From 7221a9e717f05578d70525df2d845904e4f8f1cc Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 14 Apr 2025 19:13:18 +0200 Subject: [PATCH 02/36] remove chebi imports in init.py --- chebai/preprocessing/datasets/__init__.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/chebai/preprocessing/datasets/__init__.py b/chebai/preprocessing/datasets/__init__.py index d09b21c..d5bfb20 100644 --- a/chebai/preprocessing/datasets/__init__.py +++ b/chebai/preprocessing/datasets/__init__.py @@ -1,4 +1 @@ -from .base import XYBaseDataModule -from .chebi import * -from .pubchem import * -from .tox21 import * +from .base import XYBaseDataModule, _DynamicDataset \ No newline at end of file From 2422518024fa430a4a6c20bd2870ff02ffa065f9 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 14 Apr 2025 19:23:41 +0200 Subject: [PATCH 03/36] changes for fix --- chebai/models/base.py | 2 +- chebai/preprocessing/collect_all.py | 3 +- chebai/preprocessing/datasets/__init__.py | 2 +- configs/data/scope/scope50.yml | 2 +- tests/unit/mock_data/ontology_mock_data.py | 292 +++++++++++++++++++++ 5 files changed, 296 insertions(+), 5 deletions(-) diff --git a/chebai/models/base.py b/chebai/models/base.py index 4ba27bb..677640c 100644 --- a/chebai/models/base.py +++ b/chebai/models/base.py @@ -1,5 +1,5 @@ import logging -from typing import Any, Dict, Optional, Union, Iterable +from typing import Any, Dict, Iterable, Optional, Union import torch from lightning.pytorch.core.module import LightningModule diff --git a/chebai/preprocessing/collect_all.py b/chebai/preprocessing/collect_all.py index 62e140f..6e24d83 100644 --- a/chebai/preprocessing/collect_all.py +++ b/chebai/preprocessing/collect_all.py @@ -5,6 +5,7 @@ import pytorch_lightning as pl import torch import torch.nn.functional as F +from data import ClassificationData, JCIClassificationData from pytorch_lightning import loggers as pl_loggers from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.metrics import F1 @@ -13,8 +14,6 @@ from torch_geometric import nn as tgnn from torch_geometric.data import DataLoader -from data import ClassificationData, JCIClassificationData - logging.getLogger("pysmiles").setLevel(logging.CRITICAL) diff --git a/chebai/preprocessing/datasets/__init__.py b/chebai/preprocessing/datasets/__init__.py index d5bfb20..d6cc8de 100644 --- a/chebai/preprocessing/datasets/__init__.py +++ b/chebai/preprocessing/datasets/__init__.py @@ -1 +1 @@ -from .base import XYBaseDataModule, _DynamicDataset \ No newline at end of file +from .base import XYBaseDataModule, _DynamicDataset diff --git a/configs/data/scope/scope50.yml b/configs/data/scope/scope50.yml index c65028e..a5f808d 100644 --- a/configs/data/scope/scope50.yml +++ b/configs/data/scope/scope50.yml @@ -1,3 +1,3 @@ class_path: chebai.preprocessing.datasets.scope.scope.SCOPeOver50 init_args: - scope_version: "2.08" \ No newline at end of file + scope_version: "2.08" diff --git a/tests/unit/mock_data/ontology_mock_data.py b/tests/unit/mock_data/ontology_mock_data.py index 87d24bf..552d291 100644 --- a/tests/unit/mock_data/ontology_mock_data.py +++ b/tests/unit/mock_data/ontology_mock_data.py @@ -114,6 +114,298 @@ def get_data_in_dataframe() -> pd.DataFrame: pass +class ChebiMockOntology(MockOntologyGraphData): + """ + A mock ontology representing a simplified ChEBI (Chemical Entities of Biological Interest) structure. + This class is used for testing purposes and includes nodes and edges representing chemical compounds + and their relationships in a graph structure. + + Nodes: + - CHEBI:12345 (Compound A) + - CHEBI:54321 (Compound B) + - CHEBI:67890 (Compound C) + - CHEBI:11111 (Compound D) + - CHEBI:22222 (Compound E) + - CHEBI:99999 (Compound F) + - CHEBI:77533 (Compound G, Obsolete node) + - CHEBI:77564 (Compound H, Obsolete node) + - CHEBI:88888 (Compound I) + + Valid Edges: + - CHEBI:54321 -> CHEBI:12345 + - CHEBI:67890 -> CHEBI:12345 + - CHEBI:67890 -> CHEBI:88888 + - CHEBI:11111 -> CHEBI:54321 + - CHEBI:22222 -> CHEBI:67890 + - CHEBI:12345 -> CHEBI:99999 + + The class also includes methods to retrieve nodes, edges, and transitive closure of the graph. + + Visual Representation Graph with Valid Nodes and Edges: + + 22222 + / + 11111 67890 + \\ / \ + 54321 / 88888 + \\ / + 12345 + \ + 99999 + """ + + @staticmethod + def get_nodes() -> List[int]: + """ + Get the set of valid node IDs in the mock ontology. + + Returns: + - Set[int]: A set of integers representing the valid ChEBI node IDs. + """ + return [11111, 12345, 22222, 54321, 67890, 88888, 99999] + + @staticmethod + def get_number_of_nodes() -> int: + """ + Get the number of valid nodes in the mock ontology. + + Returns: + - int: The number of valid nodes. + """ + return len(ChebiMockOntology.get_nodes()) + + @staticmethod + def get_edges() -> Set[Tuple[int, int]]: + """ + Get the set of valid edges in the mock ontology. + + Returns: + - Set[Tuple[int, int]]: A set of tuples representing the directed edges + between ChEBI nodes. + """ + return { + (54321, 12345), + (67890, 12345), + (67890, 88888), + (11111, 54321), + (22222, 67890), + (12345, 99999), + } + + @staticmethod + def get_number_of_edges() -> int: + """ + Get the number of valid edges in the mock ontology. + + Returns: + - int: The number of valid edges. + """ + return len(ChebiMockOntology.get_edges()) + + @staticmethod + def get_edges_of_transitive_closure_graph() -> Set[Tuple[int, int]]: + """ + Get the set of edges derived from the transitive closure of the mock ontology graph. + + Returns: + - Set[Tuple[int, int]]: A set of tuples representing the directed edges + in the transitive closure of the ChEBI graph. + """ + return { + (54321, 12345), + (54321, 99999), + (67890, 12345), + (67890, 99999), + (67890, 88888), + (11111, 54321), + (11111, 12345), + (11111, 99999), + (22222, 67890), + (22222, 12345), + (22222, 99999), + (22222, 88888), + (12345, 99999), + } + + @staticmethod + def get_number_of_transitive_edges() -> int: + """ + Get the number of edges in the transitive closure of the mock ontology graph. + + Returns: + - int: The number of edges in the transitive closure graph. + """ + return len(ChebiMockOntology.get_edges_of_transitive_closure_graph()) + + @staticmethod + def get_obsolete_nodes_ids() -> Set[int]: + """ + Get the set of obsolete node IDs in the mock ontology. + + Returns: + - Set[int]: A set of integers representing the obsolete ChEBI node IDs. + """ + return {77533, 77564} + + @staticmethod + def get_raw_data() -> str: + """ + Get the raw data representing the mock ontology in OBO format. + + Returns: + - str: A string containing the raw OBO data for the mock ChEBI terms. + """ + return """ + [Term] + id: CHEBI:12345 + name: Compound A + subset: 2_STAR + property_value: http://purl.obolibrary.org/obo/chebi/formula "C26H35ClN4O6S" xsd:string + property_value: http://purl.obolibrary.org/obo/chebi/charge "0" xsd:string + property_value: http://purl.obolibrary.org/obo/chebi/monoisotopicmass "566.19658" xsd:string + property_value: http://purl.obolibrary.org/obo/chebi/mass "567.099" xsd:string + property_value: http://purl.obolibrary.org/obo/chebi/inchikey "ROXPMFGZZQEKHB-IUKKYPGJSA-N" xsd:string + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=CC=CC=C1" xsd:string + property_value: http://purl.obolibrary.org/obo/chebi/inchi "InChI=1S/C26H35ClN4O6S/c1-16(2)28-26(34)30(5)14-23-17(3)13-31(18(4)15-32)25(33)21-7-6-8-22(24(21)37-23)29-38(35,36)20-11-9-19(27)10-12-20/h6-12,16-18,23,29,32H,13-15H2,1-5H3,(H,28,34)/t17-,18-,23+/m0/s1" xsd:string + xref: LINCS:LSM-20139 + is_a: CHEBI:54321 + is_a: CHEBI:67890 + + [Term] + id: CHEBI:54321 + name: Compound B + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=CC=CC=C1O" xsd:string + is_a: CHEBI:11111 + is_a: CHEBI:77564 + + [Term] + id: CHEBI:67890 + name: Compound C + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=CC=CC=C1N" xsd:string + is_a: CHEBI:22222 + + [Term] + id: CHEBI:11111 + name: Compound D + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=CC=CC=C1F" xsd:string + + [Term] + id: CHEBI:22222 + name: Compound E + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=CC=CC=C1Cl" xsd:string + + [Term] + id: CHEBI:99999 + name: Compound F + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=CC=CC=C1Br" xsd:string + is_a: CHEBI:12345 + + [Term] + id: CHEBI:77533 + name: Compound G + is_a: CHEBI:99999 + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=C1Br" xsd:string + is_obsolete: true + + [Term] + id: CHEBI:77564 + name: Compound H + property_value: http://purl.obolibrary.org/obo/chebi/smiles "CC=C1Br" xsd:string + is_obsolete: true + + [Typedef] + id: has_major_microspecies_at_pH_7_3 + name: has major microspecies at pH 7.3 + is_cyclic: true + is_transitive: false + + [Term] + id: CHEBI:88888 + name: Compound I + property_value: http://purl.obolibrary.org/obo/chebi/smiles "C1=CC=CC=C1[Mg+]" xsd:string + is_a: CHEBI:67890 + """ + + @staticmethod + def get_data_in_dataframe() -> pd.DataFrame: + data = OrderedDict( + id=[ + 12345, + 54321, + 67890, + 11111, + 22222, + 99999, + 88888, + ], + name=[ + "Compound A", + "Compound B", + "Compound C", + "Compound D", + "Compound E", + "Compound F", + "Compound I", + ], + SMILES=[ + "C1=CC=CC=C1", + "C1=CC=CC=C1O", + "C1=CC=CC=C1N", + "C1=CC=CC=C1F", + "C1=CC=CC=C1Cl", + "C1=CC=CC=C1Br", + "C1=CC=CC=C1[Mg+]", + ], + **{ + # -row- [12345, 54321, 67890, 11111, 22222, 99999, 88888] + 11111: [True, True, False, True, False, True, False], + 12345: [True, False, False, False, False, True, False], + 22222: [True, False, True, False, True, True, True], + 54321: [True, True, False, False, False, True, False], + 67890: [True, False, True, False, False, True, True], + 88888: [False, False, False, False, False, False, True], + 99999: [False, False, False, False, False, True, False], + }, + ) + + data_df = pd.DataFrame(data) + + # ------------- Code Approach ------- + # ancestors_of_nodes = {} + # for parent, child in ChebiMockOntology.get_edges_of_transitive_closure_graph(): + # if child not in ancestors_of_nodes: + # ancestors_of_nodes[child] = set() + # if parent not in ancestors_of_nodes: + # ancestors_of_nodes[parent] = set() + # ancestors_of_nodes[child].add(parent) + # ancestors_of_nodes[child].add(child) + # + # # For each node in the ontology, create a column to check if it's an ancestor of any other node or itself + # for node in ChebiMockOntology.get_nodes(): + # data_df[node] = data_df['id'].apply( + # lambda x: (x == node) or (node in ancestors_of_nodes[x]) + # ) + + return data_df + + @staticmethod + def get_transitively_closed_graph() -> nx.DiGraph: + """ + Create a directed graph, compute its transitive closure, and return it. + + Returns: + g (nx.DiGraph): A transitively closed directed graph. + """ + g = nx.DiGraph() + + for node in ChebiMockOntology.get_nodes(): + g.add_node(node, **{"smiles": "test_smiles_placeholder"}) + + g.add_edges_from(ChebiMockOntology.get_edges_of_transitive_closure_graph()) + + return g + + class GOUniProtMockData(MockOntologyGraphData): """ A mock ontology representing a simplified version of the Gene Ontology (GO) structure with nodes and edges From 64d76231f6bb7563ba7b26026eeff99d398ef31b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 14 Apr 2025 20:35:11 +0200 Subject: [PATCH 04/36] change loss module for protein data --- chebai/loss/bce_weighted.py | 8 ++------ chebai/loss/semantic.py | 20 ++++++++++---------- 2 files changed, 12 insertions(+), 16 deletions(-) diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py index b4fb863..e80bfbc 100644 --- a/chebai/loss/bce_weighted.py +++ b/chebai/loss/bce_weighted.py @@ -1,12 +1,10 @@ import os from typing import Optional -import pandas as pd import torch from chebai.preprocessing.datasets.base import XYBaseDataModule -from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor -from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed +from chebai.preprocessing.datasets.deepGO.go_uniprot import _GOUniProtDataExtractor class BCEWeighted(torch.nn.BCEWithLogitsLoss): @@ -29,11 +27,9 @@ def __init__( **kwargs, ): self.beta = beta - if isinstance(data_extractor, LabeledUnlabeledMixed): - data_extractor = data_extractor.labeled self.data_extractor = data_extractor assert ( - isinstance(self.data_extractor, _ChEBIDataExtractor) + isinstance(self.data_extractor, _GOUniProtDataExtractor) or self.data_extractor is None ) super().__init__(**kwargs) diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py index 271c312..248546e 100644 --- a/chebai/loss/semantic.py +++ b/chebai/loss/semantic.py @@ -2,14 +2,16 @@ import math import os import pickle -from typing import List, Literal, Union +from typing import List, Literal, Type, Union import torch from chebai.loss.bce_weighted import BCEWeighted from chebai.preprocessing.datasets import XYBaseDataModule -from chebai.preprocessing.datasets.chebi import ChEBIOver100, _ChEBIDataExtractor -from chebai.preprocessing.datasets.pubchem import LabeledUnlabeledMixed +from chebai.preprocessing.datasets.deepGO.go_uniprot import ( + GOUniProtOver250, + _GOUniProtDataExtractor, +) class ImplicationLoss(torch.nn.Module): @@ -17,7 +19,7 @@ class ImplicationLoss(torch.nn.Module): Implication Loss module. Args: - data_extractor _ChEBIDataExtractor: Data extractor for labels. + data_extractor _GOUniProtDataExtractor: Data extractor for labels. base_loss (torch.nn.Module, optional): Base loss function. Defaults to None. fuzzy_implication (Literal["product", "lukasiewicz", "xu19"], optional): T-norm type. Defaults to "product". impl_loss_weight (float, optional): Weight of implication loss relative to base loss. Defaults to 0.1. @@ -70,9 +72,7 @@ def __init__( ): super().__init__() # automatically choose labeled subset for implication filter in case of mixed dataset - if isinstance(data_extractor, LabeledUnlabeledMixed): - data_extractor = data_extractor.labeled - assert isinstance(data_extractor, _ChEBIDataExtractor) + assert isinstance(data_extractor, _GOUniProtDataExtractor) self.data_extractor = data_extractor # propagate data_extractor to base loss if isinstance(base_loss, BCEWeighted): @@ -329,7 +329,7 @@ class DisjointLoss(ImplicationLoss): Args: path_to_disjointness (str): Path to the disjointness data file (a csv file containing pairs of disjoint classes) - data_extractor (Union[_ChEBIDataExtractor, LabeledUnlabeledMixed]): Data extractor for labels. + data_extractor (_GOUniProtDataExtractor): Data extractor for labels. base_loss (torch.nn.Module, optional): Base loss function. Defaults to None. disjoint_loss_weight (float, optional): Weight of disjointness loss. Defaults to 100. **kwargs: Additional arguments. @@ -338,7 +338,7 @@ class DisjointLoss(ImplicationLoss): def __init__( self, path_to_disjointness: str, - data_extractor: Union[_ChEBIDataExtractor, LabeledUnlabeledMixed], + data_extractor: _GOUniProtDataExtractor, base_loss: torch.nn.Module = None, disjoint_loss_weight: float = 100, **kwargs, @@ -502,7 +502,7 @@ def _build_disjointness_filter( if __name__ == "__main__": loss = DisjointLoss( os.path.join("data", "disjoint.csv"), - ChEBIOver100(chebi_version=231), + GOUniProtOver250(), base_loss=BCEWeighted(), impl_loss_weight=1, disjoint_loss_weight=1, From beaf74e6de7761b155e5684b667523b53d54c469 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 14 Apr 2025 20:37:04 +0200 Subject: [PATCH 05/36] update trainer for protein reader --- chebai/trainer/CustomTrainer.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py index 874d6b3..cb76199 100644 --- a/chebai/trainer/CustomTrainer.py +++ b/chebai/trainer/CustomTrainer.py @@ -9,7 +9,7 @@ from torch.nn.utils.rnn import pad_sequence from chebai.loggers.custom import CustomLogger -from chebai.preprocessing.reader import CLS_TOKEN, ChemDataReader +from chebai.preprocessing.reader import CLS_TOKEN, ProteinDataReader log = logging.getLogger(__name__) @@ -99,22 +99,22 @@ def predict_from_file( predictions_df.to_csv(save_to) def _predict_smiles( - self, model: LightningModule, smiles: List[str] + self, model: LightningModule, sequence: List[str] ) -> torch.Tensor: """ Predicts the output for a list of SMILES strings using the model. Args: model: The model to use for predictions. - smiles: A list of SMILES strings. + sequence: Protein sequence. Returns: A tensor containing the predictions. """ - reader = ChemDataReader() - parsed_smiles = [reader._read_data(s) for s in smiles] + reader = ProteinDataReader() + parsed_sequence = [reader._read_data(s) for s in sequence] x = pad_sequence( - [torch.tensor(a, device=model.device) for a in parsed_smiles], + [torch.tensor(a, device=model.device) for a in parsed_sequence], batch_first=True, ) cls_tokens = ( From 9fd19a917656f96af930c4e00feec159cbc5bd36 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 14 Apr 2025 20:38:02 +0200 Subject: [PATCH 06/36] remove chebi imports and libraries --- chebai/preprocessing/reader.py | 4 ---- setup.py | 4 ---- 2 files changed, 8 deletions(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 38060f2..1fa5a47 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -3,8 +3,6 @@ from typing import Any, Dict, List, Optional, Tuple from urllib.error import HTTPError -import deepsmiles -import selfies as sf import torch from esm import Alphabet from esm.model.esm2 import ESM2 @@ -13,8 +11,6 @@ load_model_and_alphabet_core, load_model_and_alphabet_local, ) -from pysmiles.read_smiles import _tokenize -from transformers import RobertaTokenizerFast from chebai.preprocessing.collate import DefaultCollator, RaggedCollator diff --git a/setup.py b/setup.py index 8a6d3e0..1abc871 100644 --- a/setup.py +++ b/setup.py @@ -34,17 +34,13 @@ "urllib3", "transformers", "fastobo", - "pysmiles==1.1.2", "scikit-network", "svgutils", "matplotlib", - "rdkit", - "selfies", "lightning>=2.5", "jsonargparse[signatures]>=4.17", "omegaconf", "seaborn", - "deepsmiles", "iterative-stratification", "wandb", "chardet", From 7048cd0fffc0599b8d5d9da6486b1bcd5808e75b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 14 Apr 2025 22:10:40 +0200 Subject: [PATCH 07/36] remove chebi version param from base data class --- chebai/preprocessing/datasets/base.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py index 817bc1d..3308ec9 100644 --- a/chebai/preprocessing/datasets/base.py +++ b/chebai/preprocessing/datasets/base.py @@ -36,7 +36,6 @@ class XYBaseDataModule(LightningDataModule): label_filter (Optional[int]): The index of the label to filter. Default is None. balance_after_filter (Optional[float]): The ratio of negative samples to positive samples after filtering. Default is None. num_workers (int): The number of worker processes for data loading. Default is 1. - chebi_version (int): The version of ChEBI to use. Default is 200. inner_k_folds (int): The number of folds for inner cross-validation. Use -1 to disable inner cross-validation. Default is -1. fold_index (Optional[int]): The index of the fold to use for training and validation. Default is None. base_dir (Optional[str]): The base directory for storing processed and raw data. Default is None. @@ -52,7 +51,6 @@ class XYBaseDataModule(LightningDataModule): label_filter (Optional[int]): The index of the label to filter. balance_after_filter (Optional[float]): The ratio of negative samples to positive samples after filtering. num_workers (int): The number of worker processes for data loading. - chebi_version (int): The version of ChEBI to use. inner_k_folds (int): The number of folds for inner cross-validation. If it is less than to, no cross-validation will be performed. fold_index (Optional[int]): The index of the fold to use for training and validation (only relevant for cross-validation). _base_dir (Optional[str]): The base directory for storing processed and raw data. @@ -75,7 +73,6 @@ def __init__( label_filter: Optional[int] = None, balance_after_filter: Optional[float] = None, num_workers: int = 1, - chebi_version: int = 200, inner_k_folds: int = -1, # use inner cross-validation if > 1 fold_index: Optional[int] = None, base_dir: Optional[str] = None, @@ -95,7 +92,6 @@ def __init__( ), "Filter balancing requires a filter" self.balance_after_filter = balance_after_filter self.num_workers = num_workers - self.chebi_version = chebi_version assert type(inner_k_folds) is int self.inner_k_folds = inner_k_folds self.use_inner_cross_validation = ( From 9c8521fb9e236dceefdcc98dea63a7e7f4baa689 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 14 Apr 2025 22:16:35 +0200 Subject: [PATCH 08/36] electra config: update vocab size & max pos for protein seq --- configs/model/electra.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/configs/model/electra.yml b/configs/model/electra.yml index c3cf2fd..94d1dc6 100644 --- a/configs/model/electra.yml +++ b/configs/model/electra.yml @@ -3,8 +3,8 @@ init_args: optimizer_kwargs: lr: 1e-3 config: - vocab_size: 1400 - max_position_embeddings: 1800 + vocab_size: 31 # 21 amino acids (when n_gram=1) + 10 special tokens of LLM + max_position_embeddings: 1000 # max default sequence length for protein num_attention_heads: 8 num_hidden_layers: 6 type_vocab_size: 1 From b45b26647f4b4de8bad349e1717f527ca9d598a7 Mon Sep 17 00:00:00 2001 From: sfluegel Date: Wed, 16 Apr 2025 10:39:32 +0200 Subject: [PATCH 09/36] add changes from `out_dim` PR (#74 in python-chebai) --- chebai/preprocessing/datasets/deepGO/go_uniprot.py | 2 +- chebai/preprocessing/datasets/deepGO/protein_pretraining.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/chebai/preprocessing/datasets/deepGO/go_uniprot.py b/chebai/preprocessing/datasets/deepGO/go_uniprot.py index 1b0eb2a..9c5d5c0 100644 --- a/chebai/preprocessing/datasets/deepGO/go_uniprot.py +++ b/chebai/preprocessing/datasets/deepGO/go_uniprot.py @@ -770,7 +770,7 @@ def _name(self) -> str: return f"{threshold_part}{self.max_sequence_length}" # ------------------------------ Phase: Prepare data ----------------------------------- - def prepare_data(self, *args: Any, **kwargs: Any) -> None: + def _perform_data_preparation(self, *args: Any, **kwargs: Any) -> None: """ Checks for the existence of migrated DeepGO data in the specified directory. Raises an error if the required data file is not found, prompting diff --git a/chebai/preprocessing/datasets/deepGO/protein_pretraining.py b/chebai/preprocessing/datasets/deepGO/protein_pretraining.py index 8f7e9c4..4be053a 100644 --- a/chebai/preprocessing/datasets/deepGO/protein_pretraining.py +++ b/chebai/preprocessing/datasets/deepGO/protein_pretraining.py @@ -55,7 +55,7 @@ def __init__(self, **kwargs): ) # ------------------------------ Phase: Prepare data ----------------------------------- - def prepare_data(self, *args: Any, **kwargs: Any) -> None: + def _perform_data_preparation(self, *args: Any, **kwargs: Any) -> None: """ Prepares the data by downloading and parsing Swiss-Prot data if not already available. Saves the processed data for further use. From a2939316074bbb20ba7f7d705ca754123a9a938f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 23 Apr 2025 16:33:10 +0200 Subject: [PATCH 10/36] remove not required files --- chebai/__init__.py | 30 - chebai/__main__.py | 10 - chebai/callbacks.py | 86 -- chebai/callbacks/__init__.py | 0 chebai/callbacks/epoch_metrics.py | 180 --- chebai/callbacks/model_checkpoint.py | 95 -- chebai/callbacks/prediction_callback.py | 55 - chebai/cli.py | 75 -- chebai/loggers/__init__.py | 0 chebai/loggers/custom.py | 127 -- chebai/loss/__init__.py | 0 chebai/loss/bce_weighted.py | 94 -- chebai/loss/mixed.py | 40 - chebai/loss/pretraining.py | 48 - chebai/loss/semantic.py | 532 -------- chebai/models/__init__.py | 2 - chebai/models/base.py | 372 ------ chebai/models/chemberta.py | 77 -- chebai/models/chemyk.py | 63 - chebai/models/electra.py | 535 -------- chebai/models/external/__init__.py | 0 chebai/models/ffn.py | 153 --- chebai/models/lnn_model.py | 40 - chebai/models/lstm.py | 34 - chebai/models/recursive.py | 97 -- chebai/models/strontex.py | 14 - chebai/preprocessing/collate.py | 137 -- chebai/preprocessing/collect_all.py | 225 ---- chebai/preprocessing/datasets/base.py | 1180 ----------------- chebai/preprocessing/structures.py | 141 -- chebai/result/__init__.py | 0 chebai/result/analyse_sem.py | 721 ---------- chebai/result/base.py | 105 -- chebai/result/classification.py | 105 -- chebai/result/evaluate_predictions.py | 108 -- chebai/result/molplot.py | 506 ------- chebai/result/prediction_json.py | 26 - chebai/result/pretraining.py | 65 - chebai/result/utils.py | 235 ---- chebai/trainer/CustomTrainer.py | 149 --- chebai/trainer/__init__.py | 0 configs/default_prediction_callback.yml | 4 - configs/loss/bce.yml | 1 - configs/loss/electra_pre_loss.yml | 1 - configs/loss/semantic_loss.yml | 10 - configs/metrics/balanced-accuracy.yml | 5 - configs/metrics/micro-macro-f1.yml | 9 - configs/metrics/single-class-f1.yml | 5 - configs/model/electra-for-pretraining.yml | 20 - configs/model/electra.yml | 11 - configs/model/electra_pretraining.yml | 18 - configs/model/ffn.yml | 5 - configs/training/csv_logger.yml | 3 - configs/training/default_callbacks.yml | 12 - configs/training/default_trainer.yml | 5 - configs/training/early_stop_callbacks.yml | 19 - configs/training/pretraining_callbacks.yml | 12 - configs/training/pretraining_trainer.yml | 7 - configs/training/single_class_callbacks.yml | 13 - configs/training/wandb_logger.yml | 6 - docs/source/experiment.rst | 1 - docs/source/model.rst | 1 - tests/unit/collators/__init__.py | 0 tests/unit/collators/testDefaultCollator.py | 65 - tests/unit/collators/testRaggedCollator.py | 204 --- .../dataset_classes/testDynamicDataset.py | 372 ------ .../dataset_classes/testXYBaseDataModule.py | 92 -- tests/unit/readers/testDataReader.py | 56 - 68 files changed, 7419 deletions(-) delete mode 100644 chebai/__init__.py delete mode 100644 chebai/__main__.py delete mode 100644 chebai/callbacks.py delete mode 100644 chebai/callbacks/__init__.py delete mode 100644 chebai/callbacks/epoch_metrics.py delete mode 100644 chebai/callbacks/model_checkpoint.py delete mode 100644 chebai/callbacks/prediction_callback.py delete mode 100644 chebai/cli.py delete mode 100644 chebai/loggers/__init__.py delete mode 100644 chebai/loggers/custom.py delete mode 100644 chebai/loss/__init__.py delete mode 100644 chebai/loss/bce_weighted.py delete mode 100644 chebai/loss/mixed.py delete mode 100644 chebai/loss/pretraining.py delete mode 100644 chebai/loss/semantic.py delete mode 100644 chebai/models/__init__.py delete mode 100644 chebai/models/base.py delete mode 100644 chebai/models/chemberta.py delete mode 100644 chebai/models/chemyk.py delete mode 100644 chebai/models/electra.py delete mode 100644 chebai/models/external/__init__.py delete mode 100644 chebai/models/ffn.py delete mode 100644 chebai/models/lnn_model.py delete mode 100644 chebai/models/lstm.py delete mode 100644 chebai/models/recursive.py delete mode 100644 chebai/models/strontex.py delete mode 100644 chebai/preprocessing/collate.py delete mode 100644 chebai/preprocessing/collect_all.py delete mode 100644 chebai/preprocessing/datasets/base.py delete mode 100644 chebai/preprocessing/structures.py delete mode 100644 chebai/result/__init__.py delete mode 100644 chebai/result/analyse_sem.py delete mode 100644 chebai/result/base.py delete mode 100644 chebai/result/classification.py delete mode 100644 chebai/result/evaluate_predictions.py delete mode 100644 chebai/result/molplot.py delete mode 100644 chebai/result/prediction_json.py delete mode 100644 chebai/result/pretraining.py delete mode 100644 chebai/result/utils.py delete mode 100644 chebai/trainer/CustomTrainer.py delete mode 100644 chebai/trainer/__init__.py delete mode 100644 configs/default_prediction_callback.yml delete mode 100644 configs/loss/bce.yml delete mode 100644 configs/loss/electra_pre_loss.yml delete mode 100644 configs/loss/semantic_loss.yml delete mode 100644 configs/metrics/balanced-accuracy.yml delete mode 100644 configs/metrics/micro-macro-f1.yml delete mode 100644 configs/metrics/single-class-f1.yml delete mode 100644 configs/model/electra-for-pretraining.yml delete mode 100644 configs/model/electra.yml delete mode 100644 configs/model/electra_pretraining.yml delete mode 100644 configs/model/ffn.yml delete mode 100644 configs/training/csv_logger.yml delete mode 100644 configs/training/default_callbacks.yml delete mode 100644 configs/training/default_trainer.yml delete mode 100644 configs/training/early_stop_callbacks.yml delete mode 100644 configs/training/pretraining_callbacks.yml delete mode 100644 configs/training/pretraining_trainer.yml delete mode 100644 configs/training/single_class_callbacks.yml delete mode 100644 configs/training/wandb_logger.yml delete mode 100644 docs/source/experiment.rst delete mode 100644 docs/source/model.rst delete mode 100644 tests/unit/collators/__init__.py delete mode 100644 tests/unit/collators/testDefaultCollator.py delete mode 100644 tests/unit/collators/testRaggedCollator.py delete mode 100644 tests/unit/dataset_classes/testDynamicDataset.py delete mode 100644 tests/unit/dataset_classes/testXYBaseDataModule.py delete mode 100644 tests/unit/readers/testDataReader.py diff --git a/chebai/__init__.py b/chebai/__init__.py deleted file mode 100644 index 9f508aa..0000000 --- a/chebai/__init__.py +++ /dev/null @@ -1,30 +0,0 @@ -import os -from typing import Any - -import torch - -# Get the absolute path of the current file's directory -MODULE_PATH = os.path.abspath(os.path.dirname(__file__)) - - -class CustomTensor(torch.Tensor): - """ - A custom tensor class inheriting from `torch.Tensor`. - - This class allows for the creation of tensors using the provided data. - - Attributes: - data (Any): The data to be converted into a tensor. - """ - - def __new__(cls, data: Any) -> "CustomTensor": - """ - Creates a new instance of CustomTensor. - - Args: - data (Any): The data to be converted into a tensor. - - Returns: - CustomTensor: A tensor containing the provided data. - """ - return torch.tensor(data) diff --git a/chebai/__main__.py b/chebai/__main__.py deleted file mode 100644 index 0afee8e..0000000 --- a/chebai/__main__.py +++ /dev/null @@ -1,10 +0,0 @@ -from chebai.cli import cli - -if __name__ == "__main__": - """ - Entry point for the CLI application. - - This script calls the `cli` function from the `chebai.cli` module - when executed as the main program. - """ - cli() diff --git a/chebai/callbacks.py b/chebai/callbacks.py deleted file mode 100644 index 764db44..0000000 --- a/chebai/callbacks.py +++ /dev/null @@ -1,86 +0,0 @@ -import json -import os -from typing import Any, Dict, List, Literal, Union - -import torch -from lightning.pytorch.callbacks import BasePredictionWriter - - -class ChebaiPredictionWriter(BasePredictionWriter): - """ - A custom prediction writer for saving batch and epoch predictions during model training. - - This class inherits from `BasePredictionWriter` and is designed to save predictions - in a specified output directory at specified intervals. - - Args: - output_dir (str): The directory where predictions will be saved. - write_interval (str): The interval at which predictions will be written. - target_file (str): The name of the file where epoch predictions will be saved (default: "predictions.json"). - """ - - def __init__( - self, - output_dir: str, - write_interval: Literal["batch", "epoch", "batch_and_epoch"], - target_file: str = "predictions.json", - ) -> None: - super().__init__(write_interval) - self.output_dir = output_dir - self.target_file = target_file - - def write_on_batch_end( - self, - trainer: Any, - pl_module: Any, - prediction: Union[torch.Tensor, List[torch.Tensor]], - batch_indices: List[int], - batch: Any, - batch_idx: int, - dataloader_idx: int, - ) -> None: - """ - Saves batch predictions at the end of each batch. - - Args: - trainer (Any): The trainer instance. - pl_module (Any): The LightningModule instance. - prediction (Union[torch.Tensor, List[torch.Tensor]]): The prediction output from the model. - batch_indices (List[int]): The indices of the batch. - batch (Any): The current batch. - batch_idx (int): The index of the batch. - dataloader_idx (int): The index of the dataloader. - """ - outpath = os.path.join(self.output_dir, str(dataloader_idx), f"{batch_idx}.pt") - os.makedirs(os.path.dirname(outpath), exist_ok=True) - torch.save(prediction, outpath) - - def write_on_epoch_end( - self, - trainer: Any, - pl_module: Any, - predictions: List[Dict[str, Any]], - batch_indices: List[int], - ) -> None: - """ - Saves all predictions at the end of each epoch in a JSON file. - - Args: - trainer (Any): The trainer instance. - pl_module (Any): The LightningModule instance. - predictions (List[Dict[str, Any]]): The list of prediction outputs from the model. - batch_indices (List[int]): The indices of the batches. - """ - pred_list = [] - for p in predictions: - idents = p["data"]["idents"] - labels = p["data"]["labels"] - if labels is not None: - labels = labels.tolist() - else: - labels = [None for _ in idents] - output = torch.sigmoid(p["output"]["logits"]).tolist() - for i, l, o in zip(idents, labels, output): - pred_list.append(dict(ident=i, labels=l, predictions=o)) - with open(os.path.join(self.output_dir, self.target_file), "wt") as fout: - json.dump(pred_list, fout) diff --git a/chebai/callbacks/__init__.py b/chebai/callbacks/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/chebai/callbacks/epoch_metrics.py b/chebai/callbacks/epoch_metrics.py deleted file mode 100644 index c1cf7bd..0000000 --- a/chebai/callbacks/epoch_metrics.py +++ /dev/null @@ -1,180 +0,0 @@ -import torch -import torchmetrics - - -def custom_reduce_fx(input: torch.Tensor) -> torch.Tensor: - """ - Custom reduction function for distributed training. - - Args: - input (torch.Tensor): The input tensor to be reduced. - - Returns: - torch.Tensor: The reduced tensor. - """ - print(f"called reduce (device: {input.device})") - return torch.sum(input, dim=0) - - -class MacroF1(torchmetrics.Metric): - """ - Computes the Macro F1 score, which is the unweighted mean of F1 scores for each class. - This implementation differs from torchmetrics.classification.MultilabelF1Score in the behaviour for undefined - values (i.e., classes where TP+FN=0). The torchmetrics implementation sets these classes to a default value. - Here, the mean is only taken over classes which have at least one positive sample. - - Args: - num_labels (int): Number of classes/labels. - dist_sync_on_step (bool, optional): Synchronize metric state across processes at each forward - before returning the value at the step. Default: False. - threshold (float, optional): Threshold for converting predicted probabilities to binary (0, 1) predictions. - Default: 0.5. - """ - - def __init__( - self, num_labels: int, dist_sync_on_step: bool = False, threshold: float = 0.5 - ): - super().__init__(dist_sync_on_step=dist_sync_on_step) - - self.add_state( - "true_positives", - default=torch.zeros(num_labels, dtype=torch.int), - dist_reduce_fx="sum", - ) - self.add_state( - "positive_predictions", - default=torch.zeros(num_labels, dtype=torch.int), - dist_reduce_fx="sum", - ) - self.add_state( - "positive_labels", - default=torch.zeros(num_labels, dtype=torch.int), - dist_reduce_fx="sum", - ) - self.threshold = threshold - - def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None: - """ - Update the state (TPs, Positive Predictions, Positive labels) with the current batch of predictions and labels. - - Args: - preds (torch.Tensor): Predictions from the model. - labels (torch.Tensor): Ground truth labels. - """ - tps = torch.sum( - torch.logical_and(preds > self.threshold, labels.to(torch.bool)), dim=0 - ) - self.true_positives += tps - self.positive_predictions += torch.sum(preds > self.threshold, dim=0) - self.positive_labels += torch.sum(labels, dim=0) - - def compute(self) -> torch.Tensor: - """ - Compute the Macro F1 score. - - Returns: - torch.Tensor: The computed Macro F1 score. - """ - - # ignore classes without positive labels - # classes with positive labels, but no positive predictions will get a precision of "nan" (0 divided by 0), - # which is propagated to the classwise_f1 and then turned into 0 - mask = self.positive_labels != 0 - precision = self.true_positives[mask] / self.positive_predictions[mask] - recall = self.true_positives[mask] / self.positive_labels[mask] - classwise_f1 = 2 * precision * recall / (precision + recall) - # if (precision and recall are 0) or (precision is nan), set f1 to 0 - classwise_f1 = classwise_f1.nan_to_num() - return torch.mean(classwise_f1) - - -class BalancedAccuracy(torchmetrics.Metric): - """ - Computes the Balanced Accuracy, which is the average of true positive rate (TPR) and true negative rate (TNR). - Useful for imbalanced datasets. - Balanced Accuracy = (TPR + TNR)/2 = (TP/(TP + FN) + (TN)/(TN + FP))/2 - - Args: - num_labels (int): Number of classes/labels. - dist_sync_on_step (bool, optional): Synchronize metric state across processes at each forward - before returning the value at the step. Default: False. - threshold (float, optional): Threshold for converting predicted probabilities to binary (0, 1) predictions. - Default: 0.5. - """ - - def __init__( - self, num_labels: int, dist_sync_on_step: bool = False, threshold: float = 0.5 - ): - super().__init__(dist_sync_on_step=dist_sync_on_step) - - self.add_state( - "true_positives", - default=torch.zeros(num_labels, dtype=torch.int), - dist_reduce_fx="sum", - ) - - self.add_state( - "false_positives", - default=torch.zeros(num_labels, dtype=torch.int), - dist_reduce_fx="sum", - ) - - self.add_state( - "true_negatives", - default=torch.zeros(num_labels, dtype=torch.int), - dist_reduce_fx="sum", - ) - - self.add_state( - "false_negatives", - default=torch.zeros(num_labels, dtype=torch.int), - dist_reduce_fx="sum", - ) - - self.threshold = threshold - - def update(self, preds: torch.Tensor, labels: torch.Tensor) -> None: - """ - Update the state (TPs, TNs, FPs, FNs) with the current batch of predictions and labels. - - Args: - preds (torch.Tensor): Predictions from the model. - labels (torch.Tensor): Ground truth labels. - """ - - # Size: Batch_size x Num_of_Classes; - # summing over 1st dimension (dim=0), gives us the True positives per class - tps = torch.sum( - torch.logical_and(preds > self.threshold, labels.to(torch.bool)), dim=0 - ) - fps = torch.sum( - torch.logical_and(preds > self.threshold, ~labels.to(torch.bool)), dim=0 - ) - tns = torch.sum( - torch.logical_and(preds <= self.threshold, ~labels.to(torch.bool)), dim=0 - ) - fns = torch.sum( - torch.logical_and(preds <= self.threshold, labels.to(torch.bool)), dim=0 - ) - - # Size: Num_of_Classes; - self.true_positives += tps - self.false_positives += fps - self.true_negatives += tns - self.false_negatives += fns - - def compute(self) -> torch.Tensor: - """ - Compute the Balanced Accuracy. - - Returns: - torch.Tensor: The computed Balanced Accuracy. - """ - tpr = self.true_positives / (self.true_positives + self.false_negatives) - tnr = self.true_negatives / (self.true_negatives + self.false_positives) - # Convert the nan values to 0 - tpr = tpr.nan_to_num() - tnr = tnr.nan_to_num() - - balanced_acc = (tpr + tnr) / 2 - return torch.mean(balanced_acc) diff --git a/chebai/callbacks/model_checkpoint.py b/chebai/callbacks/model_checkpoint.py deleted file mode 100644 index dbdbab1..0000000 --- a/chebai/callbacks/model_checkpoint.py +++ /dev/null @@ -1,95 +0,0 @@ -import os - -from lightning.fabric.utilities.cloud_io import _is_dir -from lightning.fabric.utilities.types import _PATH -from lightning.pytorch import LightningModule, Trainer -from lightning.pytorch.callbacks.model_checkpoint import ModelCheckpoint -from lightning.pytorch.loggers import WandbLogger -from lightning.pytorch.utilities.rank_zero import rank_zero_info -from lightning_utilities.core.rank_zero import rank_zero_warn - - -class CustomModelCheckpoint(ModelCheckpoint): - """ - Custom checkpoint class that resolves checkpoint paths to ensure checkpoints are saved in the same directory - as other logs when using CustomLogger. - Inherits from PyTorch Lightning's ModelCheckpoint class. - """ - - def setup(self, trainer: Trainer, pl_module: LightningModule, stage: str) -> None: - """ - Setup the directory path for saving checkpoints. If the directory path is not set, it resolves the checkpoint - directory using the custom logger's directory. - - Note: - Same as in parent class, duplicated to be able to call self.__resolve_ckpt_dir - - Args: - trainer (Trainer): The Trainer instance. - pl_module (LightningModule): The LightningModule instance. - stage (str): The stage of training (e.g., 'fit'). - """ - if self.dirpath is not None: - self.dirpath = None - dirpath = self.__resolve_ckpt_dir(trainer) - dirpath = trainer.strategy.broadcast(dirpath) - self.dirpath = dirpath - if trainer.is_global_zero and stage == "fit": - self.__warn_if_dir_not_empty(self.dirpath) - - def __warn_if_dir_not_empty(self, dirpath: _PATH) -> None: - """ - Warn if the checkpoint directory is not empty. - - Note: - Same as in parent class, duplicated because method in parent class is not accessible - - Args: - dirpath (_PATH): The path to the checkpoint directory. - """ - if ( - self.save_top_k != 0 - and _is_dir(self._fs, dirpath, strict=True) - and len(self._fs.ls(dirpath)) > 0 - ): - rank_zero_warn(f"Checkpoint directory {dirpath} exists and is not empty.") - - def __resolve_ckpt_dir(self, trainer: Trainer) -> _PATH: - """ - Resolve the checkpoint directory path, ensuring compatibility with WandbLogger by saving checkpoints - in the same directory as Wandb logs. - - Note: - Overwritten for compatibility with wandb -> saves checkpoints in same dir as wandb logs - - Args: - trainer (Trainer): The Trainer instance. - - Returns: - _PATH: The resolved checkpoint directory path. - """ - rank_zero_info(f"Resolving checkpoint dir (custom)") - if self.dirpath is not None: - # short circuit if dirpath was passed to ModelCheckpoint - return self.dirpath - if len(trainer.loggers) > 0: - if trainer.loggers[0].save_dir is not None: - save_dir = trainer.loggers[0].save_dir - else: - save_dir = trainer.default_root_dir - name = trainer.loggers[0].name - version = trainer.loggers[0].version - version = version if isinstance(version, str) else f"version_{version}" - logger = trainer.loggers[0] - if isinstance(logger, WandbLogger) and isinstance( - logger.experiment.dir, str - ): - ckpt_path = os.path.join(logger.experiment.dir, "checkpoints") - else: - ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints") - else: - # if no loggers, use default_root_dir - ckpt_path = os.path.join(trainer.default_root_dir, "checkpoints") - - rank_zero_info(f"Now using checkpoint path {ckpt_path}") - return ckpt_path diff --git a/chebai/callbacks/prediction_callback.py b/chebai/callbacks/prediction_callback.py deleted file mode 100644 index b36197d..0000000 --- a/chebai/callbacks/prediction_callback.py +++ /dev/null @@ -1,55 +0,0 @@ -import os -import pickle -from typing import Any, Literal, Sequence - -import torch -from lightning.pytorch import LightningModule, Trainer -from lightning.pytorch.callbacks import BasePredictionWriter - - -class PredictionWriter(BasePredictionWriter): - """ - Custom callback for writing predictions to a file at the end of each epoch. - - Args: - output_dir (str): The directory where prediction files will be saved. - write_interval (str): When to write predictions. Options are "batch" or "epoch". - """ - - def __init__( - self, - output_dir: str, - write_interval: Literal["batch", "epoch", "batch_and_epoch"], - ): - super().__init__(write_interval) - self.output_dir = output_dir - self.prediction_file_name = "predictions.pkl" - - def write_on_epoch_end( - self, - trainer: Trainer, - pl_module: LightningModule, - predictions: Sequence[Any], - batch_indices: Sequence[Any], - ) -> None: - """ - Writes the predictions to a file at the end of the epoch. - - Args: - trainer (Trainer): The Trainer instance. - pl_module (LightningModule): The LightningModule instance. - predictions (Sequence[Any]): Any sequence of predictions for the epoch. - batch_indices (Sequence[Any]): Any sequence of batch indices. - """ - results = [ - dict( - ident=row["data"]["idents"][0], - predictions=torch.sigmoid(row["output"]["logits"]).numpy(), - labels=row["labels"][0].numpy() if row["labels"] is not None else None, - ) - for row in predictions - ] - with open( - os.path.join(self.output_dir, self.prediction_file_name), "wb" - ) as fout: - pickle.dump(results, fout) diff --git a/chebai/cli.py b/chebai/cli.py deleted file mode 100644 index b7e78d1..0000000 --- a/chebai/cli.py +++ /dev/null @@ -1,75 +0,0 @@ -from typing import Dict, Set - -from lightning.pytorch.cli import LightningArgumentParser, LightningCLI - -from chebai.trainer.CustomTrainer import CustomTrainer - - -class ChebaiCLI(LightningCLI): - """ - Custom CLI subclass for Chebai project based on PyTorch Lightning's LightningCLI. - - Args: - save_config_kwargs (dict): Keyword arguments for saving configuration. - parser_kwargs (dict): Keyword arguments for parser configuration. - - Attributes: - save_config_kwargs (dict): Configuration options for saving. - parser_kwargs (dict): Configuration options for the argument parser. - """ - - def __init__(self, *args, **kwargs): - """ - Initialize ChebaiCLI with custom trainer and configure parser settings. - - Args: - args (list): List of arguments for LightningCLI. - kwargs (dict): Keyword arguments for LightningCLI. - save_config_kwargs (dict): Keyword arguments for saving configuration. - parser_kwargs (dict): Keyword arguments for parser configuration. - """ - super().__init__(trainer_class=CustomTrainer, *args, **kwargs) - - def add_arguments_to_parser(self, parser: LightningArgumentParser): - """ - Link input parameters that are used by different classes (e.g. number of labels) - see https://lightning.ai/docs/pytorch/stable/cli/lightning_cli_expert.html#argument-linking - - Args: - parser (LightningArgumentParser): Argument parser instance. - """ - for kind in ("train", "val", "test"): - for average in ("micro-f1", "macro-f1", "balanced-accuracy"): - parser.link_arguments( - "model.init_args.out_dim", - f"model.init_args.{kind}_metrics.init_args.metrics.{average}.init_args.num_labels", - ) - parser.link_arguments( - "model.init_args.out_dim", "trainer.callbacks.init_args.num_labels" - ) - - @staticmethod - def subcommands() -> Dict[str, Set[str]]: - """ - Defines the list of available subcommands and the arguments to skip. - - Returns: - Dict[str, Set[str]]: Dictionary where keys are subcommands and values are sets of arguments to skip. - """ - return { - "fit": {"model", "train_dataloaders", "val_dataloaders", "datamodule"}, - "validate": {"model", "dataloaders", "datamodule"}, - "test": {"model", "dataloaders", "datamodule"}, - "predict": {"model", "dataloaders", "datamodule"}, - "predict_from_file": {"model"}, - } - - -def cli(): - """ - Main function to instantiate and run the ChebaiCLI. - """ - r = ChebaiCLI( - save_config_kwargs={"config_filename": "lightning_config.yaml"}, - parser_kwargs={"parser_mode": "omegaconf"}, - ) diff --git a/chebai/loggers/__init__.py b/chebai/loggers/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/chebai/loggers/custom.py b/chebai/loggers/custom.py deleted file mode 100644 index d1b4282..0000000 --- a/chebai/loggers/custom.py +++ /dev/null @@ -1,127 +0,0 @@ -import os -from datetime import datetime -from typing import List, Literal, Optional, Union - -import wandb -from lightning.fabric.utilities.types import _PATH -from lightning.pytorch.callbacks import ModelCheckpoint -from lightning.pytorch.loggers import WandbLogger - - -class CustomLogger(WandbLogger): - """ - A custom logger that extends WandbLogger to add support for custom naming of runs and cross-validation. - - Args: - save_dir (_PATH): Directory where logs are saved. - name (str): Name of the logging run. - version (Optional[Union[int, str]]): Version of the logging run. - prefix (str): Prefix for logging. - fold (Optional[int]): Cross-validation fold number. - project (Optional[str]): Wandb project name. - entity (Optional[str]): Wandb entity name. - offline (bool): Whether to log offline. - log_model (Union[Literal["all"], bool]): Whether to log the model. - verbose_hyperparameters (bool): Whether to log hyperparameters verbosely. - tags (Optional[List[str]]): List of tags for the run. - **kwargs: Additional keyword arguments for WandbLogger. - """ - - def __init__( - self, - save_dir: _PATH, - name: str = "logs", - version: Optional[Union[int, str]] = None, - prefix: str = "", - fold: Optional[int] = None, - project: Optional[str] = None, - entity: Optional[str] = None, - offline: bool = False, - log_model: Union[Literal["all"], bool] = False, - verbose_hyperparameters: bool = False, - tags: Optional[List[str]] = None, - **kwargs, - ): - if version is None: - version = f"{datetime.now():%y%m%d-%H%M}" - self._version = version - self._name = name - self._fold = fold - self.verbose_hyperparameters = verbose_hyperparameters - super().__init__( - name=self.name, - save_dir=save_dir, - version=None, - prefix=prefix, - log_model=log_model, - entity=entity, - project=project, - offline=offline, - **kwargs, - ) - if tags: - self.experiment.tags += tuple(tags) - - @property - def name(self) -> Optional[str]: - """ - Returns the name of the logging run, including the version and fold number if applicable. - """ - name = f"{self._name}_{self.version}" - if self._fold is not None: - name += f"_fold{self._fold}" - return name - - @property - def version(self) -> Optional[str]: - """ - Returns the version of the logging run. - """ - return self._version - - @property - def root_dir(self) -> Optional[str]: - """ - Returns the root directory for saving logs. - """ - return os.path.join(self.save_dir, self.name) - - @property - def log_dir(self) -> str: - """ - Returns the directory for saving logs, including the version and fold number if applicable. - """ - version = ( - self.version if isinstance(self.version, str) else f"version_{self.version}" - ) - if self._fold is None: - return os.path.join(self.root_dir, version) - return os.path.join(self.root_dir, version, f"fold_{self._fold}") - - def set_fold(self, fold: int) -> None: - """ - Sets the fold number and restarts the Wandb experiment with the new fold number. - - Args: - fold (int): Cross-validation fold number. - """ - if fold != self._fold: - self._fold = fold - # Start new experiment - wandb.finish() - self._wandb_init["name"] = self.name - self._experiment = None - _ = self.experiment - - @property - def fold(self) -> Optional[int]: - """ - Returns the current fold number. - """ - return self._fold - - def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None: - """ - Override method to prevent saving checkpoints as Wandb artifacts. - """ - pass diff --git a/chebai/loss/__init__.py b/chebai/loss/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/chebai/loss/bce_weighted.py b/chebai/loss/bce_weighted.py deleted file mode 100644 index e80bfbc..0000000 --- a/chebai/loss/bce_weighted.py +++ /dev/null @@ -1,94 +0,0 @@ -import os -from typing import Optional - -import torch - -from chebai.preprocessing.datasets.base import XYBaseDataModule -from chebai.preprocessing.datasets.deepGO.go_uniprot import _GOUniProtDataExtractor - - -class BCEWeighted(torch.nn.BCEWithLogitsLoss): - """ - BCEWithLogitsLoss with weights automatically computed according to the beta parameter. - If beta is None or data_extractor is None, the loss is unweighted. - - This class computes weights based on the formula from the paper by Cui et al. (2019): - https://openaccess.thecvf.com/content_CVPR_2019/papers/Cui_Class-Balanced_Loss_Based_on_Effective_Number_of_Samples_CVPR_2019_paper.pdf - - Args: - beta (float, optional): The beta parameter for weight calculation. Default is None. - data_extractor (XYBaseDataModule, optional): The data extractor for loading the dataset. Default is None. - """ - - def __init__( - self, - beta: Optional[float] = None, - data_extractor: Optional[XYBaseDataModule] = None, - **kwargs, - ): - self.beta = beta - self.data_extractor = data_extractor - assert ( - isinstance(self.data_extractor, _GOUniProtDataExtractor) - or self.data_extractor is None - ) - super().__init__(**kwargs) - - def set_pos_weight(self, input: torch.Tensor) -> None: - """ - Sets the positive weights for the loss function based on the input tensor. - - Args: - input (torch.Tensor): The input tensor for which to set the positive weights. - """ - if ( - self.beta is not None - and self.data_extractor is not None - and all( - os.path.exists( - os.path.join(self.data_extractor.processed_dir, file_name) - ) - for file_name in self.data_extractor.processed_file_names - ) - and self.pos_weight is None - ): - print( - f"Computing loss-weights based on v{self.data_extractor.chebi_version} dataset (beta={self.beta})" - ) - complete_labels = torch.concat( - [ - torch.stack( - [ - torch.Tensor(row["labels"]) - for row in self.data_extractor.load_processed_data( - filename=file_name - ) - ] - ) - for file_name in self.data_extractor.processed_file_names - ] - ) - value_counts = complete_labels.sum(dim=0) - weights = [ - (1 - self.beta) / (1 - pow(self.beta, value)) for value in value_counts - ] - mean = sum(weights) / len(weights) - self.pos_weight = torch.tensor( - [w / mean for w in weights], device=input.device - ) - - def forward( - self, input: torch.Tensor, target: torch.Tensor, **kwargs - ) -> torch.Tensor: - """ - Forward pass for the loss calculation. - - Args: - input (torch.Tensor): The input tensor (predictions). - target (torch.Tensor): The target tensor (labels). - - Returns: - torch.Tensor: The computed loss. - """ - self.set_pos_weight(input) - return super().forward(input, target) diff --git a/chebai/loss/mixed.py b/chebai/loss/mixed.py deleted file mode 100644 index edfc5cb..0000000 --- a/chebai/loss/mixed.py +++ /dev/null @@ -1,40 +0,0 @@ -import torch -from torch import nn - - -class MixedDataLoss(nn.Module): - """ - A wrapper for applying a base loss function to a subset of input data. - - This class allows for selective application of a loss function based on the provided - non-null labels. - - Args: - base_loss (nn.Module): The base loss function to be applied. - """ - - def __init__(self, base_loss: nn.Module): - super().__init__() - self.base_loss = base_loss - - def forward( - self, input: torch.Tensor, target: torch.Tensor, **kwargs - ) -> torch.Tensor: - """ - Forward pass for applying the base loss function. - - Args: - input (torch.Tensor): The input tensor (predictions). - target (torch.Tensor): The target tensor (labels). - **kwargs: Additional keyword arguments. The 'non_null_labels' key can be used - to specify the indices of the non-null labels. - - Returns: - torch.Tensor: The computed loss. - """ - nnl = kwargs.pop("non_null_labels", None) - if nnl: - inp = input[nnl] - else: - inp = input - return self.base_loss(inp, target, **kwargs) diff --git a/chebai/loss/pretraining.py b/chebai/loss/pretraining.py deleted file mode 100644 index e2f51da..0000000 --- a/chebai/loss/pretraining.py +++ /dev/null @@ -1,48 +0,0 @@ -import torch - - -class ElectraPreLoss(torch.nn.Module): - """ - Custom loss module for pre-training ELECTRA-like models. - - This module computes a combined loss from two CrossEntropyLosses: - one for generator predictions and another for discriminator predictions. - - Attributes: - ce (torch.nn.CrossEntropyLoss): Cross entropy loss function. - - Methods: - forward(input, target, **loss_kwargs): - Computes the combined loss for generator and discriminator predictions. - - """ - - def __init__(self): - """ - Initializes the ElectraPreLoss module. - """ - super().__init__() - self.ce = torch.nn.CrossEntropyLoss() - - def forward(self, input, target, **loss_kwargs): - """ - Forward pass for computing the combined loss. - - Args: - input (tuple): A tuple containing generator predictions (gen_pred, disc_pred). - target (tuple): A tuple containing generator targets (gen_tar, disc_tar). - **loss_kwargs: Additional keyword arguments. - - Returns: - torch.Tensor: Combined loss of generator and discriminator predictions. - """ - t, p = input - gen_pred, disc_pred = t - gen_tar, disc_tar = p - - # Compute losses for generator and discriminator - gen_loss = self.ce(target=torch.argmax(gen_tar.int(), dim=-1), input=gen_pred) - disc_loss = self.ce( - target=torch.argmax(disc_tar.int(), dim=-1), input=disc_pred - ) - return gen_loss + disc_loss diff --git a/chebai/loss/semantic.py b/chebai/loss/semantic.py deleted file mode 100644 index 248546e..0000000 --- a/chebai/loss/semantic.py +++ /dev/null @@ -1,532 +0,0 @@ -import csv -import math -import os -import pickle -from typing import List, Literal, Type, Union - -import torch - -from chebai.loss.bce_weighted import BCEWeighted -from chebai.preprocessing.datasets import XYBaseDataModule -from chebai.preprocessing.datasets.deepGO.go_uniprot import ( - GOUniProtOver250, - _GOUniProtDataExtractor, -) - - -class ImplicationLoss(torch.nn.Module): - """ - Implication Loss module. - - Args: - data_extractor _GOUniProtDataExtractor: Data extractor for labels. - base_loss (torch.nn.Module, optional): Base loss function. Defaults to None. - fuzzy_implication (Literal["product", "lukasiewicz", "xu19"], optional): T-norm type. Defaults to "product". - impl_loss_weight (float, optional): Weight of implication loss relative to base loss. Defaults to 0.1. - pos_scalar (int, optional): Positive scalar exponent. Defaults to 1. - pos_epsilon (float, optional): Epsilon value for numerical stability. Defaults to 0.01. - multiply_by_softmax (bool, optional): Whether to multiply by softmax. Defaults to False. - use_sigmoidal_implication (bool, optional): Whether to use the sigmoidal fuzzy implication based on the - specified fuzzy_implication (as defined by van Krieken et al., 2022: Analyzing Differentiable Fuzzy Logic - Operators). Defaults to False. - weight_epoch_dependent (Union[bool, tuple[int, int]], optional): Whether to weight the implication loss - depending on the current epoch with the sigmoid function sigmoid((epoch-c)/s). If True, c=50 and s=10, - otherwise, a tuple of integers (c,s) can be supplied. Defaults to False. - start_at_epoch (int, optional): Epoch at which to start applying the loss. Defaults to 0. - violations_per_cls_aggregator (Literal["sum", "max"], optional): How to aggregate violations for each class. - If a class is involved in several implications / disjointnesses, the loss value for this class will be - aggregated with this method. Defaults to "sum". - """ - - def __init__( - self, - data_extractor: XYBaseDataModule, - base_loss: torch.nn.Module = None, - fuzzy_implication: Literal[ - "reichenbach", - "rc", - "lukasiewicz", - "lk", - "xu19", - "kleene_dienes", - "kd", - "goedel", - "g", - "reverse-goedel", - "rg", - "binary", - "b", - ] = "reichenbach", - impl_loss_weight: float = 0.1, - pos_scalar: Union[int, float] = 1, - pos_epsilon: float = 0.01, - multiply_by_softmax: bool = False, - use_sigmoidal_implication: bool = False, - weight_epoch_dependent: Union[bool | tuple[int, int]] = False, - start_at_epoch: int = 0, - violations_per_cls_aggregator: Literal[ - "sum", "max", "mean", "log-sum", "log-max", "log-mean" - ] = "sum", - multiply_with_base_loss: bool = True, - no_grads: bool = False, - ): - super().__init__() - # automatically choose labeled subset for implication filter in case of mixed dataset - assert isinstance(data_extractor, _GOUniProtDataExtractor) - self.data_extractor = data_extractor - # propagate data_extractor to base loss - if isinstance(base_loss, BCEWeighted): - base_loss.data_extractor = self.data_extractor - base_loss.reduction = ( - "none" # needed to multiply fuzzy loss with base loss for each sample - ) - self.base_loss = base_loss - self.implication_cache_file = f"implications_{self.data_extractor.name}.cache" - self.label_names = _load_label_names( - os.path.join(data_extractor.processed_dir_main, "classes.txt") - ) - self.hierarchy = self._load_implications( - os.path.join(data_extractor.raw_dir, "chebi.obo") - ) - implication_filter_dense = _build_dense_filter( - _build_implication_filter(self.label_names, self.hierarchy), - len(self.label_names), - ) - self.implication_filter_l = implication_filter_dense - self.implication_filter_r = self.implication_filter_l.transpose(0, 1) - self.fuzzy_implication = fuzzy_implication - self.impl_weight = impl_loss_weight - self.pos_scalar = pos_scalar - self.eps = pos_epsilon - self.multiply_by_softmax = multiply_by_softmax - self.use_sigmoidal_implication = use_sigmoidal_implication - self.weight_epoch_dependent = weight_epoch_dependent - self.start_at_epoch = start_at_epoch - self.violations_per_cls_aggregator = violations_per_cls_aggregator - self.multiply_with_base_loss = multiply_with_base_loss - self.no_grads = no_grads - - def _calculate_unaggregated_fuzzy_loss( - self, - pred, - target: torch.Tensor, - weight, - filter_l, - filter_r, - mode="impl", - **kwargs, - ): - # for each batch, get all pairwise losses: [a1, a2, a3] -> [[a1*a1, a1*a2, a1*a3],[a2*a1,...],[a3*a1,...]] - preds_expanded1 = pred.unsqueeze(1).expand(-1, pred.shape[1], -1) - preds_expanded2 = pred.unsqueeze(2).expand(-1, -1, pred.shape[1]) - # filter by implication relations and labels - - label_filter = target.unsqueeze(2).expand(-1, -1, pred.shape[1]) - filter_l = filter_l.to(pred.device).unsqueeze(0).expand(pred.shape[0], -1, -1) - filter_r = filter_r.to(pred.device).unsqueeze(0).expand(pred.shape[0], -1, -1) - if mode == "impl": - all_implications = self._calculate_implication_loss( - preds_expanded2, preds_expanded1 - ) - else: - all_implications = self._calculate_implication_loss( - preds_expanded2, 1 - preds_expanded1 - ) - loss_impl_l = all_implications * filter_l * (1 - label_filter) - if mode == "impl": - loss_impl_r = all_implications.transpose(1, 2) * filter_r * label_filter - loss_impl_sum = loss_impl_l + loss_impl_r - else: - loss_impl_sum = loss_impl_l - - if self.violations_per_cls_aggregator.startswith("log-"): - loss_impl_sum = -torch.log(1 - loss_impl_sum) - violations_per_cls_aggregator = self.violations_per_cls_aggregator[4:] - else: - violations_per_cls_aggregator = self.violations_per_cls_aggregator - if violations_per_cls_aggregator == "sum": - loss_by_cls = loss_impl_sum.sum(dim=-1) - elif violations_per_cls_aggregator == "max": - loss_by_cls = loss_impl_sum.max(dim=-1).values - elif violations_per_cls_aggregator == "mean": - loss_by_cls = loss_impl_sum.mean(dim=-1) - else: - raise NotImplementedError( - f"Unknown violations_per_cls_aggregator {self.violations_per_cls_aggregator}" - ) - - unweighted_mean = loss_by_cls.mean() - implication_loss_weighted = loss_by_cls - if "current_epoch" in kwargs and self.weight_epoch_dependent: - sigmoid_center = ( - self.weight_epoch_dependent[0] - if isinstance(self.weight_epoch_dependent, tuple) - else 50 - ) - sigmoid_spread = ( - self.weight_epoch_dependent[1] - if isinstance(self.weight_epoch_dependent, tuple) - else 10 - ) - # sigmoid function centered around epoch 50 - implication_loss_weighted = implication_loss_weighted / ( - 1 - + math.exp(-(kwargs["current_epoch"] - sigmoid_center) / sigmoid_spread) - ) - implication_loss_weighted *= weight - weighted_mean = implication_loss_weighted.mean() - - return implication_loss_weighted, unweighted_mean, weighted_mean - - def _calculate_unaggregated_base_loss(self, input, target, **kwargs): - nnl = kwargs.pop("non_null_labels", None) - labeled_input = input[nnl] if nnl else input - - if target is not None and self.base_loss is not None: - return self.base_loss(labeled_input, target.float()) - else: - return torch.zeros(input.shape, device=input.device) - - def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: - """ - Forward pass of the implication loss module. - - Args: - input (torch.Tensor): Input tensor. - target (torch.Tensor): Target tensor. - **kwargs: Additional arguments. - - Returns: - tuple: Tuple containing total loss, base loss, and implication loss. - """ - base_loss = self._calculate_unaggregated_base_loss(input, target, **kwargs) - loss_components = {"base_loss": base_loss.mean()} - - if "current_epoch" in kwargs and self.start_at_epoch > kwargs["current_epoch"]: - return base_loss.mean(), loss_components - - pred = torch.sigmoid(input) - fuzzy_loss, unweighted_fuzzy_mean, weighted_fuzzy_mean = ( - self._calculate_unaggregated_fuzzy_loss( - pred, - target, - self.impl_weight, - self.implication_filter_l, - self.implication_filter_r, - **kwargs, - ) - ) - if self.no_grads: - fuzzy_loss = fuzzy_loss.detach() - loss_components["unweighted_fuzzy_loss"] = unweighted_fuzzy_mean - loss_components["weighted_fuzzy_loss"] = weighted_fuzzy_mean - if self.base_loss is None or target is None: - total_loss = self.impl_weight * fuzzy_loss - elif self.multiply_with_base_loss: - total_loss = base_loss * (1 + self.impl_weight * fuzzy_loss) - else: - total_loss = base_loss + self.impl_weight * fuzzy_loss - return total_loss.mean(), loss_components - - def _calculate_implication_loss( - self, l: torch.Tensor, r: torch.Tensor - ) -> torch.Tensor: - """ - Calculate implication loss based on T-norm and other parameters. - - Args: - l (torch.Tensor): Left part of implication. - r (torch.Tensor): Right part of implication. - - Returns: - torch.Tensor: Calculated implication loss. - """ - assert not l.isnan().any(), ( - f"l contains NaN values - l.shape: {l.shape}, l.isnan().sum(): {l.isnan().sum()}, " - f"l: {l}" - ) - assert not r.isnan().any(), ( - f"r contains NaN values - r.shape: {r.shape}, r.isnan().sum(): {r.isnan().sum()}, " - f"r: {r}" - ) - if self.pos_scalar != 1: - l = ( - torch.pow(l + self.eps, 1 / self.pos_scalar) - - math.pow(self.eps, 1 / self.pos_scalar) - ) / ( - math.pow(1 + self.eps, 1 / self.pos_scalar) - - math.pow(self.eps, 1 / self.pos_scalar) - ) - one_min_r = ( - torch.pow(1 - r + self.eps, 1 / self.pos_scalar) - - math.pow(self.eps, 1 / self.pos_scalar) - ) / ( - math.pow(1 + self.eps, 1 / self.pos_scalar) - - math.pow(self.eps, 1 / self.pos_scalar) - ) - else: - one_min_r = 1 - r - # for each implication I, calculate 1 - I(l, 1-one_min_r) - # for S-implications, this is equivalent to the t-norm - if self.fuzzy_implication in ["reichenbach", "rc"]: - individual_loss = l * one_min_r - # xu19 (from Xu et al., 2019: Semantic loss) is not a fuzzy implication, but behaves similar to the Reichenbach - # implication - elif self.fuzzy_implication == "xu19": - individual_loss = -torch.log(1 - l * one_min_r) - elif self.fuzzy_implication in ["lukasiewicz", "lk"]: - individual_loss = torch.relu(l + one_min_r - 1) - elif self.fuzzy_implication in ["kleene_dienes", "kd"]: - individual_loss = torch.min(l, 1 - r) - elif self.fuzzy_implication in ["goedel", "g"]: - individual_loss = torch.where(l <= r, 0, one_min_r) - elif self.fuzzy_implication in ["reverse-goedel", "rg"]: - individual_loss = torch.where(l <= r, 0, l) - elif self.fuzzy_implication in ["binary", "b"]: - individual_loss = torch.where(l <= r, 0, 1).to(dtype=l.dtype) - else: - raise NotImplementedError( - f"Unknown fuzzy implication {self.fuzzy_implication}" - ) - - if self.use_sigmoidal_implication: - # formula by van Krieken, 2022, applied to fuzzy implication with default parameters: b_0 = 0.5, s = 9 - # parts that only depend on b_0 and s are pre-calculated - implication = 1 - individual_loss - sigmoidal_implication = 0.01123379 * ( - 91.0171 * torch.sigmoid(9 * (implication - 0.5)) - 1 - ) - individual_loss = 1 - sigmoidal_implication - - if self.multiply_by_softmax: - individual_loss = individual_loss * individual_loss.softmax(dim=-1) - - return individual_loss - - def _load_implications(self, path_to_chebi: str) -> dict: - """ - Load class hierarchy implications. - - Args: - path_to_chebi (str): Path to the ChEBI ontology file. - - Returns: - dict: Loaded hierarchy of implications. - """ - if os.path.isfile(self.implication_cache_file): - with open(self.implication_cache_file, "rb") as fin: - hierarchy = pickle.load(fin) - else: - hierarchy = self.data_extractor.extract_class_hierarchy(path_to_chebi) - with open(self.implication_cache_file, "wb") as fout: - pickle.dump(hierarchy, fout) - return hierarchy - - -class DisjointLoss(ImplicationLoss): - """ - Disjoint Loss module, extending ImplicationLoss. - - Args: - path_to_disjointness (str): Path to the disjointness data file (a csv file containing pairs of disjoint classes) - data_extractor (_GOUniProtDataExtractor): Data extractor for labels. - base_loss (torch.nn.Module, optional): Base loss function. Defaults to None. - disjoint_loss_weight (float, optional): Weight of disjointness loss. Defaults to 100. - **kwargs: Additional arguments. - """ - - def __init__( - self, - path_to_disjointness: str, - data_extractor: _GOUniProtDataExtractor, - base_loss: torch.nn.Module = None, - disjoint_loss_weight: float = 100, - **kwargs, - ): - super().__init__(data_extractor, base_loss, **kwargs) - self.disjoint_filter_l, self.disjoint_filter_r = _build_disjointness_filter( - path_to_disjointness, self.label_names, self.hierarchy - ) - self.disjoint_weight = disjoint_loss_weight - - def forward(self, input: torch.Tensor, target: torch.Tensor, **kwargs) -> tuple: - """ - Forward pass of the disjoint loss module. - - Args: - input (torch.Tensor): Input tensor. - target (torch.Tensor): Target tensor. - **kwargs: Additional arguments. - - Returns: - tuple: Tuple containing total loss, base loss, implication loss, and disjointness loss. - """ - base_loss = self._calculate_unaggregated_base_loss(input, target, **kwargs) - loss_components = {"base_loss": base_loss.mean()} - - if "current_epoch" in kwargs and self.start_at_epoch > kwargs["current_epoch"]: - return base_loss.mean(), loss_components - - pred = torch.sigmoid(input) - impl_loss, unweighted_impl_mean, weighted_impl_mean = ( - self._calculate_unaggregated_fuzzy_loss( - pred, - target, - self.impl_weight, - self.implication_filter_l, - self.implication_filter_r, - **kwargs, - ) - ) - if self.no_grads: - impl_loss = impl_loss.detach() - loss_components["unweighted_implication_loss"] = unweighted_impl_mean - loss_components["weighted_implication_loss"] = weighted_impl_mean - - disj_loss, unweighted_disj_mean, weighted_disj_mean = ( - self._calculate_unaggregated_fuzzy_loss( - pred, - target, - self.disjoint_weight, - self.disjoint_filter_l, - self.disjoint_filter_r, - mode="disj", - **kwargs, - ) - ) - if self.no_grads: - disj_loss = disj_loss.detach() - loss_components["unweighted_disjointness_loss"] = unweighted_disj_mean - loss_components["weighted_disjointness_loss"] = weighted_disj_mean - - if self.base_loss is None or target is None: - total_loss = self.impl_weight * impl_loss + self.disjoint_weight * disj_loss - elif self.multiply_with_base_loss: - total_loss = base_loss * ( - 1 + self.impl_weight * impl_loss + self.disjoint_weight * disj_loss - ) - else: - total_loss = ( - base_loss - + self.impl_weight * impl_loss - + self.disjoint_weight * disj_loss - ) - return total_loss.mean(), loss_components - - -def _load_label_names(path_to_label_names: str) -> List: - """ - Load label names from a file. - - Args: - path_to_label_names (str): Path to the label names file. - - Returns: - list: List of label names. - """ - with open(path_to_label_names) as fin: - label_names = [int(line.strip()) for line in fin] - return label_names - - -def _build_implication_filter(label_names: List, hierarchy: dict) -> torch.Tensor: - """ - Build implication filter based on label names and hierarchy. Results in list of pairs (A,B) for each implication - A->B (including indirect implications). - - Args: - label_names (list): List of label names. - hierarchy (dict): Hierarchy of implications. - - Returns: - torch.Tensor: Tensor representing implication filter. - """ - return torch.tensor( - [ - (i1, i2) - for i1, l1 in enumerate(label_names) - for i2, l2 in enumerate(label_names) - if l2 in hierarchy.pred[l1] - ] - ) - - -def _build_dense_filter(sparse_filter: torch.Tensor, n_labels: int) -> torch.Tensor: - res = torch.zeros((n_labels, n_labels), dtype=torch.bool) - for l, r in sparse_filter: - res[l, r] = True - return res - - -def _build_disjointness_filter( - path_to_disjointness: str, label_names: List, hierarchy: dict -) -> tuple: - """ - Build disjointness filter based on disjointness data and hierarchy. - - Args: - path_to_disjointness (str): Path to the disjointness data file. - label_names (list): List of label names. - hierarchy (dict): Hierarchy of implications. - - Returns: - tuple: Tuple containing tensors representing disjointness filter. - """ - disjoints = set() - label_dict = dict(map(reversed, enumerate(label_names))) - - with open(path_to_disjointness, "rt") as fin: - reader = csv.reader(fin) - for l1_raw, r1_raw in reader: - l1 = int(l1_raw) - r1 = int(r1_raw) - if l1 == 36233 and r1 == 63353: - # ignore disaccharide-disaccharide derivative disjointness axiom - continue - disjoints.update( - { - (label_dict[l2], label_dict[r2]) - for r2 in list(hierarchy.succ[r1]) + [r1] - if r2 in label_names - for l2 in list(hierarchy.succ[l1]) + [l1] - if l2 in label_names - } - ) - - dis_filter = torch.tensor(list(disjoints)) - dense = _build_dense_filter(dis_filter, len(label_names)) - dense_r = dense.transpose(0, 1) - return dense, dense_r - - -if __name__ == "__main__": - loss = DisjointLoss( - os.path.join("data", "disjoint.csv"), - GOUniProtOver250(), - base_loss=BCEWeighted(), - impl_loss_weight=1, - disjoint_loss_weight=1, - ) - random_preds = torch.randn(10, 997) - random_labels = torch.randint(0, 2, (10, 997)) - for agg in ["sum", "max", "mean", "log-mean"]: - loss.violations_per_cls_aggregator = agg - l = loss(random_preds, random_labels) - print(f"Loss with {agg} aggregation for random input:", l) - - # simplified example for ontology with 4 classes, A -> B, B -> C, D -> C, B and D disjoint - loss.implication_filter_l = torch.tensor( - [[0, 1, 1, 0], [0, 0, 1, 0], [0, 0, 0, 0], [0, 0, 1, 0]] - ) - loss.implication_filter_r = loss.implication_filter_l.transpose(0, 1) - loss.disjoint_filter_l = torch.tensor( - [[0, 0, 0, 0], [0, 0, 0, 1], [0, 0, 0, 0], [0, 1, 0, 0]] - ) - loss.disjoint_filter_r = loss.disjoint_filter_l.transpose(0, 1) - # expected result: first sample: moderately high loss for B disj D, otherwise low, second sample: high loss for A -> B (applied to A), otherwise low - preds = torch.tensor([[0.1, 0.3, 0.7, 0.4], [0.5, 0.2, 0.9, 0.1]]) - labels = [[0, 1, 1, 0], [0, 0, 1, 1]] - for agg in ["sum", "max", "mean", "log-mean"]: - loss.violations_per_cls_aggregator = agg - l = loss(preds, torch.tensor(labels)) - print(f"Loss with {agg} aggregation for simple input:", l) diff --git a/chebai/models/__init__.py b/chebai/models/__init__.py deleted file mode 100644 index e3122d5..0000000 --- a/chebai/models/__init__.py +++ /dev/null @@ -1,2 +0,0 @@ -from chebai.models.base import * -from chebai.models.electra import * diff --git a/chebai/models/base.py b/chebai/models/base.py deleted file mode 100644 index 677640c..0000000 --- a/chebai/models/base.py +++ /dev/null @@ -1,372 +0,0 @@ -import logging -from typing import Any, Dict, Iterable, Optional, Union - -import torch -from lightning.pytorch.core.module import LightningModule -from torchmetrics import Metric - -from chebai.preprocessing.structures import XYData - -logging.getLogger("pysmiles").setLevel(logging.CRITICAL) - -_MODEL_REGISTRY = dict() - - -class ChebaiBaseNet(LightningModule): - """ - Base class for Chebai neural network models inheriting from PyTorch Lightning's LightningModule. - - Args: - criterion (torch.nn.Module, optional): The loss criterion for the model. Defaults to None. - out_dim (int, optional): The output dimension of the model. Defaults to None. - train_metrics (torch.nn.Module, optional): The metrics to be used during training. Defaults to None. - val_metrics (torch.nn.Module, optional): The metrics to be used during validation. Defaults to None. - test_metrics (torch.nn.Module, optional): The metrics to be used during testing. Defaults to None. - pass_loss_kwargs (bool, optional): Whether to pass loss kwargs to the criterion. Defaults to True. - optimizer_kwargs (Dict[str, Any], optional): Additional keyword arguments for the optimizer. Defaults to None. - **kwargs: Additional keyword arguments. - - Attributes: - NAME (str): The name of the model. - """ - - NAME = None - - def __init__( - self, - criterion: torch.nn.Module = None, - out_dim: Optional[int] = None, - train_metrics: Optional[torch.nn.Module] = None, - val_metrics: Optional[torch.nn.Module] = None, - test_metrics: Optional[torch.nn.Module] = None, - pass_loss_kwargs: bool = True, - optimizer_kwargs: Optional[Dict[str, Any]] = None, - exclude_hyperparameter_logging: Optional[Iterable[str]] = None, - **kwargs, - ): - super().__init__() - if exclude_hyperparameter_logging is None: - exclude_hyperparameter_logging = tuple() - self.criterion = criterion - self.save_hyperparameters( - ignore=[ - "criterion", - "train_metrics", - "val_metrics", - "test_metrics", - *exclude_hyperparameter_logging, - ] - ) - self.out_dim = out_dim - if optimizer_kwargs: - self.optimizer_kwargs = optimizer_kwargs - else: - self.optimizer_kwargs = dict() - self.train_metrics = train_metrics - self.validation_metrics = val_metrics - self.test_metrics = test_metrics - self.pass_loss_kwargs = pass_loss_kwargs - - def on_load_checkpoint(self, checkpoint: Dict[str, Any]) -> None: - # avoid errors due to unexpected keys (e.g., if loading checkpoint from a bce model and using it with a - # different loss) - if "criterion.base_loss.pos_weight" in checkpoint["state_dict"]: - del checkpoint["state_dict"]["criterion.base_loss.pos_weight"] - if "criterion.pos_weight" in checkpoint["state_dict"]: - del checkpoint["state_dict"]["criterion.pos_weight"] - - def __init_subclass__(cls, **kwargs): - """ - Automatically registers subclasses in the model registry to prevent duplicates. - - Args: - **kwargs: Additional keyword arguments. - """ - if cls.NAME in _MODEL_REGISTRY: - raise ValueError(f"Model {cls.NAME} does already exist") - else: - _MODEL_REGISTRY[cls.NAME] = cls - - def _get_prediction_and_labels( - self, data: Dict[str, Any], labels: torch.Tensor, output: torch.Tensor - ) -> (torch.Tensor, torch.Tensor): - """ - Gets the predictions and labels from the model output. - - Args: - data (Dict[str, Any]): The processed batch data. - labels (torch.Tensor): The true labels. - output (torch.Tensor): The model output. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: Predictions and labels. - """ - return output, labels - - def _process_labels_in_batch(self, batch: XYData) -> torch.Tensor: - """ - Processes the labels in the batch. - - Args: - batch (XYData): The input batch of data. - - Returns: - torch.Tensor: The processed labels. - """ - return batch.y.float() - - def _process_batch(self, batch: XYData, batch_idx: int) -> Dict[str, Any]: - """ - Processes the batch data. - - Args: - batch (XYData): The input batch of data. - batch_idx (int): The index of the current batch. - - Returns: - Dict[str, Any]: Processed batch data. - """ - return dict( - features=batch.x, - labels=self._process_labels_in_batch(batch), - model_kwargs=batch.additional_fields["model_kwargs"], - loss_kwargs=batch.additional_fields["loss_kwargs"], - idents=batch.additional_fields["idents"], - ) - - def _process_for_loss( - self, - model_output: torch.Tensor, - labels: torch.Tensor, - loss_kwargs: Dict[str, Any], - ) -> (torch.Tensor, torch.Tensor, Dict[str, Any]): - """ - Processes the data for loss computation. - - Args: - model_output (torch.Tensor): The model output. - labels (torch.Tensor): The true labels. - loss_kwargs (Dict[str, Any]): Additional keyword arguments for the loss function. - - Returns: - Tuple[torch.Tensor, torch.Tensor, Dict[str, Any]]: Model output, labels, and loss kwargs. - """ - return model_output, labels, loss_kwargs - - def training_step( - self, batch: XYData, batch_idx: int - ) -> Dict[str, Union[torch.Tensor, Any]]: - """ - Defines the training step. - - Args: - batch (XYData): The input batch of data. - batch_idx (int): The index of the current batch. - - Returns: - Dict[str, Union[torch.Tensor, Any]]: The result of the training step. - """ - return self._execute( - batch, batch_idx, self.train_metrics, prefix="train_", sync_dist=True - ) - - def validation_step( - self, batch: XYData, batch_idx: int - ) -> Dict[str, Union[torch.Tensor, Any]]: - """ - Defines the validation step. - - Args: - batch (XYData): The input batch of data. - batch_idx (int): The index of the current batch. - - Returns: - Dict[str, Union[torch.Tensor, Any]]: The result of the validation step. - """ - return self._execute( - batch, batch_idx, self.validation_metrics, prefix="val_", sync_dist=True - ) - - def test_step( - self, batch: XYData, batch_idx: int - ) -> Dict[str, Union[torch.Tensor, Any]]: - """ - Defines the test step. - - Args: - batch (XYData): The input batch of data. - batch_idx (int): The index of the current batch. - - Returns: - Dict[str, Union[torch.Tensor, Any]]: The result of the test step. - """ - return self._execute( - batch, batch_idx, self.test_metrics, prefix="test_", sync_dist=True - ) - - def predict_step( - self, batch: XYData, batch_idx: int, **kwargs - ) -> Dict[str, Union[torch.Tensor, Any]]: - """ - Defines the prediction step. - - Args: - batch (XYData): The input batch of data. - batch_idx (int): The index of the current batch. - **kwargs: Additional keyword arguments. - - Returns: - Dict[str, Union[torch.Tensor, Any]]: The result of the prediction step. - """ - return self._execute(batch, batch_idx, self.test_metrics, prefix="", log=False) - - def _execute( - self, - batch: XYData, - batch_idx: int, - metrics: Optional[torch.nn.Module] = None, - prefix: Optional[str] = "", - log: Optional[bool] = True, - sync_dist: Optional[bool] = False, - ) -> Dict[str, Union[torch.Tensor, Any]]: - """ - Executes the model on a batch of data and returns the model output and predictions. - - Args: - batch (XYData): The input batch of data. - batch_idx (int): The index of the current batch. - metrics (torch.nn.Module): A dictionary of metrics to track. - prefix (str, optional): A prefix to add to the metric names. Defaults to "". - log (bool, optional): Whether to log the metrics. Defaults to True. - sync_dist (bool, optional): Whether to synchronize distributed training. Defaults to False. - - Returns: - Dict[str, Union[torch.Tensor, Any]]: A dictionary containing the processed data, labels, model output, - predictions, and loss (if applicable). - """ - assert isinstance(batch, XYData) - batch = batch.to(self.device) - data = self._process_batch(batch, batch_idx) - labels = data["labels"] - model_output = self(data, **data.get("model_kwargs", dict())) - pr, tar = self._get_prediction_and_labels(data, labels, model_output) - d = dict(data=data, labels=labels, output=model_output, preds=pr) - if log: - if self.criterion is not None: - loss_data, loss_labels, loss_kwargs_candidates = self._process_for_loss( - model_output, labels, data.get("loss_kwargs", dict()) - ) - loss_kwargs = dict() - if self.pass_loss_kwargs: - loss_kwargs = loss_kwargs_candidates - loss_kwargs["current_epoch"] = self.trainer.current_epoch - loss = self.criterion(loss_data, loss_labels, **loss_kwargs) - if isinstance(loss, tuple): - unnamed_loss_index = 1 - if isinstance(loss[1], dict): - unnamed_loss_index = 2 - for key, value in loss[1].items(): - self.log( - key, - value if isinstance(value, int) else value.item(), - batch_size=len(batch), - on_step=True, - on_epoch=True, - prog_bar=False, - logger=True, - sync_dist=sync_dist, - ) - loss_additional = loss[unnamed_loss_index:] - for i, loss_add in enumerate(loss_additional): - self.log( - f"{prefix}loss_{i}", - loss_add if isinstance(loss_add, int) else loss_add.item(), - batch_size=len(batch), - on_step=True, - on_epoch=True, - prog_bar=False, - logger=True, - sync_dist=sync_dist, - ) - loss = loss[0] - - d["loss"] = loss - self.log( - f"{prefix}loss", - loss.item(), - batch_size=len(batch), - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - sync_dist=sync_dist, - ) - if metrics and labels is not None: - for metric_name, metric in metrics.items(): - metric.update(pr, tar) - self._log_metrics(prefix, metrics, len(batch)) - return d - - def _log_metrics(self, prefix: str, metrics: torch.nn.Module, batch_size: int): - """ - Logs the metrics for the given prefix. - - Args: - prefix (str): The prefix to be added to the metric names. - metrics (torch.nn.Module): A dictionary containing the metrics to be logged. - batch_size (int): The batch size used for logging. - - Returns: - None - """ - # don't use sync_dist=True if the metric is a torchmetrics-metric - # (see https://github.com/Lightning-AI/pytorch-lightning/discussions/6501#discussioncomment-569757) - for metric_name, metric in metrics.items(): - m = None # m = metric.compute() - if isinstance(m, dict): - # todo: is this case needed? it requires logging values directly which does not give accurate results - # with the current metric-setup - for k, m2 in m.items(): - self.log( - f"{prefix}{metric_name}{k}", - m2, - batch_size=batch_size, - on_step=False, - on_epoch=True, - prog_bar=True, - logger=True, - ) - else: - self.log( - f"{prefix}{metric_name}", - metric, - batch_size=batch_size, - on_step=False, - on_epoch=True, - prog_bar=True, - logger=True, - ) - - def forward(self, x: Dict[str, Any]) -> torch.Tensor: - """ - Defines the forward pass. - - Args: - x (Dict[str, Any]): The input data. - - Returns: - torch.Tensor: The model output. - """ - raise NotImplementedError - - def configure_optimizers(self, **kwargs) -> torch.optim.Optimizer: - """ - Configures the optimizers. - - Args: - **kwargs: Additional keyword arguments. - - Returns: - torch.optim.Optimizer: The optimizer. - """ - return torch.optim.Adamax(self.parameters(), **self.optimizer_kwargs) diff --git a/chebai/models/chemberta.py b/chebai/models/chemberta.py deleted file mode 100644 index b601542..0000000 --- a/chebai/models/chemberta.py +++ /dev/null @@ -1,77 +0,0 @@ -import logging -import random -from tempfile import TemporaryDirectory - -import torch -from torch import nn -from torch.nn.functional import one_hot -from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence -from transformers import ( - RobertaConfig, - RobertaForMaskedLM, - RobertaModel, - RobertaTokenizer, -) - -from chebai.models.base import ChebaiBaseNet - -logging.getLogger("pysmiles").setLevel(logging.CRITICAL) -MAX_LEN = 1800 - - -class ChembertaPre(ChebaiBaseNet): - NAME = "ChembertaPre" - - def __init__(self, p=0.2, **kwargs): - super().__init__(**kwargs) - self._p = p - self.config = RobertaConfig(**kwargs["config"]) - self.model = RobertaForMaskedLM(self.config) - - def _process_batch(self, batch, batch_idx): - masked = ( - torch.rand([batch.x.shape[0]], device=self.device) - * torch.tensor(batch.lens, device=self.device) - ).long() - labels = one_hot( - torch.gather(batch.x, 1, masked.unsqueeze(-1)).squeeze(-1), - self.config.vocab_size, - ) - features = 1 + batch.x - features = features * (1 - one_hot(masked, batch.x.shape[-1])) - return features, labels - - def forward(self, data): - x = self.model(data) - return {"logits": torch.sum(x.logits, dim=1)} - - -class Chemberta(ChebaiBaseNet): - NAME = "Chemberta" - - def __init__(self, **kwargs): - # Remove this property in order to prevent it from being stored as a - # hyper parameter - pretrained_checkpoint = ( - kwargs.pop("pretrained_checkpoint") - if "pretrained_checkpoint" in kwargs - else None - ) - super().__init__(**kwargs) - self.config = RobertaConfig( - **kwargs["config"], output_attentions=True, num_labels=self.out_dim - ) - - if pretrained_checkpoint: - elpre = RobertaModel.load_from_checkpoint(pretrained_checkpoint) - with TemporaryDirectory() as td: - elpre.electra.save_pretrained(td) - self.electra = RobertaModel.from_pretrained(td, config=self.config) - in_d = elpre.config.hidden_size - else: - self.electra = RobertaModel(config=self.config) - in_d = self.config.hidden_size - - def forward(self, data): - electra = self.electra(data) - return dict(logits=electra.logits, attentions=electra.attentions) diff --git a/chebai/models/chemyk.py b/chebai/models/chemyk.py deleted file mode 100644 index 13bbea7..0000000 --- a/chebai/models/chemyk.py +++ /dev/null @@ -1,63 +0,0 @@ -import logging -import os -import pickle -import sys - -import networkx as nx -import torch -from torch import nn -from torch.nn import functional as F -from torch.nn.functional import pad - -from chebai.models.base import ChebaiBaseNet - -logging.getLogger("pysmiles").setLevel(logging.CRITICAL) - - -class ChemYK(ChebaiBaseNet): - NAME = "ChemYK" - - def __init__(self, in_d, out_d, num_classes, **kwargs): - super().__init__(num_classes, **kwargs) - d_internal = in_d - self.d_internal = d_internal - self.embedding = nn.Embedding(800, d_internal) - self.s = nn.Linear(d_internal, 1) - self.a_l = nn.Linear(d_internal, 1) - self.a_r = nn.Linear(d_internal, 1) - self.w_l = nn.Linear(d_internal, d_internal) - self.w_r = nn.Linear(d_internal, d_internal) - self.norm = nn.LayerNorm(d_internal) - self.output = nn.Sequential( - nn.Linear(in_d, in_d), - nn.ReLU(), - nn.Dropout(0.2), - nn.Linear(in_d, num_classes), - ) - - def forward(self, data, *args, **kwargs): - m = self.embedding(data.x) - max_width = m.shape[1] - h = [m] # torch.zeros(emb.shape[0], max_width, *emb.shape[1:]) - # h[:, 0] = emb - for width in range(1, max_width): - l = torch.stack(tuple(h[i][:, : (max_width - width)] for i in range(width))) - r = torch.stack( - tuple(h[i][:, (width - i) :] for i in range(0, width)) - ).flip(0) - m = self.merge(l, r) - h.append(m) - return self.output(m).squeeze(1) - - def merge(self, l, r): - x = torch.stack([self.a_l(l), self.a_r(r)]) - beta = torch.softmax(x, 0) - return F.leaky_relu( - self.attention( - torch.sum(beta * torch.stack([self.w_l(l), self.w_r(r)]), dim=0) - ) - ) - - def attention(self, parts): - at = torch.softmax(self.s(parts), 1) - return torch.sum(at * parts, dim=0) diff --git a/chebai/models/electra.py b/chebai/models/electra.py deleted file mode 100644 index dc6c719..0000000 --- a/chebai/models/electra.py +++ /dev/null @@ -1,535 +0,0 @@ -import logging -from math import pi -from tempfile import TemporaryDirectory -from typing import Any, Dict, Optional, Tuple - -import torch -from torch import Tensor, nn -from torch.nn.utils.rnn import pad_sequence -from transformers import ( - ElectraConfig, - ElectraForMaskedLM, - ElectraForPreTraining, - ElectraModel, -) - -from chebai.loss.pretraining import ElectraPreLoss # noqa -from chebai.models.base import ChebaiBaseNet -from chebai.preprocessing.reader import CLS_TOKEN, MASK_TOKEN_INDEX - -logging.getLogger("pysmiles").setLevel(logging.CRITICAL) - -from chebai.loss.semantic import DisjointLoss as ElectraChEBIDisjointLoss # noqa - - -class ElectraPre(ChebaiBaseNet): - """ - ElectraPre class represents an Electra model for pre-training inherited from ChebaiBaseNet. - - Args: - config (dict): Configuration parameters for the Electra model. - **kwargs: Additional keyword arguments (passed to parent class). - - Attributes: - NAME (str): Name of the ElectraPre model. - generator_config (ElectraConfig): Configuration for the generator model. - generator (ElectraForMaskedLM): Generator model for masked language modeling. - discriminator_config (ElectraConfig): Configuration for the discriminator model. - discriminator (ElectraForPreTraining): Discriminator model for pre-training. - replace_p (float): Probability of replacing tokens during training. - """ - - NAME = "ElectraPre" - - def __init__(self, config: Dict[str, Any] = None, **kwargs: Any): - super().__init__(config=config, **kwargs) - self.generator_config = ElectraConfig(**config["generator"]) - self.generator = ElectraForMaskedLM(self.generator_config) - self.discriminator_config = ElectraConfig(**config["discriminator"]) - self.discriminator = ElectraForPreTraining(self.discriminator_config) - self.replace_p = 0.1 - - @property - def as_pretrained(self) -> ElectraForPreTraining: - """ - Returns the discriminator model as a pre-trained model. - - Returns: - ElectraForPreTraining: The discriminator model. - """ - return self.discriminator - - def _process_labels_in_batch(self, batch: Dict[str, Any]) -> None: - """ - Processes the labels in the batch. - - Args: - batch (Dict[str, Any]): The input batch of data. - - Returns: - torch.Tensor: The processed labels. - """ - return None - - def forward( - self, data: Dict[str, Any], **kwargs: Any - ) -> Tuple[Tuple[Tensor, Tensor], Tuple[Tensor, Tensor]]: - """ - Forward pass of the ElectraPre model. - - Args: - data (dict): Input data. - **kwargs: Additional keyword arguments. - - Returns: - tuple: A tuple containing the raw generator output and discriminator output. - The generator output is a tensor of shape (batch_size, max_seq_len, vocab_size). - The discriminator output is a tensor of shape (batch_size, max_seq_len). - """ - features = data["features"] - features = features.long() - self.batch_size = batch_size = features.shape[0] - max_seq_len = features.shape[1] - - mask = kwargs["mask"] - with torch.no_grad(): - dis_tar = ( - torch.rand((batch_size,), device=self.device) * torch.sum(mask, dim=-1) - ).int() - disc_tar_one_hot = torch.eq( - torch.arange(max_seq_len, device=self.device)[None, :], dis_tar[:, None] - ) - gen_tar = features[disc_tar_one_hot] - gen_tar_one_hot = torch.eq( - torch.arange(self.generator_config.vocab_size, device=self.device)[ - None, : - ], - gen_tar[:, None], - ) - - raw_gen_out = torch.mean( - self.generator( - (features * ~disc_tar_one_hot) + MASK_TOKEN_INDEX * disc_tar_one_hot, - attention_mask=mask, - ).logits, - dim=1, - ) - - with torch.no_grad(): - gen_best_guess = raw_gen_out.argmax(dim=-1) - correct_mask = features[disc_tar_one_hot] == gen_best_guess - random_tokens = torch.randint( - self.generator_config.vocab_size, (batch_size,), device=self.device - ) - replacements = gen_best_guess * ~correct_mask + random_tokens * correct_mask - - disc_out = self.discriminator( - features * ~disc_tar_one_hot + replacements[:, None] * disc_tar_one_hot, - attention_mask=mask, - ).logits - return (raw_gen_out, disc_out), (gen_tar_one_hot, disc_tar_one_hot) - - def _get_prediction_and_labels( - self, batch: Dict[str, Any], labels: Tensor, output: Tensor - ) -> Tuple[Tensor, Tensor]: - """ - Gets the predictions and labels from the model output. - - Args: - data (Dict[str, Any]): The processed batch data. - labels (torch.Tensor): The true labels. - output (torch.Tensor): The model output. - - Returns: - Tuple[torch.Tensor, torch.Tensor]: Predictions and labels. - """ - return torch.softmax(output[0][1], dim=-1), output[1][1].int() - - -def filter_dict(d: Dict[str, Any], filter_key: str) -> Dict[str, Any]: - """ - Filters a dictionary by a given key prefix. - - Args: - d (dict): The dictionary to filter. - filter_key (str): The key prefix to filter by. - - Returns: - dict: A dictionary containing only the key-value pairs where the key starts with the given prefix. - """ - return { - str(k)[len(filter_key) :]: v - for k, v in d.items() - if str(k).startswith(filter_key) - } - - -class Electra(ChebaiBaseNet): - """ - Electra model implementation inherited from ChebaiBaseNet. - - Args: - config (Dict[str, Any], optional): Configuration parameters for the Electra model. Defaults to None. - pretrained_checkpoint (str, optional): Path to the pretrained checkpoint file. Defaults to None. - load_prefix (str, optional): Prefix to filter the state_dict keys from the pretrained checkpoint. Defaults to None. - **kwargs: Additional keyword arguments. - - Attributes: - NAME (str): Name of the Electra model. - """ - - NAME = "Electra" - - def _process_batch(self, batch: Dict[str, Any], batch_idx: int) -> Dict[str, Any]: - """ - Process a batch of data. - - Args: - batch (Dict[str, Any]): The input batch of data. - batch_idx (int): The index of the batch (not used). - - Returns: - dict: A dictionary containing the processed batch, keys are `features`, `labels`, `model_kwargs`, - `loss_kwargs` and `idents`. - """ - model_kwargs = dict() - loss_kwargs = batch.additional_fields["loss_kwargs"] - if "lens" in batch.additional_fields["model_kwargs"]: - model_kwargs["attention_mask"] = pad_sequence( - [ - torch.ones(l + 1, device=self.device) - for l in batch.additional_fields["model_kwargs"]["lens"] - ], - batch_first=True, - ) - cls_tokens = ( - torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze( - -1 - ) - * CLS_TOKEN - ) - return dict( - features=torch.cat((cls_tokens, batch.x), dim=1), - labels=batch.y, - model_kwargs=model_kwargs, - loss_kwargs=loss_kwargs, - idents=batch.additional_fields["idents"], - ) - - @property - def as_pretrained(self) -> ElectraModel: - """ - Get the pretrained Electra model. - - Returns: - ElectraModel: The pretrained Electra model. - """ - return self.electra.electra - - def __init__( - self, - config: Optional[Dict[str, Any]] = None, - pretrained_checkpoint: Optional[str] = None, - load_prefix: Optional[str] = None, - **kwargs: Any, - ): - # Remove this property in order to prevent it from being stored as a - # hyper parameter - - super().__init__(**kwargs) - if config is None: - config = dict() - if not "num_labels" in config and self.out_dim is not None: - config["num_labels"] = self.out_dim - self.config = ElectraConfig(**config, output_attentions=True) - self.word_dropout = nn.Dropout(config.get("word_dropout", 0)) - - in_d = self.config.hidden_size - self.output = nn.Sequential( - nn.Dropout(self.config.hidden_dropout_prob), - nn.Linear(in_d, in_d), - nn.GELU(), - nn.Dropout(self.config.hidden_dropout_prob), - nn.Linear(in_d, self.config.num_labels), - ) - - # Load pretrained checkpoint if provided - if pretrained_checkpoint: - with open(pretrained_checkpoint, "rb") as fin: - model_dict = torch.load( - fin, map_location=self.device, weights_only=False - ) - if load_prefix: - state_dict = filter_dict(model_dict["state_dict"], load_prefix) - else: - state_dict = model_dict["state_dict"] - self.electra = ElectraModel.from_pretrained( - None, state_dict=state_dict, config=self.config - ) - else: - self.electra = ElectraModel(config=self.config) - - def _process_for_loss( - self, - model_output: Dict[str, Tensor], - labels: Tensor, - loss_kwargs: Dict[str, Any], - ) -> Tuple[Tensor, Tensor, Dict[str, Any]]: - """ - Process the model output for calculating the loss. - - Args: - model_output (Dict[str, Tensor]): The output of the model. - labels (Tensor): The target labels. - loss_kwargs (Dict[str, Any]): Additional loss arguments. - - Returns: - tuple: A tuple containing the processed model output, labels, and loss arguments. - """ - kwargs_copy = dict(loss_kwargs) - if labels is not None: - labels = labels.float() - return model_output["logits"], labels, kwargs_copy - - def _get_prediction_and_labels( - self, data: Dict[str, Any], labels: Tensor, model_output: Dict[str, Tensor] - ) -> Tuple[Tensor, Tensor]: - """ - Get the predictions and labels from the model output. Applies a sigmoid to the model output. - - Args: - data (Dict[str, Any]): The input data. - labels (Tensor): The target labels. - model_output (Dict[str, Tensor]): The output of the model. - - Returns: - tuple: A tuple containing the predictions and labels. - """ - d = model_output["logits"] - loss_kwargs = data.get("loss_kwargs", dict()) - if "non_null_labels" in loss_kwargs: - n = loss_kwargs["non_null_labels"] - d = d[n] - return torch.sigmoid(d), labels.int() if labels is not None else None - - def forward(self, data: Dict[str, Tensor], **kwargs: Any) -> Dict[str, Any]: - """ - Forward pass of the Electra model. - - Args: - data (Dict[str, Tensor]): The input data (expects a key `features`). - **kwargs: Additional keyword arguments for `self.electra`. - - Returns: - dict: A dictionary containing the model output (logits and attentions). - """ - self.batch_size = data["features"].shape[0] - try: - inp = self.electra.embeddings.forward(data["features"].int()) - except RuntimeError as e: - print(f"RuntimeError at forward: {e}") - print(f'data[features]: {data["features"]}') - raise e - inp = self.word_dropout(inp) - electra = self.electra(inputs_embeds=inp, **kwargs) - d = electra.last_hidden_state[:, 0, :] - return dict( - logits=self.output(d), - attentions=electra.attentions, - ) - - -class ElectraLegacy(ChebaiBaseNet): - NAME = "ElectraLeg" - - def __init__(self, **kwargs): - super().__init__(**kwargs) - self.config = ElectraConfig(**kwargs["config"], output_attentions=True) - - if "pretrained_checkpoint" in kwargs: - elpre = ElectraPre.load_from_checkpoint(kwargs["pretrained_checkpoint"]) - with TemporaryDirectory() as td: - elpre.electra.save_pretrained(td) - self.electra = ElectraModel.from_pretrained(td, config=self.config) - in_d = elpre.config.hidden_size - else: - self.electra = ElectraModel(config=self.config) - in_d = self.config.hidden_size - - self.output = nn.Sequential( - nn.Linear(in_d, in_d), - nn.ReLU(), - nn.Linear(in_d, in_d), - nn.ReLU(), - nn.Linear(in_d, in_d), - nn.ReLU(), - nn.Dropout(0.2), - nn.Linear(in_d, 500), - ) - - def forward(self, data): - electra = self.electra(data) - d = torch.sum(electra.last_hidden_state, dim=1) - return dict(logits=self.output(d), attentions=electra.attentions) - - -class ConeElectra(ChebaiBaseNet): - NAME = "ConeElectra" - - def _process_batch(self, batch, batch_idx): - mask = pad_sequence( - [torch.ones(l + 1, device=self.device) for l in batch.lens], - batch_first=True, - ) - cls_tokens = ( - torch.ones(batch.x.shape[0], dtype=torch.int, device=self.device).unsqueeze( - -1 - ) - * CLS_TOKEN - ) - return dict( - features=torch.cat((cls_tokens, batch.x), dim=1), - labels=batch.y, - model_kwargs=dict(attention_mask=mask), - ) - - @property - def as_pretrained(self): - return self.electra.electra - - def __init__(self, cone_dimensions=20, **kwargs): - # Remove this property in order to prevent it from being stored as a - # hyper parameter - pretrained_checkpoint = ( - kwargs.pop("pretrained_checkpoint") - if "pretrained_checkpoint" in kwargs - else None - ) - - self.cone_dimensions = cone_dimensions - - super().__init__(**kwargs) - if not "num_labels" in kwargs["config"] and self.out_dim is not None: - kwargs["config"]["num_labels"] = self.out_dim - self.config = ElectraConfig(**kwargs["config"], output_attentions=True) - self.word_dropout = nn.Dropout(kwargs["config"].get("word_dropout", 0)) - model_prefix = kwargs.get("load_prefix", None) - if pretrained_checkpoint: - with open(pretrained_checkpoint, "rb") as fin: - model_dict = torch.load( - fin, map_location=self.device, weights_only=False - ) - if model_prefix: - state_dict = { - str(k)[len(model_prefix) :]: v - for k, v in model_dict["state_dict"].items() - if str(k).startswith(model_prefix) - } - else: - state_dict = model_dict["state_dict"] - self.electra = ElectraModel.from_pretrained( - None, state_dict=state_dict, config=self.config - ) - else: - self.electra = ElectraModel(config=self.config) - - in_d = self.config.hidden_size - - self.line_embedding = nn.Sequential( - nn.Dropout(self.config.hidden_dropout_prob), - nn.Linear(in_d, in_d), - nn.GELU(), - nn.Dropout(self.config.hidden_dropout_prob), - nn.Linear(in_d, self.cone_dimensions), - ) - - self.cone_axes = nn.Parameter( - 2 * pi * torch.rand((1, self.config.num_labels, self.cone_dimensions)) - ) - self.cone_arcs = nn.Parameter( - pi * (1 - 2 * torch.rand((1, self.config.num_labels, self.cone_dimensions))) - ) - - def _get_data_for_loss(self, model_output, labels): - d = model_output["predicted_vectors"] - return dict( - input=dict( - predicted_vectors=d, cone_axes=self.cone_axes, cone_arcs=self.cone_arcs - ), - target=labels.float(), - ) - - def _get_prediction_and_labels(self, data, labels, model_output): - d = model_output["predicted_vectors"].unsqueeze(1) - - d = in_cone_parts(d, self.cone_axes, self.cone_arcs) - - return torch.mean(d, dim=-1), labels.int() - - def forward(self, data, **kwargs): - self.batch_size = data["features"].shape[0] - inp = self.electra.embeddings.forward(data["features"]) - inp = self.word_dropout(inp) - electra = self.electra(inputs_embeds=inp, **kwargs) - d = electra.last_hidden_state[:, 0, :] - return dict( - predicted_vectors=self.line_embedding(d), - attentions=electra.attentions, - ) - - -def softabs(x, eps=0.01): - return (x**2 + eps) ** 0.5 - eps**0.5 - - -def anglify(x): - return torch.tanh(x) * pi - - -def turn(vector, angle): - v = vector - angle - return v - (v > pi) * 2 * pi + (v < -pi) * 2 * pi - - -def in_cone_parts(vectors, cone_axes, cone_arcs): - """ - # trap between -pi and pi - cone_ax_ang = anglify(cone_axes) - v = anglify(vectors) - - # trap between 0 and pi - cone_arc_ang = (torch.tanh(cone_arcs)+1)*pi/2 - theta_L = cone_ax_ang - cone_arc_ang/2 - #theta_L = theta_L - (theta_L > 2*pi) * 2 * pi + (theta_L < 0) *2*pi - theta_R = cone_ax_ang + cone_arc_ang/2 - #theta_R = theta_R - (theta_R > 2 * pi) * 2 * pi + (theta_R < 0) * 2 * pi - dis = (torch.abs(turn(v, theta_L)) + torch.abs(turn(v, theta_R)) - cone_arc_ang)/(2*pi-cone_arc_ang) - return dis - """ - a = cone_axes - cone_arcs**2 - b = cone_axes + cone_arcs**2 - bigger_than_a = torch.sigmoid(vectors - a) - smaller_than_b = torch.sigmoid(b - vectors) - return bigger_than_a * smaller_than_b - - -class ConeLoss: - def __init__(self, center_scaling=0.1): - self.center_scaling = center_scaling - - def negate(self, ax, arc): - offset = pi * torch.ones_like(ax) - offset[ax >= 0] *= -1 - return ax + offset, pi - arc - - def __call__(self, target, input): - predicted_vectors = input["predicted_vectors"].unsqueeze(1) - cone_axes = input["cone_axes"] - cone_arcs = input["cone_arcs"] - memberships = (1 - 1e-6) * ( - in_cone_parts(predicted_vectors, cone_axes, cone_arcs) - ) - loss = torch.nn.functional.binary_cross_entropy( - memberships, target.unsqueeze(-1).expand(-1, -1, 20) - ) - return loss diff --git a/chebai/models/external/__init__.py b/chebai/models/external/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/chebai/models/ffn.py b/chebai/models/ffn.py deleted file mode 100644 index c9c6f91..0000000 --- a/chebai/models/ffn.py +++ /dev/null @@ -1,153 +0,0 @@ -from typing import Any, Dict, List, Optional, Tuple - -import torch -from torch import Tensor, nn - -from chebai.models import ChebaiBaseNet - - -class FFN(ChebaiBaseNet): - # Reference: https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/models.py#L121-L139 - - NAME = "FFN" - - def __init__( - self, - input_size: int, - hidden_layers: List[int] = [ - 1024, - ], - **kwargs - ): - super().__init__(**kwargs) - - layers = [] - current_layer_input_size = input_size - for hidden_dim in hidden_layers: - layers.append(MLPBlock(current_layer_input_size, hidden_dim)) - layers.append(Residual(MLPBlock(hidden_dim, hidden_dim))) - current_layer_input_size = hidden_dim - - layers.append(torch.nn.Linear(current_layer_input_size, self.out_dim)) - layers.append(nn.Sigmoid()) - self.model = nn.Sequential(*layers) - - def _get_prediction_and_labels(self, data, labels, model_output): - d = model_output["logits"] - loss_kwargs = data.get("loss_kwargs", dict()) - if "non_null_labels" in loss_kwargs: - n = loss_kwargs["non_null_labels"] - d = d[n] - return torch.sigmoid(d), labels.int() if labels is not None else None - - def _process_for_loss( - self, - model_output: Dict[str, Tensor], - labels: Tensor, - loss_kwargs: Dict[str, Any], - ) -> Tuple[Tensor, Tensor, Dict[str, Any]]: - """ - Process the model output for calculating the loss. - - Args: - model_output (Dict[str, Tensor]): The output of the model. - labels (Tensor): The target labels. - loss_kwargs (Dict[str, Any]): Additional loss arguments. - - Returns: - tuple: A tuple containing the processed model output, labels, and loss arguments. - """ - kwargs_copy = dict(loss_kwargs) - if labels is not None: - labels = labels.float() - return model_output["logits"], labels, kwargs_copy - - def forward(self, data, **kwargs): - x = data["features"] - return {"logits": self.model(x)} - - -class Residual(nn.Module): - """ - A residual layer that adds the output of a function to its input. - - Args: - fn (nn.Module): The function to be applied to the input. - - References: - https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/base.py#L6-L35 - """ - - def __init__(self, fn): - """ - Initialize the Residual layer with a given function. - - Args: - fn (nn.Module): The function to be applied to the input. - """ - super().__init__() - self.fn = fn - - def forward(self, x): - """ - Forward pass of the Residual layer. - - Args: - x: Input tensor. - - Returns: - torch.Tensor: The input tensor added to the result of applying the function `fn` to it. - """ - return x + self.fn(x) - - -class MLPBlock(nn.Module): - """ - A basic Multi-Layer Perceptron (MLP) block with one fully connected layer. - - Args: - in_features (int): The number of input features. - output_size (int): The number of output features. - bias (boolean): Add bias to the linear layer - layer_norm (boolean): Apply layer normalization - dropout (float): The dropout value - activation (nn.Module): The activation function to be applied after each fully connected layer. - - References: - https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/base.py#L38-L73 - - Example: - ```python - # Create an MLP block with 2 hidden layers and ReLU activation - mlp_block = MLPBlock(input_size=64, output_size=10, activation=nn.ReLU()) - - # Apply the MLP block to an input tensor - input_tensor = torch.randn(32, 64) - output = mlp_block(input_tensor) - ``` - """ - - def __init__( - self, - in_features, - out_features, - bias=True, - layer_norm=True, - dropout=0.1, - activation=nn.ReLU, - ): - super().__init__() - self.linear = nn.Linear(in_features, out_features, bias) - self.activation = activation() - self.layer_norm: Optional[nn.LayerNorm] = ( - nn.LayerNorm(out_features) if layer_norm else None - ) - self.dropout: Optional[nn.Dropout] = nn.Dropout(dropout) if dropout else None - - def forward(self, x): - x = self.activation(self.linear(x)) - if self.layer_norm: - x = self.layer_norm(x) - if self.dropout: - x = self.dropout(x) - return x diff --git a/chebai/models/lnn_model.py b/chebai/models/lnn_model.py deleted file mode 100644 index 3d61c5a..0000000 --- a/chebai/models/lnn_model.py +++ /dev/null @@ -1,40 +0,0 @@ -import fastobo -import pyhornedowl -import tqdm -from lnn import Implies, Model, Not, Predicate, Variable, World -from owlready2 import get_ontology - - -def get_name(iri: str): - return iri.split("/")[-1] - - -if __name__ == "__main__": - formulae = [] - - # Load disjointness axioms - # onto_dis = pyhornedowl.open_ontology("/data/ontologies/chebi-disjoints.owl") - # print("Process disjointness releation") - # formulae += [Implies(predicates[get_name(c)](x), Not(predicates[get_name(d)](x))) for _, c,d in (ax for ax in onto_dis.get_axioms() if ax[0] == "AxiomKind::SubClassOf" and isinstance(ax[-1], str))] - - model = Model() - x = Variable("x") - y = Variable("y") - - onto = pyhornedowl.open_ontology("/data/ontologies/chebi.owl") - - print("Process classes") - predicates = {get_name(c): Predicate(get_name(c)) for c in onto.get_classes()} - - print("Process subsumption releation") - formulae += [ - Implies(predicates[get_name(c)](x), predicates[get_name(d)](x)) - for _, c, d in ( - ax - for ax in onto.get_axioms() - if ax[0] == "AxiomKind::SubClassOf" and isinstance(ax[-1], str) - ) - ] - - model.add_knowledge(*formulae, world=World.AXIOM) - model.print() diff --git a/chebai/models/lstm.py b/chebai/models/lstm.py deleted file mode 100644 index c706d6a..0000000 --- a/chebai/models/lstm.py +++ /dev/null @@ -1,34 +0,0 @@ -import logging -import sys - -from torch import nn -from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence - -from chebai.models.base import ChebaiBaseNet - -logging.getLogger("pysmiles").setLevel(logging.CRITICAL) - - -class ChemLSTM(ChebaiBaseNet): - NAME = "LSTM" - - def __init__(self, in_d, out_d, num_classes, **kwargs): - super().__init__(num_classes, **kwargs) - self.lstm = nn.LSTM(in_d, out_d, batch_first=True) - self.embedding = nn.Embedding(800, 100) - self.output = nn.Sequential( - nn.Linear(out_d, in_d), - nn.ReLU(), - nn.Dropout(0.2), - nn.Linear(in_d, num_classes), - ) - - def forward(self, data): - x = data.x - x_lens = data.lens - x = self.embedding(x) - x = pack_padded_sequence(x, x_lens, batch_first=True, enforce_sorted=False) - x = self.lstm(x)[1][0] - # = pad_packed_sequence(x, batch_first=True)[0] - x = self.output(x) - return x.squeeze(0) diff --git a/chebai/models/recursive.py b/chebai/models/recursive.py deleted file mode 100644 index fb40803..0000000 --- a/chebai/models/recursive.py +++ /dev/null @@ -1,97 +0,0 @@ -import logging - -import networkx as nx -import torch -import torch.nn.functional as F -from torch import exp, nn, tensor - -from chebai.models.base import ChebaiBaseNet - -logging.getLogger("pysmiles").setLevel(logging.CRITICAL) - - -class Recursive(ChebaiBaseNet): - NAME = "REC" - - def __init__(self, in_d, out_d, num_classes, **kwargs): - super().__init__(num_classes, **kwargs) - mem_len = in_d - self.internal_dimension = in_d - self.embedding = nn.Embedding(800, 100) - - self.input_post = nn.Linear(in_d, in_d) - - self.input_attention = nn.MultiheadAttention(in_d, 5) - self.hidden_attention = nn.MultiheadAttention(in_d, 5) - self.merge_attention = nn.MultiheadAttention(in_d, 5) - - self.hidden_post = nn.Linear(in_d, in_d) - - self.merge_post = nn.Linear(in_d, in_d) - - self.post = nn.Linear(in_d, in_d) - - self.children_attention = nn.MultiheadAttention(in_d, 5) - - self.input_norm_1 = nn.LayerNorm(in_d) - self.input_norm_2 = nn.LayerNorm(in_d) - self.hidden_norm_1 = nn.LayerNorm(in_d) - self.merge_norm_1 = nn.LayerNorm(in_d) - self.merge_norm_2 = nn.LayerNorm(in_d) - - self.base = torch.nn.parameter.Parameter(torch.empty((in_d,))) - self.base_memory = torch.nn.parameter.Parameter(torch.empty((mem_len,))) - self.output = nn.Sequential( - nn.Linear(in_d, in_d), - nn.ReLU(), - nn.Dropout(0.2), - nn.Linear(in_d, num_classes), - ) - - def forward(self, batch): - result = [] - for row in batch: - graph = row[0] - c = nx.center(graph)[0] - d = nx.single_source_shortest_path(graph, c) - if graph.edges: - digraph = nx.DiGraph( - (a, b) if d[a] > d[b] else (b, a) for (a, b) in graph.edges - ) - else: - digraph = nx.DiGraph(graph) - child_results = {} - x = None - for node in nx.topological_sort(digraph): - child_values = child_results.pop(node, []) - inp = self.embedding(graph.nodes[node]["x"]) - if not child_values: - hidden_state = self.base_memory - else: - hidden_state = self.merge_childen(child_values, inp) - x = self.input(inp, hidden_state) - for s in digraph.successors(node): - child_results[s] = child_results.get(s, []) + [x] - result.append(self.output(x)) - return torch.stack(result) - - def merge_childen(self, child_values, x): - stack = torch.stack(child_values).unsqueeze(0).transpose(1, 0) - att = self.children_attention( - x.expand(1, stack.shape[1], -1).transpose(1, 0), stack, stack - )[0] - return torch.sum(att.squeeze(0), dim=0) - - def input(self, x0, hidden): - x = x0.unsqueeze(0).unsqueeze(0) - a = self.input_norm_1(x + self.input_attention(x, x, x)[0]) - a = self.input_norm_2(a + F.relu(self.input_post(a))) - - h0 = hidden.unsqueeze(0).unsqueeze(0) - b = self.hidden_norm_1(h0 + self.input_attention(h0, h0, h0)[0]) - # b = self.norm(b + self.hidden_post(b)) - - c = self.merge_norm_1(b + self.merge_attention(a, b, b)[0]) - c = self.merge_norm_2(c + F.relu(self.merge_post(c))) - - return self.post(c).squeeze(0).squeeze(0) diff --git a/chebai/models/strontex.py b/chebai/models/strontex.py deleted file mode 100644 index c22e72c..0000000 --- a/chebai/models/strontex.py +++ /dev/null @@ -1,14 +0,0 @@ -import abc -import typing - -import networkx as nx -import numpy as np -import torch - -FeatureType = typing.TypeVar("FeatureType") -LabelType = typing.TypeVar("LabelType") - - -class StrOntEx(torch.Module): - def __init__(self, computation_graph): - pass diff --git a/chebai/preprocessing/collate.py b/chebai/preprocessing/collate.py deleted file mode 100644 index ecbcb87..0000000 --- a/chebai/preprocessing/collate.py +++ /dev/null @@ -1,137 +0,0 @@ -from typing import Dict, List, Tuple, Union - -import torch -from torch.nn.utils.rnn import pad_sequence - -from chebai.preprocessing.structures import XYData - - -class Collator: - """Base class for collating data samples into a batch.""" - - def __init__(self, **kwargs): - pass - - def __call__(self, data: List[Dict]) -> XYData: - """Collate a list of data samples into a batch. - - Args: - data (List[Dict]): List of data samples. - - Returns: - XYData: Batched data. - """ - raise NotImplementedError - - -class DefaultCollator(Collator): - """Default collator that extracts features and labels.""" - - def __call__(self, data: List[Dict]) -> XYData: - """Collate data samples by extracting features and labels. - - Args: - data (List[Dict]): List of data samples. - - Returns: - XYData: Batched data. - """ - x, y = zip(*((d["features"], d["labels"]) for d in data)) - return XYData(x, y) - - -class RaggedCollator(Collator): - """ - Collator for handling ragged data samples, designed to support scenarios where some labels may be missing (None). - - This class is specifically designed for preparing batches of "ragged" data, where the samples may have varying sizes, - such as molecular representations or variable-length protein sequences. Additionally, it supports cases where some - of the data samples might be partially labeled, which is useful for certain loss functions that allow training - with incomplete or fuzzy data (e.g., fuzzy loss). - - During batching, the class pads the data samples to a uniform length, applies appropriate masks to differentiate - between valid and padded elements, and ensures that label misalignment is handled by filtering out unlabelled - data points. The indices of valid labels are stored in the `non_null_labels` field, which can be used later for - metrics computation such as F1-score or MSE, especially in cases where some data points lack labels. - - Reference: https://github.com/ChEB-AI/python-chebai/pull/48#issuecomment-2324393829 - """ - - def __call__(self, data: List[Union[Dict, Tuple]]) -> XYData: - """ - Collate ragged data samples (i.e., samples of unequal size, such as molecular sequences) into a batch. - - Handles both fully and partially labeled data, where some samples may have `None` as their label. The indices - of non-null labels are stored in the `non_null_labels` field, which is used to filter out predictions for - unlabeled data during evaluation (e.g., F1, MSE). For models supporting partially labeled data, this method - ensures alignment between features and labels. - - Args: - data (List[Union[Dict, Tuple]]): List of ragged data samples. Each sample can be a dictionary or tuple - with 'features', 'labels', and 'ident'. - - Returns: - XYData: A batch of padded sequences and labels, including masks for valid positions and indices of - non-null labels for metric computation. - """ - model_kwargs: Dict = dict() - # Indices of non-null labels are stored in key `non_null_labels` of loss_kwargs. - loss_kwargs: Dict = dict() - - if isinstance(data[0], tuple): - # For legacy data - x, y, idents = zip(*data) - else: - x, y, idents = zip( - *((d["features"], d["labels"], d.get("ident")) for d in data) - ) - if any(x is not None for x in y): - # If any label is not None: (None, None, `1`, None) - if any(x is None for x in y): - # If any label is None: (`None`, `None`, 1, `None`) - non_null_labels = [i for i, r in enumerate(y) if r is not None] - y = self.process_label_rows( - tuple(ye for i, ye in enumerate(y) if i in non_null_labels) - ) - loss_kwargs["non_null_labels"] = non_null_labels - else: - # If all labels are not None: (`0`, `2`, `1`, `3`) - y = self.process_label_rows(y) - else: - # If all labels are None : (`None`, `None`, `None`, `None`) - y = None - loss_kwargs["non_null_labels"] = [] - - # Calculate the lengths of each sequence, create a binary mask for valid (non-padded) positions - lens = torch.tensor(list(map(len, x))) - model_kwargs["mask"] = torch.arange(max(lens))[None, :] < lens[:, None] - model_kwargs["lens"] = lens - - return XYData( - pad_sequence([torch.tensor(a) for a in x], batch_first=True), - y, - model_kwargs=model_kwargs, - loss_kwargs=loss_kwargs, - idents=idents, - ) - - def process_label_rows(self, labels: Tuple) -> torch.Tensor: - """ - Process label rows by padding sequences to ensure uniform shape across the batch. - - This method pads the label rows, converting sequences of labels of different lengths into a uniform tensor. - It ensures that `None` values in the labels are handled by substituting them with a default value(e.g.,`False`). - - Args: - labels (Tuple): Tuple of label rows. - - Returns: - torch.Tensor: Padded label sequences. - """ - return pad_sequence( - [ - torch.tensor([v if v is not None else False for v in row]) - for row in labels - ], - batch_first=True, - ) diff --git a/chebai/preprocessing/collect_all.py b/chebai/preprocessing/collect_all.py deleted file mode 100644 index 6e24d83..0000000 --- a/chebai/preprocessing/collect_all.py +++ /dev/null @@ -1,225 +0,0 @@ -import logging -import os -import sys - -import pytorch_lightning as pl -import torch -import torch.nn.functional as F -from data import ClassificationData, JCIClassificationData -from pytorch_lightning import loggers as pl_loggers -from pytorch_lightning.callbacks import ModelCheckpoint -from pytorch_lightning.metrics import F1 -from sklearn.metrics import f1_score -from torch import nn -from torch_geometric import nn as tgnn -from torch_geometric.data import DataLoader - -logging.getLogger("pysmiles").setLevel(logging.CRITICAL) - - -class PartOfNet(pl.LightningModule): - def __init__(self, in_length, loops=10): - super().__init__() - self.loops = loops - self.left_graph_net = tgnn.GATConv(in_length, in_length) - self.right_graph_net = tgnn.GATConv(in_length, in_length) - self.attention = nn.Linear(in_length, 1) - self.global_attention = tgnn.GlobalAttention(self.attention) - self.output_net = nn.Sequential( - nn.Linear(2 * in_length, 2 * in_length), - nn.Linear(2 * in_length, in_length), - nn.Linear(in_length, 500), - ) - self.f1 = F1(1, threshold=0.5) - - def _execute(self, batch, batch_idx): - pred = self(batch) - loss = F.binary_cross_entropy_with_logits(pred, batch.label) - f1 = self.f1(batch.label, torch.sigmoid(pred)) - return loss, f1 - - def training_step(self, *args, **kwargs): - loss, f1 = self._execute(*args, **kwargs) - self.log( - "train_loss", - loss.detach().item(), - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - ) - self.log( - "train_f1", - f1.item(), - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - ) - return loss - - def validation_step(self, *args, **kwargs): - with torch.no_grad(): - loss, f1 = self._execute(*args, **kwargs) - self.log( - "val_loss", - loss.detach().item(), - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - ) - self.log( - "val_f1", - f1.item(), - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - ) - return loss - - def forward(self, x): - a = self.left_graph_net(x.x_s, x.edge_index_s.long()) - b = self.right_graph_net(x.x_t, x.edge_index_t.long()) - return self.output_net( - torch.cat( - [ - self.global_attention(a, x.x_s_batch), - self.global_attention(b, x.x_t_batch), - ], - dim=1, - ) - ) - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters()) - return optimizer - - -class JCINet(pl.LightningModule): - def __init__(self, in_length, hidden_length, num_classes, loops=10): - super().__init__() - self.loops = loops - - self.node_net = nn.Sequential( - nn.Linear(self.loops * in_length, hidden_length), nn.ReLU() - ) - self.embedding = torch.nn.Embedding(800, in_length) - self.left_graph_net = tgnn.GATConv(in_length, in_length, dropout=0.1) - self.final_graph_net = tgnn.GATConv(in_length, hidden_length, dropout=0.1) - self.attention = nn.Linear(hidden_length, 1) - self.global_attention = tgnn.GlobalAttention(self.attention) - self.output_net = nn.Sequential( - nn.Linear(hidden_length, hidden_length), - nn.Linear(hidden_length, num_classes), - ) - self.f1 = F1(num_classes, threshold=0.5) - - def _execute(self, batch, batch_idx): - pred = self(batch) - labels = batch.label.float() - loss = F.binary_cross_entropy_with_logits(pred, labels) - f1 = f1_score( - labels.cpu() > 0.5, torch.sigmoid(pred).cpu() > 0.5, average="micro" - ) - return loss, f1 - - def training_step(self, *args, **kwargs): - loss, f1 = self._execute(*args, **kwargs) - self.log( - "train_loss", - loss.detach().item(), - on_step=False, - on_epoch=True, - prog_bar=True, - logger=True, - ) - self.log( - "train_f1", - f1.item(), - on_step=False, - on_epoch=True, - prog_bar=True, - logger=True, - ) - return loss - - def validation_step(self, *args, **kwargs): - with torch.no_grad(): - loss, f1 = self._execute(*args, **kwargs) - self.log( - "val_loss", - loss.detach().item(), - on_step=False, - on_epoch=True, - prog_bar=True, - logger=True, - ) - self.log( - "val_f1", - f1.item(), - on_step=False, - on_epoch=True, - prog_bar=True, - logger=True, - ) - return loss - - def forward(self, x): - a = self.embedding(x.x) - l = [] - for _ in range(self.loops): - a = self.left_graph_net(a, x.edge_index.long()) - l.append(a) - at = self.global_attention(self.node_net(torch.cat(l, dim=1)), x.x_batch) - return self.output_net(at) - - def configure_optimizers(self): - optimizer = torch.optim.Adam(self.parameters()) - return optimizer - - -def train(train_loader, validation_loader): - if torch.cuda.is_available(): - trainer_kwargs = dict(gpus=-1, accelerator="ddp") - else: - trainer_kwargs = dict(gpus=0) - net = JCINet(100, 100, 500) - tb_logger = pl_loggers.CSVLogger("../../logs/") - checkpoint_callback = ModelCheckpoint( - dirpath=os.path.join(tb_logger.log_dir, "checkpoints"), - filename="{epoch}-{step}-{val_loss:.7f}", - save_top_k=5, - save_last=True, - verbose=True, - monitor="val_loss", - mode="min", - ) - trainer = pl.Trainer( - logger=tb_logger, - callbacks=[checkpoint_callback], - replace_sampler_ddp=False, - **trainer_kwargs - ) - trainer.fit(net, train_loader, val_dataloaders=validation_loader) - - -if __name__ == "__main__": - batch_size = int(sys.argv[1]) - # vl = ClassificationData("data/full_chebi", split="validation") - # tr = ClassificationData("data/full_chebi", split="train") - tr = JCIClassificationData("data/JCI_data", split="train") - vl = JCIClassificationData("data/JCI_data", split="validation") - - train_loader = DataLoader( - tr, - shuffle=True, - batch_size=batch_size, - follow_batch=["x", "edge_index", "label"], - ) - validation_loader = DataLoader( - vl, batch_size=batch_size, follow_batch=["x", "edge_index", "label"] - ) - - train(train_loader, validation_loader) diff --git a/chebai/preprocessing/datasets/base.py b/chebai/preprocessing/datasets/base.py deleted file mode 100644 index 3308ec9..0000000 --- a/chebai/preprocessing/datasets/base.py +++ /dev/null @@ -1,1180 +0,0 @@ -import os -import random -from abc import ABC, abstractmethod -from typing import Any, Dict, Generator, List, Optional, Tuple, Union - -import lightning as pl -import networkx as nx -import pandas as pd -import torch -import tqdm -from iterstrat.ml_stratifiers import ( - MultilabelStratifiedKFold, - MultilabelStratifiedShuffleSplit, -) -from lightning.pytorch.core.datamodule import LightningDataModule -from lightning_utilities.core.rank_zero import rank_zero_info -from sklearn.model_selection import StratifiedShuffleSplit -from torch.utils.data import DataLoader - -from chebai.preprocessing import reader as dr - - -class XYBaseDataModule(LightningDataModule): - """ - Base class for data modules. - - This class provides a base implementation for loading and preprocessing datasets. - It inherits from `LightningDataModule` and defines common properties and methods for data loading and processing. - - Args: - batch_size (int): The batch size for data loading. Default is 1. - train_split (float): The ratio of training data to total data and of test data to (validation + test) data. Default is 0.85. - reader_kwargs (dict): Additional keyword arguments to be passed to the data reader. Default is None. - prediction_kind (str): The kind of prediction to be performed (only relevant for the predict_dataloader). Default is "test". - data_limit (Optional[int]): The maximum number of data samples to load. If set to None, the complete dataset will be used. Default is None. - label_filter (Optional[int]): The index of the label to filter. Default is None. - balance_after_filter (Optional[float]): The ratio of negative samples to positive samples after filtering. Default is None. - num_workers (int): The number of worker processes for data loading. Default is 1. - inner_k_folds (int): The number of folds for inner cross-validation. Use -1 to disable inner cross-validation. Default is -1. - fold_index (Optional[int]): The index of the fold to use for training and validation. Default is None. - base_dir (Optional[str]): The base directory for storing processed and raw data. Default is None. - **kwargs: Additional keyword arguments. - - Attributes: - READER (DataReader): The data reader class to use. - reader (DataReader): An instance of the data reader class. - train_split (float): The ratio of training data to total data. - batch_size (int): The batch size for data loading. - prediction_kind (str): The kind of prediction to be performed. - data_limit (Optional[int]): The maximum number of data samples to load. - label_filter (Optional[int]): The index of the label to filter. - balance_after_filter (Optional[float]): The ratio of negative samples to positive samples after filtering. - num_workers (int): The number of worker processes for data loading. - inner_k_folds (int): The number of folds for inner cross-validation. If it is less than to, no cross-validation will be performed. - fold_index (Optional[int]): The index of the fold to use for training and validation (only relevant for cross-validation). - _base_dir (Optional[str]): The base directory for storing processed and raw data. - raw_dir (str): The directory for storing raw data. - processed_dir (str): The directory for storing processed data. - fold_dir (str): The name of the directory where the folds from inner cross-validation are stored. - _name (str): The name of the data module. - - """ - - READER = dr.DataReader - - def __init__( - self, - batch_size: int = 1, - train_split: float = 0.85, - reader_kwargs: Optional[dict] = None, - prediction_kind: str = "test", - data_limit: Optional[int] = None, - label_filter: Optional[int] = None, - balance_after_filter: Optional[float] = None, - num_workers: int = 1, - inner_k_folds: int = -1, # use inner cross-validation if > 1 - fold_index: Optional[int] = None, - base_dir: Optional[str] = None, - **kwargs, - ): - super().__init__() - if reader_kwargs is None: - reader_kwargs = dict() - self.reader = self.READER(**reader_kwargs) - self.train_split = train_split - self.batch_size = batch_size - self.prediction_kind = prediction_kind - self.data_limit = data_limit - self.label_filter = label_filter - assert (balance_after_filter is not None) or ( - self.label_filter is None - ), "Filter balancing requires a filter" - self.balance_after_filter = balance_after_filter - self.num_workers = num_workers - assert type(inner_k_folds) is int - self.inner_k_folds = inner_k_folds - self.use_inner_cross_validation = ( - inner_k_folds > 1 - ) # only use cv if there are at least 2 folds - assert ( - fold_index is None or self.use_inner_cross_validation is not None - ), "fold_index can only be set if cross validation is used" - if fold_index is not None and self.inner_k_folds is not None: - assert ( - fold_index < self.inner_k_folds - ), "fold_index can't be larger than the total number of folds" - self.fold_index = fold_index - self._base_dir = base_dir - os.makedirs(self.raw_dir, exist_ok=True) - os.makedirs(self.processed_dir, exist_ok=True) - if self.use_inner_cross_validation: - os.makedirs(os.path.join(self.raw_dir, self.fold_dir), exist_ok=True) - os.makedirs(os.path.join(self.processed_dir, self.fold_dir), exist_ok=True) - self.save_hyperparameters() - - @property - def identifier(self) -> tuple: - """Identifier for the dataset.""" - return (self.reader.name(),) - - @property - def full_identifier(self) -> tuple: - """Full identifier for the dataset.""" - return (self._name, *self.identifier) - - @property - def base_dir(self) -> str: - """Common base directory for processed and raw directories.""" - if self._base_dir is not None: - return self._base_dir - return os.path.join("data", self._name) - - @property - def processed_dir_main(self) -> str: - """Name of the directory where processed (but not tokenized) data is stored.""" - return os.path.join(self.base_dir, "processed") - - @property - def processed_dir(self) -> str: - """Name of the directory where the processed and tokenized data is stored.""" - return os.path.join(self.processed_dir_main, *self.identifier) - - @property - def raw_dir(self) -> str: - """Name of the directory where the raw data is stored.""" - return os.path.join(self.base_dir, "raw") - - @property - def fold_dir(self) -> str: - """Name of the directory where the folds from inner cross-validation (i.e., the train and val sets) are stored.""" - return f"cv_{self.inner_k_folds}_fold" - - @property - @abstractmethod - def _name(self) -> str: - """ - Abstract property representing the name of the data module. - - This property should be implemented in subclasses to provide a unique name for the data module. - The name is used to create subdirectories within the base directory or `processed_dir_main` - for storing relevant data associated with this module. - - Returns: - str: The name of the data module. - """ - pass - - def _filter_labels(self, row: dict) -> dict: - """ - Filter labels based on `label_filter`. - This method selects specific labels from the `labels` list within the row dictionary - according to the index or indices provided by the `label_filter` attribute of the class. - - Args: - row (dict): A dictionary containing the row data. - - Returns: - dict: The filtered row data. - """ - row["labels"] = [row["labels"][self.label_filter]] - return row - - def load_processed_data( - self, kind: Optional[str] = None, filename: Optional[str] = None - ) -> List: - """ - Load processed data from a file. Either the kind or the filename has to be provided. If both are provided, the - filename is used. - - Args: - kind (str, optional): The kind of dataset to load such as "train", "val" or "test". Defaults to None. - filename (str, optional): The name of the file to load the dataset from. Defaults to None. - - Returns: - List: The loaded processed data. - - Raises: - ValueError: If both kind and filename are None. - """ - if kind is None and filename is None: - raise ValueError( - "Either kind or filename is required to load the correct dataset, both are None" - ) - # if both kind and filename are given, use filename - if kind is not None and filename is None: - try: - # processed_file_names_dict is only implemented for _ChEBIDataExtractor - if self.use_inner_cross_validation and kind != "test": - filename = self.processed_file_names_dict[ - f"fold_{self.fold_index}_{kind}" - ] - else: - filename = self.processed_file_names_dict[kind] - except NotImplementedError: - filename = f"{kind}.pt" - return torch.load( - os.path.join(self.processed_dir, filename), weights_only=False - ) - - def dataloader(self, kind: str, **kwargs) -> DataLoader: - """ - Returns a DataLoader object for the specified kind (train, val or test) of data. - - Args: - kind (str): The kind indicates whether it is a train, val or test data to load. - **kwargs: Additional keyword arguments. - - Returns: - DataLoader: A DataLoader object. - """ - dataset = self.load_processed_data(kind) - if "ids" in kwargs: - ids = kwargs.pop("ids") - _dataset = [] - for i in range(len(dataset)): - if i in ids: - _dataset.append(dataset[i]) - dataset = _dataset - if self.label_filter is not None: - original_len = len(dataset) - dataset = [self._filter_labels(r) for r in dataset] - positives = [r for r in dataset if r["labels"][0]] - negatives = [r for r in dataset if not r["labels"][0]] - if self.balance_after_filter is not None: - negative_length = min( - original_len, int(len(positives) * self.balance_after_filter) - ) - dataset = positives + negatives[:negative_length] - else: - dataset = positives + negatives - random.shuffle(dataset) - if self.data_limit is not None: - dataset = dataset[: self.data_limit] - return DataLoader( - dataset, - collate_fn=self.reader.collator, - batch_size=self.batch_size, - **kwargs, - ) - - @staticmethod - def _load_dict( - input_file_path: str, - ) -> Generator[Dict[str, Any], None, None]: - """ - Load data from a file and return a dictionary. - - Args: - input_file_path (str): The path to the input file. - - Yields: - dict: A dictionary containing the features and labels. - """ - with open(input_file_path, "r") as input_file: - for row in input_file: - smiles, labels = row.split("\t") - yield dict(features=smiles, labels=labels) - - @staticmethod - def _get_data_size(input_file_path: str) -> int: - """ - Get the number of lines in a file. - - Args: - input_file_path (str): The path to the input file. - - Returns: - int: The number of lines in the file. - """ - with open(input_file_path, "r") as f: - return sum(1 for _ in f) - - def _load_data_from_file(self, path: str) -> List[Dict[str, Any]]: - """ - Load data from a file and return a list of dictionaries. - - Args: - path (str): The path to the input file. - - Returns: - List: A list of dictionaries containing the features and labels. - """ - lines = self._get_data_size(path) - print(f"Processing {lines} lines...") - data = [ - self.reader.to_data(d) - for d in tqdm.tqdm(self._load_dict(path), total=lines) - if d["features"] is not None - ] - # filter for missing features in resulting data - data = [val for val in data if val["features"] is not None] - - return data - - def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: - """ - Returns the train DataLoader. - - Args: - *args: Additional positional arguments. - **kwargs: Additional keyword arguments. - - Returns: - DataLoader: A DataLoader object for training data. - """ - return self.dataloader( - "train", - shuffle=True, - num_workers=self.num_workers, - persistent_workers=True, - **kwargs, - ) - - def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: - """ - Returns the validation DataLoader. - - Args: - *args: Additional positional arguments (unused). - **kwargs: Additional keyword arguments, passed to dataloader(). - - Returns: - Union[DataLoader, List[DataLoader]]: A DataLoader object for validation data. - """ - return self.dataloader( - "validation", - shuffle=False, - num_workers=self.num_workers, - persistent_workers=True, - **kwargs, - ) - - def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: - """ - Returns the test DataLoader. - - Args: - *args: Additional positional arguments (unused). - **kwargs: Additional keyword arguments, passed to dataloader(). - - Returns: - Union[DataLoader, List[DataLoader]]: A DataLoader object for test data. - """ - return self.dataloader("test", shuffle=False, **kwargs) - - def predict_dataloader( - self, *args, **kwargs - ) -> Union[DataLoader, List[DataLoader]]: - """ - Returns the predict DataLoader. - - Args: - *args: Additional positional arguments (unused). - **kwargs: Additional keyword arguments, passed to dataloader(). - - Returns: - Union[DataLoader, List[DataLoader]]: A DataLoader object for prediction data. - """ - return self.dataloader(self.prediction_kind, shuffle=False, **kwargs) - - def setup(self, **kwargs): - """ - Setup the data module. - - This method checks for the processed data and sets up the data module for training, validation, and testing. - - Args: - **kwargs: Additional keyword arguments. - """ - rank_zero_info(f"Check for processed data in {self.processed_dir}") - rank_zero_info(f"Cross-validation enabled: {self.use_inner_cross_validation}") - if any( - not os.path.isfile(os.path.join(self.processed_dir, f)) - for f in self.processed_file_names - ): - self.setup_processed() - - if not ("keep_reader" in kwargs and kwargs["keep_reader"]): - self.reader.on_finish() - - def setup_processed(self): - """ - Setup the processed data. - - This method should be implemented by subclasses to handle the specific setup of processed data. - """ - raise NotImplementedError - - @property - def processed_main_file_names_dict(self) -> dict: - """ - Returns a dictionary mapping processed data file names. - - Returns: - dict: A dictionary mapping dataset key to their respective file names. - For example, {"data": "data.pkl"}. - """ - raise NotImplementedError - - @property - def processed_main_file_names(self) -> List[str]: - """ - Returns a list of file names for processed data (before tokenization). - - Returns: - List[str]: A list of file names corresponding to the processed data. - """ - return list(self.processed_main_file_names_dict.values()) - - @property - def processed_file_names_dict(self) -> dict: - """ - Returns a dictionary for the processed and tokenized data files. - - Returns: - dict: A dictionary mapping dataset keys to their respective file names. - For example, {"data": "data.pt"}. - """ - raise NotImplementedError - - @property - def processed_file_names(self) -> List[str]: - """ - Returns a list of file names for processed data. - - Returns: - List[str]: A list of file names corresponding to the processed data. - """ - return list(self.processed_file_names_dict.values()) - - @property - def raw_file_names(self) -> List[str]: - """ - Returns the list of raw file names. - - Returns: - List[str]: The list of raw file names. - """ - return list(self.raw_file_names_dict.values()) - - @property - def raw_file_names_dict(self) -> dict: - """ - Returns the dictionary of raw file names (i.e., files that are directly obtained from an external source). - - This property should be implemented by subclasses to provide the dictionary of raw file names. - - Returns: - dict: The dictionary of raw file names. - """ - raise NotImplementedError - - @property - def label_number(self) -> int: - """ - Returns the number of labels. - - This property should be implemented by subclasses to provide the number of labels. - - Returns: - int: The number of labels. Returns -1 for seq2seq encoding. - """ - raise NotImplementedError - - -class MergedDataset(XYBaseDataModule): - MERGED = [] - - @property - def _name(self) -> str: - """ - Returns a concatenated name of all subset names. - """ - return "+".join(s._name for s in self.subsets) - - def __init__( - self, - batch_size: int = 1, - train_split: float = 0.85, - reader_kwargs: Union[None, List[dict]] = None, - **kwargs, - ): - """ - Args: - batch_size (int): Batch size for data loaders. - train_split (float): Fraction of data to use for training. - reader_kwargs (Union[None, List[dict]]): Optional arguments for subset readers. - **kwargs: Additional arguments to pass to LightningDataModule. - """ - if reader_kwargs is None: - reader_kwargs = [None for _ in self.MERGED] - self.train_split = train_split - self.batch_size = batch_size - self.subsets = [ - s(train_split=train_split, reader_kwargs=kws) - for s, kws in zip(self.MERGED, reader_kwargs) - ] - self.reader = self.subsets[0].reader - os.makedirs(self.processed_dir, exist_ok=True) - super(pl.LightningDataModule, self).__init__(**kwargs) - - def prepare_data(self): - """ - Placeholder for data preparation logic. - """ - for s in self.subsets: - s.prepare_data() - - def setup(self, **kwargs): - """ - Setup the data module. - - This method checks for the processed data and sets up the data module for training, validation, and testing. - - Args: - **kwargs: Additional keyword arguments. - """ - for s in self.subsets: - s.setup(**kwargs) - - def dataloader(self, kind: str, **kwargs) -> DataLoader: - """ - Creates a DataLoader for a specific subset. - - Args: - kind (str): Kind of data loader ('train', 'validation', or 'test'). - **kwargs: Additional arguments passed to DataLoader. - - Returns: - DataLoader: DataLoader object for the specified subset. - """ - subdatasets = [ - torch.load(os.path.join(s.processed_dir, f"{kind}.pt"), weights_only=False) - for s in self.subsets - ] - dataset = [ - self._process_data(i, d) - for i, (s, lim) in enumerate(zip(subdatasets, self.limits)) - for d in (s if lim is None else s[:lim]) - ] - return DataLoader( - dataset, - collate_fn=self.reader.collator, - batch_size=self.batch_size, - **kwargs, - ) - - def train_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: - """ - Returns the training DataLoader. - """ - return self.dataloader("train", shuffle=True, **kwargs) - - def val_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: - """ - Returns the validation DataLoader. - """ - return self.dataloader("validation", shuffle=False, **kwargs) - - def test_dataloader(self, *args, **kwargs) -> Union[DataLoader, List[DataLoader]]: - """ - Returns the test DataLoader. - """ - return self.dataloader("test", shuffle=False, **kwargs) - - def _process_data(self, subset_id: int, data: dict) -> dict: - """ - Processes data from a subset. - - Args: - subset_id (int): Index of the subset. - data (dict): Data from the subset. - - Returns: - dict: Processed data with 'features', 'labels', and 'ident' keys. - """ - return dict( - features=data["features"], labels=data["labels"], ident=data["ident"] - ) - - def setup_processed(self): - """ - Placeholder for setup logic after data processing. - """ - pass - - @property - def processed_file_names(self) -> List[str]: - """ - Returns the list of processed file names. - """ - return ["test.pt", "train.pt", "validation.pt"] - - @property - def label_number(self) -> int: - """ - Returns the number of labels from the first subset. - """ - return self.subsets[0].label_number - - @property - def limits(self): - """ - Returns None, assuming no limits on data slicing. - """ - return None - - -class _DynamicDataset(XYBaseDataModule, ABC): - """ - A class for extracting and processing data from the given dataset. - - The processed and transformed data is stored in `data.pkl` and `data.pt` format as a whole respectively, - rather than as separate train, validation, and test splits, with dynamic splitting of data.pt occurring at runtime. - The `_DynamicDataset` class manages data splits by either generating them during execution or retrieving them from - a CSV file. - If no split file path is provided, `_generate_dynamic_splits` creates the training, validation, and test splits - from the encoded/transformed data, storing them in `_dynamic_df_train`, `_dynamic_df_val`, and `_dynamic_df_test`. - When a split file path is provided, `_retrieve_splits_from_csv` loads splits from the CSV file, which must - include 'id' and 'split' columns. - The `dynamic_split_dfs` property ensures that the necessary splits are loaded as required. - - Args: - dynamic_data_split_seed (int, optional): The seed for random data splitting. Defaults to 42. - splits_file_path (str, optional): Path to the splits CSV file. Defaults to None. - **kwargs: Additional keyword arguments passed to XYBaseDataModule. - - Attributes: - dynamic_data_split_seed (int): The seed for random data splitting, default is 42. - splits_file_path (Optional[str]): Path to the CSV file containing split assignments. - """ - - # ---- Index for columns of processed `data.pkl` (should be derived from `_graph_to_raw_dataset` method) ------ - _ID_IDX: int = None - _DATA_REPRESENTATION_IDX: int = None - _LABELS_START_IDX: int = None - - def __init__( - self, - **kwargs, - ): - super(_DynamicDataset, self).__init__(**kwargs) - self.dynamic_data_split_seed = int(kwargs.get("seed", 42)) # default is 42 - # Class variables to store the dynamics splits - self._dynamic_df_train = None - self._dynamic_df_test = None - self._dynamic_df_val = None - # Path of csv file which contains a list of ids & their assignment to a dataset (either train, - # validation or test). - self.splits_file_path = self._validate_splits_file_path( - kwargs.get("splits_file_path", None) - ) - - @staticmethod - def _validate_splits_file_path(splits_file_path: Optional[str]) -> Optional[str]: - """ - Validates the file in provided splits file path. - - Args: - splits_file_path (Optional[str]): Path to the splits CSV file. - - Returns: - Optional[str]: Validated splits file path if checks pass, None if splits_file_path is None. - - Raises: - FileNotFoundError: If the splits file does not exist. - ValueError: If splits file is empty or missing required columns ('id' and/or 'split'), or not a CSV file. - """ - if splits_file_path is None: - return None - - if not os.path.isfile(splits_file_path): - raise FileNotFoundError(f"File {splits_file_path} does not exist") - - file_size = os.path.getsize(splits_file_path) - if file_size == 0: - raise ValueError(f"File {splits_file_path} is empty") - - # Check if the file has a CSV extension - if not splits_file_path.lower().endswith(".csv"): - raise ValueError(f"File {splits_file_path} is not a CSV file") - - # Read the first row of CSV file into a DataFrame - splits_df = pd.read_csv(splits_file_path, nrows=1) - - # Check if 'id' and 'split' columns are in the DataFrame - required_columns = {"id", "split"} - if not required_columns.issubset(splits_df.columns): - raise ValueError( - f"CSV file {splits_file_path} is missing required columns ('id' and/or 'split')." - ) - - return splits_file_path - - # ------------------------------ Phase: Prepare data ----------------------------------- - def prepare_data(self, *args: Any, **kwargs: Any) -> None: - """ - Prepares the data for the dataset. - - This method checks for the presence of raw data in the specified directory. - If the raw data is missing, it fetches the ontology and creates a dataframe and saves it to a data.pkl file. - - The resulting dataframe/pickle file is expected to contain columns with the following structure: - - Column at index `self._ID_IDX`: ID of data instance - - Column at index `self._DATA_REPRESENTATION_IDX`: Sequence representation of the protein - - Column from index `self._LABELS_START_IDX` onwards: Labels - - Args: - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - - Returns: - None - """ - print("Checking for processed data in", self.processed_dir_main) - - processed_name = self.processed_main_file_names_dict["data"] - if not os.path.isfile(os.path.join(self.processed_dir_main, processed_name)): - print(f"Missing processed data file (`{processed_name}` file)") - os.makedirs(self.processed_dir_main, exist_ok=True) - data_path = self._download_required_data() - g = self._extract_class_hierarchy(data_path) - data_df = self._graph_to_raw_dataset(g) - self.save_processed(data_df, processed_name) - - @abstractmethod - def _download_required_data(self) -> str: - """ - Downloads the required raw data. - - Returns: - str: Path to the downloaded data. - """ - pass - - @abstractmethod - def _extract_class_hierarchy(self, data_path: str) -> nx.DiGraph: - """ - Extracts the class hierarchy from the data. - Constructs a directed graph (DiGraph) using NetworkX, where nodes are annotated with fields/terms from - the term documents. - - Args: - data_path (str): Path to the data. - - Returns: - nx.DiGraph: The class hierarchy graph. - """ - pass - - @abstractmethod - def _graph_to_raw_dataset(self, graph: nx.DiGraph) -> pd.DataFrame: - """ - Converts the graph to a raw dataset. - Uses the graph created by `_extract_class_hierarchy` method to extract the - raw data in Dataframe format with additional columns corresponding to each multi-label class. - - Args: - graph (nx.DiGraph): The class hierarchy graph. - - Returns: - pd.DataFrame: The raw dataset. - """ - pass - - @abstractmethod - def select_classes(self, g: nx.DiGraph, *args, **kwargs) -> List: - """ - Selects classes from the dataset based on a specified criteria. - - Args: - g (nx.Graph): The graph representing the dataset. - *args: Additional positional arguments. - **kwargs: Additional keyword arguments. - - Returns: - List: A sorted list of node IDs that meet the specified criteria. - """ - pass - - def save_processed(self, data: pd.DataFrame, filename: str) -> None: - """ - Save the processed dataset to a pickle file. - - Args: - data (pd.DataFrame): The processed dataset to be saved. - filename (str): The filename for the pickle file. - """ - pd.to_pickle(data, open(os.path.join(self.processed_dir_main, filename), "wb")) - - # ------------------------------ Phase: Setup data ----------------------------------- - def setup_processed(self) -> None: - """ - Transforms `data.pkl` into a model input data format (`data.pt`), ensuring that the data is in a format - compatible for input to the model. - The transformed data contains the following keys: `ident`, `features`, `labels`, and `group`. - This method uses a subclass of Data Reader to perform the transformation. - - Returns: - None - """ - os.makedirs(self.processed_dir, exist_ok=True) - transformed_file_name = self.processed_file_names_dict["data"] - print( - f"Missing transformed data (`{transformed_file_name}` file). Transforming data.... " - ) - torch.save( - self._load_data_from_file( - os.path.join( - self.processed_dir_main, - self.processed_main_file_names_dict["data"], - ) - ), - os.path.join(self.processed_dir, transformed_file_name), - ) - - @staticmethod - def _get_data_size(input_file_path: str) -> int: - """ - Get the size of the data from a pickled file. - - Args: - input_file_path (str): The path to the file. - - Returns: - int: The size of the data. - """ - with open(input_file_path, "rb") as f: - return len(pd.read_pickle(f)) - - @abstractmethod - def _load_dict(self, input_file_path: str) -> Generator[Dict[str, Any], None, None]: - """ - Loads data from given pickled file and yields individual dictionaries for each row. - - This method is used by `_load_data_from_file` to generate dictionaries that are then - processed and converted into a list of dictionaries containing the features and labels. - - Args: - input_file_path (str): The path to the pickled input file. - - Yields: - Generator[Dict[str, Any], None, None]: Generator yielding dictionaries. - - """ - pass - - # ------------------------------ Phase: Dynamic Splits ----------------------------------- - @property - def dynamic_split_dfs(self) -> Dict[str, pd.DataFrame]: - """ - Property to retrieve dynamic train, validation, and test splits. - - This property checks if dynamic data splits (`_dynamic_df_train`, `_dynamic_df_val`, `_dynamic_df_test`) - are already loaded. If any of them is None, it either generates them dynamically or retrieves them - from data file with help of pre-existing split csv file (`splits_file_path`) containing splits assignments. - - Returns: - dict: A dictionary containing the dynamic train, validation, and test DataFrames. - Keys are 'train', 'validation', and 'test'. - """ - if any( - split is None - for split in [ - self._dynamic_df_test, - self._dynamic_df_val, - self._dynamic_df_train, - ] - ): - if self.splits_file_path is None: - # Generate splits based on given seed, create csv file to records the splits - self._generate_dynamic_splits() - else: - # If user has provided splits file path, use it to get the splits from the data - self._retrieve_splits_from_csv() - return { - "train": self._dynamic_df_train, - "validation": self._dynamic_df_val, - "test": self._dynamic_df_test, - } - - def _generate_dynamic_splits(self) -> None: - """ - Generate data splits during runtime and save them in class variables. - - This method loads encoded data and generates train, validation, and test splits based on the loaded data. - """ - print("\nGenerate dynamic splits...") - df_train, df_val, df_test = self._get_data_splits() - - # Generate splits.csv file to store ids of each corresponding split - split_assignment_list: List[pd.DataFrame] = [ - pd.DataFrame({"id": df_train["ident"], "split": "train"}), - pd.DataFrame({"id": df_val["ident"], "split": "validation"}), - pd.DataFrame({"id": df_test["ident"], "split": "test"}), - ] - - combined_split_assignment = pd.concat(split_assignment_list, ignore_index=True) - combined_split_assignment.to_csv( - os.path.join(self.processed_dir_main, "splits.csv"), index=False - ) - - # Store the splits in class variables - self._dynamic_df_train = df_train - self._dynamic_df_val = df_val - self._dynamic_df_test = df_test - - @abstractmethod - def _get_data_splits(self) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]: - """ - Retrieve the train, validation, and test data splits for the dataset. - - This method returns data splits according to specific criteria implemented - in the subclasses. - - Returns: - tuple: A tuple containing DataFrames for train, validation, and test splits. - """ - pass - - def get_test_split( - self, df: pd.DataFrame, seed: Optional[int] = None - ) -> Tuple[pd.DataFrame, pd.DataFrame]: - """ - Split the input DataFrame into training and testing sets based on multilabel stratified sampling. - - This method uses MultilabelStratifiedShuffleSplit to split the data such that the distribution of labels - in the training and testing sets is approximately the same. The split is based on the "labels" column - in the DataFrame. - - Args: - df (pd.DataFrame): The input DataFrame containing the data to be split. It must contain a column - named "labels" with the multilabel data. - seed (int, optional): The random seed to be used for reproducibility. Default is None. - - Returns: - Tuple[pd.DataFrame, pd.DataFrame]: A tuple containing the training set and testing set DataFrames. - - Raises: - ValueError: If the DataFrame does not contain a column named "labels". - """ - print("Get test data split") - - labels_list = df["labels"].tolist() - - test_size = 1 - self.train_split - (1 - self.train_split) ** 2 - - if len(labels_list[0]) > 1: - splitter = MultilabelStratifiedShuffleSplit( - n_splits=1, test_size=test_size, random_state=seed - ) - else: - splitter = StratifiedShuffleSplit( - n_splits=1, test_size=test_size, random_state=seed - ) - - train_indices, test_indices = next(splitter.split(labels_list, labels_list)) - - df_train = df.iloc[train_indices] - df_test = df.iloc[test_indices] - return df_train, df_test - - def get_train_val_splits_given_test( - self, df: pd.DataFrame, test_df: pd.DataFrame, seed: int = None - ) -> Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: - """ - Split the dataset into train and validation sets, given a test set. - Use test set (e.g., loaded from another source or generated in get_test_split), to avoid overlap - - Args: - df (pd.DataFrame): The original dataset. - test_df (pd.DataFrame): The test dataset. - seed (int, optional): The random seed to be used for reproducibility. Default is None. - - Returns: - Union[Dict[str, pd.DataFrame], Tuple[pd.DataFrame, pd.DataFrame]]: A dictionary containing train and - validation sets if self.use_inner_cross_validation is True, otherwise a tuple containing the train - and validation DataFrames. The keys are the names of the train and validation sets, and the values - are the corresponding DataFrames. - """ - print(f"Split dataset into train / val with given test set") - - test_ids = test_df["ident"].tolist() - df_trainval = df[~df["ident"].isin(test_ids)] - labels_list_trainval = df_trainval["labels"].tolist() - - if self.use_inner_cross_validation: - folds = {} - kfold = MultilabelStratifiedKFold( - n_splits=self.inner_k_folds, random_state=seed - ) - for fold, (train_ids, val_ids) in enumerate( - kfold.split( - labels_list_trainval, - labels_list_trainval, - ) - ): - df_validation = df_trainval.iloc[val_ids] - df_train = df_trainval.iloc[train_ids] - folds[self.raw_file_names_dict[f"fold_{fold}_train"]] = df_train - folds[self.raw_file_names_dict[f"fold_{fold}_validation"]] = ( - df_validation - ) - - return folds - - # scale val set size by 1/self.train_split to compensate for (hypothetical) test set size (1-self.train_split) - test_size = ((1 - self.train_split) ** 2) / self.train_split - - if len(labels_list_trainval[0]) > 1: - splitter = MultilabelStratifiedShuffleSplit( - n_splits=1, test_size=test_size, random_state=seed - ) - else: - splitter = StratifiedShuffleSplit( - n_splits=1, test_size=test_size, random_state=seed - ) - - train_indices, validation_indices = next( - splitter.split(labels_list_trainval, labels_list_trainval) - ) - - df_validation = df_trainval.iloc[validation_indices] - df_train = df_trainval.iloc[train_indices] - return df_train, df_validation - - def _retrieve_splits_from_csv(self) -> None: - """ - Retrieve previously saved data splits from splits.csv file or from provided file path. - - This method loads the splits.csv file located at `self.splits_file_path`. - It then loads the encoded data (`data.pt`) and filters it based on the IDs retrieved from - splits.csv to reconstruct the train, validation, and test splits. - """ - print(f"\nLoading splits from {self.splits_file_path}...") - splits_df = pd.read_csv(self.splits_file_path) - - filename = self.processed_file_names_dict["data"] - data = self.load_processed_data(filename=filename) - df_data = pd.DataFrame(data) - - train_ids = splits_df[splits_df["split"] == "train"]["id"] - validation_ids = splits_df[splits_df["split"] == "validation"]["id"] - test_ids = splits_df[splits_df["split"] == "test"]["id"] - - self._dynamic_df_train = df_data[df_data["ident"].isin(train_ids)] - self._dynamic_df_val = df_data[df_data["ident"].isin(validation_ids)] - self._dynamic_df_test = df_data[df_data["ident"].isin(test_ids)] - - def load_processed_data( - self, kind: Optional[str] = None, filename: Optional[str] = None - ) -> List[Dict[str, Any]]: - """ - Loads processed data from a specified dataset type or file. - - This method retrieves processed data based on the dataset type (`kind`) such as "train", - "val", or "test", or directly from a provided filename. When `kind` is specified, the method - leverages the `dynamic_split_dfs` property to dynamically generate or retrieve the corresponding - data splits if they are not already loaded. If both `kind` and `filename` are provided, `filename` - takes precedence. - - Args: - kind (str, optional): The type of dataset to load ("train", "val", or "test"). - If `filename` is provided, this argument is ignored. Defaults to None. - filename (str, optional): The name of the file to load the dataset from. - If provided, this takes precedence over `kind`. Defaults to None. - - Returns: - List[Dict[str, Any]]: A list of dictionaries, where each dictionary contains - the processed data for an individual data point. - - Raises: - ValueError: If both `kind` and `filename` are None, as one of them is required to load the dataset. - KeyError: If the specified `kind` does not exist in the `dynamic_split_dfs` property or - `processed_file_names_dict`, when expected. - FileNotFoundError: If the file corresponding to the provided `filename` does not exist. - """ - if kind is None and filename is None: - raise ValueError( - "Either kind or filename is required to load the correct dataset, both are None" - ) - - # If both kind and filename are given, use filename - if kind is not None and filename is None: - try: - if self.use_inner_cross_validation and kind != "test": - filename = self.processed_file_names_dict[ - f"fold_{self.fold_index}_{kind}" - ] - else: - data_df = self.dynamic_split_dfs[kind] - return data_df.to_dict(orient="records") - except KeyError: - kind = f"{kind}" - - # If filename is provided - try: - return torch.load( - os.path.join(self.processed_dir, filename), weights_only=False - ) - except FileNotFoundError: - raise FileNotFoundError(f"File {filename} doesn't exist") - - # ------------------------------ Phase: Raw Properties ----------------------------------- - @property - @abstractmethod - def base_dir(self) -> str: - """ - Returns the base directory path for storing data. - - Returns: - str: The path to the base directory. - """ - pass - - @property - def processed_dir_main(self) -> str: - """ - Returns the main directory path where processed data is stored. - - Returns: - str: The path to the main processed data directory, based on the base directory and the instance's name. - """ - return os.path.join( - self.base_dir, - self._name, - "processed", - ) - - @property - def processed_main_file_names_dict(self) -> dict: - """ - Returns a dictionary mapping processed data file names. - - Returns: - dict: A dictionary mapping dataset key to their respective file names. - For example, {"data": "data.pkl"}. - """ - return {"data": "data.pkl"} - - @property - def raw_file_names(self) -> List[str]: - """ - Returns a list of raw file names. - - Returns: - List[str]: A list of file names corresponding to the raw data. - """ - return list(self.raw_file_names_dict.values()) - - @property - def processed_file_names_dict(self) -> dict: - """ - Returns a dictionary for the processed and tokenized data files. - - Returns: - dict: A dictionary mapping dataset keys to their respective file names. - For example, {"data": "data.pt"}. - """ - return {"data": "data.pt"} diff --git a/chebai/preprocessing/structures.py b/chebai/preprocessing/structures.py deleted file mode 100644 index 1fb3711..0000000 --- a/chebai/preprocessing/structures.py +++ /dev/null @@ -1,141 +0,0 @@ -from typing import Any, Tuple, Union - -import networkx as nx -import torch - - -class XYData(torch.utils.data.Dataset): - """ - A dataset class for handling pairs of data (x, y). - - Args: - x: Input data. - y: Target data. - kwargs: Additional fields to store in the dataset. - """ - - def __init__( - self, x: Union[torch.Tensor, Tuple[Any, ...]], y: torch.Tensor, **kwargs - ): - super().__init__() - self.additional_fields = kwargs - self.x = x - self.y = y - - def __getitem__(self, index: int): - """Returns the data and target at the given index.""" - return self.x[index], self.y[index] - - def __len__(self) -> int: - """Returns the size of the dataset.""" - return len(self.x) - - def to_x(self, device: torch.device) -> Union[torch.Tensor, Tuple[Any, ...]]: - """ - Moves the input data to the specified device. - - Args: - device: The device to move the data to. - - Returns: - The input data on the specified device. - """ - if isinstance(self.x, tuple): - res = [] - for elem in self.x: - if isinstance(elem, dict): - for k, v in elem.items(): - elem[k] = v.to(device) if v is not None else None - else: - elem = elem.to(device) - res.append(elem) - return tuple(res) - return self.x.to(device) - - def to_y(self, device: torch.device) -> torch.Tensor: - """ - Moves the target data to the specified device. - - Args: - device: The device to move the data to. - - Returns: - The target data on the specified device. - """ - return self.y.to(device) - - def _to_if_tensor(self, obj: Any, device: torch.device) -> Any: - """ - Recursively moves the object to the specified device if it is a tensor. - - Args: - obj: The object to move. - device: The device to move the object to. - - Returns: - The object on the specified device. - """ - if isinstance(obj, torch.Tensor): - return obj.to(device) - elif isinstance(obj, dict): - return {k: self._to_if_tensor(v, device) for k, v in obj.items()} - elif isinstance(obj, list): - return [self._to_if_tensor(v, device) for v in obj] - else: - return obj - - def to(self, device: torch.device) -> "XYData": - """ - Moves the dataset to the specified device. - - Args: - device: The device to move the dataset to. - - Returns: - A new dataset on the specified device. - """ - x = self.to_x(device) - if self.y is not None: - y = self.to_y(device) - else: - y = None - return XYData( - x, - y, - **{ - k: self._to_if_tensor(v, device) - for k, v in self.additional_fields.items() - }, - ) - - -class XYMolData(XYData): - """ - A dataset class for handling molecular data represented as NetworkX graphs. - - Args: - x: Input molecular graphs. - y: Target data. - kwargs: Additional fields to store in the dataset. - """ - - def to_x(self, device: torch.device) -> Tuple[nx.Graph, ...]: - """ - Moves the node attributes of the molecular graphs to the specified device. - - Args: - device: The device to move the data to. - - Returns: - A tuple of molecular graphs with node attributes on the specified device. - """ - l = [] - for g in self.x: - graph = g.copy() - nx.set_node_attributes( - graph, - {k: v.to(device) for k, v in nx.get_node_attributes(g, "x").items()}, - "x", - ) - l.append(graph) - return tuple(l) diff --git a/chebai/result/__init__.py b/chebai/result/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/chebai/result/analyse_sem.py b/chebai/result/analyse_sem.py deleted file mode 100644 index 51a1fb2..0000000 --- a/chebai/result/analyse_sem.py +++ /dev/null @@ -1,721 +0,0 @@ -import gc -import sys -import traceback -from datetime import datetime -from typing import List, LiteralString - -from torchmetrics.functional.classification import ( - multilabel_auroc, - multilabel_average_precision, - multilabel_f1_score, -) -from utils import * - -from chebai.loss.semantic import DisjointLoss -from chebai.preprocessing.datasets.base import _DynamicDataset -from chebai.preprocessing.datasets.chebi import ChEBIOver100 -from chebai.preprocessing.datasets.pubchem import PubChemKMeans - -DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") - - -def binary(left, right): - return torch.logical_and(left > 0.5, right > 0.5) - - -def strict(left, right): - return left + right > 1 - - -def weak(left, right): - return left + right > 1.01 - - -def product(left, right): - return left * right - - -def lukasiewicz(left, right): - return torch.relu(left + right - 1) - - -def apply_metric(metric, left, right): - return torch.sum(metric(left, right), dim=0) - - -ALL_CONSISTENCY_METRICS = [product, lukasiewicz, weak, strict, binary] - - -def _filter_to_one_hot(preds, idx_filter): - """Takes list of indices (e.g. [1, 3, 0]) and returns a one-hot filter with these indices - (e.g. [[0,1,0,0], [0,0,0,1], [1,0,0,0]])""" - res = torch.zeros((len(idx_filter), preds.shape[1]), dtype=torch.bool) - for i, idx in enumerate(idx_filter): - res[i][idx] = True - return res - - -def _sort_results_by_label(n_labels, results, filter): - by_label = torch.zeros(n_labels, device=DEVICE, dtype=torch.int) - for r, filter_l in zip(results, filter): - by_label[filter_l] += r - return by_label - - -def get_best_epoch(run): - files = run.files() - best_ep = None - best_micro_f1 = 0 - for file in files: - if file.name.startswith("checkpoints/best_epoch"): - micro_f1 = float(file.name.split("=")[-1][:-5]) - if micro_f1 > best_micro_f1 or best_ep is None: - best_ep = int(file.name.split("=")[1].split("_")[0]) - best_micro_f1 = micro_f1 - if best_ep is None: - raise Exception(f"Could not find any 'best' checkpoint for run {run.id}") - else: - print(f"Best epoch for run {run.id}: {best_ep}") - return best_ep - - -def download_model_from_wandb( - run_id, base_dir=os.path.join("logs", "downloaded_ckpts") -): - api = wandb.Api() - run = api.run(f"chebai/chebai/{run_id}") - epoch = get_best_epoch(run) - return ( - get_checkpoint_from_wandb(epoch, run, root=base_dir), - epoch, - ) - - -def load_preds_labels( - ckpt_path: LiteralString, data_module, data_subset_key="test", buffer_dir=None -): - if buffer_dir is None: - buffer_dir = os.path.join( - "results_buffer", - *ckpt_path.split(os.path.sep)[-2:], - f"{data_module.__class__.__name__}_{data_subset_key}", - ) - model = Electra.load_from_checkpoint(ckpt_path, map_location="cuda:0", strict=False) - print( - f"Calculating predictions on {data_module.__class__.__name__} ({data_subset_key})..." - ) - evaluate_model( - model, - data_module, - buffer_dir=buffer_dir, - # for chebi, use kinds, otherwise use file names - filename=( - data_subset_key if not isinstance(buffer_dir, _DynamicDataset) else None - ), - kind=data_subset_key, - skip_existing_preds=True, - batch_size=1, - ) - return load_results_from_buffer(buffer_dir, device=torch.device("cpu")) - - -def get_label_names(data_module): - if os.path.exists(os.path.join(data_module.processed_dir_main, "classes.txt")): - with open(os.path.join(data_module.processed_dir_main, "classes.txt")) as fin: - return [int(line.strip()) for line in fin] - print( - f"Failed to retrieve label names, {os.path.join(data_module.processed_dir_main, 'classes.txt')} not found" - ) - return None - - -def get_chebi_graph(data_module, label_names): - if os.path.exists(os.path.join(data_module.raw_dir, "chebi.obo")): - chebi_graph = data_module.extract_class_hierarchy( - os.path.join(data_module.raw_dir, "chebi.obo") - ) - return chebi_graph.subgraph(label_names) - print( - f"Failed to retrieve ChEBI graph, {os.path.join(data_module.raw_dir, 'chebi.obo')} not found" - ) - return None - - -def get_disjoint_groups(): - disjoints_owl_file = os.path.join("data", "chebi-disjoints.owl") - with open(disjoints_owl_file, "r") as f: - plaintext = f.read() - segments = plaintext.split("<") - disjoint_pairs = [] - left = None - for seg in segments: - if seg.startswith("rdf:Description ") or seg.startswith("owl:Class"): - left = int(seg.split('rdf:about="&obo;CHEBI_')[1].split('"')[0]) - elif seg.startswith("owl:disjointWith"): - right = int(seg.split('rdf:resource="&obo;CHEBI_')[1].split('"')[0]) - disjoint_pairs.append([left, right]) - - disjoint_groups = [] - for seg in plaintext.split(""): - if "owl;AllDisjointClasses" in seg: - classes = seg.split('rdf:about="&obo;CHEBI_')[1:] - classes = [int(c.split('"')[0]) for c in classes] - disjoint_groups.append(classes) - disjoint_all = disjoint_pairs + disjoint_groups - # one disjointness is commented out in the owl-file - # (the correct way would be to parse the owl file and notice the comment symbols, but for this case, it should work) - disjoint_all.remove([22729, 51880]) - print(f"Found {len(disjoint_all)} disjoint groups") - return disjoint_all - - -class PredictionSmoother: - """Removes implication and disjointness violations from predictions""" - - def __init__(self, dataset): - self.label_names = get_label_names(dataset) - self.chebi_graph = get_chebi_graph(dataset, self.label_names) - self.disjoint_groups = get_disjoint_groups() - - def __call__(self, preds): - - preds_sum_orig = torch.sum(preds) - print(f"Preds sum: {preds_sum_orig}") - # eliminate implication violations by setting each prediction to maximum of its successors - for i, label in enumerate(self.label_names): - succs = [ - self.label_names.index(p) for p in self.chebi_graph.successors(label) - ] + [i] - if len(succs) > 0: - preds[:, i] = torch.max(preds[:, succs], dim=1).values - print(f"Preds change (step 1): {torch.sum(preds) - preds_sum_orig}") - preds_sum_orig = torch.sum(preds) - # step 2: eliminate disjointness violations: for group of disjoint classes, set all except max to 0.49 (if it is not already lower) - preds_bounded = torch.min(preds, torch.ones_like(preds) * 0.49) - for disj_group in self.disjoint_groups: - disj_group = [ - self.label_names.index(g) for g in disj_group if g in self.label_names - ] - if len(disj_group) > 1: - old_preds = preds[:, disj_group] - disj_max = torch.max(preds[:, disj_group], dim=1) - for i, row in enumerate(preds): - for l in range(len(preds[i])): - if l in disj_group and l != disj_group[disj_max.indices[i]]: - preds[i, l] = preds_bounded[i, l] - samples_changed = 0 - for i, row in enumerate(preds[:, disj_group]): - if any(r != o for r, o in zip(row, old_preds[i])): - samples_changed += 1 - if samples_changed != 0: - print( - f"disjointness group {[self.label_names[d] for d in disj_group]} changed {samples_changed} samples" - ) - print( - f"Preds change after disjointness (step 2): {torch.sum(preds) - preds_sum_orig}" - ) - preds_sum_orig = torch.sum(preds) - # step 3: disjointness violation removal may have caused new implication inconsistencies -> set each prediction to min of predecessors - for i, label in enumerate(self.label_names): - predecessors = [i] + [ - self.label_names.index(p) for p in self.chebi_graph.predecessors(label) - ] - lowest_predecessors = torch.min(preds[:, predecessors], dim=1) - preds[:, i] = lowest_predecessors.values - for idx_idx, idx in enumerate(lowest_predecessors.indices): - if idx > 0: - print( - f"class {label}: changed prediction of sample {idx_idx} to value of class " - f"{self.label_names[predecessors[idx]]} ({preds[idx_idx, i].item():.2f})" - ) - if torch.sum(preds) != preds_sum_orig: - print( - f"Preds change (step 3) for {label}: {torch.sum(preds) - preds_sum_orig}" - ) - preds_sum_orig = torch.sum(preds) - return preds - - -def _filter_to_dense(filter): - filter_dense = [] - for i in range(filter.shape[0]): - for j in range(filter.shape[1]): - if filter[i, j] > 0: - filter_dense.append([i, j]) - return torch.tensor(filter_dense) - - -def build_prediction_filter(data_module_labeled=None): - if data_module_labeled is None: - data_module_labeled = ChEBIOver100(chebi_version=231) - # prepare filters - print(f"Loading implication / disjointness filters...") - dl = DisjointLoss( - path_to_disjointness=os.path.join("data", "disjoint.csv"), - data_extractor=data_module_labeled, - ) - impl = _filter_to_dense(dl.implication_filter_l) - disj = _filter_to_dense(dl.disjoint_filter_l) - - return [ - (impl[:, 0], impl[:, 1], "impl"), - (disj[:, 0], disj[:, 1], "disj"), - ] - - -def run_consistency_metrics( - preds, - consistency_filters, - data_module_labeled=None, # use labels from this dataset for violations - violation_metrics=None, - verbose_violation_output=False, - save_details_to=None, -): - """Calculates all semantic metrics for given predictions (and supervised metrics if labels are provided)""" - if violation_metrics is None: - violation_metrics = ALL_CONSISTENCY_METRICS - if data_module_labeled is None: - data_module_labeled = ChEBIOver100(chebi_version=231) - if save_details_to is not None: - os.makedirs(save_details_to, exist_ok=True) - - preds.to("cpu") - - n_labels = preds.size(1) - print(f"Found {preds.shape[0]} predictions ({n_labels} classes)") - - results = {} - - for dl_filter_l, dl_filter_r, filter_type in consistency_filters: - l_preds = preds[:, dl_filter_l] - r_preds = preds[:, dl_filter_r] - for i, metric in enumerate(violation_metrics): - if metric.__name__ not in results: - results[metric.__name__] = {} - print(f"Calculating metrics {metric.__name__} on {filter_type}") - - metric_results = {} - metric_results["tps"] = torch.sum( - torch.stack( - [ - apply_metric( - metric, - l_preds[i : i + 1000], - ( - r_preds[i : i + 1000] - if filter_type == "impl" - else 1 - r_preds[i : i + 1000] - ), - ) - for i in range(0, r_preds.shape[0], 1000) - ] - ), - dim=0, - ) - metric_results["fns"] = torch.sum( - torch.stack( - [ - apply_metric( - metric, - l_preds[i : i + 1000], - ( - 1 - r_preds[i : i + 1000] - if filter_type == "impl" - else r_preds[i : i + 1000] - ), - ) - for i in range(0, r_preds.shape[0], 1000) - ] - ), - dim=0, - ) - if verbose_violation_output: - label_names = get_label_names(data_module_labeled) - print( - f"Found {torch.sum(metric_results['fns'])} {filter_type}-violations" - ) - # for k, fn_cls in enumerate(metric_results['fns']): - # if fn_cls > 0: - # print(f"\tThereof, {fn_cls.item()} belong to class {label_names[k]}") - if torch.sum(metric_results["fns"]) != 0: - fns = metric( - l_preds, 1 - r_preds if filter_type == "impl" else r_preds - ) - print(fns.shape) - for k, row in enumerate(fns): - if torch.sum(row) != 0: - print(f"{torch.sum(row)} violations for entity {k}") - for j, violation in enumerate(row): - if violation > 0: - print( - f"\tviolated ({label_names[dl_filter_l[j]]} -> {preds[k, dl_filter_l[j]]:.3f}" - f", {label_names[dl_filter_r[j]]} -> {preds[k, dl_filter_r[j]]:.3f})" - ) - - m_l_agg = {} - for key, value in metric_results.items(): - m_l_agg[key] = _sort_results_by_label( - n_labels, - value, - dl_filter_l, - ) - m_r_agg = {} - for key, value in metric_results.items(): - m_r_agg[key] = _sort_results_by_label( - n_labels, - value, - dl_filter_r, - ) - - if save_details_to is not None: - with open( - os.path.join( - save_details_to, f"{metric.__name__}_{filter_type}_all.csv" - ), - "w+", - ) as f: - f.write("left,right,tps,fns\n") - for left, right, tps, fns in zip( - dl_filter_l, - dl_filter_r, - metric_results["tps"], - metric_results["fns"], - ): - f.write(f"{left},{right},{tps},{fns}\n") - with open( - os.path.join( - save_details_to, f"{metric.__name__}_{filter_type}_l.csv" - ), - "w+", - ) as f: - f.write("left,tps,fns\n") - for left in range(n_labels): - f.write( - f"{left},{m_l_agg['tps'][left].item()},{m_l_agg['fns'][left].item()}\n" - ) - with open( - os.path.join( - save_details_to, f"{metric.__name__}_{filter_type}_r.csv" - ), - "w+", - ) as f: - f.write("right,tps,fns\n") - for right in range(n_labels): - f.write( - f"{right},{m_r_agg['tps'][right].item()},{m_r_agg['fns'][right].item()}\n" - ) - print( - f"Saved unaggregated consistency metrics ({metric.__name__}, {filter_type}) to {save_details_to}" - ) - - fns_sum = torch.sum(metric_results["fns"]).item() - results[metric.__name__][f"micro-fnr-{filter_type}"] = ( - 0 - if fns_sum == 0 - else ( - torch.sum(metric_results["fns"]) - / ( - torch.sum(metric_results[f"tps"]) - + torch.sum(metric_results[f"fns"]) - ) - ).item() - ) - macro_fnr_l = m_l_agg[f"fns"] / (m_l_agg[f"tps"] + m_l_agg[f"fns"]) - results[metric.__name__][f"lmacro-fnr-{filter_type}"] = ( - 0 - if fns_sum == 0 - else torch.mean(macro_fnr_l[~macro_fnr_l.isnan()]).item() - ) - macro_fnr_r = m_r_agg[f"fns"] / (m_r_agg[f"tps"] + m_r_agg[f"fns"]) - results[metric.__name__][f"rmacro-fnr-{filter_type}"] = ( - 0 - if fns_sum == 0 - else torch.mean(macro_fnr_r[~macro_fnr_r.isnan()]).item() - ) - results[metric.__name__][f"fn-sum-{filter_type}"] = torch.sum( - metric_results["fns"] - ).item() - results[metric.__name__][f"tp-sum-{filter_type}"] = torch.sum( - metric_results["tps"] - ).item() - - del metric_results - del m_l_agg - del m_r_agg - - gc.collect() - del l_preds - del r_preds - gc.collect() - - return results - - -def run_supervised_metrics(preds, labels, save_details_to=None): - # calculate supervised metrics - results = {} - if labels is not None: - results["micro-f1"] = multilabel_f1_score( - preds, labels, num_labels=preds.size(1), average="micro" - ).item() - results["macro-f1"] = multilabel_f1_score( - preds, labels, num_labels=preds.size(1), average="macro" - ).item() - results["micro-roc-auc"] = multilabel_auroc( - preds, labels, num_labels=preds.size(1), average="micro" - ).item() - results["macro-roc-auc"] = multilabel_auroc( - preds, labels, num_labels=preds.size(1), average="macro" - ).item() - - results["micro-ap"] = multilabel_average_precision( - preds, labels, num_labels=preds.size(1), average="micro" - ).item() - results["macro-ap"] = multilabel_average_precision( - preds, labels, num_labels=preds.size(1), average="macro" - ).item() - - if save_details_to is not None: - f1_by_label = multilabel_f1_score( - preds, labels, num_labels=preds.size(1), average=None - ) - roc_by_label = multilabel_auroc( - preds, labels, num_labels=preds.size(1), average=None - ) - ap_by_label = multilabel_average_precision( - preds, labels, num_labels=preds.size(1), average=None - ) - with open(os.path.join(save_details_to, f"supervised.csv"), "w+") as f: - f.write("label,f1,roc-auc,ap\n") - for right in range(preds.size(1)): - f.write( - f"{right},{f1_by_label[right].item()},{roc_by_label[right].item()},{ap_by_label[right].item()}\n" - ) - print(f"Saved class-wise supervised metrics to {save_details_to}") - - del preds - del labels - gc.collect() - return results - - -# run predictions / metrics calculations for semantic loss paper runs (NeSy 2024 submission) -def run_semloss_eval(): - # runs from wandb - non_wandb_runs = [] - api = wandb.Api() - runs = api.runs("chebai/chebai", filters={"tags": "eval_semloss_paper"}) - print(f"Found {len(runs)} tagged wandb runs") - ids_wandb = [run.id for run in runs] - - # ids used in the NeSy submission - prod = ["tk15yznc", "uke62a8m", "w0h3zr5s"] - xu19 = ["5ko8knb4", "061fd85t", "r50ioujs"] - prod_mixed = ["hk8555ff", "e0lxw8py", "lig23cmg"] - luka = ["0c0s48nh", "lfg384bp", "qeghvubh"] - baseline = ["i4wtz1k4", "zd020wkv", "rc1q3t49"] - prodk2 = ["ng3usn0p", "rp0wwzjv", "8fma1q7r"] - ids = baseline + prod + prodk2 + xu19 + luka + prod_mixed - # ids = ids_wandb - run_all( - ids, - non_wandb_runs, - prediction_datasets=[(ChEBIOver100(chebi_version=231), "test")], - consistency_metrics=[binary], - ) - - -def run_all( - wandb_ids=None, - local_ckpts: List[Tuple] = None, - consistency_metrics: Optional[List[callable]] = None, - prediction_datasets: List[Tuple] = None, - remove_violations: bool = False, - results_dir="_fuzzy_loss_eval", - check_consistency_on=None, - verbose_violation_output=False, -): - if wandb_ids is None: - wandb_ids = [] - if local_ckpts is None: - local_ckpts = [] - if consistency_metrics is None: - consistency_metrics = ALL_CONSISTENCY_METRICS - if prediction_datasets is None: - prediction_datasets = [ - (ChEBIOver100(chebi_version=231), "test"), - ] - if check_consistency_on is None: - check_consistency_on = ChEBIOver100(chebi_version=231) - - if remove_violations: - smooth_preds = PredictionSmoother(check_consistency_on) - else: - smooth_preds = lambda x: x - - timestamp = datetime.now().strftime("%y%m%d-%H%M%S") - prediction_filters = build_prediction_filter(check_consistency_on) - - results_path_consistency = os.path.join( - results_dir, - f"consistency_metrics_{timestamp}{'_violations_removed' if remove_violations else ''}.csv", - ) - consistency_keys = [ - "micro-fnr-impl", - "lmacro-fnr-impl", - "rmacro-fnr-impl", - "fn-sum-impl", - "tp-sum-impl", - "micro-fnr-disj", - "lmacro-fnr-disj", - "rmacro-fnr-disj", - "fn-sum-disj", - "tp-sum-disj", - ] - with open(results_path_consistency, "x") as f: - f.write( - "run-id,epoch,datamodule,data_key,metric," - + ",".join(consistency_keys) - + "\n" - ) - results_path_supervised = os.path.join( - results_dir, - f"supervised_metrics_{timestamp}{'_violations_removed' if remove_violations else ''}.csv", - ) - supervised_keys = [ - "micro-f1", - "macro-f1", - "micro-roc-auc", - "macro-roc-auc", - "micro-ap", - "macro-ap", - ] - with open(results_path_supervised, "x") as f: - f.write("run-id,epoch,datamodule,data_key," + ",".join(supervised_keys) + "\n") - - ckpts = [(run_name, ep, None) for run_name, ep in local_ckpts] + [ - (None, None, wandb_id) for wandb_id in wandb_ids - ] - - for run_name, epoch, wandb_id in ckpts: - try: - ckpt_dir = os.path.join("logs", "downloaded_ckpts") - # for wandb runs, use short id as name, otherwise use ckpt dir name - if wandb_id is not None: - run_name = wandb_id - ckpt_path, epoch = download_model_from_wandb(run_name, ckpt_dir) - else: - ckpt_path = None - for file in os.listdir(os.path.join(ckpt_dir, run_name)): - if f"epoch={epoch}_" in file or f"epoch={epoch}." in file: - ckpt_path = os.path.join(os.path.join(ckpt_dir, run_name, file)) - assert ( - ckpt_path is not None - ), f"Failed to find checkpoint for epoch {epoch} in {os.path.join(ckpt_dir, run_name)}" - print(f"Starting run {run_name} (epoch {epoch})") - - for dataset, dataset_key in prediction_datasets: - # copy data from legacy buffer dir if possible - old_buffer_dir = os.path.join( - "results_buffer", - *ckpt_path.split(os.path.sep)[-2:], - f"{dataset.__class__.__name__}_{dataset_key}", - ) - buffer_dir = os.path.join( - "results_buffer", - run_name, - f"epoch={epoch}", - f"{dataset.__class__.__name__}_{dataset_key}", - ) - print("Checking for buffer dir", old_buffer_dir) - if os.path.isdir(old_buffer_dir): - from distutils.dir_util import copy_tree, remove_tree - - os.makedirs(buffer_dir, exist_ok=True) - copy_tree(old_buffer_dir, buffer_dir) - remove_tree(old_buffer_dir, dry_run=True) - print(f"Moved buffer from {old_buffer_dir} to {buffer_dir}") - print(f"Using buffer_dir {buffer_dir}") - preds, labels = load_preds_labels( - ckpt_path, dataset, dataset_key, buffer_dir - ) - # identity function if remove_violations is False - smooth_preds(preds) - - details_path = None # os.path.join( - # results_dir, - # f"{run_name}_ep{epoch}_{dataset.__class__.__name__}_{dataset_key}", - # ) - metrics_dict = run_consistency_metrics( - preds, - prediction_filters, - check_consistency_on, - consistency_metrics, - verbose_violation_output, - save_details_to=details_path, - ) - with open(results_path_consistency, "a") as f: - for metric in metrics_dict: - values = metrics_dict[metric] - f.write( - f"{run_name},{epoch},{dataset.__class__.__name__},{dataset_key},{metric}," - f"{','.join([str(values[k]) for k in consistency_keys])}\n" - ) - print( - f"Consistency metrics have been written to {results_path_consistency}" - ) - if labels is not None: - metrics_dict = run_supervised_metrics( - preds, labels, save_details_to=details_path - ) - with open(results_path_supervised, "a") as f: - f.write( - f"{run_name},{epoch},{dataset.__class__.__name__},{dataset_key}," - f"{','.join([str(metrics_dict[k]) for k in supervised_keys])}\n" - ) - print( - f"Supervised metrics have been written to {results_path_supervised}" - ) - except Exception as e: - print( - f"Error during run {wandb_id if wandb_id is not None else run_name}: {e}" - ) - print(traceback.format_exc()) - - -# follow-up to NeSy submission -def run_fuzzy_loss(tag="fuzzy_loss", skip_first_n=0): - api = wandb.Api() - runs = api.runs("chebai/chebai", filters={"tags": tag}) - print(f"Found {len(runs)} wandb runs tagged with '{tag}'") - ids = [run.id for run in runs] - chebi100 = ChEBIOver100( - chebi_version=231, - splits_file_path=os.path.join( - "data", "chebi_v231", "ChEBI100", "fuzzy_loss_splits.csv" - ), - ) - local_ckpts = [][skip_first_n:] - pubchem_kmeans = PubChemKMeans() - run_all( - ids[max(0, skip_first_n - len(local_ckpts)) :], # ids, - local_ckpts, - consistency_metrics=[binary], - check_consistency_on=chebi100, - prediction_datasets=[ - (chebi100, "test"), - # (pubchem_kmeans, "cluster1_cutoff2k.pt"), - # (pubchem_kmeans, "cluster2.pt"), - # (pubchem_kmeans, "ten_from_each_cluster.pt"), - # (pubchem_kmeans, "chebi_close.pt"), - ], - ) - - -if __name__ == "__main__": - if len(sys.argv) > 2: - run_fuzzy_loss(sys.argv[1], int(sys.argv[2])) - elif len(sys.argv) > 1: - run_fuzzy_loss(sys.argv[1]) - else: - run_fuzzy_loss() diff --git a/chebai/result/base.py b/chebai/result/base.py deleted file mode 100644 index 9d583a0..0000000 --- a/chebai/result/base.py +++ /dev/null @@ -1,105 +0,0 @@ -import abc -import multiprocessing as mp -from typing import Iterable - -import torch -import tqdm - -from chebai.models.base import ChebaiBaseNet - -PROCESSORS = dict() - - -class ResultProcessor(abc.ABC): - @classmethod - def _identifier(cls) -> str: - raise NotImplementedError - - def start(self): - pass - - def close(self): - pass - - def __init_subclass__(cls, **kwargs): - assert ( - cls._identifier() not in PROCESSORS - ), f"ResultProcessor {cls.__name__} does not have a unique identifier" - PROCESSORS[cls._identifier()] = cls - - def process_prediction(self, proc_id, features, labels, pred, ident): - raise NotImplementedError - - -class ResultFactory(abc.ABC): - def __init__( - self, model: ChebaiBaseNet, dataset, processors: Iterable[ResultProcessor] - ): - self._model = model - self._reader = dataset.reader - self.dataset = dataset - self._processors = processors - - def _process_row(self, row): - return row - - def _generate_predictions(self, data_path, raw=False, **kwargs): - self._model.eval() - collate = self._reader.COLLATOR() - if raw: - data_tuples = [ - (x["features"], x["ident"], self._reader.to_data(self._process_row(x))) - for x in self.dataset._load_dict(data_path) - ] - else: - data_tuples = [ - (x.get("raw_features", x["ident"]), x["ident"], x) - for x in torch.load(data_path, weights_only=False) - ] - - for raw_features, ident, row in tqdm.tqdm(data_tuples): - raw_labels = row.get("labels") - - processable_data = self._model._process_batch(collate([row]), 0) - - model_output = self._model(processable_data) - preds, labels = self._model._get_prediction_and_labels( - processable_data, processable_data["labels"], model_output - ) - d = dict( - model_output=model_output, - preds=preds, - raw_features=raw_features, - ident=ident, - threshold=self._model.thres, - ) - if raw_labels is not None: - d["labels"] = raw_labels - yield d - - def call_procs(self, args): - proc_id, proc_args = args - for proc in self._processors: - try: - proc.process_prediction(proc_id, **proc_args) - except Exception: - print("Could not process results for", proc_args["ident"]) - raise - - def execute(self, data_path, **kwargs): - for proc in self._processors: - proc.start() - try: - with mp.Pool() as pool: - res = map( - self.call_procs, - enumerate(self._generate_predictions(data_path, **kwargs)), - ) - for r in res: - pass - - except: - raise - finally: - for proc in self._processors: - proc.close() diff --git a/chebai/result/classification.py b/chebai/result/classification.py deleted file mode 100644 index bb23dea..0000000 --- a/chebai/result/classification.py +++ /dev/null @@ -1,105 +0,0 @@ -from typing import List - -import matplotlib.pyplot as plt -import pandas as pd -import seaborn as sns -from torch import Tensor -from torchmetrics.classification import ( - MultilabelF1Score, - MultilabelPrecision, - MultilabelRecall, -) - -from chebai.callbacks.epoch_metrics import BalancedAccuracy, MacroF1 -from chebai.result.utils import * - - -def visualise_f1(logs_path: str) -> None: - """ - Visualize F1 scores from metrics.csv and save the plot as f1_plot.png. - - Args: - logs_path: The path to the directory containing metrics.csv. - """ - df = pd.read_csv(os.path.join(logs_path, "metrics.csv")) - df_loss = df.melt( - id_vars="epoch", - value_vars=[ - "val_ep_macro-f1", - "val_micro-f1", - "train_micro-f1", - "train_ep_macro-f1", - ], - ) - lineplt = sns.lineplot(df_loss, x="epoch", y="value", hue="variable") - plt.savefig(os.path.join(logs_path, "f1_plot.png")) - plt.show() - - -def print_metrics( - preds: Tensor, - labels: Tensor, - device: torch.device, - classes: Optional[List[str]] = None, - top_k: int = 10, - markdown_output: bool = False, -) -> None: - """ - Prints relevant metrics, including micro and macro F1, recall and precision, - best k classes, and worst classes. - - Args: - preds: Predicted labels as a tensor. - labels: True labels as a tensor. - device: The device to perform computations on. - classes: Optional list of class names. - top_k: The number of top classes to display based on F1 score. - markdown_output: If True, print metrics in markdown format. - """ - f1_micro = MultilabelF1Score(preds.shape[1], average="micro").to(device=device) - my_f1_macro = MacroF1(preds.shape[1]).to(device=device) - my_bal_acc = BalancedAccuracy(preds.shape[1]).to(device=device) - - print(f"Macro-F1: {my_f1_macro(preds, labels):3f}") - print(f"Micro-F1: {f1_micro(preds, labels):3f}") - print(f"Balanced Accuracy: {my_bal_acc(preds, labels):3f}") - precision_macro = MultilabelPrecision(preds.shape[1], average="macro").to( - device=device - ) - precision_micro = MultilabelPrecision(preds.shape[1], average="micro").to( - device=device - ) - macro_adjust = 1 - recall_macro = MultilabelRecall(preds.shape[1], average="macro").to(device=device) - recall_micro = MultilabelRecall(preds.shape[1], average="micro").to(device=device) - print(f"Macro-Precision: {precision_macro(preds, labels) * macro_adjust:3f}") - print(f"Micro-Precision: {precision_micro(preds, labels):3f}") - print(f"Macro-Recall: {recall_macro(preds, labels) * macro_adjust:3f}") - print(f"Micro-Recall: {recall_micro(preds, labels):3f}") - if markdown_output: - print( - f"| Model | Macro-F1 | Micro-F1 | Macro-Precision | Micro-Precision | Macro-Recall | Micro-Recall | Balanced Accuracy |" - ) - print(f"| --- | --- | --- | --- | --- | --- | --- | --- |") - print( - f"| | {my_f1_macro(preds, labels):3f} | {f1_micro(preds, labels):3f} | {precision_macro(preds, labels):3f} | " - f"{precision_micro(preds, labels):3f} | {recall_macro(preds, labels):3f} | " - f"{recall_micro(preds, labels):3f} | {my_bal_acc(preds, labels):3f} |" - ) - - classwise_f1_fn = MultilabelF1Score(preds.shape[1], average=None).to(device=device) - classwise_f1 = classwise_f1_fn(preds, labels) - best_classwise_f1 = torch.topk(classwise_f1, top_k).indices - print(f"Top {top_k} classes (F1-score):") - for i, best in enumerate(best_classwise_f1): - print( - f"{i + 1}. {classes[best] if classes is not None else best} - F1: {classwise_f1[best]:3f}" - ) - - zeros = [] - for i, f1 in enumerate(classwise_f1): - if f1 == 0.0 and torch.sum(labels[:, i]) != 0: - zeros.append(f"{classes[i] if classes is not None else i}") - print( - f'Found {len(zeros)} classes with F1-score == 0 (and non-zero labels): {", ".join(zeros)}' - ) diff --git a/chebai/result/evaluate_predictions.py b/chebai/result/evaluate_predictions.py deleted file mode 100644 index 355c07c..0000000 --- a/chebai/result/evaluate_predictions.py +++ /dev/null @@ -1,108 +0,0 @@ -from typing import Tuple - -import numpy as np -import torch -from jsonargparse import CLI -from torchmetrics.functional.classification import multilabel_auroc - -from chebai.callbacks.epoch_metrics import MacroF1 -from chebai.result.utils import load_results_from_buffer - - -class EvaluatePredictions: - def __init__(self, eval_dir: str): - """ - Initializes the EvaluatePredictions class. - - Args: - eval_dir (str): Path to the directory containing evaluation files. - """ - self.eval_dir = eval_dir - self.metrics = [] - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.num_labels = None - - @staticmethod - def validate_eval_dir(label_files: torch.Tensor, pred_files: torch.Tensor) -> None: - """ - Validates that the number of labels matches the number of predictions, - ensuring that they have the same shape. - - Args: - label_files (torch.Tensor): Tensor containing label data. - pred_files (torch.Tensor): Tensor containing prediction data. - - Raises: - ValueError: If label and prediction tensors are mismatched in shape. - """ - if label_files is None or pred_files is None: - raise ValueError("Both label and prediction tensors must be provided.") - - # Check if the number of labels matches the number of predictions - if label_files.shape[0] != pred_files.shape[0]: - raise ValueError( - "Number of label tensors does not match the number of prediction tensors." - ) - - # Validate that the last dimension matches the expected number of classes - if label_files.shape[1] != pred_files.shape[1]: - raise ValueError( - "Label and prediction tensors must have the same shape in terms of class outputs." - ) - - def evaluate(self) -> None: - """ - Loads predictions and labels, validates file correspondence, and calculates Multilabel AUROC and Fmax. - """ - test_preds, test_labels = load_results_from_buffer(self.eval_dir, self.device) - self.validate_eval_dir(test_labels, test_preds) - self.num_labels = test_preds.shape[1] - - ml_auroc = multilabel_auroc( - test_preds, test_labels, num_labels=self.num_labels - ).item() - - print("Multilabel AUC-ROC:", ml_auroc) - - fmax, threshold = self.calculate_fmax(test_preds, test_labels) - print(f"F-max : {fmax}, threshold: {threshold}") - - def calculate_fmax( - self, test_preds: torch.Tensor, test_labels: torch.Tensor - ) -> Tuple[float, float]: - """ - Calculates the Fmax metric using the F1 score at various thresholds. - - Args: - test_preds (torch.Tensor): Predicted scores for the labels. - test_labels (torch.Tensor): True labels for the evaluation. - - Returns: - Tuple[float, float]: The maximum F1 score and the corresponding threshold. - """ - # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/metrics.py#L51-L52 - thresholds = np.linspace(0, 1, 101) - fmax = 0.0 - best_threshold = 0.0 - - for t in thresholds: - custom_f1_metric = MacroF1(num_labels=self.num_labels, threshold=t) - custom_f1_metric.update(test_preds, test_labels) - custom_f1_metric_score = custom_f1_metric.compute().item() - - # Check if the current score is the best we've seen - if custom_f1_metric_score > fmax: - fmax = custom_f1_metric_score - best_threshold = t - - return fmax, best_threshold - - -class Main: - def evaluate(self, eval_dir: str): - EvaluatePredictions(eval_dir).evaluate() - - -if __name__ == "__main__": - # evaluate_predictions.py evaluate - CLI(Main) diff --git a/chebai/result/molplot.py b/chebai/result/molplot.py deleted file mode 100644 index 8fdbc77..0000000 --- a/chebai/result/molplot.py +++ /dev/null @@ -1,506 +0,0 @@ -import abc -from os import makedirs -from tempfile import NamedTemporaryFile - -import networkx as nx -import numpy as np -import pandas as pd -import torch -from matplotlib import cm, colors -from matplotlib import pyplot as plt -from matplotlib import rc -from matplotlib.image import AxesImage, imread -from networkx.algorithms.isomorphism import GraphMatcher -from pysmiles.read_smiles import * -from pysmiles.read_smiles import _tokenize -from rdkit import Chem -from rdkit.Chem.Draw import MolToMPL, rdMolDraw2D - -from chebai.preprocessing.datasets import JCI_500_COLUMNS, JCI_500_COLUMNS_INT -from chebai.result.base import ResultProcessor - - -class AttentionMolPlot: - def draw_attention_molecule(self, smiles, attention): - pmol = self.read_smiles_with_index(smiles) - rdmol = Chem.MolFromSmiles(smiles) - if not rdmol: - raise NoRDMolException - rdmolx = self.mol_to_nx(rdmol) - gm = GraphMatcher(pmol, rdmolx) - iso = next(gm.isomorphisms_iter()) - token_to_node_map = { - pmol.nodes[node]["token_index"]: iso[node] for node in pmol.nodes - } - d = rdMolDraw2D.MolDraw2DCairo(500, 500) - cmap = cm.ScalarMappable(cmap=cm.Greens) - - aggr_attention_colors = cmap.to_rgba( - np.max(attention[2:, :], axis=0), norm=False - ) - cols = { - token_to_node_map[token_index]: tuple( - aggr_attention_colors[token_index].tolist() - ) - for node, token_index in nx.get_node_attributes(pmol, "token_index").items() - } - highlight_atoms = [ - token_to_node_map[token_index] - for node, token_index in nx.get_node_attributes(pmol, "token_index").items() - ] - rdMolDraw2D.PrepareAndDrawMolecule( - d, rdmol, highlightAtoms=highlight_atoms, highlightAtomColors=cols - ) - - d.FinishDrawing() - return d - - def plot_attentions(self, smiles, attention, threshold, labels): - d = self.draw_attention_molecule(smiles, attention) - cmap = cm.ScalarMappable(cmap=cm.Greens) - attention_colors = cmap.to_rgba(attention, norm=False) - num_tokens = sum(1 for _ in _tokenize(smiles)) - - fig = plt.figure(figsize=(15, 15), facecolor="w") - - rc("font", **{"family": "monospace", "monospace": "DejaVu Sans Mono"}) - fig.tight_layout() - - ax2, ax = fig.subplots(2, 1, gridspec_kw={"height_ratios": [10, 1]}) - - with NamedTemporaryFile(mode="wt", suffix=".png") as svg1: - d.WriteDrawingText(svg1.name) - ax2.imshow(imread(svg1.name)) - ax2.axis("off") - ax2.spines["left"].set_position("center") - ax2.spines["bottom"].set_position("zero") - ax2.autoscale(tight=True) - - table = plt.table( - cellText=[ - (["[CLS]"] + [t for _, t in _tokenize(smiles)]) - for _ in range(attention.shape[0]) - ], - cellColours=attention_colors, - cellLoc="center", - ) - table.auto_set_column_width(list(range(num_tokens))) - table.scale(1, 4) - table.set_fontsize(26) - - ax.add_table(table) - ax.axis("off") - ax.spines["top"].set_position("zero") - ax.autoscale(tight=True) - - self.counter += 1 - for w, label, predicted in labels: - if predicted: - cat = "p" - else: - cat = "n" - if predicted == label: - cat = "t" + cat - else: - cat = "f" + cat - fig.savefig( - f"/tmp/plots/{w}/{cat}_{self.counter}.png", - transparent=False, - bbox_inches="tight", - pad_inches=0, - ) - plt.close() - - @staticmethod - def mol_to_nx(mol): - G = nx.Graph() - - for atom in mol.GetAtoms(): - G.add_node( - atom.GetIdx(), - atomic_num=atom.GetAtomicNum(), - formal_charge=atom.GetFormalCharge(), - chiral_tag=atom.GetChiralTag(), - hybridization=atom.GetHybridization(), - num_explicit_hs=atom.GetNumExplicitHs(), - is_aromatic=atom.GetIsAromatic(), - ) - for bond in mol.GetBonds(): - G.add_edge( - bond.GetBeginAtomIdx(), - bond.GetEndAtomIdx(), - bond_type=bond.GetBondType(), - ) - return G - - @staticmethod - def read_smiles_with_index( - smiles, - explicit_hydrogen=False, - zero_order_bonds=True, - reinterpret_aromatic=True, - ): - """ - This is just a re-implementation of pysmiles.read_smiles, that stores token indices - """ - bond_to_order = {"-": 1, "=": 2, "#": 3, "$": 4, ":": 1.5, ".": 0} - mol = nx.Graph() - anchor = None - idx = 0 - default_bond = 1 - next_bond = None - branches = [] - ring_nums = {} - for token_index, (tokentype, token) in enumerate(_tokenize(smiles)): - if tokentype == TokenType.ATOM: - mol.add_node(idx, token_index=token_index, **parse_atom(token)) - if anchor is not None: - if next_bond is None: - next_bond = default_bond - if next_bond or zero_order_bonds: - mol.add_edge(anchor, idx, order=next_bond) - next_bond = None - anchor = idx - idx += 1 - elif tokentype == TokenType.BRANCH_START: - branches.append(anchor) - elif tokentype == TokenType.BRANCH_END: - anchor = branches.pop() - elif tokentype == TokenType.BOND_TYPE: - if next_bond is not None: - raise ValueError( - "Previous bond (order {}) not used. " - 'Overwritten by "{}"'.format(next_bond, token) - ) - next_bond = bond_to_order[token] - elif tokentype == TokenType.RING_NUM: - if token in ring_nums: - jdx, order = ring_nums[token] - if next_bond is None and order is None: - next_bond = default_bond - elif order is None: # Note that the check is needed, - next_bond = next_bond # But this could be pass. - elif next_bond is None: - next_bond = order - elif next_bond != order: # Both are not None - raise ValueError( - "Conflicting bond orders for ring " - "between indices {}".format(token) - ) - # idx is the index of the *next* atom we're adding. So: -1. - if mol.has_edge(idx - 1, jdx): - raise ValueError( - "Edge specified by marker {} already " - "exists".format(token) - ) - if idx - 1 == jdx: - raise ValueError( - "Marker {} specifies a bond between an " - "atom and itself".format(token) - ) - if next_bond or zero_order_bonds: - mol.add_edge(idx - 1, jdx, order=next_bond) - next_bond = None - del ring_nums[token] - else: - if idx == 0: - raise ValueError( - "Can't have a marker ({}) before an atom" "".format(token) - ) - # idx is the index of the *next* atom we're adding. So: -1. - ring_nums[token] = (idx - 1, next_bond) - next_bond = None - elif tokentype == TokenType.EZSTEREO: - LOGGER.warning( - 'E/Z stereochemical information, which is specified by "%s", will be discarded', - token, - ) - if ring_nums: - raise KeyError("Unmatched ring indices {}".format(list(ring_nums.keys()))) - - # Time to deal with aromaticity. This is a mess, because it's not super - # clear what aromaticity information has been provided, and what should be - # inferred. In addition, to what extend do we want to provide a "sane" - # molecule, even if this overrides what the SMILES string specifies? - cycles = nx.cycle_basis(mol) - ring_idxs = set() - for cycle in cycles: - ring_idxs.update(cycle) - non_ring_idxs = set(mol.nodes) - ring_idxs - for n_idx in non_ring_idxs: - if mol.nodes[n_idx].get("aromatic", False): - raise ValueError( - "You specified an aromatic atom outside of a" - " ring. This is impossible" - ) - - mark_aromatic_edges(mol) - fill_valence(mol) - if reinterpret_aromatic: - mark_aromatic_atoms(mol) - mark_aromatic_edges(mol) - for idx, jdx in mol.edges: - if ( - not mol.nodes[idx].get("aromatic", False) - or not mol.nodes[jdx].get("aromatic", False) - ) and mol.edges[idx, jdx].get("order", 1) == 1.5: - mol.edges[idx, jdx]["order"] = 1 - - if explicit_hydrogen: - add_explicit_hydrogens(mol) - else: - remove_explicit_hydrogens(mol) - return mol - - -class AttentionOnMoleculesProcessor(AttentionMolPlot, ResultProcessor): - def __init__(self, *args, headers=None, **kwargs): - super().__init__(*args, **kwargs) - self.headers = headers - - def start(self): - self.counter = 0 - - @classmethod - def _identifier(cls): - return "platt" - - def filter(self, l): - return - - def process_prediction( - self, proc_id, preds, raw_features, model_output, labels, **kwargs - ): - atts = torch.stack(model_output["attentions"]).squeeze(1).detach().numpy() - predictions = preds.detach().numpy().squeeze(0) > 0.5 - if self.headers is None: - headers = list(range(len(labels))) - else: - headers = self.headers - - for w in headers: - makedirs(f"/tmp/plots/{w}", exist_ok=True) - - try: - self.plot_attentions( - raw_features, - np.max(np.max(atts, axis=2), axis=1), - 0.4, - [ - (ident, label, predicted) - for label, ident, predicted in zip(labels, headers, predictions) - if (label or predicted) - ], - ) - except StopIteration: - print("Could not match", raw_features) - except NoRDMolException: - pass - - -class LastLayerAttentionProcessor(AttentionMolPlot, ResultProcessor): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def start(self): - self.counter = 0 - for w in JCI_500_COLUMNS_INT: - makedirs(f"/tmp/plots/{w}", exist_ok=True) - - @classmethod - def _identifier(cls): - return "platt_last" - - def filter(self, l): - return - - def process_prediction(self, raw_features, raw_labels, features, labels, pred): - atts = torch.stack(pred["attentions"]).squeeze(1).detach().numpy() - last_layer = np.max(atts, axis=2)[-1, :] - if np.any(last_layer > 0.4): - try: - self.plot_attentions( - raw_features, - np.max(np.max(atts, axis=2), axis=1), - 0.4, - [ - ident - for present, ident in zip(labels, JCI_500_COLUMNS_INT) - if present - ], - ) - except StopIteration: - print("Could not match", raw_features) - except NoRDMolException: - pass - - -class SingletonAttentionProcessor(AttentionMolPlot, ResultProcessor): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def start(self): - self.counter = 0 - for w in JCI_500_COLUMNS_INT: - makedirs(f"/tmp/plots/{w}", exist_ok=True) - - @classmethod - def _identifier(cls): - return "platt_singles" - - def filter(self, l): - return - - def process_prediction(self, raw_features, raw_labels, features, labels, pred): - atts = torch.stack(pred["attentions"]).squeeze(1).detach().numpy() - if sum(labels) == 1: - try: - predictions = ( - torch.sigmoid(pred["logits"]).detach().numpy().squeeze(0) > 0.5 - ) - self.plot_attentions( - raw_features, - np.max(np.average(atts, axis=2), axis=1), - 0.4, - [ - (ident, label, predicted) - for label, ident, predicted in zip( - labels, JCI_500_COLUMNS_INT, predictions - ) - if (label or predicted) - ], - ) - except StopIteration: - print("Could not match", raw_features) - except NoRDMolException: - pass - - -class AttentionNetwork(ResultProcessor): - def __init__(self, *args, headers=None, **kwargs): - super().__init__(*args, **kwargs) - self.headers = headers - self.i = 0 - - @classmethod - def _identifier(cls): - return "platt_table" - - def start(self): - self.counter = 0 - - def process_prediction( - self, - proc_id, - preds, - raw_features, - model_output, - labels, - ident=None, - threshold=0.5, - **kwargs, - ): - if self.headers is None: - headers = list(range(len(labels))) - else: - headers = self.headers - - for w in headers: - makedirs(f"plots/{w}", exist_ok=True) - - atts = torch.stack(model_output["attentions"]).squeeze(1).detach().numpy() - predictions = preds.detach().numpy().squeeze(0) > 0.5 - plt.rcParams.update({"font.size": 8}) - try: - attentions = atts - tokens = ["[CLS]"] + [s for _, s in _tokenize(raw_features)] - cmap = cm.ScalarMappable(cmap=cm.Greens) - assert len(tokens) == attentions.shape[2] - - rows = int((attentions.shape[1] + 2)) - width = len(tokens) - height = 12 - rdmol = Chem.MolFromSmiles(raw_features) - if rdmol is not None: - fig0 = MolToMPL(rdmol, fitImage=True) - fig0.text( - 0.1, - 0, - "annotated:" - + ", ".join( - str(l) for (l, is_member) in zip(headers, labels) if is_member - ) - + "\n" - + "predicted:" - + ", ".join( - str(l) - for (l, is_member) in zip(headers, predictions) - if is_member - ), - fontdict=dict(fontsize=10), - ) - fig0.savefig( - f"plots/mol_{ident}.png", - bbox_inches="tight", - pad_inches=0, - ) - plt.close(fig0) - fig = plt.figure(figsize=(10 * 12, width // 3)) - l_tokens = {i: str(t) for i, t in enumerate(tokens)} - r_tokens = {(len(tokens) + i): str(t) for i, t in enumerate(tokens)} - labels = dict(list(l_tokens.items()) + list(r_tokens.items())) - edges = [(l, r) for r in r_tokens.keys() for l in l_tokens.keys()] - g = nx.Graph() - g.add_nodes_from(l_tokens, bipartite=0) - g.add_nodes_from(r_tokens, bipartite=1) - g.add_edges_from(edges) - pos = np.array( - [(0, -i) for i in range(len(l_tokens))] - + [(1, -i) for i in range(len(l_tokens))] - ) - - offset = np.array( - [(1, 0) for i in range(len(l_tokens))] - + [(1, 0) for i in range(len(l_tokens))] - ) - # axes = fig.subplots(1, 6 * 8 + 5, subplot_kw=dict(frameon=False)) - - ax = fig.add_subplot(111) - ax.axis("off") - for layer in range(attentions.shape[0]): - for head in range(attentions.shape[1]): - index = 8 * (layer) + head + layer + 1 - - at = np.concatenate([a for a in attentions[layer, head]]) - col = cmap.cmap(at) - col[:, 3] = at - nx.draw_networkx( - g, - pos=pos + (index * offset), - edge_color=col, - ax=ax, - labels=labels, - node_color="none", - node_size=8, - ) - # sns.heatmap(attentions[i,j], linewidth=0.5, ax=ax, cmap=cm.Greens, square=True, vmin=0, vmax=1, xticklabels=tokens, yticklabels=tokens) - fig.subplots_adjust() - fig.savefig( - f"plots/att_{ident}.png", - # transparent=True, - bbox_inches="tight", - pad_inches=0, - dpi=100, - ) - - plt.close() - except StopIteration: - print("Could not match", raw_features) - except NoRDMolException: - pass - finally: - plt.close() - - -class NoRDMolException(Exception): - pass diff --git a/chebai/result/prediction_json.py b/chebai/result/prediction_json.py deleted file mode 100644 index 924df65..0000000 --- a/chebai/result/prediction_json.py +++ /dev/null @@ -1,26 +0,0 @@ -import json - -from chebai.result.base import ResultProcessor - - -class JSONResultProcessor(ResultProcessor): - @classmethod - def _identifier(cls): - return "json" - - def start(self): - self.data = [] - - def close(self): - with open("predictions.json", "w") as fout: - json.dump(self.data, fout) - del self.data - - def process_prediction(self, proc_id, raw_features, labels, preds, ident, **kwargs): - self.data.append( - dict( - ident=ident, - labels=labels if labels is not None else None, - prediction=preds.tolist(), - ) - ) diff --git a/chebai/result/pretraining.py b/chebai/result/pretraining.py deleted file mode 100644 index 8d712f2..0000000 --- a/chebai/result/pretraining.py +++ /dev/null @@ -1,65 +0,0 @@ -import os - -import matplotlib.pyplot as plt -import pandas as pd -import seaborn as sns -import torch -import tqdm - -import chebai.models.electra as electra -from chebai.loss.pretraining import ElectraPreLoss -from chebai.result.base import ResultProcessor - - -def visualise_loss(logs_path): - df = pd.read_csv(os.path.join(logs_path, "metrics.csv")) - df_loss = df.melt( - id_vars="epoch", value_vars=["val_loss_epoch", "train_loss_epoch"] - ) - lineplt = sns.lineplot(df_loss, x="epoch", y="value", hue="variable") - plt.savefig(os.path.join(logs_path, "f1_plot.png")) - plt.show() - - -# get predictions from model -def evaluate_model(logs_base_path, model_filename, data_module): - model = electra.ElectraPre.load_from_checkpoint( - os.path.join( - logs_base_path, - "best_epoch=85_val_loss=0.0147_val_micro-f1=0.90.ckpt", - model_filename, - ) - ) - assert isinstance(model, electra.ElectraPre) - collate = data_module.reader.COLLATOR() - test_file = "test.pt" - data_path = os.path.join(data_module.processed_dir, test_file) - data_list = torch.load(data_path, weights_only=False) - preds_list = [] - labels_list = [] - - for row in tqdm.tqdm(data_list): - processable_data = model._process_batch(collate([row]), 0) - model_output = model(processable_data, **processable_data["model_kwargs"]) - preds, labels = model._get_prediction_and_labels( - processable_data, processable_data["labels"], model_output - ) - preds_list.append(preds) - labels_list.append(labels) - - test_preds = torch.cat(preds_list) - test_labels = torch.cat(labels_list) - print(test_preds.shape) - print(test_labels.shape) - test_loss = ElectraPreLoss() - print(f"Loss on test set: {test_loss(test_preds, test_labels)}") - # f1_macro = MultilabelF1Score(test_preds.shape[1], average='macro') - # f1_micro = MultilabelF1Score(test_preds.shape[1], average='micro') - # print(f'Macro-F1 on test set with {test_preds.shape[1]} classes: {f1_macro(test_preds, test_labels):3f}') - # print(f'Micro-F1 on test set with {test_preds.shape[1]} classes: {f1_micro(test_preds, test_labels):3f}') - - -class PretrainingResultProcessor(ResultProcessor): - @classmethod - def _identifier(cls) -> str: - return "PretrainingResultProcessor" diff --git a/chebai/result/utils.py b/chebai/result/utils.py deleted file mode 100644 index 991960d..0000000 --- a/chebai/result/utils.py +++ /dev/null @@ -1,235 +0,0 @@ -import os -import shutil -from typing import Optional, Tuple, Union - -import torch -import tqdm -import wandb -import wandb.util as wandb_util - -from chebai.models.base import ChebaiBaseNet -from chebai.models.electra import Electra -from chebai.preprocessing.datasets.base import XYBaseDataModule -from chebai.preprocessing.datasets.chebi import _ChEBIDataExtractor - - -def get_checkpoint_from_wandb( - epoch: int, - run: wandb.apis.public.Run, - root: str = os.path.join("logs", "downloaded_ckpts"), -): - """ - Gets a wandb checkpoint based on run and epoch, downloads it if necessary. - - Args: - epoch: The epoch number of the checkpoint to retrieve. - run: The wandb run object. - root: The root directory to save the downloaded checkpoint. - - Returns: - The location of the downloaded checkpoint. - """ - api = wandb.Api() - - files = run.files() - for file in files: - if file.name.startswith( - f"checkpoints/per_epoch={epoch}" - ) or file.name.startswith(f"checkpoints/best_epoch={epoch}"): - dest_path = os.path.join( - root, run.id, file.name.split("/")[-1].split("_")[1] + ".ckpt" - ) - # legacy: also look for ckpts in the old format - old_dest_path = os.path.join(root, run.name, file.name.split("/")[-1]) - if not os.path.isfile(dest_path): - if os.path.isfile(old_dest_path): - print(f"Copying checkpoint from {old_dest_path} to {dest_path}") - os.makedirs(os.path.dirname(dest_path), exist_ok=True) - shutil.copy2(old_dest_path, dest_path) - else: - print(f"Downloading checkpoint to {dest_path}") - wandb_util.download_file_from_url(dest_path, file.url, api.api_key) - return dest_path - print(f"No model found for epoch {epoch}") - return None - - -def _run_batch(batch, model, collate): - collated = collate(batch) - collated.x = collated.to_x(model.device) - if collated.y is not None: - collated.y = collated.to_y(model.device) - processable_data = model._process_batch(collated, 0) - del processable_data["loss_kwargs"] - model_output = model(processable_data, **processable_data["model_kwargs"]) - preds, labels = model._get_prediction_and_labels( - processable_data, processable_data["labels"], model_output - ) - return preds, labels - - -def _concat_tuple(l): - if isinstance(l[0], tuple): - print(l[0]) - return tuple([torch.cat([t[i] for t in l]) for i in range(len(l[0]))]) - return torch.cat(l) - - -def evaluate_model( - model: ChebaiBaseNet, - data_module: XYBaseDataModule, - filename: Optional[str] = None, - buffer_dir: Optional[str] = None, - batch_size: int = 32, - skip_existing_preds: bool = False, - kind: str = "test", -) -> Tuple[torch.Tensor, Optional[torch.Tensor]]: - """ - Runs the model on the test set of the data module or on the dataset found in the specified file. - If buffer_dir is set, results will be saved in buffer_dir. - - Note: - No need to provide "filename" parameter for Chebi dataset, "kind" parameter should be provided. - - Args: - model: The model to evaluate. - data_module: The data module containing the dataset. - filename: Optional file name for the dataset. - buffer_dir: Optional directory to save the results. - batch_size: The batch size for evaluation. - skip_existing_preds: Whether to skip evaluation if predictions already exist. - kind: Kind of split of the data to be used for testing the model. Default is `test`. - - Returns: - Tensors with predictions and labels. - """ - model.eval() - collate = data_module.reader.COLLATOR() - - if isinstance(data_module, _ChEBIDataExtractor): - # As the dynamic split change is implemented only for chebi-dataset as of now - data_df = data_module.dynamic_split_dfs[kind] - data_list = data_df.to_dict(orient="records") - else: - data_list = data_module.load_processed_data("test", filename) - data_list = data_list[: data_module.data_limit] - preds_list = [] - labels_list = [] - if buffer_dir is not None: - os.makedirs(buffer_dir, exist_ok=True) - save_ind = 0 - save_batch_size = 128 - n_saved = 1 - - print(f"") - for i in tqdm.tqdm(range(0, len(data_list), batch_size)): - if not ( - skip_existing_preds - and os.path.isfile(os.path.join(buffer_dir, f"preds{save_ind:03d}.pt")) - ): - preds, labels = _run_batch(data_list[i : i + batch_size], model, collate) - preds_list.append(preds) - labels_list.append(labels) - - if buffer_dir is not None: - if n_saved * batch_size >= save_batch_size: - torch.save( - _concat_tuple(preds_list), - os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), - ) - if labels_list[0] is not None: - torch.save( - _concat_tuple(labels_list), - os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"), - ) - preds_list = [] - labels_list = [] - if n_saved * batch_size >= save_batch_size: - save_ind += 1 - n_saved = 0 - n_saved += 1 - - if buffer_dir is None: - test_preds = _concat_tuple(preds_list) - if labels_list is not None: - test_labels = _concat_tuple(labels_list) - return test_preds, test_labels - return test_preds, None - elif len(preds_list) < 0: - if len(preds_list) > 0 and preds_list[0] is not None: - torch.save( - _concat_tuple(preds_list), - os.path.join(buffer_dir, f"preds{save_ind:03d}.pt"), - ) - if len(labels_list) > 0 and labels_list[0] is not None: - torch.save( - _concat_tuple(labels_list), - os.path.join(buffer_dir, f"labels{save_ind:03d}.pt"), - ) - - -def load_results_from_buffer( - buffer_dir: str, device: torch.device -) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]: - """ - Load results stored in evaluate_model() from the buffer directory. - - Args: - buffer_dir: The directory containing the buffered results. - device: The device to load the results onto. - - Returns: - Tensors with predictions and labels. - """ - preds_list = [] - labels_list = [] - - i = 0 - filename = f"preds{i:03d}.pt" - while os.path.isfile(os.path.join(buffer_dir, filename)): - preds_list.append( - torch.load( - os.path.join(buffer_dir, filename), - map_location=torch.device(device), - weights_only=False, - ) - ) - i += 1 - filename = f"preds{i:03d}.pt" - - i = 0 - filename = f"labels{i:03d}.pt" - while os.path.isfile(os.path.join(buffer_dir, filename)): - labels_list.append( - torch.load( - os.path.join(buffer_dir, filename), - map_location=torch.device(device), - weights_only=False, - ) - ) - i += 1 - filename = f"labels{i:03d}.pt" - - if len(preds_list) > 0: - test_preds = torch.cat(preds_list) - else: - test_preds = None - if len(labels_list) > 0: - test_labels = torch.cat(labels_list) - else: - test_labels = None - - return test_preds, test_labels - - -if __name__ == "__main__": - import sys - - buffer_dir = os.path.join("results_buffer", sys.argv[1], "ChEBIOver100_train") - buffer_dir_concat = os.path.join( - "results_buffer", "concatenated", sys.argv[1], "ChEBIOver100_train" - ) - os.makedirs(buffer_dir_concat, exist_ok=True) - preds, labels = load_results_from_buffer(buffer_dir, "cpu") - torch.save(preds, os.path.join(buffer_dir_concat, f"preds000.pt")) - torch.save(labels, os.path.join(buffer_dir_concat, f"labels000.pt")) diff --git a/chebai/trainer/CustomTrainer.py b/chebai/trainer/CustomTrainer.py deleted file mode 100644 index cb76199..0000000 --- a/chebai/trainer/CustomTrainer.py +++ /dev/null @@ -1,149 +0,0 @@ -import logging -from typing import Any, List, Optional, Tuple - -import pandas as pd -import torch -from lightning import LightningModule, Trainer -from lightning.fabric.utilities.types import _PATH -from lightning.pytorch.loggers import WandbLogger -from torch.nn.utils.rnn import pad_sequence - -from chebai.loggers.custom import CustomLogger -from chebai.preprocessing.reader import CLS_TOKEN, ProteinDataReader - -log = logging.getLogger(__name__) - - -class CustomTrainer(Trainer): - def __init__(self, *args, **kwargs): - """ - Initializes the CustomTrainer class, logging additional hyperparameters to the custom logger if specified. - - Args: - *args: Positional arguments for the Trainer class. - **kwargs: Keyword arguments for the Trainer class. - """ - self.init_args = args - self.init_kwargs = kwargs - super().__init__(*args, **kwargs) - # instantiation custom logger connector - self._logger_connector.on_trainer_init(self.logger, 1) - # log additional hyperparameters to wandb - if isinstance(self.logger, CustomLogger): - custom_logger = self.logger - assert isinstance(custom_logger, CustomLogger) - if custom_logger.verbose_hyperparameters: - log_kwargs = {} - for key, value in self.init_kwargs.items(): - log_key, log_value = self._resolve_logging_argument(key, value) - log_kwargs[log_key] = log_value - self.logger.log_hyperparams(log_kwargs) - - def _resolve_logging_argument(self, key: str, value: Any) -> Tuple[str, Any]: - """ - Resolves logging arguments, handling nested structures such as lists and complex objects. - - Args: - key: The key of the argument. - value: The value of the argument. - - Returns: - A tuple containing the resolved key and value. - """ - if isinstance(value, list): - key_value_pairs = [ - self._resolve_logging_argument(f"{key}_{i}", v) - for i, v in enumerate(value) - ] - return key, {k: v for k, v in key_value_pairs} - if not ( - isinstance(value, str) - or isinstance(value, float) - or isinstance(value, int) - or value is None - ): - params = {"class": value.__class__} - params.update(value.__dict__) - return key, params - else: - return key, value - - def predict_from_file( - self, - model: LightningModule, - checkpoint_path: _PATH, - input_path: _PATH, - save_to: _PATH = "predictions.csv", - classes_path: Optional[_PATH] = None, - ) -> None: - """ - Loads a model from a checkpoint and makes predictions on input data from a file. - - Args: - model: The model to use for predictions. - checkpoint_path: Path to the model checkpoint. - input_path: Path to the input file containing SMILES strings. - save_to: Path to save the predictions CSV file. - classes_path: Optional path to a file containing class names. - """ - loaded_model = model.__class__.load_from_checkpoint(checkpoint_path) - with open(input_path, "r") as input: - smiles_strings = [inp.strip() for inp in input.readlines()] - loaded_model.eval() - predictions = self._predict_smiles(loaded_model, smiles_strings) - predictions_df = pd.DataFrame(predictions.detach().cpu().numpy()) - if classes_path is not None: - with open(classes_path, "r") as f: - predictions_df.columns = [cls.strip() for cls in f.readlines()] - predictions_df.index = smiles_strings - predictions_df.to_csv(save_to) - - def _predict_smiles( - self, model: LightningModule, sequence: List[str] - ) -> torch.Tensor: - """ - Predicts the output for a list of SMILES strings using the model. - - Args: - model: The model to use for predictions. - sequence: Protein sequence. - - Returns: - A tensor containing the predictions. - """ - reader = ProteinDataReader() - parsed_sequence = [reader._read_data(s) for s in sequence] - x = pad_sequence( - [torch.tensor(a, device=model.device) for a in parsed_sequence], - batch_first=True, - ) - cls_tokens = ( - torch.ones(x.shape[0], dtype=torch.int, device=model.device).unsqueeze(-1) - * CLS_TOKEN - ) - features = torch.cat((cls_tokens, x), dim=1) - model_output = model({"features": features}) - preds = torch.sigmoid(model_output["logits"]) - - print(preds.shape) - return preds - - @property - def log_dir(self) -> Optional[str]: - """ - Returns the logging directory. - - Returns: - The path to the logging directory if available, else the default root directory. - """ - if len(self.loggers) > 0: - logger = self.loggers[0] - if isinstance(logger, WandbLogger): - dirpath = logger.experiment.dir - else: - dirpath = self.loggers[0].log_dir - else: - dirpath = self.default_root_dir - - dirpath = self.strategy.broadcast(dirpath) - return dirpath diff --git a/chebai/trainer/__init__.py b/chebai/trainer/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/configs/default_prediction_callback.yml b/configs/default_prediction_callback.yml deleted file mode 100644 index 152b5d1..0000000 --- a/configs/default_prediction_callback.yml +++ /dev/null @@ -1,4 +0,0 @@ -class_path: chebai.callbacks.prediction_callback.PredictionWriter -init_args: - output_dir: pred - write_interval: epoch diff --git a/configs/loss/bce.yml b/configs/loss/bce.yml deleted file mode 100644 index e2fc30b..0000000 --- a/configs/loss/bce.yml +++ /dev/null @@ -1 +0,0 @@ -class_path: chebai.loss.bce_weighted.BCEWeighted diff --git a/configs/loss/electra_pre_loss.yml b/configs/loss/electra_pre_loss.yml deleted file mode 100644 index 06520b2..0000000 --- a/configs/loss/electra_pre_loss.yml +++ /dev/null @@ -1 +0,0 @@ -class_path: chebai.loss.pretraining.ElectraPreLoss diff --git a/configs/loss/semantic_loss.yml b/configs/loss/semantic_loss.yml deleted file mode 100644 index 5434084..0000000 --- a/configs/loss/semantic_loss.yml +++ /dev/null @@ -1,10 +0,0 @@ -class_path: chebai.loss.semantic.DisjointLoss -init_args: - path_to_disjointness: data/disjoint.csv - base_loss: - class_path: chebai.loss.bce_weighted.BCEWeighted - init_args: - beta: 0.99 - multiply_by_softmax: true - impl_loss_weight: 100 - disjoint_loss_weight: 1000000 diff --git a/configs/metrics/balanced-accuracy.yml b/configs/metrics/balanced-accuracy.yml deleted file mode 100644 index eb079ed..0000000 --- a/configs/metrics/balanced-accuracy.yml +++ /dev/null @@ -1,5 +0,0 @@ -class_path: torchmetrics.MetricCollection -init_args: - metrics: - balanced-accuracy: - class_path: chebai.callbacks.epoch_metrics.BalancedAccuracy diff --git a/configs/metrics/micro-macro-f1.yml b/configs/metrics/micro-macro-f1.yml deleted file mode 100644 index 9cae109..0000000 --- a/configs/metrics/micro-macro-f1.yml +++ /dev/null @@ -1,9 +0,0 @@ -class_path: torchmetrics.MetricCollection -init_args: - metrics: - micro-f1: - class_path: torchmetrics.classification.MultilabelF1Score - init_args: - average: micro - macro-f1: - class_path: chebai.callbacks.epoch_metrics.MacroF1 diff --git a/configs/metrics/single-class-f1.yml b/configs/metrics/single-class-f1.yml deleted file mode 100644 index fbcd63d..0000000 --- a/configs/metrics/single-class-f1.yml +++ /dev/null @@ -1,5 +0,0 @@ -class_path: torchmetrics.MetricCollection -init_args: - metrics: - f1: - class_path: torchmetrics.classification.BinaryF1Score diff --git a/configs/model/electra-for-pretraining.yml b/configs/model/electra-for-pretraining.yml deleted file mode 100644 index 80acd9a..0000000 --- a/configs/model/electra-for-pretraining.yml +++ /dev/null @@ -1,20 +0,0 @@ -class_path: chebai.models.ElectraPre -init_args: - criterion: - class_path: chebai.loss.pretraining.ElectraPreLoss - out_dim: null - optimizer_kwargs: - lr: 1e-4 - config: - generator: - vocab_size: 1400 - max_position_embeddings: 1800 - num_attention_heads: 8 - num_hidden_layers: 6 - type_vocab_size: 1 - discriminator: - vocab_size: 1400 - max_position_embeddings: 1800 - num_attention_heads: 8 - num_hidden_layers: 6 - type_vocab_size: 1 diff --git a/configs/model/electra.yml b/configs/model/electra.yml deleted file mode 100644 index 94d1dc6..0000000 --- a/configs/model/electra.yml +++ /dev/null @@ -1,11 +0,0 @@ -class_path: chebai.models.Electra -init_args: - optimizer_kwargs: - lr: 1e-3 - config: - vocab_size: 31 # 21 amino acids (when n_gram=1) + 10 special tokens of LLM - max_position_embeddings: 1000 # max default sequence length for protein - num_attention_heads: 8 - num_hidden_layers: 6 - type_vocab_size: 1 - hidden_size: 256 diff --git a/configs/model/electra_pretraining.yml b/configs/model/electra_pretraining.yml deleted file mode 100644 index f480a79..0000000 --- a/configs/model/electra_pretraining.yml +++ /dev/null @@ -1,18 +0,0 @@ -class_path: chebai.models.ElectraPre -init_args: - out_dim: null - optimizer_kwargs: - lr: 1e-4 - config: - generator: - vocab_size: 1400 - max_position_embeddings: 1800 - num_attention_heads: 8 - num_hidden_layers: 6 - type_vocab_size: 1 - discriminator: - vocab_size: 1400 - max_position_embeddings: 1800 - num_attention_heads: 8 - num_hidden_layers: 6 - type_vocab_size: 1 diff --git a/configs/model/ffn.yml b/configs/model/ffn.yml deleted file mode 100644 index ba94a43..0000000 --- a/configs/model/ffn.yml +++ /dev/null @@ -1,5 +0,0 @@ -class_path: chebai.models.ffn.FFN -init_args: - optimizer_kwargs: - lr: 1e-3 - input_size: 2560 diff --git a/configs/training/csv_logger.yml b/configs/training/csv_logger.yml deleted file mode 100644 index 86a94ba..0000000 --- a/configs/training/csv_logger.yml +++ /dev/null @@ -1,3 +0,0 @@ -class_path: lightning.pytorch.loggers.CSVLogger -init_args: - save_dir: logs diff --git a/configs/training/default_callbacks.yml b/configs/training/default_callbacks.yml deleted file mode 100644 index ade7d14..0000000 --- a/configs/training/default_callbacks.yml +++ /dev/null @@ -1,12 +0,0 @@ -- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint - init_args: - monitor: val_micro-f1 - mode: 'max' - filename: 'best_{epoch:02d}_{val_loss:.4f}_{val_macro-f1:.4f}_{val_micro-f1:.4f}' - every_n_epochs: 1 - save_top_k: 3 -- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint - init_args: - filename: 'per_{epoch:02d}_{val_loss:.4f}_{val_macro-f1:.4f}_{val_micro-f1:.4f}' - every_n_epochs: 25 - save_top_k: -1 diff --git a/configs/training/default_trainer.yml b/configs/training/default_trainer.yml deleted file mode 100644 index 91aa424..0000000 --- a/configs/training/default_trainer.yml +++ /dev/null @@ -1,5 +0,0 @@ -min_epochs: 100 -max_epochs: 100 -default_root_dir: &default_root_dir logs -logger: csv_logger.yml -callbacks: default_callbacks.yml diff --git a/configs/training/early_stop_callbacks.yml b/configs/training/early_stop_callbacks.yml deleted file mode 100644 index 9113090..0000000 --- a/configs/training/early_stop_callbacks.yml +++ /dev/null @@ -1,19 +0,0 @@ -- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint - init_args: - monitor: val_micro-f1 - mode: 'max' - filename: 'best_{epoch:02d}_{val_loss:.4f}_{val_macro-f1:.4f}_{val_micro-f1:.4f}' - every_n_epochs: 1 - save_top_k: 3 -- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint - init_args: - filename: 'per_{epoch:02d}_{val_loss:.4f}_{val_macro-f1:.4f}_{val_micro-f1:.4f}' - every_n_epochs: 25 - save_top_k: -1 -- class_path: lightning.pytorch.callbacks.early_stopping.EarlyStopping - init_args: - monitor: "val_loss_epoch" - min_delta: 0.0 - patience: 3 - verbose: False - mode: "min" diff --git a/configs/training/pretraining_callbacks.yml b/configs/training/pretraining_callbacks.yml deleted file mode 100644 index 0862433..0000000 --- a/configs/training/pretraining_callbacks.yml +++ /dev/null @@ -1,12 +0,0 @@ -- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint - init_args: - monitor: val_loss - mode: 'min' - filename: 'best_{epoch}_{val_loss:.4f}' - every_n_epochs: 1 - save_top_k: 3 -- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint - init_args: - filename: 'per_{epoch}_{val_loss:.4f}' - every_n_epochs: 25 - save_top_k: -1 diff --git a/configs/training/pretraining_trainer.yml b/configs/training/pretraining_trainer.yml deleted file mode 100644 index 6c56870..0000000 --- a/configs/training/pretraining_trainer.yml +++ /dev/null @@ -1,7 +0,0 @@ -min_epochs: 100 -max_epochs: 100 - -default_root_dir: &default_root_dir logs -logger: csv_logger.yml - -callbacks: pretraining_callbacks.yml diff --git a/configs/training/single_class_callbacks.yml b/configs/training/single_class_callbacks.yml deleted file mode 100644 index 73f4a72..0000000 --- a/configs/training/single_class_callbacks.yml +++ /dev/null @@ -1,13 +0,0 @@ -- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint - init_args: - monitor: val_f1 - mode: 'max' - filename: 'best_{epoch:02d}_{val_loss:.4f}_{val_f1:.4f}' - every_n_epochs: 1 - save_top_k: 3 -- class_path: chebai.callbacks.model_checkpoint.CustomModelCheckpoint - init_args: - filename: 'per_{epoch:02d}_{val_loss:.4f}_{val_f1:.4f}' - every_n_epochs: 25 - save_top_k: -1 -# difference to default_callbacks.yml: no macro-f1 diff --git a/configs/training/wandb_logger.yml b/configs/training/wandb_logger.yml deleted file mode 100644 index b0dd887..0000000 --- a/configs/training/wandb_logger.yml +++ /dev/null @@ -1,6 +0,0 @@ -class_path: chebai.loggers.custom.CustomLogger # Extension of Wandb logger -init_args: - save_dir: logs - project: 'chebai' - entity: 'chebai' - log_model: 'all' diff --git a/docs/source/experiment.rst b/docs/source/experiment.rst deleted file mode 100644 index 59aced7..0000000 --- a/docs/source/experiment.rst +++ /dev/null @@ -1 +0,0 @@ -.. autoclass:: chebai.experiments.Experiment diff --git a/docs/source/model.rst b/docs/source/model.rst deleted file mode 100644 index 59aced7..0000000 --- a/docs/source/model.rst +++ /dev/null @@ -1 +0,0 @@ -.. autoclass:: chebai.experiments.Experiment diff --git a/tests/unit/collators/__init__.py b/tests/unit/collators/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/tests/unit/collators/testDefaultCollator.py b/tests/unit/collators/testDefaultCollator.py deleted file mode 100644 index 73f09c7..0000000 --- a/tests/unit/collators/testDefaultCollator.py +++ /dev/null @@ -1,65 +0,0 @@ -import unittest -from typing import Dict, List - -from chebai.preprocessing.collate import DefaultCollator -from chebai.preprocessing.structures import XYData - - -class TestDefaultCollator(unittest.TestCase): - """ - Unit tests for the DefaultCollator class. - """ - - @classmethod - def setUpClass(cls) -> None: - """ - Set up the test environment by initializing a DefaultCollator instance. - """ - cls.collator = DefaultCollator() - - def test_call_with_valid_data(self) -> None: - """ - Test the __call__ method with valid data to ensure features and labels are correctly extracted. - """ - data: List[Dict] = [ - {"features": [1.0, 2.0], "labels": [True, False, True]}, - {"features": [3.0, 4.0], "labels": [False, False, True]}, - ] - - result: XYData = self.collator(data) - self.assertIsInstance( - result, XYData, "The result should be an instance of XYData." - ) - - expected_x = ([1.0, 2.0], [3.0, 4.0]) - expected_y = ([True, False, True], [False, False, True]) - - self.assertEqual( - result.x, - expected_x, - "The feature data 'x' does not match the expected output.", - ) - self.assertEqual( - result.y, - expected_y, - "The label data 'y' does not match the expected output.", - ) - - def test_call_with_empty_data(self) -> None: - """ - Test the __call__ method with an empty list to ensure it handles the edge case correctly. - """ - data: List[Dict] = [] - - with self.assertRaises(ValueError) as context: - self.collator(data) - - self.assertEqual( - str(context.exception), - "not enough values to unpack (expected 2, got 0)", - "The exception message for empty data is not as expected.", - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/unit/collators/testRaggedCollator.py b/tests/unit/collators/testRaggedCollator.py deleted file mode 100644 index d9ab2b1..0000000 --- a/tests/unit/collators/testRaggedCollator.py +++ /dev/null @@ -1,204 +0,0 @@ -import unittest -from typing import Dict, List, Tuple - -import torch - -from chebai.preprocessing.collate import RaggedCollator -from chebai.preprocessing.structures import XYData - - -class TestRaggedCollator(unittest.TestCase): - """ - Unit tests for the RaggedCollator class. - """ - - @classmethod - def setUpClass(cls) -> None: - """ - Set up the test environment by initializing a RaggedCollator instance. - """ - cls.collator = RaggedCollator() - - def test_call_with_valid_data(self) -> None: - """ - Test the __call__ method with valid ragged data to ensure features, labels, and masks are correctly handled. - """ - data: List[Dict] = [ - {"features": [1, 2], "labels": [True, False], "ident": "sample1"}, - {"features": [3, 4, 5], "labels": [False, True, True], "ident": "sample2"}, - {"features": [6], "labels": [True], "ident": "sample3"}, - ] - - result: XYData = self.collator(data) - - expected_x = torch.tensor([[1, 2, 0], [3, 4, 5], [6, 0, 0]]) - expected_y = torch.tensor( - [[True, False, False], [False, True, True], [True, False, False]] - ) - expected_mask_for_x = torch.tensor( - [[True, True, False], [True, True, True], [True, False, False]] - ) - expected_lens_for_x = torch.tensor([2, 3, 1]) - - self.assertTrue( - torch.equal(result.x, expected_x), - "The feature tensor 'x' does not match the expected output.", - ) - self.assertTrue( - torch.equal(result.y, expected_y), - "The label tensor 'y' does not match the expected output.", - ) - self.assertTrue( - torch.equal( - result.additional_fields["model_kwargs"]["mask"], expected_mask_for_x - ), - "The mask tensor does not match the expected output.", - ) - self.assertTrue( - torch.equal( - result.additional_fields["model_kwargs"]["lens"], expected_lens_for_x - ), - "The lens tensor does not match the expected output.", - ) - self.assertEqual( - result.additional_fields["idents"], - ("sample1", "sample2", "sample3"), - "The identifiers do not match the expected output.", - ) - - def test_call_with_missing_entire_labels(self) -> None: - """ - Test the __call__ method with data where some samples are missing labels. - """ - data: List[Dict] = [ - {"features": [1, 2], "labels": [True, False], "ident": "sample1"}, - {"features": [3, 4, 5], "labels": None, "ident": "sample2"}, - {"features": [6], "labels": [True], "ident": "sample3"}, - ] - - result: XYData = self.collator(data) - - # https://github.com/ChEB-AI/python-chebai/pull/48#issuecomment-2324393829 - expected_x = torch.tensor([[1, 2, 0], [3, 4, 5], [6, 0, 0]]) - expected_y = torch.tensor( - [[True, False], [True, False]] - ) # True -> 1, False -> 0 - expected_mask_for_x = torch.tensor( - [[True, True, False], [True, True, True], [True, False, False]] - ) - expected_lens_for_x = torch.tensor([2, 3, 1]) - - self.assertTrue( - torch.equal(result.x, expected_x), - "The feature tensor 'x' does not match the expected output when labels are missing.", - ) - self.assertTrue( - torch.equal(result.y, expected_y), - "The label tensor 'y' does not match the expected output when labels are missing.", - ) - self.assertTrue( - torch.equal( - result.additional_fields["model_kwargs"]["mask"], expected_mask_for_x - ), - "The mask tensor does not match the expected output when labels are missing.", - ) - self.assertTrue( - torch.equal( - result.additional_fields["model_kwargs"]["lens"], expected_lens_for_x - ), - "The lens tensor does not match the expected output when labels are missing.", - ) - self.assertEqual( - result.additional_fields["loss_kwargs"]["non_null_labels"], - [0, 2], - "The non-null labels list does not match the expected output.", - ) - self.assertEqual( - len(result.additional_fields["loss_kwargs"]["non_null_labels"]), - result.y.shape[1], - "The length of non null labels list must match with target label variable size", - ) - self.assertEqual( - result.additional_fields["idents"], - ("sample1", "sample2", "sample3"), - "The identifiers do not match the expected output when labels are missing.", - ) - - def test_call_with_none_in_labels(self) -> None: - """ - Test the __call__ method with data where one of the elements in the labels is None. - """ - data: List[Dict] = [ - {"features": [1, 2], "labels": [None, True], "ident": "sample1"}, - {"features": [3, 4, 5], "labels": [True, False], "ident": "sample2"}, - {"features": [6], "labels": [True], "ident": "sample3"}, - ] - - result: XYData = self.collator(data) - - expected_x = torch.tensor([[1, 2, 0], [3, 4, 5], [6, 0, 0]]) - expected_y = torch.tensor( - [[False, True], [True, False], [True, False]] - ) # None -> False - expected_mask_for_x = torch.tensor( - [[True, True, False], [True, True, True], [True, False, False]] - ) - expected_lens_for_x = torch.tensor([2, 3, 1]) - - self.assertTrue( - torch.equal(result.x, expected_x), - "The feature tensor 'x' does not match the expected output when labels contain None.", - ) - self.assertTrue( - torch.equal(result.y, expected_y), - "The label tensor 'y' does not match the expected output when labels contain None.", - ) - self.assertTrue( - torch.equal( - result.additional_fields["model_kwargs"]["mask"], expected_mask_for_x - ), - "The mask tensor does not match the expected output when labels contain None.", - ) - self.assertTrue( - torch.equal( - result.additional_fields["model_kwargs"]["lens"], expected_lens_for_x - ), - "The lens tensor does not match the expected output when labels contain None.", - ) - self.assertEqual( - result.additional_fields["idents"], - ("sample1", "sample2", "sample3"), - "The identifiers do not match the expected output when labels contain None.", - ) - - def test_call_with_empty_data(self) -> None: - """ - Test the __call__ method with an empty list to ensure it raises an error. - """ - data: List[Dict] = [] - - with self.assertRaises( - Exception, msg="Expected an Error when no data is provided" - ): - self.collator(data) - - def test_process_label_rows(self) -> None: - """ - Test the process_label_rows method to ensure it pads label sequences correctly. - """ - labels: Tuple = ([True, False], [False, True, True], [True]) - - result: torch.Tensor = self.collator.process_label_rows(labels) - - expected_output = torch.tensor( - [[True, False, False], [False, True, True], [True, False, False]] - ) - - self.assertTrue( - torch.equal(result, expected_output), - "The processed label rows tensor does not match the expected output.", - ) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/unit/dataset_classes/testDynamicDataset.py b/tests/unit/dataset_classes/testDynamicDataset.py deleted file mode 100644 index c884627..0000000 --- a/tests/unit/dataset_classes/testDynamicDataset.py +++ /dev/null @@ -1,372 +0,0 @@ -import unittest -from typing import Tuple -from unittest.mock import MagicMock, PropertyMock, patch - -import pandas as pd - -from chebai.preprocessing.datasets.base import _DynamicDataset - - -class TestDynamicDataset(unittest.TestCase): - """ - Test case for _DynamicDataset functionality, ensuring correct data splits and integrity - of train, validation, and test datasets. - """ - - @classmethod - @patch.multiple(_DynamicDataset, __abstractmethods__=frozenset()) - @patch.object(_DynamicDataset, "base_dir", new_callable=PropertyMock) - @patch.object(_DynamicDataset, "_name", new_callable=PropertyMock) - @patch("os.makedirs", return_value=None) - def setUpClass( - cls, - mock_makedirs, - mock_base_dir_property: PropertyMock, - mock_name_property: PropertyMock, - ) -> None: - """ - Set up a base instance of _DynamicDataset for testing with mocked properties. - """ - - # Mocking properties - mock_base_dir_property.return_value = "MockedBaseDirPropertyDynamicDataset" - mock_name_property.return_value = "MockedNamePropertyDynamicDataset" - - # Mock Data Reader - ReaderMock = MagicMock() - ReaderMock.name.return_value = "MockedReader" - _DynamicDataset.READER = ReaderMock - - # Creating an instance of the dataset - cls.dataset: _DynamicDataset = _DynamicDataset() - - # Dataset with a balanced distribution of labels - X = [ - [1, 2], - [3, 4], - [5, 6], - [7, 8], - [9, 10], - [11, 12], - [13, 14], - [15, 16], - [17, 18], - [19, 20], - [21, 22], - [23, 24], - [25, 26], - [27, 28], - [29, 30], - [31, 32], - ] - y = [ - [False, False], - [False, True], - [True, False], - [True, True], - [False, False], - [False, True], - [True, False], - [True, True], - [False, False], - [False, True], - [True, False], - [True, True], - [False, False], - [False, True], - [True, False], - [True, True], - ] - cls.data_df = pd.DataFrame( - {"ident": [f"id{i + 1}" for i in range(len(X))], "features": X, "labels": y} - ) - - def test_get_test_split_valid(self) -> None: - """ - Test splitting the dataset into train and test sets and verify balance and non-overlap. - """ - self.dataset.train_split = 0.5 - # Test size will be 0.25 * 16 = 4 - train_df, test_df = self.dataset.get_test_split(self.data_df, seed=0) - - # Assert the correct number of rows in train and test sets - self.assertEqual(len(train_df), 12, "Train set should contain 12 samples.") - self.assertEqual(len(test_df), 4, "Test set should contain 4 samples.") - - # Check positive and negative label counts in train and test sets - train_pos_count, train_neg_count = self.get_positive_negative_labels_counts( - train_df - ) - test_pos_count, test_neg_count = self.get_positive_negative_labels_counts( - test_df - ) - - # Ensure that the train and test sets have balanced positives and negatives - self.assertEqual( - train_pos_count, train_neg_count, "Train set labels should be balanced." - ) - self.assertEqual( - test_pos_count, test_neg_count, "Test set labels should be balanced." - ) - - # Assert there is no overlap between train and test sets - train_idents = set(train_df["ident"]) - test_idents = set(test_df["ident"]) - self.assertEqual( - len(train_idents.intersection(test_idents)), - 0, - "Train and test sets should not overlap.", - ) - - def test_get_test_split_missing_labels(self) -> None: - """ - Test the behavior when the 'labels' column is missing in the dataset. - """ - df_missing_labels = pd.DataFrame({"ident": ["id1", "id2"]}) - with self.assertRaises( - KeyError, msg="Expected KeyError when 'labels' column is missing." - ): - self.dataset.get_test_split(df_missing_labels) - - def test_get_test_split_seed_consistency(self) -> None: - """ - Test that splitting the dataset with the same seed produces consistent results. - """ - train_df1, test_df1 = self.dataset.get_test_split(self.data_df, seed=42) - train_df2, test_df2 = self.dataset.get_test_split(self.data_df, seed=42) - - pd.testing.assert_frame_equal( - train_df1, - train_df2, - obj="Train sets should be identical for the same seed.", - ) - pd.testing.assert_frame_equal( - test_df1, test_df2, obj="Test sets should be identical for the same seed." - ) - - def test_get_train_val_splits_given_test(self) -> None: - """ - Test splitting the dataset into train and validation sets and verify balance and non-overlap. - """ - self.dataset.use_inner_cross_validation = False - self.dataset.train_split = 0.5 - df_train_main, test_df = self.dataset.get_test_split(self.data_df, seed=0) - train_df, val_df = self.dataset.get_train_val_splits_given_test( - df_train_main, test_df, seed=42 - ) - - # Ensure there is no overlap between train and test sets - train_idents = set(train_df["ident"]) - test_idents = set(test_df["ident"]) - self.assertEqual( - len(train_idents.intersection(test_idents)), - 0, - "Train and test sets should not overlap.", - ) - - # Ensure there is no overlap between validation and test sets - val_idents = set(val_df["ident"]) - self.assertEqual( - len(val_idents.intersection(test_idents)), - 0, - "Validation and test sets should not overlap.", - ) - - # Ensure there is no overlap between train and validation sets - self.assertEqual( - len(train_idents.intersection(val_idents)), - 0, - "Train and validation sets should not overlap.", - ) - - # Check positive and negative label counts in train and validation sets - train_pos_count, train_neg_count = self.get_positive_negative_labels_counts( - train_df - ) - val_pos_count, val_neg_count = self.get_positive_negative_labels_counts(val_df) - - # Ensure that the train and validation sets have balanced positives and negatives - self.assertEqual( - train_pos_count, train_neg_count, "Train set labels should be balanced." - ) - self.assertEqual( - val_pos_count, val_neg_count, "Validation set labels should be balanced." - ) - - def test_get_train_val_splits_given_test_consistency(self) -> None: - """ - Test that splitting the dataset into train and validation sets with the same seed produces consistent results. - """ - test_df = self.data_df.iloc[12:] # Assume rows 12 onward are for testing - train_df1, val_df1 = self.dataset.get_train_val_splits_given_test( - self.data_df, test_df, seed=42 - ) - train_df2, val_df2 = self.dataset.get_train_val_splits_given_test( - self.data_df, test_df, seed=42 - ) - - pd.testing.assert_frame_equal( - train_df1, - train_df2, - obj="Train sets should be identical for the same seed.", - ) - pd.testing.assert_frame_equal( - val_df1, - val_df2, - obj="Validation sets should be identical for the same seed.", - ) - - def test_get_test_split_stratification(self) -> None: - """ - Test that the split into train and test sets maintains the stratification of labels. - """ - self.dataset.train_split = 0.5 - train_df, test_df = self.dataset.get_test_split(self.data_df, seed=0) - - number_of_labels = len(self.data_df["labels"][0]) - - # Check the label distribution in the original dataset - original_pos_count, original_neg_count = ( - self.get_positive_negative_labels_counts(self.data_df) - ) - total_count = len(self.data_df) * number_of_labels - - # Calculate the expected proportions - original_pos_proportion = original_pos_count / total_count - original_neg_proportion = original_neg_count / total_count - - # Check the label distribution in the train set - train_pos_count, train_neg_count = self.get_positive_negative_labels_counts( - train_df - ) - train_total_count = len(train_df) * number_of_labels - - # Calculate the train set proportions - train_pos_proportion = train_pos_count / train_total_count - train_neg_proportion = train_neg_count / train_total_count - - # Assert that the proportions are similar to the original dataset - self.assertAlmostEqual( - train_pos_proportion, - original_pos_proportion, - places=1, - msg="Train set labels should maintain original positive label proportion.", - ) - self.assertAlmostEqual( - train_neg_proportion, - original_neg_proportion, - places=1, - msg="Train set labels should maintain original negative label proportion.", - ) - - # Check the label distribution in the test set - test_pos_count, test_neg_count = self.get_positive_negative_labels_counts( - test_df - ) - test_total_count = len(test_df) * number_of_labels - - # Calculate the test set proportions - test_pos_proportion = test_pos_count / test_total_count - test_neg_proportion = test_neg_count / test_total_count - - # Assert that the proportions are similar to the original dataset - self.assertAlmostEqual( - test_pos_proportion, - original_pos_proportion, - places=1, - msg="Test set labels should maintain original positive label proportion.", - ) - self.assertAlmostEqual( - test_neg_proportion, - original_neg_proportion, - places=1, - msg="Test set labels should maintain original negative label proportion.", - ) - - def test_get_train_val_splits_given_test_stratification(self) -> None: - """ - Test that the split into train and validation sets maintains the stratification of labels. - """ - self.dataset.use_inner_cross_validation = False - self.dataset.train_split = 0.5 - df_train_main, test_df = self.dataset.get_test_split(self.data_df, seed=0) - train_df, val_df = self.dataset.get_train_val_splits_given_test( - df_train_main, test_df, seed=42 - ) - - number_of_labels = len(self.data_df["labels"][0]) - - # Check the label distribution in the original dataset - original_pos_count, original_neg_count = ( - self.get_positive_negative_labels_counts(self.data_df) - ) - total_count = len(self.data_df) * number_of_labels - - # Calculate the expected proportions - original_pos_proportion = original_pos_count / total_count - original_neg_proportion = original_neg_count / total_count - - # Check the label distribution in the train set - train_pos_count, train_neg_count = self.get_positive_negative_labels_counts( - train_df - ) - train_total_count = len(train_df) * number_of_labels - - # Calculate the train set proportions - train_pos_proportion = train_pos_count / train_total_count - train_neg_proportion = train_neg_count / train_total_count - - # Assert that the proportions are similar to the original dataset - self.assertAlmostEqual( - train_pos_proportion, - original_pos_proportion, - places=1, - msg="Train set labels should maintain original positive label proportion.", - ) - self.assertAlmostEqual( - train_neg_proportion, - original_neg_proportion, - places=1, - msg="Train set labels should maintain original negative label proportion.", - ) - - # Check the label distribution in the validation set - val_pos_count, val_neg_count = self.get_positive_negative_labels_counts(val_df) - val_total_count = len(val_df) * number_of_labels - - # Calculate the validation set proportions - val_pos_proportion = val_pos_count / val_total_count - val_neg_proportion = val_neg_count / val_total_count - - # Assert that the proportions are similar to the original dataset - self.assertAlmostEqual( - val_pos_proportion, - original_pos_proportion, - places=1, - msg="Validation set labels should maintain original positive label proportion.", - ) - self.assertAlmostEqual( - val_neg_proportion, - original_neg_proportion, - places=1, - msg="Validation set labels should maintain original negative label proportion.", - ) - - @staticmethod - def get_positive_negative_labels_counts(df: pd.DataFrame) -> Tuple[int, int]: - """ - Count the number of True and False values within the labels column. - - Args: - df (pd.DataFrame): The DataFrame containing the 'labels' column. - - Returns: - Tuple[int, int]: A tuple containing the counts of True and False values, respectively. - """ - true_count = sum(sum(label) for label in df["labels"]) - false_count = sum(len(label) - sum(label) for label in df["labels"]) - return true_count, false_count - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/unit/dataset_classes/testXYBaseDataModule.py b/tests/unit/dataset_classes/testXYBaseDataModule.py deleted file mode 100644 index 64dfbe4..0000000 --- a/tests/unit/dataset_classes/testXYBaseDataModule.py +++ /dev/null @@ -1,92 +0,0 @@ -import unittest -from unittest.mock import MagicMock, PropertyMock, patch - -from chebai.preprocessing.datasets.base import XYBaseDataModule - - -class TestXYBaseDataModule(unittest.TestCase): - """ - Unit tests for the methods of the XYBaseDataModule class. - """ - - @classmethod - @patch.object(XYBaseDataModule, "_name", new_callable=PropertyMock) - @patch("os.makedirs", return_value=None) - def setUpClass(cls, mock_makedirs, mock_name_property: PropertyMock) -> None: - """ - Set up a base instance of XYBaseDataModule for testing. - """ - - # Mock the _name property of XYBaseDataModule - mock_name_property.return_value = "MockedNamePropXYBaseDataModule" - - # Assign a static variable READER with ProteinDataReader (to get rid of default Abstract DataReader) - # Mock Data Reader - ReaderMock = MagicMock() - ReaderMock.name.return_value = "MockedReader" - XYBaseDataModule.READER = ReaderMock - - # Initialize the module with a label_filter - cls.module = XYBaseDataModule( - label_filter=1, # Provide a label_filter - balance_after_filter=1.0, # Balance ratio - ) - - def test_filter_labels_valid_index(self) -> None: - """ - Test the _filter_labels method with a valid label_filter index. - """ - self.module.label_filter = 1 - row = { - "features": ["feature1", "feature2"], - "labels": [0, 3, 1, 2], # List of labels - } - filtered_row = self.module._filter_labels(row) - expected_labels = [3] # Only the label at index 1 should be kept - - self.assertEqual( - filtered_row["labels"], - expected_labels, - "The filtered labels do not match the expected labels.", - ) - - row = { - "features": ["feature1", "feature2"], - "labels": [True, False, True, True], - } - self.assertEqual( - self.module._filter_labels(row)["labels"], - [False], - "The filtered labels for the boolean case do not match the expected labels.", - ) - - def test_filter_labels_no_filter(self) -> None: - """ - Test the _filter_labels method with no label_filter index. - """ - # Update the module to have no label filter - self.module.label_filter = None - row = {"features": ["feature1", "feature2"], "labels": [False, True]} - # Handle the case where the index is out of bounds - with self.assertRaises( - TypeError, msg="Expected a TypeError when no label filter is provided." - ): - self.module._filter_labels(row) - - def test_filter_labels_invalid_index(self) -> None: - """ - Test the _filter_labels method with an invalid label_filter index. - """ - # Set an invalid label filter index (e.g., greater than the number of labels) - self.module.label_filter = 10 - row = {"features": ["feature1", "feature2"], "labels": [False, True]} - # Handle the case where the index is out of bounds - with self.assertRaises( - IndexError, - msg="Expected an IndexError when the label filter index is out of bounds.", - ): - self.module._filter_labels(row) - - -if __name__ == "__main__": - unittest.main() diff --git a/tests/unit/readers/testDataReader.py b/tests/unit/readers/testDataReader.py deleted file mode 100644 index 745c0ac..0000000 --- a/tests/unit/readers/testDataReader.py +++ /dev/null @@ -1,56 +0,0 @@ -import unittest -from typing import Any, Dict, List - -from chebai.preprocessing.reader import DataReader - - -class TestDataReader(unittest.TestCase): - """ - Unit tests for the DataReader class. - """ - - @classmethod - def setUpClass(cls) -> None: - """ - Set up the test environment by initializing a DataReader instance. - """ - cls.reader = DataReader() - - def test_to_data(self) -> None: - """ - Test the to_data method to ensure it correctly processes the input row - and formats it according to the expected output. - - This method tests the conversion of raw data into a processed format, - including extracting features, labels, ident, group, and additional - keyword arguments. - """ - features_list: List[int] = [10, 20, 30] - labels_list: List[bool] = [True, False, True] - ident_no: int = 123 - - row: Dict[str, Any] = { - "features": features_list, - "labels": labels_list, - "ident": ident_no, - "group": "group_data", - "additional_kwargs": {"extra_key": "extra_value"}, - } - - expected: Dict[str, Any] = { - "features": features_list, - "labels": labels_list, - "ident": ident_no, - "group": "group_data", - "extra_key": "extra_value", - } - - self.assertEqual( - self.reader.to_data(row), - expected, - "The to_data method did not process the input row as expected.", - ) - - -if __name__ == "__main__": - unittest.main() From 9120538d7758d04e6dcfdb3589f28014b3c73cdc Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 23 Apr 2025 16:33:23 +0200 Subject: [PATCH 11/36] Update .gitignore --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index f9cb175..9b28876 100644 --- a/.gitignore +++ b/.gitignore @@ -167,3 +167,5 @@ cython_debug/ /logs /results_buffer electra_pretrained.ckpt +.jupyter +.virtual_documents From 6d7e6bd3eb46367f1bae1bec5532a88c8f64e08e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 23 Apr 2025 16:40:54 +0200 Subject: [PATCH 12/36] update readers for proteins --- chebai/preprocessing/reader.py | 135 ++------------------------------- 1 file changed, 6 insertions(+), 129 deletions(-) diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 1fa5a47..8c599c0 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -12,115 +12,8 @@ load_model_and_alphabet_local, ) -from chebai.preprocessing.collate import DefaultCollator, RaggedCollator - -EMBEDDING_OFFSET = 10 -PADDING_TOKEN_INDEX = 0 -MASK_TOKEN_INDEX = 1 -CLS_TOKEN = 2 - - -class DataReader: - """ - Base class for reading and preprocessing data. Turns the raw input data (e.g., a SMILES string) into the model - input format (e.g., a list of tokens). - - Args: - collator_kwargs: Optional dictionary of keyword arguments for the collator. - token_path: Optional path for the token file. - kwargs: Additional keyword arguments (not used). - """ - - COLLATOR = DefaultCollator - - def __init__( - self, - collator_kwargs: Optional[Dict[str, Any]] = None, - token_path: Optional[str] = None, - **kwargs, - ): - if collator_kwargs is None: - collator_kwargs = dict() - self.collator = self.COLLATOR(**collator_kwargs) - self.dirname = os.path.dirname(__file__) - self._token_path = token_path - - def _get_raw_data(self, row: Dict[str, Any]) -> Any: - """Get raw data from the row.""" - return row["features"] - - def _get_raw_label(self, row: Dict[str, Any]) -> Any: - """Get raw label from the row.""" - return row["labels"] - - def _get_raw_id(self, row: Dict[str, Any]) -> Any: - """Get raw ID from the row.""" - return row.get("ident", row["features"]) - - def _get_raw_group(self, row: Dict[str, Any]) -> Any: - """Get raw group from the row.""" - return row.get("group", None) - - def _get_additional_kwargs(self, row: Dict[str, Any]) -> Dict[str, Any]: - """Get additional keyword arguments from the row.""" - return row.get("additional_kwargs", dict()) - - def name(cls) -> str: - """Returns the name of the data reader.""" - raise NotImplementedError - - @property - def token_path(self) -> str: - """Get token path, create file if it does not exist yet.""" - if self._token_path is not None: - return self._token_path - token_path = os.path.join(self.dirname, "bin", self.name(), "tokens.txt") - os.makedirs(os.path.join(self.dirname, "bin", self.name()), exist_ok=True) - if not os.path.exists(token_path): - with open(token_path, "x"): - pass - return token_path - - def _read_id(self, raw_data: Any) -> Any: - """Read and return ID from raw data.""" - return raw_data - - def _read_data(self, raw_data: Any) -> Any: - """Read and return data from raw data.""" - return raw_data - - def _read_label(self, raw_label: Any) -> Any: - """Read and return label from raw label.""" - return raw_label - - def _read_group(self, raw: Any) -> Any: - """Read and return group from raw group data.""" - return raw - - def _read_components(self, row: Dict[str, Any]) -> Dict[str, Any]: - """Read and return components from the row.""" - return dict( - features=self._get_raw_data(row), - labels=self._get_raw_label(row), - ident=self._get_raw_id(row), - group=self._get_raw_group(row), - additional_kwargs=self._get_additional_kwargs(row), - ) - - def to_data(self, row: Dict[str, Any]) -> Dict[str, Any]: - """Convert raw row data to processed data.""" - d = self._read_components(row) - return dict( - features=self._read_data(d["features"]), - labels=self._read_label(d["labels"]), - ident=self._read_id(d["ident"]), - group=self._read_group(d["group"]), - **d["additional_kwargs"], - ) - - def on_finish(self) -> None: - """Hook to run at the end of preprocessing.""" - return +from chebai.preprocessing.collate import RaggedCollator +from chebai.preprocessing.reader import DataReader class ProteinDataReader(DataReader): @@ -139,31 +32,15 @@ class ProteinDataReader(DataReader): COLLATOR = RaggedCollator + # fmt: off # 21 natural amino acid notation AA_LETTER = [ - "A", - "R", - "N", - "D", - "C", - "Q", - "E", - "G", - "H", - "I", - "L", - "K", - "M", - "F", - "P", - "S", - "T", - "W", - "Y", - "V", + "A", "R", "N", "D", "C", "Q", "E", "G", "H", "I", + "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V", # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L3-L5 "X", # Consider valid in latest paper year 2024 Reference number 3 in go_uniprot.py ] + # fmt: on def name(self) -> str: """ From 83e334276608d5b81acfa42dbd7613c025e63f4f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 23 Apr 2025 16:44:59 +0200 Subject: [PATCH 13/36] import offset constants from chebai + remove its worflow --- .github/workflows/export_constants.py | 22 ----- .github/workflows/verify_constants.yml | 116 ------------------------- chebai/preprocessing/reader.py | 4 +- 3 files changed, 2 insertions(+), 140 deletions(-) delete mode 100644 .github/workflows/export_constants.py delete mode 100644 .github/workflows/verify_constants.yml diff --git a/.github/workflows/export_constants.py b/.github/workflows/export_constants.py deleted file mode 100644 index 6421498..0000000 --- a/.github/workflows/export_constants.py +++ /dev/null @@ -1,22 +0,0 @@ -import json - -from chebai.preprocessing.reader import ( - CLS_TOKEN, - EMBEDDING_OFFSET, - MASK_TOKEN_INDEX, - PADDING_TOKEN_INDEX, -) - -# Define the constants you want to export -# Any changes in the key names here should also follow the same change in verify_constants.yml code -constants = { - "EMBEDDING_OFFSET": EMBEDDING_OFFSET, - "CLS_TOKEN": CLS_TOKEN, - "PADDING_TOKEN_INDEX": PADDING_TOKEN_INDEX, - "MASK_TOKEN_INDEX": MASK_TOKEN_INDEX, -} - -if __name__ == "__main__": - # Write constants to a JSON file - with open("constants.json", "w") as f: - json.dump(constants, f) diff --git a/.github/workflows/verify_constants.yml b/.github/workflows/verify_constants.yml deleted file mode 100644 index 3246f64..0000000 --- a/.github/workflows/verify_constants.yml +++ /dev/null @@ -1,116 +0,0 @@ -name: Verify Constants - -# Define the file paths under `paths` to trigger this check only when specific files are modified. -# This script will then execute checks only on files that have changed, rather than all files listed in `paths`. - -# **Note** : To add a new file for checks, include its path in: -# - `on` -> `push` and `pull_request` sections -# - `jobs` -> `verify-constants` -> `steps` -> Verify constants -> Add a new if else for your file, with check logic inside it. - - -on: - push: - paths: - - "chebai/preprocessing/reader.py" - pull_request: - paths: - - "chebai/preprocessing/reader.py" - -jobs: - verify-constants: - runs-on: ubuntu-latest - strategy: - fail-fast: false - matrix: - python-version: [ -# Only use 3.10 as of now -# "3.9", - "3.10", -# "3.11" - ] - - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Set PYTHONPATH - run: echo "PYTHONPATH=$PWD" >> $GITHUB_ENV - - - name: Get list of changed files - id: changed_files - run: | - git fetch origin dev - - # Get the list of changed files compared to origin/dev and save them to a file - git diff --name-only origin/dev > changed_files.txt - - # Print the names of changed files on separate lines - echo "Changed files:" - while read -r line; do - echo "Changed File name : $line" - done < changed_files.txt - - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v5 - with: - python-version: ${{ matrix.python-version }} - - - name: Install dependencies - # Setting a fix version for torch due to an error with latest version (2.5.1) - # ImportError: cannot import name 'T_co' from 'torch.utils.data.dataset' - run: | - python -m pip install --upgrade pip - python -m pip install --upgrade pip setuptools wheel - python -m pip install torch==2.4.1 --index-url https://download.pytorch.org/whl/cpu - python -m pip install -e . - - - name: Export constants - run: python .github/workflows/export_constants.py - - - name: Load constants into environment variables - id: load_constants - # "E_" is appended as suffix to every constant, to protect overwriting other sys env variables with same name - run: | - constants=$(cat constants.json) - echo "$constants" | jq -r 'to_entries|map("E_\(.key)=\(.value|tostring)")|.[]' >> $GITHUB_ENV - - - name: Print all environment variables - run: printenv - - - name: Verify constants - run: | - file_name="chebai/preprocessing/reader.py" - if grep -q "$file_name" changed_files.txt; then - echo "----------------------- Checking file : $file_name ----------------------- " - - # Define expected values for constants - exp_embedding_offset="10" - exp_cls_token="2" - exp_padding_token_index="0" - exp_mask_token_index="1" - - # Debugging output to check environment variables - echo "Current Environment Variables:" - echo "E_EMBEDDING_OFFSET = $E_EMBEDDING_OFFSET" - echo "Expected: $exp_embedding_offset" - - # Verify constants match expected values - if [ "$E_EMBEDDING_OFFSET" != "$exp_embedding_offset" ]; then - echo "EMBEDDING_OFFSET ($E_EMBEDDING_OFFSET) does not match expected value ($exp_embedding_offset)!" - exit 1 - fi - if [ "$E_CLS_TOKEN" != "$exp_cls_token" ]; then - echo "CLS_TOKEN ($E_CLS_TOKEN) does not match expected value ($exp_cls_token)!" - exit 1 - fi - if [ "$E_PADDING_TOKEN_INDEX" != "$exp_padding_token_index" ]; then - echo "PADDING_TOKEN_INDEX ($E_PADDING_TOKEN_INDEX) does not match expected value ($exp_padding_token_index)!" - exit 1 - fi - if [ "$E_MASK_TOKEN_INDEX" != "$exp_mask_token_index" ]; then - echo "MASK_TOKEN_INDEX ($E_MASK_TOKEN_INDEX) does not match expected value ($exp_mask_token_index)!" - exit 1 - fi - else - echo "$file_name not found in changed_files.txt; skipping check." - fi diff --git a/chebai/preprocessing/reader.py b/chebai/preprocessing/reader.py index 8c599c0..e72e9b4 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai/preprocessing/reader.py @@ -1,6 +1,6 @@ import os from pathlib import Path -from typing import Any, Dict, List, Optional, Tuple +from typing import List, Optional, Tuple from urllib.error import HTTPError import torch @@ -13,7 +13,7 @@ ) from chebai.preprocessing.collate import RaggedCollator -from chebai.preprocessing.reader import DataReader +from chebai.preprocessing.reader import EMBEDDING_OFFSET, DataReader class ProteinDataReader(DataReader): From 22815fbe370f9cd423e8c48ed9308ac1dc16a68d Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 23 Apr 2025 17:02:26 +0200 Subject: [PATCH 14/36] rename base folder to chebai_proteins --- chebai/preprocessing/datasets/__init__.py | 1 - .../preprocessing/__init__.py | 0 .../preprocessing/bin/protein_token/tokens.txt | 0 .../bin/protein_token_3_gram/tokens.txt | 0 .../preprocessing/datasets}/__init__.py | 0 .../preprocessing/datasets/deepGO}/__init__.py | 0 .../datasets/deepGO/go_uniprot.py | 18 ++++-------------- .../datasets/deepGO/protein_pretraining.py | 3 +-- .../preprocessing/datasets/scope}/__init__.py | 0 .../preprocessing/datasets/scope/scope.py | 1 - .../preprocessing/migration}/__init__.py | 0 .../migration/deep_go/__init__.py | 0 .../deep_go/migrate_deep_go_1_data.py | 3 +-- .../deep_go/migrate_deep_go_2_data.py | 3 +-- .../preprocessing/reader.py | 5 ++--- 15 files changed, 9 insertions(+), 25 deletions(-) delete mode 100644 chebai/preprocessing/datasets/__init__.py rename {chebai => chebai_proteins}/preprocessing/__init__.py (100%) rename {chebai => chebai_proteins}/preprocessing/bin/protein_token/tokens.txt (100%) rename {chebai => chebai_proteins}/preprocessing/bin/protein_token_3_gram/tokens.txt (100%) rename {chebai/preprocessing/datasets/deepGO => chebai_proteins/preprocessing/datasets}/__init__.py (100%) rename {chebai/preprocessing/datasets/scope => chebai_proteins/preprocessing/datasets/deepGO}/__init__.py (100%) rename {chebai => chebai_proteins}/preprocessing/datasets/deepGO/go_uniprot.py (99%) rename {chebai => chebai_proteins}/preprocessing/datasets/deepGO/protein_pretraining.py (99%) rename {chebai/preprocessing/migration => chebai_proteins/preprocessing/datasets/scope}/__init__.py (100%) rename {chebai => chebai_proteins}/preprocessing/datasets/scope/scope.py (99%) rename {chebai/preprocessing/migration/deep_go => chebai_proteins/preprocessing/migration}/__init__.py (100%) create mode 100644 chebai_proteins/preprocessing/migration/deep_go/__init__.py rename {chebai => chebai_proteins}/preprocessing/migration/deep_go/migrate_deep_go_1_data.py (99%) rename {chebai => chebai_proteins}/preprocessing/migration/deep_go/migrate_deep_go_2_data.py (99%) rename {chebai => chebai_proteins}/preprocessing/reader.py (99%) diff --git a/chebai/preprocessing/datasets/__init__.py b/chebai/preprocessing/datasets/__init__.py deleted file mode 100644 index d6cc8de..0000000 --- a/chebai/preprocessing/datasets/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .base import XYBaseDataModule, _DynamicDataset diff --git a/chebai/preprocessing/__init__.py b/chebai_proteins/preprocessing/__init__.py similarity index 100% rename from chebai/preprocessing/__init__.py rename to chebai_proteins/preprocessing/__init__.py diff --git a/chebai/preprocessing/bin/protein_token/tokens.txt b/chebai_proteins/preprocessing/bin/protein_token/tokens.txt similarity index 100% rename from chebai/preprocessing/bin/protein_token/tokens.txt rename to chebai_proteins/preprocessing/bin/protein_token/tokens.txt diff --git a/chebai/preprocessing/bin/protein_token_3_gram/tokens.txt b/chebai_proteins/preprocessing/bin/protein_token_3_gram/tokens.txt similarity index 100% rename from chebai/preprocessing/bin/protein_token_3_gram/tokens.txt rename to chebai_proteins/preprocessing/bin/protein_token_3_gram/tokens.txt diff --git a/chebai/preprocessing/datasets/deepGO/__init__.py b/chebai_proteins/preprocessing/datasets/__init__.py similarity index 100% rename from chebai/preprocessing/datasets/deepGO/__init__.py rename to chebai_proteins/preprocessing/datasets/__init__.py diff --git a/chebai/preprocessing/datasets/scope/__init__.py b/chebai_proteins/preprocessing/datasets/deepGO/__init__.py similarity index 100% rename from chebai/preprocessing/datasets/scope/__init__.py rename to chebai_proteins/preprocessing/datasets/deepGO/__init__.py diff --git a/chebai/preprocessing/datasets/deepGO/go_uniprot.py b/chebai_proteins/preprocessing/datasets/deepGO/go_uniprot.py similarity index 99% rename from chebai/preprocessing/datasets/deepGO/go_uniprot.py rename to chebai_proteins/preprocessing/datasets/deepGO/go_uniprot.py index 9c5d5c0..eb86e86 100644 --- a/chebai/preprocessing/datasets/deepGO/go_uniprot.py +++ b/chebai_proteins/preprocessing/datasets/deepGO/go_uniprot.py @@ -42,28 +42,18 @@ import torch import tqdm from Bio import SwissProt - from chebai.preprocessing import reader as dr from chebai.preprocessing.datasets.base import _DynamicDataset +# fmt: off # https://github.com/bio-ontology-research-group/deepgo/blob/master/utils.py#L15 EXPERIMENTAL_EVIDENCE_CODES = { - "EXP", - "IDA", - "IPI", - "IMP", - "IGI", - "IEP", - "TAS", - "IC", + "EXP", "IDA", "IPI", "IMP", "IGI", "IEP", "TAS", "IC", # New evidence codes added in latest paper year 2024 Reference number 3 # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/utils.py#L24-L26 - "HTP", - "HDA", - "HMP", - "HGI", - "HEP", + "HTP", "HDA", "HMP", "HGI", "HEP", } +# fmt: on # https://github.com/bio-ontology-research-group/deepgo/blob/d97447a05c108127fee97982fd2c57929b2cf7eb/aaindex.py#L8 # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L10 diff --git a/chebai/preprocessing/datasets/deepGO/protein_pretraining.py b/chebai_proteins/preprocessing/datasets/deepGO/protein_pretraining.py similarity index 99% rename from chebai/preprocessing/datasets/deepGO/protein_pretraining.py rename to chebai_proteins/preprocessing/datasets/deepGO/protein_pretraining.py index 4be053a..df6c5b3 100644 --- a/chebai/preprocessing/datasets/deepGO/protein_pretraining.py +++ b/chebai_proteins/preprocessing/datasets/deepGO/protein_pretraining.py @@ -9,8 +9,6 @@ import pandas as pd import torch from Bio import SwissProt -from sklearn.model_selection import train_test_split - from chebai.preprocessing.datasets.base import _DynamicDataset from chebai.preprocessing.datasets.deepGO.go_uniprot import ( AMBIGUOUS_AMINO_ACIDS, @@ -18,6 +16,7 @@ GOUniProtOver250, ) from chebai.preprocessing.reader import ProteinDataReader +from sklearn.model_selection import train_test_split class _ProteinPretrainingData(_DynamicDataset, ABC): diff --git a/chebai/preprocessing/migration/__init__.py b/chebai_proteins/preprocessing/datasets/scope/__init__.py similarity index 100% rename from chebai/preprocessing/migration/__init__.py rename to chebai_proteins/preprocessing/datasets/scope/__init__.py diff --git a/chebai/preprocessing/datasets/scope/scope.py b/chebai_proteins/preprocessing/datasets/scope/scope.py similarity index 99% rename from chebai/preprocessing/datasets/scope/scope.py rename to chebai_proteins/preprocessing/datasets/scope/scope.py index e9127b2..abdabe8 100644 --- a/chebai/preprocessing/datasets/scope/scope.py +++ b/chebai_proteins/preprocessing/datasets/scope/scope.py @@ -23,7 +23,6 @@ import requests import torch from Bio import SeqIO - from chebai.preprocessing.datasets.base import _DynamicDataset from chebai.preprocessing.reader import ProteinDataReader diff --git a/chebai/preprocessing/migration/deep_go/__init__.py b/chebai_proteins/preprocessing/migration/__init__.py similarity index 100% rename from chebai/preprocessing/migration/deep_go/__init__.py rename to chebai_proteins/preprocessing/migration/__init__.py diff --git a/chebai_proteins/preprocessing/migration/deep_go/__init__.py b/chebai_proteins/preprocessing/migration/deep_go/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py b/chebai_proteins/preprocessing/migration/deep_go/migrate_deep_go_1_data.py similarity index 99% rename from chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py rename to chebai_proteins/preprocessing/migration/deep_go/migrate_deep_go_1_data.py index 7d59c69..bce6614 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_1_data.py +++ b/chebai_proteins/preprocessing/migration/deep_go/migrate_deep_go_1_data.py @@ -3,11 +3,10 @@ from typing import List, Literal, Optional, Tuple import pandas as pd +from chebai.preprocessing.datasets.deepGO.go_uniprot import DeepGO1MigratedData from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit from jsonargparse import CLI -from chebai.preprocessing.datasets.deepGO.go_uniprot import DeepGO1MigratedData - class DeepGo1DataMigration: """ diff --git a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py b/chebai_proteins/preprocessing/migration/deep_go/migrate_deep_go_2_data.py similarity index 99% rename from chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py rename to chebai_proteins/preprocessing/migration/deep_go/migrate_deep_go_2_data.py index d23247c..27dc063 100644 --- a/chebai/preprocessing/migration/deep_go/migrate_deep_go_2_data.py +++ b/chebai_proteins/preprocessing/migration/deep_go/migrate_deep_go_2_data.py @@ -4,10 +4,9 @@ from typing import List, Literal, Optional import pandas as pd -from jsonargparse import CLI - from chebai.preprocessing.datasets.deepGO.go_uniprot import DeepGO2MigratedData from chebai.preprocessing.reader import ProteinDataReader +from jsonargparse import CLI class DeepGo2DataMigration: diff --git a/chebai/preprocessing/reader.py b/chebai_proteins/preprocessing/reader.py similarity index 99% rename from chebai/preprocessing/reader.py rename to chebai_proteins/preprocessing/reader.py index e72e9b4..5117f26 100644 --- a/chebai/preprocessing/reader.py +++ b/chebai_proteins/preprocessing/reader.py @@ -4,6 +4,8 @@ from urllib.error import HTTPError import torch +from chebai.preprocessing.collate import RaggedCollator +from chebai.preprocessing.reader import EMBEDDING_OFFSET, DataReader from esm import Alphabet from esm.model.esm2 import ESM2 from esm.pretrained import ( @@ -12,9 +14,6 @@ load_model_and_alphabet_local, ) -from chebai.preprocessing.collate import RaggedCollator -from chebai.preprocessing.reader import EMBEDDING_OFFSET, DataReader - class ProteinDataReader(DataReader): """ From 68d4040aeb65393e69cf745ebc38d38e7f3fd2af Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 23 Apr 2025 17:11:57 +0200 Subject: [PATCH 15/36] update notebook for chebai_proteins root --- tutorials/data_exploration_go.ipynb | 10 ++++++---- tutorials/data_exploration_scope.ipynb | 8 ++++---- 2 files changed, 10 insertions(+), 8 deletions(-) diff --git a/tutorials/data_exploration_go.ipynb b/tutorials/data_exploration_go.ipynb index 6f67c82..1822c2f 100644 --- a/tutorials/data_exploration_go.ipynb +++ b/tutorials/data_exploration_go.ipynb @@ -37,7 +37,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Already in the project root directory: G:\\github-aditya0by0\\python-chebai\n" + "Already in the project root directory: G:\\github-aditya0by0\\python-chebai-proteins\n" ] } ], @@ -46,7 +46,7 @@ "import os\n", "\n", "# Root directory name of the project\n", - "expected_root_dir = \"python-chebai\"\n", + "expected_root_dir = \"python-chebai-proteins\"\n", "\n", "# Check if the current directory ends with the expected root directory name\n", "if not os.getcwd().endswith(expected_root_dir):\n", @@ -70,7 +70,9 @@ } }, "outputs": [], - "source": "from chebai.preprocessing.datasets.go_uniprot import GOUniProtOver250" + "source": [ + "from chebai_proteins.preprocessing.datasets.deepGO.go_uniprot import GOUniProtOver250" + ] }, { "cell_type": "code", @@ -1148,7 +1150,7 @@ "metadata": {}, "outputs": [], "source": [ - "from chebai.preprocessing.reader import ProteinDataReader" + "from chebai_proteins.preprocessing.reader import ProteinDataReader" ] }, { diff --git a/tutorials/data_exploration_scope.ipynb b/tutorials/data_exploration_scope.ipynb index c14046a..c7d17b6 100644 --- a/tutorials/data_exploration_scope.ipynb +++ b/tutorials/data_exploration_scope.ipynb @@ -185,7 +185,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "Changed to project root directory: G:\\github-aditya0by0\\python-chebai\n" + "Changed to project root directory: G:\\github-aditya0by0\\python-chebai-proteins\n" ] } ], @@ -194,7 +194,7 @@ "import os\n", "\n", "# Root directory name of the project\n", - "expected_root_dir = \"python-chebai\"\n", + "expected_root_dir = \"python-chebai-proteins\"\n", "\n", "# Check if the current directory ends with the expected root directory name\n", "if not os.getcwd().endswith(expected_root_dir):\n", @@ -224,7 +224,7 @@ "metadata": {}, "outputs": [], "source": [ - "from chebai.preprocessing.datasets.scope.scope import SCOPeOver50" + "from chebai_proteins.preprocessing.datasets.scope.scope import SCOPeOver50" ] }, { @@ -1104,7 +1104,7 @@ "metadata": {}, "outputs": [], "source": [ - "from chebai.preprocessing.reader import ProteinDataReader" + "from chebai_proteins.preprocessing.reader import ProteinDataReader" ] }, { From 78d79da30080e06aeb7d4026b46e0e1087addcf1 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 23 Apr 2025 17:20:19 +0200 Subject: [PATCH 16/36] add chebai repo to to setup.py --- setup.py | 31 +------------------------------ 1 file changed, 1 insertion(+), 30 deletions(-) diff --git a/setup.py b/setup.py index 1abc871..678d49c 100644 --- a/setup.py +++ b/setup.py @@ -16,36 +16,7 @@ zip_safe=False, python_requires=">=3.9, <3.13", install_requires=[ - "certifi", - "idna", - "joblib", - "networkx", - "numpy<2", - "pandas", - "python-dateutil", - "pytz", - "requests", - "scikit-learn", - "scipy", - "six", - "threadpoolctl", - "torch", - "typing-extensions", - "urllib3", - "transformers", - "fastobo", - "scikit-network", - "svgutils", - "matplotlib", - "lightning>=2.5", - "jsonargparse[signatures]>=4.17", - "omegaconf", - "seaborn", - "iterative-stratification", - "wandb", - "chardet", - "pyyaml", - "torchmetrics", + "chebai @ git+https://github.com/ChEB-AI/python-chebai.git", "biopython", "fair-esm", ], From 8dce9cb64e0bd6e7c9f2214ef8e3b3bfcbcf4b8b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 23 Apr 2025 17:23:02 +0200 Subject: [PATCH 17/36] Update setup.py --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 678d49c..27284a0 100644 --- a/setup.py +++ b/setup.py @@ -3,7 +3,7 @@ packages = find_packages() print(packages) setup( - name="chebai", + name="chebai-proteins", version="0.0.2.dev0", packages=packages, package_data={"": ["**/*.txt", "**/*.json"]}, From 71e361edd7e896eb94e21a5a8f89c8498f4ea0ff Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 23 Apr 2025 17:27:28 +0200 Subject: [PATCH 18/36] update unit test --- tests/unit/dataset_classes/testGOUniProDataExtractor.py | 6 ++++-- tests/unit/dataset_classes/testGoUniProtOverX.py | 2 +- tests/unit/dataset_classes/testProteinPretrainingData.py | 4 ++-- tests/unit/readers/testProteinDataReader.py | 2 +- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/unit/dataset_classes/testGOUniProDataExtractor.py b/tests/unit/dataset_classes/testGOUniProDataExtractor.py index 96ff9a3..8cee8f8 100644 --- a/tests/unit/dataset_classes/testGOUniProDataExtractor.py +++ b/tests/unit/dataset_classes/testGOUniProDataExtractor.py @@ -6,8 +6,10 @@ import networkx as nx import pandas as pd -from chebai.preprocessing.datasets.deepGO.go_uniprot import _GOUniProtDataExtractor -from chebai.preprocessing.reader import ProteinDataReader +from chebai_proteins.preprocessing.datasets.deepGO.go_uniprot import ( + _GOUniProtDataExtractor, +) +from chebai_proteins.preprocessing.reader import ProteinDataReader from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData diff --git a/tests/unit/dataset_classes/testGoUniProtOverX.py b/tests/unit/dataset_classes/testGoUniProtOverX.py index 3f329c5..ccd2d66 100644 --- a/tests/unit/dataset_classes/testGoUniProtOverX.py +++ b/tests/unit/dataset_classes/testGoUniProtOverX.py @@ -5,7 +5,7 @@ import networkx as nx import pandas as pd -from chebai.preprocessing.datasets.deepGO.go_uniprot import _GOUniProtOverX +from chebai_proteins.preprocessing.datasets.deepGO.go_uniprot import _GOUniProtOverX from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData diff --git a/tests/unit/dataset_classes/testProteinPretrainingData.py b/tests/unit/dataset_classes/testProteinPretrainingData.py index caac3ea..6c5044c 100644 --- a/tests/unit/dataset_classes/testProteinPretrainingData.py +++ b/tests/unit/dataset_classes/testProteinPretrainingData.py @@ -1,10 +1,10 @@ import unittest from unittest.mock import PropertyMock, mock_open, patch -from chebai.preprocessing.datasets.deepGO.protein_pretraining import ( +from chebai_proteins.preprocessing.datasets.deepGO.protein_pretraining import ( _ProteinPretrainingData, ) -from chebai.preprocessing.reader import ProteinDataReader +from chebai_proteins.preprocessing.reader import ProteinDataReader from tests.unit.mock_data.ontology_mock_data import GOUniProtMockData diff --git a/tests/unit/readers/testProteinDataReader.py b/tests/unit/readers/testProteinDataReader.py index c5bc5e9..bb5264d 100644 --- a/tests/unit/readers/testProteinDataReader.py +++ b/tests/unit/readers/testProteinDataReader.py @@ -2,7 +2,7 @@ from typing import List from unittest.mock import mock_open, patch -from chebai.preprocessing.reader import EMBEDDING_OFFSET, ProteinDataReader +from chebai_proteins.preprocessing.reader import EMBEDDING_OFFSET, ProteinDataReader class TestProteinDataReader(unittest.TestCase): From 3819fd35061b6b7723707fb0520dae3384c29107 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 23 Apr 2025 17:56:38 +0200 Subject: [PATCH 19/36] fix imports from chebai_proteins --- .../preprocessing/datasets/deepGO/go_uniprot.py | 3 ++- .../preprocessing/datasets/deepGO/protein_pretraining.py | 7 ++++--- tests/unit/readers/testProteinDataReader.py | 2 +- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/chebai_proteins/preprocessing/datasets/deepGO/go_uniprot.py b/chebai_proteins/preprocessing/datasets/deepGO/go_uniprot.py index eb86e86..c25d0f4 100644 --- a/chebai_proteins/preprocessing/datasets/deepGO/go_uniprot.py +++ b/chebai_proteins/preprocessing/datasets/deepGO/go_uniprot.py @@ -42,9 +42,10 @@ import torch import tqdm from Bio import SwissProt -from chebai.preprocessing import reader as dr from chebai.preprocessing.datasets.base import _DynamicDataset +from chebai_proteins.preprocessing import reader as dr + # fmt: off # https://github.com/bio-ontology-research-group/deepgo/blob/master/utils.py#L15 EXPERIMENTAL_EVIDENCE_CODES = { diff --git a/chebai_proteins/preprocessing/datasets/deepGO/protein_pretraining.py b/chebai_proteins/preprocessing/datasets/deepGO/protein_pretraining.py index df6c5b3..8c39d86 100644 --- a/chebai_proteins/preprocessing/datasets/deepGO/protein_pretraining.py +++ b/chebai_proteins/preprocessing/datasets/deepGO/protein_pretraining.py @@ -10,13 +10,14 @@ import torch from Bio import SwissProt from chebai.preprocessing.datasets.base import _DynamicDataset -from chebai.preprocessing.datasets.deepGO.go_uniprot import ( +from sklearn.model_selection import train_test_split + +from chebai_proteins.preprocessing.datasets.deepGO.go_uniprot import ( AMBIGUOUS_AMINO_ACIDS, EXPERIMENTAL_EVIDENCE_CODES, GOUniProtOver250, ) -from chebai.preprocessing.reader import ProteinDataReader -from sklearn.model_selection import train_test_split +from chebai_proteins.preprocessing.reader import ProteinDataReader class _ProteinPretrainingData(_DynamicDataset, ABC): diff --git a/tests/unit/readers/testProteinDataReader.py b/tests/unit/readers/testProteinDataReader.py index bb5264d..9dcd575 100644 --- a/tests/unit/readers/testProteinDataReader.py +++ b/tests/unit/readers/testProteinDataReader.py @@ -12,7 +12,7 @@ class TestProteinDataReader(unittest.TestCase): @classmethod @patch( - "chebai.preprocessing.reader.open", + "chebai_proteins.preprocessing.reader.open", new_callable=mock_open, read_data="M\nK\nT\nF\nR\nN", ) From ab9bd1cbea21f5996835ac9042d3e4d1a642b43c Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 24 Apr 2025 12:14:13 +0200 Subject: [PATCH 20/36] BCELoss config for deepgo2 --- configs/loss/BCELoss.yml | 1 + 1 file changed, 1 insertion(+) create mode 100644 configs/loss/BCELoss.yml diff --git a/configs/loss/BCELoss.yml b/configs/loss/BCELoss.yml new file mode 100644 index 0000000..6ee636d --- /dev/null +++ b/configs/loss/BCELoss.yml @@ -0,0 +1 @@ +class_path: torch.nn.BCELoss From dcbd57874d57539edf3e4dc16dc2be29e145c0fe Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 24 Apr 2025 12:15:59 +0200 Subject: [PATCH 21/36] scope esm2 config --- chebai_proteins/preprocessing/datasets/scope/scope.py | 6 +++++- configs/data/scope/scope50_esm.yml | 6 ++++++ 2 files changed, 11 insertions(+), 1 deletion(-) create mode 100644 configs/data/scope/scope50_esm.yml diff --git a/chebai_proteins/preprocessing/datasets/scope/scope.py b/chebai_proteins/preprocessing/datasets/scope/scope.py index abdabe8..e286fe2 100644 --- a/chebai_proteins/preprocessing/datasets/scope/scope.py +++ b/chebai_proteins/preprocessing/datasets/scope/scope.py @@ -24,7 +24,7 @@ import torch from Bio import SeqIO from chebai.preprocessing.datasets.base import _DynamicDataset -from chebai.preprocessing.reader import ProteinDataReader +from chebai.preprocessing.reader import ESM2EmbeddingReader, ProteinDataReader class _SCOPeDataExtractor(_DynamicDataset, ABC): @@ -952,6 +952,10 @@ class SCOPeOverPartial2000(_SCOPeOverXPartial): THRESHOLD: int = 2000 +class SCOPEOver50ESM(SCOPeOver50): + READER = ESM2EmbeddingReader + + if __name__ == "__main__": scope = SCOPeOver50(scope_version="2.08") diff --git a/configs/data/scope/scope50_esm.yml b/configs/data/scope/scope50_esm.yml new file mode 100644 index 0000000..f79174b --- /dev/null +++ b/configs/data/scope/scope50_esm.yml @@ -0,0 +1,6 @@ +class_path: chebai.preprocessing.datasets.scope.scope.SCOPeOver50ESM +init_args: + scope_version: "2.08" + reader_kwargs: { + truncation_length: 1000 + } From 1b2856ddeeda79e8f42920c786a2cb1332df2baf Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 24 Apr 2025 12:21:18 +0200 Subject: [PATCH 22/36] MultilabelAUROC for deepgo MLP --- configs/metrics/MultilabelAUROC.yml | 5 +++++ 1 file changed, 5 insertions(+) create mode 100644 configs/metrics/MultilabelAUROC.yml diff --git a/configs/metrics/MultilabelAUROC.yml b/configs/metrics/MultilabelAUROC.yml new file mode 100644 index 0000000..8ee2ae8 --- /dev/null +++ b/configs/metrics/MultilabelAUROC.yml @@ -0,0 +1,5 @@ +class_path: torchmetrics.MetricCollection +init_args: + metrics: + balanced-accuracy: + class_path: torchmetrics.classification.MultilabelAUROC From 31b6f451a56bba9e1574bfbeedf3b4ab7c19928e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 24 Apr 2025 12:43:28 +0200 Subject: [PATCH 23/36] update migration script --- .../migration/deep_go/migrate_deep_go_1_data.py | 3 ++- .../migration/deep_go/migrate_deep_go_2_data.py | 5 +++-- 2 files changed, 5 insertions(+), 3 deletions(-) diff --git a/chebai_proteins/preprocessing/migration/deep_go/migrate_deep_go_1_data.py b/chebai_proteins/preprocessing/migration/deep_go/migrate_deep_go_1_data.py index bce6614..fb5beb4 100644 --- a/chebai_proteins/preprocessing/migration/deep_go/migrate_deep_go_1_data.py +++ b/chebai_proteins/preprocessing/migration/deep_go/migrate_deep_go_1_data.py @@ -3,10 +3,11 @@ from typing import List, Literal, Optional, Tuple import pandas as pd -from chebai.preprocessing.datasets.deepGO.go_uniprot import DeepGO1MigratedData from iterstrat.ml_stratifiers import MultilabelStratifiedShuffleSplit from jsonargparse import CLI +from chebai_proteins.preprocessing.datasets.deepGO.go_uniprot import DeepGO1MigratedData + class DeepGo1DataMigration: """ diff --git a/chebai_proteins/preprocessing/migration/deep_go/migrate_deep_go_2_data.py b/chebai_proteins/preprocessing/migration/deep_go/migrate_deep_go_2_data.py index 27dc063..01d9b3b 100644 --- a/chebai_proteins/preprocessing/migration/deep_go/migrate_deep_go_2_data.py +++ b/chebai_proteins/preprocessing/migration/deep_go/migrate_deep_go_2_data.py @@ -4,10 +4,11 @@ from typing import List, Literal, Optional import pandas as pd -from chebai.preprocessing.datasets.deepGO.go_uniprot import DeepGO2MigratedData -from chebai.preprocessing.reader import ProteinDataReader from jsonargparse import CLI +from chebai_proteins.preprocessing.datasets.deepGO.go_uniprot import DeepGO2MigratedData +from chebai_proteins.preprocessing.reader import ProteinDataReader + class DeepGo2DataMigration: """ From 6c2506db40b55c4f2b643949ee590df1d5531598 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 24 Apr 2025 13:28:24 +0200 Subject: [PATCH 24/36] update configs --- configs/data/deepGO/deepgo2_esm2.yml | 2 +- configs/data/deepGO/deepgo_1_migrated_data.yml | 2 +- configs/data/deepGO/deepgo_2_migrated_data.yml | 2 +- configs/data/deepGO/go250.yml | 2 +- configs/data/deepGO/go50.yml | 2 +- configs/data/scope/scope2000.yml | 2 +- configs/data/scope/scope50.yml | 2 +- configs/data/scope/scope50_esm.yml | 2 +- 8 files changed, 8 insertions(+), 8 deletions(-) diff --git a/configs/data/deepGO/deepgo2_esm2.yml b/configs/data/deepGO/deepgo2_esm2.yml index 5a0436e..4c9d200 100644 --- a/configs/data/deepGO/deepgo2_esm2.yml +++ b/configs/data/deepGO/deepgo2_esm2.yml @@ -1,4 +1,4 @@ -class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.DeepGO2MigratedData +class_path: chebai_proteins.preprocessing.datasets.deepGO.go_uniprot.DeepGO2MigratedData init_args: go_branch: "MF" max_sequence_length: 1000 diff --git a/configs/data/deepGO/deepgo_1_migrated_data.yml b/configs/data/deepGO/deepgo_1_migrated_data.yml index 0924e02..5d7d237 100644 --- a/configs/data/deepGO/deepgo_1_migrated_data.yml +++ b/configs/data/deepGO/deepgo_1_migrated_data.yml @@ -1,4 +1,4 @@ -class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.DeepGO1MigratedData +class_path: chebai_proteins.preprocessing.datasets.deepGO.go_uniprot.DeepGO1MigratedData init_args: go_branch: "MF" max_sequence_length: 1002 diff --git a/configs/data/deepGO/deepgo_2_migrated_data.yml b/configs/data/deepGO/deepgo_2_migrated_data.yml index 5a0436e..4c9d200 100644 --- a/configs/data/deepGO/deepgo_2_migrated_data.yml +++ b/configs/data/deepGO/deepgo_2_migrated_data.yml @@ -1,4 +1,4 @@ -class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.DeepGO2MigratedData +class_path: chebai_proteins.preprocessing.datasets.deepGO.go_uniprot.DeepGO2MigratedData init_args: go_branch: "MF" max_sequence_length: 1000 diff --git a/configs/data/deepGO/go250.yml b/configs/data/deepGO/go250.yml index 01e34aa..2d694b4 100644 --- a/configs/data/deepGO/go250.yml +++ b/configs/data/deepGO/go250.yml @@ -1,3 +1,3 @@ -class_path: chebai.preprocessing.datasets.go_uniprot.deepGO.GOUniProtOver250 +class_path: chebai_proteins.preprocessing.datasets.go_uniprot.deepGO.GOUniProtOver250 init_args: go_branch: "BP" diff --git a/configs/data/deepGO/go50.yml b/configs/data/deepGO/go50.yml index bee4377..495a923 100644 --- a/configs/data/deepGO/go50.yml +++ b/configs/data/deepGO/go50.yml @@ -1 +1 @@ -class_path: chebai.preprocessing.datasets.deepGO.go_uniprot.GOUniProtOver50 +class_path: chebai_proteins.preprocessing.datasets.deepGO.go_uniprot.GOUniProtOver50 diff --git a/configs/data/scope/scope2000.yml b/configs/data/scope/scope2000.yml index d75c807..ca1789b 100644 --- a/configs/data/scope/scope2000.yml +++ b/configs/data/scope/scope2000.yml @@ -1,3 +1,3 @@ -class_path: chebai.preprocessing.datasets.scope.scope.SCOPeOver2000 +class_path: chebai_proteins.preprocessing.datasets.scope.scope.SCOPeOver2000 init_args: scope_version: "2.08" diff --git a/configs/data/scope/scope50.yml b/configs/data/scope/scope50.yml index a5f808d..477d71b 100644 --- a/configs/data/scope/scope50.yml +++ b/configs/data/scope/scope50.yml @@ -1,3 +1,3 @@ -class_path: chebai.preprocessing.datasets.scope.scope.SCOPeOver50 +class_path: chebai_proteins.preprocessing.datasets.scope.scope.SCOPeOver50 init_args: scope_version: "2.08" diff --git a/configs/data/scope/scope50_esm.yml b/configs/data/scope/scope50_esm.yml index f79174b..8575b98 100644 --- a/configs/data/scope/scope50_esm.yml +++ b/configs/data/scope/scope50_esm.yml @@ -1,4 +1,4 @@ -class_path: chebai.preprocessing.datasets.scope.scope.SCOPeOver50ESM +class_path: chebai_proteins.preprocessing.datasets.scope.scope.SCOPeOver50ESM init_args: scope_version: "2.08" reader_kwargs: { From 19ab4a7d663d4c5d0a51e3cf9776bc30cfab615b Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Thu, 24 Apr 2025 15:38:25 +0200 Subject: [PATCH 25/36] make python dir --- chebai_proteins/__init__.py | 0 chebai_proteins/preprocessing/datasets/scope/scope.py | 3 ++- 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 chebai_proteins/__init__.py diff --git a/chebai_proteins/__init__.py b/chebai_proteins/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/chebai_proteins/preprocessing/datasets/scope/scope.py b/chebai_proteins/preprocessing/datasets/scope/scope.py index e286fe2..4c6147c 100644 --- a/chebai_proteins/preprocessing/datasets/scope/scope.py +++ b/chebai_proteins/preprocessing/datasets/scope/scope.py @@ -24,7 +24,8 @@ import torch from Bio import SeqIO from chebai.preprocessing.datasets.base import _DynamicDataset -from chebai.preprocessing.reader import ESM2EmbeddingReader, ProteinDataReader + +from chebai_proteins.preprocessing.reader import ESM2EmbeddingReader, ProteinDataReader class _SCOPeDataExtractor(_DynamicDataset, ABC): From add85e32cf469076471e3b0ba00a3f5965f52878 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 4 May 2025 20:19:30 +0200 Subject: [PATCH 26/36] deepgo: raise error if no classes are selected --- chebai_proteins/preprocessing/datasets/deepGO/go_uniprot.py | 4 ++++ chebai_proteins/preprocessing/datasets/scope/scope.py | 4 +--- chebai_proteins/preprocessing/reader.py | 1 + 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/chebai_proteins/preprocessing/datasets/deepGO/go_uniprot.py b/chebai_proteins/preprocessing/datasets/deepGO/go_uniprot.py index c25d0f4..dbdf93e 100644 --- a/chebai_proteins/preprocessing/datasets/deepGO/go_uniprot.py +++ b/chebai_proteins/preprocessing/datasets/deepGO/go_uniprot.py @@ -394,6 +394,10 @@ def _graph_to_raw_dataset(self, g: nx.DiGraph) -> pd.DataFrame: ) # Initialize the GO term labels/columns to False selected_classes = self.select_classes(g, data_df=data_df) + if not selected_classes: + raise ValueError( + f"No classes selected for given threshold {self.THRESHOLD}" + ) new_label_columns = pd.DataFrame( False, index=data_df.index, columns=selected_classes ) diff --git a/chebai_proteins/preprocessing/datasets/scope/scope.py b/chebai_proteins/preprocessing/datasets/scope/scope.py index 4c6147c..6f580ca 100644 --- a/chebai_proteins/preprocessing/datasets/scope/scope.py +++ b/chebai_proteins/preprocessing/datasets/scope/scope.py @@ -728,8 +728,6 @@ def _setup_pruned_test_set( """ # TODO: find a more efficient way to do this filename_old = "classes.txt" - # filename_new = f"classes_v{self.scope_version_train}.txt" - # dataset = torch.load(os.path.join(self.processed_dir, "test.pt")) # Load original classes (from the current SCOPe version - scope_version) with open(os.path.join(self.processed_dir_main, filename_old), "r") as file: @@ -760,7 +758,7 @@ def _setup_pruned_test_set( # set the corresponding label in new_labels to True if mapping[ind] is not None and label: new_labels[mapping[ind]] = label - # Update the labels from test instance from scope_version to the new labels, which are compatible to both versions + # Update the labels from test instance of scope_version to new labels, which are compatible to both versions row["labels"] = new_labels return df_test_scope_version diff --git a/chebai_proteins/preprocessing/reader.py b/chebai_proteins/preprocessing/reader.py index 5117f26..fb5cf03 100644 --- a/chebai_proteins/preprocessing/reader.py +++ b/chebai_proteins/preprocessing/reader.py @@ -166,6 +166,7 @@ class ESM2EmbeddingReader(DataReader): def __init__( self, + # --------- Default Parameters as per DeepGO2 ------------ save_model_dir: str = os.path.join("data", "esm2_reader"), model_name: str = "esm2_t36_3B_UR50D", device: Optional[torch.device] = None, From c89f26deb52a3c14e08bcf2cb774004cb576887e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 4 May 2025 20:31:31 +0200 Subject: [PATCH 27/36] rectify consistent naming of scope --- chebai_proteins/preprocessing/datasets/scope/scope.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/chebai_proteins/preprocessing/datasets/scope/scope.py b/chebai_proteins/preprocessing/datasets/scope/scope.py index 6f580ca..bf3540e 100644 --- a/chebai_proteins/preprocessing/datasets/scope/scope.py +++ b/chebai_proteins/preprocessing/datasets/scope/scope.py @@ -951,7 +951,7 @@ class SCOPeOverPartial2000(_SCOPeOverXPartial): THRESHOLD: int = 2000 -class SCOPEOver50ESM(SCOPeOver50): +class SCOPeOver50ESM(SCOPeOver50): READER = ESM2EmbeddingReader From 196d662591e27cf6a1e844c9e6623dc6982a56f5 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 7 May 2025 11:47:09 +0200 Subject: [PATCH 28/36] reader: add collator to esm reader --- chebai_proteins/preprocessing/reader.py | 19 +++++++++---------- 1 file changed, 9 insertions(+), 10 deletions(-) diff --git a/chebai_proteins/preprocessing/reader.py b/chebai_proteins/preprocessing/reader.py index fb5cf03..2be4727 100644 --- a/chebai_proteins/preprocessing/reader.py +++ b/chebai_proteins/preprocessing/reader.py @@ -8,11 +8,8 @@ from chebai.preprocessing.reader import EMBEDDING_OFFSET, DataReader from esm import Alphabet from esm.model.esm2 import ESM2 -from esm.pretrained import ( - _has_regression_weights, - load_model_and_alphabet_core, - load_model_and_alphabet_local, -) +from esm.pretrained import _has_regression_weights # noqa +from esm.pretrained import load_model_and_alphabet_core, load_model_and_alphabet_local class ProteinDataReader(DataReader): @@ -24,7 +21,7 @@ class ProteinDataReader(DataReader): Refer for amino acid sequence: https://en.wikipedia.org/wiki/Protein_primary_structure Args: - collator_kwargs (Optional[Dict[str, Any]]): Optional dictionary of keyword arguments for configuring the collator. + collator_kwargs (Optional[Dict[str, Any]]): Optional dict of keyword arguments for configuring the collator. token_path (Optional[str]): Path to the token file. If not provided, it will be created automatically. kwargs: Additional keyword arguments. """ @@ -132,7 +129,7 @@ def _read_data(self, raw_data: str) -> List[int]: def on_finish(self) -> None: """ - Saves the current cache of tokens to the token file. This method is called after all data processing is complete. + Saves the current cache of tokens to the token file.This method is called after all data processing is complete. """ with open(self.token_path, "w") as pk: print(f"Saving {len(self.cache)} tokens to {self.token_path}...") @@ -158,6 +155,8 @@ class ESM2EmbeddingReader(DataReader): """ + COLLATOR = RaggedCollator + # https://github.com/facebookresearch/esm/blob/main/esm/pretrained.py#L53 _MODELS_URL = "https://dl.fbaipublicfiles.com/fair-esm/models/{}.pt" _REGRESSION_URL = ( @@ -270,12 +269,12 @@ def load_hub_workaround(self, url) -> torch.Tensor: ) except HTTPError as e: raise Exception( - f"Could not load {url}. Did you specify the correct model name?" + f"Could not load {url}. Did you specify the correct model name? \n Error: {e}" ) return data - @staticmethod - def name() -> str: + @classmethod + def name(cls) -> str: """ Returns the name of the data reader. This method identifies the specific type of data reader. From 5af20c8cc39d6b73456ab26b45567dc293270f53 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Wed, 7 May 2025 15:21:36 +0200 Subject: [PATCH 29/36] set weight_only=False for esm reader --- chebai_proteins/preprocessing/reader.py | 28 +++++++++++++++++++++++-- 1 file changed, 26 insertions(+), 2 deletions(-) diff --git a/chebai_proteins/preprocessing/reader.py b/chebai_proteins/preprocessing/reader.py index 2be4727..4803bf6 100644 --- a/chebai_proteins/preprocessing/reader.py +++ b/chebai_proteins/preprocessing/reader.py @@ -221,10 +221,29 @@ def load_model_and_alphabet(self) -> Tuple[ESM2, Alphabet]: """ model_location = os.path.join(self.save_model_dir, f"{self.model_name}.pt") if os.path.exists(model_location): - return load_model_and_alphabet_local(model_location) + return self.load_model_and_alphabet_local(model_location) else: return self.load_model_and_alphabet_hub() + @staticmethod + def load_model_and_alphabet_local(model_location): + """Load from local path. The regression weights need to be co-located""" + model_location = Path(model_location) + model_data = torch.load( + str(model_location), map_location="cpu", weights_only=False + ) + model_name = model_location.stem + if _has_regression_weights(model_name): + regression_location = ( + str(model_location.with_suffix("")) + "-contact-regression.pt" + ) + regression_data = torch.load( + regression_location, map_location="cpu", weights_only=False + ) + else: + regression_data = None + return load_model_and_alphabet_core(model_name, model_data, regression_data) + def load_model_and_alphabet_hub(self) -> Tuple[ESM2, Alphabet]: """ Load the model and alphabet from the hub URL. @@ -257,7 +276,11 @@ def load_hub_workaround(self, url) -> torch.Tensor: """ try: data = torch.hub.load_state_dict_from_url( - url, self.save_model_dir, progress=True, map_location=self.device + url, + self.save_model_dir, + progress=True, + map_location=self.device, + weights_only=False, ) except RuntimeError: @@ -266,6 +289,7 @@ def load_hub_workaround(self, url) -> torch.Tensor: data = torch.load( f"{torch.hub.get_dir()}/checkpoints/{fn}", map_location="cpu", + weights_only=False, ) except HTTPError as e: raise Exception( From d653f52bdef25121e5b25e59c88ed6891744252f Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 10 May 2025 21:42:46 +0200 Subject: [PATCH 30/36] use `TokenIndexerReader` for `ProteinDataReader` --- chebai_proteins/preprocessing/reader.py | 27 ++++++------------------- 1 file changed, 6 insertions(+), 21 deletions(-) diff --git a/chebai_proteins/preprocessing/reader.py b/chebai_proteins/preprocessing/reader.py index 4803bf6..21bdaea 100644 --- a/chebai_proteins/preprocessing/reader.py +++ b/chebai_proteins/preprocessing/reader.py @@ -5,14 +5,14 @@ import torch from chebai.preprocessing.collate import RaggedCollator -from chebai.preprocessing.reader import EMBEDDING_OFFSET, DataReader +from chebai.preprocessing.reader import DataReader, TokenIndexerReader from esm import Alphabet from esm.model.esm2 import ESM2 from esm.pretrained import _has_regression_weights # noqa -from esm.pretrained import load_model_and_alphabet_core, load_model_and_alphabet_local +from esm.pretrained import load_model_and_alphabet_core -class ProteinDataReader(DataReader): +class ProteinDataReader(TokenIndexerReader): """ Data reader for protein sequences using amino acid tokens. This class processes raw protein sequences into a format suitable for model input by tokenizing them and assigning unique indices to each token. @@ -30,12 +30,12 @@ class ProteinDataReader(DataReader): # fmt: off # 21 natural amino acid notation - AA_LETTER = [ + AA_LETTER = { "A", "R", "N", "D", "C", "Q", "E", "G", "H", "I", "L", "K", "M", "F", "P", "S", "T", "W", "Y", "V", # https://github.com/bio-ontology-research-group/deepgo2/blob/main/deepgo/aminoacids.py#L3-L5 "X", # Consider valid in latest paper year 2024 Reference number 3 in go_uniprot.py - ] + } # fmt: on def name(self) -> str: @@ -68,10 +68,6 @@ def __init__(self, *args, n_gram: Optional[int] = None, **kwargs): super().__init__(*args, **kwargs) - # Load the existing tokens from the token file into a cache - with open(self.token_path, "r") as pk: - self.cache = [x.strip() for x in pk] - def _get_token_index(self, token: str) -> int: """ Returns a unique index for each token (amino acid). If the token is not already in the cache, it is added. @@ -102,9 +98,7 @@ def _get_token_index(self, token: str) -> int: + error_str ) - if str(token) not in self.cache: - self.cache.append(str(token)) - return self.cache.index(str(token)) + EMBEDDING_OFFSET + return super()._get_token_index(token) def _read_data(self, raw_data: str) -> List[int]: """ @@ -127,15 +121,6 @@ def _read_data(self, raw_data: str) -> List[int]: # If n_gram is None, tokenize the sequence at the amino acid level (single-letter representation) return [self._get_token_index(aa) for aa in raw_data] - def on_finish(self) -> None: - """ - Saves the current cache of tokens to the token file.This method is called after all data processing is complete. - """ - with open(self.token_path, "w") as pk: - print(f"Saving {len(self.cache)} tokens to {self.token_path}...") - print(f"First 10 tokens: {self.cache[:10]}") - pk.writelines([f"{c}\n" for c in self.cache]) - class ESM2EmbeddingReader(DataReader): """ From 71fa9fe4ea1ecf1a3cb9cfcb52093f37230da77e Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 10 May 2025 23:23:58 +0200 Subject: [PATCH 31/36] update test for protein reader for tokenindexer changes --- tests/unit/readers/testProteinDataReader.py | 75 ++++++++++++++++++--- 1 file changed, 65 insertions(+), 10 deletions(-) diff --git a/tests/unit/readers/testProteinDataReader.py b/tests/unit/readers/testProteinDataReader.py index 9dcd575..3433370 100644 --- a/tests/unit/readers/testProteinDataReader.py +++ b/tests/unit/readers/testProteinDataReader.py @@ -2,7 +2,9 @@ from typing import List from unittest.mock import mock_open, patch -from chebai_proteins.preprocessing.reader import EMBEDDING_OFFSET, ProteinDataReader +from chebai.preprocessing.reader import EMBEDDING_OFFSET + +from chebai_proteins.preprocessing.reader import ProteinDataReader class TestProteinDataReader(unittest.TestCase): @@ -25,14 +27,16 @@ def setUpClass(cls, mock_file: mock_open) -> None: """ cls.reader = ProteinDataReader(token_path="/mock/path") # After initializing, cls.reader.cache should now be set to ['M', 'K', 'T', 'F', 'R', 'N'] - assert cls.reader.cache == [ - "M", - "K", - "T", - "F", - "R", - "N", - ], "Cache initialization did not match expected tokens." + assert list(cls.reader.cache.items()) == list( + { + "M": 0, + "K": 1, + "T": 2, + "F": 3, + "R": 4, + "N": 5, + }.items() + ), "Initial cache does not match expected values or the order doesn't match." def test_read_data(self) -> None: """ @@ -86,7 +90,7 @@ def test_read_data_with_new_token(self) -> None: ) # Ensure it's at the correct index self.assertEqual( - self.reader.cache.index("Y"), + self.reader.cache["Y"], len(self.reader.cache) - 1, "The new token 'Y' was not added at the correct index in the cache.", ) @@ -134,6 +138,57 @@ def test_read_data_with_repeated_tokens(self) -> None: "The _read_data method did not correctly handle repeated tokens.", ) + @patch("builtins.open", new_callable=mock_open) + def test_finish_method_for_new_tokens(self, mock_file: mock_open) -> None: + """ + Test the on_finish method to ensure it appends only the new tokens to the token file in order. + """ + # Simulate that some tokens were already loaded + self.reader._loaded_tokens_count = 6 # 6 tokens already loaded + self.reader.cache = { + "M": 0, + "K": 1, + "T": 2, + "F": 3, + "R": 4, + "N": 5, + "W": 6, # New token 1 + "Y": 7, # New token 2 + "V": 8, # New token 3 + "Q": 9, # New token 4 + "E": 10, # New token 5 + } + + # Run the on_finish method + self.reader.on_finish() + + # Check that the file was opened in append mode ('a') + mock_file.assert_called_with(self.reader.token_path, "a") + + # Verify the new tokens were written in the correct order + mock_file().writelines.assert_called_with( + ["[H-]\n", "Br\n", "Cl\n", "Na\n", "Mg\n"] + ) + + def test_finish_method_no_new_tokens(self) -> None: + """ + Test the on_finish method when no new tokens are added (cache is the same). + """ + self.reader._loaded_tokens_count = 6 # No new tokens + self.reader.cache = { + "M": 0, + "K": 1, + "T": 2, + "F": 3, + "R": 4, + "N": 5, + } + + with patch("builtins.open", new_callable=mock_open) as mock_file: + self.reader.on_finish() + # Check that no new tokens were written + mock_file().writelines.assert_not_called() + if __name__ == "__main__": unittest.main() From 508a47aad97b46f0e77e5ad6084f89c725db9763 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sat, 10 May 2025 23:33:36 +0200 Subject: [PATCH 32/36] fix protein test for mock open --- tests/unit/readers/testProteinDataReader.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tests/unit/readers/testProteinDataReader.py b/tests/unit/readers/testProteinDataReader.py index 3433370..f097aab 100644 --- a/tests/unit/readers/testProteinDataReader.py +++ b/tests/unit/readers/testProteinDataReader.py @@ -14,7 +14,7 @@ class TestProteinDataReader(unittest.TestCase): @classmethod @patch( - "chebai_proteins.preprocessing.reader.open", + "chebai.preprocessing.reader.open", new_callable=mock_open, read_data="M\nK\nT\nF\nR\nN", ) @@ -166,9 +166,7 @@ def test_finish_method_for_new_tokens(self, mock_file: mock_open) -> None: mock_file.assert_called_with(self.reader.token_path, "a") # Verify the new tokens were written in the correct order - mock_file().writelines.assert_called_with( - ["[H-]\n", "Br\n", "Cl\n", "Na\n", "Mg\n"] - ) + mock_file().writelines.assert_called_with(["W\n", "Y\n", "V\n", "Q\n", "E\n"]) def test_finish_method_no_new_tokens(self) -> None: """ From a8823c83bd4c1ab2833cd93081ac9e85c7df6581 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Sun, 11 May 2025 11:30:55 +0200 Subject: [PATCH 33/36] add abstract DataReader for proteins repo to override token path --- chebai_proteins/preprocessing/reader.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/chebai_proteins/preprocessing/reader.py b/chebai_proteins/preprocessing/reader.py index 21bdaea..80d6cfd 100644 --- a/chebai_proteins/preprocessing/reader.py +++ b/chebai_proteins/preprocessing/reader.py @@ -1,4 +1,5 @@ import os +from abc import ABC from pathlib import Path from typing import List, Optional, Tuple from urllib.error import HTTPError @@ -12,7 +13,15 @@ from esm.pretrained import load_model_and_alphabet_core -class ProteinDataReader(TokenIndexerReader): +class _ChebaiProteinsDataReader(DataReader, ABC): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + # This to override the token directory path which points to `chebai` repo instead of `chebai-proteins` to + # search for tokens.txt files for readers defined in `chebai-proteins` repository. + self.dirname = os.path.dirname(__file__) + + +class ProteinDataReader(TokenIndexerReader, _ChebaiProteinsDataReader): """ Data reader for protein sequences using amino acid tokens. This class processes raw protein sequences into a format suitable for model input by tokenizing them and assigning unique indices to each token. @@ -122,7 +131,7 @@ def _read_data(self, raw_data: str) -> List[int]: return [self._get_token_index(aa) for aa in raw_data] -class ESM2EmbeddingReader(DataReader): +class ESM2EmbeddingReader(_ChebaiProteinsDataReader): """ A data reader to process protein sequences using the ESM2 model for embeddings. From cd92ca514f676260eb72a33412ee518e82ba6f1d Mon Sep 17 00:00:00 2001 From: Aditya Khedekar <65857172+aditya0by0@users.noreply.github.com> Date: Mon, 12 May 2025 16:06:33 +0200 Subject: [PATCH 34/36] proteins readme --- README.md | 148 +++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 146 insertions(+), 2 deletions(-) diff --git a/README.md b/README.md index b954a79..f0033da 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,146 @@ -# python-chebai-proteins -Protein-related extension of the chebai framework + +# ๐Ÿงช ChEB-AI Proteins + +`python-chebai-proteins` repository for protein prediction and classification, built on top of the [`python-chebai`](https://github.com/ChEB-AI/python-chebai) codebase. + + +## ๐Ÿ”ง Installation + + +To install, follow these steps: + +1. Clone the repository: +``` +git clone https://github.com/ChEB-AI/python-chebai-proteins.git +``` + +2. Install the package: + +``` +cd python-chebai +pip install . +``` + +## ๐Ÿ—‚ Recommended Folder Structure + +To combine configuration files from both `python-chebai` and `python-chebai-proteins`, structure your project like this: + +``` +my_projects/ +โ”œโ”€โ”€ python-chebai/ +โ”‚ โ”œโ”€โ”€ chebai/ +โ”‚ โ”œโ”€โ”€ configs/ +โ”‚ โ””โ”€โ”€ ... +โ””โ”€โ”€ python-chebai-proteins/ + โ”œโ”€โ”€ chebai_proteins/ + โ”œโ”€โ”€ configs/ + โ””โ”€โ”€ ... +``` + +This setup enables shared access to data and model configurations. + + + +## ๐Ÿš€ Training & Pretraining Guide + +### โš ๏ธ Important Setup Instructions + +Before running any training scripts, ensure the environment is correctly configured: + +* Either: + + * Install the `python-chebai` repository as a package using: + + ```bash + pip install . + ``` +* **OR** + + * Manually set the `PYTHONPATH` environment variable if working across multiple directories (`python-chebai` and `python-chebai-proteins`): + + * If your current working directory is `python-chebai-proteins`, set: + + ```bash + export PYTHONPATH=path/to/python-chebai + ``` + or vice versa. + + * If you're working within both repositories simultaneously or facing module not found errors, we **recommend configuring both directories**: + + ```bash + # Linux/macOS + export PYTHONPATH=path/to/python-chebai:path/to/python-chebai-proteins + + # Windows (use semicolon instead of colon) + set PYTHONPATH=path\to\python-chebai;path\to\python-chebai-proteins + ``` + +> ๐Ÿ”Ž See the [PYTHONPATH Explained](#-pythonpath-explained) section below for more details. + + +### ๐Ÿ“Š SCOPE hierarchy prediction + +Assuming your current working directory is `python-chebai-proteins`, run the following command to start training: +```bash +python -m chebai fit --trainer=../configs/training/default_trainer.yml --trainer.callbacks=../configs/training/default_callbacks.yml --trainer.logger.init_args.name=scope50 --trainer.accumulate_grad_batches=4 --trainer.logger=../configs/training/wandb_logger.yml --trainer.min_epochs=100 --trainer.max_epochs=100 --data=configs/data/scope/scope50.yml --data.init_args.batch_size=32 --data.init_args.num_workers=10 --model=../configs/model/electra.yml --model.train_metrics=../configs/metrics/micro-macro-f1.yml --model.test_metrics=../configs/metrics/micro-macro-f1.yml --model.val_metrics=../configs/metrics/micro-macro-f1.yml --model.pass_loss_kwargs=false --model.criterion=../configs/loss/bce.yml --model.criterion.init_args.beta=0.99 +``` + +Same command can be used for **DeepGO** just by changing the config path for data. + + + + + + + +## ๐Ÿงญ PYTHONPATH Explained + +### What is `PYTHONPATH`? + +`PYTHONPATH` is an environment variable that tells Python where to search for modules that aren't installed via `pip` or not in your current working directory. + +### Why You Need It + +If your config refers to a custom module like: + +```yaml +class_path: chebai_proteins.preprocessing.datasets.scope.scope.SCOPe50 +``` + +...and you're running the code from `python-chebai`, Python won't know where to find `chebai_proteins` (from another repo like `python-chebai-proteins/`) unless you add it to `PYTHONPATH`. + + +### How Python Finds Modules + +Python looks for imports in this order: + +1. Current directory +2. Standard library +3. Paths in `PYTHONPATH` +4. Installed packages (`site-packages`) + +You can inspect the full search paths: + +```bash +python -c "import sys; print(sys.path)" +``` + + + +### โœ… Setting `PYTHONPATH` + +#### ๐Ÿง Linux / macOS + +```bash +export PYTHONPATH=/path/to/python-chebai-graph +echo $PYTHONPATH +``` + +#### ๐ŸชŸ Windows CMD + +```cmd +set PYTHONPATH=C:\path\to\python-chebai-graph +echo %PYTHONPATH% +``` + +> ๐Ÿ’ก Note: This is temporary for your terminal session. To make it permanent, add it to your system environment variables. From 7cf059c09a95ab966464fbe60c6e0c1285cfd8fb Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 12 May 2025 20:00:48 +0200 Subject: [PATCH 35/36] Revert "add abstract DataReader for proteins repo to override token path" This reverts commit a8823c83bd4c1ab2833cd93081ac9e85c7df6581. --- chebai_proteins/preprocessing/reader.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/chebai_proteins/preprocessing/reader.py b/chebai_proteins/preprocessing/reader.py index 80d6cfd..21bdaea 100644 --- a/chebai_proteins/preprocessing/reader.py +++ b/chebai_proteins/preprocessing/reader.py @@ -1,5 +1,4 @@ import os -from abc import ABC from pathlib import Path from typing import List, Optional, Tuple from urllib.error import HTTPError @@ -13,15 +12,7 @@ from esm.pretrained import load_model_and_alphabet_core -class _ChebaiProteinsDataReader(DataReader, ABC): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - # This to override the token directory path which points to `chebai` repo instead of `chebai-proteins` to - # search for tokens.txt files for readers defined in `chebai-proteins` repository. - self.dirname = os.path.dirname(__file__) - - -class ProteinDataReader(TokenIndexerReader, _ChebaiProteinsDataReader): +class ProteinDataReader(TokenIndexerReader): """ Data reader for protein sequences using amino acid tokens. This class processes raw protein sequences into a format suitable for model input by tokenizing them and assigning unique indices to each token. @@ -131,7 +122,7 @@ def _read_data(self, raw_data: str) -> List[int]: return [self._get_token_index(aa) for aa in raw_data] -class ESM2EmbeddingReader(_ChebaiProteinsDataReader): +class ESM2EmbeddingReader(DataReader): """ A data reader to process protein sequences using the ESM2 model for embeddings. From 979b4f2d3449e493b053808a159ae957dcfa6151 Mon Sep 17 00:00:00 2001 From: aditya0by0 Date: Mon, 12 May 2025 20:18:41 +0200 Subject: [PATCH 36/36] Update .gitignore --- .gitignore | 1 + 1 file changed, 1 insertion(+) diff --git a/.gitignore b/.gitignore index 9b28876..05cdfb7 100644 --- a/.gitignore +++ b/.gitignore @@ -169,3 +169,4 @@ cython_debug/ electra_pretrained.ckpt .jupyter .virtual_documents +.isort.cfg