Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
d1eae15
fixes type checking issues-
benoistlaurent Feb 9, 2026
bdfc7f6
upgrade dependencies: urllib3, protobuf
benoistlaurent Feb 9, 2026
3f89b9c
removed streamlit from dev dependencies (already in main dependencies)
benoistlaurent Feb 9, 2026
160b281
fixes minor type hinting issue
benoistlaurent Feb 9, 2026
4f034ef
new guess_resolution
benoistlaurent Feb 12, 2026
a54a0dc
upgrade dependencies
benoistlaurent Feb 12, 2026
b75803b
Merge branch 'develop' into feature/new-guess-resolution
benoistlaurent Feb 12, 2026
c77e751
unit tests for guess resolution coarse grain
benoistlaurent Feb 12, 2026
c5b1d55
new MolecularResolution class: update regression test dataset
benoistlaurent Feb 12, 2026
fc5168f
guess_resolution: takes all atom cutoff distance as argument
benoistlaurent Feb 12, 2026
f16fa68
minor fix: spelling issue: distance_cutoff -> cutoff_distance
benoistlaurent Feb 12, 2026
1634b3e
tests/test_guesser.py: fixes english mistake in comment
benoistlaurent Feb 12, 2026
b6f9f7d
settings.DistanceCutoff.guess: fixes bug introduced with new Molecula…
benoistlaurent Feb 12, 2026
4d72d82
minor fixes
benoistlaurent Feb 12, 2026
58f08c7
Merge branch 'feature/new-guess-resolution' of github.com:MDVerse/gro…
benoistlaurent Feb 12, 2026
5792ae7
guess_resolution: takes all atom cutoff distance as argument
benoistlaurent Feb 12, 2026
74cf360
Merge branch 'feature/new-guess-resolution' of github.com:MDVerse/gro…
benoistlaurent Feb 12, 2026
d576cfd
Merge branch 'main' into feature/new-guess-resolution
benoistlaurent Feb 12, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions src/grodecoder/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from .identifier import identify
from .io import read_universe
from .models import Decoded
from .toputils import guess_resolution
from .guesser import guess_resolution
from .settings import get_settings


Expand All @@ -20,23 +20,23 @@ def decode(universe: UniverseLike) -> Decoded:

settings = get_settings()

resolution = guess_resolution(universe, cutoff_distance=settings.resolution_detection.distance_cutoff)
resolution = guess_resolution(universe, settings.resolution_detection.cutoff_distance)
logger.info(f"Guessed resolution: {resolution}")

# Guesses the chain dection distance cutoff if not provided by the user.
chain_detection_settings = get_settings().chain_detection

if chain_detection_settings.distance_cutoff.is_set():
value = chain_detection_settings.distance_cutoff.get()
if chain_detection_settings.cutoff_distance.is_set():
value = chain_detection_settings.cutoff_distance.get()
logger.debug(f"chain detection: using user-defined value: {value:.2f}")
else:
logger.debug("chain detection: guessing distance cutoff based on resolution")
chain_detection_settings.distance_cutoff.guess(resolution)
chain_detection_settings.cutoff_distance.guess(resolution)

distance_cutoff = chain_detection_settings.distance_cutoff.get()
cutoff_distance = chain_detection_settings.cutoff_distance.get()

return Decoded(
inventory=identify(universe, bond_threshold=distance_cutoff),
inventory=identify(universe, bond_threshold=cutoff_distance),
resolution=resolution,
)

Expand Down
168 changes: 168 additions & 0 deletions src/grodecoder/guesser.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
from enum import StrEnum, auto
from itertools import islice
from typing import Self

import MDAnalysis as MDA
import numpy as np
from loguru import logger
from pydantic import BaseModel, ConfigDict, model_validator

from .toputils import has_bonds


class ResolutionValue(StrEnum):
ALL_ATOM = auto()
COARSE_GRAIN = auto()


class AllAtomResolutionReason(StrEnum):
DEFAULT = auto()
HAS_HYDROGENS = auto()
RESIDUES_HAVE_BONDS_WITHIN_CUTOFF = auto()


class CoarseGrainResolutionReason(StrEnum):
HAS_BB_ATOMS = auto()
HAS_ONE_GRAIN_PER_RESIDUE = auto()
HAS_NO_BOND_WITHIN_ALL_ATOM_CUTOFF = auto()
PROTEIN_HAS_NO_HYDROGEN = auto()


