Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions openfold3/core/data/framework/data_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,7 +152,11 @@ class DataModuleConfig(BaseModel):
datasets: list[SerializeAsAny[BaseModel]]
batch_size: int = 1
num_workers: int = 0
prefetch_factor: int | None = None
num_workers_validation: int = 0
prefetch_factor_validation: int | None = None
persistent_workers: bool = False
multiprocessing_context: str | None = None
data_seed: int = 42
epoch_len: int = 1

Expand All @@ -165,8 +169,14 @@ def __init__(self, data_module_config: DataModuleConfig) -> None:

# Possibly initialize directly from DataModuleConfig
self.batch_size = data_module_config.batch_size

self.num_workers = data_module_config.num_workers
self.prefetch_factor = data_module_config.prefetch_factor
self.num_workers_validation = data_module_config.num_workers_validation
self.prefetch_factor_validation = data_module_config.prefetch_factor_validation
self.persistent_workers = data_module_config.persistent_workers
self.multiprocessing_context = data_module_config.multiprocessing_context

self.data_seed = data_module_config.data_seed
self.next_data_seed = data_module_config.data_seed
self.epoch_len = data_module_config.epoch_len
Expand Down Expand Up @@ -408,17 +418,20 @@ def generate_dataloader(self, mode: DatasetMode, sampler: Sampler | None = None)
Returns:
DataLoader: DataLoader object.
"""

# TODO: Val does not need this many workers. Due to memory leak issue,
# reduce workers here to run with more workers overall in training
# as temporary quick fix.
if (
mode == DatasetMode.validation
and DatasetMode.train in self.multi_dataset_config.modes
):
num_workers = self.num_workers_validation
prefetch_factor = self.prefetch_factor_validation
else:
num_workers = self.num_workers
prefetch_factor = self.prefetch_factor

persistent_workers = self.persistent_workers and num_workers > 0
multiprocessing_context = (
self.multiprocessing_context if num_workers > 0 else None
)

generator = self.generators.get(mode)
if generator is None:
Expand All @@ -445,6 +458,9 @@ def generate_dataloader(self, mode: DatasetMode, sampler: Sampler | None = None)
collate_fn=openfold_batch_collator,
generator=self.generators[mode],
worker_init_fn=worker_init_fn,
persistent_workers=persistent_workers,
prefetch_factor=prefetch_factor,
multiprocessing_context=multiprocessing_context,
)

def train_dataloader(self) -> DataLoader:
Expand Down
31 changes: 24 additions & 7 deletions openfold3/core/data/framework/single_datasets/base_of3.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
SingleDataset,
register_dataset,
)
from openfold3.core.data.framework.single_datasets.dataset_utils import warm_lmdb_cache
from openfold3.core.data.io.dataset_cache import read_datacache
from openfold3.core.data.pipelines.featurization.conformer import (
featurize_reference_conformers_of3,
Expand Down Expand Up @@ -153,15 +154,20 @@ def __init__(self, dataset_config) -> None:
# TODO: rename dataset_cache_file to dataset_cache_path to signal that it can be
# a directory or a file
# TODO: potentially expose the LMDB database encoding types
self.dataset_cache = read_datacache(
dataset_config.dataset_paths.dataset_cache_file
)
self._dataset_cache_file = dataset_config.dataset_paths.dataset_cache_file
self.dataset_cache = read_datacache(self._dataset_cache_file)
self.warm_cache()

self.datapoint_cache = {}

if dataset_config.dataset_paths.template_structures_directory is not None:
self.ccd = pdbx.CIFFile.read(dataset_config.dataset_paths.ccd_file)
else:
self.ccd = None
# Only used if template structures are not preprocessed
# Lazy-loaded so the dataset is picklable (forkserver)
self._ccd = None
self._ccd_file = (
dataset_config.dataset_paths.ccd_file
if dataset_config.dataset_paths.template_structures_directory is not None
else None
)

# Dataset configuration
# n_tokens can be set in the getitem method separately for each sample using
Expand All @@ -174,6 +180,17 @@ def __init__(self, dataset_config) -> None:
self.single_moltype = None
self.debug_mode = dataset_config.debug_mode

def warm_cache(self) -> None:
"""Warm the OS page cache for LMDB. No-op for JSON."""
if self._dataset_cache_file.is_dir():
warm_lmdb_cache(self._dataset_cache_file)

@property
def ccd(self):
if self._ccd is None and self._ccd_file is not None:
self._ccd = pdbx.CIFFile.read(self._ccd_file)
return self._ccd

@log_runtime_memory(runtime_dict_key="runtime-create-structure-features")
def create_structure_features(
self,
Expand Down
21 changes: 21 additions & 0 deletions openfold3/core/data/framework/single_datasets/dataset_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
import copy
import logging
import os
import time
from itertools import cycle, islice
from pathlib import Path

import pandas as pd
import torch
Expand All @@ -30,6 +32,7 @@
naive_alignment,
)

logger = logging.getLogger(__name__)
worker_seed_log = logging.getLogger(f"{__name__}.worker_seed")


Expand Down Expand Up @@ -153,3 +156,21 @@ def getitem_debug_log(dataset_name: str = "") -> None:
f"pid={os.getpid()} worker_id={worker_id} wi.seed={wi_seed} "
f"wi.base_seed={wi_base_seed} torch.initial_seed={torch_seed}",
)


def warm_file_cache(file_path: Path) -> None:
"""Sequentially read a file to warm the OS page cache."""
file_size_gb = file_path.stat().st_size / (1024**3)
logger.info(f"Warming page cache for {file_path} ({file_size_gb:.1f} GB)...")
t0 = time.monotonic()
chunk_size = 8 * 1024 * 1024
with open(file_path, "rb") as f:
while f.read(chunk_size):
pass
elapsed = time.monotonic() - t0
logger.info(f"Page cache warm complete in {elapsed:.1f}s")


def warm_lmdb_cache(lmdb_directory: Path) -> None:
"""Sequentially read the LMDB data file to warm the OS page cache."""
warm_file_cache(lmdb_directory / "data.mdb")
3 changes: 3 additions & 0 deletions openfold3/core/data/framework/single_datasets/monomer.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,9 @@ def __init__(self, dataset_config: dict) -> None:
# Datapoint cache
self.create_datapoint_cache()

# Release so fork/forkserver inherits clean state for the first epoch
self.dataset_cache.release_connections()

def create_datapoint_cache(self):
"""Creates the datapoint_cache for uniform sampling.

