diff --git a/extra_data/keydata.py b/extra_data/keydata.py index 7ce1ea02..8860a6fa 100644 --- a/extra_data/keydata.py +++ b/extra_data/keydata.py @@ -444,3 +444,9 @@ def trains(self, keep_dims=False): yield tid, ds[start] start += count + + def _pasha_functor_(self): + """Integration with pasha for map operations.""" + + from .pasha_functor import ExtraDataFunctor + return ExtraDataFunctor(self) diff --git a/extra_data/pasha_functor.py b/extra_data/pasha_functor.py new file mode 100644 index 00000000..74df6978 --- /dev/null +++ b/extra_data/pasha_functor.py @@ -0,0 +1,50 @@ +from os import getpid + +import numpy as np + +from . import DataCollection, KeyData +from .read_machinery import split_trains + + +class ExtraDataFunctor: + """Pasha functor for EXtra-data objects. + + This functor wraps an EXtra-data DataCollection, SourceData or + KeyData and performs the map operation over its trains. The kernel + is passed the current train's index in the collection, the train ID + and the data mapping (for DataCollection) or data entry (for KeyData). + """ + + def __init__(self, obj): + self.obj = obj + self.n_trains = len(self.obj.train_ids) + + # Save PID of parent process where the functor is created to + # close files as appropriately later on, see comment below. + self._parent_pid = getpid() + + @classmethod + def wrap(cls, value): + if isinstance(value, (DataCollection, KeyData)): + return cls(value) + + def split(self, num_workers): + return split_trains(self.n_trains, parts=num_workers) + + def iterate(self, share): + subobj = self.obj.select_trains(np.s_[share]) + + # Older versions of HDF < 1.10.5 are not robust against sharing + # a file descriptor across threads or processes. If running in a + # different process than the functor was initially created in, + # close all file handles inherited from the parent collection to + # force re-opening them again in each child process. + if getpid() != self._parent_pid: + for f in subobj.files: + f.close() + + index_it = range(*share.indices(self.n_trains)) + data_it = subobj.trains() + + for index, (train_id, data) in zip(index_it, data_it): + yield index, train_id, data diff --git a/extra_data/reader.py b/extra_data/reader.py index 7395b276..e558f921 100644 --- a/extra_data/reader.py +++ b/extra_data/reader.py @@ -347,6 +347,12 @@ def trains(self, devices=None, train_range=None, *, require_all=False, return iter(TrainIterator(dc, require_all=require_all, flat_keys=flat_keys, keep_dims=keep_dims)) + def _pasha_functor_(self): + """Integration with pasha for map operations.""" + + from .pasha_functor import ExtraDataFunctor + return ExtraDataFunctor(self) + def train_from_id( self, train_id, devices=None, *, flat_keys=False, keep_dims=False): """Get train data for specified train ID.