class MolecularResolution(BaseModel):
value: ResolutionValue
reason: AllAtomResolutionReason | CoarseGrainResolutionReason | None = None

model_config = ConfigDict(validate_assignment=True)

def is_all_atom(self) -> bool:
return self.value == ResolutionValue.ALL_ATOM

def is_coarse_grain(self) -> bool:
return self.value == ResolutionValue.COARSE_GRAIN

@model_validator(mode="after")
def check_reason(self) -> Self:
"""Validate that reason is compatible with value."""
if self.value == ResolutionValue.ALL_ATOM:
if self.reason is not None and not isinstance(self.reason, AllAtomResolutionReason):
raise ValueError(
f"reason must be AllAtomResolutionReason when value is 'all-atom', "
f"got {type(self.reason).__name__}"
)
elif self.value == ResolutionValue.COARSE_GRAIN:
if self.reason is not None and not isinstance(self.reason, CoarseGrainResolutionReason):
raise ValueError(
f"reason must be CoarseGrainResolutionReason when value is 'coarse-grain', "
f"got {type(self.reason).__name__}"
)
return self

@classmethod
def AllAtomWithHydrogen(cls) -> Self:
"""System that contains hydrogen atoms."""
return cls(value=ResolutionValue.ALL_ATOM, reason=AllAtomResolutionReason.HAS_HYDROGENS)

@classmethod
def AllAtomWithStandardBonds(cls) -> Self:
"""System in which atoms within residues have distances that are typical of all-atom models."""
return cls(
value=ResolutionValue.ALL_ATOM, reason=AllAtomResolutionReason.RESIDUES_HAVE_BONDS_WITHIN_CUTOFF
)

@classmethod
def CoarseGrainMartini(cls) -> Self:
"""System that contains atoms named BB*."""
return cls(value=ResolutionValue.COARSE_GRAIN, reason=CoarseGrainResolutionReason.HAS_BB_ATOMS)

@classmethod
def CoarseGrainSingleParticle(cls) -> Self:
"""System with residues made of a single particle."""
return cls(
value=ResolutionValue.COARSE_GRAIN, reason=CoarseGrainResolutionReason.HAS_ONE_GRAIN_PER_RESIDUE
)

@classmethod
def CoarseGrainOther(cls) -> Self:
"""Other coarse grain systems, typically with bond length between particle greater than standard
all-atom models."""
return cls(
value=ResolutionValue.COARSE_GRAIN,
reason=CoarseGrainResolutionReason.HAS_NO_BOND_WITHIN_ALL_ATOM_CUTOFF,
)

@classmethod
def CoarseGrainProteinHasNoHydrogen(cls) -> Self:
"""Protein is detected and residues do not have hydrogen atoms."""
return cls(
value=ResolutionValue.COARSE_GRAIN, reason=CoarseGrainResolutionReason.PROTEIN_HAS_NO_HYDROGEN
)


def _has_hydrogen(model: MDA.AtomGroup) -> bool:
"""Returns True if a model contains hydrogen atoms."""
return "H" in model.atoms.types


def _is_martini(model: MDA.AtomGroup) -> bool:
"""Returns True if a model contains atoms named BB*."""
return bool(np.any(np.char.startswith(model.atoms.names.astype("U"), "BB")))


def _has_bonds_within_all_atom_cutoff(model: MDA.AtomGroup, cutoff_distance: float) -> bool:
for residue in model.residues:
if has_bonds(residue, cutoff_distance):
return True
return False


def _has_protein(model: MDA.AtomGroup) -> bool:
return len(model.select_atoms("protein").atoms) > 0


def _protein_has_hydrogen(model: MDA.AtomGroup) -> bool:
"""Return True if protein residues have hydrogen atoms."""
return _has_hydrogen(model.select_atoms("protein"))


def guess_resolution(universe, all_atom_cutoff_distance: float = 1.6) -> MolecularResolution:
"""Guesses a system resolution (all-atom or coarse-grain)."""
# Only one atom in the system: defaulting to all-atom.
if len(universe.atoms) == 1:
return MolecularResolution(value=ResolutionValue.ALL_ATOM, reason=AllAtomResolutionReason.DEFAULT)