Expand Down
3 changes: 3 additions & 0 deletions openfold3/core/data/framework/single_datasets/pdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,9 @@ def __init__(self, dataset_config: dict) -> None:
# Datapoint cache
self.create_datapoint_cache()

# Release so fork/forkserver inherits clean state for the first epoch
self.dataset_cache.release_connections()

def create_datapoint_cache(self) -> None:
"""Creates the datapoint_cache with chain/interface probabilities.

Expand Down
3 changes: 3 additions & 0 deletions openfold3/core/data/framework/single_datasets/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ def __init__(self, dataset_config: dict, world_size: int | None = None) -> None:
# Dataset/datapoint cache
self.create_datapoint_cache()

# Release so fork/forkserver inherits clean state for the first epoch
self.dataset_cache.release_connections()

# Cropping should be disabled for validation datasets
if self.crop["token_crop"]["enabled"]:
logger.warning(
Expand Down
3 changes: 3 additions & 0 deletions openfold3/core/data/io/dataset_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,9 @@ def read_datacache(
with lmdb_env.begin() as txn:
dataset_cache_type = json.loads(txn.get(type_key).decode(str_encoding))

# Only one connection can be open at a time, close before creating LMDBDict
lmdb_env.close()

if not dataset_cache_type:
raise ValueError("No type found for this directory.")

Expand Down
67 changes: 38 additions & 29 deletions openfold3/core/data/primitives/caches/format.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

import lmdb

from openfold3.core.data.primitives.caches.lmdb import LMDBDict
from openfold3.core.data.primitives.caches.lmdb import LMDBDict, LMDBEnv
from openfold3.core.data.resources.residues import MoleculeType

K = TypeVar("K")
Expand Down Expand Up @@ -135,6 +135,7 @@ def from_json(cls, file: Path) -> PreprocessingDataCache:
# rerunning preprocessing
elif status == "failed":
release_date = None
experimental_method = None
resolution = None
chains = None
interfaces = None
Expand Down Expand Up @@ -414,7 +415,7 @@ class DatasetCache:
# TODO: update parsers for this base class
@classmethod
def from_json(cls, file: Path) -> DatasetCache:
"""Costructs a datacache from a json.
"""Constructs a datacache from a json.

Args:
file (Path):
Expand All @@ -434,13 +435,15 @@ def from_json(cls, file: Path) -> DatasetCache:
reference_molecule_data=cls._parse_ref_mol_data_json(data),
)

@staticmethod
def _parse_type_json(data: dict) -> None:
# Remove _type field (already an internal private attribute so shouldn't be
# defined as an explicit field)
if "_type" in data:
# This is conditional for legacy compatibility, should be removed after
del data["_type"]

@staticmethod
def _parse_name_json(data: dict) -> str:
return data["name"]

Expand Down Expand Up @@ -479,6 +482,15 @@ def _parse_ref_mol_data_json(cls, data: dict) -> dict:
ref_mol_data[ref_mol_id] = per_ref_mol_data_fmt
return ref_mol_data

def release_connections(self) -> None:
"""
Close any open backend connections so fork inherits clean state.
Each backend reopens lazily on next access. No-op for plain dicts.
"""
for attr in (self.structure_data, self.reference_molecule_data):
if hasattr(attr, "close"):
attr.close()

def to_json(self, file: Path) -> None:
"""Write the dataset cache to a JSON file.

