diff --git a/.github/workflows/lint_and_test.yaml b/.github/workflows/lint_and_test.yaml index c1c2d308..4fa9adc3 100644 --- a/.github/workflows/lint_and_test.yaml +++ b/.github/workflows/lint_and_test.yaml @@ -4,7 +4,6 @@ on: push: branches: - main - - dev - production paths: - 'src/**' @@ -14,6 +13,7 @@ on: branches: - main - dev + - staging - production paths: - 'src/**' @@ -29,12 +29,13 @@ on: concurrency: group: ${{ github.workflow }}-${{ github.ref }} cancel-in-progress: true + +permissions: + contents: read jobs: lint: - name: ruff (check code style) - # NOTE: We use an ubuntu runner to not be dependent on possibly limited digs infra - # for this tiny linting job. This means we can have many lint jobs accross repos in parallel. + name: ruff runs-on: ubuntu-latest steps: - uses: actions/checkout@v4 @@ -49,23 +50,18 @@ jobs: run: pip install ruff==${{ env.RUFF_VERSION }} - name: ruff format (check code formatting) run: ruff format --diff - # - name: ruff check (lint code base) - # run: ruff check + - name: ruff check (lint code base) + run: ruff check src tests test_digs: - name: pytest (run tests) + name: pytest (jojo) runs-on: [jojo] timeout-minutes: 30 needs: lint - # ... only run on non-draft PRs to `main` to avoid unnecessary CI runs - # ... and only run on changed files in the `atomworks`, `tests`, or `scripts` directories - if: | - (github.event_name == 'pull_request' && !github.event.pull_request.draft) || - (github.event_name == 'pull_request_target' && github.event.action == 'ready_for_review') + if: github.event_name == 'workflow_dispatch' steps: - uses: actions/checkout@v4 - name: Run tests - timeout-minutes: 30 run: | export N_CPU=8 srun --chdir=$PWD -p cpu -c $N_CPU -t 00:30:00 --mem=32G bash ./.github/ci/run_tests.sh @@ -110,7 +106,7 @@ jobs: run: | atomworks setup tests - - name: Run pytest with multiple cores + - name: Run pytest run: | export OPENBLAS_NUM_THREADS=1 export OMP_NUM_THREADS=1 diff --git a/.gitignore b/.gitignore index 809bb4d9..a0e9b213 100644 --- a/.gitignore +++ b/.gitignore @@ -161,4 +161,5 @@ tests/test_outputs dev.py dev.ipynb _version.py -tinker/ \ No newline at end of file +tinker/ +.DS_Store \ No newline at end of file diff --git a/README.md b/README.md index e81a1db1..49824bfe 100644 --- a/README.md +++ b/README.md @@ -25,9 +25,10 @@ AtomWorks is built atop [biotite](https://www.biotite-python.org/): We are grate ## atomworks.io -*A general-purpose Python toolkit for working with biomolecular files* +> *A general-purpose Python toolkit for cleaning up, standardizing, and working with biomolecular files - based on biotite* **atomworks.io** lets you: + - Parse, convert, and clean any common biological file (structure or sequence). For example, identifying and removing leaving groups, correcting bond order after nucleophilic addition, fixing charges, parsing covalent geometries, and appropriate treatment of structures with multiple occupancies and ligands at symmetry centers - Transform all data to a consistent `AtomArray` representation for further analysis or machine learning applications, regardless of initial source - Model missing atoms (those implied by the sequence but not represented in the coordinates) and initialize entity- and instance-level annotations (see the [glossary]() for more detail on our composable naming conventions) @@ -38,9 +39,10 @@ We have found `atomworks.io` to be useful to a general bioinformatics and protei ## atomworks.ml -*Modular, component-based library for dataset featurization within biomolecular deep learning workflows* +> *Modular, component-based library for dataset featurization within biomolecular deep learning workflows* **atomworks.ml** provides: + - A library of pre-built, well-tested `Transforms` that can be slotted into novel pipelines - An extensible framework, integrated with `atomworks.io`, to write `Transforms` for arbitrary use cases - Scripts to pre-process the PDB or other databases into dataframes appropriate for network training @@ -69,7 +71,23 @@ For more advanced setup options (including how to run workflows via apptainers) --- -## Quick Start +## Getting started + +### 1. When to use `atomworks.io` vs `atomworks.ml`? + +- Use `atomworks.io` when you: + - Need to parse/clean/convert between biological file formats (mmCIF, PDB, FASTA, etc.) + - Want a unified structural representation to plug into any downstream analysis or modeling + - Need structural operations like adding missing atoms, filtering ligands/solvents, or assembly generation + +- Use `atomworks.ml` when you: + - Need to featurize entire datasets for deep learning + - Want ready-made sampling and batching utilities for training pipelines + - Already use `atomworks.io` and want a seamless bridge to ML-ready feature engineering + +### 2. Quick Start + +To parse a pdb file (parse = load, clean, annotate relevant metadata such as entities, molecules, etc) you can use the `parse` function: ```python @@ -77,33 +95,199 @@ from atomworks.io.parser import parse result = parse(filename="3nez.cif.gz") +asym_unit: AtomArrayStack = result["asym_unit"] +assemblies: dict[str, AtomArrayStack] = result["assemblies"] + for chain_id, info in result["chain_info"].items(): -print(chain_id, info["sequence"]) + print(chain_id, info["sequence"]) ``` -Output includes: +The output of `parse` includes: + - **chain_info** — Sequences/metadata for each chain - **ligand_info** — Ligand annotation & metrics - **asym_unit** — Structure (`AtomArrayStack`) - **assemblies** — Built biological assemblies (each are their own `AtomArrayStack`) - **metadata** — Experimental and source information -See [usage examples](https://baker-laboratory.github.io/atomworks-dev/latest/auto_examples/). +See [usage examples](https://baker-laboratory.github.io/atomworks-dev/latest/auto_examples/) for more details. ---- +If you just want to load a file, you can use the `load_any` function: + +```python +from atomworks.io.utils.io_utils import load_any -## When to use atomworks.io vs atomworks.ml? +atom_array: AtomArray = load_any("3nez.cif.gz", model=1) # model=1 means that we want to load the model 1 (i.e. the first model) rather than a stack of all models in the file +``` + +### 3. Training on the PDB + +> ⚠️ **Disclaimer:** Documentation for this section is currently under construction. Please check back soon for updates! + +**Step 1 — Mirror the PDB (mmCIFs)** + To train on the PDB, you first need to make sure you have access to the samples form the PDB. We use `mmCIF` files as the highly recommended format for training. + For convenience, we provide a command to mirror the PDB: + + ```bash + # Full mirror (~100 GB) + atomworks pdb sync /path/to/pdb_mirror # This will create a carbon-copy of the PDB, dated today, in the specified directory. It will download the .mmcif files in the same sharding pattern as the original PDB and keep them gzipped for efficiency. + +# # If, for some reason you only want to download specific IDs, the CLI also supports this: +# atomworks pdb sync /path/to/pdb_mirror --pdb-id 1A0I --pdb-id 7XYZ # This will only download the specified PDB IDs. +# # or +# atomworks pdb sync /path/to/pdb_mirror --pdb-ids-file /path/to/ids.txt # This will download the PDB IDs listed in the file, one per line. Each line should be a PDB ID (e.g. '6lyz') and separated by a newline. + ``` + + Once the mirror is created, set the environment variable: + + ```bash + export PDB_MIRROR_PATH=/path/to/pdb_mirror + ``` + + To have this more permanent, you can add it to a `.env` file in your home directory. Here is an [example of a `.env`](.env.sample) file structure that you can copy, rename to `.env` and edit with your own paths. + +**Step 2 — Get PDB metadata (PN units and interfaces)** + To calculate sampling probabilities and filter examples for splits, we pre-process the PDB with metadata for each PDB entry. + To save you the work, we provide pre-computed metadata (dated July 15/2025) for downloading: + + ```bash + atomworks setup metadata /path/to/metadata # This will download the metadata (as .tar.gz) and extract it to the specified directory. + ``` + + This produces parquet files at: + +- `/path/to/metadata/pn_units_df.parquet` — Contains metadata for each *PN unit* in the PDB. The term *pn unit* is shorthand for `polymer XOR non-polymer unit` and behaves for almost all purposes like the `chain` in a PDB file. The only difference is that a ligand composed of multiple covalently bonded ligands is considered a single PN unit (whilst it would be multiple chains in a PDB file). Effectively this `.parquet` is a large table of all individual chains, ligands, etc (to be precise, it has one entry per pn unit) in the PDB that includes helpful metadata for filtering and sampling. +- `/path/to/metadata/interfaces_df.parquet` — Contains metadata for each interface in the PDB. This `.parquet` is a large table of all binary interfaces in the PDB. It lists each interface as (pn_unit_1, pn_unit_2) pairs and includes helpful metadata for filtering and sampling. + + Alternatively, you can generate fresher metadata yourself (scripts will be uploaded in the coming weeks). + +**Step 3 — Configure an AF3-style dataset (example: train only on D-polypeptides)** +Next we need to use the metadata to configure a dataset that we would like to sample from. This includes e.g. training cut-off, filters, transforms to apply, etc. +Here's a simple example that: + +- Filters to D-polypeptide and L-polypeptide chains only (`POLYPEPTIDE_D` and `POLYPEPTIDE_L` -- to include additional chain types, replace the lists with the appropriate IDs (see [mapping](./src/atomworks/enums.py#L31-L45) in comments). +- Excludes ligands in the AF3 list of excluded ligands, available at [`atomworks.io.constants.AF3_EXCLUDED_LIGANDS_REGEX`](./src/atomworks/io/constants.py#L350). + +```yaml +# NOTE: The below is a hydra config and the _target_ fields are the hydra syntax for instantiating a class. +# You can use this without hyrda, but will then instead need to provide the corresponding arguments for the +# _target_ objects directly. + +# Chain type ids used below (from atomworks.enums.ChainType): +# 0=CyclicPseudoPeptide, 1=OtherPolymer, 2=PeptideNucleicAcid, +# 3=DNA, 4=DNA_RNA_HYBRID, 5=POLYPEPTIDE_D, 6=POLYPEPTIDE_L, 7=RNA, +# 8=NON_POLYMER, 9=WATER, 10=BRANCHED, 11=MACROLIDE + +af3_pdb_dataset: + _target_: atomworks.ml.datasets.datasets.ConcatDatasetWithID + datasets: + # Single PN units + - _target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper + dataset_parser: + _target_: atomworks.ml.datasets.parsers.PNUnitsDFParser + transform: + _target_: atomworks.ml.pipelines.af3.build_af3_transform_pipeline + is_inference: false + n_recycles: 5 # This means that we will subsample 5 random sets from the MSA for each example. + crop_size: 256 + crop_contiguous_probability: 0.3333333333333333 + crop_spatial_probability: 0.6666666666666666 + diffusion_batch_size: 32 + # Optional templates (if available) + template_lookup_path: ${paths.shared}/template_lookup.csv + template_base_dir: ${paths.shared}/template + # Optional MSAs (see Step 4) + # protein_msa_dirs: + # - { dir: /path/to/msa, extension: .a3m.gz, directory_depth: 2 } + # rna_msa_dirs: + # - { dir: /path/to/msa, extension: .afa, directory_depth: 0 } + dataset: + _target_: atomworks.ml.datasets.datasets.PandasDataset + name: pn_units + id_column: example_id + data: /path/to/metadata/pn_units_df.parquet + filters: + - "deposition_date < '2022-01-01'" + - "resolution < 5.0 and ~method.str.contains('NMR')" + - "num_polymer_pn_units <= 20" + - "cluster.notnull()" + - "method in ['X-RAY_DIFFRACTION', 'ELECTRON_MICROSCOPY']" + # Train only on D-polypeptides: + - "q_pn_unit_type in [5, 6]" # 5 = POLYPEPTIDE_D, 6 = POLYPEPTIDE_L + # Exclude ligands from AF3 excluded set: + - "~(q_pn_unit_non_polymer_res_names.notnull() and q_pn_unit_non_polymer_res_names.str.contains('${af3_excluded_ligands_regex}', regex=True))" + columns_to_load: null + save_failed_examples_to_dir: null + + # Binary interfaces + - _target_: atomworks.ml.datasets.datasets.StructuralDatasetWrapper + dataset_parser: + _target_: atomworks.ml.datasets.parsers.InterfacesDFParser + transform: + _target_: atomworks.ml.pipelines.af3.build_af3_transform_pipeline + is_inference: false + n_recycles: 5 + crop_size: 256 + crop_spatial_probability: 1.0 + crop_contiguous_probability: 0.0 + diffusion_batch_size: 32 + template_lookup_path: ${paths.shared}/template_lookup.csv + template_base_dir: ${paths.shared}/template + # Optional MSAs (see Step 4) + # protein_msa_dirs: + # - { dir: /path/to/msa, extension: .a3m.gz, directory_depth: 2 } + # rna_msa_dirs: + # - { dir: /path/to/msa, extension: .afa, directory_depth: 0 } + dataset: + _target_: atomworks.ml.datasets.datasets.PandasDataset + name: interfaces + id_column: example_id + data: /path/to/metadata/interfaces_df.parquet + filters: + - "deposition_date < '2022-01-01'" + - "resolution < 5.0 and ~method.str.contains('NMR')" + - "num_polymer_pn_units <= 20" + - "cluster.notnull()" + - "method in ['X-RAY_DIFFRACTION', 'ELECTRON_MICROSCOPY']" + # Train only on D-polypeptide interfaces: + - "pn_unit_1_type in [5, 6]" # 5 = POLYPEPTIDE_D, 6 = POLYPEPTIDE_L + - "pn_unit_2_type in [5, 6]" # 5 = POLYPEPTIDE_D, 6 = POLYPEPTIDE_L + - "~(pn_unit_1_non_polymer_res_names.notnull() and pn_unit_1_non_polymer_res_names.str.contains('${af3_excluded_ligands_regex}', regex=True))" + - "~(pn_unit_2_non_polymer_res_names.notnull() and pn_unit_2_non_polymer_res_names.str.contains('${af3_excluded_ligands_regex}', regex=True))" + columns_to_load: null + cif_parser_args: + cache_dir: null + save_failed_examples_to_dir: null +``` + +**Step 4 — MSAs (optional)** +We are working on a way to make MSAs accessible to the public, but due to the large storage requirements (multiple TB) we are still working on this. If your organization has interest & capacity to host the MSAs, please contact us. In the meantime, if you have MSAs (e.g., from OpenProteinSet) you can configure the pipeline to use them like so: + +```yaml + protein_msa_dirs: + - { dir: /path/to/msa, extension: .a3m.gz, directory_depth: 2 } + rna_msa_dirs: + - { dir: /path/to/msa, extension: .afa, directory_depth: 0 } +``` -- Use **atomworks.io** when you: - - Need to parse/clean/convert between biological file formats (mmCIF, PDB, FASTA, etc.) - - Want a unified structural representation to plug into any downstream analysis or modeling - - Need structural operations like adding missing atoms, filtering ligands/solvents, or assembly generation +Or alternatively not use MSAs. -- Use **atomworks.ml** when you: - - Need to featurize entire datasets for deep learning - - Want ready-made sampling and batching utilities for training pipelines - - Already use atomworks.io and want a seamless bridge to ML-ready feature engineering +**Step 5 — Train a model** +You now have a full fledged dataset that you can use to train models on! If you want to just try this out without having to download the whole PDB and the metdatada, you can instead run our tests which have a mini-mockup of the pipeline with real pdb files, metadata, distillation data, templates and MSAs for the example of AF3. You can download all this relevant metadata via the atomworks CLI: + +```bash +atomworks setup tests # This will download the test pack to `tests/data` and unpack it there (~500 MB) +``` + +You will now have a mini PDB at `tests/data/pdb` and a mini custom CCD at `tests/data/ccd`. MSA and template data is in `tests/data/shared` and the distillation and metadata are in `data/ml/af2_distillation`, `data/ml/pdb_pn_units` and `data/ml/pdb_interfaces`. A dataset that uses all of these is [for example here](./tests/ml/conftest.py#L300). + +To run the tests for the various datasets, you can run the following command: + +```bash +# Make sure you have the correct environment activated, and set your paths correctly in the .env file / shell environment variables (see points above) +pytest tests/ml/test_data_loading_pipelines.py +``` --- @@ -116,5 +300,18 @@ Please see the [full documentation](https://baker-laboratory.github.io/atomworks If you make use of AtomWorks in your research, please cite: -* N. Corley, S. Mathis, R. Krishna, M. S. Bauer, T. R. Thompson, W. Ahern, M. W. Kazman, R. I. Brent, K. Didi, A. Kubaney, L. McHugh, A. Nagle, A. Favor, M. Kshirsagar, P. Sturmfels, Y. Li, J. Butcher, B. Qiang, L. L. Schaaf, R. Mitra, K. Campbell, O. Zhang, R. Weissman, I. R. Humphreys, Q. Cong, J. Funk, S. Sonthalia, P. Lio, D. Baker, F. DiMaio, -"Accelerating Biomolecular Modeling with AtomWorks and RF3," bioRxiv, August 2025. doi: [10.1101/2025.08.14.670328](https://doi.org/10.1101/2025.08.14.670328) +> N. Corley\*, S. Mathis\*, R. Krishna\*, M. S. Bauer, T. R. Thompson, W. Ahern, M. W. Kazman, R. I. Brent, K. Didi, A. Kubaney, L. McHugh, A. Nagle, A. Favor, M. Kshirsagar, P. Sturmfels, Y. Li, J. Butcher, B. Qiang, L. L. Schaaf, R. Mitra, K. Campbell, O. Zhang, R. Weissman, I. R. Humphreys, Q. Cong, J. Funk, S. Sonthalia, P. Lio, D. Baker, F. DiMaio, +> "Accelerating Biomolecular Modeling with AtomWorks and RF3," bioRxiv, August 2025. doi: [10.1101/2025.08.14.670328](https://doi.org/10.1101/2025.08.14.670328) + +If you use bibtex, here's the GoogleScholar formatted citation: + +```bibtex +@article{corley2025accelerating, + title={Accelerating Biomolecular Modeling with AtomWorks and RF3}, + author={Corley, Nathaniel and Mathis, Simon and Krishna, Rohith and Bauer, Magnus S and Thompson, Tuscan R and Ahern, Woody and Kazman, Maxwell W and Brent, Rafael I and Didi, Kieran and Kubaney, Andrew and others}, + journal={bioRxiv}, + pages={2025--08}, + year={2025}, + publisher={Cold Spring Harbor Laboratory} +} +``` \ No newline at end of file diff --git a/docs/_static/atomworks_glossary.png b/docs/_static/atomworks_glossary.png new file mode 100644 index 00000000..62ba544d Binary files /dev/null and b/docs/_static/atomworks_glossary.png differ diff --git a/docs/_static/examples/annotate_and_save_structures_01.png b/docs/_static/examples/annotate_and_save_structures_01.png new file mode 100644 index 00000000..4310fb9f Binary files /dev/null and b/docs/_static/examples/annotate_and_save_structures_01.png differ diff --git a/docs/_static/examples/annotate_and_save_structures_02.png b/docs/_static/examples/annotate_and_save_structures_02.png new file mode 100644 index 00000000..4bc44692 Binary files /dev/null and b/docs/_static/examples/annotate_and_save_structures_02.png differ diff --git a/docs/_static/examples/load_and_visualize_structures_01.png b/docs/_static/examples/load_and_visualize_structures_01.png new file mode 100644 index 00000000..2bd83234 Binary files /dev/null and b/docs/_static/examples/load_and_visualize_structures_01.png differ diff --git a/docs/_static/examples/pocket_conditioning_transform_01.png b/docs/_static/examples/pocket_conditioning_transform_01.png new file mode 100644 index 00000000..81a5b599 Binary files /dev/null and b/docs/_static/examples/pocket_conditioning_transform_01.png differ diff --git a/docs/examples/annotate_and_save_structures.py b/docs/examples/annotate_and_save_structures.py new file mode 100644 index 00000000..4ebc9a5e --- /dev/null +++ b/docs/examples/annotate_and_save_structures.py @@ -0,0 +1,225 @@ +""" +Annotating and Saving Protein Structures +========================================= + +This example walks through how to add custom annotations to AtomArrays, visualize them, and save them for later use. + +**Prerequisites**: Familiarity with :doc:`load_and_visualize_structures` for basic structure loading and exploration. + +.. figure:: /_static/examples/annotate_and_save_structures_01.png + :alt: Heme pocket visualization + :width: 400px + + Visualization of heme-binding pocket atoms (within 6Å of heme ligand) in myoglobin. +""" + +######################################################################## +# Setup and Structure Loading +# ---------------------------- +# +# Let's start by loading a protein structure that we'll annotate. We'll use the same myoglobin structure from the loading example: + +import os +import tempfile + +import biotite.structure as struc +import numpy as np + +from atomworks.io import parse +from atomworks.io.utils.io_utils import to_cif_file +from atomworks.io.utils.testing import get_pdb_path_or_buffer +from atomworks.io.utils.visualize import view + +# sphinx_gallery_thumbnail_path = '_static/examples/annotate_and_save_structures_01.png' + +# Load myoglobin structure with heme +example_pdb_id = "101m" # Myoglobin with heme +pdb_path = get_pdb_path_or_buffer(example_pdb_id) + +# Parse the structure (no need to add missing atoms, since we would just remove them in the following step) +atom_array = parse(pdb_path, add_missing_atoms=False, fix_formal_charges=False)["assemblies"]["1"][0] + +print(f"Loaded structure with {len(atom_array)} atoms") +print(f"Chains: {np.unique(atom_array.chain_id)}") + +# Clean up coordinates (remove any NaN values, if present) +# (NaN coordinates will break our later step when we create a CellList with Biotite) +valid_coords_mask = ~np.isnan(atom_array.coord).any(axis=1) +atom_array = atom_array[valid_coords_mask] +print(f"After removing NaN coordinates: {len(atom_array)} atoms") + +######################################################################## +# Adding Custom Annotations +# -------------------------- +# +# Now let's add custom annotations to mark different types of atoms. We'll use pocket identification as an example to demonstrate how to create meaningful structural annotations for many ML and general bioinformatics applications. +# +# Step 1: Identify Structural Features (Pocket Identification) +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Let's efficiently identify the heme-binding pocket using spatial distance cutoffs with Biotite's ``CellList`` class: + +# Find atoms within 6 Angstroms of the heme using a spatial cell list +cell_list = struc.CellList(atom_array.coord, cell_size=6.0) +heme_coords = atom_array.coord[atom_array.res_name == "HEM"] + +print(f"Found {len(heme_coords)} heme atoms") + +# Get all atoms within 6Å of any heme atom +pocket_mask = cell_list.get_atoms(heme_coords, 6.0, as_mask=True) +pocket_mask = np.any(pocket_mask, axis=0) # Combine results for all heme atoms + +print(f"Found {np.sum(pocket_mask)} atoms within 6Å of heme") + +# %% + +# Visualize the pocket region (always a helpful sanity-check, and trivial with AtomWorks) +print("\nVisualizing pocket region (all atoms within 6Å of heme):") +view(atom_array[pocket_mask]) + +######################################################################## +# .. figure:: /_static/examples/annotate_and_save_structures_01.png +# :alt: Heme pocket visualization + +######################################################################## +# Step 2: Create Annotations from Identified Features +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# Now we'll convert our pocket identification into an explicit ``AtomArray`` annotation and visualize it: + +# Boolean annotation for pocket residues (excluding heme itself) +is_pocket = pocket_mask & (atom_array.res_name != "HEM") +atom_array.set_annotation("is_hem_pocket", is_pocket.astype(bool)) + +# Boolean annotation for heme atoms +is_heme = atom_array.res_name == "HEM" +atom_array.set_annotation("is_heme", is_heme.astype(bool)) + +print(f" - Pocket atoms: {np.sum(atom_array.is_hem_pocket)}") +print(f" - Heme atoms: {np.sum(atom_array.is_heme)}") + +# %% + +# Visualize just the pocket residues +print("\nVisualizing annotated pocket residues:") +view(atom_array[atom_array.is_hem_pocket]) + +######################################################################## +# .. figure:: /_static/examples/annotate_and_save_structures_02.png +# :alt: Annotated pocket residues visualization + +######################################################################## +# Saving Annotated Structures +# ---------------------------- +# +# Now let's save our annotated structure. In many use cases we may want to save our modified ``AtomArray`` to disk and later load again, preserving our original annotations. +# +# AtomWorks provides two methods to do so: +# +# .. list-table:: +# :header-rows: 0 +# +# * - Saving to CIF, adding extra annotations directly into the file +# * - Standard Python object pickling (which may be sensitive to versions, libraries, etc.) +# +# Saving to CIF Files +# ~~~~~~~~~~~~~~~~~~~ +# +# CIF files are the standard for structural data and allow us to store arbitrary annotations and categories. + +# Create temporary directory for our files +temp_dir = tempfile.mkdtemp() +print(f"Working in temporary directory: {temp_dir}") + +# Save to CIF file with custom annotations specified +cif_path = os.path.join(temp_dir, "annotated_structure.cif") +custom_fields = ["is_hem_pocket", "is_heme"] + +saved_cif_path = to_cif_file( + atom_array, + cif_path, + extra_fields=custom_fields, +) + +print(f"Saved CIF file to: {saved_cif_path}") +print(f"File size: {os.path.getsize(saved_cif_path) / 1024:.1f} KB") + +######################################################################## +# Note on Biological Assemblies and CIF Saving +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# In some cases, you may find that ``to_cif_file`` reports an error when the structure represents a biological assembly containing multiple copies of the asymmetric unit. The reason for this error is that ``AtomWorks`` builds the biological assembly and explicitly represents every atom; we can't then reverse that process since we may be left with ambiguous bond annotations (e.g., no way to distinguish between multiple copies of "Chain A"). The best solution is to either (a) set the ``chain_id`` to the ``chain_iid`` (which resolves the ambiguity) or (b) simply save the object using a pickle. +# +# More rigorous solutions exist; a helpful place for contributions! +# +# Alternative Storage Options +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# For Python-specific workflows, you can also save structures as pickle files to preserve exact data types, though CIF files are recommended for interoperability and long-term storage. + +######################################################################## +# Loading Annotated Structures +# ----------------------------- +# +# When we load pickled ``AtomArray``'s, we should restore our original object out-of-the-box with all annotations preserved. +# +# When loading from CIF, however, we may need to grapple with data type issues, since within CIF files all fields are considered strings. +# +# In the future, we would like to automatically detect annotation data types during loading (and/or allow specification of data types) - we would love contributions and a PR! +# +# Loading from CIF Files +# ~~~~~~~~~~~~~~~~~~~~~~ + +from atomworks.io.utils.io_utils import load_any + +# Load from CIF file +loaded_from_cif = load_any(saved_cif_path, extra_fields="all")[0] + +print("Loaded from CIF file:") +print(f" Atoms: {len(loaded_from_cif)}") +print(" Custom annotations:") +for annotation in loaded_from_cif.get_annotation_categories(): + if annotation in custom_fields: + dtype = getattr(loaded_from_cif, annotation).dtype + print(f" ✓ {annotation} ({dtype})") + +######################################################################## +# Handling Data Type Conversions +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# As we can see above, when boolean annotations are saved to CIF files, they become string representations ("True"/"False"). Here's how to convert them back (we welcome contributions to automate this process and/or allow explicit specification): + + +# Convert string booleans back to actual boolean type +def fix_boolean_annotation(atom_array: struc.AtomArray, annotation_name: str) -> struc.AtomArray: + """Convert string boolean annotations back to bool type.""" + string_values = getattr(atom_array, annotation_name) + boolean_values = string_values == "True" + atom_array.del_annotation(annotation_name) + atom_array.set_annotation(annotation_name, boolean_values) + return atom_array + + +# Fix boolean annotations +loaded_from_cif = fix_boolean_annotation(loaded_from_cif, "is_hem_pocket") +loaded_from_cif = fix_boolean_annotation(loaded_from_cif, "is_heme") + +print("\nAfter conversion:") +print(f" is_hem_pocket: {loaded_from_cif.is_hem_pocket.dtype}, {np.sum(loaded_from_cif.is_hem_pocket)} True values") +print(f" is_heme: {loaded_from_cif.is_heme.dtype}, {np.sum(loaded_from_cif.is_heme)} True values") +print(f" Sample values: {loaded_from_cif.is_hem_pocket[:3]}") + +# %% + +# Clean up temporary files +import shutil + +shutil.rmtree(temp_dir) +print(f"✓ Cleaned up temporary directory: {temp_dir}") +print("✓ Successfully demonstrated structure annotation, saving, and loading!") + +######################################################################## +# Related Examples +# ---------- +# +# - :doc:`pocket_conditioning_transform` - Create custom transforms for ligand pocket identification and ML feature generation diff --git a/docs/examples/basics.py b/docs/examples/basics.py new file mode 100644 index 00000000..4aba1b72 --- /dev/null +++ b/docs/examples/basics.py @@ -0,0 +1,111 @@ +""" +Structure Manipulation Basics +============================== + +This example demonstrates how to use the ``atomworks.io`` package to load, inspect, and manipulate mmCIF structure files. +You'll see how to parse a structure, visualize it, and perform basic analysis. + +Similar content is covered in :doc:`load_and_visualize_structures`, but here we provide a more concise overview. +""" + +######################################################################## +# Loading from mmCIF +# ================== +# +# We start by loading a structure from an mmCIF file using ``atomworks.io.parse``. This function supports various file formats and allows you to specify options such as which assembly to build, whether to add missing atoms, and more. + +from atomworks.io import parse +from atomworks.io.utils.testing import get_pdb_path_or_buffer +from atomworks.io.utils.visualize import view + +result_dict = parse( + filename=get_pdb_path_or_buffer("6lyz"), + build_assembly=["1"], + add_missing_atoms=True, + remove_waters=True, + hydrogen_policy="remove", + model=1, +) + +print("Keys in parsed result:", list(result_dict.keys())) + +######################################################################## +# Visualizing the asymmetric unit and assembly +# ============================================= +# +# You can visualize the asymmetric unit or any assembly using the built-in viewer from ``atomworks.io``. This is helpful for quickly inspecting the structure and its components. + +asym_unit = result_dict["asym_unit"][0] +asym_unit = asym_unit[asym_unit.occupancy > 0] +view(asym_unit) + +######################################################################## + +assembly = result_dict["assemblies"]["1"][0] +assembly = assembly[assembly.occupancy > 0] +view(assembly) + +######################################################################## +# Inspecting structure metadata +# ============================= +# +# The parsed result contains rich metadata, including chain and ligand information, as well as annotation categories. This information is useful for downstream analysis and filtering. + +print("Chain info:", result_dict["chain_info"]) +print("Ligand info:", result_dict["ligand_info"]) +print("Metadata:", result_dict["metadata"]) +print("Annotation categories:", result_dict["asym_unit"][0].get_annotation_categories()) + +######################################################################## +# Manipulating AtomArray +# ====================== +# +# You can easily extract coordinates for specific atoms or chains, and inspect bond information. This is useful for custom analysis or feature extraction. + +ca = assembly[(assembly.atom_name == "CA") & (assembly.occupancy > 0)] +print("Coordinates of all resolved CA atoms:", ca.coord.shape) + +chain = assembly[assembly.chain_id == "A"] +print("Coordinates of chain A (all heavy atoms):", chain.coord.shape) + +print("Bond array:", assembly.bonds.as_array()) + +######################################################################## +# Distance computations +# ===================== +# +# ``biotite.structure`` provides convenient functions for distance calculations between atoms or sets of atoms. Here we compute distances between C-alpha atoms. + +import biotite.structure as struc + +distance = struc.distance(ca.coord[0], ca.coord[1]) +print(f"Distance between first two C-alpha atoms: {distance:.2f} Å") + +distance = struc.distance(ca[0], ca) +print(f"Distances between first C-alpha atom and all other C-alpha atoms: {distance}") + +######################################################################## +# Efficient neighbor search with CellList +# ======================================== +# +# For efficient spatial queries, use ``CellList`` to find atoms within a certain radius. This is useful for contact analysis and neighborhood queries. + +resolved_atom_array = assembly[assembly.occupancy > 0] +cell_list = struc.CellList(resolved_atom_array, cell_size=5.0) + +near_atoms = cell_list.get_atoms(resolved_atom_array[0].coord, radius=4) +print(f"Number of atoms within 7 Å of the first atom: {near_atoms.shape[0]}") +print(f"Atom indices: {near_atoms}") +print(f"Chain IDs: {resolved_atom_array.chain_id[near_atoms]}") +print(f"Residue IDs: {resolved_atom_array.res_id[near_atoms]}") +print(f"Residue names: {resolved_atom_array.res_name[near_atoms]}") + +######################################################################## +# Next Steps +# ---------- +# +# Now that you've learned the basics of structure manipulation, you can explore more advanced topics: +# +# - :doc:`load_and_visualize_structures` - Comprehensive guide to loading and exploring protein structures +# - :doc:`annotate_and_save_structures` - Learn how to add custom annotations to structures and save them +# - :doc:`pocket_conditioning_transform` - Create custom transforms for ligand pocket identification and ML feature generation diff --git a/docs/examples/load_and_visualize_structures.py b/docs/examples/load_and_visualize_structures.py new file mode 100644 index 00000000..09aa358e --- /dev/null +++ b/docs/examples/load_and_visualize_structures.py @@ -0,0 +1,163 @@ +""" +Loading and Visualizing Protein Structures +=========================================== + +This example demonstrates how to load protein structures from various formats and explore their content using AtomWorks. + +.. figure:: /_static/examples/load_and_visualize_structures_01.png + :alt: Myoglobin structure visualization + :width: 400px + + Interactive 3D visualization of myoglobin structure showing protein chains and heme ligand. +""" + +######################################################################## +# Loading Structures +# ================== +# +# AtomWorks provides two main functions for loading structures, each optimized for different scenarios: +# +# - ``parse()``: Full processing pipeline that cleans, validates, and processes structures, typically from the RCSB PDB. Includes imputing missing atom, inferring bonds, and extensive validation. +# - ``load_any()``: Lightweight loader for structures that do not require as extensive processing, e.g., distillation examples. Much faster when you don't need the full cleaning pipeline or missing atoms imputed. +# +# If you see output like ``Environment variable CCD_MIRROR_PATH`` or ``PDB_MIRROR_PATH`` not set, don't worry - it just means we aren't using local copies of the PDB and/or CCD (we can still load the examples we need with an internet connection). + +import numpy as np + +from atomworks.io import parse +from atomworks.io.utils.io_utils import load_any +from atomworks.io.utils.testing import get_pdb_path_or_buffer +from atomworks.io.utils.visualize import view + +# sphinx_gallery_thumbnail_path = '_static/examples/load_and_visualize_structures_01.png' + +# Load a myoglobin structure (SPERM WHALE MYOGLOBIN F46V N-BUTYL ISOCYANIDE AT PH 9.0) +example_pdb_id = "101m" +pdb_path = get_pdb_path_or_buffer(example_pdb_id) + +######################################################################## +# Using ``parse()`` for Full Processing +# ------------------------------------ +# +# For RCSB structures, we typically load structures with ``parse()`` to get clean data suitable for most downstream tasks. +# +# There are many arguments that control how the structure is processed upon parsing; see the API documentation for more detail. +# A few are: +# - ``remove_waters``: Whether to remove water molecules (True by default) +# - ``remove_ccds``: CCD codes to filter out (Default is a list of common crystallization aids, e.g., GOL, SO4, etc.) +# - ``add_missing_atoms``: Whether to add missing (e.g., unresolved) heavy atoms (True by default) +# - ``hydrogen_policy``: How to handle hydrogens (e.g., "keep", "remove", or "infer"). Default is "keep". +# ... and many more! + +# ``parse`` returns a dictionary with several data fields; see the API docs for full details. +# The loaded assembly information is stored in the "assemblies" key, which we use in the example below. + +parse_output = parse(pdb_path) + +print("Available data keys:", list(parse_output.keys())) + +######################################################################## +# Using ``load_any()`` for Lightweight Loading +# ------------------------------------------- + +# For comparison: load_any() for lightweight loading (no extensive processing) +# Useful when you have clean data (e.g., from distillation) and/or want to preserve all annotations + +loaded_structure = load_any(pdb_path, extra_fields="all") # Load with all available fields +print(f"load_any result type: {type(loaded_structure)}") +print(f"Number of models: {len(loaded_structure)}") + +# NOTE: load_any returns an AtomArrayStack directly, while parse returns a dictionary with metadata, chain info, assemblies, etc. + +######################################################################## +# Structure Visualization +# ----------------------- +# +# AtomWorks includes built-in 3D visualization capabilities. Let's extract the biological assembly and explore the structure: + +# Extract the biological assembly (first assembly, first model) +atom_array = parse_output["assemblies"]["1"][0] + +# Explore available annotations +print("Available annotations:") +annotations = atom_array.get_annotation_categories() +for i, annotation in enumerate(annotations): + print(f" {i+1:2d}. {annotation}") + + +# %% + +# Visualize the complete structure within an interactive viewer +view(atom_array) + +######################################################################## +# .. figure:: /_static/examples/load_and_visualize_structures_01.png +# :alt: Myoglobin structure visualization + +######################################################################## +# Understanding Assemblies vs Asymmetric Units +# --------------------------------------------- +# +# The RCSB PDB draws a distinction between asymmetric units and biological assemblies; see the `RCSB PDB 101 Guide `_ for more information. +# The ``parse()`` function returns both asymmetric units and biological assemblies. Let's explore the difference: + +# Compare asymmetric unit vs assembly +asym_unit = parse_output["asym_unit"][0] # First model of asymmetric unit +assembly = parse_output["assemblies"]["1"][0] # First model of first assembly + +print(f"Asymmetric unit atoms: {len(asym_unit)}") +print(f"Assembly atoms: {len(assembly)}") +print(f"\nFor this structure, they are {'the same' if len(asym_unit) == len(assembly) else 'different'}") + +# Show available assemblies +print(f"\nAvailable assemblies: {list(parse_output['assemblies'].keys())}") + +######################################################################## +# Data Exploration +# ---------------- +# +# Let's now explore the structure composition by examining chains, residues, and other annotations: + +# Examine chain composition +unique_chains = np.unique(atom_array.chain_id) +print(f"Chains present: {unique_chains}") + +# Analyze what each chain contains +for chain in unique_chains: + chain_mask = atom_array.chain_id == chain + unique_residues = np.unique(atom_array.res_name[chain_mask]) + print(f"\nChain {chain}: {len(unique_residues)} unique residue types") + print(f" Examples: {unique_residues[:5]}") # Show first 5 residue types + print(f" Total atoms: {np.sum(chain_mask)}") + +######################################################################## +# Exploring Metadata and Chain Information +# ----------------------------------------- +# +# The ``parse()`` function also extracts rich metadata about the structure from the RCSB: + +# Explore metadata +metadata = parse_output["metadata"] +print("Structure metadata:") +for key, value in metadata.items(): + if key != "parse_arguments": # Skip the verbose parse arguments + print(f" {key}: {value}") + +# Explore chain information for Chain A +chain_a_info = parse_output["chain_info"].get("A", {}) +print("\nChain A information:") +for key, value in chain_a_info.items(): + # Show only a preview for long lists or strings + if isinstance(value, str | list): + preview = value[:15] + suffix = "..." if len(value) > 15 else "" + print(f" {key}: '{preview}{suffix}'") + else: + print(f" {key}: {value}") + +######################################################################## +# Related Examples +# ---------- +# +# - :doc:`annotate_and_save_structures` - Learn how to add custom annotations to structures and save them for later use +# - :doc:`pocket_conditioning_transform` - Create custom transforms for ligand pocket identification and ML feature generation diff --git a/docs/examples/plot_basics.py b/docs/examples/plot_basics.py deleted file mode 100644 index 44689189..00000000 --- a/docs/examples/plot_basics.py +++ /dev/null @@ -1,118 +0,0 @@ -# %% [markdown] -""" -CIFUtils: User Example -====================== - -This example demonstrates how to use the `atomworks.io` package to load, inspect, and manipulate mmCIF structure files. You'll see how to parse a structure, visualize it, and perform basic analysis. -""" - -# %% [markdown] -""" -## 1.1 Loading from mmCIF - -We start by loading a structure from an mmCIF file using `atomworks.io.parse`. This function supports various file formats and allows you to specify options such as which assembly to build, whether to add missing atoms, and more. -""" - -import io - -from biotite.database import rcsb - -from atomworks.io import parse -from atomworks.io.utils.testing import get_pdb_path -from atomworks.io.utils.visualize import view - - -def get_example_path_or_buffer(pdb_id: str) -> str | io.StringIO: - try: - # ... if file is locally available - return get_pdb_path(pdb_id) - except FileNotFoundError: - # ... otherwise, fetch the file from RCSB - return rcsb.fetch(pdb_id, format="cif") - - -result_dict = parse( - filename=get_example_path_or_buffer("6lyz"), - build_assembly=["1"], - add_missing_atoms=True, - remove_waters=True, - hydrogen_policy="remove", - model=1, -) - -print("Keys in parsed result:", list(result_dict.keys())) - -# %% [markdown] -""" -## 1.2 Visualizing the asymmetric unit and assembly - -You can visualize the asymmetric unit or any assembly using the built-in viewer from `atomworks.io`. This is helpful for quickly inspecting the structure and its components. -""" - - -asym_unit = result_dict["asym_unit"][0] -asym_unit = asym_unit[asym_unit.occupancy > 0] -view(asym_unit) - -assembly = result_dict["assemblies"]["1"][0] -assembly = assembly[assembly.occupancy > 0] -view(assembly) - -# %% [markdown] -""" -## 1.3 Inspecting structure metadata - -The parsed result contains rich metadata, including chain and ligand information, as well as annotation categories. This information is useful for downstream analysis and filtering. -""" - -print("Chain info:", result_dict["chain_info"]) -print("Ligand info:", result_dict["ligand_info"]) -print("Metadata:", result_dict["metadata"]) -print("Annotation categories:", result_dict["asym_unit"][0].get_annotation_categories()) - -# %% [markdown] -""" -## 1.4 Manipulating AtomArray - -You can easily extract coordinates for specific atoms or chains, and inspect bond information. This is useful for custom analysis or feature extraction. -""" - -ca = assembly[(assembly.atom_name == "CA") & (assembly.occupancy > 0)] -print("Coordinates of all resolved CA atoms:", ca.coord.shape) - -chain = assembly[assembly.chain_id == "A"] -print("Coordinates of chain A (all heavy atoms):", chain.coord.shape) - -print("Bond array:", assembly.bonds.as_array()) - -# %% [markdown] -""" -## 1.5 Distance computations - -`biotite.structure` provides convenient functions for distance calculations between atoms or sets of atoms. Here we compute distances between C-alpha atoms. -""" - -import biotite.structure as struc - -distance = struc.distance(ca.coord[0], ca.coord[1]) -print(f"Distance between first two C-alpha atoms: {distance:.2f} Å") - -distance = struc.distance(ca[0], ca) -print(f"Distances between first C-alpha atom and all other C-alpha atoms: {distance}") - -# %% [markdown] -""" -## 1.6 Efficient neighbor search with CellList - -For efficient spatial queries, use `CellList` to find atoms within a certain radius. This is useful for contact analysis and neighborhood queries. -""" - -resolved_atom_array = assembly[assembly.occupancy > 0] -cell_list = struc.CellList(resolved_atom_array, cell_size=5.0) - -near_atoms = cell_list.get_atoms(resolved_atom_array[0].coord, radius=4) -print(f"Number of atoms within 7 Å of the first atom: {near_atoms.shape[0]}") -print(f"Atom indices: {near_atoms}") -print(f"Chain IDs: {resolved_atom_array.chain_id[near_atoms]}") -print(f"Residue IDs: {resolved_atom_array.res_id[near_atoms]}") -print(f"Residue names: {resolved_atom_array.res_name[near_atoms]}") diff --git a/docs/examples/plot_monomer.py b/docs/examples/plot_monomer.py deleted file mode 100644 index 02159747..00000000 --- a/docs/examples/plot_monomer.py +++ /dev/null @@ -1,51 +0,0 @@ -""" -Plot a Monomer Structure -======================== - -This example demonstrates how to parse a monomer PDB file and plot the backbone trace using atomworks.io. - -.. figure:: /_static/best_practices_mental_models.png - :alt: Custom thumbnail for this example - :width: 300px - - Custom thumbnail image for this example. - -**Key steps:** -- Load a PDB file -- Visualize the backbone trace - -""" - -# %% [markdown] -# ## Import libraries - -import io - -from biotite.database import rcsb - -from atomworks.io.utils.io_utils import load_any -from atomworks.io.utils.testing import get_pdb_path -from atomworks.io.utils.visualize import view - - -def get_example_path_or_buffer(pdb_id: str) -> io.StringIO | str: - try: - return get_pdb_path(pdb_id) - except FileNotFoundError: - return rcsb.fetch(pdb_id, format="cif") - - -# %% [markdown] -# ## Load and plot the structure - -example = get_example_path_or_buffer("6lyz") # e.g. '/path/to/6lyz.cif' or io.StringIO(rcsb.fetch("6lyz", "cif")) -atom_array = load_any(example, model=1, extra_fields=["charge", "occupancy"]) - -# ... inspect the first 15 atoms -print(f"Structure has {atom_array.array_length()} atoms. First 15 atoms:") -atom_array[:15] - -# %% - -# ... show the structure in a jupyter notebook -view(atom_array[atom_array.chain_id == "A"]) diff --git a/docs/examples/pocket_conditioning_transform.py b/docs/examples/pocket_conditioning_transform.py new file mode 100644 index 00000000..668edcf7 --- /dev/null +++ b/docs/examples/pocket_conditioning_transform.py @@ -0,0 +1,310 @@ +""" +Creating Custom Transforms: Ligand Pocket Conditioning +====================================================== + +This example demonstrates how to create custom Transform classes in AtomWorks using ligand pocket identification as an example. We'll build two transforms that follow AtomWorks conventions. + +**Prerequisites**: Familiarity with :doc:`load_and_visualize_structures` and :doc:`annotate_and_save_structures` for basic structure handling and annotation techniques. + +.. figure:: /_static/examples/pocket_conditioning_transform_01.png + :alt: Ligand pocket visualization +""" + +######################################################################## +# Transform Architecture and Design Patterns +# =========================================== +# +# AtomWorks Transform classes follow a standard pattern with one required method - ``forward()`` - and several optional methods/attributes to promote interoperability and pipeline compatibility. +# +# Required Method +# --------------- +# - ``forward()``: The only mandatory method. Takes a state dictionary and returns an updated dictionary. +# +# Optional Methods & Attributes +# ----------------------------- +# - ``check_input()``: Validates input data (annotations, types, etc.), raising informative errors if conditions are violated +# - ``requires_previous_transforms``: List of ``Transforms`` that MUST run within the pipeline prior to this ``Transform`` +# - ``incompatible_previous_transforms``: List of ``Transforms`` that CANNOT have been run within the pipeline prior to this ``Transform`` +# +# Conventions +# ----------- +# **A.** Store information in ``AtomArray`` annotations, not in the state dictionary. +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# This ensures robustness when atoms are added/removed downstream. +# +# For the example below: +# +# - ✅ Add ``is_pocket_atom`` annotation to AtomArray +# - ❌ Store ``pocket_atom_indices`` in dictionary (which creates significant dependencies with operations that delete or re-order atoms) +# +# **B.** Within ``forward()``, call a stand-alone function with the same name as the transform class. +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# We thus maintain an object-oriented and a functional API, making our core logic re-usable and testable outside of the ``Transform`` framework. +# +# For the example below: +# +# - ``AnnotateLigandPockets.forward()`` calls ``annotate_ligand_pockets()`` function +# - ``FeaturizePocketAtoms.forward()`` calls ``featurize_pocket_atoms()`` function +# +# Additionally, this function should preserve the input (e.g., not modify the underlying ``AtomArray``) and take as arguments any necessary parameters. +# +# **C.** Each ``Transform`` should follow the single-responsibility-principle; in particular separate Annotation from Featurization ``Transforms`` +# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ +# +# To ensure our ``Transform`` code is maximally forward-compatible and re-usable across disparate pipelines, we adhere to the single responsibility principle - that is, each transform should do *exactly one* action. +# +# For the example below: +# +# - ``AnnotateLigandPockets`` only identifies and annotates pocket atoms +# - ``FeaturizePocketAtoms`` only converts existing annotations to numeric features +# +# Now, if a different model wants to perform an action on small molecule pockets, but with a different featurization scheme, the researchers would simply need to write a different ``Featurize`` Transform leveraging the existing annotations. + +import biotite.structure as struc +import numpy as np +from biotite.structure import AtomArray + +# AtomWorks imports +from atomworks.io import parse +from atomworks.io.utils.testing import get_pdb_path_or_buffer +from atomworks.io.utils.visualize import view +from atomworks.ml.transforms._checks import check_atom_array_annotation +from atomworks.ml.transforms.base import Transform + +# sphinx_gallery_thumbnail_path = '_static/examples/pocket_conditioning_transform_01.png' + +# Load example structure (myoglobin with heme ligand; our recurring test example) +example_pdb_id = "101m" +pdb_path = get_pdb_path_or_buffer(example_pdb_id) + +parse_output = parse(pdb_path) +atom_array = parse_output["assemblies"]["1"][0] + +print(f"Loaded structure: {len(atom_array)} atoms") +print(f"Non-polymer residues: {np.unique(atom_array.res_name[~atom_array.is_polymer])}") +print(f"Heme atoms: {np.sum(atom_array.res_name == 'HEM')}") + +######################################################################## +# Building ``AnnotateLigandPockets`` +# =============================== +# +# Let's create a ``Transform`` that identifies atoms near ligands (non-polymer molecules) of sufficient size. +# +# Observe how we follow the conventions outlined above: +# +# - Stores results as ``AtomArray`` annotation rather than returning indices or masks separately. +# - Does not modify input ``AtomArray`` in place. +# - Function name mimics ``Transform`` class name for clarity. +# - Accepts all parameters as arguments. + + +def annotate_ligand_pockets( + atom_array: AtomArray, + pocket_distance: float = 6.0, + n_min_ligand_atoms: int = 5, + annotation_name: str = "is_ligand_pocket", +) -> AtomArray: + """ + Identify atoms near ligands of sufficient size. + + Args: + atom_array: Input structure + pocket_distance: Distance threshold for pocket identification (Angstroms) + n_min_ligand_atoms: Minimum atoms required for a ligand (across the full pn_unit) to define pockets + annotation_name: Name for the boolean annotation + + Returns: + AtomArray with ligand pocket annotation added + """ + atom_array = atom_array.copy() # By convention, do not modify input in place + + # Find all ligand pn_unit_iids within our structure and their atom counts + # We make use of the pn_unit_iid annotation, which is most applicable for ligands, elegantly + # handling cases of multi-residue or multi-chain small molecules (e.g., many sugars) + # See the Glossary for more information regarding our naming conventions within AtomWorks + ligand_pn_unit_iids, ligand_counts = np.unique(atom_array.pn_unit_iid[~atom_array.is_polymer], return_counts=True) + + # Filter to only ligands with sufficient size + valid_ligand_mask = ligand_counts >= n_min_ligand_atoms + valid_ligand_pn_unit_iids = ligand_pn_unit_iids[valid_ligand_mask] + + # Initialize pocket annotation + pocket_annotation = np.zeros(len(atom_array), dtype=bool) + + if len(valid_ligand_pn_unit_iids) == 0: + # No valid ligands found - store empty annotation and return + atom_array.set_annotation(annotation_name, pocket_annotation) + return atom_array + + # Build CellList for efficient distance computations on CPU + # (Atoms with invalid coordinates would break our distance search) + valid_coords_mask = ~np.isnan(atom_array.coord).any(axis=1) + assert np.any(valid_coords_mask), "No valid coordinates found" + + valid_coords = atom_array.coord[valid_coords_mask] + cell_list = struc.CellList(valid_coords, cell_size=pocket_distance) + + # Get coordinates of all valid ligands + all_valid_ligands_mask = np.isin(atom_array.pn_unit_iid, valid_ligand_pn_unit_iids) + all_ligand_coords = atom_array.coord[all_valid_ligands_mask] + + # Find atoms within distance of any ligand coordinates (all at once) + distance_mask = cell_list.get_atoms(all_ligand_coords, pocket_distance, as_mask=True) + near_ligand_valid = np.any(distance_mask, axis=0) + + # Map back to full atom array + near_ligand_full = np.zeros(len(atom_array), dtype=bool) + near_ligand_full[valid_coords_mask] = near_ligand_valid + + # Only polymer atoms can be pocket atoms + pocket_annotation = atom_array.is_polymer & near_ligand_full + + # Store result as annotation (AtomWorks convention) + atom_array.set_annotation(annotation_name, pocket_annotation) + return atom_array + + +class AnnotateLigandPockets(Transform): + """Identify atoms near ligands of sufficient size.""" + + def __init__( + self, pocket_distance: float = 6.0, n_min_ligand_atoms: int = 5, annotation_name: str = "is_ligand_pocket" + ): + self.pocket_distance = pocket_distance + self.n_min_ligand_atoms = n_min_ligand_atoms + self.annotation_name = annotation_name + + def check_input(self, data: dict) -> None: + """Validate input has required annotations. (Optional method)""" + check_atom_array_annotation(data, ["is_polymer", "pn_unit_iid"]) + + def forward(self, data: dict) -> dict: + """Apply ligand pocket annotation. (Required method)""" + # Follow forward/function pattern: call standalone function + data["atom_array"] = annotate_ligand_pockets( + data["atom_array"], + pocket_distance=self.pocket_distance, + n_min_ligand_atoms=self.n_min_ligand_atoms, + annotation_name=self.annotation_name, + ) + return data + + +######################################################################## + +# Test the functional version +result_array = annotate_ligand_pockets( + atom_array, pocket_distance=6.0, n_min_ligand_atoms=5, annotation_name="is_ligand_pocket" +) + +# Here, we are using AtomWork's "query" syntax for convenience, which operates similar to Pandas DataFrame queries +# Please see the API documentation for more details +view(result_array.query("is_ligand_pocket | (res_name == 'HEM')")) + +######################################################################## +# .. figure:: /_static/examples/pocket_conditioning_transform_01.png +# :alt: Ligand pocket visualization + +######################################################################## +# Building ``FeaturizePocketAtoms`` +# ============================== +# +# Now let's create a model-specific transform that converts derived pocket annotations into numeric features. +# +# Here, we also demonstrate the use of: +# - **``requires_previous_transforms``**: Ensures dependency ordering in pipelines +# - **``check_atom_array_annotation()``**: Validates required annotations using AtomWorks utilities +# +# We can imagine varying this featurization ``Transform`` across models while keeping the original annotation ``Transform`` constant. + + +def featurize_pocket_atoms(atom_array: AtomArray, pocket_annotation_name: str = "is_ligand_pocket") -> dict: + """ + Create one-hot encoded features from pocket annotations. + + Args: + atom_array: Structure with pocket annotations + pocket_annotation_name: Name of the pocket boolean annotation + + Returns: + Dictionary with feature array and metadata + """ + pocket_mask = getattr(atom_array, pocket_annotation_name) + + # Create one-hot encoded feature: 0.0 for non-pocket, 1.0 for pocket atoms + features = pocket_mask.astype(np.float32).reshape(-1, 1) + + return {"features": features, "feature_names": ["is_pocket_atom"], "n_atoms": len(atom_array)} + + +class FeaturizePocketAtoms(Transform): + """Convert pocket annotations into one-hot encoded numeric features.""" + + requires_previous_transforms = ["AnnotateLigandPockets"] # noqa: RUF012 + + def __init__(self, pocket_annotation_name: str = "is_ligand_pocket", feature_key: str = "pocket_features"): + self.pocket_annotation_name = pocket_annotation_name + self.feature_key = feature_key + + def check_input(self, data: dict) -> None: + """Validate input has pocket annotations using AtomWorks utility.""" + check_atom_array_annotation(data, [self.pocket_annotation_name]) + + def forward(self, data: dict) -> dict: + """Generate features following the forward/function pattern.""" + data[self.feature_key] = featurize_pocket_atoms( + data["atom_array"], pocket_annotation_name=self.pocket_annotation_name + ) + return data + + +######################################################################## + +# Test featurization using a proper pipeline +# First apply the annotation transform, then the featurization +annotator = AnnotateLigandPockets(pocket_distance=6.0, n_min_ligand_atoms=5) +featurizer = FeaturizePocketAtoms() + +# Apply both transforms in sequence +data = {"atom_array": atom_array} +annotated_data = annotator(data) +feature_result = featurizer(annotated_data) + +features = feature_result["pocket_features"] +print(f"Generated features: {features['features'].shape}") +print(f"Feature names: {features['feature_names']}") +print(f"Feature type: {type(features['features'])}") +print(f"Pocket atoms (sum): {features['features'].sum():.0f}") +print(f"Non-pocket atoms: {len(features['features']) - features['features'].sum():.0f}") + +######################################################################## +# Pipeline Composition +# ==================== +# +# Transform composition allows chaining transforms together with automatic dependency checking: + +from atomworks.ml.transforms.base import Compose + +# Create a complete ligand pocket processing pipeline +ligand_pocket_pipeline = Compose( + [ + AnnotateLigandPockets(pocket_distance=6.0, n_min_ligand_atoms=3), + FeaturizePocketAtoms(feature_key="pocket_features"), + ] +) + +# Apply pipeline to fresh data +fresh_data = {"atom_array": atom_array} +pipeline_result = ligand_pocket_pipeline(fresh_data) + +print("Pipeline Results:") +print(f" Transforms applied: {[t.__class__.__name__ for t in ligand_pocket_pipeline.transforms]}") +print(f" Pocket atoms found: {np.sum(pipeline_result['atom_array'].is_ligand_pocket)}") +print(f" Features shape: {pipeline_result['pocket_features']['features'].shape}") + +# Demonstrate the + operator +alternative_pipeline = AnnotateLigandPockets(n_min_ligand_atoms=8) + FeaturizePocketAtoms() +alt_result = alternative_pipeline({"atom_array": atom_array}) +print(f" Alternative (min 8 atoms): {np.sum(alt_result['atom_array'].is_ligand_pocket)} pocket atoms") diff --git a/docs/glossary.rst b/docs/glossary.rst index 3e8aabc6..5d4bb62e 100644 --- a/docs/glossary.rst +++ b/docs/glossary.rst @@ -47,4 +47,10 @@ Molecules ~~~~~~~~~ - `molecule_id`: 1 (numeric for memory concerns, but can be conceptualized as "A,B,C") - `molecule_iid`: 1 (numeric for memory concerns, but can be conceptualized as "A_1,B_1,C_1"), 2 (e.g., "A_2,B_2,C_2") -- `molecule_entity`: 1 \ No newline at end of file +- `molecule_entity`: 1 + +Visually, we can represent the above example as: + +.. image:: _static/atomworks_glossary.png + :alt: Visual representation of AtomWorks combinatorial nomenclature worked example + :align: center \ No newline at end of file diff --git a/docs/sg_execution_times.rst b/docs/sg_execution_times.rst index 4c2f6533..407a5f96 100644 --- a/docs/sg_execution_times.rst +++ b/docs/sg_execution_times.rst @@ -6,7 +6,7 @@ Computation times ================= -**00:05.533** total execution time for 2 files **from all galleries**: +**00:03.491** total execution time for 6 files **from all galleries**: .. container:: @@ -32,9 +32,21 @@ Computation times * - Example - Time - Mem (MB) - * - :ref:`sphx_glr_auto_examples_plot_monomer.py` (``examples/plot_monomer.py``) - - 00:05.076 - - 0.0 * - :ref:`sphx_glr_auto_examples_plot_basics.py` (``examples/plot_basics.py``) - - 00:00.457 + - 00:03.491 + - 0.0 + * - :ref:`sphx_glr_auto_examples_annotate_and_save_structures.py` (``examples/annotate_and_save_structures.py``) + - 00:00.000 + - 0.0 + * - :ref:`sphx_glr_auto_examples_basics.py` (``examples/basics.py``) + - 00:00.000 + - 0.0 + * - :ref:`sphx_glr_auto_examples_load_and_visualize_structures.py` (``examples/load_and_visualize_structures.py``) + - 00:00.000 + - 0.0 + * - :ref:`sphx_glr_auto_examples_mpnn_pipeline.py` (``examples/mpnn_pipeline.py``) + - 00:00.000 + - 0.0 + * - :ref:`sphx_glr_auto_examples_pocket_conditioning_transform.py` (``examples/pocket_conditioning_transform.py``) + - 00:00.000 - 0.0 diff --git a/pyproject.toml b/pyproject.toml index 14e92393..22181121 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,30 +27,27 @@ dependencies = [ "cytoolz>=0.12.3,<1", # Cython-optimized tools for itertools and functional programming "tqdm>=4.65.0,<5", # Fast, extensible progress bar for loops and more # ... CLI & config management - "fire>=0.6.0,<1", # Argument parsing (legacy) "typer>=0.12.5,<1", # Modern CLI framework # ... linear algebra, maths & ml - "numpy>=1.25.0,<2", # TODO: Enable numpy 2.x + "numpy>=1.25.0,<3", "scipy>=1.13.1,<2", # ... data tools - "pandas>=2.2,<2.3", # Data manipulation and analysis # TODO: Test upper bound + "pandas>=2.2,<2.4", # Data manipulation and analysis "pyarrow==17.0.0", # Columnar data format for efficient data storage and processing # TODO: Test later versions - "fastparquet==2024.5.0", # Fast Parquet file format implementation # TODO: Test if still needed # ... bioinformatics "py3Dmol>=2.2.1,<3", # Python wrapper for 3Dmol.js "pymol-remote>=0.0.5", # Remote access to PyMOL from Python (has no dependencies) - "biotite==1.3.0", # Biotite is a Python library for bioinformatics # TODO: Test newer versions - "hydride==1.2.3", # Biotite-supported tool for hydrogen addition + "biotite>=1.3.0,<2", # Biotite is a Python library for bioinformatics # TODO: Test newer versions + "hydride>=1.2.3,<2", # Biotite-supported tool for hydrogen addition # ... small molecule libraries - "rdkit>=2024.3.5", - + "rdkit>=2024.3.5,<2025.9", ] [project.optional-dependencies] ml = [ # atomworks-ml dependencies - "torch==2.7.0", - "einops==0.7.0", + "torch>=2.2.0,<2.8", + "einops>=0.7.0,<1", ] openbabel = [ @@ -61,10 +58,6 @@ openbabel = [ dev = [ # Linters & formatters "ruff==0.8.3", - "pre-commit==3.7.1", - # Debugger/interactive - "debugpy>=1.8.5,<2", - "ipykernel>=6.29.4,<7", # Testing tools "pytest>=8.2.0,<9", # testing framework "pytest-testmon>=2.1.1,<3", # run only tests related to changed code @@ -94,13 +87,19 @@ repository = "https://github.com/RosettaCommons/atomworks" documentation = "https://baker-laboratory.github.io/atomworks-dev/latest" [project.scripts] -atomworks = "atomworks.cli.__main__:main" +atomworks = "atomworks_cli.__main__:main" # Build settings ---------------------------------------------------------------------- [build-system] requires = ["hatchling", "hatch-vcs"] build-backend = "hatchling.build" +[tool.hatch.build] +packages = [ + "src/atomworks", + "src/atomworks_cli", +] + [tool.hatch.metadata] allow-direct-references = true dynamic = false # optional – this disables dynamic metadata guessing diff --git a/src/atomworks/cli/__main__.py b/src/atomworks/cli/__main__.py deleted file mode 100644 index 3c1a6d61..00000000 --- a/src/atomworks/cli/__main__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""Entry point for the AtomWorks command-line interface.""" - -from __future__ import annotations - -from . import app - - -def main() -> None: - app() - - -if __name__ == "__main__": - main() diff --git a/src/atomworks/io/common.py b/src/atomworks/common.py similarity index 69% rename from src/atomworks/io/common.py rename to src/atomworks/common.py index 57cb92eb..89a8d94b 100644 --- a/src/atomworks/io/common.py +++ b/src/atomworks/common.py @@ -1,9 +1,8 @@ -from __future__ import annotations +"""Common functions used throughout the project.""" import copy import hashlib -from collections import OrderedDict -from collections.abc import Callable, Iterable, Iterator +from collections.abc import Callable from functools import lru_cache, wraps from typing import Any @@ -12,26 +11,29 @@ def exists(obj: Any) -> bool: + """Check that `obj` is not `None`.""" return obj is not None def default(obj: Any, default: Any) -> Any: + """Return `obj` if not `None`, otherwise return `default`.""" return obj if exists(obj) else default -def deduplicate_iterator(iterator: Iterable) -> Iterator: - """Deduplicate an iterator while preserving order.""" - return iter(OrderedDict.fromkeys(iterator)) - - def to_hashable(element: Any) -> Any: """Convert an element to a hashable type.""" return element if isinstance(element, int | str | np.integer | np.str_) else tuple(element) +def string_to_md5_hash(s: str, truncate: int = 32) -> str: + """Generate an MD5 hash of a string and return the first `truncate` characters.""" + full_hash = hashlib.md5(s.encode("utf-8")).hexdigest() + return full_hash[:truncate] + + def sum_string_arrays(*objs: np.ndarray | str) -> np.ndarray: """ - Sum a list of string arrays / strings into a single string array by concatenating them and + Sum a list of string arrays or strings into a single string array by concatenating them and determining the shortest string length to set as dtype. """ return reduce(np.char.add, objs).astype(object).astype(str) @@ -47,6 +49,24 @@ def listmap(func: Callable, *iterables) -> list: return compose(list, map)(func, *iterables) +def as_list(value: Any) -> list: + """Convert a value to a list. + + Handles various types using duck typing: + - Iterable objects (lists, tuples, strings, etc.): converted to list + - Single values: wrapped in a list + """ + try: + # Try to iterate over the value (duck typing approach) + # Exclude strings since they're iterable but we want to treat them as single values + if isinstance(value, str): + return [value] + return list(value) + except TypeError: + # If it's not iterable, wrap it in a list + return [value] + + def immutable_lru_cache(maxsize: int = 128, typed: bool = False, deepcopy: bool = True) -> Callable: """An immutable version of `lru_cache` for caching functions that return mutable objects.""" copy_func = copy.deepcopy if deepcopy else copy.copy @@ -89,9 +109,3 @@ def __call__(self, value: Any) -> int: self.key_to_id[value] = self.next_id self.next_id += 1 return self.key_to_id[value] - - -def md5_hash_string(s: str, length: int = 32) -> str: - """Generate an MD5 hash of a string and return the first `length` characters.""" - full_hash = hashlib.md5(s.encode("utf-8")).hexdigest() - return full_hash[:length] diff --git a/src/atomworks/io/constants.py b/src/atomworks/constants.py similarity index 100% rename from src/atomworks/io/constants.py rename to src/atomworks/constants.py diff --git a/src/atomworks/enums.py b/src/atomworks/enums.py index 135c0972..11ba5bc9 100644 --- a/src/atomworks/enums.py +++ b/src/atomworks/enums.py @@ -7,7 +7,7 @@ import numpy as np from toolz import keymap -from atomworks.io.constants import ( +from atomworks.constants import ( AA_LIKE_CHEM_TYPES, DNA_LIKE_CHEM_TYPES, POLYPEPTIDE_D_CHEM_TYPES, diff --git a/src/atomworks/io/parser.py b/src/atomworks/io/parser.py index 40521688..7d8dc3f5 100644 --- a/src/atomworks/io/parser.py +++ b/src/atomworks/io/parser.py @@ -18,9 +18,9 @@ from toolz import keyfilter import atomworks.io.transforms.atom_array as ta +from atomworks.common import exists, string_to_md5_hash +from atomworks.constants import CCD_MIRROR_PATH, CRYSTALLIZATION_AIDS, WATER_LIKE_CCDS from atomworks.io import template -from atomworks.io.common import exists, md5_hash_string -from atomworks.io.constants import CCD_MIRROR_PATH, CRYSTALLIZATION_AIDS, WATER_LIKE_CCDS from atomworks.io.transforms.categories import ( category_to_dict, extract_crystallization_details, @@ -150,7 +150,7 @@ def parse( build_assembly (string, list, or tuple, optional): Specifies which assembly to build, if any. Options are None (e.g., asymmetric unit), "first", "all", or a list or tuple of assembly IDs. Defaults to "all". extra_fields (list, optional): A list of extra fields to include in the AtomArrayStack. Defaults to None. "all" includes all fields. - only support cif files. + Only support mmCIF files. keep_cif_block (bool, optional): Whether to keep the CIF block in the result. Defaults to False. Returns: @@ -218,7 +218,7 @@ def parse( } # Compose args_string from parse_arguments values (in order) args_string = ",".join(str(parse_arguments[k]) for k in parse_arguments) - args_hash = md5_hash_string(args_string, length=8) + args_hash = string_to_md5_hash(args_string, truncate=8) # ... generate assembly info assembly_info = ",".join(build_assembly) if isinstance(build_assembly, list | tuple) else build_assembly diff --git a/src/atomworks/io/template.py b/src/atomworks/io/template.py index e12942e3..e0465033 100644 --- a/src/atomworks/io/template.py +++ b/src/atomworks/io/template.py @@ -7,8 +7,8 @@ from biotite.structure import AtomArray, BondList import atomworks.io.transforms.atom_array as ta -from atomworks.io.common import exists, immutable_lru_cache -from atomworks.io.constants import CCD_MIRROR_PATH, DO_NOT_MATCH_CCD +from atomworks.common import exists, immutable_lru_cache +from atomworks.constants import CCD_MIRROR_PATH, DO_NOT_MATCH_CCD from atomworks.io.utils.bonds import ( correct_bond_types_for_nucleophilic_additions, correct_formal_charges_for_specified_atoms, diff --git a/src/atomworks/io/tools/fasta.py b/src/atomworks/io/tools/fasta.py index 0944caf5..4a5ae1a7 100644 --- a/src/atomworks/io/tools/fasta.py +++ b/src/atomworks/io/tools/fasta.py @@ -6,8 +6,8 @@ import os import re +from atomworks.constants import CCD_MIRROR_PATH from atomworks.enums import ChainType -from atomworks.io.constants import CCD_MIRROR_PATH from atomworks.io.utils.ccd import ( check_ccd_codes_are_available, ) diff --git a/src/atomworks/io/tools/inference.py b/src/atomworks/io/tools/inference.py index c7da1cf9..4da19c82 100644 --- a/src/atomworks/io/tools/inference.py +++ b/src/atomworks/io/tools/inference.py @@ -14,16 +14,16 @@ from rdkit.Chem import AllChem import atomworks.io.transforms.atom_array as ta -from atomworks.enums import ChainType, ChainTypeInfo -from atomworks.io import parse -from atomworks.io.common import KeyToIntMapper, exists -from atomworks.io.constants import ( +from atomworks.common import KeyToIntMapper, exists +from atomworks.constants import ( CCD_MIRROR_PATH, STANDARD_AA_ONE_LETTER, STANDARD_DNA_ONE_LETTER, STANDARD_RNA, UNKNOWN_LIGAND, ) +from atomworks.enums import ChainType, ChainTypeInfo +from atomworks.io import parse from atomworks.io.parser import DEFAULT_PARSE_KWARGS from atomworks.io.template import build_template_atom_array from atomworks.io.tools.fasta import one_letter_to_ccd_code, split_generalized_fasta_sequence diff --git a/src/atomworks/io/tools/rdkit.py b/src/atomworks/io/tools/rdkit.py index 389e28e4..c7b25d3d 100644 --- a/src/atomworks/io/tools/rdkit.py +++ b/src/atomworks/io/tools/rdkit.py @@ -22,8 +22,8 @@ from rdkit.DataStructs import ExplicitBitVect import atomworks.io.transforms.atom_array as ta -from atomworks.io.common import exists, immutable_lru_cache, not_isin -from atomworks.io.constants import ( +from atomworks.common import exists, immutable_lru_cache, not_isin +from atomworks.constants import ( BIOTITE_DEFAULT_ANNOTATIONS, CCD_MIRROR_PATH, HYDROGEN_LIKE_SYMBOLS, diff --git a/src/atomworks/io/transforms/atom_array.py b/src/atomworks/io/transforms/atom_array.py index 3050addf..7f9c2e5c 100644 --- a/src/atomworks/io/transforms/atom_array.py +++ b/src/atomworks/io/transforms/atom_array.py @@ -12,8 +12,8 @@ import pandas as pd from biotite.structure import AtomArray, AtomArrayStack, stack -from atomworks.io.common import listmap, not_isin, sum_string_arrays -from atomworks.io.constants import ELEMENT_NAME_TO_ATOMIC_NUMBER, HYDROGEN_LIKE_SYMBOLS, WATER_LIKE_CCDS +from atomworks.common import listmap, not_isin, sum_string_arrays +from atomworks.constants import ELEMENT_NAME_TO_ATOMIC_NUMBER, HYDROGEN_LIKE_SYMBOLS, WATER_LIKE_CCDS from atomworks.io.utils.bonds import ( generate_inter_level_bond_hash, get_coarse_graph_as_nodes_and_edges, @@ -332,6 +332,19 @@ def update_nonpoly_seq_ids(atom_array: AtomArray, chain_info_dict: dict) -> Atom return atom_array +def _safe_to_int(x: str | int | None) -> int: + """Robustly convert values to integers: map '.', empty strings, and None to -1; parse numerics otherwise""" + if x is None: + return -1 + s = str(x).strip() + if s in (".", ""): + return -1 + try: + return int(s) + except Exception: + return -1 + + def replace_negative_res_ids_with_auth_seq_id(atom_array: AtomArray) -> AtomArray: """ Replaces res_id values of -1 with the corresponding auth_seq_id values. @@ -350,8 +363,7 @@ def replace_negative_res_ids_with_auth_seq_id(atom_array: AtomArray) -> AtomArra # Convert auth_seq_ids to int if they are strings (as they are sometimes from AF-3 predictions) if author_seq_ids.dtype.kind in "UO": # Unicode or Object (string-like) - # Handle '.' values by replacing with -1, then convert to int - author_seq_ids = np.where(author_seq_ids == ".", -1, author_seq_ids).astype(int) + author_seq_ids = np.frompyfunc(_safe_to_int, 1, 1)(author_seq_ids).astype(int) atom_array.res_id[negative_res_id_mask] = author_seq_ids[negative_res_id_mask] diff --git a/src/atomworks/io/transforms/categories.py b/src/atomworks/io/transforms/categories.py index db4ca6bb..f8002088 100644 --- a/src/atomworks/io/transforms/categories.py +++ b/src/atomworks/io/transforms/categories.py @@ -17,9 +17,9 @@ from biotite.structure import AtomArray from biotite.structure.io.pdbx import CIFBlock +from atomworks.common import exists +from atomworks.constants import CCD_MIRROR_PATH from atomworks.enums import ChainType -from atomworks.io.common import deduplicate_iterator, exists -from atomworks.io.constants import CCD_MIRROR_PATH from atomworks.io.utils.selection import get_residue_starts from atomworks.io.utils.sequence import get_1_from_3_letter_code @@ -253,7 +253,9 @@ def load_monomer_sequence_information_from_category( # Build up the chain_info_dict with the sequence information res_starts = get_residue_starts(atom_array) - for chain_id in deduplicate_iterator(struc.get_chains(atom_array)): + # ... get the unique chain IDs by order of first appearance in the AtomArray + chain_ids = dict.fromkeys(struc.get_chains(atom_array)) + for chain_id in chain_ids: rcsb_entity = int(chain_info_dict[chain_id]["rcsb_entity"]) if rcsb_entity in polymer_entity_id_to_res_names_and_ids: diff --git a/src/atomworks/io/utils/bonds.py b/src/atomworks/io/utils/bonds.py index 751c486a..ff499eda 100644 --- a/src/atomworks/io/utils/bonds.py +++ b/src/atomworks/io/utils/bonds.py @@ -28,9 +28,8 @@ _get_struct_conn_col_name, ) -from atomworks.enums import ChainType, ChainTypeInfo -from atomworks.io.common import sum_string_arrays, to_hashable -from atomworks.io.constants import ( +from atomworks.common import sum_string_arrays, to_hashable +from atomworks.constants import ( AA_LIKE_CHEM_TYPES, CHEM_TYPE_POLYMERIZATION_ATOMS, DEFAULT_VALENCE, @@ -39,6 +38,7 @@ STRUCT_CONN_BOND_ORDER_TO_INT, STRUCT_CONN_BOND_TYPES, ) +from atomworks.enums import ChainType, ChainTypeInfo from atomworks.io.utils.ccd import get_chem_comp_leaving_atom_names, get_chem_comp_type from atomworks.io.utils.selection import get_annotation, get_residue_starts from atomworks.io.utils.testing import has_ambiguous_annotation_set diff --git a/src/atomworks/io/utils/ccd.py b/src/atomworks/io/utils/ccd.py index 280311eb..c6b2cee9 100644 --- a/src/atomworks/io/utils/ccd.py +++ b/src/atomworks/io/utils/ccd.py @@ -11,9 +11,8 @@ import numpy as np import toolz -from atomworks.enums import ChainType, ChainTypeInfo -from atomworks.io.common import exists, immutable_lru_cache -from atomworks.io.constants import ( +from atomworks.common import exists, immutable_lru_cache +from atomworks.constants import ( AA_LIKE_CHEM_TYPES, CCD_MIRROR_PATH, DNA_LIKE_CHEM_TYPES, @@ -25,6 +24,7 @@ UNKNOWN_LIGAND, UNKNOWN_RNA, ) +from atomworks.enums import ChainType, ChainTypeInfo logger = logging.getLogger(__name__) diff --git a/src/atomworks/io/utils/io_utils.py b/src/atomworks/io/utils/io_utils.py index 02df78c0..1690c018 100644 --- a/src/atomworks/io/utils/io_utils.py +++ b/src/atomworks/io/utils/io_utils.py @@ -21,9 +21,9 @@ from biotite.structure.io import mol, pdbx import atomworks.io.transforms.atom_array as ta # to avoid circular import +from atomworks.common import exists +from atomworks.constants import ATOMIC_NUMBER_TO_ELEMENT, STANDARD_AA, STANDARD_DNA, STANDARD_RNA from atomworks.enums import ChainType -from atomworks.io.common import exists -from atomworks.io.constants import ATOMIC_NUMBER_TO_ELEMENT, STANDARD_AA, STANDARD_DNA, STANDARD_RNA from atomworks.io.template import add_inter_residue_bonds from atomworks.io.transforms.categories import category_to_dict from atomworks.io.utils.selection import get_annotation diff --git a/src/atomworks/io/utils/non_rcsb.py b/src/atomworks/io/utils/non_rcsb.py index ec3fb4e7..7a44a668 100644 --- a/src/atomworks/io/utils/non_rcsb.py +++ b/src/atomworks/io/utils/non_rcsb.py @@ -16,14 +16,14 @@ from biotite.structure import AtomArray from biotite.structure.io.pdbx import CIFCategory -from atomworks.enums import ChainType -from atomworks.io.constants import ( +from atomworks.constants import ( AA_LIKE_CHEM_TYPES, DNA_LIKE_CHEM_TYPES, POLYPEPTIDE_D_CHEM_TYPES, POLYPEPTIDE_L_CHEM_TYPES, RNA_LIKE_CHEM_TYPES, ) +from atomworks.enums import ChainType from atomworks.io.utils.ccd import get_chem_comp_type from atomworks.io.utils.selection import get_residue_starts from atomworks.io.utils.sequence import get_1_from_3_letter_code diff --git a/src/atomworks/io/utils/query.py b/src/atomworks/io/utils/query.py index 23bad3b5..9ce41249 100644 --- a/src/atomworks/io/utils/query.py +++ b/src/atomworks/io/utils/query.py @@ -7,7 +7,7 @@ import numpy as np from biotite.structure import AtomArray, AtomArrayStack -from atomworks.io.common import not_isin +from atomworks.common import not_isin from atomworks.io.transforms.atom_array import is_any_coord_nan diff --git a/src/atomworks/io/utils/sequence.py b/src/atomworks/io/utils/sequence.py index a42b0e57..1c1052a8 100644 --- a/src/atomworks/io/utils/sequence.py +++ b/src/atomworks/io/utils/sequence.py @@ -11,8 +11,7 @@ import numpy as np import toolz -from atomworks.enums import ChainType -from atomworks.io.constants import ( +from atomworks.constants import ( GAP, GAP_ONE_LETTER, STANDARD_AA, @@ -25,6 +24,7 @@ UNKNOWN_DNA, UNKNOWN_RNA, ) +from atomworks.enums import ChainType from atomworks.io.utils.ccd import ( aa_chem_comps, chem_comp_to_one_letter, diff --git a/src/atomworks/io/utils/testing.py b/src/atomworks/io/utils/testing.py index 2f8f3256..937f6aec 100644 --- a/src/atomworks/io/utils/testing.py +++ b/src/atomworks/io/utils/testing.py @@ -2,15 +2,17 @@ __all__ = ["assert_same_atom_array"] +import io import os from collections.abc import Iterable import biotite.structure as struc import numpy as np +from biotite.database import rcsb from biotite.structure.atoms import AtomArray, AtomArrayStack import atomworks.io.utils.bonds as cb -from atomworks.io.constants import PDB_MIRROR_PATH +from atomworks.constants import PDB_MIRROR_PATH from atomworks.io.utils.scatter import apply_group_wise, apply_segment_wise @@ -38,6 +40,24 @@ def get_pdb_path(pdbid: str, mirror_path: str | os.PathLike = PDB_MIRROR_PATH) - return filename +def get_pdb_path_or_buffer(pdb_id: str) -> str | io.StringIO: + """Returns a local file path or an in-memory buffer for a given PDB ID. + + Args: + pdb_id (str): The PDB identifier of the structure. + + Returns: + str | io.StringIO: The local file path to the structure file if available, + otherwise an in-memory buffer containing the fetched file. + """ + try: + # ... if file is locally available + return get_pdb_path(pdb_id) + except FileNotFoundError: + # ... otherwise, fetch the file from RCSB + return rcsb.fetch(pdb_id, format="cif") + + def is_same_in_segment(segment_start_stop: np.ndarray, data: np.ndarray, raise_if_false: bool = False) -> np.ndarray: """Check if all elements in a segment are the same. diff --git a/src/atomworks/io/utils/visualize.py b/src/atomworks/io/utils/visualize.py index 3bcb103e..a9dd6cc1 100644 --- a/src/atomworks/io/utils/visualize.py +++ b/src/atomworks/io/utils/visualize.py @@ -16,7 +16,7 @@ from biotite.structure import AtomArray, AtomArrayStack from biotite.structure.io import mol, pdb, pdbx -from atomworks.io.constants import ATOMIC_NUMBER_TO_ELEMENT, METAL_ELEMENTS +from atomworks.constants import ATOMIC_NUMBER_TO_ELEMENT, METAL_ELEMENTS from atomworks.io.utils.io_utils import read_any, to_cif_string logger = logging.getLogger("atomworks.io") diff --git a/src/atomworks/ml/datasets/datasets.py b/src/atomworks/ml/datasets/datasets.py index 1ddf128d..4bfd2493 100644 --- a/src/atomworks/ml/datasets/datasets.py +++ b/src/atomworks/ml/datasets/datasets.py @@ -13,7 +13,7 @@ import pandas as pd from torch.utils.data import ConcatDataset, Dataset -from atomworks.ml.common import default, exists +from atomworks.common import default, exists from atomworks.ml.datasets import logger from atomworks.ml.datasets.parsers import MetadataRowParser, load_example_from_metadata_row from atomworks.ml.preprocessing.constants import NA_VALUES diff --git a/src/atomworks/ml/datasets/parsers/base.py b/src/atomworks/ml/datasets/parsers/base.py index d14887b6..2125463f 100644 --- a/src/atomworks/ml/datasets/parsers/base.py +++ b/src/atomworks/ml/datasets/parsers/base.py @@ -4,8 +4,8 @@ import pandas as pd +from atomworks.constants import CRYSTALLIZATION_AIDS from atomworks.io import parse -from atomworks.io.constants import CRYSTALLIZATION_AIDS DEFAULT_CIF_PARSER_ARGS = { "add_missing_atoms": True, diff --git a/src/atomworks/ml/datasets/parsers/custom_metadata_row_parsers.py b/src/atomworks/ml/datasets/parsers/custom_metadata_row_parsers.py index d4b482ef..c905ccc2 100644 --- a/src/atomworks/ml/datasets/parsers/custom_metadata_row_parsers.py +++ b/src/atomworks/ml/datasets/parsers/custom_metadata_row_parsers.py @@ -5,7 +5,7 @@ import pandas as pd -from atomworks.io.constants import PDB_MIRROR_PATH +from atomworks.constants import PDB_MIRROR_PATH from atomworks.ml.datasets.parsers import MetadataRowParser diff --git a/src/atomworks/ml/datasets/parsers/default_metadata_row_parsers.py b/src/atomworks/ml/datasets/parsers/default_metadata_row_parsers.py index e896d1a6..1718538e 100644 --- a/src/atomworks/ml/datasets/parsers/default_metadata_row_parsers.py +++ b/src/atomworks/ml/datasets/parsers/default_metadata_row_parsers.py @@ -7,8 +7,8 @@ import pandas as pd -from atomworks.io.constants import PDB_MIRROR_PATH -from atomworks.ml.common import as_list +from atomworks.common import as_list +from atomworks.constants import PDB_MIRROR_PATH from atomworks.ml.datasets.parsers import MetadataRowParser diff --git a/src/atomworks/ml/encoding_definitions.py b/src/atomworks/ml/encoding_definitions.py index 9359cb09..7621d223 100644 --- a/src/atomworks/ml/encoding_definitions.py +++ b/src/atomworks/ml/encoding_definitions.py @@ -10,7 +10,8 @@ import biotite.structure as struc import numpy as np -from atomworks.io.constants import ( +from atomworks.common import exists +from atomworks.constants import ( AA_LIKE_CHEM_TYPES, CHEM_COMP_TYPES, DNA_LIKE_CHEM_TYPES, @@ -25,7 +26,6 @@ UNKNOWN_RNA, ) from atomworks.io.utils.ccd import get_chem_comp_type -from atomworks.ml.common import exists logger = getLogger(__name__) diff --git a/src/atomworks/ml/common.py b/src/atomworks/ml/example_id.py similarity index 74% rename from src/atomworks/ml/common.py rename to src/atomworks/ml/example_id.py index 6526908c..5c01dfc2 100644 --- a/src/atomworks/ml/common.py +++ b/src/atomworks/ml/example_id.py @@ -1,9 +1,6 @@ -from __future__ import annotations +"""Functions for generating and parsing example IDs that uniquely identify examples and their corresponding datasets.""" import re -from typing import Any - -from atomworks.io.common import default, exists # noqa: F401 def generate_example_id(dataset_names: list[str], pdb_id: str, assembly_id: str, query_pn_unit_iids: list) -> str: @@ -53,21 +50,3 @@ def parse_example_id(example_id: str) -> dict: "assembly_id": assembly_id, "query_pn_unit_iids": query_pn_unit_iids, } - - -def as_list(value: Any) -> list: - """Convert a value to a list. - - Handles various types using duck typing: - - Iterable objects (lists, tuples, strings, etc.): converted to list - - Single values: wrapped in a list - """ - try: - # Try to iterate over the value (duck typing approach) - # Exclude strings since they're iterable but we want to treat them as single values - if isinstance(value, str): - return [value] - return list(value) - except TypeError: - # If it's not iterable, wrap it in a list - return [value] diff --git a/src/atomworks/ml/pipelines/af3.py b/src/atomworks/ml/pipelines/af3.py index 9ff32d00..511f2e7c 100644 --- a/src/atomworks/ml/pipelines/af3.py +++ b/src/atomworks/ml/pipelines/af3.py @@ -4,9 +4,9 @@ import numpy as np import torch +from atomworks.common import exists +from atomworks.constants import AF3_EXCLUDED_LIGANDS, GAP, STANDARD_AA, STANDARD_DNA, STANDARD_RNA from atomworks.enums import ChainType -from atomworks.io.constants import AF3_EXCLUDED_LIGANDS, GAP, STANDARD_AA, STANDARD_DNA, STANDARD_RNA -from atomworks.ml.common import exists from atomworks.ml.encoding_definitions import RF2AA_ATOM36_ENCODING, AF3SequenceEncoding from atomworks.ml.transforms.af3_reference_molecule import GetAF3ReferenceMoleculeFeatures from atomworks.ml.transforms.atom_array import ( diff --git a/src/atomworks/ml/pipelines/rf2aa.py b/src/atomworks/ml/pipelines/rf2aa.py index dfdc283f..a9c32a5a 100644 --- a/src/atomworks/ml/pipelines/rf2aa.py +++ b/src/atomworks/ml/pipelines/rf2aa.py @@ -6,8 +6,8 @@ import torch from biotite.structure import AtomArray -from atomworks.io.constants import AF3_EXCLUDED_LIGANDS -from atomworks.ml.common import exists +from atomworks.common import exists +from atomworks.constants import AF3_EXCLUDED_LIGANDS from atomworks.ml.encoding_definitions import RF2AA_ATOM36_ENCODING from atomworks.ml.transforms.atom_array import ( AddGlobalAtomIdAnnotation, diff --git a/src/atomworks/ml/preprocessing/get_pn_unit_data_from_structure.py b/src/atomworks/ml/preprocessing/get_pn_unit_data_from_structure.py index 2edd129a..1c63549a 100644 --- a/src/atomworks/ml/preprocessing/get_pn_unit_data_from_structure.py +++ b/src/atomworks/ml/preprocessing/get_pn_unit_data_from_structure.py @@ -19,11 +19,10 @@ from biotite.structure import AtomArray import atomworks.ml.preprocessing.utils.structure_utils as dp # to avoid circular imports +from atomworks.common import exists, not_isin +from atomworks.constants import CRYSTALLIZATION_AIDS, METAL_ELEMENTS from atomworks.enums import ChainType from atomworks.io import parse -from atomworks.io.common import not_isin -from atomworks.io.constants import CRYSTALLIZATION_AIDS, METAL_ELEMENTS -from atomworks.ml.common import exists from atomworks.ml.preprocessing.constants import CELL_SIZE, ClashSeverity from atomworks.ml.utils.misc import hash_sequence diff --git a/src/atomworks/ml/preprocessing/utils/clustering.py b/src/atomworks/ml/preprocessing/utils/clustering.py index 9d91c28d..a8d4cd12 100644 --- a/src/atomworks/ml/preprocessing/utils/clustering.py +++ b/src/atomworks/ml/preprocessing/utils/clustering.py @@ -8,7 +8,7 @@ import pandas as pd -from atomworks.ml.common import exists +from atomworks.common import exists from atomworks.ml.preprocessing.constants import NA_VALUES from atomworks.ml.preprocessing.utils.fasta import create_fasta_file_from_df diff --git a/src/atomworks/ml/preprocessing/utils/structure_utils.py b/src/atomworks/ml/preprocessing/utils/structure_utils.py index 59995d93..43f21f87 100644 --- a/src/atomworks/ml/preprocessing/utils/structure_utils.py +++ b/src/atomworks/ml/preprocessing/utils/structure_utils.py @@ -13,9 +13,8 @@ from biotite.structure import AtomArray, CellList from scipy.spatial.distance import cdist -from atomworks.io.common import not_isin -from atomworks.io.constants import ELEMENT_NAME_TO_ATOMIC_NUMBER, METAL_ELEMENTS -from atomworks.ml.common import default +from atomworks.common import default, not_isin +from atomworks.constants import ELEMENT_NAME_TO_ATOMIC_NUMBER, METAL_ELEMENTS from atomworks.ml.preprocessing.constants import ClashSeverity logger = logging.getLogger("preprocess") @@ -92,6 +91,8 @@ def get_atom_mask_from_cell_list( Builds a mask indicating which atoms clash with the query PN unit. If the number of comparisons is too large, the computation is split into manageable chunks along the rows of `coord`. + TODO: Update documentation since this is not specific to PN units or clashes. + Args: coord (ndarray): The coordinates of the query PN unit. Shape is (n, 3). cell_list (CellList): A CellList object that allows efficient vicinity searches. diff --git a/src/atomworks/ml/transforms/af3_reference_molecule.py b/src/atomworks/ml/transforms/af3_reference_molecule.py index b9d6dd80..10acaeef 100644 --- a/src/atomworks/ml/transforms/af3_reference_molecule.py +++ b/src/atomworks/ml/transforms/af3_reference_molecule.py @@ -8,12 +8,12 @@ from biotite.structure import AtomArray from rdkit import Chem +from atomworks.common import exists +from atomworks.constants import CCD_MIRROR_PATH, ELEMENT_NAME_TO_ATOMIC_NUMBER, UNKNOWN_LIGAND from atomworks.enums import GroundTruthConformerPolicy -from atomworks.io.constants import CCD_MIRROR_PATH, ELEMENT_NAME_TO_ATOMIC_NUMBER, UNKNOWN_LIGAND from atomworks.io.tools.rdkit import atom_array_from_rdkit, remove_hydrogens from atomworks.io.utils.ccd import get_available_ccd_codes from atomworks.io.utils.selection import get_residue_starts -from atomworks.ml.common import exists from atomworks.ml.transforms._checks import check_atom_array_annotation, check_contains_keys, check_is_instance from atomworks.ml.transforms.base import Transform from atomworks.ml.transforms.rdkit_utils import ( diff --git a/src/atomworks/ml/transforms/cached_residue_data.py b/src/atomworks/ml/transforms/cached_residue_data.py index 40dfa86a..3ddfc268 100644 --- a/src/atomworks/ml/transforms/cached_residue_data.py +++ b/src/atomworks/ml/transforms/cached_residue_data.py @@ -10,7 +10,7 @@ from biotite.structure import AtomArray, residue_iter from toolz import keyfilter -from atomworks.ml.common import exists +from atomworks.common import exists from atomworks.ml.transforms._checks import check_atom_array_annotation, check_contains_keys from atomworks.ml.transforms.base import Transform from atomworks.ml.utils.io import get_sharded_file_path diff --git a/src/atomworks/ml/transforms/covalent_modifications.py b/src/atomworks/ml/transforms/covalent_modifications.py index 5364c3a0..cfbc568e 100644 --- a/src/atomworks/ml/transforms/covalent_modifications.py +++ b/src/atomworks/ml/transforms/covalent_modifications.py @@ -98,6 +98,9 @@ class FlagAndReassignCovalentModifications(Transform): set atomize = true (thus, this transform must be run before the Atomize transform) set is_covalent_modification = true (for the entire pn_unit) ------------------------------------------------------------------------------------------------ + + TODO: Break into two Transforms - one that flags, one that reassigns. Atomizing covalent modifications is a design choice + that may not be desired in all pipelines. Annotating covalent modifications, however, is broadly useful. """ incompatible_previous_transforms: ClassVar[list[str | Transform]] = [AtomizeByCCDName, "AddGlobalTokenIdAnnotation"] diff --git a/src/atomworks/ml/transforms/crop.py b/src/atomworks/ml/transforms/crop.py index 934ab44a..c691530a 100644 --- a/src/atomworks/ml/transforms/crop.py +++ b/src/atomworks/ml/transforms/crop.py @@ -6,8 +6,8 @@ from biotite.structure import AtomArray from scipy.spatial import KDTree +from atomworks.common import exists from atomworks.io.transforms.atom_array import is_any_coord_nan -from atomworks.ml.common import exists from atomworks.ml.transforms._checks import ( check_atom_array_annotation, check_contains_keys, diff --git a/src/atomworks/ml/transforms/dna/pad_dna.py b/src/atomworks/ml/transforms/dna/pad_dna.py index b7f213ab..5f11e1b0 100644 --- a/src/atomworks/ml/transforms/dna/pad_dna.py +++ b/src/atomworks/ml/transforms/dna/pad_dna.py @@ -19,7 +19,7 @@ from biotite.structure.filter import filter_nucleotides from biotite.structure.residues import get_residue_masks, get_residue_starts_for -from atomworks.io.constants import STANDARD_DNA +from atomworks.constants import STANDARD_DNA from atomworks.io.transforms.atom_array import remove_nan_coords from atomworks.io.utils.io_utils import load_any from atomworks.io.utils.selection import ResIdxSlice diff --git a/src/atomworks/ml/transforms/encoding.py b/src/atomworks/ml/transforms/encoding.py index a8b70bb1..19de7da4 100644 --- a/src/atomworks/ml/transforms/encoding.py +++ b/src/atomworks/ml/transforms/encoding.py @@ -14,8 +14,8 @@ from biotite.structure import AtomArray from torch.nn import functional as F # noqa: N812 -from atomworks.io.common import KeyToIntMapper, exists -from atomworks.io.constants import ELEMENT_NAME_TO_ATOMIC_NUMBER +from atomworks.common import KeyToIntMapper, exists +from atomworks.constants import ELEMENT_NAME_TO_ATOMIC_NUMBER from atomworks.io.utils.ccd import get_std_to_alt_atom_name_map from atomworks.ml.encoding_definitions import AF3SequenceEncoding, TokenEncoding from atomworks.ml.transforms._checks import ( diff --git a/src/atomworks/ml/transforms/featurize_unresolved_residues.py b/src/atomworks/ml/transforms/featurize_unresolved_residues.py index a9d46140..95fba5a9 100644 --- a/src/atomworks/ml/transforms/featurize_unresolved_residues.py +++ b/src/atomworks/ml/transforms/featurize_unresolved_residues.py @@ -9,8 +9,8 @@ import numpy as np from biotite.structure import AtomArray +from atomworks.constants import NUCLEIC_ACID_FRAME_ATOM_NAMES, PROTEIN_FRAME_ATOM_NAMES from atomworks.enums import ChainTypeInfo -from atomworks.io.constants import NUCLEIC_ACID_FRAME_ATOM_NAMES, PROTEIN_FRAME_ATOM_NAMES from atomworks.ml.transforms._checks import check_atom_array_annotation, check_contains_keys, check_is_instance from atomworks.ml.transforms.atom_array import apply_and_spread_residue_wise from atomworks.ml.transforms.base import Transform diff --git a/src/atomworks/ml/transforms/filters.py b/src/atomworks/ml/transforms/filters.py index fc026e35..5f431d65 100644 --- a/src/atomworks/ml/transforms/filters.py +++ b/src/atomworks/ml/transforms/filters.py @@ -10,13 +10,12 @@ import numpy as np from biotite.structure import AtomArray, AtomArrayStack +from atomworks.common import exists, not_isin +from atomworks.constants import HYDROGEN_LIKE_SYMBOLS from atomworks.enums import ChainType, ChainTypeInfo -from atomworks.io.common import not_isin -from atomworks.io.constants import HYDROGEN_LIKE_SYMBOLS from atomworks.io.utils.query import QueryExpression from atomworks.io.utils.selection import get_annotation from atomworks.io.utils.sequence import get_1_from_3_letter_code, get_3_from_1_letter_code -from atomworks.ml.common import exists from atomworks.ml.preprocessing.constants import TRAINING_SUPPORTED_CHAIN_TYPES from atomworks.ml.transforms._checks import ( check_atom_array_annotation, diff --git a/src/atomworks/ml/transforms/openbabel_utils.py b/src/atomworks/ml/transforms/openbabel_utils.py index a7b1558d..78d89e53 100644 --- a/src/atomworks/ml/transforms/openbabel_utils.py +++ b/src/atomworks/ml/transforms/openbabel_utils.py @@ -19,7 +19,7 @@ from biotite.structure import AtomArray from openbabel import openbabel, pybel -from atomworks.io.constants import ATOMIC_NUMBER_TO_ELEMENT, ELEMENT_NAME_TO_ATOMIC_NUMBER, UNKNOWN_LIGAND +from atomworks.constants import ATOMIC_NUMBER_TO_ELEMENT, ELEMENT_NAME_TO_ATOMIC_NUMBER, UNKNOWN_LIGAND from atomworks.ml.transforms._checks import ( check_atom_array_annotation, check_contains_keys, diff --git a/src/atomworks/ml/transforms/random_atomize_residues.py b/src/atomworks/ml/transforms/random_atomize_residues.py index bb2d48d8..16fb4b72 100644 --- a/src/atomworks/ml/transforms/random_atomize_residues.py +++ b/src/atomworks/ml/transforms/random_atomize_residues.py @@ -4,7 +4,7 @@ import numpy as np from biotite.structure import AtomArray -from atomworks.io.constants import STANDARD_AA, STANDARD_DNA, STANDARD_RNA +from atomworks.constants import STANDARD_AA, STANDARD_DNA, STANDARD_RNA from atomworks.io.utils.selection import get_annotation, get_residue_starts from atomworks.ml.transforms._checks import check_atom_array_annotation from atomworks.ml.transforms.base import Transform diff --git a/src/atomworks/ml/transforms/rdkit_utils.py b/src/atomworks/ml/transforms/rdkit_utils.py index 89fc41dc..ab7c7ab0 100644 --- a/src/atomworks/ml/transforms/rdkit_utils.py +++ b/src/atomworks/ml/transforms/rdkit_utils.py @@ -7,6 +7,7 @@ from rdkit import Chem, RDLogger from rdkit.Chem import AllChem, Mol, rdDistGeom +from atomworks.common import default from atomworks.io.tools.rdkit import ( add_hydrogens, atom_array_from_rdkit, @@ -15,7 +16,6 @@ preserve_annotations, remove_hydrogens, ) -from atomworks.ml.common import default from atomworks.ml.transforms._checks import ( check_atom_array_annotation, check_contains_keys, diff --git a/src/atomworks/ml/transforms/template.py b/src/atomworks/ml/transforms/template.py index cfb32270..73970be5 100644 --- a/src/atomworks/ml/transforms/template.py +++ b/src/atomworks/ml/transforms/template.py @@ -15,8 +15,8 @@ from biotite.structure import AtomArray from torch.nn.functional import normalize +from atomworks.common import exists from atomworks.enums import ChainType -from atomworks.ml.common import exists from atomworks.ml.encoding_definitions import ( LEGACY_RF2_ATOM14_ENCODING, RF2AA_ATOM36_ENCODING, diff --git a/src/atomworks/ml/utils/debug.py b/src/atomworks/ml/utils/debug.py index 9f813cc9..73ff5672 100644 --- a/src/atomworks/ml/utils/debug.py +++ b/src/atomworks/ml/utils/debug.py @@ -4,7 +4,7 @@ import re from datetime import datetime -from atomworks.ml.common import default +from atomworks.common import default logger = logging.getLogger("atomworks.ml") _USER = default(os.getenv("USER"), "") diff --git a/src/atomworks/ml/utils/geometry.py b/src/atomworks/ml/utils/geometry.py index 79652f46..12a7a13d 100644 --- a/src/atomworks/ml/utils/geometry.py +++ b/src/atomworks/ml/utils/geometry.py @@ -6,7 +6,7 @@ from einops import einsum, rearrange from torch.nn.functional import normalize -from atomworks.ml.common import default +from atomworks.common import default def get_torch_eps(dtype: torch.dtype) -> float: diff --git a/src/atomworks/ml/utils/io.py b/src/atomworks/ml/utils/io.py index 8fea8fd9..eb38cc0a 100644 --- a/src/atomworks/ml/utils/io.py +++ b/src/atomworks/ml/utils/io.py @@ -14,7 +14,7 @@ import pyarrow as pa import pyarrow.parquet as pq -from atomworks.io.constants import ( +from atomworks.constants import ( AA_LIKE_CHEM_TYPES, ATOMIC_NUMBER_TO_ELEMENT, DNA_LIKE_CHEM_TYPES, diff --git a/src/atomworks/ml/utils/misc.py b/src/atomworks/ml/utils/misc.py index 4b72e7b7..bddabcf8 100644 --- a/src/atomworks/ml/utils/misc.py +++ b/src/atomworks/ml/utils/misc.py @@ -10,7 +10,7 @@ import torch from einops import rearrange -from atomworks.ml.common import default +from atomworks.common import default from atomworks.ml.preprocessing.constants import NA_VALUES logger = logging.getLogger(__name__) diff --git a/src/atomworks/ml/utils/testing.py b/src/atomworks/ml/utils/testing.py index f8c3dd2c..e20bc5f8 100644 --- a/src/atomworks/ml/utils/testing.py +++ b/src/atomworks/ml/utils/testing.py @@ -3,9 +3,9 @@ import numpy as np from biotite.structure import AtomArray, CellList +from atomworks.common import immutable_lru_cache +from atomworks.constants import PDB_MIRROR_PATH from atomworks.io import parse -from atomworks.io.common import immutable_lru_cache -from atomworks.io.constants import PDB_MIRROR_PATH from atomworks.ml.preprocessing.constants import CELL_SIZE from atomworks.ml.preprocessing.utils.structure_utils import get_atom_mask_from_cell_list diff --git a/src/atomworks/cli/__init__.py b/src/atomworks_cli/__main__.py similarity index 77% rename from src/atomworks/cli/__init__.py rename to src/atomworks_cli/__main__.py index 29e7a3d0..9bb936c1 100644 --- a/src/atomworks/cli/__init__.py +++ b/src/atomworks_cli/__main__.py @@ -1,6 +1,4 @@ -"""AtomWorks command-line interface.""" - -from __future__ import annotations +"""Entry point for the AtomWorks command-line interface.""" import typer @@ -15,3 +13,11 @@ app.add_typer(_ccd.app, name="ccd") app.add_typer(_pdb.app, name="pdb") app.add_typer(_setup.app, name="setup") + + +def main() -> None: + app() + + +if __name__ == "__main__": + main() diff --git a/src/atomworks/cli/ccd.py b/src/atomworks_cli/ccd.py similarity index 100% rename from src/atomworks/cli/ccd.py rename to src/atomworks_cli/ccd.py diff --git a/src/atomworks/cli/pdb.py b/src/atomworks_cli/pdb.py similarity index 100% rename from src/atomworks/cli/pdb.py rename to src/atomworks_cli/pdb.py diff --git a/src/atomworks/cli/setup.py b/src/atomworks_cli/setup.py similarity index 75% rename from src/atomworks/cli/setup.py rename to src/atomworks_cli/setup.py index 4ff0a9c6..cdb616f1 100644 --- a/src/atomworks/cli/setup.py +++ b/src/atomworks_cli/setup.py @@ -15,9 +15,14 @@ from .pdb import PDB_PORT, PDB_REMOTE, _collect_pdb_ids, _pdb_id_to_relpath, _rsync_fetch_specific, _run_rsync_list -TEST_PACK_URL = "https://files.ipd.uw.edu/pub/atomworks/test_pack_latest.tar.gz" +IPD_DOWNLOAD_URL = "https://files.ipd.uw.edu/pub/atomworks" + +TEST_PACK_URL = f"{IPD_DOWNLOAD_URL}/test_pack_latest.tar.gz" """The URL for the latest AtomWorks test pack. Should be untared in `tests/data/shared`.""" +METADATA_URL = f"{IPD_DOWNLOAD_URL}/pdb_metadata_latest.tar.gz" +"""The URL for the latest AtomWorks PDB metadata. Should be untared at the specifided location.""" + app = typer.Typer(help="Setup utilities for AtomWorks.") @@ -132,3 +137,40 @@ def setup_tests( typer.secho("Test setup completed successfully!", fg=typer.colors.GREEN) typer.secho("To run tests use: PDB_MIRROR_PATH=tests/data/pdb pytest -n auto tests") + + +@app.command("metadata") +def setup_metadata( + output_dir: Path = typer.Argument( + ..., + help="Directory where the PDB metadata archive should be extracted.", + ), + keep_archive: bool = typer.Option(False, "--keep-archive", help="Keep downloaded metadata archive."), +) -> None: + """Download the latest PDB metadata archive and extract it to the given directory. + + NOTE: It's expected that you run this command from the root of the repository. + + The metadata archive is structured to extract under `shared/` inside the provided directory by default. + + Example: + atomworks setup metadata --output-dir tests/data + """ + typer.echo("Setting up AtomWorks PDB metadata...") + + output_dir.mkdir(parents=True, exist_ok=True) + with tempfile.TemporaryDirectory() as tmpdir: + archive_path = Path(tmpdir) / "pdb_metadata_latest.tar.gz" + typer.echo(f"Downloading PDB metadata from {METADATA_URL} ...") + _download_file(METADATA_URL, archive_path) + typer.secho("Download complete", fg=typer.colors.GREEN) + + typer.echo(f"Extracting PDB metadata into {output_dir} ...") + _extract_tar_gz(archive_path, output_dir) + typer.secho("Extraction complete", fg=typer.colors.GREEN) + + if keep_archive: + keep_path = output_dir / archive_path.name + keep_path.write_bytes(archive_path.read_bytes()) + + typer.secho("PDB metadata setup completed successfully!", fg=typer.colors.GREEN) diff --git a/tests/io/components/test_caching.py b/tests/io/components/test_caching.py index 7aca24e2..b0f759e8 100644 --- a/tests/io/components/test_caching.py +++ b/tests/io/components/test_caching.py @@ -97,7 +97,7 @@ def different_args_parse(): assert cached_elapsed_time < normal_elapsed_time / 1.5 # Assert that the result with different arguments is similar to the normal elapsed time - assert abs(different_args_elapsed_time - normal_elapsed_time) < normal_elapsed_time * 0.2 + assert abs(different_args_elapsed_time - normal_elapsed_time) < normal_elapsed_time * 0.5 if __name__ == "__main__": diff --git a/tests/io/components/test_entity_annotations.py b/tests/io/components/test_entity_annotations.py index 031f8a7e..00b19462 100644 --- a/tests/io/components/test_entity_annotations.py +++ b/tests/io/components/test_entity_annotations.py @@ -2,7 +2,7 @@ import pytest from biotite.structure import AtomArray -from atomworks.io.common import not_isin +from atomworks.common import not_isin from atomworks.io.parser import parse from atomworks.io.transforms.atom_array import annotate_entities from tests.io.conftest import get_pdb_path diff --git a/tests/io/components/test_ignore_residues.py b/tests/io/components/test_ignore_residues.py index 7461aabd..6436d1cd 100644 --- a/tests/io/components/test_ignore_residues.py +++ b/tests/io/components/test_ignore_residues.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from atomworks.io.constants import CRYSTALLIZATION_AIDS +from atomworks.constants import CRYSTALLIZATION_AIDS from atomworks.io.parser import parse from tests.io.conftest import get_pdb_path diff --git a/tests/io/components/test_mse_to_met.py b/tests/io/components/test_mse_to_met.py index f2263a04..37114f67 100644 --- a/tests/io/components/test_mse_to_met.py +++ b/tests/io/components/test_mse_to_met.py @@ -3,8 +3,8 @@ import numpy as np import pytest -from atomworks.io.common import not_isin -from atomworks.io.constants import CCD_MIRROR_PATH, HYDROGEN_LIKE_SYMBOLS +from atomworks.common import not_isin +from atomworks.constants import CCD_MIRROR_PATH, HYDROGEN_LIKE_SYMBOLS from atomworks.io.parser import parse from atomworks.io.transforms.atom_array import mse_to_met from atomworks.io.utils.ccd import atom_array_from_ccd_code diff --git a/tests/io/components/test_non_covalent_bond_parsing.py b/tests/io/components/test_non_covalent_bond_parsing.py index b52b2db0..79c3b4ac 100644 --- a/tests/io/components/test_non_covalent_bond_parsing.py +++ b/tests/io/components/test_non_covalent_bond_parsing.py @@ -3,7 +3,7 @@ import biotite.structure as struc import pytest -from atomworks.io.constants import STRUCT_CONN_BOND_TYPES +from atomworks.constants import STRUCT_CONN_BOND_TYPES from atomworks.io.parser import parse from tests.io.conftest import get_pdb_path diff --git a/tests/io/components/test_regression.py b/tests/io/components/test_regression.py index d7e5d196..bd1e9d92 100644 --- a/tests/io/components/test_regression.py +++ b/tests/io/components/test_regression.py @@ -6,7 +6,7 @@ import numpy as np import pytest -from atomworks.io.constants import CRYSTALLIZATION_AIDS +from atomworks.constants import CRYSTALLIZATION_AIDS from atomworks.io.parser import parse from atomworks.io.transforms import atom_array as ta from atomworks.io.utils.io_utils import to_cif_file # noqa: F401 @@ -73,7 +73,7 @@ def test_regression_against_stored_result(pdb_id: str): # ) # ## FOR DEBUGGING REGRESSION TESTS UNCOMMENT: - from atomworks.io.common import sum_string_arrays + from atomworks.common import sum_string_arrays def get_atom_identifiers(atom_array): return sum_string_arrays( diff --git a/tests/io/speed/test_parse_speed.py b/tests/io/speed/test_parse_speed.py index 3d76bd77..e9f22b64 100644 --- a/tests/io/speed/test_parse_speed.py +++ b/tests/io/speed/test_parse_speed.py @@ -1,6 +1,6 @@ import pytest -from atomworks.io.constants import CCD_MIRROR_PATH +from atomworks.constants import CCD_MIRROR_PATH from atomworks.io.parser import parse from tests.io.conftest import get_pdb_path diff --git a/tests/io/tools/test_rdkit.py b/tests/io/tools/test_rdkit.py index 88814ab9..e33736ff 100644 --- a/tests/io/tools/test_rdkit.py +++ b/tests/io/tools/test_rdkit.py @@ -4,7 +4,7 @@ from biotite.structure import AtomArray from rdkit import Chem -from atomworks.io.constants import STANDARD_AA +from atomworks.constants import STANDARD_AA from atomworks.io.tools.inference import components_to_atom_array from atomworks.io.tools.rdkit import ( atom_array_from_rdkit, diff --git a/tests/io/utils/test_ccd.py b/tests/io/utils/test_ccd.py index 92292178..882994a1 100644 --- a/tests/io/utils/test_ccd.py +++ b/tests/io/utils/test_ccd.py @@ -3,7 +3,7 @@ import numpy as np import pytest -from atomworks.io.constants import CCD_MIRROR_PATH +from atomworks.constants import CCD_MIRROR_PATH from atomworks.io.utils.ccd import ( atom_array_from_ccd_code, get_ccd_component_from_mirror, diff --git a/tests/ml/conftest.py b/tests/ml/conftest.py index 6c87a014..7005cbe9 100644 --- a/tests/ml/conftest.py +++ b/tests/ml/conftest.py @@ -7,7 +7,7 @@ import pytest from dotenv import load_dotenv -from atomworks.io.constants import AF3_EXCLUDED_LIGANDS_REGEX, _load_env_var +from atomworks.constants import AF3_EXCLUDED_LIGANDS_REGEX, _load_env_var from atomworks.io.tools.inference import SequenceComponent from atomworks.ml.datasets.datasets import ConcatDatasetWithID, PandasDataset, StructuralDatasetWrapper from atomworks.ml.datasets.parsers import ( diff --git a/tests/ml/pipelines/test_pipeline_regression.py b/tests/ml/pipelines/test_pipeline_regression.py index be0267a4..b2242091 100644 --- a/tests/ml/pipelines/test_pipeline_regression.py +++ b/tests/ml/pipelines/test_pipeline_regression.py @@ -7,12 +7,12 @@ import pytest import torch -from atomworks.enums import ChainType -from atomworks.io import parse -from atomworks.io.constants import ( +from atomworks.constants import ( AF3_EXCLUDED_LIGANDS, GAP, ) +from atomworks.enums import ChainType +from atomworks.io import parse from atomworks.io.utils.testing import assert_same_atom_array from atomworks.ml.datasets.parsers.base import DEFAULT_CIF_PARSER_ARGS from atomworks.ml.pipelines.af3 import build_af3_transform_pipeline diff --git a/tests/ml/transforms/msa/test_featurize_msa.py b/tests/ml/transforms/msa/test_featurize_msa.py index eab42343..3678c0c7 100644 --- a/tests/ml/transforms/msa/test_featurize_msa.py +++ b/tests/ml/transforms/msa/test_featurize_msa.py @@ -7,7 +7,7 @@ import pytest import torch -from atomworks.io.constants import STANDARD_AA, STANDARD_DNA, STANDARD_RNA, UNKNOWN_AA, UNKNOWN_DNA, UNKNOWN_RNA +from atomworks.constants import STANDARD_AA, STANDARD_DNA, STANDARD_RNA, UNKNOWN_AA, UNKNOWN_DNA, UNKNOWN_RNA from atomworks.ml.encoding_definitions import RF2AA_ATOM36_ENCODING, TokenEncoding from atomworks.ml.transforms.atom_array import ( AddWithinPolyResIdxAnnotation, diff --git a/tests/ml/transforms/symmetry/test_automorphisms_networkx.py b/tests/ml/transforms/symmetry/test_automorphisms_networkx.py index a86aba1e..1579ff1d 100644 --- a/tests/ml/transforms/symmetry/test_automorphisms_networkx.py +++ b/tests/ml/transforms/symmetry/test_automorphisms_networkx.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from atomworks.io.constants import ELEMENT_NAME_TO_ATOMIC_NUMBER +from atomworks.constants import ELEMENT_NAME_TO_ATOMIC_NUMBER from atomworks.ml.encoding_definitions import AF3_TOKENS from atomworks.ml.transforms.atomize import AtomizeByCCDName from atomworks.ml.transforms.symmetry import ( diff --git a/tests/ml/transforms/test_af3_reference_molecule.py b/tests/ml/transforms/test_af3_reference_molecule.py index 8f15c9ca..2e4153a0 100644 --- a/tests/ml/transforms/test_af3_reference_molecule.py +++ b/tests/ml/transforms/test_af3_reference_molecule.py @@ -5,8 +5,8 @@ import pytest import torch +from atomworks.constants import STANDARD_AA, STANDARD_DNA, STANDARD_RNA from atomworks.enums import ChainType, GroundTruthConformerPolicy -from atomworks.io.constants import STANDARD_AA, STANDARD_DNA, STANDARD_RNA from atomworks.io.tools.inference import components_to_atom_array from atomworks.io.tools.rdkit import atom_array_from_rdkit from atomworks.io.utils.selection import get_residue_starts diff --git a/tests/ml/transforms/test_atomize.py b/tests/ml/transforms/test_atomize.py index 72fedf93..a32c60a0 100644 --- a/tests/ml/transforms/test_atomize.py +++ b/tests/ml/transforms/test_atomize.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from atomworks.io.constants import STANDARD_AA +from atomworks.constants import STANDARD_AA from atomworks.ml.transforms.atomize import AtomizeByCCDName from atomworks.ml.utils.testing import cached_parse diff --git a/tests/ml/transforms/test_bonds.py b/tests/ml/transforms/test_bonds.py index a5a9754c..b4614214 100644 --- a/tests/ml/transforms/test_bonds.py +++ b/tests/ml/transforms/test_bonds.py @@ -4,7 +4,7 @@ import pytest import torch -from atomworks.io.constants import STANDARD_AA +from atomworks.constants import STANDARD_AA from atomworks.ml.encoding_definitions import AF3SequenceEncoding from atomworks.ml.transforms.atom_array import AddWithinChainInstanceResIdx, AddWithinPolyResIdxAnnotation from atomworks.ml.transforms.atomize import AtomizeByCCDName diff --git a/tests/ml/transforms/test_confidence_transforms.py b/tests/ml/transforms/test_confidence_transforms.py index 3c7cce6c..e4d0434f 100644 --- a/tests/ml/transforms/test_confidence_transforms.py +++ b/tests/ml/transforms/test_confidence_transforms.py @@ -3,7 +3,7 @@ import pytest import torch -from atomworks.io.constants import STANDARD_AA, STANDARD_DNA, STANDARD_RNA +from atomworks.constants import STANDARD_AA, STANDARD_DNA, STANDARD_RNA from atomworks.ml.encoding_definitions import RF2AA_ATOM36_ENCODING, AF3SequenceEncoding from atomworks.ml.transforms.atom_array import ( AddGlobalAtomIdAnnotation, diff --git a/tests/ml/transforms/test_filters.py b/tests/ml/transforms/test_filters.py index 7080e7da..68b4573b 100644 --- a/tests/ml/transforms/test_filters.py +++ b/tests/ml/transforms/test_filters.py @@ -4,7 +4,7 @@ import pytest from biotite.structure import AtomArray -from atomworks.io.constants import STANDARD_AA, STANDARD_DNA, STANDARD_RNA +from atomworks.constants import STANDARD_AA, STANDARD_DNA, STANDARD_RNA from atomworks.ml.datasets.parsers import PNUnitsDFParser, load_example_from_metadata_row from atomworks.ml.preprocessing.constants import TRAINING_SUPPORTED_CHAIN_TYPES, ChainType from atomworks.ml.transforms.atomize import AtomizeByCCDName diff --git a/tests/ml/transforms/test_token_utils.py b/tests/ml/transforms/test_token_utils.py index c7dad581..2db6e355 100644 --- a/tests/ml/transforms/test_token_utils.py +++ b/tests/ml/transforms/test_token_utils.py @@ -2,7 +2,7 @@ import numpy as np import pytest -from atomworks.io.constants import STANDARD_AA, STANDARD_DNA, STANDARD_RNA +from atomworks.constants import STANDARD_AA, STANDARD_DNA, STANDARD_RNA from atomworks.io.utils.sequence import STANDARD_PURINE_RESIDUES, STANDARD_PYRIMIDINE_RESIDUES from atomworks.io.utils.testing import assert_same_atom_array from atomworks.ml.encoding_definitions import RF2AA_ATOM36_ENCODING diff --git a/tests/ml/utils/test_io.py b/tests/ml/utils/test_io.py index d86c9f8b..c7a31add 100644 --- a/tests/ml/utils/test_io.py +++ b/tests/ml/utils/test_io.py @@ -4,7 +4,7 @@ import pytest from biotite.structure import AtomArrayStack -from atomworks.io.constants import ATOMIC_NUMBER_TO_ELEMENT +from atomworks.constants import ATOMIC_NUMBER_TO_ELEMENT from atomworks.ml.utils.io import convert_af3_model_output_to_atom_array_stack from tests.ml.conftest import TEST_DATA_ML