# Selects the first five residues with at least two atoms.
resindexes = list(
islice((residue.resindex for residue in universe.residues if len(residue.atoms) > 1), 5)
)

# If no residue with more than one atom, definitely coarse grain.
no_residue_with_more_than_1_particle = len(resindexes) == 0
if no_residue_with_more_than_1_particle:
logger.debug("No residues with more than one atom: resolution is coarse grain")
return MolecularResolution.CoarseGrainSingleParticle()

small_u: MDA.AtomGroup = universe.select_atoms(f"resindex {' '.join(str(i) for i in resindexes)}")

# If we find any hydrogen atom, it's all-atom.
if _has_hydrogen(small_u):
logger.debug("Found hydrogen atoms: resolution is all-atom")
return MolecularResolution.AllAtomWithHydrogen()

# If we find any atom named "BB*", it's Martini (coarse grain).
if _is_martini(small_u):
logger.debug("Found residues named BB*: resolution is coarse grain")
return MolecularResolution.CoarseGrainMartini()

if _has_protein(universe) and not _protein_has_hydrogen(universe):
logger.debug("Found protein without hydrogen: resolution is coarse grain")
return MolecularResolution.CoarseGrainProteinHasNoHydrogen()

# Last chance: if we find any bond within a given distance, it's all-atom.
# If we reach this point, it means that, for some reason, no hydrogen atom was detected before.
if _has_bonds_within_all_atom_cutoff(small_u, all_atom_cutoff_distance):
logger.debug("Found bonds within all-atom distance cutoff: resolution is all-atom")
return MolecularResolution.AllAtomWithStandardBonds()

# Coarse grain not detected before, no bonds within cutoff distance, it's coarse grain.
logger.debug("No bonds found within all-atom distance cutoff: resolution is coarse grain")
return MolecularResolution.CoarseGrainOther()
2 changes: 1 addition & 1 deletion src/grodecoder/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def main(args: "CliArgs"):

# Storing cli arguments into settings.
settings = get_settings()
settings.chain_detection.distance_cutoff = args.bond_threshold
settings.chain_detection.cutoff_distance = args.bond_threshold
settings.output.atom_ids = not args.no_atom_ids

logger.info(f"Processing structure file: {structure_path}")
Expand Down
6 changes: 1 addition & 5 deletions src/grodecoder/models.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from __future__ import annotations
from .guesser import MolecularResolution

from enum import StrEnum
from typing import Protocol
Expand All @@ -20,11 +21,6 @@
from .settings import Settings


class MolecularResolution(StrEnum):
COARSE_GRAINED = "coarse-grained"
ALL_ATOM = "all-atom"


class MolecularType(StrEnum):
PROTEIN = "protein"
NUCLEIC = "nucleic_acid"
Expand Down
62 changes: 31 additions & 31 deletions src/grodecoder/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,49 +15,49 @@

@dataclass(init=False)
class DistanceCutoff:
default_distance_cutoff_all_atom: ClassVar[float] = 5.0
default_distance_cutoff_coarse_grain: ClassVar[float] = 6.0
_user_distance_cutoff: float | None = None
_guessed_distance_cutoff: float | None = None
default_cutoff_distance_all_atom: ClassVar[float] = 5.0
default_cutoff_distance_coarse_grain: ClassVar[float] = 6.0
_user_cutoff_distance: float | None = None
_guessed_cutoff_distance: float | None = None

def __init__(self, user_value: float | None = None):
if user_value is not None:
self.set(user_value)

def is_defined(self) -> bool:
"""Returns True if the distance cutoff has been set or guessed."""
return any((self._user_distance_cutoff, self._guessed_distance_cutoff))
return self._user_cutoff_distance is not None or self._guessed_cutoff_distance is not None

def is_set(self) -> bool:
"""Returns True if the distance cutoff has been set."""
return self._user_distance_cutoff is not None
return self._user_cutoff_distance is not None

def is_guessed(self) -> bool:
"""Returns True if the distance cutoff has been guessed."""
return self._guessed_distance_cutoff is not None
return self._guessed_cutoff_distance is not None

def get(self) -> float:
if not self.is_defined():
raise ValueError("`distance_cutoff` must be set or guessed before it is used.")
return self._user_distance_cutoff or self._guessed_distance_cutoff # ty: ignore[invalid-return-type]
raise ValueError("`cutoff_distance` must be set or guessed before it is used.")
return self._user_cutoff_distance or self._guessed_cutoff_distance # ty: ignore[invalid-return-type]

def set(self, value: float):
if self.is_guessed():
self._guessed_distance_cutoff = None
self._user_distance_cutoff = value
self._guessed_cutoff_distance = None
self._user_cutoff_distance = value

def guess(self, resolution: "MolecularResolution"):
if resolution == "ALL_ATOM":
distance_cutoff = self.default_distance_cutoff_all_atom
if resolution.is_all_atom():
cutoff_distance = self.default_cutoff_distance_all_atom
logger.debug(
f"chain detection: using default distance cutoff for all atom structures: {distance_cutoff:.2f}"
f"chain detection: using default distance cutoff for all atom structures: {cutoff_distance:.2f}"
)
else:
distance_cutoff = self.default_distance_cutoff_coarse_grain
cutoff_distance = self.default_cutoff_distance_coarse_grain
logger.debug(
f"chain detection: using default distance cutoff for coarse grain structures: {distance_cutoff:.2f}"
f"chain detection: using default distance cutoff for coarse grain structures: {cutoff_distance:.2f}"
)
self._guessed_distance_cutoff = distance_cutoff
self._guessed_cutoff_distance = cutoff_distance


class _DistanceCutoffPydanticAnnotation:
Expand All @@ -67,36 +67,36 @@ class _DistanceCutoffPydanticAnnotation:
Examples:
>>> from grodecoder.settings import ChainDetectionSettings
>>> cds = ChainDetectionSettings()
>>> cds.distance_cutoff
DistanceCutoff(_user_distance_cutoff=None, _guessed_distance_cutoff=None)
>>> cds.cutoff_distance
DistanceCutoff(_user_cutoff_distance=None, _guessed_cutoff_distance=None)

>>> # Float assignement
>>> cds.distance_cutoff = 12
>>> cds.distance_cutoff
DistanceCutoff(_user_distance_cutoff=12.0, _guessed_distance_cutoff=None)
>>> cds.cutoff_distance = 12
>>> cds.cutoff_distance
DistanceCutoff(_user_cutoff_distance=12.0, _guessed_cutoff_distance=None)

>>> # None assignment
>>> cds.distance_cutoff = None
>>> cds.distance_cutoff
DistanceCutoff(_user_distance_cutoff=None, _guessed_distance_cutoff=None)
>>> cds.cutoff_distance = None
>>> cds.cutoff_distance
DistanceCutoff(_user_cutoff_distance=None, _guessed_cutoff_distance=None)

>>> # Serialization
>>> cds.distance_cutoff = 12
>>> cds.cutoff_distance = 12
>>> cds.model_dump()
{'distance_cutoff': 12.0}
{'cutoff_distance': 12.0}

>>> # Validation
>>> as_json = cds.model_dump()
>>> ChainDetectionSettings.model_validate(as_json)
ChainDetectionSettings(distance_cutoff=DistanceCutoff(_user_distance_cutoff=12.0, _guessed_distance_cutoff=None))
ChainDetectionSettings(cutoff_distance=DistanceCutoff(_user_cutoff_distance=12.0, _guessed_cutoff_distance=None))
"""

@classmethod
def __get_pydantic_core_schema__(cls, _source_type, _handler) -> core_schema.CoreSchema:
"""
We return a pydantic_core.CoreSchema that behaves in the following ways:

* floats will be parsed as `DistanceCutoff` instances with the float as the `_user_distance_cutoff` attribute
* floats will be parsed as `DistanceCutoff` instances with the float as the `_user_cutoff_distance` attribute
* `DistanceCutoff` instances will be parsed as `DistanceCutoff` instances without any changes
* Nothing else will pass validation
* Serialization will always return just a float
Expand Down Expand Up @@ -148,11 +148,11 @@ def __get_pydantic_json_schema__(

class ChainDetectionSettings(BaseSettings):
model_config = ConfigDict(validate_assignment=True)
distance_cutoff: PydanticDistanceCutoff = Field(default_factory=DistanceCutoff)
cutoff_distance: PydanticDistanceCutoff = Field(default_factory=DistanceCutoff)


class ResolutionDetectionSettings(BaseSettings):
distance_cutoff: float = 1.6
cutoff_distance: float = 1.6


class OutputSettings(BaseSettings):
Expand Down
Loading