Expand Down Expand Up @@ -523,27 +535,29 @@ def from_lmdb(
DatasetCache:
The constructed datacache.
"""

lmdb_env = lmdb.open(
str(lmdb_directory), readonly=True, lock=False, subdir=True
)

with lmdb_env.begin() as transaction:
lmdb_env = LMDBEnv(str(lmdb_directory))
with lmdb_env.get().begin() as transaction:
_ = cls._parse_type_lmdb(transaction, str_encoding)
name = cls._parse_name_lmdb(transaction, str_encoding)
structure_data = cls._parse_structure_data_lmdb(
lmdb_env, str_encoding, structure_data_encoding
)
reference_molecule_data = cls._parse_ref_mol_data_lmdb(
lmdb_env, str_encoding, reference_molecule_data_encoding
)

return cls(
name=name,
structure_data=structure_data,
reference_molecule_data=reference_molecule_data,
)
structure_data = cls._parse_structure_data_lmdb(
lmdb_env=lmdb_env,
str_encoding=str_encoding,
structure_data_encoding=structure_data_encoding,
)
reference_molecule_data = cls._parse_ref_mol_data_lmdb(
lmdb_env=lmdb_env,
str_encoding=str_encoding,
reference_molecule_data_encoding=reference_molecule_data_encoding,
)

return cls(
name=name,
structure_data=structure_data,
reference_molecule_data=reference_molecule_data,
)

@staticmethod
def _parse_type_lmdb(
transaction: lmdb.Transaction, str_encoding: Literal["utf-8", "pkl"]
) -> str:
Expand All @@ -555,6 +569,7 @@ def _parse_type_lmdb(

return _type

@staticmethod
def _parse_name_lmdb(
transaction: lmdb.Transaction, str_encoding: Literal["utf-8", "pkl"]
) -> str:
Expand All @@ -566,31 +581,25 @@ def _parse_name_lmdb(

return name

@staticmethod
def _parse_structure_data_lmdb(
lmdb_env: lmdb.Environment,
lmdb_env: LMDBEnv,
str_encoding: Literal["utf-8", "pkl"],
structure_data_encoding: Literal["utf-8", "pkl"],
) -> LMDBDict:
from openfold3.core.data.primitives.caches.lmdb import (
LMDBDict,
)

return LMDBDict(
lmdb_env=lmdb_env,
prefix="structure_data",
key_encoding=str_encoding,
value_encoding=structure_data_encoding,
)

@staticmethod
def _parse_ref_mol_data_lmdb(
lmdb_env: lmdb.Environment,
lmdb_env: LMDBEnv,
str_encoding: Literal["utf-8", "pkl"],
reference_molecule_data_encoding: Literal["utf-8", "pkl"],
) -> LMDBDict:
from openfold3.core.data.primitives.caches.lmdb import (
LMDBDict,
)

return LMDBDict(
lmdb_env=lmdb_env,
prefix="reference_molecule_data",
Expand Down
Loading
Loading