-
Notifications
You must be signed in to change notification settings - Fork 33
Description
When examining the data loader code recently, I was surprised to learn that XarraySubset does not take into account the length of samples when subsetting its input XarrayDataset. My mental model for the way XarraySubset worked was that it first subset the dataset in time, and then derived all possible samples from that subset, but that is not the case. Instead it derives all possible samples from the original dataset, and then subsets based on sample start times, which means that particularly in the case of training with longer rollouts—e.g., up to 5 days with our current stochastic model protocol—there can be somewhat of a blurring of the boundaries between training and held-out data.
It has probably not been that impactful on our work up to this point, but it could be worth revisiting if there is not an important design requirement around the current approach. This is an example test based on the mock_monthly_netcdfs fixture in test_xarray.py that illustrates the issue:
def test_XarraySubset_uniqueness(mock_monthly_netcdfs):
mock_data: MockData = mock_monthly_netcdfs
subset_config_a = XarrayDataConfig(data_path=mock_data.tmpdir, subset=Slice(0, 2))
subset_config_b = XarrayDataConfig(data_path=mock_data.tmpdir, subset=Slice(2, 4))
schedule = IntSchedule(start_value=2, milestones=[])
names = mock_data.var_names.all_names
dataset_a, _ = get_xarray_dataset(subset_config_a, names, schedule)
dataset_b, _ = get_xarray_dataset(subset_config_b, names, schedule)
assert len(dataset_a) == len(dataset_b)
dataset_a_unique_times = set()
dataset_b_unique_times = set()
for (_, times_a, _, _), (_, times_b, _, _) in zip(dataset_a, dataset_b):
dataset_a_unique_times.update(times_a.to_numpy().tolist())
dataset_b_unique_times.update(times_b.to_numpy().tolist())
# First check that sample start times are unique
sample_start_times_a = set(dataset_a.sample_start_times)
sample_start_times_b = set(dataset_b.sample_start_times)
intersection = sample_start_times_a.intersection(sample_start_times_b)
assert intersection == set()
# Then check that all times associated with all samples are unique; this
# check fails with the current implementation of XarraySubset.
unique_times_intersection = dataset_a_unique_times.intersection(dataset_b_unique_times)
assert unique_times_intersection == set()