Skip to content

XarraySubset only applies time slice to sample start times #763

@spencerkclark

Description

@spencerkclark

When examining the data loader code recently, I was surprised to learn that XarraySubset does not take into account the length of samples when subsetting its input XarrayDataset. My mental model for the way XarraySubset worked was that it first subset the dataset in time, and then derived all possible samples from that subset, but that is not the case. Instead it derives all possible samples from the original dataset, and then subsets based on sample start times, which means that particularly in the case of training with longer rollouts—e.g., up to 5 days with our current stochastic model protocol—there can be somewhat of a blurring of the boundaries between training and held-out data.

It has probably not been that impactful on our work up to this point, but it could be worth revisiting if there is not an important design requirement around the current approach. This is an example test based on the mock_monthly_netcdfs fixture in test_xarray.py that illustrates the issue:

def test_XarraySubset_uniqueness(mock_monthly_netcdfs):
    mock_data: MockData = mock_monthly_netcdfs
    subset_config_a = XarrayDataConfig(data_path=mock_data.tmpdir, subset=Slice(0, 2))
    subset_config_b = XarrayDataConfig(data_path=mock_data.tmpdir, subset=Slice(2, 4))

    schedule = IntSchedule(start_value=2, milestones=[])
    names = mock_data.var_names.all_names
    dataset_a, _ = get_xarray_dataset(subset_config_a, names, schedule)
    dataset_b, _ = get_xarray_dataset(subset_config_b, names, schedule)
    assert len(dataset_a) == len(dataset_b)

    dataset_a_unique_times = set()
    dataset_b_unique_times = set()
    for (_, times_a, _, _), (_, times_b, _, _) in zip(dataset_a, dataset_b):
        dataset_a_unique_times.update(times_a.to_numpy().tolist())
        dataset_b_unique_times.update(times_b.to_numpy().tolist())
    
    # First check that sample start times are unique
    sample_start_times_a = set(dataset_a.sample_start_times)
    sample_start_times_b = set(dataset_b.sample_start_times)
    intersection = sample_start_times_a.intersection(sample_start_times_b)
    assert intersection == set()

    # Then check that all times associated with all samples are unique; this
    # check fails with the current implementation of XarraySubset.
    unique_times_intersection = dataset_a_unique_times.intersection(dataset_b_unique_times)
    assert unique_times_intersection == set()

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions