From 68d11665198d36e459660f65ff58caa86f487af2 Mon Sep 17 00:00:00 2001 From: Philipp Schmidt Date: Tue, 22 Mar 2022 16:42:56 +0100 Subject: [PATCH 1/3] Add ExtraDataFunctor for integration with pasha --- extra_data/keydata.py | 6 ++++ extra_data/pasha_functor.py | 66 +++++++++++++++++++++++++++++++++++++ extra_data/reader.py | 6 ++++ extra_data/sourcedata.py | 6 ++++ 4 files changed, 84 insertions(+) create mode 100644 extra_data/pasha_functor.py diff --git a/extra_data/keydata.py b/extra_data/keydata.py index 7ce1ea02..8860a6fa 100644 --- a/extra_data/keydata.py +++ b/extra_data/keydata.py @@ -444,3 +444,9 @@ def trains(self, keep_dims=False): yield tid, ds[start] start += count + + def _pasha_functor_(self): + """Integration with pasha for map operations.""" + + from .pasha_functor import ExtraDataFunctor + return ExtraDataFunctor(self) diff --git a/extra_data/pasha_functor.py b/extra_data/pasha_functor.py new file mode 100644 index 00000000..625b1b2c --- /dev/null +++ b/extra_data/pasha_functor.py @@ -0,0 +1,66 @@ + +from os import getpid + +import numpy as np +from pasha.functor import gen_split_slices + +from . import DataCollection, SourceData, KeyData + + +class ExtraDataFunctor: + """Pasha functor for EXtra-data objects. + + This functor wraps an EXtra-data DataCollection, SourceData or + KeyData and performs the map operation over its trains. The kernel + is passed the current train's index in the collection, the train ID + and the data mapping (for DataCollection and SourceData) or data + entry (for KeyData). + """ + + def __init__(self, obj): + self.obj = obj + self.n_trains = len(self.obj.train_ids) + + # Save PID of parent process where the functor is created to + # close files as appropriately later on, see comment below. + self._parent_pid = getpid() + + @classmethod + def wrap(cls, value): + if isinstance(value, (DataCollection, SourceData, KeyData)): + return cls(value) + + def split(self, num_workers): + return gen_split_slices(self.n_trains, n_parts=num_workers) + + def iterate(self, share): + subobj = self.obj.select_trains(np.s_[share]) + + # Older versions of HDF < 1.10.5 are not robust against sharing + # a file descriptor across threads or processes. If running in a + # different process than the functor was initially created in, + # close all file handles inherited from the parent collection to + # force re-opening them again in each child process. + if getpid() != self._parent_pid: + for f in subobj.files: + f.close() + + index_it = range(*share.indices(self.n_trains)) + + if isinstance(subobj, SourceData): + # SourceData has no trains() iterator yet, so simulate it + # ourselves by reconstructing a DataCollection object and + # use its trains() iterator. + dc = DataCollection( + subobj.files, {subobj.source: subobj}, subobj.train_ids, + inc_suspect_trains=subobj.inc_suspect_trains, + is_single_run=True) + data_it = ((train_id, data[subobj.source]) + for train_id, data in dc.trains()) + else: + # Use the regular trains() iterator for DataCollection and + # KeyData + data_it = subobj.trains() + + for index, (train_id, data) in zip(index_it, data_it): + yield index, train_id, data diff --git a/extra_data/reader.py b/extra_data/reader.py index 7395b276..e558f921 100644 --- a/extra_data/reader.py +++ b/extra_data/reader.py @@ -347,6 +347,12 @@ def trains(self, devices=None, train_range=None, *, require_all=False, return iter(TrainIterator(dc, require_all=require_all, flat_keys=flat_keys, keep_dims=keep_dims)) + def _pasha_functor_(self): + """Integration with pasha for map operations.""" + + from .pasha_functor import ExtraDataFunctor + return ExtraDataFunctor(self) + def train_from_id( self, train_id, devices=None, *, flat_keys=False, keep_dims=False): """Get train data for specified train ID. diff --git a/extra_data/sourcedata.py b/extra_data/sourcedata.py index 90de980d..aa905288 100644 --- a/extra_data/sourcedata.py +++ b/extra_data/sourcedata.py @@ -227,3 +227,9 @@ def union(self, *others) -> 'SourceData': section=self.section, inc_suspect_trains=self.inc_suspect_trains ) + + def _pasha_functor_(self): + """Integration with pasha for map operations.""" + + from .pasha_functor import ExtraDataFunctor + return ExtraDataFunctor(self) From 21d9ab7dad55878026ca431e9e1982217f6a4788 Mon Sep 17 00:00:00 2001 From: Thomas Kluyver Date: Wed, 6 Apr 2022 11:05:54 +0100 Subject: [PATCH 2/3] Use internal split_trains function --- extra_data/pasha_functor.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/extra_data/pasha_functor.py b/extra_data/pasha_functor.py index 625b1b2c..566c317b 100644 --- a/extra_data/pasha_functor.py +++ b/extra_data/pasha_functor.py @@ -1,10 +1,9 @@ - from os import getpid import numpy as np -from pasha.functor import gen_split_slices from . import DataCollection, SourceData, KeyData +from .read_machinery import split_trains class ExtraDataFunctor: @@ -31,7 +30,7 @@ def wrap(cls, value): return cls(value) def split(self, num_workers): - return gen_split_slices(self.n_trains, n_parts=num_workers) + return split_trains(self.n_trains, parts=num_workers) def iterate(self, share): subobj = self.obj.select_trains(np.s_[share]) From c8d37ef6e89f730568ae36220a5eee4174a8943d Mon Sep 17 00:00:00 2001 From: Philipp Schmidt Date: Wed, 27 Apr 2022 14:34:26 +0200 Subject: [PATCH 3/3] Remove provisional SourceData iteration support in pasha functor --- extra_data/pasha_functor.py | 23 ++++------------------- extra_data/sourcedata.py | 6 ------ 2 files changed, 4 insertions(+), 25 deletions(-) diff --git a/extra_data/pasha_functor.py b/extra_data/pasha_functor.py index 566c317b..74df6978 100644 --- a/extra_data/pasha_functor.py +++ b/extra_data/pasha_functor.py @@ -2,7 +2,7 @@ import numpy as np -from . import DataCollection, SourceData, KeyData +from . import DataCollection, KeyData from .read_machinery import split_trains @@ -12,8 +12,7 @@ class ExtraDataFunctor: This functor wraps an EXtra-data DataCollection, SourceData or KeyData and performs the map operation over its trains. The kernel is passed the current train's index in the collection, the train ID - and the data mapping (for DataCollection and SourceData) or data - entry (for KeyData). + and the data mapping (for DataCollection) or data entry (for KeyData). """ def __init__(self, obj): @@ -26,7 +25,7 @@ def __init__(self, obj): @classmethod def wrap(cls, value): - if isinstance(value, (DataCollection, SourceData, KeyData)): + if isinstance(value, (DataCollection, KeyData)): return cls(value) def split(self, num_workers): @@ -45,21 +44,7 @@ def iterate(self, share): f.close() index_it = range(*share.indices(self.n_trains)) - - if isinstance(subobj, SourceData): - # SourceData has no trains() iterator yet, so simulate it - # ourselves by reconstructing a DataCollection object and - # use its trains() iterator. - dc = DataCollection( - subobj.files, {subobj.source: subobj}, subobj.train_ids, - inc_suspect_trains=subobj.inc_suspect_trains, - is_single_run=True) - data_it = ((train_id, data[subobj.source]) - for train_id, data in dc.trains()) - else: - # Use the regular trains() iterator for DataCollection and - # KeyData - data_it = subobj.trains() + data_it = subobj.trains() for index, (train_id, data) in zip(index_it, data_it): yield index, train_id, data diff --git a/extra_data/sourcedata.py b/extra_data/sourcedata.py index aa905288..90de980d 100644 --- a/extra_data/sourcedata.py +++ b/extra_data/sourcedata.py @@ -227,9 +227,3 @@ def union(self, *others) -> 'SourceData': section=self.section, inc_suspect_trains=self.inc_suspect_trains ) - - def _pasha_functor_(self): - """Integration with pasha for map operations.""" - - from .pasha_functor import ExtraDataFunctor - return ExtraDataFunctor(self)