diff --git a/extra_data/keydata.py b/extra_data/keydata.py index b05a4308..a87854ea 100644 --- a/extra_data/keydata.py +++ b/extra_data/keydata.py @@ -535,6 +535,29 @@ def train_id_coordinates(self): ] return np.concatenate(chunks_trainids) + def train_index_bounds(self, labelled=False): + """Generate first and last indices of trains to use alongside data + from ``.ndarray()``. + + If *labelled* is True, returns a pandas dataframe with columns first + and last. Otherwise, returns a tuple of two NumPy arrays. + """ + + counts = self.data_counts(labelled) + first = counts.copy() + + if labelled: + first.iloc[0] = 0 + first.iloc[1:] = counts.cumsum().iloc[:-1] + last = first + counts + import pandas as pd + return pd.concat( + [first.rename('first'), last.rename('last')], axis=1) + else: + first[0] = 0 + first[1:] = counts.cumsum()[:-1] + return first, first + counts + def xarray(self, extra_dims=None, roi=(), name=None, extra_coords=None): """Load this data as a labelled xarray array or dataset. diff --git a/extra_data/tests/mockdata/xgm.py b/extra_data/tests/mockdata/xgm.py index e946d477..e575ee58 100644 --- a/extra_data/tests/mockdata/xgm.py +++ b/extra_data/tests/mockdata/xgm.py @@ -31,7 +31,7 @@ class XGM(DeviceBase): ('pulseEnergy/crossUsed', 'f4', ()), ('pulseEnergy/gammaUsed', 'f4', ()), ('pulseEnergy/gmdError', 'i4', ()), - ('pulseEnergy/nummberOfBrunches', 'f4', ()), + ('pulseEnergy/nummberOfBunches', 'f4', ()), ('pulseEnergy/photonFlux', 'f4', ()), ('pulseEnergy/pressure', 'f4', ()), ('pulseEnergy/temperature', 'f4', ()), diff --git a/extra_data/tests/test_keydata.py b/extra_data/tests/test_keydata.py index dcdf993c..d91bc53e 100644 --- a/extra_data/tests/test_keydata.py +++ b/extra_data/tests/test_keydata.py @@ -247,6 +247,33 @@ def test_data_counts_missing_train(fxe_run_module_offset): np.testing.assert_array_equal(arr, 128) +@pytest.mark.parametrize('labelled', [True, False]) +def test_train_index_bounds(mock_spb_raw_run, labelled): + run = RunDirectory(mock_spb_raw_run) + + agipd_m0 = run['SPB_DET_AGIPD1M-1/DET/0CH0:xtdf', 'image.pulseId'] + bounds = agipd_m0.train_index_bounds(labelled) + + if labelled: + first, last = bounds['first'], bounds['last'] + else: + first, last = bounds + + np.testing.assert_array_equal(first, np.arange(0, 4032+1, 64)) + np.testing.assert_array_equal(last, first+64) + + xgm = run['SPB_XTD9_XGM/DOOCS/MAIN', 'pulseEnergy.photonFlux'] + bounds = xgm.train_index_bounds(labelled) + + if labelled: + first, last = bounds['first'], bounds['last'] + else: + first, last = bounds + + np.testing.assert_array_equal(first, np.arange(len(first))) + np.testing.assert_array_equal(last, first+1) + + def test_select_by(mock_spb_raw_run): run = RunDirectory(mock_spb_raw_run) am0 = run['SPB_DET_AGIPD1M-1/DET/0CH0:xtdf', 'image.data']